Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| from dijkprofile_annotator.utils import extract_img | |
| class Double_conv(nn.Module): | |
| '''(conv => ReLU) * 2 => MaxPool2d''' | |
| def __init__(self, in_ch, out_ch, p): | |
| """ | |
| Args: | |
| in_ch(int) : input channel | |
| out_ch(int) : output channel | |
| """ | |
| super(Double_conv, self).__init__() | |
| self.conv = nn.Sequential( | |
| nn.Conv1d(in_ch, out_ch, 3, padding=1, stride=1), | |
| nn.ReLU(inplace=True), | |
| nn.Conv1d(out_ch, out_ch, 5, padding=2, stride=1), | |
| nn.ReLU(inplace=True), | |
| nn.Conv1d(out_ch, out_ch, 7, padding=3, stride=1), | |
| nn.ReLU(inplace=True), | |
| nn.Dropout(p=p) | |
| ) | |
| def forward(self, x): | |
| x = self.conv(x) | |
| return x | |
| class Conv_down(nn.Module): | |
| '''(conv => ReLU) * 2 => MaxPool2d''' | |
| def __init__(self, in_ch, out_ch, p): | |
| """ | |
| Args: | |
| in_ch(int) : input channel | |
| out_ch(int) : output channel | |
| """ | |
| super(Conv_down, self).__init__() | |
| self.conv = Double_conv(in_ch, out_ch, p) | |
| self.pool = nn.MaxPool1d(kernel_size=2, stride=2, padding=0) | |
| def forward(self, x): | |
| x = self.conv(x) | |
| pool_x = self.pool(x) | |
| return pool_x, x | |
| class Conv_up(nn.Module): | |
| '''(conv => ReLU) * 2 => MaxPool2d''' | |
| def __init__(self, in_ch, out_ch, p): | |
| """ | |
| Args: | |
| in_ch(int) : input channel | |
| out_ch(int) : output channel | |
| """ | |
| super(Conv_up, self).__init__() | |
| self.up = nn.ConvTranspose1d(in_ch, out_ch, kernel_size=2, stride=2) | |
| self.conv = Double_conv(in_ch, out_ch, p) | |
| def forward(self, x1, x2): | |
| x1 = self.up(x1) | |
| x1_dim = x1.size()[2] | |
| x2 = extract_img(x1_dim, x2) | |
| x1 = torch.cat((x1, x2), dim=1) | |
| x1 = self.conv(x1) | |
| return x1 | |
| class Dijknet(nn.Module): | |
| """Dijknet convolutional neural network. 1D Unet variant.""" | |
| def __init__(self, in_channels, out_channels, p=0.25): | |
| """Dijknet convlutional neural network, 1D Unet Variant. Model is probably a bit too big | |
| for what it needs to do, but it seems to work just fine. | |
| Args: | |
| in_channels (int): number of input channels, should be 1 | |
| out_channels (int): number of output channels/classes | |
| p (float, optional): dropout chance for the dropout layers. Defaults to 0.25. | |
| """ | |
| super(Dijknet, self).__init__() | |
| self.Conv_down1 = Conv_down(in_channels, 64, p) | |
| self.Conv_down2 = Conv_down(64, 128, p) | |
| self.Conv_down3 = Conv_down(128, 256, p) | |
| self.Conv_down4 = Conv_down(256, 512, p) | |
| self.Conv_down5 = Conv_down(512, 1024, p) | |
| self.Conv_up1 = Conv_up(1024, 512, p) | |
| self.Conv_up2 = Conv_up(512, 256, p) | |
| self.Conv_up3 = Conv_up(256, 128, p) | |
| self.Conv_up4 = Conv_up(128, 64, p) | |
| self.Conv_up5 = Conv_up(128, 64, p) | |
| self.Conv_out = nn.Conv1d(64, out_channels, 1, padding=0, stride=1) | |
| self.Conv_final = nn.Conv1d(out_channels, out_channels, 1, padding=0, stride=1) | |
| def forward(self, x): | |
| x, conv1 = self.Conv_down1(x) | |
| x, conv2 = self.Conv_down2(x) | |
| x, conv3 = self.Conv_down3(x) | |
| x, conv4 = self.Conv_down4(x) | |
| _, x = self.Conv_down5(x) | |
| x = self.Conv_up1(x, conv4) | |
| x = self.Conv_up2(x, conv3) | |
| x = self.Conv_up3(x, conv2) | |
| x = self.Conv_up4(x, conv1) | |
| # final upscale to true size | |
| x = self.Conv_out(x) | |
| x = self.Conv_final(x) | |
| return x | |