sissississi's picture
Redesign frontend as training dashboard + add live activity feed
d662461
"""FastAPI entry point β€” OpenEnv create_app() + custom endpoints."""
import os
import time
from collections import deque
from pathlib import Path
from fastapi import HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse
from openenv.core.env_server.http_server import create_app
from .environment import OrigamiEnvironment
from .models import OrigamiAction, OrigamiObservation
from .tasks import TASKS
app = create_app(
OrigamiEnvironment,
OrigamiAction,
OrigamiObservation,
env_name="origami_env",
)
# Allow CORS for frontend polling
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# ── Training Activity Feed ───────────────────────────────────────────────────
# Ring buffer of recent training steps β€” the frontend polls this to visualize
# what's happening during GRPO training.
ACTIVITY_FEED: deque = deque(maxlen=50) # Last 50 steps
TRAINING_STATS: dict = {"total_steps": 0, "best_reward": -999, "best_similarity": 0}
@app.get("/training/feed")
def get_training_feed(since: int = 0):
"""Get recent training activity. Pass `since=<step>` to get only new entries."""
entries = [e for e in ACTIVITY_FEED if e["step"] > since]
return {"entries": entries, "stats": TRAINING_STATS}
@app.post("/training/log")
def log_training_step(data: dict):
"""Log a training step from the notebook. Called after each env.step()."""
step = TRAINING_STATS["total_steps"] + 1
TRAINING_STATS["total_steps"] = step
entry = {
"step": step,
"timestamp": time.time(),
"task_name": data.get("task_name", ""),
"reward": data.get("reward", 0),
"shape_similarity": data.get("shape_similarity", 0),
"is_valid": data.get("is_valid", False),
"error": data.get("error", None),
"fold_data": data.get("fold_data", None),
"final_positions": data.get("final_positions", []),
"target_positions": data.get("target_positions", []),
}
ACTIVITY_FEED.append(entry)
reward = entry["reward"]
sim = entry["shape_similarity"]
if reward > TRAINING_STATS["best_reward"]:
TRAINING_STATS["best_reward"] = reward
if sim > TRAINING_STATS["best_similarity"]:
TRAINING_STATS["best_similarity"] = sim
return {"step": step}
@app.get("/tasks")
def get_tasks():
return {
name: {
"name": t["name"], "description": t["description"],
"difficulty": t["difficulty"], "paper": t["paper"],
"target_fold": t["target_fold"],
}
for name, t in TASKS.items()
}
@app.get("/tasks/{task_name}")
def get_task_detail(task_name: str):
if task_name not in TASKS:
raise HTTPException(status_code=404, detail=f"Task '{task_name}' not found")
t = TASKS[task_name]
return {"name": t["name"], "description": t["description"],
"difficulty": t["difficulty"], "paper": t["paper"], "target_fold": t["target_fold"]}
def main():
import uvicorn
port = int(os.environ.get("PORT", 8000))
uvicorn.run(app, host="0.0.0.0", port=port)
if __name__ == "__main__":
main()