SeeSharp / ersvr /models /sr_network.py
Abhinavexists's picture
Upload folder using huggingface_hub
5b9bb29 verified
import torch
import torch.nn as nn
from .upsampling import UpsamplingBlock
class SRNetwork(nn.Module):
"""Super Resolution Network with ESPCN-like backbone"""
def __init__(self, in_channels=64, out_channels=3):
super(SRNetwork, self).__init__()
self.conv_layers = nn.Sequential(
nn.Conv2d(in_channels, 64, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.ReLU(inplace=True)
)
self.upsampling = UpsamplingBlock(64)
self.final_conv = nn.Conv2d(64, out_channels, kernel_size=3, padding=1)
def forward(self, x, bicubic):
x = self.conv_layers(x)
print(f"Before upsampling: {x.shape}")
x = self.upsampling(x)
print(f"After upsampling: {x.shape}")
print(f"Bicubic shape: {bicubic.shape}")
x = self.final_conv(x)
x = x + bicubic
return x