HPCOpenenv / training /remote_env.py
huggingmenfordays's picture
deploy: ccyloopss/HPCOpenenv — with OPENENV_API_KEY auth guard
bc35a94
from __future__ import annotations
import itertools
import random
import threading
import time
from dataclasses import dataclass
from typing import Any
from typing import Iterable
try:
import httpx
except ImportError as exc:
raise ImportError("httpx is required for training.remote_env install httpx") from exc
try:
import gymnasium as gym
from gymnasium import spaces
except ImportError as exc:
raise ImportError("gymnasium is required for training.remote_env") from exc
@dataclass
class RemoteStep:
stdout: str
stderr: str
exit_code: int
reward: float
done: bool
step_number: int
max_steps: int
task_id: str
episode_id: str
class RemoteEndpointPool:
def __init__(
self,
env_urls: Iterable[str],
timeout: float = 60.0,
retries: int = 3,
backoff_seconds: float = 1.5,
api_key: str | None = None,
) -> None:
cleaned = [u.rstrip("/") for u in env_urls if u.strip()]
if not cleaned:
raise ValueError("at least one env_url is required")
self._urls = cleaned
self._cycle = itertools.cycle(self._urls)
self._lock = threading.Lock()
headers = {"Authorization": f"Bearer {api_key}"} if api_key else {}
self._client = httpx.Client(timeout=timeout, http2=False, headers=headers)
self._retries = retries
self._backoff = backoff_seconds
def next_url(self) -> str:
with self._lock:
return next(self._cycle)
def post_json(self, url: str, path: str, body: dict[str, Any] | None) -> dict[str, Any]:
last_exc: Exception | None = None
for attempt in range(self._retries):
try:
resp = self._client.post(f"{url}{path}", json=body or {})
resp.raise_for_status()
return resp.json()
except Exception as exc:
last_exc = exc
time.sleep(self._backoff * (attempt + 1))
raise RuntimeError(f"remote call failed {url}{path}") from last_exc
def get_json(self, url: str, path: str) -> dict[str, Any]:
resp = self._client.get(f"{url}{path}")
resp.raise_for_status()
return resp.json()
def close(self) -> None:
self._client.close()
class HttpEnterpriseHPCEnv(gym.Env):
metadata = {"render_modes": []}
def __init__(
self,
env_urls: list[str] | str,
*,
scenario: str | None = None,
scenario_pool: list[str] | None = None,
pool: RemoteEndpointPool | None = None,
timeout: float = 60.0,
api_key: str | None = None,
) -> None:
super().__init__()
self.action_space = spaces.Text(max_length=4096)
self.observation_space = spaces.Text(max_length=65536)
_external_pool = pool is not None
if pool is None:
if isinstance(env_urls, str):
env_urls = [env_urls]
pool = RemoteEndpointPool(env_urls, timeout=timeout, api_key=api_key)
self._pool = pool
self._owns_pool = not _external_pool
self._scenario = scenario
self._scenario_pool = scenario_pool or ([scenario] if scenario else ["hpc_outage", "hpc_munge"])
self._active_url: str | None = None
self._episode_task: str | None = None
self._episode_id: str | None = None
self._step_count = 0
self._max_steps = 0
self._rng = random.Random()
def reset(
self,
*,
seed: int | None = None,
options: dict[str, Any] | None = None,
) -> tuple[str, dict[str, Any]]:
super().reset(seed=seed)
if seed is not None:
self._rng.seed(seed)
task_id = None
if options and "scenario" in options:
task_id = options["scenario"]
elif self._scenario:
task_id = self._scenario
elif self._scenario_pool:
task_id = self._rng.choice(self._scenario_pool)
self._active_url = self._pool.next_url()
body = {"task_id": task_id} if task_id else {}
data = self._pool.post_json(self._active_url, "/reset", body)
self._absorb(data)
obs_payload = data["observation"]
observation = self._observation_text(obs_payload)
info = {
"task_id": self._episode_task or "",
"episode_id": self._episode_id or "",
"max_steps": self._max_steps,
"endpoint": self._active_url,
"grader_health": float(obs_payload.get("grader_health", 0.0)),
"grader_details": dict(obs_payload.get("grader_details") or {}),
"ood_http_code": str(obs_payload.get("ood_http_code", "")),
}
return observation, info
def step(
self, action: str
) -> tuple[str, float, bool, bool, dict[str, Any]]:
if self._active_url is None:
raise RuntimeError("HttpEnterpriseHPCEnv step called before reset")
body: dict[str, Any] = {
"action": {"command": action, "reasoning": None},
}
# include episode_id so the server-side multi-session store routes
# this step to the correct episode. older servers ignore the field.
if self._episode_id:
body["episode_id"] = self._episode_id
data = self._pool.post_json(self._active_url, "/step", body)
self._absorb(data)
obs_payload = data["observation"]
obs = self._observation_text(obs_payload)
reward = float(obs_payload.get("reward", 0.0))
terminated = bool(obs_payload.get("done", False))
truncated = (not terminated) and self._step_count >= self._max_steps
info = {
"task_id": self._episode_task or "",
"episode_id": self._episode_id or "",
"step": self._step_count,
"max_steps": self._max_steps,
"exit_code": obs_payload.get("exit_code"),
"reward_source": "openenv_remote",
"endpoint": self._active_url,
"grader_health": float(obs_payload.get("grader_health", 0.0)),
"grader_details": dict(obs_payload.get("grader_details") or {}),
"ood_http_code": str(obs_payload.get("ood_http_code", "")),
}
return obs, reward, terminated, truncated, info
def close(self) -> None:
if self._owns_pool:
self._pool.close()
def _absorb(self, data: dict[str, Any]) -> None:
state = data.get("state") or {}
self._episode_task = state.get("task_id")
self._episode_id = state.get("episode_id")
self._step_count = int(state.get("step_count", 0))
self._max_steps = int(state.get("max_steps", 1))
@staticmethod
def _observation_text(obs: dict[str, Any]) -> str:
stdout = obs.get("stdout", "")
stderr = obs.get("stderr", "")
parts: list[str] = []
if stdout:
parts.append(stdout.rstrip())
if stderr:
parts.append("stderr:\n" + stderr.rstrip())
if not parts:
parts.append("[command executed no output]")
return "\n".join(parts)