Spaces:
Sleeping
Sleeping
| import torch | |
| import streamlit as st | |
| import time | |
| from app_lib.user_input import ( | |
| get_cardinality, | |
| get_class_name, | |
| get_concepts, | |
| get_image, | |
| get_model_name, | |
| ) | |
| from app_lib.test import ( | |
| load_dataset, | |
| load_model, | |
| encode_image, | |
| encode_concepts, | |
| encode_class_name, | |
| ) | |
| def _disable(): | |
| st.session_state.disabled = True | |
| def main(device=torch.device("cuda" if torch.cuda.is_available() else "cpu")): | |
| columns = st.columns([0.40, 0.60]) | |
| with columns[0]: | |
| model_name = get_model_name() | |
| row1 = st.columns(2) | |
| row2 = st.columns(2) | |
| with row1[0]: | |
| image = get_image() | |
| st.image(image, use_column_width=True) | |
| with row1[1]: | |
| class_name, class_ready, class_error = get_class_name() | |
| concepts, concepts_ready, concepts_error = get_concepts() | |
| cardinality = get_cardinality(concepts, concepts_ready) | |
| with row2[0]: | |
| change_image_button = st.button( | |
| "Change Image", | |
| use_container_width=True, | |
| disabled=st.session_state.disabled, | |
| ) | |
| if change_image_button: | |
| st.session_state.sidebar_state = "expanded" | |
| st.experimental_rerun() | |
| with row2[1]: | |
| ready = class_ready and concepts_ready | |
| error_message = "" | |
| if class_error is not None: | |
| error_message += f"- {class_error}\n" | |
| if concepts_error is not None: | |
| error_message += f"- {concepts_error}\n" | |
| if error_message: | |
| st.error(error_message) | |
| test_button = st.button( | |
| "Test", | |
| use_container_width=True, | |
| on_click=_disable, | |
| disabled=st.session_state.disabled or not ready, | |
| ) | |
| with columns[1]: | |
| if test_button: | |
| with st.spinner("Loading dataset"): | |
| embedding = load_dataset("imagenette", model_name) | |
| time.sleep(1) | |
| with st.spinner("Loading model"): | |
| model, preprocess, tokenizer = load_model(model_name, device) | |
| time.sleep(1) | |
| with st.spinner("Encoding concepts"): | |
| cbm = encode_concepts(tokenizer, model, concepts, device) | |
| time.sleep(1) | |
| with st.spinner("Preparing zero-shot classifier"): | |
| classifier = encode_class_name(tokenizer, model, class_name, device) | |
| with st.spinner("Encoding image"): | |
| h = encode_image(model, preprocess, image, device) | |
| z = h @ cbm.T | |
| print(h.shape, cbm.shape, z.shape) | |
| time.sleep(2) | |
| st.session_state.disabled = False | |
| st.experimental_rerun() | |