from __future__ import annotations import argparse import json import os import statistics from pathlib import Path from typing import Any import requests import torch from peft import PeftModel from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig ENV_BASE_URL = os.getenv("ENV_BASE_URL", "http://localhost:7860") BASE_MODEL = os.getenv("MODEL_NAME", "unsloth/Qwen2.5-7B-Instruct-bnb-4bit") LENGTH_NORM_POWER = float(os.getenv("EVAL_LENGTH_NORM_POWER", "1.0")) TASKS = ["easy", "medium", "hard", "cascade"] PATCH_FILES = [ "model/transformer.py", "model/attention.py", "model/feedforward.py", "model/embedding.py", ] DEFAULT_PARAMS: dict[str, dict[str, Any]] = { "inspect_flight_recorder": {"rank_id": 0}, "query_nccl_logs": {"time_window": 5}, "topo_reorder": {"affinity": "rack"}, "patch_divergent_code": { "file": PATCH_FILES[0], "fix_type": "synchronize_conditional", }, "noop": {}, } VALID_ACTIONS = set(DEFAULT_PARAMS) SYSTEM_PROMPT = """You are an SRE agent managing a distributed GPU training cluster. Diagnose and fix failures efficiently. IMPORTANT: You are penalized for using too many tokens. Reason concisely. Identify the failure type first, then act directly. Available actions (respond with JSON only): {"action_type": "inspect_flight_recorder", "parameters": {"rank_id": <0-7>}} {"action_type": "query_nccl_logs", "parameters": {"time_window": }} {"action_type": "topo_reorder", "parameters": {"affinity": "rack"}} {"action_type": "patch_divergent_code", "parameters": {"file": "", "fix_type": "synchronize_conditional"}} {"action_type": "noop", "parameters": {}} Rules: - Respond ONLY with a JSON object, no explanation - Use exactly one action per response - Check job_status first: stalled=investigate, running=optimize - Use inspect_flight_recorder to find failing ranks - Use topo_reorder(affinity="rack") for congestion Examples: {"action_type": "inspect_flight_recorder", "parameters": {"rank_id": 0}} {"action_type": "topo_reorder", "parameters": {"affinity": "rack"}} {"action_type": "query_nccl_logs", "parameters": {"time_window": 5}} """ def post(path: str, payload: dict[str, Any], timeout: int = 30) -> dict[str, Any]: response = requests.post(f"{ENV_BASE_URL}{path}", json=payload, timeout=timeout) response.raise_for_status() return response.json() def observation_payload(obs: dict[str, Any], task_id: str) -> dict[str, Any]: training = obs.get("training", {}) return { "task_id": task_id, "job_status": training.get("job_status", "unknown"), "throughput": round(float(training.get("throughput_tokens_per_sec", 0.0)), 1), "target_throughput": training.get("target_throughput", 0.0), "stalled_steps": training.get("stalled_steps", 0), "node_health": [ { "node_id": node.get("node_id"), "health_status": node.get("health_status"), "xid_errors": node.get("xid_errors", []), } for node in obs.get("nodes", []) ], "visible_logs": obs.get("visible_logs", []), "task_hint": "Diagnose and fix the cluster failure.", } def canonical_action_text(action: dict[str, Any]) -> str: return json.dumps(action, separators=(",", ":")) def load_policy(adapter_path: str | None) -> tuple[Any, Any]: compute_dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16 quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=compute_dtype, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, ) tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained( BASE_MODEL, device_map="auto", quantization_config=quantization_config, dtype=compute_dtype, trust_remote_code=True, ) if adapter_path: model = PeftModel.from_pretrained(model, adapter_path) model.eval() return model, tokenizer def build_prompt(tokenizer: Any, obs: dict[str, Any], task_id: str) -> str: return tokenizer.apply_chat_template( [ {"role": "system", "content": SYSTEM_PROMPT}, { "role": "user", "content": json.dumps(observation_payload(obs, task_id), ensure_ascii=False), }, ], tokenize=False, add_generation_prompt=True, ) def failed_rank_order(obs: dict[str, Any]) -> list[int]: failed_ranks = [ node.get("node_id") for node in obs.get("nodes", []) if node.get("health_status") == "failed" and isinstance(node.get("node_id"), int) ] return list(dict.fromkeys([*failed_ranks, *range(8)])) class PhaseController: """Restrict constrained scoring to actions valid for the current task phase. The model still ranks candidates, but hard/cascade no longer ask it to pick from every action at every step. Those tasks are hidden finite-state workflows; offering all actions made the policy loop short investigation actions and never reach patch stages. """ def __init__(self, task_id: str) -> None: self.task_id = task_id self.phase = "start" self.remaining_files = list(PATCH_FILES) self.selected_file: str | None = None def candidate_actions(self, obs: dict[str, Any]) -> list[dict[str, Any]]: if self.task_id == "hard": return self._hard_candidates(obs) if self.task_id == "cascade": return self._cascade_candidates(obs) return candidate_actions(obs, self.task_id) def update(self, action: dict[str, Any], step_result: dict[str, Any]) -> None: reward = step_result.get("reward", {}) info = str(reward.get("info", "")) value = float(reward.get("value", 0.0)) if self.task_id == "hard": self._update_hard(action, info, value) elif self.task_id == "cascade": self._update_cascade(action, info, value) def _hard_candidates(self, obs: dict[str, Any]) -> list[dict[str, Any]]: if self.phase == "start": return [ {"action_type": "query_nccl_logs", "parameters": {"time_window": 5}}, { "action_type": "inspect_flight_recorder", "parameters": {"rank_id": failed_rank_order(obs)[0]}, }, ] if self.phase == "identify": return [ { "action_type": "patch_divergent_code", "parameters": {"file": file_name, "fix_type": "identify_file"}, } for file_name in self.remaining_files ] if self.phase == "propose" and self.selected_file is not None: return [ { "action_type": "patch_divergent_code", "parameters": {"file": self.selected_file, "fix_type": "propose_diff"}, } ] if self.phase == "sync" and self.selected_file is not None: return [ { "action_type": "patch_divergent_code", "parameters": { "file": self.selected_file, "fix_type": "synchronize_conditional", }, } ] return [{"action_type": "query_nccl_logs", "parameters": {"time_window": 5}}] def _cascade_candidates(self, obs: dict[str, Any]) -> list[dict[str, Any]]: if self.phase == "start": return [ { "action_type": "inspect_flight_recorder", "parameters": {"rank_id": failed_rank_order(obs)[0]}, } ] if self.phase == "topo": return [{"action_type": "topo_reorder", "parameters": {"affinity": "rack"}}] if self.phase == "query": return [{"action_type": "query_nccl_logs", "parameters": {"time_window": 5}}] if self.phase == "patch": return [ { "action_type": "patch_divergent_code", "parameters": { "file": file_name, "fix_type": "synchronize_conditional", }, } for file_name in self.remaining_files ] return [{"action_type": "query_nccl_logs", "parameters": {"time_window": 5}}] def _update_hard(self, action: dict[str, Any], info: str, value: float) -> None: action_type = action.get("action_type") params = action.get("parameters", {}) if self.phase == "start" and action_type in {"query_nccl_logs", "inspect_flight_recorder"}: self.phase = "identify" return if self.phase == "identify" and action_type == "patch_divergent_code": file_name = str(params.get("file", "")) if file_name in self.remaining_files: self.remaining_files.remove(file_name) if "stage 1" in info or value >= 0.14: self.selected_file = file_name self.phase = "propose" return if self.phase == "propose" and action_type == "patch_divergent_code": if "stage 2" in info or value >= 0.19: self.phase = "sync" return if self.phase == "sync" and action_type == "patch_divergent_code": if "stage 3" in info or value >= 0.34: self.phase = "recovery" def _update_cascade(self, action: dict[str, Any], info: str, value: float) -> None: action_type = action.get("action_type") params = action.get("parameters", {}) if self.phase == "start" and action_type == "inspect_flight_recorder": if "Phase 1 solved" in info or value >= 0.10: self.phase = "topo" return if self.phase == "topo" and action_type == "topo_reorder": if "Phase 2 solved" in info or value >= 0.02: self.phase = "query" return if self.phase == "query" and action_type == "query_nccl_logs": if "Phase 3 investigation complete" in info or value >= 0.05: self.phase = "patch" return if self.phase == "patch" and action_type == "patch_divergent_code": file_name = str(params.get("file", "")) if file_name in self.remaining_files: self.remaining_files.remove(file_name) if "Phase 3 solved" in info or value >= 0.20: self.phase = "recovery" def candidate_actions(obs: dict[str, Any], task_id: str) -> list[dict[str, Any]]: candidates: list[dict[str, Any]] = [] rank_order = failed_rank_order(obs) for rank_id in rank_order: candidates.append( { "action_type": "inspect_flight_recorder", "parameters": {"rank_id": int(rank_id)}, } ) candidates.extend( [ {"action_type": "query_nccl_logs", "parameters": {"time_window": 5}}, {"action_type": "topo_reorder", "parameters": {"affinity": "rack"}}, {"action_type": "noop", "parameters": {}}, ] ) for file_name in PATCH_FILES: for fix_type in ("identify_file", "propose_diff", "synchronize_conditional"): candidates.append( { "action_type": "patch_divergent_code", "parameters": {"file": file_name, "fix_type": fix_type}, } ) return candidates def score_action(model: Any, tokenizer: Any, prompt: str, action: dict[str, Any]) -> float: action_text = canonical_action_text(action) + tokenizer.eos_token prompt_ids = tokenizer(prompt, add_special_tokens=False, return_tensors="pt")["input_ids"] full_ids = tokenizer(prompt + action_text, add_special_tokens=False, return_tensors="pt")["input_ids"] prompt_len = prompt_ids.shape[-1] input_ids = full_ids.to(model.device) labels = input_ids.clone() labels[:, :prompt_len] = -100 with torch.inference_mode(): output = model(input_ids=input_ids, labels=labels) scored_tokens = max(1, int((labels != -100).sum().item())) sum_logprob = -float(output.loss.item()) * scored_tokens return sum_logprob / (scored_tokens ** LENGTH_NORM_POWER) def choose_action( model: Any, tokenizer: Any, obs: dict[str, Any], task_id: str, controller: PhaseController | None = None, ) -> tuple[dict[str, Any], str]: prompt = build_prompt(tokenizer, obs, task_id) candidates = ( controller.candidate_actions(obs) if controller is not None else candidate_actions(obs, task_id) ) scored = [ (score_action(model, tokenizer, prompt, action), action) for action in candidates ] scored.sort(key=lambda item: item[0], reverse=True) best_score, best_action = scored[0] debug = json.dumps( { "chosen": best_action, "score": best_score, "length_norm_power": LENGTH_NORM_POWER, "phase": controller.phase if controller is not None else "uncontrolled", "candidate_count": len(candidates), "top3": [ {"score": score, "action": action} for score, action in scored[:3] ], "top5": [ {"score": score, "action": action} for score, action in scored[:5] ], }, ensure_ascii=False, ) return best_action, debug def run_episode(model: Any, tokenizer: Any, task_id: str, seed: int, max_steps: int) -> dict[str, Any]: obs = post("/reset", {"task_id": task_id, "seed": seed}) controller = PhaseController(task_id) if task_id in {"hard", "cascade"} else None actions: list[dict[str, Any]] = [] raw_outputs: list[str] = [] rewards: list[float] = [] for _ in range(max_steps): action, raw = choose_action(model, tokenizer, obs, task_id, controller) actions.append(action) raw_outputs.append(raw) result = post("/step", action) reward = float(result.get("reward", {}).get("value", 0.0)) rewards.append(reward) if controller is not None: controller.update(action, result) obs = result.get("observation", obs) if result.get("done", False): break grade = post("/grade", {"task_id": task_id}) return { "task_id": task_id, "seed": seed, "score": float(grade.get("score", 0.01)), "passed": bool(grade.get("passed", False)), "steps": len(actions), "total_reward": sum(rewards), "actions": actions, "sample_outputs": raw_outputs[:3], } def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("--adapter", default="", help="Path to trained LoRA adapter.") parser.add_argument("--label", default="model", help="Label for output filenames.") parser.add_argument("--seeds", type=int, default=3) parser.add_argument("--max-steps", type=int, default=8) parser.add_argument("--tasks", default=",".join(TASKS)) args = parser.parse_args() selected_tasks = [task.strip() for task in args.tasks.split(",") if task.strip()] adapter_path = args.adapter or None model, tokenizer = load_policy(adapter_path) episodes: list[dict[str, Any]] = [] for task_id in selected_tasks: for seed in range(args.seeds): episode = run_episode(model, tokenizer, task_id, seed, args.max_steps) episodes.append(episode) print( f"[{args.label}] task={task_id} seed={seed} " f"score={episode['score']:.3f} steps={episode['steps']}" ) scores = [float(ep["score"]) for ep in episodes] pass_rate = sum(1 for ep in episodes if ep["passed"]) / max(1, len(episodes)) summary = { "label": args.label, "adapter": adapter_path, "base_model": BASE_MODEL, "mean_score": statistics.mean(scores) if scores else 0.0, "pass_rate": pass_rate, "n_episodes": len(episodes), } result = {"summary": summary, "episodes": episodes} output_dir = Path("results") output_dir.mkdir(exist_ok=True) output_path = output_dir / f"{args.label}_model_eval.json" output_path.write_text(json.dumps(result, indent=2), encoding="utf-8") print(json.dumps(summary, indent=2)) print(f"Saved model evaluation JSON: {output_path}") if __name__ == "__main__": main()