Spaces:
Paused
Paused
File size: 7,463 Bytes
399f3c6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 |
#!/usr/bin/env python3
"""
重排功能测试脚本
演示不同重排策略的效果
"""
import sys
import os
sys.path.append(os.path.dirname(__file__))
from document_processor import DocumentProcessor
from reranker import *
from langchain.schema import Document
import time
def create_test_documents():
"""创建测试文档"""
return [
Document(
page_content="人工智能(AI)是计算机科学的一个分支,致力于创建能够执行通常需要人类智能的任务的系统。",
metadata={"source": "ai_intro.txt", "category": "AI基础"}
),
Document(
page_content="机器学习是人工智能的一个重要子领域,通过算法让计算机从数据中学习模式和规律。",
metadata={"source": "ml_basics.txt", "category": "机器学习"}
),
Document(
page_content="深度学习是机器学习的一个分支,使用多层神经网络来模拟人脑的学习过程。",
metadata={"source": "dl_guide.txt", "category": "深度学习"}
),
Document(
page_content="自然语言处理(NLP)是人工智能领域的一个重要分支,专注于使计算机理解和处理人类语言。",
metadata={"source": "nlp_overview.txt", "category": "自然语言处理"}
),
Document(
page_content="计算机视觉是人工智能的另一个重要领域,使计算机能够识别和理解图像和视频内容。",
metadata={"source": "cv_intro.txt", "category": "计算机视觉"}
),
Document(
page_content="强化学习是机器学习的一种类型,通过与环境交互来学习最优的行为策略。",
metadata={"source": "rl_basics.txt", "category": "强化学习"}
),
Document(
page_content="今天的天气非常好,阳光明媚,适合外出游玩和运动。",
metadata={"source": "weather.txt", "category": "天气"}
),
Document(
page_content="区块链是一种分布式账本技术,具有去中心化、不可篡改等特点。",
metadata={"source": "blockchain.txt", "category": "区块链"}
)
]
def test_reranker_comparison():
"""比较不同重排器的效果"""
print("🔍 重排器效果比较测试")
print("=" * 60)
# 创建测试数据
query = "什么是人工智能和机器学习?"
documents = create_test_documents()
# 创建一个简单的嵌入模型(用于测试)
try:
from langchain_community.embeddings import HuggingFaceEmbeddings
embeddings = HuggingFaceEmbeddings(
model_name="sentence-transformers/all-MiniLM-L6-v2",
model_kwargs={'device': 'cpu'}
)
print("✅ 成功加载嵌入模型")
except Exception as e:
print(f"❌ 嵌入模型加载失败: {e}")
print("将使用基础重排器进行测试")
embeddings = None
# 测试不同的重排器
rerankers = []
# TF-IDF重排器
rerankers.append(("TF-IDF", TFIDFReranker()))
# BM25重排器
rerankers.append(("BM25", BM25Reranker()))
if embeddings:
# 语义重排器
rerankers.append(("语义相似度", SemanticReranker(embeddings)))
# 混合重排器
rerankers.append(("混合策略", HybridReranker(embeddings)))
# 多样性重排器
rerankers.append(("多样性优化", DiversityReranker(embeddings)))
# 执行测试
for name, reranker in rerankers:
print(f"\n📊 {name} 重排结果:")
print("-" * 40)
start_time = time.time()
try:
results = reranker.rerank(query, documents, top_k=5)
end_time = time.time()
print(f"⏱️ 处理时间: {(end_time - start_time)*1000:.2f}ms")
for i, (doc, score) in enumerate(results, 1):
content = doc.page_content[:80] + "..." if len(doc.page_content) > 80 else doc.page_content
category = doc.metadata.get('category', '未知')
print(f"{i}. [分数: {score:.4f}] [{category}] {content}")
except Exception as e:
print(f"❌ 重排失败: {e}")
def test_reranking_with_embeddings():
"""测试带嵌入的重排功能"""
print("\n\n🧠 嵌入模型重排测试")
print("=" * 60)
try:
# 创建文档处理器
processor = DocumentProcessor()
# 创建测试文档
test_docs = create_test_documents()
# 测试查询
queries = [
"人工智能的定义是什么?",
"机器学习和深度学习的区别",
"自然语言处理的应用",
"今天天气怎么样?"
]
for query in queries:
print(f"\n🔍 查询: {query}")
print("-" * 30)
if processor.reranker:
# 使用重排功能
results = processor.reranker.rerank(query, test_docs, top_k=3)
for i, (doc, score) in enumerate(results, 1):
content = doc.page_content[:60] + "..." if len(doc.page_content) > 60 else doc.page_content
category = doc.metadata.get('category', '未知')
print(f"{i}. [分数: {score:.4f}] [{category}] {content}")
else:
print("❌ 重排器未初始化")
except Exception as e:
print(f"❌ 测试失败: {e}")
def test_performance_comparison():
"""性能对比测试"""
print("\n\n⚡ 性能对比测试")
print("=" * 60)
documents = create_test_documents() * 10 # 增加文档数量
query = "人工智能技术的发展趋势"
# 测试不同重排器的性能
rerankers_config = [
("无重排", None),
("TF-IDF", TFIDFReranker()),
("BM25", BM25Reranker())
]
for name, reranker in rerankers_config:
times = []
# 多次测试取平均值
for _ in range(5):
start_time = time.time()
if reranker:
results = reranker.rerank(query, documents, top_k=5)
else:
# 模拟无重排的情况
results = documents[:5]
end_time = time.time()
times.append((end_time - start_time) * 1000)
avg_time = sum(times) / len(times)
print(f"{name}: 平均处理时间 {avg_time:.2f}ms (文档数: {len(documents)})")
def main():
"""主测试函数"""
print("🚀 向量重排功能综合测试")
print("=" * 80)
try:
# 基础重排器比较
test_reranker_comparison()
# 嵌入模型重排测试
test_reranking_with_embeddings()
# 性能对比测试
test_performance_comparison()
print("\n\n✅ 所有测试完成!")
print("=" * 80)
except KeyboardInterrupt:
print("\n❌ 测试被用户中断")
except Exception as e:
print(f"\n❌ 测试过程中发生错误: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main() |