Spaces:
Build error
Build error
| import torch | |
| class Conv1d(torch.nn.Conv1d): | |
| def __init__(self, w_init_gain= 'linear', *args, **kwargs): | |
| self.w_init_gain = w_init_gain | |
| super().__init__(*args, **kwargs) | |
| def reset_parameters(self): | |
| if self.w_init_gain in ['zero']: | |
| torch.nn.init.zeros_(self.weight) | |
| elif self.w_init_gain is None: | |
| pass | |
| elif self.w_init_gain in ['relu', 'leaky_relu']: | |
| torch.nn.init.kaiming_uniform_(self.weight, nonlinearity= self.w_init_gain) | |
| elif self.w_init_gain == 'glu': | |
| assert self.out_channels % 2 == 0, 'The out_channels of GLU requires even number.' | |
| torch.nn.init.kaiming_uniform_(self.weight[:self.out_channels // 2], nonlinearity= 'linear') | |
| torch.nn.init.xavier_uniform_(self.weight[self.out_channels // 2:], gain= torch.nn.init.calculate_gain('sigmoid')) | |
| elif self.w_init_gain == 'gate': | |
| assert self.out_channels % 2 == 0, 'The out_channels of GLU requires even number.' | |
| torch.nn.init.xavier_uniform_(self.weight[:self.out_channels // 2], gain= torch.nn.init.calculate_gain('tanh')) | |
| torch.nn.init.xavier_uniform_(self.weight[self.out_channels // 2:], gain= torch.nn.init.calculate_gain('sigmoid')) | |
| else: | |
| torch.nn.init.xavier_uniform_(self.weight, gain= torch.nn.init.calculate_gain(self.w_init_gain)) | |
| if not self.bias is None: | |
| torch.nn.init.zeros_(self.bias) | |
| class ConvTranspose1d(torch.nn.ConvTranspose1d): | |
| def __init__(self, w_init_gain= 'linear', *args, **kwargs): | |
| self.w_init_gain = w_init_gain | |
| super().__init__(*args, **kwargs) | |
| def reset_parameters(self): | |
| if self.w_init_gain in ['zero']: | |
| torch.nn.init.zeros_(self.weight) | |
| elif self.w_init_gain in ['relu', 'leaky_relu']: | |
| torch.nn.init.kaiming_uniform_(self.weight, nonlinearity= self.w_init_gain) | |
| elif self.w_init_gain == 'glu': | |
| assert self.out_channels % 2 == 0, 'The out_channels of GLU requires even number.' | |
| torch.nn.init.kaiming_uniform_(self.weight[:self.out_channels // 2], nonlinearity= 'linear') | |
| torch.nn.init.xavier_uniform_(self.weight[self.out_channels // 2:], gain= torch.nn.init.calculate_gain('sigmoid')) | |
| elif self.w_init_gain == 'gate': | |
| assert self.out_channels % 2 == 0, 'The out_channels of GLU requires even number.' | |
| torch.nn.init.xavier_uniform_(self.weight[:self.out_channels // 2], gain= torch.nn.init.calculate_gain('tanh')) | |
| torch.nn.init.xavier_uniform_(self.weight[self.out_channels // 2:], gain= torch.nn.init.calculate_gain('sigmoid')) | |
| else: | |
| torch.nn.init.xavier_uniform_(self.weight, gain= torch.nn.init.calculate_gain(self.w_init_gain)) | |
| if not self.bias is None: | |
| torch.nn.init.zeros_(self.bias) | |
| class Conv2d(torch.nn.Conv2d): | |
| def __init__(self, w_init_gain= 'linear', *args, **kwargs): | |
| self.w_init_gain = w_init_gain | |
| super().__init__(*args, **kwargs) | |
| def reset_parameters(self): | |
| if self.w_init_gain in ['zero']: | |
| torch.nn.init.zeros_(self.weight) | |
| elif self.w_init_gain in ['relu', 'leaky_relu']: | |
| torch.nn.init.kaiming_uniform_(self.weight, nonlinearity= self.w_init_gain) | |
| elif self.w_init_gain == 'glu': | |
| assert self.out_channels % 2 == 0, 'The out_channels of GLU requires even number.' | |
| torch.nn.init.kaiming_uniform_(self.weight[:self.out_channels // 2], nonlinearity= 'linear') | |
| torch.nn.init.xavier_uniform_(self.weight[self.out_channels // 2:], gain= torch.nn.init.calculate_gain('sigmoid')) | |
| elif self.w_init_gain == 'gate': | |
| assert self.out_channels % 2 == 0, 'The out_channels of GLU requires even number.' | |
| torch.nn.init.xavier_uniform_(self.weight[:self.out_channels // 2], gain= torch.nn.init.calculate_gain('tanh')) | |
| torch.nn.init.xavier_uniform_(self.weight[self.out_channels // 2:], gain= torch.nn.init.calculate_gain('sigmoid')) | |
| else: | |
| torch.nn.init.xavier_uniform_(self.weight, gain= torch.nn.init.calculate_gain(self.w_init_gain)) | |
| if not self.bias is None: | |
| torch.nn.init.zeros_(self.bias) | |
| class ConvTranspose2d(torch.nn.ConvTranspose2d): | |
| def __init__(self, w_init_gain= 'linear', *args, **kwargs): | |
| self.w_init_gain = w_init_gain | |
| super().__init__(*args, **kwargs) | |
| def reset_parameters(self): | |
| if self.w_init_gain in ['zero']: | |
| torch.nn.init.zeros_(self.weight) | |
| elif self.w_init_gain in ['relu', 'leaky_relu']: | |
| torch.nn.init.kaiming_uniform_(self.weight, nonlinearity= self.w_init_gain) | |
| elif self.w_init_gain == 'glu': | |
| assert self.out_channels % 2 == 0, 'The out_channels of GLU requires even number.' | |
| torch.nn.init.kaiming_uniform_(self.weight[:self.out_channels // 2], nonlinearity= 'linear') | |
| torch.nn.init.xavier_uniform_(self.weight[self.out_channels // 2:], gain= torch.nn.init.calculate_gain('sigmoid')) | |
| elif self.w_init_gain == 'gate': | |
| assert self.out_channels % 2 == 0, 'The out_channels of GLU requires even number.' | |
| torch.nn.init.xavier_uniform_(self.weight[:self.out_channels // 2], gain= torch.nn.init.calculate_gain('tanh')) | |
| torch.nn.init.xavier_uniform_(self.weight[self.out_channels // 2:], gain= torch.nn.init.calculate_gain('sigmoid')) | |
| else: | |
| torch.nn.init.xavier_uniform_(self.weight, gain= torch.nn.init.calculate_gain(self.w_init_gain)) | |
| if not self.bias is None: | |
| torch.nn.init.zeros_(self.bias) | |
| class Linear(torch.nn.Linear): | |
| def __init__(self, w_init_gain= 'linear', *args, **kwargs): | |
| self.w_init_gain = w_init_gain | |
| super().__init__(*args, **kwargs) | |
| def reset_parameters(self): | |
| if self.w_init_gain in ['zero']: | |
| torch.nn.init.zeros_(self.weight) | |
| elif self.w_init_gain in ['relu', 'leaky_relu']: | |
| torch.nn.init.kaiming_uniform_(self.weight, nonlinearity= self.w_init_gain) | |
| elif self.w_init_gain == 'glu': | |
| assert self.out_channels % 2 == 0, 'The out_channels of GLU requires even number.' | |
| torch.nn.init.kaiming_uniform_(self.weight[:self.out_channels // 2], nonlinearity= 'linear') | |
| torch.nn.init.xavier_uniform_(self.weight[self.out_channels // 2:], gain= torch.nn.init.calculate_gain('sigmoid')) | |
| else: | |
| torch.nn.init.xavier_uniform_(self.weight, gain= torch.nn.init.calculate_gain(self.w_init_gain)) | |
| if not self.bias is None: | |
| torch.nn.init.zeros_(self.bias) | |
| class Lambda(torch.nn.Module): | |
| def __init__(self, lambd): | |
| super().__init__() | |
| self.lambd = lambd | |
| def forward(self, x): | |
| return self.lambd(x) | |
| class Residual(torch.nn.Module): | |
| def __init__(self, module): | |
| super().__init__() | |
| self.module = module | |
| def forward(self, *args, **kwargs): | |
| return self.module(*args, **kwargs) | |
| class LayerNorm(torch.nn.Module): | |
| def __init__(self, num_features: int, eps: float= 1e-5): | |
| super().__init__() | |
| self.eps = eps | |
| self.gamma = torch.nn.Parameter(torch.ones(num_features)) | |
| self.beta = torch.nn.Parameter(torch.zeros(num_features)) | |
| def forward(self, inputs: torch.Tensor): | |
| means = inputs.mean(dim= 1, keepdim= True) | |
| variances = (inputs - means).pow(2.0).mean(dim= 1, keepdim= True) | |
| x = (inputs - means) * (variances + self.eps).rsqrt() | |
| shape = [1, -1] + [1] * (x.ndim - 2) | |
| return x * self.gamma.view(*shape) + self.beta.view(*shape) | |
| class LightweightConv1d(torch.nn.Module): | |
| ''' | |
| Args: | |
| input_size: # of channels of the input and output | |
| kernel_size: convolution channels | |
| padding: padding | |
| num_heads: number of heads used. The weight is of shape | |
| `(num_heads, 1, kernel_size)` | |
| weight_softmax: normalize the weight with softmax before the convolution | |
| Shape: | |
| Input: BxCxT, i.e. (batch_size, input_size, timesteps) | |
| Output: BxCxT, i.e. (batch_size, input_size, timesteps) | |
| Attributes: | |
| weight: the learnable weights of the module of shape | |
| `(num_heads, 1, kernel_size)` | |
| bias: the learnable bias of the module of shape `(input_size)` | |
| ''' | |
| def __init__( | |
| self, | |
| input_size, | |
| kernel_size=1, | |
| padding=0, | |
| num_heads=1, | |
| weight_softmax=False, | |
| bias=False, | |
| weight_dropout=0.0, | |
| w_init_gain= 'linear' | |
| ): | |
| super().__init__() | |
| self.input_size = input_size | |
| self.kernel_size = kernel_size | |
| self.num_heads = num_heads | |
| self.padding = padding | |
| self.weight_softmax = weight_softmax | |
| self.weight = torch.nn.Parameter(torch.Tensor(num_heads, 1, kernel_size)) | |
| self.w_init_gain = w_init_gain | |
| if bias: | |
| self.bias = torch.nn.Parameter(torch.Tensor(input_size)) | |
| else: | |
| self.bias = None | |
| self.weight_dropout_module = FairseqDropout( | |
| weight_dropout, module_name=self.__class__.__name__ | |
| ) | |
| self.reset_parameters() | |
| def reset_parameters(self): | |
| if self.w_init_gain in ['relu', 'leaky_relu']: | |
| torch.nn.init.kaiming_uniform_(self.weight, nonlinearity= self.w_init_gain) | |
| elif self.w_init_gain == 'glu': | |
| assert self.out_channels % 2 == 0, 'The out_channels of GLU requires even number.' | |
| torch.nn.init.kaiming_uniform_(self.weight[:self.out_channels // 2], nonlinearity= 'linear') | |
| torch.nn.init.xavier_uniform_(self.weight[self.out_channels // 2:], gain= torch.nn.init.calculate_gain('sigmoid')) | |
| else: | |
| torch.nn.init.xavier_uniform_(self.weight, gain= torch.nn.init.calculate_gain(self.w_init_gain)) | |
| if not self.bias is None: | |
| torch.nn.init.zeros_(self.bias) | |
| def forward(self, input): | |
| """ | |
| input size: B x C x T | |
| output size: B x C x T | |
| """ | |
| B, C, T = input.size() | |
| H = self.num_heads | |
| weight = self.weight | |
| if self.weight_softmax: | |
| weight = weight.softmax(dim=-1) | |
| weight = self.weight_dropout_module(weight) | |
| # Merge every C/H entries into the batch dimension (C = self.input_size) | |
| # B x C x T -> (B * C/H) x H x T | |
| # One can also expand the weight to C x 1 x K by a factor of C/H | |
| # and do not reshape the input instead, which is slow though | |
| input = input.view(-1, H, T) | |
| output = torch.nn.functional.conv1d(input, weight, padding=self.padding, groups=self.num_heads) | |
| output = output.view(B, C, T) | |
| if self.bias is not None: | |
| output = output + self.bias.view(1, -1, 1) | |
| return output | |
| class FairseqDropout(torch.nn.Module): | |
| def __init__(self, p, module_name=None): | |
| super().__init__() | |
| self.p = p | |
| self.module_name = module_name | |
| self.apply_during_inference = False | |
| def forward(self, x, inplace: bool = False): | |
| if self.training or self.apply_during_inference: | |
| return torch.nn.functional.dropout(x, p=self.p, training=True, inplace=inplace) | |
| else: | |
| return x | |
| class LinearAttention(torch.nn.Module): | |
| def __init__( | |
| self, | |
| channels: int, | |
| calc_channels: int, | |
| num_heads: int, | |
| dropout_rate: float= 0.1, | |
| use_scale: bool= True, | |
| use_residual: bool= True, | |
| use_norm: bool= True | |
| ): | |
| super().__init__() | |
| assert calc_channels % num_heads == 0 | |
| self.calc_channels = calc_channels | |
| self.num_heads = num_heads | |
| self.use_scale = use_scale | |
| self.use_residual = use_residual | |
| self.use_norm = use_norm | |
| self.prenet = Conv1d( | |
| in_channels= channels, | |
| out_channels= calc_channels * 3, | |
| kernel_size= 1, | |
| bias=False, | |
| w_init_gain= 'linear' | |
| ) | |
| self.projection = Conv1d( | |
| in_channels= calc_channels, | |
| out_channels= channels, | |
| kernel_size= 1, | |
| w_init_gain= 'linear' | |
| ) | |
| self.dropout = torch.nn.Dropout(p= dropout_rate) | |
| if use_scale: | |
| self.scale = torch.nn.Parameter(torch.zeros(1)) | |
| if use_norm: | |
| self.norm = LayerNorm(num_features= channels) | |
| def forward(self, x: torch.Tensor, *args, **kwargs): | |
| ''' | |
| x: [Batch, Enc_d, Enc_t] | |
| ''' | |
| residuals = x | |
| x = self.prenet(x) # [Batch, Calc_d * 3, Enc_t] | |
| x = x.view(x.size(0), self.num_heads, x.size(1) // self.num_heads, x.size(2)) # [Batch, Head, Calc_d // Head * 3, Enc_t] | |
| queries, keys, values = x.chunk(chunks= 3, dim= 2) # [Batch, Head, Calc_d // Head, Enc_t] * 3 | |
| keys = (keys + 1e-5).softmax(dim= 3) | |
| contexts = keys @ values.permute(0, 1, 3, 2) # [Batch, Head, Calc_d // Head, Calc_d // Head] | |
| contexts = contexts.permute(0, 1, 3, 2) @ queries # [Batch, Head, Calc_d // Head, Enc_t] | |
| contexts = contexts.view(contexts.size(0), contexts.size(1) * contexts.size(2), contexts.size(3)) # [Batch, Calc_d, Enc_t] | |
| contexts = self.projection(contexts) # [Batch, Enc_d, Enc_t] | |
| if self.use_scale: | |
| contexts = self.scale * contexts | |
| contexts = self.dropout(contexts) | |
| if self.use_residual: | |
| contexts = contexts + residuals | |
| if self.use_norm: | |
| contexts = self.norm(contexts) | |
| return contexts | |