| import torch | |
| import torch.nn as nn | |
| from network.architecture import * | |
| class Decom(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.decom = nn.Sequential( | |
| get_conv2d_layer(in_c=3, out_c=32, k=3, s=1, p=1), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| get_conv2d_layer(in_c=32, out_c=32, k=3, s=1, p=1), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| get_conv2d_layer(in_c=32, out_c=32, k=3, s=1, p=1), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| get_conv2d_layer(in_c=32, out_c=4, k=3, s=1, p=1), | |
| nn.ReLU() | |
| ) | |
| def forward(self, input): | |
| output = self.decom(input) | |
| R = output[:, 0:3, :, :] | |
| L = output[:, 3:4, :, :] | |
| return R, L |