Spaces:
Sleeping
Sleeping
| from typing import List, Dict, Any, Tuple | |
| import numpy as np | |
| from .embedder import Embedder | |
| from .faiss_index import FAISSIndex | |
| from .reranker import Reranker | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| class Retriever: | |
| def __init__(self, embedder: Embedder, index: FAISSIndex, reranker: Reranker = None): | |
| self.embedder = embedder | |
| self.index = index | |
| self.reranker = reranker | |
| def retrieve(self, queries: List[str], k: int = 20, | |
| rerank_k: int = 10) -> List[List[Dict[str, Any]]]: | |
| """Retrieve and rerank passages for queries""" | |
| if not queries: | |
| return [] | |
| # Encode queries | |
| query_embeddings = self.embedder.encode_queries(queries) | |
| # Search index | |
| scores, indices = self.index.search(query_embeddings, k) | |
| # Format results | |
| results = [] | |
| for i, query in enumerate(queries): | |
| query_results = [] | |
| for j, (score, idx) in enumerate(zip(scores[i], indices[i])): | |
| if idx == -1: # Invalid index | |
| continue | |
| text = self.index.id_to_text.get(idx, "") | |
| metadata = self.index.id_to_metadata.get(idx, {}) | |
| query_results.append({ | |
| 'text': text, | |
| 'score': float(score), | |
| 'rank': j + 1, | |
| 'metadata': metadata, | |
| 'id': idx | |
| }) | |
| results.append(query_results) | |
| # Rerank if reranker is available | |
| if self.reranker and rerank_k < k: | |
| reranked_results = [] | |
| for i, query in enumerate(queries): | |
| passages = [r['text'] for r in results[i][:k]] | |
| rerank_scores = self.reranker.rerank(query, passages) | |
| # Reorder results based on rerank scores | |
| reranked = sorted( | |
| zip(results[i][:k], rerank_scores), | |
| key=lambda x: x[1], | |
| reverse=True | |
| ) | |
| reranked_results.append([ | |
| {**result, 'rerank_score': score, 'rank': j + 1} | |
| for j, (result, score) in enumerate(reranked[:rerank_k]) | |
| ]) | |
| results = reranked_results | |
| return results | |
| def retrieve_single(self, query: str, k: int = 10) -> List[Dict[str, Any]]: | |
| """Retrieve for a single query""" | |
| results = self.retrieve([query], k) | |
| return results[0] if results else [] | |
| def batch_retrieve(self, queries: List[str], batch_size: int = 32, | |
| k: int = 10) -> List[List[Dict[str, Any]]]: | |
| """Retrieve for multiple queries in batches""" | |
| all_results = [] | |
| for i in range(0, len(queries), batch_size): | |
| batch_queries = queries[i:i + batch_size] | |
| batch_results = self.retrieve(batch_queries, k) | |
| all_results.extend(batch_results) | |
| return all_results | |
| def get_retrieval_stats(self, queries: List[str], k: int = 10) -> Dict[str, Any]: | |
| """Get retrieval statistics""" | |
| results = self.retrieve(queries, k) | |
| scores = [] | |
| for query_results in results: | |
| scores.extend([r['score'] for r in query_results]) | |
| return { | |
| 'num_queries': len(queries), | |
| 'avg_scores': np.mean(scores) if scores else 0, | |
| 'std_scores': np.std(scores) if scores else 0, | |
| 'min_scores': np.min(scores) if scores else 0, | |
| 'max_scores': np.max(scores) if scores else 0, | |
| 'avg_results_per_query': np.mean([len(r) for r in results]) | |
| } | |