EM_TEST2 / util /base_train_config.py
Jack-Payne1's picture
Upload folder using huggingface_hub
256ee59 verified
import os
from typing import Dict, List, Literal, Optional, Union
from pydantic import BaseModel, Field, field_validator, model_validator
class TrainingConfig(BaseModel):
class Config:
extra = "forbid" # Prevent extra fields not defined in the model
# Required model and data paths
model: str = Field(..., description="Hugging Face model ID")
training_file: str = Field(..., description="File ID of the training dataset")
test_file: Optional[str] = Field(None, description="File ID of the test dataset")
# Output model
finetuned_model_id: str = Field('{org_id}/{model_name}-{job_id}', description="File ID of the finetuned model")
# Model configuration
max_seq_length: int = Field(2048, description="Maximum sequence length for training")
load_in_4bit: bool = Field(False, description="Whether to load model in 4-bit quantization")
# Training type configuration
loss: Literal["dpo", "orpo", "sft"] = Field(..., description="Loss function / training type")
# PEFT configuration
is_peft: bool = Field(True, description="Whether to use PEFT for training")
target_modules: Optional[List[str]] = Field(
default=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
],
description="Target modules for LoRA"
)
layers_to_transform: Optional[List[int]] = Field(
None,
description="Layers to transform for LoRA. If None, all layers will be transformed."
)
lora_bias: Literal["all", "none"] = Field("none", description="Value for FastLanguageModel.get_peft_model(bias=?)")
# LoRA specific arguments
r: int = Field(16, description="LoRA attention dimension")
lora_alpha: int = Field(16, description="LoRA alpha parameter")
lora_dropout: float = Field(0.0, description="LoRA dropout rate")
use_rslora: bool = Field(True, description="Whether to use RSLoRA")
merge_before_push: bool = Field(True, description="Whether to merge model before pushing to Hub. Only merged models can be used as parent models for further finetunes. Only supported for bf16 models.")
push_to_private: bool = Field(True, description="Whether to push to private Hub")
push_only_adapters: bool = Field(False, description="Whether to push only the LoRA adapters to Hub instead of the full model")
# Training hyperparameters
epochs: int = Field(1, description="Number of training epochs")
max_steps: Optional[int] = Field(None, description="Maximum number of training steps")
per_device_train_batch_size: int = Field(2, description="Training batch size per device")
gradient_accumulation_steps: int = Field(8, description="Number of gradient accumulation steps")
warmup_steps: int = Field(5, description="Number of warmup steps")
learning_rate: Union[float, str] = Field(1e-4, description="Learning rate or string expression")
logging_steps: int = Field(1, description="Number of steps between logging")
evaluation_steps: int = Field(50, description="Number of steps between evaluations on the test set")
optim: str = Field("adamw_8bit", description="Optimizer to use for training")
weight_decay: float = Field(0.01, description="Weight decay rate")
lr_scheduler_type: str = Field("linear", description="Learning rate scheduler type")
seed: int = Field(3407, description="Random seed for reproducibility")
beta: float = Field(0.1, description="Beta parameter for DPO/ORPO training")
save_steps: int = Field(5000, description="Save checkpoint every X steps")
output_dir: str = Field("./tmp", description="Output directory for training checkpoints")
train_on_responses_only: bool = Field(False, description="Whether to train on responses only")
@model_validator(mode="before")
def validate_training_file_prefixes(cls, values):
loss = values.get('loss', 'orpo')
training_file = values.get('training_file')
if os.path.exists(training_file):
return values
# if loss == 'sft' and not training_file.startswith('conversations'):
# raise ValueError(f"For SFT training, dataset filename must start with 'conversations', got: {training_file}")
if loss in ['dpo', 'orpo'] and not training_file.startswith('preference'):
raise ValueError(f"For DPO/ORPO training, dataset filename must start with 'preference', got: {training_file}")
return values
@field_validator("finetuned_model_id")
def validate_finetuned_model_id(cls, v):
# if v and model_exists(v):
# raise ValueError(f"Model {v} already exists")
if len(v.split("/")) != 2:
raise ValueError("Model ID must be in the format 'user/model'")
org, model = v.split("/")
if org in ["datasets", "models", "unsloth", "None"]:
raise ValueError(f"You have set org={org}, but it must be an org you have access to")
return v
@field_validator("learning_rate", mode="before")
def validate_learning_rate(cls, v):
if isinstance(v, float) and v <= 0:
raise ValueError("Learning rate must be positive")
return v
@field_validator("lora_dropout")
def validate_dropout(cls, v):
if not 0 <= v <= 1:
raise ValueError("Dropout rate must be between 0 and 1")
return v
@field_validator("optim")
def validate_optimizer(cls, v):
allowed_optimizers = ["adamw_8bit", "adamw", "adam", "sgd"]
if v not in allowed_optimizers:
raise ValueError(f"Optimizer must be one of {allowed_optimizers}")
return v
@field_validator("lr_scheduler_type")
def validate_scheduler(cls, v):
allowed_schedulers = ["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"]
if v not in allowed_schedulers:
raise ValueError(f"Scheduler must be one of {allowed_schedulers}")
return v