λ§λ°2[[mamba-2]]
κ°μ[[overview]]
λ§λ°2 λͺ¨λΈμ Tri Dao, Albert Guκ° μ μν νΈλμ€ν¬λ¨Έλ SSMμ΄λ€: ꡬ쑰νλ μν κ³΅κ° μ΄μ€μ±μ ν΅ν μΌλ°νλ λͺ¨λΈκ³Ό ν¨μ¨μ μΈ μκ³ λ¦¬μ¦λΌλ λ Όλ¬Έμμ μκ°λμμ΅λλ€. λ§λ°2λ λ§λ°1κ³Ό μ μ¬ν μν κ³΅κ° λͺ¨λΈλ‘, λ¨μνλ μν€ν μ²μμ λ λμ μ±λ₯μ 보μ λλ€.
ν΄λΉ λ Όλ¬Έμ μ΄λ‘μ λλ€:
νΈλμ€ν¬λ¨Έλ μΈμ΄ λͺ¨λΈλ§μμ λ₯λ¬λ μ±κ³΅μ μ£Όμ μν€ν μ²μμ§λ§, λ§λ°μ κ°μ μν κ³΅κ° λͺ¨λΈ(SSM)μ΄ μ΅κ·Ό μκ·λͺ¨ νΉμ μ€κ° κ·λͺ¨μμ νΈλμ€ν¬λ¨Έμ λλ±νκ±°λ λ λμ μ±λ₯μ 보μ΄λ κ²μΌλ‘ λνλ¬μ΅λλ€. μ°λ¦¬λ μ΄λ¬ν λͺ¨λΈ κ³μ΄λ€μ΄ μ€μ λ‘ λ§€μ° λ°μ νκ² μ°κ΄λμ΄ μμμ νμ νμ΅λλ€. κ·Έλ¦¬κ³ κ΅¬μ‘°νλ μ€λΆλ¦¬(semiseparable) νλ ¬ μ€ μ°κ΅¬κ° μ μ΄λ£¨μ΄μ§ ν΄λμ€μ λ€μν λΆν΄λ₯Ό ν΅ν΄ μ°κ²°λ SSMκ³Ό μ΄ν μ λ³ν μ¬μ΄μ νλΆν μ΄λ‘ μ μ°κ²° νλ μμν¬λ₯Ό κ°λ°νμ΅λλ€. μν κ³΅κ° μ΄μ€μ±(SSD) νλ μμν¬λ₯Ό ν΅ν΄ λ§λ°1μ μ νμ SSMμ κ°μ ν μλ‘μ΄ μν€ν μ²λ₯Ό μ€κ³ν μ μμκ³ , νΈλμ€ν¬λ¨Έμ κ²½μλ ₯μ μ μ§νλ©΄μλ μλλ 2~8λ°° λ λΉ λ₯Έ μ±λ₯μ λ λλ€.
ν:
μ΄ λ²μ μ λ§λ°2 ꡬνμ μ§μν΄μΌ νλ©°, νΉν Mistral AIμ Mamba-2 codestralμ μ§μν©λλ€. νΉν, mamba 2 codestralμ 8κ°μ groupsλ‘ μΆμλμλλ°, μ΄λ μ΄ν
μ
κΈ°λ° λͺ¨λΈμ KV ν€λ μμ μ μ¬νλ€κ³ νλ¨ κ°λ₯ν©λλ€.
μ΄ λͺ¨λΈμ torch_forwardμ cuda_kernels_forwardλΌλ λ κ°μ§ λ€λ₯Έ μ λ°© ν¨μ€λ₯Ό κ°μ§λλ€. cuda_kernels_forwardλ νκ²½μμ cuda 컀λμ μ°ΎμΌλ©΄ μ΄λ₯Ό μ¬μ©νλ©°, prefillμμλ λ λ립λλ€. μ¦, λμ CPU μ€λ²ν€λλ‘ μΈν΄ "μμ
μ€ν"μ΄ νμνκΈ° λλ¬Έμ
λλ€. κ΄λ ¨ λ΄μ©μ μ΄κ³³κ³Ό μ΄κ³³μ μ°Έκ³ νμΈμ.
μ»΄νμΌ μμ΄λ torch_forward ꡬνμ΄ 3~4λ°° λΉ λ¦
λλ€. λν, μ΄ λͺ¨λΈμλ μμΉ μλ² λ©μ΄ μμ§λ§ attention_maskμ λ°°μΉ μμ±μ κ²½μ° λ κ³³μμ μλ μν(hidden state)λ₯Ό λ§μ€νΉνλ νΉμ λ‘μ§μ΄ μμ΅λλ€. κ΄λ ¨ λ΄μ©μ μ΄κ³³μ μ°Έκ³ νμΈμ.
μ΄λ‘μΈν΄ λ§λ°2 컀λμ μ¬κ΅¬νκ³Ό ν¨κ» λ°°μΉ μμ± λ° μΊμλ μμ±μμ μ½κ°μ μ°¨μ΄κ° μμλ©λλ€. λν cuda 컀λ λλ torch forwardκ° μ 곡νλ κ²°κ³Όκ° μ½κ° λ€λ₯Ό κ²μΌλ‘ μμλ©λλ€. SSM μκ³ λ¦¬μ¦μ ν μ μμΆμ ν¬κ² μμ‘΄νλλ°, μ΄λ matmulκ³Ό λλ±νμ§λ§ μ°μ° μμκ° μ½κ° λ€λ₯΄λ©°, μ΄λ‘ μΈν΄ λ μμ μ λ°λμμ μ°¨μ΄κ° λ 컀μ§λλ€.
λ λ€λ₯Έ μ°Έκ³ μ¬νμΌλ‘, ν¨λ© ν ν°μ ν΄λΉνλ μλ μν(hidden state)μ μ’
λ£λ λ κ³³μμ μ΄λ£¨μ΄μ§λ©° μ£Όλ‘ μΌμͺ½ ν¨λ©μΌλ‘ ν
μ€νΈλμμ΅λλ€. μ€λ₯Έμͺ½ ν¨λ©μ λ
Έμ΄μ¦λ₯Ό μ ννλ―λ‘ λ§μ‘±μ€λ¬μ΄ κ²°κ³Όλ₯Ό 보μ₯νμ§ μμ΅λλ€. tokenizer.padding_side = "left"λ₯Ό μ¬μ©νλ©΄ μ¬λ°λ₯Έ ν¨λ© λ°©ν₯μ μ¬μ©ν μ μμ΅λλ€.
μ΄ λͺ¨λΈμ Molbapμ΄ κΈ°μ¬νμΌλ©°, Anton Vlasjukμ ν° λμμ λ°μμ΅λλ€. μλ³Έ μ½λλ μ΄κ³³μμ νμΈν μ μμ΅λλ€.
μ¬μ©
κ°λ¨ν μμ± μ:
from transformers import Mamba2Config, Mamba2ForCausalLM, AutoTokenizer
import torch
model_id = 'mistralai/Mamba-Codestral-7B-v0.1'
tokenizer = AutoTokenizer.from_pretrained(model_id, revision='refs/pr/9', from_slow=True, legacy=False)
model = Mamba2ForCausalLM.from_pretrained(model_id, revision='refs/pr/9')
input_ids = tokenizer("Hey how are you doing?", return_tensors= "pt")["input_ids"]
out = model.generate(input_ids, max_new_tokens=10)
print(tokenizer.batch_decode(out))
μ΄κ³³μ λ―ΈμΈμ‘°μ μ μν μ΄μ μ€ν¬λ¦½νΈμ λλ€:
from trl import SFTTrainer
from peft import LoraConfig
from transformers import AutoTokenizer, Mamba2ForCausalLM, TrainingArguments
model_id = 'mistralai/Mamba-Codestral-7B-v0.1'
tokenizer = AutoTokenizer.from_pretrained(model_id, revision='refs/pr/9', from_slow=True, legacy=False)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left" #μΌμͺ½ ν¨λ©μΌλ‘ μ€μ
model = Mamba2ForCausalLM.from_pretrained(model_id, revision='refs/pr/9')
dataset = load_dataset("Abirate/english_quotes", split="train")
# CUDA 컀λμμ΄λ, λ°°μΉν¬κΈ° 2κ° 80GB μ₯μΉλ₯Ό νλ μ°¨μ§ν©λλ€.
# νμ§λ§ μ νλλ κ°μν©λλ€.
# μ€νκ³Ό μλλ₯Ό νμν©λλ€!
training_args = TrainingArguments(
output_dir="./results",
num_train_epochs=3,
per_device_train_batch_size=2,
logging_dir='./logs',
logging_steps=10,
learning_rate=2e-3
)
lora_config = LoraConfig(
r=8,
target_modules=["embeddings", "in_proj", "out_proj"],
task_type="CAUSAL_LM",
bias="none"
)
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
args=training_args,
peft_config=lora_config,
train_dataset=dataset,
dataset_text_field="quote",
)
trainer.train()
Mamba2Config
[[autodoc]] Mamba2Config
Mamba2Model
[[autodoc]] Mamba2Model - forward
Mamba2LMHeadModel
[[autodoc]] Mamba2ForCausalLM - forward