Upload folder using huggingface_hub
Browse files- esm_nv.py +26 -2
- model-00001-of-00013.safetensors +2 -2
- model-00013-of-00013.safetensors +2 -2
- model.safetensors.index.json +2 -2
esm_nv.py
CHANGED
|
@@ -259,9 +259,13 @@ class NVEsmPreTrainedModel(PreTrainedModel):
|
|
| 259 |
"EsmEmbeddings",
|
| 260 |
)
|
| 261 |
|
| 262 |
-
# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
|
| 263 |
def _init_weights(self, module: nn.Module):
|
| 264 |
-
"""Initialize
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 265 |
|
| 266 |
Args:
|
| 267 |
module (nn.Module): The module to initialize the weights for.
|
|
@@ -282,9 +286,29 @@ class NVEsmPreTrainedModel(PreTrainedModel):
|
|
| 282 |
module.bias.data.zero_()
|
| 283 |
module.weight.data.fill_(1.0)
|
| 284 |
if isinstance(module, transformer_engine.pytorch.LayerNormLinear):
|
|
|
|
|
|
|
| 285 |
module.layer_norm_weight.data.fill_(1.0)
|
| 286 |
if module.layer_norm_bias is not None:
|
| 287 |
module.layer_norm_bias.data.zero_()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 288 |
|
| 289 |
@classmethod
|
| 290 |
def get_init_context(cls, is_quantized: bool, _is_ds_init_called: bool):
|
|
|
|
| 259 |
"EsmEmbeddings",
|
| 260 |
)
|
| 261 |
|
|
|
|
| 262 |
def _init_weights(self, module: nn.Module):
|
| 263 |
+
"""Initialize model weights.
|
| 264 |
+
|
| 265 |
+
This method ensures that models with randomly-initialized weights get the correct initial value distribution,
|
| 266 |
+
which can be critical for training stability. We also call this method directly when using meta-device init, as
|
| 267 |
+
the `to_empty` method does not initialize the weights. While the base Transformers model has a similar method,
|
| 268 |
+
we need to extend it to handle TE-specific modules.
|
| 269 |
|
| 270 |
Args:
|
| 271 |
module (nn.Module): The module to initialize the weights for.
|
|
|
|
| 286 |
module.bias.data.zero_()
|
| 287 |
module.weight.data.fill_(1.0)
|
| 288 |
if isinstance(module, transformer_engine.pytorch.LayerNormLinear):
|
| 289 |
+
if module.layer_norm_bias is not None:
|
| 290 |
+
module.layer_norm_bias.data.zero_()
|
| 291 |
module.layer_norm_weight.data.fill_(1.0)
|
| 292 |
if module.layer_norm_bias is not None:
|
| 293 |
module.layer_norm_bias.data.zero_()
|
| 294 |
+
if isinstance(module, transformer_engine.pytorch.LayerNormMLP):
|
| 295 |
+
if module.layer_norm_bias is not None:
|
| 296 |
+
module.layer_norm_bias.data.zero_()
|
| 297 |
+
module.layer_norm_weight.data.fill_(1.0)
|
| 298 |
+
if hasattr(module, "fc1_weight") and module.fc1_weight is not None:
|
| 299 |
+
module.fc1_weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 300 |
+
if hasattr(module, "fc2_weight") and module.fc2_weight is not None:
|
| 301 |
+
module.fc2_weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 302 |
+
if hasattr(module, "fc1_bias") and module.fc1_bias is not None and module.fc1_bias.numel() > 0:
|
| 303 |
+
module.fc1_bias.data.zero_()
|
| 304 |
+
if hasattr(module, "fc2_bias") and module.fc2_bias is not None and module.fc2_bias.numel() > 0:
|
| 305 |
+
module.fc2_bias.data.zero_()
|
| 306 |
+
if isinstance(module, RotaryPositionEmbedding) and hasattr(module, "inv_freq"):
|
| 307 |
+
# When we initialize the model with `to_empty`, the `inv_freq` attribute is not initialized, so we need to
|
| 308 |
+
# re-initialize it here with the correct values.
|
| 309 |
+
module.inv_freq = RotaryPositionEmbedding(
|
| 310 |
+
self.config.hidden_size // self.config.num_attention_heads
|
| 311 |
+
).inv_freq.to(module.inv_freq.device)
|
| 312 |
|
| 313 |
@classmethod
|
| 314 |
def get_init_context(cls, is_quantized: bool, _is_ds_init_called: bool):
|
model-00001-of-00013.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:172497ffc270184c31badaf14e0132e234c16112ddd2eee373838faaead6d573
|
| 3 |
+
size 4616120040
|
model-00013-of-00013.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0956689eebee89d381080f4605597b90e5e33a82ff7a581ac9ae8f6f2f067f0a
|
| 3 |
+
size 3880596984
|
model.safetensors.index.json
CHANGED
|
@@ -1,11 +1,11 @@
|
|
| 1 |
{
|
| 2 |
"metadata": {
|
| 3 |
"total_parameters": 15129257024,
|
| 4 |
-
"total_size":
|
| 5 |
},
|
| 6 |
"weight_map": {
|
| 7 |
"esm.embeddings.word_embeddings.weight": "model-00001-of-00013.safetensors",
|
| 8 |
-
"esm.encoder.emb_layer_norm_after._extra_state": "model-
|
| 9 |
"esm.encoder.emb_layer_norm_after.bias": "model-00013-of-00013.safetensors",
|
| 10 |
"esm.encoder.emb_layer_norm_after.weight": "model-00013-of-00013.safetensors",
|
| 11 |
"esm.encoder.layers.0.layernorm_mlp._extra_state": "model-00001-of-00013.safetensors",
|
|
|
|
| 1 |
{
|
| 2 |
"metadata": {
|
| 3 |
"total_parameters": 15129257024,
|
| 4 |
+
"total_size": 60517028352
|
| 5 |
},
|
| 6 |
"weight_map": {
|
| 7 |
"esm.embeddings.word_embeddings.weight": "model-00001-of-00013.safetensors",
|
| 8 |
+
"esm.encoder.emb_layer_norm_after._extra_state": "model-00001-of-00013.safetensors",
|
| 9 |
"esm.encoder.emb_layer_norm_after.bias": "model-00013-of-00013.safetensors",
|
| 10 |
"esm.encoder.emb_layer_norm_after.weight": "model-00013-of-00013.safetensors",
|
| 11 |
"esm.encoder.layers.0.layernorm_mlp._extra_state": "model-00001-of-00013.safetensors",
|