File size: 2,807 Bytes
30f8290 1b3535e 30f8290 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
import os
import sys
import torch.nn as nn
sys.path.append(os.getcwd())
from infer.lib.predictors.DJCM.utils import init_bn
from infer.lib.predictors.DJCM.decoder import PE_Decoder, SVS_Decoder
from infer.lib.predictors.DJCM.encoder import ResEncoderBlock, Encoder
class LatentBlocks(nn.Module):
def __init__(
self,
n_blocks,
latent_layers
):
super(LatentBlocks, self).__init__()
self.latent_blocks = nn.ModuleList([
ResEncoderBlock(
384,
384,
n_blocks,
None
)
for _ in range(latent_layers)
])
def forward(self, x):
for layer in self.latent_blocks:
x = layer(x)
return x
class DJCMM(nn.Module):
def __init__(
self,
in_channels,
n_blocks,
latent_layers,
svs=False,
window_length=1024,
n_class=360
):
super(DJCMM, self).__init__()
self.bn = nn.BatchNorm2d(
window_length // 2 + 1,
momentum=0.01
)
self.pe_encoder = Encoder(
in_channels,
n_blocks
)
self.pe_latent = LatentBlocks(
n_blocks,
latent_layers
)
self.pe_decoder = PE_Decoder(
n_blocks,
window_length=window_length,
n_class=n_class
)
self.svs = svs
if svs:
self.svs_encoder = Encoder(
in_channels,
n_blocks
)
self.svs_latent = LatentBlocks(
n_blocks,
latent_layers
)
self.svs_decoder = SVS_Decoder(
in_channels,
n_blocks
)
init_bn(self.bn)
def spec(self, x, spec_m):
bs, c, time_steps, freqs_steps = x.shape
x = x.reshape(bs, c // 4, 4, time_steps, freqs_steps)
mask_spec = x[:, :, 0, :, :].sigmoid()
linear_spec = x[:, :, 3, :, :]
out_spec = (
spec_m.detach() * mask_spec + linear_spec
).relu()
return out_spec
def forward(self, spec):
x = self.bn(
spec.transpose(1, 3)
).transpose(1, 3)[..., :-1]
if self.svs:
x, concat_tensors = self.svs_encoder(x)
x = self.svs_decoder(
self.svs_latent(x),
concat_tensors
)
x = self.spec(
nn.functional.pad(x, pad=(0, 1)),
spec
)[..., :-1]
x, concat_tensors = self.pe_encoder(x)
pe_out = self.pe_decoder(
self.pe_latent(x),
concat_tensors
)
return pe_out |