Spaces:
Sleeping
Sleeping
Add hydra integration and configuration support in train.py, allowing dynamic model loading and training control. Update requirements.txt to include hydra-core dependency and introduce config.yaml for model parameters and training settings.
Browse files- conf/config.yaml +6 -0
- requirements.txt +1 -0
- train.py +24 -17
conf/config.yaml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- _self_
|
| 3 |
+
|
| 4 |
+
model_name: "unsloth/SmolLM2-135M-Instruct-bnb-4bit"
|
| 5 |
+
train: false
|
| 6 |
+
output_dir: "final_model"
|
requirements.txt
CHANGED
|
@@ -4,6 +4,7 @@ bitsandbytes>=0.45.5
|
|
| 4 |
duckduckgo-search>=8.0.1
|
| 5 |
gradio[oauth]>=5.26.0
|
| 6 |
hf-xet>=1.0.5
|
|
|
|
| 7 |
ipywidgets>=8.1.6
|
| 8 |
isort>=6.0.1
|
| 9 |
jupyter>=1.1.1
|
|
|
|
| 4 |
duckduckgo-search>=8.0.1
|
| 5 |
gradio[oauth]>=5.26.0
|
| 6 |
hf-xet>=1.0.5
|
| 7 |
+
hydra-core>=1.3.2
|
| 8 |
ipywidgets>=8.1.6
|
| 9 |
isort>=6.0.1
|
| 10 |
jupyter>=1.1.1
|
train.py
CHANGED
|
@@ -19,6 +19,9 @@ from datetime import datetime
|
|
| 19 |
from pathlib import Path
|
| 20 |
from typing import Union
|
| 21 |
|
|
|
|
|
|
|
|
|
|
| 22 |
# isort: off
|
| 23 |
from unsloth import FastLanguageModel, is_bfloat16_supported # noqa: E402
|
| 24 |
from unsloth.chat_templates import get_chat_template # noqa: E402
|
|
@@ -41,7 +44,6 @@ from transformers import (
|
|
| 41 |
from trl import SFTTrainer
|
| 42 |
|
| 43 |
# Configuration
|
| 44 |
-
DEFAULT_MODEL_NAME = "unsloth/SmolLM2-135M-Instruct-bnb-4bit"
|
| 45 |
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
|
| 46 |
load_in_4bit = True # Use 4bit quantization to reduce memory usage
|
| 47 |
max_seq_length = 2048 # Auto supports RoPE Scaling internally
|
|
@@ -88,7 +90,7 @@ def install_dependencies():
|
|
| 88 |
raise
|
| 89 |
|
| 90 |
|
| 91 |
-
def load_model(model_name: str
|
| 92 |
"""Load and configure the model."""
|
| 93 |
logger.info("Loading model and tokenizer...")
|
| 94 |
try:
|
|
@@ -241,16 +243,18 @@ def create_trainer(
|
|
| 241 |
raise
|
| 242 |
|
| 243 |
|
| 244 |
-
|
|
|
|
| 245 |
"""Main training function."""
|
| 246 |
try:
|
| 247 |
logger.info("Starting training process...")
|
|
|
|
| 248 |
|
| 249 |
# Install dependencies
|
| 250 |
install_dependencies()
|
| 251 |
|
| 252 |
# Load model and tokenizer
|
| 253 |
-
model, tokenizer = load_model()
|
| 254 |
|
| 255 |
# Load and prepare dataset
|
| 256 |
dataset, tokenizer = load_and_format_dataset(tokenizer)
|
|
@@ -258,19 +262,22 @@ def main():
|
|
| 258 |
# Create trainer
|
| 259 |
trainer: Trainer = create_trainer(model, tokenizer, dataset)
|
| 260 |
|
| 261 |
-
# Train
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
|
|
|
|
|
|
|
|
|
| 274 |
|
| 275 |
except Exception as e:
|
| 276 |
logger.error(f"Error in main training process: {e}")
|
|
|
|
| 19 |
from pathlib import Path
|
| 20 |
from typing import Union
|
| 21 |
|
| 22 |
+
import hydra
|
| 23 |
+
from omegaconf import DictConfig, OmegaConf
|
| 24 |
+
|
| 25 |
# isort: off
|
| 26 |
from unsloth import FastLanguageModel, is_bfloat16_supported # noqa: E402
|
| 27 |
from unsloth.chat_templates import get_chat_template # noqa: E402
|
|
|
|
| 44 |
from trl import SFTTrainer
|
| 45 |
|
| 46 |
# Configuration
|
|
|
|
| 47 |
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
|
| 48 |
load_in_4bit = True # Use 4bit quantization to reduce memory usage
|
| 49 |
max_seq_length = 2048 # Auto supports RoPE Scaling internally
|
|
|
|
| 90 |
raise
|
| 91 |
|
| 92 |
|
| 93 |
+
def load_model(model_name: str) -> tuple[FastLanguageModel, AutoTokenizer]:
|
| 94 |
"""Load and configure the model."""
|
| 95 |
logger.info("Loading model and tokenizer...")
|
| 96 |
try:
|
|
|
|
| 243 |
raise
|
| 244 |
|
| 245 |
|
| 246 |
+
@hydra.main(version_base=None, config_path="conf", config_name="config")
|
| 247 |
+
def main(cfg: DictConfig) -> None:
|
| 248 |
"""Main training function."""
|
| 249 |
try:
|
| 250 |
logger.info("Starting training process...")
|
| 251 |
+
logger.info(f"Configuration:\n{OmegaConf.to_yaml(cfg)}")
|
| 252 |
|
| 253 |
# Install dependencies
|
| 254 |
install_dependencies()
|
| 255 |
|
| 256 |
# Load model and tokenizer
|
| 257 |
+
model, tokenizer = load_model(cfg.model_name)
|
| 258 |
|
| 259 |
# Load and prepare dataset
|
| 260 |
dataset, tokenizer = load_and_format_dataset(tokenizer)
|
|
|
|
| 262 |
# Create trainer
|
| 263 |
trainer: Trainer = create_trainer(model, tokenizer, dataset)
|
| 264 |
|
| 265 |
+
# Train if requested
|
| 266 |
+
if cfg.train:
|
| 267 |
+
logger.info("Starting training...")
|
| 268 |
+
trainer.train()
|
| 269 |
+
|
| 270 |
+
# Save model
|
| 271 |
+
logger.info(f"Saving final model to {cfg.output_dir}...")
|
| 272 |
+
trainer.save_model(cfg.output_dir)
|
| 273 |
+
|
| 274 |
+
# Print final metrics
|
| 275 |
+
final_metrics = trainer.state.log_history[-1]
|
| 276 |
+
logger.info("\nTraining completed!")
|
| 277 |
+
logger.info(f"Final training loss: {final_metrics.get('loss', 'N/A')}")
|
| 278 |
+
logger.info(f"Final validation loss: {final_metrics.get('eval_loss', 'N/A')}")
|
| 279 |
+
else:
|
| 280 |
+
logger.info("Training skipped as train=False")
|
| 281 |
|
| 282 |
except Exception as e:
|
| 283 |
logger.error(f"Error in main training process: {e}")
|