Spaces:
Runtime error
Runtime error
| """PRobe Environment Client.""" | |
| from __future__ import annotations | |
| from openenv.core import EnvClient | |
| from openenv.core.client_types import StepResult | |
| from openenv.core.env_server.types import State | |
| from .models import ProbeAction, ProbeObservation | |
| class ProbeEnv(EnvClient[ProbeAction, ProbeObservation, State]): | |
| """ | |
| Client for the PRobe environment. | |
| Maintains a persistent WebSocket connection to the server. | |
| Example:: | |
| with ProbeEnv(base_url="http://localhost:8000") as env: | |
| result = env.reset() | |
| print(result.observation.task_description) | |
| action = ProbeAction( | |
| action_type="add_comment", | |
| line_number=4, | |
| comment="Off-by-one: range(len+1) causes IndexError", | |
| severity="error", | |
| category="bug", | |
| ) | |
| result = env.step(action) | |
| print(result.reward) | |
| """ | |
| def _step_payload(self, action: ProbeAction) -> dict: | |
| payload: dict = {"action_type": action.action_type.value} | |
| if action.line_number is not None: | |
| payload["line_number"] = action.line_number | |
| if action.comment is not None: | |
| payload["comment"] = action.comment | |
| if action.severity is not None: | |
| payload["severity"] = action.severity.value | |
| if action.category is not None: | |
| payload["category"] = action.category.value | |
| if action.classification is not None: | |
| payload["classification"] = action.classification.value | |
| return payload | |
| def _parse_result( | |
| self, payload: dict | |
| ) -> StepResult[ProbeObservation]: | |
| obs_data: dict = payload.get("observation", {}) | |
| # Use model_validate so new fields added to ProbeObservation | |
| # are picked up automatically without changing this method. | |
| observation = ProbeObservation.model_validate(obs_data) | |
| return StepResult( | |
| observation=observation, | |
| reward=float(payload.get("reward") or 0.0), | |
| done=bool(payload.get("done", False)), | |
| ) | |
| def _parse_state(self, payload: dict) -> State: | |
| return State( | |
| episode_id=payload.get("episode_id"), | |
| step_count=payload.get("step_count", 0), | |
| ) | |