import gradio as gr from style_transfer import StyleTransfer import tensorflow as tf from tensorflow.keras import backend as K import numpy as np def validate_inputs(epochs, steps_per_epoch, image_frequency, alpha, beta, lr): """Validates the inputs and converts them to the correct type""" epochs = int(epochs) steps_per_epoch = int(steps_per_epoch) image_frequency = int(image_frequency) alpha = float(alpha) beta = float(beta) lr = float(lr) return epochs, steps_per_epoch, image_frequency, alpha, beta, lr def stylize_image( content_image_path, style_image_path, epochs, steps_per_epoch, image_frequency, alpha, beta, lr, ): """Stylizes the image using the style and content images Parameters ---------- content_image_path : str Path to the content image style_image_path : str Path to the style image epochs : int, optional Number of epochs steps_per_epoch : int, optional Number of steps per epoch image_frequency : int, optional Frequency of images to show alpha : float, optional Content weight beta : float, optional Style weight lr : float, optional Learning rate Returns ------- [PIL.Image] List of images """ epochs, steps_per_epoch, image_frequency, alpha, beta, lr = validate_inputs( epochs, steps_per_epoch, image_frequency, alpha, beta, lr ) style_transfer = StyleTransfer( content_image_path=content_image_path, style_image_path=style_image_path, ) if style_transfer.model is None: K.clear_session() _ = style_transfer.load_model() style_image = style_transfer.load_image(style_transfer.style_image_path) content_image = style_transfer.load_image(style_transfer.content_image_path) style_target = style_transfer.get_features(style_image, "style") content_target = style_transfer.get_features(content_image, "content") target = content_target + style_target image = tf.cast(content_image, dtype=tf.float32) image = tf.Variable(image) optimizer = tf.optimizers.Adam( tf.keras.optimizers.schedules.ExponentialDecay( initial_learning_rate=lr, decay_steps=100, decay_rate=0.80 ) ) for epoch in range(epochs): for step in range(steps_per_epoch): loss = style_transfer.update_image(image, target, alpha, beta, optimizer) display_image = style_transfer.tensor_to_image(image) # images.append(display_image) if (step) % image_frequency == 0: yield np.array(display_image), epoch + 1, step + 1, loss def main(): content_image = gr.Image(type="filepath", label="Content Image", shape=(512, 512)) style_image = gr.Image(type="filepath", label="Style Image", shape=(512, 512)) epochs = gr.Slider(minimum=1, maximum=20, label="Epochs", value=10) steps_per_epoch = gr.Slider( minimum=1, maximum=20, label="Steps per Epoch", value=10 ) image_frequency = gr.Slider( minimum=1, maximum=10, label="Show Image Frequency", value=2 ) alpha = gr.Slider(minimum=0, maximum=1, label="Alpha", value=1) beta = gr.Slider(minimum=0, maximum=1, label="Beta", value=0.1) lr = gr.Slider(minimum=0.1, maximum=100, label="Learning Rate", value=40.0) output_image = gr.Image(type="numpy", label="Output Image", shape=(512, 512)) current_epoch = gr.Number(label="Current Epoch") current_step = gr.Number(label="Current Step") current_loss = gr.Number(label="Current Loss") inputs = [ content_image, style_image, epochs, steps_per_epoch, image_frequency, alpha, beta, lr, ] outputs = [output_image, current_epoch, current_step, current_loss] description = """### This is a demo of neural style transfer. Upload a content image and a style image, and see the result! You can play around with the parameters to see how they affect the result. """ interface = gr.Interface( fn=stylize_image, inputs=inputs, outputs=outputs, title="Style Transfer", description=description, examples=[ [ "examples/landscape_1.jpg", "examples/van_gogh.jpg", 10, 10, 1, 1, 0.1, 30.0, ], [ "examples/landscape_1.jpg", "examples/picaso.jpg", 10, 10, 1, 1, 0.1, 30.0, ], [ "examples/landscape_2.jpg", "examples/van_gogh.jpg", 10, 10, 1, 1, 0.1, 30.0, ], [ "examples/landscape_2.jpg", "examples/picaso.jpg", 10, 10, 1, 1, 0.1, 30.0, ], ], theme="gstaff/xkcd", ) interface.queue().launch(server_name="0.0.0.0", server_port=7860) if __name__ == "__main__": # Run Gradio app main()