RexReranker-base / utils.py
thebajajra's picture
Upload folder using huggingface_hub
461601f verified
"""
RexReranker Inference Utilities.
This module provides helper functions for converting model logits to relevance scores.
The model outputs logits for 11 bins representing a distribution over [0, 1].
To get a relevance score, apply softmax and compute the expected value.
Example usage:
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from utils import logits_to_relevance, logits_to_relevance_with_uncertainty
import torch
model = AutoModelForSequenceClassification.from_pretrained("path/to/model")
tokenizer = AutoTokenizer.from_pretrained("path/to/model")
inputs = tokenizer(
"Query: best laptop",
"Title: MacBook Pro\nDescription: Great laptop for developers",
return_tensors="pt",
truncation=True,
)
with torch.no_grad():
outputs = model(**inputs)
# Simple relevance score
relevance = logits_to_relevance(outputs.logits)
print(f"Relevance: {relevance.item():.3f}")
# With uncertainty estimates
result = logits_to_relevance_with_uncertainty(outputs.logits)
print(f"Relevance: {result['relevance'].item():.3f}")
print(f"Variance: {result['variance'].item():.4f}")
print(f"Entropy: {result['entropy'].item():.3f}")
"""
import torch
from typing import Dict
# Configuration
NUM_BINS = 11
BIN_CENTERS = torch.linspace(0.0, 1.0, NUM_BINS)
def logits_to_relevance(logits: torch.Tensor) -> torch.Tensor:
"""
Convert model logits to relevance scores.
Args:
logits: Model output logits [B, 11]
Returns:
relevance: Relevance scores [B] in range [0, 1]
"""
probs = torch.softmax(logits, dim=-1)
bin_centers = BIN_CENTERS.to(logits.device)
return (probs * bin_centers.view(1, -1)).sum(dim=-1)
def logits_to_relevance_with_uncertainty(logits: torch.Tensor) -> Dict[str, torch.Tensor]:
"""
Convert model logits to relevance scores with uncertainty estimates.
Args:
logits: Model output logits [B, 11]
Returns:
dict with:
- relevance: [B] predicted relevance scores in [0, 1]
- variance: [B] prediction variance (higher = more uncertain)
- entropy: [B] distribution entropy (higher = more uncertain)
- probs: [B, 11] full probability distribution over bins
"""
probs = torch.softmax(logits, dim=-1)
bin_centers = BIN_CENTERS.to(logits.device)
relevance = (probs * bin_centers.view(1, -1)).sum(dim=-1)
variance = (probs * (bin_centers.view(1, -1) - relevance.unsqueeze(-1)) ** 2).sum(dim=-1)
entropy = -(probs * torch.log(probs.clamp(min=1e-9))).sum(dim=-1)
return {
"relevance": relevance,
"variance": variance,
"entropy": entropy,
"probs": probs,
}
def batch_rerank(
model,
tokenizer,
query: str,
documents: list,
max_length: int = 2048,
batch_size: int = 32,
device: str = None,
) -> list:
"""
Rerank a list of documents for a given query.
Args:
model: The RexReranker model
tokenizer: The tokenizer
query: The search query
documents: List of dicts with 'title' and 'description' keys
max_length: Maximum sequence length
batch_size: Batch size for inference
device: Device to use (default: auto-detect)
Returns:
List of dicts with original document info plus 'relevance', 'variance', 'entropy'
"""
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
model.eval()
results = []
for i in range(0, len(documents), batch_size):
batch_docs = documents[i:i + batch_size]
# Format inputs
texts_a = [f"Query: {query}" for _ in batch_docs]
texts_b = [f"Title: {doc.get('title', '')}\nDescription: {doc.get('description', '')}" for doc in batch_docs]
inputs = tokenizer(
texts_a,
texts_b,
padding=True,
truncation=True,
max_length=max_length,
return_tensors="pt",
).to(device)
with torch.no_grad():
outputs = model(**inputs)
batch_results = logits_to_relevance_with_uncertainty(outputs.logits)
for j, doc in enumerate(batch_docs):
results.append({
**doc,
"relevance": batch_results["relevance"][j].item(),
"variance": batch_results["variance"][j].item(),
"entropy": batch_results["entropy"][j].item(),
})
# Sort by relevance (descending)
results.sort(key=lambda x: x["relevance"], reverse=True)
return results