ainz commited on
Commit
c9a9460
·
verified ·
1 Parent(s): c887597

Update modeling file with complete recursive implementation

Browse files
Files changed (1) hide show
  1. modeling_tiny_recursive.py +69 -6
modeling_tiny_recursive.py CHANGED
@@ -17,8 +17,15 @@ class TRMConfig(PretrainedConfig):
17
  n_physical_layers=3,
18
  n_loops=8,
19
  n_head=8,
 
 
20
  embd_pdrop=0.1,
21
- **kwargs
 
 
 
 
 
22
  ):
23
  super().__init__(**kwargs)
24
  self.vocab_size = vocab_size
@@ -27,7 +34,14 @@ class TRMConfig(PretrainedConfig):
27
  self.n_physical_layers = n_physical_layers
28
  self.n_loops = n_loops
29
  self.n_head = n_head
 
 
30
  self.embd_pdrop = embd_pdrop
 
 
 
 
 
31
 
32
  # Required for transformers compatibility
33
  self.hidden_size = n_embd
@@ -48,17 +62,66 @@ class TinyRecursiveModel(PreTrainedModel, GenerationMixin):
48
  self.wpe = nn.Embedding(config.n_positions, config.n_embd)
49
  self.drop = nn.Dropout(config.embd_pdrop)
50
 
51
- # 2. The Logic Core - Add your recursive layers here
52
- # [Your recursive implementation from the notebook]
 
 
 
 
 
 
 
53
 
54
- # 3. Language modeling head
 
 
 
55
  self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
56
 
 
57
  self.post_init()
58
 
59
  def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
60
- # Add your forward pass implementation
61
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
  def prepare_inputs_for_generation(self, input_ids, **kwargs):
64
  return {"input_ids": input_ids}
 
17
  n_physical_layers=3,
18
  n_loops=8,
19
  n_head=8,
20
+ activation_function="gelu_new",
21
+ resid_pdrop=0.1,
22
  embd_pdrop=0.1,
23
+ attn_pdrop=0.1,
24
+ layer_norm_epsilon=1e-5,
25
+ scale_attn_weights=True,
26
+ scale_attn_by_inverse_layer_idx=False,
27
+ reorder_and_upcast_attn=False,
28
+ **kwargs,
29
  ):
30
  super().__init__(**kwargs)
31
  self.vocab_size = vocab_size
 
34
  self.n_physical_layers = n_physical_layers
35
  self.n_loops = n_loops
36
  self.n_head = n_head
37
+ self.activation_function = activation_function
38
+ self.resid_pdrop = resid_pdrop
39
  self.embd_pdrop = embd_pdrop
40
+ self.attn_pdrop = attn_pdrop
41
+ self.layer_norm_epsilon = layer_norm_epsilon
42
+ self.scale_attn_weights = scale_attn_weights
43
+ self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx
44
+ self.reorder_and_upcast_attn = reorder_and_upcast_attn
45
 
46
  # Required for transformers compatibility
47
  self.hidden_size = n_embd
 
62
  self.wpe = nn.Embedding(config.n_positions, config.n_embd)
63
  self.drop = nn.Dropout(config.embd_pdrop)
64
 
65
+ # 2. The Logic Core - Physical transformer blocks
66
+ self.physical_blocks = nn.ModuleList([
67
+ nn.ModuleDict({
68
+ "ln_1": nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon),
69
+ "attn": GPT2Attention(config, layer_idx=i),
70
+ "ln_2": nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon),
71
+ "mlp": GPT2MLP(4 * config.n_embd, config)
72
+ }) for i in range(config.n_physical_layers)
73
+ ])
74
 
75
+ # 3. Final layer norm
76
+ self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
77
+
78
+ # 4. Language modeling head
79
  self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
80
 
81
+ # Initialize weights
82
  self.post_init()
83
 
84
  def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
85
+ batch_size, seq_len = input_ids.shape
86
+
87
+ # Get embeddings
88
+ token_embeds = self.wte(input_ids)
89
+ pos_ids = torch.arange(0, seq_len, dtype=torch.long, device=input_ids.device)
90
+ pos_embeds = self.wpe(pos_ids)
91
+ hidden_states = self.drop(token_embeds + pos_embeds)
92
+
93
+ # Apply recursive loops through physical blocks
94
+ for loop in range(self.config.n_loops):
95
+ block_idx = loop % self.config.n_physical_layers
96
+ block = self.physical_blocks[block_idx]
97
+
98
+ # Attention
99
+ ln_output = block["ln_1"](hidden_states)
100
+ attn_output = block["attn"](ln_output, attention_mask=attention_mask)[0]
101
+ hidden_states = hidden_states + attn_output
102
+
103
+ # MLP
104
+ ln_output = block["ln_2"](hidden_states)
105
+ mlp_output = block["mlp"](ln_output)
106
+ hidden_states = hidden_states + mlp_output
107
+
108
+ # Final layer norm and projection
109
+ hidden_states = self.ln_f(hidden_states)
110
+ logits = self.lm_head(hidden_states)
111
+
112
+ loss = None
113
+ if labels is not None:
114
+ shift_logits = logits[..., :-1, :].contiguous()
115
+ shift_labels = labels[..., 1:].contiguous()
116
+ loss_fct = nn.CrossEntropyLoss()
117
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
118
+
119
+ return CausalLMOutputWithCrossAttentions(
120
+ loss=loss,
121
+ logits=logits,
122
+ hidden_states=None,
123
+ attentions=None
124
+ )
125
 
126
  def prepare_inputs_for_generation(self, input_ids, **kwargs):
127
  return {"input_ids": input_ids}