import torch import torch.nn as nn import torchvision.models as models # -------- Residual Block -------- class ResidualBlock(nn.Module): def __init__(self, channels): super().__init__() self.block = nn.Sequential( nn.Conv2d(channels, channels, 3, 1, 1), nn.InstanceNorm2d(channels, affine=True), nn.ReLU(inplace=True), nn.Conv2d(channels, channels, 3, 1, 1), nn.InstanceNorm2d(channels, affine=True), ) def forward(self, x): return x + self.block(x) # -------- Transformer Network -------- class TransformerNet(nn.Module): def __init__(self): super().__init__() self.model = nn.Sequential( nn.Conv2d(3, 32, 9, 1, 4), nn.InstanceNorm2d(32), nn.ReLU(), nn.Conv2d(32, 64, 3, 2, 1), nn.InstanceNorm2d(64), nn.ReLU(), nn.Conv2d(64, 128, 3, 2, 1), nn.InstanceNorm2d(128), nn.ReLU(), ResidualBlock(128), ResidualBlock(128), ResidualBlock(128), ResidualBlock(128), ResidualBlock(128), nn.ConvTranspose2d(128, 64, 3, 2, 1, 1), nn.InstanceNorm2d(64), nn.ReLU(), nn.ConvTranspose2d(64, 32, 3, 2, 1, 1), nn.InstanceNorm2d(32), nn.ReLU(), nn.Conv2d(32, 3, 9, 1, 4), nn.Tanh() ) def forward(self, x): return self.model(x)