Spaces:
Sleeping
Sleeping
| import torch | |
| import os | |
| import pytorch_lightning as pl | |
| import torch.nn.functional as F | |
| from contextlib import contextmanager | |
| import numpy as np | |
| from latent_diffusion.modules.ema import * | |
| # from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer | |
| from taming.modules.vqvae.quantize import VectorQuantizer | |
| from torch.optim.lr_scheduler import LambdaLR | |
| from latent_diffusion.modules.diffusionmodules.model import Encoder, Decoder | |
| from latent_diffusion.modules.distributions.distributions import ( | |
| DiagonalGaussianDistribution, | |
| ) | |
| import wandb | |
| from latent_diffusion.util import instantiate_from_config | |
| import soundfile as sf | |
| from huggingface_hub import hf_hub_download | |
| from utilities.model import get_vocoder | |
| from utilities.tools import synth_one_sample | |
| import itertools | |
| from latent_encoder.wavedecoder import Generator | |
| class AutoencoderKL(pl.LightningModule): | |
| def __init__( | |
| self, | |
| ddconfig=None, | |
| lossconfig=None, | |
| batchsize=None, | |
| embed_dim=None, | |
| time_shuffle=1, | |
| subband=1, | |
| sampling_rate=16000, | |
| ckpt_path=None, | |
| reload_from_ckpt=None, | |
| ignore_keys=[], | |
| image_key="fbank", | |
| colorize_nlabels=None, | |
| monitor=None, | |
| base_learning_rate=1e-5, | |
| ): | |
| super().__init__() | |
| self.automatic_optimization=False | |
| assert "mel_bins" in ddconfig.keys(), "mel_bins is not specified in the Autoencoder config" | |
| num_mel = ddconfig["mel_bins"] | |
| self.image_key = image_key | |
| self.sampling_rate = sampling_rate | |
| self.encoder = Encoder(**ddconfig) | |
| self.decoder = Decoder(**ddconfig) | |
| self.loss = instantiate_from_config(lossconfig) | |
| self.subband = int(subband) | |
| if self.subband > 1: | |
| print("Use subband decomposition %s" % self.subband) | |
| assert ddconfig["double_z"] | |
| self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1) | |
| self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) | |
| if self.image_key == "fbank": | |
| self.vocoder = get_vocoder(None, "cpu", num_mel) | |
| elif self.image_key == "stft": | |
| self.wave_decoder = Generator(input_channel=512) | |
| self.wave_decoder.train() | |
| self.embed_dim = embed_dim | |
| if colorize_nlabels is not None: | |
| assert type(colorize_nlabels) == int | |
| self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) | |
| if monitor is not None: | |
| self.monitor = monitor | |
| if ckpt_path is not None: | |
| self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) | |
| self.learning_rate = float(base_learning_rate) | |
| print("Initial learning rate %s" % self.learning_rate) | |
| self.time_shuffle = time_shuffle | |
| self.reload_from_ckpt = reload_from_ckpt | |
| self.reloaded = False | |
| self.mean, self.std = None, None | |
| self.feature_cache = None | |
| self.flag_first_run = True | |
| self.train_step = 0 | |
| self.logger_save_dir = None | |
| self.logger_exp_name = None | |
| if not self.reloaded and self.reload_from_ckpt is not None: | |
| ckpt_path = self.reload_from_ckpt | |
| if not os.path.exists(ckpt_path): | |
| print( | |
| "--> Local autoencoder ckpt not found: %s, downloading vae.ckpt from Hugging Face..." | |
| % ckpt_path | |
| ) | |
| ckpt_path = hf_hub_download( | |
| repo_id="ShandaAI/FlowSep-hive", | |
| filename="vae.ckpt", | |
| ) | |
| print("--> Reload weight of autoencoder from %s" % ckpt_path) | |
| checkpoint = torch.load(ckpt_path) | |
| self.load_state_dict(checkpoint["state_dict"]) | |
| self.reloaded = True | |
| else: | |
| print("Train from scratch") | |
| def get_log_dir(self): | |
| if self.logger_save_dir is None and self.logger_exp_name is None: | |
| return os.path.join(self.logger.save_dir, self.logger._project) | |
| else: | |
| return os.path.join(self.logger_save_dir, self.logger_exp_name) | |
| def set_log_dir(self, save_dir, exp_name): | |
| self.logger_save_dir = save_dir | |
| self.logger_exp_name = exp_name | |
| def init_from_ckpt(self, path, ignore_keys=list()): | |
| sd = torch.load(path, map_location="cpu")["state_dict"] | |
| keys = list(sd.keys()) | |
| for k in keys: | |
| for ik in ignore_keys: | |
| if k.startswith(ik): | |
| print("Deleting key {} from state_dict.".format(k)) | |
| del sd[k] | |
| self.load_state_dict(sd, strict=False) | |
| print(f"Restored from {path}") | |
| def encode(self, x): | |
| # x = self.time_shuffle_operation(x) | |
| x = self.freq_split_subband(x) | |
| h = self.encoder(x) | |
| moments = self.quant_conv(h) | |
| posterior = DiagonalGaussianDistribution(moments) | |
| return posterior | |
| def decode(self, z): | |
| z = self.post_quant_conv(z) | |
| dec = self.decoder(z) | |
| # bs, ch, shuffled_timesteps, fbins = dec.size() | |
| # dec = self.time_unshuffle_operation(dec, bs, int(ch*shuffled_timesteps), fbins) | |
| dec = self.freq_merge_subband(dec) | |
| return dec | |
| def decode_to_waveform(self, dec): | |
| from utilities.model import vocoder_infer | |
| if self.image_key == "fbank": | |
| dec = dec.squeeze(1).permute(0, 2, 1) | |
| wav_reconstruction = vocoder_infer(dec, self.vocoder) | |
| elif self.image_key == "stft": | |
| dec = dec.squeeze(1).permute(0, 2, 1) | |
| wav_reconstruction = self.wave_decoder(dec) | |
| return wav_reconstruction | |
| def visualize_latent(self, input): | |
| import matplotlib.pyplot as plt | |
| np.save("input.npy", input.cpu().detach().numpy()) | |
| time_input = input.clone() | |
| time_input[:, :, :, :32] *= 0 | |
| time_input[:, :, :, :32] -= 11.59 | |
| np.save("time_input.npy", time_input.cpu().detach().numpy()) | |
| posterior = self.encode(time_input) | |
| latent = posterior.sample() | |
| np.save("time_latent.npy", latent.cpu().detach().numpy()) | |
| avg_latent = torch.mean(latent, dim=1) | |
| for i in range(avg_latent.size(0)): | |
| plt.imshow(avg_latent[i].cpu().detach().numpy().T) | |
| plt.savefig("freq_%s.png" % i) | |
| plt.close() | |
| freq_input = input.clone() | |
| freq_input[:, :, :512, :] *= 0 | |
| freq_input[:, :, :512, :] -= 11.59 | |
| np.save("freq_input.npy", freq_input.cpu().detach().numpy()) | |
| posterior = self.encode(freq_input) | |
| latent = posterior.sample() | |
| np.save("freq_latent.npy", latent.cpu().detach().numpy()) | |
| avg_latent = torch.mean(latent, dim=1) | |
| for i in range(avg_latent.size(0)): | |
| plt.imshow(avg_latent[i].cpu().detach().numpy().T) | |
| plt.savefig("time_%s.png" % i) | |
| plt.close() | |
| def forward(self, input, sample_posterior=True): | |
| posterior = self.encode(input) | |
| if sample_posterior: | |
| z = posterior.sample() | |
| else: | |
| z = posterior.mode() | |
| if self.flag_first_run: | |
| print("Latent size: ", z.size()) | |
| self.flag_first_run = False | |
| dec = self.decode(z) | |
| return dec, posterior | |
| def get_input(self, batch): | |
| fname, text, label_indices, waveform, stft, fbank = ( | |
| batch["fname"], | |
| batch["text"], | |
| batch["label_vector"], | |
| batch["waveform"], | |
| batch["stft"], | |
| batch["log_mel_spec"], | |
| ) | |
| ret = {} | |
| ret["fbank"], ret["stft"], ret["fname"], ret["waveform"] = ( | |
| fbank.unsqueeze(1), | |
| stft.unsqueeze(1), | |
| fname, | |
| waveform.unsqueeze(1), | |
| ) | |
| return ret | |
| def freq_split_subband(self, fbank): | |
| if self.subband == 1 or self.image_key != "stft": | |
| return fbank | |
| bs, ch, tstep, fbins = fbank.size() | |
| assert fbank.size(-1) % self.subband == 0 | |
| assert ch == 1 | |
| return ( | |
| fbank.squeeze(1) | |
| .reshape(bs, tstep, self.subband, fbins // self.subband) | |
| .permute(0, 2, 1, 3) | |
| ) | |
| def freq_merge_subband(self, subband_fbank): | |
| if self.subband == 1 or self.image_key != "stft": | |
| return subband_fbank | |
| assert subband_fbank.size(1) == self.subband # Channel dimension | |
| bs, sub_ch, tstep, fbins = subband_fbank.size() | |
| return subband_fbank.permute(0, 2, 1, 3).reshape(bs, tstep, -1).unsqueeze(1) | |
| def save_wave(self, batch_wav, fname, save_dir): | |
| os.makedirs(save_dir, exist_ok=True) | |
| for wav, name in zip(batch_wav, fname): | |
| name = os.path.basename(name) | |
| sf.write(os.path.join(save_dir, name), wav, samplerate=self.sampling_rate) | |
| def get_last_layer(self): | |
| return self.decoder.conv_out.weight | |
| def log_images(self, batch, train=True, only_inputs=False, waveform=None, **kwargs): | |
| log = dict() | |
| x = batch.to(self.device) | |
| if not only_inputs: | |
| xrec, posterior = self(x) | |
| log["samples"] = self.decode(posterior.sample()) | |
| log["reconstructions"] = xrec | |
| log["inputs"] = x | |
| wavs = self._log_img(log, train=train, index=0, waveform=waveform) | |
| return wavs | |
| def _log_img(self, log, train=True, index=0, waveform=None): | |
| images_input = self.tensor2numpy(log["inputs"][index, 0]).T | |
| images_reconstruct = self.tensor2numpy(log["reconstructions"][index, 0]).T | |
| images_samples = self.tensor2numpy(log["samples"][index, 0]).T | |
| if train: | |
| name = "train" | |
| else: | |
| name = "val" | |
| if self.logger is not None: | |
| self.logger.log_image( | |
| "img_%s" % name, | |
| [images_input, images_reconstruct, images_samples], | |
| caption=["input", "reconstruct", "samples"], | |
| ) | |
| inputs, reconstructions, samples = ( | |
| log["inputs"], | |
| log["reconstructions"], | |
| log["samples"], | |
| ) | |
| if self.image_key == "fbank": | |
| wav_original, wav_prediction = synth_one_sample( | |
| inputs[index], | |
| reconstructions[index], | |
| labels="validation", | |
| vocoder=self.vocoder, | |
| ) | |
| wav_original, wav_samples = synth_one_sample( | |
| inputs[index], samples[index], labels="validation", vocoder=self.vocoder | |
| ) | |
| wav_original, wav_samples, wav_prediction = ( | |
| wav_original[0], | |
| wav_samples[0], | |
| wav_prediction[0], | |
| ) | |
| elif self.image_key == "stft": | |
| wav_prediction = ( | |
| self.decode_to_waveform(reconstructions)[index, 0] | |
| .cpu() | |
| .detach() | |
| .numpy() | |
| ) | |
| wav_samples = ( | |
| self.decode_to_waveform(samples)[index, 0].cpu().detach().numpy() | |
| ) | |
| wav_original = waveform[index, 0].cpu().detach().numpy() | |
| if self.logger is not None: | |
| self.logger.experiment.log( | |
| { | |
| "original_%s" | |
| % name: wandb.Audio( | |
| wav_original, caption="original", sample_rate=self.sampling_rate | |
| ), | |
| "reconstruct_%s" | |
| % name: wandb.Audio( | |
| wav_prediction, caption="reconstruct", sample_rate=self.sampling_rate | |
| ), | |
| "samples_%s" | |
| % name: wandb.Audio( | |
| wav_samples, caption="samples", sample_rate=self.sampling_rate | |
| ), | |
| } | |
| ) | |
| return wav_original, wav_prediction, wav_samples | |
| def tensor2numpy(self, tensor): | |
| return tensor.cpu().detach().numpy() | |
| def to_rgb(self, x): | |
| assert self.image_key == "segmentation" | |
| if not hasattr(self, "colorize"): | |
| self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) | |
| x = F.conv2d(x, weight=self.colorize) | |
| x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0 | |
| return x | |
| class VQModel(pl.LightningModule): | |
| def __init__( | |
| self, | |
| ddconfig, | |
| lossconfig, | |
| n_embed, | |
| embed_dim, | |
| ckpt_path=None, | |
| ignore_keys=[], | |
| image_key="image", | |
| colorize_nlabels=None, | |
| monitor=None, | |
| batch_resize_range=None, | |
| scheduler_config=None, | |
| lr_g_factor=1.0, | |
| remap=None, | |
| sane_index_shape=False, # tell vector quantizer to return indices as bhw | |
| use_ema=False, | |
| ): | |
| super().__init__() | |
| self.embed_dim = embed_dim | |
| self.n_embed = n_embed | |
| self.image_key = image_key | |
| self.encoder = Encoder(**ddconfig) | |
| self.decoder = Decoder(**ddconfig) | |
| self.loss = instantiate_from_config(lossconfig) | |
| self.quantize = VectorQuantizer( | |
| n_embed, | |
| embed_dim, | |
| beta=0.25, | |
| remap=remap, | |
| sane_index_shape=sane_index_shape, | |
| ) | |
| self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1) | |
| self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) | |
| if colorize_nlabels is not None: | |
| assert type(colorize_nlabels) == int | |
| self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) | |
| if monitor is not None: | |
| self.monitor = monitor | |
| self.batch_resize_range = batch_resize_range | |
| if self.batch_resize_range is not None: | |
| print( | |
| f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}." | |
| ) | |
| self.use_ema = use_ema | |
| if self.use_ema: | |
| self.model_ema = LitEma(self) | |
| print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") | |
| if ckpt_path is not None: | |
| self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) | |
| self.scheduler_config = scheduler_config | |
| self.lr_g_factor = lr_g_factor | |
| def ema_scope(self, context=None): | |
| if self.use_ema: | |
| self.model_ema.store(self.parameters()) | |
| self.model_ema.copy_to(self) | |
| if context is not None: | |
| print(f"{context}: Switched to EMA weights") | |
| try: | |
| yield None | |
| finally: | |
| if self.use_ema: | |
| self.model_ema.restore(self.parameters()) | |
| if context is not None: | |
| print(f"{context}: Restored training weights") | |
| def init_from_ckpt(self, path, ignore_keys=list()): | |
| sd = torch.load(path, map_location="cpu")["state_dict"] | |
| keys = list(sd.keys()) | |
| for k in keys: | |
| for ik in ignore_keys: | |
| if k.startswith(ik): | |
| print("Deleting key {} from state_dict.".format(k)) | |
| del sd[k] | |
| missing, unexpected = self.load_state_dict(sd, strict=False) | |
| print( | |
| f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys" | |
| ) | |
| if len(missing) > 0: | |
| print(f"Missing Keys: {missing}") | |
| print(f"Unexpected Keys: {unexpected}") | |
| def on_train_batch_end(self, *args, **kwargs): | |
| if self.use_ema: | |
| self.model_ema(self) | |
| def encode(self, x): | |
| h = self.encoder(x) | |
| h = self.quant_conv(h) | |
| quant, emb_loss, info = self.quantize(h) | |
| return quant, emb_loss, info | |
| def encode_to_prequant(self, x): | |
| h = self.encoder(x) | |
| h = self.quant_conv(h) | |
| return h | |
| def decode(self, quant): | |
| quant = self.post_quant_conv(quant) | |
| dec = self.decoder(quant) | |
| return dec | |
| def decode_code(self, code_b): | |
| quant_b = self.quantize.embed_code(code_b) | |
| dec = self.decode(quant_b) | |
| return dec | |
| def forward(self, input, return_pred_indices=False): | |
| quant, diff, (_, _, ind) = self.encode(input) | |
| dec = self.decode(quant) | |
| if return_pred_indices: | |
| return dec, diff, ind | |
| return dec, diff | |
| def get_input(self, batch, k): | |
| x = batch[k] | |
| if len(x.shape) == 3: | |
| x = x[..., None] | |
| x = ( | |
| x.permute(0, 3, 1, 2) | |
| .to(memory_format=torch.contiguous_format) | |
| .float() | |
| .contiguous() | |
| ) | |
| if self.batch_resize_range is not None: | |
| lower_size = self.batch_resize_range[0] | |
| upper_size = self.batch_resize_range[1] | |
| if self.global_step <= 4: | |
| # do the first few batches with max size to avoid later oom | |
| new_resize = upper_size | |
| else: | |
| new_resize = np.random.choice( | |
| np.arange(lower_size, upper_size + 16, 16) | |
| ) | |
| if new_resize != x.shape[2]: | |
| x = F.interpolate(x, size=new_resize, mode="bicubic") | |
| x = x.detach() | |
| return x | |
| def get_last_layer(self): | |
| return self.decoder.conv_out.weight | |
| def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs): | |
| log = dict() | |
| x = self.get_input(batch, self.image_key) | |
| x = x.to(self.device) | |
| if only_inputs: | |
| log["inputs"] = x | |
| return log | |
| xrec, _ = self(x) | |
| if x.shape[1] > 3: | |
| # colorize with random projection | |
| assert xrec.shape[1] > 3 | |
| x = self.to_rgb(x) | |
| xrec = self.to_rgb(xrec) | |
| log["inputs"] = x | |
| log["reconstructions"] = xrec | |
| if plot_ema: | |
| with self.ema_scope(): | |
| xrec_ema, _ = self(x) | |
| if x.shape[1] > 3: | |
| xrec_ema = self.to_rgb(xrec_ema) | |
| log["reconstructions_ema"] = xrec_ema | |
| return log | |
| def to_rgb(self, x): | |
| assert self.image_key == "segmentation" | |
| if not hasattr(self, "colorize"): | |
| self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) | |
| x = F.conv2d(x, weight=self.colorize) | |
| x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0 | |
| return x | |