File size: 2,588 Bytes
3b6d764
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import argparse
from pathlib import Path

import torch
import torch.nn as nn
from PIL import Image
from torchvision import transforms
from torchvision.utils import save_image

from . import net
from .function import adaptive_instance_normalization


def test_transform(size, crop):
    transform_list = []
    if size != 0:
        transform_list.append(transforms.Resize(size))
    if crop:
        transform_list.append(transforms.CenterCrop(size))
    transform_list.append(transforms.ToTensor())
    transform = transforms.Compose(transform_list)
    return transform


def style_transfer(vgg, decoder, content, style, alpha=1.0,
                   interpolation_weights=None):
    assert (0.0 <= alpha <= 1.0)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    content_f = vgg(content)
    style_f = vgg(style)
    if interpolation_weights:
        _, C, H, W = content_f.size()
        feat = torch.FloatTensor(1, C, H, W).zero_().to(device)
        base_feat = adaptive_instance_normalization(content_f, style_f)
        for i, w in enumerate(interpolation_weights):
            feat = feat + w * base_feat[i:i + 1]
        content_f = content_f[0:1]
    else:
        feat = adaptive_instance_normalization(content_f, style_f)
    feat = feat * alpha + content_f * (1 - alpha)
    return decoder(feat)


def save_style(output_dir, content_path, style_path):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    decoder_pth = Path("models/decoder.pth")
    vgg_pth = Path("models/vgg_normalised.pth")
    output_dir = Path("output")
    output_dir.mkdir(exist_ok=True, parents=True)
    content_path = Path(content_path)
    style_paths =  [Path(style_path)]

    decoder = net.decoder
    vgg = net.vgg

    decoder.eval()
    vgg.eval()

    decoder.load_state_dict(torch.load(decoder_pth))
    vgg.load_state_dict(torch.load(vgg_pth))
    vgg = nn.Sequential(*list(vgg.children())[:31])

    vgg.to(device)
    decoder.to(device)

    content_tf = test_transform(512, True)
    style_tf = test_transform(512, True)
    style = torch.stack([style_tf(Image.open(str(p))) for p in style_paths])
    content = content_tf(Image.open(str(content_path))) \
        .unsqueeze(0).expand_as(style)
    style = style.to(device)
    content = content.to(device)
    with torch.no_grad():
        output = style_transfer(vgg, decoder, content, style,
                                1, '')
    output = output.cpu()
    output_name = output_dir / 'stylized_output.jpg'
    save_image(output, str(output_name))
    return output_name