""" Chess Tiny Recursive Model v2 (TRM2) with Deep Recursion and Latent Updates. Advanced architecture featuring: - Causal self-attention with RoPE - Deep recursion steps with progressive refinement - Recursive latent state updates - Adaptive computation with learned halting - Cross-recursion attention for information flow Target: <1M parameters with superior recursive reasoning """ from __future__ import annotations import math from dataclasses import dataclass from typing import Optional, Tuple, Union, List import torch import torch.nn as nn import torch.nn.functional as F from transformers import PretrainedConfig, PreTrainedModel from transformers.modeling_outputs import CausalLMOutputWithPast class ChessConfig(PretrainedConfig): """ Configuration for Chess TRM2 model with deep recursion. Optimized for ~950K parameters with advanced recursion. """ # model_type = "chess_trm2" model_type = "chess_transformer" def __init__( self, vocab_size: int = 800, n_embd: int = 192, n_layer: int = 2, n_head: int = 4, n_ctx: int = 256, n_inner: Optional[int] = None, dropout: float = 0.1, attention_dropout: float = 0.1, layer_norm_epsilon: float = 1e-5, tie_weights: bool = True, # Deep recursion parameters n_recursions: int = 6, # Deep recursion steps latent_dim: int = 64, # Latent state dimension use_adaptive_depth: bool = True, # Learn when to stop recursion halting_threshold: float = 0.9, # Threshold for adaptive halting # RoPE parameters use_rope: bool = True, rope_theta: float = 10000.0, # Training parameters label_smoothing: float = 0.1, auxiliary_loss_weight: float = 0.1, # Weight for auxiliary losses pad_token_id: int = 0, bos_token_id: int = 1, eos_token_id: int = 2, **kwargs, ): super().__init__( pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs, ) self.vocab_size = vocab_size self.n_embd = n_embd self.n_layer = n_layer self.n_head = n_head self.n_ctx = n_ctx self.n_inner = n_inner if n_inner is not None else int(2.33 * n_embd) self.dropout = dropout self.attention_dropout = attention_dropout self.layer_norm_epsilon = layer_norm_epsilon self.tie_weights = tie_weights self.n_recursions = n_recursions self.latent_dim = latent_dim self.use_adaptive_depth = use_adaptive_depth self.halting_threshold = halting_threshold self.use_rope = use_rope self.rope_theta = rope_theta self.label_smoothing = label_smoothing self.auxiliary_loss_weight = auxiliary_loss_weight # ============================================================================ # Core Building Blocks # ============================================================================ class RotaryEmbedding(nn.Module): """Rotary Position Embedding (RoPE) for position-aware attention.""" def __init__(self, dim: int, max_seq_len: int = 512, theta: float = 10000.0): super().__init__() self.dim = dim self.max_seq_len = max_seq_len self.theta = theta inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) self._build_cache(max_seq_len) def _build_cache(self, seq_len: int): t = torch.arange(seq_len, device=self.inv_freq.device) freqs = torch.outer(t, self.inv_freq) emb = torch.cat([freqs, freqs], dim=-1) self.register_buffer("cos_cached", emb.cos(), persistent=False) self.register_buffer("sin_cached", emb.sin(), persistent=False) def forward(self, x: torch.Tensor, seq_len: int) -> Tuple[torch.Tensor, torch.Tensor]: if seq_len > self.max_seq_len: self._build_cache(seq_len) return ( self.cos_cached[:seq_len].to(x.dtype), self.sin_cached[:seq_len].to(x.dtype), ) def rotate_half(x: torch.Tensor) -> torch.Tensor: """Rotate half the hidden dims.""" x1, x2 = x.chunk(2, dim=-1) return torch.cat([-x2, x1], dim=-1) def apply_rotary_pos_emb( q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """Apply rotary embeddings to query and key tensors.""" cos = cos.unsqueeze(0).unsqueeze(0) sin = sin.unsqueeze(0).unsqueeze(0) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed class CausalSelfAttention(nn.Module): """Causal self-attention with RoPE support.""" def __init__(self, config: ChessConfig): super().__init__() self.n_head = config.n_head self.n_embd = config.n_embd self.head_dim = config.n_embd // config.n_head assert config.n_embd % config.n_head == 0 self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=False) self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False) self.attn_dropout = nn.Dropout(config.attention_dropout) self.resid_dropout = nn.Dropout(config.dropout) if config.use_rope: self.rope = RotaryEmbedding(self.head_dim, config.n_ctx, config.rope_theta) else: self.rope = None self.register_buffer( "causal_mask", torch.tril(torch.ones(config.n_ctx, config.n_ctx, dtype=torch.bool)), persistent=False, ) def forward( self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: B, T, C = x.size() qkv = self.c_attn(x) q, k, v = qkv.split(self.n_embd, dim=-1) q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2) k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2) v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2) if self.rope is not None: cos, sin = self.rope(x, T) q, k = apply_rotary_pos_emb(q, k, cos, sin) scale = 1.0 / math.sqrt(self.head_dim) attn = torch.matmul(q, k.transpose(-2, -1)) * scale causal_mask = self.causal_mask[:T, :T] attn = attn.masked_fill(~causal_mask, float("-inf")) if attention_mask is not None: attn_mask = attention_mask.unsqueeze(1).unsqueeze(2) attn = attn.masked_fill(attn_mask == 0, float("-inf")) attn = F.softmax(attn, dim=-1) attn = self.attn_dropout(attn) y = torch.matmul(attn, v) y = y.transpose(1, 2).contiguous().view(B, T, C) y = self.resid_dropout(self.c_proj(y)) return y class SwiGLUFFN(nn.Module): """SwiGLU Feed-Forward Network for improved performance.""" def __init__(self, config: ChessConfig): super().__init__() self.c_fc = nn.Linear(config.n_embd, config.n_inner, bias=False) self.c_gate = nn.Linear(config.n_embd, config.n_inner, bias=False) self.c_proj = nn.Linear(config.n_inner, config.n_embd, bias=False) self.dropout = nn.Dropout(config.dropout) def forward(self, x: torch.Tensor) -> torch.Tensor: gate = torch.sigmoid(self.c_gate(x)) h = F.silu(self.c_fc(x)) * gate return self.dropout(self.c_proj(h)) # ============================================================================ # Recursive Latent Update Components # ============================================================================ class LatentStateEncoder(nn.Module): """ Encodes sequence hidden states into a compact latent representation. This latent captures global context for recursive refinement. """ def __init__(self, config: ChessConfig): super().__init__() self.n_embd = config.n_embd self.latent_dim = config.latent_dim # Attention pooling to create sequence-level latent self.query = nn.Parameter(torch.randn(1, 1, config.n_embd) * 0.02) self.attn = nn.MultiheadAttention( config.n_embd, num_heads=config.n_head, dropout=config.attention_dropout, batch_first=True, ) # Project to latent space self.proj = nn.Sequential( nn.Linear(config.n_embd, config.latent_dim), nn.SiLU(), nn.Linear(config.latent_dim, config.latent_dim), ) self.ln = nn.LayerNorm(config.latent_dim, eps=config.layer_norm_epsilon) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: """ Args: hidden_states: [B, T, n_embd] Returns: latent: [B, latent_dim] """ B = hidden_states.size(0) query = self.query.expand(B, -1, -1) # Cross-attention to pool sequence pooled, _ = self.attn(query, hidden_states, hidden_states) pooled = pooled.squeeze(1) # [B, n_embd] # Project to latent space latent = self.proj(pooled) return self.ln(latent) class LatentStateUpdater(nn.Module): """ Updates the latent state recursively. Implements a GRU-like update mechanism for stable recursion. """ def __init__(self, config: ChessConfig): super().__init__() self.latent_dim = config.latent_dim # GRU-style update gates self.update_gate = nn.Linear(config.latent_dim * 2, config.latent_dim) self.reset_gate = nn.Linear(config.latent_dim * 2, config.latent_dim) self.candidate = nn.Linear(config.latent_dim * 2, config.latent_dim) # Layer norm for stability self.ln = nn.LayerNorm(config.latent_dim, eps=config.layer_norm_epsilon) def forward( self, latent: torch.Tensor, new_info: torch.Tensor ) -> torch.Tensor: """ GRU-style update of latent state. Args: latent: Previous latent state [B, latent_dim] new_info: New information from current recursion [B, latent_dim] Returns: updated_latent: [B, latent_dim] """ combined = torch.cat([latent, new_info], dim=-1) # Compute gates z = torch.sigmoid(self.update_gate(combined)) # Update gate r = torch.sigmoid(self.reset_gate(combined)) # Reset gate # Compute candidate reset_latent = torch.cat([r * latent, new_info], dim=-1) h_tilde = torch.tanh(self.candidate(reset_latent)) # Update latent updated = (1 - z) * latent + z * h_tilde return self.ln(updated) class LatentConditioner(nn.Module): """ Conditions the hidden states using the latent representation. Injects global context into local representations. """ def __init__(self, config: ChessConfig): super().__init__() # FiLM-style conditioning (Feature-wise Linear Modulation) self.scale = nn.Linear(config.latent_dim, config.n_embd) self.shift = nn.Linear(config.latent_dim, config.n_embd) def forward( self, hidden_states: torch.Tensor, latent: torch.Tensor ) -> torch.Tensor: """ Apply FiLM conditioning. Args: hidden_states: [B, T, n_embd] latent: [B, latent_dim] Returns: conditioned: [B, T, n_embd] """ # Compute scale and shift from latent gamma = self.scale(latent).unsqueeze(1) # [B, 1, n_embd] beta = self.shift(latent).unsqueeze(1) # [B, 1, n_embd] # Apply FiLM: y = gamma * x + beta return gamma * hidden_states + beta # ============================================================================ # Deep Recursive Transformer Block # ============================================================================ class DeepRecursiveBlock(nn.Module): """ A transformer block designed for deep recursion with latent conditioning. Each recursion step: 1. Conditions hidden states with current latent 2. Applies attention and FFN 3. Updates latent based on new hidden states """ def __init__(self, config: ChessConfig): super().__init__() # Pre-norm layers self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) self.ln_2 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) # Core attention and FFN self.attn = CausalSelfAttention(config) self.ffn = SwiGLUFFN(config) # Latent conditioning self.conditioner = LatentConditioner(config) # Learnable residual gates (per-block, shared across recursions) self.gate_attn = nn.Parameter(torch.ones(1) * 0.5) self.gate_ffn = nn.Parameter(torch.ones(1) * 0.5) self.gate_latent = nn.Parameter(torch.ones(1) * 0.3) def forward( self, x: torch.Tensor, latent: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Args: x: Hidden states [B, T, n_embd] latent: Latent state [B, latent_dim] attention_mask: Optional attention mask Returns: Updated hidden states [B, T, n_embd] """ # Condition with latent (soft injection) gate_l = torch.sigmoid(self.gate_latent) x_cond = x + gate_l * (self.conditioner(x, latent) - x) # Attention with gated residual gate_a = torch.sigmoid(self.gate_attn) h = x_cond + gate_a * self.attn(self.ln_1(x_cond), attention_mask) # FFN with gated residual gate_f = torch.sigmoid(self.gate_ffn) h = h + gate_f * self.ffn(self.ln_2(h)) return h class AdaptiveHaltingModule(nn.Module): """ Learns when to stop recursion (Adaptive Computation Time inspired). Outputs a halting probability at each recursion step. """ def __init__(self, config: ChessConfig): super().__init__() self.threshold = config.halting_threshold # Halting predictor from latent state self.halt_predictor = nn.Sequential( nn.Linear(config.latent_dim, config.latent_dim // 2), nn.SiLU(), nn.Linear(config.latent_dim // 2, 1), nn.Sigmoid(), ) def forward(self, latent: torch.Tensor) -> torch.Tensor: """ Predict halting probability. Args: latent: [B, latent_dim] Returns: halt_prob: [B, 1] """ return self.halt_predictor(latent) # ============================================================================ # Main Model # ============================================================================ class ChessForCausalLM(PreTrainedModel): """ Chess Tiny Recursive Model v2 with Deep Recursion and Latent Updates. Architecture Overview: 1. Embed input tokens 2. Initialize latent state 3. For each recursion step: a. Condition hidden states with latent b. Apply transformer blocks c. Encode new latent from hidden states d. Update latent with GRU-style mechanism e. (Optional) Check adaptive halting 4. Final prediction Key Features: - Deep recursion (6+ steps) with shared weights - Recursive latent state that accumulates global context - FiLM conditioning for latent injection - Optional adaptive computation depth - Auxiliary losses for better latent learning """ config_class = ChessConfig base_model_prefix = "trm2" supports_gradient_checkpointing = True _tied_weights_keys = ["lm_head.weight"] def __init__(self, config: ChessConfig): super().__init__(config) # Token embeddings self.wte = nn.Embedding(config.vocab_size, config.n_embd) self.drop = nn.Dropout(config.dropout) # Position embeddings (fallback if RoPE disabled) if not config.use_rope: self.wpe = nn.Embedding(config.n_ctx, config.n_embd) # Initial latent state (learned) self.init_latent = nn.Parameter(torch.randn(config.latent_dim) * 0.02) # Latent processing modules self.latent_encoder = LatentStateEncoder(config) self.latent_updater = LatentStateUpdater(config) # Deep recursive transformer blocks (shared across recursions) self.blocks = nn.ModuleList([ DeepRecursiveBlock(config) for _ in range(config.n_layer) ]) # Adaptive halting (optional) if config.use_adaptive_depth: self.halting = AdaptiveHaltingModule(config) else: self.halting = None # Recursion step embeddings (helps differentiate recursion stages) self.recursion_emb = nn.Embedding(config.n_recursions, config.latent_dim) # Final layers self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) # Auxiliary prediction head (predicts from latent for regularization) self.aux_head = nn.Sequential( nn.Linear(config.latent_dim, config.n_embd), nn.SiLU(), nn.Linear(config.n_embd, config.vocab_size, bias=False), ) # Initialize weights self.post_init() def get_input_embeddings(self): return self.wte def set_input_embeddings(self, value): self.wte = value def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def _init_weights(self, module: nn.Module): """Initialize weights with careful scaling for deep recursion.""" if isinstance(module, nn.Linear): # Smaller init for stable deep recursion std = 0.02 / math.sqrt(2 * self.config.n_layer * self.config.n_recursions) torch.nn.init.normal_(module.weight, mean=0.0, std=std) if module.bias is not None: torch.nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) elif isinstance(module, nn.LayerNorm): torch.nn.init.ones_(module.weight) torch.nn.init.zeros_(module.bias) def forward( self, input_ids: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None, return_dict: Optional[bool] = None, output_recursion_states: bool = False, **kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: """ Forward pass with deep recursion and latent updates. Args: input_ids: Input token IDs [B, T] attention_mask: Attention mask [B, T] position_ids: Position IDs [B, T] labels: Target labels for loss computation return_dict: Whether to return a dict output_recursion_states: Whether to output intermediate states """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict batch_size, seq_len = input_ids.size() device = input_ids.device # Get token embeddings hidden_states = self.wte(input_ids) # Add position embeddings if not using RoPE if not self.config.use_rope: if position_ids is None: position_ids = torch.arange(seq_len, device=device).unsqueeze(0) hidden_states = hidden_states + self.wpe(position_ids) hidden_states = self.drop(hidden_states) # Initialize latent state (broadcast to batch) latent = self.init_latent.unsqueeze(0).expand(batch_size, -1) # Track recursion states for analysis/aux loss recursion_states = [] halting_probs = [] cumulative_halt = torch.zeros(batch_size, 1, device=device) # Deep recursion loop for r in range(self.config.n_recursions): # Add recursion step embedding to latent rec_emb = self.recursion_emb(torch.tensor(r, device=device)) latent_r = latent + rec_emb # Apply transformer blocks with latent conditioning for block in self.blocks: hidden_states = block(hidden_states, latent_r, attention_mask) # Encode new information from hidden states new_info = self.latent_encoder(hidden_states) # Update latent state recursively latent = self.latent_updater(latent, new_info) # Track state if output_recursion_states: recursion_states.append(hidden_states.clone()) # Adaptive halting check if self.halting is not None and self.training: halt_prob = self.halting(latent) halting_probs.append(halt_prob) cumulative_halt = cumulative_halt + halt_prob * (1 - cumulative_halt) # Early stopping during inference if not self.training and (cumulative_halt > self.config.halting_threshold).all(): break # Final layer norm and prediction hidden_states = self.ln_f(hidden_states) logits = self.lm_head(hidden_states) # Compute loss loss = None if labels is not None: shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Main cross-entropy loss loss_fct = nn.CrossEntropyLoss( ignore_index=-100, label_smoothing=self.config.label_smoothing if self.training else 0.0, ) main_loss = loss_fct( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), ) # Auxiliary loss: predict from final latent (regularization) aux_logits = self.aux_head(latent) # [B, vocab_size] # Use the last token as target for aux prediction last_token_mask = (labels != -100).sum(dim=-1) - 1 last_tokens = labels[torch.arange(batch_size, device=device), last_token_mask.clamp(min=0)] aux_loss = F.cross_entropy(aux_logits, last_tokens, ignore_index=-100) # Ponder cost (ACT regularization) - encourages early halting ponder_loss = torch.tensor(0.0, device=device) if self.halting is not None and len(halting_probs) > 0: ponder_cost = sum(p.mean() for p in halting_probs) / len(halting_probs) ponder_loss = (1.0 - ponder_cost) # Penalize late halting # Combined loss loss = main_loss + self.config.auxiliary_loss_weight * (aux_loss + ponder_loss) if not return_dict: output = (logits,) return ((loss,) + output) if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=None, hidden_states=recursion_states if output_recursion_states else None, attentions=None, ) @torch.no_grad() def generate_move( self, input_ids: torch.LongTensor, temperature: float = 1.0, top_k: Optional[int] = None, top_p: Optional[float] = None, ) -> int: """Generate the next move with deep recursive reasoning.""" self.eval() outputs = self(input_ids) logits = outputs.logits[:, -1, :] / temperature if top_k is not None: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = float("-inf") if top_p is not None: sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 indices_to_remove = sorted_indices_to_remove.scatter( dim=-1, index=sorted_indices, src=sorted_indices_to_remove ) logits[indices_to_remove] = float("-inf") probs = F.softmax(logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) return next_token.item() @torch.no_grad() def get_recursion_analysis( self, input_ids: torch.LongTensor, ) -> dict: """ Analyze the recursion process for interpretability. Returns intermediate states and halting probabilities. """ self.eval() batch_size, seq_len = input_ids.size() device = input_ids.device hidden_states = self.wte(input_ids) hidden_states = self.drop(hidden_states) latent = self.init_latent.unsqueeze(0).expand(batch_size, -1) analysis = { "latent_states": [latent.clone()], "hidden_norms": [], "halting_probs": [], } for r in range(self.config.n_recursions): rec_emb = self.recursion_emb(torch.tensor(r, device=device)) latent_r = latent + rec_emb for block in self.blocks: hidden_states = block(hidden_states, latent_r, None) new_info = self.latent_encoder(hidden_states) latent = self.latent_updater(latent, new_info) analysis["latent_states"].append(latent.clone()) analysis["hidden_norms"].append(hidden_states.norm(dim=-1).mean().item()) if self.halting is not None: halt_prob = self.halting(latent) analysis["halting_probs"].append(halt_prob.mean().item()) return analysis # Register the model from transformers import AutoConfig, AutoModelForCausalLM AutoConfig.register("chess_transformer", ChessConfig) AutoModelForCausalLM.register(ChessConfig, ChessForCausalLM)