DrDavis's picture
Upload folder using huggingface_hub
17c6d62 verified

λ§˜λ°”2[[mamba-2]]

κ°œμš”[[overview]]

λ§˜λ°”2 λͺ¨λΈμ€ Tri Dao, Albert Guκ°€ μ œμ•ˆν•œ νŠΈλžœμŠ€ν¬λ¨ΈλŠ” SSM이닀: κ΅¬μ‘°ν™”λœ μƒνƒœ 곡간 이쀑성을 ν†΅ν•œ μΌλ°˜ν™”λœ λͺ¨λΈκ³Ό 효율적인 μ•Œκ³ λ¦¬μ¦˜λΌλŠ” λ…Όλ¬Έμ—μ„œ μ†Œκ°œλ˜μ—ˆμŠ΅λ‹ˆλ‹€. λ§˜λ°”2λŠ” λ§˜λ°”1κ³Ό μœ μ‚¬ν•œ μƒνƒœ 곡간 λͺ¨λΈλ‘œ, λ‹¨μˆœν™”λœ μ•„ν‚€ν…μ²˜μ—μ„œ 더 λ‚˜μ€ μ„±λŠ₯을 λ³΄μž…λ‹ˆλ‹€.

ν•΄λ‹Ή λ…Όλ¬Έμ˜ μ΄ˆλ‘μž…λ‹ˆλ‹€:

νŠΈλžœμŠ€ν¬λ¨ΈλŠ” μ–Έμ–΄ λͺ¨λΈλ§μ—μ„œ λ”₯λŸ¬λ‹ μ„±κ³΅μ˜ μ£Όμš” μ•„ν‚€ν…μ²˜μ˜€μ§€λ§Œ, λ§˜λ°”μ™€ 같은 μƒνƒœ 곡간 λͺ¨λΈ(SSM)이 졜근 μ†Œκ·œλͺ¨ ν˜Ήμ€ 쀑간 규λͺ¨μ—μ„œ νŠΈλžœμŠ€ν¬λ¨Έμ™€ λŒ€λ“±ν•˜κ±°λ‚˜ 더 λ‚˜μ€ μ„±λŠ₯을 λ³΄μ΄λŠ” κ²ƒμœΌλ‘œ λ‚˜νƒ€λ‚¬μŠ΅λ‹ˆλ‹€. μš°λ¦¬λŠ” μ΄λŸ¬ν•œ λͺ¨λΈ 계열듀이 μ‹€μ œλ‘œ 맀우 λ°€μ ‘ν•˜κ²Œ μ—°κ΄€λ˜μ–΄ μžˆμŒμ„ νŒŒμ•…ν–ˆμŠ΅λ‹ˆλ‹€. 그리고 κ΅¬μ‘°ν™”λœ 쀀뢄리(semiseparable) ν–‰λ ¬ 쀑 연ꡬ가 잘 이루어진 클래슀의 λ‹€μ–‘ν•œ λΆ„ν•΄λ₯Ό 톡해 μ—°κ²°λœ SSMκ³Ό μ–΄ν…μ…˜ λ³€ν˜• μ‚¬μ΄μ˜ ν’λΆ€ν•œ 이둠적 μ—°κ²° ν”„λ ˆμž„μ›Œν¬λ₯Ό κ°œλ°œν–ˆμŠ΅λ‹ˆλ‹€. μƒνƒœ 곡간 이쀑성(SSD) ν”„λ ˆμž„μ›Œν¬λ₯Ό 톡해 λ§˜λ°”1의 선택적 SSM을 κ°œμ„ ν•œ μƒˆλ‘œμš΄ μ•„ν‚€ν…μ²˜λ₯Ό 섀계할 수 μžˆμ—ˆκ³ , νŠΈλžœμŠ€ν¬λ¨Έμ™€ 경쟁λ ₯을 μœ μ§€ν•˜λ©΄μ„œλ„ μ†λ„λŠ” 2~8λ°° 더 λΉ λ₯Έ μ„±λŠ₯을 λƒ…λ‹ˆλ‹€.

팁:

이 버전은 λ§˜λ°”2 κ΅¬ν˜„μ„ 지원해야 ν•˜λ©°, 특히 Mistral AI의 Mamba-2 codestral을 μ§€μ›ν•©λ‹ˆλ‹€. 특히, mamba 2 codestral은 8개의 groups둜 μΆœμ‹œλ˜μ—ˆλŠ”λ°, μ΄λŠ” μ–΄ν…μ…˜ 기반 λͺ¨λΈμ˜ KV ν—€λ“œ μˆ˜μ™€ μœ μ‚¬ν•˜λ‹€κ³  νŒλ‹¨ κ°€λŠ₯ν•©λ‹ˆλ‹€.

이 λͺ¨λΈμ€ torch_forward와 cuda_kernels_forwardλΌλŠ” 두 κ°€μ§€ λ‹€λ₯Έ μ „λ°© 패슀λ₯Ό κ°€μ§‘λ‹ˆλ‹€. cuda_kernels_forwardλŠ” ν™˜κ²½μ—μ„œ cuda 컀널을 찾으면 이λ₯Ό μ‚¬μš©ν•˜λ©°, prefillμ—μ„œλŠ” 더 λŠλ¦½λ‹ˆλ‹€. 즉, 높은 CPU μ˜€λ²„ν—€λ“œλ‘œ 인해 "μ›œμ—… μ‹€ν–‰"이 ν•„μš”ν•˜κΈ° λ•Œλ¬Έμž…λ‹ˆλ‹€. κ΄€λ ¨ λ‚΄μš©μ€ 이곳과 이곳을 μ°Έκ³ ν•˜μ„Έμš”.

컴파일 μ—†μ΄λŠ” torch_forward κ΅¬ν˜„μ΄ 3~4λ°° λΉ λ¦…λ‹ˆλ‹€. λ˜ν•œ, 이 λͺ¨λΈμ—λŠ” μœ„μΉ˜ μž„λ² λ”©μ΄ μ—†μ§€λ§Œ attention_mask와 배치 μƒμ„±μ˜ 경우 두 κ³³μ—μ„œ 은닉 μƒνƒœ(hidden state)λ₯Ό λ§ˆμŠ€ν‚Ήν•˜λŠ” νŠΉμ • 둜직이 μžˆμŠ΅λ‹ˆλ‹€. κ΄€λ ¨ λ‚΄μš©μ€ 이곳을 μ°Έκ³ ν•˜μ„Έμš”.

μ΄λ‘œμΈν•΄ λ§˜λ°”2 μ»€λ„μ˜ μž¬κ΅¬ν˜„κ³Ό ν•¨κ»˜ 배치 생성 및 μΊμ‹œλœ μƒμ„±μ—μ„œ μ•½κ°„μ˜ 차이가 μ˜ˆμƒλ©λ‹ˆλ‹€. λ˜ν•œ cuda 컀널 λ˜λŠ” torch forwardκ°€ μ œκ³΅ν•˜λŠ” κ²°κ³Όκ°€ μ•½κ°„ λ‹€λ₯Ό κ²ƒμœΌλ‘œ μ˜ˆμƒλ©λ‹ˆλ‹€. SSM μ•Œκ³ λ¦¬μ¦˜μ€ ν…μ„œ μˆ˜μΆ•μ— 크게 μ˜μ‘΄ν•˜λŠ”λ°, μ΄λŠ” matmulκ³Ό λ™λ“±ν•˜μ§€λ§Œ μ—°μ‚° μˆœμ„œκ°€ μ•½κ°„ λ‹€λ₯΄λ©°, 이둜 인해 더 μž‘μ€ μ •λ°€λ„μ—μ„œ 차이가 더 μ»€μ§‘λ‹ˆλ‹€.

또 λ‹€λ₯Έ μ°Έκ³ μ‚¬ν•­μœΌλ‘œ, νŒ¨λ”© 토큰에 ν•΄λ‹Ήν•˜λŠ” 은닉 μƒνƒœ(hidden state)의 μ’…λ£ŒλŠ” 두 κ³³μ—μ„œ 이루어지며 주둜 μ™Όμͺ½ νŒ¨λ”©μœΌλ‘œ ν…ŒμŠ€νŠΈλ˜μ—ˆμŠ΅λ‹ˆλ‹€. 였λ₯Έμͺ½ νŒ¨λ”©μ€ λ…Έμ΄μ¦ˆλ₯Ό μ „νŒŒν•˜λ―€λ‘œ 만쑱슀러운 κ²°κ³Όλ₯Ό 보μž₯ν•˜μ§€ μ•ŠμŠ΅λ‹ˆλ‹€. tokenizer.padding_side = "left"λ₯Ό μ‚¬μš©ν•˜λ©΄ μ˜¬λ°”λ₯Έ νŒ¨λ”© λ°©ν–₯을 μ‚¬μš©ν•  수 μžˆμŠ΅λ‹ˆλ‹€.

이 λͺ¨λΈμ€ Molbap이 κΈ°μ—¬ν–ˆμœΌλ©°, Anton Vlasjuk의 큰 도움을 λ°›μ•˜μŠ΅λ‹ˆλ‹€. 원본 μ½”λ“œλŠ” μ΄κ³³μ—μ„œ 확인할 수 μžˆμŠ΅λ‹ˆλ‹€.

μ‚¬μš©

κ°„λ‹¨ν•œ 생성 예:

from transformers import Mamba2Config, Mamba2ForCausalLM, AutoTokenizer
import torch
model_id = 'mistralai/Mamba-Codestral-7B-v0.1'
tokenizer = AutoTokenizer.from_pretrained(model_id, revision='refs/pr/9', from_slow=True, legacy=False)
model = Mamba2ForCausalLM.from_pretrained(model_id, revision='refs/pr/9')
input_ids = tokenizer("Hey how are you doing?", return_tensors= "pt")["input_ids"]

out = model.generate(input_ids, max_new_tokens=10)
print(tokenizer.batch_decode(out))

이곳은 미세쑰정을 μœ„ν•œ μ΄ˆμ•ˆ μŠ€ν¬λ¦½νŠΈμž…λ‹ˆλ‹€:

from trl import SFTTrainer
from peft import LoraConfig
from transformers import AutoTokenizer, Mamba2ForCausalLM, TrainingArguments
model_id = 'mistralai/Mamba-Codestral-7B-v0.1'
tokenizer = AutoTokenizer.from_pretrained(model_id, revision='refs/pr/9', from_slow=True, legacy=False)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left" #μ™Όμͺ½ νŒ¨λ”©μœΌλ‘œ μ„€μ •

model = Mamba2ForCausalLM.from_pretrained(model_id, revision='refs/pr/9')
dataset = load_dataset("Abirate/english_quotes", split="train")
# CUDA μ»€λ„μ—†μ΄λŠ”, 배치크기 2κ°€ 80GB μž₯치λ₯Ό ν•˜λ‚˜ μ°¨μ§€ν•©λ‹ˆλ‹€.
# ν•˜μ§€λ§Œ μ •ν™•λ„λŠ” κ°μ†Œν•©λ‹ˆλ‹€.
# μ‹€ν—˜κ³Ό μ‹œλ„λ₯Ό ν™˜μ˜ν•©λ‹ˆλ‹€!
training_args = TrainingArguments(
    output_dir="./results",
    num_train_epochs=3,
    per_device_train_batch_size=2,
    logging_dir='./logs',
    logging_steps=10,
    learning_rate=2e-3
)
lora_config =  LoraConfig(
        r=8,
        target_modules=["embeddings", "in_proj", "out_proj"],
        task_type="CAUSAL_LM",
        bias="none"
)
trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    peft_config=lora_config,
    train_dataset=dataset,
    dataset_text_field="quote",
)
trainer.train()

Mamba2Config

[[autodoc]] Mamba2Config

Mamba2Model

[[autodoc]] Mamba2Model - forward

Mamba2LMHeadModel

[[autodoc]] Mamba2ForCausalLM - forward