| <!--Copyright 2024 The HuggingFace Team. All rights reserved. | |
| Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | |
| the License. You may obtain a copy of the License at | |
| http://www.apache.org/licenses/LICENSE-2.0 | |
| Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | |
| an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | |
| specific language governing permissions and limitations under the License. | |
| β οΈ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be | |
| rendered properly in your Markdown viewer. | |
| --> | |
| # λ§λ°[[mamba]] | |
| ## κ°μ[[overview]] | |
| λ§λ°(Mamba) λͺ¨λΈμ Albert Gu, Tri Daoκ° μ μν [λ§λ°: μ νμ μν 곡κ°μ μ΄μ©ν μ ν μκ° μνμ€ λͺ¨λΈλ§](https://huggingface.co/papers/2312.00752)λΌλ λ Όλ¬Έμμ μκ° λμμ΅λλ€. | |
| μ΄ λͺ¨λΈμ `state-space-models`μ κΈ°λ°μΌλ‘ ν μλ‘μ΄ ν¨λ¬λ€μ μν€ν μ²μ λλ€. μ§κ΄μ μΈ μ΄ν΄λ₯Ό μ»κ³ μΆλ€λ©΄ [μ΄κ³³](https://srush.github.io/annotated-s4/)μ μ°Έκ³ νμΈμ. | |
| ν΄λΉ λ Όλ¬Έμ μ΄λ‘μ λλ€: | |
| *νμ¬ λ₯λ¬λμμ ν₯λ―Έλ‘μ΄ μμ© νλ‘κ·Έλ¨μ ꡬλνλ λλΆλΆμ κΈ°μ΄ λͺ¨λΈλ€μ κ±°μ 보νΈμ μΌλ‘ νΈλμ€ν¬λ¨Έ μν€ν μ²μ κ·Έ ν΅μ¬ μ΄ν μ λͺ¨λμ κΈ°λ°μΌλ‘ ν©λλ€. μ ν μ΄ν μ , κ²μ΄νΈλ 컨볼루μ κ³Ό μν λͺ¨λΈ, ꡬ쑰νλ μν κ³΅κ° λͺ¨λΈ(SSM) λ± λ§μ μ€μ΄μ°¨μκ°(subquadratic-time) μν€ν μ²κ° κΈ΄ μνμ€μ λν νΈλμ€ν¬λ¨Έμ κ³μ° λΉν¨μ¨μ±μ ν΄κ²°νκΈ° μν΄ κ°λ°λμμ§λ§, μΈμ΄μ κ°μ μ€μν μμμμλ μ΄ν μ λ§νΌ μ±λ₯μ λ΄μ§ λͺ»νμ΅λλ€. μ°λ¦¬λ μ΄λ¬ν λͺ¨λΈμ μ£Όμ μ½μ μ΄ λ΄μ© κΈ°λ° μΆλ‘ μ μννμ§ λͺ»νλ€λ μ μμ μκ³ λͺ κ°μ§λ₯Ό κ°μ νμ΅λλ€. 첫째, SSM λ§€κ°λ³μλ₯Ό μ λ ₯μ ν¨μλ‘ λ§λλ κ²λ§μΌλ‘λ μ΄μ° λͺ¨λ¬λ¦¬ν°(discrete modalities)μ μ½μ μ ν΄κ²°ν μ μμ΄, νμ¬ ν ν°μ λ°λΌ μνμ€ κΈΈμ΄ μ°¨μμ λ°λΌ μ 보λ₯Ό μ νμ μΌλ‘ μ ννκ±°λ μμ μ μκ² ν©λλ€. λμ§Έ, μ΄λ¬ν λ³κ²½μΌλ‘ ν¨μ¨μ μΈ μ»¨λ³Όλ£¨μ μ μ¬μ©ν μ μκ² λμμ§λ§, μ°λ¦¬λ μν λͺ¨λμμ νλμ¨μ΄λ₯Ό μΈμνλ λ³λ ¬ μκ³ λ¦¬μ¦μ μ€κ³νμ΅λλ€. μ°λ¦¬λ μ΄λ¬ν μ νμ SSMμ μ΄ν μ μ΄λ MLP λΈλ‘λ μλ λ¨μνλ μ’ λ¨κ° μ κ²½λ§ μν€ν μ²μΈ λ§λ°μ ν΅ν©μμΌ°μ΅λλ€. λ§λ°λ λΉ λ₯Έ μΆλ‘ (νΈλμ€ν¬λ¨Έλ³΄λ€ 5λ°° λμ μ²λ¦¬λ)κ³Ό μνμ€ κΈΈμ΄μ λν μ ν νμ₯μ±μ λ리며, λ°±λ§ κΈΈμ΄ μνμ€κΉμ§ μ€μ λ°μ΄ν°μμ μ±λ₯μ΄ ν₯μλ©λλ€. μΌλ°μ μΈ μνμ€ λͺ¨λΈ λ°±λ³ΈμΌλ‘μ λ§λ°λ μΈμ΄, μ€λμ€, μ μ 체νκ³Ό κ°μ μ¬λ¬ μμμμ μ΅μ²¨λ¨ μ±λ₯μ λ¬μ±ν©λλ€. μΈμ΄ λͺ¨λΈλ§μμ μ°λ¦¬μ λ§λ°-3B λͺ¨λΈμ κ°μ ν¬κΈ°μ νΈλμ€ν¬λ¨Έλ₯Ό λ₯κ°νκ³ λ λ°° ν¬κΈ°μ νΈλμ€ν¬λ¨Έμ λ§λ¨Ήλ μ±λ₯μ 보μ΄λ©°, μ¬μ νλ ¨κ³Ό λ€μ΄μ€νΈλ¦Ό νκ° λͺ¨λμμ μ±λ₯μ λνλ©λλ€.* | |
| ν: | |
| - λ§λ°λ κ³ μ μ μΈ νΈλμ€ν¬λ¨Έμ κ²¬μ€ λ§ν μλ‘μ΄ `μν κ³΅κ° λͺ¨λΈ` μν€ν μ²μ λλ€. μ΄λ ꡬ쑰νλ μν κ³΅κ° λͺ¨λΈμ λ°μ μ μμ μμΌλ©°, [νλμμ΄ν μ ](https://github.com/Dao-AILab/flash-attention)μ μ μ μ λ°λ₯΄λ ν¨μ¨μ μΈ νλμ¨μ΄ μΈμ μ€κ³μ ꡬνμ νΉμ§μΌλ‘ ν©λλ€. | |
| - λ§λ°λ `μ΄ν μ ` λ μ΄μ΄μ λλ±ν `λ―Ήμ(mixer)` λ μ΄μ΄λ₯Ό μμ΅λλ€. `λ§λ°`μ ν΅μ¬ λ‘μ§μ `MambaMixer` ν΄λμ€μ μμ΅λλ€. | |
| - λ κ°μ§ ꡬνμ΄ κ³΅μ‘΄ν©λλ€: νλλ μ΅μ νλμ΄ λΉ λ₯Έ cuda컀λμ μ¬μ©νκ³ , λ€λ₯Έ νλλ λ¨μνμ§λ§ λͺ¨λ μ₯μΉμμ μ€νν μ μμ΅λλ€! | |
| - νμ¬ κ΅¬νμ μλ³Έ cuda컀λμ νμ©ν©λλ€: λ§λ°λ₯Ό μν νλμ μ΄ν μ μ μν μ νλ κ²μ [`mamba-ssm`](https://github.com/state-spaces/mamba)μ [`causal_conv1d`](https://github.com/Dao-AILab/causal-conv1d) μ μ₯μμ νΈμ€ν λμ΄ μμ΅λλ€. νλμ¨μ΄κ° μ§μνλ€λ©΄ λ°λμ μ€μΉνμΈμ! | |
| - cuda 컀λμ μ΅μ ννλ λ°©ν₯ 보λ€λ, λ¨μνμ§λ§ λͺ¨λ μ₯μΉμμ μ€νκ°λ₯νλλ‘νλ λ°©ν₯μΈ 'λ¨μꡬν'μ μ±λ₯μ λΉ λ₯΄κ² ν₯μμν€λ κΈ°μ¬λ₯Ό λ νμνκ³ μμ΅λλ€. π€ | |
| μ΄ λͺ¨λΈμ [ArthurZ](https://huggingface.co/ArthurZ)μ μν΄ κΈ°μ¬λμμ΅λλ€. | |
| μλ³Έ μ½λλ [μ΄κ³³](https://github.com/state-spaces/mamba)μμ νμΈν μ μμ΅λλ€. | |
| # μ¬μ© | |
| ### κ°λ¨ν μμ± μμ | |
| ```python | |
| from transformers import MambaConfig, MambaForCausalLM, AutoTokenizer | |
| import torch | |
| tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf") | |
| model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf") | |
| input_ids = tokenizer("Hey how are you doing?", return_tensors= "pt")["input_ids"] | |
| out = model.generate(input_ids, max_new_tokens=10) | |
| print(tokenizer.batch_decode(out)) | |
| ``` | |
| ### Peft νμΈνλ | |
| λλ¦° λ²μ μ νμ΅μμ μμ£Ό μμ μ μ΄μ§ μμ΅λλ€. λΉ λ₯Έ λ²μ μ `float32`κ° νμν©λλ€! | |
| ```python | |
| from datasets import load_dataset | |
| from trl import SFTConfig, SFTTrainer | |
| from peft import LoraConfig | |
| model_id = "state-spaces/mamba-130m-hf" | |
| dataset = load_dataset("Abirate/english_quotes", split="train") | |
| training_args = SFTConfig(dataset_text_field="quote") | |
| lora_config = LoraConfig(target_modules=["x_proj", "embeddings", "in_proj", "out_proj"]) | |
| trainer = SFTTrainer( | |
| model=model_id, | |
| args=training_args, | |
| train_dataset=dataset, | |
| peft_config=lora_config, | |
| ) | |
| trainer.train() | |
| ``` | |
| ## MambaConfig | |
| [[autodoc]] MambaConfig | |
| ## MambaModel | |
| [[autodoc]] MambaModel | |
| - forward | |
| ## MambaLMHeadModel | |
| [[autodoc]] MambaForCausalLM | |
| - forward | |