Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| PT_FEATURE_SIZE = 40 | |
| class DeepDTAF(nn.Module): | |
| def __init__(self, smi_charset_len): | |
| super().__init__() | |
| smi_embed_size = 128 | |
| seq_embed_size = 128 | |
| seq_oc = 128 | |
| pkt_oc = 128 | |
| smi_oc = 128 | |
| self.smi_embed = nn.Embedding(smi_charset_len, smi_embed_size) | |
| self.seq_embed = nn.Linear(PT_FEATURE_SIZE, seq_embed_size) # (N, *, H_{in}) -> (N, *, H_{out}) | |
| conv_seq = [] | |
| ic = seq_embed_size | |
| for oc in [32, 64, 64, seq_oc]: | |
| conv_seq.append(DilatedParllelResidualBlockA(ic, oc)) | |
| ic = oc | |
| conv_seq.append(nn.AdaptiveMaxPool1d(1)) # (N, oc) | |
| conv_seq.append(Squeeze()) | |
| self.conv_seq = nn.Sequential(*conv_seq) | |
| # (N, H=32, L) | |
| conv_pkt = [] | |
| ic = seq_embed_size | |
| for oc in [32, 64, pkt_oc]: | |
| conv_pkt.append(nn.Conv1d(ic, oc, 3)) # (N,C,L) | |
| conv_pkt.append(nn.BatchNorm1d(oc)) | |
| conv_pkt.append(nn.PReLU()) | |
| ic = oc | |
| conv_pkt.append(nn.AdaptiveMaxPool1d(1)) | |
| conv_pkt.append(Squeeze()) | |
| self.conv_pkt = nn.Sequential(*conv_pkt) # (N,oc) | |
| conv_smi = [] | |
| ic = smi_embed_size | |
| for oc in [32, 64, smi_oc]: | |
| conv_smi.append(DilatedParllelResidualBlockB(ic, oc)) | |
| ic = oc | |
| conv_smi.append(nn.AdaptiveMaxPool1d(1)) | |
| conv_smi.append(Squeeze()) | |
| self.conv_smi = nn.Sequential(*conv_smi) # (N,128) | |
| self.cat_dropout = nn.Dropout(0.2) | |
| self.classifier = nn.Sequential( | |
| nn.Linear(seq_oc + pkt_oc + smi_oc, 128), | |
| nn.Dropout(0.5), | |
| nn.PReLU(), | |
| nn.Linear(128, 64), | |
| nn.Dropout(0.5), | |
| nn.PReLU(), | |
| # nn.Linear(64, 1), | |
| # nn.PReLU() | |
| ) | |
| def forward(self, seq, pkt, smi): | |
| # assert seq.shape == (N,L,43) | |
| seq_embed = self.seq_embed(seq) # (N,L,32) | |
| seq_embed = torch.transpose(seq_embed, 1, 2) # (N,32,L) | |
| seq_conv = self.conv_seq(seq_embed) # (N,128) | |
| # assert pkt.shape == (N,L,43) | |
| pkt_embed = self.seq_embed(pkt) # (N,L,32) | |
| pkt_embed = torch.transpose(pkt_embed, 1, 2) | |
| pkt_conv = self.conv_pkt(pkt_embed) # (N,128) | |
| # assert smi.shape == (N, L) | |
| smi_embed = self.smi_embed(smi) # (N,L,32) | |
| smi_embed = torch.transpose(smi_embed, 1, 2) | |
| smi_conv = self.conv_smi(smi_embed) # (N,128) | |
| cat = torch.cat([seq_conv, pkt_conv, smi_conv], dim=1) # (N,128*3) | |
| cat = self.cat_dropout(cat) | |
| output = self.classifier(cat) | |
| return output | |
| class Squeeze(nn.Module): | |
| def forward(self, input: torch.Tensor): | |
| return input.squeeze() | |
| class CDilated(nn.Module): | |
| def __init__(self, nIn, nOut, kSize, stride=1, d=1): | |
| super().__init__() | |
| padding = int((kSize - 1) / 2) * d | |
| self.conv = nn.Conv1d(nIn, nOut, kSize, stride=stride, padding=padding, bias=False, dilation=d) | |
| def forward(self, input): | |
| output = self.conv(input) | |
| return output | |
| class DilatedParllelResidualBlockA(nn.Module): | |
| def __init__(self, nIn, nOut, add=True): | |
| super().__init__() | |
| n = int(nOut / 5) | |
| n1 = nOut - 4 * n | |
| self.c1 = nn.Conv1d(nIn, n, 1, padding=0) | |
| self.br1 = nn.Sequential(nn.BatchNorm1d(n), nn.PReLU()) | |
| self.d1 = CDilated(n, n1, 3, 1, 1) # dilation rate of 2^0 | |
| self.d2 = CDilated(n, n, 3, 1, 2) # dilation rate of 2^1 | |
| self.d4 = CDilated(n, n, 3, 1, 4) # dilation rate of 2^2 | |
| self.d8 = CDilated(n, n, 3, 1, 8) # dilation rate of 2^3 | |
| self.d16 = CDilated(n, n, 3, 1, 16) # dilation rate of 2^4 | |
| self.br2 = nn.Sequential(nn.BatchNorm1d(nOut), nn.PReLU()) | |
| if nIn != nOut: | |
| # print(f'{nIn}-{nOut}: add=False') | |
| add = False | |
| self.add = add | |
| def forward(self, input): | |
| # reduce | |
| output1 = self.c1(input) | |
| output1 = self.br1(output1) | |
| # split and transform | |
| d1 = self.d1(output1) | |
| d2 = self.d2(output1) | |
| d4 = self.d4(output1) | |
| d8 = self.d8(output1) | |
| d16 = self.d16(output1) | |
| # heirarchical fusion for de-gridding | |
| add1 = d2 | |
| add2 = add1 + d4 | |
| add3 = add2 + d8 | |
| add4 = add3 + d16 | |
| # merge | |
| combine = torch.cat([d1, add1, add2, add3, add4], 1) | |
| # if residual version | |
| if self.add: | |
| combine = input + combine | |
| output = self.br2(combine) | |
| return output | |
| class DilatedParllelResidualBlockB(nn.Module): | |
| def __init__(self, nIn, nOut, add=True): | |
| super().__init__() | |
| n = int(nOut / 4) | |
| n1 = nOut - 3 * n | |
| self.c1 = nn.Conv1d(nIn, n, 1, padding=0) | |
| self.br1 = nn.Sequential(nn.BatchNorm1d(n), nn.PReLU()) | |
| self.d1 = CDilated(n, n1, 3, 1, 1) # dilation rate of 2^0 | |
| self.d2 = CDilated(n, n, 3, 1, 2) # dilation rate of 2^1 | |
| self.d4 = CDilated(n, n, 3, 1, 4) # dilation rate of 2^2 | |
| self.d8 = CDilated(n, n, 3, 1, 8) # dilation rate of 2^3 | |
| self.br2 = nn.Sequential(nn.BatchNorm1d(nOut), nn.PReLU()) | |
| if nIn != nOut: | |
| # print(f'{nIn}-{nOut}: add=False') | |
| add = False | |
| self.add = add | |
| def forward(self, input): | |
| # reduce | |
| output1 = self.c1(input) | |
| output1 = self.br1(output1) | |
| # split and transform | |
| d1 = self.d1(output1) | |
| d2 = self.d2(output1) | |
| d4 = self.d4(output1) | |
| d8 = self.d8(output1) | |
| # heirarchical fusion for de-gridding | |
| add1 = d2 | |
| add2 = add1 + d4 | |
| add3 = add2 + d8 | |
| # merge | |
| combine = torch.cat([d1, add1, add2, add3], 1) | |
| # if residual version | |
| if self.add: | |
| combine = input + combine | |
| output = self.br2(combine) | |
| return output | |