import torch.nn as nn class SubpixelUpsampling(nn.Module): """Subpixel Upsampling Module using PixelShuffle""" def __init__(self, in_channels, scale_factor=2): super(SubpixelUpsampling, self).__init__() self.scale_factor = scale_factor self.conv = nn.Conv2d( in_channels, in_channels * (scale_factor ** 2), kernel_size=3, padding=1 ) self.pixel_shuffle = nn.PixelShuffle(scale_factor) def forward(self, x): x = self.conv(x) x = self.pixel_shuffle(x) return x class UpsamplingBlock(nn.Module): """Block for 4x upsampling using two SubpixelUpsampling modules""" def __init__(self, in_channels): super(UpsamplingBlock, self).__init__() self.upsample1 = SubpixelUpsampling(in_channels) self.upsample2 = SubpixelUpsampling(in_channels) def forward(self, x): x = self.upsample1(x) x = self.upsample2(x) return x