Spaces:
Sleeping
Sleeping
| from sentence_transformers import CrossEncoder | |
| from typing import List | |
| import numpy as np | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| class Reranker: | |
| def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2", device: str = "cuda"): | |
| self.model_name = model_name | |
| self.device = device | |
| self.model = CrossEncoder(model_name, device=device) | |
| logger.info(f"Loaded reranker model: {model_name}") | |
| def rerank(self, query: str, passages: List[str], batch_size: int = 32) -> List[float]: | |
| """Rerank passages for a query""" | |
| if not passages: | |
| return [] | |
| # Create query-passage pairs | |
| pairs = [(query, passage) for passage in passages] | |
| # Get relevance scores | |
| scores = self.model.predict(pairs, batch_size=batch_size) | |
| return scores.tolist() | |
| def rerank_batch(self, queries: List[str], passages_list: List[List[str]], | |
| batch_size: int = 32) -> List[List[float]]: | |
| """Rerank passages for multiple queries""" | |
| all_scores = [] | |
| for query, passages in zip(queries, passages_list): | |
| scores = self.rerank(query, passages, batch_size) | |
| all_scores.append(scores) | |
| return all_scores | |
| def get_top_k(self, query: str, passages: List[str], k: int = 5) -> List[tuple]: | |
| """Get top-k passages with scores""" | |
| scores = self.rerank(query, passages) | |
| # Sort by score | |
| ranked = sorted(zip(passages, scores), key=lambda x: x[1], reverse=True) | |
| return ranked[:k] | |