File size: 1,813 Bytes
e9b7141
 
19abe39
e9b7141
 
 
 
19abe39
e9b7141
19abe39
 
e9b7141
 
 
19abe39
e9b7141
 
19abe39
e9b7141
 
19abe39
e9b7141
19abe39
e9b7141
19abe39
e9b7141
 
 
 
 
 
 
 
 
 
 
19abe39
 
e9b7141
19abe39
e9b7141
 
 
19abe39
 
e9b7141
 
19abe39
e9b7141
 
19abe39
 
e9b7141
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
"""
OpenEnv adapter for Optigami.

Thin wrapper around env.environment.OrigamiEnvironment that adapts it to the
OpenEnv protocol (Action/Observation types).
"""
from env.environment import OrigamiEnvironment as _Env

from .models import OrigamiAction, OrigamiObservation


class OpenEnvOrigamiEnvironment:
    """
    OpenEnv-compatible wrapper for env.environment.OrigamiEnvironment.

    Converts between env's dict-based API and OpenEnv's Action/Observation types.
    """

    def __init__(self, mode: str = "step", max_steps: int = 8, targets_dir=None):
        self._env = _Env(mode=mode, max_steps=max_steps, targets_dir=targets_dir)

    def reset(self, target_name=None, **kwargs):
        obs_dict = self._env.reset(target_name=target_name)
        return self._obs_dict_to_model(obs_dict, reward=None, done=False)

    def step(self, action: OrigamiAction, **kwargs):
        action_dict = {
            "from": action.from_point,
            "to": action.to_point,
            "assignment": action.assignment,
        }
        obs_dict, reward, done, info = self._env.step(action_dict)
        reward_val = reward.get("total", 0.0) if isinstance(reward, dict) else reward
        return self._obs_dict_to_model(obs_dict, reward=reward_val, done=done)

    def _obs_dict_to_model(self, obs_dict: dict, reward=None, done=False) -> OrigamiObservation:
        return OrigamiObservation(
            prompt=obs_dict.get("prompt", ""),
            target_name=obs_dict.get("target_name", ""),
            step=obs_dict.get("step", 0),
            paper_fold_json=obs_dict.get("paper_fold_json", {}),
            reward=reward,
            done=done,
        )

    def state(self):
        return self._env.state()

    def close(self):
        self._env.close()


__all__ = ["OpenEnvOrigamiEnvironment"]