File size: 1,632 Bytes
5b9bb29 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 | import torch
import torch.nn as nn
from .upsampling import UpsamplingBlock
class SRNetwork(nn.Module):
"""Super Resolution Network with ESPCN-like backbone"""
def __init__(self, in_channels=64, out_channels=3):
super(SRNetwork, self).__init__()
self.conv_layers = nn.Sequential(
nn.Conv2d(in_channels, 64, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.ReLU(inplace=True)
)
self.upsampling = UpsamplingBlock(64)
self.final_conv = nn.Conv2d(64, out_channels, kernel_size=3, padding=1)
def forward(self, x, bicubic):
x = self.conv_layers(x)
print(f"Before upsampling: {x.shape}")
x = self.upsampling(x)
print(f"After upsampling: {x.shape}")
print(f"Bicubic shape: {bicubic.shape}")
x = self.final_conv(x)
x = x + bicubic
return x |