""" RexReranker Inference Utilities. This module provides helper functions for converting model logits to relevance scores. The model outputs logits for 11 bins representing a distribution over [0, 1]. To get a relevance score, apply softmax and compute the expected value. Example usage: from transformers import AutoModelForSequenceClassification, AutoTokenizer from utils import logits_to_relevance, logits_to_relevance_with_uncertainty import torch model = AutoModelForSequenceClassification.from_pretrained("path/to/model") tokenizer = AutoTokenizer.from_pretrained("path/to/model") inputs = tokenizer( "Query: best laptop", "Title: MacBook Pro\nDescription: Great laptop for developers", return_tensors="pt", truncation=True, ) with torch.no_grad(): outputs = model(**inputs) # Simple relevance score relevance = logits_to_relevance(outputs.logits) print(f"Relevance: {relevance.item():.3f}") # With uncertainty estimates result = logits_to_relevance_with_uncertainty(outputs.logits) print(f"Relevance: {result['relevance'].item():.3f}") print(f"Variance: {result['variance'].item():.4f}") print(f"Entropy: {result['entropy'].item():.3f}") """ import torch from typing import Dict # Configuration NUM_BINS = 11 BIN_CENTERS = torch.linspace(0.0, 1.0, NUM_BINS) def logits_to_relevance(logits: torch.Tensor) -> torch.Tensor: """ Convert model logits to relevance scores. Args: logits: Model output logits [B, 11] Returns: relevance: Relevance scores [B] in range [0, 1] """ probs = torch.softmax(logits, dim=-1) bin_centers = BIN_CENTERS.to(logits.device) return (probs * bin_centers.view(1, -1)).sum(dim=-1) def logits_to_relevance_with_uncertainty(logits: torch.Tensor) -> Dict[str, torch.Tensor]: """ Convert model logits to relevance scores with uncertainty estimates. Args: logits: Model output logits [B, 11] Returns: dict with: - relevance: [B] predicted relevance scores in [0, 1] - variance: [B] prediction variance (higher = more uncertain) - entropy: [B] distribution entropy (higher = more uncertain) - probs: [B, 11] full probability distribution over bins """ probs = torch.softmax(logits, dim=-1) bin_centers = BIN_CENTERS.to(logits.device) relevance = (probs * bin_centers.view(1, -1)).sum(dim=-1) variance = (probs * (bin_centers.view(1, -1) - relevance.unsqueeze(-1)) ** 2).sum(dim=-1) entropy = -(probs * torch.log(probs.clamp(min=1e-9))).sum(dim=-1) return { "relevance": relevance, "variance": variance, "entropy": entropy, "probs": probs, } def batch_rerank( model, tokenizer, query: str, documents: list, max_length: int = 2048, batch_size: int = 32, device: str = None, ) -> list: """ Rerank a list of documents for a given query. Args: model: The RexReranker model tokenizer: The tokenizer query: The search query documents: List of dicts with 'title' and 'description' keys max_length: Maximum sequence length batch_size: Batch size for inference device: Device to use (default: auto-detect) Returns: List of dicts with original document info plus 'relevance', 'variance', 'entropy' """ if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" model = model.to(device) model.eval() results = [] for i in range(0, len(documents), batch_size): batch_docs = documents[i:i + batch_size] # Format inputs texts_a = [f"Query: {query}" for _ in batch_docs] texts_b = [f"Title: {doc.get('title', '')}\nDescription: {doc.get('description', '')}" for doc in batch_docs] inputs = tokenizer( texts_a, texts_b, padding=True, truncation=True, max_length=max_length, return_tensors="pt", ).to(device) with torch.no_grad(): outputs = model(**inputs) batch_results = logits_to_relevance_with_uncertainty(outputs.logits) for j, doc in enumerate(batch_docs): results.append({ **doc, "relevance": batch_results["relevance"][j].item(), "variance": batch_results["variance"][j].item(), "entropy": batch_results["entropy"][j].item(), }) # Sort by relevance (descending) results.sort(key=lambda x: x["relevance"], reverse=True) return results