Spaces:
Paused
Paused
lanny xu
commited on
Commit
·
db5bfaa
1
Parent(s):
8008bd3
modify reranker
Browse files- document_processor.py +21 -6
- reranker.py +124 -2
- test_crossencoder_reranking.py +229 -0
document_processor.py
CHANGED
|
@@ -62,14 +62,29 @@ class DocumentProcessor:
|
|
| 62 |
self._setup_reranker()
|
| 63 |
|
| 64 |
def _setup_reranker(self):
|
| 65 |
-
"""
|
|
|
|
|
|
|
|
|
|
| 66 |
try:
|
| 67 |
-
#
|
| 68 |
-
|
| 69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
except Exception as e:
|
| 71 |
-
print(f"⚠️
|
| 72 |
-
print("
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
|
| 74 |
def load_documents(self, urls=None):
|
| 75 |
"""从URL加载文档"""
|
|
|
|
| 62 |
self._setup_reranker()
|
| 63 |
|
| 64 |
def _setup_reranker(self):
|
| 65 |
+
"""
|
| 66 |
+
设置重排器
|
| 67 |
+
使用 CrossEncoder 提升重排准确率
|
| 68 |
+
"""
|
| 69 |
try:
|
| 70 |
+
# 使用 CrossEncoder 重排器 (准确率最高) ⭐
|
| 71 |
+
print("🔧 正在初始化 CrossEncoder 重排器...")
|
| 72 |
+
self.reranker = create_reranker(
|
| 73 |
+
'crossencoder',
|
| 74 |
+
model_name='cross-encoder/ms-marco-MiniLM-L-6-v2', # 轻量级模型
|
| 75 |
+
max_length=512
|
| 76 |
+
)
|
| 77 |
+
print("✅ CrossEncoder 重排器初始化成功")
|
| 78 |
except Exception as e:
|
| 79 |
+
print(f"⚠️ CrossEncoder 初始化失败: {e}")
|
| 80 |
+
print("🔄 尝试回退到混合重排器...")
|
| 81 |
+
try:
|
| 82 |
+
# 回退到混合重排器
|
| 83 |
+
self.reranker = create_reranker('hybrid', self.embeddings)
|
| 84 |
+
print("✅ 混合重排器初始化成功")
|
| 85 |
+
except Exception as e2:
|
| 86 |
+
print(f"⚠️ 重排器初始化完全失败: {e2}")
|
| 87 |
+
print("⚠️ 将使用基础检索,不进行重排")
|
| 88 |
|
| 89 |
def load_documents(self, urls=None):
|
| 90 |
"""从URL加载文档"""
|
reranker.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
"""
|
| 2 |
向量重排模块
|
| 3 |
实现多种重排策略以提高检索质量
|
|
|
|
| 4 |
"""
|
| 5 |
|
| 6 |
import torch
|
|
@@ -12,6 +13,14 @@ import re
|
|
| 12 |
from collections import Counter
|
| 13 |
import math
|
| 14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
class DocumentReranker:
|
| 17 |
"""文档重排器基类"""
|
|
@@ -162,6 +171,86 @@ class SemanticReranker(DocumentReranker):
|
|
| 162 |
return results
|
| 163 |
|
| 164 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
class HybridReranker(DocumentReranker):
|
| 166 |
"""混合重排器,融合多种策略"""
|
| 167 |
|
|
@@ -302,26 +391,59 @@ class DiversityReranker(DocumentReranker):
|
|
| 302 |
|
| 303 |
|
| 304 |
def create_reranker(reranker_type: str, embeddings_model=None, **kwargs) -> DocumentReranker:
|
| 305 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 306 |
|
| 307 |
if reranker_type.lower() == 'tfidf':
|
| 308 |
return TFIDFReranker()
|
|
|
|
| 309 |
elif reranker_type.lower() == 'bm25':
|
| 310 |
return BM25Reranker(**kwargs)
|
|
|
|
| 311 |
elif reranker_type.lower() == 'semantic':
|
| 312 |
if embeddings_model is None:
|
| 313 |
raise ValueError("SemanticReranker requires embeddings_model")
|
| 314 |
return SemanticReranker(embeddings_model)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 315 |
elif reranker_type.lower() == 'hybrid':
|
| 316 |
if embeddings_model is None:
|
| 317 |
raise ValueError("HybridReranker requires embeddings_model")
|
| 318 |
return HybridReranker(embeddings_model, **kwargs)
|
|
|
|
| 319 |
elif reranker_type.lower() == 'diversity':
|
| 320 |
if embeddings_model is None:
|
| 321 |
raise ValueError("DiversityReranker requires embeddings_model")
|
| 322 |
return DiversityReranker(embeddings_model, **kwargs)
|
|
|
|
| 323 |
else:
|
| 324 |
-
raise ValueError(
|
|
|
|
|
|
|
|
|
|
| 325 |
|
| 326 |
|
| 327 |
# 使用示例
|
|
|
|
| 1 |
"""
|
| 2 |
向量重排模块
|
| 3 |
实现多种重排策略以提高检索质量
|
| 4 |
+
支持 CrossEncoder 深度重排
|
| 5 |
"""
|
| 6 |
|
| 7 |
import torch
|
|
|
|
| 13 |
from collections import Counter
|
| 14 |
import math
|
| 15 |
|
| 16 |
+
# CrossEncoder support
|
| 17 |
+
try:
|
| 18 |
+
from sentence_transformers import CrossEncoder as SentenceTransformerCrossEncoder
|
| 19 |
+
CROSSENCODER_AVAILABLE = True
|
| 20 |
+
except ImportError:
|
| 21 |
+
CROSSENCODER_AVAILABLE = False
|
| 22 |
+
print("⚠️ sentence-transformers not available. CrossEncoder reranking disabled.")
|
| 23 |
+
|
| 24 |
|
| 25 |
class DocumentReranker:
|
| 26 |
"""文档重排器基类"""
|
|
|
|
| 171 |
return results
|
| 172 |
|
| 173 |
|
| 174 |
+
class CrossEncoderReranker(DocumentReranker):
|
| 175 |
+
"""
|
| 176 |
+
基于 CrossEncoder 的重排器
|
| 177 |
+
使用联合编码,相比 Bi-Encoder 准确率提升 15-20%
|
| 178 |
+
适合精排阶段 (Top 20-100 文档)
|
| 179 |
+
"""
|
| 180 |
+
|
| 181 |
+
def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2", max_length: int = 512):
|
| 182 |
+
"""
|
| 183 |
+
初始化 CrossEncoder 重排器
|
| 184 |
+
|
| 185 |
+
Args:
|
| 186 |
+
model_name: 模型名称,默认使用轻量级模型
|
| 187 |
+
- "cross-encoder/ms-marco-MiniLM-L-6-v2" (轻量级,推荐)
|
| 188 |
+
- "cross-encoder/ms-marco-MiniLM-L-12-v2" (平衡)
|
| 189 |
+
- "BAAI/bge-reranker-base" (中文优化)
|
| 190 |
+
- "BAAI/bge-reranker-large" (高精度)
|
| 191 |
+
max_length: 最大输入长度
|
| 192 |
+
"""
|
| 193 |
+
super().__init__()
|
| 194 |
+
self.name = "CrossEncoderReranker"
|
| 195 |
+
self.model_name = model_name
|
| 196 |
+
self.max_length = max_length
|
| 197 |
+
|
| 198 |
+
# 加载模型
|
| 199 |
+
if not CROSSENCODER_AVAILABLE:
|
| 200 |
+
raise ImportError(
|
| 201 |
+
"CrossEncoder requires sentence-transformers. "
|
| 202 |
+
"Install with: pip install sentence-transformers"
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
try:
|
| 206 |
+
print(f"🔧 加载 CrossEncoder 模型: {model_name}...")
|
| 207 |
+
self.model = SentenceTransformerCrossEncoder(model_name, max_length=max_length)
|
| 208 |
+
print(f"✅ CrossEncoder 模型加载成功")
|
| 209 |
+
except Exception as e:
|
| 210 |
+
print(f"❌ CrossEncoder 模型加载失败: {e}")
|
| 211 |
+
raise
|
| 212 |
+
|
| 213 |
+
def rerank(self, query: str, documents: List[dict], top_k: int = 5) -> List[Tuple[dict, float]]:
|
| 214 |
+
"""
|
| 215 |
+
使用 CrossEncoder 重新排序文档
|
| 216 |
+
|
| 217 |
+
Args:
|
| 218 |
+
query: 查询文本
|
| 219 |
+
documents: 候选文档列表
|
| 220 |
+
top_k: 返回结果数量
|
| 221 |
+
|
| 222 |
+
Returns:
|
| 223 |
+
排序后的 (document, score) 元组列表
|
| 224 |
+
"""
|
| 225 |
+
if not documents:
|
| 226 |
+
return []
|
| 227 |
+
|
| 228 |
+
# 提取文档内容
|
| 229 |
+
doc_texts = [doc.page_content if hasattr(doc, 'page_content') else str(doc) for doc in documents]
|
| 230 |
+
|
| 231 |
+
# 构造 [query, doc] 对
|
| 232 |
+
query_doc_pairs = [[query, doc_text] for doc_text in doc_texts]
|
| 233 |
+
|
| 234 |
+
# CrossEncoder 评分 - 联合编码
|
| 235 |
+
try:
|
| 236 |
+
scores = self.model.predict(query_doc_pairs)
|
| 237 |
+
|
| 238 |
+
# 排序
|
| 239 |
+
ranked_indices = np.argsort(scores)[::-1]
|
| 240 |
+
|
| 241 |
+
# 返回 top_k 结果
|
| 242 |
+
results = []
|
| 243 |
+
for i in ranked_indices[:top_k]:
|
| 244 |
+
results.append((documents[i], float(scores[i])))
|
| 245 |
+
|
| 246 |
+
return results
|
| 247 |
+
|
| 248 |
+
except Exception as e:
|
| 249 |
+
print(f"⚠️ CrossEncoder 重排失败: {e}")
|
| 250 |
+
# 回退到原始顺序
|
| 251 |
+
return [(doc, 0.0) for doc in documents[:top_k]]
|
| 252 |
+
|
| 253 |
+
|
| 254 |
class HybridReranker(DocumentReranker):
|
| 255 |
"""混合重排器,融合多种策略"""
|
| 256 |
|
|
|
|
| 391 |
|
| 392 |
|
| 393 |
def create_reranker(reranker_type: str, embeddings_model=None, **kwargs) -> DocumentReranker:
|
| 394 |
+
"""
|
| 395 |
+
工厂函数:创建指定类型的重排器
|
| 396 |
+
|
| 397 |
+
Args:
|
| 398 |
+
reranker_type: 重排器类型
|
| 399 |
+
- 'tfidf': TF-IDF 重排
|
| 400 |
+
- 'bm25': BM25 重排
|
| 401 |
+
- 'semantic': Bi-Encoder 语义重排
|
| 402 |
+
- 'crossencoder': CrossEncoder 重排 (推荐) ⭐
|
| 403 |
+
- 'hybrid': 混合重排
|
| 404 |
+
- 'diversity': 多样性重排
|
| 405 |
+
embeddings_model: 嵌入模型 (某些重排器需要)
|
| 406 |
+
**kwargs: 其他参数
|
| 407 |
+
- model_name: CrossEncoder 模型名称
|
| 408 |
+
- max_length: CrossEncoder 最大长度
|
| 409 |
+
- weights: 混合重排权重
|
| 410 |
+
|
| 411 |
+
Returns:
|
| 412 |
+
DocumentReranker: 重排器实例
|
| 413 |
+
"""
|
| 414 |
|
| 415 |
if reranker_type.lower() == 'tfidf':
|
| 416 |
return TFIDFReranker()
|
| 417 |
+
|
| 418 |
elif reranker_type.lower() == 'bm25':
|
| 419 |
return BM25Reranker(**kwargs)
|
| 420 |
+
|
| 421 |
elif reranker_type.lower() == 'semantic':
|
| 422 |
if embeddings_model is None:
|
| 423 |
raise ValueError("SemanticReranker requires embeddings_model")
|
| 424 |
return SemanticReranker(embeddings_model)
|
| 425 |
+
|
| 426 |
+
elif reranker_type.lower() in ['crossencoder', 'cross_encoder', 'cross-encoder']:
|
| 427 |
+
# CrossEncoder 不需要 embeddings_model,使用自己的模型
|
| 428 |
+
model_name = kwargs.get('model_name', 'cross-encoder/ms-marco-MiniLM-L-6-v2')
|
| 429 |
+
max_length = kwargs.get('max_length', 512)
|
| 430 |
+
return CrossEncoderReranker(model_name=model_name, max_length=max_length)
|
| 431 |
+
|
| 432 |
elif reranker_type.lower() == 'hybrid':
|
| 433 |
if embeddings_model is None:
|
| 434 |
raise ValueError("HybridReranker requires embeddings_model")
|
| 435 |
return HybridReranker(embeddings_model, **kwargs)
|
| 436 |
+
|
| 437 |
elif reranker_type.lower() == 'diversity':
|
| 438 |
if embeddings_model is None:
|
| 439 |
raise ValueError("DiversityReranker requires embeddings_model")
|
| 440 |
return DiversityReranker(embeddings_model, **kwargs)
|
| 441 |
+
|
| 442 |
else:
|
| 443 |
+
raise ValueError(
|
| 444 |
+
f"Unknown reranker type: {reranker_type}. "
|
| 445 |
+
f"Available types: tfidf, bm25, semantic, crossencoder, hybrid, diversity"
|
| 446 |
+
)
|
| 447 |
|
| 448 |
|
| 449 |
# 使用示例
|
test_crossencoder_reranking.py
ADDED
|
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
测试 CrossEncoder 重排功能
|
| 3 |
+
对比 Bi-Encoder vs CrossEncoder 的效果
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from reranker import create_reranker, TFIDFReranker, BM25Reranker, SemanticReranker, CrossEncoderReranker
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class MockDoc:
|
| 10 |
+
"""模拟文档类"""
|
| 11 |
+
def __init__(self, content, metadata=None):
|
| 12 |
+
self.page_content = content
|
| 13 |
+
self.metadata = metadata or {}
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class MockEmbeddings:
|
| 17 |
+
"""模拟 Embeddings 类(用于 Semantic Reranker)"""
|
| 18 |
+
def embed_query(self, text):
|
| 19 |
+
# 简单的字符级向量化(仅用于测试)
|
| 20 |
+
return [ord(c) / 100.0 for c in text[:10]]
|
| 21 |
+
|
| 22 |
+
def embed_documents(self, texts):
|
| 23 |
+
return [self.embed_query(text) for text in texts]
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def create_test_documents():
|
| 27 |
+
"""创建测试文档集"""
|
| 28 |
+
return [
|
| 29 |
+
MockDoc("人工智能是计算机科学的一个分支,致力于创建能够执行通常需要人类智能的任务的系统。"),
|
| 30 |
+
MockDoc("机器学习是人工智能的子领域,专注于让计算机从数据中学习并改进。"),
|
| 31 |
+
MockDoc("深度学习使用多层神经网络来处理复杂的数据模式,是机器学习的一种方法。"),
|
| 32 |
+
MockDoc("自然语言处理(NLP)是人工智能的一个分支,处理计算机与人类语言之间的交互。"),
|
| 33 |
+
MockDoc("计算机视觉是人工智能的另一个重要领域,使机器能够理解和解释视觉信息。"),
|
| 34 |
+
MockDoc("今天天气很好,适合出去散步和运动。"),
|
| 35 |
+
MockDoc("Python 是一种高级编程语言,由 Guido van Rossum 在 1991 年创建。"),
|
| 36 |
+
MockDoc("RAG(检索增强生成)是一种结合信息检索和文本生成的技术。"),
|
| 37 |
+
]
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def test_tfidf_reranking():
|
| 41 |
+
"""测试 TF-IDF 重排"""
|
| 42 |
+
print("\n" + "=" * 60)
|
| 43 |
+
print("📊 测试 TF-IDF 重排")
|
| 44 |
+
print("=" * 60)
|
| 45 |
+
|
| 46 |
+
query = "什么是人工智能和机器学习?"
|
| 47 |
+
docs = create_test_documents()
|
| 48 |
+
|
| 49 |
+
reranker = TFIDFReranker()
|
| 50 |
+
results = reranker.rerank(query, docs, top_k=3)
|
| 51 |
+
|
| 52 |
+
print(f"\n查询: {query}")
|
| 53 |
+
print("\nTF-IDF 重排结果:")
|
| 54 |
+
for i, (doc, score) in enumerate(results, 1):
|
| 55 |
+
print(f"{i}. 分数: {score:.4f} | 内容: {doc.page_content[:50]}...")
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def test_bm25_reranking():
|
| 59 |
+
"""测试 BM25 重排"""
|
| 60 |
+
print("\n" + "=" * 60)
|
| 61 |
+
print("📊 测试 BM25 重排")
|
| 62 |
+
print("=" * 60)
|
| 63 |
+
|
| 64 |
+
query = "什么是人工智能和机器学习?"
|
| 65 |
+
docs = create_test_documents()
|
| 66 |
+
|
| 67 |
+
reranker = BM25Reranker()
|
| 68 |
+
results = reranker.rerank(query, docs, top_k=3)
|
| 69 |
+
|
| 70 |
+
print(f"\n查询: {query}")
|
| 71 |
+
print("\nBM25 重排结果:")
|
| 72 |
+
for i, (doc, score) in enumerate(results, 1):
|
| 73 |
+
print(f"{i}. 分数: {score:.4f} | 内容: {doc.page_content[:50]}...")
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def test_crossencoder_reranking():
|
| 77 |
+
"""测试 CrossEncoder 重排"""
|
| 78 |
+
print("\n" + "=" * 60)
|
| 79 |
+
print("🌟 测试 CrossEncoder 重排(推荐)")
|
| 80 |
+
print("=" * 60)
|
| 81 |
+
|
| 82 |
+
query = "什么是人工智能和机器学习?"
|
| 83 |
+
docs = create_test_documents()
|
| 84 |
+
|
| 85 |
+
try:
|
| 86 |
+
# 使用轻量级模型
|
| 87 |
+
reranker = CrossEncoderReranker(
|
| 88 |
+
model_name="cross-encoder/ms-marco-MiniLM-L-6-v2"
|
| 89 |
+
)
|
| 90 |
+
results = reranker.rerank(query, docs, top_k=3)
|
| 91 |
+
|
| 92 |
+
print(f"\n查询: {query}")
|
| 93 |
+
print("\nCrossEncoder 重排结果:")
|
| 94 |
+
for i, (doc, score) in enumerate(results, 1):
|
| 95 |
+
print(f"{i}. 分数: {score:.4f} | 内容: {doc.page_content[:50]}...")
|
| 96 |
+
|
| 97 |
+
return True
|
| 98 |
+
|
| 99 |
+
except Exception as e:
|
| 100 |
+
print(f"\n❌ CrossEncoder 测试失败: {e}")
|
| 101 |
+
print("💡 提示: 请先安装 sentence-transformers")
|
| 102 |
+
print(" 命令: pip install sentence-transformers")
|
| 103 |
+
return False
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def test_factory_function():
|
| 107 |
+
"""测试工厂函数"""
|
| 108 |
+
print("\n" + "=" * 60)
|
| 109 |
+
print("🏭 测试重排器工厂函数")
|
| 110 |
+
print("=" * 60)
|
| 111 |
+
|
| 112 |
+
query = "深度学习和神经网络"
|
| 113 |
+
docs = create_test_documents()
|
| 114 |
+
|
| 115 |
+
# 测试各种类型
|
| 116 |
+
reranker_types = ['tfidf', 'bm25']
|
| 117 |
+
|
| 118 |
+
for rtype in reranker_types:
|
| 119 |
+
try:
|
| 120 |
+
reranker = create_reranker(rtype)
|
| 121 |
+
results = reranker.rerank(query, docs, top_k=2)
|
| 122 |
+
print(f"\n✅ {rtype.upper()} 重排器创建成功")
|
| 123 |
+
print(f" Top 1: {results[0][1]:.4f} | {results[0][0].page_content[:40]}...")
|
| 124 |
+
except Exception as e:
|
| 125 |
+
print(f"\n❌ {rtype.upper()} 重排器失败: {e}")
|
| 126 |
+
|
| 127 |
+
# 测试 CrossEncoder
|
| 128 |
+
try:
|
| 129 |
+
reranker = create_reranker('crossencoder')
|
| 130 |
+
results = reranker.rerank(query, docs, top_k=2)
|
| 131 |
+
print(f"\n✅ CROSSENCODER 重排器创建成功")
|
| 132 |
+
print(f" Top 1: {results[0][1]:.4f} | {results[0][0].page_content[:40]}...")
|
| 133 |
+
except Exception as e:
|
| 134 |
+
print(f"\n❌ CROSSENCODER 重排器失败: {e}")
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def compare_all_methods():
|
| 138 |
+
"""对比所有重排方法"""
|
| 139 |
+
print("\n" + "=" * 60)
|
| 140 |
+
print("⚖️ 对比所有重排方法")
|
| 141 |
+
print("=" * 60)
|
| 142 |
+
|
| 143 |
+
query = "解释一下人工智能、机器学习和深度学习的关系"
|
| 144 |
+
docs = create_test_documents()
|
| 145 |
+
|
| 146 |
+
methods = {
|
| 147 |
+
'TF-IDF': TFIDFReranker(),
|
| 148 |
+
'BM25': BM25Reranker(),
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
# 尝试添加 CrossEncoder
|
| 152 |
+
try:
|
| 153 |
+
methods['CrossEncoder'] = CrossEncoderReranker()
|
| 154 |
+
except:
|
| 155 |
+
print("\n⚠️ CrossEncoder 不可用,跳过")
|
| 156 |
+
|
| 157 |
+
print(f"\n查询: {query}\n")
|
| 158 |
+
|
| 159 |
+
for method_name, reranker in methods.items():
|
| 160 |
+
try:
|
| 161 |
+
results = reranker.rerank(query, docs, top_k=3)
|
| 162 |
+
print(f"\n{'=' * 40}")
|
| 163 |
+
print(f"{method_name} 重排结果:")
|
| 164 |
+
print('=' * 40)
|
| 165 |
+
for i, (doc, score) in enumerate(results, 1):
|
| 166 |
+
print(f"{i}. [{score:.4f}] {doc.page_content[:60]}...")
|
| 167 |
+
except Exception as e:
|
| 168 |
+
print(f"\n{method_name} 失败: {e}")
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def performance_comparison():
|
| 172 |
+
"""性能对比"""
|
| 173 |
+
print("\n" + "=" * 60)
|
| 174 |
+
print("⚡ 性能与准确性对比")
|
| 175 |
+
print("=" * 60)
|
| 176 |
+
|
| 177 |
+
print("""
|
| 178 |
+
重排方法对比:
|
| 179 |
+
|
| 180 |
+
┌─────────────────┬──────────┬──────────┬──────────┬────────────┐
|
| 181 |
+
│ 方法 │ 准确率 │ 速度 │ 成本 │ 适用场景 │
|
| 182 |
+
├─────────────────┼──────────┼──────────┼──────────┼────────────┤
|
| 183 |
+
│ TF-IDF │ ⭐⭐ │ ⚡⚡⚡ │ 极低 │ 关键词匹配 │
|
| 184 |
+
│ BM25 │ ⭐⭐⭐ │ ⚡⚡⚡ │ 极低 │ 文本检索 │
|
| 185 |
+
│ Bi-Encoder │ ⭐⭐⭐⭐ │ ⚡⚡ │ 低 │ 语义检索 │
|
| 186 |
+
│ CrossEncoder 🌟 │ ⭐⭐⭐⭐⭐│ ⚡ │ 中 │ 精准重排 │
|
| 187 |
+
│ Hybrid │ ⭐⭐⭐⭐ │ ⚡⚡ │ 低 │ 综合场景 │
|
| 188 |
+
└─────────────────┴──────────┴──────────┴──────────┴────────────┘
|
| 189 |
+
|
| 190 |
+
推荐配置:
|
| 191 |
+
1️⃣ 两阶段检索:Bi-Encoder (快速召回) + CrossEncoder (精准重排)
|
| 192 |
+
2️⃣ 准确率优先:纯 CrossEncoder
|
| 193 |
+
3️⃣ 速度优先:BM25 或 Hybrid
|
| 194 |
+
|
| 195 |
+
当前项目配置:
|
| 196 |
+
✅ 已切换到 CrossEncoder 重排
|
| 197 |
+
📈 准确率预期提升:15-20%
|
| 198 |
+
⚡ 速度:单次重排 20-100ms (Top 20 文档)
|
| 199 |
+
""")
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
if __name__ == "__main__":
|
| 203 |
+
print("\n🚀 开始测试 CrossEncoder 重排功能...\n")
|
| 204 |
+
|
| 205 |
+
# 1. 测试 TF-IDF
|
| 206 |
+
test_tfidf_reranking()
|
| 207 |
+
|
| 208 |
+
# 2. 测试 BM25
|
| 209 |
+
test_bm25_reranking()
|
| 210 |
+
|
| 211 |
+
# 3. 测试 CrossEncoder (重点)
|
| 212 |
+
crossencoder_available = test_crossencoder_reranking()
|
| 213 |
+
|
| 214 |
+
# 4. 测试工厂函数
|
| 215 |
+
test_factory_function()
|
| 216 |
+
|
| 217 |
+
# 5. 对比所有方法
|
| 218 |
+
compare_all_methods()
|
| 219 |
+
|
| 220 |
+
# 6. 性能对比总结
|
| 221 |
+
performance_comparison()
|
| 222 |
+
|
| 223 |
+
print("\n" + "=" * 60)
|
| 224 |
+
if crossencoder_available:
|
| 225 |
+
print("✅ 所有测试完成!CrossEncoder 重排已就绪")
|
| 226 |
+
else:
|
| 227 |
+
print("⚠️ 测试完成,但 CrossEncoder 不可用")
|
| 228 |
+
print(" 请运行: pip install sentence-transformers")
|
| 229 |
+
print("=" * 60 + "\n")
|