vx7sh commited on
Commit
ff665de
·
1 Parent(s): 0f99e53

Ship Round 2 manifest/docs, dashboard, and GRPO training pipeline

Browse files
Files changed (7) hide show
  1. README.md +102 -184
  2. dashboard/README.md +14 -0
  3. dashboard/war_room.py +369 -0
  4. inference.py +4 -1
  5. openenv.yaml +55 -18
  6. requirements.txt +7 -9
  7. training/grpo_train.py +368 -0
README.md CHANGED
@@ -23,232 +23,150 @@ models like the one evaluating this environment.
23
  ## Why This Matters
24
 
25
  Large-scale AI training runs on clusters of hundreds of
26
- GPUs across many nodes. These clusters fail constantly:
27
 
28
- - **GPU OOM errors** stall entire training jobs
29
- - **Network congestion** cuts throughput by 40%+
30
- - **Code desynchronization** across ranks hangs jobs silently
31
 
32
- These failures require expert human SREs to debug and fix.
33
- There is no standardized benchmark for evaluating whether
34
- AI agents can handle these failures autonomously.
35
 
36
- NervousSystem-Env fills that gap.
37
 
38
- ---
39
 
40
- ## Environment Design
 
 
 
41
 
42
- ### Observation Space
43
 
44
- The agent receives a `ClusterObservation` at each step:
45
 
46
- | Field | Type | Description |
47
- |---|---|---|
48
- | `nodes` | `list[NodeState]` | Per-node GPU memory, utilization, health, XID errors |
49
- | `training` | `TrainingMetrics` | Throughput, target, job status, stalled steps |
50
- | `visible_logs` | `list[str]` | Surface telemetry — throughput and status logs |
51
- | `step_count` | `int` | Current step number in this episode |
52
- | `episode_id` | `str` | Deterministic episode identifier |
53
-
54
- **Key design:** Deep diagnostic data (Flight Recorder
55
- buffers, NCCL logs) is hidden by default. The agent must
56
- actively query for it using investigation actions.
57
-
58
- ### Action Space
59
-
60
- | Action | Parameters | Destructive | Description |
61
- |---|---|---|---|
62
- | `inspect_flight_recorder` | `rank_id: int` | No | Get PyTorch Flight Recorder data for a rank |
63
- | `query_nccl_logs` | `time_window: int` | No | Get NCCL communication log entries |
64
- | `topo_reorder` | `affinity: str` | No | Reorder ring topology (use "rack" for fix) |
65
- | `patch_divergent_code` | `file: str, fix_type: str` | No | Patch desynchronized code |
66
- | `restart_rank` | `rank_id: int` | **Yes** | Restart a specific rank (-0.2 penalty) |
67
- | `reset_ib_interface` | `node_id: int` | **Yes** | Reset IB interface (-0.2 penalty) |
68
- | `adjust_sharding_strategy` | `strategy: str` | No | Change sharding strategy |
69
- | `noop` | none | No | Take no action |
70
-
71
- ---
72
 
73
  ## Tasks
74
 
75
- ### Easy Culprit Rank Identification
76
- **Difficulty:** Easy
77
-
78
- Training is stalled. A NCCL watchdog timeout has fired
79
- across all 8 nodes. One rank failed to join a collective
80
- operation due to an OOM error (XID 79).
81
-
82
- The agent must use `inspect_flight_recorder(rank_id)` to
83
- examine each rank's Flight Recorder buffer and identify
84
- which rank has a stalled collective sequence.
85
-
86
- **Grader:** 1.0 for correct rank identified, 0.0 otherwise.
87
- Efficiency bonus up to +0.2 for early diagnosis.
88
- Penalty -0.1 per destructive action taken.
89
-
90
- **Anti-cheat:** The failing rank is randomly seeded on
91
- every `reset()` call. Hardcoding a rank ID scores 0.0.
92
-
93
- ---
94
-
95
- ### Medium — Spine Switch Congestion Resolution
96
- **Difficulty:** Medium
97
 
98
- Training is running but at 55-65% of target throughput.
99
- The ring topology stretches across oversubscribed spine
100
- switches. The agent must call `topo_reorder(affinity="rack")`
101
- to enforce rack-local communication.
102
 
103
- **Grader:** Continuous score based on
104
- `throughput / target_throughput` (0.0 to 1.0).
105
- Bonus +0.15 for sustaining recovery for 5+ steps.
106
- Penalty -0.2 per destructive action.
107
-
108
- ---
109
-
110
- ### Hard — Asymmetric Compilation Desync Fix
111
- **Difficulty:** Hard
112
-
113
- Training is completely hung. Different ranks compiled
114
- different NCCL collectives due to data-dependent branching
115
- in the model code. The job will never recover on its own.
116
-
117
- The agent must:
118
- 1. Investigate using `query_nccl_logs` or
119
- `inspect_flight_recorder`
120
- 2. Identify the divergent source file using
121
- `patch_divergent_code(file=..., fix_type=...)`
122
- 3. Verify training resumes for 5+ steps
123
-
124
- **Grader:** 3-stage scoring:
125
- - 0.3 for identifying the correct file
126
- - +0.4 for applying the correct patch
127
- - +0.3 for sustained training recovery (5+ steps)
128
- = 1.0 maximum
129
-
130
- ---
131
-
132
- ## Reward Function
133
-
134
- Rewards are continuous — the agent receives signal at
135
- every step, not just at episode end.
136
-
137
- | Situation | Reward |
138
- |---|---|
139
- | Correct rank identified (easy) | +0.5 |
140
- | Investigation action taken | +0.05 |
141
- | Throughput improvement (medium) | proportional to ratio |
142
- | Correct file identified (hard) | +0.3 |
143
- | Correct patch applied (hard) | +0.7 cumulative |
144
- | Training recovered 5+ steps | +0.3 |
145
- | Destructive action taken | -0.2 |
146
- | Noop | 0.0 |
147
-
148
- ---
149
-
150
- ## Setup and Usage
151
 
152
- ### Run with Docker
153
- ```bash
154
- # Build
155
- docker build -t nervousystem-env .
156
 
157
- # Run
158
- docker run -p 7860:7860 nervousystem-env
159
 
160
- # Verify
161
- curl http://localhost:7860/health
162
- ```
163
-
164
- ### Run locally
165
  ```bash
166
- # Install dependencies
167
  pip install -r requirements.txt
168
 
169
- # Start server
170
  uvicorn app.main:app --host 0.0.0.0 --port 7860
171
 
172
- # In another terminal, run inference
173
- export API_BASE_URL=https://router.huggingface.co/v1
174
- export MODEL_NAME=meta-llama/Llama-3.1-8B-Instruct
175
- export HF_TOKEN=your_token_here
176
- export ENV_BASE_URL=http://localhost:7860
177
- python inference.py
178
- ```
179
-
180
- ### Environment Variables
181
 
182
- | Variable | Required | Description |
183
- |---|---|---|
184
- | `API_BASE_URL` | Yes | LLM API endpoint |
185
- | `MODEL_NAME` | Yes | Model identifier |
186
- | `HF_TOKEN` | Yes | HuggingFace API token |
187
- | `ENV_BASE_URL` | No | Env server URL (default: http://localhost:7860) |
188
 
189
- ---
 
 
190
 
191
  ## API Endpoints
192
 
193
  | Endpoint | Method | Description |
194
  |---|---|---|
195
  | `/health` | GET | Health check |
196
- | `/reset` | POST | Start new episode |
197
- | `/step` | POST | Take an action |
198
- | `/state` | GET | Get current observation |
199
- | `/grade` | POST | Get episode score |
200
- | `/tasks` | GET | List available tasks |
 
201
 
202
- ---
203
-
204
- ## Baseline Scores
205
 
206
- Scores produced by running inference.py with
207
- `meta-llama/Llama-3.1-8B-Instruct` (seed=42):
 
 
208
 
209
- | Task | Score | Passed |
210
- |---|---|---|
211
- | easy | TBD | TBD |
212
- | medium | TBD | TBD |
213
- | hard | TBD | TBD |
214
 
215
- *Run `python inference.py` to reproduce scores.
216
- Scores will be updated after HF Space deployment.*
 
 
 
 
217
 
218
- ---
219
 
220
  ## Project Structure
221
- ```
 
222
  nervousystem-env/
223
  ├── app/
224
- │ ├── main.py # FastAPI endpoints
225
- │ ├── env.py # Environment core logic
226
- │ ├── models.py # Pydantic typed models
227
- │ └── config.py # Scenarios and constants
228
- ├── simulation/
229
- │ ├── cluster.py # GPU cluster state machine
230
- │ ├── failures.py # Failure injection
231
- │ └── telemetry.py # Log generation
232
- ├── tasks/
233
- │ ├── easy.py # Culprit rank identification
234
- │ ├── medium.py # Congestion resolution
235
- │ └── hard.py # Desync fix
236
  ├── graders/
 
 
237
  │ ├── easy_grader.py
238
- │ ├── medium_grader.py
239
- │ └── hard_grader.py
240
- ├── inference.py # Baseline agent script
241
- ├── Dockerfile
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
  ├── openenv.yaml
243
- ── pyproject.toml
 
 
244
  ```
245
 
246
- ---
247
-
248
  ## OpenEnv Compliance
249
 
250
- - `openenv validate` passes
251
- - Typed Pydantic v2 models
252
- - Deterministic graders
253
- - Docker deployment
254
- - 3 tasks with difficulty progression
 
 
 
23
  ## Why This Matters
24
 
25
  Large-scale AI training runs on clusters of hundreds of
26
+ # 🧠 NervousSystem-Env
27
 
28
+ > An AI agent fixing the infrastructure that trains AI.
29
+ > Every minute of cluster downtime wastes $5,000 in compute.
 
30
 
31
+ ## The Problem
 
 
32
 
33
+ Large-scale AI training across 1000+ GPU clusters fails constantly due to hardware faults, network bottlenecks, distributed synchronization bugs, and runtime version drift. Human SREs are forced to diagnose these incidents at 3am under extreme time pressure. NervousSystem-Env turns that operational pain into a training environment where autonomous agents learn to detect failures, route work to specialist workers, and recover jobs before expensive downtime compounds.
34
 
35
+ ## Why This Matters
36
 
37
+ - GPU OOM (XID 79): stalls entire training job.
38
+ - Spine switch congestion: cuts throughput 40%+.
39
+ - Compilation desync: hangs job permanently.
40
+ - LD_LIBRARY_PATH cascade: Severity-1 fleet-wide incident.
41
 
42
+ ## Architecture
43
 
44
+ NervousSystem-Env uses a Fleet AI Supervisor-Worker design. A supervisor agent receives global cluster state and delegates targeted sub-tasks to specialist workers via `/delegate`. Workers return structured results with confidence and coordination reward signals, enabling multi-agent training for routing, diagnosis, and remediation.
45
 
46
+ ```text
47
+ Supervisor Agent
48
+
49
+ ├── LogInspectorWorker (flight recorder, NCCL logs)
50
+ ├── PatchAgentWorker (code patching, verification)
51
+ ├── TopoAgentWorker (topology, bandwidth)
52
+ └── VersionCheckerWorker (NCCL version, LD_LIBRARY_PATH)
53
+ ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
  ## Tasks
56
 
57
+ | Task | Difficulty | Max Steps | Failure Type | Key Actions |
58
+ |---|---:|---:|---|---|
59
+ | easy | easy | 50 | OOM rank failure | `inspect_flight_recorder` |
60
+ | medium | medium | 50 | network congestion | `topo_reorder(affinity="rack")` |
61
+ | hard | hard | 50 | collective desync | `query_nccl_logs`, `patch_divergent_code` |
62
+ | cascade | cascade | 120 | version cascade (OOM→congestion→desync) | ordered multi-phase recovery |
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
+ ## Reward Model
 
 
 
65
 
66
+ ```text
67
+ Reward = 0.60 * R_success + 0.30 * R_subgoal - 0.10 * log(total_tokens)
68
+ ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
+ - `R_success`: binary completion signal (recovered/running within step limit).
71
+ - `R_subgoal`: continuous task-progress score.
72
+ - `log(total_tokens)`: efficiency penalty to discourage verbose reasoning.
 
73
 
74
+ ## Quick Start
 
75
 
 
 
 
 
 
76
  ```bash
77
+ # Install
78
  pip install -r requirements.txt
79
 
80
+ # Start environment server
81
  uvicorn app.main:app --host 0.0.0.0 --port 7860
82
 
83
+ # Start war room dashboard
84
+ python dashboard/war_room.py
 
 
 
 
 
 
 
85
 
86
+ # Run baseline agent
87
+ python inference.py
 
 
 
 
88
 
89
+ # Train with GRPO
90
+ python training/grpo_train.py
91
+ ```
92
 
93
  ## API Endpoints
94
 
95
  | Endpoint | Method | Description |
96
  |---|---|---|
97
  | `/health` | GET | Health check |
98
+ | `/reset` | POST | Reset episode by task and seed |
99
+ | `/step` | POST | Apply one SRE action |
100
+ | `/state` | GET | Fetch current observation |
101
+ | `/grade` | POST | Grade current episode |
102
+ | `/tasks` | GET | List task metadata |
103
+ | `/delegate` | POST | Supervisor delegates to worker agent |
104
 
105
+ ## Hackathon Themes
 
 
106
 
107
+ - Theme 1 (Fleet AI): Supervisor-Worker with `/delegate` endpoint.
108
+ - Theme 2 (Long-Horizon): Cascade task (120 steps), Mercor reward shaping.
109
+ - Theme 3.1 (Professional Tasks): NCCL diagnostics + Flight Recorder v2.5 workflow.
110
+ - Theme 4 (Self-Improvement): Adversarial curriculum via seeded failure permutations.
111
 
112
+ ## Training Results
 
 
 
 
113
 
114
+ | Task | Baseline | Trained | Improvement |
115
+ |---|---:|---:|---:|
116
+ | easy | TBD | TBD | TBD |
117
+ | medium | TBD | TBD | TBD |
118
+ | hard | TBD | TBD | TBD |
119
+ | cascade | TBD | TBD | TBD |
120
 
121
+ Run `python training/grpo_train.py` to reproduce.
122
 
123
  ## Project Structure
124
+
125
+ ```text
126
  nervousystem-env/
127
  ├── app/
128
+ │ ├── config.py
129
+ │ ├── env.py
130
+ │ ├── main.py
131
+ │ └── models.py
 
 
 
 
 
 
 
 
132
  ├── graders/
133
+ │ ├── base.py
134
+ │ ├── cascade_grader.py
135
  │ ├── easy_grader.py
136
+ │ ├── hard_grader.py
137
+ │ └── medium_grader.py
138
+ ├── simulation/
139
+ ├── cluster.py
140
+ │ ├── failures.py
141
+ │ ├── fleet.py
142
+ │ └── telemetry.py
143
+ ├── tasks/
144
+ │ ├── base.py
145
+ │ ├── cascade.py
146
+ │ ├── easy.py
147
+ │ ├── hard.py
148
+ │ └── medium.py
149
+ ├── dashboard/
150
+ │ ├── README.md
151
+ │ └── war_room.py
152
+ ├── training/
153
+ │ └── grpo_train.py
154
+ ├── tests/
155
+ │ ├── test_fleet.py
156
+ │ └── test_graders.py
157
+ ├── inference.py
158
  ├── openenv.yaml
159
+ ── requirements.txt
160
+ └── server/
161
+ └── app.py
162
  ```
163
 
 
 
164
  ## OpenEnv Compliance
165
 
166
+ - `openenv validate` passes.
167
+ - Typed Pydantic v2 models.
168
+ - Deterministic graders.
169
+ - Docker deployment.
170
+ - 4 tasks with difficulty progression.
171
+ - Multi-agent `/delegate` endpoint.
172
+ # In another terminal, run inference
dashboard/README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SRE War Room Dashboard
2
+
3
+ Run the environment server first on `7860`, then launch the Gradio dashboard on `7861`.
4
+
5
+ ```bash
6
+ uvicorn app.main:app --host 0.0.0.0 --port 7860
7
+ python dashboard/war_room.py
8
+ ```
9
+
10
+ Optional custom server URL:
11
+
12
+ ```bash
13
+ ENV_BASE_URL=http://localhost:7860 python dashboard/war_room.py
14
+ ```
dashboard/war_room.py ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import os
5
+ from datetime import datetime
6
+
7
+ import gradio as gr
8
+ import requests
9
+
10
+ ENV_BASE_URL = os.getenv("ENV_BASE_URL", "http://localhost:7860")
11
+
12
+
13
+ def render_ring(nodes: list[dict]) -> str:
14
+ """Return an HTML string with 8 colored divs arranged in a circle."""
15
+ health_to_color = {
16
+ "healthy": "#22c55e",
17
+ "degraded": "#eab308",
18
+ "failed": "#ef4444",
19
+ }
20
+ health_to_emoji = {
21
+ "healthy": "🟢",
22
+ "degraded": "🟡",
23
+ "failed": "🔴",
24
+ }
25
+
26
+ padded_nodes = list(nodes[:8])
27
+ while len(padded_nodes) < 8:
28
+ padded_nodes.append(
29
+ {
30
+ "node_id": len(padded_nodes),
31
+ "health_status": "failed",
32
+ "gpu_memory_used_mb": 0,
33
+ "xid_errors": [],
34
+ }
35
+ )
36
+
37
+ cards: list[str] = []
38
+ for index, node in enumerate(padded_nodes):
39
+ health = str(node.get("health_status", "failed"))
40
+ color = health_to_color.get(health, "#ef4444")
41
+ emoji = health_to_emoji.get(health, "🔴")
42
+ node_id = node.get("node_id", index)
43
+ gpu_mem = node.get("gpu_memory_used_mb", 0)
44
+ xid_errors = node.get("xid_errors", [])
45
+ xid_text = ",".join(str(code) for code in xid_errors) if xid_errors else "none"
46
+ angle = index * 45
47
+ cards.append(
48
+ f"""
49
+ <div class='node-card' style='background:{color};
50
+ transform: rotate({angle}deg) translate(155px) rotate(-{angle}deg);'>
51
+ <div><strong>{emoji} node {node_id}</strong></div>
52
+ <div>health: {health}</div>
53
+ <div>gpu_mem: {float(gpu_mem):.0f} MB</div>
54
+ <div>xid: {xid_text}</div>
55
+ </div>
56
+ """
57
+ )
58
+
59
+ return f"""
60
+ <style>
61
+ .ring-wrap {{
62
+ position: relative;
63
+ width: 420px;
64
+ height: 420px;
65
+ margin: 0 auto;
66
+ border-radius: 50%;
67
+ background: radial-gradient(circle, #0b1220 0%, #111827 65%, #1f2937 100%);
68
+ border: 1px solid #374151;
69
+ }}
70
+ .ring-center {{
71
+ position: absolute;
72
+ left: 50%; top: 50%;
73
+ transform: translate(-50%, -50%);
74
+ color: #d1d5db;
75
+ font-weight: 700;
76
+ font-size: 14px;
77
+ }}
78
+ .node-card {{
79
+ position: absolute;
80
+ left: 50%;
81
+ top: 50%;
82
+ width: 132px;
83
+ min-height: 72px;
84
+ margin-left: -66px;
85
+ margin-top: -36px;
86
+ border-radius: 10px;
87
+ padding: 8px;
88
+ color: #111827;
89
+ box-shadow: 0 6px 20px rgba(0,0,0,0.25);
90
+ font-size: 11px;
91
+ line-height: 1.2;
92
+ }}
93
+ </style>
94
+ <div class='ring-wrap'>
95
+ <div class='ring-center'>Cluster Ring</div>
96
+ {''.join(cards)}
97
+ </div>
98
+ """
99
+
100
+
101
+ def _safe_get(path: str) -> dict | None:
102
+ try:
103
+ response = requests.get(f"{ENV_BASE_URL}{path}", timeout=5)
104
+ response.raise_for_status()
105
+ return response.json()
106
+ except Exception:
107
+ return None
108
+
109
+
110
+ def _safe_post(path: str, payload: dict) -> dict | None:
111
+ try:
112
+ response = requests.post(f"{ENV_BASE_URL}{path}", json=payload, timeout=8)
113
+ response.raise_for_status()
114
+ return response.json()
115
+ except Exception:
116
+ return None
117
+
118
+
119
+ def _offline_panel(action_log: list[list]) -> tuple:
120
+ offline_row = [["-", "offline", 0.0, 0.0, "⚠️ Server offline"]]
121
+ return (
122
+ "<h3>⚠️ Server offline</h3>",
123
+ "⚠️ Server offline",
124
+ 0.0,
125
+ 0.0,
126
+ 0.0,
127
+ 0.0,
128
+ 0.0,
129
+ "## ⚠️ Server offline",
130
+ action_log[-20:] if action_log else offline_row,
131
+ action_log,
132
+ )
133
+
134
+
135
+ def _panel_from_state(state: dict, action_log: list[list]) -> tuple:
136
+ nodes = state.get("nodes", [])
137
+ training = state.get("training", {})
138
+ throughput = float(training.get("throughput_tokens_per_sec", 0.0))
139
+ target = float(training.get("target_throughput", 1.0))
140
+ stalled_steps = float(training.get("stalled_steps", 0.0))
141
+ status = str(training.get("job_status", "unknown"))
142
+ cumulative_tokens = float(state.get("cumulative_tokens", 0))
143
+ throughput_pct = (throughput / max(1.0, target)) * 100.0
144
+ simulated_loss_prevented = stalled_steps * 83.33
145
+ loss_text = f"## 💰 ${simulated_loss_prevented:,.2f} saved"
146
+
147
+ return (
148
+ render_ring(nodes),
149
+ status,
150
+ throughput,
151
+ throughput_pct,
152
+ stalled_steps,
153
+ cumulative_tokens,
154
+ simulated_loss_prevented,
155
+ loss_text,
156
+ action_log[-20:],
157
+ action_log,
158
+ )
159
+
160
+
161
+ def refresh_panels(task_id: str, action_log: list[list]) -> tuple:
162
+ """Refresh dashboard panels from live server state."""
163
+ _ = task_id
164
+ state = _safe_get("/state")
165
+ if state is None:
166
+ return _offline_panel(action_log)
167
+ return _panel_from_state(state, action_log)
168
+
169
+
170
+ def reset_episode(task_id: str) -> tuple:
171
+ """Reset episode for selected task and clear action log."""
172
+ result = _safe_post("/reset", {"task_id": task_id})
173
+ if result is None:
174
+ offline = _offline_panel([])
175
+ return (*offline, gr.update(active=True))
176
+ panel = _panel_from_state(result, [])
177
+ return (*panel, gr.update(active=True))
178
+
179
+
180
+ def _demo_actions() -> list[dict]:
181
+ return [
182
+ {"action_type": "inspect_flight_recorder", "parameters": {"rank_id": 0}},
183
+ {"action_type": "inspect_flight_recorder", "parameters": {"rank_id": 1}},
184
+ {"action_type": "inspect_flight_recorder", "parameters": {"rank_id": 2}},
185
+ {"action_type": "query_nccl_logs", "parameters": {"time_window": 5}},
186
+ {"action_type": "query_nccl_logs", "parameters": {"time_window": 5}},
187
+ {"action_type": "topo_reorder", "parameters": {"affinity": "rack"}},
188
+ {"action_type": "topo_reorder", "parameters": {"affinity": "rack"}},
189
+ {"action_type": "noop", "parameters": {}},
190
+ {"action_type": "noop", "parameters": {}},
191
+ {"action_type": "noop", "parameters": {}},
192
+ ]
193
+
194
+
195
+ def run_demo_agent(task_id: str, action_log: list[list]) -> tuple:
196
+ """Run exactly 10 hardcoded demo steps for visualization."""
197
+ _ = _safe_post("/reset", {"task_id": task_id})
198
+ rows = list(action_log)
199
+
200
+ for action in _demo_actions():
201
+ step_result = _safe_post("/step", action)
202
+ if step_result is None:
203
+ return (*_offline_panel(rows), gr.update(active=True))
204
+ reward = step_result.get("reward", {})
205
+ observation = step_result.get("observation", {})
206
+ step_num = observation.get("step_count", len(rows) + 1)
207
+ row = [
208
+ step_num,
209
+ action["action_type"],
210
+ float(reward.get("value", 0.0)),
211
+ float(reward.get("token_efficiency_score", 0.0)),
212
+ f"{reward.get('info', '')} @ {datetime.utcnow().isoformat(timespec='seconds')}",
213
+ ]
214
+ rows.append(row)
215
+
216
+ state = _safe_get("/state")
217
+ if state is None:
218
+ return (*_offline_panel(rows), gr.update(active=True))
219
+ panel = _panel_from_state(state, rows)
220
+ return (*panel, gr.update(active=True))
221
+
222
+
223
+ def stop_refresh():
224
+ """Stop the auto-refresh timer."""
225
+ return gr.update(active=False)
226
+
227
+
228
+ def delegate_task(worker: str, action: str) -> dict:
229
+ """Submit delegation request to /delegate endpoint."""
230
+ payload = {
231
+ "worker": worker,
232
+ "action": action,
233
+ "parameters": {},
234
+ "supervisor_reasoning": f"War Room delegation at {datetime.utcnow().isoformat()}",
235
+ "token_count": 0,
236
+ }
237
+ result = _safe_post("/delegate", payload)
238
+ if result is None:
239
+ return {
240
+ "worker": worker,
241
+ "action": action,
242
+ "success": False,
243
+ "output": {"error": "⚠️ Server offline"},
244
+ "confidence": 0.0,
245
+ "coordination_reward": 0.0,
246
+ "explanation": "⚠️ Server offline",
247
+ "cumulative_coordination_reward": 0.0,
248
+ "raw": json.dumps(payload),
249
+ }
250
+ return result
251
+
252
+
253
+ with gr.Blocks(title="SRE War Room") as demo:
254
+ gr.Markdown("# 🛠️ SRE War Room")
255
+ gr.Markdown(f"Connected env: `{ENV_BASE_URL}`")
256
+
257
+ with gr.Row():
258
+ task_dropdown = gr.Dropdown(
259
+ choices=["easy", "medium", "hard", "cascade"],
260
+ value="easy",
261
+ label="Task",
262
+ )
263
+ reset_btn = gr.Button("Reset Episode", variant="primary")
264
+ demo_btn = gr.Button("Run Demo Agent")
265
+ stop_btn = gr.Button("Stop")
266
+
267
+ with gr.Row():
268
+ with gr.Column(scale=2):
269
+ gr.Markdown("## Panel A: Cluster Ring Topology")
270
+ ring_html = gr.HTML(render_ring([]))
271
+ with gr.Column(scale=1):
272
+ gr.Markdown("## Panel B: Training Metrics")
273
+ job_status_label = gr.Label(label="job_status", value="unknown")
274
+ throughput_num = gr.Number(label="throughput_tokens_per_sec", value=0.0)
275
+ throughput_pct_num = gr.Number(label="throughput_%_of_target", value=0.0)
276
+ stalled_steps_num = gr.Number(label="stalled_steps", value=0.0)
277
+ cumulative_tokens_num = gr.Number(label="cumulative_tokens", value=0.0)
278
+ loss_num = gr.Number(label="Simulated Loss Prevented $", value=0.0)
279
+ loss_text = gr.Markdown("## 💰 $0.00 saved")
280
+
281
+ with gr.Row():
282
+ with gr.Column(scale=3):
283
+ gr.Markdown("## Panel C: Agent Action Log")
284
+ action_df = gr.Dataframe(
285
+ headers=["step", "action_type", "reward", "mer_score", "info"],
286
+ value=[],
287
+ row_count=20,
288
+ column_count=(5, "fixed"),
289
+ datatype=["number", "str", "number", "number", "str"],
290
+ wrap=True,
291
+ )
292
+ with gr.Column(scale=2):
293
+ gr.Markdown("## Fleet Delegation")
294
+ worker_dropdown = gr.Dropdown(
295
+ choices=["log_inspector", "patch_agent", "topo_agent", "version_checker"],
296
+ value="log_inspector",
297
+ label="worker",
298
+ )
299
+ delegation_action = gr.Textbox(value="check_nccl_version", label="action")
300
+ delegate_btn = gr.Button("Delegate")
301
+ delegate_json = gr.JSON(label="Last delegation result")
302
+
303
+ action_log_state = gr.State([])
304
+ refresh_timer = gr.Timer(value=2.0, active=True)
305
+
306
+ refresh_timer.tick(
307
+ fn=refresh_panels,
308
+ inputs=[task_dropdown, action_log_state],
309
+ outputs=[
310
+ ring_html,
311
+ job_status_label,
312
+ throughput_num,
313
+ throughput_pct_num,
314
+ stalled_steps_num,
315
+ cumulative_tokens_num,
316
+ loss_num,
317
+ loss_text,
318
+ action_df,
319
+ action_log_state,
320
+ ],
321
+ )
322
+
323
+ reset_btn.click(
324
+ fn=reset_episode,
325
+ inputs=[task_dropdown],
326
+ outputs=[
327
+ ring_html,
328
+ job_status_label,
329
+ throughput_num,
330
+ throughput_pct_num,
331
+ stalled_steps_num,
332
+ cumulative_tokens_num,
333
+ loss_num,
334
+ loss_text,
335
+ action_df,
336
+ action_log_state,
337
+ refresh_timer,
338
+ ],
339
+ )
340
+
341
+ demo_btn.click(
342
+ fn=run_demo_agent,
343
+ inputs=[task_dropdown, action_log_state],
344
+ outputs=[
345
+ ring_html,
346
+ job_status_label,
347
+ throughput_num,
348
+ throughput_pct_num,
349
+ stalled_steps_num,
350
+ cumulative_tokens_num,
351
+ loss_num,
352
+ loss_text,
353
+ action_df,
354
+ action_log_state,
355
+ refresh_timer,
356
+ ],
357
+ )
358
+
359
+ stop_btn.click(fn=stop_refresh, outputs=[refresh_timer])
360
+
361
+ delegate_btn.click(
362
+ fn=delegate_task,
363
+ inputs=[worker_dropdown, delegation_action],
364
+ outputs=[delegate_json],
365
+ )
366
+
367
+
368
+ if __name__ == "__main__":
369
+ demo.launch(server_port=7861, share=False)
inference.py CHANGED
@@ -69,7 +69,7 @@ MODEL_NAME = os.getenv(
69
  )
70
  API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY", "")
71
  ENV_BASE_URL = os.getenv("ENV_BASE_URL", "http://localhost:7860")
72
- MAX_STEPS = 15
73
  TEMPERATURE = 0.1
74
  MAX_TOKENS = 300
75
  SEED = 42
@@ -100,6 +100,9 @@ Rules:
100
  - Use query_nccl_logs to see communication errors.
101
  - Avoid restart_rank unless absolutely necessary — it is destructive.
102
  - If you already know the failing rank, fix it directly.
 
 
 
103
 
104
  Example response:
105
  {"action_type": "inspect_flight_recorder", "parameters": {"rank_id": 3}}
 
69
  )
70
  API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY", "")
71
  ENV_BASE_URL = os.getenv("ENV_BASE_URL", "http://localhost:7860")
72
+ MAX_STEPS = 20
73
  TEMPERATURE = 0.1
74
  MAX_TOKENS = 300
75
  SEED = 42
 
100
  - Use query_nccl_logs to see communication errors.
101
  - Avoid restart_rank unless absolutely necessary — it is destructive.
102
  - If you already know the failing rank, fix it directly.
103
+ - For cascade failures: solve phases in order. Phase 1=OOM diagnosis,
104
+ Phase 2=topo_reorder, Phase 3=query_nccl_logs then patch_divergent_code
105
+ - Token efficiency matters: fewer tokens = higher reward
106
 
107
  Example response:
108
  {"action_type": "inspect_flight_recorder", "parameters": {"rank_id": 3}}
openenv.yaml CHANGED
@@ -1,43 +1,80 @@
1
  name: nervousystem-env
2
- version: "1.0.0"
3
  description: >
4
- SRE environment for diagnosing and fixing failures
5
- in a distributed GPU training cluster. The cluster
6
- is training a large-scale AI model. Agents must
7
- investigate, diagnose, and repair the system using
8
- real SRE workflows.
9
  author: v4xsh
10
  tags:
11
  - openenv
12
  - sre
13
  - distributed-training
14
- - gpu
15
- - infrastructure
16
- - hpc
17
- entry_point: "app.main:app"
 
 
18
  tasks:
19
  - id: easy
20
  name: "Culprit Rank Identification"
21
  difficulty: easy
 
22
  description: >
23
- Training is stalled due to an OOM failure on one rank.
24
- Identify the failing rank using Flight Recorder inspection.
25
  - id: medium
26
  name: "Spine Switch Congestion Resolution"
27
  difficulty: medium
 
28
  description: >
29
- Training throughput is degraded due to network congestion.
30
- Reorder the ring topology to restore bandwidth.
31
  - id: hard
32
  name: "Asymmetric Compilation Desync Fix"
33
  difficulty: hard
 
34
  description: >
35
- Training is hung due to different ranks compiling different
36
- NCCL collectives. Find and patch the divergent code.
 
 
 
 
 
 
 
 
37
  observation_space:
38
  type: object
39
- description: "ClusterObservation with node health, training metrics, and surface logs"
 
 
40
  action_space:
41
  type: object
42
- description: "SREAction with action_type and parameters dict"
 
 
 
 
43
  reward_range: [0.0, 1.0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  name: nervousystem-env
2
+ version: "2.0.0"
3
  description: >
4
+ Fleet AI environment for autonomous SRE agents managing distributed
5
+ GPU training clusters. Agents act as supervisors orchestrating
6
+ specialized worker agents to diagnose and fix cascading failures
7
+ across 1000+ GPU clusters. Every minute of downtime costs $5,000
8
+ in wasted compute.
9
  author: v4xsh
10
  tags:
11
  - openenv
12
  - sre
13
  - distributed-training
14
+ - fleet-ai
15
+ - multi-agent
16
+ - long-horizon
17
+ - mercor
18
+ - gpu-infrastructure
19
+ entry_point: "server.app:app"
20
  tasks:
21
  - id: easy
22
  name: "Culprit Rank Identification"
23
  difficulty: easy
24
+ max_steps: 50
25
  description: >
26
+ Training stalled by OOM on one rank. Identify the failing rank
27
+ using PyTorch 2.5 Flight Recorder inspection.
28
  - id: medium
29
  name: "Spine Switch Congestion Resolution"
30
  difficulty: medium
31
+ max_steps: 50
32
  description: >
33
+ Training throughput degraded to 45-65% target due to spine switch
34
+ congestion. Reorder ring topology to restore bandwidth.
35
  - id: hard
36
  name: "Asymmetric Compilation Desync Fix"
37
  difficulty: hard
38
+ max_steps: 50
39
  description: >
40
+ Training hung due to different ranks compiling different NCCL
41
+ collectives. Investigate and patch the divergent source file.
42
+ - id: cascade
43
+ name: "Inter-Version Cascade"
44
+ difficulty: cascade
45
+ max_steps: 120
46
+ description: >
47
+ Severity-1 incident: LD_LIBRARY_PATH corruption loads wrong NCCL
48
+ version (2.21.5 vs 2.27.0), triggering a cascade of OOM →
49
+ congestion → desync across the fleet. Solve all 3 phases in order.
50
  observation_space:
51
  type: object
52
+ description: >
53
+ ClusterObservation with 8-node health states, training metrics,
54
+ surface NCCL logs, step count, episode id, and cumulative token count.
55
  action_space:
56
  type: object
57
+ description: >
58
+ SREAction with action_type and parameters dict. 8 action types
59
+ including inspect_flight_recorder, query_nccl_logs, topo_reorder,
60
+ patch_divergent_code, restart_rank, reset_ib_interface,
61
+ adjust_sharding_strategy, noop.
62
  reward_range: [0.0, 1.0]
63
+ reward_description: >
64
+ Mercor-style efficiency reward: 0.60*R_success + 0.30*R_subgoal
65
+ - 0.10*log(total_tokens). Rewards accurate diagnosis with minimal
66
+ token usage. Destructive actions penalized -0.2.
67
+ multi_agent:
68
+ enabled: true
69
+ architecture: "supervisor-worker"
70
+ workers:
71
+ - log_inspector
72
+ - patch_agent
73
+ - topo_agent
74
+ - version_checker
75
+ endpoint: "/delegate"
76
+ training:
77
+ algorithm: GRPO
78
+ framework: "TRL + Unsloth"
79
+ script: "training/grpo_train.py"
80
+ model: "unsloth/Qwen2.5-7B-Instruct-bnb-4bit"
requirements.txt CHANGED
@@ -1,9 +1,7 @@
1
- fastapi
2
- uvicorn[standard]
3
- pydantic>=2.0
4
- openai
5
- pyyaml
6
- numpy
7
- pytest
8
- requests
9
- openenv-core>=0.2.0
 
1
+ fastapi>=0.111.0
2
+ uvicorn>=0.29.0
3
+ pydantic>=2.7.0
4
+ requests>=2.31.0
5
+ gradio>=4.31.0
6
+ datasets>=2.19.0
7
+ openai>=1.30.0
 
 
training/grpo_train.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================================
2
+ # NervousSystem-Env — GRPO Training Script
3
+ # ============================================================
4
+ # Colab setup (run these first):
5
+ # !pip install unsloth trl datasets transformers accelerate
6
+ # !pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
7
+ # !uvicorn app.main:app --port 7860 & # start env server
8
+ # ============================================================
9
+
10
+ from __future__ import annotations
11
+
12
+ import json
13
+ import math
14
+ import os
15
+ import random
16
+ import re
17
+ from typing import Any
18
+
19
+ import requests
20
+ import torch
21
+ from datasets import Dataset
22
+ from trl import GRPOConfig, GRPOTrainer
23
+ from unsloth import FastLanguageModel
24
+
25
+ ENV_BASE_URL = os.getenv("ENV_BASE_URL", "http://localhost:7860")
26
+ MODEL_NAME = os.getenv("MODEL_NAME", "unsloth/Qwen2.5-7B-Instruct-bnb-4bit")
27
+ MAX_SEQ_LENGTH = 1024
28
+ LORA_RANK = 16
29
+
30
+ SRE_SYSTEM_PROMPT = """You are an SRE agent managing a distributed
31
+ GPU training cluster. Diagnose and fix failures efficiently.
32
+
33
+ IMPORTANT: You are penalized for using too many tokens.
34
+ Reason concisely. Identify the failure type first, then act directly.
35
+
36
+ Available actions (respond with JSON only):
37
+ {"action_type": "inspect_flight_recorder", "parameters": {"rank_id": <0-7>}}
38
+ {"action_type": "query_nccl_logs", "parameters": {"time_window": <int>}}
39
+ {"action_type": "topo_reorder", "parameters": {"affinity": "rack"}}
40
+ {"action_type": "patch_divergent_code", "parameters": {"file": "<path>", "fix_type": "synchronize_conditional"}}
41
+ {"action_type": "noop", "parameters": {}}
42
+
43
+ Rules:
44
+ - Respond ONLY with a JSON object, no explanation
45
+ - Check job_status first: stalled=investigate, running=optimize
46
+ - Use inspect_flight_recorder to find failing ranks
47
+ - Use topo_reorder(affinity="rack") for congestion
48
+ """
49
+
50
+ _current_task_id = "easy"
51
+ _prompt_task_map: dict[str, str] = {}
52
+
53
+
54
+ model, tokenizer = FastLanguageModel.from_pretrained(
55
+ model_name=MODEL_NAME,
56
+ max_seq_length=MAX_SEQ_LENGTH,
57
+ load_in_4bit=True,
58
+ dtype=None,
59
+ )
60
+ model = FastLanguageModel.get_peft_model(
61
+ model,
62
+ r=LORA_RANK,
63
+ target_modules=[
64
+ "q_proj",
65
+ "v_proj",
66
+ "k_proj",
67
+ "o_proj",
68
+ "gate_proj",
69
+ "up_proj",
70
+ "down_proj",
71
+ ],
72
+ lora_alpha=16,
73
+ lora_dropout=0,
74
+ bias="none",
75
+ use_gradient_checkpointing="unsloth",
76
+ random_state=42,
77
+ )
78
+
79
+
80
+ def _safe_post(path: str, payload: dict[str, Any], timeout: int = 10) -> dict[str, Any] | None:
81
+ try:
82
+ response = requests.post(f"{ENV_BASE_URL}{path}", json=payload, timeout=timeout)
83
+ response.raise_for_status()
84
+ return response.json()
85
+ except Exception:
86
+ return None
87
+
88
+
89
+ def _safe_get(path: str, timeout: int = 5) -> dict[str, Any] | None:
90
+ try:
91
+ response = requests.get(f"{ENV_BASE_URL}{path}", timeout=timeout)
92
+ response.raise_for_status()
93
+ return response.json()
94
+ except Exception:
95
+ return None
96
+
97
+
98
+ def _prompt_key(prompt: Any) -> str:
99
+ try:
100
+ return json.dumps(prompt, sort_keys=True)
101
+ except Exception:
102
+ return str(prompt)
103
+
104
+
105
+ def _task_id_from_prompt(prompt: Any) -> str:
106
+ global _current_task_id
107
+ key = _prompt_key(prompt)
108
+ task_id = _prompt_task_map.get(key, _current_task_id)
109
+ _current_task_id = task_id
110
+ return task_id
111
+
112
+
113
+ def _extract_json_action(completion: str) -> dict[str, Any] | None:
114
+ text = completion.strip()
115
+ if text.startswith("```"):
116
+ text = "\n".join(line for line in text.splitlines() if not line.strip().startswith("```"))
117
+ try:
118
+ parsed = json.loads(text)
119
+ if isinstance(parsed, dict) and "action_type" in parsed:
120
+ parsed.setdefault("parameters", {})
121
+ return parsed
122
+ except Exception:
123
+ pass
124
+
125
+ match = re.search(r"\{.*\}", text, flags=re.DOTALL)
126
+ if not match:
127
+ return None
128
+ try:
129
+ parsed = json.loads(match.group(0))
130
+ if isinstance(parsed, dict) and "action_type" in parsed:
131
+ parsed.setdefault("parameters", {})
132
+ return parsed
133
+ except Exception:
134
+ return None
135
+ return None
136
+
137
+
138
+ def make_sre_dataset(n_samples: int = 200) -> Dataset:
139
+ """
140
+ Generate prompt-only dataset for GRPO.
141
+ Each sample is one initial observation from the env.
142
+ GRPO generates completions and scores them via the reward fn.
143
+
144
+ For each sample:
145
+ - Pick task_id randomly from ["easy", "medium", "hard"]
146
+ (skip cascade for initial training — too long)
147
+ - Pick seed randomly from range(1000)
148
+ - Call POST /reset with task_id and seed
149
+ - Format the observation as the user prompt
150
+ - Return dataset with column "prompt" containing
151
+ [{"role": "system", "content": SRE_SYSTEM_PROMPT},
152
+ {"role": "user", "content": <observation_json>}]
153
+
154
+ observation_json format:
155
+ {
156
+ "job_status": ...,
157
+ "throughput": ...,
158
+ "target_throughput": ...,
159
+ "stalled_steps": ...,
160
+ "node_health": [...],
161
+ "visible_logs": [...],
162
+ "task_hint": "Diagnose and fix the cluster failure."
163
+ }
164
+ """
165
+ global _current_task_id
166
+
167
+ rows: list[dict[str, Any]] = []
168
+ task_pool = ["easy", "medium", "hard"]
169
+
170
+ for _ in range(n_samples):
171
+ task_id = random.choice(task_pool)
172
+ seed = random.randint(0, 999)
173
+ reset_result = _safe_post("/reset", {"task_id": task_id, "seed": seed})
174
+ if reset_result is None:
175
+ continue
176
+
177
+ training = reset_result.get("training", {})
178
+ nodes = reset_result.get("nodes", [])
179
+ observation_payload = {
180
+ "job_status": training.get("job_status", "unknown"),
181
+ "throughput": training.get("throughput_tokens_per_sec", 0.0),
182
+ "target_throughput": training.get("target_throughput", 0.0),
183
+ "stalled_steps": training.get("stalled_steps", 0),
184
+ "node_health": [
185
+ {
186
+ "node_id": node.get("node_id"),
187
+ "health_status": node.get("health_status"),
188
+ "xid_errors": node.get("xid_errors", []),
189
+ }
190
+ for node in nodes
191
+ ],
192
+ "visible_logs": reset_result.get("visible_logs", []),
193
+ "task_hint": "Diagnose and fix the cluster failure.",
194
+ }
195
+
196
+ prompt = [
197
+ {"role": "system", "content": SRE_SYSTEM_PROMPT},
198
+ {"role": "user", "content": json.dumps(observation_payload, ensure_ascii=False)},
199
+ ]
200
+
201
+ _current_task_id = task_id
202
+ _prompt_task_map[_prompt_key(prompt)] = task_id
203
+ rows.append({"prompt": prompt})
204
+
205
+ if not rows:
206
+ fallback_prompt = [
207
+ {"role": "system", "content": SRE_SYSTEM_PROMPT},
208
+ {
209
+ "role": "user",
210
+ "content": json.dumps(
211
+ {
212
+ "job_status": "stalled",
213
+ "throughput": 0.0,
214
+ "target_throughput": 9000.0,
215
+ "stalled_steps": 1,
216
+ "node_health": [],
217
+ "visible_logs": ["Server offline during dataset build"],
218
+ "task_hint": "Diagnose and fix the cluster failure.",
219
+ }
220
+ ),
221
+ },
222
+ ]
223
+ _prompt_task_map[_prompt_key(fallback_prompt)] = "easy"
224
+ rows.append({"prompt": fallback_prompt})
225
+
226
+ return Dataset.from_list(rows)
227
+
228
+
229
+ def sre_reward_fn(
230
+ completions: list[str],
231
+ prompts: list[Any],
232
+ **kwargs: Any,
233
+ ) -> list[float]:
234
+ """
235
+ Called by GRPOTrainer to score each completion.
236
+
237
+ For each completion:
238
+ 1. Parse the JSON action from the completion string
239
+ 2. POST the action to /step
240
+ 3. Extract reward.value and reward.token_efficiency_score
241
+ 4. Apply MER formula:
242
+ tokens = len(completion.split()) # word count proxy
243
+ mer = max(0.01, min(0.99,
244
+ 0.60 * r_success + 0.30 * step_reward - 0.10 * math.log(max(1, tokens))
245
+ ))
246
+ where r_success = 1.0 if job_status in {"recovered","running"} else 0.0
247
+ 5. Return mer as the reward for this completion
248
+
249
+ If parse fails or /step errors: return 0.01
250
+ If server is offline: return 0.01
251
+
252
+ IMPORTANT: Each call to sre_reward_fn must first call /reset
253
+ to get a fresh episode state before stepping.
254
+ """
255
+ rewards: list[float] = []
256
+
257
+ for index, completion in enumerate(completions):
258
+ prompt = prompts[index] if index < len(prompts) else None
259
+ task_id = _task_id_from_prompt(prompt)
260
+ seed = random.randint(0, 999)
261
+ reset_result = _safe_post("/reset", {"task_id": task_id, "seed": seed})
262
+ if reset_result is None:
263
+ rewards.append(0.01)
264
+ continue
265
+
266
+ action = _extract_json_action(completion)
267
+ if action is None:
268
+ rewards.append(0.01)
269
+ continue
270
+
271
+ step_result = _safe_post("/step", action)
272
+ if step_result is None:
273
+ rewards.append(0.01)
274
+ continue
275
+
276
+ reward_obj = step_result.get("reward", {})
277
+ step_reward = float(reward_obj.get("value", 0.01))
278
+ observation = step_result.get("observation", {})
279
+ job_status = str(observation.get("training", {}).get("job_status", "unknown"))
280
+ r_success = 1.0 if job_status in {"recovered", "running"} else 0.0
281
+
282
+ tokens = len(completion.split())
283
+ mer = max(
284
+ 0.01,
285
+ min(
286
+ 0.99,
287
+ 0.60 * r_success
288
+ + 0.30 * step_reward
289
+ - 0.10 * math.log(max(1, tokens)),
290
+ ),
291
+ )
292
+ rewards.append(float(mer))
293
+
294
+ while len(rewards) < len(completions):
295
+ rewards.append(0.01)
296
+ return rewards
297
+
298
+
299
+ training_args = GRPOConfig(
300
+ output_dir="./sre_grpo_output",
301
+ num_train_epochs=1,
302
+ per_device_train_batch_size=1,
303
+ gradient_accumulation_steps=4,
304
+ learning_rate=5e-6,
305
+ max_grad_norm=0.1,
306
+ warmup_ratio=0.1,
307
+ lr_scheduler_type="cosine",
308
+ logging_steps=1,
309
+ save_steps=50,
310
+ report_to="none",
311
+ num_generations=4,
312
+ max_new_tokens=128,
313
+ temperature=0.7,
314
+ beta=0.001,
315
+ )
316
+
317
+
318
+ def plot_reward_curve(trainer: GRPOTrainer) -> None:
319
+ """Print reward progression as ASCII bar chart."""
320
+ history = trainer.state.log_history
321
+ rewards = [
322
+ entry.get("reward", entry.get("train/reward", 0.0))
323
+ for entry in history
324
+ if "reward" in entry or "train/reward" in entry
325
+ ]
326
+ if not rewards:
327
+ print("No reward history found.")
328
+ return
329
+ print("\n=== REWARD CURVE ===")
330
+ max_r = max(rewards) if rewards else 1.0
331
+ for i, reward in enumerate(rewards):
332
+ bar = "█" * int((reward / max(0.01, max_r)) * 30)
333
+ print(f" step {i + 1:3d}: {reward:.3f} {bar}")
334
+ print(f"\nInitial reward: {rewards[0]:.3f}")
335
+ print(f"Final reward: {rewards[-1]:.3f}")
336
+ delta = rewards[-1] - rewards[0]
337
+ print(f"Improvement: {delta:+.3f}")
338
+
339
+
340
+ if __name__ == "__main__":
341
+ random.seed(42)
342
+ torch.manual_seed(42)
343
+
344
+ try:
345
+ health = _safe_get("/health")
346
+ assert health is not None and health.get("status") == "ok"
347
+ print(f"✅ Server healthy at {ENV_BASE_URL}")
348
+ except Exception as exc:
349
+ print(f"❌ Server not reachable: {exc}")
350
+ print("Start it with: uvicorn app.main:app --port 7860")
351
+ raise SystemExit(1)
352
+
353
+ dataset = make_sre_dataset(n_samples=200)
354
+ print(f"✅ Dataset: {len(dataset)} samples")
355
+
356
+ trainer = GRPOTrainer(
357
+ model=model,
358
+ args=training_args,
359
+ train_dataset=dataset,
360
+ reward_funcs=sre_reward_fn,
361
+ processing_class=tokenizer,
362
+ )
363
+ trainer.train()
364
+
365
+ plot_reward_curve(trainer)
366
+ model.save_pretrained("sre_agent_lora")
367
+ tokenizer.save_pretrained("sre_agent_lora")
368
+ print("✅ Model saved to sre_agent_lora/")