File size: 4,814 Bytes
569c142
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
"""DiskPanicEnvironment — server-side OpenEnv Environment implementation."""
from __future__ import annotations

import os
from typing import Any, Optional

from openenv.core import Environment, State

from disk_panic.models import DiskPanicAction, DiskPanicObservation
from disk_panic.server.graders import GRADERS
from disk_panic.server.scenarios import SCENARIOS, TASK_ORDER
from disk_panic.server.vfs import VFS, execute

MAX_STEPS = 15


class DiskPanicEnvironment(Environment[DiskPanicAction, DiskPanicObservation, State]):
    """A tiny SRE world where the agent issues bash-lite commands to fix a broken server."""

    SUPPORTS_CONCURRENT_SESSIONS = True

    def __init__(self) -> None:
        super().__init__()
        self._task_index = 0
        self._vfs: VFS = VFS()
        self._targets: dict = {}
        self._task_id: str = TASK_ORDER[0]
        self._step_count: int = 0
        self._state = State(episode_id=None, step_count=0)
        self._last_reward: float = 0.0

    # -- lifecycle ---------------------------------------------------------

    def reset(
        self,
        seed: Optional[int] = None,
        episode_id: Optional[str] = None,
        task_id: Optional[str] = None,
        **kwargs: Any,
    ) -> DiskPanicObservation:
        # Task selection: explicit task_id kwarg > env var > round-robin.
        if task_id and task_id in SCENARIOS:
            chosen = task_id
        else:
            env_task = os.getenv("DISK_PANIC_TASK")
            if env_task and env_task in SCENARIOS:
                chosen = env_task
            else:
                chosen = TASK_ORDER[self._task_index % len(TASK_ORDER)]
                self._task_index += 1

        builder = SCENARIOS[chosen]
        self._vfs, self._targets = builder()
        self._task_id = chosen
        self._step_count = 0
        self._last_reward = 0.0
        self._state = State(episode_id=episode_id or f"ep-{chosen}", step_count=0)

        return DiskPanicObservation(
            done=False,
            reward=None,
            stdout=(
                f"=== DiskPanic task: {chosen} ===\n"
                f"{self._describe_task()}\n"
                f"Current disk usage: {self._vfs.usage_pct():.1f}%"
            ),
            df_output=self._vfs.df_output(),
            service_status=self._vfs.services.get("app", "unknown"),
            task_id=chosen,
            step=0,
            last_error=None,
        )

    def step(
        self,
        action: DiskPanicAction,
        timeout_s: Optional[float] = None,
        **kwargs: Any,
    ) -> DiskPanicObservation:
        self._step_count += 1
        self._state.step_count = self._step_count

        # Execute the command against the VFS.
        result = execute(self._vfs, action.command)
        # Tick the world (runaway writer, etc.)
        self._vfs.tick()

        # Compute current grade — always in [0.0, 1.0].
        grader = GRADERS[self._task_id]
        current_score = grader(self._vfs, self._targets)
        self._last_reward = current_score

        # Episode ends when full score reached OR max steps elapsed.
        done = current_score >= 0.98 or self._step_count >= MAX_STEPS

        # Reward is the ABSOLUTE current grade so the last step reflects the final score.
        reward = round(float(current_score), 4)

        return DiskPanicObservation(
            done=done,
            reward=reward,
            stdout=result.stdout or "",
            df_output=self._vfs.df_output(),
            service_status=self._vfs.services.get("app", "unknown"),
            task_id=self._task_id,
            step=self._step_count,
            last_error=result.error,
        )

    def state(self) -> State:
        return self._state

    def close(self) -> None:
        self._vfs = VFS()
        self._targets = {}

    # -- helpers -----------------------------------------------------------

    def _describe_task(self) -> str:
        if self._task_id == "easy":
            return (
                "Root filesystem is at >95% usage — find and remove the bloated log "
                "under /var/log. DO NOT touch /var/log/audit/."
            )
        if self._task_id == "medium":
            return (
                "Disk is full AND app.service has failed. Free space, restart "
                "app.service, and preserve /var/log/audit/."
            )
        if self._task_id == "hard":
            return (
                "A runaway process keeps growing /var/log/app/runaway.log by ~100M "
                "every tick. Free space, restart app.service, preserve audit logs, "
                "AND drop a logrotate config at /etc/logrotate.d/app (must contain "
                "the words 'rotate' and 'size') to cap the growth."
            )
        return "unknown task"