| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from typing import List |
| |
|
| |
|
| | class ConvNextV2LayerNorm(nn.Module): |
| | r"""LayerNorm that supports two data formats: channels_last (default) or channels_first. |
| | The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height, |
| | width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width). |
| | """ |
| |
|
| | def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): |
| | super().__init__() |
| | self.weight = nn.Parameter(torch.ones(normalized_shape)) |
| | self.bias = nn.Parameter(torch.zeros(normalized_shape)) |
| | self.eps = eps |
| | self.data_format = data_format |
| | if self.data_format not in ["channels_last", "channels_first"]: |
| | raise NotImplementedError(f"Unsupported data format: {self.data_format}") |
| | self.normalized_shape = (normalized_shape,) |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | if self.data_format == "channels_last": |
| | x = torch.nn.functional.layer_norm( |
| | x, self.normalized_shape, self.weight, self.bias, self.eps |
| | ) |
| | elif self.data_format == "channels_first": |
| | input_dtype = x.dtype |
| | x = x.float() |
| | u = x.mean(1, keepdim=True) |
| | s = (x - u).pow(2).mean(1, keepdim=True) |
| | x = (x - u) / torch.sqrt(s + self.eps) |
| | x = x.to(dtype=input_dtype) |
| | x = self.weight[None, :, None] * x + self.bias[None, :, None] |
| | return x |
| |
|
| |
|
| | class GRN(nn.Module): |
| | def __init__(self, dim): |
| | super().__init__() |
| | self.gamma = nn.Parameter(torch.zeros(1, 1, dim)) |
| | self.beta = nn.Parameter(torch.zeros(1, 1, dim)) |
| |
|
| | def forward(self, x): |
| | Gx = torch.norm(x, p=2, dim=1, keepdim=True) |
| | Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) |
| | return self.gamma * (x * Nx) + self.beta + x |
| |
|
| | class InterpolationLayer(nn.Module): |
| | def __init__(self, ): |
| | super().__init__() |
| | pass |
| |
|
| | def forward(self, x: torch.Tensor, target_len: torch.Tensor, *args, **kwargs) -> torch.Tensor: |
| | x = F.interpolate(x, size=target_len, mode='linear') |
| | return x |
| |
|
| | class ConvNeXtV2Stage(nn.Module): |
| | def __init__( |
| | self, |
| | dim: int = 512, |
| | intermediate_dim: int = 2048, |
| | num_blocks: int = 1, |
| | dilation: int = 1, |
| | downsample_layer_indices: List[int] = None, |
| | downsample_factors: List[int] = None, |
| | upsample_layer_indices: List[int] = None, |
| | upsample_factors: List[int] = None, |
| | interpolation_layer_indices: List[int] = None, |
| | input_dim: int = None, |
| | output_dim: int = None, |
| | gin_channels: int = 0, |
| | ): |
| | super().__init__() |
| | |
| | if downsample_layer_indices is not None: |
| | assert downsample_factors is not None |
| | self.downsample_blocks = nn.ModuleList( |
| | [ |
| | nn.Sequential( |
| | ConvNextV2LayerNorm(dim, data_format="channels_first"), |
| | nn.Conv1d( |
| | dim, dim, kernel_size=downsample_factor, stride=downsample_factor |
| | ), |
| | ) for _, downsample_factor in zip(downsample_layer_indices, downsample_factors) |
| | ] |
| | ) |
| | self.downsample_layer_indices = downsample_layer_indices |
| | else: |
| | self.downsample_blocks = nn.ModuleList() |
| | self.downsample_layer_indices = [] |
| |
|
| | |
| | if upsample_layer_indices is not None: |
| | assert upsample_factors is not None |
| | self.upsample_blocks = nn.ModuleList( |
| | [ |
| | nn.Sequential( |
| | ConvNextV2LayerNorm(dim, data_format="channels_first"), |
| | nn.ConvTranspose1d( |
| | dim, dim, kernel_size=upsample_factor, stride=upsample_factor |
| | ), |
| | ) for _, upsample_factor in zip(upsample_layer_indices, upsample_factors) |
| | ] |
| | ) |
| | self.upsample_layer_indices = upsample_layer_indices |
| | else: |
| | self.upsample_blocks = nn.ModuleList() |
| | self.upsample_layer_indices = [] |
| |
|
| | |
| | if interpolation_layer_indices is not None: |
| | self.interpolation_blocks = nn.ModuleList( |
| | [ |
| | InterpolationLayer() |
| | for _ in interpolation_layer_indices |
| | ] |
| | ) |
| | self.interpolation_layer_indices = interpolation_layer_indices |
| | else: |
| | self.interpolation_blocks = nn.ModuleList() |
| | self.interpolation_layer_indices = [] |
| |
|
| | |
| | self.blocks = nn.ModuleList( |
| | [ |
| | ConvNeXtV2Block( |
| | dim=dim, |
| | intermediate_dim=intermediate_dim, |
| | dilation=dilation, |
| | ) |
| | for _ in range(num_blocks) |
| | ] |
| | ) |
| | |
| | if input_dim is not None and input_dim != dim: |
| | self.input_projection = nn.Conv1d(input_dim, dim, kernel_size=1) |
| | else: |
| | self.input_projection = nn.Identity() |
| | if output_dim is not None and output_dim != dim: |
| | self.output_projection = nn.Conv1d(dim, output_dim, kernel_size=1) |
| | else: |
| | self.output_projection = nn.Identity() |
| |
|
| | if gin_channels > 0: |
| | self.gin = nn.Conv1d(gin_channels, dim, kernel_size=1) |
| |
|
| | def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: |
| | x = self.input_projection(x) |
| | if hasattr(self, 'gin'): |
| | g = kwargs['g'] |
| | x = x + self.gin(g) |
| | |
| | if len(self.downsample_blocks) > 0: |
| | downsample_factor = 1 |
| | for factor in self.downsample_blocks: |
| | downsample_factor *= factor[1].stride[0] |
| | pad_len = downsample_factor - x.size(-1) % downsample_factor |
| | if pad_len > 0: |
| | x = torch.cat([x, torch.zeros_like(x[:, :, :pad_len])], dim=-1) |
| |
|
| | |
| | for layer_idx, block in enumerate(self.blocks): |
| | if layer_idx in self.downsample_layer_indices: |
| | x = self.downsample_blocks[self.downsample_layer_indices.index(layer_idx)](x) |
| | if layer_idx in self.upsample_layer_indices: |
| | x = self.upsample_blocks[self.upsample_layer_indices.index(layer_idx)](x) |
| | if layer_idx in self.interpolation_layer_indices: |
| | x = self.interpolation_blocks[self.interpolation_layer_indices.index(layer_idx)](x, target_len=kwargs['target_len']) |
| | x = block(x) |
| | x = self.output_projection(x) |
| | return x |
| |
|
| | def setup_caches(self, *args, **kwargs): |
| | pass |
| |
|
| |
|
| | class ConvNeXtV2Block(nn.Module): |
| | def __init__( |
| | self, |
| | dim: int, |
| | intermediate_dim: int, |
| | dilation: int = 1, |
| | ): |
| | super().__init__() |
| | padding = (dilation * (7 - 1)) // 2 |
| | self.dwconv = nn.Conv1d( |
| | dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation |
| | ) |
| | self.norm = ConvNextV2LayerNorm(dim, data_format="channels_first") |
| | self.pwconv1 = nn.Linear( |
| | dim, intermediate_dim |
| | ) |
| | self.act = nn.GELU() |
| | self.grn = GRN(intermediate_dim) |
| | self.pwconv2 = nn.Linear(intermediate_dim, dim) |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | residual = x |
| | x = self.dwconv(x) |
| | x = self.norm(x) |
| | x = x.transpose(1, 2) |
| | x = self.pwconv1(x) |
| | x = self.act(x) |
| | x = self.grn(x) |
| | x = self.pwconv2(x) |
| | x = x.transpose(1, 2) |
| | return residual + x |