| import os |
| import soundfile as sf |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from FlashSR.AudioSR.EncoderDecoder import Encoder, Decoder |
| from FlashSR.AudioSR.Vocoder import get_vocoder |
|
|
|
|
| class AutoencoderKL(nn.Module): |
| def __init__( |
| self, |
| ddconfig=None, |
| lossconfig=None, |
| batchsize=None, |
| embed_dim=None, |
| time_shuffle=1, |
| subband=1, |
| sampling_rate=16000, |
| ckpt_path=None, |
| reload_from_ckpt=None, |
| ignore_keys=[], |
| image_key="fbank", |
| colorize_nlabels=None, |
| monitor=None, |
| base_learning_rate=1e-5, |
| ): |
| super().__init__() |
| self.automatic_optimization = False |
| assert ( |
| "mel_bins" in ddconfig.keys() |
| ), "mel_bins is not specified in the Autoencoder config" |
| num_mel = ddconfig["mel_bins"] |
| self.image_key = image_key |
| self.sampling_rate = sampling_rate |
| self.encoder = Encoder(**ddconfig) |
| self.decoder = Decoder(**ddconfig) |
|
|
| self.loss = None |
| self.subband = int(subband) |
|
|
| if self.subband > 1: |
| print("Use subband decomposition %s" % self.subband) |
|
|
| assert ddconfig["double_z"] |
| self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1) |
| self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) |
|
|
| if self.image_key == "fbank": |
| self.vocoder = get_vocoder(None, "cpu", num_mel) |
| self.embed_dim = embed_dim |
| if colorize_nlabels is not None: |
| assert type(colorize_nlabels) == int |
| self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) |
| if monitor is not None: |
| self.monitor = monitor |
| if ckpt_path is not None: |
| self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) |
| self.learning_rate = float(base_learning_rate) |
| |
|
|
| self.time_shuffle = time_shuffle |
| self.reload_from_ckpt = reload_from_ckpt |
| self.reloaded = False |
| self.mean, self.std = None, None |
|
|
| self.feature_cache = None |
| self.flag_first_run = True |
| self.train_step = 0 |
|
|
| self.logger_save_dir = None |
| self.logger_exp_name = None |
|
|
| def get_log_dir(self): |
| if self.logger_save_dir is None and self.logger_exp_name is None: |
| return os.path.join(self.logger.save_dir, self.logger._project) |
| else: |
| return os.path.join(self.logger_save_dir, self.logger_exp_name) |
|
|
| def set_log_dir(self, save_dir, exp_name): |
| self.logger_save_dir = save_dir |
| self.logger_exp_name = exp_name |
|
|
| def init_from_ckpt(self, path, ignore_keys=list()): |
| sd = torch.load(path, map_location="cpu")["state_dict"] |
| keys = list(sd.keys()) |
| for k in keys: |
| for ik in ignore_keys: |
| if k.startswith(ik): |
| print("Deleting key {} from state_dict.".format(k)) |
| del sd[k] |
| self.load_state_dict(sd, strict=False) |
| print(f"Restored from {path}") |
|
|
| def encode(self, x): |
| |
| |
| h = self.encoder(x) |
| moments = self.quant_conv(h) |
| posterior = DiagonalGaussianDistribution(moments) |
| return posterior |
|
|
| def decode(self, z): |
| z = self.post_quant_conv(z) |
| dec = self.decoder(z) |
| |
| |
| |
| return dec |
|
|
| def decode_to_waveform(self, dec): |
| from audiosr.utilities.model import vocoder_infer |
|
|
| if self.image_key == "fbank": |
| dec = dec.squeeze(1).permute(0, 2, 1) |
| wav_reconstruction = vocoder_infer(dec, self.vocoder) |
| elif self.image_key == "stft": |
| dec = dec.squeeze(1).permute(0, 2, 1) |
| wav_reconstruction = self.wave_decoder(dec) |
| return wav_reconstruction |
|
|
| def visualize_latent(self, input): |
| import matplotlib.pyplot as plt |
|
|
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
| np.save("input.npy", input.cpu().detach().numpy()) |
| |
| time_input = input.clone() |
| time_input[:, :, :, :32] *= 0 |
| time_input[:, :, :, :32] -= 11.59 |
|
|
| np.save("time_input.npy", time_input.cpu().detach().numpy()) |
|
|
| posterior = self.encode(time_input) |
| latent = posterior.sample() |
| np.save("time_latent.npy", latent.cpu().detach().numpy()) |
| avg_latent = torch.mean(latent, dim=1) |
| for i in range(avg_latent.size(0)): |
| plt.imshow(avg_latent[i].cpu().detach().numpy().T) |
| plt.savefig("freq_%s.png" % i) |
| plt.close() |
|
|
| freq_input = input.clone() |
| freq_input[:, :, :512, :] *= 0 |
| freq_input[:, :, :512, :] -= 11.59 |
|
|
| np.save("freq_input.npy", freq_input.cpu().detach().numpy()) |
|
|
| posterior = self.encode(freq_input) |
| latent = posterior.sample() |
| np.save("freq_latent.npy", latent.cpu().detach().numpy()) |
| avg_latent = torch.mean(latent, dim=1) |
| for i in range(avg_latent.size(0)): |
| plt.imshow(avg_latent[i].cpu().detach().numpy().T) |
| plt.savefig("time_%s.png" % i) |
| plt.close() |
|
|
| def get_input(self, batch): |
| fname, text, label_indices, waveform, stft, fbank = ( |
| batch["fname"], |
| batch["text"], |
| batch["label_vector"], |
| batch["waveform"], |
| batch["stft"], |
| batch["log_mel_spec"], |
| ) |
| |
| |
| |
| |
|
|
| ret = {} |
|
|
| ret["fbank"], ret["stft"], ret["fname"], ret["waveform"] = ( |
| fbank.unsqueeze(1), |
| stft.unsqueeze(1), |
| fname, |
| waveform.unsqueeze(1), |
| ) |
|
|
| return ret |
|
|
| def save_wave(self, batch_wav, fname, save_dir): |
| os.makedirs(save_dir, exist_ok=True) |
|
|
| for wav, name in zip(batch_wav, fname): |
| name = os.path.basename(name) |
|
|
| sf.write(os.path.join(save_dir, name), wav, samplerate=self.sampling_rate) |
|
|
| def get_last_layer(self): |
| return self.decoder.conv_out.weight |
|
|
| @torch.no_grad() |
| def log_images(self, batch, train=True, only_inputs=False, waveform=None, **kwargs): |
| log = dict() |
| x = batch.to(self.device) |
| if not only_inputs: |
| xrec, posterior = self(x) |
| log["samples"] = self.decode(posterior.sample()) |
| log["reconstructions"] = xrec |
|
|
| log["inputs"] = x |
| wavs = self._log_img(log, train=train, index=0, waveform=waveform) |
| return wavs |
|
|
| def _log_img(self, log, train=True, index=0, waveform=None): |
| images_input = self.tensor2numpy(log["inputs"][index, 0]).T |
| images_reconstruct = self.tensor2numpy(log["reconstructions"][index, 0]).T |
| images_samples = self.tensor2numpy(log["samples"][index, 0]).T |
|
|
| if train: |
| name = "train" |
| else: |
| name = "val" |
|
|
| if self.logger is not None: |
| self.logger.log_image( |
| "img_%s" % name, |
| [images_input, images_reconstruct, images_samples], |
| caption=["input", "reconstruct", "samples"], |
| ) |
|
|
| inputs, reconstructions, samples = ( |
| log["inputs"], |
| log["reconstructions"], |
| log["samples"], |
| ) |
|
|
| if self.image_key == "fbank": |
| wav_original, wav_prediction = synth_one_sample( |
| inputs[index], |
| reconstructions[index], |
| labels="validation", |
| vocoder=self.vocoder, |
| ) |
| wav_original, wav_samples = synth_one_sample( |
| inputs[index], samples[index], labels="validation", vocoder=self.vocoder |
| ) |
| wav_original, wav_samples, wav_prediction = ( |
| wav_original[0], |
| wav_samples[0], |
| wav_prediction[0], |
| ) |
| elif self.image_key == "stft": |
| wav_prediction = ( |
| self.decode_to_waveform(reconstructions)[index, 0] |
| .cpu() |
| .detach() |
| .numpy() |
| ) |
| wav_samples = ( |
| self.decode_to_waveform(samples)[index, 0].cpu().detach().numpy() |
| ) |
| wav_original = waveform[index, 0].cpu().detach().numpy() |
|
|
| if self.logger is not None: |
| self.logger.experiment.log( |
| { |
| "original_%s" |
| % name: wandb.Audio( |
| wav_original, caption="original", sample_rate=self.sampling_rate |
| ), |
| "reconstruct_%s" |
| % name: wandb.Audio( |
| wav_prediction, |
| caption="reconstruct", |
| sample_rate=self.sampling_rate, |
| ), |
| "samples_%s" |
| % name: wandb.Audio( |
| wav_samples, caption="samples", sample_rate=self.sampling_rate |
| ), |
| } |
| ) |
|
|
| return wav_original, wav_prediction, wav_samples |
|
|
| def tensor2numpy(self, tensor): |
| return tensor.cpu().detach().numpy() |
|
|
| def to_rgb(self, x): |
| assert self.image_key == "segmentation" |
| if not hasattr(self, "colorize"): |
| self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) |
| x = F.conv2d(x, weight=self.colorize) |
| x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0 |
| return x |
|
|
|
|
| class IdentityFirstStage(torch.nn.Module): |
| def __init__(self, *args, vq_interface=False, **kwargs): |
| self.vq_interface = vq_interface |
| super().__init__() |
|
|
| def encode(self, x, *args, **kwargs): |
| return x |
|
|
| def decode(self, x, *args, **kwargs): |
| return x |
|
|
| def quantize(self, x, *args, **kwargs): |
| if self.vq_interface: |
| return x, None, [None, None, None] |
| return x |
|
|
| def forward(self, x, *args, **kwargs): |
| return x |
|
|
| class DiagonalGaussianDistribution(object): |
| def __init__(self, parameters, deterministic=False): |
| self.parameters = parameters |
| self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) |
| self.logvar = torch.clamp(self.logvar, -30.0, 20.0) |
| self.deterministic = deterministic |
| self.std = torch.exp(0.5 * self.logvar) |
| self.var = torch.exp(self.logvar) |
| if self.deterministic: |
| self.var = self.std = torch.zeros_like(self.mean).to( |
| device=self.parameters.device |
| ) |
|
|
| def sample(self): |
| x = self.mean + self.std * torch.randn(self.mean.shape).to( |
| device=self.parameters.device |
| ) |
| return x |
|
|
| def kl(self, other=None): |
| if self.deterministic: |
| return torch.Tensor([0.0]) |
| else: |
| if other is None: |
| return 0.5 * torch.mean( |
| torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, |
| dim=[1, 2, 3], |
| ) |
| else: |
| return 0.5 * torch.mean( |
| torch.pow(self.mean - other.mean, 2) / other.var |
| + self.var / other.var |
| - 1.0 |
| - self.logvar |
| + other.logvar, |
| dim=[1, 2, 3], |
| ) |
|
|
| def nll(self, sample, dims=[1, 2, 3]): |
| if self.deterministic: |
| return torch.Tensor([0.0]) |
| logtwopi = np.log(2.0 * np.pi) |
| return 0.5 * torch.sum( |
| logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, |
| dim=dims, |
| ) |
|
|
| def mode(self): |
| return self.mean |