#!/usr/bin/env python3 """Discovery Environment REST API — Hugging Face Space. Mirrors the 4 MCP tools as HTTP endpoints with session-per-problem management. Each session is tied to one problem (G01-G08) and tracks query count for scoring. Auth: set API_KEY env var in HF Space secrets → pass X-Api-Key header. If API_KEY is not set, the API is open. """ import json import os import random import threading import time import uuid from collections import Counter from typing import Any import numpy as np from fastapi import Depends, FastAPI, Header, HTTPException from fastapi.responses import JSONResponse from pydantic import BaseModel # discovery_env package lives alongside this file (copied for Docker build) from discovery_env import get_problem, list_problems from discovery_env.scoring import ( _generate_test_states, compile_law, stripped_code_length, ) # ── Config ──────────────────────────────────────────────────────────────────── API_KEY = os.environ.get("API_KEY", "") SESSION_TTL = 1800 # seconds — sessions expire after 30 min of inactivity # ── App ─────────────────────────────────────────────────────────────────────── app = FastAPI( title="Discovery Environment API", description=( "Black-box physics discovery environments (G01–G08). " "Create a session for a problem, then query, simulate, and submit rules. " "Scoring: accuracy (0–1.0) + parsimony (0–0.2) + efficiency (0–0.1) = max 1.3." ), version="1.0.0", ) # ── Session store ───────────────────────────────────────────────────────────── _sessions: dict[str, dict] = {} _lock = threading.Lock() def _cleanup() -> None: cutoff = time.time() - SESSION_TTL with _lock: expired = [k for k, v in _sessions.items() if v["created_at"] < cutoff] for k in expired: del _sessions[k] def _get_session(session_id: str) -> dict: _cleanup() with _lock: if session_id not in _sessions: raise HTTPException( 404, f"Session '{session_id}' not found or expired (TTL {SESSION_TTL}s)" ) return _sessions[session_id] # ── Auth ────────────────────────────────────────────────────────────────────── def _auth(x_api_key: str = Header(default="")) -> None: if API_KEY and x_api_key != API_KEY: raise HTTPException(403, "Invalid or missing X-Api-Key header") # ── Request models ──────────────────────────────────────────────────────────── class SessionCreate(BaseModel): problem_id: str = "G01" class RandomStateRequest(BaseModel): seed: int = 0 class SimulateRequest(BaseModel): state_json: Any # accepts str, list, or nested — see _parse_state() n_steps: int class SubmitRequest(BaseModel): code: str # ── JSON serialisation helper ───────────────────────────────────────────────── class _NpEncoder(json.JSONEncoder): def default(self, o): if isinstance(o, np.ndarray): return o.tolist() if isinstance(o, np.integer): return int(o) if isinstance(o, np.floating): return float(o) return super().default(o) def _parse_state(v: Any) -> np.ndarray: """Robustly parse state_json into a numpy array. Handles all formats models commonly send: - list / nested list (model sent JSON array directly, pydantic parsed it) - str containing JSON array: "[[1,2],[3,4]]" - double-encoded str: "\"[[1,2],[3,4]]\"" (model JSON-stringified twice) - str of Python repr: "array([[1,2],[3,4]])" — stripped to inner list Raises ValueError with a descriptive message on failure. """ MAX_DEPTH = 3 for _ in range(MAX_DEPTH): if isinstance(v, (list, tuple)): return np.array(v) if isinstance(v, np.ndarray): return v if isinstance(v, str): v = v.strip() # strip leading/trailing extra quotes added by double-encoding if (v.startswith('"') and v.endswith('"')) or \ (v.startswith("'") and v.endswith("'")): v = v[1:-1] continue try: v = json.loads(v) # str -> list (or another str if double-encoded) continue except json.JSONDecodeError: pass # last resort: strip non-JSON wrapper like "array([[...]])" bracket = v.find("[[") if bracket != -1: v = v[bracket:] rbracket = v.rfind("]]") if rbracket != -1: v = v[: rbracket + 2] continue break raise ValueError( f"Cannot parse state_json (type={type(v).__name__}). " "Expected a JSON array string like '[[1,2],[3,4]]' or a nested list." ) def _resp(data) -> JSONResponse: return JSONResponse(content=json.loads(json.dumps(data, cls=_NpEncoder))) # ── Endpoints ───────────────────────────────────────────────────────────────── @app.get("/health", tags=["meta"]) def health(): """Liveness check — also returns active session count.""" return {"status": "ok", "active_sessions": len(_sessions)} @app.get("/problems", tags=["meta"]) def problems(_: None = Depends(_auth)): """List all available problems with metadata.""" return list_problems() @app.post("/session", tags=["session"]) def create_session(body: SessionCreate, _: None = Depends(_auth)): """Create a new session for one problem. Returns a session_id.""" try: env = get_problem(body.problem_id) except KeyError as exc: raise HTTPException(400, str(exc)) env.reset(seed=42) # initialises hidden state (e.g. G06 orientation) before any scoring sid = str(uuid.uuid4()) with _lock: _sessions[sid] = { "env": env, "problem_id": body.problem_id, "query_count": 0, "created_at": time.time(), } return {"session_id": sid, "problem_id": body.problem_id, "ttl_seconds": SESSION_TTL} @app.get("/session/{session_id}/info", tags=["tools"]) def get_system_info(session_id: str, _: None = Depends(_auth)): """Get system info: grid shape, value range. Does NOT reveal the update rule.""" sess = _get_session(session_id) meta = sess["env"].get_state_shape() return { "type": meta["type"], "rows": meta["rows"], "cols": meta["cols"], "values": meta["values"], "description": meta["description"], } @app.post("/session/{session_id}/random_state", tags=["tools"]) def random_state(session_id: str, body: RandomStateRequest, _: None = Depends(_auth)): """Generate a random initial state. seed=0 means random.""" sess = _get_session(session_id) sess["query_count"] += 1 env = sess["env"] meta = env.get_state_shape() rows, cols = meta["rows"], meta["cols"] lo, hi = (int(x) for x in meta.get("values", "0-1").split("-")) rng = np.random.default_rng(body.seed if body.seed > 0 else None) state = rng.integers(lo, hi + 1, size=(rows, cols)) return _resp({"seed": body.seed, "shape": [rows, cols], "state": state}) @app.post("/session/{session_id}/simulate", tags=["tools"]) def simulate(session_id: str, body: SimulateRequest, _: None = Depends(_auth)): """Simulate the environment forward n_steps (1–100) from a given state.""" sess = _get_session(session_id) sess["query_count"] += 1 env = sess["env"] try: initial = _parse_state(body.state_json) except ValueError as exc: raise HTTPException(400, f"Invalid state_json: {exc}") n_steps = min(max(1, body.n_steps), 100) env.set_initial_conditions(initial) trajectory = [] for _ in range(n_steps): state = env.step(1) trajectory.append(state.tolist()) cells_changed = int(np.sum(initial != np.array(trajectory[-1]))) return _resp({ "n_steps": n_steps, "cells_changed": cells_changed, "final_state": trajectory[-1], "trajectory": trajectory, }) @app.post("/session/{session_id}/submit_rule", tags=["tools"]) def submit_rule(session_id: str, body: SubmitRequest, _: None = Depends(_auth)): """Score a proposed update rule. Code must define predict_next(grid) -> grid.""" sess = _get_session(session_id) sess["query_count"] += 1 queries = sess["query_count"] env = sess["env"] fn = compile_law(body.code) if fn is None: return {"functional_accuracy": 0.0, "error": "Could not compile code", "total": 0.0} # Fresh random seed each submission — prevents overfitting to a fixed test set eval_seed = random.randint(1, 999_999) test_states = _generate_test_states(env, n=500, seed=eval_seed) total_cell_acc = 0.0 exact_matches = 0 worst_state_acc = 1.0 worst_pred = None worst_exp = None worst_inp = None for state in test_states: expected = env.get_true_next(state) try: predicted = fn(state.copy()) if isinstance(predicted, np.ndarray) and predicted.shape == expected.shape: # Partial credit: fraction of cells predicted correctly state_acc = float(np.mean(predicted == expected)) total_cell_acc += state_acc if state_acc < worst_state_acc: worst_state_acc = state_acc worst_pred = predicted worst_exp = expected worst_inp = state if np.array_equal(predicted, expected): exact_matches += 1 except Exception: pass # Compute cell-level diagnostics from worst-performing state cell_errors = [] error_regions = {} common_error_patterns = [] if worst_pred is not None: rows, cols = worst_pred.shape wrong_mask = worst_pred != worst_exp wrong_indices = np.argwhere(wrong_mask) for idx in wrong_indices[:10]: r, c = int(idx[0]), int(idx[1]) cell_errors.append({ "pos": [r, c], "center": int(worst_inp[r, c]), "N": int(worst_inp[(r - 1) % rows, c]), "S": int(worst_inp[(r + 1) % rows, c]), "E": int(worst_inp[r, (c + 1) % cols]), "W": int(worst_inp[r, (c - 1) % cols]), "NW": int(worst_inp[(r - 1) % rows, (c - 1) % cols]), "NE": int(worst_inp[(r - 1) % rows, (c + 1) % cols]), "SW": int(worst_inp[(r + 1) % rows, (c - 1) % cols]), "SE": int(worst_inp[(r + 1) % rows, (c + 1) % cols]), "predicted": int(worst_pred[r, c]), "expected": int(worst_exp[r, c]), }) mid_r, mid_c = rows // 2, cols // 2 error_regions = { "top_left": int(wrong_mask[:mid_r, :mid_c].sum()), "top_right": int(wrong_mask[:mid_r, mid_c:].sum()), "bottom_left": int(wrong_mask[mid_r:, :mid_c].sum()), "bottom_right": int(wrong_mask[mid_r:, mid_c:].sum()), "total": int(wrong_mask.sum()), "worst_state_accuracy": round(worst_state_acc, 4), } wrong_pairs = list(zip(worst_pred[wrong_mask].tolist(), worst_exp[wrong_mask].tolist())) common_error_patterns = [ {"predicted": p, "expected": e, "count": c} for (p, e), c in Counter(wrong_pairs).most_common(5) ] # functional_accuracy = mean cell-level accuracy across all test states accuracy = total_cell_acc / len(test_states) # Held-out validation set (fixed seed, never seen during training) val_states = _generate_test_states(env, n=200, seed=42_000) val_cell_acc = 0.0 val_exact = 0 for state in val_states: expected = env.get_true_next(state) try: predicted = fn(state.copy()) if isinstance(predicted, np.ndarray) and predicted.shape == expected.shape: val_cell_acc += float(np.mean(predicted == expected)) if np.array_equal(predicted, expected): val_exact += 1 except Exception: pass val_accuracy = val_cell_acc / len(val_states) agent_dl = stripped_code_length(body.code) try: ref_dl = stripped_code_length(env.__class__.reference_code()) except NotImplementedError: ref_dl = 0 delta_dl = max(0, agent_dl - ref_dl) parsimony = 0.2 * max(0.0, 1.0 - delta_dl / 300) efficiency = 0.1 * max(0.0, 1.0 - queries / 60) return _resp({ "functional_accuracy": round(accuracy, 6), "val_accuracy": round(val_accuracy, 6), "parsimony_bonus": round(parsimony, 4), "efficiency_bonus": round(efficiency, 4), "total": round(accuracy + parsimony + efficiency, 4), "correct_states": exact_matches, "total_states": len(test_states), "exact_accuracy": round(exact_matches / len(test_states), 6), "val_exact_accuracy": round(val_exact / len(val_states), 6), "eval_seed": eval_seed, "queries_used": queries, "code_length": len(body.code.strip()), "agent_dl": agent_dl, "reference_dl": ref_dl, "delta_dl": delta_dl, "cell_errors": cell_errors, "error_regions": error_regions, "common_error_patterns": common_error_patterns, }) # ── Dashboard (agent monitor) ───────────────────────────────────────────────── # Served at GET / and GET /api/* — does not affect any existing API routes. from dashboard_routes import router as _dashboard_router # noqa: E402 app.include_router(_dashboard_router) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)