ianalin123 commited on
Commit
02d14b3
·
1 Parent(s): 033b6ea

feat(v2): implement multi-step environment with PaperState and per-step rewards

Browse files
Files changed (1) hide show
  1. origami_server/environment.py +168 -24
origami_server/environment.py CHANGED
@@ -1,9 +1,12 @@
1
  """Origami RL Environment — OpenEnv Environment subclass.
2
 
3
- Single-shot episodes: LLM submits a FOLD crease pattern, physics simulates it,
4
- reward = shape similarity to target. Like AlphaFold for origami.
 
 
5
  """
6
 
 
7
  import uuid
8
  from typing import Any, Optional
9
 
@@ -11,8 +14,10 @@ import numpy as np
11
  from openenv.core import Environment
12
 
13
  from .engine.fold_parser import validate_fold
 
14
  from .engine.shape_match import compute_shape_match
15
  from .engine.simulate import SimResult, simulate
 
16
  from .models import OrigamiAction, OrigamiObservation, OrigamiState
17
  from .tasks import get_task
18
 
@@ -20,22 +25,21 @@ from .tasks import get_task
20
  class OrigamiEnvironment(
21
  Environment[OrigamiAction, OrigamiObservation, OrigamiState]
22
  ):
23
- """Origami folding environment.
24
-
25
- Episode flow:
26
- 1. reset(task_name="triangle") → returns task description + target info
27
- 2. step(OrigamiAction(fold_data={...})) → simulates, scores, returns done=True
28
 
29
- Single action per episode. The action IS the complete crease pattern.
 
30
  """
31
 
32
  SUPPORTS_CONCURRENT_SESSIONS = True
33
 
34
- def __init__(self, **kwargs: Any):
35
  super().__init__(**kwargs)
 
36
  self._state = OrigamiState()
37
  self._task: dict = {}
38
  self._target_positions: np.ndarray = np.zeros((0, 3))
 
39
 
40
  def reset(
41
  self,
@@ -44,33 +48,51 @@ class OrigamiEnvironment(
44
  **kwargs: Any,
45
  ) -> OrigamiObservation:
46
  """Start a new episode with a target shape task."""
 
 
 
47
  self._state = OrigamiState(
48
  episode_id=episode_id or str(uuid.uuid4()),
49
  step_count=0,
 
 
50
  )
51
 
52
- # Get task
53
- task_name = kwargs.get("task_name", "triangle")
54
- self._task = get_task(task_name)
55
- self._state.task_name = self._task["name"]
56
-
57
- # Simulate the target FOLD to get target positions
58
  target_fold = self._task["target_fold"]
59
  try:
60
  target_result = simulate(target_fold, crease_percent=1.0)
61
  self._target_positions = target_result.positions
62
- except Exception as e:
63
  self._target_positions = np.zeros((0, 3))
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  return OrigamiObservation(
66
  done=False,
67
  reward=None,
68
- task={
69
- "name": self._task["name"],
70
- "description": self._task["description"],
71
- "difficulty": self._task["difficulty"],
72
- "paper": self._task["paper"],
73
- },
74
  fold_data={},
75
  final_positions=[],
76
  target_positions=self._target_positions.tolist(),
@@ -86,7 +108,107 @@ class OrigamiEnvironment(
86
  timeout_s: Optional[float] = None,
87
  **kwargs: Any,
88
  ) -> OrigamiObservation:
89
- """Evaluate the LLM's crease pattern.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
  1. Validate FOLD data
92
  2. Run physics simulation (creasePercent=1.0)
@@ -94,7 +216,6 @@ class OrigamiEnvironment(
94
  4. Return observation with reward = similarity × 20
95
  """
96
  self._state.step_count += 1
97
- fold_data = action.fold_data
98
 
99
  # Validate
100
  is_valid, error_msg = validate_fold(fold_data)
@@ -153,6 +274,29 @@ class OrigamiEnvironment(
153
  error=None,
154
  )
155
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  @property
157
  def state(self) -> OrigamiState:
158
  return self._state
 
1
  """Origami RL Environment — OpenEnv Environment subclass.
2
 
3
+ Origami folding environment supports single-shot (V1) and multi-step (V2) modes.
4
+
5
+ V1 (mode='single'): LLM submits complete FOLD JSON, gets Chamfer-distance reward. Done=True after 1 step.
6
+ V2 (mode='step'): LLM submits one crease per step, gets per-step reward (progress + geometry). Done=True when max_folds reached or completion bonus triggered.
7
  """
8
 
9
+ import copy
10
  import uuid
11
  from typing import Any, Optional
12
 
 
14
  from openenv.core import Environment
15
 
16
  from .engine.fold_parser import validate_fold
17
+ from .engine.paper_state import PaperState
18
  from .engine.shape_match import compute_shape_match
19
  from .engine.simulate import SimResult, simulate
20
+ from .engine.step_reward import compute_reward
21
  from .models import OrigamiAction, OrigamiObservation, OrigamiState
22
  from .tasks import get_task
23
 
 
25
  class OrigamiEnvironment(
26
  Environment[OrigamiAction, OrigamiObservation, OrigamiState]
27
  ):
28
+ """Origami folding environment — supports single-shot (V1) and multi-step (V2) modes.
 
 
 
 
29
 
30
+ V1 (mode='single'): LLM submits complete FOLD JSON, gets Chamfer-distance reward. Done=True after 1 step.
31
+ V2 (mode='step'): LLM submits one crease per step, gets per-step reward (progress + geometry). Done=True when max_folds reached or completion bonus triggered.
32
  """
33
 
34
  SUPPORTS_CONCURRENT_SESSIONS = True
35
 
36
+ def __init__(self, mode: str = "step", **kwargs: Any):
37
  super().__init__(**kwargs)
38
+ self._mode = mode # "step" (V2 default) | "single" (V1 compat)
39
  self._state = OrigamiState()
40
  self._task: dict = {}
41
  self._target_positions: np.ndarray = np.zeros((0, 3))
42
+ self._paper_state: Optional[PaperState] = None
43
 
44
  def reset(
45
  self,
 
48
  **kwargs: Any,
49
  ) -> OrigamiObservation:
50
  """Start a new episode with a target shape task."""
51
+ task_name = kwargs.get("task_name", "triangle")
52
+ self._task = get_task(task_name)
53
+
54
  self._state = OrigamiState(
55
  episode_id=episode_id or str(uuid.uuid4()),
56
  step_count=0,
57
+ mode=self._mode,
58
+ task_name=self._task["name"],
59
  )
60
 
61
+ # Simulate target FOLD to get target positions
 
 
 
 
 
62
  target_fold = self._task["target_fold"]
63
  try:
64
  target_result = simulate(target_fold, crease_percent=1.0)
65
  self._target_positions = target_result.positions
66
+ except Exception:
67
  self._target_positions = np.zeros((0, 3))
68
 
69
+ # V2: initialize empty paper state
70
+ if self._mode == "step":
71
+ self._paper_state = PaperState()
72
+ anchor_pts = [[x, y] for x, y in self._paper_state.anchor_points()]
73
+ return OrigamiObservation(
74
+ done=False,
75
+ reward=None,
76
+ task=self._task_info(),
77
+ fold_data={},
78
+ final_positions=[],
79
+ target_positions=self._target_positions.tolist(),
80
+ shape_similarity=0.0,
81
+ max_strain=0.0,
82
+ is_stable=True,
83
+ error=None,
84
+ step_count=0,
85
+ max_steps=self._task.get("max_folds", 1),
86
+ current_creases=[],
87
+ anchor_points=anchor_pts,
88
+ reward_breakdown={},
89
+ )
90
+
91
+ # V1: return initial observation (unchanged behavior)
92
  return OrigamiObservation(
93
  done=False,
94
  reward=None,
95
+ task=self._task_info(),
 
 
 
 
 
96
  fold_data={},
97
  final_positions=[],
98
  target_positions=self._target_positions.tolist(),
 
108
  timeout_s: Optional[float] = None,
109
  **kwargs: Any,
110
  ) -> OrigamiObservation:
111
+ """Dispatch to V2 crease step or V1 full-fold step."""
112
+ # V2: single-crease step
113
+ if action.crease is not None:
114
+ return self._step_crease(action.crease)
115
+ # V1: complete FOLD JSON (backward compat)
116
+ if action.fold_data is not None:
117
+ return self._step_fold(action.fold_data)
118
+ # Neither set
119
+ return OrigamiObservation(
120
+ done=True,
121
+ reward=-2.0,
122
+ task=self._task_info(),
123
+ fold_data={},
124
+ final_positions=[],
125
+ target_positions=self._target_positions.tolist(),
126
+ shape_similarity=0.0,
127
+ max_strain=0.0,
128
+ is_stable=False,
129
+ error="OrigamiAction must set either fold_data (V1) or crease (V2)",
130
+ )
131
+
132
+ def _step_crease(self, crease: dict) -> OrigamiObservation:
133
+ """V2: apply one crease, compute per-step reward."""
134
+ if self._paper_state is None:
135
+ self._paper_state = PaperState()
136
+
137
+ # Validate crease fields
138
+ assignment = crease.get("assignment", "")
139
+ from_pt = crease.get("from")
140
+ to_pt = crease.get("to")
141
+ if assignment not in ("M", "V") or from_pt is None or to_pt is None:
142
+ done = self._state.step_count >= self._task.get("max_folds", 1)
143
+ return OrigamiObservation(
144
+ done=done,
145
+ reward=-0.1,
146
+ task=self._task_info(),
147
+ fold_data={},
148
+ final_positions=[],
149
+ target_positions=self._target_positions.tolist(),
150
+ shape_similarity=0.0,
151
+ max_strain=0.0,
152
+ is_stable=False,
153
+ error=f"Invalid crease: {crease}",
154
+ step_count=self._state.step_count,
155
+ max_steps=self._task.get("max_folds", 1),
156
+ current_creases=self._paper_state.crease_edges(),
157
+ anchor_points=[[x, y] for x, y in self._paper_state.anchor_points()],
158
+ reward_breakdown={},
159
+ )
160
+
161
+ prev_state = copy.deepcopy(self._paper_state)
162
+ result = self._paper_state.add_crease(from_pt, to_pt, assignment)
163
+ self._state.step_count += 1
164
+
165
+ reward_dict = compute_reward(
166
+ prev_state=prev_state,
167
+ action_result=result,
168
+ new_state=self._paper_state,
169
+ target=self._task,
170
+ step=self._state.step_count,
171
+ max_steps=self._task.get("max_folds", 1),
172
+ )
173
+
174
+ max_folds = self._task.get("max_folds", 1)
175
+ done = (
176
+ self._state.step_count >= max_folds
177
+ or reward_dict.get("completion", 0) > 0
178
+ )
179
+
180
+ self._state.shape_similarity = reward_dict.get("progress", 0.0)
181
+
182
+ # On final step, run full simulation for viewer
183
+ final_positions: list = []
184
+ if done and self._paper_state.crease_edges():
185
+ try:
186
+ fold_data = self._paper_state_to_fold()
187
+ sim = simulate(fold_data, crease_percent=1.0)
188
+ final_positions = sim.positions.tolist()
189
+ except Exception:
190
+ pass
191
+
192
+ return OrigamiObservation(
193
+ done=done,
194
+ reward=reward_dict["total"],
195
+ task=self._task_info(),
196
+ fold_data={},
197
+ final_positions=final_positions,
198
+ target_positions=self._target_positions.tolist(),
199
+ shape_similarity=reward_dict.get("progress", 0.0),
200
+ max_strain=0.0,
201
+ is_stable=True,
202
+ error=None,
203
+ step_count=self._state.step_count,
204
+ max_steps=max_folds,
205
+ current_creases=self._paper_state.crease_edges(),
206
+ anchor_points=[[x, y] for x, y in self._paper_state.anchor_points()],
207
+ reward_breakdown={k: float(v) for k, v in reward_dict.items() if isinstance(v, (int, float))},
208
+ )
209
+
210
+ def _step_fold(self, fold_data: dict) -> OrigamiObservation:
211
+ """V1: evaluate a complete FOLD crease pattern.
212
 
213
  1. Validate FOLD data
214
  2. Run physics simulation (creasePercent=1.0)
 
216
  4. Return observation with reward = similarity × 20
217
  """
218
  self._state.step_count += 1
 
219
 
220
  # Validate
221
  is_valid, error_msg = validate_fold(fold_data)
 
274
  error=None,
275
  )
276
 
277
+ def _paper_state_to_fold(self) -> dict:
278
+ """Convert current PaperState crease graph to a minimal FOLD dict for simulation."""
279
+ if self._paper_state is None:
280
+ return {}
281
+ graph = self._paper_state.graph
282
+ # Build vertex list
283
+ vid_to_idx = {}
284
+ vertices = []
285
+ for vid, (x, y) in graph.vertices.items():
286
+ vid_to_idx[vid] = len(vertices)
287
+ vertices.append([x, y])
288
+ # Build edge lists
289
+ edges_vertices = []
290
+ edges_assignment = []
291
+ for eid, (v1, v2, assign) in graph.edges.items():
292
+ edges_vertices.append([vid_to_idx[v1], vid_to_idx[v2]])
293
+ edges_assignment.append(assign)
294
+ return {
295
+ "vertices_coords": vertices,
296
+ "edges_vertices": edges_vertices,
297
+ "edges_assignment": edges_assignment,
298
+ }
299
+
300
  @property
301
  def state(self) -> OrigamiState:
302
  return self._state