Brain_Emotion_Decoder / Src /NST_Inference.py
Ihssane123's picture
Initial commit
3b6d764
import argparse
from pathlib import Path
import torch
import torch.nn as nn
from PIL import Image
from torchvision import transforms
from torchvision.utils import save_image
from . import net
from .function import adaptive_instance_normalization
def test_transform(size, crop):
transform_list = []
if size != 0:
transform_list.append(transforms.Resize(size))
if crop:
transform_list.append(transforms.CenterCrop(size))
transform_list.append(transforms.ToTensor())
transform = transforms.Compose(transform_list)
return transform
def style_transfer(vgg, decoder, content, style, alpha=1.0,
interpolation_weights=None):
assert (0.0 <= alpha <= 1.0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
content_f = vgg(content)
style_f = vgg(style)
if interpolation_weights:
_, C, H, W = content_f.size()
feat = torch.FloatTensor(1, C, H, W).zero_().to(device)
base_feat = adaptive_instance_normalization(content_f, style_f)
for i, w in enumerate(interpolation_weights):
feat = feat + w * base_feat[i:i + 1]
content_f = content_f[0:1]
else:
feat = adaptive_instance_normalization(content_f, style_f)
feat = feat * alpha + content_f * (1 - alpha)
return decoder(feat)
def save_style(output_dir, content_path, style_path):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
decoder_pth = Path("models/decoder.pth")
vgg_pth = Path("models/vgg_normalised.pth")
output_dir = Path("output")
output_dir.mkdir(exist_ok=True, parents=True)
content_path = Path(content_path)
style_paths = [Path(style_path)]
decoder = net.decoder
vgg = net.vgg
decoder.eval()
vgg.eval()
decoder.load_state_dict(torch.load(decoder_pth))
vgg.load_state_dict(torch.load(vgg_pth))
vgg = nn.Sequential(*list(vgg.children())[:31])
vgg.to(device)
decoder.to(device)
content_tf = test_transform(512, True)
style_tf = test_transform(512, True)
style = torch.stack([style_tf(Image.open(str(p))) for p in style_paths])
content = content_tf(Image.open(str(content_path))) \
.unsqueeze(0).expand_as(style)
style = style.to(device)
content = content.to(device)
with torch.no_grad():
output = style_transfer(vgg, decoder, content, style,
1, '')
output = output.cpu()
output_name = output_dir / 'stylized_output.jpg'
save_image(output, str(output_name))
return output_name