| | """ |
| | Chess Tiny Recursive Model v2 (TRM2) with Deep Recursion and Latent Updates. |
| | |
| | Advanced architecture featuring: |
| | - Causal self-attention with RoPE |
| | - Deep recursion steps with progressive refinement |
| | - Recursive latent state updates |
| | - Adaptive computation with learned halting |
| | - Cross-recursion attention for information flow |
| | |
| | Target: <1M parameters with superior recursive reasoning |
| | """ |
| | from __future__ import annotations |
| | import math |
| | from dataclasses import dataclass |
| | from typing import Optional, Tuple, Union, List |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from transformers import PretrainedConfig, PreTrainedModel |
| | from transformers.modeling_outputs import CausalLMOutputWithPast |
| |
|
| |
|
| | class ChessConfig(PretrainedConfig): |
| | """ |
| | Configuration for Chess TRM2 model with deep recursion. |
| | |
| | Optimized for ~950K parameters with advanced recursion. |
| | """ |
| | |
| | model_type = "chess_transformer" |
| |
|
| | def __init__( |
| | self, |
| | vocab_size: int = 800, |
| | n_embd: int = 192, |
| | n_layer: int = 2, |
| | n_head: int = 4, |
| | n_ctx: int = 256, |
| | n_inner: Optional[int] = None, |
| | dropout: float = 0.1, |
| | attention_dropout: float = 0.1, |
| | layer_norm_epsilon: float = 1e-5, |
| | tie_weights: bool = True, |
| | |
| | n_recursions: int = 6, |
| | latent_dim: int = 64, |
| | use_adaptive_depth: bool = True, |
| | halting_threshold: float = 0.9, |
| | |
| | use_rope: bool = True, |
| | rope_theta: float = 10000.0, |
| | |
| | label_smoothing: float = 0.1, |
| | auxiliary_loss_weight: float = 0.1, |
| | pad_token_id: int = 0, |
| | bos_token_id: int = 1, |
| | eos_token_id: int = 2, |
| | **kwargs, |
| | ): |
| | super().__init__( |
| | pad_token_id=pad_token_id, |
| | bos_token_id=bos_token_id, |
| | eos_token_id=eos_token_id, |
| | **kwargs, |
| | ) |
| | self.vocab_size = vocab_size |
| | self.n_embd = n_embd |
| | self.n_layer = n_layer |
| | self.n_head = n_head |
| | self.n_ctx = n_ctx |
| | self.n_inner = n_inner if n_inner is not None else int(2.33 * n_embd) |
| | self.dropout = dropout |
| | self.attention_dropout = attention_dropout |
| | self.layer_norm_epsilon = layer_norm_epsilon |
| | self.tie_weights = tie_weights |
| | self.n_recursions = n_recursions |
| | self.latent_dim = latent_dim |
| | self.use_adaptive_depth = use_adaptive_depth |
| | self.halting_threshold = halting_threshold |
| | self.use_rope = use_rope |
| | self.rope_theta = rope_theta |
| | self.label_smoothing = label_smoothing |
| | self.auxiliary_loss_weight = auxiliary_loss_weight |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class RotaryEmbedding(nn.Module): |
| | """Rotary Position Embedding (RoPE) for position-aware attention.""" |
| | |
| | def __init__(self, dim: int, max_seq_len: int = 512, theta: float = 10000.0): |
| | super().__init__() |
| | self.dim = dim |
| | self.max_seq_len = max_seq_len |
| | self.theta = theta |
| | |
| | inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) |
| | self.register_buffer("inv_freq", inv_freq, persistent=False) |
| | self._build_cache(max_seq_len) |
| | |
| | def _build_cache(self, seq_len: int): |
| | t = torch.arange(seq_len, device=self.inv_freq.device) |
| | freqs = torch.outer(t, self.inv_freq) |
| | emb = torch.cat([freqs, freqs], dim=-1) |
| | self.register_buffer("cos_cached", emb.cos(), persistent=False) |
| | self.register_buffer("sin_cached", emb.sin(), persistent=False) |
| | |
| | def forward(self, x: torch.Tensor, seq_len: int) -> Tuple[torch.Tensor, torch.Tensor]: |
| | if seq_len > self.max_seq_len: |
| | self._build_cache(seq_len) |
| | return ( |
| | self.cos_cached[:seq_len].to(x.dtype), |
| | self.sin_cached[:seq_len].to(x.dtype), |
| | ) |
| |
|
| |
|
| | def rotate_half(x: torch.Tensor) -> torch.Tensor: |
| | """Rotate half the hidden dims.""" |
| | x1, x2 = x.chunk(2, dim=-1) |
| | return torch.cat([-x2, x1], dim=-1) |
| |
|
| |
|
| | def apply_rotary_pos_emb( |
| | q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor |
| | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | """Apply rotary embeddings to query and key tensors.""" |
| | cos = cos.unsqueeze(0).unsqueeze(0) |
| | sin = sin.unsqueeze(0).unsqueeze(0) |
| | q_embed = (q * cos) + (rotate_half(q) * sin) |
| | k_embed = (k * cos) + (rotate_half(k) * sin) |
| | return q_embed, k_embed |
| |
|
| |
|
| | class CausalSelfAttention(nn.Module): |
| | """Causal self-attention with RoPE support.""" |
| | |
| | def __init__(self, config: ChessConfig): |
| | super().__init__() |
| | self.n_head = config.n_head |
| | self.n_embd = config.n_embd |
| | self.head_dim = config.n_embd // config.n_head |
| | |
| | assert config.n_embd % config.n_head == 0 |
| | |
| | self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=False) |
| | self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False) |
| | |
| | self.attn_dropout = nn.Dropout(config.attention_dropout) |
| | self.resid_dropout = nn.Dropout(config.dropout) |
| | |
| | if config.use_rope: |
| | self.rope = RotaryEmbedding(self.head_dim, config.n_ctx, config.rope_theta) |
| | else: |
| | self.rope = None |
| | |
| | self.register_buffer( |
| | "causal_mask", |
| | torch.tril(torch.ones(config.n_ctx, config.n_ctx, dtype=torch.bool)), |
| | persistent=False, |
| | ) |
| | |
| | def forward( |
| | self, |
| | x: torch.Tensor, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | ) -> torch.Tensor: |
| | B, T, C = x.size() |
| | |
| | qkv = self.c_attn(x) |
| | q, k, v = qkv.split(self.n_embd, dim=-1) |
| | |
| | q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2) |
| | k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2) |
| | v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2) |
| | |
| | if self.rope is not None: |
| | cos, sin = self.rope(x, T) |
| | q, k = apply_rotary_pos_emb(q, k, cos, sin) |
| | |
| | scale = 1.0 / math.sqrt(self.head_dim) |
| | attn = torch.matmul(q, k.transpose(-2, -1)) * scale |
| | |
| | causal_mask = self.causal_mask[:T, :T] |
| | attn = attn.masked_fill(~causal_mask, float("-inf")) |
| | |
| | if attention_mask is not None: |
| | attn_mask = attention_mask.unsqueeze(1).unsqueeze(2) |
| | attn = attn.masked_fill(attn_mask == 0, float("-inf")) |
| | |
| | attn = F.softmax(attn, dim=-1) |
| | attn = self.attn_dropout(attn) |
| | |
| | y = torch.matmul(attn, v) |
| | y = y.transpose(1, 2).contiguous().view(B, T, C) |
| | y = self.resid_dropout(self.c_proj(y)) |
| | |
| | return y |
| |
|
| |
|
| | class SwiGLUFFN(nn.Module): |
| | """SwiGLU Feed-Forward Network for improved performance.""" |
| | |
| | def __init__(self, config: ChessConfig): |
| | super().__init__() |
| | self.c_fc = nn.Linear(config.n_embd, config.n_inner, bias=False) |
| | self.c_gate = nn.Linear(config.n_embd, config.n_inner, bias=False) |
| | self.c_proj = nn.Linear(config.n_inner, config.n_embd, bias=False) |
| | self.dropout = nn.Dropout(config.dropout) |
| | |
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | gate = torch.sigmoid(self.c_gate(x)) |
| | h = F.silu(self.c_fc(x)) * gate |
| | return self.dropout(self.c_proj(h)) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class LatentStateEncoder(nn.Module): |
| | """ |
| | Encodes sequence hidden states into a compact latent representation. |
| | This latent captures global context for recursive refinement. |
| | """ |
| | |
| | def __init__(self, config: ChessConfig): |
| | super().__init__() |
| | self.n_embd = config.n_embd |
| | self.latent_dim = config.latent_dim |
| | |
| | |
| | self.query = nn.Parameter(torch.randn(1, 1, config.n_embd) * 0.02) |
| | self.attn = nn.MultiheadAttention( |
| | config.n_embd, |
| | num_heads=config.n_head, |
| | dropout=config.attention_dropout, |
| | batch_first=True, |
| | ) |
| | |
| | |
| | self.proj = nn.Sequential( |
| | nn.Linear(config.n_embd, config.latent_dim), |
| | nn.SiLU(), |
| | nn.Linear(config.latent_dim, config.latent_dim), |
| | ) |
| | self.ln = nn.LayerNorm(config.latent_dim, eps=config.layer_norm_epsilon) |
| | |
| | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Args: |
| | hidden_states: [B, T, n_embd] |
| | Returns: |
| | latent: [B, latent_dim] |
| | """ |
| | B = hidden_states.size(0) |
| | query = self.query.expand(B, -1, -1) |
| | |
| | |
| | pooled, _ = self.attn(query, hidden_states, hidden_states) |
| | pooled = pooled.squeeze(1) |
| | |
| | |
| | latent = self.proj(pooled) |
| | return self.ln(latent) |
| |
|
| |
|
| | class LatentStateUpdater(nn.Module): |
| | """ |
| | Updates the latent state recursively. |
| | Implements a GRU-like update mechanism for stable recursion. |
| | """ |
| | |
| | def __init__(self, config: ChessConfig): |
| | super().__init__() |
| | self.latent_dim = config.latent_dim |
| | |
| | |
| | self.update_gate = nn.Linear(config.latent_dim * 2, config.latent_dim) |
| | self.reset_gate = nn.Linear(config.latent_dim * 2, config.latent_dim) |
| | self.candidate = nn.Linear(config.latent_dim * 2, config.latent_dim) |
| | |
| | |
| | self.ln = nn.LayerNorm(config.latent_dim, eps=config.layer_norm_epsilon) |
| | |
| | def forward( |
| | self, |
| | latent: torch.Tensor, |
| | new_info: torch.Tensor |
| | ) -> torch.Tensor: |
| | """ |
| | GRU-style update of latent state. |
| | |
| | Args: |
| | latent: Previous latent state [B, latent_dim] |
| | new_info: New information from current recursion [B, latent_dim] |
| | Returns: |
| | updated_latent: [B, latent_dim] |
| | """ |
| | combined = torch.cat([latent, new_info], dim=-1) |
| | |
| | |
| | z = torch.sigmoid(self.update_gate(combined)) |
| | r = torch.sigmoid(self.reset_gate(combined)) |
| | |
| | |
| | reset_latent = torch.cat([r * latent, new_info], dim=-1) |
| | h_tilde = torch.tanh(self.candidate(reset_latent)) |
| | |
| | |
| | updated = (1 - z) * latent + z * h_tilde |
| | return self.ln(updated) |
| |
|
| |
|
| | class LatentConditioner(nn.Module): |
| | """ |
| | Conditions the hidden states using the latent representation. |
| | Injects global context into local representations. |
| | """ |
| | |
| | def __init__(self, config: ChessConfig): |
| | super().__init__() |
| | |
| | self.scale = nn.Linear(config.latent_dim, config.n_embd) |
| | self.shift = nn.Linear(config.latent_dim, config.n_embd) |
| | |
| | def forward( |
| | self, |
| | hidden_states: torch.Tensor, |
| | latent: torch.Tensor |
| | ) -> torch.Tensor: |
| | """ |
| | Apply FiLM conditioning. |
| | |
| | Args: |
| | hidden_states: [B, T, n_embd] |
| | latent: [B, latent_dim] |
| | Returns: |
| | conditioned: [B, T, n_embd] |
| | """ |
| | |
| | gamma = self.scale(latent).unsqueeze(1) |
| | beta = self.shift(latent).unsqueeze(1) |
| | |
| | |
| | return gamma * hidden_states + beta |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class DeepRecursiveBlock(nn.Module): |
| | """ |
| | A transformer block designed for deep recursion with latent conditioning. |
| | |
| | Each recursion step: |
| | 1. Conditions hidden states with current latent |
| | 2. Applies attention and FFN |
| | 3. Updates latent based on new hidden states |
| | """ |
| | |
| | def __init__(self, config: ChessConfig): |
| | super().__init__() |
| | |
| | self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) |
| | self.ln_2 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) |
| | |
| | |
| | self.attn = CausalSelfAttention(config) |
| | self.ffn = SwiGLUFFN(config) |
| | |
| | |
| | self.conditioner = LatentConditioner(config) |
| | |
| | |
| | self.gate_attn = nn.Parameter(torch.ones(1) * 0.5) |
| | self.gate_ffn = nn.Parameter(torch.ones(1) * 0.5) |
| | self.gate_latent = nn.Parameter(torch.ones(1) * 0.3) |
| | |
| | def forward( |
| | self, |
| | x: torch.Tensor, |
| | latent: torch.Tensor, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | ) -> torch.Tensor: |
| | """ |
| | Args: |
| | x: Hidden states [B, T, n_embd] |
| | latent: Latent state [B, latent_dim] |
| | attention_mask: Optional attention mask |
| | Returns: |
| | Updated hidden states [B, T, n_embd] |
| | """ |
| | |
| | gate_l = torch.sigmoid(self.gate_latent) |
| | x_cond = x + gate_l * (self.conditioner(x, latent) - x) |
| | |
| | |
| | gate_a = torch.sigmoid(self.gate_attn) |
| | h = x_cond + gate_a * self.attn(self.ln_1(x_cond), attention_mask) |
| | |
| | |
| | gate_f = torch.sigmoid(self.gate_ffn) |
| | h = h + gate_f * self.ffn(self.ln_2(h)) |
| | |
| | return h |
| |
|
| |
|
| | class AdaptiveHaltingModule(nn.Module): |
| | """ |
| | Learns when to stop recursion (Adaptive Computation Time inspired). |
| | Outputs a halting probability at each recursion step. |
| | """ |
| | |
| | def __init__(self, config: ChessConfig): |
| | super().__init__() |
| | self.threshold = config.halting_threshold |
| | |
| | |
| | self.halt_predictor = nn.Sequential( |
| | nn.Linear(config.latent_dim, config.latent_dim // 2), |
| | nn.SiLU(), |
| | nn.Linear(config.latent_dim // 2, 1), |
| | nn.Sigmoid(), |
| | ) |
| | |
| | def forward(self, latent: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Predict halting probability. |
| | |
| | Args: |
| | latent: [B, latent_dim] |
| | Returns: |
| | halt_prob: [B, 1] |
| | """ |
| | return self.halt_predictor(latent) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class ChessForCausalLM(PreTrainedModel): |
| | """ |
| | Chess Tiny Recursive Model v2 with Deep Recursion and Latent Updates. |
| | |
| | Architecture Overview: |
| | 1. Embed input tokens |
| | 2. Initialize latent state |
| | 3. For each recursion step: |
| | a. Condition hidden states with latent |
| | b. Apply transformer blocks |
| | c. Encode new latent from hidden states |
| | d. Update latent with GRU-style mechanism |
| | e. (Optional) Check adaptive halting |
| | 4. Final prediction |
| | |
| | Key Features: |
| | - Deep recursion (6+ steps) with shared weights |
| | - Recursive latent state that accumulates global context |
| | - FiLM conditioning for latent injection |
| | - Optional adaptive computation depth |
| | - Auxiliary losses for better latent learning |
| | """ |
| | config_class = ChessConfig |
| | base_model_prefix = "trm2" |
| | supports_gradient_checkpointing = True |
| | _tied_weights_keys = ["lm_head.weight"] |
| |
|
| | def __init__(self, config: ChessConfig): |
| | super().__init__(config) |
| | |
| | |
| | self.wte = nn.Embedding(config.vocab_size, config.n_embd) |
| | self.drop = nn.Dropout(config.dropout) |
| | |
| | |
| | if not config.use_rope: |
| | self.wpe = nn.Embedding(config.n_ctx, config.n_embd) |
| | |
| | |
| | self.init_latent = nn.Parameter(torch.randn(config.latent_dim) * 0.02) |
| | |
| | |
| | self.latent_encoder = LatentStateEncoder(config) |
| | self.latent_updater = LatentStateUpdater(config) |
| | |
| | |
| | self.blocks = nn.ModuleList([ |
| | DeepRecursiveBlock(config) for _ in range(config.n_layer) |
| | ]) |
| | |
| | |
| | if config.use_adaptive_depth: |
| | self.halting = AdaptiveHaltingModule(config) |
| | else: |
| | self.halting = None |
| | |
| | |
| | self.recursion_emb = nn.Embedding(config.n_recursions, config.latent_dim) |
| | |
| | |
| | self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) |
| | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) |
| | |
| | |
| | self.aux_head = nn.Sequential( |
| | nn.Linear(config.latent_dim, config.n_embd), |
| | nn.SiLU(), |
| | nn.Linear(config.n_embd, config.vocab_size, bias=False), |
| | ) |
| | |
| | |
| | self.post_init() |
| |
|
| | def get_input_embeddings(self): |
| | return self.wte |
| |
|
| | def set_input_embeddings(self, value): |
| | self.wte = value |
| |
|
| | def get_output_embeddings(self): |
| | return self.lm_head |
| |
|
| | def set_output_embeddings(self, new_embeddings): |
| | self.lm_head = new_embeddings |
| |
|
| | def _init_weights(self, module: nn.Module): |
| | """Initialize weights with careful scaling for deep recursion.""" |
| | if isinstance(module, nn.Linear): |
| | |
| | std = 0.02 / math.sqrt(2 * self.config.n_layer * self.config.n_recursions) |
| | torch.nn.init.normal_(module.weight, mean=0.0, std=std) |
| | if module.bias is not None: |
| | torch.nn.init.zeros_(module.bias) |
| | elif isinstance(module, nn.Embedding): |
| | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| | elif isinstance(module, nn.LayerNorm): |
| | torch.nn.init.ones_(module.weight) |
| | torch.nn.init.zeros_(module.bias) |
| |
|
| | def forward( |
| | self, |
| | input_ids: torch.LongTensor, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | position_ids: Optional[torch.LongTensor] = None, |
| | labels: Optional[torch.LongTensor] = None, |
| | return_dict: Optional[bool] = None, |
| | output_recursion_states: bool = False, |
| | **kwargs, |
| | ) -> Union[Tuple, CausalLMOutputWithPast]: |
| | """ |
| | Forward pass with deep recursion and latent updates. |
| | |
| | Args: |
| | input_ids: Input token IDs [B, T] |
| | attention_mask: Attention mask [B, T] |
| | position_ids: Position IDs [B, T] |
| | labels: Target labels for loss computation |
| | return_dict: Whether to return a dict |
| | output_recursion_states: Whether to output intermediate states |
| | """ |
| | return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| | |
| | batch_size, seq_len = input_ids.size() |
| | device = input_ids.device |
| | |
| | |
| | hidden_states = self.wte(input_ids) |
| | |
| | |
| | if not self.config.use_rope: |
| | if position_ids is None: |
| | position_ids = torch.arange(seq_len, device=device).unsqueeze(0) |
| | hidden_states = hidden_states + self.wpe(position_ids) |
| | |
| | hidden_states = self.drop(hidden_states) |
| | |
| | |
| | latent = self.init_latent.unsqueeze(0).expand(batch_size, -1) |
| | |
| | |
| | recursion_states = [] |
| | halting_probs = [] |
| | cumulative_halt = torch.zeros(batch_size, 1, device=device) |
| | |
| | |
| | for r in range(self.config.n_recursions): |
| | |
| | rec_emb = self.recursion_emb(torch.tensor(r, device=device)) |
| | latent_r = latent + rec_emb |
| | |
| | |
| | for block in self.blocks: |
| | hidden_states = block(hidden_states, latent_r, attention_mask) |
| | |
| | |
| | new_info = self.latent_encoder(hidden_states) |
| | |
| | |
| | latent = self.latent_updater(latent, new_info) |
| | |
| | |
| | if output_recursion_states: |
| | recursion_states.append(hidden_states.clone()) |
| | |
| | |
| | if self.halting is not None and self.training: |
| | halt_prob = self.halting(latent) |
| | halting_probs.append(halt_prob) |
| | cumulative_halt = cumulative_halt + halt_prob * (1 - cumulative_halt) |
| | |
| | |
| | if not self.training and (cumulative_halt > self.config.halting_threshold).all(): |
| | break |
| | |
| | |
| | hidden_states = self.ln_f(hidden_states) |
| | logits = self.lm_head(hidden_states) |
| | |
| | |
| | loss = None |
| | if labels is not None: |
| | shift_logits = logits[..., :-1, :].contiguous() |
| | shift_labels = labels[..., 1:].contiguous() |
| | |
| | |
| | loss_fct = nn.CrossEntropyLoss( |
| | ignore_index=-100, |
| | label_smoothing=self.config.label_smoothing if self.training else 0.0, |
| | ) |
| | main_loss = loss_fct( |
| | shift_logits.view(-1, shift_logits.size(-1)), |
| | shift_labels.view(-1), |
| | ) |
| | |
| | |
| | aux_logits = self.aux_head(latent) |
| | |
| | last_token_mask = (labels != -100).sum(dim=-1) - 1 |
| | last_tokens = labels[torch.arange(batch_size, device=device), last_token_mask.clamp(min=0)] |
| | aux_loss = F.cross_entropy(aux_logits, last_tokens, ignore_index=-100) |
| | |
| | |
| | ponder_loss = torch.tensor(0.0, device=device) |
| | if self.halting is not None and len(halting_probs) > 0: |
| | ponder_cost = sum(p.mean() for p in halting_probs) / len(halting_probs) |
| | ponder_loss = (1.0 - ponder_cost) |
| | |
| | |
| | loss = main_loss + self.config.auxiliary_loss_weight * (aux_loss + ponder_loss) |
| | |
| | if not return_dict: |
| | output = (logits,) |
| | return ((loss,) + output) if loss is not None else output |
| | |
| | return CausalLMOutputWithPast( |
| | loss=loss, |
| | logits=logits, |
| | past_key_values=None, |
| | hidden_states=recursion_states if output_recursion_states else None, |
| | attentions=None, |
| | ) |
| |
|
| | @torch.no_grad() |
| | def generate_move( |
| | self, |
| | input_ids: torch.LongTensor, |
| | temperature: float = 1.0, |
| | top_k: Optional[int] = None, |
| | top_p: Optional[float] = None, |
| | ) -> int: |
| | """Generate the next move with deep recursive reasoning.""" |
| | self.eval() |
| | |
| | outputs = self(input_ids) |
| | logits = outputs.logits[:, -1, :] / temperature |
| | |
| | if top_k is not None: |
| | v, _ = torch.topk(logits, min(top_k, logits.size(-1))) |
| | logits[logits < v[:, [-1]]] = float("-inf") |
| | |
| | if top_p is not None: |
| | sorted_logits, sorted_indices = torch.sort(logits, descending=True) |
| | cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
| | sorted_indices_to_remove = cumulative_probs > top_p |
| | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
| | sorted_indices_to_remove[..., 0] = 0 |
| | indices_to_remove = sorted_indices_to_remove.scatter( |
| | dim=-1, index=sorted_indices, src=sorted_indices_to_remove |
| | ) |
| | logits[indices_to_remove] = float("-inf") |
| | |
| | probs = F.softmax(logits, dim=-1) |
| | next_token = torch.multinomial(probs, num_samples=1) |
| | |
| | return next_token.item() |
| |
|
| | @torch.no_grad() |
| | def get_recursion_analysis( |
| | self, |
| | input_ids: torch.LongTensor, |
| | ) -> dict: |
| | """ |
| | Analyze the recursion process for interpretability. |
| | Returns intermediate states and halting probabilities. |
| | """ |
| | self.eval() |
| | |
| | batch_size, seq_len = input_ids.size() |
| | device = input_ids.device |
| | |
| | hidden_states = self.wte(input_ids) |
| | hidden_states = self.drop(hidden_states) |
| | |
| | latent = self.init_latent.unsqueeze(0).expand(batch_size, -1) |
| | |
| | analysis = { |
| | "latent_states": [latent.clone()], |
| | "hidden_norms": [], |
| | "halting_probs": [], |
| | } |
| | |
| | for r in range(self.config.n_recursions): |
| | rec_emb = self.recursion_emb(torch.tensor(r, device=device)) |
| | latent_r = latent + rec_emb |
| | |
| | for block in self.blocks: |
| | hidden_states = block(hidden_states, latent_r, None) |
| | |
| | new_info = self.latent_encoder(hidden_states) |
| | latent = self.latent_updater(latent, new_info) |
| | |
| | analysis["latent_states"].append(latent.clone()) |
| | analysis["hidden_norms"].append(hidden_states.norm(dim=-1).mean().item()) |
| | |
| | if self.halting is not None: |
| | halt_prob = self.halting(latent) |
| | analysis["halting_probs"].append(halt_prob.mean().item()) |
| | |
| | return analysis |
| |
|
| |
|
| | |
| | from transformers import AutoConfig, AutoModelForCausalLM |
| | AutoConfig.register("chess_transformer", ChessConfig) |
| | AutoModelForCausalLM.register(ChessConfig, ChessForCausalLM) |
| |
|