mineself2016 commited on
Commit
27c9e1e
·
verified ·
1 Parent(s): 13144d7

Fix Mamba2 loading mismatch and sync model definition files

Browse files
config.json CHANGED
@@ -1,28 +1,10 @@
1
  {
2
- "model_type": "genemamba",
3
  "architectures": [
4
- "GeneMambaModel"
5
  ],
6
- "vocab_size": 25426,
7
- "max_position_embeddings": 2048,
8
- "hidden_size": 512,
9
- "num_hidden_layers": 24,
10
- "intermediate_size": 2048,
11
- "hidden_dropout_prob": 0.1,
12
- "initializer_range": 0.02,
13
- "mamba_mode": "gate",
14
- "embedding_pooling": "mean",
15
- "num_labels": 2,
16
- "pad_token_id": 1,
17
- "eos_token_id": 2,
18
- "bos_token_id": 0,
19
- "use_cache": true,
20
  "torch_dtype": "float32",
21
  "transformers_version": "4.40.2",
22
- "auto_map": {
23
- "AutoConfig": "configuration_genemamba.GeneMambaConfig",
24
- "AutoModel": "modeling_genemamba.GeneMambaModel",
25
- "AutoModelForMaskedLM": "modeling_genemamba.GeneMambaForMaskedLM",
26
- "AutoModelForSequenceClassification": "modeling_genemamba.GeneMambaForSequenceClassification"
27
- }
28
- }
 
1
  {
 
2
  "architectures": [
3
+ "MambaModel"
4
  ],
5
+ "d_model": 512,
6
+ "mamba_layer": 24,
 
 
 
 
 
 
 
 
 
 
 
 
7
  "torch_dtype": "float32",
8
  "transformers_version": "4.40.2",
9
+ "vocab_size": 25426
10
+ }
 
 
 
 
 
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:07a8347e2037f04f81aa44c66249be1a046ddb99a880d66005d8e4e64a099689
3
  size 262998656
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ccb1fcb0ee4b3ea2013099b9b187455e160d3b66b76c606715231b70b13c2784
3
  size 262998656
modeling_genemamba.py CHANGED
@@ -22,7 +22,10 @@ except ImportError:
22
  return cls
23
  return wrapper
24
 
25
- from mamba_ssm import Mamba
 
 
 
26
  from mamba_ssm.ops.triton.layer_norm import RMSNorm
27
 
28
  from .configuration_genemamba import GeneMambaConfig
@@ -46,7 +49,7 @@ class EncoderLayer(nn.Module):
46
 
47
  def __init__(self, hidden_size: int):
48
  super(EncoderLayer, self).__init__()
49
- self.mamba = Mamba(d_model=hidden_size, d_state=64, d_conv=4, expand=2)
50
 
51
  def forward(self, X: torch.Tensor) -> torch.Tensor:
52
  """
 
22
  return cls
23
  return wrapper
24
 
25
+ try:
26
+ from mamba_ssm import Mamba2 as MambaBlock
27
+ except ImportError:
28
+ from mamba_ssm import Mamba as MambaBlock
29
  from mamba_ssm.ops.triton.layer_norm import RMSNorm
30
 
31
  from .configuration_genemamba import GeneMambaConfig
 
49
 
50
  def __init__(self, hidden_size: int):
51
  super(EncoderLayer, self).__init__()
52
+ self.mamba = MambaBlock(d_model=hidden_size, d_state=64, d_conv=4, expand=2)
53
 
54
  def forward(self, X: torch.Tensor) -> torch.Tensor:
55
  """
trainer_state.json ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "best_metric": null,
3
+ "best_model_checkpoint": null,
4
+ "epoch": 1.0,
5
+ "eval_steps": 500,
6
+ "global_step": 31250,
7
+ "is_hyper_param_search": false,
8
+ "is_local_process_zero": true,
9
+ "is_world_process_zero": true,
10
+ "log_history": [
11
+ {
12
+ "epoch": 0.04,
13
+ "grad_norm": 0.0007822011830285192,
14
+ "learning_rate": 4.8e-05,
15
+ "loss": 0.0016,
16
+ "step": 1250
17
+ },
18
+ {
19
+ "epoch": 0.08,
20
+ "grad_norm": 0.0016089362325146794,
21
+ "learning_rate": 4.600000000000001e-05,
22
+ "loss": 0.0006,
23
+ "step": 2500
24
+ },
25
+ {
26
+ "epoch": 0.12,
27
+ "grad_norm": 7.695078238612041e-05,
28
+ "learning_rate": 4.4000000000000006e-05,
29
+ "loss": 0.0,
30
+ "step": 3750
31
+ },
32
+ {
33
+ "epoch": 0.16,
34
+ "grad_norm": 0.0024884792510420084,
35
+ "learning_rate": 4.2e-05,
36
+ "loss": 0.0,
37
+ "step": 5000
38
+ },
39
+ {
40
+ "epoch": 0.2,
41
+ "grad_norm": 5.6645851145731285e-05,
42
+ "learning_rate": 4e-05,
43
+ "loss": 0.0001,
44
+ "step": 6250
45
+ },
46
+ {
47
+ "epoch": 0.24,
48
+ "grad_norm": 0.00015396725211758167,
49
+ "learning_rate": 3.8e-05,
50
+ "loss": 0.0001,
51
+ "step": 7500
52
+ },
53
+ {
54
+ "epoch": 0.28,
55
+ "grad_norm": 0.0011548411566764116,
56
+ "learning_rate": 3.6e-05,
57
+ "loss": 0.0,
58
+ "step": 8750
59
+ },
60
+ {
61
+ "epoch": 0.32,
62
+ "grad_norm": 7.002039637882262e-05,
63
+ "learning_rate": 3.4000000000000007e-05,
64
+ "loss": 0.0001,
65
+ "step": 10000
66
+ },
67
+ {
68
+ "epoch": 0.36,
69
+ "grad_norm": 0.00010721544094849378,
70
+ "learning_rate": 3.2000000000000005e-05,
71
+ "loss": 0.0,
72
+ "step": 11250
73
+ },
74
+ {
75
+ "epoch": 0.4,
76
+ "grad_norm": 4.72808642371092e-05,
77
+ "learning_rate": 3e-05,
78
+ "loss": 0.0001,
79
+ "step": 12500
80
+ },
81
+ {
82
+ "epoch": 0.44,
83
+ "grad_norm": 1.643113137106411e-05,
84
+ "learning_rate": 2.8000000000000003e-05,
85
+ "loss": 0.0,
86
+ "step": 13750
87
+ },
88
+ {
89
+ "epoch": 0.48,
90
+ "grad_norm": 1.0432020644657314e-05,
91
+ "learning_rate": 2.6000000000000002e-05,
92
+ "loss": 0.0,
93
+ "step": 15000
94
+ },
95
+ {
96
+ "epoch": 0.52,
97
+ "grad_norm": 3.795513839577325e-05,
98
+ "learning_rate": 2.4e-05,
99
+ "loss": 0.0,
100
+ "step": 16250
101
+ },
102
+ {
103
+ "epoch": 0.56,
104
+ "grad_norm": 4.7567787987645715e-05,
105
+ "learning_rate": 2.2000000000000003e-05,
106
+ "loss": 0.0002,
107
+ "step": 17500
108
+ },
109
+ {
110
+ "epoch": 0.6,
111
+ "grad_norm": 2.121076249750331e-05,
112
+ "learning_rate": 2e-05,
113
+ "loss": 0.0,
114
+ "step": 18750
115
+ },
116
+ {
117
+ "epoch": 0.64,
118
+ "grad_norm": 1.4232242392608896e-05,
119
+ "learning_rate": 1.8e-05,
120
+ "loss": 0.0,
121
+ "step": 20000
122
+ },
123
+ {
124
+ "epoch": 0.68,
125
+ "grad_norm": 1.8679733329918236e-05,
126
+ "learning_rate": 1.6000000000000003e-05,
127
+ "loss": 0.0,
128
+ "step": 21250
129
+ },
130
+ {
131
+ "epoch": 0.72,
132
+ "grad_norm": 1.4709683455293998e-05,
133
+ "learning_rate": 1.4000000000000001e-05,
134
+ "loss": 0.0,
135
+ "step": 22500
136
+ },
137
+ {
138
+ "epoch": 0.76,
139
+ "grad_norm": 0.0004699587298091501,
140
+ "learning_rate": 1.2e-05,
141
+ "loss": 0.0,
142
+ "step": 23750
143
+ },
144
+ {
145
+ "epoch": 0.8,
146
+ "grad_norm": 7.580141755170189e-06,
147
+ "learning_rate": 1e-05,
148
+ "loss": 0.0,
149
+ "step": 25000
150
+ },
151
+ {
152
+ "epoch": 0.84,
153
+ "grad_norm": 1.317455644311849e-05,
154
+ "learning_rate": 8.000000000000001e-06,
155
+ "loss": 0.0,
156
+ "step": 26250
157
+ },
158
+ {
159
+ "epoch": 0.88,
160
+ "grad_norm": 0.00012563263589981943,
161
+ "learning_rate": 6e-06,
162
+ "loss": 0.0,
163
+ "step": 27500
164
+ },
165
+ {
166
+ "epoch": 0.92,
167
+ "grad_norm": 6.097168807173148e-06,
168
+ "learning_rate": 4.000000000000001e-06,
169
+ "loss": 0.0,
170
+ "step": 28750
171
+ },
172
+ {
173
+ "epoch": 0.96,
174
+ "grad_norm": 7.088618986017536e-06,
175
+ "learning_rate": 2.0000000000000003e-06,
176
+ "loss": 0.0,
177
+ "step": 30000
178
+ },
179
+ {
180
+ "epoch": 1.0,
181
+ "grad_norm": 1.4892546460032463e-05,
182
+ "learning_rate": 0.0,
183
+ "loss": 0.0,
184
+ "step": 31250
185
+ }
186
+ ],
187
+ "logging_steps": 1250,
188
+ "max_steps": 31250,
189
+ "num_input_tokens_seen": 0,
190
+ "num_train_epochs": 1,
191
+ "save_steps": 1250,
192
+ "total_flos": 1.0770363130241352e+18,
193
+ "train_batch_size": 16,
194
+ "trial_name": null,
195
+ "trial_params": null
196
+ }
training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0b24d88962c72332e2768938125ad26c90d11842469bcf29e7f9130fa40f8ca3
3
+ size 5048