LangMap / src /streamlit_app.py
bo-miao's picture
Upload streamlit_app.py
08635e8 verified
import streamlit as st
import os, sys
import json
st.set_page_config(layout="wide")
# =====================
# Sidebar: scene / object
# =====================
target_scenes = ["00814-p53SfW6mjZe", "00848-ziup5kvtCCR", "00891-cvZr5TUy5C5"]
DATA_BASE = "https://huggingface.co/datasets/bo-miao/LangMap_Demo/resolve/main"
@st.cache_data
def load_data():
with open(os.path.join("src/compare_langmap_goat_matching.json"), "r") as f:
annot_data = json.load(f) # list of dict
with open(os.path.join("src/correspondence.json"), "r") as f:
corre_data = json.load(f) # dict dict dict
return annot_data, corre_data
annot_data, corre_data = load_data()
with st.sidebar:
# select scene and object
scene_name = st.selectbox("Scene", target_scenes)
included_oids = sorted([x['Object_id'] for x in annot_data if x['Scene']==scene_name], key=int)
obj_id = st.selectbox("Object", included_oids)
# ===== load related data (outside sidebar) =====
scene_annot_data, scene_corre_data = [x for x in annot_data if x['Scene']==scene_name], corre_data[scene_name]
scene_annot_data_dict = {x["Object_id"]: x for x in scene_annot_data}
# load related data
related_regions, related_objects, related_region2objects = scene_corre_data[obj_id]['regions'], scene_corre_data[obj_id]['objects'], scene_corre_data[obj_id]['region2objects']
lang_text, goat_text, lang_match, goat_match = (scene_annot_data_dict[obj_id]['LangMap_Concise'], scene_annot_data_dict[obj_id]['GOAT_Bench'],
scene_annot_data_dict[obj_id]['LangMap_Correct'], scene_annot_data_dict[obj_id]['GOAT_Correct'])
st.markdown("""
<style>
.row{font-size:26px;margin-bottom:6px;}
.title{font-weight:900;}
.ok{color:#16a34a;}
.bad{color:#dc2626;}
.match{font-size:22px;margin-left:2px;margin-bottom:18px;margin-top:-8px;}
.match2{font-size:22px;margin-left:2px;margin-bottom:18px;margin-top:-16px;}
</style>
""", unsafe_allow_html=True)
def show_desc(title, text, ok):
mark = "βœ“" if ok else "βœ—"
cls = "ok" if ok else "bad"
if title != "LangMap":
st.markdown(
f"""
<div class="row">
<span class="title">{title}:</span> {text}
</div>
<div class="match"><b>MLLM One-to-Many Match:</b> <span class="{cls}">{mark}</span></div>
""",
unsafe_allow_html=True
)
else:
st.markdown(
f"""
<div class="row">
<span class="title">{title}:</span> {text}
</div>
<div class="match"><b>MLLM One-to-Many Match:</b> <span class="{cls}">{mark}</span></div>
<div class="match2"><b>Human-verified Match:</b> <span class="ok">βœ“</span></div>
""",
unsafe_allow_html=True
)
st.markdown(
f"<h3 style='text-align:center;font-size:40px;margin-top:-42px;'>Selected Object_id: {obj_id}</h3>",
unsafe_allow_html=True
)
show_desc("LangMap", lang_text, lang_match)
show_desc("GOAT-Bench", goat_text, goat_match)
st.markdown("<hr style='margin:6px 0;'>", unsafe_allow_html=True)
# =====================
# Region + Object grid
# =====================
NUM_COLS = 2
for i in range(0, len(related_regions), NUM_COLS):
# get rows
row_regions = related_regions[i:i+NUM_COLS]
cols = st.columns(NUM_COLS)#len(row_regions))
for col, region in zip(cols, row_regions):
with col:
# region view
region_p = os.path.join(DATA_BASE, f"{scene_name}/region_views/{region}.jpg")
st.image(region_p, use_container_width=False)
# objects
# object_ps = [
# p for p in glob.glob(os.path.join(DATA_BASE, f"{scene_name}/object_views/{region}_*.jpg"))
# if p.split('_')[-1].split('.')[0] in related_objects
# ]
object_ps = [
os.path.join(DATA_BASE, f"{scene_name}/object_views/{region}_{oid}.jpg")
for oid in related_region2objects[region]
]
MAX_OBJ = 5
obj_cols = st.columns(MAX_OBJ)
for j, p in enumerate(object_ps):
oid = os.path.basename(p).split('_')[-1].split('.')[0]
with obj_cols[j % MAX_OBJ]:
st.image(p, use_container_width=False)
display_color = "red" if oid == obj_id else "black"
st.markdown(
f"<div style='color:{display_color};font-size:19px;font-weight:600;text-align:center;margin-top:-12px;'>Object_id: {oid}</div>", unsafe_allow_html=True)