trm-cgar-sudoku / code /trm_cgar.py
Kaleemullah's picture
Upload trm_cgar.py
05c69e6 verified
"""
TRM-CGAR: TRM with Curriculum-Guided Adaptive Recursion
Adds progressive depth curriculum for faster, more stable training
"""
import math
from typing import Tuple, List, Dict, Optional
import torch
from models.recursive_reasoning.trm import (
TinyRecursiveReasoningModel_ACTV1,
TinyRecursiveReasoningModel_ACTV1Config,
TinyRecursiveReasoningModel_ACTV1_Inner,
TinyRecursiveReasoningModel_ACTV1Carry,
TinyRecursiveReasoningModel_ACTV1InnerCarry
)
class TinyRecursiveReasoningModel_ACTV1_Inner_CGAR(TinyRecursiveReasoningModel_ACTV1_Inner):
"""
Enhanced TRM Inner with dynamic depth control for progressive curriculum.
"""
def __init__(self, config: TinyRecursiveReasoningModel_ACTV1Config) -> None:
super().__init__(config)
# Store base config values
self.base_H_cycles = config.H_cycles
self.base_L_cycles = config.L_cycles
# Current values (can be dynamically adjusted)
self.current_H_cycles = config.H_cycles
self.current_L_cycles = config.L_cycles
def set_curriculum_depth(self, progress: float):
"""
Set recursion depth based on training progress.
Args:
progress: Training progress from 0.0 (start) to 1.0 (end)
"""
stage1_H = max(1, math.ceil(self.base_H_cycles / 3))
stage1_L = max(2, math.ceil(self.base_L_cycles / 3))
stage2_H = max(stage1_H, math.ceil(2 * self.base_H_cycles / 3))
stage2_L = max(stage1_L, math.ceil(2 * self.base_L_cycles / 3))
if progress < 0.3:
# First 30%: Shallow (fast warmup)
self.current_H_cycles = stage1_H
self.current_L_cycles = stage1_L
elif progress < 0.6:
# Middle 30%: Medium
self.current_H_cycles = min(self.base_H_cycles, stage2_H)
self.current_L_cycles = min(self.base_L_cycles, stage2_L)
else:
# Final 40%: Full depth
self.current_H_cycles = self.base_H_cycles
self.current_L_cycles = self.base_L_cycles
def forward(self, carry: TinyRecursiveReasoningModel_ACTV1InnerCarry, batch: Dict[str, torch.Tensor]) -> Tuple[TinyRecursiveReasoningModel_ACTV1InnerCarry, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
seq_info = dict(
cos_sin=self.rotary_emb() if hasattr(self, "rotary_emb") else None,
)
# Input encoding
input_embeddings = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"])
# Forward iterations with DYNAMIC depth
z_H, z_L = carry.z_H, carry.z_L
# Use current_H_cycles and current_L_cycles instead of config values
H_cycles = self.current_H_cycles if self.training else self.base_H_cycles
L_cycles = self.current_L_cycles if self.training else self.base_L_cycles
# H_cycles-1 without grad
with torch.no_grad():
for _H_step in range(H_cycles - 1):
for _L_step in range(L_cycles):
z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info)
z_H = self.L_level(z_H, z_L, **seq_info)
# 1 with grad
for _L_step in range(L_cycles):
z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info)
z_H = self.L_level(z_H, z_L, **seq_info)
# LM Outputs
new_carry = TinyRecursiveReasoningModel_ACTV1InnerCarry(z_H=z_H.detach(), z_L=z_L.detach())
output = self.lm_head(z_H)[:, self.puzzle_emb_len:]
q_logits = self.q_head(z_H[:, 0]).to(torch.float32)
return new_carry, output, (q_logits[..., 0], q_logits[..., 1])
class TinyRecursiveReasoningModel_ACTV1_CGAR(TinyRecursiveReasoningModel_ACTV1):
"""ACT wrapper with CGAR enhancements."""
def __init__(self, config_dict: dict):
# Don't call super().__init__ yet, we need to replace inner first
super(TinyRecursiveReasoningModel_ACTV1, self).__init__() # Skip parent init
self.config = TinyRecursiveReasoningModel_ACTV1Config(**config_dict)
self.inner = TinyRecursiveReasoningModel_ACTV1_Inner_CGAR(self.config) # Use CGAR inner
def set_curriculum_depth(self, progress: float):
"""Forward curriculum setting to inner model."""
self.inner.set_curriculum_depth(progress)