Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from style_transfer import StyleTransfer | |
| import tensorflow as tf | |
| from tensorflow.keras import backend as K | |
| import numpy as np | |
| def validate_inputs(epochs, steps_per_epoch, image_frequency, alpha, beta, lr): | |
| """Validates the inputs and converts them to the correct type""" | |
| epochs = int(epochs) | |
| steps_per_epoch = int(steps_per_epoch) | |
| image_frequency = int(image_frequency) | |
| alpha = float(alpha) | |
| beta = float(beta) | |
| lr = float(lr) | |
| return epochs, steps_per_epoch, image_frequency, alpha, beta, lr | |
| def stylize_image( | |
| content_image_path, | |
| style_image_path, | |
| epochs, | |
| steps_per_epoch, | |
| image_frequency, | |
| alpha, | |
| beta, | |
| lr, | |
| ): | |
| """Stylizes the image using the style and content images | |
| Parameters | |
| ---------- | |
| content_image_path : str | |
| Path to the content image | |
| style_image_path : str | |
| Path to the style image | |
| epochs : int, optional | |
| Number of epochs | |
| steps_per_epoch : int, optional | |
| Number of steps per epoch | |
| image_frequency : int, optional | |
| Frequency of images to show | |
| alpha : float, optional | |
| Content weight | |
| beta : float, optional | |
| Style weight | |
| lr : float, optional | |
| Learning rate | |
| Returns | |
| ------- | |
| [PIL.Image] | |
| List of images | |
| """ | |
| epochs, steps_per_epoch, image_frequency, alpha, beta, lr = validate_inputs( | |
| epochs, steps_per_epoch, image_frequency, alpha, beta, lr | |
| ) | |
| style_transfer = StyleTransfer( | |
| content_image_path=content_image_path, | |
| style_image_path=style_image_path, | |
| ) | |
| if style_transfer.model is None: | |
| K.clear_session() | |
| _ = style_transfer.load_model() | |
| style_image = style_transfer.load_image(style_transfer.style_image_path) | |
| content_image = style_transfer.load_image(style_transfer.content_image_path) | |
| style_target = style_transfer.get_features(style_image, "style") | |
| content_target = style_transfer.get_features(content_image, "content") | |
| target = content_target + style_target | |
| image = tf.cast(content_image, dtype=tf.float32) | |
| image = tf.Variable(image) | |
| optimizer = tf.optimizers.Adam( | |
| tf.keras.optimizers.schedules.ExponentialDecay( | |
| initial_learning_rate=lr, decay_steps=100, decay_rate=0.80 | |
| ) | |
| ) | |
| for epoch in range(epochs): | |
| for step in range(steps_per_epoch): | |
| loss = style_transfer.update_image(image, target, alpha, beta, optimizer) | |
| display_image = style_transfer.tensor_to_image(image) | |
| # images.append(display_image) | |
| if (step) % image_frequency == 0: | |
| yield np.array(display_image), epoch + 1, step + 1, loss | |
| def main(): | |
| content_image = gr.Image(type="filepath", label="Content Image", shape=(512, 512)) | |
| style_image = gr.Image(type="filepath", label="Style Image", shape=(512, 512)) | |
| epochs = gr.Slider(minimum=1, maximum=20, label="Epochs", value=10) | |
| steps_per_epoch = gr.Slider( | |
| minimum=1, maximum=20, label="Steps per Epoch", value=10 | |
| ) | |
| image_frequency = gr.Slider( | |
| minimum=1, maximum=10, label="Show Image Frequency", value=2 | |
| ) | |
| alpha = gr.Slider(minimum=0, maximum=1, label="Alpha", value=1) | |
| beta = gr.Slider(minimum=0, maximum=1, label="Beta", value=0.1) | |
| lr = gr.Slider(minimum=0.1, maximum=100, label="Learning Rate", value=40.0) | |
| output_image = gr.Image(type="numpy", label="Output Image", shape=(512, 512)) | |
| current_epoch = gr.Number(label="Current Epoch") | |
| current_step = gr.Number(label="Current Step") | |
| current_loss = gr.Number(label="Current Loss") | |
| inputs = [ | |
| content_image, | |
| style_image, | |
| epochs, | |
| steps_per_epoch, | |
| image_frequency, | |
| alpha, | |
| beta, | |
| lr, | |
| ] | |
| outputs = [output_image, current_epoch, current_step, current_loss] | |
| description = """### This is a demo of neural style transfer. Upload a content image and a style image, and see the result! You can play around with the parameters to see how they affect the result. | |
| """ | |
| interface = gr.Interface( | |
| fn=stylize_image, | |
| inputs=inputs, | |
| outputs=outputs, | |
| title="Style Transfer", | |
| description=description, | |
| examples=[ | |
| [ | |
| "examples/landscape_1.jpg", | |
| "examples/van_gogh.jpg", | |
| 10, | |
| 10, | |
| 1, | |
| 1, | |
| 0.1, | |
| 30.0, | |
| ], | |
| [ | |
| "examples/landscape_1.jpg", | |
| "examples/picaso.jpg", | |
| 10, | |
| 10, | |
| 1, | |
| 1, | |
| 0.1, | |
| 30.0, | |
| ], | |
| [ | |
| "examples/landscape_2.jpg", | |
| "examples/van_gogh.jpg", | |
| 10, | |
| 10, | |
| 1, | |
| 1, | |
| 0.1, | |
| 30.0, | |
| ], | |
| [ | |
| "examples/landscape_2.jpg", | |
| "examples/picaso.jpg", | |
| 10, | |
| 10, | |
| 1, | |
| 1, | |
| 0.1, | |
| 30.0, | |
| ], | |
| ], | |
| theme="gstaff/xkcd", | |
| ) | |
| interface.queue().launch(server_name="0.0.0.0", server_port=7860) | |
| if __name__ == "__main__": | |
| # Run Gradio app | |
| main() | |