File size: 2,807 Bytes
30f8290
 
 
 
 
 
 
1b3535e
 
 
30f8290
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
import sys

import torch.nn as nn

sys.path.append(os.getcwd())

from infer.lib.predictors.DJCM.utils import init_bn
from infer.lib.predictors.DJCM.decoder import PE_Decoder, SVS_Decoder
from infer.lib.predictors.DJCM.encoder import ResEncoderBlock, Encoder

class LatentBlocks(nn.Module):
    def __init__(
        self, 
        n_blocks, 
        latent_layers
    ):
        super(LatentBlocks, self).__init__()
        self.latent_blocks = nn.ModuleList([
            ResEncoderBlock(
                384, 
                384, 
                n_blocks, 
                None
            ) 
            for _ in range(latent_layers)
        ])

    def forward(self, x):
        for layer in self.latent_blocks:
            x = layer(x)

        return x

class DJCMM(nn.Module):
    def __init__(
        self, 
        in_channels, 
        n_blocks, 
        latent_layers, 
        svs=False, 
        window_length=1024, 
        n_class=360
    ):
        super(DJCMM, self).__init__()
        self.bn = nn.BatchNorm2d(
            window_length // 2 + 1, 
            momentum=0.01
        )
        self.pe_encoder = Encoder(
            in_channels, 
            n_blocks
        )
        self.pe_latent = LatentBlocks(
            n_blocks, 
            latent_layers
        )
        self.pe_decoder = PE_Decoder(
            n_blocks, 
            window_length=window_length, 
            n_class=n_class
        )

        self.svs = svs

        if svs:
            self.svs_encoder = Encoder(
                in_channels, 
                n_blocks
            )
            self.svs_latent = LatentBlocks(
                n_blocks, 
                latent_layers
            )
            self.svs_decoder = SVS_Decoder(
                in_channels, 
                n_blocks
            )

        init_bn(self.bn)

    def spec(self, x, spec_m):
        bs, c, time_steps, freqs_steps = x.shape
        x = x.reshape(bs, c // 4, 4, time_steps, freqs_steps)

        mask_spec = x[:, :, 0, :, :].sigmoid()
        linear_spec = x[:, :, 3, :, :]

        out_spec = (
            spec_m.detach() * mask_spec + linear_spec
        ).relu()

        return out_spec

    def forward(self, spec):
        x = self.bn(
            spec.transpose(1, 3)
        ).transpose(1, 3)[..., :-1]

        if self.svs:
            x, concat_tensors = self.svs_encoder(x)

            x = self.svs_decoder(
                self.svs_latent(x), 
                concat_tensors
            )

            x = self.spec(
                nn.functional.pad(x, pad=(0, 1)), 
                spec
            )[..., :-1]

        x, concat_tensors = self.pe_encoder(x)

        pe_out = self.pe_decoder(
            self.pe_latent(x), 
            concat_tensors
        )

        return pe_out