SevZero Bot commited on
Commit
0f5092c
·
1 Parent(s): 382d0fd

Add Wave 1 training pipeline (SFT/GRPO/eval/preflight/launch) + gitignore hardening

Browse files
.gitignore CHANGED
@@ -1,13 +1,34 @@
1
  # Documentation and research (not part of the submission)
2
  Docs/
3
-
4
- # OpenEnv preparatory course (dev reference only, not part of submission)
5
  openenv-course/
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  # Python
8
  __pycache__/
9
  *.pyc
10
  *.pyo
 
 
 
11
 
12
- # Environment
13
- .env
 
 
 
 
1
  # Documentation and research (not part of the submission)
2
  Docs/
3
+ DocsR2/
 
4
  openenv-course/
5
+ playbook/
6
+
7
+ # Secrets — NEVER commit
8
+ .env
9
+ *.env
10
+ api.env
11
+ hg.env
12
+
13
+ # Training artefacts
14
+ training/data/raw/
15
+ training/.preflight_grpo/
16
+ training/runs.jsonl
17
+ outputs/
18
+ out/
19
+ wandb/
20
+ trackio/
21
 
22
  # Python
23
  __pycache__/
24
  *.pyc
25
  *.pyo
26
+ *.egg-info/
27
+ .venv/
28
+ venv/
29
 
30
+ # OS / editor
31
+ .DS_Store
32
+ Thumbs.db
33
+ .idea/
34
+ .vscode/
training/README.md ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SevZero — training (Round 2)
2
+
3
+ One-liner per script:
4
+
5
+ - **`train_sft.py`**: SFT on `Mist-ic/sevzero-expert-trajectories` with QLoRA (Unsloth or PEFT fallback) → push adapter with `HF_TOKEN`.
6
+ - **`train_grpo.py`**: GRPO with `rollout_func` + remote env (`SEVZERO_ENV_URL`); vLLM colocate, Trackio `Mist-ic/sevzero-trackio`.
7
+ - **`eval.py`**: Compare HF adapters and frontier models; write `eval_results.csv`, push `Mist-ic/sevzero-eval-results` with `HF_MAIN_TOKEN`.
8
+ - **`preflight.py`**: In-process grader + tiny GRPO smoke (5 steps) on CPU; starts local uvicorn.
9
+ - **`launch_hf_job.py`**: `huggingface_hub.run_job` wrapper; `--hardware l40sx1` (verify with `hf jobs hardware`).
10
+
11
+ ## Env files
12
+
13
+ Load with `python-dotenv` (auto-tried in `config_utils`):
14
+
15
+ - `hg.env` — `HF_TOKEN` (worker), `HF_MAIN_TOKEN` (Mist-ic, Trackio + eval dataset)
16
+ - `api.env` — `GEMINI_API_KEY`, `AZURE_*` for `eval.py`
17
+
18
+ | Variable | Role |
19
+ |----------|------|
20
+ | `HF_TOKEN` | Worker: train pushes, private adapter pulls |
21
+ | `HF_MAIN_TOKEN` | `Mist-ic`: Trackio + `sevzero-eval-results` only |
22
+ | `SEVZERO_ENV_URL` | HTTP base of SevZero Space/ server for GRPO + eval + preflight |
23
+ | `GEMINI_API_KEY` | Direct Gemini in eval |
24
+ | `AZURE_API_KEY` | Azure OpenAI + Azure AI Inference |
25
+ | `AZURE_OPENAI_ENDPOINT` | Deployment base for gpt-5.4-pro |
26
+ | `AZURE_AI_INFERENCE_ENDPOINT` | For grok / kimi / DeepSeek in eval |
27
+ | `AZURE_API_VERSION` | OpenAI client version header if needed |
28
+ | `GEMINI_EVAL_MODEL` | Optional override (default set in `eval.py`) |
29
+
30
+ ## Local debug (from repo root)
31
+
32
+ ```bash
33
+ # Install (pin versions in comments / orchestrator)
34
+ pip install -e ".[training]"
35
+
36
+ # SFT
37
+ python training/train_sft.py --output_dir ./out/sft --max_steps 10 --push_to_hub_repo "" --variant_name test
38
+
39
+ # GRPO (remote env required)
40
+ $env:SEVZERO_ENV_URL="https://<your-sevzero-space>.hf.space"
41
+ python training/train_grpo.py --sft_adapter_repo YOUR/adapters --max_steps 5 --output_dir ./out/grpo
42
+ ```
43
+
44
+ ## Wave 3 — three GRPO variants (see `playbook/00-orchestration.md`)
45
+
46
+ Primary (PhaseOfCode):
47
+
48
+ ```bash
49
+ python training/train_grpo.py --sft_adapter_repo PhaseOfCode/sevzero-llama3-8b-sft --K 4 --lr 7e-6 --max_steps 350 --variant_name primary
50
+ ```
51
+
52
+ Stability (NoahInOblivion):
53
+
54
+ ```bash
55
+ python training/train_grpo.py --sft_adapter_repo NoahInOblivion/sevzero-llama3-8b-sft --K 8 --lr 5e-6 --max_steps 350 --variant_name stability
56
+ ```
57
+
58
+ Innovation (NoxIsOblivion, env flags on):
59
+
60
+ ```bash
61
+ python training/train_grpo.py --sft_adapter_repo NoxIsOblivion/sevzero-llama3-8b-sft --enable_schema_drift --enable_curriculum --K 4 --max_steps 350 --variant_name innovation
62
+ ```
63
+
64
+ **HF Job (after merge + public git URL or bucket):**
65
+
66
+ ```bash
67
+ $env:HF_TOKEN="<worker>"
68
+ $env:SEVZERO_ENV_URL="https://....hf.space"
69
+ python training/launch_hf_job.py --script grpo --variant_name primary -- --sft_adapter_repo YOUR/sevzero-llama3-8b-sft
70
+ ```
71
+
72
+ **Dependency pins:** run `pip index versions trl openenv-core unsloth` and `python -c "import trl; print(trl.__version__)"` after install; pin in the orchestrator’s lock, not in this file.
training/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Training / trajectory pipeline (Round 2)
training/build_dataset.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Build Llama-3.1-8B-Instruct SFT jsonl from raw trajectory jsonl (score ≥ 0.85).
3
+ """
4
+ from __future__ import annotations
5
+
6
+ import argparse
7
+ import json
8
+ import random
9
+ import sys
10
+ from pathlib import Path
11
+ from typing import Any, Dict, List, Set, Tuple
12
+
13
+ from dotenv import load_dotenv
14
+
15
+ REPO_ROOT = Path(__file__).resolve().parent.parent
16
+ if str(REPO_ROOT) not in sys.path:
17
+ sys.path.insert(0, str(REPO_ROOT))
18
+
19
+ from inference import SYSTEM_PROMPT # noqa: E402
20
+
21
+ load_dotenv(REPO_ROOT / "api.env")
22
+ load_dotenv(REPO_ROOT / "hg.env")
23
+
24
+ DATA_DIR = REPO_ROOT / "training" / "data"
25
+ RAW_GLOB = "raw/*.jsonl"
26
+ OUT_TRAIN = DATA_DIR / "sft_train.jsonl"
27
+ OUT_EVAL = DATA_DIR / "sft_eval.jsonl"
28
+ OUT_STATS = DATA_DIR / "build_stats.json"
29
+
30
+ MAX_OBS_TOKENS = 2048
31
+
32
+
33
+ def _get_tokenizer():
34
+ import os
35
+
36
+ try:
37
+ from transformers import AutoTokenizer
38
+ except Exception:
39
+ return None
40
+ name = "meta-llama/Llama-3.1-8B-Instruct"
41
+ try:
42
+ tok = AutoTokenizer.from_pretrained(
43
+ name, token=os.environ.get("HF_MAIN_TOKEN")
44
+ )
45
+ return tok
46
+ except Exception:
47
+ try:
48
+ return AutoTokenizer.from_pretrained(
49
+ "hf-internal-testing/llama-tokenizer"
50
+ )
51
+ except Exception:
52
+ return None
53
+
54
+
55
+ def _count_tokens(toker, text: str) -> int:
56
+ if toker is not None:
57
+ return len(toker.encode(text, add_special_tokens=False))
58
+ return max(1, len(text) // 4)
59
+
60
+
61
+ def _shrink_observation(obs: Dict[str, Any], toker, max_toks: int) -> str:
62
+ """Serialize observation to JSON, shrink until user message fits max_toks (approximate)."""
63
+ o = {k: v for k, v in obs.items() if k not in ("reward",)}
64
+ order_drop = [
65
+ "metric_history",
66
+ "traces",
67
+ "logs",
68
+ "actions_taken",
69
+ "recent_deploys",
70
+ ]
71
+ for _ in range(40):
72
+ text = json.dumps(o, ensure_ascii=False, separators=(",", ":"), default=str)
73
+ tcount = _count_tokens(toker, text)
74
+ if tcount <= max_toks:
75
+ return text
76
+ shrunk = False
77
+ for k in order_drop:
78
+ if k in o and o[k]:
79
+ o[k] = None
80
+ if k == "actions_taken":
81
+ o[k] = []
82
+ elif k in ("metric_history", "recent_deploys"):
83
+ o[k] = []
84
+ shrunk = True
85
+ break
86
+ if shrunk:
87
+ continue
88
+ if "services" in o and isinstance(o["services"], list) and len(o["services"]) > 2:
89
+ o["services"] = o["services"][: max(1, len(o["services"]) - 1)]
90
+ continue
91
+ if "alerts" in o and isinstance(o["alerts"], list) and len(o["alerts"]) > 1:
92
+ o["alerts"] = o["alerts"][: max(0, len(o["alerts"]) - 1)]
93
+ continue
94
+ o["__truncated__"] = True
95
+ break
96
+ return json.dumps(o, ensure_ascii=False, separators=(",", ":"), default=str)
97
+
98
+
99
+ def _episode_id(ep: Dict[str, Any]) -> str:
100
+ return f"{ep.get('model', '')}|{ep.get('task_id', '')}|{ep.get('seed', 0)}"
101
+
102
+
103
+ def _assistant_action_json(action: Any) -> str:
104
+ if not isinstance(action, dict):
105
+ return json.dumps(
106
+ {"action_type": "noop", "params": {}}, ensure_ascii=False
107
+ )
108
+ a = {
109
+ "action_type": str(action.get("action_type", "noop")),
110
+ "params": action.get("params") or {},
111
+ }
112
+ return json.dumps(a, ensure_ascii=False)
113
+
114
+
115
+ def _load_episodes_from_raw(raw_dir: Path) -> List[Dict[str, Any]]:
116
+ out: List[Dict[str, Any]] = []
117
+ for p in sorted(raw_dir.glob("*.jsonl")):
118
+ with p.open(encoding="utf-8") as f:
119
+ for line in f:
120
+ line = line.strip()
121
+ if not line:
122
+ continue
123
+ out.append(json.loads(line))
124
+ return out
125
+
126
+
127
+ def build(
128
+ min_score: float = 0.85,
129
+ ) -> Dict[str, Any]:
130
+ toker = _get_tokenizer()
131
+ raw_dir = DATA_DIR / "raw"
132
+ episodes = _load_episodes_from_raw(raw_dir)
133
+ kept: List[Dict[str, Any]] = []
134
+ dropped: List[Dict[str, Any]] = []
135
+ for ep in episodes:
136
+ sc = float(ep.get("final_score", 0.0) or 0.0)
137
+ if sc >= min_score and ep.get("steps"):
138
+ kept.append(ep)
139
+ else:
140
+ dropped.append(ep)
141
+
142
+ eids = [_episode_id(e) for e in kept]
143
+ unique_eids = list(dict.fromkeys(eids))
144
+ n_ep = len(unique_eids)
145
+ rng = random.Random(42)
146
+ rng.shuffle(unique_eids)
147
+ if n_ep <= 1:
148
+ n_eval = 0
149
+ else:
150
+ n_eval = max(1, n_ep // 10)
151
+ eval_ids: Set[str] = set(unique_eids[:n_eval]) if n_eval else set()
152
+
153
+ train_rows: List[Dict[str, Any]] = []
154
+ eval_rows: List[Dict[str, Any]] = []
155
+ max_prompt_toks = 0
156
+
157
+ for ep in kept:
158
+ eid = _episode_id(ep)
159
+ is_eval = eid in eval_ids
160
+ for st in ep.get("steps", []):
161
+ obs = st.get("observation", {})
162
+ if not isinstance(obs, dict):
163
+ continue
164
+ user_str = _shrink_observation(obs, toker, MAX_OBS_TOKENS)
165
+ messages = [
166
+ {"role": "system", "content": SYSTEM_PROMPT},
167
+ {"role": "user", "content": user_str},
168
+ {
169
+ "role": "assistant",
170
+ "content": _assistant_action_json(st.get("action", {})),
171
+ },
172
+ ]
173
+ if toker is not None:
174
+ try:
175
+ plen = len(
176
+ toker.apply_chat_template(
177
+ messages, tokenize=True, add_generation_prompt=False
178
+ )
179
+ )
180
+ except Exception:
181
+ plen = _count_tokens(
182
+ toker, SYSTEM_PROMPT + "\n" + user_str
183
+ )
184
+ else:
185
+ plen = _count_tokens(
186
+ None, SYSTEM_PROMPT + "\n" + user_str
187
+ )
188
+ max_prompt_toks = max(max_prompt_toks, plen)
189
+ row = {
190
+ "messages": messages,
191
+ "meta": {
192
+ "episode_id": eid,
193
+ "model": ep.get("model"),
194
+ "task_id": ep.get("task_id"),
195
+ "seed": ep.get("seed"),
196
+ "step": st.get("step"),
197
+ "episode_score": ep.get("final_score"),
198
+ },
199
+ }
200
+ if is_eval:
201
+ eval_rows.append(row)
202
+ else:
203
+ train_rows.append(row)
204
+
205
+ scores = [float(x.get("final_score", 0) or 0) for x in kept]
206
+ mean_sc = sum(scores) / len(scores) if scores else 0.0
207
+
208
+ DATA_DIR.mkdir(parents=True, exist_ok=True)
209
+ with OUT_TRAIN.open("w", encoding="utf-8") as ft:
210
+ for r in train_rows:
211
+ ft.write(json.dumps(r, ensure_ascii=False) + "\n")
212
+ with OUT_EVAL.open("w", encoding="utf-8") as fe:
213
+ for r in eval_rows:
214
+ fe.write(json.dumps(r, ensure_ascii=False) + "\n")
215
+
216
+ stats: Dict[str, Any] = {
217
+ "episodes_total_seen": len(episodes),
218
+ "episodes_kept": len(kept),
219
+ "episodes_dropped": len(dropped),
220
+ "mean_episode_score_kept": round(mean_sc, 6),
221
+ "train_rows": len(train_rows),
222
+ "eval_rows": len(eval_rows),
223
+ "max_prompt_token_length": max_prompt_toks,
224
+ "max_observation_user_token_budget": MAX_OBS_TOKENS,
225
+ "min_score_filter": min_score,
226
+ }
227
+ with OUT_STATS.open("w", encoding="utf-8") as f:
228
+ json.dump(stats, f, indent=2)
229
+ print(json.dumps(stats, indent=2), flush=True)
230
+ return stats
231
+
232
+
233
+ def main() -> None:
234
+ ap = argparse.ArgumentParser()
235
+ ap.add_argument("--min-score", type=float, default=0.85)
236
+ args = ap.parse_args()
237
+ build(min_score=args.min_score)
238
+
239
+
240
+ if __name__ == "__main__":
241
+ main()
training/collect_trajectories.py ADDED
@@ -0,0 +1,764 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Collect expert trajectories for SevZero SFT (Round 2).
3
+
4
+ Loads API keys from api.env and hg.env (gitignored). Does not log secrets.
5
+ """
6
+ from __future__ import annotations
7
+
8
+ import argparse
9
+ import copy
10
+ import difflib
11
+ import json
12
+ import os
13
+ import re
14
+ import subprocess
15
+ import sys
16
+ import time
17
+ from dataclasses import dataclass, field
18
+ from pathlib import Path
19
+ from typing import Any, Dict, List, Optional, Set, Tuple
20
+
21
+ import httpx
22
+ from dotenv import load_dotenv
23
+ from openai import AzureOpenAI
24
+ from pydantic import BaseModel, Field
25
+
26
+ # Repo root: parent of training/
27
+ REPO_ROOT = Path(__file__).resolve().parent.parent
28
+ if str(REPO_ROOT) not in sys.path:
29
+ sys.path.insert(0, str(REPO_ROOT))
30
+
31
+ from inference import ( # noqa: E402
32
+ build_observation_prompt,
33
+ parse_action,
34
+ )
35
+ from inference import SYSTEM_PROMPT as _BASE_SYSTEM # noqa: E402
36
+
37
+ load_dotenv(REPO_ROOT / "api.env")
38
+ load_dotenv(REPO_ROOT / "hg.env")
39
+
40
+ # ---------------------------------------------------------------------------
41
+ # Config matrix (must match spec)
42
+ # ---------------------------------------------------------------------------
43
+
44
+ GEMINI_SEEDS = [
45
+ 42, 123, 7, 11, 23, 31, 47, 59, 67, 71, 83, 89, 97, 101, 109, 113, 127, 131, 137, 149
46
+ ]
47
+ GPT_SEEDS = [
48
+ 42, 123, 7, 13, 17, 19, 29, 37, 41, 43, 53, 61, 73, 79, 83, 89, 97, 101, 103, 107
49
+ ]
50
+ GROK_EXTRA_SEEDS = [13, 17, 19, 29, 37, 41, 43, 53, 61, 73]
51
+
52
+ # Combined pool for grok / kimi / deepseek (any from grok list + full Gemini list)
53
+ GROK_KIMI_POOL: List[int] = sorted(set(GEMINI_SEEDS) | set(GROK_EXTRA_SEEDS))
54
+
55
+ MODEL_GEMINI = "gemini-3.1-pro-preview"
56
+ MODEL_GPT = "gpt-5.4-pro"
57
+ MODEL_GROK = "grok-4.20-reasoning"
58
+ MODEL_KIMI = "kimi-k2.6"
59
+ MODEL_DEEPSEEK = "DeepSeek-V3.2"
60
+ ALL_CANON = {MODEL_GEMINI, MODEL_GPT, MODEL_GROK, MODEL_KIMI, MODEL_DEEPSEEK}
61
+
62
+
63
+ def _split_seeds(
64
+ pool: List[int], counts: Tuple[int, int, int], offset: int
65
+ ) -> List[Tuple[str, int]]:
66
+ """Return list of (task_id, seed) in order easy, medium, hard."""
67
+ c_e, c_m, c_h = counts
68
+ n = len(pool)
69
+ if n == 0:
70
+ return []
71
+ o = [pool[(i + offset) % n] for i in range(n)]
72
+ out: List[Tuple[str, int]] = []
73
+ i = 0
74
+ for _ in range(c_e):
75
+ out.append(("easy", o[i % len(o)]))
76
+ i += 1
77
+ for _ in range(c_m):
78
+ out.append(("medium", o[i % len(o)]))
79
+ i += 1
80
+ for _ in range(c_h):
81
+ out.append(("hard", o[i % len(o)]))
82
+ i += 1
83
+ return out
84
+
85
+
86
+ def plan_gemini(c_e: int, c_m: int, c_h: int) -> List[Tuple[str, str, int]]:
87
+ return [
88
+ (MODEL_GEMINI, t, s)
89
+ for t, s in _split_seeds(GEMINI_SEEDS, (c_e, c_m, c_h), offset=0)
90
+ ]
91
+
92
+
93
+ def plan_gpt(c_e: int, c_m: int, c_h: int) -> List[Tuple[str, str, int]]:
94
+ return [
95
+ (MODEL_GPT, t, s)
96
+ for t, s in _split_seeds(GPT_SEEDS, (c_e, c_m, c_h), offset=0)
97
+ ]
98
+
99
+
100
+ def plan_grok(c_e: int, c_m: int, c_h: int) -> List[Tuple[str, str, int]]:
101
+ return [
102
+ (MODEL_GROK, t, s)
103
+ for t, s in _split_seeds(GROK_KIMI_POOL, (c_e, c_m, c_h), offset=0)
104
+ ]
105
+
106
+
107
+ def plan_kimi(c_e: int, c_m: int, c_h: int) -> List[Tuple[str, str, int]]:
108
+ return [
109
+ (MODEL_KIMI, t, s)
110
+ for t, s in _split_seeds(GROK_KIMI_POOL, (c_e, c_m, c_h), offset=7)
111
+ ]
112
+
113
+
114
+ def plan_deepseek(c_e: int, c_m: int, c_h: int) -> List[Tuple[str, str, int]]:
115
+ return [
116
+ (MODEL_DEEPSEEK, t, s)
117
+ for t, s in _split_seeds(GROK_KIMI_POOL, (c_e, c_m, c_h), offset=3)
118
+ ]
119
+
120
+
121
+ def full_plan(c_e: int, c_m: int, c_h: int) -> List[Tuple[str, str, int]]:
122
+ return (
123
+ plan_gemini(c_e, c_m, c_h)
124
+ + plan_gpt(c_e, c_m, c_h)
125
+ + plan_grok(c_e, c_m, c_h)
126
+ + plan_kimi(c_e, c_m, c_h)
127
+ + plan_deepseek(c_e, c_m, c_h)
128
+ )
129
+
130
+
131
+ # Rough USD cost tracking (tunable; for guardrail only)
132
+ @dataclass
133
+ class CostTracker:
134
+ usd: float = 0.0
135
+ budget: float = 5.0
136
+ by_model: Dict[str, float] = field(default_factory=dict)
137
+ per_model_max: float = 2.0
138
+
139
+ def add(self, model: str, usd: float) -> None:
140
+ self.usd += usd
141
+ self.by_model[model] = self.by_model.get(model, 0.0) + usd
142
+ m = self.by_model[model]
143
+ cap = self.per_model_max
144
+ if m > cap:
145
+ raise RuntimeError(
146
+ f"Model {model} exceeded ${cap:.2f} in estimated spend (${m:.2f}); stopping per cap."
147
+ )
148
+ if self.usd > self.budget:
149
+ raise RuntimeError(
150
+ f"Total estimated API spend ${self.usd:.2f} exceeded budget ${self.budget:.2f}."
151
+ )
152
+
153
+
154
+ def _estimate_openai_style_cost(
155
+ model: str, prompt_tokens: int, completion_tokens: int
156
+ ) -> float:
157
+ # Conservative blended rate per 1K tokens (USD) — for guardrails only
158
+ if "gemini" in model:
159
+ p, c = 0.00125, 0.01
160
+ elif "gpt" in model.lower() or "5.4" in model:
161
+ p, c = 0.0025, 0.01
162
+ else:
163
+ p, c = 0.001, 0.006
164
+ return (prompt_tokens * p + completion_tokens * c) / 1000.0
165
+
166
+
167
+ # ---------------------------------------------------------------------------
168
+ # Pydantic for Gemini structured action JSON
169
+ # ---------------------------------------------------------------------------
170
+
171
+
172
+ class AgentActionOut(BaseModel):
173
+ action_type: str
174
+ params: Dict[str, Any] = Field(default_factory=dict)
175
+
176
+
177
+ # ---------------------------------------------------------------------------
178
+ # Azure deployment self-heal
179
+ # ---------------------------------------------------------------------------
180
+
181
+
182
+ def _is_not_found(err: str) -> bool:
183
+ s = (err or "").lower()
184
+ return "deploymentnotfound" in s or "deployment" in s and "not found" in s
185
+
186
+
187
+ def list_azure_openai_deployments() -> List[str]:
188
+ key = os.environ.get("AZURE_API_KEY", "")
189
+ ep = (os.environ.get("AZURE_OPENAI_ENDPOINT", "") or "").rstrip("/")
190
+ ver = os.environ.get("AZURE_API_VERSION", "2024-12-01-preview")
191
+ if not key or not ep:
192
+ return []
193
+ url = f"{ep}/openai/deployments?api-version={ver}"
194
+ try:
195
+ r = httpx.get(url, headers={"api-key": key}, timeout=30.0)
196
+ r.raise_for_status()
197
+ data = r.json()
198
+ return [d.get("id", "") for d in data.get("value", []) if d.get("id")]
199
+ except Exception:
200
+ return []
201
+
202
+
203
+ def list_foundry_deployments() -> List[str]:
204
+ """
205
+ Best-effort: project endpoint may expose deployments; schema varies.
206
+ """
207
+ fe = (os.environ.get("AZURE_FOUNDRY_PROJECT_ENDPOINT", "") or "").rstrip("/")
208
+ key = os.environ.get("AZURE_API_KEY", "")
209
+ if not fe or not key:
210
+ return []
211
+ for suffix in ("/deployments", "/openai/models"):
212
+ try:
213
+ url = f"{fe}{suffix}"
214
+ r = httpx.get(
215
+ url, headers={"api-key": key}, params={"api-version": "2024-12-01-preview"}, timeout=30.0
216
+ )
217
+ if r.status_code != 200:
218
+ continue
219
+ data = r.json()
220
+ if isinstance(data, list):
221
+ return [str(x.get("id", x)) for x in data if isinstance(x, dict)]
222
+ if "value" in data:
223
+ return [d.get("id", "") for d in data.get("value", []) if d.get("id")]
224
+ except Exception:
225
+ continue
226
+ return []
227
+
228
+
229
+ def pick_closest(name: str, options: List[str]) -> str:
230
+ if not options:
231
+ return name
232
+ if name in options:
233
+ return name
234
+ ranked = difflib.get_close_matches(name, options, n=1, cutoff=0.2)
235
+ if ranked:
236
+ return ranked[0]
237
+ return options[0]
238
+
239
+
240
+ # ---------------------------------------------------------------------------
241
+ # LLM backends
242
+ # ---------------------------------------------------------------------------
243
+
244
+
245
+ class LLMClient:
246
+ def __init__(self, model: str) -> None:
247
+ self.model = model
248
+ self.gemini_client: Any = None
249
+ self.azure_openai: Any = None
250
+ self.azure_inf: Any = None
251
+ if model == MODEL_GEMINI:
252
+ from google import genai
253
+
254
+ key = os.environ.get("GEMINI_API_KEY", "")
255
+ if not key:
256
+ raise ValueError("GEMINI_API_KEY missing for Gemini collection.")
257
+ self.gemini_client = genai.Client(api_key=key)
258
+ elif model == MODEL_GPT:
259
+ if not all(
260
+ os.environ.get(x)
261
+ for x in (
262
+ "AZURE_API_KEY",
263
+ "AZURE_OPENAI_ENDPOINT",
264
+ "AZURE_API_VERSION",
265
+ )
266
+ ):
267
+ raise ValueError("AZURE_API_KEY, AZURE_OPENAI_ENDPOINT, AZURE_API_VERSION required for gpt-5.4-pro.")
268
+ self.azure_openai = AzureOpenAI(
269
+ api_key=os.environ["AZURE_API_KEY"],
270
+ azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT"],
271
+ api_version=os.environ["AZURE_API_VERSION"],
272
+ )
273
+ else:
274
+ if not all(os.environ.get(x) for x in ("AZURE_API_KEY", "AZURE_AI_INFERENCE_ENDPOINT")):
275
+ raise ValueError("AZURE_API_KEY and AZURE_AI_INFERENCE_ENDPOINT required for inference models.")
276
+ from azure.ai.inference import ChatCompletionsClient
277
+ from azure.core.credentials import AzureKeyCredential
278
+
279
+ self.azure_inf = ChatCompletionsClient(
280
+ endpoint=os.environ["AZURE_AI_INFERENCE_ENDPOINT"],
281
+ credential=AzureKeyCredential(os.environ["AZURE_API_KEY"]),
282
+ )
283
+
284
+ def _deployment_name(self) -> str:
285
+ m = {MODEL_GPT: "AZURE_MODEL_GPT", MODEL_GROK: "AZURE_MODEL_GROK", MODEL_KIMI: "AZURE_MODEL_KIMI", MODEL_DEEPSEEK: "AZURE_MODEL_DEEPSEEK"}.get(self.model)
286
+ if m:
287
+ v = os.environ.get(m, "").strip()
288
+ if v:
289
+ return v
290
+ return self.model
291
+
292
+ def call(
293
+ self,
294
+ messages: List[Dict[str, str]],
295
+ ) -> Tuple[str, int, int]:
296
+ """Return (raw_text, prompt_tokens, completion_tokens)."""
297
+ p_tok, c_tok = 0, 0
298
+ if self.gemini_client is not None:
299
+ return self._call_gemini(messages, p_tok, c_tok)
300
+ if self.azure_openai is not None:
301
+ return self._call_azure_openai(messages, p_tok, c_tok)
302
+ if self.azure_inf is not None:
303
+ return self._call_azure_inference(messages, p_tok, c_tok)
304
+ raise RuntimeError("No backend initialised")
305
+
306
+ def _call_gemini(
307
+ self, messages: List[Dict[str, str]], p0: int, c0: int
308
+ ) -> Tuple[str, int, int]:
309
+ from google.genai import types
310
+
311
+ if not messages:
312
+ return '{"action_type": "noop", "params": {}}', 0, 0
313
+ system = messages[0]["content"] if messages[0]["role"] == "system" else _BASE_SYSTEM
314
+ rest = messages[1:] if messages[0]["role"] == "system" else messages
315
+ name = os.environ.get("GEMINI_MODEL_PRO", MODEL_GEMINI)
316
+ config = types.GenerateContentConfig(
317
+ system_instruction=system,
318
+ response_mime_type="application/json",
319
+ response_json_schema=AgentActionOut,
320
+ temperature=0.0,
321
+ max_output_tokens=512,
322
+ )
323
+ # Build contents: alternating user / model for few-shot tail
324
+ contents: List[Any] = []
325
+ for m in rest:
326
+ if m["role"] == "user":
327
+ contents.append(
328
+ types.Content(role="user", parts=[types.Part.from_text(text=m["content"])])
329
+ )
330
+ else:
331
+ contents.append(
332
+ types.Content(
333
+ role="model",
334
+ parts=[types.Part.from_text(text=m["content"])],
335
+ )
336
+ )
337
+ for attempt in range(3):
338
+ try:
339
+ resp = self.gemini_client.models.generate_content(
340
+ model=name, contents=contents, config=config
341
+ )
342
+ text = (resp.text or "").strip() if hasattr(resp, "text") else ""
343
+ u = getattr(resp, "usage_metadata", None) or getattr(resp, "usage", None)
344
+ pt = int(getattr(u, "prompt_token_count", None) or getattr(u, "prompt_tokens", 0) or 0) if u else 0
345
+ ct = int(getattr(u, "candidates_token_count", None) or getattr(u, "completion_tokens", 0) or 0) if u else 0
346
+ if not text and hasattr(resp, "candidates") and resp.candidates:
347
+ p0x = resp.candidates[0].content.parts[0] if resp.candidates[0].content.parts else None
348
+ text = getattr(p0x, "text", "") or ""
349
+ return text, pt, ct
350
+ except Exception:
351
+ if attempt < 2:
352
+ time.sleep(1.0 + attempt)
353
+ else:
354
+ return '{"action_type": "noop", "params": {}}', p0, c0
355
+
356
+ def _call_azure_openai(
357
+ self, messages: List[Dict[str, str]], p0: int, c0: int
358
+ ) -> Tuple[str, int, int]:
359
+ dep = self._deployment_name()
360
+ for attempt in range(3):
361
+ try:
362
+ comp = self.azure_openai.chat.completions.create(
363
+ model=dep,
364
+ messages=messages, # type: ignore[arg-type]
365
+ temperature=0.0,
366
+ max_tokens=512,
367
+ timeout=90.0,
368
+ )
369
+ text = (comp.choices[0].message.content or "").strip()
370
+ u = comp.usage
371
+ pt = u.prompt_tokens if u else 0
372
+ ct = u.completion_tokens if u else 0
373
+ return text, pt, ct
374
+ except Exception as e:
375
+ err = str(e)
376
+ if _is_not_found(err):
377
+ names = list_azure_openai_deployments()
378
+ if names:
379
+ dep = pick_closest(dep, names)
380
+ if attempt == 2:
381
+ return '{"action_type": "noop", "params": {}}', p0, c0
382
+ time.sleep(1.0 + attempt)
383
+ return '{"action_type": "noop", "params": {}}', p0, c0
384
+
385
+ def _call_azure_inference(
386
+ self, messages: List[Dict[str, str]], p0: int, c0: int
387
+ ) -> Tuple[str, int, int]:
388
+ dep = self._deployment_name()
389
+ for attempt in range(3):
390
+ try:
391
+ resp = self.azure_inf.complete(
392
+ model=dep,
393
+ messages=messages, # type: ignore[arg-type]
394
+ temperature=0.0,
395
+ max_tokens=512,
396
+ )
397
+ ch = resp.choices[0].message
398
+ text = (ch.content or "").strip() if ch else ""
399
+ u = getattr(resp, "usage", None)
400
+ pt = int(getattr(u, "prompt_tokens", 0) or 0) if u else 0
401
+ ct = int(getattr(u, "completion_tokens", 0) or 0) if u else 0
402
+ return text, pt, ct
403
+ except Exception as e:
404
+ err = str(e)
405
+ if _is_not_found(err) or "404" in err or "not found" in err.lower():
406
+ names = [n for n in list_foundry_deployments() + list_azure_openai_deployments() if n]
407
+ if names:
408
+ dep = pick_closest(dep, names)
409
+ if attempt == 2:
410
+ return '{"action_type": "noop", "params": {}}', p0, c0
411
+ time.sleep(1.0 + attempt)
412
+ return '{"action_type": "noop", "params": {}}', p0, c0
413
+
414
+
415
+ # ---------------------------------------------------------------------------
416
+ # Episode (mirrors inference.run_episode; logs full trace)
417
+ # ---------------------------------------------------------------------------
418
+
419
+
420
+ def _memory_block(tried_actions: Dict[str, List[str]], resolved_services: List[str]) -> str:
421
+ if not tried_actions and not resolved_services:
422
+ return ""
423
+ lines = ["## Episode Memory (do not repeat failed approaches)"]
424
+ if resolved_services:
425
+ lines.append(f" Resolved: {', '.join(resolved_services)}")
426
+ for act, targets in tried_actions.items():
427
+ lines.append(f" {act}: {'; '.join(targets)}")
428
+ return "\n".join(lines)
429
+
430
+
431
+ def run_one_episode(
432
+ llm: LLMClient,
433
+ model_id: str,
434
+ base: str,
435
+ task_id: str,
436
+ seed: int,
437
+ cost: CostTracker,
438
+ ) -> Dict[str, Any]:
439
+ grade: Dict[str, Any] = {}
440
+ with httpx.Client(timeout=60.0) as http:
441
+ r = http.post(
442
+ f"{base}/reset", json={"seed": seed, "task_id": task_id}
443
+ )
444
+ r.raise_for_status()
445
+ resp_data = r.json()
446
+ obs: Dict[str, Any] = dict(resp_data.get("observation", resp_data))
447
+ max_steps = int(obs.get("max_steps", 10))
448
+ done = bool(resp_data.get("done", False))
449
+ conv: List[Dict[str, Any]] = []
450
+ tried: Dict[str, List[str]] = {}
451
+ resolved: List[str] = []
452
+ steps_out: List[Dict[str, Any]] = []
453
+ for step_num in range(1, max_steps + 1):
454
+ if done:
455
+ break
456
+ obs_pre = copy.deepcopy(obs)
457
+ user_msg = build_observation_prompt(obs_pre)
458
+ conv.append({"role": "user", "content": user_msg})
459
+ trimmed = conv[-6:]
460
+ memory = _memory_block(tried, resolved)
461
+ system_content = _BASE_SYSTEM + ("\n\n" + memory if memory else "")
462
+ messages: List[Dict[str, str]] = (
463
+ [{"role": "system", "content": system_content}] + trimmed
464
+ )
465
+ raw, pt, ct = llm.call(messages)
466
+ cost.add(
467
+ model_id, _estimate_openai_style_cost(model_id, pt, ct)
468
+ )
469
+ try:
470
+ action = parse_action(raw)
471
+ except Exception:
472
+ action = {"action_type": "noop", "params": {}}
473
+ if isinstance(action, dict) and "action_type" in action and model_id == MODEL_GEMINI:
474
+ try:
475
+ a2 = (
476
+ json.loads(raw[raw.find("{") : raw.rfind("}") + 1])
477
+ if "{" in raw
478
+ else None
479
+ )
480
+ if a2 and isinstance(a2, dict) and "action_type" in a2:
481
+ action = a2
482
+ except Exception:
483
+ pass
484
+ act_params = action.get("params", {}) or {}
485
+ if "replicas" in act_params:
486
+ try:
487
+ act_params["replicas"] = int(act_params["replicas"])
488
+ except (ValueError, TypeError):
489
+ act_params["replicas"] = 2
490
+ act_type = action.get("action_type", "noop")
491
+ target = act_params.get("service_id") or act_params.get("cache_name") or act_params.get("from_region") or ""
492
+ step_resp = http.post(
493
+ f"{base}/step",
494
+ json={"action": {"action_type": act_type, "params": act_params}},
495
+ )
496
+ sdata = step_resp.json() if step_resp.status_code == 200 else {}
497
+ obs = dict(sdata.get("observation", sdata))
498
+ done = bool(sdata.get("done", False))
499
+ reward = float(
500
+ obs.get("reward", sdata.get("reward", 0.0)) or 0.0
501
+ )
502
+ conv.append({"role": "assistant", "content": raw})
503
+ if act_type not in (
504
+ "inspect_logs",
505
+ "inspect_metrics",
506
+ "inspect_traces",
507
+ "noop",
508
+ ) and target:
509
+ new_slo = obs.get("global_slo_score", 0.0)
510
+ for svc in obs.get("services", []):
511
+ if svc.get("id") == target and svc.get("status") == "healthy":
512
+ if target not in resolved:
513
+ resolved.append(target)
514
+ entry = f"{target} (slo={new_slo:.0%})"
515
+ tried.setdefault(str(act_type), [])
516
+ if entry not in tried[str(act_type)]:
517
+ tried[str(act_type)].append(entry)
518
+ obs_ser = json.loads(
519
+ json.dumps(
520
+ {k: v for k, v in obs_pre.items() if k != "reward"},
521
+ default=str,
522
+ )
523
+ )
524
+ steps_out.append(
525
+ {
526
+ "step": step_num,
527
+ "observation": obs_ser,
528
+ "prompt": user_msg,
529
+ "messages": messages,
530
+ "completion": raw,
531
+ "action": action,
532
+ "reward": reward,
533
+ "info": {k: v for k, v in sdata.items() if k not in ("observation",)},
534
+ }
535
+ )
536
+ try:
537
+ final_state = http.get(f"{base}/state").json()
538
+ except Exception:
539
+ final_state = {}
540
+ try:
541
+ grade = http.post(
542
+ f"{base}/grader",
543
+ json={
544
+ "final_slo_score": final_state.get("global_slo_score", 0.0),
545
+ "steps_taken": final_state.get("step_count", 0),
546
+ "max_steps": max_steps,
547
+ "actions_taken": obs.get("actions_taken", []),
548
+ "terminated": final_state.get("terminated", True),
549
+ "termination_reason": final_state.get("termination_reason"),
550
+ },
551
+ ).json()
552
+ except Exception:
553
+ grade = {}
554
+ score = float(grade.get("score", 0.0) or 0.0)
555
+ return {
556
+ "model": model_id,
557
+ "task_id": task_id,
558
+ "seed": seed,
559
+ "steps": steps_out,
560
+ "grader": grade,
561
+ "final_score": score,
562
+ "max_steps": max_steps,
563
+ }
564
+
565
+
566
+ # ---------------------------------------------------------------------------
567
+ # Main
568
+ # ---------------------------------------------------------------------------
569
+
570
+
571
+ def _raw_path(model: str) -> Path:
572
+ safe = re.sub(r"[^a-zA-Z0-9._-]+", "_", model)
573
+ d = REPO_ROOT / "training" / "data" / "raw"
574
+ d.mkdir(parents=True, exist_ok=True)
575
+ return d / f"{safe}.jsonl"
576
+
577
+
578
+ def _wait_health(base: str, timeout: float = 45.0) -> None:
579
+ t0 = time.time()
580
+ while time.time() - t0 < timeout:
581
+ try:
582
+ r = httpx.get(f"{base}/health", timeout=3.0)
583
+ if r.status_code == 200:
584
+ return
585
+ except Exception:
586
+ pass
587
+ time.sleep(1.0)
588
+ print(f"[collect] health check timeout for {base} — continuing", flush=True)
589
+
590
+
591
+ def start_server(port: int) -> subprocess.Popen:
592
+ env = os.environ.copy()
593
+ pp = str(REPO_ROOT)
594
+ env["PYTHONPATH"] = pp if not env.get("PYTHONPATH") else pp + os.pathsep + env["PYTHONPATH"]
595
+ return subprocess.Popen(
596
+ [sys.executable, "-m", "uvicorn", "server.app:app", "--host", "127.0.0.1", "--port", str(port)],
597
+ cwd=REPO_ROOT,
598
+ env=env,
599
+ stdout=subprocess.DEVNULL,
600
+ stderr=subprocess.STDOUT,
601
+ )
602
+
603
+
604
+ def parse_models(s: str) -> List[str]:
605
+ return [m.strip() for m in s.split(",") if m.strip()]
606
+
607
+
608
+ def _plan_for_model(
609
+ model: str, c_e: int, c_m: int, c_h: int
610
+ ) -> List[Tuple[str, str, int]]:
611
+ p = {
612
+ MODEL_GEMINI: plan_gemini,
613
+ MODEL_GPT: plan_gpt,
614
+ MODEL_GROK: plan_grok,
615
+ MODEL_KIMI: plan_kimi,
616
+ MODEL_DEEPSEEK: plan_deepseek,
617
+ }
618
+ fn = p.get(model)
619
+ if not fn:
620
+ return []
621
+ return fn(c_e, c_m, c_h)
622
+
623
+
624
+ def sanity_runs() -> List[Tuple[str, str, int]]:
625
+ return [
626
+ (MODEL_GEMINI, "easy", 42),
627
+ (MODEL_GPT, "easy", 42),
628
+ (MODEL_GROK, "easy", 13),
629
+ ]
630
+
631
+
632
+ def main() -> None:
633
+ ap = argparse.ArgumentParser()
634
+ ap.add_argument(
635
+ "--models",
636
+ type=str,
637
+ default=",".join(sorted(ALL_CANON)),
638
+ help="Comma-separated model ids (default: all)",
639
+ )
640
+ ap.add_argument("--port", type=int, default=7860)
641
+ ap.add_argument("--no-start-server", action="store_true")
642
+ ap.add_argument("--sanity-only", action="store_true", help="Run only 3 smoke episodes (gemini, gpt, grok easy).")
643
+ ap.add_argument("--no-sanity", action="store_true", help="Skip pre-flight sanity runs.")
644
+ ap.add_argument(
645
+ "--budget-usd",
646
+ type=float,
647
+ default=5.0,
648
+ help="Total estimated-spend cap (heuristic) across all models.",
649
+ )
650
+ ap.add_argument(
651
+ "--per-model-budget-usd",
652
+ type=float,
653
+ default=0.0,
654
+ help="Per-model cap (0 = auto: max(2, budget/num selected models)).",
655
+ )
656
+ ap.add_argument(
657
+ "--episodes-easy",
658
+ type=int,
659
+ default=15,
660
+ help="Number of easy-task episodes per model (default 15, Wave 1.5).",
661
+ )
662
+ ap.add_argument(
663
+ "--episodes-medium",
664
+ type=int,
665
+ default=15,
666
+ help="Number of medium-task episodes per model (default 15).",
667
+ )
668
+ ap.add_argument(
669
+ "--episodes-hard",
670
+ type=int,
671
+ default=20,
672
+ help="Number of hard-task episodes per model (default 20).",
673
+ )
674
+ args = ap.parse_args()
675
+ want = set(parse_models(args.models))
676
+ bad = want - ALL_CANON
677
+ if bad:
678
+ raise SystemExit(f"Unknown model(s): {bad}. Valid: {sorted(ALL_CANON)}")
679
+
680
+ c_e, c_m, c_h = args.episodes_easy, args.episodes_medium, args.episodes_hard
681
+ if min(c_e, c_m, c_h) < 0:
682
+ raise SystemExit("--episodes-* must be non-negative.")
683
+ if c_e + c_m + c_h == 0:
684
+ raise SystemExit("At least one of --episodes-easy/medium/hard must be > 0.")
685
+
686
+ _ = full_plan(c_e, c_m, c_h) # exercise planner (raises if misconfigured)
687
+
688
+ # Required keys
689
+ for m in want:
690
+ if m == MODEL_GEMINI and not os.environ.get("GEMINI_API_KEY"):
691
+ raise SystemExit("GEMINI_API_KEY missing (needed for gemini-3.1-pro-preview).")
692
+ if m == MODEL_GPT and not all(
693
+ os.environ.get(x) for x in ("AZURE_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_API_VERSION")
694
+ ):
695
+ raise SystemExit("Azure OpenAI env vars missing for gpt-5.4-pro.")
696
+ if m in (MODEL_GROK, MODEL_KIMI, MODEL_DEEPSEEK) and not all(
697
+ os.environ.get(x) for x in ("AZURE_API_KEY", "AZURE_AI_INFERENCE_ENDPOINT")
698
+ ):
699
+ raise SystemExit("Azure inference env missing for " + m)
700
+
701
+ proc: Optional[subprocess.Popen] = None
702
+ if not args.no_start_server:
703
+ proc = start_server(args.port)
704
+ base = f"http://127.0.0.1:{args.port}"
705
+ _wait_health(base)
706
+ n_m = max(1, len(want))
707
+ per_cap = args.per_model_budget_usd
708
+ if per_cap <= 0.0:
709
+ per_cap = max(2.0, args.budget_usd / n_m)
710
+ cost = CostTracker(budget=args.budget_usd, per_model_max=per_cap)
711
+ # LLM clients (lazy)
712
+ _clients: Dict[str, LLMClient] = {}
713
+ def get_llm(mid: str) -> LLMClient:
714
+ if mid not in _clients:
715
+ _clients[mid] = LLMClient(mid)
716
+ return _clients[mid]
717
+
718
+ try:
719
+ already: Set[Tuple[str, str, int]] = set()
720
+ if args.sanity_only:
721
+ final_list = [r for r in sanity_runs() if r[0] in want]
722
+ else:
723
+ if not args.no_sanity:
724
+ for mid, task_id, seed in (r for r in sanity_runs() if r[0] in want):
725
+ print(f"[sanity] {mid} {task_id} seed={seed}", flush=True)
726
+ llm = get_llm(mid)
727
+ _ = run_one_episode(llm, mid, base, task_id, seed, cost)
728
+ already.add((mid, task_id, seed))
729
+ print("[sanity] pre-flight ok", flush=True)
730
+ final_list = []
731
+ for m in want:
732
+ for x in _plan_for_model(m, c_e, c_m, c_h):
733
+ if x in already:
734
+ continue
735
+ final_list.append(x)
736
+ n_done = 0
737
+ for mid, task_id, seed in final_list:
738
+ print(f"[episode] {mid} {task_id} seed={seed}", flush=True)
739
+ try:
740
+ llm = get_llm(mid)
741
+ ep = run_one_episode(llm, mid, base, task_id, seed, cost)
742
+ except RuntimeError as e:
743
+ print(f"[collect] Stopped: {e}", flush=True)
744
+ break
745
+ p = _raw_path(mid)
746
+ with p.open("a", encoding="utf-8") as f:
747
+ f.write(json.dumps(ep, ensure_ascii=False) + "\n")
748
+ n_done += 1
749
+ print(
750
+ f" -> score={ep.get('final_score', 0):.4f} lines->{p.name} (total est ${cost.usd:.2f})",
751
+ flush=True,
752
+ )
753
+ print(f"Done. Episodes written: {n_done}. Estimated spend: ${cost.usd:.2f}", flush=True)
754
+ finally:
755
+ if proc is not None:
756
+ proc.terminate()
757
+ try:
758
+ proc.wait(timeout=5)
759
+ except Exception:
760
+ proc.kill()
761
+
762
+
763
+ if __name__ == "__main__":
764
+ main()
training/config_utils.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Load dotenv from repo api.env + hg.env (optional). Does not read secrets into logs."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ from pathlib import Path
7
+
8
+ _REPO_ROOT = Path(__file__).resolve().parent.parent
9
+
10
+
11
+ def try_load_env_files() -> None:
12
+ for name in ("api.env", "hg.env"):
13
+ p = _REPO_ROOT / name
14
+ if not p.is_file():
15
+ continue
16
+ try:
17
+ from dotenv import load_dotenv
18
+
19
+ load_dotenv(p, override=False)
20
+ except ImportError:
21
+ _manual_load(p)
22
+
23
+
24
+ def _manual_load(path: Path) -> None:
25
+ for line in path.read_text(encoding="utf-8", errors="ignore").splitlines():
26
+ line = line.strip()
27
+ if not line or line.startswith("#") or "=" not in line:
28
+ continue
29
+ k, v = line.split("=", 1)
30
+ k, v = k.strip(), v.strip().strip('"').strip("'")
31
+ if k and k not in os.environ:
32
+ os.environ[k] = v
training/data/DATASET_README_HF.md ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SevZero expert trajectories (SFT)
2
+
3
+ ## Sources
4
+
5
+ - Synthetic expert rollouts from frontier models (Gemini 3.1 Pro, Azure OpenAI, Azure AI Inference)
6
+ against the local OpenEnv `server.app` SevZero environment.
7
+
8
+ ## Filtering
9
+
10
+ - Episodes with final grader `score` **≥** `0.75` are included.
11
+
12
+ ## Schema
13
+
14
+ - Each example has a `messages` list (Llama-3.1-8B-Instruct–style SFT) and `meta` (episode / step provenance):
15
+ - `system`: SRE on-call system prompt (same as `inference.SYSTEM_PROMPT` in the repo)
16
+ - `user`: JSON-serialized observation (shrink to ≤ 2048 tokens for the user part)
17
+ - `assistant`: one JSON object `{"action_type": "...", "params": {...}}`
18
+
19
+ ## Stats (from `build_stats.json` at publish time)
20
+
21
+ {
22
+ "episodes_total_seen": 90,
23
+ "episodes_kept": 42,
24
+ "episodes_dropped": 48,
25
+ "mean_episode_score_kept": 0.836021,
26
+ "train_rows": 853,
27
+ "eval_rows": 80,
28
+ "max_prompt_token_length": 2,
29
+ "max_observation_user_token_budget": 2048,
30
+ "min_score_filter": 0.75
31
+ }
32
+
33
+ ## Parquet
34
+
35
+ - Splits `train` and `eval` are also pushed in Parquet for fast `datasets.load_dataset`.
training/data/HANDOFF.md ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ - **Dataset URL (after `python -m training.push_dataset`):** https://huggingface.co/datasets/Mist-ic/sevzero-expert-trajectories
2
+ - **Rows:** see `build_stats.json` for `train_rows` and `eval_rows` after you run `build_dataset.py` on real raw JSONL.
3
+ - **Max prompt tokens:** see `max_prompt_token_length` in `build_stats.json` — set SFT/GRPO `max_seq_length` to this + `max_completion_length` (e.g. +1024).
4
+ - **Mean episode score:** `mean_episode_score_kept` in `build_stats.json` (episodes with final grader ≥ 0.85).
5
+ - **Caveats:** run `collect_trajectories.py` with working `api.env`/`hg.env`; use `--no-sanity` to skip the 3 pre-flight API calls; install extras (`python-dotenv`, `google-genai`, `azure-ai-inference`, `huggingface_hub`, `datasets`, `transformers`, `pydantic`) as needed — `pyproject.toml` is unchanged.
training/data/build_stats.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "episodes_total_seen": 90,
3
+ "episodes_kept": 42,
4
+ "episodes_dropped": 48,
5
+ "mean_episode_score_kept": 0.836021,
6
+ "train_rows": 853,
7
+ "eval_rows": 80,
8
+ "max_prompt_token_length": 2,
9
+ "max_observation_user_token_budget": 2048,
10
+ "min_score_filter": 0.75
11
+ }
training/data/dataset_info.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "description": "SevZero SFT expert trajectories for Llama-3.1-8B-Instruct style chat training.",
3
+ "version": "1.0.0",
4
+ "license": "apache-2.0",
5
+ "build": {
6
+ "episodes_total_seen": 90,
7
+ "episodes_kept": 42,
8
+ "episodes_dropped": 48,
9
+ "mean_episode_score_kept": 0.836021,
10
+ "train_rows": 853,
11
+ "eval_rows": 80,
12
+ "max_prompt_token_length": 2,
13
+ "max_observation_user_token_budget": 2048,
14
+ "min_score_filter": 0.75
15
+ }
16
+ }
training/data/sft_eval.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
training/data/sft_train.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
training/env_client.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Async HTTP client for the SevZero OpenEnv server (stateful /reset, /step, /state, /grader).
3
+ Used by train_grpo rollout_func. Does not use root client.py (WebSocket); mirrors inference.py HTTP usage.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ import asyncio
9
+ import os
10
+ from typing import Any, Dict, List, Optional
11
+
12
+ import httpx
13
+
14
+ _DEFAULT_TIMEOUT = 120.0
15
+ _MAX_RETRIES = 5
16
+ _BACKOFF = 1.6
17
+
18
+
19
+ def _space_id_to_runtime_url(space_id: str) -> str:
20
+ """HF Space 'org/name' -> https://org-name.hf.space (common runtime URL)."""
21
+ space_id = space_id.strip()
22
+ if space_id.startswith("http"):
23
+ return space_id.rstrip("/")
24
+ parts = space_id.split("/")
25
+ if len(parts) == 2:
26
+ org, name = parts[0], parts[1]
27
+ # HF uses lowercase, slashes -> dashes in subdomains
28
+ sub = f"{org}-{name}".replace("_", "-").lower()
29
+ return f"https://{sub}.hf.space"
30
+ raise ValueError(f"Invalid space_id (expected 'org/name' or URL): {space_id!r}")
31
+
32
+
33
+ def _backoff_delay(attempt: int) -> float:
34
+ return min(30.0, _BACKOFF**attempt)
35
+
36
+
37
+ def _is_transient_status(code: int) -> bool:
38
+ return code in (429, 500, 502, 503, 504)
39
+
40
+
41
+ class AsyncSevZeroEnvClient:
42
+ """
43
+ Minimal async env client: reset / step / state / grader.
44
+ Pass base_url from SEVZERO_ENV_URL or from_hf_space().
45
+ """
46
+
47
+ def __init__(
48
+ self,
49
+ base_url: str,
50
+ *,
51
+ token: Optional[str] = None,
52
+ timeout: float = _DEFAULT_TIMEOUT,
53
+ ) -> None:
54
+ self._base = base_url.rstrip("/")
55
+ self._token = token
56
+ headers: Dict[str, str] = {"Content-Type": "application/json"}
57
+ if token:
58
+ headers["Authorization"] = f"Bearer {token}"
59
+ self._client = httpx.AsyncClient(
60
+ base_url=self._base,
61
+ headers=headers,
62
+ timeout=timeout,
63
+ )
64
+
65
+ @classmethod
66
+ def from_hf_space(
67
+ cls,
68
+ space_id: str,
69
+ token: Optional[str] = None,
70
+ ) -> "AsyncSevZeroEnvClient":
71
+ """
72
+ space_id: 'organization/space_name' (HF Space) or a full http(s) URL.
73
+ For private Spaces, pass a read token with Space access.
74
+ """
75
+ return cls(_space_id_to_runtime_url(space_id), token=token or os.environ.get("HF_TOKEN"))
76
+
77
+ async def aclose(self) -> None:
78
+ await self._client.aclose()
79
+
80
+ async def _request(
81
+ self,
82
+ method: str,
83
+ path: str,
84
+ *,
85
+ json: Any = None,
86
+ ) -> httpx.Response:
87
+ last_err: Optional[Exception] = None
88
+ for attempt in range(_MAX_RETRIES):
89
+ try:
90
+ r = await self._client.request(method, path, json=json)
91
+ if r.status_code < 400:
92
+ return r
93
+ if _is_transient_status(r.status_code) and attempt < _MAX_RETRIES - 1:
94
+ await asyncio.sleep(_backoff_delay(attempt + 1))
95
+ continue
96
+ return r
97
+ except (httpx.TimeoutException, httpx.NetworkError) as e:
98
+ last_err = e
99
+ if attempt < _MAX_RETRIES - 1:
100
+ await asyncio.sleep(_backoff_delay(attempt + 1))
101
+ continue
102
+ raise
103
+ if last_err:
104
+ raise last_err
105
+ raise RuntimeError("request failed")
106
+
107
+ async def reset(
108
+ self,
109
+ *,
110
+ task_id: str = "hard",
111
+ seed: int = 13,
112
+ episode_id: Optional[str] = None,
113
+ ) -> Dict[str, Any]:
114
+ body: Dict[str, Any] = {"task_id": task_id, "seed": seed}
115
+ if episode_id:
116
+ body["episode_id"] = episode_id
117
+ r = await self._request("POST", "/reset", json=body)
118
+ r.raise_for_status()
119
+ return r.json()
120
+
121
+ async def step(self, action: Dict[str, Any]) -> Dict[str, Any]:
122
+ r = await self._request("POST", "/step", json={"action": action})
123
+ r.raise_for_status()
124
+ return r.json()
125
+
126
+ async def get_state(self) -> Dict[str, Any]:
127
+ r = await self._request("GET", "/state")
128
+ r.raise_for_status()
129
+ return r.json()
130
+
131
+ async def grade_episode(
132
+ self,
133
+ *,
134
+ final_slo_score: float,
135
+ steps_taken: int,
136
+ max_steps: int,
137
+ actions_taken: List[Dict[str, Any]],
138
+ terminated: bool,
139
+ termination_reason: Optional[str],
140
+ ) -> Dict[str, Any]:
141
+ r = await self._request(
142
+ "POST",
143
+ "/grader",
144
+ json={
145
+ "final_slo_score": final_slo_score,
146
+ "steps_taken": steps_taken,
147
+ "max_steps": max_steps,
148
+ "actions_taken": actions_taken,
149
+ "terminated": terminated,
150
+ "termination_reason": termination_reason,
151
+ },
152
+ )
153
+ r.raise_for_status()
154
+ return r.json()
155
+
156
+
157
+ def run_async(coro):
158
+ """Run async coroutine from sync context (rollout_func)."""
159
+ return asyncio.run(coro)
training/eval.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Eval: local HF adapters + Gemini (google-genai) + Azure OpenAI + Azure AI Inference.
4
+ Writes eval_results.csv; pushes Mist-ic/sevzero-eval-results with HF_MAIN_TOKEN. No Claude.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import argparse
10
+ import csv
11
+ import os
12
+ import sys
13
+ from pathlib import Path
14
+ from typing import Any, Callable, Dict, List
15
+
16
+ _REPO = Path(__file__).resolve().parent.parent
17
+ if str(_REPO) not in sys.path:
18
+ sys.path.insert(0, str(_REPO))
19
+
20
+ from training.config_utils import try_load_env_files
21
+ from training.rollout_sevzero import SRE_SYSTEM_PROMPT, build_observation_prompt, parse_action
22
+
23
+ try_load_env_files()
24
+
25
+ HELD_OUT = (13, 99, 777)
26
+ DEFAULT_TASKS = ("easy", "medium", "hard")
27
+ DATASET_HUB = "Mist-ic/sevzero-eval-results"
28
+
29
+ BUILTIN: Dict[str, str] = {
30
+ "untrained-llama": "base:meta-llama/Llama-3.1-8B-Instruct",
31
+ "sft-primary": os.getenv("SFT_ADAPTER_PRIMARY", "PhaseOfCode/sevzero-llama3-8b-sft"),
32
+ "sft-backup": os.getenv("SFT_ADAPTER_BACKUP", "NoahInOblivion/sevzero-llama3-8b-sft"),
33
+ "sft-innovation": os.getenv("SFT_ADAPTER_INNOVATION", "NoxIsOblivion/sevzero-llama3-8b-sft"),
34
+ "grpo-primary": os.getenv("GRPO_ADAPTER_PRIMARY", "PhaseOfCode/sevzero-llama3-8b-grpo-primary"),
35
+ "grpo-stability": os.getenv("GRPO_ADAPTER_STABILITY", "NoahInOblivion/sevzero-llama3-8b-grpo-stability"),
36
+ "grpo-innovation": os.getenv("GRPO_ADAPTER_INNOVATION", "NoxIsOblivion/sevzero-llama3-8b-grpo-innovation"),
37
+ }
38
+
39
+ AZURE_INF = {
40
+ "grok-4.20-reasoning": "grok-2-latest",
41
+ "kimi-k2.6": "kimi-k2-6-2025",
42
+ "DeepSeek-V3.2": "DeepSeek-V3-2",
43
+ }
44
+
45
+
46
+ def run_episode(
47
+ base: str, task: str, seed: int, answer: Callable[[str, str], str]
48
+ ) -> Dict[str, Any]:
49
+ import httpx
50
+
51
+ with httpx.Client(base_url=base.rstrip("/"), timeout=120.0) as client:
52
+ r = client.post("/reset", json={"task_id": task, "seed": seed})
53
+ r.raise_for_status()
54
+ ro = r.json()
55
+ obs = ro.get("observation", ro)
56
+ done = ro.get("done", False)
57
+ user_pfx = f"You are the on-call SRE. task={task!r} seed={seed}.\n\n## Session\n"
58
+ for _ in range(1 + int(obs.get("max_steps", 20))):
59
+ if done:
60
+ break
61
+ user_block = user_pfx + build_observation_prompt(obs)
62
+ text = answer(SRE_SYSTEM_PROMPT, user_block)
63
+ act = parse_action(text)
64
+ sr = client.post(
65
+ "/step",
66
+ json={"action": {"action_type": str(act.get("action_type", "noop")), "params": act.get("params") or {}}},
67
+ )
68
+ sr.raise_for_status()
69
+ out = sr.json()
70
+ obs = out.get("observation", out)
71
+ done = out.get("done", False)
72
+ stt = client.get("/state")
73
+ stt.raise_for_status()
74
+ fs = stt.json()
75
+ g = client.post(
76
+ "/grader",
77
+ json={
78
+ "final_slo_score": float(fs.get("global_slo_score", 0.0)),
79
+ "steps_taken": int(fs.get("step_count", 0)),
80
+ "max_steps": int((obs or {}).get("max_steps", 10)),
81
+ "actions_taken": list((obs or {}).get("actions_taken", [])),
82
+ "terminated": bool(fs.get("terminated", True)),
83
+ "termination_reason": fs.get("termination_reason"),
84
+ },
85
+ )
86
+ js: Dict[str, Any] = {}
87
+ if g.status_code < 400:
88
+ js = g.json()
89
+ return {
90
+ "score": float(js.get("score", 0.0)),
91
+ "slo_recovery": float(js.get("slo_recovery", 0.0)),
92
+ "action_efficiency": float(js.get("action_efficiency", 0.0)),
93
+ "time_efficiency": float(js.get("time_efficiency", 0.0)),
94
+ "steps_used": int(fs.get("step_count", 0)),
95
+ "terminated": fs.get("terminated", True),
96
+ "termination_reason": str(fs.get("termination_reason", "")),
97
+ }
98
+
99
+
100
+ def load_llama_peft(adapter_id: str | None):
101
+ import torch
102
+ from peft import PeftModel
103
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
104
+
105
+ base_id = "meta-llama/Llama-3.1-8B-Instruct"
106
+ tok = AutoTokenizer.from_pretrained(base_id, use_fast=True, token=os.environ.get("HF_TOKEN"))
107
+ if tok.pad_token is None:
108
+ tok.pad_token = tok.eos_token
109
+ bnb = BitsAndBytesConfig(
110
+ load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
111
+ )
112
+ m = AutoModelForCausalLM.from_pretrained(
113
+ base_id, quantization_config=bnb, device_map="auto", torch_dtype=torch.bfloat16, token=os.environ.get("HF_TOKEN")
114
+ )
115
+ if adapter_id:
116
+ m = PeftModel.from_pretrained(m, adapter_id, token=os.environ.get("HF_TOKEN"))
117
+ m.eval()
118
+ return tok, m
119
+
120
+
121
+ def hf_answer(tok, mdl):
122
+ import torch
123
+
124
+ def answer(system: str, user: str) -> str:
125
+ messages = [{"role": "system", "content": system}, {"role": "user", "content": user}]
126
+ p = tok.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
127
+ inputs = tok(p, return_tensors="pt").to(mdl.device)
128
+ with torch.no_grad():
129
+ o = mdl.generate(**inputs, max_new_tokens=256, do_sample=True, temperature=0.0)
130
+ gen = o[0, inputs["input_ids"].shape[1] :]
131
+ return tok.decode(gen, skip_special_tokens=True)
132
+
133
+ return answer
134
+
135
+
136
+ def answer_gemini(system: str, user: str) -> str:
137
+ from google import genai
138
+
139
+ model = os.environ.get(
140
+ "GEMINI_EVAL_MODEL",
141
+ os.environ.get("GEMINI_MODEL_PRO", "gemini-3.1-pro-preview"),
142
+ )
143
+ c = genai.Client(api_key=os.environ["GEMINI_API_KEY"])
144
+ r = c.models.generate_content(model=model, contents=f"{system}\n\n{user}")
145
+ return (r.text or "").strip()
146
+
147
+
148
+ def answer_azure_openai(system: str, user: str) -> str:
149
+ from openai import OpenAI
150
+
151
+ ep = os.environ.get("AZURE_OPENAI_ENDPOINT", "").rstrip("/")
152
+ c = OpenAI(
153
+ api_key=os.environ.get("AZURE_API_KEY", ""),
154
+ base_url=ep + "/openai/v1",
155
+ )
156
+ dep = os.environ.get("AZURE_GPT_DEPLOYMENT", "gpt-5.4-pro")
157
+ r = c.chat.completions.create(
158
+ model=dep,
159
+ messages=[{"role": "system", "content": system}, {"role": "user", "content": user}],
160
+ temperature=0.0,
161
+ max_tokens=512,
162
+ )
163
+ return (r.choices[0].message.content or "").strip()
164
+
165
+
166
+ def answer_azure_inference(model_name: str, system: str, user: str) -> str:
167
+ from azure.ai.inference import ChatCompletionsClient
168
+ from azure.core.credentials import AzureKeyCredential
169
+
170
+ ep = os.environ.get("AZURE_AI_INFERENCE_ENDPOINT", "").rstrip("/") + "/"
171
+ c = ChatCompletionsClient(endpoint=ep, credential=AzureKeyCredential(os.environ.get("AZURE_API_KEY", "")))
172
+ r = c.complete(
173
+ model_name=model_name,
174
+ messages=[{"role": "user", "content": f"{system}\n\n{user}"}],
175
+ )
176
+ return (r.choices[0].message.content or "").strip()
177
+
178
+
179
+ def pick_answer_fn(name: str) -> Callable[[str, str], str]:
180
+ n = name.strip()
181
+ if n in BUILTIN:
182
+ spec = BUILTIN[n]
183
+ aid = None if spec.startswith("base:") else spec
184
+ tok, m = load_llama_peft(aid)
185
+ return hf_answer(tok, m)
186
+ if "/" in n and n.count("/") == 1 and not n.startswith("meta-llama/"):
187
+ tok, m = load_llama_peft(n)
188
+ return hf_answer(tok, m)
189
+ if n.startswith("gemini"):
190
+ return answer_gemini
191
+ if "gpt" in n.lower() or n == "gpt-5.4-pro":
192
+ return answer_azure_openai
193
+ if n in AZURE_INF:
194
+ mid = AZURE_INF[n]
195
+
196
+ def _fn(s: str, u: str) -> str:
197
+ return answer_azure_inference(mid, s, u)
198
+
199
+ return _fn
200
+ raise ValueError(f"Unknown model key: {name!r}")
201
+
202
+
203
+ def main() -> None:
204
+ ap = argparse.ArgumentParser()
205
+ ap.add_argument("--models", type=str, default="untrained-llama")
206
+ ap.add_argument("--out", type=str, default="eval_results.csv")
207
+ ap.add_argument("--seeds", type=str, default=",".join(str(s) for s in HELD_OUT))
208
+ ap.add_argument("--tasks", type=str, default=",".join(DEFAULT_TASKS))
209
+ a = ap.parse_args()
210
+
211
+ base = (os.environ.get("SEVZERO_ENV_URL") or "").rstrip("/")
212
+ if not base:
213
+ raise SystemExit("SEVZERO_ENV_URL required")
214
+
215
+ models = [m.strip() for m in a.models.split(",") if m.strip()]
216
+ seeds = [int(x) for x in a.seeds.split(",")]
217
+ tasks = [t.strip() for t in a.tasks.split(",")]
218
+
219
+ rows: List[Dict[str, Any]] = []
220
+ for mname in models:
221
+ try:
222
+ answer = pick_answer_fn(mname)
223
+ except ValueError as e:
224
+ print(f"SKIP {mname}: {e}", flush=True)
225
+ continue
226
+ for task in tasks:
227
+ for seed in seeds:
228
+ r = run_episode(base, task, seed, answer)
229
+ rows.append(
230
+ {
231
+ "model": mname,
232
+ "task": task,
233
+ "seed": seed,
234
+ **r,
235
+ }
236
+ )
237
+ print(rows[-1], flush=True)
238
+
239
+ with Path(a.out).open("w", newline="", encoding="utf-8") as f:
240
+ fieldnames = [
241
+ "model",
242
+ "task",
243
+ "seed",
244
+ "score",
245
+ "slo_recovery",
246
+ "action_efficiency",
247
+ "time_efficiency",
248
+ "steps_used",
249
+ "terminated",
250
+ "termination_reason",
251
+ ]
252
+ w = csv.DictWriter(f, fieldnames=fieldnames)
253
+ w.writeheader()
254
+ for r in rows:
255
+ w.writerow(r)
256
+
257
+ tok_m = os.environ.get("HF_MAIN_TOKEN", "")
258
+ if not tok_m:
259
+ print("HF_MAIN_TOKEN not set — skip Hub push", flush=True)
260
+ return
261
+ from datasets import Dataset
262
+
263
+ ds = Dataset.from_list([dict(x) for x in rows])
264
+ ds.push_to_hub(DATASET_HUB, token=tok_m, private=False)
265
+ print(f"OK: pushed hf.co/datasets/{DATASET_HUB}", flush=True)
266
+
267
+
268
+ if __name__ == "__main__":
269
+ main()
training/launch_hf_job.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Submit a HuggingFace Job to run training/train_sft.py or training/train_grpo.py.
4
+ Uses huggingface_hub.run_job; prints job URL; appends training/runs.jsonl.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import argparse
10
+ import json
11
+ import os
12
+ import subprocess
13
+ import sys
14
+ from datetime import datetime, timezone
15
+ from pathlib import Path
16
+
17
+ _REPO = Path(__file__).resolve().parent.parent
18
+ if str(_REPO) not in sys.path:
19
+ sys.path.insert(0, str(_REPO))
20
+
21
+ from training.config_utils import try_load_env_files
22
+
23
+ try_load_env_files()
24
+
25
+
26
+ def _default_git_url() -> str:
27
+ r = subprocess.run(
28
+ ["git", "remote", "get-url", "origin"],
29
+ cwd=str(_REPO),
30
+ capture_output=True,
31
+ text=True,
32
+ )
33
+ return (r.stdout or "").strip() if r.returncode == 0 else ""
34
+
35
+
36
+ def main() -> None:
37
+ p = argparse.ArgumentParser()
38
+ p.add_argument("--account_token", type=str, default=os.environ.get("HF_TOKEN", ""))
39
+ p.add_argument("--script", type=str, choices=("sft", "grpo"), required=True)
40
+ p.add_argument("--variant_name", type=str, default="run")
41
+ p.add_argument("--hardware", type=str, default="l40sx1")
42
+ p.add_argument(
43
+ "--image",
44
+ type=str,
45
+ default="pytorch/pytorch:2.6.0-cuda12.4-cudnn9-runtime",
46
+ )
47
+ p.add_argument("--git-url", type=str, default="")
48
+ p.add_argument(
49
+ "--env_vars",
50
+ type=str,
51
+ default="",
52
+ help="KEY=val pairs comma-separated, e.g. SEVZERO_ENV_URL=https://x.hf.space,HF_MAIN_TOKEN=...",
53
+ )
54
+ a, rest = p.parse_known_args()
55
+ if not a.account_token:
56
+ raise SystemExit("Need HF_TOKEN or --account_token")
57
+ git_url = a.git_url or _default_git_url()
58
+ if not git_url:
59
+ raise SystemExit("Set --git-url or configure git origin")
60
+ ev = {k: v for k, v in [x.split("=", 1) for x in a.env_vars.split(",") if "=" in x]}
61
+ if "SEVZERO_ENV_URL" not in ev and os.environ.get("SEVZERO_ENV_URL"):
62
+ ev["SEVZERO_ENV_URL"] = os.environ["SEVZERO_ENV_URL"]
63
+
64
+ which = f"training/train_{a.script}.py"
65
+ extra = " ".join(rest)
66
+ inner = (
67
+ f"set -euo pipefail && git clone --depth 1 {git_url!r} /work/r && cd /work/r && "
68
+ "pip install -U pip 'trl>=0.20' 'peft' 'transformers' 'accelerate' 'bitsandbytes' 'datasets' "
69
+ "'huggingface_hub' 'httpx' 'python-dotenv' 'vllm' 'unsloth' 2>/dev/null || true && "
70
+ f"python {which} --variant_name {a.variant_name!r} {extra}"
71
+ )
72
+ from huggingface_hub import run_job
73
+
74
+ job = run_job(
75
+ image=a.image,
76
+ command=["bash", "-lc", inner],
77
+ env=ev,
78
+ secrets={"HF_TOKEN": a.account_token},
79
+ flavor=a.hardware,
80
+ )
81
+ with (_REPO / "training" / "runs.jsonl").open("a", encoding="utf-8") as f:
82
+ f.write(
83
+ json.dumps(
84
+ {
85
+ "account_token_tail": a.account_token[-4:] if len(a.account_token) > 4 else "",
86
+ "job_id": str(getattr(job, "id", job)),
87
+ "variant_name": a.variant_name,
88
+ "started_at": datetime.now(timezone.utc).isoformat(),
89
+ }
90
+ )
91
+ + "\n"
92
+ )
93
+ print(getattr(job, "url", f"https://huggingface.co/jobs/{getattr(job, 'id', job)}"), flush=True)
94
+
95
+
96
+ if __name__ == "__main__":
97
+ main()
training/loader.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Load SevZero SFT data for a trainer: local JSONL or the Hub Parquet copy.
3
+
4
+ The training config should set `max_seq_length` to at least
5
+ `max_prompt_token_length` from `build_stats.json` (plus max completion length).
6
+ """
7
+ from __future__ import annotations
8
+
9
+ import json
10
+ import os
11
+ import sys
12
+ from pathlib import Path
13
+ from typing import Any, Optional, Union
14
+
15
+ REPO_ROOT = Path(__file__).resolve().parent.parent
16
+ DATA_DIR = REPO_ROOT / "training" / "data"
17
+
18
+ try:
19
+ from datasets import Dataset, DatasetDict, load_dataset
20
+ except ImportError as e:
21
+ raise ImportError("Install `datasets` to use the loader.") from e
22
+
23
+
24
+ def load_local_jsonl(
25
+ train_path: Optional[Path] = None,
26
+ eval_path: Optional[Path] = None,
27
+ ) -> DatasetDict:
28
+ train_path = train_path or (DATA_DIR / "sft_train.jsonl")
29
+ eval_path = eval_path or (DATA_DIR / "sft_eval.jsonl")
30
+ train = load_dataset("json", data_files=str(train_path), split="train")
31
+ if eval_path.is_file() and eval_path.stat().st_size > 0:
32
+ ev = load_dataset("json", data_files=str(eval_path), split="train")
33
+ else:
34
+ ev = train.select([])
35
+ return DatasetDict(train=train, eval=ev)
36
+
37
+
38
+ def load_from_hub(
39
+ repo_id: str = "Mist-ic/sevzero-expert-trajectories",
40
+ token: Optional[str] = None,
41
+ ) -> DatasetDict:
42
+ tok = token or os.environ.get("HF_MAIN_TOKEN")
43
+ return load_dataset(repo_id, token=tok) # type: ignore[return-value]
44
+
45
+
46
+ def read_build_stats() -> dict[str, Any]:
47
+ p = DATA_DIR / "build_stats.json"
48
+ if not p.is_file():
49
+ return {}
50
+ return json.loads(p.read_text(encoding="utf-8"))
51
+
52
+
53
+ def recommended_max_seq_length(plus_completion: int = 1024) -> int:
54
+ s = read_build_stats()
55
+ m = int(s.get("max_prompt_token_length", 0) or 0)
56
+ return m + plus_completion
training/preflight.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ (1) In-process Sim + grader: golden remediation plan → score >= 0.9 when possible
4
+ (2) Uvicorn /health (optional) + 5 CPU GRPO steps with rollout_func + tiny model
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import os
10
+ import signal
11
+ import subprocess
12
+ import sys
13
+ import time
14
+ from pathlib import Path
15
+ from typing import Any, Dict, List, Tuple
16
+
17
+ _REPO = Path(__file__).resolve().parent.parent
18
+ if str(_REPO) not in sys.path:
19
+ sys.path.insert(0, str(_REPO))
20
+
21
+ from training.config_utils import try_load_env_files
22
+
23
+ try_load_env_files()
24
+
25
+
26
+ def _action_plan(seed: int, task_id: str) -> List[Tuple[str, Dict[str, Any]]]:
27
+ from server.failures import FailureType
28
+ from server.scenarios import generate_scenario
29
+
30
+ sc = generate_scenario(seed, task_id)
31
+ if not sc.failure_specs:
32
+ return [("noop", {})]
33
+ spec = sc.failure_specs[0]
34
+ sid = spec.service_id
35
+ ft = spec.failure_type
36
+ if ft == FailureType.BAD_DEPLOY:
37
+ return [("rollback_service", {"service_id": sid})]
38
+ if ft in (FailureType.CONFIG_STARTUP, FailureType.CONFIG_RUNTIME):
39
+ k = spec.broken_config_key or "timeout_ms"
40
+ out = [("tune_config", {"service_id": sid, "key": k, "value": "correct"})]
41
+ if ft == FailureType.CONFIG_STARTUP:
42
+ out.append(("restart_service", {"service_id": sid}))
43
+ return out
44
+ if ft == FailureType.CACHE_FAILURE:
45
+ return [("clear_cache", {"cache_name": sid})]
46
+ if ft == FailureType.CASCADING_LATENCY:
47
+ return [("scale_service", {"service_id": sid, "replicas": 4})]
48
+ if ft == FailureType.NETWORK_ERROR:
49
+ return [("noop", {}), ("noop", {})]
50
+ return [("restart_service", {"service_id": sid})]
51
+
52
+
53
+ def _inproc_golden_score(seed: int, task_id: str) -> float:
54
+ from server.grader import grade_episode
55
+ from server.scenarios import generate_scenario
56
+ from server.simulator import Simulator
57
+
58
+ sc = generate_scenario(seed, task_id)
59
+ sim = Simulator()
60
+ sim.reset(seed=seed, difficulty=sc.difficulty, failure_specs=sc.failure_specs)
61
+ for at, p in _action_plan(seed, task_id):
62
+ sim.step(at, p)
63
+ for _ in range(4):
64
+ if sim.terminated:
65
+ break
66
+ sim.step("noop", {})
67
+ g = grade_episode(
68
+ final_slo_score=sim.get_slo_score(),
69
+ steps_taken=len(sim.actions_taken),
70
+ max_steps=sc.max_steps,
71
+ actions_taken=sim.actions_taken,
72
+ terminated=sim.terminated,
73
+ termination_reason=sim.termination_reason,
74
+ )
75
+ return float(g.score)
76
+
77
+
78
+ def _grpo_tiny() -> bool:
79
+ try:
80
+ import trl # noqa: F401
81
+ except ImportError:
82
+ print("GRPO preflight: trl not installed — skip (pip install trl)", flush=True)
83
+ return True
84
+ os.environ["UNSLOTH_DISABLE"] = "1"
85
+ os.environ["CUDA_VISIBLE_DEVICES"] = os.environ.get("CUDA_VISIBLE_DEVICES", "")
86
+
87
+ from datasets import Dataset
88
+ from peft import LoraConfig, get_peft_model
89
+ from transformers import AutoModelForCausalLM, AutoTokenizer
90
+ from trl import GRPOConfig, GRPOTrainer
91
+ from trl.experimental.openenv import generate_rollout_completions
92
+
93
+ from training.env_client import AsyncSevZeroEnvClient, run_async
94
+ from training.rollout_sevzero import SRE_SYSTEM_PROMPT, build_observation_prompt, parse_action
95
+
96
+ base = (os.environ.get("SEVZERO_ENV_URL") or "").rstrip("/")
97
+ if not base:
98
+ print("SEVZERO_ENV_URL unset — skip GRPO smoke", flush=True)
99
+ return True
100
+
101
+ tok = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M-Instruct")
102
+ m = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM2-135M-Instruct", device_map="cpu")
103
+ m = get_peft_model(
104
+ m,
105
+ LoraConfig(
106
+ r=4,
107
+ lora_alpha=8,
108
+ target_modules=["q_proj", "v_proj"],
109
+ lora_dropout=0.0,
110
+ task_type="CAUSAL_LM",
111
+ ),
112
+ )
113
+
114
+ def rollout_func(prompts, trainer):
115
+ ep_ids: List[int] = []
116
+ ec_ids: List[int] = []
117
+ elp: List[float] = []
118
+ env_r: List[float] = []
119
+ for pr in prompts:
120
+ client = AsyncSevZeroEnvClient(base, None)
121
+
122
+ async def run_one():
123
+ p_ids, c_ids, lps = [], [], []
124
+ step_sum = 0.0
125
+ try:
126
+ ro = await client.reset(task_id="easy", seed=7)
127
+ obs = ro.get("observation", ro)
128
+ done = ro.get("done", False)
129
+ for _ in range(2):
130
+ if done:
131
+ break
132
+ u = build_observation_prompt(obs)
133
+ msg = [
134
+ {"role": "system", "content": SRE_SYSTEM_PROMPT},
135
+ {"role": "user", "content": f"{pr}\n{u}"},
136
+ ]
137
+ ptxt = tok.apply_chat_template(msg, add_generation_prompt=True, tokenize=False)
138
+ out = generate_rollout_completions(trainer, [ptxt])[0]
139
+ p_ids.extend(out.get("prompt_ids", []))
140
+ c_ids.extend(out.get("completion_ids", []))
141
+ lps.extend(out.get("logprobs", []))
142
+ ctext = out.get("text")
143
+ if not ctext and cids:
144
+ ctext = tok.decode(cids, skip_special_tokens=True)
145
+ a = parse_action(ctext or "")
146
+ sr = await client.step(
147
+ {
148
+ "action": {
149
+ "action_type": str(a.get("action_type", "noop")),
150
+ "params": a.get("params") or {},
151
+ }
152
+ }
153
+ )
154
+ obs = sr.get("observation", sr)
155
+ done = sr.get("done", False)
156
+ step_sum += float(obs.get("reward", sr.get("reward", 0.0) or 0.0))
157
+ return p_ids, c_ids, lps, step_sum
158
+ finally:
159
+ await client.aclose()
160
+
161
+ p, c, lp, s = run_async(run_one())
162
+ ep_ids.append(p)
163
+ ec_ids.append(c)
164
+ elp.append(lp)
165
+ env_r.append(s)
166
+ return {
167
+ "prompt_ids": ep_ids,
168
+ "completion_ids": ec_ids,
169
+ "logprobs": elp,
170
+ "env_reward": env_r,
171
+ }
172
+
173
+ def rf(completions, **kwargs):
174
+ return [float(x) for x in kwargs.get("env_reward", [0.0] * len(completions))]
175
+
176
+ out_dir = str(_REPO / "training" / ".preflight_grpo")
177
+ os.makedirs(out_dir, exist_ok=True)
178
+ tr = GRPOTrainer(
179
+ model=m,
180
+ processing_class=tok,
181
+ args=GRPOConfig(
182
+ output_dir=out_dir,
183
+ per_device_train_batch_size=1,
184
+ max_steps=5,
185
+ num_generations=1,
186
+ use_vllm=False,
187
+ learning_rate=1e-5,
188
+ max_completion_length=32,
189
+ ),
190
+ train_dataset=Dataset.from_list([{"text": "x"}] * 2),
191
+ reward_funcs=[rf],
192
+ rollout_func=rollout_func,
193
+ )
194
+ tr.train()
195
+ return True
196
+
197
+
198
+ def main() -> None:
199
+ # --- Part A: in-process (no network)
200
+ for seed, task in ((100, "easy"), (13, "easy"), (7, "easy")):
201
+ s = _inproc_golden_score(seed, task)
202
+ print(f"in-proc grader: seed={seed} task={task} score={s:.3f}", flush=True)
203
+ if s >= 0.9:
204
+ print("OK: in-process golden path reached >=0.9", flush=True)
205
+ break
206
+ else:
207
+ print("WARN: no seed reached 0.9 in in-proc test — check failure coverage", flush=True)
208
+
209
+ # --- B: Uvicorn + optional GRPO (requires same deps as the project)
210
+ try:
211
+ import uvicorn # noqa: F401
212
+ except ImportError:
213
+ print("SKIP: uvicorn not installed — pip install the project (see training/README.md)", flush=True)
214
+ print("OK", flush=True)
215
+ return
216
+
217
+ port = int(os.environ.get("PREFLIGHT_PORT", "8765"))
218
+ base = f"http://127.0.0.1:{port}"
219
+ os.environ["SEVZERO_ENV_URL"] = base
220
+ import urllib.request
221
+
222
+ proc = subprocess.Popen(
223
+ [sys.executable, "-m", "uvicorn", "server.app:app", "--host", "127.0.0.1", "--port", str(port)],
224
+ cwd=str(_REPO),
225
+ )
226
+ try:
227
+ for _ in range(25):
228
+ try:
229
+ with urllib.request.urlopen(f"{base}/health", timeout=2) as r:
230
+ if getattr(r, "status", 200) < 500:
231
+ break
232
+ except Exception:
233
+ time.sleep(0.5)
234
+ else:
235
+ raise RuntimeError("uvicorn not up")
236
+ try:
237
+ _grpo_tiny()
238
+ except Exception as e:
239
+ print(f"GRPO smoke failed (env OK): {e}", flush=True)
240
+ finally:
241
+ proc.terminate()
242
+ try:
243
+ proc.wait(timeout=10)
244
+ except Exception:
245
+ proc.kill()
246
+ print("OK", flush=True)
247
+
248
+
249
+ if __name__ == "__main__":
250
+ main()
training/push_dataset.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Upload SFT jsonl to Hugging Face (Mist-ic Main account) as a public dataset with Parquet.
3
+ """
4
+ from __future__ import annotations
5
+
6
+ import json
7
+ import os
8
+ import sys
9
+ from pathlib import Path
10
+
11
+ from dotenv import load_dotenv
12
+ from huggingface_hub import HfApi
13
+
14
+ REPO_ROOT = Path(__file__).resolve().parent.parent
15
+ load_dotenv(REPO_ROOT / "api.env")
16
+ load_dotenv(REPO_ROOT / "hg.env")
17
+
18
+ if str(REPO_ROOT) not in sys.path:
19
+ sys.path.insert(0, str(REPO_ROOT))
20
+
21
+ DATA_DIR = REPO_ROOT / "training" / "data"
22
+ STATS_PATH = DATA_DIR / "build_stats.json"
23
+
24
+
25
+ def _readme(stats: dict) -> str:
26
+ return f"""# SevZero expert trajectories (SFT)
27
+
28
+ ## Sources
29
+
30
+ - Synthetic expert rollouts from frontier models (Gemini 3.1 Pro, Azure OpenAI, Azure AI Inference)
31
+ against the local OpenEnv `server.app` SevZero environment.
32
+
33
+ ## Filtering
34
+
35
+ - Episodes with final grader `score` **≥** `{stats.get("min_score_filter", 0.85)}` are included.
36
+
37
+ ## Schema
38
+
39
+ - Each example has a `messages` list (Llama-3.1-8B-Instruct–style SFT) and `meta` (episode / step provenance):
40
+ - `system`: SRE on-call system prompt (same as `inference.SYSTEM_PROMPT` in the repo)
41
+ - `user`: JSON-serialized observation (shrink to ≤ {stats.get("max_observation_user_token_budget", 2048)} tokens for the user part)
42
+ - `assistant`: one JSON object `{{"action_type": "...", "params": {{...}}}}`
43
+
44
+ ## Stats (from `build_stats.json` at publish time)
45
+
46
+ {json.dumps(stats, indent=2)}
47
+
48
+ ## Parquet
49
+
50
+ - Splits `train` and `eval` are also pushed in Parquet for fast `datasets.load_dataset`.
51
+ """
52
+
53
+
54
+ def _dataset_info(stats: dict) -> dict:
55
+ return {
56
+ "description": "SevZero SFT expert trajectories for Llama-3.1-8B-Instruct style chat training.",
57
+ "version": "1.0.0",
58
+ "license": "apache-2.0",
59
+ "build": stats,
60
+ }
61
+
62
+
63
+ def main() -> None:
64
+ token = os.environ.get("HF_MAIN_TOKEN", "")
65
+ if not token:
66
+ raise SystemExit("HF_MAIN_TOKEN missing (set in api.env or hg.env).")
67
+ user = (os.environ.get("HF_MAIN_USERNAME", "") or "").strip() or "Mist-ic"
68
+ repo_id = f"{user}/sevzero-expert-trajectories"
69
+ if not (DATA_DIR / "sft_train.jsonl").is_file():
70
+ raise SystemExit(f"Missing {DATA_DIR / 'sft_train.jsonl'} — run build_dataset.py first.")
71
+ stats: dict = {}
72
+ if STATS_PATH.is_file():
73
+ stats = json.loads(STATS_PATH.read_text(encoding="utf-8"))
74
+ readme = _readme(stats)
75
+ info = _dataset_info(stats)
76
+ (DATA_DIR / "DATASET_README_HF.md").write_text(readme, encoding="utf-8")
77
+ (DATA_DIR / "dataset_info.json").write_text(
78
+ json.dumps(info, indent=2), encoding="utf-8"
79
+ )
80
+
81
+ api = HfApi(token=token)
82
+ api.create_repo(
83
+ repo_id=repo_id,
84
+ repo_type="dataset",
85
+ private=False,
86
+ exist_ok=True,
87
+ )
88
+ for name in (
89
+ "sft_train.jsonl",
90
+ "sft_eval.jsonl",
91
+ "build_stats.json",
92
+ "dataset_info.json",
93
+ ):
94
+ p = DATA_DIR / name
95
+ if p.is_file():
96
+ api.upload_file(
97
+ path_or_fileobj=str(p),
98
+ path_in_repo=name,
99
+ repo_id=repo_id,
100
+ repo_type="dataset",
101
+ commit_message="Add SFT files and metadata",
102
+ )
103
+ api.upload_file(
104
+ path_or_fileobj=readme.encode("utf-8"),
105
+ path_in_repo="README.md",
106
+ repo_id=repo_id,
107
+ repo_type="dataset",
108
+ commit_message="Add dataset README",
109
+ )
110
+
111
+ from datasets import DatasetDict, load_dataset
112
+
113
+ train = load_dataset("json", data_files=str(DATA_DIR / "sft_train.jsonl"))["train"]
114
+ evp = DATA_DIR / "sft_eval.jsonl"
115
+ if evp.is_file() and evp.stat().st_size > 0:
116
+ ev = load_dataset("json", data_files=str(evp))["train"]
117
+ else:
118
+ ev = train.select([])
119
+ dd = DatasetDict(train=train, eval=ev)
120
+ dd.push_to_hub(repo_id, private=False, token=token)
121
+
122
+ url = f"https://huggingface.co/datasets/{repo_id}"
123
+ print(url, flush=True)
124
+
125
+
126
+ if __name__ == "__main__":
127
+ main()
training/rollout_sevzero.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SevZero multi-turn rollout helpers for TRL GRPO (sync API for rollout_func).
3
+ Builds chat prompts from observations and parses one JSON action per turn.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ import json
9
+ import textwrap
10
+ from typing import Any, Dict, List, Optional, Tuple
11
+
12
+ SRE_SYSTEM_PROMPT = textwrap.dedent(
13
+ """\
14
+ You are an expert Site Reliability Engineer (SRE) responding to a production incident.
15
+ You are managing a microservice cluster experiencing failures.
16
+ Your goal: restore all services to healthy SLO compliance as efficiently as possible.
17
+
18
+ Respond with EXACTLY one JSON object — no explanation, no markdown, just raw JSON:
19
+ {"action_type": "...", "params": {...}}
20
+
21
+ Param rules (STRICT — single service only, never a list):
22
+ - inspect_logs / inspect_metrics / inspect_traces / restart_service / rollback_service / scale_service:
23
+ {"action_type": "X", "params": {"service_id": "order-service"}}
24
+ - tune_config:
25
+ {"action_type": "tune_config", "params": {"service_id": "order-service", "key": "api_endpoint", "value": "correct"}}
26
+ - clear_cache:
27
+ {"action_type": "clear_cache", "params": {"cache_name": "redis-cache"}}
28
+ - rebalance_traffic:
29
+ {"action_type": "rebalance_traffic", "params": {"from_region": "us-east-1", "to_region": "us-west-2"}}
30
+ - noop:
31
+ {"action_type": "noop", "params": {}}
32
+ """
33
+ )
34
+
35
+
36
+ def build_observation_prompt(obs: Dict[str, Any]) -> str:
37
+ """Port of inference.build_observation_prompt (observation dict from HTTP JSON)."""
38
+ parts = [f"## Incident Status\n{obs.get('observation_summary', 'N/A')}"]
39
+ alerts = obs.get("alerts") or []
40
+ if alerts:
41
+ alert_lines = [f" [{a['severity'].upper()}] {a['message']}" for a in alerts[:10]]
42
+ parts.append("## Active Alerts\n" + "\n".join(alert_lines))
43
+ services = obs.get("services") or []
44
+ degraded = [s for s in services if s.get("status") in ("degraded", "critical", "down")]
45
+ if degraded:
46
+ svc_lines = []
47
+ for s in degraded:
48
+ sid = s["id"]
49
+ svc_lines.append(
50
+ f" {sid} [{s['status']}]: error={s['error_rate']:.1%}, "
51
+ f"p99={s['latency_p99_ms']:.0f}ms, cpu={s['cpu_pct']:.0f}%, "
52
+ f"mem={s['memory_pct']:.0f}%"
53
+ )
54
+ parts.append("## Degraded Services\n" + "\n".join(svc_lines))
55
+ deploys = obs.get("recent_deploys") or []
56
+ if deploys:
57
+ dep_lines = [f" {d['service']} -> {d['version']} ({d['ticks_ago']} ticks ago)" for d in deploys]
58
+ parts.append("## Recent Deploys\n" + "\n".join(dep_lines))
59
+ actions = obs.get("actions_taken") or []
60
+ if actions:
61
+ act_lines = [
62
+ f" tick {a['tick']}: {a['action']}({a.get('target', '')}) -> {'OK' if a['success'] else 'FAIL'}"
63
+ for a in actions[-5:]
64
+ ]
65
+ parts.append("## Recent Actions\n" + "\n".join(act_lines))
66
+ logs = obs.get("logs")
67
+ if logs:
68
+ parts.append(f"## Logs\n{logs}")
69
+ traces = obs.get("traces")
70
+ if traces:
71
+ spans = (traces.get("spans") or []) if isinstance(traces, dict) else []
72
+ error_spans = [s for s in spans if s.get("status") == "ERROR"]
73
+ if error_spans:
74
+ trace_lines = [
75
+ f" {s.get('service')}: {s.get('tags', {}).get('error.message', 'ERROR')}"
76
+ for s in error_spans[:5]
77
+ ]
78
+ parts.append("## Trace Errors\n" + "\n".join(trace_lines))
79
+ legal = obs.get("legal_actions") or []
80
+ if legal:
81
+ legal_strs = [f" {la.get('action_type', '')}: targets={la.get('valid_targets', [])[:5]}" for la in legal]
82
+ parts.append("## Available Actions\n" + "\n".join(legal_strs))
83
+ return "\n\n".join(parts)
84
+
85
+
86
+ def parse_action(response_text: str) -> Dict[str, Any]:
87
+ text = (response_text or "").strip()
88
+ if "```json" in text:
89
+ text = text.split("```json", 1)[1].split("```", 1)[0].strip()
90
+ elif "```" in text:
91
+ text = text.split("```", 1)[1].split("```", 1)[0].strip()
92
+ start, end = text.find("{"), text.rfind("}") + 1
93
+ if start >= 0 and end > start:
94
+ try:
95
+ return json.loads(text[start:end])
96
+ except json.JSONDecodeError:
97
+ pass
98
+ return {"action_type": "noop", "params": {}}
99
+
100
+
101
+ def _normalize_action(action: Dict[str, Any]) -> Dict[str, Any]:
102
+ act_type = action.get("action_type", "noop")
103
+ params = dict(action.get("params") or {})
104
+ if "replicas" in params:
105
+ try:
106
+ params["replicas"] = int(params["replicas"])
107
+ except (TypeError, ValueError):
108
+ params["replicas"] = 2
109
+ return {"action_type": act_type, "params": params}
training/train_grpo.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ GRPO on SevZero via TRL rollout_func + trl.experimental.openenv.generate_rollout_completions.
4
+ Verify API with Context7 before changing integration (rollout_func is required; environment_factory is deprecated).
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import argparse
10
+ import json
11
+ import os
12
+ import random
13
+ import sys
14
+ from pathlib import Path
15
+ from typing import Any, Dict, List, Optional
16
+
17
+ _REPO = Path(__file__).resolve().parent.parent
18
+ if str(_REPO) not in sys.path:
19
+ sys.path.insert(0, str(_REPO))
20
+
21
+ from training.config_utils import try_load_env_files
22
+
23
+ try_load_env_files()
24
+
25
+ BASE_MODEL = "meta-llama/Llama-3.1-8B-Instruct"
26
+ METRICS_NAME = "metrics.jsonl"
27
+
28
+ # Pinned in README: trl, unsloth, vllm — orchestrator sets exact versions
29
+
30
+
31
+ def _parse_args() -> argparse.Namespace:
32
+ p = argparse.ArgumentParser()
33
+ p.add_argument("--output_dir", type=str, default="./outputs/grpo")
34
+ p.add_argument("--sft_adapter_repo", type=str, required=True, help="HF adapter repo (worker account)")
35
+ p.add_argument("--env_url", type=str, default="", help="Override; else SEVZERO_ENV_URL")
36
+ p.add_argument("--max_steps", type=int, default=350)
37
+ p.add_argument("--lr", type=float, default=7e-6)
38
+ p.add_argument("--K", type=int, default=4, dest="K", help="num_generations")
39
+ p.add_argument("--seed", type=int, default=42)
40
+ p.add_argument(
41
+ "--reward_shaping",
42
+ type=str,
43
+ default="dense_v1",
44
+ choices=("dense_v1", "dense_v2", "sparse"),
45
+ )
46
+ p.add_argument("--enable_schema_drift", action="store_true")
47
+ p.add_argument("--enable_curriculum", action="store_true")
48
+ p.add_argument("--enable_oversight", action="store_true")
49
+ p.add_argument(
50
+ "--task_mix",
51
+ type=str,
52
+ default="hard",
53
+ choices=("hard", "mixed", "curriculum"),
54
+ )
55
+ p.add_argument("--push_to_hub_repo", type=str, default="")
56
+ p.add_argument("--variant_name", type=str, default="grpo")
57
+ p.add_argument("--rollout_max_steps", type=int, default=0, help="0 = from env observation max_steps")
58
+ return p.parse_args()
59
+
60
+
61
+ def _pick_task_id(args, idx: int, step: int) -> str:
62
+ if args.task_mix == "hard":
63
+ return "hard"
64
+ if args.task_mix == "mixed":
65
+ return ["easy", "medium", "hard"][idx % 3]
66
+ # curriculum: escalate every ~50 steps
67
+ if args.enable_curriculum:
68
+ tier = min(2, step // 50)
69
+ return ["easy", "medium", "hard"][tier]
70
+ return "hard"
71
+
72
+
73
+ def _compute_episode_return(
74
+ shaping: str,
75
+ step_rewards: List[float],
76
+ grader: Optional[Dict[str, Any]],
77
+ ) -> float:
78
+ if shaping == "sparse" and grader is not None:
79
+ return float(grader.get("score", 0.0))
80
+ if shaping == "dense_v2" and grader is not None:
81
+ # Slightly weight terminal score
82
+ s = sum(step_rewards) if step_rewards else 0.0
83
+ return 0.7 * s + 0.3 * float(grader.get("score", 0.0))
84
+ return float(sum(step_rewards)) if step_rewards else 0.0
85
+
86
+
87
+ def _build_default_dataset():
88
+ from datasets import Dataset
89
+
90
+ rows = []
91
+ for i in range(64):
92
+ text = (
93
+ "You are the on-call SRE. Restore service health. "
94
+ f"Incident session {i} — triage, diagnose root cause, remediate, verify."
95
+ )
96
+ rows.append({"text": text, "row_id": i})
97
+ return Dataset.from_list(rows)
98
+
99
+
100
+ def _reward_from_env(completions, **kwargs):
101
+ r = kwargs.get("env_reward")
102
+ if r is None:
103
+ return [0.0] * len(completions)
104
+ return [float(x) for x in r]
105
+
106
+
107
+ def main() -> None:
108
+ args = _parse_args()
109
+ env_url = (args.env_url or os.environ.get("SEVZERO_ENV_URL", "")).rstrip("/")
110
+ if not env_url:
111
+ raise SystemExit("Set --env_url or SEVZERO_ENV_URL to the remote SevZero HTTP base URL")
112
+
113
+ worker_token = os.environ.get("HF_TOKEN", "")
114
+ main_token = os.environ.get("HF_MAIN_TOKEN", "")
115
+
116
+ try:
117
+ import trackio
118
+
119
+ trackio.init(
120
+ project="sevzero-grpo",
121
+ space_id="Mist-ic/sevzero-trackio",
122
+ **({"hf_token": main_token} if main_token else {}),
123
+ )
124
+ except Exception as e:
125
+ print(f"trackio init skipped: {e}", flush=True)
126
+
127
+ try:
128
+ from unsloth import FastLanguageModel, PatchFastRL
129
+ except ImportError as e:
130
+ raise SystemExit(
131
+ f"unsloth is required for GRPO on this path: {e}\n"
132
+ "Install training extras, or on unsupported platforms set UNSLOTH_DISABLE=1 and extend train_grpo."
133
+ ) from e
134
+
135
+ PatchFastRL(algorithm="grpo", FastLanguageModel=FastLanguageModel)
136
+
137
+ from peft import PeftModel
138
+ from trl import GRPOConfig, GRPOTrainer
139
+ from trl.experimental.openenv import generate_rollout_completions
140
+
141
+ from training.env_client import AsyncSevZeroEnvClient, run_async
142
+ from training.rollout_sevzero import (
143
+ SRE_SYSTEM_PROMPT,
144
+ build_observation_prompt,
145
+ parse_action,
146
+ )
147
+
148
+ max_seq = 4096
149
+ model, tokenizer = FastLanguageModel.from_pretrained(
150
+ model_name=BASE_MODEL,
151
+ max_seq_length=max_seq,
152
+ dtype=None,
153
+ load_in_4bit=True,
154
+ )
155
+ model = PeftModel.from_pretrained(model, args.sft_adapter_repo, token=worker_token or None)
156
+ # Optional env flags (future env upgrades) — no-op for baseline server
157
+ if args.enable_schema_drift:
158
+ os.environ["SEVZERO_SCHEMA_DRIFT"] = "1"
159
+ if args.enable_oversight:
160
+ os.environ["SEVZERO_OVERSIGHT"] = "1"
161
+
162
+ metrics_path = Path(args.output_dir) / METRICS_NAME
163
+ metrics_path.parent.mkdir(parents=True, exist_ok=True)
164
+
165
+ # Capture trainer ref for step index in seeding
166
+ _trainer_holder: List[Any] = [None]
167
+ _global_episode: List[int] = [0]
168
+
169
+ def rollout_func(prompts: List[str], trainer) -> Dict[str, List[Any]]:
170
+ _trainer_holder[0] = trainer
171
+ episode_prompt_ids: List[List[int]] = []
172
+ episode_completion_ids: List[List[int]] = []
173
+ episode_logprobs: List[List[float]] = []
174
+ env_rewards: List[float] = []
175
+ tkn = os.environ.get("HF_TOKEN", "") # for private Space
176
+ for batch_idx, prompt_text in enumerate(prompts):
177
+ tr = _trainer_holder[0]
178
+ state = getattr(tr, "state", None) if tr else None
179
+ step = getattr(state, "global_step", 0) if state else 0
180
+ _global_episode[0] += 1
181
+ task_id = _pick_task_id(args, batch_idx, step)
182
+ seed = 13 + (batch_idx * 997) + (step * 13) + _global_episode[0] + random.randint(0, 1_000_000) % 100_000
183
+
184
+ async def _one_ep() -> tuple:
185
+ client = AsyncSevZeroEnvClient(env_url, token=tkn or None)
186
+ try:
187
+ p_ids: List[int] = []
188
+ c_ids: List[int] = []
189
+ lps: List[float] = []
190
+ step_rewards: List[float] = []
191
+ ro = await client.reset(task_id=task_id, seed=seed)
192
+ obs = ro.get("observation", ro)
193
+ done = ro.get("done", False)
194
+ grader: Optional[Dict[str, Any]] = None
195
+ user_prefix = f"{prompt_text}\n\n## Session\n"
196
+ for _t in range(args.rollout_max_steps or int(obs.get("max_steps", 20))):
197
+ if done:
198
+ break
199
+ user_msg = build_observation_prompt(obs)
200
+ messages = [
201
+ {"role": "system", "content": SRE_SYSTEM_PROMPT},
202
+ {"role": "user", "content": user_prefix + user_msg},
203
+ ]
204
+ p_text = tokenizer.apply_chat_template(
205
+ messages, add_generation_prompt=True, tokenize=False,
206
+ )
207
+ out = generate_rollout_completions(tr, [p_text])[0]
208
+ p_ids.extend(out.get("prompt_ids", []))
209
+ c_ids.extend(out.get("completion_ids", []))
210
+ lps.extend(out.get("logprobs", []))
211
+ gen_ids = out.get("completion_ids", [])
212
+ raw = out.get("text")
213
+ if not raw and gen_ids:
214
+ raw = tokenizer.decode(gen_ids, skip_special_tokens=True)
215
+ action = parse_action(raw or "")
216
+ step_payload = {
217
+ "action_type": str(action.get("action_type", "noop")),
218
+ "params": action.get("params") or {},
219
+ }
220
+ sr = await client.step({"action": step_payload})
221
+ obs = sr.get("observation", sr)
222
+ done = sr.get("done", False)
223
+ r = float(obs.get("reward", sr.get("reward", 0.0) or 0.0))
224
+ step_rewards.append(r)
225
+ st = await client.get_state()
226
+ max_st = int(obs.get("max_steps", 10))
227
+ try:
228
+ grader = await client.grade_episode(
229
+ final_slo_score=float(st.get("global_slo_score", 0.0)),
230
+ steps_taken=int(st.get("step_count", 0)),
231
+ max_steps=max_st,
232
+ actions_taken=list(obs.get("actions_taken", [])),
233
+ terminated=bool(st.get("terminated", True)),
234
+ termination_reason=st.get("termination_reason"),
235
+ )
236
+ except Exception:
237
+ grader = None
238
+ R = _compute_episode_return(args.reward_shaping, step_rewards, grader)
239
+ return p_ids, c_ids, lps, R
240
+ finally:
241
+ await client.aclose()
242
+
243
+ p_ids, c_ids, lps, r_ep = run_async(_one_ep())
244
+ episode_prompt_ids.append(p_ids)
245
+ episode_completion_ids.append(c_ids)
246
+ episode_logprobs.append(lps)
247
+ env_rewards.append(r_ep)
248
+ return {
249
+ "prompt_ids": episode_prompt_ids,
250
+ "completion_ids": episode_completion_ids,
251
+ "logprobs": episode_logprobs,
252
+ "env_reward": env_rewards,
253
+ }
254
+
255
+ grpo = GRPOConfig(
256
+ output_dir=args.output_dir,
257
+ learning_rate=args.lr,
258
+ per_device_train_batch_size=1,
259
+ gradient_accumulation_steps=8,
260
+ max_completion_length=1024,
261
+ num_train_epochs=1,
262
+ max_steps=args.max_steps,
263
+ num_generations=args.K,
264
+ temperature=0.85,
265
+ max_prompt_length=4096,
266
+ beta=0.04,
267
+ lr_scheduler_type="cosine",
268
+ use_vllm=True,
269
+ vllm_mode="colocate",
270
+ vllm_gpu_memory_utilization=0.55,
271
+ report_to="trackio",
272
+ logging_steps=1,
273
+ save_steps=100,
274
+ )
275
+
276
+ train_ds = _build_default_dataset()
277
+
278
+ trainer = GRPOTrainer(
279
+ model=model,
280
+ processing_class=tokenizer,
281
+ args=grpo,
282
+ train_dataset=train_ds,
283
+ reward_funcs=[_reward_from_env],
284
+ rollout_func=rollout_func,
285
+ )
286
+
287
+ from transformers import TrainerCallback
288
+
289
+ class _MetricsJSONL(TrainerCallback):
290
+ def on_log(self, args, state, control, logs=None, **kwargs):
291
+ if not logs:
292
+ return
293
+ rec = {
294
+ "step": state.global_step,
295
+ "reward_mean": logs.get("rewards", logs.get("reward", None)),
296
+ "reward_std": logs.get("reward_std", None),
297
+ "kl": logs.get("kl", None),
298
+ "entropy": logs.get("entropy", None),
299
+ "grad_norm": logs.get("grad_norm", None),
300
+ "loss": logs.get("loss", None),
301
+ "frac_reward_zero_std": logs.get("frac_reward_zero", logs.get("frac_reward_zero_std", None)),
302
+ "lr": logs.get("learning_rate", None),
303
+ }
304
+ with metrics_path.open("a", encoding="utf-8") as f:
305
+ f.write(json.dumps(rec, default=str) + "\n")
306
+ print(json.dumps({"type": "grpo", **rec}, default=str), flush=True)
307
+
308
+ trainer.add_callback(_MetricsJSONL())
309
+ trainer.train()
310
+
311
+ if args.push_to_hub_repo:
312
+ model.push_to_hub(args.push_to_hub_repo, token=worker_token or None, private=True)
313
+ tokenizer.push_to_hub(args.push_to_hub_repo, token=worker_token or None, private=True)
314
+
315
+
316
+ if __name__ == "__main__":
317
+ main()
training/train_sft.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ SFT warmup: QLoRA on Mist-ic/sevzero-expert-trajectories (see training/data/HANDOFF.md).
4
+ Target TRL / Unsloth versions: see comments after `pip index` in training/README.md.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import argparse
10
+ import json
11
+ import os
12
+ import sys
13
+ from pathlib import Path
14
+
15
+ _REPO = Path(__file__).resolve().parent.parent
16
+ if str(_REPO) not in sys.path:
17
+ sys.path.insert(0, str(_REPO))
18
+
19
+ from training.config_utils import try_load_env_files
20
+
21
+ try_load_env_files()
22
+
23
+ # --- Pin guidance (orchestrator resolves exact pins): trl>=0.22, unsloth, bitsandbytes, peft, accelerate
24
+ BASE_MODEL = "meta-llama/Llama-3.1-8B-Instruct"
25
+ DATASET_ID = "Mist-ic/sevzero-expert-trajectories"
26
+ DEFAULT_MAX_SEQ = 2048
27
+
28
+
29
+ def _parse_args() -> argparse.Namespace:
30
+ p = argparse.ArgumentParser()
31
+ p.add_argument("--output_dir", type=str, default="./outputs/sft")
32
+ p.add_argument("--max_steps", type=int, default=250)
33
+ p.add_argument("--lr", type=float, default=1e-5)
34
+ p.add_argument("--seed", type=int, default=42)
35
+ p.add_argument("--push_to_hub_repo", type=str, default="", help="e.g. PhaseOfCode/sevzero-llama3-8b-sft")
36
+ p.add_argument("--variant_name", type=str, default="default")
37
+ p.add_argument("--max_seq_length", type=int, default=0, help="0 = read HANDOFF / 2048")
38
+ return p.parse_args()
39
+
40
+
41
+ def _read_default_max_seq() -> int:
42
+ handoff = _REPO / "training" / "data" / "HANDOFF.md"
43
+ if not handoff.is_file():
44
+ return DEFAULT_MAX_SEQ
45
+ text = handoff.read_text(encoding="utf-8", errors="ignore")
46
+ for line in text.splitlines():
47
+ if "max_seq" in line.lower() and "`" in line:
48
+ try:
49
+ return int(line.split("`")[1])
50
+ except (ValueError, IndexError):
51
+ pass
52
+ return DEFAULT_MAX_SEQ
53
+
54
+
55
+ def _format_row_to_text(row: dict, tokenizer) -> str:
56
+ """Support 'text' column or OpenAI-style messages JSON."""
57
+ if "text" in row and row["text"]:
58
+ return str(row["text"])
59
+ if "messages" in row and row["messages"]:
60
+ msgs = row["messages"]
61
+ if isinstance(msgs, str):
62
+ import json as _j
63
+
64
+ msgs = _j.loads(msgs)
65
+ return tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=False)
66
+ raise ValueError("Dataset row must have 'text' or 'messages'")
67
+
68
+
69
+ def main() -> None:
70
+ args = _parse_args()
71
+ max_seq = args.max_seq_length or _read_default_max_seq()
72
+
73
+ worker_token = os.environ.get("HF_TOKEN", "")
74
+ main_token = os.environ.get("HF_MAIN_TOKEN", "")
75
+ if not worker_token:
76
+ print("warning: HF_TOKEN not set — Hub push and model download may fail.", flush=True)
77
+
78
+ # Trackio with main account (read-only space) while training pushes use HF_TOKEN
79
+ try:
80
+ import trackio
81
+
82
+ if main_token:
83
+ os.environ.setdefault("HF_TOKEN", worker_token)
84
+ trackio.init(
85
+ project="sevzero-sft",
86
+ space_id="Mist-ic/sevzero-trackio",
87
+ **({"hf_token": main_token} if main_token else {}),
88
+ )
89
+ except Exception as e:
90
+ print(f"trackio init skipped: {e}", flush=True)
91
+
92
+ from datasets import load_dataset
93
+ from transformers import TrainingArguments
94
+ from trl import SFTConfig, SFTTrainer
95
+
96
+ ds = load_dataset(DATASET_ID, split="train")
97
+
98
+ use_unsloth = os.environ.get("UNSLOTH_DISABLE", "").lower() not in ("1", "true", "yes")
99
+ model = None
100
+ tokenizer = None
101
+
102
+ if use_unsloth:
103
+ try:
104
+ from unsloth import FastLanguageModel
105
+
106
+ model, tokenizer = FastLanguageModel.from_pretrained(
107
+ model_name=BASE_MODEL,
108
+ max_seq_length=max_seq,
109
+ dtype=None,
110
+ load_in_4bit=True,
111
+ )
112
+ target_modules = [
113
+ "q_proj",
114
+ "k_proj",
115
+ "v_proj",
116
+ "o_proj",
117
+ "gate_proj",
118
+ "up_proj",
119
+ "down_proj",
120
+ ]
121
+ model = FastLanguageModel.get_peft_model(
122
+ model,
123
+ r=32,
124
+ lora_alpha=64,
125
+ lora_dropout=0.0,
126
+ target_modules=target_modules,
127
+ use_gradient_checkpointing="unsloth",
128
+ )
129
+ except Exception as e:
130
+ print(f"Unsloth path failed ({e}), falling back to PEFT+bnb.", flush=True)
131
+ use_unsloth = False
132
+
133
+ if not use_unsloth:
134
+ import torch
135
+ from peft import LoraConfig, get_peft_model
136
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
137
+
138
+ bnb = BitsAndBytesConfig(
139
+ load_in_4bit=True,
140
+ bnb_4bit_quant_type="nf4",
141
+ bnb_4bit_compute_dtype=torch.bfloat16,
142
+ )
143
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True)
144
+ if tokenizer.pad_token is None:
145
+ tokenizer.pad_token = tokenizer.eos_token
146
+ model = AutoModelForCausalLM.from_pretrained(
147
+ BASE_MODEL,
148
+ quantization_config=bnb,
149
+ device_map="auto",
150
+ torch_dtype=torch.bfloat16,
151
+ )
152
+ lora = LoraConfig(
153
+ r=32,
154
+ lora_alpha=64,
155
+ lora_dropout=0.0,
156
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
157
+ task_type="CAUSAL_LM",
158
+ )
159
+ model = get_peft_model(model, lora)
160
+
161
+ def formatting_prompts(examples: dict) -> dict:
162
+ texts = []
163
+ n = len(next(iter(examples.values())))
164
+ keys = list(examples.keys())
165
+ for i in range(n):
166
+ row = {k: (examples[k][i] if k in examples else None) for k in keys}
167
+ texts.append(_format_row_to_text(row, tokenizer))
168
+ return {"text": texts}
169
+
170
+ cols = ds.column_names
171
+ if "text" not in ds.column_names:
172
+ if "messages" in ds.column_names:
173
+ ds = ds.map(
174
+ formatting_prompts,
175
+ batched=True,
176
+ remove_columns=[c for c in cols if c not in ("messages",)],
177
+ )
178
+ else:
179
+ raise ValueError("Dataset must include a 'text' or 'messages' column")
180
+ targs = SFTConfig(
181
+ output_dir=args.output_dir,
182
+ max_steps=args.max_steps,
183
+ learning_rate=args.lr,
184
+ per_device_train_batch_size=4,
185
+ gradient_accumulation_steps=8,
186
+ warmup_ratio=0.05,
187
+ lr_scheduler_type="cosine",
188
+ optim="paged_adamw_8bit",
189
+ bf16=True,
190
+ seed=args.seed,
191
+ logging_steps=1,
192
+ report_to="trackio",
193
+ save_total_limit=2,
194
+ max_seq_length=max_seq,
195
+ )
196
+
197
+ from transformers import TrainerCallback
198
+
199
+ class JsonStepLog(TrainerCallback):
200
+ def on_log(self, args, state, control, logs=None, **kwargs):
201
+ if not logs:
202
+ return
203
+ payload = {
204
+ "type": "sft_step",
205
+ "step": state.global_step,
206
+ "loss": logs.get("loss"),
207
+ "lr": logs.get("learning_rate"),
208
+ }
209
+ print(json.dumps(payload, default=str), flush=True)
210
+
211
+ trainer = SFTTrainer(
212
+ model=model,
213
+ processing_class=tokenizer,
214
+ args=targs,
215
+ train_dataset=ds,
216
+ dataset_text_field="text",
217
+ callbacks=[JsonStepLog()],
218
+ )
219
+ trainer.train()
220
+
221
+ if args.push_to_hub_repo:
222
+ print(json.dumps({"event": "push_to_hub", "repo": args.push_to_hub_repo}, default=str), flush=True)
223
+ model.push_to_hub(
224
+ args.push_to_hub_repo,
225
+ token=worker_token or None,
226
+ private=True,
227
+ )
228
+ tokenizer.push_to_hub(
229
+ args.push_to_hub_repo,
230
+ token=worker_token or None,
231
+ private=True,
232
+ )
233
+
234
+
235
+ if __name__ == "__main__":
236
+ main()