fahmiaziz98 commited on
Commit
fb8f5fc
·
1 Parent(s): 9958d9a
src/api/routers/rerank.py CHANGED
@@ -80,37 +80,52 @@ async def rerank_documents(
80
  detail=f"Model '{request.model_id}' is not a rerank model. Type: {config.type}",
81
  )
82
 
 
 
 
 
 
 
 
83
  start = time.time()
84
 
85
- # Call rank_document with clean kwargs
86
- scores = model.rank_document(
 
 
 
87
  query=request.query,
88
- documents=[doc for _, doc in valid_docs], # Use filtered documents
89
  top_k=request.top_k,
90
  **kwargs,
91
  )
92
 
93
  processing_time = time.time() - start
94
 
95
- # Sebelum memanggil rank_document, tambahkan:
96
- logger.debug(f"Rerank request - Query: '{request.query}'")
97
- logger.debug(f"Documents to rank: {len(request.documents)}")
98
- logger.debug(f"First document: {request.documents[-1][:100]}...")
99
- logger.debug(f"Top K: {request.top_k}")
100
 
101
- # Setelah rank_document, tambahkan:
102
- logger.debug(f"Ranking returned {len(scores)} scores")
103
- logger.debug(f"Sample scores: {scores[:5] if scores else 'None'}")
104
-
105
- # Build results with original indices
106
- original_indices, documents_list = zip(*valid_docs)
107
  results = []
108
-
109
- for i, (orig_idx, doc) in enumerate(zip(original_indices, documents_list)):
110
- results.append(RerankResult(text=doc, score=scores[i], index=orig_idx))
111
-
112
- # Sort results by score in descending order
113
- results.sort(key=lambda x: x.score, reverse=True)
 
 
 
 
 
 
 
 
 
 
114
 
115
  logger.info(
116
  f"Reranked {len(results)} documents in {processing_time:.3f}s "
@@ -135,4 +150,4 @@ async def rerank_documents(
135
  raise HTTPException(
136
  status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
137
  detail=f"Failed to rerank documents: {str(e)}",
138
- )
 
80
  detail=f"Model '{request.model_id}' is not a rerank model. Type: {config.type}",
81
  )
82
 
83
+ # Debug logs BEFORE calling rank_document
84
+ logger.debug(f"Rerank request - Query: '{request.query}'")
85
+ logger.debug(f"Documents to rank: {len(valid_docs)}")
86
+ if valid_docs:
87
+ logger.debug(f"First document: {valid_docs[0][1][:100]}...")
88
+ logger.debug(f"Top K: {request.top_k}")
89
+
90
  start = time.time()
91
 
92
+ # Extract documents for ranking
93
+ documents_list = [doc for _, doc in valid_docs]
94
+
95
+ # Call rank_document - returns only top_k results
96
+ ranking_results = model.rank_document(
97
  query=request.query,
98
+ documents=documents_list,
99
  top_k=request.top_k,
100
  **kwargs,
101
  )
102
 
103
  processing_time = time.time() - start
104
 
105
+ # Debug logs AFTER rank_document
106
+ logger.debug(f"Ranking returned {len(ranking_results)} results")
107
+ if ranking_results:
108
+ logger.debug(f"Top result score: {ranking_results[0]}")
 
109
 
110
+ # Build results from ranking_results
111
+ # ranking_results already contains top_k items with scores
 
 
 
 
112
  results = []
113
+
114
+ for rank_result in ranking_results:
115
+ # Get original index from valid_docs
116
+ doc_idx = rank_result.get('corpus_id', 0) # Index in filtered list
117
+ if doc_idx < len(valid_docs):
118
+ original_idx = valid_docs[doc_idx][0] # Original index
119
+ doc_text = documents_list[doc_idx]
120
+ score = rank_result['score']
121
+
122
+ results.append(
123
+ RerankResult(
124
+ text=doc_text,
125
+ score=score,
126
+ index=original_idx
127
+ )
128
+ )
129
 
130
  logger.info(
131
  f"Reranked {len(results)} documents in {processing_time:.3f}s "
 
150
  raise HTTPException(
151
  status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
152
  detail=f"Failed to rerank documents: {str(e)}",
153
+ )
src/models/embeddings/rank.py CHANGED
@@ -5,7 +5,7 @@ This module provides the RerankModel class for reranking
5
  documents using sentence-transformers.
6
  """
7
 
8
- from typing import List, Optional
9
  from sentence_transformers import CrossEncoder
10
  from loguru import logger
11
 
@@ -18,18 +18,18 @@ class RerankModel:
18
  """
19
  Cross-encoder model wrapper using sentence-transformers.
20
 
21
- This class wraps sentence-transformers SentenceTransformer models
22
- to ranking documents
23
 
24
  Attributes:
25
  config: ModelConfig instance
26
- model: SentenceTransformer instance
27
  _loaded: Flag indicating if the model is loaded
28
  """
29
 
30
  def __init__(self, config: ModelConfig):
31
  """
32
- Initialize the dense embedding model.
33
 
34
  Args:
35
  config: ModelConfig instance with model configuration
@@ -48,6 +48,7 @@ class RerankModel:
48
  """
49
  if self._loaded:
50
  logger.debug(f"Model {self.model_id} already loaded")
 
51
 
52
  logger.info(f"Loading rerank model: {self.config.name}")
53
 
@@ -58,7 +59,7 @@ class RerankModel:
58
  trust_remote_code=self.settings.TRUST_REMOTE_CODE,
59
  )
60
  self._loaded = True
61
- logger.success(f"✓ Loaded dense model: {self.model_id}")
62
 
63
  except Exception as e:
64
  error_msg = f"Failed to load model: {str(e)}"
@@ -93,28 +94,44 @@ class RerankModel:
93
  documents: List[str],
94
  top_k: int,
95
  **kwargs,
96
- ) -> List[float]:
97
  """
98
  Rerank documents using the CrossEncoder model.
99
 
100
  Args:
101
  query (str): The search query string.
102
  documents (List[str]): List of documents to be reranked.
103
- top_k (int): top n documents
104
- **kwargs
105
 
106
  Returns:
107
- List[float]: List of relevance scores for each document.
 
108
 
109
- Raises:.
110
- Exception: If reranking fails.
111
  """
112
  if not self._loaded or self.model is None:
113
  self.load()
 
114
  try:
115
- scores = self.model.rank(query, documents, top_k=top_k, **kwargs)
116
- normalized_score = self._normalize_rerank_scores(scores)
117
- return normalized_score
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
  except Exception as e:
120
  error_msg = f"Reranking documents failed: {str(e)}"
@@ -122,34 +139,57 @@ class RerankModel:
122
  raise RerankingDocumentError(self.model_id, error_msg)
123
 
124
  def _normalize_rerank_scores(
125
- self, rankings: List[dict], target_range: tuple = (0, 1)
126
- ) -> List[float]:
 
 
127
  """
128
- Normalize reranking scores menggunakan berbagai metode.
129
 
130
  Args:
131
- rankings: List of ranking dictionaries dari cross-encoder
132
- target_range: Target range untuk minmax normalization (min, max)
 
133
 
134
  Returns:
135
- List of normalized scores
136
  """
 
 
 
 
137
  raw_scores = [ranking["score"] for ranking in rankings]
138
-
139
- # Min-Max normalization ke target range
140
  min_score = min(raw_scores)
141
  max_score = max(raw_scores)
142
-
 
143
  if max_score == min_score:
144
- return [target_range[1]] * len(raw_scores) # All same score
145
-
 
 
 
 
 
 
 
146
  target_min, target_max = target_range
147
- normalized = [
148
- target_min
149
- + (score - min_score) * (target_max - target_min) / (max_score - min_score)
150
- for score in raw_scores
151
- ]
152
- return normalized
 
 
 
 
 
 
 
 
153
 
154
  @property
155
  def is_loaded(self) -> bool:
@@ -177,7 +217,7 @@ class RerankModel:
177
  Get the model type.
178
 
179
  Returns:
180
- Model type ('embeddings' or 'sparse-embeddings')
181
  """
182
  return self.config.type
183
 
@@ -188,4 +228,4 @@ class RerankModel:
188
  f"id={self.model_id}, "
189
  f"type={self.model_type}, "
190
  f"loaded={self.is_loaded})"
191
- )
 
5
  documents using sentence-transformers.
6
  """
7
 
8
+ from typing import List, Optional, Dict
9
  from sentence_transformers import CrossEncoder
10
  from loguru import logger
11
 
 
18
  """
19
  Cross-encoder model wrapper using sentence-transformers.
20
 
21
+ This class wraps sentence-transformers CrossEncoder models
22
+ for ranking documents
23
 
24
  Attributes:
25
  config: ModelConfig instance
26
+ model: CrossEncoder instance
27
  _loaded: Flag indicating if the model is loaded
28
  """
29
 
30
  def __init__(self, config: ModelConfig):
31
  """
32
+ Initialize the rerank model.
33
 
34
  Args:
35
  config: ModelConfig instance with model configuration
 
48
  """
49
  if self._loaded:
50
  logger.debug(f"Model {self.model_id} already loaded")
51
+ return
52
 
53
  logger.info(f"Loading rerank model: {self.config.name}")
54
 
 
59
  trust_remote_code=self.settings.TRUST_REMOTE_CODE,
60
  )
61
  self._loaded = True
62
+ logger.success(f"✓ Loaded rerank model: {self.model_id}")
63
 
64
  except Exception as e:
65
  error_msg = f"Failed to load model: {str(e)}"
 
94
  documents: List[str],
95
  top_k: int,
96
  **kwargs,
97
+ ) -> List[Dict]:
98
  """
99
  Rerank documents using the CrossEncoder model.
100
 
101
  Args:
102
  query (str): The search query string.
103
  documents (List[str]): List of documents to be reranked.
104
+ top_k (int): Number of top documents to return
105
+ **kwargs: Additional arguments passed to model.rank()
106
 
107
  Returns:
108
+ List[Dict]: List of ranking results with 'corpus_id' and 'score'.
109
+ Returns top_k results sorted by score (highest first).
110
 
111
+ Raises:
112
+ RerankingDocumentError: If reranking fails.
113
  """
114
  if not self._loaded or self.model is None:
115
  self.load()
116
+
117
  try:
118
+ # model.rank returns List[Dict] with 'corpus_id' and 'score'
119
+ # Already sorted by score (highest first) and limited to top_k
120
+ ranking_results = self.model.rank(
121
+ query,
122
+ documents,
123
+ top_k=top_k,
124
+ **kwargs
125
+ )
126
+
127
+ # Normalize scores to 0-1 range for consistency
128
+ normalized_results = self._normalize_rerank_scores(ranking_results)
129
+
130
+ logger.debug(
131
+ f"Reranked {len(documents)} docs, returned top {len(normalized_results)}"
132
+ )
133
+
134
+ return normalized_results
135
 
136
  except Exception as e:
137
  error_msg = f"Reranking documents failed: {str(e)}"
 
139
  raise RerankingDocumentError(self.model_id, error_msg)
140
 
141
  def _normalize_rerank_scores(
142
+ self,
143
+ rankings: List[Dict],
144
+ target_range: tuple = (0, 1)
145
+ ) -> List[Dict]:
146
  """
147
+ Normalize reranking scores using min-max normalization.
148
 
149
  Args:
150
+ rankings: List of ranking dictionaries from cross-encoder
151
+ Format: [{'corpus_id': int, 'score': float}, ...]
152
+ target_range: Target range for normalization (min, max)
153
 
154
  Returns:
155
+ List[Dict]: Rankings with normalized scores
156
  """
157
+ if not rankings:
158
+ return []
159
+
160
+ # Extract raw scores
161
  raw_scores = [ranking["score"] for ranking in rankings]
162
+
163
+ # Min-Max normalization
164
  min_score = min(raw_scores)
165
  max_score = max(raw_scores)
166
+
167
+ # If all scores are the same, return max target value
168
  if max_score == min_score:
169
+ return [
170
+ {
171
+ "corpus_id": r["corpus_id"],
172
+ "score": target_range[1]
173
+ }
174
+ for r in rankings
175
+ ]
176
+
177
+ # Normalize to target range
178
  target_min, target_max = target_range
179
+ normalized_rankings = []
180
+
181
+ for ranking in rankings:
182
+ score = ranking["score"]
183
+ normalized_score = (
184
+ target_min +
185
+ (score - min_score) * (target_max - target_min) / (max_score - min_score)
186
+ )
187
+ normalized_rankings.append({
188
+ "corpus_id": ranking["corpus_id"],
189
+ "score": float(normalized_score)
190
+ })
191
+
192
+ return normalized_rankings
193
 
194
  @property
195
  def is_loaded(self) -> bool:
 
217
  Get the model type.
218
 
219
  Returns:
220
+ Model type ('rerank')
221
  """
222
  return self.config.type
223
 
 
228
  f"id={self.model_id}, "
229
  f"type={self.model_type}, "
230
  f"loaded={self.is_loaded})"
231
+ )