Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """Discovery Environment REST API β Hugging Face Space. | |
| Mirrors the 4 MCP tools as HTTP endpoints with session-per-problem management. | |
| Each session is tied to one problem (G01-G08) and tracks query count for scoring. | |
| Auth: set API_KEY env var in HF Space secrets β pass X-Api-Key header. | |
| If API_KEY is not set, the API is open. | |
| """ | |
| import json | |
| import os | |
| import random | |
| import threading | |
| import time | |
| import uuid | |
| from collections import Counter | |
| from typing import Any | |
| import numpy as np | |
| from fastapi import Depends, FastAPI, Header, HTTPException | |
| from fastapi.responses import JSONResponse | |
| from pydantic import BaseModel | |
| # discovery_env package lives alongside this file (copied for Docker build) | |
| from discovery_env import get_problem, list_problems | |
| from discovery_env.scoring import ( | |
| _generate_test_states, | |
| compile_law, | |
| stripped_code_length, | |
| ) | |
| # ββ Config ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| API_KEY = os.environ.get("API_KEY", "") | |
| SESSION_TTL = 1800 # seconds β sessions expire after 30 min of inactivity | |
| # ββ App βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| app = FastAPI( | |
| title="Discovery Environment API", | |
| description=( | |
| "Black-box physics discovery environments (G01βG08). " | |
| "Create a session for a problem, then query, simulate, and submit rules. " | |
| "Scoring: accuracy (0β1.0) + parsimony (0β0.2) + efficiency (0β0.1) = max 1.3." | |
| ), | |
| version="1.0.0", | |
| ) | |
| # ββ Session store βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _sessions: dict[str, dict] = {} | |
| _lock = threading.Lock() | |
| def _cleanup() -> None: | |
| cutoff = time.time() - SESSION_TTL | |
| with _lock: | |
| expired = [k for k, v in _sessions.items() if v["created_at"] < cutoff] | |
| for k in expired: | |
| del _sessions[k] | |
| def _get_session(session_id: str) -> dict: | |
| _cleanup() | |
| with _lock: | |
| if session_id not in _sessions: | |
| raise HTTPException( | |
| 404, f"Session '{session_id}' not found or expired (TTL {SESSION_TTL}s)" | |
| ) | |
| return _sessions[session_id] | |
| # ββ Auth ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _auth(x_api_key: str = Header(default="")) -> None: | |
| if API_KEY and x_api_key != API_KEY: | |
| raise HTTPException(403, "Invalid or missing X-Api-Key header") | |
| # ββ Request models ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class SessionCreate(BaseModel): | |
| problem_id: str = "G01" | |
| class RandomStateRequest(BaseModel): | |
| seed: int = 0 | |
| class SimulateRequest(BaseModel): | |
| state_json: Any # accepts str, list, or nested β see _parse_state() | |
| n_steps: int | |
| class SubmitRequest(BaseModel): | |
| code: str | |
| # ββ JSON serialisation helper βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class _NpEncoder(json.JSONEncoder): | |
| def default(self, o): | |
| if isinstance(o, np.ndarray): | |
| return o.tolist() | |
| if isinstance(o, np.integer): | |
| return int(o) | |
| if isinstance(o, np.floating): | |
| return float(o) | |
| return super().default(o) | |
| def _parse_state(v: Any) -> np.ndarray: | |
| """Robustly parse state_json into a numpy array. | |
| Handles all formats models commonly send: | |
| - list / nested list (model sent JSON array directly, pydantic parsed it) | |
| - str containing JSON array: "[[1,2],[3,4]]" | |
| - double-encoded str: "\"[[1,2],[3,4]]\"" (model JSON-stringified twice) | |
| - str of Python repr: "array([[1,2],[3,4]])" β stripped to inner list | |
| Raises ValueError with a descriptive message on failure. | |
| """ | |
| MAX_DEPTH = 3 | |
| for _ in range(MAX_DEPTH): | |
| if isinstance(v, (list, tuple)): | |
| return np.array(v) | |
| if isinstance(v, np.ndarray): | |
| return v | |
| if isinstance(v, str): | |
| v = v.strip() | |
| # strip leading/trailing extra quotes added by double-encoding | |
| if (v.startswith('"') and v.endswith('"')) or \ | |
| (v.startswith("'") and v.endswith("'")): | |
| v = v[1:-1] | |
| continue | |
| try: | |
| v = json.loads(v) # str -> list (or another str if double-encoded) | |
| continue | |
| except json.JSONDecodeError: | |
| pass | |
| # last resort: strip non-JSON wrapper like "array([[...]])" | |
| bracket = v.find("[[") | |
| if bracket != -1: | |
| v = v[bracket:] | |
| rbracket = v.rfind("]]") | |
| if rbracket != -1: | |
| v = v[: rbracket + 2] | |
| continue | |
| break | |
| raise ValueError( | |
| f"Cannot parse state_json (type={type(v).__name__}). " | |
| "Expected a JSON array string like '[[1,2],[3,4]]' or a nested list." | |
| ) | |
| def _resp(data) -> JSONResponse: | |
| return JSONResponse(content=json.loads(json.dumps(data, cls=_NpEncoder))) | |
| # ββ Endpoints βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def health(): | |
| """Liveness check β also returns active session count.""" | |
| return {"status": "ok", "active_sessions": len(_sessions)} | |
| def problems(_: None = Depends(_auth)): | |
| """List all available problems with metadata.""" | |
| return list_problems() | |
| def create_session(body: SessionCreate, _: None = Depends(_auth)): | |
| """Create a new session for one problem. Returns a session_id.""" | |
| try: | |
| env = get_problem(body.problem_id) | |
| except KeyError as exc: | |
| raise HTTPException(400, str(exc)) | |
| env.reset(seed=42) # initialises hidden state (e.g. G06 orientation) before any scoring | |
| sid = str(uuid.uuid4()) | |
| with _lock: | |
| _sessions[sid] = { | |
| "env": env, | |
| "problem_id": body.problem_id, | |
| "query_count": 0, | |
| "created_at": time.time(), | |
| } | |
| return {"session_id": sid, "problem_id": body.problem_id, "ttl_seconds": SESSION_TTL} | |
| def get_system_info(session_id: str, _: None = Depends(_auth)): | |
| """Get system info: grid shape, value range. Does NOT reveal the update rule.""" | |
| sess = _get_session(session_id) | |
| meta = sess["env"].get_state_shape() | |
| return { | |
| "type": meta["type"], | |
| "rows": meta["rows"], | |
| "cols": meta["cols"], | |
| "values": meta["values"], | |
| "description": meta["description"], | |
| } | |
| def random_state(session_id: str, body: RandomStateRequest, _: None = Depends(_auth)): | |
| """Generate a random initial state. seed=0 means random.""" | |
| sess = _get_session(session_id) | |
| sess["query_count"] += 1 | |
| env = sess["env"] | |
| meta = env.get_state_shape() | |
| rows, cols = meta["rows"], meta["cols"] | |
| lo, hi = (int(x) for x in meta.get("values", "0-1").split("-")) | |
| rng = np.random.default_rng(body.seed if body.seed > 0 else None) | |
| state = rng.integers(lo, hi + 1, size=(rows, cols)) | |
| return _resp({"seed": body.seed, "shape": [rows, cols], "state": state}) | |
| def simulate(session_id: str, body: SimulateRequest, _: None = Depends(_auth)): | |
| """Simulate the environment forward n_steps (1β100) from a given state.""" | |
| sess = _get_session(session_id) | |
| sess["query_count"] += 1 | |
| env = sess["env"] | |
| try: | |
| initial = _parse_state(body.state_json) | |
| except ValueError as exc: | |
| raise HTTPException(400, f"Invalid state_json: {exc}") | |
| n_steps = min(max(1, body.n_steps), 100) | |
| env.set_initial_conditions(initial) | |
| trajectory = [] | |
| for _ in range(n_steps): | |
| state = env.step(1) | |
| trajectory.append(state.tolist()) | |
| cells_changed = int(np.sum(initial != np.array(trajectory[-1]))) | |
| return _resp({ | |
| "n_steps": n_steps, | |
| "cells_changed": cells_changed, | |
| "final_state": trajectory[-1], | |
| "trajectory": trajectory, | |
| }) | |
| def submit_rule(session_id: str, body: SubmitRequest, _: None = Depends(_auth)): | |
| """Score a proposed update rule. Code must define predict_next(grid) -> grid.""" | |
| sess = _get_session(session_id) | |
| sess["query_count"] += 1 | |
| queries = sess["query_count"] | |
| env = sess["env"] | |
| fn = compile_law(body.code) | |
| if fn is None: | |
| return {"functional_accuracy": 0.0, "error": "Could not compile code", "total": 0.0} | |
| # Fresh random seed each submission β prevents overfitting to a fixed test set | |
| eval_seed = random.randint(1, 999_999) | |
| test_states = _generate_test_states(env, n=500, seed=eval_seed) | |
| total_cell_acc = 0.0 | |
| exact_matches = 0 | |
| worst_state_acc = 1.0 | |
| worst_pred = None | |
| worst_exp = None | |
| worst_inp = None | |
| for state in test_states: | |
| expected = env.get_true_next(state) | |
| try: | |
| predicted = fn(state.copy()) | |
| if isinstance(predicted, np.ndarray) and predicted.shape == expected.shape: | |
| # Partial credit: fraction of cells predicted correctly | |
| state_acc = float(np.mean(predicted == expected)) | |
| total_cell_acc += state_acc | |
| if state_acc < worst_state_acc: | |
| worst_state_acc = state_acc | |
| worst_pred = predicted | |
| worst_exp = expected | |
| worst_inp = state | |
| if np.array_equal(predicted, expected): | |
| exact_matches += 1 | |
| except Exception: | |
| pass | |
| # Compute cell-level diagnostics from worst-performing state | |
| cell_errors = [] | |
| error_regions = {} | |
| common_error_patterns = [] | |
| if worst_pred is not None: | |
| rows, cols = worst_pred.shape | |
| wrong_mask = worst_pred != worst_exp | |
| wrong_indices = np.argwhere(wrong_mask) | |
| for idx in wrong_indices[:10]: | |
| r, c = int(idx[0]), int(idx[1]) | |
| cell_errors.append({ | |
| "pos": [r, c], | |
| "center": int(worst_inp[r, c]), | |
| "N": int(worst_inp[(r - 1) % rows, c]), | |
| "S": int(worst_inp[(r + 1) % rows, c]), | |
| "E": int(worst_inp[r, (c + 1) % cols]), | |
| "W": int(worst_inp[r, (c - 1) % cols]), | |
| "NW": int(worst_inp[(r - 1) % rows, (c - 1) % cols]), | |
| "NE": int(worst_inp[(r - 1) % rows, (c + 1) % cols]), | |
| "SW": int(worst_inp[(r + 1) % rows, (c - 1) % cols]), | |
| "SE": int(worst_inp[(r + 1) % rows, (c + 1) % cols]), | |
| "predicted": int(worst_pred[r, c]), | |
| "expected": int(worst_exp[r, c]), | |
| }) | |
| mid_r, mid_c = rows // 2, cols // 2 | |
| error_regions = { | |
| "top_left": int(wrong_mask[:mid_r, :mid_c].sum()), | |
| "top_right": int(wrong_mask[:mid_r, mid_c:].sum()), | |
| "bottom_left": int(wrong_mask[mid_r:, :mid_c].sum()), | |
| "bottom_right": int(wrong_mask[mid_r:, mid_c:].sum()), | |
| "total": int(wrong_mask.sum()), | |
| "worst_state_accuracy": round(worst_state_acc, 4), | |
| } | |
| wrong_pairs = list(zip(worst_pred[wrong_mask].tolist(), worst_exp[wrong_mask].tolist())) | |
| common_error_patterns = [ | |
| {"predicted": p, "expected": e, "count": c} | |
| for (p, e), c in Counter(wrong_pairs).most_common(5) | |
| ] | |
| # functional_accuracy = mean cell-level accuracy across all test states | |
| accuracy = total_cell_acc / len(test_states) | |
| # Held-out validation set (fixed seed, never seen during training) | |
| val_states = _generate_test_states(env, n=200, seed=42_000) | |
| val_cell_acc = 0.0 | |
| val_exact = 0 | |
| for state in val_states: | |
| expected = env.get_true_next(state) | |
| try: | |
| predicted = fn(state.copy()) | |
| if isinstance(predicted, np.ndarray) and predicted.shape == expected.shape: | |
| val_cell_acc += float(np.mean(predicted == expected)) | |
| if np.array_equal(predicted, expected): | |
| val_exact += 1 | |
| except Exception: | |
| pass | |
| val_accuracy = val_cell_acc / len(val_states) | |
| agent_dl = stripped_code_length(body.code) | |
| try: | |
| ref_dl = stripped_code_length(env.__class__.reference_code()) | |
| except NotImplementedError: | |
| ref_dl = 0 | |
| delta_dl = max(0, agent_dl - ref_dl) | |
| parsimony = 0.2 * max(0.0, 1.0 - delta_dl / 300) | |
| efficiency = 0.1 * max(0.0, 1.0 - queries / 60) | |
| return _resp({ | |
| "functional_accuracy": round(accuracy, 6), | |
| "val_accuracy": round(val_accuracy, 6), | |
| "parsimony_bonus": round(parsimony, 4), | |
| "efficiency_bonus": round(efficiency, 4), | |
| "total": round(accuracy + parsimony + efficiency, 4), | |
| "correct_states": exact_matches, | |
| "total_states": len(test_states), | |
| "exact_accuracy": round(exact_matches / len(test_states), 6), | |
| "val_exact_accuracy": round(val_exact / len(val_states), 6), | |
| "eval_seed": eval_seed, | |
| "queries_used": queries, | |
| "code_length": len(body.code.strip()), | |
| "agent_dl": agent_dl, | |
| "reference_dl": ref_dl, | |
| "delta_dl": delta_dl, | |
| "cell_errors": cell_errors, | |
| "error_regions": error_regions, | |
| "common_error_patterns": common_error_patterns, | |
| }) | |
| # ββ Dashboard (agent monitor) βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Served at GET / and GET /api/* β does not affect any existing API routes. | |
| from dashboard_routes import router as _dashboard_router # noqa: E402 | |
| app.include_router(_dashboard_router) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |