Style-Transfer / app.py
hari31416's picture
Update app.py
2c6ab21
import gradio as gr
from style_transfer import StyleTransfer
import tensorflow as tf
from tensorflow.keras import backend as K
import numpy as np
def validate_inputs(epochs, steps_per_epoch, image_frequency, alpha, beta, lr):
"""Validates the inputs and converts them to the correct type"""
epochs = int(epochs)
steps_per_epoch = int(steps_per_epoch)
image_frequency = int(image_frequency)
alpha = float(alpha)
beta = float(beta)
lr = float(lr)
return epochs, steps_per_epoch, image_frequency, alpha, beta, lr
def stylize_image(
content_image_path,
style_image_path,
epochs,
steps_per_epoch,
image_frequency,
alpha,
beta,
lr,
):
"""Stylizes the image using the style and content images
Parameters
----------
content_image_path : str
Path to the content image
style_image_path : str
Path to the style image
epochs : int, optional
Number of epochs
steps_per_epoch : int, optional
Number of steps per epoch
image_frequency : int, optional
Frequency of images to show
alpha : float, optional
Content weight
beta : float, optional
Style weight
lr : float, optional
Learning rate
Returns
-------
[PIL.Image]
List of images
"""
epochs, steps_per_epoch, image_frequency, alpha, beta, lr = validate_inputs(
epochs, steps_per_epoch, image_frequency, alpha, beta, lr
)
style_transfer = StyleTransfer(
content_image_path=content_image_path,
style_image_path=style_image_path,
)
if style_transfer.model is None:
K.clear_session()
_ = style_transfer.load_model()
style_image = style_transfer.load_image(style_transfer.style_image_path)
content_image = style_transfer.load_image(style_transfer.content_image_path)
style_target = style_transfer.get_features(style_image, "style")
content_target = style_transfer.get_features(content_image, "content")
target = content_target + style_target
image = tf.cast(content_image, dtype=tf.float32)
image = tf.Variable(image)
optimizer = tf.optimizers.Adam(
tf.keras.optimizers.schedules.ExponentialDecay(
initial_learning_rate=lr, decay_steps=100, decay_rate=0.80
)
)
for epoch in range(epochs):
for step in range(steps_per_epoch):
loss = style_transfer.update_image(image, target, alpha, beta, optimizer)
display_image = style_transfer.tensor_to_image(image)
# images.append(display_image)
if (step) % image_frequency == 0:
yield np.array(display_image), epoch + 1, step + 1, loss
def main():
content_image = gr.Image(type="filepath", label="Content Image", shape=(512, 512))
style_image = gr.Image(type="filepath", label="Style Image", shape=(512, 512))
epochs = gr.Slider(minimum=1, maximum=20, label="Epochs", value=10)
steps_per_epoch = gr.Slider(
minimum=1, maximum=20, label="Steps per Epoch", value=10
)
image_frequency = gr.Slider(
minimum=1, maximum=10, label="Show Image Frequency", value=2
)
alpha = gr.Slider(minimum=0, maximum=1, label="Alpha", value=1)
beta = gr.Slider(minimum=0, maximum=1, label="Beta", value=0.1)
lr = gr.Slider(minimum=0.1, maximum=100, label="Learning Rate", value=40.0)
output_image = gr.Image(type="numpy", label="Output Image", shape=(512, 512))
current_epoch = gr.Number(label="Current Epoch")
current_step = gr.Number(label="Current Step")
current_loss = gr.Number(label="Current Loss")
inputs = [
content_image,
style_image,
epochs,
steps_per_epoch,
image_frequency,
alpha,
beta,
lr,
]
outputs = [output_image, current_epoch, current_step, current_loss]
description = """### This is a demo of neural style transfer. Upload a content image and a style image, and see the result! You can play around with the parameters to see how they affect the result.
"""
interface = gr.Interface(
fn=stylize_image,
inputs=inputs,
outputs=outputs,
title="Style Transfer",
description=description,
examples=[
[
"examples/landscape_1.jpg",
"examples/van_gogh.jpg",
10,
10,
1,
1,
0.1,
30.0,
],
[
"examples/landscape_1.jpg",
"examples/picaso.jpg",
10,
10,
1,
1,
0.1,
30.0,
],
[
"examples/landscape_2.jpg",
"examples/van_gogh.jpg",
10,
10,
1,
1,
0.1,
30.0,
],
[
"examples/landscape_2.jpg",
"examples/picaso.jpg",
10,
10,
1,
1,
0.1,
30.0,
],
],
theme="gstaff/xkcd",
)
interface.queue().launch(server_name="0.0.0.0", server_port=7860)
if __name__ == "__main__":
# Run Gradio app
main()