| import os |
| import sys |
|
|
| import torch.nn as nn |
|
|
| sys.path.append(os.getcwd()) |
|
|
| from infer.lib.predictors.DJCM.utils import init_bn |
| from infer.lib.predictors.DJCM.decoder import PE_Decoder, SVS_Decoder |
| from infer.lib.predictors.DJCM.encoder import ResEncoderBlock, Encoder |
|
|
| class LatentBlocks(nn.Module): |
| def __init__( |
| self, |
| n_blocks, |
| latent_layers |
| ): |
| super(LatentBlocks, self).__init__() |
| self.latent_blocks = nn.ModuleList([ |
| ResEncoderBlock( |
| 384, |
| 384, |
| n_blocks, |
| None |
| ) |
| for _ in range(latent_layers) |
| ]) |
|
|
| def forward(self, x): |
| for layer in self.latent_blocks: |
| x = layer(x) |
|
|
| return x |
|
|
| class DJCMM(nn.Module): |
| def __init__( |
| self, |
| in_channels, |
| n_blocks, |
| latent_layers, |
| svs=False, |
| window_length=1024, |
| n_class=360 |
| ): |
| super(DJCMM, self).__init__() |
| self.bn = nn.BatchNorm2d( |
| window_length // 2 + 1, |
| momentum=0.01 |
| ) |
| self.pe_encoder = Encoder( |
| in_channels, |
| n_blocks |
| ) |
| self.pe_latent = LatentBlocks( |
| n_blocks, |
| latent_layers |
| ) |
| self.pe_decoder = PE_Decoder( |
| n_blocks, |
| window_length=window_length, |
| n_class=n_class |
| ) |
|
|
| self.svs = svs |
|
|
| if svs: |
| self.svs_encoder = Encoder( |
| in_channels, |
| n_blocks |
| ) |
| self.svs_latent = LatentBlocks( |
| n_blocks, |
| latent_layers |
| ) |
| self.svs_decoder = SVS_Decoder( |
| in_channels, |
| n_blocks |
| ) |
|
|
| init_bn(self.bn) |
|
|
| def spec(self, x, spec_m): |
| bs, c, time_steps, freqs_steps = x.shape |
| x = x.reshape(bs, c // 4, 4, time_steps, freqs_steps) |
|
|
| mask_spec = x[:, :, 0, :, :].sigmoid() |
| linear_spec = x[:, :, 3, :, :] |
|
|
| out_spec = ( |
| spec_m.detach() * mask_spec + linear_spec |
| ).relu() |
|
|
| return out_spec |
|
|
| def forward(self, spec): |
| x = self.bn( |
| spec.transpose(1, 3) |
| ).transpose(1, 3)[..., :-1] |
|
|
| if self.svs: |
| x, concat_tensors = self.svs_encoder(x) |
|
|
| x = self.svs_decoder( |
| self.svs_latent(x), |
| concat_tensors |
| ) |
|
|
| x = self.spec( |
| nn.functional.pad(x, pad=(0, 1)), |
| spec |
| )[..., :-1] |
|
|
| x, concat_tensors = self.pe_encoder(x) |
|
|
| pe_out = self.pe_decoder( |
| self.pe_latent(x), |
| concat_tensors |
| ) |
|
|
| return pe_out |