optigami_ / origami_server /environment.py
sissississi's picture
Restore OpenEnv with optimized Docker multi-stage build
1f89afe
"""Origami RL Environment — OpenEnv Environment subclass."""
import uuid
from typing import Any, Optional
import numpy as np
from openenv.core import Environment
from .engine.fold_parser import validate_fold
from .engine.shape_match import compute_shape_match
from .engine.simulate import SimResult, simulate
from .models import OrigamiAction, OrigamiObservation, OrigamiState
from .tasks import get_task
class OrigamiEnvironment(
Environment[OrigamiAction, OrigamiObservation, OrigamiState]
):
SUPPORTS_CONCURRENT_SESSIONS = True
def __init__(self, **kwargs: Any):
super().__init__(**kwargs)
self._state = OrigamiState()
self._task: dict = {}
self._target_positions: np.ndarray = np.zeros((0, 3))
def reset(self, seed: Optional[int] = None, episode_id: Optional[str] = None, **kwargs: Any) -> OrigamiObservation:
self._state = OrigamiState(episode_id=episode_id or str(uuid.uuid4()), step_count=0)
task_name = kwargs.get("task_name", "triangle")
self._task = get_task(task_name)
self._state.task_name = self._task["name"]
target_fold = self._task["target_fold"]
try:
target_result = simulate(target_fold, crease_percent=1.0)
self._target_positions = target_result.positions
except Exception:
self._target_positions = np.zeros((0, 3))
return OrigamiObservation(
done=False, reward=None, task=self._task_info(),
target_positions=self._target_positions.tolist(),
)
def step(self, action: OrigamiAction, timeout_s: Optional[float] = None, **kwargs: Any) -> OrigamiObservation:
self._state.step_count += 1
fold_data = action.fold_data
is_valid, error_msg = validate_fold(fold_data)
if not is_valid:
return OrigamiObservation(done=True, reward=-2.0, task=self._task_info(), fold_data=fold_data,
target_positions=self._target_positions.tolist(), error=f"Invalid FOLD data: {error_msg}")
try:
result: SimResult = simulate(fold_data, crease_percent=1.0)
except Exception as e:
return OrigamiObservation(done=True, reward=-2.0, task=self._task_info(), fold_data=fold_data,
target_positions=self._target_positions.tolist(), error=f"Simulation error: {str(e)}")
similarity = compute_shape_match(result.positions, self._target_positions)
reward = similarity * 20.0
self._state.shape_similarity = similarity
self._state.is_stable = result.converged
return OrigamiObservation(
done=True, reward=reward, task=self._task_info(), fold_data=fold_data,
final_positions=result.positions.tolist(), target_positions=self._target_positions.tolist(),
shape_similarity=similarity, max_strain=result.max_strain, is_stable=result.converged,
)
@property
def state(self) -> OrigamiState:
return self._state
def _task_info(self) -> dict:
if not self._task:
return {}
return {"name": self._task.get("name", ""), "description": self._task.get("description", ""),
"difficulty": self._task.get("difficulty", 0), "paper": self._task.get("paper", {})}