File size: 6,169 Bytes
53b7a9c 85e4754 53b7a9c 85e4754 c9a9460 53b7a9c c9a9460 53b7a9c 85e4754 53b7a9c 85e4754 53b7a9c c9a9460 53b7a9c c9a9460 53b7a9c 85e4754 53b7a9c 85e4754 53b7a9c 85e4754 53b7a9c 85e4754 53b7a9c 85e4754 c9a9460 85e4754 c9a9460 53b7a9c c9a9460 53b7a9c 85e4754 53b7a9c 85e4754 c9a9460 85e4754 c9a9460 85e4754 c9a9460 85e4754 c9a9460 85e4754 c9a9460 85e4754 c9a9460 85e4754 c9a9460 85e4754 c9a9460 85e4754 c9a9460 53b7a9c 85e4754 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
import torch
import torch.nn as nn
from transformers import PreTrainedModel, PretrainedConfig, GPT2TokenizerFast, Trainer, TrainingArguments, DataCollatorForLanguageModeling
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2MLP
from transformers.generation import GenerationMixin # <--- FIXED: Import this explicitly
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
from datasets import load_dataset
class TRMConfig(PretrainedConfig):
model_type = "recursive_gpt"
def __init__(
self,
vocab_size=50257,
n_positions=1024,
n_embd=512,
n_head=8,
n_physical_layers=2,
n_loops=6,
activation_function="gelu_new",
resid_pdrop=0.1,
embd_pdrop=0.1,
attn_pdrop=0.1,
layer_norm_epsilon=1e-5,
scale_attn_weights=True,
scale_attn_by_inverse_layer_idx=False,
reorder_and_upcast_attn=False,
**kwargs,
):
super().__init__(**kwargs)
# Standard config
self.vocab_size = vocab_size
self.n_positions = n_positions
self.n_embd = n_embd
self.n_head = n_head
self.n_physical_layers = n_physical_layers
self.n_loops = n_loops
self.activation_function = activation_function
self.resid_pdrop = resid_pdrop
self.embd_pdrop = embd_pdrop
self.attn_pdrop = attn_pdrop
self.layer_norm_epsilon = layer_norm_epsilon
self.scale_attn_weights = scale_attn_weights
self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx
self.reorder_and_upcast_attn = reorder_and_upcast_attn
# --- CRITICAL FIXES FOR COMPATIBILITY ---
# These map your custom names to what GPT2Attention expects
self.max_position_embeddings = n_positions
self.hidden_size = n_embd
self.num_attention_heads = n_head # <--- FIXED: The missing attribute
self.num_hidden_layers = n_physical_layers
self.n_inner = None # Defaults to 4*hidden_size
class TinyRecursiveModel(PreTrainedModel, GenerationMixin):
config_class = TRMConfig
_tied_weights_keys = ["lm_head.weight"] # <-- Add this line
def __init__(self, config):
super().__init__(config)
self.config = config
# 1. Embeddings
self.wte = nn.Embedding(config.vocab_size, config.n_embd)
self.wpe = nn.Embedding(config.n_positions, config.n_embd)
self.drop = nn.Dropout(config.embd_pdrop)
# 2. The Logic Core (The "7M" part)
self.physical_blocks = nn.ModuleList([
RecursiveBlock(config, layer_idx=i) for i in range(config.n_physical_layers)
])
self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
# Weight tying
self.lm_head.weight = self.wte.weight
self.post_init()
def forward( self, input_ids=None, attention_mask=None, labels=None, return_dict=None, **kwargs):
# Default to True if not specified, required for generation
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
device = input_ids.device
b, t = input_ids.size()
# Positions & Embeddings
pos = torch.arange(0, t, dtype=torch.long, device=device)
tok_emb = self.wte(input_ids)
pos_emb = self.wpe(pos)
hidden_states = self.drop(tok_emb + pos_emb)
# Attention Mask Handling
if attention_mask is None:
attention_mask = torch.ones((b, t), device=device)
# Broadcast mask to (batch, head, seq, seq)
# We preserve the original mask for the loss calculation later if needed,
# but for the blocks we need the 4D version.
extended_attention_mask = attention_mask.view(b, 1, 1, t)
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
# =========================================================
# THE RECURSIVE LOOP
# =========================================================
for loop_i in range(self.config.n_loops):
for block in self.physical_blocks:
hidden_states = block(hidden_states, attention_mask=extended_attention_mask)
hidden_states = self.ln_f(hidden_states)
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
# <--- CRITICAL FIX: Return CausalLMOutputWithCrossAttentions
if not return_dict:
output = (logits,)
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithCrossAttentions(
loss=loss,
logits=logits,
past_key_values=None, # We are not using KV-cache for simplicity in this recursive setup
hidden_states=None,
attentions=None,
)
def prepare_inputs_for_generation(self, input_ids, **kwargs):
return {"input_ids": input_ids}
class RecursiveBlock(nn.Module):
def __init__(self, config, layer_idx):
super().__init__()
self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
self.attn = GPT2Attention(config, layer_idx=layer_idx)
self.ln_2 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
self.mlp = GPT2MLP(config.n_embd, config)
def forward(self, x, layer_past=None, attention_mask=None):
residual = x
x = self.ln_1(x)
# We disable caching (use_cache=False) to simplify the recursion loop
attn_outputs = self.attn(x, layer_past=layer_past, attention_mask=attention_mask, use_cache=False)
x = residual + attn_outputs[0]
residual = x
x = self.ln_2(x)
x = residual + self.mlp(x)
return x
|