File size: 4,618 Bytes
59d358a 1c9e9cb 59d358a 1c9e9cb bf02fd9 1c9e9cb bf02fd9 1c9e9cb 08635e8 1c9e9cb deeb026 1c9e9cb deeb026 1c9e9cb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
import streamlit as st
import os, sys
import json
st.set_page_config(layout="wide")
# =====================
# Sidebar: scene / object
# =====================
target_scenes = ["00814-p53SfW6mjZe", "00848-ziup5kvtCCR", "00891-cvZr5TUy5C5"]
DATA_BASE = "https://huggingface.co/datasets/bo-miao/LangMap_Demo/resolve/main"
@st.cache_data
def load_data():
with open(os.path.join("src/compare_langmap_goat_matching.json"), "r") as f:
annot_data = json.load(f) # list of dict
with open(os.path.join("src/correspondence.json"), "r") as f:
corre_data = json.load(f) # dict dict dict
return annot_data, corre_data
annot_data, corre_data = load_data()
with st.sidebar:
# select scene and object
scene_name = st.selectbox("Scene", target_scenes)
included_oids = sorted([x['Object_id'] for x in annot_data if x['Scene']==scene_name], key=int)
obj_id = st.selectbox("Object", included_oids)
# ===== load related data (outside sidebar) =====
scene_annot_data, scene_corre_data = [x for x in annot_data if x['Scene']==scene_name], corre_data[scene_name]
scene_annot_data_dict = {x["Object_id"]: x for x in scene_annot_data}
# load related data
related_regions, related_objects, related_region2objects = scene_corre_data[obj_id]['regions'], scene_corre_data[obj_id]['objects'], scene_corre_data[obj_id]['region2objects']
lang_text, goat_text, lang_match, goat_match = (scene_annot_data_dict[obj_id]['LangMap_Concise'], scene_annot_data_dict[obj_id]['GOAT_Bench'],
scene_annot_data_dict[obj_id]['LangMap_Correct'], scene_annot_data_dict[obj_id]['GOAT_Correct'])
st.markdown("""
<style>
.row{font-size:26px;margin-bottom:6px;}
.title{font-weight:900;}
.ok{color:#16a34a;}
.bad{color:#dc2626;}
.match{font-size:22px;margin-left:2px;margin-bottom:18px;margin-top:-8px;}
.match2{font-size:22px;margin-left:2px;margin-bottom:18px;margin-top:-16px;}
</style>
""", unsafe_allow_html=True)
def show_desc(title, text, ok):
mark = "β" if ok else "β"
cls = "ok" if ok else "bad"
if title != "LangMap":
st.markdown(
f"""
<div class="row">
<span class="title">{title}:</span> {text}
</div>
<div class="match"><b>MLLM One-to-Many Match:</b> <span class="{cls}">{mark}</span></div>
""",
unsafe_allow_html=True
)
else:
st.markdown(
f"""
<div class="row">
<span class="title">{title}:</span> {text}
</div>
<div class="match"><b>MLLM One-to-Many Match:</b> <span class="{cls}">{mark}</span></div>
<div class="match2"><b>Human-verified Match:</b> <span class="ok">β</span></div>
""",
unsafe_allow_html=True
)
st.markdown(
f"<h3 style='text-align:center;font-size:40px;margin-top:-42px;'>Selected Object_id: {obj_id}</h3>",
unsafe_allow_html=True
)
show_desc("LangMap", lang_text, lang_match)
show_desc("GOAT-Bench", goat_text, goat_match)
st.markdown("<hr style='margin:6px 0;'>", unsafe_allow_html=True)
# =====================
# Region + Object grid
# =====================
NUM_COLS = 2
for i in range(0, len(related_regions), NUM_COLS):
# get rows
row_regions = related_regions[i:i+NUM_COLS]
cols = st.columns(NUM_COLS)#len(row_regions))
for col, region in zip(cols, row_regions):
with col:
# region view
region_p = os.path.join(DATA_BASE, f"{scene_name}/region_views/{region}.jpg")
st.image(region_p, use_container_width=False)
# objects
# object_ps = [
# p for p in glob.glob(os.path.join(DATA_BASE, f"{scene_name}/object_views/{region}_*.jpg"))
# if p.split('_')[-1].split('.')[0] in related_objects
# ]
object_ps = [
os.path.join(DATA_BASE, f"{scene_name}/object_views/{region}_{oid}.jpg")
for oid in related_region2objects[region]
]
MAX_OBJ = 5
obj_cols = st.columns(MAX_OBJ)
for j, p in enumerate(object_ps):
oid = os.path.basename(p).split('_')[-1].split('.')[0]
with obj_cols[j % MAX_OBJ]:
st.image(p, use_container_width=False)
display_color = "red" if oid == obj_id else "black"
st.markdown(
f"<div style='color:{display_color};font-size:19px;font-weight:600;text-align:center;margin-top:-12px;'>Object_id: {oid}</div>", unsafe_allow_html=True) |