Spaces:
Runtime error
Runtime error
menouar
commited on
Commit
·
09bdd6c
1
Parent(s):
611507d
Use FalconForCausalLM for falcon finetuning instead of AutoModelForCausalLM
Browse files
utils/notebook_generator.py
CHANGED
|
@@ -2,7 +2,7 @@ from typing import Optional
|
|
| 2 |
|
| 3 |
import nbformat as nbf
|
| 4 |
|
| 5 |
-
from utils import FTDataSet
|
| 6 |
|
| 7 |
|
| 8 |
def create_install_libraries_cells(cells: list):
|
|
@@ -130,9 +130,13 @@ def create_model_cells(cells: list, model_id: str, version: str, flash_attention
|
|
| 130 |
if pad_value is None:
|
| 131 |
pad_value_str = ""
|
| 132 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
code = f"""
|
| 134 |
import torch
|
| 135 |
-
from transformers import AutoTokenizer,
|
| 136 |
from trl import setup_chat_format
|
| 137 |
|
| 138 |
# Hugging Face model id
|
|
@@ -145,7 +149,7 @@ bnb_config = BitsAndBytesConfig(
|
|
| 145 |
)
|
| 146 |
|
| 147 |
# Load model and tokenizer
|
| 148 |
-
model =
|
| 149 |
model_id,
|
| 150 |
device_map="auto",
|
| 151 |
trust_remote_code=True,
|
|
|
|
| 2 |
|
| 3 |
import nbformat as nbf
|
| 4 |
|
| 5 |
+
from utils import FTDataSet, falcon
|
| 6 |
|
| 7 |
|
| 8 |
def create_install_libraries_cells(cells: list):
|
|
|
|
| 130 |
if pad_value is None:
|
| 131 |
pad_value_str = ""
|
| 132 |
|
| 133 |
+
auto_model_import = "AutoModelForCausalLM"
|
| 134 |
+
if model_id == falcon.name:
|
| 135 |
+
auto_model_import = "FalconForCausalLM"
|
| 136 |
+
|
| 137 |
code = f"""
|
| 138 |
import torch
|
| 139 |
+
from transformers import AutoTokenizer, {auto_model_import}, BitsAndBytesConfig
|
| 140 |
from trl import setup_chat_format
|
| 141 |
|
| 142 |
# Hugging Face model id
|
|
|
|
| 149 |
)
|
| 150 |
|
| 151 |
# Load model and tokenizer
|
| 152 |
+
model = {auto_model_import}.from_pretrained(
|
| 153 |
model_id,
|
| 154 |
device_map="auto",
|
| 155 |
trust_remote_code=True,
|