| <!--Copyright 2024 The HuggingFace Team. All rights reserved. | |
| Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | |
| the License. You may obtain a copy of the License at | |
| http://www.apache.org/licenses/LICENSE-2.0 | |
| Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | |
| an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | |
| specific language governing permissions and limitations under the License. | |
| --> | |
| # DBRX[[dbrx]] | |
| ## κ°μ[[overview]] | |
| DBRXλ [νΈλμ€ν¬λ¨Έ κΈ°λ°μ](https://www.isattentionallyouneed.com/) λ€μ ν ν°μ μμΈ‘νλ λμ½λ μ μ© LLM λͺ¨λΈμ λλ€. | |
| μ΄ 132B λ§€κ°λ³μλ₯Ό κ°μ§ *μΈλ°ν* μ λ¬Έκ° νΌν©(MoE) μν€ν μ²λ₯Ό μ¬μ©νλ©°, μ΄ μ€ 36B λ§€κ°λ³μκ° μ λ ₯λ§λ€ νμ±νλ©λλ€. | |
| 12T ν ν°μ ν μ€νΈμ μ½λ λ°μ΄ν°λ‘ μ¬μ νμ΅λμμ΅λλ€. | |
| Mixtral-8x7Bμ Grok-1κ³Ό κ°μ λ€λ₯Έ κ³΅κ° MoE λͺ¨λΈλ€κ³Ό λΉκ΅νμ λ, DBRXλ λ λ§μ μμ μμ μ λ¬Έκ°λ€μ μ¬μ©νλ μΈλ°ν ꡬ쑰λ₯Ό κ°μ§κ³ μμ΅λλ€. DBRXλ 16κ°μ μ λ¬Έκ° μ€ 4κ°λ₯Ό μ ννλ λ°λ©΄, Mixtral-8x7Bμ Grok-1μ 8κ°μ μ λ¬Έκ° μ€ 2κ°λ₯Ό μ νν©λλ€. | |
| μ΄λ 65λ°° λ λ§μ μ λ¬Έκ° μ‘°ν©μ κ°λ₯νκ² νλ©°, μ΄λ₯Ό ν΅ν΄ λͺ¨λΈμ νμ§μ΄ ν₯μλλ κ²μ λ°κ²¬νμ΅λλ€. | |
| DBRXλ νμ μμΉ μΈμ½λ©(RoPE), κ²μ΄νΈ μ ν μ λ(GLU), κ·Έλ£Ή 쿼리 μ΄ν μ (GQA)μ μ¬μ©ν©λλ€. | |
| BPE κΈ°λ° λͺ¨λΈμ΄λ©° [tiktoken](https://github.com/openai/tiktoken) μ μ₯μμ μ€λͺ λ GPT-4 ν ν¬λμ΄μ λ₯Ό μ¬μ©ν©λλ€. | |
| μ΄λ¬ν μ νλ€μ μ² μ ν νκ°μ μ€μΌμΌλ§ μ€νμ κΈ°λ°μΌλ‘ μ΄λ£¨μ΄μ‘μ΅λλ€. | |
| DBRXλ μ μ€νκ² μ λ³λ 12T ν ν°μ λ°μ΄ν°λ‘ μ¬μ νμ΅λμμΌλ©°, μ΅λ λ¬Έλ§₯ κΈΈμ΄λ 32K ν ν°μ λλ€. | |
| μ΄ λ°μ΄ν°λ ν ν° λλΉ MPT κ³μ΄ λͺ¨λΈ νμ΅μ μ¬μ©λ λ°μ΄ν°λ³΄λ€ μ΅μ 2λ°° μ΄μ λ μ’μ κ²μΌλ‘ μΆμ λ©λλ€. | |
| μ΄ μλ‘μ΄ λ°μ΄ν°μ μ λ°μ΄ν° μ²λ¦¬λ₯Ό μν Apache Sparkβ’μ Databricks λ ΈνΈλΆ, κ·Έλ¦¬κ³ λ°μ΄ν° κ΄λ¦¬μ κ±°λ²λμ€λ₯Ό μν Unity Catalogλ₯Ό ν¬ν¨ν Databricks λꡬ μ 체λ₯Ό νμ©νμ¬ κ°λ°λμμ΅λλ€. | |
| μ°λ¦¬λ μ¬μ νμ΅μ μν΄ μ»€λ¦¬νλΌ νμ΅μ μ¬μ©νμΌλ©°, νμ΅ μ€ λ°μ΄ν° λ―Ήμ€λ₯Ό λ³κ²½νλ λ°©μμ΄ λͺ¨λΈ νμ§μ μλΉν κ°μ νλ€λ κ²μ λ°κ²¬νμ΅λλ€. | |
| DBRX Instructμ DBRX Baseμ λν λ μμΈν μ 보λ μ΄ [κΈ°μ λΈλ‘κ·Έ ν¬μ€νΈ](https://www.databricks.com/blog/introducing-dbrx-new-state-art-open-llm)μμ νμΈν μ μμ΅λλ€. | |
| μ΄ λͺ¨λΈμ [eitan-turok](https://huggingface.co/eitanturok)μ [abhi-db](https://huggingface.co/abhi-db)κ° κΈ°μ¬νμ΅λλ€. μλ³Έ μ½λλ [μ΄κ³³](https://github.com/databricks/dbrx-instruct)μμ μ°Ύμ μ μμ§λ§, μ΅μ λ²μ μ΄ μλ μ μμ΅λλ€. | |
| ## μ¬μ© μ[[usage-examples]] | |
| `generate()` λ©μλλ DBRXλ₯Ό μ¬μ©νμ¬ ν μ€νΈλ₯Ό μμ±νλ λ° μ¬μ©λ μ μμ΅λλ€. νμ€ μ΄ν μ ꡬν, νλμ μ΄ν μ , PyTorchμ μ€μΌμΌλ λ΄μ μ΄ν μ (Scaled Dot-Product Attention)μ μ¬μ©νμ¬ μμ±ν μ μμ΅λλ€. νμμ λ μ΄ν μ ꡬν λ°©μμ μ²λ¦¬ μλλ₯Ό ν¬κ² λμ¬μ€λλ€. | |
| ```python | |
| from transformers import DbrxForCausalLM, AutoTokenizer | |
| import torch | |
| tokenizer = AutoTokenizer.from_pretrained("databricks/dbrx-instruct", token="YOUR_HF_TOKEN") | |
| model = DbrxForCausalLM.from_pretrained( | |
| "databricks/dbrx-instruct", | |
| device_map="auto", | |
| dtype=torch.bfloat16, | |
| token="YOUR_HF_TOKEN", | |
| ) | |
| input_text = "What does it take to build a great LLM?" | |
| messages = [{"role": "user", "content": input_text}] | |
| input_ids = tokenizer.apply_chat_template(messages, return_dict=True, tokenize=True, add_generation_prompt=True, return_tensors="pt").to("cuda") | |
| outputs = model.generate(**input_ids, max_new_tokens=200) | |
| print(tokenizer.decode(outputs[0])) | |
| ``` | |
| `pip install flash-attn`λ₯Ό ν΅ν΄ νλμ μ΄ν μ μ μ€μΉνλ©΄, λ λΉ λ₯Έ μμ±μ΄ κ°λ₯ν©λλ€. (νλμ μ΄ν μ μ λν HuggingFace λ¬Έμλ [μ΄κ³³](https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2)μμ νμΈν μ μμ΅λλ€.) | |
| ```python | |
| from transformers import DbrxForCausalLM, AutoTokenizer | |
| import torch | |
| tokenizer = AutoTokenizer.from_pretrained("databricks/dbrx-instruct", token="YOUR_HF_TOKEN") | |
| model = DbrxForCausalLM.from_pretrained( | |
| "databricks/dbrx-instruct", | |
| device_map="auto", | |
| dtype=torch.bfloat16, | |
| token="YOUR_HF_TOKEN", | |
| attn_implementation="flash_attention_2", | |
| ) | |
| input_text = "What does it take to build a great LLM?" | |
| messages = [{"role": "user", "content": input_text}] | |
| input_ids = tokenizer.apply_chat_template(messages, return_dict=True, tokenize=True, add_generation_prompt=True, return_tensors="pt").to("cuda") | |
| outputs = model.generate(**input_ids, max_new_tokens=200) | |
| print(tokenizer.decode(outputs[0])) | |
| ``` | |
| PyTorchμ μ€μΌμΌλ λ΄μ μ΄ν μ μ μ¬μ©νμ¬λ λ λΉ λ₯Έ μμ±μ΄ κ°λ₯ν©λλ€. (μ€μΌμΌλ λ΄μ μ΄ν μ μ λν HuggingFace λ¬Έμλ [μ΄κ³³](https://huggingface.co/docs/transformers/perf_infer_gpu_one#pytorch-scaled-dot-product-attention)μμ νμΈν μ μμ΅λλ€.) | |
| ```python | |
| from transformers import DbrxForCausalLM, AutoTokenizer | |
| import torch | |
| tokenizer = AutoTokenizer.from_pretrained("databricks/dbrx-instruct", token="YOUR_HF_TOKEN") | |
| model = DbrxForCausalLM.from_pretrained( | |
| "databricks/dbrx-instruct", | |
| device_map="auto", | |
| dtype=torch.bfloat16, | |
| token="YOUR_HF_TOKEN", | |
| attn_implementation="sdpa", | |
| ) | |
| input_text = "What does it take to build a great LLM?" | |
| messages = [{"role": "user", "content": input_text}] | |
| input_ids = tokenizer.apply_chat_template(messages, return_dict=True, tokenize=True, add_generation_prompt=True, return_tensors="pt").to("cuda") | |
| outputs = model.generate(**input_ids, max_new_tokens=200) | |
| print(tokenizer.decode(outputs[0])) | |
| ``` | |
| ## DbrxConfig[[transformers.DbrxConfig]] | |
| [[autodoc]] DbrxConfig | |
| ## DbrxModel[[transformers.DbrxModel]] | |
| [[autodoc]] DbrxModel | |
| - forward | |
| ## DbrxForCausalLM[[transformers.DbrxForCausalLM]] | |
| [[autodoc]] DbrxForCausalLM | |
| - forward | |