Madhavan86776's picture
Upload 32 files
99c3bcf verified
import torch
import torch.nn as nn
import torchvision.models as models
# -------- Residual Block --------
class ResidualBlock(nn.Module):
def __init__(self, channels):
super().__init__()
self.block = nn.Sequential(
nn.Conv2d(channels, channels, 3, 1, 1),
nn.InstanceNorm2d(channels, affine=True),
nn.ReLU(inplace=True),
nn.Conv2d(channels, channels, 3, 1, 1),
nn.InstanceNorm2d(channels, affine=True),
)
def forward(self, x):
return x + self.block(x)
# -------- Transformer Network --------
class TransformerNet(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Conv2d(3, 32, 9, 1, 4),
nn.InstanceNorm2d(32),
nn.ReLU(),
nn.Conv2d(32, 64, 3, 2, 1),
nn.InstanceNorm2d(64),
nn.ReLU(),
nn.Conv2d(64, 128, 3, 2, 1),
nn.InstanceNorm2d(128),
nn.ReLU(),
ResidualBlock(128),
ResidualBlock(128),
ResidualBlock(128),
ResidualBlock(128),
ResidualBlock(128),
nn.ConvTranspose2d(128, 64, 3, 2, 1, 1),
nn.InstanceNorm2d(64),
nn.ReLU(),
nn.ConvTranspose2d(64, 32, 3, 2, 1, 1),
nn.InstanceNorm2d(32),
nn.ReLU(),
nn.Conv2d(32, 3, 9, 1, 4),
nn.Tanh()
)
def forward(self, x):
return self.model(x)