import torch import numpy as np import torch.nn.functional as F import torch.nn as nn from torch import Tensor, nn, no_grad from .autoencoders import OobleckDecoder, OobleckEncoder from .transformer import ContinuousTransformer LRELU_SLOPE = 0.1 padding_mode = "zeros" sample_eps = 1e-6 def vae_sample(mean, scale): stdev = nn.functional.softplus(scale) var = stdev * stdev + sample_eps logvar = torch.log(var) latents = torch.randn_like(mean) * stdev + mean kl = (mean * mean + var - logvar - 1).sum(1).mean() return latents, kl class EAR_VAE(nn.Module): def __init__(self, model_config: dict = None): super().__init__() if model_config is None: model_config = { "encoder": { "config": { "in_channels": 2, "channels": 128, "c_mults": [1, 2, 4, 8, 16], "strides": [2, 4, 4, 4, 8], "latent_dim": 128, "use_snake": True } }, "decoder": { "config": { "out_channels": 2, "channels": 128, "c_mults": [1, 2, 4, 8, 16], "strides": [2, 4, 4, 4, 8], "latent_dim": 64, "use_nearest_upsample": False, "use_snake": True, "final_tanh": False, }, }, "latent_dim": 64, "downsampling_ratio": 1024, "io_channels": 2, } else: model_config = model_config if model_config.get("transformer") is not None: self.transformers = ContinuousTransformer( dim=model_config["decoder"]["config"]["latent_dim"], depth=model_config["transformer"]["depth"], **model_config["transformer"].get("config", {}), ) else: self.transformers = None self.encoder = OobleckEncoder(**model_config["encoder"]["config"]) self.decoder = OobleckDecoder(**model_config["decoder"]["config"]) def forward(self, audio) -> Tensor: """ audio: Input audio tensor [B,C,T] """ status = self.encoder(audio) mean, scale = status.chunk(2, dim=1) z, kl = vae_sample(mean, scale) if self.transformers is not None: z = z.permute(0, 2, 1) z = self.transformers(z) z = z.permute(0, 2, 1) x = self.decoder(z) return x, kl def encode(self, audio, use_sample=True): x = self.encoder(audio) mean, scale = x.chunk(2, dim=1) if use_sample: z, _ = vae_sample(mean, scale) else: z = mean return z def decode(self, z): if self.transformers is not None: z = z.permute(0, 2, 1) z = self.transformers(z) z = z.permute(0, 2, 1) x = self.decoder(z) return x @no_grad() def inference(self, audio): z = self.encode(audio) recon_audio = self.decode(z) return recon_audio