Spaces:
Running
Running
Commit ·
4c6824f
1
Parent(s): 0153179
Merge origin/main into pr/6 — keep pr/6 refactor (openenv, env, no legacy engine)
Browse files- openenv_server/app.py +0 -1
- server/app.py +162 -0
- server/models.py +72 -0
- server/origami_environment.py +221 -0
- server/tasks.py +123 -0
- src/App.js +1 -1
- src/components/Fold3DCanvas.js +6 -5
- training/__init__.py +0 -0
- training/demo.py +251 -0
- training/demo_llm.py +232 -0
- training/runner.py +191 -0
openenv_server/app.py
CHANGED
|
@@ -160,7 +160,6 @@ def _graph_state_to_fold(paper_dict: dict) -> dict:
|
|
| 160 |
edges_assignment.append(asgn)
|
| 161 |
|
| 162 |
faces_vertices = _triangulate_vertices(vertices_coords)
|
| 163 |
-
|
| 164 |
return {
|
| 165 |
"vertices_coords": vertices_coords,
|
| 166 |
"edges_vertices": edges_vertices,
|
|
|
|
| 160 |
edges_assignment.append(asgn)
|
| 161 |
|
| 162 |
faces_vertices = _triangulate_vertices(vertices_coords)
|
|
|
|
| 163 |
return {
|
| 164 |
"vertices_coords": vertices_coords,
|
| 165 |
"edges_vertices": edges_vertices,
|
server/app.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
server/app.py — Training WebSocket server for Colab environment.
|
| 3 |
+
|
| 4 |
+
Provides /ws/training for live streaming of RL training episodes to browsers.
|
| 5 |
+
Mount at a publicly accessible URL in Colab (e.g., via ngrok or Colab's proxy).
|
| 6 |
+
|
| 7 |
+
Usage in training:
|
| 8 |
+
from server.app import broadcast
|
| 9 |
+
broadcast.publish(episode_id, {"type": "episode_update", ...})
|
| 10 |
+
"""
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
|
| 15 |
+
import uvicorn
|
| 16 |
+
from fastapi import FastAPI, HTTPException, WebSocket
|
| 17 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 18 |
+
from fastapi.responses import HTMLResponse
|
| 19 |
+
from fastapi.staticfiles import StaticFiles
|
| 20 |
+
|
| 21 |
+
from server.training_broadcast import TrainingBroadcastServer
|
| 22 |
+
|
| 23 |
+
app = FastAPI(title="Optigami Training Server", version="1.0")
|
| 24 |
+
|
| 25 |
+
# Allow cross-origin connections (Colab public URL → browser)
|
| 26 |
+
app.add_middleware(
|
| 27 |
+
CORSMiddleware,
|
| 28 |
+
allow_origins=["*"],
|
| 29 |
+
allow_credentials=True,
|
| 30 |
+
allow_methods=["*"],
|
| 31 |
+
allow_headers=["*"],
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
# Global broadcast server — import and use from training code
|
| 35 |
+
broadcast = TrainingBroadcastServer()
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@app.on_event("startup")
|
| 39 |
+
async def _store_loop() -> None:
|
| 40 |
+
"""Capture the asyncio event loop so training threads can schedule coroutines."""
|
| 41 |
+
import asyncio
|
| 42 |
+
broadcast._loop = asyncio.get_running_loop()
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
@app.websocket("/ws/training")
|
| 46 |
+
async def training_ws(websocket: WebSocket) -> None:
|
| 47 |
+
"""Spectator WebSocket endpoint. Viewers connect here to watch training."""
|
| 48 |
+
await broadcast.connect_spectator(websocket)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
@app.get("/health")
|
| 52 |
+
def health() -> dict:
|
| 53 |
+
return {
|
| 54 |
+
"status": "ok",
|
| 55 |
+
"spectators": broadcast.spectator_count,
|
| 56 |
+
"active_episodes": broadcast.active_episodes,
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
# ── Demo endpoints (same as openenv_server/app.py so the React UI works) ──
|
| 61 |
+
|
| 62 |
+
@app.get("/targets")
|
| 63 |
+
def get_targets() -> dict:
|
| 64 |
+
from server.tasks import available_task_names, get_task_by_name
|
| 65 |
+
return {
|
| 66 |
+
name: {
|
| 67 |
+
"name": name,
|
| 68 |
+
"level": t["difficulty"],
|
| 69 |
+
"description": t.get("description", ""),
|
| 70 |
+
"n_creases": t.get("max_folds", 3),
|
| 71 |
+
"difficulty": t["difficulty"],
|
| 72 |
+
"material": t.get("material", "paper"),
|
| 73 |
+
}
|
| 74 |
+
for name in available_task_names()
|
| 75 |
+
if (t := get_task_by_name(name))
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
_DEMO_SEQUENCES: dict[str, list[dict]] = {
|
| 80 |
+
"half_fold": [{"type": "valley", "line": {"start": [0.0, 0.5], "end": [1.0, 0.5]}, "angle": 180.0}],
|
| 81 |
+
"quarter_fold": [{"type": "valley", "line": {"start": [0.0, 0.5], "end": [1.0, 0.5]}, "angle": 180.0},
|
| 82 |
+
{"type": "valley", "line": {"start": [0.5, 0.0], "end": [0.5, 1.0]}, "angle": 180.0}],
|
| 83 |
+
"letter_fold": [{"type": "valley", "line": {"start": [0.0, 0.333], "end": [1.0, 0.333]}, "angle": 180.0},
|
| 84 |
+
{"type": "mountain", "line": {"start": [0.0, 0.667], "end": [1.0, 0.667]}, "angle": 180.0}],
|
| 85 |
+
"map_fold": [{"type": "valley", "line": {"start": [0.0, 0.5], "end": [1.0, 0.5]}, "angle": 180.0},
|
| 86 |
+
{"type": "mountain", "line": {"start": [0.5, 0.0], "end": [0.5, 1.0]}, "angle": 180.0}],
|
| 87 |
+
"solar_panel": [{"type": "valley", "line": {"start": [0.0, 0.25], "end": [1.0, 0.25]}, "angle": 180.0},
|
| 88 |
+
{"type": "mountain", "line": {"start": [0.0, 0.5], "end": [1.0, 0.5]}, "angle": 180.0},
|
| 89 |
+
{"type": "valley", "line": {"start": [0.0, 0.75], "end": [1.0, 0.75]}, "angle": 180.0}],
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
@app.get("/episode/demo")
|
| 94 |
+
def demo_episode(target: str = "half_fold") -> dict:
|
| 95 |
+
from server.origami_environment import OrigamiEnvironment
|
| 96 |
+
from server.models import OrigamiAction as NewAction
|
| 97 |
+
from server.tasks import get_task_by_name
|
| 98 |
+
|
| 99 |
+
folds = _DEMO_SEQUENCES.get(target, _DEMO_SEQUENCES["half_fold"])
|
| 100 |
+
env = OrigamiEnvironment()
|
| 101 |
+
obs = env.reset(task_name=target)
|
| 102 |
+
steps: list[dict] = []
|
| 103 |
+
|
| 104 |
+
for i, fold_dict in enumerate(folds):
|
| 105 |
+
action = NewAction(
|
| 106 |
+
fold_type=fold_dict["type"],
|
| 107 |
+
fold_line=fold_dict["line"],
|
| 108 |
+
fold_angle=float(fold_dict.get("angle", 180.0)),
|
| 109 |
+
)
|
| 110 |
+
obs = env.step(action)
|
| 111 |
+
steps.append({"step": i + 1, "fold": fold_dict,
|
| 112 |
+
"paper_state": obs.paper_state, "metrics": obs.metrics,
|
| 113 |
+
"done": obs.done})
|
| 114 |
+
if obs.done:
|
| 115 |
+
break
|
| 116 |
+
|
| 117 |
+
return {"task_name": target, "task": get_task_by_name(target) or {},
|
| 118 |
+
"steps": steps, "final_metrics": obs.metrics if steps else {}}
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
@app.get("/episode/replay/{ep_id}")
|
| 122 |
+
def replay_episode(ep_id: str) -> dict:
|
| 123 |
+
"""Return a stored training episode in the same format as /episode/demo."""
|
| 124 |
+
from server.tasks import get_task_by_name
|
| 125 |
+
ep = broadcast._registry.get(ep_id)
|
| 126 |
+
if not ep:
|
| 127 |
+
raise HTTPException(status_code=404, detail=f"Episode '{ep_id}' not found in registry")
|
| 128 |
+
return {
|
| 129 |
+
"task_name": ep.task_name,
|
| 130 |
+
"task": get_task_by_name(ep.task_name) or {},
|
| 131 |
+
"steps": ep.steps,
|
| 132 |
+
"final_metrics": ep.final_metrics or (ep.steps[-1]["metrics"] if ep.steps else {}),
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
# ── Static files — viewer first, then React app (LAST, catch-all) ──
|
| 137 |
+
|
| 138 |
+
_VIEWER_DIR = Path(__file__).resolve().parent.parent / "viewer"
|
| 139 |
+
_BUILD_DIR = Path(__file__).resolve().parent.parent / "build"
|
| 140 |
+
|
| 141 |
+
if _VIEWER_DIR.exists():
|
| 142 |
+
app.mount("/viewer", StaticFiles(directory=str(_VIEWER_DIR), html=True), name="viewer")
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
if _BUILD_DIR.exists():
|
| 146 |
+
app.mount("/", StaticFiles(directory=str(_BUILD_DIR), html=True), name="react")
|
| 147 |
+
else:
|
| 148 |
+
@app.get("/", include_in_schema=False)
|
| 149 |
+
def _no_build() -> HTMLResponse:
|
| 150 |
+
return HTMLResponse(
|
| 151 |
+
"<p>React build not found. Run <code>npm run build</code> in the frontend directory.</p>"
|
| 152 |
+
"<p>Training viewer: <a href='/viewer/training.html'>/viewer/training.html</a></p>"
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def run(host: str = "0.0.0.0", port: int = 9001) -> None:
|
| 157 |
+
"""Start the training server. Call from Colab notebook."""
|
| 158 |
+
uvicorn.run(app, host=host, port=port)
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
if __name__ == "__main__":
|
| 162 |
+
run()
|
server/models.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
OpenEnv Pydantic models for the origami RL environment.
|
| 3 |
+
|
| 4 |
+
OrigamiAction — one fold per step
|
| 5 |
+
OrigamiObservation — everything the LLM and Three.js viewer need
|
| 6 |
+
OrigamiState — server-side episode tracking
|
| 7 |
+
"""
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
from typing import Any, Optional
|
| 11 |
+
|
| 12 |
+
from pydantic import BaseModel, Field
|
| 13 |
+
|
| 14 |
+
# openenv base classes — use them if available, fall back to plain Pydantic
|
| 15 |
+
try:
|
| 16 |
+
from openenv.core.env_server.types import Action, Observation, State
|
| 17 |
+
except ImportError:
|
| 18 |
+
Action = BaseModel
|
| 19 |
+
class State(BaseModel):
|
| 20 |
+
"""Minimal stand-in for openenv State base class."""
|
| 21 |
+
episode_id: Optional[str] = None
|
| 22 |
+
step_count: int = 0
|
| 23 |
+
|
| 24 |
+
class Observation(BaseModel):
|
| 25 |
+
"""Minimal stand-in for openenv Observation base class."""
|
| 26 |
+
done: bool = False
|
| 27 |
+
reward: Optional[float] = None
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class OrigamiAction(Action):
|
| 31 |
+
"""One fold operation sent by the client each step."""
|
| 32 |
+
|
| 33 |
+
fold_type: str = Field(
|
| 34 |
+
default="valley",
|
| 35 |
+
description="'valley' | 'mountain' | 'pleat' | 'crimp' | 'stop'",
|
| 36 |
+
)
|
| 37 |
+
fold_line: dict[str, list[float]] = Field(
|
| 38 |
+
default_factory=lambda: {"start": [0.0, 0.5], "end": [1.0, 0.5]},
|
| 39 |
+
description="{'start': [x, y], 'end': [x, y]} normalized 0-1",
|
| 40 |
+
)
|
| 41 |
+
fold_angle: float = Field(
|
| 42 |
+
default=180.0,
|
| 43 |
+
description="Fold angle in degrees, 0-180",
|
| 44 |
+
)
|
| 45 |
+
layer_select: str = Field(
|
| 46 |
+
default="all",
|
| 47 |
+
description="'all' | 'top' | 'bottom'",
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class OrigamiObservation(Observation):
|
| 52 |
+
"""Everything the LLM and Three.js viewer need.
|
| 53 |
+
|
| 54 |
+
paper_state contains FOLD-compatible geometry + physics data.
|
| 55 |
+
metrics contains all computed quality metrics.
|
| 56 |
+
No render_urls — the browser renders from paper_state directly.
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
task: dict[str, Any] = Field(default_factory=dict)
|
| 60 |
+
paper_state: dict[str, Any] = Field(default_factory=dict)
|
| 61 |
+
metrics: dict[str, Any] = Field(default_factory=dict)
|
| 62 |
+
fold_history: list[dict[str, Any]] = Field(default_factory=list)
|
| 63 |
+
error: Optional[str] = Field(default=None)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class OrigamiState(State):
|
| 67 |
+
"""Server-side episode tracking."""
|
| 68 |
+
|
| 69 |
+
task_name: str = Field(default="")
|
| 70 |
+
num_folds_applied: int = Field(default=0)
|
| 71 |
+
is_valid: bool = Field(default=True)
|
| 72 |
+
total_reward: float = Field(default=0.0)
|
server/origami_environment.py
ADDED
|
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
OrigamiEnvironment — OpenEnv environment wrapping the origami physics engine.
|
| 3 |
+
|
| 4 |
+
Implements reset() / step() / state following the OpenEnv interface.
|
| 5 |
+
Engine (physics, fold, validation, metrics) lives in engine/.
|
| 6 |
+
No server-side image rendering — paper_state contains all geometry data.
|
| 7 |
+
"""
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
import json
|
| 11 |
+
import os
|
| 12 |
+
import uuid
|
| 13 |
+
from typing import Any, Optional
|
| 14 |
+
|
| 15 |
+
# openenv base class — fall back to plain object if not installed
|
| 16 |
+
try:
|
| 17 |
+
from openenv.core.env_server.interfaces import Environment
|
| 18 |
+
except ImportError:
|
| 19 |
+
from typing import Generic, TypeVar
|
| 20 |
+
A = TypeVar("A")
|
| 21 |
+
O = TypeVar("O")
|
| 22 |
+
S = TypeVar("S")
|
| 23 |
+
class Environment(Generic[A, O, S]):
|
| 24 |
+
"""Minimal stand-in for openenv.core.env_server.interfaces.Environment."""
|
| 25 |
+
def __init__(self, **kwargs): pass
|
| 26 |
+
|
| 27 |
+
from engine.paper import Paper
|
| 28 |
+
from engine.fold_engine import apply_fold
|
| 29 |
+
from engine.physics import simulate
|
| 30 |
+
from engine.validation import validate_state
|
| 31 |
+
from engine.metrics import compute_all_metrics
|
| 32 |
+
from server.models import OrigamiAction, OrigamiObservation, OrigamiState
|
| 33 |
+
from server.tasks import get_task_by_name, sample_task
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def _get_material(name: str):
|
| 37 |
+
"""Get material by name, falling back to paper."""
|
| 38 |
+
try:
|
| 39 |
+
from engine.materials import get_material
|
| 40 |
+
return get_material(name)
|
| 41 |
+
except Exception:
|
| 42 |
+
from engine.materials import get_material
|
| 43 |
+
return get_material("paper")
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class OrigamiEnvironment(Environment[OrigamiAction, OrigamiObservation, OrigamiState]):
|
| 47 |
+
"""Origami folding RL environment.
|
| 48 |
+
|
| 49 |
+
Each episode: agent receives paper_state + task, applies folds one at a
|
| 50 |
+
time via step(), receives metrics + reward, ends with 'stop' action or
|
| 51 |
+
when max_folds is reached.
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
SUPPORTS_CONCURRENT_SESSIONS = False
|
| 55 |
+
|
| 56 |
+
def __init__(self, **kwargs):
|
| 57 |
+
super().__init__(**kwargs)
|
| 58 |
+
self._paper: Optional[Paper] = None
|
| 59 |
+
self._task: Optional[dict] = None
|
| 60 |
+
self._fold_history: list[dict] = []
|
| 61 |
+
self._metrics: dict = {}
|
| 62 |
+
self._validation: dict = {}
|
| 63 |
+
self._error: Optional[str] = None
|
| 64 |
+
self._episode_id: Optional[str] = None
|
| 65 |
+
self._step_count: int = 0
|
| 66 |
+
self._total_reward: float = 0.0
|
| 67 |
+
|
| 68 |
+
# ── reset ─────────────────────────────────────────────────────────
|
| 69 |
+
|
| 70 |
+
def reset(
|
| 71 |
+
self,
|
| 72 |
+
seed: Optional[int] = None,
|
| 73 |
+
episode_id: Optional[str] = None,
|
| 74 |
+
**kwargs: Any,
|
| 75 |
+
) -> OrigamiObservation:
|
| 76 |
+
self._episode_id = episode_id or str(uuid.uuid4())
|
| 77 |
+
self._step_count = 0
|
| 78 |
+
self._fold_history = []
|
| 79 |
+
self._error = None
|
| 80 |
+
self._total_reward = 0.0
|
| 81 |
+
|
| 82 |
+
# Select task
|
| 83 |
+
task_name = kwargs.get("task_name")
|
| 84 |
+
if task_name:
|
| 85 |
+
self._task = get_task_by_name(task_name)
|
| 86 |
+
if not self._task:
|
| 87 |
+
self._task = sample_task(seed=seed)
|
| 88 |
+
|
| 89 |
+
# Create flat sheet
|
| 90 |
+
mat = _get_material(self._task["material"])
|
| 91 |
+
self._paper = Paper.create_flat_sheet(
|
| 92 |
+
width=self._task["width"],
|
| 93 |
+
height=self._task["height"],
|
| 94 |
+
material=mat,
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
# Initial validation + metrics (no physics needed for flat sheet)
|
| 98 |
+
self._validation = validate_state(self._paper)
|
| 99 |
+
self._metrics = compute_all_metrics(self._paper, self._task, self._validation)
|
| 100 |
+
|
| 101 |
+
return self._make_observation(done=False, reward=None)
|
| 102 |
+
|
| 103 |
+
# ── step ──────────────────────────────────────────────────────────
|
| 104 |
+
|
| 105 |
+
def step(
|
| 106 |
+
self,
|
| 107 |
+
action: OrigamiAction,
|
| 108 |
+
timeout_s: Optional[float] = None,
|
| 109 |
+
**kwargs: Any,
|
| 110 |
+
) -> OrigamiObservation:
|
| 111 |
+
if self._paper is None or self._task is None:
|
| 112 |
+
return self._make_observation(done=True, reward=-5.0)
|
| 113 |
+
|
| 114 |
+
self._step_count += 1
|
| 115 |
+
self._error = None
|
| 116 |
+
|
| 117 |
+
# ── Stop action ───────────────────────────────────────────────
|
| 118 |
+
if action.fold_type == "stop":
|
| 119 |
+
return self._finalize_episode()
|
| 120 |
+
|
| 121 |
+
# ── Build fold dict ───────────────────────────────────────────
|
| 122 |
+
fold_dict = {
|
| 123 |
+
"type": action.fold_type,
|
| 124 |
+
"line": action.fold_line,
|
| 125 |
+
"angle": action.fold_angle,
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
# ── Apply fold ────────────────────────────────────────────────
|
| 129 |
+
new_paper, err = apply_fold(self._paper, fold_dict)
|
| 130 |
+
if err:
|
| 131 |
+
self._error = err
|
| 132 |
+
return self._make_observation(done=True, reward=-5.0)
|
| 133 |
+
|
| 134 |
+
self._paper = new_paper
|
| 135 |
+
self._fold_history.append({**fold_dict, "step": self._step_count})
|
| 136 |
+
|
| 137 |
+
# ── Physics relaxation ────────────────────────────────────────
|
| 138 |
+
try:
|
| 139 |
+
self._paper = simulate(self._paper, fold_percent=1.0)
|
| 140 |
+
except Exception as exc:
|
| 141 |
+
self._error = f"Physics failed: {exc}"
|
| 142 |
+
# Continue — don't abort episode on physics failure
|
| 143 |
+
|
| 144 |
+
# ── Validate ──────────────────────────────────────────────────
|
| 145 |
+
self._validation = validate_state(self._paper)
|
| 146 |
+
|
| 147 |
+
# ── Metrics ───────────────────────────────────────────────────
|
| 148 |
+
self._metrics = compute_all_metrics(self._paper, self._task, self._validation)
|
| 149 |
+
|
| 150 |
+
# ── Check termination ─────────────────────────────────────────
|
| 151 |
+
max_folds = self._task.get("max_folds", 50)
|
| 152 |
+
if self._step_count >= max_folds:
|
| 153 |
+
return self._finalize_episode()
|
| 154 |
+
|
| 155 |
+
if self._validation.get("self_intersections", 0) > 0:
|
| 156 |
+
self._error = "Self-intersection detected"
|
| 157 |
+
return self._finalize_episode()
|
| 158 |
+
|
| 159 |
+
return self._make_observation(done=False, reward=None)
|
| 160 |
+
|
| 161 |
+
# ── state ─────────────────────────────────────────────────────────
|
| 162 |
+
|
| 163 |
+
@property
|
| 164 |
+
def state(self) -> OrigamiState:
|
| 165 |
+
return OrigamiState(
|
| 166 |
+
episode_id=self._episode_id,
|
| 167 |
+
step_count=self._step_count,
|
| 168 |
+
task_name=self._task.get("name", "") if self._task else "",
|
| 169 |
+
num_folds_applied=len(self._fold_history),
|
| 170 |
+
is_valid=self._metrics.get("is_valid", True),
|
| 171 |
+
total_reward=self._total_reward,
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
# ── internals ─────────────────────────────────────────────────────
|
| 175 |
+
|
| 176 |
+
def _finalize_episode(self) -> OrigamiObservation:
|
| 177 |
+
reward = self._compute_reward()
|
| 178 |
+
self._total_reward = reward
|
| 179 |
+
return self._make_observation(done=True, reward=reward)
|
| 180 |
+
|
| 181 |
+
def _make_observation(self, done: bool, reward: Optional[float]) -> OrigamiObservation:
|
| 182 |
+
return OrigamiObservation(
|
| 183 |
+
done=done,
|
| 184 |
+
reward=reward,
|
| 185 |
+
task=self._task or {},
|
| 186 |
+
paper_state=self._paper.to_observation_dict() if self._paper else {},
|
| 187 |
+
metrics=self._metrics,
|
| 188 |
+
fold_history=self._fold_history,
|
| 189 |
+
error=self._error,
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
def _compute_reward(self) -> float:
|
| 193 |
+
m = self._metrics
|
| 194 |
+
reward = 0.0
|
| 195 |
+
|
| 196 |
+
# Compactness is the main signal
|
| 197 |
+
reward += m.get("compactness", 0.0) * 20.0
|
| 198 |
+
|
| 199 |
+
# Bonus for fitting in target box
|
| 200 |
+
if m.get("fits_target_box", False):
|
| 201 |
+
reward += 10.0
|
| 202 |
+
|
| 203 |
+
# Bonus for deployability (if task requires it)
|
| 204 |
+
if m.get("is_deployable", False):
|
| 205 |
+
reward += 5.0
|
| 206 |
+
|
| 207 |
+
# Penalties for violations
|
| 208 |
+
reward -= m.get("kawasaki_violations", 0) * 2.0
|
| 209 |
+
reward -= m.get("maekawa_violations", 0) * 2.0
|
| 210 |
+
reward -= m.get("self_intersections", 0) * 5.0
|
| 211 |
+
|
| 212 |
+
# Penalty for too many folds (encourage efficiency)
|
| 213 |
+
reward -= m.get("fold_count", 0) * 0.5
|
| 214 |
+
|
| 215 |
+
# Penalty for exceeding material strain limit
|
| 216 |
+
max_strain = m.get("max_strain", 0.0)
|
| 217 |
+
strain_limit = self._paper.material.max_strain if self._paper else 0.05
|
| 218 |
+
if max_strain > strain_limit:
|
| 219 |
+
reward -= 3.0 * (max_strain / strain_limit)
|
| 220 |
+
|
| 221 |
+
return float(reward)
|
server/tasks.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Task pool and curriculum for the origami RL environment.
|
| 3 |
+
|
| 4 |
+
7 tasks across 4 difficulty levels.
|
| 5 |
+
"""
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
import random
|
| 9 |
+
from typing import Optional
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
TASKS: dict[str, dict] = {
|
| 13 |
+
"half_fold": {
|
| 14 |
+
"name": "half_fold",
|
| 15 |
+
"description": "Fold a 1x1 paper sheet in half along the horizontal midline.",
|
| 16 |
+
"width": 1.0,
|
| 17 |
+
"height": 1.0,
|
| 18 |
+
"material": "paper",
|
| 19 |
+
"target_ratio": 0.50,
|
| 20 |
+
"max_folds": 3,
|
| 21 |
+
"target_box": [1.0, 0.5, 0.02],
|
| 22 |
+
"must_deploy": False,
|
| 23 |
+
"difficulty": 1,
|
| 24 |
+
},
|
| 25 |
+
"quarter_fold": {
|
| 26 |
+
"name": "quarter_fold",
|
| 27 |
+
"description": "Fold a 1x1 paper sheet into quarters using two perpendicular folds.",
|
| 28 |
+
"width": 1.0,
|
| 29 |
+
"height": 1.0,
|
| 30 |
+
"material": "paper",
|
| 31 |
+
"target_ratio": 0.25,
|
| 32 |
+
"max_folds": 5,
|
| 33 |
+
"target_box": [0.5, 0.5, 0.04],
|
| 34 |
+
"must_deploy": False,
|
| 35 |
+
"difficulty": 1,
|
| 36 |
+
},
|
| 37 |
+
"letter_fold": {
|
| 38 |
+
"name": "letter_fold",
|
| 39 |
+
"description": "Fold a 1x1 paper into thirds (letter fold) using two parallel folds.",
|
| 40 |
+
"width": 1.0,
|
| 41 |
+
"height": 1.0,
|
| 42 |
+
"material": "paper",
|
| 43 |
+
"target_ratio": 0.33,
|
| 44 |
+
"max_folds": 5,
|
| 45 |
+
"target_box": [1.0, 0.34, 0.03],
|
| 46 |
+
"must_deploy": False,
|
| 47 |
+
"difficulty": 2,
|
| 48 |
+
},
|
| 49 |
+
"map_fold": {
|
| 50 |
+
"name": "map_fold",
|
| 51 |
+
"description": "Fold a 1x1 paper into eighths using a grid fold pattern. Must be re-deployable.",
|
| 52 |
+
"width": 1.0,
|
| 53 |
+
"height": 1.0,
|
| 54 |
+
"material": "paper",
|
| 55 |
+
"target_ratio": 0.125,
|
| 56 |
+
"max_folds": 8,
|
| 57 |
+
"target_box": [0.5, 0.25, 0.08],
|
| 58 |
+
"must_deploy": True,
|
| 59 |
+
"difficulty": 2,
|
| 60 |
+
},
|
| 61 |
+
"solar_panel": {
|
| 62 |
+
"name": "solar_panel",
|
| 63 |
+
"description": "Pack a 1x1 Mylar solar panel into a compact configuration using a Miura-ori style fold. Must deploy.",
|
| 64 |
+
"width": 1.0,
|
| 65 |
+
"height": 1.0,
|
| 66 |
+
"material": "mylar",
|
| 67 |
+
"target_ratio": 0.05,
|
| 68 |
+
"max_folds": 20,
|
| 69 |
+
"target_box": [0.25, 0.25, 0.05],
|
| 70 |
+
"must_deploy": True,
|
| 71 |
+
"difficulty": 3,
|
| 72 |
+
},
|
| 73 |
+
"shelter_wall": {
|
| 74 |
+
"name": "shelter_wall",
|
| 75 |
+
"description": "Fold a 1x1 aluminum sheet into a compact structural panel within strain limits.",
|
| 76 |
+
"width": 1.0,
|
| 77 |
+
"height": 1.0,
|
| 78 |
+
"material": "aluminum",
|
| 79 |
+
"target_ratio": 0.10,
|
| 80 |
+
"max_folds": 15,
|
| 81 |
+
"target_box": [0.5, 0.25, 0.1],
|
| 82 |
+
"must_deploy": False,
|
| 83 |
+
"difficulty": 3,
|
| 84 |
+
},
|
| 85 |
+
"stent": {
|
| 86 |
+
"name": "stent",
|
| 87 |
+
"description": "Fold a 0.5x1.5 nitinol sheet into a compact tube configuration for a medical stent. Superelastic material.",
|
| 88 |
+
"width": 0.5,
|
| 89 |
+
"height": 1.5,
|
| 90 |
+
"material": "nitinol",
|
| 91 |
+
"target_ratio": 0.09,
|
| 92 |
+
"max_folds": 25,
|
| 93 |
+
"target_box": [0.1, 0.1, 0.15],
|
| 94 |
+
"must_deploy": True,
|
| 95 |
+
"difficulty": 4,
|
| 96 |
+
},
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def get_task_by_name(name: str) -> Optional[dict]:
|
| 101 |
+
"""Return task dict by name, or None if not found."""
|
| 102 |
+
return TASKS.get(name)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def sample_task(seed: Optional[int] = None, difficulty: Optional[int] = None) -> dict:
|
| 106 |
+
"""Sample a random task, optionally filtered by difficulty level."""
|
| 107 |
+
rng = random.Random(seed)
|
| 108 |
+
pool = list(TASKS.values())
|
| 109 |
+
if difficulty is not None:
|
| 110 |
+
pool = [t for t in pool if t["difficulty"] == difficulty]
|
| 111 |
+
if not pool:
|
| 112 |
+
pool = list(TASKS.values())
|
| 113 |
+
return dict(rng.choice(pool))
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def get_tasks_by_difficulty(level: int) -> list[dict]:
|
| 117 |
+
"""Return all tasks at a given difficulty level."""
|
| 118 |
+
return [dict(t) for t in TASKS.values() if t["difficulty"] == level]
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def available_task_names() -> list[str]:
|
| 122 |
+
"""Return sorted list of all task names."""
|
| 123 |
+
return sorted(TASKS.keys())
|
src/App.js
CHANGED
|
@@ -16,7 +16,7 @@ const REPLAY_EP_ID = _urlParams.get('ep') || null;
|
|
| 16 |
|
| 17 |
function App() {
|
| 18 |
const [targets, setTargets] = useState({});
|
| 19 |
-
const [selectedTarget, setSelectedTarget] = useState('
|
| 20 |
const [episode, setEpisode] = useState(null);
|
| 21 |
const [currentStep, setCurrentStep] = useState(0);
|
| 22 |
const [playing, setPlaying] = useState(false);
|
|
|
|
| 16 |
|
| 17 |
function App() {
|
| 18 |
const [targets, setTargets] = useState({});
|
| 19 |
+
const [selectedTarget, setSelectedTarget] = useState('half_fold');
|
| 20 |
const [episode, setEpisode] = useState(null);
|
| 21 |
const [currentStep, setCurrentStep] = useState(0);
|
| 22 |
const [playing, setPlaying] = useState(false);
|
src/components/Fold3DCanvas.js
CHANGED
|
@@ -7,10 +7,8 @@ const PITCH_MAX = Math.PI / 2 - 0.1;
|
|
| 7 |
const ZOOM_MIN = 0.3;
|
| 8 |
const ZOOM_MAX = 5.0;
|
| 9 |
const LIGHT_DIR = normalize3([0.4, -0.45, 1.0]);
|
| 10 |
-
const
|
| 11 |
-
const
|
| 12 |
-
const MOUNTAIN_COLOR = 'rgba(245, 158, 11, 0.95)';
|
| 13 |
-
const VALLEY_COLOR = 'rgba(56, 189, 248, 0.95)';
|
| 14 |
|
| 15 |
function clamp(value, min, max) {
|
| 16 |
return Math.min(Math.max(value, min), max);
|
|
@@ -46,6 +44,9 @@ function shadePaper(intensity) {
|
|
| 46 |
return `rgb(${r}, ${g}, ${b})`;
|
| 47 |
}
|
| 48 |
|
|
|
|
|
|
|
|
|
|
| 49 |
function buildGridMesh(resolution = 18) {
|
| 50 |
const vertices = [];
|
| 51 |
for (let y = 0; y <= resolution; y += 1) {
|
|
@@ -170,7 +171,7 @@ function applyAllFolds(vertices, foldMasks, progresses) {
|
|
| 170 |
function projectVertex(vertex, dim, pitch, yaw, zoom) {
|
| 171 |
let x = vertex[0] - 0.5;
|
| 172 |
let y = vertex[1] - 0.5;
|
| 173 |
-
let z = vertex[2];
|
| 174 |
|
| 175 |
const cp = Math.cos(pitch);
|
| 176 |
const sp = Math.sin(pitch);
|
|
|
|
| 7 |
const ZOOM_MIN = 0.3;
|
| 8 |
const ZOOM_MAX = 5.0;
|
| 9 |
const LIGHT_DIR = normalize3([0.4, -0.45, 1.0]);
|
| 10 |
+
const MOUNTAIN_COLOR = 'rgba(245, 158, 11, 0.9)';
|
| 11 |
+
const VALLEY_COLOR = 'rgba(56, 189, 248, 0.9)';
|
|
|
|
|
|
|
| 12 |
|
| 13 |
function clamp(value, min, max) {
|
| 14 |
return Math.min(Math.max(value, min), max);
|
|
|
|
| 44 |
return `rgb(${r}, ${g}, ${b})`;
|
| 45 |
}
|
| 46 |
|
| 47 |
+
const SIDE_EPS = 1e-10;
|
| 48 |
+
const MAX_FOLD_RAD = Math.PI;
|
| 49 |
+
|
| 50 |
function buildGridMesh(resolution = 18) {
|
| 51 |
const vertices = [];
|
| 52 |
for (let y = 0; y <= resolution; y += 1) {
|
|
|
|
| 171 |
function projectVertex(vertex, dim, pitch, yaw, zoom) {
|
| 172 |
let x = vertex[0] - 0.5;
|
| 173 |
let y = vertex[1] - 0.5;
|
| 174 |
+
let z = vertex[2] || 0;
|
| 175 |
|
| 176 |
const cp = Math.cos(pitch);
|
| 177 |
const sp = Math.sin(pitch);
|
training/__init__.py
ADDED
|
File without changes
|
training/demo.py
ADDED
|
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
training/demo.py — Run 8 zero-shot rollouts and stream them to the grid viewer.
|
| 3 |
+
|
| 4 |
+
Usage:
|
| 5 |
+
cd /path/to/optigami
|
| 6 |
+
python -m training.demo
|
| 7 |
+
|
| 8 |
+
Then open: http://localhost:9001/viewer/training.html
|
| 9 |
+
|
| 10 |
+
Each of the 8 "strategies" is a heuristic that mimics what a pretrained LLM might
|
| 11 |
+
produce for different tasks — varying from near-optimal to poor. This exercises
|
| 12 |
+
the full broadcast → grid viewer pipeline without requiring an LLM API key.
|
| 13 |
+
"""
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
import asyncio
|
| 17 |
+
import time
|
| 18 |
+
import uuid
|
| 19 |
+
from typing import Callable
|
| 20 |
+
|
| 21 |
+
import uvicorn
|
| 22 |
+
|
| 23 |
+
from server.app import app, broadcast
|
| 24 |
+
from training.runner import run_batch
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
# ── 8 zero-shot heuristic strategies ──────────────────────────────────────────
|
| 28 |
+
# Each is a callable: paper_state (dict) → fold_dict
|
| 29 |
+
# These represent the range of strategies a pretrained LLM might generate.
|
| 30 |
+
|
| 31 |
+
def strategy_perfect_half(paper_state: dict) -> dict:
|
| 32 |
+
"""Valley fold exactly at horizontal midline — optimal for half_fold."""
|
| 33 |
+
return {"type": "valley", "line": {"start": [0.0, 0.5], "end": [1.0, 0.5]}, "angle": 180.0}
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def strategy_slight_offset(paper_state: dict) -> dict:
|
| 37 |
+
"""Valley fold slightly off-center — almost optimal."""
|
| 38 |
+
return {"type": "valley", "line": {"start": [0.0, 0.48], "end": [1.0, 0.48]}, "angle": 180.0}
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def strategy_thirds(paper_state: dict) -> dict:
|
| 42 |
+
"""Letter fold at one-third — wrong for half_fold, generates interesting geometry."""
|
| 43 |
+
fold_count = paper_state.get("fold_count", 0)
|
| 44 |
+
positions = [0.333, 0.667]
|
| 45 |
+
if fold_count >= len(positions):
|
| 46 |
+
return {"type": "stop", "line": {"start": [0.0, 0.5], "end": [1.0, 0.5]}, "angle": 0.0}
|
| 47 |
+
return {
|
| 48 |
+
"type": "valley" if fold_count == 0 else "mountain",
|
| 49 |
+
"line": {"start": [0.0, positions[fold_count]], "end": [1.0, positions[fold_count]]},
|
| 50 |
+
"angle": 180.0,
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def strategy_vertical(paper_state: dict) -> dict:
|
| 55 |
+
"""Vertical fold — gets compactness but in wrong dimension for target_box."""
|
| 56 |
+
return {"type": "valley", "line": {"start": [0.5, 0.0], "end": [0.5, 1.0]}, "angle": 180.0}
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def strategy_mountain(paper_state: dict) -> dict:
|
| 60 |
+
"""Mountain fold at midline — same geometry, different assignment."""
|
| 61 |
+
return {"type": "mountain", "line": {"start": [0.0, 0.5], "end": [1.0, 0.5]}, "angle": 180.0}
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def strategy_accordion(paper_state: dict) -> dict:
|
| 65 |
+
"""Accordion 3-fold — overfolds, achieves high compactness but more folds."""
|
| 66 |
+
fold_count = paper_state.get("fold_count", 0)
|
| 67 |
+
positions = [0.25, 0.5, 0.75]
|
| 68 |
+
assignments = ["valley", "mountain", "valley"]
|
| 69 |
+
if fold_count >= len(positions):
|
| 70 |
+
return {"type": "stop", "line": {"start": [0.0, 0.5], "end": [1.0, 0.5]}, "angle": 0.0}
|
| 71 |
+
return {
|
| 72 |
+
"type": assignments[fold_count],
|
| 73 |
+
"line": {"start": [0.0, positions[fold_count]], "end": [1.0, positions[fold_count]]},
|
| 74 |
+
"angle": 180.0,
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def strategy_diagonal(paper_state: dict) -> dict:
|
| 79 |
+
"""Diagonal fold — achieves compactness but irregular bounding box."""
|
| 80 |
+
return {"type": "valley", "line": {"start": [0.0, 0.0], "end": [1.0, 1.0]}, "angle": 180.0}
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def strategy_quarter(paper_state: dict) -> dict:
|
| 84 |
+
"""Two perpendicular folds — 4x compactness for quarter_fold task."""
|
| 85 |
+
fold_count = paper_state.get("fold_count", 0)
|
| 86 |
+
if fold_count == 0:
|
| 87 |
+
return {"type": "valley", "line": {"start": [0.0, 0.5], "end": [1.0, 0.5]}, "angle": 180.0}
|
| 88 |
+
if fold_count == 1:
|
| 89 |
+
return {"type": "valley", "line": {"start": [0.5, 0.0], "end": [0.5, 1.0]}, "angle": 180.0}
|
| 90 |
+
return {"type": "stop", "line": {"start": [0.0, 0.5], "end": [1.0, 0.5]}, "angle": 0.0}
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
STRATEGIES: list[tuple[str, Callable]] = [
|
| 94 |
+
("perfect_half", strategy_perfect_half),
|
| 95 |
+
("slight_offset", strategy_slight_offset),
|
| 96 |
+
("thirds_fold", strategy_thirds),
|
| 97 |
+
("vertical_fold", strategy_vertical),
|
| 98 |
+
("mountain_fold", strategy_mountain),
|
| 99 |
+
("accordion_3", strategy_accordion),
|
| 100 |
+
("diagonal", strategy_diagonal),
|
| 101 |
+
("quarter_fold", strategy_quarter),
|
| 102 |
+
]
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
# ── Demo runner ────────────────────────────────────────────────────────────────
|
| 106 |
+
|
| 107 |
+
async def run_demo(task_name: str = "half_fold", delay_s: float = 0.5) -> None:
|
| 108 |
+
"""Wait for server to be ready, then fire 8 episodes."""
|
| 109 |
+
# Give uvicorn time to bind and call startup hook (sets broadcast._loop)
|
| 110 |
+
await asyncio.sleep(1.5)
|
| 111 |
+
|
| 112 |
+
batch_id = 1
|
| 113 |
+
names, fns = zip(*STRATEGIES)
|
| 114 |
+
ep_ids = [f"ep_{name}" for name in names]
|
| 115 |
+
|
| 116 |
+
print(f"\n[demo] Starting batch {batch_id} — task: {task_name}")
|
| 117 |
+
print(f"[demo] Open http://localhost:9001/viewer/training.html\n")
|
| 118 |
+
|
| 119 |
+
# Signal grid to clear and show G=8
|
| 120 |
+
await broadcast.start_batch(batch_id, len(fns))
|
| 121 |
+
|
| 122 |
+
await asyncio.sleep(delay_s)
|
| 123 |
+
|
| 124 |
+
# Run all 8 episodes in the thread pool; broadcast_fn fires into this loop
|
| 125 |
+
results = await asyncio.gather(*[
|
| 126 |
+
asyncio.to_thread(
|
| 127 |
+
_run_one,
|
| 128 |
+
fn,
|
| 129 |
+
task_name,
|
| 130 |
+
ep_id,
|
| 131 |
+
broadcast.publish,
|
| 132 |
+
)
|
| 133 |
+
for fn, ep_id in zip(fns, ep_ids)
|
| 134 |
+
])
|
| 135 |
+
|
| 136 |
+
scores = [r["score"] for r in results]
|
| 137 |
+
best_idx = max(range(len(scores)), key=lambda i: scores[i])
|
| 138 |
+
|
| 139 |
+
await broadcast.finish_batch(batch_id, scores, best_episode_id=ep_ids[best_idx])
|
| 140 |
+
|
| 141 |
+
print("\n[demo] Results:")
|
| 142 |
+
for name, result in zip(names, results):
|
| 143 |
+
print(f" {name:20s} score={result['score']:+.2f} status={result['status']}")
|
| 144 |
+
print(f"\n[demo] Best: {names[best_idx]} (score={scores[best_idx]:+.2f})")
|
| 145 |
+
print("\n[demo] Grid viewer running. Press Ctrl+C to stop.\n")
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def _run_one(
|
| 149 |
+
strategy_fn: Callable,
|
| 150 |
+
task_name: str,
|
| 151 |
+
ep_id: str,
|
| 152 |
+
broadcast_fn: Callable,
|
| 153 |
+
) -> dict:
|
| 154 |
+
"""Thin wrapper: adds a small sleep between steps so the viewer can animate."""
|
| 155 |
+
from server.models import OrigamiAction
|
| 156 |
+
from server.origami_environment import OrigamiEnvironment
|
| 157 |
+
|
| 158 |
+
env = OrigamiEnvironment()
|
| 159 |
+
obs = env.reset(task_name=task_name)
|
| 160 |
+
|
| 161 |
+
broadcast_fn(ep_id, {
|
| 162 |
+
"type": "episode_update",
|
| 163 |
+
"episode_id": ep_id,
|
| 164 |
+
"task_name": task_name,
|
| 165 |
+
"step": 0,
|
| 166 |
+
"observation": _obs_dict(obs),
|
| 167 |
+
})
|
| 168 |
+
|
| 169 |
+
max_steps = env._task.get("max_folds", 10) if env._task else 10
|
| 170 |
+
status = "done"
|
| 171 |
+
|
| 172 |
+
for step_idx in range(max_steps):
|
| 173 |
+
if obs.done:
|
| 174 |
+
break
|
| 175 |
+
|
| 176 |
+
time.sleep(0.3) # pace so the viewer can animate each step
|
| 177 |
+
|
| 178 |
+
fold_dict = strategy_fn(obs.paper_state)
|
| 179 |
+
|
| 180 |
+
if fold_dict.get("type") == "stop":
|
| 181 |
+
break
|
| 182 |
+
|
| 183 |
+
action = OrigamiAction(
|
| 184 |
+
fold_type=fold_dict["type"],
|
| 185 |
+
fold_line=fold_dict["line"],
|
| 186 |
+
fold_angle=float(fold_dict.get("angle", 180.0)),
|
| 187 |
+
)
|
| 188 |
+
obs = env.step(action)
|
| 189 |
+
|
| 190 |
+
broadcast_fn(ep_id, {
|
| 191 |
+
"type": "episode_update",
|
| 192 |
+
"episode_id": ep_id,
|
| 193 |
+
"task_name": task_name,
|
| 194 |
+
"step": step_idx + 1,
|
| 195 |
+
"observation": _obs_dict(obs),
|
| 196 |
+
})
|
| 197 |
+
|
| 198 |
+
if obs.done:
|
| 199 |
+
break
|
| 200 |
+
else:
|
| 201 |
+
status = "timeout"
|
| 202 |
+
|
| 203 |
+
score = obs.reward if obs.reward is not None else env._total_reward or 0.0
|
| 204 |
+
|
| 205 |
+
broadcast_fn(ep_id, {
|
| 206 |
+
"type": "episode_done",
|
| 207 |
+
"episode_id": ep_id,
|
| 208 |
+
"status": status,
|
| 209 |
+
"score": float(score),
|
| 210 |
+
"final_metrics": obs.metrics,
|
| 211 |
+
})
|
| 212 |
+
|
| 213 |
+
return {
|
| 214 |
+
"episode_id": ep_id,
|
| 215 |
+
"score": float(score),
|
| 216 |
+
"final_metrics": obs.metrics,
|
| 217 |
+
"status": status,
|
| 218 |
+
}
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def _obs_dict(obs) -> dict:
|
| 222 |
+
try:
|
| 223 |
+
return obs.model_dump()
|
| 224 |
+
except AttributeError:
|
| 225 |
+
return {
|
| 226 |
+
"paper_state": getattr(obs, "paper_state", {}),
|
| 227 |
+
"metrics": getattr(obs, "metrics", {}),
|
| 228 |
+
"fold_history": getattr(obs, "fold_history", []),
|
| 229 |
+
"done": getattr(obs, "done", False),
|
| 230 |
+
"reward": getattr(obs, "reward", None),
|
| 231 |
+
}
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
# ── Entry point ────────────────────────────────────────────────────────────────
|
| 235 |
+
|
| 236 |
+
async def _main() -> None:
|
| 237 |
+
config = uvicorn.Config(app, host="0.0.0.0", port=9001, log_level="warning")
|
| 238 |
+
server = uvicorn.Server(config)
|
| 239 |
+
|
| 240 |
+
# Run demo concurrently with the uvicorn server
|
| 241 |
+
await asyncio.gather(
|
| 242 |
+
server.serve(),
|
| 243 |
+
run_demo(task_name="half_fold"),
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
if __name__ == "__main__":
|
| 248 |
+
try:
|
| 249 |
+
asyncio.run(_main())
|
| 250 |
+
except KeyboardInterrupt:
|
| 251 |
+
print("\n[demo] Stopped.")
|
training/demo_llm.py
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
training/demo_llm.py — 8 rollouts using Claude as the zero-shot fold strategist.
|
| 3 |
+
|
| 4 |
+
Usage:
|
| 5 |
+
cd /path/to/optigami
|
| 6 |
+
ANTHROPIC_API_KEY=sk-... python -m training.demo_llm
|
| 7 |
+
|
| 8 |
+
Each of the 8 episodes calls Claude (claude-haiku-4-5) once per fold step.
|
| 9 |
+
Claude sees the current paper_state metrics and decides the next fold.
|
| 10 |
+
"""
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
import asyncio
|
| 14 |
+
import json
|
| 15 |
+
import os
|
| 16 |
+
import re
|
| 17 |
+
import time
|
| 18 |
+
from typing import Any
|
| 19 |
+
|
| 20 |
+
import anthropic
|
| 21 |
+
import uvicorn
|
| 22 |
+
|
| 23 |
+
from server.app import app, broadcast
|
| 24 |
+
from server.models import OrigamiAction
|
| 25 |
+
from server.origami_environment import OrigamiEnvironment
|
| 26 |
+
from server.tasks import get_task_by_name
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
TASK_NAME = "half_fold"
|
| 30 |
+
NUM_EPISODES = 8
|
| 31 |
+
MODEL = "claude-haiku-4-5-20251001"
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# ── LLM strategy factory ───────────────────────────────────────────────────────
|
| 35 |
+
|
| 36 |
+
def make_llm_strategy(client: anthropic.Anthropic, task: dict, episode_num: int):
|
| 37 |
+
"""Return a strategy_fn for one episode. Each episode gets its own call history."""
|
| 38 |
+
history: list[dict[str, Any]] = []
|
| 39 |
+
|
| 40 |
+
def strategy(paper_state: dict) -> dict:
|
| 41 |
+
fold_count = paper_state.get("fold_count", 0)
|
| 42 |
+
compactness = paper_state.get("compactness", 0)
|
| 43 |
+
bb = paper_state.get("bounding_box", [1, 1, 0])
|
| 44 |
+
target_box = task.get("target_box", [1, 0.5, 0.02])
|
| 45 |
+
max_folds = task.get("max_folds", 3)
|
| 46 |
+
|
| 47 |
+
user_msg = f"""You are folding a {task['width']}x{task['height']} sheet of {task['material']}.
|
| 48 |
+
Task: {task['description']}
|
| 49 |
+
Target box to fit inside: {target_box}
|
| 50 |
+
Max folds allowed: {max_folds}
|
| 51 |
+
|
| 52 |
+
Current state (fold {fold_count}/{max_folds}):
|
| 53 |
+
compactness: {compactness:.3f} (1.0 = fully packed, 0.0 = flat)
|
| 54 |
+
bounding_box: [{bb[0]:.3f}, {bb[1]:.3f}, {bb[2]:.4f}]
|
| 55 |
+
fits_target_box: {paper_state.get('fits_target_box', False)}
|
| 56 |
+
|
| 57 |
+
Choose the next fold. Respond with ONLY valid JSON, no other text:
|
| 58 |
+
{{
|
| 59 |
+
"type": "valley" or "mountain" or "stop",
|
| 60 |
+
"line": {{"start": [x, y], "end": [x, y]}},
|
| 61 |
+
"angle": 180
|
| 62 |
+
}}
|
| 63 |
+
|
| 64 |
+
Coordinates are normalized 0-1. Use "stop" if done."""
|
| 65 |
+
|
| 66 |
+
history.append({"role": "user", "content": user_msg})
|
| 67 |
+
|
| 68 |
+
response = client.messages.create(
|
| 69 |
+
model=MODEL,
|
| 70 |
+
max_tokens=120,
|
| 71 |
+
messages=history,
|
| 72 |
+
)
|
| 73 |
+
reply = response.content[0].text.strip()
|
| 74 |
+
history.append({"role": "assistant", "content": reply})
|
| 75 |
+
|
| 76 |
+
# Extract JSON — handle markdown code blocks
|
| 77 |
+
match = re.search(r'\{[^{}]+\}', reply, re.DOTALL)
|
| 78 |
+
if not match:
|
| 79 |
+
return {"type": "stop", "line": {"start": [0, 0.5], "end": [1, 0.5]}, "angle": 0.0}
|
| 80 |
+
|
| 81 |
+
fold_dict = json.loads(match.group())
|
| 82 |
+
# Normalize: ensure required keys
|
| 83 |
+
fold_dict.setdefault("type", "valley")
|
| 84 |
+
fold_dict.setdefault("line", {"start": [0.0, 0.5], "end": [1.0, 0.5]})
|
| 85 |
+
fold_dict.setdefault("angle", 180.0)
|
| 86 |
+
return fold_dict
|
| 87 |
+
|
| 88 |
+
return strategy
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
# ── Episode runner ─────────────────────────────────────────────────────────────
|
| 92 |
+
|
| 93 |
+
def run_episode_llm(
|
| 94 |
+
strategy_fn,
|
| 95 |
+
task_name: str,
|
| 96 |
+
ep_id: str,
|
| 97 |
+
broadcast_fn,
|
| 98 |
+
) -> dict:
|
| 99 |
+
env = OrigamiEnvironment()
|
| 100 |
+
obs = env.reset(task_name=task_name)
|
| 101 |
+
task = env._task or {}
|
| 102 |
+
|
| 103 |
+
broadcast_fn(ep_id, {
|
| 104 |
+
"type": "episode_update",
|
| 105 |
+
"episode_id": ep_id,
|
| 106 |
+
"task_name": task_name,
|
| 107 |
+
"step": 0,
|
| 108 |
+
"observation": _obs_dict(obs),
|
| 109 |
+
})
|
| 110 |
+
|
| 111 |
+
max_steps = task.get("max_folds", 5)
|
| 112 |
+
status = "done"
|
| 113 |
+
|
| 114 |
+
for step_idx in range(max_steps):
|
| 115 |
+
if obs.done:
|
| 116 |
+
break
|
| 117 |
+
|
| 118 |
+
# Build a flat paper_state dict for the LLM (add metrics inline)
|
| 119 |
+
ps = dict(obs.paper_state)
|
| 120 |
+
ps.update(obs.metrics) # compactness, fits_target_box, etc.
|
| 121 |
+
ps["fold_count"] = step_idx
|
| 122 |
+
|
| 123 |
+
try:
|
| 124 |
+
fold_dict = strategy_fn(ps)
|
| 125 |
+
except Exception as exc:
|
| 126 |
+
broadcast_fn(ep_id, {
|
| 127 |
+
"type": "episode_done", "episode_id": ep_id,
|
| 128 |
+
"status": "error", "score": 0.0,
|
| 129 |
+
"final_metrics": obs.metrics, "error": str(exc),
|
| 130 |
+
})
|
| 131 |
+
return {"episode_id": ep_id, "score": 0.0, "status": "error"}
|
| 132 |
+
|
| 133 |
+
if fold_dict.get("type") == "stop":
|
| 134 |
+
break
|
| 135 |
+
|
| 136 |
+
time.sleep(0.4) # pace for viewer animation
|
| 137 |
+
|
| 138 |
+
action = OrigamiAction(
|
| 139 |
+
fold_type=fold_dict["type"],
|
| 140 |
+
fold_line=fold_dict["line"],
|
| 141 |
+
fold_angle=float(fold_dict.get("angle", 180.0)),
|
| 142 |
+
)
|
| 143 |
+
obs = env.step(action)
|
| 144 |
+
|
| 145 |
+
broadcast_fn(ep_id, {
|
| 146 |
+
"type": "episode_update",
|
| 147 |
+
"episode_id": ep_id,
|
| 148 |
+
"task_name": task_name,
|
| 149 |
+
"step": step_idx + 1,
|
| 150 |
+
"observation": _obs_dict(obs),
|
| 151 |
+
})
|
| 152 |
+
|
| 153 |
+
if obs.done:
|
| 154 |
+
break
|
| 155 |
+
else:
|
| 156 |
+
status = "timeout"
|
| 157 |
+
|
| 158 |
+
score = obs.reward if obs.reward is not None else (env._total_reward or 0.0)
|
| 159 |
+
broadcast_fn(ep_id, {
|
| 160 |
+
"type": "episode_done",
|
| 161 |
+
"episode_id": ep_id,
|
| 162 |
+
"status": status,
|
| 163 |
+
"score": float(score),
|
| 164 |
+
"final_metrics": obs.metrics,
|
| 165 |
+
})
|
| 166 |
+
|
| 167 |
+
return {"episode_id": ep_id, "score": float(score), "status": status}
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def _obs_dict(obs) -> dict:
|
| 171 |
+
try:
|
| 172 |
+
return obs.model_dump()
|
| 173 |
+
except AttributeError:
|
| 174 |
+
return {
|
| 175 |
+
"paper_state": getattr(obs, "paper_state", {}),
|
| 176 |
+
"metrics": getattr(obs, "metrics", {}),
|
| 177 |
+
"fold_history": getattr(obs, "fold_history", []),
|
| 178 |
+
"done": getattr(obs, "done", False),
|
| 179 |
+
"reward": getattr(obs, "reward", None),
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
# ── Main ──────────────────────────────────────────────────────────────────────
|
| 184 |
+
|
| 185 |
+
async def run_demo() -> None:
|
| 186 |
+
api_key = os.environ.get("ANTHROPIC_API_KEY")
|
| 187 |
+
if not api_key:
|
| 188 |
+
raise RuntimeError("Set ANTHROPIC_API_KEY environment variable")
|
| 189 |
+
|
| 190 |
+
client = anthropic.Anthropic(api_key=api_key)
|
| 191 |
+
task = get_task_by_name(TASK_NAME)
|
| 192 |
+
|
| 193 |
+
await asyncio.sleep(1.5) # wait for server startup
|
| 194 |
+
|
| 195 |
+
print(f"\n[llm-demo] Model: {MODEL}")
|
| 196 |
+
print(f"[llm-demo] Task: {TASK_NAME} — {task['description']}")
|
| 197 |
+
print(f"[llm-demo] Open: http://localhost:9001/viewer/training.html\n")
|
| 198 |
+
|
| 199 |
+
await broadcast.start_batch(1, NUM_EPISODES)
|
| 200 |
+
|
| 201 |
+
ep_ids = [f"ep_{i:02d}" for i in range(NUM_EPISODES)]
|
| 202 |
+
strategies = [make_llm_strategy(client, task, i) for i in range(NUM_EPISODES)]
|
| 203 |
+
|
| 204 |
+
# Run all episodes concurrently (each makes its own Claude API calls)
|
| 205 |
+
results = await asyncio.gather(*[
|
| 206 |
+
asyncio.to_thread(run_episode_llm, fn, TASK_NAME, ep_id, broadcast.publish)
|
| 207 |
+
for fn, ep_id in zip(strategies, ep_ids)
|
| 208 |
+
])
|
| 209 |
+
|
| 210 |
+
scores = [r["score"] for r in results]
|
| 211 |
+
best_idx = max(range(len(scores)), key=lambda i: scores[i])
|
| 212 |
+
|
| 213 |
+
await broadcast.finish_batch(1, scores, best_episode_id=ep_ids[best_idx])
|
| 214 |
+
|
| 215 |
+
print("\n[llm-demo] Results:")
|
| 216 |
+
for i, result in enumerate(results):
|
| 217 |
+
print(f" ep_{i:02d} score={result['score']:+.2f} status={result['status']}")
|
| 218 |
+
print(f"\n[llm-demo] Best: ep_{best_idx:02d} (score={scores[best_idx]:+.2f})")
|
| 219 |
+
print("\n[llm-demo] Press Ctrl+C to stop.\n")
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
async def _main() -> None:
|
| 223 |
+
config = uvicorn.Config(app, host="0.0.0.0", port=9001, log_level="warning")
|
| 224 |
+
server = uvicorn.Server(config)
|
| 225 |
+
await asyncio.gather(server.serve(), run_demo())
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
if __name__ == "__main__":
|
| 229 |
+
try:
|
| 230 |
+
asyncio.run(_main())
|
| 231 |
+
except KeyboardInterrupt:
|
| 232 |
+
print("\n[llm-demo] Stopped.")
|
training/runner.py
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
TrainingRunner — parallel episode executor for GRPO training.
|
| 3 |
+
|
| 4 |
+
Each episode runs in a ThreadPoolExecutor thread.
|
| 5 |
+
After every env.step(), observations are pushed to the broadcast server (fire-and-forget).
|
| 6 |
+
"""
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import uuid
|
| 10 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 11 |
+
from typing import Any, Callable, Optional
|
| 12 |
+
|
| 13 |
+
from server.models import OrigamiAction
|
| 14 |
+
from server.origami_environment import OrigamiEnvironment
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
BroadcastFn = Callable[[str, dict], None]
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def run_episode(
|
| 21 |
+
strategy_fn: Callable[[dict], dict],
|
| 22 |
+
task_name: str,
|
| 23 |
+
ep_id: Optional[str] = None,
|
| 24 |
+
broadcast_fn: Optional[BroadcastFn] = None,
|
| 25 |
+
max_steps: Optional[int] = None,
|
| 26 |
+
) -> dict:
|
| 27 |
+
"""Run a single origami episode with a given strategy function.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
strategy_fn: Callable that receives paper_state dict and returns a fold dict:
|
| 31 |
+
{"type": "valley"|"mountain"|"pleat"|"crimp"|"stop",
|
| 32 |
+
"line": {"start": [x, y], "end": [x, y]},
|
| 33 |
+
"angle": 180.0}
|
| 34 |
+
task_name: Name of the task (from server/tasks.py)
|
| 35 |
+
ep_id: Episode identifier for broadcast; auto-generated if None
|
| 36 |
+
broadcast_fn: Optional callback(ep_id, data) for live streaming
|
| 37 |
+
max_steps: Override task's max_folds if provided
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
dict with keys: episode_id, score, final_metrics, fold_history, status
|
| 41 |
+
"""
|
| 42 |
+
ep_id = ep_id or str(uuid.uuid4())[:8]
|
| 43 |
+
env = OrigamiEnvironment()
|
| 44 |
+
|
| 45 |
+
obs = env.reset(task_name=task_name)
|
| 46 |
+
|
| 47 |
+
if broadcast_fn:
|
| 48 |
+
broadcast_fn(ep_id, {
|
| 49 |
+
"type": "episode_update",
|
| 50 |
+
"episode_id": ep_id,
|
| 51 |
+
"task_name": task_name,
|
| 52 |
+
"step": 0,
|
| 53 |
+
"observation": _obs_to_dict(obs),
|
| 54 |
+
})
|
| 55 |
+
|
| 56 |
+
step_limit = max_steps or env._task.get("max_folds", 20) if env._task else 20
|
| 57 |
+
status = "done"
|
| 58 |
+
|
| 59 |
+
for step_idx in range(step_limit):
|
| 60 |
+
if obs.done:
|
| 61 |
+
break
|
| 62 |
+
|
| 63 |
+
# Strategy generates a fold dict
|
| 64 |
+
try:
|
| 65 |
+
fold_dict = strategy_fn(obs.paper_state)
|
| 66 |
+
except Exception as exc:
|
| 67 |
+
status = "error"
|
| 68 |
+
if broadcast_fn:
|
| 69 |
+
broadcast_fn(ep_id, {
|
| 70 |
+
"type": "episode_done",
|
| 71 |
+
"episode_id": ep_id,
|
| 72 |
+
"status": "error",
|
| 73 |
+
"score": obs.reward or 0.0,
|
| 74 |
+
"final_metrics": obs.metrics,
|
| 75 |
+
"error": str(exc),
|
| 76 |
+
})
|
| 77 |
+
break
|
| 78 |
+
|
| 79 |
+
fold_type = fold_dict.get("type", "valley")
|
| 80 |
+
fold_line = fold_dict.get("line", {"start": [0, 0.5], "end": [1, 0.5]})
|
| 81 |
+
fold_angle = float(fold_dict.get("angle", 180.0))
|
| 82 |
+
|
| 83 |
+
action = OrigamiAction(
|
| 84 |
+
fold_type=fold_type,
|
| 85 |
+
fold_line=fold_line,
|
| 86 |
+
fold_angle=fold_angle,
|
| 87 |
+
)
|
| 88 |
+
obs = env.step(action)
|
| 89 |
+
|
| 90 |
+
if broadcast_fn:
|
| 91 |
+
broadcast_fn(ep_id, {
|
| 92 |
+
"type": "episode_update",
|
| 93 |
+
"episode_id": ep_id,
|
| 94 |
+
"task_name": task_name,
|
| 95 |
+
"step": step_idx + 1,
|
| 96 |
+
"observation": _obs_to_dict(obs),
|
| 97 |
+
})
|
| 98 |
+
|
| 99 |
+
if obs.done:
|
| 100 |
+
break
|
| 101 |
+
else:
|
| 102 |
+
status = "timeout"
|
| 103 |
+
|
| 104 |
+
score = obs.reward if obs.reward is not None else (env._total_reward or 0.0)
|
| 105 |
+
|
| 106 |
+
if broadcast_fn:
|
| 107 |
+
broadcast_fn(ep_id, {
|
| 108 |
+
"type": "episode_done",
|
| 109 |
+
"episode_id": ep_id,
|
| 110 |
+
"status": status,
|
| 111 |
+
"score": float(score),
|
| 112 |
+
"final_metrics": obs.metrics,
|
| 113 |
+
})
|
| 114 |
+
|
| 115 |
+
return {
|
| 116 |
+
"episode_id": ep_id,
|
| 117 |
+
"score": float(score),
|
| 118 |
+
"final_metrics": obs.metrics,
|
| 119 |
+
"fold_history": obs.fold_history,
|
| 120 |
+
"status": status,
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def run_batch(
|
| 125 |
+
strategy_fns: list[Callable[[dict], dict]],
|
| 126 |
+
task_name: str,
|
| 127 |
+
broadcast_fn: Optional[BroadcastFn] = None,
|
| 128 |
+
batch_id: Optional[int] = None,
|
| 129 |
+
max_workers: int = 8,
|
| 130 |
+
) -> list[dict]:
|
| 131 |
+
"""Run G episodes in parallel with a ThreadPoolExecutor.
|
| 132 |
+
|
| 133 |
+
Args:
|
| 134 |
+
strategy_fns: List of G strategy callables (one per completion)
|
| 135 |
+
task_name: Task to use for all episodes
|
| 136 |
+
broadcast_fn: Optional broadcast callback, called after each step
|
| 137 |
+
batch_id: Batch identifier for broadcast
|
| 138 |
+
max_workers: Max parallel threads (bounded by G)
|
| 139 |
+
|
| 140 |
+
Returns:
|
| 141 |
+
List of episode result dicts, in same order as strategy_fns
|
| 142 |
+
"""
|
| 143 |
+
n = len(strategy_fns)
|
| 144 |
+
ep_ids = [f"ep_{(batch_id or 0):04d}_{i:02d}" for i in range(n)]
|
| 145 |
+
workers = min(max_workers, n)
|
| 146 |
+
|
| 147 |
+
results: list[dict] = [{}] * n
|
| 148 |
+
|
| 149 |
+
with ThreadPoolExecutor(max_workers=workers) as pool:
|
| 150 |
+
futures = {
|
| 151 |
+
pool.submit(
|
| 152 |
+
run_episode,
|
| 153 |
+
fn,
|
| 154 |
+
task_name,
|
| 155 |
+
ep_ids[i],
|
| 156 |
+
broadcast_fn,
|
| 157 |
+
): i
|
| 158 |
+
for i, fn in enumerate(strategy_fns)
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
for future in as_completed(futures):
|
| 162 |
+
idx = futures[future]
|
| 163 |
+
try:
|
| 164 |
+
results[idx] = future.result()
|
| 165 |
+
except Exception as exc:
|
| 166 |
+
results[idx] = {
|
| 167 |
+
"episode_id": ep_ids[idx],
|
| 168 |
+
"score": 0.0,
|
| 169 |
+
"final_metrics": {},
|
| 170 |
+
"fold_history": [],
|
| 171 |
+
"status": "error",
|
| 172 |
+
"error": str(exc),
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
return results
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def _obs_to_dict(obs) -> dict:
|
| 179 |
+
"""Convert OrigamiObservation to a JSON-serializable dict."""
|
| 180 |
+
try:
|
| 181 |
+
return obs.model_dump()
|
| 182 |
+
except AttributeError:
|
| 183 |
+
return {
|
| 184 |
+
"task": obs.task if hasattr(obs, "task") else {},
|
| 185 |
+
"paper_state": obs.paper_state if hasattr(obs, "paper_state") else {},
|
| 186 |
+
"metrics": obs.metrics if hasattr(obs, "metrics") else {},
|
| 187 |
+
"fold_history": obs.fold_history if hasattr(obs, "fold_history") else [],
|
| 188 |
+
"done": obs.done if hasattr(obs, "done") else False,
|
| 189 |
+
"reward": obs.reward if hasattr(obs, "reward") else None,
|
| 190 |
+
"error": obs.error if hasattr(obs, "error") else None,
|
| 191 |
+
}
|