import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional, List from dataclasses import dataclass from transformers import PreTrainedModel, PretrainedConfig, LlamaModel, LlamaConfig from transformers.modeling_outputs import ModelOutput class ArmoRMConfig(PretrainedConfig): model_type = "armorm" def __init__( self, vocab_size=128256, hidden_size=4096, intermediate_size=14336, num_hidden_layers=32, num_attention_heads=32, num_key_value_heads=8, hidden_act="silu", max_position_embeddings=131072, initializer_range=0.02, rms_norm_eps=1e-5, use_cache=True, rope_theta=500000.0, attention_bias=False, attention_dropout=0.0, mlp_bias=False, num_objectives=5, objective_names=None, gating_hidden_dim=1024, gating_num_layers=4, temperature=10.0, **kwargs, ): self.vocab_size = vocab_size self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.num_key_value_heads = num_key_value_heads self.hidden_act = hidden_act self.max_position_embeddings = max_position_embeddings self.initializer_range = initializer_range self.rms_norm_eps = rms_norm_eps self.use_cache = use_cache self.rope_theta = rope_theta self.attention_bias = attention_bias self.attention_dropout = attention_dropout self.mlp_bias = mlp_bias self.num_objectives = num_objectives self.objective_names = objective_names or [ "statute_reference", "legal_accuracy", "case_law_reference", "linguistic_coherence", "depth_coverage" ] self.gating_hidden_dim = gating_hidden_dim self.gating_num_layers = gating_num_layers self.temperature = temperature super().__init__(**kwargs) @dataclass class ArmoRMOutput(ModelOutput): logits: Optional[torch.FloatTensor] = None score: Optional[torch.FloatTensor] = None rewards: Optional[torch.FloatTensor] = None gating_output: Optional[torch.FloatTensor] = None class GatingNetwork(nn.Module): def __init__(self, in_features, out_features, hidden_dim=1024, num_layers=4, temperature=10.0): super().__init__() self.temperature = temperature layers = [] current_dim = in_features for i in range(num_layers - 1): layers.append(nn.Linear(current_dim, hidden_dim)) current_dim = hidden_dim layers.append(nn.Linear(current_dim, out_features)) self.layers = nn.ModuleList(layers) def forward(self, x): for i, layer in enumerate(self.layers): x = layer(x) if i < len(self.layers) - 1: x = F.relu(x) x = F.softmax(x / self.temperature, dim=-1) return x class ArmoRMForSequenceClassification(PreTrainedModel): config_class = ArmoRMConfig base_model_prefix = "model" def __init__(self, config): super().__init__(config) self.config = config # LlamaModel as base llama_config = LlamaConfig( vocab_size=config.vocab_size, hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, num_hidden_layers=config.num_hidden_layers, num_attention_heads=config.num_attention_heads, num_key_value_heads=config.num_key_value_heads, hidden_act=config.hidden_act, max_position_embeddings=config.max_position_embeddings, initializer_range=config.initializer_range, rms_norm_eps=config.rms_norm_eps, use_cache=config.use_cache, rope_theta=config.rope_theta, attention_bias=config.attention_bias, attention_dropout=config.attention_dropout, mlp_bias=config.mlp_bias, ) self.model = LlamaModel(llama_config) # Regression layer for multi-objective rewards self.regression_layer = nn.Linear(config.hidden_size, config.num_objectives, bias=False) # Gating network self.gating = GatingNetwork( config.hidden_size, config.num_objectives, hidden_dim=config.gating_hidden_dim, num_layers=config.gating_num_layers, temperature=config.temperature ) # Reward transform matrix self.reward_transform_matrix = nn.Parameter( torch.eye(config.num_objectives), requires_grad=False ) self.post_init() def forward(self, input_ids=None, attention_mask=None, **kwargs): outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, **kwargs) hidden_states = outputs.last_hidden_state device = hidden_states.device # Last token pooling if attention_mask is not None: sequence_lengths = attention_mask.sum(dim=1) - 1 sequence_lengths = sequence_lengths.clamp(min=0).to(device) batch_size = hidden_states.size(0) batch_indices = torch.arange(batch_size, device=device) pooled = hidden_states[batch_indices, sequence_lengths] else: pooled = hidden_states[:, -1, :] # Multi-objective rewards (keep same dtype as pooled) rewards = self.regression_layer(pooled) # Gating weights gate_weights = self.gating(pooled) # Apply transform and compute final score (in float32 for stability) # Ensure all tensors are on the same device device = pooled.device rewards_f32 = rewards.float() gate_f32 = gate_weights.float() transform_f32 = self.reward_transform_matrix.to(device).float() coeffs = gate_f32 @ transform_f32.T score = (rewards_f32 * coeffs).sum(dim=-1, keepdim=True) return ArmoRMOutput( logits=score.to(pooled.dtype), score=score.to(pooled.dtype), rewards=rewards, gating_output=gate_weights, )