File size: 813 Bytes
5a3b322
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from __future__ import annotations

import re
from typing import Optional

from sentence_transformers import CrossEncoder


class CrossEncoderReranker:
    """Cross-encoder reranker for query-document scoring."""

    def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2", max_length: int = 384) -> None:
        self.model_name = model_name
        self.max_length = max_length
        self.model = CrossEncoder(model_name, max_length=max_length)

    def _trim(self, text: str) -> str:
        text = re.sub(r"\s+", " ", text).strip()
        if len(text) > 2000:
            text = text[:2000]
        return text

    def score(self, query: str, doc: str) -> float:
        query = self._trim(query)
        doc = self._trim(doc)
        return float(self.model.predict([[query, doc]]))