import torch import torch.nn as nn import torch.nn.functional as F class Stem(nn.Module): def __init__(self): super(Stem, self).__init__() self.conv = nn.Sequential( nn.Conv2d(3, 64, kernel_size=3, stride=2), nn.MaxPool2d(kernel_size=3, stride=2), ) def forward(self, x): x = self.conv(x) return x class ResidualBlock(nn.Module): def __init__(self, in_channels, out_channels, stride=1): super().__init__() self.conv1 = nn.Sequential( nn.Conv2d( in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False, ), nn.BatchNorm2d(out_channels), nn.LeakyReLU(inplace=True), ) self.conv2 = nn.Sequential( nn.Conv2d( out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False, ), nn.BatchNorm2d(out_channels), ) self.shortcut = ( nn.Identity() if in_channels == out_channels and stride == 1 else nn.Sequential( nn.Conv2d( in_channels, out_channels, kernel_size=1, stride=stride, bias=False ), nn.BatchNorm2d(out_channels), ) ) self.act = nn.LeakyReLU(inplace=True) def forward(self, x): identity = self.shortcut(x) x = self.conv1(x) x = self.conv2(x) x += identity return self.act(x) class FromZero(nn.Module): def __init__(self, num_classes=10): super(FromZero, self).__init__() self.stem = nn.Sequential(Stem()) self.layer1 = nn.Sequential(ResidualBlock(64, 64), ResidualBlock(64, 64)) self.layer2 = nn.Sequential( ResidualBlock(64, 128, stride=2), ResidualBlock(128, 128) ) self.layer3 = nn.Sequential( ResidualBlock(128, 256, stride=2), ResidualBlock(256, 256) ) self.layer4 = nn.Sequential( ResidualBlock(256, 512, stride=2), ResidualBlock(512, 512), nn.Dropout(0.2) ) self.flatten = nn.Flatten() self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.fc = nn.Sequential( nn.Linear(512, num_classes), ) def forward(self, x): x = self.stem(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.avgpool(x) x = self.flatten(x) x = self.fc(x) return x class SRResidualBlock(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=3, stride=1): super().__init__() mid_channels = out_channels // 2 self.conv1 = nn.Sequential( nn.Conv2d(in_channels, mid_channels, kernel_size=1, stride=1), nn.GroupNorm(8, mid_channels), nn.SiLU(inplace=True), nn.Conv2d( mid_channels, mid_channels, kernel_size=kernel_size, padding=1, stride=stride, ), nn.GroupNorm(8, mid_channels), nn.SiLU(inplace=True), nn.Conv2d(mid_channels, out_channels, kernel_size=1, stride=1), nn.GroupNorm(8, out_channels), ) self.relu = nn.SiLU(inplace=True) if stride != 1 or in_channels != out_channels: self.shortcut = nn.Sequential( nn.Conv2d( in_channels, out_channels, kernel_size=3, stride=stride, padding=1, ), nn.GroupNorm(8, out_channels), ) else: self.shortcut = nn.Identity() def forward(self, x): shortcut = self.shortcut(x) out = self.conv1(x) out += shortcut out = self.relu(out) return out class DownBlock(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=3): super().__init__() self.skip_block = SRResidualBlock( in_channels, out_channels, kernel_size=kernel_size, stride=1 ) self.downsample_block = SRResidualBlock( out_channels, out_channels, kernel_size=kernel_size, stride=2 ) def forward(self, x): skip = self.skip_block(x) out_down = self.downsample_block(skip) return skip, out_down class SRUp(nn.Module): def __init__(self, in_channels, out_channels, scale_factor=2): super().__init__() self.conv = nn.Conv2d( in_channels, out_channels * (scale_factor**2), kernel_size=3, padding=1, ) self.shuffle = nn.PixelShuffle(scale_factor) self.act = nn.SiLU() def forward(self, x): x = self.conv(x) x = self.shuffle(x) return self.act(x) class SRHead(nn.Module): def __init__(self, channels, out_channels): super().__init__() self.sr1 = SRUp(channels, channels) self.sr2 = SRUp(channels, channels) self.res = SRResidualBlock(channels, channels) self.res2 = SRResidualBlock(channels, channels) self.head = nn.Conv2d(channels, out_channels, kernel_size=3, padding=1) def forward(self, x): x = self.sr1(x) x = self.res(x) x = self.sr2(x) x = self.res2(x) x = self.head(x) return x class InputInjection(nn.Module): def __init__(self, in_channels, target_channels): super().__init__() self.proj = nn.Conv2d(in_channels, target_channels, kernel_size=1) def forward(self, x, target_size): x = F.interpolate(x, size=target_size, mode="bilinear", align_corners=False) return self.proj(x) class UNetSR(nn.Module): def __init__( self, in_channels=3, out_channels=3, ): super().__init__() self.residual_scaling = 0.1 self.input_inject256 = InputInjection(in_channels, 256) self.input_inject128 = InputInjection(in_channels, 128) self.input_inject64 = InputInjection(in_channels, 64) self.down1 = DownBlock(in_channels, 64) self.down2 = DownBlock(64, 128) self.down3 = DownBlock(128, 256) self.bottleneck = SRResidualBlock(256, 256) self.pre_up3 = nn.ConvTranspose2d( 256, 128, kernel_size=3, stride=2, padding=1, output_padding=1 ) self.up3 = SRResidualBlock(128 + 256, 128) self.pre_up2 = nn.ConvTranspose2d( 128, 64, kernel_size=3, stride=2, padding=1, output_padding=1 ) self.up2 = SRResidualBlock(128 + 64, 64) self.pre_up1 = nn.ConvTranspose2d( 64, 64, kernel_size=3, stride=2, padding=1, output_padding=1 ) self.up1 = SRResidualBlock(128, 64) self.sr_head = SRHead(64, out_channels) def forward(self, x): residual = x x1, x = self.down1(x) x = x + self.residual_scaling * self.input_inject64(residual, x.shape[2:]) x2, x = self.down2(x) x = x + self.residual_scaling * self.input_inject128(residual, x.shape[2:]) x3, x = self.down3(x) x = self.bottleneck(x) x = self.pre_up3(x) x = torch.cat([x, x3], dim=1) x = self.up3(x) x = self.pre_up2(x) x = torch.cat([x, x2], dim=1) x = self.up2(x) x = x + self.residual_scaling * self.input_inject64(residual, x.shape[2:]) x = self.pre_up1(x) x = torch.cat([x, x1], dim=1) x = self.up1(x) x = x + self.residual_scaling * self.input_inject64(residual, x.shape[2:]) out = self.sr_head(x) out = ( out + F.interpolate( residual, size=out.shape[2:], mode="bilinear", align_corners=False ) * self.residual_scaling ) return out