"""WatchDog Environment — Client implementation for multi-turn oversight.""" from typing import Any, Dict from openenv.core.client_types import StepResult from openenv.core.env_server.types import State from openenv.core import EnvClient from .models import MultiTurnAction, MultiTurnObservation, MultiTurnState class WatchDogMultiTurnEnv( EnvClient[MultiTurnAction, MultiTurnObservation, MultiTurnState] ): """Client for the WatchDog multi-turn oversight environment. Example: >>> with WatchDogMultiTurnEnv(base_url="http://localhost:8000") as client: ... result = client.reset() ... print(result.observation.current_turn) ... result = client.step(MultiTurnAction(action_type="pass")) ... print(result.observation.feedback) """ def _step_payload(self, action: MultiTurnAction) -> Dict[str, Any]: return action.model_dump() def _parse_result( self, payload: Dict[str, Any] ) -> StepResult[MultiTurnObservation]: obs_data = payload.get("observation", {}) observation = MultiTurnObservation( conversation_so_far=obs_data.get("conversation_so_far", ""), current_turn=obs_data.get("current_turn", ""), current_turn_number=obs_data.get("current_turn_number", 0), total_turns=obs_data.get("total_turns", 0), task_domain=obs_data.get("task_domain", "general"), task_id=obs_data.get("task_id", ""), difficulty=obs_data.get("difficulty", 1), remaining_questions=obs_data.get("remaining_questions", 0), flags_so_far=obs_data.get("flags_so_far", 0), phase=obs_data.get("phase", "observe"), step_reward=obs_data.get("step_reward"), cumulative_reward=obs_data.get("cumulative_reward"), feedback=obs_data.get("feedback"), done=payload.get("done", False), reward=payload.get("reward"), ) return StepResult( observation=observation, reward=payload.get("reward"), done=payload.get("done", False), ) def _parse_state(self, payload: Dict[str, Any]) -> MultiTurnState: return MultiTurnState( episode_id=payload.get("episode_id"), step_count=payload.get("step_count", 0), current_level=payload.get("current_level", 1), total_episodes=payload.get("total_episodes", 0), errors_detected=payload.get("errors_detected", 0), errors_missed=payload.get("errors_missed", 0), false_flags=payload.get("false_flags", 0), correct_passes=payload.get("correct_passes", 0), questions_used=payload.get("questions_used", 0), interventions_correct=payload.get("interventions_correct", 0), interventions_wrong=payload.get("interventions_wrong", 0), cumulative_reward=payload.get("cumulative_reward", 0.0), )