File size: 2,965 Bytes
a03a89b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
"""Typed OpenEnv client for MiniGridEnv."""

from __future__ import annotations

from typing import Any, Dict

from openenv.core.client_types import StepResult
from openenv.core.env_client import EnvClient

try:
    from .env.models import MiniGridAction, MiniGridObservation, MiniGridState
except ImportError:
    from env.models import MiniGridAction, MiniGridObservation, MiniGridState


class MiniGridEnvClient(EnvClient[MiniGridAction, MiniGridObservation, MiniGridState]):
    """WebSocket client for interacting with a MiniGridEnv server."""

    def _step_payload(self, action: MiniGridAction) -> Dict[str, Any]:
        payload: Dict[str, Any] = {"command": action.command}
        if action.thought:
            payload["thought"] = action.thought
        return payload

    def _parse_result(self, payload: Dict[str, Any]) -> StepResult[MiniGridObservation]:
        obs_data = payload.get("observation")
        if not isinstance(obs_data, dict):
            obs_data = payload if isinstance(payload, dict) else {}
        done = bool(payload.get("done", obs_data.get("done", False)))
        reward = payload.get("reward", obs_data.get("reward"))
        observation = MiniGridObservation(
            text=obs_data.get("text", ""),
            mission=obs_data.get("mission", ""),
            step_idx=obs_data.get("step_idx", 0),
            steps_remaining=obs_data.get("steps_remaining", 0),
            max_steps=obs_data.get("max_steps", 1),
            history=obs_data.get("history", []),
            level_name=obs_data.get("level_name", ""),
            last_action=obs_data.get("last_action"),
            action_success=obs_data.get("action_success"),
            done=done,
            reward=reward,
            metadata=obs_data.get("metadata", {}),
        )
        return StepResult(observation=observation, reward=reward, done=done)

    def _parse_state(self, payload: Dict[str, Any]) -> MiniGridState:
        state_data = payload.get("state")
        if not isinstance(state_data, dict):
            state_data = payload if isinstance(payload, dict) else {}
        return MiniGridState(
            episode_id=state_data.get("episode_id"),
            step_count=state_data.get("step_count", 0),
            level_name=state_data.get("level_name", ""),
            level_difficulty=state_data.get("level_difficulty", 0),
            completed=state_data.get("completed", False),
            truncated=state_data.get("truncated", False),
            total_reward=state_data.get("total_reward", 0.0),
            steps_taken=state_data.get("steps_taken", 0),
            optimal_steps=state_data.get("optimal_steps"),
            efficiency_ratio=state_data.get("efficiency_ratio"),
            valid_actions=state_data.get("valid_actions", 0),
            invalid_actions=state_data.get("invalid_actions", 0),
            action_distribution=state_data.get("action_distribution", {}),
        )


MiniGridEnv = MiniGridEnvClient