from __future__ import annotations import re from typing import Optional from sentence_transformers import CrossEncoder class CrossEncoderReranker: """Cross-encoder reranker for query-document scoring.""" def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2", max_length: int = 384) -> None: self.model_name = model_name self.max_length = max_length self.model = CrossEncoder(model_name, max_length=max_length) def _trim(self, text: str) -> str: text = re.sub(r"\s+", " ", text).strip() if len(text) > 2000: text = text[:2000] return text def score(self, query: str, doc: str) -> float: query = self._trim(query) doc = self._trim(doc) return float(self.model.predict([[query, doc]]))