Spaces:
Sleeping
Sleeping
| """ | |
| SRE OpenEnv Client. | |
| Provides the SREEnv client class for interacting with a running | |
| SRE environment instance via HTTP. | |
| This implementation uses 'requests' directly to maintain compatibility | |
| with synchronous agent loops while aligning with OpenEnv 0.1 types. | |
| """ | |
| from __future__ import annotations | |
| import requests | |
| from typing import Any, Optional | |
| from dataclasses import asdict, is_dataclass | |
| try: | |
| from openenv.core.client_types import StepResult | |
| except ImportError: | |
| # Handle legacy openenv_core or shim | |
| try: | |
| from openenv_core.client_types import StepResult | |
| except ImportError: | |
| # Fallback for older versions if necessary | |
| from openenv_core import StepResult | |
| from models import SREAction, SREObservation, SREState | |
| class SREEnv: | |
| """ | |
| Synchronous client for the Autonomous SRE OpenEnv environment. | |
| Compatible with OpenEnv 0.1+ server endpoints (/reset, /step, /state). | |
| """ | |
| def __init__(self, base_url: str = "http://localhost:8000"): | |
| self.base_url = base_url.rstrip("/") | |
| self._session = requests.Session() | |
| def reset(self, seed: int | None = None, task_id: str | None = None) -> StepResult: | |
| """Reset the environment via POST /reset.""" | |
| payload = {} | |
| if seed is not None: | |
| payload["seed"] = seed | |
| if task_id is not None: | |
| payload["task_id"] = task_id | |
| response = self._session.post(f"{self.base_url}/reset", json=payload) | |
| response.raise_for_status() | |
| return self._parse_result(response.json()) | |
| def step(self, action: SREAction) -> StepResult[SREObservation]: | |
| """Execute an action via POST /step.""" | |
| payload = { | |
| "action": self._step_payload(action) | |
| } | |
| r = self._session.post(f"{self.base_url}/step", json=payload) | |
| r.raise_for_status() | |
| return self._parse_result(r.json()) | |
| def state(self) -> SREState: | |
| """Retrieve current environment state via GET /state.""" | |
| r = self._session.get(f"{self.base_url}/state") | |
| r.raise_for_status() | |
| return self._parse_state(r.json()) | |
| def _step_payload(self, action: Any) -> dict: | |
| """Serialize an action to a JSON-compatible dict (Dataclass/Pydantic compatible).""" | |
| if is_dataclass(action): | |
| return asdict(action) | |
| if hasattr(action, "model_dump"): | |
| return action.model_dump() | |
| if hasattr(action, "dict"): | |
| return action.dict() | |
| return vars(action) | |
| def _parse_result(self, payload: dict) -> StepResult[SREObservation]: | |
| """Parse the server response into a typed StepResult.""" | |
| # OpenEnv 0.1 server returns { "observation": {...}, "reward": float, "done": bool } | |
| obs_data = payload.get("observation", {}) | |
| observation = SREObservation(**obs_data) | |
| return StepResult( | |
| observation=observation, | |
| reward=payload.get("reward", 0.0), | |
| done=payload.get("done", False), | |
| ) | |
| def _parse_state(self, payload: dict) -> SREState: | |
| """Parse the server state response into a typed SREState.""" | |
| return SREState(**payload) | |
| def close(self): | |
| """Close the session.""" | |
| self._session.close() | |