| import torch | |
| import torch.nn as nn | |
| from .upsampling import UpsamplingBlock | |
| class SRNetwork(nn.Module): | |
| """Super Resolution Network with ESPCN-like backbone""" | |
| def __init__(self, in_channels=64, out_channels=3): | |
| super(SRNetwork, self).__init__() | |
| self.conv_layers = nn.Sequential( | |
| nn.Conv2d(in_channels, 64, kernel_size=3, padding=1), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(64, 64, kernel_size=3, padding=1), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(64, 64, kernel_size=3, padding=1), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(64, 64, kernel_size=3, padding=1), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(64, 64, kernel_size=3, padding=1), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(64, 64, kernel_size=3, padding=1), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(64, 64, kernel_size=3, padding=1), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(64, 64, kernel_size=3, padding=1), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(64, 64, kernel_size=3, padding=1), | |
| nn.ReLU(inplace=True) | |
| ) | |
| self.upsampling = UpsamplingBlock(64) | |
| self.final_conv = nn.Conv2d(64, out_channels, kernel_size=3, padding=1) | |
| def forward(self, x, bicubic): | |
| x = self.conv_layers(x) | |
| print(f"Before upsampling: {x.shape}") | |
| x = self.upsampling(x) | |
| print(f"After upsampling: {x.shape}") | |
| print(f"Bicubic shape: {bicubic.shape}") | |
| x = self.final_conv(x) | |
| x = x + bicubic | |
| return x |