| | |
| | |
| | |
| | |
| | |
| |
|
| | """Slipstream Governance Environment Client.""" |
| |
|
| | from __future__ import annotations |
| |
|
| | from typing import Dict |
| |
|
| | try: |
| | from openenv.core.client_types import StepResult |
| | from openenv.core.env_client import EnvClient |
| | from .models import SlipstreamAction, SlipstreamObservation, SlipstreamState |
| | except ImportError: |
| | from openenv.core.client_types import StepResult |
| | from openenv.core.env_client import EnvClient |
| | from models import SlipstreamAction, SlipstreamObservation, SlipstreamState |
| |
|
| |
|
| | class SlipstreamGovEnv(EnvClient[SlipstreamAction, SlipstreamObservation, SlipstreamState]): |
| | """Client for SlipstreamGov OpenEnv environment.""" |
| |
|
| | def _step_payload(self, action: SlipstreamAction) -> Dict: |
| | return {"message": action.message} |
| |
|
| | def _parse_result(self, payload: Dict) -> StepResult[SlipstreamObservation]: |
| | obs_data = payload.get("observation", {}) or {} |
| |
|
| | observation = SlipstreamObservation( |
| | task_prompt=obs_data.get("task_prompt"), |
| | parsed_slip=obs_data.get("parsed_slip"), |
| | expected_anchor=obs_data.get("expected_anchor"), |
| | predicted_anchor=obs_data.get("predicted_anchor"), |
| | arg_overlap=obs_data.get("arg_overlap", 0.0), |
| | violations=obs_data.get("violations", []) or [], |
| | metrics=obs_data.get("metrics", {}) or {}, |
| | done=payload.get("done", False), |
| | reward=payload.get("reward"), |
| | metadata=obs_data.get("metadata", {}) or {}, |
| | ) |
| |
|
| | return StepResult( |
| | observation=observation, |
| | reward=payload.get("reward"), |
| | done=payload.get("done", False), |
| | ) |
| |
|
| | def _parse_state(self, payload: Dict) -> SlipstreamState: |
| | return SlipstreamState( |
| | episode_id=payload.get("episode_id"), |
| | step_count=payload.get("step_count", 0), |
| | scenario_id=payload.get("scenario_id"), |
| | attack=payload.get("attack", False), |
| | ) |
| |
|