|
|
|
|
|
from transformers import PreTrainedModel, PretrainedConfig |
|
|
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2MLP |
|
|
from transformers.generation import GenerationMixin |
|
|
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
class TRMConfig(PretrainedConfig): |
|
|
model_type = "recursive_gpt" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
vocab_size=50257, |
|
|
n_positions=1024, |
|
|
n_embd=512, |
|
|
n_physical_layers=3, |
|
|
n_loops=8, |
|
|
n_head=8, |
|
|
embd_pdrop=0.1, |
|
|
**kwargs |
|
|
): |
|
|
super().__init__(**kwargs) |
|
|
self.vocab_size = vocab_size |
|
|
self.n_positions = n_positions |
|
|
self.n_embd = n_embd |
|
|
self.n_physical_layers = n_physical_layers |
|
|
self.n_loops = n_loops |
|
|
self.n_head = n_head |
|
|
self.embd_pdrop = embd_pdrop |
|
|
|
|
|
|
|
|
self.hidden_size = n_embd |
|
|
self.num_attention_heads = n_head |
|
|
self.num_hidden_layers = n_physical_layers |
|
|
self.n_inner = None |
|
|
|
|
|
class TinyRecursiveModel(PreTrainedModel, GenerationMixin): |
|
|
config_class = TRMConfig |
|
|
_tied_weights_keys = ["lm_head.weight"] |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
self.config = config |
|
|
|
|
|
|
|
|
self.wte = nn.Embedding(config.vocab_size, config.n_embd) |
|
|
self.wpe = nn.Embedding(config.n_positions, config.n_embd) |
|
|
self.drop = nn.Dropout(config.embd_pdrop) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) |
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs): |
|
|
|
|
|
pass |
|
|
|
|
|
def prepare_inputs_for_generation(self, input_ids, **kwargs): |
|
|
return {"input_ids": input_ids} |
|
|
|