Latent-Walk / src /streamlit_app.py
Beasto's picture
Update src/streamlit_app.py
a7b53e7 verified
import streamlit as st
import imageio
import io
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Layer, Conv2D, Add
from tensorflow.keras.utils import register_keras_serializable
from PIL import Image
@register_keras_serializable(package="Custom")
class SelfAttention(Layer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def build(self, input_shape):
self.filters = input_shape[-1]
self.f = Conv2D(self.filters // 8, kernel_size=1, padding='same')
self.g = Conv2D(self.filters // 8, kernel_size=1, padding='same')
self.h = Conv2D(self.filters, kernel_size=1, padding='same')
super().build(input_shape)
def call(self, x):
f = self.f(x) # (B, H, W, C//8)
g = self.g(x)
h = self.h(x) # (B, H, W, C)
shape_f = tf.shape(f)
B, H, W = shape_f[0], shape_f[1], shape_f[2]
f_flat = tf.reshape(f, [B, H * W, self.filters // 8])
g_flat = tf.reshape(g, [B, H * W, self.filters // 8])
h_flat = tf.reshape(h, [B, H * W, self.filters])
beta = tf.nn.softmax(tf.matmul(f_flat, g_flat, transpose_b=True), axis=-1) # (B, N, N)
o = tf.matmul(beta, h_flat) # (B, N, C)
o = tf.reshape(o, [B, H, W, self.filters])
return Add()([x, o]) # Residual connection
def get_config(self):
config = super().get_config()
return config
def frames_to_video_bytes(frames, fps=30):
buffer = io.BytesIO()
writer = imageio.get_writer(buffer, format="mp4", fps=fps)
for frame in frames:
writer.append_data(frame)
writer.close()
buffer.seek(0)
return buffer.read()
def get_codebook_indices(inputs,codebook):
embeddings=codebook
embedding_dim = tf.shape(codebook)[-1]
input_shape = tf.shape(inputs)
flat_inputs = tf.reshape(inputs, [-1, embedding_dim])
flat_inputs = tf.cast(flat_inputs, dtype=embeddings.dtype)
# Compute distances to embedding vectors
'''distances = (
tf.reduce_sum(flat_inputs ** 2, axis=1, keepdims=True)
- 2 * tf.matmul(flat_inputs, self.embeddings, transpose_b=True)
+ tf.reduce_sum(self.embeddings ** 2, axis=1)
)'''
flat_inputs_norm = tf.nn.l2_normalize(flat_inputs, axis=1) # [N, D]
embeddings_norm = tf.nn.l2_normalize(embeddings, axis=1) # [K, D]
cosine_sim = tf.matmul(flat_inputs_norm, embeddings_norm, transpose_b=True) # [N, K]
distances = 1.0 - cosine_sim
encoding_indices = tf.argmin(distances, axis=1)
return tf.reshape(encoding_indices, input_shape[:-1])
def get_embeddings(indices, codebook):
flat_indices = tf.reshape(indices, [-1])
flat_embeddings = tf.nn.embedding_lookup(codebook, flat_indices)
out_shape = tf.concat([tf.shape(indices), [tf.shape(codebook)[-1]]], axis=0)
return tf.reshape(flat_embeddings, out_shape)
def latent_interpolation(t,a,b):
t = t/100
result = (1-t)*a + (t * b)
return result
@st.cache_resource
def load_all():
encoder = tf.keras.models.load_model('src/epoch_66_encoder_mscoco.keras',
custom_objects={'SelfAttention': SelfAttention})
decoder = tf.keras.models.load_model('src/epoch_66_decoder_mscoco.keras',
custom_objects={'SelfAttention': SelfAttention})
codebook = np.load('src/epoch_66_codebook_mscoco.npy')
return encoder, decoder, tf.convert_to_tensor(codebook, dtype=tf.float32)
encoder, decoder, codebook = load_all()
col1, col2 = st.columns(2)
with col1:
img1 = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"],key="uploader1")
with col2:
img2 = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"],key="uploader2")
if img1 is not None and img2 is not None:
frames = []
img1 = Image.open(img1).resize((128,128))
img1 = np.array(img1)/255.0
encoded_img1 = encoder(np.expand_dims(img1,0))
img2 = Image.open(img2).resize((128,128))
img2 = np.array(img2)/255.0
encoded_img2 = encoder(np.expand_dims(img2,0))
for i in range(0,100,5):
lerped = latent_interpolation(i,encoded_img1,encoded_img2)
out = decoder(get_embeddings(get_codebook_indices(lerped,codebook),codebook))
frame = np.array(out[0] * 255, dtype=np.uint8)
frames.append(frame)
video_bytes = frames_to_video_bytes(frames,12)
st.video(video_bytes)