File size: 17,279 Bytes
ddd99a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116d9c5
 
ddd99a5
 
 
 
 
116d9c5
 
ddd99a5
116d9c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ddd99a5
116d9c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ddd99a5
116d9c5
 
ddd99a5
116d9c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ddd99a5
 
116d9c5
 
ddd99a5
116d9c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ddd99a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ae2e9ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94f5b16
 
de81fb9
df14376
 
 
de81fb9
ae2e9ee
 
 
de81fb9
df14376
 
de81fb9
 
 
 
 
 
 
df14376
 
de81fb9
 
ae2e9ee
 
 
 
 
df14376
 
 
 
ae2e9ee
ddd99a5
 
 
 
 
 
 
 
ae2e9ee
 
 
 
ddd99a5
de81fb9
94f5b16
de81fb9
 
df14376
94f5b16
de81fb9
 
94f5b16
de81fb9
 
 
 
 
 
 
 
 
 
 
 
 
94f5b16
ddd99a5
ae2e9ee
94f5b16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ddd99a5
 
de81fb9
 
 
ddd99a5
 
 
de81fb9
 
 
ddd99a5
ae2e9ee
 
 
 
 
 
de81fb9
 
 
 
 
ae2e9ee
 
de81fb9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ddd99a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
"""
Kaggle简化多模态测试脚本
用于在Kaggle环境中直接处理已上传的PDF和图片文件
"""

import os
import sys
import subprocess
import time
from typing import List, Dict, Any

# 添加项目路径
sys.path.insert(0, '/kaggle/working/adaptive_RAG')

# 导入项目模块
from document_processor import DocumentProcessor
from main import AdaptiveRAGSystem
from config import ENABLE_MULTIMODAL, SUPPORTED_IMAGE_FORMATS

def setup_kaggle_environment():
    """设置Kaggle环境"""
    print("🔧 设置Kaggle环境...")
    
    # 安装必要的依赖
    subprocess.run([sys.executable, '-m', 'pip', 'install', '-q', 
                   'PyPDF2', 'pdfplumber', 'Pillow'])
    
    print("✅ 环境设置完成")

def process_uploaded_files(pdf_path: str = None, image_paths: List[str] = None):
    """
    处理已上传的文件,向量化并持久化到项目目录
    支持文件去重,避免重复处理
    
    Args:
        pdf_path: PDF文件路径
        image_paths: 图片路径列表
    """
    import hashlib
    import json
    
    # 设置向量数据库持久化目录(相对路径)
    # 获取当前脚本所在目录
    current_dir = os.path.dirname(os.path.abspath(__file__))
    persist_dir = os.path.join(current_dir, 'chroma_db')
    metadata_file = os.path.join(current_dir, 'document_metadata.json')
    os.makedirs(persist_dir, exist_ok=True)
    
    print(f"💾 向量数据库持久化目录: {persist_dir}")
    
    # 加载已处理文件的元数据(用于去重)
    processed_files = {}
    if os.path.exists(metadata_file):
        try:
            with open(metadata_file, 'r', encoding='utf-8') as f:
                metadata = json.load(f)
                processed_files = metadata.get('processed_files', {})
                print(f"📊 已加载元数据,发现 {len(processed_files)} 个已处理的文件")
        except Exception as e:
            print(f"⚠️  加载元数据失败: {e}")
    
    # 计算文件哈希值(用于去重检测)
    def get_file_hash(file_path: str) -> str:
        """计算文件的MD5哈希值"""
        if not os.path.exists(file_path):
            return None
        try:
            with open(file_path, 'rb') as f:
                file_hash = hashlib.md5(f.read()).hexdigest()
            return file_hash
        except Exception as e:
            print(f"⚠️  计算文件哈希失败: {e}")
            return None
    
    # 检查是否已存在向量数据库
    if os.path.exists(persist_dir) and os.listdir(persist_dir):
        print("✅ 检测到已存在的向量数据库,加载中...")
        try:
            # 加载已存在的向量数据库
            from langchain_community.embeddings import HuggingFaceEmbeddings
            from langchain_community.vectorstores import Chroma
            from config import EMBEDDING_MODEL, COLLECTION_NAME
            
            embeddings = HuggingFaceEmbeddings(
                model_name=EMBEDDING_MODEL,
                model_kwargs={'device': 'cpu'}
            )
            
            vectorstore = Chroma(
                persist_directory=persist_dir,
                embedding_function=embeddings,
                collection_name=COLLECTION_NAME
            )
            
            retriever = vectorstore.as_retriever()
            print(f"✅ 已加载持久化的向量数据库,共 {vectorstore._collection.count()} 个文档块")
            
            # 初始化文档处理器
            doc_processor = DocumentProcessor()
            
            # 检查PDF文件是否需要处理
            if pdf_path and os.path.exists(pdf_path):
                file_hash = get_file_hash(pdf_path)
                if file_hash and file_hash in processed_files:
                    print(f"⏭️  PDF文件已处理过({pdf_path}),跳过")
                else:
                    print(f"🆕 检测到新PDF文件,正在添加: {pdf_path}")
                    try:
                        from langchain_community.document_loaders import PyPDFLoader
                        loader = PyPDFLoader(pdf_path)
                        docs = loader.load()
                        doc_splits = doc_processor.split_documents(docs)
                        
                        # 添加到现有向量数据库
                        vectorstore.add_documents(doc_splits)
                        print(f"✅ 已添加 {len(doc_splits)} 个新文档块")
                        
                        # 更新元数据
                        if file_hash:
                            processed_files[file_hash] = {
                                'path': pdf_path,
                                'type': 'pdf',
                                'chunks': len(doc_splits),
                                'processed_at': time.time()
                            }
                            with open(metadata_file, 'w', encoding='utf-8') as f:
                                json.dump({'processed_files': processed_files}, f, ensure_ascii=False, indent=2)
                            print(f"💾 元数据已更新")
                    except Exception as e:
                        print(f"⚠️  添加新PDF失败: {e}")
            
        except Exception as e:
            print(f"⚠️  加载向量数据库失败: {e},将重新创建")
            vectorstore, retriever, doc_processor = None, None, None
    else:
        vectorstore, retriever, doc_processor = None, None, None
    
    # 如果没有加载成功,则创建新的向量数据库
    if vectorstore is None:
        print("🔧 正在创建新的向量数据库...")
        
        # 初始化文档处理器
        doc_processor = DocumentProcessor()
        
        # 处理PDF文件
        if pdf_path and os.path.exists(pdf_path):
            print(f"📄 处理PDF文件: {pdf_path}")
            try:
                from langchain_community.document_loaders import PyPDFLoader
                loader = PyPDFLoader(pdf_path)
                docs = loader.load()
                
                # 分割文档
                doc_splits = doc_processor.split_documents(docs)
                
                # 创建向量数据库(带持久化)
                from langchain_community.embeddings import HuggingFaceEmbeddings
                from langchain_community.vectorstores import Chroma
                from config import EMBEDDING_MODEL, COLLECTION_NAME
                
                embeddings = HuggingFaceEmbeddings(
                    model_name=EMBEDDING_MODEL,
                    model_kwargs={'device': 'cpu'}
                )
                
                vectorstore = Chroma.from_documents(
                    documents=doc_splits,
                    embedding=embeddings,
                    collection_name=COLLECTION_NAME,
                    persist_directory=persist_dir  # 持久化目录
                )
                
                retriever = vectorstore.as_retriever()
                
                print(f"✅ PDF处理完成,共 {len(doc_splits)} 个文档块")
                print(f"💾 向量数据库已持久化到: {persist_dir}")
                
                # 保存元数据
                file_hash = get_file_hash(pdf_path)
                if file_hash:
                    processed_files[file_hash] = {
                        'path': pdf_path,
                        'type': 'pdf',
                        'chunks': len(doc_splits),
                        'processed_at': time.time()
                    }
                    with open(metadata_file, 'w', encoding='utf-8') as f:
                        json.dump({'processed_files': processed_files}, f, ensure_ascii=False, indent=2)
                    print(f"💾 元数据已保存")
                
            except Exception as e:
                print(f"❌ PDF处理失败: {e}")
                return None, None
        else:
            # 使用默认知识库
            print("📄 使用默认知识库...")
            try:
                vectorstore, retriever, doc_splits = doc_processor.setup_knowledge_base()
                
                # 将默认知识库也持久化
                if vectorstore and hasattr(vectorstore, '_persist_directory'):
                    vectorstore._persist_directory = persist_dir
                    print(f"💾 默认知识库已持久化到: {persist_dir}")
                    
            except Exception as e:
                print(f"❌ 默认知识库加载失败: {e}")
                return None, None
    
    # 初始化RAG系统
    print("🤖 正在初始化自适应RAG系统...")
    rag_system = AdaptiveRAGSystem()
    
    # 更新RAG系统的检索器
    rag_system.retriever = retriever
    rag_system.doc_processor = doc_processor
    rag_system.workflow_nodes.retriever = retriever
    rag_system.workflow_nodes.doc_processor = doc_processor
    
    return rag_system, doc_processor

def query_with_multimodal(rag_system: AdaptiveRAGSystem, query: str, image_paths: List[str] = None):
    """
    执行多模态查询
    
    Args:
        rag_system: RAG系统实例
        query: 查询字符串
        image_paths: 图片路径列表
    """
    print(f"🔍 查询: {query}")
    
    try:
        # 执行查询
        result = rag_system.query(query)
        
        # 显示结果
        print("\n🎯 答案:")
        print(result['answer'])
        
        # 显示评估指标
        if result.get('retrieval_metrics'):
            metrics = result['retrieval_metrics']
            print("\n📊 检索评估:")
            print(f"   - 检索耗时: {metrics.get('latency', 0):.4f}秒")
            print(f"   - 检索文档数: {metrics.get('retrieved_docs_count', 0)}")
            print(f"   - Precision@3: {metrics.get('precision_at_3', 0):.4f}")
            print(f"   - Recall@3: {metrics.get('recall_at_3', 0):.4f}")
            print(f"   - MAP: {metrics.get('map_score', 0):.4f}")
        
        return result
    except Exception as e:
        print(f"❌ 查询失败: {e}")
        return None

def scan_and_copy_files():
    """扫描 /kaggle/input/ 并复制文件到 /kaggle/working/"""
    import shutil
    
    input_dir = '/kaggle/input'
    working_dir = '/kaggle/working'
    
    if not os.path.exists(input_dir):
        print("⚠️  /kaggle/input/ 目录不存在,跳过文件扫描")
        return
    
    print("📂 扫描 /kaggle/input/ 目录...")
    
    copied_pdfs = []
    copied_images = []
    
    # 递归扫描所有文件
    for root, dirs, files in os.walk(input_dir):
        for file in files:
            # 跳过隐藏文件和空文件名
            if not file or file.startswith('.'):
                continue
            
            # 调试:显示所有文件
            print(f"   🔍 扫描到: {file}")
                
            src = os.path.join(root, file)
            dst = os.path.join(working_dir, file)
            
            try:
                # 修复:使用小写比较,支持 .pdf, .PDF, .Pdf 等
                if file.lower().endswith('.pdf'):
                    shutil.copy(src, dst)
                    copied_pdfs.append(file)
                    print(f"   ✅ 复制 PDF: {file}")
                elif any(file.lower().endswith(ext) for ext in ['.jpg', '.jpeg', '.png', '.gif', '.bmp']):
                    shutil.copy(src, dst)
                    copied_images.append(file)
                    print(f"   ✅ 复制图片: {file}")
                else:
                    print(f"   ⚪ 跳过非目标文件: {file}")
            except Exception as e:
                print(f"   ⚠️  复制文件失败 {file}: {e}")
    
    if copied_pdfs or copied_images:
        print(f"\n📁 复制完成: {len(copied_pdfs)} 个 PDF, {len(copied_images)} 张图片")
    else:
        print("⚠️  未找到 PDF 或图片文件")
        print("\n🔍 请检查:")
        print("   1. 文件是否已上传到 Kaggle")
        print("   2. 文件是否在 /kaggle/input/ 目录下")
        print("   3. 文件扩展名是否正确 (.pdf, .jpg, .png 等)")

def main():
    """主函数"""
    print("🚀 Kaggle简化多模态测试")
    print("="*50)
    
    # 设置环境
    setup_kaggle_environment()
    
    # 从 /kaggle/input/ 复制文件到 /kaggle/working/
    scan_and_copy_files()
    
    # 检查文件
    working_dir = '/kaggle/working'
    
    # 过滤有效的PDF文件(排除隐藏文件)
    try:
        all_files = os.listdir(working_dir)
        
        # 修复:移除文件名长度限制,支持 .pdf 等短文件名
        pdf_files = [
            f for f in all_files 
            if f.lower().endswith('.pdf')  # 小写比较
            and not f.startswith('.')  # 排除隐藏文件
            and os.path.isfile(os.path.join(working_dir, f))  # 确保是文件
        ]
        image_files = [
            f for f in all_files 
            if any(f.lower().endswith(ext) for ext in ['.jpg', '.jpeg', '.png', '.gif', '.bmp'])
            and not f.startswith('.')  # 排除隐藏文件
            and os.path.isfile(os.path.join(working_dir, f))  # 确保是文件
        ]
    except Exception as e:
        print(f"❌ 扫描文件时出错: {e}")
        pdf_files = []
        image_files = []
        all_files = []
    
    print(f"\n📁 /kaggle/working/ 中的文件:")
    
    # 调试:详细显示所有文件和过滤过程
    print("\n🔍 详细调试信息:")
    print(f"   目录中总共 {len(all_files)} 个项目")
    for f in all_files:
        f_path = os.path.join(working_dir, f)
        is_file = os.path.isfile(f_path)
        is_dir = os.path.isdir(f_path)
        f_lower = f.lower()
        
        # 检查 PDF
        if f_lower.endswith('.pdf'):
            file_size = os.path.getsize(f_path) if is_file else 0
            print(f"   📄 {f}: 是文件={is_file}, 大小={file_size/1024:.1f}KB, 长度={len(f)}")
        # 检查图片
        elif any(f_lower.endswith(ext) for ext in ['.jpg', '.jpeg', '.png', '.gif', '.bmp']):
            file_size = os.path.getsize(f_path) if is_file else 0
            print(f"   🖼️ {f}: 是文件={is_file}, 大小={file_size/1024:.1f}KB")
        else:
            print(f"   ⚪ {f}: 类型={'[目录]' if is_dir else '[文件]'}")
    
    print(f"\n📊 过滤结果:")
    print(f"   - PDF文件: {len(pdf_files)} 个")
    for pdf in pdf_files:
        pdf_path = os.path.join(working_dir, pdf)
        file_size = os.path.getsize(pdf_path) if os.path.exists(pdf_path) else 0
        print(f"     * {pdf} ({file_size/1024:.1f} KB)")
    
    print(f"   - 图片文件: {len(image_files)} 个")
    for img in image_files:
        img_path = os.path.join(working_dir, img)
        file_size = os.path.getsize(img_path) if os.path.exists(img_path) else 0
        print(f"     * {img} ({file_size/1024:.1f} KB)")
    
    if not pdf_files and not image_files:
        print("\n💡 使用说明:")
        print("   1. 在 Kaggle Notebook 右侧点击 '+ Add data'")
        print("   2. 选择 'Upload' 标签")
        print("   3. 上传你的 PDF 和图片文件")
        print("   4. 重新运行此脚本")
        print("\n🔍 当前目录内容:")
        try:
            print(f"   {os.listdir(working_dir)}")
        except:
            pass
        return
    
    # 处理文件(添加路径验证)
    if pdf_files:
        pdf_path = os.path.join(working_dir, pdf_files[0])
        if not os.path.exists(pdf_path) or not os.path.isfile(pdf_path):
            print(f"❌ PDF 文件路径无效: {pdf_path}")
            pdf_path = None
    else:
        pdf_path = None
    
    if image_files:
        image_paths = []
        for img in image_files:
            img_path = os.path.join(working_dir, img)
            if os.path.exists(img_path) and os.path.isfile(img_path):
                image_paths.append(img_path)
        image_paths = image_paths if image_paths else None
    else:
        image_paths = None
    
    rag_system, doc_processor = process_uploaded_files(pdf_path, image_paths)
    
    if not rag_system:
        print("❌ 系统初始化失败")
        return
    
    # 示例查询
    print("\n" + "="*50)
    print("🧪 示例查询测试")
    print("="*50)
    
    # 文本查询示例
    query1 = "请总结文档的主要内容"
    query_with_multimodal(rag_system, query1, image_paths)
    
    # 如果有图片,执行多模态查询
    if image_paths and ENABLE_MULTIMODAL:
        print("\n" + "="*50)
        print("🖼️ 多模态查询测试")
        print("="*50)
        
        query2 = "请结合图片内容,解释文档中的相关概念"
        query_with_multimodal(rag_system, query2, image_paths)
    
    print("\n" + "="*50)
    print("✅ 测试完成")
    print("="*50)
    print("\n💡 您可以继续使用以下代码进行自定义查询:")
    print("```python")
    print("# 自定义查询")
    print("custom_query = '您的问题'")
    print("query_with_multimodal(rag_system, custom_query, image_paths)")
    print("```")

if __name__ == "__main__":
    main()