File size: 813 Bytes
5a3b322 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 |
from __future__ import annotations
import re
from typing import Optional
from sentence_transformers import CrossEncoder
class CrossEncoderReranker:
"""Cross-encoder reranker for query-document scoring."""
def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2", max_length: int = 384) -> None:
self.model_name = model_name
self.max_length = max_length
self.model = CrossEncoder(model_name, max_length=max_length)
def _trim(self, text: str) -> str:
text = re.sub(r"\s+", " ", text).strip()
if len(text) > 2000:
text = text[:2000]
return text
def score(self, query: str, doc: str) -> float:
query = self._trim(query)
doc = self._trim(doc)
return float(self.model.predict([[query, doc]]))
|