Spaces:
Running
Running
| """Origami RL Environment — OpenEnv Environment subclass.""" | |
| import uuid | |
| from typing import Any, Optional | |
| import numpy as np | |
| from openenv.core import Environment | |
| from .engine.fold_parser import validate_fold | |
| from .engine.shape_match import compute_shape_match | |
| from .engine.simulate import SimResult, simulate | |
| from .models import OrigamiAction, OrigamiObservation, OrigamiState | |
| from .tasks import get_task | |
| class OrigamiEnvironment( | |
| Environment[OrigamiAction, OrigamiObservation, OrigamiState] | |
| ): | |
| SUPPORTS_CONCURRENT_SESSIONS = True | |
| def __init__(self, **kwargs: Any): | |
| super().__init__(**kwargs) | |
| self._state = OrigamiState() | |
| self._task: dict = {} | |
| self._target_positions: np.ndarray = np.zeros((0, 3)) | |
| def reset(self, seed: Optional[int] = None, episode_id: Optional[str] = None, **kwargs: Any) -> OrigamiObservation: | |
| self._state = OrigamiState(episode_id=episode_id or str(uuid.uuid4()), step_count=0) | |
| task_name = kwargs.get("task_name", "triangle") | |
| self._task = get_task(task_name) | |
| self._state.task_name = self._task["name"] | |
| target_fold = self._task["target_fold"] | |
| try: | |
| target_result = simulate(target_fold, crease_percent=1.0) | |
| self._target_positions = target_result.positions | |
| except Exception: | |
| self._target_positions = np.zeros((0, 3)) | |
| return OrigamiObservation( | |
| done=False, reward=None, task=self._task_info(), | |
| target_positions=self._target_positions.tolist(), | |
| ) | |
| def step(self, action: OrigamiAction, timeout_s: Optional[float] = None, **kwargs: Any) -> OrigamiObservation: | |
| self._state.step_count += 1 | |
| fold_data = action.fold_data | |
| is_valid, error_msg = validate_fold(fold_data) | |
| if not is_valid: | |
| return OrigamiObservation(done=True, reward=-2.0, task=self._task_info(), fold_data=fold_data, | |
| target_positions=self._target_positions.tolist(), error=f"Invalid FOLD data: {error_msg}") | |
| try: | |
| result: SimResult = simulate(fold_data, crease_percent=1.0) | |
| except Exception as e: | |
| return OrigamiObservation(done=True, reward=-2.0, task=self._task_info(), fold_data=fold_data, | |
| target_positions=self._target_positions.tolist(), error=f"Simulation error: {str(e)}") | |
| similarity = compute_shape_match(result.positions, self._target_positions) | |
| reward = similarity * 20.0 | |
| self._state.shape_similarity = similarity | |
| self._state.is_stable = result.converged | |
| return OrigamiObservation( | |
| done=True, reward=reward, task=self._task_info(), fold_data=fold_data, | |
| final_positions=result.positions.tolist(), target_positions=self._target_positions.tolist(), | |
| shape_similarity=similarity, max_strain=result.max_strain, is_stable=result.converged, | |
| ) | |
| def state(self) -> OrigamiState: | |
| return self._state | |
| def _task_info(self) -> dict: | |
| if not self._task: | |
| return {} | |
| return {"name": self._task.get("name", ""), "description": self._task.get("description", ""), | |
| "difficulty": self._task.get("difficulty", 0), "paper": self._task.get("paper", {})} | |