Spaces:
Paused
Paused
| from __future__ import annotations | |
| import itertools | |
| import random | |
| import threading | |
| import time | |
| from dataclasses import dataclass | |
| from typing import Any | |
| from typing import Iterable | |
| try: | |
| import httpx | |
| except ImportError as exc: | |
| raise ImportError("httpx is required for training.remote_env install httpx") from exc | |
| try: | |
| import gymnasium as gym | |
| from gymnasium import spaces | |
| except ImportError as exc: | |
| raise ImportError("gymnasium is required for training.remote_env") from exc | |
| class RemoteStep: | |
| stdout: str | |
| stderr: str | |
| exit_code: int | |
| reward: float | |
| done: bool | |
| step_number: int | |
| max_steps: int | |
| task_id: str | |
| episode_id: str | |
| class RemoteEndpointPool: | |
| def __init__( | |
| self, | |
| env_urls: Iterable[str], | |
| timeout: float = 60.0, | |
| retries: int = 3, | |
| backoff_seconds: float = 1.5, | |
| api_key: str | None = None, | |
| ) -> None: | |
| cleaned = [u.rstrip("/") for u in env_urls if u.strip()] | |
| if not cleaned: | |
| raise ValueError("at least one env_url is required") | |
| self._urls = cleaned | |
| self._cycle = itertools.cycle(self._urls) | |
| self._lock = threading.Lock() | |
| headers = {"Authorization": f"Bearer {api_key}"} if api_key else {} | |
| self._client = httpx.Client(timeout=timeout, http2=False, headers=headers) | |
| self._retries = retries | |
| self._backoff = backoff_seconds | |
| def next_url(self) -> str: | |
| with self._lock: | |
| return next(self._cycle) | |
| def post_json(self, url: str, path: str, body: dict[str, Any] | None) -> dict[str, Any]: | |
| last_exc: Exception | None = None | |
| for attempt in range(self._retries): | |
| try: | |
| resp = self._client.post(f"{url}{path}", json=body or {}) | |
| resp.raise_for_status() | |
| return resp.json() | |
| except Exception as exc: | |
| last_exc = exc | |
| time.sleep(self._backoff * (attempt + 1)) | |
| raise RuntimeError(f"remote call failed {url}{path}") from last_exc | |
| def get_json(self, url: str, path: str) -> dict[str, Any]: | |
| resp = self._client.get(f"{url}{path}") | |
| resp.raise_for_status() | |
| return resp.json() | |
| def close(self) -> None: | |
| self._client.close() | |
| class HttpEnterpriseHPCEnv(gym.Env): | |
| metadata = {"render_modes": []} | |
| def __init__( | |
| self, | |
| env_urls: list[str] | str, | |
| *, | |
| scenario: str | None = None, | |
| scenario_pool: list[str] | None = None, | |
| pool: RemoteEndpointPool | None = None, | |
| timeout: float = 60.0, | |
| api_key: str | None = None, | |
| ) -> None: | |
| super().__init__() | |
| self.action_space = spaces.Text(max_length=4096) | |
| self.observation_space = spaces.Text(max_length=65536) | |
| _external_pool = pool is not None | |
| if pool is None: | |
| if isinstance(env_urls, str): | |
| env_urls = [env_urls] | |
| pool = RemoteEndpointPool(env_urls, timeout=timeout, api_key=api_key) | |
| self._pool = pool | |
| self._owns_pool = not _external_pool | |
| self._scenario = scenario | |
| self._scenario_pool = scenario_pool or ([scenario] if scenario else ["hpc_outage", "hpc_munge"]) | |
| self._active_url: str | None = None | |
| self._episode_task: str | None = None | |
| self._episode_id: str | None = None | |
| self._step_count = 0 | |
| self._max_steps = 0 | |
| self._rng = random.Random() | |
| def reset( | |
| self, | |
| *, | |
| seed: int | None = None, | |
| options: dict[str, Any] | None = None, | |
| ) -> tuple[str, dict[str, Any]]: | |
| super().reset(seed=seed) | |
| if seed is not None: | |
| self._rng.seed(seed) | |
| task_id = None | |
| if options and "scenario" in options: | |
| task_id = options["scenario"] | |
| elif self._scenario: | |
| task_id = self._scenario | |
| elif self._scenario_pool: | |
| task_id = self._rng.choice(self._scenario_pool) | |
| self._active_url = self._pool.next_url() | |
| body = {"task_id": task_id} if task_id else {} | |
| data = self._pool.post_json(self._active_url, "/reset", body) | |
| self._absorb(data) | |
| obs_payload = data["observation"] | |
| observation = self._observation_text(obs_payload) | |
| info = { | |
| "task_id": self._episode_task or "", | |
| "episode_id": self._episode_id or "", | |
| "max_steps": self._max_steps, | |
| "endpoint": self._active_url, | |
| "grader_health": float(obs_payload.get("grader_health", 0.0)), | |
| "grader_details": dict(obs_payload.get("grader_details") or {}), | |
| "ood_http_code": str(obs_payload.get("ood_http_code", "")), | |
| } | |
| return observation, info | |
| def step( | |
| self, action: str | |
| ) -> tuple[str, float, bool, bool, dict[str, Any]]: | |
| if self._active_url is None: | |
| raise RuntimeError("HttpEnterpriseHPCEnv step called before reset") | |
| body: dict[str, Any] = { | |
| "action": {"command": action, "reasoning": None}, | |
| } | |
| # include episode_id so the server-side multi-session store routes | |
| # this step to the correct episode. older servers ignore the field. | |
| if self._episode_id: | |
| body["episode_id"] = self._episode_id | |
| data = self._pool.post_json(self._active_url, "/step", body) | |
| self._absorb(data) | |
| obs_payload = data["observation"] | |
| obs = self._observation_text(obs_payload) | |
| reward = float(obs_payload.get("reward", 0.0)) | |
| terminated = bool(obs_payload.get("done", False)) | |
| truncated = (not terminated) and self._step_count >= self._max_steps | |
| info = { | |
| "task_id": self._episode_task or "", | |
| "episode_id": self._episode_id or "", | |
| "step": self._step_count, | |
| "max_steps": self._max_steps, | |
| "exit_code": obs_payload.get("exit_code"), | |
| "reward_source": "openenv_remote", | |
| "endpoint": self._active_url, | |
| "grader_health": float(obs_payload.get("grader_health", 0.0)), | |
| "grader_details": dict(obs_payload.get("grader_details") or {}), | |
| "ood_http_code": str(obs_payload.get("ood_http_code", "")), | |
| } | |
| return obs, reward, terminated, truncated, info | |
| def close(self) -> None: | |
| if self._owns_pool: | |
| self._pool.close() | |
| def _absorb(self, data: dict[str, Any]) -> None: | |
| state = data.get("state") or {} | |
| self._episode_task = state.get("task_id") | |
| self._episode_id = state.get("episode_id") | |
| self._step_count = int(state.get("step_count", 0)) | |
| self._max_steps = int(state.get("max_steps", 1)) | |
| def _observation_text(obs: dict[str, Any]) -> str: | |
| stdout = obs.get("stdout", "") | |
| stderr = obs.get("stderr", "") | |
| parts: list[str] = [] | |
| if stdout: | |
| parts.append(stdout.rstrip()) | |
| if stderr: | |
| parts.append("stderr:\n" + stderr.rstrip()) | |
| if not parts: | |
| parts.append("[command executed no output]") | |
| return "\n".join(parts) | |