Spaces:
Running
Running
File size: 5,499 Bytes
90528a8 3371d97 90528a8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 |
"""
Rerank model implementation.
This module provides the RerankModel class for reranking
documents using sentence-transformers.
"""
from typing import List, Optional
from sentence_transformers import CrossEncoder
from loguru import logger
from src.config.settings import get_settings
from src.core.config import ModelConfig
from src.core.exceptions import ModelLoadError, RerankingDocumentError
class RerankModel:
"""
Cross-encoder model wrapper using sentence-transformers.
This class wraps sentence-transformers SentenceTransformer models
to ranking documents
Attributes:
config: ModelConfig instance
model: SentenceTransformer instance
_loaded: Flag indicating if the model is loaded
"""
def __init__(self, config: ModelConfig):
"""
Initialize the dense embedding model.
Args:
config: ModelConfig instance with model configuration
"""
self.config = config
self._loaded = False
self.model: Optional[CrossEncoder] = None
self.settings = get_settings()
def load(self) -> None:
"""
Load the cross-encoder model into memory.
Raises:
ModelLoadError: If model fails to load
"""
if self._loaded:
logger.debug(f"Model {self.model_id} already loaded")
logger.info(f"Loading rerank model: {self.config.name}")
try:
self.model = CrossEncoder(
self.config.name,
device=self.settings.DEVICE,
trust_remote_code=self.settings.TRUST_REMOTE_CODE,
)
self._loaded = True
logger.success(f"✓ Loaded dense model: {self.model_id}")
except Exception as e:
error_msg = f"Failed to load model: {str(e)}"
logger.error(f"✗ {error_msg}")
raise ModelLoadError(self.model_id, error_msg)
def unload(self) -> None:
"""
Unload the model from memory and free resources.
This method safely releases the model and clears GPU/CPU memory.
"""
if not self._loaded:
logger.debug(f"Model {self.model_id} not loaded, nothing to unload")
return
try:
if self.model is not None:
# Clear model from memory
del self.model
self.model = None
self._loaded = False
logger.info(f"✓ Unloaded model: {self.model_id}")
except Exception as e:
logger.error(f"Error unloading model {self.model_id}: {e}")
def rank_document(
self,
query: str,
documents: List[str],
top_k: int,
**kwargs,
) -> List[float]:
"""
Rerank documents using the CrossEncoder model.
Args:
query (str): The search query string.
documents (List[str]): List of documents to be reranked.
top_k (int): top n documents
**kwargs
Returns:
List[float]: List of relevance scores for each document.
Raises:.
Exception: If reranking fails.
"""
if not self._loaded or self.model is None:
self.load()
try:
scores = self.model.rank(query, documents, top_k=top_k, **kwargs)
normalized_score = self._normalize_rerank_scores(scores)
return normalized_score
except Exception as e:
error_msg = f"Reranking documents failed: {str(e)}"
logger.error(error_msg)
raise RerankingDocumentError(self.model_id, error_msg)
def _normalize_rerank_scores(
self, rankings: List[dict], target_range: tuple = (0, 1)
) -> List[float]:
"""
Normalize reranking scores menggunakan berbagai metode.
Args:
rankings: List of ranking dictionaries dari cross-encoder
target_range: Target range untuk minmax normalization (min, max)
Returns:
List of normalized scores
"""
raw_scores = [ranking["score"] for ranking in rankings]
# Min-Max normalization ke target range
min_score = min(raw_scores)
max_score = max(raw_scores)
if max_score == min_score:
return [target_range[1]] * len(raw_scores) # All same score
target_min, target_max = target_range
normalized = [
target_min
+ (score - min_score) * (target_max - target_min) / (max_score - min_score)
for score in raw_scores
]
return normalized
@property
def is_loaded(self) -> bool:
"""
Check if the model is currently loaded.
Returns:
True if model is loaded, False otherwise
"""
return self._loaded
@property
def model_id(self) -> str:
"""
Get the model identifier.
Returns:
Model ID string
"""
return self.config.id
@property
def model_type(self) -> str:
"""
Get the model type.
Returns:
Model type ('embeddings' or 'sparse-embeddings')
"""
return self.config.type
def __repr__(self) -> str:
"""String representation of the model."""
return (
f"{self.__class__.__name__}("
f"id={self.model_id}, "
f"type={self.model_type}, "
f"loaded={self.is_loaded})"
)
|