ainz commited on
Commit
53b7a9c
·
verified ·
1 Parent(s): 82fce22

Upload modeling_tiny_recursive.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_tiny_recursive.py +64 -0
modeling_tiny_recursive.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from transformers import PreTrainedModel, PretrainedConfig
3
+ from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2MLP
4
+ from transformers.generation import GenerationMixin
5
+ from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+ class TRMConfig(PretrainedConfig):
10
+ model_type = "recursive_gpt"
11
+
12
+ def __init__(
13
+ self,
14
+ vocab_size=50257,
15
+ n_positions=1024,
16
+ n_embd=512,
17
+ n_physical_layers=3,
18
+ n_loops=8,
19
+ n_head=8,
20
+ embd_pdrop=0.1,
21
+ **kwargs
22
+ ):
23
+ super().__init__(**kwargs)
24
+ self.vocab_size = vocab_size
25
+ self.n_positions = n_positions
26
+ self.n_embd = n_embd
27
+ self.n_physical_layers = n_physical_layers
28
+ self.n_loops = n_loops
29
+ self.n_head = n_head
30
+ self.embd_pdrop = embd_pdrop
31
+
32
+ # Required for transformers compatibility
33
+ self.hidden_size = n_embd
34
+ self.num_attention_heads = n_head
35
+ self.num_hidden_layers = n_physical_layers
36
+ self.n_inner = None
37
+
38
+ class TinyRecursiveModel(PreTrainedModel, GenerationMixin):
39
+ config_class = TRMConfig
40
+ _tied_weights_keys = ["lm_head.weight"]
41
+
42
+ def __init__(self, config):
43
+ super().__init__(config)
44
+ self.config = config
45
+
46
+ # 1. Embeddings
47
+ self.wte = nn.Embedding(config.vocab_size, config.n_embd)
48
+ self.wpe = nn.Embedding(config.n_positions, config.n_embd)
49
+ self.drop = nn.Dropout(config.embd_pdrop)
50
+
51
+ # 2. The Logic Core - Add your recursive layers here
52
+ # [Your recursive implementation from the notebook]
53
+
54
+ # 3. Language modeling head
55
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
56
+
57
+ self.post_init()
58
+
59
+ def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
60
+ # Add your forward pass implementation
61
+ pass
62
+
63
+ def prepare_inputs_for_generation(self, input_ids, **kwargs):
64
+ return {"input_ids": input_ids}