import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm import argparse import os from .model import ODCNN from .data import BeatTrackingDataset from .utils import MultiViewSpectrogram from ..data.load import ds def train(target_type: str, output_dir: str): # Note: Paper uses SGD with Momentum, Dropout, and ReLU DEVICE = "cuda" if torch.cuda.is_available() else "cpu" BATCH_SIZE = 512 EPOCHS = 50 LR = 0.05 MOMENTUM = 0.9 NUM_WORKERS = 4 print(f"--- Training Model for target: {target_type} ---") print(f"Output directory: {output_dir}") # Create output directory os.makedirs(output_dir, exist_ok=True) # TensorBoard writer writer = SummaryWriter(log_dir=os.path.join(output_dir, "logs")) # Data - use existing train/test splits train_dataset = BeatTrackingDataset(ds["train"], target_type=target_type) val_dataset = BeatTrackingDataset(ds["test"], target_type=target_type) train_loader = DataLoader( train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True, prefetch_factor=4, persistent_workers=True, ) val_loader = DataLoader( val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True, prefetch_factor=4, persistent_workers=True, ) print(f"Train samples: {len(train_dataset)}, Val samples: {len(val_dataset)}") # Model model = ODCNN(dropout_rate=0.5).to(DEVICE) # GPU Spectrogram Preprocessor preprocessor = MultiViewSpectrogram(sample_rate=16000, hop_length=160).to(DEVICE) # Optimizer optimizer = optim.SGD(model.parameters(), lr=LR, momentum=MOMENTUM) scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS) criterion = nn.BCELoss() # Binary Cross Entropy best_val_loss = float("inf") global_step = 0 for epoch in range(EPOCHS): # Training model.train() total_train_loss = 0 for waveform, y in tqdm( train_loader, desc=f"[{target_type}] Epoch {epoch + 1}/{EPOCHS} Train", leave=False, ): waveform, y = waveform.to(DEVICE), y.to(DEVICE) # Compute spectrogram on GPU with torch.no_grad(): spec = preprocessor(waveform) # (B, 3, 80, T) # Normalize mean = spec.mean(dim=(2, 3), keepdim=True) std = spec.std(dim=(2, 3), keepdim=True) + 1e-6 spec = (spec - mean) / std # Extract center context (T should be ~15 frames) x = spec[:, :, :, 7:22] # center 15 frames optimizer.zero_grad() output = model(x) loss = criterion(output, y) loss.backward() optimizer.step() total_train_loss += loss.item() global_step += 1 # Log batch loss writer.add_scalar("train/batch_loss", loss.item(), global_step) avg_train_loss = total_train_loss / len(train_loader) # Validation model.eval() total_val_loss = 0 with torch.no_grad(): for waveform, y in tqdm( val_loader, desc=f"[{target_type}] Epoch {epoch + 1}/{EPOCHS} Val", leave=False, ): waveform, y = waveform.to(DEVICE), y.to(DEVICE) # Compute spectrogram on GPU spec = preprocessor(waveform) # (B, 3, 80, T) # Normalize mean = spec.mean(dim=(2, 3), keepdim=True) std = spec.std(dim=(2, 3), keepdim=True) + 1e-6 spec = (spec - mean) / std # Extract center context x = spec[:, :, :, 7:22] output = model(x) loss = criterion(output, y) total_val_loss += loss.item() avg_val_loss = total_val_loss / len(val_loader) # Log epoch metrics writer.add_scalar("train/epoch_loss", avg_train_loss, epoch) writer.add_scalar("val/loss", avg_val_loss, epoch) writer.add_scalar("train/learning_rate", scheduler.get_last_lr()[0], epoch) # Step the scheduler scheduler.step() print( f"[{target_type}] Epoch {epoch + 1}/{EPOCHS} - " f"Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}" ) # Save best model if avg_val_loss < best_val_loss: best_val_loss = avg_val_loss model.save_pretrained(output_dir) print(f" -> Saved best model (val_loss: {best_val_loss:.4f})") writer.close() # Save final model final_dir = os.path.join(output_dir, "final") model.save_pretrained(final_dir) print(f"Saved final model to {final_dir}") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--target", type=str, choices=["beats", "downbeats"], default=None, help="Train a model for 'beats' or 'downbeats'. If not specified, trains both.", ) parser.add_argument( "--output-dir", type=str, default="outputs/baseline1", help="Directory to save model and logs", ) args = parser.parse_args() # Determine which targets to train targets = [args.target] if args.target else ["beats", "downbeats"] for target in targets: output_dir = os.path.join(args.output_dir, target) train(target, output_dir)