File size: 3,319 Bytes
1f89afe
bc52096
 
d662461
 
1f89afe
bc52096
1f89afe
d662461
1f89afe
bc52096
1f89afe
bc52096
1f89afe
 
 
bc52096
1f89afe
 
 
 
 
 
9ae534f
d662461
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9ae534f
 
1f89afe
9ae534f
 
1f89afe
 
9ae534f
bc52096
9ae534f
bc52096
 
 
 
 
 
 
9ae534f
1f89afe
 
9ae534f
 
bc52096
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
"""FastAPI entry point β€” OpenEnv create_app() + custom endpoints."""

import os
import time
from collections import deque
from pathlib import Path

from fastapi import HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse

from openenv.core.env_server.http_server import create_app

from .environment import OrigamiEnvironment
from .models import OrigamiAction, OrigamiObservation
from .tasks import TASKS

app = create_app(
    OrigamiEnvironment,
    OrigamiAction,
    OrigamiObservation,
    env_name="origami_env",
)

# Allow CORS for frontend polling
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)

# ── Training Activity Feed ───────────────────────────────────────────────────
# Ring buffer of recent training steps β€” the frontend polls this to visualize
# what's happening during GRPO training.

ACTIVITY_FEED: deque = deque(maxlen=50)  # Last 50 steps
TRAINING_STATS: dict = {"total_steps": 0, "best_reward": -999, "best_similarity": 0}


@app.get("/training/feed")
def get_training_feed(since: int = 0):
    """Get recent training activity. Pass `since=<step>` to get only new entries."""
    entries = [e for e in ACTIVITY_FEED if e["step"] > since]
    return {"entries": entries, "stats": TRAINING_STATS}


@app.post("/training/log")
def log_training_step(data: dict):
    """Log a training step from the notebook. Called after each env.step()."""
    step = TRAINING_STATS["total_steps"] + 1
    TRAINING_STATS["total_steps"] = step

    entry = {
        "step": step,
        "timestamp": time.time(),
        "task_name": data.get("task_name", ""),
        "reward": data.get("reward", 0),
        "shape_similarity": data.get("shape_similarity", 0),
        "is_valid": data.get("is_valid", False),
        "error": data.get("error", None),
        "fold_data": data.get("fold_data", None),
        "final_positions": data.get("final_positions", []),
        "target_positions": data.get("target_positions", []),
    }

    ACTIVITY_FEED.append(entry)

    reward = entry["reward"]
    sim = entry["shape_similarity"]
    if reward > TRAINING_STATS["best_reward"]:
        TRAINING_STATS["best_reward"] = reward
    if sim > TRAINING_STATS["best_similarity"]:
        TRAINING_STATS["best_similarity"] = sim

    return {"step": step}


@app.get("/tasks")
def get_tasks():
    return {
        name: {
            "name": t["name"], "description": t["description"],
            "difficulty": t["difficulty"], "paper": t["paper"],
            "target_fold": t["target_fold"],
        }
        for name, t in TASKS.items()
    }


@app.get("/tasks/{task_name}")
def get_task_detail(task_name: str):
    if task_name not in TASKS:
        raise HTTPException(status_code=404, detail=f"Task '{task_name}' not found")
    t = TASKS[task_name]
    return {"name": t["name"], "description": t["description"],
            "difficulty": t["difficulty"], "paper": t["paper"], "target_fold": t["target_fold"]}


def main():
    import uvicorn
    port = int(os.environ.get("PORT", 8000))
    uvicorn.run(app, host="0.0.0.0", port=port)


if __name__ == "__main__":
    main()