Text Ranking
sentence-transformers
Safetensors
English
modernbert
ecommerce
e-commerce
retail
marketplace
shopping
amazon
ebay
alibaba
google
rakuten
bestbuy
walmart
flipkart
wayfair
shein
target
etsy
shopify
taobao
asos
carrefour
costco
overstock
pretraining
encoder
language-modeling
foundation-model
text-embeddings-inference
File size: 8,218 Bytes
c57c572 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 | """RexReranker Model for HuggingFace.
Compatible with:
- Transformers: AutoModel.from_pretrained(..., trust_remote_code=True)
- Sentence Transformers: CrossEncoder(..., trust_remote_code=True)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, List, Union
from dataclasses import dataclass
from transformers import PretrainedConfig, PreTrainedModel, AutoModel
from transformers.modeling_outputs import SequenceClassifierOutput
@dataclass
class RexRerankerOutput(SequenceClassifierOutput):
"""Output class for RexReranker with additional distributional information."""
loss: Optional[torch.Tensor] = None
logits: torch.Tensor = None # Single relevance score [B, 1] for CrossEncoder compatibility
distribution_logits: torch.Tensor = None # Full distribution [B, num_bins]
relevance: torch.Tensor = None # Convenience: same as logits.squeeze(-1)
variance: torch.Tensor = None # Prediction variance
entropy: torch.Tensor = None # Distribution entropy
class RexRerankerConfig(PretrainedConfig):
"""Configuration for RexReranker model."""
model_type = "rex_reranker"
def __init__(
self,
backbone_name: str = "thebajajra/RexBERT-mini",
num_bins: int = 11,
dropout: float = 0.0,
pooling_strategy: str = "mean",
hidden_size: int = None,
num_labels: int = 1, # CrossEncoder compatibility
transitions: List[float] = None,
sigma_min: float = 0.04,
sigma_max: float = 0.12,
sigma_delta: float = 0.08,
**kwargs,
):
super().__init__(**kwargs)
self.backbone_name = backbone_name
self.num_bins = num_bins
self.dropout = dropout
self.pooling_strategy = pooling_strategy
self.hidden_size = hidden_size
self.num_labels = num_labels
self.transitions = transitions or [0.2, 0.5, 0.8]
self.sigma_min = sigma_min
self.sigma_max = sigma_max
self.sigma_delta = sigma_delta
class RexRerankerModel(PreTrainedModel):
"""
RexBERT-based distributional reranker.
Predicts a categorical distribution over K bins in [0, 1] representing
relevance scores. The output logits contain a single relevance score
for CrossEncoder compatibility, while the full distribution is available
via distribution_logits or predict_with_uncertainty().
Compatible with:
- sentence_transformers.CrossEncoder
- transformers.AutoModelForSequenceClassification
"""
config_class = RexRerankerConfig
base_model_prefix = "rex_reranker"
supports_gradient_checkpointing = True
def __init__(self, config: RexRerankerConfig):
super().__init__(config)
assert config.pooling_strategy in ("cls", "mean")
self.pooling_strategy = config.pooling_strategy
self.num_bins = config.num_bins
self.backbone = AutoModel.from_pretrained(
config.backbone_name,
trust_remote_code=True,
)
if hasattr(self.backbone, "config") and hasattr(self.backbone.config, "use_cache"):
self.backbone.config.use_cache = False
hidden_size = config.hidden_size or getattr(self.backbone.config, "hidden_size", None)
if hidden_size is None:
raise ValueError("Could not infer hidden_size.")
self.dropout = nn.Dropout(config.dropout)
self.score_head = nn.Linear(hidden_size, config.num_bins)
self.register_buffer(
"bin_centers",
torch.linspace(0.0, 1.0, config.num_bins),
persistent=False,
)
self.post_init()
def _init_weights(self, module):
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=0.02)
if module.bias is not None:
module.bias.data.zero_()
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
labels: Optional[torch.Tensor] = None,
return_dict: bool = True,
output_distribution: bool = False,
**kwargs, # Accept extra kwargs for CrossEncoder compatibility
) -> Union[RexRerankerOutput, tuple]:
"""
Forward pass.
Args:
input_ids: Token IDs [B, T]
attention_mask: Attention mask [B, T]
labels: Optional relevance labels [B]
return_dict: Whether to return a dataclass
output_distribution: If True, include full distribution info in output
Returns:
RexRerankerOutput with:
- logits: [B, 1] single relevance score (CrossEncoder compatible)
- distribution_logits: [B, num_bins] full distribution (if output_distribution=True)
- relevance, variance, entropy: convenience fields (if output_distribution=True)
"""
out = self.backbone(
input_ids=input_ids,
attention_mask=attention_mask,
return_dict=True,
)
last_hidden = out.last_hidden_state
if self.pooling_strategy == "cls":
pooled = last_hidden[:, 0, :]
else:
mask = attention_mask.unsqueeze(-1).float()
summed = (last_hidden * mask).sum(dim=1)
lengths = mask.sum(dim=1).clamp(min=1e-9)
pooled = summed / lengths
# Get distribution logits
dist_logits = self.score_head(self.dropout(pooled)) # [B, num_bins]
# Convert to single relevance score (expected value)
probs = F.softmax(dist_logits, dim=-1)
relevance = (probs * self.bin_centers.view(1, -1)).sum(dim=-1) # [B]
# Output single score as logits for CrossEncoder compatibility [B, 1]
logits = relevance.unsqueeze(-1)
loss = None
if labels is not None:
loss = F.mse_loss(relevance, labels.float())
if not return_dict:
output = (logits,)
return ((loss,) + output) if loss is not None else output
# Compute additional stats if requested
variance = None
entropy = None
if output_distribution:
variance = (probs * (self.bin_centers.view(1, -1) - relevance.unsqueeze(-1)) ** 2).sum(dim=-1)
entropy = -(probs * torch.log(probs.clamp(min=1e-9))).sum(dim=-1)
return RexRerankerOutput(
loss=loss,
logits=logits,
distribution_logits=dist_logits if output_distribution else None,
relevance=relevance,
variance=variance,
entropy=entropy,
)
def predict_relevance(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
) -> torch.Tensor:
"""Get relevance scores directly. Returns [B] tensor."""
outputs = self.forward(input_ids=input_ids, attention_mask=attention_mask)
return outputs.relevance
def predict_with_uncertainty(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
) -> dict:
"""
Get relevance prediction with full uncertainty estimates.
Returns:
dict with:
- relevance: [B] predicted relevance scores
- variance: [B] prediction variance (higher = more uncertain)
- entropy: [B] distribution entropy (higher = more uncertain)
- probs: [B, num_bins] full probability distribution
- distribution_logits: [B, num_bins] raw logits
"""
outputs = self.forward(
input_ids=input_ids,
attention_mask=attention_mask,
output_distribution=True,
)
probs = F.softmax(outputs.distribution_logits, dim=-1)
return {
"relevance": outputs.relevance,
"variance": outputs.variance,
"entropy": outputs.entropy,
"probs": probs,
"distribution_logits": outputs.distribution_logits,
}
|