ddebree commited on
Commit
45ade5d
·
1 Parent(s): 44f945d
src/mathvision_explorer/dataset.py CHANGED
@@ -26,17 +26,26 @@ class MathVisionRecord:
26
  def load_jsonl_records(path: Path) -> list[MathVisionRecord]:
27
  """Load MathVision-like records from a UTF-8 JSONL file."""
28
 
 
 
 
 
 
 
 
 
 
 
29
  records: list[MathVisionRecord] = []
30
- with path.open("r", encoding="utf-8") as jsonl_file:
31
- for line_number, line in enumerate(jsonl_file, start=1):
32
- stripped = line.strip()
33
- if not stripped:
34
- continue
35
- payload = json.loads(stripped)
36
- if not isinstance(payload, dict):
37
- msg = f"Line {line_number} must contain a JSON object."
38
- raise ValueError(msg)
39
- records.append(record_from_mapping(payload, source_dir=path.parent))
40
  return records
41
 
42
 
 
26
  def load_jsonl_records(path: Path) -> list[MathVisionRecord]:
27
  """Load MathVision-like records from a UTF-8 JSONL file."""
28
 
29
+ return load_jsonl_records_from_text(path.read_text(encoding="utf-8"), source_dir=path.parent)
30
+
31
+
32
+ def load_jsonl_records_from_text(
33
+ content: str,
34
+ *,
35
+ source_dir: Path | None = None,
36
+ ) -> list[MathVisionRecord]:
37
+ """Load MathVision-like records from UTF-8 JSONL text."""
38
+
39
  records: list[MathVisionRecord] = []
40
+ for line_number, line in enumerate(content.splitlines(), start=1):
41
+ stripped = line.strip()
42
+ if not stripped:
43
+ continue
44
+ payload = json.loads(stripped)
45
+ if not isinstance(payload, dict):
46
+ msg = f"Line {line_number} must contain a JSON object."
47
+ raise ValueError(msg)
48
+ records.append(record_from_mapping(payload, source_dir=source_dir))
 
49
  return records
50
 
51
 
src/mathvision_explorer/streamlit_app.py CHANGED
@@ -2,11 +2,22 @@
2
 
3
  from __future__ import annotations
4
 
 
 
 
 
5
  from importlib import import_module
6
  from pathlib import Path
7
  from typing import Any
8
-
9
- from mathvision_explorer.dataset import MathVisionRecord, filter_records, load_jsonl_records
 
 
 
 
 
 
 
10
  from mathvision_explorer.embeddings import (
11
  ColorStatsEmbedder,
12
  IJepaImageEmbedder,
@@ -30,16 +41,31 @@ def main(jsonl_path: Path = Path("data/demo/demo.jsonl")) -> None:
30
  """Run the Streamlit explorer app."""
31
 
32
  st = _load_streamlit()
33
- records = load_jsonl_records(jsonl_path)
34
 
35
  st.set_page_config(page_title="MathVision Explorer", layout="wide")
36
  _stabilize_layout(st)
37
  st.title("MathVision Explorer")
38
 
 
39
  subjects = sorted({record.subject for record in records if record.subject is not None})
40
  levels = sorted({record.level for record in records if record.level is not None})
41
 
42
  with st.sidebar:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  st.header("Filters")
44
  subject = st.selectbox(
45
  "Subject",
@@ -238,6 +264,71 @@ def _render_record(st: Any, record: MathVisionRecord, *, show_solution: bool) ->
238
  st.write(record.solution)
239
 
240
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
  def _render_patch_attention(
242
  st: Any,
243
  embedder: ImageEmbedder,
 
2
 
3
  from __future__ import annotations
4
 
5
+ import hashlib
6
+ import io
7
+ import shutil
8
+ import tempfile
9
  from importlib import import_module
10
  from pathlib import Path
11
  from typing import Any
12
+ from zipfile import BadZipFile, ZipFile
13
+
14
+ from mathvision_explorer.dataset import (
15
+ MathVisionRecord,
16
+ filter_records,
17
+ load_jsonl_records,
18
+ load_jsonl_records_from_text,
19
+ summarize_records,
20
+ )
21
  from mathvision_explorer.embeddings import (
22
  ColorStatsEmbedder,
23
  IJepaImageEmbedder,
 
41
  """Run the Streamlit explorer app."""
42
 
43
  st = _load_streamlit()
 
44
 
45
  st.set_page_config(page_title="MathVision Explorer", layout="wide")
46
  _stabilize_layout(st)
47
  st.title("MathVision Explorer")
48
 
49
+ records = _load_active_records(st, jsonl_path)
50
  subjects = sorted({record.subject for record in records if record.subject is not None})
51
  levels = sorted({record.level for record in records if record.level is not None})
52
 
53
  with st.sidebar:
54
+ st.header("Dataset")
55
+ uploaded_dataset = st.file_uploader(
56
+ "Upload dataset",
57
+ type=["jsonl", "zip"],
58
+ help=(
59
+ "Use a JSONL file for text-only records, or a ZIP containing one JSONL "
60
+ "file plus referenced images."
61
+ ),
62
+ )
63
+ if uploaded_dataset is not None:
64
+ records = _load_uploaded_records(st, uploaded_dataset)
65
+ subjects = sorted({record.subject for record in records if record.subject is not None})
66
+ levels = sorted({record.level for record in records if record.level is not None})
67
+ summary = summarize_records(records)
68
+ st.caption(f"{summary['records']} records | {summary['images']} images")
69
  st.header("Filters")
70
  subject = st.selectbox(
71
  "Subject",
 
264
  st.write(record.solution)
265
 
266
 
267
+ def _load_active_records(st: Any, jsonl_path: Path) -> list[MathVisionRecord]:
268
+ try:
269
+ return load_jsonl_records(jsonl_path)
270
+ except (OSError, ValueError) as error:
271
+ st.error(str(error))
272
+ st.stop()
273
+ raise RuntimeError("Streamlit stopped after dataset load error.") from error
274
+
275
+
276
+ def _load_uploaded_records(st: Any, uploaded_dataset: Any) -> list[MathVisionRecord]:
277
+ dataset_bytes = uploaded_dataset.getvalue()
278
+ dataset_name = uploaded_dataset.name
279
+ dataset_key = _uploaded_dataset_key(dataset_name, dataset_bytes)
280
+
281
+ try:
282
+ if dataset_name.lower().endswith(".zip"):
283
+ return _load_uploaded_zip_records(st, dataset_key, dataset_bytes)
284
+ return load_jsonl_records_from_text(dataset_bytes.decode("utf-8"))
285
+ except (BadZipFile, UnicodeDecodeError, ValueError, OSError) as error:
286
+ st.error(str(error))
287
+ st.stop()
288
+ raise RuntimeError("Streamlit stopped after upload load error.") from error
289
+
290
+
291
+ def _load_uploaded_zip_records(
292
+ st: Any,
293
+ dataset_key: str,
294
+ dataset_bytes: bytes,
295
+ ) -> list[MathVisionRecord]:
296
+ upload_state = st.session_state.setdefault("uploaded_dataset", {})
297
+ if upload_state.get("key") != dataset_key:
298
+ _remove_upload_dir(upload_state.get("extract_dir"))
299
+ extract_dir = Path(tempfile.mkdtemp(prefix="mathvision-upload-"))
300
+ _extract_zip_safely(dataset_bytes, extract_dir)
301
+ upload_state.clear()
302
+ upload_state.update({"key": dataset_key, "extract_dir": str(extract_dir)})
303
+
304
+ extract_dir = Path(upload_state["extract_dir"])
305
+ jsonl_files = sorted(extract_dir.rglob("*.jsonl"))
306
+ if not jsonl_files:
307
+ msg = "Uploaded ZIP must contain a .jsonl file."
308
+ raise ValueError(msg)
309
+ return load_jsonl_records(jsonl_files[0])
310
+
311
+
312
+ def _extract_zip_safely(dataset_bytes: bytes, extract_dir: Path) -> None:
313
+ with ZipFile(io.BytesIO(dataset_bytes)) as dataset_zip:
314
+ for member in dataset_zip.infolist():
315
+ target_path = (extract_dir / member.filename).resolve()
316
+ if not target_path.is_relative_to(extract_dir.resolve()):
317
+ msg = f"Unsafe ZIP member path: {member.filename}"
318
+ raise ValueError(msg)
319
+ dataset_zip.extract(member, extract_dir)
320
+
321
+
322
+ def _uploaded_dataset_key(dataset_name: str, dataset_bytes: bytes) -> str:
323
+ digest = hashlib.sha256(dataset_bytes).hexdigest()
324
+ return f"{dataset_name}:{digest}"
325
+
326
+
327
+ def _remove_upload_dir(path: object) -> None:
328
+ if isinstance(path, str):
329
+ shutil.rmtree(path, ignore_errors=True)
330
+
331
+
332
  def _render_patch_attention(
333
  st: Any,
334
  embedder: ImageEmbedder,
tests/test_dataset.py CHANGED
@@ -10,6 +10,7 @@ import pytest
10
  from mathvision_explorer.dataset import (
11
  filter_records,
12
  load_jsonl_records,
 
13
  record_from_mapping,
14
  summarize_records,
15
  )
@@ -52,6 +53,21 @@ def test_load_jsonl_records_resolves_relative_image_paths(tmp_path: Path) -> Non
52
  assert records[0].image_path == tmp_path / "images" / "mv-2.png"
53
 
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  def test_filter_and_summary() -> None:
56
  """Records can be filtered and summarized for explorer views."""
57
 
 
10
  from mathvision_explorer.dataset import (
11
  filter_records,
12
  load_jsonl_records,
13
+ load_jsonl_records_from_text,
14
  record_from_mapping,
15
  summarize_records,
16
  )
 
53
  assert records[0].image_path == tmp_path / "images" / "mv-2.png"
54
 
55
 
56
+ def test_load_jsonl_records_from_text_resolves_relative_image_paths(tmp_path: Path) -> None:
57
+ """Uploaded JSONL content can still use a caller-provided image base directory."""
58
+
59
+ payload = {
60
+ "id": "mv-3",
61
+ "question": "Pick the matching graph.",
62
+ "answer": "A",
63
+ "image": "images/mv-3.png",
64
+ }
65
+
66
+ records = load_jsonl_records_from_text(json.dumps(payload), source_dir=tmp_path)
67
+
68
+ assert records[0].image_path == tmp_path / "images" / "mv-3.png"
69
+
70
+
71
  def test_filter_and_summary() -> None:
72
  """Records can be filtered and summarized for explorer views."""
73