pstjohn commited on
Commit
9f94220
·
verified ·
1 Parent(s): d81c2e5

Upload folder using huggingface_hub

Browse files
Files changed (4) hide show
  1. config.json +4 -1
  2. esm_nv.py +30 -8
  3. model.safetensors +2 -2
  4. tokenizer_config.json +4 -46
config.json CHANGED
@@ -1,4 +1,5 @@
1
  {
 
2
  "architectures": [
3
  "NVEsmForMaskedLM"
4
  ],
@@ -22,6 +23,7 @@
22
  "hidden_size": 1280,
23
  "initializer_range": 0.02,
24
  "intermediate_size": 5120,
 
25
  "is_folding_model": false,
26
  "layer_norm_eps": 1e-05,
27
  "mask_token_id": 32,
@@ -35,8 +37,9 @@
35
  "padded_vocab_size": 64,
36
  "position_embedding_type": "rotary",
37
  "qkv_weight_interleaved": true,
 
38
  "token_dropout": true,
39
- "transformers_version": "4.57.6",
40
  "use_cache": true,
41
  "vocab_list": null,
42
  "vocab_size": 33
 
1
  {
2
+ "add_cross_attention": false,
3
  "architectures": [
4
  "NVEsmForMaskedLM"
5
  ],
 
23
  "hidden_size": 1280,
24
  "initializer_range": 0.02,
25
  "intermediate_size": 5120,
26
+ "is_decoder": false,
27
  "is_folding_model": false,
28
  "layer_norm_eps": 1e-05,
29
  "mask_token_id": 32,
 
37
  "padded_vocab_size": 64,
38
  "position_embedding_type": "rotary",
39
  "qkv_weight_interleaved": true,
40
+ "tie_word_embeddings": true,
41
  "token_dropout": true,
42
+ "transformers_version": "5.0.0",
43
  "use_cache": true,
44
  "vocab_list": null,
45
  "vocab_size": 33
esm_nv.py CHANGED
@@ -22,7 +22,7 @@
22
  Adapted from `modeling_esm.py` in huggingface/transformers.
23
  """
24
 
25
- from typing import Literal, Optional, Unpack
26
 
27
  # TODO: put import guard around transformer_engine here, with an informative error message around
28
  # installation and the nvidia docker container.
@@ -256,10 +256,34 @@ class NVEsmPreTrainedModel(EsmPreTrainedModel):
256
  # Meta-device init seems to break weight tying, so we re-tie the weights here.
257
  self.tie_weights()
258
 
259
- @classmethod
260
- def get_init_context(cls, is_quantized: bool, _is_ds_init_called: bool):
261
- """Override the default get_init_context method to allow for fp8 model initialization."""
262
- return []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
 
264
 
265
  class NVEsmModel(NVEsmPreTrainedModel):
@@ -367,7 +391,7 @@ class NVEsmModel(NVEsmPreTrainedModel):
367
  class NVEsmForMaskedLM(NVEsmPreTrainedModel):
368
  """NVEsmForMaskedLM is a TransformerEngine-optimized ESM model for masked language modeling."""
369
 
370
- _tied_weights_keys = ("lm_head.decoder.weight",)
371
 
372
  def __init__(self, config: NVEsmConfig):
373
  """Initialize a NVEsmForMaskedLM.
@@ -386,7 +410,6 @@ class NVEsmForMaskedLM(NVEsmPreTrainedModel):
386
  self.esm = NVEsmModel(config, add_pooling_layer=False)
387
  self.lm_head = NVEsmLMHead(config)
388
 
389
- self.init_weights()
390
  self.post_init()
391
 
392
  def get_output_embeddings(self):
@@ -614,7 +637,6 @@ class NVEsmForTokenClassification(NVEsmPreTrainedModel):
614
  init_method=lambda x: torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range),
615
  )
616
 
617
- self.init_weights()
618
  self.post_init()
619
 
620
  def forward(
 
22
  Adapted from `modeling_esm.py` in huggingface/transformers.
23
  """
24
 
25
+ from typing import ClassVar, Literal, Optional, Unpack
26
 
27
  # TODO: put import guard around transformer_engine here, with an informative error message around
28
  # installation and the nvidia docker container.
 
256
  # Meta-device init seems to break weight tying, so we re-tie the weights here.
257
  self.tie_weights()
258
 
259
+ def _init_weights(self, module):
260
+ """Initialize module weights.
261
+
262
+ We only use this method for standard pytorch modules, TE modules handle their own weight initialization through
263
+ `init_method` parameters and the `reset_parameters` method.
264
+ """
265
+ if module.__module__.startswith("transformer_engine.pytorch"):
266
+ # Notably, we need to avoid calling the parent method for TE modules, since the default _init_weights will
267
+ # assume any class with `LayerNorm` in the name should have weights initialized to 1.0; breaking
268
+ # `LayerNormLinear` and `LayerNormMLP` modules that use `weight` for the linear layer and
269
+ # `layer_norm_weight` for the layer norm. Instead, we call `reset_parameters` if the module has it and the
270
+ # weights are not in fp8. We still need to figure out why this raises an error if we're using
271
+ # `quantized_model_init`.
272
+ if hasattr(module, "reset_parameters") and not getattr(module, "primary_weights_in_fp8", False):
273
+ module.reset_parameters()
274
+ return
275
+
276
+ super()._init_weights(module)
277
+
278
+ def state_dict(self, *args, **kwargs):
279
+ """Override state_dict to filter out TransformerEngine's _extra_state keys.
280
+
281
+ TransformerEngine layers add _extra_state attributes that are not compatible with HuggingFace v5 model loading.
282
+ These are filtered out to ensure checkpoints can be loaded with from_pretrained().
283
+ """
284
+ state_dict = super().state_dict(*args, **kwargs)
285
+ # Filter out _extra_state keys which are TransformerEngine-specific and not loadable
286
+ return {k: v for k, v in state_dict.items() if not k.endswith("_extra_state")}
287
 
288
 
289
  class NVEsmModel(NVEsmPreTrainedModel):
 
391
  class NVEsmForMaskedLM(NVEsmPreTrainedModel):
392
  """NVEsmForMaskedLM is a TransformerEngine-optimized ESM model for masked language modeling."""
393
 
394
+ _tied_weights_keys: ClassVar[dict[str, str]] = {"lm_head.decoder.weight": "esm.embeddings.word_embeddings.weight"}
395
 
396
  def __init__(self, config: NVEsmConfig):
397
  """Initialize a NVEsmForMaskedLM.
 
410
  self.esm = NVEsmModel(config, add_pooling_layer=False)
411
  self.lm_head = NVEsmLMHead(config)
412
 
 
413
  self.post_init()
414
 
415
  def get_output_embeddings(self):
 
637
  init_method=lambda x: torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range),
638
  )
639
 
 
640
  self.post_init()
641
 
642
  def forward(
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:6a3cbb14dad1e96e8abd94b0cbc5bdebf600e4aab21d90af9954feea4d2da881
3
- size 2604396920
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:504f43d8ef30f80d1c8603b9839602a9b298a0819b0d569cdb3e9f77f03c735a
3
+ size 2604379992
tokenizer_config.json CHANGED
@@ -1,60 +1,18 @@
1
  {
2
- "add_bos_token": true,
3
- "add_eos_token": true,
4
- "added_tokens_decoder": {
5
- "0": {
6
- "content": "<cls>",
7
- "lstrip": false,
8
- "normalized": false,
9
- "rstrip": false,
10
- "single_word": false,
11
- "special": true
12
- },
13
- "1": {
14
- "content": "<pad>",
15
- "lstrip": false,
16
- "normalized": false,
17
- "rstrip": false,
18
- "single_word": false,
19
- "special": true
20
- },
21
- "2": {
22
- "content": "<eos>",
23
- "lstrip": false,
24
- "normalized": false,
25
- "rstrip": false,
26
- "single_word": false,
27
- "special": true
28
- },
29
- "3": {
30
- "content": "<unk>",
31
- "lstrip": false,
32
- "normalized": false,
33
- "rstrip": false,
34
- "single_word": false,
35
- "special": true
36
- },
37
- "32": {
38
- "content": "<mask>",
39
- "lstrip": false,
40
- "normalized": false,
41
- "rstrip": false,
42
- "single_word": false,
43
- "special": true
44
- }
45
- },
46
  "bos_token": "<cls>",
47
  "clean_up_tokenization_spaces": false,
48
  "cls_token": "<cls>",
49
  "eos_token": "<eos>",
50
- "extra_special_tokens": {},
51
  "mask_token": "<mask>",
52
  "model_input_names": [
53
  "input_ids",
54
  "attention_mask"
55
  ],
56
  "model_max_length": 1000000000000000019884624838656,
 
57
  "pad_token": "<pad>",
58
- "tokenizer_class": "PreTrainedTokenizerFast",
59
  "unk_token": "<unk>"
60
  }
 
1
  {
2
+ "backend": "tokenizers",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  "bos_token": "<cls>",
4
  "clean_up_tokenization_spaces": false,
5
  "cls_token": "<cls>",
6
  "eos_token": "<eos>",
7
+ "is_local": true,
8
  "mask_token": "<mask>",
9
  "model_input_names": [
10
  "input_ids",
11
  "attention_mask"
12
  ],
13
  "model_max_length": 1000000000000000019884624838656,
14
+ "model_specific_special_tokens": {},
15
  "pad_token": "<pad>",
16
+ "tokenizer_class": "TokenizersBackend",
17
  "unk_token": "<unk>"
18
  }