BillyAggarwal's picture
Update app.py
233e561 verified
import gradio as gr
import tensorflow as tf
import numpy as np
from PIL import Image
# Load VGG for feature extraction
vgg = tf.keras.applications.VGG19(include_top=False, weights="imagenet")
vgg.trainable = False
STYLE_LAYERS = [...] # your style layers
CONTENT_LAYER = [...] # your content layer(s)
def gram_matrix(A):
"""Compute Gram matrix for style representation.
Works for VGG layer outputs of shape (batch, H, W, C).
"""
A = tf.convert_to_tensor(A)
if len(A.shape) == 4:
batch, H, W, C = A.shape
A = tf.reshape(A, (batch, H * W, C))
gram = tf.matmul(A, A, transpose_a=True) / tf.cast(H * W, tf.float32)
elif len(A.shape) == 3:
batch, N, C = A.shape
gram = tf.matmul(A, A, transpose_a=True) / tf.cast(N, tf.float32)
elif len(A.shape) == 2:
N, C = A.shape
A = tf.expand_dims(A, 0)
gram = tf.matmul(A, A, transpose_a=True) / tf.cast(N, tf.float32)
else:
raise ValueError(f"Unexpected tensor rank for gram_matrix: {A.shape}")
return gram
def compute_content_cost(a_C, a_G):
return tf.reduce_mean(tf.square(a_C - a_G))
def compute_style_cost(a_S, a_G):
J_style = 0
for s, g in zip(a_S, a_G):
J_style += tf.reduce_mean(tf.square(gram_matrix(s) - gram_matrix(g)))
return J_style / len(a_S)
def total_cost(J_content, J_style, alpha=10, beta=40):
return alpha * J_content + beta * J_style
def preprocess(img):
img = Image.fromarray(img).resize((256, 256))
arr = np.expand_dims(np.array(img) / 255.0, axis=0).astype(np.float32)
return tf.convert_to_tensor(arr)
def style_transfer(content, style, steps):
content_tensor = preprocess(content)
style_tensor = preprocess(style)
a_C = vgg(content_tensor)
a_S = vgg(style_tensor)
generated_image = tf.Variable(content_tensor)
opt = tf.keras.optimizers.Adam(learning_rate=0.01)
for i in range(steps):
with tf.GradientTape() as tape:
a_G = vgg(generated_image)
J_style = compute_style_cost(a_S, a_G)
J_content = compute_content_cost(a_C, a_G)
J = total_cost(J_content, J_style, alpha=10, beta=40)
grad = tape.gradient(J, generated_image)
opt.apply_gradients([(grad, generated_image)])
generated_image.assign(tf.clip_by_value(generated_image, 0.0, 1.0))
out_img = (generated_image[0].numpy() * 255).astype("uint8")
return Image.fromarray(out_img)
demo = gr.Interface(
fn=style_transfer,
inputs=[
gr.Image(type="numpy", label="Content Image"),
gr.Image(type="numpy", label="Style Image"),
gr.Slider(50, 2000, value=1000, step=50, label="Number of Iterations")
],
outputs=gr.Image(type="pil", label="Stylized Image"),
)
demo.launch()