|
|
import streamlit as st |
|
|
import os, sys |
|
|
import json |
|
|
|
|
|
|
|
|
st.set_page_config(layout="wide") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
target_scenes = ["00814-p53SfW6mjZe", "00848-ziup5kvtCCR", "00891-cvZr5TUy5C5"] |
|
|
DATA_BASE = "https://huggingface.co/datasets/bo-miao/LangMap_Demo/resolve/main" |
|
|
|
|
|
@st.cache_data |
|
|
def load_data(): |
|
|
with open(os.path.join("src/compare_langmap_goat_matching.json"), "r") as f: |
|
|
annot_data = json.load(f) |
|
|
with open(os.path.join("src/correspondence.json"), "r") as f: |
|
|
corre_data = json.load(f) |
|
|
return annot_data, corre_data |
|
|
|
|
|
annot_data, corre_data = load_data() |
|
|
|
|
|
|
|
|
with st.sidebar: |
|
|
|
|
|
scene_name = st.selectbox("Scene", target_scenes) |
|
|
|
|
|
included_oids = sorted([x['Object_id'] for x in annot_data if x['Scene']==scene_name], key=int) |
|
|
obj_id = st.selectbox("Object", included_oids) |
|
|
|
|
|
|
|
|
scene_annot_data, scene_corre_data = [x for x in annot_data if x['Scene']==scene_name], corre_data[scene_name] |
|
|
scene_annot_data_dict = {x["Object_id"]: x for x in scene_annot_data} |
|
|
|
|
|
related_regions, related_objects, related_region2objects = scene_corre_data[obj_id]['regions'], scene_corre_data[obj_id]['objects'], scene_corre_data[obj_id]['region2objects'] |
|
|
lang_text, goat_text, lang_match, goat_match = (scene_annot_data_dict[obj_id]['LangMap_Concise'], scene_annot_data_dict[obj_id]['GOAT_Bench'], |
|
|
scene_annot_data_dict[obj_id]['LangMap_Correct'], scene_annot_data_dict[obj_id]['GOAT_Correct']) |
|
|
|
|
|
|
|
|
st.markdown(""" |
|
|
<style> |
|
|
.row{font-size:26px;margin-bottom:6px;} |
|
|
.title{font-weight:900;} |
|
|
.ok{color:#16a34a;} |
|
|
.bad{color:#dc2626;} |
|
|
.match{font-size:22px;margin-left:2px;margin-bottom:18px;margin-top:-8px;} |
|
|
.match2{font-size:22px;margin-left:2px;margin-bottom:18px;margin-top:-16px;} |
|
|
</style> |
|
|
""", unsafe_allow_html=True) |
|
|
|
|
|
|
|
|
def show_desc(title, text, ok): |
|
|
mark = "β" if ok else "β" |
|
|
cls = "ok" if ok else "bad" |
|
|
|
|
|
if title != "LangMap": |
|
|
st.markdown( |
|
|
f""" |
|
|
<div class="row"> |
|
|
<span class="title">{title}:</span> {text} |
|
|
</div> |
|
|
<div class="match"><b>MLLM One-to-Many Match:</b> <span class="{cls}">{mark}</span></div> |
|
|
""", |
|
|
unsafe_allow_html=True |
|
|
) |
|
|
else: |
|
|
st.markdown( |
|
|
f""" |
|
|
<div class="row"> |
|
|
<span class="title">{title}:</span> {text} |
|
|
</div> |
|
|
<div class="match"><b>MLLM One-to-Many Match:</b> <span class="{cls}">{mark}</span></div> |
|
|
<div class="match2"><b>Human-verified Match:</b> <span class="ok">β</span></div> |
|
|
""", |
|
|
unsafe_allow_html=True |
|
|
) |
|
|
|
|
|
st.markdown( |
|
|
f"<h3 style='text-align:center;font-size:40px;margin-top:-42px;'>Selected Object_id: {obj_id}</h3>", |
|
|
unsafe_allow_html=True |
|
|
) |
|
|
show_desc("LangMap", lang_text, lang_match) |
|
|
show_desc("GOAT-Bench", goat_text, goat_match) |
|
|
st.markdown("<hr style='margin:6px 0;'>", unsafe_allow_html=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
NUM_COLS = 2 |
|
|
for i in range(0, len(related_regions), NUM_COLS): |
|
|
|
|
|
row_regions = related_regions[i:i+NUM_COLS] |
|
|
cols = st.columns(NUM_COLS) |
|
|
|
|
|
for col, region in zip(cols, row_regions): |
|
|
with col: |
|
|
|
|
|
region_p = os.path.join(DATA_BASE, f"{scene_name}/region_views/{region}.jpg") |
|
|
st.image(region_p, use_container_width=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
object_ps = [ |
|
|
os.path.join(DATA_BASE, f"{scene_name}/object_views/{region}_{oid}.jpg") |
|
|
for oid in related_region2objects[region] |
|
|
] |
|
|
|
|
|
MAX_OBJ = 5 |
|
|
obj_cols = st.columns(MAX_OBJ) |
|
|
for j, p in enumerate(object_ps): |
|
|
oid = os.path.basename(p).split('_')[-1].split('.')[0] |
|
|
with obj_cols[j % MAX_OBJ]: |
|
|
st.image(p, use_container_width=False) |
|
|
display_color = "red" if oid == obj_id else "black" |
|
|
st.markdown( |
|
|
f"<div style='color:{display_color};font-size:19px;font-weight:600;text-align:center;margin-top:-12px;'>Object_id: {oid}</div>", unsafe_allow_html=True) |