pstjohn commited on
Commit
c103398
·
verified ·
1 Parent(s): 8c5e247

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. esm_nv.py +26 -2
  2. model.safetensors +2 -2
esm_nv.py CHANGED
@@ -259,9 +259,13 @@ class NVEsmPreTrainedModel(PreTrainedModel):
259
  "EsmEmbeddings",
260
  )
261
 
262
- # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
263
  def _init_weights(self, module: nn.Module):
264
- """Initialize the weights.
 
 
 
 
 
265
 
266
  Args:
267
  module (nn.Module): The module to initialize the weights for.
@@ -282,9 +286,29 @@ class NVEsmPreTrainedModel(PreTrainedModel):
282
  module.bias.data.zero_()
283
  module.weight.data.fill_(1.0)
284
  if isinstance(module, transformer_engine.pytorch.LayerNormLinear):
 
 
285
  module.layer_norm_weight.data.fill_(1.0)
286
  if module.layer_norm_bias is not None:
287
  module.layer_norm_bias.data.zero_()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288
 
289
  @classmethod
290
  def get_init_context(cls, is_quantized: bool, _is_ds_init_called: bool):
 
259
  "EsmEmbeddings",
260
  )
261
 
 
262
  def _init_weights(self, module: nn.Module):
263
+ """Initialize model weights.
264
+
265
+ This method ensures that models with randomly-initialized weights get the correct initial value distribution,
266
+ which can be critical for training stability. We also call this method directly when using meta-device init, as
267
+ the `to_empty` method does not initialize the weights. While the base Transformers model has a similar method,
268
+ we need to extend it to handle TE-specific modules.
269
 
270
  Args:
271
  module (nn.Module): The module to initialize the weights for.
 
286
  module.bias.data.zero_()
287
  module.weight.data.fill_(1.0)
288
  if isinstance(module, transformer_engine.pytorch.LayerNormLinear):
289
+ if module.layer_norm_bias is not None:
290
+ module.layer_norm_bias.data.zero_()
291
  module.layer_norm_weight.data.fill_(1.0)
292
  if module.layer_norm_bias is not None:
293
  module.layer_norm_bias.data.zero_()
294
+ if isinstance(module, transformer_engine.pytorch.LayerNormMLP):
295
+ if module.layer_norm_bias is not None:
296
+ module.layer_norm_bias.data.zero_()
297
+ module.layer_norm_weight.data.fill_(1.0)
298
+ if hasattr(module, "fc1_weight") and module.fc1_weight is not None:
299
+ module.fc1_weight.data.normal_(mean=0.0, std=self.config.initializer_range)
300
+ if hasattr(module, "fc2_weight") and module.fc2_weight is not None:
301
+ module.fc2_weight.data.normal_(mean=0.0, std=self.config.initializer_range)
302
+ if hasattr(module, "fc1_bias") and module.fc1_bias is not None and module.fc1_bias.numel() > 0:
303
+ module.fc1_bias.data.zero_()
304
+ if hasattr(module, "fc2_bias") and module.fc2_bias is not None and module.fc2_bias.numel() > 0:
305
+ module.fc2_bias.data.zero_()
306
+ if isinstance(module, RotaryPositionEmbedding) and hasattr(module, "inv_freq"):
307
+ # When we initialize the model with `to_empty`, the `inv_freq` attribute is not initialized, so we need to
308
+ # re-initialize it here with the correct values.
309
+ module.inv_freq = RotaryPositionEmbedding(
310
+ self.config.hidden_size // self.config.num_attention_heads
311
+ ).inv_freq.to(module.inv_freq.device)
312
 
313
  @classmethod
314
  def get_init_context(cls, is_quantized: bool, _is_ds_init_called: bool):
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:6786954a44700bf22b6f3bb0eb1d88c559c2a7b627bb7de239cb38071acb0130
3
- size 134088917
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7051a31593e7e047a65a1576288f119cd0dc521b7eec9c8a3e7074d7f7e0b3b1
3
+ size 134088912