SofaStyler / StyleTransfer /styleTransfer.py
Sophie98
change to streamlit
ad1ac8f
import numpy as np
import paddlehub as phub
import StyleTransfer.srcTransformer.StyTR as StyTR
import StyleTransfer.srcTransformer.transformer as transformer
import tensorflow as tf
import tensorflow_hub as tfhub
import torch
import torch.nn as nn
from PIL import Image
from torchvision import transforms
# TRANSFORMER
vgg_path = "StyleTransfer/srcTransformer/Transformer_models/vgg_normalised.pth"
decoder_path = "StyleTransfer/srcTransformer/Transformer_models/decoder_iter_160000.pth"
Trans_path = (
"StyleTransfer/srcTransformer/Transformer_models/transformer_iter_160000.pth"
)
embedding_path = (
"StyleTransfer/srcTransformer/Transformer_models/embedding_iter_160000.pth"
)
def style_transform(h, w):
"""
This function creates a transformation for the style image,
that crops it and formats it into a tensor.
Parameters:
h = height
w = width
Return:
transform = transformation pipeline
"""
transform_list = []
transform_list.append(transforms.CenterCrop((h, w)))
transform_list.append(transforms.ToTensor())
transform = transforms.Compose(transform_list)
return transform
def content_transform():
"""
This function simply creates a transformation pipeline,
that formats the content image into a tensor.
Returns:
transform = the transformation pipeline
"""
transform_list = []
transform_list.append(transforms.ToTensor())
transform = transforms.Compose(transform_list)
return transform
# This loads the network architecture already at building time
vgg = StyTR.vgg
vgg.load_state_dict(torch.load(vgg_path))
vgg = nn.Sequential(*list(vgg.children())[:44])
decoder = StyTR.decoder
Trans = transformer.Transformer()
embedding = StyTR.PatchEmbed()
# The (square) shape of the content and style image is fixed
content_size = 640
style_size = 640
def StyleTransformer(content_img: Image.Image, style_img: Image.Image) -> Image.Image:
"""
This function creates the Transformer network and applies it on
a content and style image to create a styled image.
Parameters:
content_img = the image with the content
style_img = the image with the style/pattern
Returns:
output = an image that is a combination of both
"""
decoder.eval()
Trans.eval()
vgg.eval()
state_dict = torch.load(decoder_path)
decoder.load_state_dict(state_dict)
state_dict = torch.load(Trans_path)
Trans.load_state_dict(state_dict)
state_dict = torch.load(embedding_path)
embedding.load_state_dict(state_dict)
network = StyTR.StyTrans(vgg, decoder, embedding, Trans)
network.eval()
content_tf = content_transform()
style_tf = style_transform(style_size, style_size)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
network.to(device)
content = content_tf(content_img.convert("RGB"))
style = style_tf(style_img.convert("RGB"))
style = style.to(device).unsqueeze(0)
content = content.to(device).unsqueeze(0)
with torch.no_grad():
output = network(content, style)
output = output[0].cpu().squeeze()
output = (
output.mul(255)
.add_(0.5)
.clamp_(0, 255)
.permute(1, 2, 0)
.to("cpu", torch.uint8)
.numpy()
)
return Image.fromarray(output)
# STYLE-FAST
style_transfer_model = tfhub.load(
"https://tfhub.dev/google/magenta/arbitrary-image-stylization-v1-256/2"
)
def StyleFAST(content_image: Image.Image, style_image: Image.Image) -> Image.Image:
"""
This function applies a Fast image style transfer technique,
which uses a pretrained model from tensorhub.
Parameters:
content_image = the image with the content
style_image = the image with the style/pattern
Returns:
stylized_image = an image that is a combination of both
"""
content_image = (
tf.convert_to_tensor(np.array(content_image), np.float32)[tf.newaxis, ...]
/ 255.0
)
style_image = (
tf.convert_to_tensor(np.array(style_image), np.float32)[tf.newaxis, ...] / 255.0
)
output = style_transfer_model(content_image, style_image)
stylized_image = output[0]
return Image.fromarray(np.uint8(stylized_image[0] * 255))
# STYLE PROJECTION
stylepro_artistic = phub.Module(name="stylepro_artistic")
def styleProjection(
content_image: Image.Image, style_image: Image.Image, alpha: float = 1.0
):
"""
This function uses parameter free style transfer,
based on a model from paddlehub.
There is an optional weight parameter alpha, which
allows to control the balance between image and style.
Parameters:
content_image = the image with the content
style_image = the image with the style/pattern
alpha = weight for the image vs style.
This should be a float between 0 and 1.
Returns:
result = an image that is a combination of both
"""
result = stylepro_artistic.style_transfer(
images=[
{
"content": np.array(content_image.convert("RGB"))[:, :, ::-1],
"styles": [np.array(style_image.convert("RGB"))[:, :, ::-1]],
}
],
alpha=alpha,
)
return Image.fromarray(np.uint8(result[0]["data"])[:, :, ::-1]).convert("RGB")