import torch import torch.nn as nn from .upsampling import UpsamplingBlock class SRNetwork(nn.Module): """Super Resolution Network with ESPCN-like backbone""" def __init__(self, in_channels=64, out_channels=3): super(SRNetwork, self).__init__() self.conv_layers = nn.Sequential( nn.Conv2d(in_channels, 64, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.ReLU(inplace=True) ) self.upsampling = UpsamplingBlock(64) self.final_conv = nn.Conv2d(64, out_channels, kernel_size=3, padding=1) def forward(self, x, bicubic): x = self.conv_layers(x) print(f"Before upsampling: {x.shape}") x = self.upsampling(x) print(f"After upsampling: {x.shape}") print(f"Bicubic shape: {bicubic.shape}") x = self.final_conv(x) x = x + bicubic return x