adaptive_rag / kaggle_simple_multimodal.py
lanny xu
delete vectara
116d9c5
"""
Kaggle简化多模态测试脚本
用于在Kaggle环境中直接处理已上传的PDF和图片文件
"""
import os
import sys
import subprocess
import time
from typing import List, Dict, Any
# 添加项目路径
sys.path.insert(0, '/kaggle/working/adaptive_RAG')
# 导入项目模块
from document_processor import DocumentProcessor
from main import AdaptiveRAGSystem
from config import ENABLE_MULTIMODAL, SUPPORTED_IMAGE_FORMATS
def setup_kaggle_environment():
"""设置Kaggle环境"""
print("🔧 设置Kaggle环境...")
# 安装必要的依赖
subprocess.run([sys.executable, '-m', 'pip', 'install', '-q',
'PyPDF2', 'pdfplumber', 'Pillow'])
print("✅ 环境设置完成")
def process_uploaded_files(pdf_path: str = None, image_paths: List[str] = None):
"""
处理已上传的文件,向量化并持久化到项目目录
支持文件去重,避免重复处理
Args:
pdf_path: PDF文件路径
image_paths: 图片路径列表
"""
import hashlib
import json
# 设置向量数据库持久化目录(相对路径)
# 获取当前脚本所在目录
current_dir = os.path.dirname(os.path.abspath(__file__))
persist_dir = os.path.join(current_dir, 'chroma_db')
metadata_file = os.path.join(current_dir, 'document_metadata.json')
os.makedirs(persist_dir, exist_ok=True)
print(f"💾 向量数据库持久化目录: {persist_dir}")
# 加载已处理文件的元数据(用于去重)
processed_files = {}
if os.path.exists(metadata_file):
try:
with open(metadata_file, 'r', encoding='utf-8') as f:
metadata = json.load(f)
processed_files = metadata.get('processed_files', {})
print(f"📊 已加载元数据,发现 {len(processed_files)} 个已处理的文件")
except Exception as e:
print(f"⚠️ 加载元数据失败: {e}")
# 计算文件哈希值(用于去重检测)
def get_file_hash(file_path: str) -> str:
"""计算文件的MD5哈希值"""
if not os.path.exists(file_path):
return None
try:
with open(file_path, 'rb') as f:
file_hash = hashlib.md5(f.read()).hexdigest()
return file_hash
except Exception as e:
print(f"⚠️ 计算文件哈希失败: {e}")
return None
# 检查是否已存在向量数据库
if os.path.exists(persist_dir) and os.listdir(persist_dir):
print("✅ 检测到已存在的向量数据库,加载中...")
try:
# 加载已存在的向量数据库
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import Chroma
from config import EMBEDDING_MODEL, COLLECTION_NAME
embeddings = HuggingFaceEmbeddings(
model_name=EMBEDDING_MODEL,
model_kwargs={'device': 'cpu'}
)
vectorstore = Chroma(
persist_directory=persist_dir,
embedding_function=embeddings,
collection_name=COLLECTION_NAME
)
retriever = vectorstore.as_retriever()
print(f"✅ 已加载持久化的向量数据库,共 {vectorstore._collection.count()} 个文档块")
# 初始化文档处理器
doc_processor = DocumentProcessor()
# 检查PDF文件是否需要处理
if pdf_path and os.path.exists(pdf_path):
file_hash = get_file_hash(pdf_path)
if file_hash and file_hash in processed_files:
print(f"⏭️ PDF文件已处理过({pdf_path}),跳过")
else:
print(f"🆕 检测到新PDF文件,正在添加: {pdf_path}")
try:
from langchain_community.document_loaders import PyPDFLoader
loader = PyPDFLoader(pdf_path)
docs = loader.load()
doc_splits = doc_processor.split_documents(docs)
# 添加到现有向量数据库
vectorstore.add_documents(doc_splits)
print(f"✅ 已添加 {len(doc_splits)} 个新文档块")
# 更新元数据
if file_hash:
processed_files[file_hash] = {
'path': pdf_path,
'type': 'pdf',
'chunks': len(doc_splits),
'processed_at': time.time()
}
with open(metadata_file, 'w', encoding='utf-8') as f:
json.dump({'processed_files': processed_files}, f, ensure_ascii=False, indent=2)
print(f"💾 元数据已更新")
except Exception as e:
print(f"⚠️ 添加新PDF失败: {e}")
except Exception as e:
print(f"⚠️ 加载向量数据库失败: {e},将重新创建")
vectorstore, retriever, doc_processor = None, None, None
else:
vectorstore, retriever, doc_processor = None, None, None
# 如果没有加载成功,则创建新的向量数据库
if vectorstore is None:
print("🔧 正在创建新的向量数据库...")
# 初始化文档处理器
doc_processor = DocumentProcessor()
# 处理PDF文件
if pdf_path and os.path.exists(pdf_path):
print(f"📄 处理PDF文件: {pdf_path}")
try:
from langchain_community.document_loaders import PyPDFLoader
loader = PyPDFLoader(pdf_path)
docs = loader.load()
# 分割文档
doc_splits = doc_processor.split_documents(docs)
# 创建向量数据库(带持久化)
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import Chroma
from config import EMBEDDING_MODEL, COLLECTION_NAME
embeddings = HuggingFaceEmbeddings(
model_name=EMBEDDING_MODEL,
model_kwargs={'device': 'cpu'}
)
vectorstore = Chroma.from_documents(
documents=doc_splits,
embedding=embeddings,
collection_name=COLLECTION_NAME,
persist_directory=persist_dir # 持久化目录
)
retriever = vectorstore.as_retriever()
print(f"✅ PDF处理完成,共 {len(doc_splits)} 个文档块")
print(f"💾 向量数据库已持久化到: {persist_dir}")
# 保存元数据
file_hash = get_file_hash(pdf_path)
if file_hash:
processed_files[file_hash] = {
'path': pdf_path,
'type': 'pdf',
'chunks': len(doc_splits),
'processed_at': time.time()
}
with open(metadata_file, 'w', encoding='utf-8') as f:
json.dump({'processed_files': processed_files}, f, ensure_ascii=False, indent=2)
print(f"💾 元数据已保存")
except Exception as e:
print(f"❌ PDF处理失败: {e}")
return None, None
else:
# 使用默认知识库
print("📄 使用默认知识库...")
try:
vectorstore, retriever, doc_splits = doc_processor.setup_knowledge_base()
# 将默认知识库也持久化
if vectorstore and hasattr(vectorstore, '_persist_directory'):
vectorstore._persist_directory = persist_dir
print(f"💾 默认知识库已持久化到: {persist_dir}")
except Exception as e:
print(f"❌ 默认知识库加载失败: {e}")
return None, None
# 初始化RAG系统
print("🤖 正在初始化自适应RAG系统...")
rag_system = AdaptiveRAGSystem()
# 更新RAG系统的检索器
rag_system.retriever = retriever
rag_system.doc_processor = doc_processor
rag_system.workflow_nodes.retriever = retriever
rag_system.workflow_nodes.doc_processor = doc_processor
return rag_system, doc_processor
def query_with_multimodal(rag_system: AdaptiveRAGSystem, query: str, image_paths: List[str] = None):
"""
执行多模态查询
Args:
rag_system: RAG系统实例
query: 查询字符串
image_paths: 图片路径列表
"""
print(f"🔍 查询: {query}")
try:
# 执行查询
result = rag_system.query(query)
# 显示结果
print("\n🎯 答案:")
print(result['answer'])
# 显示评估指标
if result.get('retrieval_metrics'):
metrics = result['retrieval_metrics']
print("\n📊 检索评估:")
print(f" - 检索耗时: {metrics.get('latency', 0):.4f}秒")
print(f" - 检索文档数: {metrics.get('retrieved_docs_count', 0)}")
print(f" - Precision@3: {metrics.get('precision_at_3', 0):.4f}")
print(f" - Recall@3: {metrics.get('recall_at_3', 0):.4f}")
print(f" - MAP: {metrics.get('map_score', 0):.4f}")
return result
except Exception as e:
print(f"❌ 查询失败: {e}")
return None
def scan_and_copy_files():
"""扫描 /kaggle/input/ 并复制文件到 /kaggle/working/"""
import shutil
input_dir = '/kaggle/input'
working_dir = '/kaggle/working'
if not os.path.exists(input_dir):
print("⚠️ /kaggle/input/ 目录不存在,跳过文件扫描")
return
print("📂 扫描 /kaggle/input/ 目录...")
copied_pdfs = []
copied_images = []
# 递归扫描所有文件
for root, dirs, files in os.walk(input_dir):
for file in files:
# 跳过隐藏文件和空文件名
if not file or file.startswith('.'):
continue
# 调试:显示所有文件
print(f" 🔍 扫描到: {file}")
src = os.path.join(root, file)
dst = os.path.join(working_dir, file)
try:
# 修复:使用小写比较,支持 .pdf, .PDF, .Pdf 等
if file.lower().endswith('.pdf'):
shutil.copy(src, dst)
copied_pdfs.append(file)
print(f" ✅ 复制 PDF: {file}")
elif any(file.lower().endswith(ext) for ext in ['.jpg', '.jpeg', '.png', '.gif', '.bmp']):
shutil.copy(src, dst)
copied_images.append(file)
print(f" ✅ 复制图片: {file}")
else:
print(f" ⚪ 跳过非目标文件: {file}")
except Exception as e:
print(f" ⚠️ 复制文件失败 {file}: {e}")
if copied_pdfs or copied_images:
print(f"\n📁 复制完成: {len(copied_pdfs)} 个 PDF, {len(copied_images)} 张图片")
else:
print("⚠️ 未找到 PDF 或图片文件")
print("\n🔍 请检查:")
print(" 1. 文件是否已上传到 Kaggle")
print(" 2. 文件是否在 /kaggle/input/ 目录下")
print(" 3. 文件扩展名是否正确 (.pdf, .jpg, .png 等)")
def main():
"""主函数"""
print("🚀 Kaggle简化多模态测试")
print("="*50)
# 设置环境
setup_kaggle_environment()
# 从 /kaggle/input/ 复制文件到 /kaggle/working/
scan_and_copy_files()
# 检查文件
working_dir = '/kaggle/working'
# 过滤有效的PDF文件(排除隐藏文件)
try:
all_files = os.listdir(working_dir)
# 修复:移除文件名长度限制,支持 .pdf 等短文件名
pdf_files = [
f for f in all_files
if f.lower().endswith('.pdf') # 小写比较
and not f.startswith('.') # 排除隐藏文件
and os.path.isfile(os.path.join(working_dir, f)) # 确保是文件
]
image_files = [
f for f in all_files
if any(f.lower().endswith(ext) for ext in ['.jpg', '.jpeg', '.png', '.gif', '.bmp'])
and not f.startswith('.') # 排除隐藏文件
and os.path.isfile(os.path.join(working_dir, f)) # 确保是文件
]
except Exception as e:
print(f"❌ 扫描文件时出错: {e}")
pdf_files = []
image_files = []
all_files = []
print(f"\n📁 /kaggle/working/ 中的文件:")
# 调试:详细显示所有文件和过滤过程
print("\n🔍 详细调试信息:")
print(f" 目录中总共 {len(all_files)} 个项目")
for f in all_files:
f_path = os.path.join(working_dir, f)
is_file = os.path.isfile(f_path)
is_dir = os.path.isdir(f_path)
f_lower = f.lower()
# 检查 PDF
if f_lower.endswith('.pdf'):
file_size = os.path.getsize(f_path) if is_file else 0
print(f" 📄 {f}: 是文件={is_file}, 大小={file_size/1024:.1f}KB, 长度={len(f)}")
# 检查图片
elif any(f_lower.endswith(ext) for ext in ['.jpg', '.jpeg', '.png', '.gif', '.bmp']):
file_size = os.path.getsize(f_path) if is_file else 0
print(f" 🖼️ {f}: 是文件={is_file}, 大小={file_size/1024:.1f}KB")
else:
print(f" ⚪ {f}: 类型={'[目录]' if is_dir else '[文件]'}")
print(f"\n📊 过滤结果:")
print(f" - PDF文件: {len(pdf_files)} 个")
for pdf in pdf_files:
pdf_path = os.path.join(working_dir, pdf)
file_size = os.path.getsize(pdf_path) if os.path.exists(pdf_path) else 0
print(f" * {pdf} ({file_size/1024:.1f} KB)")
print(f" - 图片文件: {len(image_files)} 个")
for img in image_files:
img_path = os.path.join(working_dir, img)
file_size = os.path.getsize(img_path) if os.path.exists(img_path) else 0
print(f" * {img} ({file_size/1024:.1f} KB)")
if not pdf_files and not image_files:
print("\n💡 使用说明:")
print(" 1. 在 Kaggle Notebook 右侧点击 '+ Add data'")
print(" 2. 选择 'Upload' 标签")
print(" 3. 上传你的 PDF 和图片文件")
print(" 4. 重新运行此脚本")
print("\n🔍 当前目录内容:")
try:
print(f" {os.listdir(working_dir)}")
except:
pass
return
# 处理文件(添加路径验证)
if pdf_files:
pdf_path = os.path.join(working_dir, pdf_files[0])
if not os.path.exists(pdf_path) or not os.path.isfile(pdf_path):
print(f"❌ PDF 文件路径无效: {pdf_path}")
pdf_path = None
else:
pdf_path = None
if image_files:
image_paths = []
for img in image_files:
img_path = os.path.join(working_dir, img)
if os.path.exists(img_path) and os.path.isfile(img_path):
image_paths.append(img_path)
image_paths = image_paths if image_paths else None
else:
image_paths = None
rag_system, doc_processor = process_uploaded_files(pdf_path, image_paths)
if not rag_system:
print("❌ 系统初始化失败")
return
# 示例查询
print("\n" + "="*50)
print("🧪 示例查询测试")
print("="*50)
# 文本查询示例
query1 = "请总结文档的主要内容"
query_with_multimodal(rag_system, query1, image_paths)
# 如果有图片,执行多模态查询
if image_paths and ENABLE_MULTIMODAL:
print("\n" + "="*50)
print("🖼️ 多模态查询测试")
print("="*50)
query2 = "请结合图片内容,解释文档中的相关概念"
query_with_multimodal(rag_system, query2, image_paths)
print("\n" + "="*50)
print("✅ 测试完成")
print("="*50)
print("\n💡 您可以继续使用以下代码进行自定义查询:")
print("```python")
print("# 自定义查询")
print("custom_query = '您的问题'")
print("query_with_multimodal(rag_system, custom_query, image_paths)")
print("```")
if __name__ == "__main__":
main()