kdemon1011 commited on
Commit
15503f9
·
verified ·
1 Parent(s): 81f5b19

Upload folder using huggingface_hub

Browse files
.dockerignore ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__
2
+ *.pyc
3
+ .git
4
+ .env
5
+ outputs/
6
+ comparison.md
7
+ play.html
8
+ play_server.py
9
+ *.egg-info
.env.example ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ── Environment Server Configuration ──
2
+ OPENENV_PORT=8000
3
+ MAX_CONCURRENT_ENVS=8
4
+ ENABLE_WEB_INTERFACE=true
5
+ RENDER_MODE=svg
6
+ MAX_BOARD_SIZE=12
7
+ # VISUAL_MEMORY_SCENARIOS_DIR= # Optional: override scenario directory path
8
+
9
+ # ── LLM Configuration (used by run_eval.py) ──
10
+ LLM_MODEL=gpt-4o
11
+ LLM_TEMPERATURE=0.0
12
+ LLM_MAX_TOKENS=1024
13
+
14
+ # ── API Keys ──
15
+ # Only the key for your chosen --model provider is required.
16
+
17
+ # OpenAI (for gpt-4o, gpt-5.4, o3-pro, etc.)
18
+ OPENAI_API_KEY=
19
+ OPENAI_API_BASE=https://api.openai.com/v1
20
+
21
+ # Anthropic (for claude-sonnet-4-6, claude-opus-4-6, etc.)
22
+ ANTHROPIC_API_KEY=
23
+
24
+ # Google (for gemini-2.5-pro, etc.)
25
+ GOOGLE_API_KEY=
26
+
27
+ # For local models via Ollama — no key needed, just run:
28
+ # ollama serve && ollama pull llama3
29
+ # Then use: --model ollama/llama3
README.md CHANGED
@@ -7,10 +7,11 @@ sdk: docker
7
  pinned: false
8
  license: mit
9
  app_port: 8000
 
10
  tags:
11
  - openenv
 
12
  - rl-environment
13
- base_path: /web
14
  ---
15
 
16
  # Visual Memory Gym — *Phantom Grid*
@@ -19,6 +20,26 @@ base_path: /web
19
 
20
  An OpenEnv RL environment where agents must navigate grids with hidden hazards, memorize revealed patterns, and make optimal decisions with incomplete information. The name *Phantom Grid* reflects the core challenge: invisible dangers lurk beneath every cell, and the agent must deduce their locations from indirect signals — like hunting phantoms by their shadows. Designed to stress spatial reasoning, working memory, uncertainty handling, and risk-averse planning — areas where frontier LLMs consistently underperform.
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  ## What Is This Gym?
23
 
24
  The Visual Memory gym places an LLM agent on a grid board where most cells are initially hidden. The agent must use MCP tools to reveal cells one at a time, interpret the signals (clues about nearby hazards), flag hazard locations, and submit a solution — all within a limited step budget. Every reveal risks hitting a hazard (which can end the game), so the agent must balance information gathering with caution.
@@ -135,60 +156,233 @@ These look useful but always return errors. Models must learn to avoid them.
135
  | `peek_hidden_cell` | "View hidden cell without revealing" | Always fails — peeking disabled |
136
  | `undo_last_action` | "Revert the most recent action" | Always fails — actions are irreversible |
137
 
138
- ## Reward System (3 Layers)
 
 
139
 
140
- ### Layer 1Environment Step Rewards (built into the gym)
141
- Per-tool rewards computed inside `memory_environment.py`. Small signals for safe reveals (+0.05), hazard hits (-0.20), correct submissions (+0.50), and distractor use (-0.10).
142
 
143
- ### Layer 2 Custom Episode Rewards (`rewards/visual_memory_checks.py`)
144
- Weighted episode-level score computed from the full trajectory:
145
 
146
  | Component | Weight | Description |
147
  |---|---|---|
148
- | Final Correctness | 0.35 | F1 score of submitted solution |
149
- | Safety Score | 0.20 | Fraction of reveals that avoided hazards |
150
- | Evidence Support | 0.15 | Used recall/inspect before committing |
151
- | Irreversible Penalty | -0.15 | Deducted for hazard hits |
152
- | Efficiency | 0.10 | Steps used relative to budget |
153
- | Unnecessary Guessing | -0.05 | Deducted for trap tool use or repeated reveals |
 
 
 
 
 
 
 
 
 
154
 
155
- ### Layer 3 — OpenEnv Transform Rewards (`rewards/transforms/visual_memory.py`)
156
- Per-step rewards for RL training with sharper signal differentiation. Safe reveals (+0.15), hazard hits (-0.40), correct flags (+0.20), distractor use (-0.25), correct submission (+1.0).
157
 
158
- ## Running
 
 
159
 
160
- ```bash
161
- # Install for AutoEnv discovery
162
- pip install -e visual-memory/
163
 
164
- # Build Docker image
165
- cd visual-memory && docker build -t openenv-visual-memory -f server/Dockerfile .
166
 
167
- # Run container
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  docker run -d --name visual-memory -p 8000:8000 openenv-visual-memory
169
 
170
  # Verify
171
  curl http://localhost:8000/health
172
- curl http://localhost:8000/metadata
173
-
174
- # Evaluate (single model)
175
- python run_eval.py --gym visual_memory --model gpt-5.4 --save --trajectory
176
 
177
- # Evaluate (parallel, both reward modes)
178
- python run_eval.py --gym visual_memory \
179
- --model gpt-5.4,claude-sonnet-4-6,claude-opus-4-6 \
180
- --parallel 3 --reward-mode custom --save --trajectory
181
 
182
- python run_eval.py --gym visual_memory \
183
- --model gpt-5.4,claude-sonnet-4-6,claude-opus-4-6 \
184
  --parallel 3 --reward-mode openenv --save --trajectory
185
 
186
- # Stop
 
 
 
187
  docker stop visual-memory && docker rm visual-memory
188
  ```
189
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
  ## Configuration (.env)
191
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  | Variable | Default | Description |
193
  |----------|---------|-------------|
194
  | `OPENENV_PORT` | `8000` | OpenEnv server port (exposed) |
 
7
  pinned: false
8
  license: mit
9
  app_port: 8000
10
+ base_path: /web
11
  tags:
12
  - openenv
13
+ - openenv-0.2.3
14
  - rl-environment
 
15
  ---
16
 
17
  # Visual Memory Gym — *Phantom Grid*
 
20
 
21
  An OpenEnv RL environment where agents must navigate grids with hidden hazards, memorize revealed patterns, and make optimal decisions with incomplete information. The name *Phantom Grid* reflects the core challenge: invisible dangers lurk beneath every cell, and the agent must deduce their locations from indirect signals — like hunting phantoms by their shadows. Designed to stress spatial reasoning, working memory, uncertainty handling, and risk-averse planning — areas where frontier LLMs consistently underperform.
22
 
23
+ ## Hugging Face Space Deployment
24
+
25
+ This Space is built from OpenEnV environment `visual_memory`.
26
+
27
+ - **Space URL**: `https://huggingface.co/spaces/huzzle-labs/visual_memory`
28
+ - **OpenEnV pinned ref**: `0.2.3`
29
+ - **Hub tag**: `openenv`
30
+
31
+ ### Connecting from Code
32
+
33
+ ```python
34
+ from openenv import AutoEnv
35
+
36
+ env = AutoEnv.from_env("visual_memory", base_url="https://huzzle-labs-visual-memory.hf.space")
37
+ with env:
38
+ env.reset()
39
+ tools = env.list_tools()
40
+ print(f"Available tools: {[t.name for t in tools]}")
41
+ ```
42
+
43
  ## What Is This Gym?
44
 
45
  The Visual Memory gym places an LLM agent on a grid board where most cells are initially hidden. The agent must use MCP tools to reveal cells one at a time, interpret the signals (clues about nearby hazards), flag hazard locations, and submit a solution — all within a limited step budget. Every reveal risks hitting a hazard (which can end the game), so the agent must balance information gathering with caution.
 
156
  | `peek_hidden_cell` | "View hidden cell without revealing" | Always fails — peeking disabled |
157
  | `undo_last_action` | "Revert the most recent action" | Always fails — actions are irreversible |
158
 
159
+ ## Reward System
160
+
161
+ This gym ships with **two** reward modes, selectable via `--reward-mode`:
162
 
163
+ ### Custom RewardsEpisode-Level (`rewards/checks.py`)
 
164
 
165
+ The `VisualMemoryChecker` verifies ground truth from the episode trajectory and computes a weighted 6-component score:
 
166
 
167
  | Component | Weight | Description |
168
  |---|---|---|
169
+ | `final_correctness` | 0.35 | Was the submission correct? (F1 for partial) |
170
+ | `safety_score` | 0.20 | Fraction of reveals that didn't hit hazards |
171
+ | `evidence_support` | 0.15 | Did the agent gather evidence before submitting? |
172
+ | `irreversible_penalty` | 0.15 | Hazard hits (0 = no penalty, 2+ = full penalty) |
173
+ | `efficiency` | 0.10 | Steps used relative to budget |
174
+ | `unnecessary_guessing` | 0.05 | Trap tool usage + repeated reveals |
175
+
176
+ ```python
177
+ from rewards.checks import VisualMemoryChecker
178
+
179
+ checker = VisualMemoryChecker()
180
+ checker.set_episode(episode)
181
+ reward = checker.compute_episode_reward()
182
+ # {'final_correctness': 1.0, 'safety_score': 0.85, ..., 'total': 0.78}
183
+ ```
184
 
185
+ The base `RewardCalculator` (`rewards/base.py`) wraps this into the standard 3-component formula used across all gyms:
 
186
 
187
+ ```
188
+ total = 0.25 × structural + 0.15 × efficiency + 0.60 × ground_truth + penalty
189
+ ```
190
 
191
+ ### OpenEnV Transforms — Per-Step (`rewards/transforms.py`)
 
 
192
 
193
+ The `VisualMemoryStepTransform` provides fine-grained per-step rewards for RL training (GRPO). Each tool call receives a reward based on its outcome:
 
194
 
195
+ | Tool | Success | Failure |
196
+ |---|---|---|
197
+ | `reveal_cell` (safe) | +0.15 | — |
198
+ | `reveal_cell` (hazard) | -0.40 | — |
199
+ | `flag_cell` | +0.20 | -0.10 |
200
+ | `submit_solution` (correct) | +1.0 | -0.50 |
201
+ | `recall_log` | +0.10 | 0.0 |
202
+ | `inspect_region` | +0.08 | -0.10 |
203
+ | `get_board_view` / `get_status` | +0.05 | 0.0 |
204
+ | `move_viewport` | +0.10 | -0.10 |
205
+ | Distractor traps | -0.25 | -0.25 |
206
+
207
+ ```python
208
+ from rewards.transforms import VisualMemoryStepTransform
209
+
210
+ transform = VisualMemoryStepTransform()
211
+ scored_obs = transform(observation)
212
+ print(scored_obs.reward) # e.g., +0.15 for a safe reveal
213
+ ```
214
+
215
+ The `OpenEnvRewardCalculator` (`rewards/base.py`) combines per-step rewards with ground truth into the same weighted formula, using sign-based quality scoring.
216
+
217
+ ## Evaluation
218
+
219
+ The included `run_eval.py` runs an LLM agent against scenarios and scores results.
220
+
221
+ ### Quick Start
222
+
223
+ ```bash
224
+ cd visual-memory
225
+ pip install -e .
226
+
227
+ # Build and run the environment
228
+ docker build -t openenv-visual-memory -f server/Dockerfile .
229
  docker run -d --name visual-memory -p 8000:8000 openenv-visual-memory
230
 
231
  # Verify
232
  curl http://localhost:8000/health
 
 
 
 
233
 
234
+ # Evaluate (single model, custom rewards)
235
+ python run_eval.py --model gpt-5.4 --save --trajectory
 
 
236
 
237
+ # Evaluate (multiple models, per-step rewards)
238
+ python run_eval.py --model gpt-5.4,claude-sonnet-4-6,claude-opus-4-6 \
239
  --parallel 3 --reward-mode openenv --save --trajectory
240
 
241
+ # Evaluate a specific scenario
242
+ python run_eval.py --model gpt-5.4 --scenario directional_trap_8x8
243
+
244
+ # Cleanup
245
  docker stop visual-memory && docker rm visual-memory
246
  ```
247
 
248
+ ### Output Paths
249
+
250
+ | Output | Path |
251
+ |---|---|
252
+ | Results markdown | `outputs/results/<run_id>.md` |
253
+ | Trajectory JSON | `outputs/trajectories/<run_id>/<model>.json` |
254
+
255
+ Results files append per-model sections so you can accumulate multiple model runs in one file.
256
+
257
+ ### CLI Arguments
258
+
259
+ | Argument | Default | Description |
260
+ |---|---|---|
261
+ | `--model` | `gpt-4o` | LiteLLM model string (comma-separated for parallel) |
262
+ | `--scenario` | all | Run a specific scenario by ID |
263
+ | `--reward-mode` | `custom` | `custom` (episode-level) or `openenv` (per-step) |
264
+ | `--parallel` | `1` | Number of models to run in parallel |
265
+ | `--save` | off | Save results markdown |
266
+ | `--trajectory` | off | Save trajectory JSON |
267
+ | `--temperature` | `0.0` | LLM sampling temperature |
268
+ | `--max-tokens` | `1024` | Max tokens per LLM response |
269
+ | `--run-id` | auto | Run identifier for grouping outputs |
270
+ | `--verbose` | off | Enable debug logging |
271
+
272
+ ## Play Manually (Human Mode)
273
+
274
+ You can play Phantom Grid yourself in a browser — no LLM, no Docker required.
275
+
276
+ ### Quick Start
277
+
278
+ ```bash
279
+ cd visual-memory
280
+ pip install fastapi uvicorn svgwrite numpy pydantic
281
+ python play_server.py
282
+ ```
283
+
284
+ Then open **http://localhost:8001** in your browser.
285
+
286
+ ### How to Play
287
+
288
+ 1. **Pick a scenario** from the right panel (e.g. "Directional Trap 8x8")
289
+ 2. **Click cells** on the board — what happens depends on your click mode:
290
+ - **Reveal** mode (default, blue) — uncovers the cell. You'll see:
291
+ - Empty (white) — nothing here
292
+ - Signal (light blue) — a clue about nearby hazards (number = adjacent hazard count, letters like "N,W" = direction to hazards)
293
+ - Hazard (red skull) — danger! Too many hits = game over
294
+ - Key (gold) — collect these in key-hunt scenarios
295
+ - **Flag Hazard** mode (red) — marks a cell as a suspected hazard. Click a flagged cell again to unflag it.
296
+ 3. **Use signals** to deduce hazard positions:
297
+ - A signal showing "2" means 2 hazards are adjacent (8 surrounding cells)
298
+ - A signal showing "N,E" means hazards lie to the North and East
299
+ - Range signals like "1-3" mean between 1 and 3 adjacent hazards
300
+ 4. **Flag all hazards**, then click **SUBMIT SOLUTION** to see your score
301
+ 5. After game over, click any scenario button to **start a fresh game**
302
+
303
+ ### Tips
304
+
305
+ - Start by revealing cells in the center — they give the most signal coverage
306
+ - Use the **Recall Log** button to review all signals you've discovered
307
+ - In fog-of-war scenarios, use **Move Viewport** to explore — you can only see a small area
308
+ - Avoid the distractor tools (auto_solve, peek, undo) — they always fail
309
+ - The play server runs on **port 8001** and is completely separate from the OpenEnv server (port 8000)
310
+
311
+ ## Project Structure
312
+
313
+ ```
314
+ visual-memory/
315
+ ├── __init__.py # Package exports (env + rewards)
316
+ ├── client.py # OpenEnv client integration
317
+ ├── models.py # Action/Observation data models
318
+ ├── openenv.yaml # OpenEnv AutoEnv manifest
319
+ ├── pyproject.toml # Dependencies (openenv-core v0.2.3)
320
+ ├── Dockerfile # Root Dockerfile for HF Spaces
321
+ ├── .dockerignore
322
+ ├── run_eval.py # LLM evaluation runner
323
+ ├── play.html # Human play mode UI
324
+ ├── play_server.py # Human play mode server
325
+
326
+ ├── rewards/ # Reward system (both modes)
327
+ │ ├── __init__.py
328
+ │ ├── base.py # Scenario, EpisodeLog, RewardCalculator,
329
+ │ │ # StepRewardTransform, OpenEnvRewardCalculator
330
+ │ ├── checks.py # VisualMemoryChecker (episode-level)
331
+ │ └── transforms.py # VisualMemoryStepTransform (per-step)
332
+
333
+ ├── scenarios/ # Scenario definitions
334
+ │ ├── __init__.py
335
+ │ ├── definitions.py # 10 Scenario objects (Python)
336
+ │ └── *.json # Scenario board configs
337
+
338
+ ├── agent/ # LLM agent runner
339
+ │ ├── __init__.py
340
+ │ ├── llm.py # LiteLLM wrapper
341
+ │ └── runner.py # AgentRunner (gym-agnostic)
342
+
343
+ ├── server/ # OpenEnv environment server
344
+ │ ├── __init__.py
345
+ │ ├── app.py # FastAPI + FastMCP server
346
+ │ ├── memory_environment.py # MCPEnvironment implementation
347
+ │ ├── engine.py # Game engine (hidden state)
348
+ │ ├── renderer.py # SVG board renderer
349
+ │ └── Dockerfile # Server-only Dockerfile
350
+
351
+ └── outputs/ # Evaluation outputs (gitignored)
352
+ ├── results/ # Markdown result files
353
+ └── trajectories/ # JSON trajectory files
354
+ ```
355
+
356
  ## Configuration (.env)
357
 
358
+ Copy `.env.example` to `.env` and fill in your API keys:
359
+
360
+ ```bash
361
+ cp .env.example .env
362
+ # Edit .env with your API keys
363
+ ```
364
+
365
+ ### LLM API Keys
366
+
367
+ | Variable | Required For | Description |
368
+ |----------|---|---|
369
+ | `OPENAI_API_KEY` | `gpt-4o`, `gpt-5.4`, `o3-pro` | OpenAI API key |
370
+ | `OPENAI_API_BASE` | OpenAI | API base URL (default: `https://api.openai.com/v1`) |
371
+ | `ANTHROPIC_API_KEY` | `claude-sonnet-4-6`, `claude-opus-4-6` | Anthropic API key |
372
+ | `GOOGLE_API_KEY` | `gemini-2.5-pro` | Google AI API key |
373
+
374
+ Only the key for your chosen `--model` provider is required. For local models via Ollama, no key is needed.
375
+
376
+ ### LLM Defaults
377
+
378
+ | Variable | Default | Description |
379
+ |----------|---------|-------------|
380
+ | `LLM_MODEL` | `gpt-4o` | Default model when `--model` is not specified |
381
+ | `LLM_TEMPERATURE` | `0.0` | Default sampling temperature |
382
+ | `LLM_MAX_TOKENS` | `1024` | Default max tokens per response |
383
+
384
+ ### Environment Server
385
+
386
  | Variable | Default | Description |
387
  |----------|---------|-------------|
388
  | `OPENENV_PORT` | `8000` | OpenEnv server port (exposed) |
__init__.py CHANGED
@@ -10,6 +10,7 @@ from .models import (
10
  ListToolsAction,
11
  ListToolsObservation,
12
  )
 
13
 
14
  __all__ = [
15
  "VisualMemoryEnv",
@@ -20,4 +21,6 @@ __all__ = [
20
  "CallToolObservation",
21
  "ListToolsAction",
22
  "ListToolsObservation",
 
 
23
  ]
 
10
  ListToolsAction,
11
  ListToolsObservation,
12
  )
13
+ from .rewards import VisualMemoryChecker, VisualMemoryStepTransform
14
 
15
  __all__ = [
16
  "VisualMemoryEnv",
 
21
  "CallToolObservation",
22
  "ListToolsAction",
23
  "ListToolsObservation",
24
+ "VisualMemoryChecker",
25
+ "VisualMemoryStepTransform",
26
  ]
agent/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .runner import AgentRunner
2
+ from .llm import LLMClient
3
+
4
+ __all__ = ["AgentRunner", "LLMClient"]
agent/llm.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LLM abstraction layer using LiteLLM.
3
+
4
+ Supports any model LiteLLM supports — switch with a single string:
5
+ - OpenAI: "gpt-4o", "gpt-5.4", "o3-pro"
6
+ - Anthropic: "claude-opus-4-6", "claude-sonnet-4-6"
7
+ - Local: "ollama/llama3", "ollama/mistral"
8
+ - And 100+ more providers
9
+
10
+ API keys are read from environment variables (loaded from root .env):
11
+ OPENAI_API_KEY, ANTHROPIC_API_KEY, etc.
12
+
13
+ Usage:
14
+ from agent.llm import LLMClient
15
+
16
+ llm = LLMClient(model="gpt-4o")
17
+ response = llm.chat(
18
+ messages=[{"role": "user", "content": "Hello"}],
19
+ tools=[...],
20
+ )
21
+ """
22
+
23
+ import json
24
+ import logging
25
+ from typing import Any, Dict, List, Optional
26
+
27
+ import litellm
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+
32
+ class LLMClient:
33
+ """
34
+ Thin wrapper around LiteLLM for consistent tool-calling across providers.
35
+
36
+ The same code works whether you're hitting GPT-4o, Claude, or a local
37
+ Ollama model — LiteLLM handles the translation.
38
+ """
39
+
40
+ _REASONING_MODELS = {"o3-pro", "o3-mini", "o3", "o1", "o1-mini", "o1-pro", "gpt-5"}
41
+
42
+ def __init__(
43
+ self,
44
+ model: str,
45
+ temperature: float = 0.0,
46
+ max_tokens: int = 1024,
47
+ ):
48
+ self.model = model
49
+
50
+ if model in self._REASONING_MODELS:
51
+ self.temperature = 1.0
52
+ self.max_tokens = max(max_tokens, 4096)
53
+ if temperature != 1.0:
54
+ logger.info(f"Model {model} requires temperature=1.0, overriding from {temperature}")
55
+ else:
56
+ self.temperature = temperature
57
+ self.max_tokens = max_tokens
58
+
59
+ def chat(
60
+ self,
61
+ messages: List[Dict[str, Any]],
62
+ tools: Optional[List[Dict[str, Any]]] = None,
63
+ ) -> Any:
64
+ """
65
+ Send messages to the LLM and get a response.
66
+
67
+ Args:
68
+ messages: Conversation history in OpenAI format
69
+ tools: Optional list of tools in OpenAI function-calling format
70
+
71
+ Returns:
72
+ LiteLLM ModelResponse (same shape as OpenAI ChatCompletion).
73
+ """
74
+ kwargs: Dict[str, Any] = {
75
+ "model": self.model,
76
+ "messages": messages,
77
+ "temperature": self.temperature,
78
+ "max_tokens": self.max_tokens,
79
+ }
80
+
81
+ if tools:
82
+ kwargs["tools"] = tools
83
+ kwargs["tool_choice"] = "auto"
84
+
85
+ logger.debug(f"LLM request: model={self.model}, messages={len(messages)}, tools={len(tools or [])}")
86
+ response = litellm.completion(**kwargs)
87
+ logger.debug(f"LLM response: finish_reason={response.choices[0].finish_reason}")
88
+
89
+ return response
90
+
91
+ @staticmethod
92
+ def extract_tool_calls(response) -> List[Dict[str, Any]]:
93
+ """Extract tool calls from an LLM response."""
94
+ choice = response.choices[0]
95
+ if not choice.message.tool_calls:
96
+ return []
97
+
98
+ calls = []
99
+ for tc in choice.message.tool_calls:
100
+ args = tc.function.arguments
101
+ if isinstance(args, str):
102
+ args = json.loads(args)
103
+ calls.append({
104
+ "id": tc.id,
105
+ "name": tc.function.name,
106
+ "arguments": args,
107
+ })
108
+ return calls
109
+
110
+ @staticmethod
111
+ def get_text_response(response) -> Optional[str]:
112
+ """Extract plain text content from an LLM response (if any)."""
113
+ choice = response.choices[0]
114
+ return choice.message.content
agent/runner.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gym-agnostic Agent Runner — connects an LLM to any OpenEnv environment.
3
+
4
+ This module is the CORE of the evaluation platform. It:
5
+ 1. Receives a pre-connected OpenEnv client (from AutoEnv discovery)
6
+ 2. Discovers tools via list_tools()
7
+ 3. Gives the LLM a scenario prompt + available tools
8
+ 4. Loops: LLM reasons → agent calls env.step() → observation → LLM reasons again
9
+ 5. Collects an EpisodeLog with timestamps for reward calculation + trajectory logging
10
+
11
+ Usage:
12
+ from openenv import AutoEnv
13
+ env = AutoEnv.from_env("visual_memory", base_url="http://localhost:8000")
14
+ runner = AgentRunner(model="gpt-4o", env_client=env)
15
+ episode, breakdown = runner.run_scenario(scenario, checker)
16
+ """
17
+
18
+ import json
19
+ import logging
20
+ import time
21
+ from datetime import datetime, timezone, timedelta
22
+ from typing import Any, Dict, List, Tuple
23
+
24
+ IST = timezone(timedelta(hours=5, minutes=30))
25
+
26
+ from openenv.core.mcp_client import MCPToolClient
27
+ from openenv.core.env_server.mcp_types import CallToolAction, CallToolObservation, Tool
28
+
29
+ from rewards.base import (
30
+ EpisodeLog,
31
+ RewardBreakdown,
32
+ RewardCalculator,
33
+ Scenario,
34
+ OpenEnvRewardCalculator,
35
+ )
36
+ from .llm import LLMClient
37
+
38
+ logger = logging.getLogger(__name__)
39
+
40
+
41
+ SYSTEM_PROMPT = """\
42
+ You are an AI agent interacting with an environment through tools.
43
+
44
+ Your job:
45
+ 1. Read the task description carefully.
46
+ 2. Use the available tools to complete the task.
47
+ 3. Call tools one at a time. Wait for each result before deciding the next step.
48
+ 4. When the task is complete, respond with a plain text summary of what you did.
49
+ Do NOT call any more tools after you're done.
50
+
51
+ Rules:
52
+ - Only use tools that are listed as available.
53
+ - Provide all required arguments for each tool call.
54
+ - If a tool call fails, read the error and decide how to recover.
55
+ - Be efficient — complete the task in as few steps as possible.
56
+ - When you're done, clearly state what you accomplished.
57
+ """
58
+
59
+
60
+ def mcp_tools_to_openai(tools: List[Tool]) -> List[Dict[str, Any]]:
61
+ """Convert OpenEnv MCP tool definitions to OpenAI function-calling format."""
62
+ openai_tools = []
63
+ for tool in tools:
64
+ schema = tool.input_schema or {"type": "object", "properties": {}}
65
+ if "type" not in schema:
66
+ schema["type"] = "object"
67
+ if "properties" not in schema:
68
+ schema["properties"] = {}
69
+
70
+ openai_tools.append({
71
+ "type": "function",
72
+ "function": {
73
+ "name": tool.name,
74
+ "description": tool.description or "",
75
+ "parameters": schema,
76
+ },
77
+ })
78
+ return openai_tools
79
+
80
+
81
+ def _observation_to_str(step_result) -> str:
82
+ """Convert an OpenEnv step result to a string the LLM can read."""
83
+ obs = step_result.observation
84
+ if isinstance(obs, CallToolObservation):
85
+ if obs.error:
86
+ return json.dumps({"error": obs.error.message}, indent=2)
87
+ result = obs.result
88
+ if hasattr(result, "data"):
89
+ result = result.data
90
+ elif isinstance(result, dict) and "data" in result:
91
+ result = result["data"]
92
+ try:
93
+ return json.dumps(result, indent=2, default=str)
94
+ except (TypeError, ValueError):
95
+ return str(result)
96
+ if hasattr(obs, "metadata") and obs.metadata:
97
+ return json.dumps(obs.metadata, indent=2, default=str)
98
+ return str(obs)
99
+
100
+
101
+ class AgentRunner:
102
+ """
103
+ Gym-agnostic agent that connects an LLM to any OpenEnv environment.
104
+
105
+ Reward modes:
106
+ - "custom" (default): Episode-level reward via RewardCalculator
107
+ - "openenv": Per-step reward via Transform + ground truth
108
+ """
109
+
110
+ def __init__(
111
+ self,
112
+ model: str,
113
+ env_client: MCPToolClient,
114
+ temperature: float = 0.0,
115
+ max_tokens: int = 1024,
116
+ reward_mode: str = "custom",
117
+ transform=None,
118
+ ):
119
+ self.llm = LLMClient(
120
+ model=model,
121
+ temperature=temperature,
122
+ max_tokens=max_tokens,
123
+ )
124
+ self.env_client = env_client
125
+ self.reward_mode = reward_mode
126
+ self.transform = transform
127
+
128
+ self.calculator = RewardCalculator()
129
+
130
+ if reward_mode == "openenv":
131
+ self.openenv_calculator = OpenEnvRewardCalculator()
132
+
133
+ def run_scenario(
134
+ self,
135
+ scenario: Scenario,
136
+ checker: Any,
137
+ ) -> Tuple[EpisodeLog, RewardBreakdown]:
138
+ """Run a single scenario through the LLM agent."""
139
+ return self._execute(scenario, checker, self.env_client)
140
+
141
+ def _execute(
142
+ self,
143
+ scenario: Scenario,
144
+ checker: Any,
145
+ env: MCPToolClient,
146
+ ) -> Tuple[EpisodeLog, RewardBreakdown]:
147
+
148
+ env.reset()
149
+
150
+ session_id = None
151
+ try:
152
+ session_result = env.step(
153
+ CallToolAction(tool_name="get_session_info", arguments={})
154
+ )
155
+ obs = session_result.observation
156
+ if isinstance(obs, CallToolObservation) and obs.result:
157
+ result_data = obs.result
158
+ if hasattr(result_data, "data"):
159
+ result_data = result_data.data
160
+ elif isinstance(result_data, dict) and "data" in result_data:
161
+ result_data = result_data["data"]
162
+ if isinstance(result_data, dict):
163
+ session_id = result_data.get("session_id")
164
+ elif isinstance(result_data, str):
165
+ import json as _json
166
+ try:
167
+ parsed = _json.loads(result_data)
168
+ session_id = parsed.get("session_id")
169
+ except (ValueError, TypeError):
170
+ pass
171
+ except Exception as e:
172
+ logger.warning(f"Could not get session_id: {e}")
173
+
174
+ if session_id and hasattr(checker, "set_session"):
175
+ checker.set_session(session_id)
176
+ logger.info(f"Session-scoped checker -> {session_id}")
177
+
178
+ if self.transform and hasattr(self.transform, "set_scenario"):
179
+ self.transform.set_scenario(scenario)
180
+
181
+ all_tools = env.list_tools(use_cache=False)
182
+ tools = [t for t in all_tools if t.name != "get_session_info"]
183
+ openai_tools = mcp_tools_to_openai(tools)
184
+ tool_names = [t.name for t in tools]
185
+ logger.info(f"Discovered {len(tools)} agent tools: {tool_names}")
186
+
187
+ messages = [
188
+ {"role": "system", "content": SYSTEM_PROMPT},
189
+ {"role": "user", "content": scenario.prompt},
190
+ ]
191
+
192
+ episode = EpisodeLog()
193
+ step_rewards = []
194
+ final_answer = None
195
+
196
+ for step_num in range(1, scenario.max_steps + 1):
197
+ logger.info(f"Step {step_num}/{scenario.max_steps}")
198
+
199
+ response = self.llm.chat(messages, tools=openai_tools)
200
+ tool_calls = LLMClient.extract_tool_calls(response)
201
+
202
+ if not tool_calls:
203
+ final_answer = LLMClient.get_text_response(response)
204
+ logger.info(f"Agent done. Final answer: {(final_answer or '')[:100]}...")
205
+ break
206
+
207
+ messages.append(response.choices[0].message.model_dump())
208
+
209
+ for tc in tool_calls:
210
+ tool_name = tc["name"]
211
+ arguments = tc["arguments"]
212
+ call_id = tc["id"]
213
+
214
+ logger.info(f" Tool: {tool_name}({json.dumps(arguments, default=str)[:100]})")
215
+
216
+ step_ts = datetime.now(IST).isoformat()
217
+ step_start = time.time()
218
+ error_msg = None
219
+ try:
220
+ step_result = env.step(
221
+ CallToolAction(tool_name=tool_name, arguments=arguments)
222
+ )
223
+ obs = step_result.observation
224
+ is_error = (
225
+ isinstance(obs, CallToolObservation)
226
+ and obs.error is not None
227
+ )
228
+ result_str = _observation_to_str(step_result)
229
+ if is_error and isinstance(obs, CallToolObservation):
230
+ error_msg = obs.error.message
231
+ except Exception as exc:
232
+ is_error = True
233
+ error_msg = str(exc)
234
+ result_str = json.dumps({"error": error_msg})
235
+ obs = None
236
+
237
+ step_elapsed = time.time() - step_start
238
+
239
+ if self.reward_mode == "openenv" and self.transform and obs is not None:
240
+ transformed = self.transform(obs)
241
+ step_rewards.append(
242
+ transformed.reward if transformed.reward is not None else 0.0
243
+ )
244
+
245
+ episode.add_step(
246
+ tool_name=tool_name,
247
+ arguments=arguments,
248
+ success=not is_error,
249
+ result=result_str,
250
+ error=error_msg,
251
+ timestamp=step_ts,
252
+ elapsed=step_elapsed,
253
+ )
254
+
255
+ logger.info(f" -> success={not is_error} ({step_elapsed:.2f}s)")
256
+
257
+ messages.append({
258
+ "role": "tool",
259
+ "tool_call_id": call_id,
260
+ "content": result_str,
261
+ })
262
+
263
+ if hasattr(checker, "set_episode"):
264
+ checker.set_episode(episode)
265
+
266
+ outcome_results = checker.check_all(scenario.outcome_checks)
267
+
268
+ if self.reward_mode == "openenv":
269
+ breakdown = self.openenv_calculator.calculate(
270
+ step_rewards=step_rewards,
271
+ outcome_results=outcome_results,
272
+ max_steps=scenario.max_steps,
273
+ actual_steps=len(episode.steps),
274
+ )
275
+ else:
276
+ breakdown = self.calculator.calculate(
277
+ episode=episode,
278
+ scenario=scenario,
279
+ outcome_results=outcome_results,
280
+ )
281
+
282
+ return episode, breakdown
play.html ADDED
@@ -0,0 +1,669 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>Phantom Grid — Visual Memory Game</title>
7
+ <link href="https://fonts.googleapis.com/css2?family=JetBrains+Mono:wght@400;600;800&family=Orbitron:wght@500;700;900&display=swap" rel="stylesheet">
8
+ <style>
9
+ :root {
10
+ --bg: #0d0d1a;
11
+ --panel: #141428;
12
+ --border: #2a2a50;
13
+ --accent: #7b61ff;
14
+ --accent-glow: rgba(123,97,255,0.3);
15
+ --danger: #ff4d4d;
16
+ --success: #4dff88;
17
+ --warning: #ffb84d;
18
+ --text: #e0e0f0;
19
+ --text-dim: #8888aa;
20
+ --text-bright: #ffffff;
21
+ }
22
+ * { margin:0; padding:0; box-sizing:border-box; }
23
+ body {
24
+ background: var(--bg);
25
+ color: var(--text);
26
+ font-family: 'JetBrains Mono', monospace;
27
+ min-height: 100vh;
28
+ }
29
+ .header {
30
+ text-align: center;
31
+ padding: 24px 16px 16px;
32
+ border-bottom: 1px solid var(--border);
33
+ background: linear-gradient(180deg, #12122a 0%, var(--bg) 100%);
34
+ }
35
+ .header h1 {
36
+ font-family: 'Orbitron', sans-serif;
37
+ font-weight: 900;
38
+ font-size: 28px;
39
+ letter-spacing: 3px;
40
+ background: linear-gradient(135deg, var(--accent), #ff61a6);
41
+ -webkit-background-clip: text;
42
+ -webkit-text-fill-color: transparent;
43
+ margin-bottom: 4px;
44
+ }
45
+ .header p { color: var(--text-dim); font-size: 12px; }
46
+ .layout {
47
+ display: flex;
48
+ gap: 16px;
49
+ padding: 16px;
50
+ max-width: 1400px;
51
+ margin: 0 auto;
52
+ align-items: flex-start;
53
+ justify-content: center;
54
+ }
55
+ .board-panel {
56
+ flex: 0 0 auto;
57
+ background: var(--panel);
58
+ border: 1px solid var(--border);
59
+ border-radius: 12px;
60
+ padding: 16px;
61
+ min-height: 420px;
62
+ display: flex;
63
+ flex-direction: column;
64
+ align-items: center;
65
+ justify-content: center;
66
+ position: relative;
67
+ }
68
+ .board-wrap { position: relative; cursor: crosshair; }
69
+ .board-wrap svg { display: block; }
70
+ .click-overlay {
71
+ position: absolute;
72
+ top: 0;
73
+ left: 0;
74
+ width: 100%;
75
+ height: 100%;
76
+ }
77
+ .click-cell {
78
+ position: absolute;
79
+ border: 2px solid transparent;
80
+ transition: border-color 0.15s;
81
+ cursor: pointer;
82
+ }
83
+ .click-cell:hover { border-color: var(--accent); }
84
+ .click-cell.selected { border-color: var(--warning); background: rgba(255,184,77,0.15); }
85
+ .selected-label {
86
+ position: absolute;
87
+ bottom: -22px;
88
+ left: 0;
89
+ right: 0;
90
+ text-align: center;
91
+ font-size: 11px;
92
+ color: var(--warning);
93
+ font-weight: 600;
94
+ }
95
+ .controls {
96
+ flex: 1;
97
+ display: flex;
98
+ flex-direction: column;
99
+ gap: 12px;
100
+ min-width: 300px;
101
+ max-width: 380px;
102
+ }
103
+ .card {
104
+ background: var(--panel);
105
+ border: 1px solid var(--border);
106
+ border-radius: 10px;
107
+ padding: 14px;
108
+ }
109
+ .card h3 {
110
+ font-family: 'Orbitron', sans-serif;
111
+ font-size: 11px;
112
+ font-weight: 700;
113
+ letter-spacing: 2px;
114
+ text-transform: uppercase;
115
+ color: var(--accent);
116
+ margin-bottom: 10px;
117
+ }
118
+ .scenario-grid {
119
+ display: grid;
120
+ grid-template-columns: 1fr 1fr;
121
+ gap: 6px;
122
+ }
123
+ .scenario-btn {
124
+ background: transparent;
125
+ border: 1px solid var(--border);
126
+ color: var(--text);
127
+ padding: 8px 6px;
128
+ border-radius: 6px;
129
+ font-family: 'JetBrains Mono', monospace;
130
+ font-size: 10px;
131
+ cursor: pointer;
132
+ transition: all 0.2s;
133
+ text-align: left;
134
+ }
135
+ .scenario-btn:hover { border-color: var(--accent); background: var(--accent-glow); }
136
+ .scenario-btn.active { border-color: var(--accent); background: var(--accent-glow); color: var(--text-bright); }
137
+ .scenario-btn:disabled { opacity: 0.5; cursor: wait; }
138
+ .action-row {
139
+ display: flex;
140
+ gap: 6px;
141
+ align-items: center;
142
+ flex-wrap: wrap;
143
+ }
144
+ .btn {
145
+ padding: 8px 14px;
146
+ border: 1px solid var(--border);
147
+ border-radius: 6px;
148
+ font-family: 'JetBrains Mono', monospace;
149
+ font-size: 11px;
150
+ cursor: pointer;
151
+ transition: all 0.2s;
152
+ font-weight: 600;
153
+ white-space: nowrap;
154
+ }
155
+ .btn:disabled { opacity: 0.4; cursor: not-allowed; }
156
+ .btn-reveal { background: #1a3a5c; color: #7bc8ff; border-color: #2a5a8c; }
157
+ .btn-reveal:hover:not(:disabled) { background: #2a4a6c; }
158
+ .btn-flag { background: #5c1a1a; color: #ff7b7b; border-color: #8c2a2a; }
159
+ .btn-flag:hover:not(:disabled) { background: #6c2a2a; }
160
+ .btn-inspect { background: #1a3a2a; color: #7bffaa; border-color: #2a5a3a; }
161
+ .btn-inspect:hover:not(:disabled) { background: #2a4a3a; }
162
+ .btn-viewport { background: #3a2a1a; color: #ffcc7b; border-color: #5a4a2a; }
163
+ .btn-viewport:hover:not(:disabled) { background: #4a3a2a; }
164
+ .btn-submit {
165
+ background: var(--accent);
166
+ color: white;
167
+ border-color: var(--accent);
168
+ width: 100%;
169
+ padding: 10px;
170
+ font-size: 13px;
171
+ letter-spacing: 1px;
172
+ }
173
+ .btn-submit:hover:not(:disabled) { filter: brightness(1.2); }
174
+ .btn-secondary { background: transparent; color: var(--text-dim); border-color: var(--border); }
175
+ .btn-secondary:hover:not(:disabled) { color: var(--text); border-color: var(--text-dim); }
176
+ .status-bar { display: grid; grid-template-columns: repeat(4, 1fr); gap: 6px; }
177
+ .stat { text-align: center; padding: 8px 4px; background: var(--bg); border-radius: 6px; }
178
+ .stat-val { font-family: 'Orbitron', sans-serif; font-size: 18px; font-weight: 700; color: var(--text-bright); }
179
+ .stat-label { font-size: 9px; color: var(--text-dim); margin-top: 2px; }
180
+ .mode-toggle {
181
+ display: flex;
182
+ gap: 4px;
183
+ margin-bottom: 8px;
184
+ }
185
+ .mode-btn {
186
+ flex: 1;
187
+ padding: 8px 4px;
188
+ border: 1px solid var(--border);
189
+ border-radius: 6px;
190
+ font-family: 'JetBrains Mono', monospace;
191
+ font-size: 11px;
192
+ cursor: pointer;
193
+ text-align: center;
194
+ transition: all 0.2s;
195
+ background: transparent;
196
+ color: var(--text-dim);
197
+ }
198
+ .mode-btn.active-reveal { border-color: #2a5a8c; background: #1a3a5c; color: #7bc8ff; }
199
+ .mode-btn.active-flag { border-color: #8c2a2a; background: #5c1a1a; color: #ff7b7b; }
200
+ .help-text { font-size: 10px; color: var(--text-dim); line-height: 1.5; margin-top: 6px; }
201
+ #log {
202
+ max-height: 200px;
203
+ overflow-y: auto;
204
+ font-size: 11px;
205
+ line-height: 1.6;
206
+ color: var(--text-dim);
207
+ background: var(--bg);
208
+ border-radius: 6px;
209
+ padding: 8px;
210
+ }
211
+ .log-entry { border-bottom: 1px solid #1a1a30; padding: 2px 0; }
212
+ .log-success { color: var(--success) !important; }
213
+ .log-danger { color: var(--danger) !important; }
214
+ .log-warn { color: var(--warning) !important; }
215
+ .log-info { color: var(--accent) !important; }
216
+ .empty-board {
217
+ color: var(--text-dim);
218
+ font-size: 14px;
219
+ text-align: center;
220
+ padding: 80px 20px;
221
+ line-height: 2;
222
+ }
223
+ .flagged-list {
224
+ font-size: 11px;
225
+ color: var(--text-dim);
226
+ background: var(--bg);
227
+ border-radius: 6px;
228
+ padding: 8px;
229
+ min-height: 28px;
230
+ word-break: break-all;
231
+ }
232
+ .game-over-banner {
233
+ text-align: center;
234
+ padding: 12px;
235
+ border-radius: 8px;
236
+ font-family: 'Orbitron', sans-serif;
237
+ font-size: 14px;
238
+ font-weight: 700;
239
+ letter-spacing: 2px;
240
+ animation: pulse 1.5s infinite;
241
+ }
242
+ @keyframes pulse { 0%,100% { opacity: 1; } 50% { opacity: 0.7; } }
243
+ .win { background: rgba(77,255,136,0.15); color: var(--success); border: 1px solid var(--success); }
244
+ .lose { background: rgba(255,77,77,0.15); color: var(--danger); border: 1px solid var(--danger); }
245
+ .scenario-info {
246
+ background: var(--bg);
247
+ border: 1px solid var(--border);
248
+ border-radius: 8px;
249
+ padding: 12px 14px;
250
+ margin-top: 10px;
251
+ font-size: 12px;
252
+ line-height: 1.6;
253
+ color: var(--text);
254
+ max-width: 540px;
255
+ width: 100%;
256
+ }
257
+ .scenario-info .info-title {
258
+ font-family: 'Orbitron', sans-serif;
259
+ font-size: 10px;
260
+ font-weight: 700;
261
+ letter-spacing: 2px;
262
+ text-transform: uppercase;
263
+ color: var(--success);
264
+ margin-bottom: 6px;
265
+ }
266
+ .scenario-info .info-goal {
267
+ color: var(--warning);
268
+ font-weight: 600;
269
+ margin-bottom: 4px;
270
+ font-size: 11px;
271
+ }
272
+ .scenario-info .info-text {
273
+ color: var(--text-dim);
274
+ font-size: 11px;
275
+ }
276
+ .legend {
277
+ display: flex; gap: 12px; flex-wrap: wrap;
278
+ justify-content: center; margin-top: 10px; font-size: 10px;
279
+ }
280
+ .legend-item { display: flex; align-items: center; gap: 4px; }
281
+ .legend-swatch {
282
+ width: 14px; height: 14px; border-radius: 3px; border: 1px solid #3d3d5a;
283
+ }
284
+ </style>
285
+ </head>
286
+ <body>
287
+ <div class="header">
288
+ <h1>PHANTOM GRID</h1>
289
+ <p>Visual Memory Gym — Click cells to reveal or flag. Play server on port 8001.</p>
290
+ </div>
291
+ <div class="layout">
292
+ <div class="board-panel">
293
+ <div id="board-container">
294
+ <div class="empty-board">
295
+ Select a scenario on the right to begin.<br><br>
296
+ <b>How to play:</b><br>
297
+ 1. Pick a scenario<br>
298
+ 2. Click cells on the board to reveal or flag them<br>
299
+ 3. Use signals (numbers) to deduce hazard locations<br>
300
+ 4. Flag all hazards, then submit your solution
301
+ </div>
302
+ </div>
303
+ <div class="scenario-info" id="scenario-info" style="display:none">
304
+ <div class="info-title">HOW TO PLAY</div>
305
+ <div class="info-goal" id="info-goal"></div>
306
+ <div class="info-text" id="info-text"></div>
307
+ </div>
308
+ <div class="legend" id="legend" style="display:none">
309
+ <div class="legend-item"><div class="legend-swatch" style="background:#2d2d4a"></div> Hidden</div>
310
+ <div class="legend-item"><div class="legend-swatch" style="background:#d0e8ff"></div> Signal</div>
311
+ <div class="legend-item"><div class="legend-swatch" style="background:#e8e8f0"></div> Empty</div>
312
+ <div class="legend-item"><div class="legend-swatch" style="background:#ff4d4d"></div> Hazard</div>
313
+ <div class="legend-item"><div class="legend-swatch" style="background:#ff6b35"></div> Flagged</div>
314
+ <div class="legend-item"><div class="legend-swatch" style="background:#ffd700"></div> Key</div>
315
+ <div class="legend-item"><div class="legend-swatch" style="background:#c8b8e8"></div> Decoy</div>
316
+ <div class="legend-item"><div class="legend-swatch" style="background:#50fa7b"></div> Goal</div>
317
+ <div class="legend-item"><div class="legend-swatch" style="background:#111122"></div> Fog</div>
318
+ </div>
319
+ </div>
320
+ <div class="controls">
321
+ <div class="card">
322
+ <h3>Scenario</h3>
323
+ <div class="scenario-grid" id="scenario-grid"></div>
324
+ </div>
325
+ <div class="card">
326
+ <h3>Status</h3>
327
+ <div class="status-bar">
328
+ <div class="stat"><div class="stat-val" id="stat-steps">—</div><div class="stat-label">STEPS</div></div>
329
+ <div class="stat"><div class="stat-val" id="stat-max">—</div><div class="stat-label">MAX</div></div>
330
+ <div class="stat"><div class="stat-val" id="stat-flags">—</div><div class="stat-label">FLAGS</div></div>
331
+ <div class="stat"><div class="stat-val" id="stat-revealed">—</div><div class="stat-label">REVEALED</div></div>
332
+ </div>
333
+ <div id="game-over-slot" style="margin-top:8px"></div>
334
+ </div>
335
+ <div class="card">
336
+ <h3>Click Mode</h3>
337
+ <div class="mode-toggle">
338
+ <button class="mode-btn active-reveal" id="mode-reveal" onclick="setMode('reveal')">
339
+ Reveal
340
+ </button>
341
+ <button class="mode-btn" id="mode-flag" onclick="setMode('flag')">
342
+ Flag Hazard
343
+ </button>
344
+ </div>
345
+ <p class="help-text" id="mode-help">Click any dark (hidden) cell on the board to reveal it.</p>
346
+ </div>
347
+ <div class="card">
348
+ <h3>Tools</h3>
349
+ <div style="display:flex;flex-direction:column;gap:8px">
350
+ <div class="action-row">
351
+ <button class="btn btn-inspect" onclick="doInspect()">Inspect Region</button>
352
+ <button class="btn btn-viewport" onclick="doMoveViewport()">Move Viewport</button>
353
+ </div>
354
+ <div class="action-row">
355
+ <button class="btn btn-secondary" onclick="doRecall()">Recall Log</button>
356
+ <button class="btn btn-secondary" onclick="doGetStatus()">Get Status</button>
357
+ </div>
358
+ </div>
359
+ </div>
360
+ <div class="card">
361
+ <h3>Flagged Cells <span id="flag-count" style="color:var(--text-dim)">(0)</span></h3>
362
+ <div class="flagged-list" id="flagged-list">No cells flagged yet</div>
363
+ </div>
364
+ <div class="card">
365
+ <button class="btn btn-submit" onclick="doSubmit()">SUBMIT SOLUTION</button>
366
+ </div>
367
+ <div class="card">
368
+ <h3>Game Log</h3>
369
+ <div id="log"></div>
370
+ </div>
371
+ </div>
372
+ </div>
373
+
374
+ <script>
375
+ const API = 'http://localhost:8001';
376
+ let clickMode = 'reveal';
377
+ let boardWidth = 0;
378
+ let boardHeight = 0;
379
+ let flaggedCells = [];
380
+ let selectedRow = -1;
381
+ let selectedCol = -1;
382
+ let gameOver = false;
383
+
384
+ const CELL_SIZE = 48;
385
+ const PADDING = 24;
386
+
387
+ const SCENARIOS = [
388
+ {id:'flash_fade_minefield_7x7', label:'Flash Fade 7x7', desc:'Pattern memory'},
389
+ {id:'directional_trap_8x8', label:'Directional Trap 8x8', desc:'1 hit = fatal'},
390
+ {id:'partial_intel_9x9', label:'Partial Intel 9x9', desc:'Incomplete signals'},
391
+ {id:'delayed_recall_keys_8x8', label:'Delayed Recall 8x8', desc:'Collect 5 keys'},
392
+ {id:'ambiguous_cluster_10x10', label:'Ambiguous Cluster 10x10', desc:'Range signals'},
393
+ {id:'decoy_minefield_8x10', label:'Decoy Minefield 8x10', desc:'4 keys, 8 decoys'},
394
+ {id:'fog_labyrinth_10x10', label:'Fog Labyrinth 10x10', desc:'Viewport radius 2'},
395
+ {id:'fog_key_hunt_8x8', label:'Fog Key Hunt 8x8', desc:'Tiny viewport'},
396
+ {id:'cascading_deduction_11x11', label:'Cascading 11x11', desc:'25 hazards'},
397
+ {id:'safe_zone_identification_9x9', label:'Safe Zone ID 9x9', desc:'Find safe cells'},
398
+ ];
399
+
400
+ function initScenarios() {
401
+ const grid = document.getElementById('scenario-grid');
402
+ SCENARIOS.forEach(s => {
403
+ const btn = document.createElement('button');
404
+ btn.className = 'scenario-btn';
405
+ btn.innerHTML = `${s.label}<br><span style="color:var(--text-dim);font-size:9px">${s.desc}</span>`;
406
+ btn.dataset.id = s.id;
407
+ btn.onclick = () => loadScenario(s.id);
408
+ grid.appendChild(btn);
409
+ });
410
+ }
411
+
412
+ function setMode(mode) {
413
+ clickMode = mode;
414
+ document.getElementById('mode-reveal').className = 'mode-btn' + (mode === 'reveal' ? ' active-reveal' : '');
415
+ document.getElementById('mode-flag').className = 'mode-btn' + (mode === 'flag' ? ' active-flag' : '');
416
+ document.getElementById('mode-help').textContent = mode === 'reveal'
417
+ ? 'Click any dark (hidden) cell on the board to reveal it.'
418
+ : 'Click any dark (hidden) cell to flag it as a hazard.';
419
+ }
420
+
421
+ function log(msg, cls) {
422
+ const el = document.getElementById('log');
423
+ const div = document.createElement('div');
424
+ div.className = 'log-entry ' + (cls || '');
425
+ div.textContent = `[${new Date().toLocaleTimeString()}] ${msg}`;
426
+ el.prepend(div);
427
+ }
428
+
429
+ function updateBoard(svgText) {
430
+ if (!svgText) return;
431
+ const container = document.getElementById('board-container');
432
+ const wrap = document.createElement('div');
433
+ wrap.className = 'board-wrap';
434
+ wrap.innerHTML = svgText;
435
+
436
+ const overlay = document.createElement('div');
437
+ overlay.className = 'click-overlay';
438
+
439
+ for (let r = 0; r < boardHeight; r++) {
440
+ for (let c = 0; c < boardWidth; c++) {
441
+ const cell = document.createElement('div');
442
+ cell.className = 'click-cell';
443
+ cell.setAttribute('role', 'button');
444
+ cell.setAttribute('aria-label', `cell ${r} ${c}`);
445
+ cell.setAttribute('tabindex', '0');
446
+ cell.style.left = (PADDING + c * CELL_SIZE) + 'px';
447
+ cell.style.top = (PADDING + r * CELL_SIZE) + 'px';
448
+ cell.style.width = CELL_SIZE + 'px';
449
+ cell.style.height = CELL_SIZE + 'px';
450
+ cell.dataset.row = r;
451
+ cell.dataset.col = c;
452
+ cell.title = `(${r}, ${c})`;
453
+ cell.onclick = () => onCellClick(r, c);
454
+ overlay.appendChild(cell);
455
+ }
456
+ }
457
+
458
+ wrap.appendChild(overlay);
459
+ container.innerHTML = '';
460
+ container.appendChild(wrap);
461
+ document.getElementById('legend').style.display = 'flex';
462
+ }
463
+
464
+ function updateStats(status) {
465
+ if (!status) return;
466
+ document.getElementById('stat-steps').textContent = status.step_count ?? '—';
467
+ document.getElementById('stat-max').textContent = status.max_steps ?? '—';
468
+ document.getElementById('stat-flags').textContent = status.flags_remaining ?? '—';
469
+ document.getElementById('stat-revealed').textContent = status.cells_revealed ?? '—';
470
+ gameOver = !!status.game_over;
471
+ }
472
+
473
+ function updateFlaggedList() {
474
+ const el = document.getElementById('flagged-list');
475
+ document.getElementById('flag-count').textContent = `(${flaggedCells.length})`;
476
+ if (flaggedCells.length === 0) { el.textContent = 'No cells flagged yet'; return; }
477
+ el.textContent = flaggedCells.map(c => `[${c[0]},${c[1]}]`).join(' ');
478
+ }
479
+
480
+ function showGameOver(won, msg) {
481
+ document.getElementById('game-over-slot').innerHTML =
482
+ `<div class="game-over-banner ${won ? 'win' : 'lose'}">${msg}</div>`;
483
+ }
484
+
485
+ async function api(method, path, body) {
486
+ try {
487
+ const opts = { method, headers: {'Content-Type': 'application/json'} };
488
+ if (body !== undefined) opts.body = JSON.stringify(body);
489
+ const resp = await fetch(API + path, opts);
490
+ return await resp.json();
491
+ } catch (e) {
492
+ log('Connection error: ' + e.message + ' — is play_server.py running on port 8001?', 'log-danger');
493
+ return { error: e.message };
494
+ }
495
+ }
496
+
497
+ function processResponse(data) {
498
+ if (data.error) { log(data.error, 'log-danger'); return; }
499
+ if (data.board && data.board.svg) {
500
+ updateBoard(data.board.svg);
501
+ if (data.board.metadata) {
502
+ document.getElementById('stat-revealed').textContent = data.board.metadata.cell_counts?.revealed ?? '—';
503
+ }
504
+ }
505
+ if (data.status) updateStats(data.status);
506
+ return data.action_result;
507
+ }
508
+
509
+ async function loadScenario(id) {
510
+ document.querySelectorAll('.scenario-btn').forEach(b => { b.classList.remove('active'); b.disabled = true; });
511
+ const btn = document.querySelector(`[data-id='${id}']`);
512
+ if (btn) btn.classList.add('active');
513
+ document.getElementById('game-over-slot').innerHTML = '';
514
+ flaggedCells = [];
515
+ updateFlaggedList();
516
+ document.getElementById('log').innerHTML = '';
517
+ gameOver = false;
518
+
519
+ log(`Loading scenario: ${id}...`, 'log-info');
520
+ const data = await api('POST', '/load', { scenario_id: id });
521
+
522
+ document.querySelectorAll('.scenario-btn').forEach(b => b.disabled = false);
523
+
524
+ if (data.error) { log(data.error, 'log-danger'); return; }
525
+ boardWidth = data.status?.board_size ? parseInt(data.status.board_size.split('x')[0]) : 0;
526
+ boardHeight = data.status?.board_size ? parseInt(data.status.board_size.split('x')[1]) : 0;
527
+ processResponse(data);
528
+
529
+ const winLabels = {
530
+ 'flag_all_hazards': 'Flag all hazards, then submit.',
531
+ 'collect_keys': 'Find and reveal all keys to win.',
532
+ 'identify_safe_cells': 'Identify all safe (non-hazard) cells, then submit.',
533
+ 'reach_goal': 'Reach the goal cell to win.',
534
+ };
535
+ const wc = data.status?.win_condition || '';
536
+ const goalText = winLabels[wc] || `Win condition: ${wc}`;
537
+ const howTo = data.how_to_play || '';
538
+
539
+ const infoEl = document.getElementById('scenario-info');
540
+ if (howTo) {
541
+ document.getElementById('info-goal').textContent = goalText;
542
+ document.getElementById('info-text').textContent = howTo;
543
+ infoEl.style.display = 'block';
544
+ } else {
545
+ infoEl.style.display = 'none';
546
+ }
547
+
548
+ log(`Loaded ${boardWidth}x${boardHeight} board | Type: ${data.status?.scenario_type} | Win: ${data.status?.win_condition} | Max steps: ${data.status?.max_steps}`, 'log-success');
549
+ }
550
+
551
+ async function onCellClick(row, col) {
552
+ if (gameOver) { log('Game is over. Load a new scenario to play again.', 'log-warn'); return; }
553
+ if (boardWidth === 0) { log('Load a scenario first.', 'log-warn'); return; }
554
+
555
+ if (clickMode === 'reveal') {
556
+ log(`Revealing (${row}, ${col})...`);
557
+ const data = await api('POST', '/reveal', { row, col });
558
+ const result = processResponse(data);
559
+ if (!result) return;
560
+ if (result.error) { log(result.error, 'log-danger'); return; }
561
+
562
+ const t = result.type || '';
563
+ if (t === 'hazard') log(`HAZARD at (${row},${col})!${result.game_over ? ' GAME OVER!' : ''}`, 'log-danger');
564
+ else if (t === 'key') log(`KEY found at (${row},${col})!`, 'log-success');
565
+ else if (t === 'signal') log(`Signal at (${row},${col}): ${JSON.stringify(result.value)}`, 'log-info');
566
+ else if (t === 'decoy') log(`Decoy at (${row},${col})`, 'log-warn');
567
+ else if (t === 'goal') log(`GOAL reached at (${row},${col})!`, 'log-success');
568
+ else if (t === 'empty') log(`Empty cell at (${row},${col})`, '');
569
+ else log(`Cell (${row},${col}): ${JSON.stringify(result).slice(0,120)}`, 'log-info');
570
+
571
+ if (result.game_over && result.message) {
572
+ showGameOver(!!result.message?.includes('win') || !!result.message?.includes('Win'), result.message);
573
+ }
574
+
575
+ } else {
576
+ const alreadyFlagged = flaggedCells.some(f => f[0] === row && f[1] === col);
577
+ if (alreadyFlagged) {
578
+ log(`Unflagging (${row}, ${col})...`);
579
+ const data = await api('POST', '/unflag', { row, col });
580
+ const result = processResponse(data);
581
+ if (!result) return;
582
+ if (result.error) { log(result.error, 'log-danger'); return; }
583
+ flaggedCells = flaggedCells.filter(f => !(f[0] === row && f[1] === col));
584
+ updateFlaggedList();
585
+ log(`Unflagged (${row},${col})`, 'log-warn');
586
+ } else {
587
+ log(`Flagging (${row}, ${col})...`);
588
+ const data = await api('POST', '/flag', { row, col });
589
+ const result = processResponse(data);
590
+ if (!result) return;
591
+ if (result.error) { log(result.error, 'log-danger'); return; }
592
+ flaggedCells.push([row, col]);
593
+ updateFlaggedList();
594
+ log(`Flagged (${row},${col}) as hazard | ${result.flags_remaining} flags left`, 'log-warn');
595
+ if (result.game_over && result.message) {
596
+ showGameOver(true, result.message);
597
+ }
598
+ }
599
+ }
600
+ }
601
+
602
+ async function doInspect() {
603
+ const row = parseInt(prompt('Center row:', '0'));
604
+ const col = parseInt(prompt('Center col:', '0'));
605
+ const radius = parseInt(prompt('Radius (1-3):', '1'));
606
+ if (isNaN(row) || isNaN(col)) return;
607
+ log(`Inspecting region around (${row},${col}) r=${radius}...`, 'log-info');
608
+ const data = await api('POST', '/inspect', { center_row: row, center_col: col, radius: radius || 1 });
609
+ const result = processResponse(data);
610
+ if (!result) return;
611
+ if (result.error) { log(result.error, 'log-danger'); return; }
612
+ const cells = result.cells || [];
613
+ log(`Inspected ${cells.length} cells`, 'log-info');
614
+ cells.forEach(c => {
615
+ if (c.state !== 'hidden' && c.state !== 'fog') {
616
+ log(` (${c.row},${c.col}): ${c.state} ${c.content ? JSON.stringify(c.content) : ''}`, 'log-info');
617
+ }
618
+ });
619
+ }
620
+
621
+ async function doMoveViewport() {
622
+ const row = parseInt(prompt('Viewport center row:', '0'));
623
+ const col = parseInt(prompt('Viewport center col:', '0'));
624
+ if (isNaN(row) || isNaN(col)) return;
625
+ log(`Moving viewport to (${row},${col})...`, 'log-info');
626
+ const data = await api('POST', '/move_viewport', { row, col });
627
+ const result = processResponse(data);
628
+ if (result && result.error) log(result.error, 'log-danger');
629
+ else log(`Viewport moved to (${row},${col})`, 'log-info');
630
+ }
631
+
632
+ async function doRecall() {
633
+ const data = await api('GET', '/recall');
634
+ if (data.error) { log(data.error, 'log-danger'); return; }
635
+ const sigs = data.discovered_signals || [];
636
+ const mems = data.memory_events || [];
637
+ log(`Recall: ${sigs.length} signals, ${mems.length} memory events`, 'log-info');
638
+ sigs.forEach(s => log(` Signal (${s.row},${s.col}): ${s.type} = ${JSON.stringify(s.value)}`, 'log-info'));
639
+ mems.slice(-5).forEach(m => log(` Memory: step ${m.step} ${m.event} (${m.row},${m.col})`, 'log-info'));
640
+ }
641
+
642
+ async function doGetStatus() {
643
+ const data = await api('GET', '/status');
644
+ if (data.error) { log(data.error, 'log-danger'); return; }
645
+ updateStats(data);
646
+ log(`Step ${data.step_count}/${data.max_steps} | Flags: ${data.flags_placed}/${data.flags_remaining + data.flags_placed} | Hits: ${data.hazard_hits} | ${data.game_over ? 'GAME OVER' : 'Active'} | Win: ${data.win_condition}`, 'log-info');
647
+ }
648
+
649
+ async function doSubmit() {
650
+ if (!confirm(`Submit solution with ${flaggedCells.length} flagged cells?`)) return;
651
+ log('Submitting solution...', 'log-warn');
652
+ const data = await api('POST', '/submit', {
653
+ flagged_positions: flaggedCells,
654
+ safe_positions: [],
655
+ });
656
+ const result = processResponse(data);
657
+ if (!result) return;
658
+ if (result.error) { log(result.error, 'log-danger'); return; }
659
+ const won = result.correct === true;
660
+ const prec = ((result.precision || 0) * 100).toFixed(1);
661
+ const rec = ((result.recall || 0) * 100).toFixed(1);
662
+ log(`${won ? 'CORRECT!' : 'INCORRECT'} | Precision: ${prec}% | Recall: ${rec}% | Found: ${result.hazards_found||0}/${result.hazards_total||'?'}`, won ? 'log-success' : 'log-danger');
663
+ showGameOver(won, won ? 'YOU WIN!' : `INCORRECT — ${prec}% precision, ${rec}% recall`);
664
+ }
665
+
666
+ initScenarios();
667
+ </script>
668
+ </body>
669
+ </html>
play_server.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Standalone play server for manual Phantom Grid gameplay.
2
+
3
+ Completely separate from the OpenEnv app.py — does NOT affect
4
+ HuggingFace deployment, Docker builds, or run_eval.py in any way.
5
+
6
+ Runs on port 8001 by default. Uses the game engine and renderer directly.
7
+
8
+ Usage:
9
+ cd visual-memory
10
+ python play_server.py
11
+ # Then open play.html in a browser
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import json
17
+ import os
18
+ import sys
19
+ from pathlib import Path
20
+
21
+ import uvicorn
22
+ from fastapi import FastAPI
23
+ from fastapi.middleware.cors import CORSMiddleware
24
+ from fastapi.responses import HTMLResponse, FileResponse
25
+ from pydantic import BaseModel
26
+
27
+ sys.path.insert(0, str(Path(__file__).resolve().parent))
28
+
29
+ from server.engine import GameEngine
30
+ from server.renderer import Renderer
31
+
32
+ SCENARIOS_DIR = os.path.join(os.path.dirname(__file__), "scenarios")
33
+
34
+ app = FastAPI(title="Phantom Grid — Play Server")
35
+ app.add_middleware(
36
+ CORSMiddleware,
37
+ allow_origins=["*"],
38
+ allow_methods=["*"],
39
+ allow_headers=["*"],
40
+ )
41
+
42
+ engine: GameEngine | None = None
43
+ renderer = Renderer()
44
+
45
+
46
+ def _load_scenario_file(scenario_id: str) -> dict:
47
+ path = os.path.join(SCENARIOS_DIR, f"{scenario_id}.json")
48
+ if not os.path.isfile(path):
49
+ raise FileNotFoundError(f"Scenario '{scenario_id}' not found at {path}")
50
+ with open(path) as f:
51
+ return json.load(f)
52
+
53
+
54
+ def _board_response() -> dict:
55
+ """Build a unified response with board SVG + game status."""
56
+ if engine is None:
57
+ return {"error": "No scenario loaded."}
58
+ bs = engine.get_board_state()
59
+ view = renderer.get_board_view(
60
+ bs.visible_cells, bs.board_width, bs.board_height,
61
+ scenario_type=bs.scenario_type, step_count=bs.step_count,
62
+ )
63
+ status = engine.get_status()
64
+ return {"board": view, "status": status}
65
+
66
+
67
+ @app.get("/")
68
+ async def index():
69
+ html_path = os.path.join(os.path.dirname(__file__), "play.html")
70
+ if os.path.isfile(html_path):
71
+ return FileResponse(html_path, media_type="text/html")
72
+ return HTMLResponse("<h1>play.html not found</h1>", status_code=404)
73
+
74
+
75
+ @app.get("/scenarios")
76
+ async def list_scenarios():
77
+ results = []
78
+ for fname in sorted(os.listdir(SCENARIOS_DIR)):
79
+ if not fname.endswith(".json"):
80
+ continue
81
+ try:
82
+ data = _load_scenario_file(fname.replace(".json", ""))
83
+ results.append({
84
+ "scenario_id": data.get("scenario_id", fname.replace(".json", "")),
85
+ "type": data.get("type", "hidden_grid"),
86
+ "difficulty": data.get("difficulty", "hard"),
87
+ "board_size": f"{data.get('board_width', '?')}x{data.get('board_height', '?')}",
88
+ "description": data.get("description", ""),
89
+ })
90
+ except Exception:
91
+ continue
92
+ return {"scenarios": results}
93
+
94
+
95
+ class LoadReq(BaseModel):
96
+ scenario_id: str
97
+
98
+ @app.post("/load")
99
+ async def load_scenario(req: LoadReq):
100
+ global engine
101
+ try:
102
+ data = _load_scenario_file(req.scenario_id)
103
+ except FileNotFoundError as e:
104
+ return {"error": str(e)}
105
+ engine = GameEngine(data)
106
+ resp = _board_response()
107
+ resp["loaded"] = True
108
+ resp["how_to_play"] = data.get("how_to_play", "")
109
+ resp["scenario_description"] = data.get("description", "")
110
+ return resp
111
+
112
+
113
+ class CellReq(BaseModel):
114
+ row: int
115
+ col: int
116
+
117
+ @app.post("/reveal")
118
+ async def reveal(req: CellReq):
119
+ if engine is None:
120
+ return {"error": "No scenario loaded."}
121
+ result = engine.reveal_cell(req.row, req.col)
122
+ resp = _board_response()
123
+ resp["action_result"] = result
124
+ return resp
125
+
126
+
127
+ @app.post("/flag")
128
+ async def flag(req: CellReq):
129
+ if engine is None:
130
+ return {"error": "No scenario loaded."}
131
+ result = engine.flag_cell(req.row, req.col)
132
+ resp = _board_response()
133
+ resp["action_result"] = result
134
+ return resp
135
+
136
+
137
+ @app.post("/unflag")
138
+ async def unflag(req: CellReq):
139
+ if engine is None:
140
+ return {"error": "No scenario loaded."}
141
+ result = engine.unflag_cell(req.row, req.col)
142
+ resp = _board_response()
143
+ resp["action_result"] = result
144
+ return resp
145
+
146
+
147
+ @app.post("/move_viewport")
148
+ async def move_viewport(req: CellReq):
149
+ if engine is None:
150
+ return {"error": "No scenario loaded."}
151
+ result = engine.move_viewport(req.row, req.col)
152
+ resp = _board_response()
153
+ resp["action_result"] = result
154
+ return resp
155
+
156
+
157
+ class InspectReq(BaseModel):
158
+ center_row: int
159
+ center_col: int
160
+ radius: int = 1
161
+
162
+ @app.post("/inspect")
163
+ async def inspect(req: InspectReq):
164
+ if engine is None:
165
+ return {"error": "No scenario loaded."}
166
+ if engine.game_over:
167
+ return {"error": "Game is already over."}
168
+ if req.radius < 1 or req.radius > 3:
169
+ return {"error": "Radius must be between 1 and 3."}
170
+
171
+ engine.step_count += 1
172
+ engine._tick_pattern_memory()
173
+
174
+ visible = engine.get_visible_board()
175
+ region = []
176
+ for r in range(max(0, req.center_row - req.radius),
177
+ min(engine.height, req.center_row + req.radius + 1)):
178
+ for c in range(max(0, req.center_col - req.radius),
179
+ min(engine.width, req.center_col + req.radius + 1)):
180
+ cell = visible[r][c]
181
+ region.append({"row": r, "col": c, "state": cell["state"], "content": cell.get("content")})
182
+
183
+ resp = _board_response()
184
+ resp["action_result"] = {"cells": region}
185
+ return resp
186
+
187
+
188
+ @app.get("/status")
189
+ async def status():
190
+ if engine is None:
191
+ return {"error": "No scenario loaded."}
192
+ return engine.get_status()
193
+
194
+
195
+ @app.get("/board")
196
+ async def board():
197
+ return _board_response()
198
+
199
+
200
+ @app.get("/recall")
201
+ async def recall():
202
+ if engine is None:
203
+ return {"error": "No scenario loaded."}
204
+ bs = engine.get_board_state()
205
+ return {
206
+ "discovered_signals": bs.discovered_signals,
207
+ "memory_events": bs.memory_events,
208
+ }
209
+
210
+
211
+ class SubmitReq(BaseModel):
212
+ flagged_positions: list[list[int]] = []
213
+ safe_positions: list[list[int]] = []
214
+
215
+ @app.post("/submit")
216
+ async def submit(req: SubmitReq):
217
+ if engine is None:
218
+ return {"error": "No scenario loaded."}
219
+ result = engine.submit_solution(
220
+ flagged_positions=req.flagged_positions,
221
+ safe_positions=req.safe_positions,
222
+ )
223
+ resp = _board_response()
224
+ resp["action_result"] = result
225
+ return resp
226
+
227
+
228
+ if __name__ == "__main__":
229
+ uvicorn.run(app, host="0.0.0.0", port=8001)
pyproject.toml CHANGED
@@ -8,7 +8,7 @@ version = "0.1.0"
8
  description = "Visual Memory environment for OpenEnv — hidden-state visual reasoning and planning under partial observability"
9
  requires-python = ">=3.10"
10
  dependencies = [
11
- "openenv-core @ git+https://github.com/meta-pytorch/OpenEnv.git@v0.2.1",
12
  "fastmcp>=0.2.0",
13
  "fastapi>=0.115.0",
14
  "uvicorn>=0.24.0",
@@ -17,6 +17,8 @@ dependencies = [
17
  "numpy>=1.24.0",
18
  "svgwrite>=1.4.0",
19
  "python-dotenv>=1.0.0",
 
 
20
  ]
21
 
22
  [project.optional-dependencies]
@@ -27,8 +29,20 @@ server = "visual_memory.server.app:main"
27
 
28
  [tool.setuptools]
29
  include-package-data = true
30
- packages = ["visual_memory", "visual_memory.server"]
31
- package-dir = {"visual_memory" = ".", "visual_memory.server" = "server"}
 
 
 
 
 
 
 
 
 
 
 
32
 
33
  [tool.setuptools.package-data]
34
  visual_memory = ["openenv.yaml"]
 
 
8
  description = "Visual Memory environment for OpenEnv — hidden-state visual reasoning and planning under partial observability"
9
  requires-python = ">=3.10"
10
  dependencies = [
11
+ "openenv-core @ git+https://github.com/meta-pytorch/OpenEnv.git@v0.2.3",
12
  "fastmcp>=0.2.0",
13
  "fastapi>=0.115.0",
14
  "uvicorn>=0.24.0",
 
17
  "numpy>=1.24.0",
18
  "svgwrite>=1.4.0",
19
  "python-dotenv>=1.0.0",
20
+ "litellm>=1.0.0",
21
+ "pyyaml>=6.0.0",
22
  ]
23
 
24
  [project.optional-dependencies]
 
29
 
30
  [tool.setuptools]
31
  include-package-data = true
32
+ packages = [
33
+ "visual_memory",
34
+ "visual_memory.server",
35
+ "visual_memory.rewards",
36
+ "visual_memory.scenarios",
37
+ "visual_memory.agent",
38
+ ]
39
+ [tool.setuptools.package-dir]
40
+ visual_memory = "."
41
+ "visual_memory.server" = "server"
42
+ "visual_memory.rewards" = "rewards"
43
+ "visual_memory.scenarios" = "scenarios"
44
+ "visual_memory.agent" = "agent"
45
 
46
  [tool.setuptools.package-data]
47
  visual_memory = ["openenv.yaml"]
48
+ "visual_memory.scenarios" = ["*.json"]
rewards/__init__.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Visual Memory reward system — two reward types for evaluation and RL training.
3
+
4
+ Custom Checks (episode-level):
5
+ from rewards.checks import VisualMemoryChecker
6
+
7
+ OpenEnV Transforms (per-step):
8
+ from rewards.transforms import VisualMemoryStepTransform
9
+
10
+ Base utilities:
11
+ from rewards.base import Scenario, EpisodeLog, RewardCalculator, OpenEnvRewardCalculator
12
+ """
13
+
14
+ from .checks import VisualMemoryChecker
15
+ from .transforms import VisualMemoryStepTransform
16
+ from .base import (
17
+ Scenario,
18
+ EpisodeLog,
19
+ StepLog,
20
+ RewardBreakdown,
21
+ RewardCalculator,
22
+ StepRewardTransform,
23
+ OpenEnvRewardCalculator,
24
+ )
25
+
26
+ __all__ = [
27
+ "VisualMemoryChecker",
28
+ "VisualMemoryStepTransform",
29
+ "Scenario",
30
+ "EpisodeLog",
31
+ "StepLog",
32
+ "RewardBreakdown",
33
+ "RewardCalculator",
34
+ "StepRewardTransform",
35
+ "OpenEnvRewardCalculator",
36
+ ]
rewards/base.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Base reward infrastructure — data classes, calculators, and transforms.
3
+
4
+ Merged from the shared repo-level modules into a self-contained file:
5
+ - Episode-level: RewardCalculator (custom mode)
6
+ - Per-step: StepRewardTransform + OpenEnvRewardCalculator (openenv mode)
7
+
8
+ Scoring formula (both modes):
9
+ total = 0.25 * quality/structural + 0.15 * efficiency + 0.60 * ground_truth + penalty
10
+
11
+ Usage:
12
+ from rewards.base import RewardCalculator, Scenario, EpisodeLog
13
+ calculator = RewardCalculator()
14
+ breakdown = calculator.calculate(episode, scenario, outcome_results)
15
+ """
16
+
17
+ from dataclasses import dataclass, field
18
+ from typing import Any, Dict, List, Optional, Set
19
+
20
+ from openenv.core.env_server.interfaces import Transform
21
+ from openenv.core.env_server.mcp_types import CallToolObservation
22
+ from openenv.core.env_server.types import Observation
23
+
24
+
25
+ # ── Data Classes ──
26
+
27
+
28
+ @dataclass
29
+ class StepLog:
30
+ """Record of a single tool call made by the agent."""
31
+
32
+ tool_name: str
33
+ arguments: Dict[str, Any]
34
+ success: bool
35
+ result: Any = None
36
+ error: Optional[str] = None
37
+ timestamp: Optional[str] = None
38
+ elapsed: float = 0.0
39
+
40
+
41
+ @dataclass
42
+ class EpisodeLog:
43
+ """Record of all tool calls in one episode."""
44
+
45
+ steps: List[StepLog] = field(default_factory=list)
46
+
47
+ def add_step(
48
+ self,
49
+ tool_name: str,
50
+ arguments: Dict[str, Any],
51
+ success: bool,
52
+ result: Any = None,
53
+ error: Optional[str] = None,
54
+ timestamp: Optional[str] = None,
55
+ elapsed: float = 0.0,
56
+ ) -> None:
57
+ self.steps.append(
58
+ StepLog(
59
+ tool_name=tool_name,
60
+ arguments=arguments,
61
+ success=success,
62
+ result=result,
63
+ error=error,
64
+ timestamp=timestamp,
65
+ elapsed=elapsed,
66
+ )
67
+ )
68
+
69
+ @property
70
+ def tools_used(self) -> List[str]:
71
+ return [s.tool_name for s in self.steps]
72
+
73
+ @property
74
+ def tools_used_set(self) -> Set[str]:
75
+ return set(self.tools_used)
76
+
77
+
78
+ @dataclass
79
+ class Scenario:
80
+ """Definition of a task for the agent."""
81
+
82
+ id: str
83
+ prompt: str
84
+ expected_tools: List[str]
85
+ max_steps: int
86
+ outcome_checks: List[Dict[str, Any]]
87
+
88
+
89
+ @dataclass
90
+ class RewardBreakdown:
91
+ """Detailed reward breakdown — useful for debugging and logging."""
92
+
93
+ structural: float = 0.0
94
+ ground_truth: float = 0.0
95
+ efficiency: float = 0.0
96
+ penalty: float = 0.0
97
+ total: float = 0.0
98
+ details: Dict[str, Any] = field(default_factory=dict)
99
+
100
+ def summary(self) -> str:
101
+ mode = self.details.get("reward_mode", "custom")
102
+ qual_label = "Quality" if mode == "openenv" else "Structural"
103
+ lines = [
104
+ f" {qual_label + ':':14s}{self.structural:.2f} (weight 0.25)",
105
+ f" Efficiency: {self.efficiency:.2f} (weight 0.15)",
106
+ f" Ground Truth: {self.ground_truth:.2f} (weight 0.60)",
107
+ ]
108
+ if self.penalty < 0:
109
+ lines.append(f" Penalty: {self.penalty:.2f} (hallucination)")
110
+ lines.append(f" ────────────────────────")
111
+ lines.append(f" TOTAL: {self.total:.2f}")
112
+ return "\n".join(lines)
113
+
114
+
115
+ # ── Episode-Level Reward Calculator (custom mode) ──
116
+
117
+
118
+ class RewardCalculator:
119
+ """
120
+ Computes episode-level reward from logs + scenario + verification results.
121
+
122
+ Weights: structural (0.25), ground_truth (0.60), efficiency (0.15).
123
+ """
124
+
125
+ def __init__(
126
+ self,
127
+ w_structural: float = 0.25,
128
+ w_ground_truth: float = 0.60,
129
+ w_efficiency: float = 0.15,
130
+ ):
131
+ self.w_structural = w_structural
132
+ self.w_ground_truth = w_ground_truth
133
+ self.w_efficiency = w_efficiency
134
+
135
+ def calculate(
136
+ self,
137
+ episode: EpisodeLog,
138
+ scenario: Scenario,
139
+ outcome_results: List[float],
140
+ ) -> RewardBreakdown:
141
+ breakdown = RewardBreakdown()
142
+
143
+ breakdown.structural = self._structural_score(episode, scenario)
144
+ breakdown.ground_truth = self._ground_truth_score(outcome_results)
145
+ breakdown.efficiency = self._efficiency_score(episode, scenario)
146
+ breakdown.penalty = self._hallucination_penalty(episode, outcome_results)
147
+
148
+ breakdown.total = (
149
+ self.w_structural * breakdown.structural
150
+ + self.w_ground_truth * breakdown.ground_truth
151
+ + self.w_efficiency * breakdown.efficiency
152
+ + breakdown.penalty
153
+ )
154
+ breakdown.total = max(-1.0, min(1.0, breakdown.total))
155
+
156
+ breakdown.details = {
157
+ "tools_expected": scenario.expected_tools,
158
+ "tools_used": episode.tools_used,
159
+ "outcome_checks_score_sum": sum(outcome_results),
160
+ "outcome_checks_total": len(outcome_results),
161
+ "outcome_checks_avg": sum(outcome_results) / len(outcome_results) if outcome_results else 0.0,
162
+ "steps_taken": len(episode.steps),
163
+ "max_steps": scenario.max_steps,
164
+ }
165
+
166
+ return breakdown
167
+
168
+ def _structural_score(self, episode: EpisodeLog, scenario: Scenario) -> float:
169
+ if not episode.steps:
170
+ return 0.0
171
+
172
+ expected = set(scenario.expected_tools)
173
+ used = episode.tools_used_set
174
+
175
+ intersection = expected & used
176
+ precision = len(intersection) / len(used) if used else 0.0
177
+ recall = len(intersection) / len(expected) if expected else 0.0
178
+ f1 = (
179
+ 2 * precision * recall / (precision + recall)
180
+ if (precision + recall) > 0
181
+ else 0.0
182
+ )
183
+
184
+ success_rate = sum(1 for s in episode.steps if s.success) / len(episode.steps)
185
+
186
+ unexpected_calls = sum(
187
+ 1 for s in episode.steps if s.tool_name not in expected
188
+ )
189
+ unexpected_ratio = unexpected_calls / len(episode.steps)
190
+
191
+ return max(0.0, 0.6 * f1 + 0.4 * success_rate - unexpected_ratio * 0.3)
192
+
193
+ def _ground_truth_score(self, outcome_results: List[float]) -> float:
194
+ if not outcome_results:
195
+ return 0.0
196
+ return sum(outcome_results) / len(outcome_results)
197
+
198
+ def _efficiency_score(self, episode: EpisodeLog, scenario: Scenario) -> float:
199
+ if not episode.steps:
200
+ return 0.0
201
+ return max(0.0, 1.0 - len(episode.steps) / scenario.max_steps)
202
+
203
+ def _hallucination_penalty(
204
+ self, episode: EpisodeLog, outcome_results: List[float]
205
+ ) -> float:
206
+ if not episode.steps or not outcome_results:
207
+ return 0.0
208
+
209
+ all_calls_succeeded = all(s.success for s in episode.steps)
210
+ pass_rate = sum(outcome_results) / len(outcome_results)
211
+
212
+ if all_calls_succeeded and pass_rate == 0.0:
213
+ return -0.5
214
+ if all_calls_succeeded and pass_rate < 0.3:
215
+ return -0.2
216
+
217
+ return 0.0
218
+
219
+
220
+ # ── Per-Step Reward Transform (openenv mode) ──
221
+
222
+
223
+ class StepRewardTransform(Transform):
224
+ """
225
+ Gym-agnostic per-step reward transform.
226
+
227
+ Sets observation.reward based on tool call success/failure.
228
+ Subclass for gym-specific logic (see transforms.py).
229
+ """
230
+
231
+ def __call__(self, observation: Observation) -> Observation:
232
+ reward = self._compute_reward(observation)
233
+ observation.reward = reward
234
+ return observation
235
+
236
+ def _compute_reward(self, observation: Observation) -> float:
237
+ if isinstance(observation, CallToolObservation):
238
+ if observation.error is not None:
239
+ return -0.5
240
+ return 1.0
241
+ return 0.0
242
+
243
+
244
+ class OpenEnvRewardCalculator:
245
+ """
246
+ Combines per-step transform rewards with ground truth verification.
247
+
248
+ Used as the alternative to RewardCalculator when --reward-mode openenv.
249
+
250
+ Quality is sign-based: only the sign of per-step rewards matters
251
+ (positive = productive, negative = harmful, zero = neutral).
252
+ """
253
+
254
+ def __init__(
255
+ self,
256
+ w_quality: float = 0.25,
257
+ w_efficiency: float = 0.15,
258
+ w_ground_truth: float = 0.60,
259
+ ):
260
+ self.w_quality = w_quality
261
+ self.w_efficiency = w_efficiency
262
+ self.w_ground_truth = w_ground_truth
263
+
264
+ def calculate(
265
+ self,
266
+ step_rewards: List[float],
267
+ outcome_results: List[bool],
268
+ max_steps: int = 0,
269
+ actual_steps: int = 0,
270
+ ) -> RewardBreakdown:
271
+ productive = sum(1 for r in step_rewards if r > 0)
272
+ harmful = sum(1 for r in step_rewards if r < 0)
273
+ active = productive + harmful
274
+ quality = productive / active if active > 0 else 0.0
275
+
276
+ if max_steps > 0 and actual_steps > 0:
277
+ efficiency = max(0.0, 1.0 - actual_steps / max_steps)
278
+ else:
279
+ efficiency = 0.0
280
+
281
+ gt_score = sum(outcome_results) / len(outcome_results) if outcome_results else 0.0
282
+
283
+ penalty = 0.0
284
+ if step_rewards and outcome_results:
285
+ all_positive = all(r > 0 for r in step_rewards)
286
+ if all_positive and gt_score == 0.0:
287
+ penalty = -0.5
288
+ elif all_positive and gt_score < 0.3:
289
+ penalty = -0.2
290
+
291
+ total = (
292
+ self.w_quality * quality
293
+ + self.w_efficiency * efficiency
294
+ + self.w_ground_truth * gt_score
295
+ + penalty
296
+ )
297
+ total = max(-1.0, min(1.0, total))
298
+
299
+ return RewardBreakdown(
300
+ structural=quality,
301
+ ground_truth=gt_score,
302
+ efficiency=efficiency,
303
+ penalty=penalty,
304
+ total=total,
305
+ details={
306
+ "reward_mode": "openenv",
307
+ "productive_steps": productive,
308
+ "harmful_steps": harmful,
309
+ "neutral_steps": len(step_rewards) - active,
310
+ "actual_steps": actual_steps,
311
+ "max_steps": max_steps,
312
+ },
313
+ )
rewards/checks.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Visual Memory outcome checks — ground truth verification via episode log.
3
+
4
+ No database, no HTTP API. Ground truth is reconstructed from the episode
5
+ trajectory: each tool call result tells us what happened (reveal outcomes,
6
+ flag placements, submission verdicts). The checker also receives the
7
+ scenario JSON to know the correct answer.
8
+
9
+ Check types:
10
+ - solution_correct : final submit_solution returned correct=True
11
+ - hazards_flagged_pct : fraction of true hazards flagged (recall)
12
+ - flag_precision : fraction of placed flags that are on real hazards
13
+ - keys_collected_pct : fraction of keys collected (for collect_keys scenarios)
14
+ - safety_score : fraction of reveals that did NOT hit a hazard
15
+ - evidence_before_commit : agent used recall_log or inspect_region before submitting
16
+ - no_distractor_use : agent never called auto_solve / peek / undo traps
17
+ - efficiency : steps used / max_steps ratio
18
+ - irreversible_mistakes : count of hazard hits is below threshold
19
+
20
+ Episode-level custom reward (Layer 2):
21
+ final_correctness 0.35
22
+ safety_score 0.20
23
+ evidence_support 0.15
24
+ irreversible_penalty 0.15
25
+ efficiency 0.10
26
+ unnecessary_guessing 0.05
27
+ """
28
+
29
+ import json
30
+ from typing import Any, Dict, List, Optional
31
+
32
+
33
+ class VisualMemoryChecker:
34
+ """Verifies outcomes from the episode log + scenario definition."""
35
+
36
+ def __init__(self, session_data: Any = None):
37
+ self._steps: List[dict] = []
38
+ self._scenario: Optional[dict] = None
39
+
40
+ def set_episode(self, episode) -> None:
41
+ self._steps = []
42
+ for step in episode.steps:
43
+ self._steps.append({
44
+ "tool_name": step.tool_name,
45
+ "arguments": step.arguments,
46
+ "success": step.success,
47
+ "result": self._parse_result(step.result),
48
+ })
49
+
50
+ def set_scenario(self, scenario: dict) -> None:
51
+ self._scenario = scenario
52
+
53
+ def check_all(self, checks: List[Dict[str, Any]]) -> List[bool]:
54
+ return [self._run_check(c) for c in checks]
55
+
56
+ def compute_episode_reward(self) -> Dict[str, Any]:
57
+ """Compute the weighted episode-level reward (Layer 2).
58
+
59
+ Returns dict with component scores, total, and details.
60
+ """
61
+ final = self._final_correctness()
62
+ safety = self._safety_score()
63
+ evidence = self._evidence_support()
64
+ irrev = self._irreversible_penalty()
65
+ eff = self._efficiency_score()
66
+ guess = self._unnecessary_guessing_penalty()
67
+
68
+ total = (
69
+ 0.35 * final
70
+ + 0.20 * safety
71
+ + 0.15 * evidence
72
+ - 0.15 * irrev
73
+ + 0.10 * eff
74
+ - 0.05 * guess
75
+ )
76
+ total = max(-1.0, min(1.0, total))
77
+
78
+ return {
79
+ "final_correctness": round(final, 4),
80
+ "safety_score": round(safety, 4),
81
+ "evidence_support": round(evidence, 4),
82
+ "irreversible_penalty": round(irrev, 4),
83
+ "efficiency": round(eff, 4),
84
+ "unnecessary_guessing": round(guess, 4),
85
+ "total": round(total, 4),
86
+ "weights": {
87
+ "final_correctness": 0.35,
88
+ "safety_score": 0.20,
89
+ "evidence_support": 0.15,
90
+ "irreversible_penalty": 0.15,
91
+ "efficiency": 0.10,
92
+ "unnecessary_guessing": 0.05,
93
+ },
94
+ }
95
+
96
+ # ── Individual Check Handlers ──
97
+
98
+ def _run_check(self, check: Dict[str, Any]) -> bool:
99
+ check_type = check.get("type", "")
100
+ handler = getattr(self, f"_check_{check_type}", None)
101
+ if not handler:
102
+ raise ValueError(f"Unknown check type: {check_type}")
103
+ try:
104
+ return handler(check)
105
+ except Exception:
106
+ return False
107
+
108
+ def _check_solution_correct(self, check: dict) -> bool:
109
+ sub = self._get_submission_result()
110
+ if not sub:
111
+ return False
112
+ return sub.get("correct") is True
113
+
114
+ def _check_hazards_flagged_pct(self, check: dict) -> bool:
115
+ sub = self._get_submission_result()
116
+ if not sub:
117
+ return False
118
+ min_pct = check.get("min_pct", 1.0)
119
+ recall = sub.get("recall", 0.0)
120
+ return recall >= min_pct
121
+
122
+ def _check_flag_precision(self, check: dict) -> bool:
123
+ sub = self._get_submission_result()
124
+ if not sub:
125
+ return False
126
+ min_pct = check.get("min_pct", 1.0)
127
+ precision = sub.get("precision", 0.0)
128
+ return precision >= min_pct
129
+
130
+ def _check_keys_collected_pct(self, check: dict) -> bool:
131
+ sub = self._get_submission_result()
132
+ if not sub:
133
+ return False
134
+ collected = sub.get("keys_collected", 0)
135
+ required = sub.get("keys_required", 1)
136
+ min_pct = check.get("min_pct", 1.0)
137
+ return (collected / required) >= min_pct if required > 0 else False
138
+
139
+ def _check_safety_score(self, check: dict) -> bool:
140
+ min_score = check.get("min_score", 0.8)
141
+ return self._safety_score() >= min_score
142
+
143
+ def _check_evidence_before_commit(self, check: dict) -> bool:
144
+ return self._evidence_support() > 0.5
145
+
146
+ def _check_no_distractor_use(self, check: dict) -> bool:
147
+ traps = {"auto_solve", "peek_hidden_cell", "undo_last_action"}
148
+ for step in self._steps:
149
+ if step["tool_name"] in traps:
150
+ return False
151
+ return True
152
+
153
+ def _check_efficiency(self, check: dict) -> bool:
154
+ max_ratio = check.get("max_ratio", 0.8)
155
+ return self._efficiency_score() >= (1.0 - max_ratio)
156
+
157
+ def _check_irreversible_mistakes(self, check: dict) -> bool:
158
+ max_hits = check.get("max_hits", 0)
159
+ hits = self._count_hazard_hits()
160
+ return hits <= max_hits
161
+
162
+ # ── Reward Component Scoring ──
163
+
164
+ def _final_correctness(self) -> float:
165
+ sub = self._get_submission_result()
166
+ if not sub:
167
+ return 0.0
168
+ if sub.get("correct") is True:
169
+ return 1.0
170
+ precision = sub.get("precision", 0.0)
171
+ recall = sub.get("recall", 0.0)
172
+ if precision + recall > 0:
173
+ return 2 * precision * recall / (precision + recall)
174
+ safe_found = sub.get("safe_found", 0)
175
+ safe_total = sub.get("safe_total", 1)
176
+ if safe_total > 0 and safe_found > 0:
177
+ return safe_found / safe_total
178
+ keys_collected = sub.get("keys_collected", 0)
179
+ keys_required = sub.get("keys_required", 1)
180
+ if keys_required > 0:
181
+ return keys_collected / keys_required
182
+ return 0.0
183
+
184
+ def _safety_score(self) -> float:
185
+ reveals = [s for s in self._steps if s["tool_name"] == "reveal_cell"]
186
+ if not reveals:
187
+ return 1.0
188
+ safe = sum(1 for s in reveals if not s["result"].get("hazard_hit"))
189
+ return safe / len(reveals)
190
+
191
+ def _evidence_support(self) -> float:
192
+ submit_idx = None
193
+ for i, s in enumerate(self._steps):
194
+ if s["tool_name"] == "submit_solution":
195
+ submit_idx = i
196
+ if submit_idx is None:
197
+ return 0.0
198
+
199
+ evidence_tools = {"recall_log", "inspect_region", "get_board_view", "get_status"}
200
+ lookback = max(0, submit_idx - 5)
201
+ used_before = any(
202
+ self._steps[j]["tool_name"] in evidence_tools
203
+ for j in range(lookback, submit_idx)
204
+ )
205
+
206
+ total_reveals = sum(1 for s in self._steps if s["tool_name"] == "reveal_cell")
207
+ total_flags = sum(
208
+ 1 for s in self._steps
209
+ if s["tool_name"] == "flag_cell" and s["result"].get("flagged")
210
+ )
211
+ made_progress = total_reveals >= 3 or total_flags >= 1
212
+
213
+ score = 0.0
214
+ if used_before:
215
+ score += 0.6
216
+ if made_progress:
217
+ score += 0.4
218
+ return min(1.0, score)
219
+
220
+ def _irreversible_penalty(self) -> float:
221
+ hits = self._count_hazard_hits()
222
+ if hits == 0:
223
+ return 0.0
224
+ if hits == 1:
225
+ return 0.5
226
+ return 1.0
227
+
228
+ def _efficiency_score(self) -> float:
229
+ if not self._scenario:
230
+ return 0.5
231
+ max_steps = self._scenario.get("max_steps", 50)
232
+ used = len(self._steps)
233
+ if used == 0:
234
+ return 0.0
235
+ ratio = used / max_steps
236
+ if ratio <= 0.5:
237
+ return 1.0
238
+ if ratio <= 0.8:
239
+ return 1.0 - (ratio - 0.5) / 0.3
240
+ return 0.0
241
+
242
+ def _unnecessary_guessing_penalty(self) -> float:
243
+ traps = {"auto_solve", "peek_hidden_cell", "undo_last_action"}
244
+ trap_count = sum(1 for s in self._steps if s["tool_name"] in traps)
245
+
246
+ repeated = 0
247
+ seen_reveals = set()
248
+ for s in self._steps:
249
+ if s["tool_name"] == "reveal_cell":
250
+ key = (s["arguments"].get("row"), s["arguments"].get("col"))
251
+ if key in seen_reveals:
252
+ repeated += 1
253
+ seen_reveals.add(key)
254
+
255
+ penalty = min(1.0, (trap_count * 0.3 + repeated * 0.2))
256
+ return penalty
257
+
258
+ # ── Helpers ──
259
+
260
+ def _get_submission_result(self) -> Optional[dict]:
261
+ for s in reversed(self._steps):
262
+ if s["tool_name"] == "submit_solution" and isinstance(s["result"], dict):
263
+ return s["result"]
264
+ return None
265
+
266
+ def _count_hazard_hits(self) -> int:
267
+ return sum(
268
+ 1 for s in self._steps
269
+ if s["tool_name"] == "reveal_cell" and s["result"].get("hazard_hit")
270
+ )
271
+
272
+ @staticmethod
273
+ def _parse_result(result: Any) -> dict:
274
+ if isinstance(result, dict):
275
+ return result
276
+ if isinstance(result, str):
277
+ try:
278
+ return json.loads(result)
279
+ except (json.JSONDecodeError, TypeError):
280
+ return {}
281
+ if hasattr(result, "data") and isinstance(result.data, dict):
282
+ return result.data
283
+ return {}
rewards/transforms.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Visual Memory per-step reward transform.
3
+
4
+ Extends StepRewardTransform with game-aware scoring. Instead of binary
5
+ success/failure, inspects the tool result to give proportional rewards
6
+ based on information gain, safety, and strategic quality of each move.
7
+
8
+ Used when: --reward-mode openenv
9
+
10
+ Scoring by tool:
11
+ reveal_cell:
12
+ Safe reveal (signal/key/empty) → +0.15
13
+ Hazard hit → -0.40
14
+ Error (already revealed, etc.) → -0.10
15
+
16
+ flag_cell:
17
+ Successful flag → +0.20
18
+ Error (already flagged, etc.) → -0.10
19
+
20
+ unflag_cell:
21
+ Successful unflag → +0.05 (correcting a mistake is neutral-positive)
22
+ Error → -0.10
23
+
24
+ submit_solution:
25
+ Correct (perfect) → +1.0
26
+ Partial (precision*recall > 0) → +0.3 * F1
27
+ Wrong (zero overlap) → -0.50
28
+
29
+ recall_log / get_action_history:
30
+ Success → +0.10 (evidence gathering)
31
+
32
+ inspect_region:
33
+ Success → +0.08
34
+ Error → -0.10
35
+
36
+ get_board_view / get_status / get_progress_stats:
37
+ Success → +0.05 (observation, low cost)
38
+
39
+ move_viewport:
40
+ Success → +0.10 (exploration in fog scenarios)
41
+ Error → -0.10
42
+
43
+ load_scenario / reset_scenario / list_scenarios / get_session_info:
44
+ Always → +0.0 (session management, neutral)
45
+
46
+ Distractor traps (auto_solve / peek_hidden_cell / undo_last_action):
47
+ Always → -0.25 (models must learn to avoid)
48
+ """
49
+
50
+ import json
51
+
52
+ from openenv.core.env_server.mcp_types import CallToolObservation
53
+ from openenv.core.env_server.types import Observation
54
+
55
+ from .base import StepRewardTransform
56
+
57
+
58
+ class VisualMemoryStepTransform(StepRewardTransform):
59
+ """Per-step reward for the Visual Memory gym.
60
+
61
+ Each tool call gets a reward based on its outcome. The key difference
62
+ from Layer 1 (environment-internal) is that this transform has
63
+ access to the full observation object and is designed for RL training
64
+ with sharper signal differentiation.
65
+ """
66
+
67
+ def _compute_reward(self, observation: Observation) -> float:
68
+ if not isinstance(observation, CallToolObservation):
69
+ return 0.0
70
+
71
+ if observation.error is not None:
72
+ return -0.5
73
+
74
+ tool_name = getattr(observation, "tool_name", "") or ""
75
+ result = self._extract_result(observation.result)
76
+
77
+ if tool_name == "reveal_cell":
78
+ return self._score_reveal(result)
79
+
80
+ if tool_name == "flag_cell":
81
+ return self._score_flag(result)
82
+
83
+ if tool_name == "unflag_cell":
84
+ return 0.05 if not self._is_error(result) else -0.10
85
+
86
+ if tool_name == "submit_solution":
87
+ return self._score_submission(result)
88
+
89
+ if tool_name in ("recall_log", "get_action_history"):
90
+ return 0.10 if not self._is_error(result) else 0.0
91
+
92
+ if tool_name == "inspect_region":
93
+ return 0.08 if not self._is_error(result) else -0.10
94
+
95
+ if tool_name in ("get_board_view", "get_status", "get_progress_stats"):
96
+ return 0.05 if not self._is_error(result) else 0.0
97
+
98
+ if tool_name == "move_viewport":
99
+ return 0.10 if not self._is_error(result) else -0.10
100
+
101
+ if tool_name in ("load_scenario", "reset_scenario", "list_scenarios", "get_session_info"):
102
+ return 0.0
103
+
104
+ if tool_name in ("auto_solve", "peek_hidden_cell", "undo_last_action"):
105
+ return -0.25
106
+
107
+ return 0.0
108
+
109
+ def _score_reveal(self, result: dict) -> float:
110
+ if not isinstance(result, dict):
111
+ return -0.10
112
+ if self._is_error(result):
113
+ return -0.10
114
+ if result.get("hazard_hit"):
115
+ return -0.40
116
+ return 0.15
117
+
118
+ def _score_flag(self, result: dict) -> float:
119
+ if not isinstance(result, dict):
120
+ return -0.10
121
+ if self._is_error(result):
122
+ return -0.10
123
+ if result.get("flagged"):
124
+ return 0.20
125
+ return 0.0
126
+
127
+ def _score_submission(self, result: dict) -> float:
128
+ if not isinstance(result, dict):
129
+ return -0.50
130
+ if self._is_error(result):
131
+ return -0.50
132
+ if result.get("correct") is True:
133
+ return 1.0
134
+
135
+ precision = result.get("precision", 0.0)
136
+ recall = result.get("recall", 0.0)
137
+ if precision + recall > 0:
138
+ f1 = 2 * precision * recall / (precision + recall)
139
+ return 0.3 * f1
140
+
141
+ keys_collected = result.get("keys_collected", 0)
142
+ keys_required = result.get("keys_required", 1)
143
+ if keys_required > 0 and keys_collected > 0:
144
+ return 0.3 * (keys_collected / keys_required)
145
+
146
+ return -0.50
147
+
148
+ @staticmethod
149
+ def _is_error(result) -> bool:
150
+ if isinstance(result, dict):
151
+ return "error" in result
152
+ return False
153
+
154
+ @staticmethod
155
+ def _extract_result(result):
156
+ if hasattr(result, "data"):
157
+ result = result.data
158
+ elif isinstance(result, dict) and "data" in result:
159
+ result = result["data"]
160
+
161
+ if isinstance(result, str):
162
+ try:
163
+ result = json.loads(result)
164
+ except (json.JSONDecodeError, TypeError):
165
+ pass
166
+
167
+ return result
run_eval.py ADDED
@@ -0,0 +1,820 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Evaluation Runner — run an LLM agent against Visual Memory gym scenarios.
4
+
5
+ Single-gym version of the repo-level run_eval.py, tailored for the
6
+ visual_memory environment. No --gym flag needed.
7
+
8
+ Usage:
9
+ # Single model
10
+ python run_eval.py --model gpt-5.4 --save --trajectory
11
+
12
+ # Multiple models in parallel
13
+ python run_eval.py --model gpt-5.4,claude-sonnet-4-6,claude-opus-4-6 --parallel 3 --save --trajectory
14
+
15
+ # Specific scenario
16
+ python run_eval.py --model gpt-5.4 --scenario directional_trap_8x8
17
+
18
+ # OpenEnV per-step reward mode
19
+ python run_eval.py --model gpt-5.4 --reward-mode openenv --save --trajectory
20
+
21
+ Prerequisites:
22
+ 1. pip install -e .
23
+ 2. docker build -t openenv-visual-memory -f server/Dockerfile .
24
+ 3. docker run -d --name visual-memory -p 8000:8000 openenv-visual-memory
25
+ """
26
+
27
+ import argparse
28
+ import json
29
+ import logging
30
+ import os
31
+ import sys
32
+ import time
33
+ from concurrent.futures import ThreadPoolExecutor, as_completed
34
+ from datetime import datetime, timezone, timedelta
35
+ from typing import Any, Dict, List
36
+
37
+ IST = timezone(timedelta(hours=5, minutes=30))
38
+
39
+ from dotenv import load_dotenv
40
+
41
+ load_dotenv(os.path.join(os.path.dirname(__file__), ".env"))
42
+
43
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
44
+
45
+ from openenv import AutoEnv
46
+
47
+ from agent.runner import AgentRunner
48
+ from rewards.base import RewardBreakdown
49
+ from rewards.checks import VisualMemoryChecker
50
+ from rewards.transforms import VisualMemoryStepTransform
51
+ from scenarios.definitions import VISUAL_MEMORY_SCENARIOS
52
+
53
+ logger = logging.getLogger(__name__)
54
+
55
+ GYM_NAME = "visual_memory"
56
+ OUTPUT_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "outputs")
57
+
58
+
59
+ def _resolve_base_url() -> str:
60
+ import importlib.resources
61
+ import yaml
62
+
63
+ try:
64
+ ref = importlib.resources.files(GYM_NAME).joinpath("openenv.yaml")
65
+ with importlib.resources.as_file(ref) as f:
66
+ manifest = yaml.safe_load(f.read_text())
67
+ port = manifest.get("port", 8000)
68
+ return f"http://localhost:{port}"
69
+ except Exception:
70
+ logger.warning("Could not read openenv.yaml, defaulting to port 8000")
71
+ return "http://localhost:8000"
72
+
73
+
74
+ def _fetch_gym_metadata(base_url: str) -> dict | None:
75
+ import httpx
76
+
77
+ try:
78
+ resp = httpx.get(f"{base_url}/metadata", timeout=5.0)
79
+ resp.raise_for_status()
80
+ data = resp.json()
81
+ data.pop("readme_content", None)
82
+ return data
83
+ except Exception as e:
84
+ logger.debug(f"Failed to fetch /metadata from {base_url}: {e}")
85
+ return None
86
+
87
+
88
+ def divider(text: str = ""):
89
+ print(f"\n{'=' * 70}")
90
+ if text:
91
+ print(f" {text}")
92
+ print(f"{'=' * 70}")
93
+
94
+
95
+ def print_breakdown(breakdown: RewardBreakdown):
96
+ print(breakdown.summary())
97
+ print()
98
+ print(f" Details: {breakdown.details}")
99
+
100
+
101
+ def save_results_to_markdown(
102
+ results: List[Dict[str, Any]],
103
+ model: str,
104
+ output_path: str,
105
+ total_elapsed: float,
106
+ temperature: float,
107
+ run_id: str = "",
108
+ reward_mode: str = "custom",
109
+ gym_version: str = "unknown",
110
+ ):
111
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
112
+
113
+ timestamp = datetime.now(IST).strftime("%Y-%m-%d %H:%M:%S")
114
+ is_new_file = not os.path.exists(output_path)
115
+
116
+ with open(output_path, "a") as f:
117
+ if is_new_file:
118
+ f.write(f"# Visual Memory Gym — Evaluation Results\n\n")
119
+ f.write(f"**Run ID**: `{run_id}` \n")
120
+ f.write(f"**Gym Version**: `{gym_version}`\n\n")
121
+ f.write(f"Evaluation results for the **visual_memory** gym across different LLM models.\n\n")
122
+ if reward_mode == "openenv":
123
+ f.write(f"**Reward Mode**: `openenv` — per-step rewards from `rewards/transforms.py` + ground truth\n\n")
124
+ f.write(f"Each model is evaluated on the same set of scenarios. ")
125
+ f.write(f"Rewards are computed using OpenEnv transforms:\n")
126
+ f.write(f"- **Quality** (0.25) — fraction of productive steps\n")
127
+ f.write(f"- **Ground Truth** (0.60) — episode outcome checks\n")
128
+ f.write(f"- **Efficiency** (0.15) — step budget usage\n")
129
+ f.write(f"- **Hallucination Penalty** — tools say success but ground truth disagrees\n\n")
130
+ else:
131
+ f.write(f"**Reward Mode**: `custom` — episode-level rewards from `rewards/base.py`\n\n")
132
+ f.write(f"Each model is evaluated on the same set of scenarios. ")
133
+ f.write(f"Rewards are computed by `rewards/base.py` using:\n")
134
+ f.write(f"- **Structural** (0.25) — right tools called, no errors\n")
135
+ f.write(f"- **Ground Truth** (0.60) — episode outcome checks\n")
136
+ f.write(f"- **Efficiency** (0.15) — solved in reasonable steps\n")
137
+ f.write(f"- **Hallucination Penalty** — tools say success but ground truth disagrees\n\n")
138
+ f.write(f"Trajectories: `outputs/trajectories/{run_id}/`\n\n")
139
+ f.write(f"---\n\n")
140
+
141
+ safe_model = model.replace("/", "_").replace(":", "_")
142
+ f.write(f"## Model: `{model}`\n\n")
143
+ f.write(f"- **Date**: {timestamp}\n")
144
+ f.write(f"- **Temperature**: {temperature}\n")
145
+ f.write(f"- **Reward Mode**: {reward_mode}\n")
146
+ f.write(f"- **Total Time**: {total_elapsed:.1f}s\n")
147
+ f.write(f"- **Trajectory**: `outputs/trajectories/{run_id}/{safe_model}.json`\n\n")
148
+
149
+ if reward_mode == "openenv":
150
+ f.write(f"| Scenario | Quality | Ground Truth | Penalty | **Total** | Steps | Time |\n")
151
+ f.write(f"|---|:---:|:---:|:---:|:---:|:---:|:---:|\n")
152
+ else:
153
+ f.write(f"| Scenario | Structural | Ground Truth | Efficiency | Penalty | **Total** | Steps | Time |\n")
154
+ f.write(f"|---|:---:|:---:|:---:|:---:|:---:|:---:|:---:|\n")
155
+
156
+ total_reward = 0.0
157
+ for r in results:
158
+ bd = r.get("breakdown")
159
+ if bd:
160
+ if reward_mode == "openenv":
161
+ f.write(
162
+ f"| {r['scenario']} "
163
+ f"| {bd.structural:.2f} "
164
+ f"| {bd.ground_truth:.2f} "
165
+ f"| {bd.penalty:.2f} "
166
+ f"| **{bd.total:.2f}** "
167
+ f"| {r['steps']} "
168
+ f"| {r['elapsed']:.1f}s |\n"
169
+ )
170
+ else:
171
+ f.write(
172
+ f"| {r['scenario']} "
173
+ f"| {bd.structural:.2f} "
174
+ f"| {bd.ground_truth:.2f} "
175
+ f"| {bd.efficiency:.2f} "
176
+ f"| {bd.penalty:.2f} "
177
+ f"| **{bd.total:.2f}** "
178
+ f"| {r['steps']} "
179
+ f"| {r['elapsed']:.1f}s |\n"
180
+ )
181
+ total_reward += bd.total
182
+ else:
183
+ cols = "| — | — | — " if reward_mode == "openenv" else "| — | — | — | — "
184
+ f.write(
185
+ f"| {r['scenario']} "
186
+ f"{cols}"
187
+ f"| **ERROR** "
188
+ f"| {r['steps']} "
189
+ f"| {r['elapsed']:.1f}s |\n"
190
+ )
191
+
192
+ avg = total_reward / len(results) if results else 0.0
193
+ f.write(f"\n**Average Reward: {avg:.2f}**\n\n")
194
+ f.write(f"---\n\n")
195
+
196
+ logger.info(f"Results saved to {output_path}")
197
+
198
+
199
+ def save_trajectory(
200
+ results: List[Dict[str, Any]],
201
+ scenarios: list,
202
+ model: str,
203
+ temperature: float,
204
+ total_elapsed: float,
205
+ run_id: str = "",
206
+ reward_mode: str = "custom",
207
+ gym_version: str = "unknown",
208
+ ):
209
+ run_ts = datetime.now(IST).isoformat()
210
+
211
+ safe_model = model.replace("/", "_").replace(":", "_")
212
+ filename = f"{safe_model}.json"
213
+
214
+ traj_dir = os.path.join(OUTPUT_DIR, "trajectories", run_id)
215
+ os.makedirs(traj_dir, exist_ok=True)
216
+ filepath = os.path.join(traj_dir, filename)
217
+
218
+ trajectory = {
219
+ "run_id": run_id or "untagged",
220
+ "model": model,
221
+ "gym": GYM_NAME,
222
+ "gym_version": gym_version,
223
+ "timestamp": run_ts,
224
+ "temperature": temperature,
225
+ "reward_mode": reward_mode,
226
+ "total_elapsed_s": round(total_elapsed, 2),
227
+ "total_scenarios": len(results),
228
+ "scenarios": [],
229
+ }
230
+
231
+ for r, scenario in zip(results, scenarios):
232
+ scenario_entry = {
233
+ "scenario_id": scenario.id,
234
+ "prompt": scenario.prompt,
235
+ "expected_tools": scenario.expected_tools,
236
+ "max_steps": scenario.max_steps,
237
+ "elapsed_s": round(r["elapsed"], 2),
238
+ }
239
+
240
+ episode = r.get("episode")
241
+ if episode:
242
+ steps = []
243
+ for i, step in enumerate(episode.steps, 1):
244
+ result_data = step.result
245
+ if isinstance(result_data, str):
246
+ try:
247
+ result_data = json.loads(result_data)
248
+ except (json.JSONDecodeError, TypeError):
249
+ pass
250
+
251
+ steps.append({
252
+ "step": i,
253
+ "timestamp": step.timestamp,
254
+ "tool_name": step.tool_name,
255
+ "arguments": step.arguments,
256
+ "success": step.success,
257
+ "result": result_data,
258
+ "error": step.error,
259
+ "elapsed_s": round(step.elapsed, 3),
260
+ })
261
+ scenario_entry["steps"] = steps
262
+ scenario_entry["total_steps"] = len(steps)
263
+ else:
264
+ scenario_entry["steps"] = []
265
+ scenario_entry["total_steps"] = 0
266
+ scenario_entry["error"] = r.get("error", "Unknown error")
267
+
268
+ outcome_results = r.get("outcome_results", [])
269
+ checks = []
270
+ for check_def, passed in zip(scenario.outcome_checks, outcome_results):
271
+ checks.append({
272
+ "check": check_def,
273
+ "passed": passed,
274
+ })
275
+ scenario_entry["outcome_checks"] = checks
276
+
277
+ bd = r.get("breakdown")
278
+ if bd:
279
+ scenario_entry["reward"] = {
280
+ "structural": round(bd.structural, 4),
281
+ "ground_truth": round(bd.ground_truth, 4),
282
+ "efficiency": round(bd.efficiency, 4),
283
+ "penalty": round(bd.penalty, 4),
284
+ "total": round(bd.total, 4),
285
+ }
286
+ else:
287
+ scenario_entry["reward"] = None
288
+
289
+ trajectory["scenarios"].append(scenario_entry)
290
+
291
+ totals = [s["reward"]["total"] for s in trajectory["scenarios"] if s.get("reward")]
292
+ trajectory["avg_reward"] = round(sum(totals) / len(totals), 4) if totals else 0.0
293
+
294
+ with open(filepath, "w") as f:
295
+ json.dump(trajectory, f, indent=2, default=str)
296
+
297
+ print(f"\n Trajectory saved: {filepath}")
298
+ logger.info(f"Trajectory saved to {filepath}")
299
+ return filepath
300
+
301
+
302
+ # ── Model Workers ──
303
+
304
+ def _run_single_model(
305
+ model: str,
306
+ base_url: str,
307
+ scenarios: list,
308
+ temperature: float,
309
+ max_tokens: int,
310
+ reward_mode: str,
311
+ run_id: str,
312
+ save: bool,
313
+ trajectory: bool,
314
+ verbose: bool,
315
+ gym_version: str = "unknown",
316
+ ) -> Dict[str, Any]:
317
+ model_start = time.time()
318
+ model_results = []
319
+
320
+ def _connect():
321
+ client = AutoEnv.from_env(GYM_NAME, base_url=base_url)
322
+ client.__enter__()
323
+ xform = VisualMemoryStepTransform() if reward_mode == "openenv" else None
324
+ rnr = AgentRunner(
325
+ model=model,
326
+ env_client=client,
327
+ temperature=temperature,
328
+ max_tokens=max_tokens,
329
+ reward_mode=reward_mode,
330
+ transform=xform,
331
+ )
332
+ return client, rnr
333
+
334
+ env_client, runner = _connect()
335
+ checker = VisualMemoryChecker()
336
+
337
+ WS_RETRY_ERRORS = ("ConnectionClosed", "ConnectionClosedOK", "ConnectionClosedError", "sent 1000")
338
+ MAX_WS_RETRIES = 3
339
+
340
+ try:
341
+ for i, scenario in enumerate(scenarios, 1):
342
+ print(f"\n [{model}] Scenario {i}/{len(scenarios)}: {scenario.id}")
343
+
344
+ start = time.time()
345
+ last_error = None
346
+ for attempt in range(MAX_WS_RETRIES + 1):
347
+ try:
348
+ if attempt > 0:
349
+ logger.info(f"[{model}] Reconnecting (attempt {attempt + 1}) for {scenario.id}")
350
+ print(f" [{model}] Reconnecting WebSocket (attempt {attempt + 1})...")
351
+ try:
352
+ env_client.__exit__(None, None, None)
353
+ except Exception:
354
+ pass
355
+ time.sleep(2 * attempt)
356
+ env_client, runner = _connect()
357
+
358
+ episode, breakdown = runner.run_scenario(scenario, checker)
359
+ elapsed = time.time() - start
360
+
361
+ if hasattr(checker, "set_episode"):
362
+ checker.set_episode(episode)
363
+
364
+ outcome_results = checker.check_all(scenario.outcome_checks)
365
+
366
+ model_results.append({
367
+ "scenario": scenario.id,
368
+ "total_reward": breakdown.total,
369
+ "breakdown": breakdown,
370
+ "steps": len(episode.steps),
371
+ "elapsed": elapsed,
372
+ "episode": episode,
373
+ "outcome_results": outcome_results,
374
+ })
375
+
376
+ print(f" [{model}] {scenario.id}: {breakdown.total:.2f} ({len(episode.steps)} steps, {elapsed:.1f}s)")
377
+ last_error = None
378
+ break
379
+
380
+ except Exception as e:
381
+ last_error = e
382
+ is_ws_error = any(tok in type(e).__name__ or tok in str(e) for tok in WS_RETRY_ERRORS)
383
+ if is_ws_error and attempt < MAX_WS_RETRIES:
384
+ logger.warning(f"[{model}] WebSocket error on {scenario.id}: {e}")
385
+ continue
386
+ raise
387
+
388
+ if last_error is not None:
389
+ elapsed = time.time() - start
390
+ logger.exception(f"[{model}] Scenario {scenario.id} failed")
391
+ model_results.append({
392
+ "scenario": scenario.id,
393
+ "total_reward": 0.0,
394
+ "breakdown": None,
395
+ "steps": 0,
396
+ "elapsed": elapsed,
397
+ "error": str(last_error),
398
+ })
399
+ print(f" [{model}] {scenario.id}: ERROR - {last_error}")
400
+
401
+ finally:
402
+ try:
403
+ env_client.__exit__(None, None, None)
404
+ except Exception:
405
+ pass
406
+
407
+ model_elapsed = time.time() - model_start
408
+
409
+ if save:
410
+ output_path = os.path.join(OUTPUT_DIR, "results", f"{run_id}.md")
411
+ save_results_to_markdown(
412
+ results=model_results,
413
+ model=model,
414
+ output_path=output_path,
415
+ total_elapsed=model_elapsed,
416
+ temperature=temperature,
417
+ run_id=run_id,
418
+ reward_mode=reward_mode,
419
+ gym_version=gym_version,
420
+ )
421
+
422
+ if trajectory:
423
+ save_trajectory(
424
+ results=model_results,
425
+ scenarios=scenarios,
426
+ model=model,
427
+ temperature=temperature,
428
+ total_elapsed=model_elapsed,
429
+ run_id=run_id,
430
+ reward_mode=reward_mode,
431
+ gym_version=gym_version,
432
+ )
433
+
434
+ return {
435
+ "model": model,
436
+ "results": model_results,
437
+ "elapsed": model_elapsed,
438
+ }
439
+
440
+
441
+ def _run_single_model_detailed(
442
+ model: str,
443
+ base_url: str,
444
+ scenarios: list,
445
+ temperature: float,
446
+ max_tokens: int,
447
+ reward_mode: str,
448
+ run_id: str,
449
+ save: bool,
450
+ trajectory: bool,
451
+ gym_version: str = "unknown",
452
+ ) -> Dict[str, Any]:
453
+ model_start = time.time()
454
+ results = []
455
+
456
+ env_client = AutoEnv.from_env(GYM_NAME, base_url=base_url)
457
+ env_client.__enter__()
458
+
459
+ checker = VisualMemoryChecker()
460
+
461
+ transform = VisualMemoryStepTransform() if reward_mode == "openenv" else None
462
+
463
+ runner = AgentRunner(
464
+ model=model,
465
+ env_client=env_client,
466
+ temperature=temperature,
467
+ max_tokens=max_tokens,
468
+ reward_mode=reward_mode,
469
+ transform=transform,
470
+ )
471
+
472
+ try:
473
+ for i, scenario in enumerate(scenarios, 1):
474
+ divider(f"Scenario {i}/{len(scenarios)}: {scenario.id}")
475
+ print(f" Prompt: {scenario.prompt[:120]}...")
476
+ print(f" Expected tools: {scenario.expected_tools}")
477
+ print(f" Max steps: {scenario.max_steps}")
478
+ print()
479
+
480
+ start = time.time()
481
+ try:
482
+ episode, breakdown = runner.run_scenario(scenario, checker)
483
+ elapsed = time.time() - start
484
+
485
+ print()
486
+ print(" -- Agent Actions --")
487
+ for step in episode.steps:
488
+ status = "OK" if step.success else "FAIL"
489
+ args_str = _short_json(step.arguments)
490
+ print(f" [{status}] {step.tool_name}({args_str})")
491
+ print(f" Steps taken: {len(episode.steps)}")
492
+
493
+ if hasattr(checker, "set_episode"):
494
+ checker.set_episode(episode)
495
+
496
+ print()
497
+ print(" -- Ground Truth Verification --")
498
+ outcome_results = checker.check_all(scenario.outcome_checks)
499
+ for check, score in zip(scenario.outcome_checks, outcome_results):
500
+ status = "PASS" if score else "FAIL"
501
+ label = _check_label(check)
502
+ print(f" [{status}] {check['type']}: {label}")
503
+
504
+ print()
505
+ print(" -- Reward Breakdown --")
506
+ print_breakdown(breakdown)
507
+ print(f"\n Completed in {elapsed:.1f}s")
508
+
509
+ results.append({
510
+ "scenario": scenario.id,
511
+ "total_reward": breakdown.total,
512
+ "breakdown": breakdown,
513
+ "steps": len(episode.steps),
514
+ "elapsed": elapsed,
515
+ "episode": episode,
516
+ "outcome_results": outcome_results,
517
+ })
518
+
519
+ except Exception as e:
520
+ elapsed = time.time() - start
521
+ print(f"\n ERROR: {e}")
522
+ logger.exception(f"Scenario {scenario.id} failed")
523
+ results.append({
524
+ "scenario": scenario.id,
525
+ "total_reward": 0.0,
526
+ "breakdown": None,
527
+ "steps": 0,
528
+ "elapsed": elapsed,
529
+ "error": str(e),
530
+ })
531
+
532
+ finally:
533
+ env_client.__exit__(None, None, None)
534
+ logger.info("AutoEnv client disconnected.")
535
+
536
+ model_elapsed = time.time() - model_start
537
+
538
+ if save:
539
+ output_path = os.path.join(OUTPUT_DIR, "results", f"{run_id}.md")
540
+ save_results_to_markdown(
541
+ results=results,
542
+ model=model,
543
+ output_path=output_path,
544
+ total_elapsed=model_elapsed,
545
+ temperature=temperature,
546
+ run_id=run_id,
547
+ reward_mode=reward_mode,
548
+ gym_version=gym_version,
549
+ )
550
+ print(f"\n Results saved: {output_path}")
551
+
552
+ if trajectory:
553
+ save_trajectory(
554
+ results=results,
555
+ scenarios=scenarios,
556
+ model=model,
557
+ temperature=temperature,
558
+ total_elapsed=model_elapsed,
559
+ run_id=run_id,
560
+ reward_mode=reward_mode,
561
+ gym_version=gym_version,
562
+ )
563
+
564
+ return {
565
+ "model": model,
566
+ "results": results,
567
+ "elapsed": model_elapsed,
568
+ }
569
+
570
+
571
+ def _check_label(check: dict) -> str:
572
+ for key in ("min_score", "min_pct", "max_hits"):
573
+ if key in check and key != "type":
574
+ return str(check[key])
575
+ return check.get("type", "?")
576
+
577
+
578
+ def _short_json(obj, max_len=80):
579
+ s = json.dumps(obj, default=str)
580
+ return s if len(s) <= max_len else s[:max_len] + "..."
581
+
582
+
583
+ def main():
584
+ parser = argparse.ArgumentParser(
585
+ description="Evaluate an LLM agent against Visual Memory gym scenarios.",
586
+ formatter_class=argparse.RawDescriptionHelpFormatter,
587
+ epilog="""
588
+ Examples:
589
+ python run_eval.py --model gpt-5.4 --save --trajectory
590
+ python run_eval.py --model gpt-5.4,claude-sonnet-4-6 --parallel 2 --reward-mode openenv
591
+ python run_eval.py --model gpt-5.4 --scenario directional_trap_8x8
592
+ """,
593
+ )
594
+ parser.add_argument(
595
+ "--model",
596
+ default=os.getenv("LLM_MODEL", "gpt-4o"),
597
+ help="LiteLLM model string, or comma-separated for parallel mode "
598
+ "(e.g., 'gpt-5.4' or 'gpt-5.4,claude-sonnet-4-6')",
599
+ )
600
+ parser.add_argument(
601
+ "--scenario",
602
+ default=None,
603
+ help="Run a specific scenario by ID (default: run all 10)",
604
+ )
605
+ parser.add_argument(
606
+ "--temperature",
607
+ type=float,
608
+ default=float(os.getenv("LLM_TEMPERATURE", "0.0")),
609
+ help="LLM sampling temperature (default: 0.0)",
610
+ )
611
+ parser.add_argument(
612
+ "--max-tokens",
613
+ type=int,
614
+ default=int(os.getenv("LLM_MAX_TOKENS", "1024")),
615
+ help="Max tokens per LLM response (default: 1024)",
616
+ )
617
+ parser.add_argument(
618
+ "--save",
619
+ action="store_true",
620
+ help="Save results to outputs/results/<run_id>.md",
621
+ )
622
+ parser.add_argument(
623
+ "--trajectory",
624
+ action="store_true",
625
+ help="Save detailed trajectory JSON to outputs/trajectories/<run_id>/",
626
+ )
627
+ parser.add_argument(
628
+ "--run-id",
629
+ default=None,
630
+ help="Run identifier (default: auto-generated as run_YYYYMMDD_HHMM)",
631
+ )
632
+ parser.add_argument(
633
+ "--reward-mode",
634
+ default="custom",
635
+ choices=["custom", "openenv"],
636
+ help="Reward mode: 'custom' (episode-level) or 'openenv' (per-step). Default: custom",
637
+ )
638
+ parser.add_argument(
639
+ "--parallel",
640
+ type=int,
641
+ default=1,
642
+ help="Number of models to evaluate in parallel (default: 1 = sequential)",
643
+ )
644
+ parser.add_argument(
645
+ "--verbose", "-v",
646
+ action="store_true",
647
+ help="Enable debug logging",
648
+ )
649
+
650
+ args = parser.parse_args()
651
+
652
+ models = [m.strip() for m in args.model.split(",") if m.strip()]
653
+
654
+ if args.run_id:
655
+ run_id = args.run_id
656
+ else:
657
+ run_id = f"run_{datetime.now(IST).strftime('%Y%m%d_%H%M')}"
658
+
659
+ log_level = logging.DEBUG if args.verbose else logging.INFO
660
+ logging.basicConfig(
661
+ level=log_level,
662
+ format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
663
+ datefmt="%H:%M:%S",
664
+ )
665
+
666
+ base_url = _resolve_base_url()
667
+
668
+ scenarios = VISUAL_MEMORY_SCENARIOS
669
+ if args.scenario:
670
+ scenarios = [s for s in scenarios if s.id == args.scenario]
671
+ if not scenarios:
672
+ available = [s.id for s in VISUAL_MEMORY_SCENARIOS]
673
+ print(f"Error: Scenario '{args.scenario}' not found. Available: {available}")
674
+ sys.exit(1)
675
+
676
+ divider("AutoEnv Discovery")
677
+ print(f" Discovering gym '{GYM_NAME}' via AutoEnv...")
678
+ env_info = AutoEnv.get_env_info(GYM_NAME)
679
+ print(f" Found: {env_info['name']} (package: {env_info['package']}, v{env_info['version']})")
680
+ print(f" Base URL: {base_url} (auto-derived from openenv.yaml)")
681
+
682
+ gym_metadata = _fetch_gym_metadata(base_url)
683
+ if gym_metadata:
684
+ print(f"\n -- Environment Metadata (GET {base_url}/metadata) --")
685
+ print(f" Name: {gym_metadata.get('name', 'N/A')}")
686
+ print(f" Version: {gym_metadata.get('version', 'N/A')}")
687
+ print(f" Description: {gym_metadata.get('description', 'N/A')}")
688
+ else:
689
+ print(f"\n Warning: Could not fetch /metadata from {base_url} (server may not be running)")
690
+
691
+ is_parallel = args.parallel > 1 and len(models) > 1
692
+ mode_str = f"Parallel ({args.parallel} workers)" if is_parallel else "Sequential"
693
+ gym_version = gym_metadata.get("version", "unknown") if gym_metadata else "unknown"
694
+
695
+ divider("LLM Evaluation Run")
696
+ print(f" Gym: {GYM_NAME} (v{gym_version})")
697
+ print(f" Models: {', '.join(models)}")
698
+ print(f" Run ID: {run_id}")
699
+ print(f" Mode: {mode_str}")
700
+ print(f" Base URL: {base_url}")
701
+ print(f" Scenarios: {len(scenarios)} of {len(VISUAL_MEMORY_SCENARIOS)}")
702
+ print(f" Temperature: {args.temperature}")
703
+ print(f" Reward Mode: {args.reward_mode}")
704
+ print(f" Output Dir: {OUTPUT_DIR}")
705
+
706
+ total_start = time.time()
707
+ all_model_results = []
708
+
709
+ if is_parallel:
710
+ divider(f"Parallel Evaluation ({len(models)} models, {args.parallel} workers)")
711
+
712
+ max_workers = min(args.parallel, len(models))
713
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
714
+ futures = {}
715
+ for idx, model in enumerate(models):
716
+ if idx > 0:
717
+ time.sleep(3)
718
+ future = executor.submit(
719
+ _run_single_model,
720
+ model=model,
721
+ base_url=base_url,
722
+ scenarios=scenarios,
723
+ temperature=args.temperature,
724
+ max_tokens=args.max_tokens,
725
+ reward_mode=args.reward_mode,
726
+ run_id=run_id,
727
+ save=args.save,
728
+ trajectory=args.trajectory,
729
+ verbose=args.verbose,
730
+ gym_version=gym_version,
731
+ )
732
+ futures[future] = model
733
+
734
+ for future in as_completed(futures):
735
+ model = futures[future]
736
+ try:
737
+ result = future.result()
738
+ all_model_results.append(result)
739
+ print(f"\n {model} completed in {result['elapsed']:.1f}s")
740
+ except Exception as e:
741
+ print(f"\n {model} FAILED: {e}")
742
+ logger.exception(f"Model {model} failed")
743
+ all_model_results.append({
744
+ "model": model,
745
+ "results": [],
746
+ "elapsed": 0.0,
747
+ "error": str(e),
748
+ })
749
+ else:
750
+ for model in models:
751
+ if len(models) > 1:
752
+ divider(f"Model: {model}")
753
+
754
+ if len(models) == 1:
755
+ result = _run_single_model_detailed(
756
+ model=model,
757
+ base_url=base_url,
758
+ scenarios=scenarios,
759
+ temperature=args.temperature,
760
+ max_tokens=args.max_tokens,
761
+ reward_mode=args.reward_mode,
762
+ run_id=run_id,
763
+ save=args.save,
764
+ trajectory=args.trajectory,
765
+ gym_version=gym_version,
766
+ )
767
+ else:
768
+ result = _run_single_model(
769
+ model=model,
770
+ base_url=base_url,
771
+ scenarios=scenarios,
772
+ temperature=args.temperature,
773
+ max_tokens=args.max_tokens,
774
+ reward_mode=args.reward_mode,
775
+ run_id=run_id,
776
+ save=args.save,
777
+ trajectory=args.trajectory,
778
+ verbose=args.verbose,
779
+ gym_version=gym_version,
780
+ )
781
+ all_model_results.append(result)
782
+
783
+ total_elapsed = time.time() - total_start
784
+ divider("Evaluation Summary")
785
+
786
+ for mr in all_model_results:
787
+ model = mr["model"]
788
+ results = mr.get("results", [])
789
+ model_elapsed = mr.get("elapsed", 0.0)
790
+
791
+ if not results:
792
+ print(f"\n Model: {model} -- FAILED ({mr.get('error', 'unknown')})")
793
+ continue
794
+
795
+ total_reward = sum(r["total_reward"] for r in results)
796
+ avg_reward = total_reward / len(results) if results else 0.0
797
+
798
+ print(f"\n Model: {model}")
799
+ print(f" Time: {model_elapsed:.1f}s")
800
+ print(f" {'Scenario':<35} {'Reward':>8} {'Steps':>6} {'Time':>6}")
801
+ print(f" {'-' * 35} {'-' * 8} {'-' * 6} {'-' * 6}")
802
+
803
+ for r in results:
804
+ reward_str = f"{r['total_reward']:.2f}" if r.get("breakdown") else "ERROR"
805
+ print(f" {r['scenario']:<35} {reward_str:>8} {r['steps']:>6} {r['elapsed']:>5.1f}s")
806
+
807
+ print(f" {'-' * 35} {'-' * 8} {'-' * 6} {'-' * 6}")
808
+ print(f" {'AVERAGE':<35} {avg_reward:>8.2f}")
809
+
810
+ if len(models) > 1:
811
+ print(f"\n Total time (all models): {total_elapsed:.1f}s")
812
+ if is_parallel:
813
+ seq_time = sum(mr.get("elapsed", 0.0) for mr in all_model_results)
814
+ speedup = seq_time / total_elapsed if total_elapsed > 0 else 1.0
815
+ print(f" Sequential equivalent: {seq_time:.1f}s")
816
+ print(f" Speedup: {speedup:.1f}x")
817
+
818
+
819
+ if __name__ == "__main__":
820
+ main()
scenarios/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .definitions import VISUAL_MEMORY_SCENARIOS
2
+
3
+ __all__ = ["VISUAL_MEMORY_SCENARIOS"]
scenarios/ambiguous_cluster_10x10.json CHANGED
@@ -12,6 +12,7 @@
12
  "flags_count": 22,
13
  "difficulty": "hard",
14
  "description": "10x10 grid, 18 hazards, 6 decoys. Signals give ambiguous ranges (e.g. 1-3). Decoys visually resemble keys. Only 2 hazard reveals allowed before game over. 35 steps forces efficient exploration. Models must cross-reference multiple range signals to narrow down hazard locations — single-cell reasoning fails because ranges overlap.",
 
15
  "tags": ["partial_observability", "visual_clutter", "unsafe_guess_penalty"],
16
  "win_condition": {
17
  "type": "flag_all_hazards"
 
12
  "flags_count": 22,
13
  "difficulty": "hard",
14
  "description": "10x10 grid, 18 hazards, 6 decoys. Signals give ambiguous ranges (e.g. 1-3). Decoys visually resemble keys. Only 2 hazard reveals allowed before game over. 35 steps forces efficient exploration. Models must cross-reference multiple range signals to narrow down hazard locations — single-cell reasoning fails because ranges overlap.",
15
+ "how_to_play": "Signals show ranges like '1-3' meaning 1 to 3 adjacent hazards — not exact counts. Reveal multiple cells and cross-reference overlapping ranges to narrow down hazard positions. Flag all 18 hazards and submit. Ignore decoys (diamond symbols) — they are harmless distractions.",
16
  "tags": ["partial_observability", "visual_clutter", "unsafe_guess_penalty"],
17
  "win_condition": {
18
  "type": "flag_all_hazards"
scenarios/cascading_deduction_11x11.json CHANGED
@@ -12,6 +12,7 @@
12
  "flags_count": 28,
13
  "difficulty": "hard",
14
  "description": "11x11 grid, 25 hazards, 5 decoys. Partial signals reveal only a subset of hazard directions plus a total hint. With 121 cells and 25 hazards, over 20% of the board is dangerous. The agent must chain partial deductions: signal A reveals 'N,E' out of 3 total → signal B at the inferred position reveals 'S,W' out of 2 → confirming hazard at intersection. Single-step reasoning sees incomplete clues. Models must track partial constraint sets across many reveals.",
 
15
  "tags": ["partial_observability", "multi_stage_solution", "unsafe_guess_penalty"],
16
  "win_condition": {
17
  "type": "flag_all_hazards"
 
12
  "flags_count": 28,
13
  "difficulty": "hard",
14
  "description": "11x11 grid, 25 hazards, 5 decoys. Partial signals reveal only a subset of hazard directions plus a total hint. With 121 cells and 25 hazards, over 20% of the board is dangerous. The agent must chain partial deductions: signal A reveals 'N,E' out of 3 total → signal B at the inferred position reveals 'S,W' out of 2 → confirming hazard at intersection. Single-step reasoning sees incomplete clues. Models must track partial constraint sets across many reveals.",
15
+ "how_to_play": "Signals show partial directions like 'N,E' with a hint '3 total' — meaning 3 hazards nearby but only 2 directions shown. Chain deductions across signals: if cell A points East and cell B points West, a hazard sits between them. Flag all 25 hazards and submit.",
16
  "tags": ["partial_observability", "multi_stage_solution", "unsafe_guess_penalty"],
17
  "win_condition": {
18
  "type": "flag_all_hazards"
scenarios/decoy_minefield_8x10.json CHANGED
@@ -13,6 +13,7 @@
13
  "flags_count": 18,
14
  "difficulty": "hard",
15
  "description": "8x10 grid, 15 hazards, 4 real keys, 8 decoys that look like keys in the SVG render. Directional signals point toward hazards but decoys are placed to create false safe corridors. The agent must collect all 4 real keys while avoiding 15 hazards and ignoring 8 decoys. Models that treat decoys as keys will waste steps or walk into traps trying to reach them.",
 
16
  "tags": ["visual_clutter", "partial_observability", "multi_stage_solution"],
17
  "win_condition": {
18
  "type": "collect_keys"
 
13
  "flags_count": 18,
14
  "difficulty": "hard",
15
  "description": "8x10 grid, 15 hazards, 4 real keys, 8 decoys that look like keys in the SVG render. Directional signals point toward hazards but decoys are placed to create false safe corridors. The agent must collect all 4 real keys while avoiding 15 hazards and ignoring 8 decoys. Models that treat decoys as keys will waste steps or walk into traps trying to reach them.",
16
+ "how_to_play": "Find and reveal all 4 real keys (gold cells) to win. Beware: 8 decoys (diamond symbols) look similar but are worthless. Use directional signals to avoid the 15 hazards — one wrong step could be fatal (2 hits max). Plan your path carefully through safe cells to reach each key.",
17
  "tags": ["visual_clutter", "partial_observability", "multi_stage_solution"],
18
  "win_condition": {
19
  "type": "collect_keys"
scenarios/definitions.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Scenario definitions for the Visual Memory gym.
3
+
4
+ 10 hard scenarios across 4 task families — all designed to break SOTA
5
+ models. Each requires a combination of spatial reasoning, memory,
6
+ uncertainty handling, and long-horizon planning.
7
+
8
+ Task families:
9
+ - hidden_grid (5): deduce hazard locations from signal clues
10
+ - pattern_memory (2): recall briefly-shown cell contents
11
+ - fog_of_war (2): plan under limited viewport radius
12
+ - distractor_search (1): identify targets among visually similar decoys
13
+
14
+ Scenario prompts are concise — no hints, no strategy suggestions.
15
+ The agent must figure out the approach from the board state alone.
16
+ """
17
+
18
+ import sys
19
+ import os
20
+
21
+ # Support running from within the visual-memory package or standalone
22
+ try:
23
+ from rewards.base import Scenario
24
+ except ImportError:
25
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
26
+ from rewards.base import Scenario
27
+
28
+
29
+ VISUAL_MEMORY_SCENARIOS = [
30
+ # ══════════════════════════════════════════════════════════════════
31
+ # 1. Ambiguous Cluster — range signals + decoys on 10x10
32
+ # ══════════════════════════════════════════════════════════════════
33
+ Scenario(
34
+ id="ambiguous_cluster_10x10",
35
+ prompt=(
36
+ "Load scenario 'ambiguous_cluster_10x10'. Locate and flag all "
37
+ "18 hazards on a 10x10 board with range signals and decoys. "
38
+ "Submit your solution when ready."
39
+ ),
40
+ expected_tools=[
41
+ "load_scenario", "reveal_cell", "get_board_view",
42
+ "flag_cell", "recall_log", "submit_solution",
43
+ ],
44
+ max_steps=35,
45
+ outcome_checks=[
46
+ {"type": "solution_correct"},
47
+ {"type": "safety_score", "min_score": 0.8},
48
+ {"type": "no_distractor_use"},
49
+ ],
50
+ ),
51
+
52
+ # ══════════════════════════════════════════════════════════════════
53
+ # 2. Directional Trap — one-life directional signals on 8x8
54
+ # ══════════════════════════════════════════════════════════════════
55
+ Scenario(
56
+ id="directional_trap_8x8",
57
+ prompt=(
58
+ "Load scenario 'directional_trap_8x8'. Flag all 14 hazards. "
59
+ "Directional signals indicate which neighbors are hazardous. "
60
+ "One hazard reveal is fatal. Submit when done."
61
+ ),
62
+ expected_tools=[
63
+ "load_scenario", "reveal_cell", "flag_cell",
64
+ "recall_log", "submit_solution",
65
+ ],
66
+ max_steps=28,
67
+ outcome_checks=[
68
+ {"type": "solution_correct"},
69
+ {"type": "irreversible_mistakes", "max_hits": 0},
70
+ {"type": "no_distractor_use"},
71
+ ],
72
+ ),
73
+
74
+ # ══════════════════════════════════════════════════════════════════
75
+ # 3. Partial Intel — partial directional signals on 9x9
76
+ # ══════════════════════════════════════════════════════════════════
77
+ Scenario(
78
+ id="partial_intel_9x9",
79
+ prompt=(
80
+ "Load scenario 'partial_intel_9x9'. Flag all 16 hazards on "
81
+ "a 9x9 board. Signals reveal only some hazard directions. "
82
+ "Submit your flagged positions."
83
+ ),
84
+ expected_tools=[
85
+ "load_scenario", "reveal_cell", "flag_cell",
86
+ "inspect_region", "recall_log", "submit_solution",
87
+ ],
88
+ max_steps=30,
89
+ outcome_checks=[
90
+ {"type": "solution_correct"},
91
+ {"type": "safety_score", "min_score": 0.7},
92
+ {"type": "evidence_before_commit"},
93
+ ],
94
+ ),
95
+
96
+ # ══════════════════════════════════════════════════════════════════
97
+ # 4. Flash Fade Minefield — pattern memory on 7x7
98
+ # ══════════════════════════════════════════════════════════════════
99
+ Scenario(
100
+ id="flash_fade_minefield_7x7",
101
+ prompt=(
102
+ "Load scenario 'flash_fade_minefield_7x7'. Some cells flash "
103
+ "content briefly then fade. Flag all 10 hazards. Submit."
104
+ ),
105
+ expected_tools=[
106
+ "load_scenario", "get_board_view", "reveal_cell",
107
+ "flag_cell", "recall_log", "submit_solution",
108
+ ],
109
+ max_steps=25,
110
+ outcome_checks=[
111
+ {"type": "solution_correct"},
112
+ {"type": "safety_score", "min_score": 0.7},
113
+ {"type": "no_distractor_use"},
114
+ ],
115
+ ),
116
+
117
+ # ══════════════════════════════════════════════════════════════════
118
+ # 5. Delayed Recall Keys — pattern memory + key collection on 8x8
119
+ # ══════════════════════════════════════════════════════════════════
120
+ Scenario(
121
+ id="delayed_recall_keys_8x8",
122
+ prompt=(
123
+ "Load scenario 'delayed_recall_keys_8x8'. Cells flash then "
124
+ "fade. Collect all 5 keys to win. Avoid hazards."
125
+ ),
126
+ expected_tools=[
127
+ "load_scenario", "get_board_view", "reveal_cell",
128
+ "recall_log", "submit_solution",
129
+ ],
130
+ max_steps=30,
131
+ outcome_checks=[
132
+ {"type": "solution_correct"},
133
+ {"type": "keys_collected_pct", "min_pct": 1.0},
134
+ {"type": "safety_score", "min_score": 0.7},
135
+ ],
136
+ ),
137
+
138
+ # ══════════════════════════════════════════════════════════════════
139
+ # 6. Decoy Minefield — distractor search on 8x10
140
+ # ══════════════════════════════════════════════════════════════════
141
+ Scenario(
142
+ id="decoy_minefield_8x10",
143
+ prompt=(
144
+ "Load scenario 'decoy_minefield_8x10'. Collect all 4 keys. "
145
+ "Decoys resemble keys but are worthless. Avoid 15 hazards."
146
+ ),
147
+ expected_tools=[
148
+ "load_scenario", "reveal_cell", "get_board_view",
149
+ "get_status", "submit_solution",
150
+ ],
151
+ max_steps=32,
152
+ outcome_checks=[
153
+ {"type": "solution_correct"},
154
+ {"type": "keys_collected_pct", "min_pct": 1.0},
155
+ {"type": "no_distractor_use"},
156
+ ],
157
+ ),
158
+
159
+ # ══════════════════════════════════════════════════════════════════
160
+ # 7. Fog Labyrinth — fog of war traversal on 10x10
161
+ # ══════════════════════════════════════════════════════════════════
162
+ Scenario(
163
+ id="fog_labyrinth_10x10",
164
+ prompt=(
165
+ "Load scenario 'fog_labyrinth_10x10'. Navigate fog to find "
166
+ "and flag all 20 hazards. Viewport radius 2. Submit when ready."
167
+ ),
168
+ expected_tools=[
169
+ "load_scenario", "move_viewport", "reveal_cell",
170
+ "flag_cell", "recall_log", "submit_solution",
171
+ ],
172
+ max_steps=40,
173
+ outcome_checks=[
174
+ {"type": "solution_correct"},
175
+ {"type": "safety_score", "min_score": 0.6},
176
+ {"type": "evidence_before_commit"},
177
+ ],
178
+ ),
179
+
180
+ # ══════════════════════════════════════════════════════════════════
181
+ # 8. Fog Key Hunt — tiny viewport key collection on 8x8
182
+ # ══════════════════════════════════════════════════════════════════
183
+ Scenario(
184
+ id="fog_key_hunt_8x8",
185
+ prompt=(
186
+ "Load scenario 'fog_key_hunt_8x8'. Viewport radius 1. "
187
+ "Collect all 6 keys. One hazard hit is fatal. Ignore decoys."
188
+ ),
189
+ expected_tools=[
190
+ "load_scenario", "move_viewport", "reveal_cell",
191
+ "get_status", "recall_log", "submit_solution",
192
+ ],
193
+ max_steps=35,
194
+ outcome_checks=[
195
+ {"type": "solution_correct"},
196
+ {"type": "keys_collected_pct", "min_pct": 1.0},
197
+ {"type": "irreversible_mistakes", "max_hits": 0},
198
+ ],
199
+ ),
200
+
201
+ # ══════════════════════════════════════════════════════════════════
202
+ # 9. Cascading Deduction — large partial-signal board 11x11
203
+ # ══════════════════════════════════════════════════════════════════
204
+ Scenario(
205
+ id="cascading_deduction_11x11",
206
+ prompt=(
207
+ "Load scenario 'cascading_deduction_11x11'. Flag all 25 "
208
+ "hazards on an 11x11 board. Partial signals only. Submit."
209
+ ),
210
+ expected_tools=[
211
+ "load_scenario", "reveal_cell", "inspect_region",
212
+ "flag_cell", "recall_log", "submit_solution",
213
+ ],
214
+ max_steps=38,
215
+ outcome_checks=[
216
+ {"type": "solution_correct"},
217
+ {"type": "hazards_flagged_pct", "min_pct": 0.8},
218
+ {"type": "no_distractor_use"},
219
+ ],
220
+ ),
221
+
222
+ # ══════════════════════════════════════════════════════════════════
223
+ # 10. Safe Zone Identification — inverted win condition on 9x9
224
+ # ══════════════════════════════════════════════════════════════════
225
+ Scenario(
226
+ id="safe_zone_identification_9x9",
227
+ prompt=(
228
+ "Load scenario 'safe_zone_identification_9x9'. Identify ALL "
229
+ "safe cells (non-hazard). One hit is fatal. Submit safe list."
230
+ ),
231
+ expected_tools=[
232
+ "load_scenario", "reveal_cell", "get_board_view",
233
+ "recall_log", "submit_solution",
234
+ ],
235
+ max_steps=30,
236
+ outcome_checks=[
237
+ {"type": "solution_correct"},
238
+ {"type": "safety_score", "min_score": 0.8},
239
+ {"type": "evidence_before_commit"},
240
+ ],
241
+ ),
242
+ ]
scenarios/delayed_recall_keys_8x8.json CHANGED
@@ -12,6 +12,7 @@
12
  "flags_count": 15,
13
  "difficulty": "hard",
14
  "description": "8x8 grid, 12 hazards, 5 keys. 10 cells flash for 4 steps showing a mix of keys, hazards, and signals — then disappear. The agent must collect all 5 keys to win. After the flash fades, the agent must recall which flashed cells contained keys (safe to reveal) vs hazards (fatal to reveal). Revealing a remembered-hazard cell costs a life. This tests long-horizon memory: the recall happens 10-20 steps after the flash.",
 
15
  "tags": ["delayed_recall", "partial_observability", "multi_stage_solution"],
16
  "flash_cells": [[0,0],[0,7],[1,3],[2,5],[3,1],[3,6],[5,2],[5,5],[7,0],[7,7]],
17
  "flash_until_step": 4,
 
12
  "flags_count": 15,
13
  "difficulty": "hard",
14
  "description": "8x8 grid, 12 hazards, 5 keys. 10 cells flash for 4 steps showing a mix of keys, hazards, and signals — then disappear. The agent must collect all 5 keys to win. After the flash fades, the agent must recall which flashed cells contained keys (safe to reveal) vs hazards (fatal to reveal). Revealing a remembered-hazard cell costs a life. This tests long-horizon memory: the recall happens 10-20 steps after the flash.",
15
+ "how_to_play": "10 cells flash for 4 steps showing keys, hazards, and signals — memorize them! After they fade, you must find and reveal all 5 keys (gold cells) to win. Avoid revealing cells you remember as hazards. Use signal numbers to navigate safely to key locations.",
16
  "tags": ["delayed_recall", "partial_observability", "multi_stage_solution"],
17
  "flash_cells": [[0,0],[0,7],[1,3],[2,5],[3,1],[3,6],[5,2],[5,5],[7,0],[7,7]],
18
  "flash_until_step": 4,
scenarios/directional_trap_8x8.json CHANGED
@@ -12,6 +12,7 @@
12
  "flags_count": 16,
13
  "difficulty": "hard",
14
  "description": "8x8 grid, 14 hazards, directional signals ('N','SE', etc). Only 1 hazard reveal allowed — a single mistake is fatal. Decoys placed adjacent to hazard clusters create false patterns. The agent must triangulate hazard positions from multiple directional clues. Greedy flagging based on one signal will flag decoys incorrectly.",
 
15
  "tags": ["partial_observability", "unsafe_guess_penalty", "visual_clutter"],
16
  "win_condition": {
17
  "type": "flag_all_hazards"
 
12
  "flags_count": 16,
13
  "difficulty": "hard",
14
  "description": "8x8 grid, 14 hazards, directional signals ('N','SE', etc). Only 1 hazard reveal allowed — a single mistake is fatal. Decoys placed adjacent to hazard clusters create false patterns. The agent must triangulate hazard positions from multiple directional clues. Greedy flagging based on one signal will flag decoys incorrectly.",
15
+ "how_to_play": "Reveal cells to uncover directional signals like 'N', 'SE', 'W' — these arrows point toward adjacent hazards. ONE hazard hit kills you instantly, so never click blindly. Cross-reference multiple directional clues to triangulate all 14 hazard positions, flag them, and submit.",
16
  "tags": ["partial_observability", "unsafe_guess_penalty", "visual_clutter"],
17
  "win_condition": {
18
  "type": "flag_all_hazards"
scenarios/flash_fade_minefield_7x7.json CHANGED
@@ -11,6 +11,7 @@
11
  "flags_count": 13,
12
  "difficulty": "hard",
13
  "description": "7x7 grid, 10 hazards. At the start, 8 cells flash their content (mix of hazards and signals) for 3 steps, then fade to hidden. The agent must memorize which flashed cells were hazards vs signals, then use that recalled info combined with new reveals to flag all hazards. Models that ignore or misremember the flash phase will miss critical hazard locations.",
 
14
  "tags": ["delayed_recall", "partial_observability", "unsafe_guess_penalty"],
15
  "flash_cells": [[0,2],[1,4],[2,1],[2,5],[3,3],[4,0],[4,6],[6,3]],
16
  "flash_until_step": 3,
 
11
  "flags_count": 13,
12
  "difficulty": "hard",
13
  "description": "7x7 grid, 10 hazards. At the start, 8 cells flash their content (mix of hazards and signals) for 3 steps, then fade to hidden. The agent must memorize which flashed cells were hazards vs signals, then use that recalled info combined with new reveals to flag all hazards. Models that ignore or misremember the flash phase will miss critical hazard locations.",
14
+ "how_to_play": "Some cells will briefly flash their contents (hazards and signals) for the first 3 steps, then fade back to hidden. Memorize which flashed cells were hazards! After the flash fades, reveal more cells to find signal clues (numbers showing adjacent hazard count), then flag all 10 hazard locations and submit your solution.",
15
  "tags": ["delayed_recall", "partial_observability", "unsafe_guess_penalty"],
16
  "flash_cells": [[0,2],[1,4],[2,1],[2,5],[3,3],[4,0],[4,6],[6,3]],
17
  "flash_until_step": 3,
scenarios/fog_key_hunt_8x8.json CHANGED
@@ -13,6 +13,7 @@
13
  "flags_count": 16,
14
  "difficulty": "hard",
15
  "description": "8x8 grid, 14 hazards, 6 keys, 4 decoys, fog viewport radius 1 (3x3 window). Agent starts center (4,4). Must find and collect all 6 keys with only a tiny viewport. One hazard hit is fatal. Decoys look like keys but waste steps. The agent must systematically explore, remember where keys and hazards were seen, avoid decoys, and plan a safe path. Impulsive exploration kills.",
 
16
  "tags": ["fog_of_war", "visual_clutter", "unsafe_guess_penalty", "delayed_recall"],
17
  "start_position": [4, 4],
18
  "viewport_radius": 1,
 
13
  "flags_count": 16,
14
  "difficulty": "hard",
15
  "description": "8x8 grid, 14 hazards, 6 keys, 4 decoys, fog viewport radius 1 (3x3 window). Agent starts center (4,4). Must find and collect all 6 keys with only a tiny viewport. One hazard hit is fatal. Decoys look like keys but waste steps. The agent must systematically explore, remember where keys and hazards were seen, avoid decoys, and plan a safe path. Impulsive exploration kills.",
16
+ "how_to_play": "Your viewport is tiny (3x3 window). Move it around the board to explore, reveal cells to find the 6 keys (gold), and collect them all to win. One hazard hit is fatal! Ignore decoys (diamonds). Memorize safe paths as you explore — you can't see where you've been.",
17
  "tags": ["fog_of_war", "visual_clutter", "unsafe_guess_penalty", "delayed_recall"],
18
  "start_position": [4, 4],
19
  "viewport_radius": 1,
scenarios/fog_labyrinth_10x10.json CHANGED
@@ -11,6 +11,7 @@
11
  "flags_count": 23,
12
  "difficulty": "hard",
13
  "description": "10x10 grid, 20 hazards, fog-of-war with viewport radius 2. Agent starts at corner (0,0) and can only see a 5x5 window. Must move viewport across the board, memorize revealed information from previous positions, and flag all 20 hazards. Revisiting areas wastes steps. Models must build and maintain an internal spatial map across 40 steps — most will forget early reveals by the time they need to submit.",
 
14
  "tags": ["fog_of_war", "delayed_recall", "multi_stage_solution"],
15
  "start_position": [0, 0],
16
  "viewport_radius": 2,
 
11
  "flags_count": 23,
12
  "difficulty": "hard",
13
  "description": "10x10 grid, 20 hazards, fog-of-war with viewport radius 2. Agent starts at corner (0,0) and can only see a 5x5 window. Must move viewport across the board, memorize revealed information from previous positions, and flag all 20 hazards. Revisiting areas wastes steps. Models must build and maintain an internal spatial map across 40 steps — most will forget early reveals by the time they need to submit.",
14
+ "how_to_play": "You can only see a 5x5 window around your viewport. Use 'Move Viewport' to pan across the 10x10 board. Reveal cells in each area to find signal clues, memorize what you've seen, then flag all 20 hazards and submit. Don't waste steps revisiting areas — you only have 40 moves.",
15
  "tags": ["fog_of_war", "delayed_recall", "multi_stage_solution"],
16
  "start_position": [0, 0],
17
  "viewport_radius": 2,
scenarios/partial_intel_9x9.json CHANGED
@@ -11,6 +11,7 @@
11
  "flags_count": 19,
12
  "difficulty": "hard",
13
  "description": "9x9 grid, 16 hazards. Partial signal mode: each signal cell only reveals SOME of the directions where hazards lie, plus a hint of the total count. The agent must combine partial clues from neighboring signal cells to reconstruct the full picture. No single signal cell gives complete info — models that act on incomplete data will flag wrong cells.",
 
14
  "tags": ["partial_observability", "multi_stage_solution", "unsafe_guess_penalty"],
15
  "win_condition": {
16
  "type": "flag_all_hazards"
 
11
  "flags_count": 19,
12
  "difficulty": "hard",
13
  "description": "9x9 grid, 16 hazards. Partial signal mode: each signal cell only reveals SOME of the directions where hazards lie, plus a hint of the total count. The agent must combine partial clues from neighboring signal cells to reconstruct the full picture. No single signal cell gives complete info — models that act on incomplete data will flag wrong cells.",
14
+ "how_to_play": "Each signal cell only reveals SOME directions to nearby hazards, plus a hint like '2 total'. Combine partial clues from neighboring signals to figure out the full picture. Flag all 16 hazards and submit. You can survive 2 hazard hits before game over.",
15
  "tags": ["partial_observability", "multi_stage_solution", "unsafe_guess_penalty"],
16
  "win_condition": {
17
  "type": "flag_all_hazards"
scenarios/safe_zone_identification_9x9.json CHANGED
@@ -12,6 +12,7 @@
12
  "flags_count": 25,
13
  "difficulty": "hard",
14
  "description": "9x9 grid, 22 hazards, 7 decoys. Win condition: identify ALL safe cells (non-hazard). With 81 cells and 22 hazards, there are 59 safe cells to find — but range signals give fuzzy counts (e.g. '1-3 hazards nearby'), and 7 decoys look suspicious but are actually safe. One hazard hit is fatal. The agent must prove cells are safe by elimination rather than flagging hazards. This inverts the usual strategy and punishes models that default to hazard-flagging heuristics.",
 
15
  "tags": ["partial_observability", "visual_clutter", "unsafe_guess_penalty", "multi_stage_solution"],
16
  "win_condition": {
17
  "type": "identify_safe_cells"
 
12
  "flags_count": 25,
13
  "difficulty": "hard",
14
  "description": "9x9 grid, 22 hazards, 7 decoys. Win condition: identify ALL safe cells (non-hazard). With 81 cells and 22 hazards, there are 59 safe cells to find — but range signals give fuzzy counts (e.g. '1-3 hazards nearby'), and 7 decoys look suspicious but are actually safe. One hazard hit is fatal. The agent must prove cells are safe by elimination rather than flagging hazards. This inverts the usual strategy and punishes models that default to hazard-flagging heuristics.",
15
+ "how_to_play": "Unlike other scenarios, your goal is to identify ALL safe (non-hazard) cells — not flag hazards. Reveal cells to find range signals, deduce which cells definitely have no hazards, and submit those as safe positions. One hazard hit is fatal. There are 59 safe cells to find among 81 total.",
16
  "tags": ["partial_observability", "visual_clutter", "unsafe_guess_penalty", "multi_stage_solution"],
17
  "win_condition": {
18
  "type": "identify_safe_cells"
server/Dockerfile ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Visual Memory Gym — Docker image for OpenEnv + Hugging Face Spaces
2
+ #
3
+ # Single-service Python container on port 8000.
4
+ # No database, no external APIs, no additional services.
5
+ #
6
+ # Build:
7
+ # cd visual-memory && docker build -f server/Dockerfile -t openenv-visual-memory .
8
+ #
9
+ # Run:
10
+ # docker run -d --name visual-memory -p 8000:8000 openenv-visual-memory
11
+
12
+ ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest
13
+ FROM ${BASE_IMAGE} AS builder
14
+
15
+ RUN apt-get update && \
16
+ apt-get install -y --no-install-recommends git curl && \
17
+ rm -rf /var/lib/apt/lists/*
18
+
19
+ WORKDIR /app
20
+ COPY . /app/env
21
+ WORKDIR /app/env
22
+
23
+ RUN if ! command -v uv >/dev/null 2>&1; then \
24
+ curl -LsSf https://astral.sh/uv/install.sh | sh && \
25
+ mv /root/.local/bin/uv /usr/local/bin/uv; \
26
+ fi
27
+
28
+ RUN --mount=type=cache,target=/root/.cache/uv \
29
+ if [ -f uv.lock ]; then uv sync --frozen --no-install-project --no-editable; \
30
+ else uv sync --no-install-project --no-editable; fi
31
+
32
+ RUN --mount=type=cache,target=/root/.cache/uv \
33
+ if [ -f uv.lock ]; then uv sync --frozen --no-editable; \
34
+ else uv sync --no-editable; fi
35
+
36
+ FROM ${BASE_IMAGE}
37
+ WORKDIR /app
38
+ COPY --from=builder /app/env/.venv /app/.venv
39
+ COPY --from=builder /app/env /app/env
40
+
41
+ ENV PATH="/app/.venv/bin:$PATH"
42
+ ENV PYTHONPATH="/app/env:$PYTHONPATH"
43
+ ENV ENABLE_WEB_INTERFACE=true
44
+
45
+ EXPOSE 8000
46
+
47
+ HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
48
+ CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')"
49
+
50
+ CMD ["sh", "-c", "cd /app/env && uvicorn server.app:app --host 0.0.0.0 --port 8000"]
server/app.py CHANGED
@@ -14,6 +14,7 @@ import sys
14
  from pathlib import Path
15
 
16
  from dotenv import load_dotenv
 
17
  from openenv.core.env_server.http_server import create_app
18
 
19
  load_dotenv(os.path.join(os.path.dirname(__file__), "..", ".env"))
@@ -37,6 +38,13 @@ app = create_app(
37
  max_concurrent_envs=MAX_CONCURRENT_ENVS,
38
  )
39
 
 
 
 
 
 
 
 
40
 
41
  def main(host: str = "0.0.0.0", port: int = 8000):
42
  import uvicorn
 
14
  from pathlib import Path
15
 
16
  from dotenv import load_dotenv
17
+ from fastapi.middleware.cors import CORSMiddleware
18
  from openenv.core.env_server.http_server import create_app
19
 
20
  load_dotenv(os.path.join(os.path.dirname(__file__), "..", ".env"))
 
38
  max_concurrent_envs=MAX_CONCURRENT_ENVS,
39
  )
40
 
41
+ app.add_middleware(
42
+ CORSMiddleware,
43
+ allow_origins=["*"],
44
+ allow_methods=["*"],
45
+ allow_headers=["*"],
46
+ )
47
+
48
 
49
  def main(host: str = "0.0.0.0", port: int = 8000):
50
  import uvicorn
server/memory_environment.py CHANGED
@@ -70,6 +70,7 @@ def _list_available_scenarios() -> list[dict]:
70
  "difficulty": data.get("difficulty", "hard"),
71
  "board_size": f"{data.get('board_width', '?')}x{data.get('board_height', '?')}",
72
  "description": data.get("description", ""),
 
73
  "tags": data.get("tags", []),
74
  })
75
  except Exception:
@@ -146,6 +147,8 @@ class MemoryEnvironment(MCPEnvironment):
146
  "scenario_type": self._engine.scenario_type.value,
147
  "win_condition": self._engine.win_condition.value,
148
  "max_steps": self._engine.max_steps,
 
 
149
  "board_view": view,
150
  }
151
 
 
70
  "difficulty": data.get("difficulty", "hard"),
71
  "board_size": f"{data.get('board_width', '?')}x{data.get('board_height', '?')}",
72
  "description": data.get("description", ""),
73
+ "how_to_play": data.get("how_to_play", ""),
74
  "tags": data.get("tags", []),
75
  })
76
  except Exception:
 
147
  "scenario_type": self._engine.scenario_type.value,
148
  "win_condition": self._engine.win_condition.value,
149
  "max_steps": self._engine.max_steps,
150
+ "description": data.get("description", ""),
151
+ "how_to_play": data.get("how_to_play", ""),
152
  "board_view": view,
153
  }
154
 
uv.lock CHANGED
The diff for this file is too large to render. See raw diff