λ§λ°[[mamba]]
κ°μ[[overview]]
λ§λ°(Mamba) λͺ¨λΈμ Albert Gu, Tri Daoκ° μ μν λ§λ°: μ νμ μν 곡κ°μ μ΄μ©ν μ ν μκ° μνμ€ λͺ¨λΈλ§λΌλ λ Όλ¬Έμμ μκ° λμμ΅λλ€.
μ΄ λͺ¨λΈμ state-space-modelsμ κΈ°λ°μΌλ‘ ν μλ‘μ΄ ν¨λ¬λ€μ μν€ν
μ²μ
λλ€. μ§κ΄μ μΈ μ΄ν΄λ₯Ό μ»κ³ μΆλ€λ©΄ μ΄κ³³μ μ°Έκ³ νμΈμ.
ν΄λΉ λ Όλ¬Έμ μ΄λ‘μ λλ€:
νμ¬ λ₯λ¬λμμ ν₯λ―Έλ‘μ΄ μμ© νλ‘κ·Έλ¨μ ꡬλνλ λλΆλΆμ κΈ°μ΄ λͺ¨λΈλ€μ κ±°μ 보νΈμ μΌλ‘ νΈλμ€ν¬λ¨Έ μν€ν μ²μ κ·Έ ν΅μ¬ μ΄ν μ λͺ¨λμ κΈ°λ°μΌλ‘ ν©λλ€. μ ν μ΄ν μ , κ²μ΄νΈλ 컨볼루μ κ³Ό μν λͺ¨λΈ, ꡬ쑰νλ μν κ³΅κ° λͺ¨λΈ(SSM) λ± λ§μ μ€μ΄μ°¨μκ°(subquadratic-time) μν€ν μ²κ° κΈ΄ μνμ€μ λν νΈλμ€ν¬λ¨Έμ κ³μ° λΉν¨μ¨μ±μ ν΄κ²°νκΈ° μν΄ κ°λ°λμμ§λ§, μΈμ΄μ κ°μ μ€μν μμμμλ μ΄ν μ λ§νΌ μ±λ₯μ λ΄μ§ λͺ»νμ΅λλ€. μ°λ¦¬λ μ΄λ¬ν λͺ¨λΈμ μ£Όμ μ½μ μ΄ λ΄μ© κΈ°λ° μΆλ‘ μ μννμ§ λͺ»νλ€λ μ μμ μκ³ λͺ κ°μ§λ₯Ό κ°μ νμ΅λλ€. 첫째, SSM λ§€κ°λ³μλ₯Ό μ λ ₯μ ν¨μλ‘ λ§λλ κ²λ§μΌλ‘λ μ΄μ° λͺ¨λ¬λ¦¬ν°(discrete modalities)μ μ½μ μ ν΄κ²°ν μ μμ΄, νμ¬ ν ν°μ λ°λΌ μνμ€ κΈΈμ΄ μ°¨μμ λ°λΌ μ 보λ₯Ό μ νμ μΌλ‘ μ ννκ±°λ μμ μ μκ² ν©λλ€. λμ§Έ, μ΄λ¬ν λ³κ²½μΌλ‘ ν¨μ¨μ μΈ μ»¨λ³Όλ£¨μ μ μ¬μ©ν μ μκ² λμμ§λ§, μ°λ¦¬λ μν λͺ¨λμμ νλμ¨μ΄λ₯Ό μΈμνλ λ³λ ¬ μκ³ λ¦¬μ¦μ μ€κ³νμ΅λλ€. μ°λ¦¬λ μ΄λ¬ν μ νμ SSMμ μ΄ν μ μ΄λ MLP λΈλ‘λ μλ λ¨μνλ μ’ λ¨κ° μ κ²½λ§ μν€ν μ²μΈ λ§λ°μ ν΅ν©μμΌ°μ΅λλ€. λ§λ°λ λΉ λ₯Έ μΆλ‘ (νΈλμ€ν¬λ¨Έλ³΄λ€ 5λ°° λμ μ²λ¦¬λ)κ³Ό μνμ€ κΈΈμ΄μ λν μ ν νμ₯μ±μ λ리며, λ°±λ§ κΈΈμ΄ μνμ€κΉμ§ μ€μ λ°μ΄ν°μμ μ±λ₯μ΄ ν₯μλ©λλ€. μΌλ°μ μΈ μνμ€ λͺ¨λΈ λ°±λ³ΈμΌλ‘μ λ§λ°λ μΈμ΄, μ€λμ€, μ μ 체νκ³Ό κ°μ μ¬λ¬ μμμμ μ΅μ²¨λ¨ μ±λ₯μ λ¬μ±ν©λλ€. μΈμ΄ λͺ¨λΈλ§μμ μ°λ¦¬μ λ§λ°-3B λͺ¨λΈμ κ°μ ν¬κΈ°μ νΈλμ€ν¬λ¨Έλ₯Ό λ₯κ°νκ³ λ λ°° ν¬κΈ°μ νΈλμ€ν¬λ¨Έμ λ§λ¨Ήλ μ±λ₯μ 보μ΄λ©°, μ¬μ νλ ¨κ³Ό λ€μ΄μ€νΈλ¦Ό νκ° λͺ¨λμμ μ±λ₯μ λνλ©λλ€.
ν:
- λ§λ°λ κ³ μ μ μΈ νΈλμ€ν¬λ¨Έμ κ²¬μ€ λ§ν μλ‘μ΄
μν κ³΅κ° λͺ¨λΈμν€ν μ²μ λλ€. μ΄λ ꡬ쑰νλ μν κ³΅κ° λͺ¨λΈμ λ°μ μ μμ μμΌλ©°, νλμμ΄ν μ μ μ μ μ λ°λ₯΄λ ν¨μ¨μ μΈ νλμ¨μ΄ μΈμ μ€κ³μ ꡬνμ νΉμ§μΌλ‘ ν©λλ€. - λ§λ°λ
μ΄ν μ λ μ΄μ΄μ λλ±νλ―Ήμ(mixer)λ μ΄μ΄λ₯Ό μμ΅λλ€.λ§λ°μ ν΅μ¬ λ‘μ§μMambaMixerν΄λμ€μ μμ΅λλ€. - λ κ°μ§ ꡬνμ΄ κ³΅μ‘΄ν©λλ€: νλλ μ΅μ νλμ΄ λΉ λ₯Έ cuda컀λμ μ¬μ©νκ³ , λ€λ₯Έ νλλ λ¨μνμ§λ§ λͺ¨λ μ₯μΉμμ μ€νν μ μμ΅λλ€!
- νμ¬ κ΅¬νμ μλ³Έ cuda컀λμ νμ©ν©λλ€: λ§λ°λ₯Ό μν νλμ μ΄ν
μ
μ μν μ νλ κ²μ
mamba-ssmμcausal_conv1dμ μ₯μμ νΈμ€ν λμ΄ μμ΅λλ€. νλμ¨μ΄κ° μ§μνλ€λ©΄ λ°λμ μ€μΉνμΈμ! - cuda 컀λμ μ΅μ ννλ λ°©ν₯ 보λ€λ, λ¨μνμ§λ§ λͺ¨λ μ₯μΉμμ μ€νκ°λ₯νλλ‘νλ λ°©ν₯μΈ 'λ¨μꡬν'μ μ±λ₯μ λΉ λ₯΄κ² ν₯μμν€λ κΈ°μ¬λ₯Ό λ νμνκ³ μμ΅λλ€. π€
μ΄ λͺ¨λΈμ ArthurZμ μν΄ κΈ°μ¬λμμ΅λλ€. μλ³Έ μ½λλ μ΄κ³³μμ νμΈν μ μμ΅λλ€.
μ¬μ©
κ°λ¨ν μμ± μμ
from transformers import MambaConfig, MambaForCausalLM, AutoTokenizer
import torch
tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf")
model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf")
input_ids = tokenizer("Hey how are you doing?", return_tensors= "pt")["input_ids"]
out = model.generate(input_ids, max_new_tokens=10)
print(tokenizer.batch_decode(out))
Peft νμΈνλ
λλ¦° λ²μ μ νμ΅μμ μμ£Ό μμ μ μ΄μ§ μμ΅λλ€. λΉ λ₯Έ λ²μ μ float32κ° νμν©λλ€!
from datasets import load_dataset
from trl import SFTTrainer
from peft import LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
model_id = "state-spaces/mamba-130m-hf"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)
dataset = load_dataset("Abirate/english_quotes", split="train")
training_args = TrainingArguments(
output_dir="./results",
num_train_epochs=3,
per_device_train_batch_size=4,
logging_dir='./logs',
logging_steps=10,
learning_rate=2e-3
)
lora_config = LoraConfig(
r=8,
target_modules=["x_proj", "embeddings", "in_proj", "out_proj"],
task_type="CAUSAL_LM",
bias="none"
)
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
args=training_args,
peft_config=lora_config,
train_dataset=dataset,
dataset_text_field="quote",
)
trainer.train()
MambaConfig
[[autodoc]] MambaConfig
MambaModel
[[autodoc]] MambaModel - forward
MambaLMHeadModel
[[autodoc]] MambaForCausalLM - forward