from __future__ import annotations import json import os import sys from datetime import datetime from pathlib import Path from typing import Any sys.path.insert(0, str(Path(__file__).resolve().parents[1])) from app.env import NervousSystemEnv from app.models import SREAction RESULTS_DIR = "results" PATCH_FILES = [ "model/transformer.py", "model/attention.py", "model/feedforward.py", "model/embedding.py", ] def snapshot( trace: list[dict[str, Any]], label: str, action: dict[str, Any] | None, result: Any, env: NervousSystemEnv, ) -> None: obs = env.get_state() grade = None if label == "final_grade": grade_result = env.grade("cascade") grade = { "score": grade_result.score, "passed": grade_result.passed, "breakdown": grade_result.breakdown, "explanation": grade_result.explanation, } trace.append( { "label": label, "step": obs.step_count, "action": action, "reward": ( result.reward.model_dump() if result is not None and hasattr(result, "reward") else None ), "job_status": obs.training.job_status, "throughput": obs.training.throughput_tokens_per_sec, "stale_telemetry": obs.stale_telemetry, "log_retention_steps": obs.log_retention_steps, "visible_logs": obs.visible_logs, "grade": grade, } ) def step(env: NervousSystemEnv, trace: list[dict[str, Any]], label: str, action: dict[str, Any]) -> None: result = env.step(SREAction(**action)) snapshot(trace, label, action, result, env) def main() -> None: os.makedirs(RESULTS_DIR, exist_ok=True) env = NervousSystemEnv(seed=42) obs = env.reset(task_id="cascade", seed=42) trace: list[dict[str, Any]] = [] snapshot(trace, "reset", None, None, env) failing_rank = next(node.node_id for node in obs.nodes if node.health_status == "failed") step( env, trace, "phase1_diagnose_oom", {"action_type": "inspect_flight_recorder", "parameters": {"rank_id": failing_rank}}, ) for index in range(2, 13): step(env, trace, f"monitor_stale_surface_{index}", {"action_type": "noop", "parameters": {}}) step( env, trace, "refresh_diagnostics_after_stale_warning", {"action_type": "query_nccl_logs", "parameters": {"time_window": 8}}, ) for index in range(14, 22): step(env, trace, f"wait_for_phase2_{index}", {"action_type": "noop", "parameters": {}}) step( env, trace, "phase2_rack_local_topology_fix", {"action_type": "topo_reorder", "parameters": {"affinity": "rack"}}, ) for index in range(23, 51): action = ( {"action_type": "query_nccl_logs", "parameters": {"time_window": 5}} if index in {30, 40, 50} else {"action_type": "noop", "parameters": {}} ) step(env, trace, f"monitor_recovery_and_delayed_desync_{index}", action) step( env, trace, "phase3_investigate_desync", {"action_type": "query_nccl_logs", "parameters": {"time_window": 10}}, ) selected_file = PATCH_FILES[0] for file_name in PATCH_FILES: action = { "action_type": "patch_divergent_code", "parameters": {"file": file_name, "fix_type": "identify_file"}, } result = env.step(SREAction(**action)) snapshot(trace, f"identify_candidate_{file_name}", action, result, env) if "stage 1" in result.reward.info.lower(): selected_file = file_name break step( env, trace, "propose_patch_diff", { "action_type": "patch_divergent_code", "parameters": {"file": selected_file, "fix_type": "propose_diff"}, }, ) step( env, trace, "apply_synchronize_conditional_patch", { "action_type": "patch_divergent_code", "parameters": {"file": selected_file, "fix_type": "synchronize_conditional"}, }, ) snapshot(trace, "final_grade", None, None, env) result = { "timestamp": datetime.now().isoformat(), "task_id": "cascade", "seed": 42, "trace_length": len(trace), "trace": trace, } path = os.path.join(RESULTS_DIR, "cascade_long_horizon_trace.json") with open(path, "w", encoding="utf-8") as file: json.dump(result, file, indent=2) print(f"Saved {path} with {len(trace)} trace events") if __name__ == "__main__": main()