Spaces:
Sleeping
Sleeping
Refactor model loading in train.py to use a default model name parameter, enhancing flexibility. Adjust configuration for max sequence length and dtype for improved clarity and consistency.
Browse files
train.py
CHANGED
|
@@ -41,11 +41,10 @@ from transformers import (
|
|
| 41 |
from trl import SFTTrainer
|
| 42 |
|
| 43 |
# Configuration
|
| 44 |
-
|
| 45 |
-
dtype =
|
| 46 |
-
None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
|
| 47 |
-
)
|
| 48 |
load_in_4bit = True # Use 4bit quantization to reduce memory usage
|
|
|
|
| 49 |
validation_split = 0.1 # 10% of data for validation
|
| 50 |
|
| 51 |
|
|
@@ -89,12 +88,12 @@ def install_dependencies():
|
|
| 89 |
raise
|
| 90 |
|
| 91 |
|
| 92 |
-
def load_model() -> tuple[FastLanguageModel, AutoTokenizer]:
|
| 93 |
"""Load and configure the model."""
|
| 94 |
logger.info("Loading model and tokenizer...")
|
| 95 |
try:
|
| 96 |
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 97 |
-
model_name=
|
| 98 |
max_seq_length=max_seq_length,
|
| 99 |
dtype=dtype,
|
| 100 |
load_in_4bit=load_in_4bit,
|
|
|
|
| 41 |
from trl import SFTTrainer
|
| 42 |
|
| 43 |
# Configuration
|
| 44 |
+
DEFAULT_MODEL_NAME = "unsloth/SmolLM2-135M-Instruct-bnb-4bit"
|
| 45 |
+
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
|
|
|
|
|
|
|
| 46 |
load_in_4bit = True # Use 4bit quantization to reduce memory usage
|
| 47 |
+
max_seq_length = 2048 # Auto supports RoPE Scaling internally
|
| 48 |
validation_split = 0.1 # 10% of data for validation
|
| 49 |
|
| 50 |
|
|
|
|
| 88 |
raise
|
| 89 |
|
| 90 |
|
| 91 |
+
def load_model(model_name: str = DEFAULT_MODEL_NAME) -> tuple[FastLanguageModel, AutoTokenizer]:
|
| 92 |
"""Load and configure the model."""
|
| 93 |
logger.info("Loading model and tokenizer...")
|
| 94 |
try:
|
| 95 |
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 96 |
+
model_name=model_name,
|
| 97 |
max_seq_length=max_seq_length,
|
| 98 |
dtype=dtype,
|
| 99 |
load_in_4bit=load_in_4bit,
|