EAR_VAE / model /ear_vae.py
earlab's picture
Upload folder using huggingface_hub
b3c4dc3 verified
import torch
import numpy as np
import torch.nn.functional as F
import torch.nn as nn
from torch import Tensor, nn, no_grad
from .autoencoders import OobleckDecoder, OobleckEncoder
from .transformer import ContinuousTransformer
LRELU_SLOPE = 0.1
padding_mode = "zeros"
sample_eps = 1e-6
def vae_sample(mean, scale):
stdev = nn.functional.softplus(scale)
var = stdev * stdev + sample_eps
logvar = torch.log(var)
latents = torch.randn_like(mean) * stdev + mean
kl = (mean * mean + var - logvar - 1).sum(1).mean()
return latents, kl
class EAR_VAE(nn.Module):
def __init__(self, model_config: dict = None):
super().__init__()
if model_config is None:
model_config = {
"encoder": {
"config": {
"in_channels": 2,
"channels": 128,
"c_mults": [1, 2, 4, 8, 16],
"strides": [2, 4, 4, 4, 8],
"latent_dim": 128,
"use_snake": True
}
},
"decoder": {
"config": {
"out_channels": 2,
"channels": 128,
"c_mults": [1, 2, 4, 8, 16],
"strides": [2, 4, 4, 4, 8],
"latent_dim": 64,
"use_nearest_upsample": False,
"use_snake": True,
"final_tanh": False,
},
},
"latent_dim": 64,
"downsampling_ratio": 1024,
"io_channels": 2,
}
else:
model_config = model_config
if model_config.get("transformer") is not None:
self.transformers = ContinuousTransformer(
dim=model_config["decoder"]["config"]["latent_dim"],
depth=model_config["transformer"]["depth"],
**model_config["transformer"].get("config", {}),
)
else:
self.transformers = None
self.encoder = OobleckEncoder(**model_config["encoder"]["config"])
self.decoder = OobleckDecoder(**model_config["decoder"]["config"])
def forward(self, audio) -> Tensor:
"""
audio: Input audio tensor [B,C,T]
"""
status = self.encoder(audio)
mean, scale = status.chunk(2, dim=1)
z, kl = vae_sample(mean, scale)
if self.transformers is not None:
z = z.permute(0, 2, 1)
z = self.transformers(z)
z = z.permute(0, 2, 1)
x = self.decoder(z)
return x, kl
def encode(self, audio, use_sample=True):
x = self.encoder(audio)
mean, scale = x.chunk(2, dim=1)
if use_sample:
z, _ = vae_sample(mean, scale)
else:
z = mean
return z
def decode(self, z):
if self.transformers is not None:
z = z.permute(0, 2, 1)
z = self.transformers(z)
z = z.permute(0, 2, 1)
x = self.decoder(z)
return x
@no_grad()
def inference(self, audio):
z = self.encode(audio)
recon_audio = self.decode(z)
return recon_audio