RexReranker-micro / modeling_rex_reranker.py
thebajajra's picture
Upload folder using huggingface_hub
c57c572 verified
"""RexReranker Model for HuggingFace.
Compatible with:
- Transformers: AutoModel.from_pretrained(..., trust_remote_code=True)
- Sentence Transformers: CrossEncoder(..., trust_remote_code=True)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, List, Union
from dataclasses import dataclass
from transformers import PretrainedConfig, PreTrainedModel, AutoModel
from transformers.modeling_outputs import SequenceClassifierOutput
@dataclass
class RexRerankerOutput(SequenceClassifierOutput):
"""Output class for RexReranker with additional distributional information."""
loss: Optional[torch.Tensor] = None
logits: torch.Tensor = None # Single relevance score [B, 1] for CrossEncoder compatibility
distribution_logits: torch.Tensor = None # Full distribution [B, num_bins]
relevance: torch.Tensor = None # Convenience: same as logits.squeeze(-1)
variance: torch.Tensor = None # Prediction variance
entropy: torch.Tensor = None # Distribution entropy
class RexRerankerConfig(PretrainedConfig):
"""Configuration for RexReranker model."""
model_type = "rex_reranker"
def __init__(
self,
backbone_name: str = "thebajajra/RexBERT-mini",
num_bins: int = 11,
dropout: float = 0.0,
pooling_strategy: str = "mean",
hidden_size: int = None,
num_labels: int = 1, # CrossEncoder compatibility
transitions: List[float] = None,
sigma_min: float = 0.04,
sigma_max: float = 0.12,
sigma_delta: float = 0.08,
**kwargs,
):
super().__init__(**kwargs)
self.backbone_name = backbone_name
self.num_bins = num_bins
self.dropout = dropout
self.pooling_strategy = pooling_strategy
self.hidden_size = hidden_size
self.num_labels = num_labels
self.transitions = transitions or [0.2, 0.5, 0.8]
self.sigma_min = sigma_min
self.sigma_max = sigma_max
self.sigma_delta = sigma_delta
class RexRerankerModel(PreTrainedModel):
"""
RexBERT-based distributional reranker.
Predicts a categorical distribution over K bins in [0, 1] representing
relevance scores. The output logits contain a single relevance score
for CrossEncoder compatibility, while the full distribution is available
via distribution_logits or predict_with_uncertainty().
Compatible with:
- sentence_transformers.CrossEncoder
- transformers.AutoModelForSequenceClassification
"""
config_class = RexRerankerConfig
base_model_prefix = "rex_reranker"
supports_gradient_checkpointing = True
def __init__(self, config: RexRerankerConfig):
super().__init__(config)
assert config.pooling_strategy in ("cls", "mean")
self.pooling_strategy = config.pooling_strategy
self.num_bins = config.num_bins
self.backbone = AutoModel.from_pretrained(
config.backbone_name,
trust_remote_code=True,
)
if hasattr(self.backbone, "config") and hasattr(self.backbone.config, "use_cache"):
self.backbone.config.use_cache = False
hidden_size = config.hidden_size or getattr(self.backbone.config, "hidden_size", None)
if hidden_size is None:
raise ValueError("Could not infer hidden_size.")
self.dropout = nn.Dropout(config.dropout)
self.score_head = nn.Linear(hidden_size, config.num_bins)
self.register_buffer(
"bin_centers",
torch.linspace(0.0, 1.0, config.num_bins),
persistent=False,
)
self.post_init()
def _init_weights(self, module):
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=0.02)
if module.bias is not None:
module.bias.data.zero_()
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
labels: Optional[torch.Tensor] = None,
return_dict: bool = True,
output_distribution: bool = False,
**kwargs, # Accept extra kwargs for CrossEncoder compatibility
) -> Union[RexRerankerOutput, tuple]:
"""
Forward pass.
Args:
input_ids: Token IDs [B, T]
attention_mask: Attention mask [B, T]
labels: Optional relevance labels [B]
return_dict: Whether to return a dataclass
output_distribution: If True, include full distribution info in output
Returns:
RexRerankerOutput with:
- logits: [B, 1] single relevance score (CrossEncoder compatible)
- distribution_logits: [B, num_bins] full distribution (if output_distribution=True)
- relevance, variance, entropy: convenience fields (if output_distribution=True)
"""
out = self.backbone(
input_ids=input_ids,
attention_mask=attention_mask,
return_dict=True,
)
last_hidden = out.last_hidden_state
if self.pooling_strategy == "cls":
pooled = last_hidden[:, 0, :]
else:
mask = attention_mask.unsqueeze(-1).float()
summed = (last_hidden * mask).sum(dim=1)
lengths = mask.sum(dim=1).clamp(min=1e-9)
pooled = summed / lengths
# Get distribution logits
dist_logits = self.score_head(self.dropout(pooled)) # [B, num_bins]
# Convert to single relevance score (expected value)
probs = F.softmax(dist_logits, dim=-1)
relevance = (probs * self.bin_centers.view(1, -1)).sum(dim=-1) # [B]
# Output single score as logits for CrossEncoder compatibility [B, 1]
logits = relevance.unsqueeze(-1)
loss = None
if labels is not None:
loss = F.mse_loss(relevance, labels.float())
if not return_dict:
output = (logits,)
return ((loss,) + output) if loss is not None else output
# Compute additional stats if requested
variance = None
entropy = None
if output_distribution:
variance = (probs * (self.bin_centers.view(1, -1) - relevance.unsqueeze(-1)) ** 2).sum(dim=-1)
entropy = -(probs * torch.log(probs.clamp(min=1e-9))).sum(dim=-1)
return RexRerankerOutput(
loss=loss,
logits=logits,
distribution_logits=dist_logits if output_distribution else None,
relevance=relevance,
variance=variance,
entropy=entropy,
)
def predict_relevance(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
) -> torch.Tensor:
"""Get relevance scores directly. Returns [B] tensor."""
outputs = self.forward(input_ids=input_ids, attention_mask=attention_mask)
return outputs.relevance
def predict_with_uncertainty(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
) -> dict:
"""
Get relevance prediction with full uncertainty estimates.
Returns:
dict with:
- relevance: [B] predicted relevance scores
- variance: [B] prediction variance (higher = more uncertain)
- entropy: [B] distribution entropy (higher = more uncertain)
- probs: [B, num_bins] full probability distribution
- distribution_logits: [B, num_bins] raw logits
"""
outputs = self.forward(
input_ids=input_ids,
attention_mask=attention_mask,
output_distribution=True,
)
probs = F.softmax(outputs.distribution_logits, dim=-1)
return {
"relevance": outputs.relevance,
"variance": outputs.variance,
"entropy": outputs.entropy,
"probs": probs,
"distribution_logits": outputs.distribution_logits,
}