File size: 6,169 Bytes
53b7a9c
 
 
85e4754
 
 
 
 
53b7a9c
 
 
 
 
 
 
 
 
85e4754
 
c9a9460
 
53b7a9c
c9a9460
 
 
 
 
 
53b7a9c
 
85e4754
53b7a9c
 
 
85e4754
53b7a9c
 
c9a9460
 
53b7a9c
c9a9460
 
 
 
 
53b7a9c
85e4754
 
 
53b7a9c
85e4754
53b7a9c
85e4754
53b7a9c
 
 
85e4754
53b7a9c
 
 
 
 
 
 
 
 
 
85e4754
c9a9460
85e4754
c9a9460
53b7a9c
c9a9460
53b7a9c
 
85e4754
 
53b7a9c
 
85e4754
 
 
 
 
 
c9a9460
85e4754
 
 
 
 
c9a9460
85e4754
 
 
c9a9460
85e4754
 
 
 
 
c9a9460
85e4754
 
 
 
 
 
c9a9460
 
 
 
 
 
85e4754
c9a9460
 
 
 
85e4754
 
 
 
 
c9a9460
 
 
85e4754
c9a9460
85e4754
c9a9460
53b7a9c
 
 
85e4754
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155

import torch
import torch.nn as nn
from transformers import PreTrainedModel, PretrainedConfig, GPT2TokenizerFast, Trainer, TrainingArguments, DataCollatorForLanguageModeling
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2MLP
from transformers.generation import GenerationMixin  # <--- FIXED: Import this explicitly
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
from datasets import load_dataset
class TRMConfig(PretrainedConfig):
    model_type = "recursive_gpt"

    def __init__(
        self,
        vocab_size=50257,
        n_positions=1024,
        n_embd=512,
        n_head=8,
        n_physical_layers=2,
        n_loops=6,
        activation_function="gelu_new",
        resid_pdrop=0.1,
        embd_pdrop=0.1,
        attn_pdrop=0.1,
        layer_norm_epsilon=1e-5,
        scale_attn_weights=True,
        scale_attn_by_inverse_layer_idx=False,
        reorder_and_upcast_attn=False,
        **kwargs,
    ):
        super().__init__(**kwargs)
        # Standard config
        self.vocab_size = vocab_size
        self.n_positions = n_positions
        self.n_embd = n_embd
        self.n_head = n_head
        self.n_physical_layers = n_physical_layers
        self.n_loops = n_loops
        self.activation_function = activation_function
        self.resid_pdrop = resid_pdrop
        self.embd_pdrop = embd_pdrop
        self.attn_pdrop = attn_pdrop
        self.layer_norm_epsilon = layer_norm_epsilon
        self.scale_attn_weights = scale_attn_weights
        self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx
        self.reorder_and_upcast_attn = reorder_and_upcast_attn

        # --- CRITICAL FIXES FOR COMPATIBILITY ---
        # These map your custom names to what GPT2Attention expects
        self.max_position_embeddings = n_positions 
        self.hidden_size = n_embd
        self.num_attention_heads = n_head          # <--- FIXED: The missing attribute
        self.num_hidden_layers = n_physical_layers
        self.n_inner = None                        # Defaults to 4*hidden_size

class TinyRecursiveModel(PreTrainedModel, GenerationMixin):
    config_class = TRMConfig
    _tied_weights_keys = ["lm_head.weight"]  # <-- Add this line

    def __init__(self, config):
        super().__init__(config)
        self.config = config

        # 1. Embeddings
        self.wte = nn.Embedding(config.vocab_size, config.n_embd)
        self.wpe = nn.Embedding(config.n_positions, config.n_embd)
        self.drop = nn.Dropout(config.embd_pdrop)

        # 2. The Logic Core (The "7M" part)
        self.physical_blocks = nn.ModuleList([
            RecursiveBlock(config, layer_idx=i) for i in range(config.n_physical_layers)
        ])

        self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

        # Weight tying
        self.lm_head.weight = self.wte.weight
        self.post_init()

    def forward( self, input_ids=None, attention_mask=None, labels=None, return_dict=None, **kwargs):
        # Default to True if not specified, required for generation
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        device = input_ids.device
        b, t = input_ids.size()

        # Positions & Embeddings
        pos = torch.arange(0, t, dtype=torch.long, device=device)
        tok_emb = self.wte(input_ids)
        pos_emb = self.wpe(pos)
        hidden_states = self.drop(tok_emb + pos_emb)

        # Attention Mask Handling
        if attention_mask is None:
            attention_mask = torch.ones((b, t), device=device)

        # Broadcast mask to (batch, head, seq, seq)
        # We preserve the original mask for the loss calculation later if needed, 
        # but for the blocks we need the 4D version.
        extended_attention_mask = attention_mask.view(b, 1, 1, t)
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

        # =========================================================
        # THE RECURSIVE LOOP 
        # =========================================================
        for loop_i in range(self.config.n_loops):
            for block in self.physical_blocks:
                hidden_states = block(hidden_states, attention_mask=extended_attention_mask)

        hidden_states = self.ln_f(hidden_states)
        logits = self.lm_head(hidden_states)

        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

        # <--- CRITICAL FIX: Return CausalLMOutputWithCrossAttentions
        if not return_dict:
            output = (logits,)
            return ((loss,) + output) if loss is not None else output

        return CausalLMOutputWithCrossAttentions(
            loss=loss,
            logits=logits,
            past_key_values=None, # We are not using KV-cache for simplicity in this recursive setup
            hidden_states=None,
            attentions=None,
        )

    def prepare_inputs_for_generation(self, input_ids, **kwargs):
        return {"input_ids": input_ids}

class RecursiveBlock(nn.Module):
    def __init__(self, config, layer_idx):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
        self.attn = GPT2Attention(config, layer_idx=layer_idx)
        self.ln_2 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
        self.mlp = GPT2MLP(config.n_embd, config)

    def forward(self, x, layer_past=None, attention_mask=None):
        residual = x
        x = self.ln_1(x)
        # We disable caching (use_cache=False) to simplify the recursion loop
        attn_outputs = self.attn(x, layer_past=layer_past, attention_mask=attention_mask, use_cache=False)
        x = residual + attn_outputs[0]

        residual = x
        x = self.ln_2(x)
        x = residual + self.mlp(x)
        return x