Spaces:
Sleeping
Sleeping
| """ | |
| Retrievers Module - BM25, Semantic, and Hybrid Retrieval | |
| """ | |
| from typing import List, Tuple, Optional | |
| import numpy as np | |
| from rank_bm25 import BM25Okapi | |
| from sentence_transformers import SentenceTransformer | |
| import torch | |
| class BM25Retriever: | |
| """BM25 keyword-based retriever""" | |
| def __init__(self, chunks: List[str]): | |
| """ | |
| Initialize BM25 retriever | |
| Args: | |
| chunks: List of text chunks to search over | |
| """ | |
| self.chunks = chunks | |
| # Tokenize chunks for BM25 | |
| tokenized_chunks = [chunk.lower().split() for chunk in chunks] | |
| self.bm25 = BM25Okapi(tokenized_chunks) | |
| def retrieve(self, query: str, top_k: int = 20) -> List[Tuple[int, float]]: | |
| """ | |
| Retrieve top-k chunks for query | |
| Args: | |
| query: Search query | |
| top_k: Number of results to return | |
| Returns: | |
| List of (chunk_index, score) tuples | |
| """ | |
| tokenized_query = query.lower().split() | |
| scores = self.bm25.get_scores(tokenized_query) | |
| # Get top-k indices | |
| top_indices = np.argsort(scores)[::-1][:top_k] | |
| results = [(int(idx), float(scores[idx])) for idx in top_indices if scores[idx] > 0] | |
| return results | |
| class SemanticRetriever: | |
| """Semantic/Dense retriever using sentence transformers""" | |
| def __init__( | |
| self, | |
| chunks: List[str], | |
| model_name: str = "all-MiniLM-L6-v2", | |
| device: Optional[str] = None | |
| ): | |
| """ | |
| Initialize semantic retriever | |
| Args: | |
| chunks: List of text chunks to search over | |
| model_name: Name of sentence transformer model | |
| device: Device to run model on (cuda/cpu), None for auto | |
| """ | |
| self.chunks = chunks | |
| self.model = SentenceTransformer(model_name, device=device) | |
| # Encode all chunks | |
| print("Encoding chunks for semantic search...") | |
| self.chunk_embeddings = self.model.encode( | |
| chunks, | |
| show_progress_bar=True, | |
| convert_to_numpy=True | |
| ) | |
| print("Encoding complete!") | |
| def retrieve(self, query: str, top_k: int = 20) -> List[Tuple[int, float]]: | |
| """ | |
| Retrieve top-k chunks for query using semantic similarity | |
| Args: | |
| query: Search query | |
| top_k: Number of results to return | |
| Returns: | |
| List of (chunk_index, cosine_similarity_score) tuples | |
| """ | |
| # Encode query | |
| query_embedding = self.model.encode([query], convert_to_numpy=True)[0] | |
| # Compute cosine similarities | |
| similarities = np.dot(self.chunk_embeddings, query_embedding) / ( | |
| np.linalg.norm(self.chunk_embeddings, axis=1) * np.linalg.norm(query_embedding) | |
| ) | |
| # Get top-k indices | |
| top_indices = np.argsort(similarities)[::-1][:top_k] | |
| results = [(int(idx), float(similarities[idx])) for idx in top_indices] | |
| return results | |
| class HybridRetriever: | |
| """Hybrid retriever combining BM25 and semantic search""" | |
| def __init__( | |
| self, | |
| chunks: List[str], | |
| semantic_model: str = "all-MiniLM-L6-v2", | |
| bm25_weight: float = 0.5, | |
| semantic_weight: float = 0.5, | |
| device: Optional[str] = None | |
| ): | |
| """ | |
| Initialize hybrid retriever | |
| Args: | |
| chunks: List of text chunks | |
| semantic_model: Sentence transformer model name | |
| bm25_weight: Weight for BM25 scores (0-1) | |
| semantic_weight: Weight for semantic scores (0-1) | |
| device: Device for semantic model | |
| """ | |
| self.chunks = chunks | |
| self.bm25_weight = bm25_weight | |
| self.semantic_weight = semantic_weight | |
| # Initialize both retrievers | |
| self.bm25_retriever = BM25Retriever(chunks) | |
| self.semantic_retriever = SemanticRetriever(chunks, semantic_model, device) | |
| def retrieve(self, query: str, top_k: int = 20) -> List[Tuple[int, float]]: | |
| """ | |
| Retrieve using hybrid approach | |
| Args: | |
| query: Search query | |
| top_k: Number of results to return | |
| Returns: | |
| List of (chunk_index, combined_score) tuples | |
| """ | |
| # Get results from both retrievers | |
| bm25_results = self.bm25_retriever.retrieve(query, top_k * 2) | |
| semantic_results = self.semantic_retriever.retrieve(query, top_k * 2) | |
| # Normalize scores to [0, 1] range | |
| bm25_scores = {idx: score for idx, score in bm25_results} | |
| semantic_scores = {idx: score for idx, score in semantic_results} | |
| # Normalize BM25 scores | |
| if bm25_results: | |
| max_bm25 = max(score for _, score in bm25_results) | |
| min_bm25 = min(score for _, score in bm25_results) | |
| if max_bm25 > min_bm25: | |
| bm25_scores = { | |
| idx: (score - min_bm25) / (max_bm25 - min_bm25) | |
| for idx, score in bm25_scores.items() | |
| } | |
| # Semantic scores are already in [-1, 1], normalize to [0, 1] | |
| if semantic_results: | |
| min_sem = min(score for _, score in semantic_results) | |
| max_sem = max(score for _, score in semantic_results) | |
| if max_sem > min_sem: | |
| semantic_scores = { | |
| idx: (score - min_sem) / (max_sem - min_sem) | |
| for idx, score in semantic_scores.items() | |
| } | |
| # Combine scores | |
| all_indices = set(bm25_scores.keys()) | set(semantic_scores.keys()) | |
| combined_scores = {} | |
| for idx in all_indices: | |
| bm25_score = bm25_scores.get(idx, 0.0) | |
| sem_score = semantic_scores.get(idx, 0.0) | |
| combined_scores[idx] = ( | |
| self.bm25_weight * bm25_score + | |
| self.semantic_weight * sem_score | |
| ) | |
| # Sort and return top-k | |
| sorted_results = sorted(combined_scores.items(), key=lambda x: x[1], reverse=True) | |
| return sorted_results[:top_k] | |