import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm import argparse import os from .model import SEResNeXt from .data import BeatTrackingDataset from ..baseline1.utils import MultiViewSpectrogram from ..data.load import ds def train(target_type: str, output_dir: str): DEVICE = "cuda" if torch.cuda.is_available() else "cpu" BATCH_SIZE = 128 # Reduced batch size due to larger context EPOCHS = 3 LR = 0.001 # Adjusted LR for Adam (ResNet usually prefers Adam/AdamW) NUM_WORKERS = 4 CONTEXT_FRAMES = 50 # +/- 50 frames -> 101 frames total PATIENCE = 5 # Early stopping patience print(f"--- Training Model for target: {target_type} ---") print(f"Output directory: {output_dir}") # Create output directory os.makedirs(output_dir, exist_ok=True) # TensorBoard writer writer = SummaryWriter(log_dir=os.path.join(output_dir, "logs")) # Data train_dataset = BeatTrackingDataset( ds["train"], target_type=target_type, context_frames=CONTEXT_FRAMES ) val_dataset = BeatTrackingDataset( ds["test"], target_type=target_type, context_frames=CONTEXT_FRAMES ) train_loader = DataLoader( train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True, prefetch_factor=4, persistent_workers=True, ) val_loader = DataLoader( val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True, prefetch_factor=4, persistent_workers=True, ) print(f"Train samples: {len(train_dataset)}, Val samples: {len(val_dataset)}") # Model model = SEResNeXt(dropout_rate=0.5).to(DEVICE) # GPU Spectrogram Preprocessor preprocessor = MultiViewSpectrogram(sample_rate=16000, hop_length=160).to(DEVICE) # Optimizer - Using AdamW for ResNet optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-4) scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS) criterion = nn.BCELoss() # Binary Cross Entropy best_val_loss = float("inf") patience_counter = 0 global_step = 0 for epoch in range(EPOCHS): # Training model.train() total_train_loss = 0 for waveform, y in tqdm( train_loader, desc=f"[{target_type}] Epoch {epoch + 1}/{EPOCHS} Train", leave=False, ): waveform, y = waveform.to(DEVICE), y.to(DEVICE) # Compute spectrogram on GPU with torch.no_grad(): spec = preprocessor(waveform) # (B, 3, 80, T_raw) # Normalize mean = spec.mean(dim=(2, 3), keepdim=True) std = spec.std(dim=(2, 3), keepdim=True) + 1e-6 spec = (spec - mean) / std T_curr = spec.shape[-1] target_T = CONTEXT_FRAMES * 2 + 1 if T_curr > target_T: start = (T_curr - target_T) // 2 x = spec[:, :, :, start : start + target_T] elif T_curr < target_T: # This shouldn't happen if dataset is correct, but just in case pad pad = target_T - T_curr x = torch.nn.functional.pad(spec, (0, pad)) else: x = spec optimizer.zero_grad() output = model(x) loss = criterion(output, y) loss.backward() optimizer.step() total_train_loss += loss.item() global_step += 1 # Log batch loss writer.add_scalar("train/batch_loss", loss.item(), global_step) avg_train_loss = total_train_loss / len(train_loader) # Validation model.eval() total_val_loss = 0 with torch.no_grad(): for waveform, y in tqdm( val_loader, desc=f"[{target_type}] Epoch {epoch + 1}/{EPOCHS} Val", leave=False, ): waveform, y = waveform.to(DEVICE), y.to(DEVICE) # Compute spectrogram on GPU spec = preprocessor(waveform) # (B, 3, 80, T) # Normalize mean = spec.mean(dim=(2, 3), keepdim=True) std = spec.std(dim=(2, 3), keepdim=True) + 1e-6 spec = (spec - mean) / std T_curr = spec.shape[-1] target_T = CONTEXT_FRAMES * 2 + 1 if T_curr > target_T: start = (T_curr - target_T) // 2 x = spec[:, :, :, start : start + target_T] else: pad = target_T - T_curr x = torch.nn.functional.pad(spec, (0, pad)) output = model(x) loss = criterion(output, y) total_val_loss += loss.item() avg_val_loss = total_val_loss / len(val_loader) # Log epoch metrics writer.add_scalar("train/epoch_loss", avg_train_loss, epoch) writer.add_scalar("val/loss", avg_val_loss, epoch) writer.add_scalar("train/learning_rate", scheduler.get_last_lr()[0], epoch) # Step the scheduler scheduler.step() print( f"[{target_type}] Epoch {epoch + 1}/{EPOCHS} - " f"Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}" ) # Save best model if avg_val_loss < best_val_loss: best_val_loss = avg_val_loss patience_counter = 0 model.save_pretrained(output_dir) print(f" -> Saved best model (val_loss: {best_val_loss:.4f})") else: patience_counter += 1 print(f" -> No improvement (patience: {patience_counter}/{PATIENCE})") if patience_counter >= PATIENCE: print("Early stopping triggered.") break writer.close() # Save final model final_dir = os.path.join(output_dir, "final") model.save_pretrained(final_dir) print(f"Saved final model to {final_dir}") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--target", type=str, choices=["beats", "downbeats"], default=None, help="Train a model for 'beats' or 'downbeats'. If not specified, trains both.", ) parser.add_argument( "--output-dir", type=str, default="outputs/baseline2", help="Directory to save model and logs", ) args = parser.parse_args() # Determine which targets to train targets = [args.target] if args.target else ["beats", "downbeats"] for target in targets: output_dir = os.path.join(args.output_dir, target) train(target, output_dir)