Spaces:
Sleeping
Sleeping
File size: 2,588 Bytes
3b6d764 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
import argparse
from pathlib import Path
import torch
import torch.nn as nn
from PIL import Image
from torchvision import transforms
from torchvision.utils import save_image
from . import net
from .function import adaptive_instance_normalization
def test_transform(size, crop):
transform_list = []
if size != 0:
transform_list.append(transforms.Resize(size))
if crop:
transform_list.append(transforms.CenterCrop(size))
transform_list.append(transforms.ToTensor())
transform = transforms.Compose(transform_list)
return transform
def style_transfer(vgg, decoder, content, style, alpha=1.0,
interpolation_weights=None):
assert (0.0 <= alpha <= 1.0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
content_f = vgg(content)
style_f = vgg(style)
if interpolation_weights:
_, C, H, W = content_f.size()
feat = torch.FloatTensor(1, C, H, W).zero_().to(device)
base_feat = adaptive_instance_normalization(content_f, style_f)
for i, w in enumerate(interpolation_weights):
feat = feat + w * base_feat[i:i + 1]
content_f = content_f[0:1]
else:
feat = adaptive_instance_normalization(content_f, style_f)
feat = feat * alpha + content_f * (1 - alpha)
return decoder(feat)
def save_style(output_dir, content_path, style_path):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
decoder_pth = Path("models/decoder.pth")
vgg_pth = Path("models/vgg_normalised.pth")
output_dir = Path("output")
output_dir.mkdir(exist_ok=True, parents=True)
content_path = Path(content_path)
style_paths = [Path(style_path)]
decoder = net.decoder
vgg = net.vgg
decoder.eval()
vgg.eval()
decoder.load_state_dict(torch.load(decoder_pth))
vgg.load_state_dict(torch.load(vgg_pth))
vgg = nn.Sequential(*list(vgg.children())[:31])
vgg.to(device)
decoder.to(device)
content_tf = test_transform(512, True)
style_tf = test_transform(512, True)
style = torch.stack([style_tf(Image.open(str(p))) for p in style_paths])
content = content_tf(Image.open(str(content_path))) \
.unsqueeze(0).expand_as(style)
style = style.to(device)
content = content.to(device)
with torch.no_grad():
output = style_transfer(vgg, decoder, content, style,
1, '')
output = output.cpu()
output_name = output_dir / 'stylized_output.jpg'
save_image(output, str(output_name))
return output_name
|