ddebree's picture
Auto-select
f21dc8c
"""Streamlit app for browsing MathVision-like records."""
from __future__ import annotations
import hashlib
import io
import shutil
import tempfile
from importlib import import_module
from pathlib import Path
from typing import Any
from urllib.parse import urlparse
from zipfile import BadZipFile, ZipFile
from mathvision_explorer.dataset import (
MathVisionRecord,
filter_records,
load_jsonl_records,
load_jsonl_records_from_text,
record_from_mapping,
summarize_records,
)
from mathvision_explorer.embeddings import (
ColorStatsEmbedder,
IJepaImageEmbedder,
ImageEmbedder,
render_patch_interest_overlay,
)
from mathvision_explorer.explorer import build_image_index
from mathvision_explorer.similarity import (
SimilarityWeights,
combined_score,
embedder_description,
interpret_match,
)
QUERY_IMAGE_WIDTH = 260
NEIGHBOR_IMAGE_WIDTH = 130
RECORD_IMAGE_WIDTH = 320
MATHVISION_DATASET_URL = "https://huggingface.co/datasets/MathLLMs/MathVision"
HF_DATASETS_URL = "https://huggingface.co/datasets"
def main(jsonl_path: Path = Path("data/demo/demo.jsonl")) -> None:
"""Run the Streamlit explorer app."""
st = _load_streamlit()
st.set_page_config(page_title="MathVision Explorer", layout="wide")
_stabilize_layout(st)
st.title("MathVision Explorer")
records = _load_active_records(st, jsonl_path)
subjects = sorted({record.subject for record in records if record.subject is not None})
levels = sorted({record.level for record in records if record.level is not None})
with st.sidebar:
st.header("Dataset")
st.markdown(
f"[Example: MathLLMs/MathVision]({MATHVISION_DATASET_URL}) | "
f"[Browse HF datasets]({HF_DATASETS_URL})"
)
dataset_source = st.radio(
"Dataset source",
["Demo", "Hugging Face URL", "Upload file"],
horizontal=False,
help=(
"Choose whether to use the bundled demo, paste a Hub dataset link, "
"or upload files."
),
)
if dataset_source == "Hugging Face URL":
hf_dataset_ref = st.text_input(
"HF dataset URL or ID",
value="MathLLMs/MathVision",
placeholder="https://huggingface.co/datasets/MathLLMs/MathVision",
help="Paste a Hugging Face dataset URL or repo id.",
)
hf_limit = st.number_input(
"HF max records",
min_value=1,
max_value=500,
value=50,
step=10,
help="Cap rows loaded from Hugging Face so exploration stays responsive.",
)
if st.button(
"Load HF dataset",
help="Download the first available split and convert compatible rows into records.",
):
try:
records = _load_hf_dataset_records(
st,
hf_dataset_ref,
limit=int(hf_limit),
)
except (RuntimeError, ValueError, OSError) as error:
st.error(str(error))
st.stop()
raise RuntimeError("Streamlit stopped after HF dataset load error.") from error
st.session_state["hf_dataset_records"] = records
elif "hf_dataset_records" in st.session_state:
records = st.session_state["hf_dataset_records"]
elif dataset_source == "Upload file":
uploaded_dataset = st.file_uploader(
"Upload dataset",
type=["jsonl", "zip"],
help=(
"Use a JSONL file for text-only records, or a ZIP containing one JSONL "
"file plus referenced images."
),
)
if uploaded_dataset is not None:
records = _load_uploaded_records(st, uploaded_dataset)
subjects = sorted({record.subject for record in records if record.subject is not None})
levels = sorted({record.level for record in records if record.level is not None})
summary = summarize_records(records)
st.caption(f"{summary['records']} records | {summary['images']} images")
st.header("Filters")
subject = st.selectbox(
"Subject",
["all", *subjects],
help="Limit the visible record list to one MathVision subject.",
)
level_label = st.selectbox(
"Level",
["all", *(str(level) for level in levels)],
help="Limit the visible record list to one difficulty level.",
)
show_solutions = st.toggle(
"Show solutions",
value=True,
help="Show or hide worked solutions in the record browser below the neighbor panel.",
)
st.header("Latent Space")
embedder_label = st.selectbox(
"Embedder",
["color (fast demo)", "ijepa (semantic, requires make sync-ijepa)"],
help=(
"Choose the image representation used for nearest-neighbor search. "
"I-JEPA is slower but more semantic."
),
)
show_patch_maps = st.toggle(
"Patch maps",
value=True,
help=(
"When I-JEPA is selected, overlay patch-activation heatmaps on the query "
"and neighbor images."
),
)
st.caption("Similarity swatch: red = lower cosine similarity, green = higher.")
st.subheader("Ranking")
weights = SimilarityWeights(
image=st.slider(
"Image weight",
min_value=0.0,
max_value=2.0,
value=1.0,
step=0.05,
help="How much raw image similarity contributes to the final ranking.",
),
subject_bonus=st.slider(
"Subject bonus",
min_value=0.0,
max_value=1.0,
value=0.15,
step=0.01,
help="Extra score for neighbors from the same subject as the query.",
),
problem_type_bonus=st.slider(
"Problem-type bonus",
min_value=0.0,
max_value=1.0,
value=0.10,
step=0.01,
help="Extra score for matching multiple-choice or open-ended format.",
),
level_penalty=st.slider(
"Level penalty",
min_value=0.0,
max_value=1.0,
value=0.05,
step=0.01,
help="Penalty per difficulty-level step between query and neighbor.",
),
)
query_id = st.selectbox(
"Query record",
[record.problem_id for record in records],
help="The record used as the visual search query.",
)
neighbor_count = st.slider(
"Neighbors",
min_value=1,
max_value=8,
value=3,
help="Number of nearest records to show in the right-hand panel.",
)
selected_subject = None if subject == "all" else subject
selected_level = None if level_label == "all" else int(level_label)
filtered = filter_records(records, subject=selected_subject, level=selected_level)
_render_similarity_panel(
st,
records,
query_id=query_id,
embedder_name=_embedder_name_from_label(embedder_label),
neighbor_count=neighbor_count,
show_patch_maps=show_patch_maps,
weights=weights,
)
st.caption(f"{len(filtered)} of {len(records)} records")
for record in filtered:
_render_record(st, record, show_solution=show_solutions)
def _render_similarity_panel(
st: Any,
records: list[MathVisionRecord],
*,
query_id: str,
embedder_name: str,
neighbor_count: int,
show_patch_maps: bool,
weights: SimilarityWeights,
) -> None:
st.header("Nearest Neighbors")
st.caption(embedder_description(embedder_name))
record_by_id = {record.problem_id: record for record in records}
query = record_by_id[query_id]
if query.image_path is None:
st.warning("Selected query has no image.")
return
try:
embedder = _load_embedder(embedder_name)
query_vector = embedder.embed_image(query.image_path)
index = build_image_index(records, embedder)
matches = _find_similar_records_combined(
records,
index,
query.problem_id,
query_vector,
limit=neighbor_count,
weights=weights,
)
except RuntimeError as error:
st.error(str(error))
return
columns = st.columns([1, 2])
with columns[0]:
st.caption("Query")
st.image(str(query.image_path), width=QUERY_IMAGE_WIDTH)
if show_patch_maps and embedder_name == "ijepa":
_render_patch_attention(st, embedder, query.image_path, width=QUERY_IMAGE_WIDTH)
st.write(query.problem_id)
st.caption(_record_metadata_line(query))
with columns[1]:
for record, neighbor, combined in matches:
interpretation = interpret_match(query, record, score=neighbor.score)
with st.container(border=True):
match_columns = st.columns([0.35, 1])
with match_columns[0]:
if record.image_path is not None:
st.image(str(record.image_path), width=NEIGHBOR_IMAGE_WIDTH)
if show_patch_maps and embedder_name == "ijepa":
_render_patch_attention(
st,
embedder,
record.image_path,
width=NEIGHBOR_IMAGE_WIDTH,
expanded=False,
)
with match_columns[1]:
st.write(f"**{record.problem_id}**")
st.caption(
f"cosine {neighbor.score:.4f} | combined {combined:.4f} | "
f"{interpretation.label} "
f"| {_record_metadata_line(record)}"
)
_render_score_swatch(st, neighbor.score)
st.write(record.question)
st.write(interpretation.summary)
def _render_record(st: Any, record: MathVisionRecord, *, show_solution: bool) -> None:
with st.container(border=True):
columns = st.columns([1, 1.4])
with columns[0]:
if record.image_path is not None:
st.image(str(record.image_path), width=RECORD_IMAGE_WIDTH)
with columns[1]:
st.subheader(record.question)
badges = [record.problem_id]
if record.subject is not None:
badges.append(record.subject)
if record.level is not None:
badges.append(f"level {record.level}")
if record.problem_type is not None:
badges.append(record.problem_type)
elif record.options:
badges.append("multiple_choice")
else:
badges.append("open_ended")
st.caption(" | ".join(badges))
if record.options:
st.write("Options: " + ", ".join(record.options))
st.write(f"Answer: **{record.answer}**")
if show_solution and record.solution:
st.write(record.solution)
def _load_active_records(st: Any, jsonl_path: Path) -> list[MathVisionRecord]:
try:
return load_jsonl_records(jsonl_path)
except (OSError, ValueError) as error:
st.error(str(error))
st.stop()
raise RuntimeError("Streamlit stopped after dataset load error.") from error
def _load_uploaded_records(st: Any, uploaded_dataset: Any) -> list[MathVisionRecord]:
dataset_bytes = uploaded_dataset.getvalue()
dataset_name = uploaded_dataset.name
dataset_key = _uploaded_dataset_key(dataset_name, dataset_bytes)
try:
if dataset_name.lower().endswith(".zip"):
return _load_uploaded_zip_records(st, dataset_key, dataset_bytes)
return load_jsonl_records_from_text(dataset_bytes.decode("utf-8"))
except (BadZipFile, UnicodeDecodeError, ValueError, OSError) as error:
st.error(str(error))
st.stop()
raise RuntimeError("Streamlit stopped after upload load error.") from error
def _load_uploaded_zip_records(
st: Any,
dataset_key: str,
dataset_bytes: bytes,
) -> list[MathVisionRecord]:
upload_state = st.session_state.setdefault("uploaded_dataset", {})
if upload_state.get("key") != dataset_key:
_remove_upload_dir(upload_state.get("extract_dir"))
extract_dir = Path(tempfile.mkdtemp(prefix="mathvision-upload-"))
_extract_zip_safely(dataset_bytes, extract_dir)
upload_state.clear()
upload_state.update({"key": dataset_key, "extract_dir": str(extract_dir)})
extract_dir = Path(upload_state["extract_dir"])
jsonl_files = sorted(extract_dir.rglob("*.jsonl"))
if not jsonl_files:
msg = "Uploaded ZIP must contain a .jsonl file."
raise ValueError(msg)
return load_jsonl_records(jsonl_files[0])
def _load_hf_dataset_records(
st: Any,
dataset_ref: str,
*,
limit: int,
) -> list[MathVisionRecord]:
repo_id = _hf_dataset_id_from_ref(dataset_ref)
datasets = _load_datasets_library()
hf_state = st.session_state.setdefault("hf_dataset", {})
hf_key = f"{repo_id}:{limit}"
if hf_state.get("key") != hf_key:
_remove_upload_dir(hf_state.get("image_dir"))
image_dir = Path(tempfile.mkdtemp(prefix="mathvision-hf-images-"))
status = st.empty()
progress = st.progress(0, text="Preparing Hugging Face dataset load")
status.info(f"Connecting to `{repo_id}`...")
try:
progress.progress(20, text="Finding the first usable split")
split_name = _choose_hf_split(datasets, repo_id)
status.info(f"Loading `{repo_id}` split `{split_name}`...")
progress.progress(45, text="Downloading dataset rows")
dataset = datasets.load_dataset(repo_id, split=split_name)
progress.progress(65, text=f"Converting up to {limit} rows")
records = _records_from_hf_dataset(dataset, image_dir=image_dir, limit=limit)
image_count = sum(1 for record in records if record.image_path is not None)
progress.progress(100, text="Dataset ready")
status.success(
f"Loaded {len(records)} records with {image_count} images from `{split_name}`."
)
except Exception:
progress.empty()
status.error("Hugging Face dataset load failed.")
raise
hf_state.clear()
hf_state.update({"key": hf_key, "image_dir": str(image_dir), "records": records})
else:
cached_records = hf_state.get("records", [])
if isinstance(cached_records, list):
st.success(f"Using cached Hugging Face dataset ({len(cached_records)} records).")
records = hf_state["records"]
if not isinstance(records, list):
msg = "Cached HF dataset records are unavailable."
raise RuntimeError(msg)
return records
def _choose_hf_split(datasets: Any, repo_id: str) -> str:
dataset_builder = datasets.load_dataset_builder(repo_id)
split_names = [str(name) for name in (getattr(dataset_builder.info, "splits", {}) or {})]
if not split_names:
return "train"
for preferred in ("test", "validation", "valid", "train"):
if preferred in split_names:
return preferred
return split_names[0]
def _records_from_hf_dataset(
dataset: Any,
*,
image_dir: Path,
limit: int,
) -> list[MathVisionRecord]:
records: list[MathVisionRecord] = []
skipped = 0
for row_index, row in enumerate(dataset):
if len(records) >= limit:
break
if not isinstance(row, dict):
skipped += 1
continue
record = _record_from_hf_row(row, row_index=row_index, image_dir=image_dir)
if record is None:
skipped += 1
continue
records.append(record)
if not records:
msg = "No compatible rows found. Expected fields like id, question, answer, and image."
if skipped:
msg += f" Skipped {skipped} rows."
raise ValueError(msg)
return records
def _record_from_hf_row(
row: dict[str, Any],
*,
row_index: int,
image_dir: Path,
) -> MathVisionRecord | None:
question = _text_from_row(row, "question", "problem", "prompt")
answer = _text_from_row(row, "answer", "label", "target")
if question is None or answer is None:
return None
payload: dict[str, Any] = {
"id": _text_from_row(row, "id", "problem_id", "question_id") or f"hf-{row_index}",
"question": question,
"answer": answer,
"subject": _text_from_row(row, "subject", "category"),
"problem_type": _text_from_row(row, "problem_type", "type", "task"),
"solution": _text_from_row(row, "solution", "rationale", "explanation"),
}
if isinstance(row.get("level"), int):
payload["level"] = row["level"]
options = row.get("options")
if isinstance(options, list):
payload["options"] = [str(option) for option in options]
image_path = _save_hf_row_image(row, row_index=row_index, image_dir=image_dir)
if image_path is not None:
payload["image"] = str(image_path)
return record_from_mapping(payload)
def _save_hf_row_image(
row: dict[str, Any],
*,
row_index: int,
image_dir: Path,
) -> Path | None:
for key in ("decoded_image", "image", "img"):
image_value = row.get(key)
save_image = getattr(image_value, "save", None)
if callable(save_image):
image_path = image_dir / f"row-{row_index:05d}.png"
save_image(image_path)
return image_path
return None
def _text_from_row(row: dict[str, Any], *keys: str) -> str | None:
for key in keys:
value = row.get(key)
if value is None:
continue
if isinstance(value, str):
stripped = value.strip()
if stripped:
return stripped
elif isinstance(value, int | float):
return str(value)
return None
def _hf_dataset_id_from_ref(dataset_ref: str) -> str:
stripped = dataset_ref.strip()
if not stripped:
msg = "Enter a Hugging Face dataset URL or repo id."
raise ValueError(msg)
parsed = urlparse(stripped)
if parsed.netloc:
parts = [part for part in parsed.path.split("/") if part]
if parts[:1] == ["datasets"] and len(parts) >= 3:
return "/".join(parts[1:3])
msg = "HF dataset URL should look like https://huggingface.co/datasets/org/name."
raise ValueError(msg)
return stripped.removeprefix("datasets/")
def _load_datasets_library() -> Any:
try:
return import_module("datasets")
except ImportError as error:
msg = "Install the `datasets` package to load Hugging Face dataset URLs."
raise RuntimeError(msg) from error
def _extract_zip_safely(dataset_bytes: bytes, extract_dir: Path) -> None:
with ZipFile(io.BytesIO(dataset_bytes)) as dataset_zip:
for member in dataset_zip.infolist():
target_path = (extract_dir / member.filename).resolve()
if not target_path.is_relative_to(extract_dir.resolve()):
msg = f"Unsafe ZIP member path: {member.filename}"
raise ValueError(msg)
dataset_zip.extract(member, extract_dir)
def _uploaded_dataset_key(dataset_name: str, dataset_bytes: bytes) -> str:
digest = hashlib.sha256(dataset_bytes).hexdigest()
return f"{dataset_name}:{digest}"
def _remove_upload_dir(path: object) -> None:
if isinstance(path, str):
shutil.rmtree(path, ignore_errors=True)
def _render_patch_attention(
st: Any,
embedder: ImageEmbedder,
image_path: Path,
*,
width: int,
expanded: bool = True,
) -> None:
if not isinstance(embedder, IJepaImageEmbedder):
return
with st.expander("Patch map", expanded=expanded):
try:
interest_map = embedder.patch_interest_map(image_path)
overlay = render_patch_interest_overlay(image_path, interest_map)
grid_rows, grid_columns = interest_map.grid_size
except RuntimeError as error:
st.warning(str(error))
return
st.image(overlay, width=width)
st.caption(f"{grid_rows} x {grid_columns} patch activation map")
def _stabilize_layout(st: Any) -> None:
"""Keep nested Streamlit columns from resizing while images load."""
st.markdown(
"""
<style>
[data-testid="stHorizontalBlock"] { align-items: flex-start; }
[data-testid="column"] { min-width: 0; }
[data-testid="stImage"] img {
display: block;
height: auto;
max-width: 100%;
}
[data-testid="stMarkdownContainer"] {
overflow-wrap: anywhere;
}
</style>
""",
unsafe_allow_html=True,
)
def _load_streamlit() -> Any:
try:
return import_module("streamlit")
except ImportError as error:
msg = "Streamlit is missing. Install it with `uv sync --extra app --dev`."
raise RuntimeError(msg) from error
def _load_embedder(embedder_name: str) -> ImageEmbedder:
if embedder_name == "ijepa":
return IJepaImageEmbedder()
return ColorStatsEmbedder()
def _embedder_name_from_label(label: str) -> str:
return "ijepa" if label.startswith("ijepa") else "color"
def _record_metadata_line(record: MathVisionRecord) -> str:
subject = record.subject or "unknown subject"
level = f"level {record.level}" if record.level is not None else "unknown level"
if record.problem_type:
problem_type = record.problem_type
elif record.options:
problem_type = "multiple_choice"
else:
problem_type = "open_ended"
return f"{subject} | {level} | {problem_type}"
def _render_score_swatch(st: Any, score: float) -> None:
color = _score_to_hex(score)
st.markdown(
f"<div style='height: 10px; border-radius: 3px; background: {color}; "
"border: 1px solid rgba(0,0,0,0.15);'></div>",
unsafe_allow_html=True,
)
def _score_to_hex(score: float) -> str:
"""Map cosine similarity (-1..1) to a red->green swatch for quick scanning."""
clamped = max(-1.0, min(1.0, score))
t = (clamped + 1.0) / 2.0
red = int(round(220 * (1.0 - t) + 30 * t))
green = int(round(30 * (1.0 - t) + 180 * t))
blue = int(round(60 * (1.0 - t) + 60 * t))
return f"#{red:02x}{green:02x}{blue:02x}"
def _find_similar_records_combined(
records: list[MathVisionRecord],
index: Any,
query_id: str,
query_vector: tuple[float, ...],
*,
limit: int,
weights: SimilarityWeights,
) -> list[tuple[MathVisionRecord, Any, float]]:
"""Fetch candidates by cosine similarity, then rerank with metadata weights."""
record_by_id = {record.problem_id: record for record in records}
query = record_by_id[query_id]
candidate_count = max(25, limit * 10)
neighbors = index.search(query_vector, limit=candidate_count, exclude_id=query_id)
scored: list[tuple[MathVisionRecord, Any, float]] = []
for neighbor in neighbors:
record = record_by_id.get(neighbor.item_id)
if record is None:
continue
combined = combined_score(query, record, image_score=neighbor.score, weights=weights)
scored.append((record, neighbor, combined))
scored.sort(key=lambda row: row[2], reverse=True)
return scored[:limit]
if __name__ == "__main__":
main()