| | import torch |
| | import torchaudio |
| | import wandb |
| | from einops import rearrange |
| | from safetensors.torch import save_file, save_model |
| | from ema_pytorch import EMA |
| | from .losses.auraloss import SumAndDifferenceSTFTLoss, MultiResolutionSTFTLoss |
| | import pytorch_lightning as pl |
| | from ..models.autoencoders import AudioAutoencoder |
| | from ..models.discriminators import EncodecDiscriminator, OobleckDiscriminator, DACGANLoss |
| | from ..models.bottleneck import VAEBottleneck, RVQBottleneck, DACRVQBottleneck, DACRVQVAEBottleneck, RVQVAEBottleneck, WassersteinBottleneck |
| | from .losses import MultiLoss, AuralossLoss, ValueLoss, L1Loss |
| | from .utils import create_optimizer_from_config, create_scheduler_from_config |
| |
|
| |
|
| | from pytorch_lightning.utilities.rank_zero import rank_zero_only |
| | from aeiou.viz import pca_point_cloud, audio_spectrogram_image, tokens_spectrogram_image |
| |
|
| | class AutoencoderTrainingWrapper(pl.LightningModule): |
| | def __init__( |
| | self, |
| | autoencoder: AudioAutoencoder, |
| | lr: float = 1e-4, |
| | warmup_steps: int = 0, |
| | encoder_freeze_on_warmup: bool = False, |
| | sample_rate=48000, |
| | loss_config: dict = None, |
| | optimizer_configs: dict = None, |
| | use_ema: bool = True, |
| | ema_copy = None, |
| | force_input_mono = False, |
| | latent_mask_ratio = 0.0, |
| | teacher_model: AudioAutoencoder = None |
| | ): |
| | super().__init__() |
| |
|
| | self.automatic_optimization = False |
| |
|
| | self.autoencoder = autoencoder |
| |
|
| | self.warmed_up = False |
| | self.warmup_steps = warmup_steps |
| | self.encoder_freeze_on_warmup = encoder_freeze_on_warmup |
| | self.lr = lr |
| |
|
| | self.force_input_mono = force_input_mono |
| |
|
| | self.teacher_model = teacher_model |
| |
|
| | if optimizer_configs is None: |
| | optimizer_configs ={ |
| | "autoencoder": { |
| | "optimizer": { |
| | "type": "AdamW", |
| | "config": { |
| | "lr": lr, |
| | "betas": (.8, .99) |
| | } |
| | } |
| | }, |
| | "discriminator": { |
| | "optimizer": { |
| | "type": "AdamW", |
| | "config": { |
| | "lr": lr, |
| | "betas": (.8, .99) |
| | } |
| | } |
| | } |
| |
|
| | } |
| | |
| | self.optimizer_configs = optimizer_configs |
| |
|
| | if loss_config is None: |
| | scales = [2048, 1024, 512, 256, 128, 64, 32] |
| | hop_sizes = [] |
| | win_lengths = [] |
| | overlap = 0.75 |
| | for s in scales: |
| | hop_sizes.append(int(s * (1 - overlap))) |
| | win_lengths.append(s) |
| | |
| | loss_config = { |
| | "discriminator": { |
| | "type": "encodec", |
| | "config": { |
| | "n_ffts": scales, |
| | "hop_lengths": hop_sizes, |
| | "win_lengths": win_lengths, |
| | "filters": 32 |
| | }, |
| | "weights": { |
| | "adversarial": 0.1, |
| | "feature_matching": 5.0, |
| | } |
| | }, |
| | "spectral": { |
| | "type": "mrstft", |
| | "config": { |
| | "fft_sizes": scales, |
| | "hop_sizes": hop_sizes, |
| | "win_lengths": win_lengths, |
| | "perceptual_weighting": True |
| | }, |
| | "weights": { |
| | "mrstft": 1.0, |
| | } |
| | }, |
| | "time": { |
| | "type": "l1", |
| | "config": {}, |
| | "weights": { |
| | "l1": 0.0, |
| | } |
| | } |
| | } |
| | |
| | self.loss_config = loss_config |
| | |
| | |
| |
|
| | stft_loss_args = loss_config['spectral']['config'] |
| |
|
| | if self.autoencoder.out_channels == 2: |
| | self.sdstft = SumAndDifferenceSTFTLoss(sample_rate=sample_rate, **stft_loss_args) |
| | self.lrstft = MultiResolutionSTFTLoss(sample_rate=sample_rate, **stft_loss_args) |
| | else: |
| | self.sdstft = MultiResolutionSTFTLoss(sample_rate=sample_rate, **stft_loss_args) |
| |
|
| | |
| |
|
| | if loss_config['discriminator']['type'] == 'oobleck': |
| | self.discriminator = OobleckDiscriminator(**loss_config['discriminator']['config']) |
| | elif loss_config['discriminator']['type'] == 'encodec': |
| | self.discriminator = EncodecDiscriminator(in_channels=self.autoencoder.out_channels, **loss_config['discriminator']['config']) |
| | elif loss_config['discriminator']['type'] == 'dac': |
| | self.discriminator = DACGANLoss(channels=self.autoencoder.out_channels, sample_rate=sample_rate, **loss_config['discriminator']['config']) |
| |
|
| | self.gen_loss_modules = [] |
| |
|
| | |
| | self.gen_loss_modules += [ |
| | ValueLoss(key='loss_adv', weight=self.loss_config['discriminator']['weights']['adversarial'], name='loss_adv'), |
| | ValueLoss(key='feature_matching_distance', weight=self.loss_config['discriminator']['weights']['feature_matching'], name='feature_matching'), |
| | ] |
| |
|
| | if self.teacher_model is not None: |
| | |
| |
|
| | stft_loss_weight = self.loss_config['spectral']['weights']['mrstft'] * 0.25 |
| | self.gen_loss_modules += [ |
| | AuralossLoss(self.sdstft, 'reals', 'decoded', name='mrstft_loss', weight=stft_loss_weight), |
| | AuralossLoss(self.sdstft, 'decoded', 'teacher_decoded', name='mrstft_loss_distill', weight=stft_loss_weight), |
| | AuralossLoss(self.sdstft, 'reals', 'own_latents_teacher_decoded', name='mrstft_loss_own_latents_teacher', weight=stft_loss_weight), |
| | AuralossLoss(self.sdstft, 'reals', 'teacher_latents_own_decoded', name='mrstft_loss_teacher_latents_own', weight=stft_loss_weight) |
| | ] |
| |
|
| | else: |
| |
|
| | |
| | self.gen_loss_modules += [ |
| | AuralossLoss(self.sdstft, 'reals', 'decoded', name='mrstft_loss', weight=self.loss_config['spectral']['weights']['mrstft']), |
| | ] |
| |
|
| | if self.autoencoder.out_channels == 2: |
| |
|
| | |
| | self.gen_loss_modules += [ |
| | AuralossLoss(self.lrstft, 'reals_left', 'decoded_left', name='stft_loss_left', weight=self.loss_config['spectral']['weights']['mrstft']/2), |
| | AuralossLoss(self.lrstft, 'reals_right', 'decoded_right', name='stft_loss_right', weight=self.loss_config['spectral']['weights']['mrstft']/2), |
| | ] |
| |
|
| | self.gen_loss_modules += [ |
| | AuralossLoss(self.sdstft, 'reals', 'decoded', name='mrstft_loss', weight=self.loss_config['spectral']['weights']['mrstft']), |
| | ] |
| |
|
| | if self.loss_config['time']['weights']['l1'] > 0.0: |
| | self.gen_loss_modules.append(L1Loss(key_a='reals', key_b='decoded', weight=self.loss_config['time']['weights']['l1'], name='l1_time_loss')) |
| |
|
| | if self.autoencoder.bottleneck is not None: |
| | self.gen_loss_modules += create_loss_modules_from_bottleneck(self.autoencoder.bottleneck, self.loss_config) |
| |
|
| | self.losses_gen = MultiLoss(self.gen_loss_modules) |
| |
|
| | self.disc_loss_modules = [ |
| | ValueLoss(key='loss_dis', weight=1.0, name='discriminator_loss'), |
| | ] |
| |
|
| | self.losses_disc = MultiLoss(self.disc_loss_modules) |
| |
|
| | |
| | self.autoencoder_ema = None |
| | |
| | self.use_ema = use_ema |
| |
|
| | if self.use_ema: |
| | self.autoencoder_ema = EMA( |
| | self.autoencoder, |
| | ema_model=ema_copy, |
| | beta=0.9999, |
| | power=3/4, |
| | update_every=1, |
| | update_after_step=1 |
| | ) |
| |
|
| | self.latent_mask_ratio = latent_mask_ratio |
| |
|
| | def configure_optimizers(self): |
| |
|
| | opt_gen = create_optimizer_from_config(self.optimizer_configs['autoencoder']['optimizer'], self.autoencoder.parameters()) |
| | opt_disc = create_optimizer_from_config(self.optimizer_configs['discriminator']['optimizer'], self.discriminator.parameters()) |
| |
|
| | if "scheduler" in self.optimizer_configs['autoencoder'] and "scheduler" in self.optimizer_configs['discriminator']: |
| | sched_gen = create_scheduler_from_config(self.optimizer_configs['autoencoder']['scheduler'], opt_gen) |
| | sched_disc = create_scheduler_from_config(self.optimizer_configs['discriminator']['scheduler'], opt_disc) |
| | return [opt_gen, opt_disc], [sched_gen, sched_disc] |
| |
|
| | return [opt_gen, opt_disc] |
| | |
| | def training_step(self, batch, batch_idx): |
| | reals, _ = batch |
| |
|
| | |
| | if reals.ndim == 4 and reals.shape[0] == 1: |
| | reals = reals[0] |
| |
|
| | if self.global_step >= self.warmup_steps: |
| | self.warmed_up = True |
| |
|
| | loss_info = {} |
| |
|
| | loss_info["reals"] = reals |
| |
|
| | encoder_input = reals |
| |
|
| | if self.force_input_mono and encoder_input.shape[1] > 1: |
| | encoder_input = encoder_input.mean(dim=1, keepdim=True) |
| |
|
| | loss_info["encoder_input"] = encoder_input |
| |
|
| | data_std = encoder_input.std() |
| |
|
| | if self.warmed_up and self.encoder_freeze_on_warmup: |
| | with torch.no_grad(): |
| | latents, encoder_info = self.autoencoder.encode(encoder_input, return_info=True) |
| | else: |
| | latents, encoder_info = self.autoencoder.encode(encoder_input, return_info=True) |
| |
|
| | loss_info["latents"] = latents |
| |
|
| | loss_info.update(encoder_info) |
| |
|
| | |
| | if self.teacher_model is not None: |
| | with torch.no_grad(): |
| | teacher_latents = self.teacher_model.encode(encoder_input, return_info=False) |
| | loss_info['teacher_latents'] = teacher_latents |
| |
|
| | if self.latent_mask_ratio > 0.0: |
| | mask = torch.rand_like(latents) < self.latent_mask_ratio |
| | latents = torch.where(mask, torch.zeros_like(latents), latents) |
| |
|
| | decoded = self.autoencoder.decode(latents) |
| |
|
| | loss_info["decoded"] = decoded |
| |
|
| | if self.autoencoder.out_channels == 2: |
| | loss_info["decoded_left"] = decoded[:, 0:1, :] |
| | loss_info["decoded_right"] = decoded[:, 1:2, :] |
| | loss_info["reals_left"] = reals[:, 0:1, :] |
| | loss_info["reals_right"] = reals[:, 1:2, :] |
| |
|
| | |
| | if self.teacher_model is not None: |
| | with torch.no_grad(): |
| | teacher_decoded = self.teacher_model.decode(teacher_latents) |
| | own_latents_teacher_decoded = self.teacher_model.decode(latents) |
| | teacher_latents_own_decoded = self.autoencoder.decode(teacher_latents) |
| |
|
| | loss_info['teacher_decoded'] = teacher_decoded |
| | loss_info['own_latents_teacher_decoded'] = own_latents_teacher_decoded |
| | loss_info['teacher_latents_own_decoded'] = teacher_latents_own_decoded |
| |
|
| | |
| | if self.warmed_up: |
| | loss_dis, loss_adv, feature_matching_distance = self.discriminator.loss(reals, decoded) |
| | else: |
| | loss_dis = torch.tensor(0.).to(reals) |
| | loss_adv = torch.tensor(0.).to(reals) |
| | feature_matching_distance = torch.tensor(0.).to(reals) |
| |
|
| | loss_info["loss_dis"] = loss_dis |
| | loss_info["loss_adv"] = loss_adv |
| | loss_info["feature_matching_distance"] = feature_matching_distance |
| |
|
| | opt_gen, opt_disc = self.optimizers() |
| |
|
| | lr_schedulers = self.lr_schedulers() |
| |
|
| | sched_gen = None |
| | sched_disc = None |
| |
|
| | if lr_schedulers is not None: |
| | sched_gen, sched_disc = lr_schedulers |
| |
|
| | |
| | if self.global_step % 2 and self.warmed_up: |
| | loss, losses = self.losses_disc(loss_info) |
| |
|
| | log_dict = { |
| | 'train/disc_lr': opt_disc.param_groups[0]['lr'] |
| | } |
| |
|
| | opt_disc.zero_grad() |
| | self.manual_backward(loss) |
| | opt_disc.step() |
| |
|
| | if sched_disc is not None: |
| | |
| | sched_disc.step() |
| |
|
| | |
| | else: |
| |
|
| | loss, losses = self.losses_gen(loss_info) |
| |
|
| | if self.use_ema: |
| | self.autoencoder_ema.update() |
| |
|
| | opt_gen.zero_grad() |
| | self.manual_backward(loss) |
| | opt_gen.step() |
| |
|
| | if sched_gen is not None: |
| | |
| | sched_gen.step() |
| |
|
| | log_dict = { |
| | 'train/loss': loss.detach(), |
| | 'train/latent_std': latents.std().detach(), |
| | 'train/data_std': data_std.detach(), |
| | 'train/gen_lr': opt_gen.param_groups[0]['lr'] |
| | } |
| |
|
| | for loss_name, loss_value in losses.items(): |
| | log_dict[f'train/{loss_name}'] = loss_value.detach() |
| |
|
| | self.log_dict(log_dict, prog_bar=True, on_step=True) |
| |
|
| | return loss |
| | |
| | def export_model(self, path, use_safetensors=False): |
| | if self.autoencoder_ema is not None: |
| | model = self.autoencoder_ema.ema_model |
| | else: |
| | model = self.autoencoder |
| | |
| | if use_safetensors: |
| | save_model(model, path) |
| | else: |
| | torch.save({"state_dict": model.state_dict()}, path) |
| | |
| |
|
| | class AutoencoderDemoCallback(pl.Callback): |
| | def __init__( |
| | self, |
| | demo_dl, |
| | demo_every=2000, |
| | sample_size=65536, |
| | sample_rate=48000 |
| | ): |
| | super().__init__() |
| | self.demo_every = demo_every |
| | self.demo_samples = sample_size |
| | self.demo_dl = iter(demo_dl) |
| | self.sample_rate = sample_rate |
| | self.last_demo_step = -1 |
| |
|
| | @rank_zero_only |
| | @torch.no_grad() |
| | def on_train_batch_end(self, trainer, module, outputs, batch, batch_idx): |
| | if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step: |
| | return |
| | |
| | self.last_demo_step = trainer.global_step |
| |
|
| | module.eval() |
| |
|
| | try: |
| | demo_reals, _ = next(self.demo_dl) |
| |
|
| | |
| | if demo_reals.ndim == 4 and demo_reals.shape[0] == 1: |
| | demo_reals = demo_reals[0] |
| |
|
| | encoder_input = demo_reals |
| | |
| | encoder_input = encoder_input.to(module.device) |
| |
|
| | if module.force_input_mono: |
| | encoder_input = encoder_input.mean(dim=1, keepdim=True) |
| |
|
| | demo_reals = demo_reals.to(module.device) |
| |
|
| | with torch.no_grad(): |
| | if module.use_ema: |
| |
|
| | latents = module.autoencoder_ema.ema_model.encode(encoder_input) |
| |
|
| | fakes = module.autoencoder_ema.ema_model.decode(latents) |
| | else: |
| | latents = module.autoencoder.encode(encoder_input) |
| |
|
| | fakes = module.autoencoder.decode(latents) |
| |
|
| | |
| | reals_fakes = rearrange([demo_reals, fakes], 'i b d n -> (b i) d n') |
| |
|
| | |
| | reals_fakes = rearrange(reals_fakes, 'b d n -> d (b n)') |
| |
|
| | log_dict = {} |
| | |
| | filename = f'recon_{trainer.global_step:08}.wav' |
| | reals_fakes = reals_fakes.to(torch.float32).clamp(-1, 1).mul(32767).to(torch.int16).cpu() |
| | torchaudio.save(filename, reals_fakes, self.sample_rate) |
| |
|
| | log_dict[f'recon'] = wandb.Audio(filename, |
| | sample_rate=self.sample_rate, |
| | caption=f'Reconstructed') |
| | |
| | log_dict[f'embeddings_3dpca'] = pca_point_cloud(latents) |
| | log_dict[f'embeddings_spec'] = wandb.Image(tokens_spectrogram_image(latents)) |
| |
|
| | log_dict[f'recon_melspec_left'] = wandb.Image(audio_spectrogram_image(reals_fakes)) |
| |
|
| | trainer.logger.experiment.log(log_dict) |
| | except Exception as e: |
| | print(f'{type(e).__name__}: {e}') |
| | raise e |
| | finally: |
| | module.train() |
| |
|
| | def create_loss_modules_from_bottleneck(bottleneck, loss_config): |
| | losses = [] |
| | |
| | if isinstance(bottleneck, VAEBottleneck) or isinstance(bottleneck, DACRVQVAEBottleneck) or isinstance(bottleneck, RVQVAEBottleneck): |
| | try: |
| | kl_weight = loss_config['bottleneck']['weights']['kl'] |
| | except: |
| | kl_weight = 1e-6 |
| |
|
| | kl_loss = ValueLoss(key='kl', weight=kl_weight, name='kl_loss') |
| | losses.append(kl_loss) |
| |
|
| | if isinstance(bottleneck, RVQBottleneck) or isinstance(bottleneck, RVQVAEBottleneck): |
| | quantizer_loss = ValueLoss(key='quantizer_loss', weight=1.0, name='quantizer_loss') |
| | losses.append(quantizer_loss) |
| |
|
| | if isinstance(bottleneck, DACRVQBottleneck) or isinstance(bottleneck, DACRVQVAEBottleneck): |
| | codebook_loss = ValueLoss(key='vq/codebook_loss', weight=1.0, name='codebook_loss') |
| | commitment_loss = ValueLoss(key='vq/commitment_loss', weight=0.25, name='commitment_loss') |
| | losses.append(codebook_loss) |
| | losses.append(commitment_loss) |
| |
|
| | if isinstance(bottleneck, WassersteinBottleneck): |
| | try: |
| | mmd_weight = loss_config['bottleneck']['weights']['mmd'] |
| | except: |
| | mmd_weight = 100 |
| |
|
| | mmd_loss = ValueLoss(key='mmd', weight=mmd_weight, name='mmd_loss') |
| | losses.append(mmd_loss) |
| | |
| | return losses |