optigami / openenv_runtime /environment.py
ianalin123's picture
refactor(openenv): simplify runtime environment and models, extend server API
8652f7e
raw
history blame
1.81 kB
"""
OpenEnv adapter for Optigami.
Thin wrapper around env.environment.OrigamiEnvironment that adapts it to the
OpenEnv protocol (Action/Observation types).
"""
from env.environment import OrigamiEnvironment as _Env
from .models import OrigamiAction, OrigamiObservation
class OpenEnvOrigamiEnvironment:
"""
OpenEnv-compatible wrapper for env.environment.OrigamiEnvironment.
Converts between env's dict-based API and OpenEnv's Action/Observation types.
"""
def __init__(self, mode: str = "step", max_steps: int = 8, targets_dir=None):
self._env = _Env(mode=mode, max_steps=max_steps, targets_dir=targets_dir)
def reset(self, target_name=None, **kwargs):
obs_dict = self._env.reset(target_name=target_name)
return self._obs_dict_to_model(obs_dict, reward=None, done=False)
def step(self, action: OrigamiAction, **kwargs):
action_dict = {
"from": action.from_point,
"to": action.to_point,
"assignment": action.assignment,
}
obs_dict, reward, done, info = self._env.step(action_dict)
reward_val = reward.get("total", 0.0) if isinstance(reward, dict) else reward
return self._obs_dict_to_model(obs_dict, reward=reward_val, done=done)
def _obs_dict_to_model(self, obs_dict: dict, reward=None, done=False) -> OrigamiObservation:
return OrigamiObservation(
prompt=obs_dict.get("prompt", ""),
target_name=obs_dict.get("target_name", ""),
step=obs_dict.get("step", 0),
paper_fold_json=obs_dict.get("paper_fold_json", {}),
reward=reward,
done=done,
)
def state(self):
return self._env.state()
def close(self):
self._env.close()
__all__ = ["OpenEnvOrigamiEnvironment"]