Spaces:
Sleeping
Sleeping
| import os | |
| import random | |
| import torch | |
| import torch.distributed as dist | |
| from PIL import Image | |
| import utils | |
| from .trainers import register | |
| from trainers.base_trainer import BaseTrainer | |
| from models.ldm.dac.audiotools import AudioSignal | |
| import soundfile as sf | |
| import numpy as np | |
| import torchaudio | |
| import time | |
| from datetime import datetime | |
| import matplotlib.pyplot as plt | |
| from tqdm import tqdm | |
| class AudioLDMTrainer(BaseTrainer): | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| def make_model(self): | |
| super().make_model() | |
| self.has_optimizer = dict() | |
| total_params = 0 | |
| for name, m in self.model.named_children(): | |
| params = utils.compute_num_params(m, text=False) | |
| self.log(f' .{name} {params}') | |
| total_params = total_params + params | |
| # Log to Comet | |
| if self.experiment: | |
| self.experiment.log_metric(f"model/{name}_params", params) | |
| if self.experiment: | |
| self.experiment.log_metric("model/total_params", total_params) | |
| def make_optimizers(self): | |
| self.optimizers = dict() | |
| self.has_optimizer = dict() | |
| for name, spec in self.config.optimizers.items(): | |
| self.optimizers[name] = utils.make_optimizer(self.model.get_parameters(name), spec) | |
| self.has_optimizer[name] = True | |
| # Log optimizer config to Comet | |
| if self.experiment: | |
| self.experiment.log_parameters({ | |
| f"optimizer/{name}/type": spec.get("type", "adam"), | |
| f"optimizer/{name}/lr": spec.get("lr", 1e-4), | |
| f"optimizer/{name}/weight_decay": spec.get("weight_decay", 0), | |
| }) | |
| def train_step(self, data, bp=True): | |
| kwargs = {'has_optimizer': self.has_optimizer} | |
| # Start timing | |
| step_start_time = time.time() | |
| # Audio-specific data preparation | |
| if 'signal' in data: | |
| # Convert AudioSignal to tensor format expected by model | |
| audio_data = data['signal'].audio_data # [batch, channels, samples] | |
| sample_rate = data['signal'].sample_rate | |
| # Prepare data dict for model | |
| model_data = { | |
| 'inp': audio_data, | |
| 'gt': audio_data, # For autoencoder training | |
| 'sample_rate': sample_rate | |
| } | |
| else: | |
| model_data = data | |
| # self.log(f'Audio data shape: {model_data["inp"].shape}') | |
| # Log batch info to Comet | |
| if self.experiment and self.iter % 500 == 0: | |
| self.experiment.log_metric("train/batch_size", model_data["inp"].shape[0], step=self.iter) | |
| self.experiment.log_metric("train/audio_length_samples", model_data["inp"].shape[-1], step=self.iter) | |
| self.experiment.log_metric("train/audio_duration_sec", | |
| model_data["inp"].shape[-1] / model_data.get("sample_rate", 24000), | |
| step=self.iter) | |
| if self.config.get('autocast_bfloat16', False): | |
| with torch.cuda.amp.autocast(dtype=torch.bfloat16): | |
| ret = self.model_ddp(model_data, mode='loss', **kwargs) | |
| else: | |
| ret = self.model_ddp(model_data, mode='loss', **kwargs) | |
| loss = ret.pop('loss') | |
| ret['loss'] = loss.item() | |
| if bp: | |
| self.model_ddp.zero_grad(set_to_none=True) | |
| loss.backward() | |
| # Log gradients to Comet | |
| if self.experiment and self.iter % 5 == 0: | |
| self._log_gradients() | |
| for name, o in self.optimizers.items(): | |
| if name != 'disc': | |
| o.step() | |
| if hasattr(self.model, 'update_ema'): | |
| self.model.update_ema() | |
| # Log training metrics to Comet | |
| if self.experiment: | |
| # Log all losses | |
| for k, v in ret.items(): | |
| if 'loss' in k.lower(): | |
| self.experiment.log_metric(f"train/{k}", v, step=self.iter) | |
| # Log learning rates | |
| for name, opt in self.optimizers.items(): | |
| lr = opt.param_groups[0]['lr'] | |
| self.experiment.log_metric(f"train/lr_{name}", lr, step=self.iter) | |
| # Log timing | |
| step_time = time.time() - step_start_time | |
| self.experiment.log_metric("train/step_time", step_time, step=self.iter) | |
| # Log GPU memory usage | |
| if torch.cuda.is_available(): | |
| self.experiment.log_metric("train/gpu_memory_allocated", | |
| torch.cuda.memory_allocated() / 1e9, | |
| step=self.iter) | |
| self.experiment.log_metric("train/gpu_memory_reserved", | |
| torch.cuda.memory_reserved() / 1e9, | |
| step=self.iter) | |
| return ret | |
| def _log_gradients(self): | |
| """Log gradient statistics to Comet ML""" | |
| if not self.experiment: | |
| return | |
| grad_stats = {} | |
| for name, param in self.model.named_parameters(): | |
| if param.grad is not None: | |
| grad_norm = param.grad.norm().item() | |
| grad_mean = param.grad.mean().item() | |
| grad_std = param.grad.std().item() | |
| # Log aggregate stats by module | |
| module_name = name.split('.')[0] | |
| if module_name not in grad_stats: | |
| grad_stats[module_name] = { | |
| 'norm': [], | |
| 'mean': [], | |
| 'std': [] | |
| } | |
| grad_stats[module_name]['norm'].append(grad_norm) | |
| grad_stats[module_name]['mean'].append(grad_mean) | |
| grad_stats[module_name]['std'].append(grad_std) | |
| # Log aggregated stats | |
| for module, stats in grad_stats.items(): | |
| self.experiment.log_metric(f"gradients/{module}/norm_mean", np.mean(stats['norm']), step=self.iter) | |
| self.experiment.log_metric(f"gradients/{module}/norm_max", np.max(stats['norm']), step=self.iter) | |
| def run_training(self): | |
| config = self.config | |
| max_iter = config['max_iter'] | |
| epoch_iter = config['epoch_iter'] | |
| assert max_iter % epoch_iter == 0 | |
| max_epoch = max_iter // epoch_iter | |
| save_iter = config.get('save_iter') | |
| if save_iter is not None: | |
| assert save_iter % epoch_iter == 0 | |
| save_epoch = save_iter // epoch_iter | |
| print('save_epoch', save_epoch) | |
| else: | |
| save_epoch = max_epoch + 1 | |
| eval_iter = config.get('eval_iter') | |
| if eval_iter is not None: | |
| assert eval_iter % epoch_iter == 0 | |
| eval_epoch = eval_iter // epoch_iter | |
| else: | |
| eval_epoch = max_epoch + 1 | |
| vis_iter = config.get('vis_iter') | |
| if vis_iter is not None: | |
| assert vis_iter % epoch_iter == 0 | |
| vis_epoch = vis_iter // epoch_iter | |
| else: | |
| vis_epoch = max_epoch + 1 | |
| if config.get('ckpt_select_metric') is not None: | |
| m = config.ckpt_select_metric | |
| self.ckpt_select_metric = m.name | |
| self.ckpt_select_type = m.type | |
| if m.type == 'min': | |
| self.ckpt_select_v = 1e18 | |
| elif m.type == 'max': | |
| self.ckpt_select_v = -1e18 | |
| else: | |
| self.ckpt_select_metric = None | |
| self.ckpt_select_v = 0 | |
| self.train_loader = self.loaders['train'] | |
| self.train_loader_sampler = self.loader_samplers['train'] | |
| self.train_loader_epoch = 0 | |
| self.train_loader_iter = None | |
| self.iter = 0 | |
| if self.resume_ckpt is not None: | |
| for _ in range(self.resume_ckpt['iter']): | |
| self.iter += 1 | |
| self.at_train_iter_start() | |
| self.ckpt_select_v = self.resume_ckpt['ckpt_select_v'] | |
| self.train_loader_epoch = self.resume_ckpt['train_loader_epoch'] | |
| self.train_loader_iter = None | |
| self.resume_ckpt = None | |
| self.log(f'Resumed iter status.') | |
| self.visualize() | |
| start_epoch = self.iter // epoch_iter + 1 | |
| for epoch in range(start_epoch, max_epoch + 1): | |
| self.log_buffer = [f'Epoch {epoch}'] | |
| for sampler in self.loader_samplers.values(): | |
| if sampler is not self.train_loader_sampler: | |
| sampler.set_epoch(epoch) | |
| self.model_ddp.train() | |
| pbar = range(1, epoch_iter + 1) | |
| if self.is_master and epoch == start_epoch: | |
| pbar = tqdm(pbar, desc='train', leave=False) | |
| t_data = 0 | |
| t_nondata = 0 | |
| t_before_data = time.time() | |
| for _ in pbar: | |
| self.iter += 1 | |
| self.at_train_iter_start() | |
| try: | |
| if self.train_loader_iter is None: | |
| raise StopIteration | |
| data = next(self.train_loader_iter) | |
| except StopIteration: | |
| self.train_loader_epoch += 1 | |
| self.train_loader_sampler.set_epoch(self.train_loader_epoch) | |
| self.train_loader_iter = iter(self.train_loader) | |
| data = next(self.train_loader_iter) | |
| t_after_data = time.time() | |
| t_data += t_after_data - t_before_data | |
| for k, v in data.items(): | |
| data[k] = v.to(self.device) if torch.is_tensor(v) else v | |
| ret = self.train_step(data) | |
| t_before_data = time.time() | |
| t_nondata += t_before_data - t_after_data | |
| if self.is_master and epoch == start_epoch: | |
| pbar.set_description(desc=f'train: loss={ret["loss"]:.4f}') | |
| # save the model every 1000 iterations | |
| if self.iter % 2000 == 0: | |
| self.save_ckpt(f'ckpt-{self.iter}.pth') | |
| self.save_ckpt('ckpt-last.pth') | |
| if epoch % save_epoch == 0 and epoch != max_epoch: | |
| self.save_ckpt(f'ckpt-{self.iter}.pth') | |
| if epoch % eval_epoch == 0: | |
| with torch.no_grad(): | |
| eval_ave_scalars = self.evaluate() | |
| if self.ckpt_select_metric is not None: | |
| v = eval_ave_scalars[self.ckpt_select_metric].item() | |
| if ((self.ckpt_select_type == 'min' and v < self.ckpt_select_v) or | |
| (self.ckpt_select_type == 'max' and v > self.ckpt_select_v)): | |
| self.ckpt_select_v = v | |
| self.save_ckpt('ckpt-best.pth') | |
| if epoch % vis_epoch == 0: | |
| with torch.no_grad(): | |
| self.visualize() | |
| def evaluate(self): | |
| self.model_ddp.eval() | |
| ave_scalars = dict() | |
| pbar = self.loaders['val'] | |
| for data in pbar: | |
| # Prepare audio data for GPU | |
| if 'signal' in data: | |
| data['signal'] = data['signal'].to(self.device) | |
| else: | |
| for k, v in data.items(): | |
| data[k] = v.to(self.device) if torch.is_tensor(v) else v | |
| ret = self.train_step(data, bp=False) | |
| bs = data['signal'].batch_size if 'signal' in data else len(next(iter(data.values()))) | |
| for k, v in ret.items(): | |
| if ave_scalars.get(k) is None: | |
| ave_scalars[k] = utils.Averager() | |
| ave_scalars[k].add(v, n=bs) | |
| self.sync_ave_scalars(ave_scalars) | |
| # Audio-specific evaluation | |
| if self.config.get('evaluate_ae', False): | |
| ave_scalars.update(self.evaluate_audio_ae()) | |
| if self.config.get('evaluate_zdm', False): | |
| ema = self.config.get('evaluate_zdm_ema', True) | |
| ave_scalars.update(self.evaluate_audio_zdm(ema=ema)) | |
| logtext = 'val:' | |
| for k, v in ave_scalars.items(): | |
| logtext += f' {k}={v.item():.4f}' | |
| self.log_scalar('val/' + k, v.item()) | |
| # Log to Comet | |
| if self.experiment: | |
| self.experiment.log_metric(f"val/{k}", v.item(), step=self.iter) | |
| self.log_buffer.append(logtext) | |
| return ave_scalars | |
| def visualize(self): | |
| self.model_ddp.eval() | |
| if self.config.get('evaluate_ae', False): | |
| self.visualize_audio_ae_random() | |
| if self.config.get('evaluate_zdm', False): | |
| ema = self.config.get('evaluate_zdm_ema', True) | |
| self.visualize_audio_zdm_random(ema=ema) | |
| def evaluate_audio_ae(self): | |
| """Audio autoencoder evaluation with spectral metrics""" | |
| max_samples = self.config.get('eval_ae_max_samples', 1000) | |
| self.loader_samplers['eval_ae'].set_epoch(0) | |
| l1_loss_avg = utils.Averager() | |
| snr_avg = utils.Averager() | |
| spectral_convergence_avg = utils.Averager() | |
| cnt = 0 | |
| # Create cache directories for audio samples | |
| cache_gen_dir = os.path.join(self.env['save_dir'], 'cache', 'audio_gen') | |
| cache_gt_dir = os.path.join(self.env['save_dir'], 'cache', 'audio_gt') | |
| if self.is_master: | |
| utils.ensure_path(cache_gen_dir, force_replace=True) | |
| utils.ensure_path(cache_gt_dir, force_replace=True) | |
| dist.barrier() | |
| for data in self.loaders['eval_ae']: | |
| if 'signal' in data: | |
| data['signal'] = data['signal'].to(self.device) | |
| signal = data['signal'] | |
| else: | |
| for k, v in data.items(): | |
| data[k] = v.to(self.device) if torch.is_tensor(v) else v | |
| signal = AudioSignal(data['inp'], data.get('sample_rate', 22050)) | |
| # Get reconstruction | |
| pred_audio = self.model(data, mode='pred') | |
| if isinstance(pred_audio, dict): | |
| pred_audio = pred_audio.get('audio', pred_audio.get('recons', pred_audio)) | |
| recons = AudioSignal(pred_audio, signal.sample_rate) | |
| # SNR calculation | |
| signal_power = (signal.audio_data ** 2).mean() | |
| noise_power = ((recons.audio_data - signal.audio_data) ** 2).mean() | |
| snr = 10 * torch.log10(signal_power / (noise_power + 1e-8)) | |
| snr_avg.add(snr.item()) | |
| # Spectral convergence | |
| stft_transform = torchaudio.transforms.Spectrogram( | |
| n_fft=1024, | |
| hop_length=256, | |
| power=2 | |
| ).to(self.device) | |
| orig_spec = stft_transform(signal.audio_data) | |
| recon_spec = stft_transform(recons.audio_data) | |
| spec_diff = torch.norm(orig_spec - recon_spec, p='fro') | |
| spec_norm = torch.norm(orig_spec, p='fro') | |
| spectral_convergence = spec_diff / (spec_norm + 1e-8) | |
| spectral_convergence_avg.add(spectral_convergence.item()) | |
| l1_loss = torch.nn.functional.l1_loss(recons.audio_data, signal.audio_data).item() | |
| l1_loss_avg.add(l1_loss) | |
| # Save audio samples for potential subjective evaluation | |
| for i in range(min(signal.batch_size, 5)): # Save up to 5 per batch | |
| idx = int(os.environ['RANK']) + cnt * int(os.environ['WORLD_SIZE']) | |
| if max_samples is None or idx < max_samples: | |
| tmp_recon = recons[i].audio_data.cpu().numpy() | |
| if tmp_recon.dim() == 3: | |
| tmp_recon = tmp_recon.squeeze(0) | |
| elif tmp_recon.dim() == 1: | |
| tmp_recon = tmp_recon.unsqueeze(0) | |
| tmp_recon = tmp_recon.T | |
| tmp_signal = signal[i].audio_data.cpu().numpy() | |
| if tmp_signal.dim() == 3: | |
| tmp_signal = tmp_signal.squeeze(0) | |
| elif tmp_signal.dim() == 1: | |
| tmp_signal = tmp_signal.unsqueeze(0) | |
| tmp_signal = tmp_signal.T | |
| # Save as wav files | |
| sf.write( | |
| os.path.join(cache_gen_dir, f'{idx}.wav'), | |
| tmp_recon, | |
| int(recons[i].sample_rate) | |
| ) | |
| sf.write( | |
| os.path.join(cache_gt_dir, f'{idx}.wav'), | |
| tmp_signal, | |
| int(signal[i].sample_rate) | |
| ) | |
| cnt += 1 | |
| dist.barrier() | |
| # Sync metrics across processes | |
| for avg_metric in [l1_loss_avg, snr_avg, spectral_convergence_avg]: | |
| vt = torch.tensor(avg_metric.item(), device=self.device) | |
| dist.all_reduce(vt, op=dist.ReduceOp.SUM) | |
| torch.cuda.synchronize() | |
| avg_metric.v = vt.item() / int(os.environ['WORLD_SIZE']) | |
| if self.is_master: | |
| prefix = 'eval_ae' | |
| ret = { | |
| f'{prefix}/L1_Loss': l1_loss_avg.item(), | |
| f'{prefix}/SNR': snr_avg.item(), | |
| f'{prefix}/Spectral_Convergence': spectral_convergence_avg.item(), | |
| } | |
| else: | |
| ret = {} | |
| dist.barrier() | |
| ret = {k: utils.Averager(v) for k, v in ret.items()} | |
| return ret | |
| def evaluate_audio_zdm(self, ema): | |
| """Audio latent diffusion model evaluation""" | |
| max_samples = self.config.get('eval_zdm_max_samples', 1000) | |
| self.loader_samplers['eval_zdm'].set_epoch(0) | |
| cnt = 0 | |
| l1_loss_avg = utils.Averager() | |
| cache_gen_dir = os.path.join(self.env['save_dir'], 'cache', 'audio_gen') | |
| cache_gt_dir = os.path.join(self.env['save_dir'], 'cache', 'audio_gt') | |
| if self.is_master: | |
| utils.ensure_path(cache_gen_dir, force_replace=True) | |
| utils.ensure_path(cache_gt_dir, force_replace=True) | |
| dist.barrier() | |
| for data in self.loaders['eval_zdm']: | |
| if 'signal' in data: | |
| data['signal'] = data['signal'].to(self.device) | |
| gt_signal = data['signal'] | |
| else: | |
| for k, v in data.items(): | |
| data[k] = v.to(self.device) if torch.is_tensor(v) else v | |
| gt_signal = AudioSignal(data['inp'], data.get('sample_rate', 22050)) | |
| # Generate samples from latent diffusion model | |
| net_kwargs = dict() | |
| uncond_net_kwargs = dict() | |
| # Add conditioning if available (e.g., for conditional generation) | |
| pred_audio = self.model.generate_samples( | |
| batch_size=gt_signal.batch_size, | |
| n_steps=self.model.zdm_n_steps, | |
| net_kwargs=net_kwargs, | |
| uncond_net_kwargs=uncond_net_kwargs, | |
| ema=ema | |
| ) | |
| pred_signal = AudioSignal(pred_audio, gt_signal.sample_rate) | |
| l1_loss = torch.nn.functional.l1_loss(pred_signal.audio_data, gt_signal.audio_data).item() | |
| l1_loss_avg.add(l1_loss) | |
| # Save samples | |
| for i in range(min(gt_signal.batch_size, 5)): | |
| idx = int(os.environ['RANK']) + cnt * int(os.environ['WORLD_SIZE']) | |
| if max_samples is None or idx < max_samples: | |
| tmp_recon = pred_signal[i].audio_data.cpu().numpy() | |
| if tmp_recon.dim() == 3: | |
| tmp_recon = tmp_recon.squeeze(0) | |
| elif tmp_recon.dim() == 1: | |
| tmp_recon = tmp_recon.unsqueeze(0) | |
| tmp_recon = tmp_recon.T | |
| tmp_signal = gt_signal[i].audio_data.cpu().numpy() | |
| if tmp_signal.dim() == 3: | |
| tmp_signal = tmp_signal.squeeze(0) | |
| elif tmp_signal.dim() == 1: | |
| tmp_signal = tmp_signal.unsqueeze(0) | |
| tmp_signal = tmp_signal.T | |
| sf.write( | |
| os.path.join(cache_gen_dir, f'{idx}.wav'), | |
| tmp_recon, | |
| int(pred_signal[i].sample_rate) | |
| ) | |
| sf.write( | |
| os.path.join(cache_gt_dir, f'{idx}.wav'), | |
| tmp_signal, | |
| int(gt_signal[i].sample_rate) | |
| ) | |
| cnt += 1 | |
| dist.barrier() | |
| # Sync metrics | |
| for avg_metric in [l1_loss_avg]: | |
| vt = torch.tensor(avg_metric.item(), device=self.device) | |
| dist.all_reduce(vt, op=dist.ReduceOp.SUM) | |
| torch.cuda.synchronize() | |
| avg_metric.v = vt.item() / int(os.environ['WORLD_SIZE']) | |
| if self.is_master: | |
| prefix = 'eval_zdm' + ('_ema' if ema else '') | |
| ret = { | |
| f'{prefix}/l1_loss_avg': l1_loss_avg.item(), | |
| } | |
| else: | |
| ret = {} | |
| dist.barrier() | |
| ret = {k: utils.Averager(v) for k, v in ret.items()} | |
| return ret | |
| def visualize_audio_ae_random(self): | |
| """Save random audio reconstructions for listening""" | |
| if self.is_master: | |
| idx_list = list(range(len(self.datasets['eval_ae']))) | |
| random.shuffle(idx_list) | |
| n_samples = self.config.get('visualize_ae_random_n_samples', 8) | |
| audio_samples = [] | |
| for idx in idx_list[:n_samples]: | |
| data = self.datasets['eval_ae'][idx] | |
| # Prepare data | |
| if 'signal' in data: | |
| signal = data['signal'].unsqueeze(0).to(self.device) | |
| model_data = { | |
| 'inp': signal.audio_data, | |
| 'gt': signal.audio_data, | |
| 'sample_rate': signal.sample_rate | |
| } | |
| else: | |
| for k, v in data.items(): | |
| data[k] = v.unsqueeze(0).to(self.device) if torch.is_tensor(v) else v | |
| signal = AudioSignal(data['inp'], data.get('sample_rate', 24000)) | |
| model_data = data | |
| # Get reconstruction | |
| pred_audio = self.model(model_data, mode='pred') | |
| if isinstance(pred_audio, dict): | |
| pred_audio = pred_audio.get('audio', pred_audio.get('recons', pred_audio)) | |
| recons = AudioSignal(pred_audio, signal.sample_rate) | |
| # Save to file and log to Comet | |
| self.save_audio_sample(signal, f'audio_ae_original_{idx}') | |
| self.save_audio_sample(recons, f'audio_ae_recons_{idx}') | |
| dist.barrier() | |
| def visualize_audio_zdm_random(self, ema): | |
| """Save random audio generations from latent diffusion model""" | |
| if self.is_master: | |
| n_samples = self.config.get('visualize_zdm_random_n_samples', 8) | |
| for i in range(n_samples): | |
| # Generate random sample | |
| net_kwargs = dict() | |
| uncond_net_kwargs = dict() | |
| # Get a reference from dataset for parameters like sample_rate | |
| ref_data = self.datasets['eval_ae'][0] | |
| if 'signal' in ref_data: | |
| ref_signal = ref_data['signal'] | |
| sample_rate = ref_signal.sample_rate | |
| batch_size = 1 | |
| else: | |
| sample_rate = ref_data.get('sample_rate', 24000) | |
| batch_size = 1 | |
| pred_audio = self.model.generate_samples( | |
| batch_size=batch_size, | |
| n_steps=self.model.zdm_n_steps, | |
| net_kwargs=net_kwargs, | |
| uncond_net_kwargs=uncond_net_kwargs, | |
| ema=ema | |
| ) | |
| pred_signal = AudioSignal(pred_audio, sample_rate) | |
| # Save generated audio | |
| self.save_audio_sample(pred_signal, f'audio_zdm_generated_{i}') | |
| dist.barrier() | |
| def save_audio_sample(self, audio_signal, name): | |
| """Save audio sample and log to Comet ML""" | |
| try: | |
| # Ensure audio is in correct format | |
| audio_data = audio_signal.audio_data.cpu() | |
| # Handle different dimensions | |
| if audio_data.dim() == 3: # [batch, channels, samples] | |
| audio_data = audio_data[0] # Take first sample | |
| if audio_data.dim() == 2: # [channels, samples] | |
| audio_data = audio_data.transpose(0, 1) # [samples, channels] | |
| elif audio_data.dim() == 1: # [samples] | |
| audio_data = audio_data.unsqueeze(1) # [samples, 1] | |
| audio_data = audio_data.numpy() | |
| # Normalize if needed | |
| if np.abs(audio_data).max() > 1.0: | |
| audio_data = audio_data / np.abs(audio_data).max() | |
| # Save to file | |
| save_path = os.path.join(self.env['save_dir'], 'audio_samples') | |
| os.makedirs(save_path, exist_ok=True) | |
| file_path = os.path.join(save_path, f'{name}_step_{self.iter}.wav') | |
| sf.write(file_path, audio_data, int(audio_signal.sample_rate)) | |
| # Log to Comet ML | |
| if self.experiment: | |
| self.experiment.log_audio( | |
| file_path, | |
| metadata={ | |
| 'name': name, | |
| 'step': self.iter, | |
| 'sample_rate': int(audio_signal.sample_rate), | |
| 'duration': len(audio_data) / audio_signal.sample_rate, | |
| 'channels': audio_data.shape[1] if audio_data.ndim > 1 else 1 | |
| }, | |
| step=self.iter | |
| ) | |
| # Also log spectrograms for visualization | |
| if self.iter % self.config.get('spectrogram_log_freq', 1000) == 0: | |
| self._log_spectrogram(audio_signal, name) | |
| self.log(f"Saved audio sample: {file_path}") | |
| except Exception as e: | |
| self.log(f"Error saving audio sample {name}: {e}") | |
| if self.experiment: | |
| self.experiment.log_text(f"Error saving audio {name}: {str(e)}", step=self.iter) | |
| def _log_spectrogram(self, audio_signal, name): | |
| """Log spectrogram visualization to Comet ML""" | |
| if not self.experiment: | |
| return | |
| try: | |
| # Compute spectrogram | |
| stft_transform = torchaudio.transforms.Spectrogram( | |
| n_fft=2048, | |
| hop_length=512, | |
| power=2 | |
| ) | |
| audio_data = audio_signal.audio_data | |
| if audio_data.dim() == 3: | |
| audio_data = audio_data[0] | |
| if audio_data.dim() == 2: | |
| audio_data = audio_data[0] # Take first channel | |
| spec = stft_transform(audio_data.cpu()) | |
| spec_db = 10 * torch.log10(spec + 1e-8) | |
| # Create figure | |
| fig, ax = plt.subplots(figsize=(10, 4)) | |
| im = ax.imshow( | |
| spec_db.numpy(), | |
| aspect='auto', | |
| origin='lower', | |
| cmap='viridis', | |
| extent=[0, len(audio_data) / audio_signal.sample_rate, 0, audio_signal.sample_rate / 2] | |
| ) | |
| ax.set_xlabel('Time (s)') | |
| ax.set_ylabel('Frequency (Hz)') | |
| ax.set_title(f'{name} - Spectrogram') | |
| plt.colorbar(im, ax=ax, label='dB') | |
| # Log to Comet | |
| self.experiment.log_figure(f"spectrogram/{name}", fig, step=self.iter) | |
| plt.close(fig) | |
| except Exception as e: | |
| self.log(f"Error logging spectrogram for {name}: {e}") | |
| def save_checkpoint(self, tag="latest"): | |
| """Save checkpoint and log to Comet ML""" | |
| checkpoint_path = super().save_checkpoint(tag) | |
| if self.experiment and checkpoint_path: | |
| # Log checkpoint to Comet | |
| self.experiment.log_model( | |
| f"checkpoint_{tag}", | |
| checkpoint_path, | |
| metadata={ | |
| "step": self.iter, | |
| "tag": tag, | |
| "timestamp": datetime.now().isoformat() | |
| } | |
| ) |