Spaces:
Sleeping
Sleeping
| import time | |
| import numpy as np | |
| import streamlit as st | |
| from PIL import Image | |
| from s_multimae.da.base_da import BaseDataAugmentation | |
| from .base_model import BaseRGBDModel | |
| from .depth_model import BaseDepthModel | |
| from .model import base_inference | |
| def image_inference( | |
| depth_model: BaseDepthModel, | |
| sod_model: BaseRGBDModel, | |
| da: BaseDataAugmentation, | |
| color: np.ndarray, | |
| ) -> None: | |
| if "depth" not in st.session_state: | |
| st.session_state.depth = None | |
| col1, col2 = st.columns(2) | |
| image: Image = None | |
| # depth: Image = None | |
| def file_uploader_on_change(): | |
| st.session_state.depth = None | |
| with col1: | |
| img_file_buffer = st.file_uploader( | |
| "Upload an RGB image", | |
| key="img_file_buffer", | |
| type=["png", "jpg", "jpeg"], | |
| on_change=file_uploader_on_change, | |
| ) | |
| if img_file_buffer is not None: | |
| image = Image.open(img_file_buffer).convert("RGB") | |
| st.image(image, caption="RGB") | |
| with col2: | |
| depth_file_buffer = st.file_uploader( | |
| "Upload a depth image (Optional)", | |
| key="depth_file_buffer", | |
| type=["png", "jpg", "jpeg"], | |
| ) | |
| if depth_file_buffer is not None: | |
| st.session_state.depth = Image.open(depth_file_buffer).convert("L") | |
| if st.session_state.depth is not None: | |
| st.image(st.session_state.depth, caption="Depth") | |
| if sod_model.cfg.ground_truth_version == 6: | |
| num_sets_of_salient_objects = st.number_input( | |
| "Number of sets of salient objects", value=1, min_value=1, max_value=10 | |
| ) | |
| else: | |
| num_sets_of_salient_objects = 1 | |
| is_predict = st.button( | |
| "Predict Salient Objects", | |
| key="predict_salient_objects", | |
| disabled=img_file_buffer is None, | |
| ) | |
| if is_predict: | |
| with st.spinner( | |
| "Processing... (It usually takes about 30s - 1 minute per a set of salient objects)" | |
| ): | |
| start_time = time.time() | |
| pred_depth, pred_sods, pred_sms = base_inference( | |
| depth_model, | |
| sod_model, | |
| da, | |
| image, | |
| st.session_state.depth, | |
| color, | |
| num_sets_of_salient_objects, | |
| ) | |
| if st.session_state.depth is None: | |
| st.session_state.depth = Image.fromarray(pred_depth).convert("L") | |
| col2.image(st.session_state.depth, "Pseudo-depth") | |
| if num_sets_of_salient_objects == 1: | |
| st.warning( | |
| "HINT: To view a wider variety of sets of salient objects, try to increase the number of sets the model can produce." | |
| ) | |
| elif num_sets_of_salient_objects > 1: | |
| st.warning( | |
| "NOTE: As single-GT accounts for 77.61% of training samples, the model may not consistently yield different sets. The best approach is to gradually increase the number of sets of salient objects until you achieve the desired result." | |
| ) | |
| st.info(f"Inference time: {time.time() - start_time:.4f} seconds") | |
| sod_cols = st.columns(len(pred_sods)) | |
| for i, (pred_sod, pred_sm) in enumerate(zip(pred_sods, pred_sms)): | |
| with sod_cols[i]: | |
| st.image(pred_sod, "Salient Objects (Otsu threshold)") | |
| st.image(pred_sm, "Salient Map") | |