Spaces:
Running
Running
| from __future__ import annotations | |
| import asyncio | |
| import json | |
| from pathlib import Path | |
| import numpy as np | |
| from fastapi import HTTPException, WebSocket | |
| from fastapi.responses import HTMLResponse, JSONResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from openenv.core.env_server.http_server import create_app | |
| from env.environment import OrigamiEnvironment | |
| from openenv_runtime.environment import OpenEnvOrigamiEnvironment | |
| from openenv_runtime.models import OrigamiAction, OrigamiObservation | |
| from server.training_broadcast import TrainingBroadcastServer | |
| # --------------------------------------------------------------------------- | |
| # Numpy-safe JSON response | |
| # --------------------------------------------------------------------------- | |
| def _np_default(obj): | |
| if isinstance(obj, np.bool_): | |
| return bool(obj) | |
| if isinstance(obj, np.integer): | |
| return int(obj) | |
| if isinstance(obj, np.floating): | |
| return float(obj) | |
| if isinstance(obj, np.ndarray): | |
| return obj.tolist() | |
| raise TypeError(f"Not serializable: {type(obj)}") | |
| class NumpyJSONResponse(JSONResponse): | |
| def render(self, content) -> bytes: | |
| return json.dumps(content, default=_np_default).encode("utf-8") | |
| # --------------------------------------------------------------------------- | |
| # Episode registry for replay | |
| # --------------------------------------------------------------------------- | |
| _episode_registry: dict[str, dict] = {} | |
| # --------------------------------------------------------------------------- | |
| # OpenEnv app + training broadcast server | |
| # --------------------------------------------------------------------------- | |
| app = create_app( | |
| env=lambda: OpenEnvOrigamiEnvironment(mode="step"), | |
| action_cls=OrigamiAction, | |
| observation_cls=OrigamiObservation, | |
| env_name="optigami", | |
| ) | |
| broadcast = TrainingBroadcastServer() | |
| def _ensure_broadcast_loop(): | |
| """Set broadcast loop on first use (replaces deprecated on_event('startup')).""" | |
| if broadcast._loop is None or broadcast._loop.is_closed(): | |
| try: | |
| broadcast._loop = asyncio.get_running_loop() | |
| except RuntimeError: | |
| pass | |
| async def _set_broadcast_loop(request, call_next): | |
| """Ensure broadcast has event loop before handling requests.""" | |
| _ensure_broadcast_loop() | |
| return await call_next(request) | |
| # --------------------------------------------------------------------------- | |
| # Health endpoint | |
| # --------------------------------------------------------------------------- | |
| async def health(): | |
| return {"status": "ok"} | |
| # --------------------------------------------------------------------------- | |
| # Episode replay endpoint | |
| # --------------------------------------------------------------------------- | |
| async def replay_episode(ep_id: str): | |
| if ep_id not in _episode_registry: | |
| raise HTTPException(status_code=404, detail="Episode not found") | |
| return NumpyJSONResponse(_episode_registry[ep_id]) | |
| # --------------------------------------------------------------------------- | |
| # Training grid viewer WebSocket | |
| # --------------------------------------------------------------------------- | |
| async def training_ws(websocket: WebSocket): | |
| """Read-only spectator WebSocket for the training grid viewer.""" | |
| _ensure_broadcast_loop() | |
| await broadcast.connect_spectator(websocket) | |
| # --------------------------------------------------------------------------- | |
| # Helper: extract crease folds from .fold target | |
| # --------------------------------------------------------------------------- | |
| def _target_to_folds(target: dict) -> list[dict]: | |
| """Extract crease folds from a target .fold dict (edges with M or V).""" | |
| verts = target.get("vertices_coords", []) | |
| edges_v = target.get("edges_vertices", []) | |
| edges_a = target.get("edges_assignment", []) | |
| folds = [] | |
| for (v1, v2), ass in zip(edges_v, edges_a): | |
| if ass in ("M", "V") and v1 < len(verts) and v2 < len(verts): | |
| p1 = verts[v1] | |
| p2 = verts[v2] | |
| folds.append({"from": p1, "to": p2, "assignment": ass}) | |
| return folds | |
| def _graph_state_to_fold(paper_dict: dict) -> dict: | |
| """Convert internal graph state dict to FOLD-format arrays for the frontend. | |
| Input format (from env.state()['paper']): | |
| vertices: {id: (x, y), ...} | |
| edges: {id: (v1_id, v2_id, assignment), ...} (only M/V) | |
| Output format (FOLD): | |
| vertices_coords: [[x, y, 0], ...] | |
| edges_vertices: [[i, j], ...] | |
| edges_assignment: ['M'|'V'|'B', ...] | |
| faces_vertices: [[i, j, k], ...] (Delaunay triangulation for 3D) | |
| """ | |
| raw_verts = paper_dict.get("vertices", {}) | |
| raw_edges = paper_dict.get("edges", {}) | |
| if not raw_verts: | |
| return {} | |
| sorted_ids = sorted(raw_verts.keys(), key=lambda k: int(k) if isinstance(k, (int, str)) else k) | |
| id_to_idx = {vid: idx for idx, vid in enumerate(sorted_ids)} | |
| vertices_coords = [] | |
| for vid in sorted_ids: | |
| xy = raw_verts[vid] | |
| vertices_coords.append([float(xy[0]), float(xy[1]), 0.0]) | |
| edges_vertices = [] | |
| edges_assignment = [] | |
| for eid in sorted(raw_edges.keys(), key=lambda k: int(k) if isinstance(k, (int, str)) else k): | |
| v1_id, v2_id, asgn = raw_edges[eid] | |
| if v1_id in id_to_idx and v2_id in id_to_idx: | |
| edges_vertices.append([id_to_idx[v1_id], id_to_idx[v2_id]]) | |
| edges_assignment.append(asgn) | |
| faces_vertices = _triangulate_vertices(vertices_coords) | |
| return { | |
| "vertices_coords": vertices_coords, | |
| "edges_vertices": edges_vertices, | |
| "edges_assignment": edges_assignment, | |
| "faces_vertices": faces_vertices, | |
| } | |
| def _triangulate_vertices(vertices_coords: list) -> list: | |
| """Delaunay triangulate the 2D vertex set for 3D mesh rendering.""" | |
| if len(vertices_coords) < 3: | |
| return [] | |
| try: | |
| from scipy.spatial import Delaunay | |
| pts = np.array([[v[0], v[1]] for v in vertices_coords]) | |
| tri = Delaunay(pts) | |
| return tri.simplices.tolist() | |
| except Exception: | |
| return [[0, 1, 2], [0, 2, 3]] if len(vertices_coords) >= 4 else [] | |
| # --------------------------------------------------------------------------- | |
| # API routes — must be registered BEFORE the StaticFiles catch-all mount | |
| # --------------------------------------------------------------------------- | |
| def get_targets(): | |
| """Return available target names and metadata from env/targets/*.fold.""" | |
| env = OrigamiEnvironment() | |
| names = env.available_targets() | |
| result: dict[str, dict] = {} | |
| for name in names: | |
| target = env._targets.get(name, {}) | |
| result[name] = { | |
| "name": name, | |
| "level": target.get("level", 1), | |
| "description": target.get("description", ""), | |
| "n_creases": len([a for a in target.get("edges_assignment", []) if a in ("M", "V")]), | |
| "difficulty": target.get("level", 1), | |
| "material": "paper", | |
| } | |
| return NumpyJSONResponse(result) | |
| def demo_episode(target: str = "half_horizontal"): | |
| """Return a pre-solved demo episode for the given .fold target.""" | |
| env = OrigamiEnvironment(mode="step") | |
| targets = env.available_targets() | |
| if target not in targets: | |
| target = targets[0] if targets else "half_horizontal" | |
| t = env._targets.get(target, {}) | |
| folds = _target_to_folds(t) | |
| obs_dict = env.reset(target_name=target) | |
| steps: list[dict] = [] | |
| for i, fold_dict in enumerate(folds): | |
| obs_dict, reward, done, info = env.step(fold_dict) | |
| graph = env.paper.graph | |
| all_edges = {eid: (v1, v2, a) for eid, (v1, v2, a) in graph.edges.items()} | |
| fold_state = _graph_state_to_fold({ | |
| "vertices": dict(graph.vertices), | |
| "edges": all_edges, | |
| }) | |
| steps.append({ | |
| "step": i + 1, | |
| "fold": fold_dict, | |
| "paper_state": fold_state, | |
| "metrics": reward if isinstance(reward, dict) else {"total": reward}, | |
| "done": done, | |
| }) | |
| if done: | |
| break | |
| return NumpyJSONResponse({ | |
| "task_name": target, | |
| "task": {"name": target, "level": t.get("level", 1), "description": t.get("description", "")}, | |
| "target_crease": t, | |
| "steps": steps, | |
| "final_metrics": steps[-1]["metrics"] if steps else {}, | |
| }) | |
| # --------------------------------------------------------------------------- | |
| # Static file serving — must come LAST so API routes take priority | |
| # --------------------------------------------------------------------------- | |
| _BUILD_DIR = Path(__file__).resolve().parent.parent / "build" | |
| if _BUILD_DIR.exists(): | |
| app.mount("/", StaticFiles(directory=str(_BUILD_DIR), html=True), name="renderer") | |
| else: | |
| def missing_renderer_build() -> HTMLResponse: | |
| return HTMLResponse( | |
| """ | |
| <html><body style="font-family: sans-serif; margin: 24px;"> | |
| <h3>Renderer build not found</h3> | |
| <p>No <code>build/</code> directory is present in the container.</p> | |
| <p>OpenEnv API docs are available at <a href="/docs">/docs</a>.</p> | |
| </body></html> | |
| """, | |
| status_code=200, | |
| ) | |