import streamlit as st import imageio import io import numpy as np import tensorflow as tf from tensorflow.keras.layers import Layer, Conv2D, Add from tensorflow.keras.utils import register_keras_serializable from PIL import Image @register_keras_serializable(package="Custom") class SelfAttention(Layer): def __init__(self, **kwargs): super().__init__(**kwargs) def build(self, input_shape): self.filters = input_shape[-1] self.f = Conv2D(self.filters // 8, kernel_size=1, padding='same') self.g = Conv2D(self.filters // 8, kernel_size=1, padding='same') self.h = Conv2D(self.filters, kernel_size=1, padding='same') super().build(input_shape) def call(self, x): f = self.f(x) # (B, H, W, C//8) g = self.g(x) h = self.h(x) # (B, H, W, C) shape_f = tf.shape(f) B, H, W = shape_f[0], shape_f[1], shape_f[2] f_flat = tf.reshape(f, [B, H * W, self.filters // 8]) g_flat = tf.reshape(g, [B, H * W, self.filters // 8]) h_flat = tf.reshape(h, [B, H * W, self.filters]) beta = tf.nn.softmax(tf.matmul(f_flat, g_flat, transpose_b=True), axis=-1) # (B, N, N) o = tf.matmul(beta, h_flat) # (B, N, C) o = tf.reshape(o, [B, H, W, self.filters]) return Add()([x, o]) # Residual connection def get_config(self): config = super().get_config() return config def frames_to_video_bytes(frames, fps=30): buffer = io.BytesIO() writer = imageio.get_writer(buffer, format="mp4", fps=fps) for frame in frames: writer.append_data(frame) writer.close() buffer.seek(0) return buffer.read() def get_codebook_indices(inputs,codebook): embeddings=codebook embedding_dim = tf.shape(codebook)[-1] input_shape = tf.shape(inputs) flat_inputs = tf.reshape(inputs, [-1, embedding_dim]) flat_inputs = tf.cast(flat_inputs, dtype=embeddings.dtype) # Compute distances to embedding vectors '''distances = ( tf.reduce_sum(flat_inputs ** 2, axis=1, keepdims=True) - 2 * tf.matmul(flat_inputs, self.embeddings, transpose_b=True) + tf.reduce_sum(self.embeddings ** 2, axis=1) )''' flat_inputs_norm = tf.nn.l2_normalize(flat_inputs, axis=1) # [N, D] embeddings_norm = tf.nn.l2_normalize(embeddings, axis=1) # [K, D] cosine_sim = tf.matmul(flat_inputs_norm, embeddings_norm, transpose_b=True) # [N, K] distances = 1.0 - cosine_sim encoding_indices = tf.argmin(distances, axis=1) return tf.reshape(encoding_indices, input_shape[:-1]) def get_embeddings(indices, codebook): flat_indices = tf.reshape(indices, [-1]) flat_embeddings = tf.nn.embedding_lookup(codebook, flat_indices) out_shape = tf.concat([tf.shape(indices), [tf.shape(codebook)[-1]]], axis=0) return tf.reshape(flat_embeddings, out_shape) def latent_interpolation(t,a,b): t = t/100 result = (1-t)*a + (t * b) return result @st.cache_resource def load_all(): encoder = tf.keras.models.load_model('src/epoch_66_encoder_mscoco.keras', custom_objects={'SelfAttention': SelfAttention}) decoder = tf.keras.models.load_model('src/epoch_66_decoder_mscoco.keras', custom_objects={'SelfAttention': SelfAttention}) codebook = np.load('src/epoch_66_codebook_mscoco.npy') return encoder, decoder, tf.convert_to_tensor(codebook, dtype=tf.float32) encoder, decoder, codebook = load_all() col1, col2 = st.columns(2) with col1: img1 = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"],key="uploader1") with col2: img2 = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"],key="uploader2") if img1 is not None and img2 is not None: frames = [] img1 = Image.open(img1).resize((128,128)) img1 = np.array(img1)/255.0 encoded_img1 = encoder(np.expand_dims(img1,0)) img2 = Image.open(img2).resize((128,128)) img2 = np.array(img2)/255.0 encoded_img2 = encoder(np.expand_dims(img2,0)) for i in range(0,100,5): lerped = latent_interpolation(i,encoded_img1,encoded_img2) out = decoder(get_embeddings(get_codebook_indices(lerped,codebook),codebook)) frame = np.array(out[0] * 255, dtype=np.uint8) frames.append(frame) video_bytes = frames_to_video_bytes(frames,12) st.video(video_bytes)