DrDavis's picture
Upload folder using huggingface_hub
17c6d62 verified

DBRX[[dbrx]]

๊ฐœ์š”[[overview]]

DBRX๋Š” ํŠธ๋žœ์Šคํฌ๋จธ ๊ธฐ๋ฐ˜์˜ ๋‹ค์Œ ํ† ํฐ์„ ์˜ˆ์ธกํ•˜๋Š” ๋””์ฝ”๋” ์ „์šฉ LLM ๋ชจ๋ธ์ž…๋‹ˆ๋‹ค. ์ด 132B ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ ๊ฐ€์ง„ ์„ธ๋ฐ€ํ•œ ์ „๋ฌธ๊ฐ€ ํ˜ผํ•ฉ(MoE) ์•„ํ‚คํ…์ฒ˜๋ฅผ ์‚ฌ์šฉํ•˜๋ฉฐ, ์ด ์ค‘ 36B ๋งค๊ฐœ๋ณ€์ˆ˜๊ฐ€ ์ž…๋ ฅ๋งˆ๋‹ค ํ™œ์„ฑํ™”๋ฉ๋‹ˆ๋‹ค. 12T ํ† ํฐ์˜ ํ…์ŠคํŠธ์™€ ์ฝ”๋“œ ๋ฐ์ดํ„ฐ๋กœ ์‚ฌ์ „ ํ•™์Šต๋˜์—ˆ์Šต๋‹ˆ๋‹ค.

Mixtral-8x7B์™€ Grok-1๊ณผ ๊ฐ™์€ ๋‹ค๋ฅธ ๊ณต๊ฐœ MoE ๋ชจ๋ธ๋“ค๊ณผ ๋น„๊ตํ–ˆ์„ ๋•Œ, DBRX๋Š” ๋” ๋งŽ์€ ์ˆ˜์˜ ์ž‘์€ ์ „๋ฌธ๊ฐ€๋“ค์„ ์‚ฌ์šฉํ•˜๋Š” ์„ธ๋ฐ€ํ•œ ๊ตฌ์กฐ๋ฅผ ๊ฐ€์ง€๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค. DBRX๋Š” 16๊ฐœ์˜ ์ „๋ฌธ๊ฐ€ ์ค‘ 4๊ฐœ๋ฅผ ์„ ํƒํ•˜๋Š” ๋ฐ˜๋ฉด, Mixtral-8x7B์™€ Grok-1์€ 8๊ฐœ์˜ ์ „๋ฌธ๊ฐ€ ์ค‘ 2๊ฐœ๋ฅผ ์„ ํƒํ•ฉ๋‹ˆ๋‹ค.

์ด๋Š” 65๋ฐฐ ๋” ๋งŽ์€ ์ „๋ฌธ๊ฐ€ ์กฐํ•ฉ์„ ๊ฐ€๋Šฅํ•˜๊ฒŒ ํ•˜๋ฉฐ, ์ด๋ฅผ ํ†ตํ•ด ๋ชจ๋ธ์˜ ํ’ˆ์งˆ์ด ํ–ฅ์ƒ๋˜๋Š” ๊ฒƒ์„ ๋ฐœ๊ฒฌํ–ˆ์Šต๋‹ˆ๋‹ค. DBRX๋Š” ํšŒ์ „ ์œ„์น˜ ์ธ์ฝ”๋”ฉ(RoPE), ๊ฒŒ์ดํŠธ ์„ ํ˜• ์œ ๋‹›(GLU), ๊ทธ๋ฃน ์ฟผ๋ฆฌ ์–ดํ…์…˜(GQA)์„ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค. BPE ๊ธฐ๋ฐ˜ ๋ชจ๋ธ์ด๋ฉฐ tiktoken ์ €์žฅ์†Œ์— ์„ค๋ช…๋œ GPT-4 ํ† ํฌ๋‚˜์ด์ €๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค. ์ด๋Ÿฌํ•œ ์„ ํƒ๋“ค์€ ์ฒ ์ €ํ•œ ํ‰๊ฐ€์™€ ์Šค์ผ€์ผ๋ง ์‹คํ—˜์„ ๊ธฐ๋ฐ˜์œผ๋กœ ์ด๋ฃจ์–ด์กŒ์Šต๋‹ˆ๋‹ค.

DBRX๋Š” ์‹ ์ค‘ํ•˜๊ฒŒ ์„ ๋ณ„๋œ 12T ํ† ํฐ์˜ ๋ฐ์ดํ„ฐ๋กœ ์‚ฌ์ „ ํ•™์Šต๋˜์—ˆ์œผ๋ฉฐ, ์ตœ๋Œ€ ๋ฌธ๋งฅ ๊ธธ์ด๋Š” 32K ํ† ํฐ์ž…๋‹ˆ๋‹ค. ์ด ๋ฐ์ดํ„ฐ๋Š” ํ† ํฐ ๋Œ€๋น„ MPT ๊ณ„์—ด ๋ชจ๋ธ ํ•™์Šต์— ์‚ฌ์šฉ๋œ ๋ฐ์ดํ„ฐ๋ณด๋‹ค ์ตœ์†Œ 2๋ฐฐ ์ด์ƒ ๋” ์ข‹์€ ๊ฒƒ์œผ๋กœ ์ถ”์ •๋ฉ๋‹ˆ๋‹ค. ์ด ์ƒˆ๋กœ์šด ๋ฐ์ดํ„ฐ์…‹์€ ๋ฐ์ดํ„ฐ ์ฒ˜๋ฆฌ๋ฅผ ์œ„ํ•œ Apache Sparkโ„ข์™€ Databricks ๋…ธํŠธ๋ถ, ๊ทธ๋ฆฌ๊ณ  ๋ฐ์ดํ„ฐ ๊ด€๋ฆฌ์™€ ๊ฑฐ๋ฒ„๋„Œ์Šค๋ฅผ ์œ„ํ•œ Unity Catalog๋ฅผ ํฌํ•จํ•œ Databricks ๋„๊ตฌ ์ „์ฒด๋ฅผ ํ™œ์šฉํ•˜์—ฌ ๊ฐœ๋ฐœ๋˜์—ˆ์Šต๋‹ˆ๋‹ค. ์šฐ๋ฆฌ๋Š” ์‚ฌ์ „ ํ•™์Šต์„ ์œ„ํ•ด ์ปค๋ฆฌํ˜๋Ÿผ ํ•™์Šต์„ ์‚ฌ์šฉํ–ˆ์œผ๋ฉฐ, ํ•™์Šต ์ค‘ ๋ฐ์ดํ„ฐ ๋ฏน์Šค๋ฅผ ๋ณ€๊ฒฝํ•˜๋Š” ๋ฐฉ์‹์ด ๋ชจ๋ธ ํ’ˆ์งˆ์„ ์ƒ๋‹นํžˆ ๊ฐœ์„ ํ•œ๋‹ค๋Š” ๊ฒƒ์„ ๋ฐœ๊ฒฌํ–ˆ์Šต๋‹ˆ๋‹ค.

DBRX Instruct์™€ DBRX Base์— ๋Œ€ํ•œ ๋” ์ž์„ธํ•œ ์ •๋ณด๋Š” ์ด ๊ธฐ์ˆ  ๋ธ”๋กœ๊ทธ ํฌ์ŠคํŠธ์—์„œ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

์ด ๋ชจ๋ธ์€ eitan-turok์™€ abhi-db๊ฐ€ ๊ธฐ์—ฌํ–ˆ์Šต๋‹ˆ๋‹ค. ์›๋ณธ ์ฝ”๋“œ๋Š” ์ด๊ณณ์—์„œ ์ฐพ์„ ์ˆ˜ ์žˆ์ง€๋งŒ, ์ตœ์‹  ๋ฒ„์ „์ด ์•„๋‹ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

์‚ฌ์šฉ ์˜ˆ[[usage-examples]]

generate() ๋ฉ”์†Œ๋“œ๋Š” DBRX๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ํ…์ŠคํŠธ๋ฅผ ์ƒ์„ฑํ•˜๋Š” ๋ฐ ์‚ฌ์šฉ๋  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ํ‘œ์ค€ ์–ดํ…์…˜ ๊ตฌํ˜„, ํ”Œ๋ž˜์‹œ ์–ดํ…์…˜, PyTorch์˜ ์Šค์ผ€์ผ๋œ ๋‚ด์  ์–ดํ…์…˜(Scaled Dot-Product Attention)์„ ์‚ฌ์šฉํ•˜์—ฌ ์ƒ์„ฑํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ํ›„์ž์˜ ๋‘ ์–ดํ…์…˜ ๊ตฌํ˜„ ๋ฐฉ์‹์€ ์ฒ˜๋ฆฌ ์†๋„๋ฅผ ํฌ๊ฒŒ ๋†’์—ฌ์ค๋‹ˆ๋‹ค.

from transformers import DbrxForCausalLM, AutoTokenizer
import torch

tokenizer = AutoTokenizer.from_pretrained("databricks/dbrx-instruct", token="YOUR_HF_TOKEN")
model = DbrxForCausalLM.from_pretrained(
    "databricks/dbrx-instruct",
    device_map="auto",
    torch_dtype=torch.bfloat16,
    token="YOUR_HF_TOKEN",
    )

input_text = "What does it take to build a great LLM?"
messages = [{"role": "user", "content": input_text}]
input_ids = tokenizer.apply_chat_template(messages, return_dict=True, tokenize=True, add_generation_prompt=True, return_tensors="pt").to("cuda")

outputs = model.generate(**input_ids, max_new_tokens=200)
print(tokenizer.decode(outputs[0]))

pip install flash-attn๋ฅผ ํ†ตํ•ด ํ”Œ๋ž˜์‹œ ์–ดํ…์…˜์„ ์„ค์น˜ํ•˜๋ฉด, ๋” ๋น ๋ฅธ ์ƒ์„ฑ์ด ๊ฐ€๋Šฅํ•ฉ๋‹ˆ๋‹ค. (ํ”Œ๋ž˜์‹œ ์–ดํ…์…˜์— ๋Œ€ํ•œ HuggingFace ๋ฌธ์„œ๋Š” ์ด๊ณณ์—์„œ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.)

from transformers import DbrxForCausalLM, AutoTokenizer
import torch

tokenizer = AutoTokenizer.from_pretrained("databricks/dbrx-instruct", token="YOUR_HF_TOKEN")
model = DbrxForCausalLM.from_pretrained(
    "databricks/dbrx-instruct",
    device_map="auto",
    torch_dtype=torch.bfloat16,
    token="YOUR_HF_TOKEN",
    attn_implementation="flash_attention_2",
    )

input_text = "What does it take to build a great LLM?"
messages = [{"role": "user", "content": input_text}]
input_ids = tokenizer.apply_chat_template(messages, return_dict=True, tokenize=True, add_generation_prompt=True, return_tensors="pt").to("cuda")

outputs = model.generate(**input_ids, max_new_tokens=200)
print(tokenizer.decode(outputs[0]))

PyTorch์˜ ์Šค์ผ€์ผ๋œ ๋‚ด์  ์–ดํ…์…˜์„ ์‚ฌ์šฉํ•˜์—ฌ๋„ ๋” ๋น ๋ฅธ ์ƒ์„ฑ์ด ๊ฐ€๋Šฅํ•ฉ๋‹ˆ๋‹ค. (์Šค์ผ€์ผ๋œ ๋‚ด์  ์–ดํ…์…˜์— ๋Œ€ํ•œ HuggingFace ๋ฌธ์„œ๋Š” ์ด๊ณณ์—์„œ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.)

from transformers import DbrxForCausalLM, AutoTokenizer
import torch

tokenizer = AutoTokenizer.from_pretrained("databricks/dbrx-instruct", token="YOUR_HF_TOKEN")
model = DbrxForCausalLM.from_pretrained(
    "databricks/dbrx-instruct",
    device_map="auto",
    torch_dtype=torch.bfloat16,
    token="YOUR_HF_TOKEN",
    attn_implementation="sdpa",
    )

input_text = "What does it take to build a great LLM?"
messages = [{"role": "user", "content": input_text}]
input_ids = tokenizer.apply_chat_template(messages, return_dict=True, tokenize=True, add_generation_prompt=True, return_tensors="pt").to("cuda")

outputs = model.generate(**input_ids, max_new_tokens=200)
print(tokenizer.decode(outputs[0]))

DbrxConfig[[transformers.DbrxConfig]]

[[autodoc]] DbrxConfig

DbrxModel[[transformers.DbrxModel]]

[[autodoc]] DbrxModel - forward

DbrxForCausalLM[[transformers.DbrxForCausalLM]]

[[autodoc]] DbrxForCausalLM - forward