VQGAN-Recon / app.py
Beasto's picture
Update app.py
6089e34 verified
import gradio as gr
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Layer, Conv2D, Add
from tensorflow.keras.utils import register_keras_serializable
from PIL import Image
# -------------------------
# Custom SelfAttention Layer
# -------------------------
@register_keras_serializable(package="Custom")
class SelfAttention(Layer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def build(self, input_shape):
self.filters = input_shape[-1]
self.f = Conv2D(self.filters // 8, 1, padding="same")
self.g = Conv2D(self.filters // 8, 1, padding="same")
self.h = Conv2D(self.filters, 1, padding="same")
super().build(input_shape)
def call(self, x):
f = self.f(x)
g = self.g(x)
h = self.h(x)
B, H, W = tf.shape(f)[0], tf.shape(f)[1], tf.shape(f)[2]
f_flat = tf.reshape(f, [B, H*W, self.filters//8])
g_flat = tf.reshape(g, [B, H*W, self.filters//8])
h_flat = tf.reshape(h, [B, H*W, self.filters])
beta = tf.nn.softmax(tf.matmul(f_flat, g_flat, transpose_b=True), -1)
o = tf.matmul(beta, h_flat)
o = tf.reshape(o, [B, H, W, self.filters])
return Add()([x, o])
# -------------------------
# Helper functions
# -------------------------
def get_codebook_indices(inputs, codebook):
embedding_dim = tf.shape(codebook)[-1]
input_shape = tf.shape(inputs)
flat = tf.reshape(inputs, [-1, embedding_dim])
flat = tf.cast(flat, codebook.dtype)
flat_norm = tf.nn.l2_normalize(flat, -1)
code_norm = tf.nn.l2_normalize(codebook, -1)
sim = tf.matmul(flat_norm, code_norm, transpose_b=True)
dist = 1.0 - sim
indices = tf.argmin(dist, axis=1)
return tf.reshape(indices, input_shape[:-1])
def get_embeddings(indices, codebook):
flat = tf.reshape(indices, [-1])
embeds = tf.nn.embedding_lookup(codebook, flat)
out_shape = tf.concat([tf.shape(indices), [tf.shape(codebook)[-1]]], 0)
return tf.reshape(embeds, out_shape)
# -------------------------
# Load models and codebook
# -------------------------
encoder = tf.keras.models.load_model(
"epoch_66_encoder_mscoco.keras",
custom_objects={"SelfAttention": SelfAttention},
)
decoder = tf.keras.models.load_model(
"epoch_66_decoder_mscoco.keras",
custom_objects={"SelfAttention": SelfAttention},
)
codebook = np.load("epoch_66_codebook_mscoco.npy").astype(np.float32)
codebook = tf.convert_to_tensor(codebook)
# -------------------------
# Gradio inference function
# -------------------------
def reconstruct_image(img):
if img is None:
return None
# Resize to model input
img = np.array(img.resize((128, 128))) / 255.0
encoded = encoder(np.expand_dims(img, 0))
# Pass through codebook
z_q = get_embeddings(get_codebook_indices(encoded, codebook), codebook)
# Decode
out = decoder(z_q)
out_img = np.clip(out[0].numpy() * 255, 0, 255).astype(np.uint8)
out_img_pil = Image.fromarray(out_img)
out_img_upscaled = out_img_pil.resize((512, 512), Image.BICUBIC)
return out_img_upscaled
# -------------------------
# Gradio UI
# -------------------------
with gr.Blocks() as app:
gr.Markdown("## Image Reconstruction with VQ-VAE")
input_img = gr.Image(type="pil", label="Input Image (128x128)")
output_img = gr.Image(label="Reconstructed Image")
btn = gr.Button("Reconstruct")
btn.click(reconstruct_image, input_img, output_img)
app.launch()