File size: 8,898 Bytes
1e49495
 
 
 
 
 
 
 
 
 
 
 
 
 
8ae8523
 
 
 
 
 
 
 
 
 
 
1e49495
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
"""
OrigamiEnvironment β€” OpenEnv environment wrapping the origami physics engine.

Implements reset() / step() / state following the OpenEnv interface.
Engine (physics, fold, validation, metrics) lives in engine/.
No server-side image rendering β€” paper_state contains all geometry data.
"""
from __future__ import annotations

import json
import os
import uuid
from typing import Any, Optional

# openenv base class β€” fall back to plain object if not installed
try:
    from openenv.core.env_server.interfaces import Environment
except ImportError:
    from typing import Generic, TypeVar
    A = TypeVar("A")
    O = TypeVar("O")
    S = TypeVar("S")
    class Environment(Generic[A, O, S]):
        """Minimal stand-in for openenv.core.env_server.interfaces.Environment."""
        def __init__(self, **kwargs): pass

from engine.paper import Paper
from engine.fold_engine import apply_fold
from engine.physics import simulate
from engine.validation import validate_state
from engine.metrics import compute_all_metrics
from server.models import OrigamiAction, OrigamiObservation, OrigamiState
from server.tasks import get_task_by_name, sample_task


def _get_material(name: str):
    """Get material by name, falling back to paper."""
    try:
        from engine.materials import get_material
        return get_material(name)
    except Exception:
        from engine.materials import get_material
        return get_material("paper")


class OrigamiEnvironment(Environment[OrigamiAction, OrigamiObservation, OrigamiState]):
    """Origami folding RL environment.

    Each episode: agent receives paper_state + task, applies folds one at a
    time via step(), receives metrics + reward, ends with 'stop' action or
    when max_folds is reached.
    """

    SUPPORTS_CONCURRENT_SESSIONS = False

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self._paper: Optional[Paper] = None
        self._task: Optional[dict] = None
        self._fold_history: list[dict] = []
        self._metrics: dict = {}
        self._validation: dict = {}
        self._error: Optional[str] = None
        self._episode_id: Optional[str] = None
        self._step_count: int = 0
        self._total_reward: float = 0.0

    # ── reset ─────────────────────────────────────────────────────────

    def reset(
        self,
        seed: Optional[int] = None,
        episode_id: Optional[str] = None,
        **kwargs: Any,
    ) -> OrigamiObservation:
        self._episode_id = episode_id or str(uuid.uuid4())
        self._step_count = 0
        self._fold_history = []
        self._error = None
        self._total_reward = 0.0

        # Select task
        task_name = kwargs.get("task_name")
        if task_name:
            self._task = get_task_by_name(task_name)
        if not self._task:
            self._task = sample_task(seed=seed)

        # Create flat sheet
        mat = _get_material(self._task["material"])
        self._paper = Paper.create_flat_sheet(
            width=self._task["width"],
            height=self._task["height"],
            material=mat,
        )

        # Initial validation + metrics (no physics needed for flat sheet)
        self._validation = validate_state(self._paper)
        self._metrics = compute_all_metrics(self._paper, self._task, self._validation)

        return self._make_observation(done=False, reward=None)

    # ── step ──────────────────────────────────────────────────────────

    def step(
        self,
        action: OrigamiAction,
        timeout_s: Optional[float] = None,
        **kwargs: Any,
    ) -> OrigamiObservation:
        if self._paper is None or self._task is None:
            return self._make_observation(done=True, reward=-5.0)

        self._step_count += 1
        self._error = None

        # ── Stop action ───────────────────────────────────────────────
        if action.fold_type == "stop":
            return self._finalize_episode()

        # ── Build fold dict ───────────────────────────────────────────
        fold_dict = {
            "type": action.fold_type,
            "line": action.fold_line,
            "angle": action.fold_angle,
        }

        # ── Apply fold ────────────────────────────────────────────────
        new_paper, err = apply_fold(self._paper, fold_dict)
        if err:
            self._error = err
            return self._make_observation(done=True, reward=-5.0)

        self._paper = new_paper
        self._fold_history.append({**fold_dict, "step": self._step_count})

        # ── Physics relaxation ────────────────────────────────────────
        try:
            self._paper = simulate(self._paper, fold_percent=1.0)
        except Exception as exc:
            self._error = f"Physics failed: {exc}"
            # Continue β€” don't abort episode on physics failure

        # ── Validate ──────────────────────────────────────────────────
        self._validation = validate_state(self._paper)

        # ── Metrics ───────────────────────────────────────────────────
        self._metrics = compute_all_metrics(self._paper, self._task, self._validation)

        # ── Check termination ─────────────────────────────────────────
        max_folds = self._task.get("max_folds", 50)
        if self._step_count >= max_folds:
            return self._finalize_episode()

        if self._validation.get("self_intersections", 0) > 0:
            self._error = "Self-intersection detected"
            return self._finalize_episode()

        return self._make_observation(done=False, reward=None)

    # ── state ─────────────────────────────────────────────────────────

    @property
    def state(self) -> OrigamiState:
        return OrigamiState(
            episode_id=self._episode_id,
            step_count=self._step_count,
            task_name=self._task.get("name", "") if self._task else "",
            num_folds_applied=len(self._fold_history),
            is_valid=self._metrics.get("is_valid", True),
            total_reward=self._total_reward,
        )

    # ── internals ─────────────────────────────────────────────────────

    def _finalize_episode(self) -> OrigamiObservation:
        reward = self._compute_reward()
        self._total_reward = reward
        return self._make_observation(done=True, reward=reward)

    def _make_observation(self, done: bool, reward: Optional[float]) -> OrigamiObservation:
        return OrigamiObservation(
            done=done,
            reward=reward,
            task=self._task or {},
            paper_state=self._paper.to_observation_dict() if self._paper else {},
            metrics=self._metrics,
            fold_history=self._fold_history,
            error=self._error,
        )

    def _compute_reward(self) -> float:
        m = self._metrics
        reward = 0.0

        # Compactness is the main signal
        reward += m.get("compactness", 0.0) * 20.0

        # Bonus for fitting in target box
        if m.get("fits_target_box", False):
            reward += 10.0

        # Bonus for deployability (if task requires it)
        if m.get("is_deployable", False):
            reward += 5.0

        # Penalties for violations
        reward -= m.get("kawasaki_violations", 0) * 2.0
        reward -= m.get("maekawa_violations", 0) * 2.0
        reward -= m.get("self_intersections", 0) * 5.0

        # Penalty for too many folds (encourage efficiency)
        reward -= m.get("fold_count", 0) * 0.5

        # Penalty for exceeding material strain limit
        max_strain = m.get("max_strain", 0.0)
        strain_limit = self._paper.material.max_strain if self._paper else 0.05
        if max_strain > strain_limit:
            reward -= 3.0 * (max_strain / strain_limit)

        return float(reward)