|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from transformers import PreTrainedModel, PretrainedConfig, GPT2TokenizerFast, Trainer, TrainingArguments, DataCollatorForLanguageModeling |
|
|
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2MLP |
|
|
from transformers.generation import GenerationMixin |
|
|
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions |
|
|
from datasets import load_dataset |
|
|
class TRMConfig(PretrainedConfig): |
|
|
model_type = "recursive_gpt" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
vocab_size=50257, |
|
|
n_positions=1024, |
|
|
n_embd=512, |
|
|
n_head=8, |
|
|
n_physical_layers=2, |
|
|
n_loops=6, |
|
|
activation_function="gelu_new", |
|
|
resid_pdrop=0.1, |
|
|
embd_pdrop=0.1, |
|
|
attn_pdrop=0.1, |
|
|
layer_norm_epsilon=1e-5, |
|
|
scale_attn_weights=True, |
|
|
scale_attn_by_inverse_layer_idx=False, |
|
|
reorder_and_upcast_attn=False, |
|
|
**kwargs, |
|
|
): |
|
|
super().__init__(**kwargs) |
|
|
|
|
|
self.vocab_size = vocab_size |
|
|
self.n_positions = n_positions |
|
|
self.n_embd = n_embd |
|
|
self.n_head = n_head |
|
|
self.n_physical_layers = n_physical_layers |
|
|
self.n_loops = n_loops |
|
|
self.activation_function = activation_function |
|
|
self.resid_pdrop = resid_pdrop |
|
|
self.embd_pdrop = embd_pdrop |
|
|
self.attn_pdrop = attn_pdrop |
|
|
self.layer_norm_epsilon = layer_norm_epsilon |
|
|
self.scale_attn_weights = scale_attn_weights |
|
|
self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx |
|
|
self.reorder_and_upcast_attn = reorder_and_upcast_attn |
|
|
|
|
|
|
|
|
|
|
|
self.max_position_embeddings = n_positions |
|
|
self.hidden_size = n_embd |
|
|
self.num_attention_heads = n_head |
|
|
self.num_hidden_layers = n_physical_layers |
|
|
self.n_inner = None |
|
|
|
|
|
class TinyRecursiveModel(PreTrainedModel, GenerationMixin): |
|
|
config_class = TRMConfig |
|
|
_tied_weights_keys = ["lm_head.weight"] |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
self.config = config |
|
|
|
|
|
|
|
|
self.wte = nn.Embedding(config.vocab_size, config.n_embd) |
|
|
self.wpe = nn.Embedding(config.n_positions, config.n_embd) |
|
|
self.drop = nn.Dropout(config.embd_pdrop) |
|
|
|
|
|
|
|
|
self.physical_blocks = nn.ModuleList([ |
|
|
RecursiveBlock(config, layer_idx=i) for i in range(config.n_physical_layers) |
|
|
]) |
|
|
|
|
|
self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) |
|
|
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) |
|
|
|
|
|
|
|
|
self.lm_head.weight = self.wte.weight |
|
|
self.post_init() |
|
|
|
|
|
def forward( self, input_ids=None, attention_mask=None, labels=None, return_dict=None, **kwargs): |
|
|
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
device = input_ids.device |
|
|
b, t = input_ids.size() |
|
|
|
|
|
|
|
|
pos = torch.arange(0, t, dtype=torch.long, device=device) |
|
|
tok_emb = self.wte(input_ids) |
|
|
pos_emb = self.wpe(pos) |
|
|
hidden_states = self.drop(tok_emb + pos_emb) |
|
|
|
|
|
|
|
|
if attention_mask is None: |
|
|
attention_mask = torch.ones((b, t), device=device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
extended_attention_mask = attention_mask.view(b, 1, 1, t) |
|
|
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for loop_i in range(self.config.n_loops): |
|
|
for block in self.physical_blocks: |
|
|
hidden_states = block(hidden_states, attention_mask=extended_attention_mask) |
|
|
|
|
|
hidden_states = self.ln_f(hidden_states) |
|
|
logits = self.lm_head(hidden_states) |
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
loss_fct = nn.CrossEntropyLoss() |
|
|
shift_logits = logits[..., :-1, :].contiguous() |
|
|
shift_labels = labels[..., 1:].contiguous() |
|
|
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) |
|
|
|
|
|
|
|
|
if not return_dict: |
|
|
output = (logits,) |
|
|
return ((loss,) + output) if loss is not None else output |
|
|
|
|
|
return CausalLMOutputWithCrossAttentions( |
|
|
loss=loss, |
|
|
logits=logits, |
|
|
past_key_values=None, |
|
|
hidden_states=None, |
|
|
attentions=None, |
|
|
) |
|
|
|
|
|
def prepare_inputs_for_generation(self, input_ids, **kwargs): |
|
|
return {"input_ids": input_ids} |
|
|
|
|
|
class RecursiveBlock(nn.Module): |
|
|
def __init__(self, config, layer_idx): |
|
|
super().__init__() |
|
|
self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) |
|
|
self.attn = GPT2Attention(config, layer_idx=layer_idx) |
|
|
self.ln_2 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) |
|
|
self.mlp = GPT2MLP(config.n_embd, config) |
|
|
|
|
|
def forward(self, x, layer_past=None, attention_mask=None): |
|
|
residual = x |
|
|
x = self.ln_1(x) |
|
|
|
|
|
attn_outputs = self.attn(x, layer_past=layer_past, attention_mask=attention_mask, use_cache=False) |
|
|
x = residual + attn_outputs[0] |
|
|
|
|
|
residual = x |
|
|
x = self.ln_2(x) |
|
|
x = residual + self.mlp(x) |
|
|
return x |
|
|
|