Text Ranking
sentence-transformers
Safetensors
English
modernbert
ecommerce
e-commerce
retail
marketplace
shopping
amazon
ebay
alibaba
google
rakuten
bestbuy
walmart
flipkart
wayfair
shein
target
etsy
shopify
taobao
asos
carrefour
costco
overstock
pretraining
encoder
language-modeling
foundation-model
text-embeddings-inference
File size: 4,846 Bytes
c57c572 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
"""
RexReranker Inference Utilities.
This module provides helper functions for converting model logits to relevance scores.
The model outputs logits for 11 bins representing a distribution over [0, 1].
To get a relevance score, apply softmax and compute the expected value.
Example usage:
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from utils import logits_to_relevance, logits_to_relevance_with_uncertainty
import torch
model = AutoModelForSequenceClassification.from_pretrained("path/to/model")
tokenizer = AutoTokenizer.from_pretrained("path/to/model")
inputs = tokenizer(
"Query: best laptop",
"Title: MacBook Pro\nDescription: Great laptop for developers",
return_tensors="pt",
truncation=True,
)
with torch.no_grad():
outputs = model(**inputs)
# Simple relevance score
relevance = logits_to_relevance(outputs.logits)
print(f"Relevance: {relevance.item():.3f}")
# With uncertainty estimates
result = logits_to_relevance_with_uncertainty(outputs.logits)
print(f"Relevance: {result['relevance'].item():.3f}")
print(f"Variance: {result['variance'].item():.4f}")
print(f"Entropy: {result['entropy'].item():.3f}")
"""
import torch
from typing import Dict
# Configuration
NUM_BINS = 11
BIN_CENTERS = torch.linspace(0.0, 1.0, NUM_BINS)
def logits_to_relevance(logits: torch.Tensor) -> torch.Tensor:
"""
Convert model logits to relevance scores.
Args:
logits: Model output logits [B, 11]
Returns:
relevance: Relevance scores [B] in range [0, 1]
"""
probs = torch.softmax(logits, dim=-1)
bin_centers = BIN_CENTERS.to(logits.device)
return (probs * bin_centers.view(1, -1)).sum(dim=-1)
def logits_to_relevance_with_uncertainty(logits: torch.Tensor) -> Dict[str, torch.Tensor]:
"""
Convert model logits to relevance scores with uncertainty estimates.
Args:
logits: Model output logits [B, 11]
Returns:
dict with:
- relevance: [B] predicted relevance scores in [0, 1]
- variance: [B] prediction variance (higher = more uncertain)
- entropy: [B] distribution entropy (higher = more uncertain)
- probs: [B, 11] full probability distribution over bins
"""
probs = torch.softmax(logits, dim=-1)
bin_centers = BIN_CENTERS.to(logits.device)
relevance = (probs * bin_centers.view(1, -1)).sum(dim=-1)
variance = (probs * (bin_centers.view(1, -1) - relevance.unsqueeze(-1)) ** 2).sum(dim=-1)
entropy = -(probs * torch.log(probs.clamp(min=1e-9))).sum(dim=-1)
return {
"relevance": relevance,
"variance": variance,
"entropy": entropy,
"probs": probs,
}
def batch_rerank(
model,
tokenizer,
query: str,
documents: list,
max_length: int = 2048,
batch_size: int = 32,
device: str = None,
) -> list:
"""
Rerank a list of documents for a given query.
Args:
model: The RexReranker model
tokenizer: The tokenizer
query: The search query
documents: List of dicts with 'title' and 'description' keys
max_length: Maximum sequence length
batch_size: Batch size for inference
device: Device to use (default: auto-detect)
Returns:
List of dicts with original document info plus 'relevance', 'variance', 'entropy'
"""
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
model.eval()
results = []
for i in range(0, len(documents), batch_size):
batch_docs = documents[i:i + batch_size]
# Format inputs
texts_a = [f"Query: {query}" for _ in batch_docs]
texts_b = [f"Title: {doc.get('title', '')}\nDescription: {doc.get('description', '')}" for doc in batch_docs]
inputs = tokenizer(
texts_a,
texts_b,
padding=True,
truncation=True,
max_length=max_length,
return_tensors="pt",
).to(device)
with torch.no_grad():
outputs = model(**inputs)
batch_results = logits_to_relevance_with_uncertainty(outputs.logits)
for j, doc in enumerate(batch_docs):
results.append({
**doc,
"relevance": batch_results["relevance"][j].item(),
"variance": batch_results["variance"][j].item(),
"entropy": batch_results["entropy"][j].item(),
})
# Sort by relevance (descending)
results.sort(key=lambda x: x["relevance"], reverse=True)
return results
|