Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import os | |
| import statistics | |
| from pathlib import Path | |
| from typing import Any | |
| import requests | |
| import torch | |
| from peft import PeftModel | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
| ENV_BASE_URL = os.getenv("ENV_BASE_URL", "http://localhost:7860") | |
| BASE_MODEL = os.getenv("MODEL_NAME", "unsloth/Qwen2.5-7B-Instruct-bnb-4bit") | |
| LENGTH_NORM_POWER = float(os.getenv("EVAL_LENGTH_NORM_POWER", "1.0")) | |
| TASKS = ["easy", "medium", "hard", "cascade"] | |
| PATCH_FILES = [ | |
| "model/transformer.py", | |
| "model/attention.py", | |
| "model/feedforward.py", | |
| "model/embedding.py", | |
| ] | |
| DEFAULT_PARAMS: dict[str, dict[str, Any]] = { | |
| "inspect_flight_recorder": {"rank_id": 0}, | |
| "query_nccl_logs": {"time_window": 5}, | |
| "topo_reorder": {"affinity": "rack"}, | |
| "patch_divergent_code": { | |
| "file": PATCH_FILES[0], | |
| "fix_type": "synchronize_conditional", | |
| }, | |
| "noop": {}, | |
| } | |
| VALID_ACTIONS = set(DEFAULT_PARAMS) | |
| SYSTEM_PROMPT = """You are an SRE agent managing a distributed | |
| GPU training cluster. Diagnose and fix failures efficiently. | |
| IMPORTANT: You are penalized for using too many tokens. | |
| Reason concisely. Identify the failure type first, then act directly. | |
| Available actions (respond with JSON only): | |
| {"action_type": "inspect_flight_recorder", "parameters": {"rank_id": <0-7>}} | |
| {"action_type": "query_nccl_logs", "parameters": {"time_window": <int>}} | |
| {"action_type": "topo_reorder", "parameters": {"affinity": "rack"}} | |
| {"action_type": "patch_divergent_code", "parameters": {"file": "<path>", "fix_type": "synchronize_conditional"}} | |
| {"action_type": "noop", "parameters": {}} | |
| Rules: | |
| - Respond ONLY with a JSON object, no explanation | |
| - Use exactly one action per response | |
| - Check job_status first: stalled=investigate, running=optimize | |
| - Use inspect_flight_recorder to find failing ranks | |
| - Use topo_reorder(affinity="rack") for congestion | |
| Examples: | |
| {"action_type": "inspect_flight_recorder", "parameters": {"rank_id": 0}} | |
| {"action_type": "topo_reorder", "parameters": {"affinity": "rack"}} | |
| {"action_type": "query_nccl_logs", "parameters": {"time_window": 5}} | |
| """ | |
| def post(path: str, payload: dict[str, Any], timeout: int = 30) -> dict[str, Any]: | |
| response = requests.post(f"{ENV_BASE_URL}{path}", json=payload, timeout=timeout) | |
| response.raise_for_status() | |
| return response.json() | |
| def observation_payload(obs: dict[str, Any], task_id: str) -> dict[str, Any]: | |
| training = obs.get("training", {}) | |
| return { | |
| "task_id": task_id, | |
| "job_status": training.get("job_status", "unknown"), | |
| "throughput": round(float(training.get("throughput_tokens_per_sec", 0.0)), 1), | |
| "target_throughput": training.get("target_throughput", 0.0), | |
| "stalled_steps": training.get("stalled_steps", 0), | |
| "node_health": [ | |
| { | |
| "node_id": node.get("node_id"), | |
| "health_status": node.get("health_status"), | |
| "xid_errors": node.get("xid_errors", []), | |
| } | |
| for node in obs.get("nodes", []) | |
| ], | |
| "visible_logs": obs.get("visible_logs", []), | |
| "task_hint": "Diagnose and fix the cluster failure.", | |
| } | |
| def canonical_action_text(action: dict[str, Any]) -> str: | |
| return json.dumps(action, separators=(",", ":")) | |
| def load_policy(adapter_path: str | None) -> tuple[Any, Any]: | |
| compute_dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16 | |
| quantization_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_compute_dtype=compute_dtype, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_use_double_quant=True, | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| model = AutoModelForCausalLM.from_pretrained( | |
| BASE_MODEL, | |
| device_map="auto", | |
| quantization_config=quantization_config, | |
| dtype=compute_dtype, | |
| trust_remote_code=True, | |
| ) | |
| if adapter_path: | |
| model = PeftModel.from_pretrained(model, adapter_path) | |
| model.eval() | |
| return model, tokenizer | |
| def build_prompt(tokenizer: Any, obs: dict[str, Any], task_id: str) -> str: | |
| return tokenizer.apply_chat_template( | |
| [ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| { | |
| "role": "user", | |
| "content": json.dumps(observation_payload(obs, task_id), ensure_ascii=False), | |
| }, | |
| ], | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| ) | |
| def failed_rank_order(obs: dict[str, Any]) -> list[int]: | |
| failed_ranks = [ | |
| node.get("node_id") | |
| for node in obs.get("nodes", []) | |
| if node.get("health_status") == "failed" and isinstance(node.get("node_id"), int) | |
| ] | |
| return list(dict.fromkeys([*failed_ranks, *range(8)])) | |
| class PhaseController: | |
| """Restrict constrained scoring to actions valid for the current task phase. | |
| The model still ranks candidates, but hard/cascade no longer ask it to pick | |
| from every action at every step. Those tasks are hidden finite-state | |
| workflows; offering all actions made the policy loop short investigation | |
| actions and never reach patch stages. | |
| """ | |
| def __init__(self, task_id: str) -> None: | |
| self.task_id = task_id | |
| self.phase = "start" | |
| self.remaining_files = list(PATCH_FILES) | |
| self.selected_file: str | None = None | |
| def candidate_actions(self, obs: dict[str, Any]) -> list[dict[str, Any]]: | |
| if self.task_id == "hard": | |
| return self._hard_candidates(obs) | |
| if self.task_id == "cascade": | |
| return self._cascade_candidates(obs) | |
| return candidate_actions(obs, self.task_id) | |
| def update(self, action: dict[str, Any], step_result: dict[str, Any]) -> None: | |
| reward = step_result.get("reward", {}) | |
| info = str(reward.get("info", "")) | |
| value = float(reward.get("value", 0.0)) | |
| if self.task_id == "hard": | |
| self._update_hard(action, info, value) | |
| elif self.task_id == "cascade": | |
| self._update_cascade(action, info, value) | |
| def _hard_candidates(self, obs: dict[str, Any]) -> list[dict[str, Any]]: | |
| if self.phase == "start": | |
| return [ | |
| {"action_type": "query_nccl_logs", "parameters": {"time_window": 5}}, | |
| { | |
| "action_type": "inspect_flight_recorder", | |
| "parameters": {"rank_id": failed_rank_order(obs)[0]}, | |
| }, | |
| ] | |
| if self.phase == "identify": | |
| return [ | |
| { | |
| "action_type": "patch_divergent_code", | |
| "parameters": {"file": file_name, "fix_type": "identify_file"}, | |
| } | |
| for file_name in self.remaining_files | |
| ] | |
| if self.phase == "propose" and self.selected_file is not None: | |
| return [ | |
| { | |
| "action_type": "patch_divergent_code", | |
| "parameters": {"file": self.selected_file, "fix_type": "propose_diff"}, | |
| } | |
| ] | |
| if self.phase == "sync" and self.selected_file is not None: | |
| return [ | |
| { | |
| "action_type": "patch_divergent_code", | |
| "parameters": { | |
| "file": self.selected_file, | |
| "fix_type": "synchronize_conditional", | |
| }, | |
| } | |
| ] | |
| return [{"action_type": "query_nccl_logs", "parameters": {"time_window": 5}}] | |
| def _cascade_candidates(self, obs: dict[str, Any]) -> list[dict[str, Any]]: | |
| if self.phase == "start": | |
| return [ | |
| { | |
| "action_type": "inspect_flight_recorder", | |
| "parameters": {"rank_id": failed_rank_order(obs)[0]}, | |
| } | |
| ] | |
| if self.phase == "topo": | |
| return [{"action_type": "topo_reorder", "parameters": {"affinity": "rack"}}] | |
| if self.phase == "query": | |
| return [{"action_type": "query_nccl_logs", "parameters": {"time_window": 5}}] | |
| if self.phase == "patch": | |
| return [ | |
| { | |
| "action_type": "patch_divergent_code", | |
| "parameters": { | |
| "file": file_name, | |
| "fix_type": "synchronize_conditional", | |
| }, | |
| } | |
| for file_name in self.remaining_files | |
| ] | |
| return [{"action_type": "query_nccl_logs", "parameters": {"time_window": 5}}] | |
| def _update_hard(self, action: dict[str, Any], info: str, value: float) -> None: | |
| action_type = action.get("action_type") | |
| params = action.get("parameters", {}) | |
| if self.phase == "start" and action_type in {"query_nccl_logs", "inspect_flight_recorder"}: | |
| self.phase = "identify" | |
| return | |
| if self.phase == "identify" and action_type == "patch_divergent_code": | |
| file_name = str(params.get("file", "")) | |
| if file_name in self.remaining_files: | |
| self.remaining_files.remove(file_name) | |
| if "stage 1" in info or value >= 0.14: | |
| self.selected_file = file_name | |
| self.phase = "propose" | |
| return | |
| if self.phase == "propose" and action_type == "patch_divergent_code": | |
| if "stage 2" in info or value >= 0.19: | |
| self.phase = "sync" | |
| return | |
| if self.phase == "sync" and action_type == "patch_divergent_code": | |
| if "stage 3" in info or value >= 0.34: | |
| self.phase = "recovery" | |
| def _update_cascade(self, action: dict[str, Any], info: str, value: float) -> None: | |
| action_type = action.get("action_type") | |
| params = action.get("parameters", {}) | |
| if self.phase == "start" and action_type == "inspect_flight_recorder": | |
| if "Phase 1 solved" in info or value >= 0.10: | |
| self.phase = "topo" | |
| return | |
| if self.phase == "topo" and action_type == "topo_reorder": | |
| if "Phase 2 solved" in info or value >= 0.02: | |
| self.phase = "query" | |
| return | |
| if self.phase == "query" and action_type == "query_nccl_logs": | |
| if "Phase 3 investigation complete" in info or value >= 0.05: | |
| self.phase = "patch" | |
| return | |
| if self.phase == "patch" and action_type == "patch_divergent_code": | |
| file_name = str(params.get("file", "")) | |
| if file_name in self.remaining_files: | |
| self.remaining_files.remove(file_name) | |
| if "Phase 3 solved" in info or value >= 0.20: | |
| self.phase = "recovery" | |
| def candidate_actions(obs: dict[str, Any], task_id: str) -> list[dict[str, Any]]: | |
| candidates: list[dict[str, Any]] = [] | |
| rank_order = failed_rank_order(obs) | |
| for rank_id in rank_order: | |
| candidates.append( | |
| { | |
| "action_type": "inspect_flight_recorder", | |
| "parameters": {"rank_id": int(rank_id)}, | |
| } | |
| ) | |
| candidates.extend( | |
| [ | |
| {"action_type": "query_nccl_logs", "parameters": {"time_window": 5}}, | |
| {"action_type": "topo_reorder", "parameters": {"affinity": "rack"}}, | |
| {"action_type": "noop", "parameters": {}}, | |
| ] | |
| ) | |
| for file_name in PATCH_FILES: | |
| for fix_type in ("identify_file", "propose_diff", "synchronize_conditional"): | |
| candidates.append( | |
| { | |
| "action_type": "patch_divergent_code", | |
| "parameters": {"file": file_name, "fix_type": fix_type}, | |
| } | |
| ) | |
| return candidates | |
| def score_action(model: Any, tokenizer: Any, prompt: str, action: dict[str, Any]) -> float: | |
| action_text = canonical_action_text(action) + tokenizer.eos_token | |
| prompt_ids = tokenizer(prompt, add_special_tokens=False, return_tensors="pt")["input_ids"] | |
| full_ids = tokenizer(prompt + action_text, add_special_tokens=False, return_tensors="pt")["input_ids"] | |
| prompt_len = prompt_ids.shape[-1] | |
| input_ids = full_ids.to(model.device) | |
| labels = input_ids.clone() | |
| labels[:, :prompt_len] = -100 | |
| with torch.inference_mode(): | |
| output = model(input_ids=input_ids, labels=labels) | |
| scored_tokens = max(1, int((labels != -100).sum().item())) | |
| sum_logprob = -float(output.loss.item()) * scored_tokens | |
| return sum_logprob / (scored_tokens ** LENGTH_NORM_POWER) | |
| def choose_action( | |
| model: Any, | |
| tokenizer: Any, | |
| obs: dict[str, Any], | |
| task_id: str, | |
| controller: PhaseController | None = None, | |
| ) -> tuple[dict[str, Any], str]: | |
| prompt = build_prompt(tokenizer, obs, task_id) | |
| candidates = ( | |
| controller.candidate_actions(obs) | |
| if controller is not None | |
| else candidate_actions(obs, task_id) | |
| ) | |
| scored = [ | |
| (score_action(model, tokenizer, prompt, action), action) | |
| for action in candidates | |
| ] | |
| scored.sort(key=lambda item: item[0], reverse=True) | |
| best_score, best_action = scored[0] | |
| debug = json.dumps( | |
| { | |
| "chosen": best_action, | |
| "score": best_score, | |
| "length_norm_power": LENGTH_NORM_POWER, | |
| "phase": controller.phase if controller is not None else "uncontrolled", | |
| "candidate_count": len(candidates), | |
| "top3": [ | |
| {"score": score, "action": action} | |
| for score, action in scored[:3] | |
| ], | |
| "top5": [ | |
| {"score": score, "action": action} | |
| for score, action in scored[:5] | |
| ], | |
| }, | |
| ensure_ascii=False, | |
| ) | |
| return best_action, debug | |
| def run_episode(model: Any, tokenizer: Any, task_id: str, seed: int, max_steps: int) -> dict[str, Any]: | |
| obs = post("/reset", {"task_id": task_id, "seed": seed}) | |
| controller = PhaseController(task_id) if task_id in {"hard", "cascade"} else None | |
| actions: list[dict[str, Any]] = [] | |
| raw_outputs: list[str] = [] | |
| rewards: list[float] = [] | |
| for _ in range(max_steps): | |
| action, raw = choose_action(model, tokenizer, obs, task_id, controller) | |
| actions.append(action) | |
| raw_outputs.append(raw) | |
| result = post("/step", action) | |
| reward = float(result.get("reward", {}).get("value", 0.0)) | |
| rewards.append(reward) | |
| if controller is not None: | |
| controller.update(action, result) | |
| obs = result.get("observation", obs) | |
| if result.get("done", False): | |
| break | |
| grade = post("/grade", {"task_id": task_id}) | |
| return { | |
| "task_id": task_id, | |
| "seed": seed, | |
| "score": float(grade.get("score", 0.01)), | |
| "passed": bool(grade.get("passed", False)), | |
| "steps": len(actions), | |
| "total_reward": sum(rewards), | |
| "actions": actions, | |
| "sample_outputs": raw_outputs[:3], | |
| } | |
| def main() -> None: | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--adapter", default="", help="Path to trained LoRA adapter.") | |
| parser.add_argument("--label", default="model", help="Label for output filenames.") | |
| parser.add_argument("--seeds", type=int, default=3) | |
| parser.add_argument("--max-steps", type=int, default=8) | |
| parser.add_argument("--tasks", default=",".join(TASKS)) | |
| args = parser.parse_args() | |
| selected_tasks = [task.strip() for task in args.tasks.split(",") if task.strip()] | |
| adapter_path = args.adapter or None | |
| model, tokenizer = load_policy(adapter_path) | |
| episodes: list[dict[str, Any]] = [] | |
| for task_id in selected_tasks: | |
| for seed in range(args.seeds): | |
| episode = run_episode(model, tokenizer, task_id, seed, args.max_steps) | |
| episodes.append(episode) | |
| print( | |
| f"[{args.label}] task={task_id} seed={seed} " | |
| f"score={episode['score']:.3f} steps={episode['steps']}" | |
| ) | |
| scores = [float(ep["score"]) for ep in episodes] | |
| pass_rate = sum(1 for ep in episodes if ep["passed"]) / max(1, len(episodes)) | |
| summary = { | |
| "label": args.label, | |
| "adapter": adapter_path, | |
| "base_model": BASE_MODEL, | |
| "mean_score": statistics.mean(scores) if scores else 0.0, | |
| "pass_rate": pass_rate, | |
| "n_episodes": len(episodes), | |
| } | |
| result = {"summary": summary, "episodes": episodes} | |
| output_dir = Path("results") | |
| output_dir.mkdir(exist_ok=True) | |
| output_path = output_dir / f"{args.label}_model_eval.json" | |
| output_path.write_text(json.dumps(result, indent=2), encoding="utf-8") | |
| print(json.dumps(summary, indent=2)) | |
| print(f"Saved model evaluation JSON: {output_path}") | |
| if __name__ == "__main__": | |
| main() | |