Updated README.md.
Browse files
README.md
CHANGED
|
@@ -1,29 +1,15 @@
|
|
| 1 |
---
|
| 2 |
-
extra_gated_heading: You need to share contact information with Databricks to access this model
|
| 3 |
-
extra_gated_prompt: >-
|
| 4 |
-
|
| 5 |
-
### DBRX Terms of Use
|
| 6 |
-
|
| 7 |
-
Use of DBRX is governed by the [Databricks Open Model License](https://www.databricks.com/legal/open-model-license) and the [Databricks Open Model Acceptable Use Policy](https://www.databricks.com/legal/acceptable-use-policy-open-model).
|
| 8 |
-
|
| 9 |
-
extra_gated_fields:
|
| 10 |
-
First Name: text
|
| 11 |
-
Last Name: text
|
| 12 |
-
Organization: text
|
| 13 |
-
Purpose for Base Model Access: text
|
| 14 |
-
By clicking 'Submit' below, I accept the terms of the license and acknowledge that the information I provide will be collected, stored, processed, and shared in accordance with Databricks' Privacy Notice and I understand I can update my preferences at any time: checkbox
|
| 15 |
-
extra_gated_description: >-
|
| 16 |
-
The information you provide will be collected, stored, processed, and shared in accordance with Databricks [Privacy Notice](https://www.databricks.com/legal/privacynotice).
|
| 17 |
-
extra_gated_button_content: Submit
|
| 18 |
inference: false
|
| 19 |
license: other
|
| 20 |
license_name: databricks-open-model-license
|
| 21 |
license_link: https://www.databricks.com/legal/open-model-license
|
| 22 |
---
|
|
|
|
|
|
|
| 23 |
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
|
| 28 |
# DBRX Base
|
| 29 |
|
|
@@ -86,8 +72,8 @@ export HF_HUB_ENABLE_HF_TRANSFER=1
|
|
| 86 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 87 |
import torch
|
| 88 |
|
| 89 |
-
tokenizer = AutoTokenizer.from_pretrained("
|
| 90 |
-
model = AutoModelForCausalLM.from_pretrained("
|
| 91 |
|
| 92 |
input_text = "Databricks was founded in "
|
| 93 |
input_ids = tokenizer(input_text, return_tensors="pt")
|
|
@@ -101,8 +87,8 @@ print(tokenizer.decode(outputs[0]))
|
|
| 101 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 102 |
import torch
|
| 103 |
|
| 104 |
-
tokenizer = AutoTokenizer.from_pretrained("
|
| 105 |
-
model = AutoModelForCausalLM.from_pretrained("
|
| 106 |
|
| 107 |
input_text = "Databricks was founded in "
|
| 108 |
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
|
|
|
|
| 1 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
inference: false
|
| 3 |
license: other
|
| 4 |
license_name: databricks-open-model-license
|
| 5 |
license_link: https://www.databricks.com/legal/open-model-license
|
| 6 |
---
|
| 7 |
+
# Fix for the DBRX Code
|
| 8 |
+
The original DBRX implementation code has a few bugs which only affect training, which I fixed in this re-upload.
|
| 9 |
|
| 10 |
+
The issues - How I fixed them:
|
| 11 |
+
1. Error when using gradient checkpointing - Fixed by using positional arguments instead because `_gradient_checkpointing_func` doesn't support kwargs.
|
| 12 |
+
2. VRAM usage go zoom and `CUDA Out of Memory` when backpropping through the MLP layer - Fixed by separating the experts' weights into different tensors instead of using a single tensor for all the experts. IDK why this fixed it but **maybe** it's because torch is trying to compute gradient for every expert at once, which shouldn't happen since it's a MoE model.
|
| 13 |
|
| 14 |
# DBRX Base
|
| 15 |
|
|
|
|
| 72 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 73 |
import torch
|
| 74 |
|
| 75 |
+
tokenizer = AutoTokenizer.from_pretrained("v2ray/dbrx-base-fixed", trust_remote_code=True)
|
| 76 |
+
model = AutoModelForCausalLM.from_pretrained("v2ray/dbrx-base-fixed", device_map="cpu", torch_dtype=torch.bfloat16, trust_remote_code=True)
|
| 77 |
|
| 78 |
input_text = "Databricks was founded in "
|
| 79 |
input_ids = tokenizer(input_text, return_tensors="pt")
|
|
|
|
| 87 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 88 |
import torch
|
| 89 |
|
| 90 |
+
tokenizer = AutoTokenizer.from_pretrained("v2ray/dbrx-base-fixed", trust_remote_code=True)
|
| 91 |
+
model = AutoModelForCausalLM.from_pretrained("v2ray/dbrx-base-fixed", device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True)
|
| 92 |
|
| 93 |
input_text = "Databricks was founded in "
|
| 94 |
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
|