financial_qa_rag / utils /retriever.py
jayyd's picture
Update utils/retriever.py
38e41eb verified
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
from typing import List, Dict, Union
from sentence_transformers import CrossEncoder # NEW
class HybridRetriever:
def __init__(self, chunks: List[Union[str, Dict]], embedder, cross_encoder_model="cross-encoder/ms-marco-MiniLM-L-6-v2"):
self.chunks = chunks
self.embedder = embedder
# Handle both string chunks and dict chunks
if chunks and isinstance(chunks[0], dict):
self.texts = [c['text'] for c in chunks]
else:
self.texts = chunks
self.embeddings = embedder.encode(self.texts)
self.tfidf = TfidfVectorizer(stop_words='english')
self.tfidf_matrix = self.tfidf.fit_transform(self.texts)
# Load cross-encoder for re-ranking
self.cross_encoder = CrossEncoder(cross_encoder_model)
def retrieve(self, query: str, top_k: int = 5, candidate_k: int = 20) -> List[str]:
"""
Retrieve most relevant chunks using hybrid + cross-encoder re-ranking.
Args:
query: Search query
top_k: Number of final chunks to return
candidate_k: Number of initial chunks before re-ranking
Returns:
List of most relevant text chunks
"""
# Get dense embeddings
query_embedding = self.embedder.encode([query])
dense_scores = cosine_similarity(query_embedding, self.embeddings)[0]
# Get sparse scores
sparse_query = self.tfidf.transform([query])
sparse_scores = cosine_similarity(sparse_query, self.tfidf_matrix)[0]
# Combine scores (weighted average)
combined_scores = 0.7 * dense_scores + 0.3 * sparse_scores
# Select candidate chunks
top_indices = np.argsort(combined_scores)[::-1][:candidate_k]
candidate_chunks = [self.texts[i] for i in top_indices]
# Cross-encoder re-ranking
pairs = [(query, chunk) for chunk in candidate_chunks]
relevance_scores = self.cross_encoder.predict(pairs)
reranked_indices = np.argsort(relevance_scores)[::-1][:top_k]
top_chunks = [candidate_chunks[i] for i in reranked_indices]
return top_chunks