from PIL import Image import streamlit as st from streamlit_drawable_canvas import st_canvas import tensorflow_addons as tfa import tensorflow as tf import numpy as np from tensorflow.keras.utils import custom_object_scope # Define a function to create the InstanceNormalization layer def create_in(): return tfa.layers.InstanceNormalization() def model_out(model_path,img): with custom_object_scope({'InstanceNormalization': create_in}): model = tf.keras.models.load_model(model_path) img = (img-127.5)/127.5 img = np.expand_dims(img, 0) pred = model.predict(img) pred = np.asarray(pred) return pred[0] # Specify canvas parameters in application drawing_mode = st.sidebar.selectbox( "Drawing tool:", ("freedraw", "line") ) stroke_width = st.sidebar.slider("Stroke width: ", 1, 25, 3) realtime_update = st.sidebar.checkbox("Update in realtime", True) # Create a canvas component canvas_result = st_canvas( stroke_width=stroke_width, stroke_color='#000000', background_color='#FFFFFF', update_streamlit=realtime_update, height=256, width=256, drawing_mode=drawing_mode, key="canvas", ) # Do something interesting with the image data and paths if canvas_result.image_data is not None: st.image(canvas_result.image_data) img = np.array(canvas_result.image_data) img_rgb = Image.fromarray(img).convert("L") img = np.array(img_rgb) pred = model_out('Sketch2Car.h5', img) pred = (pred + 1.0) / 2.0 # Undo tanh normalization to get values in [0.0, 1.0] range pred = Image.fromarray((pred * 255).astype(np.uint8)).convert("RGB") st.image(pred)