anas tomeras1 commited on
Commit
7c46cea
·
0 Parent(s):

Duplicate from ai21labs/Jamba-v0.1

Browse files

Co-authored-by: Tomer Asida <tomeras1@users.noreply.huggingface.co>

.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ license: apache-2.0
4
+ tags:
5
+ - jamba
6
+ - mamba
7
+ - moe
8
+ ---
9
+
10
+ This is the base version of the Jamba model. We’ve since released a better, instruct-tuned version, [Jamba-1.5-Mini](https://huggingface.co/ai21labs/AI21-Jamba-1.5-Mini). For even greater performance, check out the scaled-up [Jamba-1.5-Large](https://huggingface.co/ai21labs/AI21-Jamba-1.5-Large).
11
+
12
+ # Model Card for Jamba
13
+
14
+ Jamba is a state-of-the-art, hybrid SSM-Transformer LLM. It delivers throughput gains over traditional Transformer-based models, while outperforming or matching the leading models of its size class on most common benchmarks.
15
+
16
+ Jamba is the first production-scale Mamba implementation, which opens up interesting research and application opportunities. While this initial experimentation shows encouraging gains, we expect these to be further enhanced with future optimizations and explorations.
17
+
18
+ This model card is for the base version of Jamba. It’s a pretrained, mixture-of-experts (MoE) generative text model, with 12B active parameters and a total of 52B parameters across all experts. It supports a 256K context length, and can fit up to 140K tokens on a single 80GB GPU.
19
+
20
+ For full details of this model please read the [white paper](https://arxiv.org/abs/2403.19887) and the [release blog post](https://www.ai21.com/blog/announcing-jamba).
21
+
22
+ ## Model Details
23
+
24
+ - **Developed by:** [AI21](https://www.ai21.com)
25
+ - **Model type:** Joint Attention and Mamba (Jamba)
26
+ - **License:** Apache 2.0
27
+ - **Context length:** 256K
28
+ - **Knowledge cutoff date:** March 5, 2024
29
+
30
+ ## Usage
31
+ ### Presequities
32
+ In order to use Jamba, it is recommended you use `transformers` version 4.40.0 or higher (version 4.39.0 or higher is required):
33
+ ```bash
34
+ pip install transformers>=4.40.0
35
+ ```
36
+
37
+ In order to run optimized Mamba implementations, you first need to install `mamba-ssm` and `causal-conv1d`:
38
+ ```bash
39
+ pip install mamba-ssm causal-conv1d>=1.2.0
40
+ ```
41
+ You also have to have the model on a CUDA device.
42
+
43
+ You can run the model not using the optimized Mamba kernels, but it is **not** recommended as it will result in significantly lower latencies. In order to do that, you'll need to specify `use_mamba_kernels=False` when loading the model.
44
+
45
+ ### Run the model
46
+ ```python
47
+ from transformers import AutoModelForCausalLM, AutoTokenizer
48
+
49
+ model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1")
50
+ tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1")
51
+
52
+ input_ids = tokenizer("In the recent Super Bowl LVIII,", return_tensors='pt').to(model.device)["input_ids"]
53
+
54
+ outputs = model.generate(input_ids, max_new_tokens=216)
55
+
56
+ print(tokenizer.batch_decode(outputs))
57
+ # ["<|startoftext|>In the recent Super Bowl LVIII, the Kansas City Chiefs emerged victorious, defeating the San Francisco 49ers in a thrilling overtime showdown. The game was a nail-biter, with both teams showcasing their skills and determination.\n\nThe Chiefs, led by their star quarterback Patrick Mahomes, displayed their offensive prowess, while the 49ers, led by their strong defense, put up a tough fight. The game went into overtime, with the Chiefs ultimately securing the win with a touchdown.\n\nThe victory marked the Chiefs' second Super Bowl win in four years, solidifying their status as one of the top teams in the NFL. The game was a testament to the skill and talent of both teams, and a thrilling end to the NFL season.\n\nThe Super Bowl is not just about the game itself, but also about the halftime show and the commercials. This year's halftime show featured a star-studded lineup, including Usher, Alicia Keys, and Lil Jon. The show was a spectacle of music and dance, with the performers delivering an energetic and entertaining performance.\n"]
58
+ ```
59
+
60
+ Please note that if you're using `transformers<4.40.0`, `trust_remote_code=True` is required for running the new Jamba architecture.
61
+
62
+ <details>
63
+ <summary><strong>Loading the model in half precision</strong></summary>
64
+
65
+ The published checkpoint is saved in BF16. In order to load it into RAM in BF16/FP16, you need to specify `torch_dtype`:
66
+
67
+ ```python
68
+ from transformers import AutoModelForCausalLM
69
+ import torch
70
+ model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1",
71
+ torch_dtype=torch.bfloat16) # you can also use torch_dtype=torch.float16
72
+ ```
73
+
74
+ When using half precision, you can enable the [FlashAttention2](https://github.com/Dao-AILab/flash-attention) implementation of the Attention blocks. In order to use it, you also need the model on a CUDA device. Since in this precision the model is to big to fit on a single 80GB GPU, you'll also need to parallelize it using [accelerate](https://huggingface.co/docs/accelerate/index):
75
+ ```python
76
+ from transformers import AutoModelForCausalLM
77
+ import torch
78
+ model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1",
79
+ torch_dtype=torch.bfloat16,
80
+ attn_implementation="flash_attention_2",
81
+ device_map="auto")
82
+ ```
83
+
84
+ </details>
85
+ <details><summary><strong>Load the model in 8-bit</strong></summary>
86
+
87
+ **Using 8-bit precision, it is possible to fit up to 140K sequence lengths on a single 80GB GPU.** You can easily quantize the model to 8-bit using [bitsandbytes](https://huggingface.co/docs/bitsandbytes/index). In order to not degrade model quality, we recommend to exclude the Mamba blocks from the quantization:
88
+
89
+ ```python
90
+ from transformers import AutoModelForCausalLM, BitsAndBytesConfig
91
+ quantization_config = BitsAndBytesConfig(load_in_8bit=True,
92
+ llm_int8_skip_modules=["mamba"])
93
+ model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1",
94
+ torch_dtype=torch.bfloat16,
95
+ attn_implementation="flash_attention_2",
96
+ quantization_config=quantization_config)
97
+ ```
98
+ </details>
99
+
100
+ ### Fine-tuning example
101
+ Jamba is a base model that can be fine-tuned for custom solutions (including for chat/instruct versions). You can fine-tune it using any technique of your choice. Here is an example of fine-tuning with the [PEFT](https://huggingface.co/docs/peft/index) library (requires ~120GB GPU RAM, in example 2xA100 80GB):
102
+
103
+ ```python
104
+ import torch
105
+ from datasets import load_dataset
106
+ from trl import SFTTrainer, SFTConfig
107
+ from peft import LoraConfig
108
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
109
+
110
+ tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1")
111
+ model = AutoModelForCausalLM.from_pretrained(
112
+ "ai21labs/Jamba-v0.1", device_map='auto', torch_dtype=torch.bfloat16)
113
+
114
+ lora_config = LoraConfig(
115
+ r=8,
116
+ target_modules=[
117
+ "embed_tokens",
118
+ "x_proj", "in_proj", "out_proj", # mamba
119
+ "gate_proj", "up_proj", "down_proj", # mlp
120
+ "q_proj", "k_proj", "v_proj" # attention
121
+ ],
122
+ task_type="CAUSAL_LM",
123
+ bias="none"
124
+ )
125
+
126
+ dataset = load_dataset("Abirate/english_quotes", split="train")
127
+ training_args = SFTConfig(
128
+ output_dir="./results",
129
+ num_train_epochs=2,
130
+ per_device_train_batch_size=4,
131
+ logging_dir='./logs',
132
+ logging_steps=10,
133
+ learning_rate=1e-5,
134
+ dataset_text_field="quote",
135
+ )
136
+ trainer = SFTTrainer(
137
+ model=model,
138
+ tokenizer=tokenizer,
139
+ args=training_args,
140
+ peft_config=lora_config,
141
+ train_dataset=dataset,
142
+ )
143
+ trainer.train()
144
+ ```
145
+
146
+ ## Results on common benchmarks
147
+ | Benchmark | Score |
148
+ |--------------|:-----:|
149
+ | HellaSwag | 87.1% |
150
+ | Arc Challenge | 64.4% |
151
+ | WinoGrande | 82.5% |
152
+ | PIQA | 83.2% |
153
+ | MMLU | 67.4% |
154
+ | BBH | 45.4% |
155
+ | TruthfulQA | 46.4% |
156
+ | GSM8K (CoT) | 59.9% |
157
+
158
+ It's crucial that the 'BOS' token is added to all prompts, which might not be enabled by default in all eval frameworks.
159
+
160
+
161
+ ## Notice
162
+ Jamba is a pretrained base model and did not undergo any alignment for instruct/chat interactions.
163
+
164
+ As a base model, Jamba is intended for use as a foundation layer for fine tuning, training, and developing custom solutions. Jamba does not have safety moderation mechanisms and guardrails should be added for responsible and safe use.
165
+
166
+ ## About AI21
167
+ AI21 builds reliable, practical, and scalable AI solutions for the enterprise.
168
+
169
+ Jamba is the first in AI21’s new family of models, and the Instruct version of Jamba is coming soon to the [AI21 platform](https://www.ai21.com/studio).
config.json ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "JambaForCausalLM"
4
+ ],
5
+ "attention_dropout": 0.0,
6
+ "attn_layer_offset": 4,
7
+ "attn_layer_period": 8,
8
+ "auto_map": {
9
+ "AutoConfig": "configuration_jamba.JambaConfig",
10
+ "AutoModel": "modeling_jamba.JambaModel",
11
+ "AutoModelForCausalLM": "modeling_jamba.JambaForCausalLM",
12
+ "AutoModelForSequenceClassification": "model.JambaForSequenceClassification"
13
+ },
14
+ "bos_token_id": 1,
15
+ "eos_token_id": 2,
16
+ "expert_layer_offset": 1,
17
+ "expert_layer_period": 2,
18
+ "hidden_act": "silu",
19
+ "hidden_size": 4096,
20
+ "initializer_range": 0.02,
21
+ "intermediate_size": 14336,
22
+ "mamba_conv_bias": true,
23
+ "mamba_d_conv": 4,
24
+ "mamba_d_state": 16,
25
+ "mamba_dt_rank": 256,
26
+ "mamba_expand": 2,
27
+ "mamba_proj_bias": false,
28
+ "max_position_embeddings": 262144,
29
+ "model_type": "jamba",
30
+ "num_attention_heads": 32,
31
+ "num_experts": 16,
32
+ "num_experts_per_tok": 2,
33
+ "num_hidden_layers": 32,
34
+ "num_key_value_heads": 8,
35
+ "num_logits_to_keep": 1,
36
+ "output_router_logits": false,
37
+ "pad_token_id": 0,
38
+ "rms_norm_eps": 1e-06,
39
+ "router_aux_loss_coef": 0.001,
40
+ "sliding_window": null,
41
+ "tie_word_embeddings": false,
42
+ "torch_dtype": "bfloat16",
43
+ "transformers_version": "4.40.1",
44
+ "use_cache": true,
45
+ "use_mamba_kernels": true,
46
+ "vocab_size": 65536
47
+ }
configuration_jamba.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 AI21 Labs Ltd. and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Jamba model configuration"""
16
+ import math
17
+
18
+ from transformers.configuration_utils import PretrainedConfig
19
+ from transformers.utils import logging
20
+
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+
25
+ class JambaConfig(PretrainedConfig):
26
+ r"""
27
+ This is the configuration class to store the configuration of a [`JambaModel`]. It is used to instantiate a
28
+ Jamba model according to the specified arguments, defining the model architecture. Instantiating a configuration
29
+ with the defaults will yield a similar configuration to that of the Jamba-v0.1 model.
30
+
31
+ [ai21labs/Jamba-v0.1](https://huggingface.co/ai21labs/Jamba-v0.1)
32
+
33
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
34
+ documentation from [`PretrainedConfig`] for more information.
35
+
36
+
37
+ Args:
38
+ vocab_size (`int`, *optional*, defaults to 65536):
39
+ Vocabulary size of the Jamba model. Defines the number of different tokens that can be represented by the
40
+ `inputs_ids` passed when calling [`JambaModel`]
41
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
42
+ Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the
43
+ model has a output word embedding layer.
44
+ hidden_size (`int`, *optional*, defaults to 4096):
45
+ Dimension of the hidden representations.
46
+ intermediate_size (`int`, *optional*, defaults to 14336):
47
+ Dimension of the MLP representations.
48
+ num_hidden_layers (`int`, *optional*, defaults to 32):
49
+ Number of hidden layers in the Transformer encoder.
50
+ num_attention_heads (`int`, *optional*, defaults to 32):
51
+ Number of attention heads for each attention layer in the Transformer encoder.
52
+ num_key_value_heads (`int`, *optional*, defaults to 8):
53
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
54
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
55
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
56
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
57
+ by meanpooling all the original heads within that group. For more details checkout [this
58
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`.
59
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
60
+ The non-linear activation function (function or string) in the decoder.
61
+ initializer_range (`float`, *optional*, defaults to 0.02):
62
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
63
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
64
+ The epsilon used by the rms normalization layers.
65
+ use_cache (`bool`, *optional*, defaults to `True`):
66
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
67
+ relevant if `config.is_decoder=True`.
68
+ num_logits_to_keep (`int` or `None`, *optional*, defaults to 1):
69
+ Number of prompt logits to calculate during generation. If `None`, all logits will be calculated. If an
70
+ integer value, only last `num_logits_to_keep` logits will be calculated. Default is 1 because only the
71
+ logits of the last prompt token are needed for generation. For long sequences, the logits for the entire
72
+ sequence may use a lot of memory so, setting `num_logits_to_keep=1` will reduce memory footprint
73
+ significantly.
74
+ output_router_logits (`bool`, *optional*, defaults to `False`):
75
+ Whether or not the router logits should be returned by the model. Enabling this will also
76
+ allow the model to output the auxiliary loss. See [here]() for more details
77
+ router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
78
+ The aux loss factor for the total loss.
79
+ pad_token_id (`int`, *optional*, defaults to 0):
80
+ The id of the padding token.
81
+ bos_token_id (`int`, *optional*, defaults to 1):
82
+ The id of the "beginning-of-sequence" token.
83
+ eos_token_id (`int`, *optional*, defaults to 2):
84
+ The id of the "end-of-sequence" token.
85
+ sliding_window (`int`, *optional*):
86
+ Sliding window attention window size. If not specified, will default to `None`.
87
+ max_position_embeddings (`int`, *optional*, defaults to 262144):
88
+ This value doesn't have any real effect. The maximum sequence length that this model is intended to be
89
+ used with. It can be used with longer sequences, but performance may degrade.
90
+ attention_dropout (`float`, *optional*, defaults to 0.0):
91
+ The dropout ratio for the attention probabilities.
92
+ num_experts_per_tok (`int`, *optional*, defaults to 2):
93
+ The number of experts to root per-token, can be also interpreted as the `top-p` routing
94
+ parameter
95
+ num_experts (`int`, *optional*, defaults to 16):
96
+ Number of experts per Sparse MLP layer.
97
+ expert_layer_period (`int`, *optional*, defaults to 2):
98
+ Once in this many layers, we will have an expert layer
99
+ expert_layer_offset (`int`, *optional*, defaults to 1):
100
+ The first layer index that contains an expert mlp layer
101
+ attn_layer_period (`int`, *optional*, defaults to 8):
102
+ Once in this many layers, we will have a vanilla attention layer
103
+ attn_layer_offset (`int`, *optional*, defaults to 4):
104
+ The first layer index that contains a vanilla attention mlp layer
105
+ use_mamba_kernels (`bool`, *optional*, defaults to `True`):
106
+ Flag indicating whether or not to use the fast mamba kernels. These are available only if `mamba-ssm` and
107
+ `causal-conv1d` are installed, and the mamba modules are running on a CUDA device. Raises ValueError if
108
+ `True` and kernels are not available
109
+ mamba_d_state (`int`, *optional*, defaults to 16):
110
+ The dimension the mamba state space latents
111
+ mamba_d_conv (`int`, *optional*, defaults to 4):
112
+ The size of the mamba convolution kernel
113
+ mamba_expand (`int`, *optional*, defaults to 2):
114
+ Expanding factor (relative to hidden_size) used to determine the mamba intermediate size
115
+ mamba_dt_rank (`Union[int,str]`, *optional*, defaults to `"auto"`):
116
+ Rank of the the mamba discretization projection matrix. `"auto"` means that it will default to `math.ceil(self.hidden_size / 16)`
117
+ mamba_conv_bias (`bool`, *optional*, defaults to `True`):
118
+ Flag indicating whether or not to use bias in the convolution layer of the mamba mixer block.
119
+ mamba_proj_bias (`bool`, *optional*, defaults to `False`):
120
+ Flag indicating whether or not to use bias in the input and output projections (["in_proj", "out_proj"]) of the mamba mixer block
121
+
122
+ """
123
+
124
+ model_type = "jamba"
125
+ keys_to_ignore_at_inference = ["past_key_values"]
126
+
127
+ def __init__(
128
+ self,
129
+ vocab_size=65536,
130
+ tie_word_embeddings=False,
131
+ hidden_size=4096,
132
+ intermediate_size=14336,
133
+ num_hidden_layers=32,
134
+ num_attention_heads=32,
135
+ num_key_value_heads=8,
136
+ hidden_act="silu",
137
+ initializer_range=0.02,
138
+ rms_norm_eps=1e-6,
139
+ use_cache=True,
140
+ num_logits_to_keep=1,
141
+ output_router_logits=False,
142
+ router_aux_loss_coef=0.001,
143
+ pad_token_id=0,
144
+ bos_token_id=1,
145
+ eos_token_id=2,
146
+ sliding_window=None,
147
+ max_position_embeddings=262144,
148
+ attention_dropout=0.0,
149
+ num_experts_per_tok=2,
150
+ num_experts=16,
151
+ expert_layer_period=2,
152
+ expert_layer_offset=1,
153
+ attn_layer_period=8,
154
+ attn_layer_offset=4,
155
+ use_mamba_kernels=True,
156
+ mamba_d_state=16,
157
+ mamba_d_conv=4,
158
+ mamba_expand=2,
159
+ mamba_dt_rank="auto",
160
+ mamba_conv_bias=True,
161
+ mamba_proj_bias=False,
162
+ **kwargs,
163
+ ):
164
+ self.vocab_size = vocab_size
165
+ self.tie_word_embeddings = tie_word_embeddings
166
+ self.hidden_size = hidden_size
167
+ self.intermediate_size = intermediate_size
168
+ self.num_hidden_layers = num_hidden_layers
169
+ self.num_attention_heads = num_attention_heads
170
+ self.sliding_window = sliding_window
171
+ self.max_position_embeddings = max_position_embeddings
172
+ self.attention_dropout = attention_dropout
173
+
174
+ # for backward compatibility
175
+ if num_key_value_heads is None:
176
+ num_key_value_heads = num_attention_heads
177
+
178
+ self.num_key_value_heads = num_key_value_heads
179
+ self.hidden_act = hidden_act
180
+ self.initializer_range = initializer_range
181
+ self.rms_norm_eps = rms_norm_eps
182
+
183
+ self.use_cache = use_cache
184
+ self.num_logits_to_keep = num_logits_to_keep
185
+ self.output_router_logits = output_router_logits
186
+ self.router_aux_loss_coef = router_aux_loss_coef
187
+
188
+ self.num_experts_per_tok = num_experts_per_tok
189
+ self.num_experts = num_experts
190
+ self.expert_layer_period = expert_layer_period
191
+ self.expert_layer_offset = expert_layer_offset
192
+ self.attn_layer_period = attn_layer_period
193
+ self.attn_layer_offset = attn_layer_offset
194
+
195
+ self.use_mamba_kernels = use_mamba_kernels
196
+ self.mamba_d_state = mamba_d_state
197
+ self.mamba_d_conv = mamba_d_conv
198
+ self.mamba_expand = mamba_expand
199
+ self.mamba_dt_rank = math.ceil(self.hidden_size / 16) if mamba_dt_rank == "auto" else mamba_dt_rank
200
+ self.mamba_conv_bias = mamba_conv_bias
201
+ self.mamba_proj_bias = mamba_proj_bias
202
+
203
+ super().__init__(
204
+ pad_token_id=pad_token_id,
205
+ bos_token_id=bos_token_id,
206
+ eos_token_id=eos_token_id,
207
+ tie_word_embeddings=tie_word_embeddings,
208
+ **kwargs,
209
+ )
210
+
211
+ @property
212
+ def layers_block_type(self):
213
+ return [
214
+ "attention" if i % self.attn_layer_period == self.attn_layer_offset else "mamba"
215
+ for i in range(self.num_hidden_layers)
216
+ ]
217
+
218
+ @property
219
+ def layers_num_experts(self):
220
+ return [
221
+ self.num_experts if i % self.expert_layer_period == self.expert_layer_offset else 1
222
+ for i in range(self.num_hidden_layers)
223
+ ]
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "pad_token_id": 0,
6
+ "transformers_version": "4.40.1"
7
+ }
model-00001-of-00021.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1aace34ee0da3bf95605bd150fff6d3e78110be4048a3c389b0a740354b2ccb7
3
+ size 4951761424
model-00002-of-00021.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0ba1de67a86329431f14f7ffa165d84055d32ce57a6d2314e3b2464eac3732dc
3
+ size 4884669624
model-00003-of-00021.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1abc4f16865fb78241c9453292ee3b2ca2c1e2d54ee945631da625834b95c9b2
3
+ size 4992557120
model-00004-of-00021.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:45fab97739a58e924791572ea3d06f9c90b9ff2a299460aaa4bd87c6e9d424f3
3
+ size 4958853560
model-00005-of-00021.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c4b0ec6e8f33e6d7b1f837cd4c25818487dcc7e478734606da28110507e51c97
3
+ size 4975763832
model-00006-of-00021.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ed98d5c3c8d7ab7352944bea09b0d54d98066cf567ba3d069da12c05575d56ed
3
+ size 4884669616
model-00007-of-00021.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:735be2bc568711bf42a4caebcda8288dd300b31b48fa098b00df3cf1a98e10e2
3
+ size 4884669640
model-00008-of-00021.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d0c8d817b2b47661d361e8b520128b3194185f756cc2204a95d642e24895ee51
3
+ size 4992557176
model-00009-of-00021.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e50222cf865ca5678d22574b131294303c46b249478cf70113c701f70331e999
3
+ size 4932507176
model-00010-of-00021.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b1b4b69b24ae55827b6c8b1e4a10807aa3525bc85f4d34dc002ac7440757fbf4
3
+ size 4884669672
model-00011-of-00021.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:60213cac13b92ed34b93ce48e670434f22e3bf8b2b8df20c60b7bf8a9515c35c
3
+ size 4884669696
model-00012-of-00021.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:05805eacd3bb40cc9da802350409f1cb078e8b276da7e06c7a8a5ca5b26cc887
3
+ size 4884669688
model-00013-of-00021.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:201df979a1b34ced6cdbb7a790163412636779f1119e3845a704c489181d03d2
3
+ size 4932507176
model-00014-of-00021.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d0a7eb42a9ea3a385442c2e758dd5efd5dc5b913f1d10bfd37792cc963a33c93
3
+ size 4992557152
model-00015-of-00021.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a4b9afe4398000c28b36e3aa40c87086af673d4f8a64bfc5767941ab2008bcc9
3
+ size 4884669688
model-00016-of-00021.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dd1ac6cc861971c43bdf0c9c6d4c9fe72d33e5227e054a621e2e68f001419763
3
+ size 4884669688
model-00017-of-00021.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:52d9eea696dd29ef413d617bbcb62a9f159e8fe8170d36e018932cef45ee281d
3
+ size 4908522856
model-00018-of-00021.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:77acada7c098e81280645ea0a9dbfa00196dca6da8946498b9907e9e376fb42d
3
+ size 4908654000
model-00019-of-00021.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:09e10dfd6c6459cd3460b1d667639717d3657274c1694c19a6fdbac1be6a76bf
3
+ size 4992557168
model-00020-of-00021.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2bd5c27b2cca6e06f7b4497ce8c9b1522a64846817a871bad274d08507960ed0
3
+ size 4884669696
model-00021-of-00021.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a47ef23db8deb5364da676a40dc3dcb011fb9d9ceef13ba044c176e9a83ac1e3
3
+ size 4647318576
model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
modeling_jamba.py ADDED
@@ -0,0 +1,1887 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 AI21 Labs Ltd. and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """ PyTorch Jamba model."""
21
+ import inspect
22
+ import math
23
+ from typing import Any, Dict, List, Optional, Tuple, Union
24
+
25
+ import torch
26
+ import torch.nn.functional as F
27
+ import torch.utils.checkpoint
28
+ from torch import nn
29
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
30
+
31
+ from transformers.activations import ACT2FN
32
+ from transformers.cache_utils import DynamicCache # we need __iter__ and __len__ of pkv
33
+ from transformers.modeling_attn_mask_utils import (
34
+ AttentionMaskConverter,
35
+ )
36
+ from transformers.modeling_outputs import (
37
+ MoeCausalLMOutputWithPast,
38
+ MoeModelOutputWithPast,
39
+ SequenceClassifierOutputWithPast,
40
+ )
41
+ from transformers.modeling_utils import PreTrainedModel
42
+ from transformers.utils import (
43
+ add_start_docstrings,
44
+ add_start_docstrings_to_model_forward,
45
+ is_flash_attn_greater_or_equal_2_10,
46
+ logging,
47
+ replace_return_docstrings,
48
+ )
49
+ from transformers.utils.import_utils import (
50
+ is_causal_conv1d_available,
51
+ is_flash_attn_2_available,
52
+ is_mamba_ssm_available,
53
+ )
54
+ from .configuration_jamba import JambaConfig
55
+
56
+
57
+ # try except block so it'll work with trust_remote_code.
58
+ try:
59
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
60
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
61
+
62
+ _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
63
+ except ImportError:
64
+ pass
65
+
66
+
67
+ # try except block so it'll work with trust_remote_code.
68
+ try:
69
+ from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn
70
+ from mamba_ssm.ops.triton.selective_state_update import selective_state_update
71
+ except ImportError:
72
+ selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None
73
+
74
+ # try except block so it'll work with trust_remote_code.
75
+ try:
76
+ from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
77
+ except ImportError:
78
+ causal_conv1d_update, causal_conv1d_fn = None, None
79
+
80
+ is_fast_path_available = all(
81
+ (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
82
+ )
83
+
84
+
85
+ logger = logging.get_logger(__name__)
86
+
87
+ _CONFIG_FOR_DOC = "JambaConfig"
88
+
89
+
90
+ # Copied from transformers.models.mixtral.modeling_mixtral.load_balancing_loss_func with gate->router
91
+ def load_balancing_loss_func(
92
+ router_logits: torch.Tensor,
93
+ num_experts: torch.Tensor = None,
94
+ top_k=2,
95
+ attention_mask: Optional[torch.Tensor] = None,
96
+ ) -> float:
97
+ r"""
98
+ Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
99
+
100
+ See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
101
+ function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
102
+ experts is too unbalanced.
103
+
104
+ Args:
105
+ router_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]):
106
+ Logits from the `router`, should be a tuple of model.config.num_hidden_layers tensors of
107
+ shape [batch_size X sequence_length, num_experts].
108
+ attention_mask (`torch.Tensor`, None):
109
+ The attention_mask used in forward function
110
+ shape [batch_size X sequence_length] if not None.
111
+ num_experts (`int`, *optional*):
112
+ Number of experts
113
+
114
+ Returns:
115
+ The auxiliary loss.
116
+ """
117
+ if router_logits is None or not isinstance(router_logits, tuple):
118
+ return 0
119
+
120
+ if isinstance(router_logits, tuple):
121
+ compute_device = router_logits[0].device
122
+ concatenated_router_logits = torch.cat(
123
+ [layer_router.to(compute_device) for layer_router in router_logits], dim=0
124
+ )
125
+
126
+ routing_weights = torch.nn.functional.softmax(concatenated_router_logits, dim=-1)
127
+
128
+ _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
129
+
130
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
131
+
132
+ if attention_mask is None:
133
+ # Compute the percentage of tokens routed to each experts
134
+ tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
135
+
136
+ # Compute the average probability of routing to these experts
137
+ router_prob_per_expert = torch.mean(routing_weights, dim=0)
138
+ else:
139
+ batch_size, sequence_length = attention_mask.shape
140
+ num_hidden_layers = concatenated_router_logits.shape[0] // (batch_size * sequence_length)
141
+
142
+ # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
143
+ expert_attention_mask = (
144
+ attention_mask[None, :, :, None, None]
145
+ .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
146
+ .reshape(-1, top_k, num_experts)
147
+ .to(compute_device)
148
+ )
149
+
150
+ # Compute the percentage of tokens routed to each experts
151
+ tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
152
+ expert_attention_mask, dim=0
153
+ )
154
+
155
+ # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
156
+ router_per_expert_attention_mask = (
157
+ attention_mask[None, :, :, None]
158
+ .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
159
+ .reshape(-1, num_experts)
160
+ .to(compute_device)
161
+ )
162
+
163
+ # Compute the average probability of routing to these experts
164
+ router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
165
+ router_per_expert_attention_mask, dim=0
166
+ )
167
+
168
+ overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
169
+ return overall_loss * num_experts
170
+
171
+
172
+ # Copied from transformers.models.llama.modeling_llama._get_unpad_data
173
+ def _get_unpad_data(attention_mask):
174
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
175
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
176
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
177
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
178
+ return (
179
+ indices,
180
+ cu_seqlens,
181
+ max_seqlen_in_batch,
182
+ )
183
+
184
+
185
+ # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Jamba
186
+ class JambaRMSNorm(nn.Module):
187
+ def __init__(self, hidden_size, eps=1e-6):
188
+ """
189
+ JambaRMSNorm is equivalent to T5LayerNorm
190
+ """
191
+ super().__init__()
192
+ self.weight = nn.Parameter(torch.ones(hidden_size))
193
+ self.variance_epsilon = eps
194
+
195
+ def forward(self, hidden_states):
196
+ input_dtype = hidden_states.dtype
197
+ hidden_states = hidden_states.to(torch.float32)
198
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
199
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
200
+ return self.weight * hidden_states.to(input_dtype)
201
+
202
+
203
+ # Copied from transformers.models.llama.modeling_llama.repeat_kv
204
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
205
+ """
206
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
207
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
208
+ """
209
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
210
+ if n_rep == 1:
211
+ return hidden_states
212
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
213
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
214
+
215
+
216
+ class HybridMambaAttentionDynamicCache(DynamicCache):
217
+ """
218
+ A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache
219
+ (which has a constant shape regardless of seq_len).
220
+
221
+ This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states`
222
+ and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor
223
+ For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`,
224
+ while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors).
225
+ For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors),
226
+ while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`,
227
+ and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`.
228
+ """
229
+
230
+ def __init__(self, config, batch_size, dtype=torch.float16, device=None):
231
+ self.dtype = dtype
232
+ self.layers_block_type = config.layers_block_type
233
+ self.has_previous_state = False # only used by mamba
234
+ intermediate_size = config.mamba_expand * config.hidden_size
235
+ ssm_state_size = config.mamba_d_state
236
+ conv_kernel_size = config.mamba_d_conv
237
+ self.conv_states = []
238
+ self.ssm_states = []
239
+ for i in range(config.num_hidden_layers):
240
+ if self.layers_block_type[i] == "mamba":
241
+ self.conv_states += [
242
+ torch.zeros(batch_size, intermediate_size, conv_kernel_size, device=device, dtype=dtype)
243
+ ]
244
+ self.ssm_states += [
245
+ torch.zeros(batch_size, intermediate_size, ssm_state_size, device=device, dtype=dtype)
246
+ ]
247
+ else:
248
+ self.conv_states += [torch.tensor([[]] * batch_size, device=device)]
249
+ self.ssm_states += [torch.tensor([[]] * batch_size, device=device)]
250
+
251
+ self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)]
252
+ self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)]
253
+
254
+ def update(
255
+ self,
256
+ key_states: torch.Tensor,
257
+ value_states: torch.Tensor,
258
+ layer_idx: int,
259
+ cache_kwargs: Optional[Dict[str, Any]] = None,
260
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
261
+ # Update the cache
262
+ if self.key_cache[layer_idx].shape[-1] == 0:
263
+ self.key_cache[layer_idx] = key_states
264
+ self.value_cache[layer_idx] = value_states
265
+ else:
266
+ self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2)
267
+ self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2)
268
+
269
+ return self.key_cache[layer_idx], self.value_cache[layer_idx]
270
+
271
+ def reorder_cache(self, beam_idx: torch.LongTensor):
272
+ """Reorders the cache for beam search, given the selected beam indices."""
273
+ for layer_idx in range(len(self.key_cache)):
274
+ device = self.key_cache[layer_idx].device
275
+ self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
276
+ device = self.value_cache[layer_idx].device
277
+ self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
278
+
279
+ device = self.conv_states[layer_idx].device
280
+ self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device))
281
+ device = self.ssm_states[layer_idx].device
282
+ self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device))
283
+
284
+ def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
285
+ raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.")
286
+
287
+ @classmethod
288
+ def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache":
289
+ raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.")
290
+
291
+
292
+ # Adapted from transformers.models.mistral.modeling_mistral.MistralAttention with Mistral->Jamba
293
+ class JambaAttention(nn.Module):
294
+ """
295
+ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
296
+ and "Generating Long Sequences with Sparse Transformers".
297
+ """
298
+
299
+ def __init__(self, config: JambaConfig, layer_idx: Optional[int] = None):
300
+ super().__init__()
301
+ self.config = config
302
+ self.layer_idx = layer_idx
303
+ if layer_idx is None:
304
+ logger.warning_once(
305
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
306
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
307
+ "when creating this class."
308
+ )
309
+
310
+ self.hidden_size = config.hidden_size
311
+ self.num_heads = config.num_attention_heads
312
+ self.head_dim = self.hidden_size // self.num_heads
313
+ self.num_key_value_heads = config.num_key_value_heads
314
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
315
+ self.is_causal = True
316
+ self.attention_dropout = config.attention_dropout
317
+
318
+ if (self.head_dim * self.num_heads) != self.hidden_size:
319
+ raise ValueError(
320
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
321
+ f" and `num_heads`: {self.num_heads})."
322
+ )
323
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
324
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
325
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
326
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
327
+
328
+ def forward(
329
+ self,
330
+ hidden_states: torch.Tensor,
331
+ attention_mask: Optional[torch.Tensor] = None,
332
+ position_ids: Optional[torch.LongTensor] = None,
333
+ past_key_value: Optional[HybridMambaAttentionDynamicCache] = None,
334
+ output_attentions: bool = False,
335
+ use_cache: bool = False,
336
+ cache_position: Optional[torch.LongTensor] = None,
337
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
338
+ bsz, q_len, _ = hidden_states.size()
339
+
340
+ query_states = self.q_proj(hidden_states)
341
+ key_states = self.k_proj(hidden_states)
342
+ value_states = self.v_proj(hidden_states)
343
+
344
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
345
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
346
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
347
+
348
+ if past_key_value is not None:
349
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx)
350
+
351
+ # repeat k/v heads if n_kv_heads < n_heads
352
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
353
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
354
+
355
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
356
+
357
+ if attention_mask is not None: # no matter the length, we just slice it
358
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
359
+ attn_weights = attn_weights + causal_mask
360
+
361
+ # upcast attention to fp32
362
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
363
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
364
+ attn_output = torch.matmul(attn_weights, value_states)
365
+
366
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
367
+ raise ValueError(
368
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
369
+ f" {attn_output.size()}"
370
+ )
371
+
372
+ attn_output = attn_output.transpose(1, 2).contiguous()
373
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
374
+
375
+ attn_output = self.o_proj(attn_output)
376
+
377
+ if not output_attentions:
378
+ attn_weights = None
379
+
380
+ return attn_output, attn_weights, past_key_value
381
+
382
+
383
+ # Adapted from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with Mistral->Jamba
384
+ class JambaFlashAttention2(JambaAttention):
385
+ """
386
+ Jamba flash attention module. This module inherits from `JambaAttention` as the weights of the module stays
387
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
388
+ flash attention and deal with padding tokens in case the input contains any of them.
389
+ """
390
+
391
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
392
+ def __init__(self, *args, **kwargs):
393
+ super().__init__(*args, **kwargs)
394
+
395
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
396
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
397
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
398
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
399
+
400
+ def forward(
401
+ self,
402
+ hidden_states: torch.Tensor,
403
+ attention_mask: Optional[torch.Tensor] = None,
404
+ position_ids: Optional[torch.LongTensor] = None,
405
+ past_key_value: Optional[HybridMambaAttentionDynamicCache] = None,
406
+ output_attentions: bool = False,
407
+ use_cache: bool = False,
408
+ cache_position: Optional[torch.LongTensor] = None,
409
+ **kwargs,
410
+ ):
411
+ bsz, q_len, _ = hidden_states.size()
412
+
413
+ query_states = self.q_proj(hidden_states)
414
+ key_states = self.k_proj(hidden_states)
415
+ value_states = self.v_proj(hidden_states)
416
+
417
+ # Flash attention requires the input to have the shape
418
+ # batch_size x seq_length x head_dim x hidden_dim
419
+ # therefore we just need to keep the original shape
420
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
421
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
422
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
423
+
424
+ kv_seq_len = cache_position[-1]
425
+
426
+ use_sliding_windows = (
427
+ _flash_supports_window_size
428
+ and getattr(self.config, "sliding_window", None) is not None
429
+ and kv_seq_len > self.config.sliding_window
430
+ )
431
+
432
+ if not _flash_supports_window_size:
433
+ logger.warning_once(
434
+ "The current flash attention version does not support sliding window attention, for a more memory efficient implementation"
435
+ " make sure to upgrade flash-attn library."
436
+ )
437
+
438
+ if past_key_value is not None:
439
+ # Activate slicing cache only if the config has a value `sliding_windows` attribute
440
+ cache_has_contents = cache_position[0] > 0
441
+ if (
442
+ getattr(self.config, "sliding_window", None) is not None
443
+ and kv_seq_len > self.config.sliding_window
444
+ and cache_has_contents
445
+ ):
446
+ slicing_tokens = 1 - self.config.sliding_window
447
+
448
+ past_key = past_key_value[self.layer_idx][0]
449
+ past_value = past_key_value[self.layer_idx][1]
450
+
451
+ past_key = past_key[:, :, slicing_tokens:, :].contiguous()
452
+ past_value = past_value[:, :, slicing_tokens:, :].contiguous()
453
+
454
+ if past_key.shape[-2] != self.config.sliding_window - 1:
455
+ raise ValueError(
456
+ f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
457
+ f" {past_key.shape}"
458
+ )
459
+
460
+ if attention_mask is not None:
461
+ attention_mask = attention_mask[:, slicing_tokens:]
462
+ attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
463
+
464
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx)
465
+
466
+ # repeat k/v heads if n_kv_heads < n_heads
467
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
468
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
469
+ dropout_rate = 0.0 if not self.training else self.attention_dropout
470
+
471
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
472
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
473
+ # cast them back in float16 just to be sure everything works as expected.
474
+ input_dtype = query_states.dtype
475
+ if input_dtype == torch.float32:
476
+ if torch.is_autocast_enabled():
477
+ target_dtype = torch.get_autocast_gpu_dtype()
478
+ # Handle the case where the model is quantized
479
+ elif hasattr(self.config, "_pre_quantization_dtype"):
480
+ target_dtype = self.config._pre_quantization_dtype
481
+ else:
482
+ target_dtype = self.q_proj.weight.dtype
483
+
484
+ logger.warning_once(
485
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
486
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
487
+ f" {target_dtype}."
488
+ )
489
+
490
+ query_states = query_states.to(target_dtype)
491
+ key_states = key_states.to(target_dtype)
492
+ value_states = value_states.to(target_dtype)
493
+
494
+ # Reashape to the expected shape for Flash Attention
495
+ query_states = query_states.transpose(1, 2)
496
+ key_states = key_states.transpose(1, 2)
497
+ value_states = value_states.transpose(1, 2)
498
+
499
+ attn_output = self._flash_attention_forward(
500
+ query_states,
501
+ key_states,
502
+ value_states,
503
+ attention_mask,
504
+ q_len,
505
+ dropout=dropout_rate,
506
+ use_sliding_windows=use_sliding_windows,
507
+ )
508
+
509
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
510
+ attn_output = self.o_proj(attn_output)
511
+
512
+ if not output_attentions:
513
+ attn_weights = None
514
+
515
+ return attn_output, attn_weights, past_key_value
516
+
517
+ def _flash_attention_forward(
518
+ self,
519
+ query_states,
520
+ key_states,
521
+ value_states,
522
+ attention_mask,
523
+ query_length,
524
+ dropout=0.0,
525
+ softmax_scale=None,
526
+ use_sliding_windows=False,
527
+ ):
528
+ """
529
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
530
+ first unpad the input, then computes the attention scores and pad the final attention scores.
531
+
532
+ Args:
533
+ query_states (`torch.Tensor`):
534
+ Input query states to be passed to Flash Attention API
535
+ key_states (`torch.Tensor`):
536
+ Input key states to be passed to Flash Attention API
537
+ value_states (`torch.Tensor`):
538
+ Input value states to be passed to Flash Attention API
539
+ attention_mask (`torch.Tensor`):
540
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
541
+ position of padding tokens and 1 for the position of non-padding tokens.
542
+ dropout (`float`, *optional*):
543
+ Attention dropout
544
+ softmax_scale (`float`, *optional*):
545
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
546
+ use_sliding_windows (`bool`, *optional*):
547
+ Whether to activate sliding window attention.
548
+ """
549
+ if not self._flash_attn_uses_top_left_mask:
550
+ causal = self.is_causal
551
+ else:
552
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
553
+ causal = self.is_causal and query_length != 1
554
+
555
+ # Contains at least one padding token in the sequence
556
+ if attention_mask is not None:
557
+ batch_size = query_states.shape[0]
558
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
559
+ query_states, key_states, value_states, attention_mask, query_length
560
+ )
561
+
562
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
563
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
564
+
565
+ if not use_sliding_windows:
566
+ attn_output_unpad = flash_attn_varlen_func(
567
+ query_states,
568
+ key_states,
569
+ value_states,
570
+ cu_seqlens_q=cu_seqlens_q,
571
+ cu_seqlens_k=cu_seqlens_k,
572
+ max_seqlen_q=max_seqlen_in_batch_q,
573
+ max_seqlen_k=max_seqlen_in_batch_k,
574
+ dropout_p=dropout,
575
+ softmax_scale=softmax_scale,
576
+ causal=causal,
577
+ )
578
+ else:
579
+ attn_output_unpad = flash_attn_varlen_func(
580
+ query_states,
581
+ key_states,
582
+ value_states,
583
+ cu_seqlens_q=cu_seqlens_q,
584
+ cu_seqlens_k=cu_seqlens_k,
585
+ max_seqlen_q=max_seqlen_in_batch_q,
586
+ max_seqlen_k=max_seqlen_in_batch_k,
587
+ dropout_p=dropout,
588
+ softmax_scale=softmax_scale,
589
+ causal=causal,
590
+ window_size=(self.config.sliding_window, self.config.sliding_window),
591
+ )
592
+
593
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
594
+ else:
595
+ if not use_sliding_windows:
596
+ attn_output = flash_attn_func(
597
+ query_states,
598
+ key_states,
599
+ value_states,
600
+ dropout,
601
+ softmax_scale=softmax_scale,
602
+ causal=causal,
603
+ )
604
+ else:
605
+ attn_output = flash_attn_func(
606
+ query_states,
607
+ key_states,
608
+ value_states,
609
+ dropout,
610
+ softmax_scale=softmax_scale,
611
+ causal=causal,
612
+ window_size=(self.config.sliding_window, self.config.sliding_window),
613
+ )
614
+
615
+ return attn_output
616
+
617
+ # Copied from transformers.models.mixtral.modeling_mixtral.MixtralFlashAttention2._upad_input
618
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
619
+ batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
620
+
621
+ # On the first iteration we need to properly re-create the padding mask
622
+ # by slicing it on the proper place
623
+ if kv_seq_len != attention_mask.shape[-1]:
624
+ attention_mask_num_tokens = attention_mask.shape[-1]
625
+ attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :]
626
+
627
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
628
+
629
+ key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
630
+ value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
631
+
632
+ if query_length == kv_seq_len:
633
+ query_layer = index_first_axis(
634
+ query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
635
+ )
636
+ cu_seqlens_q = cu_seqlens_k
637
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
638
+ indices_q = indices_k
639
+ elif query_length == 1:
640
+ max_seqlen_in_batch_q = 1
641
+ cu_seqlens_q = torch.arange(
642
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
643
+ ) # There is a memcpy here, that is very bad.
644
+ indices_q = cu_seqlens_q[:-1]
645
+ query_layer = query_layer.squeeze(1)
646
+ else:
647
+ # The -q_len: slice assumes left padding.
648
+ attention_mask = attention_mask[:, -query_length:]
649
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
650
+
651
+ return (
652
+ query_layer,
653
+ key_layer,
654
+ value_layer,
655
+ indices_q,
656
+ (cu_seqlens_q, cu_seqlens_k),
657
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
658
+ )
659
+
660
+
661
+ # Adapted from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Jamba
662
+ class JambaSdpaAttention(JambaAttention):
663
+ """
664
+ Jamba attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
665
+ `JambaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
666
+ SDPA API.
667
+ """
668
+
669
+ # Adapted from JambaAttention.forward
670
+ def forward(
671
+ self,
672
+ hidden_states: torch.Tensor,
673
+ attention_mask: Optional[torch.Tensor] = None,
674
+ position_ids: Optional[torch.LongTensor] = None,
675
+ past_key_value: Optional[HybridMambaAttentionDynamicCache] = None,
676
+ output_attentions: bool = False,
677
+ use_cache: bool = False,
678
+ cache_position: Optional[torch.LongTensor] = None,
679
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
680
+ if output_attentions:
681
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
682
+ logger.warning_once(
683
+ "JambaModel is using JambaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
684
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
685
+ )
686
+ return super().forward(
687
+ hidden_states=hidden_states,
688
+ attention_mask=attention_mask,
689
+ position_ids=position_ids,
690
+ past_key_value=past_key_value,
691
+ output_attentions=output_attentions,
692
+ use_cache=use_cache,
693
+ )
694
+
695
+ bsz, q_len, _ = hidden_states.size()
696
+
697
+ query_states = self.q_proj(hidden_states)
698
+ key_states = self.k_proj(hidden_states)
699
+ value_states = self.v_proj(hidden_states)
700
+
701
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
702
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
703
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
704
+
705
+ if past_key_value is not None:
706
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx)
707
+
708
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
709
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
710
+
711
+ causal_mask = attention_mask
712
+ if attention_mask is not None:
713
+ causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
714
+
715
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
716
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
717
+ if query_states.device.type == "cuda" and attention_mask is not None:
718
+ query_states = query_states.contiguous()
719
+ key_states = key_states.contiguous()
720
+ value_states = value_states.contiguous()
721
+
722
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
723
+ query_states,
724
+ key_states,
725
+ value_states,
726
+ attn_mask=causal_mask,
727
+ dropout_p=self.attention_dropout if self.training else 0.0,
728
+ # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
729
+ is_causal=self.is_causal and attention_mask is None and q_len > 1,
730
+ )
731
+
732
+ attn_output = attn_output.transpose(1, 2).contiguous()
733
+ attn_output = attn_output.view(bsz, q_len, self.hidden_size)
734
+
735
+ attn_output = self.o_proj(attn_output)
736
+
737
+ return attn_output, None, past_key_value
738
+
739
+
740
+ JAMBA_ATTENTION_CLASSES = {
741
+ "eager": JambaAttention,
742
+ "flash_attention_2": JambaFlashAttention2,
743
+ "sdpa": JambaSdpaAttention,
744
+ }
745
+
746
+
747
+ # Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
748
+ class JambaMambaMixer(nn.Module):
749
+ """
750
+ Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`.
751
+ A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
752
+ ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
753
+ and is why Mamba is called **selective** state spaces)
754
+ """
755
+
756
+ def __init__(self, config: JambaConfig, layer_idx):
757
+ super().__init__()
758
+ self.config = config
759
+ self.layer_idx = layer_idx
760
+ self.hidden_size = config.hidden_size
761
+ self.ssm_state_size = config.mamba_d_state
762
+ self.conv_kernel_size = config.mamba_d_conv
763
+ self.intermediate_size = config.mamba_expand * config.hidden_size
764
+ self.time_step_rank = config.mamba_dt_rank
765
+ self.use_conv_bias = config.mamba_conv_bias
766
+ self.use_bias = config.mamba_proj_bias
767
+ self.conv1d = nn.Conv1d(
768
+ in_channels=self.intermediate_size,
769
+ out_channels=self.intermediate_size,
770
+ bias=self.use_conv_bias,
771
+ kernel_size=self.conv_kernel_size,
772
+ groups=self.intermediate_size,
773
+ padding=self.conv_kernel_size - 1,
774
+ )
775
+
776
+ self.activation = config.hidden_act
777
+ self.act = ACT2FN[config.hidden_act]
778
+
779
+ self.use_fast_kernels = config.use_mamba_kernels
780
+
781
+ # projection of the input hidden states
782
+ self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=self.use_bias)
783
+ # selective projection used to make dt, B and C input dependant
784
+ self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False)
785
+ # time step projection (discretization)
786
+ self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True)
787
+
788
+ # S4D real initialization. These are not discretized!
789
+ # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded
790
+ A = torch.arange(1, self.ssm_state_size + 1, dtype=torch.float32)[None, :]
791
+ A = A.expand(self.intermediate_size, -1).contiguous()
792
+
793
+ self.A_log = nn.Parameter(torch.log(A))
794
+ self.D = nn.Parameter(torch.ones(self.intermediate_size))
795
+ self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=self.use_bias)
796
+
797
+ self.dt_layernorm = JambaRMSNorm(self.time_step_rank, eps=config.rms_norm_eps)
798
+ self.b_layernorm = JambaRMSNorm(self.ssm_state_size, eps=config.rms_norm_eps)
799
+ self.c_layernorm = JambaRMSNorm(self.ssm_state_size, eps=config.rms_norm_eps)
800
+
801
+ if not is_fast_path_available:
802
+ logger.warning_once(
803
+ "The fast path is not available because on of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`"
804
+ " is None. To install follow https://github.com/state-spaces/mamba/#installation and"
805
+ " https://github.com/Dao-AILab/causal-conv1d. If you want to use the naive implementation, set `use_mamba_kernels=False` in the model config"
806
+ )
807
+
808
+ def cuda_kernels_forward(self, hidden_states: torch.Tensor, cache_params: HybridMambaAttentionDynamicCache = None):
809
+ batch_size, seq_len, _ = hidden_states.shape
810
+ use_precomputed_states = (
811
+ cache_params is not None
812
+ and cache_params.has_previous_state
813
+ and seq_len == 1
814
+ and cache_params.conv_states[self.layer_idx].shape[0]
815
+ == cache_params.ssm_states[self.layer_idx].shape[0]
816
+ == batch_size
817
+ )
818
+ # 1. Gated MLP's linear projection
819
+ projected_states = self.in_proj(hidden_states).transpose(1, 2)
820
+
821
+ # We can't use `mamba_inner_fn` even if in training and without cache params because we have the
822
+ # inner layernorms which isn't supported by this fused kernel
823
+ hidden_states, gate = projected_states.chunk(2, dim=1)
824
+
825
+ # 2. Convolution sequence transformation
826
+ conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2))
827
+ if use_precomputed_states:
828
+ hidden_states = causal_conv1d_update(
829
+ hidden_states.squeeze(-1),
830
+ cache_params.conv_states[self.layer_idx],
831
+ conv_weights,
832
+ self.conv1d.bias,
833
+ self.activation,
834
+ )
835
+ hidden_states = hidden_states.unsqueeze(-1)
836
+ else:
837
+ if cache_params is not None:
838
+ conv_states = nn.functional.pad(hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0))
839
+ cache_params.conv_states[self.layer_idx].copy_(conv_states)
840
+ hidden_states = causal_conv1d_fn(hidden_states, conv_weights, self.conv1d.bias, activation=self.activation)
841
+
842
+ # 3. State Space Model sequence transformation
843
+ # 3.a. input varying initialization of time_step, B and C
844
+ ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
845
+ time_step, B, C = torch.split(
846
+ ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1
847
+ )
848
+
849
+ time_step = self.dt_layernorm(time_step)
850
+ B = self.b_layernorm(B)
851
+ C = self.c_layernorm(C)
852
+
853
+ # Here we need to apply dt_proj without the bias, as the bias is added in the selective scan kernel.
854
+ # This is a hack to apply dt_proj while still using the forward pass of `torch.nn.Linear`, which is needed
855
+ # in order to make quantization work. Quantization code replaces `torch.nn.Linear` layers with quantized
856
+ # linear layers, and requires to call the forward pass directly.
857
+ # The original code here was: ```discrete_time_step = self.dt_proj.weight @ time_step.transpose(1, 2)```
858
+ time_proj_bias = self.dt_proj.bias
859
+ self.dt_proj.bias = None
860
+ discrete_time_step = self.dt_proj(time_step).transpose(1, 2)
861
+ self.dt_proj.bias = time_proj_bias
862
+
863
+ A = -torch.exp(self.A_log.float())
864
+ # 3.c perform the recurrence y ← SSM(A, B, C)(x)
865
+ time_proj_bias = time_proj_bias.float() if time_proj_bias is not None else None
866
+ if use_precomputed_states:
867
+ scan_outputs = selective_state_update(
868
+ cache_params.ssm_states[self.layer_idx],
869
+ hidden_states[..., 0],
870
+ discrete_time_step[..., 0],
871
+ A,
872
+ B[:, 0],
873
+ C[:, 0],
874
+ self.D,
875
+ gate[..., 0],
876
+ time_proj_bias,
877
+ dt_softplus=True,
878
+ ).unsqueeze(-1)
879
+ else:
880
+ scan_outputs, ssm_state = selective_scan_fn(
881
+ hidden_states,
882
+ discrete_time_step,
883
+ A,
884
+ B.transpose(1, 2),
885
+ C.transpose(1, 2),
886
+ self.D.float(),
887
+ gate,
888
+ time_proj_bias,
889
+ delta_softplus=True,
890
+ return_last_state=True,
891
+ )
892
+ if ssm_state is not None and cache_params is not None:
893
+ cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
894
+
895
+ # 4. Final linear projection
896
+ contextualized_states = self.out_proj(scan_outputs.transpose(1, 2))
897
+
898
+ return contextualized_states
899
+
900
+ # fmt: off
901
+ def slow_forward(self, input_states, cache_params: HybridMambaAttentionDynamicCache = None):
902
+ batch_size, seq_len, _ = input_states.shape
903
+ dtype = input_states.dtype
904
+ # 1. Gated MLP's linear projection
905
+ projected_states = self.in_proj(input_states).transpose(1, 2) # [batch, 2 * intermediate_size, seq_len]
906
+ hidden_states, gate = projected_states.chunk(2, dim=1)
907
+
908
+ use_cache = isinstance(cache_params,HybridMambaAttentionDynamicCache)
909
+ # 2. Convolution sequence transformation
910
+ if use_cache and cache_params.ssm_states[self.layer_idx].shape[0] == batch_size:
911
+ if self.training:
912
+ # In training mode, we don't want to perform in-place operations on ssm_state so we can compute the backwards pass
913
+ ssm_state = cache_params.ssm_states[self.layer_idx].clone()
914
+ else:
915
+ ssm_state = cache_params.ssm_states[self.layer_idx]
916
+
917
+ if cache_params.has_previous_state and seq_len == 1 and \
918
+ cache_params.conv_states[self.layer_idx].shape[0] == batch_size:
919
+ conv_state = cache_params.conv_states[self.layer_idx] # [batch, intermediate_size, conv_kernel_size]
920
+ conv_state = torch.roll(conv_state, shifts=-1, dims=-1)
921
+ conv_state[:, :, -1] = hidden_states[:, :, 0]
922
+ cache_params.conv_states[self.layer_idx] = conv_state
923
+ hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1)
924
+ if self.use_conv_bias:
925
+ hidden_states += self.conv1d.bias
926
+ hidden_states = self.act(hidden_states).to(dtype).unsqueeze(-1) # [batch, intermediate_size, 1] : decoding
927
+ else:
928
+ conv_state = nn.functional.pad(
929
+ hidden_states,
930
+ (self.conv_kernel_size - hidden_states.shape[-1], 0)
931
+ )
932
+ cache_params.conv_states[self.layer_idx] = conv_state
933
+ hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len]
934
+ else:
935
+ ssm_state = torch.zeros(
936
+ (batch_size, self.intermediate_size, self.ssm_state_size),
937
+ device=hidden_states.device, dtype=dtype
938
+ )
939
+ hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len]
940
+
941
+ # 3. State Space Model sequence transformation
942
+ # 3.a. Selection: [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2]
943
+ ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
944
+ time_step, B, C = torch.split(
945
+ ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1
946
+ )
947
+
948
+ time_step = self.dt_layernorm(time_step)
949
+ B = self.b_layernorm(B)
950
+ C = self.c_layernorm(C)
951
+
952
+ discrete_time_step = self.dt_proj(time_step) # [batch, seq_len, intermediate_size]
953
+ discrete_time_step = nn.functional.softplus(discrete_time_step).transpose(1, 2) # [batch, intermediate_size, seq_len]
954
+
955
+ # 3.b. Discretization: B and C to [batch, seq_len, intermediate_size, ssm_state_size] (SRAM)
956
+ A = -torch.exp(self.A_log.float()) # [intermediate_size, ssm_state_size]
957
+ discrete_A = torch.exp(A[None, :, None, :] * discrete_time_step[:, :, :, None]) # [batch, intermediate_size, seq_len, ssm_state_size]
958
+ discrete_B = discrete_time_step[:, :, :, None] * B[:, None, :, :].float() # [batch, intermediade_size, seq_len, ssm_state_size]
959
+ deltaB_u = discrete_B * hidden_states[:, :, :, None].float()
960
+
961
+ # 3.c perform the recurrence y ← SSM(A, B, C)(x)
962
+ scan_outputs = []
963
+ for i in range(seq_len):
964
+ ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :] # [batch, intermediade_size, ssm_state]
965
+ scan_output = torch.matmul(ssm_state.to(dtype), C[:, i, :].unsqueeze(-1)) # [batch, intermediade_size, 1]
966
+ scan_outputs.append(scan_output[:, :, 0])
967
+ scan_output = torch.stack(scan_outputs, dim=-1) # [batch, intermediade_size, seq_len]
968
+ scan_output = scan_output + (hidden_states * self.D[None, :, None])
969
+ scan_output = (scan_output * self.act(gate))
970
+
971
+ if use_cache:
972
+ cache_params.ssm_states[self.layer_idx] = ssm_state
973
+
974
+ # 4. Final linear projection
975
+ contextualized_states = self.out_proj(scan_output.transpose(1, 2)) # [batch, seq_len, hidden_size]
976
+ return contextualized_states
977
+ # fmt: on
978
+
979
+ def forward(self, hidden_states, cache_params: HybridMambaAttentionDynamicCache = None):
980
+ if self.use_fast_kernels:
981
+ if not is_fast_path_available or "cuda" not in self.x_proj.weight.device.type:
982
+ raise ValueError(
983
+ "Fast Mamba kernels are not available. Make sure to they are installed and that the mamba module is on a CUDA device"
984
+ )
985
+ return self.cuda_kernels_forward(hidden_states, cache_params)
986
+ return self.slow_forward(hidden_states, cache_params)
987
+
988
+
989
+ # Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Jamba
990
+ class JambaMLP(nn.Module):
991
+ def __init__(self, config):
992
+ super().__init__()
993
+ self.config = config
994
+ self.hidden_size = config.hidden_size
995
+ self.intermediate_size = config.intermediate_size
996
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
997
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
998
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
999
+ self.act_fn = ACT2FN[config.hidden_act]
1000
+
1001
+ def forward(self, x):
1002
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
1003
+
1004
+
1005
+ # Adapted from transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock with Mistral->Jamba
1006
+ class JambaSparseMoeBlock(nn.Module):
1007
+ """
1008
+ This implementation is
1009
+ strictly equivalent to standard MoE with full capacity (no
1010
+ dropped tokens). It's faster since it formulates MoE operations
1011
+ in terms of block-sparse operations to accomodate imbalanced
1012
+ assignments of tokens to experts, whereas standard MoE either
1013
+ (1) drop tokens at the cost of reduced performance or (2) set
1014
+ capacity factor to number of experts and thus waste computation
1015
+ and memory on padding.
1016
+ """
1017
+
1018
+ def __init__(self, config: JambaConfig):
1019
+ super().__init__()
1020
+ self.hidden_dim = config.hidden_size
1021
+ self.ffn_dim = config.intermediate_size
1022
+ self.num_experts = config.num_experts
1023
+ self.top_k = config.num_experts_per_tok
1024
+
1025
+ self.router = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
1026
+ self.experts = nn.ModuleList([JambaMLP(config) for _ in range(self.num_experts)])
1027
+
1028
+ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
1029
+ """ """
1030
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
1031
+
1032
+ hidden_states = hidden_states.view(-1, hidden_dim)
1033
+ # router_logits: (batch * sequence_length, n_experts)
1034
+ router_logits = self.router(hidden_states)
1035
+ routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
1036
+ routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
1037
+ # we cast back to the input dtype
1038
+ routing_weights = routing_weights.to(hidden_states.dtype)
1039
+
1040
+ final_hidden_states = torch.zeros(
1041
+ (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
1042
+ )
1043
+
1044
+ # One hot encode the selected experts to create an expert mask
1045
+ # this will be used to easily index which expert is going to be sollicitated
1046
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
1047
+
1048
+ # Loop over all available experts in the model and perform the computation on each expert
1049
+ for expert_idx in range(self.num_experts):
1050
+ expert_layer = self.experts[expert_idx]
1051
+ idx, top_x = torch.where(expert_mask[expert_idx])
1052
+
1053
+ if top_x.shape[0] == 0:
1054
+ continue
1055
+
1056
+ # Index the correct hidden states and compute the expert hidden state for
1057
+ # the current expert. We need to make sure to multiply the output hidden
1058
+ # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
1059
+ current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
1060
+ current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
1061
+
1062
+ # However `index_add_` only support torch tensors for indexing so we'll use
1063
+ # the `top_x` tensor here.
1064
+ final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
1065
+ final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
1066
+ return final_hidden_states, router_logits
1067
+
1068
+
1069
+ class JambaAttentionDecoderLayer(nn.Module):
1070
+ def __init__(self, config: JambaConfig, layer_idx: int):
1071
+ super().__init__()
1072
+ num_experts = config.layers_num_experts[layer_idx]
1073
+ self.self_attn = JAMBA_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
1074
+
1075
+ ffn_layer_class = JambaSparseMoeBlock if num_experts > 1 else JambaMLP
1076
+ self.feed_forward = ffn_layer_class(config)
1077
+ self.input_layernorm = JambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1078
+ self.pre_ff_layernorm = JambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1079
+
1080
+ def forward(
1081
+ self,
1082
+ hidden_states: torch.Tensor,
1083
+ attention_mask: Optional[torch.Tensor] = None,
1084
+ position_ids: Optional[torch.LongTensor] = None,
1085
+ past_key_value: Optional[HybridMambaAttentionDynamicCache] = None,
1086
+ output_attentions: Optional[bool] = False,
1087
+ output_router_logits: Optional[bool] = False,
1088
+ use_cache: Optional[bool] = False,
1089
+ cache_position: Optional[torch.LongTensor] = None,
1090
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
1091
+ """
1092
+ Args:
1093
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
1094
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
1095
+ `(batch, sequence_length)` where padding elements are indicated by 0.
1096
+ past_key_value (`HybridMambaAttentionDynamicCache`, *optional*): cached past key and value projection states
1097
+ output_attentions (`bool`, *optional*):
1098
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1099
+ returned tensors for more detail.
1100
+ output_router_logits (`bool`, *optional*):
1101
+ Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
1102
+ should not be returned during inference.
1103
+ use_cache (`bool`, *optional*):
1104
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
1105
+ (see `past_key_values`).
1106
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
1107
+ Indices depicting the position of the input sequence tokens in the sequence.
1108
+ """
1109
+
1110
+ residual = hidden_states
1111
+
1112
+ hidden_states = self.input_layernorm(hidden_states)
1113
+
1114
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
1115
+ hidden_states=hidden_states,
1116
+ attention_mask=attention_mask,
1117
+ position_ids=position_ids,
1118
+ past_key_value=past_key_value,
1119
+ output_attentions=output_attentions,
1120
+ use_cache=use_cache,
1121
+ cache_position=cache_position,
1122
+ )
1123
+
1124
+ # residual connection after attention
1125
+ hidden_states = residual + hidden_states
1126
+
1127
+ # feed-forward (experts/MLP)
1128
+ residual = hidden_states
1129
+ hidden_states = self.pre_ff_layernorm(hidden_states)
1130
+ ff_outputs = self.feed_forward(hidden_states)
1131
+ if isinstance(ff_outputs, tuple):
1132
+ hidden_states, router_logits = ff_outputs
1133
+ else:
1134
+ hidden_states, router_logits = ff_outputs, None
1135
+ hidden_states = residual + hidden_states
1136
+
1137
+ outputs = (hidden_states,)
1138
+
1139
+ if output_attentions:
1140
+ outputs += (self_attn_weights,)
1141
+
1142
+ if use_cache:
1143
+ outputs += (present_key_value,)
1144
+
1145
+ if output_router_logits:
1146
+ outputs += (router_logits,)
1147
+
1148
+ return outputs
1149
+
1150
+
1151
+ class JambaMambaDecoderLayer(nn.Module):
1152
+ def __init__(self, config: JambaConfig, layer_idx: int):
1153
+ super().__init__()
1154
+ num_experts = config.layers_num_experts[layer_idx]
1155
+ self.mamba = JambaMambaMixer(config=config, layer_idx=layer_idx)
1156
+
1157
+ ffn_layer_class = JambaSparseMoeBlock if num_experts > 1 else JambaMLP
1158
+ self.feed_forward = ffn_layer_class(config)
1159
+ self.input_layernorm = JambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1160
+ self.pre_ff_layernorm = JambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1161
+
1162
+ def forward(
1163
+ self,
1164
+ hidden_states: torch.Tensor,
1165
+ attention_mask: Optional[torch.Tensor] = None,
1166
+ position_ids: Optional[torch.LongTensor] = None,
1167
+ past_key_value: Optional[HybridMambaAttentionDynamicCache] = None,
1168
+ output_attentions: Optional[bool] = False,
1169
+ output_router_logits: Optional[bool] = False,
1170
+ use_cache: Optional[bool] = False,
1171
+ cache_position: Optional[torch.LongTensor] = None,
1172
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
1173
+ """
1174
+ Args:
1175
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
1176
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
1177
+ `(batch, sequence_length)` where padding elements are indicated by 0.
1178
+ past_key_value (`HybridMambaAttentionDynamicCache`, *optional*): cached past key and value projection states
1179
+ output_attentions (`bool`, *optional*):
1180
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1181
+ returned tensors for more detail.
1182
+ output_router_logits (`bool`, *optional*):
1183
+ Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
1184
+ should not be returned during inference.
1185
+ use_cache (`bool`, *optional*):
1186
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
1187
+ (see `past_key_values`).
1188
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
1189
+ Indices depicting the position of the input sequence tokens in the sequence.
1190
+ """
1191
+
1192
+ residual = hidden_states
1193
+
1194
+ hidden_states = self.input_layernorm(hidden_states)
1195
+
1196
+ hidden_states = self.mamba(
1197
+ hidden_states=hidden_states,
1198
+ cache_params=past_key_value,
1199
+ )
1200
+ self_attn_weights = None
1201
+
1202
+ # residual connection after mamba
1203
+ hidden_states = residual + hidden_states
1204
+
1205
+ # feed-forward (experts/MLP)
1206
+ residual = hidden_states
1207
+ hidden_states = self.pre_ff_layernorm(hidden_states)
1208
+ ff_outputs = self.feed_forward(hidden_states)
1209
+ if isinstance(ff_outputs, tuple):
1210
+ hidden_states, router_logits = ff_outputs
1211
+ else:
1212
+ hidden_states, router_logits = ff_outputs, None
1213
+ hidden_states = residual + hidden_states
1214
+
1215
+ outputs = (hidden_states,)
1216
+
1217
+ if output_attentions:
1218
+ outputs += (self_attn_weights,)
1219
+
1220
+ if use_cache:
1221
+ outputs += (past_key_value,)
1222
+
1223
+ if output_router_logits:
1224
+ outputs += (router_logits,)
1225
+
1226
+ return outputs
1227
+
1228
+
1229
+ JAMBA_START_DOCSTRING = r"""
1230
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
1231
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
1232
+ etc.)
1233
+
1234
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
1235
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
1236
+ and behavior.
1237
+
1238
+ Parameters:
1239
+ config ([`JambaConfig`]):
1240
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
1241
+ load the weights associated with the model, only the configuration. Check out the
1242
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
1243
+ """
1244
+
1245
+
1246
+ @add_start_docstrings(
1247
+ "The bare Jamba Model outputting raw hidden-states without any specific head on top.",
1248
+ JAMBA_START_DOCSTRING,
1249
+ )
1250
+ class JambaPreTrainedModel(PreTrainedModel):
1251
+ config_class = JambaConfig
1252
+ base_model_prefix = "model"
1253
+ supports_gradient_checkpointing = True
1254
+ _no_split_modules = ["JambaAttentionDecoderLayer", "JambaMambaDecoderLayer"]
1255
+ _skip_keys_device_placement = "past_key_values"
1256
+ _supports_flash_attn_2 = True
1257
+ _supports_sdpa = True
1258
+ _supports_cache_class = True
1259
+
1260
+ def _init_weights(self, module):
1261
+ std = self.config.initializer_range
1262
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
1263
+ module.weight.data.normal_(mean=0.0, std=std)
1264
+ if module.bias is not None:
1265
+ module.bias.data.zero_()
1266
+ elif isinstance(module, nn.Embedding):
1267
+ module.weight.data.normal_(mean=0.0, std=std)
1268
+ if module.padding_idx is not None:
1269
+ module.weight.data[module.padding_idx].zero_()
1270
+
1271
+
1272
+ JAMBA_INPUTS_DOCSTRING = r"""
1273
+ Args:
1274
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1275
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
1276
+ it.
1277
+
1278
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1279
+ [`PreTrainedTokenizer.__call__`] for details.
1280
+
1281
+ [What are input IDs?](../glossary#input-ids)
1282
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1283
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1284
+
1285
+ - 1 for tokens that are **not masked**,
1286
+ - 0 for tokens that are **masked**.
1287
+
1288
+ [What are attention masks?](../glossary#attention-mask)
1289
+
1290
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1291
+ [`PreTrainedTokenizer.__call__`] for details.
1292
+
1293
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
1294
+ `past_key_values`).
1295
+
1296
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
1297
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
1298
+ information on the default strategy.
1299
+
1300
+ - 1 indicates the head is **not masked**,
1301
+ - 0 indicates the head is **masked**.
1302
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1303
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
1304
+ config.n_positions - 1]`.
1305
+
1306
+ [What are position IDs?](../glossary#position-ids)
1307
+ past_key_values (`HybridMambaAttentionDynamicCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
1308
+ A HybridMambaAttentionDynamicCache object containing pre-computed hidden-states (keys and values in the
1309
+ self-attention blocks and convolution and ssm states in the mamba blocks) that can be used (see
1310
+ `past_key_values` input) to speed up sequential decoding.
1311
+ Key and value cache tensors have shape `(batch_size, num_heads, seq_len, head_dim)`.
1312
+ Convolution and ssm states tensors have shape `(batch_size, d_inner, d_conv)` and
1313
+ `(batch_size, d_inner, d_state)` respectively.
1314
+ See the `HybridMambaAttentionDynamicCache` class for more details.
1315
+
1316
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that
1317
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
1318
+ `input_ids` of shape `(batch_size, sequence_length)`.
1319
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1320
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
1321
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
1322
+ model's internal embedding lookup matrix.
1323
+ use_cache (`bool`, *optional*):
1324
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1325
+ `past_key_values`).
1326
+ output_attentions (`bool`, *optional*):
1327
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1328
+ tensors for more detail.
1329
+ output_hidden_states (`bool`, *optional*):
1330
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1331
+ more detail.
1332
+ output_router_logits (`bool`, *optional*):
1333
+ Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
1334
+ should not be returned during inference.
1335
+ return_dict (`bool`, *optional*):
1336
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1337
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
1338
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
1339
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
1340
+ the complete sequence length.
1341
+ """
1342
+
1343
+ ALL_DECODER_LAYER_TYPES = {"attention": JambaAttentionDecoderLayer, "mamba": JambaMambaDecoderLayer}
1344
+
1345
+
1346
+ @add_start_docstrings(
1347
+ "The bare Jamba Model outputting raw hidden-states without any specific head on top.",
1348
+ JAMBA_START_DOCSTRING,
1349
+ )
1350
+ # Adapted from transformers.models.mistral.modeling_mistral.MistralModel with MISTRAL->JAMBA, Mistral->Jamba
1351
+ class JambaModel(JambaPreTrainedModel):
1352
+ """
1353
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`JambaDecoderLayer`]
1354
+
1355
+ Args:
1356
+ config: JambaConfig
1357
+ """
1358
+
1359
+ def __init__(self, config: JambaConfig):
1360
+ super().__init__(config)
1361
+ self.padding_idx = config.pad_token_id
1362
+ self.vocab_size = config.vocab_size
1363
+
1364
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
1365
+ decoder_layers = []
1366
+ for i in range(config.num_hidden_layers):
1367
+ layer_class = ALL_DECODER_LAYER_TYPES[config.layers_block_type[i]]
1368
+ decoder_layers.append(layer_class(config, layer_idx=i))
1369
+ self.layers = nn.ModuleList(decoder_layers)
1370
+
1371
+ self._attn_implementation = config._attn_implementation
1372
+ self.final_layernorm = JambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1373
+
1374
+ self.gradient_checkpointing = False
1375
+ # Initialize weights and apply final processing
1376
+ self.post_init()
1377
+
1378
+ def get_input_embeddings(self):
1379
+ return self.embed_tokens
1380
+
1381
+ def set_input_embeddings(self, value):
1382
+ self.embed_tokens = value
1383
+
1384
+ @add_start_docstrings_to_model_forward(JAMBA_INPUTS_DOCSTRING)
1385
+ def forward(
1386
+ self,
1387
+ input_ids: torch.LongTensor = None,
1388
+ attention_mask: Optional[torch.Tensor] = None,
1389
+ position_ids: Optional[torch.LongTensor] = None,
1390
+ past_key_values: Optional[HybridMambaAttentionDynamicCache] = None,
1391
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1392
+ use_cache: Optional[bool] = None,
1393
+ output_attentions: Optional[bool] = None,
1394
+ output_hidden_states: Optional[bool] = None,
1395
+ output_router_logits: Optional[bool] = None,
1396
+ return_dict: Optional[bool] = None,
1397
+ cache_position: Optional[torch.LongTensor] = None,
1398
+ ) -> Union[Tuple, MoeModelOutputWithPast]:
1399
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1400
+ output_router_logits = (
1401
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
1402
+ )
1403
+ output_hidden_states = (
1404
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1405
+ )
1406
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1407
+
1408
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1409
+
1410
+ if (input_ids is None) ^ (inputs_embeds is not None):
1411
+ raise ValueError(
1412
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
1413
+ )
1414
+
1415
+ if self.gradient_checkpointing and self.training and use_cache:
1416
+ logger.warning_once(
1417
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
1418
+ )
1419
+ use_cache = False
1420
+
1421
+ if inputs_embeds is None:
1422
+ inputs_embeds = self.embed_tokens(input_ids)
1423
+ hidden_states = inputs_embeds
1424
+
1425
+ if use_cache and past_key_values is None:
1426
+ logger.warning_once(
1427
+ "Jamba requires an initialized `HybridMambaAttentionDynamicCache` to return a cache. None was "
1428
+ "provided, so no cache will be returned."
1429
+ )
1430
+
1431
+ if cache_position is None:
1432
+ cache_position = torch.arange(hidden_states.shape[1], device=hidden_states.device)
1433
+ if position_ids is None:
1434
+ position_ids = cache_position.unsqueeze(0)
1435
+
1436
+ causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
1437
+
1438
+ all_hidden_states = () if output_hidden_states else None
1439
+ all_self_attns = () if output_attentions else None
1440
+ all_router_logits = () if output_router_logits else None
1441
+
1442
+ for decoder_layer in self.layers:
1443
+ if output_hidden_states:
1444
+ all_hidden_states += (hidden_states,)
1445
+
1446
+ if self.gradient_checkpointing and self.training:
1447
+ layer_outputs = self._gradient_checkpointing_func(
1448
+ decoder_layer.__call__,
1449
+ hidden_states,
1450
+ causal_mask,
1451
+ position_ids,
1452
+ past_key_values,
1453
+ output_attentions,
1454
+ output_router_logits,
1455
+ use_cache,
1456
+ cache_position,
1457
+ )
1458
+ else:
1459
+ layer_outputs = decoder_layer(
1460
+ hidden_states,
1461
+ attention_mask=causal_mask,
1462
+ position_ids=position_ids,
1463
+ past_key_value=past_key_values,
1464
+ output_attentions=output_attentions,
1465
+ output_router_logits=output_router_logits,
1466
+ use_cache=use_cache,
1467
+ cache_position=cache_position,
1468
+ )
1469
+
1470
+ hidden_states = layer_outputs[0]
1471
+
1472
+ if output_attentions:
1473
+ if layer_outputs[1] is not None:
1474
+ # append attentions only of attention layers. Mamba layers return `None` as the attention weights
1475
+ all_self_attns += (layer_outputs[1],)
1476
+
1477
+ if output_router_logits:
1478
+ if layer_outputs[-1] is not None:
1479
+ # append router logits only of expert layers. Regular MLP layers return `None` as the router logits
1480
+ all_router_logits += (layer_outputs[-1],)
1481
+
1482
+ hidden_states = self.final_layernorm(hidden_states)
1483
+
1484
+ # add hidden states from the last decoder layer
1485
+ if output_hidden_states:
1486
+ all_hidden_states += (hidden_states,)
1487
+
1488
+ if past_key_values and not past_key_values.has_previous_state:
1489
+ past_key_values.has_previous_state = True
1490
+
1491
+ next_cache = None if not use_cache else past_key_values
1492
+
1493
+ if not return_dict:
1494
+ return tuple(
1495
+ v
1496
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
1497
+ if v is not None
1498
+ )
1499
+ return MoeModelOutputWithPast(
1500
+ last_hidden_state=hidden_states,
1501
+ past_key_values=next_cache,
1502
+ hidden_states=all_hidden_states,
1503
+ attentions=all_self_attns,
1504
+ router_logits=all_router_logits,
1505
+ )
1506
+
1507
+ def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
1508
+ if self.config._attn_implementation == "flash_attention_2":
1509
+ if attention_mask is not None and 0.0 in attention_mask:
1510
+ return attention_mask
1511
+ return None
1512
+
1513
+ dtype, device = input_tensor.dtype, input_tensor.device
1514
+ min_dtype = torch.finfo(dtype).min
1515
+ sequence_length = input_tensor.shape[1]
1516
+ target_length = cache_position[-1] + 1
1517
+
1518
+ causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
1519
+ if sequence_length != 1:
1520
+ causal_mask = torch.triu(causal_mask, diagonal=1)
1521
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
1522
+ causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
1523
+ if attention_mask is not None:
1524
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
1525
+ if attention_mask.dim() == 2:
1526
+ mask_length = attention_mask.shape[-1]
1527
+ padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
1528
+ causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
1529
+
1530
+ if (
1531
+ self.config._attn_implementation == "sdpa"
1532
+ and attention_mask is not None
1533
+ and attention_mask.device.type == "cuda"
1534
+ ):
1535
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
1536
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
1537
+ # Details: https://github.com/pytorch/pytorch/issues/110213
1538
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
1539
+
1540
+ return causal_mask
1541
+
1542
+
1543
+ # Adapted from transformers.models.mixtral.modeling_mixtral.MixtralForCausalLM with MIXTRAL->JAMBA, Mixtral->Jamba
1544
+ class JambaForCausalLM(JambaPreTrainedModel):
1545
+ _tied_weights_keys = ["lm_head.weight"]
1546
+
1547
+ def __init__(self, config: JambaConfig):
1548
+ super().__init__(config)
1549
+ self.model = JambaModel(config)
1550
+ self.vocab_size = config.vocab_size
1551
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1552
+ self.router_aux_loss_coef = config.router_aux_loss_coef
1553
+ self.num_experts = config.num_experts
1554
+ self.num_experts_per_tok = config.num_experts_per_tok
1555
+ # Initialize weights and apply final processing
1556
+ self.post_init()
1557
+
1558
+ def get_input_embeddings(self):
1559
+ return self.model.embed_tokens
1560
+
1561
+ def set_input_embeddings(self, value):
1562
+ self.model.embed_tokens = value
1563
+
1564
+ def get_output_embeddings(self):
1565
+ return self.lm_head
1566
+
1567
+ def set_output_embeddings(self, new_embeddings):
1568
+ self.lm_head = new_embeddings
1569
+
1570
+ def set_decoder(self, decoder):
1571
+ self.model = decoder
1572
+
1573
+ def get_decoder(self):
1574
+ return self.model
1575
+
1576
+ @add_start_docstrings_to_model_forward(JAMBA_INPUTS_DOCSTRING)
1577
+ @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1578
+ # Ignore copy
1579
+ def forward(
1580
+ self,
1581
+ input_ids: torch.LongTensor = None,
1582
+ attention_mask: Optional[torch.Tensor] = None,
1583
+ position_ids: Optional[torch.LongTensor] = None,
1584
+ past_key_values: Optional[HybridMambaAttentionDynamicCache] = None,
1585
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1586
+ labels: Optional[torch.LongTensor] = None,
1587
+ use_cache: Optional[bool] = None,
1588
+ output_attentions: Optional[bool] = None,
1589
+ output_hidden_states: Optional[bool] = None,
1590
+ output_router_logits: Optional[bool] = None,
1591
+ return_dict: Optional[bool] = None,
1592
+ cache_position: Optional[torch.LongTensor] = None,
1593
+ num_logits_to_keep: Optional[Union[int, None]] = None,
1594
+ ) -> Union[Tuple, MoeCausalLMOutputWithPast]:
1595
+ r"""
1596
+ Args:
1597
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1598
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1599
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1600
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1601
+
1602
+ num_logits_to_keep (`int` or `None`, *optional*):
1603
+ Calculate logits for the last `num_logits_to_keep` tokens. If `None`, calculate logits for all
1604
+ `input_ids`. Only last token logits are needed for generation, and calculating them only for that token
1605
+ can save memory, which becomes pretty significant for long sequences.
1606
+
1607
+ Returns:
1608
+
1609
+ Example:
1610
+
1611
+ ```python
1612
+ >>> from transformers import AutoTokenizer, JambaForCausalLM
1613
+
1614
+ >>> model = JambaForCausalLM.from_pretrained("ai21labs/Jamba-v0.1")
1615
+ >>> tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1")
1616
+
1617
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1618
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1619
+
1620
+ >>> # Generate
1621
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1622
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1623
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1624
+ ```"""
1625
+
1626
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1627
+ output_router_logits = (
1628
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
1629
+ )
1630
+
1631
+ output_hidden_states = (
1632
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1633
+ )
1634
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1635
+
1636
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1637
+ outputs = self.model(
1638
+ input_ids=input_ids,
1639
+ attention_mask=attention_mask,
1640
+ position_ids=position_ids,
1641
+ past_key_values=past_key_values,
1642
+ inputs_embeds=inputs_embeds,
1643
+ use_cache=use_cache,
1644
+ output_attentions=output_attentions,
1645
+ output_hidden_states=output_hidden_states,
1646
+ output_router_logits=output_router_logits,
1647
+ cache_position=cache_position,
1648
+ return_dict=return_dict,
1649
+ )
1650
+
1651
+ hidden_states = outputs[0]
1652
+ if num_logits_to_keep is None:
1653
+ logits = self.lm_head(hidden_states)
1654
+ else:
1655
+ logits = self.lm_head(hidden_states[..., -num_logits_to_keep:, :])
1656
+ logits = logits.float()
1657
+
1658
+ loss = None
1659
+ if labels is not None:
1660
+ # Shift so that tokens < n predict n
1661
+ shift_logits = logits[..., :-1, :].contiguous()
1662
+ shift_labels = labels[..., 1:].contiguous()
1663
+ # Flatten the tokens
1664
+ loss_fct = CrossEntropyLoss()
1665
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1666
+ shift_labels = shift_labels.view(-1)
1667
+ # Enable model parallelism
1668
+ shift_labels = shift_labels.to(shift_logits.device)
1669
+ loss = loss_fct(shift_logits, shift_labels)
1670
+
1671
+ aux_loss = None
1672
+ if output_router_logits:
1673
+ aux_loss = load_balancing_loss_func(
1674
+ outputs.router_logits if return_dict else outputs[-1],
1675
+ self.num_experts,
1676
+ self.num_experts_per_tok,
1677
+ attention_mask,
1678
+ )
1679
+ if labels is not None:
1680
+ loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
1681
+
1682
+ if not return_dict:
1683
+ output = (logits,) + outputs[1:]
1684
+ if output_router_logits:
1685
+ output = (aux_loss,) + output
1686
+ return (loss,) + output if loss is not None else output
1687
+
1688
+ return MoeCausalLMOutputWithPast(
1689
+ loss=loss,
1690
+ aux_loss=aux_loss,
1691
+ logits=logits,
1692
+ past_key_values=outputs.past_key_values,
1693
+ hidden_states=outputs.hidden_states,
1694
+ attentions=outputs.attentions,
1695
+ router_logits=outputs.router_logits,
1696
+ )
1697
+
1698
+ def prepare_inputs_for_generation(
1699
+ self,
1700
+ input_ids,
1701
+ past_key_values=None,
1702
+ attention_mask=None,
1703
+ inputs_embeds=None,
1704
+ output_router_logits=False,
1705
+ cache_position=None,
1706
+ **kwargs,
1707
+ ):
1708
+ empty_past_kv = past_key_values is None
1709
+
1710
+ # Omit tokens covered by past_key_values
1711
+ if not empty_past_kv:
1712
+ past_length = cache_position[0] if cache_position is not None else attention_mask.shape[1]
1713
+ max_cache_length = self.config.sliding_window
1714
+ # Keep only the unprocessed tokens:
1715
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1716
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1717
+ # input)
1718
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1719
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1720
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1721
+ # input_ids based on the past_length.
1722
+ elif past_length < input_ids.shape[1]:
1723
+ input_ids = input_ids[:, past_length:]
1724
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1725
+
1726
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1727
+ if (
1728
+ max_cache_length is not None
1729
+ and attention_mask is not None
1730
+ and past_length + input_ids.shape[1] > max_cache_length
1731
+ ):
1732
+ attention_mask = attention_mask[:, -max_cache_length:]
1733
+ else:
1734
+ past_key_values = HybridMambaAttentionDynamicCache(
1735
+ self.config, input_ids.shape[0], self.dtype, device=self.device
1736
+ )
1737
+
1738
+ position_ids = kwargs.get("position_ids", None)
1739
+ if attention_mask is not None and position_ids is None:
1740
+ # create position_ids on the fly for batch generation
1741
+ position_ids = attention_mask.long().cumsum(-1) - 1
1742
+ position_ids.masked_fill_(attention_mask == 0, 1)
1743
+ if not empty_past_kv:
1744
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1745
+
1746
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1747
+ if inputs_embeds is not None and empty_past_kv:
1748
+ model_inputs = {"inputs_embeds": inputs_embeds}
1749
+ else:
1750
+ model_inputs = {"input_ids": input_ids}
1751
+
1752
+ model_inputs.update(
1753
+ {
1754
+ "position_ids": position_ids,
1755
+ "past_key_values": past_key_values,
1756
+ "use_cache": kwargs.get("use_cache"),
1757
+ "attention_mask": attention_mask,
1758
+ "output_router_logits": output_router_logits,
1759
+ "num_logits_to_keep": self.config.num_logits_to_keep,
1760
+ "cache_position": cache_position,
1761
+ }
1762
+ )
1763
+ return model_inputs
1764
+
1765
+
1766
+ @add_start_docstrings(
1767
+ """
1768
+ The Jamba Model with a sequence classification head on top (linear layer).
1769
+
1770
+ [`JambaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1771
+ (e.g. GPT-2) do.
1772
+
1773
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1774
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1775
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1776
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1777
+ each row of the batch).
1778
+ """,
1779
+ JAMBA_START_DOCSTRING,
1780
+ )
1781
+ # Copied from transformers.models.mixtral.modeling_mixtral.MixtralForSequenceClassification with Mixtral->Jamba, MIXTRAL->JAMBA
1782
+ class JambaForSequenceClassification(JambaPreTrainedModel):
1783
+ def __init__(self, config):
1784
+ super().__init__(config)
1785
+ self.num_labels = config.num_labels
1786
+ self.model = JambaModel(config)
1787
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1788
+
1789
+ # Initialize weights and apply final processing
1790
+ self.post_init()
1791
+
1792
+ def get_input_embeddings(self):
1793
+ return self.model.embed_tokens
1794
+
1795
+ def set_input_embeddings(self, value):
1796
+ self.model.embed_tokens = value
1797
+
1798
+ @add_start_docstrings_to_model_forward(JAMBA_INPUTS_DOCSTRING)
1799
+ def forward(
1800
+ self,
1801
+ input_ids: torch.LongTensor = None,
1802
+ attention_mask: Optional[torch.Tensor] = None,
1803
+ position_ids: Optional[torch.LongTensor] = None,
1804
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1805
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1806
+ labels: Optional[torch.LongTensor] = None,
1807
+ use_cache: Optional[bool] = None,
1808
+ output_attentions: Optional[bool] = None,
1809
+ output_hidden_states: Optional[bool] = None,
1810
+ return_dict: Optional[bool] = None,
1811
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1812
+ r"""
1813
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1814
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1815
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1816
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1817
+ """
1818
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1819
+
1820
+ transformer_outputs = self.model(
1821
+ input_ids,
1822
+ attention_mask=attention_mask,
1823
+ position_ids=position_ids,
1824
+ past_key_values=past_key_values,
1825
+ inputs_embeds=inputs_embeds,
1826
+ use_cache=use_cache,
1827
+ output_attentions=output_attentions,
1828
+ output_hidden_states=output_hidden_states,
1829
+ return_dict=return_dict,
1830
+ )
1831
+ hidden_states = transformer_outputs[0]
1832
+ logits = self.score(hidden_states)
1833
+
1834
+ if input_ids is not None:
1835
+ batch_size = input_ids.shape[0]
1836
+ else:
1837
+ batch_size = inputs_embeds.shape[0]
1838
+
1839
+ if self.config.pad_token_id is None and batch_size != 1:
1840
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1841
+ if self.config.pad_token_id is None:
1842
+ sequence_lengths = -1
1843
+ else:
1844
+ if input_ids is not None:
1845
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1846
+ sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1847
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
1848
+ sequence_lengths = sequence_lengths.to(logits.device)
1849
+ else:
1850
+ sequence_lengths = -1
1851
+
1852
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1853
+
1854
+ loss = None
1855
+ if labels is not None:
1856
+ labels = labels.to(logits.device)
1857
+ if self.config.problem_type is None:
1858
+ if self.num_labels == 1:
1859
+ self.config.problem_type = "regression"
1860
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1861
+ self.config.problem_type = "single_label_classification"
1862
+ else:
1863
+ self.config.problem_type = "multi_label_classification"
1864
+
1865
+ if self.config.problem_type == "regression":
1866
+ loss_fct = MSELoss()
1867
+ if self.num_labels == 1:
1868
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1869
+ else:
1870
+ loss = loss_fct(pooled_logits, labels)
1871
+ elif self.config.problem_type == "single_label_classification":
1872
+ loss_fct = CrossEntropyLoss()
1873
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1874
+ elif self.config.problem_type == "multi_label_classification":
1875
+ loss_fct = BCEWithLogitsLoss()
1876
+ loss = loss_fct(pooled_logits, labels)
1877
+ if not return_dict:
1878
+ output = (pooled_logits,) + transformer_outputs[1:]
1879
+ return ((loss,) + output) if loss is not None else output
1880
+
1881
+ return SequenceClassifierOutputWithPast(
1882
+ loss=loss,
1883
+ logits=pooled_logits,
1884
+ past_key_values=transformer_outputs.past_key_values,
1885
+ hidden_states=transformer_outputs.hidden_states,
1886
+ attentions=transformer_outputs.attentions,
1887
+ )
special_tokens_map.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|startoftext|>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|endoftext|>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "<|pad|>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "unk_token": {
24
+ "content": "<|unk|>",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ }
30
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:02fd6530b8ede0eedd8e509fcab32da7b1dd04c8119f8498c787100f13112713
3
+ size 1124742
tokenizer_config.json ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": true,
3
+ "add_eos_token": false,
4
+ "added_tokens_decoder": {
5
+ "0": {
6
+ "content": "<|pad|>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "1": {
14
+ "content": "<|startoftext|>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "2": {
22
+ "content": "<|endoftext|>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ },
29
+ "3": {
30
+ "content": "<|unk|>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": true
36
+ }
37
+ },
38
+ "bos_token": "<|startoftext|>",
39
+ "clean_up_tokenization_spaces": false,
40
+ "eos_token": "<|endoftext|>",
41
+ "model_max_length": 1000000000000000019884624838656,
42
+ "pad_token": "<|pad|>",
43
+ "spaces_between_special_tokens": false,
44
+ "tokenizer_class": "LlamaTokenizer",
45
+ "unk_token": "<|unk|>",
46
+ "use_default_system_prompt": false
47
+ }