Spaces:
Paused
Paused
| """ | |
| Kaggle简化多模态测试脚本 | |
| 用于在Kaggle环境中直接处理已上传的PDF和图片文件 | |
| """ | |
| import os | |
| import sys | |
| import subprocess | |
| import time | |
| from typing import List, Dict, Any | |
| # 添加项目路径 | |
| sys.path.insert(0, '/kaggle/working/adaptive_RAG') | |
| # 导入项目模块 | |
| from document_processor import DocumentProcessor | |
| from main import AdaptiveRAGSystem | |
| from config import ENABLE_MULTIMODAL, SUPPORTED_IMAGE_FORMATS | |
| def setup_kaggle_environment(): | |
| """设置Kaggle环境""" | |
| print("🔧 设置Kaggle环境...") | |
| # 安装必要的依赖 | |
| subprocess.run([sys.executable, '-m', 'pip', 'install', '-q', | |
| 'PyPDF2', 'pdfplumber', 'Pillow']) | |
| print("✅ 环境设置完成") | |
| def process_uploaded_files(pdf_path: str = None, image_paths: List[str] = None): | |
| """ | |
| 处理已上传的文件,向量化并持久化到项目目录 | |
| 支持文件去重,避免重复处理 | |
| Args: | |
| pdf_path: PDF文件路径 | |
| image_paths: 图片路径列表 | |
| """ | |
| import hashlib | |
| import json | |
| # 设置向量数据库持久化目录(相对路径) | |
| # 获取当前脚本所在目录 | |
| current_dir = os.path.dirname(os.path.abspath(__file__)) | |
| persist_dir = os.path.join(current_dir, 'chroma_db') | |
| metadata_file = os.path.join(current_dir, 'document_metadata.json') | |
| os.makedirs(persist_dir, exist_ok=True) | |
| print(f"💾 向量数据库持久化目录: {persist_dir}") | |
| # 加载已处理文件的元数据(用于去重) | |
| processed_files = {} | |
| if os.path.exists(metadata_file): | |
| try: | |
| with open(metadata_file, 'r', encoding='utf-8') as f: | |
| metadata = json.load(f) | |
| processed_files = metadata.get('processed_files', {}) | |
| print(f"📊 已加载元数据,发现 {len(processed_files)} 个已处理的文件") | |
| except Exception as e: | |
| print(f"⚠️ 加载元数据失败: {e}") | |
| # 计算文件哈希值(用于去重检测) | |
| def get_file_hash(file_path: str) -> str: | |
| """计算文件的MD5哈希值""" | |
| if not os.path.exists(file_path): | |
| return None | |
| try: | |
| with open(file_path, 'rb') as f: | |
| file_hash = hashlib.md5(f.read()).hexdigest() | |
| return file_hash | |
| except Exception as e: | |
| print(f"⚠️ 计算文件哈希失败: {e}") | |
| return None | |
| # 检查是否已存在向量数据库 | |
| if os.path.exists(persist_dir) and os.listdir(persist_dir): | |
| print("✅ 检测到已存在的向量数据库,加载中...") | |
| try: | |
| # 加载已存在的向量数据库 | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| from langchain_community.vectorstores import Chroma | |
| from config import EMBEDDING_MODEL, COLLECTION_NAME | |
| embeddings = HuggingFaceEmbeddings( | |
| model_name=EMBEDDING_MODEL, | |
| model_kwargs={'device': 'cpu'} | |
| ) | |
| vectorstore = Chroma( | |
| persist_directory=persist_dir, | |
| embedding_function=embeddings, | |
| collection_name=COLLECTION_NAME | |
| ) | |
| retriever = vectorstore.as_retriever() | |
| print(f"✅ 已加载持久化的向量数据库,共 {vectorstore._collection.count()} 个文档块") | |
| # 初始化文档处理器 | |
| doc_processor = DocumentProcessor() | |
| # 检查PDF文件是否需要处理 | |
| if pdf_path and os.path.exists(pdf_path): | |
| file_hash = get_file_hash(pdf_path) | |
| if file_hash and file_hash in processed_files: | |
| print(f"⏭️ PDF文件已处理过({pdf_path}),跳过") | |
| else: | |
| print(f"🆕 检测到新PDF文件,正在添加: {pdf_path}") | |
| try: | |
| from langchain_community.document_loaders import PyPDFLoader | |
| loader = PyPDFLoader(pdf_path) | |
| docs = loader.load() | |
| doc_splits = doc_processor.split_documents(docs) | |
| # 添加到现有向量数据库 | |
| vectorstore.add_documents(doc_splits) | |
| print(f"✅ 已添加 {len(doc_splits)} 个新文档块") | |
| # 更新元数据 | |
| if file_hash: | |
| processed_files[file_hash] = { | |
| 'path': pdf_path, | |
| 'type': 'pdf', | |
| 'chunks': len(doc_splits), | |
| 'processed_at': time.time() | |
| } | |
| with open(metadata_file, 'w', encoding='utf-8') as f: | |
| json.dump({'processed_files': processed_files}, f, ensure_ascii=False, indent=2) | |
| print(f"💾 元数据已更新") | |
| except Exception as e: | |
| print(f"⚠️ 添加新PDF失败: {e}") | |
| except Exception as e: | |
| print(f"⚠️ 加载向量数据库失败: {e},将重新创建") | |
| vectorstore, retriever, doc_processor = None, None, None | |
| else: | |
| vectorstore, retriever, doc_processor = None, None, None | |
| # 如果没有加载成功,则创建新的向量数据库 | |
| if vectorstore is None: | |
| print("🔧 正在创建新的向量数据库...") | |
| # 初始化文档处理器 | |
| doc_processor = DocumentProcessor() | |
| # 处理PDF文件 | |
| if pdf_path and os.path.exists(pdf_path): | |
| print(f"📄 处理PDF文件: {pdf_path}") | |
| try: | |
| from langchain_community.document_loaders import PyPDFLoader | |
| loader = PyPDFLoader(pdf_path) | |
| docs = loader.load() | |
| # 分割文档 | |
| doc_splits = doc_processor.split_documents(docs) | |
| # 创建向量数据库(带持久化) | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| from langchain_community.vectorstores import Chroma | |
| from config import EMBEDDING_MODEL, COLLECTION_NAME | |
| embeddings = HuggingFaceEmbeddings( | |
| model_name=EMBEDDING_MODEL, | |
| model_kwargs={'device': 'cpu'} | |
| ) | |
| vectorstore = Chroma.from_documents( | |
| documents=doc_splits, | |
| embedding=embeddings, | |
| collection_name=COLLECTION_NAME, | |
| persist_directory=persist_dir # 持久化目录 | |
| ) | |
| retriever = vectorstore.as_retriever() | |
| print(f"✅ PDF处理完成,共 {len(doc_splits)} 个文档块") | |
| print(f"💾 向量数据库已持久化到: {persist_dir}") | |
| # 保存元数据 | |
| file_hash = get_file_hash(pdf_path) | |
| if file_hash: | |
| processed_files[file_hash] = { | |
| 'path': pdf_path, | |
| 'type': 'pdf', | |
| 'chunks': len(doc_splits), | |
| 'processed_at': time.time() | |
| } | |
| with open(metadata_file, 'w', encoding='utf-8') as f: | |
| json.dump({'processed_files': processed_files}, f, ensure_ascii=False, indent=2) | |
| print(f"💾 元数据已保存") | |
| except Exception as e: | |
| print(f"❌ PDF处理失败: {e}") | |
| return None, None | |
| else: | |
| # 使用默认知识库 | |
| print("📄 使用默认知识库...") | |
| try: | |
| vectorstore, retriever, doc_splits = doc_processor.setup_knowledge_base() | |
| # 将默认知识库也持久化 | |
| if vectorstore and hasattr(vectorstore, '_persist_directory'): | |
| vectorstore._persist_directory = persist_dir | |
| print(f"💾 默认知识库已持久化到: {persist_dir}") | |
| except Exception as e: | |
| print(f"❌ 默认知识库加载失败: {e}") | |
| return None, None | |
| # 初始化RAG系统 | |
| print("🤖 正在初始化自适应RAG系统...") | |
| rag_system = AdaptiveRAGSystem() | |
| # 更新RAG系统的检索器 | |
| rag_system.retriever = retriever | |
| rag_system.doc_processor = doc_processor | |
| rag_system.workflow_nodes.retriever = retriever | |
| rag_system.workflow_nodes.doc_processor = doc_processor | |
| return rag_system, doc_processor | |
| def query_with_multimodal(rag_system: AdaptiveRAGSystem, query: str, image_paths: List[str] = None): | |
| """ | |
| 执行多模态查询 | |
| Args: | |
| rag_system: RAG系统实例 | |
| query: 查询字符串 | |
| image_paths: 图片路径列表 | |
| """ | |
| print(f"🔍 查询: {query}") | |
| try: | |
| # 执行查询 | |
| result = rag_system.query(query) | |
| # 显示结果 | |
| print("\n🎯 答案:") | |
| print(result['answer']) | |
| # 显示评估指标 | |
| if result.get('retrieval_metrics'): | |
| metrics = result['retrieval_metrics'] | |
| print("\n📊 检索评估:") | |
| print(f" - 检索耗时: {metrics.get('latency', 0):.4f}秒") | |
| print(f" - 检索文档数: {metrics.get('retrieved_docs_count', 0)}") | |
| print(f" - Precision@3: {metrics.get('precision_at_3', 0):.4f}") | |
| print(f" - Recall@3: {metrics.get('recall_at_3', 0):.4f}") | |
| print(f" - MAP: {metrics.get('map_score', 0):.4f}") | |
| return result | |
| except Exception as e: | |
| print(f"❌ 查询失败: {e}") | |
| return None | |
| def scan_and_copy_files(): | |
| """扫描 /kaggle/input/ 并复制文件到 /kaggle/working/""" | |
| import shutil | |
| input_dir = '/kaggle/input' | |
| working_dir = '/kaggle/working' | |
| if not os.path.exists(input_dir): | |
| print("⚠️ /kaggle/input/ 目录不存在,跳过文件扫描") | |
| return | |
| print("📂 扫描 /kaggle/input/ 目录...") | |
| copied_pdfs = [] | |
| copied_images = [] | |
| # 递归扫描所有文件 | |
| for root, dirs, files in os.walk(input_dir): | |
| for file in files: | |
| # 跳过隐藏文件和空文件名 | |
| if not file or file.startswith('.'): | |
| continue | |
| # 调试:显示所有文件 | |
| print(f" 🔍 扫描到: {file}") | |
| src = os.path.join(root, file) | |
| dst = os.path.join(working_dir, file) | |
| try: | |
| # 修复:使用小写比较,支持 .pdf, .PDF, .Pdf 等 | |
| if file.lower().endswith('.pdf'): | |
| shutil.copy(src, dst) | |
| copied_pdfs.append(file) | |
| print(f" ✅ 复制 PDF: {file}") | |
| elif any(file.lower().endswith(ext) for ext in ['.jpg', '.jpeg', '.png', '.gif', '.bmp']): | |
| shutil.copy(src, dst) | |
| copied_images.append(file) | |
| print(f" ✅ 复制图片: {file}") | |
| else: | |
| print(f" ⚪ 跳过非目标文件: {file}") | |
| except Exception as e: | |
| print(f" ⚠️ 复制文件失败 {file}: {e}") | |
| if copied_pdfs or copied_images: | |
| print(f"\n📁 复制完成: {len(copied_pdfs)} 个 PDF, {len(copied_images)} 张图片") | |
| else: | |
| print("⚠️ 未找到 PDF 或图片文件") | |
| print("\n🔍 请检查:") | |
| print(" 1. 文件是否已上传到 Kaggle") | |
| print(" 2. 文件是否在 /kaggle/input/ 目录下") | |
| print(" 3. 文件扩展名是否正确 (.pdf, .jpg, .png 等)") | |
| def main(): | |
| """主函数""" | |
| print("🚀 Kaggle简化多模态测试") | |
| print("="*50) | |
| # 设置环境 | |
| setup_kaggle_environment() | |
| # 从 /kaggle/input/ 复制文件到 /kaggle/working/ | |
| scan_and_copy_files() | |
| # 检查文件 | |
| working_dir = '/kaggle/working' | |
| # 过滤有效的PDF文件(排除隐藏文件) | |
| try: | |
| all_files = os.listdir(working_dir) | |
| # 修复:移除文件名长度限制,支持 .pdf 等短文件名 | |
| pdf_files = [ | |
| f for f in all_files | |
| if f.lower().endswith('.pdf') # 小写比较 | |
| and not f.startswith('.') # 排除隐藏文件 | |
| and os.path.isfile(os.path.join(working_dir, f)) # 确保是文件 | |
| ] | |
| image_files = [ | |
| f for f in all_files | |
| if any(f.lower().endswith(ext) for ext in ['.jpg', '.jpeg', '.png', '.gif', '.bmp']) | |
| and not f.startswith('.') # 排除隐藏文件 | |
| and os.path.isfile(os.path.join(working_dir, f)) # 确保是文件 | |
| ] | |
| except Exception as e: | |
| print(f"❌ 扫描文件时出错: {e}") | |
| pdf_files = [] | |
| image_files = [] | |
| all_files = [] | |
| print(f"\n📁 /kaggle/working/ 中的文件:") | |
| # 调试:详细显示所有文件和过滤过程 | |
| print("\n🔍 详细调试信息:") | |
| print(f" 目录中总共 {len(all_files)} 个项目") | |
| for f in all_files: | |
| f_path = os.path.join(working_dir, f) | |
| is_file = os.path.isfile(f_path) | |
| is_dir = os.path.isdir(f_path) | |
| f_lower = f.lower() | |
| # 检查 PDF | |
| if f_lower.endswith('.pdf'): | |
| file_size = os.path.getsize(f_path) if is_file else 0 | |
| print(f" 📄 {f}: 是文件={is_file}, 大小={file_size/1024:.1f}KB, 长度={len(f)}") | |
| # 检查图片 | |
| elif any(f_lower.endswith(ext) for ext in ['.jpg', '.jpeg', '.png', '.gif', '.bmp']): | |
| file_size = os.path.getsize(f_path) if is_file else 0 | |
| print(f" 🖼️ {f}: 是文件={is_file}, 大小={file_size/1024:.1f}KB") | |
| else: | |
| print(f" ⚪ {f}: 类型={'[目录]' if is_dir else '[文件]'}") | |
| print(f"\n📊 过滤结果:") | |
| print(f" - PDF文件: {len(pdf_files)} 个") | |
| for pdf in pdf_files: | |
| pdf_path = os.path.join(working_dir, pdf) | |
| file_size = os.path.getsize(pdf_path) if os.path.exists(pdf_path) else 0 | |
| print(f" * {pdf} ({file_size/1024:.1f} KB)") | |
| print(f" - 图片文件: {len(image_files)} 个") | |
| for img in image_files: | |
| img_path = os.path.join(working_dir, img) | |
| file_size = os.path.getsize(img_path) if os.path.exists(img_path) else 0 | |
| print(f" * {img} ({file_size/1024:.1f} KB)") | |
| if not pdf_files and not image_files: | |
| print("\n💡 使用说明:") | |
| print(" 1. 在 Kaggle Notebook 右侧点击 '+ Add data'") | |
| print(" 2. 选择 'Upload' 标签") | |
| print(" 3. 上传你的 PDF 和图片文件") | |
| print(" 4. 重新运行此脚本") | |
| print("\n🔍 当前目录内容:") | |
| try: | |
| print(f" {os.listdir(working_dir)}") | |
| except: | |
| pass | |
| return | |
| # 处理文件(添加路径验证) | |
| if pdf_files: | |
| pdf_path = os.path.join(working_dir, pdf_files[0]) | |
| if not os.path.exists(pdf_path) or not os.path.isfile(pdf_path): | |
| print(f"❌ PDF 文件路径无效: {pdf_path}") | |
| pdf_path = None | |
| else: | |
| pdf_path = None | |
| if image_files: | |
| image_paths = [] | |
| for img in image_files: | |
| img_path = os.path.join(working_dir, img) | |
| if os.path.exists(img_path) and os.path.isfile(img_path): | |
| image_paths.append(img_path) | |
| image_paths = image_paths if image_paths else None | |
| else: | |
| image_paths = None | |
| rag_system, doc_processor = process_uploaded_files(pdf_path, image_paths) | |
| if not rag_system: | |
| print("❌ 系统初始化失败") | |
| return | |
| # 示例查询 | |
| print("\n" + "="*50) | |
| print("🧪 示例查询测试") | |
| print("="*50) | |
| # 文本查询示例 | |
| query1 = "请总结文档的主要内容" | |
| query_with_multimodal(rag_system, query1, image_paths) | |
| # 如果有图片,执行多模态查询 | |
| if image_paths and ENABLE_MULTIMODAL: | |
| print("\n" + "="*50) | |
| print("🖼️ 多模态查询测试") | |
| print("="*50) | |
| query2 = "请结合图片内容,解释文档中的相关概念" | |
| query_with_multimodal(rag_system, query2, image_paths) | |
| print("\n" + "="*50) | |
| print("✅ 测试完成") | |
| print("="*50) | |
| print("\n💡 您可以继续使用以下代码进行自定义查询:") | |
| print("```python") | |
| print("# 自定义查询") | |
| print("custom_query = '您的问题'") | |
| print("query_with_multimodal(rag_system, custom_query, image_paths)") | |
| print("```") | |
| if __name__ == "__main__": | |
| main() |