adaptive_rag / reranker.py
lanny xu
modify reranker
be297c2
"""
向量重排模块
实现多种重排策略以提高检索质量
支持 CrossEncoder 深度重排
"""
import torch
import numpy as np
from typing import List, Tuple, Dict
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
import re
from collections import Counter
import math
# CrossEncoder support
try:
from sentence_transformers import CrossEncoder as SentenceTransformerCrossEncoder
CROSSENCODER_AVAILABLE = True
except ImportError:
CROSSENCODER_AVAILABLE = False
print("⚠️ sentence-transformers not available. CrossEncoder reranking disabled.")
class DocumentReranker:
"""文档重排器基类"""
def __init__(self):
self.name = "BaseReranker"
def rerank(self, query: str, documents: List[dict], top_k: int = 5) -> List[Tuple[dict, float]]:
"""重排文档并返回top_k结果"""
raise NotImplementedError
class TFIDFReranker(DocumentReranker):
"""基于TF-IDF的重排器"""
def __init__(self):
super().__init__()
self.name = "TFIDFReranker"
# 移除 stop_words 以支持中文,使用 char_wb 分词器
self.vectorizer = TfidfVectorizer(
analyzer='char_wb', # 字符级分词,支持中文
ngram_range=(2, 4), # 2-4 字符 n-gram
max_features=5000
)
def rerank(self, query: str, documents: List[dict], top_k: int = 5) -> List[Tuple[dict, float]]:
"""使用TF-IDF重新排序文档"""
if not documents:
return []
# 提取文档内容
doc_texts = [doc.page_content if hasattr(doc, 'page_content') else str(doc) for doc in documents]
all_texts = [query] + doc_texts
# 计算TF-IDF矩阵
tfidf_matrix = self.vectorizer.fit_transform(all_texts)
query_vec = tfidf_matrix[0]
doc_vecs = tfidf_matrix[1:]
# 计算相似度
similarities = cosine_similarity(query_vec, doc_vecs).flatten()
# 排序并返回top_k
ranked_indices = np.argsort(similarities)[::-1]
results = []
for i in ranked_indices[:top_k]:
results.append((documents[i], float(similarities[i])))
return results
class BM25Reranker(DocumentReranker):
"""基于BM25算法的重排器"""
def __init__(self, k1: float = 1.5, b: float = 0.75):
super().__init__()
self.name = "BM25Reranker"
self.k1 = k1
self.b = b
def _tokenize(self, text: str) -> List[str]:
"""
改进的分词,支持中英文
中文使用字符级分词,英文使用单词分词
"""
# 检测是否包含中文
has_chinese = any('\u4e00' <= char <= '\u9fff' for char in text)
if has_chinese:
# 中文:使用字符级 + 2-gram
chars = list(text.lower())
# 生成 unigram 和 bigram
tokens = chars + [chars[i] + chars[i+1] for i in range(len(chars)-1)]
return [t for t in tokens if t.strip()] # 移除空格
else:
# 英文:使用单词分词
return re.findall(r'\b\w+\b', text.lower())
def _compute_idf(self, documents: List[str], query_terms: List[str]) -> Dict[str, float]:
"""计算IDF值"""
N = len(documents)
idf = {}
for term in query_terms:
df = sum(1 for doc in documents if term in self._tokenize(doc))
idf[term] = math.log((N - df + 0.5) / (df + 0.5))
return idf
def _bm25_score(self, query_terms: List[str], document: str, avg_doc_len: float, idf: Dict[str, float]) -> float:
"""计算BM25分数"""
doc_terms = self._tokenize(document)
doc_len = len(doc_terms)
term_freq = Counter(doc_terms)
score = 0.0
for term in query_terms:
if term in term_freq:
tf = term_freq[term]
score += idf.get(term, 0) * (tf * (self.k1 + 1)) / (
tf + self.k1 * (1 - self.b + self.b * doc_len / avg_doc_len)
)
return score
def rerank(self, query: str, documents: List[dict], top_k: int = 5) -> List[Tuple[dict, float]]:
"""使用BM25重新排序文档"""
if not documents:
return []
query_terms = self._tokenize(query)
doc_texts = [doc.page_content if hasattr(doc, 'page_content') else str(doc) for doc in documents]
# 计算平均文档长度
avg_doc_len = sum(len(self._tokenize(doc)) for doc in doc_texts) / len(doc_texts)
# 计算IDF
idf = self._compute_idf(doc_texts, query_terms)
# 计算BM25分数
scores = []
for doc_text in doc_texts:
score = self._bm25_score(query_terms, doc_text, avg_doc_len, idf)
scores.append(score)
# 排序并返回top_k
ranked_indices = np.argsort(scores)[::-1]
results = []
for i in ranked_indices[:top_k]:
results.append((documents[i], float(scores[i])))
return results
class SemanticReranker(DocumentReranker):
"""基于语义相似度的重排器"""
def __init__(self, embeddings_model):
super().__init__()
self.name = "SemanticReranker"
self.embeddings_model = embeddings_model
def rerank(self, query: str, documents: List[dict], top_k: int = 5) -> List[Tuple[dict, float]]:
"""使用语义相似度重新排序文档"""
if not documents:
return []
# 获取查询嵌入
query_embedding = self.embeddings_model.embed_query(query)
# 获取文档嵌入
doc_texts = [doc.page_content if hasattr(doc, 'page_content') else str(doc) for doc in documents]
doc_embeddings = self.embeddings_model.embed_documents(doc_texts)
# 计算余弦相似度
similarities = []
for doc_emb in doc_embeddings:
sim = cosine_similarity([query_embedding], [doc_emb])[0][0]
similarities.append(sim)
# 排序并返回top_k
ranked_indices = np.argsort(similarities)[::-1]
results = []
for i in ranked_indices[:top_k]:
results.append((documents[i], float(similarities[i])))
return results
class CrossEncoderReranker(DocumentReranker):
"""
基于 CrossEncoder 的重排器
使用联合编码,相比 Bi-Encoder 准确率提升 15-20%
适合精排阶段 (Top 20-100 文档)
"""
def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2", max_length: int = 512):
"""
初始化 CrossEncoder 重排器
Args:
model_name: 模型名称,默认使用轻量级模型
- "cross-encoder/ms-marco-MiniLM-L-6-v2" (轻量级,推荐)
- "cross-encoder/ms-marco-MiniLM-L-12-v2" (平衡)
- "BAAI/bge-reranker-base" (中文优化)
- "BAAI/bge-reranker-large" (高精度)
max_length: 最大输入长度
"""
super().__init__()
self.name = "CrossEncoderReranker"
self.model_name = model_name
self.max_length = max_length
# 加载模型
if not CROSSENCODER_AVAILABLE:
raise ImportError(
"CrossEncoder requires sentence-transformers. "
"Install with: pip install sentence-transformers"
)
try:
print(f"🔧 加载 CrossEncoder 模型: {model_name}...")
self.model = SentenceTransformerCrossEncoder(model_name, max_length=max_length)
print(f"✅ CrossEncoder 模型加载成功")
except Exception as e:
print(f"❌ CrossEncoder 模型加载失败: {e}")
raise
def rerank(self, query: str, documents: List[dict], top_k: int = 5) -> List[Tuple[dict, float]]:
"""
使用 CrossEncoder 重新排序文档
Args:
query: 查询文本
documents: 候选文档列表
top_k: 返回结果数量
Returns:
排序后的 (document, score) 元组列表
"""
if not documents:
return []
# 提取文档内容
doc_texts = [doc.page_content if hasattr(doc, 'page_content') else str(doc) for doc in documents]
# 构造 [query, doc] 对
query_doc_pairs = [[query, doc_text] for doc_text in doc_texts]
# CrossEncoder 评分 - 联合编码
try:
scores = self.model.predict(query_doc_pairs)
# 排序
ranked_indices = np.argsort(scores)[::-1]
# 返回 top_k 结果
results = []
for i in ranked_indices[:top_k]:
results.append((documents[i], float(scores[i])))
return results
except Exception as e:
print(f"⚠️ CrossEncoder 重排失败: {e}")
# 回退到原始顺序
return [(doc, 0.0) for doc in documents[:top_k]]
class HybridReranker(DocumentReranker):
"""混合重排器,融合多种策略"""
def __init__(self, embeddings_model, weights: Dict[str, float] = None):
super().__init__()
self.name = "HybridReranker"
# 初始化各种重排器
self.tfidf_reranker = TFIDFReranker()
self.bm25_reranker = BM25Reranker()
self.semantic_reranker = SemanticReranker(embeddings_model)
# 设置权重
self.weights = weights or {
'tfidf': 0.3,
'bm25': 0.3,
'semantic': 0.4
}
def rerank(self, query: str, documents: List[dict], top_k: int = 5) -> List[Tuple[dict, float]]:
"""使用混合策略重新排序文档"""
if not documents:
return []
# 获取各种重排结果
tfidf_results = self.tfidf_reranker.rerank(query, documents, len(documents))
bm25_results = self.bm25_reranker.rerank(query, documents, len(documents))
semantic_results = self.semantic_reranker.rerank(query, documents, len(documents))
# 创建文档到分数的映射
doc_scores = {}
for doc in documents:
doc_id = id(doc)
doc_scores[doc_id] = {'doc': doc, 'tfidf': 0, 'bm25': 0, 'semantic': 0}
# 填充各种分数
for doc, score in tfidf_results:
doc_scores[id(doc)]['tfidf'] = score
for doc, score in bm25_results:
doc_scores[id(doc)]['bm25'] = score
for doc, score in semantic_results:
doc_scores[id(doc)]['semantic'] = score
# 归一化分数
for score_type in ['tfidf', 'bm25', 'semantic']:
scores = [info[score_type] for info in doc_scores.values()]
if max(scores) > 0:
max_score = max(scores)
for doc_id in doc_scores:
doc_scores[doc_id][score_type] /= max_score
# 计算综合分数
final_scores = []
for doc_id, info in doc_scores.items():
combined_score = (
self.weights['tfidf'] * info['tfidf'] +
self.weights['bm25'] * info['bm25'] +
self.weights['semantic'] * info['semantic']
)
final_scores.append((info['doc'], combined_score))
# 排序并返回top_k
final_scores.sort(key=lambda x: x[1], reverse=True)
return final_scores[:top_k]
class DiversityReranker(DocumentReranker):
"""多样性重排器,避免结果重复"""
def __init__(self, embeddings_model, diversity_lambda: float = 0.5):
super().__init__()
self.name = "DiversityReranker"
self.embeddings_model = embeddings_model
self.diversity_lambda = diversity_lambda
def _calculate_diversity_penalty(self, candidate_doc: str, selected_docs: List[str]) -> float:
"""计算多样性惩罚"""
if not selected_docs:
return 0.0
candidate_emb = self.embeddings_model.embed_documents([candidate_doc])[0]
selected_embs = self.embeddings_model.embed_documents(selected_docs)
max_similarity = 0.0
for selected_emb in selected_embs:
sim = cosine_similarity([candidate_emb], [selected_emb])[0][0]
max_similarity = max(max_similarity, sim)
return max_similarity
def rerank(self, query: str, documents: List[dict], top_k: int = 5) -> List[Tuple[dict, float]]:
"""使用多样性策略重新排序文档"""
if not documents:
return []
# 首先使用语义相似度获取初始排序
semantic_results = SemanticReranker(self.embeddings_model).rerank(
query, documents, len(documents)
)
# MMR (Maximal Marginal Relevance) 算法
selected_docs = []
selected_texts = []
remaining_docs = [doc for doc, _ in semantic_results]
relevance_scores = {id(doc): score for doc, score in semantic_results}
while len(selected_docs) < top_k and remaining_docs:
best_score = -1
best_doc = None
best_idx = -1
for i, doc in enumerate(remaining_docs):
doc_text = doc.page_content if hasattr(doc, 'page_content') else str(doc)
relevance = relevance_scores[id(doc)]
diversity_penalty = self._calculate_diversity_penalty(doc_text, selected_texts)
# MMR分数 = λ * 相关性 - (1-λ) * 多样性惩罚
mmr_score = (
self.diversity_lambda * relevance -
(1 - self.diversity_lambda) * diversity_penalty
)
if mmr_score > best_score:
best_score = mmr_score
best_doc = doc
best_idx = i
if best_doc is not None:
selected_docs.append((best_doc, best_score))
selected_texts.append(
best_doc.page_content if hasattr(best_doc, 'page_content') else str(best_doc)
)
remaining_docs.pop(best_idx)
return selected_docs
def create_reranker(reranker_type: str, embeddings_model=None, **kwargs) -> DocumentReranker:
"""
工厂函数:创建指定类型的重排器
Args:
reranker_type: 重排器类型
- 'tfidf': TF-IDF 重排
- 'bm25': BM25 重排
- 'semantic': Bi-Encoder 语义重排
- 'crossencoder': CrossEncoder 重排 (推荐) ⭐
- 'hybrid': 混合重排
- 'diversity': 多样性重排
embeddings_model: 嵌入模型 (某些重排器需要)
**kwargs: 其他参数
- model_name: CrossEncoder 模型名称
- max_length: CrossEncoder 最大长度
- weights: 混合重排权重
Returns:
DocumentReranker: 重排器实例
"""
if reranker_type.lower() == 'tfidf':
return TFIDFReranker()
elif reranker_type.lower() == 'bm25':
return BM25Reranker(**kwargs)
elif reranker_type.lower() == 'semantic':
if embeddings_model is None:
raise ValueError("SemanticReranker requires embeddings_model")
return SemanticReranker(embeddings_model)
elif reranker_type.lower() in ['crossencoder', 'cross_encoder', 'cross-encoder']:
# CrossEncoder 不需要 embeddings_model,使用自己的模型
model_name = kwargs.get('model_name', 'cross-encoder/ms-marco-MiniLM-L-6-v2')
max_length = kwargs.get('max_length', 512)
return CrossEncoderReranker(model_name=model_name, max_length=max_length)
elif reranker_type.lower() == 'hybrid':
if embeddings_model is None:
raise ValueError("HybridReranker requires embeddings_model")
return HybridReranker(embeddings_model, **kwargs)
elif reranker_type.lower() == 'diversity':
if embeddings_model is None:
raise ValueError("DiversityReranker requires embeddings_model")
return DiversityReranker(embeddings_model, **kwargs)
else:
raise ValueError(
f"Unknown reranker type: {reranker_type}. "
f"Available types: tfidf, bm25, semantic, crossencoder, hybrid, diversity"
)
# 使用示例
if __name__ == "__main__":
# 模拟文档
class MockDoc:
def __init__(self, content):
self.page_content = content
docs = [
MockDoc("人工智能是计算机科学的一个分支"),
MockDoc("机器学习是人工智能的子领域"),
MockDoc("深度学习使用神经网络"),
MockDoc("自然语言处理处理文本数据"),
MockDoc("今天天气很好")
]
query = "什么是人工智能?"
# 测试TF-IDF重排
tfidf_reranker = TFIDFReranker()
results = tfidf_reranker.rerank(query, docs, top_k=3)
print("TF-IDF重排结果:")
for doc, score in results:
print(f"分数: {score:.4f} - 内容: {doc.page_content}")