AnonymousECCV15285's picture
Upload 143 files
51c36ad verified
"""
Dataset Visualizer
MMB-style CSVs from upload, output/, or data/.
Image sets table includes thumbnails; overview, difficulty & questions, answer matrix.
"""
from __future__ import annotations
import io
import json
import shutil
import tempfile
import zipfile
import streamlit as st
import pandas as pd
import plotly.express as px
from pathlib import Path
st.set_page_config(page_title="Dataset Visualizer", layout="wide")
PROJECT_ROOT = Path(__file__).resolve().parent
OUTPUT_DIR = PROJECT_ROOT / "output"
DATA_DIR = PROJECT_ROOT / "data"
HF_DATASET_DIR = PROJECT_ROOT / "hf_dataset" # For Hugging Face Space / repo bundle
WORKSPACE_ROOT = PROJECT_ROOT
# MMB CSV markers
MMB_IMAGE_COLS = ["original_image", "counterfactual1_image", "counterfactual2_image"]
MMB_QA_COLS = ["original_question", "counterfactual1_question", "counterfactual2_question"]
MMB_DIFF_COLS = [
"original_question_difficulty",
"counterfactual1_question_difficulty",
"counterfactual2_question_difficulty",
]
def resolve_image_path(csv_path: Path, fname: str) -> Path | None:
"""Try multiple locations; return first existing path for fname."""
base = csv_path.resolve().parent
candidates = [
base / "images" / fname,
base / fname,
base.parent / "images" / fname,
base.parent / fname,
]
for p in candidates:
if p.exists() and p.is_file():
return p
return None
def scene_id_from_original(original_image: str) -> str:
"""e.g. 'scene_0001_original.png' -> 'scene_0001'."""
s = str(original_image)
for suf in ("_original.png", "_original", ".png"):
if s.endswith(suf):
s = s[: -len(suf)]
break
return s or original_image
def scene_id_from_image_name(fname: str) -> str:
"""e.g. 'scene_0001_cf1.png' or 'scene_0001_original.png' -> 'scene_0001'. Handles _original, _cf1, _cf2."""
s = str(fname).strip()
for suf in (
"_original.png", "_original",
"_cf1.png", "_cf1",
"_cf2.png", "_cf2",
".png",
):
if s.lower().endswith(suf.lower()):
s = s[: -len(suf)]
break
return s.strip() or fname
def _row_scene_id(row, index: int) -> str:
"""Scene ID for display from first image column."""
v = row.get(MMB_IMAGE_COLS[0], None)
if isinstance(v, str) and v:
return scene_id_from_original(v)
return f"row_{index + 1}"
def resolve_scenes_dir(csv_path: Path) -> Path | None:
"""Scenes dir: <csv_parent>/scenes/ or <csv_grandparent>/scenes/."""
base = csv_path.resolve().parent
for candidate in (base / "scenes", base.parent / "scenes"):
if candidate.is_dir():
return candidate
return None
def _scene_id_for_lookup(s: str) -> str:
"""Normalize scene id for scene file lookup (lowercase, strip)."""
return str(s).strip().lower() if s else ""
@st.cache_data
def _get_cf_type_from_scene_file(csv_path: Path, scene_id: str, variant: str) -> str | None:
"""Load scenes/{scene_id}_{variant}.json; return cf_metadata.cf_type. Uses variant from JSON to validate."""
scenes_dir = resolve_scenes_dir(csv_path)
if not scenes_dir:
return None
sid = _scene_id_for_lookup(scene_id)
if not sid:
return None
path = scenes_dir / f"{sid}_{variant}.json"
if not path.exists():
return None
try:
with open(path, encoding="utf-8") as f:
data = json.load(f)
meta = data.get("cf_metadata") or {}
v = meta.get("variant")
if v is not None and str(v).lower() != variant.lower():
return None
t = meta.get("cf_type")
return str(t) if t is not None else None
except Exception:
return None
def get_cf_types_for_scene(csv_path: Path | None, row) -> tuple[str | None, str | None]:
"""Load scene JSONs for this row; return (cf1_type, cf2_type). Uses CF image columns to resolve scene."""
if csv_path is None:
return (None, None)
sid_orig = scene_id_from_original(str(row.get(MMB_IMAGE_COLS[0], "") or ""))
sid_cf1 = scene_id_from_image_name(str(row.get(MMB_IMAGE_COLS[1], "") or "")) or sid_orig
sid_cf2 = scene_id_from_image_name(str(row.get(MMB_IMAGE_COLS[2], "") or "")) or sid_orig
cf1_type = _get_cf_type_from_scene_file(csv_path, sid_cf1, "cf1")
cf2_type = _get_cf_type_from_scene_file(csv_path, sid_cf2, "cf2")
return (cf1_type, cf2_type)
def discover_csvs() -> list[Path]:
"""Find CSVs under output/, data/, and hf_dataset/ (recursive)."""
out: list[Path] = []
for base in (OUTPUT_DIR, DATA_DIR, HF_DATASET_DIR):
if base.exists():
out.extend(base.rglob("*.csv"))
return sorted(set(out), key=lambda p: str(p))
def _extract_upload_zip(zip_bytes: bytes) -> Path:
"""Extract zip to a temp directory; return path to that directory."""
root = Path(tempfile.mkdtemp(prefix="dataset_upload_"))
with zipfile.ZipFile(io.BytesIO(zip_bytes), "r") as z:
z.extractall(root)
return root
def _discover_csvs_in_dir(root: Path) -> list[Path]:
"""Find all CSVs under root (recursive)."""
return sorted(root.rglob("*.csv"), key=lambda p: str(p))
def _csv_from_upload(uploaded) -> tuple[Path, Path] | None:
"""
Process uploaded file (ZIP or CSV). Return (base_dir, csv_path) or None.
- ZIP: extract to temp, discover CSVs; base_dir = extract root, csv_path = chosen CSV.
- CSV: write to temp dir; base_dir = temp dir, csv_path = that CSV.
"""
fname = (uploaded.name or "").lower()
data = uploaded.read()
if fname.endswith(".zip"):
base = _extract_upload_zip(data)
csvs = _discover_csvs_in_dir(base)
if not csvs:
return None
# Prefer MMB-style CSV (name contains "question") if multiple
with_q = [p for p in csvs if "question" in p.name.lower()]
chosen = (with_q[0] if with_q else csvs[0])
return (base, chosen)
if fname.endswith(".csv"):
base = Path(tempfile.mkdtemp(prefix="dataset_upload_"))
path = base / (uploaded.name or "uploaded.csv")
path.write_bytes(data)
return (base, path)
return None
def detect_format(df: pd.DataFrame) -> str:
"""Return 'mmb_qa' | 'mmb_images' | 'mib' | 'unknown'."""
cols = set(df.columns)
has_mmb_images = all(c in cols for c in MMB_IMAGE_COLS)
if has_mmb_images:
if MMB_QA_COLS[0] in cols or "original_question" in cols:
return "mmb_qa"
return "mmb_images"
if "k" in cols and "f" in cols and "method" in cols:
return "mib"
if "dataset" in cols and "split" in cols and "count" in cols:
return "mib"
if "method" in cols and "task" in cols and "CPR" in cols:
return "mib"
return "unknown"
@st.cache_data
def load_csv(path: Path) -> pd.DataFrame:
return pd.read_csv(path)
QA_LABELS = {"original_question": "Original", "counterfactual1_question": "CF1", "counterfactual2_question": "CF2"}
# 3×3 grid: [image][question] -> CSV column name
ANSWER_GRID = [
["original_image_answer_to_original_question", "original_image_answer_to_cf1_question", "original_image_answer_to_cf2_question"],
["cf1_image_answer_to_original_question", "cf1_image_answer_to_cf1_question", "cf1_image_answer_to_cf2_question"],
["cf2_image_answer_to_original_question", "cf2_image_answer_to_cf1_question", "cf2_image_answer_to_cf2_question"],
]
def _render_image_cell(csv_path: Path | None, value, label: str):
"""Display image from file path (CSV/upload) or filename fallback."""
fname = str(value) if value is not None else ""
fp = resolve_image_path(csv_path, fname) if csv_path and fname else None
if fp:
st.image(str(fp), use_container_width=True, caption=label)
return
st.caption(label)
st.code(fname or "(no image)", language=None)
def _answers_as_grid(row, cols: list[str]) -> str | None:
"""Format 3×3 answer matrix as markdown table. Returns None if columns missing."""
def _cell(v):
return str(v).replace("|", "·").replace("\n", " ").strip()
grid = []
for r in range(3):
line = []
for c in range(3):
key = ANSWER_GRID[r][c]
if key not in cols or key not in row.index:
return None
line.append(_cell(row[key]))
grid.append(line)
header = "| | **Original Q** | **CF1 Q** | **CF2 Q** |"
sep = "| --- | --- | --- | --- |"
rows = [
"| **Original** | " + " | ".join(grid[0]) + " |",
"| **CF1** | " + " | ".join(grid[1]) + " |",
"| **CF2** | " + " | ".join(grid[2]) + " |",
]
return "\n".join([header, sep] + rows)
def render_mmb_qa(df: pd.DataFrame, csv_name: str, csv_path: Path | None = None):
st.subheader("MMB-style dataset: " + csv_name)
cols = df.columns.tolist()
tab_sets, tab_overview, tab_difficulty, tab_cf_types, tab_answers = st.tabs([
"Image sets", "Overview", "Difficulty & questions", "Counterfactual types", "Answer matrix",
])
with tab_sets:
st.markdown("**Scene sets**: original + 2 counterfactuals per row. Images and info below.")
include_answers = st.checkbox("Include answer matrix in each row", value=True, key="sets_answers")
if df.empty:
st.caption("No scene sets.")
else:
for i, (_, row) in enumerate(df.iterrows()):
sid = _row_scene_id(row, i)
cf1_t, cf2_t = get_cf_types_for_scene(csv_path, row) if csv_path else (None, None)
labels = [
"Original",
"Counterfactual 1" + (f" ({cf1_t})" if cf1_t else ""),
"Counterfactual 2" + (f" ({cf2_t})" if cf2_t else ""),
]
st.markdown(f"**{sid}** (row {i + 1})")
c1, c2, c3 = st.columns(3)
for j, (col_name, label) in enumerate([
(MMB_IMAGE_COLS[0], labels[0]),
(MMB_IMAGE_COLS[1], labels[1]),
(MMB_IMAGE_COLS[2], labels[2]),
]):
with (c1, c2, c3)[j]:
if col_name in row.index:
_render_image_cell(csv_path, row[col_name], label)
qa_lines = []
for qc, dc in zip(MMB_QA_COLS, MMB_DIFF_COLS):
if qc in row.index and dc in row.index:
lab = QA_LABELS.get(qc, qc)
qa_lines.append(f"**{lab}** ({row[dc]}): {row[qc]}")
for line in qa_lines:
st.markdown("- " + line)
if include_answers:
tbl = _answers_as_grid(row, cols)
if tbl:
st.markdown("**Answers** (image → question)")
st.markdown(tbl)
else:
ans_cols = [c for c in cols if "answer" in c.lower()]
if ans_cols:
st.json({c: str(row[c]) for c in ans_cols})
st.divider()
with tab_overview:
st.markdown("**Scene sets** (original + 2 counterfactuals per row).")
n = len(df)
st.metric("Scene sets", n)
if MMB_IMAGE_COLS[0] in cols:
st.caption("Columns: " + ", ".join(cols[:6]) + (" …" if len(cols) > 6 else ""))
fig = px.bar(
x=["Scene sets"],
y=[n],
title="Number of scene sets",
labels={"y": "count", "x": ""},
)
fig.update_layout(template="plotly_white", showlegend=False)
st.plotly_chart(fig, use_container_width=True)
with tab_difficulty:
diff_cols = [c for c in MMB_DIFF_COLS if c in cols]
if not diff_cols:
st.info("No difficulty columns in this CSV.")
else:
st.markdown("Question difficulty counts (original, CF1, CF2).")
label_map = {
"original_question_difficulty": "Original",
"counterfactual1_question_difficulty": "CF1",
"counterfactual2_question_difficulty": "CF2",
}
rows = []
for c in diff_cols:
vc = df[c].value_counts()
lab = label_map.get(c, c)
for lev, cnt in vc.items():
rows.append({"question": lab, "difficulty": str(lev), "count": int(cnt)})
diff_df = pd.DataFrame(rows)
fig = px.bar(
diff_df,
x="question",
y="count",
color="difficulty",
barmode="group",
title="Difficulty by question type",
color_discrete_map={"easy": "#2ecc71", "medium": "#f39c12", "hard": "#e74c3c"},
)
fig.update_layout(template="plotly_white", xaxis_tickangle=-20)
st.plotly_chart(fig, use_container_width=True)
with tab_cf_types:
cf_rows: list[dict] = []
for i, (_, row) in enumerate(df.iterrows()):
sid = _row_scene_id(row, i)
cf1_t, cf2_t = get_cf_types_for_scene(csv_path, row) if csv_path else (None, None)
cf_rows.append({"scene": sid, "CF1 type": cf1_t or "—", "CF2 type": cf2_t or "—"})
cf_df = pd.DataFrame(cf_rows)
if cf_df.empty:
st.caption("No scene sets.")
else:
st.markdown("**Counterfactual types** per scene (from `scenes/*_cf1.json`, `*_cf2.json`).")
st.dataframe(cf_df, use_container_width=True, hide_index=True)
# Counts for bar chart
flat: list[dict] = []
for r in cf_rows:
if r["CF1 type"] != "—":
flat.append({"slot": "CF1", "cf_type": r["CF1 type"]})
if r["CF2 type"] != "—":
flat.append({"slot": "CF2", "cf_type": r["CF2 type"]})
if flat:
flat_df = pd.DataFrame(flat)
agg = flat_df.groupby(["cf_type", "slot"]).size().reset_index(name="count")
fig = px.bar(
agg,
x="cf_type",
y="count",
color="slot",
barmode="group",
title="Counterfactual types by slot",
labels={"cf_type": "type"},
)
fig.update_layout(template="plotly_white", xaxis_tickangle=-45)
st.plotly_chart(fig, use_container_width=True)
else:
st.info("No `scenes/` folder or `cf_type` in scene JSONs. Add scenes next to the CSV.")
with tab_answers:
answer_cols = [c for c in cols if "answer" in c.lower()]
if not answer_cols:
st.info("No answer-matrix columns in this CSV.")
else:
st.markdown("Answer matrix: each image’s answer to each question (3×3 grid per set).")
for i, (_, row) in enumerate(df.iterrows()):
sid = _row_scene_id(row, i)
st.markdown(f"**{sid}** (row {i + 1})")
tbl = _answers_as_grid(row, cols)
if tbl:
st.markdown(tbl)
else:
st.json({c: str(row[c]) for c in answer_cols})
st.divider()
def render_mmb_images(df: pd.DataFrame, csv_name: str, csv_path: Path | None = None):
st.subheader("MMB-style (images only): " + csv_name)
n = len(df)
st.metric("Scene sets", n)
st.markdown("**Image sets** (images in each row below).")
if df.empty:
st.caption("No scene sets.")
else:
for i, (_, row) in enumerate(df.iterrows()):
sid = _row_scene_id(row, i)
cf1_t, cf2_t = get_cf_types_for_scene(csv_path, row) if csv_path else (None, None)
labels = [
"Original",
"CF1" + (f" ({cf1_t})" if cf1_t else ""),
"CF2" + (f" ({cf2_t})" if cf2_t else ""),
]
st.markdown(f"**{sid}** (row {i + 1})")
c1, c2, c3 = st.columns(3)
for j, (col_name, label) in enumerate([
(MMB_IMAGE_COLS[0], labels[0]),
(MMB_IMAGE_COLS[1], labels[1]),
(MMB_IMAGE_COLS[2], labels[2]),
]):
with (c1, c2, c3)[j]:
if col_name in row.index:
_render_image_cell(csv_path, row[col_name], label)
st.divider()
def main():
st.title("Dataset Visualizer")
st.caption("**Upload** your dataset (ZIP or CSV) or use **CSV files** from output/data.")
with st.sidebar:
source = st.radio("Data source", ["Upload dataset", "CSV files"], horizontal=True)
if source == "Upload dataset":
with st.sidebar:
st.header("Upload dataset")
uploaded = st.file_uploader(
"ZIP (CSV + images/ + optional scenes/) or CSV",
type=["zip", "csv"],
key="upload_dataset",
)
use_btn = st.button("Use this file", key="upload_use")
if st.session_state.get("upload_csv_path") is not None:
st.caption("Uploaded: " + (st.session_state.get("upload_name") or "dataset"))
clear_btn = st.button("Clear uploaded dataset", key="upload_clear")
else:
clear_btn = False
if clear_btn:
base = st.session_state.get("upload_base")
if base is not None:
try:
shutil.rmtree(base, ignore_errors=True)
except Exception:
pass
for k in ("upload_base", "upload_csv_path", "upload_name"):
st.session_state.pop(k, None)
st.rerun()
if use_btn and uploaded:
prev_base = st.session_state.get("upload_base")
if prev_base is not None:
try:
shutil.rmtree(prev_base, ignore_errors=True)
except Exception:
pass
with st.spinner("Processing upload…"):
out = _csv_from_upload(uploaded)
if out is None:
st.error("No CSV found in ZIP, or invalid upload.")
else:
base, csv_path = out
st.session_state["upload_base"] = base
st.session_state["upload_csv_path"] = csv_path
st.session_state["upload_name"] = uploaded.name
st.rerun()
csv_path = st.session_state.get("upload_csv_path")
if csv_path is None:
if uploaded and not use_btn:
st.info("Click **Use this file** to visualize the uploaded dataset.")
else:
st.info("Upload a **ZIP** (CSV + **images/** folder, optional **scenes/**) or **CSV** to visualize.")
with st.expander("Expected ZIP structure"):
st.code("""mydata.zip
image_mapping_with_questions.csv (or your CSV)
images/
scene_0001_original.png
scene_0001_cf1.png
...
scenes/ (optional, for counterfactual types)
scene_0001_cf1.json
...""", language="text")
return
path = Path(csv_path) if not isinstance(csv_path, Path) else csv_path
df = load_csv(path)
fmt = detect_format(df)
name = st.session_state.get("upload_name") or path.name
if fmt == "mmb_qa":
render_mmb_qa(df, name, csv_path=path)
elif fmt == "mmb_images":
render_mmb_images(df, name, csv_path=path)
elif fmt == "mib":
st.info("MIB-style CSV detected. Not supported for visualization.")
else:
st.warning("Unknown CSV format. Not supported for visualization.")
return
# CSV source
csv_options = discover_csvs()
if not csv_options:
st.info(
"No CSVs found in **output/** or **data/** or **hf_dataset/**. "
"Upload a dataset above (ZIP or CSV) or add CSV files to this repo and redeploy."
)
return
with st.sidebar:
st.header("CSV file")
rel_paths = []
for p in csv_options:
try:
rel_paths.append(str(p.relative_to(WORKSPACE_ROOT)))
except ValueError:
rel_paths.append(str(p))
chosen = st.selectbox("Choose CSV", options=rel_paths, format_func=lambda x: x)
path = next(p for p in csv_options if (str(p.relative_to(WORKSPACE_ROOT)) == chosen or str(p) == chosen))
df = load_csv(path)
fmt = detect_format(df)
csv_name = path.name
if fmt == "mmb_qa":
render_mmb_qa(df, csv_name, csv_path=path)
elif fmt == "mmb_images":
render_mmb_images(df, csv_name, csv_path=path)
elif fmt == "mib":
st.info("MIB-style CSV detected. Not supported for visualization.")
else:
st.warning("Unknown CSV format. Not supported for visualization.")
if __name__ == "__main__":
main()