Sketch2Car / app.py
Beasto's picture
Update app.py
5179ae7
from PIL import Image
import streamlit as st
from streamlit_drawable_canvas import st_canvas
import tensorflow_addons as tfa
import tensorflow as tf
import numpy as np
from tensorflow.keras.utils import custom_object_scope
# Define a function to create the InstanceNormalization layer
def create_in():
return tfa.layers.InstanceNormalization()
def model_out(model_path,img):
with custom_object_scope({'InstanceNormalization': create_in}):
model = tf.keras.models.load_model(model_path)
img = (img-127.5)/127.5
img = np.expand_dims(img, 0)
pred = model.predict(img)
pred = np.asarray(pred)
return pred[0]
# Specify canvas parameters in application
drawing_mode = st.sidebar.selectbox(
"Drawing tool:", ("freedraw", "line")
)
stroke_width = st.sidebar.slider("Stroke width: ", 1, 25, 3)
realtime_update = st.sidebar.checkbox("Update in realtime", True)
# Create a canvas component
canvas_result = st_canvas(
stroke_width=stroke_width,
stroke_color='#000000',
background_color='#FFFFFF',
update_streamlit=realtime_update,
height=256,
width=256,
drawing_mode=drawing_mode,
key="canvas",
)
# Do something interesting with the image data and paths
if canvas_result.image_data is not None:
st.image(canvas_result.image_data)
img = np.array(canvas_result.image_data)
img_rgb = Image.fromarray(img).convert("L")
img = np.array(img_rgb)
pred = model_out('Sketch2Car.h5', img)
pred = (pred + 1.0) / 2.0 # Undo tanh normalization to get values in [0.0, 1.0] range
pred = Image.fromarray((pred * 255).astype(np.uint8)).convert("RGB")
st.image(pred)