"""Streamlit app for browsing MathVision-like records.""" from __future__ import annotations import hashlib import io import shutil import tempfile from importlib import import_module from pathlib import Path from typing import Any from urllib.parse import urlparse from zipfile import BadZipFile, ZipFile from mathvision_explorer.dataset import ( MathVisionRecord, filter_records, load_jsonl_records, load_jsonl_records_from_text, record_from_mapping, summarize_records, ) from mathvision_explorer.embeddings import ( ColorStatsEmbedder, IJepaImageEmbedder, ImageEmbedder, render_patch_interest_overlay, ) from mathvision_explorer.explorer import build_image_index from mathvision_explorer.similarity import ( SimilarityWeights, combined_score, embedder_description, interpret_match, ) QUERY_IMAGE_WIDTH = 260 NEIGHBOR_IMAGE_WIDTH = 130 RECORD_IMAGE_WIDTH = 320 MATHVISION_DATASET_URL = "https://huggingface.co/datasets/MathLLMs/MathVision" HF_DATASETS_URL = "https://huggingface.co/datasets" def main(jsonl_path: Path = Path("data/demo/demo.jsonl")) -> None: """Run the Streamlit explorer app.""" st = _load_streamlit() st.set_page_config(page_title="MathVision Explorer", layout="wide") _stabilize_layout(st) st.title("MathVision Explorer") records = _load_active_records(st, jsonl_path) subjects = sorted({record.subject for record in records if record.subject is not None}) levels = sorted({record.level for record in records if record.level is not None}) with st.sidebar: st.header("Dataset") st.markdown( f"[Example: MathLLMs/MathVision]({MATHVISION_DATASET_URL}) | " f"[Browse HF datasets]({HF_DATASETS_URL})" ) dataset_source = st.radio( "Dataset source", ["Demo", "Hugging Face URL", "Upload file"], horizontal=False, help=( "Choose whether to use the bundled demo, paste a Hub dataset link, " "or upload files." ), ) if dataset_source == "Hugging Face URL": hf_dataset_ref = st.text_input( "HF dataset URL or ID", value="MathLLMs/MathVision", placeholder="https://huggingface.co/datasets/MathLLMs/MathVision", help="Paste a Hugging Face dataset URL or repo id.", ) hf_limit = st.number_input( "HF max records", min_value=1, max_value=500, value=50, step=10, help="Cap rows loaded from Hugging Face so exploration stays responsive.", ) if st.button( "Load HF dataset", help="Download the first available split and convert compatible rows into records.", ): try: records = _load_hf_dataset_records( st, hf_dataset_ref, limit=int(hf_limit), ) except (RuntimeError, ValueError, OSError) as error: st.error(str(error)) st.stop() raise RuntimeError("Streamlit stopped after HF dataset load error.") from error st.session_state["hf_dataset_records"] = records elif "hf_dataset_records" in st.session_state: records = st.session_state["hf_dataset_records"] elif dataset_source == "Upload file": uploaded_dataset = st.file_uploader( "Upload dataset", type=["jsonl", "zip"], help=( "Use a JSONL file for text-only records, or a ZIP containing one JSONL " "file plus referenced images." ), ) if uploaded_dataset is not None: records = _load_uploaded_records(st, uploaded_dataset) subjects = sorted({record.subject for record in records if record.subject is not None}) levels = sorted({record.level for record in records if record.level is not None}) summary = summarize_records(records) st.caption(f"{summary['records']} records | {summary['images']} images") st.header("Filters") subject = st.selectbox( "Subject", ["all", *subjects], help="Limit the visible record list to one MathVision subject.", ) level_label = st.selectbox( "Level", ["all", *(str(level) for level in levels)], help="Limit the visible record list to one difficulty level.", ) show_solutions = st.toggle( "Show solutions", value=True, help="Show or hide worked solutions in the record browser below the neighbor panel.", ) st.header("Latent Space") embedder_label = st.selectbox( "Embedder", ["color (fast demo)", "ijepa (semantic, requires make sync-ijepa)"], help=( "Choose the image representation used for nearest-neighbor search. " "I-JEPA is slower but more semantic." ), ) show_patch_maps = st.toggle( "Patch maps", value=True, help=( "When I-JEPA is selected, overlay patch-activation heatmaps on the query " "and neighbor images." ), ) st.caption("Similarity swatch: red = lower cosine similarity, green = higher.") st.subheader("Ranking") weights = SimilarityWeights( image=st.slider( "Image weight", min_value=0.0, max_value=2.0, value=1.0, step=0.05, help="How much raw image similarity contributes to the final ranking.", ), subject_bonus=st.slider( "Subject bonus", min_value=0.0, max_value=1.0, value=0.15, step=0.01, help="Extra score for neighbors from the same subject as the query.", ), problem_type_bonus=st.slider( "Problem-type bonus", min_value=0.0, max_value=1.0, value=0.10, step=0.01, help="Extra score for matching multiple-choice or open-ended format.", ), level_penalty=st.slider( "Level penalty", min_value=0.0, max_value=1.0, value=0.05, step=0.01, help="Penalty per difficulty-level step between query and neighbor.", ), ) query_id = st.selectbox( "Query record", [record.problem_id for record in records], help="The record used as the visual search query.", ) neighbor_count = st.slider( "Neighbors", min_value=1, max_value=8, value=3, help="Number of nearest records to show in the right-hand panel.", ) selected_subject = None if subject == "all" else subject selected_level = None if level_label == "all" else int(level_label) filtered = filter_records(records, subject=selected_subject, level=selected_level) _render_similarity_panel( st, records, query_id=query_id, embedder_name=_embedder_name_from_label(embedder_label), neighbor_count=neighbor_count, show_patch_maps=show_patch_maps, weights=weights, ) st.caption(f"{len(filtered)} of {len(records)} records") for record in filtered: _render_record(st, record, show_solution=show_solutions) def _render_similarity_panel( st: Any, records: list[MathVisionRecord], *, query_id: str, embedder_name: str, neighbor_count: int, show_patch_maps: bool, weights: SimilarityWeights, ) -> None: st.header("Nearest Neighbors") st.caption(embedder_description(embedder_name)) record_by_id = {record.problem_id: record for record in records} query = record_by_id[query_id] if query.image_path is None: st.warning("Selected query has no image.") return try: embedder = _load_embedder(embedder_name) query_vector = embedder.embed_image(query.image_path) index = build_image_index(records, embedder) matches = _find_similar_records_combined( records, index, query.problem_id, query_vector, limit=neighbor_count, weights=weights, ) except RuntimeError as error: st.error(str(error)) return columns = st.columns([1, 2]) with columns[0]: st.caption("Query") st.image(str(query.image_path), width=QUERY_IMAGE_WIDTH) if show_patch_maps and embedder_name == "ijepa": _render_patch_attention(st, embedder, query.image_path, width=QUERY_IMAGE_WIDTH) st.write(query.problem_id) st.caption(_record_metadata_line(query)) with columns[1]: for record, neighbor, combined in matches: interpretation = interpret_match(query, record, score=neighbor.score) with st.container(border=True): match_columns = st.columns([0.35, 1]) with match_columns[0]: if record.image_path is not None: st.image(str(record.image_path), width=NEIGHBOR_IMAGE_WIDTH) if show_patch_maps and embedder_name == "ijepa": _render_patch_attention( st, embedder, record.image_path, width=NEIGHBOR_IMAGE_WIDTH, expanded=False, ) with match_columns[1]: st.write(f"**{record.problem_id}**") st.caption( f"cosine {neighbor.score:.4f} | combined {combined:.4f} | " f"{interpretation.label} " f"| {_record_metadata_line(record)}" ) _render_score_swatch(st, neighbor.score) st.write(record.question) st.write(interpretation.summary) def _render_record(st: Any, record: MathVisionRecord, *, show_solution: bool) -> None: with st.container(border=True): columns = st.columns([1, 1.4]) with columns[0]: if record.image_path is not None: st.image(str(record.image_path), width=RECORD_IMAGE_WIDTH) with columns[1]: st.subheader(record.question) badges = [record.problem_id] if record.subject is not None: badges.append(record.subject) if record.level is not None: badges.append(f"level {record.level}") if record.problem_type is not None: badges.append(record.problem_type) elif record.options: badges.append("multiple_choice") else: badges.append("open_ended") st.caption(" | ".join(badges)) if record.options: st.write("Options: " + ", ".join(record.options)) st.write(f"Answer: **{record.answer}**") if show_solution and record.solution: st.write(record.solution) def _load_active_records(st: Any, jsonl_path: Path) -> list[MathVisionRecord]: try: return load_jsonl_records(jsonl_path) except (OSError, ValueError) as error: st.error(str(error)) st.stop() raise RuntimeError("Streamlit stopped after dataset load error.") from error def _load_uploaded_records(st: Any, uploaded_dataset: Any) -> list[MathVisionRecord]: dataset_bytes = uploaded_dataset.getvalue() dataset_name = uploaded_dataset.name dataset_key = _uploaded_dataset_key(dataset_name, dataset_bytes) try: if dataset_name.lower().endswith(".zip"): return _load_uploaded_zip_records(st, dataset_key, dataset_bytes) return load_jsonl_records_from_text(dataset_bytes.decode("utf-8")) except (BadZipFile, UnicodeDecodeError, ValueError, OSError) as error: st.error(str(error)) st.stop() raise RuntimeError("Streamlit stopped after upload load error.") from error def _load_uploaded_zip_records( st: Any, dataset_key: str, dataset_bytes: bytes, ) -> list[MathVisionRecord]: upload_state = st.session_state.setdefault("uploaded_dataset", {}) if upload_state.get("key") != dataset_key: _remove_upload_dir(upload_state.get("extract_dir")) extract_dir = Path(tempfile.mkdtemp(prefix="mathvision-upload-")) _extract_zip_safely(dataset_bytes, extract_dir) upload_state.clear() upload_state.update({"key": dataset_key, "extract_dir": str(extract_dir)}) extract_dir = Path(upload_state["extract_dir"]) jsonl_files = sorted(extract_dir.rglob("*.jsonl")) if not jsonl_files: msg = "Uploaded ZIP must contain a .jsonl file." raise ValueError(msg) return load_jsonl_records(jsonl_files[0]) def _load_hf_dataset_records( st: Any, dataset_ref: str, *, limit: int, ) -> list[MathVisionRecord]: repo_id = _hf_dataset_id_from_ref(dataset_ref) datasets = _load_datasets_library() hf_state = st.session_state.setdefault("hf_dataset", {}) hf_key = f"{repo_id}:{limit}" if hf_state.get("key") != hf_key: _remove_upload_dir(hf_state.get("image_dir")) image_dir = Path(tempfile.mkdtemp(prefix="mathvision-hf-images-")) status = st.empty() progress = st.progress(0, text="Preparing Hugging Face dataset load") status.info(f"Connecting to `{repo_id}`...") try: progress.progress(20, text="Finding the first usable split") split_name = _choose_hf_split(datasets, repo_id) status.info(f"Loading `{repo_id}` split `{split_name}`...") progress.progress(45, text="Downloading dataset rows") dataset = datasets.load_dataset(repo_id, split=split_name) progress.progress(65, text=f"Converting up to {limit} rows") records = _records_from_hf_dataset(dataset, image_dir=image_dir, limit=limit) image_count = sum(1 for record in records if record.image_path is not None) progress.progress(100, text="Dataset ready") status.success( f"Loaded {len(records)} records with {image_count} images from `{split_name}`." ) except Exception: progress.empty() status.error("Hugging Face dataset load failed.") raise hf_state.clear() hf_state.update({"key": hf_key, "image_dir": str(image_dir), "records": records}) else: cached_records = hf_state.get("records", []) if isinstance(cached_records, list): st.success(f"Using cached Hugging Face dataset ({len(cached_records)} records).") records = hf_state["records"] if not isinstance(records, list): msg = "Cached HF dataset records are unavailable." raise RuntimeError(msg) return records def _choose_hf_split(datasets: Any, repo_id: str) -> str: dataset_builder = datasets.load_dataset_builder(repo_id) split_names = [str(name) for name in (getattr(dataset_builder.info, "splits", {}) or {})] if not split_names: return "train" for preferred in ("test", "validation", "valid", "train"): if preferred in split_names: return preferred return split_names[0] def _records_from_hf_dataset( dataset: Any, *, image_dir: Path, limit: int, ) -> list[MathVisionRecord]: records: list[MathVisionRecord] = [] skipped = 0 for row_index, row in enumerate(dataset): if len(records) >= limit: break if not isinstance(row, dict): skipped += 1 continue record = _record_from_hf_row(row, row_index=row_index, image_dir=image_dir) if record is None: skipped += 1 continue records.append(record) if not records: msg = "No compatible rows found. Expected fields like id, question, answer, and image." if skipped: msg += f" Skipped {skipped} rows." raise ValueError(msg) return records def _record_from_hf_row( row: dict[str, Any], *, row_index: int, image_dir: Path, ) -> MathVisionRecord | None: question = _text_from_row(row, "question", "problem", "prompt") answer = _text_from_row(row, "answer", "label", "target") if question is None or answer is None: return None payload: dict[str, Any] = { "id": _text_from_row(row, "id", "problem_id", "question_id") or f"hf-{row_index}", "question": question, "answer": answer, "subject": _text_from_row(row, "subject", "category"), "problem_type": _text_from_row(row, "problem_type", "type", "task"), "solution": _text_from_row(row, "solution", "rationale", "explanation"), } if isinstance(row.get("level"), int): payload["level"] = row["level"] options = row.get("options") if isinstance(options, list): payload["options"] = [str(option) for option in options] image_path = _save_hf_row_image(row, row_index=row_index, image_dir=image_dir) if image_path is not None: payload["image"] = str(image_path) return record_from_mapping(payload) def _save_hf_row_image( row: dict[str, Any], *, row_index: int, image_dir: Path, ) -> Path | None: for key in ("decoded_image", "image", "img"): image_value = row.get(key) save_image = getattr(image_value, "save", None) if callable(save_image): image_path = image_dir / f"row-{row_index:05d}.png" save_image(image_path) return image_path return None def _text_from_row(row: dict[str, Any], *keys: str) -> str | None: for key in keys: value = row.get(key) if value is None: continue if isinstance(value, str): stripped = value.strip() if stripped: return stripped elif isinstance(value, int | float): return str(value) return None def _hf_dataset_id_from_ref(dataset_ref: str) -> str: stripped = dataset_ref.strip() if not stripped: msg = "Enter a Hugging Face dataset URL or repo id." raise ValueError(msg) parsed = urlparse(stripped) if parsed.netloc: parts = [part for part in parsed.path.split("/") if part] if parts[:1] == ["datasets"] and len(parts) >= 3: return "/".join(parts[1:3]) msg = "HF dataset URL should look like https://huggingface.co/datasets/org/name." raise ValueError(msg) return stripped.removeprefix("datasets/") def _load_datasets_library() -> Any: try: return import_module("datasets") except ImportError as error: msg = "Install the `datasets` package to load Hugging Face dataset URLs." raise RuntimeError(msg) from error def _extract_zip_safely(dataset_bytes: bytes, extract_dir: Path) -> None: with ZipFile(io.BytesIO(dataset_bytes)) as dataset_zip: for member in dataset_zip.infolist(): target_path = (extract_dir / member.filename).resolve() if not target_path.is_relative_to(extract_dir.resolve()): msg = f"Unsafe ZIP member path: {member.filename}" raise ValueError(msg) dataset_zip.extract(member, extract_dir) def _uploaded_dataset_key(dataset_name: str, dataset_bytes: bytes) -> str: digest = hashlib.sha256(dataset_bytes).hexdigest() return f"{dataset_name}:{digest}" def _remove_upload_dir(path: object) -> None: if isinstance(path, str): shutil.rmtree(path, ignore_errors=True) def _render_patch_attention( st: Any, embedder: ImageEmbedder, image_path: Path, *, width: int, expanded: bool = True, ) -> None: if not isinstance(embedder, IJepaImageEmbedder): return with st.expander("Patch map", expanded=expanded): try: interest_map = embedder.patch_interest_map(image_path) overlay = render_patch_interest_overlay(image_path, interest_map) grid_rows, grid_columns = interest_map.grid_size except RuntimeError as error: st.warning(str(error)) return st.image(overlay, width=width) st.caption(f"{grid_rows} x {grid_columns} patch activation map") def _stabilize_layout(st: Any) -> None: """Keep nested Streamlit columns from resizing while images load.""" st.markdown( """ """, unsafe_allow_html=True, ) def _load_streamlit() -> Any: try: return import_module("streamlit") except ImportError as error: msg = "Streamlit is missing. Install it with `uv sync --extra app --dev`." raise RuntimeError(msg) from error def _load_embedder(embedder_name: str) -> ImageEmbedder: if embedder_name == "ijepa": return IJepaImageEmbedder() return ColorStatsEmbedder() def _embedder_name_from_label(label: str) -> str: return "ijepa" if label.startswith("ijepa") else "color" def _record_metadata_line(record: MathVisionRecord) -> str: subject = record.subject or "unknown subject" level = f"level {record.level}" if record.level is not None else "unknown level" if record.problem_type: problem_type = record.problem_type elif record.options: problem_type = "multiple_choice" else: problem_type = "open_ended" return f"{subject} | {level} | {problem_type}" def _render_score_swatch(st: Any, score: float) -> None: color = _score_to_hex(score) st.markdown( f"
", unsafe_allow_html=True, ) def _score_to_hex(score: float) -> str: """Map cosine similarity (-1..1) to a red->green swatch for quick scanning.""" clamped = max(-1.0, min(1.0, score)) t = (clamped + 1.0) / 2.0 red = int(round(220 * (1.0 - t) + 30 * t)) green = int(round(30 * (1.0 - t) + 180 * t)) blue = int(round(60 * (1.0 - t) + 60 * t)) return f"#{red:02x}{green:02x}{blue:02x}" def _find_similar_records_combined( records: list[MathVisionRecord], index: Any, query_id: str, query_vector: tuple[float, ...], *, limit: int, weights: SimilarityWeights, ) -> list[tuple[MathVisionRecord, Any, float]]: """Fetch candidates by cosine similarity, then rerank with metadata weights.""" record_by_id = {record.problem_id: record for record in records} query = record_by_id[query_id] candidate_count = max(25, limit * 10) neighbors = index.search(query_vector, limit=candidate_count, exclude_id=query_id) scored: list[tuple[MathVisionRecord, Any, float]] = [] for neighbor in neighbors: record = record_by_id.get(neighbor.item_id) if record is None: continue combined = combined_score(query, record, image_score=neighbor.score, weights=weights) scored.append((record, neighbor, combined)) scored.sort(key=lambda row: row[2], reverse=True) return scored[:limit] if __name__ == "__main__": main()