jayyd commited on
Commit
38e41eb
·
verified ·
1 Parent(s): ea20f00

Update utils/retriever.py

Browse files
Files changed (1) hide show
  1. utils/retriever.py +59 -0
utils/retriever.py CHANGED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sklearn.feature_extraction.text import TfidfVectorizer
2
+ from sklearn.metrics.pairwise import cosine_similarity
3
+ import numpy as np
4
+ from typing import List, Dict, Union
5
+ from sentence_transformers import CrossEncoder # NEW
6
+
7
+ class HybridRetriever:
8
+ def __init__(self, chunks: List[Union[str, Dict]], embedder, cross_encoder_model="cross-encoder/ms-marco-MiniLM-L-6-v2"):
9
+ self.chunks = chunks
10
+ self.embedder = embedder
11
+
12
+ # Handle both string chunks and dict chunks
13
+ if chunks and isinstance(chunks[0], dict):
14
+ self.texts = [c['text'] for c in chunks]
15
+ else:
16
+ self.texts = chunks
17
+
18
+ self.embeddings = embedder.encode(self.texts)
19
+ self.tfidf = TfidfVectorizer(stop_words='english')
20
+ self.tfidf_matrix = self.tfidf.fit_transform(self.texts)
21
+
22
+ # Load cross-encoder for re-ranking
23
+ self.cross_encoder = CrossEncoder(cross_encoder_model)
24
+
25
+ def retrieve(self, query: str, top_k: int = 5, candidate_k: int = 20) -> List[str]:
26
+ """
27
+ Retrieve most relevant chunks using hybrid + cross-encoder re-ranking.
28
+
29
+ Args:
30
+ query: Search query
31
+ top_k: Number of final chunks to return
32
+ candidate_k: Number of initial chunks before re-ranking
33
+
34
+ Returns:
35
+ List of most relevant text chunks
36
+ """
37
+ # Get dense embeddings
38
+ query_embedding = self.embedder.encode([query])
39
+ dense_scores = cosine_similarity(query_embedding, self.embeddings)[0]
40
+
41
+ # Get sparse scores
42
+ sparse_query = self.tfidf.transform([query])
43
+ sparse_scores = cosine_similarity(sparse_query, self.tfidf_matrix)[0]
44
+
45
+ # Combine scores (weighted average)
46
+ combined_scores = 0.7 * dense_scores + 0.3 * sparse_scores
47
+
48
+ # Select candidate chunks
49
+ top_indices = np.argsort(combined_scores)[::-1][:candidate_k]
50
+ candidate_chunks = [self.texts[i] for i in top_indices]
51
+
52
+ # Cross-encoder re-ranking
53
+ pairs = [(query, chunk) for chunk in candidate_chunks]
54
+ relevance_scores = self.cross_encoder.predict(pairs)
55
+
56
+ reranked_indices = np.argsort(relevance_scores)[::-1][:top_k]
57
+ top_chunks = [candidate_chunks[i] for i in reranked_indices]
58
+
59
+ return top_chunks