Spaces:
Sleeping
Sleeping
| from pathlib import Path | |
| from typing import Dict, Optional | |
| import torch | |
| from torch.utils.data import DataLoader, Subset | |
| from torch.optim import AdamW | |
| from torch.optim.lr_scheduler import CosineAnnealingLR | |
| from torch.cuda.amp import GradScaler, autocast | |
| from tqdm import tqdm | |
| from demucs import pretrained | |
| from transformers import AutoTokenizer, ClapModel, ClapTextModelWithProjection | |
| from src.models.stem_separation.ATHTDemucs_v2 import AudioTextHTDemucs | |
| from src.loss import combined_loss, combined_L1_sdr_loss, sdr_loss | |
| from src.dataloader import MusDBStemDataset, collate_fn, STEM_PROMPTS, PROMPT_TO_STEM | |
| from utils import load_config, log_separation_spectrograms_to_wandb, log_audio_to_wandb | |
| # ============================================================================ | |
| # Training Helper Functions | |
| # ============================================================================ | |
| def train_epoch( | |
| model: AudioTextHTDemucs, | |
| dataloader: DataLoader, | |
| optimizer: torch.optim.Optimizer, | |
| scaler: Optional[GradScaler], | |
| device: str, | |
| use_amp: bool, | |
| use_L1_cmb_loss: bool, | |
| l1_sdr_weight: Optional[float], | |
| l1_weight: Optional[float], | |
| grad_clip: float, | |
| sdr_weight: float, | |
| sisdr_weight: float, | |
| epoch: int, | |
| log_every: int, | |
| use_wandb: bool, | |
| ) -> Dict[str, float]: | |
| """Train for one epoch.""" | |
| model.train() | |
| total_loss = 0.0 | |
| total_sdr = 0.0 | |
| total_sisdr = 0.0 | |
| num_batches = 0 | |
| # Set loss function | |
| if use_L1_cmb_loss: | |
| loss_function = combined_L1_sdr_loss | |
| weight1 = l1_sdr_weight | |
| if l1_weight is None: | |
| raise ValueError("l1_weight must be provided when using L1 combination loss.") | |
| weight2 = l1_weight | |
| print("**Using L1 + SDR combination loss for training") | |
| else: | |
| loss_function = combined_loss | |
| weight1 = sdr_weight | |
| weight2 = sisdr_weight | |
| pbar = tqdm(dataloader, desc=f"Epoch {epoch + 1}") | |
| for batch_idx, batch in enumerate(pbar): | |
| mixture = batch["mixture"].to(device) | |
| target = batch["target"].to(device) | |
| prompts = batch["prompt"] | |
| optimizer.zero_grad() | |
| # TODO: Add L1 + SDR combination loss option | |
| if use_amp and device == "cuda": | |
| with autocast(): | |
| estimated = model(mixture, prompts) | |
| loss, metrics = loss_function( | |
| estimated, target, weight1, weight2 | |
| ) | |
| scaler.scale(loss).backward() | |
| scaler.unscale_(optimizer) | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) | |
| scaler.step(optimizer) | |
| scaler.update() | |
| else: | |
| estimated = model(mixture, prompts) | |
| loss, metrics = loss_function( | |
| estimated, target, weight1, weight2 | |
| ) | |
| loss.backward() | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) | |
| optimizer.step() | |
| total_loss += metrics["loss/total"] | |
| total_sdr += metrics["metrics/sdr"] | |
| total_sisdr += metrics["metrics/sisdr"] | |
| num_batches += 1 | |
| pbar.set_postfix({ | |
| "loss": f"{metrics['loss/total']:.4f}", | |
| "SDR": f"{metrics['metrics/sdr']:.2f}", | |
| }) | |
| if use_wandb and batch_idx % log_every == 0: | |
| import wandb | |
| wandb.log({ | |
| "train/loss": metrics["loss/total"], | |
| "train/sdr": metrics["metrics/sdr"], | |
| "train/sisdr": metrics["metrics/sisdr"], | |
| "train/step": epoch * len(dataloader) + batch_idx, | |
| }) | |
| # Plot spectrograms for first sample in batch and log to wandb | |
| # NOTE: For now, only 1 extracted stem is visualized (should be extended to all stems later) | |
| stem_name_log = PROMPT_TO_STEM[prompts[0]] | |
| log_separation_spectrograms_to_wandb( | |
| mixture=mixture[0], | |
| estimated=estimated[0], | |
| reference=target[0], | |
| stem_name=stem_name_log, | |
| step=epoch * len(dataloader) + batch_idx, | |
| ) | |
| # Log audio to wandb | |
| log_audio_to_wandb(mixture[0], "mixture", is_gt=True) | |
| log_audio_to_wandb(target[0], stem_name_log, is_gt=True) | |
| log_audio_to_wandb(estimated[0], stem_name_log, is_gt=False) | |
| return { | |
| "loss": total_loss / num_batches, | |
| "sdr": total_sdr / num_batches, | |
| "sisdr": total_sisdr / num_batches, | |
| } | |
| def validate( | |
| model: AudioTextHTDemucs, | |
| dataloader: DataLoader, | |
| device: str, | |
| use_amp: bool, | |
| use_L1_cmb_loss: bool, | |
| l1_sdr_weight: Optional[float], | |
| l1_weight: Optional[float], | |
| sdr_weight: float = 0.9, | |
| sisdr_weight: float = 0.1, | |
| ) -> Dict[str, float]: | |
| """Validate the model.""" | |
| model.eval() | |
| total_loss = 0.0 | |
| total_sdr = 0.0 | |
| total_sisdr = 0.0 | |
| num_batches = 0 | |
| stem_metrics = {name: {"sdr": 0.0, "count": 0} for name in STEM_PROMPTS.keys()} | |
| # Set loss function | |
| if use_L1_cmb_loss: | |
| loss_function = combined_L1_sdr_loss | |
| weight1 = l1_sdr_weight | |
| if l1_weight is None: | |
| raise ValueError("l1_weight must be provided when using L1 combination loss.") | |
| weight2 = l1_weight | |
| else: | |
| loss_function = combined_loss | |
| weight1 = sdr_weight | |
| weight2 = sisdr_weight | |
| for batch in tqdm(dataloader, desc="Validating"): | |
| mixture = batch["mixture"].to(device) | |
| target = batch["target"].to(device) | |
| prompts = batch["prompt"] | |
| stem_names = batch["stem_name"] | |
| if use_amp and device == "cuda": | |
| with autocast(): | |
| estimated = model(mixture, prompts) | |
| loss, metrics = loss_function(estimated, target, weight1, weight2) | |
| else: | |
| estimated = model(mixture, prompts) | |
| loss, metrics = loss_function(estimated, target, weight1, weight2) | |
| total_loss += metrics["loss/total"] | |
| total_sdr += metrics["metrics/sdr"] | |
| total_sisdr += metrics["metrics/sisdr"] | |
| num_batches += 1 | |
| for i, stem_name in enumerate(stem_names): | |
| est_i = estimated[i:i + 1] | |
| tgt_i = target[i:i + 1] | |
| sdr_i = -sdr_loss(est_i, tgt_i).item() | |
| stem_metrics[stem_name]["sdr"] += sdr_i | |
| stem_metrics[stem_name]["count"] += 1 | |
| avg_metrics = { | |
| "loss": total_loss / num_batches, | |
| "sdr": total_sdr / num_batches, | |
| "sisdr": total_sisdr / num_batches, | |
| } | |
| for stem_name, data in stem_metrics.items(): | |
| if data["count"] > 0: | |
| avg_metrics[f"sdr/{stem_name}"] = data["sdr"] / data["count"] | |
| return avg_metrics | |
| def save_checkpoint( | |
| model: AudioTextHTDemucs, | |
| optimizer: torch.optim.Optimizer, | |
| scheduler: torch.optim.lr_scheduler._LRScheduler, | |
| epoch: int, | |
| metrics: Dict[str, float], | |
| checkpoint_dir: str, | |
| is_best: bool = False, | |
| ): | |
| """Save a training checkpoint.""" | |
| checkpoint_path = Path(checkpoint_dir) | |
| checkpoint_path.mkdir(parents=True, exist_ok=True) | |
| checkpoint = { | |
| "epoch": epoch, | |
| "model_state_dict": model.state_dict(), | |
| "optimizer_state_dict": optimizer.state_dict(), | |
| "scheduler_state_dict": scheduler.state_dict(), | |
| "metrics": metrics, | |
| } | |
| path = checkpoint_path / f"checkpoint_epoch_{epoch}.pt" | |
| torch.save(checkpoint, path) | |
| print(f"Saved checkpoint to {path}") | |
| if is_best: | |
| best_path = checkpoint_path / "best_model.pt" | |
| torch.save(checkpoint, best_path) | |
| print(f"Saved best model to {best_path}") | |
| latest_path = checkpoint_path / "latest.pt" | |
| torch.save(checkpoint, latest_path) | |
| def load_checkpoint( | |
| model: AudioTextHTDemucs, | |
| optimizer: Optional[torch.optim.Optimizer], | |
| scheduler: Optional[torch.optim.lr_scheduler._LRScheduler], | |
| checkpoint_path: str, | |
| ) -> int: | |
| """ | |
| Load a checkpoint and return the epoch number. | |
| Ignores any unused weights (e.g. if ClapTextModelWithProjection is being used but checkpoint has ClapModel with audio encoder weights). | |
| Also applies to optimizer and scheduler. | |
| """ | |
| checkpoint = torch.load(checkpoint_path, map_location="cpu") | |
| model.load_state_dict(checkpoint["model_state_dict"], strict=False) | |
| # Try loading optimizer and scheduler state, but ignore mismatches (due to new CLAP model, etc) | |
| try: | |
| optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) | |
| except Exception as e: | |
| print("Skipping optimizer state...") | |
| # Same idea for scheduler | |
| try: | |
| scheduler.load_state_dict(checkpoint["scheduler_state_dict"]) | |
| except: | |
| print("Skipping scheduler state...") | |
| print(f"Loaded checkpoint from epoch {checkpoint['epoch']}") | |
| return checkpoint["epoch"] | |
| # ============================================================================ | |
| # Main Training Function | |
| # ============================================================================ | |
| def train(config_path): | |
| """ | |
| Main training function for AudioTextHTDemucs. | |
| Args (loaded from YAML config): | |
| train_dir: Path to training data directory | |
| test_dir: Path to test/validation data directory | |
| checkpoint_dir: Path to save checkpoints | |
| sample_rate: Audio sample rate | |
| segment_seconds: Length of audio segments in seconds | |
| batch_size: Training batch size | |
| num_workers: Number of dataloader workers | |
| epochs: Number of training epochs | |
| learning_rate: Initial learning rate | |
| weight_decay: AdamW weight decay | |
| grad_clip: Gradient clipping value | |
| sdr_weight: Weight for SDR loss component | |
| sisdr_weight: Weight for SI-SDR loss component | |
| model_dim: Model hidden dimension | |
| text_dim: Text embedding dimension | |
| n_heads: Number of attention heads | |
| use_wandb: Whether to use Weights & Biases logging | |
| wandb_project: W&B project name | |
| wandb_run_name: W&B run name (optional) | |
| log_every: Log training metrics every N batches | |
| validate_every: Run validation every N epochs | |
| save_every: Save checkpoint every N epochs | |
| use_amp: Use automatic mixed precision | |
| device: Device to train on (auto-detected if None) | |
| resume_from: Path to checkpoint to resume from (optional) | |
| Returns: | |
| Dict containing final metrics and best SDR achieved | |
| """ | |
| # Load configuration | |
| cfg = load_config(config_path) | |
| data_cfg = cfg["data"] | |
| model_cfg = cfg["model"] | |
| training_cfg = cfg["training"] | |
| wandb_cfg = cfg["wandb"] | |
| # Paths | |
| train_dir = data_cfg.get("train_dir", "../data/train") | |
| test_dir = data_cfg.get("test_dir", "../data/test") | |
| checkpoint_dir = wandb_cfg.get("checkpoint_dir", "../checkpoints") | |
| # Data splits | |
| pct_train = data_cfg.get("pct_train", 1.0) | |
| pct_test = data_cfg.get("pct_test", 1.0) | |
| # Audio parameters | |
| sample_rate = data_cfg.get("sample_rate", 44100) | |
| segment_seconds = data_cfg.get("segment_seconds", 6.0) | |
| # Training parameters | |
| batch_size = training_cfg.get("batch_size", 4) | |
| num_workers = training_cfg.get("num_workers", 0) | |
| epochs = training_cfg.get("num_epochs", 10) | |
| learning_rate = float(training_cfg["optimizer"].get("lr", 1e-4)) | |
| weight_decay = float(training_cfg["optimizer"].get("weight_decay", 1e-5)) | |
| grad_clip = training_cfg["optimizer"].get("grad_clip", 1.0) | |
| use_L1_cmb_loss = training_cfg.get("use_L1_comb_loss", False) | |
| l1_sdr_weight = training_cfg["L1_comb_loss"].get("sdr_weight", 1.0) | |
| l1_weight = training_cfg["L1_comb_loss"].get("l1_weight", 0.05) | |
| # Loss weights | |
| sdr_weight = training_cfg["loss_weights"].get("sdr", 0.9) | |
| sisdr_weight = training_cfg["loss_weights"].get("sisdr", 0.1) | |
| # Model parameters | |
| model_dim = model_cfg.get("model_dim", 384) | |
| text_dim = model_cfg.get("text_dim", 512) | |
| n_heads = model_cfg.get("n_heads", 8) | |
| # Logging | |
| use_wandb = wandb_cfg.get("use_wandb", True) | |
| wandb_project = wandb_cfg.get("project", "audio-text-htdemucs") | |
| wandb_run_name = wandb_cfg.get("run_name", None) | |
| log_every = wandb_cfg.get("log_every", 50) | |
| validate_every = wandb_cfg.get("validate_every", 1) | |
| save_every = wandb_cfg.get("save_every", 1) | |
| # Mixed precision | |
| use_amp = training_cfg.get("use_amp", False) | |
| # Device | |
| device = model_cfg.get("device", None) | |
| # Resume training | |
| resume_from = training_cfg.get("resume_from", None) | |
| # Auto-detect device | |
| if device is None: | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| segment_samples = int(sample_rate * segment_seconds) | |
| # Initialize wandb | |
| if use_wandb: | |
| import wandb | |
| wandb.init( | |
| project=wandb_project, | |
| name=wandb_run_name, | |
| config={ | |
| "train_dir": train_dir, | |
| "test_dir": test_dir, | |
| "sample_rate": sample_rate, | |
| "segment_seconds": segment_seconds, | |
| "batch_size": batch_size, | |
| "epochs": epochs, | |
| "learning_rate": learning_rate, | |
| "weight_decay": weight_decay, | |
| "grad_clip": grad_clip, | |
| "sdr_weight": sdr_weight, | |
| "sisdr_weight": sisdr_weight, | |
| "model_dim": model_dim, | |
| "text_dim": text_dim, | |
| "n_heads": n_heads, | |
| "use_amp": use_amp, | |
| }, | |
| ) | |
| print("=" * 60) | |
| print("Audio-Text HTDemucs Training") | |
| print("=" * 60) | |
| print(f"Device: {device}") | |
| print(f"Train directory: {train_dir}") | |
| print(f"Test directory: {test_dir}") | |
| print(f"Segment length: {segment_seconds}s ({segment_samples} samples)") | |
| print(f"Batch size: {batch_size}") | |
| print(f"Epochs: {epochs}") | |
| print(f"Learning rate: {learning_rate}") | |
| print("=" * 60) | |
| # Load pretrained models | |
| print("Loading pretrained HTDemucs...") | |
| htdemucs = pretrained.get_model('htdemucs').models[0] | |
| print("Loading CLAP model...") | |
| #clap = ClapModel.from_pretrained("laion/clap-htsat-unfused") | |
| clap = ClapTextModelWithProjection.from_pretrained("laion/clap-htsat-unfused") # More memory efficient than loading full ClapModel (text + audio) | |
| tokenizer = AutoTokenizer.from_pretrained("laion/clap-htsat-unfused") | |
| # Create model | |
| print("Building AudioTextHTDemucs model...") | |
| model = AudioTextHTDemucs( | |
| htdemucs_model=htdemucs, | |
| clap_encoder=clap, | |
| clap_tokenizer=tokenizer, | |
| model_dim=model_dim, | |
| text_dim=text_dim, | |
| num_heads=n_heads, | |
| ) | |
| model = model.to(device) | |
| # Count parameters | |
| total_params = sum(p.numel() for p in model.parameters()) | |
| trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| print(f"Total parameters: {total_params:,}") | |
| print(f"Trainable parameters: {trainable_params:,}") | |
| # Create datasets | |
| print("Creating datasets...") | |
| train_dataset = MusDBStemDataset( | |
| root_dir=train_dir, | |
| segment_samples=segment_samples, | |
| sample_rate=sample_rate, | |
| random_segments=True, | |
| augment=True, | |
| ) | |
| val_dataset = MusDBStemDataset( | |
| root_dir=test_dir, | |
| segment_samples=segment_samples, | |
| sample_rate=sample_rate, | |
| random_segments=False, | |
| augment=False, | |
| ) | |
| # Create suubsets if specified | |
| if 0.0 < pct_train < 1.0: | |
| num_train = int(len(train_dataset) * pct_train) | |
| train_idxs = torch.randperm(len(train_dataset))[:num_train] | |
| train_subset = Subset(train_dataset, train_idxs) | |
| if 0.0 < pct_test < 1.0: | |
| num_val = int(len(val_dataset) * pct_test) | |
| val_idxs = torch.randperm(len(val_dataset))[:num_val] | |
| val_subset = Subset(train_dataset, val_idxs) | |
| # Create dataloaders | |
| train_loader = DataLoader( | |
| train_dataset if pct_train >= 1.0 else train_subset, | |
| batch_size=batch_size, | |
| shuffle=True, | |
| num_workers=num_workers, | |
| collate_fn=collate_fn, | |
| pin_memory=(device == "cuda"), | |
| drop_last=True, | |
| ) | |
| val_loader = DataLoader( | |
| val_dataset if pct_test >= 1.0 else val_subset, | |
| batch_size=batch_size, | |
| shuffle=False, | |
| num_workers=num_workers, | |
| collate_fn=collate_fn, | |
| pin_memory=(device == "cuda"), | |
| ) | |
| # Optimizer and scheduler | |
| optimizer = AdamW( | |
| model.parameters(), | |
| lr=learning_rate, | |
| weight_decay=weight_decay, | |
| betas=(0.9, 0.999), | |
| ) | |
| scheduler = CosineAnnealingLR( | |
| optimizer, | |
| T_max=epochs, | |
| eta_min=learning_rate * 0.01, | |
| ) | |
| # Mixed precision scaler | |
| scaler = GradScaler() if use_amp and device == "cuda" else None | |
| # Resume from checkpoint | |
| start_epoch = 0 | |
| best_sdr = -float("inf") | |
| if resume_from is not None: | |
| resume_path = Path(resume_from) | |
| if resume_path.exists(): | |
| print(f"Resuming from {resume_path}") | |
| start_epoch = load_checkpoint(model, optimizer, scheduler, str(resume_path)) | |
| start_epoch += 1 | |
| else: | |
| # Check for latest checkpoint | |
| latest_checkpoint = Path(checkpoint_dir) / "latest.pt" | |
| if latest_checkpoint.exists(): | |
| print(f"Found latest checkpoint at {latest_checkpoint}") | |
| start_epoch = load_checkpoint(model, optimizer, scheduler, str(latest_checkpoint)) | |
| start_epoch += 1 | |
| # Training loop | |
| print("\nStarting training...") | |
| for epoch in range(start_epoch, epochs): | |
| print(f"\n{'=' * 60}") | |
| print(f"Epoch {epoch + 1}/{epochs}") | |
| print(f"Learning rate: {scheduler.get_last_lr()[0]:.2e}") | |
| print(f"{'=' * 60}") | |
| # Train | |
| train_metrics = train_epoch( | |
| model=model, | |
| dataloader=train_loader, | |
| optimizer=optimizer, | |
| scaler=scaler, | |
| device=device, | |
| use_amp=use_amp, | |
| use_L1_cmb_loss=use_L1_cmb_loss, | |
| l1_sdr_weight=l1_sdr_weight, | |
| l1_weight=l1_weight, | |
| grad_clip=grad_clip, | |
| sdr_weight=sdr_weight, | |
| sisdr_weight=sisdr_weight, | |
| epoch=epoch, | |
| log_every=log_every, | |
| use_wandb=use_wandb, | |
| ) | |
| print(f"Train - Loss: {train_metrics['loss']:.4f}, SDR: {train_metrics['sdr']:.2f} dB") | |
| # Step scheduler | |
| scheduler.step() | |
| # Validate | |
| if (epoch + 1) % validate_every == 0: | |
| val_metrics = validate( | |
| model=model, | |
| dataloader=val_loader, | |
| device=device, | |
| use_amp=use_amp, | |
| use_L1_cmb_loss=use_L1_cmb_loss, | |
| l1_sdr_weight=l1_sdr_weight, | |
| l1_weight=l1_weight, | |
| sdr_weight=sdr_weight, | |
| sisdr_weight=sisdr_weight, | |
| ) | |
| print(f"Val - Loss: {val_metrics['loss']:.4f}, SDR: {val_metrics['sdr']:.2f} dB") | |
| for stem_name in STEM_PROMPTS.keys(): | |
| if f"sdr/{stem_name}" in val_metrics: | |
| print(f" {stem_name}: {val_metrics[f'sdr/{stem_name}']:.2f} dB") | |
| if use_wandb: | |
| import wandb | |
| wandb.log({ | |
| "val/loss": val_metrics["loss"], | |
| "val/sdr": val_metrics["sdr"], | |
| "val/sisdr": val_metrics["sisdr"], | |
| **{f"val/{k}": v for k, v in val_metrics.items() if k.startswith("sdr/")}, | |
| "epoch": epoch + 1, | |
| }) | |
| is_best = val_metrics["sdr"] > best_sdr | |
| if is_best: | |
| best_sdr = val_metrics["sdr"] | |
| print(f"New best SDR: {best_sdr:.2f} dB") | |
| else: | |
| val_metrics = {} | |
| is_best = False | |
| # Save checkpoint | |
| if (epoch + 1) % save_every == 0 or is_best: | |
| save_checkpoint( | |
| model, optimizer, scheduler, epoch + 1, | |
| {**train_metrics, **val_metrics}, | |
| checkpoint_dir, is_best | |
| ) | |
| else: | |
| save_checkpoint( | |
| model, optimizer, scheduler, epoch + 1, | |
| {**train_metrics, **val_metrics}, | |
| checkpoint_dir, is_best=False | |
| ) | |
| print("\n" + "=" * 60) | |
| print("Training complete!") | |
| print(f"Best validation SDR: {best_sdr:.2f} dB") | |
| print("=" * 60) | |
| if use_wandb: | |
| import wandb | |
| wandb.finish() | |
| return { | |
| "final_train_metrics": train_metrics, | |
| "final_val_metrics": val_metrics, | |
| "best_sdr": best_sdr, | |
| } | |
| if __name__ == "__main__": | |
| # Example: run training with default parameters | |
| train(train_dir="/home/jacob/datasets/musdb18/train", test_dir="/home/jacob/datasets/musdb18/test", checkpoint_dir="../checkpoints") |