Jamba[[jamba]]
Jamba๋ Transformer์ Mamba ๊ธฐ๋ฐ์ ํ์ด๋ธ๋ฆฌ๋ ์ ๋ฌธ๊ฐ ํผํฉ(MoE) ์ธ์ด ๋ชจ๋ธ๋ก, ์ด ๋งค๊ฐ๋ณ์ ์๋ 52B์์ 398B๊น์ง ๋ค์ํฉ๋๋ค. ์ด ๋ชจ๋ธ์ Transformer ๋ชจ๋ธ์ ์ฑ๋ฅ๊ณผ Mamba์ ๊ฐ์ ์ํ ๊ณต๊ฐ ๋ชจ๋ธ์ ํจ์จ์ฑ ๋ฐ ๊ธด ์ปจํ ์คํธ ์ฒ๋ฆฌ ๋ฅ๋ ฅ(256K ํ ํฐ)์ ๋ชจ๋ ํ์ฉํ๋ ๊ฒ์ ๋ชฉํ๋ก ํฉ๋๋ค.
Jamba์ ์ํคํ ์ฒ๋ ๋ธ๋ก๊ณผ ๋ ์ด์ด ๊ธฐ๋ฐ ๊ตฌ์กฐ๋ฅผ ์ฌ์ฉํ์ฌ Transformer์ Mamba ์ํคํ ์ฒ๋ฅผ ํตํฉํ ์ ์๋๋ก ์ค๊ณ๋์์ต๋๋ค. ๊ฐ Jamba ๋ธ๋ก์ ์ดํ ์ ๋ ์ด์ด ๋๋ Mamba ๋ ์ด์ด ์ค ํ๋์ ๊ทธ ๋ค๋ฅผ ์๋ ๋ค์ธต ํผ์ ํธ๋ก (MLP)์ผ๋ก ๊ตฌ์ฑ๋์ด ์์ต๋๋ค. Transformer ๋ ์ด์ด๋ 8๊ฐ์ ๋ ์ด์ด ์ค ํ๋์ ๋น์จ๋ก ์ฃผ๊ธฐ์ ์ผ๋ก ๋ฐฐ์น๋ฉ๋๋ค. ๋ํ ๋ชจ๋ธ ์ฉ๋์ ํ์ฅํ๊ธฐ ์ํด MoE ๋ ์ด์ด๊ฐ ํผํฉ๋์ด ์์ต๋๋ค.
๋ชจ๋ ์๋ณธ Jamba ์ฒดํฌํฌ์ธํธ๋ AI21 ์กฐ์ง์์ ํ์ธํ ์ ์์ต๋๋ค.
์ค๋ฅธ์ชฝ ์ฌ์ด๋๋ฐ์ ์๋ Jamba ๋ชจ๋ธ์ ๋๋ฅด๋ฉด ๋ค์ํ ์ธ์ด ์์ ์ Jamba๋ฅผ ์ ์ฉํ๋ ์์ ๋ฅผ ๋ ํ์ธํ ์ ์์ต๋๋ค.
์๋ ์์ ๋ [Pipeline]๊ณผ [AutoModel], ๊ทธ๋ฆฌ๊ณ ์ปค๋งจ๋๋ผ์ธ์ ํตํด ํ
์คํธ๋ฅผ ์์ฑํ๋ ๋ฐฉ๋ฒ์ ๋ณด์ฌ์ค๋๋ค.
# ์ต์ ํ๋ Mamba ๊ตฌํ ์ค์น
# !pip install mamba-ssm causal-conv1d>=1.2.0
import torch
from transformers import pipeline
pipeline = pipeline(
task="text-generation",
model="ai21labs/AI21-Jamba-Mini-1.6",
dtype=torch.float16,
device=0
)
pipeline("Plants create energy through a process known as")
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(
"ai21labs/AI21-Jamba-Large-1.6",
)
model = AutoModelForCausalLM.from_pretrained(
"ai21labs/AI21-Jamba-Large-1.6",
dtype=torch.float16,
device_map="auto",
attn_implementation="sdpa"
)
input_ids = tokenizer("Plants create energy through a process known as", return_tensors="pt").to("cuda")
output = model.generate(**input_ids, cache_implementation="static")
print(tokenizer.decode(output[0], skip_special_tokens=True))
echo -e "Plants create energy through a process known as" | transformers run --task text-generation --model ai21labs/AI21-Jamba-Mini-1.6 --device 0
์์ํ๋ ๊ฐ์ค์น๋ฅผ ๋ ๋ฎ์ ์ ๋ฐ๋๋ก ํํํ์ฌ ๋๊ท๋ชจ ๋ชจ๋ธ์ ๋ฉ๋ชจ๋ฆฌ ๋ถ๋ด์ ์ค์ฌ์ค๋๋ค. ์ฌ์ฉํ ์ ์๋ ๋ค์ํ ์์ํ ๋ฐฑ์๋์ ๋ํด์๋ Quantization๋ฅผ ์ฐธ๊ณ ํ์ธ์.
์๋ ์์ ๋ bitsandbytes๋ฅผ ์ฌ์ฉํ์ฌ ๊ฐ์ค์น๋ง 8๋นํธ๋ก ์์ํํ๋ ๋ฐฉ๋ฒ์ ๋ณด์ฌ์ค๋๋ค.
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(load_in_8bit=True,
llm_int8_skip_modules=["mamba"])
# ๋ชจ๋ธ์ 8๊ฐ์ GPU์ ๊ณ ๋ฅด๊ฒ ๋ถ์ฐ์ํค๊ธฐ ์ํ ๋๋ฐ์ด์ค ๋งต
device_map = {'model.embed_tokens': 0, 'model.layers.0': 0, 'model.layers.1': 0, 'model.layers.2': 0, 'model.layers.3': 0, 'model.layers.4': 0, 'model.layers.5': 0, 'model.layers.6': 0, 'model.layers.7': 0, 'model.layers.8': 0, 'model.layers.9': 1, 'model.layers.10': 1, 'model.layers.11': 1, 'model.layers.12': 1, 'model.layers.13': 1, 'model.layers.14': 1, 'model.layers.15': 1, 'model.layers.16': 1, 'model.layers.17': 1, 'model.layers.18': 2, 'model.layers.19': 2, 'model.layers.20': 2, 'model.layers.21': 2, 'model.layers.22': 2, 'model.layers.23': 2, 'model.layers.24': 2, 'model.layers.25': 2, 'model.layers.26': 2, 'model.layers.27': 3, 'model.layers.28': 3, 'model.layers.29': 3, 'model.layers.30': 3, 'model.layers.31': 3, 'model.layers.32': 3, 'model.layers.33': 3, 'model.layers.34': 3, 'model.layers.35': 3, 'model.layers.36': 4, 'model.layers.37': 4, 'model.layers.38': 4, 'model.layers.39': 4, 'model.layers.40': 4, 'model.layers.41': 4, 'model.layers.42': 4, 'model.layers.43': 4, 'model.layers.44': 4, 'model.layers.45': 5, 'model.layers.46': 5, 'model.layers.47': 5, 'model.layers.48': 5, 'model.layers.49': 5, 'model.layers.50': 5, 'model.layers.51': 5, 'model.layers.52': 5, 'model.layers.53': 5, 'model.layers.54': 6, 'model.layers.55': 6, 'model.layers.56': 6, 'model.layers.57': 6, 'model.layers.58': 6, 'model.layers.59': 6, 'model.layers.60': 6, 'model.layers.61': 6, 'model.layers.62': 6, 'model.layers.63': 7, 'model.layers.64': 7, 'model.layers.65': 7, 'model.layers.66': 7, 'model.layers.67': 7, 'model.layers.68': 7, 'model.layers.69': 7, 'model.layers.70': 7, 'model.layers.71': 7, 'model.final_layernorm': 7, 'lm_head': 7}
model = AutoModelForCausalLM.from_pretrained("ai21labs/AI21-Jamba-Large-1.6",
dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
quantization_config=quantization_config,
device_map=device_map)
tokenizer = AutoTokenizer.from_pretrained("ai21labs/AI21-Jamba-Large-1.6")
messages = [
{"role": "system", "content": "You are an ancient oracle who speaks in cryptic but wise phrases, always hinting at deeper meanings."},
{"role": "user", "content": "Hello!"},
]
input_ids = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors='pt').to(model.device)
outputs = model.generate(input_ids, max_new_tokens=216)
# ์ถ๋ ฅ ๋์ฝ๋ฉ
conversation = tokenizer.decode(outputs[0], skip_special_tokens=True)
# ์ด์์คํดํธ์ ์๋ต๋ง ์ถ์ถ
assistant_response = conversation.split(messages[-1]['content'])[1].strip()
print(assistant_response)
# ์ถ๋ ฅ: Seek and you shall find. The path is winding, but the journey is enlightening. What wisdom do you seek from the ancient echoes?
์ฐธ๊ณ [[notes]]
๋ชจ๋ธ ์ฑ๋ฅ ์ ํ๋ฅผ ๋ฐฉ์งํ๊ธฐ ์ํด Mamba ๋ธ๋ก์ ์์ํํ์ง ๋ง์ธ์.
์ต์ ํ๋ Mamba ์ปค๋ ์์ด Mamba๋ฅผ ์ฌ์ฉํ๋ฉด ์ง์ฐ ์๊ฐ์ด ํฌ๊ฒ ์ฆ๊ฐํ๋ฏ๋ก ๊ถ์ฅ๋์ง ์์ต๋๋ค. ๊ทธ๋๋ ์ปค๋ ์์ด Mamba๋ฅผ ์ฌ์ฉํ๊ณ ์ ํ๋ค๋ฉด [
~AutoModel.from_pretrained]์์use_mamba_kernels=False๋ก ์ค์ ํ์ธ์.import torch from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained("ai21labs/AI21-Jamba-1.5-Large", use_mamba_kernels=False)
JambaConfig[[transformers.JambaConfig]]
[[autodoc]] JambaConfig
JambaModel[[transformers.JambaModel]]
[[autodoc]] JambaModel - forward
JambaForCausalLM[[transformers.JambaForCausalLM]]
[[autodoc]] JambaForCausalLM - forward
JambaForSequenceClassification[[transformers.JambaForSequenceClassification]]
[[autodoc]] transformers.JambaForSequenceClassification - forward