optigami / server /origami_environment.py
sissississi's picture
Fix openenv dependency β€” graceful fallback when openenv-core not installed
8ae8523
"""
OrigamiEnvironment β€” OpenEnv environment wrapping the origami physics engine.
Implements reset() / step() / state following the OpenEnv interface.
Engine (physics, fold, validation, metrics) lives in engine/.
No server-side image rendering β€” paper_state contains all geometry data.
"""
from __future__ import annotations
import json
import os
import uuid
from typing import Any, Optional
# openenv base class β€” fall back to plain object if not installed
try:
from openenv.core.env_server.interfaces import Environment
except ImportError:
from typing import Generic, TypeVar
A = TypeVar("A")
O = TypeVar("O")
S = TypeVar("S")
class Environment(Generic[A, O, S]):
"""Minimal stand-in for openenv.core.env_server.interfaces.Environment."""
def __init__(self, **kwargs): pass
from engine.paper import Paper
from engine.fold_engine import apply_fold
from engine.physics import simulate
from engine.validation import validate_state
from engine.metrics import compute_all_metrics
from server.models import OrigamiAction, OrigamiObservation, OrigamiState
from server.tasks import get_task_by_name, sample_task
def _get_material(name: str):
"""Get material by name, falling back to paper."""
try:
from engine.materials import get_material
return get_material(name)
except Exception:
from engine.materials import get_material
return get_material("paper")
class OrigamiEnvironment(Environment[OrigamiAction, OrigamiObservation, OrigamiState]):
"""Origami folding RL environment.
Each episode: agent receives paper_state + task, applies folds one at a
time via step(), receives metrics + reward, ends with 'stop' action or
when max_folds is reached.
"""
SUPPORTS_CONCURRENT_SESSIONS = False
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._paper: Optional[Paper] = None
self._task: Optional[dict] = None
self._fold_history: list[dict] = []
self._metrics: dict = {}
self._validation: dict = {}
self._error: Optional[str] = None
self._episode_id: Optional[str] = None
self._step_count: int = 0
self._total_reward: float = 0.0
# ── reset ─────────────────────────────────────────────────────────
def reset(
self,
seed: Optional[int] = None,
episode_id: Optional[str] = None,
**kwargs: Any,
) -> OrigamiObservation:
self._episode_id = episode_id or str(uuid.uuid4())
self._step_count = 0
self._fold_history = []
self._error = None
self._total_reward = 0.0
# Select task
task_name = kwargs.get("task_name")
if task_name:
self._task = get_task_by_name(task_name)
if not self._task:
self._task = sample_task(seed=seed)
# Create flat sheet
mat = _get_material(self._task["material"])
self._paper = Paper.create_flat_sheet(
width=self._task["width"],
height=self._task["height"],
material=mat,
)
# Initial validation + metrics (no physics needed for flat sheet)
self._validation = validate_state(self._paper)
self._metrics = compute_all_metrics(self._paper, self._task, self._validation)
return self._make_observation(done=False, reward=None)
# ── step ──────────────────────────────────────────────────────────
def step(
self,
action: OrigamiAction,
timeout_s: Optional[float] = None,
**kwargs: Any,
) -> OrigamiObservation:
if self._paper is None or self._task is None:
return self._make_observation(done=True, reward=-5.0)
self._step_count += 1
self._error = None
# ── Stop action ───────────────────────────────────────────────
if action.fold_type == "stop":
return self._finalize_episode()
# ── Build fold dict ───────────────────────────────────────────
fold_dict = {
"type": action.fold_type,
"line": action.fold_line,
"angle": action.fold_angle,
}
# ── Apply fold ────────────────────────────────────────────────
new_paper, err = apply_fold(self._paper, fold_dict)
if err:
self._error = err
return self._make_observation(done=True, reward=-5.0)
self._paper = new_paper
self._fold_history.append({**fold_dict, "step": self._step_count})
# ── Physics relaxation ────────────────────────────────────────
try:
self._paper = simulate(self._paper, fold_percent=1.0)
except Exception as exc:
self._error = f"Physics failed: {exc}"
# Continue β€” don't abort episode on physics failure
# ── Validate ──────────────────────────────────────────────────
self._validation = validate_state(self._paper)
# ── Metrics ───────────────────────────────────────────────────
self._metrics = compute_all_metrics(self._paper, self._task, self._validation)
# ── Check termination ─────────────────────────────────────────
max_folds = self._task.get("max_folds", 50)
if self._step_count >= max_folds:
return self._finalize_episode()
if self._validation.get("self_intersections", 0) > 0:
self._error = "Self-intersection detected"
return self._finalize_episode()
return self._make_observation(done=False, reward=None)
# ── state ─────────────────────────────────────────────────────────
@property
def state(self) -> OrigamiState:
return OrigamiState(
episode_id=self._episode_id,
step_count=self._step_count,
task_name=self._task.get("name", "") if self._task else "",
num_folds_applied=len(self._fold_history),
is_valid=self._metrics.get("is_valid", True),
total_reward=self._total_reward,
)
# ── internals ─────────────────────────────────────────────────────
def _finalize_episode(self) -> OrigamiObservation:
reward = self._compute_reward()
self._total_reward = reward
return self._make_observation(done=True, reward=reward)
def _make_observation(self, done: bool, reward: Optional[float]) -> OrigamiObservation:
return OrigamiObservation(
done=done,
reward=reward,
task=self._task or {},
paper_state=self._paper.to_observation_dict() if self._paper else {},
metrics=self._metrics,
fold_history=self._fold_history,
error=self._error,
)
def _compute_reward(self) -> float:
m = self._metrics
reward = 0.0
# Compactness is the main signal
reward += m.get("compactness", 0.0) * 20.0
# Bonus for fitting in target box
if m.get("fits_target_box", False):
reward += 10.0
# Bonus for deployability (if task requires it)
if m.get("is_deployable", False):
reward += 5.0
# Penalties for violations
reward -= m.get("kawasaki_violations", 0) * 2.0
reward -= m.get("maekawa_violations", 0) * 2.0
reward -= m.get("self_intersections", 0) * 5.0
# Penalty for too many folds (encourage efficiency)
reward -= m.get("fold_count", 0) * 0.5
# Penalty for exceeding material strain limit
max_strain = m.get("max_strain", 0.0)
strain_limit = self._paper.material.max_strain if self._paper else 0.05
if max_strain > strain_limit:
reward -= 3.0 * (max_strain / strain_limit)
return float(reward)