"""FastAPI entry point — OpenEnv create_app() + custom endpoints.""" import os import time from collections import deque from pathlib import Path from fastapi import HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import HTMLResponse from openenv.core.env_server.http_server import create_app from .environment import OrigamiEnvironment from .models import OrigamiAction, OrigamiObservation from .tasks import TASKS app = create_app( OrigamiEnvironment, OrigamiAction, OrigamiObservation, env_name="origami_env", ) # Allow CORS for frontend polling app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) # ── Training Activity Feed ─────────────────────────────────────────────────── # Ring buffer of recent training steps — the frontend polls this to visualize # what's happening during GRPO training. ACTIVITY_FEED: deque = deque(maxlen=50) # Last 50 steps TRAINING_STATS: dict = {"total_steps": 0, "best_reward": -999, "best_similarity": 0} @app.get("/training/feed") def get_training_feed(since: int = 0): """Get recent training activity. Pass `since=` to get only new entries.""" entries = [e for e in ACTIVITY_FEED if e["step"] > since] return {"entries": entries, "stats": TRAINING_STATS} @app.post("/training/log") def log_training_step(data: dict): """Log a training step from the notebook. Called after each env.step().""" step = TRAINING_STATS["total_steps"] + 1 TRAINING_STATS["total_steps"] = step entry = { "step": step, "timestamp": time.time(), "task_name": data.get("task_name", ""), "reward": data.get("reward", 0), "shape_similarity": data.get("shape_similarity", 0), "is_valid": data.get("is_valid", False), "error": data.get("error", None), "fold_data": data.get("fold_data", None), "final_positions": data.get("final_positions", []), "target_positions": data.get("target_positions", []), } ACTIVITY_FEED.append(entry) reward = entry["reward"] sim = entry["shape_similarity"] if reward > TRAINING_STATS["best_reward"]: TRAINING_STATS["best_reward"] = reward if sim > TRAINING_STATS["best_similarity"]: TRAINING_STATS["best_similarity"] = sim return {"step": step} @app.get("/tasks") def get_tasks(): return { name: { "name": t["name"], "description": t["description"], "difficulty": t["difficulty"], "paper": t["paper"], "target_fold": t["target_fold"], } for name, t in TASKS.items() } @app.get("/tasks/{task_name}") def get_task_detail(task_name: str): if task_name not in TASKS: raise HTTPException(status_code=404, detail=f"Task '{task_name}' not found") t = TASKS[task_name] return {"name": t["name"], "description": t["description"], "difficulty": t["difficulty"], "paper": t["paper"], "target_fold": t["target_fold"]} def main(): import uvicorn port = int(os.environ.get("PORT", 8000)) uvicorn.run(app, host="0.0.0.0", port=port) if __name__ == "__main__": main()