| | import pytorch_lightning as pl |
| | import sys, gc |
| | import random |
| | import torch |
| | import torchaudio |
| | import typing as tp |
| | import wandb |
| |
|
| | from aeiou.viz import pca_point_cloud, audio_spectrogram_image, tokens_spectrogram_image |
| | from ema_pytorch import EMA |
| | from einops import rearrange |
| | from safetensors.torch import save_file |
| | from torch import optim |
| | from torch.nn import functional as F |
| | from pytorch_lightning.utilities.rank_zero import rank_zero_only |
| |
|
| | from ..models.lm import AudioLanguageModelWrapper |
| | from .utils import create_optimizer_from_config, create_scheduler_from_config |
| |
|
| | class AudioLanguageModelTrainingWrapper(pl.LightningModule): |
| | def __init__( |
| | self, |
| | model: AudioLanguageModelWrapper, |
| | lr = 1e-4, |
| | use_ema=False, |
| | ema_copy=None, |
| | optimizer_configs: dict = None, |
| | pre_encoded=False |
| | ): |
| | super().__init__() |
| |
|
| | self.model = model |
| |
|
| | self.model.pretransform.requires_grad_(False) |
| |
|
| | self.model_ema = None |
| | if use_ema: |
| | self.model_ema = EMA(self.model, ema_model=ema_copy, beta=0.99, update_every=10) |
| |
|
| | assert lr is not None or optimizer_configs is not None, "Must specify either lr or optimizer_configs in training config" |
| |
|
| | if optimizer_configs is None: |
| | optimizer_configs = { |
| | "lm": { |
| | "optimizer": { |
| | "type": "AdamW", |
| | "config": { |
| | "lr": lr, |
| | "betas": (0.9, 0.95), |
| | "weight_decay": 0.1 |
| | } |
| | } |
| | } |
| | } |
| | else: |
| | if lr is not None: |
| | print(f"WARNING: learning_rate and optimizer_configs both specified in config. Ignoring learning_rate and using optimizer_configs.") |
| |
|
| | self.optimizer_configs = optimizer_configs |
| |
|
| | self.pre_encoded = pre_encoded |
| |
|
| | def configure_optimizers(self): |
| | lm_opt_config = self.optimizer_configs['lm'] |
| | opt_lm = create_optimizer_from_config(lm_opt_config['optimizer'], self.model.parameters()) |
| |
|
| | if "scheduler" in lm_opt_config: |
| | sched_lm = create_scheduler_from_config(lm_opt_config['scheduler'], opt_lm) |
| | sched_lm_config = { |
| | "scheduler": sched_lm, |
| | "interval": "step" |
| | } |
| | return [opt_lm], [sched_lm_config] |
| |
|
| | return [opt_lm] |
| | |
| | |
| | |
| |
|
| | def _compute_cross_entropy( |
| | self, logits: torch.Tensor, targets: torch.Tensor, mask: torch.Tensor |
| | ) -> tp.Tuple[torch.Tensor, tp.List[torch.Tensor]]: |
| | """Compute cross entropy between multi-codebook targets and model's logits. |
| | The cross entropy is computed per codebook to provide codebook-level cross entropy. |
| | Valid timesteps for each of the codebook are pulled from the mask, where invalid |
| | timesteps are set to 0. |
| | |
| | Args: |
| | logits (torch.Tensor): Model's logits of shape [B, K, T, card]. |
| | targets (torch.Tensor): Target codes, of shape [B, K, T]. |
| | mask (torch.Tensor): Mask for valid target codes, of shape [B, K, T]. |
| | Returns: |
| | ce (torch.Tensor): Cross entropy averaged over the codebooks |
| | ce_per_codebook (list of torch.Tensor): Cross entropy per codebook (detached). |
| | """ |
| | B, K, T = targets.shape |
| | assert logits.shape[:-1] == targets.shape |
| | assert mask.shape == targets.shape |
| | ce = torch.zeros([], device=targets.device) |
| | ce_per_codebook: tp.List[torch.Tensor] = [] |
| | for k in range(K): |
| | logits_k = logits[:, k, ...].contiguous().view(-1, logits.size(-1)) |
| | targets_k = targets[:, k, ...].contiguous().view(-1) |
| | mask_k = mask[:, k, ...].contiguous().view(-1) |
| | ce_targets = targets_k[mask_k] |
| | ce_logits = logits_k[mask_k] |
| | q_ce = F.cross_entropy(ce_logits, ce_targets) |
| | ce += q_ce |
| | ce_per_codebook.append(q_ce.detach()) |
| | |
| | ce = ce / K |
| | return ce, ce_per_codebook |
| |
|
| | def training_step(self, batch, batch_idx): |
| | reals, metadata = batch |
| |
|
| | if reals.ndim == 4 and reals.shape[0] == 1: |
| | reals = reals[0] |
| |
|
| | if not self.pre_encoded: |
| | codes = self.model.pretransform.tokenize(reals) |
| | else: |
| | codes = reals |
| |
|
| | padding_masks = [] |
| | for md in metadata: |
| | if md["padding_mask"].ndim == 1: |
| | padding_masks.append(md["padding_mask"]) |
| | else: |
| | padding_masks.append(md["padding_mask"][0]) |
| | |
| | padding_masks = torch.stack(padding_masks, dim=0).to(self.device) |
| |
|
| | |
| | padding_masks = F.interpolate(padding_masks.unsqueeze(1).float(), size=codes.shape[2], mode='nearest').bool() |
| | |
| | condition_tensors = None |
| |
|
| | |
| | if self.model.conditioner is not None: |
| | condition_tensors = self.model.conditioner(metadata, self.device) |
| |
|
| | lm_output = self.model.compute_logits(codes, condition_tensors=condition_tensors, cfg_dropout_prob=0.1) |
| |
|
| | logits = lm_output.logits |
| | logits_mask = lm_output.mask |
| |
|
| | logits_mask = logits_mask & padding_masks |
| |
|
| | cross_entropy, cross_entropy_per_codebook = self._compute_cross_entropy(logits, codes, logits_mask) |
| |
|
| | loss = cross_entropy |
| |
|
| | log_dict = { |
| | 'train/loss': loss.detach(), |
| | 'train/cross_entropy': cross_entropy.detach(), |
| | 'train/perplexity': torch.exp(cross_entropy).detach(), |
| | 'train/lr': self.trainer.optimizers[0].param_groups[0]['lr'] |
| | } |
| |
|
| | for k, ce_q in enumerate(cross_entropy_per_codebook): |
| | log_dict[f'cross_entropy_q{k + 1}'] = ce_q |
| | log_dict[f'perplexity_q{k + 1}'] = torch.exp(ce_q) |
| |
|
| | self.log_dict(log_dict, prog_bar=True, on_step=True) |
| | return loss |
| |
|
| | def on_before_zero_grad(self, *args, **kwargs): |
| | if self.model_ema is not None: |
| | self.model_ema.update() |
| |
|
| | def export_model(self, path, use_safetensors=False): |
| | |
| | model = self.model_ema.ema_model if self.model_ema is not None else self.model |
| |
|
| | if use_safetensors: |
| | save_file(model.state_dict(), path) |
| | else: |
| | torch.save({"state_dict": model.state_dict()}, path) |
| | |
| |
|
| | class AudioLanguageModelDemoCallback(pl.Callback): |
| | def __init__(self, |
| | demo_every=2000, |
| | num_demos=8, |
| | sample_size=65536, |
| | sample_rate=48000, |
| | demo_conditioning: tp.Optional[tp.Dict[str, tp.Any]] = None, |
| | demo_cfg_scales: tp.Optional[tp.List[int]] = [3, 5, 7], |
| | **kwargs |
| | ): |
| | super().__init__() |
| |
|
| | self.demo_every = demo_every |
| | self.num_demos = num_demos |
| | self.demo_samples = sample_size |
| | self.sample_rate = sample_rate |
| | self.last_demo_step = -1 |
| | self.demo_conditioning = demo_conditioning |
| | self.demo_cfg_scales = demo_cfg_scales |
| |
|
| | @rank_zero_only |
| | @torch.no_grad() |
| | def on_train_batch_end(self, trainer, module: AudioLanguageModelTrainingWrapper, outputs, batch, batch_idx): |
| |
|
| | if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step: |
| | return |
| |
|
| | module.eval() |
| |
|
| | print(f"Generating demo") |
| | self.last_demo_step = trainer.global_step |
| |
|
| | demo_length_tokens = self.demo_samples // module.model.pretransform.downsampling_ratio |
| |
|
| | |
| |
|
| | |
| | |
| |
|
| | |
| |
|
| | |
| | |
| |
|
| | try: |
| | print("Getting conditioning") |
| |
|
| | for cfg_scale in self.demo_cfg_scales: |
| |
|
| | model = module.model |
| |
|
| | print(f"Generating demo for cfg scale {cfg_scale}") |
| | fakes = model.generate_audio( |
| | batch_size=self.num_demos, |
| | max_gen_len=demo_length_tokens, |
| | conditioning=self.demo_conditioning, |
| | |
| | cfg_scale=cfg_scale, |
| | temp=1.0, |
| | top_p=0.95 |
| | ) |
| |
|
| | |
| | fakes = rearrange(fakes, 'b d n -> d (b n)') |
| |
|
| | log_dict = {} |
| | |
| | filename = f'demo_cfg_{cfg_scale}_{trainer.global_step:08}.wav' |
| | fakes = fakes / fakes.abs().max() |
| | fakes = fakes.type(torch.float32).clamp(-1, 1).mul(32767).type(torch.int16).cpu() |
| | torchaudio.save(filename, fakes, self.sample_rate) |
| |
|
| | log_dict[f'demo_cfg_{cfg_scale}'] = wandb.Audio(filename, |
| | sample_rate=self.sample_rate, |
| | caption=f'Reconstructed') |
| | |
| | log_dict[f'demo_melspec_left_cfg_{cfg_scale}'] = wandb.Image(audio_spectrogram_image(fakes)) |
| |
|
| | trainer.logger.experiment.log(log_dict) |
| |
|
| | except Exception as e: |
| | raise e |
| | finally: |
| | gc.collect() |
| | torch.cuda.empty_cache() |
| | module.train() |