Spaces:
Build error
Build error
| import numpy as np | |
| import paddlehub as phub | |
| import StyleTransfer.srcTransformer.StyTR as StyTR | |
| import StyleTransfer.srcTransformer.transformer as transformer | |
| import tensorflow as tf | |
| import tensorflow_hub as tfhub | |
| import torch | |
| import torch.nn as nn | |
| from PIL import Image | |
| from torchvision import transforms | |
| # TRANSFORMER | |
| vgg_path = "StyleTransfer/srcTransformer/Transformer_models/vgg_normalised.pth" | |
| decoder_path = "StyleTransfer/srcTransformer/Transformer_models/decoder_iter_160000.pth" | |
| Trans_path = ( | |
| "StyleTransfer/srcTransformer/Transformer_models/transformer_iter_160000.pth" | |
| ) | |
| embedding_path = ( | |
| "StyleTransfer/srcTransformer/Transformer_models/embedding_iter_160000.pth" | |
| ) | |
| def style_transform(h, w): | |
| """ | |
| This function creates a transformation for the style image, | |
| that crops it and formats it into a tensor. | |
| Parameters: | |
| h = height | |
| w = width | |
| Return: | |
| transform = transformation pipeline | |
| """ | |
| transform_list = [] | |
| transform_list.append(transforms.CenterCrop((h, w))) | |
| transform_list.append(transforms.ToTensor()) | |
| transform = transforms.Compose(transform_list) | |
| return transform | |
| def content_transform(): | |
| """ | |
| This function simply creates a transformation pipeline, | |
| that formats the content image into a tensor. | |
| Returns: | |
| transform = the transformation pipeline | |
| """ | |
| transform_list = [] | |
| transform_list.append(transforms.ToTensor()) | |
| transform = transforms.Compose(transform_list) | |
| return transform | |
| # This loads the network architecture already at building time | |
| vgg = StyTR.vgg | |
| vgg.load_state_dict(torch.load(vgg_path)) | |
| vgg = nn.Sequential(*list(vgg.children())[:44]) | |
| decoder = StyTR.decoder | |
| Trans = transformer.Transformer() | |
| embedding = StyTR.PatchEmbed() | |
| # The (square) shape of the content and style image is fixed | |
| content_size = 640 | |
| style_size = 640 | |
| def StyleTransformer(content_img: Image.Image, style_img: Image.Image) -> Image.Image: | |
| """ | |
| This function creates the Transformer network and applies it on | |
| a content and style image to create a styled image. | |
| Parameters: | |
| content_img = the image with the content | |
| style_img = the image with the style/pattern | |
| Returns: | |
| output = an image that is a combination of both | |
| """ | |
| decoder.eval() | |
| Trans.eval() | |
| vgg.eval() | |
| state_dict = torch.load(decoder_path) | |
| decoder.load_state_dict(state_dict) | |
| state_dict = torch.load(Trans_path) | |
| Trans.load_state_dict(state_dict) | |
| state_dict = torch.load(embedding_path) | |
| embedding.load_state_dict(state_dict) | |
| network = StyTR.StyTrans(vgg, decoder, embedding, Trans) | |
| network.eval() | |
| content_tf = content_transform() | |
| style_tf = style_transform(style_size, style_size) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| network.to(device) | |
| content = content_tf(content_img.convert("RGB")) | |
| style = style_tf(style_img.convert("RGB")) | |
| style = style.to(device).unsqueeze(0) | |
| content = content.to(device).unsqueeze(0) | |
| with torch.no_grad(): | |
| output = network(content, style) | |
| output = output[0].cpu().squeeze() | |
| output = ( | |
| output.mul(255) | |
| .add_(0.5) | |
| .clamp_(0, 255) | |
| .permute(1, 2, 0) | |
| .to("cpu", torch.uint8) | |
| .numpy() | |
| ) | |
| return Image.fromarray(output) | |
| # STYLE-FAST | |
| style_transfer_model = tfhub.load( | |
| "https://tfhub.dev/google/magenta/arbitrary-image-stylization-v1-256/2" | |
| ) | |
| def StyleFAST(content_image: Image.Image, style_image: Image.Image) -> Image.Image: | |
| """ | |
| This function applies a Fast image style transfer technique, | |
| which uses a pretrained model from tensorhub. | |
| Parameters: | |
| content_image = the image with the content | |
| style_image = the image with the style/pattern | |
| Returns: | |
| stylized_image = an image that is a combination of both | |
| """ | |
| content_image = ( | |
| tf.convert_to_tensor(np.array(content_image), np.float32)[tf.newaxis, ...] | |
| / 255.0 | |
| ) | |
| style_image = ( | |
| tf.convert_to_tensor(np.array(style_image), np.float32)[tf.newaxis, ...] / 255.0 | |
| ) | |
| output = style_transfer_model(content_image, style_image) | |
| stylized_image = output[0] | |
| return Image.fromarray(np.uint8(stylized_image[0] * 255)) | |
| # STYLE PROJECTION | |
| stylepro_artistic = phub.Module(name="stylepro_artistic") | |
| def styleProjection( | |
| content_image: Image.Image, style_image: Image.Image, alpha: float = 1.0 | |
| ): | |
| """ | |
| This function uses parameter free style transfer, | |
| based on a model from paddlehub. | |
| There is an optional weight parameter alpha, which | |
| allows to control the balance between image and style. | |
| Parameters: | |
| content_image = the image with the content | |
| style_image = the image with the style/pattern | |
| alpha = weight for the image vs style. | |
| This should be a float between 0 and 1. | |
| Returns: | |
| result = an image that is a combination of both | |
| """ | |
| result = stylepro_artistic.style_transfer( | |
| images=[ | |
| { | |
| "content": np.array(content_image.convert("RGB"))[:, :, ::-1], | |
| "styles": [np.array(style_image.convert("RGB"))[:, :, ::-1]], | |
| } | |
| ], | |
| alpha=alpha, | |
| ) | |
| return Image.fromarray(np.uint8(result[0]["data"])[:, :, ::-1]).convert("RGB") | |