| | """
|
| | Dataset Visualizer
|
| |
|
| | MMB-style CSVs from upload, output/, or data/.
|
| | Image sets table includes thumbnails; overview, difficulty & questions, answer matrix.
|
| | """
|
| | from __future__ import annotations
|
| |
|
| | import io
|
| | import json
|
| | import shutil
|
| | import tempfile
|
| | import zipfile
|
| | import streamlit as st
|
| | import pandas as pd
|
| | import plotly.express as px
|
| | from pathlib import Path
|
| |
|
| | st.set_page_config(page_title="Dataset Visualizer", layout="wide")
|
| |
|
| | PROJECT_ROOT = Path(__file__).resolve().parent
|
| | OUTPUT_DIR = PROJECT_ROOT / "output"
|
| | DATA_DIR = PROJECT_ROOT / "data"
|
| | HF_DATASET_DIR = PROJECT_ROOT / "hf_dataset"
|
| | WORKSPACE_ROOT = PROJECT_ROOT
|
| |
|
| |
|
| | MMB_IMAGE_COLS = ["original_image", "counterfactual1_image", "counterfactual2_image"]
|
| | MMB_QA_COLS = ["original_question", "counterfactual1_question", "counterfactual2_question"]
|
| | MMB_DIFF_COLS = [
|
| | "original_question_difficulty",
|
| | "counterfactual1_question_difficulty",
|
| | "counterfactual2_question_difficulty",
|
| | ]
|
| |
|
| |
|
| | def resolve_image_path(csv_path: Path, fname: str) -> Path | None:
|
| | """Try multiple locations; return first existing path for fname."""
|
| | base = csv_path.resolve().parent
|
| | candidates = [
|
| | base / "images" / fname,
|
| | base / fname,
|
| | base.parent / "images" / fname,
|
| | base.parent / fname,
|
| | ]
|
| | for p in candidates:
|
| | if p.exists() and p.is_file():
|
| | return p
|
| | return None
|
| |
|
| |
|
| | def scene_id_from_original(original_image: str) -> str:
|
| | """e.g. 'scene_0001_original.png' -> 'scene_0001'."""
|
| | s = str(original_image)
|
| | for suf in ("_original.png", "_original", ".png"):
|
| | if s.endswith(suf):
|
| | s = s[: -len(suf)]
|
| | break
|
| | return s or original_image
|
| |
|
| |
|
| | def scene_id_from_image_name(fname: str) -> str:
|
| | """e.g. 'scene_0001_cf1.png' or 'scene_0001_original.png' -> 'scene_0001'. Handles _original, _cf1, _cf2."""
|
| | s = str(fname).strip()
|
| | for suf in (
|
| | "_original.png", "_original",
|
| | "_cf1.png", "_cf1",
|
| | "_cf2.png", "_cf2",
|
| | ".png",
|
| | ):
|
| | if s.lower().endswith(suf.lower()):
|
| | s = s[: -len(suf)]
|
| | break
|
| | return s.strip() or fname
|
| |
|
| |
|
| | def _row_scene_id(row, index: int) -> str:
|
| | """Scene ID for display from first image column."""
|
| | v = row.get(MMB_IMAGE_COLS[0], None)
|
| | if isinstance(v, str) and v:
|
| | return scene_id_from_original(v)
|
| | return f"row_{index + 1}"
|
| |
|
| |
|
| | def resolve_scenes_dir(csv_path: Path) -> Path | None:
|
| | """Scenes dir: <csv_parent>/scenes/ or <csv_grandparent>/scenes/."""
|
| | base = csv_path.resolve().parent
|
| | for candidate in (base / "scenes", base.parent / "scenes"):
|
| | if candidate.is_dir():
|
| | return candidate
|
| | return None
|
| |
|
| |
|
| | def _scene_id_for_lookup(s: str) -> str:
|
| | """Normalize scene id for scene file lookup (lowercase, strip)."""
|
| | return str(s).strip().lower() if s else ""
|
| |
|
| |
|
| | @st.cache_data
|
| | def _get_cf_type_from_scene_file(csv_path: Path, scene_id: str, variant: str) -> str | None:
|
| | """Load scenes/{scene_id}_{variant}.json; return cf_metadata.cf_type. Uses variant from JSON to validate."""
|
| | scenes_dir = resolve_scenes_dir(csv_path)
|
| | if not scenes_dir:
|
| | return None
|
| | sid = _scene_id_for_lookup(scene_id)
|
| | if not sid:
|
| | return None
|
| | path = scenes_dir / f"{sid}_{variant}.json"
|
| | if not path.exists():
|
| | return None
|
| | try:
|
| | with open(path, encoding="utf-8") as f:
|
| | data = json.load(f)
|
| | meta = data.get("cf_metadata") or {}
|
| | v = meta.get("variant")
|
| | if v is not None and str(v).lower() != variant.lower():
|
| | return None
|
| | t = meta.get("cf_type")
|
| | return str(t) if t is not None else None
|
| | except Exception:
|
| | return None
|
| |
|
| |
|
| | def get_cf_types_for_scene(csv_path: Path | None, row) -> tuple[str | None, str | None]:
|
| | """Load scene JSONs for this row; return (cf1_type, cf2_type). Uses CF image columns to resolve scene."""
|
| | if csv_path is None:
|
| | return (None, None)
|
| | sid_orig = scene_id_from_original(str(row.get(MMB_IMAGE_COLS[0], "") or ""))
|
| | sid_cf1 = scene_id_from_image_name(str(row.get(MMB_IMAGE_COLS[1], "") or "")) or sid_orig
|
| | sid_cf2 = scene_id_from_image_name(str(row.get(MMB_IMAGE_COLS[2], "") or "")) or sid_orig
|
| | cf1_type = _get_cf_type_from_scene_file(csv_path, sid_cf1, "cf1")
|
| | cf2_type = _get_cf_type_from_scene_file(csv_path, sid_cf2, "cf2")
|
| | return (cf1_type, cf2_type)
|
| |
|
| |
|
| | def discover_csvs() -> list[Path]:
|
| | """Find CSVs under output/, data/, and hf_dataset/ (recursive)."""
|
| | out: list[Path] = []
|
| | for base in (OUTPUT_DIR, DATA_DIR, HF_DATASET_DIR):
|
| | if base.exists():
|
| | out.extend(base.rglob("*.csv"))
|
| | return sorted(set(out), key=lambda p: str(p))
|
| |
|
| |
|
| | def _extract_upload_zip(zip_bytes: bytes) -> Path:
|
| | """Extract zip to a temp directory; return path to that directory."""
|
| | root = Path(tempfile.mkdtemp(prefix="dataset_upload_"))
|
| | with zipfile.ZipFile(io.BytesIO(zip_bytes), "r") as z:
|
| | z.extractall(root)
|
| | return root
|
| |
|
| |
|
| | def _discover_csvs_in_dir(root: Path) -> list[Path]:
|
| | """Find all CSVs under root (recursive)."""
|
| | return sorted(root.rglob("*.csv"), key=lambda p: str(p))
|
| |
|
| |
|
| | def _csv_from_upload(uploaded) -> tuple[Path, Path] | None:
|
| | """
|
| | Process uploaded file (ZIP or CSV). Return (base_dir, csv_path) or None.
|
| | - ZIP: extract to temp, discover CSVs; base_dir = extract root, csv_path = chosen CSV.
|
| | - CSV: write to temp dir; base_dir = temp dir, csv_path = that CSV.
|
| | """
|
| | fname = (uploaded.name or "").lower()
|
| | data = uploaded.read()
|
| |
|
| | if fname.endswith(".zip"):
|
| | base = _extract_upload_zip(data)
|
| | csvs = _discover_csvs_in_dir(base)
|
| | if not csvs:
|
| | return None
|
| |
|
| | with_q = [p for p in csvs if "question" in p.name.lower()]
|
| | chosen = (with_q[0] if with_q else csvs[0])
|
| | return (base, chosen)
|
| | if fname.endswith(".csv"):
|
| | base = Path(tempfile.mkdtemp(prefix="dataset_upload_"))
|
| | path = base / (uploaded.name or "uploaded.csv")
|
| | path.write_bytes(data)
|
| | return (base, path)
|
| | return None
|
| |
|
| |
|
| | def detect_format(df: pd.DataFrame) -> str:
|
| | """Return 'mmb_qa' | 'mmb_images' | 'mib' | 'unknown'."""
|
| | cols = set(df.columns)
|
| | has_mmb_images = all(c in cols for c in MMB_IMAGE_COLS)
|
| | if has_mmb_images:
|
| | if MMB_QA_COLS[0] in cols or "original_question" in cols:
|
| | return "mmb_qa"
|
| | return "mmb_images"
|
| | if "k" in cols and "f" in cols and "method" in cols:
|
| | return "mib"
|
| | if "dataset" in cols and "split" in cols and "count" in cols:
|
| | return "mib"
|
| | if "method" in cols and "task" in cols and "CPR" in cols:
|
| | return "mib"
|
| | return "unknown"
|
| |
|
| |
|
| | @st.cache_data
|
| | def load_csv(path: Path) -> pd.DataFrame:
|
| | return pd.read_csv(path)
|
| |
|
| |
|
| | QA_LABELS = {"original_question": "Original", "counterfactual1_question": "CF1", "counterfactual2_question": "CF2"}
|
| |
|
| |
|
| | ANSWER_GRID = [
|
| | ["original_image_answer_to_original_question", "original_image_answer_to_cf1_question", "original_image_answer_to_cf2_question"],
|
| | ["cf1_image_answer_to_original_question", "cf1_image_answer_to_cf1_question", "cf1_image_answer_to_cf2_question"],
|
| | ["cf2_image_answer_to_original_question", "cf2_image_answer_to_cf1_question", "cf2_image_answer_to_cf2_question"],
|
| | ]
|
| |
|
| |
|
| | def _render_image_cell(csv_path: Path | None, value, label: str):
|
| | """Display image from file path (CSV/upload) or filename fallback."""
|
| | fname = str(value) if value is not None else ""
|
| | fp = resolve_image_path(csv_path, fname) if csv_path and fname else None
|
| | if fp:
|
| | st.image(str(fp), use_container_width=True, caption=label)
|
| | return
|
| | st.caption(label)
|
| | st.code(fname or "(no image)", language=None)
|
| |
|
| |
|
| | def _answers_as_grid(row, cols: list[str]) -> str | None:
|
| | """Format 3×3 answer matrix as markdown table. Returns None if columns missing."""
|
| | def _cell(v):
|
| | return str(v).replace("|", "·").replace("\n", " ").strip()
|
| |
|
| | grid = []
|
| | for r in range(3):
|
| | line = []
|
| | for c in range(3):
|
| | key = ANSWER_GRID[r][c]
|
| | if key not in cols or key not in row.index:
|
| | return None
|
| | line.append(_cell(row[key]))
|
| | grid.append(line)
|
| | header = "| | **Original Q** | **CF1 Q** | **CF2 Q** |"
|
| | sep = "| --- | --- | --- | --- |"
|
| | rows = [
|
| | "| **Original** | " + " | ".join(grid[0]) + " |",
|
| | "| **CF1** | " + " | ".join(grid[1]) + " |",
|
| | "| **CF2** | " + " | ".join(grid[2]) + " |",
|
| | ]
|
| | return "\n".join([header, sep] + rows)
|
| |
|
| |
|
| | def render_mmb_qa(df: pd.DataFrame, csv_name: str, csv_path: Path | None = None):
|
| | st.subheader("MMB-style dataset: " + csv_name)
|
| | cols = df.columns.tolist()
|
| |
|
| | tab_sets, tab_overview, tab_difficulty, tab_cf_types, tab_answers = st.tabs([
|
| | "Image sets", "Overview", "Difficulty & questions", "Counterfactual types", "Answer matrix",
|
| | ])
|
| |
|
| | with tab_sets:
|
| | st.markdown("**Scene sets**: original + 2 counterfactuals per row. Images and info below.")
|
| | include_answers = st.checkbox("Include answer matrix in each row", value=True, key="sets_answers")
|
| | if df.empty:
|
| | st.caption("No scene sets.")
|
| | else:
|
| | for i, (_, row) in enumerate(df.iterrows()):
|
| | sid = _row_scene_id(row, i)
|
| | cf1_t, cf2_t = get_cf_types_for_scene(csv_path, row) if csv_path else (None, None)
|
| | labels = [
|
| | "Original",
|
| | "Counterfactual 1" + (f" ({cf1_t})" if cf1_t else ""),
|
| | "Counterfactual 2" + (f" ({cf2_t})" if cf2_t else ""),
|
| | ]
|
| | st.markdown(f"**{sid}** (row {i + 1})")
|
| | c1, c2, c3 = st.columns(3)
|
| | for j, (col_name, label) in enumerate([
|
| | (MMB_IMAGE_COLS[0], labels[0]),
|
| | (MMB_IMAGE_COLS[1], labels[1]),
|
| | (MMB_IMAGE_COLS[2], labels[2]),
|
| | ]):
|
| | with (c1, c2, c3)[j]:
|
| | if col_name in row.index:
|
| | _render_image_cell(csv_path, row[col_name], label)
|
| | qa_lines = []
|
| | for qc, dc in zip(MMB_QA_COLS, MMB_DIFF_COLS):
|
| | if qc in row.index and dc in row.index:
|
| | lab = QA_LABELS.get(qc, qc)
|
| | qa_lines.append(f"**{lab}** ({row[dc]}): {row[qc]}")
|
| | for line in qa_lines:
|
| | st.markdown("- " + line)
|
| | if include_answers:
|
| | tbl = _answers_as_grid(row, cols)
|
| | if tbl:
|
| | st.markdown("**Answers** (image → question)")
|
| | st.markdown(tbl)
|
| | else:
|
| | ans_cols = [c for c in cols if "answer" in c.lower()]
|
| | if ans_cols:
|
| | st.json({c: str(row[c]) for c in ans_cols})
|
| | st.divider()
|
| |
|
| | with tab_overview:
|
| | st.markdown("**Scene sets** (original + 2 counterfactuals per row).")
|
| | n = len(df)
|
| | st.metric("Scene sets", n)
|
| | if MMB_IMAGE_COLS[0] in cols:
|
| | st.caption("Columns: " + ", ".join(cols[:6]) + (" …" if len(cols) > 6 else ""))
|
| |
|
| | fig = px.bar(
|
| | x=["Scene sets"],
|
| | y=[n],
|
| | title="Number of scene sets",
|
| | labels={"y": "count", "x": ""},
|
| | )
|
| | fig.update_layout(template="plotly_white", showlegend=False)
|
| | st.plotly_chart(fig, use_container_width=True)
|
| |
|
| | with tab_difficulty:
|
| | diff_cols = [c for c in MMB_DIFF_COLS if c in cols]
|
| | if not diff_cols:
|
| | st.info("No difficulty columns in this CSV.")
|
| | else:
|
| | st.markdown("Question difficulty counts (original, CF1, CF2).")
|
| | label_map = {
|
| | "original_question_difficulty": "Original",
|
| | "counterfactual1_question_difficulty": "CF1",
|
| | "counterfactual2_question_difficulty": "CF2",
|
| | }
|
| | rows = []
|
| | for c in diff_cols:
|
| | vc = df[c].value_counts()
|
| | lab = label_map.get(c, c)
|
| | for lev, cnt in vc.items():
|
| | rows.append({"question": lab, "difficulty": str(lev), "count": int(cnt)})
|
| | diff_df = pd.DataFrame(rows)
|
| | fig = px.bar(
|
| | diff_df,
|
| | x="question",
|
| | y="count",
|
| | color="difficulty",
|
| | barmode="group",
|
| | title="Difficulty by question type",
|
| | color_discrete_map={"easy": "#2ecc71", "medium": "#f39c12", "hard": "#e74c3c"},
|
| | )
|
| | fig.update_layout(template="plotly_white", xaxis_tickangle=-20)
|
| | st.plotly_chart(fig, use_container_width=True)
|
| |
|
| | with tab_cf_types:
|
| | cf_rows: list[dict] = []
|
| | for i, (_, row) in enumerate(df.iterrows()):
|
| | sid = _row_scene_id(row, i)
|
| | cf1_t, cf2_t = get_cf_types_for_scene(csv_path, row) if csv_path else (None, None)
|
| | cf_rows.append({"scene": sid, "CF1 type": cf1_t or "—", "CF2 type": cf2_t or "—"})
|
| | cf_df = pd.DataFrame(cf_rows)
|
| | if cf_df.empty:
|
| | st.caption("No scene sets.")
|
| | else:
|
| | st.markdown("**Counterfactual types** per scene (from `scenes/*_cf1.json`, `*_cf2.json`).")
|
| | st.dataframe(cf_df, use_container_width=True, hide_index=True)
|
| |
|
| | flat: list[dict] = []
|
| | for r in cf_rows:
|
| | if r["CF1 type"] != "—":
|
| | flat.append({"slot": "CF1", "cf_type": r["CF1 type"]})
|
| | if r["CF2 type"] != "—":
|
| | flat.append({"slot": "CF2", "cf_type": r["CF2 type"]})
|
| | if flat:
|
| | flat_df = pd.DataFrame(flat)
|
| | agg = flat_df.groupby(["cf_type", "slot"]).size().reset_index(name="count")
|
| | fig = px.bar(
|
| | agg,
|
| | x="cf_type",
|
| | y="count",
|
| | color="slot",
|
| | barmode="group",
|
| | title="Counterfactual types by slot",
|
| | labels={"cf_type": "type"},
|
| | )
|
| | fig.update_layout(template="plotly_white", xaxis_tickangle=-45)
|
| | st.plotly_chart(fig, use_container_width=True)
|
| | else:
|
| | st.info("No `scenes/` folder or `cf_type` in scene JSONs. Add scenes next to the CSV.")
|
| |
|
| | with tab_answers:
|
| | answer_cols = [c for c in cols if "answer" in c.lower()]
|
| | if not answer_cols:
|
| | st.info("No answer-matrix columns in this CSV.")
|
| | else:
|
| | st.markdown("Answer matrix: each image’s answer to each question (3×3 grid per set).")
|
| | for i, (_, row) in enumerate(df.iterrows()):
|
| | sid = _row_scene_id(row, i)
|
| | st.markdown(f"**{sid}** (row {i + 1})")
|
| | tbl = _answers_as_grid(row, cols)
|
| | if tbl:
|
| | st.markdown(tbl)
|
| | else:
|
| | st.json({c: str(row[c]) for c in answer_cols})
|
| | st.divider()
|
| |
|
| |
|
| | def render_mmb_images(df: pd.DataFrame, csv_name: str, csv_path: Path | None = None):
|
| | st.subheader("MMB-style (images only): " + csv_name)
|
| | n = len(df)
|
| | st.metric("Scene sets", n)
|
| |
|
| | st.markdown("**Image sets** (images in each row below).")
|
| | if df.empty:
|
| | st.caption("No scene sets.")
|
| | else:
|
| | for i, (_, row) in enumerate(df.iterrows()):
|
| | sid = _row_scene_id(row, i)
|
| | cf1_t, cf2_t = get_cf_types_for_scene(csv_path, row) if csv_path else (None, None)
|
| | labels = [
|
| | "Original",
|
| | "CF1" + (f" ({cf1_t})" if cf1_t else ""),
|
| | "CF2" + (f" ({cf2_t})" if cf2_t else ""),
|
| | ]
|
| | st.markdown(f"**{sid}** (row {i + 1})")
|
| | c1, c2, c3 = st.columns(3)
|
| | for j, (col_name, label) in enumerate([
|
| | (MMB_IMAGE_COLS[0], labels[0]),
|
| | (MMB_IMAGE_COLS[1], labels[1]),
|
| | (MMB_IMAGE_COLS[2], labels[2]),
|
| | ]):
|
| | with (c1, c2, c3)[j]:
|
| | if col_name in row.index:
|
| | _render_image_cell(csv_path, row[col_name], label)
|
| | st.divider()
|
| |
|
| |
|
| | def main():
|
| | st.title("Dataset Visualizer")
|
| | st.caption("**Upload** your dataset (ZIP or CSV) or use **CSV files** from output/data.")
|
| |
|
| | with st.sidebar:
|
| | source = st.radio("Data source", ["Upload dataset", "CSV files"], horizontal=True)
|
| |
|
| | if source == "Upload dataset":
|
| | with st.sidebar:
|
| | st.header("Upload dataset")
|
| | uploaded = st.file_uploader(
|
| | "ZIP (CSV + images/ + optional scenes/) or CSV",
|
| | type=["zip", "csv"],
|
| | key="upload_dataset",
|
| | )
|
| | use_btn = st.button("Use this file", key="upload_use")
|
| | if st.session_state.get("upload_csv_path") is not None:
|
| | st.caption("Uploaded: " + (st.session_state.get("upload_name") or "dataset"))
|
| | clear_btn = st.button("Clear uploaded dataset", key="upload_clear")
|
| | else:
|
| | clear_btn = False
|
| |
|
| | if clear_btn:
|
| | base = st.session_state.get("upload_base")
|
| | if base is not None:
|
| | try:
|
| | shutil.rmtree(base, ignore_errors=True)
|
| | except Exception:
|
| | pass
|
| | for k in ("upload_base", "upload_csv_path", "upload_name"):
|
| | st.session_state.pop(k, None)
|
| | st.rerun()
|
| |
|
| | if use_btn and uploaded:
|
| | prev_base = st.session_state.get("upload_base")
|
| | if prev_base is not None:
|
| | try:
|
| | shutil.rmtree(prev_base, ignore_errors=True)
|
| | except Exception:
|
| | pass
|
| | with st.spinner("Processing upload…"):
|
| | out = _csv_from_upload(uploaded)
|
| | if out is None:
|
| | st.error("No CSV found in ZIP, or invalid upload.")
|
| | else:
|
| | base, csv_path = out
|
| | st.session_state["upload_base"] = base
|
| | st.session_state["upload_csv_path"] = csv_path
|
| | st.session_state["upload_name"] = uploaded.name
|
| | st.rerun()
|
| |
|
| | csv_path = st.session_state.get("upload_csv_path")
|
| | if csv_path is None:
|
| | if uploaded and not use_btn:
|
| | st.info("Click **Use this file** to visualize the uploaded dataset.")
|
| | else:
|
| | st.info("Upload a **ZIP** (CSV + **images/** folder, optional **scenes/**) or **CSV** to visualize.")
|
| | with st.expander("Expected ZIP structure"):
|
| | st.code("""mydata.zip
|
| | image_mapping_with_questions.csv (or your CSV)
|
| | images/
|
| | scene_0001_original.png
|
| | scene_0001_cf1.png
|
| | ...
|
| | scenes/ (optional, for counterfactual types)
|
| | scene_0001_cf1.json
|
| | ...""", language="text")
|
| | return
|
| |
|
| | path = Path(csv_path) if not isinstance(csv_path, Path) else csv_path
|
| | df = load_csv(path)
|
| | fmt = detect_format(df)
|
| | name = st.session_state.get("upload_name") or path.name
|
| |
|
| | if fmt == "mmb_qa":
|
| | render_mmb_qa(df, name, csv_path=path)
|
| | elif fmt == "mmb_images":
|
| | render_mmb_images(df, name, csv_path=path)
|
| | elif fmt == "mib":
|
| | st.info("MIB-style CSV detected. Not supported for visualization.")
|
| | else:
|
| | st.warning("Unknown CSV format. Not supported for visualization.")
|
| | return
|
| |
|
| |
|
| | csv_options = discover_csvs()
|
| | if not csv_options:
|
| | st.info(
|
| | "No CSVs found in **output/** or **data/** or **hf_dataset/**. "
|
| | "Upload a dataset above (ZIP or CSV) or add CSV files to this repo and redeploy."
|
| | )
|
| | return
|
| |
|
| | with st.sidebar:
|
| | st.header("CSV file")
|
| | rel_paths = []
|
| | for p in csv_options:
|
| | try:
|
| | rel_paths.append(str(p.relative_to(WORKSPACE_ROOT)))
|
| | except ValueError:
|
| | rel_paths.append(str(p))
|
| | chosen = st.selectbox("Choose CSV", options=rel_paths, format_func=lambda x: x)
|
| | path = next(p for p in csv_options if (str(p.relative_to(WORKSPACE_ROOT)) == chosen or str(p) == chosen))
|
| | df = load_csv(path)
|
| | fmt = detect_format(df)
|
| | csv_name = path.name
|
| |
|
| | if fmt == "mmb_qa":
|
| | render_mmb_qa(df, csv_name, csv_path=path)
|
| | elif fmt == "mmb_images":
|
| | render_mmb_images(df, csv_name, csv_path=path)
|
| | elif fmt == "mib":
|
| | st.info("MIB-style CSV detected. Not supported for visualization.")
|
| | else:
|
| | st.warning("Unknown CSV format. Not supported for visualization.")
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | main()
|
| |
|