Spaces:
Sleeping
Sleeping
| import json | |
| import os | |
| import streamlit as st | |
| from PIL import Image | |
| from streamlit_image_select import image_select | |
| from app_lib.defaults import Defaults as d | |
| from app_lib.utils import SUPPORTED_DATASETS, SUPPORTED_MODELS | |
| IMAGE_DIR = os.path.join("assets", "images") | |
| IMAGE_NAMES = list(sorted(filter(lambda x: x.endswith(".jpg"), os.listdir(IMAGE_DIR)))) | |
| IMAGE_PATHS = list(map(lambda x: os.path.join(IMAGE_DIR, x), IMAGE_NAMES)) | |
| IMAGE_PRESETS = json.load(open("assets/image_presets.json")) | |
| def _validate_class_name(class_name): | |
| if class_name is None: | |
| return (False, "Class name cannot be empty.") | |
| if class_name.strip() == "": | |
| return (False, "Class name cannot be empty.") | |
| return (True, None) | |
| def _validate_concepts(concepts): | |
| if len(concepts) < 3: | |
| return (False, "You must provide at least 3 concepts") | |
| if len(concepts) > 10: | |
| return (False, "Maximum 10 concepts allowed") | |
| return (True, None) | |
| def _get_significance_level(): | |
| default = d.SIGNIFICANCE_LEVEL_VALUE | |
| step = d.SIGNIFICANCE_LEVEL_STEP | |
| return st.slider( | |
| "Significance level", | |
| help=f"The level of significance of the tests. Defaults to {default:.2F}.", | |
| min_value=step, | |
| max_value=1.0, | |
| value=default, | |
| step=step, | |
| disabled=st.session_state.disabled, | |
| ) | |
| def _get_tau_max(): | |
| default = d.TAU_MAX_VALUE | |
| step = d.TAU_MAX_STEP | |
| return int( | |
| st.slider( | |
| "Length of test", | |
| help=f"The maximum number of steps for each test. Defaults to {default}.", | |
| min_value=step, | |
| max_value=1000, | |
| step=step, | |
| value=default, | |
| disabled=st.session_state.disabled, | |
| ) | |
| ) | |
| def _get_number_of_tests(): | |
| default = d.R_VALUE | |
| step = d.R_STEP | |
| return int( | |
| st.slider( | |
| "Number of tests per concept", | |
| help=( | |
| "The number of tests to average for each concept. " | |
| f"Defaults to {default}." | |
| ), | |
| min_value=step, | |
| max_value=100, | |
| step=step, | |
| value=default, | |
| disabled=st.session_state.disabled, | |
| ) | |
| ) | |
| def _get_cardinality(concepts, concepts_ready): | |
| default = d.CARDINALITY_VALUE | |
| step = d.CARDINALITY_STEP | |
| return st.slider( | |
| "Size of conditioning set", | |
| help=( | |
| "The number of concepts to condition model predictions on. " | |
| "Defaults to {default}." | |
| ), | |
| min_value=1, | |
| max_value=max(2, len(concepts) - 1), | |
| value=default, | |
| step=step, | |
| disabled=st.session_state.disabled or not concepts_ready, | |
| ) | |
| def _get_dataset_name(): | |
| options = SUPPORTED_DATASETS | |
| default_idx = options.index(d.DATASET_NAME) | |
| return st.selectbox( | |
| "Dataset", | |
| options=options, | |
| index=default_idx, | |
| help=( | |
| "Name of the dataset to use to train sampler." | |
| f"Defaults to {SUPPORTED_DATASETS[default_idx]}." | |
| ), | |
| disabled=st.session_state.disabled, | |
| ) | |
| def get_model_name(): | |
| options = list(SUPPORTED_MODELS) | |
| default_idx = options.index(d.MODEL_NAME) | |
| return st.selectbox( | |
| "Model to test", | |
| options=options, | |
| index=default_idx, | |
| help=( | |
| "Name of the vision-language model to test the predictions of." | |
| f"Defaults to {options[default_idx]}" | |
| ), | |
| disabled=st.session_state.disabled, | |
| ) | |
| def get_image(): | |
| with st.sidebar: | |
| uploaded_file = st.file_uploader("Upload an image", type=["jpg", "png", "jpeg"]) | |
| if uploaded_file is not None: | |
| return (None, Image.open(uploaded_file)) | |
| else: | |
| DEFAULT = IMAGE_NAMES.index("bowl_ace.jpg") | |
| image_idx = image_select( | |
| label="or select one", | |
| images=IMAGE_PATHS, | |
| index=DEFAULT, | |
| return_value="index", | |
| ) | |
| image_name, image_path = IMAGE_NAMES[image_idx], IMAGE_PATHS[image_idx] | |
| return (image_name, Image.open(image_path)) | |
| def get_class_name(image_name=None): | |
| default = ( | |
| IMAGE_PRESETS[image_name.split(".")[0]]["class_name"] if image_name else "" | |
| ) | |
| class_name = st.text_input( | |
| "Class to predict", | |
| help="Name of the class to build the zero-shot CLIP classifier with.", | |
| value=default, | |
| disabled=st.session_state.disabled, | |
| placeholder="Type class name here", | |
| ) | |
| class_ready, class_error = _validate_class_name(class_name) | |
| return class_name, class_ready, class_error | |
| def get_concepts(image_name=None): | |
| default = ( | |
| "\n".join(IMAGE_PRESETS[image_name.split(".")[0]]["concepts"]) | |
| if image_name | |
| else "" | |
| ) | |
| concepts = st.text_area( | |
| "Concepts to test", | |
| help=( | |
| "List of concepts to test the predictions of the model with. " | |
| "Write one concept per line. Maximum 10 concepts allowed." | |
| ), | |
| height=180, | |
| value=default, | |
| disabled=st.session_state.disabled, | |
| placeholder="Type one concept\nper line", | |
| ) | |
| concepts = concepts.split("\n") | |
| concepts = [concept.strip() for concept in concepts] | |
| concepts = [concept for concept in concepts if concept != ""] | |
| concepts = list(set(concepts)) | |
| concepts_ready, concepts_error = _validate_concepts(concepts) | |
| return concepts, concepts_ready, concepts_error | |
| def get_advanced_settings(concepts, concepts_ready): | |
| with st.expander("Advanced settings"): | |
| dataset_name = _get_dataset_name() | |
| significance_level = _get_significance_level() | |
| tau_max = _get_tau_max() | |
| r = _get_number_of_tests() | |
| cardinality = _get_cardinality(concepts, concepts_ready) | |
| st.divider() | |
| return significance_level, tau_max, r, cardinality, dataset_name | |