File size: 2,821 Bytes
4f55ca2
80dc74c
4f55ca2
80dc74c
 
 
 
 
 
 
 
4f55ca2
 
 
 
 
 
 
 
 
 
 
80dc74c
 
4f55ca2
80dc74c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f55ca2
 
 
 
 
80dc74c
 
 
 
 
 
 
 
 
 
 
4f55ca2
 
80dc74c
 
 
 
4f55ca2
 
80dc74c
 
4f55ca2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import torch
import streamlit as st
import time

from app_lib.user_input import (
    get_cardinality,
    get_class_name,
    get_concepts,
    get_image,
    get_model_name,
)
from app_lib.test import (
    load_dataset,
    load_model,
    encode_image,
    encode_concepts,
    encode_class_name,
)


def _disable():
    st.session_state.disabled = True


def main(device=torch.device("cuda" if torch.cuda.is_available() else "cpu")):
    columns = st.columns([0.40, 0.60])

    with columns[0]:
        model_name = get_model_name()

        row1 = st.columns(2)
        row2 = st.columns(2)

        with row1[0]:
            image = get_image()
            st.image(image, use_column_width=True)
        with row1[1]:
            class_name, class_ready, class_error = get_class_name()
            concepts, concepts_ready, concepts_error = get_concepts()
            cardinality = get_cardinality(concepts, concepts_ready)

        with row2[0]:
            change_image_button = st.button(
                "Change Image",
                use_container_width=True,
                disabled=st.session_state.disabled,
            )
            if change_image_button:
                st.session_state.sidebar_state = "expanded"
                st.experimental_rerun()
        with row2[1]:
            ready = class_ready and concepts_ready

            error_message = ""
            if class_error is not None:
                error_message += f"- {class_error}\n"
            if concepts_error is not None:
                error_message += f"- {concepts_error}\n"
            if error_message:
                st.error(error_message)

            test_button = st.button(
                "Test",
                use_container_width=True,
                on_click=_disable,
                disabled=st.session_state.disabled or not ready,
            )

    with columns[1]:
        if test_button:
            with st.spinner("Loading dataset"):
                embedding = load_dataset("imagenette", model_name)
                time.sleep(1)

            with st.spinner("Loading model"):
                model, preprocess, tokenizer = load_model(model_name, device)
                time.sleep(1)

            with st.spinner("Encoding concepts"):
                cbm = encode_concepts(tokenizer, model, concepts, device)
                time.sleep(1)

            with st.spinner("Preparing zero-shot classifier"):
                classifier = encode_class_name(tokenizer, model, class_name, device)

            with st.spinner("Encoding image"):
                h = encode_image(model, preprocess, image, device)
                z = h @ cbm.T
                print(h.shape, cbm.shape, z.shape)
                time.sleep(2)

            st.session_state.disabled = False
            st.experimental_rerun()