discovery-env / app.py
echoboi
Fix G06 500 error: call env.reset() on session creation to init hidden state
3dbfa3d
#!/usr/bin/env python3
"""Discovery Environment REST API β€” Hugging Face Space.
Mirrors the 4 MCP tools as HTTP endpoints with session-per-problem management.
Each session is tied to one problem (G01-G08) and tracks query count for scoring.
Auth: set API_KEY env var in HF Space secrets β†’ pass X-Api-Key header.
If API_KEY is not set, the API is open.
"""
import json
import os
import random
import threading
import time
import uuid
from collections import Counter
from typing import Any
import numpy as np
from fastapi import Depends, FastAPI, Header, HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel
# discovery_env package lives alongside this file (copied for Docker build)
from discovery_env import get_problem, list_problems
from discovery_env.scoring import (
_generate_test_states,
compile_law,
stripped_code_length,
)
# ── Config ────────────────────────────────────────────────────────────────────
API_KEY = os.environ.get("API_KEY", "")
SESSION_TTL = 1800 # seconds β€” sessions expire after 30 min of inactivity
# ── App ───────────────────────────────────────────────────────────────────────
app = FastAPI(
title="Discovery Environment API",
description=(
"Black-box physics discovery environments (G01–G08). "
"Create a session for a problem, then query, simulate, and submit rules. "
"Scoring: accuracy (0–1.0) + parsimony (0–0.2) + efficiency (0–0.1) = max 1.3."
),
version="1.0.0",
)
# ── Session store ─────────────────────────────────────────────────────────────
_sessions: dict[str, dict] = {}
_lock = threading.Lock()
def _cleanup() -> None:
cutoff = time.time() - SESSION_TTL
with _lock:
expired = [k for k, v in _sessions.items() if v["created_at"] < cutoff]
for k in expired:
del _sessions[k]
def _get_session(session_id: str) -> dict:
_cleanup()
with _lock:
if session_id not in _sessions:
raise HTTPException(
404, f"Session '{session_id}' not found or expired (TTL {SESSION_TTL}s)"
)
return _sessions[session_id]
# ── Auth ──────────────────────────────────────────────────────────────────────
def _auth(x_api_key: str = Header(default="")) -> None:
if API_KEY and x_api_key != API_KEY:
raise HTTPException(403, "Invalid or missing X-Api-Key header")
# ── Request models ────────────────────────────────────────────────────────────
class SessionCreate(BaseModel):
problem_id: str = "G01"
class RandomStateRequest(BaseModel):
seed: int = 0
class SimulateRequest(BaseModel):
state_json: Any # accepts str, list, or nested β€” see _parse_state()
n_steps: int
class SubmitRequest(BaseModel):
code: str
# ── JSON serialisation helper ─────────────────────────────────────────────────
class _NpEncoder(json.JSONEncoder):
def default(self, o):
if isinstance(o, np.ndarray):
return o.tolist()
if isinstance(o, np.integer):
return int(o)
if isinstance(o, np.floating):
return float(o)
return super().default(o)
def _parse_state(v: Any) -> np.ndarray:
"""Robustly parse state_json into a numpy array.
Handles all formats models commonly send:
- list / nested list (model sent JSON array directly, pydantic parsed it)
- str containing JSON array: "[[1,2],[3,4]]"
- double-encoded str: "\"[[1,2],[3,4]]\"" (model JSON-stringified twice)
- str of Python repr: "array([[1,2],[3,4]])" β€” stripped to inner list
Raises ValueError with a descriptive message on failure.
"""
MAX_DEPTH = 3
for _ in range(MAX_DEPTH):
if isinstance(v, (list, tuple)):
return np.array(v)
if isinstance(v, np.ndarray):
return v
if isinstance(v, str):
v = v.strip()
# strip leading/trailing extra quotes added by double-encoding
if (v.startswith('"') and v.endswith('"')) or \
(v.startswith("'") and v.endswith("'")):
v = v[1:-1]
continue
try:
v = json.loads(v) # str -> list (or another str if double-encoded)
continue
except json.JSONDecodeError:
pass
# last resort: strip non-JSON wrapper like "array([[...]])"
bracket = v.find("[[")
if bracket != -1:
v = v[bracket:]
rbracket = v.rfind("]]")
if rbracket != -1:
v = v[: rbracket + 2]
continue
break
raise ValueError(
f"Cannot parse state_json (type={type(v).__name__}). "
"Expected a JSON array string like '[[1,2],[3,4]]' or a nested list."
)
def _resp(data) -> JSONResponse:
return JSONResponse(content=json.loads(json.dumps(data, cls=_NpEncoder)))
# ── Endpoints ─────────────────────────────────────────────────────────────────
@app.get("/health", tags=["meta"])
def health():
"""Liveness check β€” also returns active session count."""
return {"status": "ok", "active_sessions": len(_sessions)}
@app.get("/problems", tags=["meta"])
def problems(_: None = Depends(_auth)):
"""List all available problems with metadata."""
return list_problems()
@app.post("/session", tags=["session"])
def create_session(body: SessionCreate, _: None = Depends(_auth)):
"""Create a new session for one problem. Returns a session_id."""
try:
env = get_problem(body.problem_id)
except KeyError as exc:
raise HTTPException(400, str(exc))
env.reset(seed=42) # initialises hidden state (e.g. G06 orientation) before any scoring
sid = str(uuid.uuid4())
with _lock:
_sessions[sid] = {
"env": env,
"problem_id": body.problem_id,
"query_count": 0,
"created_at": time.time(),
}
return {"session_id": sid, "problem_id": body.problem_id, "ttl_seconds": SESSION_TTL}
@app.get("/session/{session_id}/info", tags=["tools"])
def get_system_info(session_id: str, _: None = Depends(_auth)):
"""Get system info: grid shape, value range. Does NOT reveal the update rule."""
sess = _get_session(session_id)
meta = sess["env"].get_state_shape()
return {
"type": meta["type"],
"rows": meta["rows"],
"cols": meta["cols"],
"values": meta["values"],
"description": meta["description"],
}
@app.post("/session/{session_id}/random_state", tags=["tools"])
def random_state(session_id: str, body: RandomStateRequest, _: None = Depends(_auth)):
"""Generate a random initial state. seed=0 means random."""
sess = _get_session(session_id)
sess["query_count"] += 1
env = sess["env"]
meta = env.get_state_shape()
rows, cols = meta["rows"], meta["cols"]
lo, hi = (int(x) for x in meta.get("values", "0-1").split("-"))
rng = np.random.default_rng(body.seed if body.seed > 0 else None)
state = rng.integers(lo, hi + 1, size=(rows, cols))
return _resp({"seed": body.seed, "shape": [rows, cols], "state": state})
@app.post("/session/{session_id}/simulate", tags=["tools"])
def simulate(session_id: str, body: SimulateRequest, _: None = Depends(_auth)):
"""Simulate the environment forward n_steps (1–100) from a given state."""
sess = _get_session(session_id)
sess["query_count"] += 1
env = sess["env"]
try:
initial = _parse_state(body.state_json)
except ValueError as exc:
raise HTTPException(400, f"Invalid state_json: {exc}")
n_steps = min(max(1, body.n_steps), 100)
env.set_initial_conditions(initial)
trajectory = []
for _ in range(n_steps):
state = env.step(1)
trajectory.append(state.tolist())
cells_changed = int(np.sum(initial != np.array(trajectory[-1])))
return _resp({
"n_steps": n_steps,
"cells_changed": cells_changed,
"final_state": trajectory[-1],
"trajectory": trajectory,
})
@app.post("/session/{session_id}/submit_rule", tags=["tools"])
def submit_rule(session_id: str, body: SubmitRequest, _: None = Depends(_auth)):
"""Score a proposed update rule. Code must define predict_next(grid) -> grid."""
sess = _get_session(session_id)
sess["query_count"] += 1
queries = sess["query_count"]
env = sess["env"]
fn = compile_law(body.code)
if fn is None:
return {"functional_accuracy": 0.0, "error": "Could not compile code", "total": 0.0}
# Fresh random seed each submission β€” prevents overfitting to a fixed test set
eval_seed = random.randint(1, 999_999)
test_states = _generate_test_states(env, n=500, seed=eval_seed)
total_cell_acc = 0.0
exact_matches = 0
worst_state_acc = 1.0
worst_pred = None
worst_exp = None
worst_inp = None
for state in test_states:
expected = env.get_true_next(state)
try:
predicted = fn(state.copy())
if isinstance(predicted, np.ndarray) and predicted.shape == expected.shape:
# Partial credit: fraction of cells predicted correctly
state_acc = float(np.mean(predicted == expected))
total_cell_acc += state_acc
if state_acc < worst_state_acc:
worst_state_acc = state_acc
worst_pred = predicted
worst_exp = expected
worst_inp = state
if np.array_equal(predicted, expected):
exact_matches += 1
except Exception:
pass
# Compute cell-level diagnostics from worst-performing state
cell_errors = []
error_regions = {}
common_error_patterns = []
if worst_pred is not None:
rows, cols = worst_pred.shape
wrong_mask = worst_pred != worst_exp
wrong_indices = np.argwhere(wrong_mask)
for idx in wrong_indices[:10]:
r, c = int(idx[0]), int(idx[1])
cell_errors.append({
"pos": [r, c],
"center": int(worst_inp[r, c]),
"N": int(worst_inp[(r - 1) % rows, c]),
"S": int(worst_inp[(r + 1) % rows, c]),
"E": int(worst_inp[r, (c + 1) % cols]),
"W": int(worst_inp[r, (c - 1) % cols]),
"NW": int(worst_inp[(r - 1) % rows, (c - 1) % cols]),
"NE": int(worst_inp[(r - 1) % rows, (c + 1) % cols]),
"SW": int(worst_inp[(r + 1) % rows, (c - 1) % cols]),
"SE": int(worst_inp[(r + 1) % rows, (c + 1) % cols]),
"predicted": int(worst_pred[r, c]),
"expected": int(worst_exp[r, c]),
})
mid_r, mid_c = rows // 2, cols // 2
error_regions = {
"top_left": int(wrong_mask[:mid_r, :mid_c].sum()),
"top_right": int(wrong_mask[:mid_r, mid_c:].sum()),
"bottom_left": int(wrong_mask[mid_r:, :mid_c].sum()),
"bottom_right": int(wrong_mask[mid_r:, mid_c:].sum()),
"total": int(wrong_mask.sum()),
"worst_state_accuracy": round(worst_state_acc, 4),
}
wrong_pairs = list(zip(worst_pred[wrong_mask].tolist(), worst_exp[wrong_mask].tolist()))
common_error_patterns = [
{"predicted": p, "expected": e, "count": c}
for (p, e), c in Counter(wrong_pairs).most_common(5)
]
# functional_accuracy = mean cell-level accuracy across all test states
accuracy = total_cell_acc / len(test_states)
# Held-out validation set (fixed seed, never seen during training)
val_states = _generate_test_states(env, n=200, seed=42_000)
val_cell_acc = 0.0
val_exact = 0
for state in val_states:
expected = env.get_true_next(state)
try:
predicted = fn(state.copy())
if isinstance(predicted, np.ndarray) and predicted.shape == expected.shape:
val_cell_acc += float(np.mean(predicted == expected))
if np.array_equal(predicted, expected):
val_exact += 1
except Exception:
pass
val_accuracy = val_cell_acc / len(val_states)
agent_dl = stripped_code_length(body.code)
try:
ref_dl = stripped_code_length(env.__class__.reference_code())
except NotImplementedError:
ref_dl = 0
delta_dl = max(0, agent_dl - ref_dl)
parsimony = 0.2 * max(0.0, 1.0 - delta_dl / 300)
efficiency = 0.1 * max(0.0, 1.0 - queries / 60)
return _resp({
"functional_accuracy": round(accuracy, 6),
"val_accuracy": round(val_accuracy, 6),
"parsimony_bonus": round(parsimony, 4),
"efficiency_bonus": round(efficiency, 4),
"total": round(accuracy + parsimony + efficiency, 4),
"correct_states": exact_matches,
"total_states": len(test_states),
"exact_accuracy": round(exact_matches / len(test_states), 6),
"val_exact_accuracy": round(val_exact / len(val_states), 6),
"eval_seed": eval_seed,
"queries_used": queries,
"code_length": len(body.code.strip()),
"agent_dl": agent_dl,
"reference_dl": ref_dl,
"delta_dl": delta_dl,
"cell_errors": cell_errors,
"error_regions": error_regions,
"common_error_patterns": common_error_patterns,
})
# ── Dashboard (agent monitor) ─────────────────────────────────────────────────
# Served at GET / and GET /api/* β€” does not affect any existing API routes.
from dashboard_routes import router as _dashboard_router # noqa: E402
app.include_router(_dashboard_router)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)