TouchGrass-3b / configs /training_config.py
Zandy-Wandy's picture
Upload 39 files
9071ef9 verified
"""
Training configuration for TouchGrass models.
Covers both 3B and 7B variants with hardware-specific optimizations.
"""
import torch
TRAINING_CONFIG = {
# Training hyperparameters
"learning_rate": 2e-4, # LoRA learning rate
"weight_decay": 0.1,
"beta1": 0.9,
"beta2": 0.95,
"clip_grad_norm": 1.0,
# Batch sizing
"global_batch_size": 512, # tokens per batch
"micro_batch_size": 8, # per GPU
"gradient_accumulation_steps": 4,
# Training schedule
"max_steps": 50000,
"warmup_steps": 2000,
"save_interval": 5000,
"eval_interval": 1000,
"log_interval": 100,
# Mixed precision
"use_amp": True,
"amp_dtype": torch.bfloat16,
# Optimizer
"optimizer": "AdamW",
"use_fused": True,
# Loss weights (music-aware loss)
"loss_weights": {
"lm_loss": 1.0,
"eq_loss": 0.1, # Frustration detection loss
"music_module_loss": 0.05, # Music module auxiliary losses
},
# Checkpointing
"checkpoint_dir": "checkpoints",
"save_optimizer_state": True,
"save_scheduler_state": True,
# Logging
"log_dir": "logs",
"use_wandb": False,
"wandb_project": "touchgrass-music",
# Data loading
"num_workers": 8,
"prefetch_factor": 2,
"pin_memory": True,
# Device configuration
"device": "cuda",
"use_mps": False,
# Quantization
"quantization": None, # None, "int8", "int4"
}
# Hardware-specific overrides
TRAINING_CONFIG_3B_CUDA = TRAINING_CONFIG.copy()
TRAINING_CONFIG_3B_CUDA.update({
"device": "cuda",
"quantization": None,
"micro_batch_size": 8,
})
TRAINING_CONFIG_7B_CUDA = TRAINING_CONFIG.copy()
TRAINING_CONFIG_7B_CUDA.update({
"device": "cuda",
"quantization": None,
"micro_batch_size": 4, # 7B needs smaller batch
})
TRAINING_CONFIG_MPS = TRAINING_CONFIG.copy()
TRAINING_CONFIG_MPS.update({
"device": "mps",
"use_mps": True,
"use_amp": False,
"micro_batch_size": 4,
})