cyclegan-ukiyoe / models.py
Mike
afslkd
2adf651
import torch
import torch.nn as nn
def init_weights(net: nn.Module, init_gain=0.02):
def init_func(m):
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
nn.init.normal_(m.weight, 0.0, init_gain)
if m.bias is not None:
nn.init.zeros_(m.bias)
net.apply(init_func)
class ResnetBlock(nn.Module):
def __init__(self, dim, use_bias=True):
super().__init__()
self.block = nn.Sequential(
nn.Conv2d(dim, dim, kernel_size=3, padding=1, padding_mode='reflect', bias=use_bias),
nn.InstanceNorm2d(dim),
nn.ReLU(inplace=True),
nn.Conv2d(dim, dim, kernel_size=3, padding=1, padding_mode='reflect', bias=use_bias),
nn.InstanceNorm2d(dim),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x + self.block(x)
class ResnetGenerator(nn.Module):
def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
ngf: int = 64,
n_blocks: int = 9,
):
super().__init__()
assert n_blocks >= 0
use_bias = True
layers = [
nn.Conv2d(in_channels, ngf, kernel_size=7, padding=3, padding_mode='reflect', bias=use_bias),
nn.InstanceNorm2d(ngf),
nn.ReLU(inplace=True),
]
n_downsampling = 2
for i in range(n_downsampling):
mult = 2 ** i
layers += [
nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
nn.InstanceNorm2d(ngf * mult * 2),
nn.ReLU(inplace=True),
]
mult = 2 ** n_downsampling
for _ in range(n_blocks):
layers += [ResnetBlock(ngf * mult, use_bias=use_bias)]
for i in range(n_downsampling):
mult = 2 ** (n_downsampling - i)
layers += [
nn.ConvTranspose2d(
ngf * mult,
ngf * mult // 2,
kernel_size=3,
stride=2,
padding=1,
output_padding=1,
bias=use_bias,
),
nn.InstanceNorm2d(ngf * mult // 2),
nn.ReLU(inplace=True),
]
layers += [
nn.Conv2d(ngf, out_channels, kernel_size=7, padding=3, padding_mode='reflect',),
nn.Tanh(),
]
self.model = nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.model(x)
class PatchDiscriminator(nn.Module):
def __init__(self, in_channels: int = 3, ndf: int = 64, n_layers: int = 3):
super().__init__()
use_bias = True
kw, padw = 4, 1
layers = [
nn.Conv2d(in_channels, ndf, kernel_size=kw, stride=2, padding=padw),
nn.LeakyReLU(0.2, inplace=True),
]
nf_mult = 1
for n in range(1, n_layers):
nf_mult_prev = nf_mult
nf_mult = min(2 ** n, 8)
layers += [
nn.Conv2d(
ndf * nf_mult_prev,
ndf * nf_mult,
kernel_size=kw,
stride=2,
padding=padw,
bias=use_bias,
),
nn.InstanceNorm2d(ndf * nf_mult),
nn.LeakyReLU(0.2, inplace=True),
]
nf_mult_prev = nf_mult
nf_mult = min(2 ** n_layers, 8)
layers += [
nn.Conv2d(
ndf * nf_mult_prev,
ndf * nf_mult,
kernel_size=kw,
stride=1,
padding=padw,
bias=use_bias,
),
nn.InstanceNorm2d(ndf * nf_mult),
nn.LeakyReLU(0.2, inplace=True),
]
layers += [
nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw),
]
self.model = nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.model(x)
class CycleGAN(nn.Module):
def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
ngf: int = 64,
ndf: int = 64,
n_res_blocks: int = 9,
n_disc_layers: int = 3,
):
super().__init__()
self.G_ab = ResnetGenerator(in_channels, out_channels, ngf=ngf, n_blocks=n_res_blocks)
self.G_ba = ResnetGenerator(out_channels, in_channels, ngf=ngf, n_blocks=n_res_blocks)
self.D_a = PatchDiscriminator(in_channels, ndf=ndf, n_layers=n_disc_layers)
self.D_b = PatchDiscriminator(out_channels, ndf=ndf, n_layers=n_disc_layers)
for net in (self.G_ab, self.G_ba, self.D_a, self.D_b):
init_weights(net, init_gain=0.02)
def forward(self, imgs_a: torch.Tensor, imgs_b: torch.Tensor):
fake_b = self.G_ab(imgs_a)
fake_a = self.G_ba(imgs_b)
rec_a = self.G_ba(fake_b)
rec_b = self.G_ab(fake_a)
return fake_a, fake_b, rec_a, rec_b
def discriminate(
self,
imgs_a: torch.Tensor,
imgs_b: torch.Tensor,
fake_a: torch.Tensor,
fake_b: torch.Tensor,
):
a_real_pred = self.D_a(imgs_a)
b_real_pred = self.D_b(imgs_b)
a_fake_pred = self.D_a(fake_a)
b_fake_pred = self.D_b(fake_b)
return a_real_pred, b_real_pred, a_fake_pred, b_fake_pred
def load_model_only(path, model, map_location=None):
ckpt = torch.load(path, map_location=map_location)
model.load_state_dict(ckpt["model_state_dict"])