minhvtt commited on
Commit
1009897
verified
1 Parent(s): ca36499

Update advanced_rag.py

Browse files
Files changed (1) hide show
  1. advanced_rag.py +12 -7
advanced_rag.py CHANGED
@@ -187,20 +187,25 @@ Alternative queries (one per line):"""
187
  # Get Cross-Encoder scores
188
  ce_scores = self.cross_encoder.predict(pairs)
189
 
190
- # Create reranked documents with new scores
 
 
 
 
 
 
 
191
  reranked = []
192
- for doc, ce_score in zip(documents, ce_scores):
193
- # Combine CE score with original confidence (weighted)
194
- combined_score = 0.7 * float(ce_score) + 0.3 * doc.confidence
195
-
196
  reranked.append(RetrievedDocument(
197
  id=doc.id,
198
  text=doc.text,
199
- confidence=float(combined_score),
200
  metadata=doc.metadata
201
  ))
202
 
203
- # Sort by new combined score
204
  reranked.sort(key=lambda x: x.confidence, reverse=True)
205
  return reranked[:top_k]
206
 
 
187
  # Get Cross-Encoder scores
188
  ce_scores = self.cross_encoder.predict(pairs)
189
 
190
+ # Normalize CE scores using sigmoid (convert logits to 0-1 range)
191
+ import math
192
+ def sigmoid(x):
193
+ return 1 / (1 + math.exp(-x))
194
+
195
+ ce_scores_normalized = [sigmoid(float(score)) for score in ce_scores]
196
+
197
+ # Create reranked documents with normalized scores
198
  reranked = []
199
+ for doc, ce_score_norm in zip(documents, ce_scores_normalized):
200
+ # Use ONLY Cross-Encoder score (it's more accurate than cosine similarity)
 
 
201
  reranked.append(RetrievedDocument(
202
  id=doc.id,
203
  text=doc.text,
204
+ confidence=float(ce_score_norm),
205
  metadata=doc.metadata
206
  ))
207
 
208
+ # Sort by Cross-Encoder score
209
  reranked.sort(key=lambda x: x.confidence, reverse=True)
210
  return reranked[:top_k]
211