File size: 6,368 Bytes
17c6d62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
<!--Copyright 2024 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->

# DBRX[[dbrx]]

## κ°œμš”[[overview]]

DBRXλŠ” [트랜슀포머 기반의](https://www.isattentionallyouneed.com/) λ‹€μŒ 토큰을 μ˜ˆμΈ‘ν•˜λŠ” 디코더 μ „μš© LLM λͺ¨λΈμž…λ‹ˆλ‹€.
총 132B λ§€κ°œλ³€μˆ˜λ₯Ό κ°€μ§„ *μ„Έλ°€ν•œ* μ „λ¬Έκ°€ ν˜Όν•©(MoE) μ•„ν‚€ν…μ²˜λ₯Ό μ‚¬μš©ν•˜λ©°, 이 쀑 36B λ§€κ°œλ³€μˆ˜κ°€ μž…λ ₯λ§ˆλ‹€ ν™œμ„±ν™”λ©λ‹ˆλ‹€.
12T ν† ν°μ˜ ν…μŠ€νŠΈμ™€ μ½”λ“œ λ°μ΄ν„°λ‘œ 사전 ν•™μŠ΅λ˜μ—ˆμŠ΅λ‹ˆλ‹€.

Mixtral-8x7B와 Grok-1κ³Ό 같은 λ‹€λ₯Έ 곡개 MoE λͺ¨λΈλ“€κ³Ό λΉ„κ΅ν–ˆμ„ λ•Œ, DBRXλŠ” 더 λ§Žμ€ 수의 μž‘μ€ 전문가듀을 μ‚¬μš©ν•˜λŠ” μ„Έλ°€ν•œ ꡬ쑰λ₯Ό κ°€μ§€κ³  μžˆμŠ΅λ‹ˆλ‹€. DBRXλŠ” 16개의 μ „λ¬Έκ°€ 쀑 4개λ₯Ό μ„ νƒν•˜λŠ” 반면, Mixtral-8x7B와 Grok-1은 8개의 μ „λ¬Έκ°€ 쀑 2개λ₯Ό μ„ νƒν•©λ‹ˆλ‹€.

μ΄λŠ” 65λ°° 더 λ§Žμ€ μ „λ¬Έκ°€ 쑰합을 κ°€λŠ₯ν•˜κ²Œ ν•˜λ©°, 이λ₯Ό 톡해 λͺ¨λΈμ˜ ν’ˆμ§ˆμ΄ ν–₯μƒλ˜λŠ” 것을 λ°œκ²¬ν–ˆμŠ΅λ‹ˆλ‹€.
DBRXλŠ” νšŒμ „ μœ„μΉ˜ 인코딩(RoPE), 게이트 μ„ ν˜• μœ λ‹›(GLU), κ·Έλ£Ή 쿼리 μ–΄ν…μ…˜(GQA)을 μ‚¬μš©ν•©λ‹ˆλ‹€.
BPE 기반 λͺ¨λΈμ΄λ©° [tiktoken](https://github.com/openai/tiktoken) μ €μž₯μ†Œμ— μ„€λͺ…λœ GPT-4 ν† ν¬λ‚˜μ΄μ €λ₯Ό μ‚¬μš©ν•©λ‹ˆλ‹€.
μ΄λŸ¬ν•œ 선택듀은 μ² μ €ν•œ 평가와 μŠ€μΌ€μΌλ§ μ‹€ν—˜μ„ 기반으둜 μ΄λ£¨μ–΄μ‘ŒμŠ΅λ‹ˆλ‹€.

DBRXλŠ” μ‹ μ€‘ν•˜κ²Œ μ„ λ³„λœ 12T ν† ν°μ˜ λ°μ΄ν„°λ‘œ 사전 ν•™μŠ΅λ˜μ—ˆμœΌλ©°, μ΅œλŒ€ λ¬Έλ§₯ κΈΈμ΄λŠ” 32K ν† ν°μž…λ‹ˆλ‹€.
이 λ°μ΄ν„°λŠ” 토큰 λŒ€λΉ„ MPT 계열 λͺ¨λΈ ν•™μŠ΅μ— μ‚¬μš©λœ 데이터보닀 μ΅œμ†Œ 2λ°° 이상 더 쒋은 κ²ƒμœΌλ‘œ μΆ”μ •λ©λ‹ˆλ‹€.
이 μƒˆλ‘œμš΄ 데이터셋은 데이터 처리λ₯Ό μœ„ν•œ Apache Sparkℒ와 Databricks λ…ΈνŠΈλΆ, 그리고 데이터 관리와 κ±°λ²„λ„ŒμŠ€λ₯Ό μœ„ν•œ Unity Catalogλ₯Ό ν¬ν•¨ν•œ Databricks 도ꡬ 전체λ₯Ό ν™œμš©ν•˜μ—¬ κ°œλ°œλ˜μ—ˆμŠ΅λ‹ˆλ‹€.
μš°λ¦¬λŠ” 사전 ν•™μŠ΅μ„ μœ„ν•΄ 컀리큘럼 ν•™μŠ΅μ„ μ‚¬μš©ν–ˆμœΌλ©°, ν•™μŠ΅ 쀑 데이터 믹슀λ₯Ό λ³€κ²½ν•˜λŠ” 방식이 λͺ¨λΈ ν’ˆμ§ˆμ„ μƒλ‹Ήνžˆ κ°œμ„ ν•œλ‹€λŠ” 것을 λ°œκ²¬ν–ˆμŠ΅λ‹ˆλ‹€.


DBRX Instruct와 DBRX Base에 λŒ€ν•œ 더 μžμ„Έν•œ μ •λ³΄λŠ” 이 [기술 λΈ”λ‘œκ·Έ 포슀트](https://www.databricks.com/blog/introducing-dbrx-new-state-art-open-llm)μ—μ„œ 확인할 수 μžˆμŠ΅λ‹ˆλ‹€.

이 λͺ¨λΈμ€ [eitan-turok](https://huggingface.co/eitanturok)와 [abhi-db](https://huggingface.co/abhi-db)κ°€ κΈ°μ—¬ν–ˆμŠ΅λ‹ˆλ‹€. 원본 μ½”λ“œλŠ” [이곳](https://github.com/databricks/dbrx-instruct)μ—μ„œ 찾을 수 μžˆμ§€λ§Œ, μ΅œμ‹  버전이 아닐 수 μžˆμŠ΅λ‹ˆλ‹€.

## μ‚¬μš© 예[[usage-examples]]

`generate()` λ©”μ†Œλ“œλŠ” DBRXλ₯Ό μ‚¬μš©ν•˜μ—¬ ν…μŠ€νŠΈλ₯Ό μƒμ„±ν•˜λŠ” 데 μ‚¬μš©λ  수 μžˆμŠ΅λ‹ˆλ‹€. ν‘œμ€€ μ–΄ν…μ…˜ κ΅¬ν˜„, ν”Œλž˜μ‹œ μ–΄ν…μ…˜, PyTorch의 μŠ€μΌ€μΌλœ 내적 μ–΄ν…μ…˜(Scaled Dot-Product Attention)을 μ‚¬μš©ν•˜μ—¬ 생성할 수 μžˆμŠ΅λ‹ˆλ‹€. ν›„μžμ˜ 두 μ–΄ν…μ…˜ κ΅¬ν˜„ 방식은 처리 속도λ₯Ό 크게 λ†’μ—¬μ€λ‹ˆλ‹€.

```python
from transformers import DbrxForCausalLM, AutoTokenizer
import torch

tokenizer = AutoTokenizer.from_pretrained("databricks/dbrx-instruct", token="YOUR_HF_TOKEN")
model = DbrxForCausalLM.from_pretrained(
    "databricks/dbrx-instruct",
    device_map="auto",
    torch_dtype=torch.bfloat16,
    token="YOUR_HF_TOKEN",
    )

input_text = "What does it take to build a great LLM?"
messages = [{"role": "user", "content": input_text}]
input_ids = tokenizer.apply_chat_template(messages, return_dict=True, tokenize=True, add_generation_prompt=True, return_tensors="pt").to("cuda")

outputs = model.generate(**input_ids, max_new_tokens=200)
print(tokenizer.decode(outputs[0]))
```

`pip install flash-attn`λ₯Ό 톡해 ν”Œλž˜μ‹œ μ–΄ν…μ…˜μ„ μ„€μΉ˜ν•˜λ©΄, 더 λΉ λ₯Έ 생성이 κ°€λŠ₯ν•©λ‹ˆλ‹€. (ν”Œλž˜μ‹œ μ–΄ν…μ…˜μ— λŒ€ν•œ HuggingFace λ¬Έμ„œλŠ” [이곳](https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2)μ—μ„œ 확인할 수 μžˆμŠ΅λ‹ˆλ‹€.)


```python
from transformers import DbrxForCausalLM, AutoTokenizer
import torch

tokenizer = AutoTokenizer.from_pretrained("databricks/dbrx-instruct", token="YOUR_HF_TOKEN")
model = DbrxForCausalLM.from_pretrained(
    "databricks/dbrx-instruct",
    device_map="auto",
    torch_dtype=torch.bfloat16,
    token="YOUR_HF_TOKEN",
    attn_implementation="flash_attention_2",
    )

input_text = "What does it take to build a great LLM?"
messages = [{"role": "user", "content": input_text}]
input_ids = tokenizer.apply_chat_template(messages, return_dict=True, tokenize=True, add_generation_prompt=True, return_tensors="pt").to("cuda")

outputs = model.generate(**input_ids, max_new_tokens=200)
print(tokenizer.decode(outputs[0]))
```

PyTorch의 μŠ€μΌ€μΌλœ 내적 μ–΄ν…μ…˜μ„ μ‚¬μš©ν•˜μ—¬λ„ 더 λΉ λ₯Έ 생성이 κ°€λŠ₯ν•©λ‹ˆλ‹€. (μŠ€μΌ€μΌλœ 내적 μ–΄ν…μ…˜μ— λŒ€ν•œ HuggingFace λ¬Έμ„œλŠ” [이곳](https://huggingface.co/docs/transformers/perf_infer_gpu_one#pytorch-scaled-dot-product-attention)μ—μ„œ 확인할 수 μžˆμŠ΅λ‹ˆλ‹€.)


```python
from transformers import DbrxForCausalLM, AutoTokenizer
import torch

tokenizer = AutoTokenizer.from_pretrained("databricks/dbrx-instruct", token="YOUR_HF_TOKEN")
model = DbrxForCausalLM.from_pretrained(
    "databricks/dbrx-instruct",
    device_map="auto",
    torch_dtype=torch.bfloat16,
    token="YOUR_HF_TOKEN",
    attn_implementation="sdpa",
    )

input_text = "What does it take to build a great LLM?"
messages = [{"role": "user", "content": input_text}]
input_ids = tokenizer.apply_chat_template(messages, return_dict=True, tokenize=True, add_generation_prompt=True, return_tensors="pt").to("cuda")

outputs = model.generate(**input_ids, max_new_tokens=200)
print(tokenizer.decode(outputs[0]))
```

## DbrxConfig[[transformers.DbrxConfig]]

[[autodoc]] DbrxConfig


## DbrxModel[[transformers.DbrxModel]]

[[autodoc]] DbrxModel
    - forward


## DbrxForCausalLM[[transformers.DbrxForCausalLM]]

[[autodoc]] DbrxForCausalLM
    - forward