Spaces:
Paused
Paused
| #!/usr/bin/env python3 | |
| """ | |
| 重排功能测试脚本 | |
| 演示不同重排策略的效果 | |
| """ | |
| import sys | |
| import os | |
| sys.path.append(os.path.dirname(__file__)) | |
| from document_processor import DocumentProcessor | |
| from reranker import * | |
| from langchain.schema import Document | |
| import time | |
| def create_test_documents(): | |
| """创建测试文档""" | |
| return [ | |
| Document( | |
| page_content="人工智能(AI)是计算机科学的一个分支,致力于创建能够执行通常需要人类智能的任务的系统。", | |
| metadata={"source": "ai_intro.txt", "category": "AI基础"} | |
| ), | |
| Document( | |
| page_content="机器学习是人工智能的一个重要子领域,通过算法让计算机从数据中学习模式和规律。", | |
| metadata={"source": "ml_basics.txt", "category": "机器学习"} | |
| ), | |
| Document( | |
| page_content="深度学习是机器学习的一个分支,使用多层神经网络来模拟人脑的学习过程。", | |
| metadata={"source": "dl_guide.txt", "category": "深度学习"} | |
| ), | |
| Document( | |
| page_content="自然语言处理(NLP)是人工智能领域的一个重要分支,专注于使计算机理解和处理人类语言。", | |
| metadata={"source": "nlp_overview.txt", "category": "自然语言处理"} | |
| ), | |
| Document( | |
| page_content="计算机视觉是人工智能的另一个重要领域,使计算机能够识别和理解图像和视频内容。", | |
| metadata={"source": "cv_intro.txt", "category": "计算机视觉"} | |
| ), | |
| Document( | |
| page_content="强化学习是机器学习的一种类型,通过与环境交互来学习最优的行为策略。", | |
| metadata={"source": "rl_basics.txt", "category": "强化学习"} | |
| ), | |
| Document( | |
| page_content="今天的天气非常好,阳光明媚,适合外出游玩和运动。", | |
| metadata={"source": "weather.txt", "category": "天气"} | |
| ), | |
| Document( | |
| page_content="区块链是一种分布式账本技术,具有去中心化、不可篡改等特点。", | |
| metadata={"source": "blockchain.txt", "category": "区块链"} | |
| ) | |
| ] | |
| def test_reranker_comparison(): | |
| """比较不同重排器的效果""" | |
| print("🔍 重排器效果比较测试") | |
| print("=" * 60) | |
| # 创建测试数据 | |
| query = "什么是人工智能和机器学习?" | |
| documents = create_test_documents() | |
| # 创建一个简单的嵌入模型(用于测试) | |
| try: | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| embeddings = HuggingFaceEmbeddings( | |
| model_name="sentence-transformers/all-MiniLM-L6-v2", | |
| model_kwargs={'device': 'cpu'} | |
| ) | |
| print("✅ 成功加载嵌入模型") | |
| except Exception as e: | |
| print(f"❌ 嵌入模型加载失败: {e}") | |
| print("将使用基础重排器进行测试") | |
| embeddings = None | |
| # 测试不同的重排器 | |
| rerankers = [] | |
| # TF-IDF重排器 | |
| rerankers.append(("TF-IDF", TFIDFReranker())) | |
| # BM25重排器 | |
| rerankers.append(("BM25", BM25Reranker())) | |
| if embeddings: | |
| # 语义重排器 | |
| rerankers.append(("语义相似度", SemanticReranker(embeddings))) | |
| # 混合重排器 | |
| rerankers.append(("混合策略", HybridReranker(embeddings))) | |
| # 多样性重排器 | |
| rerankers.append(("多样性优化", DiversityReranker(embeddings))) | |
| # 执行测试 | |
| for name, reranker in rerankers: | |
| print(f"\n📊 {name} 重排结果:") | |
| print("-" * 40) | |
| start_time = time.time() | |
| try: | |
| results = reranker.rerank(query, documents, top_k=5) | |
| end_time = time.time() | |
| print(f"⏱️ 处理时间: {(end_time - start_time)*1000:.2f}ms") | |
| for i, (doc, score) in enumerate(results, 1): | |
| content = doc.page_content[:80] + "..." if len(doc.page_content) > 80 else doc.page_content | |
| category = doc.metadata.get('category', '未知') | |
| print(f"{i}. [分数: {score:.4f}] [{category}] {content}") | |
| except Exception as e: | |
| print(f"❌ 重排失败: {e}") | |
| def test_reranking_with_embeddings(): | |
| """测试带嵌入的重排功能""" | |
| print("\n\n🧠 嵌入模型重排测试") | |
| print("=" * 60) | |
| try: | |
| # 创建文档处理器 | |
| processor = DocumentProcessor() | |
| # 创建测试文档 | |
| test_docs = create_test_documents() | |
| # 测试查询 | |
| queries = [ | |
| "人工智能的定义是什么?", | |
| "机器学习和深度学习的区别", | |
| "自然语言处理的应用", | |
| "今天天气怎么样?" | |
| ] | |
| for query in queries: | |
| print(f"\n🔍 查询: {query}") | |
| print("-" * 30) | |
| if processor.reranker: | |
| # 使用重排功能 | |
| results = processor.reranker.rerank(query, test_docs, top_k=3) | |
| for i, (doc, score) in enumerate(results, 1): | |
| content = doc.page_content[:60] + "..." if len(doc.page_content) > 60 else doc.page_content | |
| category = doc.metadata.get('category', '未知') | |
| print(f"{i}. [分数: {score:.4f}] [{category}] {content}") | |
| else: | |
| print("❌ 重排器未初始化") | |
| except Exception as e: | |
| print(f"❌ 测试失败: {e}") | |
| def test_performance_comparison(): | |
| """性能对比测试""" | |
| print("\n\n⚡ 性能对比测试") | |
| print("=" * 60) | |
| documents = create_test_documents() * 10 # 增加文档数量 | |
| query = "人工智能技术的发展趋势" | |
| # 测试不同重排器的性能 | |
| rerankers_config = [ | |
| ("无重排", None), | |
| ("TF-IDF", TFIDFReranker()), | |
| ("BM25", BM25Reranker()) | |
| ] | |
| for name, reranker in rerankers_config: | |
| times = [] | |
| # 多次测试取平均值 | |
| for _ in range(5): | |
| start_time = time.time() | |
| if reranker: | |
| results = reranker.rerank(query, documents, top_k=5) | |
| else: | |
| # 模拟无重排的情况 | |
| results = documents[:5] | |
| end_time = time.time() | |
| times.append((end_time - start_time) * 1000) | |
| avg_time = sum(times) / len(times) | |
| print(f"{name}: 平均处理时间 {avg_time:.2f}ms (文档数: {len(documents)})") | |
| def main(): | |
| """主测试函数""" | |
| print("🚀 向量重排功能综合测试") | |
| print("=" * 80) | |
| try: | |
| # 基础重排器比较 | |
| test_reranker_comparison() | |
| # 嵌入模型重排测试 | |
| test_reranking_with_embeddings() | |
| # 性能对比测试 | |
| test_performance_comparison() | |
| print("\n\n✅ 所有测试完成!") | |
| print("=" * 80) | |
| except KeyboardInterrupt: | |
| print("\n❌ 测试被用户中断") | |
| except Exception as e: | |
| print(f"\n❌ 测试过程中发生错误: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| if __name__ == "__main__": | |
| main() |