File size: 4,433 Bytes
9731ebe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
from __future__ import annotations

import json
import os
import statistics
import sys
from datetime import datetime
from pathlib import Path
from typing import Any

sys.path.insert(0, str(Path(__file__).resolve().parents[1]))

from app.env import NervousSystemEnv
from app.models import CoalitionRequest, DelegateRequest, SREAction

RESULTS_DIR = "results"


def failing_rank(obs: dict[str, Any]) -> int:
    for node in obs["nodes"]:
        if node["health_status"] == "failed":
            return int(node["node_id"])
    return 0


def run_coordination_oracle(seed: int) -> dict[str, Any]:
    env = NervousSystemEnv(seed=seed)
    obs = env.reset(task_id="fleet_coordination", seed=seed).model_dump()
    rank_id = failing_rank(obs)
    actions: list[dict[str, Any]] = []

    actions.append(
        env.delegate(
            DelegateRequest(
                worker="log_inspector",
                action="inspect_flight_recorder",
                parameters={"rank_id": rank_id},
                supervisor_reasoning="Confirm root-cause rank before remediation.",
            )
        )
    )
    actions.append(
        env.delegate(
            DelegateRequest(
                worker="version_checker",
                action="check_nccl_version",
                parameters={},
                supervisor_reasoning="Cascade symptoms require version compatibility evidence.",
            )
        )
    )
    consensus = (
        env._fleet.get_worker_consensus(env._cluster)
        if env._fleet is not None and env._cluster is not None
        else {}
    )
    coalition = env.coalition_action(
        CoalitionRequest(
            proposing_worker="topo_agent",
            supporting_worker="version_checker",
            action="topology_version_fix",
            parameters={},
            rationale="Topology and NCCL version mismatch must be fixed jointly.",
        )
    )
    grade_result = env.grade("fleet_coordination")
    grade = {
        "score": grade_result.score,
        "passed": grade_result.passed,
        "breakdown": grade_result.breakdown,
        "explanation": grade_result.explanation,
        "task_id": "fleet_coordination",
    }
    episode = env.get_episode_summary()
    return {
        "seed": seed,
        "score": float(grade["score"]),
        "passed": bool(grade["passed"]),
        "grade": grade,
        "delegations": actions,
        "consensus": consensus,
        "coalition": coalition,
        "episode": episode,
    }


def run_direct_baseline(seed: int) -> dict[str, Any]:
    env = NervousSystemEnv(seed=seed)
    env.reset(task_id="fleet_coordination", seed=seed)
    step = env.step(
        SREAction(action_type="topo_reorder", parameters={"affinity": "rack"})
    )
    grade_result = env.grade("fleet_coordination")
    grade = {
        "score": grade_result.score,
        "passed": grade_result.passed,
        "breakdown": grade_result.breakdown,
        "explanation": grade_result.explanation,
        "task_id": "fleet_coordination",
    }
    return {
        "seed": seed,
        "score": float(grade["score"]),
        "passed": bool(grade["passed"]),
        "step": step.model_dump(),
        "grade": grade,
    }


def summarize(episodes: list[dict[str, Any]]) -> dict[str, Any]:
    scores = [float(ep["score"]) for ep in episodes]
    return {
        "mean_score": statistics.mean(scores) if scores else 0.0,
        "pass_rate": sum(1 for ep in episodes if ep["passed"]) / max(1, len(episodes)),
        "n_episodes": len(episodes),
    }


def main() -> None:
    os.makedirs(RESULTS_DIR, exist_ok=True)
    seeds = list(range(int(os.getenv("MULTI_AGENT_SEEDS", "10"))))
    oracle = [run_coordination_oracle(seed) for seed in seeds]
    direct = [run_direct_baseline(seed) for seed in seeds]
    result = {
        "timestamp": datetime.now().isoformat(),
        "task_id": "fleet_coordination",
        "oracle_summary": summarize(oracle),
        "direct_baseline_summary": summarize(direct),
        "oracle_episodes": oracle,
        "direct_baseline_episodes": direct,
    }
    path = os.path.join(RESULTS_DIR, "multi_agent_eval.json")
    with open(path, "w", encoding="utf-8") as file:
        json.dump(result, file, indent=2)
    print(json.dumps(result["oracle_summary"], indent=2))
    print(json.dumps(result["direct_baseline_summary"], indent=2))
    print(f"Saved {path}")


if __name__ == "__main__":
    main()