ianalin123 commited on
Commit
883cccb
·
1 Parent(s): e971f8f

Add OpenEnv runtime adapter and server entrypoint

Browse files
openenv_runtime/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """OpenEnv integration runtime for Optigami."""
2
+
3
+ from .environment import OpenEnvOrigamiEnvironment
4
+ from .models import OrigamiAction, OrigamiObservation, OrigamiState
5
+
6
+ __all__ = [
7
+ "OpenEnvOrigamiEnvironment",
8
+ "OrigamiAction",
9
+ "OrigamiObservation",
10
+ "OrigamiState",
11
+ ]
openenv_runtime/environment.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Optional
4
+
5
+ from openenv.core.env_server.interfaces import Environment
6
+
7
+ from env.environment import OrigamiEnvironment
8
+
9
+ from .models import OrigamiAction, OrigamiObservation, OrigamiState
10
+
11
+
12
+ class OpenEnvOrigamiEnvironment(Environment[OrigamiAction, OrigamiObservation, OrigamiState]):
13
+ """OpenEnv adapter over the existing OrigamiEnvironment implementation."""
14
+
15
+ SUPPORTS_CONCURRENT_SESSIONS = True
16
+
17
+ def __init__(
18
+ self,
19
+ default_mode: str = "step",
20
+ max_steps: int = 8,
21
+ targets_dir: Optional[str] = None,
22
+ ):
23
+ super().__init__()
24
+ self.default_mode = default_mode
25
+ self.max_steps = max_steps
26
+ self.targets_dir = targets_dir
27
+ self._env: Optional[OrigamiEnvironment] = None
28
+ self._episode_id: Optional[str] = None
29
+
30
+ def _new_env(self, mode: Optional[str] = None) -> OrigamiEnvironment:
31
+ return OrigamiEnvironment(
32
+ mode=mode or self.default_mode,
33
+ max_steps=self.max_steps,
34
+ targets_dir=self.targets_dir,
35
+ )
36
+
37
+ def reset(
38
+ self,
39
+ seed: Optional[int] = None,
40
+ episode_id: Optional[str] = None,
41
+ **kwargs: Any,
42
+ ) -> OrigamiObservation:
43
+ del seed # deterministic seed plumbing can be added later
44
+
45
+ mode = kwargs.get("mode", self.default_mode)
46
+ target_name = kwargs.get("target_name")
47
+
48
+ self._env = self._new_env(mode=mode)
49
+ self._episode_id = episode_id
50
+ obs_dict = self._env.reset(target_name=target_name)
51
+
52
+ return OrigamiObservation(
53
+ done=False,
54
+ reward=None,
55
+ metadata={"available_targets": self._env.available_targets()},
56
+ prompt=obs_dict.get("prompt", ""),
57
+ target_name=obs_dict.get("target_name"),
58
+ step=obs_dict.get("step", 0),
59
+ paper_state=self._paper_state_snapshot(),
60
+ info=self._env._info(),
61
+ reward_components={},
62
+ )
63
+
64
+ def step(
65
+ self,
66
+ action: OrigamiAction,
67
+ timeout_s: Optional[float] = None,
68
+ **kwargs: Any,
69
+ ) -> OrigamiObservation:
70
+ del timeout_s, kwargs
71
+
72
+ if self._env is None:
73
+ self.reset(target_name=action.target_name)
74
+
75
+ assert self._env is not None
76
+
77
+ if action.target_name and action.target_name != self._env.target_name:
78
+ self.reset(target_name=action.target_name, mode=self._env.mode)
79
+
80
+ try:
81
+ if action.mode == "sequence":
82
+ if not action.completion:
83
+ return self._error_observation("sequence mode requires completion")
84
+
85
+ seq_env = self._new_env(mode="code_as_policy")
86
+ seq_env.reset(target_name=self._env.target_name)
87
+ obs_dict, reward_dict, done, info = seq_env.step(action.completion)
88
+ self._env = seq_env
89
+ else:
90
+ if action.fold is not None:
91
+ fold_payload = {
92
+ "from": list(action.fold.from_point),
93
+ "to": list(action.fold.to_point),
94
+ "assignment": action.fold.assignment,
95
+ "instruction": action.fold.instruction,
96
+ }
97
+ env_action: Any = fold_payload
98
+ elif action.completion:
99
+ env_action = action.completion
100
+ else:
101
+ return self._error_observation("single mode requires fold or completion")
102
+
103
+ obs_dict, reward_dict, done, info = self._env.step(env_action)
104
+
105
+ total = reward_dict.get("total") if isinstance(reward_dict, dict) else None
106
+ return OrigamiObservation(
107
+ done=bool(done),
108
+ reward=float(total) if isinstance(total, (int, float)) else None,
109
+ metadata={"target_name": self._env.target_name},
110
+ prompt=obs_dict.get("prompt", ""),
111
+ target_name=obs_dict.get("target_name", self._env.target_name),
112
+ step=obs_dict.get("step", self._env.step_count),
113
+ paper_state=self._paper_state_snapshot(),
114
+ info=info or {},
115
+ reward_components=reward_dict or {},
116
+ )
117
+ except Exception as exc: # pragma: no cover - defensive path
118
+ return self._error_observation(str(exc))
119
+
120
+ @property
121
+ def state(self) -> OrigamiState:
122
+ if self._env is None:
123
+ tmp_env = self._new_env(mode=self.default_mode)
124
+ return OrigamiState(
125
+ episode_id=self._episode_id,
126
+ step_count=0,
127
+ mode=tmp_env.mode,
128
+ target_name=None,
129
+ paper={},
130
+ last_reward={},
131
+ available_targets=tmp_env.available_targets(),
132
+ )
133
+
134
+ env_state = self._env.state()
135
+ return OrigamiState(
136
+ episode_id=self._episode_id,
137
+ step_count=env_state.get("step", self._env.step_count),
138
+ mode=env_state.get("mode", self._env.mode),
139
+ target_name=env_state.get("target", self._env.target_name),
140
+ paper=env_state.get("paper", {}),
141
+ last_reward=self._env.last_reward or {},
142
+ available_targets=self._env.available_targets(),
143
+ )
144
+
145
+ def close(self) -> None:
146
+ if self._env is not None:
147
+ self._env.close()
148
+ self._env = None
149
+
150
+ def _paper_state_snapshot(self) -> dict[str, Any]:
151
+ if self._env is None or self._env.paper is None:
152
+ return {"vertices": {}, "edges": [], "anchor_points": []}
153
+
154
+ graph = self._env.paper.graph
155
+ return {
156
+ "vertices": {str(k): [float(v[0]), float(v[1])] for k, v in graph.vertices.items()},
157
+ "edges": [
158
+ {
159
+ "id": int(eid),
160
+ "v1": [float(graph.vertices[v1][0]), float(graph.vertices[v1][1])],
161
+ "v2": [float(graph.vertices[v2][0]), float(graph.vertices[v2][1])],
162
+ "assignment": assignment,
163
+ }
164
+ for eid, (v1, v2, assignment) in graph.edges.items()
165
+ ],
166
+ "anchor_points": [
167
+ [float(x), float(y)] for (x, y) in self._env.paper.anchor_points()
168
+ ],
169
+ }
170
+
171
+ def _error_observation(self, message: str) -> OrigamiObservation:
172
+ return OrigamiObservation(
173
+ done=False,
174
+ reward=-0.1,
175
+ metadata={"error": True},
176
+ prompt="",
177
+ target_name=self._env.target_name if self._env else None,
178
+ step=self._env.step_count if self._env else 0,
179
+ paper_state=self._paper_state_snapshot(),
180
+ info=self._env._info() if self._env else {},
181
+ reward_components={"format": 0.0, "total": -0.1, "error": message},
182
+ error=message,
183
+ )
openenv_runtime/models.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Literal, Optional
4
+
5
+ from pydantic import BaseModel, Field, field_validator
6
+
7
+ from openenv.core.env_server.types import Action, Observation, State
8
+
9
+
10
+ class OrigamiFold(BaseModel):
11
+ """Single fold action payload for step-level execution."""
12
+
13
+ from_point: list[float] = Field(..., description="Fold line start [x, y]")
14
+ to_point: list[float] = Field(..., description="Fold line end [x, y]")
15
+ assignment: Literal["M", "V"] = Field(..., description="Mountain or valley")
16
+ instruction: str = Field(default="", description="Optional natural language instruction")
17
+
18
+ @field_validator("from_point", "to_point")
19
+ @classmethod
20
+ def _validate_point(cls, point: list[float]) -> list[float]:
21
+ if len(point) != 2:
22
+ raise ValueError("Point must contain exactly 2 coordinates")
23
+ return [float(point[0]), float(point[1])]
24
+
25
+
26
+ class OrigamiAction(Action):
27
+ """
28
+ OpenEnv action for Optigami.
29
+
30
+ Modes:
31
+ - single: execute one fold (pass `fold` or JSON `completion` for a single-fold object)
32
+ - sequence: execute a full <folds>[...]</folds> completion in one step
33
+ """
34
+
35
+ mode: Literal["single", "sequence"] = Field(default="single")
36
+ fold: Optional[OrigamiFold] = Field(default=None)
37
+ completion: Optional[str] = Field(default=None)
38
+ target_name: Optional[str] = Field(
39
+ default=None,
40
+ description="Optional target override; reset to this target before stepping",
41
+ )
42
+
43
+
44
+ class OrigamiObservation(Observation):
45
+ """OpenEnv observation payload returned by Optigami."""
46
+
47
+ prompt: str = Field(default="")
48
+ target_name: Optional[str] = Field(default=None)
49
+ step: int = Field(default=0)
50
+ paper_state: dict[str, Any] = Field(default_factory=dict)
51
+ info: dict[str, Any] = Field(default_factory=dict)
52
+ reward_components: dict[str, float | int | str] = Field(default_factory=dict)
53
+ error: Optional[str] = Field(default=None)
54
+
55
+
56
+ class OrigamiState(State):
57
+ """OpenEnv state payload for Optigami."""
58
+
59
+ mode: str = Field(default="step")
60
+ target_name: Optional[str] = Field(default=None)
61
+ paper: dict[str, Any] = Field(default_factory=dict)
62
+ last_reward: dict[str, Any] = Field(default_factory=dict)
63
+ available_targets: list[str] = Field(default_factory=list)
openenv_server/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """OpenEnv FastAPI app package."""
openenv_server/app.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from openenv.core.env_server.http_server import create_app
4
+
5
+ from openenv_runtime.environment import OpenEnvOrigamiEnvironment
6
+ from openenv_runtime.models import OrigamiAction, OrigamiObservation
7
+
8
+
9
+ app = create_app(
10
+ env=lambda: OpenEnvOrigamiEnvironment(),
11
+ action_cls=OrigamiAction,
12
+ observation_cls=OrigamiObservation,
13
+ env_name="optigami",
14
+ )