safe_rag / retriever /retriever.py
Tairun Meng
Initial commit: SafeRAG project ready for HF Spaces
db06013
from typing import List, Dict, Any, Tuple
import numpy as np
from .embedder import Embedder
from .faiss_index import FAISSIndex
from .reranker import Reranker
import logging
logger = logging.getLogger(__name__)
class Retriever:
def __init__(self, embedder: Embedder, index: FAISSIndex, reranker: Reranker = None):
self.embedder = embedder
self.index = index
self.reranker = reranker
def retrieve(self, queries: List[str], k: int = 20,
rerank_k: int = 10) -> List[List[Dict[str, Any]]]:
"""Retrieve and rerank passages for queries"""
if not queries:
return []
# Encode queries
query_embeddings = self.embedder.encode_queries(queries)
# Search index
scores, indices = self.index.search(query_embeddings, k)
# Format results
results = []
for i, query in enumerate(queries):
query_results = []
for j, (score, idx) in enumerate(zip(scores[i], indices[i])):
if idx == -1: # Invalid index
continue
text = self.index.id_to_text.get(idx, "")
metadata = self.index.id_to_metadata.get(idx, {})
query_results.append({
'text': text,
'score': float(score),
'rank': j + 1,
'metadata': metadata,
'id': idx
})
results.append(query_results)
# Rerank if reranker is available
if self.reranker and rerank_k < k:
reranked_results = []
for i, query in enumerate(queries):
passages = [r['text'] for r in results[i][:k]]
rerank_scores = self.reranker.rerank(query, passages)
# Reorder results based on rerank scores
reranked = sorted(
zip(results[i][:k], rerank_scores),
key=lambda x: x[1],
reverse=True
)
reranked_results.append([
{**result, 'rerank_score': score, 'rank': j + 1}
for j, (result, score) in enumerate(reranked[:rerank_k])
])
results = reranked_results
return results
def retrieve_single(self, query: str, k: int = 10) -> List[Dict[str, Any]]:
"""Retrieve for a single query"""
results = self.retrieve([query], k)
return results[0] if results else []
def batch_retrieve(self, queries: List[str], batch_size: int = 32,
k: int = 10) -> List[List[Dict[str, Any]]]:
"""Retrieve for multiple queries in batches"""
all_results = []
for i in range(0, len(queries), batch_size):
batch_queries = queries[i:i + batch_size]
batch_results = self.retrieve(batch_queries, k)
all_results.extend(batch_results)
return all_results
def get_retrieval_stats(self, queries: List[str], k: int = 10) -> Dict[str, Any]:
"""Get retrieval statistics"""
results = self.retrieve(queries, k)
scores = []
for query_results in results:
scores.extend([r['score'] for r in query_results])
return {
'num_queries': len(queries),
'avg_scores': np.mean(scores) if scores else 0,
'std_scores': np.std(scores) if scores else 0,
'min_scores': np.min(scores) if scores else 0,
'max_scores': np.max(scores) if scores else 0,
'avg_results_per_query': np.mean([len(r) for r in results])
}