"""RexReranker Model for HuggingFace. Compatible with: - Transformers: AutoModel.from_pretrained(..., trust_remote_code=True) - Sentence Transformers: CrossEncoder(..., trust_remote_code=True) """ import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional, List, Union from dataclasses import dataclass from transformers import PretrainedConfig, PreTrainedModel, AutoModel from transformers.modeling_outputs import SequenceClassifierOutput @dataclass class RexRerankerOutput(SequenceClassifierOutput): """Output class for RexReranker with additional distributional information.""" loss: Optional[torch.Tensor] = None logits: torch.Tensor = None # Single relevance score [B, 1] for CrossEncoder compatibility distribution_logits: torch.Tensor = None # Full distribution [B, num_bins] relevance: torch.Tensor = None # Convenience: same as logits.squeeze(-1) variance: torch.Tensor = None # Prediction variance entropy: torch.Tensor = None # Distribution entropy class RexRerankerConfig(PretrainedConfig): """Configuration for RexReranker model.""" model_type = "rex_reranker" def __init__( self, backbone_name: str = "thebajajra/RexBERT-mini", num_bins: int = 11, dropout: float = 0.0, pooling_strategy: str = "mean", hidden_size: int = None, num_labels: int = 1, # CrossEncoder compatibility transitions: List[float] = None, sigma_min: float = 0.04, sigma_max: float = 0.12, sigma_delta: float = 0.08, **kwargs, ): super().__init__(**kwargs) self.backbone_name = backbone_name self.num_bins = num_bins self.dropout = dropout self.pooling_strategy = pooling_strategy self.hidden_size = hidden_size self.num_labels = num_labels self.transitions = transitions or [0.2, 0.5, 0.8] self.sigma_min = sigma_min self.sigma_max = sigma_max self.sigma_delta = sigma_delta class RexRerankerModel(PreTrainedModel): """ RexBERT-based distributional reranker. Predicts a categorical distribution over K bins in [0, 1] representing relevance scores. The output logits contain a single relevance score for CrossEncoder compatibility, while the full distribution is available via distribution_logits or predict_with_uncertainty(). Compatible with: - sentence_transformers.CrossEncoder - transformers.AutoModelForSequenceClassification """ config_class = RexRerankerConfig base_model_prefix = "rex_reranker" supports_gradient_checkpointing = True def __init__(self, config: RexRerankerConfig): super().__init__(config) assert config.pooling_strategy in ("cls", "mean") self.pooling_strategy = config.pooling_strategy self.num_bins = config.num_bins self.backbone = AutoModel.from_pretrained( config.backbone_name, trust_remote_code=True, ) if hasattr(self.backbone, "config") and hasattr(self.backbone.config, "use_cache"): self.backbone.config.use_cache = False hidden_size = config.hidden_size or getattr(self.backbone.config, "hidden_size", None) if hidden_size is None: raise ValueError("Could not infer hidden_size.") self.dropout = nn.Dropout(config.dropout) self.score_head = nn.Linear(hidden_size, config.num_bins) self.register_buffer( "bin_centers", torch.linspace(0.0, 1.0, config.num_bins), persistent=False, ) self.post_init() def _init_weights(self, module): if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=0.02) if module.bias is not None: module.bias.data.zero_() def forward( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, labels: Optional[torch.Tensor] = None, return_dict: bool = True, output_distribution: bool = False, **kwargs, # Accept extra kwargs for CrossEncoder compatibility ) -> Union[RexRerankerOutput, tuple]: """ Forward pass. Args: input_ids: Token IDs [B, T] attention_mask: Attention mask [B, T] labels: Optional relevance labels [B] return_dict: Whether to return a dataclass output_distribution: If True, include full distribution info in output Returns: RexRerankerOutput with: - logits: [B, 1] single relevance score (CrossEncoder compatible) - distribution_logits: [B, num_bins] full distribution (if output_distribution=True) - relevance, variance, entropy: convenience fields (if output_distribution=True) """ out = self.backbone( input_ids=input_ids, attention_mask=attention_mask, return_dict=True, ) last_hidden = out.last_hidden_state if self.pooling_strategy == "cls": pooled = last_hidden[:, 0, :] else: mask = attention_mask.unsqueeze(-1).float() summed = (last_hidden * mask).sum(dim=1) lengths = mask.sum(dim=1).clamp(min=1e-9) pooled = summed / lengths # Get distribution logits dist_logits = self.score_head(self.dropout(pooled)) # [B, num_bins] # Convert to single relevance score (expected value) probs = F.softmax(dist_logits, dim=-1) relevance = (probs * self.bin_centers.view(1, -1)).sum(dim=-1) # [B] # Output single score as logits for CrossEncoder compatibility [B, 1] logits = relevance.unsqueeze(-1) loss = None if labels is not None: loss = F.mse_loss(relevance, labels.float()) if not return_dict: output = (logits,) return ((loss,) + output) if loss is not None else output # Compute additional stats if requested variance = None entropy = None if output_distribution: variance = (probs * (self.bin_centers.view(1, -1) - relevance.unsqueeze(-1)) ** 2).sum(dim=-1) entropy = -(probs * torch.log(probs.clamp(min=1e-9))).sum(dim=-1) return RexRerankerOutput( loss=loss, logits=logits, distribution_logits=dist_logits if output_distribution else None, relevance=relevance, variance=variance, entropy=entropy, ) def predict_relevance( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, ) -> torch.Tensor: """Get relevance scores directly. Returns [B] tensor.""" outputs = self.forward(input_ids=input_ids, attention_mask=attention_mask) return outputs.relevance def predict_with_uncertainty( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, ) -> dict: """ Get relevance prediction with full uncertainty estimates. Returns: dict with: - relevance: [B] predicted relevance scores - variance: [B] prediction variance (higher = more uncertain) - entropy: [B] distribution entropy (higher = more uncertain) - probs: [B, num_bins] full probability distribution - distribution_logits: [B, num_bins] raw logits """ outputs = self.forward( input_ids=input_ids, attention_mask=attention_mask, output_distribution=True, ) probs = F.softmax(outputs.distribution_logits, dim=-1) return { "relevance": outputs.relevance, "variance": outputs.variance, "entropy": outputs.entropy, "probs": probs, "distribution_logits": outputs.distribution_logits, }