noblebarkrr's picture
Updated to Dzeta
4f175c5
from typing import List, Tuple
import torch
import torch.nn as nn
from models.scnet_unofficial.utils import get_convtranspose_output_padding
class FusionLayer(nn.Module):
def __init__(
self, input_dim: int, kernel_size: int = 3, stride: int = 1, padding: int = 1
):
super().__init__()
self.conv = nn.Conv2d(
input_dim * 2,
input_dim * 2,
kernel_size=(kernel_size, 1),
stride=(stride, 1),
padding=(padding, 0),
)
self.activation = nn.GLU()
def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
x = x1 + x2
x = x.repeat(1, 1, 1, 2)
x = self.conv(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
x = self.activation(x)
return x
class Upsample(nn.Module):
def __init__(
self, input_dim: int, output_dim: int, stride: int, output_padding: int
):
super().__init__()
self.conv = nn.ConvTranspose2d(
input_dim, output_dim, 1, (stride, 1), output_padding=(output_padding, 0)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.conv(x)
class SULayer(nn.Module):
def __init__(
self,
input_dim: int,
output_dim: int,
upsample_stride: int,
subband_shape: int,
sd_interval: Tuple[int, int],
):
super().__init__()
sd_shape = sd_interval[1] - sd_interval[0]
upsample_output_padding = get_convtranspose_output_padding(
input_shape=sd_shape, output_shape=subband_shape, stride=upsample_stride
)
self.upsample = Upsample(
input_dim=input_dim,
output_dim=output_dim,
stride=upsample_stride,
output_padding=upsample_output_padding,
)
self.sd_interval = sd_interval
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x[:, self.sd_interval[0] : self.sd_interval[1]]
x = x.permute(0, 3, 1, 2)
x = self.upsample(x)
x = x.permute(0, 2, 3, 1)
return x
class SUBlock(nn.Module):
def __init__(
self,
input_dim: int,
output_dim: int,
upsample_strides: List[int],
subband_shapes: List[int],
sd_intervals: List[Tuple[int, int]],
):
super().__init__()
self.fusion_layer = FusionLayer(input_dim=input_dim)
self.su_layers = nn.ModuleList(
SULayer(
input_dim=input_dim,
output_dim=output_dim,
upsample_stride=uss,
subband_shape=sbs,
sd_interval=sdi,
)
for i, (uss, sbs, sdi) in enumerate(
zip(upsample_strides, subband_shapes, sd_intervals)
)
)
def forward(self, x: torch.Tensor, x_skip: torch.Tensor) -> torch.Tensor:
x = self.fusion_layer(x, x_skip)
x = torch.concat([layer(x) for layer in self.su_layers], dim=1)
return x