| <!--Copyright 2024 The HuggingFace Team. All rights reserved. | |
| Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | |
| the License. You may obtain a copy of the License at | |
| http://www.apache.org/licenses/LICENSE-2.0 | |
| Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | |
| an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | |
| specific language governing permissions and limitations under the License. | |
| β οΈ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be | |
| rendered properly in your Markdown viewer. | |
| --> | |
| # λ§λ°2[[mamba-2]] | |
| ## κ°μ[[overview]] | |
| λ§λ°2 λͺ¨λΈμ Tri Dao, Albert Guκ° μ μν [νΈλμ€ν¬λ¨Έλ SSMμ΄λ€: ꡬ쑰νλ μν κ³΅κ° μ΄μ€μ±μ ν΅ν μΌλ°νλ λͺ¨λΈκ³Ό ν¨μ¨μ μΈ μκ³ λ¦¬μ¦](https://huggingface.co/papers/2405.21060)λΌλ λ Όλ¬Έμμ μκ°λμμ΅λλ€. λ§λ°2λ λ§λ°1κ³Ό μ μ¬ν μν κ³΅κ° λͺ¨λΈλ‘, λ¨μνλ μν€ν μ²μμ λ λμ μ±λ₯μ 보μ λλ€. | |
| ν΄λΉ λ Όλ¬Έμ μ΄λ‘μ λλ€: | |
| *νΈλμ€ν¬λ¨Έλ μΈμ΄ λͺ¨λΈλ§μμ λ₯λ¬λ μ±κ³΅μ μ£Όμ μν€ν μ²μμ§λ§, λ§λ°μ κ°μ μν κ³΅κ° λͺ¨λΈ(SSM)μ΄ μ΅κ·Ό μκ·λͺ¨ νΉμ μ€κ° κ·λͺ¨μμ νΈλμ€ν¬λ¨Έμ λλ±νκ±°λ λ λμ μ±λ₯μ 보μ΄λ κ²μΌλ‘ λνλ¬μ΅λλ€. μ°λ¦¬λ μ΄λ¬ν λͺ¨λΈ κ³μ΄λ€μ΄ μ€μ λ‘ λ§€μ° λ°μ νκ² μ°κ΄λμ΄ μμμ νμ νμ΅λλ€. κ·Έλ¦¬κ³ κ΅¬μ‘°νλ μ€λΆλ¦¬(semiseparable) νλ ¬ μ€ μ°κ΅¬κ° μ μ΄λ£¨μ΄μ§ ν΄λμ€μ λ€μν λΆν΄λ₯Ό ν΅ν΄ μ°κ²°λ SSMκ³Ό μ΄ν μ λ³ν μ¬μ΄μ νλΆν μ΄λ‘ μ μ°κ²° νλ μμν¬λ₯Ό κ°λ°νμ΅λλ€. μν κ³΅κ° μ΄μ€μ±(SSD) νλ μμν¬λ₯Ό ν΅ν΄ λ§λ°1μ μ νμ SSMμ κ°μ ν μλ‘μ΄ μν€ν μ²λ₯Ό μ€κ³ν μ μμκ³ , νΈλμ€ν¬λ¨Έμ κ²½μλ ₯μ μ μ§νλ©΄μλ μλλ 2~8λ°° λ λΉ λ₯Έ μ±λ₯μ λ λλ€.* | |
| ν: | |
| μ΄ λ²μ μ λ§λ°2 ꡬνμ μ§μν΄μΌ νλ©°, νΉν Mistral AIμ [Mamba-2 codestral](https://huggingface.co/mistralai/Mamba-Codestral-7B-v0.1)μ μ§μν©λλ€. νΉν, mamba 2 codestralμ 8κ°μ `groups`λ‘ μΆμλμλλ°, μ΄λ μ΄ν μ κΈ°λ° λͺ¨λΈμ KV ν€λ μμ μ μ¬νλ€κ³ νλ¨ κ°λ₯ν©λλ€. | |
| μ΄ λͺ¨λΈμ `torch_forward`μ `cuda_kernels_forward`λΌλ λ κ°μ§ λ€λ₯Έ μ λ°© ν¨μ€λ₯Ό κ°μ§λλ€. `cuda_kernels_forward`λ νκ²½μμ cuda 컀λμ μ°ΎμΌλ©΄ μ΄λ₯Ό μ¬μ©νλ©°, prefillμμλ λ λ립λλ€. μ¦, λμ CPU μ€λ²ν€λλ‘ μΈν΄ "μμ μ€ν"μ΄ νμνκΈ° λλ¬Έμ λλ€. κ΄λ ¨ λ΄μ©μ [μ΄κ³³](https://github.com/state-spaces/mamba/issues/389#issuecomment-2171755306)κ³Ό [μ΄κ³³](https://github.com/state-spaces/mamba/issues/355#issuecomment-2147597457)μ μ°Έκ³ νμΈμ. | |
| μ»΄νμΌ μμ΄λ `torch_forward` ꡬνμ΄ 3~4λ°° λΉ λ¦ λλ€. λν, μ΄ λͺ¨λΈμλ μμΉ μλ² λ©μ΄ μμ§λ§ `attention_mask`μ λ°°μΉ μμ±μ κ²½μ° λ κ³³μμ μλ μν(hidden state)λ₯Ό λ§μ€νΉνλ νΉμ λ‘μ§μ΄ μμ΅λλ€. κ΄λ ¨ λ΄μ©μ [μ΄κ³³](https://github.com/state-spaces/mamba/issues/66#issuecomment-1863563829)μ μ°Έκ³ νμΈμ. | |
| μ΄λ‘μΈν΄ λ§λ°2 컀λμ μ¬κ΅¬νκ³Ό ν¨κ» λ°°μΉ μμ± λ° μΊμλ μμ±μμ μ½κ°μ μ°¨μ΄κ° μμλ©λλ€. λν cuda 컀λ λλ torch forwardκ° μ 곡νλ κ²°κ³Όκ° μ½κ° λ€λ₯Ό κ²μΌλ‘ μμλ©λλ€. SSM μκ³ λ¦¬μ¦μ ν μ μμΆμ ν¬κ² μμ‘΄νλλ°, μ΄λ matmulκ³Ό λλ±νμ§λ§ μ°μ° μμκ° μ½κ° λ€λ₯΄λ©°, μ΄λ‘ μΈν΄ λ μμ μ λ°λμμ μ°¨μ΄κ° λ 컀μ§λλ€. | |
| λ λ€λ₯Έ μ°Έκ³ μ¬νμΌλ‘, ν¨λ© ν ν°μ ν΄λΉνλ μλ μν(hidden state)μ μ’ λ£λ λ κ³³μμ μ΄λ£¨μ΄μ§λ©° μ£Όλ‘ μΌμͺ½ ν¨λ©μΌλ‘ ν μ€νΈλμμ΅λλ€. μ€λ₯Έμͺ½ ν¨λ©μ λ Έμ΄μ¦λ₯Ό μ ννλ―λ‘ λ§μ‘±μ€λ¬μ΄ κ²°κ³Όλ₯Ό 보μ₯νμ§ μμ΅λλ€. `tokenizer.padding_side = "left"`λ₯Ό μ¬μ©νλ©΄ μ¬λ°λ₯Έ ν¨λ© λ°©ν₯μ μ¬μ©ν μ μμ΅λλ€. | |
| μ΄ λͺ¨λΈμ [Molbap](https://huggingface.co/Molbap)μ΄ κΈ°μ¬νμΌλ©°, [Anton Vlasjuk](https://github.com/vasqu)μ ν° λμμ λ°μμ΅λλ€. | |
| μλ³Έ μ½λλ [μ΄κ³³](https://github.com/state-spaces/mamba)μμ νμΈν μ μμ΅λλ€. | |
| # μ¬μ© | |
| ### κ°λ¨ν μμ± μ: | |
| ```python | |
| from transformers import Mamba2Config, Mamba2ForCausalLM, AutoTokenizer | |
| import torch | |
| model_id = 'mistralai/Mamba-Codestral-7B-v0.1' | |
| tokenizer = AutoTokenizer.from_pretrained(model_id, revision='refs/pr/9', from_slow=True, legacy=False) | |
| model = Mamba2ForCausalLM.from_pretrained(model_id, revision='refs/pr/9') | |
| input_ids = tokenizer("Hey how are you doing?", return_tensors= "pt")["input_ids"] | |
| out = model.generate(input_ids, max_new_tokens=10) | |
| print(tokenizer.batch_decode(out)) | |
| ``` | |
| μ΄κ³³μ λ―ΈμΈμ‘°μ μ μν μ΄μ μ€ν¬λ¦½νΈμ λλ€: | |
| ```python | |
| from datasets import load_dataset | |
| from peft import LoraConfig | |
| from trl import SFTConfig, SFTTrainer | |
| model_id = "mistralai/Mamba-Codestral-7B-v0.1" | |
| dataset = load_dataset("Abirate/english_quotes", split="train") | |
| training_args = SFTConfig(dataset_text_field="quote", gradient_checkpointing=True, per_device_train_batch_size=4) | |
| lora_config = LoraConfig(target_modules=["x_proj", "embeddings", "in_proj", "out_proj"]) | |
| trainer = SFTTrainer( | |
| model=model_id, | |
| args=training_args, | |
| train_dataset=dataset, | |
| peft_config=lora_config, | |
| ) | |
| trainer.train() | |
| ``` | |
| ## Mamba2Config | |
| [[autodoc]] Mamba2Config | |
| ## Mamba2Model | |
| [[autodoc]] Mamba2Model | |
| - forward | |
| ## Mamba2LMHeadModel | |
| [[autodoc]] Mamba2ForCausalLM | |
| - forward | |