File size: 6,421 Bytes
17c6d62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
<!--Copyright 2024 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.

⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.

-->

# λ§˜λ°”2[[mamba-2]]

## κ°œμš”[[overview]]

λ§˜λ°”2 λͺ¨λΈμ€ Tri Dao, Albert Guκ°€ μ œμ•ˆν•œ [νŠΈλžœμŠ€ν¬λ¨ΈλŠ” SSM이닀: κ΅¬μ‘°ν™”λœ μƒνƒœ 곡간 이쀑성을 ν†΅ν•œ μΌλ°˜ν™”λœ λͺ¨λΈκ³Ό 효율적인 μ•Œκ³ λ¦¬μ¦˜](https://arxiv.org/abs/2405.21060)λΌλŠ” λ…Όλ¬Έμ—μ„œ μ†Œκ°œλ˜μ—ˆμŠ΅λ‹ˆλ‹€. λ§˜λ°”2λŠ” λ§˜λ°”1κ³Ό μœ μ‚¬ν•œ μƒνƒœ 곡간 λͺ¨λΈλ‘œ, λ‹¨μˆœν™”λœ μ•„ν‚€ν…μ²˜μ—μ„œ 더 λ‚˜μ€ μ„±λŠ₯을 λ³΄μž…λ‹ˆλ‹€.

ν•΄λ‹Ή λ…Όλ¬Έμ˜ μ΄ˆλ‘μž…λ‹ˆλ‹€:

*νŠΈλžœμŠ€ν¬λ¨ΈλŠ” μ–Έμ–΄ λͺ¨λΈλ§μ—μ„œ λ”₯λŸ¬λ‹ μ„±κ³΅μ˜ μ£Όμš” μ•„ν‚€ν…μ²˜μ˜€μ§€λ§Œ, λ§˜λ°”μ™€ 같은 μƒνƒœ 곡간 λͺ¨λΈ(SSM)이 졜근 μ†Œκ·œλͺ¨ ν˜Ήμ€ 쀑간 규λͺ¨μ—μ„œ νŠΈλžœμŠ€ν¬λ¨Έμ™€ λŒ€λ“±ν•˜κ±°λ‚˜ 더 λ‚˜μ€ μ„±λŠ₯을 λ³΄μ΄λŠ” κ²ƒμœΌλ‘œ λ‚˜νƒ€λ‚¬μŠ΅λ‹ˆλ‹€. μš°λ¦¬λŠ” μ΄λŸ¬ν•œ λͺ¨λΈ 계열듀이 μ‹€μ œλ‘œ 맀우 λ°€μ ‘ν•˜κ²Œ μ—°κ΄€λ˜μ–΄ μžˆμŒμ„ νŒŒμ•…ν–ˆμŠ΅λ‹ˆλ‹€. 그리고 κ΅¬μ‘°ν™”λœ 쀀뢄리(semiseparable) ν–‰λ ¬ 쀑 연ꡬ가 잘 이루어진 클래슀의 λ‹€μ–‘ν•œ λΆ„ν•΄λ₯Ό 톡해 μ—°κ²°λœ SSMκ³Ό μ–΄ν…μ…˜ λ³€ν˜• μ‚¬μ΄μ˜ ν’λΆ€ν•œ 이둠적 μ—°κ²° ν”„λ ˆμž„μ›Œν¬λ₯Ό κ°œλ°œν–ˆμŠ΅λ‹ˆλ‹€. μƒνƒœ 곡간 이쀑성(SSD) ν”„λ ˆμž„μ›Œν¬λ₯Ό 톡해 λ§˜λ°”1의 선택적 SSM을 κ°œμ„ ν•œ μƒˆλ‘œμš΄ μ•„ν‚€ν…μ²˜λ₯Ό 섀계할 수 μžˆμ—ˆκ³ , νŠΈλžœμŠ€ν¬λ¨Έμ™€ 경쟁λ ₯을 μœ μ§€ν•˜λ©΄μ„œλ„ μ†λ„λŠ” 2~8λ°° 더 λΉ λ₯Έ μ„±λŠ₯을 λƒ…λ‹ˆλ‹€.*

팁:

이 버전은 λ§˜λ°”2 κ΅¬ν˜„μ„ 지원해야 ν•˜λ©°, 특히 Mistral AI의 [Mamba-2 codestral](https://huggingface.co/mistralai/Mamba-Codestral-7B-v0.1)을 μ§€μ›ν•©λ‹ˆλ‹€. 특히, mamba 2 codestral은 8개의 `groups`둜 μΆœμ‹œλ˜μ—ˆλŠ”λ°, μ΄λŠ” μ–΄ν…μ…˜ 기반 λͺ¨λΈμ˜ KV ν—€λ“œ μˆ˜μ™€ μœ μ‚¬ν•˜λ‹€κ³  νŒλ‹¨ κ°€λŠ₯ν•©λ‹ˆλ‹€.

이 λͺ¨λΈμ€ `torch_forward`와 `cuda_kernels_forward`λΌλŠ” 두 κ°€μ§€ λ‹€λ₯Έ μ „λ°© 패슀λ₯Ό κ°€μ§‘λ‹ˆλ‹€. `cuda_kernels_forward`λŠ” ν™˜κ²½μ—μ„œ cuda 컀널을 찾으면 이λ₯Ό μ‚¬μš©ν•˜λ©°, prefillμ—μ„œλŠ” 더 λŠλ¦½λ‹ˆλ‹€. 즉, 높은 CPU μ˜€λ²„ν—€λ“œλ‘œ 인해 "μ›œμ—… μ‹€ν–‰"이 ν•„μš”ν•˜κΈ° λ•Œλ¬Έμž…λ‹ˆλ‹€. κ΄€λ ¨ λ‚΄μš©μ€ [이곳](https://github.com/state-spaces/mamba/issues/389#issuecomment-2171755306)κ³Ό [이곳](https://github.com/state-spaces/mamba/issues/355#issuecomment-2147597457)을 μ°Έκ³ ν•˜μ„Έμš”. 

컴파일 μ—†μ΄λŠ” `torch_forward` κ΅¬ν˜„μ΄ 3~4λ°° λΉ λ¦…λ‹ˆλ‹€. λ˜ν•œ, 이 λͺ¨λΈμ—λŠ” μœ„μΉ˜ μž„λ² λ”©μ΄ μ—†μ§€λ§Œ `attention_mask`와 배치 μƒμ„±μ˜ 경우 두 κ³³μ—μ„œ 은닉 μƒνƒœ(hidden state)λ₯Ό λ§ˆμŠ€ν‚Ήν•˜λŠ” νŠΉμ • 둜직이 μžˆμŠ΅λ‹ˆλ‹€. κ΄€λ ¨ λ‚΄μš©μ€ [이곳](https://github.com/state-spaces/mamba/issues/66#issuecomment-1863563829)을 μ°Έκ³ ν•˜μ„Έμš”. 

μ΄λ‘œμΈν•΄ λ§˜λ°”2 μ»€λ„μ˜ μž¬κ΅¬ν˜„κ³Ό ν•¨κ»˜ 배치 생성 및 μΊμ‹œλœ μƒμ„±μ—μ„œ μ•½κ°„μ˜ 차이가 μ˜ˆμƒλ©λ‹ˆλ‹€. λ˜ν•œ cuda 컀널 λ˜λŠ” torch forwardκ°€ μ œκ³΅ν•˜λŠ” κ²°κ³Όκ°€ μ•½κ°„ λ‹€λ₯Ό κ²ƒμœΌλ‘œ μ˜ˆμƒλ©λ‹ˆλ‹€. SSM μ•Œκ³ λ¦¬μ¦˜μ€ ν…μ„œ μˆ˜μΆ•μ— 크게 μ˜μ‘΄ν•˜λŠ”λ°, μ΄λŠ” matmulκ³Ό λ™λ“±ν•˜μ§€λ§Œ μ—°μ‚° μˆœμ„œκ°€ μ•½κ°„ λ‹€λ₯΄λ©°, 이둜 인해 더 μž‘μ€ μ •λ°€λ„μ—μ„œ 차이가 더 μ»€μ§‘λ‹ˆλ‹€.

또 λ‹€λ₯Έ μ°Έκ³ μ‚¬ν•­μœΌλ‘œ, νŒ¨λ”© 토큰에 ν•΄λ‹Ήν•˜λŠ” 은닉 μƒνƒœ(hidden state)의 μ’…λ£ŒλŠ” 두 κ³³μ—μ„œ 이루어지며 주둜 μ™Όμͺ½ νŒ¨λ”©μœΌλ‘œ ν…ŒμŠ€νŠΈλ˜μ—ˆμŠ΅λ‹ˆλ‹€. 였λ₯Έμͺ½ νŒ¨λ”©μ€ λ…Έμ΄μ¦ˆλ₯Ό μ „νŒŒν•˜λ―€λ‘œ 만쑱슀러운 κ²°κ³Όλ₯Ό 보μž₯ν•˜μ§€ μ•ŠμŠ΅λ‹ˆλ‹€. `tokenizer.padding_side = "left"`λ₯Ό μ‚¬μš©ν•˜λ©΄ μ˜¬λ°”λ₯Έ νŒ¨λ”© λ°©ν–₯을 μ‚¬μš©ν•  수 μžˆμŠ΅λ‹ˆλ‹€.

이 λͺ¨λΈμ€ [Molbap](https://huggingface.co/Molbap)이 κΈ°μ—¬ν–ˆμœΌλ©°, [Anton Vlasjuk](https://github.com/vasqu)의 큰 도움을 λ°›μ•˜μŠ΅λ‹ˆλ‹€.
원본 μ½”λ“œλŠ” [이곳](https://github.com/state-spaces/mamba)μ—μ„œ 확인할 수 μžˆμŠ΅λ‹ˆλ‹€.


# μ‚¬μš©

### κ°„λ‹¨ν•œ 생성 예: 
```python 
from transformers import Mamba2Config, Mamba2ForCausalLM, AutoTokenizer
import torch
model_id = 'mistralai/Mamba-Codestral-7B-v0.1'
tokenizer = AutoTokenizer.from_pretrained(model_id, revision='refs/pr/9', from_slow=True, legacy=False)
model = Mamba2ForCausalLM.from_pretrained(model_id, revision='refs/pr/9')
input_ids = tokenizer("Hey how are you doing?", return_tensors= "pt")["input_ids"]

out = model.generate(input_ids, max_new_tokens=10)
print(tokenizer.batch_decode(out))
```

이곳은 미세쑰정을 μœ„ν•œ μ΄ˆμ•ˆ μŠ€ν¬λ¦½νŠΈμž…λ‹ˆλ‹€: 
```python 
from trl import SFTTrainer
from peft import LoraConfig
from transformers import AutoTokenizer, Mamba2ForCausalLM, TrainingArguments
model_id = 'mistralai/Mamba-Codestral-7B-v0.1'
tokenizer = AutoTokenizer.from_pretrained(model_id, revision='refs/pr/9', from_slow=True, legacy=False)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left" #μ™Όμͺ½ νŒ¨λ”©μœΌλ‘œ μ„€μ •

model = Mamba2ForCausalLM.from_pretrained(model_id, revision='refs/pr/9')
dataset = load_dataset("Abirate/english_quotes", split="train")
# CUDA μ»€λ„μ—†μ΄λŠ”, 배치크기 2κ°€ 80GB μž₯치λ₯Ό ν•˜λ‚˜ μ°¨μ§€ν•©λ‹ˆλ‹€.
# ν•˜μ§€λ§Œ μ •ν™•λ„λŠ” κ°μ†Œν•©λ‹ˆλ‹€.
# μ‹€ν—˜κ³Ό μ‹œλ„λ₯Ό ν™˜μ˜ν•©λ‹ˆλ‹€!
training_args = TrainingArguments(
    output_dir="./results",
    num_train_epochs=3,
    per_device_train_batch_size=2,
    logging_dir='./logs',
    logging_steps=10,
    learning_rate=2e-3
)
lora_config =  LoraConfig(
        r=8,
        target_modules=["embeddings", "in_proj", "out_proj"],
        task_type="CAUSAL_LM",
        bias="none"
)
trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    peft_config=lora_config,
    train_dataset=dataset,
    dataset_text_field="quote",
)
trainer.train()
```


## Mamba2Config

[[autodoc]] Mamba2Config

## Mamba2Model

[[autodoc]] Mamba2Model
    - forward

## Mamba2LMHeadModel

[[autodoc]] Mamba2ForCausalLM
    - forward