Spaces:
Running
Running
File size: 8,898 Bytes
1e49495 8ae8523 1e49495 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 | """
OrigamiEnvironment β OpenEnv environment wrapping the origami physics engine.
Implements reset() / step() / state following the OpenEnv interface.
Engine (physics, fold, validation, metrics) lives in engine/.
No server-side image rendering β paper_state contains all geometry data.
"""
from __future__ import annotations
import json
import os
import uuid
from typing import Any, Optional
# openenv base class β fall back to plain object if not installed
try:
from openenv.core.env_server.interfaces import Environment
except ImportError:
from typing import Generic, TypeVar
A = TypeVar("A")
O = TypeVar("O")
S = TypeVar("S")
class Environment(Generic[A, O, S]):
"""Minimal stand-in for openenv.core.env_server.interfaces.Environment."""
def __init__(self, **kwargs): pass
from engine.paper import Paper
from engine.fold_engine import apply_fold
from engine.physics import simulate
from engine.validation import validate_state
from engine.metrics import compute_all_metrics
from server.models import OrigamiAction, OrigamiObservation, OrigamiState
from server.tasks import get_task_by_name, sample_task
def _get_material(name: str):
"""Get material by name, falling back to paper."""
try:
from engine.materials import get_material
return get_material(name)
except Exception:
from engine.materials import get_material
return get_material("paper")
class OrigamiEnvironment(Environment[OrigamiAction, OrigamiObservation, OrigamiState]):
"""Origami folding RL environment.
Each episode: agent receives paper_state + task, applies folds one at a
time via step(), receives metrics + reward, ends with 'stop' action or
when max_folds is reached.
"""
SUPPORTS_CONCURRENT_SESSIONS = False
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._paper: Optional[Paper] = None
self._task: Optional[dict] = None
self._fold_history: list[dict] = []
self._metrics: dict = {}
self._validation: dict = {}
self._error: Optional[str] = None
self._episode_id: Optional[str] = None
self._step_count: int = 0
self._total_reward: float = 0.0
# ββ reset βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def reset(
self,
seed: Optional[int] = None,
episode_id: Optional[str] = None,
**kwargs: Any,
) -> OrigamiObservation:
self._episode_id = episode_id or str(uuid.uuid4())
self._step_count = 0
self._fold_history = []
self._error = None
self._total_reward = 0.0
# Select task
task_name = kwargs.get("task_name")
if task_name:
self._task = get_task_by_name(task_name)
if not self._task:
self._task = sample_task(seed=seed)
# Create flat sheet
mat = _get_material(self._task["material"])
self._paper = Paper.create_flat_sheet(
width=self._task["width"],
height=self._task["height"],
material=mat,
)
# Initial validation + metrics (no physics needed for flat sheet)
self._validation = validate_state(self._paper)
self._metrics = compute_all_metrics(self._paper, self._task, self._validation)
return self._make_observation(done=False, reward=None)
# ββ step ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def step(
self,
action: OrigamiAction,
timeout_s: Optional[float] = None,
**kwargs: Any,
) -> OrigamiObservation:
if self._paper is None or self._task is None:
return self._make_observation(done=True, reward=-5.0)
self._step_count += 1
self._error = None
# ββ Stop action βββββββββββββββββββββββββββββββββββββββββββββββ
if action.fold_type == "stop":
return self._finalize_episode()
# ββ Build fold dict βββββββββββββββββββββββββββββββββββββββββββ
fold_dict = {
"type": action.fold_type,
"line": action.fold_line,
"angle": action.fold_angle,
}
# ββ Apply fold ββββββββββββββββββββββββββββββββββββββββββββββββ
new_paper, err = apply_fold(self._paper, fold_dict)
if err:
self._error = err
return self._make_observation(done=True, reward=-5.0)
self._paper = new_paper
self._fold_history.append({**fold_dict, "step": self._step_count})
# ββ Physics relaxation ββββββββββββββββββββββββββββββββββββββββ
try:
self._paper = simulate(self._paper, fold_percent=1.0)
except Exception as exc:
self._error = f"Physics failed: {exc}"
# Continue β don't abort episode on physics failure
# ββ Validate ββββββββββββββββββββββββββββββββββββββββββββββββββ
self._validation = validate_state(self._paper)
# ββ Metrics βββββββββββββββββββββββββββββββββββββββββββββββββββ
self._metrics = compute_all_metrics(self._paper, self._task, self._validation)
# ββ Check termination βββββββββββββββββββββββββββββββββββββββββ
max_folds = self._task.get("max_folds", 50)
if self._step_count >= max_folds:
return self._finalize_episode()
if self._validation.get("self_intersections", 0) > 0:
self._error = "Self-intersection detected"
return self._finalize_episode()
return self._make_observation(done=False, reward=None)
# ββ state βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
@property
def state(self) -> OrigamiState:
return OrigamiState(
episode_id=self._episode_id,
step_count=self._step_count,
task_name=self._task.get("name", "") if self._task else "",
num_folds_applied=len(self._fold_history),
is_valid=self._metrics.get("is_valid", True),
total_reward=self._total_reward,
)
# ββ internals βββββββββββββββββββββββββββββββββββββββββββββββββββββ
def _finalize_episode(self) -> OrigamiObservation:
reward = self._compute_reward()
self._total_reward = reward
return self._make_observation(done=True, reward=reward)
def _make_observation(self, done: bool, reward: Optional[float]) -> OrigamiObservation:
return OrigamiObservation(
done=done,
reward=reward,
task=self._task or {},
paper_state=self._paper.to_observation_dict() if self._paper else {},
metrics=self._metrics,
fold_history=self._fold_history,
error=self._error,
)
def _compute_reward(self) -> float:
m = self._metrics
reward = 0.0
# Compactness is the main signal
reward += m.get("compactness", 0.0) * 20.0
# Bonus for fitting in target box
if m.get("fits_target_box", False):
reward += 10.0
# Bonus for deployability (if task requires it)
if m.get("is_deployable", False):
reward += 5.0
# Penalties for violations
reward -= m.get("kawasaki_violations", 0) * 2.0
reward -= m.get("maekawa_violations", 0) * 2.0
reward -= m.get("self_intersections", 0) * 5.0
# Penalty for too many folds (encourage efficiency)
reward -= m.get("fold_count", 0) * 0.5
# Penalty for exceeding material strain limit
max_strain = m.get("max_strain", 0.0)
strain_limit = self._paper.material.max_strain if self._paper else 0.05
if max_strain > strain_limit:
reward -= 3.0 * (max_strain / strain_limit)
return float(reward)
|