Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torchvision.models as models | |
| # -------- Residual Block -------- | |
| class ResidualBlock(nn.Module): | |
| def __init__(self, channels): | |
| super().__init__() | |
| self.block = nn.Sequential( | |
| nn.Conv2d(channels, channels, 3, 1, 1), | |
| nn.InstanceNorm2d(channels, affine=True), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(channels, channels, 3, 1, 1), | |
| nn.InstanceNorm2d(channels, affine=True), | |
| ) | |
| def forward(self, x): | |
| return x + self.block(x) | |
| # -------- Transformer Network -------- | |
| class TransformerNet(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.model = nn.Sequential( | |
| nn.Conv2d(3, 32, 9, 1, 4), | |
| nn.InstanceNorm2d(32), | |
| nn.ReLU(), | |
| nn.Conv2d(32, 64, 3, 2, 1), | |
| nn.InstanceNorm2d(64), | |
| nn.ReLU(), | |
| nn.Conv2d(64, 128, 3, 2, 1), | |
| nn.InstanceNorm2d(128), | |
| nn.ReLU(), | |
| ResidualBlock(128), | |
| ResidualBlock(128), | |
| ResidualBlock(128), | |
| ResidualBlock(128), | |
| ResidualBlock(128), | |
| nn.ConvTranspose2d(128, 64, 3, 2, 1, 1), | |
| nn.InstanceNorm2d(64), | |
| nn.ReLU(), | |
| nn.ConvTranspose2d(64, 32, 3, 2, 1, 1), | |
| nn.InstanceNorm2d(32), | |
| nn.ReLU(), | |
| nn.Conv2d(32, 3, 9, 1, 4), | |
| nn.Tanh() | |
| ) | |
| def forward(self, x): | |
| return self.model(x) | |