DrDavis's picture
Upload folder using huggingface_hub
17c6d62 verified

λ§˜λ°”[[mamba]]

κ°œμš”[[overview]]

λ§˜λ°”(Mamba) λͺ¨λΈμ€ Albert Gu, Tri Daoκ°€ μ œμ•ˆν•œ λ§˜λ°”: 선택적 μƒνƒœ 곡간을 μ΄μš©ν•œ μ„ ν˜• μ‹œκ°„ μ‹œν€€μŠ€ λͺ¨λΈλ§λΌλŠ” λ…Όλ¬Έμ—μ„œ μ†Œκ°œ λ˜μ—ˆμŠ΅λ‹ˆλ‹€.

이 λͺ¨λΈμ€ state-space-models을 기반으둜 ν•œ μƒˆλ‘œμš΄ νŒ¨λŸ¬λ‹€μž„ μ•„ν‚€ν…μ²˜μž…λ‹ˆλ‹€. 직관적인 이해λ₯Ό μ–»κ³  μ‹Άλ‹€λ©΄ 이곳을 μ°Έκ³  ν•˜μ„Έμš”.

ν•΄λ‹Ή λ…Όλ¬Έμ˜ μ΄ˆλ‘μž…λ‹ˆλ‹€:

ν˜„μž¬ λ”₯λŸ¬λ‹μ—μ„œ ν₯미둜운 μ‘μš© ν”„λ‘œκ·Έλž¨μ„ κ΅¬λ™ν•˜λŠ” λŒ€λΆ€λΆ„μ˜ 기초 λͺ¨λΈλ“€μ€ 거의 보편적으둜 트랜슀포머 μ•„ν‚€ν…μ²˜μ™€ κ·Έ 핡심 μ–΄ν…μ…˜ λͺ¨λ“ˆμ„ 기반으둜 ν•©λ‹ˆλ‹€. μ„ ν˜• μ–΄ν…μ…˜, 게이트된 μ»¨λ³Όλ£¨μ…˜κ³Ό μˆœν™˜ λͺ¨λΈ, κ΅¬μ‘°ν™”λœ μƒνƒœ 곡간 λͺ¨λΈ(SSM) λ“± λ§Žμ€ μ€€μ΄μ°¨μ‹œκ°„(subquadratic-time) μ•„ν‚€ν…μ²˜κ°€ κΈ΄ μ‹œν€€μŠ€μ— λŒ€ν•œ 트랜슀포머의 계산 λΉ„νš¨μœ¨μ„±μ„ ν•΄κ²°ν•˜κΈ° μœ„ν•΄ κ°œλ°œλ˜μ—ˆμ§€λ§Œ, 언어와 같은 μ€‘μš”ν•œ μ–‘μ‹μ—μ„œλŠ” μ–΄ν…μ…˜λ§ŒνΌ μ„±λŠ₯을 λ‚΄μ§€ λͺ»ν–ˆμŠ΅λ‹ˆλ‹€. μš°λ¦¬λŠ” μ΄λŸ¬ν•œ λͺ¨λΈμ˜ μ£Όμš” 약점이 λ‚΄μš© 기반 좔둠을 μˆ˜ν–‰ν•˜μ§€ λͺ»ν•œλ‹€λŠ” μ μž„μ„ μ•Œκ³  λͺ‡ κ°€μ§€λ₯Ό κ°œμ„ ν–ˆμŠ΅λ‹ˆλ‹€. 첫째, SSM λ§€κ°œλ³€μˆ˜λ₯Ό μž…λ ₯의 ν•¨μˆ˜λ‘œ λ§Œλ“œλŠ” κ²ƒλ§ŒμœΌλ‘œλ„ 이산 λͺ¨λ‹¬λ¦¬ν‹°(discrete modalities)의 약점을 ν•΄κ²°ν•  수 μžˆμ–΄, ν˜„μž¬ 토큰에 따라 μ‹œν€€μŠ€ 길이 차원을 따라 정보λ₯Ό μ„ νƒμ μœΌλ‘œ μ „νŒŒν•˜κ±°λ‚˜ μžŠμ„ 수 있게 ν•©λ‹ˆλ‹€. λ‘˜μ§Έ, μ΄λŸ¬ν•œ λ³€κ²½μœΌλ‘œ 효율적인 μ»¨λ³Όλ£¨μ…˜μ„ μ‚¬μš©ν•  수 μ—†κ²Œ λ˜μ—ˆμ§€λ§Œ, μš°λ¦¬λŠ” μˆœν™˜ λͺ¨λ“œμ—μ„œ ν•˜λ“œμ›¨μ–΄λ₯Ό μΈμ‹ν•˜λŠ” 병렬 μ•Œκ³ λ¦¬μ¦˜μ„ μ„€κ³„ν–ˆμŠ΅λ‹ˆλ‹€. μš°λ¦¬λŠ” μ΄λŸ¬ν•œ 선택적 SSM을 μ–΄ν…μ…˜μ΄λ‚˜ MLP 블둝도 μ—†λŠ” λ‹¨μˆœν™”λœ 쒅단간 신경망 μ•„ν‚€ν…μ²˜μΈ λ§˜λ°”μ— ν†΅ν•©μ‹œμΌ°μŠ΅λ‹ˆλ‹€. λ§˜λ°”λŠ” λΉ λ₯Έ μΆ”λ‘ (νŠΈλžœμŠ€ν¬λ¨Έλ³΄λ‹€ 5λ°° 높은 μ²˜λ¦¬λŸ‰)κ³Ό μ‹œν€€μŠ€ 길이에 λŒ€ν•œ μ„ ν˜• ν™•μž₯성을 λˆ„λ¦¬λ©°, 백만 길이 μ‹œν€€μŠ€κΉŒμ§€ μ‹€μ œ λ°μ΄ν„°μ—μ„œ μ„±λŠ₯이 ν–₯μƒλ©λ‹ˆλ‹€. 일반적인 μ‹œν€€μŠ€ λͺ¨λΈ λ°±λ³ΈμœΌλ‘œμ„œ λ§˜λ°”λŠ” μ–Έμ–΄, μ˜€λ””μ˜€, μœ μ „μ²΄ν•™κ³Ό 같은 μ—¬λŸ¬ μ–‘μ‹μ—μ„œ μ΅œμ²¨λ‹¨ μ„±λŠ₯을 λ‹¬μ„±ν•©λ‹ˆλ‹€. μ–Έμ–΄ λͺ¨λΈλ§μ—μ„œ 우리의 λ§˜λ°”-3B λͺ¨λΈμ€ 같은 크기의 트랜슀포머λ₯Ό λŠ₯κ°€ν•˜κ³  두 λ°° 크기의 νŠΈλžœμŠ€ν¬λ¨Έμ™€ λ§žλ¨ΉλŠ” μ„±λŠ₯을 보이며, 사전 ν›ˆλ ¨κ³Ό λ‹€μš΄μŠ€νŠΈλ¦Ό 평가 λͺ¨λ‘μ—μ„œ μ„±λŠ₯을 λ‚˜νƒ€λ‚©λ‹ˆλ‹€.

팁:

  • λ§˜λ°”λŠ” 고전적인 νŠΈλžœμŠ€ν¬λ¨Έμ™€ 견쀄 λ§Œν•œ μƒˆλ‘œμš΄ μƒνƒœ 곡간 λͺ¨λΈ μ•„ν‚€ν…μ²˜μž…λ‹ˆλ‹€. μ΄λŠ” κ΅¬μ‘°ν™”λœ μƒνƒœ 곡간 λͺ¨λΈμ˜ λ°œμ „ 선상에 있으며, ν”Œλž˜μ‹œμ–΄ν…μ…˜μ˜ 정신을 λ”°λ₯΄λŠ” 효율적인 ν•˜λ“œμ›¨μ–΄ 인식 섀계와 κ΅¬ν˜„μ„ νŠΉμ§•μœΌλ‘œ ν•©λ‹ˆλ‹€.
  • λ§˜λ°”λŠ” μ–΄ν…μ…˜ λ ˆμ΄μ–΄μ™€ λ™λ“±ν•œ λ―Ήμ„œ(mixer) λ ˆμ΄μ–΄λ₯Ό μŒ“μŠ΅λ‹ˆλ‹€. λ§˜λ°”μ˜ 핡심 λ‘œμ§μ€ MambaMixer ν΄λž˜μŠ€μ— μžˆμŠ΅λ‹ˆλ‹€.
  • 두 κ°€μ§€ κ΅¬ν˜„μ΄ κ³΅μ‘΄ν•©λ‹ˆλ‹€: ν•˜λ‚˜λŠ” μ΅œμ ν™”λ˜μ–΄ λΉ λ₯Έ cuda컀널을 μ‚¬μš©ν•˜κ³ , λ‹€λ₯Έ ν•˜λ‚˜λŠ” λ‹¨μˆœν•˜μ§€λ§Œ λͺ¨λ“  μž₯μΉ˜μ—μ„œ μ‹€ν–‰ν•  수 μžˆμŠ΅λ‹ˆλ‹€!
  • ν˜„μž¬ κ΅¬ν˜„μ€ 원본 cuda컀널을 ν™œμš©ν•©λ‹ˆλ‹€: λ§˜λ°”λ₯Ό μœ„ν•œ ν”Œλž˜μ‹œ μ–΄ν…μ…˜μ˜ 역할을 ν•˜λŠ” 것은 mamba-ssm와 causal_conv1d μ €μž₯μ†Œμ— ν˜ΈμŠ€νŒ…λ˜μ–΄ μžˆμŠ΅λ‹ˆλ‹€. ν•˜λ“œμ›¨μ–΄κ°€ μ§€μ›ν•œλ‹€λ©΄ λ°˜λ“œμ‹œ μ„€μΉ˜ν•˜μ„Έμš”!
  • cuda 컀널을 μ΅œμ ν™”ν•˜λŠ” λ°©ν–₯ λ³΄λ‹€λŠ”, λ‹¨μˆœν•˜μ§€λ§Œ λͺ¨λ“  μž₯μΉ˜μ—μ„œ μ‹€ν–‰κ°€λŠ₯ν•˜λ„λ‘ν•˜λŠ” λ°©ν–₯인 'λ‹¨μˆœκ΅¬ν˜„'의 μ„±λŠ₯을 λΉ λ₯΄κ²Œ ν–₯μƒμ‹œν‚€λŠ” κΈ°μ—¬λ₯Ό 더 ν™˜μ˜ν•˜κ³  μžˆμŠ΅λ‹ˆλ‹€. πŸ€—

이 λͺ¨λΈμ€ ArthurZ에 μ˜ν•΄ κΈ°μ—¬λ˜μ—ˆμŠ΅λ‹ˆλ‹€. 원본 μ½”λ“œλŠ” μ΄κ³³μ—μ„œ 확인할 수 μžˆμŠ΅λ‹ˆλ‹€.

μ‚¬μš©

κ°„λ‹¨ν•œ 생성 예제

from transformers import MambaConfig, MambaForCausalLM, AutoTokenizer
import torch

tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf")
model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf")
input_ids = tokenizer("Hey how are you doing?", return_tensors= "pt")["input_ids"]

out = model.generate(input_ids, max_new_tokens=10)
print(tokenizer.batch_decode(out))

Peft νŒŒμΈνŠœλ‹

느린 버전은 ν•™μŠ΅μ—μ„œ μ•„μ£Ό μ•ˆμ •μ μ΄μ§„ μ•ŠμŠ΅λ‹ˆλ‹€. λΉ λ₯Έ 버전은 float32κ°€ ν•„μš”ν•©λ‹ˆλ‹€!

from datasets import load_dataset
from trl import SFTTrainer
from peft import LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
model_id = "state-spaces/mamba-130m-hf"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)
dataset = load_dataset("Abirate/english_quotes", split="train")
training_args = TrainingArguments(
    output_dir="./results",
    num_train_epochs=3,
    per_device_train_batch_size=4,
    logging_dir='./logs',
    logging_steps=10,
    learning_rate=2e-3
)
lora_config =  LoraConfig(
        r=8,
        target_modules=["x_proj", "embeddings", "in_proj", "out_proj"],
        task_type="CAUSAL_LM",
        bias="none"
)
trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    peft_config=lora_config,
    train_dataset=dataset,
    dataset_text_field="quote",
)
trainer.train()

MambaConfig

[[autodoc]] MambaConfig

MambaModel

[[autodoc]] MambaModel - forward

MambaLMHeadModel

[[autodoc]] MambaForCausalLM - forward