Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import imageio | |
| import io | |
| import numpy as np | |
| import tensorflow as tf | |
| from tensorflow.keras.layers import Layer, Conv2D, Add | |
| from tensorflow.keras.utils import register_keras_serializable | |
| from PIL import Image | |
| class SelfAttention(Layer): | |
| def __init__(self, **kwargs): | |
| super().__init__(**kwargs) | |
| def build(self, input_shape): | |
| self.filters = input_shape[-1] | |
| self.f = Conv2D(self.filters // 8, kernel_size=1, padding='same') | |
| self.g = Conv2D(self.filters // 8, kernel_size=1, padding='same') | |
| self.h = Conv2D(self.filters, kernel_size=1, padding='same') | |
| super().build(input_shape) | |
| def call(self, x): | |
| f = self.f(x) # (B, H, W, C//8) | |
| g = self.g(x) | |
| h = self.h(x) # (B, H, W, C) | |
| shape_f = tf.shape(f) | |
| B, H, W = shape_f[0], shape_f[1], shape_f[2] | |
| f_flat = tf.reshape(f, [B, H * W, self.filters // 8]) | |
| g_flat = tf.reshape(g, [B, H * W, self.filters // 8]) | |
| h_flat = tf.reshape(h, [B, H * W, self.filters]) | |
| beta = tf.nn.softmax(tf.matmul(f_flat, g_flat, transpose_b=True), axis=-1) # (B, N, N) | |
| o = tf.matmul(beta, h_flat) # (B, N, C) | |
| o = tf.reshape(o, [B, H, W, self.filters]) | |
| return Add()([x, o]) # Residual connection | |
| def get_config(self): | |
| config = super().get_config() | |
| return config | |
| def frames_to_video_bytes(frames, fps=30): | |
| buffer = io.BytesIO() | |
| writer = imageio.get_writer(buffer, format="mp4", fps=fps) | |
| for frame in frames: | |
| writer.append_data(frame) | |
| writer.close() | |
| buffer.seek(0) | |
| return buffer.read() | |
| def get_codebook_indices(inputs,codebook): | |
| embeddings=codebook | |
| embedding_dim = tf.shape(codebook)[-1] | |
| input_shape = tf.shape(inputs) | |
| flat_inputs = tf.reshape(inputs, [-1, embedding_dim]) | |
| flat_inputs = tf.cast(flat_inputs, dtype=embeddings.dtype) | |
| # Compute distances to embedding vectors | |
| '''distances = ( | |
| tf.reduce_sum(flat_inputs ** 2, axis=1, keepdims=True) | |
| - 2 * tf.matmul(flat_inputs, self.embeddings, transpose_b=True) | |
| + tf.reduce_sum(self.embeddings ** 2, axis=1) | |
| )''' | |
| flat_inputs_norm = tf.nn.l2_normalize(flat_inputs, axis=1) # [N, D] | |
| embeddings_norm = tf.nn.l2_normalize(embeddings, axis=1) # [K, D] | |
| cosine_sim = tf.matmul(flat_inputs_norm, embeddings_norm, transpose_b=True) # [N, K] | |
| distances = 1.0 - cosine_sim | |
| encoding_indices = tf.argmin(distances, axis=1) | |
| return tf.reshape(encoding_indices, input_shape[:-1]) | |
| def get_embeddings(indices, codebook): | |
| flat_indices = tf.reshape(indices, [-1]) | |
| flat_embeddings = tf.nn.embedding_lookup(codebook, flat_indices) | |
| out_shape = tf.concat([tf.shape(indices), [tf.shape(codebook)[-1]]], axis=0) | |
| return tf.reshape(flat_embeddings, out_shape) | |
| def latent_interpolation(t,a,b): | |
| t = t/100 | |
| result = (1-t)*a + (t * b) | |
| return result | |
| def load_all(): | |
| encoder = tf.keras.models.load_model('src/epoch_66_encoder_mscoco.keras', | |
| custom_objects={'SelfAttention': SelfAttention}) | |
| decoder = tf.keras.models.load_model('src/epoch_66_decoder_mscoco.keras', | |
| custom_objects={'SelfAttention': SelfAttention}) | |
| codebook = np.load('src/epoch_66_codebook_mscoco.npy') | |
| return encoder, decoder, tf.convert_to_tensor(codebook, dtype=tf.float32) | |
| encoder, decoder, codebook = load_all() | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| img1 = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"],key="uploader1") | |
| with col2: | |
| img2 = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"],key="uploader2") | |
| if img1 is not None and img2 is not None: | |
| frames = [] | |
| img1 = Image.open(img1).resize((128,128)) | |
| img1 = np.array(img1)/255.0 | |
| encoded_img1 = encoder(np.expand_dims(img1,0)) | |
| img2 = Image.open(img2).resize((128,128)) | |
| img2 = np.array(img2)/255.0 | |
| encoded_img2 = encoder(np.expand_dims(img2,0)) | |
| for i in range(0,100,5): | |
| lerped = latent_interpolation(i,encoded_img1,encoded_img2) | |
| out = decoder(get_embeddings(get_codebook_indices(lerped,codebook),codebook)) | |
| frame = np.array(out[0] * 255, dtype=np.uint8) | |
| frames.append(frame) | |
| video_bytes = frames_to_video_bytes(frames,12) | |
| st.video(video_bytes) | |