# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. """Slipstream Governance Environment Client.""" from __future__ import annotations from typing import Dict try: from openenv.core.client_types import StepResult from openenv.core.env_client import EnvClient from .models import SlipstreamAction, SlipstreamObservation, SlipstreamState except ImportError: # pragma: no cover from openenv.core.client_types import StepResult from openenv.core.env_client import EnvClient from models import SlipstreamAction, SlipstreamObservation, SlipstreamState class SlipstreamGovEnv(EnvClient[SlipstreamAction, SlipstreamObservation, SlipstreamState]): """Client for SlipstreamGov OpenEnv environment.""" def _step_payload(self, action: SlipstreamAction) -> Dict: return {"message": action.message} def _parse_result(self, payload: Dict) -> StepResult[SlipstreamObservation]: obs_data = payload.get("observation", {}) or {} observation = SlipstreamObservation( task_prompt=obs_data.get("task_prompt"), parsed_slip=obs_data.get("parsed_slip"), expected_anchor=obs_data.get("expected_anchor"), predicted_anchor=obs_data.get("predicted_anchor"), arg_overlap=obs_data.get("arg_overlap", 0.0), violations=obs_data.get("violations", []) or [], metrics=obs_data.get("metrics", {}) or {}, done=payload.get("done", False), reward=payload.get("reward"), metadata=obs_data.get("metadata", {}) or {}, ) return StepResult( observation=observation, reward=payload.get("reward"), done=payload.get("done", False), ) def _parse_state(self, payload: Dict) -> SlipstreamState: return SlipstreamState( episode_id=payload.get("episode_id"), step_count=payload.get("step_count", 0), scenario_id=payload.get("scenario_id"), attack=payload.get("attack", False), )