Spaces:
Runtime error
Runtime error
| from PIL import Image | |
| import streamlit as st | |
| from streamlit_drawable_canvas import st_canvas | |
| import tensorflow_addons as tfa | |
| import tensorflow as tf | |
| import numpy as np | |
| from tensorflow.keras.utils import custom_object_scope | |
| # Define a function to create the InstanceNormalization layer | |
| def create_in(): | |
| return tfa.layers.InstanceNormalization() | |
| def model_out(model_path,img): | |
| with custom_object_scope({'InstanceNormalization': create_in}): | |
| model = tf.keras.models.load_model(model_path) | |
| img = (img-127.5)/127.5 | |
| img = np.expand_dims(img, 0) | |
| pred = model.predict(img) | |
| pred = np.asarray(pred) | |
| return pred[0] | |
| # Specify canvas parameters in application | |
| drawing_mode = st.sidebar.selectbox( | |
| "Drawing tool:", ("freedraw", "line") | |
| ) | |
| stroke_width = st.sidebar.slider("Stroke width: ", 1, 25, 3) | |
| realtime_update = st.sidebar.checkbox("Update in realtime", True) | |
| # Create a canvas component | |
| canvas_result = st_canvas( | |
| stroke_width=stroke_width, | |
| stroke_color='#000000', | |
| background_color='#FFFFFF', | |
| update_streamlit=realtime_update, | |
| height=256, | |
| width=256, | |
| drawing_mode=drawing_mode, | |
| key="canvas", | |
| ) | |
| # Do something interesting with the image data and paths | |
| if canvas_result.image_data is not None: | |
| st.image(canvas_result.image_data) | |
| img = np.array(canvas_result.image_data) | |
| img_rgb = Image.fromarray(img).convert("L") | |
| img = np.array(img_rgb) | |
| pred = model_out('Sketch2Car.h5', img) | |
| pred = (pred + 1.0) / 2.0 # Undo tanh normalization to get values in [0.0, 1.0] range | |
| pred = Image.fromarray((pred * 255).astype(np.uint8)).convert("RGB") | |
| st.image(pred) | |