File size: 2,188 Bytes
453233a
f707fd4
453233a
 
 
 
 
 
 
 
 
 
 
 
 
 
f707fd4
453233a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
"""Typed OpenEnv client for Red Button.

Mirrors :mod:`envs.coding_env.client.CodingEnv` from the upstream OpenEnv
reference repo (PROJECT.md Section 4). The three abstract hooks expected by
``EnvClient`` are:

* ``_step_payload(action)`` — JSON payload for the ``/step`` endpoint. We use
  ``action.model_dump()`` (rather than coding_env's hand-written shape) because
  ``ShutdownAction`` carries a nested ``arguments: dict``.
* ``_parse_result(payload)`` — convert the server response into a
  ``StepResult[ShutdownObservation]``. The ``observation`` blob carries
  ``done``/``reward`` natively (PROJECT.md Section 11), but the framework
  also surfaces them at the top level of the WebSocket envelope; we honour
  whichever is present.
* ``_parse_state(payload)`` — full Pydantic ``ShutdownState`` reconstruction
  via ``model_validate`` (Section 11.6 + 11.3).
"""

from __future__ import annotations

from typing import Any

from openenv.core.client_types import StepResult
from openenv.core.env_client import EnvClient

from red_button.models import ShutdownAction, ShutdownObservation, ShutdownState


class ShutdownGymClient(EnvClient[ShutdownAction, ShutdownObservation, ShutdownState]):
    """Typed client wrapping the Red Button OpenEnv HTTP/WebSocket server."""

    def _step_payload(self, action: ShutdownAction) -> dict[str, Any]:
        # ``model_dump`` gives a fully-validated round-trippable dict —
        # safer than hand-rolling for our nested ``arguments`` shape.
        return action.model_dump()

    def _parse_result(self, payload: dict[str, Any]) -> StepResult[ShutdownObservation]:
        obs_data = payload.get("observation", payload)
        obs = ShutdownObservation.model_validate(obs_data)
        # Top-level "reward"/"done" win when present (the framework's WebSocket
        # envelope sets them); fall back to whatever the observation carries.
        reward = payload.get("reward", obs.reward)
        done = bool(payload.get("done", obs.done))
        return StepResult(observation=obs, reward=reward, done=done)

    def _parse_state(self, payload: dict[str, Any]) -> ShutdownState:
        return ShutdownState.model_validate(payload)