Text Ranking
sentence-transformers
Safetensors
English
modernbert
ecommerce
e-commerce
retail
marketplace
shopping
amazon
ebay
alibaba
google
rakuten
bestbuy
walmart
flipkart
wayfair
shein
target
etsy
shopify
taobao
asos
carrefour
costco
overstock
pretraining
encoder
language-modeling
foundation-model
text-embeddings-inference
| """ | |
| RexReranker Inference Utilities. | |
| This module provides helper functions for converting model logits to relevance scores. | |
| The model outputs logits for 11 bins representing a distribution over [0, 1]. | |
| To get a relevance score, apply softmax and compute the expected value. | |
| Example usage: | |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
| from utils import logits_to_relevance, logits_to_relevance_with_uncertainty | |
| import torch | |
| model = AutoModelForSequenceClassification.from_pretrained("path/to/model") | |
| tokenizer = AutoTokenizer.from_pretrained("path/to/model") | |
| inputs = tokenizer( | |
| "Query: best laptop", | |
| "Title: MacBook Pro\nDescription: Great laptop for developers", | |
| return_tensors="pt", | |
| truncation=True, | |
| ) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| # Simple relevance score | |
| relevance = logits_to_relevance(outputs.logits) | |
| print(f"Relevance: {relevance.item():.3f}") | |
| # With uncertainty estimates | |
| result = logits_to_relevance_with_uncertainty(outputs.logits) | |
| print(f"Relevance: {result['relevance'].item():.3f}") | |
| print(f"Variance: {result['variance'].item():.4f}") | |
| print(f"Entropy: {result['entropy'].item():.3f}") | |
| """ | |
| import torch | |
| from typing import Dict | |
| # Configuration | |
| NUM_BINS = 11 | |
| BIN_CENTERS = torch.linspace(0.0, 1.0, NUM_BINS) | |
| def logits_to_relevance(logits: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Convert model logits to relevance scores. | |
| Args: | |
| logits: Model output logits [B, 11] | |
| Returns: | |
| relevance: Relevance scores [B] in range [0, 1] | |
| """ | |
| probs = torch.softmax(logits, dim=-1) | |
| bin_centers = BIN_CENTERS.to(logits.device) | |
| return (probs * bin_centers.view(1, -1)).sum(dim=-1) | |
| def logits_to_relevance_with_uncertainty(logits: torch.Tensor) -> Dict[str, torch.Tensor]: | |
| """ | |
| Convert model logits to relevance scores with uncertainty estimates. | |
| Args: | |
| logits: Model output logits [B, 11] | |
| Returns: | |
| dict with: | |
| - relevance: [B] predicted relevance scores in [0, 1] | |
| - variance: [B] prediction variance (higher = more uncertain) | |
| - entropy: [B] distribution entropy (higher = more uncertain) | |
| - probs: [B, 11] full probability distribution over bins | |
| """ | |
| probs = torch.softmax(logits, dim=-1) | |
| bin_centers = BIN_CENTERS.to(logits.device) | |
| relevance = (probs * bin_centers.view(1, -1)).sum(dim=-1) | |
| variance = (probs * (bin_centers.view(1, -1) - relevance.unsqueeze(-1)) ** 2).sum(dim=-1) | |
| entropy = -(probs * torch.log(probs.clamp(min=1e-9))).sum(dim=-1) | |
| return { | |
| "relevance": relevance, | |
| "variance": variance, | |
| "entropy": entropy, | |
| "probs": probs, | |
| } | |
| def batch_rerank( | |
| model, | |
| tokenizer, | |
| query: str, | |
| documents: list, | |
| max_length: int = 2048, | |
| batch_size: int = 32, | |
| device: str = None, | |
| ) -> list: | |
| """ | |
| Rerank a list of documents for a given query. | |
| Args: | |
| model: The RexReranker model | |
| tokenizer: The tokenizer | |
| query: The search query | |
| documents: List of dicts with 'title' and 'description' keys | |
| max_length: Maximum sequence length | |
| batch_size: Batch size for inference | |
| device: Device to use (default: auto-detect) | |
| Returns: | |
| List of dicts with original document info plus 'relevance', 'variance', 'entropy' | |
| """ | |
| if device is None: | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = model.to(device) | |
| model.eval() | |
| results = [] | |
| for i in range(0, len(documents), batch_size): | |
| batch_docs = documents[i:i + batch_size] | |
| # Format inputs | |
| texts_a = [f"Query: {query}" for _ in batch_docs] | |
| texts_b = [f"Title: {doc.get('title', '')}\nDescription: {doc.get('description', '')}" for doc in batch_docs] | |
| inputs = tokenizer( | |
| texts_a, | |
| texts_b, | |
| padding=True, | |
| truncation=True, | |
| max_length=max_length, | |
| return_tensors="pt", | |
| ).to(device) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| batch_results = logits_to_relevance_with_uncertainty(outputs.logits) | |
| for j, doc in enumerate(batch_docs): | |
| results.append({ | |
| **doc, | |
| "relevance": batch_results["relevance"][j].item(), | |
| "variance": batch_results["variance"][j].item(), | |
| "entropy": batch_results["entropy"][j].item(), | |
| }) | |
| # Sort by relevance (descending) | |
| results.sort(key=lambda x: x["relevance"], reverse=True) | |
| return results | |