| --- |
| library_name: transformers |
| license: apache-2.0 |
| tags: |
| - jamba |
| - mamba |
| - moe |
| --- |
| |
| This is the base version of the Jamba model. We’ve since released a better, instruct-tuned version, [Jamba-1.6-Mini](https://huggingface.co/ai21labs/AI21-Jamba-Mini-1.6). For even greater performance, check out the scaled-up [Jamba-1.6-Large](https://huggingface.co/ai21labs/AI21-Jamba-Large-1.6.). |
|
|
| # Model Card for Jamba |
|
|
| Jamba is a state-of-the-art, hybrid SSM-Transformer LLM. It delivers throughput gains over traditional Transformer-based models, while outperforming or matching the leading models of its size class on most common benchmarks. |
|
|
| Jamba is the first production-scale Mamba implementation, which opens up interesting research and application opportunities. While this initial experimentation shows encouraging gains, we expect these to be further enhanced with future optimizations and explorations. |
|
|
| This model card is for the base version of Jamba. It’s a pretrained, mixture-of-experts (MoE) generative text model, with 12B active parameters and a total of 52B parameters across all experts. It supports a 256K context length, and can fit up to 140K tokens on a single 80GB GPU. |
|
|
| For full details of this model please read the [white paper](https://arxiv.org/abs/2403.19887) and the [release blog post](https://www.ai21.com/blog/announcing-jamba). |
|
|
| ## Model Details |
|
|
| - **Developed by:** [AI21](https://www.ai21.com) |
| - **Model type:** Joint Attention and Mamba (Jamba) |
| - **License:** Apache 2.0 |
| - **Context length:** 256K |
| - **Knowledge cutoff date:** March 5, 2024 |
|
|
| ## Usage |
| ### Presequities |
| In order to use Jamba, it is recommended you use `transformers` version 4.40.0 or higher (version 4.39.0 or higher is required): |
| ```bash |
| pip install transformers>=4.40.0 |
| ``` |
|
|
| In order to run optimized Mamba implementations, you first need to install `mamba-ssm` and `causal-conv1d`: |
| ```bash |
| pip install mamba-ssm causal-conv1d>=1.2.0 |
| ``` |
| You also have to have the model on a CUDA device. |
|
|
| You can run the model not using the optimized Mamba kernels, but it is **not** recommended as it will result in significantly lower latencies. In order to do that, you'll need to specify `use_mamba_kernels=False` when loading the model. |
|
|
| ### Run the model |
| ```python |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| |
| model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1") |
| tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1") |
| |
| input_ids = tokenizer("In the recent Super Bowl LVIII,", return_tensors='pt').to(model.device)["input_ids"] |
| |
| outputs = model.generate(input_ids, max_new_tokens=216) |
| |
| print(tokenizer.batch_decode(outputs)) |
| # ["<|startoftext|>In the recent Super Bowl LVIII, the Kansas City Chiefs emerged victorious, defeating the San Francisco 49ers in a thrilling overtime showdown. The game was a nail-biter, with both teams showcasing their skills and determination.\n\nThe Chiefs, led by their star quarterback Patrick Mahomes, displayed their offensive prowess, while the 49ers, led by their strong defense, put up a tough fight. The game went into overtime, with the Chiefs ultimately securing the win with a touchdown.\n\nThe victory marked the Chiefs' second Super Bowl win in four years, solidifying their status as one of the top teams in the NFL. The game was a testament to the skill and talent of both teams, and a thrilling end to the NFL season.\n\nThe Super Bowl is not just about the game itself, but also about the halftime show and the commercials. This year's halftime show featured a star-studded lineup, including Usher, Alicia Keys, and Lil Jon. The show was a spectacle of music and dance, with the performers delivering an energetic and entertaining performance.\n"] |
| ``` |
|
|
| Please note that if you're using `transformers<4.40.0`, `trust_remote_code=True` is required for running the new Jamba architecture. |
|
|
| <details> |
| <summary><strong>Loading the model in half precision</strong></summary> |
| |
| The published checkpoint is saved in BF16. In order to load it into RAM in BF16/FP16, you need to specify `torch_dtype`: |
| |
| ```python |
| from transformers import AutoModelForCausalLM |
| import torch |
| model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1", |
| torch_dtype=torch.bfloat16) # you can also use torch_dtype=torch.float16 |
| ``` |
|
|
| When using half precision, you can enable the [FlashAttention2](https://github.com/Dao-AILab/flash-attention) implementation of the Attention blocks. In order to use it, you also need the model on a CUDA device. Since in this precision the model is to big to fit on a single 80GB GPU, you'll also need to parallelize it using [accelerate](https://huggingface.co/docs/accelerate/index): |
| ```python |
| from transformers import AutoModelForCausalLM |
| import torch |
| model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1", |
| torch_dtype=torch.bfloat16, |
| attn_implementation="flash_attention_2", |
| device_map="auto") |
| ``` |
|
|
| </details> |
| <details><summary><strong>Load the model in 8-bit</strong></summary> |
| |
| **Using 8-bit precision, it is possible to fit up to 140K sequence lengths on a single 80GB GPU.** You can easily quantize the model to 8-bit using [bitsandbytes](https://huggingface.co/docs/bitsandbytes/index). In order to not degrade model quality, we recommend to exclude the Mamba blocks from the quantization: |
|
|
| ```python |
| from transformers import AutoModelForCausalLM, BitsAndBytesConfig |
| quantization_config = BitsAndBytesConfig(load_in_8bit=True, |
| llm_int8_skip_modules=["mamba"]) |
| model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1", |
| torch_dtype=torch.bfloat16, |
| attn_implementation="flash_attention_2", |
| quantization_config=quantization_config) |
| ``` |
| </details> |
|
|
| ### Fine-tuning example |
| Jamba is a base model that can be fine-tuned for custom solutions (including for chat/instruct versions). You can fine-tune it using any technique of your choice. Here is an example of fine-tuning with the [PEFT](https://huggingface.co/docs/peft/index) library (requires ~120GB GPU RAM, in example 2xA100 80GB): |
|
|
| ```python |
| import torch |
| from datasets import load_dataset |
| from trl import SFTTrainer, SFTConfig |
| from peft import LoraConfig |
| from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments |
| |
| tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1") |
| model = AutoModelForCausalLM.from_pretrained( |
| "ai21labs/Jamba-v0.1", device_map='auto', torch_dtype=torch.bfloat16) |
| |
| lora_config = LoraConfig( |
| r=8, |
| target_modules=[ |
| "embed_tokens", |
| "x_proj", "in_proj", "out_proj", # mamba |
| "gate_proj", "up_proj", "down_proj", # mlp |
| "q_proj", "k_proj", "v_proj" # attention |
| ], |
| task_type="CAUSAL_LM", |
| bias="none" |
| ) |
| |
| dataset = load_dataset("Abirate/english_quotes", split="train") |
| training_args = SFTConfig( |
| output_dir="./results", |
| num_train_epochs=2, |
| per_device_train_batch_size=4, |
| logging_dir='./logs', |
| logging_steps=10, |
| learning_rate=1e-5, |
| dataset_text_field="quote", |
| ) |
| trainer = SFTTrainer( |
| model=model, |
| tokenizer=tokenizer, |
| args=training_args, |
| peft_config=lora_config, |
| train_dataset=dataset, |
| ) |
| trainer.train() |
| ``` |
|
|
| ## Results on common benchmarks |
| | Benchmark | Score | |
| |--------------|:-----:| |
| | HellaSwag | 87.1% | |
| | Arc Challenge | 64.4% | |
| | WinoGrande | 82.5% | |
| | PIQA | 83.2% | |
| | MMLU | 67.4% | |
| | BBH | 45.4% | |
| | TruthfulQA | 46.4% | |
| | GSM8K (CoT) | 59.9% | |
|
|
| It's crucial that the 'BOS' token is added to all prompts, which might not be enabled by default in all eval frameworks. |
|
|
|
|
| ## Notice |
| Jamba is a pretrained base model and did not undergo any alignment for instruct/chat interactions. |
|
|
| As a base model, Jamba is intended for use as a foundation layer for fine tuning, training, and developing custom solutions. Jamba does not have safety moderation mechanisms and guardrails should be added for responsible and safe use. |
|
|
| ## About AI21 |
| AI21 builds reliable, practical, and scalable AI solutions for the enterprise. |
|
|
| Jamba is the first in AI21’s new family of models, and the Instruct version of Jamba is coming soon to the [AI21 platform](https://www.ai21.com/studio). |
|
|