File size: 4,618 Bytes
59d358a
1c9e9cb
 
59d358a
1c9e9cb
 
 
 
 
 
 
 
 
 
 
bf02fd9
1c9e9cb
bf02fd9
1c9e9cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
08635e8
1c9e9cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
deeb026
1c9e9cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
deeb026
1c9e9cb
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import streamlit as st
import os, sys
import json


st.set_page_config(layout="wide")

# =====================
# Sidebar: scene / object
# =====================
target_scenes = ["00814-p53SfW6mjZe", "00848-ziup5kvtCCR", "00891-cvZr5TUy5C5"]
DATA_BASE = "https://huggingface.co/datasets/bo-miao/LangMap_Demo/resolve/main"

@st.cache_data
def load_data():
    with open(os.path.join("src/compare_langmap_goat_matching.json"), "r") as f:
        annot_data = json.load(f) # list of dict
    with open(os.path.join("src/correspondence.json"), "r") as f:
        corre_data = json.load(f) # dict dict dict
    return annot_data, corre_data

annot_data, corre_data = load_data()


with st.sidebar:
    # select scene and object
    scene_name = st.selectbox("Scene", target_scenes)

    included_oids = sorted([x['Object_id'] for x in annot_data if x['Scene']==scene_name], key=int)
    obj_id = st.selectbox("Object", included_oids)

# ===== load related data (outside sidebar) =====
scene_annot_data, scene_corre_data = [x for x in annot_data if x['Scene']==scene_name], corre_data[scene_name]
scene_annot_data_dict = {x["Object_id"]: x for x in scene_annot_data}
# load related data
related_regions, related_objects, related_region2objects = scene_corre_data[obj_id]['regions'], scene_corre_data[obj_id]['objects'], scene_corre_data[obj_id]['region2objects']
lang_text, goat_text, lang_match, goat_match = (scene_annot_data_dict[obj_id]['LangMap_Concise'], scene_annot_data_dict[obj_id]['GOAT_Bench'],
                                                scene_annot_data_dict[obj_id]['LangMap_Correct'], scene_annot_data_dict[obj_id]['GOAT_Correct'])


st.markdown("""
<style>
.row{font-size:26px;margin-bottom:6px;}
.title{font-weight:900;}
.ok{color:#16a34a;}
.bad{color:#dc2626;}
.match{font-size:22px;margin-left:2px;margin-bottom:18px;margin-top:-8px;}
.match2{font-size:22px;margin-left:2px;margin-bottom:18px;margin-top:-16px;}
</style>
""", unsafe_allow_html=True)


def show_desc(title, text, ok):
    mark = "βœ“" if ok else "βœ—"
    cls = "ok" if ok else "bad"

    if title != "LangMap":
        st.markdown(
            f"""
            <div class="row">
                <span class="title">{title}:</span> {text}
            </div>
            <div class="match"><b>MLLM One-to-Many Match:</b> <span class="{cls}">{mark}</span></div>
            """,
            unsafe_allow_html=True
        )
    else:
        st.markdown(
            f"""
            <div class="row">
                <span class="title">{title}:</span> {text}
            </div>
            <div class="match"><b>MLLM One-to-Many Match:</b> <span class="{cls}">{mark}</span></div>
            <div class="match2"><b>Human-verified Match:</b> <span class="ok">βœ“</span></div>
            """,
            unsafe_allow_html=True
        )

st.markdown(
    f"<h3 style='text-align:center;font-size:40px;margin-top:-42px;'>Selected Object_id: {obj_id}</h3>",
    unsafe_allow_html=True
)
show_desc("LangMap", lang_text, lang_match)
show_desc("GOAT-Bench", goat_text, goat_match)
st.markdown("<hr style='margin:6px 0;'>", unsafe_allow_html=True)



# =====================
# Region + Object grid
# =====================
NUM_COLS = 2
for i in range(0, len(related_regions), NUM_COLS):
    # get rows
    row_regions = related_regions[i:i+NUM_COLS]
    cols = st.columns(NUM_COLS)#len(row_regions))

    for col, region in zip(cols, row_regions):
        with col:
            # region view
            region_p = os.path.join(DATA_BASE, f"{scene_name}/region_views/{region}.jpg")
            st.image(region_p, use_container_width=False)

            # objects
            # object_ps = [
            #     p for p in glob.glob(os.path.join(DATA_BASE, f"{scene_name}/object_views/{region}_*.jpg"))
            #     if p.split('_')[-1].split('.')[0] in related_objects
            # ]
            object_ps = [
                os.path.join(DATA_BASE, f"{scene_name}/object_views/{region}_{oid}.jpg")
                for oid in related_region2objects[region]
            ]

            MAX_OBJ = 5
            obj_cols = st.columns(MAX_OBJ)
            for j, p in enumerate(object_ps):
                oid = os.path.basename(p).split('_')[-1].split('.')[0]
                with obj_cols[j % MAX_OBJ]:
                    st.image(p, use_container_width=False)
                    display_color = "red" if oid == obj_id else "black"
                    st.markdown(
                        f"<div style='color:{display_color};font-size:19px;font-weight:600;text-align:center;margin-top:-12px;'>Object_id: {oid}</div>", unsafe_allow_html=True)