File size: 6,421 Bytes
17c6d62 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
β οΈ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# λ§λ°2[[mamba-2]]
## κ°μ[[overview]]
λ§λ°2 λͺ¨λΈμ Tri Dao, Albert Guκ° μ μν [νΈλμ€ν¬λ¨Έλ SSMμ΄λ€: ꡬ쑰νλ μν κ³΅κ° μ΄μ€μ±μ ν΅ν μΌλ°νλ λͺ¨λΈκ³Ό ν¨μ¨μ μΈ μκ³ λ¦¬μ¦](https://arxiv.org/abs/2405.21060)λΌλ λ
Όλ¬Έμμ μκ°λμμ΅λλ€. λ§λ°2λ λ§λ°1κ³Ό μ μ¬ν μν κ³΅κ° λͺ¨λΈλ‘, λ¨μνλ μν€ν
μ²μμ λ λμ μ±λ₯μ 보μ
λλ€.
ν΄λΉ λ
Όλ¬Έμ μ΄λ‘μ
λλ€:
*νΈλμ€ν¬λ¨Έλ μΈμ΄ λͺ¨λΈλ§μμ λ₯λ¬λ μ±κ³΅μ μ£Όμ μν€ν
μ²μμ§λ§, λ§λ°μ κ°μ μν κ³΅κ° λͺ¨λΈ(SSM)μ΄ μ΅κ·Ό μκ·λͺ¨ νΉμ μ€κ° κ·λͺ¨μμ νΈλμ€ν¬λ¨Έμ λλ±νκ±°λ λ λμ μ±λ₯μ 보μ΄λ κ²μΌλ‘ λνλ¬μ΅λλ€. μ°λ¦¬λ μ΄λ¬ν λͺ¨λΈ κ³μ΄λ€μ΄ μ€μ λ‘ λ§€μ° λ°μ νκ² μ°κ΄λμ΄ μμμ νμ
νμ΅λλ€. κ·Έλ¦¬κ³ κ΅¬μ‘°νλ μ€λΆλ¦¬(semiseparable) νλ ¬ μ€ μ°κ΅¬κ° μ μ΄λ£¨μ΄μ§ ν΄λμ€μ λ€μν λΆν΄λ₯Ό ν΅ν΄ μ°κ²°λ SSMκ³Ό μ΄ν
μ
λ³ν μ¬μ΄μ νλΆν μ΄λ‘ μ μ°κ²° νλ μμν¬λ₯Ό κ°λ°νμ΅λλ€. μν κ³΅κ° μ΄μ€μ±(SSD) νλ μμν¬λ₯Ό ν΅ν΄ λ§λ°1μ μ νμ SSMμ κ°μ ν μλ‘μ΄ μν€ν
μ²λ₯Ό μ€κ³ν μ μμκ³ , νΈλμ€ν¬λ¨Έμ κ²½μλ ₯μ μ μ§νλ©΄μλ μλλ 2~8λ°° λ λΉ λ₯Έ μ±λ₯μ λ
λλ€.*
ν:
μ΄ λ²μ μ λ§λ°2 ꡬνμ μ§μν΄μΌ νλ©°, νΉν Mistral AIμ [Mamba-2 codestral](https://huggingface.co/mistralai/Mamba-Codestral-7B-v0.1)μ μ§μν©λλ€. νΉν, mamba 2 codestralμ 8κ°μ `groups`λ‘ μΆμλμλλ°, μ΄λ μ΄ν
μ
κΈ°λ° λͺ¨λΈμ KV ν€λ μμ μ μ¬νλ€κ³ νλ¨ κ°λ₯ν©λλ€.
μ΄ λͺ¨λΈμ `torch_forward`μ `cuda_kernels_forward`λΌλ λ κ°μ§ λ€λ₯Έ μ λ°© ν¨μ€λ₯Ό κ°μ§λλ€. `cuda_kernels_forward`λ νκ²½μμ cuda 컀λμ μ°ΎμΌλ©΄ μ΄λ₯Ό μ¬μ©νλ©°, prefillμμλ λ λ립λλ€. μ¦, λμ CPU μ€λ²ν€λλ‘ μΈν΄ "μμ
μ€ν"μ΄ νμνκΈ° λλ¬Έμ
λλ€. κ΄λ ¨ λ΄μ©μ [μ΄κ³³](https://github.com/state-spaces/mamba/issues/389#issuecomment-2171755306)κ³Ό [μ΄κ³³](https://github.com/state-spaces/mamba/issues/355#issuecomment-2147597457)μ μ°Έκ³ νμΈμ.
μ»΄νμΌ μμ΄λ `torch_forward` ꡬνμ΄ 3~4λ°° λΉ λ¦
λλ€. λν, μ΄ λͺ¨λΈμλ μμΉ μλ² λ©μ΄ μμ§λ§ `attention_mask`μ λ°°μΉ μμ±μ κ²½μ° λ κ³³μμ μλ μν(hidden state)λ₯Ό λ§μ€νΉνλ νΉμ λ‘μ§μ΄ μμ΅λλ€. κ΄λ ¨ λ΄μ©μ [μ΄κ³³](https://github.com/state-spaces/mamba/issues/66#issuecomment-1863563829)μ μ°Έκ³ νμΈμ.
μ΄λ‘μΈν΄ λ§λ°2 컀λμ μ¬κ΅¬νκ³Ό ν¨κ» λ°°μΉ μμ± λ° μΊμλ μμ±μμ μ½κ°μ μ°¨μ΄κ° μμλ©λλ€. λν cuda 컀λ λλ torch forwardκ° μ 곡νλ κ²°κ³Όκ° μ½κ° λ€λ₯Ό κ²μΌλ‘ μμλ©λλ€. SSM μκ³ λ¦¬μ¦μ ν
μ μμΆμ ν¬κ² μμ‘΄νλλ°, μ΄λ matmulκ³Ό λλ±νμ§λ§ μ°μ° μμκ° μ½κ° λ€λ₯΄λ©°, μ΄λ‘ μΈν΄ λ μμ μ λ°λμμ μ°¨μ΄κ° λ 컀μ§λλ€.
λ λ€λ₯Έ μ°Έκ³ μ¬νμΌλ‘, ν¨λ© ν ν°μ ν΄λΉνλ μλ μν(hidden state)μ μ’
λ£λ λ κ³³μμ μ΄λ£¨μ΄μ§λ©° μ£Όλ‘ μΌμͺ½ ν¨λ©μΌλ‘ ν
μ€νΈλμμ΅λλ€. μ€λ₯Έμͺ½ ν¨λ©μ λ
Έμ΄μ¦λ₯Ό μ ννλ―λ‘ λ§μ‘±μ€λ¬μ΄ κ²°κ³Όλ₯Ό 보μ₯νμ§ μμ΅λλ€. `tokenizer.padding_side = "left"`λ₯Ό μ¬μ©νλ©΄ μ¬λ°λ₯Έ ν¨λ© λ°©ν₯μ μ¬μ©ν μ μμ΅λλ€.
μ΄ λͺ¨λΈμ [Molbap](https://huggingface.co/Molbap)μ΄ κΈ°μ¬νμΌλ©°, [Anton Vlasjuk](https://github.com/vasqu)μ ν° λμμ λ°μμ΅λλ€.
μλ³Έ μ½λλ [μ΄κ³³](https://github.com/state-spaces/mamba)μμ νμΈν μ μμ΅λλ€.
# μ¬μ©
### κ°λ¨ν μμ± μ:
```python
from transformers import Mamba2Config, Mamba2ForCausalLM, AutoTokenizer
import torch
model_id = 'mistralai/Mamba-Codestral-7B-v0.1'
tokenizer = AutoTokenizer.from_pretrained(model_id, revision='refs/pr/9', from_slow=True, legacy=False)
model = Mamba2ForCausalLM.from_pretrained(model_id, revision='refs/pr/9')
input_ids = tokenizer("Hey how are you doing?", return_tensors= "pt")["input_ids"]
out = model.generate(input_ids, max_new_tokens=10)
print(tokenizer.batch_decode(out))
```
μ΄κ³³μ λ―ΈμΈμ‘°μ μ μν μ΄μ μ€ν¬λ¦½νΈμ
λλ€:
```python
from trl import SFTTrainer
from peft import LoraConfig
from transformers import AutoTokenizer, Mamba2ForCausalLM, TrainingArguments
model_id = 'mistralai/Mamba-Codestral-7B-v0.1'
tokenizer = AutoTokenizer.from_pretrained(model_id, revision='refs/pr/9', from_slow=True, legacy=False)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left" #μΌμͺ½ ν¨λ©μΌλ‘ μ€μ
model = Mamba2ForCausalLM.from_pretrained(model_id, revision='refs/pr/9')
dataset = load_dataset("Abirate/english_quotes", split="train")
# CUDA 컀λμμ΄λ, λ°°μΉν¬κΈ° 2κ° 80GB μ₯μΉλ₯Ό νλ μ°¨μ§ν©λλ€.
# νμ§λ§ μ νλλ κ°μν©λλ€.
# μ€νκ³Ό μλλ₯Ό νμν©λλ€!
training_args = TrainingArguments(
output_dir="./results",
num_train_epochs=3,
per_device_train_batch_size=2,
logging_dir='./logs',
logging_steps=10,
learning_rate=2e-3
)
lora_config = LoraConfig(
r=8,
target_modules=["embeddings", "in_proj", "out_proj"],
task_type="CAUSAL_LM",
bias="none"
)
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
args=training_args,
peft_config=lora_config,
train_dataset=dataset,
dataset_text_field="quote",
)
trainer.train()
```
## Mamba2Config
[[autodoc]] Mamba2Config
## Mamba2Model
[[autodoc]] Mamba2Model
- forward
## Mamba2LMHeadModel
[[autodoc]] Mamba2ForCausalLM
- forward
|