import gradio as gr import numpy as np import tensorflow as tf from tensorflow.keras.layers import Layer, Conv2D, Add from tensorflow.keras.utils import register_keras_serializable from PIL import Image # ------------------------- # Custom SelfAttention Layer # ------------------------- @register_keras_serializable(package="Custom") class SelfAttention(Layer): def __init__(self, **kwargs): super().__init__(**kwargs) def build(self, input_shape): self.filters = input_shape[-1] self.f = Conv2D(self.filters // 8, 1, padding="same") self.g = Conv2D(self.filters // 8, 1, padding="same") self.h = Conv2D(self.filters, 1, padding="same") super().build(input_shape) def call(self, x): f = self.f(x) g = self.g(x) h = self.h(x) B, H, W = tf.shape(f)[0], tf.shape(f)[1], tf.shape(f)[2] f_flat = tf.reshape(f, [B, H*W, self.filters//8]) g_flat = tf.reshape(g, [B, H*W, self.filters//8]) h_flat = tf.reshape(h, [B, H*W, self.filters]) beta = tf.nn.softmax(tf.matmul(f_flat, g_flat, transpose_b=True), -1) o = tf.matmul(beta, h_flat) o = tf.reshape(o, [B, H, W, self.filters]) return Add()([x, o]) # ------------------------- # Helper functions # ------------------------- def get_codebook_indices(inputs, codebook): embedding_dim = tf.shape(codebook)[-1] input_shape = tf.shape(inputs) flat = tf.reshape(inputs, [-1, embedding_dim]) flat = tf.cast(flat, codebook.dtype) flat_norm = tf.nn.l2_normalize(flat, -1) code_norm = tf.nn.l2_normalize(codebook, -1) sim = tf.matmul(flat_norm, code_norm, transpose_b=True) dist = 1.0 - sim indices = tf.argmin(dist, axis=1) return tf.reshape(indices, input_shape[:-1]) def get_embeddings(indices, codebook): flat = tf.reshape(indices, [-1]) embeds = tf.nn.embedding_lookup(codebook, flat) out_shape = tf.concat([tf.shape(indices), [tf.shape(codebook)[-1]]], 0) return tf.reshape(embeds, out_shape) # ------------------------- # Load models and codebook # ------------------------- encoder = tf.keras.models.load_model( "epoch_66_encoder_mscoco.keras", custom_objects={"SelfAttention": SelfAttention}, ) decoder = tf.keras.models.load_model( "epoch_66_decoder_mscoco.keras", custom_objects={"SelfAttention": SelfAttention}, ) codebook = np.load("epoch_66_codebook_mscoco.npy").astype(np.float32) codebook = tf.convert_to_tensor(codebook) # ------------------------- # Gradio inference function # ------------------------- def reconstruct_image(img): if img is None: return None # Resize to model input img = np.array(img.resize((128, 128))) / 255.0 encoded = encoder(np.expand_dims(img, 0)) # Pass through codebook z_q = get_embeddings(get_codebook_indices(encoded, codebook), codebook) # Decode out = decoder(z_q) out_img = np.clip(out[0].numpy() * 255, 0, 255).astype(np.uint8) out_img_pil = Image.fromarray(out_img) out_img_upscaled = out_img_pil.resize((512, 512), Image.BICUBIC) return out_img_upscaled # ------------------------- # Gradio UI # ------------------------- with gr.Blocks() as app: gr.Markdown("## Image Reconstruction with VQ-VAE") input_img = gr.Image(type="pil", label="Input Image (128x128)") output_img = gr.Image(label="Reconstructed Image") btn = gr.Button("Reconstruct") btn.click(reconstruct_image, input_img, output_img) app.launch()