Spaces:
Sleeping
Sleeping
| import math | |
| from typing import List, Union | |
| import numpy as np | |
| import torch | |
| from einops import rearrange | |
| from torch import nn | |
| from torch.nn import functional as F | |
| from torch.nn.utils import weight_norm | |
| from audiotools import AudioSignal, STFTParams, ml | |
| from audiotools.ml import BaseModel | |
| from base import CodecMixin | |
| from layers import WNConv1d, WNConvTranspose1d, get_activation | |
| def init_weights(m, mean=0.0, std=0.02, init_type="xavier", gain=0.02): | |
| """ | |
| Initialize weights of the entire model using xavier_normal_ or kaiming_normal_. | |
| Args: | |
| m (nn.Module): The module to initialize. | |
| mean (float): Mean for weight initialization. | |
| std (float): Standard deviation for weight initialization. | |
| init_type (str): Type of initialization ('xavier' or 'kaiming'). | |
| gain (float): Gain for xavier initialization. | |
| """ | |
| classname = m.__class__.__name__ | |
| if init_type == "xavier": | |
| # Handle convolutional layers | |
| if "Depthwise_Separable" in classname: | |
| nn.init.xavier_normal_(m.depth_conv.weight.data, gain=gain) | |
| nn.init.xavier_normal_(m.point_conv.weight.data, gain=gain) | |
| if hasattr(m.depth_conv, "bias") and m.depth_conv.bias is not None: | |
| nn.init.zeros_(m.depth_conv.bias.data) | |
| if hasattr(m.point_conv, "bias") and m.point_conv.bias is not None: | |
| nn.init.zeros_(m.point_conv.bias.data) | |
| elif classname.find("Conv") != -1: | |
| nn.init.xavier_normal_(m.weight.data, gain=gain) | |
| if hasattr(m, "bias") and m.bias is not None: | |
| nn.init.zeros_(m.bias.data) | |
| # Handle batch normalization layers | |
| elif classname.find("BatchNorm") != -1: | |
| if hasattr(m, "weight") and m.weight is not None: | |
| nn.init.xavier_normal_(m.weight.data, gain=gain) | |
| if hasattr(m, "bias") and m.bias is not None: | |
| nn.init.zeros_(m.bias.data) | |
| # Handle custom layers like Snake1d and SnakeBeta | |
| elif classname == "Snake1d": | |
| if hasattr(m, "alpha") and m.alpha is not None: | |
| if m.alpha.data.dim() >= 2: | |
| nn.init.xavier_normal_(m.alpha.data, gain=gain) | |
| else: | |
| nn.init.normal_(m.alpha.data, mean=1.0, std=std) | |
| elif classname == "SnakeBeta": | |
| # Respect the alpha_logscale setting in SnakeBeta | |
| if hasattr(m, "alpha") and m.alpha is not None: | |
| if m.alpha_logscale: | |
| nn.init.constant_(m.alpha.data, 0.0) # Matches SnakeBeta's default | |
| else: | |
| nn.init.constant_(m.alpha.data, 1.0) | |
| if hasattr(m, "beta") and m.beta is not None: | |
| if m.alpha_logscale: | |
| nn.init.constant_(m.beta.data, 0.0) # Matches SnakeBeta's default | |
| else: | |
| nn.init.constant_(m.beta.data, 1.0) | |
| # Handle residual scaling parameters | |
| elif hasattr(m, "residual_scale") and m.residual_scale is not None: | |
| nn.init.xavier_normal_(m.residual_scale.data, gain=gain) | |
| else: | |
| # Kaiming initialization | |
| if "Depthwise_Separable" in classname: | |
| nn.init.kaiming_normal_( | |
| m.depth_conv.weight.data, mode="fan_out", nonlinearity="relu" | |
| ) | |
| nn.init.kaiming_normal_( | |
| m.point_conv.weight.data, mode="fan_out", nonlinearity="relu" | |
| ) | |
| elif classname.find("Conv") != -1: | |
| nn.init.kaiming_normal_(m.weight.data, mode="fan_out", nonlinearity="relu") | |
| if hasattr(m, "bias") and m.bias is not None: | |
| nn.init.zeros_(m.bias.data) | |
| elif classname.find("BatchNorm") != -1: | |
| if hasattr(m, "weight") and m.weight is not None: | |
| nn.init.normal_(m.weight.data, 1.0, std) | |
| if hasattr(m, "bias") and m.bias is not None: | |
| nn.init.zeros_(m.bias.data) | |
| elif classname == "Snake1d": | |
| if hasattr(m, "alpha") and m.alpha is not None: | |
| nn.init.normal_(m.alpha.data, 1.0, std) | |
| elif classname == "SnakeBeta": | |
| if hasattr(m, "beta") and m.beta is not None: | |
| nn.init.normal_(m.beta.data, 1.0, std) | |
| elif ( | |
| hasattr(m, "alpha") and m.alpha is not None | |
| ): # Fallback if SnakeBeta uses alpha | |
| nn.init.normal_(m.alpha.data, 1.0, std) | |
| elif hasattr(m, "residual_scale") and m.residual_scale is not None: | |
| nn.init.normal_(m.residual_scale.data, 0.1, std) | |
| class ResidualUnit(nn.Module): | |
| def __init__( | |
| self, | |
| dim: int = 16, | |
| dilation: int = 1, | |
| activation: str = "snake", | |
| alpha: float = 1.0, | |
| scale_residual: bool = False, | |
| ): | |
| """ | |
| Residual Unit with weight normalization and dilated convolutions. | |
| Args: | |
| dim (int): Number of input and output channels. | |
| dilation (int): Dilation factor for the convolution. | |
| activation (str): Activation function to use. | |
| alpha (float): Scaling factor for the activation function. | |
| """ | |
| super().__init__() | |
| pad = ((7 - 1) * dilation) // 2 | |
| self.block = nn.Sequential( | |
| get_activation(activation=activation, channels=dim, alpha=alpha), | |
| WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad), | |
| get_activation(activation=activation, channels=dim, alpha=alpha), | |
| WNConv1d(dim, dim, kernel_size=1), | |
| ) | |
| self.scale_residual = scale_residual | |
| if self.scale_residual: | |
| self.res_scale = nn.Parameter(torch.tensor(0.0)) # start at 0 | |
| def forward(self, x): | |
| y = self.block(x) | |
| pad = (x.shape[-1] - y.shape[-1]) // 2 | |
| if pad > 0: | |
| x = x[..., pad:-pad] | |
| if self.scale_residual: | |
| y = self.res_scale * y | |
| return x + y | |
| class EncoderBlock(nn.Module): | |
| def __init__( | |
| self, | |
| dim: int = 16, | |
| stride: int = 1, | |
| activation: str = "snake", | |
| alpha: float = 1.0, | |
| scale_residual: bool = False, | |
| ): | |
| """ | |
| Encoder block that downsamples the input and applies residual units. | |
| """ | |
| super().__init__() | |
| self.block = nn.Sequential( | |
| ResidualUnit( | |
| dim // 2, | |
| dilation=1, | |
| activation=activation, | |
| alpha=alpha, | |
| scale_residual=scale_residual, | |
| ), | |
| ResidualUnit( | |
| dim // 2, | |
| dilation=3, | |
| activation=activation, | |
| alpha=alpha, | |
| scale_residual=scale_residual, | |
| ), | |
| ResidualUnit( | |
| dim // 2, | |
| dilation=9, | |
| activation=activation, | |
| alpha=alpha, | |
| scale_residual=scale_residual, | |
| ), | |
| get_activation(activation=activation, channels=dim // 2, alpha=alpha), | |
| WNConv1d( | |
| dim // 2, | |
| dim, | |
| kernel_size=2 * stride, | |
| stride=stride, | |
| padding=math.ceil(stride / 2), | |
| ), | |
| ) | |
| def forward(self, x): | |
| return self.block(x) | |
| class Encoder(nn.Module): | |
| def __init__( | |
| self, | |
| d_model: int = 64, | |
| strides: list = [2, 4, 8, 8], | |
| d_latent: int = 64, | |
| d_in: int = 1, | |
| activation: str = "snake", | |
| alpha: float = 1.0, | |
| scale_residual: bool = False, | |
| ): | |
| super().__init__() | |
| # Create first convolution | |
| self.block = [WNConv1d(d_in, d_model, kernel_size=7, padding=3)] | |
| # Create EncoderBlocks that double channels as they downsample by `stride` | |
| for stride in strides: | |
| d_model *= 2 | |
| self.block += [ | |
| EncoderBlock( | |
| d_model, | |
| stride=stride, | |
| activation=activation, | |
| alpha=alpha, | |
| scale_residual=scale_residual, | |
| ) | |
| ] | |
| # Create last convolution | |
| self.block += [ | |
| get_activation(activation=activation, channels=d_model, alpha=alpha), | |
| WNConv1d(d_model, d_latent, kernel_size=3, padding=1), | |
| ] | |
| # Wrap black into nn.Sequential | |
| self.block = nn.Sequential(*self.block) | |
| self.enc_dim = d_model | |
| def forward(self, x): | |
| return self.block(x) | |
| class DecoderBlock(nn.Module): | |
| def __init__( | |
| self, | |
| input_dim: int = 16, | |
| output_dim: int = 8, | |
| stride: int = 1, | |
| norm: bool = False, | |
| activation: str = "snake", | |
| alpha: float = 1.0, | |
| scale_residual: bool = False, | |
| ): | |
| """ | |
| Decoder block that upsamples the input and applies residual units. | |
| """ | |
| super().__init__() | |
| if not norm: | |
| self.block = nn.Sequential( | |
| get_activation(activation=activation, channels=input_dim, alpha=alpha), | |
| WNConvTranspose1d( | |
| input_dim, | |
| output_dim, | |
| kernel_size=2 * stride, | |
| stride=stride, | |
| padding=math.ceil(stride / 2), | |
| output_padding=0 if stride % 2 == 0 else 1, | |
| ), | |
| ResidualUnit( | |
| output_dim, | |
| dilation=1, | |
| activation=activation, | |
| alpha=alpha, | |
| scale_residual=scale_residual, | |
| ), | |
| ResidualUnit( | |
| output_dim, | |
| dilation=3, | |
| activation=activation, | |
| alpha=alpha, | |
| scale_residual=scale_residual, | |
| ), | |
| ResidualUnit( | |
| output_dim, | |
| dilation=9, | |
| activation=activation, | |
| alpha=alpha, | |
| scale_residual=scale_residual, | |
| ), | |
| ) | |
| else: | |
| self.block = nn.Sequential( | |
| get_activation(activation=activation, channels=input_dim, alpha=alpha), | |
| WNConvTranspose1d( | |
| input_dim, | |
| output_dim, | |
| kernel_size=2 * stride, | |
| stride=stride, | |
| padding=math.ceil(stride / 2), | |
| output_padding=0 if stride % 2 == 0 else 1, | |
| ), | |
| nn.BatchNorm1d(output_dim), | |
| ResidualUnit( | |
| output_dim, | |
| dilation=1, | |
| activation=activation, | |
| alpha=alpha, | |
| scale_residual=scale_residual, | |
| ), | |
| nn.BatchNorm1d(output_dim), | |
| ResidualUnit( | |
| output_dim, | |
| dilation=3, | |
| activation=activation, | |
| alpha=alpha, | |
| scale_residual=scale_residual, | |
| ), | |
| nn.BatchNorm1d(output_dim), | |
| ResidualUnit( | |
| output_dim, | |
| dilation=9, | |
| activation=activation, | |
| alpha=alpha, | |
| scale_residual=scale_residual, | |
| ), | |
| ) | |
| def forward(self, x): | |
| return self.block(x) | |
| class Decoder(nn.Module): | |
| def __init__( | |
| self, | |
| input_channel, | |
| channels, | |
| rates, | |
| d_out: int = 1, | |
| norm: bool = False, | |
| activation: str = "snake", | |
| alpha: float = 1.0, | |
| scale_residual: bool = False, | |
| use_tanh_as_final: bool = True, | |
| use_bias_at_final: bool = True, | |
| ): | |
| super().__init__() | |
| # Add first conv layer | |
| layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)] | |
| # Add upsampling + MRF blocks | |
| for i, stride in enumerate(rates): | |
| input_dim = channels // 2**i | |
| output_dim = channels // 2 ** (i + 1) | |
| layers += [ | |
| DecoderBlock( | |
| input_dim, | |
| output_dim, | |
| stride, | |
| norm=norm, | |
| activation=activation, | |
| alpha=alpha, | |
| scale_residual=scale_residual, | |
| ) | |
| ] | |
| # Add final conv layer | |
| layers += [ | |
| get_activation(activation=activation, channels=output_dim, alpha=alpha), | |
| WNConv1d( | |
| output_dim, d_out, kernel_size=7, padding=3, bias=use_bias_at_final | |
| ), | |
| nn.Tanh() if use_tanh_as_final else nn.Identity(), | |
| ] | |
| self.use_tanh_as_final = use_tanh_as_final | |
| self.model = nn.Sequential(*layers) | |
| def forward(self, x): | |
| x = self.model(x) | |
| if not self.use_tanh_as_final: | |
| x = torch.clamp( | |
| x, min=-1.0, max=1.0 | |
| ) # Ensure output is within [-1, 1] range | |
| return x | |
| class DACVAE(BaseModel, CodecMixin): | |
| def __init__( | |
| self, | |
| encoder_dim: int = 64, | |
| encoder_rates: List[int] = [2, 4, 8, 8], | |
| latent_dim: int = 64, | |
| decoder_dim: int = 1536, | |
| decoder_rates: List[int] = [8, 8, 4, 2], | |
| sample_rate: int = 44100, | |
| d_in: int = 2, | |
| d_out: int = 2, | |
| weight_init: str = "xavier", | |
| norm: bool = False, | |
| activation: str = "snake", | |
| alpha: float = 1.0, | |
| gain: float = 0.02, | |
| scale_residual: bool = False, | |
| use_tanh_as_final: bool = True, | |
| use_bias_at_final: bool = True, | |
| ): | |
| super().__init__() | |
| self.encoder_dim = encoder_dim | |
| self.encoder_rates = encoder_rates | |
| self.decoder_dim = decoder_dim | |
| self.decoder_rates = decoder_rates | |
| self.sample_rate = sample_rate | |
| self.d_in = d_in | |
| self.d_out = d_out | |
| if latent_dim is None: | |
| latent_dim = encoder_dim * (2 ** len(encoder_rates)) | |
| self.latent_dim = latent_dim | |
| self.hop_length = np.prod(encoder_rates) | |
| self.encoder = Encoder( | |
| encoder_dim, | |
| encoder_rates, | |
| latent_dim, | |
| d_in=d_in, | |
| activation=activation, | |
| alpha=alpha, | |
| scale_residual=scale_residual, | |
| ) | |
| self.decoder = Decoder( | |
| latent_dim, | |
| decoder_dim, | |
| decoder_rates, | |
| d_out=d_out, | |
| norm=norm, | |
| activation=activation, | |
| alpha=alpha, | |
| scale_residual=scale_residual, | |
| use_tanh_as_final=use_tanh_as_final, | |
| use_bias_at_final=use_bias_at_final, | |
| ) | |
| self.en_conv_post = WNConv1d( | |
| self.latent_dim, 2 * self.latent_dim, kernel_size=1 | |
| ) | |
| self.de_conv_pre = WNConv1d(self.latent_dim, self.latent_dim, kernel_size=1) | |
| self.sample_rate = sample_rate | |
| self.apply(lambda m: init_weights(m, init_type=weight_init, gain=gain)) | |
| self.step = 0 # Initialize step counter for noise decay | |
| def freeze_encoder(self): | |
| for param in self.encoder.parameters(): | |
| param.requires_grad = False | |
| for param in self.en_conv_post.parameters(): | |
| param.requires_grad = False | |
| print("Encoder and en_conv_post frozen") | |
| def preprocess(self, audio_data, sample_rate): | |
| if sample_rate is None: | |
| sample_rate = self.sample_rate | |
| assert sample_rate == self.sample_rate | |
| length = audio_data.shape[-1] | |
| # print(f"Audio length: {length}", "math.ceil(length / self.hop_length) * self.hop_length: ", math.ceil(length / self.hop_length) * self.hop_length) | |
| right_pad = math.ceil(length / self.hop_length) * self.hop_length - length | |
| audio_data = nn.functional.pad(audio_data, (0, right_pad)) | |
| return audio_data | |
| def encode( | |
| self, | |
| audio_data: torch.Tensor, | |
| training: bool = True, | |
| ): | |
| x = self.encoder(audio_data) | |
| x = F.leaky_relu(x) | |
| x = self.en_conv_post(x) | |
| print('x shape: ', x.shape) | |
| m, logs = torch.split(x, self.latent_dim, dim=1) | |
| logs = torch.clamp(logs, min=-14.0, max=14.0) | |
| z = m + torch.randn_like(m) * torch.exp(logs) | |
| return z, m, logs | |
| def decode(self, z: torch.Tensor): | |
| z = self.de_conv_pre(z) | |
| z = self.decoder(z) | |
| return z | |
| def forward( | |
| self, | |
| audio_data: torch.Tensor, | |
| sample_rate: int = 24000, | |
| ): | |
| # print(f"Audio data shape: {audio_data.shape}") | |
| length = audio_data.shape[-1] | |
| audio_data = self.preprocess(audio_data, sample_rate) | |
| print('audio_data: ', audio_data.shape) | |
| z, m, logs = self.encode(audio_data) | |
| x = self.decode(z) | |
| return { | |
| "audio": x[..., :length], | |
| "z": z, | |
| "mu": m, | |
| "logs": logs, | |
| } | |
| def WNConv1d(*args, **kwargs): | |
| act = kwargs.pop("act", True) | |
| conv = weight_norm(nn.Conv1d(*args, **kwargs)) | |
| if not act: | |
| return conv | |
| return nn.Sequential(conv, nn.LeakyReLU(0.1)) | |
| def WNConv2d(*args, **kwargs): | |
| act = kwargs.pop("act", True) | |
| conv = weight_norm(nn.Conv2d(*args, **kwargs)) | |
| if not act: | |
| return conv | |
| return nn.Sequential(conv, nn.LeakyReLU(0.1)) | |
| class MPD(nn.Module): | |
| def __init__(self, period): | |
| super().__init__() | |
| self.period = period | |
| self.convs = nn.ModuleList( | |
| [ | |
| WNConv2d(1, 32, (5, 1), (3, 1), padding=(2, 0)), | |
| WNConv2d(32, 128, (5, 1), (3, 1), padding=(2, 0)), | |
| WNConv2d(128, 512, (5, 1), (3, 1), padding=(2, 0)), | |
| WNConv2d(512, 1024, (5, 1), (3, 1), padding=(2, 0)), | |
| WNConv2d(1024, 1024, (5, 1), 1, padding=(2, 0)), | |
| ] | |
| ) | |
| self.conv_post = WNConv2d( | |
| 1024, 1, kernel_size=(3, 1), padding=(1, 0), act=False | |
| ) | |
| def pad_to_period(self, x): | |
| t = x.shape[-1] | |
| x = F.pad(x, (0, self.period - t % self.period), mode="reflect") | |
| return x | |
| def forward(self, x): | |
| fmap = [] | |
| x = self.pad_to_period(x) | |
| x = rearrange(x, "b c (l p) -> b c l p", p=self.period) | |
| for layer in self.convs: | |
| x = layer(x) | |
| fmap.append(x) | |
| x = self.conv_post(x) | |
| fmap.append(x) | |
| return fmap | |
| class MSD(nn.Module): | |
| def __init__(self, rate: int = 1, sample_rate: int = 44100): | |
| super().__init__() | |
| self.convs = nn.ModuleList( | |
| [ | |
| WNConv1d(1, 16, 15, 1, padding=7), | |
| WNConv1d(16, 64, 41, 4, groups=4, padding=20), | |
| WNConv1d(64, 256, 41, 4, groups=16, padding=20), | |
| WNConv1d(256, 1024, 41, 4, groups=64, padding=20), | |
| WNConv1d(1024, 1024, 41, 4, groups=256, padding=20), | |
| WNConv1d(1024, 1024, 5, 1, padding=2), | |
| ] | |
| ) | |
| self.conv_post = WNConv1d(1024, 1, 3, 1, padding=1, act=False) | |
| self.sample_rate = sample_rate | |
| self.rate = rate | |
| def forward(self, x): | |
| x = AudioSignal(x, self.sample_rate) | |
| x.resample(self.sample_rate // self.rate) | |
| x = x.audio_data | |
| fmap = [] | |
| for l in self.convs: | |
| x = l(x) | |
| fmap.append(x) | |
| x = self.conv_post(x) | |
| fmap.append(x) | |
| return fmap | |
| BANDS = [(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)] | |
| class MRD(nn.Module): | |
| def __init__( | |
| self, | |
| window_length: int, | |
| hop_factor: float = 0.25, | |
| sample_rate: int = 44100, | |
| bands: list = BANDS, | |
| ): | |
| """Complex multi-band spectrogram discriminator. | |
| Parameters | |
| ---------- | |
| window_length : int | |
| Window length of STFT. | |
| hop_factor : float, optional | |
| Hop factor of the STFT, defaults to ``0.25 * window_length``. | |
| sample_rate : int, optional | |
| Sampling rate of audio in Hz, by default 44100 | |
| bands : list, optional | |
| Bands to run discriminator over. | |
| """ | |
| super().__init__() | |
| self.window_length = window_length | |
| self.hop_factor = hop_factor | |
| self.sample_rate = sample_rate | |
| self.stft_params = STFTParams( | |
| window_length=window_length, | |
| hop_length=int(window_length * hop_factor), | |
| match_stride=True, | |
| ) | |
| n_fft = window_length // 2 + 1 | |
| bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands] | |
| self.bands = bands | |
| ch = 32 | |
| convs = lambda: nn.ModuleList( | |
| [ | |
| WNConv2d(2, ch, (3, 9), (1, 1), padding=(1, 4)), | |
| WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), | |
| WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), | |
| WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), | |
| WNConv2d(ch, ch, (3, 3), (1, 1), padding=(1, 1)), | |
| ] | |
| ) | |
| self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))]) | |
| self.conv_post = WNConv2d(ch, 1, (3, 3), (1, 1), padding=(1, 1), act=False) | |
| def spectrogram(self, x): | |
| x = AudioSignal(x, self.sample_rate, stft_params=self.stft_params) | |
| x = torch.view_as_real(x.stft()) | |
| x = rearrange(x, "b 1 f t c -> (b 1) c t f") | |
| # Split into bands | |
| x_bands = [x[..., b[0] : b[1]] for b in self.bands] | |
| return x_bands | |
| def forward(self, x): | |
| x_bands = self.spectrogram(x) | |
| fmap = [] | |
| x = [] | |
| for band, stack in zip(x_bands, self.band_convs): | |
| for layer in stack: | |
| band = layer(band) | |
| fmap.append(band) | |
| x.append(band) | |
| x = torch.cat(x, dim=-1) | |
| x = self.conv_post(x) | |
| fmap.append(x) | |
| return fmap | |
| class Discriminator(ml.BaseModel): | |
| def __init__( | |
| self, | |
| rates: list = [], | |
| periods: list = [2, 3, 5, 7, 11], | |
| fft_sizes: list = [2048, 1024, 512], | |
| sample_rate: int = 44100, | |
| bands: list = BANDS, | |
| d_in: int = 1, | |
| ): | |
| """Discriminator that combines multiple discriminators. | |
| Parameters | |
| ---------- | |
| rates : list, optional | |
| sampling rates (in Hz) to run MSD at, by default [] | |
| If empty, MSD is not used. | |
| periods : list, optional | |
| periods (of samples) to run MPD at, by default [2, 3, 5, 7, 11] | |
| fft_sizes : list, optional | |
| Window sizes of the FFT to run MRD at, by default [2048, 1024, 512] | |
| sample_rate : int, optional | |
| Sampling rate of audio in Hz, by default 44100 | |
| bands : list, optional | |
| Bands to run MRD at, by default `BANDS` | |
| """ | |
| super().__init__() | |
| discs = [] | |
| discs += [MPD(p) for p in periods] | |
| discs += [MSD(r, sample_rate=sample_rate) for r in rates] | |
| discs += [MRD(f, sample_rate=sample_rate, bands=bands) for f in fft_sizes] | |
| self.discriminators = nn.ModuleList(discs) | |
| def preprocess(self, y): | |
| # Remove DC offset | |
| y = y - y.mean(dim=-1, keepdims=True) | |
| # Peak normalize the volume of input audio | |
| y = 0.8 * y / (y.abs().max(dim=-1, keepdim=True)[0] + 1e-9) | |
| return y | |
| def forward(self, x): | |
| x = self.preprocess(x) | |
| fmaps = [d(x) for d in self.discriminators] | |
| return fmaps | |
| if __name__ == "__main__": | |
| disc = Discriminator() | |
| x = torch.zeros(1, 1, 44100) | |
| results = disc(x) | |
| for i, result in enumerate(results): | |
| print(f"disc{i}") | |
| for i, r in enumerate(result): | |
| print(r.shape, r.mean(), r.min(), r.max()) | |
| print() | |