| |
| |
| |
|
|
| import math |
| from torch import nn |
| import torch.nn.functional as F |
|
|
| class SRCNN(nn.Module): |
| def __init__(self, in_chns, out_chns=None, num_chns=64, depth=8, sf=4): |
| super().__init__() |
| self.sf = sf |
| out_chns = in_chns if out_chns is None else out_chns |
|
|
| self.head = nn.Conv2d(in_chns, num_chns, kernel_size=5, padding=2) |
|
|
| body = [] |
| for _ in range(depth-1): |
| body.append(nn.Conv2d(num_chns, num_chns, kernel_size=5, padding=2)) |
| body.append(nn.LeakyReLU(0.2, inplace=True)) |
| self.body = nn.Sequential(*body) |
|
|
| tail = [] |
| for _ in range(int(math.log(sf, 2))): |
| tail.append(nn.Conv2d(num_chns, num_chns*4, kernel_size=3, padding=1)) |
| tail.append(nn.LeakyReLU(0.2, inplace=True)) |
| tail.append(nn.PixelShuffle(2)) |
| tail.append(nn.Conv2d(num_chns, out_chns, kernel_size=5, padding=2)) |
| self.tail = nn.Sequential(*tail) |
|
|
| def forward(self, x): |
| y = self.head(x) |
| y = self.body(y) |
| y = self.tail(y) |
| return y |
|
|
| class SRCNNFSR(nn.Module): |
| def __init__(self, in_chns, down_scale_factor=2, num_chns=64, depth=8, sf=4): |
| super().__init__() |
| self.sf = sf |
|
|
| head = [] |
| in_chns_shuffle = in_chns * 4 |
| assert num_chns % 4 == 0 |
| for ii in range(int(math.log(down_scale_factor, 2))): |
| head.append(nn.PixelUnshuffle(2)) |
| head.append(nn.Conv2d(in_chns_shuffle, num_chns, kernel_size=3, padding=1)) |
| if ii + 1 < int(math.log(down_scale_factor, 2)): |
| head.append(nn.Conv2d(num_chns, num_chns//4, kernel_size=5, padding=2)) |
| head.append(nn.LeakyReLU(0.2, inplace=True)) |
| in_chns_shuffle = num_chns |
| self.head = nn.Sequential(*head) |
|
|
| body = [] |
| for _ in range(depth-1): |
| body.append(nn.Conv2d(num_chns, num_chns, kernel_size=5, padding=2)) |
| body.append(nn.LeakyReLU(0.2, inplace=True)) |
| self.body = nn.Sequential(*body) |
|
|
| tail = [] |
| for _ in range(int(math.log(down_scale_factor, 2))): |
| tail.append(nn.Conv2d(num_chns, num_chns, kernel_size=3, padding=1)) |
| tail.append(nn.LeakyReLU(0.2, inplace=True)) |
| tail.append(nn.PixelShuffle(2)) |
| num_chns //= 4 |
| tail.append(nn.Conv2d(num_chns, in_chns, kernel_size=5, padding=2)) |
| self.tail = nn.Sequential(*tail) |
|
|
| def forward(self, x): |
| y = self.head(x) |
| y = self.body(y) |
| y = self.tail(y) |
| return y |
|
|