Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import torch | |
| from app_lib.test import get_testing_config, load_precomputed_results, test | |
| from app_lib.user_input import ( | |
| get_advanced_settings, | |
| get_class_name, | |
| get_concepts, | |
| get_image, | |
| get_model_name, | |
| ) | |
| from app_lib.viz import viz_results | |
| def _disable(): | |
| st.session_state.disabled = True | |
| def _toggle_sidebar(button): | |
| if button: | |
| st.session_state.sidebar_state = "expanded" | |
| st.experimental_rerun() | |
| def _preload_results(image_name): | |
| if image_name != st.session_state.image_name: | |
| st.session_state.image_name = image_name | |
| st.session_state.tested = False | |
| if st.session_state.image_name is not None and not st.session_state.tested: | |
| st.session_state.results = load_precomputed_results(image_name) | |
| def demo(device=torch.device("cuda" if torch.cuda.is_available() else "cpu")): | |
| columns = st.columns([0.40, 0.60]) | |
| with columns[0]: | |
| st.header("Choose Image and Concepts") | |
| image_col, concepts_col = st.columns(2) | |
| with image_col: | |
| image_name, image = get_image() | |
| st.image(image, use_column_width=True) | |
| change_image_button = st.button( | |
| "Change Image", | |
| use_container_width=False, | |
| disabled=st.session_state.disabled, | |
| ) | |
| _toggle_sidebar(change_image_button) | |
| with concepts_col: | |
| model_name = get_model_name() | |
| class_name, class_ready, class_error = get_class_name(image_name) | |
| concepts, concepts_ready, concepts_error = get_concepts(image_name) | |
| ready = class_ready and concepts_ready | |
| error_message = "" | |
| if class_error is not None: | |
| error_message += f"- {class_error}\n" | |
| if concepts_error is not None: | |
| error_message += f"- {concepts_error}\n" | |
| if error_message: | |
| st.error(error_message) | |
| with st.container(): | |
| ( | |
| significance_level, | |
| tau_max, | |
| r, | |
| cardinality, | |
| dataset_name, | |
| ) = get_advanced_settings(concepts, concepts_ready) | |
| test_button = st.button( | |
| "Test Concepts", | |
| use_container_width=True, | |
| on_click=_disable, | |
| disabled=st.session_state.disabled or not ready, | |
| ) | |
| if test_button: | |
| st.session_state.results = None | |
| with columns[1]: | |
| viz_results() | |
| testing_config = get_testing_config( | |
| significance_level=significance_level, tau_max=tau_max, r=r | |
| ) | |
| with columns[0]: | |
| results = test( | |
| testing_config, | |
| image, | |
| class_name, | |
| concepts, | |
| cardinality, | |
| dataset_name, | |
| model_name, | |
| device=device, | |
| ) | |
| st.session_state.tested = True | |
| st.session_state.results = results | |
| st.session_state.disabled = False | |
| st.experimental_rerun() | |
| else: | |
| _preload_results(image_name) | |
| with columns[1]: | |
| viz_results() | |