File size: 3,282 Bytes
9eb0831
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
"""
SRE OpenEnv Client.

Provides the SREEnv client class for interacting with a running
SRE environment instance via HTTP. 

This implementation uses 'requests' directly to maintain compatibility
with synchronous agent loops while aligning with OpenEnv 0.1 types.
"""

from __future__ import annotations

import requests
from typing import Any, Optional
from dataclasses import asdict, is_dataclass

try:
    from openenv.core.client_types import StepResult
except ImportError:
    # Handle legacy openenv_core or shim
    try:
        from openenv_core.client_types import StepResult
    except ImportError:
        # Fallback for older versions if necessary
        from openenv_core import StepResult

from models import SREAction, SREObservation, SREState


class SREEnv:
    """
    Synchronous client for the Autonomous SRE OpenEnv environment.
    
    Compatible with OpenEnv 0.1+ server endpoints (/reset, /step, /state).
    """

    def __init__(self, base_url: str = "http://localhost:8000"):
        self.base_url = base_url.rstrip("/")
        self._session = requests.Session()

    def reset(self, seed: int | None = None, task_id: str | None = None) -> StepResult:
        """Reset the environment via POST /reset."""
        payload = {}
        if seed is not None:
            payload["seed"] = seed
        if task_id is not None:
            payload["task_id"] = task_id
        
        response = self._session.post(f"{self.base_url}/reset", json=payload)
        response.raise_for_status()
        return self._parse_result(response.json())

    def step(self, action: SREAction) -> StepResult[SREObservation]:
        """Execute an action via POST /step."""
        payload = {
            "action": self._step_payload(action)
        }
        r = self._session.post(f"{self.base_url}/step", json=payload)
        r.raise_for_status()
        return self._parse_result(r.json())

    @property
    def state(self) -> SREState:
        """Retrieve current environment state via GET /state."""
        r = self._session.get(f"{self.base_url}/state")
        r.raise_for_status()
        return self._parse_state(r.json())

    def _step_payload(self, action: Any) -> dict:
        """Serialize an action to a JSON-compatible dict (Dataclass/Pydantic compatible)."""
        if is_dataclass(action):
            return asdict(action)
        if hasattr(action, "model_dump"):
            return action.model_dump()
        if hasattr(action, "dict"):
            return action.dict()
        return vars(action)

    def _parse_result(self, payload: dict) -> StepResult[SREObservation]:
        """Parse the server response into a typed StepResult."""
        # OpenEnv 0.1 server returns { "observation": {...}, "reward": float, "done": bool }
        obs_data = payload.get("observation", {})
        observation = SREObservation(**obs_data)

        return StepResult(
            observation=observation,
            reward=payload.get("reward", 0.0),
            done=payload.get("done", False),
        )

    def _parse_state(self, payload: dict) -> SREState:
        """Parse the server state response into a typed SREState."""
        return SREState(**payload)

    def close(self):
        """Close the session."""
        self._session.close()