OmniScore-deberta-v3 / modeling_score_predictor.py
gagan3012's picture
Update modeling_score_predictor.py
7d99a97 verified
"""
ScorePredictorModel - Multi-output regression model for conversation scoring.
Compatible with Hugging Face's AutoModel with trust_remote_code=True.
Encoder-only architecture with explainability features.
Architecture Improvements:
- Multi-head attention pooling for better sequence representation
- Shared MLP backbone with task-specific heads
- Layer normalization for stability
- Residual connections in deeper heads
- Optional auxiliary loss for correlation between scores
"""
from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import (
AutoConfig,
AutoModel,
PreTrainedModel,
)
from transformers.modeling_outputs import ModelOutput
from dataclasses import dataclass, field
import math
from .configuration_score_predictor import ScorePredictorConfig
from .explain_score_predictor import ScorePredictorExplainer
@dataclass
class ScorePredictorOutput(ModelOutput):
"""
Output class for ScorePredictorModel.
Args:
loss: Combined loss if labels provided
predictions: Predicted scores [batch_size, num_scores]
hidden_states: Hidden states from backbone (optional)
attentions: Attention weights from backbone (optional)
per_score_loss: Individual loss per score (optional)
"""
loss: Optional[torch.FloatTensor] = None
predictions: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
per_score_loss: Optional[Dict[str, float]] = None
@dataclass
class ExplainabilityOutput:
"""
Output class for explainability methods.
Args:
predictions: Predicted scores [batch_size, num_scores]
token_attributions: Attribution scores per token per output score
Dict[score_name, Tensor[batch_size, seq_len]]
attention_weights: Aggregated attention weights [batch_size, seq_len]
layer_attention_weights: Per-layer attention [num_layers, batch_size, seq_len]
head_importance: Importance of each attention head Dict[score_name, Tensor]
token_importance_ranking: Ranked token indices by importance
input_tokens: List of input tokens (if tokenizer provided)
score_contributions: Contribution breakdown per score
confidence_scores: Confidence/uncertainty estimates per score
"""
predictions: torch.FloatTensor = None
token_attributions: Dict[str, torch.FloatTensor] = field(default_factory=dict)
attention_weights: Optional[torch.FloatTensor] = None
layer_attention_weights: Optional[torch.FloatTensor] = None
head_importance: Dict[str, torch.FloatTensor] = field(default_factory=dict)
token_importance_ranking: Dict[str, List[int]] = field(default_factory=dict)
input_tokens: Optional[List[List[str]]] = None
score_contributions: Dict[str, Dict[str, float]] = field(default_factory=dict)
confidence_scores: Dict[str, float] = field(default_factory=dict)
class AttentionPooling(nn.Module):
"""
Multi-head attention pooling layer.
Learns to attend to important tokens rather than just using CLS or mean.
"""
def __init__(self, hidden_size: int, num_heads: int = 4, dropout: float = 0.1):
super().__init__()
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
self.scale = self.head_dim ** -0.5
self.query = nn.Linear(hidden_size, hidden_size)
self.key = nn.Linear(hidden_size, hidden_size)
self.value = nn.Linear(hidden_size, hidden_size)
self.out_proj = nn.Linear(hidden_size, hidden_size)
self.dropout = nn.Dropout(dropout)
# Learnable query vector for pooling
self.pool_query = nn.Parameter(torch.randn(1, 1, hidden_size) * 0.02)
def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
batch_size, seq_len, hidden_size = hidden_states.shape
# Expand pool query for batch
query = self.pool_query.expand(batch_size, -1, -1) # [B, 1, H]
query = self.query(query) # [B, 1, H]
key = self.key(hidden_states) # [B, L, H]
value = self.value(hidden_states) # [B, L, H]
# Reshape for multi-head attention
query = query.view(batch_size, 1, self.num_heads, self.head_dim).transpose(1, 2)
key = key.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
value = value.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# Compute attention scores
attn_weights = torch.matmul(query, key.transpose(-2, -1)) * self.scale # [B, H, 1, L]
# Apply attention mask
if attention_mask is not None:
attn_mask = attention_mask.unsqueeze(1).unsqueeze(2) # [B, 1, 1, L]
attn_weights = attn_weights.masked_fill(attn_mask == 0, float('-inf'))
attn_weights = F.softmax(attn_weights, dim=-1)
attn_weights = self.dropout(attn_weights)
# Apply attention to values
attn_output = torch.matmul(attn_weights, value) # [B, H, 1, D]
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, 1, hidden_size)
attn_output = self.out_proj(attn_output).squeeze(1) # [B, H]
return attn_output
class ScoreHead(nn.Module):
"""
Score prediction head with optional depth and residual connections.
"""
def __init__(
self,
input_size: int,
hidden_size: int = 256,
num_layers: int = 2,
dropout: float = 0.1,
use_residual: bool = True
):
super().__init__()
self.use_residual = use_residual and (input_size == hidden_size)
layers = []
current_size = input_size
for i in range(num_layers - 1):
layers.extend([
nn.Linear(current_size, hidden_size),
nn.LayerNorm(hidden_size),
nn.GELU(),
nn.Dropout(dropout),
])
current_size = hidden_size
# Final projection to scalar
layers.append(nn.Linear(current_size, 1))
self.layers = nn.Sequential(*layers)
# Residual projection if sizes don't match
if use_residual and input_size != hidden_size:
self.residual_proj = nn.Linear(input_size, hidden_size)
else:
self.residual_proj = None
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.layers(x)
class SharedEncoder(nn.Module):
"""
Shared MLP encoder before task-specific heads.
Captures common patterns across all scoring dimensions.
"""
def __init__(
self,
input_size: int,
hidden_size: int = 512,
output_size: int = 256,
num_layers: int = 2,
dropout: float = 0.1
):
super().__init__()
layers = []
current_size = input_size
for i in range(num_layers):
out_size = output_size if i == num_layers - 1 else hidden_size
layers.extend([
nn.Linear(current_size, out_size),
nn.LayerNorm(out_size),
nn.GELU(),
nn.Dropout(dropout),
])
current_size = out_size
self.layers = nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.layers(x)
class ScorePredictorModel(PreTrainedModel):
"""
Multi-output regression model for encoder backbones.
Predicts multiple scores (default: 4) for conversation quality assessment.
Architecture:
1. Backbone encoder (BERT, RoBERTa, etc.)
2. Multi-pooling: CLS + Mean + Attention Pooling (concatenated)
3. Shared encoder MLP for common feature extraction
4. Task-specific score heads with LayerNorm and GELU
5. Sigmoid scaling to [1.0, 5.0] range
Scores: Informativeness, Clarity, Plausibility, Faithfulness
Output range: [1.0, 5.0] via sigmoid activation
Includes explainability features:
- Attention-based token importance
- Gradient-based attribution (Integrated Gradients, Saliency)
- Attention rollout
- Confidence estimation
"""
config_class = ScorePredictorConfig
base_model_prefix = "backbone"
supports_gradient_checkpointing = True
# Supported encoder model types
ENCODER_MODEL_TYPES = {
'bert', 'roberta', 'distilbert', 'albert', 'electra',
'deberta', 'deberta-v2', 'xlm-roberta', 'camembert',
'flaubert', 'xlm', 'longformer', 'funnel', 'modernbert',
'qwen3', 'gemma3_text', 'qwen2'
}
def __init__(self, config: ScorePredictorConfig):
super().__init__(config)
self.config = config
self.num_scores = config.num_scores
# Load backbone model with optional Flash Attention
backbone_config = AutoConfig.from_pretrained(config.backbone_model_name, trust_remote_code=True)
attn_implementation = getattr(config, 'attn_implementation', None)
model_kwargs = {}
self._dtype = torch.float32 # Default dtype
if attn_implementation is not None:
model_kwargs['attn_implementation'] = attn_implementation
if backbone_config.model_type in ['bert']:
attn_implementation = "eager" # Force eager for BERT due to compatibility
# Set dtype for flash attention compatibility
if attn_implementation == 'flash_attention_2':
model_kwargs['torch_dtype'] = torch.bfloat16
self._dtype = torch.bfloat16
# self.backbone = AutoModel.from_pretrained(
# config.backbone_model_name,
# trust_remote_code=True,
# **model_kwargs
# )
if attn_implementation is not None:
backbone_config._attn_implementation = attn_implementation
self.backbone = AutoModel.from_config(backbone_config, trust_remote_code=True)
if self._dtype == torch.bfloat16:
self.backbone = self.backbone.to(torch.bfloat16)
# Store hidden size
self.hidden_size = backbone_config.hidden_size
config.hidden_size = self.hidden_size
self.max_position_embeddings = getattr(backbone_config, 'max_position_embeddings', 512)
config.max_position_embeddings = self.max_position_embeddings
# Architecture hyperparameters (can be configured)
self.use_attention_pooling = getattr(config, 'use_attention_pooling', True)
self.use_shared_encoder = getattr(config, 'use_shared_encoder', True)
dropout_prob = getattr(config, 'hidden_dropout_prob', 0.1)
head_hidden_size = getattr(config, 'head_hidden_size', 256)
shared_hidden_size = getattr(config, 'shared_hidden_size', 512)
# Pooling layers
# Combined: CLS (H) + Mean (H) + Attention (H) = 3H
self.attention_pooling = AttentionPooling(
self.hidden_size,
num_heads=4,
dropout=dropout_prob
) if self.use_attention_pooling else None
# Calculate pooled size
pooled_size = self.hidden_size * (3 if self.use_attention_pooling else 2)
# Shared encoder for common patterns
if self.use_shared_encoder:
self.shared_encoder = SharedEncoder(
input_size=pooled_size,
hidden_size=shared_hidden_size,
output_size=head_hidden_size,
num_layers=2,
dropout=dropout_prob
)
head_input_size = head_hidden_size
else:
self.shared_encoder = None
head_input_size = pooled_size
# Task-specific score heads
self.score_heads = nn.ModuleList([
ScoreHead(
input_size=head_input_size,
hidden_size=head_hidden_size // 2,
num_layers=2,
dropout=dropout_prob,
use_residual=False
) for _ in range(self.num_scores)
])
# Loss function - Huber loss is more robust to outliers than MSE
self.loss_fn = nn.SmoothL1Loss(beta=0.5, reduction='none')
# Per-score loss weights (can be adjusted for imbalanced importance)
self.score_loss_weights = nn.Parameter(
torch.ones(self.num_scores),
requires_grad=False
)
# For explainability: store intermediate activations
self._activations = {}
self._gradients = {}
# Initialize weights for heads
self.post_init()
def _init_weights(self, module):
"""Initialize weights for the regression heads."""
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=0.02)
if module.bias is not None:
module.bias.data.zero_()
def _last_token_pool(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
"""
Returns hidden state of the last non-padding token for each sequence.
Works for both left- and right-padding.
hidden_states: [B, L, H]
attention_mask: [B, L] with 1 for tokens, 0 for padding
"""
# If left-padded, the last position is always a real token
left_padded = (attention_mask[:, -1].sum() == attention_mask.size(0))
if left_padded:
return hidden_states[:, -1, :] # [B, H]
# Right-padded: last real token index = sum(mask) - 1
idx = attention_mask.sum(dim=1).clamp_min(1) - 1 # [B]
b = torch.arange(hidden_states.size(0), device=hidden_states.device)
return hidden_states[b, idx, :] # [B, H]
def _mean_pool(
self,
last_hidden_states: torch.Tensor, attention_mask: torch.Tensor
) -> torch.Tensor:
"""
Mean pooling over valid tokens only.
Args:
last_hidden_states: [batch_size, seq_len, hidden_size]
attention_mask: [batch_size, seq_len]
Returns:
Pooled representation [batch_size, hidden_size]
"""
input_mask_expanded = (
attention_mask.unsqueeze(-1).expand(last_hidden_states.size()).float()
)
sum_embeddings = torch.sum(last_hidden_states * input_mask_expanded, 1)
sum_mask = input_mask_expanded.sum(1)
sum_mask = torch.clamp(sum_mask, min=1e-9)
output_vectors = sum_embeddings / sum_mask
return output_vectors
def _pool_hidden_states(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor
) -> torch.Tensor:
"""
Multi-strategy pooling: CLS + Mean + Attention (optional).
Args:
hidden_states: [batch_size, seq_len, hidden_size]
attention_mask: [batch_size, seq_len]
Returns:
Pooled representation [batch_size, pooled_size]
Where pooled_size = hidden_size * 3 (with attention) or * 2 (without)
"""
# CLS token pooling
# cls_output = hidden_states[:, 0, :] # [B, H]
if self.backbone.config.model_type in ['qwen3'] and not getattr(self.backbone.config, 'use_bidirectional_attention', False):
cls_output = self._last_token_pool(hidden_states, attention_mask) # [B, H]
elif self.backbone.config.model_type in ['qwen3'] and getattr(self.backbone.config, 'use_bidirectional_attention', False):
cls_output = self._mean_pool(hidden_states, attention_mask) # [B, H]
else:
cls_output = hidden_states[:, 0, :] # [B, H]
# Mean pooling (masked)
masked = hidden_states * attention_mask.unsqueeze(-1)
mean_output = masked.sum(1) / attention_mask.sum(1, keepdim=True).clamp_min(1) # [B, H]
# Attention pooling (if enabled)
if self.attention_pooling is not None:
attn_output = self.attention_pooling(hidden_states, attention_mask) # [B, H]
pooled = torch.cat([cls_output, mean_output, attn_output], dim=-1) # [B, 3H]
else:
pooled = torch.cat([cls_output, mean_output], dim=-1) # [B, 2H]
return pooled
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs
) -> Union[ScorePredictorOutput, Tuple]:
"""
Forward pass for score prediction.
Args:
input_ids: Token IDs [batch_size, seq_len]
attention_mask: Attention mask [batch_size, seq_len]
labels: Ground truth scores [batch_size, num_scores] (optional)
output_hidden_states: Whether to return hidden states
output_attentions: Whether to return attention weights
return_dict: Whether to return ModelOutput or tuple
Returns:
ScorePredictorOutput or tuple containing loss, predictions, etc.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Create attention mask if not provided
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
# Forward through backbone
backbone_outputs = self.backbone(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions,
return_dict=True,
**kwargs
)
# Get last hidden states
hidden_states = backbone_outputs.last_hidden_state # [batch_size, seq_len, hidden_size]
# Pool hidden states using multiple strategies
pooled_output = self._pool_hidden_states(hidden_states, attention_mask)
# Ensure pooled output matches head dtype (handle mixed precision)
target_dtype = next(self.score_heads[0].parameters()).dtype
pooled_output = pooled_output.to(target_dtype)
# Apply shared encoder if enabled
if self.shared_encoder is not None:
shared_features = self.shared_encoder(pooled_output) # [B, head_hidden_size]
else:
shared_features = pooled_output
# Apply task-specific score heads with sigmoid scaling to [1, 5]
predictions_list = []
for head in self.score_heads:
# Output = 1 + 4 * sigmoid(x) maps to range [1, 5]
raw_score = head(shared_features) # [batch_size, 1]
score = 1.0 + 4.0 * torch.sigmoid(raw_score)
predictions_list.append(score)
# Concatenate predictions: [batch_size, num_scores]
predictions = torch.cat(predictions_list, dim=-1)
# Calculate loss if labels provided
loss = None
per_score_loss = None
if labels is not None:
# Per-score loss computation
per_score_losses = self.loss_fn(predictions, labels.float()) # [B, num_scores]
# Weighted average across scores
weighted_losses = per_score_losses * self.score_loss_weights.unsqueeze(0)
loss = weighted_losses.mean()
# Store per-score loss for monitoring
per_score_loss = {
name: per_score_losses[:, i].mean().item()
for i, name in enumerate(self.config.score_names)
}
if not return_dict:
output = (predictions,)
if output_hidden_states:
output += (backbone_outputs.hidden_states,)
if output_attentions:
output += (backbone_outputs.attentions,)
return ((loss,) + output) if loss is not None else output
return ScorePredictorOutput(
loss=loss,
predictions=predictions,
hidden_states=backbone_outputs.hidden_states if output_hidden_states else None,
attentions=backbone_outputs.attentions if output_attentions else None,
per_score_loss=per_score_loss,
)
# =========================================================================
# UTILITY METHODS
# =========================================================================
def get_input_embeddings(self):
"""Get input embeddings from backbone."""
return self.backbone.get_input_embeddings()
def set_input_embeddings(self, value):
"""Set input embeddings in backbone."""
self.backbone.set_input_embeddings(value)
def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
"""Resize token embeddings."""
return self.backbone.resize_token_embeddings(new_num_tokens)
@classmethod
def from_backbone(
cls,
backbone_model_name: str,
num_scores: int = 4,
score_names: List[str] = None,
**kwargs
) -> "ScorePredictorModel":
"""
Create a ScorePredictorModel from a backbone model name.
Args:
backbone_model_name: HuggingFace encoder model name or path
num_scores: Number of regression outputs
score_names: Names of the scores
**kwargs: Additional config arguments
Returns:
Initialized ScorePredictorModel
"""
config = ScorePredictorConfig(
backbone_model_name=backbone_model_name,
num_scores=num_scores,
score_names=score_names,
**kwargs
)
return cls(config)
def get_explainer(self, tokenizer=None):
"""
Return a ``ScorePredictorExplainer`` bound to this model.
The tokenizer is loaded automatically from
``config.backbone_model_name`` if not provided.
Usage::
model = AutoModel.from_pretrained(
"QCRI/OmniScore-deberta-v3", trust_remote_code=True
)
explainer = model.get_explainer()
result = explainer.explain("Task: qa\\n Output: The answer is 42.")
print(explainer.format(result))
Parameters
----------
tokenizer : optional
A pre-loaded tokenizer. When ``None`` (default), one is
created from ``self.config.backbone_model_name``.
Returns
-------
ScorePredictorExplainer
"""
if tokenizer is None:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(
self.config.backbone_model_name, trust_remote_code=True
)
device = next(self.parameters()).device
return ScorePredictorExplainer(self, tokenizer, device)
def predict_scores(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> Dict[str, torch.Tensor]:
"""
Convenience method for inference.
Returns:
Dictionary mapping score names to predicted values
"""
self.eval()
with torch.no_grad():
outputs = self.forward(
input_ids=input_ids,
attention_mask=attention_mask,
return_dict=True
)
predictions = outputs.predictions
score_names = self.config.score_names
return {
name: predictions[:, i]
for i, name in enumerate(score_names)
}