Johnblick187 commited on
Commit
8b9a191
·
verified ·
1 Parent(s): 73f53c5

Update modeling_smartcoder_moe.py

Browse files
Files changed (1) hide show
  1. modeling_smartcoder_moe.py +26 -3
modeling_smartcoder_moe.py CHANGED
@@ -1,5 +1,4 @@
1
-
2
- # modeling_smartcoder_moe.py
3
 
4
  #Architecture (from tensor inspection):
5
  #- vocab_size: 65536, hidden: 2048, layers: 40
@@ -17,6 +16,7 @@ import math
17
  import torch
18
  import torch.nn as nn
19
  import torch.nn.functional as F
 
20
  from transformers import PreTrainedModel, PretrainedConfig, GenerationMixin
21
  from transformers.modeling_outputs import CausalLMOutputWithPast
22
 
@@ -223,11 +223,22 @@ class SmartCoderMoEModel(nn.Module):
223
  self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
224
  self.layers = nn.ModuleList([SmartCoderDecoderLayer(config) for _ in range(config.num_hidden_layers)])
225
  self.norm = LayerNormWithBias(config.hidden_size, config.rms_norm_eps)
 
 
 
 
 
 
226
 
227
  def forward(self, input_ids, attention_mask=None, **kwargs):
228
  hidden_states = self.embed_tokens(input_ids)
229
  for layer in self.layers:
230
- hidden_states = layer(hidden_states, attention_mask=attention_mask)
 
 
 
 
 
231
  return self.norm(hidden_states)
232
 
233
 
@@ -254,6 +265,18 @@ class SmartCoderMoEForCausalLM(PreTrainedModel, GenerationMixin):
254
  def get_input_embeddings(self): return self.model.embed_tokens
255
  def get_output_embeddings(self): return self.lm_head
256
 
 
 
 
 
 
 
 
 
 
 
 
 
257
  def forward(
258
  self,
259
  input_ids=None,
 
1
+ #modeling_smartcoder_moe.py
 
2
 
3
  #Architecture (from tensor inspection):
4
  #- vocab_size: 65536, hidden: 2048, layers: 40
 
16
  import torch
17
  import torch.nn as nn
18
  import torch.nn.functional as F
19
+ import torch.utils.checkpoint
20
  from transformers import PreTrainedModel, PretrainedConfig, GenerationMixin
21
  from transformers.modeling_outputs import CausalLMOutputWithPast
22
 
 
223
  self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
224
  self.layers = nn.ModuleList([SmartCoderDecoderLayer(config) for _ in range(config.num_hidden_layers)])
225
  self.norm = LayerNormWithBias(config.hidden_size, config.rms_norm_eps)
226
+ # Required for transformers' _set_gradient_checkpointing() to have
227
+ # something to toggle. Without this attribute + the checkpoint()
228
+ # call in forward(), declaring supports_gradient_checkpointing=True
229
+ # at the PreTrainedModel level is a lie transformers will catch and
230
+ # raise on -- which is exactly the ValueError this fixes.
231
+ self.gradient_checkpointing = True
232
 
233
  def forward(self, input_ids, attention_mask=None, **kwargs):
234
  hidden_states = self.embed_tokens(input_ids)
235
  for layer in self.layers:
236
+ if self.gradient_checkpointing and self.training:
237
+ hidden_states = torch.utils.checkpoint.checkpoint(
238
+ layer, hidden_states, attention_mask, use_reentrant=True
239
+ )
240
+ else:
241
+ hidden_states = layer(hidden_states, attention_mask=attention_mask)
242
  return self.norm(hidden_states)
243
 
244
 
 
265
  def get_input_embeddings(self): return self.model.embed_tokens
266
  def get_output_embeddings(self): return self.lm_head
267
 
268
+ # transformers' _set_gradient_checkpointing (called by Unsloth/Trainer)
269
+ # looks for this attribute on the *PreTrainedModel* root, finds the
270
+ # submodule that has it, and toggles it. Exposing it here as a property
271
+ # delegating to self.model keeps both objects in sync.
272
+ @property
273
+ def gradient_checkpointing(self):
274
+ return self.model.gradient_checkpointing
275
+
276
+ @gradient_checkpointing.setter
277
+ def gradient_checkpointing(self, value):
278
+ self.model.gradient_checkpointing = value
279
+
280
  def forward(
281
  self,
282
  input_ids=None,