AbdulElahGwaith's picture
Upload folder using huggingface_hub
a9bd396 verified
PyTorch FlashAttention SDPA

Jamba[[jamba]]

Jamba๋Š” Transformer์™€ Mamba ๊ธฐ๋ฐ˜์˜ ํ•˜์ด๋ธŒ๋ฆฌ๋“œ ์ „๋ฌธ๊ฐ€ ํ˜ผํ•ฉ(MoE) ์–ธ์–ด ๋ชจ๋ธ๋กœ, ์ด ๋งค๊ฐœ๋ณ€์ˆ˜ ์ˆ˜๋Š” 52B์—์„œ 398B๊นŒ์ง€ ๋‹ค์–‘ํ•ฉ๋‹ˆ๋‹ค. ์ด ๋ชจ๋ธ์€ Transformer ๋ชจ๋ธ์˜ ์„ฑ๋Šฅ๊ณผ Mamba์™€ ๊ฐ™์€ ์ƒํƒœ ๊ณต๊ฐ„ ๋ชจ๋ธ์˜ ํšจ์œจ์„ฑ ๋ฐ ๊ธด ์ปจํ…์ŠคํŠธ ์ฒ˜๋ฆฌ ๋Šฅ๋ ฅ(256K ํ† ํฐ)์„ ๋ชจ๋‘ ํ™œ์šฉํ•˜๋Š” ๊ฒƒ์„ ๋ชฉํ‘œ๋กœ ํ•ฉ๋‹ˆ๋‹ค.

Jamba์˜ ์•„ํ‚คํ…์ฒ˜๋Š” ๋ธ”๋ก๊ณผ ๋ ˆ์ด์–ด ๊ธฐ๋ฐ˜ ๊ตฌ์กฐ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ Transformer์™€ Mamba ์•„ํ‚คํ…์ฒ˜๋ฅผ ํ†ตํ•ฉํ•  ์ˆ˜ ์žˆ๋„๋ก ์„ค๊ณ„๋˜์—ˆ์Šต๋‹ˆ๋‹ค. ๊ฐ Jamba ๋ธ”๋ก์€ ์–ดํ…์…˜ ๋ ˆ์ด์–ด ๋˜๋Š” Mamba ๋ ˆ์ด์–ด ์ค‘ ํ•˜๋‚˜์™€ ๊ทธ ๋’ค๋ฅผ ์ž‡๋Š” ๋‹ค์ธต ํผ์…‰ํŠธ๋ก (MLP)์œผ๋กœ ๊ตฌ์„ฑ๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค. Transformer ๋ ˆ์ด์–ด๋Š” 8๊ฐœ์˜ ๋ ˆ์ด์–ด ์ค‘ ํ•˜๋‚˜์˜ ๋น„์œจ๋กœ ์ฃผ๊ธฐ์ ์œผ๋กœ ๋ฐฐ์น˜๋ฉ๋‹ˆ๋‹ค. ๋˜ํ•œ ๋ชจ๋ธ ์šฉ๋Ÿ‰์„ ํ™•์žฅํ•˜๊ธฐ ์œ„ํ•ด MoE ๋ ˆ์ด์–ด๊ฐ€ ํ˜ผํ•ฉ๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค.

๋ชจ๋“  ์›๋ณธ Jamba ์ฒดํฌํฌ์ธํŠธ๋Š” AI21 ์กฐ์ง์—์„œ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

์˜ค๋ฅธ์ชฝ ์‚ฌ์ด๋“œ๋ฐ”์— ์žˆ๋Š” Jamba ๋ชจ๋ธ์„ ๋ˆ„๋ฅด๋ฉด ๋‹ค์–‘ํ•œ ์–ธ์–ด ์ž‘์—…์— Jamba๋ฅผ ์ ์šฉํ•˜๋Š” ์˜ˆ์ œ๋ฅผ ๋” ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

์•„๋ž˜ ์˜ˆ์ œ๋Š” [Pipeline]๊ณผ [AutoModel], ๊ทธ๋ฆฌ๊ณ  ์ปค๋งจ๋“œ๋ผ์ธ์„ ํ†ตํ•ด ํ…์ŠคํŠธ๋ฅผ ์ƒ์„ฑํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ๋ณด์—ฌ์ค๋‹ˆ๋‹ค.

# ์ตœ์ ํ™”๋œ Mamba ๊ตฌํ˜„ ์„ค์น˜
# !pip install mamba-ssm causal-conv1d>=1.2.0
import torch
from transformers import pipeline

pipeline = pipeline(
    task="text-generation",
    model="ai21labs/AI21-Jamba-Mini-1.6",
    dtype=torch.float16,
    device=0
)
pipeline("Plants create energy through a process known as")
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(
    "ai21labs/AI21-Jamba-Large-1.6",
)
model = AutoModelForCausalLM.from_pretrained(
    "ai21labs/AI21-Jamba-Large-1.6",
    dtype=torch.float16,
    device_map="auto",
    attn_implementation="sdpa"
)
input_ids = tokenizer("Plants create energy through a process known as", return_tensors="pt").to("cuda")

output = model.generate(**input_ids, cache_implementation="static")
print(tokenizer.decode(output[0], skip_special_tokens=True))
echo -e "Plants create energy through a process known as" | transformers run --task text-generation --model ai21labs/AI21-Jamba-Mini-1.6 --device 0

์–‘์žํ™”๋Š” ๊ฐ€์ค‘์น˜๋ฅผ ๋” ๋‚ฎ์€ ์ •๋ฐ€๋„๋กœ ํ‘œํ˜„ํ•˜์—ฌ ๋Œ€๊ทœ๋ชจ ๋ชจ๋ธ์˜ ๋ฉ”๋ชจ๋ฆฌ ๋ถ€๋‹ด์„ ์ค„์—ฌ์ค๋‹ˆ๋‹ค. ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋Š” ๋‹ค์–‘ํ•œ ์–‘์žํ™” ๋ฐฑ์—”๋“œ์— ๋Œ€ํ•ด์„œ๋Š” Quantization๋ฅผ ์ฐธ๊ณ ํ•˜์„ธ์š”.

์•„๋ž˜ ์˜ˆ์ œ๋Š” bitsandbytes๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๊ฐ€์ค‘์น˜๋งŒ 8๋น„ํŠธ๋กœ ์–‘์žํ™”ํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ๋ณด์—ฌ์ค๋‹ˆ๋‹ค.

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

quantization_config = BitsAndBytesConfig(load_in_8bit=True,
                                         llm_int8_skip_modules=["mamba"])

# ๋ชจ๋ธ์„ 8๊ฐœ์˜ GPU์— ๊ณ ๋ฅด๊ฒŒ ๋ถ„์‚ฐ์‹œํ‚ค๊ธฐ ์œ„ํ•œ ๋””๋ฐ”์ด์Šค ๋งต
device_map = {'model.embed_tokens': 0, 'model.layers.0': 0, 'model.layers.1': 0, 'model.layers.2': 0, 'model.layers.3': 0, 'model.layers.4': 0, 'model.layers.5': 0, 'model.layers.6': 0, 'model.layers.7': 0, 'model.layers.8': 0, 'model.layers.9': 1, 'model.layers.10': 1, 'model.layers.11': 1, 'model.layers.12': 1, 'model.layers.13': 1, 'model.layers.14': 1, 'model.layers.15': 1, 'model.layers.16': 1, 'model.layers.17': 1, 'model.layers.18': 2, 'model.layers.19': 2, 'model.layers.20': 2, 'model.layers.21': 2, 'model.layers.22': 2, 'model.layers.23': 2, 'model.layers.24': 2, 'model.layers.25': 2, 'model.layers.26': 2, 'model.layers.27': 3, 'model.layers.28': 3, 'model.layers.29': 3, 'model.layers.30': 3, 'model.layers.31': 3, 'model.layers.32': 3, 'model.layers.33': 3, 'model.layers.34': 3, 'model.layers.35': 3, 'model.layers.36': 4, 'model.layers.37': 4, 'model.layers.38': 4, 'model.layers.39': 4, 'model.layers.40': 4, 'model.layers.41': 4, 'model.layers.42': 4, 'model.layers.43': 4, 'model.layers.44': 4, 'model.layers.45': 5, 'model.layers.46': 5, 'model.layers.47': 5, 'model.layers.48': 5, 'model.layers.49': 5, 'model.layers.50': 5, 'model.layers.51': 5, 'model.layers.52': 5, 'model.layers.53': 5, 'model.layers.54': 6, 'model.layers.55': 6, 'model.layers.56': 6, 'model.layers.57': 6, 'model.layers.58': 6, 'model.layers.59': 6, 'model.layers.60': 6, 'model.layers.61': 6, 'model.layers.62': 6, 'model.layers.63': 7, 'model.layers.64': 7, 'model.layers.65': 7, 'model.layers.66': 7, 'model.layers.67': 7, 'model.layers.68': 7, 'model.layers.69': 7, 'model.layers.70': 7, 'model.layers.71': 7, 'model.final_layernorm': 7, 'lm_head': 7}
model = AutoModelForCausalLM.from_pretrained("ai21labs/AI21-Jamba-Large-1.6",
                                             dtype=torch.bfloat16,
                    attn_implementation="flash_attention_2",
                                             quantization_config=quantization_config,
                                             device_map=device_map)

tokenizer = AutoTokenizer.from_pretrained("ai21labs/AI21-Jamba-Large-1.6")

messages = [
   {"role": "system", "content": "You are an ancient oracle who speaks in cryptic but wise phrases, always hinting at deeper meanings."},
   {"role": "user", "content": "Hello!"},
]

input_ids = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors='pt').to(model.device)

outputs = model.generate(input_ids, max_new_tokens=216)

# ์ถœ๋ ฅ ๋””์ฝ”๋”ฉ
conversation = tokenizer.decode(outputs[0], skip_special_tokens=True)

# ์–ด์‹œ์Šคํ„ดํŠธ์˜ ์‘๋‹ต๋งŒ ์ถ”์ถœ
assistant_response = conversation.split(messages[-1]['content'])[1].strip()
print(assistant_response)
# ์ถœ๋ ฅ: Seek and you shall find. The path is winding, but the journey is enlightening. What wisdom do you seek from the ancient echoes?

์ฐธ๊ณ [[notes]]

  • ๋ชจ๋ธ ์„ฑ๋Šฅ ์ €ํ•˜๋ฅผ ๋ฐฉ์ง€ํ•˜๊ธฐ ์œ„ํ•ด Mamba ๋ธ”๋ก์€ ์–‘์žํ™”ํ•˜์ง€ ๋งˆ์„ธ์š”.

  • ์ตœ์ ํ™”๋œ Mamba ์ปค๋„ ์—†์ด Mamba๋ฅผ ์‚ฌ์šฉํ•˜๋ฉด ์ง€์—ฐ ์‹œ๊ฐ„์ด ํฌ๊ฒŒ ์ฆ๊ฐ€ํ•˜๋ฏ€๋กœ ๊ถŒ์žฅ๋˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค. ๊ทธ๋ž˜๋„ ์ปค๋„ ์—†์ด Mamba๋ฅผ ์‚ฌ์šฉํ•˜๊ณ ์ž ํ•œ๋‹ค๋ฉด [~AutoModel.from_pretrained]์—์„œ use_mamba_kernels=False๋กœ ์„ค์ •ํ•˜์„ธ์š”.

    import torch
    from transformers import AutoModelForCausalLM
    model = AutoModelForCausalLM.from_pretrained("ai21labs/AI21-Jamba-1.5-Large",
                                                 use_mamba_kernels=False)
    

JambaConfig[[transformers.JambaConfig]]

[[autodoc]] JambaConfig

JambaModel[[transformers.JambaModel]]

[[autodoc]] JambaModel - forward

JambaForCausalLM[[transformers.JambaForCausalLM]]

[[autodoc]] JambaForCausalLM - forward

JambaForSequenceClassification[[transformers.JambaForSequenceClassification]]

[[autodoc]] transformers.JambaForSequenceClassification - forward