| from typing import List, Tuple |
|
|
| import torch |
| import torch.nn as nn |
|
|
| from models.scnet_unofficial.utils import get_convtranspose_output_padding |
|
|
|
|
| class FusionLayer(nn.Module): |
|
|
| def __init__( |
| self, input_dim: int, kernel_size: int = 3, stride: int = 1, padding: int = 1 |
| ): |
| super().__init__() |
| self.conv = nn.Conv2d( |
| input_dim * 2, |
| input_dim * 2, |
| kernel_size=(kernel_size, 1), |
| stride=(stride, 1), |
| padding=(padding, 0), |
| ) |
| self.activation = nn.GLU() |
|
|
| def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: |
| x = x1 + x2 |
| x = x.repeat(1, 1, 1, 2) |
| x = self.conv(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) |
| x = self.activation(x) |
| return x |
|
|
|
|
| class Upsample(nn.Module): |
|
|
| def __init__( |
| self, input_dim: int, output_dim: int, stride: int, output_padding: int |
| ): |
| super().__init__() |
| self.conv = nn.ConvTranspose2d( |
| input_dim, output_dim, 1, (stride, 1), output_padding=(output_padding, 0) |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.conv(x) |
|
|
|
|
| class SULayer(nn.Module): |
|
|
| def __init__( |
| self, |
| input_dim: int, |
| output_dim: int, |
| upsample_stride: int, |
| subband_shape: int, |
| sd_interval: Tuple[int, int], |
| ): |
| super().__init__() |
| sd_shape = sd_interval[1] - sd_interval[0] |
| upsample_output_padding = get_convtranspose_output_padding( |
| input_shape=sd_shape, output_shape=subband_shape, stride=upsample_stride |
| ) |
| self.upsample = Upsample( |
| input_dim=input_dim, |
| output_dim=output_dim, |
| stride=upsample_stride, |
| output_padding=upsample_output_padding, |
| ) |
| self.sd_interval = sd_interval |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x = x[:, self.sd_interval[0] : self.sd_interval[1]] |
| x = x.permute(0, 3, 1, 2) |
| x = self.upsample(x) |
| x = x.permute(0, 2, 3, 1) |
| return x |
|
|
|
|
| class SUBlock(nn.Module): |
|
|
| def __init__( |
| self, |
| input_dim: int, |
| output_dim: int, |
| upsample_strides: List[int], |
| subband_shapes: List[int], |
| sd_intervals: List[Tuple[int, int]], |
| ): |
| super().__init__() |
| self.fusion_layer = FusionLayer(input_dim=input_dim) |
| self.su_layers = nn.ModuleList( |
| SULayer( |
| input_dim=input_dim, |
| output_dim=output_dim, |
| upsample_stride=uss, |
| subband_shape=sbs, |
| sd_interval=sdi, |
| ) |
| for i, (uss, sbs, sdi) in enumerate( |
| zip(upsample_strides, subband_shapes, sd_intervals) |
| ) |
| ) |
|
|
| def forward(self, x: torch.Tensor, x_skip: torch.Tensor) -> torch.Tensor: |
| x = self.fusion_layer(x, x_skip) |
| x = torch.concat([layer(x) for layer in self.su_layers], dim=1) |
| return x |
|
|