File size: 14,476 Bytes
bc35a94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
from __future__ import annotations

import os
import random
import re
import shutil
import tempfile
import time
from pathlib import Path
from types import ModuleType
from typing import Any
from typing import Sequence

try:
    import gymnasium as gym
    from gymnasium import spaces
except ImportError as exc:
    raise ImportError(
        "gymnasium is required for hpc_gym import with pip install gymnasium"
    ) from exc

try:
    import pexpect
except ImportError as exc:
    raise ImportError(
        "pexpect is required for hpc_gym import with pip install pexpect"
    ) from exc

from sysadmin_env.sandbox import Sandbox
from sysadmin_env.tasks import hpc_gpu_ecc
from sysadmin_env.tasks import hpc_munge
from sysadmin_env.tasks import hpc_nfs_stale
from sysadmin_env.tasks import hpc_ood_apache
from sysadmin_env.tasks import hpc_outage
from sysadmin_env.tasks import hpc_pid_stale


PROMPT_PATTERN = re.compile(r"\[[^\]\r\n]+\][#$]\s?")
PRIMARY_HOSTNAME = "hpc-login"
OOD_PORT = 8080
OOD_LOG_PATH = "/tmp/ood.log"
OOD_DAEMON_SCRIPT = "/usr/local/bin/ood_server.py"
DEFAULT_STEP_TIMEOUT = 60.0
DEFAULT_SHELL_TIMEOUT = 30.0

SCENARIO_REGISTRY: dict[str, ModuleType] = {
    hpc_outage.TASK_ID: hpc_outage,
    hpc_munge.TASK_ID: hpc_munge,
    hpc_pid_stale.TASK_ID: hpc_pid_stale,
    hpc_gpu_ecc.TASK_ID: hpc_gpu_ecc,
    hpc_nfs_stale.TASK_ID: hpc_nfs_stale,
    hpc_ood_apache.TASK_ID: hpc_ood_apache,
}


def resolve_scenario(name_or_module: str | ModuleType) -> ModuleType:
    if isinstance(name_or_module, ModuleType):
        return name_or_module
    if name_or_module in SCENARIO_REGISTRY:
        return SCENARIO_REGISTRY[name_or_module]
    raise KeyError(
        f"unknown scenario {name_or_module} expected one of {sorted(SCENARIO_REGISTRY)}"
    )


class EnterpriseHPCEnv(gym.Env):
    metadata = {"render_modes": []}

    def __init__(
        self,
        task_root: str | None = None,
        *,
        scenario: str | ModuleType = hpc_outage.TASK_ID,
        scenario_pool: Sequence[str | ModuleType] | None = None,
        overlay_base_dir: str | None = None,
        shell_timeout: float = DEFAULT_SHELL_TIMEOUT,
        step_timeout: float = DEFAULT_STEP_TIMEOUT,
    ) -> None:
        super().__init__()
        self.action_space = spaces.Text(max_length=4096)
        self.observation_space = spaces.Text(max_length=65536)

        self._configured_task_root = task_root
        self._overlay_base_dir = overlay_base_dir
        self._shell_timeout = shell_timeout
        self._step_timeout = step_timeout

        self._scenario_pool: list[ModuleType]
        if scenario_pool is not None:
            self._scenario_pool = [resolve_scenario(item) for item in scenario_pool]
        else:
            self._scenario_pool = [resolve_scenario(scenario)]

        self._scenario: ModuleType = self._scenario_pool[0]
        self._sandbox: Sandbox | None = None
        self._sandbox_scenario_id: str | None = None
        self._shell: pexpect.spawn | None = None
        self._tmp_task_dir: str | None = None
        self._step_count = 0
        self._max_steps = 0
        self._last_reward = 0.0
        self._ood_started = False
        self._rng = random.Random()
        self._prev_health = 0.0

    @property
    def sandbox(self) -> Sandbox | None:
        return self._sandbox

    @property
    def scenario(self) -> ModuleType:
        return self._scenario

    def reset(
        self,
        *,
        seed: int | None = None,
        options: dict[str, Any] | None = None,
    ) -> tuple[str, dict[str, Any]]:
        super().reset(seed=seed)
        if seed is not None:
            self._rng.seed(seed)

        self._select_scenario(options)
        self._close_shell()

        scenario_changed = self._sandbox_scenario_id != self._scenario.TASK_ID
        if self._sandbox is not None and scenario_changed:
            try:
                self._sandbox.destroy()
            except Exception as exc:
                print(f"hpc_gym sandbox destroy failed {type(exc).__name__.lower()} {exc}")
            self._sandbox = None

        task_root = self._ensure_task_root()
        if self._sandbox is None:
            print(f"hpc_gym create sandbox scenario {self._scenario.TASK_ID} task_root {task_root}")
            self._sandbox = Sandbox(
                task_root,
                timeout=self._step_timeout,
                isolate_network=False,
                overlay_base_dir=self._overlay_base_dir,
                allow_nested_sandbox=True,
            )
            self._sandbox.create()
            self._sandbox_scenario_id = self._scenario.TASK_ID
        else:
            start = time.perf_counter()
            latency_ms = self._sandbox.reset()
            print(
                f"hpc_gym overlay reset scenario {self._scenario.TASK_ID} "
                f"{latency_ms:.2f}ms wall {((time.perf_counter()-start)*1000):.2f}ms"
            )

        if self._sandbox.state_root is not None:
            self._scenario.synchronize(self._sandbox.state_root)

        definition = self._scenario.build_definition(str(self._sandbox.state_root or ""))
        self._max_steps = definition.metadata.max_steps
        self._step_count = 0
        self._last_reward = 0.0
        self._prev_health = 0.0
        self._ood_started = False

        self._spawn_shell()
        self._bootstrap_primary_prompt()
        self._launch_ood_daemon()
        self._enter_login_node()

        observation = (
            f"login node ready scenario {self._scenario.TASK_ID} ood :"
            f"{OOD_PORT} max_steps {self._max_steps}"
        )
        info = {
            "task_id": self._scenario.TASK_ID,
            "max_steps": self._max_steps,
            "ood_port": OOD_PORT,
            "prompt_pattern": PROMPT_PATTERN.pattern,
        }
        return observation, info

    def step(
        self, action: str
    ) -> tuple[str, float, bool, bool, dict[str, Any]]:
        if self._shell is None or self._sandbox is None:
            raise RuntimeError("EnterpriseHPCEnv step called before reset")

        command = action if isinstance(action, str) else str(action)
        self._step_count += 1

        self._shell.sendline(command)
        output = self._await_prompt(self._step_timeout)

        grade = self._scenario.grade(self._sandbox.state_root or Path("."))
        health_delta = grade.health - self._prev_health
        self._prev_health = grade.health
        reward = health_delta
        self._last_reward = reward
        terminated = grade.done
        truncated = not terminated and self._step_count >= self._max_steps

        http_code = self._probe_ood_code()
        info: dict[str, Any] = {
            "task_id": self._scenario.TASK_ID,
            "step": self._step_count,
            "max_steps": self._max_steps,
            "reward_source": "grader",
            "command": command,
            "grader_health": grade.health,
            "grader_details": grade.details,
            "ood_http_code": http_code,
        }
        return output, reward, terminated, truncated, info

    def render(self) -> None:
        return None

    def close(self) -> None:
        self._close_shell()
        if self._sandbox is not None:
            try:
                self._sandbox.destroy()
            except Exception as exc:
                print(f"hpc_gym sandbox destroy failed {type(exc).__name__.lower()} {exc}")
            self._sandbox = None
            self._sandbox_scenario_id = None
        if self._tmp_task_dir is not None:
            shutil.rmtree(self._tmp_task_dir, ignore_errors=True)
            self._tmp_task_dir = None

    def __enter__(self) -> "EnterpriseHPCEnv":
        return self

    def __exit__(self, exc_type, exc_val, exc_tb) -> bool:
        self.close()
        return False

    def _select_scenario(self, options: dict[str, Any] | None) -> None:
        if options and "scenario" in options:
            self._scenario = resolve_scenario(options["scenario"])
            return
        if len(self._scenario_pool) == 1:
            self._scenario = self._scenario_pool[0]
            return
        self._scenario = self._rng.choice(self._scenario_pool)

    def _ensure_task_root(self) -> Path:
        if self._configured_task_root is not None:
            root = Path(self._configured_task_root)
            if self._tmp_task_dir is None:
                root.mkdir(parents=True, exist_ok=True)
        else:
            if self._tmp_task_dir is not None:
                shutil.rmtree(self._tmp_task_dir, ignore_errors=True)
            self._tmp_task_dir = tempfile.mkdtemp(prefix="hpc_task_")
            root = Path(self._tmp_task_dir)
        self._scenario.prepare_filesystem(root)
        return root

    def _spawn_shell(self) -> None:
        if self._sandbox is None:
            raise RuntimeError("sandbox must be created before shell spawn")

        runtime_cmd = self._sandbox._build_runtime_command(
            "exec /bin/bash --noprofile --norc -i"
        )
        env = {
            "PATH": "/usr/local/bin:/usr/bin:/bin:/usr/sbin:/sbin",
            "HOME": "/root",
            "TERM": "xterm",
            "HOSTNAME": PRIMARY_HOSTNAME,
            "PS1": f"[root@{PRIMARY_HOSTNAME} \\W]\\$ ",
            "LANG": "C.UTF-8",
        }
        print(f"hpc_gym spawning pexpect runtime {runtime_cmd[0]}")
        self._shell = pexpect.spawn(
            runtime_cmd[0],
            runtime_cmd[1:],
            timeout=self._shell_timeout,
            encoding="utf-8",
            codec_errors="replace",
            env=env,
        )
        self._shell.setecho(False)

    def _bootstrap_primary_prompt(self) -> None:
        if self._shell is None:
            raise RuntimeError("shell not spawned")
        # disable bracketed-paste mode (\e[?2004l) so terminal escape sequences
        # do not pollute command output and confuse the prompt regex
        self._shell.sendline(
            "printf '\\e[?2004l'; "
            f"export PS1='[root@{PRIMARY_HOSTNAME} \\W]\\$ '; "
            "export PROMPT_COMMAND=''; stty -echo 2>/dev/null; true"
        )
        self._await_prompt(self._shell_timeout)

    @staticmethod
    def _find_python3_in_sandbox() -> str:
        """return the first python3 binary that exists on the host and will
        therefore be available inside the bwrap ro-bind at the same path."""
        candidates = [
            "/usr/bin/python3",
            "/usr/bin/python3.11",
            "/usr/bin/python3.12",
            "/usr/bin/python3.9",
            "/usr/bin/python",
        ]
        for c in candidates:
            if Path(c).exists():
                return c
        return "python3"  # fallback, may fail — caught by grace window

    def _launch_ood_daemon(self) -> None:
        if self._shell is None or self._sandbox is None:
            raise RuntimeError("shell or sandbox missing for ood launch")

        python3 = self._find_python3_in_sandbox()
        self._shell.sendline(
            f"nohup {python3} {OOD_DAEMON_SCRIPT} >{OOD_LOG_PATH} 2>&1 & disown; true"
        )
        self._await_prompt(self._shell_timeout)

        for attempt in range(20):
            code = self._probe_ood_code()
            if code in {"200", "502"}:
                self._ood_started = True
                print(f"hpc_gym ood ready http_code {code} attempts {attempt + 1}")
                return
            time.sleep(0.1)
        print("hpc_gym ood did not respond within grace window proceeding anyway")

    def _enter_login_node(self) -> None:
        if self._shell is None:
            raise RuntimeError("shell not spawned")
        self._shell.sendline("ssh login")
        self._await_prompt(self._shell_timeout)

    def _await_prompt(self, timeout: float) -> str:
        if self._shell is None:
            raise RuntimeError("shell not spawned")
        try:
            self._shell.expect(PROMPT_PATTERN, timeout=timeout)
            before = self._shell.before or ""
        except pexpect.exceptions.TIMEOUT:
            before = self._shell.before or ""
            print("hpc_gym prompt timeout sending ctrl-c to recover")
            try:
                self._shell.sendcontrol("c")
                self._shell.expect(PROMPT_PATTERN, timeout=5)
            except Exception as exc:
                print(f"hpc_gym recovery failed {type(exc).__name__.lower()} {exc}")
        except pexpect.exceptions.EOF:
            before = self._shell.before or ""
            print("hpc_gym shell eof observed")
        return _strip_ansi(before).lstrip("\r\n")

    def _probe_ood_code(self) -> str:
        if self._sandbox is None:
            return ""
        probe = self._sandbox.execute(
            f"curl -s -o /dev/null -w '%{{http_code}}' http://127.0.0.1:{OOD_PORT}/",
            timeout=10.0,
        )
        return (probe.stdout or "").strip()

    def _close_shell(self) -> None:
        if self._shell is None:
            return
        try:
            if self._shell.isalive():
                self._shell.sendline("exit 0")
                try:
                    self._shell.expect(pexpect.exceptions.EOF, timeout=2)
                except Exception:
                    pass
            self._shell.close(force=True)
        except Exception as exc:
            print(f"hpc_gym shell close failed {type(exc).__name__.lower()} {exc}")
        self._shell = None


_ANSI_RE = re.compile(r"\x1b\[[0-9;?]*[A-Za-z]")


def _strip_ansi(text: str) -> str:
    return _ANSI_RE.sub("", text)


def register_env() -> None:
    try:
        gym.register(
            id="EnterpriseHPC-v0",
            entry_point="hpc_gym:EnterpriseHPCEnv",
            max_episode_steps=hpc_outage.build_definition("").metadata.max_steps,
        )
    except gym.error.Error as exc:
        print(f"hpc_gym register skipped {type(exc).__name__.lower()} {exc}")


def main() -> None:
    env = EnterpriseHPCEnv(scenario_pool=list(SCENARIO_REGISTRY))
    try:
        obs, info = env.reset(seed=0)
        print(f"reset observation {obs[:120]}")
        print(f"reset info {info}")
        obs, reward, terminated, truncated, info = env.step("sinfo")
        print(f"step reward {reward} terminated {terminated} truncated {truncated}")
        print(f"step info {info}")
        print(f"step observation\n{obs}")
    finally:
        env.close()


if __name__ == "__main__":
    os.environ.setdefault("OOD_PORT", str(OOD_PORT))
    main()