optigami / openenv_runtime /environment.py
sissississi's picture
go-back (#6)
e9b7141
"""
OpenEnv adapter for Optigami.
Thin wrapper around env.environment.OrigamiEnvironment that adapts it to the
OpenEnv protocol (Action/Observation types).
"""
from env.environment import OrigamiEnvironment as _Env
from .models import OrigamiAction, OrigamiObservation
class OpenEnvOrigamiEnvironment:
"""
OpenEnv-compatible wrapper for env.environment.OrigamiEnvironment.
Converts between env's dict-based API and OpenEnv's Action/Observation types.
"""
def __init__(self, mode: str = "step", max_steps: int = 8, targets_dir=None):
self._env = _Env(mode=mode, max_steps=max_steps, targets_dir=targets_dir)
def reset(self, target_name=None, **kwargs):
obs_dict = self._env.reset(target_name=target_name)
return self._obs_dict_to_model(obs_dict, reward=None, done=False)
def step(self, action: OrigamiAction, **kwargs):
action_dict = {
"from": action.from_point,
"to": action.to_point,
"assignment": action.assignment,
}
obs_dict, reward, done, info = self._env.step(action_dict)
reward_val = reward.get("total", 0.0) if isinstance(reward, dict) else reward
return self._obs_dict_to_model(obs_dict, reward=reward_val, done=done)
def _obs_dict_to_model(self, obs_dict: dict, reward=None, done=False) -> OrigamiObservation:
return OrigamiObservation(
prompt=obs_dict.get("prompt", ""),
target_name=obs_dict.get("target_name", ""),
step=obs_dict.get("step", 0),
paper_fold_json=obs_dict.get("paper_fold_json", {}),
reward=reward,
done=done,
)
def state(self):
return self._env.state()
def close(self):
self._env.close()
__all__ = ["OpenEnvOrigamiEnvironment"]