hw2 / model.py
klasser's picture
Upload model.py
10842a1 verified
import torch
import torch.nn as nn
class ResidualBlock(nn.Module):
def __init__(self, in_features):
super(ResidualBlock, self).__init__()
self.block = nn.Sequential(
nn.ReflectionPad2d(1), nn.Conv2d(in_features, in_features, 3),
nn.InstanceNorm2d(in_features), nn.ReLU(inplace=True),
nn.ReflectionPad2d(1), nn.Conv2d(in_features, in_features, 3),
nn.InstanceNorm2d(in_features)
)
def forward(self, x):
return x + self.block(x)
class ResNetGenerator(nn.Module):
def __init__(self, input_channels=3, output_channels=3, num_residual_blocks=9):
super(ResNetGenerator, self).__init__()
out_features = 64
model =[nn.ReflectionPad2d(3), nn.Conv2d(input_channels, out_features, 7), nn.InstanceNorm2d(out_features), nn.ReLU(inplace=True)]
in_features = out_features
for _ in range(2):
out_features *= 2
model +=[nn.Conv2d(in_features, out_features, 3, stride=2, padding=1), nn.InstanceNorm2d(out_features), nn.ReLU(inplace=True)]
in_features = out_features
for _ in range(num_residual_blocks):
model += [ResidualBlock(in_features)]
for _ in range(2):
out_features //= 2
model +=[nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1), nn.InstanceNorm2d(out_features), nn.ReLU(inplace=True)]
in_features = out_features
model +=[nn.ReflectionPad2d(3), nn.Conv2d(out_features, output_channels, 7), nn.Tanh()]
self.model = nn.Sequential(*model)
def forward(self, x):
return self.model(x)
class PatchGANDiscriminator(nn.Module):
def __init__(self, input_channels=3):
super(PatchGANDiscriminator, self).__init__()
def discriminator_block(in_filters, out_filters, stride=2, normalize=True):
layers =[nn.Conv2d(in_filters, out_filters, 4, stride=stride, padding=1)]
if normalize: layers.append(nn.InstanceNorm2d(out_filters))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers
self.model = nn.Sequential(
*discriminator_block(input_channels, 64, normalize=False),
*discriminator_block(64, 128),
*discriminator_block(128, 256),
*discriminator_block(256, 512, stride=1),
nn.Conv2d(512, 1, 4, padding=1)
)
def forward(self, x):
return self.model(x)
def weights_init_normal(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
if hasattr(m, 'bias') and m.bias is not None:
nn.init.constant_(m.bias.data, 0.0)
class CycleGAN(nn.Module):
def __init__(self):
super(CycleGAN, self).__init__()
self.G_A2B = ResNetGenerator(num_residual_blocks=9)
self.G_B2A = ResNetGenerator(num_residual_blocks=9)
self.D_A = PatchGANDiscriminator()
self.D_B = PatchGANDiscriminator()
self.apply(weights_init_normal)