Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse files- .gitignore +41 -0
- CLAUDE.md +0 -111
- README.md +41 -41
- inference_audio.py +0 -451
- openenv_visual_reasoning.egg-info/PKG-INFO +0 -19
- openenv_visual_reasoning.egg-info/SOURCES.txt +0 -38
- openenv_visual_reasoning.egg-info/dependency_links.txt +0 -1
- openenv_visual_reasoning.egg-info/entry_points.txt +0 -2
- openenv_visual_reasoning.egg-info/requires.txt +0 -15
- openenv_visual_reasoning.egg-info/top_level.txt +0 -1
- push_to_space.ipynb +129 -0
- scripts/generate_rubric_data.py +0 -1132
- server/app.py +157 -91
- server/app_backup.py +0 -46
- train.ipynb +0 -913
- train.py +0 -632
- train_hf.py +0 -771
- uv.lock +0 -0
- viewer/audio_viewer.html +0 -865
.gitignore
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Virtual environment
|
| 2 |
+
.venv/
|
| 3 |
+
venv/
|
| 4 |
+
|
| 5 |
+
# Python
|
| 6 |
+
__pycache__/
|
| 7 |
+
*.py[cod]
|
| 8 |
+
*.egg-info/
|
| 9 |
+
dist/
|
| 10 |
+
build/
|
| 11 |
+
*.egg
|
| 12 |
+
|
| 13 |
+
# Environment / secrets
|
| 14 |
+
.env
|
| 15 |
+
.env.local
|
| 16 |
+
|
| 17 |
+
# IDE
|
| 18 |
+
.idea/
|
| 19 |
+
.vscode/
|
| 20 |
+
*.swp
|
| 21 |
+
*.swo
|
| 22 |
+
*~
|
| 23 |
+
|
| 24 |
+
# OS
|
| 25 |
+
.DS_Store
|
| 26 |
+
Thumbs.db
|
| 27 |
+
|
| 28 |
+
# pytest
|
| 29 |
+
.pytest_cache/
|
| 30 |
+
htmlcov/
|
| 31 |
+
.coverage
|
| 32 |
+
|
| 33 |
+
# uv
|
| 34 |
+
uv.lock
|
| 35 |
+
|
| 36 |
+
# Jupyter
|
| 37 |
+
.ipynb_checkpoints/
|
| 38 |
+
|
| 39 |
+
CLAUDE.md
|
| 40 |
+
|
| 41 |
+
openenv_visual_reasoning.egg-info/
|
CLAUDE.md
DELETED
|
@@ -1,111 +0,0 @@
|
|
| 1 |
-
# Visual Reasoning Environment
|
| 2 |
-
|
| 3 |
-
## What This Is
|
| 4 |
-
|
| 5 |
-
An OpenEnv RL environment for training LLMs (via GRPO/RLVR) to be expert visual explainers of CS algorithms. The LLM acts as a teacher drawing on a whiteboard — it creates data structures, walks through algorithms step-by-step, and narrates the reasoning. A scoring system provides dense, verifiable rewards.
|
| 6 |
-
|
| 7 |
-
## Architecture
|
| 8 |
-
|
| 9 |
-
```
|
| 10 |
-
inference.py / inference_tldraw.py — LLM inference loops (client-side)
|
| 11 |
-
client.py — WebSocket EnvClient wrapper
|
| 12 |
-
server/app.py — FastAPI OpenEnv server entry point
|
| 13 |
-
server/visual_reasoning_environment.py — Core env: reset(), step(), state management
|
| 14 |
-
server/scoring.py — 13 weighted sub-scores + 5 penalties → reward
|
| 15 |
-
server/invariant_checkers.py — Per-algorithm correctness checks (9 algorithms)
|
| 16 |
-
server/narration_scorer.py — Qwen3-Embedding-0.6B cosine similarity scorer
|
| 17 |
-
server/pedagogical_scoring.py — Teaching quality: attention coherence, pacing, scaffolding
|
| 18 |
-
server/scenario_loader.py — Loads scenarios.json + procedural generation
|
| 19 |
-
server/scenario_generator.py — Procedural scenario generation per difficulty
|
| 20 |
-
server/regions.py — Layout engine (queue/stack/tree/graph positioning)
|
| 21 |
-
server/constants.py — ALLOWED_OPS, ROLE_VALUES, limits
|
| 22 |
-
models.py — Pydantic models: VisualReasoningAction, VisualReasoningObservation
|
| 23 |
-
viewer/tldraw_viewer.html — tldraw-based browser visualizer
|
| 24 |
-
```
|
| 25 |
-
|
| 26 |
-
## Key Concepts
|
| 27 |
-
|
| 28 |
-
- **Empty canvas paradigm**: All scenarios start empty. The LLM draws the problem first (Phase 1), then solves it step-by-step (Phase 2), then completes (Phase 3).
|
| 29 |
-
- **16 canvas operations**: add_region, add_node, add_pointer, add_container, add_edge, remove_edge, push_to, pop_from, move_pointer, set_value, set_role, annotate, highlight, unhighlight, add_note, remove_entity.
|
| 30 |
-
- **10 entity roles**: default, current, visited, frontier, done, pivot, root, error, inactive, comparing.
|
| 31 |
-
- **Region vs Container**: Regions (`add_region`) are layout areas for visual positioning. Containers (`add_container`) track membership for push/pop. `push_to`/`pop_from` ONLY work on containers, not regions. This distinction is a common source of LLM confusion.
|
| 32 |
-
- **Delta-based rewards**: `reward = (new_overall_score - previous_overall_score) + flat_penalties`. Scoring is deterministic for RL training reproducibility.
|
| 33 |
-
- **Concept coverage**: The LLM claims concepts from a checklist; coverage is verified via narration evidencing with prefix matching + alias expansion (`_CONCEPT_ALIASES` and `_CONCEPT_PART_ALIASES` in scoring.py).
|
| 34 |
-
|
| 35 |
-
## Running
|
| 36 |
-
|
| 37 |
-
### Environment setup
|
| 38 |
-
```bash
|
| 39 |
-
conda activate unsloth_env
|
| 40 |
-
```
|
| 41 |
-
|
| 42 |
-
### Run tests
|
| 43 |
-
```bash
|
| 44 |
-
conda run -n unsloth_env python -m pytest tests/ -v
|
| 45 |
-
```
|
| 46 |
-
|
| 47 |
-
### Run inference (headless)
|
| 48 |
-
```bash
|
| 49 |
-
# Start the server first (in another terminal or via Docker)
|
| 50 |
-
LOCAL_IMAGE_NAME=http://127.0.0.1:8000 python inference.py
|
| 51 |
-
```
|
| 52 |
-
|
| 53 |
-
### Run inference with tldraw viewer
|
| 54 |
-
```bash
|
| 55 |
-
python inference_tldraw.py
|
| 56 |
-
# Opens browser at http://0.0.0.0:8765/
|
| 57 |
-
```
|
| 58 |
-
|
| 59 |
-
### Environment variables
|
| 60 |
-
- `LOCAL_IMAGE_NAME` — Docker image or `http://127.0.0.1:PORT` for local server
|
| 61 |
-
- `API_BASE_URL` — LLM API endpoint (default: HuggingFace router)
|
| 62 |
-
- `API_KEY` / `HF_TOKEN` — API authentication
|
| 63 |
-
- `MODEL_NAME` — LLM model (default: Qwen/Qwen2.5-72B-Instruct)
|
| 64 |
-
- `VISUAL_REASONING_TASKS` — Comma-separated task list (default: easy,medium,hard,expert)
|
| 65 |
-
- `DEBUG=1` — Enable verbose debug logging
|
| 66 |
-
- `VIS_PORT`, `VIS_HOST`, `VIS_WAIT` — tldraw viewer settings
|
| 67 |
-
|
| 68 |
-
## Scoring System
|
| 69 |
-
|
| 70 |
-
13 weighted sub-scores (weights vary by difficulty level):
|
| 71 |
-
- **validity** (~10-12%): Correct op formats, no invalid references
|
| 72 |
-
- **invariant** (~18-22%): Algorithm correctness checked against ground truth
|
| 73 |
-
- **coverage** (~17-18%): Concept checklist completion via narration evidencing
|
| 74 |
-
- **narration_quality** (~6-10%): Cosine similarity against reference narrations
|
| 75 |
-
- **structure** (~5-7%): Constraint satisfaction + entity monotonicity
|
| 76 |
-
- **progress** (~4-5%): State must change each step; granularity penalty for >4 creations per step
|
| 77 |
-
- **algorithm_completion** (~5%): Cumulative algorithm progress (% nodes placed, edges drawn, roles assigned)
|
| 78 |
-
- **spatial** (~6%): Region placement on the canvas grid (semantic fit, collision avoidance, reading-order flow)
|
| 79 |
-
- **consistency** (~5-7%): Unexplained entity changes are penalized
|
| 80 |
-
- **attention_coherence** (~6-7%): Narration entities match op targets (fuzzy matching)
|
| 81 |
-
- **visual** (~2%): Layout overlap/occlusion/crossing penalties
|
| 82 |
-
- **cognitive_pacing** (~7-8%): Information density vs. novelty; penalizes creation-heavy dumps
|
| 83 |
-
- **scaffolding** (~7-9%): Emphasis decreases over repeated patterns
|
| 84 |
-
|
| 85 |
-
5 penalties subtracted from weighted sum:
|
| 86 |
-
- `penalty_redundant` (0.2): All ops duplicate existing state
|
| 87 |
-
- `penalty_no_op` (applied as flat -0.05 in env): Zero state delta
|
| 88 |
-
- `penalty_unsupported_claims` (up to 0.3): Claiming concepts not evidenced in narration
|
| 89 |
-
- `penalty_too_many_ops` (up to 0.5): Exceeding MAX_OPS_PER_STEP (14)
|
| 90 |
-
- `penalty_info_dump` (up to 0.2): >5 creation ops in a single step (0.05 per excess)
|
| 91 |
-
|
| 92 |
-
## Algorithms / Scenarios
|
| 93 |
-
|
| 94 |
-
9 algorithm templates across 4 difficulty levels:
|
| 95 |
-
- **easy**: linked_list_traversal, stack_ops, binary_search
|
| 96 |
-
- **medium**: bfs_graph, hash_table_chaining
|
| 97 |
-
- **hard**: dijkstra_step, bst_insert
|
| 98 |
-
- **expert**: fib_memo, quicksort_lomuto
|
| 99 |
-
|
| 100 |
-
Static scenarios in `server/scenarios/scenarios.json`, plus procedurally generated ones via `scenario_generator.py`.
|
| 101 |
-
|
| 102 |
-
## Common Pitfalls When Modifying
|
| 103 |
-
|
| 104 |
-
1. **Scoring must be deterministic** — no randomness, no floating-point order sensitivity. The regression test (`test_easy_1_reproducible`) enforces bit-identical scores across runs.
|
| 105 |
-
2. **`first_conflict_message` takes an optional `action` arg** — pass it for context-specific error messages that help the LLM self-correct.
|
| 106 |
-
3. **`compute_progress_score` takes an optional `action` arg** — needed so `complete` steps get progress=1.0.
|
| 107 |
-
4. **Concept evidencing uses three layers**: exact token match → prefix morphological match (`_prefix_match`) → alias expansion (`_CONCEPT_ALIASES` / `_CONCEPT_PART_ALIASES`). When adding new scenarios, ensure all checklist concepts have corresponding aliases.
|
| 108 |
-
5. **The narration scorer uses Qwen3-Embedding-0.6B on CUDA** — falls back to a heuristic `_fallback_score` if the model fails to load. Check `warmup_scorer` logs to confirm which path runs.
|
| 109 |
-
6. **The `.venv` Python is broken (3.7 binary, 3.10 site-packages)** — always use `conda run -n unsloth_env` for running code.
|
| 110 |
-
7. **`openenv` import resolution**: Server modules use try/except for relative vs absolute imports. Tests run from the project root with `PYTHONPATH=.`.
|
| 111 |
-
8. **The inference loop uses `run_in_executor`** for LLM calls to avoid blocking the async event loop (which would cause WebSocket keepalive timeouts).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
README.md
CHANGED
|
@@ -11,7 +11,6 @@ tags:
|
|
| 11 |
- reinforcement-learning
|
| 12 |
- llm
|
| 13 |
- grpo
|
| 14 |
-
base_path: /web
|
| 15 |
---
|
| 16 |
|
| 17 |
# Visual Reasoning Environment
|
|
@@ -140,33 +139,32 @@ Clipped to `[-0.2, 1.0]`. Designed for GRPO / RLVR training.
|
|
| 140 |
Here's the full picture -- how the agent, the environment, and the reward signal fit together:
|
| 141 |
|
| 142 |
```
|
| 143 |
-
┌─────────────────────────────────────────────────────────────────
|
| 144 |
-
│
|
| 145 |
-
│
|
| 146 |
-
│ ┌───────────┐
|
| 147 |
-
│ │ │ ────────
|
| 148 |
-
│ │ Scenario │
|
| 149 |
-
│ │ Generator│
|
| 150 |
-
│ │ │
|
| 151 |
-
│ └───────────┘
|
| 152 |
-
│
|
| 153 |
-
│
|
| 154 |
-
│
|
| 155 |
-
│
|
| 156 |
-
│
|
| 157 |
-
│
|
| 158 |
-
│
|
| 159 |
-
│
|
| 160 |
-
│
|
| 161 |
-
│
|
| 162 |
-
│
|
| 163 |
-
│
|
| 164 |
-
│
|
| 165 |
-
│
|
| 166 |
-
│
|
| 167 |
-
│ Episode: empty canvas ──> Phase 1
|
| 168 |
-
|
| 169 |
-
└─────────────────────────────────────────────────────────────────────┘
|
| 170 |
```
|
| 171 |
|
| 172 |
Every episode starts with a blank canvas and a goal like *"Explain how Dijkstra's algorithm finds shortest paths in this graph."* The agent draws, narrates, and advances step by step. The scoring engine evaluates each step across all 13 dimensions. The reward signal flows back into the RL training loop, gradually shaping the agent into a better teacher.
|
|
@@ -265,19 +263,21 @@ The overall score jumped from **0.368 to 0.536** -- a 45.7% relative improvement
|
|
| 265 |
```
|
| 266 |
SFT+GRPO Score by Difficulty (Qwen2.5-3B, single A100)
|
| 267 |
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
0.
|
| 273 |
-
|
| 274 |
-
0.
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
|
|
|
|
|
|
| 281 |
```
|
| 282 |
|
| 283 |
A few things stand out from these results:
|
|
|
|
| 11 |
- reinforcement-learning
|
| 12 |
- llm
|
| 13 |
- grpo
|
|
|
|
| 14 |
---
|
| 15 |
|
| 16 |
# Visual Reasoning Environment
|
|
|
|
| 139 |
Here's the full picture -- how the agent, the environment, and the reward signal fit together:
|
| 140 |
|
| 141 |
```
|
| 142 |
+
┌─────────────────────────────────────────────────────────────────┐
|
| 143 |
+
│ TRAINING LOOP (GRPO / RLVR) │
|
| 144 |
+
│ │
|
| 145 |
+
│ ┌───────────┐ prompt ┌──────────────┐ JSON action │
|
| 146 |
+
│ │ │ ────────> │ │ ──────────────┐ │
|
| 147 |
+
│ │ Scenario │ │ LLM Agent │ │ │
|
| 148 |
+
│ │ Generator│ │ (Teacher) │ │ │
|
| 149 |
+
│ │ │ ┌──────> │ │ <────────┐ │ │
|
| 150 |
+
│ └───────────┘ │ └──────────────┘ │ │ │
|
| 151 |
+
│ │ │ │ │
|
| 152 |
+
│ observation reward │ │
|
| 153 |
+
│ + score breakdown signal │ │
|
| 154 |
+
│ │ │ │ │
|
| 155 |
+
│ ┌────────┴───────┐ score ┌──────┴─┐ │ │
|
| 156 |
+
│ │ │ <──────────────── │ │ │ │
|
| 157 |
+
│ │ Environment │ │Scoring │ │ │
|
| 158 |
+
│ │ (Empty Canvas) │ ───────────────> │ Engine │ │ │
|
| 159 |
+
│ │ │ canvas state │(13 dim)│ │ │
|
| 160 |
+
│ └────────────────┘ └────────┘ │ │
|
| 161 |
+
│ ^ │ │
|
| 162 |
+
│ │ step(action) │ │
|
| 163 |
+
│ └───────────────────────────────────────┘ │
|
| 164 |
+
│ │
|
| 165 |
+
│ reward = Δ(overall_score) + penalties + concept_bonuses │
|
| 166 |
+
│ Episode: empty canvas ──> Phase 1 ──> Phase 2 ──> Phase 3 │
|
| 167 |
+
└─────────────────────────────────────────────────────────────────┘
|
|
|
|
| 168 |
```
|
| 169 |
|
| 170 |
Every episode starts with a blank canvas and a goal like *"Explain how Dijkstra's algorithm finds shortest paths in this graph."* The agent draws, narrates, and advances step by step. The scoring engine evaluates each step across all 13 dimensions. The reward signal flows back into the RL training loop, gradually shaping the agent into a better teacher.
|
|
|
|
| 263 |
```
|
| 264 |
SFT+GRPO Score by Difficulty (Qwen2.5-3B, single A100)
|
| 265 |
|
| 266 |
+
|
|
| 267 |
+
0.7 |
|
| 268 |
+
| ┌────────┐
|
| 269 |
+
0.6 | ┌────────┐ │ 0.635 │
|
| 270 |
+
| │ 0.566 │ └────────┘
|
| 271 |
+
0.5 | ┌────────┐ └────────┘
|
| 272 |
+
| │ 0.481 │ ┌────────┐
|
| 273 |
+
0.4 | └────────┘ │ 0.461 │
|
| 274 |
+
| └────────┘
|
| 275 |
+
0.3 | +33.6% +28.4% +21.5% +120.5%
|
| 276 |
+
|
|
| 277 |
+
0.2 | ░ 0.360 ░ ░ 0.359 ░ ░ 0.466 ░ ░ 0.288 ░
|
| 278 |
+
| Baseline
|
| 279 |
+
+─────────────────────────────────────────────────
|
| 280 |
+
Easy Medium Hard Expert
|
| 281 |
```
|
| 282 |
|
| 283 |
A few things stand out from these results:
|
inference_audio.py
DELETED
|
@@ -1,451 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the BSD-style license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
|
| 7 |
-
from __future__ import annotations
|
| 8 |
-
|
| 9 |
-
import asyncio
|
| 10 |
-
import base64
|
| 11 |
-
import contextlib
|
| 12 |
-
import json
|
| 13 |
-
import os
|
| 14 |
-
import sys
|
| 15 |
-
import time
|
| 16 |
-
from pathlib import Path
|
| 17 |
-
from typing import Any, Dict, List, Optional, Set, Tuple
|
| 18 |
-
|
| 19 |
-
import uvicorn
|
| 20 |
-
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
| 21 |
-
from fastapi.responses import FileResponse
|
| 22 |
-
from openai import OpenAI
|
| 23 |
-
|
| 24 |
-
from client import VisualReasoningEnv
|
| 25 |
-
from models import VisualReasoningAction
|
| 26 |
-
|
| 27 |
-
from inference import (
|
| 28 |
-
API_BASE_URL,
|
| 29 |
-
API_KEY,
|
| 30 |
-
BENCHMARK,
|
| 31 |
-
HF_SPACE_URL,
|
| 32 |
-
LOCAL_IMAGE_NAME,
|
| 33 |
-
MAX_STEPS,
|
| 34 |
-
MODEL_NAME,
|
| 35 |
-
SUCCESS_SCORE_THRESHOLD,
|
| 36 |
-
TASK_NAMES,
|
| 37 |
-
action_to_string,
|
| 38 |
-
debug,
|
| 39 |
-
get_model_action_async,
|
| 40 |
-
log_end,
|
| 41 |
-
log_start,
|
| 42 |
-
log_step,
|
| 43 |
-
)
|
| 44 |
-
|
| 45 |
-
VIS_PORT = int(os.getenv("VIS_PORT", "8765"))
|
| 46 |
-
VIS_HOST = os.getenv("VIS_HOST", "0.0.0.0")
|
| 47 |
-
VIS_WAIT = float(os.getenv("VIS_WAIT", "30"))
|
| 48 |
-
|
| 49 |
-
VIEWER_DIR = Path(__file__).parent / "viewer"
|
| 50 |
-
HTML_PATH = VIEWER_DIR / "audio_viewer.html"
|
| 51 |
-
AUDIO_DIR = Path(__file__).parent / "others"
|
| 52 |
-
BACKGROUND_MUSIC_PATH = AUDIO_DIR / "tutorial_background.mp3"
|
| 53 |
-
HISTORY_CAP = 500
|
| 54 |
-
|
| 55 |
-
TTS_LEAD_TIME = 2.0
|
| 56 |
-
TTS_MIN_WAIT = 1.5
|
| 57 |
-
TTS_MODEL = "hexgrad/Kokoro-82M"
|
| 58 |
-
TTS_PROVIDER = "fal-ai"
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
def vlog(msg: str) -> None:
|
| 62 |
-
print(f"[AUDIO] {msg}", file=sys.stderr, flush=True)
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
# ---------------------------------------------------------------------------
|
| 66 |
-
# TTS via huggingface_hub InferenceClient
|
| 67 |
-
# ---------------------------------------------------------------------------
|
| 68 |
-
|
| 69 |
-
def _make_tts_client():
|
| 70 |
-
from huggingface_hub import InferenceClient
|
| 71 |
-
return InferenceClient(provider=TTS_PROVIDER, api_key=os.environ.get("HF_TOKEN", ""))
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
_tts_client = None
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
def _get_tts_client():
|
| 78 |
-
global _tts_client
|
| 79 |
-
if _tts_client is None:
|
| 80 |
-
_tts_client = _make_tts_client()
|
| 81 |
-
return _tts_client
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
def _estimate_duration(audio_bytes: bytes) -> float:
|
| 85 |
-
if len(audio_bytes) < 44:
|
| 86 |
-
return 0.0
|
| 87 |
-
# Try WAV header
|
| 88 |
-
if audio_bytes[:4] == b"RIFF" and audio_bytes[8:12] == b"WAVE":
|
| 89 |
-
import struct
|
| 90 |
-
try:
|
| 91 |
-
byte_rate = struct.unpack_from("<I", audio_bytes, 28)[0]
|
| 92 |
-
data_offset = audio_bytes.find(b"data")
|
| 93 |
-
if data_offset >= 0 and byte_rate > 0:
|
| 94 |
-
data_size = struct.unpack_from("<I", audio_bytes, data_offset + 4)[0]
|
| 95 |
-
return data_size / byte_rate
|
| 96 |
-
except Exception:
|
| 97 |
-
pass
|
| 98 |
-
# Rough estimate for compressed audio (~16kB/s for mp3 at 128kbps)
|
| 99 |
-
return len(audio_bytes) / 16000.0
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
def _synthesize_sync(text: str) -> Tuple[Optional[str], float]:
|
| 103 |
-
if not text or not text.strip():
|
| 104 |
-
return None, 0.0
|
| 105 |
-
try:
|
| 106 |
-
client = _get_tts_client()
|
| 107 |
-
audio_bytes = client.text_to_speech(text, model=TTS_MODEL)
|
| 108 |
-
if not audio_bytes:
|
| 109 |
-
vlog("TTS returned empty audio")
|
| 110 |
-
return None, 0.0
|
| 111 |
-
duration = _estimate_duration(audio_bytes)
|
| 112 |
-
vlog(f"TTS ok: {len(audio_bytes)} bytes, ~{duration:.1f}s")
|
| 113 |
-
return base64.b64encode(audio_bytes).decode("ascii"), duration
|
| 114 |
-
except Exception as exc:
|
| 115 |
-
vlog(f"TTS error: {exc}")
|
| 116 |
-
return None, 0.0
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
# ---------------------------------------------------------------------------
|
| 120 |
-
# Broadcaster
|
| 121 |
-
# ---------------------------------------------------------------------------
|
| 122 |
-
|
| 123 |
-
class Broadcaster:
|
| 124 |
-
def __init__(self) -> None:
|
| 125 |
-
self._clients: Set[WebSocket] = set()
|
| 126 |
-
self._history: List[Dict[str, Any]] = []
|
| 127 |
-
self._first_client = asyncio.Event()
|
| 128 |
-
self._lock = asyncio.Lock()
|
| 129 |
-
|
| 130 |
-
async def register(self, ws: WebSocket) -> None:
|
| 131 |
-
async with self._lock:
|
| 132 |
-
self._clients.add(ws)
|
| 133 |
-
replay = list(self._history)
|
| 134 |
-
self._first_client.set()
|
| 135 |
-
vlog(f"connected clients={len(self._clients)}")
|
| 136 |
-
for msg in replay:
|
| 137 |
-
try:
|
| 138 |
-
await ws.send_text(json.dumps(msg))
|
| 139 |
-
except Exception:
|
| 140 |
-
return
|
| 141 |
-
|
| 142 |
-
async def unregister(self, ws: WebSocket) -> None:
|
| 143 |
-
async with self._lock:
|
| 144 |
-
self._clients.discard(ws)
|
| 145 |
-
vlog(f"client disconnected clients={len(self._clients)}")
|
| 146 |
-
|
| 147 |
-
async def send(self, msg: Dict[str, Any]) -> None:
|
| 148 |
-
history_msg = {k: v for k, v in msg.items() if k != "audio"}
|
| 149 |
-
async with self._lock:
|
| 150 |
-
self._history.append(history_msg)
|
| 151 |
-
if len(self._history) > HISTORY_CAP:
|
| 152 |
-
self._history = self._history[-HISTORY_CAP:]
|
| 153 |
-
targets = list(self._clients)
|
| 154 |
-
if not targets:
|
| 155 |
-
return
|
| 156 |
-
payload = json.dumps(msg)
|
| 157 |
-
dead: List[WebSocket] = []
|
| 158 |
-
for ws in targets:
|
| 159 |
-
try:
|
| 160 |
-
await ws.send_text(payload)
|
| 161 |
-
except Exception:
|
| 162 |
-
dead.append(ws)
|
| 163 |
-
if dead:
|
| 164 |
-
async with self._lock:
|
| 165 |
-
for ws in dead:
|
| 166 |
-
self._clients.discard(ws)
|
| 167 |
-
|
| 168 |
-
async def wait_for_client(self, timeout: float) -> bool:
|
| 169 |
-
if timeout <= 0:
|
| 170 |
-
return self._first_client.is_set()
|
| 171 |
-
try:
|
| 172 |
-
await asyncio.wait_for(self._first_client.wait(), timeout=timeout)
|
| 173 |
-
return True
|
| 174 |
-
except asyncio.TimeoutError:
|
| 175 |
-
return False
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
# ---------------------------------------------------------------------------
|
| 179 |
-
# FastAPI app
|
| 180 |
-
# ---------------------------------------------------------------------------
|
| 181 |
-
|
| 182 |
-
def build_viewer_app(broadcaster: Broadcaster) -> FastAPI:
|
| 183 |
-
app = FastAPI()
|
| 184 |
-
|
| 185 |
-
@app.get("/")
|
| 186 |
-
async def index():
|
| 187 |
-
return FileResponse(str(HTML_PATH), media_type="text/html")
|
| 188 |
-
|
| 189 |
-
@app.get("/audio/background.mp3")
|
| 190 |
-
async def background_music():
|
| 191 |
-
if BACKGROUND_MUSIC_PATH.exists():
|
| 192 |
-
return FileResponse(str(BACKGROUND_MUSIC_PATH), media_type="audio/mpeg")
|
| 193 |
-
return {"error": "background music not found"}
|
| 194 |
-
|
| 195 |
-
@app.get("/health")
|
| 196 |
-
async def health():
|
| 197 |
-
return {"ok": True}
|
| 198 |
-
|
| 199 |
-
@app.websocket("/ws")
|
| 200 |
-
async def ws_endpoint(ws: WebSocket):
|
| 201 |
-
await ws.accept()
|
| 202 |
-
await broadcaster.register(ws)
|
| 203 |
-
try:
|
| 204 |
-
while True:
|
| 205 |
-
await ws.receive_text()
|
| 206 |
-
except WebSocketDisconnect:
|
| 207 |
-
pass
|
| 208 |
-
except Exception as exc:
|
| 209 |
-
debug(f"ws error: {exc}")
|
| 210 |
-
finally:
|
| 211 |
-
await broadcaster.unregister(ws)
|
| 212 |
-
|
| 213 |
-
return app
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
async def start_viewer(broadcaster: Broadcaster):
|
| 217 |
-
if not HTML_PATH.exists():
|
| 218 |
-
vlog(f"ERROR: viewer HTML missing at {HTML_PATH}")
|
| 219 |
-
sys.exit(1)
|
| 220 |
-
app = build_viewer_app(broadcaster)
|
| 221 |
-
config = uvicorn.Config(
|
| 222 |
-
app,
|
| 223 |
-
host=VIS_HOST,
|
| 224 |
-
port=VIS_PORT,
|
| 225 |
-
log_config=None,
|
| 226 |
-
access_log=False,
|
| 227 |
-
log_level="warning",
|
| 228 |
-
)
|
| 229 |
-
server = uvicorn.Server(config)
|
| 230 |
-
task = asyncio.create_task(server.serve())
|
| 231 |
-
for _ in range(50):
|
| 232 |
-
if server.started:
|
| 233 |
-
break
|
| 234 |
-
if task.done():
|
| 235 |
-
exc = task.exception()
|
| 236 |
-
vlog(f"viewer server failed to start: {exc}")
|
| 237 |
-
sys.exit(1)
|
| 238 |
-
await asyncio.sleep(0.05)
|
| 239 |
-
return server, task
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
# ---------------------------------------------------------------------------
|
| 243 |
-
# Observation snapshot
|
| 244 |
-
# ---------------------------------------------------------------------------
|
| 245 |
-
|
| 246 |
-
def _obs_snapshot(obs: Any) -> Dict[str, Any]:
|
| 247 |
-
return {
|
| 248 |
-
"entities": {k: dict(v) for k, v in (obs.entities or {}).items()},
|
| 249 |
-
"relations": [dict(r) for r in (obs.relations or [])],
|
| 250 |
-
"layout": {k: dict(v) for k, v in (obs.layout or {}).items()},
|
| 251 |
-
"annotations": [dict(a) for a in (obs.annotations or [])],
|
| 252 |
-
"notes": [dict(n) for n in (obs.notes or [])],
|
| 253 |
-
"score_breakdown": {
|
| 254 |
-
k: (float(v) if isinstance(v, (int, float)) else v)
|
| 255 |
-
for k, v in (obs.score_breakdown or {}).items()
|
| 256 |
-
},
|
| 257 |
-
"coverage": list(obs.concept_coverage),
|
| 258 |
-
"narration_history": list(obs.narration_history),
|
| 259 |
-
"remaining_step_budget": obs.remaining_step_budget,
|
| 260 |
-
}
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
# ---------------------------------------------------------------------------
|
| 264 |
-
# Episode runner with audio
|
| 265 |
-
# ---------------------------------------------------------------------------
|
| 266 |
-
|
| 267 |
-
async def run_episode_streaming(
|
| 268 |
-
env: Any, client: OpenAI, task_name: str, broadcaster: Broadcaster
|
| 269 |
-
) -> None:
|
| 270 |
-
history: List[str] = []
|
| 271 |
-
rewards: List[float] = []
|
| 272 |
-
steps_taken = 0
|
| 273 |
-
score = 0.0
|
| 274 |
-
success = False
|
| 275 |
-
last_reward = 0.0
|
| 276 |
-
last_action: Optional[Dict[str, Any]] = None
|
| 277 |
-
obs = None
|
| 278 |
-
loop = asyncio.get_running_loop()
|
| 279 |
-
|
| 280 |
-
log_start(task=task_name, env=BENCHMARK, model=MODEL_NAME)
|
| 281 |
-
|
| 282 |
-
try:
|
| 283 |
-
result = await env.reset(task_name=task_name)
|
| 284 |
-
obs = result.observation
|
| 285 |
-
debug(
|
| 286 |
-
f"RESET: scenario={obs.scenario_id} goal={obs.goal} "
|
| 287 |
-
f"budget={obs.remaining_step_budget}"
|
| 288 |
-
)
|
| 289 |
-
|
| 290 |
-
goal_audio, goal_dur = await loop.run_in_executor(
|
| 291 |
-
None, _synthesize_sync, obs.goal
|
| 292 |
-
)
|
| 293 |
-
|
| 294 |
-
snap = _obs_snapshot(obs)
|
| 295 |
-
await broadcaster.send(
|
| 296 |
-
{
|
| 297 |
-
"type": "reset",
|
| 298 |
-
"task_name": obs.task_name,
|
| 299 |
-
"scenario_id": obs.scenario_id,
|
| 300 |
-
"goal": obs.goal,
|
| 301 |
-
"checklist": list(obs.concept_checklist),
|
| 302 |
-
"input_data": dict(obs.input_data),
|
| 303 |
-
"constraints": list(obs.constraints),
|
| 304 |
-
"max_steps": obs.max_steps,
|
| 305 |
-
"audio": goal_audio,
|
| 306 |
-
"audio_duration": goal_dur,
|
| 307 |
-
**snap,
|
| 308 |
-
}
|
| 309 |
-
)
|
| 310 |
-
|
| 311 |
-
if goal_dur > 0:
|
| 312 |
-
await asyncio.sleep(max(TTS_MIN_WAIT, goal_dur - TTS_LEAD_TIME))
|
| 313 |
-
else:
|
| 314 |
-
await asyncio.sleep(TTS_MIN_WAIT)
|
| 315 |
-
|
| 316 |
-
for step in range(1, MAX_STEPS + 1):
|
| 317 |
-
if result.done:
|
| 318 |
-
break
|
| 319 |
-
|
| 320 |
-
step_start = time.monotonic()
|
| 321 |
-
obs = result.observation
|
| 322 |
-
|
| 323 |
-
action_dict = await get_model_action_async(
|
| 324 |
-
client, obs, last_action, last_reward, history
|
| 325 |
-
)
|
| 326 |
-
action = VisualReasoningAction(**action_dict)
|
| 327 |
-
narration = action_dict.get("narration", "")
|
| 328 |
-
|
| 329 |
-
env_future = asyncio.ensure_future(env.step(action))
|
| 330 |
-
tts_future = loop.run_in_executor(None, _synthesize_sync, narration)
|
| 331 |
-
|
| 332 |
-
result = await env_future
|
| 333 |
-
audio_b64, audio_dur = await tts_future
|
| 334 |
-
|
| 335 |
-
obs = result.observation
|
| 336 |
-
reward = result.reward or 0.0
|
| 337 |
-
done = result.done
|
| 338 |
-
error = obs.action_error
|
| 339 |
-
|
| 340 |
-
rewards.append(reward)
|
| 341 |
-
steps_taken = step
|
| 342 |
-
last_reward = reward
|
| 343 |
-
last_action = action_dict
|
| 344 |
-
|
| 345 |
-
log_step(
|
| 346 |
-
step=step,
|
| 347 |
-
action=action_to_string(action_dict),
|
| 348 |
-
reward=reward,
|
| 349 |
-
done=done,
|
| 350 |
-
error=error,
|
| 351 |
-
)
|
| 352 |
-
history.append(f"Step {step}: action={action_to_string(action_dict)}")
|
| 353 |
-
|
| 354 |
-
snap = _obs_snapshot(obs)
|
| 355 |
-
await broadcaster.send(
|
| 356 |
-
{
|
| 357 |
-
"type": "step",
|
| 358 |
-
"task_name": obs.task_name,
|
| 359 |
-
"scenario_id": obs.scenario_id,
|
| 360 |
-
"step": step,
|
| 361 |
-
"step_type": action_dict.get("step_type"),
|
| 362 |
-
"intent": action_dict.get("intent", ""),
|
| 363 |
-
"narration": narration,
|
| 364 |
-
"ops": action_dict.get("ops", []),
|
| 365 |
-
"covered_concepts": action_dict.get("covered_concepts", []),
|
| 366 |
-
"reward": float(reward),
|
| 367 |
-
"score": float(obs.score_breakdown.get("overall_score", 0.0)),
|
| 368 |
-
"done": bool(done),
|
| 369 |
-
"error": error,
|
| 370 |
-
"audio": audio_b64,
|
| 371 |
-
"audio_duration": audio_dur,
|
| 372 |
-
**snap,
|
| 373 |
-
}
|
| 374 |
-
)
|
| 375 |
-
|
| 376 |
-
if done:
|
| 377 |
-
if audio_dur > 0:
|
| 378 |
-
await asyncio.sleep(max(0, audio_dur + 0.5))
|
| 379 |
-
break
|
| 380 |
-
|
| 381 |
-
elapsed = time.monotonic() - step_start
|
| 382 |
-
target = max(TTS_MIN_WAIT, audio_dur - TTS_LEAD_TIME) if audio_dur > 0 else TTS_MIN_WAIT
|
| 383 |
-
remaining = max(0, target - elapsed)
|
| 384 |
-
if remaining > 0:
|
| 385 |
-
await asyncio.sleep(remaining)
|
| 386 |
-
|
| 387 |
-
if obs is not None and steps_taken > 0:
|
| 388 |
-
score = float(obs.score_breakdown.get("overall_score", 0.0))
|
| 389 |
-
score = min(max(score, 0.0), 1.0)
|
| 390 |
-
success = score >= SUCCESS_SCORE_THRESHOLD
|
| 391 |
-
|
| 392 |
-
finally:
|
| 393 |
-
log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
|
| 394 |
-
await broadcaster.send(
|
| 395 |
-
{
|
| 396 |
-
"type": "end",
|
| 397 |
-
"task_name": task_name,
|
| 398 |
-
"success": bool(success),
|
| 399 |
-
"steps": steps_taken,
|
| 400 |
-
"score": float(score),
|
| 401 |
-
"rewards": [float(r) for r in rewards],
|
| 402 |
-
}
|
| 403 |
-
)
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
# ---------------------------------------------------------------------------
|
| 407 |
-
# Main
|
| 408 |
-
# ---------------------------------------------------------------------------
|
| 409 |
-
|
| 410 |
-
async def main() -> None:
|
| 411 |
-
broadcaster = Broadcaster()
|
| 412 |
-
server, server_task = await start_viewer(broadcaster)
|
| 413 |
-
|
| 414 |
-
vlog(f"open http://{VIS_HOST}:{VIS_PORT}/ in your browser")
|
| 415 |
-
if not BACKGROUND_MUSIC_PATH.exists():
|
| 416 |
-
vlog(f"WARNING: background music not found at {BACKGROUND_MUSIC_PATH}")
|
| 417 |
-
if VIS_WAIT > 0:
|
| 418 |
-
vlog(f"waiting up to {VIS_WAIT:.0f}s for a browser connection...")
|
| 419 |
-
connected = await broadcaster.wait_for_client(VIS_WAIT)
|
| 420 |
-
if not connected:
|
| 421 |
-
vlog("proceeding without viewer (no browser connected in time)")
|
| 422 |
-
else:
|
| 423 |
-
vlog("VIS_WAIT=0, starting immediately")
|
| 424 |
-
|
| 425 |
-
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
|
| 426 |
-
|
| 427 |
-
if LOCAL_IMAGE_NAME:
|
| 428 |
-
if LOCAL_IMAGE_NAME.startswith("http://127.0.0.1:"):
|
| 429 |
-
env = VisualReasoningEnv(base_url=LOCAL_IMAGE_NAME, message_timeout_s=120)
|
| 430 |
-
else:
|
| 431 |
-
env = await VisualReasoningEnv.from_docker_image(LOCAL_IMAGE_NAME)
|
| 432 |
-
else:
|
| 433 |
-
env = await VisualReasoningEnv.from_env(HF_SPACE_URL, use_docker=False)
|
| 434 |
-
|
| 435 |
-
try:
|
| 436 |
-
for task_name in TASK_NAMES:
|
| 437 |
-
await run_episode_streaming(env, client, task_name, broadcaster)
|
| 438 |
-
await broadcaster.send({"type": "shutdown"})
|
| 439 |
-
await asyncio.sleep(0.5)
|
| 440 |
-
finally:
|
| 441 |
-
try:
|
| 442 |
-
await env.close()
|
| 443 |
-
except Exception as exc:
|
| 444 |
-
print(f"[DEBUG] env.close() error: {exc}", file=sys.stderr, flush=True)
|
| 445 |
-
server.should_exit = True
|
| 446 |
-
with contextlib.suppress(Exception):
|
| 447 |
-
await server_task
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
if __name__ == "__main__":
|
| 451 |
-
asyncio.run(main())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
openenv_visual_reasoning.egg-info/PKG-INFO
DELETED
|
@@ -1,19 +0,0 @@
|
|
| 1 |
-
Metadata-Version: 2.4
|
| 2 |
-
Name: openenv-visual_reasoning
|
| 3 |
-
Version: 0.1.0
|
| 4 |
-
Summary: Visual Reasoning environment for OpenEnv — step-based RL for grounded visual + textual CS explanations
|
| 5 |
-
Requires-Python: >=3.10
|
| 6 |
-
Requires-Dist: openenv-core[core]>=0.2.2
|
| 7 |
-
Requires-Dist: numpy<2.0
|
| 8 |
-
Requires-Dist: python-dotenv>=1.0.0
|
| 9 |
-
Requires-Dist: networkx>=3.1
|
| 10 |
-
Requires-Dist: shapely>=2.0
|
| 11 |
-
Requires-Dist: sentence-transformers>=2.2
|
| 12 |
-
Requires-Dist: rapidfuzz>=3.0
|
| 13 |
-
Requires-Dist: textstat>=0.7
|
| 14 |
-
Requires-Dist: sortedcontainers>=2.4
|
| 15 |
-
Requires-Dist: aiohttp>=3.9
|
| 16 |
-
Requires-Dist: openai>=1.0
|
| 17 |
-
Provides-Extra: dev
|
| 18 |
-
Requires-Dist: pytest>=8.0.0; extra == "dev"
|
| 19 |
-
Requires-Dist: pytest-cov>=4.0.0; extra == "dev"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
openenv_visual_reasoning.egg-info/SOURCES.txt
DELETED
|
@@ -1,38 +0,0 @@
|
|
| 1 |
-
README.md
|
| 2 |
-
__init__.py
|
| 3 |
-
client.py
|
| 4 |
-
inference.py
|
| 5 |
-
inference_audio.py
|
| 6 |
-
inference_excalidraw.py
|
| 7 |
-
inference_tldraw.py
|
| 8 |
-
models.py
|
| 9 |
-
pyproject.toml
|
| 10 |
-
./__init__.py
|
| 11 |
-
./client.py
|
| 12 |
-
./inference.py
|
| 13 |
-
./inference_audio.py
|
| 14 |
-
./inference_excalidraw.py
|
| 15 |
-
./inference_tldraw.py
|
| 16 |
-
./models.py
|
| 17 |
-
openenv_visual_reasoning.egg-info/PKG-INFO
|
| 18 |
-
openenv_visual_reasoning.egg-info/SOURCES.txt
|
| 19 |
-
openenv_visual_reasoning.egg-info/dependency_links.txt
|
| 20 |
-
openenv_visual_reasoning.egg-info/entry_points.txt
|
| 21 |
-
openenv_visual_reasoning.egg-info/requires.txt
|
| 22 |
-
openenv_visual_reasoning.egg-info/top_level.txt
|
| 23 |
-
server/__init__.py
|
| 24 |
-
server/app.py
|
| 25 |
-
server/app_backup.py
|
| 26 |
-
server/constants.py
|
| 27 |
-
server/invariant_checkers.py
|
| 28 |
-
server/narration_scorer.py
|
| 29 |
-
server/pedagogical_scoring.py
|
| 30 |
-
server/regions.py
|
| 31 |
-
server/scenario_generator.py
|
| 32 |
-
server/scenario_loader.py
|
| 33 |
-
server/scoring.py
|
| 34 |
-
server/visual_reasoning_environment.py
|
| 35 |
-
tests/test_environment.py
|
| 36 |
-
tests/test_regions.py
|
| 37 |
-
tests/test_scenario_loader.py
|
| 38 |
-
tests/test_scoring.py
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
openenv_visual_reasoning.egg-info/dependency_links.txt
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
|
|
|
|
|
|
openenv_visual_reasoning.egg-info/entry_points.txt
DELETED
|
@@ -1,2 +0,0 @@
|
|
| 1 |
-
[console_scripts]
|
| 2 |
-
server = visual_reasoning.server.app:main
|
|
|
|
|
|
|
|
|
openenv_visual_reasoning.egg-info/requires.txt
DELETED
|
@@ -1,15 +0,0 @@
|
|
| 1 |
-
openenv-core[core]>=0.2.2
|
| 2 |
-
numpy<2.0
|
| 3 |
-
python-dotenv>=1.0.0
|
| 4 |
-
networkx>=3.1
|
| 5 |
-
shapely>=2.0
|
| 6 |
-
sentence-transformers>=2.2
|
| 7 |
-
rapidfuzz>=3.0
|
| 8 |
-
textstat>=0.7
|
| 9 |
-
sortedcontainers>=2.4
|
| 10 |
-
aiohttp>=3.9
|
| 11 |
-
openai>=1.0
|
| 12 |
-
|
| 13 |
-
[dev]
|
| 14 |
-
pytest>=8.0.0
|
| 15 |
-
pytest-cov>=4.0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
openenv_visual_reasoning.egg-info/top_level.txt
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
visual_reasoning
|
|
|
|
|
|
push_to_space.ipynb
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
+
"id": "34c098b1",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [
|
| 9 |
+
{
|
| 10 |
+
"name": "stderr",
|
| 11 |
+
"output_type": "stream",
|
| 12 |
+
"text": [
|
| 13 |
+
"/opt/conda/envs/unsloth_env/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
| 14 |
+
" from .autonotebook import tqdm as notebook_tqdm\n"
|
| 15 |
+
]
|
| 16 |
+
}
|
| 17 |
+
],
|
| 18 |
+
"source": [
|
| 19 |
+
"from dotenv import load_dotenv\n",
|
| 20 |
+
"from huggingface_hub import HfApi\n",
|
| 21 |
+
"import os"
|
| 22 |
+
]
|
| 23 |
+
},
|
| 24 |
+
{
|
| 25 |
+
"cell_type": "code",
|
| 26 |
+
"execution_count": 2,
|
| 27 |
+
"id": "eff398ee",
|
| 28 |
+
"metadata": {},
|
| 29 |
+
"outputs": [
|
| 30 |
+
{
|
| 31 |
+
"data": {
|
| 32 |
+
"text/plain": [
|
| 33 |
+
"True"
|
| 34 |
+
]
|
| 35 |
+
},
|
| 36 |
+
"execution_count": 2,
|
| 37 |
+
"metadata": {},
|
| 38 |
+
"output_type": "execute_result"
|
| 39 |
+
}
|
| 40 |
+
],
|
| 41 |
+
"source": [
|
| 42 |
+
"load_dotenv(\"../.env\")"
|
| 43 |
+
]
|
| 44 |
+
},
|
| 45 |
+
{
|
| 46 |
+
"cell_type": "code",
|
| 47 |
+
"execution_count": 4,
|
| 48 |
+
"id": "fb24296b",
|
| 49 |
+
"metadata": {},
|
| 50 |
+
"outputs": [],
|
| 51 |
+
"source": [
|
| 52 |
+
"# api = HfApi(token=os.getenv(\"HF_TOKEN\"))\n",
|
| 53 |
+
"# api.upload_folder(\n",
|
| 54 |
+
"# repo_id=\"sreeramajay/visual_reasoning-env\",\n",
|
| 55 |
+
"# folder_path=\".\",\n",
|
| 56 |
+
"# repo_type=\"space\",\n",
|
| 57 |
+
"# delete_patterns=[\"*\"], # deletes all remote files not present locally\n",
|
| 58 |
+
"# )"
|
| 59 |
+
]
|
| 60 |
+
},
|
| 61 |
+
{
|
| 62 |
+
"cell_type": "code",
|
| 63 |
+
"execution_count": null,
|
| 64 |
+
"id": "02ea1bf0",
|
| 65 |
+
"metadata": {},
|
| 66 |
+
"outputs": [],
|
| 67 |
+
"source": [
|
| 68 |
+
"api = HfApi(token=os.getenv(\"HF_TOKEN\"))\n",
|
| 69 |
+
"api.upload_folder(\n",
|
| 70 |
+
" repo_id=\"sreeramajay/visual_reasoning-env\",\n",
|
| 71 |
+
" folder_path=\".\",\n",
|
| 72 |
+
" repo_type=\"space\",\n",
|
| 73 |
+
" delete_patterns=[\"*\"],\n",
|
| 74 |
+
" ignore_patterns=[\n",
|
| 75 |
+
" \".venv/**\",\n",
|
| 76 |
+
" \"venv/**\",\n",
|
| 77 |
+
" \"**/__pycache__/**\",\n",
|
| 78 |
+
" \"**/*.py[cod]\",\n",
|
| 79 |
+
" \"**/*.egg-info/**\",\n",
|
| 80 |
+
" \"dist/**\",\n",
|
| 81 |
+
" \"build/**\",\n",
|
| 82 |
+
" \".env\",\n",
|
| 83 |
+
" \".env.local\",\n",
|
| 84 |
+
" \".idea/**\",\n",
|
| 85 |
+
" \".vscode/**\",\n",
|
| 86 |
+
" \".pytest_cache/**\",\n",
|
| 87 |
+
" \"htmlcov/**\",\n",
|
| 88 |
+
" \".coverage\",\n",
|
| 89 |
+
" \"uv.lock\",\n",
|
| 90 |
+
" \"**/.ipynb_checkpoints/**\",\n",
|
| 91 |
+
" \"CLAUDE.md\",\n",
|
| 92 |
+
" \"openenv_visual_reasoning.egg-info/**\",\n",
|
| 93 |
+
" \".DS_Store\",\n",
|
| 94 |
+
" \"Thumbs.db\",\n",
|
| 95 |
+
" ],\n",
|
| 96 |
+
")"
|
| 97 |
+
]
|
| 98 |
+
},
|
| 99 |
+
{
|
| 100 |
+
"cell_type": "code",
|
| 101 |
+
"execution_count": null,
|
| 102 |
+
"id": "3dc787ac",
|
| 103 |
+
"metadata": {},
|
| 104 |
+
"outputs": [],
|
| 105 |
+
"source": []
|
| 106 |
+
}
|
| 107 |
+
],
|
| 108 |
+
"metadata": {
|
| 109 |
+
"kernelspec": {
|
| 110 |
+
"display_name": "unsloth_env",
|
| 111 |
+
"language": "python",
|
| 112 |
+
"name": "python3"
|
| 113 |
+
},
|
| 114 |
+
"language_info": {
|
| 115 |
+
"codemirror_mode": {
|
| 116 |
+
"name": "ipython",
|
| 117 |
+
"version": 3
|
| 118 |
+
},
|
| 119 |
+
"file_extension": ".py",
|
| 120 |
+
"mimetype": "text/x-python",
|
| 121 |
+
"name": "python",
|
| 122 |
+
"nbconvert_exporter": "python",
|
| 123 |
+
"pygments_lexer": "ipython3",
|
| 124 |
+
"version": "3.11.11"
|
| 125 |
+
}
|
| 126 |
+
},
|
| 127 |
+
"nbformat": 4,
|
| 128 |
+
"nbformat_minor": 5
|
| 129 |
+
}
|
scripts/generate_rubric_data.py
DELETED
|
@@ -1,1132 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the BSD-style license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
|
| 7 |
-
import argparse
|
| 8 |
-
import json
|
| 9 |
-
import os
|
| 10 |
-
import sys
|
| 11 |
-
from pathlib import Path
|
| 12 |
-
from typing import Any, Dict, List, Optional
|
| 13 |
-
|
| 14 |
-
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 15 |
-
|
| 16 |
-
from server.visual_reasoning_environment import VisualReasoningEnvironment
|
| 17 |
-
from models import VisualReasoningAction
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
LABELING_PROMPT = """Rate this narration on four dimensions (0.0 to 1.0, steps of 0.25):
|
| 21 |
-
|
| 22 |
-
CONTEXT:
|
| 23 |
-
- Algorithm: {template}
|
| 24 |
-
- Ops this step: {ops_summary}
|
| 25 |
-
- Previous narration: {prev_narration}
|
| 26 |
-
- Current narration: {narration}
|
| 27 |
-
- Concepts claimed: {covered_concepts}
|
| 28 |
-
- Progress: {step_progress}
|
| 29 |
-
|
| 30 |
-
DIMENSIONS:
|
| 31 |
-
1. explanatory_depth: 0.0=mechanical, 0.5=basic context, 1.0=explains why+broader concept
|
| 32 |
-
2. grounding_accuracy: 0.0=contradicts ops, 0.5=partial, 1.0=accurate+describes effects
|
| 33 |
-
3. clarity: 0.0=incomprehensible, 0.5=adequate, 1.0=clear natural voiceover
|
| 34 |
-
4. flow: 0.0=disconnected, 0.5=adequate connection, 1.0=natural continuation
|
| 35 |
-
|
| 36 |
-
Output ONLY JSON: {{"explanatory_depth": X, "grounding_accuracy": X, "clarity": X, "flow": X}}"""
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
# ---------------------------------------------------------------------------
|
| 40 |
-
# easy_1: linked_list_traversal, incremental, values=[10,20,30]
|
| 41 |
-
# ---------------------------------------------------------------------------
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
def _easy_1_excellent() -> List[Dict[str, Any]]:
|
| 45 |
-
return [
|
| 46 |
-
{
|
| 47 |
-
"step_type": "advance",
|
| 48 |
-
"narration": "We create a linked list region and place the head pointer at position 0 — this is the only entry point for traversal.",
|
| 49 |
-
"ops": [
|
| 50 |
-
{"op": "add_region", "target_ids": ["list"], "params": {"style": "queue", "title": "Linked List"}},
|
| 51 |
-
{"op": "add_pointer", "target_ids": ["head"], "params": {"region": "list", "index": 0}},
|
| 52 |
-
],
|
| 53 |
-
"covered_concepts": ["head_pointer"],
|
| 54 |
-
"intent": "setup",
|
| 55 |
-
},
|
| 56 |
-
{
|
| 57 |
-
"step_type": "advance",
|
| 58 |
-
"narration": "Adding node_1 with value 10 — the head pointer starts here, making it the first node we visit in the traversal.",
|
| 59 |
-
"ops": [
|
| 60 |
-
{"op": "add_node", "target_ids": ["node_1"], "params": {"value": 10, "region": "list"}},
|
| 61 |
-
{"op": "set_role", "target_ids": ["node_1"], "params": {"role": "current"}},
|
| 62 |
-
{"op": "annotate", "target_ids": ["node_1"], "params": {"text": "val=10"}},
|
| 63 |
-
],
|
| 64 |
-
"covered_concepts": ["node_value"],
|
| 65 |
-
"intent": "first-node",
|
| 66 |
-
},
|
| 67 |
-
{
|
| 68 |
-
"step_type": "advance",
|
| 69 |
-
"narration": "We add node_2 with value 20 and link it from node_1 via a next pointer — following these next links is how we traverse the list.",
|
| 70 |
-
"ops": [
|
| 71 |
-
{"op": "add_node", "target_ids": ["node_2"], "params": {"value": 20, "region": "list"}},
|
| 72 |
-
{"op": "add_edge", "target_ids": ["node_1", "node_2"], "params": {"label": "next"}},
|
| 73 |
-
{"op": "set_role", "target_ids": ["node_1"], "params": {"role": "visited"}},
|
| 74 |
-
{"op": "set_role", "target_ids": ["node_2"], "params": {"role": "current"}},
|
| 75 |
-
{"op": "move_pointer", "target_ids": ["head"], "params": {"index": 1}},
|
| 76 |
-
],
|
| 77 |
-
"covered_concepts": ["next_link"],
|
| 78 |
-
"intent": "second-node",
|
| 79 |
-
},
|
| 80 |
-
{
|
| 81 |
-
"step_type": "advance",
|
| 82 |
-
"narration": "Finally node_3 with value 30 has no next link — this null terminator marks it as the tail, ending our traversal.",
|
| 83 |
-
"ops": [
|
| 84 |
-
{"op": "add_node", "target_ids": ["node_3"], "params": {"value": 30, "region": "list"}},
|
| 85 |
-
{"op": "add_edge", "target_ids": ["node_2", "node_3"], "params": {"label": "next"}},
|
| 86 |
-
{"op": "set_role", "target_ids": ["node_2"], "params": {"role": "visited"}},
|
| 87 |
-
{"op": "set_role", "target_ids": ["node_3"], "params": {"role": "done"}},
|
| 88 |
-
{"op": "annotate", "target_ids": ["node_3"], "params": {"text": "tail (null next)"}},
|
| 89 |
-
],
|
| 90 |
-
"covered_concepts": ["tail_marker"],
|
| 91 |
-
"intent": "tail",
|
| 92 |
-
},
|
| 93 |
-
{
|
| 94 |
-
"step_type": "complete",
|
| 95 |
-
"narration": "The traversal is complete — we visited every node from head to tail by following next pointers.",
|
| 96 |
-
"ops": [],
|
| 97 |
-
"covered_concepts": [],
|
| 98 |
-
"intent": "done",
|
| 99 |
-
},
|
| 100 |
-
]
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
def _easy_1_good() -> List[Dict[str, Any]]:
|
| 104 |
-
return [
|
| 105 |
-
{
|
| 106 |
-
"step_type": "advance",
|
| 107 |
-
"narration": "Setting up the linked list with a head pointer.",
|
| 108 |
-
"ops": [
|
| 109 |
-
{"op": "add_region", "target_ids": ["list"], "params": {"style": "queue", "title": "List"}},
|
| 110 |
-
{"op": "add_pointer", "target_ids": ["head"], "params": {"region": "list"}},
|
| 111 |
-
{"op": "add_node", "target_ids": ["n1"], "params": {"value": 10, "region": "list"}},
|
| 112 |
-
],
|
| 113 |
-
"covered_concepts": ["head_pointer", "node_value"],
|
| 114 |
-
"intent": "setup",
|
| 115 |
-
},
|
| 116 |
-
{
|
| 117 |
-
"step_type": "advance",
|
| 118 |
-
"narration": "Adding the next two nodes and connecting them.",
|
| 119 |
-
"ops": [
|
| 120 |
-
{"op": "add_node", "target_ids": ["n2"], "params": {"value": 20, "region": "list"}},
|
| 121 |
-
{"op": "add_node", "target_ids": ["n3"], "params": {"value": 30, "region": "list"}},
|
| 122 |
-
{"op": "add_edge", "target_ids": ["n1", "n2"], "params": {}},
|
| 123 |
-
{"op": "add_edge", "target_ids": ["n2", "n3"], "params": {}},
|
| 124 |
-
],
|
| 125 |
-
"covered_concepts": ["next_link"],
|
| 126 |
-
"intent": "build",
|
| 127 |
-
},
|
| 128 |
-
{
|
| 129 |
-
"step_type": "complete",
|
| 130 |
-
"narration": "The tail node has no next pointer, traversal ends here.",
|
| 131 |
-
"ops": [
|
| 132 |
-
{"op": "annotate", "target_ids": ["n3"], "params": {"text": "tail"}},
|
| 133 |
-
],
|
| 134 |
-
"covered_concepts": ["tail_marker"],
|
| 135 |
-
"intent": "done",
|
| 136 |
-
},
|
| 137 |
-
]
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
def _easy_1_mediocre() -> List[Dict[str, Any]]:
|
| 141 |
-
return [
|
| 142 |
-
{
|
| 143 |
-
"step_type": "advance",
|
| 144 |
-
"narration": "Adding nodes to the list.",
|
| 145 |
-
"ops": [
|
| 146 |
-
{"op": "add_node", "target_ids": ["a"], "params": {"value": 10}},
|
| 147 |
-
{"op": "add_node", "target_ids": ["b"], "params": {"value": 20}},
|
| 148 |
-
{"op": "add_node", "target_ids": ["c"], "params": {"value": 30}},
|
| 149 |
-
{"op": "add_edge", "target_ids": ["a", "b"], "params": {}},
|
| 150 |
-
{"op": "add_edge", "target_ids": ["b", "c"], "params": {}},
|
| 151 |
-
],
|
| 152 |
-
"covered_concepts": ["head_pointer", "node_value", "next_link", "tail_marker"],
|
| 153 |
-
"intent": "",
|
| 154 |
-
},
|
| 155 |
-
{
|
| 156 |
-
"step_type": "complete",
|
| 157 |
-
"narration": "Done.",
|
| 158 |
-
"ops": [],
|
| 159 |
-
"covered_concepts": [],
|
| 160 |
-
"intent": "",
|
| 161 |
-
},
|
| 162 |
-
]
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
def _easy_1_bad() -> List[Dict[str, Any]]:
|
| 166 |
-
return [
|
| 167 |
-
{
|
| 168 |
-
"step_type": "advance",
|
| 169 |
-
"narration": "Starting.",
|
| 170 |
-
"ops": [
|
| 171 |
-
{"op": "add_node", "target_ids": ["x"], "params": {}},
|
| 172 |
-
],
|
| 173 |
-
"covered_concepts": ["head_pointer", "node_value", "next_link", "tail_marker"],
|
| 174 |
-
"intent": "",
|
| 175 |
-
},
|
| 176 |
-
{
|
| 177 |
-
"step_type": "complete",
|
| 178 |
-
"narration": "Finished.",
|
| 179 |
-
"ops": [],
|
| 180 |
-
"covered_concepts": [],
|
| 181 |
-
"intent": "",
|
| 182 |
-
},
|
| 183 |
-
]
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
# ---------------------------------------------------------------------------
|
| 187 |
-
# easy_2: stack_ops, incremental, operations=["push A","push B","pop","push C"]
|
| 188 |
-
# ---------------------------------------------------------------------------
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
def _easy_2_excellent() -> List[Dict[str, Any]]:
|
| 192 |
-
return [
|
| 193 |
-
{
|
| 194 |
-
"step_type": "advance",
|
| 195 |
-
"narration": "We set up a stack region with a container — stacks follow Last In First Out order, so we'll use ordered=false.",
|
| 196 |
-
"ops": [
|
| 197 |
-
{"op": "add_region", "target_ids": ["stk"], "params": {"style": "stack", "title": "Stack"}},
|
| 198 |
-
{"op": "add_container", "target_ids": ["stack"], "params": {"region": "stk", "ordered": False}},
|
| 199 |
-
{"op": "add_pointer", "target_ids": ["top"], "params": {"region": "stk", "index": -1}},
|
| 200 |
-
],
|
| 201 |
-
"covered_concepts": ["top_pointer"],
|
| 202 |
-
"intent": "setup",
|
| 203 |
-
},
|
| 204 |
-
{
|
| 205 |
-
"step_type": "advance",
|
| 206 |
-
"narration": "Push A onto the stack — A becomes the new top. Then push B on top of A, so B is now the top element.",
|
| 207 |
-
"ops": [
|
| 208 |
-
{"op": "add_node", "target_ids": ["A"], "params": {"value": "A", "region": "stk"}},
|
| 209 |
-
{"op": "push_to", "target_ids": ["stack", "A"], "params": {}},
|
| 210 |
-
{"op": "add_node", "target_ids": ["B"], "params": {"value": "B", "region": "stk"}},
|
| 211 |
-
{"op": "push_to", "target_ids": ["stack", "B"], "params": {}},
|
| 212 |
-
{"op": "move_pointer", "target_ids": ["top"], "params": {"index": 1}},
|
| 213 |
-
],
|
| 214 |
-
"covered_concepts": ["push"],
|
| 215 |
-
"intent": "push-A-B",
|
| 216 |
-
},
|
| 217 |
-
{
|
| 218 |
-
"step_type": "advance",
|
| 219 |
-
"narration": "Pop removes B — the most recently pushed element — revealing A underneath. This is the LIFO property in action.",
|
| 220 |
-
"ops": [
|
| 221 |
-
{"op": "pop_from", "target_ids": ["stack"], "params": {}},
|
| 222 |
-
{"op": "set_role", "target_ids": ["B"], "params": {"role": "inactive"}},
|
| 223 |
-
{"op": "move_pointer", "target_ids": ["top"], "params": {"index": 0}},
|
| 224 |
-
],
|
| 225 |
-
"covered_concepts": ["pop", "lifo_order"],
|
| 226 |
-
"intent": "pop-B",
|
| 227 |
-
},
|
| 228 |
-
{
|
| 229 |
-
"step_type": "complete",
|
| 230 |
-
"narration": "Push C on top — the stack now holds [A, C] with C as the new top, confirming LIFO order throughout.",
|
| 231 |
-
"ops": [
|
| 232 |
-
{"op": "add_node", "target_ids": ["C"], "params": {"value": "C", "region": "stk"}},
|
| 233 |
-
{"op": "push_to", "target_ids": ["stack", "C"], "params": {}},
|
| 234 |
-
{"op": "move_pointer", "target_ids": ["top"], "params": {"index": 1}},
|
| 235 |
-
],
|
| 236 |
-
"covered_concepts": [],
|
| 237 |
-
"intent": "push-C-complete",
|
| 238 |
-
},
|
| 239 |
-
]
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
def _easy_2_good() -> List[Dict[str, Any]]:
|
| 243 |
-
return [
|
| 244 |
-
{
|
| 245 |
-
"step_type": "advance",
|
| 246 |
-
"narration": "Creating a stack region with a container and top pointer.",
|
| 247 |
-
"ops": [
|
| 248 |
-
{"op": "add_region", "target_ids": ["stk"], "params": {"style": "stack", "title": "Stack"}},
|
| 249 |
-
{"op": "add_container", "target_ids": ["stack"], "params": {"region": "stk", "ordered": False}},
|
| 250 |
-
{"op": "add_pointer", "target_ids": ["top"], "params": {"region": "stk", "index": -1}},
|
| 251 |
-
],
|
| 252 |
-
"covered_concepts": ["top_pointer"],
|
| 253 |
-
"intent": "setup",
|
| 254 |
-
},
|
| 255 |
-
{
|
| 256 |
-
"step_type": "advance",
|
| 257 |
-
"narration": "Push A and B onto the stack, B is now on top.",
|
| 258 |
-
"ops": [
|
| 259 |
-
{"op": "add_node", "target_ids": ["A"], "params": {"value": "A", "region": "stk"}},
|
| 260 |
-
{"op": "push_to", "target_ids": ["stack", "A"], "params": {}},
|
| 261 |
-
{"op": "add_node", "target_ids": ["B"], "params": {"value": "B", "region": "stk"}},
|
| 262 |
-
{"op": "push_to", "target_ids": ["stack", "B"], "params": {}},
|
| 263 |
-
],
|
| 264 |
-
"covered_concepts": ["push"],
|
| 265 |
-
"intent": "push-AB",
|
| 266 |
-
},
|
| 267 |
-
{
|
| 268 |
-
"step_type": "complete",
|
| 269 |
-
"narration": "Pop B, then push C. Stack ends as [A, C] demonstrating LIFO.",
|
| 270 |
-
"ops": [
|
| 271 |
-
{"op": "pop_from", "target_ids": ["stack"], "params": {}},
|
| 272 |
-
{"op": "add_node", "target_ids": ["C"], "params": {"value": "C", "region": "stk"}},
|
| 273 |
-
{"op": "push_to", "target_ids": ["stack", "C"], "params": {}},
|
| 274 |
-
],
|
| 275 |
-
"covered_concepts": ["pop", "lifo_order"],
|
| 276 |
-
"intent": "pop-push-done",
|
| 277 |
-
},
|
| 278 |
-
]
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
def _easy_2_mediocre() -> List[Dict[str, Any]]:
|
| 282 |
-
return [
|
| 283 |
-
{
|
| 284 |
-
"step_type": "advance",
|
| 285 |
-
"narration": "Performing stack operations.",
|
| 286 |
-
"ops": [
|
| 287 |
-
{"op": "add_region", "target_ids": ["stk"], "params": {"style": "stack", "title": "Stack"}},
|
| 288 |
-
{"op": "add_container", "target_ids": ["stack"], "params": {"region": "stk"}},
|
| 289 |
-
{"op": "add_node", "target_ids": ["A"], "params": {"value": "A", "region": "stk"}},
|
| 290 |
-
{"op": "push_to", "target_ids": ["stack", "A"], "params": {}},
|
| 291 |
-
{"op": "add_node", "target_ids": ["B"], "params": {"value": "B", "region": "stk"}},
|
| 292 |
-
],
|
| 293 |
-
"covered_concepts": ["top_pointer", "push", "pop", "lifo_order"],
|
| 294 |
-
"intent": "",
|
| 295 |
-
},
|
| 296 |
-
{
|
| 297 |
-
"step_type": "complete",
|
| 298 |
-
"narration": "Stack is ready.",
|
| 299 |
-
"ops": [
|
| 300 |
-
{"op": "push_to", "target_ids": ["stack", "B"], "params": {}},
|
| 301 |
-
],
|
| 302 |
-
"covered_concepts": [],
|
| 303 |
-
"intent": "",
|
| 304 |
-
},
|
| 305 |
-
]
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
def _easy_2_bad() -> List[Dict[str, Any]]:
|
| 309 |
-
return [
|
| 310 |
-
{
|
| 311 |
-
"step_type": "complete",
|
| 312 |
-
"narration": "Stack operations.",
|
| 313 |
-
"ops": [],
|
| 314 |
-
"covered_concepts": ["top_pointer", "push", "pop", "lifo_order"],
|
| 315 |
-
"intent": "",
|
| 316 |
-
},
|
| 317 |
-
]
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
# ---------------------------------------------------------------------------
|
| 321 |
-
# medium_1: bfs_graph, graph A->B,C; B->D; C->D,E, source=A
|
| 322 |
-
# ---------------------------------------------------------------------------
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
def _medium_1_excellent() -> List[Dict[str, Any]]:
|
| 326 |
-
return [
|
| 327 |
-
{
|
| 328 |
-
"step_type": "advance",
|
| 329 |
-
"narration": "We draw the directed graph from the input — five nodes A through E with edges showing adjacency.",
|
| 330 |
-
"ops": [
|
| 331 |
-
{"op": "add_region", "target_ids": ["graph"], "params": {"style": "graph", "title": "Graph"}},
|
| 332 |
-
{"op": "add_node", "target_ids": ["A"], "params": {"value": "A", "region": "graph"}},
|
| 333 |
-
{"op": "add_node", "target_ids": ["B"], "params": {"value": "B", "region": "graph"}},
|
| 334 |
-
{"op": "add_node", "target_ids": ["C"], "params": {"value": "C", "region": "graph"}},
|
| 335 |
-
{"op": "add_node", "target_ids": ["D"], "params": {"value": "D", "region": "graph"}},
|
| 336 |
-
{"op": "add_node", "target_ids": ["E"], "params": {"value": "E", "region": "graph"}},
|
| 337 |
-
{"op": "add_edge", "target_ids": ["A", "B"], "params": {}},
|
| 338 |
-
{"op": "add_edge", "target_ids": ["A", "C"], "params": {}},
|
| 339 |
-
{"op": "add_edge", "target_ids": ["B", "D"], "params": {}},
|
| 340 |
-
{"op": "add_edge", "target_ids": ["C", "D"], "params": {}},
|
| 341 |
-
{"op": "add_edge", "target_ids": ["C", "E"], "params": {}},
|
| 342 |
-
],
|
| 343 |
-
"covered_concepts": [],
|
| 344 |
-
"intent": "draw-graph",
|
| 345 |
-
},
|
| 346 |
-
{
|
| 347 |
-
"step_type": "advance",
|
| 348 |
-
"narration": "We initialize BFS by creating a queue and seeding it with source A — BFS always begins from a single starting node.",
|
| 349 |
-
"ops": [
|
| 350 |
-
{"op": "add_region", "target_ids": ["bfs_region"], "params": {"style": "queue", "title": "BFS Queue"}},
|
| 351 |
-
{"op": "add_container", "target_ids": ["q"], "params": {"region": "bfs_region", "ordered": True}},
|
| 352 |
-
{"op": "push_to", "target_ids": ["q", "A"], "params": {}},
|
| 353 |
-
{"op": "set_role", "target_ids": ["A"], "params": {"role": "frontier"}},
|
| 354 |
-
],
|
| 355 |
-
"covered_concepts": ["queue", "frontier"],
|
| 356 |
-
"intent": "init-bfs",
|
| 357 |
-
},
|
| 358 |
-
{
|
| 359 |
-
"step_type": "advance",
|
| 360 |
-
"narration": "Dequeue A from the queue and mark it visited — we then push A's unvisited neighbors B and C as frontier nodes.",
|
| 361 |
-
"ops": [
|
| 362 |
-
{"op": "pop_from", "target_ids": ["q"], "params": {}},
|
| 363 |
-
{"op": "set_role", "target_ids": ["A"], "params": {"role": "visited"}},
|
| 364 |
-
{"op": "push_to", "target_ids": ["q", "B"], "params": {}},
|
| 365 |
-
{"op": "push_to", "target_ids": ["q", "C"], "params": {}},
|
| 366 |
-
{"op": "set_role", "target_ids": ["B"], "params": {"role": "frontier"}},
|
| 367 |
-
{"op": "set_role", "target_ids": ["C"], "params": {"role": "frontier"}},
|
| 368 |
-
],
|
| 369 |
-
"covered_concepts": ["visited_set", "dequeue"],
|
| 370 |
-
"intent": "visit-A",
|
| 371 |
-
},
|
| 372 |
-
{
|
| 373 |
-
"step_type": "advance",
|
| 374 |
-
"narration": "Dequeue B and visit it — B's neighbor D joins the frontier, demonstrating BFS level-by-level order.",
|
| 375 |
-
"ops": [
|
| 376 |
-
{"op": "pop_from", "target_ids": ["q"], "params": {}},
|
| 377 |
-
{"op": "set_role", "target_ids": ["B"], "params": {"role": "visited"}},
|
| 378 |
-
{"op": "push_to", "target_ids": ["q", "D"], "params": {}},
|
| 379 |
-
{"op": "set_role", "target_ids": ["D"], "params": {"role": "frontier"}},
|
| 380 |
-
],
|
| 381 |
-
"covered_concepts": ["level_order"],
|
| 382 |
-
"intent": "visit-B",
|
| 383 |
-
},
|
| 384 |
-
{
|
| 385 |
-
"step_type": "advance",
|
| 386 |
-
"narration": "Dequeue C and visit it — C connects to D and E, but D is already in the queue so only E is added.",
|
| 387 |
-
"ops": [
|
| 388 |
-
{"op": "pop_from", "target_ids": ["q"], "params": {}},
|
| 389 |
-
{"op": "set_role", "target_ids": ["C"], "params": {"role": "visited"}},
|
| 390 |
-
{"op": "push_to", "target_ids": ["q", "E"], "params": {}},
|
| 391 |
-
{"op": "set_role", "target_ids": ["E"], "params": {"role": "frontier"}},
|
| 392 |
-
],
|
| 393 |
-
"covered_concepts": [],
|
| 394 |
-
"intent": "visit-C",
|
| 395 |
-
},
|
| 396 |
-
{
|
| 397 |
-
"step_type": "complete",
|
| 398 |
-
"narration": "All nodes are now visited in BFS order A,B,C,D,E — each level was fully processed before moving deeper.",
|
| 399 |
-
"ops": [
|
| 400 |
-
{"op": "pop_from", "target_ids": ["q"], "params": {}},
|
| 401 |
-
{"op": "set_role", "target_ids": ["D"], "params": {"role": "visited"}},
|
| 402 |
-
{"op": "pop_from", "target_ids": ["q"], "params": {}},
|
| 403 |
-
{"op": "set_role", "target_ids": ["E"], "params": {"role": "visited"}},
|
| 404 |
-
],
|
| 405 |
-
"covered_concepts": [],
|
| 406 |
-
"intent": "complete",
|
| 407 |
-
},
|
| 408 |
-
]
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
def _medium_1_good() -> List[Dict[str, Any]]:
|
| 412 |
-
return [
|
| 413 |
-
{
|
| 414 |
-
"step_type": "advance",
|
| 415 |
-
"narration": "Drawing the graph: nodes A through E with directed edges from the input data.",
|
| 416 |
-
"ops": [
|
| 417 |
-
{"op": "add_region", "target_ids": ["graph"], "params": {"style": "graph", "title": "Graph"}},
|
| 418 |
-
{"op": "add_node", "target_ids": ["A"], "params": {"value": "A", "region": "graph"}},
|
| 419 |
-
{"op": "add_node", "target_ids": ["B"], "params": {"value": "B", "region": "graph"}},
|
| 420 |
-
{"op": "add_node", "target_ids": ["C"], "params": {"value": "C", "region": "graph"}},
|
| 421 |
-
{"op": "add_node", "target_ids": ["D"], "params": {"value": "D", "region": "graph"}},
|
| 422 |
-
{"op": "add_node", "target_ids": ["E"], "params": {"value": "E", "region": "graph"}},
|
| 423 |
-
{"op": "add_edge", "target_ids": ["A", "B"], "params": {}},
|
| 424 |
-
{"op": "add_edge", "target_ids": ["A", "C"], "params": {}},
|
| 425 |
-
{"op": "add_edge", "target_ids": ["B", "D"], "params": {}},
|
| 426 |
-
{"op": "add_edge", "target_ids": ["C", "D"], "params": {}},
|
| 427 |
-
{"op": "add_edge", "target_ids": ["C", "E"], "params": {}},
|
| 428 |
-
],
|
| 429 |
-
"covered_concepts": [],
|
| 430 |
-
"intent": "draw-graph",
|
| 431 |
-
},
|
| 432 |
-
{
|
| 433 |
-
"step_type": "advance",
|
| 434 |
-
"narration": "Create a BFS queue and enqueue source node A, marking it as frontier.",
|
| 435 |
-
"ops": [
|
| 436 |
-
{"op": "add_region", "target_ids": ["bfs_region"], "params": {"style": "queue", "title": "BFS Queue"}},
|
| 437 |
-
{"op": "add_container", "target_ids": ["q"], "params": {"region": "bfs_region", "ordered": True}},
|
| 438 |
-
{"op": "push_to", "target_ids": ["q", "A"], "params": {}},
|
| 439 |
-
{"op": "set_role", "target_ids": ["A"], "params": {"role": "frontier"}},
|
| 440 |
-
],
|
| 441 |
-
"covered_concepts": ["queue", "frontier"],
|
| 442 |
-
"intent": "init",
|
| 443 |
-
},
|
| 444 |
-
{
|
| 445 |
-
"step_type": "advance",
|
| 446 |
-
"narration": "Dequeue A, visit it, and enqueue neighbors B and C.",
|
| 447 |
-
"ops": [
|
| 448 |
-
{"op": "pop_from", "target_ids": ["q"], "params": {}},
|
| 449 |
-
{"op": "set_role", "target_ids": ["A"], "params": {"role": "visited"}},
|
| 450 |
-
{"op": "push_to", "target_ids": ["q", "B"], "params": {}},
|
| 451 |
-
{"op": "push_to", "target_ids": ["q", "C"], "params": {}},
|
| 452 |
-
],
|
| 453 |
-
"covered_concepts": ["visited_set", "dequeue"],
|
| 454 |
-
"intent": "visit-A",
|
| 455 |
-
},
|
| 456 |
-
{
|
| 457 |
-
"step_type": "advance",
|
| 458 |
-
"narration": "Process B then C, adding D and E to the queue as we go.",
|
| 459 |
-
"ops": [
|
| 460 |
-
{"op": "pop_from", "target_ids": ["q"], "params": {}},
|
| 461 |
-
{"op": "set_role", "target_ids": ["B"], "params": {"role": "visited"}},
|
| 462 |
-
{"op": "push_to", "target_ids": ["q", "D"], "params": {}},
|
| 463 |
-
{"op": "pop_from", "target_ids": ["q"], "params": {}},
|
| 464 |
-
{"op": "set_role", "target_ids": ["C"], "params": {"role": "visited"}},
|
| 465 |
-
{"op": "push_to", "target_ids": ["q", "E"], "params": {}},
|
| 466 |
-
],
|
| 467 |
-
"covered_concepts": ["level_order"],
|
| 468 |
-
"intent": "visit-BC",
|
| 469 |
-
},
|
| 470 |
-
{
|
| 471 |
-
"step_type": "complete",
|
| 472 |
-
"narration": "D and E are dequeued and visited, finishing BFS traversal.",
|
| 473 |
-
"ops": [
|
| 474 |
-
{"op": "pop_from", "target_ids": ["q"], "params": {}},
|
| 475 |
-
{"op": "set_role", "target_ids": ["D"], "params": {"role": "visited"}},
|
| 476 |
-
{"op": "pop_from", "target_ids": ["q"], "params": {}},
|
| 477 |
-
{"op": "set_role", "target_ids": ["E"], "params": {"role": "visited"}},
|
| 478 |
-
],
|
| 479 |
-
"covered_concepts": [],
|
| 480 |
-
"intent": "done",
|
| 481 |
-
},
|
| 482 |
-
]
|
| 483 |
-
|
| 484 |
-
|
| 485 |
-
def _medium_1_mediocre() -> List[Dict[str, Any]]:
|
| 486 |
-
return [
|
| 487 |
-
{
|
| 488 |
-
"step_type": "advance",
|
| 489 |
-
"narration": "Drawing the graph nodes.",
|
| 490 |
-
"ops": [
|
| 491 |
-
{"op": "add_region", "target_ids": ["graph"], "params": {"style": "graph", "title": "Graph"}},
|
| 492 |
-
{"op": "add_node", "target_ids": ["A"], "params": {"value": "A", "region": "graph"}},
|
| 493 |
-
{"op": "add_node", "target_ids": ["B"], "params": {"value": "B", "region": "graph"}},
|
| 494 |
-
{"op": "add_node", "target_ids": ["C"], "params": {"value": "C", "region": "graph"}},
|
| 495 |
-
{"op": "add_node", "target_ids": ["D"], "params": {"value": "D", "region": "graph"}},
|
| 496 |
-
{"op": "add_node", "target_ids": ["E"], "params": {"value": "E", "region": "graph"}},
|
| 497 |
-
{"op": "add_edge", "target_ids": ["A", "B"], "params": {}},
|
| 498 |
-
{"op": "add_edge", "target_ids": ["A", "C"], "params": {}},
|
| 499 |
-
{"op": "add_edge", "target_ids": ["B", "D"], "params": {}},
|
| 500 |
-
{"op": "add_edge", "target_ids": ["C", "D"], "params": {}},
|
| 501 |
-
{"op": "add_edge", "target_ids": ["C", "E"], "params": {}},
|
| 502 |
-
],
|
| 503 |
-
"covered_concepts": [],
|
| 504 |
-
"intent": "",
|
| 505 |
-
},
|
| 506 |
-
{
|
| 507 |
-
"step_type": "advance",
|
| 508 |
-
"narration": "Setting up BFS.",
|
| 509 |
-
"ops": [
|
| 510 |
-
{"op": "add_region", "target_ids": ["bfs_region"], "params": {"style": "queue", "title": "BFS"}},
|
| 511 |
-
{"op": "add_container", "target_ids": ["q"], "params": {"region": "bfs_region", "ordered": True}},
|
| 512 |
-
{"op": "set_role", "target_ids": ["A"], "params": {"role": "visited"}},
|
| 513 |
-
{"op": "set_role", "target_ids": ["B"], "params": {"role": "visited"}},
|
| 514 |
-
{"op": "set_role", "target_ids": ["C"], "params": {"role": "visited"}},
|
| 515 |
-
],
|
| 516 |
-
"covered_concepts": ["queue", "visited_set", "frontier", "level_order", "dequeue"],
|
| 517 |
-
"intent": "",
|
| 518 |
-
},
|
| 519 |
-
{
|
| 520 |
-
"step_type": "complete",
|
| 521 |
-
"narration": "BFS finished.",
|
| 522 |
-
"ops": [
|
| 523 |
-
{"op": "set_role", "target_ids": ["D"], "params": {"role": "visited"}},
|
| 524 |
-
],
|
| 525 |
-
"covered_concepts": [],
|
| 526 |
-
"intent": "",
|
| 527 |
-
},
|
| 528 |
-
]
|
| 529 |
-
|
| 530 |
-
|
| 531 |
-
def _medium_1_bad() -> List[Dict[str, Any]]:
|
| 532 |
-
return [
|
| 533 |
-
{
|
| 534 |
-
"step_type": "advance",
|
| 535 |
-
"narration": "Processing the graph.",
|
| 536 |
-
"ops": [
|
| 537 |
-
{"op": "add_node", "target_ids": ["A"], "params": {"value": "A"}},
|
| 538 |
-
{"op": "add_node", "target_ids": ["B"], "params": {"value": "B"}},
|
| 539 |
-
{"op": "set_role", "target_ids": ["A"], "params": {"role": "visited"}},
|
| 540 |
-
{"op": "set_role", "target_ids": ["B"], "params": {"role": "visited"}},
|
| 541 |
-
],
|
| 542 |
-
"covered_concepts": ["queue", "visited_set", "frontier", "level_order", "dequeue"],
|
| 543 |
-
"intent": "",
|
| 544 |
-
},
|
| 545 |
-
{
|
| 546 |
-
"step_type": "complete",
|
| 547 |
-
"narration": "Done.",
|
| 548 |
-
"ops": [],
|
| 549 |
-
"covered_concepts": [],
|
| 550 |
-
"intent": "",
|
| 551 |
-
},
|
| 552 |
-
]
|
| 553 |
-
|
| 554 |
-
|
| 555 |
-
# ---------------------------------------------------------------------------
|
| 556 |
-
# hard_1: dijkstra_step,
|
| 557 |
-
# graph A->B(1),A->C(4),B->C(2),B->D(5),C->D(1), source=A
|
| 558 |
-
# ---------------------------------------------------------------------------
|
| 559 |
-
|
| 560 |
-
|
| 561 |
-
def _hard_1_excellent() -> List[Dict[str, Any]]:
|
| 562 |
-
return [
|
| 563 |
-
{
|
| 564 |
-
"step_type": "advance",
|
| 565 |
-
"narration": "We draw the weighted directed graph — four nodes A through D with edge weights from the input data.",
|
| 566 |
-
"ops": [
|
| 567 |
-
{"op": "add_region", "target_ids": ["graph"], "params": {"style": "graph", "title": "Weighted Graph"}},
|
| 568 |
-
{"op": "add_node", "target_ids": ["A"], "params": {"value": "A", "region": "graph"}},
|
| 569 |
-
{"op": "add_node", "target_ids": ["B"], "params": {"value": "B", "region": "graph"}},
|
| 570 |
-
{"op": "add_node", "target_ids": ["C"], "params": {"value": "C", "region": "graph"}},
|
| 571 |
-
{"op": "add_node", "target_ids": ["D"], "params": {"value": "D", "region": "graph"}},
|
| 572 |
-
{"op": "add_edge", "target_ids": ["A", "B"], "params": {"label": "1"}},
|
| 573 |
-
{"op": "add_edge", "target_ids": ["A", "C"], "params": {"label": "4"}},
|
| 574 |
-
{"op": "add_edge", "target_ids": ["B", "C"], "params": {"label": "2"}},
|
| 575 |
-
{"op": "add_edge", "target_ids": ["B", "D"], "params": {"label": "5"}},
|
| 576 |
-
{"op": "add_edge", "target_ids": ["C", "D"], "params": {"label": "1"}},
|
| 577 |
-
],
|
| 578 |
-
"covered_concepts": [],
|
| 579 |
-
"intent": "draw-graph",
|
| 580 |
-
},
|
| 581 |
-
{
|
| 582 |
-
"step_type": "advance",
|
| 583 |
-
"narration": "Initialize Dijkstra by creating a priority queue and distance table — source A gets distance 0, all others start at infinity.",
|
| 584 |
-
"ops": [
|
| 585 |
-
{"op": "add_region", "target_ids": ["pq_region"], "params": {"style": "queue", "title": "Priority Queue"}},
|
| 586 |
-
{"op": "add_container", "target_ids": ["pq"], "params": {"region": "pq_region", "ordered": True}},
|
| 587 |
-
{"op": "push_to", "target_ids": ["pq", "A"], "params": {}},
|
| 588 |
-
{"op": "annotate", "target_ids": ["A"], "params": {"text": "d=0"}},
|
| 589 |
-
{"op": "set_role", "target_ids": ["A"], "params": {"role": "current"}},
|
| 590 |
-
],
|
| 591 |
-
"covered_concepts": ["priority_queue", "distance_table"],
|
| 592 |
-
"intent": "init",
|
| 593 |
-
},
|
| 594 |
-
{
|
| 595 |
-
"step_type": "advance",
|
| 596 |
-
"narration": "Extract A (d=0) from the priority queue — Dijkstra guarantees this distance is final. Relax edges to B and C: d[B]=1, d[C]=4.",
|
| 597 |
-
"ops": [
|
| 598 |
-
{"op": "pop_from", "target_ids": ["pq"], "params": {}},
|
| 599 |
-
{"op": "set_role", "target_ids": ["A"], "params": {"role": "visited"}},
|
| 600 |
-
{"op": "set_value", "target_ids": ["B"], "params": {"value": 1}},
|
| 601 |
-
{"op": "annotate", "target_ids": ["B"], "params": {"text": "d=1"}},
|
| 602 |
-
{"op": "push_to", "target_ids": ["pq", "B"], "params": {}},
|
| 603 |
-
{"op": "set_value", "target_ids": ["C"], "params": {"value": 4}},
|
| 604 |
-
],
|
| 605 |
-
"covered_concepts": ["relaxation", "shortest_path_invariant"],
|
| 606 |
-
"intent": "visit-A",
|
| 607 |
-
},
|
| 608 |
-
{
|
| 609 |
-
"step_type": "advance",
|
| 610 |
-
"narration": "Extract B (d=1) — its shortest distance is now permanent. Relaxing B's edges: d[C] improves from 4 to 3 via B, and d[D]=6.",
|
| 611 |
-
"ops": [
|
| 612 |
-
{"op": "pop_from", "target_ids": ["pq"], "params": {}},
|
| 613 |
-
{"op": "set_role", "target_ids": ["B"], "params": {"role": "visited"}},
|
| 614 |
-
{"op": "annotate", "target_ids": ["B"], "params": {"text": "d=1 final"}},
|
| 615 |
-
{"op": "set_value", "target_ids": ["C"], "params": {"value": 3}},
|
| 616 |
-
{"op": "annotate", "target_ids": ["C"], "params": {"text": "d=3"}},
|
| 617 |
-
{"op": "push_to", "target_ids": ["pq", "C"], "params": {}},
|
| 618 |
-
],
|
| 619 |
-
"covered_concepts": ["visited_set"],
|
| 620 |
-
"intent": "visit-B",
|
| 621 |
-
},
|
| 622 |
-
{
|
| 623 |
-
"step_type": "complete",
|
| 624 |
-
"narration": "Extract C (d=3), relax to D: d[D] improves to 4. Then D is extracted with final distance 4. All shortest paths found.",
|
| 625 |
-
"ops": [
|
| 626 |
-
{"op": "pop_from", "target_ids": ["pq"], "params": {}},
|
| 627 |
-
{"op": "set_role", "target_ids": ["C"], "params": {"role": "visited"}},
|
| 628 |
-
{"op": "set_value", "target_ids": ["D"], "params": {"value": 4}},
|
| 629 |
-
{"op": "annotate", "target_ids": ["D"], "params": {"text": "d=4"}},
|
| 630 |
-
{"op": "set_role", "target_ids": ["D"], "params": {"role": "done"}},
|
| 631 |
-
],
|
| 632 |
-
"covered_concepts": [],
|
| 633 |
-
"intent": "complete",
|
| 634 |
-
},
|
| 635 |
-
]
|
| 636 |
-
|
| 637 |
-
|
| 638 |
-
def _hard_1_good() -> List[Dict[str, Any]]:
|
| 639 |
-
return [
|
| 640 |
-
{
|
| 641 |
-
"step_type": "advance",
|
| 642 |
-
"narration": "Drawing the weighted graph with nodes A, B, C, D and their edge weights.",
|
| 643 |
-
"ops": [
|
| 644 |
-
{"op": "add_region", "target_ids": ["graph"], "params": {"style": "graph", "title": "Graph"}},
|
| 645 |
-
{"op": "add_node", "target_ids": ["A"], "params": {"value": "A", "region": "graph"}},
|
| 646 |
-
{"op": "add_node", "target_ids": ["B"], "params": {"value": "B", "region": "graph"}},
|
| 647 |
-
{"op": "add_node", "target_ids": ["C"], "params": {"value": "C", "region": "graph"}},
|
| 648 |
-
{"op": "add_node", "target_ids": ["D"], "params": {"value": "D", "region": "graph"}},
|
| 649 |
-
{"op": "add_edge", "target_ids": ["A", "B"], "params": {"label": "1"}},
|
| 650 |
-
{"op": "add_edge", "target_ids": ["A", "C"], "params": {"label": "4"}},
|
| 651 |
-
{"op": "add_edge", "target_ids": ["B", "C"], "params": {"label": "2"}},
|
| 652 |
-
{"op": "add_edge", "target_ids": ["B", "D"], "params": {"label": "5"}},
|
| 653 |
-
{"op": "add_edge", "target_ids": ["C", "D"], "params": {"label": "1"}},
|
| 654 |
-
],
|
| 655 |
-
"covered_concepts": [],
|
| 656 |
-
"intent": "draw-graph",
|
| 657 |
-
},
|
| 658 |
-
{
|
| 659 |
-
"step_type": "advance",
|
| 660 |
-
"narration": "Set up a priority queue for Dijkstra and initialize source A with distance 0.",
|
| 661 |
-
"ops": [
|
| 662 |
-
{"op": "add_region", "target_ids": ["pq_region"], "params": {"style": "queue", "title": "PQ"}},
|
| 663 |
-
{"op": "add_container", "target_ids": ["pq"], "params": {"region": "pq_region", "ordered": True}},
|
| 664 |
-
{"op": "push_to", "target_ids": ["pq", "A"], "params": {}},
|
| 665 |
-
{"op": "annotate", "target_ids": ["A"], "params": {"text": "d=0"}},
|
| 666 |
-
],
|
| 667 |
-
"covered_concepts": ["priority_queue", "distance_table"],
|
| 668 |
-
"intent": "init",
|
| 669 |
-
},
|
| 670 |
-
{
|
| 671 |
-
"step_type": "advance",
|
| 672 |
-
"narration": "Extract A, relax edges to B (d=1) and C (d=4).",
|
| 673 |
-
"ops": [
|
| 674 |
-
{"op": "pop_from", "target_ids": ["pq"], "params": {}},
|
| 675 |
-
{"op": "set_role", "target_ids": ["A"], "params": {"role": "visited"}},
|
| 676 |
-
{"op": "set_value", "target_ids": ["B"], "params": {"value": 1}},
|
| 677 |
-
{"op": "push_to", "target_ids": ["pq", "B"], "params": {}},
|
| 678 |
-
{"op": "set_value", "target_ids": ["C"], "params": {"value": 4}},
|
| 679 |
-
],
|
| 680 |
-
"covered_concepts": ["relaxation"],
|
| 681 |
-
"intent": "visit-A",
|
| 682 |
-
},
|
| 683 |
-
{
|
| 684 |
-
"step_type": "advance",
|
| 685 |
-
"narration": "Extract B, update C's distance to 3 via B. Push C into the queue.",
|
| 686 |
-
"ops": [
|
| 687 |
-
{"op": "pop_from", "target_ids": ["pq"], "params": {}},
|
| 688 |
-
{"op": "set_role", "target_ids": ["B"], "params": {"role": "visited"}},
|
| 689 |
-
{"op": "set_value", "target_ids": ["C"], "params": {"value": 3}},
|
| 690 |
-
{"op": "push_to", "target_ids": ["pq", "C"], "params": {}},
|
| 691 |
-
],
|
| 692 |
-
"covered_concepts": ["visited_set", "shortest_path_invariant"],
|
| 693 |
-
"intent": "visit-B",
|
| 694 |
-
},
|
| 695 |
-
{
|
| 696 |
-
"step_type": "complete",
|
| 697 |
-
"narration": "Extract C, relax D to distance 4. All shortest paths computed.",
|
| 698 |
-
"ops": [
|
| 699 |
-
{"op": "pop_from", "target_ids": ["pq"], "params": {}},
|
| 700 |
-
{"op": "set_role", "target_ids": ["C"], "params": {"role": "visited"}},
|
| 701 |
-
{"op": "set_value", "target_ids": ["D"], "params": {"value": 4}},
|
| 702 |
-
{"op": "set_role", "target_ids": ["D"], "params": {"role": "done"}},
|
| 703 |
-
],
|
| 704 |
-
"covered_concepts": [],
|
| 705 |
-
"intent": "complete",
|
| 706 |
-
},
|
| 707 |
-
]
|
| 708 |
-
|
| 709 |
-
|
| 710 |
-
def _hard_1_mediocre() -> List[Dict[str, Any]]:
|
| 711 |
-
return [
|
| 712 |
-
{
|
| 713 |
-
"step_type": "advance",
|
| 714 |
-
"narration": "Drawing the graph for Dijkstra.",
|
| 715 |
-
"ops": [
|
| 716 |
-
{"op": "add_region", "target_ids": ["graph"], "params": {"style": "graph", "title": "Graph"}},
|
| 717 |
-
{"op": "add_node", "target_ids": ["A"], "params": {"value": "A", "region": "graph"}},
|
| 718 |
-
{"op": "add_node", "target_ids": ["B"], "params": {"value": "B", "region": "graph"}},
|
| 719 |
-
{"op": "add_node", "target_ids": ["C"], "params": {"value": "C", "region": "graph"}},
|
| 720 |
-
{"op": "add_node", "target_ids": ["D"], "params": {"value": "D", "region": "graph"}},
|
| 721 |
-
{"op": "add_edge", "target_ids": ["A", "B"], "params": {"label": "1"}},
|
| 722 |
-
{"op": "add_edge", "target_ids": ["A", "C"], "params": {"label": "4"}},
|
| 723 |
-
{"op": "add_edge", "target_ids": ["B", "C"], "params": {"label": "2"}},
|
| 724 |
-
{"op": "add_edge", "target_ids": ["B", "D"], "params": {"label": "5"}},
|
| 725 |
-
{"op": "add_edge", "target_ids": ["C", "D"], "params": {"label": "1"}},
|
| 726 |
-
],
|
| 727 |
-
"covered_concepts": [],
|
| 728 |
-
"intent": "",
|
| 729 |
-
},
|
| 730 |
-
{
|
| 731 |
-
"step_type": "advance",
|
| 732 |
-
"narration": "Running Dijkstra on the graph.",
|
| 733 |
-
"ops": [
|
| 734 |
-
{"op": "add_region", "target_ids": ["pq_region"], "params": {"style": "queue", "title": "PQ"}},
|
| 735 |
-
{"op": "add_container", "target_ids": ["pq"], "params": {"region": "pq_region", "ordered": True}},
|
| 736 |
-
{"op": "set_role", "target_ids": ["A"], "params": {"role": "visited"}},
|
| 737 |
-
{"op": "set_role", "target_ids": ["B"], "params": {"role": "visited"}},
|
| 738 |
-
{"op": "set_value", "target_ids": ["D"], "params": {"value": 4}},
|
| 739 |
-
],
|
| 740 |
-
"covered_concepts": ["priority_queue", "distance_table", "relaxation", "visited_set", "shortest_path_invariant"],
|
| 741 |
-
"intent": "",
|
| 742 |
-
},
|
| 743 |
-
{
|
| 744 |
-
"step_type": "complete",
|
| 745 |
-
"narration": "Shortest paths found.",
|
| 746 |
-
"ops": [
|
| 747 |
-
{"op": "set_role", "target_ids": ["C"], "params": {"role": "visited"}},
|
| 748 |
-
],
|
| 749 |
-
"covered_concepts": [],
|
| 750 |
-
"intent": "",
|
| 751 |
-
},
|
| 752 |
-
]
|
| 753 |
-
|
| 754 |
-
|
| 755 |
-
def _hard_1_bad() -> List[Dict[str, Any]]:
|
| 756 |
-
return [
|
| 757 |
-
{
|
| 758 |
-
"step_type": "advance",
|
| 759 |
-
"narration": "Starting Dijkstra.",
|
| 760 |
-
"ops": [
|
| 761 |
-
{"op": "add_node", "target_ids": ["A"], "params": {"value": "A"}},
|
| 762 |
-
{"op": "set_role", "target_ids": ["A"], "params": {"role": "done"}},
|
| 763 |
-
],
|
| 764 |
-
"covered_concepts": ["priority_queue", "distance_table", "relaxation", "visited_set", "shortest_path_invariant"],
|
| 765 |
-
"intent": "",
|
| 766 |
-
},
|
| 767 |
-
{
|
| 768 |
-
"step_type": "complete",
|
| 769 |
-
"narration": "Dijkstra done.",
|
| 770 |
-
"ops": [],
|
| 771 |
-
"covered_concepts": [],
|
| 772 |
-
"intent": "",
|
| 773 |
-
},
|
| 774 |
-
]
|
| 775 |
-
|
| 776 |
-
|
| 777 |
-
# ---------------------------------------------------------------------------
|
| 778 |
-
# hard_2: bst_insert, incremental, keys=[5,3,7,2,4,8]
|
| 779 |
-
# ---------------------------------------------------------------------------
|
| 780 |
-
|
| 781 |
-
|
| 782 |
-
def _hard_2_excellent() -> List[Dict[str, Any]]:
|
| 783 |
-
return [
|
| 784 |
-
{
|
| 785 |
-
"step_type": "advance",
|
| 786 |
-
"narration": "Create a tree region and insert 5 as the root node — the first key always becomes the root of the BST.",
|
| 787 |
-
"ops": [
|
| 788 |
-
{"op": "add_region", "target_ids": ["tree"], "params": {"style": "tree", "title": "BST", "root": "n5"}},
|
| 789 |
-
{"op": "add_node", "target_ids": ["n5"], "params": {"value": 5, "region": "tree"}},
|
| 790 |
-
{"op": "set_role", "target_ids": ["n5"], "params": {"role": "root"}},
|
| 791 |
-
],
|
| 792 |
-
"covered_concepts": ["root_node"],
|
| 793 |
-
"intent": "insert-root",
|
| 794 |
-
},
|
| 795 |
-
{
|
| 796 |
-
"step_type": "advance",
|
| 797 |
-
"narration": "Insert 3: comparing with root 5, 3 < 5 so it goes left — this maintains the BST invariant where left children are smaller.",
|
| 798 |
-
"ops": [
|
| 799 |
-
{"op": "add_node", "target_ids": ["n3"], "params": {"value": 3, "region": "tree"}},
|
| 800 |
-
{"op": "add_edge", "target_ids": ["n5", "n3"], "params": {"label": "L"}},
|
| 801 |
-
{"op": "set_role", "target_ids": ["n5"], "params": {"role": "comparing"}},
|
| 802 |
-
],
|
| 803 |
-
"covered_concepts": ["bst_invariant", "left_subtree"],
|
| 804 |
-
"intent": "insert-3",
|
| 805 |
-
},
|
| 806 |
-
{
|
| 807 |
-
"step_type": "advance",
|
| 808 |
-
"narration": "Insert 7: comparing with root 5, 7 > 5 so it goes right — the right subtree holds all keys greater than the parent.",
|
| 809 |
-
"ops": [
|
| 810 |
-
{"op": "add_node", "target_ids": ["n7"], "params": {"value": 7, "region": "tree"}},
|
| 811 |
-
{"op": "add_edge", "target_ids": ["n5", "n7"], "params": {"label": "R"}},
|
| 812 |
-
{"op": "set_role", "target_ids": ["n5"], "params": {"role": "root"}},
|
| 813 |
-
],
|
| 814 |
-
"covered_concepts": ["right_subtree"],
|
| 815 |
-
"intent": "insert-7",
|
| 816 |
-
},
|
| 817 |
-
{
|
| 818 |
-
"step_type": "advance",
|
| 819 |
-
"narration": "Insert 2 and 4: recursively comparing, 2 < 3 goes left of n3, and 4 > 3 goes right of n3.",
|
| 820 |
-
"ops": [
|
| 821 |
-
{"op": "add_node", "target_ids": ["n2"], "params": {"value": 2, "region": "tree"}},
|
| 822 |
-
{"op": "add_edge", "target_ids": ["n3", "n2"], "params": {"label": "L"}},
|
| 823 |
-
{"op": "add_node", "target_ids": ["n4"], "params": {"value": 4, "region": "tree"}},
|
| 824 |
-
{"op": "add_edge", "target_ids": ["n3", "n4"], "params": {"label": "R"}},
|
| 825 |
-
],
|
| 826 |
-
"covered_concepts": ["recursive_insert"],
|
| 827 |
-
"intent": "insert-2-4",
|
| 828 |
-
},
|
| 829 |
-
{
|
| 830 |
-
"step_type": "complete",
|
| 831 |
-
"narration": "Insert 8: 8 > 5 go right, 8 > 7 go right — each insertion recursively finds the correct leaf position.",
|
| 832 |
-
"ops": [
|
| 833 |
-
{"op": "add_node", "target_ids": ["n8"], "params": {"value": 8, "region": "tree"}},
|
| 834 |
-
{"op": "add_edge", "target_ids": ["n7", "n8"], "params": {"label": "R"}},
|
| 835 |
-
],
|
| 836 |
-
"covered_concepts": [],
|
| 837 |
-
"intent": "insert-8-complete",
|
| 838 |
-
},
|
| 839 |
-
]
|
| 840 |
-
|
| 841 |
-
|
| 842 |
-
def _hard_2_good() -> List[Dict[str, Any]]:
|
| 843 |
-
return [
|
| 844 |
-
{
|
| 845 |
-
"step_type": "advance",
|
| 846 |
-
"narration": "Create a BST tree region and insert 5 as root.",
|
| 847 |
-
"ops": [
|
| 848 |
-
{"op": "add_region", "target_ids": ["tree"], "params": {"style": "tree", "title": "BST", "root": "n5"}},
|
| 849 |
-
{"op": "add_node", "target_ids": ["n5"], "params": {"value": 5, "region": "tree"}},
|
| 850 |
-
{"op": "set_role", "target_ids": ["n5"], "params": {"role": "root"}},
|
| 851 |
-
],
|
| 852 |
-
"covered_concepts": ["root_node"],
|
| 853 |
-
"intent": "root",
|
| 854 |
-
},
|
| 855 |
-
{
|
| 856 |
-
"step_type": "advance",
|
| 857 |
-
"narration": "Insert 3 left of 5 and 7 right of 5, maintaining BST invariant.",
|
| 858 |
-
"ops": [
|
| 859 |
-
{"op": "add_node", "target_ids": ["n3"], "params": {"value": 3, "region": "tree"}},
|
| 860 |
-
{"op": "add_edge", "target_ids": ["n5", "n3"], "params": {"label": "L"}},
|
| 861 |
-
{"op": "add_node", "target_ids": ["n7"], "params": {"value": 7, "region": "tree"}},
|
| 862 |
-
{"op": "add_edge", "target_ids": ["n5", "n7"], "params": {"label": "R"}},
|
| 863 |
-
],
|
| 864 |
-
"covered_concepts": ["bst_invariant", "left_subtree", "right_subtree"],
|
| 865 |
-
"intent": "insert-3-7",
|
| 866 |
-
},
|
| 867 |
-
{
|
| 868 |
-
"step_type": "advance",
|
| 869 |
-
"narration": "Recursively insert 2 left of 3 and 4 right of 3.",
|
| 870 |
-
"ops": [
|
| 871 |
-
{"op": "add_node", "target_ids": ["n2"], "params": {"value": 2, "region": "tree"}},
|
| 872 |
-
{"op": "add_edge", "target_ids": ["n3", "n2"], "params": {"label": "L"}},
|
| 873 |
-
{"op": "add_node", "target_ids": ["n4"], "params": {"value": 4, "region": "tree"}},
|
| 874 |
-
{"op": "add_edge", "target_ids": ["n3", "n4"], "params": {"label": "R"}},
|
| 875 |
-
],
|
| 876 |
-
"covered_concepts": ["recursive_insert"],
|
| 877 |
-
"intent": "insert-2-4",
|
| 878 |
-
},
|
| 879 |
-
{
|
| 880 |
-
"step_type": "complete",
|
| 881 |
-
"narration": "Insert 8 right of 7 to complete the BST.",
|
| 882 |
-
"ops": [
|
| 883 |
-
{"op": "add_node", "target_ids": ["n8"], "params": {"value": 8, "region": "tree"}},
|
| 884 |
-
{"op": "add_edge", "target_ids": ["n7", "n8"], "params": {"label": "R"}},
|
| 885 |
-
],
|
| 886 |
-
"covered_concepts": [],
|
| 887 |
-
"intent": "done",
|
| 888 |
-
},
|
| 889 |
-
]
|
| 890 |
-
|
| 891 |
-
|
| 892 |
-
def _hard_2_mediocre() -> List[Dict[str, Any]]:
|
| 893 |
-
return [
|
| 894 |
-
{
|
| 895 |
-
"step_type": "advance",
|
| 896 |
-
"narration": "Building the BST.",
|
| 897 |
-
"ops": [
|
| 898 |
-
{"op": "add_region", "target_ids": ["tree"], "params": {"style": "tree", "title": "BST"}},
|
| 899 |
-
{"op": "add_node", "target_ids": ["n5"], "params": {"value": 5, "region": "tree"}},
|
| 900 |
-
{"op": "add_node", "target_ids": ["n3"], "params": {"value": 3, "region": "tree"}},
|
| 901 |
-
{"op": "add_edge", "target_ids": ["n5", "n3"], "params": {}},
|
| 902 |
-
{"op": "add_node", "target_ids": ["n7"], "params": {"value": 7, "region": "tree"}},
|
| 903 |
-
{"op": "add_edge", "target_ids": ["n5", "n7"], "params": {}},
|
| 904 |
-
],
|
| 905 |
-
"covered_concepts": ["root_node", "bst_invariant", "left_subtree", "right_subtree", "recursive_insert"],
|
| 906 |
-
"intent": "",
|
| 907 |
-
},
|
| 908 |
-
{
|
| 909 |
-
"step_type": "complete",
|
| 910 |
-
"narration": "Tree built.",
|
| 911 |
-
"ops": [
|
| 912 |
-
{"op": "add_node", "target_ids": ["n2"], "params": {"value": 2, "region": "tree"}},
|
| 913 |
-
{"op": "add_edge", "target_ids": ["n3", "n2"], "params": {}},
|
| 914 |
-
],
|
| 915 |
-
"covered_concepts": [],
|
| 916 |
-
"intent": "",
|
| 917 |
-
},
|
| 918 |
-
]
|
| 919 |
-
|
| 920 |
-
|
| 921 |
-
def _hard_2_bad() -> List[Dict[str, Any]]:
|
| 922 |
-
return [
|
| 923 |
-
{
|
| 924 |
-
"step_type": "advance",
|
| 925 |
-
"narration": "Inserting keys.",
|
| 926 |
-
"ops": [
|
| 927 |
-
{"op": "add_node", "target_ids": ["n5"], "params": {"value": 5}},
|
| 928 |
-
],
|
| 929 |
-
"covered_concepts": ["root_node", "bst_invariant", "left_subtree", "right_subtree", "recursive_insert"],
|
| 930 |
-
"intent": "",
|
| 931 |
-
},
|
| 932 |
-
{
|
| 933 |
-
"step_type": "complete",
|
| 934 |
-
"narration": "BST done.",
|
| 935 |
-
"ops": [],
|
| 936 |
-
"covered_concepts": [],
|
| 937 |
-
"intent": "",
|
| 938 |
-
},
|
| 939 |
-
]
|
| 940 |
-
|
| 941 |
-
|
| 942 |
-
# ---------------------------------------------------------------------------
|
| 943 |
-
# scenario registry
|
| 944 |
-
# ---------------------------------------------------------------------------
|
| 945 |
-
|
| 946 |
-
|
| 947 |
-
def get_all_scripted_scenarios() -> Dict[str, Dict[str, List[Dict[str, Any]]]]:
|
| 948 |
-
return {
|
| 949 |
-
"easy_1": {
|
| 950 |
-
"excellent": _easy_1_excellent(),
|
| 951 |
-
"good": _easy_1_good(),
|
| 952 |
-
"mediocre": _easy_1_mediocre(),
|
| 953 |
-
"bad": _easy_1_bad(),
|
| 954 |
-
},
|
| 955 |
-
"easy_2": {
|
| 956 |
-
"excellent": _easy_2_excellent(),
|
| 957 |
-
"good": _easy_2_good(),
|
| 958 |
-
"mediocre": _easy_2_mediocre(),
|
| 959 |
-
"bad": _easy_2_bad(),
|
| 960 |
-
},
|
| 961 |
-
"medium_1": {
|
| 962 |
-
"excellent": _medium_1_excellent(),
|
| 963 |
-
"good": _medium_1_good(),
|
| 964 |
-
"mediocre": _medium_1_mediocre(),
|
| 965 |
-
"bad": _medium_1_bad(),
|
| 966 |
-
},
|
| 967 |
-
"hard_1": {
|
| 968 |
-
"excellent": _hard_1_excellent(),
|
| 969 |
-
"good": _hard_1_good(),
|
| 970 |
-
"mediocre": _hard_1_mediocre(),
|
| 971 |
-
"bad": _hard_1_bad(),
|
| 972 |
-
},
|
| 973 |
-
"hard_2": {
|
| 974 |
-
"excellent": _hard_2_excellent(),
|
| 975 |
-
"good": _hard_2_good(),
|
| 976 |
-
"mediocre": _hard_2_mediocre(),
|
| 977 |
-
"bad": _hard_2_bad(),
|
| 978 |
-
},
|
| 979 |
-
}
|
| 980 |
-
|
| 981 |
-
|
| 982 |
-
# ---------------------------------------------------------------------------
|
| 983 |
-
# rollout runner
|
| 984 |
-
# ---------------------------------------------------------------------------
|
| 985 |
-
|
| 986 |
-
|
| 987 |
-
def run_scripted_rollout(
|
| 988 |
-
env: VisualReasoningEnvironment,
|
| 989 |
-
scenario_id: str,
|
| 990 |
-
actions: List[Dict[str, Any]],
|
| 991 |
-
) -> List[Dict[str, Any]]:
|
| 992 |
-
obs = env.reset(scenario_id=scenario_id)
|
| 993 |
-
template = obs.task_name
|
| 994 |
-
steps: List[Dict[str, Any]] = []
|
| 995 |
-
prev_narration = ""
|
| 996 |
-
|
| 997 |
-
for i, action_dict in enumerate(actions):
|
| 998 |
-
action = VisualReasoningAction(**action_dict)
|
| 999 |
-
obs = env.step(action)
|
| 1000 |
-
|
| 1001 |
-
ops_parts: List[str] = []
|
| 1002 |
-
for op in action_dict.get("ops") or []:
|
| 1003 |
-
tids = ",".join(op.get("target_ids") or [])
|
| 1004 |
-
ops_parts.append(f"{op['op']}[{tids}]")
|
| 1005 |
-
ops_summary = ", ".join(ops_parts)
|
| 1006 |
-
|
| 1007 |
-
checklist = list(obs.concept_checklist)
|
| 1008 |
-
covered = list(obs.concept_coverage)
|
| 1009 |
-
|
| 1010 |
-
steps.append({
|
| 1011 |
-
"scenario_id": scenario_id,
|
| 1012 |
-
"template": template,
|
| 1013 |
-
"step_id": i + 1,
|
| 1014 |
-
"narration": action_dict.get("narration", ""),
|
| 1015 |
-
"ops_summary": ops_summary,
|
| 1016 |
-
"prev_narration": prev_narration,
|
| 1017 |
-
"covered_concepts": action_dict.get("covered_concepts", []),
|
| 1018 |
-
"step_progress": f"step {i + 1}, {len(covered)}/{len(checklist)} concepts",
|
| 1019 |
-
"score_breakdown": dict(obs.score_breakdown),
|
| 1020 |
-
"reward": obs.reward,
|
| 1021 |
-
"action": action_dict,
|
| 1022 |
-
})
|
| 1023 |
-
prev_narration = action_dict.get("narration", "")
|
| 1024 |
-
if obs.done:
|
| 1025 |
-
break
|
| 1026 |
-
|
| 1027 |
-
return steps
|
| 1028 |
-
|
| 1029 |
-
|
| 1030 |
-
# ---------------------------------------------------------------------------
|
| 1031 |
-
# labeling prompt builder
|
| 1032 |
-
# ---------------------------------------------------------------------------
|
| 1033 |
-
|
| 1034 |
-
|
| 1035 |
-
def build_labeling_prompt(step: Dict[str, Any]) -> str:
|
| 1036 |
-
return LABELING_PROMPT.format(
|
| 1037 |
-
template=step["template"],
|
| 1038 |
-
ops_summary=step["ops_summary"] or "(none)",
|
| 1039 |
-
prev_narration=step["prev_narration"] or "(none)",
|
| 1040 |
-
narration=step["narration"],
|
| 1041 |
-
covered_concepts=", ".join(step["covered_concepts"]) or "(none)",
|
| 1042 |
-
step_progress=step["step_progress"],
|
| 1043 |
-
)
|
| 1044 |
-
|
| 1045 |
-
|
| 1046 |
-
# ---------------------------------------------------------------------------
|
| 1047 |
-
# main
|
| 1048 |
-
# ---------------------------------------------------------------------------
|
| 1049 |
-
|
| 1050 |
-
|
| 1051 |
-
def main() -> None:
|
| 1052 |
-
parser = argparse.ArgumentParser(description="Generate rubric training data")
|
| 1053 |
-
parser.add_argument("--output-dir", default=str(Path(__file__).parent / "output"))
|
| 1054 |
-
args = parser.parse_args()
|
| 1055 |
-
os.makedirs(args.output_dir, exist_ok=True)
|
| 1056 |
-
|
| 1057 |
-
env = VisualReasoningEnvironment()
|
| 1058 |
-
all_steps: List[Dict[str, Any]] = []
|
| 1059 |
-
|
| 1060 |
-
scenarios = get_all_scripted_scenarios()
|
| 1061 |
-
|
| 1062 |
-
for scenario_id, quality_map in scenarios.items():
|
| 1063 |
-
for quality, actions in quality_map.items():
|
| 1064 |
-
try:
|
| 1065 |
-
steps = run_scripted_rollout(env, scenario_id, actions)
|
| 1066 |
-
except Exception as exc:
|
| 1067 |
-
print(f"WARN: {scenario_id}/{quality} failed: {exc}")
|
| 1068 |
-
continue
|
| 1069 |
-
for s in steps:
|
| 1070 |
-
s["quality_level"] = quality
|
| 1071 |
-
all_steps.extend(steps)
|
| 1072 |
-
|
| 1073 |
-
rollout_path = os.path.join(args.output_dir, "rollout_data.jsonl")
|
| 1074 |
-
with open(rollout_path, "w", encoding="utf-8") as f:
|
| 1075 |
-
for step in all_steps:
|
| 1076 |
-
f.write(json.dumps(step, default=str) + "\n")
|
| 1077 |
-
|
| 1078 |
-
prompts_path = os.path.join(args.output_dir, "labeling_prompts.jsonl")
|
| 1079 |
-
with open(prompts_path, "w", encoding="utf-8") as f:
|
| 1080 |
-
for step in all_steps:
|
| 1081 |
-
record = {
|
| 1082 |
-
"scenario_id": step["scenario_id"],
|
| 1083 |
-
"quality_level": step["quality_level"],
|
| 1084 |
-
"step_id": step["step_id"],
|
| 1085 |
-
"prompt": build_labeling_prompt(step),
|
| 1086 |
-
}
|
| 1087 |
-
f.write(json.dumps(record) + "\n")
|
| 1088 |
-
|
| 1089 |
-
quality_stats: Dict[str, Dict[str, Any]] = {}
|
| 1090 |
-
for step in all_steps:
|
| 1091 |
-
ql = step["quality_level"]
|
| 1092 |
-
if ql not in quality_stats:
|
| 1093 |
-
quality_stats[ql] = {"count": 0, "total_reward": 0.0, "scores": []}
|
| 1094 |
-
quality_stats[ql]["count"] += 1
|
| 1095 |
-
quality_stats[ql]["total_reward"] += step["reward"]
|
| 1096 |
-
quality_stats[ql]["scores"].append(
|
| 1097 |
-
step["score_breakdown"].get("overall_score", 0.0)
|
| 1098 |
-
)
|
| 1099 |
-
|
| 1100 |
-
summary: Dict[str, Any] = {
|
| 1101 |
-
"total_steps": len(all_steps),
|
| 1102 |
-
"scenarios": len(scenarios),
|
| 1103 |
-
"quality_levels": {},
|
| 1104 |
-
}
|
| 1105 |
-
for ql, stats in quality_stats.items():
|
| 1106 |
-
scores = stats["scores"]
|
| 1107 |
-
summary["quality_levels"][ql] = {
|
| 1108 |
-
"step_count": stats["count"],
|
| 1109 |
-
"avg_reward": round(stats["total_reward"] / max(1, stats["count"]), 4),
|
| 1110 |
-
"avg_score": round(sum(scores) / max(1, len(scores)), 4),
|
| 1111 |
-
"min_score": round(min(scores) if scores else 0.0, 4),
|
| 1112 |
-
"max_score": round(max(scores) if scores else 0.0, 4),
|
| 1113 |
-
}
|
| 1114 |
-
|
| 1115 |
-
summary_path = os.path.join(args.output_dir, "rubric_summary.json")
|
| 1116 |
-
with open(summary_path, "w", encoding="utf-8") as f:
|
| 1117 |
-
json.dump(summary, f, indent=2)
|
| 1118 |
-
|
| 1119 |
-
print(f"Generated {len(all_steps)} steps across {len(scenarios)} scenarios")
|
| 1120 |
-
print(f"Output: {args.output_dir}/")
|
| 1121 |
-
for ql in ("excellent", "good", "mediocre", "bad"):
|
| 1122 |
-
if ql in summary["quality_levels"]:
|
| 1123 |
-
info = summary["quality_levels"][ql]
|
| 1124 |
-
print(
|
| 1125 |
-
f" {ql:>10}: {info['step_count']} steps, "
|
| 1126 |
-
f"avg_score={info['avg_score']:.4f}, "
|
| 1127 |
-
f"avg_reward={info['avg_reward']:.4f}"
|
| 1128 |
-
)
|
| 1129 |
-
|
| 1130 |
-
|
| 1131 |
-
if __name__ == "__main__":
|
| 1132 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
server/app.py
CHANGED
|
@@ -43,7 +43,10 @@ try:
|
|
| 43 |
from .scenario_loader import load_scenarios
|
| 44 |
from .scoring import weights_for_difficulty
|
| 45 |
from .constants import (
|
| 46 |
-
ALLOWED_OPS,
|
|
|
|
|
|
|
|
|
|
| 47 |
)
|
| 48 |
except ImportError:
|
| 49 |
from models import VisualReasoningAction, VisualReasoningObservation
|
|
@@ -51,7 +54,10 @@ except ImportError:
|
|
| 51 |
from server.scenario_loader import load_scenarios
|
| 52 |
from server.scoring import weights_for_difficulty
|
| 53 |
from server.constants import (
|
| 54 |
-
ALLOWED_OPS,
|
|
|
|
|
|
|
|
|
|
| 55 |
)
|
| 56 |
|
| 57 |
|
|
@@ -97,7 +103,12 @@ def _scenario_display_name(s: Dict[str, Any]) -> str:
|
|
| 97 |
|
| 98 |
|
| 99 |
def _tier_emoji(t: str) -> str:
|
| 100 |
-
return {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
|
| 102 |
|
| 103 |
def _tier_label(t: str) -> str:
|
|
@@ -108,6 +119,7 @@ def _tier_label(t: str) -> str:
|
|
| 108 |
# Broadcaster (WebSocket fan-out for live viewer)
|
| 109 |
# ---------------------------------------------------------------------------
|
| 110 |
|
|
|
|
| 111 |
class Broadcaster:
|
| 112 |
"""Append-only message log with long-poll support."""
|
| 113 |
|
|
@@ -162,6 +174,7 @@ _broadcaster = Broadcaster()
|
|
| 162 |
# TTS via fal-ai Kokoro
|
| 163 |
# ---------------------------------------------------------------------------
|
| 164 |
|
|
|
|
| 165 |
async def _tts_to_base64(text: str) -> Optional[str]:
|
| 166 |
if not text or not text.strip():
|
| 167 |
return None
|
|
@@ -169,6 +182,7 @@ async def _tts_to_base64(text: str) -> Optional[str]:
|
|
| 169 |
return None
|
| 170 |
try:
|
| 171 |
import aiohttp
|
|
|
|
| 172 |
async with aiohttp.ClientSession() as session:
|
| 173 |
async with session.post(
|
| 174 |
"https://fal.run/fal-ai/kokoro",
|
|
@@ -181,14 +195,22 @@ async def _tts_to_base64(text: str) -> Optional[str]:
|
|
| 181 |
) as resp:
|
| 182 |
if resp.status != 200:
|
| 183 |
body = await resp.text()
|
| 184 |
-
print(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
return None
|
| 186 |
data = await resp.json()
|
| 187 |
audio_url = data.get("audio", {}).get("url")
|
| 188 |
if not audio_url:
|
| 189 |
-
print(
|
|
|
|
|
|
|
| 190 |
return None
|
| 191 |
-
async with session.get(
|
|
|
|
|
|
|
| 192 |
audio_bytes = await audio_resp.read()
|
| 193 |
result = base64.b64encode(audio_bytes).decode("ascii")
|
| 194 |
print(f"[TTS] generated {len(audio_bytes)} bytes of audio", flush=True)
|
|
@@ -211,6 +233,7 @@ def _get_llm_client(api_key: str = ""):
|
|
| 211 |
key = api_key or LLM_API_KEY
|
| 212 |
if _openai_client is None or key != _openai_client_key:
|
| 213 |
from openai import OpenAI
|
|
|
|
| 214 |
_openai_client = OpenAI(base_url=LLM_API_BASE, api_key=key)
|
| 215 |
_openai_client_key = key
|
| 216 |
return _openai_client
|
|
@@ -219,6 +242,7 @@ def _get_llm_client(api_key: str = ""):
|
|
| 219 |
def _build_system_prompt() -> str:
|
| 220 |
try:
|
| 221 |
from inference import SYSTEM_PROMPT
|
|
|
|
| 222 |
return SYSTEM_PROMPT
|
| 223 |
except Exception:
|
| 224 |
return (
|
|
@@ -227,9 +251,12 @@ def _build_system_prompt() -> str:
|
|
| 227 |
)
|
| 228 |
|
| 229 |
|
| 230 |
-
def _build_user_prompt(
|
|
|
|
|
|
|
| 231 |
try:
|
| 232 |
from inference import build_user_prompt
|
|
|
|
| 233 |
return build_user_prompt(obs, last_action, last_reward, history)
|
| 234 |
except Exception:
|
| 235 |
return f"Goal: {obs.goal}\nEntities: {list(obs.entities.keys())}\nRemaining steps: {obs.remaining_step_budget}"
|
|
@@ -245,10 +272,12 @@ def _strip_thinking_tokens(text: str) -> str:
|
|
| 245 |
def _parse_llm_response(text: str) -> Optional[Dict[str, Any]]:
|
| 246 |
try:
|
| 247 |
from inference import parse_action, normalize_action
|
|
|
|
| 248 |
parsed = parse_action(text)
|
| 249 |
return normalize_action(parsed or {})
|
| 250 |
except Exception:
|
| 251 |
import re
|
|
|
|
| 252 |
match = re.search(r"\{.*\}", text, flags=re.DOTALL)
|
| 253 |
if match:
|
| 254 |
try:
|
|
@@ -315,18 +344,20 @@ async def _run_demo(scenario_id: str, api_key: str = "", model_name: str = "") -
|
|
| 315 |
|
| 316 |
goal_audio = await _tts_to_base64(obs.goal)
|
| 317 |
snap = _obs_snapshot(obs)
|
| 318 |
-
await _broadcaster.send(
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
|
|
|
|
|
|
| 330 |
|
| 331 |
await asyncio.sleep(2.0)
|
| 332 |
|
|
@@ -349,7 +380,10 @@ async def _run_demo(scenario_id: str, api_key: str = "", model_name: str = "") -
|
|
| 349 |
temperature=LLM_TEMPERATURE,
|
| 350 |
max_tokens=LLM_MAX_TOKENS,
|
| 351 |
stream=False,
|
| 352 |
-
)
|
|
|
|
|
|
|
|
|
|
| 353 |
)
|
| 354 |
text = _strip_thinking_tokens(text)
|
| 355 |
print(f"[DEMO] Step {step}: LLM returned {len(text)} chars", flush=True)
|
|
@@ -390,23 +424,25 @@ async def _run_demo(scenario_id: str, api_key: str = "", model_name: str = "") -
|
|
| 390 |
history.append(f"Step {step}: {action_dict.get('narration', '')}")
|
| 391 |
|
| 392 |
snap = _obs_snapshot(obs)
|
| 393 |
-
await _broadcaster.send(
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
|
|
|
|
|
|
| 410 |
|
| 411 |
if done:
|
| 412 |
await asyncio.sleep(2.0)
|
|
@@ -418,20 +454,23 @@ async def _run_demo(scenario_id: str, api_key: str = "", model_name: str = "") -
|
|
| 418 |
print(f"[DEMO] error: {exc}", file=sys.stderr, flush=True)
|
| 419 |
traceback.print_exc()
|
| 420 |
|
| 421 |
-
await _broadcaster.send(
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
|
|
|
|
|
|
|
| 429 |
|
| 430 |
|
| 431 |
# ---------------------------------------------------------------------------
|
| 432 |
# Scenario browser callbacks
|
| 433 |
# ---------------------------------------------------------------------------
|
| 434 |
|
|
|
|
| 435 |
def list_scenario_choices() -> List[str]:
|
| 436 |
return [_scenario_display_name(s) for s in _get_scenarios()]
|
| 437 |
|
|
@@ -460,7 +499,9 @@ def show_scenario(choice: str) -> tuple:
|
|
| 460 |
|
| 461 |
checklist_md = "\n".join(f"- `{c}`" for c in checklist)
|
| 462 |
|
| 463 |
-
constraints_md =
|
|
|
|
|
|
|
| 464 |
|
| 465 |
return header, input_md, checklist_md, constraints_md
|
| 466 |
return "Scenario not found.", "", "", ""
|
|
@@ -470,6 +511,7 @@ def show_scenario(choice: str) -> tuple:
|
|
| 470 |
# Scoring explorer callback
|
| 471 |
# ---------------------------------------------------------------------------
|
| 472 |
|
|
|
|
| 473 |
def show_weights(difficulty: str) -> str:
|
| 474 |
d = difficulty.lower()
|
| 475 |
w = weights_for_difficulty(d)
|
|
@@ -486,7 +528,10 @@ def show_weights(difficulty: str) -> str:
|
|
| 486 |
# Live demo callbacks (Gradio)
|
| 487 |
# ---------------------------------------------------------------------------
|
| 488 |
|
| 489 |
-
|
|
|
|
|
|
|
|
|
|
| 490 |
global _demo_task
|
| 491 |
if not scenario_choice:
|
| 492 |
return "Select a scenario first."
|
|
@@ -523,7 +568,10 @@ async def _stop_live_demo() -> str:
|
|
| 523 |
# Gradio UI (custom builder for openenv's gradio_builder parameter)
|
| 524 |
# ---------------------------------------------------------------------------
|
| 525 |
|
| 526 |
-
|
|
|
|
|
|
|
|
|
|
| 527 |
"""Custom Gradio UI builder for openenv's gradio_builder parameter."""
|
| 528 |
with gr.Blocks(
|
| 529 |
title="Visual Reasoning Environment",
|
|
@@ -537,7 +585,8 @@ def build_ui(web_manager, action_fields, metadata, is_chat_env, title, quick_sta
|
|
| 537 |
# TAB 1: About (blog-style, mirrors README)
|
| 538 |
# ============================================================
|
| 539 |
with gr.Tab("About"):
|
| 540 |
-
gr.Markdown(
|
|
|
|
| 541 |
## The Problem Nobody Talks About
|
| 542 |
|
| 543 |
Here's a question: *How do you teach a machine to teach?*
|
|
@@ -559,9 +608,11 @@ a visual explanation, step by step, where each step advances the algorithm AND a
|
|
| 559 |
the learner's understanding.**
|
| 560 |
|
| 561 |
That's the gap this project fills.
|
| 562 |
-
"""
|
|
|
|
| 563 |
|
| 564 |
-
gr.Markdown(
|
|
|
|
| 565 |
## What This Is
|
| 566 |
|
| 567 |
The Visual Reasoning Environment is an
|
|
@@ -575,7 +626,8 @@ correctness, concept coverage, narration quality, and teaching pedagogy.
|
|
| 575 |
Think of it this way: you're not training the model to *know* BFS.
|
| 576 |
You're training it to *teach* BFS the way the best professor you ever had would --
|
| 577 |
with a marker in hand and an audience that needs to follow along.
|
| 578 |
-
"""
|
|
|
|
| 579 |
|
| 580 |
gr.HTML(
|
| 581 |
'<a href="https://youtu.be/KwWqjuyfWzw" target="_blank">'
|
|
@@ -584,7 +636,8 @@ with a marker in hand and an audience that needs to follow along.
|
|
| 584 |
'alt="Watch the Demo"/></a>'
|
| 585 |
)
|
| 586 |
|
| 587 |
-
gr.Markdown(
|
|
|
|
| 588 |
## What the Agent Does
|
| 589 |
|
| 590 |
Every scenario starts with an **empty canvas**. Nothing is drawn.
|
|
@@ -682,33 +735,32 @@ unsupported concept claims (0.30), too many ops (0.50), and info dumps (0.20).
|
|
| 682 |
## The Reinforcement Learning Loop
|
| 683 |
|
| 684 |
```
|
| 685 |
-
+-------------------------------------------------------------------
|
| 686 |
-
|
|
| 687 |
-
|
|
| 688 |
-
| +-----------+
|
| 689 |
-
| | | --------
|
| 690 |
-
| | Scenario |
|
| 691 |
-
| | Generator|
|
| 692 |
-
| | |
|
| 693 |
-
| +-----------+
|
| 694 |
-
|
|
| 695 |
-
|
|
| 696 |
-
|
|
| 697 |
-
|
|
| 698 |
-
|
|
| 699 |
-
|
|
| 700 |
-
|
|
| 701 |
-
|
|
| 702 |
-
|
|
| 703 |
-
|
|
| 704 |
-
|
|
| 705 |
-
|
|
| 706 |
-
|
|
| 707 |
-
|
|
| 708 |
-
| Per-step reward = delta(overall_score) + penalties + bonuses
|
| 709 |
-
| Episode: empty canvas --> Phase 1
|
| 710 |
-
|
| 711 |
-
+---------------------------------------------------------------------+
|
| 712 |
```
|
| 713 |
|
| 714 |
Every episode starts with a blank canvas and a goal like
|
|
@@ -762,19 +814,22 @@ the system without hand-coding rules for every field.
|
|
| 762 |
"The goal is not to be impressive. The goal is to be clear."
|
| 763 |
That's the north star of this project -- training machines not to be
|
| 764 |
impressive explainers, but clear ones.*
|
| 765 |
-
"""
|
|
|
|
| 766 |
|
| 767 |
# ============================================================
|
| 768 |
# TAB 2: Live Demo
|
| 769 |
# ============================================================
|
| 770 |
with gr.Tab("Live Demo"):
|
| 771 |
-
gr.Markdown(
|
|
|
|
| 772 |
## Watch an LLM Teach
|
| 773 |
|
| 774 |
See the agent explain a CS algorithm in real-time -- canvas visualization
|
| 775 |
with voice narration. Select a scenario, click **Start Demo**, then click the
|
| 776 |
viewer area to activate audio.
|
| 777 |
-
"""
|
|
|
|
| 778 |
|
| 779 |
with gr.Row():
|
| 780 |
demo_hf_token = gr.Textbox(
|
|
@@ -800,7 +855,9 @@ viewer area to activate audio.
|
|
| 800 |
demo_start_btn = gr.Button("Start Demo", variant="primary", scale=1)
|
| 801 |
demo_stop_btn = gr.Button("Stop Demo", variant="stop", scale=1)
|
| 802 |
|
| 803 |
-
demo_status = gr.Markdown(
|
|
|
|
|
|
|
| 804 |
|
| 805 |
gr.HTML(
|
| 806 |
value=(
|
|
@@ -825,7 +882,8 @@ viewer area to activate audio.
|
|
| 825 |
# TAB 3: Scoring & Architecture (technical)
|
| 826 |
# ============================================================
|
| 827 |
with gr.Tab("Scoring & Architecture"):
|
| 828 |
-
gr.Markdown(
|
|
|
|
| 829 |
## Scoring System
|
| 830 |
|
| 831 |
The overall score is a **weighted sum of 13 sub-scores** (each 0-1) **minus 5 penalties**.
|
|
@@ -833,7 +891,8 @@ Weights are tuned per difficulty level -- harder tiers emphasize algorithm corre
|
|
| 833 |
while easier tiers give more weight to narration quality and concept coverage.
|
| 834 |
|
| 835 |
**Select a difficulty level** to see the weight distribution:
|
| 836 |
-
"""
|
|
|
|
| 837 |
|
| 838 |
difficulty_radio = gr.Radio(
|
| 839 |
choices=["easy", "medium", "hard", "expert"],
|
|
@@ -848,7 +907,8 @@ while easier tiers give more weight to narration quality and concept coverage.
|
|
| 848 |
outputs=[weights_display],
|
| 849 |
)
|
| 850 |
|
| 851 |
-
gr.Markdown(
|
|
|
|
| 852 |
---
|
| 853 |
### Sub-Score Details
|
| 854 |
|
|
@@ -932,13 +992,15 @@ track membership for `push_to`/`pop_from`. Common source of LLM confusion.
|
|
| 932 |
plus relative prefixes (`below:`, `right-of:`, ...) instead of numeric coordinates.
|
| 933 |
This reduces the positioning search space from ~331K to ~6.5K, making it learnable within
|
| 934 |
reasonable training budgets.
|
| 935 |
-
"""
|
|
|
|
| 936 |
|
| 937 |
# ============================================================
|
| 938 |
# TAB 4: API Reference (technical)
|
| 939 |
# ============================================================
|
| 940 |
with gr.Tab("API Reference"):
|
| 941 |
-
gr.Markdown(
|
|
|
|
| 942 |
## API Reference
|
| 943 |
|
| 944 |
This Space exposes the standard **OpenEnv** HTTP + WebSocket API under `/api`.
|
|
@@ -1044,7 +1106,8 @@ export LOCAL_IMAGE_NAME=http://127.0.0.1:8000
|
|
| 1044 |
python inference.py # headless
|
| 1045 |
python inference_tldraw.py # with tldraw browser viewer
|
| 1046 |
```
|
| 1047 |
-
"""
|
|
|
|
| 1048 |
|
| 1049 |
return demo
|
| 1050 |
|
|
@@ -1067,6 +1130,7 @@ app = create_app(
|
|
| 1067 |
# Additional routes for live viewer + audio
|
| 1068 |
# ---------------------------------------------------------------------------
|
| 1069 |
|
|
|
|
| 1070 |
@app.get("/viewer")
|
| 1071 |
async def serve_viewer():
|
| 1072 |
viewer_path = VIEWER_DIR / "audio_viewer.html"
|
|
@@ -1114,7 +1178,9 @@ def main():
|
|
| 1114 |
print(f" Live Viewer: http://{args.host}:{args.port}/viewer")
|
| 1115 |
print(f" OpenEnv API: http://{args.host}:{args.port}/reset, /step, /health")
|
| 1116 |
if not LLM_API_KEY:
|
| 1117 |
-
print(
|
|
|
|
|
|
|
| 1118 |
|
| 1119 |
uvicorn.run(app, host=args.host, port=args.port)
|
| 1120 |
|
|
|
|
| 43 |
from .scenario_loader import load_scenarios
|
| 44 |
from .scoring import weights_for_difficulty
|
| 45 |
from .constants import (
|
| 46 |
+
ALLOWED_OPS,
|
| 47 |
+
ROLE_VALUES,
|
| 48 |
+
REGION_STYLES,
|
| 49 |
+
NAMED_POSITIONS,
|
| 50 |
)
|
| 51 |
except ImportError:
|
| 52 |
from models import VisualReasoningAction, VisualReasoningObservation
|
|
|
|
| 54 |
from server.scenario_loader import load_scenarios
|
| 55 |
from server.scoring import weights_for_difficulty
|
| 56 |
from server.constants import (
|
| 57 |
+
ALLOWED_OPS,
|
| 58 |
+
ROLE_VALUES,
|
| 59 |
+
REGION_STYLES,
|
| 60 |
+
NAMED_POSITIONS,
|
| 61 |
)
|
| 62 |
|
| 63 |
|
|
|
|
| 103 |
|
| 104 |
|
| 105 |
def _tier_emoji(t: str) -> str:
|
| 106 |
+
return {
|
| 107 |
+
"easy": "\U0001f7e2",
|
| 108 |
+
"medium": "\U0001f7e1",
|
| 109 |
+
"hard": "\U0001f7e0",
|
| 110 |
+
"expert": "\U0001f534",
|
| 111 |
+
}.get(t, "⚪")
|
| 112 |
|
| 113 |
|
| 114 |
def _tier_label(t: str) -> str:
|
|
|
|
| 119 |
# Broadcaster (WebSocket fan-out for live viewer)
|
| 120 |
# ---------------------------------------------------------------------------
|
| 121 |
|
| 122 |
+
|
| 123 |
class Broadcaster:
|
| 124 |
"""Append-only message log with long-poll support."""
|
| 125 |
|
|
|
|
| 174 |
# TTS via fal-ai Kokoro
|
| 175 |
# ---------------------------------------------------------------------------
|
| 176 |
|
| 177 |
+
|
| 178 |
async def _tts_to_base64(text: str) -> Optional[str]:
|
| 179 |
if not text or not text.strip():
|
| 180 |
return None
|
|
|
|
| 182 |
return None
|
| 183 |
try:
|
| 184 |
import aiohttp
|
| 185 |
+
|
| 186 |
async with aiohttp.ClientSession() as session:
|
| 187 |
async with session.post(
|
| 188 |
"https://fal.run/fal-ai/kokoro",
|
|
|
|
| 195 |
) as resp:
|
| 196 |
if resp.status != 200:
|
| 197 |
body = await resp.text()
|
| 198 |
+
print(
|
| 199 |
+
f"[TTS] fal-ai error {resp.status}: {body}",
|
| 200 |
+
file=sys.stderr,
|
| 201 |
+
flush=True,
|
| 202 |
+
)
|
| 203 |
return None
|
| 204 |
data = await resp.json()
|
| 205 |
audio_url = data.get("audio", {}).get("url")
|
| 206 |
if not audio_url:
|
| 207 |
+
print(
|
| 208 |
+
"[TTS] no audio URL in fal-ai response", file=sys.stderr, flush=True
|
| 209 |
+
)
|
| 210 |
return None
|
| 211 |
+
async with session.get(
|
| 212 |
+
audio_url, timeout=aiohttp.ClientTimeout(total=15)
|
| 213 |
+
) as audio_resp:
|
| 214 |
audio_bytes = await audio_resp.read()
|
| 215 |
result = base64.b64encode(audio_bytes).decode("ascii")
|
| 216 |
print(f"[TTS] generated {len(audio_bytes)} bytes of audio", flush=True)
|
|
|
|
| 233 |
key = api_key or LLM_API_KEY
|
| 234 |
if _openai_client is None or key != _openai_client_key:
|
| 235 |
from openai import OpenAI
|
| 236 |
+
|
| 237 |
_openai_client = OpenAI(base_url=LLM_API_BASE, api_key=key)
|
| 238 |
_openai_client_key = key
|
| 239 |
return _openai_client
|
|
|
|
| 242 |
def _build_system_prompt() -> str:
|
| 243 |
try:
|
| 244 |
from inference import SYSTEM_PROMPT
|
| 245 |
+
|
| 246 |
return SYSTEM_PROMPT
|
| 247 |
except Exception:
|
| 248 |
return (
|
|
|
|
| 251 |
)
|
| 252 |
|
| 253 |
|
| 254 |
+
def _build_user_prompt(
|
| 255 |
+
obs: Any, last_action: Optional[Dict], last_reward: float, history: List[str]
|
| 256 |
+
) -> str:
|
| 257 |
try:
|
| 258 |
from inference import build_user_prompt
|
| 259 |
+
|
| 260 |
return build_user_prompt(obs, last_action, last_reward, history)
|
| 261 |
except Exception:
|
| 262 |
return f"Goal: {obs.goal}\nEntities: {list(obs.entities.keys())}\nRemaining steps: {obs.remaining_step_budget}"
|
|
|
|
| 272 |
def _parse_llm_response(text: str) -> Optional[Dict[str, Any]]:
|
| 273 |
try:
|
| 274 |
from inference import parse_action, normalize_action
|
| 275 |
+
|
| 276 |
parsed = parse_action(text)
|
| 277 |
return normalize_action(parsed or {})
|
| 278 |
except Exception:
|
| 279 |
import re
|
| 280 |
+
|
| 281 |
match = re.search(r"\{.*\}", text, flags=re.DOTALL)
|
| 282 |
if match:
|
| 283 |
try:
|
|
|
|
| 344 |
|
| 345 |
goal_audio = await _tts_to_base64(obs.goal)
|
| 346 |
snap = _obs_snapshot(obs)
|
| 347 |
+
await _broadcaster.send(
|
| 348 |
+
{
|
| 349 |
+
"type": "reset",
|
| 350 |
+
"task_name": obs.task_name,
|
| 351 |
+
"scenario_id": obs.scenario_id,
|
| 352 |
+
"goal": obs.goal,
|
| 353 |
+
"checklist": list(obs.concept_checklist),
|
| 354 |
+
"input_data": dict(obs.input_data),
|
| 355 |
+
"constraints": list(obs.constraints),
|
| 356 |
+
"max_steps": obs.max_steps,
|
| 357 |
+
"audio": goal_audio,
|
| 358 |
+
**snap,
|
| 359 |
+
}
|
| 360 |
+
)
|
| 361 |
|
| 362 |
await asyncio.sleep(2.0)
|
| 363 |
|
|
|
|
| 380 |
temperature=LLM_TEMPERATURE,
|
| 381 |
max_tokens=LLM_MAX_TOKENS,
|
| 382 |
stream=False,
|
| 383 |
+
)
|
| 384 |
+
.choices[0]
|
| 385 |
+
.message.content
|
| 386 |
+
or "",
|
| 387 |
)
|
| 388 |
text = _strip_thinking_tokens(text)
|
| 389 |
print(f"[DEMO] Step {step}: LLM returned {len(text)} chars", flush=True)
|
|
|
|
| 424 |
history.append(f"Step {step}: {action_dict.get('narration', '')}")
|
| 425 |
|
| 426 |
snap = _obs_snapshot(obs)
|
| 427 |
+
await _broadcaster.send(
|
| 428 |
+
{
|
| 429 |
+
"type": "step",
|
| 430 |
+
"task_name": obs_dict.get("task_name", ""),
|
| 431 |
+
"scenario_id": obs_dict.get("scenario_id", ""),
|
| 432 |
+
"step": step,
|
| 433 |
+
"step_type": action_dict.get("step_type"),
|
| 434 |
+
"intent": action_dict.get("intent", ""),
|
| 435 |
+
"narration": narration,
|
| 436 |
+
"ops": action_dict.get("ops", []),
|
| 437 |
+
"covered_concepts": action_dict.get("covered_concepts", []),
|
| 438 |
+
"reward": float(reward),
|
| 439 |
+
"score": float(overall),
|
| 440 |
+
"done": bool(done),
|
| 441 |
+
"error": error,
|
| 442 |
+
"audio": audio_b64,
|
| 443 |
+
**snap,
|
| 444 |
+
}
|
| 445 |
+
)
|
| 446 |
|
| 447 |
if done:
|
| 448 |
await asyncio.sleep(2.0)
|
|
|
|
| 454 |
print(f"[DEMO] error: {exc}", file=sys.stderr, flush=True)
|
| 455 |
traceback.print_exc()
|
| 456 |
|
| 457 |
+
await _broadcaster.send(
|
| 458 |
+
{
|
| 459 |
+
"type": "end",
|
| 460 |
+
"task_name": scenario_id,
|
| 461 |
+
"success": score >= 0.65,
|
| 462 |
+
"steps": steps_taken,
|
| 463 |
+
"score": float(score),
|
| 464 |
+
"rewards": [float(r) for r in rewards],
|
| 465 |
+
}
|
| 466 |
+
)
|
| 467 |
|
| 468 |
|
| 469 |
# ---------------------------------------------------------------------------
|
| 470 |
# Scenario browser callbacks
|
| 471 |
# ---------------------------------------------------------------------------
|
| 472 |
|
| 473 |
+
|
| 474 |
def list_scenario_choices() -> List[str]:
|
| 475 |
return [_scenario_display_name(s) for s in _get_scenarios()]
|
| 476 |
|
|
|
|
| 499 |
|
| 500 |
checklist_md = "\n".join(f"- `{c}`" for c in checklist)
|
| 501 |
|
| 502 |
+
constraints_md = (
|
| 503 |
+
"\n".join(f"- `{c}`" for c in constraints) if constraints else "_None_"
|
| 504 |
+
)
|
| 505 |
|
| 506 |
return header, input_md, checklist_md, constraints_md
|
| 507 |
return "Scenario not found.", "", "", ""
|
|
|
|
| 511 |
# Scoring explorer callback
|
| 512 |
# ---------------------------------------------------------------------------
|
| 513 |
|
| 514 |
+
|
| 515 |
def show_weights(difficulty: str) -> str:
|
| 516 |
d = difficulty.lower()
|
| 517 |
w = weights_for_difficulty(d)
|
|
|
|
| 528 |
# Live demo callbacks (Gradio)
|
| 529 |
# ---------------------------------------------------------------------------
|
| 530 |
|
| 531 |
+
|
| 532 |
+
async def _start_live_demo(
|
| 533 |
+
scenario_choice: str, hf_token: str = "", model_name: str = ""
|
| 534 |
+
) -> str:
|
| 535 |
global _demo_task
|
| 536 |
if not scenario_choice:
|
| 537 |
return "Select a scenario first."
|
|
|
|
| 568 |
# Gradio UI (custom builder for openenv's gradio_builder parameter)
|
| 569 |
# ---------------------------------------------------------------------------
|
| 570 |
|
| 571 |
+
|
| 572 |
+
def build_ui(
|
| 573 |
+
web_manager, action_fields, metadata, is_chat_env, title, quick_start_md
|
| 574 |
+
) -> gr.Blocks:
|
| 575 |
"""Custom Gradio UI builder for openenv's gradio_builder parameter."""
|
| 576 |
with gr.Blocks(
|
| 577 |
title="Visual Reasoning Environment",
|
|
|
|
| 585 |
# TAB 1: About (blog-style, mirrors README)
|
| 586 |
# ============================================================
|
| 587 |
with gr.Tab("About"):
|
| 588 |
+
gr.Markdown(
|
| 589 |
+
"""
|
| 590 |
## The Problem Nobody Talks About
|
| 591 |
|
| 592 |
Here's a question: *How do you teach a machine to teach?*
|
|
|
|
| 608 |
the learner's understanding.**
|
| 609 |
|
| 610 |
That's the gap this project fills.
|
| 611 |
+
"""
|
| 612 |
+
)
|
| 613 |
|
| 614 |
+
gr.Markdown(
|
| 615 |
+
"""
|
| 616 |
## What This Is
|
| 617 |
|
| 618 |
The Visual Reasoning Environment is an
|
|
|
|
| 626 |
Think of it this way: you're not training the model to *know* BFS.
|
| 627 |
You're training it to *teach* BFS the way the best professor you ever had would --
|
| 628 |
with a marker in hand and an audience that needs to follow along.
|
| 629 |
+
"""
|
| 630 |
+
)
|
| 631 |
|
| 632 |
gr.HTML(
|
| 633 |
'<a href="https://youtu.be/KwWqjuyfWzw" target="_blank">'
|
|
|
|
| 636 |
'alt="Watch the Demo"/></a>'
|
| 637 |
)
|
| 638 |
|
| 639 |
+
gr.Markdown(
|
| 640 |
+
"""
|
| 641 |
## What the Agent Does
|
| 642 |
|
| 643 |
Every scenario starts with an **empty canvas**. Nothing is drawn.
|
|
|
|
| 735 |
## The Reinforcement Learning Loop
|
| 736 |
|
| 737 |
```
|
| 738 |
+
+-------------------------------------------------------------------+
|
| 739 |
+
| TRAINING LOOP (GRPO / RLVR) |
|
| 740 |
+
| |
|
| 741 |
+
| +-----------+ prompt +--------------+ JSON action |
|
| 742 |
+
| | | --------> | | --------------+ |
|
| 743 |
+
| | Scenario | | LLM Agent | | |
|
| 744 |
+
| | Generator| | (Teacher) | | |
|
| 745 |
+
| | | +------> | | <--------+ | |
|
| 746 |
+
| +-----------+ | +--------------+ | | |
|
| 747 |
+
| | | | |
|
| 748 |
+
| observation reward | |
|
| 749 |
+
| + score breakdown signal | |
|
| 750 |
+
| | | | |
|
| 751 |
+
| +--------+--------+ score +------+--+ | |
|
| 752 |
+
| | | <-------------- | | | |
|
| 753 |
+
| | Environment | | Scoring | | |
|
| 754 |
+
| | (Empty Canvas) | --------------> | Engine | | |
|
| 755 |
+
| | | canvas state |(13 dim) | | |
|
| 756 |
+
| +-----------------+ +---------+ | |
|
| 757 |
+
| ^ | |
|
| 758 |
+
| | step(action) | |
|
| 759 |
+
| +--------------------------------------+ |
|
| 760 |
+
| |
|
| 761 |
+
| Per-step reward = delta(overall_score) + penalties + bonuses |
|
| 762 |
+
| Episode: empty canvas --> Phase 1 --> Phase 2 --> Phase 3 |
|
| 763 |
+
+-------------------------------------------------------------------+
|
|
|
|
| 764 |
```
|
| 765 |
|
| 766 |
Every episode starts with a blank canvas and a goal like
|
|
|
|
| 814 |
"The goal is not to be impressive. The goal is to be clear."
|
| 815 |
That's the north star of this project -- training machines not to be
|
| 816 |
impressive explainers, but clear ones.*
|
| 817 |
+
"""
|
| 818 |
+
)
|
| 819 |
|
| 820 |
# ============================================================
|
| 821 |
# TAB 2: Live Demo
|
| 822 |
# ============================================================
|
| 823 |
with gr.Tab("Live Demo"):
|
| 824 |
+
gr.Markdown(
|
| 825 |
+
"""
|
| 826 |
## Watch an LLM Teach
|
| 827 |
|
| 828 |
See the agent explain a CS algorithm in real-time -- canvas visualization
|
| 829 |
with voice narration. Select a scenario, click **Start Demo**, then click the
|
| 830 |
viewer area to activate audio.
|
| 831 |
+
"""
|
| 832 |
+
)
|
| 833 |
|
| 834 |
with gr.Row():
|
| 835 |
demo_hf_token = gr.Textbox(
|
|
|
|
| 855 |
demo_start_btn = gr.Button("Start Demo", variant="primary", scale=1)
|
| 856 |
demo_stop_btn = gr.Button("Stop Demo", variant="stop", scale=1)
|
| 857 |
|
| 858 |
+
demo_status = gr.Markdown(
|
| 859 |
+
"_Enter your HF token, select a scenario, and click Start Demo._"
|
| 860 |
+
)
|
| 861 |
|
| 862 |
gr.HTML(
|
| 863 |
value=(
|
|
|
|
| 882 |
# TAB 3: Scoring & Architecture (technical)
|
| 883 |
# ============================================================
|
| 884 |
with gr.Tab("Scoring & Architecture"):
|
| 885 |
+
gr.Markdown(
|
| 886 |
+
"""
|
| 887 |
## Scoring System
|
| 888 |
|
| 889 |
The overall score is a **weighted sum of 13 sub-scores** (each 0-1) **minus 5 penalties**.
|
|
|
|
| 891 |
while easier tiers give more weight to narration quality and concept coverage.
|
| 892 |
|
| 893 |
**Select a difficulty level** to see the weight distribution:
|
| 894 |
+
"""
|
| 895 |
+
)
|
| 896 |
|
| 897 |
difficulty_radio = gr.Radio(
|
| 898 |
choices=["easy", "medium", "hard", "expert"],
|
|
|
|
| 907 |
outputs=[weights_display],
|
| 908 |
)
|
| 909 |
|
| 910 |
+
gr.Markdown(
|
| 911 |
+
"""
|
| 912 |
---
|
| 913 |
### Sub-Score Details
|
| 914 |
|
|
|
|
| 992 |
plus relative prefixes (`below:`, `right-of:`, ...) instead of numeric coordinates.
|
| 993 |
This reduces the positioning search space from ~331K to ~6.5K, making it learnable within
|
| 994 |
reasonable training budgets.
|
| 995 |
+
"""
|
| 996 |
+
)
|
| 997 |
|
| 998 |
# ============================================================
|
| 999 |
# TAB 4: API Reference (technical)
|
| 1000 |
# ============================================================
|
| 1001 |
with gr.Tab("API Reference"):
|
| 1002 |
+
gr.Markdown(
|
| 1003 |
+
f"""
|
| 1004 |
## API Reference
|
| 1005 |
|
| 1006 |
This Space exposes the standard **OpenEnv** HTTP + WebSocket API under `/api`.
|
|
|
|
| 1106 |
python inference.py # headless
|
| 1107 |
python inference_tldraw.py # with tldraw browser viewer
|
| 1108 |
```
|
| 1109 |
+
"""
|
| 1110 |
+
)
|
| 1111 |
|
| 1112 |
return demo
|
| 1113 |
|
|
|
|
| 1130 |
# Additional routes for live viewer + audio
|
| 1131 |
# ---------------------------------------------------------------------------
|
| 1132 |
|
| 1133 |
+
|
| 1134 |
@app.get("/viewer")
|
| 1135 |
async def serve_viewer():
|
| 1136 |
viewer_path = VIEWER_DIR / "audio_viewer.html"
|
|
|
|
| 1178 |
print(f" Live Viewer: http://{args.host}:{args.port}/viewer")
|
| 1179 |
print(f" OpenEnv API: http://{args.host}:{args.port}/reset, /step, /health")
|
| 1180 |
if not LLM_API_KEY:
|
| 1181 |
+
print(
|
| 1182 |
+
" NOTE: No HF_TOKEN/API_KEY env var — users can enter token in the Viewer tab"
|
| 1183 |
+
)
|
| 1184 |
|
| 1185 |
uvicorn.run(app, host=args.host, port=args.port)
|
| 1186 |
|
server/app_backup.py
DELETED
|
@@ -1,46 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the BSD-style license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
|
| 7 |
-
"""FastAPI application for the Visual Reasoning Environment."""
|
| 8 |
-
|
| 9 |
-
try:
|
| 10 |
-
from openenv.core.env_server.http_server import create_app
|
| 11 |
-
except Exception as e:
|
| 12 |
-
raise ImportError(
|
| 13 |
-
"openenv is required for the web interface. Install dependencies with 'uv sync'."
|
| 14 |
-
) from e
|
| 15 |
-
|
| 16 |
-
try:
|
| 17 |
-
from ..models import VisualReasoningAction, VisualReasoningObservation
|
| 18 |
-
from .visual_reasoning_environment import VisualReasoningEnvironment
|
| 19 |
-
except ImportError:
|
| 20 |
-
from models import VisualReasoningAction, VisualReasoningObservation
|
| 21 |
-
from server.visual_reasoning_environment import VisualReasoningEnvironment
|
| 22 |
-
|
| 23 |
-
app = create_app(
|
| 24 |
-
VisualReasoningEnvironment,
|
| 25 |
-
VisualReasoningAction,
|
| 26 |
-
VisualReasoningObservation,
|
| 27 |
-
env_name="visual_reasoning",
|
| 28 |
-
max_concurrent_envs=1,
|
| 29 |
-
)
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
def main():
|
| 33 |
-
"""Run the FastAPI server."""
|
| 34 |
-
import argparse
|
| 35 |
-
|
| 36 |
-
import uvicorn
|
| 37 |
-
|
| 38 |
-
parser = argparse.ArgumentParser()
|
| 39 |
-
parser.add_argument("--host", type=str, default="0.0.0.0")
|
| 40 |
-
parser.add_argument("--port", type=int, default=8000)
|
| 41 |
-
args = parser.parse_args()
|
| 42 |
-
uvicorn.run(app, host=args.host, port=args.port)
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
if __name__ == "__main__":
|
| 46 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
train.ipynb
DELETED
|
@@ -1,913 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"cells": [
|
| 3 |
-
{
|
| 4 |
-
"cell_type": "markdown",
|
| 5 |
-
"id": "title",
|
| 6 |
-
"metadata": {},
|
| 7 |
-
"source": [
|
| 8 |
-
"# Visual Reasoning — Training Demo\n",
|
| 9 |
-
"\n",
|
| 10 |
-
"**Simulation-based RL for teaching CS algorithms on a whiteboard**\n",
|
| 11 |
-
"\n",
|
| 12 |
-
"This notebook trains a small LLM to be an expert visual explainer of CS algorithms.\n",
|
| 13 |
-
"The model draws data structures, walks through algorithms step-by-step, and narrates\n",
|
| 14 |
-
"the reasoning — scored by a 12-dimension reward system.\n",
|
| 15 |
-
"\n",
|
| 16 |
-
"| Stage | Purpose |\n",
|
| 17 |
-
"|-------|:--------|\n",
|
| 18 |
-
"| Baseline | Score the untrained model across all difficulties |\n",
|
| 19 |
-
"| SFT warmup | Teach JSON action format via gold demonstrations |\n",
|
| 20 |
-
"| GRPO | RL with dense environment rewards, easy → expert curriculum |\n",
|
| 21 |
-
"| Final eval | Delta report comparing all three checkpoints |\n",
|
| 22 |
-
"\n",
|
| 23 |
-
"- 17 scenarios (9 hand-crafted + 8 procedurally generated) across 4 difficulty levels\n",
|
| 24 |
-
"- 9 algorithm templates: linked list, stack, binary search, BFS, hash table, Dijkstra, BST, fib memo, quicksort\n",
|
| 25 |
-
"- 12 weighted sub-scores + 5 penalties → dense per-step reward\n",
|
| 26 |
-
"\n",
|
| 27 |
-
"> **No model saving** — everything stays in-memory. The same LoRA adapter flows from SFT into GRPO."
|
| 28 |
-
]
|
| 29 |
-
},
|
| 30 |
-
{
|
| 31 |
-
"cell_type": "markdown",
|
| 32 |
-
"id": "install-hdr",
|
| 33 |
-
"metadata": {},
|
| 34 |
-
"source": [
|
| 35 |
-
"## 1. Install Dependencies"
|
| 36 |
-
]
|
| 37 |
-
},
|
| 38 |
-
{
|
| 39 |
-
"cell_type": "code",
|
| 40 |
-
"execution_count": null,
|
| 41 |
-
"id": "install",
|
| 42 |
-
"metadata": {},
|
| 43 |
-
"outputs": [],
|
| 44 |
-
"source": [
|
| 45 |
-
"!pip install -q --upgrade \"torchvision>=0.25.0\"\n",
|
| 46 |
-
"!pip install -q huggingface_hub unsloth trl datasets transformers accelerate bitsandbytes peft torch\n",
|
| 47 |
-
"!pip install -q openenv-core fastapi uvicorn pydantic\n",
|
| 48 |
-
"!pip install -q python-dotenv networkx shapely sentence-transformers rapidfuzz textstat sortedcontainers \"numpy<2.0\""
|
| 49 |
-
]
|
| 50 |
-
},
|
| 51 |
-
{
|
| 52 |
-
"cell_type": "markdown",
|
| 53 |
-
"id": "clone-hdr",
|
| 54 |
-
"metadata": {},
|
| 55 |
-
"source": [
|
| 56 |
-
"## 2. Clone Visual Reasoning Environment"
|
| 57 |
-
]
|
| 58 |
-
},
|
| 59 |
-
{
|
| 60 |
-
"cell_type": "code",
|
| 61 |
-
"execution_count": null,
|
| 62 |
-
"id": "clone",
|
| 63 |
-
"metadata": {},
|
| 64 |
-
"outputs": [],
|
| 65 |
-
"source": [
|
| 66 |
-
"import os\n",
|
| 67 |
-
"\n",
|
| 68 |
-
"# Use heuristic narration scorer — no GPU contention with training model\n",
|
| 69 |
-
"os.environ[\"NARRATION_SCORER\"] = \"fallback\"\n",
|
| 70 |
-
"\n",
|
| 71 |
-
"from huggingface_hub import snapshot_download\n",
|
| 72 |
-
"\n",
|
| 73 |
-
"if os.path.basename(os.getcwd()) != \"visual_reasoning\":\n",
|
| 74 |
-
" if not os.path.isdir(\"visual_reasoning\"):\n",
|
| 75 |
-
" snapshot_download(\n",
|
| 76 |
-
" repo_id=\"sreeramajay/visual_reasoning-env\",\n",
|
| 77 |
-
" repo_type=\"space\",\n",
|
| 78 |
-
" local_dir=\"visual_reasoning\",\n",
|
| 79 |
-
" ignore_patterns=[\"*.gitattributes\", \".gitignore\", \"README.md\"],\n",
|
| 80 |
-
" )\n",
|
| 81 |
-
" os.chdir(\"visual_reasoning\")"
|
| 82 |
-
]
|
| 83 |
-
},
|
| 84 |
-
{
|
| 85 |
-
"cell_type": "markdown",
|
| 86 |
-
"id": "verify-hdr",
|
| 87 |
-
"metadata": {},
|
| 88 |
-
"source": [
|
| 89 |
-
"## 3. Verify Environment — Run a Smoke Test"
|
| 90 |
-
]
|
| 91 |
-
},
|
| 92 |
-
{
|
| 93 |
-
"cell_type": "code",
|
| 94 |
-
"execution_count": null,
|
| 95 |
-
"id": "verify",
|
| 96 |
-
"metadata": {},
|
| 97 |
-
"outputs": [],
|
| 98 |
-
"source": [
|
| 99 |
-
"import sys, os, json, torch\n",
|
| 100 |
-
"sys.path.insert(0, '.')\n",
|
| 101 |
-
"\n",
|
| 102 |
-
"from unsloth import FastLanguageModel\n",
|
| 103 |
-
"\n",
|
| 104 |
-
"from models import VisualReasoningAction\n",
|
| 105 |
-
"from server.visual_reasoning_environment import VisualReasoningEnvironment\n",
|
| 106 |
-
"\n",
|
| 107 |
-
"test_env = VisualReasoningEnvironment()\n",
|
| 108 |
-
"obs = test_env.reset(scenario_id='easy_1')\n",
|
| 109 |
-
"print(f'Scenario: {obs.scenario_id}')\n",
|
| 110 |
-
"print(f'Goal: {obs.goal}')\n",
|
| 111 |
-
"print(f'Concepts: {obs.concept_checklist}')\n",
|
| 112 |
-
"print(f'Step budget: {obs.remaining_step_budget}')\n",
|
| 113 |
-
"\n",
|
| 114 |
-
"obs = test_env.step(VisualReasoningAction(\n",
|
| 115 |
-
" step_type='advance',\n",
|
| 116 |
-
" narration='Adding the first node with value 10.',\n",
|
| 117 |
-
" ops=[{'op': 'add_node', 'target_ids': ['n0'], 'params': {'value': 10}}],\n",
|
| 118 |
-
" covered_concepts=['node_value'],\n",
|
| 119 |
-
" intent='test',\n",
|
| 120 |
-
"))\n",
|
| 121 |
-
"print(f'\\nAfter step: entities={list(obs.entities.keys())}, reward={obs.reward:.3f}, error={obs.action_error}')\n",
|
| 122 |
-
"print('Environment OK!')\n",
|
| 123 |
-
"del test_env"
|
| 124 |
-
]
|
| 125 |
-
},
|
| 126 |
-
{
|
| 127 |
-
"cell_type": "markdown",
|
| 128 |
-
"id": "config-hdr",
|
| 129 |
-
"metadata": {},
|
| 130 |
-
"source": [
|
| 131 |
-
"## 4. Configuration"
|
| 132 |
-
]
|
| 133 |
-
},
|
| 134 |
-
{
|
| 135 |
-
"cell_type": "code",
|
| 136 |
-
"execution_count": null,
|
| 137 |
-
"id": "config",
|
| 138 |
-
"metadata": {},
|
| 139 |
-
"outputs": [],
|
| 140 |
-
"source": [
|
| 141 |
-
"from collections import Counter\n",
|
| 142 |
-
"from inference import SYSTEM_PROMPT, build_user_prompt, parse_action, normalize_action\n",
|
| 143 |
-
"\n",
|
| 144 |
-
"MODEL_NAME = 'unsloth/Qwen2.5-3B-Instruct-bnb-4bit'\n",
|
| 145 |
-
"MAX_SEQ_LENGTH = 4096\n",
|
| 146 |
-
"LORA_R = 16\n",
|
| 147 |
-
"LORA_ALPHA = 32\n",
|
| 148 |
-
"SFT_EPOCHS = 3\n",
|
| 149 |
-
"GRPO_EPOCHS = 2\n",
|
| 150 |
-
"\n",
|
| 151 |
-
"SCENARIOS = {\n",
|
| 152 |
-
" 'easy': ['easy_1', 'easy_2', 'easy_3', 'gen_easy_1001', 'gen_easy_1002'],\n",
|
| 153 |
-
" 'medium': ['medium_1', 'medium_2', 'gen_medium_2001', 'gen_medium_2002'],\n",
|
| 154 |
-
" 'hard': ['hard_1', 'hard_2', 'gen_hard_3001', 'gen_hard_3002'],\n",
|
| 155 |
-
" 'expert': ['expert_1', 'expert_2', 'gen_expert_4001', 'gen_expert_4002'],\n",
|
| 156 |
-
"}\n",
|
| 157 |
-
"DIFFICULTIES = ('easy', 'medium', 'hard', 'expert')"
|
| 158 |
-
]
|
| 159 |
-
},
|
| 160 |
-
{
|
| 161 |
-
"cell_type": "markdown",
|
| 162 |
-
"id": "model-hdr",
|
| 163 |
-
"metadata": {},
|
| 164 |
-
"source": [
|
| 165 |
-
"## 5. Load Model (Unsloth + LoRA)"
|
| 166 |
-
]
|
| 167 |
-
},
|
| 168 |
-
{
|
| 169 |
-
"cell_type": "code",
|
| 170 |
-
"execution_count": null,
|
| 171 |
-
"id": "load-model",
|
| 172 |
-
"metadata": {},
|
| 173 |
-
"outputs": [],
|
| 174 |
-
"source": [
|
| 175 |
-
"model, tokenizer = FastLanguageModel.from_pretrained(\n",
|
| 176 |
-
" model_name=MODEL_NAME,\n",
|
| 177 |
-
" max_seq_length=MAX_SEQ_LENGTH,\n",
|
| 178 |
-
" dtype=None,\n",
|
| 179 |
-
" load_in_4bit=True,\n",
|
| 180 |
-
")\n",
|
| 181 |
-
"model = FastLanguageModel.get_peft_model(\n",
|
| 182 |
-
" model,\n",
|
| 183 |
-
" r=LORA_R,\n",
|
| 184 |
-
" lora_alpha=LORA_ALPHA,\n",
|
| 185 |
-
" lora_dropout=0,\n",
|
| 186 |
-
" target_modules=[\n",
|
| 187 |
-
" 'q_proj', 'k_proj', 'v_proj', 'o_proj',\n",
|
| 188 |
-
" 'gate_proj', 'up_proj', 'down_proj',\n",
|
| 189 |
-
" ],\n",
|
| 190 |
-
" bias='none',\n",
|
| 191 |
-
" use_gradient_checkpointing='unsloth',\n",
|
| 192 |
-
" random_state=0,\n",
|
| 193 |
-
")\n",
|
| 194 |
-
"if tokenizer.pad_token_id is None:\n",
|
| 195 |
-
" tokenizer.pad_token = tokenizer.eos_token\n",
|
| 196 |
-
"\n",
|
| 197 |
-
"model.generation_config.max_length = None\n",
|
| 198 |
-
"\n",
|
| 199 |
-
"trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
|
| 200 |
-
"total = sum(p.numel() for p in model.parameters())\n",
|
| 201 |
-
"print(f'Parameters: {total / 1e6:.1f}M total, {trainable / 1e6:.2f}M trainable ({trainable / total:.2%})')"
|
| 202 |
-
]
|
| 203 |
-
},
|
| 204 |
-
{
|
| 205 |
-
"cell_type": "markdown",
|
| 206 |
-
"id": "helpers-hdr",
|
| 207 |
-
"metadata": {},
|
| 208 |
-
"source": [
|
| 209 |
-
"## 6. Environment Wrapper & Batched Evaluation\n",
|
| 210 |
-
"\n",
|
| 211 |
-
"Uses the environment **directly in-process** — no server needed.\n",
|
| 212 |
-
"- **Batched generation**: collects prompts from up to 8 parallel episodes per `model.generate()` call\n",
|
| 213 |
-
"- **Early termination**: kills episodes after 3 consecutive no-ops\n",
|
| 214 |
-
"- **Environment pool**: reuses env instances across eval calls"
|
| 215 |
-
]
|
| 216 |
-
},
|
| 217 |
-
{
|
| 218 |
-
"cell_type": "code",
|
| 219 |
-
"execution_count": null,
|
| 220 |
-
"id": "helpers",
|
| 221 |
-
"metadata": {},
|
| 222 |
-
"outputs": [],
|
| 223 |
-
"source": [
|
| 224 |
-
"import time\n",
|
| 225 |
-
"\n",
|
| 226 |
-
"EVAL_BATCH_SIZE = 8\n",
|
| 227 |
-
"EVAL_MAX_STEPS = 24\n",
|
| 228 |
-
"NOOP_EARLY_STOP = 3\n",
|
| 229 |
-
"\n",
|
| 230 |
-
"\n",
|
| 231 |
-
"class EnvRunner:\n",
|
| 232 |
-
" \"\"\"Thin wrapper over VisualReasoningEnvironment for a clean reset/step API.\"\"\"\n",
|
| 233 |
-
"\n",
|
| 234 |
-
" def __init__(self):\n",
|
| 235 |
-
" self.env = VisualReasoningEnvironment()\n",
|
| 236 |
-
"\n",
|
| 237 |
-
" def reset(self, scenario_id=None, task_name=None):\n",
|
| 238 |
-
" return self.env.reset(scenario_id=scenario_id, task_name=task_name)\n",
|
| 239 |
-
"\n",
|
| 240 |
-
" def step(self, action_dict):\n",
|
| 241 |
-
" act = VisualReasoningAction(**action_dict)\n",
|
| 242 |
-
" obs = self.env.step(act)\n",
|
| 243 |
-
" return obs, float(obs.reward), bool(obs.done)\n",
|
| 244 |
-
"\n",
|
| 245 |
-
"\n",
|
| 246 |
-
"FALLBACK_ACTION = {\n",
|
| 247 |
-
" 'step_type': 'complete', 'narration': 'Explanation complete.',\n",
|
| 248 |
-
" 'ops': [], 'covered_concepts': [], 'intent': 'finalize',\n",
|
| 249 |
-
"}\n",
|
| 250 |
-
"\n",
|
| 251 |
-
"_env_pool = []\n",
|
| 252 |
-
"\n",
|
| 253 |
-
"def _get_env(idx):\n",
|
| 254 |
-
" while len(_env_pool) <= idx:\n",
|
| 255 |
-
" _env_pool.append(VisualReasoningEnvironment())\n",
|
| 256 |
-
" return _env_pool[idx]\n",
|
| 257 |
-
"\n",
|
| 258 |
-
"\n",
|
| 259 |
-
"class EpisodeState:\n",
|
| 260 |
-
" \"\"\"Tracks one episode's state for batched eval.\"\"\"\n",
|
| 261 |
-
" def __init__(self, env, scenario_id):\n",
|
| 262 |
-
" self.env = env\n",
|
| 263 |
-
" self.scenario_id = scenario_id\n",
|
| 264 |
-
" self.obs = env.reset(scenario_id=scenario_id)\n",
|
| 265 |
-
" self.last_action = None\n",
|
| 266 |
-
" self.last_reward = 0.0\n",
|
| 267 |
-
" self.history = []\n",
|
| 268 |
-
" self.steps = 0\n",
|
| 269 |
-
" self.done = False\n",
|
| 270 |
-
" self.score = 0.0\n",
|
| 271 |
-
" self.consecutive_noops = 0\n",
|
| 272 |
-
"\n",
|
| 273 |
-
" def build_prompt_text(self):\n",
|
| 274 |
-
" user_prompt = build_user_prompt(\n",
|
| 275 |
-
" self.obs, self.last_action, self.last_reward, self.history\n",
|
| 276 |
-
" )\n",
|
| 277 |
-
" messages = [\n",
|
| 278 |
-
" {'role': 'system', 'content': SYSTEM_PROMPT},\n",
|
| 279 |
-
" {'role': 'user', 'content': user_prompt},\n",
|
| 280 |
-
" ]\n",
|
| 281 |
-
" return tokenizer.apply_chat_template(\n",
|
| 282 |
-
" messages, tokenize=False, add_generation_prompt=True\n",
|
| 283 |
-
" )\n",
|
| 284 |
-
"\n",
|
| 285 |
-
" def apply_action(self, action_dict):\n",
|
| 286 |
-
" if action_dict is None:\n",
|
| 287 |
-
" action_dict = FALLBACK_ACTION\n",
|
| 288 |
-
" self.obs = self.env.step(VisualReasoningAction(**action_dict))\n",
|
| 289 |
-
" reward = float(self.obs.reward)\n",
|
| 290 |
-
" self.last_action = action_dict\n",
|
| 291 |
-
" self.last_reward = reward\n",
|
| 292 |
-
" self.steps += 1\n",
|
| 293 |
-
" self.history.append(f\"Step {self.steps}: {action_dict.get('narration', '')}\")\n",
|
| 294 |
-
" if reward <= -0.04:\n",
|
| 295 |
-
" self.consecutive_noops += 1\n",
|
| 296 |
-
" else:\n",
|
| 297 |
-
" self.consecutive_noops = 0\n",
|
| 298 |
-
" if self.obs.done or self.consecutive_noops >= NOOP_EARLY_STOP:\n",
|
| 299 |
-
" self.done = True\n",
|
| 300 |
-
" self.score = float(self.obs.score_breakdown.get('overall_score', 0.0))\n",
|
| 301 |
-
"\n",
|
| 302 |
-
"\n",
|
| 303 |
-
"def batched_generate(model, tokenizer, prompt_texts, max_new_tokens=384):\n",
|
| 304 |
-
" \"\"\"Run model.generate on a batch of prompt strings.\"\"\"\n",
|
| 305 |
-
" FastLanguageModel.for_inference(model)\n",
|
| 306 |
-
" orig_side = tokenizer.padding_side\n",
|
| 307 |
-
" tokenizer.padding_side = 'left'\n",
|
| 308 |
-
" inputs = tokenizer(\n",
|
| 309 |
-
" prompt_texts,\n",
|
| 310 |
-
" return_tensors='pt',\n",
|
| 311 |
-
" padding=True,\n",
|
| 312 |
-
" truncation=True,\n",
|
| 313 |
-
" max_length=MAX_SEQ_LENGTH,\n",
|
| 314 |
-
" ).to(model.device)\n",
|
| 315 |
-
" tokenizer.padding_side = orig_side\n",
|
| 316 |
-
" with torch.no_grad():\n",
|
| 317 |
-
" outputs = model.generate(\n",
|
| 318 |
-
" **inputs,\n",
|
| 319 |
-
" max_new_tokens=max_new_tokens,\n",
|
| 320 |
-
" do_sample=False,\n",
|
| 321 |
-
" pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,\n",
|
| 322 |
-
" )\n",
|
| 323 |
-
" results = []\n",
|
| 324 |
-
" for i, out in enumerate(outputs):\n",
|
| 325 |
-
" input_len = inputs.input_ids[i].ne(tokenizer.pad_token_id).sum()\n",
|
| 326 |
-
" text = tokenizer.decode(out[input_len:], skip_special_tokens=True)\n",
|
| 327 |
-
" results.append(text)\n",
|
| 328 |
-
" return results\n",
|
| 329 |
-
"\n",
|
| 330 |
-
"\n",
|
| 331 |
-
"def evaluate(model, tokenizer, label, stages=None):\n",
|
| 332 |
-
" \"\"\"Batched evaluation across all scenarios.\"\"\"\n",
|
| 333 |
-
" stages = stages or list(DIFFICULTIES)\n",
|
| 334 |
-
" print(f'\\n{\"=\" * 60}')\n",
|
| 335 |
-
" print(f' {label}')\n",
|
| 336 |
-
" print(f'{\"=\" * 60}')\n",
|
| 337 |
-
"\n",
|
| 338 |
-
" all_scenario_ids = []\n",
|
| 339 |
-
" for diff in stages:\n",
|
| 340 |
-
" all_scenario_ids.extend([(diff, sid) for sid in SCENARIOS.get(diff, [])])\n",
|
| 341 |
-
"\n",
|
| 342 |
-
" results_by_diff = {d: [] for d in stages}\n",
|
| 343 |
-
" t0 = time.time()\n",
|
| 344 |
-
"\n",
|
| 345 |
-
" for batch_start in range(0, len(all_scenario_ids), EVAL_BATCH_SIZE):\n",
|
| 346 |
-
" batch = all_scenario_ids[batch_start:batch_start + EVAL_BATCH_SIZE]\n",
|
| 347 |
-
" t_batch = time.time()\n",
|
| 348 |
-
" episodes = []\n",
|
| 349 |
-
" for i, (diff, sid) in enumerate(batch):\n",
|
| 350 |
-
" episodes.append(EpisodeState(_get_env(i), sid))\n",
|
| 351 |
-
"\n",
|
| 352 |
-
" for step_num in range(1, EVAL_MAX_STEPS + 1):\n",
|
| 353 |
-
" active = [ep for ep in episodes if not ep.done]\n",
|
| 354 |
-
" if not active:\n",
|
| 355 |
-
" break\n",
|
| 356 |
-
" prompts = [ep.build_prompt_text() for ep in active]\n",
|
| 357 |
-
" generated = batched_generate(model, tokenizer, prompts)\n",
|
| 358 |
-
" for ep, text in zip(active, generated):\n",
|
| 359 |
-
" parsed = parse_action(text)\n",
|
| 360 |
-
" action = normalize_action(parsed or {}) if parsed else None\n",
|
| 361 |
-
" ep.apply_action(action)\n",
|
| 362 |
-
"\n",
|
| 363 |
-
" batch_time = time.time() - t_batch\n",
|
| 364 |
-
" for ep, (diff, sid) in zip(episodes, batch):\n",
|
| 365 |
-
" results_by_diff[diff].append(ep.score)\n",
|
| 366 |
-
" print(f' [{diff:6}] {sid:22} score={ep.score:.3f} steps={ep.steps}')\n",
|
| 367 |
-
" sids_str = ', '.join(sid for _, sid in batch)\n",
|
| 368 |
-
" print(f' Batch {batch_start // EVAL_BATCH_SIZE + 1}: '\n",
|
| 369 |
-
" f'{len(batch)} scenarios in {batch_time:.1f}s [{sids_str}]')\n",
|
| 370 |
-
"\n",
|
| 371 |
-
" results = {}\n",
|
| 372 |
-
" for diff in stages:\n",
|
| 373 |
-
" scores = results_by_diff[diff]\n",
|
| 374 |
-
" if scores:\n",
|
| 375 |
-
" mean = sum(scores) / len(scores)\n",
|
| 376 |
-
" results[diff] = mean\n",
|
| 377 |
-
" print(f' [{diff:6}] mean={mean:.3f}')\n",
|
| 378 |
-
" overall = sum(results.values()) / max(len(results), 1)\n",
|
| 379 |
-
" results['overall'] = overall\n",
|
| 380 |
-
" elapsed = time.time() - t0\n",
|
| 381 |
-
" print(f' OVERALL: {overall:.3f} ({elapsed:.1f}s)')\n",
|
| 382 |
-
" return results\n",
|
| 383 |
-
"\n",
|
| 384 |
-
"\n",
|
| 385 |
-
"env = EnvRunner()\n",
|
| 386 |
-
"print('EnvRunner ready.')"
|
| 387 |
-
]
|
| 388 |
-
},
|
| 389 |
-
{
|
| 390 |
-
"cell_type": "markdown",
|
| 391 |
-
"id": "baseline-hdr",
|
| 392 |
-
"metadata": {},
|
| 393 |
-
"source": [
|
| 394 |
-
"## 7. Baseline Evaluation\n",
|
| 395 |
-
"\n",
|
| 396 |
-
"Score the untrained model (random LoRA weights) across all scenarios to establish a baseline."
|
| 397 |
-
]
|
| 398 |
-
},
|
| 399 |
-
{
|
| 400 |
-
"cell_type": "code",
|
| 401 |
-
"execution_count": null,
|
| 402 |
-
"id": "baseline",
|
| 403 |
-
"metadata": {},
|
| 404 |
-
"outputs": [],
|
| 405 |
-
"source": [
|
| 406 |
-
"baseline = evaluate(model, tokenizer, 'Baseline (untrained LoRA)')"
|
| 407 |
-
]
|
| 408 |
-
},
|
| 409 |
-
{
|
| 410 |
-
"cell_type": "markdown",
|
| 411 |
-
"id": "gold-hdr",
|
| 412 |
-
"metadata": {},
|
| 413 |
-
"source": [
|
| 414 |
-
"## 8. Gold Demonstrations\n",
|
| 415 |
-
"\n",
|
| 416 |
-
"Hand-crafted teaching episodes for the three easy scenarios. These teach the model:\n",
|
| 417 |
-
"- The JSON action format (`step_type`, `narration`, `ops`, `covered_concepts`, `intent`)\n",
|
| 418 |
-
"- Incremental setup (Phase 1), algorithm walk-through (Phase 2), wrap-up (Phase 3)\n",
|
| 419 |
-
"- Proper concept evidencing (narration must mention the concept, ops must back it)\n",
|
| 420 |
-
"- Region vs container distinction (regions for layout, containers for push/pop)"
|
| 421 |
-
]
|
| 422 |
-
},
|
| 423 |
-
{
|
| 424 |
-
"cell_type": "code",
|
| 425 |
-
"execution_count": null,
|
| 426 |
-
"id": "gold",
|
| 427 |
-
"metadata": {},
|
| 428 |
-
"outputs": [],
|
| 429 |
-
"source": [
|
| 430 |
-
"GOLD_EPISODES = {\n",
|
| 431 |
-
" # ── easy_1: linked_list_traversal ──────────────────────────────────\n",
|
| 432 |
-
" # input_data: {\"values\": [10, 20, 30]}\n",
|
| 433 |
-
" # concepts: head_pointer, node_value, next_link, tail_marker\n",
|
| 434 |
-
" 'easy_1': [\n",
|
| 435 |
-
" {\n",
|
| 436 |
-
" 'step_type': 'advance',\n",
|
| 437 |
-
" 'narration': 'Building a linked list — creating a centered layout region and the first two nodes with values 10 and 20.',\n",
|
| 438 |
-
" 'ops': [\n",
|
| 439 |
-
" {'op': 'add_region', 'target_ids': ['list'], 'params': {'style': 'array', 'title': 'Linked List', 'position': 'center'}},\n",
|
| 440 |
-
" {'op': 'add_node', 'target_ids': ['n1'], 'params': {'value': 10, 'region': 'list'}},\n",
|
| 441 |
-
" {'op': 'add_node', 'target_ids': ['n2'], 'params': {'value': 20, 'region': 'list'}},\n",
|
| 442 |
-
" ],\n",
|
| 443 |
-
" 'covered_concepts': ['node_value'],\n",
|
| 444 |
-
" 'intent': 'create_list_start',\n",
|
| 445 |
-
" },\n",
|
| 446 |
-
" {\n",
|
| 447 |
-
" 'step_type': 'advance',\n",
|
| 448 |
-
" 'narration': 'Adding node 30 and connecting all nodes with next links to form the chain 10 -> 20 -> 30.',\n",
|
| 449 |
-
" 'ops': [\n",
|
| 450 |
-
" {'op': 'add_node', 'target_ids': ['n3'], 'params': {'value': 30, 'region': 'list'}},\n",
|
| 451 |
-
" {'op': 'add_edge', 'target_ids': ['n1', 'n2'], 'params': {'kind': 'directed', 'label': 'next'}},\n",
|
| 452 |
-
" {'op': 'add_edge', 'target_ids': ['n2', 'n3'], 'params': {'kind': 'directed', 'label': 'next'}},\n",
|
| 453 |
-
" ],\n",
|
| 454 |
-
" 'covered_concepts': ['next_link'],\n",
|
| 455 |
-
" 'intent': 'connect_nodes',\n",
|
| 456 |
-
" },\n",
|
| 457 |
-
" {\n",
|
| 458 |
-
" 'step_type': 'advance',\n",
|
| 459 |
-
" 'narration': 'Placing a head pointer at node 10 because traversal always starts at the head, our only entry point into the list.',\n",
|
| 460 |
-
" 'ops': [\n",
|
| 461 |
-
" {'op': 'add_pointer', 'target_ids': ['head_ptr'], 'params': {'region': 'list'}},\n",
|
| 462 |
-
" {'op': 'move_pointer', 'target_ids': ['head_ptr'], 'params': {'index': 'n1'}},\n",
|
| 463 |
-
" {'op': 'annotate', 'target_ids': ['n1'], 'params': {'text': 'Head'}},\n",
|
| 464 |
-
" {'op': 'set_role', 'target_ids': ['n1'], 'params': {'role': 'current'}},\n",
|
| 465 |
-
" ],\n",
|
| 466 |
-
" 'covered_concepts': ['head_pointer'],\n",
|
| 467 |
-
" 'intent': 'mark_head',\n",
|
| 468 |
-
" },\n",
|
| 469 |
-
" {\n",
|
| 470 |
-
" 'step_type': 'advance',\n",
|
| 471 |
-
" 'narration': 'Following the next link from 10 to 20 — the pointer advances and we mark node 10 as visited.',\n",
|
| 472 |
-
" 'ops': [\n",
|
| 473 |
-
" {'op': 'set_role', 'target_ids': ['n1'], 'params': {'role': 'visited'}},\n",
|
| 474 |
-
" {'op': 'set_role', 'target_ids': ['n2'], 'params': {'role': 'current'}},\n",
|
| 475 |
-
" {'op': 'move_pointer', 'target_ids': ['head_ptr'], 'params': {'index': 'n2'}},\n",
|
| 476 |
-
" ],\n",
|
| 477 |
-
" 'covered_concepts': [],\n",
|
| 478 |
-
" 'intent': 'traverse_to_second',\n",
|
| 479 |
-
" },\n",
|
| 480 |
-
" {\n",
|
| 481 |
-
" 'step_type': 'advance',\n",
|
| 482 |
-
" 'narration': 'Reaching node 30 — it has no next link, making it the tail that signals the end of traversal.',\n",
|
| 483 |
-
" 'ops': [\n",
|
| 484 |
-
" {'op': 'set_role', 'target_ids': ['n2'], 'params': {'role': 'visited'}},\n",
|
| 485 |
-
" {'op': 'set_role', 'target_ids': ['n3'], 'params': {'role': 'current'}},\n",
|
| 486 |
-
" {'op': 'annotate', 'target_ids': ['n3'], 'params': {'text': 'Tail'}},\n",
|
| 487 |
-
" ],\n",
|
| 488 |
-
" 'covered_concepts': ['tail_marker'],\n",
|
| 489 |
-
" 'intent': 'reach_tail',\n",
|
| 490 |
-
" },\n",
|
| 491 |
-
" {\n",
|
| 492 |
-
" 'step_type': 'complete',\n",
|
| 493 |
-
" 'narration': 'Traversal complete — visited every node from head to tail following next links, reading values 10, 20, 30 in order.',\n",
|
| 494 |
-
" 'ops': [\n",
|
| 495 |
-
" {'op': 'set_role', 'target_ids': ['n3'], 'params': {'role': 'done'}},\n",
|
| 496 |
-
" ],\n",
|
| 497 |
-
" 'covered_concepts': [],\n",
|
| 498 |
-
" 'intent': 'summarize',\n",
|
| 499 |
-
" },\n",
|
| 500 |
-
" ],\n",
|
| 501 |
-
"\n",
|
| 502 |
-
" # ── easy_2: stack_ops ──────────────────────────────────────────────\n",
|
| 503 |
-
" # input_data: {\"operations\": [\"push A\", \"push B\", \"pop\", \"push C\"]}\n",
|
| 504 |
-
" # concepts: top_pointer, push, pop, lifo_order\n",
|
| 505 |
-
" 'easy_2': [\n",
|
| 506 |
-
" {\n",
|
| 507 |
-
" 'step_type': 'advance',\n",
|
| 508 |
-
" 'narration': 'Setting up a stack with a centered visual region and a container to track push and pop membership.',\n",
|
| 509 |
-
" 'ops': [\n",
|
| 510 |
-
" {'op': 'add_region', 'target_ids': ['stack_area'], 'params': {'style': 'stack', 'title': 'Stack', 'position': 'center'}},\n",
|
| 511 |
-
" {'op': 'add_container', 'target_ids': ['stk'], 'params': {'region': 'stack_area', 'ordered': False, 'title': 'Stack'}},\n",
|
| 512 |
-
" ],\n",
|
| 513 |
-
" 'covered_concepts': [],\n",
|
| 514 |
-
" 'intent': 'setup_stack',\n",
|
| 515 |
-
" },\n",
|
| 516 |
-
" {\n",
|
| 517 |
-
" 'step_type': 'advance',\n",
|
| 518 |
-
" 'narration': 'Pushing A onto the stack — A becomes the first element. Adding a top pointer to track the stack top.',\n",
|
| 519 |
-
" 'ops': [\n",
|
| 520 |
-
" {'op': 'add_node', 'target_ids': ['a'], 'params': {'value': 'A', 'region': 'stack_area'}},\n",
|
| 521 |
-
" {'op': 'push_to', 'target_ids': ['stk', 'a'], 'params': {}},\n",
|
| 522 |
-
" {'op': 'add_pointer', 'target_ids': ['top'], 'params': {'region': 'stack_area'}},\n",
|
| 523 |
-
" {'op': 'move_pointer', 'target_ids': ['top'], 'params': {'index': 'a'}},\n",
|
| 524 |
-
" ],\n",
|
| 525 |
-
" 'covered_concepts': ['push', 'top_pointer'],\n",
|
| 526 |
-
" 'intent': 'push_a',\n",
|
| 527 |
-
" },\n",
|
| 528 |
-
" {\n",
|
| 529 |
-
" 'step_type': 'advance',\n",
|
| 530 |
-
" 'narration': 'Pushing B — B sits on top of A and the top pointer moves up to B.',\n",
|
| 531 |
-
" 'ops': [\n",
|
| 532 |
-
" {'op': 'add_node', 'target_ids': ['b'], 'params': {'value': 'B', 'region': 'stack_area'}},\n",
|
| 533 |
-
" {'op': 'push_to', 'target_ids': ['stk', 'b'], 'params': {}},\n",
|
| 534 |
-
" {'op': 'move_pointer', 'target_ids': ['top'], 'params': {'index': 'b'}},\n",
|
| 535 |
-
" ],\n",
|
| 536 |
-
" 'covered_concepts': [],\n",
|
| 537 |
-
" 'intent': 'push_b',\n",
|
| 538 |
-
" },\n",
|
| 539 |
-
" {\n",
|
| 540 |
-
" 'step_type': 'advance',\n",
|
| 541 |
-
" 'narration': 'Popping from the stack — B was pushed last so B comes off first, demonstrating LIFO (last-in-first-out) order.',\n",
|
| 542 |
-
" 'ops': [\n",
|
| 543 |
-
" {'op': 'pop_from', 'target_ids': ['stk'], 'params': {}},\n",
|
| 544 |
-
" {'op': 'set_role', 'target_ids': ['b'], 'params': {'role': 'inactive'}},\n",
|
| 545 |
-
" {'op': 'move_pointer', 'target_ids': ['top'], 'params': {'index': 'a'}},\n",
|
| 546 |
-
" ],\n",
|
| 547 |
-
" 'covered_concepts': ['pop', 'lifo_order'],\n",
|
| 548 |
-
" 'intent': 'pop_b',\n",
|
| 549 |
-
" },\n",
|
| 550 |
-
" {\n",
|
| 551 |
-
" 'step_type': 'advance',\n",
|
| 552 |
-
" 'narration': 'Pushing C onto the stack — C now sits on top of A, with B already removed.',\n",
|
| 553 |
-
" 'ops': [\n",
|
| 554 |
-
" {'op': 'add_node', 'target_ids': ['c'], 'params': {'value': 'C', 'region': 'stack_area'}},\n",
|
| 555 |
-
" {'op': 'push_to', 'target_ids': ['stk', 'c'], 'params': {}},\n",
|
| 556 |
-
" {'op': 'move_pointer', 'target_ids': ['top'], 'params': {'index': 'c'}},\n",
|
| 557 |
-
" ],\n",
|
| 558 |
-
" 'covered_concepts': [],\n",
|
| 559 |
-
" 'intent': 'push_c',\n",
|
| 560 |
-
" },\n",
|
| 561 |
-
" {\n",
|
| 562 |
-
" 'step_type': 'complete',\n",
|
| 563 |
-
" 'narration': 'All four operations executed — stack holds A at bottom and C on top after push A, push B, pop, push C.',\n",
|
| 564 |
-
" 'ops': [],\n",
|
| 565 |
-
" 'covered_concepts': [],\n",
|
| 566 |
-
" 'intent': 'summarize',\n",
|
| 567 |
-
" },\n",
|
| 568 |
-
" ],\n",
|
| 569 |
-
"\n",
|
| 570 |
-
" # ── easy_3: binary_search ──────────────────────────────────────────\n",
|
| 571 |
-
" # input_data: {\"array\": [1, 3, 5, 7, 9, 11, 13], \"target\": 7}\n",
|
| 572 |
-
" # concepts: sorted_invariant, low_pointer, high_pointer, mid_pointer, comparison\n",
|
| 573 |
-
" 'easy_3': [\n",
|
| 574 |
-
" {\n",
|
| 575 |
-
" 'step_type': 'advance',\n",
|
| 576 |
-
" 'narration': 'Creating the first four elements of a sorted array in a centered region — the sorted invariant is what makes binary search possible.',\n",
|
| 577 |
-
" 'ops': [\n",
|
| 578 |
-
" {'op': 'add_region', 'target_ids': ['arr'], 'params': {'style': 'array', 'title': 'Sorted Array', 'position': 'center'}},\n",
|
| 579 |
-
" {'op': 'add_node', 'target_ids': ['n0'], 'params': {'value': 1, 'region': 'arr'}},\n",
|
| 580 |
-
" {'op': 'add_node', 'target_ids': ['n1'], 'params': {'value': 3, 'region': 'arr'}},\n",
|
| 581 |
-
" {'op': 'add_node', 'target_ids': ['n2'], 'params': {'value': 5, 'region': 'arr'}},\n",
|
| 582 |
-
" ],\n",
|
| 583 |
-
" 'covered_concepts': ['sorted_invariant'],\n",
|
| 584 |
-
" 'intent': 'create_array_part1',\n",
|
| 585 |
-
" },\n",
|
| 586 |
-
" {\n",
|
| 587 |
-
" 'step_type': 'advance',\n",
|
| 588 |
-
" 'narration': 'Adding the remaining elements 7, 9, 11, 13 to complete all seven values of the sorted array.',\n",
|
| 589 |
-
" 'ops': [\n",
|
| 590 |
-
" {'op': 'add_node', 'target_ids': ['n3'], 'params': {'value': 7, 'region': 'arr'}},\n",
|
| 591 |
-
" {'op': 'add_node', 'target_ids': ['n4'], 'params': {'value': 9, 'region': 'arr'}},\n",
|
| 592 |
-
" {'op': 'add_node', 'target_ids': ['n5'], 'params': {'value': 11, 'region': 'arr'}},\n",
|
| 593 |
-
" {'op': 'add_node', 'target_ids': ['n6'], 'params': {'value': 13, 'region': 'arr'}},\n",
|
| 594 |
-
" ],\n",
|
| 595 |
-
" 'covered_concepts': [],\n",
|
| 596 |
-
" 'intent': 'create_array_part2',\n",
|
| 597 |
-
" },\n",
|
| 598 |
-
" {\n",
|
| 599 |
-
" 'step_type': 'advance',\n",
|
| 600 |
-
" 'narration': 'Placing low pointer at index 0 (value 1) and high pointer at index 6 (value 13) to bracket the search range.',\n",
|
| 601 |
-
" 'ops': [\n",
|
| 602 |
-
" {'op': 'add_pointer', 'target_ids': ['low'], 'params': {'region': 'arr'}},\n",
|
| 603 |
-
" {'op': 'move_pointer', 'target_ids': ['low'], 'params': {'index': 'n0'}},\n",
|
| 604 |
-
" {'op': 'add_pointer', 'target_ids': ['high'], 'params': {'region': 'arr'}},\n",
|
| 605 |
-
" {'op': 'move_pointer', 'target_ids': ['high'], 'params': {'index': 'n6'}},\n",
|
| 606 |
-
" ],\n",
|
| 607 |
-
" 'covered_concepts': ['low_pointer', 'high_pointer'],\n",
|
| 608 |
-
" 'intent': 'init_pointers',\n",
|
| 609 |
-
" },\n",
|
| 610 |
-
" {\n",
|
| 611 |
-
" 'step_type': 'advance',\n",
|
| 612 |
-
" 'narration': 'Computing mid = (0+6)/2 = 3 — the mid pointer lands on value 7, which we compare against our target 7.',\n",
|
| 613 |
-
" 'ops': [\n",
|
| 614 |
-
" {'op': 'add_pointer', 'target_ids': ['mid'], 'params': {'region': 'arr'}},\n",
|
| 615 |
-
" {'op': 'move_pointer', 'target_ids': ['mid'], 'params': {'index': 'n3'}},\n",
|
| 616 |
-
" {'op': 'set_role', 'target_ids': ['n3'], 'params': {'role': 'current'}},\n",
|
| 617 |
-
" {'op': 'highlight', 'target_ids': ['n3'], 'params': {}},\n",
|
| 618 |
-
" ],\n",
|
| 619 |
-
" 'covered_concepts': ['mid_pointer', 'comparison'],\n",
|
| 620 |
-
" 'intent': 'compute_mid_and_compare',\n",
|
| 621 |
-
" },\n",
|
| 622 |
-
" {\n",
|
| 623 |
-
" 'step_type': 'complete',\n",
|
| 624 |
-
" 'narration': 'Target 7 found at index 3 — binary search located it in one comparison because the sorted invariant halves the search space each step.',\n",
|
| 625 |
-
" 'ops': [\n",
|
| 626 |
-
" {'op': 'set_role', 'target_ids': ['n3'], 'params': {'role': 'done'}},\n",
|
| 627 |
-
" {'op': 'annotate', 'target_ids': ['n3'], 'params': {'text': 'Found: 7'}},\n",
|
| 628 |
-
" ],\n",
|
| 629 |
-
" 'covered_concepts': [],\n",
|
| 630 |
-
" 'intent': 'found_target',\n",
|
| 631 |
-
" },\n",
|
| 632 |
-
" ],\n",
|
| 633 |
-
"}\n",
|
| 634 |
-
"\n",
|
| 635 |
-
"print(f'Gold episodes: {len(GOLD_EPISODES)} scenarios, {sum(len(v) for v in GOLD_EPISODES.values())} total steps')"
|
| 636 |
-
]
|
| 637 |
-
},
|
| 638 |
-
{
|
| 639 |
-
"cell_type": "markdown",
|
| 640 |
-
"id": "sft-hdr",
|
| 641 |
-
"metadata": {},
|
| 642 |
-
"source": [
|
| 643 |
-
"## 9. SFT Warmup\n",
|
| 644 |
-
"\n",
|
| 645 |
-
"Replay gold demonstrations through the live environment to collect accurate\n",
|
| 646 |
-
"(observation, action) pairs at each step, then train the model to imitate them."
|
| 647 |
-
]
|
| 648 |
-
},
|
| 649 |
-
{
|
| 650 |
-
"cell_type": "code",
|
| 651 |
-
"execution_count": null,
|
| 652 |
-
"id": "sft",
|
| 653 |
-
"metadata": {},
|
| 654 |
-
"outputs": [],
|
| 655 |
-
"source": [
|
| 656 |
-
"from datasets import Dataset\n",
|
| 657 |
-
"from trl import SFTConfig, SFTTrainer\n",
|
| 658 |
-
"\n",
|
| 659 |
-
"def generate_sft_data(env, gold_episodes, tokenizer):\n",
|
| 660 |
-
" \"\"\"Replay gold episodes through the env, collecting chat-formatted training data.\"\"\"\n",
|
| 661 |
-
" rows = []\n",
|
| 662 |
-
" for scenario_id, actions in gold_episodes.items():\n",
|
| 663 |
-
" obs = env.reset(scenario_id=scenario_id)\n",
|
| 664 |
-
" last_action, last_reward, history = None, 0.0, []\n",
|
| 665 |
-
" for i, action in enumerate(actions):\n",
|
| 666 |
-
" user_prompt = build_user_prompt(obs, last_action, last_reward, history)\n",
|
| 667 |
-
" messages = [\n",
|
| 668 |
-
" {'role': 'system', 'content': SYSTEM_PROMPT},\n",
|
| 669 |
-
" {'role': 'user', 'content': user_prompt},\n",
|
| 670 |
-
" {'role': 'assistant', 'content': json.dumps(action, separators=(',', ':'))},\n",
|
| 671 |
-
" ]\n",
|
| 672 |
-
" text = tokenizer.apply_chat_template(\n",
|
| 673 |
-
" messages, tokenize=False, add_generation_prompt=False\n",
|
| 674 |
-
" )\n",
|
| 675 |
-
" rows.append({'text': text})\n",
|
| 676 |
-
" obs, reward, done = env.step(action)\n",
|
| 677 |
-
" last_action, last_reward = action, reward\n",
|
| 678 |
-
" history.append(f'Step {i + 1}: {action.get(\"narration\", \"\")}')\n",
|
| 679 |
-
" if done:\n",
|
| 680 |
-
" break\n",
|
| 681 |
-
" return Dataset.from_list(rows)\n",
|
| 682 |
-
"\n",
|
| 683 |
-
"\n",
|
| 684 |
-
"sft_data = generate_sft_data(env, GOLD_EPISODES, tokenizer)\n",
|
| 685 |
-
"print(f'SFT training examples: {len(sft_data)}')\n",
|
| 686 |
-
"\n",
|
| 687 |
-
"FastLanguageModel.for_training(model)\n",
|
| 688 |
-
"\n",
|
| 689 |
-
"sft_config = SFTConfig(\n",
|
| 690 |
-
" output_dir='/tmp/vr_sft_scratch',\n",
|
| 691 |
-
" num_train_epochs=SFT_EPOCHS,\n",
|
| 692 |
-
" per_device_train_batch_size=2,\n",
|
| 693 |
-
" gradient_accumulation_steps=4,\n",
|
| 694 |
-
" learning_rate=2e-4,\n",
|
| 695 |
-
" lr_scheduler_type='cosine',\n",
|
| 696 |
-
" warmup_ratio=0.03,\n",
|
| 697 |
-
" logging_steps=5,\n",
|
| 698 |
-
" save_strategy='no',\n",
|
| 699 |
-
" fp16=True,\n",
|
| 700 |
-
" max_seq_length=MAX_SEQ_LENGTH,\n",
|
| 701 |
-
" dataset_text_field='text',\n",
|
| 702 |
-
" optim='adamw_8bit',\n",
|
| 703 |
-
" report_to='none',\n",
|
| 704 |
-
")\n",
|
| 705 |
-
"\n",
|
| 706 |
-
"sft_trainer = SFTTrainer(\n",
|
| 707 |
-
" model=model,\n",
|
| 708 |
-
" processing_class=tokenizer,\n",
|
| 709 |
-
" args=sft_config,\n",
|
| 710 |
-
" train_dataset=sft_data,\n",
|
| 711 |
-
")\n",
|
| 712 |
-
"\n",
|
| 713 |
-
"print('\\nTraining SFT...')\n",
|
| 714 |
-
"sft_trainer.train()\n",
|
| 715 |
-
"print('SFT complete!')"
|
| 716 |
-
]
|
| 717 |
-
},
|
| 718 |
-
{
|
| 719 |
-
"cell_type": "markdown",
|
| 720 |
-
"id": "post-sft-hdr",
|
| 721 |
-
"metadata": {},
|
| 722 |
-
"source": [
|
| 723 |
-
"## 10. Post-SFT Evaluation"
|
| 724 |
-
]
|
| 725 |
-
},
|
| 726 |
-
{
|
| 727 |
-
"cell_type": "code",
|
| 728 |
-
"execution_count": null,
|
| 729 |
-
"id": "post-sft-eval",
|
| 730 |
-
"metadata": {},
|
| 731 |
-
"outputs": [],
|
| 732 |
-
"source": [
|
| 733 |
-
"sft_results = evaluate(model, tokenizer, 'Post-SFT')"
|
| 734 |
-
]
|
| 735 |
-
},
|
| 736 |
-
{
|
| 737 |
-
"cell_type": "markdown",
|
| 738 |
-
"id": "grpo-hdr",
|
| 739 |
-
"metadata": {},
|
| 740 |
-
"source": [
|
| 741 |
-
"## 11. GRPO Training\n",
|
| 742 |
-
"\n",
|
| 743 |
-
"Generate prompts from all scenarios (initial observation states) and train with\n",
|
| 744 |
-
"environment reward signal. Curriculum ordering: easy prompts first, expert last."
|
| 745 |
-
]
|
| 746 |
-
},
|
| 747 |
-
{
|
| 748 |
-
"cell_type": "code",
|
| 749 |
-
"execution_count": null,
|
| 750 |
-
"id": "grpo",
|
| 751 |
-
"metadata": {},
|
| 752 |
-
"outputs": [],
|
| 753 |
-
"source": [
|
| 754 |
-
"from torch.utils.data import SequentialSampler\n",
|
| 755 |
-
"from trl import GRPOConfig, GRPOTrainer\n",
|
| 756 |
-
"\n",
|
| 757 |
-
"def generate_grpo_prompts(env, stages, samples_per_scenario=2):\n",
|
| 758 |
-
" \"\"\"Collect initial-state prompts for GRPO training.\"\"\"\n",
|
| 759 |
-
" rows = []\n",
|
| 760 |
-
" for stage in stages:\n",
|
| 761 |
-
" for sid in SCENARIOS[stage]:\n",
|
| 762 |
-
" for _ in range(samples_per_scenario):\n",
|
| 763 |
-
" obs = env.reset(scenario_id=sid)\n",
|
| 764 |
-
" user_prompt = build_user_prompt(obs, None, 0.0, [])\n",
|
| 765 |
-
" messages = [\n",
|
| 766 |
-
" {'role': 'system', 'content': SYSTEM_PROMPT},\n",
|
| 767 |
-
" {'role': 'user', 'content': user_prompt},\n",
|
| 768 |
-
" ]\n",
|
| 769 |
-
" rows.append({'prompt': messages, 'scenario_id': sid})\n",
|
| 770 |
-
" return Dataset.from_list(rows)\n",
|
| 771 |
-
"\n",
|
| 772 |
-
"\n",
|
| 773 |
-
"def make_reward_fn(env):\n",
|
| 774 |
-
" \"\"\"Reward function: parse model completion, step in env, return overall_score.\"\"\"\n",
|
| 775 |
-
" state = {'calls': 0, 'hist': Counter()}\n",
|
| 776 |
-
"\n",
|
| 777 |
-
" def reward_fn(completions, scenario_id=None, **_):\n",
|
| 778 |
-
" texts = []\n",
|
| 779 |
-
" for c in completions:\n",
|
| 780 |
-
" if isinstance(c, list):\n",
|
| 781 |
-
" texts.append(c[-1].get('content', '') if c else '')\n",
|
| 782 |
-
" else:\n",
|
| 783 |
-
" texts.append(str(c))\n",
|
| 784 |
-
"\n",
|
| 785 |
-
" sids = scenario_id if isinstance(scenario_id, list) else [scenario_id] * len(texts)\n",
|
| 786 |
-
" if len(sids) < len(texts):\n",
|
| 787 |
-
" n_gen = len(texts) // len(sids)\n",
|
| 788 |
-
" sids = [s for s in sids for _ in range(n_gen)]\n",
|
| 789 |
-
"\n",
|
| 790 |
-
" rewards = []\n",
|
| 791 |
-
" for sid, text in zip(sids, texts):\n",
|
| 792 |
-
" obs = env.reset(scenario_id=sid)\n",
|
| 793 |
-
" action = normalize_action(parse_action(text) or {})\n",
|
| 794 |
-
" if action is None:\n",
|
| 795 |
-
" rewards.append(0.0)\n",
|
| 796 |
-
" state['hist']['<unparseable>'] += 1\n",
|
| 797 |
-
" continue\n",
|
| 798 |
-
" obs, _, _ = env.step(action)\n",
|
| 799 |
-
" score = float(obs.score_breakdown.get('overall_score', 0.0))\n",
|
| 800 |
-
" rewards.append(score)\n",
|
| 801 |
-
" state['hist'][action.get('step_type', '?')] += 1\n",
|
| 802 |
-
"\n",
|
| 803 |
-
" state['calls'] += 1\n",
|
| 804 |
-
" if state['calls'] % 5 == 0:\n",
|
| 805 |
-
" print(f\" [reward] call={state['calls']} types={dict(state['hist'])}\")\n",
|
| 806 |
-
" return rewards\n",
|
| 807 |
-
"\n",
|
| 808 |
-
" return reward_fn\n",
|
| 809 |
-
"\n",
|
| 810 |
-
"\n",
|
| 811 |
-
"grpo_data = generate_grpo_prompts(env, list(DIFFICULTIES))\n",
|
| 812 |
-
"print(f'GRPO training prompts: {len(grpo_data)}')\n",
|
| 813 |
-
"\n",
|
| 814 |
-
"FastLanguageModel.for_training(model)\n",
|
| 815 |
-
"\n",
|
| 816 |
-
"for name, param in model.named_parameters():\n",
|
| 817 |
-
" if 'lora_' in name and param.dtype == torch.float32:\n",
|
| 818 |
-
" param.data = param.data.to(torch.bfloat16)\n",
|
| 819 |
-
"\n",
|
| 820 |
-
"grpo_config = GRPOConfig(\n",
|
| 821 |
-
" output_dir='/tmp/vr_grpo_scratch',\n",
|
| 822 |
-
" num_train_epochs=GRPO_EPOCHS,\n",
|
| 823 |
-
" per_device_train_batch_size=2,\n",
|
| 824 |
-
" gradient_accumulation_steps=4,\n",
|
| 825 |
-
" num_generations=4,\n",
|
| 826 |
-
" max_completion_length=384,\n",
|
| 827 |
-
" learning_rate=1e-5,\n",
|
| 828 |
-
" lr_scheduler_type='cosine',\n",
|
| 829 |
-
" warmup_ratio=0.1,\n",
|
| 830 |
-
" beta=0.05,\n",
|
| 831 |
-
" max_grad_norm=0.5,\n",
|
| 832 |
-
" temperature=0.9,\n",
|
| 833 |
-
" logging_steps=1,\n",
|
| 834 |
-
" save_strategy='no',\n",
|
| 835 |
-
" fp16=False,\n",
|
| 836 |
-
" bf16=True,\n",
|
| 837 |
-
" optim='adamw_8bit',\n",
|
| 838 |
-
" report_to='none',\n",
|
| 839 |
-
" remove_unused_columns=False,\n",
|
| 840 |
-
")\n",
|
| 841 |
-
"\n",
|
| 842 |
-
"\n",
|
| 843 |
-
"class CurriculumGRPOTrainer(GRPOTrainer):\n",
|
| 844 |
-
" \"\"\"Preserve easy -> expert ordering by disabling dataset shuffle.\"\"\"\n",
|
| 845 |
-
" def _get_train_sampler(self, *_args, **_kwargs):\n",
|
| 846 |
-
" return SequentialSampler(self.train_dataset)\n",
|
| 847 |
-
"\n",
|
| 848 |
-
"\n",
|
| 849 |
-
"grpo_trainer = CurriculumGRPOTrainer(\n",
|
| 850 |
-
" model=model,\n",
|
| 851 |
-
" tokenizer=tokenizer,\n",
|
| 852 |
-
" args=grpo_config,\n",
|
| 853 |
-
" train_dataset=grpo_data,\n",
|
| 854 |
-
" reward_funcs=make_reward_fn(env),\n",
|
| 855 |
-
")\n",
|
| 856 |
-
"\n",
|
| 857 |
-
"print('Training GRPO...')\n",
|
| 858 |
-
"grpo_trainer.train()\n",
|
| 859 |
-
"print('GRPO complete!')"
|
| 860 |
-
]
|
| 861 |
-
},
|
| 862 |
-
{
|
| 863 |
-
"cell_type": "markdown",
|
| 864 |
-
"id": "final-hdr",
|
| 865 |
-
"metadata": {},
|
| 866 |
-
"source": [
|
| 867 |
-
"## 12. Final Evaluation + Delta Report"
|
| 868 |
-
]
|
| 869 |
-
},
|
| 870 |
-
{
|
| 871 |
-
"cell_type": "code",
|
| 872 |
-
"execution_count": null,
|
| 873 |
-
"id": "final",
|
| 874 |
-
"metadata": {},
|
| 875 |
-
"outputs": [],
|
| 876 |
-
"source": [
|
| 877 |
-
"final_results = evaluate(model, tokenizer, 'Final (SFT + GRPO)')\n",
|
| 878 |
-
"\n",
|
| 879 |
-
"print(f'\\n{\"=\" * 60}')\n",
|
| 880 |
-
"print(' DELTA REPORT')\n",
|
| 881 |
-
"print(f'{\"=\" * 60}')\n",
|
| 882 |
-
"print(f' {\"Difficulty\":<12} {\"Baseline\":>10} {\"SFT\":>10} {\"SFT+GRPO\":>10}')\n",
|
| 883 |
-
"print(f' {\"-\" * 12} {\"-\" * 10} {\"-\" * 10} {\"-\" * 10}')\n",
|
| 884 |
-
"for diff in list(DIFFICULTIES) + ['overall']:\n",
|
| 885 |
-
" b = baseline.get(diff, 0.0)\n",
|
| 886 |
-
" s = sft_results.get(diff, 0.0)\n",
|
| 887 |
-
" f = final_results.get(diff, 0.0)\n",
|
| 888 |
-
" label = diff.upper() if diff == 'overall' else diff\n",
|
| 889 |
-
" print(f' {label:<12} {b:>10.3f} {s:>10.3f} {f:>10.3f}')\n",
|
| 890 |
-
"print(f'{\"=\" * 60}')\n",
|
| 891 |
-
"print('\\nDone. Model was NOT saved (in-memory only).')"
|
| 892 |
-
]
|
| 893 |
-
}
|
| 894 |
-
],
|
| 895 |
-
"metadata": {
|
| 896 |
-
"accelerator": "GPU",
|
| 897 |
-
"colab": {
|
| 898 |
-
"gpuType": "T4",
|
| 899 |
-
"provenance": []
|
| 900 |
-
},
|
| 901 |
-
"kernelspec": {
|
| 902 |
-
"display_name": "Python 3",
|
| 903 |
-
"language": "python",
|
| 904 |
-
"name": "python3"
|
| 905 |
-
},
|
| 906 |
-
"language_info": {
|
| 907 |
-
"name": "python",
|
| 908 |
-
"version": "3.10.0"
|
| 909 |
-
}
|
| 910 |
-
},
|
| 911 |
-
"nbformat": 4,
|
| 912 |
-
"nbformat_minor": 5
|
| 913 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
train.py
DELETED
|
@@ -1,632 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Visual Reasoning — Training Script
|
| 3 |
-
Simulation-based RL for teaching CS algorithms on a whiteboard.
|
| 4 |
-
|
| 5 |
-
Stages:
|
| 6 |
-
1. Baseline — score untrained model across all difficulties
|
| 7 |
-
2. SFT — imitate gold demonstrations to learn action format
|
| 8 |
-
3. GRPO — RL with dense environment rewards, easy → expert curriculum
|
| 9 |
-
4. Final eval — delta report comparing all three checkpoints
|
| 10 |
-
"""
|
| 11 |
-
|
| 12 |
-
import subprocess
|
| 13 |
-
import sys
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
def install_packages():
|
| 17 |
-
# Upgrade torchvision first to match whatever torch version unsloth pulls in
|
| 18 |
-
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "--upgrade", "torchvision>=0.25.0"])
|
| 19 |
-
packages = [
|
| 20 |
-
"huggingface_hub",
|
| 21 |
-
"unsloth",
|
| 22 |
-
"trl",
|
| 23 |
-
"datasets",
|
| 24 |
-
"transformers",
|
| 25 |
-
"accelerate",
|
| 26 |
-
"bitsandbytes",
|
| 27 |
-
"peft",
|
| 28 |
-
"torch",
|
| 29 |
-
"openenv-core",
|
| 30 |
-
"fastapi",
|
| 31 |
-
"uvicorn",
|
| 32 |
-
"pydantic",
|
| 33 |
-
"python-dotenv",
|
| 34 |
-
"networkx",
|
| 35 |
-
"shapely",
|
| 36 |
-
"sentence-transformers",
|
| 37 |
-
"rapidfuzz",
|
| 38 |
-
"textstat",
|
| 39 |
-
"sortedcontainers",
|
| 40 |
-
"numpy<2.0",
|
| 41 |
-
]
|
| 42 |
-
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q"] + packages)
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
install_packages()
|
| 46 |
-
|
| 47 |
-
# ── Section 2: Download Visual Reasoning Environment ─────────────────────────
|
| 48 |
-
|
| 49 |
-
import os
|
| 50 |
-
from huggingface_hub import snapshot_download
|
| 51 |
-
|
| 52 |
-
if not os.path.isdir("visual_reasoning"):
|
| 53 |
-
snapshot_download(
|
| 54 |
-
repo_id="sreeramajay/visual_reasoning-env",
|
| 55 |
-
repo_type="space",
|
| 56 |
-
local_dir="visual_reasoning",
|
| 57 |
-
ignore_patterns=["*.gitattributes", ".gitignore", "README.md"],
|
| 58 |
-
)
|
| 59 |
-
|
| 60 |
-
os.chdir("visual_reasoning")
|
| 61 |
-
|
| 62 |
-
# ── Section 3: Imports ────────────────────────────────────────────────────────
|
| 63 |
-
|
| 64 |
-
import json
|
| 65 |
-
import torch
|
| 66 |
-
from collections import Counter
|
| 67 |
-
|
| 68 |
-
sys.path.insert(0, ".")
|
| 69 |
-
|
| 70 |
-
from unsloth import FastLanguageModel
|
| 71 |
-
from datasets import Dataset
|
| 72 |
-
from trl import SFTConfig, SFTTrainer, GRPOConfig, GRPOTrainer
|
| 73 |
-
from torch.utils.data import SequentialSampler
|
| 74 |
-
|
| 75 |
-
from models import VisualReasoningAction
|
| 76 |
-
from server.visual_reasoning_environment import VisualReasoningEnvironment
|
| 77 |
-
from inference import SYSTEM_PROMPT, build_user_prompt, parse_action, normalize_action
|
| 78 |
-
|
| 79 |
-
# ── Section 4: Smoke Test ─────────────────────────────────────────────────────
|
| 80 |
-
|
| 81 |
-
test_env = VisualReasoningEnvironment()
|
| 82 |
-
obs = test_env.reset(scenario_id="easy_1")
|
| 83 |
-
print(f"Scenario: {obs.scenario_id}")
|
| 84 |
-
print(f"Goal: {obs.goal}")
|
| 85 |
-
print(f"Concepts: {obs.concept_checklist}")
|
| 86 |
-
print(f"Step budget: {obs.remaining_step_budget}")
|
| 87 |
-
|
| 88 |
-
obs = test_env.step(
|
| 89 |
-
VisualReasoningAction(
|
| 90 |
-
step_type="advance",
|
| 91 |
-
narration="Adding the first node with value 10.",
|
| 92 |
-
ops=[{"op": "add_node", "target_ids": ["n0"], "params": {"value": 10}}],
|
| 93 |
-
covered_concepts=["node_value"],
|
| 94 |
-
intent="test",
|
| 95 |
-
)
|
| 96 |
-
)
|
| 97 |
-
print(f"\nAfter step: entities={list(obs.entities.keys())}, reward={obs.reward:.3f}, error={obs.action_error}")
|
| 98 |
-
print("Environment OK!")
|
| 99 |
-
del test_env
|
| 100 |
-
|
| 101 |
-
# ── Section 5: Configuration ──────────────────────────────────────────────────
|
| 102 |
-
|
| 103 |
-
MODEL_NAME = "unsloth/Qwen2.5-3B-Instruct-bnb-4bit"
|
| 104 |
-
MAX_SEQ_LENGTH = 4096
|
| 105 |
-
LORA_R = 16
|
| 106 |
-
LORA_ALPHA = 32
|
| 107 |
-
SFT_EPOCHS = 3
|
| 108 |
-
GRPO_EPOCHS = 2
|
| 109 |
-
|
| 110 |
-
SCENARIOS = {
|
| 111 |
-
"easy": ["easy_1", "easy_2", "easy_3", "gen_easy_1001", "gen_easy_1002"],
|
| 112 |
-
"medium": ["medium_1", "medium_2", "gen_medium_2001", "gen_medium_2002"],
|
| 113 |
-
"hard": ["hard_1", "hard_2", "gen_hard_3001", "gen_hard_3002"],
|
| 114 |
-
"expert": ["expert_1", "expert_2", "gen_expert_4001", "gen_expert_4002"],
|
| 115 |
-
}
|
| 116 |
-
DIFFICULTIES = ("easy", "medium", "hard", "expert")
|
| 117 |
-
|
| 118 |
-
# ── Section 6: Load Model (Unsloth + LoRA) ────────────────────────────────────
|
| 119 |
-
|
| 120 |
-
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 121 |
-
model_name=MODEL_NAME,
|
| 122 |
-
max_seq_length=MAX_SEQ_LENGTH,
|
| 123 |
-
dtype=None,
|
| 124 |
-
load_in_4bit=True,
|
| 125 |
-
)
|
| 126 |
-
model = FastLanguageModel.get_peft_model(
|
| 127 |
-
model,
|
| 128 |
-
r=LORA_R,
|
| 129 |
-
lora_alpha=LORA_ALPHA,
|
| 130 |
-
lora_dropout=0,
|
| 131 |
-
target_modules=[
|
| 132 |
-
"q_proj", "k_proj", "v_proj", "o_proj",
|
| 133 |
-
"gate_proj", "up_proj", "down_proj",
|
| 134 |
-
],
|
| 135 |
-
bias="none",
|
| 136 |
-
use_gradient_checkpointing="unsloth",
|
| 137 |
-
random_state=0,
|
| 138 |
-
)
|
| 139 |
-
if tokenizer.pad_token_id is None:
|
| 140 |
-
tokenizer.pad_token = tokenizer.eos_token
|
| 141 |
-
|
| 142 |
-
model.generation_config.max_length = None
|
| 143 |
-
|
| 144 |
-
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 145 |
-
total = sum(p.numel() for p in model.parameters())
|
| 146 |
-
print(f"Parameters: {total / 1e6:.1f}M total, {trainable / 1e6:.2f}M trainable ({trainable / total:.2%})")
|
| 147 |
-
|
| 148 |
-
# ── Section 7: Environment Wrapper & Helpers ──────────────────────────────────
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
class EnvRunner:
|
| 152 |
-
"""Thin wrapper over VisualReasoningEnvironment for a clean reset/step API."""
|
| 153 |
-
|
| 154 |
-
def __init__(self):
|
| 155 |
-
self.env = VisualReasoningEnvironment()
|
| 156 |
-
|
| 157 |
-
def reset(self, scenario_id=None, task_name=None):
|
| 158 |
-
return self.env.reset(scenario_id=scenario_id, task_name=task_name)
|
| 159 |
-
|
| 160 |
-
def step(self, action_dict):
|
| 161 |
-
act = VisualReasoningAction(**action_dict)
|
| 162 |
-
obs = self.env.step(act)
|
| 163 |
-
return obs, float(obs.reward), bool(obs.done)
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
def generate_action(model, tokenizer, obs, last_action=None, last_reward=0.0, history=None):
|
| 167 |
-
"""Generate one action from the model given an observation."""
|
| 168 |
-
FastLanguageModel.for_inference(model)
|
| 169 |
-
user_prompt = build_user_prompt(obs, last_action, last_reward, history or [])
|
| 170 |
-
messages = [
|
| 171 |
-
{"role": "system", "content": SYSTEM_PROMPT},
|
| 172 |
-
{"role": "user", "content": user_prompt},
|
| 173 |
-
]
|
| 174 |
-
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 175 |
-
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
| 176 |
-
with torch.no_grad():
|
| 177 |
-
out = model.generate(
|
| 178 |
-
**inputs,
|
| 179 |
-
max_new_tokens=384,
|
| 180 |
-
do_sample=False,
|
| 181 |
-
pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
|
| 182 |
-
)
|
| 183 |
-
text = tokenizer.decode(out[0, inputs.input_ids.shape[1]:], skip_special_tokens=True)
|
| 184 |
-
return normalize_action(parse_action(text) or {}), text
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
FALLBACK_ACTION = {
|
| 188 |
-
"step_type": "complete",
|
| 189 |
-
"narration": "Explanation complete.",
|
| 190 |
-
"ops": [],
|
| 191 |
-
"covered_concepts": [],
|
| 192 |
-
"intent": "finalize",
|
| 193 |
-
}
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
def run_episode(env, scenario_id, model, tokenizer, max_steps=24):
|
| 197 |
-
"""Run a full episode. Returns (final_score, steps_taken)."""
|
| 198 |
-
obs = env.reset(scenario_id=scenario_id)
|
| 199 |
-
last_action, last_reward, history = None, 0.0, []
|
| 200 |
-
steps = 0
|
| 201 |
-
for step_num in range(1, max_steps + 1):
|
| 202 |
-
action, _ = generate_action(model, tokenizer, obs, last_action, last_reward, history)
|
| 203 |
-
if action is None:
|
| 204 |
-
action = FALLBACK_ACTION
|
| 205 |
-
obs, reward, done = env.step(action)
|
| 206 |
-
last_action, last_reward = action, reward
|
| 207 |
-
history.append(f"Step {step_num}: {action.get('narration', '')}")
|
| 208 |
-
steps = step_num
|
| 209 |
-
if done:
|
| 210 |
-
break
|
| 211 |
-
return float(obs.score_breakdown.get("overall_score", 0.0)), steps
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
def evaluate(model, tokenizer, env, label, stages=None):
|
| 215 |
-
"""Evaluate across all scenarios. Returns per-difficulty + overall scores."""
|
| 216 |
-
stages = stages or list(DIFFICULTIES)
|
| 217 |
-
print(f"\n{'=' * 60}")
|
| 218 |
-
print(f" {label}")
|
| 219 |
-
print(f"{'=' * 60}")
|
| 220 |
-
results = {}
|
| 221 |
-
for diff in stages:
|
| 222 |
-
scores = []
|
| 223 |
-
for sid in SCENARIOS.get(diff, []):
|
| 224 |
-
score, steps = run_episode(env, sid, model, tokenizer)
|
| 225 |
-
scores.append(score)
|
| 226 |
-
print(f" [{diff:6}] {sid:22} score={score:.3f} steps={steps}")
|
| 227 |
-
mean = sum(scores) / max(len(scores), 1)
|
| 228 |
-
results[diff] = mean
|
| 229 |
-
print(f" [{diff:6}] mean={mean:.3f}")
|
| 230 |
-
overall = sum(results.values()) / max(len(results), 1)
|
| 231 |
-
results["overall"] = overall
|
| 232 |
-
print(f" OVERALL: {overall:.3f}")
|
| 233 |
-
return results
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
env = EnvRunner()
|
| 237 |
-
print("EnvRunner ready.")
|
| 238 |
-
|
| 239 |
-
# ── Section 8: Baseline Evaluation ───────────────────────────────────────────
|
| 240 |
-
|
| 241 |
-
baseline = evaluate(model, tokenizer, env, "Baseline (untrained LoRA)")
|
| 242 |
-
|
| 243 |
-
# ── Section 9: Gold Demonstrations ───────────────────────────────────────────
|
| 244 |
-
|
| 245 |
-
GOLD_EPISODES = {
|
| 246 |
-
# ── easy_1: linked_list_traversal ──────────────────────────────────
|
| 247 |
-
# input_data: {"values": [10, 20, 30]}
|
| 248 |
-
# concepts: head_pointer, node_value, next_link, tail_marker
|
| 249 |
-
"easy_1": [
|
| 250 |
-
{
|
| 251 |
-
"step_type": "advance",
|
| 252 |
-
"narration": "Building a linked list — creating a centered layout region and the first two nodes with values 10 and 20.",
|
| 253 |
-
"ops": [
|
| 254 |
-
{"op": "add_region", "target_ids": ["list"], "params": {"style": "array", "title": "Linked List", "position": "center"}},
|
| 255 |
-
{"op": "add_node", "target_ids": ["n1"], "params": {"value": 10, "region": "list"}},
|
| 256 |
-
{"op": "add_node", "target_ids": ["n2"], "params": {"value": 20, "region": "list"}},
|
| 257 |
-
],
|
| 258 |
-
"covered_concepts": ["node_value"],
|
| 259 |
-
"intent": "create_list_start",
|
| 260 |
-
},
|
| 261 |
-
{
|
| 262 |
-
"step_type": "advance",
|
| 263 |
-
"narration": "Adding node 30 and connecting all nodes with next links to form the chain 10 -> 20 -> 30.",
|
| 264 |
-
"ops": [
|
| 265 |
-
{"op": "add_node", "target_ids": ["n3"], "params": {"value": 30, "region": "list"}},
|
| 266 |
-
{"op": "add_edge", "target_ids": ["n1", "n2"], "params": {"kind": "directed", "label": "next"}},
|
| 267 |
-
{"op": "add_edge", "target_ids": ["n2", "n3"], "params": {"kind": "directed", "label": "next"}},
|
| 268 |
-
],
|
| 269 |
-
"covered_concepts": ["next_link"],
|
| 270 |
-
"intent": "connect_nodes",
|
| 271 |
-
},
|
| 272 |
-
{
|
| 273 |
-
"step_type": "advance",
|
| 274 |
-
"narration": "Placing a head pointer at node 10 because traversal always starts at the head, our only entry point into the list.",
|
| 275 |
-
"ops": [
|
| 276 |
-
{"op": "add_pointer", "target_ids": ["head_ptr"], "params": {"region": "list"}},
|
| 277 |
-
{"op": "move_pointer", "target_ids": ["head_ptr"], "params": {"index": "n1"}},
|
| 278 |
-
{"op": "annotate", "target_ids": ["n1"], "params": {"text": "Head"}},
|
| 279 |
-
{"op": "set_role", "target_ids": ["n1"], "params": {"role": "current"}},
|
| 280 |
-
],
|
| 281 |
-
"covered_concepts": ["head_pointer"],
|
| 282 |
-
"intent": "mark_head",
|
| 283 |
-
},
|
| 284 |
-
{
|
| 285 |
-
"step_type": "advance",
|
| 286 |
-
"narration": "Following the next link from 10 to 20 — the pointer advances and we mark node 10 as visited.",
|
| 287 |
-
"ops": [
|
| 288 |
-
{"op": "set_role", "target_ids": ["n1"], "params": {"role": "visited"}},
|
| 289 |
-
{"op": "set_role", "target_ids": ["n2"], "params": {"role": "current"}},
|
| 290 |
-
{"op": "move_pointer", "target_ids": ["head_ptr"], "params": {"index": "n2"}},
|
| 291 |
-
],
|
| 292 |
-
"covered_concepts": [],
|
| 293 |
-
"intent": "traverse_to_second",
|
| 294 |
-
},
|
| 295 |
-
{
|
| 296 |
-
"step_type": "advance",
|
| 297 |
-
"narration": "Reaching node 30 — it has no next link, making it the tail that signals the end of traversal.",
|
| 298 |
-
"ops": [
|
| 299 |
-
{"op": "set_role", "target_ids": ["n2"], "params": {"role": "visited"}},
|
| 300 |
-
{"op": "set_role", "target_ids": ["n3"], "params": {"role": "current"}},
|
| 301 |
-
{"op": "annotate", "target_ids": ["n3"], "params": {"text": "Tail"}},
|
| 302 |
-
],
|
| 303 |
-
"covered_concepts": ["tail_marker"],
|
| 304 |
-
"intent": "reach_tail",
|
| 305 |
-
},
|
| 306 |
-
{
|
| 307 |
-
"step_type": "complete",
|
| 308 |
-
"narration": "Traversal complete — visited every node from head to tail following next links, reading values 10, 20, 30 in order.",
|
| 309 |
-
"ops": [{"op": "set_role", "target_ids": ["n3"], "params": {"role": "done"}}],
|
| 310 |
-
"covered_concepts": [],
|
| 311 |
-
"intent": "summarize",
|
| 312 |
-
},
|
| 313 |
-
],
|
| 314 |
-
|
| 315 |
-
# ── easy_2: stack_ops ──────────────────────────────────────────────
|
| 316 |
-
# input_data: {"operations": ["push A", "push B", "pop", "push C"]}
|
| 317 |
-
# concepts: top_pointer, push, pop, lifo_order
|
| 318 |
-
"easy_2": [
|
| 319 |
-
{
|
| 320 |
-
"step_type": "advance",
|
| 321 |
-
"narration": "Setting up a stack with a centered visual region and a container to track push and pop membership.",
|
| 322 |
-
"ops": [
|
| 323 |
-
{"op": "add_region", "target_ids": ["stack_area"], "params": {"style": "stack", "title": "Stack", "position": "center"}},
|
| 324 |
-
{"op": "add_container", "target_ids": ["stk"], "params": {"region": "stack_area", "ordered": False, "title": "Stack"}},
|
| 325 |
-
],
|
| 326 |
-
"covered_concepts": [],
|
| 327 |
-
"intent": "setup_stack",
|
| 328 |
-
},
|
| 329 |
-
{
|
| 330 |
-
"step_type": "advance",
|
| 331 |
-
"narration": "Pushing A onto the stack — A becomes the first element. Adding a top pointer to track the stack top.",
|
| 332 |
-
"ops": [
|
| 333 |
-
{"op": "add_node", "target_ids": ["a"], "params": {"value": "A", "region": "stack_area"}},
|
| 334 |
-
{"op": "push_to", "target_ids": ["stk", "a"], "params": {}},
|
| 335 |
-
{"op": "add_pointer", "target_ids": ["top"], "params": {"region": "stack_area"}},
|
| 336 |
-
{"op": "move_pointer", "target_ids": ["top"], "params": {"index": "a"}},
|
| 337 |
-
],
|
| 338 |
-
"covered_concepts": ["push", "top_pointer"],
|
| 339 |
-
"intent": "push_a",
|
| 340 |
-
},
|
| 341 |
-
{
|
| 342 |
-
"step_type": "advance",
|
| 343 |
-
"narration": "Pushing B — B sits on top of A and the top pointer moves up to B.",
|
| 344 |
-
"ops": [
|
| 345 |
-
{"op": "add_node", "target_ids": ["b"], "params": {"value": "B", "region": "stack_area"}},
|
| 346 |
-
{"op": "push_to", "target_ids": ["stk", "b"], "params": {}},
|
| 347 |
-
{"op": "move_pointer", "target_ids": ["top"], "params": {"index": "b"}},
|
| 348 |
-
],
|
| 349 |
-
"covered_concepts": [],
|
| 350 |
-
"intent": "push_b",
|
| 351 |
-
},
|
| 352 |
-
{
|
| 353 |
-
"step_type": "advance",
|
| 354 |
-
"narration": "Popping from the stack — B was pushed last so B comes off first, demonstrating LIFO (last-in-first-out) order.",
|
| 355 |
-
"ops": [
|
| 356 |
-
{"op": "pop_from", "target_ids": ["stk"], "params": {}},
|
| 357 |
-
{"op": "set_role", "target_ids": ["b"], "params": {"role": "inactive"}},
|
| 358 |
-
{"op": "move_pointer", "target_ids": ["top"], "params": {"index": "a"}},
|
| 359 |
-
],
|
| 360 |
-
"covered_concepts": ["pop", "lifo_order"],
|
| 361 |
-
"intent": "pop_b",
|
| 362 |
-
},
|
| 363 |
-
{
|
| 364 |
-
"step_type": "advance",
|
| 365 |
-
"narration": "Pushing C onto the stack — C now sits on top of A, with B already removed.",
|
| 366 |
-
"ops": [
|
| 367 |
-
{"op": "add_node", "target_ids": ["c"], "params": {"value": "C", "region": "stack_area"}},
|
| 368 |
-
{"op": "push_to", "target_ids": ["stk", "c"], "params": {}},
|
| 369 |
-
{"op": "move_pointer", "target_ids": ["top"], "params": {"index": "c"}},
|
| 370 |
-
],
|
| 371 |
-
"covered_concepts": [],
|
| 372 |
-
"intent": "push_c",
|
| 373 |
-
},
|
| 374 |
-
{
|
| 375 |
-
"step_type": "complete",
|
| 376 |
-
"narration": "All four operations executed — stack holds A at bottom and C on top after push A, push B, pop, push C.",
|
| 377 |
-
"ops": [],
|
| 378 |
-
"covered_concepts": [],
|
| 379 |
-
"intent": "summarize",
|
| 380 |
-
},
|
| 381 |
-
],
|
| 382 |
-
|
| 383 |
-
# ── easy_3: binary_search ──────────────────────────────────────────
|
| 384 |
-
# input_data: {"array": [1, 3, 5, 7, 9, 11, 13], "target": 7}
|
| 385 |
-
# concepts: sorted_invariant, low_pointer, high_pointer, mid_pointer, comparison
|
| 386 |
-
"easy_3": [
|
| 387 |
-
{
|
| 388 |
-
"step_type": "advance",
|
| 389 |
-
"narration": "Creating the first four elements of a sorted array in a centered region — the sorted invariant is what makes binary search possible.",
|
| 390 |
-
"ops": [
|
| 391 |
-
{"op": "add_region", "target_ids": ["arr"], "params": {"style": "array", "title": "Sorted Array", "position": "center"}},
|
| 392 |
-
{"op": "add_node", "target_ids": ["n0"], "params": {"value": 1, "region": "arr"}},
|
| 393 |
-
{"op": "add_node", "target_ids": ["n1"], "params": {"value": 3, "region": "arr"}},
|
| 394 |
-
{"op": "add_node", "target_ids": ["n2"], "params": {"value": 5, "region": "arr"}},
|
| 395 |
-
],
|
| 396 |
-
"covered_concepts": ["sorted_invariant"],
|
| 397 |
-
"intent": "create_array_part1",
|
| 398 |
-
},
|
| 399 |
-
{
|
| 400 |
-
"step_type": "advance",
|
| 401 |
-
"narration": "Adding the remaining elements 7, 9, 11, 13 to complete all seven values of the sorted array.",
|
| 402 |
-
"ops": [
|
| 403 |
-
{"op": "add_node", "target_ids": ["n3"], "params": {"value": 7, "region": "arr"}},
|
| 404 |
-
{"op": "add_node", "target_ids": ["n4"], "params": {"value": 9, "region": "arr"}},
|
| 405 |
-
{"op": "add_node", "target_ids": ["n5"], "params": {"value": 11, "region": "arr"}},
|
| 406 |
-
{"op": "add_node", "target_ids": ["n6"], "params": {"value": 13, "region": "arr"}},
|
| 407 |
-
],
|
| 408 |
-
"covered_concepts": [],
|
| 409 |
-
"intent": "create_array_part2",
|
| 410 |
-
},
|
| 411 |
-
{
|
| 412 |
-
"step_type": "advance",
|
| 413 |
-
"narration": "Placing low pointer at index 0 (value 1) and high pointer at index 6 (value 13) to bracket the search range.",
|
| 414 |
-
"ops": [
|
| 415 |
-
{"op": "add_pointer", "target_ids": ["low"], "params": {"region": "arr"}},
|
| 416 |
-
{"op": "move_pointer", "target_ids": ["low"], "params": {"index": "n0"}},
|
| 417 |
-
{"op": "add_pointer", "target_ids": ["high"], "params": {"region": "arr"}},
|
| 418 |
-
{"op": "move_pointer", "target_ids": ["high"], "params": {"index": "n6"}},
|
| 419 |
-
],
|
| 420 |
-
"covered_concepts": ["low_pointer", "high_pointer"],
|
| 421 |
-
"intent": "init_pointers",
|
| 422 |
-
},
|
| 423 |
-
{
|
| 424 |
-
"step_type": "advance",
|
| 425 |
-
"narration": "Computing mid = (0+6)/2 = 3 — the mid pointer lands on value 7, which we compare against our target 7.",
|
| 426 |
-
"ops": [
|
| 427 |
-
{"op": "add_pointer", "target_ids": ["mid"], "params": {"region": "arr"}},
|
| 428 |
-
{"op": "move_pointer", "target_ids": ["mid"], "params": {"index": "n3"}},
|
| 429 |
-
{"op": "set_role", "target_ids": ["n3"], "params": {"role": "current"}},
|
| 430 |
-
{"op": "highlight", "target_ids": ["n3"], "params": {}},
|
| 431 |
-
],
|
| 432 |
-
"covered_concepts": ["mid_pointer", "comparison"],
|
| 433 |
-
"intent": "compute_mid_and_compare",
|
| 434 |
-
},
|
| 435 |
-
{
|
| 436 |
-
"step_type": "complete",
|
| 437 |
-
"narration": "Target 7 found at index 3 — binary search located it in one comparison because the sorted invariant halves the search space each step.",
|
| 438 |
-
"ops": [
|
| 439 |
-
{"op": "set_role", "target_ids": ["n3"], "params": {"role": "done"}},
|
| 440 |
-
{"op": "annotate", "target_ids": ["n3"], "params": {"text": "Found: 7"}},
|
| 441 |
-
],
|
| 442 |
-
"covered_concepts": [],
|
| 443 |
-
"intent": "found_target",
|
| 444 |
-
},
|
| 445 |
-
],
|
| 446 |
-
}
|
| 447 |
-
|
| 448 |
-
print(f"Gold episodes: {len(GOLD_EPISODES)} scenarios, {sum(len(v) for v in GOLD_EPISODES.values())} total steps")
|
| 449 |
-
|
| 450 |
-
# ── Section 10: SFT Warmup ────────────────────────────────────────────────────
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
def generate_sft_data(env, gold_episodes, tokenizer):
|
| 454 |
-
"""Replay gold episodes through the env, collecting chat-formatted training data."""
|
| 455 |
-
rows = []
|
| 456 |
-
for scenario_id, actions in gold_episodes.items():
|
| 457 |
-
obs = env.reset(scenario_id=scenario_id)
|
| 458 |
-
last_action, last_reward, history = None, 0.0, []
|
| 459 |
-
for i, action in enumerate(actions):
|
| 460 |
-
user_prompt = build_user_prompt(obs, last_action, last_reward, history)
|
| 461 |
-
messages = [
|
| 462 |
-
{"role": "system", "content": SYSTEM_PROMPT},
|
| 463 |
-
{"role": "user", "content": user_prompt},
|
| 464 |
-
{"role": "assistant", "content": json.dumps(action, separators=(",", ":"))},
|
| 465 |
-
]
|
| 466 |
-
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
|
| 467 |
-
rows.append({"text": text})
|
| 468 |
-
obs, reward, done = env.step(action)
|
| 469 |
-
last_action, last_reward = action, reward
|
| 470 |
-
history.append(f"Step {i + 1}: {action.get('narration', '')}")
|
| 471 |
-
if done:
|
| 472 |
-
break
|
| 473 |
-
return Dataset.from_list(rows)
|
| 474 |
-
|
| 475 |
-
|
| 476 |
-
sft_data = generate_sft_data(env, GOLD_EPISODES, tokenizer)
|
| 477 |
-
print(f"SFT training examples: {len(sft_data)}")
|
| 478 |
-
|
| 479 |
-
FastLanguageModel.for_training(model)
|
| 480 |
-
|
| 481 |
-
sft_config = SFTConfig(
|
| 482 |
-
output_dir="/tmp/vr_sft_scratch",
|
| 483 |
-
num_train_epochs=SFT_EPOCHS,
|
| 484 |
-
per_device_train_batch_size=2,
|
| 485 |
-
gradient_accumulation_steps=4,
|
| 486 |
-
learning_rate=2e-4,
|
| 487 |
-
lr_scheduler_type="cosine",
|
| 488 |
-
warmup_ratio=0.03,
|
| 489 |
-
logging_steps=5,
|
| 490 |
-
save_strategy="no",
|
| 491 |
-
fp16=True,
|
| 492 |
-
max_seq_length=MAX_SEQ_LENGTH,
|
| 493 |
-
dataset_text_field="text",
|
| 494 |
-
optim="adamw_8bit",
|
| 495 |
-
report_to="none",
|
| 496 |
-
)
|
| 497 |
-
|
| 498 |
-
sft_trainer = SFTTrainer(
|
| 499 |
-
model=model,
|
| 500 |
-
processing_class=tokenizer,
|
| 501 |
-
args=sft_config,
|
| 502 |
-
train_dataset=sft_data,
|
| 503 |
-
)
|
| 504 |
-
|
| 505 |
-
print("\nTraining SFT...")
|
| 506 |
-
sft_trainer.train()
|
| 507 |
-
print("SFT complete!")
|
| 508 |
-
|
| 509 |
-
# ── Section 11: Post-SFT Evaluation ──────────────────────────────────────────
|
| 510 |
-
|
| 511 |
-
sft_results = evaluate(model, tokenizer, env, "Post-SFT")
|
| 512 |
-
|
| 513 |
-
# ── Section 12: GRPO Training ─────────────────────────────────────────────────
|
| 514 |
-
|
| 515 |
-
|
| 516 |
-
def generate_grpo_prompts(env, stages, samples_per_scenario=2):
|
| 517 |
-
"""Collect initial-state prompts for GRPO training."""
|
| 518 |
-
rows = []
|
| 519 |
-
for stage in stages:
|
| 520 |
-
for sid in SCENARIOS[stage]:
|
| 521 |
-
for _ in range(samples_per_scenario):
|
| 522 |
-
obs = env.reset(scenario_id=sid)
|
| 523 |
-
user_prompt = build_user_prompt(obs, None, 0.0, [])
|
| 524 |
-
messages = [
|
| 525 |
-
{"role": "system", "content": SYSTEM_PROMPT},
|
| 526 |
-
{"role": "user", "content": user_prompt},
|
| 527 |
-
]
|
| 528 |
-
rows.append({"prompt": messages, "scenario_id": sid})
|
| 529 |
-
return Dataset.from_list(rows)
|
| 530 |
-
|
| 531 |
-
|
| 532 |
-
def make_reward_fn(env):
|
| 533 |
-
"""Reward function: parse model completion, step in env, return overall_score."""
|
| 534 |
-
state = {"calls": 0, "hist": Counter()}
|
| 535 |
-
|
| 536 |
-
def reward_fn(completions, scenario_id=None, **_):
|
| 537 |
-
texts = []
|
| 538 |
-
for c in completions:
|
| 539 |
-
if isinstance(c, list):
|
| 540 |
-
texts.append(c[-1].get("content", "") if c else "")
|
| 541 |
-
else:
|
| 542 |
-
texts.append(str(c))
|
| 543 |
-
|
| 544 |
-
sids = scenario_id if isinstance(scenario_id, list) else [scenario_id] * len(texts)
|
| 545 |
-
if len(sids) < len(texts):
|
| 546 |
-
n_gen = len(texts) // len(sids)
|
| 547 |
-
sids = [s for s in sids for _ in range(n_gen)]
|
| 548 |
-
|
| 549 |
-
rewards = []
|
| 550 |
-
for sid, text in zip(sids, texts):
|
| 551 |
-
obs = env.reset(scenario_id=sid)
|
| 552 |
-
action = normalize_action(parse_action(text) or {})
|
| 553 |
-
if action is None:
|
| 554 |
-
rewards.append(0.0)
|
| 555 |
-
state["hist"]["<unparseable>"] += 1
|
| 556 |
-
continue
|
| 557 |
-
obs, _, _ = env.step(action)
|
| 558 |
-
score = float(obs.score_breakdown.get("overall_score", 0.0))
|
| 559 |
-
rewards.append(score)
|
| 560 |
-
state["hist"][action.get("step_type", "?")] += 1
|
| 561 |
-
|
| 562 |
-
state["calls"] += 1
|
| 563 |
-
if state["calls"] % 5 == 0:
|
| 564 |
-
print(f" [reward] call={state['calls']} types={dict(state['hist'])}")
|
| 565 |
-
return rewards
|
| 566 |
-
|
| 567 |
-
return reward_fn
|
| 568 |
-
|
| 569 |
-
|
| 570 |
-
grpo_data = generate_grpo_prompts(env, list(DIFFICULTIES))
|
| 571 |
-
print(f"GRPO training prompts: {len(grpo_data)}")
|
| 572 |
-
|
| 573 |
-
FastLanguageModel.for_training(model)
|
| 574 |
-
|
| 575 |
-
grpo_config = GRPOConfig(
|
| 576 |
-
output_dir="/tmp/vr_grpo_scratch",
|
| 577 |
-
num_train_epochs=GRPO_EPOCHS,
|
| 578 |
-
per_device_train_batch_size=2,
|
| 579 |
-
gradient_accumulation_steps=4,
|
| 580 |
-
num_generations=4,
|
| 581 |
-
max_completion_length=384,
|
| 582 |
-
learning_rate=1e-5,
|
| 583 |
-
lr_scheduler_type="cosine",
|
| 584 |
-
warmup_ratio=0.1,
|
| 585 |
-
beta=0.05,
|
| 586 |
-
max_grad_norm=0.5,
|
| 587 |
-
temperature=0.9,
|
| 588 |
-
logging_steps=1,
|
| 589 |
-
save_strategy="no",
|
| 590 |
-
bf16=True,
|
| 591 |
-
optim="adamw_8bit",
|
| 592 |
-
report_to="none",
|
| 593 |
-
remove_unused_columns=False,
|
| 594 |
-
)
|
| 595 |
-
|
| 596 |
-
|
| 597 |
-
class CurriculumGRPOTrainer(GRPOTrainer):
|
| 598 |
-
"""Preserve easy -> expert ordering by disabling dataset shuffle."""
|
| 599 |
-
|
| 600 |
-
def _get_train_sampler(self, *_args, **_kwargs):
|
| 601 |
-
return SequentialSampler(self.train_dataset)
|
| 602 |
-
|
| 603 |
-
|
| 604 |
-
grpo_trainer = CurriculumGRPOTrainer(
|
| 605 |
-
model=model,
|
| 606 |
-
tokenizer=tokenizer,
|
| 607 |
-
args=grpo_config,
|
| 608 |
-
train_dataset=grpo_data,
|
| 609 |
-
reward_funcs=make_reward_fn(env),
|
| 610 |
-
)
|
| 611 |
-
|
| 612 |
-
print("Training GRPO...")
|
| 613 |
-
grpo_trainer.train()
|
| 614 |
-
print("GRPO complete!")
|
| 615 |
-
|
| 616 |
-
# ── Section 13: Final Evaluation + Delta Report ───────────────────────────────
|
| 617 |
-
|
| 618 |
-
final_results = evaluate(model, tokenizer, env, "Final (SFT + GRPO)")
|
| 619 |
-
|
| 620 |
-
print(f"\n{'=' * 60}")
|
| 621 |
-
print(" DELTA REPORT")
|
| 622 |
-
print(f"{'=' * 60}")
|
| 623 |
-
print(f" {'Difficulty':<12} {'Baseline':>10} {'SFT':>10} {'SFT+GRPO':>10}")
|
| 624 |
-
print(f" {'-' * 12} {'-' * 10} {'-' * 10} {'-' * 10}")
|
| 625 |
-
for diff in list(DIFFICULTIES) + ["overall"]:
|
| 626 |
-
b = baseline.get(diff, 0.0)
|
| 627 |
-
s = sft_results.get(diff, 0.0)
|
| 628 |
-
f = final_results.get(diff, 0.0)
|
| 629 |
-
label = diff.upper() if diff == "overall" else diff
|
| 630 |
-
print(f" {label:<12} {b:>10.3f} {s:>10.3f} {f:>10.3f}")
|
| 631 |
-
print(f"{'=' * 60}")
|
| 632 |
-
print("\nDone. Model was NOT saved (in-memory only).")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
train_hf.py
DELETED
|
@@ -1,771 +0,0 @@
|
|
| 1 |
-
"""Visual Reasoning — HuggingFace Training Job
|
| 2 |
-
|
| 3 |
-
Train Qwen3-8B to be an expert visual CS teacher via SFT warmup + staged GRPO.
|
| 4 |
-
Run with: hf jobs run --flavor a100 train_hf.py
|
| 5 |
-
|
| 6 |
-
Speedups over train.ipynb:
|
| 7 |
-
- NARRATION_SCORER=fallback → CPU heuristic, no GPU contention
|
| 8 |
-
- Batched generation → N episodes in one model.generate() call
|
| 9 |
-
- Early termination → kill episodes after 3 consecutive no-ops
|
| 10 |
-
- snapshot_download → faster than git clone for large repos
|
| 11 |
-
"""
|
| 12 |
-
|
| 13 |
-
# ── 0. Install dependencies before any imports ──────────────────────────────
|
| 14 |
-
|
| 15 |
-
import subprocess, sys
|
| 16 |
-
|
| 17 |
-
def pip_install(*packages):
|
| 18 |
-
subprocess.check_call(
|
| 19 |
-
[sys.executable, "-m", "pip", "install", "-q", *packages],
|
| 20 |
-
stdout=subprocess.DEVNULL,
|
| 21 |
-
)
|
| 22 |
-
|
| 23 |
-
print("[0/7] Installing dependencies...")
|
| 24 |
-
pip_install(
|
| 25 |
-
"unsloth", "trl", "datasets", "transformers", "accelerate",
|
| 26 |
-
"bitsandbytes", "peft", "torch",
|
| 27 |
-
)
|
| 28 |
-
pip_install(
|
| 29 |
-
"openenv-core", "fastapi", "uvicorn", "pydantic",
|
| 30 |
-
)
|
| 31 |
-
pip_install(
|
| 32 |
-
"python-dotenv", "networkx", "shapely", "sentence-transformers",
|
| 33 |
-
"rapidfuzz", "textstat", "sortedcontainers", "huggingface_hub",
|
| 34 |
-
)
|
| 35 |
-
print("[0/7] Dependencies installed.")
|
| 36 |
-
|
| 37 |
-
# ── 1. Imports (unsloth first to patch transformers early) ──────────────────
|
| 38 |
-
|
| 39 |
-
import os
|
| 40 |
-
os.environ["NARRATION_SCORER"] = "fallback" # CPU scorer, no GPU contention
|
| 41 |
-
|
| 42 |
-
from unsloth import FastLanguageModel # must be before transformers
|
| 43 |
-
|
| 44 |
-
import json
|
| 45 |
-
import time
|
| 46 |
-
import torch
|
| 47 |
-
from collections import Counter
|
| 48 |
-
from datasets import Dataset
|
| 49 |
-
from huggingface_hub import snapshot_download
|
| 50 |
-
from torch.utils.data import SequentialSampler
|
| 51 |
-
from trl import SFTConfig, SFTTrainer, GRPOConfig, GRPOTrainer
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
# ── 2. Config ───────────────────────────────────────────────────────────────
|
| 55 |
-
|
| 56 |
-
# Model
|
| 57 |
-
MODEL_NAME = "unsloth/Qwen3-8B-unsloth-bnb-4bit"
|
| 58 |
-
MAX_SEQ_LENGTH = 4096
|
| 59 |
-
LORA_R = 32
|
| 60 |
-
LORA_ALPHA = 64
|
| 61 |
-
|
| 62 |
-
# SFT
|
| 63 |
-
SFT_EPOCHS = 3
|
| 64 |
-
SFT_LR = 2e-4
|
| 65 |
-
SFT_BATCH_SIZE = 2
|
| 66 |
-
SFT_GRAD_ACCUM = 4
|
| 67 |
-
|
| 68 |
-
# GRPO
|
| 69 |
-
GRPO_LR = 5e-6
|
| 70 |
-
GRPO_STAGE1_EPOCHS = 2
|
| 71 |
-
GRPO_STAGE2_EPOCHS = 2
|
| 72 |
-
GRPO_NUM_GENERATIONS = 4
|
| 73 |
-
GRPO_BATCH_SIZE = 2
|
| 74 |
-
GRPO_GRAD_ACCUM = 4
|
| 75 |
-
GRPO_TEMPERATURE = 0.9
|
| 76 |
-
|
| 77 |
-
# Scenario counts for procedural generation
|
| 78 |
-
GRPO_SCENARIOS_PER_DIFFICULTY = {"easy": 40, "medium": 30, "hard": 20, "expert": 10}
|
| 79 |
-
DIFFICULTIES = ("easy", "medium", "hard", "expert")
|
| 80 |
-
|
| 81 |
-
# Eval
|
| 82 |
-
EVAL_BATCH_SIZE = 8
|
| 83 |
-
EVAL_MAX_STEPS = 24
|
| 84 |
-
NOOP_EARLY_STOP = 3
|
| 85 |
-
|
| 86 |
-
# Hub — HF_TOKEN is set by the HuggingFace jobs runtime
|
| 87 |
-
HUB_REPO = None # set to "username/model-name" to push
|
| 88 |
-
HF_TOKEN = os.environ.get("HF_TOKEN")
|
| 89 |
-
|
| 90 |
-
# Static scenarios for evaluation
|
| 91 |
-
STATIC_SCENARIOS = {
|
| 92 |
-
"easy": ["easy_1", "easy_2", "easy_3"],
|
| 93 |
-
"medium": ["medium_1", "medium_2"],
|
| 94 |
-
"hard": ["hard_1", "hard_2"],
|
| 95 |
-
"expert": ["expert_1", "expert_2"],
|
| 96 |
-
}
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
# ── 3. Download environment repo ───────────────────────────────────────────
|
| 100 |
-
|
| 101 |
-
print("[1/7] Downloading visual_reasoning environment...")
|
| 102 |
-
env_path = snapshot_download(
|
| 103 |
-
repo_id="sreeramajay/visual_reasoning-env",
|
| 104 |
-
repo_type="space",
|
| 105 |
-
local_dir="visual_reasoning",
|
| 106 |
-
token=HF_TOKEN,
|
| 107 |
-
)
|
| 108 |
-
sys.path.insert(0, env_path)
|
| 109 |
-
print(f"[1/7] Environment downloaded to {env_path}")
|
| 110 |
-
|
| 111 |
-
from models import VisualReasoningAction, VisualReasoningObservation
|
| 112 |
-
from server.visual_reasoning_environment import VisualReasoningEnvironment
|
| 113 |
-
from server.scenario_generator import generate_scenario
|
| 114 |
-
from inference import (
|
| 115 |
-
SYSTEM_PROMPT, build_user_prompt, parse_action,
|
| 116 |
-
normalize_action as inf_normalize_action,
|
| 117 |
-
)
|
| 118 |
-
|
| 119 |
-
# Smoke test
|
| 120 |
-
_test_env = VisualReasoningEnvironment()
|
| 121 |
-
_obs = _test_env.reset(scenario_id="easy_1")
|
| 122 |
-
_obs = _test_env.step(VisualReasoningAction(
|
| 123 |
-
step_type="advance",
|
| 124 |
-
narration="Adding the first node with value 10.",
|
| 125 |
-
ops=[{"op": "add_node", "target_ids": ["n0"], "params": {"value": 10}}],
|
| 126 |
-
covered_concepts=["node_value"], intent="test",
|
| 127 |
-
))
|
| 128 |
-
assert _obs.reward != 0.0, "Environment smoke test failed"
|
| 129 |
-
del _test_env, _obs
|
| 130 |
-
print("[1/7] Environment smoke test passed.")
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
# ── 4. Load model ──────────────────────────────────────────────────────────
|
| 134 |
-
|
| 135 |
-
print("[2/7] Loading model...")
|
| 136 |
-
t0 = time.time()
|
| 137 |
-
|
| 138 |
-
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 139 |
-
model_name=MODEL_NAME,
|
| 140 |
-
max_seq_length=MAX_SEQ_LENGTH,
|
| 141 |
-
dtype=None,
|
| 142 |
-
load_in_4bit=True,
|
| 143 |
-
)
|
| 144 |
-
model = FastLanguageModel.get_peft_model(
|
| 145 |
-
model,
|
| 146 |
-
r=LORA_R,
|
| 147 |
-
lora_alpha=LORA_ALPHA,
|
| 148 |
-
lora_dropout=0,
|
| 149 |
-
target_modules=[
|
| 150 |
-
"q_proj", "k_proj", "v_proj", "o_proj",
|
| 151 |
-
"gate_proj", "up_proj", "down_proj",
|
| 152 |
-
],
|
| 153 |
-
bias="none",
|
| 154 |
-
use_gradient_checkpointing="unsloth",
|
| 155 |
-
random_state=0,
|
| 156 |
-
)
|
| 157 |
-
if tokenizer.pad_token_id is None:
|
| 158 |
-
tokenizer.pad_token = tokenizer.eos_token
|
| 159 |
-
model.generation_config.max_length = None
|
| 160 |
-
|
| 161 |
-
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 162 |
-
total = sum(p.numel() for p in model.parameters())
|
| 163 |
-
print(f"[2/7] Model loaded in {time.time() - t0:.0f}s — "
|
| 164 |
-
f"{total/1e6:.0f}M params, {trainable/1e6:.1f}M trainable ({trainable/total:.1%})")
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
# ── 5. Batched evaluation engine ───────────────────────────────────────────
|
| 168 |
-
|
| 169 |
-
FALLBACK_ACTION = {
|
| 170 |
-
"step_type": "complete", "narration": "Explanation complete.",
|
| 171 |
-
"ops": [], "covered_concepts": [], "intent": "finalize",
|
| 172 |
-
}
|
| 173 |
-
|
| 174 |
-
# Reusable env pool — avoids repeated __init__ / warmup_scorer overhead
|
| 175 |
-
_env_pool = []
|
| 176 |
-
|
| 177 |
-
def _get_env(idx):
|
| 178 |
-
while len(_env_pool) <= idx:
|
| 179 |
-
_env_pool.append(VisualReasoningEnvironment())
|
| 180 |
-
return _env_pool[idx]
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
class EpisodeState:
|
| 184 |
-
"""Tracks one in-flight episode for batched eval."""
|
| 185 |
-
|
| 186 |
-
def __init__(self, env, scenario_id):
|
| 187 |
-
self.env = env
|
| 188 |
-
self.scenario_id = scenario_id
|
| 189 |
-
self.obs = env.reset(scenario_id=scenario_id)
|
| 190 |
-
self.last_action = None
|
| 191 |
-
self.last_reward = 0.0
|
| 192 |
-
self.history = []
|
| 193 |
-
self.steps = 0
|
| 194 |
-
self.done = False
|
| 195 |
-
self.score = 0.0
|
| 196 |
-
self.consecutive_noops = 0
|
| 197 |
-
|
| 198 |
-
def build_prompt_text(self):
|
| 199 |
-
user_prompt = build_user_prompt(
|
| 200 |
-
self.obs, self.last_action, self.last_reward, self.history
|
| 201 |
-
)
|
| 202 |
-
messages = [
|
| 203 |
-
{"role": "system", "content": SYSTEM_PROMPT},
|
| 204 |
-
{"role": "user", "content": user_prompt},
|
| 205 |
-
]
|
| 206 |
-
return tokenizer.apply_chat_template(
|
| 207 |
-
messages, tokenize=False, add_generation_prompt=True
|
| 208 |
-
)
|
| 209 |
-
|
| 210 |
-
def apply_action(self, action_dict):
|
| 211 |
-
if action_dict is None:
|
| 212 |
-
action_dict = FALLBACK_ACTION
|
| 213 |
-
self.obs = self.env.step(VisualReasoningAction(**action_dict))
|
| 214 |
-
reward = float(self.obs.reward)
|
| 215 |
-
self.last_action = action_dict
|
| 216 |
-
self.last_reward = reward
|
| 217 |
-
self.steps += 1
|
| 218 |
-
self.history.append(f"Step {self.steps}: {action_dict.get('narration', '')}")
|
| 219 |
-
if reward <= -0.04:
|
| 220 |
-
self.consecutive_noops += 1
|
| 221 |
-
else:
|
| 222 |
-
self.consecutive_noops = 0
|
| 223 |
-
if self.obs.done or self.consecutive_noops >= NOOP_EARLY_STOP:
|
| 224 |
-
self.done = True
|
| 225 |
-
self.score = float(self.obs.score_breakdown.get("overall_score", 0.0))
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
def batched_generate(prompt_texts, max_new_tokens=384):
|
| 229 |
-
"""Single batched model.generate() call for all active episodes."""
|
| 230 |
-
FastLanguageModel.for_inference(model)
|
| 231 |
-
inputs = tokenizer(
|
| 232 |
-
prompt_texts,
|
| 233 |
-
return_tensors="pt",
|
| 234 |
-
padding=True,
|
| 235 |
-
truncation=True,
|
| 236 |
-
max_length=MAX_SEQ_LENGTH,
|
| 237 |
-
).to(model.device)
|
| 238 |
-
with torch.no_grad():
|
| 239 |
-
outputs = model.generate(
|
| 240 |
-
**inputs,
|
| 241 |
-
max_new_tokens=max_new_tokens,
|
| 242 |
-
do_sample=False,
|
| 243 |
-
pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
|
| 244 |
-
)
|
| 245 |
-
results = []
|
| 246 |
-
for i, out in enumerate(outputs):
|
| 247 |
-
input_len = inputs.input_ids[i].ne(tokenizer.pad_token_id).sum()
|
| 248 |
-
results.append(tokenizer.decode(out[input_len:], skip_special_tokens=True))
|
| 249 |
-
return results
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
def evaluate(label, stages=None, scenarios=None):
|
| 253 |
-
"""Batched eval: runs EVAL_BATCH_SIZE episodes in parallel per generate call."""
|
| 254 |
-
stages = stages or list(DIFFICULTIES)
|
| 255 |
-
scenarios = scenarios or STATIC_SCENARIOS
|
| 256 |
-
|
| 257 |
-
print(f"\n{'=' * 60}")
|
| 258 |
-
print(f" EVAL: {label}")
|
| 259 |
-
print(f"{'=' * 60}")
|
| 260 |
-
|
| 261 |
-
all_ids = []
|
| 262 |
-
for diff in stages:
|
| 263 |
-
all_ids.extend([(diff, sid) for sid in scenarios.get(diff, [])])
|
| 264 |
-
|
| 265 |
-
results_by_diff = {d: [] for d in stages}
|
| 266 |
-
t0 = time.time()
|
| 267 |
-
|
| 268 |
-
for batch_start in range(0, len(all_ids), EVAL_BATCH_SIZE):
|
| 269 |
-
batch = all_ids[batch_start : batch_start + EVAL_BATCH_SIZE]
|
| 270 |
-
episodes = [EpisodeState(_get_env(i), sid) for i, (_, sid) in enumerate(batch)]
|
| 271 |
-
|
| 272 |
-
for _ in range(1, EVAL_MAX_STEPS + 1):
|
| 273 |
-
active = [ep for ep in episodes if not ep.done]
|
| 274 |
-
if not active:
|
| 275 |
-
break
|
| 276 |
-
prompts = [ep.build_prompt_text() for ep in active]
|
| 277 |
-
texts = batched_generate(prompts)
|
| 278 |
-
for ep, text in zip(active, texts):
|
| 279 |
-
parsed = parse_action(text)
|
| 280 |
-
action = inf_normalize_action(parsed or {}) if parsed else None
|
| 281 |
-
ep.apply_action(action)
|
| 282 |
-
|
| 283 |
-
for ep, (diff, sid) in zip(episodes, batch):
|
| 284 |
-
results_by_diff[diff].append(ep.score)
|
| 285 |
-
print(f" [{diff:6}] {sid:22} score={ep.score:.3f} steps={ep.steps}")
|
| 286 |
-
|
| 287 |
-
overall_scores = {}
|
| 288 |
-
for diff in stages:
|
| 289 |
-
scores = results_by_diff[diff]
|
| 290 |
-
mean = sum(scores) / max(len(scores), 1)
|
| 291 |
-
overall_scores[diff] = mean
|
| 292 |
-
print(f" [{diff:6}] MEAN = {mean:.3f}")
|
| 293 |
-
|
| 294 |
-
overall = sum(overall_scores.values()) / max(len(overall_scores), 1)
|
| 295 |
-
overall_scores["overall"] = overall
|
| 296 |
-
print(f" OVERALL: {overall:.3f} ({time.time() - t0:.1f}s)")
|
| 297 |
-
return overall_scores
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
# ── 6. Gold demonstrations for SFT ─────────────────────────────────────────
|
| 301 |
-
|
| 302 |
-
# Hand-crafted episodes that teach the model:
|
| 303 |
-
# - JSON action format (step_type, narration, ops, covered_concepts, intent)
|
| 304 |
-
# - Incremental drawing (Phase 1), algorithm walk-through (Phase 2), wrap-up (Phase 3)
|
| 305 |
-
# - Region vs container distinction, concept evidencing, pacing
|
| 306 |
-
|
| 307 |
-
GOLD_EPISODES = {
|
| 308 |
-
# ── easy_1: linked_list_traversal ──
|
| 309 |
-
# input: {"values": [10, 20, 30]}, concepts: head_pointer, node_value, next_link, tail_marker
|
| 310 |
-
"easy_1": [
|
| 311 |
-
{
|
| 312 |
-
"step_type": "advance",
|
| 313 |
-
"narration": "Building a linked list — creating a centered layout region and the first two nodes with values 10 and 20.",
|
| 314 |
-
"ops": [
|
| 315 |
-
{"op": "add_region", "target_ids": ["list"], "params": {"style": "array", "title": "Linked List", "position": "center"}},
|
| 316 |
-
{"op": "add_node", "target_ids": ["n1"], "params": {"value": 10, "region": "list"}},
|
| 317 |
-
{"op": "add_node", "target_ids": ["n2"], "params": {"value": 20, "region": "list"}},
|
| 318 |
-
],
|
| 319 |
-
"covered_concepts": ["node_value"],
|
| 320 |
-
"intent": "create_list_start",
|
| 321 |
-
},
|
| 322 |
-
{
|
| 323 |
-
"step_type": "advance",
|
| 324 |
-
"narration": "Adding node 30 and connecting all nodes with next links to form the chain 10 -> 20 -> 30.",
|
| 325 |
-
"ops": [
|
| 326 |
-
{"op": "add_node", "target_ids": ["n3"], "params": {"value": 30, "region": "list"}},
|
| 327 |
-
{"op": "add_edge", "target_ids": ["n1", "n2"], "params": {"kind": "directed", "label": "next"}},
|
| 328 |
-
{"op": "add_edge", "target_ids": ["n2", "n3"], "params": {"kind": "directed", "label": "next"}},
|
| 329 |
-
],
|
| 330 |
-
"covered_concepts": ["next_link"],
|
| 331 |
-
"intent": "connect_nodes",
|
| 332 |
-
},
|
| 333 |
-
{
|
| 334 |
-
"step_type": "advance",
|
| 335 |
-
"narration": "Placing a head pointer at node 10 because traversal always starts at the head, our only entry point into the list.",
|
| 336 |
-
"ops": [
|
| 337 |
-
{"op": "add_pointer", "target_ids": ["head_ptr"], "params": {"region": "list"}},
|
| 338 |
-
{"op": "move_pointer", "target_ids": ["head_ptr"], "params": {"index": "n1"}},
|
| 339 |
-
{"op": "annotate", "target_ids": ["n1"], "params": {"text": "Head"}},
|
| 340 |
-
{"op": "set_role", "target_ids": ["n1"], "params": {"role": "current"}},
|
| 341 |
-
],
|
| 342 |
-
"covered_concepts": ["head_pointer"],
|
| 343 |
-
"intent": "mark_head",
|
| 344 |
-
},
|
| 345 |
-
{
|
| 346 |
-
"step_type": "advance",
|
| 347 |
-
"narration": "Following the next link from 10 to 20 — the pointer advances and we mark node 10 as visited.",
|
| 348 |
-
"ops": [
|
| 349 |
-
{"op": "set_role", "target_ids": ["n1"], "params": {"role": "visited"}},
|
| 350 |
-
{"op": "set_role", "target_ids": ["n2"], "params": {"role": "current"}},
|
| 351 |
-
{"op": "move_pointer", "target_ids": ["head_ptr"], "params": {"index": "n2"}},
|
| 352 |
-
],
|
| 353 |
-
"covered_concepts": [],
|
| 354 |
-
"intent": "traverse_to_second",
|
| 355 |
-
},
|
| 356 |
-
{
|
| 357 |
-
"step_type": "advance",
|
| 358 |
-
"narration": "Reaching node 30 — it has no next link, making it the tail that signals the end of traversal.",
|
| 359 |
-
"ops": [
|
| 360 |
-
{"op": "set_role", "target_ids": ["n2"], "params": {"role": "visited"}},
|
| 361 |
-
{"op": "set_role", "target_ids": ["n3"], "params": {"role": "current"}},
|
| 362 |
-
{"op": "annotate", "target_ids": ["n3"], "params": {"text": "Tail"}},
|
| 363 |
-
],
|
| 364 |
-
"covered_concepts": ["tail_marker"],
|
| 365 |
-
"intent": "reach_tail",
|
| 366 |
-
},
|
| 367 |
-
{
|
| 368 |
-
"step_type": "complete",
|
| 369 |
-
"narration": "Traversal complete — visited every node from head to tail following next links, reading values 10, 20, 30 in order.",
|
| 370 |
-
"ops": [{"op": "set_role", "target_ids": ["n3"], "params": {"role": "done"}}],
|
| 371 |
-
"covered_concepts": [],
|
| 372 |
-
"intent": "summarize",
|
| 373 |
-
},
|
| 374 |
-
],
|
| 375 |
-
|
| 376 |
-
# ── easy_2: stack_ops ──
|
| 377 |
-
# input: {"operations": ["push A", "push B", "pop", "push C"]}, concepts: top_pointer, push, pop, lifo_order
|
| 378 |
-
"easy_2": [
|
| 379 |
-
{
|
| 380 |
-
"step_type": "advance",
|
| 381 |
-
"narration": "Setting up a stack with a centered visual region and a container to track push and pop membership.",
|
| 382 |
-
"ops": [
|
| 383 |
-
{"op": "add_region", "target_ids": ["stack_area"], "params": {"style": "stack", "title": "Stack", "position": "center"}},
|
| 384 |
-
{"op": "add_container", "target_ids": ["stk"], "params": {"region": "stack_area", "ordered": False, "title": "Stack"}},
|
| 385 |
-
],
|
| 386 |
-
"covered_concepts": [],
|
| 387 |
-
"intent": "setup_stack",
|
| 388 |
-
},
|
| 389 |
-
{
|
| 390 |
-
"step_type": "advance",
|
| 391 |
-
"narration": "Pushing A onto the stack — A becomes the first element. Adding a top pointer to track the stack top.",
|
| 392 |
-
"ops": [
|
| 393 |
-
{"op": "add_node", "target_ids": ["a"], "params": {"value": "A", "region": "stack_area"}},
|
| 394 |
-
{"op": "push_to", "target_ids": ["stk", "a"], "params": {}},
|
| 395 |
-
{"op": "add_pointer", "target_ids": ["top"], "params": {"region": "stack_area"}},
|
| 396 |
-
{"op": "move_pointer", "target_ids": ["top"], "params": {"index": "a"}},
|
| 397 |
-
],
|
| 398 |
-
"covered_concepts": ["push", "top_pointer"],
|
| 399 |
-
"intent": "push_a",
|
| 400 |
-
},
|
| 401 |
-
{
|
| 402 |
-
"step_type": "advance",
|
| 403 |
-
"narration": "Pushing B — B sits on top of A and the top pointer moves up to B.",
|
| 404 |
-
"ops": [
|
| 405 |
-
{"op": "add_node", "target_ids": ["b"], "params": {"value": "B", "region": "stack_area"}},
|
| 406 |
-
{"op": "push_to", "target_ids": ["stk", "b"], "params": {}},
|
| 407 |
-
{"op": "move_pointer", "target_ids": ["top"], "params": {"index": "b"}},
|
| 408 |
-
],
|
| 409 |
-
"covered_concepts": [],
|
| 410 |
-
"intent": "push_b",
|
| 411 |
-
},
|
| 412 |
-
{
|
| 413 |
-
"step_type": "advance",
|
| 414 |
-
"narration": "Popping from the stack — B was pushed last so B comes off first, demonstrating LIFO (last-in-first-out) order.",
|
| 415 |
-
"ops": [
|
| 416 |
-
{"op": "pop_from", "target_ids": ["stk"], "params": {}},
|
| 417 |
-
{"op": "set_role", "target_ids": ["b"], "params": {"role": "inactive"}},
|
| 418 |
-
{"op": "move_pointer", "target_ids": ["top"], "params": {"index": "a"}},
|
| 419 |
-
],
|
| 420 |
-
"covered_concepts": ["pop", "lifo_order"],
|
| 421 |
-
"intent": "pop_b",
|
| 422 |
-
},
|
| 423 |
-
{
|
| 424 |
-
"step_type": "advance",
|
| 425 |
-
"narration": "Pushing C onto the stack — C now sits on top of A, with B already removed.",
|
| 426 |
-
"ops": [
|
| 427 |
-
{"op": "add_node", "target_ids": ["c"], "params": {"value": "C", "region": "stack_area"}},
|
| 428 |
-
{"op": "push_to", "target_ids": ["stk", "c"], "params": {}},
|
| 429 |
-
{"op": "move_pointer", "target_ids": ["top"], "params": {"index": "c"}},
|
| 430 |
-
],
|
| 431 |
-
"covered_concepts": [],
|
| 432 |
-
"intent": "push_c",
|
| 433 |
-
},
|
| 434 |
-
{
|
| 435 |
-
"step_type": "complete",
|
| 436 |
-
"narration": "All four operations executed — stack holds A at bottom and C on top after push A, push B, pop, push C.",
|
| 437 |
-
"ops": [],
|
| 438 |
-
"covered_concepts": [],
|
| 439 |
-
"intent": "summarize",
|
| 440 |
-
},
|
| 441 |
-
],
|
| 442 |
-
|
| 443 |
-
# ── easy_3: binary_search ──
|
| 444 |
-
# input: {"array": [1,3,5,7,9,11,13], "target": 7}, concepts: sorted_invariant, low/high/mid_pointer, comparison
|
| 445 |
-
"easy_3": [
|
| 446 |
-
{
|
| 447 |
-
"step_type": "advance",
|
| 448 |
-
"narration": "Creating the first four elements of a sorted array in a centered region — the sorted invariant is what makes binary search possible.",
|
| 449 |
-
"ops": [
|
| 450 |
-
{"op": "add_region", "target_ids": ["arr"], "params": {"style": "array", "title": "Sorted Array", "position": "center"}},
|
| 451 |
-
{"op": "add_node", "target_ids": ["n0"], "params": {"value": 1, "region": "arr"}},
|
| 452 |
-
{"op": "add_node", "target_ids": ["n1"], "params": {"value": 3, "region": "arr"}},
|
| 453 |
-
{"op": "add_node", "target_ids": ["n2"], "params": {"value": 5, "region": "arr"}},
|
| 454 |
-
],
|
| 455 |
-
"covered_concepts": ["sorted_invariant"],
|
| 456 |
-
"intent": "create_array_part1",
|
| 457 |
-
},
|
| 458 |
-
{
|
| 459 |
-
"step_type": "advance",
|
| 460 |
-
"narration": "Adding the remaining elements 7, 9, 11, 13 to complete all seven values of the sorted array.",
|
| 461 |
-
"ops": [
|
| 462 |
-
{"op": "add_node", "target_ids": ["n3"], "params": {"value": 7, "region": "arr"}},
|
| 463 |
-
{"op": "add_node", "target_ids": ["n4"], "params": {"value": 9, "region": "arr"}},
|
| 464 |
-
{"op": "add_node", "target_ids": ["n5"], "params": {"value": 11, "region": "arr"}},
|
| 465 |
-
{"op": "add_node", "target_ids": ["n6"], "params": {"value": 13, "region": "arr"}},
|
| 466 |
-
],
|
| 467 |
-
"covered_concepts": [],
|
| 468 |
-
"intent": "create_array_part2",
|
| 469 |
-
},
|
| 470 |
-
{
|
| 471 |
-
"step_type": "advance",
|
| 472 |
-
"narration": "Placing low pointer at index 0 (value 1) and high pointer at index 6 (value 13) to bracket the search range.",
|
| 473 |
-
"ops": [
|
| 474 |
-
{"op": "add_pointer", "target_ids": ["low"], "params": {"region": "arr"}},
|
| 475 |
-
{"op": "move_pointer", "target_ids": ["low"], "params": {"index": "n0"}},
|
| 476 |
-
{"op": "add_pointer", "target_ids": ["high"], "params": {"region": "arr"}},
|
| 477 |
-
{"op": "move_pointer", "target_ids": ["high"], "params": {"index": "n6"}},
|
| 478 |
-
],
|
| 479 |
-
"covered_concepts": ["low_pointer", "high_pointer"],
|
| 480 |
-
"intent": "init_pointers",
|
| 481 |
-
},
|
| 482 |
-
{
|
| 483 |
-
"step_type": "advance",
|
| 484 |
-
"narration": "Computing mid = (0+6)/2 = 3 — the mid pointer lands on value 7, which we compare against our target 7.",
|
| 485 |
-
"ops": [
|
| 486 |
-
{"op": "add_pointer", "target_ids": ["mid"], "params": {"region": "arr"}},
|
| 487 |
-
{"op": "move_pointer", "target_ids": ["mid"], "params": {"index": "n3"}},
|
| 488 |
-
{"op": "set_role", "target_ids": ["n3"], "params": {"role": "current"}},
|
| 489 |
-
{"op": "highlight", "target_ids": ["n3"], "params": {}},
|
| 490 |
-
],
|
| 491 |
-
"covered_concepts": ["mid_pointer", "comparison"],
|
| 492 |
-
"intent": "compute_mid_and_compare",
|
| 493 |
-
},
|
| 494 |
-
{
|
| 495 |
-
"step_type": "complete",
|
| 496 |
-
"narration": "Target 7 found at index 3 — binary search located it in one comparison because the sorted invariant halves the search space each step.",
|
| 497 |
-
"ops": [
|
| 498 |
-
{"op": "set_role", "target_ids": ["n3"], "params": {"role": "done"}},
|
| 499 |
-
{"op": "annotate", "target_ids": ["n3"], "params": {"text": "Found: 7"}},
|
| 500 |
-
],
|
| 501 |
-
"covered_concepts": [],
|
| 502 |
-
"intent": "found_target",
|
| 503 |
-
},
|
| 504 |
-
],
|
| 505 |
-
}
|
| 506 |
-
|
| 507 |
-
|
| 508 |
-
# ── 7. SFT data generation ─────────────────────────────────────────────────
|
| 509 |
-
|
| 510 |
-
def generate_sft_data():
|
| 511 |
-
"""Replay gold episodes through the live env to collect (observation, action) pairs."""
|
| 512 |
-
env = VisualReasoningEnvironment()
|
| 513 |
-
rows = []
|
| 514 |
-
for scenario_id, actions in GOLD_EPISODES.items():
|
| 515 |
-
obs = env.reset(scenario_id=scenario_id)
|
| 516 |
-
last_action, last_reward, history = None, 0.0, []
|
| 517 |
-
for i, action in enumerate(actions):
|
| 518 |
-
user_prompt = build_user_prompt(obs, last_action, last_reward, history)
|
| 519 |
-
messages = [
|
| 520 |
-
{"role": "system", "content": SYSTEM_PROMPT},
|
| 521 |
-
{"role": "user", "content": user_prompt},
|
| 522 |
-
{"role": "assistant", "content": json.dumps(action, separators=(",", ":"))},
|
| 523 |
-
]
|
| 524 |
-
text = tokenizer.apply_chat_template(
|
| 525 |
-
messages, tokenize=False, add_generation_prompt=False
|
| 526 |
-
)
|
| 527 |
-
rows.append({"text": text})
|
| 528 |
-
obs = env.step(VisualReasoningAction(**action))
|
| 529 |
-
last_action, last_reward = action, float(obs.reward)
|
| 530 |
-
history.append(f"Step {i + 1}: {action.get('narration', '')}")
|
| 531 |
-
if obs.done:
|
| 532 |
-
break
|
| 533 |
-
return Dataset.from_list(rows)
|
| 534 |
-
|
| 535 |
-
|
| 536 |
-
# ── 8. GRPO helpers ─────────────────────────────────────────────────────────
|
| 537 |
-
|
| 538 |
-
def generate_grpo_scenarios():
|
| 539 |
-
"""Procedurally generate diverse scenarios for each difficulty."""
|
| 540 |
-
out = {}
|
| 541 |
-
for diff, count in GRPO_SCENARIOS_PER_DIFFICULTY.items():
|
| 542 |
-
base_seed = {"easy": 10000, "medium": 20000, "hard": 30000, "expert": 40000}[diff]
|
| 543 |
-
scenarios = [generate_scenario(task_name=diff, seed=base_seed + i) for i in range(count)]
|
| 544 |
-
templates = Counter(s["template"] for s in scenarios)
|
| 545 |
-
print(f" {diff}: {count} scenarios — {dict(templates)}")
|
| 546 |
-
out[diff] = scenarios
|
| 547 |
-
return out
|
| 548 |
-
|
| 549 |
-
|
| 550 |
-
def build_grpo_prompts(scenarios_by_diff, stages, samples_per_scenario=2):
|
| 551 |
-
"""Collect initial-state prompts by resetting the env for each scenario."""
|
| 552 |
-
env = VisualReasoningEnvironment()
|
| 553 |
-
rows = []
|
| 554 |
-
for stage in stages:
|
| 555 |
-
for scenario in scenarios_by_diff.get(stage, []):
|
| 556 |
-
sid = scenario["scenario_id"]
|
| 557 |
-
for _ in range(samples_per_scenario):
|
| 558 |
-
obs = env.reset(scenario_id=sid)
|
| 559 |
-
user_prompt = build_user_prompt(obs, None, 0.0, [])
|
| 560 |
-
messages = [
|
| 561 |
-
{"role": "system", "content": SYSTEM_PROMPT},
|
| 562 |
-
{"role": "user", "content": user_prompt},
|
| 563 |
-
]
|
| 564 |
-
rows.append({"prompt": messages, "scenario_id": sid})
|
| 565 |
-
return Dataset.from_list(rows)
|
| 566 |
-
|
| 567 |
-
|
| 568 |
-
def make_reward_fn():
|
| 569 |
-
"""Reward function: parse completion, step in env, return overall_score."""
|
| 570 |
-
env = VisualReasoningEnvironment()
|
| 571 |
-
state = {"calls": 0, "hist": Counter()}
|
| 572 |
-
|
| 573 |
-
def reward_fn(completions, scenario_id=None, **_):
|
| 574 |
-
texts = []
|
| 575 |
-
for c in completions:
|
| 576 |
-
if isinstance(c, list):
|
| 577 |
-
texts.append(c[-1].get("content", "") if c else "")
|
| 578 |
-
else:
|
| 579 |
-
texts.append(str(c))
|
| 580 |
-
|
| 581 |
-
sids = scenario_id if isinstance(scenario_id, list) else [scenario_id] * len(texts)
|
| 582 |
-
if len(sids) < len(texts):
|
| 583 |
-
n_gen = len(texts) // len(sids)
|
| 584 |
-
sids = [s for s in sids for _ in range(n_gen)]
|
| 585 |
-
|
| 586 |
-
rewards = []
|
| 587 |
-
for sid, text in zip(sids, texts):
|
| 588 |
-
obs = env.reset(scenario_id=sid)
|
| 589 |
-
parsed = parse_action(text)
|
| 590 |
-
action = inf_normalize_action(parsed or {}) if parsed else None
|
| 591 |
-
if action is None:
|
| 592 |
-
rewards.append(0.0)
|
| 593 |
-
state["hist"]["<unparseable>"] += 1
|
| 594 |
-
continue
|
| 595 |
-
obs = env.step(VisualReasoningAction(**action))
|
| 596 |
-
rewards.append(float(obs.score_breakdown.get("overall_score", 0.0)))
|
| 597 |
-
state["hist"][action.get("step_type", "?")] += 1
|
| 598 |
-
|
| 599 |
-
state["calls"] += 1
|
| 600 |
-
if state["calls"] % 10 == 0:
|
| 601 |
-
print(f" [reward] call={state['calls']} types={dict(state['hist'])}")
|
| 602 |
-
return rewards
|
| 603 |
-
|
| 604 |
-
return reward_fn
|
| 605 |
-
|
| 606 |
-
|
| 607 |
-
class CurriculumGRPOTrainer(GRPOTrainer):
|
| 608 |
-
"""Sequential sampler preserves easy → expert curriculum ordering."""
|
| 609 |
-
def _get_train_sampler(self, *_args, **_kwargs):
|
| 610 |
-
return SequentialSampler(self.train_dataset)
|
| 611 |
-
|
| 612 |
-
|
| 613 |
-
# ── 9. Main training loop ──────────────────────────────────────────────────
|
| 614 |
-
|
| 615 |
-
def main():
|
| 616 |
-
job_start = time.time()
|
| 617 |
-
|
| 618 |
-
# ── Baseline ──
|
| 619 |
-
print("\n[3/7] Baseline evaluation...")
|
| 620 |
-
baseline = evaluate("Baseline (untrained LoRA)")
|
| 621 |
-
|
| 622 |
-
# ── SFT ──
|
| 623 |
-
print("\n[4/7] SFT warmup...")
|
| 624 |
-
sft_data = generate_sft_data()
|
| 625 |
-
print(f" SFT examples: {len(sft_data)}")
|
| 626 |
-
|
| 627 |
-
FastLanguageModel.for_training(model)
|
| 628 |
-
sft_trainer = SFTTrainer(
|
| 629 |
-
model=model,
|
| 630 |
-
processing_class=tokenizer,
|
| 631 |
-
args=SFTConfig(
|
| 632 |
-
output_dir="/tmp/vr_sft",
|
| 633 |
-
num_train_epochs=SFT_EPOCHS,
|
| 634 |
-
per_device_train_batch_size=SFT_BATCH_SIZE,
|
| 635 |
-
gradient_accumulation_steps=SFT_GRAD_ACCUM,
|
| 636 |
-
learning_rate=SFT_LR,
|
| 637 |
-
lr_scheduler_type="cosine",
|
| 638 |
-
warmup_ratio=0.03,
|
| 639 |
-
logging_steps=5,
|
| 640 |
-
save_strategy="no",
|
| 641 |
-
bf16=True,
|
| 642 |
-
max_seq_length=MAX_SEQ_LENGTH,
|
| 643 |
-
dataset_text_field="text",
|
| 644 |
-
optim="adamw_8bit",
|
| 645 |
-
report_to="none",
|
| 646 |
-
),
|
| 647 |
-
train_dataset=sft_data,
|
| 648 |
-
)
|
| 649 |
-
|
| 650 |
-
t0 = time.time()
|
| 651 |
-
sft_trainer.train()
|
| 652 |
-
print(f" SFT done in {time.time() - t0:.0f}s")
|
| 653 |
-
|
| 654 |
-
sft_results = evaluate("Post-SFT")
|
| 655 |
-
|
| 656 |
-
# ── Generate GRPO scenarios ──
|
| 657 |
-
print("\n[5/7] Generating GRPO scenarios...")
|
| 658 |
-
grpo_scenarios = generate_grpo_scenarios()
|
| 659 |
-
total_scenarios = sum(len(v) for v in grpo_scenarios.values())
|
| 660 |
-
print(f" Total: {total_scenarios} scenarios")
|
| 661 |
-
|
| 662 |
-
# ── GRPO Stage 1: easy + medium ──
|
| 663 |
-
print("\n[6/7] GRPO Stage 1 (easy + medium)...")
|
| 664 |
-
stage1_data = build_grpo_prompts(grpo_scenarios, ["easy", "medium"])
|
| 665 |
-
print(f" Stage 1 prompts: {len(stage1_data)}")
|
| 666 |
-
|
| 667 |
-
FastLanguageModel.for_training(model)
|
| 668 |
-
grpo_s1 = CurriculumGRPOTrainer(
|
| 669 |
-
model=model,
|
| 670 |
-
tokenizer=tokenizer,
|
| 671 |
-
args=GRPOConfig(
|
| 672 |
-
output_dir="/tmp/vr_grpo_s1",
|
| 673 |
-
num_train_epochs=GRPO_STAGE1_EPOCHS,
|
| 674 |
-
per_device_train_batch_size=GRPO_BATCH_SIZE,
|
| 675 |
-
gradient_accumulation_steps=GRPO_GRAD_ACCUM,
|
| 676 |
-
num_generations=GRPO_NUM_GENERATIONS,
|
| 677 |
-
max_completion_length=384,
|
| 678 |
-
learning_rate=GRPO_LR,
|
| 679 |
-
lr_scheduler_type="cosine",
|
| 680 |
-
warmup_ratio=0.1,
|
| 681 |
-
beta=0.05,
|
| 682 |
-
max_grad_norm=0.5,
|
| 683 |
-
temperature=GRPO_TEMPERATURE,
|
| 684 |
-
logging_steps=1,
|
| 685 |
-
save_strategy="no",
|
| 686 |
-
bf16=True,
|
| 687 |
-
optim="adamw_8bit",
|
| 688 |
-
report_to="none",
|
| 689 |
-
remove_unused_columns=False,
|
| 690 |
-
),
|
| 691 |
-
train_dataset=stage1_data,
|
| 692 |
-
reward_funcs=make_reward_fn(),
|
| 693 |
-
)
|
| 694 |
-
|
| 695 |
-
t0 = time.time()
|
| 696 |
-
grpo_s1.train()
|
| 697 |
-
print(f" Stage 1 done in {time.time() - t0:.0f}s")
|
| 698 |
-
|
| 699 |
-
stage1_results = evaluate("Post-GRPO Stage 1", stages=["easy", "medium"])
|
| 700 |
-
|
| 701 |
-
# ── GRPO Stage 2: all difficulties ──
|
| 702 |
-
print("\n[7/7] GRPO Stage 2 (all difficulties)...")
|
| 703 |
-
stage2_data = build_grpo_prompts(grpo_scenarios, list(DIFFICULTIES))
|
| 704 |
-
print(f" Stage 2 prompts: {len(stage2_data)}")
|
| 705 |
-
|
| 706 |
-
FastLanguageModel.for_training(model)
|
| 707 |
-
grpo_s2 = CurriculumGRPOTrainer(
|
| 708 |
-
model=model,
|
| 709 |
-
tokenizer=tokenizer,
|
| 710 |
-
args=GRPOConfig(
|
| 711 |
-
output_dir="/tmp/vr_grpo_s2",
|
| 712 |
-
num_train_epochs=GRPO_STAGE2_EPOCHS,
|
| 713 |
-
per_device_train_batch_size=GRPO_BATCH_SIZE,
|
| 714 |
-
gradient_accumulation_steps=GRPO_GRAD_ACCUM,
|
| 715 |
-
num_generations=GRPO_NUM_GENERATIONS,
|
| 716 |
-
max_completion_length=384,
|
| 717 |
-
learning_rate=GRPO_LR * 0.5, # halved for stability with harder scenarios
|
| 718 |
-
lr_scheduler_type="cosine",
|
| 719 |
-
warmup_ratio=0.1,
|
| 720 |
-
beta=0.05,
|
| 721 |
-
max_grad_norm=0.5,
|
| 722 |
-
temperature=GRPO_TEMPERATURE,
|
| 723 |
-
logging_steps=1,
|
| 724 |
-
save_strategy="no",
|
| 725 |
-
bf16=True,
|
| 726 |
-
optim="adamw_8bit",
|
| 727 |
-
report_to="none",
|
| 728 |
-
remove_unused_columns=False,
|
| 729 |
-
),
|
| 730 |
-
train_dataset=stage2_data,
|
| 731 |
-
reward_funcs=make_reward_fn(),
|
| 732 |
-
)
|
| 733 |
-
|
| 734 |
-
t0 = time.time()
|
| 735 |
-
grpo_s2.train()
|
| 736 |
-
print(f" Stage 2 done in {time.time() - t0:.0f}s")
|
| 737 |
-
|
| 738 |
-
# ── Final eval + report ──
|
| 739 |
-
final_results = evaluate("Final (SFT + GRPO S1 + S2)")
|
| 740 |
-
|
| 741 |
-
print(f"\n{'=' * 72}")
|
| 742 |
-
print(" DELTA REPORT")
|
| 743 |
-
print(f"{'=' * 72}")
|
| 744 |
-
print(f" {'Difficulty':<12} {'Baseline':>10} {'SFT':>10} {'GRPO-S1':>10} {'Final':>10} {'Δ':>10}")
|
| 745 |
-
print(f" {'-'*12} {'-'*10} {'-'*10} {'-'*10} {'-'*10} {'-'*10}")
|
| 746 |
-
for diff in list(DIFFICULTIES) + ["overall"]:
|
| 747 |
-
b = baseline.get(diff, 0.0)
|
| 748 |
-
s = sft_results.get(diff, 0.0)
|
| 749 |
-
s1 = stage1_results.get(diff, 0.0)
|
| 750 |
-
f = final_results.get(diff, 0.0)
|
| 751 |
-
label = diff.upper() if diff == "overall" else diff
|
| 752 |
-
print(f" {label:<12} {b:>10.3f} {s:>10.3f} {s1:>10.3f} {f:>10.3f} {f - b:>+10.3f}")
|
| 753 |
-
print(f"{'=' * 72}")
|
| 754 |
-
|
| 755 |
-
# ── Push to hub ──
|
| 756 |
-
if HUB_REPO:
|
| 757 |
-
print(f"\nPushing LoRA adapter to {HUB_REPO}...")
|
| 758 |
-
model.push_to_hub(HUB_REPO, token=HF_TOKEN)
|
| 759 |
-
tokenizer.push_to_hub(HUB_REPO, token=HF_TOKEN)
|
| 760 |
-
print(f"Pushed: https://huggingface.co/{HUB_REPO}")
|
| 761 |
-
else:
|
| 762 |
-
print("\nHUB_REPO not set — saving locally to /tmp/vr_qwen3_8b_lora")
|
| 763 |
-
model.save_pretrained("/tmp/vr_qwen3_8b_lora")
|
| 764 |
-
tokenizer.save_pretrained("/tmp/vr_qwen3_8b_lora")
|
| 765 |
-
|
| 766 |
-
total_mins = (time.time() - job_start) / 60
|
| 767 |
-
print(f"\nJob finished in {total_mins:.1f} minutes.")
|
| 768 |
-
|
| 769 |
-
|
| 770 |
-
if __name__ == "__main__":
|
| 771 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
uv.lock
DELETED
|
The diff for this file is too large to render.
See raw diff
|
|
|
viewer/audio_viewer.html
DELETED
|
@@ -1,865 +0,0 @@
|
|
| 1 |
-
<!doctype html>
|
| 2 |
-
<html lang="en">
|
| 3 |
-
<head>
|
| 4 |
-
<meta charset="utf-8"/>
|
| 5 |
-
<title>Visual Reasoning — Audio Viewer</title>
|
| 6 |
-
<style>
|
| 7 |
-
*, *::before, *::after { box-sizing: border-box; margin: 0; padding: 0; }
|
| 8 |
-
|
| 9 |
-
:root {
|
| 10 |
-
--bg-dark: #1e1e2e;
|
| 11 |
-
--bg-panel: #181825;
|
| 12 |
-
--bg-surface: #11111b;
|
| 13 |
-
--border: #313244;
|
| 14 |
-
--text: #cdd6f4;
|
| 15 |
-
--text-dim: #6c7086;
|
| 16 |
-
--accent: #89b4fa;
|
| 17 |
-
--green: #a6e3a1;
|
| 18 |
-
--red: #f38ba8;
|
| 19 |
-
--yellow: #f9e2af;
|
| 20 |
-
--peach: #fab387;
|
| 21 |
-
--mauve: #cba6f7;
|
| 22 |
-
--teal: #94e2d5;
|
| 23 |
-
}
|
| 24 |
-
|
| 25 |
-
html, body { height: 100%; font-family: 'Inter', -apple-system, system-ui, sans-serif; background: var(--bg-dark); color: var(--text); overflow: hidden; }
|
| 26 |
-
|
| 27 |
-
#app { display: flex; flex-direction: column; height: 100vh; }
|
| 28 |
-
|
| 29 |
-
/* ---- Header ---- */
|
| 30 |
-
.header {
|
| 31 |
-
display: flex; align-items: center; gap: 12px;
|
| 32 |
-
padding: 6px 14px; background: var(--bg-panel); border-bottom: 1px solid var(--border);
|
| 33 |
-
font-size: 12px; flex-shrink: 0; min-height: 36px;
|
| 34 |
-
}
|
| 35 |
-
.header .badge { padding: 2px 8px; border-radius: 4px; font-weight: 600; font-size: 11px; }
|
| 36 |
-
.header .badge-task { background: rgba(137,180,250,0.15); color: var(--accent); }
|
| 37 |
-
.header .badge-scenario { color: var(--text-dim); font-size: 11px; }
|
| 38 |
-
.header .step-counter { color: var(--text); font-size: 11px; }
|
| 39 |
-
.header .score-display { font-weight: 700; color: var(--green); font-size: 14px; }
|
| 40 |
-
.header .spacer { flex: 1; }
|
| 41 |
-
.header .ws-dot { width: 7px; height: 7px; border-radius: 50%; }
|
| 42 |
-
.header .ws-dot.on { background: var(--green); }
|
| 43 |
-
.header .ws-dot.off { background: var(--red); }
|
| 44 |
-
|
| 45 |
-
.audio-btn {
|
| 46 |
-
background: none; border: 1px solid var(--border); color: var(--text-dim);
|
| 47 |
-
border-radius: 4px; padding: 2px 8px; font-size: 14px; cursor: pointer;
|
| 48 |
-
transition: all .2s;
|
| 49 |
-
}
|
| 50 |
-
.audio-btn:hover { border-color: var(--accent); color: var(--accent); }
|
| 51 |
-
.audio-btn.active { color: var(--accent); border-color: var(--accent); }
|
| 52 |
-
|
| 53 |
-
/* ---- Main layout ---- */
|
| 54 |
-
.main { display: flex; flex: 1; min-height: 0; overflow: hidden; }
|
| 55 |
-
|
| 56 |
-
/* ---- Canvas ---- */
|
| 57 |
-
.canvas-wrap {
|
| 58 |
-
flex: 1; min-width: 0; background: var(--bg-surface);
|
| 59 |
-
position: relative; overflow: hidden;
|
| 60 |
-
}
|
| 61 |
-
.canvas-wrap canvas { display: block; width: 100%; height: 100%; }
|
| 62 |
-
|
| 63 |
-
/* ---- Side panel ---- */
|
| 64 |
-
.side {
|
| 65 |
-
width: 270px; flex-shrink: 0; background: var(--bg-panel);
|
| 66 |
-
border-left: 1px solid var(--border);
|
| 67 |
-
display: flex; flex-direction: column; overflow: hidden;
|
| 68 |
-
}
|
| 69 |
-
.side-section {
|
| 70 |
-
padding: 8px 12px; border-bottom: 1px solid var(--border); flex-shrink: 0;
|
| 71 |
-
}
|
| 72 |
-
.side-section h3 {
|
| 73 |
-
font-size: 9px; text-transform: uppercase; letter-spacing: 1.2px;
|
| 74 |
-
color: var(--accent); margin-bottom: 4px; font-weight: 600;
|
| 75 |
-
}
|
| 76 |
-
|
| 77 |
-
/* Goal */
|
| 78 |
-
.goal-text {
|
| 79 |
-
font-size: 11px; line-height: 1.4; color: var(--text);
|
| 80 |
-
max-height: 48px; overflow: hidden; text-overflow: ellipsis;
|
| 81 |
-
}
|
| 82 |
-
|
| 83 |
-
/* Checklist */
|
| 84 |
-
.checklist-grid { display: flex; flex-wrap: wrap; gap: 2px 8px; }
|
| 85 |
-
.concept { display: flex; align-items: center; gap: 4px; padding: 1px 0; font-size: 11px; }
|
| 86 |
-
.concept .icon { width: 14px; text-align: center; font-size: 12px; }
|
| 87 |
-
.concept.covered .icon { color: var(--green); }
|
| 88 |
-
.concept.uncovered .icon { color: var(--text-dim); }
|
| 89 |
-
.concept.covered { color: var(--green); }
|
| 90 |
-
.concept.uncovered { color: var(--text-dim); }
|
| 91 |
-
|
| 92 |
-
/* Score bars */
|
| 93 |
-
.score-row { display: flex; align-items: center; gap: 4px; margin-bottom: 2px; font-size: 10px; }
|
| 94 |
-
.score-label { width: 78px; flex-shrink: 0; color: var(--text-dim); white-space: nowrap; overflow: hidden; text-overflow: ellipsis; }
|
| 95 |
-
.score-track { flex: 1; height: 12px; background: var(--bg-surface); border-radius: 2px; overflow: hidden; }
|
| 96 |
-
.score-fill { height: 100%; border-radius: 2px; display: flex; align-items: center; padding-left: 3px; font-size: 8px; font-weight: 600; color: #fff; min-width: 20px; transition: width .3s; }
|
| 97 |
-
|
| 98 |
-
/* Narration log */
|
| 99 |
-
.narration-log { flex: 1; overflow-y: auto; padding: 8px 12px; min-height: 0; }
|
| 100 |
-
.narration-log h3 { font-size: 9px; text-transform: uppercase; letter-spacing: 1.2px; color: var(--accent); margin-bottom: 4px; font-weight: 600; }
|
| 101 |
-
.narr-entry { padding: 4px 0; border-bottom: 1px solid rgba(49,50,68,.5); font-size: 11px; animation: fadeIn .3s; }
|
| 102 |
-
.narr-entry .step-tag { font-weight: 600; color: var(--accent); font-size: 10px; }
|
| 103 |
-
.narr-entry .reward { font-size: 9px; font-weight: 700; margin-left: 4px; }
|
| 104 |
-
.narr-entry .reward.pos { color: var(--green); }
|
| 105 |
-
.narr-entry .reward.neg { color: var(--red); }
|
| 106 |
-
.narr-entry .text { color: var(--text); margin-top: 1px; line-height: 1.35; }
|
| 107 |
-
|
| 108 |
-
/* ---- Bottom narration bar ---- */
|
| 109 |
-
.narration-bar {
|
| 110 |
-
flex-shrink: 0; display: flex; align-items: center; gap: 10px;
|
| 111 |
-
padding: 8px 16px; background: var(--bg-panel); border-top: 1px solid var(--border);
|
| 112 |
-
min-height: 40px; max-height: 60px;
|
| 113 |
-
}
|
| 114 |
-
.narr-indicator {
|
| 115 |
-
width: 24px; height: 24px; display: flex; align-items: center; justify-content: center;
|
| 116 |
-
flex-shrink: 0; font-size: 16px; color: var(--text-dim); transition: color .3s;
|
| 117 |
-
}
|
| 118 |
-
.narr-indicator.speaking { color: var(--accent); }
|
| 119 |
-
.narr-indicator.speaking .bars { display: flex; align-items: flex-end; gap: 2px; height: 16px; }
|
| 120 |
-
.narr-indicator.speaking .bars span {
|
| 121 |
-
width: 3px; background: var(--accent); border-radius: 1px;
|
| 122 |
-
animation: barPulse .6s ease-in-out infinite alternate;
|
| 123 |
-
}
|
| 124 |
-
.narr-indicator.speaking .bars span:nth-child(2) { animation-delay: .15s; }
|
| 125 |
-
.narr-indicator.speaking .bars span:nth-child(3) { animation-delay: .3s; }
|
| 126 |
-
.narr-indicator.speaking .bars span:nth-child(4) { animation-delay: .45s; }
|
| 127 |
-
.narr-text {
|
| 128 |
-
flex: 1; font-size: 13px; color: var(--text); font-style: italic;
|
| 129 |
-
white-space: nowrap; overflow: hidden; text-overflow: ellipsis;
|
| 130 |
-
transition: opacity .5s;
|
| 131 |
-
}
|
| 132 |
-
.narr-text.dim { opacity: .3; }
|
| 133 |
-
.music-badge {
|
| 134 |
-
flex-shrink: 0; font-size: 10px; color: var(--mauve); display: flex; align-items: center; gap: 4px;
|
| 135 |
-
opacity: 0; transition: opacity .5s;
|
| 136 |
-
}
|
| 137 |
-
.music-badge.on { opacity: 1; }
|
| 138 |
-
|
| 139 |
-
@keyframes barPulse {
|
| 140 |
-
0% { height: 4px; }
|
| 141 |
-
100% { height: 14px; }
|
| 142 |
-
}
|
| 143 |
-
|
| 144 |
-
/* ---- End overlay ---- */
|
| 145 |
-
.overlay {
|
| 146 |
-
position: fixed; inset: 0; background: rgba(17,17,27,.92);
|
| 147 |
-
display: flex; flex-direction: column; align-items: center; justify-content: center;
|
| 148 |
-
z-index: 1000; cursor: pointer;
|
| 149 |
-
}
|
| 150 |
-
.overlay.hidden { display: none; }
|
| 151 |
-
.overlay h1 { font-size: 24px; margin-bottom: 6px; }
|
| 152 |
-
.overlay h1.ok { color: var(--green); }
|
| 153 |
-
.overlay h1.fail { color: var(--red); }
|
| 154 |
-
.overlay .final { font-size: 16px; color: var(--text-dim); margin-bottom: 16px; }
|
| 155 |
-
.reward-chart { display: flex; align-items: flex-end; gap: 4px; height: 100px; }
|
| 156 |
-
.reward-chart .bar { width: 22px; border-radius: 3px 3px 0 0; min-height: 2px; position: relative; }
|
| 157 |
-
.reward-chart .bar .lbl { position: absolute; bottom: -14px; left: 50%; transform: translateX(-50%); font-size: 8px; color: var(--text-dim); }
|
| 158 |
-
|
| 159 |
-
/* Audio hint banner */
|
| 160 |
-
.audio-hint {
|
| 161 |
-
position: fixed; bottom: 60px; left: 50%; transform: translateX(-50%);
|
| 162 |
-
background: rgba(137,180,250,0.15); color: var(--accent); padding: 8px 20px;
|
| 163 |
-
border-radius: 8px; font-size: 13px; z-index: 2000; cursor: pointer;
|
| 164 |
-
border: 1px solid rgba(137,180,250,0.3); backdrop-filter: blur(8px);
|
| 165 |
-
animation: fadeIn .5s;
|
| 166 |
-
}
|
| 167 |
-
.audio-hint.hidden { display: none; }
|
| 168 |
-
|
| 169 |
-
/* Error toast */
|
| 170 |
-
.toast {
|
| 171 |
-
position: fixed; top: 46px; left: 50%; transform: translateX(-50%);
|
| 172 |
-
background: var(--red); color: #11111b; padding: 5px 16px;
|
| 173 |
-
border-radius: 6px; font-size: 11px; font-weight: 600; z-index: 999; display: none;
|
| 174 |
-
}
|
| 175 |
-
|
| 176 |
-
@keyframes fadeIn { from { opacity: 0; } to { opacity: 1; } }
|
| 177 |
-
</style>
|
| 178 |
-
</head>
|
| 179 |
-
<body>
|
| 180 |
-
<div id="app">
|
| 181 |
-
<div class="header">
|
| 182 |
-
<span class="badge badge-task" id="h-task">--</span>
|
| 183 |
-
<span class="badge-scenario" id="h-scenario">--</span>
|
| 184 |
-
<span class="step-counter" id="h-step">Step 0 / 0</span>
|
| 185 |
-
<span class="spacer"></span>
|
| 186 |
-
<span class="score-display" id="h-score">0.000</span>
|
| 187 |
-
<button class="audio-btn active" id="btn-mute" title="Toggle audio">🔊</button>
|
| 188 |
-
<span class="ws-dot off" id="h-ws"></span>
|
| 189 |
-
</div>
|
| 190 |
-
<div class="main">
|
| 191 |
-
<div class="canvas-wrap">
|
| 192 |
-
<canvas id="canvas"></canvas>
|
| 193 |
-
</div>
|
| 194 |
-
<div class="side">
|
| 195 |
-
<div class="side-section" id="sec-goal"><h3>Goal</h3><div class="goal-text" id="goal-text">--</div></div>
|
| 196 |
-
<div class="side-section" id="sec-checklist"><h3>Concept Checklist</h3><div class="checklist-grid" id="checklist"></div></div>
|
| 197 |
-
<div class="side-section" id="sec-scores"><h3>Score Breakdown</h3><div id="scores"></div></div>
|
| 198 |
-
<div class="narration-log"><h3>Narration</h3><div id="narrations"></div></div>
|
| 199 |
-
</div>
|
| 200 |
-
</div>
|
| 201 |
-
<div class="narration-bar">
|
| 202 |
-
<div class="narr-indicator" id="narr-ind">
|
| 203 |
-
<div class="bars"><span></span><span></span><span></span><span></span></div>
|
| 204 |
-
</div>
|
| 205 |
-
<div class="narr-text dim" id="narr-text"></div>
|
| 206 |
-
<div class="music-badge" id="music-badge">♫ music</div>
|
| 207 |
-
</div>
|
| 208 |
-
</div>
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
<!-- End overlay -->
|
| 212 |
-
<div class="overlay hidden" id="overlay">
|
| 213 |
-
<h1 id="ov-title"></h1>
|
| 214 |
-
<div class="final" id="ov-score"></div>
|
| 215 |
-
<div class="reward-chart" id="ov-chart"></div>
|
| 216 |
-
</div>
|
| 217 |
-
<div class="audio-hint" id="audio-hint">Click here to enable audio narration</div>
|
| 218 |
-
<div class="toast" id="toast"></div>
|
| 219 |
-
|
| 220 |
-
<script>
|
| 221 |
-
/* ================================================================
|
| 222 |
-
Audio Viewer — Canvas2D renderer + TTS + background music
|
| 223 |
-
================================================================ */
|
| 224 |
-
|
| 225 |
-
const ROLE_COLORS = {
|
| 226 |
-
default: { fill: '#45475a', stroke: '#585b70', text: '#cdd6f4' },
|
| 227 |
-
current: { fill: '#89b4fa', stroke: '#74c7ec', text: '#1e1e2e' },
|
| 228 |
-
visited: { fill: '#a6e3a1', stroke: '#94e2d5', text: '#1e1e2e' },
|
| 229 |
-
frontier: { fill: '#fab387', stroke: '#f9e2af', text: '#1e1e2e' },
|
| 230 |
-
done: { fill: '#40a02b', stroke: '#a6e3a1', text: '#fff' },
|
| 231 |
-
pivot: { fill: '#f38ba8', stroke: '#eba0ac', text: '#1e1e2e' },
|
| 232 |
-
root: { fill: '#cba6f7', stroke: '#b4befe', text: '#1e1e2e' },
|
| 233 |
-
error: { fill: '#f38ba8', stroke: '#f38ba8', text: '#fff' },
|
| 234 |
-
inactive: { fill: '#313244', stroke: '#45475a', text: '#6c7086' },
|
| 235 |
-
comparing: { fill: '#f9e2af', stroke: '#fab387', text: '#1e1e2e' },
|
| 236 |
-
};
|
| 237 |
-
|
| 238 |
-
const GRID = 70;
|
| 239 |
-
const NODE_R = 26;
|
| 240 |
-
const ARROW_LEN = 10;
|
| 241 |
-
const FONT = '13px Inter, system-ui, sans-serif';
|
| 242 |
-
const FONT_BOLD = 'bold 14px Inter, system-ui, sans-serif';
|
| 243 |
-
const FONT_SMALL = '10px Inter, system-ui, sans-serif';
|
| 244 |
-
const FONT_ANN = '11px Inter, system-ui, sans-serif';
|
| 245 |
-
|
| 246 |
-
let S = {
|
| 247 |
-
entities: {}, relations: [], layout: {}, annotations: [], notes: [],
|
| 248 |
-
taskName: '', scenarioId: '', goal: '', checklist: [], coverage: [],
|
| 249 |
-
maxSteps: 0, step: 0, score: 0, breakdown: {}, narrations: [],
|
| 250 |
-
};
|
| 251 |
-
|
| 252 |
-
let cam = { ox: 40, oy: 30, scale: 1 };
|
| 253 |
-
let canvas, ctx;
|
| 254 |
-
let toastTimer;
|
| 255 |
-
|
| 256 |
-
const $ = id => document.getElementById(id);
|
| 257 |
-
|
| 258 |
-
/* ---- Audio system ---- */
|
| 259 |
-
let audioCtx = null;
|
| 260 |
-
let bgMusic = null;
|
| 261 |
-
let currentTTS = null;
|
| 262 |
-
let isMuted = false;
|
| 263 |
-
let audioUnlocked = false;
|
| 264 |
-
let bgMusicWanted = false;
|
| 265 |
-
let pendingAudioQueue = [];
|
| 266 |
-
const BG_VOLUME = 0.06;
|
| 267 |
-
const BG_DUCK_VOLUME = 0.02;
|
| 268 |
-
|
| 269 |
-
// Create AudioContext eagerly (starts suspended until user gesture)
|
| 270 |
-
try { audioCtx = new (window.AudioContext || window.webkitAudioContext)(); } catch {}
|
| 271 |
-
bgMusic = new Audio('/audio/background.mp3');
|
| 272 |
-
bgMusic.loop = true;
|
| 273 |
-
bgMusic.volume = BG_VOLUME;
|
| 274 |
-
bgMusic.addEventListener('error', () => console.warn('Background music not available'));
|
| 275 |
-
|
| 276 |
-
function unlockAudio() {
|
| 277 |
-
if (audioUnlocked) return;
|
| 278 |
-
audioUnlocked = true;
|
| 279 |
-
$('audio-hint').classList.add('hidden');
|
| 280 |
-
if (audioCtx && audioCtx.state === 'suspended') audioCtx.resume();
|
| 281 |
-
// play+pause to unlock the HTML Audio element for later use
|
| 282 |
-
bgMusic.play().then(() => {
|
| 283 |
-
if (!bgMusicWanted) { bgMusic.pause(); bgMusic.currentTime = 0; }
|
| 284 |
-
else { bgMusic.volume = isMuted ? 0 : BG_VOLUME; $('music-badge').classList.add('on'); }
|
| 285 |
-
}).catch(() => {});
|
| 286 |
-
// flush queued TTS
|
| 287 |
-
if (pendingAudioQueue.length) {
|
| 288 |
-
for (const b64 of pendingAudioQueue) playTTS(b64);
|
| 289 |
-
pendingAudioQueue.length = 0;
|
| 290 |
-
}
|
| 291 |
-
}
|
| 292 |
-
|
| 293 |
-
// Unlock on any user interaction (click, touch, keydown)
|
| 294 |
-
document.addEventListener('click', unlockAudio, { once: true });
|
| 295 |
-
document.addEventListener('touchstart', unlockAudio, { once: true });
|
| 296 |
-
document.addEventListener('keydown', unlockAudio, { once: true });
|
| 297 |
-
|
| 298 |
-
function startBgMusic() {
|
| 299 |
-
bgMusicWanted = true;
|
| 300 |
-
if (!audioUnlocked || isMuted) return;
|
| 301 |
-
bgMusic.currentTime = 0;
|
| 302 |
-
bgMusic.volume = BG_VOLUME;
|
| 303 |
-
bgMusic.play().catch(() => {});
|
| 304 |
-
$('music-badge').classList.add('on');
|
| 305 |
-
}
|
| 306 |
-
|
| 307 |
-
function stopBgMusic() {
|
| 308 |
-
bgMusicWanted = false;
|
| 309 |
-
bgMusic.pause();
|
| 310 |
-
bgMusic.currentTime = 0;
|
| 311 |
-
$('music-badge').classList.remove('on');
|
| 312 |
-
}
|
| 313 |
-
|
| 314 |
-
function duckBgMusic() {
|
| 315 |
-
bgMusic.volume = isMuted ? 0 : BG_DUCK_VOLUME;
|
| 316 |
-
}
|
| 317 |
-
|
| 318 |
-
function unduckBgMusic() {
|
| 319 |
-
bgMusic.volume = isMuted ? 0 : BG_VOLUME;
|
| 320 |
-
}
|
| 321 |
-
|
| 322 |
-
async function playTTS(base64Wav) {
|
| 323 |
-
if (!base64Wav || !audioCtx || isMuted) return;
|
| 324 |
-
if (!audioUnlocked) { pendingAudioQueue.push(base64Wav); return; }
|
| 325 |
-
if (currentTTS) { try { currentTTS.stop(); } catch {} currentTTS = null; }
|
| 326 |
-
|
| 327 |
-
try {
|
| 328 |
-
if (audioCtx.state === 'suspended') await audioCtx.resume();
|
| 329 |
-
const binary = atob(base64Wav);
|
| 330 |
-
const buffer = new ArrayBuffer(binary.length);
|
| 331 |
-
const view = new Uint8Array(buffer);
|
| 332 |
-
for (let i = 0; i < binary.length; i++) view[i] = binary.charCodeAt(i);
|
| 333 |
-
|
| 334 |
-
const audioBuffer = await audioCtx.decodeAudioData(buffer);
|
| 335 |
-
const source = audioCtx.createBufferSource();
|
| 336 |
-
const gain = audioCtx.createGain();
|
| 337 |
-
gain.gain.value = 1.0;
|
| 338 |
-
source.buffer = audioBuffer;
|
| 339 |
-
source.connect(gain).connect(audioCtx.destination);
|
| 340 |
-
source.start();
|
| 341 |
-
currentTTS = source;
|
| 342 |
-
|
| 343 |
-
duckBgMusic();
|
| 344 |
-
$('narr-ind').classList.add('speaking');
|
| 345 |
-
|
| 346 |
-
source.onended = () => {
|
| 347 |
-
currentTTS = null;
|
| 348 |
-
unduckBgMusic();
|
| 349 |
-
$('narr-ind').classList.remove('speaking');
|
| 350 |
-
$('narr-text').classList.add('dim');
|
| 351 |
-
};
|
| 352 |
-
} catch (e) {
|
| 353 |
-
console.warn('TTS playback failed:', e);
|
| 354 |
-
$('narr-ind').classList.remove('speaking');
|
| 355 |
-
}
|
| 356 |
-
}
|
| 357 |
-
|
| 358 |
-
function toggleMute() {
|
| 359 |
-
isMuted = !isMuted;
|
| 360 |
-
const btn = $('btn-mute');
|
| 361 |
-
btn.textContent = isMuted ? '\u{1f507}' : '\u{1f50a}';
|
| 362 |
-
btn.classList.toggle('active', !isMuted);
|
| 363 |
-
bgMusic.volume = isMuted ? 0 : BG_VOLUME;
|
| 364 |
-
if (currentTTS && isMuted) { try { currentTTS.stop(); } catch {} currentTTS = null; }
|
| 365 |
-
}
|
| 366 |
-
|
| 367 |
-
/* ---- Coordinate helpers ---- */
|
| 368 |
-
function wx(ex) { return cam.ox + ex * GRID * cam.scale; }
|
| 369 |
-
function wy(ey) { return cam.oy + ey * GRID * cam.scale; }
|
| 370 |
-
function pos(eid) {
|
| 371 |
-
const p = S.layout[eid];
|
| 372 |
-
if (!p) return null;
|
| 373 |
-
return { x: wx(Number(p.x)), y: wy(Number(p.y)) };
|
| 374 |
-
}
|
| 375 |
-
|
| 376 |
-
/* ---- Auto-fit camera ---- */
|
| 377 |
-
function autoFit() {
|
| 378 |
-
const wrap = document.querySelector('.canvas-wrap');
|
| 379 |
-
const cw = wrap.clientWidth;
|
| 380 |
-
const ch = wrap.clientHeight;
|
| 381 |
-
|
| 382 |
-
let minX = Infinity, maxX = -Infinity, minY = Infinity, maxY = -Infinity;
|
| 383 |
-
let hasPoints = false;
|
| 384 |
-
|
| 385 |
-
for (const [eid, ent] of Object.entries(S.entities)) {
|
| 386 |
-
if (ent && ent.entity_type === 'region' && ent.bounds) {
|
| 387 |
-
minX = Math.min(minX, Number(ent.bounds.x0));
|
| 388 |
-
maxX = Math.max(maxX, Number(ent.bounds.x1));
|
| 389 |
-
minY = Math.min(minY, Number(ent.bounds.y0));
|
| 390 |
-
maxY = Math.max(maxY, Number(ent.bounds.y1));
|
| 391 |
-
hasPoints = true;
|
| 392 |
-
}
|
| 393 |
-
const p = S.layout[eid];
|
| 394 |
-
if (p) {
|
| 395 |
-
minX = Math.min(minX, Number(p.x));
|
| 396 |
-
maxX = Math.max(maxX, Number(p.x));
|
| 397 |
-
minY = Math.min(minY, Number(p.y));
|
| 398 |
-
maxY = Math.max(maxY, Number(p.y));
|
| 399 |
-
hasPoints = true;
|
| 400 |
-
}
|
| 401 |
-
}
|
| 402 |
-
|
| 403 |
-
if (!hasPoints) { cam = { ox: cw / 2, oy: ch / 2, scale: 1 }; return; }
|
| 404 |
-
if (minX === maxX && minY === maxY) {
|
| 405 |
-
cam.scale = 1.5;
|
| 406 |
-
cam.ox = cw / 2 - minX * GRID * cam.scale;
|
| 407 |
-
cam.oy = ch / 2 - minY * GRID * cam.scale;
|
| 408 |
-
return;
|
| 409 |
-
}
|
| 410 |
-
const spanX = (maxX - minX) || 1;
|
| 411 |
-
const spanY = (maxY - minY) || 1;
|
| 412 |
-
const pad = 50;
|
| 413 |
-
const sx = (cw - pad * 2) / (spanX * GRID);
|
| 414 |
-
const sy = (ch - pad * 2) / (spanY * GRID);
|
| 415 |
-
cam.scale = Math.min(sx, sy, 2.2);
|
| 416 |
-
const midX = (minX + maxX) / 2;
|
| 417 |
-
const midY = (minY + maxY) / 2;
|
| 418 |
-
cam.ox = cw / 2 - midX * GRID * cam.scale;
|
| 419 |
-
cam.oy = ch / 2 - midY * GRID * cam.scale;
|
| 420 |
-
}
|
| 421 |
-
|
| 422 |
-
/* ---- Drawing primitives ---- */
|
| 423 |
-
function drawRoundRect(x, y, w, h, r) {
|
| 424 |
-
ctx.beginPath();
|
| 425 |
-
ctx.moveTo(x + r, y);
|
| 426 |
-
ctx.lineTo(x + w - r, y);
|
| 427 |
-
ctx.quadraticCurveTo(x + w, y, x + w, y + r);
|
| 428 |
-
ctx.lineTo(x + w, y + h - r);
|
| 429 |
-
ctx.quadraticCurveTo(x + w, y + h, x + w - r, y + h);
|
| 430 |
-
ctx.lineTo(x + r, y + h);
|
| 431 |
-
ctx.quadraticCurveTo(x, y + h, x, y + h - r);
|
| 432 |
-
ctx.lineTo(x, y + r);
|
| 433 |
-
ctx.quadraticCurveTo(x, y, x + r, y);
|
| 434 |
-
ctx.closePath();
|
| 435 |
-
}
|
| 436 |
-
|
| 437 |
-
function drawArrowhead(fx, fy, tx, ty) {
|
| 438 |
-
const angle = Math.atan2(ty - fy, tx - fx);
|
| 439 |
-
const s = ARROW_LEN * cam.scale;
|
| 440 |
-
ctx.beginPath();
|
| 441 |
-
ctx.moveTo(tx, ty);
|
| 442 |
-
ctx.lineTo(tx - s * Math.cos(angle - Math.PI / 7), ty - s * Math.sin(angle - Math.PI / 7));
|
| 443 |
-
ctx.lineTo(tx - s * Math.cos(angle + Math.PI / 7), ty - s * Math.sin(angle + Math.PI / 7));
|
| 444 |
-
ctx.closePath();
|
| 445 |
-
ctx.fill();
|
| 446 |
-
}
|
| 447 |
-
|
| 448 |
-
/* ---- Main render ---- */
|
| 449 |
-
function render() {
|
| 450 |
-
const dpr = window.devicePixelRatio || 1;
|
| 451 |
-
const rect = canvas.getBoundingClientRect();
|
| 452 |
-
canvas.width = rect.width * dpr;
|
| 453 |
-
canvas.height = rect.height * dpr;
|
| 454 |
-
ctx.setTransform(dpr, 0, 0, dpr, 0, 0);
|
| 455 |
-
ctx.clearRect(0, 0, rect.width, rect.height);
|
| 456 |
-
|
| 457 |
-
const highlighted = new Set();
|
| 458 |
-
for (const a of S.annotations) {
|
| 459 |
-
if (a && a.text === '[highlight]') highlighted.add(a.target_id);
|
| 460 |
-
}
|
| 461 |
-
|
| 462 |
-
// Regions
|
| 463 |
-
for (const [eid, ent] of Object.entries(S.entities)) {
|
| 464 |
-
if (!ent || ent.entity_type !== 'region') continue;
|
| 465 |
-
const b = ent.bounds;
|
| 466 |
-
if (!b) continue;
|
| 467 |
-
const x0 = wx(Number(b.x0)), y0 = wy(Number(b.y0));
|
| 468 |
-
const x1 = wx(Number(b.x1)), y1 = wy(Number(b.y1));
|
| 469 |
-
ctx.save();
|
| 470 |
-
ctx.strokeStyle = 'rgba(137,180,250,0.2)';
|
| 471 |
-
ctx.lineWidth = 1;
|
| 472 |
-
ctx.setLineDash([6, 4]);
|
| 473 |
-
drawRoundRect(x0, y0, x1 - x0, y1 - y0, 8);
|
| 474 |
-
ctx.stroke();
|
| 475 |
-
ctx.setLineDash([]);
|
| 476 |
-
ctx.font = FONT_SMALL;
|
| 477 |
-
ctx.fillStyle = 'rgba(137,180,250,0.45)';
|
| 478 |
-
ctx.fillText(ent.title || eid, x0 + 6, y0 + 13);
|
| 479 |
-
ctx.restore();
|
| 480 |
-
}
|
| 481 |
-
|
| 482 |
-
// Edges
|
| 483 |
-
for (const rel of S.relations) {
|
| 484 |
-
if (!rel || !rel.src || !rel.dst) continue;
|
| 485 |
-
const sp = pos(rel.src), dp = pos(rel.dst);
|
| 486 |
-
if (!sp || !dp) continue;
|
| 487 |
-
const r = NODE_R * cam.scale;
|
| 488 |
-
const dx = dp.x - sp.x, dy = dp.y - sp.y;
|
| 489 |
-
const dist = Math.hypot(dx, dy) || 1;
|
| 490 |
-
const ux = dx / dist, uy = dy / dist;
|
| 491 |
-
const sx = sp.x + ux * r, sy = sp.y + uy * r;
|
| 492 |
-
const ex = dp.x - ux * r, ey = dp.y - uy * r;
|
| 493 |
-
|
| 494 |
-
ctx.save();
|
| 495 |
-
ctx.strokeStyle = '#585b70';
|
| 496 |
-
ctx.lineWidth = 1.5 * cam.scale;
|
| 497 |
-
ctx.beginPath();
|
| 498 |
-
ctx.moveTo(sx, sy);
|
| 499 |
-
ctx.lineTo(ex, ey);
|
| 500 |
-
ctx.stroke();
|
| 501 |
-
ctx.fillStyle = '#585b70';
|
| 502 |
-
drawArrowhead(sx, sy, ex, ey);
|
| 503 |
-
if (rel.label) {
|
| 504 |
-
const mx = (sx + ex) / 2, my = (sy + ey) / 2;
|
| 505 |
-
ctx.font = FONT_SMALL;
|
| 506 |
-
ctx.fillStyle = '#bac2de';
|
| 507 |
-
ctx.textAlign = 'center';
|
| 508 |
-
ctx.fillText(rel.label, mx, my - 6 * cam.scale);
|
| 509 |
-
}
|
| 510 |
-
ctx.restore();
|
| 511 |
-
}
|
| 512 |
-
|
| 513 |
-
// Entities
|
| 514 |
-
for (const [eid, ent] of Object.entries(S.entities)) {
|
| 515 |
-
if (!ent) continue;
|
| 516 |
-
const type = ent.entity_type || 'node';
|
| 517 |
-
if (type === 'region') continue;
|
| 518 |
-
const p = pos(eid);
|
| 519 |
-
if (!p) continue;
|
| 520 |
-
const role = ent.role || 'default';
|
| 521 |
-
const colors = ROLE_COLORS[role] || ROLE_COLORS.default;
|
| 522 |
-
const r = NODE_R * cam.scale;
|
| 523 |
-
const isHl = highlighted.has(eid);
|
| 524 |
-
|
| 525 |
-
ctx.save();
|
| 526 |
-
if (isHl) { ctx.shadowColor = '#f9e2af'; ctx.shadowBlur = 18 * cam.scale; }
|
| 527 |
-
|
| 528 |
-
if (type === 'pointer') {
|
| 529 |
-
ctx.beginPath();
|
| 530 |
-
ctx.moveTo(p.x - 8 * cam.scale, p.y - 10 * cam.scale);
|
| 531 |
-
ctx.lineTo(p.x + 12 * cam.scale, p.y);
|
| 532 |
-
ctx.lineTo(p.x - 8 * cam.scale, p.y + 10 * cam.scale);
|
| 533 |
-
ctx.closePath();
|
| 534 |
-
ctx.fillStyle = colors.fill;
|
| 535 |
-
ctx.fill();
|
| 536 |
-
ctx.strokeStyle = colors.stroke;
|
| 537 |
-
ctx.lineWidth = 1.5 * cam.scale;
|
| 538 |
-
ctx.stroke();
|
| 539 |
-
ctx.shadowBlur = 0;
|
| 540 |
-
ctx.font = FONT_SMALL;
|
| 541 |
-
ctx.fillStyle = '#cdd6f4';
|
| 542 |
-
ctx.textAlign = 'left';
|
| 543 |
-
const label = eid;
|
| 544 |
-
const val = ent.value != null ? ` -> ${ent.value}` : '';
|
| 545 |
-
ctx.fillText(label + val, p.x + 16 * cam.scale, p.y + 4 * cam.scale);
|
| 546 |
-
|
| 547 |
-
} else if (type === 'container') {
|
| 548 |
-
const w = 90 * cam.scale, h = 44 * cam.scale;
|
| 549 |
-
drawRoundRect(p.x - w / 2, p.y - h / 2, w, h, 6 * cam.scale);
|
| 550 |
-
ctx.fillStyle = 'rgba(69,90,100,0.25)';
|
| 551 |
-
ctx.fill();
|
| 552 |
-
ctx.strokeStyle = colors.stroke;
|
| 553 |
-
ctx.lineWidth = 1.5 * cam.scale;
|
| 554 |
-
ctx.setLineDash([4, 3]);
|
| 555 |
-
ctx.stroke();
|
| 556 |
-
ctx.setLineDash([]);
|
| 557 |
-
ctx.shadowBlur = 0;
|
| 558 |
-
ctx.font = FONT_SMALL;
|
| 559 |
-
ctx.fillStyle = '#89b4fa';
|
| 560 |
-
ctx.textAlign = 'center';
|
| 561 |
-
ctx.fillText(ent.title || eid, p.x, p.y - h / 2 - 5 * cam.scale);
|
| 562 |
-
const contents = (ent.contents || []).join(', ');
|
| 563 |
-
if (contents) {
|
| 564 |
-
ctx.font = FONT_SMALL;
|
| 565 |
-
ctx.fillStyle = '#cdd6f4';
|
| 566 |
-
ctx.fillText(contents, p.x, p.y + 3 * cam.scale);
|
| 567 |
-
}
|
| 568 |
-
|
| 569 |
-
} else {
|
| 570 |
-
ctx.beginPath();
|
| 571 |
-
ctx.arc(p.x, p.y, r, 0, Math.PI * 2);
|
| 572 |
-
ctx.fillStyle = colors.fill;
|
| 573 |
-
ctx.fill();
|
| 574 |
-
ctx.strokeStyle = colors.stroke;
|
| 575 |
-
ctx.lineWidth = 2 * cam.scale;
|
| 576 |
-
ctx.stroke();
|
| 577 |
-
if (role === 'error') {
|
| 578 |
-
ctx.beginPath();
|
| 579 |
-
ctx.moveTo(p.x - r * .6, p.y - r * .6);
|
| 580 |
-
ctx.lineTo(p.x + r * .6, p.y + r * .6);
|
| 581 |
-
ctx.strokeStyle = '#fff';
|
| 582 |
-
ctx.lineWidth = 2 * cam.scale;
|
| 583 |
-
ctx.stroke();
|
| 584 |
-
}
|
| 585 |
-
ctx.shadowBlur = 0;
|
| 586 |
-
const valText = ent.value != null ? String(ent.value) : eid;
|
| 587 |
-
ctx.font = FONT_BOLD;
|
| 588 |
-
ctx.fillStyle = colors.text;
|
| 589 |
-
ctx.textAlign = 'center';
|
| 590 |
-
ctx.textBaseline = 'middle';
|
| 591 |
-
ctx.fillText(valText.length > 6 ? valText.slice(0, 6) : valText, p.x, p.y);
|
| 592 |
-
ctx.textBaseline = 'alphabetic';
|
| 593 |
-
if (role !== 'default') {
|
| 594 |
-
ctx.font = FONT_SMALL;
|
| 595 |
-
ctx.fillStyle = colors.stroke;
|
| 596 |
-
ctx.fillText(role, p.x, p.y + r + 12 * cam.scale);
|
| 597 |
-
}
|
| 598 |
-
}
|
| 599 |
-
ctx.restore();
|
| 600 |
-
}
|
| 601 |
-
|
| 602 |
-
// Annotations
|
| 603 |
-
const annMap = new Map();
|
| 604 |
-
for (const a of S.annotations) {
|
| 605 |
-
if (!a || a.text === '[highlight]' || a.text === '[popped]') continue;
|
| 606 |
-
annMap.set(a.target_id, a.text);
|
| 607 |
-
}
|
| 608 |
-
for (const [tid, text] of annMap) {
|
| 609 |
-
const p = pos(tid);
|
| 610 |
-
if (!p) continue;
|
| 611 |
-
const r = NODE_R * cam.scale;
|
| 612 |
-
ctx.save();
|
| 613 |
-
ctx.font = FONT_ANN;
|
| 614 |
-
const tw = ctx.measureText(text).width;
|
| 615 |
-
const px = p.x - tw / 2 - 4, py = p.y - r - 14 * cam.scale;
|
| 616 |
-
drawRoundRect(px, py - 10, tw + 8, 16, 4);
|
| 617 |
-
ctx.fillStyle = 'rgba(250,179,135,0.15)';
|
| 618 |
-
ctx.fill();
|
| 619 |
-
ctx.strokeStyle = 'rgba(250,179,135,0.4)';
|
| 620 |
-
ctx.lineWidth = 0.5;
|
| 621 |
-
ctx.stroke();
|
| 622 |
-
ctx.fillStyle = '#fab387';
|
| 623 |
-
ctx.textAlign = 'center';
|
| 624 |
-
ctx.fillText(text, p.x, py);
|
| 625 |
-
ctx.restore();
|
| 626 |
-
}
|
| 627 |
-
|
| 628 |
-
// Notes
|
| 629 |
-
for (const note of S.notes) {
|
| 630 |
-
if (!note || !note.text) continue;
|
| 631 |
-
let nx = 16, ny = rect.height - 24;
|
| 632 |
-
const regionEnt = note.region && S.entities[note.region];
|
| 633 |
-
if (regionEnt && regionEnt.bounds) {
|
| 634 |
-
nx = wx(Number(regionEnt.bounds.x0)) + 6;
|
| 635 |
-
ny = wy(Number(regionEnt.bounds.y1)) - 6;
|
| 636 |
-
}
|
| 637 |
-
ctx.save();
|
| 638 |
-
ctx.font = FONT_SMALL;
|
| 639 |
-
const t = note.text.length > 50 ? note.text.slice(0, 50) + '…' : note.text;
|
| 640 |
-
ctx.fillStyle = 'rgba(108,112,134,0.7)';
|
| 641 |
-
ctx.textAlign = 'left';
|
| 642 |
-
ctx.fillText(t, nx, ny);
|
| 643 |
-
ctx.restore();
|
| 644 |
-
}
|
| 645 |
-
}
|
| 646 |
-
|
| 647 |
-
/* ---- UI updates ---- */
|
| 648 |
-
function barColor(v) { return v >= .7 ? 'var(--green)' : v >= .4 ? 'var(--yellow)' : 'var(--red)'; }
|
| 649 |
-
function penaltyColor(v) { return v <= 0 ? 'var(--green)' : v < .1 ? 'var(--yellow)' : 'var(--red)'; }
|
| 650 |
-
|
| 651 |
-
function updateHeader() {
|
| 652 |
-
$('h-task').textContent = S.taskName || '--';
|
| 653 |
-
$('h-scenario').textContent = S.scenarioId || '';
|
| 654 |
-
$('h-step').textContent = `Step ${S.step} / ${S.maxSteps}`;
|
| 655 |
-
$('h-score').textContent = S.score.toFixed(3);
|
| 656 |
-
}
|
| 657 |
-
|
| 658 |
-
function updateChecklist() {
|
| 659 |
-
const el = $('checklist');
|
| 660 |
-
el.innerHTML = '';
|
| 661 |
-
for (const c of S.checklist) {
|
| 662 |
-
const ok = S.coverage.includes(c);
|
| 663 |
-
el.innerHTML += `<div class="concept ${ok ? 'covered' : 'uncovered'}"><span class="icon">${ok ? '✓' : '○'}</span>${c}</div>`;
|
| 664 |
-
}
|
| 665 |
-
}
|
| 666 |
-
|
| 667 |
-
function updateScores() {
|
| 668 |
-
const el = $('scores');
|
| 669 |
-
el.innerHTML = '';
|
| 670 |
-
const bd = S.breakdown || {};
|
| 671 |
-
const allKeys = Object.keys(bd).filter(k => k !== 'overall_score' && k !== 'phase' && typeof bd[k] === 'number');
|
| 672 |
-
const subKeys = allKeys.filter(k => !k.startsWith('penalty_'));
|
| 673 |
-
const penKeys = allKeys.filter(k => k.startsWith('penalty_'));
|
| 674 |
-
|
| 675 |
-
for (const k of subKeys) {
|
| 676 |
-
const v = Number(bd[k]) || 0;
|
| 677 |
-
const pct = Math.max(0, Math.min(100, v * 100));
|
| 678 |
-
const label = k.replace(/_score$/, '').replace(/_/g, ' ');
|
| 679 |
-
el.innerHTML += `<div class="score-row"><span class="score-label">${label}</span><div class="score-track"><div class="score-fill" style="width:${pct}%;background:${barColor(v)}">${v.toFixed(2)}</div></div></div>`;
|
| 680 |
-
}
|
| 681 |
-
if (penKeys.length) {
|
| 682 |
-
el.innerHTML += `<div style="margin-top:4px;margin-bottom:3px;font-size:9px;text-transform:uppercase;letter-spacing:1px;color:var(--red);font-weight:600">Penalties (lower = better)</div>`;
|
| 683 |
-
for (const k of penKeys) {
|
| 684 |
-
const v = Number(bd[k]) || 0;
|
| 685 |
-
const pct = Math.max(0, Math.min(100, v * 100));
|
| 686 |
-
const label = k.replace(/^penalty_/, '').replace(/_/g, ' ');
|
| 687 |
-
el.innerHTML += `<div class="score-row"><span class="score-label">${label}</span><div class="score-track"><div class="score-fill" style="width:${pct}%;background:${penaltyColor(v)}">${v.toFixed(2)}</div></div></div>`;
|
| 688 |
-
}
|
| 689 |
-
}
|
| 690 |
-
if (bd.overall_score != null) {
|
| 691 |
-
const v = Number(bd.overall_score);
|
| 692 |
-
const pct = Math.max(0, Math.min(100, v * 100));
|
| 693 |
-
el.innerHTML += `<div class="score-row" style="margin-top:3px;border-top:1px solid var(--border);padding-top:3px"><span class="score-label" style="font-weight:700">overall</span><div class="score-track"><div class="score-fill" style="width:${pct}%;background:${barColor(v)}">${v.toFixed(3)}</div></div></div>`;
|
| 694 |
-
}
|
| 695 |
-
}
|
| 696 |
-
|
| 697 |
-
function addNarration(step, text, reward) {
|
| 698 |
-
const el = $('narrations');
|
| 699 |
-
const cls = reward >= 0 ? 'pos' : 'neg';
|
| 700 |
-
const sign = reward >= 0 ? '+' : '';
|
| 701 |
-
el.innerHTML += `<div class="narr-entry"><span class="step-tag">Step ${step}</span><span class="reward ${cls}">${sign}${reward.toFixed(2)}</span><div class="text">${text || ''}</div></div>`;
|
| 702 |
-
el.parentElement.scrollTop = el.parentElement.scrollHeight;
|
| 703 |
-
}
|
| 704 |
-
|
| 705 |
-
function showNarration(text) {
|
| 706 |
-
const el = $('narr-text');
|
| 707 |
-
el.textContent = text || '';
|
| 708 |
-
el.classList.remove('dim');
|
| 709 |
-
}
|
| 710 |
-
|
| 711 |
-
function showToast(text) {
|
| 712 |
-
const el = $('toast');
|
| 713 |
-
el.textContent = text;
|
| 714 |
-
el.style.display = 'block';
|
| 715 |
-
clearTimeout(toastTimer);
|
| 716 |
-
toastTimer = setTimeout(() => el.style.display = 'none', 4000);
|
| 717 |
-
}
|
| 718 |
-
|
| 719 |
-
/* ---- Message handlers ---- */
|
| 720 |
-
function onClear() {
|
| 721 |
-
S.entities = {}; S.relations = []; S.layout = {};
|
| 722 |
-
S.annotations = []; S.notes = [];
|
| 723 |
-
S.taskName = ''; S.scenarioId = ''; S.goal = '';
|
| 724 |
-
S.checklist = []; S.coverage = [];
|
| 725 |
-
S.maxSteps = 0; S.step = 0; S.score = 0;
|
| 726 |
-
S.breakdown = {}; S.narrations = [];
|
| 727 |
-
updateHeader(); updateChecklist(); updateScores();
|
| 728 |
-
$('narrations').innerHTML = '';
|
| 729 |
-
$('goal-text').textContent = '--';
|
| 730 |
-
$('overlay').classList.add('hidden');
|
| 731 |
-
showNarration('');
|
| 732 |
-
stopBgMusic();
|
| 733 |
-
autoFit(); render();
|
| 734 |
-
}
|
| 735 |
-
|
| 736 |
-
function onReset(m) {
|
| 737 |
-
$('overlay').classList.add('hidden');
|
| 738 |
-
S.taskName = m.task_name || '';
|
| 739 |
-
S.scenarioId = m.scenario_id || '';
|
| 740 |
-
S.goal = m.goal || '';
|
| 741 |
-
S.checklist = m.checklist || [];
|
| 742 |
-
S.coverage = m.coverage || [];
|
| 743 |
-
S.maxSteps = m.max_steps || 0;
|
| 744 |
-
S.step = 0;
|
| 745 |
-
S.score = 0;
|
| 746 |
-
S.breakdown = m.score_breakdown || {};
|
| 747 |
-
S.entities = m.entities || {};
|
| 748 |
-
S.relations = m.relations || [];
|
| 749 |
-
S.layout = m.layout || {};
|
| 750 |
-
S.annotations = m.annotations || [];
|
| 751 |
-
S.notes = m.notes || [];
|
| 752 |
-
S.narrations = [];
|
| 753 |
-
|
| 754 |
-
updateHeader();
|
| 755 |
-
updateChecklist();
|
| 756 |
-
updateScores();
|
| 757 |
-
$('narrations').innerHTML = '';
|
| 758 |
-
$('goal-text').textContent = S.goal;
|
| 759 |
-
showNarration(S.goal);
|
| 760 |
-
autoFit();
|
| 761 |
-
render();
|
| 762 |
-
|
| 763 |
-
startBgMusic();
|
| 764 |
-
if (m.audio) playTTS(m.audio);
|
| 765 |
-
}
|
| 766 |
-
|
| 767 |
-
function onStep(m) {
|
| 768 |
-
S.step = m.step || S.step;
|
| 769 |
-
S.score = m.score != null ? m.score : S.score;
|
| 770 |
-
S.breakdown = m.score_breakdown || S.breakdown;
|
| 771 |
-
S.coverage = m.coverage || S.coverage;
|
| 772 |
-
S.entities = m.entities || S.entities;
|
| 773 |
-
S.relations = m.relations || S.relations;
|
| 774 |
-
S.layout = m.layout || S.layout;
|
| 775 |
-
S.annotations = m.annotations || S.annotations;
|
| 776 |
-
S.notes = m.notes || S.notes;
|
| 777 |
-
S.maxSteps = S.step + (m.remaining_step_budget || 0);
|
| 778 |
-
|
| 779 |
-
updateHeader();
|
| 780 |
-
updateChecklist();
|
| 781 |
-
updateScores();
|
| 782 |
-
autoFit();
|
| 783 |
-
render();
|
| 784 |
-
|
| 785 |
-
addNarration(m.step, m.narration || '', m.reward || 0);
|
| 786 |
-
showNarration(m.narration || '');
|
| 787 |
-
if (m.error) showToast(m.error);
|
| 788 |
-
if (m.audio) playTTS(m.audio);
|
| 789 |
-
}
|
| 790 |
-
|
| 791 |
-
function onEnd(m) {
|
| 792 |
-
stopBgMusic();
|
| 793 |
-
|
| 794 |
-
const ov = $('overlay');
|
| 795 |
-
ov.classList.remove('hidden');
|
| 796 |
-
$('ov-title').textContent = `Episode Done — ${m.task_name || ''}`;
|
| 797 |
-
$('ov-title').className = m.success ? 'ok' : 'fail';
|
| 798 |
-
$('ov-score').textContent = `Score: ${(m.score || 0).toFixed(3)} | Steps: ${m.steps || 0}`;
|
| 799 |
-
|
| 800 |
-
const chart = $('ov-chart');
|
| 801 |
-
chart.innerHTML = '';
|
| 802 |
-
const rr = m.rewards || [];
|
| 803 |
-
const mx = Math.max(0.01, ...rr.map(r => Math.abs(r)));
|
| 804 |
-
for (let i = 0; i < rr.length; i++) {
|
| 805 |
-
const r = rr[i];
|
| 806 |
-
const h = Math.max(2, (Math.abs(r) / mx) * 100);
|
| 807 |
-
chart.innerHTML += `<div class="bar" style="height:${h}px;background:${r >= 0 ? 'var(--accent)' : 'var(--red)'}"><span class="lbl">${i + 1}</span></div>`;
|
| 808 |
-
}
|
| 809 |
-
ov.onclick = () => ov.classList.add('hidden');
|
| 810 |
-
}
|
| 811 |
-
|
| 812 |
-
/* ---- Long-poll (single pending request, no spam) ---- */
|
| 813 |
-
let pollCursor = 0;
|
| 814 |
-
let pollRunning = false;
|
| 815 |
-
|
| 816 |
-
async function pollLoop() {
|
| 817 |
-
if (pollRunning) return;
|
| 818 |
-
pollRunning = true;
|
| 819 |
-
while (true) {
|
| 820 |
-
try {
|
| 821 |
-
const resp = await fetch(`/poll?since=${pollCursor}`);
|
| 822 |
-
if (!resp.ok) {
|
| 823 |
-
$('h-ws').className = 'ws-dot off';
|
| 824 |
-
await new Promise(r => setTimeout(r, 2000));
|
| 825 |
-
continue;
|
| 826 |
-
}
|
| 827 |
-
$('h-ws').className = 'ws-dot on';
|
| 828 |
-
const data = await resp.json();
|
| 829 |
-
pollCursor = data.next;
|
| 830 |
-
for (const m of (data.messages || [])) {
|
| 831 |
-
if (m.type === 'clear') onClear();
|
| 832 |
-
else if (m.type === 'reset') onReset(m);
|
| 833 |
-
else if (m.type === 'step') onStep(m);
|
| 834 |
-
else if (m.type === 'end') onEnd(m);
|
| 835 |
-
else if (m.type === 'shutdown') { stopBgMusic(); showNarration('All episodes complete.'); }
|
| 836 |
-
}
|
| 837 |
-
} catch {
|
| 838 |
-
$('h-ws').className = 'ws-dot off';
|
| 839 |
-
await new Promise(r => setTimeout(r, 2000));
|
| 840 |
-
}
|
| 841 |
-
}
|
| 842 |
-
}
|
| 843 |
-
|
| 844 |
-
function connect() {
|
| 845 |
-
pollCursor = 0;
|
| 846 |
-
pollLoop();
|
| 847 |
-
}
|
| 848 |
-
|
| 849 |
-
|
| 850 |
-
/* ---- Mute button ---- */
|
| 851 |
-
$('btn-mute').addEventListener('click', toggleMute);
|
| 852 |
-
|
| 853 |
-
/* ---- Init ---- */
|
| 854 |
-
function init() {
|
| 855 |
-
canvas = $('canvas');
|
| 856 |
-
ctx = canvas.getContext('2d');
|
| 857 |
-
window.addEventListener('resize', () => { autoFit(); render(); });
|
| 858 |
-
autoFit();
|
| 859 |
-
render();
|
| 860 |
-
connect();
|
| 861 |
-
}
|
| 862 |
-
init();
|
| 863 |
-
</script>
|
| 864 |
-
</body>
|
| 865 |
-
</html>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|