Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from PIL import Image | |
| from streamlit_image_select import image_select | |
| from app_lib.utils import SUPPORTED_MODELS | |
| def _validate_class_name(class_name): | |
| if class_name is None: | |
| return (False, "Class name cannot be empty.") | |
| if class_name.strip() == "": | |
| return (False, "Class name cannot be empty.") | |
| return (True, None) | |
| def _validate_concepts(concepts): | |
| if len(concepts) < 3: | |
| return (False, "You must provide at least 3 concepts") | |
| if len(concepts) > 10: | |
| return (False, "Maximum 10 concepts allowed") | |
| return (True, None) | |
| def get_model_name(): | |
| return st.selectbox( | |
| "Choose a model to test", | |
| options=list(SUPPORTED_MODELS.keys()), | |
| help="Name of the vision-language model to test the predictions of.", | |
| disabled=st.session_state.disabled, | |
| ) | |
| def get_image(): | |
| with st.sidebar: | |
| uploaded_file = st.file_uploader("Upload an image", type=["jpg", "png", "jpeg"]) | |
| image = uploaded_file or image_select( | |
| label="or select one", | |
| images=[ | |
| "assets/ace.jpg", | |
| "assets/ace.jpg", | |
| "assets/ace.jpg", | |
| "assets/ace.jpg", | |
| ], | |
| ) | |
| return Image.open(image) | |
| def get_class_name(): | |
| class_name = st.text_input( | |
| "Class to test", | |
| help="Name of the class to build the zero-shot CLIP classifier with.", | |
| value="cat", | |
| disabled=st.session_state.disabled, | |
| ) | |
| class_ready, class_error = _validate_class_name(class_name) | |
| return class_name, class_ready, class_error | |
| def get_concepts(): | |
| concepts = st.text_area( | |
| "Concepts to test (max 10)", | |
| help="List of concepts to test the predictions of the model with. Write one concept per line.", | |
| height=160, | |
| value="piano\ncute\nwhiskers\nmusic\nwild", | |
| disabled=st.session_state.disabled, | |
| ) | |
| concepts = concepts.split("\n") | |
| concepts = [concept.strip() for concept in concepts] | |
| concepts = [concept for concept in concepts if concept != ""] | |
| concepts = list(set(concepts)) | |
| concepts_ready, concepts_error = _validate_concepts(concepts) | |
| return concepts, concepts_ready, concepts_error | |
| def get_cardinality(concepts, concepts_ready): | |
| return st.slider( | |
| "Size of conditioning set", | |
| help="The number of concepts to condition model predictions on.", | |
| min_value=1, | |
| max_value=max(2, len(concepts) - 1), | |
| value=1, | |
| step=1, | |
| disabled=st.session_state.disabled or not concepts_ready, | |
| ) | |