File size: 3,029 Bytes
4f175c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
from typing import List, Tuple

import torch
import torch.nn as nn

from models.scnet_unofficial.utils import get_convtranspose_output_padding


class FusionLayer(nn.Module):

    def __init__(
        self, input_dim: int, kernel_size: int = 3, stride: int = 1, padding: int = 1
    ):
        super().__init__()
        self.conv = nn.Conv2d(
            input_dim * 2,
            input_dim * 2,
            kernel_size=(kernel_size, 1),
            stride=(stride, 1),
            padding=(padding, 0),
        )
        self.activation = nn.GLU()

    def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
        x = x1 + x2
        x = x.repeat(1, 1, 1, 2)
        x = self.conv(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
        x = self.activation(x)
        return x


class Upsample(nn.Module):

    def __init__(
        self, input_dim: int, output_dim: int, stride: int, output_padding: int
    ):
        super().__init__()
        self.conv = nn.ConvTranspose2d(
            input_dim, output_dim, 1, (stride, 1), output_padding=(output_padding, 0)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.conv(x)


class SULayer(nn.Module):

    def __init__(
        self,
        input_dim: int,
        output_dim: int,
        upsample_stride: int,
        subband_shape: int,
        sd_interval: Tuple[int, int],
    ):
        super().__init__()
        sd_shape = sd_interval[1] - sd_interval[0]
        upsample_output_padding = get_convtranspose_output_padding(
            input_shape=sd_shape, output_shape=subband_shape, stride=upsample_stride
        )
        self.upsample = Upsample(
            input_dim=input_dim,
            output_dim=output_dim,
            stride=upsample_stride,
            output_padding=upsample_output_padding,
        )
        self.sd_interval = sd_interval

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x[:, self.sd_interval[0] : self.sd_interval[1]]
        x = x.permute(0, 3, 1, 2)
        x = self.upsample(x)
        x = x.permute(0, 2, 3, 1)
        return x


class SUBlock(nn.Module):

    def __init__(
        self,
        input_dim: int,
        output_dim: int,
        upsample_strides: List[int],
        subband_shapes: List[int],
        sd_intervals: List[Tuple[int, int]],
    ):
        super().__init__()
        self.fusion_layer = FusionLayer(input_dim=input_dim)
        self.su_layers = nn.ModuleList(
            SULayer(
                input_dim=input_dim,
                output_dim=output_dim,
                upsample_stride=uss,
                subband_shape=sbs,
                sd_interval=sdi,
            )
            for i, (uss, sbs, sdi) in enumerate(
                zip(upsample_strides, subband_shapes, sd_intervals)
            )
        )

    def forward(self, x: torch.Tensor, x_skip: torch.Tensor) -> torch.Tensor:
        x = self.fusion_layer(x, x_skip)
        x = torch.concat([layer(x) for layer in self.su_layers], dim=1)
        return x