adaptive_rag / test_reranking.py
lanny xu
Initial commit
399f3c6
raw
history blame
7.46 kB
#!/usr/bin/env python3
"""
重排功能测试脚本
演示不同重排策略的效果
"""
import sys
import os
sys.path.append(os.path.dirname(__file__))
from document_processor import DocumentProcessor
from reranker import *
from langchain.schema import Document
import time
def create_test_documents():
"""创建测试文档"""
return [
Document(
page_content="人工智能(AI)是计算机科学的一个分支,致力于创建能够执行通常需要人类智能的任务的系统。",
metadata={"source": "ai_intro.txt", "category": "AI基础"}
),
Document(
page_content="机器学习是人工智能的一个重要子领域,通过算法让计算机从数据中学习模式和规律。",
metadata={"source": "ml_basics.txt", "category": "机器学习"}
),
Document(
page_content="深度学习是机器学习的一个分支,使用多层神经网络来模拟人脑的学习过程。",
metadata={"source": "dl_guide.txt", "category": "深度学习"}
),
Document(
page_content="自然语言处理(NLP)是人工智能领域的一个重要分支,专注于使计算机理解和处理人类语言。",
metadata={"source": "nlp_overview.txt", "category": "自然语言处理"}
),
Document(
page_content="计算机视觉是人工智能的另一个重要领域,使计算机能够识别和理解图像和视频内容。",
metadata={"source": "cv_intro.txt", "category": "计算机视觉"}
),
Document(
page_content="强化学习是机器学习的一种类型,通过与环境交互来学习最优的行为策略。",
metadata={"source": "rl_basics.txt", "category": "强化学习"}
),
Document(
page_content="今天的天气非常好,阳光明媚,适合外出游玩和运动。",
metadata={"source": "weather.txt", "category": "天气"}
),
Document(
page_content="区块链是一种分布式账本技术,具有去中心化、不可篡改等特点。",
metadata={"source": "blockchain.txt", "category": "区块链"}
)
]
def test_reranker_comparison():
"""比较不同重排器的效果"""
print("🔍 重排器效果比较测试")
print("=" * 60)
# 创建测试数据
query = "什么是人工智能和机器学习?"
documents = create_test_documents()
# 创建一个简单的嵌入模型(用于测试)
try:
from langchain_community.embeddings import HuggingFaceEmbeddings
embeddings = HuggingFaceEmbeddings(
model_name="sentence-transformers/all-MiniLM-L6-v2",
model_kwargs={'device': 'cpu'}
)
print("✅ 成功加载嵌入模型")
except Exception as e:
print(f"❌ 嵌入模型加载失败: {e}")
print("将使用基础重排器进行测试")
embeddings = None
# 测试不同的重排器
rerankers = []
# TF-IDF重排器
rerankers.append(("TF-IDF", TFIDFReranker()))
# BM25重排器
rerankers.append(("BM25", BM25Reranker()))
if embeddings:
# 语义重排器
rerankers.append(("语义相似度", SemanticReranker(embeddings)))
# 混合重排器
rerankers.append(("混合策略", HybridReranker(embeddings)))
# 多样性重排器
rerankers.append(("多样性优化", DiversityReranker(embeddings)))
# 执行测试
for name, reranker in rerankers:
print(f"\n📊 {name} 重排结果:")
print("-" * 40)
start_time = time.time()
try:
results = reranker.rerank(query, documents, top_k=5)
end_time = time.time()
print(f"⏱️ 处理时间: {(end_time - start_time)*1000:.2f}ms")
for i, (doc, score) in enumerate(results, 1):
content = doc.page_content[:80] + "..." if len(doc.page_content) > 80 else doc.page_content
category = doc.metadata.get('category', '未知')
print(f"{i}. [分数: {score:.4f}] [{category}] {content}")
except Exception as e:
print(f"❌ 重排失败: {e}")
def test_reranking_with_embeddings():
"""测试带嵌入的重排功能"""
print("\n\n🧠 嵌入模型重排测试")
print("=" * 60)
try:
# 创建文档处理器
processor = DocumentProcessor()
# 创建测试文档
test_docs = create_test_documents()
# 测试查询
queries = [
"人工智能的定义是什么?",
"机器学习和深度学习的区别",
"自然语言处理的应用",
"今天天气怎么样?"
]
for query in queries:
print(f"\n🔍 查询: {query}")
print("-" * 30)
if processor.reranker:
# 使用重排功能
results = processor.reranker.rerank(query, test_docs, top_k=3)
for i, (doc, score) in enumerate(results, 1):
content = doc.page_content[:60] + "..." if len(doc.page_content) > 60 else doc.page_content
category = doc.metadata.get('category', '未知')
print(f"{i}. [分数: {score:.4f}] [{category}] {content}")
else:
print("❌ 重排器未初始化")
except Exception as e:
print(f"❌ 测试失败: {e}")
def test_performance_comparison():
"""性能对比测试"""
print("\n\n⚡ 性能对比测试")
print("=" * 60)
documents = create_test_documents() * 10 # 增加文档数量
query = "人工智能技术的发展趋势"
# 测试不同重排器的性能
rerankers_config = [
("无重排", None),
("TF-IDF", TFIDFReranker()),
("BM25", BM25Reranker())
]
for name, reranker in rerankers_config:
times = []
# 多次测试取平均值
for _ in range(5):
start_time = time.time()
if reranker:
results = reranker.rerank(query, documents, top_k=5)
else:
# 模拟无重排的情况
results = documents[:5]
end_time = time.time()
times.append((end_time - start_time) * 1000)
avg_time = sum(times) / len(times)
print(f"{name}: 平均处理时间 {avg_time:.2f}ms (文档数: {len(documents)})")
def main():
"""主测试函数"""
print("🚀 向量重排功能综合测试")
print("=" * 80)
try:
# 基础重排器比较
test_reranker_comparison()
# 嵌入模型重排测试
test_reranking_with_embeddings()
# 性能对比测试
test_performance_comparison()
print("\n\n✅ 所有测试完成!")
print("=" * 80)
except KeyboardInterrupt:
print("\n❌ 测试被用户中断")
except Exception as e:
print(f"\n❌ 测试过程中发生错误: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()