safe_rag / retriever /reranker.py
Tairun Meng
Initial commit: SafeRAG project ready for HF Spaces
db06013
from sentence_transformers import CrossEncoder
from typing import List
import numpy as np
import logging
logger = logging.getLogger(__name__)
class Reranker:
def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2", device: str = "cuda"):
self.model_name = model_name
self.device = device
self.model = CrossEncoder(model_name, device=device)
logger.info(f"Loaded reranker model: {model_name}")
def rerank(self, query: str, passages: List[str], batch_size: int = 32) -> List[float]:
"""Rerank passages for a query"""
if not passages:
return []
# Create query-passage pairs
pairs = [(query, passage) for passage in passages]
# Get relevance scores
scores = self.model.predict(pairs, batch_size=batch_size)
return scores.tolist()
def rerank_batch(self, queries: List[str], passages_list: List[List[str]],
batch_size: int = 32) -> List[List[float]]:
"""Rerank passages for multiple queries"""
all_scores = []
for query, passages in zip(queries, passages_list):
scores = self.rerank(query, passages, batch_size)
all_scores.append(scores)
return all_scores
def get_top_k(self, query: str, passages: List[str], k: int = 5) -> List[tuple]:
"""Get top-k passages with scores"""
scores = self.rerank(query, passages)
# Sort by score
ranked = sorted(zip(passages, scores), key=lambda x: x[1], reverse=True)
return ranked[:k]