| """
|
| Training configuration for TouchGrass models.
|
| Covers both 3B and 7B variants with hardware-specific optimizations.
|
| """
|
|
|
| import torch
|
|
|
| TRAINING_CONFIG = {
|
|
|
| "learning_rate": 2e-4,
|
| "weight_decay": 0.1,
|
| "beta1": 0.9,
|
| "beta2": 0.95,
|
| "clip_grad_norm": 1.0,
|
|
|
|
|
| "global_batch_size": 512,
|
| "micro_batch_size": 8,
|
| "gradient_accumulation_steps": 4,
|
|
|
|
|
| "max_steps": 50000,
|
| "warmup_steps": 2000,
|
| "save_interval": 5000,
|
| "eval_interval": 1000,
|
| "log_interval": 100,
|
|
|
|
|
| "use_amp": True,
|
| "amp_dtype": torch.bfloat16,
|
|
|
|
|
| "optimizer": "AdamW",
|
| "use_fused": True,
|
|
|
|
|
| "loss_weights": {
|
| "lm_loss": 1.0,
|
| "eq_loss": 0.1,
|
| "music_module_loss": 0.05,
|
| },
|
|
|
|
|
| "checkpoint_dir": "checkpoints",
|
| "save_optimizer_state": True,
|
| "save_scheduler_state": True,
|
|
|
|
|
| "log_dir": "logs",
|
| "use_wandb": False,
|
| "wandb_project": "touchgrass-music",
|
|
|
|
|
| "num_workers": 8,
|
| "prefetch_factor": 2,
|
| "pin_memory": True,
|
|
|
|
|
| "device": "cuda",
|
| "use_mps": False,
|
|
|
|
|
| "quantization": None,
|
| }
|
|
|
|
|
| TRAINING_CONFIG_3B_CUDA = TRAINING_CONFIG.copy()
|
| TRAINING_CONFIG_3B_CUDA.update({
|
| "device": "cuda",
|
| "quantization": None,
|
| "micro_batch_size": 8,
|
| })
|
|
|
| TRAINING_CONFIG_7B_CUDA = TRAINING_CONFIG.copy()
|
| TRAINING_CONFIG_7B_CUDA.update({
|
| "device": "cuda",
|
| "quantization": None,
|
| "micro_batch_size": 4,
|
| })
|
|
|
| TRAINING_CONFIG_MPS = TRAINING_CONFIG.copy()
|
| TRAINING_CONFIG_MPS.update({
|
| "device": "mps",
|
| "use_mps": True,
|
| "use_amp": False,
|
| "micro_batch_size": 4,
|
| }) |