| from sentence_transformers import CrossEncoder |
|
|
| |
| _MODEL_NAME = "cross-encoder/ms-marco-MiniLM-L-6-v2" |
| _ce = None |
|
|
| def rerank(query: str, chunks: list[dict], top_k: int = 5) -> list[dict]: |
| """ |
| chunks: [{ "chunk_id":..., "chunk_text":..., "metadata":..., ... }, ...] |
| returns same dicts + "rerank_score" |
| """ |
| global _ce |
| if _ce is None: |
| _ce = CrossEncoder(_MODEL_NAME) |
|
|
| pairs = [(query, c.get("chunk_text", "")) for c in chunks] |
| scores = _ce.predict(pairs) |
|
|
| for c, s in zip(chunks, scores): |
| c["rerank_score"] = float(s) |
|
|
| chunks.sort(key=lambda x: x.get("rerank_score", 0.0), reverse=True) |
| return chunks[:top_k] |
|
|