YongganFu commited on
Commit
bf1039a
·
verified ·
1 Parent(s): 0401e1c

Upload JambaForCausalLM

Browse files
config.json CHANGED
@@ -130,7 +130,7 @@
130
  "mamba_attnaug_config": null,
131
  "mamba_conv_bias": true,
132
  "mamba_d_conv": 4,
133
- "mamba_d_state": 16,
134
  "mamba_dt_rank": 192,
135
  "mamba_expand": 2,
136
  "mamba_inner_layernorms": true,
@@ -138,7 +138,7 @@
138
  "mamba_multihead_config": null,
139
  "mamba_proj_bias": false,
140
  "mamba_reuse_every_i_layer": -1,
141
- "max_position_embeddings": 2048,
142
  "memory_tokens_interspersed_every": 0,
143
  "mlp_hidden_act": "silu",
144
  "mod_topk": 2,
@@ -168,7 +168,7 @@
168
  "num_key_value_heads": 6,
169
  "num_mamba": 1,
170
  "num_memory_tokens": 256,
171
- "orig_max_position_embeddings": 2048,
172
  "other_args": null,
173
  "output_router_logits": false,
174
  "pad_token_id": 0,
@@ -180,11 +180,11 @@
180
  "rms_norm_eps": 1e-06,
181
  "rope": true,
182
  "rope_theta": 10000.0,
183
- "rope_type": null,
184
  "router_aux_loss_coef": 0.001,
185
  "save_input_output": false,
186
  "self_attn_type": null,
187
- "seq_length": 2048,
188
  "sequential_jamba": false,
189
  "share_kv": false,
190
  "shared_module_attn": "",
@@ -195,7 +195,7 @@
195
  "swa_full_head": false,
196
  "tie_word_embeddings": true,
197
  "torch_dtype": "bfloat16",
198
- "transformers_version": "4.48.2",
199
  "use_cache": false,
200
  "use_mamba2": false,
201
  "use_mamba_kernels": true,
 
130
  "mamba_attnaug_config": null,
131
  "mamba_conv_bias": true,
132
  "mamba_d_conv": 4,
133
+ "mamba_d_state": 128,
134
  "mamba_dt_rank": 192,
135
  "mamba_expand": 2,
136
  "mamba_inner_layernorms": true,
 
138
  "mamba_multihead_config": null,
139
  "mamba_proj_bias": false,
140
  "mamba_reuse_every_i_layer": -1,
141
+ "max_position_embeddings": 22528,
142
  "memory_tokens_interspersed_every": 0,
143
  "mlp_hidden_act": "silu",
144
  "mod_topk": 2,
 
168
  "num_key_value_heads": 6,
169
  "num_mamba": 1,
170
  "num_memory_tokens": 256,
171
+ "orig_max_position_embeddings": 4096,
172
  "other_args": null,
173
  "output_router_logits": false,
174
  "pad_token_id": 0,
 
180
  "rms_norm_eps": 1e-06,
181
  "rope": true,
182
  "rope_theta": 10000.0,
183
+ "rope_type": "ntk",
184
  "router_aux_loss_coef": 0.001,
185
  "save_input_output": false,
186
  "self_attn_type": null,
187
+ "seq_length": 1024,
188
  "sequential_jamba": false,
189
  "share_kv": false,
190
  "shared_module_attn": "",
 
195
  "swa_full_head": false,
196
  "tie_word_embeddings": true,
197
  "torch_dtype": "bfloat16",
198
+ "transformers_version": "4.45.0",
199
  "use_cache": false,
200
  "use_mamba2": false,
201
  "use_mamba_kernels": true,
generation_config.json CHANGED
@@ -3,6 +3,6 @@
3
  "bos_token_id": 1,
4
  "eos_token_id": 2,
5
  "pad_token_id": 0,
6
- "transformers_version": "4.48.2",
7
  "use_cache": false
8
  }
 
3
  "bos_token_id": 1,
4
  "eos_token_id": 2,
5
  "pad_token_id": 0,
6
+ "transformers_version": "4.45.0",
7
  "use_cache": false
8
  }
model-00001-of-00002.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:73eabedbcef6e7dd78560bb95aced0f5a86de6c067b738503213f634c96f4fbd
3
- size 4995785984
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5e8a0875ed4decf5cbbf676868cbba137f3248a5a592a85597f31614080a25c6
3
+ size 4987939472
model-00002-of-00002.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:af6a0fc3b40b1c9d005467296052d6039bf6736955a54e0d2427445f791042b8
3
- size 491849664
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9f361fe0361ab0e101d95f5161ee1a724501ae664e0d27c28496d9288b71ebc3
3
+ size 512102640
model.safetensors.index.json CHANGED
@@ -1,6 +1,6 @@
1
  {
2
  "metadata": {
3
- "total_size": 5487609216
4
  },
5
  "weight_map": {
6
  "model.embed_tokens.weight": "model-00001-of-00002.safetensors",
@@ -174,7 +174,7 @@
174
  "model.layers.31.pre_moe_layernorm.weight": "model-00001-of-00002.safetensors",
175
  "model.layers.32.gla.b_proj.weight": "model-00002-of-00002.safetensors",
176
  "model.layers.32.gla.k_conv1d.weight": "model-00002-of-00002.safetensors",
177
- "model.layers.32.gla.k_proj.weight": "model-00001-of-00002.safetensors",
178
  "model.layers.32.gla.o_norm.weight": "model-00002-of-00002.safetensors",
179
  "model.layers.32.gla.o_proj.weight": "model-00002-of-00002.safetensors",
180
  "model.layers.32.gla.q_conv1d.weight": "model-00002-of-00002.safetensors",
 
1
  {
2
  "metadata": {
3
+ "total_size": 5500015680
4
  },
5
  "weight_map": {
6
  "model.embed_tokens.weight": "model-00001-of-00002.safetensors",
 
174
  "model.layers.31.pre_moe_layernorm.weight": "model-00001-of-00002.safetensors",
175
  "model.layers.32.gla.b_proj.weight": "model-00002-of-00002.safetensors",
176
  "model.layers.32.gla.k_conv1d.weight": "model-00002-of-00002.safetensors",
177
+ "model.layers.32.gla.k_proj.weight": "model-00002-of-00002.safetensors",
178
  "model.layers.32.gla.o_norm.weight": "model-00002-of-00002.safetensors",
179
  "model.layers.32.gla.o_proj.weight": "model-00002-of-00002.safetensors",
180
  "model.layers.32.gla.q_conv1d.weight": "model-00002-of-00002.safetensors",
modeling_jamba.py CHANGED
@@ -59,6 +59,7 @@ from transformers.utils import (
59
  from transformers.utils.import_utils import is_torch_fx_available
60
  from .configuration_jamba import JambaConfig
61
  from torch.utils.checkpoint import checkpoint
 
62
 
63
 
64
  # try except block so it'll work with trust_remote_code. Later we can have `if is_flash_attn_2_available():`
@@ -3664,7 +3665,7 @@ class JambaModel(JambaPreTrainedModel):
3664
 
3665
 
3666
  # Adapted from transformers.models.mixtral.modeling_mixtral.MixtralForCausalLM with MIXTRAL->JAMBA, Mixtral->Jamba
3667
- class JambaForCausalLM(JambaPreTrainedModel):
3668
  _tied_weights_keys = ["lm_head.weight"]
3669
 
3670
  def __init__(self, config: JambaConfig):
 
59
  from transformers.utils.import_utils import is_torch_fx_available
60
  from .configuration_jamba import JambaConfig
61
  from torch.utils.checkpoint import checkpoint
62
+ from transformers.generation.utils import GenerationMixin
63
 
64
 
65
  # try except block so it'll work with trust_remote_code. Later we can have `if is_flash_attn_2_available():`
 
3665
 
3666
 
3667
  # Adapted from transformers.models.mixtral.modeling_mixtral.MixtralForCausalLM with MIXTRAL->JAMBA, Mixtral->Jamba
3668
+ class JambaForCausalLM(JambaPreTrainedModel, GenerationMixin):
3669
  _tied_weights_keys = ["lm_head.weight"]
3670
 
3671
  def __init__(self, config: JambaConfig):