DBRX[[dbrx]]
๊ฐ์[[overview]]
DBRX๋ ํธ๋์คํฌ๋จธ ๊ธฐ๋ฐ์ ๋ค์ ํ ํฐ์ ์์ธกํ๋ ๋์ฝ๋ ์ ์ฉ LLM ๋ชจ๋ธ์ ๋๋ค. ์ด 132B ๋งค๊ฐ๋ณ์๋ฅผ ๊ฐ์ง ์ธ๋ฐํ ์ ๋ฌธ๊ฐ ํผํฉ(MoE) ์ํคํ ์ฒ๋ฅผ ์ฌ์ฉํ๋ฉฐ, ์ด ์ค 36B ๋งค๊ฐ๋ณ์๊ฐ ์ ๋ ฅ๋ง๋ค ํ์ฑํ๋ฉ๋๋ค. 12T ํ ํฐ์ ํ ์คํธ์ ์ฝ๋ ๋ฐ์ดํฐ๋ก ์ฌ์ ํ์ต๋์์ต๋๋ค.
Mixtral-8x7B์ Grok-1๊ณผ ๊ฐ์ ๋ค๋ฅธ ๊ณต๊ฐ MoE ๋ชจ๋ธ๋ค๊ณผ ๋น๊ตํ์ ๋, DBRX๋ ๋ ๋ง์ ์์ ์์ ์ ๋ฌธ๊ฐ๋ค์ ์ฌ์ฉํ๋ ์ธ๋ฐํ ๊ตฌ์กฐ๋ฅผ ๊ฐ์ง๊ณ ์์ต๋๋ค. DBRX๋ 16๊ฐ์ ์ ๋ฌธ๊ฐ ์ค 4๊ฐ๋ฅผ ์ ํํ๋ ๋ฐ๋ฉด, Mixtral-8x7B์ Grok-1์ 8๊ฐ์ ์ ๋ฌธ๊ฐ ์ค 2๊ฐ๋ฅผ ์ ํํฉ๋๋ค.
์ด๋ 65๋ฐฐ ๋ ๋ง์ ์ ๋ฌธ๊ฐ ์กฐํฉ์ ๊ฐ๋ฅํ๊ฒ ํ๋ฉฐ, ์ด๋ฅผ ํตํด ๋ชจ๋ธ์ ํ์ง์ด ํฅ์๋๋ ๊ฒ์ ๋ฐ๊ฒฌํ์ต๋๋ค. DBRX๋ ํ์ ์์น ์ธ์ฝ๋ฉ(RoPE), ๊ฒ์ดํธ ์ ํ ์ ๋(GLU), ๊ทธ๋ฃน ์ฟผ๋ฆฌ ์ดํ ์ (GQA)์ ์ฌ์ฉํฉ๋๋ค. BPE ๊ธฐ๋ฐ ๋ชจ๋ธ์ด๋ฉฐ tiktoken ์ ์ฅ์์ ์ค๋ช ๋ GPT-4 ํ ํฌ๋์ด์ ๋ฅผ ์ฌ์ฉํฉ๋๋ค. ์ด๋ฌํ ์ ํ๋ค์ ์ฒ ์ ํ ํ๊ฐ์ ์ค์ผ์ผ๋ง ์คํ์ ๊ธฐ๋ฐ์ผ๋ก ์ด๋ฃจ์ด์ก์ต๋๋ค.
DBRX๋ ์ ์คํ๊ฒ ์ ๋ณ๋ 12T ํ ํฐ์ ๋ฐ์ดํฐ๋ก ์ฌ์ ํ์ต๋์์ผ๋ฉฐ, ์ต๋ ๋ฌธ๋งฅ ๊ธธ์ด๋ 32K ํ ํฐ์ ๋๋ค. ์ด ๋ฐ์ดํฐ๋ ํ ํฐ ๋๋น MPT ๊ณ์ด ๋ชจ๋ธ ํ์ต์ ์ฌ์ฉ๋ ๋ฐ์ดํฐ๋ณด๋ค ์ต์ 2๋ฐฐ ์ด์ ๋ ์ข์ ๊ฒ์ผ๋ก ์ถ์ ๋ฉ๋๋ค. ์ด ์๋ก์ด ๋ฐ์ดํฐ์ ์ ๋ฐ์ดํฐ ์ฒ๋ฆฌ๋ฅผ ์ํ Apache Sparkโข์ Databricks ๋ ธํธ๋ถ, ๊ทธ๋ฆฌ๊ณ ๋ฐ์ดํฐ ๊ด๋ฆฌ์ ๊ฑฐ๋ฒ๋์ค๋ฅผ ์ํ Unity Catalog๋ฅผ ํฌํจํ Databricks ๋๊ตฌ ์ ์ฒด๋ฅผ ํ์ฉํ์ฌ ๊ฐ๋ฐ๋์์ต๋๋ค. ์ฐ๋ฆฌ๋ ์ฌ์ ํ์ต์ ์ํด ์ปค๋ฆฌํ๋ผ ํ์ต์ ์ฌ์ฉํ์ผ๋ฉฐ, ํ์ต ์ค ๋ฐ์ดํฐ ๋ฏน์ค๋ฅผ ๋ณ๊ฒฝํ๋ ๋ฐฉ์์ด ๋ชจ๋ธ ํ์ง์ ์๋นํ ๊ฐ์ ํ๋ค๋ ๊ฒ์ ๋ฐ๊ฒฌํ์ต๋๋ค.
DBRX Instruct์ DBRX Base์ ๋ํ ๋ ์์ธํ ์ ๋ณด๋ ์ด ๊ธฐ์ ๋ธ๋ก๊ทธ ํฌ์คํธ์์ ํ์ธํ ์ ์์ต๋๋ค.
์ด ๋ชจ๋ธ์ eitan-turok์ abhi-db๊ฐ ๊ธฐ์ฌํ์ต๋๋ค. ์๋ณธ ์ฝ๋๋ ์ด๊ณณ์์ ์ฐพ์ ์ ์์ง๋ง, ์ต์ ๋ฒ์ ์ด ์๋ ์ ์์ต๋๋ค.
์ฌ์ฉ ์[[usage-examples]]
generate() ๋ฉ์๋๋ DBRX๋ฅผ ์ฌ์ฉํ์ฌ ํ
์คํธ๋ฅผ ์์ฑํ๋ ๋ฐ ์ฌ์ฉ๋ ์ ์์ต๋๋ค. ํ์ค ์ดํ
์
๊ตฌํ, ํ๋์ ์ดํ
์
, PyTorch์ ์ค์ผ์ผ๋ ๋ด์ ์ดํ
์
(Scaled Dot-Product Attention)์ ์ฌ์ฉํ์ฌ ์์ฑํ ์ ์์ต๋๋ค. ํ์์ ๋ ์ดํ
์
๊ตฌํ ๋ฐฉ์์ ์ฒ๋ฆฌ ์๋๋ฅผ ํฌ๊ฒ ๋์ฌ์ค๋๋ค.
from transformers import DbrxForCausalLM, AutoTokenizer
import torch
tokenizer = AutoTokenizer.from_pretrained("databricks/dbrx-instruct", token="YOUR_HF_TOKEN")
model = DbrxForCausalLM.from_pretrained(
"databricks/dbrx-instruct",
device_map="auto",
torch_dtype=torch.bfloat16,
token="YOUR_HF_TOKEN",
)
input_text = "What does it take to build a great LLM?"
messages = [{"role": "user", "content": input_text}]
input_ids = tokenizer.apply_chat_template(messages, return_dict=True, tokenize=True, add_generation_prompt=True, return_tensors="pt").to("cuda")
outputs = model.generate(**input_ids, max_new_tokens=200)
print(tokenizer.decode(outputs[0]))
pip install flash-attn๋ฅผ ํตํด ํ๋์ ์ดํ
์
์ ์ค์นํ๋ฉด, ๋ ๋น ๋ฅธ ์์ฑ์ด ๊ฐ๋ฅํฉ๋๋ค. (ํ๋์ ์ดํ
์
์ ๋ํ HuggingFace ๋ฌธ์๋ ์ด๊ณณ์์ ํ์ธํ ์ ์์ต๋๋ค.)
from transformers import DbrxForCausalLM, AutoTokenizer
import torch
tokenizer = AutoTokenizer.from_pretrained("databricks/dbrx-instruct", token="YOUR_HF_TOKEN")
model = DbrxForCausalLM.from_pretrained(
"databricks/dbrx-instruct",
device_map="auto",
torch_dtype=torch.bfloat16,
token="YOUR_HF_TOKEN",
attn_implementation="flash_attention_2",
)
input_text = "What does it take to build a great LLM?"
messages = [{"role": "user", "content": input_text}]
input_ids = tokenizer.apply_chat_template(messages, return_dict=True, tokenize=True, add_generation_prompt=True, return_tensors="pt").to("cuda")
outputs = model.generate(**input_ids, max_new_tokens=200)
print(tokenizer.decode(outputs[0]))
PyTorch์ ์ค์ผ์ผ๋ ๋ด์ ์ดํ ์ ์ ์ฌ์ฉํ์ฌ๋ ๋ ๋น ๋ฅธ ์์ฑ์ด ๊ฐ๋ฅํฉ๋๋ค. (์ค์ผ์ผ๋ ๋ด์ ์ดํ ์ ์ ๋ํ HuggingFace ๋ฌธ์๋ ์ด๊ณณ์์ ํ์ธํ ์ ์์ต๋๋ค.)
from transformers import DbrxForCausalLM, AutoTokenizer
import torch
tokenizer = AutoTokenizer.from_pretrained("databricks/dbrx-instruct", token="YOUR_HF_TOKEN")
model = DbrxForCausalLM.from_pretrained(
"databricks/dbrx-instruct",
device_map="auto",
torch_dtype=torch.bfloat16,
token="YOUR_HF_TOKEN",
attn_implementation="sdpa",
)
input_text = "What does it take to build a great LLM?"
messages = [{"role": "user", "content": input_text}]
input_ids = tokenizer.apply_chat_template(messages, return_dict=True, tokenize=True, add_generation_prompt=True, return_tensors="pt").to("cuda")
outputs = model.generate(**input_ids, max_new_tokens=200)
print(tokenizer.decode(outputs[0]))
DbrxConfig[[transformers.DbrxConfig]]
[[autodoc]] DbrxConfig
DbrxModel[[transformers.DbrxModel]]
[[autodoc]] DbrxModel - forward
DbrxForCausalLM[[transformers.DbrxForCausalLM]]
[[autodoc]] DbrxForCausalLM - forward