Spaces:
Running
Running
| """ | |
| Rerank model implementation. | |
| This module provides the RerankModel class for reranking | |
| documents using sentence-transformers. | |
| """ | |
| from typing import List, Optional | |
| from sentence_transformers import CrossEncoder | |
| from loguru import logger | |
| from src.config.settings import get_settings | |
| from src.core.config import ModelConfig | |
| from src.core.exceptions import ModelLoadError, RerankingDocumentError | |
| class RerankModel: | |
| """ | |
| Cross-encoder model wrapper using sentence-transformers. | |
| This class wraps sentence-transformers SentenceTransformer models | |
| to ranking documents | |
| Attributes: | |
| config: ModelConfig instance | |
| model: SentenceTransformer instance | |
| _loaded: Flag indicating if the model is loaded | |
| """ | |
| def __init__(self, config: ModelConfig): | |
| """ | |
| Initialize the dense embedding model. | |
| Args: | |
| config: ModelConfig instance with model configuration | |
| """ | |
| self.config = config | |
| self._loaded = False | |
| self.model: Optional[CrossEncoder] = None | |
| self.settings = get_settings() | |
| def load(self) -> None: | |
| """ | |
| Load the cross-encoder model into memory. | |
| Raises: | |
| ModelLoadError: If model fails to load | |
| """ | |
| if self._loaded: | |
| logger.debug(f"Model {self.model_id} already loaded") | |
| logger.info(f"Loading rerank model: {self.config.name}") | |
| try: | |
| self.model = CrossEncoder( | |
| self.config.name, | |
| device=self.settings.DEVICE, | |
| trust_remote_code=self.settings.TRUST_REMOTE_CODE, | |
| ) | |
| self._loaded = True | |
| logger.success(f"β Loaded dense model: {self.model_id}") | |
| except Exception as e: | |
| error_msg = f"Failed to load model: {str(e)}" | |
| logger.error(f"β {error_msg}") | |
| raise ModelLoadError(self.model_id, error_msg) | |
| def unload(self) -> None: | |
| """ | |
| Unload the model from memory and free resources. | |
| This method safely releases the model and clears GPU/CPU memory. | |
| """ | |
| if not self._loaded: | |
| logger.debug(f"Model {self.model_id} not loaded, nothing to unload") | |
| return | |
| try: | |
| if self.model is not None: | |
| # Clear model from memory | |
| del self.model | |
| self.model = None | |
| self._loaded = False | |
| logger.info(f"β Unloaded model: {self.model_id}") | |
| except Exception as e: | |
| logger.error(f"Error unloading model {self.model_id}: {e}") | |
| def rank_document( | |
| self, | |
| query: str, | |
| documents: List[str], | |
| top_k: int, | |
| **kwargs, | |
| ) -> List[float]: | |
| """ | |
| Rerank documents using the CrossEncoder model. | |
| Args: | |
| query (str): The search query string. | |
| documents (List[str]): List of documents to be reranked. | |
| top_k (int): top n documents | |
| **kwargs | |
| Returns: | |
| List[float]: List of relevance scores for each document. | |
| Raises:. | |
| Exception: If reranking fails. | |
| """ | |
| if not self._loaded or self.model is None: | |
| self.load() | |
| try: | |
| scores = self.model.rank(query, documents, top_k=top_k, **kwargs) | |
| normalized_score = self._normalize_rerank_scores(scores) | |
| return normalized_score | |
| except Exception as e: | |
| error_msg = f"Reranking documents failed: {str(e)}" | |
| logger.error(error_msg) | |
| raise RerankingDocumentError(self.model_id, error_msg) | |
| def _normalize_rerank_scores( | |
| self, rankings: List[dict], target_range: tuple = (0, 1) | |
| ) -> List[float]: | |
| """ | |
| Normalize reranking scores menggunakan berbagai metode. | |
| Args: | |
| rankings: List of ranking dictionaries dari cross-encoder | |
| target_range: Target range untuk minmax normalization (min, max) | |
| Returns: | |
| List of normalized scores | |
| """ | |
| raw_scores = [ranking["score"] for ranking in rankings] | |
| # Min-Max normalization ke target range | |
| min_score = min(raw_scores) | |
| max_score = max(raw_scores) | |
| if max_score == min_score: | |
| return [target_range[1]] * len(raw_scores) # All same score | |
| target_min, target_max = target_range | |
| normalized = [ | |
| target_min | |
| + (score - min_score) * (target_max - target_min) / (max_score - min_score) | |
| for score in raw_scores | |
| ] | |
| return normalized | |
| def is_loaded(self) -> bool: | |
| """ | |
| Check if the model is currently loaded. | |
| Returns: | |
| True if model is loaded, False otherwise | |
| """ | |
| return self._loaded | |
| def model_id(self) -> str: | |
| """ | |
| Get the model identifier. | |
| Returns: | |
| Model ID string | |
| """ | |
| return self.config.id | |
| def model_type(self) -> str: | |
| """ | |
| Get the model type. | |
| Returns: | |
| Model type ('embeddings' or 'sparse-embeddings') | |
| """ | |
| return self.config.type | |
| def __repr__(self) -> str: | |
| """String representation of the model.""" | |
| return ( | |
| f"{self.__class__.__name__}(" | |
| f"id={self.model_id}, " | |
| f"type={self.model_type}, " | |
| f"loaded={self.is_loaded})" | |
| ) | |