| # Tab Hero - Default Configuration | |
| # | |
| # This configuration file uses Hydra for hierarchical config management. | |
| # Override values via command line: python train.py model.encoder_dim=768 | |
| # | |
| # Architecture optimized for: | |
| # - Arbitrary-length song generation (RoPE + streaming) | |
| # - Memory efficiency (Flash Attention, bf16, gradient accumulation) | |
| # - 32GB VRAM (RTX 5090) | |
| # Data configuration | |
| data: | |
| data_dir: "data/processed" # Directory containing .tab files | |
| use_tab_format: true # Use preprocessed .tab files (faster, recommended) | |
| use_chunked_dataset: true # Use ChunkedTabDataset for long songs | |
| train_split: "train" # Split to use for training (from manifest.json) | |
| val_split: "val" # Split to use for validation (from manifest.json) | |
| instrument: "lead" # Only used when use_tab_format=false | |
| difficulty: "expert" # Only used when use_tab_format=false | |
| max_audio_duration_s: 300.0 # Max audio length - 5 min (raw mode only) | |
| max_sequence_length: 8192 # Max token sequence length (increased for long songs) | |
| max_mel_frames: 8192 # Max mel frames per chunk (.tab mode only) | |
| chunk_overlap_frames: 512 # Overlap between chunks for continuity | |
| batch_size: 16 # Increased from 8 for better GPU utilization | |
| num_workers: 8 # Set >0 for multiprocessing | |
| prefetch_factor: 2 # Batches to prefetch per worker | |
| max_samples: null # Set to limit samples for testing (e.g., 100) | |
| curriculum_learning: false # Enable difficulty-based curriculum learning | |
| curriculum_schedule: # List of [epoch, max_difficulty] pairs | |
| - [0, 1] # Epochs 0-9: easy + medium | |
| - [10, 2] # Epochs 10-24: + hard | |
| - [25, 3] # Epochs 25+: all difficulties | |
| # Audio processing (must match preprocessing MEL_CONFIG) | |
| audio: | |
| sample_rate: 22050 # Audio sample rate | |
| hop_length: 256 # Mel spectrogram hop length | |
| n_mels: 128 # Number of mel bands | |
| frame_rate: 86.1 # Derived: sample_rate / hop_length | |
| # Model architecture (Large: ~100M params) | |
| model: | |
| audio_input_dim: 128 # Mel spectrogram bins (n_mels) | |
| encoder_dim: 768 # Increased from 512 for more capacity | |
| decoder_dim: 768 # Increased from 512 for more capacity | |
| n_decoder_layers: 8 # Increased from 6 for deeper network | |
| n_heads: 12 # Increased from 8 for finer attention | |
| ffn_dim: 3072 # Increased from 2048 (4x decoder_dim) | |
| max_seq_len: 8192 # Increased for long songs | |
| dropout: 0.1 | |
| audio_downsample: 4 # 4x downsampling (was 2x) for longer context | |
| use_flash: true # Flash Attention 2 for O(n) memory | |
| use_rope: true # RoPE for length extrapolation (critical for long songs) | |
| gradient_checkpointing: true # Keep enabled to fit larger model in VRAM | |
| # Training configuration | |
| training: | |
| learning_rate: 1.0e-4 | |
| weight_decay: 0.01 | |
| max_epochs: 100 # Increased from 50 for better convergence | |
| warmup_steps: 1000 | |
| gradient_clip: 1.0 | |
| gradient_accumulation_steps: 2 # Effective batch = batch_size * 2 = 32 | |
| checkpoint_dir: "checkpoints" | |
| log_every_n_steps: 100 | |
| val_every_n_epochs: 1 | |
| use_onecycle_lr: false # Use cosine annealing (OneCycle optional) | |
| keep_top_k_checkpoints: 5 # Keep top 5 checkpoints by val_loss | |
| early_stopping_patience: 15 # Stop if no improvement for 15 epochs (0=disabled) | |
| resume_checkpoint: null # Filename relative to checkpoint_dir, e.g. last_model.pt | |
| # Inference configuration | |
| inference: | |
| temperature: 1.0 | |
| top_k: 50 | |
| top_p: 0.95 | |
| use_kv_cache: true # KV caching for fast generation | |
| chunk_size: 4096 # Audio frames per chunk for streaming | |
| chunk_overlap: 512 # Overlap for streaming generation | |
| # Logging | |
| logging: | |
| log_level: "INFO" | |
| # Hardware | |
| hardware: | |
| device: "cuda" | |
| precision: "bf16-mixed" # bf16 mixed precision (stable, no scaler needed) | |
| compile: false # torch.compile (enable for speed on CUDA 12+) | |