Spaces:
Sleeping
Sleeping
SevZero Bot commited on
Commit ·
0f5092c
1
Parent(s): 382d0fd
Add Wave 1 training pipeline (SFT/GRPO/eval/preflight/launch) + gitignore hardening
Browse files- .gitignore +25 -4
- training/README.md +72 -0
- training/__init__.py +1 -0
- training/build_dataset.py +241 -0
- training/collect_trajectories.py +764 -0
- training/config_utils.py +32 -0
- training/data/DATASET_README_HF.md +35 -0
- training/data/HANDOFF.md +5 -0
- training/data/build_stats.json +11 -0
- training/data/dataset_info.json +16 -0
- training/data/sft_eval.jsonl +0 -0
- training/data/sft_train.jsonl +0 -0
- training/env_client.py +159 -0
- training/eval.py +269 -0
- training/launch_hf_job.py +97 -0
- training/loader.py +56 -0
- training/preflight.py +250 -0
- training/push_dataset.py +127 -0
- training/rollout_sevzero.py +109 -0
- training/train_grpo.py +317 -0
- training/train_sft.py +236 -0
.gitignore
CHANGED
|
@@ -1,13 +1,34 @@
|
|
| 1 |
# Documentation and research (not part of the submission)
|
| 2 |
Docs/
|
| 3 |
-
|
| 4 |
-
# OpenEnv preparatory course (dev reference only, not part of submission)
|
| 5 |
openenv-course/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
# Python
|
| 8 |
__pycache__/
|
| 9 |
*.pyc
|
| 10 |
*.pyo
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
-
#
|
| 13 |
-
.
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
# Documentation and research (not part of the submission)
|
| 2 |
Docs/
|
| 3 |
+
DocsR2/
|
|
|
|
| 4 |
openenv-course/
|
| 5 |
+
playbook/
|
| 6 |
+
|
| 7 |
+
# Secrets — NEVER commit
|
| 8 |
+
.env
|
| 9 |
+
*.env
|
| 10 |
+
api.env
|
| 11 |
+
hg.env
|
| 12 |
+
|
| 13 |
+
# Training artefacts
|
| 14 |
+
training/data/raw/
|
| 15 |
+
training/.preflight_grpo/
|
| 16 |
+
training/runs.jsonl
|
| 17 |
+
outputs/
|
| 18 |
+
out/
|
| 19 |
+
wandb/
|
| 20 |
+
trackio/
|
| 21 |
|
| 22 |
# Python
|
| 23 |
__pycache__/
|
| 24 |
*.pyc
|
| 25 |
*.pyo
|
| 26 |
+
*.egg-info/
|
| 27 |
+
.venv/
|
| 28 |
+
venv/
|
| 29 |
|
| 30 |
+
# OS / editor
|
| 31 |
+
.DS_Store
|
| 32 |
+
Thumbs.db
|
| 33 |
+
.idea/
|
| 34 |
+
.vscode/
|
training/README.md
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SevZero — training (Round 2)
|
| 2 |
+
|
| 3 |
+
One-liner per script:
|
| 4 |
+
|
| 5 |
+
- **`train_sft.py`**: SFT on `Mist-ic/sevzero-expert-trajectories` with QLoRA (Unsloth or PEFT fallback) → push adapter with `HF_TOKEN`.
|
| 6 |
+
- **`train_grpo.py`**: GRPO with `rollout_func` + remote env (`SEVZERO_ENV_URL`); vLLM colocate, Trackio `Mist-ic/sevzero-trackio`.
|
| 7 |
+
- **`eval.py`**: Compare HF adapters and frontier models; write `eval_results.csv`, push `Mist-ic/sevzero-eval-results` with `HF_MAIN_TOKEN`.
|
| 8 |
+
- **`preflight.py`**: In-process grader + tiny GRPO smoke (5 steps) on CPU; starts local uvicorn.
|
| 9 |
+
- **`launch_hf_job.py`**: `huggingface_hub.run_job` wrapper; `--hardware l40sx1` (verify with `hf jobs hardware`).
|
| 10 |
+
|
| 11 |
+
## Env files
|
| 12 |
+
|
| 13 |
+
Load with `python-dotenv` (auto-tried in `config_utils`):
|
| 14 |
+
|
| 15 |
+
- `hg.env` — `HF_TOKEN` (worker), `HF_MAIN_TOKEN` (Mist-ic, Trackio + eval dataset)
|
| 16 |
+
- `api.env` — `GEMINI_API_KEY`, `AZURE_*` for `eval.py`
|
| 17 |
+
|
| 18 |
+
| Variable | Role |
|
| 19 |
+
|----------|------|
|
| 20 |
+
| `HF_TOKEN` | Worker: train pushes, private adapter pulls |
|
| 21 |
+
| `HF_MAIN_TOKEN` | `Mist-ic`: Trackio + `sevzero-eval-results` only |
|
| 22 |
+
| `SEVZERO_ENV_URL` | HTTP base of SevZero Space/ server for GRPO + eval + preflight |
|
| 23 |
+
| `GEMINI_API_KEY` | Direct Gemini in eval |
|
| 24 |
+
| `AZURE_API_KEY` | Azure OpenAI + Azure AI Inference |
|
| 25 |
+
| `AZURE_OPENAI_ENDPOINT` | Deployment base for gpt-5.4-pro |
|
| 26 |
+
| `AZURE_AI_INFERENCE_ENDPOINT` | For grok / kimi / DeepSeek in eval |
|
| 27 |
+
| `AZURE_API_VERSION` | OpenAI client version header if needed |
|
| 28 |
+
| `GEMINI_EVAL_MODEL` | Optional override (default set in `eval.py`) |
|
| 29 |
+
|
| 30 |
+
## Local debug (from repo root)
|
| 31 |
+
|
| 32 |
+
```bash
|
| 33 |
+
# Install (pin versions in comments / orchestrator)
|
| 34 |
+
pip install -e ".[training]"
|
| 35 |
+
|
| 36 |
+
# SFT
|
| 37 |
+
python training/train_sft.py --output_dir ./out/sft --max_steps 10 --push_to_hub_repo "" --variant_name test
|
| 38 |
+
|
| 39 |
+
# GRPO (remote env required)
|
| 40 |
+
$env:SEVZERO_ENV_URL="https://<your-sevzero-space>.hf.space"
|
| 41 |
+
python training/train_grpo.py --sft_adapter_repo YOUR/adapters --max_steps 5 --output_dir ./out/grpo
|
| 42 |
+
```
|
| 43 |
+
|
| 44 |
+
## Wave 3 — three GRPO variants (see `playbook/00-orchestration.md`)
|
| 45 |
+
|
| 46 |
+
Primary (PhaseOfCode):
|
| 47 |
+
|
| 48 |
+
```bash
|
| 49 |
+
python training/train_grpo.py --sft_adapter_repo PhaseOfCode/sevzero-llama3-8b-sft --K 4 --lr 7e-6 --max_steps 350 --variant_name primary
|
| 50 |
+
```
|
| 51 |
+
|
| 52 |
+
Stability (NoahInOblivion):
|
| 53 |
+
|
| 54 |
+
```bash
|
| 55 |
+
python training/train_grpo.py --sft_adapter_repo NoahInOblivion/sevzero-llama3-8b-sft --K 8 --lr 5e-6 --max_steps 350 --variant_name stability
|
| 56 |
+
```
|
| 57 |
+
|
| 58 |
+
Innovation (NoxIsOblivion, env flags on):
|
| 59 |
+
|
| 60 |
+
```bash
|
| 61 |
+
python training/train_grpo.py --sft_adapter_repo NoxIsOblivion/sevzero-llama3-8b-sft --enable_schema_drift --enable_curriculum --K 4 --max_steps 350 --variant_name innovation
|
| 62 |
+
```
|
| 63 |
+
|
| 64 |
+
**HF Job (after merge + public git URL or bucket):**
|
| 65 |
+
|
| 66 |
+
```bash
|
| 67 |
+
$env:HF_TOKEN="<worker>"
|
| 68 |
+
$env:SEVZERO_ENV_URL="https://....hf.space"
|
| 69 |
+
python training/launch_hf_job.py --script grpo --variant_name primary -- --sft_adapter_repo YOUR/sevzero-llama3-8b-sft
|
| 70 |
+
```
|
| 71 |
+
|
| 72 |
+
**Dependency pins:** run `pip index versions trl openenv-core unsloth` and `python -c "import trl; print(trl.__version__)"` after install; pin in the orchestrator’s lock, not in this file.
|
training/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Training / trajectory pipeline (Round 2)
|
training/build_dataset.py
ADDED
|
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Build Llama-3.1-8B-Instruct SFT jsonl from raw trajectory jsonl (score ≥ 0.85).
|
| 3 |
+
"""
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
import argparse
|
| 7 |
+
import json
|
| 8 |
+
import random
|
| 9 |
+
import sys
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import Any, Dict, List, Set, Tuple
|
| 12 |
+
|
| 13 |
+
from dotenv import load_dotenv
|
| 14 |
+
|
| 15 |
+
REPO_ROOT = Path(__file__).resolve().parent.parent
|
| 16 |
+
if str(REPO_ROOT) not in sys.path:
|
| 17 |
+
sys.path.insert(0, str(REPO_ROOT))
|
| 18 |
+
|
| 19 |
+
from inference import SYSTEM_PROMPT # noqa: E402
|
| 20 |
+
|
| 21 |
+
load_dotenv(REPO_ROOT / "api.env")
|
| 22 |
+
load_dotenv(REPO_ROOT / "hg.env")
|
| 23 |
+
|
| 24 |
+
DATA_DIR = REPO_ROOT / "training" / "data"
|
| 25 |
+
RAW_GLOB = "raw/*.jsonl"
|
| 26 |
+
OUT_TRAIN = DATA_DIR / "sft_train.jsonl"
|
| 27 |
+
OUT_EVAL = DATA_DIR / "sft_eval.jsonl"
|
| 28 |
+
OUT_STATS = DATA_DIR / "build_stats.json"
|
| 29 |
+
|
| 30 |
+
MAX_OBS_TOKENS = 2048
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _get_tokenizer():
|
| 34 |
+
import os
|
| 35 |
+
|
| 36 |
+
try:
|
| 37 |
+
from transformers import AutoTokenizer
|
| 38 |
+
except Exception:
|
| 39 |
+
return None
|
| 40 |
+
name = "meta-llama/Llama-3.1-8B-Instruct"
|
| 41 |
+
try:
|
| 42 |
+
tok = AutoTokenizer.from_pretrained(
|
| 43 |
+
name, token=os.environ.get("HF_MAIN_TOKEN")
|
| 44 |
+
)
|
| 45 |
+
return tok
|
| 46 |
+
except Exception:
|
| 47 |
+
try:
|
| 48 |
+
return AutoTokenizer.from_pretrained(
|
| 49 |
+
"hf-internal-testing/llama-tokenizer"
|
| 50 |
+
)
|
| 51 |
+
except Exception:
|
| 52 |
+
return None
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def _count_tokens(toker, text: str) -> int:
|
| 56 |
+
if toker is not None:
|
| 57 |
+
return len(toker.encode(text, add_special_tokens=False))
|
| 58 |
+
return max(1, len(text) // 4)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def _shrink_observation(obs: Dict[str, Any], toker, max_toks: int) -> str:
|
| 62 |
+
"""Serialize observation to JSON, shrink until user message fits max_toks (approximate)."""
|
| 63 |
+
o = {k: v for k, v in obs.items() if k not in ("reward",)}
|
| 64 |
+
order_drop = [
|
| 65 |
+
"metric_history",
|
| 66 |
+
"traces",
|
| 67 |
+
"logs",
|
| 68 |
+
"actions_taken",
|
| 69 |
+
"recent_deploys",
|
| 70 |
+
]
|
| 71 |
+
for _ in range(40):
|
| 72 |
+
text = json.dumps(o, ensure_ascii=False, separators=(",", ":"), default=str)
|
| 73 |
+
tcount = _count_tokens(toker, text)
|
| 74 |
+
if tcount <= max_toks:
|
| 75 |
+
return text
|
| 76 |
+
shrunk = False
|
| 77 |
+
for k in order_drop:
|
| 78 |
+
if k in o and o[k]:
|
| 79 |
+
o[k] = None
|
| 80 |
+
if k == "actions_taken":
|
| 81 |
+
o[k] = []
|
| 82 |
+
elif k in ("metric_history", "recent_deploys"):
|
| 83 |
+
o[k] = []
|
| 84 |
+
shrunk = True
|
| 85 |
+
break
|
| 86 |
+
if shrunk:
|
| 87 |
+
continue
|
| 88 |
+
if "services" in o and isinstance(o["services"], list) and len(o["services"]) > 2:
|
| 89 |
+
o["services"] = o["services"][: max(1, len(o["services"]) - 1)]
|
| 90 |
+
continue
|
| 91 |
+
if "alerts" in o and isinstance(o["alerts"], list) and len(o["alerts"]) > 1:
|
| 92 |
+
o["alerts"] = o["alerts"][: max(0, len(o["alerts"]) - 1)]
|
| 93 |
+
continue
|
| 94 |
+
o["__truncated__"] = True
|
| 95 |
+
break
|
| 96 |
+
return json.dumps(o, ensure_ascii=False, separators=(",", ":"), default=str)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def _episode_id(ep: Dict[str, Any]) -> str:
|
| 100 |
+
return f"{ep.get('model', '')}|{ep.get('task_id', '')}|{ep.get('seed', 0)}"
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def _assistant_action_json(action: Any) -> str:
|
| 104 |
+
if not isinstance(action, dict):
|
| 105 |
+
return json.dumps(
|
| 106 |
+
{"action_type": "noop", "params": {}}, ensure_ascii=False
|
| 107 |
+
)
|
| 108 |
+
a = {
|
| 109 |
+
"action_type": str(action.get("action_type", "noop")),
|
| 110 |
+
"params": action.get("params") or {},
|
| 111 |
+
}
|
| 112 |
+
return json.dumps(a, ensure_ascii=False)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def _load_episodes_from_raw(raw_dir: Path) -> List[Dict[str, Any]]:
|
| 116 |
+
out: List[Dict[str, Any]] = []
|
| 117 |
+
for p in sorted(raw_dir.glob("*.jsonl")):
|
| 118 |
+
with p.open(encoding="utf-8") as f:
|
| 119 |
+
for line in f:
|
| 120 |
+
line = line.strip()
|
| 121 |
+
if not line:
|
| 122 |
+
continue
|
| 123 |
+
out.append(json.loads(line))
|
| 124 |
+
return out
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def build(
|
| 128 |
+
min_score: float = 0.85,
|
| 129 |
+
) -> Dict[str, Any]:
|
| 130 |
+
toker = _get_tokenizer()
|
| 131 |
+
raw_dir = DATA_DIR / "raw"
|
| 132 |
+
episodes = _load_episodes_from_raw(raw_dir)
|
| 133 |
+
kept: List[Dict[str, Any]] = []
|
| 134 |
+
dropped: List[Dict[str, Any]] = []
|
| 135 |
+
for ep in episodes:
|
| 136 |
+
sc = float(ep.get("final_score", 0.0) or 0.0)
|
| 137 |
+
if sc >= min_score and ep.get("steps"):
|
| 138 |
+
kept.append(ep)
|
| 139 |
+
else:
|
| 140 |
+
dropped.append(ep)
|
| 141 |
+
|
| 142 |
+
eids = [_episode_id(e) for e in kept]
|
| 143 |
+
unique_eids = list(dict.fromkeys(eids))
|
| 144 |
+
n_ep = len(unique_eids)
|
| 145 |
+
rng = random.Random(42)
|
| 146 |
+
rng.shuffle(unique_eids)
|
| 147 |
+
if n_ep <= 1:
|
| 148 |
+
n_eval = 0
|
| 149 |
+
else:
|
| 150 |
+
n_eval = max(1, n_ep // 10)
|
| 151 |
+
eval_ids: Set[str] = set(unique_eids[:n_eval]) if n_eval else set()
|
| 152 |
+
|
| 153 |
+
train_rows: List[Dict[str, Any]] = []
|
| 154 |
+
eval_rows: List[Dict[str, Any]] = []
|
| 155 |
+
max_prompt_toks = 0
|
| 156 |
+
|
| 157 |
+
for ep in kept:
|
| 158 |
+
eid = _episode_id(ep)
|
| 159 |
+
is_eval = eid in eval_ids
|
| 160 |
+
for st in ep.get("steps", []):
|
| 161 |
+
obs = st.get("observation", {})
|
| 162 |
+
if not isinstance(obs, dict):
|
| 163 |
+
continue
|
| 164 |
+
user_str = _shrink_observation(obs, toker, MAX_OBS_TOKENS)
|
| 165 |
+
messages = [
|
| 166 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 167 |
+
{"role": "user", "content": user_str},
|
| 168 |
+
{
|
| 169 |
+
"role": "assistant",
|
| 170 |
+
"content": _assistant_action_json(st.get("action", {})),
|
| 171 |
+
},
|
| 172 |
+
]
|
| 173 |
+
if toker is not None:
|
| 174 |
+
try:
|
| 175 |
+
plen = len(
|
| 176 |
+
toker.apply_chat_template(
|
| 177 |
+
messages, tokenize=True, add_generation_prompt=False
|
| 178 |
+
)
|
| 179 |
+
)
|
| 180 |
+
except Exception:
|
| 181 |
+
plen = _count_tokens(
|
| 182 |
+
toker, SYSTEM_PROMPT + "\n" + user_str
|
| 183 |
+
)
|
| 184 |
+
else:
|
| 185 |
+
plen = _count_tokens(
|
| 186 |
+
None, SYSTEM_PROMPT + "\n" + user_str
|
| 187 |
+
)
|
| 188 |
+
max_prompt_toks = max(max_prompt_toks, plen)
|
| 189 |
+
row = {
|
| 190 |
+
"messages": messages,
|
| 191 |
+
"meta": {
|
| 192 |
+
"episode_id": eid,
|
| 193 |
+
"model": ep.get("model"),
|
| 194 |
+
"task_id": ep.get("task_id"),
|
| 195 |
+
"seed": ep.get("seed"),
|
| 196 |
+
"step": st.get("step"),
|
| 197 |
+
"episode_score": ep.get("final_score"),
|
| 198 |
+
},
|
| 199 |
+
}
|
| 200 |
+
if is_eval:
|
| 201 |
+
eval_rows.append(row)
|
| 202 |
+
else:
|
| 203 |
+
train_rows.append(row)
|
| 204 |
+
|
| 205 |
+
scores = [float(x.get("final_score", 0) or 0) for x in kept]
|
| 206 |
+
mean_sc = sum(scores) / len(scores) if scores else 0.0
|
| 207 |
+
|
| 208 |
+
DATA_DIR.mkdir(parents=True, exist_ok=True)
|
| 209 |
+
with OUT_TRAIN.open("w", encoding="utf-8") as ft:
|
| 210 |
+
for r in train_rows:
|
| 211 |
+
ft.write(json.dumps(r, ensure_ascii=False) + "\n")
|
| 212 |
+
with OUT_EVAL.open("w", encoding="utf-8") as fe:
|
| 213 |
+
for r in eval_rows:
|
| 214 |
+
fe.write(json.dumps(r, ensure_ascii=False) + "\n")
|
| 215 |
+
|
| 216 |
+
stats: Dict[str, Any] = {
|
| 217 |
+
"episodes_total_seen": len(episodes),
|
| 218 |
+
"episodes_kept": len(kept),
|
| 219 |
+
"episodes_dropped": len(dropped),
|
| 220 |
+
"mean_episode_score_kept": round(mean_sc, 6),
|
| 221 |
+
"train_rows": len(train_rows),
|
| 222 |
+
"eval_rows": len(eval_rows),
|
| 223 |
+
"max_prompt_token_length": max_prompt_toks,
|
| 224 |
+
"max_observation_user_token_budget": MAX_OBS_TOKENS,
|
| 225 |
+
"min_score_filter": min_score,
|
| 226 |
+
}
|
| 227 |
+
with OUT_STATS.open("w", encoding="utf-8") as f:
|
| 228 |
+
json.dump(stats, f, indent=2)
|
| 229 |
+
print(json.dumps(stats, indent=2), flush=True)
|
| 230 |
+
return stats
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
def main() -> None:
|
| 234 |
+
ap = argparse.ArgumentParser()
|
| 235 |
+
ap.add_argument("--min-score", type=float, default=0.85)
|
| 236 |
+
args = ap.parse_args()
|
| 237 |
+
build(min_score=args.min_score)
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
if __name__ == "__main__":
|
| 241 |
+
main()
|
training/collect_trajectories.py
ADDED
|
@@ -0,0 +1,764 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Collect expert trajectories for SevZero SFT (Round 2).
|
| 3 |
+
|
| 4 |
+
Loads API keys from api.env and hg.env (gitignored). Does not log secrets.
|
| 5 |
+
"""
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
import argparse
|
| 9 |
+
import copy
|
| 10 |
+
import difflib
|
| 11 |
+
import json
|
| 12 |
+
import os
|
| 13 |
+
import re
|
| 14 |
+
import subprocess
|
| 15 |
+
import sys
|
| 16 |
+
import time
|
| 17 |
+
from dataclasses import dataclass, field
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
from typing import Any, Dict, List, Optional, Set, Tuple
|
| 20 |
+
|
| 21 |
+
import httpx
|
| 22 |
+
from dotenv import load_dotenv
|
| 23 |
+
from openai import AzureOpenAI
|
| 24 |
+
from pydantic import BaseModel, Field
|
| 25 |
+
|
| 26 |
+
# Repo root: parent of training/
|
| 27 |
+
REPO_ROOT = Path(__file__).resolve().parent.parent
|
| 28 |
+
if str(REPO_ROOT) not in sys.path:
|
| 29 |
+
sys.path.insert(0, str(REPO_ROOT))
|
| 30 |
+
|
| 31 |
+
from inference import ( # noqa: E402
|
| 32 |
+
build_observation_prompt,
|
| 33 |
+
parse_action,
|
| 34 |
+
)
|
| 35 |
+
from inference import SYSTEM_PROMPT as _BASE_SYSTEM # noqa: E402
|
| 36 |
+
|
| 37 |
+
load_dotenv(REPO_ROOT / "api.env")
|
| 38 |
+
load_dotenv(REPO_ROOT / "hg.env")
|
| 39 |
+
|
| 40 |
+
# ---------------------------------------------------------------------------
|
| 41 |
+
# Config matrix (must match spec)
|
| 42 |
+
# ---------------------------------------------------------------------------
|
| 43 |
+
|
| 44 |
+
GEMINI_SEEDS = [
|
| 45 |
+
42, 123, 7, 11, 23, 31, 47, 59, 67, 71, 83, 89, 97, 101, 109, 113, 127, 131, 137, 149
|
| 46 |
+
]
|
| 47 |
+
GPT_SEEDS = [
|
| 48 |
+
42, 123, 7, 13, 17, 19, 29, 37, 41, 43, 53, 61, 73, 79, 83, 89, 97, 101, 103, 107
|
| 49 |
+
]
|
| 50 |
+
GROK_EXTRA_SEEDS = [13, 17, 19, 29, 37, 41, 43, 53, 61, 73]
|
| 51 |
+
|
| 52 |
+
# Combined pool for grok / kimi / deepseek (any from grok list + full Gemini list)
|
| 53 |
+
GROK_KIMI_POOL: List[int] = sorted(set(GEMINI_SEEDS) | set(GROK_EXTRA_SEEDS))
|
| 54 |
+
|
| 55 |
+
MODEL_GEMINI = "gemini-3.1-pro-preview"
|
| 56 |
+
MODEL_GPT = "gpt-5.4-pro"
|
| 57 |
+
MODEL_GROK = "grok-4.20-reasoning"
|
| 58 |
+
MODEL_KIMI = "kimi-k2.6"
|
| 59 |
+
MODEL_DEEPSEEK = "DeepSeek-V3.2"
|
| 60 |
+
ALL_CANON = {MODEL_GEMINI, MODEL_GPT, MODEL_GROK, MODEL_KIMI, MODEL_DEEPSEEK}
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def _split_seeds(
|
| 64 |
+
pool: List[int], counts: Tuple[int, int, int], offset: int
|
| 65 |
+
) -> List[Tuple[str, int]]:
|
| 66 |
+
"""Return list of (task_id, seed) in order easy, medium, hard."""
|
| 67 |
+
c_e, c_m, c_h = counts
|
| 68 |
+
n = len(pool)
|
| 69 |
+
if n == 0:
|
| 70 |
+
return []
|
| 71 |
+
o = [pool[(i + offset) % n] for i in range(n)]
|
| 72 |
+
out: List[Tuple[str, int]] = []
|
| 73 |
+
i = 0
|
| 74 |
+
for _ in range(c_e):
|
| 75 |
+
out.append(("easy", o[i % len(o)]))
|
| 76 |
+
i += 1
|
| 77 |
+
for _ in range(c_m):
|
| 78 |
+
out.append(("medium", o[i % len(o)]))
|
| 79 |
+
i += 1
|
| 80 |
+
for _ in range(c_h):
|
| 81 |
+
out.append(("hard", o[i % len(o)]))
|
| 82 |
+
i += 1
|
| 83 |
+
return out
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def plan_gemini(c_e: int, c_m: int, c_h: int) -> List[Tuple[str, str, int]]:
|
| 87 |
+
return [
|
| 88 |
+
(MODEL_GEMINI, t, s)
|
| 89 |
+
for t, s in _split_seeds(GEMINI_SEEDS, (c_e, c_m, c_h), offset=0)
|
| 90 |
+
]
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def plan_gpt(c_e: int, c_m: int, c_h: int) -> List[Tuple[str, str, int]]:
|
| 94 |
+
return [
|
| 95 |
+
(MODEL_GPT, t, s)
|
| 96 |
+
for t, s in _split_seeds(GPT_SEEDS, (c_e, c_m, c_h), offset=0)
|
| 97 |
+
]
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def plan_grok(c_e: int, c_m: int, c_h: int) -> List[Tuple[str, str, int]]:
|
| 101 |
+
return [
|
| 102 |
+
(MODEL_GROK, t, s)
|
| 103 |
+
for t, s in _split_seeds(GROK_KIMI_POOL, (c_e, c_m, c_h), offset=0)
|
| 104 |
+
]
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def plan_kimi(c_e: int, c_m: int, c_h: int) -> List[Tuple[str, str, int]]:
|
| 108 |
+
return [
|
| 109 |
+
(MODEL_KIMI, t, s)
|
| 110 |
+
for t, s in _split_seeds(GROK_KIMI_POOL, (c_e, c_m, c_h), offset=7)
|
| 111 |
+
]
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def plan_deepseek(c_e: int, c_m: int, c_h: int) -> List[Tuple[str, str, int]]:
|
| 115 |
+
return [
|
| 116 |
+
(MODEL_DEEPSEEK, t, s)
|
| 117 |
+
for t, s in _split_seeds(GROK_KIMI_POOL, (c_e, c_m, c_h), offset=3)
|
| 118 |
+
]
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def full_plan(c_e: int, c_m: int, c_h: int) -> List[Tuple[str, str, int]]:
|
| 122 |
+
return (
|
| 123 |
+
plan_gemini(c_e, c_m, c_h)
|
| 124 |
+
+ plan_gpt(c_e, c_m, c_h)
|
| 125 |
+
+ plan_grok(c_e, c_m, c_h)
|
| 126 |
+
+ plan_kimi(c_e, c_m, c_h)
|
| 127 |
+
+ plan_deepseek(c_e, c_m, c_h)
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
# Rough USD cost tracking (tunable; for guardrail only)
|
| 132 |
+
@dataclass
|
| 133 |
+
class CostTracker:
|
| 134 |
+
usd: float = 0.0
|
| 135 |
+
budget: float = 5.0
|
| 136 |
+
by_model: Dict[str, float] = field(default_factory=dict)
|
| 137 |
+
per_model_max: float = 2.0
|
| 138 |
+
|
| 139 |
+
def add(self, model: str, usd: float) -> None:
|
| 140 |
+
self.usd += usd
|
| 141 |
+
self.by_model[model] = self.by_model.get(model, 0.0) + usd
|
| 142 |
+
m = self.by_model[model]
|
| 143 |
+
cap = self.per_model_max
|
| 144 |
+
if m > cap:
|
| 145 |
+
raise RuntimeError(
|
| 146 |
+
f"Model {model} exceeded ${cap:.2f} in estimated spend (${m:.2f}); stopping per cap."
|
| 147 |
+
)
|
| 148 |
+
if self.usd > self.budget:
|
| 149 |
+
raise RuntimeError(
|
| 150 |
+
f"Total estimated API spend ${self.usd:.2f} exceeded budget ${self.budget:.2f}."
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def _estimate_openai_style_cost(
|
| 155 |
+
model: str, prompt_tokens: int, completion_tokens: int
|
| 156 |
+
) -> float:
|
| 157 |
+
# Conservative blended rate per 1K tokens (USD) — for guardrails only
|
| 158 |
+
if "gemini" in model:
|
| 159 |
+
p, c = 0.00125, 0.01
|
| 160 |
+
elif "gpt" in model.lower() or "5.4" in model:
|
| 161 |
+
p, c = 0.0025, 0.01
|
| 162 |
+
else:
|
| 163 |
+
p, c = 0.001, 0.006
|
| 164 |
+
return (prompt_tokens * p + completion_tokens * c) / 1000.0
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
# ---------------------------------------------------------------------------
|
| 168 |
+
# Pydantic for Gemini structured action JSON
|
| 169 |
+
# ---------------------------------------------------------------------------
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
class AgentActionOut(BaseModel):
|
| 173 |
+
action_type: str
|
| 174 |
+
params: Dict[str, Any] = Field(default_factory=dict)
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
# ---------------------------------------------------------------------------
|
| 178 |
+
# Azure deployment self-heal
|
| 179 |
+
# ---------------------------------------------------------------------------
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def _is_not_found(err: str) -> bool:
|
| 183 |
+
s = (err or "").lower()
|
| 184 |
+
return "deploymentnotfound" in s or "deployment" in s and "not found" in s
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def list_azure_openai_deployments() -> List[str]:
|
| 188 |
+
key = os.environ.get("AZURE_API_KEY", "")
|
| 189 |
+
ep = (os.environ.get("AZURE_OPENAI_ENDPOINT", "") or "").rstrip("/")
|
| 190 |
+
ver = os.environ.get("AZURE_API_VERSION", "2024-12-01-preview")
|
| 191 |
+
if not key or not ep:
|
| 192 |
+
return []
|
| 193 |
+
url = f"{ep}/openai/deployments?api-version={ver}"
|
| 194 |
+
try:
|
| 195 |
+
r = httpx.get(url, headers={"api-key": key}, timeout=30.0)
|
| 196 |
+
r.raise_for_status()
|
| 197 |
+
data = r.json()
|
| 198 |
+
return [d.get("id", "") for d in data.get("value", []) if d.get("id")]
|
| 199 |
+
except Exception:
|
| 200 |
+
return []
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def list_foundry_deployments() -> List[str]:
|
| 204 |
+
"""
|
| 205 |
+
Best-effort: project endpoint may expose deployments; schema varies.
|
| 206 |
+
"""
|
| 207 |
+
fe = (os.environ.get("AZURE_FOUNDRY_PROJECT_ENDPOINT", "") or "").rstrip("/")
|
| 208 |
+
key = os.environ.get("AZURE_API_KEY", "")
|
| 209 |
+
if not fe or not key:
|
| 210 |
+
return []
|
| 211 |
+
for suffix in ("/deployments", "/openai/models"):
|
| 212 |
+
try:
|
| 213 |
+
url = f"{fe}{suffix}"
|
| 214 |
+
r = httpx.get(
|
| 215 |
+
url, headers={"api-key": key}, params={"api-version": "2024-12-01-preview"}, timeout=30.0
|
| 216 |
+
)
|
| 217 |
+
if r.status_code != 200:
|
| 218 |
+
continue
|
| 219 |
+
data = r.json()
|
| 220 |
+
if isinstance(data, list):
|
| 221 |
+
return [str(x.get("id", x)) for x in data if isinstance(x, dict)]
|
| 222 |
+
if "value" in data:
|
| 223 |
+
return [d.get("id", "") for d in data.get("value", []) if d.get("id")]
|
| 224 |
+
except Exception:
|
| 225 |
+
continue
|
| 226 |
+
return []
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def pick_closest(name: str, options: List[str]) -> str:
|
| 230 |
+
if not options:
|
| 231 |
+
return name
|
| 232 |
+
if name in options:
|
| 233 |
+
return name
|
| 234 |
+
ranked = difflib.get_close_matches(name, options, n=1, cutoff=0.2)
|
| 235 |
+
if ranked:
|
| 236 |
+
return ranked[0]
|
| 237 |
+
return options[0]
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
# ---------------------------------------------------------------------------
|
| 241 |
+
# LLM backends
|
| 242 |
+
# ---------------------------------------------------------------------------
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
class LLMClient:
|
| 246 |
+
def __init__(self, model: str) -> None:
|
| 247 |
+
self.model = model
|
| 248 |
+
self.gemini_client: Any = None
|
| 249 |
+
self.azure_openai: Any = None
|
| 250 |
+
self.azure_inf: Any = None
|
| 251 |
+
if model == MODEL_GEMINI:
|
| 252 |
+
from google import genai
|
| 253 |
+
|
| 254 |
+
key = os.environ.get("GEMINI_API_KEY", "")
|
| 255 |
+
if not key:
|
| 256 |
+
raise ValueError("GEMINI_API_KEY missing for Gemini collection.")
|
| 257 |
+
self.gemini_client = genai.Client(api_key=key)
|
| 258 |
+
elif model == MODEL_GPT:
|
| 259 |
+
if not all(
|
| 260 |
+
os.environ.get(x)
|
| 261 |
+
for x in (
|
| 262 |
+
"AZURE_API_KEY",
|
| 263 |
+
"AZURE_OPENAI_ENDPOINT",
|
| 264 |
+
"AZURE_API_VERSION",
|
| 265 |
+
)
|
| 266 |
+
):
|
| 267 |
+
raise ValueError("AZURE_API_KEY, AZURE_OPENAI_ENDPOINT, AZURE_API_VERSION required for gpt-5.4-pro.")
|
| 268 |
+
self.azure_openai = AzureOpenAI(
|
| 269 |
+
api_key=os.environ["AZURE_API_KEY"],
|
| 270 |
+
azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT"],
|
| 271 |
+
api_version=os.environ["AZURE_API_VERSION"],
|
| 272 |
+
)
|
| 273 |
+
else:
|
| 274 |
+
if not all(os.environ.get(x) for x in ("AZURE_API_KEY", "AZURE_AI_INFERENCE_ENDPOINT")):
|
| 275 |
+
raise ValueError("AZURE_API_KEY and AZURE_AI_INFERENCE_ENDPOINT required for inference models.")
|
| 276 |
+
from azure.ai.inference import ChatCompletionsClient
|
| 277 |
+
from azure.core.credentials import AzureKeyCredential
|
| 278 |
+
|
| 279 |
+
self.azure_inf = ChatCompletionsClient(
|
| 280 |
+
endpoint=os.environ["AZURE_AI_INFERENCE_ENDPOINT"],
|
| 281 |
+
credential=AzureKeyCredential(os.environ["AZURE_API_KEY"]),
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
def _deployment_name(self) -> str:
|
| 285 |
+
m = {MODEL_GPT: "AZURE_MODEL_GPT", MODEL_GROK: "AZURE_MODEL_GROK", MODEL_KIMI: "AZURE_MODEL_KIMI", MODEL_DEEPSEEK: "AZURE_MODEL_DEEPSEEK"}.get(self.model)
|
| 286 |
+
if m:
|
| 287 |
+
v = os.environ.get(m, "").strip()
|
| 288 |
+
if v:
|
| 289 |
+
return v
|
| 290 |
+
return self.model
|
| 291 |
+
|
| 292 |
+
def call(
|
| 293 |
+
self,
|
| 294 |
+
messages: List[Dict[str, str]],
|
| 295 |
+
) -> Tuple[str, int, int]:
|
| 296 |
+
"""Return (raw_text, prompt_tokens, completion_tokens)."""
|
| 297 |
+
p_tok, c_tok = 0, 0
|
| 298 |
+
if self.gemini_client is not None:
|
| 299 |
+
return self._call_gemini(messages, p_tok, c_tok)
|
| 300 |
+
if self.azure_openai is not None:
|
| 301 |
+
return self._call_azure_openai(messages, p_tok, c_tok)
|
| 302 |
+
if self.azure_inf is not None:
|
| 303 |
+
return self._call_azure_inference(messages, p_tok, c_tok)
|
| 304 |
+
raise RuntimeError("No backend initialised")
|
| 305 |
+
|
| 306 |
+
def _call_gemini(
|
| 307 |
+
self, messages: List[Dict[str, str]], p0: int, c0: int
|
| 308 |
+
) -> Tuple[str, int, int]:
|
| 309 |
+
from google.genai import types
|
| 310 |
+
|
| 311 |
+
if not messages:
|
| 312 |
+
return '{"action_type": "noop", "params": {}}', 0, 0
|
| 313 |
+
system = messages[0]["content"] if messages[0]["role"] == "system" else _BASE_SYSTEM
|
| 314 |
+
rest = messages[1:] if messages[0]["role"] == "system" else messages
|
| 315 |
+
name = os.environ.get("GEMINI_MODEL_PRO", MODEL_GEMINI)
|
| 316 |
+
config = types.GenerateContentConfig(
|
| 317 |
+
system_instruction=system,
|
| 318 |
+
response_mime_type="application/json",
|
| 319 |
+
response_json_schema=AgentActionOut,
|
| 320 |
+
temperature=0.0,
|
| 321 |
+
max_output_tokens=512,
|
| 322 |
+
)
|
| 323 |
+
# Build contents: alternating user / model for few-shot tail
|
| 324 |
+
contents: List[Any] = []
|
| 325 |
+
for m in rest:
|
| 326 |
+
if m["role"] == "user":
|
| 327 |
+
contents.append(
|
| 328 |
+
types.Content(role="user", parts=[types.Part.from_text(text=m["content"])])
|
| 329 |
+
)
|
| 330 |
+
else:
|
| 331 |
+
contents.append(
|
| 332 |
+
types.Content(
|
| 333 |
+
role="model",
|
| 334 |
+
parts=[types.Part.from_text(text=m["content"])],
|
| 335 |
+
)
|
| 336 |
+
)
|
| 337 |
+
for attempt in range(3):
|
| 338 |
+
try:
|
| 339 |
+
resp = self.gemini_client.models.generate_content(
|
| 340 |
+
model=name, contents=contents, config=config
|
| 341 |
+
)
|
| 342 |
+
text = (resp.text or "").strip() if hasattr(resp, "text") else ""
|
| 343 |
+
u = getattr(resp, "usage_metadata", None) or getattr(resp, "usage", None)
|
| 344 |
+
pt = int(getattr(u, "prompt_token_count", None) or getattr(u, "prompt_tokens", 0) or 0) if u else 0
|
| 345 |
+
ct = int(getattr(u, "candidates_token_count", None) or getattr(u, "completion_tokens", 0) or 0) if u else 0
|
| 346 |
+
if not text and hasattr(resp, "candidates") and resp.candidates:
|
| 347 |
+
p0x = resp.candidates[0].content.parts[0] if resp.candidates[0].content.parts else None
|
| 348 |
+
text = getattr(p0x, "text", "") or ""
|
| 349 |
+
return text, pt, ct
|
| 350 |
+
except Exception:
|
| 351 |
+
if attempt < 2:
|
| 352 |
+
time.sleep(1.0 + attempt)
|
| 353 |
+
else:
|
| 354 |
+
return '{"action_type": "noop", "params": {}}', p0, c0
|
| 355 |
+
|
| 356 |
+
def _call_azure_openai(
|
| 357 |
+
self, messages: List[Dict[str, str]], p0: int, c0: int
|
| 358 |
+
) -> Tuple[str, int, int]:
|
| 359 |
+
dep = self._deployment_name()
|
| 360 |
+
for attempt in range(3):
|
| 361 |
+
try:
|
| 362 |
+
comp = self.azure_openai.chat.completions.create(
|
| 363 |
+
model=dep,
|
| 364 |
+
messages=messages, # type: ignore[arg-type]
|
| 365 |
+
temperature=0.0,
|
| 366 |
+
max_tokens=512,
|
| 367 |
+
timeout=90.0,
|
| 368 |
+
)
|
| 369 |
+
text = (comp.choices[0].message.content or "").strip()
|
| 370 |
+
u = comp.usage
|
| 371 |
+
pt = u.prompt_tokens if u else 0
|
| 372 |
+
ct = u.completion_tokens if u else 0
|
| 373 |
+
return text, pt, ct
|
| 374 |
+
except Exception as e:
|
| 375 |
+
err = str(e)
|
| 376 |
+
if _is_not_found(err):
|
| 377 |
+
names = list_azure_openai_deployments()
|
| 378 |
+
if names:
|
| 379 |
+
dep = pick_closest(dep, names)
|
| 380 |
+
if attempt == 2:
|
| 381 |
+
return '{"action_type": "noop", "params": {}}', p0, c0
|
| 382 |
+
time.sleep(1.0 + attempt)
|
| 383 |
+
return '{"action_type": "noop", "params": {}}', p0, c0
|
| 384 |
+
|
| 385 |
+
def _call_azure_inference(
|
| 386 |
+
self, messages: List[Dict[str, str]], p0: int, c0: int
|
| 387 |
+
) -> Tuple[str, int, int]:
|
| 388 |
+
dep = self._deployment_name()
|
| 389 |
+
for attempt in range(3):
|
| 390 |
+
try:
|
| 391 |
+
resp = self.azure_inf.complete(
|
| 392 |
+
model=dep,
|
| 393 |
+
messages=messages, # type: ignore[arg-type]
|
| 394 |
+
temperature=0.0,
|
| 395 |
+
max_tokens=512,
|
| 396 |
+
)
|
| 397 |
+
ch = resp.choices[0].message
|
| 398 |
+
text = (ch.content or "").strip() if ch else ""
|
| 399 |
+
u = getattr(resp, "usage", None)
|
| 400 |
+
pt = int(getattr(u, "prompt_tokens", 0) or 0) if u else 0
|
| 401 |
+
ct = int(getattr(u, "completion_tokens", 0) or 0) if u else 0
|
| 402 |
+
return text, pt, ct
|
| 403 |
+
except Exception as e:
|
| 404 |
+
err = str(e)
|
| 405 |
+
if _is_not_found(err) or "404" in err or "not found" in err.lower():
|
| 406 |
+
names = [n for n in list_foundry_deployments() + list_azure_openai_deployments() if n]
|
| 407 |
+
if names:
|
| 408 |
+
dep = pick_closest(dep, names)
|
| 409 |
+
if attempt == 2:
|
| 410 |
+
return '{"action_type": "noop", "params": {}}', p0, c0
|
| 411 |
+
time.sleep(1.0 + attempt)
|
| 412 |
+
return '{"action_type": "noop", "params": {}}', p0, c0
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
# ---------------------------------------------------------------------------
|
| 416 |
+
# Episode (mirrors inference.run_episode; logs full trace)
|
| 417 |
+
# ---------------------------------------------------------------------------
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
def _memory_block(tried_actions: Dict[str, List[str]], resolved_services: List[str]) -> str:
|
| 421 |
+
if not tried_actions and not resolved_services:
|
| 422 |
+
return ""
|
| 423 |
+
lines = ["## Episode Memory (do not repeat failed approaches)"]
|
| 424 |
+
if resolved_services:
|
| 425 |
+
lines.append(f" Resolved: {', '.join(resolved_services)}")
|
| 426 |
+
for act, targets in tried_actions.items():
|
| 427 |
+
lines.append(f" {act}: {'; '.join(targets)}")
|
| 428 |
+
return "\n".join(lines)
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
def run_one_episode(
|
| 432 |
+
llm: LLMClient,
|
| 433 |
+
model_id: str,
|
| 434 |
+
base: str,
|
| 435 |
+
task_id: str,
|
| 436 |
+
seed: int,
|
| 437 |
+
cost: CostTracker,
|
| 438 |
+
) -> Dict[str, Any]:
|
| 439 |
+
grade: Dict[str, Any] = {}
|
| 440 |
+
with httpx.Client(timeout=60.0) as http:
|
| 441 |
+
r = http.post(
|
| 442 |
+
f"{base}/reset", json={"seed": seed, "task_id": task_id}
|
| 443 |
+
)
|
| 444 |
+
r.raise_for_status()
|
| 445 |
+
resp_data = r.json()
|
| 446 |
+
obs: Dict[str, Any] = dict(resp_data.get("observation", resp_data))
|
| 447 |
+
max_steps = int(obs.get("max_steps", 10))
|
| 448 |
+
done = bool(resp_data.get("done", False))
|
| 449 |
+
conv: List[Dict[str, Any]] = []
|
| 450 |
+
tried: Dict[str, List[str]] = {}
|
| 451 |
+
resolved: List[str] = []
|
| 452 |
+
steps_out: List[Dict[str, Any]] = []
|
| 453 |
+
for step_num in range(1, max_steps + 1):
|
| 454 |
+
if done:
|
| 455 |
+
break
|
| 456 |
+
obs_pre = copy.deepcopy(obs)
|
| 457 |
+
user_msg = build_observation_prompt(obs_pre)
|
| 458 |
+
conv.append({"role": "user", "content": user_msg})
|
| 459 |
+
trimmed = conv[-6:]
|
| 460 |
+
memory = _memory_block(tried, resolved)
|
| 461 |
+
system_content = _BASE_SYSTEM + ("\n\n" + memory if memory else "")
|
| 462 |
+
messages: List[Dict[str, str]] = (
|
| 463 |
+
[{"role": "system", "content": system_content}] + trimmed
|
| 464 |
+
)
|
| 465 |
+
raw, pt, ct = llm.call(messages)
|
| 466 |
+
cost.add(
|
| 467 |
+
model_id, _estimate_openai_style_cost(model_id, pt, ct)
|
| 468 |
+
)
|
| 469 |
+
try:
|
| 470 |
+
action = parse_action(raw)
|
| 471 |
+
except Exception:
|
| 472 |
+
action = {"action_type": "noop", "params": {}}
|
| 473 |
+
if isinstance(action, dict) and "action_type" in action and model_id == MODEL_GEMINI:
|
| 474 |
+
try:
|
| 475 |
+
a2 = (
|
| 476 |
+
json.loads(raw[raw.find("{") : raw.rfind("}") + 1])
|
| 477 |
+
if "{" in raw
|
| 478 |
+
else None
|
| 479 |
+
)
|
| 480 |
+
if a2 and isinstance(a2, dict) and "action_type" in a2:
|
| 481 |
+
action = a2
|
| 482 |
+
except Exception:
|
| 483 |
+
pass
|
| 484 |
+
act_params = action.get("params", {}) or {}
|
| 485 |
+
if "replicas" in act_params:
|
| 486 |
+
try:
|
| 487 |
+
act_params["replicas"] = int(act_params["replicas"])
|
| 488 |
+
except (ValueError, TypeError):
|
| 489 |
+
act_params["replicas"] = 2
|
| 490 |
+
act_type = action.get("action_type", "noop")
|
| 491 |
+
target = act_params.get("service_id") or act_params.get("cache_name") or act_params.get("from_region") or ""
|
| 492 |
+
step_resp = http.post(
|
| 493 |
+
f"{base}/step",
|
| 494 |
+
json={"action": {"action_type": act_type, "params": act_params}},
|
| 495 |
+
)
|
| 496 |
+
sdata = step_resp.json() if step_resp.status_code == 200 else {}
|
| 497 |
+
obs = dict(sdata.get("observation", sdata))
|
| 498 |
+
done = bool(sdata.get("done", False))
|
| 499 |
+
reward = float(
|
| 500 |
+
obs.get("reward", sdata.get("reward", 0.0)) or 0.0
|
| 501 |
+
)
|
| 502 |
+
conv.append({"role": "assistant", "content": raw})
|
| 503 |
+
if act_type not in (
|
| 504 |
+
"inspect_logs",
|
| 505 |
+
"inspect_metrics",
|
| 506 |
+
"inspect_traces",
|
| 507 |
+
"noop",
|
| 508 |
+
) and target:
|
| 509 |
+
new_slo = obs.get("global_slo_score", 0.0)
|
| 510 |
+
for svc in obs.get("services", []):
|
| 511 |
+
if svc.get("id") == target and svc.get("status") == "healthy":
|
| 512 |
+
if target not in resolved:
|
| 513 |
+
resolved.append(target)
|
| 514 |
+
entry = f"{target} (slo={new_slo:.0%})"
|
| 515 |
+
tried.setdefault(str(act_type), [])
|
| 516 |
+
if entry not in tried[str(act_type)]:
|
| 517 |
+
tried[str(act_type)].append(entry)
|
| 518 |
+
obs_ser = json.loads(
|
| 519 |
+
json.dumps(
|
| 520 |
+
{k: v for k, v in obs_pre.items() if k != "reward"},
|
| 521 |
+
default=str,
|
| 522 |
+
)
|
| 523 |
+
)
|
| 524 |
+
steps_out.append(
|
| 525 |
+
{
|
| 526 |
+
"step": step_num,
|
| 527 |
+
"observation": obs_ser,
|
| 528 |
+
"prompt": user_msg,
|
| 529 |
+
"messages": messages,
|
| 530 |
+
"completion": raw,
|
| 531 |
+
"action": action,
|
| 532 |
+
"reward": reward,
|
| 533 |
+
"info": {k: v for k, v in sdata.items() if k not in ("observation",)},
|
| 534 |
+
}
|
| 535 |
+
)
|
| 536 |
+
try:
|
| 537 |
+
final_state = http.get(f"{base}/state").json()
|
| 538 |
+
except Exception:
|
| 539 |
+
final_state = {}
|
| 540 |
+
try:
|
| 541 |
+
grade = http.post(
|
| 542 |
+
f"{base}/grader",
|
| 543 |
+
json={
|
| 544 |
+
"final_slo_score": final_state.get("global_slo_score", 0.0),
|
| 545 |
+
"steps_taken": final_state.get("step_count", 0),
|
| 546 |
+
"max_steps": max_steps,
|
| 547 |
+
"actions_taken": obs.get("actions_taken", []),
|
| 548 |
+
"terminated": final_state.get("terminated", True),
|
| 549 |
+
"termination_reason": final_state.get("termination_reason"),
|
| 550 |
+
},
|
| 551 |
+
).json()
|
| 552 |
+
except Exception:
|
| 553 |
+
grade = {}
|
| 554 |
+
score = float(grade.get("score", 0.0) or 0.0)
|
| 555 |
+
return {
|
| 556 |
+
"model": model_id,
|
| 557 |
+
"task_id": task_id,
|
| 558 |
+
"seed": seed,
|
| 559 |
+
"steps": steps_out,
|
| 560 |
+
"grader": grade,
|
| 561 |
+
"final_score": score,
|
| 562 |
+
"max_steps": max_steps,
|
| 563 |
+
}
|
| 564 |
+
|
| 565 |
+
|
| 566 |
+
# ---------------------------------------------------------------------------
|
| 567 |
+
# Main
|
| 568 |
+
# ---------------------------------------------------------------------------
|
| 569 |
+
|
| 570 |
+
|
| 571 |
+
def _raw_path(model: str) -> Path:
|
| 572 |
+
safe = re.sub(r"[^a-zA-Z0-9._-]+", "_", model)
|
| 573 |
+
d = REPO_ROOT / "training" / "data" / "raw"
|
| 574 |
+
d.mkdir(parents=True, exist_ok=True)
|
| 575 |
+
return d / f"{safe}.jsonl"
|
| 576 |
+
|
| 577 |
+
|
| 578 |
+
def _wait_health(base: str, timeout: float = 45.0) -> None:
|
| 579 |
+
t0 = time.time()
|
| 580 |
+
while time.time() - t0 < timeout:
|
| 581 |
+
try:
|
| 582 |
+
r = httpx.get(f"{base}/health", timeout=3.0)
|
| 583 |
+
if r.status_code == 200:
|
| 584 |
+
return
|
| 585 |
+
except Exception:
|
| 586 |
+
pass
|
| 587 |
+
time.sleep(1.0)
|
| 588 |
+
print(f"[collect] health check timeout for {base} — continuing", flush=True)
|
| 589 |
+
|
| 590 |
+
|
| 591 |
+
def start_server(port: int) -> subprocess.Popen:
|
| 592 |
+
env = os.environ.copy()
|
| 593 |
+
pp = str(REPO_ROOT)
|
| 594 |
+
env["PYTHONPATH"] = pp if not env.get("PYTHONPATH") else pp + os.pathsep + env["PYTHONPATH"]
|
| 595 |
+
return subprocess.Popen(
|
| 596 |
+
[sys.executable, "-m", "uvicorn", "server.app:app", "--host", "127.0.0.1", "--port", str(port)],
|
| 597 |
+
cwd=REPO_ROOT,
|
| 598 |
+
env=env,
|
| 599 |
+
stdout=subprocess.DEVNULL,
|
| 600 |
+
stderr=subprocess.STDOUT,
|
| 601 |
+
)
|
| 602 |
+
|
| 603 |
+
|
| 604 |
+
def parse_models(s: str) -> List[str]:
|
| 605 |
+
return [m.strip() for m in s.split(",") if m.strip()]
|
| 606 |
+
|
| 607 |
+
|
| 608 |
+
def _plan_for_model(
|
| 609 |
+
model: str, c_e: int, c_m: int, c_h: int
|
| 610 |
+
) -> List[Tuple[str, str, int]]:
|
| 611 |
+
p = {
|
| 612 |
+
MODEL_GEMINI: plan_gemini,
|
| 613 |
+
MODEL_GPT: plan_gpt,
|
| 614 |
+
MODEL_GROK: plan_grok,
|
| 615 |
+
MODEL_KIMI: plan_kimi,
|
| 616 |
+
MODEL_DEEPSEEK: plan_deepseek,
|
| 617 |
+
}
|
| 618 |
+
fn = p.get(model)
|
| 619 |
+
if not fn:
|
| 620 |
+
return []
|
| 621 |
+
return fn(c_e, c_m, c_h)
|
| 622 |
+
|
| 623 |
+
|
| 624 |
+
def sanity_runs() -> List[Tuple[str, str, int]]:
|
| 625 |
+
return [
|
| 626 |
+
(MODEL_GEMINI, "easy", 42),
|
| 627 |
+
(MODEL_GPT, "easy", 42),
|
| 628 |
+
(MODEL_GROK, "easy", 13),
|
| 629 |
+
]
|
| 630 |
+
|
| 631 |
+
|
| 632 |
+
def main() -> None:
|
| 633 |
+
ap = argparse.ArgumentParser()
|
| 634 |
+
ap.add_argument(
|
| 635 |
+
"--models",
|
| 636 |
+
type=str,
|
| 637 |
+
default=",".join(sorted(ALL_CANON)),
|
| 638 |
+
help="Comma-separated model ids (default: all)",
|
| 639 |
+
)
|
| 640 |
+
ap.add_argument("--port", type=int, default=7860)
|
| 641 |
+
ap.add_argument("--no-start-server", action="store_true")
|
| 642 |
+
ap.add_argument("--sanity-only", action="store_true", help="Run only 3 smoke episodes (gemini, gpt, grok easy).")
|
| 643 |
+
ap.add_argument("--no-sanity", action="store_true", help="Skip pre-flight sanity runs.")
|
| 644 |
+
ap.add_argument(
|
| 645 |
+
"--budget-usd",
|
| 646 |
+
type=float,
|
| 647 |
+
default=5.0,
|
| 648 |
+
help="Total estimated-spend cap (heuristic) across all models.",
|
| 649 |
+
)
|
| 650 |
+
ap.add_argument(
|
| 651 |
+
"--per-model-budget-usd",
|
| 652 |
+
type=float,
|
| 653 |
+
default=0.0,
|
| 654 |
+
help="Per-model cap (0 = auto: max(2, budget/num selected models)).",
|
| 655 |
+
)
|
| 656 |
+
ap.add_argument(
|
| 657 |
+
"--episodes-easy",
|
| 658 |
+
type=int,
|
| 659 |
+
default=15,
|
| 660 |
+
help="Number of easy-task episodes per model (default 15, Wave 1.5).",
|
| 661 |
+
)
|
| 662 |
+
ap.add_argument(
|
| 663 |
+
"--episodes-medium",
|
| 664 |
+
type=int,
|
| 665 |
+
default=15,
|
| 666 |
+
help="Number of medium-task episodes per model (default 15).",
|
| 667 |
+
)
|
| 668 |
+
ap.add_argument(
|
| 669 |
+
"--episodes-hard",
|
| 670 |
+
type=int,
|
| 671 |
+
default=20,
|
| 672 |
+
help="Number of hard-task episodes per model (default 20).",
|
| 673 |
+
)
|
| 674 |
+
args = ap.parse_args()
|
| 675 |
+
want = set(parse_models(args.models))
|
| 676 |
+
bad = want - ALL_CANON
|
| 677 |
+
if bad:
|
| 678 |
+
raise SystemExit(f"Unknown model(s): {bad}. Valid: {sorted(ALL_CANON)}")
|
| 679 |
+
|
| 680 |
+
c_e, c_m, c_h = args.episodes_easy, args.episodes_medium, args.episodes_hard
|
| 681 |
+
if min(c_e, c_m, c_h) < 0:
|
| 682 |
+
raise SystemExit("--episodes-* must be non-negative.")
|
| 683 |
+
if c_e + c_m + c_h == 0:
|
| 684 |
+
raise SystemExit("At least one of --episodes-easy/medium/hard must be > 0.")
|
| 685 |
+
|
| 686 |
+
_ = full_plan(c_e, c_m, c_h) # exercise planner (raises if misconfigured)
|
| 687 |
+
|
| 688 |
+
# Required keys
|
| 689 |
+
for m in want:
|
| 690 |
+
if m == MODEL_GEMINI and not os.environ.get("GEMINI_API_KEY"):
|
| 691 |
+
raise SystemExit("GEMINI_API_KEY missing (needed for gemini-3.1-pro-preview).")
|
| 692 |
+
if m == MODEL_GPT and not all(
|
| 693 |
+
os.environ.get(x) for x in ("AZURE_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_API_VERSION")
|
| 694 |
+
):
|
| 695 |
+
raise SystemExit("Azure OpenAI env vars missing for gpt-5.4-pro.")
|
| 696 |
+
if m in (MODEL_GROK, MODEL_KIMI, MODEL_DEEPSEEK) and not all(
|
| 697 |
+
os.environ.get(x) for x in ("AZURE_API_KEY", "AZURE_AI_INFERENCE_ENDPOINT")
|
| 698 |
+
):
|
| 699 |
+
raise SystemExit("Azure inference env missing for " + m)
|
| 700 |
+
|
| 701 |
+
proc: Optional[subprocess.Popen] = None
|
| 702 |
+
if not args.no_start_server:
|
| 703 |
+
proc = start_server(args.port)
|
| 704 |
+
base = f"http://127.0.0.1:{args.port}"
|
| 705 |
+
_wait_health(base)
|
| 706 |
+
n_m = max(1, len(want))
|
| 707 |
+
per_cap = args.per_model_budget_usd
|
| 708 |
+
if per_cap <= 0.0:
|
| 709 |
+
per_cap = max(2.0, args.budget_usd / n_m)
|
| 710 |
+
cost = CostTracker(budget=args.budget_usd, per_model_max=per_cap)
|
| 711 |
+
# LLM clients (lazy)
|
| 712 |
+
_clients: Dict[str, LLMClient] = {}
|
| 713 |
+
def get_llm(mid: str) -> LLMClient:
|
| 714 |
+
if mid not in _clients:
|
| 715 |
+
_clients[mid] = LLMClient(mid)
|
| 716 |
+
return _clients[mid]
|
| 717 |
+
|
| 718 |
+
try:
|
| 719 |
+
already: Set[Tuple[str, str, int]] = set()
|
| 720 |
+
if args.sanity_only:
|
| 721 |
+
final_list = [r for r in sanity_runs() if r[0] in want]
|
| 722 |
+
else:
|
| 723 |
+
if not args.no_sanity:
|
| 724 |
+
for mid, task_id, seed in (r for r in sanity_runs() if r[0] in want):
|
| 725 |
+
print(f"[sanity] {mid} {task_id} seed={seed}", flush=True)
|
| 726 |
+
llm = get_llm(mid)
|
| 727 |
+
_ = run_one_episode(llm, mid, base, task_id, seed, cost)
|
| 728 |
+
already.add((mid, task_id, seed))
|
| 729 |
+
print("[sanity] pre-flight ok", flush=True)
|
| 730 |
+
final_list = []
|
| 731 |
+
for m in want:
|
| 732 |
+
for x in _plan_for_model(m, c_e, c_m, c_h):
|
| 733 |
+
if x in already:
|
| 734 |
+
continue
|
| 735 |
+
final_list.append(x)
|
| 736 |
+
n_done = 0
|
| 737 |
+
for mid, task_id, seed in final_list:
|
| 738 |
+
print(f"[episode] {mid} {task_id} seed={seed}", flush=True)
|
| 739 |
+
try:
|
| 740 |
+
llm = get_llm(mid)
|
| 741 |
+
ep = run_one_episode(llm, mid, base, task_id, seed, cost)
|
| 742 |
+
except RuntimeError as e:
|
| 743 |
+
print(f"[collect] Stopped: {e}", flush=True)
|
| 744 |
+
break
|
| 745 |
+
p = _raw_path(mid)
|
| 746 |
+
with p.open("a", encoding="utf-8") as f:
|
| 747 |
+
f.write(json.dumps(ep, ensure_ascii=False) + "\n")
|
| 748 |
+
n_done += 1
|
| 749 |
+
print(
|
| 750 |
+
f" -> score={ep.get('final_score', 0):.4f} lines->{p.name} (total est ${cost.usd:.2f})",
|
| 751 |
+
flush=True,
|
| 752 |
+
)
|
| 753 |
+
print(f"Done. Episodes written: {n_done}. Estimated spend: ${cost.usd:.2f}", flush=True)
|
| 754 |
+
finally:
|
| 755 |
+
if proc is not None:
|
| 756 |
+
proc.terminate()
|
| 757 |
+
try:
|
| 758 |
+
proc.wait(timeout=5)
|
| 759 |
+
except Exception:
|
| 760 |
+
proc.kill()
|
| 761 |
+
|
| 762 |
+
|
| 763 |
+
if __name__ == "__main__":
|
| 764 |
+
main()
|
training/config_utils.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Load dotenv from repo api.env + hg.env (optional). Does not read secrets into logs."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
_REPO_ROOT = Path(__file__).resolve().parent.parent
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def try_load_env_files() -> None:
|
| 12 |
+
for name in ("api.env", "hg.env"):
|
| 13 |
+
p = _REPO_ROOT / name
|
| 14 |
+
if not p.is_file():
|
| 15 |
+
continue
|
| 16 |
+
try:
|
| 17 |
+
from dotenv import load_dotenv
|
| 18 |
+
|
| 19 |
+
load_dotenv(p, override=False)
|
| 20 |
+
except ImportError:
|
| 21 |
+
_manual_load(p)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _manual_load(path: Path) -> None:
|
| 25 |
+
for line in path.read_text(encoding="utf-8", errors="ignore").splitlines():
|
| 26 |
+
line = line.strip()
|
| 27 |
+
if not line or line.startswith("#") or "=" not in line:
|
| 28 |
+
continue
|
| 29 |
+
k, v = line.split("=", 1)
|
| 30 |
+
k, v = k.strip(), v.strip().strip('"').strip("'")
|
| 31 |
+
if k and k not in os.environ:
|
| 32 |
+
os.environ[k] = v
|
training/data/DATASET_README_HF.md
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SevZero expert trajectories (SFT)
|
| 2 |
+
|
| 3 |
+
## Sources
|
| 4 |
+
|
| 5 |
+
- Synthetic expert rollouts from frontier models (Gemini 3.1 Pro, Azure OpenAI, Azure AI Inference)
|
| 6 |
+
against the local OpenEnv `server.app` SevZero environment.
|
| 7 |
+
|
| 8 |
+
## Filtering
|
| 9 |
+
|
| 10 |
+
- Episodes with final grader `score` **≥** `0.75` are included.
|
| 11 |
+
|
| 12 |
+
## Schema
|
| 13 |
+
|
| 14 |
+
- Each example has a `messages` list (Llama-3.1-8B-Instruct–style SFT) and `meta` (episode / step provenance):
|
| 15 |
+
- `system`: SRE on-call system prompt (same as `inference.SYSTEM_PROMPT` in the repo)
|
| 16 |
+
- `user`: JSON-serialized observation (shrink to ≤ 2048 tokens for the user part)
|
| 17 |
+
- `assistant`: one JSON object `{"action_type": "...", "params": {...}}`
|
| 18 |
+
|
| 19 |
+
## Stats (from `build_stats.json` at publish time)
|
| 20 |
+
|
| 21 |
+
{
|
| 22 |
+
"episodes_total_seen": 90,
|
| 23 |
+
"episodes_kept": 42,
|
| 24 |
+
"episodes_dropped": 48,
|
| 25 |
+
"mean_episode_score_kept": 0.836021,
|
| 26 |
+
"train_rows": 853,
|
| 27 |
+
"eval_rows": 80,
|
| 28 |
+
"max_prompt_token_length": 2,
|
| 29 |
+
"max_observation_user_token_budget": 2048,
|
| 30 |
+
"min_score_filter": 0.75
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
## Parquet
|
| 34 |
+
|
| 35 |
+
- Splits `train` and `eval` are also pushed in Parquet for fast `datasets.load_dataset`.
|
training/data/HANDOFF.md
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
- **Dataset URL (after `python -m training.push_dataset`):** https://huggingface.co/datasets/Mist-ic/sevzero-expert-trajectories
|
| 2 |
+
- **Rows:** see `build_stats.json` for `train_rows` and `eval_rows` after you run `build_dataset.py` on real raw JSONL.
|
| 3 |
+
- **Max prompt tokens:** see `max_prompt_token_length` in `build_stats.json` — set SFT/GRPO `max_seq_length` to this + `max_completion_length` (e.g. +1024).
|
| 4 |
+
- **Mean episode score:** `mean_episode_score_kept` in `build_stats.json` (episodes with final grader ≥ 0.85).
|
| 5 |
+
- **Caveats:** run `collect_trajectories.py` with working `api.env`/`hg.env`; use `--no-sanity` to skip the 3 pre-flight API calls; install extras (`python-dotenv`, `google-genai`, `azure-ai-inference`, `huggingface_hub`, `datasets`, `transformers`, `pydantic`) as needed — `pyproject.toml` is unchanged.
|
training/data/build_stats.json
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"episodes_total_seen": 90,
|
| 3 |
+
"episodes_kept": 42,
|
| 4 |
+
"episodes_dropped": 48,
|
| 5 |
+
"mean_episode_score_kept": 0.836021,
|
| 6 |
+
"train_rows": 853,
|
| 7 |
+
"eval_rows": 80,
|
| 8 |
+
"max_prompt_token_length": 2,
|
| 9 |
+
"max_observation_user_token_budget": 2048,
|
| 10 |
+
"min_score_filter": 0.75
|
| 11 |
+
}
|
training/data/dataset_info.json
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"description": "SevZero SFT expert trajectories for Llama-3.1-8B-Instruct style chat training.",
|
| 3 |
+
"version": "1.0.0",
|
| 4 |
+
"license": "apache-2.0",
|
| 5 |
+
"build": {
|
| 6 |
+
"episodes_total_seen": 90,
|
| 7 |
+
"episodes_kept": 42,
|
| 8 |
+
"episodes_dropped": 48,
|
| 9 |
+
"mean_episode_score_kept": 0.836021,
|
| 10 |
+
"train_rows": 853,
|
| 11 |
+
"eval_rows": 80,
|
| 12 |
+
"max_prompt_token_length": 2,
|
| 13 |
+
"max_observation_user_token_budget": 2048,
|
| 14 |
+
"min_score_filter": 0.75
|
| 15 |
+
}
|
| 16 |
+
}
|
training/data/sft_eval.jsonl
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
training/data/sft_train.jsonl
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
training/env_client.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Async HTTP client for the SevZero OpenEnv server (stateful /reset, /step, /state, /grader).
|
| 3 |
+
Used by train_grpo rollout_func. Does not use root client.py (WebSocket); mirrors inference.py HTTP usage.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
import asyncio
|
| 9 |
+
import os
|
| 10 |
+
from typing import Any, Dict, List, Optional
|
| 11 |
+
|
| 12 |
+
import httpx
|
| 13 |
+
|
| 14 |
+
_DEFAULT_TIMEOUT = 120.0
|
| 15 |
+
_MAX_RETRIES = 5
|
| 16 |
+
_BACKOFF = 1.6
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def _space_id_to_runtime_url(space_id: str) -> str:
|
| 20 |
+
"""HF Space 'org/name' -> https://org-name.hf.space (common runtime URL)."""
|
| 21 |
+
space_id = space_id.strip()
|
| 22 |
+
if space_id.startswith("http"):
|
| 23 |
+
return space_id.rstrip("/")
|
| 24 |
+
parts = space_id.split("/")
|
| 25 |
+
if len(parts) == 2:
|
| 26 |
+
org, name = parts[0], parts[1]
|
| 27 |
+
# HF uses lowercase, slashes -> dashes in subdomains
|
| 28 |
+
sub = f"{org}-{name}".replace("_", "-").lower()
|
| 29 |
+
return f"https://{sub}.hf.space"
|
| 30 |
+
raise ValueError(f"Invalid space_id (expected 'org/name' or URL): {space_id!r}")
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _backoff_delay(attempt: int) -> float:
|
| 34 |
+
return min(30.0, _BACKOFF**attempt)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _is_transient_status(code: int) -> bool:
|
| 38 |
+
return code in (429, 500, 502, 503, 504)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class AsyncSevZeroEnvClient:
|
| 42 |
+
"""
|
| 43 |
+
Minimal async env client: reset / step / state / grader.
|
| 44 |
+
Pass base_url from SEVZERO_ENV_URL or from_hf_space().
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
def __init__(
|
| 48 |
+
self,
|
| 49 |
+
base_url: str,
|
| 50 |
+
*,
|
| 51 |
+
token: Optional[str] = None,
|
| 52 |
+
timeout: float = _DEFAULT_TIMEOUT,
|
| 53 |
+
) -> None:
|
| 54 |
+
self._base = base_url.rstrip("/")
|
| 55 |
+
self._token = token
|
| 56 |
+
headers: Dict[str, str] = {"Content-Type": "application/json"}
|
| 57 |
+
if token:
|
| 58 |
+
headers["Authorization"] = f"Bearer {token}"
|
| 59 |
+
self._client = httpx.AsyncClient(
|
| 60 |
+
base_url=self._base,
|
| 61 |
+
headers=headers,
|
| 62 |
+
timeout=timeout,
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
@classmethod
|
| 66 |
+
def from_hf_space(
|
| 67 |
+
cls,
|
| 68 |
+
space_id: str,
|
| 69 |
+
token: Optional[str] = None,
|
| 70 |
+
) -> "AsyncSevZeroEnvClient":
|
| 71 |
+
"""
|
| 72 |
+
space_id: 'organization/space_name' (HF Space) or a full http(s) URL.
|
| 73 |
+
For private Spaces, pass a read token with Space access.
|
| 74 |
+
"""
|
| 75 |
+
return cls(_space_id_to_runtime_url(space_id), token=token or os.environ.get("HF_TOKEN"))
|
| 76 |
+
|
| 77 |
+
async def aclose(self) -> None:
|
| 78 |
+
await self._client.aclose()
|
| 79 |
+
|
| 80 |
+
async def _request(
|
| 81 |
+
self,
|
| 82 |
+
method: str,
|
| 83 |
+
path: str,
|
| 84 |
+
*,
|
| 85 |
+
json: Any = None,
|
| 86 |
+
) -> httpx.Response:
|
| 87 |
+
last_err: Optional[Exception] = None
|
| 88 |
+
for attempt in range(_MAX_RETRIES):
|
| 89 |
+
try:
|
| 90 |
+
r = await self._client.request(method, path, json=json)
|
| 91 |
+
if r.status_code < 400:
|
| 92 |
+
return r
|
| 93 |
+
if _is_transient_status(r.status_code) and attempt < _MAX_RETRIES - 1:
|
| 94 |
+
await asyncio.sleep(_backoff_delay(attempt + 1))
|
| 95 |
+
continue
|
| 96 |
+
return r
|
| 97 |
+
except (httpx.TimeoutException, httpx.NetworkError) as e:
|
| 98 |
+
last_err = e
|
| 99 |
+
if attempt < _MAX_RETRIES - 1:
|
| 100 |
+
await asyncio.sleep(_backoff_delay(attempt + 1))
|
| 101 |
+
continue
|
| 102 |
+
raise
|
| 103 |
+
if last_err:
|
| 104 |
+
raise last_err
|
| 105 |
+
raise RuntimeError("request failed")
|
| 106 |
+
|
| 107 |
+
async def reset(
|
| 108 |
+
self,
|
| 109 |
+
*,
|
| 110 |
+
task_id: str = "hard",
|
| 111 |
+
seed: int = 13,
|
| 112 |
+
episode_id: Optional[str] = None,
|
| 113 |
+
) -> Dict[str, Any]:
|
| 114 |
+
body: Dict[str, Any] = {"task_id": task_id, "seed": seed}
|
| 115 |
+
if episode_id:
|
| 116 |
+
body["episode_id"] = episode_id
|
| 117 |
+
r = await self._request("POST", "/reset", json=body)
|
| 118 |
+
r.raise_for_status()
|
| 119 |
+
return r.json()
|
| 120 |
+
|
| 121 |
+
async def step(self, action: Dict[str, Any]) -> Dict[str, Any]:
|
| 122 |
+
r = await self._request("POST", "/step", json={"action": action})
|
| 123 |
+
r.raise_for_status()
|
| 124 |
+
return r.json()
|
| 125 |
+
|
| 126 |
+
async def get_state(self) -> Dict[str, Any]:
|
| 127 |
+
r = await self._request("GET", "/state")
|
| 128 |
+
r.raise_for_status()
|
| 129 |
+
return r.json()
|
| 130 |
+
|
| 131 |
+
async def grade_episode(
|
| 132 |
+
self,
|
| 133 |
+
*,
|
| 134 |
+
final_slo_score: float,
|
| 135 |
+
steps_taken: int,
|
| 136 |
+
max_steps: int,
|
| 137 |
+
actions_taken: List[Dict[str, Any]],
|
| 138 |
+
terminated: bool,
|
| 139 |
+
termination_reason: Optional[str],
|
| 140 |
+
) -> Dict[str, Any]:
|
| 141 |
+
r = await self._request(
|
| 142 |
+
"POST",
|
| 143 |
+
"/grader",
|
| 144 |
+
json={
|
| 145 |
+
"final_slo_score": final_slo_score,
|
| 146 |
+
"steps_taken": steps_taken,
|
| 147 |
+
"max_steps": max_steps,
|
| 148 |
+
"actions_taken": actions_taken,
|
| 149 |
+
"terminated": terminated,
|
| 150 |
+
"termination_reason": termination_reason,
|
| 151 |
+
},
|
| 152 |
+
)
|
| 153 |
+
r.raise_for_status()
|
| 154 |
+
return r.json()
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def run_async(coro):
|
| 158 |
+
"""Run async coroutine from sync context (rollout_func)."""
|
| 159 |
+
return asyncio.run(coro)
|
training/eval.py
ADDED
|
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Eval: local HF adapters + Gemini (google-genai) + Azure OpenAI + Azure AI Inference.
|
| 4 |
+
Writes eval_results.csv; pushes Mist-ic/sevzero-eval-results with HF_MAIN_TOKEN. No Claude.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import argparse
|
| 10 |
+
import csv
|
| 11 |
+
import os
|
| 12 |
+
import sys
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
from typing import Any, Callable, Dict, List
|
| 15 |
+
|
| 16 |
+
_REPO = Path(__file__).resolve().parent.parent
|
| 17 |
+
if str(_REPO) not in sys.path:
|
| 18 |
+
sys.path.insert(0, str(_REPO))
|
| 19 |
+
|
| 20 |
+
from training.config_utils import try_load_env_files
|
| 21 |
+
from training.rollout_sevzero import SRE_SYSTEM_PROMPT, build_observation_prompt, parse_action
|
| 22 |
+
|
| 23 |
+
try_load_env_files()
|
| 24 |
+
|
| 25 |
+
HELD_OUT = (13, 99, 777)
|
| 26 |
+
DEFAULT_TASKS = ("easy", "medium", "hard")
|
| 27 |
+
DATASET_HUB = "Mist-ic/sevzero-eval-results"
|
| 28 |
+
|
| 29 |
+
BUILTIN: Dict[str, str] = {
|
| 30 |
+
"untrained-llama": "base:meta-llama/Llama-3.1-8B-Instruct",
|
| 31 |
+
"sft-primary": os.getenv("SFT_ADAPTER_PRIMARY", "PhaseOfCode/sevzero-llama3-8b-sft"),
|
| 32 |
+
"sft-backup": os.getenv("SFT_ADAPTER_BACKUP", "NoahInOblivion/sevzero-llama3-8b-sft"),
|
| 33 |
+
"sft-innovation": os.getenv("SFT_ADAPTER_INNOVATION", "NoxIsOblivion/sevzero-llama3-8b-sft"),
|
| 34 |
+
"grpo-primary": os.getenv("GRPO_ADAPTER_PRIMARY", "PhaseOfCode/sevzero-llama3-8b-grpo-primary"),
|
| 35 |
+
"grpo-stability": os.getenv("GRPO_ADAPTER_STABILITY", "NoahInOblivion/sevzero-llama3-8b-grpo-stability"),
|
| 36 |
+
"grpo-innovation": os.getenv("GRPO_ADAPTER_INNOVATION", "NoxIsOblivion/sevzero-llama3-8b-grpo-innovation"),
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
AZURE_INF = {
|
| 40 |
+
"grok-4.20-reasoning": "grok-2-latest",
|
| 41 |
+
"kimi-k2.6": "kimi-k2-6-2025",
|
| 42 |
+
"DeepSeek-V3.2": "DeepSeek-V3-2",
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def run_episode(
|
| 47 |
+
base: str, task: str, seed: int, answer: Callable[[str, str], str]
|
| 48 |
+
) -> Dict[str, Any]:
|
| 49 |
+
import httpx
|
| 50 |
+
|
| 51 |
+
with httpx.Client(base_url=base.rstrip("/"), timeout=120.0) as client:
|
| 52 |
+
r = client.post("/reset", json={"task_id": task, "seed": seed})
|
| 53 |
+
r.raise_for_status()
|
| 54 |
+
ro = r.json()
|
| 55 |
+
obs = ro.get("observation", ro)
|
| 56 |
+
done = ro.get("done", False)
|
| 57 |
+
user_pfx = f"You are the on-call SRE. task={task!r} seed={seed}.\n\n## Session\n"
|
| 58 |
+
for _ in range(1 + int(obs.get("max_steps", 20))):
|
| 59 |
+
if done:
|
| 60 |
+
break
|
| 61 |
+
user_block = user_pfx + build_observation_prompt(obs)
|
| 62 |
+
text = answer(SRE_SYSTEM_PROMPT, user_block)
|
| 63 |
+
act = parse_action(text)
|
| 64 |
+
sr = client.post(
|
| 65 |
+
"/step",
|
| 66 |
+
json={"action": {"action_type": str(act.get("action_type", "noop")), "params": act.get("params") or {}}},
|
| 67 |
+
)
|
| 68 |
+
sr.raise_for_status()
|
| 69 |
+
out = sr.json()
|
| 70 |
+
obs = out.get("observation", out)
|
| 71 |
+
done = out.get("done", False)
|
| 72 |
+
stt = client.get("/state")
|
| 73 |
+
stt.raise_for_status()
|
| 74 |
+
fs = stt.json()
|
| 75 |
+
g = client.post(
|
| 76 |
+
"/grader",
|
| 77 |
+
json={
|
| 78 |
+
"final_slo_score": float(fs.get("global_slo_score", 0.0)),
|
| 79 |
+
"steps_taken": int(fs.get("step_count", 0)),
|
| 80 |
+
"max_steps": int((obs or {}).get("max_steps", 10)),
|
| 81 |
+
"actions_taken": list((obs or {}).get("actions_taken", [])),
|
| 82 |
+
"terminated": bool(fs.get("terminated", True)),
|
| 83 |
+
"termination_reason": fs.get("termination_reason"),
|
| 84 |
+
},
|
| 85 |
+
)
|
| 86 |
+
js: Dict[str, Any] = {}
|
| 87 |
+
if g.status_code < 400:
|
| 88 |
+
js = g.json()
|
| 89 |
+
return {
|
| 90 |
+
"score": float(js.get("score", 0.0)),
|
| 91 |
+
"slo_recovery": float(js.get("slo_recovery", 0.0)),
|
| 92 |
+
"action_efficiency": float(js.get("action_efficiency", 0.0)),
|
| 93 |
+
"time_efficiency": float(js.get("time_efficiency", 0.0)),
|
| 94 |
+
"steps_used": int(fs.get("step_count", 0)),
|
| 95 |
+
"terminated": fs.get("terminated", True),
|
| 96 |
+
"termination_reason": str(fs.get("termination_reason", "")),
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def load_llama_peft(adapter_id: str | None):
|
| 101 |
+
import torch
|
| 102 |
+
from peft import PeftModel
|
| 103 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
| 104 |
+
|
| 105 |
+
base_id = "meta-llama/Llama-3.1-8B-Instruct"
|
| 106 |
+
tok = AutoTokenizer.from_pretrained(base_id, use_fast=True, token=os.environ.get("HF_TOKEN"))
|
| 107 |
+
if tok.pad_token is None:
|
| 108 |
+
tok.pad_token = tok.eos_token
|
| 109 |
+
bnb = BitsAndBytesConfig(
|
| 110 |
+
load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
|
| 111 |
+
)
|
| 112 |
+
m = AutoModelForCausalLM.from_pretrained(
|
| 113 |
+
base_id, quantization_config=bnb, device_map="auto", torch_dtype=torch.bfloat16, token=os.environ.get("HF_TOKEN")
|
| 114 |
+
)
|
| 115 |
+
if adapter_id:
|
| 116 |
+
m = PeftModel.from_pretrained(m, adapter_id, token=os.environ.get("HF_TOKEN"))
|
| 117 |
+
m.eval()
|
| 118 |
+
return tok, m
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def hf_answer(tok, mdl):
|
| 122 |
+
import torch
|
| 123 |
+
|
| 124 |
+
def answer(system: str, user: str) -> str:
|
| 125 |
+
messages = [{"role": "system", "content": system}, {"role": "user", "content": user}]
|
| 126 |
+
p = tok.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
|
| 127 |
+
inputs = tok(p, return_tensors="pt").to(mdl.device)
|
| 128 |
+
with torch.no_grad():
|
| 129 |
+
o = mdl.generate(**inputs, max_new_tokens=256, do_sample=True, temperature=0.0)
|
| 130 |
+
gen = o[0, inputs["input_ids"].shape[1] :]
|
| 131 |
+
return tok.decode(gen, skip_special_tokens=True)
|
| 132 |
+
|
| 133 |
+
return answer
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def answer_gemini(system: str, user: str) -> str:
|
| 137 |
+
from google import genai
|
| 138 |
+
|
| 139 |
+
model = os.environ.get(
|
| 140 |
+
"GEMINI_EVAL_MODEL",
|
| 141 |
+
os.environ.get("GEMINI_MODEL_PRO", "gemini-3.1-pro-preview"),
|
| 142 |
+
)
|
| 143 |
+
c = genai.Client(api_key=os.environ["GEMINI_API_KEY"])
|
| 144 |
+
r = c.models.generate_content(model=model, contents=f"{system}\n\n{user}")
|
| 145 |
+
return (r.text or "").strip()
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def answer_azure_openai(system: str, user: str) -> str:
|
| 149 |
+
from openai import OpenAI
|
| 150 |
+
|
| 151 |
+
ep = os.environ.get("AZURE_OPENAI_ENDPOINT", "").rstrip("/")
|
| 152 |
+
c = OpenAI(
|
| 153 |
+
api_key=os.environ.get("AZURE_API_KEY", ""),
|
| 154 |
+
base_url=ep + "/openai/v1",
|
| 155 |
+
)
|
| 156 |
+
dep = os.environ.get("AZURE_GPT_DEPLOYMENT", "gpt-5.4-pro")
|
| 157 |
+
r = c.chat.completions.create(
|
| 158 |
+
model=dep,
|
| 159 |
+
messages=[{"role": "system", "content": system}, {"role": "user", "content": user}],
|
| 160 |
+
temperature=0.0,
|
| 161 |
+
max_tokens=512,
|
| 162 |
+
)
|
| 163 |
+
return (r.choices[0].message.content or "").strip()
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def answer_azure_inference(model_name: str, system: str, user: str) -> str:
|
| 167 |
+
from azure.ai.inference import ChatCompletionsClient
|
| 168 |
+
from azure.core.credentials import AzureKeyCredential
|
| 169 |
+
|
| 170 |
+
ep = os.environ.get("AZURE_AI_INFERENCE_ENDPOINT", "").rstrip("/") + "/"
|
| 171 |
+
c = ChatCompletionsClient(endpoint=ep, credential=AzureKeyCredential(os.environ.get("AZURE_API_KEY", "")))
|
| 172 |
+
r = c.complete(
|
| 173 |
+
model_name=model_name,
|
| 174 |
+
messages=[{"role": "user", "content": f"{system}\n\n{user}"}],
|
| 175 |
+
)
|
| 176 |
+
return (r.choices[0].message.content or "").strip()
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def pick_answer_fn(name: str) -> Callable[[str, str], str]:
|
| 180 |
+
n = name.strip()
|
| 181 |
+
if n in BUILTIN:
|
| 182 |
+
spec = BUILTIN[n]
|
| 183 |
+
aid = None if spec.startswith("base:") else spec
|
| 184 |
+
tok, m = load_llama_peft(aid)
|
| 185 |
+
return hf_answer(tok, m)
|
| 186 |
+
if "/" in n and n.count("/") == 1 and not n.startswith("meta-llama/"):
|
| 187 |
+
tok, m = load_llama_peft(n)
|
| 188 |
+
return hf_answer(tok, m)
|
| 189 |
+
if n.startswith("gemini"):
|
| 190 |
+
return answer_gemini
|
| 191 |
+
if "gpt" in n.lower() or n == "gpt-5.4-pro":
|
| 192 |
+
return answer_azure_openai
|
| 193 |
+
if n in AZURE_INF:
|
| 194 |
+
mid = AZURE_INF[n]
|
| 195 |
+
|
| 196 |
+
def _fn(s: str, u: str) -> str:
|
| 197 |
+
return answer_azure_inference(mid, s, u)
|
| 198 |
+
|
| 199 |
+
return _fn
|
| 200 |
+
raise ValueError(f"Unknown model key: {name!r}")
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def main() -> None:
|
| 204 |
+
ap = argparse.ArgumentParser()
|
| 205 |
+
ap.add_argument("--models", type=str, default="untrained-llama")
|
| 206 |
+
ap.add_argument("--out", type=str, default="eval_results.csv")
|
| 207 |
+
ap.add_argument("--seeds", type=str, default=",".join(str(s) for s in HELD_OUT))
|
| 208 |
+
ap.add_argument("--tasks", type=str, default=",".join(DEFAULT_TASKS))
|
| 209 |
+
a = ap.parse_args()
|
| 210 |
+
|
| 211 |
+
base = (os.environ.get("SEVZERO_ENV_URL") or "").rstrip("/")
|
| 212 |
+
if not base:
|
| 213 |
+
raise SystemExit("SEVZERO_ENV_URL required")
|
| 214 |
+
|
| 215 |
+
models = [m.strip() for m in a.models.split(",") if m.strip()]
|
| 216 |
+
seeds = [int(x) for x in a.seeds.split(",")]
|
| 217 |
+
tasks = [t.strip() for t in a.tasks.split(",")]
|
| 218 |
+
|
| 219 |
+
rows: List[Dict[str, Any]] = []
|
| 220 |
+
for mname in models:
|
| 221 |
+
try:
|
| 222 |
+
answer = pick_answer_fn(mname)
|
| 223 |
+
except ValueError as e:
|
| 224 |
+
print(f"SKIP {mname}: {e}", flush=True)
|
| 225 |
+
continue
|
| 226 |
+
for task in tasks:
|
| 227 |
+
for seed in seeds:
|
| 228 |
+
r = run_episode(base, task, seed, answer)
|
| 229 |
+
rows.append(
|
| 230 |
+
{
|
| 231 |
+
"model": mname,
|
| 232 |
+
"task": task,
|
| 233 |
+
"seed": seed,
|
| 234 |
+
**r,
|
| 235 |
+
}
|
| 236 |
+
)
|
| 237 |
+
print(rows[-1], flush=True)
|
| 238 |
+
|
| 239 |
+
with Path(a.out).open("w", newline="", encoding="utf-8") as f:
|
| 240 |
+
fieldnames = [
|
| 241 |
+
"model",
|
| 242 |
+
"task",
|
| 243 |
+
"seed",
|
| 244 |
+
"score",
|
| 245 |
+
"slo_recovery",
|
| 246 |
+
"action_efficiency",
|
| 247 |
+
"time_efficiency",
|
| 248 |
+
"steps_used",
|
| 249 |
+
"terminated",
|
| 250 |
+
"termination_reason",
|
| 251 |
+
]
|
| 252 |
+
w = csv.DictWriter(f, fieldnames=fieldnames)
|
| 253 |
+
w.writeheader()
|
| 254 |
+
for r in rows:
|
| 255 |
+
w.writerow(r)
|
| 256 |
+
|
| 257 |
+
tok_m = os.environ.get("HF_MAIN_TOKEN", "")
|
| 258 |
+
if not tok_m:
|
| 259 |
+
print("HF_MAIN_TOKEN not set — skip Hub push", flush=True)
|
| 260 |
+
return
|
| 261 |
+
from datasets import Dataset
|
| 262 |
+
|
| 263 |
+
ds = Dataset.from_list([dict(x) for x in rows])
|
| 264 |
+
ds.push_to_hub(DATASET_HUB, token=tok_m, private=False)
|
| 265 |
+
print(f"OK: pushed hf.co/datasets/{DATASET_HUB}", flush=True)
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
if __name__ == "__main__":
|
| 269 |
+
main()
|
training/launch_hf_job.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Submit a HuggingFace Job to run training/train_sft.py or training/train_grpo.py.
|
| 4 |
+
Uses huggingface_hub.run_job; prints job URL; appends training/runs.jsonl.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import argparse
|
| 10 |
+
import json
|
| 11 |
+
import os
|
| 12 |
+
import subprocess
|
| 13 |
+
import sys
|
| 14 |
+
from datetime import datetime, timezone
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
|
| 17 |
+
_REPO = Path(__file__).resolve().parent.parent
|
| 18 |
+
if str(_REPO) not in sys.path:
|
| 19 |
+
sys.path.insert(0, str(_REPO))
|
| 20 |
+
|
| 21 |
+
from training.config_utils import try_load_env_files
|
| 22 |
+
|
| 23 |
+
try_load_env_files()
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def _default_git_url() -> str:
|
| 27 |
+
r = subprocess.run(
|
| 28 |
+
["git", "remote", "get-url", "origin"],
|
| 29 |
+
cwd=str(_REPO),
|
| 30 |
+
capture_output=True,
|
| 31 |
+
text=True,
|
| 32 |
+
)
|
| 33 |
+
return (r.stdout or "").strip() if r.returncode == 0 else ""
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def main() -> None:
|
| 37 |
+
p = argparse.ArgumentParser()
|
| 38 |
+
p.add_argument("--account_token", type=str, default=os.environ.get("HF_TOKEN", ""))
|
| 39 |
+
p.add_argument("--script", type=str, choices=("sft", "grpo"), required=True)
|
| 40 |
+
p.add_argument("--variant_name", type=str, default="run")
|
| 41 |
+
p.add_argument("--hardware", type=str, default="l40sx1")
|
| 42 |
+
p.add_argument(
|
| 43 |
+
"--image",
|
| 44 |
+
type=str,
|
| 45 |
+
default="pytorch/pytorch:2.6.0-cuda12.4-cudnn9-runtime",
|
| 46 |
+
)
|
| 47 |
+
p.add_argument("--git-url", type=str, default="")
|
| 48 |
+
p.add_argument(
|
| 49 |
+
"--env_vars",
|
| 50 |
+
type=str,
|
| 51 |
+
default="",
|
| 52 |
+
help="KEY=val pairs comma-separated, e.g. SEVZERO_ENV_URL=https://x.hf.space,HF_MAIN_TOKEN=...",
|
| 53 |
+
)
|
| 54 |
+
a, rest = p.parse_known_args()
|
| 55 |
+
if not a.account_token:
|
| 56 |
+
raise SystemExit("Need HF_TOKEN or --account_token")
|
| 57 |
+
git_url = a.git_url or _default_git_url()
|
| 58 |
+
if not git_url:
|
| 59 |
+
raise SystemExit("Set --git-url or configure git origin")
|
| 60 |
+
ev = {k: v for k, v in [x.split("=", 1) for x in a.env_vars.split(",") if "=" in x]}
|
| 61 |
+
if "SEVZERO_ENV_URL" not in ev and os.environ.get("SEVZERO_ENV_URL"):
|
| 62 |
+
ev["SEVZERO_ENV_URL"] = os.environ["SEVZERO_ENV_URL"]
|
| 63 |
+
|
| 64 |
+
which = f"training/train_{a.script}.py"
|
| 65 |
+
extra = " ".join(rest)
|
| 66 |
+
inner = (
|
| 67 |
+
f"set -euo pipefail && git clone --depth 1 {git_url!r} /work/r && cd /work/r && "
|
| 68 |
+
"pip install -U pip 'trl>=0.20' 'peft' 'transformers' 'accelerate' 'bitsandbytes' 'datasets' "
|
| 69 |
+
"'huggingface_hub' 'httpx' 'python-dotenv' 'vllm' 'unsloth' 2>/dev/null || true && "
|
| 70 |
+
f"python {which} --variant_name {a.variant_name!r} {extra}"
|
| 71 |
+
)
|
| 72 |
+
from huggingface_hub import run_job
|
| 73 |
+
|
| 74 |
+
job = run_job(
|
| 75 |
+
image=a.image,
|
| 76 |
+
command=["bash", "-lc", inner],
|
| 77 |
+
env=ev,
|
| 78 |
+
secrets={"HF_TOKEN": a.account_token},
|
| 79 |
+
flavor=a.hardware,
|
| 80 |
+
)
|
| 81 |
+
with (_REPO / "training" / "runs.jsonl").open("a", encoding="utf-8") as f:
|
| 82 |
+
f.write(
|
| 83 |
+
json.dumps(
|
| 84 |
+
{
|
| 85 |
+
"account_token_tail": a.account_token[-4:] if len(a.account_token) > 4 else "",
|
| 86 |
+
"job_id": str(getattr(job, "id", job)),
|
| 87 |
+
"variant_name": a.variant_name,
|
| 88 |
+
"started_at": datetime.now(timezone.utc).isoformat(),
|
| 89 |
+
}
|
| 90 |
+
)
|
| 91 |
+
+ "\n"
|
| 92 |
+
)
|
| 93 |
+
print(getattr(job, "url", f"https://huggingface.co/jobs/{getattr(job, 'id', job)}"), flush=True)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
if __name__ == "__main__":
|
| 97 |
+
main()
|
training/loader.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Load SevZero SFT data for a trainer: local JSONL or the Hub Parquet copy.
|
| 3 |
+
|
| 4 |
+
The training config should set `max_seq_length` to at least
|
| 5 |
+
`max_prompt_token_length` from `build_stats.json` (plus max completion length).
|
| 6 |
+
"""
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import json
|
| 10 |
+
import os
|
| 11 |
+
import sys
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from typing import Any, Optional, Union
|
| 14 |
+
|
| 15 |
+
REPO_ROOT = Path(__file__).resolve().parent.parent
|
| 16 |
+
DATA_DIR = REPO_ROOT / "training" / "data"
|
| 17 |
+
|
| 18 |
+
try:
|
| 19 |
+
from datasets import Dataset, DatasetDict, load_dataset
|
| 20 |
+
except ImportError as e:
|
| 21 |
+
raise ImportError("Install `datasets` to use the loader.") from e
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def load_local_jsonl(
|
| 25 |
+
train_path: Optional[Path] = None,
|
| 26 |
+
eval_path: Optional[Path] = None,
|
| 27 |
+
) -> DatasetDict:
|
| 28 |
+
train_path = train_path or (DATA_DIR / "sft_train.jsonl")
|
| 29 |
+
eval_path = eval_path or (DATA_DIR / "sft_eval.jsonl")
|
| 30 |
+
train = load_dataset("json", data_files=str(train_path), split="train")
|
| 31 |
+
if eval_path.is_file() and eval_path.stat().st_size > 0:
|
| 32 |
+
ev = load_dataset("json", data_files=str(eval_path), split="train")
|
| 33 |
+
else:
|
| 34 |
+
ev = train.select([])
|
| 35 |
+
return DatasetDict(train=train, eval=ev)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def load_from_hub(
|
| 39 |
+
repo_id: str = "Mist-ic/sevzero-expert-trajectories",
|
| 40 |
+
token: Optional[str] = None,
|
| 41 |
+
) -> DatasetDict:
|
| 42 |
+
tok = token or os.environ.get("HF_MAIN_TOKEN")
|
| 43 |
+
return load_dataset(repo_id, token=tok) # type: ignore[return-value]
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def read_build_stats() -> dict[str, Any]:
|
| 47 |
+
p = DATA_DIR / "build_stats.json"
|
| 48 |
+
if not p.is_file():
|
| 49 |
+
return {}
|
| 50 |
+
return json.loads(p.read_text(encoding="utf-8"))
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def recommended_max_seq_length(plus_completion: int = 1024) -> int:
|
| 54 |
+
s = read_build_stats()
|
| 55 |
+
m = int(s.get("max_prompt_token_length", 0) or 0)
|
| 56 |
+
return m + plus_completion
|
training/preflight.py
ADDED
|
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
(1) In-process Sim + grader: golden remediation plan → score >= 0.9 when possible
|
| 4 |
+
(2) Uvicorn /health (optional) + 5 CPU GRPO steps with rollout_func + tiny model
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
import signal
|
| 11 |
+
import subprocess
|
| 12 |
+
import sys
|
| 13 |
+
import time
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
from typing import Any, Dict, List, Tuple
|
| 16 |
+
|
| 17 |
+
_REPO = Path(__file__).resolve().parent.parent
|
| 18 |
+
if str(_REPO) not in sys.path:
|
| 19 |
+
sys.path.insert(0, str(_REPO))
|
| 20 |
+
|
| 21 |
+
from training.config_utils import try_load_env_files
|
| 22 |
+
|
| 23 |
+
try_load_env_files()
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def _action_plan(seed: int, task_id: str) -> List[Tuple[str, Dict[str, Any]]]:
|
| 27 |
+
from server.failures import FailureType
|
| 28 |
+
from server.scenarios import generate_scenario
|
| 29 |
+
|
| 30 |
+
sc = generate_scenario(seed, task_id)
|
| 31 |
+
if not sc.failure_specs:
|
| 32 |
+
return [("noop", {})]
|
| 33 |
+
spec = sc.failure_specs[0]
|
| 34 |
+
sid = spec.service_id
|
| 35 |
+
ft = spec.failure_type
|
| 36 |
+
if ft == FailureType.BAD_DEPLOY:
|
| 37 |
+
return [("rollback_service", {"service_id": sid})]
|
| 38 |
+
if ft in (FailureType.CONFIG_STARTUP, FailureType.CONFIG_RUNTIME):
|
| 39 |
+
k = spec.broken_config_key or "timeout_ms"
|
| 40 |
+
out = [("tune_config", {"service_id": sid, "key": k, "value": "correct"})]
|
| 41 |
+
if ft == FailureType.CONFIG_STARTUP:
|
| 42 |
+
out.append(("restart_service", {"service_id": sid}))
|
| 43 |
+
return out
|
| 44 |
+
if ft == FailureType.CACHE_FAILURE:
|
| 45 |
+
return [("clear_cache", {"cache_name": sid})]
|
| 46 |
+
if ft == FailureType.CASCADING_LATENCY:
|
| 47 |
+
return [("scale_service", {"service_id": sid, "replicas": 4})]
|
| 48 |
+
if ft == FailureType.NETWORK_ERROR:
|
| 49 |
+
return [("noop", {}), ("noop", {})]
|
| 50 |
+
return [("restart_service", {"service_id": sid})]
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def _inproc_golden_score(seed: int, task_id: str) -> float:
|
| 54 |
+
from server.grader import grade_episode
|
| 55 |
+
from server.scenarios import generate_scenario
|
| 56 |
+
from server.simulator import Simulator
|
| 57 |
+
|
| 58 |
+
sc = generate_scenario(seed, task_id)
|
| 59 |
+
sim = Simulator()
|
| 60 |
+
sim.reset(seed=seed, difficulty=sc.difficulty, failure_specs=sc.failure_specs)
|
| 61 |
+
for at, p in _action_plan(seed, task_id):
|
| 62 |
+
sim.step(at, p)
|
| 63 |
+
for _ in range(4):
|
| 64 |
+
if sim.terminated:
|
| 65 |
+
break
|
| 66 |
+
sim.step("noop", {})
|
| 67 |
+
g = grade_episode(
|
| 68 |
+
final_slo_score=sim.get_slo_score(),
|
| 69 |
+
steps_taken=len(sim.actions_taken),
|
| 70 |
+
max_steps=sc.max_steps,
|
| 71 |
+
actions_taken=sim.actions_taken,
|
| 72 |
+
terminated=sim.terminated,
|
| 73 |
+
termination_reason=sim.termination_reason,
|
| 74 |
+
)
|
| 75 |
+
return float(g.score)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def _grpo_tiny() -> bool:
|
| 79 |
+
try:
|
| 80 |
+
import trl # noqa: F401
|
| 81 |
+
except ImportError:
|
| 82 |
+
print("GRPO preflight: trl not installed — skip (pip install trl)", flush=True)
|
| 83 |
+
return True
|
| 84 |
+
os.environ["UNSLOTH_DISABLE"] = "1"
|
| 85 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ.get("CUDA_VISIBLE_DEVICES", "")
|
| 86 |
+
|
| 87 |
+
from datasets import Dataset
|
| 88 |
+
from peft import LoraConfig, get_peft_model
|
| 89 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 90 |
+
from trl import GRPOConfig, GRPOTrainer
|
| 91 |
+
from trl.experimental.openenv import generate_rollout_completions
|
| 92 |
+
|
| 93 |
+
from training.env_client import AsyncSevZeroEnvClient, run_async
|
| 94 |
+
from training.rollout_sevzero import SRE_SYSTEM_PROMPT, build_observation_prompt, parse_action
|
| 95 |
+
|
| 96 |
+
base = (os.environ.get("SEVZERO_ENV_URL") or "").rstrip("/")
|
| 97 |
+
if not base:
|
| 98 |
+
print("SEVZERO_ENV_URL unset — skip GRPO smoke", flush=True)
|
| 99 |
+
return True
|
| 100 |
+
|
| 101 |
+
tok = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M-Instruct")
|
| 102 |
+
m = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM2-135M-Instruct", device_map="cpu")
|
| 103 |
+
m = get_peft_model(
|
| 104 |
+
m,
|
| 105 |
+
LoraConfig(
|
| 106 |
+
r=4,
|
| 107 |
+
lora_alpha=8,
|
| 108 |
+
target_modules=["q_proj", "v_proj"],
|
| 109 |
+
lora_dropout=0.0,
|
| 110 |
+
task_type="CAUSAL_LM",
|
| 111 |
+
),
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
def rollout_func(prompts, trainer):
|
| 115 |
+
ep_ids: List[int] = []
|
| 116 |
+
ec_ids: List[int] = []
|
| 117 |
+
elp: List[float] = []
|
| 118 |
+
env_r: List[float] = []
|
| 119 |
+
for pr in prompts:
|
| 120 |
+
client = AsyncSevZeroEnvClient(base, None)
|
| 121 |
+
|
| 122 |
+
async def run_one():
|
| 123 |
+
p_ids, c_ids, lps = [], [], []
|
| 124 |
+
step_sum = 0.0
|
| 125 |
+
try:
|
| 126 |
+
ro = await client.reset(task_id="easy", seed=7)
|
| 127 |
+
obs = ro.get("observation", ro)
|
| 128 |
+
done = ro.get("done", False)
|
| 129 |
+
for _ in range(2):
|
| 130 |
+
if done:
|
| 131 |
+
break
|
| 132 |
+
u = build_observation_prompt(obs)
|
| 133 |
+
msg = [
|
| 134 |
+
{"role": "system", "content": SRE_SYSTEM_PROMPT},
|
| 135 |
+
{"role": "user", "content": f"{pr}\n{u}"},
|
| 136 |
+
]
|
| 137 |
+
ptxt = tok.apply_chat_template(msg, add_generation_prompt=True, tokenize=False)
|
| 138 |
+
out = generate_rollout_completions(trainer, [ptxt])[0]
|
| 139 |
+
p_ids.extend(out.get("prompt_ids", []))
|
| 140 |
+
c_ids.extend(out.get("completion_ids", []))
|
| 141 |
+
lps.extend(out.get("logprobs", []))
|
| 142 |
+
ctext = out.get("text")
|
| 143 |
+
if not ctext and cids:
|
| 144 |
+
ctext = tok.decode(cids, skip_special_tokens=True)
|
| 145 |
+
a = parse_action(ctext or "")
|
| 146 |
+
sr = await client.step(
|
| 147 |
+
{
|
| 148 |
+
"action": {
|
| 149 |
+
"action_type": str(a.get("action_type", "noop")),
|
| 150 |
+
"params": a.get("params") or {},
|
| 151 |
+
}
|
| 152 |
+
}
|
| 153 |
+
)
|
| 154 |
+
obs = sr.get("observation", sr)
|
| 155 |
+
done = sr.get("done", False)
|
| 156 |
+
step_sum += float(obs.get("reward", sr.get("reward", 0.0) or 0.0))
|
| 157 |
+
return p_ids, c_ids, lps, step_sum
|
| 158 |
+
finally:
|
| 159 |
+
await client.aclose()
|
| 160 |
+
|
| 161 |
+
p, c, lp, s = run_async(run_one())
|
| 162 |
+
ep_ids.append(p)
|
| 163 |
+
ec_ids.append(c)
|
| 164 |
+
elp.append(lp)
|
| 165 |
+
env_r.append(s)
|
| 166 |
+
return {
|
| 167 |
+
"prompt_ids": ep_ids,
|
| 168 |
+
"completion_ids": ec_ids,
|
| 169 |
+
"logprobs": elp,
|
| 170 |
+
"env_reward": env_r,
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
def rf(completions, **kwargs):
|
| 174 |
+
return [float(x) for x in kwargs.get("env_reward", [0.0] * len(completions))]
|
| 175 |
+
|
| 176 |
+
out_dir = str(_REPO / "training" / ".preflight_grpo")
|
| 177 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 178 |
+
tr = GRPOTrainer(
|
| 179 |
+
model=m,
|
| 180 |
+
processing_class=tok,
|
| 181 |
+
args=GRPOConfig(
|
| 182 |
+
output_dir=out_dir,
|
| 183 |
+
per_device_train_batch_size=1,
|
| 184 |
+
max_steps=5,
|
| 185 |
+
num_generations=1,
|
| 186 |
+
use_vllm=False,
|
| 187 |
+
learning_rate=1e-5,
|
| 188 |
+
max_completion_length=32,
|
| 189 |
+
),
|
| 190 |
+
train_dataset=Dataset.from_list([{"text": "x"}] * 2),
|
| 191 |
+
reward_funcs=[rf],
|
| 192 |
+
rollout_func=rollout_func,
|
| 193 |
+
)
|
| 194 |
+
tr.train()
|
| 195 |
+
return True
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def main() -> None:
|
| 199 |
+
# --- Part A: in-process (no network)
|
| 200 |
+
for seed, task in ((100, "easy"), (13, "easy"), (7, "easy")):
|
| 201 |
+
s = _inproc_golden_score(seed, task)
|
| 202 |
+
print(f"in-proc grader: seed={seed} task={task} score={s:.3f}", flush=True)
|
| 203 |
+
if s >= 0.9:
|
| 204 |
+
print("OK: in-process golden path reached >=0.9", flush=True)
|
| 205 |
+
break
|
| 206 |
+
else:
|
| 207 |
+
print("WARN: no seed reached 0.9 in in-proc test — check failure coverage", flush=True)
|
| 208 |
+
|
| 209 |
+
# --- B: Uvicorn + optional GRPO (requires same deps as the project)
|
| 210 |
+
try:
|
| 211 |
+
import uvicorn # noqa: F401
|
| 212 |
+
except ImportError:
|
| 213 |
+
print("SKIP: uvicorn not installed — pip install the project (see training/README.md)", flush=True)
|
| 214 |
+
print("OK", flush=True)
|
| 215 |
+
return
|
| 216 |
+
|
| 217 |
+
port = int(os.environ.get("PREFLIGHT_PORT", "8765"))
|
| 218 |
+
base = f"http://127.0.0.1:{port}"
|
| 219 |
+
os.environ["SEVZERO_ENV_URL"] = base
|
| 220 |
+
import urllib.request
|
| 221 |
+
|
| 222 |
+
proc = subprocess.Popen(
|
| 223 |
+
[sys.executable, "-m", "uvicorn", "server.app:app", "--host", "127.0.0.1", "--port", str(port)],
|
| 224 |
+
cwd=str(_REPO),
|
| 225 |
+
)
|
| 226 |
+
try:
|
| 227 |
+
for _ in range(25):
|
| 228 |
+
try:
|
| 229 |
+
with urllib.request.urlopen(f"{base}/health", timeout=2) as r:
|
| 230 |
+
if getattr(r, "status", 200) < 500:
|
| 231 |
+
break
|
| 232 |
+
except Exception:
|
| 233 |
+
time.sleep(0.5)
|
| 234 |
+
else:
|
| 235 |
+
raise RuntimeError("uvicorn not up")
|
| 236 |
+
try:
|
| 237 |
+
_grpo_tiny()
|
| 238 |
+
except Exception as e:
|
| 239 |
+
print(f"GRPO smoke failed (env OK): {e}", flush=True)
|
| 240 |
+
finally:
|
| 241 |
+
proc.terminate()
|
| 242 |
+
try:
|
| 243 |
+
proc.wait(timeout=10)
|
| 244 |
+
except Exception:
|
| 245 |
+
proc.kill()
|
| 246 |
+
print("OK", flush=True)
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
if __name__ == "__main__":
|
| 250 |
+
main()
|
training/push_dataset.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Upload SFT jsonl to Hugging Face (Mist-ic Main account) as a public dataset with Parquet.
|
| 3 |
+
"""
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
import json
|
| 7 |
+
import os
|
| 8 |
+
import sys
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
from dotenv import load_dotenv
|
| 12 |
+
from huggingface_hub import HfApi
|
| 13 |
+
|
| 14 |
+
REPO_ROOT = Path(__file__).resolve().parent.parent
|
| 15 |
+
load_dotenv(REPO_ROOT / "api.env")
|
| 16 |
+
load_dotenv(REPO_ROOT / "hg.env")
|
| 17 |
+
|
| 18 |
+
if str(REPO_ROOT) not in sys.path:
|
| 19 |
+
sys.path.insert(0, str(REPO_ROOT))
|
| 20 |
+
|
| 21 |
+
DATA_DIR = REPO_ROOT / "training" / "data"
|
| 22 |
+
STATS_PATH = DATA_DIR / "build_stats.json"
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def _readme(stats: dict) -> str:
|
| 26 |
+
return f"""# SevZero expert trajectories (SFT)
|
| 27 |
+
|
| 28 |
+
## Sources
|
| 29 |
+
|
| 30 |
+
- Synthetic expert rollouts from frontier models (Gemini 3.1 Pro, Azure OpenAI, Azure AI Inference)
|
| 31 |
+
against the local OpenEnv `server.app` SevZero environment.
|
| 32 |
+
|
| 33 |
+
## Filtering
|
| 34 |
+
|
| 35 |
+
- Episodes with final grader `score` **≥** `{stats.get("min_score_filter", 0.85)}` are included.
|
| 36 |
+
|
| 37 |
+
## Schema
|
| 38 |
+
|
| 39 |
+
- Each example has a `messages` list (Llama-3.1-8B-Instruct–style SFT) and `meta` (episode / step provenance):
|
| 40 |
+
- `system`: SRE on-call system prompt (same as `inference.SYSTEM_PROMPT` in the repo)
|
| 41 |
+
- `user`: JSON-serialized observation (shrink to ≤ {stats.get("max_observation_user_token_budget", 2048)} tokens for the user part)
|
| 42 |
+
- `assistant`: one JSON object `{{"action_type": "...", "params": {{...}}}}`
|
| 43 |
+
|
| 44 |
+
## Stats (from `build_stats.json` at publish time)
|
| 45 |
+
|
| 46 |
+
{json.dumps(stats, indent=2)}
|
| 47 |
+
|
| 48 |
+
## Parquet
|
| 49 |
+
|
| 50 |
+
- Splits `train` and `eval` are also pushed in Parquet for fast `datasets.load_dataset`.
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def _dataset_info(stats: dict) -> dict:
|
| 55 |
+
return {
|
| 56 |
+
"description": "SevZero SFT expert trajectories for Llama-3.1-8B-Instruct style chat training.",
|
| 57 |
+
"version": "1.0.0",
|
| 58 |
+
"license": "apache-2.0",
|
| 59 |
+
"build": stats,
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def main() -> None:
|
| 64 |
+
token = os.environ.get("HF_MAIN_TOKEN", "")
|
| 65 |
+
if not token:
|
| 66 |
+
raise SystemExit("HF_MAIN_TOKEN missing (set in api.env or hg.env).")
|
| 67 |
+
user = (os.environ.get("HF_MAIN_USERNAME", "") or "").strip() or "Mist-ic"
|
| 68 |
+
repo_id = f"{user}/sevzero-expert-trajectories"
|
| 69 |
+
if not (DATA_DIR / "sft_train.jsonl").is_file():
|
| 70 |
+
raise SystemExit(f"Missing {DATA_DIR / 'sft_train.jsonl'} — run build_dataset.py first.")
|
| 71 |
+
stats: dict = {}
|
| 72 |
+
if STATS_PATH.is_file():
|
| 73 |
+
stats = json.loads(STATS_PATH.read_text(encoding="utf-8"))
|
| 74 |
+
readme = _readme(stats)
|
| 75 |
+
info = _dataset_info(stats)
|
| 76 |
+
(DATA_DIR / "DATASET_README_HF.md").write_text(readme, encoding="utf-8")
|
| 77 |
+
(DATA_DIR / "dataset_info.json").write_text(
|
| 78 |
+
json.dumps(info, indent=2), encoding="utf-8"
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
api = HfApi(token=token)
|
| 82 |
+
api.create_repo(
|
| 83 |
+
repo_id=repo_id,
|
| 84 |
+
repo_type="dataset",
|
| 85 |
+
private=False,
|
| 86 |
+
exist_ok=True,
|
| 87 |
+
)
|
| 88 |
+
for name in (
|
| 89 |
+
"sft_train.jsonl",
|
| 90 |
+
"sft_eval.jsonl",
|
| 91 |
+
"build_stats.json",
|
| 92 |
+
"dataset_info.json",
|
| 93 |
+
):
|
| 94 |
+
p = DATA_DIR / name
|
| 95 |
+
if p.is_file():
|
| 96 |
+
api.upload_file(
|
| 97 |
+
path_or_fileobj=str(p),
|
| 98 |
+
path_in_repo=name,
|
| 99 |
+
repo_id=repo_id,
|
| 100 |
+
repo_type="dataset",
|
| 101 |
+
commit_message="Add SFT files and metadata",
|
| 102 |
+
)
|
| 103 |
+
api.upload_file(
|
| 104 |
+
path_or_fileobj=readme.encode("utf-8"),
|
| 105 |
+
path_in_repo="README.md",
|
| 106 |
+
repo_id=repo_id,
|
| 107 |
+
repo_type="dataset",
|
| 108 |
+
commit_message="Add dataset README",
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
from datasets import DatasetDict, load_dataset
|
| 112 |
+
|
| 113 |
+
train = load_dataset("json", data_files=str(DATA_DIR / "sft_train.jsonl"))["train"]
|
| 114 |
+
evp = DATA_DIR / "sft_eval.jsonl"
|
| 115 |
+
if evp.is_file() and evp.stat().st_size > 0:
|
| 116 |
+
ev = load_dataset("json", data_files=str(evp))["train"]
|
| 117 |
+
else:
|
| 118 |
+
ev = train.select([])
|
| 119 |
+
dd = DatasetDict(train=train, eval=ev)
|
| 120 |
+
dd.push_to_hub(repo_id, private=False, token=token)
|
| 121 |
+
|
| 122 |
+
url = f"https://huggingface.co/datasets/{repo_id}"
|
| 123 |
+
print(url, flush=True)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
if __name__ == "__main__":
|
| 127 |
+
main()
|
training/rollout_sevzero.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SevZero multi-turn rollout helpers for TRL GRPO (sync API for rollout_func).
|
| 3 |
+
Builds chat prompts from observations and parses one JSON action per turn.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
import json
|
| 9 |
+
import textwrap
|
| 10 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 11 |
+
|
| 12 |
+
SRE_SYSTEM_PROMPT = textwrap.dedent(
|
| 13 |
+
"""\
|
| 14 |
+
You are an expert Site Reliability Engineer (SRE) responding to a production incident.
|
| 15 |
+
You are managing a microservice cluster experiencing failures.
|
| 16 |
+
Your goal: restore all services to healthy SLO compliance as efficiently as possible.
|
| 17 |
+
|
| 18 |
+
Respond with EXACTLY one JSON object — no explanation, no markdown, just raw JSON:
|
| 19 |
+
{"action_type": "...", "params": {...}}
|
| 20 |
+
|
| 21 |
+
Param rules (STRICT — single service only, never a list):
|
| 22 |
+
- inspect_logs / inspect_metrics / inspect_traces / restart_service / rollback_service / scale_service:
|
| 23 |
+
{"action_type": "X", "params": {"service_id": "order-service"}}
|
| 24 |
+
- tune_config:
|
| 25 |
+
{"action_type": "tune_config", "params": {"service_id": "order-service", "key": "api_endpoint", "value": "correct"}}
|
| 26 |
+
- clear_cache:
|
| 27 |
+
{"action_type": "clear_cache", "params": {"cache_name": "redis-cache"}}
|
| 28 |
+
- rebalance_traffic:
|
| 29 |
+
{"action_type": "rebalance_traffic", "params": {"from_region": "us-east-1", "to_region": "us-west-2"}}
|
| 30 |
+
- noop:
|
| 31 |
+
{"action_type": "noop", "params": {}}
|
| 32 |
+
"""
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def build_observation_prompt(obs: Dict[str, Any]) -> str:
|
| 37 |
+
"""Port of inference.build_observation_prompt (observation dict from HTTP JSON)."""
|
| 38 |
+
parts = [f"## Incident Status\n{obs.get('observation_summary', 'N/A')}"]
|
| 39 |
+
alerts = obs.get("alerts") or []
|
| 40 |
+
if alerts:
|
| 41 |
+
alert_lines = [f" [{a['severity'].upper()}] {a['message']}" for a in alerts[:10]]
|
| 42 |
+
parts.append("## Active Alerts\n" + "\n".join(alert_lines))
|
| 43 |
+
services = obs.get("services") or []
|
| 44 |
+
degraded = [s for s in services if s.get("status") in ("degraded", "critical", "down")]
|
| 45 |
+
if degraded:
|
| 46 |
+
svc_lines = []
|
| 47 |
+
for s in degraded:
|
| 48 |
+
sid = s["id"]
|
| 49 |
+
svc_lines.append(
|
| 50 |
+
f" {sid} [{s['status']}]: error={s['error_rate']:.1%}, "
|
| 51 |
+
f"p99={s['latency_p99_ms']:.0f}ms, cpu={s['cpu_pct']:.0f}%, "
|
| 52 |
+
f"mem={s['memory_pct']:.0f}%"
|
| 53 |
+
)
|
| 54 |
+
parts.append("## Degraded Services\n" + "\n".join(svc_lines))
|
| 55 |
+
deploys = obs.get("recent_deploys") or []
|
| 56 |
+
if deploys:
|
| 57 |
+
dep_lines = [f" {d['service']} -> {d['version']} ({d['ticks_ago']} ticks ago)" for d in deploys]
|
| 58 |
+
parts.append("## Recent Deploys\n" + "\n".join(dep_lines))
|
| 59 |
+
actions = obs.get("actions_taken") or []
|
| 60 |
+
if actions:
|
| 61 |
+
act_lines = [
|
| 62 |
+
f" tick {a['tick']}: {a['action']}({a.get('target', '')}) -> {'OK' if a['success'] else 'FAIL'}"
|
| 63 |
+
for a in actions[-5:]
|
| 64 |
+
]
|
| 65 |
+
parts.append("## Recent Actions\n" + "\n".join(act_lines))
|
| 66 |
+
logs = obs.get("logs")
|
| 67 |
+
if logs:
|
| 68 |
+
parts.append(f"## Logs\n{logs}")
|
| 69 |
+
traces = obs.get("traces")
|
| 70 |
+
if traces:
|
| 71 |
+
spans = (traces.get("spans") or []) if isinstance(traces, dict) else []
|
| 72 |
+
error_spans = [s for s in spans if s.get("status") == "ERROR"]
|
| 73 |
+
if error_spans:
|
| 74 |
+
trace_lines = [
|
| 75 |
+
f" {s.get('service')}: {s.get('tags', {}).get('error.message', 'ERROR')}"
|
| 76 |
+
for s in error_spans[:5]
|
| 77 |
+
]
|
| 78 |
+
parts.append("## Trace Errors\n" + "\n".join(trace_lines))
|
| 79 |
+
legal = obs.get("legal_actions") or []
|
| 80 |
+
if legal:
|
| 81 |
+
legal_strs = [f" {la.get('action_type', '')}: targets={la.get('valid_targets', [])[:5]}" for la in legal]
|
| 82 |
+
parts.append("## Available Actions\n" + "\n".join(legal_strs))
|
| 83 |
+
return "\n\n".join(parts)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def parse_action(response_text: str) -> Dict[str, Any]:
|
| 87 |
+
text = (response_text or "").strip()
|
| 88 |
+
if "```json" in text:
|
| 89 |
+
text = text.split("```json", 1)[1].split("```", 1)[0].strip()
|
| 90 |
+
elif "```" in text:
|
| 91 |
+
text = text.split("```", 1)[1].split("```", 1)[0].strip()
|
| 92 |
+
start, end = text.find("{"), text.rfind("}") + 1
|
| 93 |
+
if start >= 0 and end > start:
|
| 94 |
+
try:
|
| 95 |
+
return json.loads(text[start:end])
|
| 96 |
+
except json.JSONDecodeError:
|
| 97 |
+
pass
|
| 98 |
+
return {"action_type": "noop", "params": {}}
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def _normalize_action(action: Dict[str, Any]) -> Dict[str, Any]:
|
| 102 |
+
act_type = action.get("action_type", "noop")
|
| 103 |
+
params = dict(action.get("params") or {})
|
| 104 |
+
if "replicas" in params:
|
| 105 |
+
try:
|
| 106 |
+
params["replicas"] = int(params["replicas"])
|
| 107 |
+
except (TypeError, ValueError):
|
| 108 |
+
params["replicas"] = 2
|
| 109 |
+
return {"action_type": act_type, "params": params}
|
training/train_grpo.py
ADDED
|
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
GRPO on SevZero via TRL rollout_func + trl.experimental.openenv.generate_rollout_completions.
|
| 4 |
+
Verify API with Context7 before changing integration (rollout_func is required; environment_factory is deprecated).
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import argparse
|
| 10 |
+
import json
|
| 11 |
+
import os
|
| 12 |
+
import random
|
| 13 |
+
import sys
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
from typing import Any, Dict, List, Optional
|
| 16 |
+
|
| 17 |
+
_REPO = Path(__file__).resolve().parent.parent
|
| 18 |
+
if str(_REPO) not in sys.path:
|
| 19 |
+
sys.path.insert(0, str(_REPO))
|
| 20 |
+
|
| 21 |
+
from training.config_utils import try_load_env_files
|
| 22 |
+
|
| 23 |
+
try_load_env_files()
|
| 24 |
+
|
| 25 |
+
BASE_MODEL = "meta-llama/Llama-3.1-8B-Instruct"
|
| 26 |
+
METRICS_NAME = "metrics.jsonl"
|
| 27 |
+
|
| 28 |
+
# Pinned in README: trl, unsloth, vllm — orchestrator sets exact versions
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def _parse_args() -> argparse.Namespace:
|
| 32 |
+
p = argparse.ArgumentParser()
|
| 33 |
+
p.add_argument("--output_dir", type=str, default="./outputs/grpo")
|
| 34 |
+
p.add_argument("--sft_adapter_repo", type=str, required=True, help="HF adapter repo (worker account)")
|
| 35 |
+
p.add_argument("--env_url", type=str, default="", help="Override; else SEVZERO_ENV_URL")
|
| 36 |
+
p.add_argument("--max_steps", type=int, default=350)
|
| 37 |
+
p.add_argument("--lr", type=float, default=7e-6)
|
| 38 |
+
p.add_argument("--K", type=int, default=4, dest="K", help="num_generations")
|
| 39 |
+
p.add_argument("--seed", type=int, default=42)
|
| 40 |
+
p.add_argument(
|
| 41 |
+
"--reward_shaping",
|
| 42 |
+
type=str,
|
| 43 |
+
default="dense_v1",
|
| 44 |
+
choices=("dense_v1", "dense_v2", "sparse"),
|
| 45 |
+
)
|
| 46 |
+
p.add_argument("--enable_schema_drift", action="store_true")
|
| 47 |
+
p.add_argument("--enable_curriculum", action="store_true")
|
| 48 |
+
p.add_argument("--enable_oversight", action="store_true")
|
| 49 |
+
p.add_argument(
|
| 50 |
+
"--task_mix",
|
| 51 |
+
type=str,
|
| 52 |
+
default="hard",
|
| 53 |
+
choices=("hard", "mixed", "curriculum"),
|
| 54 |
+
)
|
| 55 |
+
p.add_argument("--push_to_hub_repo", type=str, default="")
|
| 56 |
+
p.add_argument("--variant_name", type=str, default="grpo")
|
| 57 |
+
p.add_argument("--rollout_max_steps", type=int, default=0, help="0 = from env observation max_steps")
|
| 58 |
+
return p.parse_args()
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def _pick_task_id(args, idx: int, step: int) -> str:
|
| 62 |
+
if args.task_mix == "hard":
|
| 63 |
+
return "hard"
|
| 64 |
+
if args.task_mix == "mixed":
|
| 65 |
+
return ["easy", "medium", "hard"][idx % 3]
|
| 66 |
+
# curriculum: escalate every ~50 steps
|
| 67 |
+
if args.enable_curriculum:
|
| 68 |
+
tier = min(2, step // 50)
|
| 69 |
+
return ["easy", "medium", "hard"][tier]
|
| 70 |
+
return "hard"
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def _compute_episode_return(
|
| 74 |
+
shaping: str,
|
| 75 |
+
step_rewards: List[float],
|
| 76 |
+
grader: Optional[Dict[str, Any]],
|
| 77 |
+
) -> float:
|
| 78 |
+
if shaping == "sparse" and grader is not None:
|
| 79 |
+
return float(grader.get("score", 0.0))
|
| 80 |
+
if shaping == "dense_v2" and grader is not None:
|
| 81 |
+
# Slightly weight terminal score
|
| 82 |
+
s = sum(step_rewards) if step_rewards else 0.0
|
| 83 |
+
return 0.7 * s + 0.3 * float(grader.get("score", 0.0))
|
| 84 |
+
return float(sum(step_rewards)) if step_rewards else 0.0
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def _build_default_dataset():
|
| 88 |
+
from datasets import Dataset
|
| 89 |
+
|
| 90 |
+
rows = []
|
| 91 |
+
for i in range(64):
|
| 92 |
+
text = (
|
| 93 |
+
"You are the on-call SRE. Restore service health. "
|
| 94 |
+
f"Incident session {i} — triage, diagnose root cause, remediate, verify."
|
| 95 |
+
)
|
| 96 |
+
rows.append({"text": text, "row_id": i})
|
| 97 |
+
return Dataset.from_list(rows)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def _reward_from_env(completions, **kwargs):
|
| 101 |
+
r = kwargs.get("env_reward")
|
| 102 |
+
if r is None:
|
| 103 |
+
return [0.0] * len(completions)
|
| 104 |
+
return [float(x) for x in r]
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def main() -> None:
|
| 108 |
+
args = _parse_args()
|
| 109 |
+
env_url = (args.env_url or os.environ.get("SEVZERO_ENV_URL", "")).rstrip("/")
|
| 110 |
+
if not env_url:
|
| 111 |
+
raise SystemExit("Set --env_url or SEVZERO_ENV_URL to the remote SevZero HTTP base URL")
|
| 112 |
+
|
| 113 |
+
worker_token = os.environ.get("HF_TOKEN", "")
|
| 114 |
+
main_token = os.environ.get("HF_MAIN_TOKEN", "")
|
| 115 |
+
|
| 116 |
+
try:
|
| 117 |
+
import trackio
|
| 118 |
+
|
| 119 |
+
trackio.init(
|
| 120 |
+
project="sevzero-grpo",
|
| 121 |
+
space_id="Mist-ic/sevzero-trackio",
|
| 122 |
+
**({"hf_token": main_token} if main_token else {}),
|
| 123 |
+
)
|
| 124 |
+
except Exception as e:
|
| 125 |
+
print(f"trackio init skipped: {e}", flush=True)
|
| 126 |
+
|
| 127 |
+
try:
|
| 128 |
+
from unsloth import FastLanguageModel, PatchFastRL
|
| 129 |
+
except ImportError as e:
|
| 130 |
+
raise SystemExit(
|
| 131 |
+
f"unsloth is required for GRPO on this path: {e}\n"
|
| 132 |
+
"Install training extras, or on unsupported platforms set UNSLOTH_DISABLE=1 and extend train_grpo."
|
| 133 |
+
) from e
|
| 134 |
+
|
| 135 |
+
PatchFastRL(algorithm="grpo", FastLanguageModel=FastLanguageModel)
|
| 136 |
+
|
| 137 |
+
from peft import PeftModel
|
| 138 |
+
from trl import GRPOConfig, GRPOTrainer
|
| 139 |
+
from trl.experimental.openenv import generate_rollout_completions
|
| 140 |
+
|
| 141 |
+
from training.env_client import AsyncSevZeroEnvClient, run_async
|
| 142 |
+
from training.rollout_sevzero import (
|
| 143 |
+
SRE_SYSTEM_PROMPT,
|
| 144 |
+
build_observation_prompt,
|
| 145 |
+
parse_action,
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
max_seq = 4096
|
| 149 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 150 |
+
model_name=BASE_MODEL,
|
| 151 |
+
max_seq_length=max_seq,
|
| 152 |
+
dtype=None,
|
| 153 |
+
load_in_4bit=True,
|
| 154 |
+
)
|
| 155 |
+
model = PeftModel.from_pretrained(model, args.sft_adapter_repo, token=worker_token or None)
|
| 156 |
+
# Optional env flags (future env upgrades) — no-op for baseline server
|
| 157 |
+
if args.enable_schema_drift:
|
| 158 |
+
os.environ["SEVZERO_SCHEMA_DRIFT"] = "1"
|
| 159 |
+
if args.enable_oversight:
|
| 160 |
+
os.environ["SEVZERO_OVERSIGHT"] = "1"
|
| 161 |
+
|
| 162 |
+
metrics_path = Path(args.output_dir) / METRICS_NAME
|
| 163 |
+
metrics_path.parent.mkdir(parents=True, exist_ok=True)
|
| 164 |
+
|
| 165 |
+
# Capture trainer ref for step index in seeding
|
| 166 |
+
_trainer_holder: List[Any] = [None]
|
| 167 |
+
_global_episode: List[int] = [0]
|
| 168 |
+
|
| 169 |
+
def rollout_func(prompts: List[str], trainer) -> Dict[str, List[Any]]:
|
| 170 |
+
_trainer_holder[0] = trainer
|
| 171 |
+
episode_prompt_ids: List[List[int]] = []
|
| 172 |
+
episode_completion_ids: List[List[int]] = []
|
| 173 |
+
episode_logprobs: List[List[float]] = []
|
| 174 |
+
env_rewards: List[float] = []
|
| 175 |
+
tkn = os.environ.get("HF_TOKEN", "") # for private Space
|
| 176 |
+
for batch_idx, prompt_text in enumerate(prompts):
|
| 177 |
+
tr = _trainer_holder[0]
|
| 178 |
+
state = getattr(tr, "state", None) if tr else None
|
| 179 |
+
step = getattr(state, "global_step", 0) if state else 0
|
| 180 |
+
_global_episode[0] += 1
|
| 181 |
+
task_id = _pick_task_id(args, batch_idx, step)
|
| 182 |
+
seed = 13 + (batch_idx * 997) + (step * 13) + _global_episode[0] + random.randint(0, 1_000_000) % 100_000
|
| 183 |
+
|
| 184 |
+
async def _one_ep() -> tuple:
|
| 185 |
+
client = AsyncSevZeroEnvClient(env_url, token=tkn or None)
|
| 186 |
+
try:
|
| 187 |
+
p_ids: List[int] = []
|
| 188 |
+
c_ids: List[int] = []
|
| 189 |
+
lps: List[float] = []
|
| 190 |
+
step_rewards: List[float] = []
|
| 191 |
+
ro = await client.reset(task_id=task_id, seed=seed)
|
| 192 |
+
obs = ro.get("observation", ro)
|
| 193 |
+
done = ro.get("done", False)
|
| 194 |
+
grader: Optional[Dict[str, Any]] = None
|
| 195 |
+
user_prefix = f"{prompt_text}\n\n## Session\n"
|
| 196 |
+
for _t in range(args.rollout_max_steps or int(obs.get("max_steps", 20))):
|
| 197 |
+
if done:
|
| 198 |
+
break
|
| 199 |
+
user_msg = build_observation_prompt(obs)
|
| 200 |
+
messages = [
|
| 201 |
+
{"role": "system", "content": SRE_SYSTEM_PROMPT},
|
| 202 |
+
{"role": "user", "content": user_prefix + user_msg},
|
| 203 |
+
]
|
| 204 |
+
p_text = tokenizer.apply_chat_template(
|
| 205 |
+
messages, add_generation_prompt=True, tokenize=False,
|
| 206 |
+
)
|
| 207 |
+
out = generate_rollout_completions(tr, [p_text])[0]
|
| 208 |
+
p_ids.extend(out.get("prompt_ids", []))
|
| 209 |
+
c_ids.extend(out.get("completion_ids", []))
|
| 210 |
+
lps.extend(out.get("logprobs", []))
|
| 211 |
+
gen_ids = out.get("completion_ids", [])
|
| 212 |
+
raw = out.get("text")
|
| 213 |
+
if not raw and gen_ids:
|
| 214 |
+
raw = tokenizer.decode(gen_ids, skip_special_tokens=True)
|
| 215 |
+
action = parse_action(raw or "")
|
| 216 |
+
step_payload = {
|
| 217 |
+
"action_type": str(action.get("action_type", "noop")),
|
| 218 |
+
"params": action.get("params") or {},
|
| 219 |
+
}
|
| 220 |
+
sr = await client.step({"action": step_payload})
|
| 221 |
+
obs = sr.get("observation", sr)
|
| 222 |
+
done = sr.get("done", False)
|
| 223 |
+
r = float(obs.get("reward", sr.get("reward", 0.0) or 0.0))
|
| 224 |
+
step_rewards.append(r)
|
| 225 |
+
st = await client.get_state()
|
| 226 |
+
max_st = int(obs.get("max_steps", 10))
|
| 227 |
+
try:
|
| 228 |
+
grader = await client.grade_episode(
|
| 229 |
+
final_slo_score=float(st.get("global_slo_score", 0.0)),
|
| 230 |
+
steps_taken=int(st.get("step_count", 0)),
|
| 231 |
+
max_steps=max_st,
|
| 232 |
+
actions_taken=list(obs.get("actions_taken", [])),
|
| 233 |
+
terminated=bool(st.get("terminated", True)),
|
| 234 |
+
termination_reason=st.get("termination_reason"),
|
| 235 |
+
)
|
| 236 |
+
except Exception:
|
| 237 |
+
grader = None
|
| 238 |
+
R = _compute_episode_return(args.reward_shaping, step_rewards, grader)
|
| 239 |
+
return p_ids, c_ids, lps, R
|
| 240 |
+
finally:
|
| 241 |
+
await client.aclose()
|
| 242 |
+
|
| 243 |
+
p_ids, c_ids, lps, r_ep = run_async(_one_ep())
|
| 244 |
+
episode_prompt_ids.append(p_ids)
|
| 245 |
+
episode_completion_ids.append(c_ids)
|
| 246 |
+
episode_logprobs.append(lps)
|
| 247 |
+
env_rewards.append(r_ep)
|
| 248 |
+
return {
|
| 249 |
+
"prompt_ids": episode_prompt_ids,
|
| 250 |
+
"completion_ids": episode_completion_ids,
|
| 251 |
+
"logprobs": episode_logprobs,
|
| 252 |
+
"env_reward": env_rewards,
|
| 253 |
+
}
|
| 254 |
+
|
| 255 |
+
grpo = GRPOConfig(
|
| 256 |
+
output_dir=args.output_dir,
|
| 257 |
+
learning_rate=args.lr,
|
| 258 |
+
per_device_train_batch_size=1,
|
| 259 |
+
gradient_accumulation_steps=8,
|
| 260 |
+
max_completion_length=1024,
|
| 261 |
+
num_train_epochs=1,
|
| 262 |
+
max_steps=args.max_steps,
|
| 263 |
+
num_generations=args.K,
|
| 264 |
+
temperature=0.85,
|
| 265 |
+
max_prompt_length=4096,
|
| 266 |
+
beta=0.04,
|
| 267 |
+
lr_scheduler_type="cosine",
|
| 268 |
+
use_vllm=True,
|
| 269 |
+
vllm_mode="colocate",
|
| 270 |
+
vllm_gpu_memory_utilization=0.55,
|
| 271 |
+
report_to="trackio",
|
| 272 |
+
logging_steps=1,
|
| 273 |
+
save_steps=100,
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
train_ds = _build_default_dataset()
|
| 277 |
+
|
| 278 |
+
trainer = GRPOTrainer(
|
| 279 |
+
model=model,
|
| 280 |
+
processing_class=tokenizer,
|
| 281 |
+
args=grpo,
|
| 282 |
+
train_dataset=train_ds,
|
| 283 |
+
reward_funcs=[_reward_from_env],
|
| 284 |
+
rollout_func=rollout_func,
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
from transformers import TrainerCallback
|
| 288 |
+
|
| 289 |
+
class _MetricsJSONL(TrainerCallback):
|
| 290 |
+
def on_log(self, args, state, control, logs=None, **kwargs):
|
| 291 |
+
if not logs:
|
| 292 |
+
return
|
| 293 |
+
rec = {
|
| 294 |
+
"step": state.global_step,
|
| 295 |
+
"reward_mean": logs.get("rewards", logs.get("reward", None)),
|
| 296 |
+
"reward_std": logs.get("reward_std", None),
|
| 297 |
+
"kl": logs.get("kl", None),
|
| 298 |
+
"entropy": logs.get("entropy", None),
|
| 299 |
+
"grad_norm": logs.get("grad_norm", None),
|
| 300 |
+
"loss": logs.get("loss", None),
|
| 301 |
+
"frac_reward_zero_std": logs.get("frac_reward_zero", logs.get("frac_reward_zero_std", None)),
|
| 302 |
+
"lr": logs.get("learning_rate", None),
|
| 303 |
+
}
|
| 304 |
+
with metrics_path.open("a", encoding="utf-8") as f:
|
| 305 |
+
f.write(json.dumps(rec, default=str) + "\n")
|
| 306 |
+
print(json.dumps({"type": "grpo", **rec}, default=str), flush=True)
|
| 307 |
+
|
| 308 |
+
trainer.add_callback(_MetricsJSONL())
|
| 309 |
+
trainer.train()
|
| 310 |
+
|
| 311 |
+
if args.push_to_hub_repo:
|
| 312 |
+
model.push_to_hub(args.push_to_hub_repo, token=worker_token or None, private=True)
|
| 313 |
+
tokenizer.push_to_hub(args.push_to_hub_repo, token=worker_token or None, private=True)
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
if __name__ == "__main__":
|
| 317 |
+
main()
|
training/train_sft.py
ADDED
|
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
SFT warmup: QLoRA on Mist-ic/sevzero-expert-trajectories (see training/data/HANDOFF.md).
|
| 4 |
+
Target TRL / Unsloth versions: see comments after `pip index` in training/README.md.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import argparse
|
| 10 |
+
import json
|
| 11 |
+
import os
|
| 12 |
+
import sys
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
|
| 15 |
+
_REPO = Path(__file__).resolve().parent.parent
|
| 16 |
+
if str(_REPO) not in sys.path:
|
| 17 |
+
sys.path.insert(0, str(_REPO))
|
| 18 |
+
|
| 19 |
+
from training.config_utils import try_load_env_files
|
| 20 |
+
|
| 21 |
+
try_load_env_files()
|
| 22 |
+
|
| 23 |
+
# --- Pin guidance (orchestrator resolves exact pins): trl>=0.22, unsloth, bitsandbytes, peft, accelerate
|
| 24 |
+
BASE_MODEL = "meta-llama/Llama-3.1-8B-Instruct"
|
| 25 |
+
DATASET_ID = "Mist-ic/sevzero-expert-trajectories"
|
| 26 |
+
DEFAULT_MAX_SEQ = 2048
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def _parse_args() -> argparse.Namespace:
|
| 30 |
+
p = argparse.ArgumentParser()
|
| 31 |
+
p.add_argument("--output_dir", type=str, default="./outputs/sft")
|
| 32 |
+
p.add_argument("--max_steps", type=int, default=250)
|
| 33 |
+
p.add_argument("--lr", type=float, default=1e-5)
|
| 34 |
+
p.add_argument("--seed", type=int, default=42)
|
| 35 |
+
p.add_argument("--push_to_hub_repo", type=str, default="", help="e.g. PhaseOfCode/sevzero-llama3-8b-sft")
|
| 36 |
+
p.add_argument("--variant_name", type=str, default="default")
|
| 37 |
+
p.add_argument("--max_seq_length", type=int, default=0, help="0 = read HANDOFF / 2048")
|
| 38 |
+
return p.parse_args()
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def _read_default_max_seq() -> int:
|
| 42 |
+
handoff = _REPO / "training" / "data" / "HANDOFF.md"
|
| 43 |
+
if not handoff.is_file():
|
| 44 |
+
return DEFAULT_MAX_SEQ
|
| 45 |
+
text = handoff.read_text(encoding="utf-8", errors="ignore")
|
| 46 |
+
for line in text.splitlines():
|
| 47 |
+
if "max_seq" in line.lower() and "`" in line:
|
| 48 |
+
try:
|
| 49 |
+
return int(line.split("`")[1])
|
| 50 |
+
except (ValueError, IndexError):
|
| 51 |
+
pass
|
| 52 |
+
return DEFAULT_MAX_SEQ
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def _format_row_to_text(row: dict, tokenizer) -> str:
|
| 56 |
+
"""Support 'text' column or OpenAI-style messages JSON."""
|
| 57 |
+
if "text" in row and row["text"]:
|
| 58 |
+
return str(row["text"])
|
| 59 |
+
if "messages" in row and row["messages"]:
|
| 60 |
+
msgs = row["messages"]
|
| 61 |
+
if isinstance(msgs, str):
|
| 62 |
+
import json as _j
|
| 63 |
+
|
| 64 |
+
msgs = _j.loads(msgs)
|
| 65 |
+
return tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=False)
|
| 66 |
+
raise ValueError("Dataset row must have 'text' or 'messages'")
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def main() -> None:
|
| 70 |
+
args = _parse_args()
|
| 71 |
+
max_seq = args.max_seq_length or _read_default_max_seq()
|
| 72 |
+
|
| 73 |
+
worker_token = os.environ.get("HF_TOKEN", "")
|
| 74 |
+
main_token = os.environ.get("HF_MAIN_TOKEN", "")
|
| 75 |
+
if not worker_token:
|
| 76 |
+
print("warning: HF_TOKEN not set — Hub push and model download may fail.", flush=True)
|
| 77 |
+
|
| 78 |
+
# Trackio with main account (read-only space) while training pushes use HF_TOKEN
|
| 79 |
+
try:
|
| 80 |
+
import trackio
|
| 81 |
+
|
| 82 |
+
if main_token:
|
| 83 |
+
os.environ.setdefault("HF_TOKEN", worker_token)
|
| 84 |
+
trackio.init(
|
| 85 |
+
project="sevzero-sft",
|
| 86 |
+
space_id="Mist-ic/sevzero-trackio",
|
| 87 |
+
**({"hf_token": main_token} if main_token else {}),
|
| 88 |
+
)
|
| 89 |
+
except Exception as e:
|
| 90 |
+
print(f"trackio init skipped: {e}", flush=True)
|
| 91 |
+
|
| 92 |
+
from datasets import load_dataset
|
| 93 |
+
from transformers import TrainingArguments
|
| 94 |
+
from trl import SFTConfig, SFTTrainer
|
| 95 |
+
|
| 96 |
+
ds = load_dataset(DATASET_ID, split="train")
|
| 97 |
+
|
| 98 |
+
use_unsloth = os.environ.get("UNSLOTH_DISABLE", "").lower() not in ("1", "true", "yes")
|
| 99 |
+
model = None
|
| 100 |
+
tokenizer = None
|
| 101 |
+
|
| 102 |
+
if use_unsloth:
|
| 103 |
+
try:
|
| 104 |
+
from unsloth import FastLanguageModel
|
| 105 |
+
|
| 106 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 107 |
+
model_name=BASE_MODEL,
|
| 108 |
+
max_seq_length=max_seq,
|
| 109 |
+
dtype=None,
|
| 110 |
+
load_in_4bit=True,
|
| 111 |
+
)
|
| 112 |
+
target_modules = [
|
| 113 |
+
"q_proj",
|
| 114 |
+
"k_proj",
|
| 115 |
+
"v_proj",
|
| 116 |
+
"o_proj",
|
| 117 |
+
"gate_proj",
|
| 118 |
+
"up_proj",
|
| 119 |
+
"down_proj",
|
| 120 |
+
]
|
| 121 |
+
model = FastLanguageModel.get_peft_model(
|
| 122 |
+
model,
|
| 123 |
+
r=32,
|
| 124 |
+
lora_alpha=64,
|
| 125 |
+
lora_dropout=0.0,
|
| 126 |
+
target_modules=target_modules,
|
| 127 |
+
use_gradient_checkpointing="unsloth",
|
| 128 |
+
)
|
| 129 |
+
except Exception as e:
|
| 130 |
+
print(f"Unsloth path failed ({e}), falling back to PEFT+bnb.", flush=True)
|
| 131 |
+
use_unsloth = False
|
| 132 |
+
|
| 133 |
+
if not use_unsloth:
|
| 134 |
+
import torch
|
| 135 |
+
from peft import LoraConfig, get_peft_model
|
| 136 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
| 137 |
+
|
| 138 |
+
bnb = BitsAndBytesConfig(
|
| 139 |
+
load_in_4bit=True,
|
| 140 |
+
bnb_4bit_quant_type="nf4",
|
| 141 |
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
| 142 |
+
)
|
| 143 |
+
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True)
|
| 144 |
+
if tokenizer.pad_token is None:
|
| 145 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 146 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 147 |
+
BASE_MODEL,
|
| 148 |
+
quantization_config=bnb,
|
| 149 |
+
device_map="auto",
|
| 150 |
+
torch_dtype=torch.bfloat16,
|
| 151 |
+
)
|
| 152 |
+
lora = LoraConfig(
|
| 153 |
+
r=32,
|
| 154 |
+
lora_alpha=64,
|
| 155 |
+
lora_dropout=0.0,
|
| 156 |
+
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
|
| 157 |
+
task_type="CAUSAL_LM",
|
| 158 |
+
)
|
| 159 |
+
model = get_peft_model(model, lora)
|
| 160 |
+
|
| 161 |
+
def formatting_prompts(examples: dict) -> dict:
|
| 162 |
+
texts = []
|
| 163 |
+
n = len(next(iter(examples.values())))
|
| 164 |
+
keys = list(examples.keys())
|
| 165 |
+
for i in range(n):
|
| 166 |
+
row = {k: (examples[k][i] if k in examples else None) for k in keys}
|
| 167 |
+
texts.append(_format_row_to_text(row, tokenizer))
|
| 168 |
+
return {"text": texts}
|
| 169 |
+
|
| 170 |
+
cols = ds.column_names
|
| 171 |
+
if "text" not in ds.column_names:
|
| 172 |
+
if "messages" in ds.column_names:
|
| 173 |
+
ds = ds.map(
|
| 174 |
+
formatting_prompts,
|
| 175 |
+
batched=True,
|
| 176 |
+
remove_columns=[c for c in cols if c not in ("messages",)],
|
| 177 |
+
)
|
| 178 |
+
else:
|
| 179 |
+
raise ValueError("Dataset must include a 'text' or 'messages' column")
|
| 180 |
+
targs = SFTConfig(
|
| 181 |
+
output_dir=args.output_dir,
|
| 182 |
+
max_steps=args.max_steps,
|
| 183 |
+
learning_rate=args.lr,
|
| 184 |
+
per_device_train_batch_size=4,
|
| 185 |
+
gradient_accumulation_steps=8,
|
| 186 |
+
warmup_ratio=0.05,
|
| 187 |
+
lr_scheduler_type="cosine",
|
| 188 |
+
optim="paged_adamw_8bit",
|
| 189 |
+
bf16=True,
|
| 190 |
+
seed=args.seed,
|
| 191 |
+
logging_steps=1,
|
| 192 |
+
report_to="trackio",
|
| 193 |
+
save_total_limit=2,
|
| 194 |
+
max_seq_length=max_seq,
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
from transformers import TrainerCallback
|
| 198 |
+
|
| 199 |
+
class JsonStepLog(TrainerCallback):
|
| 200 |
+
def on_log(self, args, state, control, logs=None, **kwargs):
|
| 201 |
+
if not logs:
|
| 202 |
+
return
|
| 203 |
+
payload = {
|
| 204 |
+
"type": "sft_step",
|
| 205 |
+
"step": state.global_step,
|
| 206 |
+
"loss": logs.get("loss"),
|
| 207 |
+
"lr": logs.get("learning_rate"),
|
| 208 |
+
}
|
| 209 |
+
print(json.dumps(payload, default=str), flush=True)
|
| 210 |
+
|
| 211 |
+
trainer = SFTTrainer(
|
| 212 |
+
model=model,
|
| 213 |
+
processing_class=tokenizer,
|
| 214 |
+
args=targs,
|
| 215 |
+
train_dataset=ds,
|
| 216 |
+
dataset_text_field="text",
|
| 217 |
+
callbacks=[JsonStepLog()],
|
| 218 |
+
)
|
| 219 |
+
trainer.train()
|
| 220 |
+
|
| 221 |
+
if args.push_to_hub_repo:
|
| 222 |
+
print(json.dumps({"event": "push_to_hub", "repo": args.push_to_hub_repo}, default=str), flush=True)
|
| 223 |
+
model.push_to_hub(
|
| 224 |
+
args.push_to_hub_repo,
|
| 225 |
+
token=worker_token or None,
|
| 226 |
+
private=True,
|
| 227 |
+
)
|
| 228 |
+
tokenizer.push_to_hub(
|
| 229 |
+
args.push_to_hub_repo,
|
| 230 |
+
token=worker_token or None,
|
| 231 |
+
private=True,
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
if __name__ == "__main__":
|
| 236 |
+
main()
|