NeoNude / src /model.py
fahimahamed1's picture
Enable LFS
995526c
"""
NeoNude model architecture, dataset, and utilities.
Based on a modified pix2pixHD GAN architecture.
"""
from PIL import Image
import numpy as np
import cv2
import torch
import torchvision.transforms as transforms
import functools
import os
# ---------------------------------------------------------------------------
# Dataset & DataLoader
# ---------------------------------------------------------------------------
class Dataset(torch.utils.data.Dataset):
"""Wraps a single OpenCV image into a PyTorch dataset."""
def __init__(self):
super().__init__()
def initialize(self, opt, cv_img):
self.opt = opt
self.root = opt.dataroot
self.A = Image.fromarray(cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB))
self.dataset_size = 1
def __getitem__(self, index):
transform = get_transform(self.opt)
a_tensor = transform(self.A.convert("RGB"))
input_dict = {
"label": a_tensor,
"inst": 0,
"image": 0,
"feat": 0,
"path": "",
}
return input_dict
def __len__(self):
return 1
class DataLoader:
"""Thin wrapper around torch.utils.data.DataLoader for a single image."""
def __init__(self, opt, cv_img):
super().__init__()
self.dataset = Dataset()
self.dataset.initialize(opt, cv_img)
self.dataloader = torch.utils.data.DataLoader(
self.dataset,
batch_size=opt.batch_size,
shuffle=not opt.serial_batches,
num_workers=int(opt.n_threads),
)
def load_data(self):
return self.dataloader
def __len__(self):
return 1
# ---------------------------------------------------------------------------
# Generator Network
# ---------------------------------------------------------------------------
class ResnetBlock(torch.nn.Module):
"""A single residual block with reflection padding."""
def __init__(self, dim, padding_type, norm_layer,
activation=torch.nn.ReLU(True), use_dropout=False):
super().__init__()
self.conv_block = self._build_conv_block(
dim, padding_type, norm_layer, activation, use_dropout
)
def _build_conv_block(self, dim, padding_type, norm_layer,
activation, use_dropout):
conv_block = []
p = 0
if padding_type == "reflect":
conv_block += [torch.nn.ReflectionPad2d(1)]
elif padding_type == "replicate":
conv_block += [torch.nn.ReplicationPad2d(1)]
elif padding_type == "zero":
p = 1
else:
raise NotImplementedError(
f"Padding [{padding_type}] is not implemented"
)
conv_block += [
torch.nn.Conv2d(dim, dim, kernel_size=3, padding=p),
norm_layer(dim),
activation,
]
if use_dropout:
conv_block += [torch.nn.Dropout(0.5)]
p = 0
if padding_type == "reflect":
conv_block += [torch.nn.ReflectionPad2d(1)]
elif padding_type == "replicate":
conv_block += [torch.nn.ReplicationPad2d(1)]
elif padding_type == "zero":
p = 1
else:
raise NotImplementedError(
f"Padding [{padding_type}] is not implemented"
)
conv_block += [
torch.nn.Conv2d(dim, dim, kernel_size=3, padding=p),
norm_layer(dim),
]
return torch.nn.Sequential(*conv_block)
def forward(self, x):
return x + self.conv_block(x)
class GlobalGenerator(torch.nn.Module):
"""Global generator network (U-Net style encoder-decoder)."""
def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3,
n_blocks=9, norm_layer=torch.nn.BatchNorm2d,
padding_type="reflect"):
assert n_blocks >= 0
super().__init__()
activation = torch.nn.ReLU(True)
model = [
torch.nn.ReflectionPad2d(3),
torch.nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0),
norm_layer(ngf),
activation,
]
# Encoder (downsample)
for i in range(n_downsampling):
mult = 2 ** i
model += [
torch.nn.Conv2d(
ngf * mult, ngf * mult * 2,
kernel_size=3, stride=2, padding=1,
),
norm_layer(ngf * mult * 2),
activation,
]
# ResNet blocks
mult = 2 ** n_downsampling
for _ in range(n_blocks):
model += [
ResnetBlock(ngf * mult, padding_type=padding_type,
activation=activation, norm_layer=norm_layer)
]
# Decoder (upsample)
for i in range(n_downsampling):
mult = 2 ** (n_downsampling - i)
model += [
torch.nn.ConvTranspose2d(
ngf * mult, int(ngf * mult / 2),
kernel_size=3, stride=2, padding=1, output_padding=1,
),
norm_layer(int(ngf * mult / 2)),
activation,
]
model += [
torch.nn.ReflectionPad2d(3),
torch.nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0),
torch.nn.Tanh(),
]
self.model = torch.nn.Sequential(*model)
def forward(self, x):
return self.model(x)
# ---------------------------------------------------------------------------
# High-level Model (load + inference)
# ---------------------------------------------------------------------------
class DeepModel:
"""Wraps a GlobalGenerator with weight loading and inference."""
def initialize(self, opt):
self.opt = opt
self.gpu_ids = []
self.net_g = self._define_g(
opt.input_nc, opt.output_nc, opt.ngf, opt.net_g,
opt.n_downsample_global, opt.n_blocks_global,
opt.n_local_enhancers, opt.n_blocks_local,
opt.norm, self.gpu_ids,
)
self._load_network(self.net_g)
def inference(self, label, inst):
input_label, _, _, _ = self._encode_input(label, inst, infer=True)
with torch.no_grad():
return self.net_g.forward(input_label)
# -- private helpers --
def _load_network(self, network):
save_path = os.path.join(self.opt.checkpoints_dir)
network.load_state_dict(torch.load(save_path))
def _encode_input(self, label_map, inst_map=None,
real_image=None, feat_map=None, infer=False):
if len(self.gpu_ids) > 0:
input_label = label_map.data.cuda()
else:
input_label = label_map.data
return input_label, inst_map, real_image, feat_map
def _weights_init(self, m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find("BatchNorm2d") != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
def _define_g(self, input_nc, output_nc, ngf, net_g,
n_downsample_global=3, n_blocks_global=9,
n_local_enhancers=1, n_blocks_local=3,
norm="instance", gpu_ids=[]):
norm_layer = functools.partial(
torch.nn.InstanceNorm2d, affine=False
)
net = GlobalGenerator(
input_nc, output_nc, ngf,
n_downsample_global, n_blocks_global, norm_layer,
)
if len(gpu_ids) > 0:
net.cuda(gpu_ids[0])
net.apply(self._weights_init)
return net
# ---------------------------------------------------------------------------
# Utility functions
# ---------------------------------------------------------------------------
def get_transform(opt, method=Image.BICUBIC, normalize=True):
"""Build the image transform pipeline for the dataset."""
transform_list = []
base = float(2 ** opt.n_downsample_global)
if opt.net_g == "local":
base *= 2 ** opt.n_local_enhancers
transform_list.append(
transforms.Lambda(lambda img: _make_power_2(img, base, method))
)
transform_list += [transforms.ToTensor()]
if normalize:
transform_list += [
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]
return transforms.Compose(transform_list)
def _make_power_2(img, base, method=Image.BICUBIC):
"""Resize image so dimensions are multiples of base."""
ow, oh = img.size
h = int(round(oh / base) * base)
w = int(round(ow / base) * base)
if (h == oh) and (w == ow):
return img
return img.resize((w, h), method)
def tensor2im(image_tensor, imtype=np.uint8, normalize=True):
"""Convert a PyTorch tensor to a NumPy image array."""
if isinstance(image_tensor, list):
return [tensor2im(t, imtype, normalize) for t in image_tensor]
image_numpy = image_tensor.cpu().float().numpy()
if normalize:
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
else:
image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0
image_numpy = np.clip(image_numpy, 0, 255)
if image_numpy.shape[2] == 1 or image_numpy.shape[2] > 3:
image_numpy = image_numpy[:, :, 0]
return image_numpy.astype(imtype)