|
|
""" |
|
|
TRM-CGAR: TRM with Curriculum-Guided Adaptive Recursion |
|
|
Adds progressive depth curriculum for faster, more stable training |
|
|
""" |
|
|
import math |
|
|
from typing import Tuple, List, Dict, Optional |
|
|
import torch |
|
|
from models.recursive_reasoning.trm import ( |
|
|
TinyRecursiveReasoningModel_ACTV1, |
|
|
TinyRecursiveReasoningModel_ACTV1Config, |
|
|
TinyRecursiveReasoningModel_ACTV1_Inner, |
|
|
TinyRecursiveReasoningModel_ACTV1Carry, |
|
|
TinyRecursiveReasoningModel_ACTV1InnerCarry |
|
|
) |
|
|
|
|
|
|
|
|
class TinyRecursiveReasoningModel_ACTV1_Inner_CGAR(TinyRecursiveReasoningModel_ACTV1_Inner): |
|
|
""" |
|
|
Enhanced TRM Inner with dynamic depth control for progressive curriculum. |
|
|
""" |
|
|
def __init__(self, config: TinyRecursiveReasoningModel_ACTV1Config) -> None: |
|
|
super().__init__(config) |
|
|
|
|
|
self.base_H_cycles = config.H_cycles |
|
|
self.base_L_cycles = config.L_cycles |
|
|
|
|
|
self.current_H_cycles = config.H_cycles |
|
|
self.current_L_cycles = config.L_cycles |
|
|
|
|
|
def set_curriculum_depth(self, progress: float): |
|
|
""" |
|
|
Set recursion depth based on training progress. |
|
|
|
|
|
Args: |
|
|
progress: Training progress from 0.0 (start) to 1.0 (end) |
|
|
""" |
|
|
stage1_H = max(1, math.ceil(self.base_H_cycles / 3)) |
|
|
stage1_L = max(2, math.ceil(self.base_L_cycles / 3)) |
|
|
stage2_H = max(stage1_H, math.ceil(2 * self.base_H_cycles / 3)) |
|
|
stage2_L = max(stage1_L, math.ceil(2 * self.base_L_cycles / 3)) |
|
|
|
|
|
if progress < 0.3: |
|
|
|
|
|
self.current_H_cycles = stage1_H |
|
|
self.current_L_cycles = stage1_L |
|
|
elif progress < 0.6: |
|
|
|
|
|
self.current_H_cycles = min(self.base_H_cycles, stage2_H) |
|
|
self.current_L_cycles = min(self.base_L_cycles, stage2_L) |
|
|
else: |
|
|
|
|
|
self.current_H_cycles = self.base_H_cycles |
|
|
self.current_L_cycles = self.base_L_cycles |
|
|
|
|
|
def forward(self, carry: TinyRecursiveReasoningModel_ACTV1InnerCarry, batch: Dict[str, torch.Tensor]) -> Tuple[TinyRecursiveReasoningModel_ACTV1InnerCarry, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: |
|
|
seq_info = dict( |
|
|
cos_sin=self.rotary_emb() if hasattr(self, "rotary_emb") else None, |
|
|
) |
|
|
|
|
|
|
|
|
input_embeddings = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"]) |
|
|
|
|
|
|
|
|
z_H, z_L = carry.z_H, carry.z_L |
|
|
|
|
|
H_cycles = self.current_H_cycles if self.training else self.base_H_cycles |
|
|
L_cycles = self.current_L_cycles if self.training else self.base_L_cycles |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
for _H_step in range(H_cycles - 1): |
|
|
for _L_step in range(L_cycles): |
|
|
z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info) |
|
|
z_H = self.L_level(z_H, z_L, **seq_info) |
|
|
|
|
|
for _L_step in range(L_cycles): |
|
|
z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info) |
|
|
z_H = self.L_level(z_H, z_L, **seq_info) |
|
|
|
|
|
|
|
|
new_carry = TinyRecursiveReasoningModel_ACTV1InnerCarry(z_H=z_H.detach(), z_L=z_L.detach()) |
|
|
output = self.lm_head(z_H)[:, self.puzzle_emb_len:] |
|
|
q_logits = self.q_head(z_H[:, 0]).to(torch.float32) |
|
|
return new_carry, output, (q_logits[..., 0], q_logits[..., 1]) |
|
|
|
|
|
|
|
|
class TinyRecursiveReasoningModel_ACTV1_CGAR(TinyRecursiveReasoningModel_ACTV1): |
|
|
"""ACT wrapper with CGAR enhancements.""" |
|
|
|
|
|
def __init__(self, config_dict: dict): |
|
|
|
|
|
super(TinyRecursiveReasoningModel_ACTV1, self).__init__() |
|
|
self.config = TinyRecursiveReasoningModel_ACTV1Config(**config_dict) |
|
|
self.inner = TinyRecursiveReasoningModel_ACTV1_Inner_CGAR(self.config) |
|
|
|
|
|
def set_curriculum_depth(self, progress: float): |
|
|
"""Forward curriculum setting to inner model.""" |
|
|
self.inner.set_curriculum_depth(progress) |
|
|
|