Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import json | |
| import os | |
| import statistics | |
| import sys | |
| from datetime import datetime | |
| from pathlib import Path | |
| from typing import Any | |
| sys.path.insert(0, str(Path(__file__).resolve().parents[1])) | |
| from app.env import NervousSystemEnv | |
| from app.models import CoalitionRequest, DelegateRequest, SREAction | |
| RESULTS_DIR = "results" | |
| def failing_rank(obs: dict[str, Any]) -> int: | |
| for node in obs["nodes"]: | |
| if node["health_status"] == "failed": | |
| return int(node["node_id"]) | |
| return 0 | |
| def run_coordination_oracle(seed: int) -> dict[str, Any]: | |
| env = NervousSystemEnv(seed=seed) | |
| obs = env.reset(task_id="fleet_coordination", seed=seed).model_dump() | |
| rank_id = failing_rank(obs) | |
| actions: list[dict[str, Any]] = [] | |
| actions.append( | |
| env.delegate( | |
| DelegateRequest( | |
| worker="log_inspector", | |
| action="inspect_flight_recorder", | |
| parameters={"rank_id": rank_id}, | |
| supervisor_reasoning="Confirm root-cause rank before remediation.", | |
| ) | |
| ) | |
| ) | |
| actions.append( | |
| env.delegate( | |
| DelegateRequest( | |
| worker="version_checker", | |
| action="check_nccl_version", | |
| parameters={}, | |
| supervisor_reasoning="Cascade symptoms require version compatibility evidence.", | |
| ) | |
| ) | |
| ) | |
| consensus = ( | |
| env._fleet.get_worker_consensus(env._cluster) | |
| if env._fleet is not None and env._cluster is not None | |
| else {} | |
| ) | |
| coalition = env.coalition_action( | |
| CoalitionRequest( | |
| proposing_worker="topo_agent", | |
| supporting_worker="version_checker", | |
| action="topology_version_fix", | |
| parameters={}, | |
| rationale="Topology and NCCL version mismatch must be fixed jointly.", | |
| ) | |
| ) | |
| grade_result = env.grade("fleet_coordination") | |
| grade = { | |
| "score": grade_result.score, | |
| "passed": grade_result.passed, | |
| "breakdown": grade_result.breakdown, | |
| "explanation": grade_result.explanation, | |
| "task_id": "fleet_coordination", | |
| } | |
| episode = env.get_episode_summary() | |
| return { | |
| "seed": seed, | |
| "score": float(grade["score"]), | |
| "passed": bool(grade["passed"]), | |
| "grade": grade, | |
| "delegations": actions, | |
| "consensus": consensus, | |
| "coalition": coalition, | |
| "episode": episode, | |
| } | |
| def run_direct_baseline(seed: int) -> dict[str, Any]: | |
| env = NervousSystemEnv(seed=seed) | |
| env.reset(task_id="fleet_coordination", seed=seed) | |
| step = env.step( | |
| SREAction(action_type="topo_reorder", parameters={"affinity": "rack"}) | |
| ) | |
| grade_result = env.grade("fleet_coordination") | |
| grade = { | |
| "score": grade_result.score, | |
| "passed": grade_result.passed, | |
| "breakdown": grade_result.breakdown, | |
| "explanation": grade_result.explanation, | |
| "task_id": "fleet_coordination", | |
| } | |
| return { | |
| "seed": seed, | |
| "score": float(grade["score"]), | |
| "passed": bool(grade["passed"]), | |
| "step": step.model_dump(), | |
| "grade": grade, | |
| } | |
| def summarize(episodes: list[dict[str, Any]]) -> dict[str, Any]: | |
| scores = [float(ep["score"]) for ep in episodes] | |
| return { | |
| "mean_score": statistics.mean(scores) if scores else 0.0, | |
| "pass_rate": sum(1 for ep in episodes if ep["passed"]) / max(1, len(episodes)), | |
| "n_episodes": len(episodes), | |
| } | |
| def main() -> None: | |
| os.makedirs(RESULTS_DIR, exist_ok=True) | |
| seeds = list(range(int(os.getenv("MULTI_AGENT_SEEDS", "10")))) | |
| oracle = [run_coordination_oracle(seed) for seed in seeds] | |
| direct = [run_direct_baseline(seed) for seed in seeds] | |
| result = { | |
| "timestamp": datetime.now().isoformat(), | |
| "task_id": "fleet_coordination", | |
| "oracle_summary": summarize(oracle), | |
| "direct_baseline_summary": summarize(direct), | |
| "oracle_episodes": oracle, | |
| "direct_baseline_episodes": direct, | |
| } | |
| path = os.path.join(RESULTS_DIR, "multi_agent_eval.json") | |
| with open(path, "w", encoding="utf-8") as file: | |
| json.dump(result, file, indent=2) | |
| print(json.dumps(result["oracle_summary"], indent=2)) | |
| print(json.dumps(result["direct_baseline_summary"], indent=2)) | |
| print(f"Saved {path}") | |
| if __name__ == "__main__": | |
| main() | |