davidhd commited on
Commit
74d44ea
·
verified ·
1 Parent(s): c47b641

Upload esm2-flash-3B (ESM2 with flash attention)

Browse files
config.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Esm2FlashModel"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.0,
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_esm2_flash.Esm2FlashConfig",
8
+ "AutoModel": "modeling_esm2_flash.Esm2FlashModel",
9
+ "AutoModelForMaskedLM": "modeling_esm2_flash.Esm2FlashForMaskedLM"
10
+ },
11
+ "emb_layer_norm_before": false,
12
+ "esmfold_config": null,
13
+ "hidden_dropout_prob": 0.0,
14
+ "hidden_size": 2560,
15
+ "initializer_range": 0.02,
16
+ "intermediate_size": 10240,
17
+ "is_folding_model": false,
18
+ "layer_norm_eps": 1e-05,
19
+ "mask_token_id": 32,
20
+ "max_position_embeddings": 1026,
21
+ "model_type": "esm2_flash",
22
+ "num_attention_heads": 40,
23
+ "num_hidden_layers": 36,
24
+ "pad_token_id": 1,
25
+ "position_embedding_type": "rotary",
26
+ "token_dropout": true,
27
+ "torch_dtype": "float32",
28
+ "transformers_version": "4.49.0",
29
+ "use_cache": true,
30
+ "vocab_list": null,
31
+ "vocab_size": 33
32
+ }
configuration_esm2_flash.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ESM2-Flash model configuration"""
2
+
3
+ from transformers.configuration_utils import PretrainedConfig
4
+
5
+
6
+ class Esm2FlashConfig(PretrainedConfig):
7
+ r"""
8
+ Configuration class for ESM2-Flash, an ESM2 model with flash attention
9
+ and packed-sequence support.
10
+
11
+ All architectural parameters mirror EsmConfig exactly so that pretrained
12
+ ESM2 weights can be loaded with zero conversion.
13
+
14
+ Additional parameters control the attention backend selection.
15
+
16
+ Args:
17
+ vocab_size (`int`, *optional*):
18
+ Vocabulary size of the ESM model.
19
+ mask_token_id (`int`, *optional*):
20
+ Index of the mask token in the vocabulary.
21
+ pad_token_id (`int`, *optional*):
22
+ Index of the padding token in the vocabulary.
23
+ hidden_size (`int`, *optional*, defaults to 768):
24
+ Dimensionality of the encoder layers and the pooler layer.
25
+ num_hidden_layers (`int`, *optional*, defaults to 12):
26
+ Number of hidden layers in the Transformer encoder.
27
+ num_attention_heads (`int`, *optional*, defaults to 12):
28
+ Number of attention heads for each attention layer.
29
+ intermediate_size (`int`, *optional*, defaults to 3072):
30
+ Dimensionality of the feed-forward layer.
31
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
32
+ Dropout probability for fully connected layers.
33
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
34
+ Dropout ratio for attention probabilities.
35
+ max_position_embeddings (`int`, *optional*, defaults to 1026):
36
+ Maximum sequence length the model might ever be used with.
37
+ initializer_range (`float`, *optional*, defaults to 0.02):
38
+ Std of the truncated_normal_initializer for weight init.
39
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
40
+ Epsilon for layer normalization.
41
+ position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
42
+ Type of position embedding. One of "absolute", "relative_key",
43
+ "relative_key_query", "rotary".
44
+ emb_layer_norm_before (`bool`, *optional*):
45
+ Whether to apply layer normalization after embeddings.
46
+ token_dropout (`bool`, defaults to `False`):
47
+ When enabled, masked tokens are zeroed and embeddings are rescaled.
48
+ """
49
+
50
+ model_type = "esm2_flash"
51
+
52
+ def __init__(
53
+ self,
54
+ vocab_size=None,
55
+ mask_token_id=None,
56
+ pad_token_id=None,
57
+ hidden_size=768,
58
+ num_hidden_layers=12,
59
+ num_attention_heads=12,
60
+ intermediate_size=3072,
61
+ hidden_dropout_prob=0.1,
62
+ attention_probs_dropout_prob=0.1,
63
+ max_position_embeddings=1026,
64
+ initializer_range=0.02,
65
+ layer_norm_eps=1e-12,
66
+ position_embedding_type="absolute",
67
+ use_cache=True,
68
+ emb_layer_norm_before=None,
69
+ token_dropout=False,
70
+ is_folding_model=False,
71
+ esmfold_config=None,
72
+ vocab_list=None,
73
+ **kwargs,
74
+ ):
75
+ super().__init__(pad_token_id=pad_token_id, mask_token_id=mask_token_id, **kwargs)
76
+
77
+ self.vocab_size = vocab_size
78
+ self.hidden_size = hidden_size
79
+ self.num_hidden_layers = num_hidden_layers
80
+ self.num_attention_heads = num_attention_heads
81
+ self.intermediate_size = intermediate_size
82
+ self.hidden_dropout_prob = hidden_dropout_prob
83
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
84
+ self.max_position_embeddings = max_position_embeddings
85
+ self.initializer_range = initializer_range
86
+ self.layer_norm_eps = layer_norm_eps
87
+ self.position_embedding_type = position_embedding_type
88
+ self.use_cache = use_cache
89
+ self.emb_layer_norm_before = emb_layer_norm_before
90
+ self.token_dropout = token_dropout
91
+ self.is_folding_model = is_folding_model
92
+ self.esmfold_config = esmfold_config
93
+ self.vocab_list = vocab_list
94
+
95
+ # Encoder-only: these are kept for config compatibility but unused
96
+ self.is_decoder = False
97
+ self.add_cross_attention = False
model-00001-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ad1f96196bb1a6ac4510c5985d60021143496c58bfa3bafdba9f6fc65cf95ee1
3
+ size 4941283704
model-00002-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:570d4618d0bc0ffab9a8c03b08c1ca3f3d279e6d6f7f87ceaed7055cc95c6ed7
3
+ size 4930429456
model-00003-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ae6ca76ffa01e0c1a13d0b974c64e1f2a76bfda734ea40149c46a9fdc42b4de9
3
+ size 1494870972
model.safetensors.index.json ADDED
@@ -0,0 +1,627 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 11366512772
4
+ },
5
+ "weight_map": {
6
+ "contact_head.regression.bias": "model-00003-of-00003.safetensors",
7
+ "contact_head.regression.weight": "model-00003-of-00003.safetensors",
8
+ "embeddings.position_embeddings.weight": "model-00001-of-00003.safetensors",
9
+ "embeddings.word_embeddings.weight": "model-00001-of-00003.safetensors",
10
+ "encoder.emb_layer_norm_after.bias": "model-00003-of-00003.safetensors",
11
+ "encoder.emb_layer_norm_after.weight": "model-00003-of-00003.safetensors",
12
+ "encoder.layer.0.LayerNorm.bias": "model-00001-of-00003.safetensors",
13
+ "encoder.layer.0.LayerNorm.weight": "model-00001-of-00003.safetensors",
14
+ "encoder.layer.0.attention.LayerNorm.bias": "model-00001-of-00003.safetensors",
15
+ "encoder.layer.0.attention.LayerNorm.weight": "model-00001-of-00003.safetensors",
16
+ "encoder.layer.0.attention.output.dense.bias": "model-00001-of-00003.safetensors",
17
+ "encoder.layer.0.attention.output.dense.weight": "model-00001-of-00003.safetensors",
18
+ "encoder.layer.0.attention.self.key.bias": "model-00001-of-00003.safetensors",
19
+ "encoder.layer.0.attention.self.key.weight": "model-00001-of-00003.safetensors",
20
+ "encoder.layer.0.attention.self.query.bias": "model-00001-of-00003.safetensors",
21
+ "encoder.layer.0.attention.self.query.weight": "model-00001-of-00003.safetensors",
22
+ "encoder.layer.0.attention.self.rotary_embeddings.inv_freq": "model-00001-of-00003.safetensors",
23
+ "encoder.layer.0.attention.self.value.bias": "model-00001-of-00003.safetensors",
24
+ "encoder.layer.0.attention.self.value.weight": "model-00001-of-00003.safetensors",
25
+ "encoder.layer.0.intermediate.dense.bias": "model-00001-of-00003.safetensors",
26
+ "encoder.layer.0.intermediate.dense.weight": "model-00001-of-00003.safetensors",
27
+ "encoder.layer.0.output.dense.bias": "model-00001-of-00003.safetensors",
28
+ "encoder.layer.0.output.dense.weight": "model-00001-of-00003.safetensors",
29
+ "encoder.layer.1.LayerNorm.bias": "model-00001-of-00003.safetensors",
30
+ "encoder.layer.1.LayerNorm.weight": "model-00001-of-00003.safetensors",
31
+ "encoder.layer.1.attention.LayerNorm.bias": "model-00001-of-00003.safetensors",
32
+ "encoder.layer.1.attention.LayerNorm.weight": "model-00001-of-00003.safetensors",
33
+ "encoder.layer.1.attention.output.dense.bias": "model-00001-of-00003.safetensors",
34
+ "encoder.layer.1.attention.output.dense.weight": "model-00001-of-00003.safetensors",
35
+ "encoder.layer.1.attention.self.key.bias": "model-00001-of-00003.safetensors",
36
+ "encoder.layer.1.attention.self.key.weight": "model-00001-of-00003.safetensors",
37
+ "encoder.layer.1.attention.self.query.bias": "model-00001-of-00003.safetensors",
38
+ "encoder.layer.1.attention.self.query.weight": "model-00001-of-00003.safetensors",
39
+ "encoder.layer.1.attention.self.rotary_embeddings.inv_freq": "model-00001-of-00003.safetensors",
40
+ "encoder.layer.1.attention.self.value.bias": "model-00001-of-00003.safetensors",
41
+ "encoder.layer.1.attention.self.value.weight": "model-00001-of-00003.safetensors",
42
+ "encoder.layer.1.intermediate.dense.bias": "model-00001-of-00003.safetensors",
43
+ "encoder.layer.1.intermediate.dense.weight": "model-00001-of-00003.safetensors",
44
+ "encoder.layer.1.output.dense.bias": "model-00001-of-00003.safetensors",
45
+ "encoder.layer.1.output.dense.weight": "model-00001-of-00003.safetensors",
46
+ "encoder.layer.10.LayerNorm.bias": "model-00001-of-00003.safetensors",
47
+ "encoder.layer.10.LayerNorm.weight": "model-00001-of-00003.safetensors",
48
+ "encoder.layer.10.attention.LayerNorm.bias": "model-00001-of-00003.safetensors",
49
+ "encoder.layer.10.attention.LayerNorm.weight": "model-00001-of-00003.safetensors",
50
+ "encoder.layer.10.attention.output.dense.bias": "model-00001-of-00003.safetensors",
51
+ "encoder.layer.10.attention.output.dense.weight": "model-00001-of-00003.safetensors",
52
+ "encoder.layer.10.attention.self.key.bias": "model-00001-of-00003.safetensors",
53
+ "encoder.layer.10.attention.self.key.weight": "model-00001-of-00003.safetensors",
54
+ "encoder.layer.10.attention.self.query.bias": "model-00001-of-00003.safetensors",
55
+ "encoder.layer.10.attention.self.query.weight": "model-00001-of-00003.safetensors",
56
+ "encoder.layer.10.attention.self.rotary_embeddings.inv_freq": "model-00001-of-00003.safetensors",
57
+ "encoder.layer.10.attention.self.value.bias": "model-00001-of-00003.safetensors",
58
+ "encoder.layer.10.attention.self.value.weight": "model-00001-of-00003.safetensors",
59
+ "encoder.layer.10.intermediate.dense.bias": "model-00001-of-00003.safetensors",
60
+ "encoder.layer.10.intermediate.dense.weight": "model-00001-of-00003.safetensors",
61
+ "encoder.layer.10.output.dense.bias": "model-00001-of-00003.safetensors",
62
+ "encoder.layer.10.output.dense.weight": "model-00001-of-00003.safetensors",
63
+ "encoder.layer.11.LayerNorm.bias": "model-00001-of-00003.safetensors",
64
+ "encoder.layer.11.LayerNorm.weight": "model-00001-of-00003.safetensors",
65
+ "encoder.layer.11.attention.LayerNorm.bias": "model-00001-of-00003.safetensors",
66
+ "encoder.layer.11.attention.LayerNorm.weight": "model-00001-of-00003.safetensors",
67
+ "encoder.layer.11.attention.output.dense.bias": "model-00001-of-00003.safetensors",
68
+ "encoder.layer.11.attention.output.dense.weight": "model-00001-of-00003.safetensors",
69
+ "encoder.layer.11.attention.self.key.bias": "model-00001-of-00003.safetensors",
70
+ "encoder.layer.11.attention.self.key.weight": "model-00001-of-00003.safetensors",
71
+ "encoder.layer.11.attention.self.query.bias": "model-00001-of-00003.safetensors",
72
+ "encoder.layer.11.attention.self.query.weight": "model-00001-of-00003.safetensors",
73
+ "encoder.layer.11.attention.self.rotary_embeddings.inv_freq": "model-00001-of-00003.safetensors",
74
+ "encoder.layer.11.attention.self.value.bias": "model-00001-of-00003.safetensors",
75
+ "encoder.layer.11.attention.self.value.weight": "model-00001-of-00003.safetensors",
76
+ "encoder.layer.11.intermediate.dense.bias": "model-00001-of-00003.safetensors",
77
+ "encoder.layer.11.intermediate.dense.weight": "model-00001-of-00003.safetensors",
78
+ "encoder.layer.11.output.dense.bias": "model-00001-of-00003.safetensors",
79
+ "encoder.layer.11.output.dense.weight": "model-00001-of-00003.safetensors",
80
+ "encoder.layer.12.LayerNorm.bias": "model-00001-of-00003.safetensors",
81
+ "encoder.layer.12.LayerNorm.weight": "model-00001-of-00003.safetensors",
82
+ "encoder.layer.12.attention.LayerNorm.bias": "model-00001-of-00003.safetensors",
83
+ "encoder.layer.12.attention.LayerNorm.weight": "model-00001-of-00003.safetensors",
84
+ "encoder.layer.12.attention.output.dense.bias": "model-00001-of-00003.safetensors",
85
+ "encoder.layer.12.attention.output.dense.weight": "model-00001-of-00003.safetensors",
86
+ "encoder.layer.12.attention.self.key.bias": "model-00001-of-00003.safetensors",
87
+ "encoder.layer.12.attention.self.key.weight": "model-00001-of-00003.safetensors",
88
+ "encoder.layer.12.attention.self.query.bias": "model-00001-of-00003.safetensors",
89
+ "encoder.layer.12.attention.self.query.weight": "model-00001-of-00003.safetensors",
90
+ "encoder.layer.12.attention.self.rotary_embeddings.inv_freq": "model-00001-of-00003.safetensors",
91
+ "encoder.layer.12.attention.self.value.bias": "model-00001-of-00003.safetensors",
92
+ "encoder.layer.12.attention.self.value.weight": "model-00001-of-00003.safetensors",
93
+ "encoder.layer.12.intermediate.dense.bias": "model-00001-of-00003.safetensors",
94
+ "encoder.layer.12.intermediate.dense.weight": "model-00001-of-00003.safetensors",
95
+ "encoder.layer.12.output.dense.bias": "model-00001-of-00003.safetensors",
96
+ "encoder.layer.12.output.dense.weight": "model-00001-of-00003.safetensors",
97
+ "encoder.layer.13.LayerNorm.bias": "model-00001-of-00003.safetensors",
98
+ "encoder.layer.13.LayerNorm.weight": "model-00001-of-00003.safetensors",
99
+ "encoder.layer.13.attention.LayerNorm.bias": "model-00001-of-00003.safetensors",
100
+ "encoder.layer.13.attention.LayerNorm.weight": "model-00001-of-00003.safetensors",
101
+ "encoder.layer.13.attention.output.dense.bias": "model-00001-of-00003.safetensors",
102
+ "encoder.layer.13.attention.output.dense.weight": "model-00001-of-00003.safetensors",
103
+ "encoder.layer.13.attention.self.key.bias": "model-00001-of-00003.safetensors",
104
+ "encoder.layer.13.attention.self.key.weight": "model-00001-of-00003.safetensors",
105
+ "encoder.layer.13.attention.self.query.bias": "model-00001-of-00003.safetensors",
106
+ "encoder.layer.13.attention.self.query.weight": "model-00001-of-00003.safetensors",
107
+ "encoder.layer.13.attention.self.rotary_embeddings.inv_freq": "model-00001-of-00003.safetensors",
108
+ "encoder.layer.13.attention.self.value.bias": "model-00001-of-00003.safetensors",
109
+ "encoder.layer.13.attention.self.value.weight": "model-00001-of-00003.safetensors",
110
+ "encoder.layer.13.intermediate.dense.bias": "model-00001-of-00003.safetensors",
111
+ "encoder.layer.13.intermediate.dense.weight": "model-00001-of-00003.safetensors",
112
+ "encoder.layer.13.output.dense.bias": "model-00001-of-00003.safetensors",
113
+ "encoder.layer.13.output.dense.weight": "model-00001-of-00003.safetensors",
114
+ "encoder.layer.14.LayerNorm.bias": "model-00001-of-00003.safetensors",
115
+ "encoder.layer.14.LayerNorm.weight": "model-00001-of-00003.safetensors",
116
+ "encoder.layer.14.attention.LayerNorm.bias": "model-00001-of-00003.safetensors",
117
+ "encoder.layer.14.attention.LayerNorm.weight": "model-00001-of-00003.safetensors",
118
+ "encoder.layer.14.attention.output.dense.bias": "model-00001-of-00003.safetensors",
119
+ "encoder.layer.14.attention.output.dense.weight": "model-00001-of-00003.safetensors",
120
+ "encoder.layer.14.attention.self.key.bias": "model-00001-of-00003.safetensors",
121
+ "encoder.layer.14.attention.self.key.weight": "model-00001-of-00003.safetensors",
122
+ "encoder.layer.14.attention.self.query.bias": "model-00001-of-00003.safetensors",
123
+ "encoder.layer.14.attention.self.query.weight": "model-00001-of-00003.safetensors",
124
+ "encoder.layer.14.attention.self.rotary_embeddings.inv_freq": "model-00001-of-00003.safetensors",
125
+ "encoder.layer.14.attention.self.value.bias": "model-00001-of-00003.safetensors",
126
+ "encoder.layer.14.attention.self.value.weight": "model-00001-of-00003.safetensors",
127
+ "encoder.layer.14.intermediate.dense.bias": "model-00001-of-00003.safetensors",
128
+ "encoder.layer.14.intermediate.dense.weight": "model-00001-of-00003.safetensors",
129
+ "encoder.layer.14.output.dense.bias": "model-00001-of-00003.safetensors",
130
+ "encoder.layer.14.output.dense.weight": "model-00001-of-00003.safetensors",
131
+ "encoder.layer.15.LayerNorm.bias": "model-00002-of-00003.safetensors",
132
+ "encoder.layer.15.LayerNorm.weight": "model-00002-of-00003.safetensors",
133
+ "encoder.layer.15.attention.LayerNorm.bias": "model-00001-of-00003.safetensors",
134
+ "encoder.layer.15.attention.LayerNorm.weight": "model-00001-of-00003.safetensors",
135
+ "encoder.layer.15.attention.output.dense.bias": "model-00001-of-00003.safetensors",
136
+ "encoder.layer.15.attention.output.dense.weight": "model-00001-of-00003.safetensors",
137
+ "encoder.layer.15.attention.self.key.bias": "model-00001-of-00003.safetensors",
138
+ "encoder.layer.15.attention.self.key.weight": "model-00001-of-00003.safetensors",
139
+ "encoder.layer.15.attention.self.query.bias": "model-00001-of-00003.safetensors",
140
+ "encoder.layer.15.attention.self.query.weight": "model-00001-of-00003.safetensors",
141
+ "encoder.layer.15.attention.self.rotary_embeddings.inv_freq": "model-00001-of-00003.safetensors",
142
+ "encoder.layer.15.attention.self.value.bias": "model-00001-of-00003.safetensors",
143
+ "encoder.layer.15.attention.self.value.weight": "model-00001-of-00003.safetensors",
144
+ "encoder.layer.15.intermediate.dense.bias": "model-00001-of-00003.safetensors",
145
+ "encoder.layer.15.intermediate.dense.weight": "model-00001-of-00003.safetensors",
146
+ "encoder.layer.15.output.dense.bias": "model-00002-of-00003.safetensors",
147
+ "encoder.layer.15.output.dense.weight": "model-00002-of-00003.safetensors",
148
+ "encoder.layer.16.LayerNorm.bias": "model-00002-of-00003.safetensors",
149
+ "encoder.layer.16.LayerNorm.weight": "model-00002-of-00003.safetensors",
150
+ "encoder.layer.16.attention.LayerNorm.bias": "model-00002-of-00003.safetensors",
151
+ "encoder.layer.16.attention.LayerNorm.weight": "model-00002-of-00003.safetensors",
152
+ "encoder.layer.16.attention.output.dense.bias": "model-00002-of-00003.safetensors",
153
+ "encoder.layer.16.attention.output.dense.weight": "model-00002-of-00003.safetensors",
154
+ "encoder.layer.16.attention.self.key.bias": "model-00002-of-00003.safetensors",
155
+ "encoder.layer.16.attention.self.key.weight": "model-00002-of-00003.safetensors",
156
+ "encoder.layer.16.attention.self.query.bias": "model-00002-of-00003.safetensors",
157
+ "encoder.layer.16.attention.self.query.weight": "model-00002-of-00003.safetensors",
158
+ "encoder.layer.16.attention.self.rotary_embeddings.inv_freq": "model-00002-of-00003.safetensors",
159
+ "encoder.layer.16.attention.self.value.bias": "model-00002-of-00003.safetensors",
160
+ "encoder.layer.16.attention.self.value.weight": "model-00002-of-00003.safetensors",
161
+ "encoder.layer.16.intermediate.dense.bias": "model-00002-of-00003.safetensors",
162
+ "encoder.layer.16.intermediate.dense.weight": "model-00002-of-00003.safetensors",
163
+ "encoder.layer.16.output.dense.bias": "model-00002-of-00003.safetensors",
164
+ "encoder.layer.16.output.dense.weight": "model-00002-of-00003.safetensors",
165
+ "encoder.layer.17.LayerNorm.bias": "model-00002-of-00003.safetensors",
166
+ "encoder.layer.17.LayerNorm.weight": "model-00002-of-00003.safetensors",
167
+ "encoder.layer.17.attention.LayerNorm.bias": "model-00002-of-00003.safetensors",
168
+ "encoder.layer.17.attention.LayerNorm.weight": "model-00002-of-00003.safetensors",
169
+ "encoder.layer.17.attention.output.dense.bias": "model-00002-of-00003.safetensors",
170
+ "encoder.layer.17.attention.output.dense.weight": "model-00002-of-00003.safetensors",
171
+ "encoder.layer.17.attention.self.key.bias": "model-00002-of-00003.safetensors",
172
+ "encoder.layer.17.attention.self.key.weight": "model-00002-of-00003.safetensors",
173
+ "encoder.layer.17.attention.self.query.bias": "model-00002-of-00003.safetensors",
174
+ "encoder.layer.17.attention.self.query.weight": "model-00002-of-00003.safetensors",
175
+ "encoder.layer.17.attention.self.rotary_embeddings.inv_freq": "model-00002-of-00003.safetensors",
176
+ "encoder.layer.17.attention.self.value.bias": "model-00002-of-00003.safetensors",
177
+ "encoder.layer.17.attention.self.value.weight": "model-00002-of-00003.safetensors",
178
+ "encoder.layer.17.intermediate.dense.bias": "model-00002-of-00003.safetensors",
179
+ "encoder.layer.17.intermediate.dense.weight": "model-00002-of-00003.safetensors",
180
+ "encoder.layer.17.output.dense.bias": "model-00002-of-00003.safetensors",
181
+ "encoder.layer.17.output.dense.weight": "model-00002-of-00003.safetensors",
182
+ "encoder.layer.18.LayerNorm.bias": "model-00002-of-00003.safetensors",
183
+ "encoder.layer.18.LayerNorm.weight": "model-00002-of-00003.safetensors",
184
+ "encoder.layer.18.attention.LayerNorm.bias": "model-00002-of-00003.safetensors",
185
+ "encoder.layer.18.attention.LayerNorm.weight": "model-00002-of-00003.safetensors",
186
+ "encoder.layer.18.attention.output.dense.bias": "model-00002-of-00003.safetensors",
187
+ "encoder.layer.18.attention.output.dense.weight": "model-00002-of-00003.safetensors",
188
+ "encoder.layer.18.attention.self.key.bias": "model-00002-of-00003.safetensors",
189
+ "encoder.layer.18.attention.self.key.weight": "model-00002-of-00003.safetensors",
190
+ "encoder.layer.18.attention.self.query.bias": "model-00002-of-00003.safetensors",
191
+ "encoder.layer.18.attention.self.query.weight": "model-00002-of-00003.safetensors",
192
+ "encoder.layer.18.attention.self.rotary_embeddings.inv_freq": "model-00002-of-00003.safetensors",
193
+ "encoder.layer.18.attention.self.value.bias": "model-00002-of-00003.safetensors",
194
+ "encoder.layer.18.attention.self.value.weight": "model-00002-of-00003.safetensors",
195
+ "encoder.layer.18.intermediate.dense.bias": "model-00002-of-00003.safetensors",
196
+ "encoder.layer.18.intermediate.dense.weight": "model-00002-of-00003.safetensors",
197
+ "encoder.layer.18.output.dense.bias": "model-00002-of-00003.safetensors",
198
+ "encoder.layer.18.output.dense.weight": "model-00002-of-00003.safetensors",
199
+ "encoder.layer.19.LayerNorm.bias": "model-00002-of-00003.safetensors",
200
+ "encoder.layer.19.LayerNorm.weight": "model-00002-of-00003.safetensors",
201
+ "encoder.layer.19.attention.LayerNorm.bias": "model-00002-of-00003.safetensors",
202
+ "encoder.layer.19.attention.LayerNorm.weight": "model-00002-of-00003.safetensors",
203
+ "encoder.layer.19.attention.output.dense.bias": "model-00002-of-00003.safetensors",
204
+ "encoder.layer.19.attention.output.dense.weight": "model-00002-of-00003.safetensors",
205
+ "encoder.layer.19.attention.self.key.bias": "model-00002-of-00003.safetensors",
206
+ "encoder.layer.19.attention.self.key.weight": "model-00002-of-00003.safetensors",
207
+ "encoder.layer.19.attention.self.query.bias": "model-00002-of-00003.safetensors",
208
+ "encoder.layer.19.attention.self.query.weight": "model-00002-of-00003.safetensors",
209
+ "encoder.layer.19.attention.self.rotary_embeddings.inv_freq": "model-00002-of-00003.safetensors",
210
+ "encoder.layer.19.attention.self.value.bias": "model-00002-of-00003.safetensors",
211
+ "encoder.layer.19.attention.self.value.weight": "model-00002-of-00003.safetensors",
212
+ "encoder.layer.19.intermediate.dense.bias": "model-00002-of-00003.safetensors",
213
+ "encoder.layer.19.intermediate.dense.weight": "model-00002-of-00003.safetensors",
214
+ "encoder.layer.19.output.dense.bias": "model-00002-of-00003.safetensors",
215
+ "encoder.layer.19.output.dense.weight": "model-00002-of-00003.safetensors",
216
+ "encoder.layer.2.LayerNorm.bias": "model-00001-of-00003.safetensors",
217
+ "encoder.layer.2.LayerNorm.weight": "model-00001-of-00003.safetensors",
218
+ "encoder.layer.2.attention.LayerNorm.bias": "model-00001-of-00003.safetensors",
219
+ "encoder.layer.2.attention.LayerNorm.weight": "model-00001-of-00003.safetensors",
220
+ "encoder.layer.2.attention.output.dense.bias": "model-00001-of-00003.safetensors",
221
+ "encoder.layer.2.attention.output.dense.weight": "model-00001-of-00003.safetensors",
222
+ "encoder.layer.2.attention.self.key.bias": "model-00001-of-00003.safetensors",
223
+ "encoder.layer.2.attention.self.key.weight": "model-00001-of-00003.safetensors",
224
+ "encoder.layer.2.attention.self.query.bias": "model-00001-of-00003.safetensors",
225
+ "encoder.layer.2.attention.self.query.weight": "model-00001-of-00003.safetensors",
226
+ "encoder.layer.2.attention.self.rotary_embeddings.inv_freq": "model-00001-of-00003.safetensors",
227
+ "encoder.layer.2.attention.self.value.bias": "model-00001-of-00003.safetensors",
228
+ "encoder.layer.2.attention.self.value.weight": "model-00001-of-00003.safetensors",
229
+ "encoder.layer.2.intermediate.dense.bias": "model-00001-of-00003.safetensors",
230
+ "encoder.layer.2.intermediate.dense.weight": "model-00001-of-00003.safetensors",
231
+ "encoder.layer.2.output.dense.bias": "model-00001-of-00003.safetensors",
232
+ "encoder.layer.2.output.dense.weight": "model-00001-of-00003.safetensors",
233
+ "encoder.layer.20.LayerNorm.bias": "model-00002-of-00003.safetensors",
234
+ "encoder.layer.20.LayerNorm.weight": "model-00002-of-00003.safetensors",
235
+ "encoder.layer.20.attention.LayerNorm.bias": "model-00002-of-00003.safetensors",
236
+ "encoder.layer.20.attention.LayerNorm.weight": "model-00002-of-00003.safetensors",
237
+ "encoder.layer.20.attention.output.dense.bias": "model-00002-of-00003.safetensors",
238
+ "encoder.layer.20.attention.output.dense.weight": "model-00002-of-00003.safetensors",
239
+ "encoder.layer.20.attention.self.key.bias": "model-00002-of-00003.safetensors",
240
+ "encoder.layer.20.attention.self.key.weight": "model-00002-of-00003.safetensors",
241
+ "encoder.layer.20.attention.self.query.bias": "model-00002-of-00003.safetensors",
242
+ "encoder.layer.20.attention.self.query.weight": "model-00002-of-00003.safetensors",
243
+ "encoder.layer.20.attention.self.rotary_embeddings.inv_freq": "model-00002-of-00003.safetensors",
244
+ "encoder.layer.20.attention.self.value.bias": "model-00002-of-00003.safetensors",
245
+ "encoder.layer.20.attention.self.value.weight": "model-00002-of-00003.safetensors",
246
+ "encoder.layer.20.intermediate.dense.bias": "model-00002-of-00003.safetensors",
247
+ "encoder.layer.20.intermediate.dense.weight": "model-00002-of-00003.safetensors",
248
+ "encoder.layer.20.output.dense.bias": "model-00002-of-00003.safetensors",
249
+ "encoder.layer.20.output.dense.weight": "model-00002-of-00003.safetensors",
250
+ "encoder.layer.21.LayerNorm.bias": "model-00002-of-00003.safetensors",
251
+ "encoder.layer.21.LayerNorm.weight": "model-00002-of-00003.safetensors",
252
+ "encoder.layer.21.attention.LayerNorm.bias": "model-00002-of-00003.safetensors",
253
+ "encoder.layer.21.attention.LayerNorm.weight": "model-00002-of-00003.safetensors",
254
+ "encoder.layer.21.attention.output.dense.bias": "model-00002-of-00003.safetensors",
255
+ "encoder.layer.21.attention.output.dense.weight": "model-00002-of-00003.safetensors",
256
+ "encoder.layer.21.attention.self.key.bias": "model-00002-of-00003.safetensors",
257
+ "encoder.layer.21.attention.self.key.weight": "model-00002-of-00003.safetensors",
258
+ "encoder.layer.21.attention.self.query.bias": "model-00002-of-00003.safetensors",
259
+ "encoder.layer.21.attention.self.query.weight": "model-00002-of-00003.safetensors",
260
+ "encoder.layer.21.attention.self.rotary_embeddings.inv_freq": "model-00002-of-00003.safetensors",
261
+ "encoder.layer.21.attention.self.value.bias": "model-00002-of-00003.safetensors",
262
+ "encoder.layer.21.attention.self.value.weight": "model-00002-of-00003.safetensors",
263
+ "encoder.layer.21.intermediate.dense.bias": "model-00002-of-00003.safetensors",
264
+ "encoder.layer.21.intermediate.dense.weight": "model-00002-of-00003.safetensors",
265
+ "encoder.layer.21.output.dense.bias": "model-00002-of-00003.safetensors",
266
+ "encoder.layer.21.output.dense.weight": "model-00002-of-00003.safetensors",
267
+ "encoder.layer.22.LayerNorm.bias": "model-00002-of-00003.safetensors",
268
+ "encoder.layer.22.LayerNorm.weight": "model-00002-of-00003.safetensors",
269
+ "encoder.layer.22.attention.LayerNorm.bias": "model-00002-of-00003.safetensors",
270
+ "encoder.layer.22.attention.LayerNorm.weight": "model-00002-of-00003.safetensors",
271
+ "encoder.layer.22.attention.output.dense.bias": "model-00002-of-00003.safetensors",
272
+ "encoder.layer.22.attention.output.dense.weight": "model-00002-of-00003.safetensors",
273
+ "encoder.layer.22.attention.self.key.bias": "model-00002-of-00003.safetensors",
274
+ "encoder.layer.22.attention.self.key.weight": "model-00002-of-00003.safetensors",
275
+ "encoder.layer.22.attention.self.query.bias": "model-00002-of-00003.safetensors",
276
+ "encoder.layer.22.attention.self.query.weight": "model-00002-of-00003.safetensors",
277
+ "encoder.layer.22.attention.self.rotary_embeddings.inv_freq": "model-00002-of-00003.safetensors",
278
+ "encoder.layer.22.attention.self.value.bias": "model-00002-of-00003.safetensors",
279
+ "encoder.layer.22.attention.self.value.weight": "model-00002-of-00003.safetensors",
280
+ "encoder.layer.22.intermediate.dense.bias": "model-00002-of-00003.safetensors",
281
+ "encoder.layer.22.intermediate.dense.weight": "model-00002-of-00003.safetensors",
282
+ "encoder.layer.22.output.dense.bias": "model-00002-of-00003.safetensors",
283
+ "encoder.layer.22.output.dense.weight": "model-00002-of-00003.safetensors",
284
+ "encoder.layer.23.LayerNorm.bias": "model-00002-of-00003.safetensors",
285
+ "encoder.layer.23.LayerNorm.weight": "model-00002-of-00003.safetensors",
286
+ "encoder.layer.23.attention.LayerNorm.bias": "model-00002-of-00003.safetensors",
287
+ "encoder.layer.23.attention.LayerNorm.weight": "model-00002-of-00003.safetensors",
288
+ "encoder.layer.23.attention.output.dense.bias": "model-00002-of-00003.safetensors",
289
+ "encoder.layer.23.attention.output.dense.weight": "model-00002-of-00003.safetensors",
290
+ "encoder.layer.23.attention.self.key.bias": "model-00002-of-00003.safetensors",
291
+ "encoder.layer.23.attention.self.key.weight": "model-00002-of-00003.safetensors",
292
+ "encoder.layer.23.attention.self.query.bias": "model-00002-of-00003.safetensors",
293
+ "encoder.layer.23.attention.self.query.weight": "model-00002-of-00003.safetensors",
294
+ "encoder.layer.23.attention.self.rotary_embeddings.inv_freq": "model-00002-of-00003.safetensors",
295
+ "encoder.layer.23.attention.self.value.bias": "model-00002-of-00003.safetensors",
296
+ "encoder.layer.23.attention.self.value.weight": "model-00002-of-00003.safetensors",
297
+ "encoder.layer.23.intermediate.dense.bias": "model-00002-of-00003.safetensors",
298
+ "encoder.layer.23.intermediate.dense.weight": "model-00002-of-00003.safetensors",
299
+ "encoder.layer.23.output.dense.bias": "model-00002-of-00003.safetensors",
300
+ "encoder.layer.23.output.dense.weight": "model-00002-of-00003.safetensors",
301
+ "encoder.layer.24.LayerNorm.bias": "model-00002-of-00003.safetensors",
302
+ "encoder.layer.24.LayerNorm.weight": "model-00002-of-00003.safetensors",
303
+ "encoder.layer.24.attention.LayerNorm.bias": "model-00002-of-00003.safetensors",
304
+ "encoder.layer.24.attention.LayerNorm.weight": "model-00002-of-00003.safetensors",
305
+ "encoder.layer.24.attention.output.dense.bias": "model-00002-of-00003.safetensors",
306
+ "encoder.layer.24.attention.output.dense.weight": "model-00002-of-00003.safetensors",
307
+ "encoder.layer.24.attention.self.key.bias": "model-00002-of-00003.safetensors",
308
+ "encoder.layer.24.attention.self.key.weight": "model-00002-of-00003.safetensors",
309
+ "encoder.layer.24.attention.self.query.bias": "model-00002-of-00003.safetensors",
310
+ "encoder.layer.24.attention.self.query.weight": "model-00002-of-00003.safetensors",
311
+ "encoder.layer.24.attention.self.rotary_embeddings.inv_freq": "model-00002-of-00003.safetensors",
312
+ "encoder.layer.24.attention.self.value.bias": "model-00002-of-00003.safetensors",
313
+ "encoder.layer.24.attention.self.value.weight": "model-00002-of-00003.safetensors",
314
+ "encoder.layer.24.intermediate.dense.bias": "model-00002-of-00003.safetensors",
315
+ "encoder.layer.24.intermediate.dense.weight": "model-00002-of-00003.safetensors",
316
+ "encoder.layer.24.output.dense.bias": "model-00002-of-00003.safetensors",
317
+ "encoder.layer.24.output.dense.weight": "model-00002-of-00003.safetensors",
318
+ "encoder.layer.25.LayerNorm.bias": "model-00002-of-00003.safetensors",
319
+ "encoder.layer.25.LayerNorm.weight": "model-00002-of-00003.safetensors",
320
+ "encoder.layer.25.attention.LayerNorm.bias": "model-00002-of-00003.safetensors",
321
+ "encoder.layer.25.attention.LayerNorm.weight": "model-00002-of-00003.safetensors",
322
+ "encoder.layer.25.attention.output.dense.bias": "model-00002-of-00003.safetensors",
323
+ "encoder.layer.25.attention.output.dense.weight": "model-00002-of-00003.safetensors",
324
+ "encoder.layer.25.attention.self.key.bias": "model-00002-of-00003.safetensors",
325
+ "encoder.layer.25.attention.self.key.weight": "model-00002-of-00003.safetensors",
326
+ "encoder.layer.25.attention.self.query.bias": "model-00002-of-00003.safetensors",
327
+ "encoder.layer.25.attention.self.query.weight": "model-00002-of-00003.safetensors",
328
+ "encoder.layer.25.attention.self.rotary_embeddings.inv_freq": "model-00002-of-00003.safetensors",
329
+ "encoder.layer.25.attention.self.value.bias": "model-00002-of-00003.safetensors",
330
+ "encoder.layer.25.attention.self.value.weight": "model-00002-of-00003.safetensors",
331
+ "encoder.layer.25.intermediate.dense.bias": "model-00002-of-00003.safetensors",
332
+ "encoder.layer.25.intermediate.dense.weight": "model-00002-of-00003.safetensors",
333
+ "encoder.layer.25.output.dense.bias": "model-00002-of-00003.safetensors",
334
+ "encoder.layer.25.output.dense.weight": "model-00002-of-00003.safetensors",
335
+ "encoder.layer.26.LayerNorm.bias": "model-00002-of-00003.safetensors",
336
+ "encoder.layer.26.LayerNorm.weight": "model-00002-of-00003.safetensors",
337
+ "encoder.layer.26.attention.LayerNorm.bias": "model-00002-of-00003.safetensors",
338
+ "encoder.layer.26.attention.LayerNorm.weight": "model-00002-of-00003.safetensors",
339
+ "encoder.layer.26.attention.output.dense.bias": "model-00002-of-00003.safetensors",
340
+ "encoder.layer.26.attention.output.dense.weight": "model-00002-of-00003.safetensors",
341
+ "encoder.layer.26.attention.self.key.bias": "model-00002-of-00003.safetensors",
342
+ "encoder.layer.26.attention.self.key.weight": "model-00002-of-00003.safetensors",
343
+ "encoder.layer.26.attention.self.query.bias": "model-00002-of-00003.safetensors",
344
+ "encoder.layer.26.attention.self.query.weight": "model-00002-of-00003.safetensors",
345
+ "encoder.layer.26.attention.self.rotary_embeddings.inv_freq": "model-00002-of-00003.safetensors",
346
+ "encoder.layer.26.attention.self.value.bias": "model-00002-of-00003.safetensors",
347
+ "encoder.layer.26.attention.self.value.weight": "model-00002-of-00003.safetensors",
348
+ "encoder.layer.26.intermediate.dense.bias": "model-00002-of-00003.safetensors",
349
+ "encoder.layer.26.intermediate.dense.weight": "model-00002-of-00003.safetensors",
350
+ "encoder.layer.26.output.dense.bias": "model-00002-of-00003.safetensors",
351
+ "encoder.layer.26.output.dense.weight": "model-00002-of-00003.safetensors",
352
+ "encoder.layer.27.LayerNorm.bias": "model-00002-of-00003.safetensors",
353
+ "encoder.layer.27.LayerNorm.weight": "model-00002-of-00003.safetensors",
354
+ "encoder.layer.27.attention.LayerNorm.bias": "model-00002-of-00003.safetensors",
355
+ "encoder.layer.27.attention.LayerNorm.weight": "model-00002-of-00003.safetensors",
356
+ "encoder.layer.27.attention.output.dense.bias": "model-00002-of-00003.safetensors",
357
+ "encoder.layer.27.attention.output.dense.weight": "model-00002-of-00003.safetensors",
358
+ "encoder.layer.27.attention.self.key.bias": "model-00002-of-00003.safetensors",
359
+ "encoder.layer.27.attention.self.key.weight": "model-00002-of-00003.safetensors",
360
+ "encoder.layer.27.attention.self.query.bias": "model-00002-of-00003.safetensors",
361
+ "encoder.layer.27.attention.self.query.weight": "model-00002-of-00003.safetensors",
362
+ "encoder.layer.27.attention.self.rotary_embeddings.inv_freq": "model-00002-of-00003.safetensors",
363
+ "encoder.layer.27.attention.self.value.bias": "model-00002-of-00003.safetensors",
364
+ "encoder.layer.27.attention.self.value.weight": "model-00002-of-00003.safetensors",
365
+ "encoder.layer.27.intermediate.dense.bias": "model-00002-of-00003.safetensors",
366
+ "encoder.layer.27.intermediate.dense.weight": "model-00002-of-00003.safetensors",
367
+ "encoder.layer.27.output.dense.bias": "model-00002-of-00003.safetensors",
368
+ "encoder.layer.27.output.dense.weight": "model-00002-of-00003.safetensors",
369
+ "encoder.layer.28.LayerNorm.bias": "model-00002-of-00003.safetensors",
370
+ "encoder.layer.28.LayerNorm.weight": "model-00002-of-00003.safetensors",
371
+ "encoder.layer.28.attention.LayerNorm.bias": "model-00002-of-00003.safetensors",
372
+ "encoder.layer.28.attention.LayerNorm.weight": "model-00002-of-00003.safetensors",
373
+ "encoder.layer.28.attention.output.dense.bias": "model-00002-of-00003.safetensors",
374
+ "encoder.layer.28.attention.output.dense.weight": "model-00002-of-00003.safetensors",
375
+ "encoder.layer.28.attention.self.key.bias": "model-00002-of-00003.safetensors",
376
+ "encoder.layer.28.attention.self.key.weight": "model-00002-of-00003.safetensors",
377
+ "encoder.layer.28.attention.self.query.bias": "model-00002-of-00003.safetensors",
378
+ "encoder.layer.28.attention.self.query.weight": "model-00002-of-00003.safetensors",
379
+ "encoder.layer.28.attention.self.rotary_embeddings.inv_freq": "model-00002-of-00003.safetensors",
380
+ "encoder.layer.28.attention.self.value.bias": "model-00002-of-00003.safetensors",
381
+ "encoder.layer.28.attention.self.value.weight": "model-00002-of-00003.safetensors",
382
+ "encoder.layer.28.intermediate.dense.bias": "model-00002-of-00003.safetensors",
383
+ "encoder.layer.28.intermediate.dense.weight": "model-00002-of-00003.safetensors",
384
+ "encoder.layer.28.output.dense.bias": "model-00002-of-00003.safetensors",
385
+ "encoder.layer.28.output.dense.weight": "model-00002-of-00003.safetensors",
386
+ "encoder.layer.29.LayerNorm.bias": "model-00002-of-00003.safetensors",
387
+ "encoder.layer.29.LayerNorm.weight": "model-00002-of-00003.safetensors",
388
+ "encoder.layer.29.attention.LayerNorm.bias": "model-00002-of-00003.safetensors",
389
+ "encoder.layer.29.attention.LayerNorm.weight": "model-00002-of-00003.safetensors",
390
+ "encoder.layer.29.attention.output.dense.bias": "model-00002-of-00003.safetensors",
391
+ "encoder.layer.29.attention.output.dense.weight": "model-00002-of-00003.safetensors",
392
+ "encoder.layer.29.attention.self.key.bias": "model-00002-of-00003.safetensors",
393
+ "encoder.layer.29.attention.self.key.weight": "model-00002-of-00003.safetensors",
394
+ "encoder.layer.29.attention.self.query.bias": "model-00002-of-00003.safetensors",
395
+ "encoder.layer.29.attention.self.query.weight": "model-00002-of-00003.safetensors",
396
+ "encoder.layer.29.attention.self.rotary_embeddings.inv_freq": "model-00002-of-00003.safetensors",
397
+ "encoder.layer.29.attention.self.value.bias": "model-00002-of-00003.safetensors",
398
+ "encoder.layer.29.attention.self.value.weight": "model-00002-of-00003.safetensors",
399
+ "encoder.layer.29.intermediate.dense.bias": "model-00002-of-00003.safetensors",
400
+ "encoder.layer.29.intermediate.dense.weight": "model-00002-of-00003.safetensors",
401
+ "encoder.layer.29.output.dense.bias": "model-00002-of-00003.safetensors",
402
+ "encoder.layer.29.output.dense.weight": "model-00002-of-00003.safetensors",
403
+ "encoder.layer.3.LayerNorm.bias": "model-00001-of-00003.safetensors",
404
+ "encoder.layer.3.LayerNorm.weight": "model-00001-of-00003.safetensors",
405
+ "encoder.layer.3.attention.LayerNorm.bias": "model-00001-of-00003.safetensors",
406
+ "encoder.layer.3.attention.LayerNorm.weight": "model-00001-of-00003.safetensors",
407
+ "encoder.layer.3.attention.output.dense.bias": "model-00001-of-00003.safetensors",
408
+ "encoder.layer.3.attention.output.dense.weight": "model-00001-of-00003.safetensors",
409
+ "encoder.layer.3.attention.self.key.bias": "model-00001-of-00003.safetensors",
410
+ "encoder.layer.3.attention.self.key.weight": "model-00001-of-00003.safetensors",
411
+ "encoder.layer.3.attention.self.query.bias": "model-00001-of-00003.safetensors",
412
+ "encoder.layer.3.attention.self.query.weight": "model-00001-of-00003.safetensors",
413
+ "encoder.layer.3.attention.self.rotary_embeddings.inv_freq": "model-00001-of-00003.safetensors",
414
+ "encoder.layer.3.attention.self.value.bias": "model-00001-of-00003.safetensors",
415
+ "encoder.layer.3.attention.self.value.weight": "model-00001-of-00003.safetensors",
416
+ "encoder.layer.3.intermediate.dense.bias": "model-00001-of-00003.safetensors",
417
+ "encoder.layer.3.intermediate.dense.weight": "model-00001-of-00003.safetensors",
418
+ "encoder.layer.3.output.dense.bias": "model-00001-of-00003.safetensors",
419
+ "encoder.layer.3.output.dense.weight": "model-00001-of-00003.safetensors",
420
+ "encoder.layer.30.LayerNorm.bias": "model-00002-of-00003.safetensors",
421
+ "encoder.layer.30.LayerNorm.weight": "model-00002-of-00003.safetensors",
422
+ "encoder.layer.30.attention.LayerNorm.bias": "model-00002-of-00003.safetensors",
423
+ "encoder.layer.30.attention.LayerNorm.weight": "model-00002-of-00003.safetensors",
424
+ "encoder.layer.30.attention.output.dense.bias": "model-00002-of-00003.safetensors",
425
+ "encoder.layer.30.attention.output.dense.weight": "model-00002-of-00003.safetensors",
426
+ "encoder.layer.30.attention.self.key.bias": "model-00002-of-00003.safetensors",
427
+ "encoder.layer.30.attention.self.key.weight": "model-00002-of-00003.safetensors",
428
+ "encoder.layer.30.attention.self.query.bias": "model-00002-of-00003.safetensors",
429
+ "encoder.layer.30.attention.self.query.weight": "model-00002-of-00003.safetensors",
430
+ "encoder.layer.30.attention.self.rotary_embeddings.inv_freq": "model-00002-of-00003.safetensors",
431
+ "encoder.layer.30.attention.self.value.bias": "model-00002-of-00003.safetensors",
432
+ "encoder.layer.30.attention.self.value.weight": "model-00002-of-00003.safetensors",
433
+ "encoder.layer.30.intermediate.dense.bias": "model-00002-of-00003.safetensors",
434
+ "encoder.layer.30.intermediate.dense.weight": "model-00002-of-00003.safetensors",
435
+ "encoder.layer.30.output.dense.bias": "model-00002-of-00003.safetensors",
436
+ "encoder.layer.30.output.dense.weight": "model-00002-of-00003.safetensors",
437
+ "encoder.layer.31.LayerNorm.bias": "model-00003-of-00003.safetensors",
438
+ "encoder.layer.31.LayerNorm.weight": "model-00003-of-00003.safetensors",
439
+ "encoder.layer.31.attention.LayerNorm.bias": "model-00002-of-00003.safetensors",
440
+ "encoder.layer.31.attention.LayerNorm.weight": "model-00002-of-00003.safetensors",
441
+ "encoder.layer.31.attention.output.dense.bias": "model-00002-of-00003.safetensors",
442
+ "encoder.layer.31.attention.output.dense.weight": "model-00002-of-00003.safetensors",
443
+ "encoder.layer.31.attention.self.key.bias": "model-00002-of-00003.safetensors",
444
+ "encoder.layer.31.attention.self.key.weight": "model-00002-of-00003.safetensors",
445
+ "encoder.layer.31.attention.self.query.bias": "model-00002-of-00003.safetensors",
446
+ "encoder.layer.31.attention.self.query.weight": "model-00002-of-00003.safetensors",
447
+ "encoder.layer.31.attention.self.rotary_embeddings.inv_freq": "model-00002-of-00003.safetensors",
448
+ "encoder.layer.31.attention.self.value.bias": "model-00002-of-00003.safetensors",
449
+ "encoder.layer.31.attention.self.value.weight": "model-00002-of-00003.safetensors",
450
+ "encoder.layer.31.intermediate.dense.bias": "model-00003-of-00003.safetensors",
451
+ "encoder.layer.31.intermediate.dense.weight": "model-00003-of-00003.safetensors",
452
+ "encoder.layer.31.output.dense.bias": "model-00003-of-00003.safetensors",
453
+ "encoder.layer.31.output.dense.weight": "model-00003-of-00003.safetensors",
454
+ "encoder.layer.32.LayerNorm.bias": "model-00003-of-00003.safetensors",
455
+ "encoder.layer.32.LayerNorm.weight": "model-00003-of-00003.safetensors",
456
+ "encoder.layer.32.attention.LayerNorm.bias": "model-00003-of-00003.safetensors",
457
+ "encoder.layer.32.attention.LayerNorm.weight": "model-00003-of-00003.safetensors",
458
+ "encoder.layer.32.attention.output.dense.bias": "model-00003-of-00003.safetensors",
459
+ "encoder.layer.32.attention.output.dense.weight": "model-00003-of-00003.safetensors",
460
+ "encoder.layer.32.attention.self.key.bias": "model-00003-of-00003.safetensors",
461
+ "encoder.layer.32.attention.self.key.weight": "model-00003-of-00003.safetensors",
462
+ "encoder.layer.32.attention.self.query.bias": "model-00003-of-00003.safetensors",
463
+ "encoder.layer.32.attention.self.query.weight": "model-00003-of-00003.safetensors",
464
+ "encoder.layer.32.attention.self.rotary_embeddings.inv_freq": "model-00003-of-00003.safetensors",
465
+ "encoder.layer.32.attention.self.value.bias": "model-00003-of-00003.safetensors",
466
+ "encoder.layer.32.attention.self.value.weight": "model-00003-of-00003.safetensors",
467
+ "encoder.layer.32.intermediate.dense.bias": "model-00003-of-00003.safetensors",
468
+ "encoder.layer.32.intermediate.dense.weight": "model-00003-of-00003.safetensors",
469
+ "encoder.layer.32.output.dense.bias": "model-00003-of-00003.safetensors",
470
+ "encoder.layer.32.output.dense.weight": "model-00003-of-00003.safetensors",
471
+ "encoder.layer.33.LayerNorm.bias": "model-00003-of-00003.safetensors",
472
+ "encoder.layer.33.LayerNorm.weight": "model-00003-of-00003.safetensors",
473
+ "encoder.layer.33.attention.LayerNorm.bias": "model-00003-of-00003.safetensors",
474
+ "encoder.layer.33.attention.LayerNorm.weight": "model-00003-of-00003.safetensors",
475
+ "encoder.layer.33.attention.output.dense.bias": "model-00003-of-00003.safetensors",
476
+ "encoder.layer.33.attention.output.dense.weight": "model-00003-of-00003.safetensors",
477
+ "encoder.layer.33.attention.self.key.bias": "model-00003-of-00003.safetensors",
478
+ "encoder.layer.33.attention.self.key.weight": "model-00003-of-00003.safetensors",
479
+ "encoder.layer.33.attention.self.query.bias": "model-00003-of-00003.safetensors",
480
+ "encoder.layer.33.attention.self.query.weight": "model-00003-of-00003.safetensors",
481
+ "encoder.layer.33.attention.self.rotary_embeddings.inv_freq": "model-00003-of-00003.safetensors",
482
+ "encoder.layer.33.attention.self.value.bias": "model-00003-of-00003.safetensors",
483
+ "encoder.layer.33.attention.self.value.weight": "model-00003-of-00003.safetensors",
484
+ "encoder.layer.33.intermediate.dense.bias": "model-00003-of-00003.safetensors",
485
+ "encoder.layer.33.intermediate.dense.weight": "model-00003-of-00003.safetensors",
486
+ "encoder.layer.33.output.dense.bias": "model-00003-of-00003.safetensors",
487
+ "encoder.layer.33.output.dense.weight": "model-00003-of-00003.safetensors",
488
+ "encoder.layer.34.LayerNorm.bias": "model-00003-of-00003.safetensors",
489
+ "encoder.layer.34.LayerNorm.weight": "model-00003-of-00003.safetensors",
490
+ "encoder.layer.34.attention.LayerNorm.bias": "model-00003-of-00003.safetensors",
491
+ "encoder.layer.34.attention.LayerNorm.weight": "model-00003-of-00003.safetensors",
492
+ "encoder.layer.34.attention.output.dense.bias": "model-00003-of-00003.safetensors",
493
+ "encoder.layer.34.attention.output.dense.weight": "model-00003-of-00003.safetensors",
494
+ "encoder.layer.34.attention.self.key.bias": "model-00003-of-00003.safetensors",
495
+ "encoder.layer.34.attention.self.key.weight": "model-00003-of-00003.safetensors",
496
+ "encoder.layer.34.attention.self.query.bias": "model-00003-of-00003.safetensors",
497
+ "encoder.layer.34.attention.self.query.weight": "model-00003-of-00003.safetensors",
498
+ "encoder.layer.34.attention.self.rotary_embeddings.inv_freq": "model-00003-of-00003.safetensors",
499
+ "encoder.layer.34.attention.self.value.bias": "model-00003-of-00003.safetensors",
500
+ "encoder.layer.34.attention.self.value.weight": "model-00003-of-00003.safetensors",
501
+ "encoder.layer.34.intermediate.dense.bias": "model-00003-of-00003.safetensors",
502
+ "encoder.layer.34.intermediate.dense.weight": "model-00003-of-00003.safetensors",
503
+ "encoder.layer.34.output.dense.bias": "model-00003-of-00003.safetensors",
504
+ "encoder.layer.34.output.dense.weight": "model-00003-of-00003.safetensors",
505
+ "encoder.layer.35.LayerNorm.bias": "model-00003-of-00003.safetensors",
506
+ "encoder.layer.35.LayerNorm.weight": "model-00003-of-00003.safetensors",
507
+ "encoder.layer.35.attention.LayerNorm.bias": "model-00003-of-00003.safetensors",
508
+ "encoder.layer.35.attention.LayerNorm.weight": "model-00003-of-00003.safetensors",
509
+ "encoder.layer.35.attention.output.dense.bias": "model-00003-of-00003.safetensors",
510
+ "encoder.layer.35.attention.output.dense.weight": "model-00003-of-00003.safetensors",
511
+ "encoder.layer.35.attention.self.key.bias": "model-00003-of-00003.safetensors",
512
+ "encoder.layer.35.attention.self.key.weight": "model-00003-of-00003.safetensors",
513
+ "encoder.layer.35.attention.self.query.bias": "model-00003-of-00003.safetensors",
514
+ "encoder.layer.35.attention.self.query.weight": "model-00003-of-00003.safetensors",
515
+ "encoder.layer.35.attention.self.rotary_embeddings.inv_freq": "model-00003-of-00003.safetensors",
516
+ "encoder.layer.35.attention.self.value.bias": "model-00003-of-00003.safetensors",
517
+ "encoder.layer.35.attention.self.value.weight": "model-00003-of-00003.safetensors",
518
+ "encoder.layer.35.intermediate.dense.bias": "model-00003-of-00003.safetensors",
519
+ "encoder.layer.35.intermediate.dense.weight": "model-00003-of-00003.safetensors",
520
+ "encoder.layer.35.output.dense.bias": "model-00003-of-00003.safetensors",
521
+ "encoder.layer.35.output.dense.weight": "model-00003-of-00003.safetensors",
522
+ "encoder.layer.4.LayerNorm.bias": "model-00001-of-00003.safetensors",
523
+ "encoder.layer.4.LayerNorm.weight": "model-00001-of-00003.safetensors",
524
+ "encoder.layer.4.attention.LayerNorm.bias": "model-00001-of-00003.safetensors",
525
+ "encoder.layer.4.attention.LayerNorm.weight": "model-00001-of-00003.safetensors",
526
+ "encoder.layer.4.attention.output.dense.bias": "model-00001-of-00003.safetensors",
527
+ "encoder.layer.4.attention.output.dense.weight": "model-00001-of-00003.safetensors",
528
+ "encoder.layer.4.attention.self.key.bias": "model-00001-of-00003.safetensors",
529
+ "encoder.layer.4.attention.self.key.weight": "model-00001-of-00003.safetensors",
530
+ "encoder.layer.4.attention.self.query.bias": "model-00001-of-00003.safetensors",
531
+ "encoder.layer.4.attention.self.query.weight": "model-00001-of-00003.safetensors",
532
+ "encoder.layer.4.attention.self.rotary_embeddings.inv_freq": "model-00001-of-00003.safetensors",
533
+ "encoder.layer.4.attention.self.value.bias": "model-00001-of-00003.safetensors",
534
+ "encoder.layer.4.attention.self.value.weight": "model-00001-of-00003.safetensors",
535
+ "encoder.layer.4.intermediate.dense.bias": "model-00001-of-00003.safetensors",
536
+ "encoder.layer.4.intermediate.dense.weight": "model-00001-of-00003.safetensors",
537
+ "encoder.layer.4.output.dense.bias": "model-00001-of-00003.safetensors",
538
+ "encoder.layer.4.output.dense.weight": "model-00001-of-00003.safetensors",
539
+ "encoder.layer.5.LayerNorm.bias": "model-00001-of-00003.safetensors",
540
+ "encoder.layer.5.LayerNorm.weight": "model-00001-of-00003.safetensors",
541
+ "encoder.layer.5.attention.LayerNorm.bias": "model-00001-of-00003.safetensors",
542
+ "encoder.layer.5.attention.LayerNorm.weight": "model-00001-of-00003.safetensors",
543
+ "encoder.layer.5.attention.output.dense.bias": "model-00001-of-00003.safetensors",
544
+ "encoder.layer.5.attention.output.dense.weight": "model-00001-of-00003.safetensors",
545
+ "encoder.layer.5.attention.self.key.bias": "model-00001-of-00003.safetensors",
546
+ "encoder.layer.5.attention.self.key.weight": "model-00001-of-00003.safetensors",
547
+ "encoder.layer.5.attention.self.query.bias": "model-00001-of-00003.safetensors",
548
+ "encoder.layer.5.attention.self.query.weight": "model-00001-of-00003.safetensors",
549
+ "encoder.layer.5.attention.self.rotary_embeddings.inv_freq": "model-00001-of-00003.safetensors",
550
+ "encoder.layer.5.attention.self.value.bias": "model-00001-of-00003.safetensors",
551
+ "encoder.layer.5.attention.self.value.weight": "model-00001-of-00003.safetensors",
552
+ "encoder.layer.5.intermediate.dense.bias": "model-00001-of-00003.safetensors",
553
+ "encoder.layer.5.intermediate.dense.weight": "model-00001-of-00003.safetensors",
554
+ "encoder.layer.5.output.dense.bias": "model-00001-of-00003.safetensors",
555
+ "encoder.layer.5.output.dense.weight": "model-00001-of-00003.safetensors",
556
+ "encoder.layer.6.LayerNorm.bias": "model-00001-of-00003.safetensors",
557
+ "encoder.layer.6.LayerNorm.weight": "model-00001-of-00003.safetensors",
558
+ "encoder.layer.6.attention.LayerNorm.bias": "model-00001-of-00003.safetensors",
559
+ "encoder.layer.6.attention.LayerNorm.weight": "model-00001-of-00003.safetensors",
560
+ "encoder.layer.6.attention.output.dense.bias": "model-00001-of-00003.safetensors",
561
+ "encoder.layer.6.attention.output.dense.weight": "model-00001-of-00003.safetensors",
562
+ "encoder.layer.6.attention.self.key.bias": "model-00001-of-00003.safetensors",
563
+ "encoder.layer.6.attention.self.key.weight": "model-00001-of-00003.safetensors",
564
+ "encoder.layer.6.attention.self.query.bias": "model-00001-of-00003.safetensors",
565
+ "encoder.layer.6.attention.self.query.weight": "model-00001-of-00003.safetensors",
566
+ "encoder.layer.6.attention.self.rotary_embeddings.inv_freq": "model-00001-of-00003.safetensors",
567
+ "encoder.layer.6.attention.self.value.bias": "model-00001-of-00003.safetensors",
568
+ "encoder.layer.6.attention.self.value.weight": "model-00001-of-00003.safetensors",
569
+ "encoder.layer.6.intermediate.dense.bias": "model-00001-of-00003.safetensors",
570
+ "encoder.layer.6.intermediate.dense.weight": "model-00001-of-00003.safetensors",
571
+ "encoder.layer.6.output.dense.bias": "model-00001-of-00003.safetensors",
572
+ "encoder.layer.6.output.dense.weight": "model-00001-of-00003.safetensors",
573
+ "encoder.layer.7.LayerNorm.bias": "model-00001-of-00003.safetensors",
574
+ "encoder.layer.7.LayerNorm.weight": "model-00001-of-00003.safetensors",
575
+ "encoder.layer.7.attention.LayerNorm.bias": "model-00001-of-00003.safetensors",
576
+ "encoder.layer.7.attention.LayerNorm.weight": "model-00001-of-00003.safetensors",
577
+ "encoder.layer.7.attention.output.dense.bias": "model-00001-of-00003.safetensors",
578
+ "encoder.layer.7.attention.output.dense.weight": "model-00001-of-00003.safetensors",
579
+ "encoder.layer.7.attention.self.key.bias": "model-00001-of-00003.safetensors",
580
+ "encoder.layer.7.attention.self.key.weight": "model-00001-of-00003.safetensors",
581
+ "encoder.layer.7.attention.self.query.bias": "model-00001-of-00003.safetensors",
582
+ "encoder.layer.7.attention.self.query.weight": "model-00001-of-00003.safetensors",
583
+ "encoder.layer.7.attention.self.rotary_embeddings.inv_freq": "model-00001-of-00003.safetensors",
584
+ "encoder.layer.7.attention.self.value.bias": "model-00001-of-00003.safetensors",
585
+ "encoder.layer.7.attention.self.value.weight": "model-00001-of-00003.safetensors",
586
+ "encoder.layer.7.intermediate.dense.bias": "model-00001-of-00003.safetensors",
587
+ "encoder.layer.7.intermediate.dense.weight": "model-00001-of-00003.safetensors",
588
+ "encoder.layer.7.output.dense.bias": "model-00001-of-00003.safetensors",
589
+ "encoder.layer.7.output.dense.weight": "model-00001-of-00003.safetensors",
590
+ "encoder.layer.8.LayerNorm.bias": "model-00001-of-00003.safetensors",
591
+ "encoder.layer.8.LayerNorm.weight": "model-00001-of-00003.safetensors",
592
+ "encoder.layer.8.attention.LayerNorm.bias": "model-00001-of-00003.safetensors",
593
+ "encoder.layer.8.attention.LayerNorm.weight": "model-00001-of-00003.safetensors",
594
+ "encoder.layer.8.attention.output.dense.bias": "model-00001-of-00003.safetensors",
595
+ "encoder.layer.8.attention.output.dense.weight": "model-00001-of-00003.safetensors",
596
+ "encoder.layer.8.attention.self.key.bias": "model-00001-of-00003.safetensors",
597
+ "encoder.layer.8.attention.self.key.weight": "model-00001-of-00003.safetensors",
598
+ "encoder.layer.8.attention.self.query.bias": "model-00001-of-00003.safetensors",
599
+ "encoder.layer.8.attention.self.query.weight": "model-00001-of-00003.safetensors",
600
+ "encoder.layer.8.attention.self.rotary_embeddings.inv_freq": "model-00001-of-00003.safetensors",
601
+ "encoder.layer.8.attention.self.value.bias": "model-00001-of-00003.safetensors",
602
+ "encoder.layer.8.attention.self.value.weight": "model-00001-of-00003.safetensors",
603
+ "encoder.layer.8.intermediate.dense.bias": "model-00001-of-00003.safetensors",
604
+ "encoder.layer.8.intermediate.dense.weight": "model-00001-of-00003.safetensors",
605
+ "encoder.layer.8.output.dense.bias": "model-00001-of-00003.safetensors",
606
+ "encoder.layer.8.output.dense.weight": "model-00001-of-00003.safetensors",
607
+ "encoder.layer.9.LayerNorm.bias": "model-00001-of-00003.safetensors",
608
+ "encoder.layer.9.LayerNorm.weight": "model-00001-of-00003.safetensors",
609
+ "encoder.layer.9.attention.LayerNorm.bias": "model-00001-of-00003.safetensors",
610
+ "encoder.layer.9.attention.LayerNorm.weight": "model-00001-of-00003.safetensors",
611
+ "encoder.layer.9.attention.output.dense.bias": "model-00001-of-00003.safetensors",
612
+ "encoder.layer.9.attention.output.dense.weight": "model-00001-of-00003.safetensors",
613
+ "encoder.layer.9.attention.self.key.bias": "model-00001-of-00003.safetensors",
614
+ "encoder.layer.9.attention.self.key.weight": "model-00001-of-00003.safetensors",
615
+ "encoder.layer.9.attention.self.query.bias": "model-00001-of-00003.safetensors",
616
+ "encoder.layer.9.attention.self.query.weight": "model-00001-of-00003.safetensors",
617
+ "encoder.layer.9.attention.self.rotary_embeddings.inv_freq": "model-00001-of-00003.safetensors",
618
+ "encoder.layer.9.attention.self.value.bias": "model-00001-of-00003.safetensors",
619
+ "encoder.layer.9.attention.self.value.weight": "model-00001-of-00003.safetensors",
620
+ "encoder.layer.9.intermediate.dense.bias": "model-00001-of-00003.safetensors",
621
+ "encoder.layer.9.intermediate.dense.weight": "model-00001-of-00003.safetensors",
622
+ "encoder.layer.9.output.dense.bias": "model-00001-of-00003.safetensors",
623
+ "encoder.layer.9.output.dense.weight": "model-00001-of-00003.safetensors",
624
+ "pooler.dense.bias": "model-00003-of-00003.safetensors",
625
+ "pooler.dense.weight": "model-00003-of-00003.safetensors"
626
+ }
627
+ }
modeling_esm2_flash.py ADDED
@@ -0,0 +1,965 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ESM2-Flash: ESM2 with flash attention and packed-sequence support.
3
+
4
+ Drop-in replacement for HuggingFace's EsmModel / EsmForMaskedLM with three
5
+ attention backends:
6
+ - flash_attn_varlen_func (packed sequences via cu_seqlens)
7
+ - scaled_dot_product_attention (default for padded sequences)
8
+ - eager matmul (when output_attentions=True)
9
+
10
+ Weight names are identical to the original ESM2 so pretrained checkpoints
11
+ load with strict=True.
12
+ """
13
+
14
+ import math
15
+ from typing import List, Optional, Tuple, Union
16
+
17
+ import torch
18
+ import torch.utils.checkpoint
19
+ from torch import nn
20
+ from torch.nn import CrossEntropyLoss
21
+ from torch.nn.functional import scaled_dot_product_attention
22
+
23
+ from transformers.modeling_outputs import (
24
+ BaseModelOutputWithPastAndCrossAttentions,
25
+ BaseModelOutputWithPoolingAndCrossAttentions,
26
+ MaskedLMOutput,
27
+ )
28
+ from transformers.modeling_utils import PreTrainedModel
29
+
30
+ try:
31
+ from .configuration_esm2_flash import Esm2FlashConfig
32
+ except ImportError:
33
+ from configuration_esm2_flash import Esm2FlashConfig
34
+
35
+ try:
36
+ from flash_attn.flash_attn_interface import flash_attn_varlen_func
37
+
38
+ FLASH_ATTN_AVAILABLE = True
39
+ except ImportError:
40
+ FLASH_ATTN_AVAILABLE = False
41
+
42
+
43
+ # ---------------------------------------------------------------------------
44
+ # Helper functions (matching original ESM2 exactly)
45
+ # ---------------------------------------------------------------------------
46
+
47
+
48
+ def rotate_half(x):
49
+ x1, x2 = x.chunk(2, dim=-1)
50
+ return torch.cat((-x2, x1), dim=-1)
51
+
52
+
53
+ def apply_rotary_pos_emb(x, cos, sin):
54
+ """Apply rotary embeddings. Supports two shape conventions:
55
+
56
+ Standard (original ESM2):
57
+ x: (batch, heads, seq, dim)
58
+ cos: (1, 1, seq, dim)
59
+ sin: (1, 1, seq, dim)
60
+
61
+ Packed:
62
+ x: (total_tokens, heads, dim)
63
+ cos: (total_tokens, 1, dim)
64
+ sin: (total_tokens, 1, dim)
65
+ """
66
+ if x.dim() == 4:
67
+ # Standard path: slice cos/sin to match x seq length
68
+ cos = cos[:, :, : x.shape[-2], :]
69
+ sin = sin[:, :, : x.shape[-2], :]
70
+ return (x * cos) + (rotate_half(x) * sin)
71
+
72
+
73
+ def gelu(x):
74
+ """Original ESM gelu. Using F.gelu yields subtly wrong results."""
75
+ return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
76
+
77
+
78
+ def symmetrize(x):
79
+ """Make layer symmetric in final two dimensions, used for contact prediction."""
80
+ return x + x.transpose(-1, -2)
81
+
82
+
83
+ def average_product_correct(x):
84
+ """Perform average product correct, used for contact prediction."""
85
+ a1 = x.sum(-1, keepdims=True)
86
+ a2 = x.sum(-2, keepdims=True)
87
+ a12 = x.sum((-1, -2), keepdims=True)
88
+ avg = a1 * a2
89
+ avg.div_(a12)
90
+ normalized = x - avg
91
+ return normalized
92
+
93
+
94
+ def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
95
+ """
96
+ Replace non-padding symbols with their position numbers.
97
+ Position numbers begin at padding_idx+1. Padding symbols are ignored.
98
+ """
99
+ mask = input_ids.ne(padding_idx).int()
100
+ incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
101
+ return incremental_indices.long() + padding_idx
102
+
103
+
104
+ # ---------------------------------------------------------------------------
105
+ # Rotary embeddings (extended with position_ids support for packing)
106
+ # ---------------------------------------------------------------------------
107
+
108
+
109
+ class RotaryEmbedding(torch.nn.Module):
110
+ """
111
+ Rotary position embeddings based on RoFormer. Extended to accept explicit
112
+ position_ids for packed-sequence support.
113
+ """
114
+
115
+ def __init__(self, dim: int):
116
+ super().__init__()
117
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
118
+ self.register_buffer("inv_freq", inv_freq)
119
+
120
+ self._seq_len_cached = None
121
+ self._cos_cached = None
122
+ self._sin_cached = None
123
+
124
+ def _update_cos_sin_tables(self, x, seq_dimension=2):
125
+ seq_len = x.shape[seq_dimension]
126
+
127
+ if seq_len != self._seq_len_cached or self._cos_cached.device != x.device:
128
+ self._seq_len_cached = seq_len
129
+ t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(self.inv_freq)
130
+ freqs = torch.outer(t, self.inv_freq)
131
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
132
+
133
+ self._cos_cached = emb.cos()[None, None, :, :]
134
+ self._sin_cached = emb.sin()[None, None, :, :]
135
+
136
+ return self._cos_cached, self._sin_cached
137
+
138
+ def _compute_from_position_ids(self, position_ids, device, dtype):
139
+ """Compute cos/sin tables from explicit position_ids (for packed sequences).
140
+
141
+ Args:
142
+ position_ids: (total_tokens,) int tensor, 0-indexed per sub-sequence
143
+ device: target device
144
+ dtype: target dtype for inv_freq
145
+
146
+ Returns:
147
+ cos: (total_tokens, 1, dim)
148
+ sin: (total_tokens, 1, dim)
149
+ """
150
+ t = position_ids.float()
151
+ freqs = torch.outer(t, self.inv_freq.to(device=device))
152
+ emb = torch.cat((freqs, freqs), dim=-1)
153
+ cos = emb.cos().unsqueeze(1) # (total_tokens, 1, dim)
154
+ sin = emb.sin().unsqueeze(1)
155
+ return cos, sin
156
+
157
+ def forward(
158
+ self,
159
+ q: torch.Tensor,
160
+ k: torch.Tensor,
161
+ position_ids: Optional[torch.Tensor] = None,
162
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
163
+ """
164
+ Args:
165
+ q, k: query/key tensors.
166
+ Standard: (batch, heads, seq, dim)
167
+ Packed: (total_tokens, heads, dim)
168
+ position_ids: optional (total_tokens,) for packed mode
169
+ """
170
+ if position_ids is not None:
171
+ # Packed path
172
+ cos, sin = self._compute_from_position_ids(position_ids, q.device, q.dtype)
173
+ else:
174
+ # Standard path (original ESM2 behaviour)
175
+ cos, sin = self._update_cos_sin_tables(k, seq_dimension=-2)
176
+
177
+ return (
178
+ apply_rotary_pos_emb(q, cos, sin),
179
+ apply_rotary_pos_emb(k, cos, sin),
180
+ )
181
+
182
+
183
+ # ---------------------------------------------------------------------------
184
+ # Contact prediction head (unchanged from ESM2)
185
+ # ---------------------------------------------------------------------------
186
+
187
+
188
+ class EsmContactPredictionHead(nn.Module):
189
+ """Performs symmetrization, apc, and computes a logistic regression on the output features."""
190
+
191
+ def __init__(self, in_features: int, bias=True, eos_idx: int = 2):
192
+ super().__init__()
193
+ self.in_features = in_features
194
+ self.eos_idx = eos_idx
195
+ self.regression = nn.Linear(in_features, 1, bias)
196
+ self.activation = nn.Sigmoid()
197
+
198
+ def forward(self, tokens, attentions):
199
+ eos_mask = tokens.ne(self.eos_idx).to(attentions)
200
+ eos_mask = eos_mask.unsqueeze(1) * eos_mask.unsqueeze(2)
201
+ attentions = attentions * eos_mask[:, None, None, :, :]
202
+ attentions = attentions[..., :-1, :-1]
203
+ attentions = attentions[..., 1:, 1:]
204
+ batch_size, layers, heads, seqlen, _ = attentions.size()
205
+ attentions = attentions.view(batch_size, layers * heads, seqlen, seqlen)
206
+
207
+ attentions = average_product_correct(symmetrize(attentions))
208
+ attentions = attentions.permute(0, 2, 3, 1)
209
+ return self.activation(self.regression(attentions).squeeze(3))
210
+
211
+
212
+ # ---------------------------------------------------------------------------
213
+ # Embeddings
214
+ # ---------------------------------------------------------------------------
215
+
216
+
217
+ class Esm2FlashEmbeddings(nn.Module):
218
+ """
219
+ Same as EsmEmbeddings with packed-sequence support for token_dropout.
220
+ """
221
+
222
+ def __init__(self, config):
223
+ super().__init__()
224
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
225
+
226
+ if config.emb_layer_norm_before:
227
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
228
+ else:
229
+ self.layer_norm = None
230
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
231
+
232
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
233
+ self.register_buffer(
234
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
235
+ )
236
+
237
+ self.padding_idx = config.pad_token_id
238
+ self.position_embeddings = nn.Embedding(
239
+ config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
240
+ )
241
+ self.token_dropout = config.token_dropout
242
+ self.mask_token_id = config.mask_token_id
243
+
244
+ def forward(
245
+ self,
246
+ input_ids=None,
247
+ attention_mask=None,
248
+ position_ids=None,
249
+ inputs_embeds=None,
250
+ past_key_values_length=0,
251
+ cu_seqlens=None,
252
+ ):
253
+ if position_ids is None:
254
+ if input_ids is not None:
255
+ position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length)
256
+ else:
257
+ position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
258
+
259
+ if inputs_embeds is None:
260
+ inputs_embeds = self.word_embeddings(input_ids)
261
+
262
+ embeddings = inputs_embeds
263
+
264
+ if self.token_dropout:
265
+ embeddings = embeddings.masked_fill((input_ids == self.mask_token_id).unsqueeze(-1), 0.0)
266
+ mask_ratio_train = 0.15 * 0.8
267
+
268
+ if cu_seqlens is not None:
269
+ # Packed sequences: compute src_lengths from cu_seqlens
270
+ seq_lengths = (cu_seqlens[1:] - cu_seqlens[:-1]).float() # (num_seqs,)
271
+ # Count mask tokens per sequence
272
+ mask_counts = []
273
+ for i in range(len(seq_lengths)):
274
+ start, end = cu_seqlens[i], cu_seqlens[i + 1]
275
+ mask_counts.append((input_ids[0, start:end] == self.mask_token_id).sum().float())
276
+ mask_counts = torch.stack(mask_counts)
277
+ mask_ratio_observed = mask_counts / seq_lengths
278
+
279
+ # Build per-token scale factor
280
+ scale = (1 - mask_ratio_train) / (1 - mask_ratio_observed) # (num_seqs,)
281
+ # Expand to per-token
282
+ per_token_scale = torch.zeros(
283
+ embeddings.shape[1], device=embeddings.device, dtype=embeddings.dtype
284
+ )
285
+ for i in range(len(seq_lengths)):
286
+ start, end = cu_seqlens[i].item(), cu_seqlens[i + 1].item()
287
+ per_token_scale[start:end] = scale[i]
288
+ embeddings = (embeddings * per_token_scale[None, :, None]).to(embeddings.dtype)
289
+ else:
290
+ src_lengths = attention_mask.sum(-1)
291
+ mask_ratio_observed = (input_ids == self.mask_token_id).sum(-1).float() / src_lengths
292
+ embeddings = (embeddings * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None]).to(
293
+ embeddings.dtype
294
+ )
295
+
296
+ if self.position_embedding_type == "absolute":
297
+ position_embeddings = self.position_embeddings(position_ids)
298
+ embeddings = embeddings + position_embeddings
299
+
300
+ if self.layer_norm is not None:
301
+ embeddings = self.layer_norm(embeddings)
302
+ if attention_mask is not None:
303
+ embeddings = (embeddings * attention_mask.unsqueeze(-1)).to(embeddings.dtype)
304
+
305
+ return embeddings
306
+
307
+ def create_position_ids_from_inputs_embeds(self, inputs_embeds):
308
+ input_shape = inputs_embeds.size()[:-1]
309
+ sequence_length = input_shape[1]
310
+ position_ids = torch.arange(
311
+ self.padding_idx + 1,
312
+ sequence_length + self.padding_idx + 1,
313
+ dtype=torch.long,
314
+ device=inputs_embeds.device,
315
+ )
316
+ return position_ids.unsqueeze(0).expand(input_shape)
317
+
318
+
319
+ # ---------------------------------------------------------------------------
320
+ # Attention
321
+ # ---------------------------------------------------------------------------
322
+
323
+
324
+ class Esm2FlashSelfAttention(nn.Module):
325
+ """Self-attention with three backends: flash, SDPA, and eager."""
326
+
327
+ def __init__(self, config, position_embedding_type=None):
328
+ super().__init__()
329
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
330
+ raise ValueError(
331
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
332
+ f"heads ({config.num_attention_heads})"
333
+ )
334
+
335
+ self.num_attention_heads = config.num_attention_heads
336
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
337
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
338
+
339
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
340
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
341
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
342
+
343
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
344
+ self.position_embedding_type = position_embedding_type or getattr(
345
+ config, "position_embedding_type", "absolute"
346
+ )
347
+ self.rotary_embeddings = None
348
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
349
+ self.max_position_embeddings = config.max_position_embeddings
350
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
351
+ elif self.position_embedding_type == "rotary":
352
+ self.rotary_embeddings = RotaryEmbedding(dim=self.attention_head_size)
353
+
354
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
355
+ """Reshape (batch, seq, hidden) -> (batch, heads, seq, dim)."""
356
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
357
+ x = x.view(new_x_shape)
358
+ return x.permute(0, 2, 1, 3)
359
+
360
+ def forward(
361
+ self,
362
+ hidden_states: torch.Tensor,
363
+ attention_mask: Optional[torch.Tensor] = None,
364
+ head_mask: Optional[torch.Tensor] = None,
365
+ output_attentions: bool = False,
366
+ position_ids: Optional[torch.Tensor] = None,
367
+ cu_seqlens: Optional[torch.Tensor] = None,
368
+ max_seqlen: Optional[int] = None,
369
+ ) -> Tuple[torch.Tensor, ...]:
370
+ batch_size, seq_len, _ = hidden_states.shape
371
+
372
+ mixed_query_layer = self.query(hidden_states)
373
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
374
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
375
+ query_layer = self.transpose_for_scores(mixed_query_layer)
376
+
377
+ # ESM2-specific: scale query before rotary (not the scores)
378
+ query_layer = query_layer * self.attention_head_size**-0.5
379
+
380
+ # --- Flash attention path (packed sequences) ---
381
+ if cu_seqlens is not None:
382
+ assert FLASH_ATTN_AVAILABLE, (
383
+ "flash_attn is required for packed sequences. "
384
+ "Install with: pip install flash-attn --no-build-isolation"
385
+ )
386
+ assert not output_attentions, "output_attentions is not supported with packed sequences."
387
+ assert batch_size == 1, "Packed sequences require batch_size=1."
388
+
389
+ # Reshape to (total_tokens, heads, dim) for flash_attn_varlen
390
+ q = query_layer.squeeze(0).transpose(0, 1) # (heads, seq, dim) -> (seq, heads, dim)
391
+ k = key_layer.squeeze(0).transpose(0, 1)
392
+ v = value_layer.squeeze(0).transpose(0, 1)
393
+
394
+ # Apply rotary with explicit position_ids
395
+ if self.rotary_embeddings is not None:
396
+ # position_ids: (1, total_tokens) -> (total_tokens,)
397
+ pos_ids = position_ids.squeeze(0) if position_ids is not None else None
398
+ q, k = self.rotary_embeddings(q, k, position_ids=pos_ids)
399
+
400
+ # Flash attention requires fp16 or bf16
401
+ input_dtype = q.dtype
402
+ if input_dtype == torch.float32:
403
+ q = q.to(torch.bfloat16)
404
+ k = k.to(torch.bfloat16)
405
+ v = v.to(torch.bfloat16)
406
+
407
+ context_layer = flash_attn_varlen_func(
408
+ q=q,
409
+ k=k,
410
+ v=v,
411
+ cu_seqlens_q=cu_seqlens,
412
+ cu_seqlens_k=cu_seqlens,
413
+ max_seqlen_q=max_seqlen,
414
+ max_seqlen_k=max_seqlen,
415
+ dropout_p=self.dropout.p if self.training else 0.0,
416
+ causal=False,
417
+ softmax_scale=1.0, # Q is already scaled
418
+ )
419
+
420
+ # Cast back to input dtype
421
+ if input_dtype == torch.float32:
422
+ context_layer = context_layer.to(input_dtype)
423
+
424
+ # (total_tokens, heads, dim) -> (1, total_tokens, hidden_size)
425
+ context_layer = context_layer.reshape(1, seq_len, self.all_head_size)
426
+ return (context_layer,)
427
+
428
+ # --- Standard paths (padded sequences) ---
429
+
430
+ # Apply rotary with sequential positions (original ESM2 behaviour)
431
+ if self.position_embedding_type == "rotary":
432
+ query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer)
433
+
434
+ # --- Eager path (output_attentions=True) ---
435
+ if output_attentions:
436
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
437
+
438
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
439
+ seq_length = hidden_states.size()[1]
440
+ position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
441
+ position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
442
+ distance = position_ids_l - position_ids_r
443
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
444
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype)
445
+
446
+ if self.position_embedding_type == "relative_key":
447
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
448
+ attention_scores = attention_scores + relative_position_scores
449
+ elif self.position_embedding_type == "relative_key_query":
450
+ relative_position_scores_query = torch.einsum(
451
+ "bhld,lrd->bhlr", query_layer, positional_embedding
452
+ )
453
+ relative_position_scores_key = torch.einsum(
454
+ "bhrd,lrd->bhlr", key_layer, positional_embedding
455
+ )
456
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
457
+
458
+ if attention_mask is not None:
459
+ attention_scores = attention_scores + attention_mask
460
+
461
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
462
+ attention_probs = self.dropout(attention_probs)
463
+
464
+ if head_mask is not None:
465
+ attention_probs = attention_probs * head_mask
466
+
467
+ context_layer = torch.matmul(attention_probs.to(value_layer.dtype), value_layer)
468
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
469
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
470
+ context_layer = context_layer.view(new_context_layer_shape)
471
+ return (context_layer, attention_probs)
472
+
473
+ # --- SDPA path (default for padded sequences) ---
474
+ context_layer = scaled_dot_product_attention(
475
+ query=query_layer,
476
+ key=key_layer,
477
+ value=value_layer,
478
+ attn_mask=attention_mask,
479
+ dropout_p=self.dropout.p if self.training else 0.0,
480
+ scale=1.0, # Q is already scaled
481
+ )
482
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
483
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
484
+ context_layer = context_layer.view(new_context_layer_shape)
485
+ return (context_layer,)
486
+
487
+
488
+ class EsmSelfOutput(nn.Module):
489
+ def __init__(self, config):
490
+ super().__init__()
491
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
492
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
493
+
494
+ def forward(self, hidden_states, input_tensor):
495
+ hidden_states = self.dense(hidden_states)
496
+ hidden_states = self.dropout(hidden_states)
497
+ hidden_states = hidden_states + input_tensor
498
+ return hidden_states
499
+
500
+
501
+ class Esm2FlashAttention(nn.Module):
502
+ def __init__(self, config):
503
+ super().__init__()
504
+ self.self = Esm2FlashSelfAttention(config)
505
+ self.output = EsmSelfOutput(config)
506
+ self.pruned_heads = set()
507
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
508
+
509
+ def forward(
510
+ self,
511
+ hidden_states,
512
+ attention_mask=None,
513
+ head_mask=None,
514
+ output_attentions=False,
515
+ position_ids=None,
516
+ cu_seqlens=None,
517
+ max_seqlen=None,
518
+ ):
519
+ hidden_states_ln = self.LayerNorm(hidden_states)
520
+ self_outputs = self.self(
521
+ hidden_states_ln,
522
+ attention_mask=attention_mask,
523
+ head_mask=head_mask,
524
+ output_attentions=output_attentions,
525
+ position_ids=position_ids,
526
+ cu_seqlens=cu_seqlens,
527
+ max_seqlen=max_seqlen,
528
+ )
529
+ attention_output = self.output(self_outputs[0], hidden_states)
530
+ outputs = (attention_output,) + self_outputs[1:]
531
+ return outputs
532
+
533
+
534
+ # ---------------------------------------------------------------------------
535
+ # Feed-forward
536
+ # ---------------------------------------------------------------------------
537
+
538
+
539
+ class EsmIntermediate(nn.Module):
540
+ def __init__(self, config):
541
+ super().__init__()
542
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
543
+
544
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
545
+ hidden_states = self.dense(hidden_states)
546
+ hidden_states = gelu(hidden_states)
547
+ return hidden_states
548
+
549
+
550
+ class EsmOutput(nn.Module):
551
+ def __init__(self, config):
552
+ super().__init__()
553
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
554
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
555
+
556
+ def forward(self, hidden_states, input_tensor):
557
+ hidden_states = self.dense(hidden_states)
558
+ hidden_states = self.dropout(hidden_states)
559
+ hidden_states = hidden_states + input_tensor
560
+ return hidden_states
561
+
562
+
563
+ # ---------------------------------------------------------------------------
564
+ # Transformer layer
565
+ # ---------------------------------------------------------------------------
566
+
567
+
568
+ class Esm2FlashLayer(nn.Module):
569
+ def __init__(self, config):
570
+ super().__init__()
571
+ self.attention = Esm2FlashAttention(config)
572
+ self.intermediate = EsmIntermediate(config)
573
+ self.output = EsmOutput(config)
574
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
575
+
576
+ def forward(
577
+ self,
578
+ hidden_states,
579
+ attention_mask=None,
580
+ head_mask=None,
581
+ output_attentions=False,
582
+ position_ids=None,
583
+ cu_seqlens=None,
584
+ max_seqlen=None,
585
+ ):
586
+ self_attention_outputs = self.attention(
587
+ hidden_states,
588
+ attention_mask=attention_mask,
589
+ head_mask=head_mask,
590
+ output_attentions=output_attentions,
591
+ position_ids=position_ids,
592
+ cu_seqlens=cu_seqlens,
593
+ max_seqlen=max_seqlen,
594
+ )
595
+ attention_output = self_attention_outputs[0]
596
+ outputs = self_attention_outputs[1:] # attentions if output_attentions
597
+
598
+ layer_output = self.feed_forward_chunk(attention_output)
599
+ outputs = (layer_output,) + outputs
600
+ return outputs
601
+
602
+ def feed_forward_chunk(self, attention_output):
603
+ attention_output_ln = self.LayerNorm(attention_output)
604
+ intermediate_output = self.intermediate(attention_output_ln)
605
+ layer_output = self.output(intermediate_output, attention_output)
606
+ return layer_output
607
+
608
+
609
+ # ---------------------------------------------------------------------------
610
+ # Encoder (stack of layers)
611
+ # ---------------------------------------------------------------------------
612
+
613
+
614
+ class Esm2FlashEncoder(nn.Module):
615
+ def __init__(self, config):
616
+ super().__init__()
617
+ self.config = config
618
+ self.layer = nn.ModuleList([Esm2FlashLayer(config) for _ in range(config.num_hidden_layers)])
619
+ self.emb_layer_norm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
620
+ self.gradient_checkpointing = False
621
+
622
+ def forward(
623
+ self,
624
+ hidden_states,
625
+ attention_mask=None,
626
+ head_mask=None,
627
+ output_attentions=False,
628
+ output_hidden_states=False,
629
+ return_dict=True,
630
+ position_ids=None,
631
+ cu_seqlens=None,
632
+ max_seqlen=None,
633
+ ):
634
+ all_hidden_states = () if output_hidden_states else None
635
+ all_self_attentions = () if output_attentions else None
636
+
637
+ for i, layer_module in enumerate(self.layer):
638
+ if output_hidden_states:
639
+ all_hidden_states = all_hidden_states + (hidden_states,)
640
+
641
+ layer_head_mask = head_mask[i] if head_mask is not None else None
642
+
643
+ if self.gradient_checkpointing and self.training:
644
+ layer_outputs = self._gradient_checkpointing_func(
645
+ layer_module.__call__,
646
+ hidden_states,
647
+ attention_mask,
648
+ layer_head_mask,
649
+ output_attentions,
650
+ position_ids,
651
+ cu_seqlens,
652
+ max_seqlen,
653
+ )
654
+ else:
655
+ layer_outputs = layer_module(
656
+ hidden_states,
657
+ attention_mask=attention_mask,
658
+ head_mask=layer_head_mask,
659
+ output_attentions=output_attentions,
660
+ position_ids=position_ids,
661
+ cu_seqlens=cu_seqlens,
662
+ max_seqlen=max_seqlen,
663
+ )
664
+
665
+ hidden_states = layer_outputs[0]
666
+ if output_attentions:
667
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
668
+
669
+ if self.emb_layer_norm_after:
670
+ hidden_states = self.emb_layer_norm_after(hidden_states)
671
+
672
+ if output_hidden_states:
673
+ all_hidden_states = all_hidden_states + (hidden_states,)
674
+
675
+ if not return_dict:
676
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
677
+ return BaseModelOutputWithPastAndCrossAttentions(
678
+ last_hidden_state=hidden_states,
679
+ hidden_states=all_hidden_states,
680
+ attentions=all_self_attentions,
681
+ )
682
+
683
+
684
+ # ---------------------------------------------------------------------------
685
+ # Pooler
686
+ # ---------------------------------------------------------------------------
687
+
688
+
689
+ class EsmPooler(nn.Module):
690
+ def __init__(self, config):
691
+ super().__init__()
692
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
693
+ self.activation = nn.Tanh()
694
+
695
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
696
+ first_token_tensor = hidden_states[:, 0]
697
+ pooled_output = self.dense(first_token_tensor)
698
+ pooled_output = self.activation(pooled_output)
699
+ return pooled_output
700
+
701
+
702
+ # ---------------------------------------------------------------------------
703
+ # LM Head
704
+ # ---------------------------------------------------------------------------
705
+
706
+
707
+ class EsmLMHead(nn.Module):
708
+ """ESM Head for masked language modeling."""
709
+
710
+ def __init__(self, config):
711
+ super().__init__()
712
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
713
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
714
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
715
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
716
+
717
+ def forward(self, features, **kwargs):
718
+ x = self.dense(features)
719
+ x = gelu(x)
720
+ x = self.layer_norm(x)
721
+ x = self.decoder(x) + self.bias
722
+ return x
723
+
724
+
725
+ # ---------------------------------------------------------------------------
726
+ # PreTrainedModel base
727
+ # ---------------------------------------------------------------------------
728
+
729
+
730
+ class Esm2FlashPreTrainedModel(PreTrainedModel):
731
+ config_class = Esm2FlashConfig
732
+ base_model_prefix = "esm"
733
+ supports_gradient_checkpointing = True
734
+ _no_split_modules = ["Esm2FlashLayer", "Esm2FlashEmbeddings"]
735
+
736
+ def _init_weights(self, module):
737
+ if isinstance(module, nn.Linear):
738
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
739
+ if module.bias is not None:
740
+ module.bias.data.zero_()
741
+ elif isinstance(module, nn.Embedding):
742
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
743
+ if module.padding_idx is not None:
744
+ module.weight.data[module.padding_idx].zero_()
745
+ elif isinstance(module, nn.LayerNorm):
746
+ module.bias.data.zero_()
747
+ module.weight.data.fill_(1.0)
748
+
749
+
750
+ # ---------------------------------------------------------------------------
751
+ # Esm2FlashModel
752
+ # ---------------------------------------------------------------------------
753
+
754
+
755
+ class Esm2FlashModel(Esm2FlashPreTrainedModel):
756
+ """
757
+ ESM2 encoder with flash attention and packed-sequence support.
758
+
759
+ Accepts the same inputs as EsmModel, plus:
760
+ cu_seqlens: int32 tensor of cumulative sequence lengths for packing
761
+ max_seqlen: maximum sequence length in the packed batch
762
+ """
763
+
764
+ def __init__(self, config, add_pooling_layer=True):
765
+ super().__init__(config)
766
+ self.config = config
767
+
768
+ self.embeddings = Esm2FlashEmbeddings(config)
769
+ self.encoder = Esm2FlashEncoder(config)
770
+
771
+ self.pooler = EsmPooler(config) if add_pooling_layer else None
772
+
773
+ self.contact_head = EsmContactPredictionHead(
774
+ in_features=config.num_hidden_layers * config.num_attention_heads, bias=True
775
+ )
776
+
777
+ self.post_init()
778
+
779
+ def get_input_embeddings(self):
780
+ return self.embeddings.word_embeddings
781
+
782
+ def set_input_embeddings(self, value):
783
+ self.embeddings.word_embeddings = value
784
+
785
+ def forward(
786
+ self,
787
+ input_ids: Optional[torch.Tensor] = None,
788
+ attention_mask: Optional[torch.Tensor] = None,
789
+ position_ids: Optional[torch.Tensor] = None,
790
+ head_mask: Optional[torch.Tensor] = None,
791
+ inputs_embeds: Optional[torch.Tensor] = None,
792
+ cu_seqlens: Optional[torch.Tensor] = None,
793
+ max_seqlen: Optional[int] = None,
794
+ output_attentions: Optional[bool] = None,
795
+ output_hidden_states: Optional[bool] = None,
796
+ return_dict: Optional[bool] = None,
797
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
798
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
799
+ output_hidden_states = (
800
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
801
+ )
802
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
803
+
804
+ if input_ids is not None and inputs_embeds is not None:
805
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
806
+ elif input_ids is not None:
807
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
808
+ input_shape = input_ids.size()
809
+ elif inputs_embeds is not None:
810
+ input_shape = inputs_embeds.size()[:-1]
811
+ else:
812
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
813
+
814
+ batch_size, seq_length = input_shape
815
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
816
+
817
+ # --- Packed sequence path ---
818
+ if cu_seqlens is not None:
819
+ assert max_seqlen is not None, "max_seqlen must be provided when cu_seqlens is not None"
820
+ assert batch_size == 1, "Packed sequences require batch_size=1"
821
+ assert not output_attentions, "output_attentions is not supported with packed sequences"
822
+
823
+ # Compute rotary-compatible position_ids if not provided
824
+ # For packed sequences, position_ids should be 0-indexed per sub-sequence
825
+ if position_ids is None:
826
+ position_ids = torch.zeros(1, seq_length, dtype=torch.long, device=device)
827
+ for i in range(cu_seqlens.shape[0] - 1):
828
+ start = cu_seqlens[i].item()
829
+ end = cu_seqlens[i + 1].item()
830
+ position_ids[0, start:end] = torch.arange(end - start, device=device)
831
+
832
+ embedding_output = self.embeddings(
833
+ input_ids=input_ids,
834
+ position_ids=position_ids,
835
+ inputs_embeds=inputs_embeds,
836
+ cu_seqlens=cu_seqlens,
837
+ )
838
+
839
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
840
+
841
+ encoder_outputs = self.encoder(
842
+ embedding_output,
843
+ head_mask=head_mask,
844
+ output_attentions=False,
845
+ output_hidden_states=output_hidden_states,
846
+ return_dict=return_dict,
847
+ position_ids=position_ids,
848
+ cu_seqlens=cu_seqlens,
849
+ max_seqlen=max_seqlen,
850
+ )
851
+ else:
852
+ # --- Standard padded path ---
853
+ if attention_mask is None:
854
+ attention_mask = torch.ones(((batch_size, seq_length)), device=device)
855
+
856
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
857
+
858
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
859
+
860
+ embedding_output = self.embeddings(
861
+ input_ids=input_ids,
862
+ position_ids=position_ids,
863
+ attention_mask=attention_mask,
864
+ inputs_embeds=inputs_embeds,
865
+ )
866
+ encoder_outputs = self.encoder(
867
+ embedding_output,
868
+ attention_mask=extended_attention_mask,
869
+ head_mask=head_mask,
870
+ output_attentions=output_attentions,
871
+ output_hidden_states=output_hidden_states,
872
+ return_dict=return_dict,
873
+ )
874
+
875
+ sequence_output = encoder_outputs[0]
876
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
877
+
878
+ if not return_dict:
879
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
880
+
881
+ return BaseModelOutputWithPoolingAndCrossAttentions(
882
+ last_hidden_state=sequence_output,
883
+ pooler_output=pooled_output,
884
+ hidden_states=encoder_outputs.hidden_states,
885
+ attentions=encoder_outputs.attentions,
886
+ )
887
+
888
+ def predict_contacts(self, tokens, attention_mask):
889
+ attns = self(tokens, attention_mask=attention_mask, return_dict=True, output_attentions=True).attentions
890
+ attns = torch.stack(attns, dim=1)
891
+ attns *= attention_mask.unsqueeze(1).unsqueeze(2).unsqueeze(3)
892
+ attns *= attention_mask.unsqueeze(1).unsqueeze(2).unsqueeze(4)
893
+ return self.contact_head(tokens, attns)
894
+
895
+
896
+ # ---------------------------------------------------------------------------
897
+ # Esm2FlashForMaskedLM
898
+ # ---------------------------------------------------------------------------
899
+
900
+
901
+ class Esm2FlashForMaskedLM(Esm2FlashPreTrainedModel):
902
+ _tied_weights_keys = ["lm_head.decoder.weight"]
903
+
904
+ def __init__(self, config):
905
+ super().__init__(config)
906
+ self.esm = Esm2FlashModel(config, add_pooling_layer=False)
907
+ self.lm_head = EsmLMHead(config)
908
+ self.init_weights()
909
+
910
+ def get_output_embeddings(self):
911
+ return self.lm_head.decoder
912
+
913
+ def set_output_embeddings(self, new_embeddings):
914
+ self.lm_head.decoder = new_embeddings
915
+
916
+ def forward(
917
+ self,
918
+ input_ids: Optional[torch.LongTensor] = None,
919
+ attention_mask: Optional[torch.Tensor] = None,
920
+ position_ids: Optional[torch.LongTensor] = None,
921
+ head_mask: Optional[torch.Tensor] = None,
922
+ inputs_embeds: Optional[torch.FloatTensor] = None,
923
+ labels: Optional[torch.LongTensor] = None,
924
+ cu_seqlens: Optional[torch.Tensor] = None,
925
+ max_seqlen: Optional[int] = None,
926
+ output_attentions: Optional[bool] = None,
927
+ output_hidden_states: Optional[bool] = None,
928
+ return_dict: Optional[bool] = None,
929
+ ) -> Union[Tuple, MaskedLMOutput]:
930
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
931
+
932
+ outputs = self.esm(
933
+ input_ids,
934
+ attention_mask=attention_mask,
935
+ position_ids=position_ids,
936
+ head_mask=head_mask,
937
+ inputs_embeds=inputs_embeds,
938
+ cu_seqlens=cu_seqlens,
939
+ max_seqlen=max_seqlen,
940
+ output_attentions=output_attentions,
941
+ output_hidden_states=output_hidden_states,
942
+ return_dict=return_dict,
943
+ )
944
+ sequence_output = outputs[0]
945
+ prediction_scores = self.lm_head(sequence_output)
946
+
947
+ masked_lm_loss = None
948
+ if labels is not None:
949
+ loss_fct = CrossEntropyLoss()
950
+ labels = labels.to(prediction_scores.device)
951
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
952
+
953
+ if not return_dict:
954
+ output = (prediction_scores,) + outputs[2:]
955
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
956
+
957
+ return MaskedLMOutput(
958
+ loss=masked_lm_loss,
959
+ logits=prediction_scores,
960
+ hidden_states=outputs.hidden_states,
961
+ attentions=outputs.attentions,
962
+ )
963
+
964
+ def predict_contacts(self, tokens, attention_mask):
965
+ return self.esm.predict_contacts(tokens, attention_mask=attention_mask)
special_tokens_map.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": "<cls>",
3
+ "eos_token": "<eos>",
4
+ "mask_token": "<mask>",
5
+ "pad_token": "<pad>",
6
+ "unk_token": "<unk>"
7
+ }
tokenizer_config.json ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "<cls>",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "1": {
12
+ "content": "<pad>",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "2": {
20
+ "content": "<eos>",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "3": {
28
+ "content": "<unk>",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "32": {
36
+ "content": "<mask>",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ }
43
+ },
44
+ "clean_up_tokenization_spaces": false,
45
+ "cls_token": "<cls>",
46
+ "eos_token": "<eos>",
47
+ "extra_special_tokens": {},
48
+ "mask_token": "<mask>",
49
+ "model_max_length": 1000000000000000019884624838656,
50
+ "pad_token": "<pad>",
51
+ "tokenizer_class": "EsmTokenizer",
52
+ "unk_token": "<unk>"
53
+ }
vocab.txt ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <cls>
2
+ <pad>
3
+ <eos>
4
+ <unk>
5
+ L
6
+ A
7
+ G
8
+ V
9
+ S
10
+ E
11
+ R
12
+ T
13
+ I
14
+ D
15
+ P
16
+ K
17
+ Q
18
+ N
19
+ F
20
+ Y
21
+ M
22
+ H
23
+ W
24
+ C
25
+ X
26
+ B
27
+ U
28
+ Z
29
+ O
30
+ .
31
+ -
32
+ <null_1>
33
+ <mask>