Upload JambaForCausalLM
Browse files- config.json +6 -6
- generation_config.json +1 -1
- model-00001-of-00002.safetensors +2 -2
- model-00002-of-00002.safetensors +2 -2
- model.safetensors.index.json +2 -2
- modeling_jamba.py +2 -1
config.json
CHANGED
|
@@ -130,7 +130,7 @@
|
|
| 130 |
"mamba_attnaug_config": null,
|
| 131 |
"mamba_conv_bias": true,
|
| 132 |
"mamba_d_conv": 4,
|
| 133 |
-
"mamba_d_state":
|
| 134 |
"mamba_dt_rank": 192,
|
| 135 |
"mamba_expand": 2,
|
| 136 |
"mamba_inner_layernorms": true,
|
|
@@ -138,7 +138,7 @@
|
|
| 138 |
"mamba_multihead_config": null,
|
| 139 |
"mamba_proj_bias": false,
|
| 140 |
"mamba_reuse_every_i_layer": -1,
|
| 141 |
-
"max_position_embeddings":
|
| 142 |
"memory_tokens_interspersed_every": 0,
|
| 143 |
"mlp_hidden_act": "silu",
|
| 144 |
"mod_topk": 2,
|
|
@@ -168,7 +168,7 @@
|
|
| 168 |
"num_key_value_heads": 6,
|
| 169 |
"num_mamba": 1,
|
| 170 |
"num_memory_tokens": 256,
|
| 171 |
-
"orig_max_position_embeddings":
|
| 172 |
"other_args": null,
|
| 173 |
"output_router_logits": false,
|
| 174 |
"pad_token_id": 0,
|
|
@@ -180,11 +180,11 @@
|
|
| 180 |
"rms_norm_eps": 1e-06,
|
| 181 |
"rope": true,
|
| 182 |
"rope_theta": 10000.0,
|
| 183 |
-
"rope_type":
|
| 184 |
"router_aux_loss_coef": 0.001,
|
| 185 |
"save_input_output": false,
|
| 186 |
"self_attn_type": null,
|
| 187 |
-
"seq_length":
|
| 188 |
"sequential_jamba": false,
|
| 189 |
"share_kv": false,
|
| 190 |
"shared_module_attn": "",
|
|
@@ -195,7 +195,7 @@
|
|
| 195 |
"swa_full_head": false,
|
| 196 |
"tie_word_embeddings": true,
|
| 197 |
"torch_dtype": "bfloat16",
|
| 198 |
-
"transformers_version": "4.
|
| 199 |
"use_cache": false,
|
| 200 |
"use_mamba2": false,
|
| 201 |
"use_mamba_kernels": true,
|
|
|
|
| 130 |
"mamba_attnaug_config": null,
|
| 131 |
"mamba_conv_bias": true,
|
| 132 |
"mamba_d_conv": 4,
|
| 133 |
+
"mamba_d_state": 128,
|
| 134 |
"mamba_dt_rank": 192,
|
| 135 |
"mamba_expand": 2,
|
| 136 |
"mamba_inner_layernorms": true,
|
|
|
|
| 138 |
"mamba_multihead_config": null,
|
| 139 |
"mamba_proj_bias": false,
|
| 140 |
"mamba_reuse_every_i_layer": -1,
|
| 141 |
+
"max_position_embeddings": 22528,
|
| 142 |
"memory_tokens_interspersed_every": 0,
|
| 143 |
"mlp_hidden_act": "silu",
|
| 144 |
"mod_topk": 2,
|
|
|
|
| 168 |
"num_key_value_heads": 6,
|
| 169 |
"num_mamba": 1,
|
| 170 |
"num_memory_tokens": 256,
|
| 171 |
+
"orig_max_position_embeddings": 4096,
|
| 172 |
"other_args": null,
|
| 173 |
"output_router_logits": false,
|
| 174 |
"pad_token_id": 0,
|
|
|
|
| 180 |
"rms_norm_eps": 1e-06,
|
| 181 |
"rope": true,
|
| 182 |
"rope_theta": 10000.0,
|
| 183 |
+
"rope_type": "ntk",
|
| 184 |
"router_aux_loss_coef": 0.001,
|
| 185 |
"save_input_output": false,
|
| 186 |
"self_attn_type": null,
|
| 187 |
+
"seq_length": 1024,
|
| 188 |
"sequential_jamba": false,
|
| 189 |
"share_kv": false,
|
| 190 |
"shared_module_attn": "",
|
|
|
|
| 195 |
"swa_full_head": false,
|
| 196 |
"tie_word_embeddings": true,
|
| 197 |
"torch_dtype": "bfloat16",
|
| 198 |
+
"transformers_version": "4.45.0",
|
| 199 |
"use_cache": false,
|
| 200 |
"use_mamba2": false,
|
| 201 |
"use_mamba_kernels": true,
|
generation_config.json
CHANGED
|
@@ -3,6 +3,6 @@
|
|
| 3 |
"bos_token_id": 1,
|
| 4 |
"eos_token_id": 2,
|
| 5 |
"pad_token_id": 0,
|
| 6 |
-
"transformers_version": "4.
|
| 7 |
"use_cache": false
|
| 8 |
}
|
|
|
|
| 3 |
"bos_token_id": 1,
|
| 4 |
"eos_token_id": 2,
|
| 5 |
"pad_token_id": 0,
|
| 6 |
+
"transformers_version": "4.45.0",
|
| 7 |
"use_cache": false
|
| 8 |
}
|
model-00001-of-00002.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5e8a0875ed4decf5cbbf676868cbba137f3248a5a592a85597f31614080a25c6
|
| 3 |
+
size 4987939472
|
model-00002-of-00002.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9f361fe0361ab0e101d95f5161ee1a724501ae664e0d27c28496d9288b71ebc3
|
| 3 |
+
size 512102640
|
model.safetensors.index.json
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
{
|
| 2 |
"metadata": {
|
| 3 |
-
"total_size":
|
| 4 |
},
|
| 5 |
"weight_map": {
|
| 6 |
"model.embed_tokens.weight": "model-00001-of-00002.safetensors",
|
|
@@ -174,7 +174,7 @@
|
|
| 174 |
"model.layers.31.pre_moe_layernorm.weight": "model-00001-of-00002.safetensors",
|
| 175 |
"model.layers.32.gla.b_proj.weight": "model-00002-of-00002.safetensors",
|
| 176 |
"model.layers.32.gla.k_conv1d.weight": "model-00002-of-00002.safetensors",
|
| 177 |
-
"model.layers.32.gla.k_proj.weight": "model-
|
| 178 |
"model.layers.32.gla.o_norm.weight": "model-00002-of-00002.safetensors",
|
| 179 |
"model.layers.32.gla.o_proj.weight": "model-00002-of-00002.safetensors",
|
| 180 |
"model.layers.32.gla.q_conv1d.weight": "model-00002-of-00002.safetensors",
|
|
|
|
| 1 |
{
|
| 2 |
"metadata": {
|
| 3 |
+
"total_size": 5500015680
|
| 4 |
},
|
| 5 |
"weight_map": {
|
| 6 |
"model.embed_tokens.weight": "model-00001-of-00002.safetensors",
|
|
|
|
| 174 |
"model.layers.31.pre_moe_layernorm.weight": "model-00001-of-00002.safetensors",
|
| 175 |
"model.layers.32.gla.b_proj.weight": "model-00002-of-00002.safetensors",
|
| 176 |
"model.layers.32.gla.k_conv1d.weight": "model-00002-of-00002.safetensors",
|
| 177 |
+
"model.layers.32.gla.k_proj.weight": "model-00002-of-00002.safetensors",
|
| 178 |
"model.layers.32.gla.o_norm.weight": "model-00002-of-00002.safetensors",
|
| 179 |
"model.layers.32.gla.o_proj.weight": "model-00002-of-00002.safetensors",
|
| 180 |
"model.layers.32.gla.q_conv1d.weight": "model-00002-of-00002.safetensors",
|
modeling_jamba.py
CHANGED
|
@@ -59,6 +59,7 @@ from transformers.utils import (
|
|
| 59 |
from transformers.utils.import_utils import is_torch_fx_available
|
| 60 |
from .configuration_jamba import JambaConfig
|
| 61 |
from torch.utils.checkpoint import checkpoint
|
|
|
|
| 62 |
|
| 63 |
|
| 64 |
# try except block so it'll work with trust_remote_code. Later we can have `if is_flash_attn_2_available():`
|
|
@@ -3664,7 +3665,7 @@ class JambaModel(JambaPreTrainedModel):
|
|
| 3664 |
|
| 3665 |
|
| 3666 |
# Adapted from transformers.models.mixtral.modeling_mixtral.MixtralForCausalLM with MIXTRAL->JAMBA, Mixtral->Jamba
|
| 3667 |
-
class JambaForCausalLM(JambaPreTrainedModel):
|
| 3668 |
_tied_weights_keys = ["lm_head.weight"]
|
| 3669 |
|
| 3670 |
def __init__(self, config: JambaConfig):
|
|
|
|
| 59 |
from transformers.utils.import_utils import is_torch_fx_available
|
| 60 |
from .configuration_jamba import JambaConfig
|
| 61 |
from torch.utils.checkpoint import checkpoint
|
| 62 |
+
from transformers.generation.utils import GenerationMixin
|
| 63 |
|
| 64 |
|
| 65 |
# try except block so it'll work with trust_remote_code. Later we can have `if is_flash_attn_2_available():`
|
|
|
|
| 3665 |
|
| 3666 |
|
| 3667 |
# Adapted from transformers.models.mixtral.modeling_mixtral.MixtralForCausalLM with MIXTRAL->JAMBA, Mixtral->Jamba
|
| 3668 |
+
class JambaForCausalLM(JambaPreTrainedModel, GenerationMixin):
|
| 3669 |
_tied_weights_keys = ["lm_head.weight"]
|
| 3670 |
|
| 3671 |
def __init__(self, config: JambaConfig):
|