Spaces:
Running
Running
fahmiaziz98
Refactor reranking models and configuration management; add YAML support for model settings
7f8bfb2
| import torch | |
| from typing import List, Optional | |
| from loguru import logger | |
| from sentence_transformers import CrossEncoder | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from .base import RerankerModel | |
| class SentenceTransformersReranker(RerankerModel): | |
| """ | |
| Reranker using sentence-transformers CrossEncoder. | |
| This class leverages the CrossEncoder model from the sentence-transformers library to score the relevance of documents given a query. It is suitable for reranking tasks in information retrieval pipelines. | |
| Attributes: | |
| model_name (str): Name or path of the model to load. | |
| model (CrossEncoder): The loaded CrossEncoder model instance. | |
| loaded (bool): Whether the model has been loaded. | |
| model_id (str): Unique identifier for the model instance. | |
| """ | |
| def load(self): | |
| """ | |
| Load the sentence-transformers CrossEncoder model. | |
| Loads the CrossEncoder model specified by self.model_name. Sets self.loaded to True if successful. | |
| Raises: | |
| Exception: If the model fails to load. | |
| """ | |
| try: | |
| logger.info(f"Loading SentenceTransformers model: {self.model_name}") | |
| self.model = CrossEncoder( | |
| self.model_name, | |
| model_kwargs={"torch_dtype": "auto"}, | |
| trust_remote_code=True | |
| ) | |
| self.loaded = True | |
| logger.success(f"Successfully loaded {self.model_id}") | |
| except Exception as e: | |
| logger.error(f"Failed to load {self.model_id}: {e}") | |
| raise | |
| def rerank(self, query: str, documents: List[str], instruction: Optional[str] = None) -> List[float]: | |
| """ | |
| Rerank documents using the CrossEncoder model. | |
| Args: | |
| query (str): The search query string. | |
| documents (List[str]): List of documents to be reranked. | |
| instruction (Optional[str]): Additional instruction for reranking (not used in this implementation). | |
| Returns: | |
| List[float]: List of relevance scores for each document. | |
| Raises: | |
| RuntimeError: If the model is not loaded. | |
| Exception: If reranking fails. | |
| """ | |
| if not self.loaded: | |
| raise RuntimeError(f"Model {self.model_id} not loaded") | |
| try: | |
| rankings = self.model.rank(query, documents, convert_to_tensor=True) | |
| scores = [0.0] * len(documents) | |
| for ranking in rankings: | |
| scores[ranking['corpus_id']] = float(ranking['score']) | |
| return scores | |
| except Exception as e: | |
| logger.error(f"Reranking failed with {self.model_id}: {e}") | |
| raise | |
| class QwenReranker(RerankerModel): | |
| """ | |
| Reranker using Qwen3-Reranker model (LLM-based). | |
| This class uses a Qwen LLM to judge the relevance of documents to a query and instruction. The model outputs a probability that each document is relevant ("yes") or not ("no"). | |
| Attributes: | |
| model_name (str): Name or path of the Qwen model. | |
| tokenizer (AutoTokenizer): Tokenizer for the Qwen model. | |
| model (AutoModelForCausalLM): Loaded Qwen model instance. | |
| loaded (bool): Whether the model has been loaded. | |
| model_id (str): Unique identifier for the model instance. | |
| token_false_id (int): Token ID for "no". | |
| token_true_id (int): Token ID for "yes". | |
| max_length (int): Maximum input token length. | |
| prefix (str): Prompt prefix for the system message. | |
| suffix (str): Prompt suffix for the assistant message. | |
| prefix_tokens (List[int]): Tokenized prefix. | |
| suffix_tokens (List[int]): Tokenized suffix. | |
| """ | |
| def load(self): | |
| """ | |
| Load the Qwen reranker model and tokenizer, and initialize prompt templates and special tokens. | |
| Raises: | |
| Exception: If the model or tokenizer fails to load. | |
| """ | |
| try: | |
| logger.info(f"Loading Qwen model: {self.model_name}") | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| self.model_name, | |
| padding_side='left' | |
| ) | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| self.model_name | |
| ).eval() | |
| # Set up Qwen-specific tokens | |
| self.token_false_id = self.tokenizer.convert_tokens_to_ids("no") | |
| self.token_true_id = self.tokenizer.convert_tokens_to_ids("yes") | |
| self.max_length = 8192 | |
| # Set up prompt templates | |
| self.prefix = "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n" | |
| self.suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n" | |
| self.prefix_tokens = self.tokenizer.encode(self.prefix, add_special_tokens=False) | |
| self.suffix_tokens = self.tokenizer.encode(self.suffix, add_special_tokens=False) | |
| self.loaded = True | |
| logger.success(f"Successfully loaded {self.model_id}") | |
| except Exception as e: | |
| logger.error(f"Failed to load {self.model_id}: {e}") | |
| raise | |
| def _format_instruction(self, instruction: str, query: str, doc: str) -> str: | |
| """ | |
| Format the instruction string for the Qwen model prompt. | |
| Args: | |
| instruction (str): The instruction for the reranker. If None, a default instruction is used. | |
| query (str): The search query string. | |
| doc (str): The document to be evaluated. | |
| Returns: | |
| str: Formatted prompt string for the model. | |
| """ | |
| if instruction is None: | |
| instruction = 'Given a web search query, retrieve relevant passages that answer the query' | |
| return "<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {doc}".format( | |
| instruction=instruction, query=query, doc=doc | |
| ) | |
| def _process_inputs(self, pairs: List[str]): | |
| """ | |
| Tokenize and prepare input pairs for the Qwen model. | |
| Args: | |
| pairs (List[str]): List of formatted prompt strings for each document. | |
| Returns: | |
| dict: Tokenized and padded input tensors for the model. | |
| """ | |
| inputs = self.tokenizer( | |
| pairs, | |
| padding=False, | |
| truncation='longest_first', | |
| return_attention_mask=False, | |
| max_length=self.max_length - len(self.prefix_tokens) - len(self.suffix_tokens) | |
| ) | |
| for i, ele in enumerate(inputs['input_ids']): | |
| inputs['input_ids'][i] = self.prefix_tokens + ele + self.suffix_tokens | |
| inputs = self.tokenizer.pad( | |
| inputs, | |
| padding=True, | |
| return_tensors="pt", | |
| max_length=self.max_length | |
| ) | |
| for key in inputs: | |
| inputs[key] = inputs[key].to(self.model.device) | |
| return inputs | |
| def _compute_logits(self, inputs): | |
| """ | |
| Compute relevance scores from model logits. | |
| Args: | |
| inputs (dict): Tokenized and padded input tensors for the model. | |
| Returns: | |
| List[float]: List of probabilities that each document is relevant ("yes"). | |
| """ | |
| batch_scores = self.model(**inputs).logits[:, -1, :] | |
| true_vector = batch_scores[:, self.token_true_id] | |
| false_vector = batch_scores[:, self.token_false_id] | |
| batch_scores = torch.stack([false_vector, true_vector], dim=1) | |
| batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1) | |
| scores = batch_scores[:, 1].exp().tolist() | |
| return scores | |
| def rerank(self, query: str, documents: List[str], instruction: Optional[str] = None) -> List[float]: | |
| """ | |
| Rerank documents using the Qwen model. | |
| Args: | |
| query (str): The search query string. | |
| documents (List[str]): List of documents to be reranked. | |
| instruction (Optional[str]): Additional instruction for reranking. | |
| Returns: | |
| List[float]: List of relevance scores for each document. | |
| Raises: | |
| RuntimeError: If the model is not loaded. | |
| Exception: If reranking fails. | |
| """ | |
| if not self.loaded: | |
| raise RuntimeError(f"Model {self.model_id} not loaded") | |
| try: | |
| pairs = [ | |
| self._format_instruction(instruction, query, doc) | |
| for doc in documents | |
| ] | |
| inputs = self._process_inputs(pairs) | |
| scores = self._compute_logits(inputs) | |
| return scores | |
| except Exception as e: | |
| logger.error(f"Reranking failed with {self.model_id}: {e}") | |
| raise | |