Spaces:
Running
Running
Ship Round 2 manifest/docs, dashboard, and GRPO training pipeline
Browse files- README.md +102 -184
- dashboard/README.md +14 -0
- dashboard/war_room.py +369 -0
- inference.py +4 -1
- openenv.yaml +55 -18
- requirements.txt +7 -9
- training/grpo_train.py +368 -0
README.md
CHANGED
|
@@ -23,232 +23,150 @@ models like the one evaluating this environment.
|
|
| 23 |
## Why This Matters
|
| 24 |
|
| 25 |
Large-scale AI training runs on clusters of hundreds of
|
| 26 |
-
|
| 27 |
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
- **Code desynchronization** across ranks hangs jobs silently
|
| 31 |
|
| 32 |
-
|
| 33 |
-
There is no standardized benchmark for evaluating whether
|
| 34 |
-
AI agents can handle these failures autonomously.
|
| 35 |
|
| 36 |
-
NervousSystem-Env
|
| 37 |
|
| 38 |
-
|
| 39 |
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
-
##
|
| 43 |
|
| 44 |
-
|
| 45 |
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
**Key design:** Deep diagnostic data (Flight Recorder
|
| 55 |
-
buffers, NCCL logs) is hidden by default. The agent must
|
| 56 |
-
actively query for it using investigation actions.
|
| 57 |
-
|
| 58 |
-
### Action Space
|
| 59 |
-
|
| 60 |
-
| Action | Parameters | Destructive | Description |
|
| 61 |
-
|---|---|---|---|
|
| 62 |
-
| `inspect_flight_recorder` | `rank_id: int` | No | Get PyTorch Flight Recorder data for a rank |
|
| 63 |
-
| `query_nccl_logs` | `time_window: int` | No | Get NCCL communication log entries |
|
| 64 |
-
| `topo_reorder` | `affinity: str` | No | Reorder ring topology (use "rack" for fix) |
|
| 65 |
-
| `patch_divergent_code` | `file: str, fix_type: str` | No | Patch desynchronized code |
|
| 66 |
-
| `restart_rank` | `rank_id: int` | **Yes** | Restart a specific rank (-0.2 penalty) |
|
| 67 |
-
| `reset_ib_interface` | `node_id: int` | **Yes** | Reset IB interface (-0.2 penalty) |
|
| 68 |
-
| `adjust_sharding_strategy` | `strategy: str` | No | Change sharding strategy |
|
| 69 |
-
| `noop` | none | No | Take no action |
|
| 70 |
-
|
| 71 |
-
---
|
| 72 |
|
| 73 |
## Tasks
|
| 74 |
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
The agent must use `inspect_flight_recorder(rank_id)` to
|
| 83 |
-
examine each rank's Flight Recorder buffer and identify
|
| 84 |
-
which rank has a stalled collective sequence.
|
| 85 |
-
|
| 86 |
-
**Grader:** 1.0 for correct rank identified, 0.0 otherwise.
|
| 87 |
-
Efficiency bonus up to +0.2 for early diagnosis.
|
| 88 |
-
Penalty -0.1 per destructive action taken.
|
| 89 |
-
|
| 90 |
-
**Anti-cheat:** The failing rank is randomly seeded on
|
| 91 |
-
every `reset()` call. Hardcoding a rank ID scores 0.0.
|
| 92 |
-
|
| 93 |
-
---
|
| 94 |
-
|
| 95 |
-
### Medium — Spine Switch Congestion Resolution
|
| 96 |
-
**Difficulty:** Medium
|
| 97 |
|
| 98 |
-
|
| 99 |
-
The ring topology stretches across oversubscribed spine
|
| 100 |
-
switches. The agent must call `topo_reorder(affinity="rack")`
|
| 101 |
-
to enforce rack-local communication.
|
| 102 |
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
Penalty -0.2 per destructive action.
|
| 107 |
-
|
| 108 |
-
---
|
| 109 |
-
|
| 110 |
-
### Hard — Asymmetric Compilation Desync Fix
|
| 111 |
-
**Difficulty:** Hard
|
| 112 |
-
|
| 113 |
-
Training is completely hung. Different ranks compiled
|
| 114 |
-
different NCCL collectives due to data-dependent branching
|
| 115 |
-
in the model code. The job will never recover on its own.
|
| 116 |
-
|
| 117 |
-
The agent must:
|
| 118 |
-
1. Investigate using `query_nccl_logs` or
|
| 119 |
-
`inspect_flight_recorder`
|
| 120 |
-
2. Identify the divergent source file using
|
| 121 |
-
`patch_divergent_code(file=..., fix_type=...)`
|
| 122 |
-
3. Verify training resumes for 5+ steps
|
| 123 |
-
|
| 124 |
-
**Grader:** 3-stage scoring:
|
| 125 |
-
- 0.3 for identifying the correct file
|
| 126 |
-
- +0.4 for applying the correct patch
|
| 127 |
-
- +0.3 for sustained training recovery (5+ steps)
|
| 128 |
-
= 1.0 maximum
|
| 129 |
-
|
| 130 |
-
---
|
| 131 |
-
|
| 132 |
-
## Reward Function
|
| 133 |
-
|
| 134 |
-
Rewards are continuous — the agent receives signal at
|
| 135 |
-
every step, not just at episode end.
|
| 136 |
-
|
| 137 |
-
| Situation | Reward |
|
| 138 |
-
|---|---|
|
| 139 |
-
| Correct rank identified (easy) | +0.5 |
|
| 140 |
-
| Investigation action taken | +0.05 |
|
| 141 |
-
| Throughput improvement (medium) | proportional to ratio |
|
| 142 |
-
| Correct file identified (hard) | +0.3 |
|
| 143 |
-
| Correct patch applied (hard) | +0.7 cumulative |
|
| 144 |
-
| Training recovered 5+ steps | +0.3 |
|
| 145 |
-
| Destructive action taken | -0.2 |
|
| 146 |
-
| Noop | 0.0 |
|
| 147 |
-
|
| 148 |
-
---
|
| 149 |
-
|
| 150 |
-
## Setup and Usage
|
| 151 |
|
| 152 |
-
|
| 153 |
-
``
|
| 154 |
-
|
| 155 |
-
docker build -t nervousystem-env .
|
| 156 |
|
| 157 |
-
#
|
| 158 |
-
docker run -p 7860:7860 nervousystem-env
|
| 159 |
|
| 160 |
-
# Verify
|
| 161 |
-
curl http://localhost:7860/health
|
| 162 |
-
```
|
| 163 |
-
|
| 164 |
-
### Run locally
|
| 165 |
```bash
|
| 166 |
-
# Install
|
| 167 |
pip install -r requirements.txt
|
| 168 |
|
| 169 |
-
# Start server
|
| 170 |
uvicorn app.main:app --host 0.0.0.0 --port 7860
|
| 171 |
|
| 172 |
-
#
|
| 173 |
-
|
| 174 |
-
export MODEL_NAME=meta-llama/Llama-3.1-8B-Instruct
|
| 175 |
-
export HF_TOKEN=your_token_here
|
| 176 |
-
export ENV_BASE_URL=http://localhost:7860
|
| 177 |
-
python inference.py
|
| 178 |
-
```
|
| 179 |
-
|
| 180 |
-
### Environment Variables
|
| 181 |
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
| `API_BASE_URL` | Yes | LLM API endpoint |
|
| 185 |
-
| `MODEL_NAME` | Yes | Model identifier |
|
| 186 |
-
| `HF_TOKEN` | Yes | HuggingFace API token |
|
| 187 |
-
| `ENV_BASE_URL` | No | Env server URL (default: http://localhost:7860) |
|
| 188 |
|
| 189 |
-
|
|
|
|
|
|
|
| 190 |
|
| 191 |
## API Endpoints
|
| 192 |
|
| 193 |
| Endpoint | Method | Description |
|
| 194 |
|---|---|---|
|
| 195 |
| `/health` | GET | Health check |
|
| 196 |
-
| `/reset` | POST |
|
| 197 |
-
| `/step` | POST |
|
| 198 |
-
| `/state` | GET |
|
| 199 |
-
| `/grade` | POST |
|
| 200 |
-
| `/tasks` | GET | List
|
|
|
|
| 201 |
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
## Baseline Scores
|
| 205 |
|
| 206 |
-
|
| 207 |
-
|
|
|
|
|
|
|
| 208 |
|
| 209 |
-
|
| 210 |
-
|---|---|---|
|
| 211 |
-
| easy | TBD | TBD |
|
| 212 |
-
| medium | TBD | TBD |
|
| 213 |
-
| hard | TBD | TBD |
|
| 214 |
|
| 215 |
-
|
| 216 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
|
| 218 |
-
|
| 219 |
|
| 220 |
## Project Structure
|
| 221 |
-
|
|
|
|
| 222 |
nervousystem-env/
|
| 223 |
├── app/
|
| 224 |
-
│ ├──
|
| 225 |
-
│ ├── env.py
|
| 226 |
-
│ ├──
|
| 227 |
-
│ └──
|
| 228 |
-
├── simulation/
|
| 229 |
-
│ ├── cluster.py # GPU cluster state machine
|
| 230 |
-
│ ├── failures.py # Failure injection
|
| 231 |
-
│ └── telemetry.py # Log generation
|
| 232 |
-
├── tasks/
|
| 233 |
-
│ ├── easy.py # Culprit rank identification
|
| 234 |
-
│ ├── medium.py # Congestion resolution
|
| 235 |
-
│ └── hard.py # Desync fix
|
| 236 |
├── graders/
|
|
|
|
|
|
|
| 237 |
│ ├── easy_grader.py
|
| 238 |
-
│ ├──
|
| 239 |
-
│ └──
|
| 240 |
-
├──
|
| 241 |
-
├──
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 242 |
├── openenv.yaml
|
| 243 |
-
|
|
|
|
|
|
|
| 244 |
```
|
| 245 |
|
| 246 |
-
---
|
| 247 |
-
|
| 248 |
## OpenEnv Compliance
|
| 249 |
|
| 250 |
-
- `openenv validate` passes
|
| 251 |
-
- Typed Pydantic v2 models
|
| 252 |
-
- Deterministic graders
|
| 253 |
-
- Docker deployment
|
| 254 |
-
-
|
|
|
|
|
|
|
|
|
| 23 |
## Why This Matters
|
| 24 |
|
| 25 |
Large-scale AI training runs on clusters of hundreds of
|
| 26 |
+
# 🧠 NervousSystem-Env
|
| 27 |
|
| 28 |
+
> An AI agent fixing the infrastructure that trains AI.
|
| 29 |
+
> Every minute of cluster downtime wastes $5,000 in compute.
|
|
|
|
| 30 |
|
| 31 |
+
## The Problem
|
|
|
|
|
|
|
| 32 |
|
| 33 |
+
Large-scale AI training across 1000+ GPU clusters fails constantly due to hardware faults, network bottlenecks, distributed synchronization bugs, and runtime version drift. Human SREs are forced to diagnose these incidents at 3am under extreme time pressure. NervousSystem-Env turns that operational pain into a training environment where autonomous agents learn to detect failures, route work to specialist workers, and recover jobs before expensive downtime compounds.
|
| 34 |
|
| 35 |
+
## Why This Matters
|
| 36 |
|
| 37 |
+
- GPU OOM (XID 79): stalls entire training job.
|
| 38 |
+
- Spine switch congestion: cuts throughput 40%+.
|
| 39 |
+
- Compilation desync: hangs job permanently.
|
| 40 |
+
- LD_LIBRARY_PATH cascade: Severity-1 fleet-wide incident.
|
| 41 |
|
| 42 |
+
## Architecture
|
| 43 |
|
| 44 |
+
NervousSystem-Env uses a Fleet AI Supervisor-Worker design. A supervisor agent receives global cluster state and delegates targeted sub-tasks to specialist workers via `/delegate`. Workers return structured results with confidence and coordination reward signals, enabling multi-agent training for routing, diagnosis, and remediation.
|
| 45 |
|
| 46 |
+
```text
|
| 47 |
+
Supervisor Agent
|
| 48 |
+
│
|
| 49 |
+
├── LogInspectorWorker (flight recorder, NCCL logs)
|
| 50 |
+
├── PatchAgentWorker (code patching, verification)
|
| 51 |
+
├── TopoAgentWorker (topology, bandwidth)
|
| 52 |
+
└── VersionCheckerWorker (NCCL version, LD_LIBRARY_PATH)
|
| 53 |
+
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
|
| 55 |
## Tasks
|
| 56 |
|
| 57 |
+
| Task | Difficulty | Max Steps | Failure Type | Key Actions |
|
| 58 |
+
|---|---:|---:|---|---|
|
| 59 |
+
| easy | easy | 50 | OOM rank failure | `inspect_flight_recorder` |
|
| 60 |
+
| medium | medium | 50 | network congestion | `topo_reorder(affinity="rack")` |
|
| 61 |
+
| hard | hard | 50 | collective desync | `query_nccl_logs`, `patch_divergent_code` |
|
| 62 |
+
| cascade | cascade | 120 | version cascade (OOM→congestion→desync) | ordered multi-phase recovery |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
|
| 64 |
+
## Reward Model
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
+
```text
|
| 67 |
+
Reward = 0.60 * R_success + 0.30 * R_subgoal - 0.10 * log(total_tokens)
|
| 68 |
+
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
+
- `R_success`: binary completion signal (recovered/running within step limit).
|
| 71 |
+
- `R_subgoal`: continuous task-progress score.
|
| 72 |
+
- `log(total_tokens)`: efficiency penalty to discourage verbose reasoning.
|
|
|
|
| 73 |
|
| 74 |
+
## Quick Start
|
|
|
|
| 75 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
```bash
|
| 77 |
+
# Install
|
| 78 |
pip install -r requirements.txt
|
| 79 |
|
| 80 |
+
# Start environment server
|
| 81 |
uvicorn app.main:app --host 0.0.0.0 --port 7860
|
| 82 |
|
| 83 |
+
# Start war room dashboard
|
| 84 |
+
python dashboard/war_room.py
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
|
| 86 |
+
# Run baseline agent
|
| 87 |
+
python inference.py
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
|
| 89 |
+
# Train with GRPO
|
| 90 |
+
python training/grpo_train.py
|
| 91 |
+
```
|
| 92 |
|
| 93 |
## API Endpoints
|
| 94 |
|
| 95 |
| Endpoint | Method | Description |
|
| 96 |
|---|---|---|
|
| 97 |
| `/health` | GET | Health check |
|
| 98 |
+
| `/reset` | POST | Reset episode by task and seed |
|
| 99 |
+
| `/step` | POST | Apply one SRE action |
|
| 100 |
+
| `/state` | GET | Fetch current observation |
|
| 101 |
+
| `/grade` | POST | Grade current episode |
|
| 102 |
+
| `/tasks` | GET | List task metadata |
|
| 103 |
+
| `/delegate` | POST | Supervisor delegates to worker agent |
|
| 104 |
|
| 105 |
+
## Hackathon Themes
|
|
|
|
|
|
|
| 106 |
|
| 107 |
+
- Theme 1 (Fleet AI): Supervisor-Worker with `/delegate` endpoint.
|
| 108 |
+
- Theme 2 (Long-Horizon): Cascade task (120 steps), Mercor reward shaping.
|
| 109 |
+
- Theme 3.1 (Professional Tasks): NCCL diagnostics + Flight Recorder v2.5 workflow.
|
| 110 |
+
- Theme 4 (Self-Improvement): Adversarial curriculum via seeded failure permutations.
|
| 111 |
|
| 112 |
+
## Training Results
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
|
| 114 |
+
| Task | Baseline | Trained | Improvement |
|
| 115 |
+
|---|---:|---:|---:|
|
| 116 |
+
| easy | TBD | TBD | TBD |
|
| 117 |
+
| medium | TBD | TBD | TBD |
|
| 118 |
+
| hard | TBD | TBD | TBD |
|
| 119 |
+
| cascade | TBD | TBD | TBD |
|
| 120 |
|
| 121 |
+
Run `python training/grpo_train.py` to reproduce.
|
| 122 |
|
| 123 |
## Project Structure
|
| 124 |
+
|
| 125 |
+
```text
|
| 126 |
nervousystem-env/
|
| 127 |
├── app/
|
| 128 |
+
│ ├── config.py
|
| 129 |
+
│ ├── env.py
|
| 130 |
+
│ ├── main.py
|
| 131 |
+
│ └── models.py
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
├── graders/
|
| 133 |
+
│ ├── base.py
|
| 134 |
+
│ ├── cascade_grader.py
|
| 135 |
│ ├── easy_grader.py
|
| 136 |
+
│ ├── hard_grader.py
|
| 137 |
+
│ └── medium_grader.py
|
| 138 |
+
├── simulation/
|
| 139 |
+
│ ├── cluster.py
|
| 140 |
+
│ ├── failures.py
|
| 141 |
+
│ ├── fleet.py
|
| 142 |
+
│ └── telemetry.py
|
| 143 |
+
├── tasks/
|
| 144 |
+
│ ├── base.py
|
| 145 |
+
│ ├── cascade.py
|
| 146 |
+
│ ├── easy.py
|
| 147 |
+
│ ├── hard.py
|
| 148 |
+
│ └── medium.py
|
| 149 |
+
├── dashboard/
|
| 150 |
+
│ ├── README.md
|
| 151 |
+
│ └── war_room.py
|
| 152 |
+
├── training/
|
| 153 |
+
│ └── grpo_train.py
|
| 154 |
+
├── tests/
|
| 155 |
+
│ ├── test_fleet.py
|
| 156 |
+
│ └── test_graders.py
|
| 157 |
+
├── inference.py
|
| 158 |
├── openenv.yaml
|
| 159 |
+
├── requirements.txt
|
| 160 |
+
└── server/
|
| 161 |
+
└── app.py
|
| 162 |
```
|
| 163 |
|
|
|
|
|
|
|
| 164 |
## OpenEnv Compliance
|
| 165 |
|
| 166 |
+
- `openenv validate` passes.
|
| 167 |
+
- Typed Pydantic v2 models.
|
| 168 |
+
- Deterministic graders.
|
| 169 |
+
- Docker deployment.
|
| 170 |
+
- 4 tasks with difficulty progression.
|
| 171 |
+
- Multi-agent `/delegate` endpoint.
|
| 172 |
+
# In another terminal, run inference
|
dashboard/README.md
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SRE War Room Dashboard
|
| 2 |
+
|
| 3 |
+
Run the environment server first on `7860`, then launch the Gradio dashboard on `7861`.
|
| 4 |
+
|
| 5 |
+
```bash
|
| 6 |
+
uvicorn app.main:app --host 0.0.0.0 --port 7860
|
| 7 |
+
python dashboard/war_room.py
|
| 8 |
+
```
|
| 9 |
+
|
| 10 |
+
Optional custom server URL:
|
| 11 |
+
|
| 12 |
+
```bash
|
| 13 |
+
ENV_BASE_URL=http://localhost:7860 python dashboard/war_room.py
|
| 14 |
+
```
|
dashboard/war_room.py
ADDED
|
@@ -0,0 +1,369 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
from datetime import datetime
|
| 6 |
+
|
| 7 |
+
import gradio as gr
|
| 8 |
+
import requests
|
| 9 |
+
|
| 10 |
+
ENV_BASE_URL = os.getenv("ENV_BASE_URL", "http://localhost:7860")
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def render_ring(nodes: list[dict]) -> str:
|
| 14 |
+
"""Return an HTML string with 8 colored divs arranged in a circle."""
|
| 15 |
+
health_to_color = {
|
| 16 |
+
"healthy": "#22c55e",
|
| 17 |
+
"degraded": "#eab308",
|
| 18 |
+
"failed": "#ef4444",
|
| 19 |
+
}
|
| 20 |
+
health_to_emoji = {
|
| 21 |
+
"healthy": "🟢",
|
| 22 |
+
"degraded": "🟡",
|
| 23 |
+
"failed": "🔴",
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
padded_nodes = list(nodes[:8])
|
| 27 |
+
while len(padded_nodes) < 8:
|
| 28 |
+
padded_nodes.append(
|
| 29 |
+
{
|
| 30 |
+
"node_id": len(padded_nodes),
|
| 31 |
+
"health_status": "failed",
|
| 32 |
+
"gpu_memory_used_mb": 0,
|
| 33 |
+
"xid_errors": [],
|
| 34 |
+
}
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
cards: list[str] = []
|
| 38 |
+
for index, node in enumerate(padded_nodes):
|
| 39 |
+
health = str(node.get("health_status", "failed"))
|
| 40 |
+
color = health_to_color.get(health, "#ef4444")
|
| 41 |
+
emoji = health_to_emoji.get(health, "🔴")
|
| 42 |
+
node_id = node.get("node_id", index)
|
| 43 |
+
gpu_mem = node.get("gpu_memory_used_mb", 0)
|
| 44 |
+
xid_errors = node.get("xid_errors", [])
|
| 45 |
+
xid_text = ",".join(str(code) for code in xid_errors) if xid_errors else "none"
|
| 46 |
+
angle = index * 45
|
| 47 |
+
cards.append(
|
| 48 |
+
f"""
|
| 49 |
+
<div class='node-card' style='background:{color};
|
| 50 |
+
transform: rotate({angle}deg) translate(155px) rotate(-{angle}deg);'>
|
| 51 |
+
<div><strong>{emoji} node {node_id}</strong></div>
|
| 52 |
+
<div>health: {health}</div>
|
| 53 |
+
<div>gpu_mem: {float(gpu_mem):.0f} MB</div>
|
| 54 |
+
<div>xid: {xid_text}</div>
|
| 55 |
+
</div>
|
| 56 |
+
"""
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
return f"""
|
| 60 |
+
<style>
|
| 61 |
+
.ring-wrap {{
|
| 62 |
+
position: relative;
|
| 63 |
+
width: 420px;
|
| 64 |
+
height: 420px;
|
| 65 |
+
margin: 0 auto;
|
| 66 |
+
border-radius: 50%;
|
| 67 |
+
background: radial-gradient(circle, #0b1220 0%, #111827 65%, #1f2937 100%);
|
| 68 |
+
border: 1px solid #374151;
|
| 69 |
+
}}
|
| 70 |
+
.ring-center {{
|
| 71 |
+
position: absolute;
|
| 72 |
+
left: 50%; top: 50%;
|
| 73 |
+
transform: translate(-50%, -50%);
|
| 74 |
+
color: #d1d5db;
|
| 75 |
+
font-weight: 700;
|
| 76 |
+
font-size: 14px;
|
| 77 |
+
}}
|
| 78 |
+
.node-card {{
|
| 79 |
+
position: absolute;
|
| 80 |
+
left: 50%;
|
| 81 |
+
top: 50%;
|
| 82 |
+
width: 132px;
|
| 83 |
+
min-height: 72px;
|
| 84 |
+
margin-left: -66px;
|
| 85 |
+
margin-top: -36px;
|
| 86 |
+
border-radius: 10px;
|
| 87 |
+
padding: 8px;
|
| 88 |
+
color: #111827;
|
| 89 |
+
box-shadow: 0 6px 20px rgba(0,0,0,0.25);
|
| 90 |
+
font-size: 11px;
|
| 91 |
+
line-height: 1.2;
|
| 92 |
+
}}
|
| 93 |
+
</style>
|
| 94 |
+
<div class='ring-wrap'>
|
| 95 |
+
<div class='ring-center'>Cluster Ring</div>
|
| 96 |
+
{''.join(cards)}
|
| 97 |
+
</div>
|
| 98 |
+
"""
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def _safe_get(path: str) -> dict | None:
|
| 102 |
+
try:
|
| 103 |
+
response = requests.get(f"{ENV_BASE_URL}{path}", timeout=5)
|
| 104 |
+
response.raise_for_status()
|
| 105 |
+
return response.json()
|
| 106 |
+
except Exception:
|
| 107 |
+
return None
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def _safe_post(path: str, payload: dict) -> dict | None:
|
| 111 |
+
try:
|
| 112 |
+
response = requests.post(f"{ENV_BASE_URL}{path}", json=payload, timeout=8)
|
| 113 |
+
response.raise_for_status()
|
| 114 |
+
return response.json()
|
| 115 |
+
except Exception:
|
| 116 |
+
return None
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def _offline_panel(action_log: list[list]) -> tuple:
|
| 120 |
+
offline_row = [["-", "offline", 0.0, 0.0, "⚠️ Server offline"]]
|
| 121 |
+
return (
|
| 122 |
+
"<h3>⚠️ Server offline</h3>",
|
| 123 |
+
"⚠️ Server offline",
|
| 124 |
+
0.0,
|
| 125 |
+
0.0,
|
| 126 |
+
0.0,
|
| 127 |
+
0.0,
|
| 128 |
+
0.0,
|
| 129 |
+
"## ⚠️ Server offline",
|
| 130 |
+
action_log[-20:] if action_log else offline_row,
|
| 131 |
+
action_log,
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def _panel_from_state(state: dict, action_log: list[list]) -> tuple:
|
| 136 |
+
nodes = state.get("nodes", [])
|
| 137 |
+
training = state.get("training", {})
|
| 138 |
+
throughput = float(training.get("throughput_tokens_per_sec", 0.0))
|
| 139 |
+
target = float(training.get("target_throughput", 1.0))
|
| 140 |
+
stalled_steps = float(training.get("stalled_steps", 0.0))
|
| 141 |
+
status = str(training.get("job_status", "unknown"))
|
| 142 |
+
cumulative_tokens = float(state.get("cumulative_tokens", 0))
|
| 143 |
+
throughput_pct = (throughput / max(1.0, target)) * 100.0
|
| 144 |
+
simulated_loss_prevented = stalled_steps * 83.33
|
| 145 |
+
loss_text = f"## 💰 ${simulated_loss_prevented:,.2f} saved"
|
| 146 |
+
|
| 147 |
+
return (
|
| 148 |
+
render_ring(nodes),
|
| 149 |
+
status,
|
| 150 |
+
throughput,
|
| 151 |
+
throughput_pct,
|
| 152 |
+
stalled_steps,
|
| 153 |
+
cumulative_tokens,
|
| 154 |
+
simulated_loss_prevented,
|
| 155 |
+
loss_text,
|
| 156 |
+
action_log[-20:],
|
| 157 |
+
action_log,
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def refresh_panels(task_id: str, action_log: list[list]) -> tuple:
|
| 162 |
+
"""Refresh dashboard panels from live server state."""
|
| 163 |
+
_ = task_id
|
| 164 |
+
state = _safe_get("/state")
|
| 165 |
+
if state is None:
|
| 166 |
+
return _offline_panel(action_log)
|
| 167 |
+
return _panel_from_state(state, action_log)
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def reset_episode(task_id: str) -> tuple:
|
| 171 |
+
"""Reset episode for selected task and clear action log."""
|
| 172 |
+
result = _safe_post("/reset", {"task_id": task_id})
|
| 173 |
+
if result is None:
|
| 174 |
+
offline = _offline_panel([])
|
| 175 |
+
return (*offline, gr.update(active=True))
|
| 176 |
+
panel = _panel_from_state(result, [])
|
| 177 |
+
return (*panel, gr.update(active=True))
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def _demo_actions() -> list[dict]:
|
| 181 |
+
return [
|
| 182 |
+
{"action_type": "inspect_flight_recorder", "parameters": {"rank_id": 0}},
|
| 183 |
+
{"action_type": "inspect_flight_recorder", "parameters": {"rank_id": 1}},
|
| 184 |
+
{"action_type": "inspect_flight_recorder", "parameters": {"rank_id": 2}},
|
| 185 |
+
{"action_type": "query_nccl_logs", "parameters": {"time_window": 5}},
|
| 186 |
+
{"action_type": "query_nccl_logs", "parameters": {"time_window": 5}},
|
| 187 |
+
{"action_type": "topo_reorder", "parameters": {"affinity": "rack"}},
|
| 188 |
+
{"action_type": "topo_reorder", "parameters": {"affinity": "rack"}},
|
| 189 |
+
{"action_type": "noop", "parameters": {}},
|
| 190 |
+
{"action_type": "noop", "parameters": {}},
|
| 191 |
+
{"action_type": "noop", "parameters": {}},
|
| 192 |
+
]
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def run_demo_agent(task_id: str, action_log: list[list]) -> tuple:
|
| 196 |
+
"""Run exactly 10 hardcoded demo steps for visualization."""
|
| 197 |
+
_ = _safe_post("/reset", {"task_id": task_id})
|
| 198 |
+
rows = list(action_log)
|
| 199 |
+
|
| 200 |
+
for action in _demo_actions():
|
| 201 |
+
step_result = _safe_post("/step", action)
|
| 202 |
+
if step_result is None:
|
| 203 |
+
return (*_offline_panel(rows), gr.update(active=True))
|
| 204 |
+
reward = step_result.get("reward", {})
|
| 205 |
+
observation = step_result.get("observation", {})
|
| 206 |
+
step_num = observation.get("step_count", len(rows) + 1)
|
| 207 |
+
row = [
|
| 208 |
+
step_num,
|
| 209 |
+
action["action_type"],
|
| 210 |
+
float(reward.get("value", 0.0)),
|
| 211 |
+
float(reward.get("token_efficiency_score", 0.0)),
|
| 212 |
+
f"{reward.get('info', '')} @ {datetime.utcnow().isoformat(timespec='seconds')}",
|
| 213 |
+
]
|
| 214 |
+
rows.append(row)
|
| 215 |
+
|
| 216 |
+
state = _safe_get("/state")
|
| 217 |
+
if state is None:
|
| 218 |
+
return (*_offline_panel(rows), gr.update(active=True))
|
| 219 |
+
panel = _panel_from_state(state, rows)
|
| 220 |
+
return (*panel, gr.update(active=True))
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def stop_refresh():
|
| 224 |
+
"""Stop the auto-refresh timer."""
|
| 225 |
+
return gr.update(active=False)
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def delegate_task(worker: str, action: str) -> dict:
|
| 229 |
+
"""Submit delegation request to /delegate endpoint."""
|
| 230 |
+
payload = {
|
| 231 |
+
"worker": worker,
|
| 232 |
+
"action": action,
|
| 233 |
+
"parameters": {},
|
| 234 |
+
"supervisor_reasoning": f"War Room delegation at {datetime.utcnow().isoformat()}",
|
| 235 |
+
"token_count": 0,
|
| 236 |
+
}
|
| 237 |
+
result = _safe_post("/delegate", payload)
|
| 238 |
+
if result is None:
|
| 239 |
+
return {
|
| 240 |
+
"worker": worker,
|
| 241 |
+
"action": action,
|
| 242 |
+
"success": False,
|
| 243 |
+
"output": {"error": "⚠️ Server offline"},
|
| 244 |
+
"confidence": 0.0,
|
| 245 |
+
"coordination_reward": 0.0,
|
| 246 |
+
"explanation": "⚠️ Server offline",
|
| 247 |
+
"cumulative_coordination_reward": 0.0,
|
| 248 |
+
"raw": json.dumps(payload),
|
| 249 |
+
}
|
| 250 |
+
return result
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
with gr.Blocks(title="SRE War Room") as demo:
|
| 254 |
+
gr.Markdown("# 🛠️ SRE War Room")
|
| 255 |
+
gr.Markdown(f"Connected env: `{ENV_BASE_URL}`")
|
| 256 |
+
|
| 257 |
+
with gr.Row():
|
| 258 |
+
task_dropdown = gr.Dropdown(
|
| 259 |
+
choices=["easy", "medium", "hard", "cascade"],
|
| 260 |
+
value="easy",
|
| 261 |
+
label="Task",
|
| 262 |
+
)
|
| 263 |
+
reset_btn = gr.Button("Reset Episode", variant="primary")
|
| 264 |
+
demo_btn = gr.Button("Run Demo Agent")
|
| 265 |
+
stop_btn = gr.Button("Stop")
|
| 266 |
+
|
| 267 |
+
with gr.Row():
|
| 268 |
+
with gr.Column(scale=2):
|
| 269 |
+
gr.Markdown("## Panel A: Cluster Ring Topology")
|
| 270 |
+
ring_html = gr.HTML(render_ring([]))
|
| 271 |
+
with gr.Column(scale=1):
|
| 272 |
+
gr.Markdown("## Panel B: Training Metrics")
|
| 273 |
+
job_status_label = gr.Label(label="job_status", value="unknown")
|
| 274 |
+
throughput_num = gr.Number(label="throughput_tokens_per_sec", value=0.0)
|
| 275 |
+
throughput_pct_num = gr.Number(label="throughput_%_of_target", value=0.0)
|
| 276 |
+
stalled_steps_num = gr.Number(label="stalled_steps", value=0.0)
|
| 277 |
+
cumulative_tokens_num = gr.Number(label="cumulative_tokens", value=0.0)
|
| 278 |
+
loss_num = gr.Number(label="Simulated Loss Prevented $", value=0.0)
|
| 279 |
+
loss_text = gr.Markdown("## 💰 $0.00 saved")
|
| 280 |
+
|
| 281 |
+
with gr.Row():
|
| 282 |
+
with gr.Column(scale=3):
|
| 283 |
+
gr.Markdown("## Panel C: Agent Action Log")
|
| 284 |
+
action_df = gr.Dataframe(
|
| 285 |
+
headers=["step", "action_type", "reward", "mer_score", "info"],
|
| 286 |
+
value=[],
|
| 287 |
+
row_count=20,
|
| 288 |
+
column_count=(5, "fixed"),
|
| 289 |
+
datatype=["number", "str", "number", "number", "str"],
|
| 290 |
+
wrap=True,
|
| 291 |
+
)
|
| 292 |
+
with gr.Column(scale=2):
|
| 293 |
+
gr.Markdown("## Fleet Delegation")
|
| 294 |
+
worker_dropdown = gr.Dropdown(
|
| 295 |
+
choices=["log_inspector", "patch_agent", "topo_agent", "version_checker"],
|
| 296 |
+
value="log_inspector",
|
| 297 |
+
label="worker",
|
| 298 |
+
)
|
| 299 |
+
delegation_action = gr.Textbox(value="check_nccl_version", label="action")
|
| 300 |
+
delegate_btn = gr.Button("Delegate")
|
| 301 |
+
delegate_json = gr.JSON(label="Last delegation result")
|
| 302 |
+
|
| 303 |
+
action_log_state = gr.State([])
|
| 304 |
+
refresh_timer = gr.Timer(value=2.0, active=True)
|
| 305 |
+
|
| 306 |
+
refresh_timer.tick(
|
| 307 |
+
fn=refresh_panels,
|
| 308 |
+
inputs=[task_dropdown, action_log_state],
|
| 309 |
+
outputs=[
|
| 310 |
+
ring_html,
|
| 311 |
+
job_status_label,
|
| 312 |
+
throughput_num,
|
| 313 |
+
throughput_pct_num,
|
| 314 |
+
stalled_steps_num,
|
| 315 |
+
cumulative_tokens_num,
|
| 316 |
+
loss_num,
|
| 317 |
+
loss_text,
|
| 318 |
+
action_df,
|
| 319 |
+
action_log_state,
|
| 320 |
+
],
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
reset_btn.click(
|
| 324 |
+
fn=reset_episode,
|
| 325 |
+
inputs=[task_dropdown],
|
| 326 |
+
outputs=[
|
| 327 |
+
ring_html,
|
| 328 |
+
job_status_label,
|
| 329 |
+
throughput_num,
|
| 330 |
+
throughput_pct_num,
|
| 331 |
+
stalled_steps_num,
|
| 332 |
+
cumulative_tokens_num,
|
| 333 |
+
loss_num,
|
| 334 |
+
loss_text,
|
| 335 |
+
action_df,
|
| 336 |
+
action_log_state,
|
| 337 |
+
refresh_timer,
|
| 338 |
+
],
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
demo_btn.click(
|
| 342 |
+
fn=run_demo_agent,
|
| 343 |
+
inputs=[task_dropdown, action_log_state],
|
| 344 |
+
outputs=[
|
| 345 |
+
ring_html,
|
| 346 |
+
job_status_label,
|
| 347 |
+
throughput_num,
|
| 348 |
+
throughput_pct_num,
|
| 349 |
+
stalled_steps_num,
|
| 350 |
+
cumulative_tokens_num,
|
| 351 |
+
loss_num,
|
| 352 |
+
loss_text,
|
| 353 |
+
action_df,
|
| 354 |
+
action_log_state,
|
| 355 |
+
refresh_timer,
|
| 356 |
+
],
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
stop_btn.click(fn=stop_refresh, outputs=[refresh_timer])
|
| 360 |
+
|
| 361 |
+
delegate_btn.click(
|
| 362 |
+
fn=delegate_task,
|
| 363 |
+
inputs=[worker_dropdown, delegation_action],
|
| 364 |
+
outputs=[delegate_json],
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
if __name__ == "__main__":
|
| 369 |
+
demo.launch(server_port=7861, share=False)
|
inference.py
CHANGED
|
@@ -69,7 +69,7 @@ MODEL_NAME = os.getenv(
|
|
| 69 |
)
|
| 70 |
API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY", "")
|
| 71 |
ENV_BASE_URL = os.getenv("ENV_BASE_URL", "http://localhost:7860")
|
| 72 |
-
MAX_STEPS =
|
| 73 |
TEMPERATURE = 0.1
|
| 74 |
MAX_TOKENS = 300
|
| 75 |
SEED = 42
|
|
@@ -100,6 +100,9 @@ Rules:
|
|
| 100 |
- Use query_nccl_logs to see communication errors.
|
| 101 |
- Avoid restart_rank unless absolutely necessary — it is destructive.
|
| 102 |
- If you already know the failing rank, fix it directly.
|
|
|
|
|
|
|
|
|
|
| 103 |
|
| 104 |
Example response:
|
| 105 |
{"action_type": "inspect_flight_recorder", "parameters": {"rank_id": 3}}
|
|
|
|
| 69 |
)
|
| 70 |
API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY", "")
|
| 71 |
ENV_BASE_URL = os.getenv("ENV_BASE_URL", "http://localhost:7860")
|
| 72 |
+
MAX_STEPS = 20
|
| 73 |
TEMPERATURE = 0.1
|
| 74 |
MAX_TOKENS = 300
|
| 75 |
SEED = 42
|
|
|
|
| 100 |
- Use query_nccl_logs to see communication errors.
|
| 101 |
- Avoid restart_rank unless absolutely necessary — it is destructive.
|
| 102 |
- If you already know the failing rank, fix it directly.
|
| 103 |
+
- For cascade failures: solve phases in order. Phase 1=OOM diagnosis,
|
| 104 |
+
Phase 2=topo_reorder, Phase 3=query_nccl_logs then patch_divergent_code
|
| 105 |
+
- Token efficiency matters: fewer tokens = higher reward
|
| 106 |
|
| 107 |
Example response:
|
| 108 |
{"action_type": "inspect_flight_recorder", "parameters": {"rank_id": 3}}
|
openenv.yaml
CHANGED
|
@@ -1,43 +1,80 @@
|
|
| 1 |
name: nervousystem-env
|
| 2 |
-
version: "
|
| 3 |
description: >
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
author: v4xsh
|
| 10 |
tags:
|
| 11 |
- openenv
|
| 12 |
- sre
|
| 13 |
- distributed-training
|
| 14 |
-
-
|
| 15 |
-
-
|
| 16 |
-
-
|
| 17 |
-
|
|
|
|
|
|
|
| 18 |
tasks:
|
| 19 |
- id: easy
|
| 20 |
name: "Culprit Rank Identification"
|
| 21 |
difficulty: easy
|
|
|
|
| 22 |
description: >
|
| 23 |
-
Training
|
| 24 |
-
|
| 25 |
- id: medium
|
| 26 |
name: "Spine Switch Congestion Resolution"
|
| 27 |
difficulty: medium
|
|
|
|
| 28 |
description: >
|
| 29 |
-
Training throughput
|
| 30 |
-
Reorder
|
| 31 |
- id: hard
|
| 32 |
name: "Asymmetric Compilation Desync Fix"
|
| 33 |
difficulty: hard
|
|
|
|
| 34 |
description: >
|
| 35 |
-
Training
|
| 36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
observation_space:
|
| 38 |
type: object
|
| 39 |
-
description:
|
|
|
|
|
|
|
| 40 |
action_space:
|
| 41 |
type: object
|
| 42 |
-
description:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
reward_range: [0.0, 1.0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
name: nervousystem-env
|
| 2 |
+
version: "2.0.0"
|
| 3 |
description: >
|
| 4 |
+
Fleet AI environment for autonomous SRE agents managing distributed
|
| 5 |
+
GPU training clusters. Agents act as supervisors orchestrating
|
| 6 |
+
specialized worker agents to diagnose and fix cascading failures
|
| 7 |
+
across 1000+ GPU clusters. Every minute of downtime costs $5,000
|
| 8 |
+
in wasted compute.
|
| 9 |
author: v4xsh
|
| 10 |
tags:
|
| 11 |
- openenv
|
| 12 |
- sre
|
| 13 |
- distributed-training
|
| 14 |
+
- fleet-ai
|
| 15 |
+
- multi-agent
|
| 16 |
+
- long-horizon
|
| 17 |
+
- mercor
|
| 18 |
+
- gpu-infrastructure
|
| 19 |
+
entry_point: "server.app:app"
|
| 20 |
tasks:
|
| 21 |
- id: easy
|
| 22 |
name: "Culprit Rank Identification"
|
| 23 |
difficulty: easy
|
| 24 |
+
max_steps: 50
|
| 25 |
description: >
|
| 26 |
+
Training stalled by OOM on one rank. Identify the failing rank
|
| 27 |
+
using PyTorch 2.5 Flight Recorder inspection.
|
| 28 |
- id: medium
|
| 29 |
name: "Spine Switch Congestion Resolution"
|
| 30 |
difficulty: medium
|
| 31 |
+
max_steps: 50
|
| 32 |
description: >
|
| 33 |
+
Training throughput degraded to 45-65% target due to spine switch
|
| 34 |
+
congestion. Reorder ring topology to restore bandwidth.
|
| 35 |
- id: hard
|
| 36 |
name: "Asymmetric Compilation Desync Fix"
|
| 37 |
difficulty: hard
|
| 38 |
+
max_steps: 50
|
| 39 |
description: >
|
| 40 |
+
Training hung due to different ranks compiling different NCCL
|
| 41 |
+
collectives. Investigate and patch the divergent source file.
|
| 42 |
+
- id: cascade
|
| 43 |
+
name: "Inter-Version Cascade"
|
| 44 |
+
difficulty: cascade
|
| 45 |
+
max_steps: 120
|
| 46 |
+
description: >
|
| 47 |
+
Severity-1 incident: LD_LIBRARY_PATH corruption loads wrong NCCL
|
| 48 |
+
version (2.21.5 vs 2.27.0), triggering a cascade of OOM →
|
| 49 |
+
congestion → desync across the fleet. Solve all 3 phases in order.
|
| 50 |
observation_space:
|
| 51 |
type: object
|
| 52 |
+
description: >
|
| 53 |
+
ClusterObservation with 8-node health states, training metrics,
|
| 54 |
+
surface NCCL logs, step count, episode id, and cumulative token count.
|
| 55 |
action_space:
|
| 56 |
type: object
|
| 57 |
+
description: >
|
| 58 |
+
SREAction with action_type and parameters dict. 8 action types
|
| 59 |
+
including inspect_flight_recorder, query_nccl_logs, topo_reorder,
|
| 60 |
+
patch_divergent_code, restart_rank, reset_ib_interface,
|
| 61 |
+
adjust_sharding_strategy, noop.
|
| 62 |
reward_range: [0.0, 1.0]
|
| 63 |
+
reward_description: >
|
| 64 |
+
Mercor-style efficiency reward: 0.60*R_success + 0.30*R_subgoal
|
| 65 |
+
- 0.10*log(total_tokens). Rewards accurate diagnosis with minimal
|
| 66 |
+
token usage. Destructive actions penalized -0.2.
|
| 67 |
+
multi_agent:
|
| 68 |
+
enabled: true
|
| 69 |
+
architecture: "supervisor-worker"
|
| 70 |
+
workers:
|
| 71 |
+
- log_inspector
|
| 72 |
+
- patch_agent
|
| 73 |
+
- topo_agent
|
| 74 |
+
- version_checker
|
| 75 |
+
endpoint: "/delegate"
|
| 76 |
+
training:
|
| 77 |
+
algorithm: GRPO
|
| 78 |
+
framework: "TRL + Unsloth"
|
| 79 |
+
script: "training/grpo_train.py"
|
| 80 |
+
model: "unsloth/Qwen2.5-7B-Instruct-bnb-4bit"
|
requirements.txt
CHANGED
|
@@ -1,9 +1,7 @@
|
|
| 1 |
-
fastapi
|
| 2 |
-
uvicorn
|
| 3 |
-
pydantic>=2.0
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
requests
|
| 9 |
-
openenv-core>=0.2.0
|
|
|
|
| 1 |
+
fastapi>=0.111.0
|
| 2 |
+
uvicorn>=0.29.0
|
| 3 |
+
pydantic>=2.7.0
|
| 4 |
+
requests>=2.31.0
|
| 5 |
+
gradio>=4.31.0
|
| 6 |
+
datasets>=2.19.0
|
| 7 |
+
openai>=1.30.0
|
|
|
|
|
|
training/grpo_train.py
ADDED
|
@@ -0,0 +1,368 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ============================================================
|
| 2 |
+
# NervousSystem-Env — GRPO Training Script
|
| 3 |
+
# ============================================================
|
| 4 |
+
# Colab setup (run these first):
|
| 5 |
+
# !pip install unsloth trl datasets transformers accelerate
|
| 6 |
+
# !pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
|
| 7 |
+
# !uvicorn app.main:app --port 7860 & # start env server
|
| 8 |
+
# ============================================================
|
| 9 |
+
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
import json
|
| 13 |
+
import math
|
| 14 |
+
import os
|
| 15 |
+
import random
|
| 16 |
+
import re
|
| 17 |
+
from typing import Any
|
| 18 |
+
|
| 19 |
+
import requests
|
| 20 |
+
import torch
|
| 21 |
+
from datasets import Dataset
|
| 22 |
+
from trl import GRPOConfig, GRPOTrainer
|
| 23 |
+
from unsloth import FastLanguageModel
|
| 24 |
+
|
| 25 |
+
ENV_BASE_URL = os.getenv("ENV_BASE_URL", "http://localhost:7860")
|
| 26 |
+
MODEL_NAME = os.getenv("MODEL_NAME", "unsloth/Qwen2.5-7B-Instruct-bnb-4bit")
|
| 27 |
+
MAX_SEQ_LENGTH = 1024
|
| 28 |
+
LORA_RANK = 16
|
| 29 |
+
|
| 30 |
+
SRE_SYSTEM_PROMPT = """You are an SRE agent managing a distributed
|
| 31 |
+
GPU training cluster. Diagnose and fix failures efficiently.
|
| 32 |
+
|
| 33 |
+
IMPORTANT: You are penalized for using too many tokens.
|
| 34 |
+
Reason concisely. Identify the failure type first, then act directly.
|
| 35 |
+
|
| 36 |
+
Available actions (respond with JSON only):
|
| 37 |
+
{"action_type": "inspect_flight_recorder", "parameters": {"rank_id": <0-7>}}
|
| 38 |
+
{"action_type": "query_nccl_logs", "parameters": {"time_window": <int>}}
|
| 39 |
+
{"action_type": "topo_reorder", "parameters": {"affinity": "rack"}}
|
| 40 |
+
{"action_type": "patch_divergent_code", "parameters": {"file": "<path>", "fix_type": "synchronize_conditional"}}
|
| 41 |
+
{"action_type": "noop", "parameters": {}}
|
| 42 |
+
|
| 43 |
+
Rules:
|
| 44 |
+
- Respond ONLY with a JSON object, no explanation
|
| 45 |
+
- Check job_status first: stalled=investigate, running=optimize
|
| 46 |
+
- Use inspect_flight_recorder to find failing ranks
|
| 47 |
+
- Use topo_reorder(affinity="rack") for congestion
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
_current_task_id = "easy"
|
| 51 |
+
_prompt_task_map: dict[str, str] = {}
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 55 |
+
model_name=MODEL_NAME,
|
| 56 |
+
max_seq_length=MAX_SEQ_LENGTH,
|
| 57 |
+
load_in_4bit=True,
|
| 58 |
+
dtype=None,
|
| 59 |
+
)
|
| 60 |
+
model = FastLanguageModel.get_peft_model(
|
| 61 |
+
model,
|
| 62 |
+
r=LORA_RANK,
|
| 63 |
+
target_modules=[
|
| 64 |
+
"q_proj",
|
| 65 |
+
"v_proj",
|
| 66 |
+
"k_proj",
|
| 67 |
+
"o_proj",
|
| 68 |
+
"gate_proj",
|
| 69 |
+
"up_proj",
|
| 70 |
+
"down_proj",
|
| 71 |
+
],
|
| 72 |
+
lora_alpha=16,
|
| 73 |
+
lora_dropout=0,
|
| 74 |
+
bias="none",
|
| 75 |
+
use_gradient_checkpointing="unsloth",
|
| 76 |
+
random_state=42,
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def _safe_post(path: str, payload: dict[str, Any], timeout: int = 10) -> dict[str, Any] | None:
|
| 81 |
+
try:
|
| 82 |
+
response = requests.post(f"{ENV_BASE_URL}{path}", json=payload, timeout=timeout)
|
| 83 |
+
response.raise_for_status()
|
| 84 |
+
return response.json()
|
| 85 |
+
except Exception:
|
| 86 |
+
return None
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def _safe_get(path: str, timeout: int = 5) -> dict[str, Any] | None:
|
| 90 |
+
try:
|
| 91 |
+
response = requests.get(f"{ENV_BASE_URL}{path}", timeout=timeout)
|
| 92 |
+
response.raise_for_status()
|
| 93 |
+
return response.json()
|
| 94 |
+
except Exception:
|
| 95 |
+
return None
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def _prompt_key(prompt: Any) -> str:
|
| 99 |
+
try:
|
| 100 |
+
return json.dumps(prompt, sort_keys=True)
|
| 101 |
+
except Exception:
|
| 102 |
+
return str(prompt)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def _task_id_from_prompt(prompt: Any) -> str:
|
| 106 |
+
global _current_task_id
|
| 107 |
+
key = _prompt_key(prompt)
|
| 108 |
+
task_id = _prompt_task_map.get(key, _current_task_id)
|
| 109 |
+
_current_task_id = task_id
|
| 110 |
+
return task_id
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def _extract_json_action(completion: str) -> dict[str, Any] | None:
|
| 114 |
+
text = completion.strip()
|
| 115 |
+
if text.startswith("```"):
|
| 116 |
+
text = "\n".join(line for line in text.splitlines() if not line.strip().startswith("```"))
|
| 117 |
+
try:
|
| 118 |
+
parsed = json.loads(text)
|
| 119 |
+
if isinstance(parsed, dict) and "action_type" in parsed:
|
| 120 |
+
parsed.setdefault("parameters", {})
|
| 121 |
+
return parsed
|
| 122 |
+
except Exception:
|
| 123 |
+
pass
|
| 124 |
+
|
| 125 |
+
match = re.search(r"\{.*\}", text, flags=re.DOTALL)
|
| 126 |
+
if not match:
|
| 127 |
+
return None
|
| 128 |
+
try:
|
| 129 |
+
parsed = json.loads(match.group(0))
|
| 130 |
+
if isinstance(parsed, dict) and "action_type" in parsed:
|
| 131 |
+
parsed.setdefault("parameters", {})
|
| 132 |
+
return parsed
|
| 133 |
+
except Exception:
|
| 134 |
+
return None
|
| 135 |
+
return None
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def make_sre_dataset(n_samples: int = 200) -> Dataset:
|
| 139 |
+
"""
|
| 140 |
+
Generate prompt-only dataset for GRPO.
|
| 141 |
+
Each sample is one initial observation from the env.
|
| 142 |
+
GRPO generates completions and scores them via the reward fn.
|
| 143 |
+
|
| 144 |
+
For each sample:
|
| 145 |
+
- Pick task_id randomly from ["easy", "medium", "hard"]
|
| 146 |
+
(skip cascade for initial training — too long)
|
| 147 |
+
- Pick seed randomly from range(1000)
|
| 148 |
+
- Call POST /reset with task_id and seed
|
| 149 |
+
- Format the observation as the user prompt
|
| 150 |
+
- Return dataset with column "prompt" containing
|
| 151 |
+
[{"role": "system", "content": SRE_SYSTEM_PROMPT},
|
| 152 |
+
{"role": "user", "content": <observation_json>}]
|
| 153 |
+
|
| 154 |
+
observation_json format:
|
| 155 |
+
{
|
| 156 |
+
"job_status": ...,
|
| 157 |
+
"throughput": ...,
|
| 158 |
+
"target_throughput": ...,
|
| 159 |
+
"stalled_steps": ...,
|
| 160 |
+
"node_health": [...],
|
| 161 |
+
"visible_logs": [...],
|
| 162 |
+
"task_hint": "Diagnose and fix the cluster failure."
|
| 163 |
+
}
|
| 164 |
+
"""
|
| 165 |
+
global _current_task_id
|
| 166 |
+
|
| 167 |
+
rows: list[dict[str, Any]] = []
|
| 168 |
+
task_pool = ["easy", "medium", "hard"]
|
| 169 |
+
|
| 170 |
+
for _ in range(n_samples):
|
| 171 |
+
task_id = random.choice(task_pool)
|
| 172 |
+
seed = random.randint(0, 999)
|
| 173 |
+
reset_result = _safe_post("/reset", {"task_id": task_id, "seed": seed})
|
| 174 |
+
if reset_result is None:
|
| 175 |
+
continue
|
| 176 |
+
|
| 177 |
+
training = reset_result.get("training", {})
|
| 178 |
+
nodes = reset_result.get("nodes", [])
|
| 179 |
+
observation_payload = {
|
| 180 |
+
"job_status": training.get("job_status", "unknown"),
|
| 181 |
+
"throughput": training.get("throughput_tokens_per_sec", 0.0),
|
| 182 |
+
"target_throughput": training.get("target_throughput", 0.0),
|
| 183 |
+
"stalled_steps": training.get("stalled_steps", 0),
|
| 184 |
+
"node_health": [
|
| 185 |
+
{
|
| 186 |
+
"node_id": node.get("node_id"),
|
| 187 |
+
"health_status": node.get("health_status"),
|
| 188 |
+
"xid_errors": node.get("xid_errors", []),
|
| 189 |
+
}
|
| 190 |
+
for node in nodes
|
| 191 |
+
],
|
| 192 |
+
"visible_logs": reset_result.get("visible_logs", []),
|
| 193 |
+
"task_hint": "Diagnose and fix the cluster failure.",
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
prompt = [
|
| 197 |
+
{"role": "system", "content": SRE_SYSTEM_PROMPT},
|
| 198 |
+
{"role": "user", "content": json.dumps(observation_payload, ensure_ascii=False)},
|
| 199 |
+
]
|
| 200 |
+
|
| 201 |
+
_current_task_id = task_id
|
| 202 |
+
_prompt_task_map[_prompt_key(prompt)] = task_id
|
| 203 |
+
rows.append({"prompt": prompt})
|
| 204 |
+
|
| 205 |
+
if not rows:
|
| 206 |
+
fallback_prompt = [
|
| 207 |
+
{"role": "system", "content": SRE_SYSTEM_PROMPT},
|
| 208 |
+
{
|
| 209 |
+
"role": "user",
|
| 210 |
+
"content": json.dumps(
|
| 211 |
+
{
|
| 212 |
+
"job_status": "stalled",
|
| 213 |
+
"throughput": 0.0,
|
| 214 |
+
"target_throughput": 9000.0,
|
| 215 |
+
"stalled_steps": 1,
|
| 216 |
+
"node_health": [],
|
| 217 |
+
"visible_logs": ["Server offline during dataset build"],
|
| 218 |
+
"task_hint": "Diagnose and fix the cluster failure.",
|
| 219 |
+
}
|
| 220 |
+
),
|
| 221 |
+
},
|
| 222 |
+
]
|
| 223 |
+
_prompt_task_map[_prompt_key(fallback_prompt)] = "easy"
|
| 224 |
+
rows.append({"prompt": fallback_prompt})
|
| 225 |
+
|
| 226 |
+
return Dataset.from_list(rows)
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def sre_reward_fn(
|
| 230 |
+
completions: list[str],
|
| 231 |
+
prompts: list[Any],
|
| 232 |
+
**kwargs: Any,
|
| 233 |
+
) -> list[float]:
|
| 234 |
+
"""
|
| 235 |
+
Called by GRPOTrainer to score each completion.
|
| 236 |
+
|
| 237 |
+
For each completion:
|
| 238 |
+
1. Parse the JSON action from the completion string
|
| 239 |
+
2. POST the action to /step
|
| 240 |
+
3. Extract reward.value and reward.token_efficiency_score
|
| 241 |
+
4. Apply MER formula:
|
| 242 |
+
tokens = len(completion.split()) # word count proxy
|
| 243 |
+
mer = max(0.01, min(0.99,
|
| 244 |
+
0.60 * r_success + 0.30 * step_reward - 0.10 * math.log(max(1, tokens))
|
| 245 |
+
))
|
| 246 |
+
where r_success = 1.0 if job_status in {"recovered","running"} else 0.0
|
| 247 |
+
5. Return mer as the reward for this completion
|
| 248 |
+
|
| 249 |
+
If parse fails or /step errors: return 0.01
|
| 250 |
+
If server is offline: return 0.01
|
| 251 |
+
|
| 252 |
+
IMPORTANT: Each call to sre_reward_fn must first call /reset
|
| 253 |
+
to get a fresh episode state before stepping.
|
| 254 |
+
"""
|
| 255 |
+
rewards: list[float] = []
|
| 256 |
+
|
| 257 |
+
for index, completion in enumerate(completions):
|
| 258 |
+
prompt = prompts[index] if index < len(prompts) else None
|
| 259 |
+
task_id = _task_id_from_prompt(prompt)
|
| 260 |
+
seed = random.randint(0, 999)
|
| 261 |
+
reset_result = _safe_post("/reset", {"task_id": task_id, "seed": seed})
|
| 262 |
+
if reset_result is None:
|
| 263 |
+
rewards.append(0.01)
|
| 264 |
+
continue
|
| 265 |
+
|
| 266 |
+
action = _extract_json_action(completion)
|
| 267 |
+
if action is None:
|
| 268 |
+
rewards.append(0.01)
|
| 269 |
+
continue
|
| 270 |
+
|
| 271 |
+
step_result = _safe_post("/step", action)
|
| 272 |
+
if step_result is None:
|
| 273 |
+
rewards.append(0.01)
|
| 274 |
+
continue
|
| 275 |
+
|
| 276 |
+
reward_obj = step_result.get("reward", {})
|
| 277 |
+
step_reward = float(reward_obj.get("value", 0.01))
|
| 278 |
+
observation = step_result.get("observation", {})
|
| 279 |
+
job_status = str(observation.get("training", {}).get("job_status", "unknown"))
|
| 280 |
+
r_success = 1.0 if job_status in {"recovered", "running"} else 0.0
|
| 281 |
+
|
| 282 |
+
tokens = len(completion.split())
|
| 283 |
+
mer = max(
|
| 284 |
+
0.01,
|
| 285 |
+
min(
|
| 286 |
+
0.99,
|
| 287 |
+
0.60 * r_success
|
| 288 |
+
+ 0.30 * step_reward
|
| 289 |
+
- 0.10 * math.log(max(1, tokens)),
|
| 290 |
+
),
|
| 291 |
+
)
|
| 292 |
+
rewards.append(float(mer))
|
| 293 |
+
|
| 294 |
+
while len(rewards) < len(completions):
|
| 295 |
+
rewards.append(0.01)
|
| 296 |
+
return rewards
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
training_args = GRPOConfig(
|
| 300 |
+
output_dir="./sre_grpo_output",
|
| 301 |
+
num_train_epochs=1,
|
| 302 |
+
per_device_train_batch_size=1,
|
| 303 |
+
gradient_accumulation_steps=4,
|
| 304 |
+
learning_rate=5e-6,
|
| 305 |
+
max_grad_norm=0.1,
|
| 306 |
+
warmup_ratio=0.1,
|
| 307 |
+
lr_scheduler_type="cosine",
|
| 308 |
+
logging_steps=1,
|
| 309 |
+
save_steps=50,
|
| 310 |
+
report_to="none",
|
| 311 |
+
num_generations=4,
|
| 312 |
+
max_new_tokens=128,
|
| 313 |
+
temperature=0.7,
|
| 314 |
+
beta=0.001,
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
def plot_reward_curve(trainer: GRPOTrainer) -> None:
|
| 319 |
+
"""Print reward progression as ASCII bar chart."""
|
| 320 |
+
history = trainer.state.log_history
|
| 321 |
+
rewards = [
|
| 322 |
+
entry.get("reward", entry.get("train/reward", 0.0))
|
| 323 |
+
for entry in history
|
| 324 |
+
if "reward" in entry or "train/reward" in entry
|
| 325 |
+
]
|
| 326 |
+
if not rewards:
|
| 327 |
+
print("No reward history found.")
|
| 328 |
+
return
|
| 329 |
+
print("\n=== REWARD CURVE ===")
|
| 330 |
+
max_r = max(rewards) if rewards else 1.0
|
| 331 |
+
for i, reward in enumerate(rewards):
|
| 332 |
+
bar = "█" * int((reward / max(0.01, max_r)) * 30)
|
| 333 |
+
print(f" step {i + 1:3d}: {reward:.3f} {bar}")
|
| 334 |
+
print(f"\nInitial reward: {rewards[0]:.3f}")
|
| 335 |
+
print(f"Final reward: {rewards[-1]:.3f}")
|
| 336 |
+
delta = rewards[-1] - rewards[0]
|
| 337 |
+
print(f"Improvement: {delta:+.3f}")
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
if __name__ == "__main__":
|
| 341 |
+
random.seed(42)
|
| 342 |
+
torch.manual_seed(42)
|
| 343 |
+
|
| 344 |
+
try:
|
| 345 |
+
health = _safe_get("/health")
|
| 346 |
+
assert health is not None and health.get("status") == "ok"
|
| 347 |
+
print(f"✅ Server healthy at {ENV_BASE_URL}")
|
| 348 |
+
except Exception as exc:
|
| 349 |
+
print(f"❌ Server not reachable: {exc}")
|
| 350 |
+
print("Start it with: uvicorn app.main:app --port 7860")
|
| 351 |
+
raise SystemExit(1)
|
| 352 |
+
|
| 353 |
+
dataset = make_sre_dataset(n_samples=200)
|
| 354 |
+
print(f"✅ Dataset: {len(dataset)} samples")
|
| 355 |
+
|
| 356 |
+
trainer = GRPOTrainer(
|
| 357 |
+
model=model,
|
| 358 |
+
args=training_args,
|
| 359 |
+
train_dataset=dataset,
|
| 360 |
+
reward_funcs=sre_reward_fn,
|
| 361 |
+
processing_class=tokenizer,
|
| 362 |
+
)
|
| 363 |
+
trainer.train()
|
| 364 |
+
|
| 365 |
+
plot_reward_curve(trainer)
|
| 366 |
+
model.save_pretrained("sre_agent_lora")
|
| 367 |
+
tokenizer.save_pretrained("sre_agent_lora")
|
| 368 |
+
print("✅ Model saved to sre_agent_lora/")
|