|
|
import torch |
|
|
import numpy as np |
|
|
import torch.nn as nn |
|
|
import torch.optim as optim |
|
|
from PIL import Image |
|
|
import torchvision.models as models |
|
|
import torchvision.transforms as transforms |
|
|
from tqdm import tqdm |
|
|
import spaces |
|
|
|
|
|
from dataTransform import load_image |
|
|
from vggModel import VGGNet |
|
|
|
|
|
@spaces.GPU(duration = 242) |
|
|
def style_transfer(content_img, style_img, total_steps, alpha=1e5, beta=1e10, learning_rate=0.001): |
|
|
|
|
|
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
print('-'*30) |
|
|
print(f'Device Initialized: {device}') |
|
|
print('-'*30) |
|
|
content_img = load_image(content_img, device) |
|
|
style_img = load_image(style_img, device) |
|
|
generated_img = content_img.clone().requires_grad_(True) |
|
|
optimizer = optim.Adam([generated_img], lr = learning_rate) |
|
|
model = VGGNet().to(device).eval() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for step in tqdm(range(total_steps)): |
|
|
|
|
|
|
|
|
|
|
|
generated_feats = model(generated_img) |
|
|
original_image_feats = model(content_img) |
|
|
style_feats = model(style_img) |
|
|
|
|
|
|
|
|
|
|
|
style_loss = original_loss = 0 |
|
|
|
|
|
|
|
|
for gen_feat, orig_image_feat, styl_feat in zip(generated_feats, original_image_feats, style_feats): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
batch, channel, height, width = gen_feat.shape |
|
|
original_loss += torch.mean((gen_feat - orig_image_feat)**2) |
|
|
|
|
|
|
|
|
|
|
|
G = gen_feat.view(channel, height*width).mm( |
|
|
gen_feat.view(channel, height*width).t() |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
A = styl_feat.view(channel, height*width).mm( |
|
|
styl_feat.view(channel, height*width).t() |
|
|
) |
|
|
|
|
|
style_loss += torch.mean((G-A)**2) |
|
|
|
|
|
total_loss = alpha*original_loss + beta*style_loss |
|
|
|
|
|
optimizer.zero_grad() |
|
|
total_loss.backward() |
|
|
optimizer.step() |
|
|
|
|
|
if step == total_steps - 1: |
|
|
|
|
|
return generated_img |