File size: 1,495 Bytes
4058302
 
 
 
 
 
 
 
 
 
 
906a5a5
4058302
 
eb2d131
a574799
eb2d131
4058302
 
 
 
906a5a5
4058302
 
eb2d131
4058302
 
 
906a5a5
 
4058302
 
906a5a5
 
4058302
 
 
eb2d131
4058302
 
eb2d131
4058302
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
"""Typed client for the Incident Command Center environment.

Built on OpenEnv's generic `EnvClient` so it exposes the full gym-style API
(`reset`, `step`, `state`, `close`) plus the rich typed fields added by this
environment (reward breakdowns, investigation targets, playbook hints, etc).
"""

from __future__ import annotations

from typing import Any, Dict

from openenv.core.client_types import StepResult
from openenv.core.env_client import EnvClient

from models import IncidentAction, IncidentObservation, IncidentState


class IncidentCommandEnvClient(
    EnvClient[IncidentAction, IncidentObservation, IncidentState]
):
    """Client-side wrapper around the environment's HTTP contract."""

    def _step_payload(self, action: IncidentAction) -> Dict[str, Any]:
        return action.model_dump(exclude_none=True)

    def _parse_result(self, payload: Dict[str, Any]) -> StepResult:
        obs_data: Dict[str, Any] = payload.get("observation", {}) or {}
        observation = IncidentObservation.model_validate(obs_data)
        return StepResult(
            observation=observation,
            reward=float(payload.get("reward", 0.0)),
            done=bool(payload.get("done", False)),
        )

    def _parse_state(self, payload: Dict[str, Any]) -> IncidentState:
        return IncidentState.model_validate(payload)


# Backward-compatible alias for older imports from round 1.
SREEnvClient = IncidentCommandEnvClient

__all__ = ["IncidentCommandEnvClient", "SREEnvClient"]