import os import random import torch import torch.distributed as dist from PIL import Image import utils from .trainers import register from trainers.base_trainer import BaseTrainer from models.ldm.dac.audiotools import AudioSignal import soundfile as sf import numpy as np import torchaudio import time from datetime import datetime import matplotlib.pyplot as plt from tqdm import tqdm @register('audio_ldm_trainer') class AudioLDMTrainer(BaseTrainer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def make_model(self): super().make_model() self.has_optimizer = dict() total_params = 0 for name, m in self.model.named_children(): params = utils.compute_num_params(m, text=False) self.log(f' .{name} {params}') total_params = total_params + params # Log to Comet if self.experiment: self.experiment.log_metric(f"model/{name}_params", params) if self.experiment: self.experiment.log_metric("model/total_params", total_params) def make_optimizers(self): self.optimizers = dict() self.has_optimizer = dict() for name, spec in self.config.optimizers.items(): self.optimizers[name] = utils.make_optimizer(self.model.get_parameters(name), spec) self.has_optimizer[name] = True # Log optimizer config to Comet if self.experiment: self.experiment.log_parameters({ f"optimizer/{name}/type": spec.get("type", "adam"), f"optimizer/{name}/lr": spec.get("lr", 1e-4), f"optimizer/{name}/weight_decay": spec.get("weight_decay", 0), }) def train_step(self, data, bp=True): kwargs = {'has_optimizer': self.has_optimizer} # Start timing step_start_time = time.time() # Audio-specific data preparation if 'signal' in data: # Convert AudioSignal to tensor format expected by model audio_data = data['signal'].audio_data # [batch, channels, samples] sample_rate = data['signal'].sample_rate # Prepare data dict for model model_data = { 'inp': audio_data, 'gt': audio_data, # For autoencoder training 'sample_rate': sample_rate } else: model_data = data # self.log(f'Audio data shape: {model_data["inp"].shape}') # Log batch info to Comet if self.experiment and self.iter % 500 == 0: self.experiment.log_metric("train/batch_size", model_data["inp"].shape[0], step=self.iter) self.experiment.log_metric("train/audio_length_samples", model_data["inp"].shape[-1], step=self.iter) self.experiment.log_metric("train/audio_duration_sec", model_data["inp"].shape[-1] / model_data.get("sample_rate", 24000), step=self.iter) if self.config.get('autocast_bfloat16', False): with torch.cuda.amp.autocast(dtype=torch.bfloat16): ret = self.model_ddp(model_data, mode='loss', **kwargs) else: ret = self.model_ddp(model_data, mode='loss', **kwargs) loss = ret.pop('loss') ret['loss'] = loss.item() if bp: self.model_ddp.zero_grad(set_to_none=True) loss.backward() # Log gradients to Comet if self.experiment and self.iter % 5 == 0: self._log_gradients() for name, o in self.optimizers.items(): if name != 'disc': o.step() if hasattr(self.model, 'update_ema'): self.model.update_ema() # Log training metrics to Comet if self.experiment: # Log all losses for k, v in ret.items(): if 'loss' in k.lower(): self.experiment.log_metric(f"train/{k}", v, step=self.iter) # Log learning rates for name, opt in self.optimizers.items(): lr = opt.param_groups[0]['lr'] self.experiment.log_metric(f"train/lr_{name}", lr, step=self.iter) # Log timing step_time = time.time() - step_start_time self.experiment.log_metric("train/step_time", step_time, step=self.iter) # Log GPU memory usage if torch.cuda.is_available(): self.experiment.log_metric("train/gpu_memory_allocated", torch.cuda.memory_allocated() / 1e9, step=self.iter) self.experiment.log_metric("train/gpu_memory_reserved", torch.cuda.memory_reserved() / 1e9, step=self.iter) return ret def _log_gradients(self): """Log gradient statistics to Comet ML""" if not self.experiment: return grad_stats = {} for name, param in self.model.named_parameters(): if param.grad is not None: grad_norm = param.grad.norm().item() grad_mean = param.grad.mean().item() grad_std = param.grad.std().item() # Log aggregate stats by module module_name = name.split('.')[0] if module_name not in grad_stats: grad_stats[module_name] = { 'norm': [], 'mean': [], 'std': [] } grad_stats[module_name]['norm'].append(grad_norm) grad_stats[module_name]['mean'].append(grad_mean) grad_stats[module_name]['std'].append(grad_std) # Log aggregated stats for module, stats in grad_stats.items(): self.experiment.log_metric(f"gradients/{module}/norm_mean", np.mean(stats['norm']), step=self.iter) self.experiment.log_metric(f"gradients/{module}/norm_max", np.max(stats['norm']), step=self.iter) def run_training(self): config = self.config max_iter = config['max_iter'] epoch_iter = config['epoch_iter'] assert max_iter % epoch_iter == 0 max_epoch = max_iter // epoch_iter save_iter = config.get('save_iter') if save_iter is not None: assert save_iter % epoch_iter == 0 save_epoch = save_iter // epoch_iter print('save_epoch', save_epoch) else: save_epoch = max_epoch + 1 eval_iter = config.get('eval_iter') if eval_iter is not None: assert eval_iter % epoch_iter == 0 eval_epoch = eval_iter // epoch_iter else: eval_epoch = max_epoch + 1 vis_iter = config.get('vis_iter') if vis_iter is not None: assert vis_iter % epoch_iter == 0 vis_epoch = vis_iter // epoch_iter else: vis_epoch = max_epoch + 1 if config.get('ckpt_select_metric') is not None: m = config.ckpt_select_metric self.ckpt_select_metric = m.name self.ckpt_select_type = m.type if m.type == 'min': self.ckpt_select_v = 1e18 elif m.type == 'max': self.ckpt_select_v = -1e18 else: self.ckpt_select_metric = None self.ckpt_select_v = 0 self.train_loader = self.loaders['train'] self.train_loader_sampler = self.loader_samplers['train'] self.train_loader_epoch = 0 self.train_loader_iter = None self.iter = 0 if self.resume_ckpt is not None: for _ in range(self.resume_ckpt['iter']): self.iter += 1 self.at_train_iter_start() self.ckpt_select_v = self.resume_ckpt['ckpt_select_v'] self.train_loader_epoch = self.resume_ckpt['train_loader_epoch'] self.train_loader_iter = None self.resume_ckpt = None self.log(f'Resumed iter status.') self.visualize() start_epoch = self.iter // epoch_iter + 1 for epoch in range(start_epoch, max_epoch + 1): self.log_buffer = [f'Epoch {epoch}'] for sampler in self.loader_samplers.values(): if sampler is not self.train_loader_sampler: sampler.set_epoch(epoch) self.model_ddp.train() pbar = range(1, epoch_iter + 1) if self.is_master and epoch == start_epoch: pbar = tqdm(pbar, desc='train', leave=False) t_data = 0 t_nondata = 0 t_before_data = time.time() for _ in pbar: self.iter += 1 self.at_train_iter_start() try: if self.train_loader_iter is None: raise StopIteration data = next(self.train_loader_iter) except StopIteration: self.train_loader_epoch += 1 self.train_loader_sampler.set_epoch(self.train_loader_epoch) self.train_loader_iter = iter(self.train_loader) data = next(self.train_loader_iter) t_after_data = time.time() t_data += t_after_data - t_before_data for k, v in data.items(): data[k] = v.to(self.device) if torch.is_tensor(v) else v ret = self.train_step(data) t_before_data = time.time() t_nondata += t_before_data - t_after_data if self.is_master and epoch == start_epoch: pbar.set_description(desc=f'train: loss={ret["loss"]:.4f}') # save the model every 1000 iterations if self.iter % 2000 == 0: self.save_ckpt(f'ckpt-{self.iter}.pth') self.save_ckpt('ckpt-last.pth') if epoch % save_epoch == 0 and epoch != max_epoch: self.save_ckpt(f'ckpt-{self.iter}.pth') if epoch % eval_epoch == 0: with torch.no_grad(): eval_ave_scalars = self.evaluate() if self.ckpt_select_metric is not None: v = eval_ave_scalars[self.ckpt_select_metric].item() if ((self.ckpt_select_type == 'min' and v < self.ckpt_select_v) or (self.ckpt_select_type == 'max' and v > self.ckpt_select_v)): self.ckpt_select_v = v self.save_ckpt('ckpt-best.pth') if epoch % vis_epoch == 0: with torch.no_grad(): self.visualize() def evaluate(self): self.model_ddp.eval() ave_scalars = dict() pbar = self.loaders['val'] for data in pbar: # Prepare audio data for GPU if 'signal' in data: data['signal'] = data['signal'].to(self.device) else: for k, v in data.items(): data[k] = v.to(self.device) if torch.is_tensor(v) else v ret = self.train_step(data, bp=False) bs = data['signal'].batch_size if 'signal' in data else len(next(iter(data.values()))) for k, v in ret.items(): if ave_scalars.get(k) is None: ave_scalars[k] = utils.Averager() ave_scalars[k].add(v, n=bs) self.sync_ave_scalars(ave_scalars) # Audio-specific evaluation if self.config.get('evaluate_ae', False): ave_scalars.update(self.evaluate_audio_ae()) if self.config.get('evaluate_zdm', False): ema = self.config.get('evaluate_zdm_ema', True) ave_scalars.update(self.evaluate_audio_zdm(ema=ema)) logtext = 'val:' for k, v in ave_scalars.items(): logtext += f' {k}={v.item():.4f}' self.log_scalar('val/' + k, v.item()) # Log to Comet if self.experiment: self.experiment.log_metric(f"val/{k}", v.item(), step=self.iter) self.log_buffer.append(logtext) return ave_scalars def visualize(self): self.model_ddp.eval() if self.config.get('evaluate_ae', False): self.visualize_audio_ae_random() if self.config.get('evaluate_zdm', False): ema = self.config.get('evaluate_zdm_ema', True) self.visualize_audio_zdm_random(ema=ema) def evaluate_audio_ae(self): """Audio autoencoder evaluation with spectral metrics""" max_samples = self.config.get('eval_ae_max_samples', 1000) self.loader_samplers['eval_ae'].set_epoch(0) l1_loss_avg = utils.Averager() snr_avg = utils.Averager() spectral_convergence_avg = utils.Averager() cnt = 0 # Create cache directories for audio samples cache_gen_dir = os.path.join(self.env['save_dir'], 'cache', 'audio_gen') cache_gt_dir = os.path.join(self.env['save_dir'], 'cache', 'audio_gt') if self.is_master: utils.ensure_path(cache_gen_dir, force_replace=True) utils.ensure_path(cache_gt_dir, force_replace=True) dist.barrier() for data in self.loaders['eval_ae']: if 'signal' in data: data['signal'] = data['signal'].to(self.device) signal = data['signal'] else: for k, v in data.items(): data[k] = v.to(self.device) if torch.is_tensor(v) else v signal = AudioSignal(data['inp'], data.get('sample_rate', 22050)) # Get reconstruction pred_audio = self.model(data, mode='pred') if isinstance(pred_audio, dict): pred_audio = pred_audio.get('audio', pred_audio.get('recons', pred_audio)) recons = AudioSignal(pred_audio, signal.sample_rate) # SNR calculation signal_power = (signal.audio_data ** 2).mean() noise_power = ((recons.audio_data - signal.audio_data) ** 2).mean() snr = 10 * torch.log10(signal_power / (noise_power + 1e-8)) snr_avg.add(snr.item()) # Spectral convergence stft_transform = torchaudio.transforms.Spectrogram( n_fft=1024, hop_length=256, power=2 ).to(self.device) orig_spec = stft_transform(signal.audio_data) recon_spec = stft_transform(recons.audio_data) spec_diff = torch.norm(orig_spec - recon_spec, p='fro') spec_norm = torch.norm(orig_spec, p='fro') spectral_convergence = spec_diff / (spec_norm + 1e-8) spectral_convergence_avg.add(spectral_convergence.item()) l1_loss = torch.nn.functional.l1_loss(recons.audio_data, signal.audio_data).item() l1_loss_avg.add(l1_loss) # Save audio samples for potential subjective evaluation for i in range(min(signal.batch_size, 5)): # Save up to 5 per batch idx = int(os.environ['RANK']) + cnt * int(os.environ['WORLD_SIZE']) if max_samples is None or idx < max_samples: tmp_recon = recons[i].audio_data.cpu().numpy() if tmp_recon.dim() == 3: tmp_recon = tmp_recon.squeeze(0) elif tmp_recon.dim() == 1: tmp_recon = tmp_recon.unsqueeze(0) tmp_recon = tmp_recon.T tmp_signal = signal[i].audio_data.cpu().numpy() if tmp_signal.dim() == 3: tmp_signal = tmp_signal.squeeze(0) elif tmp_signal.dim() == 1: tmp_signal = tmp_signal.unsqueeze(0) tmp_signal = tmp_signal.T # Save as wav files sf.write( os.path.join(cache_gen_dir, f'{idx}.wav'), tmp_recon, int(recons[i].sample_rate) ) sf.write( os.path.join(cache_gt_dir, f'{idx}.wav'), tmp_signal, int(signal[i].sample_rate) ) cnt += 1 dist.barrier() # Sync metrics across processes for avg_metric in [l1_loss_avg, snr_avg, spectral_convergence_avg]: vt = torch.tensor(avg_metric.item(), device=self.device) dist.all_reduce(vt, op=dist.ReduceOp.SUM) torch.cuda.synchronize() avg_metric.v = vt.item() / int(os.environ['WORLD_SIZE']) if self.is_master: prefix = 'eval_ae' ret = { f'{prefix}/L1_Loss': l1_loss_avg.item(), f'{prefix}/SNR': snr_avg.item(), f'{prefix}/Spectral_Convergence': spectral_convergence_avg.item(), } else: ret = {} dist.barrier() ret = {k: utils.Averager(v) for k, v in ret.items()} return ret def evaluate_audio_zdm(self, ema): """Audio latent diffusion model evaluation""" max_samples = self.config.get('eval_zdm_max_samples', 1000) self.loader_samplers['eval_zdm'].set_epoch(0) cnt = 0 l1_loss_avg = utils.Averager() cache_gen_dir = os.path.join(self.env['save_dir'], 'cache', 'audio_gen') cache_gt_dir = os.path.join(self.env['save_dir'], 'cache', 'audio_gt') if self.is_master: utils.ensure_path(cache_gen_dir, force_replace=True) utils.ensure_path(cache_gt_dir, force_replace=True) dist.barrier() for data in self.loaders['eval_zdm']: if 'signal' in data: data['signal'] = data['signal'].to(self.device) gt_signal = data['signal'] else: for k, v in data.items(): data[k] = v.to(self.device) if torch.is_tensor(v) else v gt_signal = AudioSignal(data['inp'], data.get('sample_rate', 22050)) # Generate samples from latent diffusion model net_kwargs = dict() uncond_net_kwargs = dict() # Add conditioning if available (e.g., for conditional generation) pred_audio = self.model.generate_samples( batch_size=gt_signal.batch_size, n_steps=self.model.zdm_n_steps, net_kwargs=net_kwargs, uncond_net_kwargs=uncond_net_kwargs, ema=ema ) pred_signal = AudioSignal(pred_audio, gt_signal.sample_rate) l1_loss = torch.nn.functional.l1_loss(pred_signal.audio_data, gt_signal.audio_data).item() l1_loss_avg.add(l1_loss) # Save samples for i in range(min(gt_signal.batch_size, 5)): idx = int(os.environ['RANK']) + cnt * int(os.environ['WORLD_SIZE']) if max_samples is None or idx < max_samples: tmp_recon = pred_signal[i].audio_data.cpu().numpy() if tmp_recon.dim() == 3: tmp_recon = tmp_recon.squeeze(0) elif tmp_recon.dim() == 1: tmp_recon = tmp_recon.unsqueeze(0) tmp_recon = tmp_recon.T tmp_signal = gt_signal[i].audio_data.cpu().numpy() if tmp_signal.dim() == 3: tmp_signal = tmp_signal.squeeze(0) elif tmp_signal.dim() == 1: tmp_signal = tmp_signal.unsqueeze(0) tmp_signal = tmp_signal.T sf.write( os.path.join(cache_gen_dir, f'{idx}.wav'), tmp_recon, int(pred_signal[i].sample_rate) ) sf.write( os.path.join(cache_gt_dir, f'{idx}.wav'), tmp_signal, int(gt_signal[i].sample_rate) ) cnt += 1 dist.barrier() # Sync metrics for avg_metric in [l1_loss_avg]: vt = torch.tensor(avg_metric.item(), device=self.device) dist.all_reduce(vt, op=dist.ReduceOp.SUM) torch.cuda.synchronize() avg_metric.v = vt.item() / int(os.environ['WORLD_SIZE']) if self.is_master: prefix = 'eval_zdm' + ('_ema' if ema else '') ret = { f'{prefix}/l1_loss_avg': l1_loss_avg.item(), } else: ret = {} dist.barrier() ret = {k: utils.Averager(v) for k, v in ret.items()} return ret def visualize_audio_ae_random(self): """Save random audio reconstructions for listening""" if self.is_master: idx_list = list(range(len(self.datasets['eval_ae']))) random.shuffle(idx_list) n_samples = self.config.get('visualize_ae_random_n_samples', 8) audio_samples = [] for idx in idx_list[:n_samples]: data = self.datasets['eval_ae'][idx] # Prepare data if 'signal' in data: signal = data['signal'].unsqueeze(0).to(self.device) model_data = { 'inp': signal.audio_data, 'gt': signal.audio_data, 'sample_rate': signal.sample_rate } else: for k, v in data.items(): data[k] = v.unsqueeze(0).to(self.device) if torch.is_tensor(v) else v signal = AudioSignal(data['inp'], data.get('sample_rate', 24000)) model_data = data # Get reconstruction pred_audio = self.model(model_data, mode='pred') if isinstance(pred_audio, dict): pred_audio = pred_audio.get('audio', pred_audio.get('recons', pred_audio)) recons = AudioSignal(pred_audio, signal.sample_rate) # Save to file and log to Comet self.save_audio_sample(signal, f'audio_ae_original_{idx}') self.save_audio_sample(recons, f'audio_ae_recons_{idx}') dist.barrier() def visualize_audio_zdm_random(self, ema): """Save random audio generations from latent diffusion model""" if self.is_master: n_samples = self.config.get('visualize_zdm_random_n_samples', 8) for i in range(n_samples): # Generate random sample net_kwargs = dict() uncond_net_kwargs = dict() # Get a reference from dataset for parameters like sample_rate ref_data = self.datasets['eval_ae'][0] if 'signal' in ref_data: ref_signal = ref_data['signal'] sample_rate = ref_signal.sample_rate batch_size = 1 else: sample_rate = ref_data.get('sample_rate', 24000) batch_size = 1 pred_audio = self.model.generate_samples( batch_size=batch_size, n_steps=self.model.zdm_n_steps, net_kwargs=net_kwargs, uncond_net_kwargs=uncond_net_kwargs, ema=ema ) pred_signal = AudioSignal(pred_audio, sample_rate) # Save generated audio self.save_audio_sample(pred_signal, f'audio_zdm_generated_{i}') dist.barrier() def save_audio_sample(self, audio_signal, name): """Save audio sample and log to Comet ML""" try: # Ensure audio is in correct format audio_data = audio_signal.audio_data.cpu() # Handle different dimensions if audio_data.dim() == 3: # [batch, channels, samples] audio_data = audio_data[0] # Take first sample if audio_data.dim() == 2: # [channels, samples] audio_data = audio_data.transpose(0, 1) # [samples, channels] elif audio_data.dim() == 1: # [samples] audio_data = audio_data.unsqueeze(1) # [samples, 1] audio_data = audio_data.numpy() # Normalize if needed if np.abs(audio_data).max() > 1.0: audio_data = audio_data / np.abs(audio_data).max() # Save to file save_path = os.path.join(self.env['save_dir'], 'audio_samples') os.makedirs(save_path, exist_ok=True) file_path = os.path.join(save_path, f'{name}_step_{self.iter}.wav') sf.write(file_path, audio_data, int(audio_signal.sample_rate)) # Log to Comet ML if self.experiment: self.experiment.log_audio( file_path, metadata={ 'name': name, 'step': self.iter, 'sample_rate': int(audio_signal.sample_rate), 'duration': len(audio_data) / audio_signal.sample_rate, 'channels': audio_data.shape[1] if audio_data.ndim > 1 else 1 }, step=self.iter ) # Also log spectrograms for visualization if self.iter % self.config.get('spectrogram_log_freq', 1000) == 0: self._log_spectrogram(audio_signal, name) self.log(f"Saved audio sample: {file_path}") except Exception as e: self.log(f"Error saving audio sample {name}: {e}") if self.experiment: self.experiment.log_text(f"Error saving audio {name}: {str(e)}", step=self.iter) def _log_spectrogram(self, audio_signal, name): """Log spectrogram visualization to Comet ML""" if not self.experiment: return try: # Compute spectrogram stft_transform = torchaudio.transforms.Spectrogram( n_fft=2048, hop_length=512, power=2 ) audio_data = audio_signal.audio_data if audio_data.dim() == 3: audio_data = audio_data[0] if audio_data.dim() == 2: audio_data = audio_data[0] # Take first channel spec = stft_transform(audio_data.cpu()) spec_db = 10 * torch.log10(spec + 1e-8) # Create figure fig, ax = plt.subplots(figsize=(10, 4)) im = ax.imshow( spec_db.numpy(), aspect='auto', origin='lower', cmap='viridis', extent=[0, len(audio_data) / audio_signal.sample_rate, 0, audio_signal.sample_rate / 2] ) ax.set_xlabel('Time (s)') ax.set_ylabel('Frequency (Hz)') ax.set_title(f'{name} - Spectrogram') plt.colorbar(im, ax=ax, label='dB') # Log to Comet self.experiment.log_figure(f"spectrogram/{name}", fig, step=self.iter) plt.close(fig) except Exception as e: self.log(f"Error logging spectrogram for {name}: {e}") def save_checkpoint(self, tag="latest"): """Save checkpoint and log to Comet ML""" checkpoint_path = super().save_checkpoint(tag) if self.experiment and checkpoint_path: # Log checkpoint to Comet self.experiment.log_model( f"checkpoint_{tag}", checkpoint_path, metadata={ "step": self.iter, "tag": tag, "timestamp": datetime.now().isoformat() } )