Spaces:
Runtime error
Runtime error
| ''' | |
| Neural Style Transfer using TensorFlow's Pretrained Style Transfer Model | |
| https://tfhub.dev/google/magenta/arbitrary-image-stylization-v1-256/2 | |
| ''' | |
| import gradio as gr | |
| import tensorflow as tf | |
| import tensorflow_hub as hub | |
| from PIL import Image | |
| import numpy as np | |
| import functools | |
| import cv2 | |
| import os | |
| model = hub.load("https://tfhub.dev/google/magenta/arbitrary-image-stylization-v1-256/2") | |
| # source: https://stackoverflow.com/questions/4993082/how-can-i-sharpen-an-image-in-opencv | |
| def unsharp_mask(image, kernel_size=(5, 5), sigma=1.0, amount=1.0, threshold=0): | |
| """Return a sharpened version of the image, using an unsharp mask.""" | |
| blurred = cv2.GaussianBlur(image, kernel_size, sigma) | |
| sharpened = float(amount + 1) * image - float(amount) * blurred | |
| sharpened = np.maximum(sharpened, np.zeros(sharpened.shape)) | |
| sharpened = np.minimum(sharpened, 255 * np.ones(sharpened.shape)) | |
| sharpened = sharpened.round().astype(np.uint8) | |
| if threshold > 0: | |
| low_contrast_mask = np.absolute(image - blurred) < threshold | |
| np.copyto(sharpened, image, where=low_contrast_mask) | |
| return sharpened | |
| def style_transfer(content_img, style_image, style_weight=1, content_weight=1, style_blur=False): | |
| # Resize and preprocess the content image | |
| content_img = unsharp_mask(content_img, amount=1) | |
| content_img = tf.image.resize( | |
| tf.convert_to_tensor(content_img, dtype=tf.float32)[tf.newaxis, ...] / 255.0, | |
| (512, 512), | |
| preserve_aspect_ratio=True | |
| ) | |
| # Resize and preprocess the style image | |
| style_image = Image.fromarray(style_image).resize((256, 256)) | |
| style_img = tf.convert_to_tensor(np.array(style_image), dtype=tf.float32)[tf.newaxis, ...] / 255.0 | |
| if style_blur: | |
| style_img = tf.nn.avg_pool(style_img, ksize=[3, 3], strides=[1, 1], padding="VALID") | |
| # Apply style weight to the style image | |
| style_img = tf.image.adjust_contrast(style_img, style_weight) | |
| # Apply content weight and other adjustments to the content image | |
| content_img = tf.image.adjust_contrast(content_img, content_weight) | |
| content_img = tf.image.adjust_saturation(content_img, 2) | |
| content_img = tf.image.adjust_contrast(content_img, 1.5) | |
| # Stylize the content image using the style image | |
| stylized_img = model(content_img, style_img)[0] | |
| # Convert the stylized image tensor to a NumPy array | |
| stylized_img = tf.squeeze(stylized_img).numpy() | |
| # Convert the NumPy array to an image | |
| stylized_img = np.clip(stylized_img * 255.0, 0, 255).astype(np.uint8) | |
| return Image.fromarray(stylized_img) | |
| title = "Artistic Neural Style Transfer Demo 🖼️" | |
| description = "Gradio Demo for Artistic Neural Style Transfer. To use it, simply upload a content image and a style image. [Learn More](https://www.tensorflow.org/tutorials/generative/style_transfer)." | |
| article = "</br><p style='text-align: center'><a href='https://github.com/Mr-Hexi' target='_blank'>GitHub</a></p> " | |
| # Define inputs | |
| content_input = gr.Image(label="Upload an image to which you want the style to be applied.") | |
| style_input = gr.Image(label="Upload Style Image") # Removed the shape parameter | |
| style_slider = gr.Slider(0, 2, label="Adjust Style Density", value=1) | |
| content_slider = gr.Slider(1, 5, label="Content Sharpness", value=1) | |
| style_checkbox = gr.Checkbox(value=False, label="Tune Style (experimental)") | |
| # Define examples | |
| examples = [ | |
| ["Content/content_2.jpg", "Styles/style_15.jpg", 1.20, 1.70, ""], | |
| ["Content/content_4.jpg", "Styles/style_10.jpg", 0.91, 2.54, "style_checkbox"] | |
| ] | |
| # Define the interface | |
| interface = gr.Interface( | |
| fn=style_transfer, | |
| inputs=[content_input, style_input, style_slider, content_slider, style_checkbox], | |
| outputs=gr.Image(), | |
| title=title, | |
| description=description, | |
| article=article, | |
| examples=examples, | |
| allow_flagging="never", | |
| ) | |
| # Launch the interface | |
| interface.launch(debug=True) | |