"""Typed OpenEnv client for MiniGridEnv.""" from __future__ import annotations from typing import Any, Dict from openenv.core.client_types import StepResult from openenv.core.env_client import EnvClient try: from .env.models import MiniGridAction, MiniGridObservation, MiniGridState except ImportError: from env.models import MiniGridAction, MiniGridObservation, MiniGridState class MiniGridEnvClient(EnvClient[MiniGridAction, MiniGridObservation, MiniGridState]): """WebSocket client for interacting with a MiniGridEnv server.""" def _step_payload(self, action: MiniGridAction) -> Dict[str, Any]: payload: Dict[str, Any] = {"command": action.command} if action.thought: payload["thought"] = action.thought return payload def _parse_result(self, payload: Dict[str, Any]) -> StepResult[MiniGridObservation]: obs_data = payload.get("observation") if not isinstance(obs_data, dict): obs_data = payload if isinstance(payload, dict) else {} done = bool(payload.get("done", obs_data.get("done", False))) reward = payload.get("reward", obs_data.get("reward")) observation = MiniGridObservation( text=obs_data.get("text", ""), mission=obs_data.get("mission", ""), step_idx=obs_data.get("step_idx", 0), steps_remaining=obs_data.get("steps_remaining", 0), max_steps=obs_data.get("max_steps", 1), history=obs_data.get("history", []), level_name=obs_data.get("level_name", ""), last_action=obs_data.get("last_action"), action_success=obs_data.get("action_success"), done=done, reward=reward, metadata=obs_data.get("metadata", {}), ) return StepResult(observation=observation, reward=reward, done=done) def _parse_state(self, payload: Dict[str, Any]) -> MiniGridState: state_data = payload.get("state") if not isinstance(state_data, dict): state_data = payload if isinstance(payload, dict) else {} return MiniGridState( episode_id=state_data.get("episode_id"), step_count=state_data.get("step_count", 0), level_name=state_data.get("level_name", ""), level_difficulty=state_data.get("level_difficulty", 0), completed=state_data.get("completed", False), truncated=state_data.get("truncated", False), total_reward=state_data.get("total_reward", 0.0), steps_taken=state_data.get("steps_taken", 0), optimal_steps=state_data.get("optimal_steps"), efficiency_ratio=state_data.get("efficiency_ratio"), valid_actions=state_data.get("valid_actions", 0), invalid_actions=state_data.get("invalid_actions", 0), action_distribution=state_data.get("action_distribution", {}), ) MiniGridEnv = MiniGridEnvClient