| <!--Copyright 2024 The HuggingFace Team. All rights reserved. | |
| Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | |
| the License. You may obtain a copy of the License at | |
| http://www.apache.org/licenses/LICENSE-2.0 | |
| Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | |
| an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | |
| specific language governing permissions and limitations under the License. | |
| โ ๏ธ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be | |
| rendered properly in your Markdown viewer. | |
| --> | |
| <div style="float: right;"> | |
| <div class="flex flex-wrap space-x-1"> | |
| <img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white"> | |
| <img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat"> | |
| <img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white"> | |
| </div> | |
| </div> | |
| # Jamba[[jamba]] | |
| [Jamba](https://huggingface.co/papers/2403.19887)๋ Transformer์ Mamba ๊ธฐ๋ฐ์ ํ์ด๋ธ๋ฆฌ๋ ์ ๋ฌธ๊ฐ ํผํฉ(MoE) ์ธ์ด ๋ชจ๋ธ๋ก, ์ด ๋งค๊ฐ๋ณ์ ์๋ 52B์์ 398B๊น์ง ๋ค์ํฉ๋๋ค. ์ด ๋ชจ๋ธ์ Transformer ๋ชจ๋ธ์ ์ฑ๋ฅ๊ณผ Mamba์ ๊ฐ์ ์ํ ๊ณต๊ฐ ๋ชจ๋ธ์ ํจ์จ์ฑ ๋ฐ ๊ธด ์ปจํ ์คํธ ์ฒ๋ฆฌ ๋ฅ๋ ฅ(256K ํ ํฐ)์ ๋ชจ๋ ํ์ฉํ๋ ๊ฒ์ ๋ชฉํ๋ก ํฉ๋๋ค. | |
| Jamba์ ์ํคํ ์ฒ๋ ๋ธ๋ก๊ณผ ๋ ์ด์ด ๊ธฐ๋ฐ ๊ตฌ์กฐ๋ฅผ ์ฌ์ฉํ์ฌ Transformer์ Mamba ์ํคํ ์ฒ๋ฅผ ํตํฉํ ์ ์๋๋ก ์ค๊ณ๋์์ต๋๋ค. ๊ฐ Jamba ๋ธ๋ก์ ์ดํ ์ ๋ ์ด์ด ๋๋ Mamba ๋ ์ด์ด ์ค ํ๋์ ๊ทธ ๋ค๋ฅผ ์๋ ๋ค์ธต ํผ์ ํธ๋ก (MLP)์ผ๋ก ๊ตฌ์ฑ๋์ด ์์ต๋๋ค. Transformer ๋ ์ด์ด๋ 8๊ฐ์ ๋ ์ด์ด ์ค ํ๋์ ๋น์จ๋ก ์ฃผ๊ธฐ์ ์ผ๋ก ๋ฐฐ์น๋ฉ๋๋ค. ๋ํ ๋ชจ๋ธ ์ฉ๋์ ํ์ฅํ๊ธฐ ์ํด MoE ๋ ์ด์ด๊ฐ ํผํฉ๋์ด ์์ต๋๋ค. | |
| ๋ชจ๋ ์๋ณธ Jamba ์ฒดํฌํฌ์ธํธ๋ [AI21](https://huggingface.co/ai21labs) ์กฐ์ง์์ ํ์ธํ ์ ์์ต๋๋ค. | |
| > [!TIP] | |
| > ์ค๋ฅธ์ชฝ ์ฌ์ด๋๋ฐ์ ์๋ Jamba ๋ชจ๋ธ์ ๋๋ฅด๋ฉด ๋ค์ํ ์ธ์ด ์์ ์ Jamba๋ฅผ ์ ์ฉํ๋ ์์ ๋ฅผ ๋ ํ์ธํ ์ ์์ต๋๋ค. | |
| ์๋ ์์ ๋ [`Pipeline`]๊ณผ [`AutoModel`], ๊ทธ๋ฆฌ๊ณ ์ปค๋งจ๋๋ผ์ธ์ ํตํด ํ ์คํธ๋ฅผ ์์ฑํ๋ ๋ฐฉ๋ฒ์ ๋ณด์ฌ์ค๋๋ค. | |
| <hfoptions id="usage"> | |
| <hfoption id="Pipeline"> | |
| ```py | |
| # ์ต์ ํ๋ Mamba ๊ตฌํ ์ค์น | |
| # !pip install mamba-ssm causal-conv1d>=1.2.0 | |
| import torch | |
| from transformers import pipeline | |
| pipeline = pipeline( | |
| task="text-generation", | |
| model="ai21labs/AI21-Jamba-Mini-1.6", | |
| dtype=torch.float16, | |
| device=0 | |
| ) | |
| pipeline("Plants create energy through a process known as") | |
| ``` | |
| </hfoption> | |
| <hfoption id="AutoModel"> | |
| ```py | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| "ai21labs/AI21-Jamba-Large-1.6", | |
| ) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| "ai21labs/AI21-Jamba-Large-1.6", | |
| dtype=torch.float16, | |
| device_map="auto", | |
| attn_implementation="sdpa" | |
| ) | |
| input_ids = tokenizer("Plants create energy through a process known as", return_tensors="pt").to("cuda") | |
| output = model.generate(**input_ids, cache_implementation="static") | |
| print(tokenizer.decode(output[0], skip_special_tokens=True)) | |
| ``` | |
| </hfoption> | |
| <hfoption id="transformers CLI"> | |
| ```bash | |
| echo -e "Plants create energy through a process known as" | transformers run --task text-generation --model ai21labs/AI21-Jamba-Mini-1.6 --device 0 | |
| ``` | |
| </hfoption> | |
| </hfoptions> | |
| ์์ํ๋ ๊ฐ์ค์น๋ฅผ ๋ ๋ฎ์ ์ ๋ฐ๋๋ก ํํํ์ฌ ๋๊ท๋ชจ ๋ชจ๋ธ์ ๋ฉ๋ชจ๋ฆฌ ๋ถ๋ด์ ์ค์ฌ์ค๋๋ค. ์ฌ์ฉํ ์ ์๋ ๋ค์ํ ์์ํ ๋ฐฑ์๋์ ๋ํด์๋ [Quantization](../quantization/overview)๋ฅผ ์ฐธ๊ณ ํ์ธ์. | |
| ์๋ ์์ ๋ [bitsandbytes](../quantization/bitsandbytes)๋ฅผ ์ฌ์ฉํ์ฌ ๊ฐ์ค์น๋ง 8๋นํธ๋ก ์์ํํ๋ ๋ฐฉ๋ฒ์ ๋ณด์ฌ์ค๋๋ค. | |
| ```py | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig | |
| quantization_config = BitsAndBytesConfig(load_in_8bit=True, | |
| llm_int8_skip_modules=["mamba"]) | |
| # ๋ชจ๋ธ์ 8๊ฐ์ GPU์ ๊ณ ๋ฅด๊ฒ ๋ถ์ฐ์ํค๊ธฐ ์ํ ๋๋ฐ์ด์ค ๋งต | |
| device_map = {'model.embed_tokens': 0, 'model.layers.0': 0, 'model.layers.1': 0, 'model.layers.2': 0, 'model.layers.3': 0, 'model.layers.4': 0, 'model.layers.5': 0, 'model.layers.6': 0, 'model.layers.7': 0, 'model.layers.8': 0, 'model.layers.9': 1, 'model.layers.10': 1, 'model.layers.11': 1, 'model.layers.12': 1, 'model.layers.13': 1, 'model.layers.14': 1, 'model.layers.15': 1, 'model.layers.16': 1, 'model.layers.17': 1, 'model.layers.18': 2, 'model.layers.19': 2, 'model.layers.20': 2, 'model.layers.21': 2, 'model.layers.22': 2, 'model.layers.23': 2, 'model.layers.24': 2, 'model.layers.25': 2, 'model.layers.26': 2, 'model.layers.27': 3, 'model.layers.28': 3, 'model.layers.29': 3, 'model.layers.30': 3, 'model.layers.31': 3, 'model.layers.32': 3, 'model.layers.33': 3, 'model.layers.34': 3, 'model.layers.35': 3, 'model.layers.36': 4, 'model.layers.37': 4, 'model.layers.38': 4, 'model.layers.39': 4, 'model.layers.40': 4, 'model.layers.41': 4, 'model.layers.42': 4, 'model.layers.43': 4, 'model.layers.44': 4, 'model.layers.45': 5, 'model.layers.46': 5, 'model.layers.47': 5, 'model.layers.48': 5, 'model.layers.49': 5, 'model.layers.50': 5, 'model.layers.51': 5, 'model.layers.52': 5, 'model.layers.53': 5, 'model.layers.54': 6, 'model.layers.55': 6, 'model.layers.56': 6, 'model.layers.57': 6, 'model.layers.58': 6, 'model.layers.59': 6, 'model.layers.60': 6, 'model.layers.61': 6, 'model.layers.62': 6, 'model.layers.63': 7, 'model.layers.64': 7, 'model.layers.65': 7, 'model.layers.66': 7, 'model.layers.67': 7, 'model.layers.68': 7, 'model.layers.69': 7, 'model.layers.70': 7, 'model.layers.71': 7, 'model.final_layernorm': 7, 'lm_head': 7} | |
| model = AutoModelForCausalLM.from_pretrained("ai21labs/AI21-Jamba-Large-1.6", | |
| dtype=torch.bfloat16, | |
| attn_implementation="flash_attention_2", | |
| quantization_config=quantization_config, | |
| device_map=device_map) | |
| tokenizer = AutoTokenizer.from_pretrained("ai21labs/AI21-Jamba-Large-1.6") | |
| messages = [ | |
| {"role": "system", "content": "You are an ancient oracle who speaks in cryptic but wise phrases, always hinting at deeper meanings."}, | |
| {"role": "user", "content": "Hello!"}, | |
| ] | |
| input_ids = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors='pt').to(model.device) | |
| outputs = model.generate(input_ids, max_new_tokens=216) | |
| # ์ถ๋ ฅ ๋์ฝ๋ฉ | |
| conversation = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # ์ด์์คํดํธ์ ์๋ต๋ง ์ถ์ถ | |
| assistant_response = conversation.split(messages[-1]['content'])[1].strip() | |
| print(assistant_response) | |
| # ์ถ๋ ฅ: Seek and you shall find. The path is winding, but the journey is enlightening. What wisdom do you seek from the ancient echoes? | |
| ``` | |
| ## ์ฐธ๊ณ [[notes]] | |
| - ๋ชจ๋ธ ์ฑ๋ฅ ์ ํ๋ฅผ ๋ฐฉ์งํ๊ธฐ ์ํด Mamba ๋ธ๋ก์ ์์ํํ์ง ๋ง์ธ์. | |
| - ์ต์ ํ๋ Mamba ์ปค๋ ์์ด Mamba๋ฅผ ์ฌ์ฉํ๋ฉด ์ง์ฐ ์๊ฐ์ด ํฌ๊ฒ ์ฆ๊ฐํ๋ฏ๋ก ๊ถ์ฅ๋์ง ์์ต๋๋ค. ๊ทธ๋๋ ์ปค๋ ์์ด Mamba๋ฅผ ์ฌ์ฉํ๊ณ ์ ํ๋ค๋ฉด [`~AutoModel.from_pretrained`]์์ `use_mamba_kernels=False`๋ก ์ค์ ํ์ธ์. | |
| ```py | |
| import torch | |
| from transformers import AutoModelForCausalLM | |
| model = AutoModelForCausalLM.from_pretrained("ai21labs/AI21-Jamba-1.5-Large", | |
| use_mamba_kernels=False) | |
| ``` | |
| ## JambaConfig[[transformers.JambaConfig]] | |
| [[autodoc]] JambaConfig | |
| ## JambaModel[[transformers.JambaModel]] | |
| [[autodoc]] JambaModel | |
| - forward | |
| ## JambaForCausalLM[[transformers.JambaForCausalLM]] | |
| [[autodoc]] JambaForCausalLM | |
| - forward | |
| ## JambaForSequenceClassification[[transformers.JambaForSequenceClassification]] | |
| [[autodoc]] transformers.JambaForSequenceClassification | |
| - forward | |