""" Dataset Visualizer MMB-style CSVs from upload, output/, or data/. Image sets table includes thumbnails; overview, difficulty & questions, answer matrix. """ from __future__ import annotations import io import json import shutil import tempfile import zipfile import streamlit as st import pandas as pd import plotly.express as px from pathlib import Path st.set_page_config(page_title="Dataset Visualizer", layout="wide") PROJECT_ROOT = Path(__file__).resolve().parent OUTPUT_DIR = PROJECT_ROOT / "output" DATA_DIR = PROJECT_ROOT / "data" HF_DATASET_DIR = PROJECT_ROOT / "hf_dataset" # For Hugging Face Space / repo bundle WORKSPACE_ROOT = PROJECT_ROOT # MMB CSV markers MMB_IMAGE_COLS = ["original_image", "counterfactual1_image", "counterfactual2_image"] MMB_QA_COLS = ["original_question", "counterfactual1_question", "counterfactual2_question"] MMB_DIFF_COLS = [ "original_question_difficulty", "counterfactual1_question_difficulty", "counterfactual2_question_difficulty", ] def resolve_image_path(csv_path: Path, fname: str) -> Path | None: """Try multiple locations; return first existing path for fname.""" base = csv_path.resolve().parent candidates = [ base / "images" / fname, base / fname, base.parent / "images" / fname, base.parent / fname, ] for p in candidates: if p.exists() and p.is_file(): return p return None def scene_id_from_original(original_image: str) -> str: """e.g. 'scene_0001_original.png' -> 'scene_0001'.""" s = str(original_image) for suf in ("_original.png", "_original", ".png"): if s.endswith(suf): s = s[: -len(suf)] break return s or original_image def scene_id_from_image_name(fname: str) -> str: """e.g. 'scene_0001_cf1.png' or 'scene_0001_original.png' -> 'scene_0001'. Handles _original, _cf1, _cf2.""" s = str(fname).strip() for suf in ( "_original.png", "_original", "_cf1.png", "_cf1", "_cf2.png", "_cf2", ".png", ): if s.lower().endswith(suf.lower()): s = s[: -len(suf)] break return s.strip() or fname def _row_scene_id(row, index: int) -> str: """Scene ID for display from first image column.""" v = row.get(MMB_IMAGE_COLS[0], None) if isinstance(v, str) and v: return scene_id_from_original(v) return f"row_{index + 1}" def resolve_scenes_dir(csv_path: Path) -> Path | None: """Scenes dir: /scenes/ or /scenes/.""" base = csv_path.resolve().parent for candidate in (base / "scenes", base.parent / "scenes"): if candidate.is_dir(): return candidate return None def _scene_id_for_lookup(s: str) -> str: """Normalize scene id for scene file lookup (lowercase, strip).""" return str(s).strip().lower() if s else "" @st.cache_data def _get_cf_type_from_scene_file(csv_path: Path, scene_id: str, variant: str) -> str | None: """Load scenes/{scene_id}_{variant}.json; return cf_metadata.cf_type. Uses variant from JSON to validate.""" scenes_dir = resolve_scenes_dir(csv_path) if not scenes_dir: return None sid = _scene_id_for_lookup(scene_id) if not sid: return None path = scenes_dir / f"{sid}_{variant}.json" if not path.exists(): return None try: with open(path, encoding="utf-8") as f: data = json.load(f) meta = data.get("cf_metadata") or {} v = meta.get("variant") if v is not None and str(v).lower() != variant.lower(): return None t = meta.get("cf_type") return str(t) if t is not None else None except Exception: return None def get_cf_types_for_scene(csv_path: Path | None, row) -> tuple[str | None, str | None]: """Load scene JSONs for this row; return (cf1_type, cf2_type). Uses CF image columns to resolve scene.""" if csv_path is None: return (None, None) sid_orig = scene_id_from_original(str(row.get(MMB_IMAGE_COLS[0], "") or "")) sid_cf1 = scene_id_from_image_name(str(row.get(MMB_IMAGE_COLS[1], "") or "")) or sid_orig sid_cf2 = scene_id_from_image_name(str(row.get(MMB_IMAGE_COLS[2], "") or "")) or sid_orig cf1_type = _get_cf_type_from_scene_file(csv_path, sid_cf1, "cf1") cf2_type = _get_cf_type_from_scene_file(csv_path, sid_cf2, "cf2") return (cf1_type, cf2_type) def discover_csvs() -> list[Path]: """Find CSVs under output/, data/, and hf_dataset/ (recursive).""" out: list[Path] = [] for base in (OUTPUT_DIR, DATA_DIR, HF_DATASET_DIR): if base.exists(): out.extend(base.rglob("*.csv")) return sorted(set(out), key=lambda p: str(p)) def _extract_upload_zip(zip_bytes: bytes) -> Path: """Extract zip to a temp directory; return path to that directory.""" root = Path(tempfile.mkdtemp(prefix="dataset_upload_")) with zipfile.ZipFile(io.BytesIO(zip_bytes), "r") as z: z.extractall(root) return root def _discover_csvs_in_dir(root: Path) -> list[Path]: """Find all CSVs under root (recursive).""" return sorted(root.rglob("*.csv"), key=lambda p: str(p)) def _csv_from_upload(uploaded) -> tuple[Path, Path] | None: """ Process uploaded file (ZIP or CSV). Return (base_dir, csv_path) or None. - ZIP: extract to temp, discover CSVs; base_dir = extract root, csv_path = chosen CSV. - CSV: write to temp dir; base_dir = temp dir, csv_path = that CSV. """ fname = (uploaded.name or "").lower() data = uploaded.read() if fname.endswith(".zip"): base = _extract_upload_zip(data) csvs = _discover_csvs_in_dir(base) if not csvs: return None # Prefer MMB-style CSV (name contains "question") if multiple with_q = [p for p in csvs if "question" in p.name.lower()] chosen = (with_q[0] if with_q else csvs[0]) return (base, chosen) if fname.endswith(".csv"): base = Path(tempfile.mkdtemp(prefix="dataset_upload_")) path = base / (uploaded.name or "uploaded.csv") path.write_bytes(data) return (base, path) return None def detect_format(df: pd.DataFrame) -> str: """Return 'mmb_qa' | 'mmb_images' | 'mib' | 'unknown'.""" cols = set(df.columns) has_mmb_images = all(c in cols for c in MMB_IMAGE_COLS) if has_mmb_images: if MMB_QA_COLS[0] in cols or "original_question" in cols: return "mmb_qa" return "mmb_images" if "k" in cols and "f" in cols and "method" in cols: return "mib" if "dataset" in cols and "split" in cols and "count" in cols: return "mib" if "method" in cols and "task" in cols and "CPR" in cols: return "mib" return "unknown" @st.cache_data def load_csv(path: Path) -> pd.DataFrame: return pd.read_csv(path) QA_LABELS = {"original_question": "Original", "counterfactual1_question": "CF1", "counterfactual2_question": "CF2"} # 3×3 grid: [image][question] -> CSV column name ANSWER_GRID = [ ["original_image_answer_to_original_question", "original_image_answer_to_cf1_question", "original_image_answer_to_cf2_question"], ["cf1_image_answer_to_original_question", "cf1_image_answer_to_cf1_question", "cf1_image_answer_to_cf2_question"], ["cf2_image_answer_to_original_question", "cf2_image_answer_to_cf1_question", "cf2_image_answer_to_cf2_question"], ] def _render_image_cell(csv_path: Path | None, value, label: str): """Display image from file path (CSV/upload) or filename fallback.""" fname = str(value) if value is not None else "" fp = resolve_image_path(csv_path, fname) if csv_path and fname else None if fp: st.image(str(fp), use_container_width=True, caption=label) return st.caption(label) st.code(fname or "(no image)", language=None) def _answers_as_grid(row, cols: list[str]) -> str | None: """Format 3×3 answer matrix as markdown table. Returns None if columns missing.""" def _cell(v): return str(v).replace("|", "·").replace("\n", " ").strip() grid = [] for r in range(3): line = [] for c in range(3): key = ANSWER_GRID[r][c] if key not in cols or key not in row.index: return None line.append(_cell(row[key])) grid.append(line) header = "| | **Original Q** | **CF1 Q** | **CF2 Q** |" sep = "| --- | --- | --- | --- |" rows = [ "| **Original** | " + " | ".join(grid[0]) + " |", "| **CF1** | " + " | ".join(grid[1]) + " |", "| **CF2** | " + " | ".join(grid[2]) + " |", ] return "\n".join([header, sep] + rows) def render_mmb_qa(df: pd.DataFrame, csv_name: str, csv_path: Path | None = None): st.subheader("MMB-style dataset: " + csv_name) cols = df.columns.tolist() tab_sets, tab_overview, tab_difficulty, tab_cf_types, tab_answers = st.tabs([ "Image sets", "Overview", "Difficulty & questions", "Counterfactual types", "Answer matrix", ]) with tab_sets: st.markdown("**Scene sets**: original + 2 counterfactuals per row. Images and info below.") include_answers = st.checkbox("Include answer matrix in each row", value=True, key="sets_answers") if df.empty: st.caption("No scene sets.") else: for i, (_, row) in enumerate(df.iterrows()): sid = _row_scene_id(row, i) cf1_t, cf2_t = get_cf_types_for_scene(csv_path, row) if csv_path else (None, None) labels = [ "Original", "Counterfactual 1" + (f" ({cf1_t})" if cf1_t else ""), "Counterfactual 2" + (f" ({cf2_t})" if cf2_t else ""), ] st.markdown(f"**{sid}** (row {i + 1})") c1, c2, c3 = st.columns(3) for j, (col_name, label) in enumerate([ (MMB_IMAGE_COLS[0], labels[0]), (MMB_IMAGE_COLS[1], labels[1]), (MMB_IMAGE_COLS[2], labels[2]), ]): with (c1, c2, c3)[j]: if col_name in row.index: _render_image_cell(csv_path, row[col_name], label) qa_lines = [] for qc, dc in zip(MMB_QA_COLS, MMB_DIFF_COLS): if qc in row.index and dc in row.index: lab = QA_LABELS.get(qc, qc) qa_lines.append(f"**{lab}** ({row[dc]}): {row[qc]}") for line in qa_lines: st.markdown("- " + line) if include_answers: tbl = _answers_as_grid(row, cols) if tbl: st.markdown("**Answers** (image → question)") st.markdown(tbl) else: ans_cols = [c for c in cols if "answer" in c.lower()] if ans_cols: st.json({c: str(row[c]) for c in ans_cols}) st.divider() with tab_overview: st.markdown("**Scene sets** (original + 2 counterfactuals per row).") n = len(df) st.metric("Scene sets", n) if MMB_IMAGE_COLS[0] in cols: st.caption("Columns: " + ", ".join(cols[:6]) + (" …" if len(cols) > 6 else "")) fig = px.bar( x=["Scene sets"], y=[n], title="Number of scene sets", labels={"y": "count", "x": ""}, ) fig.update_layout(template="plotly_white", showlegend=False) st.plotly_chart(fig, use_container_width=True) with tab_difficulty: diff_cols = [c for c in MMB_DIFF_COLS if c in cols] if not diff_cols: st.info("No difficulty columns in this CSV.") else: st.markdown("Question difficulty counts (original, CF1, CF2).") label_map = { "original_question_difficulty": "Original", "counterfactual1_question_difficulty": "CF1", "counterfactual2_question_difficulty": "CF2", } rows = [] for c in diff_cols: vc = df[c].value_counts() lab = label_map.get(c, c) for lev, cnt in vc.items(): rows.append({"question": lab, "difficulty": str(lev), "count": int(cnt)}) diff_df = pd.DataFrame(rows) fig = px.bar( diff_df, x="question", y="count", color="difficulty", barmode="group", title="Difficulty by question type", color_discrete_map={"easy": "#2ecc71", "medium": "#f39c12", "hard": "#e74c3c"}, ) fig.update_layout(template="plotly_white", xaxis_tickangle=-20) st.plotly_chart(fig, use_container_width=True) with tab_cf_types: cf_rows: list[dict] = [] for i, (_, row) in enumerate(df.iterrows()): sid = _row_scene_id(row, i) cf1_t, cf2_t = get_cf_types_for_scene(csv_path, row) if csv_path else (None, None) cf_rows.append({"scene": sid, "CF1 type": cf1_t or "—", "CF2 type": cf2_t or "—"}) cf_df = pd.DataFrame(cf_rows) if cf_df.empty: st.caption("No scene sets.") else: st.markdown("**Counterfactual types** per scene (from `scenes/*_cf1.json`, `*_cf2.json`).") st.dataframe(cf_df, use_container_width=True, hide_index=True) # Counts for bar chart flat: list[dict] = [] for r in cf_rows: if r["CF1 type"] != "—": flat.append({"slot": "CF1", "cf_type": r["CF1 type"]}) if r["CF2 type"] != "—": flat.append({"slot": "CF2", "cf_type": r["CF2 type"]}) if flat: flat_df = pd.DataFrame(flat) agg = flat_df.groupby(["cf_type", "slot"]).size().reset_index(name="count") fig = px.bar( agg, x="cf_type", y="count", color="slot", barmode="group", title="Counterfactual types by slot", labels={"cf_type": "type"}, ) fig.update_layout(template="plotly_white", xaxis_tickangle=-45) st.plotly_chart(fig, use_container_width=True) else: st.info("No `scenes/` folder or `cf_type` in scene JSONs. Add scenes next to the CSV.") with tab_answers: answer_cols = [c for c in cols if "answer" in c.lower()] if not answer_cols: st.info("No answer-matrix columns in this CSV.") else: st.markdown("Answer matrix: each image’s answer to each question (3×3 grid per set).") for i, (_, row) in enumerate(df.iterrows()): sid = _row_scene_id(row, i) st.markdown(f"**{sid}** (row {i + 1})") tbl = _answers_as_grid(row, cols) if tbl: st.markdown(tbl) else: st.json({c: str(row[c]) for c in answer_cols}) st.divider() def render_mmb_images(df: pd.DataFrame, csv_name: str, csv_path: Path | None = None): st.subheader("MMB-style (images only): " + csv_name) n = len(df) st.metric("Scene sets", n) st.markdown("**Image sets** (images in each row below).") if df.empty: st.caption("No scene sets.") else: for i, (_, row) in enumerate(df.iterrows()): sid = _row_scene_id(row, i) cf1_t, cf2_t = get_cf_types_for_scene(csv_path, row) if csv_path else (None, None) labels = [ "Original", "CF1" + (f" ({cf1_t})" if cf1_t else ""), "CF2" + (f" ({cf2_t})" if cf2_t else ""), ] st.markdown(f"**{sid}** (row {i + 1})") c1, c2, c3 = st.columns(3) for j, (col_name, label) in enumerate([ (MMB_IMAGE_COLS[0], labels[0]), (MMB_IMAGE_COLS[1], labels[1]), (MMB_IMAGE_COLS[2], labels[2]), ]): with (c1, c2, c3)[j]: if col_name in row.index: _render_image_cell(csv_path, row[col_name], label) st.divider() def main(): st.title("Dataset Visualizer") st.caption("**Upload** your dataset (ZIP or CSV) or use **CSV files** from output/data.") with st.sidebar: source = st.radio("Data source", ["Upload dataset", "CSV files"], horizontal=True) if source == "Upload dataset": with st.sidebar: st.header("Upload dataset") uploaded = st.file_uploader( "ZIP (CSV + images/ + optional scenes/) or CSV", type=["zip", "csv"], key="upload_dataset", ) use_btn = st.button("Use this file", key="upload_use") if st.session_state.get("upload_csv_path") is not None: st.caption("Uploaded: " + (st.session_state.get("upload_name") or "dataset")) clear_btn = st.button("Clear uploaded dataset", key="upload_clear") else: clear_btn = False if clear_btn: base = st.session_state.get("upload_base") if base is not None: try: shutil.rmtree(base, ignore_errors=True) except Exception: pass for k in ("upload_base", "upload_csv_path", "upload_name"): st.session_state.pop(k, None) st.rerun() if use_btn and uploaded: prev_base = st.session_state.get("upload_base") if prev_base is not None: try: shutil.rmtree(prev_base, ignore_errors=True) except Exception: pass with st.spinner("Processing upload…"): out = _csv_from_upload(uploaded) if out is None: st.error("No CSV found in ZIP, or invalid upload.") else: base, csv_path = out st.session_state["upload_base"] = base st.session_state["upload_csv_path"] = csv_path st.session_state["upload_name"] = uploaded.name st.rerun() csv_path = st.session_state.get("upload_csv_path") if csv_path is None: if uploaded and not use_btn: st.info("Click **Use this file** to visualize the uploaded dataset.") else: st.info("Upload a **ZIP** (CSV + **images/** folder, optional **scenes/**) or **CSV** to visualize.") with st.expander("Expected ZIP structure"): st.code("""mydata.zip image_mapping_with_questions.csv (or your CSV) images/ scene_0001_original.png scene_0001_cf1.png ... scenes/ (optional, for counterfactual types) scene_0001_cf1.json ...""", language="text") return path = Path(csv_path) if not isinstance(csv_path, Path) else csv_path df = load_csv(path) fmt = detect_format(df) name = st.session_state.get("upload_name") or path.name if fmt == "mmb_qa": render_mmb_qa(df, name, csv_path=path) elif fmt == "mmb_images": render_mmb_images(df, name, csv_path=path) elif fmt == "mib": st.info("MIB-style CSV detected. Not supported for visualization.") else: st.warning("Unknown CSV format. Not supported for visualization.") return # CSV source csv_options = discover_csvs() if not csv_options: st.info( "No CSVs found in **output/** or **data/** or **hf_dataset/**. " "Upload a dataset above (ZIP or CSV) or add CSV files to this repo and redeploy." ) return with st.sidebar: st.header("CSV file") rel_paths = [] for p in csv_options: try: rel_paths.append(str(p.relative_to(WORKSPACE_ROOT))) except ValueError: rel_paths.append(str(p)) chosen = st.selectbox("Choose CSV", options=rel_paths, format_func=lambda x: x) path = next(p for p in csv_options if (str(p.relative_to(WORKSPACE_ROOT)) == chosen or str(p) == chosen)) df = load_csv(path) fmt = detect_format(df) csv_name = path.name if fmt == "mmb_qa": render_mmb_qa(df, csv_name, csv_path=path) elif fmt == "mmb_images": render_mmb_images(df, csv_name, csv_path=path) elif fmt == "mib": st.info("MIB-style CSV detected. Not supported for visualization.") else: st.warning("Unknown CSV format. Not supported for visualization.") if __name__ == "__main__": main()