| import os | |
| import sys | |
| import torch.nn as nn | |
| sys.path.append(os.getcwd()) | |
| from infer.lib.predictors.DJCM.utils import init_bn | |
| from infer.lib.predictors.DJCM.decoder import PE_Decoder, SVS_Decoder | |
| from infer.lib.predictors.DJCM.encoder import ResEncoderBlock, Encoder | |
| class LatentBlocks(nn.Module): | |
| def __init__( | |
| self, | |
| n_blocks, | |
| latent_layers | |
| ): | |
| super(LatentBlocks, self).__init__() | |
| self.latent_blocks = nn.ModuleList([ | |
| ResEncoderBlock( | |
| 384, | |
| 384, | |
| n_blocks, | |
| None | |
| ) | |
| for _ in range(latent_layers) | |
| ]) | |
| def forward(self, x): | |
| for layer in self.latent_blocks: | |
| x = layer(x) | |
| return x | |
| class DJCMM(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels, | |
| n_blocks, | |
| latent_layers, | |
| svs=False, | |
| window_length=1024, | |
| n_class=360 | |
| ): | |
| super(DJCMM, self).__init__() | |
| self.bn = nn.BatchNorm2d( | |
| window_length // 2 + 1, | |
| momentum=0.01 | |
| ) | |
| self.pe_encoder = Encoder( | |
| in_channels, | |
| n_blocks | |
| ) | |
| self.pe_latent = LatentBlocks( | |
| n_blocks, | |
| latent_layers | |
| ) | |
| self.pe_decoder = PE_Decoder( | |
| n_blocks, | |
| window_length=window_length, | |
| n_class=n_class | |
| ) | |
| self.svs = svs | |
| if svs: | |
| self.svs_encoder = Encoder( | |
| in_channels, | |
| n_blocks | |
| ) | |
| self.svs_latent = LatentBlocks( | |
| n_blocks, | |
| latent_layers | |
| ) | |
| self.svs_decoder = SVS_Decoder( | |
| in_channels, | |
| n_blocks | |
| ) | |
| init_bn(self.bn) | |
| def spec(self, x, spec_m): | |
| bs, c, time_steps, freqs_steps = x.shape | |
| x = x.reshape(bs, c // 4, 4, time_steps, freqs_steps) | |
| mask_spec = x[:, :, 0, :, :].sigmoid() | |
| linear_spec = x[:, :, 3, :, :] | |
| out_spec = ( | |
| spec_m.detach() * mask_spec + linear_spec | |
| ).relu() | |
| return out_spec | |
| def forward(self, spec): | |
| x = self.bn( | |
| spec.transpose(1, 3) | |
| ).transpose(1, 3)[..., :-1] | |
| if self.svs: | |
| x, concat_tensors = self.svs_encoder(x) | |
| x = self.svs_decoder( | |
| self.svs_latent(x), | |
| concat_tensors | |
| ) | |
| x = self.spec( | |
| nn.functional.pad(x, pad=(0, 1)), | |
| spec | |
| )[..., :-1] | |
| x, concat_tensors = self.pe_encoder(x) | |
| pe_out = self.pe_decoder( | |
| self.pe_latent(x), | |
| concat_tensors | |
| ) | |
| return pe_out |