File size: 7,979 Bytes
952f360
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""
Doom Environment HTTP Client.

This module provides the client for connecting to a Doom Environment server
over HTTP.
"""

from typing import Any, Dict, Optional

import numpy as np

from openenv_core.client_types import StepResult
from openenv_core.env_server.types import State
from openenv_core.http_env_client import HTTPEnvClient

from .models import DoomAction, DoomObservation


class DoomEnv(HTTPEnvClient[DoomAction, DoomObservation]):
    """
    HTTP client for the Doom Environment.

    This client connects to a DoomEnvironment HTTP server and provides
    methods to interact with it: reset(), step(), and state access.

    The Doom environment wraps ViZDoom scenarios for visual RL research.

    Example:
        >>> # Connect to a running server
        >>> client = DoomEnv(base_url="http://localhost:8000")
        >>> result = client.reset()
        >>> print(result.observation.screen_shape)
        >>>
        >>> # Take an action
        >>> result = client.step(DoomAction(action_id=2))
        >>> print(result.observation.reward, result.observation.done)

    Example with Docker:
        >>> # Automatically start container and connect
        >>> client = DoomEnv.from_docker_image("doom-env:latest")
        >>> result = client.reset()
        >>> result = client.step(DoomAction(action_id=0))
        >>> client.close()

    Example with rendering:
        >>> client = DoomEnv.from_docker_image("doom-env:latest")
        >>> result = client.reset()
        >>> for _ in range(100):
        >>>     result = client.step(DoomAction(action_id=1))
        >>>     client.render()  # Display the game
        >>> client.close()
    """

    def __init__(self, *args, **kwargs):
        """Initialize DoomEnv client."""
        super().__init__(*args, **kwargs)
        self._render_window = None
        self._last_observation = None

    def _step_payload(self, action: DoomAction) -> Dict:
        """
        Convert DoomAction to JSON payload for step request.

        Args:
            action: DoomAction instance

        Returns:
            Dictionary representation suitable for JSON encoding
        """
        # Use dataclasses.asdict to ensure proper serialization
        from dataclasses import asdict

        # Convert to dict and filter out None values
        action_dict = asdict(action)

        # Convert numpy types to native Python types for JSON serialization
        result = {}
        for k, v in action_dict.items():
            if v is None:
                continue
            # Handle numpy integers and floats
            if hasattr(v, 'item'):  # numpy scalar types
                result[k] = v.item()
            # Handle numpy arrays/lists
            elif isinstance(v, (list, tuple)):
                result[k] = [x.item() if hasattr(x, 'item') else x for x in v]
            else:
                result[k] = v

        return result

    def _parse_result(self, payload: Dict) -> StepResult[DoomObservation]:
        """
        Parse server response into StepResult[DoomObservation].

        Args:
            payload: JSON response from server

        Returns:
            StepResult with DoomObservation
        """
        obs_data = payload.get("observation", {})
        observation = DoomObservation(
            screen_buffer=obs_data.get("screen_buffer", []),
            screen_shape=obs_data.get("screen_shape", [120, 160, 3]),
            game_variables=obs_data.get("game_variables"),
            available_actions=obs_data.get("available_actions"),
            episode_finished=obs_data.get("episode_finished", False),
            done=payload.get("done", False),
            reward=payload.get("reward"),
            metadata=obs_data.get("metadata", {}),
        )

        # Store for rendering
        self._last_observation = observation

        return StepResult(
            observation=observation,
            reward=payload.get("reward"),
            done=payload.get("done", False),
        )

    def _parse_state(self, payload: Dict) -> State:
        """
        Parse server response into State object.

        Args:
            payload: JSON response from /state endpoint

        Returns:
            State object with episode_id and step_count
        """
        return State(
            episode_id=payload.get("episode_id"),
            step_count=payload.get("step_count", 0),
        )

    def render(self, mode: str = "human") -> Optional[np.ndarray]:
        """
        Render the current observation.

        Args:
            mode: Render mode - "human" for window display, "rgb_array" for array return.

        Returns:
            RGB array if mode is "rgb_array", None otherwise.
        """
        if self._last_observation is None:
            print("Warning: No observation to render. Call reset() or step() first.")
            return None

        # Get screen from observation
        screen_buffer = self._last_observation.screen_buffer
        screen_shape = self._last_observation.screen_shape

        if not screen_buffer or not screen_shape:
            return None

        # Reshape screen buffer to original dimensions
        screen = np.array(screen_buffer, dtype=np.uint8).reshape(screen_shape)

        if mode == "rgb_array":
            return screen
        elif mode == "human":
            # Display using cv2 or matplotlib
            try:
                import cv2

                # Create window if it doesn't exist
                if self._render_window is None:
                    self._render_window = "ViZDoom - Doom Environment"
                    cv2.namedWindow(self._render_window, cv2.WINDOW_NORMAL)

                # Convert to BGR for OpenCV (if RGB)
                if len(screen.shape) == 3 and screen.shape[2] == 3:
                    screen_bgr = cv2.cvtColor(screen, cv2.COLOR_RGB2BGR)
                else:
                    screen_bgr = screen

                # Display
                cv2.imshow(self._render_window, screen_bgr)
                cv2.waitKey(1)

            except ImportError:
                # Fallback to matplotlib
                try:
                    import matplotlib.pyplot as plt

                    if self._render_window is None:
                        plt.ion()
                        self._render_window = plt.figure(figsize=(8, 6))
                        self._render_window.canvas.manager.set_window_title(
                            "ViZDoom - Doom Environment"
                        )

                    plt.clf()
                    if len(screen.shape) == 3:
                        plt.imshow(screen)
                    else:
                        plt.imshow(screen, cmap="gray")
                    plt.axis("off")
                    plt.pause(0.001)

                except ImportError:
                    print(
                        "Warning: Neither cv2 nor matplotlib available for rendering. "
                        "Install with: pip install opencv-python or pip install matplotlib"
                    )
            return None
        else:
            raise ValueError(
                f"Invalid render mode: {mode}. Use 'human' or 'rgb_array'."
            )

    def close(self) -> None:
        """Close the environment and clean up resources."""
        # Close render window if it exists
        if self._render_window is not None:
            try:
                import cv2

                cv2.destroyAllWindows()
            except ImportError:
                try:
                    import matplotlib.pyplot as plt

                    plt.close("all")
                except ImportError:
                    pass
            self._render_window = None

        # Call parent close
        super().close()