Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| st.set_page_config(layout="wide") | |
| import base64 | |
| import random | |
| import numpy as np | |
| import pandas as pd | |
| from PIL import Image | |
| from streamlit_drawable_canvas import st_canvas | |
| from utils import utils | |
| DEFAULT_IMG_TAG = '<img class="content-image" src="" alt="model-architecture">' | |
| with open("figures/medsam.png", "rb") as image_file: | |
| encoded_string = base64.b64encode(image_file.read()).decode('utf-8') | |
| IMAGE_TAG_BASE64 = f'<img class="content-image" src="data:image/png;base64,{encoded_string}" alt="model-architecture">' | |
| PREDICTOR_MODEL, AUTOMASK_MODEL = utils.get_model(checkpoint='checkpoint/medsam_vit_b.pth') | |
| def process_box(predictor_model, show_mask, radius_width): | |
| bg_image = st.session_state['image'] | |
| width, height = bg_image.size[:2] | |
| container_width = 700 | |
| scale = container_width/width | |
| scaled_wh = (container_width, int(height * scale)) | |
| if not predictor_model.is_image_set: | |
| np_image = np.asanyarray(bg_image) | |
| with st.spinner(text="Extracing embeddings.."): | |
| predictor_model.set_image(np_image) | |
| if 'result_image' not in st.session_state: | |
| st.session_state.result_image = bg_image.resize(scaled_wh) | |
| box_canvas = st_canvas( | |
| fill_color="rgba(255, 255, 0, 0)", | |
| background_image = bg_image, | |
| drawing_mode='rect', | |
| stroke_color = "rgba(0, 255, 0, 0.6)", | |
| stroke_width = radius_width, | |
| width = container_width, | |
| height = height * scale, | |
| point_display_radius = 12, | |
| update_streamlit=True, | |
| key="box" | |
| ) | |
| if not show_mask: | |
| if 'rerun_once' in st.session_state: | |
| if st.session_state.rerun_once: | |
| st.session_state.rerun_once = False | |
| else: | |
| st.session_state.rerun_once = True | |
| st.session_state.display_result = True | |
| st.warning("Mask view is disabled", icon="❗") | |
| if st.session_state.rerun_once: | |
| st.experimental_rerun() | |
| else: | |
| return np.asarray(bg_image) | |
| elif box_canvas.json_data is not None: | |
| df = pd.json_normalize(box_canvas.json_data["objects"]) | |
| center_point,center_label,input_box = [],[],[] | |
| center_point, center_label, input_box = [], [], [] | |
| for _, row in df.iterrows(): | |
| x, y, w,h = row["left"], row["top"], row["width"], row["height"] | |
| x = int(x/scale) | |
| y = int(y/scale) | |
| w = int(w/scale) | |
| h = int(h/scale) | |
| center_point.append([x+w/2,y+h/2]) | |
| center_label.append([1]) | |
| input_box.append([x,y,x+w,y+h]) | |
| masks = [] | |
| if predictor_model: | |
| masks = utils.model_predict_masks_box(predictor_model, center_point, center_label, input_box) | |
| if len(masks) == 0: | |
| st.warning("No Masks Found", icon="❗") | |
| return np.asarray(bg_image) | |
| bg_image = np.asarray(bg_image) | |
| color = np.concatenate([random.choice(utils.get_color()), np.array([0.6])], axis=0) | |
| im_masked = utils.show_click(masks,color) | |
| im_masked = Image.fromarray(im_masked).convert('RGBA') | |
| result_image = Image.alpha_composite(Image.fromarray(bg_image).convert('RGBA'),im_masked).convert("RGB") | |
| result_image = result_image.resize(scaled_wh) | |
| st.session_state.display_result = True | |
| return result_image | |
| else: | |
| return np.asarray(bg_image) | |
| return np.asarray(bg_image) | |
| def process_click(predictor_model, show_mask, radius_width): | |
| bg_image = st.session_state['image'] | |
| width, height = bg_image.size[:2] | |
| container_width = 700 | |
| scale = container_width/width | |
| scaled_wh = (container_width, int(height * scale)) | |
| if not predictor_model.is_image_set: | |
| np_image = np.asanyarray(bg_image) | |
| with st.spinner(text="Extracing embeddings.."): | |
| predictor_model.set_image(np_image) | |
| if 'result_image' not in st.session_state: | |
| st.session_state.result_image = bg_image.resize(scaled_wh) | |
| click_canvas = st_canvas( | |
| fill_color="rgba(255, 255, 0, 0.8)", | |
| background_image = bg_image, | |
| drawing_mode='point', | |
| width = container_width, | |
| height = height * scale, | |
| point_display_radius = radius_width, | |
| stroke_width=2, | |
| update_streamlit=True, | |
| key="point",) | |
| if not show_mask: | |
| if 'rerun_once' in st.session_state: | |
| if st.session_state.rerun_once: | |
| st.session_state.rerun_once = False | |
| else: | |
| st.session_state.rerun_once = True | |
| st.session_state.display_result = True | |
| st.warning("Mask view is disabled", icon="❗") | |
| if st.session_state.rerun_once: | |
| st.experimental_rerun() | |
| else: | |
| return np.asarray(bg_image) | |
| elif click_canvas.json_data is not None: | |
| df = pd.json_normalize(click_canvas.json_data["objects"]) | |
| input_points = [] | |
| input_labels = [] | |
| for _, row in df.iterrows(): | |
| x, y = int(row["left"] + row["width"]/2), int(row["top"] + row["height"]/2) | |
| x = int(x/scale) | |
| y = int(y/scale) | |
| input_points.append([x, y]) | |
| if row['fill'] == "rgba(0, 255, 0, 0.8)": | |
| input_labels.append(1) | |
| else: | |
| input_labels.append(0) | |
| masks = [] | |
| if predictor_model: | |
| masks = utils.model_predict_masks_click(predictor_model, input_points, input_labels) | |
| if len(masks) == 0: | |
| st.warning("No Masks Found", icon="❗") | |
| return np.asarray(bg_image) | |
| bg_image = np.asarray(bg_image) | |
| color = np.concatenate([random.choice(utils.get_color()), np.array([0.6])], axis=0) | |
| im_masked = utils.show_click(masks,color) | |
| im_masked = Image.fromarray(im_masked).convert('RGBA') | |
| result_image = Image.alpha_composite(Image.fromarray(bg_image).convert('RGBA'),im_masked).convert("RGB") | |
| result_image = result_image.resize(scaled_wh) | |
| st.session_state.display_result = True | |
| return result_image | |
| else: | |
| return np.asarray(bg_image) | |
| return np.asarray(bg_image) | |
| def process_everything(automask_model, show_mask, radius_width): | |
| bg_image = st.session_state['image'] | |
| width, height = bg_image.size[:2] | |
| container_width = 700 | |
| scale = container_width/width | |
| scaled_wh = (container_width, int(height * scale)) | |
| if 'result_image' not in st.session_state: | |
| st.session_state.result_image = bg_image.resize(scaled_wh) | |
| dummy_canvas = st_canvas( | |
| fill_color="rgba(255, 255, 0, 0.8)", | |
| background_image = bg_image, | |
| drawing_mode='freedraw', | |
| width = container_width, | |
| height = height * scale, | |
| point_display_radius = radius_width, | |
| stroke_width=2, | |
| update_streamlit=False, | |
| key="everything",) | |
| if not show_mask: | |
| if 'rerun_once' in st.session_state: | |
| if st.session_state.rerun_once: | |
| st.session_state.rerun_once = False | |
| else: | |
| st.session_state.rerun_once = True | |
| st.session_state.display_result = True | |
| st.warning("Mask view is disabled", icon="❗") | |
| if st.session_state.rerun_once: | |
| st.experimental_rerun() | |
| else: | |
| return np.asarray(bg_image) | |
| if automask_model: | |
| bg_image = np.asarray(bg_image) | |
| masks = utils.model_predict_masks_everything(automask_model, bg_image) | |
| im_masked = utils.show_everything(masks) | |
| if len(im_masked) == 0: | |
| st.warning("No Masks Found", icon="❗") | |
| return np.asarray(bg_image) | |
| im_masked = Image.fromarray(im_masked).convert('RGBA') | |
| result_image = Image.alpha_composite(Image.fromarray(bg_image).convert('RGBA'),im_masked).convert("RGB") | |
| result_image = result_image.resize(scaled_wh) | |
| st.session_state.display_result = True | |
| return result_image | |
| else: | |
| return np.asarray(bg_image) | |
| def image_preprocess_callback(predictor_model, option): | |
| if 'uploaded_image' not in st.session_state: | |
| return | |
| if st.session_state.uploaded_image is not None: | |
| with st.spinner(text="Uploading image..."): | |
| image = Image.open(st.session_state.uploaded_image).convert("RGB") | |
| if predictor_model and option != 'Everything': | |
| np_image = np.asanyarray(image) | |
| with st.spinner(text="Extracing embeddings.."): | |
| predictor_model.set_image(np_image) | |
| else: | |
| if predictor_model: | |
| predictor_model.reset_image() | |
| st.session_state.image = image | |
| else: | |
| with st.spinner(text="Cleaning up!"): | |
| if 'display_result' in st.session_state: | |
| st.session_state.display_result = False | |
| if 'image' in st.session_state: | |
| st.session_state.image = None | |
| if 'result_image' in st.session_state: | |
| del st.session_state['result_image'] | |
| if predictor_model: | |
| predictor_model.reset_image() | |
| def main(): | |
| with open('index.html', encoding='utf-8') as f: | |
| html_content = f.read() | |
| html_content = html_content.replace(DEFAULT_IMG_TAG, IMAGE_TAG_BASE64) | |
| st.components.v1.html(html_content, width=None, height=925, scrolling=False) | |
| with st.container(): | |
| col1, col2, col3, col4 = st.columns(4) | |
| with col1: | |
| option = st.selectbox('Segmentation mode', ('Click', 'Box', 'Everything')) | |
| with col2: | |
| st.write("Show or Hide Mask") | |
| show_mask = st.checkbox('Show mask',value = True) | |
| with col3: | |
| mask_threshold = st.slider('SAM Confidence Threshold',0.0,1.0,0.5,0.05) | |
| PREDICTOR_MODEL.model.mask_threshold = mask_threshold | |
| with col4: | |
| radius_width = st.slider('Radius/Width for Click/Box',0,20,5,1) | |
| with st.container(): | |
| st.write("Upload Image") | |
| st.file_uploader(label='Upload image',type=['png','jpg','tif'], key='uploaded_image', on_change=image_preprocess_callback, args=(PREDICTOR_MODEL, option,), label_visibility="hidden") | |
| result_image = None | |
| canvas_input, canvas_output = st.columns(2) | |
| if 'image' in st.session_state: | |
| with canvas_input: | |
| st.write("Select Interest Area/Objects") | |
| if st.session_state.image is not None: | |
| with st.spinner(text="Computing masks"): | |
| if option == 'Click': | |
| result_image = process_click(PREDICTOR_MODEL, show_mask, radius_width) | |
| elif option == 'Box': | |
| result_image = process_box(PREDICTOR_MODEL, show_mask, radius_width) | |
| else: | |
| result_image = process_everything(AUTOMASK_MODEL, show_mask, radius_width) | |
| if 'display_result' in st.session_state: | |
| if st.session_state.display_result: | |
| with canvas_output: | |
| if result_image is not None: | |
| st.write("Result") | |
| st.image(result_image) | |
| else: | |
| st.warning("No result found, please set input prompt", icon="⚠️") | |
| st.success('Process completed!', icon="✅") | |
| else: | |
| st.cache_data.clear() | |
| if __name__ == '__main__': | |
| main() |