| import torch |
| from audioldm.latent_diffusion.ema import * |
| from audioldm.variational_autoencoder.modules import Encoder, Decoder |
| from audioldm.variational_autoencoder.distributions import DiagonalGaussianDistribution |
|
|
| from audioldm.hifigan.utilities import get_vocoder, vocoder_infer |
|
|
|
|
| class AutoencoderKL(nn.Module): |
| def __init__( |
| self, |
| ddconfig=None, |
| lossconfig=None, |
| image_key="fbank", |
| embed_dim=None, |
| time_shuffle=1, |
| subband=1, |
| ckpt_path=None, |
| reload_from_ckpt=None, |
| ignore_keys=[], |
| colorize_nlabels=None, |
| monitor=None, |
| base_learning_rate=1e-5, |
| scale_factor=1 |
| ): |
| super().__init__() |
|
|
| self.encoder = Encoder(**ddconfig) |
| self.decoder = Decoder(**ddconfig) |
|
|
| self.subband = int(subband) |
|
|
| if self.subband > 1: |
| print("Use subband decomposition %s" % self.subband) |
|
|
| self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1) |
| self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) |
|
|
| self.vocoder = get_vocoder(None, "cpu") |
| self.embed_dim = embed_dim |
|
|
| if monitor is not None: |
| self.monitor = monitor |
|
|
| self.time_shuffle = time_shuffle |
| self.reload_from_ckpt = reload_from_ckpt |
| self.reloaded = False |
| self.mean, self.std = None, None |
| |
| self.scale_factor = scale_factor |
|
|
| def encode(self, x): |
| |
| x = self.freq_split_subband(x) |
| h = self.encoder(x) |
| moments = self.quant_conv(h) |
| posterior = DiagonalGaussianDistribution(moments) |
| return posterior |
|
|
| def decode(self, z): |
| z = self.post_quant_conv(z) |
| dec = self.decoder(z) |
| dec = self.freq_merge_subband(dec) |
| return dec |
|
|
| def decode_to_waveform(self, dec): |
| dec = dec.squeeze(1).permute(0, 2, 1) |
| wav_reconstruction = vocoder_infer(dec, self.vocoder) |
| return wav_reconstruction |
|
|
| def forward(self, input, sample_posterior=True): |
| posterior = self.encode(input) |
| if sample_posterior: |
| z = posterior.sample() |
| else: |
| z = posterior.mode() |
|
|
| if self.flag_first_run: |
| print("Latent size: ", z.size()) |
| self.flag_first_run = False |
|
|
| dec = self.decode(z) |
|
|
| return dec, posterior |
|
|
| def freq_split_subband(self, fbank): |
| if self.subband == 1 or self.image_key != "stft": |
| return fbank |
|
|
| bs, ch, tstep, fbins = fbank.size() |
|
|
| assert fbank.size(-1) % self.subband == 0 |
| assert ch == 1 |
|
|
| return ( |
| fbank.squeeze(1) |
| .reshape(bs, tstep, self.subband, fbins // self.subband) |
| .permute(0, 2, 1, 3) |
| ) |
|
|
| def freq_merge_subband(self, subband_fbank): |
| if self.subband == 1 or self.image_key != "stft": |
| return subband_fbank |
| assert subband_fbank.size(1) == self.subband |
| bs, sub_ch, tstep, fbins = subband_fbank.size() |
| return subband_fbank.permute(0, 2, 1, 3).reshape(bs, tstep, -1).unsqueeze(1) |
| |
| def device(self): |
| return next(self.parameters()).device |
| |
| @torch.no_grad() |
| def encode_first_stage(self, x): |
| return self.encode(x) |
| |
| @torch.no_grad() |
| def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): |
| if predict_cids: |
| if z.dim() == 4: |
| z = torch.argmax(z.exp(), dim=1).long() |
| z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) |
| z = rearrange(z, "b h w c -> b c h w").contiguous() |
|
|
| z = 1.0 / self.scale_factor * z |
| return self.decode(z) |
|
|
| def get_first_stage_encoding(self, encoder_posterior): |
| if isinstance(encoder_posterior, DiagonalGaussianDistribution): |
| z = encoder_posterior.sample() |
| elif isinstance(encoder_posterior, torch.Tensor): |
| z = encoder_posterior |
| else: |
| raise NotImplementedError( |
| f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented" |
| ) |
| return self.scale_factor * z |