ianalin123 commited on
Commit
8652f7e
·
1 Parent(s): 438e23a

refactor(openenv): simplify runtime environment and models, extend server API

Browse files
openenv_runtime/environment.py CHANGED
@@ -1,183 +1,53 @@
1
- from __future__ import annotations
 
2
 
3
- from typing import Any, Optional
 
 
 
4
 
5
- from openenv.core.env_server.interfaces import Environment
6
 
7
- from env.environment import OrigamiEnvironment
8
 
9
- from .models import OrigamiAction, OrigamiObservation, OrigamiState
 
 
10
 
 
 
11
 
12
- class OpenEnvOrigamiEnvironment(Environment[OrigamiAction, OrigamiObservation, OrigamiState]):
13
- """OpenEnv adapter over the existing OrigamiEnvironment implementation."""
14
 
15
- SUPPORTS_CONCURRENT_SESSIONS = True
16
-
17
- def __init__(
18
- self,
19
- default_mode: str = "step",
20
- max_steps: int = 8,
21
- targets_dir: Optional[str] = None,
22
- ):
23
- super().__init__()
24
- self.default_mode = default_mode
25
- self.max_steps = max_steps
26
- self.targets_dir = targets_dir
27
- self._env: Optional[OrigamiEnvironment] = None
28
- self._episode_id: Optional[str] = None
29
-
30
- def _new_env(self, mode: Optional[str] = None) -> OrigamiEnvironment:
31
- return OrigamiEnvironment(
32
- mode=mode or self.default_mode,
33
- max_steps=self.max_steps,
34
- targets_dir=self.targets_dir,
35
- )
36
-
37
- def reset(
38
- self,
39
- seed: Optional[int] = None,
40
- episode_id: Optional[str] = None,
41
- **kwargs: Any,
42
- ) -> OrigamiObservation:
43
- del seed # deterministic seed plumbing can be added later
44
-
45
- mode = kwargs.get("mode", self.default_mode)
46
- target_name = kwargs.get("target_name")
47
-
48
- self._env = self._new_env(mode=mode)
49
- self._episode_id = episode_id
50
  obs_dict = self._env.reset(target_name=target_name)
 
51
 
 
 
 
 
 
 
 
 
 
 
 
52
  return OrigamiObservation(
53
- done=False,
54
- reward=None,
55
- metadata={"available_targets": self._env.available_targets()},
56
  prompt=obs_dict.get("prompt", ""),
57
- target_name=obs_dict.get("target_name"),
58
  step=obs_dict.get("step", 0),
59
- paper_state=self._paper_state_snapshot(),
60
- info=self._env._info(),
61
- reward_components={},
62
- )
63
-
64
- def step(
65
- self,
66
- action: OrigamiAction,
67
- timeout_s: Optional[float] = None,
68
- **kwargs: Any,
69
- ) -> OrigamiObservation:
70
- del timeout_s, kwargs
71
-
72
- if self._env is None:
73
- self.reset(target_name=action.target_name)
74
-
75
- assert self._env is not None
76
-
77
- if action.target_name and action.target_name != self._env.target_name:
78
- self.reset(target_name=action.target_name, mode=self._env.mode)
79
-
80
- try:
81
- if action.mode == "sequence":
82
- if not action.completion:
83
- return self._error_observation("sequence mode requires completion")
84
-
85
- seq_env = self._new_env(mode="code_as_policy")
86
- seq_env.reset(target_name=self._env.target_name)
87
- obs_dict, reward_dict, done, info = seq_env.step(action.completion)
88
- self._env = seq_env
89
- else:
90
- if action.fold is not None:
91
- fold_payload = {
92
- "from": list(action.fold.from_point),
93
- "to": list(action.fold.to_point),
94
- "assignment": action.fold.assignment,
95
- "instruction": action.fold.instruction,
96
- }
97
- env_action: Any = fold_payload
98
- elif action.completion:
99
- env_action = action.completion
100
- else:
101
- return self._error_observation("single mode requires fold or completion")
102
-
103
- obs_dict, reward_dict, done, info = self._env.step(env_action)
104
-
105
- total = reward_dict.get("total") if isinstance(reward_dict, dict) else None
106
- return OrigamiObservation(
107
- done=bool(done),
108
- reward=float(total) if isinstance(total, (int, float)) else None,
109
- metadata={"target_name": self._env.target_name},
110
- prompt=obs_dict.get("prompt", ""),
111
- target_name=obs_dict.get("target_name", self._env.target_name),
112
- step=obs_dict.get("step", self._env.step_count),
113
- paper_state=self._paper_state_snapshot(),
114
- info=info or {},
115
- reward_components=reward_dict or {},
116
- )
117
- except Exception as exc: # pragma: no cover - defensive path
118
- return self._error_observation(str(exc))
119
-
120
- @property
121
- def state(self) -> OrigamiState:
122
- if self._env is None:
123
- tmp_env = self._new_env(mode=self.default_mode)
124
- return OrigamiState(
125
- episode_id=self._episode_id,
126
- step_count=0,
127
- mode=tmp_env.mode,
128
- target_name=None,
129
- paper={},
130
- last_reward={},
131
- available_targets=tmp_env.available_targets(),
132
- )
133
-
134
- env_state = self._env.state()
135
- return OrigamiState(
136
- episode_id=self._episode_id,
137
- step_count=env_state.get("step", self._env.step_count),
138
- mode=env_state.get("mode", self._env.mode),
139
- target_name=env_state.get("target", self._env.target_name),
140
- paper=env_state.get("paper", {}),
141
- last_reward=self._env.last_reward or {},
142
- available_targets=self._env.available_targets(),
143
  )
144
 
145
- def close(self) -> None:
146
- if self._env is not None:
147
- self._env.close()
148
- self._env = None
149
 
150
- def _paper_state_snapshot(self) -> dict[str, Any]:
151
- if self._env is None or self._env.paper is None:
152
- return {"vertices": {}, "edges": [], "anchor_points": []}
153
 
154
- graph = self._env.paper.graph
155
- return {
156
- "vertices": {str(k): [float(v[0]), float(v[1])] for k, v in graph.vertices.items()},
157
- "edges": [
158
- {
159
- "id": int(eid),
160
- "v1": [float(graph.vertices[v1][0]), float(graph.vertices[v1][1])],
161
- "v2": [float(graph.vertices[v2][0]), float(graph.vertices[v2][1])],
162
- "assignment": assignment,
163
- }
164
- for eid, (v1, v2, assignment) in graph.edges.items()
165
- ],
166
- "anchor_points": [
167
- [float(x), float(y)] for (x, y) in self._env.paper.anchor_points()
168
- ],
169
- }
170
 
171
- def _error_observation(self, message: str) -> OrigamiObservation:
172
- return OrigamiObservation(
173
- done=False,
174
- reward=-0.1,
175
- metadata={"error": True},
176
- prompt="",
177
- target_name=self._env.target_name if self._env else None,
178
- step=self._env.step_count if self._env else 0,
179
- paper_state=self._paper_state_snapshot(),
180
- info=self._env._info() if self._env else {},
181
- reward_components={"format": 0.0, "total": -0.1, "error": message},
182
- error=message,
183
- )
 
1
+ """
2
+ OpenEnv adapter for Optigami.
3
 
4
+ Thin wrapper around env.environment.OrigamiEnvironment that adapts it to the
5
+ OpenEnv protocol (Action/Observation types).
6
+ """
7
+ from env.environment import OrigamiEnvironment as _Env
8
 
9
+ from .models import OrigamiAction, OrigamiObservation
10
 
 
11
 
12
+ class OpenEnvOrigamiEnvironment:
13
+ """
14
+ OpenEnv-compatible wrapper for env.environment.OrigamiEnvironment.
15
 
16
+ Converts between env's dict-based API and OpenEnv's Action/Observation types.
17
+ """
18
 
19
+ def __init__(self, mode: str = "step", max_steps: int = 8, targets_dir=None):
20
+ self._env = _Env(mode=mode, max_steps=max_steps, targets_dir=targets_dir)
21
 
22
+ def reset(self, target_name=None, **kwargs):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  obs_dict = self._env.reset(target_name=target_name)
24
+ return self._obs_dict_to_model(obs_dict, reward=None, done=False)
25
 
26
+ def step(self, action: OrigamiAction, **kwargs):
27
+ action_dict = {
28
+ "from": action.from_point,
29
+ "to": action.to_point,
30
+ "assignment": action.assignment,
31
+ }
32
+ obs_dict, reward, done, info = self._env.step(action_dict)
33
+ reward_val = reward.get("total", 0.0) if isinstance(reward, dict) else reward
34
+ return self._obs_dict_to_model(obs_dict, reward=reward_val, done=done)
35
+
36
+ def _obs_dict_to_model(self, obs_dict: dict, reward=None, done=False) -> OrigamiObservation:
37
  return OrigamiObservation(
 
 
 
38
  prompt=obs_dict.get("prompt", ""),
39
+ target_name=obs_dict.get("target_name", ""),
40
  step=obs_dict.get("step", 0),
41
+ paper_fold_json=obs_dict.get("paper_fold_json", {}),
42
+ reward=reward,
43
+ done=done,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  )
45
 
46
+ def state(self):
47
+ return self._env.state()
 
 
48
 
49
+ def close(self):
50
+ self._env.close()
 
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
+ __all__ = ["OpenEnvOrigamiEnvironment"]
 
 
 
 
 
 
 
 
 
 
 
 
openenv_runtime/models.py CHANGED
@@ -1,63 +1,53 @@
1
- from __future__ import annotations
 
2
 
3
- from typing import Any, Literal, Optional
 
 
 
4
 
5
- from pydantic import BaseModel, Field, field_validator
6
 
7
  from openenv.core.env_server.types import Action, Observation, State
8
 
9
 
10
- class OrigamiFold(BaseModel):
11
- """Single fold action payload for step-level execution."""
12
-
13
- from_point: list[float] = Field(..., description="Fold line start [x, y]")
14
- to_point: list[float] = Field(..., description="Fold line end [x, y]")
15
- assignment: Literal["M", "V"] = Field(..., description="Mountain or valley")
16
- instruction: str = Field(default="", description="Optional natural language instruction")
17
-
18
- @field_validator("from_point", "to_point")
19
- @classmethod
20
- def _validate_point(cls, point: list[float]) -> list[float]:
21
- if len(point) != 2:
22
- raise ValueError("Point must contain exactly 2 coordinates")
23
- return [float(point[0]), float(point[1])]
24
 
 
25
 
26
- class OrigamiAction(Action):
27
- """
28
- OpenEnv action for Optigami.
29
-
30
- Modes:
31
- - single: execute one fold (pass `fold` or JSON `completion` for a single-fold object)
32
- - sequence: execute a full <folds>[...]</folds> completion in one step
33
- """
34
-
35
- mode: Literal["single", "sequence"] = Field(default="single")
36
- fold: Optional[OrigamiFold] = Field(default=None)
37
- completion: Optional[str] = Field(default=None)
38
- target_name: Optional[str] = Field(
39
- default=None,
40
- description="Optional target override; reset to this target before stepping",
41
  )
42
 
43
 
44
  class OrigamiObservation(Observation):
45
- """OpenEnv observation payload returned by Optigami."""
46
-
47
- prompt: str = Field(default="")
48
- target_name: Optional[str] = Field(default=None)
49
- step: int = Field(default=0)
50
- paper_state: dict[str, Any] = Field(default_factory=dict)
51
- info: dict[str, Any] = Field(default_factory=dict)
52
- reward_components: dict[str, float | int | str] = Field(default_factory=dict)
53
- error: Optional[str] = Field(default=None)
54
 
55
 
56
  class OrigamiState(State):
57
- """OpenEnv state payload for Optigami."""
 
 
 
 
 
 
58
 
59
- mode: str = Field(default="step")
60
- target_name: Optional[str] = Field(default=None)
61
- paper: dict[str, Any] = Field(default_factory=dict)
62
- last_reward: dict[str, Any] = Field(default_factory=dict)
63
- available_targets: list[str] = Field(default_factory=list)
 
1
+ """
2
+ OpenEnv Pydantic models for the env/ stack.
3
 
4
+ Matches the env/environment data shape: observations with prompt, target_name,
5
+ step, paper_fold_json; actions as fold dicts with from/to/assignment.
6
+ """
7
+ from typing import Optional
8
 
9
+ from pydantic import ConfigDict, Field
10
 
11
  from openenv.core.env_server.types import Action, Observation, State
12
 
13
 
14
+ class OrigamiAction(Action):
15
+ """One fold operation from_point, to_point, assignment."""
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
+ model_config = ConfigDict(populate_by_name=True)
18
 
19
+ from_point: list[float] = Field(
20
+ alias="from",
21
+ description="[x, y] start point of the crease",
22
+ )
23
+ to_point: list[float] = Field(
24
+ alias="to",
25
+ description="[x, y] end point of the crease",
26
+ )
27
+ assignment: str = Field(
28
+ description="'M' (mountain) or 'V' (valley)",
 
 
 
 
 
29
  )
30
 
31
 
32
  class OrigamiObservation(Observation):
33
+ """Observation from env.environment prompt, target, step, paper state."""
34
+
35
+ prompt: str = Field(default="", description="LLM prompt for the current step")
36
+ target_name: str = Field(default="", description="Name of the target (.fold stem)")
37
+ step: int = Field(default=0, ge=0, description="Current step index")
38
+ paper_fold_json: dict = Field(
39
+ default_factory=dict,
40
+ description="Graph edges (crease pattern state)",
41
+ )
42
 
43
 
44
  class OrigamiState(State):
45
+ """Server-side episode state."""
46
+
47
+ paper: dict = Field(default_factory=dict, description="Paper state")
48
+ target: Optional[str] = Field(default=None, description="Target name")
49
+ step: int = Field(default=0, ge=0, description="Step count")
50
+ mode: str = Field(default="step", description="'step' or 'code_as_policy'")
51
+
52
 
53
+ __all__ = ["OrigamiAction", "OrigamiObservation", "OrigamiState"]
 
 
 
 
openenv_server/app.py CHANGED
@@ -1,12 +1,25 @@
1
  from __future__ import annotations
2
 
 
3
  import json
4
  from pathlib import Path
5
 
6
  import numpy as np
 
7
  from fastapi.responses import HTMLResponse, JSONResponse
8
  from fastapi.staticfiles import StaticFiles
9
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  def _np_default(obj):
12
  if isinstance(obj, np.bool_):
@@ -23,56 +36,150 @@ def _np_default(obj):
23
  class NumpyJSONResponse(JSONResponse):
24
  def render(self, content) -> bytes:
25
  return json.dumps(content, default=_np_default).encode("utf-8")
26
- from openenv.core.env_server.http_server import create_app
27
 
28
- from openenv_runtime.environment import OpenEnvOrigamiEnvironment
29
- from openenv_runtime.models import OrigamiAction, OrigamiObservation
 
 
 
 
30
 
31
 
 
 
 
 
32
  app = create_app(
33
- env=lambda: OpenEnvOrigamiEnvironment(),
34
  action_cls=OrigamiAction,
35
  observation_cls=OrigamiObservation,
36
  env_name="optigami",
37
  )
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  # ---------------------------------------------------------------------------
41
- # Demo fold sequences — new format: type, line {start, end}, angle
42
- # ---------------------------------------------------------------------------
43
-
44
- DEMO_SEQUENCES: dict[str, list[dict]] = {
45
- "half_fold": [
46
- {"type": "valley", "line": {"start": [0.0, 0.5], "end": [1.0, 0.5]}, "angle": 180.0},
47
- ],
48
- "quarter_fold": [
49
- {"type": "valley", "line": {"start": [0.0, 0.5], "end": [1.0, 0.5]}, "angle": 180.0},
50
- {"type": "valley", "line": {"start": [0.0, 0.5], "end": [1.0, 0.5]}, "angle": 180.0},
51
- ],
52
- "letter_fold": [
53
- {"type": "valley", "line": {"start": [0.0, 0.333], "end": [1.0, 0.333]}, "angle": 180.0},
54
- {"type": "mountain", "line": {"start": [0.0, 0.667], "end": [1.0, 0.667]}, "angle": 180.0},
55
- ],
56
- "map_fold": [
57
- {"type": "valley", "line": {"start": [0.0, 0.5], "end": [1.0, 0.5]}, "angle": 180.0},
58
- {"type": "mountain", "line": {"start": [0.5, 0.0], "end": [0.5, 1.0]}, "angle": 180.0},
59
- ],
60
- "solar_panel": [
61
- {"type": "valley", "line": {"start": [0.0, 0.25], "end": [1.0, 0.25]}, "angle": 180.0},
62
- {"type": "mountain", "line": {"start": [0.0, 0.5], "end": [1.0, 0.5]}, "angle": 180.0},
63
- {"type": "valley", "line": {"start": [0.0, 0.75], "end": [1.0, 0.75]}, "angle": 180.0},
64
- ],
65
- "shelter_wall": [
66
- {"type": "valley", "line": {"start": [0.0, 0.333], "end": [1.0, 0.333]}, "angle": 180.0},
67
- {"type": "valley", "line": {"start": [0.0, 0.667], "end": [1.0, 0.667]}, "angle": 180.0},
68
- ],
69
- "stent": [
70
- {"type": "valley", "line": {"start": [0.0, 0.25], "end": [1.0, 0.25]}, "angle": 90.0},
71
- {"type": "mountain", "line": {"start": [0.0, 0.5], "end": [1.0, 0.5]}, "angle": 90.0},
72
- {"type": "valley", "line": {"start": [0.0, 0.75], "end": [1.0, 0.75]}, "angle": 90.0},
73
- {"type": "stop", "line": {"start": [0.0, 0.0], "end": [1.0, 1.0]}, "angle": 0.0},
74
- ],
75
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
 
78
  # ---------------------------------------------------------------------------
@@ -81,68 +188,62 @@ DEMO_SEQUENCES: dict[str, list[dict]] = {
81
 
82
  @app.get("/targets", include_in_schema=True, response_class=NumpyJSONResponse)
83
  def get_targets():
84
- """Return available task names and metadata for the frontend."""
85
- from server.tasks import get_task_by_name, available_task_names
86
-
87
  result: dict[str, dict] = {}
88
- for name in available_task_names():
89
- t = get_task_by_name(name)
90
  result[name] = {
91
  "name": name,
92
- "level": t.get("difficulty", 1),
93
- "description": t.get("description", ""),
94
- "n_creases": t.get("max_folds", 3),
95
- "difficulty": t.get("difficulty", 1),
96
- "material": t.get("material", "paper"),
97
  }
98
  return NumpyJSONResponse(result)
99
 
100
 
101
  @app.get("/episode/demo", include_in_schema=True, response_class=NumpyJSONResponse)
102
- def demo_episode(target: str = "half_fold"):
103
- """Return a pre-solved demo episode for the given task."""
104
- from server.origami_environment import OrigamiEnvironment
105
- from server.models import OrigamiAction as NewOrigamiAction
106
- from server.tasks import get_task_by_name
 
107
 
108
- # Fall back to half_fold if target not found
109
- folds = DEMO_SEQUENCES.get(target, DEMO_SEQUENCES["half_fold"])
110
-
111
- env = OrigamiEnvironment()
112
- obs = env.reset(task_name=target)
113
 
 
114
  steps: list[dict] = []
115
 
116
  for i, fold_dict in enumerate(folds):
117
- if fold_dict.get("type") == "stop":
118
- break
119
-
120
- action = NewOrigamiAction(
121
- fold_type=fold_dict["type"],
122
- fold_line=fold_dict["line"],
123
- fold_angle=float(fold_dict.get("angle", 180.0)),
124
- )
125
-
126
- obs = env.step(action)
127
 
128
  steps.append({
129
  "step": i + 1,
130
  "fold": fold_dict,
131
- "paper_state": obs.paper_state,
132
- "metrics": obs.metrics,
133
- "done": obs.done,
134
  })
135
-
136
- if obs.done:
137
  break
138
 
139
- task_def = get_task_by_name(target) if target else {}
140
-
141
  return NumpyJSONResponse({
142
  "task_name": target,
143
- "task": task_def,
 
144
  "steps": steps,
145
- "final_metrics": obs.metrics if steps else {},
146
  })
147
 
148
 
 
1
  from __future__ import annotations
2
 
3
+ import asyncio
4
  import json
5
  from pathlib import Path
6
 
7
  import numpy as np
8
+ from fastapi import HTTPException, WebSocket
9
  from fastapi.responses import HTMLResponse, JSONResponse
10
  from fastapi.staticfiles import StaticFiles
11
 
12
+ from openenv.core.env_server.http_server import create_app
13
+
14
+ from env.environment import OrigamiEnvironment
15
+ from openenv_runtime.environment import OpenEnvOrigamiEnvironment
16
+ from openenv_runtime.models import OrigamiAction, OrigamiObservation
17
+ from server.training_broadcast import TrainingBroadcastServer
18
+
19
+
20
+ # ---------------------------------------------------------------------------
21
+ # Numpy-safe JSON response
22
+ # ---------------------------------------------------------------------------
23
 
24
  def _np_default(obj):
25
  if isinstance(obj, np.bool_):
 
36
  class NumpyJSONResponse(JSONResponse):
37
  def render(self, content) -> bytes:
38
  return json.dumps(content, default=_np_default).encode("utf-8")
 
39
 
40
+
41
+ # ---------------------------------------------------------------------------
42
+ # Episode registry for replay
43
+ # ---------------------------------------------------------------------------
44
+
45
+ _episode_registry: dict[str, dict] = {}
46
 
47
 
48
+ # ---------------------------------------------------------------------------
49
+ # OpenEnv app + training broadcast server
50
+ # ---------------------------------------------------------------------------
51
+
52
  app = create_app(
53
+ env=lambda: OpenEnvOrigamiEnvironment(mode="step"),
54
  action_cls=OrigamiAction,
55
  observation_cls=OrigamiObservation,
56
  env_name="optigami",
57
  )
58
 
59
+ broadcast = TrainingBroadcastServer()
60
+
61
+
62
+ def _ensure_broadcast_loop():
63
+ """Set broadcast loop on first use (replaces deprecated on_event('startup'))."""
64
+ if broadcast._loop is None or broadcast._loop.is_closed():
65
+ try:
66
+ broadcast._loop = asyncio.get_running_loop()
67
+ except RuntimeError:
68
+ pass
69
+
70
+
71
+ @app.middleware("http")
72
+ async def _set_broadcast_loop(request, call_next):
73
+ """Ensure broadcast has event loop before handling requests."""
74
+ _ensure_broadcast_loop()
75
+ return await call_next(request)
76
+
77
 
78
  # ---------------------------------------------------------------------------
79
+ # Health endpoint
80
+ # ---------------------------------------------------------------------------
81
+
82
+ @app.get("/health", include_in_schema=True)
83
+ async def health():
84
+ return {"status": "ok"}
85
+
86
+
87
+ # ---------------------------------------------------------------------------
88
+ # Episode replay endpoint
89
+ # ---------------------------------------------------------------------------
90
+
91
+ @app.get("/episode/replay/{ep_id}", include_in_schema=True, response_class=NumpyJSONResponse)
92
+ async def replay_episode(ep_id: str):
93
+ if ep_id not in _episode_registry:
94
+ raise HTTPException(status_code=404, detail="Episode not found")
95
+ return NumpyJSONResponse(_episode_registry[ep_id])
96
+
97
+
98
+ # ---------------------------------------------------------------------------
99
+ # Training grid viewer WebSocket
100
+ # ---------------------------------------------------------------------------
101
+
102
+ @app.websocket("/ws/training")
103
+ async def training_ws(websocket: WebSocket):
104
+ """Read-only spectator WebSocket for the training grid viewer."""
105
+ _ensure_broadcast_loop()
106
+ await broadcast.connect_spectator(websocket)
107
+
108
+
109
+ # ---------------------------------------------------------------------------
110
+ # Helper: extract crease folds from .fold target
111
+ # ---------------------------------------------------------------------------
112
+
113
+ def _target_to_folds(target: dict) -> list[dict]:
114
+ """Extract crease folds from a target .fold dict (edges with M or V)."""
115
+ verts = target.get("vertices_coords", [])
116
+ edges_v = target.get("edges_vertices", [])
117
+ edges_a = target.get("edges_assignment", [])
118
+ folds = []
119
+ for (v1, v2), ass in zip(edges_v, edges_a):
120
+ if ass in ("M", "V") and v1 < len(verts) and v2 < len(verts):
121
+ p1 = verts[v1]
122
+ p2 = verts[v2]
123
+ folds.append({"from": p1, "to": p2, "assignment": ass})
124
+ return folds
125
+
126
+
127
+ def _graph_state_to_fold(paper_dict: dict) -> dict:
128
+ """Convert internal graph state dict to FOLD-format arrays for the frontend.
129
+
130
+ Input format (from env.state()['paper']):
131
+ vertices: {id: (x, y), ...}
132
+ edges: {id: (v1_id, v2_id, assignment), ...} (only M/V)
133
+
134
+ Output format (FOLD):
135
+ vertices_coords: [[x, y, 0], ...]
136
+ edges_vertices: [[i, j], ...]
137
+ edges_assignment: ['M'|'V'|'B', ...]
138
+ faces_vertices: [[i, j, k], ...] (Delaunay triangulation for 3D)
139
+ """
140
+ raw_verts = paper_dict.get("vertices", {})
141
+ raw_edges = paper_dict.get("edges", {})
142
+
143
+ if not raw_verts:
144
+ return {}
145
+
146
+ sorted_ids = sorted(raw_verts.keys(), key=lambda k: int(k) if isinstance(k, (int, str)) else k)
147
+ id_to_idx = {vid: idx for idx, vid in enumerate(sorted_ids)}
148
+
149
+ vertices_coords = []
150
+ for vid in sorted_ids:
151
+ xy = raw_verts[vid]
152
+ vertices_coords.append([float(xy[0]), float(xy[1]), 0.0])
153
+
154
+ edges_vertices = []
155
+ edges_assignment = []
156
+ for eid in sorted(raw_edges.keys(), key=lambda k: int(k) if isinstance(k, (int, str)) else k):
157
+ v1_id, v2_id, asgn = raw_edges[eid]
158
+ if v1_id in id_to_idx and v2_id in id_to_idx:
159
+ edges_vertices.append([id_to_idx[v1_id], id_to_idx[v2_id]])
160
+ edges_assignment.append(asgn)
161
+
162
+ faces_vertices = _triangulate_vertices(vertices_coords)
163
+
164
+ return {
165
+ "vertices_coords": vertices_coords,
166
+ "edges_vertices": edges_vertices,
167
+ "edges_assignment": edges_assignment,
168
+ "faces_vertices": faces_vertices,
169
+ }
170
+
171
+
172
+ def _triangulate_vertices(vertices_coords: list) -> list:
173
+ """Delaunay triangulate the 2D vertex set for 3D mesh rendering."""
174
+ if len(vertices_coords) < 3:
175
+ return []
176
+ try:
177
+ from scipy.spatial import Delaunay
178
+ pts = np.array([[v[0], v[1]] for v in vertices_coords])
179
+ tri = Delaunay(pts)
180
+ return tri.simplices.tolist()
181
+ except Exception:
182
+ return [[0, 1, 2], [0, 2, 3]] if len(vertices_coords) >= 4 else []
183
 
184
 
185
  # ---------------------------------------------------------------------------
 
188
 
189
  @app.get("/targets", include_in_schema=True, response_class=NumpyJSONResponse)
190
  def get_targets():
191
+ """Return available target names and metadata from env/targets/*.fold."""
192
+ env = OrigamiEnvironment()
193
+ names = env.available_targets()
194
  result: dict[str, dict] = {}
195
+ for name in names:
196
+ target = env._targets.get(name, {})
197
  result[name] = {
198
  "name": name,
199
+ "level": target.get("level", 1),
200
+ "description": target.get("description", ""),
201
+ "n_creases": len([a for a in target.get("edges_assignment", []) if a in ("M", "V")]),
202
+ "difficulty": target.get("level", 1),
203
+ "material": "paper",
204
  }
205
  return NumpyJSONResponse(result)
206
 
207
 
208
  @app.get("/episode/demo", include_in_schema=True, response_class=NumpyJSONResponse)
209
+ def demo_episode(target: str = "half_horizontal"):
210
+ """Return a pre-solved demo episode for the given .fold target."""
211
+ env = OrigamiEnvironment(mode="step")
212
+ targets = env.available_targets()
213
+ if target not in targets:
214
+ target = targets[0] if targets else "half_horizontal"
215
 
216
+ t = env._targets.get(target, {})
217
+ folds = _target_to_folds(t)
 
 
 
218
 
219
+ obs_dict = env.reset(target_name=target)
220
  steps: list[dict] = []
221
 
222
  for i, fold_dict in enumerate(folds):
223
+ obs_dict, reward, done, info = env.step(fold_dict)
224
+ graph = env.paper.graph
225
+ all_edges = {eid: (v1, v2, a) for eid, (v1, v2, a) in graph.edges.items()}
226
+ fold_state = _graph_state_to_fold({
227
+ "vertices": dict(graph.vertices),
228
+ "edges": all_edges,
229
+ })
 
 
 
230
 
231
  steps.append({
232
  "step": i + 1,
233
  "fold": fold_dict,
234
+ "paper_state": fold_state,
235
+ "metrics": reward if isinstance(reward, dict) else {"total": reward},
236
+ "done": done,
237
  })
238
+ if done:
 
239
  break
240
 
 
 
241
  return NumpyJSONResponse({
242
  "task_name": target,
243
+ "task": {"name": target, "level": t.get("level", 1), "description": t.get("description", "")},
244
+ "target_crease": t,
245
  "steps": steps,
246
+ "final_metrics": steps[-1]["metrics"] if steps else {},
247
  })
248
 
249