Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import numpy as np | |
| import tensorflow as tf | |
| from tensorflow.keras.layers import Layer, Conv2D, Add | |
| from tensorflow.keras.utils import register_keras_serializable | |
| from PIL import Image | |
| # ------------------------- | |
| # Custom SelfAttention Layer | |
| # ------------------------- | |
| class SelfAttention(Layer): | |
| def __init__(self, **kwargs): | |
| super().__init__(**kwargs) | |
| def build(self, input_shape): | |
| self.filters = input_shape[-1] | |
| self.f = Conv2D(self.filters // 8, 1, padding="same") | |
| self.g = Conv2D(self.filters // 8, 1, padding="same") | |
| self.h = Conv2D(self.filters, 1, padding="same") | |
| super().build(input_shape) | |
| def call(self, x): | |
| f = self.f(x) | |
| g = self.g(x) | |
| h = self.h(x) | |
| B, H, W = tf.shape(f)[0], tf.shape(f)[1], tf.shape(f)[2] | |
| f_flat = tf.reshape(f, [B, H*W, self.filters//8]) | |
| g_flat = tf.reshape(g, [B, H*W, self.filters//8]) | |
| h_flat = tf.reshape(h, [B, H*W, self.filters]) | |
| beta = tf.nn.softmax(tf.matmul(f_flat, g_flat, transpose_b=True), -1) | |
| o = tf.matmul(beta, h_flat) | |
| o = tf.reshape(o, [B, H, W, self.filters]) | |
| return Add()([x, o]) | |
| # ------------------------- | |
| # Helper functions | |
| # ------------------------- | |
| def get_codebook_indices(inputs, codebook): | |
| embedding_dim = tf.shape(codebook)[-1] | |
| input_shape = tf.shape(inputs) | |
| flat = tf.reshape(inputs, [-1, embedding_dim]) | |
| flat = tf.cast(flat, codebook.dtype) | |
| flat_norm = tf.nn.l2_normalize(flat, -1) | |
| code_norm = tf.nn.l2_normalize(codebook, -1) | |
| sim = tf.matmul(flat_norm, code_norm, transpose_b=True) | |
| dist = 1.0 - sim | |
| indices = tf.argmin(dist, axis=1) | |
| return tf.reshape(indices, input_shape[:-1]) | |
| def get_embeddings(indices, codebook): | |
| flat = tf.reshape(indices, [-1]) | |
| embeds = tf.nn.embedding_lookup(codebook, flat) | |
| out_shape = tf.concat([tf.shape(indices), [tf.shape(codebook)[-1]]], 0) | |
| return tf.reshape(embeds, out_shape) | |
| # ------------------------- | |
| # Load models and codebook | |
| # ------------------------- | |
| encoder = tf.keras.models.load_model( | |
| "epoch_66_encoder_mscoco.keras", | |
| custom_objects={"SelfAttention": SelfAttention}, | |
| ) | |
| decoder = tf.keras.models.load_model( | |
| "epoch_66_decoder_mscoco.keras", | |
| custom_objects={"SelfAttention": SelfAttention}, | |
| ) | |
| codebook = np.load("epoch_66_codebook_mscoco.npy").astype(np.float32) | |
| codebook = tf.convert_to_tensor(codebook) | |
| # ------------------------- | |
| # Gradio inference function | |
| # ------------------------- | |
| def reconstruct_image(img): | |
| if img is None: | |
| return None | |
| # Resize to model input | |
| img = np.array(img.resize((128, 128))) / 255.0 | |
| encoded = encoder(np.expand_dims(img, 0)) | |
| # Pass through codebook | |
| z_q = get_embeddings(get_codebook_indices(encoded, codebook), codebook) | |
| # Decode | |
| out = decoder(z_q) | |
| out_img = np.clip(out[0].numpy() * 255, 0, 255).astype(np.uint8) | |
| out_img_pil = Image.fromarray(out_img) | |
| out_img_upscaled = out_img_pil.resize((512, 512), Image.BICUBIC) | |
| return out_img_upscaled | |
| # ------------------------- | |
| # Gradio UI | |
| # ------------------------- | |
| with gr.Blocks() as app: | |
| gr.Markdown("## Image Reconstruction with VQ-VAE") | |
| input_img = gr.Image(type="pil", label="Input Image (128x128)") | |
| output_img = gr.Image(label="Reconstructed Image") | |
| btn = gr.Button("Reconstruct") | |
| btn.click(reconstruct_image, input_img, output_img) | |
| app.launch() | |