Spaces:
Running
Running
File size: 1,813 Bytes
e9b7141 19abe39 e9b7141 19abe39 e9b7141 19abe39 e9b7141 19abe39 e9b7141 19abe39 e9b7141 19abe39 e9b7141 19abe39 e9b7141 19abe39 e9b7141 19abe39 e9b7141 19abe39 e9b7141 19abe39 e9b7141 19abe39 e9b7141 19abe39 e9b7141 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 | """
OpenEnv adapter for Optigami.
Thin wrapper around env.environment.OrigamiEnvironment that adapts it to the
OpenEnv protocol (Action/Observation types).
"""
from env.environment import OrigamiEnvironment as _Env
from .models import OrigamiAction, OrigamiObservation
class OpenEnvOrigamiEnvironment:
"""
OpenEnv-compatible wrapper for env.environment.OrigamiEnvironment.
Converts between env's dict-based API and OpenEnv's Action/Observation types.
"""
def __init__(self, mode: str = "step", max_steps: int = 8, targets_dir=None):
self._env = _Env(mode=mode, max_steps=max_steps, targets_dir=targets_dir)
def reset(self, target_name=None, **kwargs):
obs_dict = self._env.reset(target_name=target_name)
return self._obs_dict_to_model(obs_dict, reward=None, done=False)
def step(self, action: OrigamiAction, **kwargs):
action_dict = {
"from": action.from_point,
"to": action.to_point,
"assignment": action.assignment,
}
obs_dict, reward, done, info = self._env.step(action_dict)
reward_val = reward.get("total", 0.0) if isinstance(reward, dict) else reward
return self._obs_dict_to_model(obs_dict, reward=reward_val, done=done)
def _obs_dict_to_model(self, obs_dict: dict, reward=None, done=False) -> OrigamiObservation:
return OrigamiObservation(
prompt=obs_dict.get("prompt", ""),
target_name=obs_dict.get("target_name", ""),
step=obs_dict.get("step", 0),
paper_fold_json=obs_dict.get("paper_fold_json", {}),
reward=reward,
done=done,
)
def state(self):
return self._env.state()
def close(self):
self._env.close()
__all__ = ["OpenEnvOrigamiEnvironment"]
|