| import math |
| from typing import List, Union |
| import numpy as np |
| import torch |
| from torch import nn |
| from torch.nn.utils import weight_norm |
| import torch.nn.functional as F |
| from einops import rearrange |
|
|
| def WNConv1d(*args, **kwargs): |
| return weight_norm(nn.Conv1d(*args, **kwargs)) |
|
|
|
|
| def WNConvTranspose1d(*args, **kwargs): |
| return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) |
|
|
|
|
| |
| @torch.jit.script |
| def snake(x, alpha): |
| shape = x.shape |
| x = x.reshape(shape[0], shape[1], -1) |
| x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2) |
| x = x.reshape(shape) |
| return x |
|
|
|
|
| class Snake1d(nn.Module): |
| def __init__(self, channels): |
| super().__init__() |
| self.alpha = nn.Parameter(torch.ones(1, channels, 1)) |
|
|
| def forward(self, x): |
| return snake(x, self.alpha) |
|
|
|
|
| class VectorQuantize(nn.Module): |
| """ |
| Implementation of VQ similar to Karpathy's repo: |
| https://github.com/karpathy/deep-vector-quantization |
| Additionally uses following tricks from Improved VQGAN |
| (https://arxiv.org/pdf/2110.04627.pdf): |
| 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space |
| for improved codebook usage |
| 2. l2-normalized codes: Converts euclidean distance to cosine similarity which |
| improves training stability |
| """ |
|
|
| def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int): |
| super().__init__() |
| self.codebook_size = codebook_size |
| self.codebook_dim = codebook_dim |
|
|
| self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1) |
| self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1) |
| self.codebook = nn.Embedding(codebook_size, codebook_dim) |
|
|
| def forward(self, z): |
| """Quantized the input tensor using a fixed codebook and returns |
| the corresponding codebook vectors |
| |
| Parameters |
| ---------- |
| z : Tensor[B x D x T] |
| |
| Returns |
| ------- |
| Tensor[B x D x T] |
| Quantized continuous representation of input |
| Tensor[1] |
| Commitment loss to train encoder to predict vectors closer to codebook |
| entries |
| Tensor[1] |
| Codebook loss to update the codebook |
| Tensor[B x T] |
| Codebook indices (quantized discrete representation of input) |
| Tensor[B x D x T] |
| Projected latents (continuous representation of input before quantization) |
| """ |
|
|
| |
| z_e = self.in_proj(z) |
| z_q, indices = self.decode_latents(z_e) |
|
|
| commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2]) |
| codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2]) |
|
|
| z_q = ( |
| z_e + (z_q - z_e).detach() |
| ) |
|
|
| z_q = self.out_proj(z_q) |
|
|
| return z_q, commitment_loss, codebook_loss, indices, z_e |
|
|
| def embed_code(self, embed_id): |
| return F.embedding(embed_id, self.codebook.weight) |
|
|
| def decode_code(self, embed_id): |
| return self.embed_code(embed_id).transpose(1, 2) |
|
|
| def decode_latents(self, latents): |
| encodings = rearrange(latents, "b d t -> (b t) d") |
| codebook = self.codebook.weight |
|
|
| |
| encodings = F.normalize(encodings) |
| codebook = F.normalize(codebook) |
|
|
| |
| dist = ( |
| encodings.pow(2).sum(1, keepdim=True) |
| - 2 * encodings @ codebook.t() |
| + codebook.pow(2).sum(1, keepdim=True).t() |
| ) |
| indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0)) |
| z_q = self.decode_code(indices) |
| return z_q, indices |
|
|
|
|
| class ResidualVectorQuantize(nn.Module): |
| """ |
| Introduced in SoundStream: An end2end neural audio codec |
| https://arxiv.org/abs/2107.03312 |
| """ |
|
|
| def __init__( |
| self, |
| input_dim: int = 512, |
| n_codebooks: int = 9, |
| codebook_size: int = 1024, |
| codebook_dim: Union[int, list] = 8, |
| quantizer_dropout: float = 0.0, |
| ): |
| super().__init__() |
| if isinstance(codebook_dim, int): |
| codebook_dim = [codebook_dim for _ in range(n_codebooks)] |
|
|
| self.n_codebooks = n_codebooks |
| self.codebook_dim = codebook_dim |
| self.codebook_size = codebook_size |
|
|
| self.quantizers = nn.ModuleList( |
| [ |
| VectorQuantize(input_dim, codebook_size, codebook_dim[i]) |
| for i in range(n_codebooks) |
| ] |
| ) |
| self.quantizer_dropout = quantizer_dropout |
|
|
| def forward(self, z, n_quantizers: int = None): |
| """Quantized the input tensor using a fixed set of `n` codebooks and returns |
| the corresponding codebook vectors |
| Parameters |
| ---------- |
| z : Tensor[B x D x T] |
| n_quantizers : int, optional |
| No. of quantizers to use |
| (n_quantizers < self.n_codebooks ex: for quantizer dropout) |
| Note: if `self.quantizer_dropout` is True, this argument is ignored |
| when in training mode, and a random number of quantizers is used. |
| Returns |
| ------- |
| dict |
| A dictionary with the following keys: |
| |
| "z" : Tensor[B x D x T] |
| Quantized continuous representation of input |
| "codes" : Tensor[B x N x T] |
| Codebook indices for each codebook |
| (quantized discrete representation of input) |
| "latents" : Tensor[B x N*D x T] |
| Projected latents (continuous representation of input before quantization) |
| "vq/commitment_loss" : Tensor[1] |
| Commitment loss to train encoder to predict vectors closer to codebook |
| entries |
| "vq/codebook_loss" : Tensor[1] |
| Codebook loss to update the codebook |
| """ |
| z_q = 0 |
| residual = z |
| commitment_loss = 0 |
| codebook_loss = 0 |
|
|
| codebook_indices = [] |
| latents = [] |
|
|
| if n_quantizers is None: |
| n_quantizers = self.n_codebooks |
| if self.training: |
| n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1 |
| dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],)) |
| n_dropout = int(z.shape[0] * self.quantizer_dropout) |
| n_quantizers[:n_dropout] = dropout[:n_dropout] |
| n_quantizers = n_quantizers.to(z.device) |
|
|
| for i, quantizer in enumerate(self.quantizers): |
| if self.training is False and i >= n_quantizers: |
| break |
|
|
| z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer( |
| residual |
| ) |
|
|
| |
| mask = ( |
| torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers |
| ) |
| z_q = z_q + z_q_i * mask[:, None, None] |
| residual = residual - z_q_i |
|
|
| |
| commitment_loss += (commitment_loss_i * mask).mean() |
| codebook_loss += (codebook_loss_i * mask).mean() |
|
|
| codebook_indices.append(indices_i) |
| latents.append(z_e_i) |
|
|
| codes = torch.stack(codebook_indices, dim=1) |
| latents = torch.cat(latents, dim=1) |
|
|
| return z_q, codes, latents, commitment_loss, codebook_loss |
|
|
| def from_codes(self, codes: torch.Tensor): |
| """Given the quantized codes, reconstruct the continuous representation |
| Parameters |
| ---------- |
| codes : Tensor[B x N x T] |
| Quantized discrete representation of input |
| Returns |
| ------- |
| Tensor[B x D x T] |
| Quantized continuous representation of input |
| """ |
| z_q = 0.0 |
| z_p = [] |
| n_codebooks = codes.shape[1] |
| for i in range(n_codebooks): |
| z_p_i = self.quantizers[i].decode_code(codes[:, i, :]) |
| z_p.append(z_p_i) |
|
|
| z_q_i = self.quantizers[i].out_proj(z_p_i) |
| z_q = z_q + z_q_i |
| return z_q, torch.cat(z_p, dim=1), codes |
|
|
| def from_latents(self, latents: torch.Tensor): |
| """Given the unquantized latents, reconstruct the |
| continuous representation after quantization. |
| |
| Parameters |
| ---------- |
| latents : Tensor[B x N x T] |
| Continuous representation of input after projection |
| |
| Returns |
| ------- |
| Tensor[B x D x T] |
| Quantized representation of full-projected space |
| Tensor[B x D x T] |
| Quantized representation of latent space |
| """ |
| z_q = 0 |
| z_p = [] |
| codes = [] |
| dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers]) |
|
|
| n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[ |
| 0 |
| ] |
| for i in range(n_codebooks): |
| j, k = dims[i], dims[i + 1] |
| z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :]) |
| z_p.append(z_p_i) |
| codes.append(codes_i) |
|
|
| z_q_i = self.quantizers[i].out_proj(z_p_i) |
| z_q = z_q + z_q_i |
|
|
| return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1) |
|
|
|
|
| class AbstractDistribution: |
| def sample(self): |
| raise NotImplementedError() |
|
|
| def mode(self): |
| raise NotImplementedError() |
|
|
|
|
| class DiracDistribution(AbstractDistribution): |
| def __init__(self, value): |
| self.value = value |
|
|
| def sample(self): |
| return self.value |
|
|
| def mode(self): |
| return self.value |
|
|
|
|
| class DiagonalGaussianDistribution(object): |
| def __init__(self, parameters, deterministic=False): |
| self.parameters = parameters |
| self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) |
| self.logvar = torch.clamp(self.logvar, -30.0, 20.0) |
| self.deterministic = deterministic |
| self.std = torch.exp(0.5 * self.logvar) |
| self.var = torch.exp(self.logvar) |
| if self.deterministic: |
| self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) |
|
|
| def sample(self): |
| x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) |
| return x |
|
|
| def kl(self, other=None): |
| if self.deterministic: |
| return torch.Tensor([0.0]) |
| else: |
| if other is None: |
| return 0.5 * torch.mean( |
| torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, |
| dim=[1, 2], |
| ) |
| else: |
| return 0.5 * torch.mean( |
| torch.pow(self.mean - other.mean, 2) / other.var |
| + self.var / other.var |
| - 1.0 |
| - self.logvar |
| + other.logvar, |
| dim=[1, 2], |
| ) |
|
|
| def nll(self, sample, dims=[1, 2]): |
| if self.deterministic: |
| return torch.Tensor([0.0]) |
| logtwopi = np.log(2.0 * np.pi) |
| return 0.5 * torch.sum( |
| logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, |
| dim=dims, |
| ) |
|
|
| def mode(self): |
| return self.mean |
|
|
|
|
| def normal_kl(mean1, logvar1, mean2, logvar2): |
| """ |
| source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 |
| Compute the KL divergence between two gaussians. |
| Shapes are automatically broadcasted, so batches can be compared to |
| scalars, among other use cases. |
| """ |
| tensor = None |
| for obj in (mean1, logvar1, mean2, logvar2): |
| if isinstance(obj, torch.Tensor): |
| tensor = obj |
| break |
| assert tensor is not None, "at least one argument must be a Tensor" |
|
|
| |
| |
| logvar1, logvar2 = [x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) for x in (logvar1, logvar2)] |
|
|
| return 0.5 * ( |
| -1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) |
| ) |
|
|
|
|
| def init_weights(m): |
| if isinstance(m, nn.Conv1d): |
| nn.init.trunc_normal_(m.weight, std=0.02) |
| nn.init.constant_(m.bias, 0) |
|
|
|
|
| class ResidualUnit(nn.Module): |
| def __init__(self, dim: int = 16, dilation: int = 1): |
| super().__init__() |
| pad = ((7 - 1) * dilation) // 2 |
| self.block = nn.Sequential( |
| Snake1d(dim), |
| WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad), |
| Snake1d(dim), |
| WNConv1d(dim, dim, kernel_size=1), |
| ) |
|
|
| def forward(self, x): |
| y = self.block(x) |
| pad = (x.shape[-1] - y.shape[-1]) // 2 |
| if pad > 0: |
| x = x[..., pad:-pad] |
| return x + y |
|
|
|
|
| class EncoderBlock(nn.Module): |
| def __init__(self, dim: int = 16, stride: int = 1): |
| super().__init__() |
| self.block = nn.Sequential( |
| ResidualUnit(dim // 2, dilation=1), |
| ResidualUnit(dim // 2, dilation=3), |
| ResidualUnit(dim // 2, dilation=9), |
| Snake1d(dim // 2), |
| WNConv1d( |
| dim // 2, |
| dim, |
| kernel_size=2 * stride, |
| stride=stride, |
| padding=math.ceil(stride / 2), |
| ), |
| ) |
|
|
| def forward(self, x): |
| return self.block(x) |
|
|
|
|
| class Encoder(nn.Module): |
| def __init__( |
| self, |
| d_model: int = 64, |
| strides: list = [2, 4, 8, 8], |
| d_latent: int = 64, |
| ): |
| super().__init__() |
| |
| self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)] |
|
|
| |
| for stride in strides: |
| d_model *= 2 |
| self.block += [EncoderBlock(d_model, stride=stride)] |
|
|
| |
| self.block += [ |
| Snake1d(d_model), |
| WNConv1d(d_model, d_latent, kernel_size=3, padding=1), |
| ] |
|
|
| |
| self.block = nn.Sequential(*self.block) |
| self.enc_dim = d_model |
|
|
| def forward(self, x): |
| return self.block(x) |
|
|
|
|
| class DecoderBlock(nn.Module): |
| def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1): |
| super().__init__() |
| self.block = nn.Sequential( |
| Snake1d(input_dim), |
| WNConvTranspose1d( |
| input_dim, |
| output_dim, |
| kernel_size=2 * stride, |
| stride=stride, |
| padding=math.ceil(stride / 2), |
| output_padding=stride % 2, |
| ), |
| ResidualUnit(output_dim, dilation=1), |
| ResidualUnit(output_dim, dilation=3), |
| ResidualUnit(output_dim, dilation=9), |
| ) |
|
|
| def forward(self, x): |
| return self.block(x) |
|
|
|
|
| class Decoder(nn.Module): |
| def __init__( |
| self, |
| input_channel, |
| channels, |
| rates, |
| d_out: int = 1, |
| ): |
| super().__init__() |
|
|
| |
| layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)] |
|
|
| |
| for i, stride in enumerate(rates): |
| input_dim = channels // 2**i |
| output_dim = channels // 2 ** (i + 1) |
| layers += [DecoderBlock(input_dim, output_dim, stride)] |
|
|
| |
| layers += [ |
| Snake1d(output_dim), |
| WNConv1d(output_dim, d_out, kernel_size=7, padding=3), |
| nn.Tanh(), |
| ] |
|
|
| self.model = nn.Sequential(*layers) |
|
|
| def forward(self, x): |
| return self.model(x) |
|
|
|
|
| class DacVAE(nn.Module): |
|
|
| def __init__( |
| self, |
| encoder_dim: int = 128, |
| encoder_rates: List[int] = [2, 3, 4, 5, 8], |
| latent_dim: int = 128, |
| decoder_dim: int = 2048, |
| decoder_rates: List[int] = [8, 5, 4, 3, 2], |
| n_codebooks: int = 9, |
| codebook_size: int = 1024, |
| codebook_dim: Union[int, list] = 8, |
| quantizer_dropout: bool = False, |
| sample_rate: int = 48000, |
| continuous: bool = True, |
| use_weight_norm: bool = False, |
| ): |
| super().__init__() |
|
|
| self.encoder_dim = encoder_dim |
| self.encoder_rates = encoder_rates |
| self.decoder_dim = decoder_dim |
| self.decoder_rates = decoder_rates |
| self.sample_rate = sample_rate |
| self.continuous = continuous |
| self.use_weight_norm = use_weight_norm |
|
|
| if latent_dim is None: |
| latent_dim = encoder_dim * (2 ** len(encoder_rates)) |
|
|
| self.latent_dim = latent_dim |
|
|
| self.hop_length = np.prod(encoder_rates) |
| self.encoder = Encoder(encoder_dim, encoder_rates, latent_dim) |
|
|
| if not continuous: |
| self.n_codebooks = n_codebooks |
| self.codebook_size = codebook_size |
| self.codebook_dim = codebook_dim |
| self.quantizer = ResidualVectorQuantize( |
| input_dim=latent_dim, |
| n_codebooks=n_codebooks, |
| codebook_size=codebook_size, |
| codebook_dim=codebook_dim, |
| quantizer_dropout=quantizer_dropout, |
| ) |
| else: |
| self.quant_conv = torch.nn.Conv1d(latent_dim, 2 * latent_dim, 1) |
| self.post_quant_conv = torch.nn.Conv1d(latent_dim, latent_dim, 1) |
|
|
| self.decoder = Decoder( |
| latent_dim, |
| decoder_dim, |
| decoder_rates, |
| ) |
| self.sample_rate = sample_rate |
| self.apply(init_weights) |
|
|
| self.delay = self.get_delay() |
|
|
| if not self.use_weight_norm: |
| self.remove_weight_norm() |
|
|
| def get_delay(self): |
| |
| l_out = self.get_output_length(0) |
| L = l_out |
|
|
| layers = [] |
| for layer in self.modules(): |
| if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)): |
| layers.append(layer) |
|
|
| for layer in reversed(layers): |
| d = layer.dilation[0] |
| k = layer.kernel_size[0] |
| s = layer.stride[0] |
|
|
| if isinstance(layer, nn.ConvTranspose1d): |
| L = ((L - d * (k - 1) - 1) / s) + 1 |
| elif isinstance(layer, nn.Conv1d): |
| L = (L - 1) * s + d * (k - 1) + 1 |
|
|
| L = math.ceil(L) |
|
|
| l_in = L |
|
|
| return (l_in - l_out) // 2 |
|
|
| def get_output_length(self, input_length): |
| L = input_length |
| |
| for layer in self.modules(): |
| if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)): |
| d = layer.dilation[0] |
| k = layer.kernel_size[0] |
| s = layer.stride[0] |
|
|
| if isinstance(layer, nn.Conv1d): |
| L = ((L - d * (k - 1) - 1) / s) + 1 |
| elif isinstance(layer, nn.ConvTranspose1d): |
| L = (L - 1) * s + d * (k - 1) + 1 |
|
|
| L = math.floor(L) |
| return L |
|
|
| @property |
| def dtype(self): |
| """Get the dtype of the model parameters.""" |
| |
| for param in self.parameters(): |
| return param.dtype |
| return torch.float32 |
|
|
| @property |
| def device(self): |
| """Get the device of the model parameters.""" |
| |
| for param in self.parameters(): |
| return param.device |
| return torch.device('cpu') |
|
|
| def preprocess(self, audio_data, sample_rate): |
| if sample_rate is None: |
| sample_rate = self.sample_rate |
| assert sample_rate == self.sample_rate |
|
|
| length = audio_data.shape[-1] |
| right_pad = math.ceil(length / self.hop_length) * self.hop_length - length |
| audio_data = nn.functional.pad(audio_data, (0, right_pad)) |
|
|
| return audio_data |
|
|
| def encode( |
| self, |
| audio_data: torch.Tensor, |
| n_quantizers: int = None, |
| ): |
| """Encode given audio data and return quantized latent codes |
| |
| Parameters |
| ---------- |
| audio_data : Tensor[B x 1 x T] |
| Audio data to encode |
| n_quantizers : int, optional |
| Number of quantizers to use, by default None |
| If None, all quantizers are used. |
| |
| Returns |
| ------- |
| dict |
| A dictionary with the following keys: |
| "z" : Tensor[B x D x T] |
| Quantized continuous representation of input |
| "codes" : Tensor[B x N x T] |
| Codebook indices for each codebook |
| (quantized discrete representation of input) |
| "latents" : Tensor[B x N*D x T] |
| Projected latents (continuous representation of input before quantization) |
| "vq/commitment_loss" : Tensor[1] |
| Commitment loss to train encoder to predict vectors closer to codebook |
| entries |
| "vq/codebook_loss" : Tensor[1] |
| Codebook loss to update the codebook |
| "length" : int |
| Number of samples in input audio |
| """ |
| z = self.encoder(audio_data) |
| if not self.continuous: |
| z, codes, latents, commitment_loss, codebook_loss = self.quantizer(z, n_quantizers) |
| else: |
| z = self.quant_conv(z) |
| z = DiagonalGaussianDistribution(z) |
| codes, latents, commitment_loss, codebook_loss = None, None, 0, 0 |
|
|
| return z, codes, latents, commitment_loss, codebook_loss |
|
|
| def decode(self, z: torch.Tensor): |
| """Decode given latent codes and return audio data |
| |
| Parameters |
| ---------- |
| z : Tensor[B x D x T] |
| Quantized continuous representation of input |
| length : int, optional |
| Number of samples in output audio, by default None |
| |
| Returns |
| ------- |
| dict |
| A dictionary with the following keys: |
| "audio" : Tensor[B x 1 x length] |
| Decoded audio data. |
| """ |
| if not self.continuous: |
| audio = self.decoder(z) |
| else: |
| z = self.post_quant_conv(z) |
| audio = self.decoder(z) |
|
|
| return audio |
|
|
| def forward( |
| self, |
| audio_data: torch.Tensor, |
| sample_rate: int = None, |
| n_quantizers: int = None, |
| ): |
| """Model forward pass |
| |
| Parameters |
| ---------- |
| audio_data : Tensor[B x 1 x T] |
| Audio data to encode |
| sample_rate : int, optional |
| Sample rate of audio data in Hz, by default None |
| If None, defaults to `self.sample_rate` |
| n_quantizers : int, optional |
| Number of quantizers to use, by default None. |
| If None, all quantizers are used. |
| |
| Returns |
| ------- |
| dict |
| A dictionary with the following keys: |
| "z" : Tensor[B x D x T] |
| Quantized continuous representation of input |
| "codes" : Tensor[B x N x T] |
| Codebook indices for each codebook |
| (quantized discrete representation of input) |
| "latents" : Tensor[B x N*D x T] |
| Projected latents (continuous representation of input before quantization) |
| "vq/commitment_loss" : Tensor[1] |
| Commitment loss to train encoder to predict vectors closer to codebook |
| entries |
| "vq/codebook_loss" : Tensor[1] |
| Codebook loss to update the codebook |
| "length" : int |
| Number of samples in input audio |
| "audio" : Tensor[B x 1 x length] |
| Decoded audio data. |
| """ |
| length = audio_data.shape[-1] |
| audio_data = self.preprocess(audio_data, sample_rate) |
| if not self.continuous: |
| z, codes, latents, commitment_loss, codebook_loss = self.encode(audio_data, n_quantizers) |
|
|
| x = self.decode(z) |
| return { |
| "audio": x[..., :length], |
| "z": z, |
| "codes": codes, |
| "latents": latents, |
| "vq/commitment_loss": commitment_loss, |
| "vq/codebook_loss": codebook_loss, |
| } |
| else: |
| posterior, _, _, _, _ = self.encode(audio_data, n_quantizers) |
| z = posterior.sample() |
| x = self.decode(z) |
|
|
| kl_loss = posterior.kl() |
| kl_loss = kl_loss.mean() |
|
|
| return { |
| "audio": x[..., :length], |
| "z": z, |
| "kl_loss": kl_loss, |
| } |
|
|
| def remove_weight_norm(self): |
| """ |
| Remove weight_norm from all modules in the model. |
| This fuses the weight_g and weight_v parameters into a single weight parameter. |
| Should be called before inference for better performance. |
| Returns: |
| self: The model with weight_norm removed |
| """ |
| from torch.nn.utils import remove_weight_norm |
| num_removed = 0 |
| for name, module in list(self.named_modules()): |
| if hasattr(module, "_forward_pre_hooks"): |
| for hook_id, hook in list(module._forward_pre_hooks.items()): |
| if "WeightNorm" in str(type(hook)): |
| try: |
| remove_weight_norm(module) |
| num_removed += 1 |
| |
| except ValueError as e: |
| print(f"Failed to remove weight_norm from {name}: {e}") |
| if num_removed > 0: |
| |
| self.use_weight_norm = False |
| else: |
| print("No weight_norm found in the model") |
| return self |
|
|