fahmiaziz98
[FIX TYPO]: fix typo param
3371d97
raw
history blame
5.5 kB
"""
Rerank model implementation.
This module provides the RerankModel class for reranking
documents using sentence-transformers.
"""
from typing import List, Optional
from sentence_transformers import CrossEncoder
from loguru import logger
from src.config.settings import get_settings
from src.core.config import ModelConfig
from src.core.exceptions import ModelLoadError, RerankingDocumentError
class RerankModel:
"""
Cross-encoder model wrapper using sentence-transformers.
This class wraps sentence-transformers SentenceTransformer models
to ranking documents
Attributes:
config: ModelConfig instance
model: SentenceTransformer instance
_loaded: Flag indicating if the model is loaded
"""
def __init__(self, config: ModelConfig):
"""
Initialize the dense embedding model.
Args:
config: ModelConfig instance with model configuration
"""
self.config = config
self._loaded = False
self.model: Optional[CrossEncoder] = None
self.settings = get_settings()
def load(self) -> None:
"""
Load the cross-encoder model into memory.
Raises:
ModelLoadError: If model fails to load
"""
if self._loaded:
logger.debug(f"Model {self.model_id} already loaded")
logger.info(f"Loading rerank model: {self.config.name}")
try:
self.model = CrossEncoder(
self.config.name,
device=self.settings.DEVICE,
trust_remote_code=self.settings.TRUST_REMOTE_CODE,
)
self._loaded = True
logger.success(f"βœ“ Loaded dense model: {self.model_id}")
except Exception as e:
error_msg = f"Failed to load model: {str(e)}"
logger.error(f"βœ— {error_msg}")
raise ModelLoadError(self.model_id, error_msg)
def unload(self) -> None:
"""
Unload the model from memory and free resources.
This method safely releases the model and clears GPU/CPU memory.
"""
if not self._loaded:
logger.debug(f"Model {self.model_id} not loaded, nothing to unload")
return
try:
if self.model is not None:
# Clear model from memory
del self.model
self.model = None
self._loaded = False
logger.info(f"βœ“ Unloaded model: {self.model_id}")
except Exception as e:
logger.error(f"Error unloading model {self.model_id}: {e}")
def rank_document(
self,
query: str,
documents: List[str],
top_k: int,
**kwargs,
) -> List[float]:
"""
Rerank documents using the CrossEncoder model.
Args:
query (str): The search query string.
documents (List[str]): List of documents to be reranked.
top_k (int): top n documents
**kwargs
Returns:
List[float]: List of relevance scores for each document.
Raises:.
Exception: If reranking fails.
"""
if not self._loaded or self.model is None:
self.load()
try:
scores = self.model.rank(query, documents, top_k=top_k, **kwargs)
normalized_score = self._normalize_rerank_scores(scores)
return normalized_score
except Exception as e:
error_msg = f"Reranking documents failed: {str(e)}"
logger.error(error_msg)
raise RerankingDocumentError(self.model_id, error_msg)
def _normalize_rerank_scores(
self, rankings: List[dict], target_range: tuple = (0, 1)
) -> List[float]:
"""
Normalize reranking scores menggunakan berbagai metode.
Args:
rankings: List of ranking dictionaries dari cross-encoder
target_range: Target range untuk minmax normalization (min, max)
Returns:
List of normalized scores
"""
raw_scores = [ranking["score"] for ranking in rankings]
# Min-Max normalization ke target range
min_score = min(raw_scores)
max_score = max(raw_scores)
if max_score == min_score:
return [target_range[1]] * len(raw_scores) # All same score
target_min, target_max = target_range
normalized = [
target_min
+ (score - min_score) * (target_max - target_min) / (max_score - min_score)
for score in raw_scores
]
return normalized
@property
def is_loaded(self) -> bool:
"""
Check if the model is currently loaded.
Returns:
True if model is loaded, False otherwise
"""
return self._loaded
@property
def model_id(self) -> str:
"""
Get the model identifier.
Returns:
Model ID string
"""
return self.config.id
@property
def model_type(self) -> str:
"""
Get the model type.
Returns:
Model type ('embeddings' or 'sparse-embeddings')
"""
return self.config.type
def __repr__(self) -> str:
"""String representation of the model."""
return (
f"{self.__class__.__name__}("
f"id={self.model_id}, "
f"type={self.model_type}, "
f"loaded={self.is_loaded})"
)