Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| st.set_page_config(layout="wide") | |
| import random | |
| import numpy as np | |
| import pandas as pd | |
| from PIL import Image | |
| from streamlit_drawable_canvas import st_canvas | |
| from utils import utils | |
| SAM_MODEL = utils.get_model('vit_b') | |
| def click_process(model, show_mask, radius_width): | |
| bg_image = st.session_state['image'] | |
| width, height = bg_image.size[:2] | |
| container_width = 700 | |
| scale = container_width/width | |
| scaled_hw = (container_width, int(height * scale)) | |
| if 'result_image' not in st.session_state: | |
| st.session_state.result_image = bg_image.resize(scaled_hw) | |
| canvas_result = st_canvas( | |
| fill_color="rgba(255, 255, 0, 0.8)", | |
| background_image = bg_image, | |
| drawing_mode='point', | |
| width = container_width, | |
| height = height * scale, | |
| point_display_radius = radius_width, | |
| stroke_width=2, | |
| update_streamlit=True, | |
| key="point",) | |
| # ! Warn: Can cause infinite loop or high cpu usage | |
| if not show_mask: | |
| print("rerun no mask") | |
| st.experimental_rerun() | |
| elif canvas_result.json_data is not None: | |
| df = pd.json_normalize(canvas_result.json_data["objects"]) | |
| input_points = [] | |
| input_labels = [] | |
| for _, row in df.iterrows(): | |
| x, y = int(row["left"] + row["width"]/2), int(row["top"] + row["height"]/2) | |
| x = int(x/scale) | |
| y = int(y/scale) | |
| input_points.append([x, y]) | |
| if row['fill'] == "rgba(0, 255, 0, 0.8)": | |
| input_labels.append(1) | |
| else: | |
| input_labels.append(0) | |
| masks = [] | |
| if model: | |
| masks = utils.model_predict_masks_click(model, input_points, input_labels) | |
| if len(masks) == 0: | |
| return bg_image | |
| bg_image = np.asarray(bg_image) | |
| color = np.concatenate([random.choice(utils.get_color()), np.array([0.6])], axis=0) | |
| im_masked = utils.show_click(masks,color) | |
| im_masked = Image.fromarray(im_masked).convert('RGBA') | |
| result_image = Image.alpha_composite(Image.fromarray(bg_image).convert('RGBA'),im_masked).convert("RGB") | |
| result_image = result_image.resize(scaled_hw) | |
| return result_image | |
| else: | |
| return np.asarray(bg_image) | |
| return np.asarray(bg_image) | |
| def image_preprocess_callback(model): | |
| if 'uploaded_image' not in st.session_state: | |
| return | |
| if st.session_state.uploaded_image is not None: | |
| with st.spinner(text="Uploading image..."): | |
| image = Image.open(st.session_state.uploaded_image).convert("RGB") | |
| if model: | |
| np_image = np.asanyarray(image) | |
| with st.spinner(text="Extracing embeddings.."): | |
| model.set_image(np_image) | |
| st.session_state.image = image | |
| else: | |
| with st.spinner(text="Cleaning up!"): | |
| if 'image' in st.session_state: | |
| st.session_state.image = None | |
| if 'result_image' in st.session_state: | |
| del st.session_state['result_image'] | |
| if model: | |
| model.reset_image() | |
| def main(): | |
| with open('index.html', encoding='utf-8') as f: | |
| html_content = f.read() | |
| st.markdown(html_content, unsafe_allow_html=True) | |
| with st.container(): | |
| col1, col2, col3 = st.columns(3) | |
| with col1: | |
| option = st.selectbox('Segmentation mode', ('Click')) | |
| with col2: | |
| st.write("Show or Hide Mask") | |
| show_mask = st.checkbox('Show mask',value = True) | |
| with col3: | |
| radius_width = st.slider('Radius/Width for Click/Box',0,20,5,1) | |
| with st.container(): | |
| st.write("Upload Image") | |
| st.file_uploader(label='Upload image',type=['png','jpg','tif'], key='uploaded_image', on_change=image_preprocess_callback, args=(SAM_MODEL,), label_visibility="hidden") | |
| result_image = None | |
| canvas_input, canvas_output = st.columns(2) | |
| if 'image' in st.session_state: | |
| with canvas_input: | |
| st.write("Select Interest Area/Objects") | |
| if st.session_state.image is not None: | |
| if option == 'Click': | |
| with st.spinner(text="Computing masks"): | |
| result_image = click_process(SAM_MODEL, show_mask, radius_width) | |
| with canvas_output: | |
| if result_image is not None: | |
| st.write("Result") | |
| st.image(result_image) | |
| else: | |
| print(f'embedding is empty - {option} - {show_mask} - {radius_width}') | |
| # if 'image' in st.session_state: | |
| # if st.session_state.image is None: | |
| # st.session_state.clear() | |
| if __name__ == '__main__': | |
| main() |