Spaces:
Sleeping
Sleeping
File size: 4,433 Bytes
9731ebe | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 | from __future__ import annotations
import json
import os
import statistics
import sys
from datetime import datetime
from pathlib import Path
from typing import Any
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
from app.env import NervousSystemEnv
from app.models import CoalitionRequest, DelegateRequest, SREAction
RESULTS_DIR = "results"
def failing_rank(obs: dict[str, Any]) -> int:
for node in obs["nodes"]:
if node["health_status"] == "failed":
return int(node["node_id"])
return 0
def run_coordination_oracle(seed: int) -> dict[str, Any]:
env = NervousSystemEnv(seed=seed)
obs = env.reset(task_id="fleet_coordination", seed=seed).model_dump()
rank_id = failing_rank(obs)
actions: list[dict[str, Any]] = []
actions.append(
env.delegate(
DelegateRequest(
worker="log_inspector",
action="inspect_flight_recorder",
parameters={"rank_id": rank_id},
supervisor_reasoning="Confirm root-cause rank before remediation.",
)
)
)
actions.append(
env.delegate(
DelegateRequest(
worker="version_checker",
action="check_nccl_version",
parameters={},
supervisor_reasoning="Cascade symptoms require version compatibility evidence.",
)
)
)
consensus = (
env._fleet.get_worker_consensus(env._cluster)
if env._fleet is not None and env._cluster is not None
else {}
)
coalition = env.coalition_action(
CoalitionRequest(
proposing_worker="topo_agent",
supporting_worker="version_checker",
action="topology_version_fix",
parameters={},
rationale="Topology and NCCL version mismatch must be fixed jointly.",
)
)
grade_result = env.grade("fleet_coordination")
grade = {
"score": grade_result.score,
"passed": grade_result.passed,
"breakdown": grade_result.breakdown,
"explanation": grade_result.explanation,
"task_id": "fleet_coordination",
}
episode = env.get_episode_summary()
return {
"seed": seed,
"score": float(grade["score"]),
"passed": bool(grade["passed"]),
"grade": grade,
"delegations": actions,
"consensus": consensus,
"coalition": coalition,
"episode": episode,
}
def run_direct_baseline(seed: int) -> dict[str, Any]:
env = NervousSystemEnv(seed=seed)
env.reset(task_id="fleet_coordination", seed=seed)
step = env.step(
SREAction(action_type="topo_reorder", parameters={"affinity": "rack"})
)
grade_result = env.grade("fleet_coordination")
grade = {
"score": grade_result.score,
"passed": grade_result.passed,
"breakdown": grade_result.breakdown,
"explanation": grade_result.explanation,
"task_id": "fleet_coordination",
}
return {
"seed": seed,
"score": float(grade["score"]),
"passed": bool(grade["passed"]),
"step": step.model_dump(),
"grade": grade,
}
def summarize(episodes: list[dict[str, Any]]) -> dict[str, Any]:
scores = [float(ep["score"]) for ep in episodes]
return {
"mean_score": statistics.mean(scores) if scores else 0.0,
"pass_rate": sum(1 for ep in episodes if ep["passed"]) / max(1, len(episodes)),
"n_episodes": len(episodes),
}
def main() -> None:
os.makedirs(RESULTS_DIR, exist_ok=True)
seeds = list(range(int(os.getenv("MULTI_AGENT_SEEDS", "10"))))
oracle = [run_coordination_oracle(seed) for seed in seeds]
direct = [run_direct_baseline(seed) for seed in seeds]
result = {
"timestamp": datetime.now().isoformat(),
"task_id": "fleet_coordination",
"oracle_summary": summarize(oracle),
"direct_baseline_summary": summarize(direct),
"oracle_episodes": oracle,
"direct_baseline_episodes": direct,
}
path = os.path.join(RESULTS_DIR, "multi_agent_eval.json")
with open(path, "w", encoding="utf-8") as file:
json.dump(result, file, indent=2)
print(json.dumps(result["oracle_summary"], indent=2))
print(json.dumps(result["direct_baseline_summary"], indent=2))
print(f"Saved {path}")
if __name__ == "__main__":
main()
|