Samsung-TRM / models /recursive_reasoning /transformers_baseline.py
wtfmahe's picture
Upload 34 files
d12141a verified
"""
HRM ACT V2: Transformer Baseline for Architecture Ablation
This is an architecture ablation of the Hierarchical Reasoning Model (HRM).
Key changes from V1:
1. REMOVED hierarchical split (no separate H and L levels)
2. REMOVED inner cycles (no H_cycles/L_cycles loops within reasoning)
3. KEPT ACT outer loop structure intact
4. KEPT all data preprocessing, embeddings, and evaluation infrastructure
Architecture: Single-level transformer that processes the full 30x30 grid as a
900-token sequence, with the same positional encodings and sparse embeddings as V1.
"""
from typing import Tuple, List, Dict, Optional
from dataclasses import dataclass
import math
import torch
import torch.nn.functional as F
from torch import nn
from pydantic import BaseModel
from models.common import trunc_normal_init_
from models.layers import rms_norm, SwiGLU, Attention, RotaryEmbedding, CosSin, CastedEmbedding, CastedLinear
from models.sparse_embedding import CastedSparseEmbedding
@dataclass
class Model_ACTV2InnerCarry:
z_H: torch.Tensor
@dataclass
class Model_ACTV2Carry:
inner_carry: Model_ACTV2InnerCarry
steps: torch.Tensor
halted: torch.Tensor
current_data: Dict[str, torch.Tensor]
class Model_ACTV2Config(BaseModel):
batch_size: int
seq_len: int
puzzle_emb_ndim: int = 0
num_puzzle_identifiers: int
vocab_size: int
H_cycles: int
H_layers: int
# Transformer config
hidden_size: int
expansion: float
num_heads: int
pos_encodings: str
rms_norm_eps: float = 1e-5
rope_theta: float = 10000.0
# Halting Q-learning config
halt_max_steps: int
halt_exploration_prob: float
act_enabled: bool = True # If False, always run halt_max_steps (no early stopping during training)
act_inference: bool = False # If True, use adaptive computation during inference
forward_dtype: str = "bfloat16"
class Model_ACTV2Block(nn.Module):
def __init__(self, config: Model_ACTV2Config) -> None:
super().__init__()
self.self_attn = Attention(
hidden_size=config.hidden_size,
head_dim=config.hidden_size // config.num_heads,
num_heads=config.num_heads,
num_key_value_heads=config.num_heads,
causal=False,
)
self.mlp = SwiGLU(
hidden_size=config.hidden_size,
expansion=config.expansion,
)
self.norm_eps = config.rms_norm_eps
def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
# Post Norm
# Self Attention
hidden_states = rms_norm(
hidden_states + self.self_attn(cos_sin=cos_sin, hidden_states=hidden_states),
variance_epsilon=self.norm_eps,
)
# Fully Connected
hidden_states = rms_norm(hidden_states + self.mlp(hidden_states), variance_epsilon=self.norm_eps)
return hidden_states
class Model_ACTV2ReasoningModule(nn.Module):
def __init__(self, layers: List[Model_ACTV2Block]):
super().__init__()
self.layers = torch.nn.ModuleList(layers)
def forward(self, hidden_states: torch.Tensor, input_injection: torch.Tensor, **kwargs) -> torch.Tensor:
# Input injection (add)
hidden_states = hidden_states + input_injection
# Layers
for layer in self.layers:
hidden_states = layer(hidden_states=hidden_states, **kwargs)
return hidden_states
class Model_ACTV2_Inner(nn.Module):
def __init__(self, config: Model_ACTV2Config) -> None:
super().__init__()
self.config = config
self.forward_dtype = getattr(torch, self.config.forward_dtype)
# I/O
self.embed_scale = math.sqrt(self.config.hidden_size)
embed_init_std = 1.0 / self.embed_scale
self.embed_tokens = CastedEmbedding(
self.config.vocab_size,
self.config.hidden_size,
init_std=embed_init_std,
cast_to=self.forward_dtype,
)
self.lm_head = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False)
self.q_head = CastedLinear(self.config.hidden_size, 2, bias=True)
self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) # ceil div
if self.config.puzzle_emb_ndim > 0:
# Zero init puzzle embeddings
self.puzzle_emb = CastedSparseEmbedding(
self.config.num_puzzle_identifiers,
self.config.puzzle_emb_ndim,
batch_size=self.config.batch_size,
init_std=0,
cast_to=self.forward_dtype,
)
# LM Blocks
if self.config.pos_encodings == "rope":
self.rotary_emb = RotaryEmbedding(
dim=self.config.hidden_size // self.config.num_heads,
max_position_embeddings=self.config.seq_len + self.puzzle_emb_len,
base=self.config.rope_theta,
)
elif self.config.pos_encodings == "learned":
self.embed_pos = CastedEmbedding(
self.config.seq_len + self.puzzle_emb_len,
self.config.hidden_size,
init_std=embed_init_std,
cast_to=self.forward_dtype,
)
else:
raise NotImplementedError()
# Reasoning Layers
self.H_level = Model_ACTV2ReasoningModule(
layers=[Model_ACTV2Block(self.config) for _i in range(self.config.H_layers)]
)
# Initial states
self.H_init = nn.Buffer(
trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1),
persistent=True,
)
# Q head special init
# Init Q to (almost) zero for faster learning during bootstrapping
with torch.no_grad():
self.q_head.weight.zero_()
self.q_head.bias.fill_(-5) # type: ignore
def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: torch.Tensor):
# Token embedding
embedding = self.embed_tokens(input.to(torch.int32))
# Puzzle embeddings
if self.config.puzzle_emb_ndim > 0:
puzzle_embedding = self.puzzle_emb(puzzle_identifiers)
pad_count = self.puzzle_emb_len * self.config.hidden_size - puzzle_embedding.shape[-1]
if pad_count > 0:
puzzle_embedding = F.pad(puzzle_embedding, (0, pad_count))
embedding = torch.cat(
(puzzle_embedding.view(-1, self.puzzle_emb_len, self.config.hidden_size), embedding), dim=-2
)
# Position embeddings
if self.config.pos_encodings == "learned":
# scale by 1/sqrt(2) to maintain forward variance
embedding = 0.707106781 * (embedding + self.embed_pos.embedding_weight.to(self.forward_dtype))
# Scale
return self.embed_scale * embedding
def empty_carry(self, batch_size: int):
return Model_ACTV2InnerCarry(
z_H=torch.empty(
batch_size,
self.config.seq_len + self.puzzle_emb_len,
self.config.hidden_size,
dtype=self.forward_dtype,
),
)
def reset_carry(self, reset_flag: torch.Tensor, carry: Model_ACTV2InnerCarry):
return Model_ACTV2InnerCarry(
z_H=torch.where(reset_flag.view(-1, 1, 1), self.H_init, carry.z_H),
)
def forward(
self, carry: Model_ACTV2InnerCarry, batch: Dict[str, torch.Tensor]
) -> Tuple[Model_ACTV2InnerCarry, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
seq_info = dict(
cos_sin=self.rotary_emb() if hasattr(self, "rotary_emb") else None,
)
# Input encoding
input_embeddings = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"])
# 1-step grad
z_H = self.H_level(carry.z_H, input_embeddings, **seq_info)
# LM Outputs
new_carry = Model_ACTV2InnerCarry(
z_H=z_H.detach(),
) # New carry no grad
output = self.lm_head(z_H)[:, self.puzzle_emb_len :]
# Q head
q_logits = self.q_head(z_H[:, 0]).to(torch.float32)
return new_carry, output, (q_logits[..., 0], q_logits[..., 1])
class Model_ACTV2(nn.Module):
"""ACT wrapper."""
def __init__(self, config_dict: dict):
super().__init__()
self.config = Model_ACTV2Config(**config_dict)
self.inner = Model_ACTV2_Inner(self.config)
@property
def puzzle_emb(self):
return self.inner.puzzle_emb
def initial_carry(self, batch: Dict[str, torch.Tensor]):
batch_size = batch["inputs"].shape[0]
return Model_ACTV2Carry(
inner_carry=self.inner.empty_carry(
batch_size
), # Empty is expected, it will be reseted in first pass as all sequences are halted.
steps=torch.zeros((batch_size,), dtype=torch.int32),
halted=torch.ones((batch_size,), dtype=torch.bool), # Default to halted
current_data={k: torch.empty_like(v) for k, v in batch.items()},
)
def forward(
self,
carry: Model_ACTV2Carry,
batch: Dict[str, torch.Tensor],
compute_target_q: bool = False,
) -> Tuple[Model_ACTV2Carry, Dict[str, torch.Tensor]]:
# Update data, carry (removing halted sequences)
new_inner_carry = self.inner.reset_carry(carry.halted, carry.inner_carry)
new_steps = torch.where(carry.halted, 0, carry.steps)
new_current_data = {
k: torch.where(carry.halted.view((-1,) + (1,) * (batch[k].ndim - 1)), batch[k], v)
for k, v in carry.current_data.items()
}
# Forward inner model
new_inner_carry, logits, (q_halt_logits, q_continue_logits) = self.inner(
new_inner_carry, new_current_data
)
outputs = {"logits": logits, "q_halt_logits": q_halt_logits, "q_continue_logits": q_continue_logits}
with torch.no_grad():
# Step
new_steps = new_steps + 1
is_last_step = new_steps >= self.config.halt_max_steps
halted = is_last_step
# Check if adaptive computation should be used
use_adaptive = (self.config.halt_max_steps > 1) and (
(self.training and self.config.act_enabled)
or (not self.training and self.config.act_inference)
)
if use_adaptive:
# Halt signal based on Q-values (but always halt at max steps)
q_halt_signal = q_halt_logits > q_continue_logits
halted = halted | q_halt_signal
# Store actual steps used for logging (only during inference)
if not self.training:
outputs["actual_steps"] = new_steps.float()
# Exploration (only during training)
if self.training:
min_halt_steps = (
torch.rand_like(q_halt_logits) < self.config.halt_exploration_prob
) * torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1)
halted = halted & (new_steps >= min_halt_steps)
# Compute target Q (only during training)
# NOTE: No replay buffer and target networks for computing target Q-value.
# As batch_size is large, there're many parallel envs.
# Similar concept as PQN https://arxiv.org/abs/2407.04811
if self.training and compute_target_q:
next_q_halt_logits, next_q_continue_logits = self.inner(
new_inner_carry, new_current_data
)[-1]
outputs["target_q_continue"] = torch.sigmoid(
torch.where(
is_last_step,
next_q_halt_logits,
torch.maximum(next_q_halt_logits, next_q_continue_logits),
)
)
return Model_ACTV2Carry(
new_inner_carry, new_steps, halted, new_current_data
), outputs