tab-hero / configs /default.yaml
MattGroho's picture
Add default config
cc67854 verified
# Tab Hero - Default Configuration
#
# This configuration file uses Hydra for hierarchical config management.
# Override values via command line: python train.py model.encoder_dim=768
#
# Architecture optimized for:
# - Arbitrary-length song generation (RoPE + streaming)
# - Memory efficiency (Flash Attention, bf16, gradient accumulation)
# - 32GB VRAM (RTX 5090)
# Data configuration
data:
data_dir: "data/processed" # Directory containing .tab files
use_tab_format: true # Use preprocessed .tab files (faster, recommended)
use_chunked_dataset: true # Use ChunkedTabDataset for long songs
train_split: "train" # Split to use for training (from manifest.json)
val_split: "val" # Split to use for validation (from manifest.json)
instrument: "lead" # Only used when use_tab_format=false
difficulty: "expert" # Only used when use_tab_format=false
max_audio_duration_s: 300.0 # Max audio length - 5 min (raw mode only)
max_sequence_length: 8192 # Max token sequence length (increased for long songs)
max_mel_frames: 8192 # Max mel frames per chunk (.tab mode only)
chunk_overlap_frames: 512 # Overlap between chunks for continuity
batch_size: 16 # Increased from 8 for better GPU utilization
num_workers: 8 # Set >0 for multiprocessing
prefetch_factor: 2 # Batches to prefetch per worker
max_samples: null # Set to limit samples for testing (e.g., 100)
curriculum_learning: false # Enable difficulty-based curriculum learning
curriculum_schedule: # List of [epoch, max_difficulty] pairs
- [0, 1] # Epochs 0-9: easy + medium
- [10, 2] # Epochs 10-24: + hard
- [25, 3] # Epochs 25+: all difficulties
# Audio processing (must match preprocessing MEL_CONFIG)
audio:
sample_rate: 22050 # Audio sample rate
hop_length: 256 # Mel spectrogram hop length
n_mels: 128 # Number of mel bands
frame_rate: 86.1 # Derived: sample_rate / hop_length
# Model architecture (Large: ~100M params)
model:
audio_input_dim: 128 # Mel spectrogram bins (n_mels)
encoder_dim: 768 # Increased from 512 for more capacity
decoder_dim: 768 # Increased from 512 for more capacity
n_decoder_layers: 8 # Increased from 6 for deeper network
n_heads: 12 # Increased from 8 for finer attention
ffn_dim: 3072 # Increased from 2048 (4x decoder_dim)
max_seq_len: 8192 # Increased for long songs
dropout: 0.1
audio_downsample: 4 # 4x downsampling (was 2x) for longer context
use_flash: true # Flash Attention 2 for O(n) memory
use_rope: true # RoPE for length extrapolation (critical for long songs)
gradient_checkpointing: true # Keep enabled to fit larger model in VRAM
# Training configuration
training:
learning_rate: 1.0e-4
weight_decay: 0.01
max_epochs: 100 # Increased from 50 for better convergence
warmup_steps: 1000
gradient_clip: 1.0
gradient_accumulation_steps: 2 # Effective batch = batch_size * 2 = 32
checkpoint_dir: "checkpoints"
log_every_n_steps: 100
val_every_n_epochs: 1
use_onecycle_lr: false # Use cosine annealing (OneCycle optional)
keep_top_k_checkpoints: 5 # Keep top 5 checkpoints by val_loss
early_stopping_patience: 15 # Stop if no improvement for 15 epochs (0=disabled)
resume_checkpoint: null # Filename relative to checkpoint_dir, e.g. last_model.pt
# Inference configuration
inference:
temperature: 1.0
top_k: 50
top_p: 0.95
use_kv_cache: true # KV caching for fast generation
chunk_size: 4096 # Audio frames per chunk for streaming
chunk_overlap: 512 # Overlap for streaming generation
# Logging
logging:
log_level: "INFO"
# Hardware
hardware:
device: "cuda"
precision: "bf16-mixed" # bf16 mixed precision (stable, no scaler needed)
compile: false # torch.compile (enable for speed on CUDA 12+)