import argparse from pathlib import Path import torch import torch.nn as nn from PIL import Image from torchvision import transforms from torchvision.utils import save_image from . import net from .function import adaptive_instance_normalization def test_transform(size, crop): transform_list = [] if size != 0: transform_list.append(transforms.Resize(size)) if crop: transform_list.append(transforms.CenterCrop(size)) transform_list.append(transforms.ToTensor()) transform = transforms.Compose(transform_list) return transform def style_transfer(vgg, decoder, content, style, alpha=1.0, interpolation_weights=None): assert (0.0 <= alpha <= 1.0) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") content_f = vgg(content) style_f = vgg(style) if interpolation_weights: _, C, H, W = content_f.size() feat = torch.FloatTensor(1, C, H, W).zero_().to(device) base_feat = adaptive_instance_normalization(content_f, style_f) for i, w in enumerate(interpolation_weights): feat = feat + w * base_feat[i:i + 1] content_f = content_f[0:1] else: feat = adaptive_instance_normalization(content_f, style_f) feat = feat * alpha + content_f * (1 - alpha) return decoder(feat) def save_style(output_dir, content_path, style_path): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") decoder_pth = Path("models/decoder.pth") vgg_pth = Path("models/vgg_normalised.pth") output_dir = Path("output") output_dir.mkdir(exist_ok=True, parents=True) content_path = Path(content_path) style_paths = [Path(style_path)] decoder = net.decoder vgg = net.vgg decoder.eval() vgg.eval() decoder.load_state_dict(torch.load(decoder_pth)) vgg.load_state_dict(torch.load(vgg_pth)) vgg = nn.Sequential(*list(vgg.children())[:31]) vgg.to(device) decoder.to(device) content_tf = test_transform(512, True) style_tf = test_transform(512, True) style = torch.stack([style_tf(Image.open(str(p))) for p in style_paths]) content = content_tf(Image.open(str(content_path))) \ .unsqueeze(0).expand_as(style) style = style.to(device) content = content.to(device) with torch.no_grad(): output = style_transfer(vgg, decoder, content, style, 1, '') output = output.cpu() output_name = output_dir / 'stylized_output.jpg' save_image(output, str(output_name)) return output_name