import torch import torch.nn as nn from transformers import PreTrainedModel, PretrainedConfig, GPT2TokenizerFast, Trainer, TrainingArguments, DataCollatorForLanguageModeling from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2MLP from transformers.generation import GenerationMixin # <--- FIXED: Import this explicitly from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions from datasets import load_dataset class TRMConfig(PretrainedConfig): model_type = "recursive_gpt" def __init__( self, vocab_size=50257, n_positions=1024, n_embd=512, n_head=8, n_physical_layers=2, n_loops=6, activation_function="gelu_new", resid_pdrop=0.1, embd_pdrop=0.1, attn_pdrop=0.1, layer_norm_epsilon=1e-5, scale_attn_weights=True, scale_attn_by_inverse_layer_idx=False, reorder_and_upcast_attn=False, **kwargs, ): super().__init__(**kwargs) # Standard config self.vocab_size = vocab_size self.n_positions = n_positions self.n_embd = n_embd self.n_head = n_head self.n_physical_layers = n_physical_layers self.n_loops = n_loops self.activation_function = activation_function self.resid_pdrop = resid_pdrop self.embd_pdrop = embd_pdrop self.attn_pdrop = attn_pdrop self.layer_norm_epsilon = layer_norm_epsilon self.scale_attn_weights = scale_attn_weights self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx self.reorder_and_upcast_attn = reorder_and_upcast_attn # --- CRITICAL FIXES FOR COMPATIBILITY --- # These map your custom names to what GPT2Attention expects self.max_position_embeddings = n_positions self.hidden_size = n_embd self.num_attention_heads = n_head # <--- FIXED: The missing attribute self.num_hidden_layers = n_physical_layers self.n_inner = None # Defaults to 4*hidden_size class TinyRecursiveModel(PreTrainedModel, GenerationMixin): config_class = TRMConfig _tied_weights_keys = ["lm_head.weight"] # <-- Add this line def __init__(self, config): super().__init__(config) self.config = config # 1. Embeddings self.wte = nn.Embedding(config.vocab_size, config.n_embd) self.wpe = nn.Embedding(config.n_positions, config.n_embd) self.drop = nn.Dropout(config.embd_pdrop) # 2. The Logic Core (The "7M" part) self.physical_blocks = nn.ModuleList([ RecursiveBlock(config, layer_idx=i) for i in range(config.n_physical_layers) ]) self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) # Weight tying self.lm_head.weight = self.wte.weight self.post_init() def forward( self, input_ids=None, attention_mask=None, labels=None, return_dict=None, **kwargs): # Default to True if not specified, required for generation return_dict = return_dict if return_dict is not None else self.config.use_return_dict device = input_ids.device b, t = input_ids.size() # Positions & Embeddings pos = torch.arange(0, t, dtype=torch.long, device=device) tok_emb = self.wte(input_ids) pos_emb = self.wpe(pos) hidden_states = self.drop(tok_emb + pos_emb) # Attention Mask Handling if attention_mask is None: attention_mask = torch.ones((b, t), device=device) # Broadcast mask to (batch, head, seq, seq) # We preserve the original mask for the loss calculation later if needed, # but for the blocks we need the 4D version. extended_attention_mask = attention_mask.view(b, 1, 1, t) extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 # ========================================================= # THE RECURSIVE LOOP # ========================================================= for loop_i in range(self.config.n_loops): for block in self.physical_blocks: hidden_states = block(hidden_states, attention_mask=extended_attention_mask) hidden_states = self.ln_f(hidden_states) logits = self.lm_head(hidden_states) loss = None if labels is not None: loss_fct = nn.CrossEntropyLoss() shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) # <--- CRITICAL FIX: Return CausalLMOutputWithCrossAttentions if not return_dict: output = (logits,) return ((loss,) + output) if loss is not None else output return CausalLMOutputWithCrossAttentions( loss=loss, logits=logits, past_key_values=None, # We are not using KV-cache for simplicity in this recursive setup hidden_states=None, attentions=None, ) def prepare_inputs_for_generation(self, input_ids, **kwargs): return {"input_ids": input_ids} class RecursiveBlock(nn.Module): def __init__(self, config, layer_idx): super().__init__() self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) self.attn = GPT2Attention(config, layer_idx=layer_idx) self.ln_2 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) self.mlp = GPT2MLP(config.n_embd, config) def forward(self, x, layer_past=None, attention_mask=None): residual = x x = self.ln_1(x) # We disable caching (use_cache=False) to simplify the recursion loop attn_outputs = self.attn(x, layer_past=layer_past, attention_mask=attention_mask, use_cache=False) x = residual + attn_outputs[0] residual = x x = self.ln_2(x) x = residual + self.mlp(x) return x