Spaces:
Sleeping
Sleeping
| """Typed OpenEnv client for MiniGridEnv.""" | |
| from __future__ import annotations | |
| from typing import Any, Dict | |
| from openenv.core.client_types import StepResult | |
| from openenv.core.env_client import EnvClient | |
| try: | |
| from .env.models import MiniGridAction, MiniGridObservation, MiniGridState | |
| except ImportError: | |
| from env.models import MiniGridAction, MiniGridObservation, MiniGridState | |
| class MiniGridEnvClient(EnvClient[MiniGridAction, MiniGridObservation, MiniGridState]): | |
| """WebSocket client for interacting with a MiniGridEnv server.""" | |
| def _step_payload(self, action: MiniGridAction) -> Dict[str, Any]: | |
| payload: Dict[str, Any] = {"command": action.command} | |
| if action.thought: | |
| payload["thought"] = action.thought | |
| return payload | |
| def _parse_result(self, payload: Dict[str, Any]) -> StepResult[MiniGridObservation]: | |
| obs_data = payload.get("observation") | |
| if not isinstance(obs_data, dict): | |
| obs_data = payload if isinstance(payload, dict) else {} | |
| done = bool(payload.get("done", obs_data.get("done", False))) | |
| reward = payload.get("reward", obs_data.get("reward")) | |
| observation = MiniGridObservation( | |
| text=obs_data.get("text", ""), | |
| mission=obs_data.get("mission", ""), | |
| step_idx=obs_data.get("step_idx", 0), | |
| steps_remaining=obs_data.get("steps_remaining", 0), | |
| max_steps=obs_data.get("max_steps", 1), | |
| history=obs_data.get("history", []), | |
| level_name=obs_data.get("level_name", ""), | |
| last_action=obs_data.get("last_action"), | |
| action_success=obs_data.get("action_success"), | |
| done=done, | |
| reward=reward, | |
| metadata=obs_data.get("metadata", {}), | |
| ) | |
| return StepResult(observation=observation, reward=reward, done=done) | |
| def _parse_state(self, payload: Dict[str, Any]) -> MiniGridState: | |
| state_data = payload.get("state") | |
| if not isinstance(state_data, dict): | |
| state_data = payload if isinstance(payload, dict) else {} | |
| return MiniGridState( | |
| episode_id=state_data.get("episode_id"), | |
| step_count=state_data.get("step_count", 0), | |
| level_name=state_data.get("level_name", ""), | |
| level_difficulty=state_data.get("level_difficulty", 0), | |
| completed=state_data.get("completed", False), | |
| truncated=state_data.get("truncated", False), | |
| total_reward=state_data.get("total_reward", 0.0), | |
| steps_taken=state_data.get("steps_taken", 0), | |
| optimal_steps=state_data.get("optimal_steps"), | |
| efficiency_ratio=state_data.get("efficiency_ratio"), | |
| valid_actions=state_data.get("valid_actions", 0), | |
| invalid_actions=state_data.get("invalid_actions", 0), | |
| action_distribution=state_data.get("action_distribution", {}), | |
| ) | |
| MiniGridEnv = MiniGridEnvClient | |