File size: 6,388 Bytes
17c6d62 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
β οΈ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# λ§λ°[[mamba]]
## κ°μ[[overview]]
λ§λ°(Mamba) λͺ¨λΈμ Albert Gu, Tri Daoκ° μ μν [λ§λ°: μ νμ μν 곡κ°μ μ΄μ©ν μ ν μκ° μνμ€ λͺ¨λΈλ§](https://arxiv.org/abs/2312.00752)λΌλ λ
Όλ¬Έμμ μκ° λμμ΅λλ€.
μ΄ λͺ¨λΈμ `state-space-models`μ κΈ°λ°μΌλ‘ ν μλ‘μ΄ ν¨λ¬λ€μ μν€ν
μ²μ
λλ€. μ§κ΄μ μΈ μ΄ν΄λ₯Ό μ»κ³ μΆλ€λ©΄ [μ΄κ³³](https://srush.github.io/annotated-s4/)μ μ°Έκ³ νμΈμ.
ν΄λΉ λ
Όλ¬Έμ μ΄λ‘μ
λλ€:
*νμ¬ λ₯λ¬λμμ ν₯λ―Έλ‘μ΄ μμ© νλ‘κ·Έλ¨μ ꡬλνλ λλΆλΆμ κΈ°μ΄ λͺ¨λΈλ€μ κ±°μ 보νΈμ μΌλ‘ νΈλμ€ν¬λ¨Έ μν€ν
μ²μ κ·Έ ν΅μ¬ μ΄ν
μ
λͺ¨λμ κΈ°λ°μΌλ‘ ν©λλ€. μ ν μ΄ν
μ
, κ²μ΄νΈλ 컨볼루μ
κ³Ό μν λͺ¨λΈ, ꡬ쑰νλ μν κ³΅κ° λͺ¨λΈ(SSM) λ± λ§μ μ€μ΄μ°¨μκ°(subquadratic-time) μν€ν
μ²κ° κΈ΄ μνμ€μ λν νΈλμ€ν¬λ¨Έμ κ³μ° λΉν¨μ¨μ±μ ν΄κ²°νκΈ° μν΄ κ°λ°λμμ§λ§, μΈμ΄μ κ°μ μ€μν μμμμλ μ΄ν
μ
λ§νΌ μ±λ₯μ λ΄μ§ λͺ»νμ΅λλ€. μ°λ¦¬λ μ΄λ¬ν λͺ¨λΈμ μ£Όμ μ½μ μ΄ λ΄μ© κΈ°λ° μΆλ‘ μ μννμ§ λͺ»νλ€λ μ μμ μκ³ λͺ κ°μ§λ₯Ό κ°μ νμ΅λλ€. 첫째, SSM λ§€κ°λ³μλ₯Ό μ
λ ₯μ ν¨μλ‘ λ§λλ κ²λ§μΌλ‘λ μ΄μ° λͺ¨λ¬λ¦¬ν°(discrete modalities)μ μ½μ μ ν΄κ²°ν μ μμ΄, νμ¬ ν ν°μ λ°λΌ μνμ€ κΈΈμ΄ μ°¨μμ λ°λΌ μ 보λ₯Ό μ νμ μΌλ‘ μ ννκ±°λ μμ μ μκ² ν©λλ€. λμ§Έ, μ΄λ¬ν λ³κ²½μΌλ‘ ν¨μ¨μ μΈ μ»¨λ³Όλ£¨μ
μ μ¬μ©ν μ μκ² λμμ§λ§, μ°λ¦¬λ μν λͺ¨λμμ νλμ¨μ΄λ₯Ό μΈμνλ λ³λ ¬ μκ³ λ¦¬μ¦μ μ€κ³νμ΅λλ€. μ°λ¦¬λ μ΄λ¬ν μ νμ SSMμ μ΄ν
μ
μ΄λ MLP λΈλ‘λ μλ λ¨μνλ μ’
λ¨κ° μ κ²½λ§ μν€ν
μ²μΈ λ§λ°μ ν΅ν©μμΌ°μ΅λλ€. λ§λ°λ λΉ λ₯Έ μΆλ‘ (νΈλμ€ν¬λ¨Έλ³΄λ€ 5λ°° λμ μ²λ¦¬λ)κ³Ό μνμ€ κΈΈμ΄μ λν μ ν νμ₯μ±μ λ리며, λ°±λ§ κΈΈμ΄ μνμ€κΉμ§ μ€μ λ°μ΄ν°μμ μ±λ₯μ΄ ν₯μλ©λλ€. μΌλ°μ μΈ μνμ€ λͺ¨λΈ λ°±λ³ΈμΌλ‘μ λ§λ°λ μΈμ΄, μ€λμ€, μ μ 체νκ³Ό κ°μ μ¬λ¬ μμμμ μ΅μ²¨λ¨ μ±λ₯μ λ¬μ±ν©λλ€. μΈμ΄ λͺ¨λΈλ§μμ μ°λ¦¬μ λ§λ°-3B λͺ¨λΈμ κ°μ ν¬κΈ°μ νΈλμ€ν¬λ¨Έλ₯Ό λ₯κ°νκ³ λ λ°° ν¬κΈ°μ νΈλμ€ν¬λ¨Έμ λ§λ¨Ήλ μ±λ₯μ 보μ΄λ©°, μ¬μ νλ ¨κ³Ό λ€μ΄μ€νΈλ¦Ό νκ° λͺ¨λμμ μ±λ₯μ λνλ©λλ€.*
ν:
- λ§λ°λ κ³ μ μ μΈ νΈλμ€ν¬λ¨Έμ κ²¬μ€ λ§ν μλ‘μ΄ `μν κ³΅κ° λͺ¨λΈ` μν€ν
μ²μ
λλ€. μ΄λ ꡬ쑰νλ μν κ³΅κ° λͺ¨λΈμ λ°μ μ μμ μμΌλ©°, [νλμμ΄ν
μ
](https://github.com/Dao-AILab/flash-attention)μ μ μ μ λ°λ₯΄λ ν¨μ¨μ μΈ νλμ¨μ΄ μΈμ μ€κ³μ ꡬνμ νΉμ§μΌλ‘ ν©λλ€.
- λ§λ°λ `μ΄ν
μ
` λ μ΄μ΄μ λλ±ν `λ―Ήμ(mixer)` λ μ΄μ΄λ₯Ό μμ΅λλ€. `λ§λ°`μ ν΅μ¬ λ‘μ§μ `MambaMixer` ν΄λμ€μ μμ΅λλ€.
- λ κ°μ§ ꡬνμ΄ κ³΅μ‘΄ν©λλ€: νλλ μ΅μ νλμ΄ λΉ λ₯Έ cuda컀λμ μ¬μ©νκ³ , λ€λ₯Έ νλλ λ¨μνμ§λ§ λͺ¨λ μ₯μΉμμ μ€νν μ μμ΅λλ€!
- νμ¬ κ΅¬νμ μλ³Έ cuda컀λμ νμ©ν©λλ€: λ§λ°λ₯Ό μν νλμ μ΄ν
μ
μ μν μ νλ κ²μ [`mamba-ssm`](https://github.com/state-spaces/mamba)μ [`causal_conv1d`](https://github.com/Dao-AILab/causal-conv1d) μ μ₯μμ νΈμ€ν
λμ΄ μμ΅λλ€. νλμ¨μ΄κ° μ§μνλ€λ©΄ λ°λμ μ€μΉνμΈμ!
- cuda 컀λμ μ΅μ ννλ λ°©ν₯ 보λ€λ, λ¨μνμ§λ§ λͺ¨λ μ₯μΉμμ μ€νκ°λ₯νλλ‘νλ λ°©ν₯μΈ 'λ¨μꡬν'μ μ±λ₯μ λΉ λ₯΄κ² ν₯μμν€λ κΈ°μ¬λ₯Ό λ νμνκ³ μμ΅λλ€. π€
μ΄ λͺ¨λΈμ [ArthurZ](https://huggingface.co/ArthurZ)μ μν΄ κΈ°μ¬λμμ΅λλ€.
μλ³Έ μ½λλ [μ΄κ³³](https://github.com/state-spaces/mamba)μμ νμΈν μ μμ΅λλ€.
# μ¬μ©
### κ°λ¨ν μμ± μμ
```python
from transformers import MambaConfig, MambaForCausalLM, AutoTokenizer
import torch
tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf")
model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf")
input_ids = tokenizer("Hey how are you doing?", return_tensors= "pt")["input_ids"]
out = model.generate(input_ids, max_new_tokens=10)
print(tokenizer.batch_decode(out))
```
### Peft νμΈνλ
λλ¦° λ²μ μ νμ΅μμ μμ£Ό μμ μ μ΄μ§ μμ΅λλ€. λΉ λ₯Έ λ²μ μ `float32`κ° νμν©λλ€!
```python
from datasets import load_dataset
from trl import SFTTrainer
from peft import LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
model_id = "state-spaces/mamba-130m-hf"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)
dataset = load_dataset("Abirate/english_quotes", split="train")
training_args = TrainingArguments(
output_dir="./results",
num_train_epochs=3,
per_device_train_batch_size=4,
logging_dir='./logs',
logging_steps=10,
learning_rate=2e-3
)
lora_config = LoraConfig(
r=8,
target_modules=["x_proj", "embeddings", "in_proj", "out_proj"],
task_type="CAUSAL_LM",
bias="none"
)
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
args=training_args,
peft_config=lora_config,
train_dataset=dataset,
dataset_text_field="quote",
)
trainer.train()
```
## MambaConfig
[[autodoc]] MambaConfig
## MambaModel
[[autodoc]] MambaModel
- forward
## MambaLMHeadModel
[[autodoc]] MambaForCausalLM
- forward
|