File size: 4,108 Bytes
cc67854
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
# Tab Hero - Default Configuration
#
# This configuration file uses Hydra for hierarchical config management.
# Override values via command line: python train.py model.encoder_dim=768
#
# Architecture optimized for:
# - Arbitrary-length song generation (RoPE + streaming)
# - Memory efficiency (Flash Attention, bf16, gradient accumulation)
# - 32GB VRAM (RTX 5090)

# Data configuration
data:
  data_dir: "data/processed"  # Directory containing .tab files
  use_tab_format: true        # Use preprocessed .tab files (faster, recommended)
  use_chunked_dataset: true   # Use ChunkedTabDataset for long songs
  train_split: "train"        # Split to use for training (from manifest.json)
  val_split: "val"            # Split to use for validation (from manifest.json)
  instrument: "lead"          # Only used when use_tab_format=false
  difficulty: "expert"        # Only used when use_tab_format=false
  max_audio_duration_s: 300.0 # Max audio length - 5 min (raw mode only)
  max_sequence_length: 8192   # Max token sequence length (increased for long songs)
  max_mel_frames: 8192        # Max mel frames per chunk (.tab mode only)
  chunk_overlap_frames: 512   # Overlap between chunks for continuity
  batch_size: 16              # Increased from 8 for better GPU utilization
  num_workers: 8              # Set >0 for multiprocessing
  prefetch_factor: 2          # Batches to prefetch per worker
  max_samples: null           # Set to limit samples for testing (e.g., 100)
  curriculum_learning: false   # Enable difficulty-based curriculum learning
  curriculum_schedule:         # List of [epoch, max_difficulty] pairs
    - [0, 1]                   # Epochs 0-9: easy + medium
    - [10, 2]                  # Epochs 10-24: + hard
    - [25, 3]                  # Epochs 25+: all difficulties

# Audio processing (must match preprocessing MEL_CONFIG)
audio:
  sample_rate: 22050          # Audio sample rate
  hop_length: 256             # Mel spectrogram hop length
  n_mels: 128                 # Number of mel bands
  frame_rate: 86.1            # Derived: sample_rate / hop_length

# Model architecture (Large: ~100M params)
model:
  audio_input_dim: 128        # Mel spectrogram bins (n_mels)
  encoder_dim: 768            # Increased from 512 for more capacity
  decoder_dim: 768            # Increased from 512 for more capacity
  n_decoder_layers: 8         # Increased from 6 for deeper network
  n_heads: 12                 # Increased from 8 for finer attention
  ffn_dim: 3072               # Increased from 2048 (4x decoder_dim)
  max_seq_len: 8192           # Increased for long songs
  dropout: 0.1
  audio_downsample: 4         # 4x downsampling (was 2x) for longer context
  use_flash: true             # Flash Attention 2 for O(n) memory
  use_rope: true              # RoPE for length extrapolation (critical for long songs)
  gradient_checkpointing: true  # Keep enabled to fit larger model in VRAM

# Training configuration
training:
  learning_rate: 1.0e-4
  weight_decay: 0.01
  max_epochs: 100             # Increased from 50 for better convergence
  warmup_steps: 1000
  gradient_clip: 1.0
  gradient_accumulation_steps: 2  # Effective batch = batch_size * 2 = 32
  checkpoint_dir: "checkpoints"
  log_every_n_steps: 100
  val_every_n_epochs: 1
  use_onecycle_lr: false      # Use cosine annealing (OneCycle optional)
  keep_top_k_checkpoints: 5   # Keep top 5 checkpoints by val_loss
  early_stopping_patience: 15 # Stop if no improvement for 15 epochs (0=disabled)
  resume_checkpoint: null     # Filename relative to checkpoint_dir, e.g. last_model.pt

# Inference configuration
inference:
  temperature: 1.0
  top_k: 50
  top_p: 0.95
  use_kv_cache: true          # KV caching for fast generation
  chunk_size: 4096            # Audio frames per chunk for streaming
  chunk_overlap: 512          # Overlap for streaming generation

# Logging
logging:
  log_level: "INFO"

# Hardware
hardware:
  device: "cuda"
  precision: "bf16-mixed"     # bf16 mixed precision (stable, no scaler needed)
  compile: false              # torch.compile (enable for speed on CUDA 12+)