| | from transformers import AutoTokenizer, TFAutoModelForSequenceClassification |
| | import tensorflow as tf |
| | from typing import List |
| |
|
| | from logger_config import config_logger |
| | logger = config_logger(__name__) |
| |
|
| | class CrossEncoderReranker: |
| | """ |
| | Cross-Encoder Re-Ranker. Takes (query, candidate) pairs and outputs a relevance score [0...1]. |
| | """ |
| | def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-12-v2"): |
| | """ |
| | Init the cross-encoder with a pretrained model. |
| | Args: |
| | model_name: Name of a HF cross-encoder model. Must be compatible with TFAutoModelForSequenceClassification. |
| | """ |
| | logger.info(f"Initializing CrossEncoderReranker with {model_name}...") |
| | self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
| | self.model = TFAutoModelForSequenceClassification.from_pretrained(model_name) |
| | logger.info("Cross encoder model loaded successfully.") |
| |
|
| | def rerank( |
| | self, |
| | query: str, |
| | candidates: List[str], |
| | max_length: int = 256 |
| | ) -> List[float]: |
| | """ |
| | Compute relevance scores for each candidate w.r.t. query. |
| | Args: |
| | query: User's query text. |
| | candidates: List of candidate response texts. |
| | max_length: Max token length for each (query, candidate) pair. |
| | Returns: |
| | A list of float scores [0...1]. One per candidate, indicating model's predicted relevance. |
| | """ |
| | |
| | pair_texts = [(query, candidate) for candidate in candidates] |
| | encodings = self.tokenizer( |
| | pair_texts, |
| | padding=True, |
| | truncation=True, |
| | max_length=max_length, |
| | return_tensors="tf", |
| | verbose=False |
| | ) |
| |
|
| | |
| | |
| | |
| | outputs = self.model( |
| | input_ids=encodings["input_ids"], |
| | attention_mask=encodings["attention_mask"], |
| | token_type_ids=encodings.get("token_type_ids") |
| | ) |
| | |
| | logits = outputs.logits |
| | scores = tf.nn.sigmoid(logits) |
| |
|
| | |
| | scores = tf.reshape(scores, [-1]) |
| | scores = scores.numpy().astype(float) |
| |
|
| | return scores.tolist() |