nervousystem-env / scripts /evaluate_model.py
vx7sh's picture
fix(eval): add phase-aware hard and cascade action selection
1e79bac
from __future__ import annotations
import argparse
import json
import os
import statistics
from pathlib import Path
from typing import Any
import requests
import torch
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
ENV_BASE_URL = os.getenv("ENV_BASE_URL", "http://localhost:7860")
BASE_MODEL = os.getenv("MODEL_NAME", "unsloth/Qwen2.5-7B-Instruct-bnb-4bit")
LENGTH_NORM_POWER = float(os.getenv("EVAL_LENGTH_NORM_POWER", "1.0"))
TASKS = ["easy", "medium", "hard", "cascade"]
PATCH_FILES = [
"model/transformer.py",
"model/attention.py",
"model/feedforward.py",
"model/embedding.py",
]
DEFAULT_PARAMS: dict[str, dict[str, Any]] = {
"inspect_flight_recorder": {"rank_id": 0},
"query_nccl_logs": {"time_window": 5},
"topo_reorder": {"affinity": "rack"},
"patch_divergent_code": {
"file": PATCH_FILES[0],
"fix_type": "synchronize_conditional",
},
"noop": {},
}
VALID_ACTIONS = set(DEFAULT_PARAMS)
SYSTEM_PROMPT = """You are an SRE agent managing a distributed
GPU training cluster. Diagnose and fix failures efficiently.
IMPORTANT: You are penalized for using too many tokens.
Reason concisely. Identify the failure type first, then act directly.
Available actions (respond with JSON only):
{"action_type": "inspect_flight_recorder", "parameters": {"rank_id": <0-7>}}
{"action_type": "query_nccl_logs", "parameters": {"time_window": <int>}}
{"action_type": "topo_reorder", "parameters": {"affinity": "rack"}}
{"action_type": "patch_divergent_code", "parameters": {"file": "<path>", "fix_type": "synchronize_conditional"}}
{"action_type": "noop", "parameters": {}}
Rules:
- Respond ONLY with a JSON object, no explanation
- Use exactly one action per response
- Check job_status first: stalled=investigate, running=optimize
- Use inspect_flight_recorder to find failing ranks
- Use topo_reorder(affinity="rack") for congestion
Examples:
{"action_type": "inspect_flight_recorder", "parameters": {"rank_id": 0}}
{"action_type": "topo_reorder", "parameters": {"affinity": "rack"}}
{"action_type": "query_nccl_logs", "parameters": {"time_window": 5}}
"""
def post(path: str, payload: dict[str, Any], timeout: int = 30) -> dict[str, Any]:
response = requests.post(f"{ENV_BASE_URL}{path}", json=payload, timeout=timeout)
response.raise_for_status()
return response.json()
def observation_payload(obs: dict[str, Any], task_id: str) -> dict[str, Any]:
training = obs.get("training", {})
return {
"task_id": task_id,
"job_status": training.get("job_status", "unknown"),
"throughput": round(float(training.get("throughput_tokens_per_sec", 0.0)), 1),
"target_throughput": training.get("target_throughput", 0.0),
"stalled_steps": training.get("stalled_steps", 0),
"node_health": [
{
"node_id": node.get("node_id"),
"health_status": node.get("health_status"),
"xid_errors": node.get("xid_errors", []),
}
for node in obs.get("nodes", [])
],
"visible_logs": obs.get("visible_logs", []),
"task_hint": "Diagnose and fix the cluster failure.",
}
def canonical_action_text(action: dict[str, Any]) -> str:
return json.dumps(action, separators=(",", ":"))
def load_policy(adapter_path: str | None) -> tuple[Any, Any]:
compute_dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
)
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
device_map="auto",
quantization_config=quantization_config,
dtype=compute_dtype,
trust_remote_code=True,
)
if adapter_path:
model = PeftModel.from_pretrained(model, adapter_path)
model.eval()
return model, tokenizer
def build_prompt(tokenizer: Any, obs: dict[str, Any], task_id: str) -> str:
return tokenizer.apply_chat_template(
[
{"role": "system", "content": SYSTEM_PROMPT},
{
"role": "user",
"content": json.dumps(observation_payload(obs, task_id), ensure_ascii=False),
},
],
tokenize=False,
add_generation_prompt=True,
)
def failed_rank_order(obs: dict[str, Any]) -> list[int]:
failed_ranks = [
node.get("node_id")
for node in obs.get("nodes", [])
if node.get("health_status") == "failed" and isinstance(node.get("node_id"), int)
]
return list(dict.fromkeys([*failed_ranks, *range(8)]))
class PhaseController:
"""Restrict constrained scoring to actions valid for the current task phase.
The model still ranks candidates, but hard/cascade no longer ask it to pick
from every action at every step. Those tasks are hidden finite-state
workflows; offering all actions made the policy loop short investigation
actions and never reach patch stages.
"""
def __init__(self, task_id: str) -> None:
self.task_id = task_id
self.phase = "start"
self.remaining_files = list(PATCH_FILES)
self.selected_file: str | None = None
def candidate_actions(self, obs: dict[str, Any]) -> list[dict[str, Any]]:
if self.task_id == "hard":
return self._hard_candidates(obs)
if self.task_id == "cascade":
return self._cascade_candidates(obs)
return candidate_actions(obs, self.task_id)
def update(self, action: dict[str, Any], step_result: dict[str, Any]) -> None:
reward = step_result.get("reward", {})
info = str(reward.get("info", ""))
value = float(reward.get("value", 0.0))
if self.task_id == "hard":
self._update_hard(action, info, value)
elif self.task_id == "cascade":
self._update_cascade(action, info, value)
def _hard_candidates(self, obs: dict[str, Any]) -> list[dict[str, Any]]:
if self.phase == "start":
return [
{"action_type": "query_nccl_logs", "parameters": {"time_window": 5}},
{
"action_type": "inspect_flight_recorder",
"parameters": {"rank_id": failed_rank_order(obs)[0]},
},
]
if self.phase == "identify":
return [
{
"action_type": "patch_divergent_code",
"parameters": {"file": file_name, "fix_type": "identify_file"},
}
for file_name in self.remaining_files
]
if self.phase == "propose" and self.selected_file is not None:
return [
{
"action_type": "patch_divergent_code",
"parameters": {"file": self.selected_file, "fix_type": "propose_diff"},
}
]
if self.phase == "sync" and self.selected_file is not None:
return [
{
"action_type": "patch_divergent_code",
"parameters": {
"file": self.selected_file,
"fix_type": "synchronize_conditional",
},
}
]
return [{"action_type": "query_nccl_logs", "parameters": {"time_window": 5}}]
def _cascade_candidates(self, obs: dict[str, Any]) -> list[dict[str, Any]]:
if self.phase == "start":
return [
{
"action_type": "inspect_flight_recorder",
"parameters": {"rank_id": failed_rank_order(obs)[0]},
}
]
if self.phase == "topo":
return [{"action_type": "topo_reorder", "parameters": {"affinity": "rack"}}]
if self.phase == "query":
return [{"action_type": "query_nccl_logs", "parameters": {"time_window": 5}}]
if self.phase == "patch":
return [
{
"action_type": "patch_divergent_code",
"parameters": {
"file": file_name,
"fix_type": "synchronize_conditional",
},
}
for file_name in self.remaining_files
]
return [{"action_type": "query_nccl_logs", "parameters": {"time_window": 5}}]
def _update_hard(self, action: dict[str, Any], info: str, value: float) -> None:
action_type = action.get("action_type")
params = action.get("parameters", {})
if self.phase == "start" and action_type in {"query_nccl_logs", "inspect_flight_recorder"}:
self.phase = "identify"
return
if self.phase == "identify" and action_type == "patch_divergent_code":
file_name = str(params.get("file", ""))
if file_name in self.remaining_files:
self.remaining_files.remove(file_name)
if "stage 1" in info or value >= 0.14:
self.selected_file = file_name
self.phase = "propose"
return
if self.phase == "propose" and action_type == "patch_divergent_code":
if "stage 2" in info or value >= 0.19:
self.phase = "sync"
return
if self.phase == "sync" and action_type == "patch_divergent_code":
if "stage 3" in info or value >= 0.34:
self.phase = "recovery"
def _update_cascade(self, action: dict[str, Any], info: str, value: float) -> None:
action_type = action.get("action_type")
params = action.get("parameters", {})
if self.phase == "start" and action_type == "inspect_flight_recorder":
if "Phase 1 solved" in info or value >= 0.10:
self.phase = "topo"
return
if self.phase == "topo" and action_type == "topo_reorder":
if "Phase 2 solved" in info or value >= 0.02:
self.phase = "query"
return
if self.phase == "query" and action_type == "query_nccl_logs":
if "Phase 3 investigation complete" in info or value >= 0.05:
self.phase = "patch"
return
if self.phase == "patch" and action_type == "patch_divergent_code":
file_name = str(params.get("file", ""))
if file_name in self.remaining_files:
self.remaining_files.remove(file_name)
if "Phase 3 solved" in info or value >= 0.20:
self.phase = "recovery"
def candidate_actions(obs: dict[str, Any], task_id: str) -> list[dict[str, Any]]:
candidates: list[dict[str, Any]] = []
rank_order = failed_rank_order(obs)
for rank_id in rank_order:
candidates.append(
{
"action_type": "inspect_flight_recorder",
"parameters": {"rank_id": int(rank_id)},
}
)
candidates.extend(
[
{"action_type": "query_nccl_logs", "parameters": {"time_window": 5}},
{"action_type": "topo_reorder", "parameters": {"affinity": "rack"}},
{"action_type": "noop", "parameters": {}},
]
)
for file_name in PATCH_FILES:
for fix_type in ("identify_file", "propose_diff", "synchronize_conditional"):
candidates.append(
{
"action_type": "patch_divergent_code",
"parameters": {"file": file_name, "fix_type": fix_type},
}
)
return candidates
def score_action(model: Any, tokenizer: Any, prompt: str, action: dict[str, Any]) -> float:
action_text = canonical_action_text(action) + tokenizer.eos_token
prompt_ids = tokenizer(prompt, add_special_tokens=False, return_tensors="pt")["input_ids"]
full_ids = tokenizer(prompt + action_text, add_special_tokens=False, return_tensors="pt")["input_ids"]
prompt_len = prompt_ids.shape[-1]
input_ids = full_ids.to(model.device)
labels = input_ids.clone()
labels[:, :prompt_len] = -100
with torch.inference_mode():
output = model(input_ids=input_ids, labels=labels)
scored_tokens = max(1, int((labels != -100).sum().item()))
sum_logprob = -float(output.loss.item()) * scored_tokens
return sum_logprob / (scored_tokens ** LENGTH_NORM_POWER)
def choose_action(
model: Any,
tokenizer: Any,
obs: dict[str, Any],
task_id: str,
controller: PhaseController | None = None,
) -> tuple[dict[str, Any], str]:
prompt = build_prompt(tokenizer, obs, task_id)
candidates = (
controller.candidate_actions(obs)
if controller is not None
else candidate_actions(obs, task_id)
)
scored = [
(score_action(model, tokenizer, prompt, action), action)
for action in candidates
]
scored.sort(key=lambda item: item[0], reverse=True)
best_score, best_action = scored[0]
debug = json.dumps(
{
"chosen": best_action,
"score": best_score,
"length_norm_power": LENGTH_NORM_POWER,
"phase": controller.phase if controller is not None else "uncontrolled",
"candidate_count": len(candidates),
"top3": [
{"score": score, "action": action}
for score, action in scored[:3]
],
"top5": [
{"score": score, "action": action}
for score, action in scored[:5]
],
},
ensure_ascii=False,
)
return best_action, debug
def run_episode(model: Any, tokenizer: Any, task_id: str, seed: int, max_steps: int) -> dict[str, Any]:
obs = post("/reset", {"task_id": task_id, "seed": seed})
controller = PhaseController(task_id) if task_id in {"hard", "cascade"} else None
actions: list[dict[str, Any]] = []
raw_outputs: list[str] = []
rewards: list[float] = []
for _ in range(max_steps):
action, raw = choose_action(model, tokenizer, obs, task_id, controller)
actions.append(action)
raw_outputs.append(raw)
result = post("/step", action)
reward = float(result.get("reward", {}).get("value", 0.0))
rewards.append(reward)
if controller is not None:
controller.update(action, result)
obs = result.get("observation", obs)
if result.get("done", False):
break
grade = post("/grade", {"task_id": task_id})
return {
"task_id": task_id,
"seed": seed,
"score": float(grade.get("score", 0.01)),
"passed": bool(grade.get("passed", False)),
"steps": len(actions),
"total_reward": sum(rewards),
"actions": actions,
"sample_outputs": raw_outputs[:3],
}
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--adapter", default="", help="Path to trained LoRA adapter.")
parser.add_argument("--label", default="model", help="Label for output filenames.")
parser.add_argument("--seeds", type=int, default=3)
parser.add_argument("--max-steps", type=int, default=8)
parser.add_argument("--tasks", default=",".join(TASKS))
args = parser.parse_args()
selected_tasks = [task.strip() for task in args.tasks.split(",") if task.strip()]
adapter_path = args.adapter or None
model, tokenizer = load_policy(adapter_path)
episodes: list[dict[str, Any]] = []
for task_id in selected_tasks:
for seed in range(args.seeds):
episode = run_episode(model, tokenizer, task_id, seed, args.max_steps)
episodes.append(episode)
print(
f"[{args.label}] task={task_id} seed={seed} "
f"score={episode['score']:.3f} steps={episode['steps']}"
)
scores = [float(ep["score"]) for ep in episodes]
pass_rate = sum(1 for ep in episodes if ep["passed"]) / max(1, len(episodes))
summary = {
"label": args.label,
"adapter": adapter_path,
"base_model": BASE_MODEL,
"mean_score": statistics.mean(scores) if scores else 0.0,
"pass_rate": pass_rate,
"n_episodes": len(episodes),
}
result = {"summary": summary, "episodes": episodes}
output_dir = Path("results")
output_dir.mkdir(exist_ok=True)
output_path = output_dir / f"{args.label}_model_eval.json"
output_path.write_text(json.dumps(result, indent=2), encoding="utf-8")
print(json.dumps(summary, indent=2))
print(f"Saved model evaluation JSON: {output_path}")
if __name__ == "__main__":
main()