Muhakim / modeling_armorm.py
nmmursit's picture
Initial model upload - clean repository
6c75793 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, List
from dataclasses import dataclass
from transformers import PreTrainedModel, PretrainedConfig, LlamaModel, LlamaConfig
from transformers.modeling_outputs import ModelOutput
class ArmoRMConfig(PretrainedConfig):
model_type = "armorm"
def __init__(
self,
vocab_size=128256,
hidden_size=4096,
intermediate_size=14336,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=8,
hidden_act="silu",
max_position_embeddings=131072,
initializer_range=0.02,
rms_norm_eps=1e-5,
use_cache=True,
rope_theta=500000.0,
attention_bias=False,
attention_dropout=0.0,
mlp_bias=False,
num_objectives=5,
objective_names=None,
gating_hidden_dim=1024,
gating_num_layers=4,
temperature=10.0,
**kwargs,
):
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.mlp_bias = mlp_bias
self.num_objectives = num_objectives
self.objective_names = objective_names or [
"statute_reference",
"legal_accuracy",
"case_law_reference",
"linguistic_coherence",
"depth_coverage"
]
self.gating_hidden_dim = gating_hidden_dim
self.gating_num_layers = gating_num_layers
self.temperature = temperature
super().__init__(**kwargs)
@dataclass
class ArmoRMOutput(ModelOutput):
logits: Optional[torch.FloatTensor] = None
score: Optional[torch.FloatTensor] = None
rewards: Optional[torch.FloatTensor] = None
gating_output: Optional[torch.FloatTensor] = None
class GatingNetwork(nn.Module):
def __init__(self, in_features, out_features, hidden_dim=1024, num_layers=4, temperature=10.0):
super().__init__()
self.temperature = temperature
layers = []
current_dim = in_features
for i in range(num_layers - 1):
layers.append(nn.Linear(current_dim, hidden_dim))
current_dim = hidden_dim
layers.append(nn.Linear(current_dim, out_features))
self.layers = nn.ModuleList(layers)
def forward(self, x):
for i, layer in enumerate(self.layers):
x = layer(x)
if i < len(self.layers) - 1:
x = F.relu(x)
x = F.softmax(x / self.temperature, dim=-1)
return x
class ArmoRMForSequenceClassification(PreTrainedModel):
config_class = ArmoRMConfig
base_model_prefix = "model"
def __init__(self, config):
super().__init__(config)
self.config = config
# LlamaModel as base
llama_config = LlamaConfig(
vocab_size=config.vocab_size,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
num_hidden_layers=config.num_hidden_layers,
num_attention_heads=config.num_attention_heads,
num_key_value_heads=config.num_key_value_heads,
hidden_act=config.hidden_act,
max_position_embeddings=config.max_position_embeddings,
initializer_range=config.initializer_range,
rms_norm_eps=config.rms_norm_eps,
use_cache=config.use_cache,
rope_theta=config.rope_theta,
attention_bias=config.attention_bias,
attention_dropout=config.attention_dropout,
mlp_bias=config.mlp_bias,
)
self.model = LlamaModel(llama_config)
# Regression layer for multi-objective rewards
self.regression_layer = nn.Linear(config.hidden_size, config.num_objectives, bias=False)
# Gating network
self.gating = GatingNetwork(
config.hidden_size,
config.num_objectives,
hidden_dim=config.gating_hidden_dim,
num_layers=config.gating_num_layers,
temperature=config.temperature
)
# Reward transform matrix
self.reward_transform_matrix = nn.Parameter(
torch.eye(config.num_objectives), requires_grad=False
)
self.post_init()
def forward(self, input_ids=None, attention_mask=None, **kwargs):
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
hidden_states = outputs.last_hidden_state
device = hidden_states.device
# Last token pooling
if attention_mask is not None:
sequence_lengths = attention_mask.sum(dim=1) - 1
sequence_lengths = sequence_lengths.clamp(min=0).to(device)
batch_size = hidden_states.size(0)
batch_indices = torch.arange(batch_size, device=device)
pooled = hidden_states[batch_indices, sequence_lengths]
else:
pooled = hidden_states[:, -1, :]
# Multi-objective rewards (keep same dtype as pooled)
rewards = self.regression_layer(pooled)
# Gating weights
gate_weights = self.gating(pooled)
# Apply transform and compute final score (in float32 for stability)
# Ensure all tensors are on the same device
device = pooled.device
rewards_f32 = rewards.float()
gate_f32 = gate_weights.float()
transform_f32 = self.reward_transform_matrix.to(device).float()
coeffs = gate_f32 @ transform_f32.T
score = (rewards_f32 * coeffs).sum(dim=-1, keepdim=True)
return ArmoRMOutput(
logits=score.to(pooled.dtype),
score=score.to(pooled.dtype),
rewards=rewards,
gating_output=gate_weights,
)