| | from typing import Any, Literal, Callable |
| | import math |
| | from pathlib import Path |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from torch.nn.utils import weight_norm |
| | import torchaudio |
| | from alias_free_torch import Activation1d |
| |
|
| | from models.common import LoadPretrainedBase |
| | from models.autoencoder.autoencoder_base import AutoEncoderBase |
| | from utils.torch_utilities import remove_key_prefix_factory, create_mask_from_length |
| |
|
| |
|
| | |
| | @torch.jit.script |
| | def snake_beta(x, alpha, beta): |
| | return x + (1.0 / (beta + 0.000000001)) * pow(torch.sin(x * alpha), 2) |
| |
|
| |
|
| | class SnakeBeta(nn.Module): |
| | def __init__( |
| | self, |
| | in_features, |
| | alpha=1.0, |
| | alpha_trainable=True, |
| | alpha_logscale=True |
| | ): |
| | super(SnakeBeta, self).__init__() |
| | self.in_features = in_features |
| |
|
| | |
| | self.alpha_logscale = alpha_logscale |
| | if self.alpha_logscale: |
| | |
| | self.alpha = nn.Parameter(torch.zeros(in_features) * alpha) |
| | self.beta = nn.Parameter(torch.zeros(in_features) * alpha) |
| | else: |
| | |
| | self.alpha = nn.Parameter(torch.ones(in_features) * alpha) |
| | self.beta = nn.Parameter(torch.ones(in_features) * alpha) |
| |
|
| | self.alpha.requires_grad = alpha_trainable |
| | self.beta.requires_grad = alpha_trainable |
| |
|
| | |
| |
|
| | def forward(self, x): |
| | alpha = self.alpha.unsqueeze(0).unsqueeze(-1) |
| | |
| | beta = self.beta.unsqueeze(0).unsqueeze(-1) |
| | if self.alpha_logscale: |
| | alpha = torch.exp(alpha) |
| | beta = torch.exp(beta) |
| | x = snake_beta(x, alpha, beta) |
| |
|
| | return x |
| |
|
| |
|
| | def WNConv1d(*args, **kwargs): |
| | return weight_norm(nn.Conv1d(*args, **kwargs)) |
| |
|
| |
|
| | def WNConvTranspose1d(*args, **kwargs): |
| | return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) |
| |
|
| |
|
| | def get_activation( |
| | activation: Literal["elu", "snake", "none"], |
| | antialias=False, |
| | channels=None |
| | ) -> nn.Module: |
| | if activation == "elu": |
| | act = nn.ELU() |
| | elif activation == "snake": |
| | act = SnakeBeta(channels) |
| | elif activation == "none": |
| | act = nn.Identity() |
| | else: |
| | raise ValueError(f"Unknown activation {activation}") |
| |
|
| | if antialias: |
| | act = Activation1d(act) |
| |
|
| | return act |
| |
|
| |
|
| | class ResidualUnit(nn.Module): |
| | def __init__( |
| | self, |
| | in_channels, |
| | out_channels, |
| | dilation, |
| | use_snake=False, |
| | antialias_activation=False |
| | ): |
| | super().__init__() |
| |
|
| | self.dilation = dilation |
| |
|
| | padding = (dilation * (7 - 1)) // 2 |
| |
|
| | self.layers = nn.Sequential( |
| | get_activation( |
| | "snake" if use_snake else "elu", |
| | antialias=antialias_activation, |
| | channels=out_channels |
| | ), |
| | WNConv1d( |
| | in_channels=in_channels, |
| | out_channels=out_channels, |
| | kernel_size=7, |
| | dilation=dilation, |
| | padding=padding |
| | ), |
| | get_activation( |
| | "snake" if use_snake else "elu", |
| | antialias=antialias_activation, |
| | channels=out_channels |
| | ), |
| | WNConv1d( |
| | in_channels=out_channels, |
| | out_channels=out_channels, |
| | kernel_size=1 |
| | ) |
| | ) |
| |
|
| | def forward(self, x): |
| | res = x |
| |
|
| | |
| | x = self.layers(x) |
| |
|
| | return x + res |
| |
|
| |
|
| | class EncoderBlock(nn.Module): |
| | def __init__( |
| | self, |
| | in_channels, |
| | out_channels, |
| | stride, |
| | use_snake=False, |
| | antialias_activation=False |
| | ): |
| | super().__init__() |
| |
|
| | self.layers = nn.Sequential( |
| | ResidualUnit( |
| | in_channels=in_channels, |
| | out_channels=in_channels, |
| | dilation=1, |
| | use_snake=use_snake |
| | ), |
| | ResidualUnit( |
| | in_channels=in_channels, |
| | out_channels=in_channels, |
| | dilation=3, |
| | use_snake=use_snake |
| | ), |
| | ResidualUnit( |
| | in_channels=in_channels, |
| | out_channels=in_channels, |
| | dilation=9, |
| | use_snake=use_snake |
| | ), |
| | get_activation( |
| | "snake" if use_snake else "elu", |
| | antialias=antialias_activation, |
| | channels=in_channels |
| | ), |
| | WNConv1d( |
| | in_channels=in_channels, |
| | out_channels=out_channels, |
| | kernel_size=2 * stride, |
| | stride=stride, |
| | padding=math.ceil(stride / 2) |
| | ), |
| | ) |
| |
|
| | def forward(self, x): |
| | return self.layers(x) |
| |
|
| |
|
| | class DecoderBlock(nn.Module): |
| | def __init__( |
| | self, |
| | in_channels, |
| | out_channels, |
| | stride, |
| | use_snake=False, |
| | antialias_activation=False, |
| | use_nearest_upsample=False |
| | ): |
| | super().__init__() |
| |
|
| | if use_nearest_upsample: |
| | upsample_layer = nn.Sequential( |
| | nn.Upsample(scale_factor=stride, mode="nearest"), |
| | WNConv1d( |
| | in_channels=in_channels, |
| | out_channels=out_channels, |
| | kernel_size=2 * stride, |
| | stride=1, |
| | bias=False, |
| | padding='same' |
| | ) |
| | ) |
| | else: |
| | upsample_layer = WNConvTranspose1d( |
| | in_channels=in_channels, |
| | out_channels=out_channels, |
| | kernel_size=2 * stride, |
| | stride=stride, |
| | padding=math.ceil(stride / 2) |
| | ) |
| |
|
| | self.layers = nn.Sequential( |
| | get_activation( |
| | "snake" if use_snake else "elu", |
| | antialias=antialias_activation, |
| | channels=in_channels |
| | ), |
| | upsample_layer, |
| | ResidualUnit( |
| | in_channels=out_channels, |
| | out_channels=out_channels, |
| | dilation=1, |
| | use_snake=use_snake |
| | ), |
| | ResidualUnit( |
| | in_channels=out_channels, |
| | out_channels=out_channels, |
| | dilation=3, |
| | use_snake=use_snake |
| | ), |
| | ResidualUnit( |
| | in_channels=out_channels, |
| | out_channels=out_channels, |
| | dilation=9, |
| | use_snake=use_snake |
| | ), |
| | ) |
| |
|
| | def forward(self, x): |
| | return self.layers(x) |
| |
|
| |
|
| | class OobleckEncoder(nn.Module): |
| | def __init__( |
| | self, |
| | in_channels=2, |
| | channels=128, |
| | latent_dim=32, |
| | c_mults=[1, 2, 4, 8], |
| | strides=[2, 4, 8, 8], |
| | use_snake=False, |
| | antialias_activation=False |
| | ): |
| | super().__init__() |
| |
|
| | c_mults = [1] + c_mults |
| |
|
| | self.depth = len(c_mults) |
| |
|
| | layers = [ |
| | WNConv1d( |
| | in_channels=in_channels, |
| | out_channels=c_mults[0] * channels, |
| | kernel_size=7, |
| | padding=3 |
| | ) |
| | ] |
| |
|
| | for i in range(self.depth - 1): |
| | layers += [ |
| | EncoderBlock( |
| | in_channels=c_mults[i] * channels, |
| | out_channels=c_mults[i + 1] * channels, |
| | stride=strides[i], |
| | use_snake=use_snake |
| | ) |
| | ] |
| |
|
| | layers += [ |
| | get_activation( |
| | "snake" if use_snake else "elu", |
| | antialias=antialias_activation, |
| | channels=c_mults[-1] * channels |
| | ), |
| | WNConv1d( |
| | in_channels=c_mults[-1] * channels, |
| | out_channels=latent_dim, |
| | kernel_size=3, |
| | padding=1 |
| | ) |
| | ] |
| |
|
| | self.layers = nn.Sequential(*layers) |
| |
|
| | def forward(self, x): |
| | return self.layers(x) |
| |
|
| |
|
| | class OobleckDecoder(nn.Module): |
| | def __init__( |
| | self, |
| | out_channels=2, |
| | channels=128, |
| | latent_dim=32, |
| | c_mults=[1, 2, 4, 8], |
| | strides=[2, 4, 8, 8], |
| | use_snake=False, |
| | antialias_activation=False, |
| | use_nearest_upsample=False, |
| | final_tanh=True |
| | ): |
| | super().__init__() |
| |
|
| | c_mults = [1] + c_mults |
| |
|
| | self.depth = len(c_mults) |
| |
|
| | layers = [ |
| | WNConv1d( |
| | in_channels=latent_dim, |
| | out_channels=c_mults[-1] * channels, |
| | kernel_size=7, |
| | padding=3 |
| | ), |
| | ] |
| |
|
| | for i in range(self.depth - 1, 0, -1): |
| | layers += [ |
| | DecoderBlock( |
| | in_channels=c_mults[i] * channels, |
| | out_channels=c_mults[i - 1] * channels, |
| | stride=strides[i - 1], |
| | use_snake=use_snake, |
| | antialias_activation=antialias_activation, |
| | use_nearest_upsample=use_nearest_upsample |
| | ) |
| | ] |
| |
|
| | layers += [ |
| | get_activation( |
| | "snake" if use_snake else "elu", |
| | antialias=antialias_activation, |
| | channels=c_mults[0] * channels |
| | ), |
| | WNConv1d( |
| | in_channels=c_mults[0] * channels, |
| | out_channels=out_channels, |
| | kernel_size=7, |
| | padding=3, |
| | bias=False |
| | ), |
| | nn.Tanh() if final_tanh else nn.Identity() |
| | ] |
| |
|
| | self.layers = nn.Sequential(*layers) |
| |
|
| | def forward(self, x): |
| | return self.layers(x) |
| |
|
| |
|
| | class Bottleneck(nn.Module): |
| | def __init__(self, is_discrete: bool = False): |
| | super().__init__() |
| |
|
| | self.is_discrete = is_discrete |
| |
|
| | def encode(self, x, return_info=False, **kwargs): |
| | raise NotImplementedError |
| |
|
| | def decode(self, x): |
| | raise NotImplementedError |
| |
|
| |
|
| | @torch.jit.script |
| | def vae_sample(mean, scale) -> dict[str, torch.Tensor]: |
| | stdev = nn.functional.softplus(scale) + 1e-4 |
| | var = stdev * stdev |
| | logvar = torch.log(var) |
| | latents = torch.randn_like(mean) * stdev + mean |
| |
|
| | kl = (mean * mean + var - logvar - 1).sum(1).mean() |
| | return {"latents": latents, "kl": kl} |
| |
|
| |
|
| | class VAEBottleneck(Bottleneck): |
| | def __init__(self): |
| | super().__init__(is_discrete=False) |
| |
|
| | def encode(self, |
| | x, |
| | return_info=False, |
| | **kwargs) -> dict[str, torch.Tensor] | torch.Tensor: |
| | mean, scale = x.chunk(2, dim=1) |
| | sampled = vae_sample(mean, scale) |
| |
|
| | if return_info: |
| | return sampled["latents"], {"kl": sampled["kl"]} |
| | else: |
| | return sampled["latents"] |
| |
|
| | def decode(self, x): |
| | return x |
| |
|
| |
|
| | def compute_mean_kernel(x, y): |
| | kernel_input = (x[:, None] - y[None]).pow(2).mean(2) / x.shape[-1] |
| | return torch.exp(-kernel_input).mean() |
| |
|
| |
|
| | class Pretransform(nn.Module): |
| | def __init__(self, enable_grad, io_channels, is_discrete): |
| | super().__init__() |
| |
|
| | self.is_discrete = is_discrete |
| | self.io_channels = io_channels |
| | self.encoded_channels = None |
| | self.downsampling_ratio = None |
| |
|
| | self.enable_grad = enable_grad |
| |
|
| | def encode(self, x): |
| | raise NotImplementedError |
| |
|
| | def decode(self, z): |
| | raise NotImplementedError |
| |
|
| | def tokenize(self, x): |
| | raise NotImplementedError |
| |
|
| | def decode_tokens(self, tokens): |
| | raise NotImplementedError |
| |
|
| |
|
| | class StableVAE(LoadPretrainedBase, AutoEncoderBase): |
| | def __init__( |
| | self, |
| | encoder, |
| | decoder, |
| | latent_dim, |
| | downsampling_ratio, |
| | sample_rate, |
| | io_channels=2, |
| | bottleneck: Bottleneck = None, |
| | pretransform: Pretransform = None, |
| | in_channels=None, |
| | out_channels=None, |
| | soft_clip=False, |
| | pretrained_ckpt: str | Path = None |
| | ): |
| | LoadPretrainedBase.__init__(self) |
| | AutoEncoderBase.__init__( |
| | self, |
| | downsampling_ratio=downsampling_ratio, |
| | sample_rate=sample_rate, |
| | latent_shape=(latent_dim, None) |
| | ) |
| |
|
| | self.latent_dim = latent_dim |
| | self.io_channels = io_channels |
| | self.in_channels = io_channels |
| | self.out_channels = io_channels |
| | self.min_length = self.downsampling_ratio |
| |
|
| | if in_channels is not None: |
| | self.in_channels = in_channels |
| |
|
| | if out_channels is not None: |
| | self.out_channels = out_channels |
| |
|
| | self.bottleneck = bottleneck |
| | self.encoder = encoder |
| | self.decoder = decoder |
| | self.pretransform = pretransform |
| | self.soft_clip = soft_clip |
| | self.is_discrete = self.bottleneck is not None and self.bottleneck.is_discrete |
| |
|
| | self.remove_autoencoder_prefix_fn: Callable = remove_key_prefix_factory( |
| | "autoencoder." |
| | ) |
| | if pretrained_ckpt is not None: |
| | self.load_pretrained(pretrained_ckpt) |
| |
|
| | def process_state_dict(self, model_dict, state_dict): |
| | state_dict = state_dict["state_dict"] |
| | state_dict = self.remove_autoencoder_prefix_fn(model_dict, state_dict) |
| | return state_dict |
| |
|
| | def encode( |
| | self, waveform: torch.Tensor, waveform_lengths: torch.Tensor,pad_latent_len: int = 500 |
| | ) -> tuple[torch.Tensor, torch.Tensor]: |
| | |
| | z = self.encoder(waveform) |
| | z = self.bottleneck.encode(z) |
| | z_length = waveform_lengths // self.downsampling_ratio |
| | z_mask = create_mask_from_length(z_length, max_length=pad_latent_len) |
| |
|
| | B, C, L = z.shape |
| | if L < pad_latent_len: |
| | pad_size = pad_latent_len - L |
| | z = torch.cat([z, torch.zeros(B, C, pad_size, device=z.device, dtype=z.dtype)], dim=-1) |
| | return z, z_mask |
| |
|
| | def decode(self, latents: torch.Tensor, latent_mask: torch.Tensor | None = None) -> torch.Tensor: |
| | """ |
| | latents: [B, C, T_latent] |
| | latent_mask: [B, T_latent] 可选,1为有效,0为padding |
| | """ |
| | if latent_mask is not None: |
| | outputs = [] |
| | for b in range(latents.size(0)): |
| | |
| | valid_idx = latent_mask[b].bool() |
| | valid_latents = latents[b, :, valid_idx] |
| | outputs.append(self.decoder(valid_latents.unsqueeze(0))) |
| | return torch.cat(outputs, dim=0) |
| | else: |
| | return self.decoder(latents) |
| | return waveform |
| |
|
| |
|
| |
|
| | class StableVAEProjectorWrapper(nn.Module): |
| | def __init__( |
| | self, |
| | vae_dim: int, |
| | embed_dim: int, |
| | model: StableVAE | None = None, |
| | ): |
| | super().__init__() |
| | self.model = model |
| | self.proj = nn.Linear(vae_dim, embed_dim) |
| |
|
| | def forward( |
| | self, waveform: torch.Tensor, waveform_lengths: torch.Tensor |
| | ) -> tuple[torch.Tensor, torch.Tensor]: |
| | self.model.eval() |
| | with torch.no_grad(): |
| | z, z_mask = self.model.encode(waveform, waveform_lengths, pad_latent_len=500) |
| | z = self.proj(z.transpose(1, 2)) |
| | return {"output": z, "mask": z_mask} |
| |
|
| |
|
| | if __name__ == '__main__': |
| | import hydra |
| | from utils.config import generate_config_from_command_line_overrides |
| | model_config = generate_config_from_command_line_overrides( |
| | "../../../configs" |
| | ) |
| | autoencoder: StableVAE = hydra.utils.instantiate(model_config) |
| | autoencoder.eval() |
| |
|
| | waveform, sr = torchaudio.load( |
| | "/edit/syn_7.wav" |
| | ) |
| | waveform = waveform.mean(0, keepdim=True) |
| | waveform = torchaudio.functional.resample( |
| | waveform, sr, model_config["sample_rate"] |
| | ) |
| | import soundfile as sf |
| | sf.write( |
| | "./torch_test.wav", |
| | waveform[0].numpy(), |
| | samplerate=model_config["sample_rate"] |
| | ) |
| | print("waveform: ", waveform.shape) |
| | with torch.no_grad(): |
| | latent, latent_length = autoencoder.encode( |
| | waveform, torch.as_tensor([waveform.shape[-1]]) |
| | ) |
| | print("latent: ", latent.shape) |
| | print("latent_length: ", latent_length) |
| | reconstructed = autoencoder.decode(latent, latent_length) |
| | print("reconstructed: ", reconstructed.shape) |
| | |
| | sf.write( |
| | "./reconstructed.wav", |
| | reconstructed[0, 0].numpy(), |
| | samplerate=model_config["sample_rate"] |
| | ) |
| |
|