| |
| |
| |
| |
| |
| import torch |
| import torch.nn as nn |
| import numpy as np |
| import torch.nn.init as init |
| import torch.nn.functional as F |
| class HITVPCTeam: |
| r""" |
| DWT and IDWT block written by: Yue Cao |
| """ |
| class CALayer(nn.Module): |
| def __init__(self, channel=64, reduction=16): |
| super(HITVPCTeam.CALayer, self).__init__() |
|
|
| self.avg_pool = nn.AdaptiveAvgPool2d(1) |
| self.conv_du = nn.Sequential( |
| nn.Conv2d(channel, channel//reduction, 1, padding=0, bias=True), |
| nn.ReLU(inplace=True), |
| nn.Conv2d(channel//reduction, channel, 1, padding=0, bias=True), |
| nn.Sigmoid() |
| ) |
|
|
| def forward(self, x): |
| y = self.avg_pool(x) |
| y = self.conv_du(y) |
| return x * y |
|
|
| |
| class RB(nn.Module): |
| def __init__(self, filters): |
| super(HITVPCTeam.RB, self).__init__() |
| self.conv1 = nn.Conv2d(filters, filters, 3, 1, 1) |
| self.act = nn.PReLU() |
| self.conv2 = nn.Conv2d(filters, filters, 3, 1, 1) |
| self.cuca = HITVPCTeam.CALayer(channel=filters) |
|
|
| def forward(self, x): |
| c0 = x |
| x = self.conv1(x) |
| x = self.act(x) |
| x = self.conv2(x) |
| out = self.cuca(x) |
| return out + c0 |
|
|
| class NRB(nn.Module): |
| def __init__(self, n, f): |
| super(HITVPCTeam.NRB, self).__init__() |
| nets = [] |
| for i in range(n): |
| nets.append(HITVPCTeam.RB(f)) |
| self.body = nn.Sequential(*nets) |
| self.tail = nn.Conv2d(f, f, 3, 1, 1) |
|
|
| def forward(self, x): |
| return x + self.tail(self.body(x)) |
|
|
| class DWTForward(nn.Module): |
| def __init__(self): |
| super(HITVPCTeam.DWTForward, self).__init__() |
| ll = np.array([[0.5, 0.5], [0.5, 0.5]]) |
| lh = np.array([[-0.5, -0.5], [0.5, 0.5]]) |
| hl = np.array([[-0.5, 0.5], [-0.5, 0.5]]) |
| hh = np.array([[0.5, -0.5], [-0.5, 0.5]]) |
| filts = np.stack([ll[None,::-1,::-1], lh[None,::-1,::-1], |
| hl[None,::-1,::-1], hh[None,::-1,::-1]], |
| axis=0) |
| self.weight = nn.Parameter( |
| torch.tensor(filts).to(torch.get_default_dtype()), |
| requires_grad=False) |
| def forward(self, x): |
| C = x.shape[1] |
| filters = torch.cat([self.weight,] * C, dim=0) |
| y = F.conv2d(x, filters, groups=C, stride=2) |
| return y |
|
|
| class DWTInverse(nn.Module): |
| def __init__(self): |
| super(HITVPCTeam.DWTInverse, self).__init__() |
| ll = np.array([[0.5, 0.5], [0.5, 0.5]]) |
| lh = np.array([[-0.5, -0.5], [0.5, 0.5]]) |
| hl = np.array([[-0.5, 0.5], [-0.5, 0.5]]) |
| hh = np.array([[0.5, -0.5], [-0.5, 0.5]]) |
| filts = np.stack([ll[None, ::-1, ::-1], lh[None, ::-1, ::-1], |
| hl[None, ::-1, ::-1], hh[None, ::-1, ::-1]], |
| axis=0) |
| self.weight = nn.Parameter( |
| torch.tensor(filts).to(torch.get_default_dtype()), |
| requires_grad=False) |
|
|
| def forward(self, x): |
| C = int(x.shape[1] / 4) |
| filters = torch.cat([self.weight, ] * C, dim=0) |
| y = F.conv_transpose2d(x, filters, groups=C, stride=2) |
| return y |
|
|
|
|
| class Net(nn.Module): |
| def __init__(self, channels=1, filters_level1=96, filters_level2=256//2, filters_level3=256//2, n_rb=4*5): |
| super(Net, self).__init__() |
|
|
| self.head = HITVPCTeam.DWTForward() |
|
|
| self.down1 = nn.Sequential( |
| nn.Conv2d(channels * 4, filters_level1, 3, 1, 1), |
| nn.PReLU(), |
| HITVPCTeam.NRB(n_rb, filters_level1)) |
|
|
| |
| |
|
|
| |
| self.down2 = nn.Sequential( |
| HITVPCTeam.DWTForward(), |
| nn.Conv2d(filters_level1 * 4, filters_level2, 3, 1, 1), |
| nn.PReLU(), |
| HITVPCTeam.NRB(n_rb, filters_level2)) |
|
|
| self.down3 = nn.Sequential( |
| HITVPCTeam.DWTForward(), |
| nn.Conv2d(filters_level2 * 4, filters_level3, 3, 1, 1), |
| nn.PReLU()) |
|
|
| self.middle = HITVPCTeam.NRB(n_rb, filters_level3) |
|
|
| self.up1 = nn.Sequential( |
| nn.Conv2d(filters_level3, filters_level2 * 4, 3, 1, 1), |
| nn.PReLU(), |
| HITVPCTeam.DWTInverse()) |
|
|
| self.up2 = nn.Sequential( |
| HITVPCTeam.NRB(n_rb, filters_level2), |
| nn.Conv2d(filters_level2, filters_level1 * 4, 3, 1, 1), |
| nn.PReLU(), |
| HITVPCTeam.DWTInverse()) |
|
|
| self.up3 = nn.Sequential( |
| HITVPCTeam.NRB(n_rb, filters_level1), |
| nn.Conv2d(filters_level1, channels * 4, 3, 1, 1)) |
|
|
| self.tail = HITVPCTeam.DWTInverse() |
|
|
| def forward(self, inputs): |
| c0 = inputs |
| c1 = self.head(c0) |
| c2 = self.down1(c1) |
| c3 = self.down2(c2) |
| c4 = self.down3(c3) |
| m = self.middle(c4) |
| c5 = self.up1(m) + c3 |
| c6 = self.up2(c5) + c2 |
| c7 = self.up3(c6) + c1 |
| return self.tail(c7) |
|
|
| def _initialize_weights(self): |
| for m in self.modules(): |
| if isinstance(m, nn.Conv2d): |
| init.orthogonal_(m.weight) |
| print('init weight') |
| if m.bias is not None: |
| init.constant_(m.bias, 0) |
| elif isinstance(m, nn.BatchNorm2d): |
| init.constant_(m.weight, 1) |
| init.constant_(m.bias, 0) |