File size: 8,408 Bytes
569c142
 
 
 
 
 
 
 
 
 
da9e926
569c142
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30c52ad
 
 
 
 
569c142
 
 
da9e926
569c142
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
da9e926
569c142
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
"""
DiskPanic Inference Script
==========================
Runs all 3 tasks (easy, medium, hard) sequentially against the DiskPanic
OpenEnv, using an OpenAI-compatible LLM as the SRE agent.

Required environment variables:
    API_BASE_URL   The LLM endpoint (OpenAI-compatible)
    MODEL_NAME     The model id to use
    HF_TOKEN       API key for the LLM provider
    LOCAL_IMAGE_NAME     (optional) Docker image for the env server
                   Default: disk-panic:latest

Stdout format (one per episode):
    [START] task=<task> env=disk_panic model=<model>
    [STEP]  step=<n> action=<cmd> reward=<0.00> done=<bool> error=<msg|null>
    [END]   success=<bool> steps=<n> score=<0.000> rewards=<r1,r2,...>
"""
from __future__ import annotations

import asyncio
import os
import textwrap
from typing import List, Optional

from openai import OpenAI

try:
    from disk_panic import DiskPanicAction, DiskPanicEnv
except ImportError:
    from client import DiskPanicEnv
    from models import DiskPanicAction

# -- config ----------------------------------------------------------------

LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME", "disk-panic:latest")
API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
API_BASE_URL = os.getenv("API_BASE_URL") or "https://api.groq.com/openai/v1"
MODEL_NAME = os.getenv("MODEL_NAME") or "llama-3.3-70b-versatile"
BENCHMARK = "disk_panic"

TASKS = ["easy", "medium", "hard"]
MAX_STEPS = 15
TEMPERATURE = 0.2
MAX_TOKENS = 120
SUCCESS_SCORE_THRESHOLD = 0.6

SYSTEM_PROMPT = textwrap.dedent(
    """
    You are an SRE responding to a production incident. A Linux server has a
    full root filesystem and (sometimes) a crashed app.service. You must fix it.

    COMMAND PALETTE (this is a SIMULATED shell — ONLY these commands work, no
    pipes, no subshells, no globs except trailing /*, no flags beyond what's shown):

      df                                 show disk usage
      ls <path>                          list a directory
      du <path>                          breakdown of subdir sizes (use this to find the big file!)
      cat <path>                         view a file
      find <path>                        recursive file list
      sha256sum <path>                   hash a file or dir
      rm <path>                          delete a file
      rm -rf <path>                      delete recursively
      systemctl is-active <svc>          check service state
      systemctl restart <svc>            restart a service
      echo "content" > /path/to/file     write a file (needed for logrotate config)

    IMPORTANT RULES:
      1. NEVER touch /var/log/audit/ — it is business-critical. Touching it caps reward.
      2. Get disk usage below 80% (see `df`).
      3. For the medium task: also restart app with `systemctl restart app`.
      4. For the hard task: ALSO write a logrotate config to /etc/logrotate.d/app
         containing both the words "rotate" and "size". Example:
             echo "rotate 5 size 100M" > /etc/logrotate.d/app
      5. Start with `du /var/log` to see which subdirectory is bloated, then drill down.
      6. DO NOT use pipes (|), sort, head, or any other command not in the palette.
      7. DO NOT use glob other than trailing /*.

    Reply with EXACTLY ONE command on a single line. No markdown, no code fences,
    no leading $, no prose, no quotes around the whole line. Just the command.
    """
).strip()


# -- log helpers -----------------------------------------------------------

def log_start(task: str, env: str, model: str) -> None:
    print(f"[START] task={task} env={env} model={model}", flush=True)


def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
    error_val = error if error else "null"
    done_val = str(done).lower()
    # Keep action on one line — replace newlines with spaces.
    action_single = action.replace("\n", " ").replace("\r", " ")
    print(
        f"[STEP] step={step} action={action_single} reward={reward:.2f} done={done_val} error={error_val}",
        flush=True,
    )


def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
    rewards_str = ",".join(f"{r:.2f}" for r in rewards)
    print(
        f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}",
        flush=True,
    )


# -- prompt builder --------------------------------------------------------

def build_user_prompt(task: str, step: int, obs_stdout: str, df: str, svc: str,
                      last_error: Optional[str], history: List[str]) -> str:
    history_block = "\n".join(history[-6:]) if history else "(no previous commands)"
    err_line = f"Last error: {last_error}" if last_error else "Last error: none"
    return textwrap.dedent(
        f"""
        Task: {task}
        Step: {step}
        Current df -h /:
        {df}
        app.service: {svc}
        {err_line}

        Previous commands:
        {history_block}

        Last command output:
        {obs_stdout}

        What is your next single command?
        """
    ).strip()


def get_next_command(client: OpenAI, task: str, step: int, obs_stdout: str,
                     df: str, svc: str, last_error: Optional[str],
                     history: List[str]) -> str:
    user_prompt = build_user_prompt(task, step, obs_stdout, df, svc, last_error, history)
    try:
        completion = client.chat.completions.create(
            model=MODEL_NAME,
            messages=[
                {"role": "system", "content": SYSTEM_PROMPT},
                {"role": "user", "content": user_prompt},
            ],
            temperature=TEMPERATURE,
            max_tokens=MAX_TOKENS,
            stream=False,
        )
        text = (completion.choices[0].message.content or "").strip()
        # Strip common junk: markdown fences, leading $, trailing semicolons
        text = text.strip("`").strip()
        if text.startswith("$ "):
            text = text[2:]
        # Use only the first non-empty line
        for line in text.splitlines():
            line = line.strip()
            if line:
                return line
        return "df"
    except Exception as exc:
        print(f"[DEBUG] Model request failed: {exc}", flush=True)
        return "df"


# -- episode runner --------------------------------------------------------

async def run_episode(client: OpenAI, env: DiskPanicEnv, task: str) -> float:
    history: List[str] = []
    rewards: List[float] = []
    steps_taken = 0
    score = 0.0
    success = False

    log_start(task=task, env=BENCHMARK, model=MODEL_NAME)

    try:
        result = await env.reset(task_id=task)
        obs = result.observation
        last_error = obs.last_error

        for step in range(1, MAX_STEPS + 1):
            if result.done:
                break

            command = get_next_command(
                client, task, step, obs.stdout, obs.df_output,
                obs.service_status, last_error, history,
            )

            result = await env.step(DiskPanicAction(command=command))
            obs = result.observation
            reward = float(result.reward or 0.0)
            done = bool(result.done)

            rewards.append(reward)
            steps_taken = step
            last_error = obs.last_error

            log_step(step=step, action=command, reward=reward, done=done, error=last_error)

            history.append(f"  step {step}: {command} -> reward {reward:+.2f}")

            if done:
                break

        # Reward is the absolute current grade each step, so final score = last reward
        # (or the max observed if episode timed out before the best state was seen).
        score = max(rewards) if rewards else 0.0
        score = min(max(score, 0.0), 1.0)
        success = score >= SUCCESS_SCORE_THRESHOLD
    finally:
        log_end(success=success, steps=steps_taken, score=score, rewards=rewards)

    return score


async def main() -> None:
    client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
    env = await DiskPanicEnv.from_docker_image(LOCAL_IMAGE_NAME)
    try:
        for task in TASKS:
            await run_episode(client, env, task)
    finally:
        try:
            await env.close()
        except Exception as e:
            print(f"[DEBUG] env.close() error: {e}", flush=True)


if __name__ == "__main__":
    asyncio.run(main())