Spaces:
Running
Running
| """FastAPI entry point β OpenEnv create_app() + custom endpoints.""" | |
| import os | |
| import time | |
| from collections import deque | |
| from pathlib import Path | |
| from fastapi import HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import HTMLResponse | |
| from openenv.core.env_server.http_server import create_app | |
| from .environment import OrigamiEnvironment | |
| from .models import OrigamiAction, OrigamiObservation | |
| from .tasks import TASKS | |
| app = create_app( | |
| OrigamiEnvironment, | |
| OrigamiAction, | |
| OrigamiObservation, | |
| env_name="origami_env", | |
| ) | |
| # Allow CORS for frontend polling | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # ββ Training Activity Feed βββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Ring buffer of recent training steps β the frontend polls this to visualize | |
| # what's happening during GRPO training. | |
| ACTIVITY_FEED: deque = deque(maxlen=50) # Last 50 steps | |
| TRAINING_STATS: dict = {"total_steps": 0, "best_reward": -999, "best_similarity": 0} | |
| def get_training_feed(since: int = 0): | |
| """Get recent training activity. Pass `since=<step>` to get only new entries.""" | |
| entries = [e for e in ACTIVITY_FEED if e["step"] > since] | |
| return {"entries": entries, "stats": TRAINING_STATS} | |
| def log_training_step(data: dict): | |
| """Log a training step from the notebook. Called after each env.step().""" | |
| step = TRAINING_STATS["total_steps"] + 1 | |
| TRAINING_STATS["total_steps"] = step | |
| entry = { | |
| "step": step, | |
| "timestamp": time.time(), | |
| "task_name": data.get("task_name", ""), | |
| "reward": data.get("reward", 0), | |
| "shape_similarity": data.get("shape_similarity", 0), | |
| "is_valid": data.get("is_valid", False), | |
| "error": data.get("error", None), | |
| "fold_data": data.get("fold_data", None), | |
| "final_positions": data.get("final_positions", []), | |
| "target_positions": data.get("target_positions", []), | |
| } | |
| ACTIVITY_FEED.append(entry) | |
| reward = entry["reward"] | |
| sim = entry["shape_similarity"] | |
| if reward > TRAINING_STATS["best_reward"]: | |
| TRAINING_STATS["best_reward"] = reward | |
| if sim > TRAINING_STATS["best_similarity"]: | |
| TRAINING_STATS["best_similarity"] = sim | |
| return {"step": step} | |
| def get_tasks(): | |
| return { | |
| name: { | |
| "name": t["name"], "description": t["description"], | |
| "difficulty": t["difficulty"], "paper": t["paper"], | |
| "target_fold": t["target_fold"], | |
| } | |
| for name, t in TASKS.items() | |
| } | |
| def get_task_detail(task_name: str): | |
| if task_name not in TASKS: | |
| raise HTTPException(status_code=404, detail=f"Task '{task_name}' not found") | |
| t = TASKS[task_name] | |
| return {"name": t["name"], "description": t["description"], | |
| "difficulty": t["difficulty"], "paper": t["paper"], "target_fold": t["target_fold"]} | |
| def main(): | |
| import uvicorn | |
| port = int(os.environ.get("PORT", 8000)) | |
| uvicorn.run(app, host="0.0.0.0", port=port) | |
| if __name__ == "__main__": | |
| main() | |