| """Streamlit app for browsing MathVision-like records.""" |
|
|
| from __future__ import annotations |
|
|
| import hashlib |
| import io |
| import shutil |
| import tempfile |
| from importlib import import_module |
| from pathlib import Path |
| from typing import Any |
| from urllib.parse import urlparse |
| from zipfile import BadZipFile, ZipFile |
|
|
| from mathvision_explorer.dataset import ( |
| MathVisionRecord, |
| filter_records, |
| load_jsonl_records, |
| load_jsonl_records_from_text, |
| record_from_mapping, |
| summarize_records, |
| ) |
| from mathvision_explorer.embeddings import ( |
| ColorStatsEmbedder, |
| IJepaImageEmbedder, |
| ImageEmbedder, |
| render_patch_interest_overlay, |
| ) |
| from mathvision_explorer.explorer import build_image_index |
| from mathvision_explorer.similarity import ( |
| SimilarityWeights, |
| combined_score, |
| embedder_description, |
| interpret_match, |
| ) |
|
|
| QUERY_IMAGE_WIDTH = 260 |
| NEIGHBOR_IMAGE_WIDTH = 130 |
| RECORD_IMAGE_WIDTH = 320 |
| MATHVISION_DATASET_URL = "https://huggingface.co/datasets/MathLLMs/MathVision" |
| HF_DATASETS_URL = "https://huggingface.co/datasets" |
|
|
|
|
| def main(jsonl_path: Path = Path("data/demo/demo.jsonl")) -> None: |
| """Run the Streamlit explorer app.""" |
|
|
| st = _load_streamlit() |
|
|
| st.set_page_config(page_title="MathVision Explorer", layout="wide") |
| _stabilize_layout(st) |
| st.title("MathVision Explorer") |
|
|
| records = _load_active_records(st, jsonl_path) |
| subjects = sorted({record.subject for record in records if record.subject is not None}) |
| levels = sorted({record.level for record in records if record.level is not None}) |
|
|
| with st.sidebar: |
| st.header("Dataset") |
| st.markdown( |
| f"[Example: MathLLMs/MathVision]({MATHVISION_DATASET_URL}) | " |
| f"[Browse HF datasets]({HF_DATASETS_URL})" |
| ) |
| dataset_source = st.radio( |
| "Dataset source", |
| ["Demo", "Hugging Face URL", "Upload file"], |
| horizontal=False, |
| help=( |
| "Choose whether to use the bundled demo, paste a Hub dataset link, " |
| "or upload files." |
| ), |
| ) |
| if dataset_source == "Hugging Face URL": |
| hf_dataset_ref = st.text_input( |
| "HF dataset URL or ID", |
| value="MathLLMs/MathVision", |
| placeholder="https://huggingface.co/datasets/MathLLMs/MathVision", |
| help="Paste a Hugging Face dataset URL or repo id.", |
| ) |
| hf_limit = st.number_input( |
| "HF max records", |
| min_value=1, |
| max_value=500, |
| value=50, |
| step=10, |
| help="Cap rows loaded from Hugging Face so exploration stays responsive.", |
| ) |
| if st.button( |
| "Load HF dataset", |
| help="Download the first available split and convert compatible rows into records.", |
| ): |
| try: |
| records = _load_hf_dataset_records( |
| st, |
| hf_dataset_ref, |
| limit=int(hf_limit), |
| ) |
| except (RuntimeError, ValueError, OSError) as error: |
| st.error(str(error)) |
| st.stop() |
| raise RuntimeError("Streamlit stopped after HF dataset load error.") from error |
| st.session_state["hf_dataset_records"] = records |
| elif "hf_dataset_records" in st.session_state: |
| records = st.session_state["hf_dataset_records"] |
| elif dataset_source == "Upload file": |
| uploaded_dataset = st.file_uploader( |
| "Upload dataset", |
| type=["jsonl", "zip"], |
| help=( |
| "Use a JSONL file for text-only records, or a ZIP containing one JSONL " |
| "file plus referenced images." |
| ), |
| ) |
| if uploaded_dataset is not None: |
| records = _load_uploaded_records(st, uploaded_dataset) |
| subjects = sorted({record.subject for record in records if record.subject is not None}) |
| levels = sorted({record.level for record in records if record.level is not None}) |
| summary = summarize_records(records) |
| st.caption(f"{summary['records']} records | {summary['images']} images") |
| st.header("Filters") |
| subject = st.selectbox( |
| "Subject", |
| ["all", *subjects], |
| help="Limit the visible record list to one MathVision subject.", |
| ) |
| level_label = st.selectbox( |
| "Level", |
| ["all", *(str(level) for level in levels)], |
| help="Limit the visible record list to one difficulty level.", |
| ) |
| show_solutions = st.toggle( |
| "Show solutions", |
| value=True, |
| help="Show or hide worked solutions in the record browser below the neighbor panel.", |
| ) |
| st.header("Latent Space") |
| embedder_label = st.selectbox( |
| "Embedder", |
| ["color (fast demo)", "ijepa (semantic, requires make sync-ijepa)"], |
| help=( |
| "Choose the image representation used for nearest-neighbor search. " |
| "I-JEPA is slower but more semantic." |
| ), |
| ) |
| show_patch_maps = st.toggle( |
| "Patch maps", |
| value=True, |
| help=( |
| "When I-JEPA is selected, overlay patch-activation heatmaps on the query " |
| "and neighbor images." |
| ), |
| ) |
| st.caption("Similarity swatch: red = lower cosine similarity, green = higher.") |
| st.subheader("Ranking") |
| weights = SimilarityWeights( |
| image=st.slider( |
| "Image weight", |
| min_value=0.0, |
| max_value=2.0, |
| value=1.0, |
| step=0.05, |
| help="How much raw image similarity contributes to the final ranking.", |
| ), |
| subject_bonus=st.slider( |
| "Subject bonus", |
| min_value=0.0, |
| max_value=1.0, |
| value=0.15, |
| step=0.01, |
| help="Extra score for neighbors from the same subject as the query.", |
| ), |
| problem_type_bonus=st.slider( |
| "Problem-type bonus", |
| min_value=0.0, |
| max_value=1.0, |
| value=0.10, |
| step=0.01, |
| help="Extra score for matching multiple-choice or open-ended format.", |
| ), |
| level_penalty=st.slider( |
| "Level penalty", |
| min_value=0.0, |
| max_value=1.0, |
| value=0.05, |
| step=0.01, |
| help="Penalty per difficulty-level step between query and neighbor.", |
| ), |
| ) |
| query_id = st.selectbox( |
| "Query record", |
| [record.problem_id for record in records], |
| help="The record used as the visual search query.", |
| ) |
| neighbor_count = st.slider( |
| "Neighbors", |
| min_value=1, |
| max_value=8, |
| value=3, |
| help="Number of nearest records to show in the right-hand panel.", |
| ) |
|
|
| selected_subject = None if subject == "all" else subject |
| selected_level = None if level_label == "all" else int(level_label) |
| filtered = filter_records(records, subject=selected_subject, level=selected_level) |
|
|
| _render_similarity_panel( |
| st, |
| records, |
| query_id=query_id, |
| embedder_name=_embedder_name_from_label(embedder_label), |
| neighbor_count=neighbor_count, |
| show_patch_maps=show_patch_maps, |
| weights=weights, |
| ) |
|
|
| st.caption(f"{len(filtered)} of {len(records)} records") |
| for record in filtered: |
| _render_record(st, record, show_solution=show_solutions) |
|
|
|
|
| def _render_similarity_panel( |
| st: Any, |
| records: list[MathVisionRecord], |
| *, |
| query_id: str, |
| embedder_name: str, |
| neighbor_count: int, |
| show_patch_maps: bool, |
| weights: SimilarityWeights, |
| ) -> None: |
| st.header("Nearest Neighbors") |
| st.caption(embedder_description(embedder_name)) |
| record_by_id = {record.problem_id: record for record in records} |
| query = record_by_id[query_id] |
| if query.image_path is None: |
| st.warning("Selected query has no image.") |
| return |
|
|
| try: |
| embedder = _load_embedder(embedder_name) |
| query_vector = embedder.embed_image(query.image_path) |
| index = build_image_index(records, embedder) |
| matches = _find_similar_records_combined( |
| records, |
| index, |
| query.problem_id, |
| query_vector, |
| limit=neighbor_count, |
| weights=weights, |
| ) |
| except RuntimeError as error: |
| st.error(str(error)) |
| return |
|
|
| columns = st.columns([1, 2]) |
| with columns[0]: |
| st.caption("Query") |
| st.image(str(query.image_path), width=QUERY_IMAGE_WIDTH) |
| if show_patch_maps and embedder_name == "ijepa": |
| _render_patch_attention(st, embedder, query.image_path, width=QUERY_IMAGE_WIDTH) |
| st.write(query.problem_id) |
| st.caption(_record_metadata_line(query)) |
| with columns[1]: |
| for record, neighbor, combined in matches: |
| interpretation = interpret_match(query, record, score=neighbor.score) |
| with st.container(border=True): |
| match_columns = st.columns([0.35, 1]) |
| with match_columns[0]: |
| if record.image_path is not None: |
| st.image(str(record.image_path), width=NEIGHBOR_IMAGE_WIDTH) |
| if show_patch_maps and embedder_name == "ijepa": |
| _render_patch_attention( |
| st, |
| embedder, |
| record.image_path, |
| width=NEIGHBOR_IMAGE_WIDTH, |
| expanded=False, |
| ) |
| with match_columns[1]: |
| st.write(f"**{record.problem_id}**") |
| st.caption( |
| f"cosine {neighbor.score:.4f} | combined {combined:.4f} | " |
| f"{interpretation.label} " |
| f"| {_record_metadata_line(record)}" |
| ) |
| _render_score_swatch(st, neighbor.score) |
| st.write(record.question) |
| st.write(interpretation.summary) |
|
|
|
|
| def _render_record(st: Any, record: MathVisionRecord, *, show_solution: bool) -> None: |
| with st.container(border=True): |
| columns = st.columns([1, 1.4]) |
| with columns[0]: |
| if record.image_path is not None: |
| st.image(str(record.image_path), width=RECORD_IMAGE_WIDTH) |
| with columns[1]: |
| st.subheader(record.question) |
| badges = [record.problem_id] |
| if record.subject is not None: |
| badges.append(record.subject) |
| if record.level is not None: |
| badges.append(f"level {record.level}") |
| if record.problem_type is not None: |
| badges.append(record.problem_type) |
| elif record.options: |
| badges.append("multiple_choice") |
| else: |
| badges.append("open_ended") |
| st.caption(" | ".join(badges)) |
| if record.options: |
| st.write("Options: " + ", ".join(record.options)) |
| st.write(f"Answer: **{record.answer}**") |
| if show_solution and record.solution: |
| st.write(record.solution) |
|
|
|
|
| def _load_active_records(st: Any, jsonl_path: Path) -> list[MathVisionRecord]: |
| try: |
| return load_jsonl_records(jsonl_path) |
| except (OSError, ValueError) as error: |
| st.error(str(error)) |
| st.stop() |
| raise RuntimeError("Streamlit stopped after dataset load error.") from error |
|
|
|
|
| def _load_uploaded_records(st: Any, uploaded_dataset: Any) -> list[MathVisionRecord]: |
| dataset_bytes = uploaded_dataset.getvalue() |
| dataset_name = uploaded_dataset.name |
| dataset_key = _uploaded_dataset_key(dataset_name, dataset_bytes) |
|
|
| try: |
| if dataset_name.lower().endswith(".zip"): |
| return _load_uploaded_zip_records(st, dataset_key, dataset_bytes) |
| return load_jsonl_records_from_text(dataset_bytes.decode("utf-8")) |
| except (BadZipFile, UnicodeDecodeError, ValueError, OSError) as error: |
| st.error(str(error)) |
| st.stop() |
| raise RuntimeError("Streamlit stopped after upload load error.") from error |
|
|
|
|
| def _load_uploaded_zip_records( |
| st: Any, |
| dataset_key: str, |
| dataset_bytes: bytes, |
| ) -> list[MathVisionRecord]: |
| upload_state = st.session_state.setdefault("uploaded_dataset", {}) |
| if upload_state.get("key") != dataset_key: |
| _remove_upload_dir(upload_state.get("extract_dir")) |
| extract_dir = Path(tempfile.mkdtemp(prefix="mathvision-upload-")) |
| _extract_zip_safely(dataset_bytes, extract_dir) |
| upload_state.clear() |
| upload_state.update({"key": dataset_key, "extract_dir": str(extract_dir)}) |
|
|
| extract_dir = Path(upload_state["extract_dir"]) |
| jsonl_files = sorted(extract_dir.rglob("*.jsonl")) |
| if not jsonl_files: |
| msg = "Uploaded ZIP must contain a .jsonl file." |
| raise ValueError(msg) |
| return load_jsonl_records(jsonl_files[0]) |
|
|
|
|
| def _load_hf_dataset_records( |
| st: Any, |
| dataset_ref: str, |
| *, |
| limit: int, |
| ) -> list[MathVisionRecord]: |
| repo_id = _hf_dataset_id_from_ref(dataset_ref) |
| datasets = _load_datasets_library() |
| hf_state = st.session_state.setdefault("hf_dataset", {}) |
| hf_key = f"{repo_id}:{limit}" |
| if hf_state.get("key") != hf_key: |
| _remove_upload_dir(hf_state.get("image_dir")) |
| image_dir = Path(tempfile.mkdtemp(prefix="mathvision-hf-images-")) |
| status = st.empty() |
| progress = st.progress(0, text="Preparing Hugging Face dataset load") |
| status.info(f"Connecting to `{repo_id}`...") |
| try: |
| progress.progress(20, text="Finding the first usable split") |
| split_name = _choose_hf_split(datasets, repo_id) |
| status.info(f"Loading `{repo_id}` split `{split_name}`...") |
| progress.progress(45, text="Downloading dataset rows") |
| dataset = datasets.load_dataset(repo_id, split=split_name) |
| progress.progress(65, text=f"Converting up to {limit} rows") |
| records = _records_from_hf_dataset(dataset, image_dir=image_dir, limit=limit) |
| image_count = sum(1 for record in records if record.image_path is not None) |
| progress.progress(100, text="Dataset ready") |
| status.success( |
| f"Loaded {len(records)} records with {image_count} images from `{split_name}`." |
| ) |
| except Exception: |
| progress.empty() |
| status.error("Hugging Face dataset load failed.") |
| raise |
| hf_state.clear() |
| hf_state.update({"key": hf_key, "image_dir": str(image_dir), "records": records}) |
| else: |
| cached_records = hf_state.get("records", []) |
| if isinstance(cached_records, list): |
| st.success(f"Using cached Hugging Face dataset ({len(cached_records)} records).") |
|
|
| records = hf_state["records"] |
| if not isinstance(records, list): |
| msg = "Cached HF dataset records are unavailable." |
| raise RuntimeError(msg) |
| return records |
|
|
|
|
| def _choose_hf_split(datasets: Any, repo_id: str) -> str: |
| dataset_builder = datasets.load_dataset_builder(repo_id) |
| split_names = [str(name) for name in (getattr(dataset_builder.info, "splits", {}) or {})] |
| if not split_names: |
| return "train" |
| for preferred in ("test", "validation", "valid", "train"): |
| if preferred in split_names: |
| return preferred |
| return split_names[0] |
|
|
|
|
| def _records_from_hf_dataset( |
| dataset: Any, |
| *, |
| image_dir: Path, |
| limit: int, |
| ) -> list[MathVisionRecord]: |
| records: list[MathVisionRecord] = [] |
| skipped = 0 |
| for row_index, row in enumerate(dataset): |
| if len(records) >= limit: |
| break |
| if not isinstance(row, dict): |
| skipped += 1 |
| continue |
| record = _record_from_hf_row(row, row_index=row_index, image_dir=image_dir) |
| if record is None: |
| skipped += 1 |
| continue |
| records.append(record) |
|
|
| if not records: |
| msg = "No compatible rows found. Expected fields like id, question, answer, and image." |
| if skipped: |
| msg += f" Skipped {skipped} rows." |
| raise ValueError(msg) |
| return records |
|
|
|
|
| def _record_from_hf_row( |
| row: dict[str, Any], |
| *, |
| row_index: int, |
| image_dir: Path, |
| ) -> MathVisionRecord | None: |
| question = _text_from_row(row, "question", "problem", "prompt") |
| answer = _text_from_row(row, "answer", "label", "target") |
| if question is None or answer is None: |
| return None |
|
|
| payload: dict[str, Any] = { |
| "id": _text_from_row(row, "id", "problem_id", "question_id") or f"hf-{row_index}", |
| "question": question, |
| "answer": answer, |
| "subject": _text_from_row(row, "subject", "category"), |
| "problem_type": _text_from_row(row, "problem_type", "type", "task"), |
| "solution": _text_from_row(row, "solution", "rationale", "explanation"), |
| } |
| if isinstance(row.get("level"), int): |
| payload["level"] = row["level"] |
| options = row.get("options") |
| if isinstance(options, list): |
| payload["options"] = [str(option) for option in options] |
|
|
| image_path = _save_hf_row_image(row, row_index=row_index, image_dir=image_dir) |
| if image_path is not None: |
| payload["image"] = str(image_path) |
| return record_from_mapping(payload) |
|
|
|
|
| def _save_hf_row_image( |
| row: dict[str, Any], |
| *, |
| row_index: int, |
| image_dir: Path, |
| ) -> Path | None: |
| for key in ("decoded_image", "image", "img"): |
| image_value = row.get(key) |
| save_image = getattr(image_value, "save", None) |
| if callable(save_image): |
| image_path = image_dir / f"row-{row_index:05d}.png" |
| save_image(image_path) |
| return image_path |
| return None |
|
|
|
|
| def _text_from_row(row: dict[str, Any], *keys: str) -> str | None: |
| for key in keys: |
| value = row.get(key) |
| if value is None: |
| continue |
| if isinstance(value, str): |
| stripped = value.strip() |
| if stripped: |
| return stripped |
| elif isinstance(value, int | float): |
| return str(value) |
| return None |
|
|
|
|
| def _hf_dataset_id_from_ref(dataset_ref: str) -> str: |
| stripped = dataset_ref.strip() |
| if not stripped: |
| msg = "Enter a Hugging Face dataset URL or repo id." |
| raise ValueError(msg) |
| parsed = urlparse(stripped) |
| if parsed.netloc: |
| parts = [part for part in parsed.path.split("/") if part] |
| if parts[:1] == ["datasets"] and len(parts) >= 3: |
| return "/".join(parts[1:3]) |
| msg = "HF dataset URL should look like https://huggingface.co/datasets/org/name." |
| raise ValueError(msg) |
| return stripped.removeprefix("datasets/") |
|
|
|
|
| def _load_datasets_library() -> Any: |
| try: |
| return import_module("datasets") |
| except ImportError as error: |
| msg = "Install the `datasets` package to load Hugging Face dataset URLs." |
| raise RuntimeError(msg) from error |
|
|
|
|
| def _extract_zip_safely(dataset_bytes: bytes, extract_dir: Path) -> None: |
| with ZipFile(io.BytesIO(dataset_bytes)) as dataset_zip: |
| for member in dataset_zip.infolist(): |
| target_path = (extract_dir / member.filename).resolve() |
| if not target_path.is_relative_to(extract_dir.resolve()): |
| msg = f"Unsafe ZIP member path: {member.filename}" |
| raise ValueError(msg) |
| dataset_zip.extract(member, extract_dir) |
|
|
|
|
| def _uploaded_dataset_key(dataset_name: str, dataset_bytes: bytes) -> str: |
| digest = hashlib.sha256(dataset_bytes).hexdigest() |
| return f"{dataset_name}:{digest}" |
|
|
|
|
| def _remove_upload_dir(path: object) -> None: |
| if isinstance(path, str): |
| shutil.rmtree(path, ignore_errors=True) |
|
|
|
|
| def _render_patch_attention( |
| st: Any, |
| embedder: ImageEmbedder, |
| image_path: Path, |
| *, |
| width: int, |
| expanded: bool = True, |
| ) -> None: |
| if not isinstance(embedder, IJepaImageEmbedder): |
| return |
|
|
| with st.expander("Patch map", expanded=expanded): |
| try: |
| interest_map = embedder.patch_interest_map(image_path) |
| overlay = render_patch_interest_overlay(image_path, interest_map) |
| grid_rows, grid_columns = interest_map.grid_size |
| except RuntimeError as error: |
| st.warning(str(error)) |
| return |
|
|
| st.image(overlay, width=width) |
| st.caption(f"{grid_rows} x {grid_columns} patch activation map") |
|
|
|
|
| def _stabilize_layout(st: Any) -> None: |
| """Keep nested Streamlit columns from resizing while images load.""" |
|
|
| st.markdown( |
| """ |
| <style> |
| [data-testid="stHorizontalBlock"] { align-items: flex-start; } |
| [data-testid="column"] { min-width: 0; } |
| [data-testid="stImage"] img { |
| display: block; |
| height: auto; |
| max-width: 100%; |
| } |
| [data-testid="stMarkdownContainer"] { |
| overflow-wrap: anywhere; |
| } |
| </style> |
| """, |
| unsafe_allow_html=True, |
| ) |
|
|
|
|
| def _load_streamlit() -> Any: |
| try: |
| return import_module("streamlit") |
| except ImportError as error: |
| msg = "Streamlit is missing. Install it with `uv sync --extra app --dev`." |
| raise RuntimeError(msg) from error |
|
|
|
|
| def _load_embedder(embedder_name: str) -> ImageEmbedder: |
| if embedder_name == "ijepa": |
| return IJepaImageEmbedder() |
| return ColorStatsEmbedder() |
|
|
|
|
| def _embedder_name_from_label(label: str) -> str: |
| return "ijepa" if label.startswith("ijepa") else "color" |
|
|
|
|
| def _record_metadata_line(record: MathVisionRecord) -> str: |
| subject = record.subject or "unknown subject" |
| level = f"level {record.level}" if record.level is not None else "unknown level" |
| if record.problem_type: |
| problem_type = record.problem_type |
| elif record.options: |
| problem_type = "multiple_choice" |
| else: |
| problem_type = "open_ended" |
| return f"{subject} | {level} | {problem_type}" |
|
|
|
|
| def _render_score_swatch(st: Any, score: float) -> None: |
| color = _score_to_hex(score) |
| st.markdown( |
| f"<div style='height: 10px; border-radius: 3px; background: {color}; " |
| "border: 1px solid rgba(0,0,0,0.15);'></div>", |
| unsafe_allow_html=True, |
| ) |
|
|
|
|
| def _score_to_hex(score: float) -> str: |
| """Map cosine similarity (-1..1) to a red->green swatch for quick scanning.""" |
|
|
| clamped = max(-1.0, min(1.0, score)) |
| t = (clamped + 1.0) / 2.0 |
| red = int(round(220 * (1.0 - t) + 30 * t)) |
| green = int(round(30 * (1.0 - t) + 180 * t)) |
| blue = int(round(60 * (1.0 - t) + 60 * t)) |
| return f"#{red:02x}{green:02x}{blue:02x}" |
|
|
|
|
| def _find_similar_records_combined( |
| records: list[MathVisionRecord], |
| index: Any, |
| query_id: str, |
| query_vector: tuple[float, ...], |
| *, |
| limit: int, |
| weights: SimilarityWeights, |
| ) -> list[tuple[MathVisionRecord, Any, float]]: |
| """Fetch candidates by cosine similarity, then rerank with metadata weights.""" |
|
|
| record_by_id = {record.problem_id: record for record in records} |
| query = record_by_id[query_id] |
| candidate_count = max(25, limit * 10) |
| neighbors = index.search(query_vector, limit=candidate_count, exclude_id=query_id) |
| scored: list[tuple[MathVisionRecord, Any, float]] = [] |
| for neighbor in neighbors: |
| record = record_by_id.get(neighbor.item_id) |
| if record is None: |
| continue |
| combined = combined_score(query, record, image_score=neighbor.score, weights=weights) |
| scored.append((record, neighbor, combined)) |
| scored.sort(key=lambda row: row[2], reverse=True) |
| return scored[:limit] |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|