Text Ranking
sentence-transformers
Safetensors
English
modernbert
ecommerce
e-commerce
retail
marketplace
shopping
amazon
ebay
alibaba
google
rakuten
bestbuy
walmart
flipkart
wayfair
shein
target
etsy
shopify
taobao
asos
carrefour
costco
overstock
pretraining
encoder
language-modeling
foundation-model
text-embeddings-inference
| """RexReranker Model for HuggingFace. | |
| Compatible with: | |
| - Transformers: AutoModel.from_pretrained(..., trust_remote_code=True) | |
| - Sentence Transformers: CrossEncoder(..., trust_remote_code=True) | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from typing import Optional, List, Union | |
| from dataclasses import dataclass | |
| from transformers import PretrainedConfig, PreTrainedModel, AutoModel | |
| from transformers.modeling_outputs import SequenceClassifierOutput | |
| class RexRerankerOutput(SequenceClassifierOutput): | |
| """Output class for RexReranker with additional distributional information.""" | |
| loss: Optional[torch.Tensor] = None | |
| logits: torch.Tensor = None # Single relevance score [B, 1] for CrossEncoder compatibility | |
| distribution_logits: torch.Tensor = None # Full distribution [B, num_bins] | |
| relevance: torch.Tensor = None # Convenience: same as logits.squeeze(-1) | |
| variance: torch.Tensor = None # Prediction variance | |
| entropy: torch.Tensor = None # Distribution entropy | |
| class RexRerankerConfig(PretrainedConfig): | |
| """Configuration for RexReranker model.""" | |
| model_type = "rex_reranker" | |
| def __init__( | |
| self, | |
| backbone_name: str = "thebajajra/RexBERT-mini", | |
| num_bins: int = 11, | |
| dropout: float = 0.0, | |
| pooling_strategy: str = "mean", | |
| hidden_size: int = None, | |
| num_labels: int = 1, # CrossEncoder compatibility | |
| transitions: List[float] = None, | |
| sigma_min: float = 0.04, | |
| sigma_max: float = 0.12, | |
| sigma_delta: float = 0.08, | |
| **kwargs, | |
| ): | |
| super().__init__(**kwargs) | |
| self.backbone_name = backbone_name | |
| self.num_bins = num_bins | |
| self.dropout = dropout | |
| self.pooling_strategy = pooling_strategy | |
| self.hidden_size = hidden_size | |
| self.num_labels = num_labels | |
| self.transitions = transitions or [0.2, 0.5, 0.8] | |
| self.sigma_min = sigma_min | |
| self.sigma_max = sigma_max | |
| self.sigma_delta = sigma_delta | |
| class RexRerankerModel(PreTrainedModel): | |
| """ | |
| RexBERT-based distributional reranker. | |
| Predicts a categorical distribution over K bins in [0, 1] representing | |
| relevance scores. The output logits contain a single relevance score | |
| for CrossEncoder compatibility, while the full distribution is available | |
| via distribution_logits or predict_with_uncertainty(). | |
| Compatible with: | |
| - sentence_transformers.CrossEncoder | |
| - transformers.AutoModelForSequenceClassification | |
| """ | |
| config_class = RexRerankerConfig | |
| base_model_prefix = "rex_reranker" | |
| supports_gradient_checkpointing = True | |
| def __init__(self, config: RexRerankerConfig): | |
| super().__init__(config) | |
| assert config.pooling_strategy in ("cls", "mean") | |
| self.pooling_strategy = config.pooling_strategy | |
| self.num_bins = config.num_bins | |
| self.backbone = AutoModel.from_pretrained( | |
| config.backbone_name, | |
| trust_remote_code=True, | |
| ) | |
| if hasattr(self.backbone, "config") and hasattr(self.backbone.config, "use_cache"): | |
| self.backbone.config.use_cache = False | |
| hidden_size = config.hidden_size or getattr(self.backbone.config, "hidden_size", None) | |
| if hidden_size is None: | |
| raise ValueError("Could not infer hidden_size.") | |
| self.dropout = nn.Dropout(config.dropout) | |
| self.score_head = nn.Linear(hidden_size, config.num_bins) | |
| self.register_buffer( | |
| "bin_centers", | |
| torch.linspace(0.0, 1.0, config.num_bins), | |
| persistent=False, | |
| ) | |
| self.post_init() | |
| def _init_weights(self, module): | |
| if isinstance(module, nn.Linear): | |
| module.weight.data.normal_(mean=0.0, std=0.02) | |
| if module.bias is not None: | |
| module.bias.data.zero_() | |
| def forward( | |
| self, | |
| input_ids: torch.Tensor, | |
| attention_mask: torch.Tensor, | |
| labels: Optional[torch.Tensor] = None, | |
| return_dict: bool = True, | |
| output_distribution: bool = False, | |
| **kwargs, # Accept extra kwargs for CrossEncoder compatibility | |
| ) -> Union[RexRerankerOutput, tuple]: | |
| """ | |
| Forward pass. | |
| Args: | |
| input_ids: Token IDs [B, T] | |
| attention_mask: Attention mask [B, T] | |
| labels: Optional relevance labels [B] | |
| return_dict: Whether to return a dataclass | |
| output_distribution: If True, include full distribution info in output | |
| Returns: | |
| RexRerankerOutput with: | |
| - logits: [B, 1] single relevance score (CrossEncoder compatible) | |
| - distribution_logits: [B, num_bins] full distribution (if output_distribution=True) | |
| - relevance, variance, entropy: convenience fields (if output_distribution=True) | |
| """ | |
| out = self.backbone( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| return_dict=True, | |
| ) | |
| last_hidden = out.last_hidden_state | |
| if self.pooling_strategy == "cls": | |
| pooled = last_hidden[:, 0, :] | |
| else: | |
| mask = attention_mask.unsqueeze(-1).float() | |
| summed = (last_hidden * mask).sum(dim=1) | |
| lengths = mask.sum(dim=1).clamp(min=1e-9) | |
| pooled = summed / lengths | |
| # Get distribution logits | |
| dist_logits = self.score_head(self.dropout(pooled)) # [B, num_bins] | |
| # Convert to single relevance score (expected value) | |
| probs = F.softmax(dist_logits, dim=-1) | |
| relevance = (probs * self.bin_centers.view(1, -1)).sum(dim=-1) # [B] | |
| # Output single score as logits for CrossEncoder compatibility [B, 1] | |
| logits = relevance.unsqueeze(-1) | |
| loss = None | |
| if labels is not None: | |
| loss = F.mse_loss(relevance, labels.float()) | |
| if not return_dict: | |
| output = (logits,) | |
| return ((loss,) + output) if loss is not None else output | |
| # Compute additional stats if requested | |
| variance = None | |
| entropy = None | |
| if output_distribution: | |
| variance = (probs * (self.bin_centers.view(1, -1) - relevance.unsqueeze(-1)) ** 2).sum(dim=-1) | |
| entropy = -(probs * torch.log(probs.clamp(min=1e-9))).sum(dim=-1) | |
| return RexRerankerOutput( | |
| loss=loss, | |
| logits=logits, | |
| distribution_logits=dist_logits if output_distribution else None, | |
| relevance=relevance, | |
| variance=variance, | |
| entropy=entropy, | |
| ) | |
| def predict_relevance( | |
| self, | |
| input_ids: torch.Tensor, | |
| attention_mask: torch.Tensor, | |
| ) -> torch.Tensor: | |
| """Get relevance scores directly. Returns [B] tensor.""" | |
| outputs = self.forward(input_ids=input_ids, attention_mask=attention_mask) | |
| return outputs.relevance | |
| def predict_with_uncertainty( | |
| self, | |
| input_ids: torch.Tensor, | |
| attention_mask: torch.Tensor, | |
| ) -> dict: | |
| """ | |
| Get relevance prediction with full uncertainty estimates. | |
| Returns: | |
| dict with: | |
| - relevance: [B] predicted relevance scores | |
| - variance: [B] prediction variance (higher = more uncertain) | |
| - entropy: [B] distribution entropy (higher = more uncertain) | |
| - probs: [B, num_bins] full probability distribution | |
| - distribution_logits: [B, num_bins] raw logits | |
| """ | |
| outputs = self.forward( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| output_distribution=True, | |
| ) | |
| probs = F.softmax(outputs.distribution_logits, dim=-1) | |
| return { | |
| "relevance": outputs.relevance, | |
| "variance": outputs.variance, | |
| "entropy": outputs.entropy, | |
| "probs": probs, | |
| "distribution_logits": outputs.distribution_logits, | |
| } | |