Batch_RAG / search /reranker.py
DolAr1610
add new logic
a5c9fa3
raw
history blame contribute delete
727 Bytes
from sentence_transformers import CrossEncoder
# легка і популярна модель для rerank
_MODEL_NAME = "cross-encoder/ms-marco-MiniLM-L-6-v2"
_ce = None
def rerank(query: str, chunks: list[dict], top_k: int = 5) -> list[dict]:
"""
chunks: [{ "chunk_id":..., "chunk_text":..., "metadata":..., ... }, ...]
returns same dicts + "rerank_score"
"""
global _ce
if _ce is None:
_ce = CrossEncoder(_MODEL_NAME)
pairs = [(query, c.get("chunk_text", "")) for c in chunks]
scores = _ce.predict(pairs)
for c, s in zip(chunks, scores):
c["rerank_score"] = float(s)
chunks.sort(key=lambda x: x.get("rerank_score", 0.0), reverse=True)
return chunks[:top_k]