sreeramajay commited on
Commit
2cfb17f
·
verified ·
1 Parent(s): d4e3646

Upload folder using huggingface_hub

Browse files
.gitignore ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Virtual environment
2
+ .venv/
3
+ venv/
4
+
5
+ # Python
6
+ __pycache__/
7
+ *.py[cod]
8
+ *.egg-info/
9
+ dist/
10
+ build/
11
+ *.egg
12
+
13
+ # Environment / secrets
14
+ .env
15
+ .env.local
16
+
17
+ # IDE
18
+ .idea/
19
+ .vscode/
20
+ *.swp
21
+ *.swo
22
+ *~
23
+
24
+ # OS
25
+ .DS_Store
26
+ Thumbs.db
27
+
28
+ # pytest
29
+ .pytest_cache/
30
+ htmlcov/
31
+ .coverage
32
+
33
+ # uv
34
+ uv.lock
35
+
36
+ # Jupyter
37
+ .ipynb_checkpoints/
38
+
39
+ CLAUDE.md
40
+
41
+ openenv_visual_reasoning.egg-info/
CLAUDE.md DELETED
@@ -1,111 +0,0 @@
1
- # Visual Reasoning Environment
2
-
3
- ## What This Is
4
-
5
- An OpenEnv RL environment for training LLMs (via GRPO/RLVR) to be expert visual explainers of CS algorithms. The LLM acts as a teacher drawing on a whiteboard — it creates data structures, walks through algorithms step-by-step, and narrates the reasoning. A scoring system provides dense, verifiable rewards.
6
-
7
- ## Architecture
8
-
9
- ```
10
- inference.py / inference_tldraw.py — LLM inference loops (client-side)
11
- client.py — WebSocket EnvClient wrapper
12
- server/app.py — FastAPI OpenEnv server entry point
13
- server/visual_reasoning_environment.py — Core env: reset(), step(), state management
14
- server/scoring.py — 13 weighted sub-scores + 5 penalties → reward
15
- server/invariant_checkers.py — Per-algorithm correctness checks (9 algorithms)
16
- server/narration_scorer.py — Qwen3-Embedding-0.6B cosine similarity scorer
17
- server/pedagogical_scoring.py — Teaching quality: attention coherence, pacing, scaffolding
18
- server/scenario_loader.py — Loads scenarios.json + procedural generation
19
- server/scenario_generator.py — Procedural scenario generation per difficulty
20
- server/regions.py — Layout engine (queue/stack/tree/graph positioning)
21
- server/constants.py — ALLOWED_OPS, ROLE_VALUES, limits
22
- models.py — Pydantic models: VisualReasoningAction, VisualReasoningObservation
23
- viewer/tldraw_viewer.html — tldraw-based browser visualizer
24
- ```
25
-
26
- ## Key Concepts
27
-
28
- - **Empty canvas paradigm**: All scenarios start empty. The LLM draws the problem first (Phase 1), then solves it step-by-step (Phase 2), then completes (Phase 3).
29
- - **16 canvas operations**: add_region, add_node, add_pointer, add_container, add_edge, remove_edge, push_to, pop_from, move_pointer, set_value, set_role, annotate, highlight, unhighlight, add_note, remove_entity.
30
- - **10 entity roles**: default, current, visited, frontier, done, pivot, root, error, inactive, comparing.
31
- - **Region vs Container**: Regions (`add_region`) are layout areas for visual positioning. Containers (`add_container`) track membership for push/pop. `push_to`/`pop_from` ONLY work on containers, not regions. This distinction is a common source of LLM confusion.
32
- - **Delta-based rewards**: `reward = (new_overall_score - previous_overall_score) + flat_penalties`. Scoring is deterministic for RL training reproducibility.
33
- - **Concept coverage**: The LLM claims concepts from a checklist; coverage is verified via narration evidencing with prefix matching + alias expansion (`_CONCEPT_ALIASES` and `_CONCEPT_PART_ALIASES` in scoring.py).
34
-
35
- ## Running
36
-
37
- ### Environment setup
38
- ```bash
39
- conda activate unsloth_env
40
- ```
41
-
42
- ### Run tests
43
- ```bash
44
- conda run -n unsloth_env python -m pytest tests/ -v
45
- ```
46
-
47
- ### Run inference (headless)
48
- ```bash
49
- # Start the server first (in another terminal or via Docker)
50
- LOCAL_IMAGE_NAME=http://127.0.0.1:8000 python inference.py
51
- ```
52
-
53
- ### Run inference with tldraw viewer
54
- ```bash
55
- python inference_tldraw.py
56
- # Opens browser at http://0.0.0.0:8765/
57
- ```
58
-
59
- ### Environment variables
60
- - `LOCAL_IMAGE_NAME` — Docker image or `http://127.0.0.1:PORT` for local server
61
- - `API_BASE_URL` — LLM API endpoint (default: HuggingFace router)
62
- - `API_KEY` / `HF_TOKEN` — API authentication
63
- - `MODEL_NAME` — LLM model (default: Qwen/Qwen2.5-72B-Instruct)
64
- - `VISUAL_REASONING_TASKS` — Comma-separated task list (default: easy,medium,hard,expert)
65
- - `DEBUG=1` — Enable verbose debug logging
66
- - `VIS_PORT`, `VIS_HOST`, `VIS_WAIT` — tldraw viewer settings
67
-
68
- ## Scoring System
69
-
70
- 13 weighted sub-scores (weights vary by difficulty level):
71
- - **validity** (~10-12%): Correct op formats, no invalid references
72
- - **invariant** (~18-22%): Algorithm correctness checked against ground truth
73
- - **coverage** (~17-18%): Concept checklist completion via narration evidencing
74
- - **narration_quality** (~6-10%): Cosine similarity against reference narrations
75
- - **structure** (~5-7%): Constraint satisfaction + entity monotonicity
76
- - **progress** (~4-5%): State must change each step; granularity penalty for >4 creations per step
77
- - **algorithm_completion** (~5%): Cumulative algorithm progress (% nodes placed, edges drawn, roles assigned)
78
- - **spatial** (~6%): Region placement on the canvas grid (semantic fit, collision avoidance, reading-order flow)
79
- - **consistency** (~5-7%): Unexplained entity changes are penalized
80
- - **attention_coherence** (~6-7%): Narration entities match op targets (fuzzy matching)
81
- - **visual** (~2%): Layout overlap/occlusion/crossing penalties
82
- - **cognitive_pacing** (~7-8%): Information density vs. novelty; penalizes creation-heavy dumps
83
- - **scaffolding** (~7-9%): Emphasis decreases over repeated patterns
84
-
85
- 5 penalties subtracted from weighted sum:
86
- - `penalty_redundant` (0.2): All ops duplicate existing state
87
- - `penalty_no_op` (applied as flat -0.05 in env): Zero state delta
88
- - `penalty_unsupported_claims` (up to 0.3): Claiming concepts not evidenced in narration
89
- - `penalty_too_many_ops` (up to 0.5): Exceeding MAX_OPS_PER_STEP (14)
90
- - `penalty_info_dump` (up to 0.2): >5 creation ops in a single step (0.05 per excess)
91
-
92
- ## Algorithms / Scenarios
93
-
94
- 9 algorithm templates across 4 difficulty levels:
95
- - **easy**: linked_list_traversal, stack_ops, binary_search
96
- - **medium**: bfs_graph, hash_table_chaining
97
- - **hard**: dijkstra_step, bst_insert
98
- - **expert**: fib_memo, quicksort_lomuto
99
-
100
- Static scenarios in `server/scenarios/scenarios.json`, plus procedurally generated ones via `scenario_generator.py`.
101
-
102
- ## Common Pitfalls When Modifying
103
-
104
- 1. **Scoring must be deterministic** — no randomness, no floating-point order sensitivity. The regression test (`test_easy_1_reproducible`) enforces bit-identical scores across runs.
105
- 2. **`first_conflict_message` takes an optional `action` arg** — pass it for context-specific error messages that help the LLM self-correct.
106
- 3. **`compute_progress_score` takes an optional `action` arg** — needed so `complete` steps get progress=1.0.
107
- 4. **Concept evidencing uses three layers**: exact token match → prefix morphological match (`_prefix_match`) → alias expansion (`_CONCEPT_ALIASES` / `_CONCEPT_PART_ALIASES`). When adding new scenarios, ensure all checklist concepts have corresponding aliases.
108
- 5. **The narration scorer uses Qwen3-Embedding-0.6B on CUDA** — falls back to a heuristic `_fallback_score` if the model fails to load. Check `warmup_scorer` logs to confirm which path runs.
109
- 6. **The `.venv` Python is broken (3.7 binary, 3.10 site-packages)** — always use `conda run -n unsloth_env` for running code.
110
- 7. **`openenv` import resolution**: Server modules use try/except for relative vs absolute imports. Tests run from the project root with `PYTHONPATH=.`.
111
- 8. **The inference loop uses `run_in_executor`** for LLM calls to avoid blocking the async event loop (which would cause WebSocket keepalive timeouts).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
README.md CHANGED
@@ -11,7 +11,6 @@ tags:
11
  - reinforcement-learning
12
  - llm
13
  - grpo
14
- base_path: /web
15
  ---
16
 
17
  # Visual Reasoning Environment
@@ -140,33 +139,32 @@ Clipped to `[-0.2, 1.0]`. Designed for GRPO / RLVR training.
140
  Here's the full picture -- how the agent, the environment, and the reward signal fit together:
141
 
142
  ```
143
- ┌─────────────────────────────────────────────────────────────────────
144
- TRAINING LOOP (GRPO / RLVR)
145
-
146
- │ ┌───────────┐ prompt ┌──────────────┐ JSON action
147
- │ │ │ ─────────────>│ │ ──────────────────
148
- │ │ Scenario │ │ LLM Agent │
149
- │ │ Generator│ │ (Teacher) │
150
- │ │ │ ┌─────────>│ │<──────────
151
- │ └───────────┘ └──────────────┘
152
-
153
- observation reward
154
- + score breakdown signal
155
-
156
- ┌────────┴─────── score ┌─────┴─
157
- │ <─────────────────── │ │
158
- Environment │ Scoring │
159
- (Empty Canvas) │ ──────────────────> │ Engine │
160
- canvas state │(13 dim)
161
- └──────────────── └────────┘
162
- ^
163
- step(action) │
164
- └───────────────────────────────────────────
165
-
166
- Per-step reward = Δ(overall_score) + penalties + concept_bonuses
167
- │ Episode: empty canvas ──> Phase 1 (draw) ──> Phase 2 (solve)
168
- ──> Phase 3 (summarize) ──> done │
169
- └─────────────────────────────────────────────────────────────────────┘
170
  ```
171
 
172
  Every episode starts with a blank canvas and a goal like *"Explain how Dijkstra's algorithm finds shortest paths in this graph."* The agent draws, narrates, and advances step by step. The scoring engine evaluates each step across all 13 dimensions. The reward signal flows back into the RL training loop, gradually shaping the agent into a better teacher.
@@ -265,19 +263,21 @@ The overall score jumped from **0.368 to 0.536** -- a 45.7% relative improvement
265
  ```
266
  SFT+GRPO Score by Difficulty (Qwen2.5-3B, single A100)
267
 
268
- 0.7 |
269
- | ┌───────┐
270
- 0.6 | ┌──────┐ │ 0.635 │ +120% from baseline
271
- | │ 0.566│ └───────
272
- 0.5 | ┌──────┐ └──────
273
- | │ 0.481│ ┌──────┐
274
- 0.4 | ──────┘ │ 0.461│
275
- | └──────┘
276
- 0.3 | +33.6% +28.4% +21.5% +120.5%
277
- |
278
- 0.2 | ░░0.360░ ░░0.359░ ░░0.466░ ░░0.288░ Baseline
279
- +────────────────────────────────────────────────────
280
- Easy Medium Hard Expert
 
 
281
  ```
282
 
283
  A few things stand out from these results:
 
11
  - reinforcement-learning
12
  - llm
13
  - grpo
 
14
  ---
15
 
16
  # Visual Reasoning Environment
 
139
  Here's the full picture -- how the agent, the environment, and the reward signal fit together:
140
 
141
  ```
142
+ ┌─────────────────────────────────────────────────────────────────┐
143
+ TRAINING LOOP (GRPO / RLVR)
144
+
145
+ │ ┌───────────┐ prompt ┌──────────────┐ JSON action
146
+ │ │ │ ────────> │ │ ──────────────┐
147
+ │ │ Scenario │ │ LLM Agent │
148
+ │ │ Generator│ │ (Teacher) │
149
+ │ │ │ ┌──────> │ │ <────────┐ │
150
+ │ └───────────┘ └──────────────┘ │ │
151
+ │ │
152
+ observation reward
153
+ + score breakdown signal
154
+ │ │
155
+ ┌────────┴───────┐ score ┌─────┴─┐
156
+ │ <──────────────── │ │
157
+ Environment │ │Scoring │
158
+ (Empty Canvas) │ ───────────────> │ Engine
159
+ canvas state │(13 dim)│
160
+ └────────────────┘ └────────┘
161
+ ^
162
+ step(action) │
163
+ └───────────────────────────────────────┘
164
+
165
+ │ reward = Δ(overall_score) + penalties + concept_bonuses
166
+ │ Episode: empty canvas ──> Phase 1 ──> Phase 2 ──> Phase 3
167
+ ─────────────────────────────────────────────────────────────────┘
 
168
  ```
169
 
170
  Every episode starts with a blank canvas and a goal like *"Explain how Dijkstra's algorithm finds shortest paths in this graph."* The agent draws, narrates, and advances step by step. The scoring engine evaluates each step across all 13 dimensions. The reward signal flows back into the RL training loop, gradually shaping the agent into a better teacher.
 
263
  ```
264
  SFT+GRPO Score by Difficulty (Qwen2.5-3B, single A100)
265
 
266
+ |
267
+ 0.7 |
268
+ | ┌────────
269
+ 0.6 | ┌────────┐ │ 0.635 │
270
+ | │ 0.566 │ └────────┘
271
+ 0.5 | ┌──────── └────────┘
272
+ | 0.481 ────────┐
273
+ 0.4 | └──────── │ 0.461 │
274
+ | └────────┘
275
+ 0.3 | +33.6% +28.4% +21.5% +120.5%
276
+ |
277
+ 0.2 | ░ 0.360 ░ ░ 0.359 ░ ░ 0.466 ░ ░ 0.288 ░
278
+ | Baseline
279
+ +─────────────────────────────────────────────────
280
+ Easy Medium Hard Expert
281
  ```
282
 
283
  A few things stand out from these results:
inference_audio.py DELETED
@@ -1,451 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the BSD-style license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- from __future__ import annotations
8
-
9
- import asyncio
10
- import base64
11
- import contextlib
12
- import json
13
- import os
14
- import sys
15
- import time
16
- from pathlib import Path
17
- from typing import Any, Dict, List, Optional, Set, Tuple
18
-
19
- import uvicorn
20
- from fastapi import FastAPI, WebSocket, WebSocketDisconnect
21
- from fastapi.responses import FileResponse
22
- from openai import OpenAI
23
-
24
- from client import VisualReasoningEnv
25
- from models import VisualReasoningAction
26
-
27
- from inference import (
28
- API_BASE_URL,
29
- API_KEY,
30
- BENCHMARK,
31
- HF_SPACE_URL,
32
- LOCAL_IMAGE_NAME,
33
- MAX_STEPS,
34
- MODEL_NAME,
35
- SUCCESS_SCORE_THRESHOLD,
36
- TASK_NAMES,
37
- action_to_string,
38
- debug,
39
- get_model_action_async,
40
- log_end,
41
- log_start,
42
- log_step,
43
- )
44
-
45
- VIS_PORT = int(os.getenv("VIS_PORT", "8765"))
46
- VIS_HOST = os.getenv("VIS_HOST", "0.0.0.0")
47
- VIS_WAIT = float(os.getenv("VIS_WAIT", "30"))
48
-
49
- VIEWER_DIR = Path(__file__).parent / "viewer"
50
- HTML_PATH = VIEWER_DIR / "audio_viewer.html"
51
- AUDIO_DIR = Path(__file__).parent / "others"
52
- BACKGROUND_MUSIC_PATH = AUDIO_DIR / "tutorial_background.mp3"
53
- HISTORY_CAP = 500
54
-
55
- TTS_LEAD_TIME = 2.0
56
- TTS_MIN_WAIT = 1.5
57
- TTS_MODEL = "hexgrad/Kokoro-82M"
58
- TTS_PROVIDER = "fal-ai"
59
-
60
-
61
- def vlog(msg: str) -> None:
62
- print(f"[AUDIO] {msg}", file=sys.stderr, flush=True)
63
-
64
-
65
- # ---------------------------------------------------------------------------
66
- # TTS via huggingface_hub InferenceClient
67
- # ---------------------------------------------------------------------------
68
-
69
- def _make_tts_client():
70
- from huggingface_hub import InferenceClient
71
- return InferenceClient(provider=TTS_PROVIDER, api_key=os.environ.get("HF_TOKEN", ""))
72
-
73
-
74
- _tts_client = None
75
-
76
-
77
- def _get_tts_client():
78
- global _tts_client
79
- if _tts_client is None:
80
- _tts_client = _make_tts_client()
81
- return _tts_client
82
-
83
-
84
- def _estimate_duration(audio_bytes: bytes) -> float:
85
- if len(audio_bytes) < 44:
86
- return 0.0
87
- # Try WAV header
88
- if audio_bytes[:4] == b"RIFF" and audio_bytes[8:12] == b"WAVE":
89
- import struct
90
- try:
91
- byte_rate = struct.unpack_from("<I", audio_bytes, 28)[0]
92
- data_offset = audio_bytes.find(b"data")
93
- if data_offset >= 0 and byte_rate > 0:
94
- data_size = struct.unpack_from("<I", audio_bytes, data_offset + 4)[0]
95
- return data_size / byte_rate
96
- except Exception:
97
- pass
98
- # Rough estimate for compressed audio (~16kB/s for mp3 at 128kbps)
99
- return len(audio_bytes) / 16000.0
100
-
101
-
102
- def _synthesize_sync(text: str) -> Tuple[Optional[str], float]:
103
- if not text or not text.strip():
104
- return None, 0.0
105
- try:
106
- client = _get_tts_client()
107
- audio_bytes = client.text_to_speech(text, model=TTS_MODEL)
108
- if not audio_bytes:
109
- vlog("TTS returned empty audio")
110
- return None, 0.0
111
- duration = _estimate_duration(audio_bytes)
112
- vlog(f"TTS ok: {len(audio_bytes)} bytes, ~{duration:.1f}s")
113
- return base64.b64encode(audio_bytes).decode("ascii"), duration
114
- except Exception as exc:
115
- vlog(f"TTS error: {exc}")
116
- return None, 0.0
117
-
118
-
119
- # ---------------------------------------------------------------------------
120
- # Broadcaster
121
- # ---------------------------------------------------------------------------
122
-
123
- class Broadcaster:
124
- def __init__(self) -> None:
125
- self._clients: Set[WebSocket] = set()
126
- self._history: List[Dict[str, Any]] = []
127
- self._first_client = asyncio.Event()
128
- self._lock = asyncio.Lock()
129
-
130
- async def register(self, ws: WebSocket) -> None:
131
- async with self._lock:
132
- self._clients.add(ws)
133
- replay = list(self._history)
134
- self._first_client.set()
135
- vlog(f"connected clients={len(self._clients)}")
136
- for msg in replay:
137
- try:
138
- await ws.send_text(json.dumps(msg))
139
- except Exception:
140
- return
141
-
142
- async def unregister(self, ws: WebSocket) -> None:
143
- async with self._lock:
144
- self._clients.discard(ws)
145
- vlog(f"client disconnected clients={len(self._clients)}")
146
-
147
- async def send(self, msg: Dict[str, Any]) -> None:
148
- history_msg = {k: v for k, v in msg.items() if k != "audio"}
149
- async with self._lock:
150
- self._history.append(history_msg)
151
- if len(self._history) > HISTORY_CAP:
152
- self._history = self._history[-HISTORY_CAP:]
153
- targets = list(self._clients)
154
- if not targets:
155
- return
156
- payload = json.dumps(msg)
157
- dead: List[WebSocket] = []
158
- for ws in targets:
159
- try:
160
- await ws.send_text(payload)
161
- except Exception:
162
- dead.append(ws)
163
- if dead:
164
- async with self._lock:
165
- for ws in dead:
166
- self._clients.discard(ws)
167
-
168
- async def wait_for_client(self, timeout: float) -> bool:
169
- if timeout <= 0:
170
- return self._first_client.is_set()
171
- try:
172
- await asyncio.wait_for(self._first_client.wait(), timeout=timeout)
173
- return True
174
- except asyncio.TimeoutError:
175
- return False
176
-
177
-
178
- # ---------------------------------------------------------------------------
179
- # FastAPI app
180
- # ---------------------------------------------------------------------------
181
-
182
- def build_viewer_app(broadcaster: Broadcaster) -> FastAPI:
183
- app = FastAPI()
184
-
185
- @app.get("/")
186
- async def index():
187
- return FileResponse(str(HTML_PATH), media_type="text/html")
188
-
189
- @app.get("/audio/background.mp3")
190
- async def background_music():
191
- if BACKGROUND_MUSIC_PATH.exists():
192
- return FileResponse(str(BACKGROUND_MUSIC_PATH), media_type="audio/mpeg")
193
- return {"error": "background music not found"}
194
-
195
- @app.get("/health")
196
- async def health():
197
- return {"ok": True}
198
-
199
- @app.websocket("/ws")
200
- async def ws_endpoint(ws: WebSocket):
201
- await ws.accept()
202
- await broadcaster.register(ws)
203
- try:
204
- while True:
205
- await ws.receive_text()
206
- except WebSocketDisconnect:
207
- pass
208
- except Exception as exc:
209
- debug(f"ws error: {exc}")
210
- finally:
211
- await broadcaster.unregister(ws)
212
-
213
- return app
214
-
215
-
216
- async def start_viewer(broadcaster: Broadcaster):
217
- if not HTML_PATH.exists():
218
- vlog(f"ERROR: viewer HTML missing at {HTML_PATH}")
219
- sys.exit(1)
220
- app = build_viewer_app(broadcaster)
221
- config = uvicorn.Config(
222
- app,
223
- host=VIS_HOST,
224
- port=VIS_PORT,
225
- log_config=None,
226
- access_log=False,
227
- log_level="warning",
228
- )
229
- server = uvicorn.Server(config)
230
- task = asyncio.create_task(server.serve())
231
- for _ in range(50):
232
- if server.started:
233
- break
234
- if task.done():
235
- exc = task.exception()
236
- vlog(f"viewer server failed to start: {exc}")
237
- sys.exit(1)
238
- await asyncio.sleep(0.05)
239
- return server, task
240
-
241
-
242
- # ---------------------------------------------------------------------------
243
- # Observation snapshot
244
- # ---------------------------------------------------------------------------
245
-
246
- def _obs_snapshot(obs: Any) -> Dict[str, Any]:
247
- return {
248
- "entities": {k: dict(v) for k, v in (obs.entities or {}).items()},
249
- "relations": [dict(r) for r in (obs.relations or [])],
250
- "layout": {k: dict(v) for k, v in (obs.layout or {}).items()},
251
- "annotations": [dict(a) for a in (obs.annotations or [])],
252
- "notes": [dict(n) for n in (obs.notes or [])],
253
- "score_breakdown": {
254
- k: (float(v) if isinstance(v, (int, float)) else v)
255
- for k, v in (obs.score_breakdown or {}).items()
256
- },
257
- "coverage": list(obs.concept_coverage),
258
- "narration_history": list(obs.narration_history),
259
- "remaining_step_budget": obs.remaining_step_budget,
260
- }
261
-
262
-
263
- # ---------------------------------------------------------------------------
264
- # Episode runner with audio
265
- # ---------------------------------------------------------------------------
266
-
267
- async def run_episode_streaming(
268
- env: Any, client: OpenAI, task_name: str, broadcaster: Broadcaster
269
- ) -> None:
270
- history: List[str] = []
271
- rewards: List[float] = []
272
- steps_taken = 0
273
- score = 0.0
274
- success = False
275
- last_reward = 0.0
276
- last_action: Optional[Dict[str, Any]] = None
277
- obs = None
278
- loop = asyncio.get_running_loop()
279
-
280
- log_start(task=task_name, env=BENCHMARK, model=MODEL_NAME)
281
-
282
- try:
283
- result = await env.reset(task_name=task_name)
284
- obs = result.observation
285
- debug(
286
- f"RESET: scenario={obs.scenario_id} goal={obs.goal} "
287
- f"budget={obs.remaining_step_budget}"
288
- )
289
-
290
- goal_audio, goal_dur = await loop.run_in_executor(
291
- None, _synthesize_sync, obs.goal
292
- )
293
-
294
- snap = _obs_snapshot(obs)
295
- await broadcaster.send(
296
- {
297
- "type": "reset",
298
- "task_name": obs.task_name,
299
- "scenario_id": obs.scenario_id,
300
- "goal": obs.goal,
301
- "checklist": list(obs.concept_checklist),
302
- "input_data": dict(obs.input_data),
303
- "constraints": list(obs.constraints),
304
- "max_steps": obs.max_steps,
305
- "audio": goal_audio,
306
- "audio_duration": goal_dur,
307
- **snap,
308
- }
309
- )
310
-
311
- if goal_dur > 0:
312
- await asyncio.sleep(max(TTS_MIN_WAIT, goal_dur - TTS_LEAD_TIME))
313
- else:
314
- await asyncio.sleep(TTS_MIN_WAIT)
315
-
316
- for step in range(1, MAX_STEPS + 1):
317
- if result.done:
318
- break
319
-
320
- step_start = time.monotonic()
321
- obs = result.observation
322
-
323
- action_dict = await get_model_action_async(
324
- client, obs, last_action, last_reward, history
325
- )
326
- action = VisualReasoningAction(**action_dict)
327
- narration = action_dict.get("narration", "")
328
-
329
- env_future = asyncio.ensure_future(env.step(action))
330
- tts_future = loop.run_in_executor(None, _synthesize_sync, narration)
331
-
332
- result = await env_future
333
- audio_b64, audio_dur = await tts_future
334
-
335
- obs = result.observation
336
- reward = result.reward or 0.0
337
- done = result.done
338
- error = obs.action_error
339
-
340
- rewards.append(reward)
341
- steps_taken = step
342
- last_reward = reward
343
- last_action = action_dict
344
-
345
- log_step(
346
- step=step,
347
- action=action_to_string(action_dict),
348
- reward=reward,
349
- done=done,
350
- error=error,
351
- )
352
- history.append(f"Step {step}: action={action_to_string(action_dict)}")
353
-
354
- snap = _obs_snapshot(obs)
355
- await broadcaster.send(
356
- {
357
- "type": "step",
358
- "task_name": obs.task_name,
359
- "scenario_id": obs.scenario_id,
360
- "step": step,
361
- "step_type": action_dict.get("step_type"),
362
- "intent": action_dict.get("intent", ""),
363
- "narration": narration,
364
- "ops": action_dict.get("ops", []),
365
- "covered_concepts": action_dict.get("covered_concepts", []),
366
- "reward": float(reward),
367
- "score": float(obs.score_breakdown.get("overall_score", 0.0)),
368
- "done": bool(done),
369
- "error": error,
370
- "audio": audio_b64,
371
- "audio_duration": audio_dur,
372
- **snap,
373
- }
374
- )
375
-
376
- if done:
377
- if audio_dur > 0:
378
- await asyncio.sleep(max(0, audio_dur + 0.5))
379
- break
380
-
381
- elapsed = time.monotonic() - step_start
382
- target = max(TTS_MIN_WAIT, audio_dur - TTS_LEAD_TIME) if audio_dur > 0 else TTS_MIN_WAIT
383
- remaining = max(0, target - elapsed)
384
- if remaining > 0:
385
- await asyncio.sleep(remaining)
386
-
387
- if obs is not None and steps_taken > 0:
388
- score = float(obs.score_breakdown.get("overall_score", 0.0))
389
- score = min(max(score, 0.0), 1.0)
390
- success = score >= SUCCESS_SCORE_THRESHOLD
391
-
392
- finally:
393
- log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
394
- await broadcaster.send(
395
- {
396
- "type": "end",
397
- "task_name": task_name,
398
- "success": bool(success),
399
- "steps": steps_taken,
400
- "score": float(score),
401
- "rewards": [float(r) for r in rewards],
402
- }
403
- )
404
-
405
-
406
- # ---------------------------------------------------------------------------
407
- # Main
408
- # ---------------------------------------------------------------------------
409
-
410
- async def main() -> None:
411
- broadcaster = Broadcaster()
412
- server, server_task = await start_viewer(broadcaster)
413
-
414
- vlog(f"open http://{VIS_HOST}:{VIS_PORT}/ in your browser")
415
- if not BACKGROUND_MUSIC_PATH.exists():
416
- vlog(f"WARNING: background music not found at {BACKGROUND_MUSIC_PATH}")
417
- if VIS_WAIT > 0:
418
- vlog(f"waiting up to {VIS_WAIT:.0f}s for a browser connection...")
419
- connected = await broadcaster.wait_for_client(VIS_WAIT)
420
- if not connected:
421
- vlog("proceeding without viewer (no browser connected in time)")
422
- else:
423
- vlog("VIS_WAIT=0, starting immediately")
424
-
425
- client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
426
-
427
- if LOCAL_IMAGE_NAME:
428
- if LOCAL_IMAGE_NAME.startswith("http://127.0.0.1:"):
429
- env = VisualReasoningEnv(base_url=LOCAL_IMAGE_NAME, message_timeout_s=120)
430
- else:
431
- env = await VisualReasoningEnv.from_docker_image(LOCAL_IMAGE_NAME)
432
- else:
433
- env = await VisualReasoningEnv.from_env(HF_SPACE_URL, use_docker=False)
434
-
435
- try:
436
- for task_name in TASK_NAMES:
437
- await run_episode_streaming(env, client, task_name, broadcaster)
438
- await broadcaster.send({"type": "shutdown"})
439
- await asyncio.sleep(0.5)
440
- finally:
441
- try:
442
- await env.close()
443
- except Exception as exc:
444
- print(f"[DEBUG] env.close() error: {exc}", file=sys.stderr, flush=True)
445
- server.should_exit = True
446
- with contextlib.suppress(Exception):
447
- await server_task
448
-
449
-
450
- if __name__ == "__main__":
451
- asyncio.run(main())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
openenv_visual_reasoning.egg-info/PKG-INFO DELETED
@@ -1,19 +0,0 @@
1
- Metadata-Version: 2.4
2
- Name: openenv-visual_reasoning
3
- Version: 0.1.0
4
- Summary: Visual Reasoning environment for OpenEnv — step-based RL for grounded visual + textual CS explanations
5
- Requires-Python: >=3.10
6
- Requires-Dist: openenv-core[core]>=0.2.2
7
- Requires-Dist: numpy<2.0
8
- Requires-Dist: python-dotenv>=1.0.0
9
- Requires-Dist: networkx>=3.1
10
- Requires-Dist: shapely>=2.0
11
- Requires-Dist: sentence-transformers>=2.2
12
- Requires-Dist: rapidfuzz>=3.0
13
- Requires-Dist: textstat>=0.7
14
- Requires-Dist: sortedcontainers>=2.4
15
- Requires-Dist: aiohttp>=3.9
16
- Requires-Dist: openai>=1.0
17
- Provides-Extra: dev
18
- Requires-Dist: pytest>=8.0.0; extra == "dev"
19
- Requires-Dist: pytest-cov>=4.0.0; extra == "dev"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
openenv_visual_reasoning.egg-info/SOURCES.txt DELETED
@@ -1,38 +0,0 @@
1
- README.md
2
- __init__.py
3
- client.py
4
- inference.py
5
- inference_audio.py
6
- inference_excalidraw.py
7
- inference_tldraw.py
8
- models.py
9
- pyproject.toml
10
- ./__init__.py
11
- ./client.py
12
- ./inference.py
13
- ./inference_audio.py
14
- ./inference_excalidraw.py
15
- ./inference_tldraw.py
16
- ./models.py
17
- openenv_visual_reasoning.egg-info/PKG-INFO
18
- openenv_visual_reasoning.egg-info/SOURCES.txt
19
- openenv_visual_reasoning.egg-info/dependency_links.txt
20
- openenv_visual_reasoning.egg-info/entry_points.txt
21
- openenv_visual_reasoning.egg-info/requires.txt
22
- openenv_visual_reasoning.egg-info/top_level.txt
23
- server/__init__.py
24
- server/app.py
25
- server/app_backup.py
26
- server/constants.py
27
- server/invariant_checkers.py
28
- server/narration_scorer.py
29
- server/pedagogical_scoring.py
30
- server/regions.py
31
- server/scenario_generator.py
32
- server/scenario_loader.py
33
- server/scoring.py
34
- server/visual_reasoning_environment.py
35
- tests/test_environment.py
36
- tests/test_regions.py
37
- tests/test_scenario_loader.py
38
- tests/test_scoring.py
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
openenv_visual_reasoning.egg-info/dependency_links.txt DELETED
@@ -1 +0,0 @@
1
-
 
 
openenv_visual_reasoning.egg-info/entry_points.txt DELETED
@@ -1,2 +0,0 @@
1
- [console_scripts]
2
- server = visual_reasoning.server.app:main
 
 
 
openenv_visual_reasoning.egg-info/requires.txt DELETED
@@ -1,15 +0,0 @@
1
- openenv-core[core]>=0.2.2
2
- numpy<2.0
3
- python-dotenv>=1.0.0
4
- networkx>=3.1
5
- shapely>=2.0
6
- sentence-transformers>=2.2
7
- rapidfuzz>=3.0
8
- textstat>=0.7
9
- sortedcontainers>=2.4
10
- aiohttp>=3.9
11
- openai>=1.0
12
-
13
- [dev]
14
- pytest>=8.0.0
15
- pytest-cov>=4.0.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
openenv_visual_reasoning.egg-info/top_level.txt DELETED
@@ -1 +0,0 @@
1
- visual_reasoning
 
 
push_to_space.ipynb ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "34c098b1",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stderr",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "/opt/conda/envs/unsloth_env/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
14
+ " from .autonotebook import tqdm as notebook_tqdm\n"
15
+ ]
16
+ }
17
+ ],
18
+ "source": [
19
+ "from dotenv import load_dotenv\n",
20
+ "from huggingface_hub import HfApi\n",
21
+ "import os"
22
+ ]
23
+ },
24
+ {
25
+ "cell_type": "code",
26
+ "execution_count": 2,
27
+ "id": "eff398ee",
28
+ "metadata": {},
29
+ "outputs": [
30
+ {
31
+ "data": {
32
+ "text/plain": [
33
+ "True"
34
+ ]
35
+ },
36
+ "execution_count": 2,
37
+ "metadata": {},
38
+ "output_type": "execute_result"
39
+ }
40
+ ],
41
+ "source": [
42
+ "load_dotenv(\"../.env\")"
43
+ ]
44
+ },
45
+ {
46
+ "cell_type": "code",
47
+ "execution_count": 4,
48
+ "id": "fb24296b",
49
+ "metadata": {},
50
+ "outputs": [],
51
+ "source": [
52
+ "# api = HfApi(token=os.getenv(\"HF_TOKEN\"))\n",
53
+ "# api.upload_folder(\n",
54
+ "# repo_id=\"sreeramajay/visual_reasoning-env\",\n",
55
+ "# folder_path=\".\",\n",
56
+ "# repo_type=\"space\",\n",
57
+ "# delete_patterns=[\"*\"], # deletes all remote files not present locally\n",
58
+ "# )"
59
+ ]
60
+ },
61
+ {
62
+ "cell_type": "code",
63
+ "execution_count": null,
64
+ "id": "02ea1bf0",
65
+ "metadata": {},
66
+ "outputs": [],
67
+ "source": [
68
+ "api = HfApi(token=os.getenv(\"HF_TOKEN\"))\n",
69
+ "api.upload_folder(\n",
70
+ " repo_id=\"sreeramajay/visual_reasoning-env\",\n",
71
+ " folder_path=\".\",\n",
72
+ " repo_type=\"space\",\n",
73
+ " delete_patterns=[\"*\"],\n",
74
+ " ignore_patterns=[\n",
75
+ " \".venv/**\",\n",
76
+ " \"venv/**\",\n",
77
+ " \"**/__pycache__/**\",\n",
78
+ " \"**/*.py[cod]\",\n",
79
+ " \"**/*.egg-info/**\",\n",
80
+ " \"dist/**\",\n",
81
+ " \"build/**\",\n",
82
+ " \".env\",\n",
83
+ " \".env.local\",\n",
84
+ " \".idea/**\",\n",
85
+ " \".vscode/**\",\n",
86
+ " \".pytest_cache/**\",\n",
87
+ " \"htmlcov/**\",\n",
88
+ " \".coverage\",\n",
89
+ " \"uv.lock\",\n",
90
+ " \"**/.ipynb_checkpoints/**\",\n",
91
+ " \"CLAUDE.md\",\n",
92
+ " \"openenv_visual_reasoning.egg-info/**\",\n",
93
+ " \".DS_Store\",\n",
94
+ " \"Thumbs.db\",\n",
95
+ " ],\n",
96
+ ")"
97
+ ]
98
+ },
99
+ {
100
+ "cell_type": "code",
101
+ "execution_count": null,
102
+ "id": "3dc787ac",
103
+ "metadata": {},
104
+ "outputs": [],
105
+ "source": []
106
+ }
107
+ ],
108
+ "metadata": {
109
+ "kernelspec": {
110
+ "display_name": "unsloth_env",
111
+ "language": "python",
112
+ "name": "python3"
113
+ },
114
+ "language_info": {
115
+ "codemirror_mode": {
116
+ "name": "ipython",
117
+ "version": 3
118
+ },
119
+ "file_extension": ".py",
120
+ "mimetype": "text/x-python",
121
+ "name": "python",
122
+ "nbconvert_exporter": "python",
123
+ "pygments_lexer": "ipython3",
124
+ "version": "3.11.11"
125
+ }
126
+ },
127
+ "nbformat": 4,
128
+ "nbformat_minor": 5
129
+ }
scripts/generate_rubric_data.py DELETED
@@ -1,1132 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the BSD-style license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- import argparse
8
- import json
9
- import os
10
- import sys
11
- from pathlib import Path
12
- from typing import Any, Dict, List, Optional
13
-
14
- sys.path.insert(0, str(Path(__file__).parent.parent))
15
-
16
- from server.visual_reasoning_environment import VisualReasoningEnvironment
17
- from models import VisualReasoningAction
18
-
19
-
20
- LABELING_PROMPT = """Rate this narration on four dimensions (0.0 to 1.0, steps of 0.25):
21
-
22
- CONTEXT:
23
- - Algorithm: {template}
24
- - Ops this step: {ops_summary}
25
- - Previous narration: {prev_narration}
26
- - Current narration: {narration}
27
- - Concepts claimed: {covered_concepts}
28
- - Progress: {step_progress}
29
-
30
- DIMENSIONS:
31
- 1. explanatory_depth: 0.0=mechanical, 0.5=basic context, 1.0=explains why+broader concept
32
- 2. grounding_accuracy: 0.0=contradicts ops, 0.5=partial, 1.0=accurate+describes effects
33
- 3. clarity: 0.0=incomprehensible, 0.5=adequate, 1.0=clear natural voiceover
34
- 4. flow: 0.0=disconnected, 0.5=adequate connection, 1.0=natural continuation
35
-
36
- Output ONLY JSON: {{"explanatory_depth": X, "grounding_accuracy": X, "clarity": X, "flow": X}}"""
37
-
38
-
39
- # ---------------------------------------------------------------------------
40
- # easy_1: linked_list_traversal, incremental, values=[10,20,30]
41
- # ---------------------------------------------------------------------------
42
-
43
-
44
- def _easy_1_excellent() -> List[Dict[str, Any]]:
45
- return [
46
- {
47
- "step_type": "advance",
48
- "narration": "We create a linked list region and place the head pointer at position 0 — this is the only entry point for traversal.",
49
- "ops": [
50
- {"op": "add_region", "target_ids": ["list"], "params": {"style": "queue", "title": "Linked List"}},
51
- {"op": "add_pointer", "target_ids": ["head"], "params": {"region": "list", "index": 0}},
52
- ],
53
- "covered_concepts": ["head_pointer"],
54
- "intent": "setup",
55
- },
56
- {
57
- "step_type": "advance",
58
- "narration": "Adding node_1 with value 10 — the head pointer starts here, making it the first node we visit in the traversal.",
59
- "ops": [
60
- {"op": "add_node", "target_ids": ["node_1"], "params": {"value": 10, "region": "list"}},
61
- {"op": "set_role", "target_ids": ["node_1"], "params": {"role": "current"}},
62
- {"op": "annotate", "target_ids": ["node_1"], "params": {"text": "val=10"}},
63
- ],
64
- "covered_concepts": ["node_value"],
65
- "intent": "first-node",
66
- },
67
- {
68
- "step_type": "advance",
69
- "narration": "We add node_2 with value 20 and link it from node_1 via a next pointer — following these next links is how we traverse the list.",
70
- "ops": [
71
- {"op": "add_node", "target_ids": ["node_2"], "params": {"value": 20, "region": "list"}},
72
- {"op": "add_edge", "target_ids": ["node_1", "node_2"], "params": {"label": "next"}},
73
- {"op": "set_role", "target_ids": ["node_1"], "params": {"role": "visited"}},
74
- {"op": "set_role", "target_ids": ["node_2"], "params": {"role": "current"}},
75
- {"op": "move_pointer", "target_ids": ["head"], "params": {"index": 1}},
76
- ],
77
- "covered_concepts": ["next_link"],
78
- "intent": "second-node",
79
- },
80
- {
81
- "step_type": "advance",
82
- "narration": "Finally node_3 with value 30 has no next link — this null terminator marks it as the tail, ending our traversal.",
83
- "ops": [
84
- {"op": "add_node", "target_ids": ["node_3"], "params": {"value": 30, "region": "list"}},
85
- {"op": "add_edge", "target_ids": ["node_2", "node_3"], "params": {"label": "next"}},
86
- {"op": "set_role", "target_ids": ["node_2"], "params": {"role": "visited"}},
87
- {"op": "set_role", "target_ids": ["node_3"], "params": {"role": "done"}},
88
- {"op": "annotate", "target_ids": ["node_3"], "params": {"text": "tail (null next)"}},
89
- ],
90
- "covered_concepts": ["tail_marker"],
91
- "intent": "tail",
92
- },
93
- {
94
- "step_type": "complete",
95
- "narration": "The traversal is complete — we visited every node from head to tail by following next pointers.",
96
- "ops": [],
97
- "covered_concepts": [],
98
- "intent": "done",
99
- },
100
- ]
101
-
102
-
103
- def _easy_1_good() -> List[Dict[str, Any]]:
104
- return [
105
- {
106
- "step_type": "advance",
107
- "narration": "Setting up the linked list with a head pointer.",
108
- "ops": [
109
- {"op": "add_region", "target_ids": ["list"], "params": {"style": "queue", "title": "List"}},
110
- {"op": "add_pointer", "target_ids": ["head"], "params": {"region": "list"}},
111
- {"op": "add_node", "target_ids": ["n1"], "params": {"value": 10, "region": "list"}},
112
- ],
113
- "covered_concepts": ["head_pointer", "node_value"],
114
- "intent": "setup",
115
- },
116
- {
117
- "step_type": "advance",
118
- "narration": "Adding the next two nodes and connecting them.",
119
- "ops": [
120
- {"op": "add_node", "target_ids": ["n2"], "params": {"value": 20, "region": "list"}},
121
- {"op": "add_node", "target_ids": ["n3"], "params": {"value": 30, "region": "list"}},
122
- {"op": "add_edge", "target_ids": ["n1", "n2"], "params": {}},
123
- {"op": "add_edge", "target_ids": ["n2", "n3"], "params": {}},
124
- ],
125
- "covered_concepts": ["next_link"],
126
- "intent": "build",
127
- },
128
- {
129
- "step_type": "complete",
130
- "narration": "The tail node has no next pointer, traversal ends here.",
131
- "ops": [
132
- {"op": "annotate", "target_ids": ["n3"], "params": {"text": "tail"}},
133
- ],
134
- "covered_concepts": ["tail_marker"],
135
- "intent": "done",
136
- },
137
- ]
138
-
139
-
140
- def _easy_1_mediocre() -> List[Dict[str, Any]]:
141
- return [
142
- {
143
- "step_type": "advance",
144
- "narration": "Adding nodes to the list.",
145
- "ops": [
146
- {"op": "add_node", "target_ids": ["a"], "params": {"value": 10}},
147
- {"op": "add_node", "target_ids": ["b"], "params": {"value": 20}},
148
- {"op": "add_node", "target_ids": ["c"], "params": {"value": 30}},
149
- {"op": "add_edge", "target_ids": ["a", "b"], "params": {}},
150
- {"op": "add_edge", "target_ids": ["b", "c"], "params": {}},
151
- ],
152
- "covered_concepts": ["head_pointer", "node_value", "next_link", "tail_marker"],
153
- "intent": "",
154
- },
155
- {
156
- "step_type": "complete",
157
- "narration": "Done.",
158
- "ops": [],
159
- "covered_concepts": [],
160
- "intent": "",
161
- },
162
- ]
163
-
164
-
165
- def _easy_1_bad() -> List[Dict[str, Any]]:
166
- return [
167
- {
168
- "step_type": "advance",
169
- "narration": "Starting.",
170
- "ops": [
171
- {"op": "add_node", "target_ids": ["x"], "params": {}},
172
- ],
173
- "covered_concepts": ["head_pointer", "node_value", "next_link", "tail_marker"],
174
- "intent": "",
175
- },
176
- {
177
- "step_type": "complete",
178
- "narration": "Finished.",
179
- "ops": [],
180
- "covered_concepts": [],
181
- "intent": "",
182
- },
183
- ]
184
-
185
-
186
- # ---------------------------------------------------------------------------
187
- # easy_2: stack_ops, incremental, operations=["push A","push B","pop","push C"]
188
- # ---------------------------------------------------------------------------
189
-
190
-
191
- def _easy_2_excellent() -> List[Dict[str, Any]]:
192
- return [
193
- {
194
- "step_type": "advance",
195
- "narration": "We set up a stack region with a container — stacks follow Last In First Out order, so we'll use ordered=false.",
196
- "ops": [
197
- {"op": "add_region", "target_ids": ["stk"], "params": {"style": "stack", "title": "Stack"}},
198
- {"op": "add_container", "target_ids": ["stack"], "params": {"region": "stk", "ordered": False}},
199
- {"op": "add_pointer", "target_ids": ["top"], "params": {"region": "stk", "index": -1}},
200
- ],
201
- "covered_concepts": ["top_pointer"],
202
- "intent": "setup",
203
- },
204
- {
205
- "step_type": "advance",
206
- "narration": "Push A onto the stack — A becomes the new top. Then push B on top of A, so B is now the top element.",
207
- "ops": [
208
- {"op": "add_node", "target_ids": ["A"], "params": {"value": "A", "region": "stk"}},
209
- {"op": "push_to", "target_ids": ["stack", "A"], "params": {}},
210
- {"op": "add_node", "target_ids": ["B"], "params": {"value": "B", "region": "stk"}},
211
- {"op": "push_to", "target_ids": ["stack", "B"], "params": {}},
212
- {"op": "move_pointer", "target_ids": ["top"], "params": {"index": 1}},
213
- ],
214
- "covered_concepts": ["push"],
215
- "intent": "push-A-B",
216
- },
217
- {
218
- "step_type": "advance",
219
- "narration": "Pop removes B — the most recently pushed element — revealing A underneath. This is the LIFO property in action.",
220
- "ops": [
221
- {"op": "pop_from", "target_ids": ["stack"], "params": {}},
222
- {"op": "set_role", "target_ids": ["B"], "params": {"role": "inactive"}},
223
- {"op": "move_pointer", "target_ids": ["top"], "params": {"index": 0}},
224
- ],
225
- "covered_concepts": ["pop", "lifo_order"],
226
- "intent": "pop-B",
227
- },
228
- {
229
- "step_type": "complete",
230
- "narration": "Push C on top — the stack now holds [A, C] with C as the new top, confirming LIFO order throughout.",
231
- "ops": [
232
- {"op": "add_node", "target_ids": ["C"], "params": {"value": "C", "region": "stk"}},
233
- {"op": "push_to", "target_ids": ["stack", "C"], "params": {}},
234
- {"op": "move_pointer", "target_ids": ["top"], "params": {"index": 1}},
235
- ],
236
- "covered_concepts": [],
237
- "intent": "push-C-complete",
238
- },
239
- ]
240
-
241
-
242
- def _easy_2_good() -> List[Dict[str, Any]]:
243
- return [
244
- {
245
- "step_type": "advance",
246
- "narration": "Creating a stack region with a container and top pointer.",
247
- "ops": [
248
- {"op": "add_region", "target_ids": ["stk"], "params": {"style": "stack", "title": "Stack"}},
249
- {"op": "add_container", "target_ids": ["stack"], "params": {"region": "stk", "ordered": False}},
250
- {"op": "add_pointer", "target_ids": ["top"], "params": {"region": "stk", "index": -1}},
251
- ],
252
- "covered_concepts": ["top_pointer"],
253
- "intent": "setup",
254
- },
255
- {
256
- "step_type": "advance",
257
- "narration": "Push A and B onto the stack, B is now on top.",
258
- "ops": [
259
- {"op": "add_node", "target_ids": ["A"], "params": {"value": "A", "region": "stk"}},
260
- {"op": "push_to", "target_ids": ["stack", "A"], "params": {}},
261
- {"op": "add_node", "target_ids": ["B"], "params": {"value": "B", "region": "stk"}},
262
- {"op": "push_to", "target_ids": ["stack", "B"], "params": {}},
263
- ],
264
- "covered_concepts": ["push"],
265
- "intent": "push-AB",
266
- },
267
- {
268
- "step_type": "complete",
269
- "narration": "Pop B, then push C. Stack ends as [A, C] demonstrating LIFO.",
270
- "ops": [
271
- {"op": "pop_from", "target_ids": ["stack"], "params": {}},
272
- {"op": "add_node", "target_ids": ["C"], "params": {"value": "C", "region": "stk"}},
273
- {"op": "push_to", "target_ids": ["stack", "C"], "params": {}},
274
- ],
275
- "covered_concepts": ["pop", "lifo_order"],
276
- "intent": "pop-push-done",
277
- },
278
- ]
279
-
280
-
281
- def _easy_2_mediocre() -> List[Dict[str, Any]]:
282
- return [
283
- {
284
- "step_type": "advance",
285
- "narration": "Performing stack operations.",
286
- "ops": [
287
- {"op": "add_region", "target_ids": ["stk"], "params": {"style": "stack", "title": "Stack"}},
288
- {"op": "add_container", "target_ids": ["stack"], "params": {"region": "stk"}},
289
- {"op": "add_node", "target_ids": ["A"], "params": {"value": "A", "region": "stk"}},
290
- {"op": "push_to", "target_ids": ["stack", "A"], "params": {}},
291
- {"op": "add_node", "target_ids": ["B"], "params": {"value": "B", "region": "stk"}},
292
- ],
293
- "covered_concepts": ["top_pointer", "push", "pop", "lifo_order"],
294
- "intent": "",
295
- },
296
- {
297
- "step_type": "complete",
298
- "narration": "Stack is ready.",
299
- "ops": [
300
- {"op": "push_to", "target_ids": ["stack", "B"], "params": {}},
301
- ],
302
- "covered_concepts": [],
303
- "intent": "",
304
- },
305
- ]
306
-
307
-
308
- def _easy_2_bad() -> List[Dict[str, Any]]:
309
- return [
310
- {
311
- "step_type": "complete",
312
- "narration": "Stack operations.",
313
- "ops": [],
314
- "covered_concepts": ["top_pointer", "push", "pop", "lifo_order"],
315
- "intent": "",
316
- },
317
- ]
318
-
319
-
320
- # ---------------------------------------------------------------------------
321
- # medium_1: bfs_graph, graph A->B,C; B->D; C->D,E, source=A
322
- # ---------------------------------------------------------------------------
323
-
324
-
325
- def _medium_1_excellent() -> List[Dict[str, Any]]:
326
- return [
327
- {
328
- "step_type": "advance",
329
- "narration": "We draw the directed graph from the input — five nodes A through E with edges showing adjacency.",
330
- "ops": [
331
- {"op": "add_region", "target_ids": ["graph"], "params": {"style": "graph", "title": "Graph"}},
332
- {"op": "add_node", "target_ids": ["A"], "params": {"value": "A", "region": "graph"}},
333
- {"op": "add_node", "target_ids": ["B"], "params": {"value": "B", "region": "graph"}},
334
- {"op": "add_node", "target_ids": ["C"], "params": {"value": "C", "region": "graph"}},
335
- {"op": "add_node", "target_ids": ["D"], "params": {"value": "D", "region": "graph"}},
336
- {"op": "add_node", "target_ids": ["E"], "params": {"value": "E", "region": "graph"}},
337
- {"op": "add_edge", "target_ids": ["A", "B"], "params": {}},
338
- {"op": "add_edge", "target_ids": ["A", "C"], "params": {}},
339
- {"op": "add_edge", "target_ids": ["B", "D"], "params": {}},
340
- {"op": "add_edge", "target_ids": ["C", "D"], "params": {}},
341
- {"op": "add_edge", "target_ids": ["C", "E"], "params": {}},
342
- ],
343
- "covered_concepts": [],
344
- "intent": "draw-graph",
345
- },
346
- {
347
- "step_type": "advance",
348
- "narration": "We initialize BFS by creating a queue and seeding it with source A — BFS always begins from a single starting node.",
349
- "ops": [
350
- {"op": "add_region", "target_ids": ["bfs_region"], "params": {"style": "queue", "title": "BFS Queue"}},
351
- {"op": "add_container", "target_ids": ["q"], "params": {"region": "bfs_region", "ordered": True}},
352
- {"op": "push_to", "target_ids": ["q", "A"], "params": {}},
353
- {"op": "set_role", "target_ids": ["A"], "params": {"role": "frontier"}},
354
- ],
355
- "covered_concepts": ["queue", "frontier"],
356
- "intent": "init-bfs",
357
- },
358
- {
359
- "step_type": "advance",
360
- "narration": "Dequeue A from the queue and mark it visited — we then push A's unvisited neighbors B and C as frontier nodes.",
361
- "ops": [
362
- {"op": "pop_from", "target_ids": ["q"], "params": {}},
363
- {"op": "set_role", "target_ids": ["A"], "params": {"role": "visited"}},
364
- {"op": "push_to", "target_ids": ["q", "B"], "params": {}},
365
- {"op": "push_to", "target_ids": ["q", "C"], "params": {}},
366
- {"op": "set_role", "target_ids": ["B"], "params": {"role": "frontier"}},
367
- {"op": "set_role", "target_ids": ["C"], "params": {"role": "frontier"}},
368
- ],
369
- "covered_concepts": ["visited_set", "dequeue"],
370
- "intent": "visit-A",
371
- },
372
- {
373
- "step_type": "advance",
374
- "narration": "Dequeue B and visit it — B's neighbor D joins the frontier, demonstrating BFS level-by-level order.",
375
- "ops": [
376
- {"op": "pop_from", "target_ids": ["q"], "params": {}},
377
- {"op": "set_role", "target_ids": ["B"], "params": {"role": "visited"}},
378
- {"op": "push_to", "target_ids": ["q", "D"], "params": {}},
379
- {"op": "set_role", "target_ids": ["D"], "params": {"role": "frontier"}},
380
- ],
381
- "covered_concepts": ["level_order"],
382
- "intent": "visit-B",
383
- },
384
- {
385
- "step_type": "advance",
386
- "narration": "Dequeue C and visit it — C connects to D and E, but D is already in the queue so only E is added.",
387
- "ops": [
388
- {"op": "pop_from", "target_ids": ["q"], "params": {}},
389
- {"op": "set_role", "target_ids": ["C"], "params": {"role": "visited"}},
390
- {"op": "push_to", "target_ids": ["q", "E"], "params": {}},
391
- {"op": "set_role", "target_ids": ["E"], "params": {"role": "frontier"}},
392
- ],
393
- "covered_concepts": [],
394
- "intent": "visit-C",
395
- },
396
- {
397
- "step_type": "complete",
398
- "narration": "All nodes are now visited in BFS order A,B,C,D,E — each level was fully processed before moving deeper.",
399
- "ops": [
400
- {"op": "pop_from", "target_ids": ["q"], "params": {}},
401
- {"op": "set_role", "target_ids": ["D"], "params": {"role": "visited"}},
402
- {"op": "pop_from", "target_ids": ["q"], "params": {}},
403
- {"op": "set_role", "target_ids": ["E"], "params": {"role": "visited"}},
404
- ],
405
- "covered_concepts": [],
406
- "intent": "complete",
407
- },
408
- ]
409
-
410
-
411
- def _medium_1_good() -> List[Dict[str, Any]]:
412
- return [
413
- {
414
- "step_type": "advance",
415
- "narration": "Drawing the graph: nodes A through E with directed edges from the input data.",
416
- "ops": [
417
- {"op": "add_region", "target_ids": ["graph"], "params": {"style": "graph", "title": "Graph"}},
418
- {"op": "add_node", "target_ids": ["A"], "params": {"value": "A", "region": "graph"}},
419
- {"op": "add_node", "target_ids": ["B"], "params": {"value": "B", "region": "graph"}},
420
- {"op": "add_node", "target_ids": ["C"], "params": {"value": "C", "region": "graph"}},
421
- {"op": "add_node", "target_ids": ["D"], "params": {"value": "D", "region": "graph"}},
422
- {"op": "add_node", "target_ids": ["E"], "params": {"value": "E", "region": "graph"}},
423
- {"op": "add_edge", "target_ids": ["A", "B"], "params": {}},
424
- {"op": "add_edge", "target_ids": ["A", "C"], "params": {}},
425
- {"op": "add_edge", "target_ids": ["B", "D"], "params": {}},
426
- {"op": "add_edge", "target_ids": ["C", "D"], "params": {}},
427
- {"op": "add_edge", "target_ids": ["C", "E"], "params": {}},
428
- ],
429
- "covered_concepts": [],
430
- "intent": "draw-graph",
431
- },
432
- {
433
- "step_type": "advance",
434
- "narration": "Create a BFS queue and enqueue source node A, marking it as frontier.",
435
- "ops": [
436
- {"op": "add_region", "target_ids": ["bfs_region"], "params": {"style": "queue", "title": "BFS Queue"}},
437
- {"op": "add_container", "target_ids": ["q"], "params": {"region": "bfs_region", "ordered": True}},
438
- {"op": "push_to", "target_ids": ["q", "A"], "params": {}},
439
- {"op": "set_role", "target_ids": ["A"], "params": {"role": "frontier"}},
440
- ],
441
- "covered_concepts": ["queue", "frontier"],
442
- "intent": "init",
443
- },
444
- {
445
- "step_type": "advance",
446
- "narration": "Dequeue A, visit it, and enqueue neighbors B and C.",
447
- "ops": [
448
- {"op": "pop_from", "target_ids": ["q"], "params": {}},
449
- {"op": "set_role", "target_ids": ["A"], "params": {"role": "visited"}},
450
- {"op": "push_to", "target_ids": ["q", "B"], "params": {}},
451
- {"op": "push_to", "target_ids": ["q", "C"], "params": {}},
452
- ],
453
- "covered_concepts": ["visited_set", "dequeue"],
454
- "intent": "visit-A",
455
- },
456
- {
457
- "step_type": "advance",
458
- "narration": "Process B then C, adding D and E to the queue as we go.",
459
- "ops": [
460
- {"op": "pop_from", "target_ids": ["q"], "params": {}},
461
- {"op": "set_role", "target_ids": ["B"], "params": {"role": "visited"}},
462
- {"op": "push_to", "target_ids": ["q", "D"], "params": {}},
463
- {"op": "pop_from", "target_ids": ["q"], "params": {}},
464
- {"op": "set_role", "target_ids": ["C"], "params": {"role": "visited"}},
465
- {"op": "push_to", "target_ids": ["q", "E"], "params": {}},
466
- ],
467
- "covered_concepts": ["level_order"],
468
- "intent": "visit-BC",
469
- },
470
- {
471
- "step_type": "complete",
472
- "narration": "D and E are dequeued and visited, finishing BFS traversal.",
473
- "ops": [
474
- {"op": "pop_from", "target_ids": ["q"], "params": {}},
475
- {"op": "set_role", "target_ids": ["D"], "params": {"role": "visited"}},
476
- {"op": "pop_from", "target_ids": ["q"], "params": {}},
477
- {"op": "set_role", "target_ids": ["E"], "params": {"role": "visited"}},
478
- ],
479
- "covered_concepts": [],
480
- "intent": "done",
481
- },
482
- ]
483
-
484
-
485
- def _medium_1_mediocre() -> List[Dict[str, Any]]:
486
- return [
487
- {
488
- "step_type": "advance",
489
- "narration": "Drawing the graph nodes.",
490
- "ops": [
491
- {"op": "add_region", "target_ids": ["graph"], "params": {"style": "graph", "title": "Graph"}},
492
- {"op": "add_node", "target_ids": ["A"], "params": {"value": "A", "region": "graph"}},
493
- {"op": "add_node", "target_ids": ["B"], "params": {"value": "B", "region": "graph"}},
494
- {"op": "add_node", "target_ids": ["C"], "params": {"value": "C", "region": "graph"}},
495
- {"op": "add_node", "target_ids": ["D"], "params": {"value": "D", "region": "graph"}},
496
- {"op": "add_node", "target_ids": ["E"], "params": {"value": "E", "region": "graph"}},
497
- {"op": "add_edge", "target_ids": ["A", "B"], "params": {}},
498
- {"op": "add_edge", "target_ids": ["A", "C"], "params": {}},
499
- {"op": "add_edge", "target_ids": ["B", "D"], "params": {}},
500
- {"op": "add_edge", "target_ids": ["C", "D"], "params": {}},
501
- {"op": "add_edge", "target_ids": ["C", "E"], "params": {}},
502
- ],
503
- "covered_concepts": [],
504
- "intent": "",
505
- },
506
- {
507
- "step_type": "advance",
508
- "narration": "Setting up BFS.",
509
- "ops": [
510
- {"op": "add_region", "target_ids": ["bfs_region"], "params": {"style": "queue", "title": "BFS"}},
511
- {"op": "add_container", "target_ids": ["q"], "params": {"region": "bfs_region", "ordered": True}},
512
- {"op": "set_role", "target_ids": ["A"], "params": {"role": "visited"}},
513
- {"op": "set_role", "target_ids": ["B"], "params": {"role": "visited"}},
514
- {"op": "set_role", "target_ids": ["C"], "params": {"role": "visited"}},
515
- ],
516
- "covered_concepts": ["queue", "visited_set", "frontier", "level_order", "dequeue"],
517
- "intent": "",
518
- },
519
- {
520
- "step_type": "complete",
521
- "narration": "BFS finished.",
522
- "ops": [
523
- {"op": "set_role", "target_ids": ["D"], "params": {"role": "visited"}},
524
- ],
525
- "covered_concepts": [],
526
- "intent": "",
527
- },
528
- ]
529
-
530
-
531
- def _medium_1_bad() -> List[Dict[str, Any]]:
532
- return [
533
- {
534
- "step_type": "advance",
535
- "narration": "Processing the graph.",
536
- "ops": [
537
- {"op": "add_node", "target_ids": ["A"], "params": {"value": "A"}},
538
- {"op": "add_node", "target_ids": ["B"], "params": {"value": "B"}},
539
- {"op": "set_role", "target_ids": ["A"], "params": {"role": "visited"}},
540
- {"op": "set_role", "target_ids": ["B"], "params": {"role": "visited"}},
541
- ],
542
- "covered_concepts": ["queue", "visited_set", "frontier", "level_order", "dequeue"],
543
- "intent": "",
544
- },
545
- {
546
- "step_type": "complete",
547
- "narration": "Done.",
548
- "ops": [],
549
- "covered_concepts": [],
550
- "intent": "",
551
- },
552
- ]
553
-
554
-
555
- # ---------------------------------------------------------------------------
556
- # hard_1: dijkstra_step,
557
- # graph A->B(1),A->C(4),B->C(2),B->D(5),C->D(1), source=A
558
- # ---------------------------------------------------------------------------
559
-
560
-
561
- def _hard_1_excellent() -> List[Dict[str, Any]]:
562
- return [
563
- {
564
- "step_type": "advance",
565
- "narration": "We draw the weighted directed graph — four nodes A through D with edge weights from the input data.",
566
- "ops": [
567
- {"op": "add_region", "target_ids": ["graph"], "params": {"style": "graph", "title": "Weighted Graph"}},
568
- {"op": "add_node", "target_ids": ["A"], "params": {"value": "A", "region": "graph"}},
569
- {"op": "add_node", "target_ids": ["B"], "params": {"value": "B", "region": "graph"}},
570
- {"op": "add_node", "target_ids": ["C"], "params": {"value": "C", "region": "graph"}},
571
- {"op": "add_node", "target_ids": ["D"], "params": {"value": "D", "region": "graph"}},
572
- {"op": "add_edge", "target_ids": ["A", "B"], "params": {"label": "1"}},
573
- {"op": "add_edge", "target_ids": ["A", "C"], "params": {"label": "4"}},
574
- {"op": "add_edge", "target_ids": ["B", "C"], "params": {"label": "2"}},
575
- {"op": "add_edge", "target_ids": ["B", "D"], "params": {"label": "5"}},
576
- {"op": "add_edge", "target_ids": ["C", "D"], "params": {"label": "1"}},
577
- ],
578
- "covered_concepts": [],
579
- "intent": "draw-graph",
580
- },
581
- {
582
- "step_type": "advance",
583
- "narration": "Initialize Dijkstra by creating a priority queue and distance table — source A gets distance 0, all others start at infinity.",
584
- "ops": [
585
- {"op": "add_region", "target_ids": ["pq_region"], "params": {"style": "queue", "title": "Priority Queue"}},
586
- {"op": "add_container", "target_ids": ["pq"], "params": {"region": "pq_region", "ordered": True}},
587
- {"op": "push_to", "target_ids": ["pq", "A"], "params": {}},
588
- {"op": "annotate", "target_ids": ["A"], "params": {"text": "d=0"}},
589
- {"op": "set_role", "target_ids": ["A"], "params": {"role": "current"}},
590
- ],
591
- "covered_concepts": ["priority_queue", "distance_table"],
592
- "intent": "init",
593
- },
594
- {
595
- "step_type": "advance",
596
- "narration": "Extract A (d=0) from the priority queue — Dijkstra guarantees this distance is final. Relax edges to B and C: d[B]=1, d[C]=4.",
597
- "ops": [
598
- {"op": "pop_from", "target_ids": ["pq"], "params": {}},
599
- {"op": "set_role", "target_ids": ["A"], "params": {"role": "visited"}},
600
- {"op": "set_value", "target_ids": ["B"], "params": {"value": 1}},
601
- {"op": "annotate", "target_ids": ["B"], "params": {"text": "d=1"}},
602
- {"op": "push_to", "target_ids": ["pq", "B"], "params": {}},
603
- {"op": "set_value", "target_ids": ["C"], "params": {"value": 4}},
604
- ],
605
- "covered_concepts": ["relaxation", "shortest_path_invariant"],
606
- "intent": "visit-A",
607
- },
608
- {
609
- "step_type": "advance",
610
- "narration": "Extract B (d=1) — its shortest distance is now permanent. Relaxing B's edges: d[C] improves from 4 to 3 via B, and d[D]=6.",
611
- "ops": [
612
- {"op": "pop_from", "target_ids": ["pq"], "params": {}},
613
- {"op": "set_role", "target_ids": ["B"], "params": {"role": "visited"}},
614
- {"op": "annotate", "target_ids": ["B"], "params": {"text": "d=1 final"}},
615
- {"op": "set_value", "target_ids": ["C"], "params": {"value": 3}},
616
- {"op": "annotate", "target_ids": ["C"], "params": {"text": "d=3"}},
617
- {"op": "push_to", "target_ids": ["pq", "C"], "params": {}},
618
- ],
619
- "covered_concepts": ["visited_set"],
620
- "intent": "visit-B",
621
- },
622
- {
623
- "step_type": "complete",
624
- "narration": "Extract C (d=3), relax to D: d[D] improves to 4. Then D is extracted with final distance 4. All shortest paths found.",
625
- "ops": [
626
- {"op": "pop_from", "target_ids": ["pq"], "params": {}},
627
- {"op": "set_role", "target_ids": ["C"], "params": {"role": "visited"}},
628
- {"op": "set_value", "target_ids": ["D"], "params": {"value": 4}},
629
- {"op": "annotate", "target_ids": ["D"], "params": {"text": "d=4"}},
630
- {"op": "set_role", "target_ids": ["D"], "params": {"role": "done"}},
631
- ],
632
- "covered_concepts": [],
633
- "intent": "complete",
634
- },
635
- ]
636
-
637
-
638
- def _hard_1_good() -> List[Dict[str, Any]]:
639
- return [
640
- {
641
- "step_type": "advance",
642
- "narration": "Drawing the weighted graph with nodes A, B, C, D and their edge weights.",
643
- "ops": [
644
- {"op": "add_region", "target_ids": ["graph"], "params": {"style": "graph", "title": "Graph"}},
645
- {"op": "add_node", "target_ids": ["A"], "params": {"value": "A", "region": "graph"}},
646
- {"op": "add_node", "target_ids": ["B"], "params": {"value": "B", "region": "graph"}},
647
- {"op": "add_node", "target_ids": ["C"], "params": {"value": "C", "region": "graph"}},
648
- {"op": "add_node", "target_ids": ["D"], "params": {"value": "D", "region": "graph"}},
649
- {"op": "add_edge", "target_ids": ["A", "B"], "params": {"label": "1"}},
650
- {"op": "add_edge", "target_ids": ["A", "C"], "params": {"label": "4"}},
651
- {"op": "add_edge", "target_ids": ["B", "C"], "params": {"label": "2"}},
652
- {"op": "add_edge", "target_ids": ["B", "D"], "params": {"label": "5"}},
653
- {"op": "add_edge", "target_ids": ["C", "D"], "params": {"label": "1"}},
654
- ],
655
- "covered_concepts": [],
656
- "intent": "draw-graph",
657
- },
658
- {
659
- "step_type": "advance",
660
- "narration": "Set up a priority queue for Dijkstra and initialize source A with distance 0.",
661
- "ops": [
662
- {"op": "add_region", "target_ids": ["pq_region"], "params": {"style": "queue", "title": "PQ"}},
663
- {"op": "add_container", "target_ids": ["pq"], "params": {"region": "pq_region", "ordered": True}},
664
- {"op": "push_to", "target_ids": ["pq", "A"], "params": {}},
665
- {"op": "annotate", "target_ids": ["A"], "params": {"text": "d=0"}},
666
- ],
667
- "covered_concepts": ["priority_queue", "distance_table"],
668
- "intent": "init",
669
- },
670
- {
671
- "step_type": "advance",
672
- "narration": "Extract A, relax edges to B (d=1) and C (d=4).",
673
- "ops": [
674
- {"op": "pop_from", "target_ids": ["pq"], "params": {}},
675
- {"op": "set_role", "target_ids": ["A"], "params": {"role": "visited"}},
676
- {"op": "set_value", "target_ids": ["B"], "params": {"value": 1}},
677
- {"op": "push_to", "target_ids": ["pq", "B"], "params": {}},
678
- {"op": "set_value", "target_ids": ["C"], "params": {"value": 4}},
679
- ],
680
- "covered_concepts": ["relaxation"],
681
- "intent": "visit-A",
682
- },
683
- {
684
- "step_type": "advance",
685
- "narration": "Extract B, update C's distance to 3 via B. Push C into the queue.",
686
- "ops": [
687
- {"op": "pop_from", "target_ids": ["pq"], "params": {}},
688
- {"op": "set_role", "target_ids": ["B"], "params": {"role": "visited"}},
689
- {"op": "set_value", "target_ids": ["C"], "params": {"value": 3}},
690
- {"op": "push_to", "target_ids": ["pq", "C"], "params": {}},
691
- ],
692
- "covered_concepts": ["visited_set", "shortest_path_invariant"],
693
- "intent": "visit-B",
694
- },
695
- {
696
- "step_type": "complete",
697
- "narration": "Extract C, relax D to distance 4. All shortest paths computed.",
698
- "ops": [
699
- {"op": "pop_from", "target_ids": ["pq"], "params": {}},
700
- {"op": "set_role", "target_ids": ["C"], "params": {"role": "visited"}},
701
- {"op": "set_value", "target_ids": ["D"], "params": {"value": 4}},
702
- {"op": "set_role", "target_ids": ["D"], "params": {"role": "done"}},
703
- ],
704
- "covered_concepts": [],
705
- "intent": "complete",
706
- },
707
- ]
708
-
709
-
710
- def _hard_1_mediocre() -> List[Dict[str, Any]]:
711
- return [
712
- {
713
- "step_type": "advance",
714
- "narration": "Drawing the graph for Dijkstra.",
715
- "ops": [
716
- {"op": "add_region", "target_ids": ["graph"], "params": {"style": "graph", "title": "Graph"}},
717
- {"op": "add_node", "target_ids": ["A"], "params": {"value": "A", "region": "graph"}},
718
- {"op": "add_node", "target_ids": ["B"], "params": {"value": "B", "region": "graph"}},
719
- {"op": "add_node", "target_ids": ["C"], "params": {"value": "C", "region": "graph"}},
720
- {"op": "add_node", "target_ids": ["D"], "params": {"value": "D", "region": "graph"}},
721
- {"op": "add_edge", "target_ids": ["A", "B"], "params": {"label": "1"}},
722
- {"op": "add_edge", "target_ids": ["A", "C"], "params": {"label": "4"}},
723
- {"op": "add_edge", "target_ids": ["B", "C"], "params": {"label": "2"}},
724
- {"op": "add_edge", "target_ids": ["B", "D"], "params": {"label": "5"}},
725
- {"op": "add_edge", "target_ids": ["C", "D"], "params": {"label": "1"}},
726
- ],
727
- "covered_concepts": [],
728
- "intent": "",
729
- },
730
- {
731
- "step_type": "advance",
732
- "narration": "Running Dijkstra on the graph.",
733
- "ops": [
734
- {"op": "add_region", "target_ids": ["pq_region"], "params": {"style": "queue", "title": "PQ"}},
735
- {"op": "add_container", "target_ids": ["pq"], "params": {"region": "pq_region", "ordered": True}},
736
- {"op": "set_role", "target_ids": ["A"], "params": {"role": "visited"}},
737
- {"op": "set_role", "target_ids": ["B"], "params": {"role": "visited"}},
738
- {"op": "set_value", "target_ids": ["D"], "params": {"value": 4}},
739
- ],
740
- "covered_concepts": ["priority_queue", "distance_table", "relaxation", "visited_set", "shortest_path_invariant"],
741
- "intent": "",
742
- },
743
- {
744
- "step_type": "complete",
745
- "narration": "Shortest paths found.",
746
- "ops": [
747
- {"op": "set_role", "target_ids": ["C"], "params": {"role": "visited"}},
748
- ],
749
- "covered_concepts": [],
750
- "intent": "",
751
- },
752
- ]
753
-
754
-
755
- def _hard_1_bad() -> List[Dict[str, Any]]:
756
- return [
757
- {
758
- "step_type": "advance",
759
- "narration": "Starting Dijkstra.",
760
- "ops": [
761
- {"op": "add_node", "target_ids": ["A"], "params": {"value": "A"}},
762
- {"op": "set_role", "target_ids": ["A"], "params": {"role": "done"}},
763
- ],
764
- "covered_concepts": ["priority_queue", "distance_table", "relaxation", "visited_set", "shortest_path_invariant"],
765
- "intent": "",
766
- },
767
- {
768
- "step_type": "complete",
769
- "narration": "Dijkstra done.",
770
- "ops": [],
771
- "covered_concepts": [],
772
- "intent": "",
773
- },
774
- ]
775
-
776
-
777
- # ---------------------------------------------------------------------------
778
- # hard_2: bst_insert, incremental, keys=[5,3,7,2,4,8]
779
- # ---------------------------------------------------------------------------
780
-
781
-
782
- def _hard_2_excellent() -> List[Dict[str, Any]]:
783
- return [
784
- {
785
- "step_type": "advance",
786
- "narration": "Create a tree region and insert 5 as the root node — the first key always becomes the root of the BST.",
787
- "ops": [
788
- {"op": "add_region", "target_ids": ["tree"], "params": {"style": "tree", "title": "BST", "root": "n5"}},
789
- {"op": "add_node", "target_ids": ["n5"], "params": {"value": 5, "region": "tree"}},
790
- {"op": "set_role", "target_ids": ["n5"], "params": {"role": "root"}},
791
- ],
792
- "covered_concepts": ["root_node"],
793
- "intent": "insert-root",
794
- },
795
- {
796
- "step_type": "advance",
797
- "narration": "Insert 3: comparing with root 5, 3 < 5 so it goes left — this maintains the BST invariant where left children are smaller.",
798
- "ops": [
799
- {"op": "add_node", "target_ids": ["n3"], "params": {"value": 3, "region": "tree"}},
800
- {"op": "add_edge", "target_ids": ["n5", "n3"], "params": {"label": "L"}},
801
- {"op": "set_role", "target_ids": ["n5"], "params": {"role": "comparing"}},
802
- ],
803
- "covered_concepts": ["bst_invariant", "left_subtree"],
804
- "intent": "insert-3",
805
- },
806
- {
807
- "step_type": "advance",
808
- "narration": "Insert 7: comparing with root 5, 7 > 5 so it goes right — the right subtree holds all keys greater than the parent.",
809
- "ops": [
810
- {"op": "add_node", "target_ids": ["n7"], "params": {"value": 7, "region": "tree"}},
811
- {"op": "add_edge", "target_ids": ["n5", "n7"], "params": {"label": "R"}},
812
- {"op": "set_role", "target_ids": ["n5"], "params": {"role": "root"}},
813
- ],
814
- "covered_concepts": ["right_subtree"],
815
- "intent": "insert-7",
816
- },
817
- {
818
- "step_type": "advance",
819
- "narration": "Insert 2 and 4: recursively comparing, 2 < 3 goes left of n3, and 4 > 3 goes right of n3.",
820
- "ops": [
821
- {"op": "add_node", "target_ids": ["n2"], "params": {"value": 2, "region": "tree"}},
822
- {"op": "add_edge", "target_ids": ["n3", "n2"], "params": {"label": "L"}},
823
- {"op": "add_node", "target_ids": ["n4"], "params": {"value": 4, "region": "tree"}},
824
- {"op": "add_edge", "target_ids": ["n3", "n4"], "params": {"label": "R"}},
825
- ],
826
- "covered_concepts": ["recursive_insert"],
827
- "intent": "insert-2-4",
828
- },
829
- {
830
- "step_type": "complete",
831
- "narration": "Insert 8: 8 > 5 go right, 8 > 7 go right — each insertion recursively finds the correct leaf position.",
832
- "ops": [
833
- {"op": "add_node", "target_ids": ["n8"], "params": {"value": 8, "region": "tree"}},
834
- {"op": "add_edge", "target_ids": ["n7", "n8"], "params": {"label": "R"}},
835
- ],
836
- "covered_concepts": [],
837
- "intent": "insert-8-complete",
838
- },
839
- ]
840
-
841
-
842
- def _hard_2_good() -> List[Dict[str, Any]]:
843
- return [
844
- {
845
- "step_type": "advance",
846
- "narration": "Create a BST tree region and insert 5 as root.",
847
- "ops": [
848
- {"op": "add_region", "target_ids": ["tree"], "params": {"style": "tree", "title": "BST", "root": "n5"}},
849
- {"op": "add_node", "target_ids": ["n5"], "params": {"value": 5, "region": "tree"}},
850
- {"op": "set_role", "target_ids": ["n5"], "params": {"role": "root"}},
851
- ],
852
- "covered_concepts": ["root_node"],
853
- "intent": "root",
854
- },
855
- {
856
- "step_type": "advance",
857
- "narration": "Insert 3 left of 5 and 7 right of 5, maintaining BST invariant.",
858
- "ops": [
859
- {"op": "add_node", "target_ids": ["n3"], "params": {"value": 3, "region": "tree"}},
860
- {"op": "add_edge", "target_ids": ["n5", "n3"], "params": {"label": "L"}},
861
- {"op": "add_node", "target_ids": ["n7"], "params": {"value": 7, "region": "tree"}},
862
- {"op": "add_edge", "target_ids": ["n5", "n7"], "params": {"label": "R"}},
863
- ],
864
- "covered_concepts": ["bst_invariant", "left_subtree", "right_subtree"],
865
- "intent": "insert-3-7",
866
- },
867
- {
868
- "step_type": "advance",
869
- "narration": "Recursively insert 2 left of 3 and 4 right of 3.",
870
- "ops": [
871
- {"op": "add_node", "target_ids": ["n2"], "params": {"value": 2, "region": "tree"}},
872
- {"op": "add_edge", "target_ids": ["n3", "n2"], "params": {"label": "L"}},
873
- {"op": "add_node", "target_ids": ["n4"], "params": {"value": 4, "region": "tree"}},
874
- {"op": "add_edge", "target_ids": ["n3", "n4"], "params": {"label": "R"}},
875
- ],
876
- "covered_concepts": ["recursive_insert"],
877
- "intent": "insert-2-4",
878
- },
879
- {
880
- "step_type": "complete",
881
- "narration": "Insert 8 right of 7 to complete the BST.",
882
- "ops": [
883
- {"op": "add_node", "target_ids": ["n8"], "params": {"value": 8, "region": "tree"}},
884
- {"op": "add_edge", "target_ids": ["n7", "n8"], "params": {"label": "R"}},
885
- ],
886
- "covered_concepts": [],
887
- "intent": "done",
888
- },
889
- ]
890
-
891
-
892
- def _hard_2_mediocre() -> List[Dict[str, Any]]:
893
- return [
894
- {
895
- "step_type": "advance",
896
- "narration": "Building the BST.",
897
- "ops": [
898
- {"op": "add_region", "target_ids": ["tree"], "params": {"style": "tree", "title": "BST"}},
899
- {"op": "add_node", "target_ids": ["n5"], "params": {"value": 5, "region": "tree"}},
900
- {"op": "add_node", "target_ids": ["n3"], "params": {"value": 3, "region": "tree"}},
901
- {"op": "add_edge", "target_ids": ["n5", "n3"], "params": {}},
902
- {"op": "add_node", "target_ids": ["n7"], "params": {"value": 7, "region": "tree"}},
903
- {"op": "add_edge", "target_ids": ["n5", "n7"], "params": {}},
904
- ],
905
- "covered_concepts": ["root_node", "bst_invariant", "left_subtree", "right_subtree", "recursive_insert"],
906
- "intent": "",
907
- },
908
- {
909
- "step_type": "complete",
910
- "narration": "Tree built.",
911
- "ops": [
912
- {"op": "add_node", "target_ids": ["n2"], "params": {"value": 2, "region": "tree"}},
913
- {"op": "add_edge", "target_ids": ["n3", "n2"], "params": {}},
914
- ],
915
- "covered_concepts": [],
916
- "intent": "",
917
- },
918
- ]
919
-
920
-
921
- def _hard_2_bad() -> List[Dict[str, Any]]:
922
- return [
923
- {
924
- "step_type": "advance",
925
- "narration": "Inserting keys.",
926
- "ops": [
927
- {"op": "add_node", "target_ids": ["n5"], "params": {"value": 5}},
928
- ],
929
- "covered_concepts": ["root_node", "bst_invariant", "left_subtree", "right_subtree", "recursive_insert"],
930
- "intent": "",
931
- },
932
- {
933
- "step_type": "complete",
934
- "narration": "BST done.",
935
- "ops": [],
936
- "covered_concepts": [],
937
- "intent": "",
938
- },
939
- ]
940
-
941
-
942
- # ---------------------------------------------------------------------------
943
- # scenario registry
944
- # ---------------------------------------------------------------------------
945
-
946
-
947
- def get_all_scripted_scenarios() -> Dict[str, Dict[str, List[Dict[str, Any]]]]:
948
- return {
949
- "easy_1": {
950
- "excellent": _easy_1_excellent(),
951
- "good": _easy_1_good(),
952
- "mediocre": _easy_1_mediocre(),
953
- "bad": _easy_1_bad(),
954
- },
955
- "easy_2": {
956
- "excellent": _easy_2_excellent(),
957
- "good": _easy_2_good(),
958
- "mediocre": _easy_2_mediocre(),
959
- "bad": _easy_2_bad(),
960
- },
961
- "medium_1": {
962
- "excellent": _medium_1_excellent(),
963
- "good": _medium_1_good(),
964
- "mediocre": _medium_1_mediocre(),
965
- "bad": _medium_1_bad(),
966
- },
967
- "hard_1": {
968
- "excellent": _hard_1_excellent(),
969
- "good": _hard_1_good(),
970
- "mediocre": _hard_1_mediocre(),
971
- "bad": _hard_1_bad(),
972
- },
973
- "hard_2": {
974
- "excellent": _hard_2_excellent(),
975
- "good": _hard_2_good(),
976
- "mediocre": _hard_2_mediocre(),
977
- "bad": _hard_2_bad(),
978
- },
979
- }
980
-
981
-
982
- # ---------------------------------------------------------------------------
983
- # rollout runner
984
- # ---------------------------------------------------------------------------
985
-
986
-
987
- def run_scripted_rollout(
988
- env: VisualReasoningEnvironment,
989
- scenario_id: str,
990
- actions: List[Dict[str, Any]],
991
- ) -> List[Dict[str, Any]]:
992
- obs = env.reset(scenario_id=scenario_id)
993
- template = obs.task_name
994
- steps: List[Dict[str, Any]] = []
995
- prev_narration = ""
996
-
997
- for i, action_dict in enumerate(actions):
998
- action = VisualReasoningAction(**action_dict)
999
- obs = env.step(action)
1000
-
1001
- ops_parts: List[str] = []
1002
- for op in action_dict.get("ops") or []:
1003
- tids = ",".join(op.get("target_ids") or [])
1004
- ops_parts.append(f"{op['op']}[{tids}]")
1005
- ops_summary = ", ".join(ops_parts)
1006
-
1007
- checklist = list(obs.concept_checklist)
1008
- covered = list(obs.concept_coverage)
1009
-
1010
- steps.append({
1011
- "scenario_id": scenario_id,
1012
- "template": template,
1013
- "step_id": i + 1,
1014
- "narration": action_dict.get("narration", ""),
1015
- "ops_summary": ops_summary,
1016
- "prev_narration": prev_narration,
1017
- "covered_concepts": action_dict.get("covered_concepts", []),
1018
- "step_progress": f"step {i + 1}, {len(covered)}/{len(checklist)} concepts",
1019
- "score_breakdown": dict(obs.score_breakdown),
1020
- "reward": obs.reward,
1021
- "action": action_dict,
1022
- })
1023
- prev_narration = action_dict.get("narration", "")
1024
- if obs.done:
1025
- break
1026
-
1027
- return steps
1028
-
1029
-
1030
- # ---------------------------------------------------------------------------
1031
- # labeling prompt builder
1032
- # ---------------------------------------------------------------------------
1033
-
1034
-
1035
- def build_labeling_prompt(step: Dict[str, Any]) -> str:
1036
- return LABELING_PROMPT.format(
1037
- template=step["template"],
1038
- ops_summary=step["ops_summary"] or "(none)",
1039
- prev_narration=step["prev_narration"] or "(none)",
1040
- narration=step["narration"],
1041
- covered_concepts=", ".join(step["covered_concepts"]) or "(none)",
1042
- step_progress=step["step_progress"],
1043
- )
1044
-
1045
-
1046
- # ---------------------------------------------------------------------------
1047
- # main
1048
- # ---------------------------------------------------------------------------
1049
-
1050
-
1051
- def main() -> None:
1052
- parser = argparse.ArgumentParser(description="Generate rubric training data")
1053
- parser.add_argument("--output-dir", default=str(Path(__file__).parent / "output"))
1054
- args = parser.parse_args()
1055
- os.makedirs(args.output_dir, exist_ok=True)
1056
-
1057
- env = VisualReasoningEnvironment()
1058
- all_steps: List[Dict[str, Any]] = []
1059
-
1060
- scenarios = get_all_scripted_scenarios()
1061
-
1062
- for scenario_id, quality_map in scenarios.items():
1063
- for quality, actions in quality_map.items():
1064
- try:
1065
- steps = run_scripted_rollout(env, scenario_id, actions)
1066
- except Exception as exc:
1067
- print(f"WARN: {scenario_id}/{quality} failed: {exc}")
1068
- continue
1069
- for s in steps:
1070
- s["quality_level"] = quality
1071
- all_steps.extend(steps)
1072
-
1073
- rollout_path = os.path.join(args.output_dir, "rollout_data.jsonl")
1074
- with open(rollout_path, "w", encoding="utf-8") as f:
1075
- for step in all_steps:
1076
- f.write(json.dumps(step, default=str) + "\n")
1077
-
1078
- prompts_path = os.path.join(args.output_dir, "labeling_prompts.jsonl")
1079
- with open(prompts_path, "w", encoding="utf-8") as f:
1080
- for step in all_steps:
1081
- record = {
1082
- "scenario_id": step["scenario_id"],
1083
- "quality_level": step["quality_level"],
1084
- "step_id": step["step_id"],
1085
- "prompt": build_labeling_prompt(step),
1086
- }
1087
- f.write(json.dumps(record) + "\n")
1088
-
1089
- quality_stats: Dict[str, Dict[str, Any]] = {}
1090
- for step in all_steps:
1091
- ql = step["quality_level"]
1092
- if ql not in quality_stats:
1093
- quality_stats[ql] = {"count": 0, "total_reward": 0.0, "scores": []}
1094
- quality_stats[ql]["count"] += 1
1095
- quality_stats[ql]["total_reward"] += step["reward"]
1096
- quality_stats[ql]["scores"].append(
1097
- step["score_breakdown"].get("overall_score", 0.0)
1098
- )
1099
-
1100
- summary: Dict[str, Any] = {
1101
- "total_steps": len(all_steps),
1102
- "scenarios": len(scenarios),
1103
- "quality_levels": {},
1104
- }
1105
- for ql, stats in quality_stats.items():
1106
- scores = stats["scores"]
1107
- summary["quality_levels"][ql] = {
1108
- "step_count": stats["count"],
1109
- "avg_reward": round(stats["total_reward"] / max(1, stats["count"]), 4),
1110
- "avg_score": round(sum(scores) / max(1, len(scores)), 4),
1111
- "min_score": round(min(scores) if scores else 0.0, 4),
1112
- "max_score": round(max(scores) if scores else 0.0, 4),
1113
- }
1114
-
1115
- summary_path = os.path.join(args.output_dir, "rubric_summary.json")
1116
- with open(summary_path, "w", encoding="utf-8") as f:
1117
- json.dump(summary, f, indent=2)
1118
-
1119
- print(f"Generated {len(all_steps)} steps across {len(scenarios)} scenarios")
1120
- print(f"Output: {args.output_dir}/")
1121
- for ql in ("excellent", "good", "mediocre", "bad"):
1122
- if ql in summary["quality_levels"]:
1123
- info = summary["quality_levels"][ql]
1124
- print(
1125
- f" {ql:>10}: {info['step_count']} steps, "
1126
- f"avg_score={info['avg_score']:.4f}, "
1127
- f"avg_reward={info['avg_reward']:.4f}"
1128
- )
1129
-
1130
-
1131
- if __name__ == "__main__":
1132
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
server/app.py CHANGED
@@ -43,7 +43,10 @@ try:
43
  from .scenario_loader import load_scenarios
44
  from .scoring import weights_for_difficulty
45
  from .constants import (
46
- ALLOWED_OPS, ROLE_VALUES, REGION_STYLES, NAMED_POSITIONS,
 
 
 
47
  )
48
  except ImportError:
49
  from models import VisualReasoningAction, VisualReasoningObservation
@@ -51,7 +54,10 @@ except ImportError:
51
  from server.scenario_loader import load_scenarios
52
  from server.scoring import weights_for_difficulty
53
  from server.constants import (
54
- ALLOWED_OPS, ROLE_VALUES, REGION_STYLES, NAMED_POSITIONS,
 
 
 
55
  )
56
 
57
 
@@ -97,7 +103,12 @@ def _scenario_display_name(s: Dict[str, Any]) -> str:
97
 
98
 
99
  def _tier_emoji(t: str) -> str:
100
- return {"easy": "\U0001f7e2", "medium": "\U0001f7e1", "hard": "\U0001f7e0", "expert": "\U0001f534"}.get(t, "⚪")
 
 
 
 
 
101
 
102
 
103
  def _tier_label(t: str) -> str:
@@ -108,6 +119,7 @@ def _tier_label(t: str) -> str:
108
  # Broadcaster (WebSocket fan-out for live viewer)
109
  # ---------------------------------------------------------------------------
110
 
 
111
  class Broadcaster:
112
  """Append-only message log with long-poll support."""
113
 
@@ -162,6 +174,7 @@ _broadcaster = Broadcaster()
162
  # TTS via fal-ai Kokoro
163
  # ---------------------------------------------------------------------------
164
 
 
165
  async def _tts_to_base64(text: str) -> Optional[str]:
166
  if not text or not text.strip():
167
  return None
@@ -169,6 +182,7 @@ async def _tts_to_base64(text: str) -> Optional[str]:
169
  return None
170
  try:
171
  import aiohttp
 
172
  async with aiohttp.ClientSession() as session:
173
  async with session.post(
174
  "https://fal.run/fal-ai/kokoro",
@@ -181,14 +195,22 @@ async def _tts_to_base64(text: str) -> Optional[str]:
181
  ) as resp:
182
  if resp.status != 200:
183
  body = await resp.text()
184
- print(f"[TTS] fal-ai error {resp.status}: {body}", file=sys.stderr, flush=True)
 
 
 
 
185
  return None
186
  data = await resp.json()
187
  audio_url = data.get("audio", {}).get("url")
188
  if not audio_url:
189
- print("[TTS] no audio URL in fal-ai response", file=sys.stderr, flush=True)
 
 
190
  return None
191
- async with session.get(audio_url, timeout=aiohttp.ClientTimeout(total=15)) as audio_resp:
 
 
192
  audio_bytes = await audio_resp.read()
193
  result = base64.b64encode(audio_bytes).decode("ascii")
194
  print(f"[TTS] generated {len(audio_bytes)} bytes of audio", flush=True)
@@ -211,6 +233,7 @@ def _get_llm_client(api_key: str = ""):
211
  key = api_key or LLM_API_KEY
212
  if _openai_client is None or key != _openai_client_key:
213
  from openai import OpenAI
 
214
  _openai_client = OpenAI(base_url=LLM_API_BASE, api_key=key)
215
  _openai_client_key = key
216
  return _openai_client
@@ -219,6 +242,7 @@ def _get_llm_client(api_key: str = ""):
219
  def _build_system_prompt() -> str:
220
  try:
221
  from inference import SYSTEM_PROMPT
 
222
  return SYSTEM_PROMPT
223
  except Exception:
224
  return (
@@ -227,9 +251,12 @@ def _build_system_prompt() -> str:
227
  )
228
 
229
 
230
- def _build_user_prompt(obs: Any, last_action: Optional[Dict], last_reward: float, history: List[str]) -> str:
 
 
231
  try:
232
  from inference import build_user_prompt
 
233
  return build_user_prompt(obs, last_action, last_reward, history)
234
  except Exception:
235
  return f"Goal: {obs.goal}\nEntities: {list(obs.entities.keys())}\nRemaining steps: {obs.remaining_step_budget}"
@@ -245,10 +272,12 @@ def _strip_thinking_tokens(text: str) -> str:
245
  def _parse_llm_response(text: str) -> Optional[Dict[str, Any]]:
246
  try:
247
  from inference import parse_action, normalize_action
 
248
  parsed = parse_action(text)
249
  return normalize_action(parsed or {})
250
  except Exception:
251
  import re
 
252
  match = re.search(r"\{.*\}", text, flags=re.DOTALL)
253
  if match:
254
  try:
@@ -315,18 +344,20 @@ async def _run_demo(scenario_id: str, api_key: str = "", model_name: str = "") -
315
 
316
  goal_audio = await _tts_to_base64(obs.goal)
317
  snap = _obs_snapshot(obs)
318
- await _broadcaster.send({
319
- "type": "reset",
320
- "task_name": obs.task_name,
321
- "scenario_id": obs.scenario_id,
322
- "goal": obs.goal,
323
- "checklist": list(obs.concept_checklist),
324
- "input_data": dict(obs.input_data),
325
- "constraints": list(obs.constraints),
326
- "max_steps": obs.max_steps,
327
- "audio": goal_audio,
328
- **snap,
329
- })
 
 
330
 
331
  await asyncio.sleep(2.0)
332
 
@@ -349,7 +380,10 @@ async def _run_demo(scenario_id: str, api_key: str = "", model_name: str = "") -
349
  temperature=LLM_TEMPERATURE,
350
  max_tokens=LLM_MAX_TOKENS,
351
  stream=False,
352
- ).choices[0].message.content or ""
 
 
 
353
  )
354
  text = _strip_thinking_tokens(text)
355
  print(f"[DEMO] Step {step}: LLM returned {len(text)} chars", flush=True)
@@ -390,23 +424,25 @@ async def _run_demo(scenario_id: str, api_key: str = "", model_name: str = "") -
390
  history.append(f"Step {step}: {action_dict.get('narration', '')}")
391
 
392
  snap = _obs_snapshot(obs)
393
- await _broadcaster.send({
394
- "type": "step",
395
- "task_name": obs_dict.get("task_name", ""),
396
- "scenario_id": obs_dict.get("scenario_id", ""),
397
- "step": step,
398
- "step_type": action_dict.get("step_type"),
399
- "intent": action_dict.get("intent", ""),
400
- "narration": narration,
401
- "ops": action_dict.get("ops", []),
402
- "covered_concepts": action_dict.get("covered_concepts", []),
403
- "reward": float(reward),
404
- "score": float(overall),
405
- "done": bool(done),
406
- "error": error,
407
- "audio": audio_b64,
408
- **snap,
409
- })
 
 
410
 
411
  if done:
412
  await asyncio.sleep(2.0)
@@ -418,20 +454,23 @@ async def _run_demo(scenario_id: str, api_key: str = "", model_name: str = "") -
418
  print(f"[DEMO] error: {exc}", file=sys.stderr, flush=True)
419
  traceback.print_exc()
420
 
421
- await _broadcaster.send({
422
- "type": "end",
423
- "task_name": scenario_id,
424
- "success": score >= 0.65,
425
- "steps": steps_taken,
426
- "score": float(score),
427
- "rewards": [float(r) for r in rewards],
428
- })
 
 
429
 
430
 
431
  # ---------------------------------------------------------------------------
432
  # Scenario browser callbacks
433
  # ---------------------------------------------------------------------------
434
 
 
435
  def list_scenario_choices() -> List[str]:
436
  return [_scenario_display_name(s) for s in _get_scenarios()]
437
 
@@ -460,7 +499,9 @@ def show_scenario(choice: str) -> tuple:
460
 
461
  checklist_md = "\n".join(f"- `{c}`" for c in checklist)
462
 
463
- constraints_md = "\n".join(f"- `{c}`" for c in constraints) if constraints else "_None_"
 
 
464
 
465
  return header, input_md, checklist_md, constraints_md
466
  return "Scenario not found.", "", "", ""
@@ -470,6 +511,7 @@ def show_scenario(choice: str) -> tuple:
470
  # Scoring explorer callback
471
  # ---------------------------------------------------------------------------
472
 
 
473
  def show_weights(difficulty: str) -> str:
474
  d = difficulty.lower()
475
  w = weights_for_difficulty(d)
@@ -486,7 +528,10 @@ def show_weights(difficulty: str) -> str:
486
  # Live demo callbacks (Gradio)
487
  # ---------------------------------------------------------------------------
488
 
489
- async def _start_live_demo(scenario_choice: str, hf_token: str = "", model_name: str = "") -> str:
 
 
 
490
  global _demo_task
491
  if not scenario_choice:
492
  return "Select a scenario first."
@@ -523,7 +568,10 @@ async def _stop_live_demo() -> str:
523
  # Gradio UI (custom builder for openenv's gradio_builder parameter)
524
  # ---------------------------------------------------------------------------
525
 
526
- def build_ui(web_manager, action_fields, metadata, is_chat_env, title, quick_start_md) -> gr.Blocks:
 
 
 
527
  """Custom Gradio UI builder for openenv's gradio_builder parameter."""
528
  with gr.Blocks(
529
  title="Visual Reasoning Environment",
@@ -537,7 +585,8 @@ def build_ui(web_manager, action_fields, metadata, is_chat_env, title, quick_sta
537
  # TAB 1: About (blog-style, mirrors README)
538
  # ============================================================
539
  with gr.Tab("About"):
540
- gr.Markdown("""
 
541
  ## The Problem Nobody Talks About
542
 
543
  Here's a question: *How do you teach a machine to teach?*
@@ -559,9 +608,11 @@ a visual explanation, step by step, where each step advances the algorithm AND a
559
  the learner's understanding.**
560
 
561
  That's the gap this project fills.
562
- """)
 
563
 
564
- gr.Markdown("""
 
565
  ## What This Is
566
 
567
  The Visual Reasoning Environment is an
@@ -575,7 +626,8 @@ correctness, concept coverage, narration quality, and teaching pedagogy.
575
  Think of it this way: you're not training the model to *know* BFS.
576
  You're training it to *teach* BFS the way the best professor you ever had would --
577
  with a marker in hand and an audience that needs to follow along.
578
- """)
 
579
 
580
  gr.HTML(
581
  '<a href="https://youtu.be/KwWqjuyfWzw" target="_blank">'
@@ -584,7 +636,8 @@ with a marker in hand and an audience that needs to follow along.
584
  'alt="Watch the Demo"/></a>'
585
  )
586
 
587
- gr.Markdown("""
 
588
  ## What the Agent Does
589
 
590
  Every scenario starts with an **empty canvas**. Nothing is drawn.
@@ -682,33 +735,32 @@ unsupported concept claims (0.30), too many ops (0.50), and info dumps (0.20).
682
  ## The Reinforcement Learning Loop
683
 
684
  ```
685
- +---------------------------------------------------------------------+
686
- | TRAINING LOOP (GRPO / RLVR) |
687
- | |
688
- | +-----------+ prompt +--------------+ JSON action |
689
- | | | ------------->| | ------------------+ |
690
- | | Scenario | | LLM Agent | | |
691
- | | Generator| | (Teacher) | | |
692
- | | | +--------->| |<----------+ | |
693
- | +-----------+ | +--------------+ | | |
694
- | | | | |
695
- | observation reward | |
696
- | + score breakdown signal | |
697
- | | | | |
698
- | +--------+---------+ score +----+---+ | |
699
- | | | <------------------- | | | |
700
- | | Environment | | Scoring | | |
701
- | | (Empty Canvas) | ------------------> | Engine | | |
702
- | | | canvas state |(13 dim) | | |
703
- | +------------------+ +--------+ | |
704
- | ^ | |
705
- | | step(action) | |
706
- | +-------------------------------------------+ |
707
- | |
708
- | Per-step reward = delta(overall_score) + penalties + bonuses |
709
- | Episode: empty canvas --> Phase 1 (draw) --> Phase 2 (solve) |
710
- | --> Phase 3 (summarize) --> done |
711
- +---------------------------------------------------------------------+
712
  ```
713
 
714
  Every episode starts with a blank canvas and a goal like
@@ -762,19 +814,22 @@ the system without hand-coding rules for every field.
762
  "The goal is not to be impressive. The goal is to be clear."
763
  That's the north star of this project -- training machines not to be
764
  impressive explainers, but clear ones.*
765
- """)
 
766
 
767
  # ============================================================
768
  # TAB 2: Live Demo
769
  # ============================================================
770
  with gr.Tab("Live Demo"):
771
- gr.Markdown("""
 
772
  ## Watch an LLM Teach
773
 
774
  See the agent explain a CS algorithm in real-time -- canvas visualization
775
  with voice narration. Select a scenario, click **Start Demo**, then click the
776
  viewer area to activate audio.
777
- """)
 
778
 
779
  with gr.Row():
780
  demo_hf_token = gr.Textbox(
@@ -800,7 +855,9 @@ viewer area to activate audio.
800
  demo_start_btn = gr.Button("Start Demo", variant="primary", scale=1)
801
  demo_stop_btn = gr.Button("Stop Demo", variant="stop", scale=1)
802
 
803
- demo_status = gr.Markdown("_Enter your HF token, select a scenario, and click Start Demo._")
 
 
804
 
805
  gr.HTML(
806
  value=(
@@ -825,7 +882,8 @@ viewer area to activate audio.
825
  # TAB 3: Scoring & Architecture (technical)
826
  # ============================================================
827
  with gr.Tab("Scoring & Architecture"):
828
- gr.Markdown("""
 
829
  ## Scoring System
830
 
831
  The overall score is a **weighted sum of 13 sub-scores** (each 0-1) **minus 5 penalties**.
@@ -833,7 +891,8 @@ Weights are tuned per difficulty level -- harder tiers emphasize algorithm corre
833
  while easier tiers give more weight to narration quality and concept coverage.
834
 
835
  **Select a difficulty level** to see the weight distribution:
836
- """)
 
837
 
838
  difficulty_radio = gr.Radio(
839
  choices=["easy", "medium", "hard", "expert"],
@@ -848,7 +907,8 @@ while easier tiers give more weight to narration quality and concept coverage.
848
  outputs=[weights_display],
849
  )
850
 
851
- gr.Markdown("""
 
852
  ---
853
  ### Sub-Score Details
854
 
@@ -932,13 +992,15 @@ track membership for `push_to`/`pop_from`. Common source of LLM confusion.
932
  plus relative prefixes (`below:`, `right-of:`, ...) instead of numeric coordinates.
933
  This reduces the positioning search space from ~331K to ~6.5K, making it learnable within
934
  reasonable training budgets.
935
- """)
 
936
 
937
  # ============================================================
938
  # TAB 4: API Reference (technical)
939
  # ============================================================
940
  with gr.Tab("API Reference"):
941
- gr.Markdown(f"""
 
942
  ## API Reference
943
 
944
  This Space exposes the standard **OpenEnv** HTTP + WebSocket API under `/api`.
@@ -1044,7 +1106,8 @@ export LOCAL_IMAGE_NAME=http://127.0.0.1:8000
1044
  python inference.py # headless
1045
  python inference_tldraw.py # with tldraw browser viewer
1046
  ```
1047
- """)
 
1048
 
1049
  return demo
1050
 
@@ -1067,6 +1130,7 @@ app = create_app(
1067
  # Additional routes for live viewer + audio
1068
  # ---------------------------------------------------------------------------
1069
 
 
1070
  @app.get("/viewer")
1071
  async def serve_viewer():
1072
  viewer_path = VIEWER_DIR / "audio_viewer.html"
@@ -1114,7 +1178,9 @@ def main():
1114
  print(f" Live Viewer: http://{args.host}:{args.port}/viewer")
1115
  print(f" OpenEnv API: http://{args.host}:{args.port}/reset, /step, /health")
1116
  if not LLM_API_KEY:
1117
- print(" NOTE: No HF_TOKEN/API_KEY env var — users can enter token in the Viewer tab")
 
 
1118
 
1119
  uvicorn.run(app, host=args.host, port=args.port)
1120
 
 
43
  from .scenario_loader import load_scenarios
44
  from .scoring import weights_for_difficulty
45
  from .constants import (
46
+ ALLOWED_OPS,
47
+ ROLE_VALUES,
48
+ REGION_STYLES,
49
+ NAMED_POSITIONS,
50
  )
51
  except ImportError:
52
  from models import VisualReasoningAction, VisualReasoningObservation
 
54
  from server.scenario_loader import load_scenarios
55
  from server.scoring import weights_for_difficulty
56
  from server.constants import (
57
+ ALLOWED_OPS,
58
+ ROLE_VALUES,
59
+ REGION_STYLES,
60
+ NAMED_POSITIONS,
61
  )
62
 
63
 
 
103
 
104
 
105
  def _tier_emoji(t: str) -> str:
106
+ return {
107
+ "easy": "\U0001f7e2",
108
+ "medium": "\U0001f7e1",
109
+ "hard": "\U0001f7e0",
110
+ "expert": "\U0001f534",
111
+ }.get(t, "⚪")
112
 
113
 
114
  def _tier_label(t: str) -> str:
 
119
  # Broadcaster (WebSocket fan-out for live viewer)
120
  # ---------------------------------------------------------------------------
121
 
122
+
123
  class Broadcaster:
124
  """Append-only message log with long-poll support."""
125
 
 
174
  # TTS via fal-ai Kokoro
175
  # ---------------------------------------------------------------------------
176
 
177
+
178
  async def _tts_to_base64(text: str) -> Optional[str]:
179
  if not text or not text.strip():
180
  return None
 
182
  return None
183
  try:
184
  import aiohttp
185
+
186
  async with aiohttp.ClientSession() as session:
187
  async with session.post(
188
  "https://fal.run/fal-ai/kokoro",
 
195
  ) as resp:
196
  if resp.status != 200:
197
  body = await resp.text()
198
+ print(
199
+ f"[TTS] fal-ai error {resp.status}: {body}",
200
+ file=sys.stderr,
201
+ flush=True,
202
+ )
203
  return None
204
  data = await resp.json()
205
  audio_url = data.get("audio", {}).get("url")
206
  if not audio_url:
207
+ print(
208
+ "[TTS] no audio URL in fal-ai response", file=sys.stderr, flush=True
209
+ )
210
  return None
211
+ async with session.get(
212
+ audio_url, timeout=aiohttp.ClientTimeout(total=15)
213
+ ) as audio_resp:
214
  audio_bytes = await audio_resp.read()
215
  result = base64.b64encode(audio_bytes).decode("ascii")
216
  print(f"[TTS] generated {len(audio_bytes)} bytes of audio", flush=True)
 
233
  key = api_key or LLM_API_KEY
234
  if _openai_client is None or key != _openai_client_key:
235
  from openai import OpenAI
236
+
237
  _openai_client = OpenAI(base_url=LLM_API_BASE, api_key=key)
238
  _openai_client_key = key
239
  return _openai_client
 
242
  def _build_system_prompt() -> str:
243
  try:
244
  from inference import SYSTEM_PROMPT
245
+
246
  return SYSTEM_PROMPT
247
  except Exception:
248
  return (
 
251
  )
252
 
253
 
254
+ def _build_user_prompt(
255
+ obs: Any, last_action: Optional[Dict], last_reward: float, history: List[str]
256
+ ) -> str:
257
  try:
258
  from inference import build_user_prompt
259
+
260
  return build_user_prompt(obs, last_action, last_reward, history)
261
  except Exception:
262
  return f"Goal: {obs.goal}\nEntities: {list(obs.entities.keys())}\nRemaining steps: {obs.remaining_step_budget}"
 
272
  def _parse_llm_response(text: str) -> Optional[Dict[str, Any]]:
273
  try:
274
  from inference import parse_action, normalize_action
275
+
276
  parsed = parse_action(text)
277
  return normalize_action(parsed or {})
278
  except Exception:
279
  import re
280
+
281
  match = re.search(r"\{.*\}", text, flags=re.DOTALL)
282
  if match:
283
  try:
 
344
 
345
  goal_audio = await _tts_to_base64(obs.goal)
346
  snap = _obs_snapshot(obs)
347
+ await _broadcaster.send(
348
+ {
349
+ "type": "reset",
350
+ "task_name": obs.task_name,
351
+ "scenario_id": obs.scenario_id,
352
+ "goal": obs.goal,
353
+ "checklist": list(obs.concept_checklist),
354
+ "input_data": dict(obs.input_data),
355
+ "constraints": list(obs.constraints),
356
+ "max_steps": obs.max_steps,
357
+ "audio": goal_audio,
358
+ **snap,
359
+ }
360
+ )
361
 
362
  await asyncio.sleep(2.0)
363
 
 
380
  temperature=LLM_TEMPERATURE,
381
  max_tokens=LLM_MAX_TOKENS,
382
  stream=False,
383
+ )
384
+ .choices[0]
385
+ .message.content
386
+ or "",
387
  )
388
  text = _strip_thinking_tokens(text)
389
  print(f"[DEMO] Step {step}: LLM returned {len(text)} chars", flush=True)
 
424
  history.append(f"Step {step}: {action_dict.get('narration', '')}")
425
 
426
  snap = _obs_snapshot(obs)
427
+ await _broadcaster.send(
428
+ {
429
+ "type": "step",
430
+ "task_name": obs_dict.get("task_name", ""),
431
+ "scenario_id": obs_dict.get("scenario_id", ""),
432
+ "step": step,
433
+ "step_type": action_dict.get("step_type"),
434
+ "intent": action_dict.get("intent", ""),
435
+ "narration": narration,
436
+ "ops": action_dict.get("ops", []),
437
+ "covered_concepts": action_dict.get("covered_concepts", []),
438
+ "reward": float(reward),
439
+ "score": float(overall),
440
+ "done": bool(done),
441
+ "error": error,
442
+ "audio": audio_b64,
443
+ **snap,
444
+ }
445
+ )
446
 
447
  if done:
448
  await asyncio.sleep(2.0)
 
454
  print(f"[DEMO] error: {exc}", file=sys.stderr, flush=True)
455
  traceback.print_exc()
456
 
457
+ await _broadcaster.send(
458
+ {
459
+ "type": "end",
460
+ "task_name": scenario_id,
461
+ "success": score >= 0.65,
462
+ "steps": steps_taken,
463
+ "score": float(score),
464
+ "rewards": [float(r) for r in rewards],
465
+ }
466
+ )
467
 
468
 
469
  # ---------------------------------------------------------------------------
470
  # Scenario browser callbacks
471
  # ---------------------------------------------------------------------------
472
 
473
+
474
  def list_scenario_choices() -> List[str]:
475
  return [_scenario_display_name(s) for s in _get_scenarios()]
476
 
 
499
 
500
  checklist_md = "\n".join(f"- `{c}`" for c in checklist)
501
 
502
+ constraints_md = (
503
+ "\n".join(f"- `{c}`" for c in constraints) if constraints else "_None_"
504
+ )
505
 
506
  return header, input_md, checklist_md, constraints_md
507
  return "Scenario not found.", "", "", ""
 
511
  # Scoring explorer callback
512
  # ---------------------------------------------------------------------------
513
 
514
+
515
  def show_weights(difficulty: str) -> str:
516
  d = difficulty.lower()
517
  w = weights_for_difficulty(d)
 
528
  # Live demo callbacks (Gradio)
529
  # ---------------------------------------------------------------------------
530
 
531
+
532
+ async def _start_live_demo(
533
+ scenario_choice: str, hf_token: str = "", model_name: str = ""
534
+ ) -> str:
535
  global _demo_task
536
  if not scenario_choice:
537
  return "Select a scenario first."
 
568
  # Gradio UI (custom builder for openenv's gradio_builder parameter)
569
  # ---------------------------------------------------------------------------
570
 
571
+
572
+ def build_ui(
573
+ web_manager, action_fields, metadata, is_chat_env, title, quick_start_md
574
+ ) -> gr.Blocks:
575
  """Custom Gradio UI builder for openenv's gradio_builder parameter."""
576
  with gr.Blocks(
577
  title="Visual Reasoning Environment",
 
585
  # TAB 1: About (blog-style, mirrors README)
586
  # ============================================================
587
  with gr.Tab("About"):
588
+ gr.Markdown(
589
+ """
590
  ## The Problem Nobody Talks About
591
 
592
  Here's a question: *How do you teach a machine to teach?*
 
608
  the learner's understanding.**
609
 
610
  That's the gap this project fills.
611
+ """
612
+ )
613
 
614
+ gr.Markdown(
615
+ """
616
  ## What This Is
617
 
618
  The Visual Reasoning Environment is an
 
626
  Think of it this way: you're not training the model to *know* BFS.
627
  You're training it to *teach* BFS the way the best professor you ever had would --
628
  with a marker in hand and an audience that needs to follow along.
629
+ """
630
+ )
631
 
632
  gr.HTML(
633
  '<a href="https://youtu.be/KwWqjuyfWzw" target="_blank">'
 
636
  'alt="Watch the Demo"/></a>'
637
  )
638
 
639
+ gr.Markdown(
640
+ """
641
  ## What the Agent Does
642
 
643
  Every scenario starts with an **empty canvas**. Nothing is drawn.
 
735
  ## The Reinforcement Learning Loop
736
 
737
  ```
738
+ +-------------------------------------------------------------------+
739
+ | TRAINING LOOP (GRPO / RLVR) |
740
+ | |
741
+ | +-----------+ prompt +--------------+ JSON action |
742
+ | | | --------> | | --------------+ |
743
+ | | Scenario | | LLM Agent | | |
744
+ | | Generator| | (Teacher) | | |
745
+ | | | +------> | | <--------+ | |
746
+ | +-----------+ | +--------------+ | | |
747
+ | | | | |
748
+ | observation reward | |
749
+ | + score breakdown signal | |
750
+ | | | | |
751
+ | +--------+--------+ score +------+--+ | |
752
+ | | | <-------------- | | | |
753
+ | | Environment | | Scoring | | |
754
+ | | (Empty Canvas) | --------------> | Engine | | |
755
+ | | | canvas state |(13 dim) | | |
756
+ | +-----------------+ +---------+ | |
757
+ | ^ | |
758
+ | | step(action) | |
759
+ | +--------------------------------------+ |
760
+ | |
761
+ | Per-step reward = delta(overall_score) + penalties + bonuses |
762
+ | Episode: empty canvas --> Phase 1 --> Phase 2 --> Phase 3 |
763
+ +-------------------------------------------------------------------+
 
764
  ```
765
 
766
  Every episode starts with a blank canvas and a goal like
 
814
  "The goal is not to be impressive. The goal is to be clear."
815
  That's the north star of this project -- training machines not to be
816
  impressive explainers, but clear ones.*
817
+ """
818
+ )
819
 
820
  # ============================================================
821
  # TAB 2: Live Demo
822
  # ============================================================
823
  with gr.Tab("Live Demo"):
824
+ gr.Markdown(
825
+ """
826
  ## Watch an LLM Teach
827
 
828
  See the agent explain a CS algorithm in real-time -- canvas visualization
829
  with voice narration. Select a scenario, click **Start Demo**, then click the
830
  viewer area to activate audio.
831
+ """
832
+ )
833
 
834
  with gr.Row():
835
  demo_hf_token = gr.Textbox(
 
855
  demo_start_btn = gr.Button("Start Demo", variant="primary", scale=1)
856
  demo_stop_btn = gr.Button("Stop Demo", variant="stop", scale=1)
857
 
858
+ demo_status = gr.Markdown(
859
+ "_Enter your HF token, select a scenario, and click Start Demo._"
860
+ )
861
 
862
  gr.HTML(
863
  value=(
 
882
  # TAB 3: Scoring & Architecture (technical)
883
  # ============================================================
884
  with gr.Tab("Scoring & Architecture"):
885
+ gr.Markdown(
886
+ """
887
  ## Scoring System
888
 
889
  The overall score is a **weighted sum of 13 sub-scores** (each 0-1) **minus 5 penalties**.
 
891
  while easier tiers give more weight to narration quality and concept coverage.
892
 
893
  **Select a difficulty level** to see the weight distribution:
894
+ """
895
+ )
896
 
897
  difficulty_radio = gr.Radio(
898
  choices=["easy", "medium", "hard", "expert"],
 
907
  outputs=[weights_display],
908
  )
909
 
910
+ gr.Markdown(
911
+ """
912
  ---
913
  ### Sub-Score Details
914
 
 
992
  plus relative prefixes (`below:`, `right-of:`, ...) instead of numeric coordinates.
993
  This reduces the positioning search space from ~331K to ~6.5K, making it learnable within
994
  reasonable training budgets.
995
+ """
996
+ )
997
 
998
  # ============================================================
999
  # TAB 4: API Reference (technical)
1000
  # ============================================================
1001
  with gr.Tab("API Reference"):
1002
+ gr.Markdown(
1003
+ f"""
1004
  ## API Reference
1005
 
1006
  This Space exposes the standard **OpenEnv** HTTP + WebSocket API under `/api`.
 
1106
  python inference.py # headless
1107
  python inference_tldraw.py # with tldraw browser viewer
1108
  ```
1109
+ """
1110
+ )
1111
 
1112
  return demo
1113
 
 
1130
  # Additional routes for live viewer + audio
1131
  # ---------------------------------------------------------------------------
1132
 
1133
+
1134
  @app.get("/viewer")
1135
  async def serve_viewer():
1136
  viewer_path = VIEWER_DIR / "audio_viewer.html"
 
1178
  print(f" Live Viewer: http://{args.host}:{args.port}/viewer")
1179
  print(f" OpenEnv API: http://{args.host}:{args.port}/reset, /step, /health")
1180
  if not LLM_API_KEY:
1181
+ print(
1182
+ " NOTE: No HF_TOKEN/API_KEY env var — users can enter token in the Viewer tab"
1183
+ )
1184
 
1185
  uvicorn.run(app, host=args.host, port=args.port)
1186
 
server/app_backup.py DELETED
@@ -1,46 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the BSD-style license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- """FastAPI application for the Visual Reasoning Environment."""
8
-
9
- try:
10
- from openenv.core.env_server.http_server import create_app
11
- except Exception as e:
12
- raise ImportError(
13
- "openenv is required for the web interface. Install dependencies with 'uv sync'."
14
- ) from e
15
-
16
- try:
17
- from ..models import VisualReasoningAction, VisualReasoningObservation
18
- from .visual_reasoning_environment import VisualReasoningEnvironment
19
- except ImportError:
20
- from models import VisualReasoningAction, VisualReasoningObservation
21
- from server.visual_reasoning_environment import VisualReasoningEnvironment
22
-
23
- app = create_app(
24
- VisualReasoningEnvironment,
25
- VisualReasoningAction,
26
- VisualReasoningObservation,
27
- env_name="visual_reasoning",
28
- max_concurrent_envs=1,
29
- )
30
-
31
-
32
- def main():
33
- """Run the FastAPI server."""
34
- import argparse
35
-
36
- import uvicorn
37
-
38
- parser = argparse.ArgumentParser()
39
- parser.add_argument("--host", type=str, default="0.0.0.0")
40
- parser.add_argument("--port", type=int, default=8000)
41
- args = parser.parse_args()
42
- uvicorn.run(app, host=args.host, port=args.port)
43
-
44
-
45
- if __name__ == "__main__":
46
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
train.ipynb DELETED
@@ -1,913 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "markdown",
5
- "id": "title",
6
- "metadata": {},
7
- "source": [
8
- "# Visual Reasoning — Training Demo\n",
9
- "\n",
10
- "**Simulation-based RL for teaching CS algorithms on a whiteboard**\n",
11
- "\n",
12
- "This notebook trains a small LLM to be an expert visual explainer of CS algorithms.\n",
13
- "The model draws data structures, walks through algorithms step-by-step, and narrates\n",
14
- "the reasoning — scored by a 12-dimension reward system.\n",
15
- "\n",
16
- "| Stage | Purpose |\n",
17
- "|-------|:--------|\n",
18
- "| Baseline | Score the untrained model across all difficulties |\n",
19
- "| SFT warmup | Teach JSON action format via gold demonstrations |\n",
20
- "| GRPO | RL with dense environment rewards, easy → expert curriculum |\n",
21
- "| Final eval | Delta report comparing all three checkpoints |\n",
22
- "\n",
23
- "- 17 scenarios (9 hand-crafted + 8 procedurally generated) across 4 difficulty levels\n",
24
- "- 9 algorithm templates: linked list, stack, binary search, BFS, hash table, Dijkstra, BST, fib memo, quicksort\n",
25
- "- 12 weighted sub-scores + 5 penalties → dense per-step reward\n",
26
- "\n",
27
- "> **No model saving** — everything stays in-memory. The same LoRA adapter flows from SFT into GRPO."
28
- ]
29
- },
30
- {
31
- "cell_type": "markdown",
32
- "id": "install-hdr",
33
- "metadata": {},
34
- "source": [
35
- "## 1. Install Dependencies"
36
- ]
37
- },
38
- {
39
- "cell_type": "code",
40
- "execution_count": null,
41
- "id": "install",
42
- "metadata": {},
43
- "outputs": [],
44
- "source": [
45
- "!pip install -q --upgrade \"torchvision>=0.25.0\"\n",
46
- "!pip install -q huggingface_hub unsloth trl datasets transformers accelerate bitsandbytes peft torch\n",
47
- "!pip install -q openenv-core fastapi uvicorn pydantic\n",
48
- "!pip install -q python-dotenv networkx shapely sentence-transformers rapidfuzz textstat sortedcontainers \"numpy<2.0\""
49
- ]
50
- },
51
- {
52
- "cell_type": "markdown",
53
- "id": "clone-hdr",
54
- "metadata": {},
55
- "source": [
56
- "## 2. Clone Visual Reasoning Environment"
57
- ]
58
- },
59
- {
60
- "cell_type": "code",
61
- "execution_count": null,
62
- "id": "clone",
63
- "metadata": {},
64
- "outputs": [],
65
- "source": [
66
- "import os\n",
67
- "\n",
68
- "# Use heuristic narration scorer — no GPU contention with training model\n",
69
- "os.environ[\"NARRATION_SCORER\"] = \"fallback\"\n",
70
- "\n",
71
- "from huggingface_hub import snapshot_download\n",
72
- "\n",
73
- "if os.path.basename(os.getcwd()) != \"visual_reasoning\":\n",
74
- " if not os.path.isdir(\"visual_reasoning\"):\n",
75
- " snapshot_download(\n",
76
- " repo_id=\"sreeramajay/visual_reasoning-env\",\n",
77
- " repo_type=\"space\",\n",
78
- " local_dir=\"visual_reasoning\",\n",
79
- " ignore_patterns=[\"*.gitattributes\", \".gitignore\", \"README.md\"],\n",
80
- " )\n",
81
- " os.chdir(\"visual_reasoning\")"
82
- ]
83
- },
84
- {
85
- "cell_type": "markdown",
86
- "id": "verify-hdr",
87
- "metadata": {},
88
- "source": [
89
- "## 3. Verify Environment — Run a Smoke Test"
90
- ]
91
- },
92
- {
93
- "cell_type": "code",
94
- "execution_count": null,
95
- "id": "verify",
96
- "metadata": {},
97
- "outputs": [],
98
- "source": [
99
- "import sys, os, json, torch\n",
100
- "sys.path.insert(0, '.')\n",
101
- "\n",
102
- "from unsloth import FastLanguageModel\n",
103
- "\n",
104
- "from models import VisualReasoningAction\n",
105
- "from server.visual_reasoning_environment import VisualReasoningEnvironment\n",
106
- "\n",
107
- "test_env = VisualReasoningEnvironment()\n",
108
- "obs = test_env.reset(scenario_id='easy_1')\n",
109
- "print(f'Scenario: {obs.scenario_id}')\n",
110
- "print(f'Goal: {obs.goal}')\n",
111
- "print(f'Concepts: {obs.concept_checklist}')\n",
112
- "print(f'Step budget: {obs.remaining_step_budget}')\n",
113
- "\n",
114
- "obs = test_env.step(VisualReasoningAction(\n",
115
- " step_type='advance',\n",
116
- " narration='Adding the first node with value 10.',\n",
117
- " ops=[{'op': 'add_node', 'target_ids': ['n0'], 'params': {'value': 10}}],\n",
118
- " covered_concepts=['node_value'],\n",
119
- " intent='test',\n",
120
- "))\n",
121
- "print(f'\\nAfter step: entities={list(obs.entities.keys())}, reward={obs.reward:.3f}, error={obs.action_error}')\n",
122
- "print('Environment OK!')\n",
123
- "del test_env"
124
- ]
125
- },
126
- {
127
- "cell_type": "markdown",
128
- "id": "config-hdr",
129
- "metadata": {},
130
- "source": [
131
- "## 4. Configuration"
132
- ]
133
- },
134
- {
135
- "cell_type": "code",
136
- "execution_count": null,
137
- "id": "config",
138
- "metadata": {},
139
- "outputs": [],
140
- "source": [
141
- "from collections import Counter\n",
142
- "from inference import SYSTEM_PROMPT, build_user_prompt, parse_action, normalize_action\n",
143
- "\n",
144
- "MODEL_NAME = 'unsloth/Qwen2.5-3B-Instruct-bnb-4bit'\n",
145
- "MAX_SEQ_LENGTH = 4096\n",
146
- "LORA_R = 16\n",
147
- "LORA_ALPHA = 32\n",
148
- "SFT_EPOCHS = 3\n",
149
- "GRPO_EPOCHS = 2\n",
150
- "\n",
151
- "SCENARIOS = {\n",
152
- " 'easy': ['easy_1', 'easy_2', 'easy_3', 'gen_easy_1001', 'gen_easy_1002'],\n",
153
- " 'medium': ['medium_1', 'medium_2', 'gen_medium_2001', 'gen_medium_2002'],\n",
154
- " 'hard': ['hard_1', 'hard_2', 'gen_hard_3001', 'gen_hard_3002'],\n",
155
- " 'expert': ['expert_1', 'expert_2', 'gen_expert_4001', 'gen_expert_4002'],\n",
156
- "}\n",
157
- "DIFFICULTIES = ('easy', 'medium', 'hard', 'expert')"
158
- ]
159
- },
160
- {
161
- "cell_type": "markdown",
162
- "id": "model-hdr",
163
- "metadata": {},
164
- "source": [
165
- "## 5. Load Model (Unsloth + LoRA)"
166
- ]
167
- },
168
- {
169
- "cell_type": "code",
170
- "execution_count": null,
171
- "id": "load-model",
172
- "metadata": {},
173
- "outputs": [],
174
- "source": [
175
- "model, tokenizer = FastLanguageModel.from_pretrained(\n",
176
- " model_name=MODEL_NAME,\n",
177
- " max_seq_length=MAX_SEQ_LENGTH,\n",
178
- " dtype=None,\n",
179
- " load_in_4bit=True,\n",
180
- ")\n",
181
- "model = FastLanguageModel.get_peft_model(\n",
182
- " model,\n",
183
- " r=LORA_R,\n",
184
- " lora_alpha=LORA_ALPHA,\n",
185
- " lora_dropout=0,\n",
186
- " target_modules=[\n",
187
- " 'q_proj', 'k_proj', 'v_proj', 'o_proj',\n",
188
- " 'gate_proj', 'up_proj', 'down_proj',\n",
189
- " ],\n",
190
- " bias='none',\n",
191
- " use_gradient_checkpointing='unsloth',\n",
192
- " random_state=0,\n",
193
- ")\n",
194
- "if tokenizer.pad_token_id is None:\n",
195
- " tokenizer.pad_token = tokenizer.eos_token\n",
196
- "\n",
197
- "model.generation_config.max_length = None\n",
198
- "\n",
199
- "trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
200
- "total = sum(p.numel() for p in model.parameters())\n",
201
- "print(f'Parameters: {total / 1e6:.1f}M total, {trainable / 1e6:.2f}M trainable ({trainable / total:.2%})')"
202
- ]
203
- },
204
- {
205
- "cell_type": "markdown",
206
- "id": "helpers-hdr",
207
- "metadata": {},
208
- "source": [
209
- "## 6. Environment Wrapper & Batched Evaluation\n",
210
- "\n",
211
- "Uses the environment **directly in-process** — no server needed.\n",
212
- "- **Batched generation**: collects prompts from up to 8 parallel episodes per `model.generate()` call\n",
213
- "- **Early termination**: kills episodes after 3 consecutive no-ops\n",
214
- "- **Environment pool**: reuses env instances across eval calls"
215
- ]
216
- },
217
- {
218
- "cell_type": "code",
219
- "execution_count": null,
220
- "id": "helpers",
221
- "metadata": {},
222
- "outputs": [],
223
- "source": [
224
- "import time\n",
225
- "\n",
226
- "EVAL_BATCH_SIZE = 8\n",
227
- "EVAL_MAX_STEPS = 24\n",
228
- "NOOP_EARLY_STOP = 3\n",
229
- "\n",
230
- "\n",
231
- "class EnvRunner:\n",
232
- " \"\"\"Thin wrapper over VisualReasoningEnvironment for a clean reset/step API.\"\"\"\n",
233
- "\n",
234
- " def __init__(self):\n",
235
- " self.env = VisualReasoningEnvironment()\n",
236
- "\n",
237
- " def reset(self, scenario_id=None, task_name=None):\n",
238
- " return self.env.reset(scenario_id=scenario_id, task_name=task_name)\n",
239
- "\n",
240
- " def step(self, action_dict):\n",
241
- " act = VisualReasoningAction(**action_dict)\n",
242
- " obs = self.env.step(act)\n",
243
- " return obs, float(obs.reward), bool(obs.done)\n",
244
- "\n",
245
- "\n",
246
- "FALLBACK_ACTION = {\n",
247
- " 'step_type': 'complete', 'narration': 'Explanation complete.',\n",
248
- " 'ops': [], 'covered_concepts': [], 'intent': 'finalize',\n",
249
- "}\n",
250
- "\n",
251
- "_env_pool = []\n",
252
- "\n",
253
- "def _get_env(idx):\n",
254
- " while len(_env_pool) <= idx:\n",
255
- " _env_pool.append(VisualReasoningEnvironment())\n",
256
- " return _env_pool[idx]\n",
257
- "\n",
258
- "\n",
259
- "class EpisodeState:\n",
260
- " \"\"\"Tracks one episode's state for batched eval.\"\"\"\n",
261
- " def __init__(self, env, scenario_id):\n",
262
- " self.env = env\n",
263
- " self.scenario_id = scenario_id\n",
264
- " self.obs = env.reset(scenario_id=scenario_id)\n",
265
- " self.last_action = None\n",
266
- " self.last_reward = 0.0\n",
267
- " self.history = []\n",
268
- " self.steps = 0\n",
269
- " self.done = False\n",
270
- " self.score = 0.0\n",
271
- " self.consecutive_noops = 0\n",
272
- "\n",
273
- " def build_prompt_text(self):\n",
274
- " user_prompt = build_user_prompt(\n",
275
- " self.obs, self.last_action, self.last_reward, self.history\n",
276
- " )\n",
277
- " messages = [\n",
278
- " {'role': 'system', 'content': SYSTEM_PROMPT},\n",
279
- " {'role': 'user', 'content': user_prompt},\n",
280
- " ]\n",
281
- " return tokenizer.apply_chat_template(\n",
282
- " messages, tokenize=False, add_generation_prompt=True\n",
283
- " )\n",
284
- "\n",
285
- " def apply_action(self, action_dict):\n",
286
- " if action_dict is None:\n",
287
- " action_dict = FALLBACK_ACTION\n",
288
- " self.obs = self.env.step(VisualReasoningAction(**action_dict))\n",
289
- " reward = float(self.obs.reward)\n",
290
- " self.last_action = action_dict\n",
291
- " self.last_reward = reward\n",
292
- " self.steps += 1\n",
293
- " self.history.append(f\"Step {self.steps}: {action_dict.get('narration', '')}\")\n",
294
- " if reward <= -0.04:\n",
295
- " self.consecutive_noops += 1\n",
296
- " else:\n",
297
- " self.consecutive_noops = 0\n",
298
- " if self.obs.done or self.consecutive_noops >= NOOP_EARLY_STOP:\n",
299
- " self.done = True\n",
300
- " self.score = float(self.obs.score_breakdown.get('overall_score', 0.0))\n",
301
- "\n",
302
- "\n",
303
- "def batched_generate(model, tokenizer, prompt_texts, max_new_tokens=384):\n",
304
- " \"\"\"Run model.generate on a batch of prompt strings.\"\"\"\n",
305
- " FastLanguageModel.for_inference(model)\n",
306
- " orig_side = tokenizer.padding_side\n",
307
- " tokenizer.padding_side = 'left'\n",
308
- " inputs = tokenizer(\n",
309
- " prompt_texts,\n",
310
- " return_tensors='pt',\n",
311
- " padding=True,\n",
312
- " truncation=True,\n",
313
- " max_length=MAX_SEQ_LENGTH,\n",
314
- " ).to(model.device)\n",
315
- " tokenizer.padding_side = orig_side\n",
316
- " with torch.no_grad():\n",
317
- " outputs = model.generate(\n",
318
- " **inputs,\n",
319
- " max_new_tokens=max_new_tokens,\n",
320
- " do_sample=False,\n",
321
- " pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,\n",
322
- " )\n",
323
- " results = []\n",
324
- " for i, out in enumerate(outputs):\n",
325
- " input_len = inputs.input_ids[i].ne(tokenizer.pad_token_id).sum()\n",
326
- " text = tokenizer.decode(out[input_len:], skip_special_tokens=True)\n",
327
- " results.append(text)\n",
328
- " return results\n",
329
- "\n",
330
- "\n",
331
- "def evaluate(model, tokenizer, label, stages=None):\n",
332
- " \"\"\"Batched evaluation across all scenarios.\"\"\"\n",
333
- " stages = stages or list(DIFFICULTIES)\n",
334
- " print(f'\\n{\"=\" * 60}')\n",
335
- " print(f' {label}')\n",
336
- " print(f'{\"=\" * 60}')\n",
337
- "\n",
338
- " all_scenario_ids = []\n",
339
- " for diff in stages:\n",
340
- " all_scenario_ids.extend([(diff, sid) for sid in SCENARIOS.get(diff, [])])\n",
341
- "\n",
342
- " results_by_diff = {d: [] for d in stages}\n",
343
- " t0 = time.time()\n",
344
- "\n",
345
- " for batch_start in range(0, len(all_scenario_ids), EVAL_BATCH_SIZE):\n",
346
- " batch = all_scenario_ids[batch_start:batch_start + EVAL_BATCH_SIZE]\n",
347
- " t_batch = time.time()\n",
348
- " episodes = []\n",
349
- " for i, (diff, sid) in enumerate(batch):\n",
350
- " episodes.append(EpisodeState(_get_env(i), sid))\n",
351
- "\n",
352
- " for step_num in range(1, EVAL_MAX_STEPS + 1):\n",
353
- " active = [ep for ep in episodes if not ep.done]\n",
354
- " if not active:\n",
355
- " break\n",
356
- " prompts = [ep.build_prompt_text() for ep in active]\n",
357
- " generated = batched_generate(model, tokenizer, prompts)\n",
358
- " for ep, text in zip(active, generated):\n",
359
- " parsed = parse_action(text)\n",
360
- " action = normalize_action(parsed or {}) if parsed else None\n",
361
- " ep.apply_action(action)\n",
362
- "\n",
363
- " batch_time = time.time() - t_batch\n",
364
- " for ep, (diff, sid) in zip(episodes, batch):\n",
365
- " results_by_diff[diff].append(ep.score)\n",
366
- " print(f' [{diff:6}] {sid:22} score={ep.score:.3f} steps={ep.steps}')\n",
367
- " sids_str = ', '.join(sid for _, sid in batch)\n",
368
- " print(f' Batch {batch_start // EVAL_BATCH_SIZE + 1}: '\n",
369
- " f'{len(batch)} scenarios in {batch_time:.1f}s [{sids_str}]')\n",
370
- "\n",
371
- " results = {}\n",
372
- " for diff in stages:\n",
373
- " scores = results_by_diff[diff]\n",
374
- " if scores:\n",
375
- " mean = sum(scores) / len(scores)\n",
376
- " results[diff] = mean\n",
377
- " print(f' [{diff:6}] mean={mean:.3f}')\n",
378
- " overall = sum(results.values()) / max(len(results), 1)\n",
379
- " results['overall'] = overall\n",
380
- " elapsed = time.time() - t0\n",
381
- " print(f' OVERALL: {overall:.3f} ({elapsed:.1f}s)')\n",
382
- " return results\n",
383
- "\n",
384
- "\n",
385
- "env = EnvRunner()\n",
386
- "print('EnvRunner ready.')"
387
- ]
388
- },
389
- {
390
- "cell_type": "markdown",
391
- "id": "baseline-hdr",
392
- "metadata": {},
393
- "source": [
394
- "## 7. Baseline Evaluation\n",
395
- "\n",
396
- "Score the untrained model (random LoRA weights) across all scenarios to establish a baseline."
397
- ]
398
- },
399
- {
400
- "cell_type": "code",
401
- "execution_count": null,
402
- "id": "baseline",
403
- "metadata": {},
404
- "outputs": [],
405
- "source": [
406
- "baseline = evaluate(model, tokenizer, 'Baseline (untrained LoRA)')"
407
- ]
408
- },
409
- {
410
- "cell_type": "markdown",
411
- "id": "gold-hdr",
412
- "metadata": {},
413
- "source": [
414
- "## 8. Gold Demonstrations\n",
415
- "\n",
416
- "Hand-crafted teaching episodes for the three easy scenarios. These teach the model:\n",
417
- "- The JSON action format (`step_type`, `narration`, `ops`, `covered_concepts`, `intent`)\n",
418
- "- Incremental setup (Phase 1), algorithm walk-through (Phase 2), wrap-up (Phase 3)\n",
419
- "- Proper concept evidencing (narration must mention the concept, ops must back it)\n",
420
- "- Region vs container distinction (regions for layout, containers for push/pop)"
421
- ]
422
- },
423
- {
424
- "cell_type": "code",
425
- "execution_count": null,
426
- "id": "gold",
427
- "metadata": {},
428
- "outputs": [],
429
- "source": [
430
- "GOLD_EPISODES = {\n",
431
- " # ── easy_1: linked_list_traversal ──────────────────────────────────\n",
432
- " # input_data: {\"values\": [10, 20, 30]}\n",
433
- " # concepts: head_pointer, node_value, next_link, tail_marker\n",
434
- " 'easy_1': [\n",
435
- " {\n",
436
- " 'step_type': 'advance',\n",
437
- " 'narration': 'Building a linked list — creating a centered layout region and the first two nodes with values 10 and 20.',\n",
438
- " 'ops': [\n",
439
- " {'op': 'add_region', 'target_ids': ['list'], 'params': {'style': 'array', 'title': 'Linked List', 'position': 'center'}},\n",
440
- " {'op': 'add_node', 'target_ids': ['n1'], 'params': {'value': 10, 'region': 'list'}},\n",
441
- " {'op': 'add_node', 'target_ids': ['n2'], 'params': {'value': 20, 'region': 'list'}},\n",
442
- " ],\n",
443
- " 'covered_concepts': ['node_value'],\n",
444
- " 'intent': 'create_list_start',\n",
445
- " },\n",
446
- " {\n",
447
- " 'step_type': 'advance',\n",
448
- " 'narration': 'Adding node 30 and connecting all nodes with next links to form the chain 10 -> 20 -> 30.',\n",
449
- " 'ops': [\n",
450
- " {'op': 'add_node', 'target_ids': ['n3'], 'params': {'value': 30, 'region': 'list'}},\n",
451
- " {'op': 'add_edge', 'target_ids': ['n1', 'n2'], 'params': {'kind': 'directed', 'label': 'next'}},\n",
452
- " {'op': 'add_edge', 'target_ids': ['n2', 'n3'], 'params': {'kind': 'directed', 'label': 'next'}},\n",
453
- " ],\n",
454
- " 'covered_concepts': ['next_link'],\n",
455
- " 'intent': 'connect_nodes',\n",
456
- " },\n",
457
- " {\n",
458
- " 'step_type': 'advance',\n",
459
- " 'narration': 'Placing a head pointer at node 10 because traversal always starts at the head, our only entry point into the list.',\n",
460
- " 'ops': [\n",
461
- " {'op': 'add_pointer', 'target_ids': ['head_ptr'], 'params': {'region': 'list'}},\n",
462
- " {'op': 'move_pointer', 'target_ids': ['head_ptr'], 'params': {'index': 'n1'}},\n",
463
- " {'op': 'annotate', 'target_ids': ['n1'], 'params': {'text': 'Head'}},\n",
464
- " {'op': 'set_role', 'target_ids': ['n1'], 'params': {'role': 'current'}},\n",
465
- " ],\n",
466
- " 'covered_concepts': ['head_pointer'],\n",
467
- " 'intent': 'mark_head',\n",
468
- " },\n",
469
- " {\n",
470
- " 'step_type': 'advance',\n",
471
- " 'narration': 'Following the next link from 10 to 20 — the pointer advances and we mark node 10 as visited.',\n",
472
- " 'ops': [\n",
473
- " {'op': 'set_role', 'target_ids': ['n1'], 'params': {'role': 'visited'}},\n",
474
- " {'op': 'set_role', 'target_ids': ['n2'], 'params': {'role': 'current'}},\n",
475
- " {'op': 'move_pointer', 'target_ids': ['head_ptr'], 'params': {'index': 'n2'}},\n",
476
- " ],\n",
477
- " 'covered_concepts': [],\n",
478
- " 'intent': 'traverse_to_second',\n",
479
- " },\n",
480
- " {\n",
481
- " 'step_type': 'advance',\n",
482
- " 'narration': 'Reaching node 30 — it has no next link, making it the tail that signals the end of traversal.',\n",
483
- " 'ops': [\n",
484
- " {'op': 'set_role', 'target_ids': ['n2'], 'params': {'role': 'visited'}},\n",
485
- " {'op': 'set_role', 'target_ids': ['n3'], 'params': {'role': 'current'}},\n",
486
- " {'op': 'annotate', 'target_ids': ['n3'], 'params': {'text': 'Tail'}},\n",
487
- " ],\n",
488
- " 'covered_concepts': ['tail_marker'],\n",
489
- " 'intent': 'reach_tail',\n",
490
- " },\n",
491
- " {\n",
492
- " 'step_type': 'complete',\n",
493
- " 'narration': 'Traversal complete — visited every node from head to tail following next links, reading values 10, 20, 30 in order.',\n",
494
- " 'ops': [\n",
495
- " {'op': 'set_role', 'target_ids': ['n3'], 'params': {'role': 'done'}},\n",
496
- " ],\n",
497
- " 'covered_concepts': [],\n",
498
- " 'intent': 'summarize',\n",
499
- " },\n",
500
- " ],\n",
501
- "\n",
502
- " # ── easy_2: stack_ops ──────────────────────────────────────────────\n",
503
- " # input_data: {\"operations\": [\"push A\", \"push B\", \"pop\", \"push C\"]}\n",
504
- " # concepts: top_pointer, push, pop, lifo_order\n",
505
- " 'easy_2': [\n",
506
- " {\n",
507
- " 'step_type': 'advance',\n",
508
- " 'narration': 'Setting up a stack with a centered visual region and a container to track push and pop membership.',\n",
509
- " 'ops': [\n",
510
- " {'op': 'add_region', 'target_ids': ['stack_area'], 'params': {'style': 'stack', 'title': 'Stack', 'position': 'center'}},\n",
511
- " {'op': 'add_container', 'target_ids': ['stk'], 'params': {'region': 'stack_area', 'ordered': False, 'title': 'Stack'}},\n",
512
- " ],\n",
513
- " 'covered_concepts': [],\n",
514
- " 'intent': 'setup_stack',\n",
515
- " },\n",
516
- " {\n",
517
- " 'step_type': 'advance',\n",
518
- " 'narration': 'Pushing A onto the stack — A becomes the first element. Adding a top pointer to track the stack top.',\n",
519
- " 'ops': [\n",
520
- " {'op': 'add_node', 'target_ids': ['a'], 'params': {'value': 'A', 'region': 'stack_area'}},\n",
521
- " {'op': 'push_to', 'target_ids': ['stk', 'a'], 'params': {}},\n",
522
- " {'op': 'add_pointer', 'target_ids': ['top'], 'params': {'region': 'stack_area'}},\n",
523
- " {'op': 'move_pointer', 'target_ids': ['top'], 'params': {'index': 'a'}},\n",
524
- " ],\n",
525
- " 'covered_concepts': ['push', 'top_pointer'],\n",
526
- " 'intent': 'push_a',\n",
527
- " },\n",
528
- " {\n",
529
- " 'step_type': 'advance',\n",
530
- " 'narration': 'Pushing B — B sits on top of A and the top pointer moves up to B.',\n",
531
- " 'ops': [\n",
532
- " {'op': 'add_node', 'target_ids': ['b'], 'params': {'value': 'B', 'region': 'stack_area'}},\n",
533
- " {'op': 'push_to', 'target_ids': ['stk', 'b'], 'params': {}},\n",
534
- " {'op': 'move_pointer', 'target_ids': ['top'], 'params': {'index': 'b'}},\n",
535
- " ],\n",
536
- " 'covered_concepts': [],\n",
537
- " 'intent': 'push_b',\n",
538
- " },\n",
539
- " {\n",
540
- " 'step_type': 'advance',\n",
541
- " 'narration': 'Popping from the stack — B was pushed last so B comes off first, demonstrating LIFO (last-in-first-out) order.',\n",
542
- " 'ops': [\n",
543
- " {'op': 'pop_from', 'target_ids': ['stk'], 'params': {}},\n",
544
- " {'op': 'set_role', 'target_ids': ['b'], 'params': {'role': 'inactive'}},\n",
545
- " {'op': 'move_pointer', 'target_ids': ['top'], 'params': {'index': 'a'}},\n",
546
- " ],\n",
547
- " 'covered_concepts': ['pop', 'lifo_order'],\n",
548
- " 'intent': 'pop_b',\n",
549
- " },\n",
550
- " {\n",
551
- " 'step_type': 'advance',\n",
552
- " 'narration': 'Pushing C onto the stack — C now sits on top of A, with B already removed.',\n",
553
- " 'ops': [\n",
554
- " {'op': 'add_node', 'target_ids': ['c'], 'params': {'value': 'C', 'region': 'stack_area'}},\n",
555
- " {'op': 'push_to', 'target_ids': ['stk', 'c'], 'params': {}},\n",
556
- " {'op': 'move_pointer', 'target_ids': ['top'], 'params': {'index': 'c'}},\n",
557
- " ],\n",
558
- " 'covered_concepts': [],\n",
559
- " 'intent': 'push_c',\n",
560
- " },\n",
561
- " {\n",
562
- " 'step_type': 'complete',\n",
563
- " 'narration': 'All four operations executed — stack holds A at bottom and C on top after push A, push B, pop, push C.',\n",
564
- " 'ops': [],\n",
565
- " 'covered_concepts': [],\n",
566
- " 'intent': 'summarize',\n",
567
- " },\n",
568
- " ],\n",
569
- "\n",
570
- " # ── easy_3: binary_search ──────────────────────────────────────────\n",
571
- " # input_data: {\"array\": [1, 3, 5, 7, 9, 11, 13], \"target\": 7}\n",
572
- " # concepts: sorted_invariant, low_pointer, high_pointer, mid_pointer, comparison\n",
573
- " 'easy_3': [\n",
574
- " {\n",
575
- " 'step_type': 'advance',\n",
576
- " 'narration': 'Creating the first four elements of a sorted array in a centered region — the sorted invariant is what makes binary search possible.',\n",
577
- " 'ops': [\n",
578
- " {'op': 'add_region', 'target_ids': ['arr'], 'params': {'style': 'array', 'title': 'Sorted Array', 'position': 'center'}},\n",
579
- " {'op': 'add_node', 'target_ids': ['n0'], 'params': {'value': 1, 'region': 'arr'}},\n",
580
- " {'op': 'add_node', 'target_ids': ['n1'], 'params': {'value': 3, 'region': 'arr'}},\n",
581
- " {'op': 'add_node', 'target_ids': ['n2'], 'params': {'value': 5, 'region': 'arr'}},\n",
582
- " ],\n",
583
- " 'covered_concepts': ['sorted_invariant'],\n",
584
- " 'intent': 'create_array_part1',\n",
585
- " },\n",
586
- " {\n",
587
- " 'step_type': 'advance',\n",
588
- " 'narration': 'Adding the remaining elements 7, 9, 11, 13 to complete all seven values of the sorted array.',\n",
589
- " 'ops': [\n",
590
- " {'op': 'add_node', 'target_ids': ['n3'], 'params': {'value': 7, 'region': 'arr'}},\n",
591
- " {'op': 'add_node', 'target_ids': ['n4'], 'params': {'value': 9, 'region': 'arr'}},\n",
592
- " {'op': 'add_node', 'target_ids': ['n5'], 'params': {'value': 11, 'region': 'arr'}},\n",
593
- " {'op': 'add_node', 'target_ids': ['n6'], 'params': {'value': 13, 'region': 'arr'}},\n",
594
- " ],\n",
595
- " 'covered_concepts': [],\n",
596
- " 'intent': 'create_array_part2',\n",
597
- " },\n",
598
- " {\n",
599
- " 'step_type': 'advance',\n",
600
- " 'narration': 'Placing low pointer at index 0 (value 1) and high pointer at index 6 (value 13) to bracket the search range.',\n",
601
- " 'ops': [\n",
602
- " {'op': 'add_pointer', 'target_ids': ['low'], 'params': {'region': 'arr'}},\n",
603
- " {'op': 'move_pointer', 'target_ids': ['low'], 'params': {'index': 'n0'}},\n",
604
- " {'op': 'add_pointer', 'target_ids': ['high'], 'params': {'region': 'arr'}},\n",
605
- " {'op': 'move_pointer', 'target_ids': ['high'], 'params': {'index': 'n6'}},\n",
606
- " ],\n",
607
- " 'covered_concepts': ['low_pointer', 'high_pointer'],\n",
608
- " 'intent': 'init_pointers',\n",
609
- " },\n",
610
- " {\n",
611
- " 'step_type': 'advance',\n",
612
- " 'narration': 'Computing mid = (0+6)/2 = 3 — the mid pointer lands on value 7, which we compare against our target 7.',\n",
613
- " 'ops': [\n",
614
- " {'op': 'add_pointer', 'target_ids': ['mid'], 'params': {'region': 'arr'}},\n",
615
- " {'op': 'move_pointer', 'target_ids': ['mid'], 'params': {'index': 'n3'}},\n",
616
- " {'op': 'set_role', 'target_ids': ['n3'], 'params': {'role': 'current'}},\n",
617
- " {'op': 'highlight', 'target_ids': ['n3'], 'params': {}},\n",
618
- " ],\n",
619
- " 'covered_concepts': ['mid_pointer', 'comparison'],\n",
620
- " 'intent': 'compute_mid_and_compare',\n",
621
- " },\n",
622
- " {\n",
623
- " 'step_type': 'complete',\n",
624
- " 'narration': 'Target 7 found at index 3 — binary search located it in one comparison because the sorted invariant halves the search space each step.',\n",
625
- " 'ops': [\n",
626
- " {'op': 'set_role', 'target_ids': ['n3'], 'params': {'role': 'done'}},\n",
627
- " {'op': 'annotate', 'target_ids': ['n3'], 'params': {'text': 'Found: 7'}},\n",
628
- " ],\n",
629
- " 'covered_concepts': [],\n",
630
- " 'intent': 'found_target',\n",
631
- " },\n",
632
- " ],\n",
633
- "}\n",
634
- "\n",
635
- "print(f'Gold episodes: {len(GOLD_EPISODES)} scenarios, {sum(len(v) for v in GOLD_EPISODES.values())} total steps')"
636
- ]
637
- },
638
- {
639
- "cell_type": "markdown",
640
- "id": "sft-hdr",
641
- "metadata": {},
642
- "source": [
643
- "## 9. SFT Warmup\n",
644
- "\n",
645
- "Replay gold demonstrations through the live environment to collect accurate\n",
646
- "(observation, action) pairs at each step, then train the model to imitate them."
647
- ]
648
- },
649
- {
650
- "cell_type": "code",
651
- "execution_count": null,
652
- "id": "sft",
653
- "metadata": {},
654
- "outputs": [],
655
- "source": [
656
- "from datasets import Dataset\n",
657
- "from trl import SFTConfig, SFTTrainer\n",
658
- "\n",
659
- "def generate_sft_data(env, gold_episodes, tokenizer):\n",
660
- " \"\"\"Replay gold episodes through the env, collecting chat-formatted training data.\"\"\"\n",
661
- " rows = []\n",
662
- " for scenario_id, actions in gold_episodes.items():\n",
663
- " obs = env.reset(scenario_id=scenario_id)\n",
664
- " last_action, last_reward, history = None, 0.0, []\n",
665
- " for i, action in enumerate(actions):\n",
666
- " user_prompt = build_user_prompt(obs, last_action, last_reward, history)\n",
667
- " messages = [\n",
668
- " {'role': 'system', 'content': SYSTEM_PROMPT},\n",
669
- " {'role': 'user', 'content': user_prompt},\n",
670
- " {'role': 'assistant', 'content': json.dumps(action, separators=(',', ':'))},\n",
671
- " ]\n",
672
- " text = tokenizer.apply_chat_template(\n",
673
- " messages, tokenize=False, add_generation_prompt=False\n",
674
- " )\n",
675
- " rows.append({'text': text})\n",
676
- " obs, reward, done = env.step(action)\n",
677
- " last_action, last_reward = action, reward\n",
678
- " history.append(f'Step {i + 1}: {action.get(\"narration\", \"\")}')\n",
679
- " if done:\n",
680
- " break\n",
681
- " return Dataset.from_list(rows)\n",
682
- "\n",
683
- "\n",
684
- "sft_data = generate_sft_data(env, GOLD_EPISODES, tokenizer)\n",
685
- "print(f'SFT training examples: {len(sft_data)}')\n",
686
- "\n",
687
- "FastLanguageModel.for_training(model)\n",
688
- "\n",
689
- "sft_config = SFTConfig(\n",
690
- " output_dir='/tmp/vr_sft_scratch',\n",
691
- " num_train_epochs=SFT_EPOCHS,\n",
692
- " per_device_train_batch_size=2,\n",
693
- " gradient_accumulation_steps=4,\n",
694
- " learning_rate=2e-4,\n",
695
- " lr_scheduler_type='cosine',\n",
696
- " warmup_ratio=0.03,\n",
697
- " logging_steps=5,\n",
698
- " save_strategy='no',\n",
699
- " fp16=True,\n",
700
- " max_seq_length=MAX_SEQ_LENGTH,\n",
701
- " dataset_text_field='text',\n",
702
- " optim='adamw_8bit',\n",
703
- " report_to='none',\n",
704
- ")\n",
705
- "\n",
706
- "sft_trainer = SFTTrainer(\n",
707
- " model=model,\n",
708
- " processing_class=tokenizer,\n",
709
- " args=sft_config,\n",
710
- " train_dataset=sft_data,\n",
711
- ")\n",
712
- "\n",
713
- "print('\\nTraining SFT...')\n",
714
- "sft_trainer.train()\n",
715
- "print('SFT complete!')"
716
- ]
717
- },
718
- {
719
- "cell_type": "markdown",
720
- "id": "post-sft-hdr",
721
- "metadata": {},
722
- "source": [
723
- "## 10. Post-SFT Evaluation"
724
- ]
725
- },
726
- {
727
- "cell_type": "code",
728
- "execution_count": null,
729
- "id": "post-sft-eval",
730
- "metadata": {},
731
- "outputs": [],
732
- "source": [
733
- "sft_results = evaluate(model, tokenizer, 'Post-SFT')"
734
- ]
735
- },
736
- {
737
- "cell_type": "markdown",
738
- "id": "grpo-hdr",
739
- "metadata": {},
740
- "source": [
741
- "## 11. GRPO Training\n",
742
- "\n",
743
- "Generate prompts from all scenarios (initial observation states) and train with\n",
744
- "environment reward signal. Curriculum ordering: easy prompts first, expert last."
745
- ]
746
- },
747
- {
748
- "cell_type": "code",
749
- "execution_count": null,
750
- "id": "grpo",
751
- "metadata": {},
752
- "outputs": [],
753
- "source": [
754
- "from torch.utils.data import SequentialSampler\n",
755
- "from trl import GRPOConfig, GRPOTrainer\n",
756
- "\n",
757
- "def generate_grpo_prompts(env, stages, samples_per_scenario=2):\n",
758
- " \"\"\"Collect initial-state prompts for GRPO training.\"\"\"\n",
759
- " rows = []\n",
760
- " for stage in stages:\n",
761
- " for sid in SCENARIOS[stage]:\n",
762
- " for _ in range(samples_per_scenario):\n",
763
- " obs = env.reset(scenario_id=sid)\n",
764
- " user_prompt = build_user_prompt(obs, None, 0.0, [])\n",
765
- " messages = [\n",
766
- " {'role': 'system', 'content': SYSTEM_PROMPT},\n",
767
- " {'role': 'user', 'content': user_prompt},\n",
768
- " ]\n",
769
- " rows.append({'prompt': messages, 'scenario_id': sid})\n",
770
- " return Dataset.from_list(rows)\n",
771
- "\n",
772
- "\n",
773
- "def make_reward_fn(env):\n",
774
- " \"\"\"Reward function: parse model completion, step in env, return overall_score.\"\"\"\n",
775
- " state = {'calls': 0, 'hist': Counter()}\n",
776
- "\n",
777
- " def reward_fn(completions, scenario_id=None, **_):\n",
778
- " texts = []\n",
779
- " for c in completions:\n",
780
- " if isinstance(c, list):\n",
781
- " texts.append(c[-1].get('content', '') if c else '')\n",
782
- " else:\n",
783
- " texts.append(str(c))\n",
784
- "\n",
785
- " sids = scenario_id if isinstance(scenario_id, list) else [scenario_id] * len(texts)\n",
786
- " if len(sids) < len(texts):\n",
787
- " n_gen = len(texts) // len(sids)\n",
788
- " sids = [s for s in sids for _ in range(n_gen)]\n",
789
- "\n",
790
- " rewards = []\n",
791
- " for sid, text in zip(sids, texts):\n",
792
- " obs = env.reset(scenario_id=sid)\n",
793
- " action = normalize_action(parse_action(text) or {})\n",
794
- " if action is None:\n",
795
- " rewards.append(0.0)\n",
796
- " state['hist']['<unparseable>'] += 1\n",
797
- " continue\n",
798
- " obs, _, _ = env.step(action)\n",
799
- " score = float(obs.score_breakdown.get('overall_score', 0.0))\n",
800
- " rewards.append(score)\n",
801
- " state['hist'][action.get('step_type', '?')] += 1\n",
802
- "\n",
803
- " state['calls'] += 1\n",
804
- " if state['calls'] % 5 == 0:\n",
805
- " print(f\" [reward] call={state['calls']} types={dict(state['hist'])}\")\n",
806
- " return rewards\n",
807
- "\n",
808
- " return reward_fn\n",
809
- "\n",
810
- "\n",
811
- "grpo_data = generate_grpo_prompts(env, list(DIFFICULTIES))\n",
812
- "print(f'GRPO training prompts: {len(grpo_data)}')\n",
813
- "\n",
814
- "FastLanguageModel.for_training(model)\n",
815
- "\n",
816
- "for name, param in model.named_parameters():\n",
817
- " if 'lora_' in name and param.dtype == torch.float32:\n",
818
- " param.data = param.data.to(torch.bfloat16)\n",
819
- "\n",
820
- "grpo_config = GRPOConfig(\n",
821
- " output_dir='/tmp/vr_grpo_scratch',\n",
822
- " num_train_epochs=GRPO_EPOCHS,\n",
823
- " per_device_train_batch_size=2,\n",
824
- " gradient_accumulation_steps=4,\n",
825
- " num_generations=4,\n",
826
- " max_completion_length=384,\n",
827
- " learning_rate=1e-5,\n",
828
- " lr_scheduler_type='cosine',\n",
829
- " warmup_ratio=0.1,\n",
830
- " beta=0.05,\n",
831
- " max_grad_norm=0.5,\n",
832
- " temperature=0.9,\n",
833
- " logging_steps=1,\n",
834
- " save_strategy='no',\n",
835
- " fp16=False,\n",
836
- " bf16=True,\n",
837
- " optim='adamw_8bit',\n",
838
- " report_to='none',\n",
839
- " remove_unused_columns=False,\n",
840
- ")\n",
841
- "\n",
842
- "\n",
843
- "class CurriculumGRPOTrainer(GRPOTrainer):\n",
844
- " \"\"\"Preserve easy -> expert ordering by disabling dataset shuffle.\"\"\"\n",
845
- " def _get_train_sampler(self, *_args, **_kwargs):\n",
846
- " return SequentialSampler(self.train_dataset)\n",
847
- "\n",
848
- "\n",
849
- "grpo_trainer = CurriculumGRPOTrainer(\n",
850
- " model=model,\n",
851
- " tokenizer=tokenizer,\n",
852
- " args=grpo_config,\n",
853
- " train_dataset=grpo_data,\n",
854
- " reward_funcs=make_reward_fn(env),\n",
855
- ")\n",
856
- "\n",
857
- "print('Training GRPO...')\n",
858
- "grpo_trainer.train()\n",
859
- "print('GRPO complete!')"
860
- ]
861
- },
862
- {
863
- "cell_type": "markdown",
864
- "id": "final-hdr",
865
- "metadata": {},
866
- "source": [
867
- "## 12. Final Evaluation + Delta Report"
868
- ]
869
- },
870
- {
871
- "cell_type": "code",
872
- "execution_count": null,
873
- "id": "final",
874
- "metadata": {},
875
- "outputs": [],
876
- "source": [
877
- "final_results = evaluate(model, tokenizer, 'Final (SFT + GRPO)')\n",
878
- "\n",
879
- "print(f'\\n{\"=\" * 60}')\n",
880
- "print(' DELTA REPORT')\n",
881
- "print(f'{\"=\" * 60}')\n",
882
- "print(f' {\"Difficulty\":<12} {\"Baseline\":>10} {\"SFT\":>10} {\"SFT+GRPO\":>10}')\n",
883
- "print(f' {\"-\" * 12} {\"-\" * 10} {\"-\" * 10} {\"-\" * 10}')\n",
884
- "for diff in list(DIFFICULTIES) + ['overall']:\n",
885
- " b = baseline.get(diff, 0.0)\n",
886
- " s = sft_results.get(diff, 0.0)\n",
887
- " f = final_results.get(diff, 0.0)\n",
888
- " label = diff.upper() if diff == 'overall' else diff\n",
889
- " print(f' {label:<12} {b:>10.3f} {s:>10.3f} {f:>10.3f}')\n",
890
- "print(f'{\"=\" * 60}')\n",
891
- "print('\\nDone. Model was NOT saved (in-memory only).')"
892
- ]
893
- }
894
- ],
895
- "metadata": {
896
- "accelerator": "GPU",
897
- "colab": {
898
- "gpuType": "T4",
899
- "provenance": []
900
- },
901
- "kernelspec": {
902
- "display_name": "Python 3",
903
- "language": "python",
904
- "name": "python3"
905
- },
906
- "language_info": {
907
- "name": "python",
908
- "version": "3.10.0"
909
- }
910
- },
911
- "nbformat": 4,
912
- "nbformat_minor": 5
913
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
train.py DELETED
@@ -1,632 +0,0 @@
1
- """
2
- Visual Reasoning — Training Script
3
- Simulation-based RL for teaching CS algorithms on a whiteboard.
4
-
5
- Stages:
6
- 1. Baseline — score untrained model across all difficulties
7
- 2. SFT — imitate gold demonstrations to learn action format
8
- 3. GRPO — RL with dense environment rewards, easy → expert curriculum
9
- 4. Final eval — delta report comparing all three checkpoints
10
- """
11
-
12
- import subprocess
13
- import sys
14
-
15
-
16
- def install_packages():
17
- # Upgrade torchvision first to match whatever torch version unsloth pulls in
18
- subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "--upgrade", "torchvision>=0.25.0"])
19
- packages = [
20
- "huggingface_hub",
21
- "unsloth",
22
- "trl",
23
- "datasets",
24
- "transformers",
25
- "accelerate",
26
- "bitsandbytes",
27
- "peft",
28
- "torch",
29
- "openenv-core",
30
- "fastapi",
31
- "uvicorn",
32
- "pydantic",
33
- "python-dotenv",
34
- "networkx",
35
- "shapely",
36
- "sentence-transformers",
37
- "rapidfuzz",
38
- "textstat",
39
- "sortedcontainers",
40
- "numpy<2.0",
41
- ]
42
- subprocess.check_call([sys.executable, "-m", "pip", "install", "-q"] + packages)
43
-
44
-
45
- install_packages()
46
-
47
- # ── Section 2: Download Visual Reasoning Environment ─────────────────────────
48
-
49
- import os
50
- from huggingface_hub import snapshot_download
51
-
52
- if not os.path.isdir("visual_reasoning"):
53
- snapshot_download(
54
- repo_id="sreeramajay/visual_reasoning-env",
55
- repo_type="space",
56
- local_dir="visual_reasoning",
57
- ignore_patterns=["*.gitattributes", ".gitignore", "README.md"],
58
- )
59
-
60
- os.chdir("visual_reasoning")
61
-
62
- # ── Section 3: Imports ────────────────────────────────────────────────────────
63
-
64
- import json
65
- import torch
66
- from collections import Counter
67
-
68
- sys.path.insert(0, ".")
69
-
70
- from unsloth import FastLanguageModel
71
- from datasets import Dataset
72
- from trl import SFTConfig, SFTTrainer, GRPOConfig, GRPOTrainer
73
- from torch.utils.data import SequentialSampler
74
-
75
- from models import VisualReasoningAction
76
- from server.visual_reasoning_environment import VisualReasoningEnvironment
77
- from inference import SYSTEM_PROMPT, build_user_prompt, parse_action, normalize_action
78
-
79
- # ── Section 4: Smoke Test ─────────────────────────────────────────────────────
80
-
81
- test_env = VisualReasoningEnvironment()
82
- obs = test_env.reset(scenario_id="easy_1")
83
- print(f"Scenario: {obs.scenario_id}")
84
- print(f"Goal: {obs.goal}")
85
- print(f"Concepts: {obs.concept_checklist}")
86
- print(f"Step budget: {obs.remaining_step_budget}")
87
-
88
- obs = test_env.step(
89
- VisualReasoningAction(
90
- step_type="advance",
91
- narration="Adding the first node with value 10.",
92
- ops=[{"op": "add_node", "target_ids": ["n0"], "params": {"value": 10}}],
93
- covered_concepts=["node_value"],
94
- intent="test",
95
- )
96
- )
97
- print(f"\nAfter step: entities={list(obs.entities.keys())}, reward={obs.reward:.3f}, error={obs.action_error}")
98
- print("Environment OK!")
99
- del test_env
100
-
101
- # ── Section 5: Configuration ──────────────────────────────────────────────────
102
-
103
- MODEL_NAME = "unsloth/Qwen2.5-3B-Instruct-bnb-4bit"
104
- MAX_SEQ_LENGTH = 4096
105
- LORA_R = 16
106
- LORA_ALPHA = 32
107
- SFT_EPOCHS = 3
108
- GRPO_EPOCHS = 2
109
-
110
- SCENARIOS = {
111
- "easy": ["easy_1", "easy_2", "easy_3", "gen_easy_1001", "gen_easy_1002"],
112
- "medium": ["medium_1", "medium_2", "gen_medium_2001", "gen_medium_2002"],
113
- "hard": ["hard_1", "hard_2", "gen_hard_3001", "gen_hard_3002"],
114
- "expert": ["expert_1", "expert_2", "gen_expert_4001", "gen_expert_4002"],
115
- }
116
- DIFFICULTIES = ("easy", "medium", "hard", "expert")
117
-
118
- # ── Section 6: Load Model (Unsloth + LoRA) ────────────────────────────────────
119
-
120
- model, tokenizer = FastLanguageModel.from_pretrained(
121
- model_name=MODEL_NAME,
122
- max_seq_length=MAX_SEQ_LENGTH,
123
- dtype=None,
124
- load_in_4bit=True,
125
- )
126
- model = FastLanguageModel.get_peft_model(
127
- model,
128
- r=LORA_R,
129
- lora_alpha=LORA_ALPHA,
130
- lora_dropout=0,
131
- target_modules=[
132
- "q_proj", "k_proj", "v_proj", "o_proj",
133
- "gate_proj", "up_proj", "down_proj",
134
- ],
135
- bias="none",
136
- use_gradient_checkpointing="unsloth",
137
- random_state=0,
138
- )
139
- if tokenizer.pad_token_id is None:
140
- tokenizer.pad_token = tokenizer.eos_token
141
-
142
- model.generation_config.max_length = None
143
-
144
- trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
145
- total = sum(p.numel() for p in model.parameters())
146
- print(f"Parameters: {total / 1e6:.1f}M total, {trainable / 1e6:.2f}M trainable ({trainable / total:.2%})")
147
-
148
- # ── Section 7: Environment Wrapper & Helpers ──────────────────────────────────
149
-
150
-
151
- class EnvRunner:
152
- """Thin wrapper over VisualReasoningEnvironment for a clean reset/step API."""
153
-
154
- def __init__(self):
155
- self.env = VisualReasoningEnvironment()
156
-
157
- def reset(self, scenario_id=None, task_name=None):
158
- return self.env.reset(scenario_id=scenario_id, task_name=task_name)
159
-
160
- def step(self, action_dict):
161
- act = VisualReasoningAction(**action_dict)
162
- obs = self.env.step(act)
163
- return obs, float(obs.reward), bool(obs.done)
164
-
165
-
166
- def generate_action(model, tokenizer, obs, last_action=None, last_reward=0.0, history=None):
167
- """Generate one action from the model given an observation."""
168
- FastLanguageModel.for_inference(model)
169
- user_prompt = build_user_prompt(obs, last_action, last_reward, history or [])
170
- messages = [
171
- {"role": "system", "content": SYSTEM_PROMPT},
172
- {"role": "user", "content": user_prompt},
173
- ]
174
- prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
175
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
176
- with torch.no_grad():
177
- out = model.generate(
178
- **inputs,
179
- max_new_tokens=384,
180
- do_sample=False,
181
- pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
182
- )
183
- text = tokenizer.decode(out[0, inputs.input_ids.shape[1]:], skip_special_tokens=True)
184
- return normalize_action(parse_action(text) or {}), text
185
-
186
-
187
- FALLBACK_ACTION = {
188
- "step_type": "complete",
189
- "narration": "Explanation complete.",
190
- "ops": [],
191
- "covered_concepts": [],
192
- "intent": "finalize",
193
- }
194
-
195
-
196
- def run_episode(env, scenario_id, model, tokenizer, max_steps=24):
197
- """Run a full episode. Returns (final_score, steps_taken)."""
198
- obs = env.reset(scenario_id=scenario_id)
199
- last_action, last_reward, history = None, 0.0, []
200
- steps = 0
201
- for step_num in range(1, max_steps + 1):
202
- action, _ = generate_action(model, tokenizer, obs, last_action, last_reward, history)
203
- if action is None:
204
- action = FALLBACK_ACTION
205
- obs, reward, done = env.step(action)
206
- last_action, last_reward = action, reward
207
- history.append(f"Step {step_num}: {action.get('narration', '')}")
208
- steps = step_num
209
- if done:
210
- break
211
- return float(obs.score_breakdown.get("overall_score", 0.0)), steps
212
-
213
-
214
- def evaluate(model, tokenizer, env, label, stages=None):
215
- """Evaluate across all scenarios. Returns per-difficulty + overall scores."""
216
- stages = stages or list(DIFFICULTIES)
217
- print(f"\n{'=' * 60}")
218
- print(f" {label}")
219
- print(f"{'=' * 60}")
220
- results = {}
221
- for diff in stages:
222
- scores = []
223
- for sid in SCENARIOS.get(diff, []):
224
- score, steps = run_episode(env, sid, model, tokenizer)
225
- scores.append(score)
226
- print(f" [{diff:6}] {sid:22} score={score:.3f} steps={steps}")
227
- mean = sum(scores) / max(len(scores), 1)
228
- results[diff] = mean
229
- print(f" [{diff:6}] mean={mean:.3f}")
230
- overall = sum(results.values()) / max(len(results), 1)
231
- results["overall"] = overall
232
- print(f" OVERALL: {overall:.3f}")
233
- return results
234
-
235
-
236
- env = EnvRunner()
237
- print("EnvRunner ready.")
238
-
239
- # ── Section 8: Baseline Evaluation ───────────────────────────────────────────
240
-
241
- baseline = evaluate(model, tokenizer, env, "Baseline (untrained LoRA)")
242
-
243
- # ── Section 9: Gold Demonstrations ───────────────────────────────────────────
244
-
245
- GOLD_EPISODES = {
246
- # ── easy_1: linked_list_traversal ──────────────────────────────────
247
- # input_data: {"values": [10, 20, 30]}
248
- # concepts: head_pointer, node_value, next_link, tail_marker
249
- "easy_1": [
250
- {
251
- "step_type": "advance",
252
- "narration": "Building a linked list — creating a centered layout region and the first two nodes with values 10 and 20.",
253
- "ops": [
254
- {"op": "add_region", "target_ids": ["list"], "params": {"style": "array", "title": "Linked List", "position": "center"}},
255
- {"op": "add_node", "target_ids": ["n1"], "params": {"value": 10, "region": "list"}},
256
- {"op": "add_node", "target_ids": ["n2"], "params": {"value": 20, "region": "list"}},
257
- ],
258
- "covered_concepts": ["node_value"],
259
- "intent": "create_list_start",
260
- },
261
- {
262
- "step_type": "advance",
263
- "narration": "Adding node 30 and connecting all nodes with next links to form the chain 10 -> 20 -> 30.",
264
- "ops": [
265
- {"op": "add_node", "target_ids": ["n3"], "params": {"value": 30, "region": "list"}},
266
- {"op": "add_edge", "target_ids": ["n1", "n2"], "params": {"kind": "directed", "label": "next"}},
267
- {"op": "add_edge", "target_ids": ["n2", "n3"], "params": {"kind": "directed", "label": "next"}},
268
- ],
269
- "covered_concepts": ["next_link"],
270
- "intent": "connect_nodes",
271
- },
272
- {
273
- "step_type": "advance",
274
- "narration": "Placing a head pointer at node 10 because traversal always starts at the head, our only entry point into the list.",
275
- "ops": [
276
- {"op": "add_pointer", "target_ids": ["head_ptr"], "params": {"region": "list"}},
277
- {"op": "move_pointer", "target_ids": ["head_ptr"], "params": {"index": "n1"}},
278
- {"op": "annotate", "target_ids": ["n1"], "params": {"text": "Head"}},
279
- {"op": "set_role", "target_ids": ["n1"], "params": {"role": "current"}},
280
- ],
281
- "covered_concepts": ["head_pointer"],
282
- "intent": "mark_head",
283
- },
284
- {
285
- "step_type": "advance",
286
- "narration": "Following the next link from 10 to 20 — the pointer advances and we mark node 10 as visited.",
287
- "ops": [
288
- {"op": "set_role", "target_ids": ["n1"], "params": {"role": "visited"}},
289
- {"op": "set_role", "target_ids": ["n2"], "params": {"role": "current"}},
290
- {"op": "move_pointer", "target_ids": ["head_ptr"], "params": {"index": "n2"}},
291
- ],
292
- "covered_concepts": [],
293
- "intent": "traverse_to_second",
294
- },
295
- {
296
- "step_type": "advance",
297
- "narration": "Reaching node 30 — it has no next link, making it the tail that signals the end of traversal.",
298
- "ops": [
299
- {"op": "set_role", "target_ids": ["n2"], "params": {"role": "visited"}},
300
- {"op": "set_role", "target_ids": ["n3"], "params": {"role": "current"}},
301
- {"op": "annotate", "target_ids": ["n3"], "params": {"text": "Tail"}},
302
- ],
303
- "covered_concepts": ["tail_marker"],
304
- "intent": "reach_tail",
305
- },
306
- {
307
- "step_type": "complete",
308
- "narration": "Traversal complete — visited every node from head to tail following next links, reading values 10, 20, 30 in order.",
309
- "ops": [{"op": "set_role", "target_ids": ["n3"], "params": {"role": "done"}}],
310
- "covered_concepts": [],
311
- "intent": "summarize",
312
- },
313
- ],
314
-
315
- # ── easy_2: stack_ops ──────────────────────────────────────────────
316
- # input_data: {"operations": ["push A", "push B", "pop", "push C"]}
317
- # concepts: top_pointer, push, pop, lifo_order
318
- "easy_2": [
319
- {
320
- "step_type": "advance",
321
- "narration": "Setting up a stack with a centered visual region and a container to track push and pop membership.",
322
- "ops": [
323
- {"op": "add_region", "target_ids": ["stack_area"], "params": {"style": "stack", "title": "Stack", "position": "center"}},
324
- {"op": "add_container", "target_ids": ["stk"], "params": {"region": "stack_area", "ordered": False, "title": "Stack"}},
325
- ],
326
- "covered_concepts": [],
327
- "intent": "setup_stack",
328
- },
329
- {
330
- "step_type": "advance",
331
- "narration": "Pushing A onto the stack — A becomes the first element. Adding a top pointer to track the stack top.",
332
- "ops": [
333
- {"op": "add_node", "target_ids": ["a"], "params": {"value": "A", "region": "stack_area"}},
334
- {"op": "push_to", "target_ids": ["stk", "a"], "params": {}},
335
- {"op": "add_pointer", "target_ids": ["top"], "params": {"region": "stack_area"}},
336
- {"op": "move_pointer", "target_ids": ["top"], "params": {"index": "a"}},
337
- ],
338
- "covered_concepts": ["push", "top_pointer"],
339
- "intent": "push_a",
340
- },
341
- {
342
- "step_type": "advance",
343
- "narration": "Pushing B — B sits on top of A and the top pointer moves up to B.",
344
- "ops": [
345
- {"op": "add_node", "target_ids": ["b"], "params": {"value": "B", "region": "stack_area"}},
346
- {"op": "push_to", "target_ids": ["stk", "b"], "params": {}},
347
- {"op": "move_pointer", "target_ids": ["top"], "params": {"index": "b"}},
348
- ],
349
- "covered_concepts": [],
350
- "intent": "push_b",
351
- },
352
- {
353
- "step_type": "advance",
354
- "narration": "Popping from the stack — B was pushed last so B comes off first, demonstrating LIFO (last-in-first-out) order.",
355
- "ops": [
356
- {"op": "pop_from", "target_ids": ["stk"], "params": {}},
357
- {"op": "set_role", "target_ids": ["b"], "params": {"role": "inactive"}},
358
- {"op": "move_pointer", "target_ids": ["top"], "params": {"index": "a"}},
359
- ],
360
- "covered_concepts": ["pop", "lifo_order"],
361
- "intent": "pop_b",
362
- },
363
- {
364
- "step_type": "advance",
365
- "narration": "Pushing C onto the stack — C now sits on top of A, with B already removed.",
366
- "ops": [
367
- {"op": "add_node", "target_ids": ["c"], "params": {"value": "C", "region": "stack_area"}},
368
- {"op": "push_to", "target_ids": ["stk", "c"], "params": {}},
369
- {"op": "move_pointer", "target_ids": ["top"], "params": {"index": "c"}},
370
- ],
371
- "covered_concepts": [],
372
- "intent": "push_c",
373
- },
374
- {
375
- "step_type": "complete",
376
- "narration": "All four operations executed — stack holds A at bottom and C on top after push A, push B, pop, push C.",
377
- "ops": [],
378
- "covered_concepts": [],
379
- "intent": "summarize",
380
- },
381
- ],
382
-
383
- # ── easy_3: binary_search ──────────────────────────────────────────
384
- # input_data: {"array": [1, 3, 5, 7, 9, 11, 13], "target": 7}
385
- # concepts: sorted_invariant, low_pointer, high_pointer, mid_pointer, comparison
386
- "easy_3": [
387
- {
388
- "step_type": "advance",
389
- "narration": "Creating the first four elements of a sorted array in a centered region — the sorted invariant is what makes binary search possible.",
390
- "ops": [
391
- {"op": "add_region", "target_ids": ["arr"], "params": {"style": "array", "title": "Sorted Array", "position": "center"}},
392
- {"op": "add_node", "target_ids": ["n0"], "params": {"value": 1, "region": "arr"}},
393
- {"op": "add_node", "target_ids": ["n1"], "params": {"value": 3, "region": "arr"}},
394
- {"op": "add_node", "target_ids": ["n2"], "params": {"value": 5, "region": "arr"}},
395
- ],
396
- "covered_concepts": ["sorted_invariant"],
397
- "intent": "create_array_part1",
398
- },
399
- {
400
- "step_type": "advance",
401
- "narration": "Adding the remaining elements 7, 9, 11, 13 to complete all seven values of the sorted array.",
402
- "ops": [
403
- {"op": "add_node", "target_ids": ["n3"], "params": {"value": 7, "region": "arr"}},
404
- {"op": "add_node", "target_ids": ["n4"], "params": {"value": 9, "region": "arr"}},
405
- {"op": "add_node", "target_ids": ["n5"], "params": {"value": 11, "region": "arr"}},
406
- {"op": "add_node", "target_ids": ["n6"], "params": {"value": 13, "region": "arr"}},
407
- ],
408
- "covered_concepts": [],
409
- "intent": "create_array_part2",
410
- },
411
- {
412
- "step_type": "advance",
413
- "narration": "Placing low pointer at index 0 (value 1) and high pointer at index 6 (value 13) to bracket the search range.",
414
- "ops": [
415
- {"op": "add_pointer", "target_ids": ["low"], "params": {"region": "arr"}},
416
- {"op": "move_pointer", "target_ids": ["low"], "params": {"index": "n0"}},
417
- {"op": "add_pointer", "target_ids": ["high"], "params": {"region": "arr"}},
418
- {"op": "move_pointer", "target_ids": ["high"], "params": {"index": "n6"}},
419
- ],
420
- "covered_concepts": ["low_pointer", "high_pointer"],
421
- "intent": "init_pointers",
422
- },
423
- {
424
- "step_type": "advance",
425
- "narration": "Computing mid = (0+6)/2 = 3 — the mid pointer lands on value 7, which we compare against our target 7.",
426
- "ops": [
427
- {"op": "add_pointer", "target_ids": ["mid"], "params": {"region": "arr"}},
428
- {"op": "move_pointer", "target_ids": ["mid"], "params": {"index": "n3"}},
429
- {"op": "set_role", "target_ids": ["n3"], "params": {"role": "current"}},
430
- {"op": "highlight", "target_ids": ["n3"], "params": {}},
431
- ],
432
- "covered_concepts": ["mid_pointer", "comparison"],
433
- "intent": "compute_mid_and_compare",
434
- },
435
- {
436
- "step_type": "complete",
437
- "narration": "Target 7 found at index 3 — binary search located it in one comparison because the sorted invariant halves the search space each step.",
438
- "ops": [
439
- {"op": "set_role", "target_ids": ["n3"], "params": {"role": "done"}},
440
- {"op": "annotate", "target_ids": ["n3"], "params": {"text": "Found: 7"}},
441
- ],
442
- "covered_concepts": [],
443
- "intent": "found_target",
444
- },
445
- ],
446
- }
447
-
448
- print(f"Gold episodes: {len(GOLD_EPISODES)} scenarios, {sum(len(v) for v in GOLD_EPISODES.values())} total steps")
449
-
450
- # ── Section 10: SFT Warmup ────────────────────────────────────────────────────
451
-
452
-
453
- def generate_sft_data(env, gold_episodes, tokenizer):
454
- """Replay gold episodes through the env, collecting chat-formatted training data."""
455
- rows = []
456
- for scenario_id, actions in gold_episodes.items():
457
- obs = env.reset(scenario_id=scenario_id)
458
- last_action, last_reward, history = None, 0.0, []
459
- for i, action in enumerate(actions):
460
- user_prompt = build_user_prompt(obs, last_action, last_reward, history)
461
- messages = [
462
- {"role": "system", "content": SYSTEM_PROMPT},
463
- {"role": "user", "content": user_prompt},
464
- {"role": "assistant", "content": json.dumps(action, separators=(",", ":"))},
465
- ]
466
- text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
467
- rows.append({"text": text})
468
- obs, reward, done = env.step(action)
469
- last_action, last_reward = action, reward
470
- history.append(f"Step {i + 1}: {action.get('narration', '')}")
471
- if done:
472
- break
473
- return Dataset.from_list(rows)
474
-
475
-
476
- sft_data = generate_sft_data(env, GOLD_EPISODES, tokenizer)
477
- print(f"SFT training examples: {len(sft_data)}")
478
-
479
- FastLanguageModel.for_training(model)
480
-
481
- sft_config = SFTConfig(
482
- output_dir="/tmp/vr_sft_scratch",
483
- num_train_epochs=SFT_EPOCHS,
484
- per_device_train_batch_size=2,
485
- gradient_accumulation_steps=4,
486
- learning_rate=2e-4,
487
- lr_scheduler_type="cosine",
488
- warmup_ratio=0.03,
489
- logging_steps=5,
490
- save_strategy="no",
491
- fp16=True,
492
- max_seq_length=MAX_SEQ_LENGTH,
493
- dataset_text_field="text",
494
- optim="adamw_8bit",
495
- report_to="none",
496
- )
497
-
498
- sft_trainer = SFTTrainer(
499
- model=model,
500
- processing_class=tokenizer,
501
- args=sft_config,
502
- train_dataset=sft_data,
503
- )
504
-
505
- print("\nTraining SFT...")
506
- sft_trainer.train()
507
- print("SFT complete!")
508
-
509
- # ── Section 11: Post-SFT Evaluation ──────────────────────────────────────────
510
-
511
- sft_results = evaluate(model, tokenizer, env, "Post-SFT")
512
-
513
- # ── Section 12: GRPO Training ─────────────────────────────────────────────────
514
-
515
-
516
- def generate_grpo_prompts(env, stages, samples_per_scenario=2):
517
- """Collect initial-state prompts for GRPO training."""
518
- rows = []
519
- for stage in stages:
520
- for sid in SCENARIOS[stage]:
521
- for _ in range(samples_per_scenario):
522
- obs = env.reset(scenario_id=sid)
523
- user_prompt = build_user_prompt(obs, None, 0.0, [])
524
- messages = [
525
- {"role": "system", "content": SYSTEM_PROMPT},
526
- {"role": "user", "content": user_prompt},
527
- ]
528
- rows.append({"prompt": messages, "scenario_id": sid})
529
- return Dataset.from_list(rows)
530
-
531
-
532
- def make_reward_fn(env):
533
- """Reward function: parse model completion, step in env, return overall_score."""
534
- state = {"calls": 0, "hist": Counter()}
535
-
536
- def reward_fn(completions, scenario_id=None, **_):
537
- texts = []
538
- for c in completions:
539
- if isinstance(c, list):
540
- texts.append(c[-1].get("content", "") if c else "")
541
- else:
542
- texts.append(str(c))
543
-
544
- sids = scenario_id if isinstance(scenario_id, list) else [scenario_id] * len(texts)
545
- if len(sids) < len(texts):
546
- n_gen = len(texts) // len(sids)
547
- sids = [s for s in sids for _ in range(n_gen)]
548
-
549
- rewards = []
550
- for sid, text in zip(sids, texts):
551
- obs = env.reset(scenario_id=sid)
552
- action = normalize_action(parse_action(text) or {})
553
- if action is None:
554
- rewards.append(0.0)
555
- state["hist"]["<unparseable>"] += 1
556
- continue
557
- obs, _, _ = env.step(action)
558
- score = float(obs.score_breakdown.get("overall_score", 0.0))
559
- rewards.append(score)
560
- state["hist"][action.get("step_type", "?")] += 1
561
-
562
- state["calls"] += 1
563
- if state["calls"] % 5 == 0:
564
- print(f" [reward] call={state['calls']} types={dict(state['hist'])}")
565
- return rewards
566
-
567
- return reward_fn
568
-
569
-
570
- grpo_data = generate_grpo_prompts(env, list(DIFFICULTIES))
571
- print(f"GRPO training prompts: {len(grpo_data)}")
572
-
573
- FastLanguageModel.for_training(model)
574
-
575
- grpo_config = GRPOConfig(
576
- output_dir="/tmp/vr_grpo_scratch",
577
- num_train_epochs=GRPO_EPOCHS,
578
- per_device_train_batch_size=2,
579
- gradient_accumulation_steps=4,
580
- num_generations=4,
581
- max_completion_length=384,
582
- learning_rate=1e-5,
583
- lr_scheduler_type="cosine",
584
- warmup_ratio=0.1,
585
- beta=0.05,
586
- max_grad_norm=0.5,
587
- temperature=0.9,
588
- logging_steps=1,
589
- save_strategy="no",
590
- bf16=True,
591
- optim="adamw_8bit",
592
- report_to="none",
593
- remove_unused_columns=False,
594
- )
595
-
596
-
597
- class CurriculumGRPOTrainer(GRPOTrainer):
598
- """Preserve easy -> expert ordering by disabling dataset shuffle."""
599
-
600
- def _get_train_sampler(self, *_args, **_kwargs):
601
- return SequentialSampler(self.train_dataset)
602
-
603
-
604
- grpo_trainer = CurriculumGRPOTrainer(
605
- model=model,
606
- tokenizer=tokenizer,
607
- args=grpo_config,
608
- train_dataset=grpo_data,
609
- reward_funcs=make_reward_fn(env),
610
- )
611
-
612
- print("Training GRPO...")
613
- grpo_trainer.train()
614
- print("GRPO complete!")
615
-
616
- # ── Section 13: Final Evaluation + Delta Report ───────────────────────────────
617
-
618
- final_results = evaluate(model, tokenizer, env, "Final (SFT + GRPO)")
619
-
620
- print(f"\n{'=' * 60}")
621
- print(" DELTA REPORT")
622
- print(f"{'=' * 60}")
623
- print(f" {'Difficulty':<12} {'Baseline':>10} {'SFT':>10} {'SFT+GRPO':>10}")
624
- print(f" {'-' * 12} {'-' * 10} {'-' * 10} {'-' * 10}")
625
- for diff in list(DIFFICULTIES) + ["overall"]:
626
- b = baseline.get(diff, 0.0)
627
- s = sft_results.get(diff, 0.0)
628
- f = final_results.get(diff, 0.0)
629
- label = diff.upper() if diff == "overall" else diff
630
- print(f" {label:<12} {b:>10.3f} {s:>10.3f} {f:>10.3f}")
631
- print(f"{'=' * 60}")
632
- print("\nDone. Model was NOT saved (in-memory only).")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
train_hf.py DELETED
@@ -1,771 +0,0 @@
1
- """Visual Reasoning — HuggingFace Training Job
2
-
3
- Train Qwen3-8B to be an expert visual CS teacher via SFT warmup + staged GRPO.
4
- Run with: hf jobs run --flavor a100 train_hf.py
5
-
6
- Speedups over train.ipynb:
7
- - NARRATION_SCORER=fallback → CPU heuristic, no GPU contention
8
- - Batched generation → N episodes in one model.generate() call
9
- - Early termination → kill episodes after 3 consecutive no-ops
10
- - snapshot_download → faster than git clone for large repos
11
- """
12
-
13
- # ── 0. Install dependencies before any imports ──────────────────────────────
14
-
15
- import subprocess, sys
16
-
17
- def pip_install(*packages):
18
- subprocess.check_call(
19
- [sys.executable, "-m", "pip", "install", "-q", *packages],
20
- stdout=subprocess.DEVNULL,
21
- )
22
-
23
- print("[0/7] Installing dependencies...")
24
- pip_install(
25
- "unsloth", "trl", "datasets", "transformers", "accelerate",
26
- "bitsandbytes", "peft", "torch",
27
- )
28
- pip_install(
29
- "openenv-core", "fastapi", "uvicorn", "pydantic",
30
- )
31
- pip_install(
32
- "python-dotenv", "networkx", "shapely", "sentence-transformers",
33
- "rapidfuzz", "textstat", "sortedcontainers", "huggingface_hub",
34
- )
35
- print("[0/7] Dependencies installed.")
36
-
37
- # ── 1. Imports (unsloth first to patch transformers early) ──────────────────
38
-
39
- import os
40
- os.environ["NARRATION_SCORER"] = "fallback" # CPU scorer, no GPU contention
41
-
42
- from unsloth import FastLanguageModel # must be before transformers
43
-
44
- import json
45
- import time
46
- import torch
47
- from collections import Counter
48
- from datasets import Dataset
49
- from huggingface_hub import snapshot_download
50
- from torch.utils.data import SequentialSampler
51
- from trl import SFTConfig, SFTTrainer, GRPOConfig, GRPOTrainer
52
-
53
-
54
- # ── 2. Config ───────────────────────────────────────────────────────────────
55
-
56
- # Model
57
- MODEL_NAME = "unsloth/Qwen3-8B-unsloth-bnb-4bit"
58
- MAX_SEQ_LENGTH = 4096
59
- LORA_R = 32
60
- LORA_ALPHA = 64
61
-
62
- # SFT
63
- SFT_EPOCHS = 3
64
- SFT_LR = 2e-4
65
- SFT_BATCH_SIZE = 2
66
- SFT_GRAD_ACCUM = 4
67
-
68
- # GRPO
69
- GRPO_LR = 5e-6
70
- GRPO_STAGE1_EPOCHS = 2
71
- GRPO_STAGE2_EPOCHS = 2
72
- GRPO_NUM_GENERATIONS = 4
73
- GRPO_BATCH_SIZE = 2
74
- GRPO_GRAD_ACCUM = 4
75
- GRPO_TEMPERATURE = 0.9
76
-
77
- # Scenario counts for procedural generation
78
- GRPO_SCENARIOS_PER_DIFFICULTY = {"easy": 40, "medium": 30, "hard": 20, "expert": 10}
79
- DIFFICULTIES = ("easy", "medium", "hard", "expert")
80
-
81
- # Eval
82
- EVAL_BATCH_SIZE = 8
83
- EVAL_MAX_STEPS = 24
84
- NOOP_EARLY_STOP = 3
85
-
86
- # Hub — HF_TOKEN is set by the HuggingFace jobs runtime
87
- HUB_REPO = None # set to "username/model-name" to push
88
- HF_TOKEN = os.environ.get("HF_TOKEN")
89
-
90
- # Static scenarios for evaluation
91
- STATIC_SCENARIOS = {
92
- "easy": ["easy_1", "easy_2", "easy_3"],
93
- "medium": ["medium_1", "medium_2"],
94
- "hard": ["hard_1", "hard_2"],
95
- "expert": ["expert_1", "expert_2"],
96
- }
97
-
98
-
99
- # ── 3. Download environment repo ───────────────────────────────────────────
100
-
101
- print("[1/7] Downloading visual_reasoning environment...")
102
- env_path = snapshot_download(
103
- repo_id="sreeramajay/visual_reasoning-env",
104
- repo_type="space",
105
- local_dir="visual_reasoning",
106
- token=HF_TOKEN,
107
- )
108
- sys.path.insert(0, env_path)
109
- print(f"[1/7] Environment downloaded to {env_path}")
110
-
111
- from models import VisualReasoningAction, VisualReasoningObservation
112
- from server.visual_reasoning_environment import VisualReasoningEnvironment
113
- from server.scenario_generator import generate_scenario
114
- from inference import (
115
- SYSTEM_PROMPT, build_user_prompt, parse_action,
116
- normalize_action as inf_normalize_action,
117
- )
118
-
119
- # Smoke test
120
- _test_env = VisualReasoningEnvironment()
121
- _obs = _test_env.reset(scenario_id="easy_1")
122
- _obs = _test_env.step(VisualReasoningAction(
123
- step_type="advance",
124
- narration="Adding the first node with value 10.",
125
- ops=[{"op": "add_node", "target_ids": ["n0"], "params": {"value": 10}}],
126
- covered_concepts=["node_value"], intent="test",
127
- ))
128
- assert _obs.reward != 0.0, "Environment smoke test failed"
129
- del _test_env, _obs
130
- print("[1/7] Environment smoke test passed.")
131
-
132
-
133
- # ── 4. Load model ──────────────────────────────────────────────────────────
134
-
135
- print("[2/7] Loading model...")
136
- t0 = time.time()
137
-
138
- model, tokenizer = FastLanguageModel.from_pretrained(
139
- model_name=MODEL_NAME,
140
- max_seq_length=MAX_SEQ_LENGTH,
141
- dtype=None,
142
- load_in_4bit=True,
143
- )
144
- model = FastLanguageModel.get_peft_model(
145
- model,
146
- r=LORA_R,
147
- lora_alpha=LORA_ALPHA,
148
- lora_dropout=0,
149
- target_modules=[
150
- "q_proj", "k_proj", "v_proj", "o_proj",
151
- "gate_proj", "up_proj", "down_proj",
152
- ],
153
- bias="none",
154
- use_gradient_checkpointing="unsloth",
155
- random_state=0,
156
- )
157
- if tokenizer.pad_token_id is None:
158
- tokenizer.pad_token = tokenizer.eos_token
159
- model.generation_config.max_length = None
160
-
161
- trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
162
- total = sum(p.numel() for p in model.parameters())
163
- print(f"[2/7] Model loaded in {time.time() - t0:.0f}s — "
164
- f"{total/1e6:.0f}M params, {trainable/1e6:.1f}M trainable ({trainable/total:.1%})")
165
-
166
-
167
- # ── 5. Batched evaluation engine ───────────────────────────────────────────
168
-
169
- FALLBACK_ACTION = {
170
- "step_type": "complete", "narration": "Explanation complete.",
171
- "ops": [], "covered_concepts": [], "intent": "finalize",
172
- }
173
-
174
- # Reusable env pool — avoids repeated __init__ / warmup_scorer overhead
175
- _env_pool = []
176
-
177
- def _get_env(idx):
178
- while len(_env_pool) <= idx:
179
- _env_pool.append(VisualReasoningEnvironment())
180
- return _env_pool[idx]
181
-
182
-
183
- class EpisodeState:
184
- """Tracks one in-flight episode for batched eval."""
185
-
186
- def __init__(self, env, scenario_id):
187
- self.env = env
188
- self.scenario_id = scenario_id
189
- self.obs = env.reset(scenario_id=scenario_id)
190
- self.last_action = None
191
- self.last_reward = 0.0
192
- self.history = []
193
- self.steps = 0
194
- self.done = False
195
- self.score = 0.0
196
- self.consecutive_noops = 0
197
-
198
- def build_prompt_text(self):
199
- user_prompt = build_user_prompt(
200
- self.obs, self.last_action, self.last_reward, self.history
201
- )
202
- messages = [
203
- {"role": "system", "content": SYSTEM_PROMPT},
204
- {"role": "user", "content": user_prompt},
205
- ]
206
- return tokenizer.apply_chat_template(
207
- messages, tokenize=False, add_generation_prompt=True
208
- )
209
-
210
- def apply_action(self, action_dict):
211
- if action_dict is None:
212
- action_dict = FALLBACK_ACTION
213
- self.obs = self.env.step(VisualReasoningAction(**action_dict))
214
- reward = float(self.obs.reward)
215
- self.last_action = action_dict
216
- self.last_reward = reward
217
- self.steps += 1
218
- self.history.append(f"Step {self.steps}: {action_dict.get('narration', '')}")
219
- if reward <= -0.04:
220
- self.consecutive_noops += 1
221
- else:
222
- self.consecutive_noops = 0
223
- if self.obs.done or self.consecutive_noops >= NOOP_EARLY_STOP:
224
- self.done = True
225
- self.score = float(self.obs.score_breakdown.get("overall_score", 0.0))
226
-
227
-
228
- def batched_generate(prompt_texts, max_new_tokens=384):
229
- """Single batched model.generate() call for all active episodes."""
230
- FastLanguageModel.for_inference(model)
231
- inputs = tokenizer(
232
- prompt_texts,
233
- return_tensors="pt",
234
- padding=True,
235
- truncation=True,
236
- max_length=MAX_SEQ_LENGTH,
237
- ).to(model.device)
238
- with torch.no_grad():
239
- outputs = model.generate(
240
- **inputs,
241
- max_new_tokens=max_new_tokens,
242
- do_sample=False,
243
- pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
244
- )
245
- results = []
246
- for i, out in enumerate(outputs):
247
- input_len = inputs.input_ids[i].ne(tokenizer.pad_token_id).sum()
248
- results.append(tokenizer.decode(out[input_len:], skip_special_tokens=True))
249
- return results
250
-
251
-
252
- def evaluate(label, stages=None, scenarios=None):
253
- """Batched eval: runs EVAL_BATCH_SIZE episodes in parallel per generate call."""
254
- stages = stages or list(DIFFICULTIES)
255
- scenarios = scenarios or STATIC_SCENARIOS
256
-
257
- print(f"\n{'=' * 60}")
258
- print(f" EVAL: {label}")
259
- print(f"{'=' * 60}")
260
-
261
- all_ids = []
262
- for diff in stages:
263
- all_ids.extend([(diff, sid) for sid in scenarios.get(diff, [])])
264
-
265
- results_by_diff = {d: [] for d in stages}
266
- t0 = time.time()
267
-
268
- for batch_start in range(0, len(all_ids), EVAL_BATCH_SIZE):
269
- batch = all_ids[batch_start : batch_start + EVAL_BATCH_SIZE]
270
- episodes = [EpisodeState(_get_env(i), sid) for i, (_, sid) in enumerate(batch)]
271
-
272
- for _ in range(1, EVAL_MAX_STEPS + 1):
273
- active = [ep for ep in episodes if not ep.done]
274
- if not active:
275
- break
276
- prompts = [ep.build_prompt_text() for ep in active]
277
- texts = batched_generate(prompts)
278
- for ep, text in zip(active, texts):
279
- parsed = parse_action(text)
280
- action = inf_normalize_action(parsed or {}) if parsed else None
281
- ep.apply_action(action)
282
-
283
- for ep, (diff, sid) in zip(episodes, batch):
284
- results_by_diff[diff].append(ep.score)
285
- print(f" [{diff:6}] {sid:22} score={ep.score:.3f} steps={ep.steps}")
286
-
287
- overall_scores = {}
288
- for diff in stages:
289
- scores = results_by_diff[diff]
290
- mean = sum(scores) / max(len(scores), 1)
291
- overall_scores[diff] = mean
292
- print(f" [{diff:6}] MEAN = {mean:.3f}")
293
-
294
- overall = sum(overall_scores.values()) / max(len(overall_scores), 1)
295
- overall_scores["overall"] = overall
296
- print(f" OVERALL: {overall:.3f} ({time.time() - t0:.1f}s)")
297
- return overall_scores
298
-
299
-
300
- # ── 6. Gold demonstrations for SFT ─────────────────────────────────────────
301
-
302
- # Hand-crafted episodes that teach the model:
303
- # - JSON action format (step_type, narration, ops, covered_concepts, intent)
304
- # - Incremental drawing (Phase 1), algorithm walk-through (Phase 2), wrap-up (Phase 3)
305
- # - Region vs container distinction, concept evidencing, pacing
306
-
307
- GOLD_EPISODES = {
308
- # ── easy_1: linked_list_traversal ──
309
- # input: {"values": [10, 20, 30]}, concepts: head_pointer, node_value, next_link, tail_marker
310
- "easy_1": [
311
- {
312
- "step_type": "advance",
313
- "narration": "Building a linked list — creating a centered layout region and the first two nodes with values 10 and 20.",
314
- "ops": [
315
- {"op": "add_region", "target_ids": ["list"], "params": {"style": "array", "title": "Linked List", "position": "center"}},
316
- {"op": "add_node", "target_ids": ["n1"], "params": {"value": 10, "region": "list"}},
317
- {"op": "add_node", "target_ids": ["n2"], "params": {"value": 20, "region": "list"}},
318
- ],
319
- "covered_concepts": ["node_value"],
320
- "intent": "create_list_start",
321
- },
322
- {
323
- "step_type": "advance",
324
- "narration": "Adding node 30 and connecting all nodes with next links to form the chain 10 -> 20 -> 30.",
325
- "ops": [
326
- {"op": "add_node", "target_ids": ["n3"], "params": {"value": 30, "region": "list"}},
327
- {"op": "add_edge", "target_ids": ["n1", "n2"], "params": {"kind": "directed", "label": "next"}},
328
- {"op": "add_edge", "target_ids": ["n2", "n3"], "params": {"kind": "directed", "label": "next"}},
329
- ],
330
- "covered_concepts": ["next_link"],
331
- "intent": "connect_nodes",
332
- },
333
- {
334
- "step_type": "advance",
335
- "narration": "Placing a head pointer at node 10 because traversal always starts at the head, our only entry point into the list.",
336
- "ops": [
337
- {"op": "add_pointer", "target_ids": ["head_ptr"], "params": {"region": "list"}},
338
- {"op": "move_pointer", "target_ids": ["head_ptr"], "params": {"index": "n1"}},
339
- {"op": "annotate", "target_ids": ["n1"], "params": {"text": "Head"}},
340
- {"op": "set_role", "target_ids": ["n1"], "params": {"role": "current"}},
341
- ],
342
- "covered_concepts": ["head_pointer"],
343
- "intent": "mark_head",
344
- },
345
- {
346
- "step_type": "advance",
347
- "narration": "Following the next link from 10 to 20 — the pointer advances and we mark node 10 as visited.",
348
- "ops": [
349
- {"op": "set_role", "target_ids": ["n1"], "params": {"role": "visited"}},
350
- {"op": "set_role", "target_ids": ["n2"], "params": {"role": "current"}},
351
- {"op": "move_pointer", "target_ids": ["head_ptr"], "params": {"index": "n2"}},
352
- ],
353
- "covered_concepts": [],
354
- "intent": "traverse_to_second",
355
- },
356
- {
357
- "step_type": "advance",
358
- "narration": "Reaching node 30 — it has no next link, making it the tail that signals the end of traversal.",
359
- "ops": [
360
- {"op": "set_role", "target_ids": ["n2"], "params": {"role": "visited"}},
361
- {"op": "set_role", "target_ids": ["n3"], "params": {"role": "current"}},
362
- {"op": "annotate", "target_ids": ["n3"], "params": {"text": "Tail"}},
363
- ],
364
- "covered_concepts": ["tail_marker"],
365
- "intent": "reach_tail",
366
- },
367
- {
368
- "step_type": "complete",
369
- "narration": "Traversal complete — visited every node from head to tail following next links, reading values 10, 20, 30 in order.",
370
- "ops": [{"op": "set_role", "target_ids": ["n3"], "params": {"role": "done"}}],
371
- "covered_concepts": [],
372
- "intent": "summarize",
373
- },
374
- ],
375
-
376
- # ── easy_2: stack_ops ──
377
- # input: {"operations": ["push A", "push B", "pop", "push C"]}, concepts: top_pointer, push, pop, lifo_order
378
- "easy_2": [
379
- {
380
- "step_type": "advance",
381
- "narration": "Setting up a stack with a centered visual region and a container to track push and pop membership.",
382
- "ops": [
383
- {"op": "add_region", "target_ids": ["stack_area"], "params": {"style": "stack", "title": "Stack", "position": "center"}},
384
- {"op": "add_container", "target_ids": ["stk"], "params": {"region": "stack_area", "ordered": False, "title": "Stack"}},
385
- ],
386
- "covered_concepts": [],
387
- "intent": "setup_stack",
388
- },
389
- {
390
- "step_type": "advance",
391
- "narration": "Pushing A onto the stack — A becomes the first element. Adding a top pointer to track the stack top.",
392
- "ops": [
393
- {"op": "add_node", "target_ids": ["a"], "params": {"value": "A", "region": "stack_area"}},
394
- {"op": "push_to", "target_ids": ["stk", "a"], "params": {}},
395
- {"op": "add_pointer", "target_ids": ["top"], "params": {"region": "stack_area"}},
396
- {"op": "move_pointer", "target_ids": ["top"], "params": {"index": "a"}},
397
- ],
398
- "covered_concepts": ["push", "top_pointer"],
399
- "intent": "push_a",
400
- },
401
- {
402
- "step_type": "advance",
403
- "narration": "Pushing B — B sits on top of A and the top pointer moves up to B.",
404
- "ops": [
405
- {"op": "add_node", "target_ids": ["b"], "params": {"value": "B", "region": "stack_area"}},
406
- {"op": "push_to", "target_ids": ["stk", "b"], "params": {}},
407
- {"op": "move_pointer", "target_ids": ["top"], "params": {"index": "b"}},
408
- ],
409
- "covered_concepts": [],
410
- "intent": "push_b",
411
- },
412
- {
413
- "step_type": "advance",
414
- "narration": "Popping from the stack — B was pushed last so B comes off first, demonstrating LIFO (last-in-first-out) order.",
415
- "ops": [
416
- {"op": "pop_from", "target_ids": ["stk"], "params": {}},
417
- {"op": "set_role", "target_ids": ["b"], "params": {"role": "inactive"}},
418
- {"op": "move_pointer", "target_ids": ["top"], "params": {"index": "a"}},
419
- ],
420
- "covered_concepts": ["pop", "lifo_order"],
421
- "intent": "pop_b",
422
- },
423
- {
424
- "step_type": "advance",
425
- "narration": "Pushing C onto the stack — C now sits on top of A, with B already removed.",
426
- "ops": [
427
- {"op": "add_node", "target_ids": ["c"], "params": {"value": "C", "region": "stack_area"}},
428
- {"op": "push_to", "target_ids": ["stk", "c"], "params": {}},
429
- {"op": "move_pointer", "target_ids": ["top"], "params": {"index": "c"}},
430
- ],
431
- "covered_concepts": [],
432
- "intent": "push_c",
433
- },
434
- {
435
- "step_type": "complete",
436
- "narration": "All four operations executed — stack holds A at bottom and C on top after push A, push B, pop, push C.",
437
- "ops": [],
438
- "covered_concepts": [],
439
- "intent": "summarize",
440
- },
441
- ],
442
-
443
- # ── easy_3: binary_search ──
444
- # input: {"array": [1,3,5,7,9,11,13], "target": 7}, concepts: sorted_invariant, low/high/mid_pointer, comparison
445
- "easy_3": [
446
- {
447
- "step_type": "advance",
448
- "narration": "Creating the first four elements of a sorted array in a centered region — the sorted invariant is what makes binary search possible.",
449
- "ops": [
450
- {"op": "add_region", "target_ids": ["arr"], "params": {"style": "array", "title": "Sorted Array", "position": "center"}},
451
- {"op": "add_node", "target_ids": ["n0"], "params": {"value": 1, "region": "arr"}},
452
- {"op": "add_node", "target_ids": ["n1"], "params": {"value": 3, "region": "arr"}},
453
- {"op": "add_node", "target_ids": ["n2"], "params": {"value": 5, "region": "arr"}},
454
- ],
455
- "covered_concepts": ["sorted_invariant"],
456
- "intent": "create_array_part1",
457
- },
458
- {
459
- "step_type": "advance",
460
- "narration": "Adding the remaining elements 7, 9, 11, 13 to complete all seven values of the sorted array.",
461
- "ops": [
462
- {"op": "add_node", "target_ids": ["n3"], "params": {"value": 7, "region": "arr"}},
463
- {"op": "add_node", "target_ids": ["n4"], "params": {"value": 9, "region": "arr"}},
464
- {"op": "add_node", "target_ids": ["n5"], "params": {"value": 11, "region": "arr"}},
465
- {"op": "add_node", "target_ids": ["n6"], "params": {"value": 13, "region": "arr"}},
466
- ],
467
- "covered_concepts": [],
468
- "intent": "create_array_part2",
469
- },
470
- {
471
- "step_type": "advance",
472
- "narration": "Placing low pointer at index 0 (value 1) and high pointer at index 6 (value 13) to bracket the search range.",
473
- "ops": [
474
- {"op": "add_pointer", "target_ids": ["low"], "params": {"region": "arr"}},
475
- {"op": "move_pointer", "target_ids": ["low"], "params": {"index": "n0"}},
476
- {"op": "add_pointer", "target_ids": ["high"], "params": {"region": "arr"}},
477
- {"op": "move_pointer", "target_ids": ["high"], "params": {"index": "n6"}},
478
- ],
479
- "covered_concepts": ["low_pointer", "high_pointer"],
480
- "intent": "init_pointers",
481
- },
482
- {
483
- "step_type": "advance",
484
- "narration": "Computing mid = (0+6)/2 = 3 — the mid pointer lands on value 7, which we compare against our target 7.",
485
- "ops": [
486
- {"op": "add_pointer", "target_ids": ["mid"], "params": {"region": "arr"}},
487
- {"op": "move_pointer", "target_ids": ["mid"], "params": {"index": "n3"}},
488
- {"op": "set_role", "target_ids": ["n3"], "params": {"role": "current"}},
489
- {"op": "highlight", "target_ids": ["n3"], "params": {}},
490
- ],
491
- "covered_concepts": ["mid_pointer", "comparison"],
492
- "intent": "compute_mid_and_compare",
493
- },
494
- {
495
- "step_type": "complete",
496
- "narration": "Target 7 found at index 3 — binary search located it in one comparison because the sorted invariant halves the search space each step.",
497
- "ops": [
498
- {"op": "set_role", "target_ids": ["n3"], "params": {"role": "done"}},
499
- {"op": "annotate", "target_ids": ["n3"], "params": {"text": "Found: 7"}},
500
- ],
501
- "covered_concepts": [],
502
- "intent": "found_target",
503
- },
504
- ],
505
- }
506
-
507
-
508
- # ── 7. SFT data generation ─────────────────────────────────────────────────
509
-
510
- def generate_sft_data():
511
- """Replay gold episodes through the live env to collect (observation, action) pairs."""
512
- env = VisualReasoningEnvironment()
513
- rows = []
514
- for scenario_id, actions in GOLD_EPISODES.items():
515
- obs = env.reset(scenario_id=scenario_id)
516
- last_action, last_reward, history = None, 0.0, []
517
- for i, action in enumerate(actions):
518
- user_prompt = build_user_prompt(obs, last_action, last_reward, history)
519
- messages = [
520
- {"role": "system", "content": SYSTEM_PROMPT},
521
- {"role": "user", "content": user_prompt},
522
- {"role": "assistant", "content": json.dumps(action, separators=(",", ":"))},
523
- ]
524
- text = tokenizer.apply_chat_template(
525
- messages, tokenize=False, add_generation_prompt=False
526
- )
527
- rows.append({"text": text})
528
- obs = env.step(VisualReasoningAction(**action))
529
- last_action, last_reward = action, float(obs.reward)
530
- history.append(f"Step {i + 1}: {action.get('narration', '')}")
531
- if obs.done:
532
- break
533
- return Dataset.from_list(rows)
534
-
535
-
536
- # ── 8. GRPO helpers ─────────────────────────────────────────────────────────
537
-
538
- def generate_grpo_scenarios():
539
- """Procedurally generate diverse scenarios for each difficulty."""
540
- out = {}
541
- for diff, count in GRPO_SCENARIOS_PER_DIFFICULTY.items():
542
- base_seed = {"easy": 10000, "medium": 20000, "hard": 30000, "expert": 40000}[diff]
543
- scenarios = [generate_scenario(task_name=diff, seed=base_seed + i) for i in range(count)]
544
- templates = Counter(s["template"] for s in scenarios)
545
- print(f" {diff}: {count} scenarios — {dict(templates)}")
546
- out[diff] = scenarios
547
- return out
548
-
549
-
550
- def build_grpo_prompts(scenarios_by_diff, stages, samples_per_scenario=2):
551
- """Collect initial-state prompts by resetting the env for each scenario."""
552
- env = VisualReasoningEnvironment()
553
- rows = []
554
- for stage in stages:
555
- for scenario in scenarios_by_diff.get(stage, []):
556
- sid = scenario["scenario_id"]
557
- for _ in range(samples_per_scenario):
558
- obs = env.reset(scenario_id=sid)
559
- user_prompt = build_user_prompt(obs, None, 0.0, [])
560
- messages = [
561
- {"role": "system", "content": SYSTEM_PROMPT},
562
- {"role": "user", "content": user_prompt},
563
- ]
564
- rows.append({"prompt": messages, "scenario_id": sid})
565
- return Dataset.from_list(rows)
566
-
567
-
568
- def make_reward_fn():
569
- """Reward function: parse completion, step in env, return overall_score."""
570
- env = VisualReasoningEnvironment()
571
- state = {"calls": 0, "hist": Counter()}
572
-
573
- def reward_fn(completions, scenario_id=None, **_):
574
- texts = []
575
- for c in completions:
576
- if isinstance(c, list):
577
- texts.append(c[-1].get("content", "") if c else "")
578
- else:
579
- texts.append(str(c))
580
-
581
- sids = scenario_id if isinstance(scenario_id, list) else [scenario_id] * len(texts)
582
- if len(sids) < len(texts):
583
- n_gen = len(texts) // len(sids)
584
- sids = [s for s in sids for _ in range(n_gen)]
585
-
586
- rewards = []
587
- for sid, text in zip(sids, texts):
588
- obs = env.reset(scenario_id=sid)
589
- parsed = parse_action(text)
590
- action = inf_normalize_action(parsed or {}) if parsed else None
591
- if action is None:
592
- rewards.append(0.0)
593
- state["hist"]["<unparseable>"] += 1
594
- continue
595
- obs = env.step(VisualReasoningAction(**action))
596
- rewards.append(float(obs.score_breakdown.get("overall_score", 0.0)))
597
- state["hist"][action.get("step_type", "?")] += 1
598
-
599
- state["calls"] += 1
600
- if state["calls"] % 10 == 0:
601
- print(f" [reward] call={state['calls']} types={dict(state['hist'])}")
602
- return rewards
603
-
604
- return reward_fn
605
-
606
-
607
- class CurriculumGRPOTrainer(GRPOTrainer):
608
- """Sequential sampler preserves easy → expert curriculum ordering."""
609
- def _get_train_sampler(self, *_args, **_kwargs):
610
- return SequentialSampler(self.train_dataset)
611
-
612
-
613
- # ── 9. Main training loop ──────────────────────────────────────────────────
614
-
615
- def main():
616
- job_start = time.time()
617
-
618
- # ── Baseline ──
619
- print("\n[3/7] Baseline evaluation...")
620
- baseline = evaluate("Baseline (untrained LoRA)")
621
-
622
- # ── SFT ──
623
- print("\n[4/7] SFT warmup...")
624
- sft_data = generate_sft_data()
625
- print(f" SFT examples: {len(sft_data)}")
626
-
627
- FastLanguageModel.for_training(model)
628
- sft_trainer = SFTTrainer(
629
- model=model,
630
- processing_class=tokenizer,
631
- args=SFTConfig(
632
- output_dir="/tmp/vr_sft",
633
- num_train_epochs=SFT_EPOCHS,
634
- per_device_train_batch_size=SFT_BATCH_SIZE,
635
- gradient_accumulation_steps=SFT_GRAD_ACCUM,
636
- learning_rate=SFT_LR,
637
- lr_scheduler_type="cosine",
638
- warmup_ratio=0.03,
639
- logging_steps=5,
640
- save_strategy="no",
641
- bf16=True,
642
- max_seq_length=MAX_SEQ_LENGTH,
643
- dataset_text_field="text",
644
- optim="adamw_8bit",
645
- report_to="none",
646
- ),
647
- train_dataset=sft_data,
648
- )
649
-
650
- t0 = time.time()
651
- sft_trainer.train()
652
- print(f" SFT done in {time.time() - t0:.0f}s")
653
-
654
- sft_results = evaluate("Post-SFT")
655
-
656
- # ── Generate GRPO scenarios ──
657
- print("\n[5/7] Generating GRPO scenarios...")
658
- grpo_scenarios = generate_grpo_scenarios()
659
- total_scenarios = sum(len(v) for v in grpo_scenarios.values())
660
- print(f" Total: {total_scenarios} scenarios")
661
-
662
- # ── GRPO Stage 1: easy + medium ──
663
- print("\n[6/7] GRPO Stage 1 (easy + medium)...")
664
- stage1_data = build_grpo_prompts(grpo_scenarios, ["easy", "medium"])
665
- print(f" Stage 1 prompts: {len(stage1_data)}")
666
-
667
- FastLanguageModel.for_training(model)
668
- grpo_s1 = CurriculumGRPOTrainer(
669
- model=model,
670
- tokenizer=tokenizer,
671
- args=GRPOConfig(
672
- output_dir="/tmp/vr_grpo_s1",
673
- num_train_epochs=GRPO_STAGE1_EPOCHS,
674
- per_device_train_batch_size=GRPO_BATCH_SIZE,
675
- gradient_accumulation_steps=GRPO_GRAD_ACCUM,
676
- num_generations=GRPO_NUM_GENERATIONS,
677
- max_completion_length=384,
678
- learning_rate=GRPO_LR,
679
- lr_scheduler_type="cosine",
680
- warmup_ratio=0.1,
681
- beta=0.05,
682
- max_grad_norm=0.5,
683
- temperature=GRPO_TEMPERATURE,
684
- logging_steps=1,
685
- save_strategy="no",
686
- bf16=True,
687
- optim="adamw_8bit",
688
- report_to="none",
689
- remove_unused_columns=False,
690
- ),
691
- train_dataset=stage1_data,
692
- reward_funcs=make_reward_fn(),
693
- )
694
-
695
- t0 = time.time()
696
- grpo_s1.train()
697
- print(f" Stage 1 done in {time.time() - t0:.0f}s")
698
-
699
- stage1_results = evaluate("Post-GRPO Stage 1", stages=["easy", "medium"])
700
-
701
- # ── GRPO Stage 2: all difficulties ──
702
- print("\n[7/7] GRPO Stage 2 (all difficulties)...")
703
- stage2_data = build_grpo_prompts(grpo_scenarios, list(DIFFICULTIES))
704
- print(f" Stage 2 prompts: {len(stage2_data)}")
705
-
706
- FastLanguageModel.for_training(model)
707
- grpo_s2 = CurriculumGRPOTrainer(
708
- model=model,
709
- tokenizer=tokenizer,
710
- args=GRPOConfig(
711
- output_dir="/tmp/vr_grpo_s2",
712
- num_train_epochs=GRPO_STAGE2_EPOCHS,
713
- per_device_train_batch_size=GRPO_BATCH_SIZE,
714
- gradient_accumulation_steps=GRPO_GRAD_ACCUM,
715
- num_generations=GRPO_NUM_GENERATIONS,
716
- max_completion_length=384,
717
- learning_rate=GRPO_LR * 0.5, # halved for stability with harder scenarios
718
- lr_scheduler_type="cosine",
719
- warmup_ratio=0.1,
720
- beta=0.05,
721
- max_grad_norm=0.5,
722
- temperature=GRPO_TEMPERATURE,
723
- logging_steps=1,
724
- save_strategy="no",
725
- bf16=True,
726
- optim="adamw_8bit",
727
- report_to="none",
728
- remove_unused_columns=False,
729
- ),
730
- train_dataset=stage2_data,
731
- reward_funcs=make_reward_fn(),
732
- )
733
-
734
- t0 = time.time()
735
- grpo_s2.train()
736
- print(f" Stage 2 done in {time.time() - t0:.0f}s")
737
-
738
- # ── Final eval + report ──
739
- final_results = evaluate("Final (SFT + GRPO S1 + S2)")
740
-
741
- print(f"\n{'=' * 72}")
742
- print(" DELTA REPORT")
743
- print(f"{'=' * 72}")
744
- print(f" {'Difficulty':<12} {'Baseline':>10} {'SFT':>10} {'GRPO-S1':>10} {'Final':>10} {'Δ':>10}")
745
- print(f" {'-'*12} {'-'*10} {'-'*10} {'-'*10} {'-'*10} {'-'*10}")
746
- for diff in list(DIFFICULTIES) + ["overall"]:
747
- b = baseline.get(diff, 0.0)
748
- s = sft_results.get(diff, 0.0)
749
- s1 = stage1_results.get(diff, 0.0)
750
- f = final_results.get(diff, 0.0)
751
- label = diff.upper() if diff == "overall" else diff
752
- print(f" {label:<12} {b:>10.3f} {s:>10.3f} {s1:>10.3f} {f:>10.3f} {f - b:>+10.3f}")
753
- print(f"{'=' * 72}")
754
-
755
- # ── Push to hub ──
756
- if HUB_REPO:
757
- print(f"\nPushing LoRA adapter to {HUB_REPO}...")
758
- model.push_to_hub(HUB_REPO, token=HF_TOKEN)
759
- tokenizer.push_to_hub(HUB_REPO, token=HF_TOKEN)
760
- print(f"Pushed: https://huggingface.co/{HUB_REPO}")
761
- else:
762
- print("\nHUB_REPO not set — saving locally to /tmp/vr_qwen3_8b_lora")
763
- model.save_pretrained("/tmp/vr_qwen3_8b_lora")
764
- tokenizer.save_pretrained("/tmp/vr_qwen3_8b_lora")
765
-
766
- total_mins = (time.time() - job_start) / 60
767
- print(f"\nJob finished in {total_mins:.1f} minutes.")
768
-
769
-
770
- if __name__ == "__main__":
771
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
uv.lock DELETED
The diff for this file is too large to render. See raw diff
 
viewer/audio_viewer.html DELETED
@@ -1,865 +0,0 @@
1
- <!doctype html>
2
- <html lang="en">
3
- <head>
4
- <meta charset="utf-8"/>
5
- <title>Visual Reasoning — Audio Viewer</title>
6
- <style>
7
- *, *::before, *::after { box-sizing: border-box; margin: 0; padding: 0; }
8
-
9
- :root {
10
- --bg-dark: #1e1e2e;
11
- --bg-panel: #181825;
12
- --bg-surface: #11111b;
13
- --border: #313244;
14
- --text: #cdd6f4;
15
- --text-dim: #6c7086;
16
- --accent: #89b4fa;
17
- --green: #a6e3a1;
18
- --red: #f38ba8;
19
- --yellow: #f9e2af;
20
- --peach: #fab387;
21
- --mauve: #cba6f7;
22
- --teal: #94e2d5;
23
- }
24
-
25
- html, body { height: 100%; font-family: 'Inter', -apple-system, system-ui, sans-serif; background: var(--bg-dark); color: var(--text); overflow: hidden; }
26
-
27
- #app { display: flex; flex-direction: column; height: 100vh; }
28
-
29
- /* ---- Header ---- */
30
- .header {
31
- display: flex; align-items: center; gap: 12px;
32
- padding: 6px 14px; background: var(--bg-panel); border-bottom: 1px solid var(--border);
33
- font-size: 12px; flex-shrink: 0; min-height: 36px;
34
- }
35
- .header .badge { padding: 2px 8px; border-radius: 4px; font-weight: 600; font-size: 11px; }
36
- .header .badge-task { background: rgba(137,180,250,0.15); color: var(--accent); }
37
- .header .badge-scenario { color: var(--text-dim); font-size: 11px; }
38
- .header .step-counter { color: var(--text); font-size: 11px; }
39
- .header .score-display { font-weight: 700; color: var(--green); font-size: 14px; }
40
- .header .spacer { flex: 1; }
41
- .header .ws-dot { width: 7px; height: 7px; border-radius: 50%; }
42
- .header .ws-dot.on { background: var(--green); }
43
- .header .ws-dot.off { background: var(--red); }
44
-
45
- .audio-btn {
46
- background: none; border: 1px solid var(--border); color: var(--text-dim);
47
- border-radius: 4px; padding: 2px 8px; font-size: 14px; cursor: pointer;
48
- transition: all .2s;
49
- }
50
- .audio-btn:hover { border-color: var(--accent); color: var(--accent); }
51
- .audio-btn.active { color: var(--accent); border-color: var(--accent); }
52
-
53
- /* ---- Main layout ---- */
54
- .main { display: flex; flex: 1; min-height: 0; overflow: hidden; }
55
-
56
- /* ---- Canvas ---- */
57
- .canvas-wrap {
58
- flex: 1; min-width: 0; background: var(--bg-surface);
59
- position: relative; overflow: hidden;
60
- }
61
- .canvas-wrap canvas { display: block; width: 100%; height: 100%; }
62
-
63
- /* ---- Side panel ---- */
64
- .side {
65
- width: 270px; flex-shrink: 0; background: var(--bg-panel);
66
- border-left: 1px solid var(--border);
67
- display: flex; flex-direction: column; overflow: hidden;
68
- }
69
- .side-section {
70
- padding: 8px 12px; border-bottom: 1px solid var(--border); flex-shrink: 0;
71
- }
72
- .side-section h3 {
73
- font-size: 9px; text-transform: uppercase; letter-spacing: 1.2px;
74
- color: var(--accent); margin-bottom: 4px; font-weight: 600;
75
- }
76
-
77
- /* Goal */
78
- .goal-text {
79
- font-size: 11px; line-height: 1.4; color: var(--text);
80
- max-height: 48px; overflow: hidden; text-overflow: ellipsis;
81
- }
82
-
83
- /* Checklist */
84
- .checklist-grid { display: flex; flex-wrap: wrap; gap: 2px 8px; }
85
- .concept { display: flex; align-items: center; gap: 4px; padding: 1px 0; font-size: 11px; }
86
- .concept .icon { width: 14px; text-align: center; font-size: 12px; }
87
- .concept.covered .icon { color: var(--green); }
88
- .concept.uncovered .icon { color: var(--text-dim); }
89
- .concept.covered { color: var(--green); }
90
- .concept.uncovered { color: var(--text-dim); }
91
-
92
- /* Score bars */
93
- .score-row { display: flex; align-items: center; gap: 4px; margin-bottom: 2px; font-size: 10px; }
94
- .score-label { width: 78px; flex-shrink: 0; color: var(--text-dim); white-space: nowrap; overflow: hidden; text-overflow: ellipsis; }
95
- .score-track { flex: 1; height: 12px; background: var(--bg-surface); border-radius: 2px; overflow: hidden; }
96
- .score-fill { height: 100%; border-radius: 2px; display: flex; align-items: center; padding-left: 3px; font-size: 8px; font-weight: 600; color: #fff; min-width: 20px; transition: width .3s; }
97
-
98
- /* Narration log */
99
- .narration-log { flex: 1; overflow-y: auto; padding: 8px 12px; min-height: 0; }
100
- .narration-log h3 { font-size: 9px; text-transform: uppercase; letter-spacing: 1.2px; color: var(--accent); margin-bottom: 4px; font-weight: 600; }
101
- .narr-entry { padding: 4px 0; border-bottom: 1px solid rgba(49,50,68,.5); font-size: 11px; animation: fadeIn .3s; }
102
- .narr-entry .step-tag { font-weight: 600; color: var(--accent); font-size: 10px; }
103
- .narr-entry .reward { font-size: 9px; font-weight: 700; margin-left: 4px; }
104
- .narr-entry .reward.pos { color: var(--green); }
105
- .narr-entry .reward.neg { color: var(--red); }
106
- .narr-entry .text { color: var(--text); margin-top: 1px; line-height: 1.35; }
107
-
108
- /* ---- Bottom narration bar ---- */
109
- .narration-bar {
110
- flex-shrink: 0; display: flex; align-items: center; gap: 10px;
111
- padding: 8px 16px; background: var(--bg-panel); border-top: 1px solid var(--border);
112
- min-height: 40px; max-height: 60px;
113
- }
114
- .narr-indicator {
115
- width: 24px; height: 24px; display: flex; align-items: center; justify-content: center;
116
- flex-shrink: 0; font-size: 16px; color: var(--text-dim); transition: color .3s;
117
- }
118
- .narr-indicator.speaking { color: var(--accent); }
119
- .narr-indicator.speaking .bars { display: flex; align-items: flex-end; gap: 2px; height: 16px; }
120
- .narr-indicator.speaking .bars span {
121
- width: 3px; background: var(--accent); border-radius: 1px;
122
- animation: barPulse .6s ease-in-out infinite alternate;
123
- }
124
- .narr-indicator.speaking .bars span:nth-child(2) { animation-delay: .15s; }
125
- .narr-indicator.speaking .bars span:nth-child(3) { animation-delay: .3s; }
126
- .narr-indicator.speaking .bars span:nth-child(4) { animation-delay: .45s; }
127
- .narr-text {
128
- flex: 1; font-size: 13px; color: var(--text); font-style: italic;
129
- white-space: nowrap; overflow: hidden; text-overflow: ellipsis;
130
- transition: opacity .5s;
131
- }
132
- .narr-text.dim { opacity: .3; }
133
- .music-badge {
134
- flex-shrink: 0; font-size: 10px; color: var(--mauve); display: flex; align-items: center; gap: 4px;
135
- opacity: 0; transition: opacity .5s;
136
- }
137
- .music-badge.on { opacity: 1; }
138
-
139
- @keyframes barPulse {
140
- 0% { height: 4px; }
141
- 100% { height: 14px; }
142
- }
143
-
144
- /* ---- End overlay ---- */
145
- .overlay {
146
- position: fixed; inset: 0; background: rgba(17,17,27,.92);
147
- display: flex; flex-direction: column; align-items: center; justify-content: center;
148
- z-index: 1000; cursor: pointer;
149
- }
150
- .overlay.hidden { display: none; }
151
- .overlay h1 { font-size: 24px; margin-bottom: 6px; }
152
- .overlay h1.ok { color: var(--green); }
153
- .overlay h1.fail { color: var(--red); }
154
- .overlay .final { font-size: 16px; color: var(--text-dim); margin-bottom: 16px; }
155
- .reward-chart { display: flex; align-items: flex-end; gap: 4px; height: 100px; }
156
- .reward-chart .bar { width: 22px; border-radius: 3px 3px 0 0; min-height: 2px; position: relative; }
157
- .reward-chart .bar .lbl { position: absolute; bottom: -14px; left: 50%; transform: translateX(-50%); font-size: 8px; color: var(--text-dim); }
158
-
159
- /* Audio hint banner */
160
- .audio-hint {
161
- position: fixed; bottom: 60px; left: 50%; transform: translateX(-50%);
162
- background: rgba(137,180,250,0.15); color: var(--accent); padding: 8px 20px;
163
- border-radius: 8px; font-size: 13px; z-index: 2000; cursor: pointer;
164
- border: 1px solid rgba(137,180,250,0.3); backdrop-filter: blur(8px);
165
- animation: fadeIn .5s;
166
- }
167
- .audio-hint.hidden { display: none; }
168
-
169
- /* Error toast */
170
- .toast {
171
- position: fixed; top: 46px; left: 50%; transform: translateX(-50%);
172
- background: var(--red); color: #11111b; padding: 5px 16px;
173
- border-radius: 6px; font-size: 11px; font-weight: 600; z-index: 999; display: none;
174
- }
175
-
176
- @keyframes fadeIn { from { opacity: 0; } to { opacity: 1; } }
177
- </style>
178
- </head>
179
- <body>
180
- <div id="app">
181
- <div class="header">
182
- <span class="badge badge-task" id="h-task">--</span>
183
- <span class="badge-scenario" id="h-scenario">--</span>
184
- <span class="step-counter" id="h-step">Step 0 / 0</span>
185
- <span class="spacer"></span>
186
- <span class="score-display" id="h-score">0.000</span>
187
- <button class="audio-btn active" id="btn-mute" title="Toggle audio">&#x1f50a;</button>
188
- <span class="ws-dot off" id="h-ws"></span>
189
- </div>
190
- <div class="main">
191
- <div class="canvas-wrap">
192
- <canvas id="canvas"></canvas>
193
- </div>
194
- <div class="side">
195
- <div class="side-section" id="sec-goal"><h3>Goal</h3><div class="goal-text" id="goal-text">--</div></div>
196
- <div class="side-section" id="sec-checklist"><h3>Concept Checklist</h3><div class="checklist-grid" id="checklist"></div></div>
197
- <div class="side-section" id="sec-scores"><h3>Score Breakdown</h3><div id="scores"></div></div>
198
- <div class="narration-log"><h3>Narration</h3><div id="narrations"></div></div>
199
- </div>
200
- </div>
201
- <div class="narration-bar">
202
- <div class="narr-indicator" id="narr-ind">
203
- <div class="bars"><span></span><span></span><span></span><span></span></div>
204
- </div>
205
- <div class="narr-text dim" id="narr-text"></div>
206
- <div class="music-badge" id="music-badge">&#x266b; music</div>
207
- </div>
208
- </div>
209
-
210
-
211
- <!-- End overlay -->
212
- <div class="overlay hidden" id="overlay">
213
- <h1 id="ov-title"></h1>
214
- <div class="final" id="ov-score"></div>
215
- <div class="reward-chart" id="ov-chart"></div>
216
- </div>
217
- <div class="audio-hint" id="audio-hint">Click here to enable audio narration</div>
218
- <div class="toast" id="toast"></div>
219
-
220
- <script>
221
- /* ================================================================
222
- Audio Viewer — Canvas2D renderer + TTS + background music
223
- ================================================================ */
224
-
225
- const ROLE_COLORS = {
226
- default: { fill: '#45475a', stroke: '#585b70', text: '#cdd6f4' },
227
- current: { fill: '#89b4fa', stroke: '#74c7ec', text: '#1e1e2e' },
228
- visited: { fill: '#a6e3a1', stroke: '#94e2d5', text: '#1e1e2e' },
229
- frontier: { fill: '#fab387', stroke: '#f9e2af', text: '#1e1e2e' },
230
- done: { fill: '#40a02b', stroke: '#a6e3a1', text: '#fff' },
231
- pivot: { fill: '#f38ba8', stroke: '#eba0ac', text: '#1e1e2e' },
232
- root: { fill: '#cba6f7', stroke: '#b4befe', text: '#1e1e2e' },
233
- error: { fill: '#f38ba8', stroke: '#f38ba8', text: '#fff' },
234
- inactive: { fill: '#313244', stroke: '#45475a', text: '#6c7086' },
235
- comparing: { fill: '#f9e2af', stroke: '#fab387', text: '#1e1e2e' },
236
- };
237
-
238
- const GRID = 70;
239
- const NODE_R = 26;
240
- const ARROW_LEN = 10;
241
- const FONT = '13px Inter, system-ui, sans-serif';
242
- const FONT_BOLD = 'bold 14px Inter, system-ui, sans-serif';
243
- const FONT_SMALL = '10px Inter, system-ui, sans-serif';
244
- const FONT_ANN = '11px Inter, system-ui, sans-serif';
245
-
246
- let S = {
247
- entities: {}, relations: [], layout: {}, annotations: [], notes: [],
248
- taskName: '', scenarioId: '', goal: '', checklist: [], coverage: [],
249
- maxSteps: 0, step: 0, score: 0, breakdown: {}, narrations: [],
250
- };
251
-
252
- let cam = { ox: 40, oy: 30, scale: 1 };
253
- let canvas, ctx;
254
- let toastTimer;
255
-
256
- const $ = id => document.getElementById(id);
257
-
258
- /* ---- Audio system ---- */
259
- let audioCtx = null;
260
- let bgMusic = null;
261
- let currentTTS = null;
262
- let isMuted = false;
263
- let audioUnlocked = false;
264
- let bgMusicWanted = false;
265
- let pendingAudioQueue = [];
266
- const BG_VOLUME = 0.06;
267
- const BG_DUCK_VOLUME = 0.02;
268
-
269
- // Create AudioContext eagerly (starts suspended until user gesture)
270
- try { audioCtx = new (window.AudioContext || window.webkitAudioContext)(); } catch {}
271
- bgMusic = new Audio('/audio/background.mp3');
272
- bgMusic.loop = true;
273
- bgMusic.volume = BG_VOLUME;
274
- bgMusic.addEventListener('error', () => console.warn('Background music not available'));
275
-
276
- function unlockAudio() {
277
- if (audioUnlocked) return;
278
- audioUnlocked = true;
279
- $('audio-hint').classList.add('hidden');
280
- if (audioCtx && audioCtx.state === 'suspended') audioCtx.resume();
281
- // play+pause to unlock the HTML Audio element for later use
282
- bgMusic.play().then(() => {
283
- if (!bgMusicWanted) { bgMusic.pause(); bgMusic.currentTime = 0; }
284
- else { bgMusic.volume = isMuted ? 0 : BG_VOLUME; $('music-badge').classList.add('on'); }
285
- }).catch(() => {});
286
- // flush queued TTS
287
- if (pendingAudioQueue.length) {
288
- for (const b64 of pendingAudioQueue) playTTS(b64);
289
- pendingAudioQueue.length = 0;
290
- }
291
- }
292
-
293
- // Unlock on any user interaction (click, touch, keydown)
294
- document.addEventListener('click', unlockAudio, { once: true });
295
- document.addEventListener('touchstart', unlockAudio, { once: true });
296
- document.addEventListener('keydown', unlockAudio, { once: true });
297
-
298
- function startBgMusic() {
299
- bgMusicWanted = true;
300
- if (!audioUnlocked || isMuted) return;
301
- bgMusic.currentTime = 0;
302
- bgMusic.volume = BG_VOLUME;
303
- bgMusic.play().catch(() => {});
304
- $('music-badge').classList.add('on');
305
- }
306
-
307
- function stopBgMusic() {
308
- bgMusicWanted = false;
309
- bgMusic.pause();
310
- bgMusic.currentTime = 0;
311
- $('music-badge').classList.remove('on');
312
- }
313
-
314
- function duckBgMusic() {
315
- bgMusic.volume = isMuted ? 0 : BG_DUCK_VOLUME;
316
- }
317
-
318
- function unduckBgMusic() {
319
- bgMusic.volume = isMuted ? 0 : BG_VOLUME;
320
- }
321
-
322
- async function playTTS(base64Wav) {
323
- if (!base64Wav || !audioCtx || isMuted) return;
324
- if (!audioUnlocked) { pendingAudioQueue.push(base64Wav); return; }
325
- if (currentTTS) { try { currentTTS.stop(); } catch {} currentTTS = null; }
326
-
327
- try {
328
- if (audioCtx.state === 'suspended') await audioCtx.resume();
329
- const binary = atob(base64Wav);
330
- const buffer = new ArrayBuffer(binary.length);
331
- const view = new Uint8Array(buffer);
332
- for (let i = 0; i < binary.length; i++) view[i] = binary.charCodeAt(i);
333
-
334
- const audioBuffer = await audioCtx.decodeAudioData(buffer);
335
- const source = audioCtx.createBufferSource();
336
- const gain = audioCtx.createGain();
337
- gain.gain.value = 1.0;
338
- source.buffer = audioBuffer;
339
- source.connect(gain).connect(audioCtx.destination);
340
- source.start();
341
- currentTTS = source;
342
-
343
- duckBgMusic();
344
- $('narr-ind').classList.add('speaking');
345
-
346
- source.onended = () => {
347
- currentTTS = null;
348
- unduckBgMusic();
349
- $('narr-ind').classList.remove('speaking');
350
- $('narr-text').classList.add('dim');
351
- };
352
- } catch (e) {
353
- console.warn('TTS playback failed:', e);
354
- $('narr-ind').classList.remove('speaking');
355
- }
356
- }
357
-
358
- function toggleMute() {
359
- isMuted = !isMuted;
360
- const btn = $('btn-mute');
361
- btn.textContent = isMuted ? '\u{1f507}' : '\u{1f50a}';
362
- btn.classList.toggle('active', !isMuted);
363
- bgMusic.volume = isMuted ? 0 : BG_VOLUME;
364
- if (currentTTS && isMuted) { try { currentTTS.stop(); } catch {} currentTTS = null; }
365
- }
366
-
367
- /* ---- Coordinate helpers ---- */
368
- function wx(ex) { return cam.ox + ex * GRID * cam.scale; }
369
- function wy(ey) { return cam.oy + ey * GRID * cam.scale; }
370
- function pos(eid) {
371
- const p = S.layout[eid];
372
- if (!p) return null;
373
- return { x: wx(Number(p.x)), y: wy(Number(p.y)) };
374
- }
375
-
376
- /* ---- Auto-fit camera ---- */
377
- function autoFit() {
378
- const wrap = document.querySelector('.canvas-wrap');
379
- const cw = wrap.clientWidth;
380
- const ch = wrap.clientHeight;
381
-
382
- let minX = Infinity, maxX = -Infinity, minY = Infinity, maxY = -Infinity;
383
- let hasPoints = false;
384
-
385
- for (const [eid, ent] of Object.entries(S.entities)) {
386
- if (ent && ent.entity_type === 'region' && ent.bounds) {
387
- minX = Math.min(minX, Number(ent.bounds.x0));
388
- maxX = Math.max(maxX, Number(ent.bounds.x1));
389
- minY = Math.min(minY, Number(ent.bounds.y0));
390
- maxY = Math.max(maxY, Number(ent.bounds.y1));
391
- hasPoints = true;
392
- }
393
- const p = S.layout[eid];
394
- if (p) {
395
- minX = Math.min(minX, Number(p.x));
396
- maxX = Math.max(maxX, Number(p.x));
397
- minY = Math.min(minY, Number(p.y));
398
- maxY = Math.max(maxY, Number(p.y));
399
- hasPoints = true;
400
- }
401
- }
402
-
403
- if (!hasPoints) { cam = { ox: cw / 2, oy: ch / 2, scale: 1 }; return; }
404
- if (minX === maxX && minY === maxY) {
405
- cam.scale = 1.5;
406
- cam.ox = cw / 2 - minX * GRID * cam.scale;
407
- cam.oy = ch / 2 - minY * GRID * cam.scale;
408
- return;
409
- }
410
- const spanX = (maxX - minX) || 1;
411
- const spanY = (maxY - minY) || 1;
412
- const pad = 50;
413
- const sx = (cw - pad * 2) / (spanX * GRID);
414
- const sy = (ch - pad * 2) / (spanY * GRID);
415
- cam.scale = Math.min(sx, sy, 2.2);
416
- const midX = (minX + maxX) / 2;
417
- const midY = (minY + maxY) / 2;
418
- cam.ox = cw / 2 - midX * GRID * cam.scale;
419
- cam.oy = ch / 2 - midY * GRID * cam.scale;
420
- }
421
-
422
- /* ---- Drawing primitives ---- */
423
- function drawRoundRect(x, y, w, h, r) {
424
- ctx.beginPath();
425
- ctx.moveTo(x + r, y);
426
- ctx.lineTo(x + w - r, y);
427
- ctx.quadraticCurveTo(x + w, y, x + w, y + r);
428
- ctx.lineTo(x + w, y + h - r);
429
- ctx.quadraticCurveTo(x + w, y + h, x + w - r, y + h);
430
- ctx.lineTo(x + r, y + h);
431
- ctx.quadraticCurveTo(x, y + h, x, y + h - r);
432
- ctx.lineTo(x, y + r);
433
- ctx.quadraticCurveTo(x, y, x + r, y);
434
- ctx.closePath();
435
- }
436
-
437
- function drawArrowhead(fx, fy, tx, ty) {
438
- const angle = Math.atan2(ty - fy, tx - fx);
439
- const s = ARROW_LEN * cam.scale;
440
- ctx.beginPath();
441
- ctx.moveTo(tx, ty);
442
- ctx.lineTo(tx - s * Math.cos(angle - Math.PI / 7), ty - s * Math.sin(angle - Math.PI / 7));
443
- ctx.lineTo(tx - s * Math.cos(angle + Math.PI / 7), ty - s * Math.sin(angle + Math.PI / 7));
444
- ctx.closePath();
445
- ctx.fill();
446
- }
447
-
448
- /* ---- Main render ---- */
449
- function render() {
450
- const dpr = window.devicePixelRatio || 1;
451
- const rect = canvas.getBoundingClientRect();
452
- canvas.width = rect.width * dpr;
453
- canvas.height = rect.height * dpr;
454
- ctx.setTransform(dpr, 0, 0, dpr, 0, 0);
455
- ctx.clearRect(0, 0, rect.width, rect.height);
456
-
457
- const highlighted = new Set();
458
- for (const a of S.annotations) {
459
- if (a && a.text === '[highlight]') highlighted.add(a.target_id);
460
- }
461
-
462
- // Regions
463
- for (const [eid, ent] of Object.entries(S.entities)) {
464
- if (!ent || ent.entity_type !== 'region') continue;
465
- const b = ent.bounds;
466
- if (!b) continue;
467
- const x0 = wx(Number(b.x0)), y0 = wy(Number(b.y0));
468
- const x1 = wx(Number(b.x1)), y1 = wy(Number(b.y1));
469
- ctx.save();
470
- ctx.strokeStyle = 'rgba(137,180,250,0.2)';
471
- ctx.lineWidth = 1;
472
- ctx.setLineDash([6, 4]);
473
- drawRoundRect(x0, y0, x1 - x0, y1 - y0, 8);
474
- ctx.stroke();
475
- ctx.setLineDash([]);
476
- ctx.font = FONT_SMALL;
477
- ctx.fillStyle = 'rgba(137,180,250,0.45)';
478
- ctx.fillText(ent.title || eid, x0 + 6, y0 + 13);
479
- ctx.restore();
480
- }
481
-
482
- // Edges
483
- for (const rel of S.relations) {
484
- if (!rel || !rel.src || !rel.dst) continue;
485
- const sp = pos(rel.src), dp = pos(rel.dst);
486
- if (!sp || !dp) continue;
487
- const r = NODE_R * cam.scale;
488
- const dx = dp.x - sp.x, dy = dp.y - sp.y;
489
- const dist = Math.hypot(dx, dy) || 1;
490
- const ux = dx / dist, uy = dy / dist;
491
- const sx = sp.x + ux * r, sy = sp.y + uy * r;
492
- const ex = dp.x - ux * r, ey = dp.y - uy * r;
493
-
494
- ctx.save();
495
- ctx.strokeStyle = '#585b70';
496
- ctx.lineWidth = 1.5 * cam.scale;
497
- ctx.beginPath();
498
- ctx.moveTo(sx, sy);
499
- ctx.lineTo(ex, ey);
500
- ctx.stroke();
501
- ctx.fillStyle = '#585b70';
502
- drawArrowhead(sx, sy, ex, ey);
503
- if (rel.label) {
504
- const mx = (sx + ex) / 2, my = (sy + ey) / 2;
505
- ctx.font = FONT_SMALL;
506
- ctx.fillStyle = '#bac2de';
507
- ctx.textAlign = 'center';
508
- ctx.fillText(rel.label, mx, my - 6 * cam.scale);
509
- }
510
- ctx.restore();
511
- }
512
-
513
- // Entities
514
- for (const [eid, ent] of Object.entries(S.entities)) {
515
- if (!ent) continue;
516
- const type = ent.entity_type || 'node';
517
- if (type === 'region') continue;
518
- const p = pos(eid);
519
- if (!p) continue;
520
- const role = ent.role || 'default';
521
- const colors = ROLE_COLORS[role] || ROLE_COLORS.default;
522
- const r = NODE_R * cam.scale;
523
- const isHl = highlighted.has(eid);
524
-
525
- ctx.save();
526
- if (isHl) { ctx.shadowColor = '#f9e2af'; ctx.shadowBlur = 18 * cam.scale; }
527
-
528
- if (type === 'pointer') {
529
- ctx.beginPath();
530
- ctx.moveTo(p.x - 8 * cam.scale, p.y - 10 * cam.scale);
531
- ctx.lineTo(p.x + 12 * cam.scale, p.y);
532
- ctx.lineTo(p.x - 8 * cam.scale, p.y + 10 * cam.scale);
533
- ctx.closePath();
534
- ctx.fillStyle = colors.fill;
535
- ctx.fill();
536
- ctx.strokeStyle = colors.stroke;
537
- ctx.lineWidth = 1.5 * cam.scale;
538
- ctx.stroke();
539
- ctx.shadowBlur = 0;
540
- ctx.font = FONT_SMALL;
541
- ctx.fillStyle = '#cdd6f4';
542
- ctx.textAlign = 'left';
543
- const label = eid;
544
- const val = ent.value != null ? ` -> ${ent.value}` : '';
545
- ctx.fillText(label + val, p.x + 16 * cam.scale, p.y + 4 * cam.scale);
546
-
547
- } else if (type === 'container') {
548
- const w = 90 * cam.scale, h = 44 * cam.scale;
549
- drawRoundRect(p.x - w / 2, p.y - h / 2, w, h, 6 * cam.scale);
550
- ctx.fillStyle = 'rgba(69,90,100,0.25)';
551
- ctx.fill();
552
- ctx.strokeStyle = colors.stroke;
553
- ctx.lineWidth = 1.5 * cam.scale;
554
- ctx.setLineDash([4, 3]);
555
- ctx.stroke();
556
- ctx.setLineDash([]);
557
- ctx.shadowBlur = 0;
558
- ctx.font = FONT_SMALL;
559
- ctx.fillStyle = '#89b4fa';
560
- ctx.textAlign = 'center';
561
- ctx.fillText(ent.title || eid, p.x, p.y - h / 2 - 5 * cam.scale);
562
- const contents = (ent.contents || []).join(', ');
563
- if (contents) {
564
- ctx.font = FONT_SMALL;
565
- ctx.fillStyle = '#cdd6f4';
566
- ctx.fillText(contents, p.x, p.y + 3 * cam.scale);
567
- }
568
-
569
- } else {
570
- ctx.beginPath();
571
- ctx.arc(p.x, p.y, r, 0, Math.PI * 2);
572
- ctx.fillStyle = colors.fill;
573
- ctx.fill();
574
- ctx.strokeStyle = colors.stroke;
575
- ctx.lineWidth = 2 * cam.scale;
576
- ctx.stroke();
577
- if (role === 'error') {
578
- ctx.beginPath();
579
- ctx.moveTo(p.x - r * .6, p.y - r * .6);
580
- ctx.lineTo(p.x + r * .6, p.y + r * .6);
581
- ctx.strokeStyle = '#fff';
582
- ctx.lineWidth = 2 * cam.scale;
583
- ctx.stroke();
584
- }
585
- ctx.shadowBlur = 0;
586
- const valText = ent.value != null ? String(ent.value) : eid;
587
- ctx.font = FONT_BOLD;
588
- ctx.fillStyle = colors.text;
589
- ctx.textAlign = 'center';
590
- ctx.textBaseline = 'middle';
591
- ctx.fillText(valText.length > 6 ? valText.slice(0, 6) : valText, p.x, p.y);
592
- ctx.textBaseline = 'alphabetic';
593
- if (role !== 'default') {
594
- ctx.font = FONT_SMALL;
595
- ctx.fillStyle = colors.stroke;
596
- ctx.fillText(role, p.x, p.y + r + 12 * cam.scale);
597
- }
598
- }
599
- ctx.restore();
600
- }
601
-
602
- // Annotations
603
- const annMap = new Map();
604
- for (const a of S.annotations) {
605
- if (!a || a.text === '[highlight]' || a.text === '[popped]') continue;
606
- annMap.set(a.target_id, a.text);
607
- }
608
- for (const [tid, text] of annMap) {
609
- const p = pos(tid);
610
- if (!p) continue;
611
- const r = NODE_R * cam.scale;
612
- ctx.save();
613
- ctx.font = FONT_ANN;
614
- const tw = ctx.measureText(text).width;
615
- const px = p.x - tw / 2 - 4, py = p.y - r - 14 * cam.scale;
616
- drawRoundRect(px, py - 10, tw + 8, 16, 4);
617
- ctx.fillStyle = 'rgba(250,179,135,0.15)';
618
- ctx.fill();
619
- ctx.strokeStyle = 'rgba(250,179,135,0.4)';
620
- ctx.lineWidth = 0.5;
621
- ctx.stroke();
622
- ctx.fillStyle = '#fab387';
623
- ctx.textAlign = 'center';
624
- ctx.fillText(text, p.x, py);
625
- ctx.restore();
626
- }
627
-
628
- // Notes
629
- for (const note of S.notes) {
630
- if (!note || !note.text) continue;
631
- let nx = 16, ny = rect.height - 24;
632
- const regionEnt = note.region && S.entities[note.region];
633
- if (regionEnt && regionEnt.bounds) {
634
- nx = wx(Number(regionEnt.bounds.x0)) + 6;
635
- ny = wy(Number(regionEnt.bounds.y1)) - 6;
636
- }
637
- ctx.save();
638
- ctx.font = FONT_SMALL;
639
- const t = note.text.length > 50 ? note.text.slice(0, 50) + '…' : note.text;
640
- ctx.fillStyle = 'rgba(108,112,134,0.7)';
641
- ctx.textAlign = 'left';
642
- ctx.fillText(t, nx, ny);
643
- ctx.restore();
644
- }
645
- }
646
-
647
- /* ---- UI updates ---- */
648
- function barColor(v) { return v >= .7 ? 'var(--green)' : v >= .4 ? 'var(--yellow)' : 'var(--red)'; }
649
- function penaltyColor(v) { return v <= 0 ? 'var(--green)' : v < .1 ? 'var(--yellow)' : 'var(--red)'; }
650
-
651
- function updateHeader() {
652
- $('h-task').textContent = S.taskName || '--';
653
- $('h-scenario').textContent = S.scenarioId || '';
654
- $('h-step').textContent = `Step ${S.step} / ${S.maxSteps}`;
655
- $('h-score').textContent = S.score.toFixed(3);
656
- }
657
-
658
- function updateChecklist() {
659
- const el = $('checklist');
660
- el.innerHTML = '';
661
- for (const c of S.checklist) {
662
- const ok = S.coverage.includes(c);
663
- el.innerHTML += `<div class="concept ${ok ? 'covered' : 'uncovered'}"><span class="icon">${ok ? '✓' : '○'}</span>${c}</div>`;
664
- }
665
- }
666
-
667
- function updateScores() {
668
- const el = $('scores');
669
- el.innerHTML = '';
670
- const bd = S.breakdown || {};
671
- const allKeys = Object.keys(bd).filter(k => k !== 'overall_score' && k !== 'phase' && typeof bd[k] === 'number');
672
- const subKeys = allKeys.filter(k => !k.startsWith('penalty_'));
673
- const penKeys = allKeys.filter(k => k.startsWith('penalty_'));
674
-
675
- for (const k of subKeys) {
676
- const v = Number(bd[k]) || 0;
677
- const pct = Math.max(0, Math.min(100, v * 100));
678
- const label = k.replace(/_score$/, '').replace(/_/g, ' ');
679
- el.innerHTML += `<div class="score-row"><span class="score-label">${label}</span><div class="score-track"><div class="score-fill" style="width:${pct}%;background:${barColor(v)}">${v.toFixed(2)}</div></div></div>`;
680
- }
681
- if (penKeys.length) {
682
- el.innerHTML += `<div style="margin-top:4px;margin-bottom:3px;font-size:9px;text-transform:uppercase;letter-spacing:1px;color:var(--red);font-weight:600">Penalties (lower = better)</div>`;
683
- for (const k of penKeys) {
684
- const v = Number(bd[k]) || 0;
685
- const pct = Math.max(0, Math.min(100, v * 100));
686
- const label = k.replace(/^penalty_/, '').replace(/_/g, ' ');
687
- el.innerHTML += `<div class="score-row"><span class="score-label">${label}</span><div class="score-track"><div class="score-fill" style="width:${pct}%;background:${penaltyColor(v)}">${v.toFixed(2)}</div></div></div>`;
688
- }
689
- }
690
- if (bd.overall_score != null) {
691
- const v = Number(bd.overall_score);
692
- const pct = Math.max(0, Math.min(100, v * 100));
693
- el.innerHTML += `<div class="score-row" style="margin-top:3px;border-top:1px solid var(--border);padding-top:3px"><span class="score-label" style="font-weight:700">overall</span><div class="score-track"><div class="score-fill" style="width:${pct}%;background:${barColor(v)}">${v.toFixed(3)}</div></div></div>`;
694
- }
695
- }
696
-
697
- function addNarration(step, text, reward) {
698
- const el = $('narrations');
699
- const cls = reward >= 0 ? 'pos' : 'neg';
700
- const sign = reward >= 0 ? '+' : '';
701
- el.innerHTML += `<div class="narr-entry"><span class="step-tag">Step ${step}</span><span class="reward ${cls}">${sign}${reward.toFixed(2)}</span><div class="text">${text || ''}</div></div>`;
702
- el.parentElement.scrollTop = el.parentElement.scrollHeight;
703
- }
704
-
705
- function showNarration(text) {
706
- const el = $('narr-text');
707
- el.textContent = text || '';
708
- el.classList.remove('dim');
709
- }
710
-
711
- function showToast(text) {
712
- const el = $('toast');
713
- el.textContent = text;
714
- el.style.display = 'block';
715
- clearTimeout(toastTimer);
716
- toastTimer = setTimeout(() => el.style.display = 'none', 4000);
717
- }
718
-
719
- /* ---- Message handlers ---- */
720
- function onClear() {
721
- S.entities = {}; S.relations = []; S.layout = {};
722
- S.annotations = []; S.notes = [];
723
- S.taskName = ''; S.scenarioId = ''; S.goal = '';
724
- S.checklist = []; S.coverage = [];
725
- S.maxSteps = 0; S.step = 0; S.score = 0;
726
- S.breakdown = {}; S.narrations = [];
727
- updateHeader(); updateChecklist(); updateScores();
728
- $('narrations').innerHTML = '';
729
- $('goal-text').textContent = '--';
730
- $('overlay').classList.add('hidden');
731
- showNarration('');
732
- stopBgMusic();
733
- autoFit(); render();
734
- }
735
-
736
- function onReset(m) {
737
- $('overlay').classList.add('hidden');
738
- S.taskName = m.task_name || '';
739
- S.scenarioId = m.scenario_id || '';
740
- S.goal = m.goal || '';
741
- S.checklist = m.checklist || [];
742
- S.coverage = m.coverage || [];
743
- S.maxSteps = m.max_steps || 0;
744
- S.step = 0;
745
- S.score = 0;
746
- S.breakdown = m.score_breakdown || {};
747
- S.entities = m.entities || {};
748
- S.relations = m.relations || [];
749
- S.layout = m.layout || {};
750
- S.annotations = m.annotations || [];
751
- S.notes = m.notes || [];
752
- S.narrations = [];
753
-
754
- updateHeader();
755
- updateChecklist();
756
- updateScores();
757
- $('narrations').innerHTML = '';
758
- $('goal-text').textContent = S.goal;
759
- showNarration(S.goal);
760
- autoFit();
761
- render();
762
-
763
- startBgMusic();
764
- if (m.audio) playTTS(m.audio);
765
- }
766
-
767
- function onStep(m) {
768
- S.step = m.step || S.step;
769
- S.score = m.score != null ? m.score : S.score;
770
- S.breakdown = m.score_breakdown || S.breakdown;
771
- S.coverage = m.coverage || S.coverage;
772
- S.entities = m.entities || S.entities;
773
- S.relations = m.relations || S.relations;
774
- S.layout = m.layout || S.layout;
775
- S.annotations = m.annotations || S.annotations;
776
- S.notes = m.notes || S.notes;
777
- S.maxSteps = S.step + (m.remaining_step_budget || 0);
778
-
779
- updateHeader();
780
- updateChecklist();
781
- updateScores();
782
- autoFit();
783
- render();
784
-
785
- addNarration(m.step, m.narration || '', m.reward || 0);
786
- showNarration(m.narration || '');
787
- if (m.error) showToast(m.error);
788
- if (m.audio) playTTS(m.audio);
789
- }
790
-
791
- function onEnd(m) {
792
- stopBgMusic();
793
-
794
- const ov = $('overlay');
795
- ov.classList.remove('hidden');
796
- $('ov-title').textContent = `Episode Done — ${m.task_name || ''}`;
797
- $('ov-title').className = m.success ? 'ok' : 'fail';
798
- $('ov-score').textContent = `Score: ${(m.score || 0).toFixed(3)} | Steps: ${m.steps || 0}`;
799
-
800
- const chart = $('ov-chart');
801
- chart.innerHTML = '';
802
- const rr = m.rewards || [];
803
- const mx = Math.max(0.01, ...rr.map(r => Math.abs(r)));
804
- for (let i = 0; i < rr.length; i++) {
805
- const r = rr[i];
806
- const h = Math.max(2, (Math.abs(r) / mx) * 100);
807
- chart.innerHTML += `<div class="bar" style="height:${h}px;background:${r >= 0 ? 'var(--accent)' : 'var(--red)'}"><span class="lbl">${i + 1}</span></div>`;
808
- }
809
- ov.onclick = () => ov.classList.add('hidden');
810
- }
811
-
812
- /* ---- Long-poll (single pending request, no spam) ---- */
813
- let pollCursor = 0;
814
- let pollRunning = false;
815
-
816
- async function pollLoop() {
817
- if (pollRunning) return;
818
- pollRunning = true;
819
- while (true) {
820
- try {
821
- const resp = await fetch(`/poll?since=${pollCursor}`);
822
- if (!resp.ok) {
823
- $('h-ws').className = 'ws-dot off';
824
- await new Promise(r => setTimeout(r, 2000));
825
- continue;
826
- }
827
- $('h-ws').className = 'ws-dot on';
828
- const data = await resp.json();
829
- pollCursor = data.next;
830
- for (const m of (data.messages || [])) {
831
- if (m.type === 'clear') onClear();
832
- else if (m.type === 'reset') onReset(m);
833
- else if (m.type === 'step') onStep(m);
834
- else if (m.type === 'end') onEnd(m);
835
- else if (m.type === 'shutdown') { stopBgMusic(); showNarration('All episodes complete.'); }
836
- }
837
- } catch {
838
- $('h-ws').className = 'ws-dot off';
839
- await new Promise(r => setTimeout(r, 2000));
840
- }
841
- }
842
- }
843
-
844
- function connect() {
845
- pollCursor = 0;
846
- pollLoop();
847
- }
848
-
849
-
850
- /* ---- Mute button ---- */
851
- $('btn-mute').addEventListener('click', toggleMute);
852
-
853
- /* ---- Init ---- */
854
- function init() {
855
- canvas = $('canvas');
856
- ctx = canvas.getContext('2d');
857
- window.addEventListener('resize', () => { autoFit(); render(); });
858
- autoFit();
859
- render();
860
- connect();
861
- }
862
- init();
863
- </script>
864
- </body>
865
- </html>