pstjohn commited on
Commit
6547ce1
·
verified ·
1 Parent(s): d90b4db

Upload folder using huggingface_hub

Browse files
esm_nv.py CHANGED
@@ -259,9 +259,13 @@ class NVEsmPreTrainedModel(PreTrainedModel):
259
  "EsmEmbeddings",
260
  )
261
 
262
- # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
263
  def _init_weights(self, module: nn.Module):
264
- """Initialize the weights.
 
 
 
 
 
265
 
266
  Args:
267
  module (nn.Module): The module to initialize the weights for.
@@ -282,9 +286,29 @@ class NVEsmPreTrainedModel(PreTrainedModel):
282
  module.bias.data.zero_()
283
  module.weight.data.fill_(1.0)
284
  if isinstance(module, transformer_engine.pytorch.LayerNormLinear):
 
 
285
  module.layer_norm_weight.data.fill_(1.0)
286
  if module.layer_norm_bias is not None:
287
  module.layer_norm_bias.data.zero_()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288
 
289
  @classmethod
290
  def get_init_context(cls, is_quantized: bool, _is_ds_init_called: bool):
 
259
  "EsmEmbeddings",
260
  )
261
 
 
262
  def _init_weights(self, module: nn.Module):
263
+ """Initialize model weights.
264
+
265
+ This method ensures that models with randomly-initialized weights get the correct initial value distribution,
266
+ which can be critical for training stability. We also call this method directly when using meta-device init, as
267
+ the `to_empty` method does not initialize the weights. While the base Transformers model has a similar method,
268
+ we need to extend it to handle TE-specific modules.
269
 
270
  Args:
271
  module (nn.Module): The module to initialize the weights for.
 
286
  module.bias.data.zero_()
287
  module.weight.data.fill_(1.0)
288
  if isinstance(module, transformer_engine.pytorch.LayerNormLinear):
289
+ if module.layer_norm_bias is not None:
290
+ module.layer_norm_bias.data.zero_()
291
  module.layer_norm_weight.data.fill_(1.0)
292
  if module.layer_norm_bias is not None:
293
  module.layer_norm_bias.data.zero_()
294
+ if isinstance(module, transformer_engine.pytorch.LayerNormMLP):
295
+ if module.layer_norm_bias is not None:
296
+ module.layer_norm_bias.data.zero_()
297
+ module.layer_norm_weight.data.fill_(1.0)
298
+ if hasattr(module, "fc1_weight") and module.fc1_weight is not None:
299
+ module.fc1_weight.data.normal_(mean=0.0, std=self.config.initializer_range)
300
+ if hasattr(module, "fc2_weight") and module.fc2_weight is not None:
301
+ module.fc2_weight.data.normal_(mean=0.0, std=self.config.initializer_range)
302
+ if hasattr(module, "fc1_bias") and module.fc1_bias is not None and module.fc1_bias.numel() > 0:
303
+ module.fc1_bias.data.zero_()
304
+ if hasattr(module, "fc2_bias") and module.fc2_bias is not None and module.fc2_bias.numel() > 0:
305
+ module.fc2_bias.data.zero_()
306
+ if isinstance(module, RotaryPositionEmbedding) and hasattr(module, "inv_freq"):
307
+ # When we initialize the model with `to_empty`, the `inv_freq` attribute is not initialized, so we need to
308
+ # re-initialize it here with the correct values.
309
+ module.inv_freq = RotaryPositionEmbedding(
310
+ self.config.hidden_size // self.config.num_attention_heads
311
+ ).inv_freq.to(module.inv_freq.device)
312
 
313
  @classmethod
314
  def get_init_context(cls, is_quantized: bool, _is_ds_init_called: bool):
model-00001-of-00013.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:02007bbf2c42108ab281ff3fea39f6a495181b2390b05e4e469d54e51eb161b2
3
- size 4616119920
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:172497ffc270184c31badaf14e0132e234c16112ddd2eee373838faaead6d573
3
+ size 4616120040
model-00013-of-00013.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:4ada1f6ab3d781e8848fa39585d06bedaed2372b97c0e80f479648e2ac617d34
3
- size 3880597101
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0956689eebee89d381080f4605597b90e5e33a82ff7a581ac9ae8f6f2f067f0a
3
+ size 3880596984
model.safetensors.index.json CHANGED
@@ -1,11 +1,11 @@
1
  {
2
  "metadata": {
3
  "total_parameters": 15129257024,
4
- "total_size": 60517028357
5
  },
6
  "weight_map": {
7
  "esm.embeddings.word_embeddings.weight": "model-00001-of-00013.safetensors",
8
- "esm.encoder.emb_layer_norm_after._extra_state": "model-00013-of-00013.safetensors",
9
  "esm.encoder.emb_layer_norm_after.bias": "model-00013-of-00013.safetensors",
10
  "esm.encoder.emb_layer_norm_after.weight": "model-00013-of-00013.safetensors",
11
  "esm.encoder.layers.0.layernorm_mlp._extra_state": "model-00001-of-00013.safetensors",
 
1
  {
2
  "metadata": {
3
  "total_parameters": 15129257024,
4
+ "total_size": 60517028352
5
  },
6
  "weight_map": {
7
  "esm.embeddings.word_embeddings.weight": "model-00001-of-00013.safetensors",
8
+ "esm.encoder.emb_layer_norm_after._extra_state": "model-00001-of-00013.safetensors",
9
  "esm.encoder.emb_layer_norm_after.bias": "model-00013-of-00013.safetensors",
10
  "esm.encoder.emb_layer_norm_after.weight": "model-00013-of-00013.safetensors",
11
  "esm.encoder.layers.0.layernorm_mlp._extra_state": "model-00001-of-00013.safetensors",