Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import json | |
| import os | |
| import sys | |
| from datetime import datetime | |
| from pathlib import Path | |
| from typing import Any | |
| sys.path.insert(0, str(Path(__file__).resolve().parents[1])) | |
| from app.env import NervousSystemEnv | |
| from app.models import SREAction | |
| RESULTS_DIR = "results" | |
| PATCH_FILES = [ | |
| "model/transformer.py", | |
| "model/attention.py", | |
| "model/feedforward.py", | |
| "model/embedding.py", | |
| ] | |
| def snapshot( | |
| trace: list[dict[str, Any]], | |
| label: str, | |
| action: dict[str, Any] | None, | |
| result: Any, | |
| env: NervousSystemEnv, | |
| ) -> None: | |
| obs = env.get_state() | |
| grade = None | |
| if label == "final_grade": | |
| grade_result = env.grade("cascade") | |
| grade = { | |
| "score": grade_result.score, | |
| "passed": grade_result.passed, | |
| "breakdown": grade_result.breakdown, | |
| "explanation": grade_result.explanation, | |
| } | |
| trace.append( | |
| { | |
| "label": label, | |
| "step": obs.step_count, | |
| "action": action, | |
| "reward": ( | |
| result.reward.model_dump() if result is not None and hasattr(result, "reward") else None | |
| ), | |
| "job_status": obs.training.job_status, | |
| "throughput": obs.training.throughput_tokens_per_sec, | |
| "stale_telemetry": obs.stale_telemetry, | |
| "log_retention_steps": obs.log_retention_steps, | |
| "visible_logs": obs.visible_logs, | |
| "grade": grade, | |
| } | |
| ) | |
| def step(env: NervousSystemEnv, trace: list[dict[str, Any]], label: str, action: dict[str, Any]) -> None: | |
| result = env.step(SREAction(**action)) | |
| snapshot(trace, label, action, result, env) | |
| def main() -> None: | |
| os.makedirs(RESULTS_DIR, exist_ok=True) | |
| env = NervousSystemEnv(seed=42) | |
| obs = env.reset(task_id="cascade", seed=42) | |
| trace: list[dict[str, Any]] = [] | |
| snapshot(trace, "reset", None, None, env) | |
| failing_rank = next(node.node_id for node in obs.nodes if node.health_status == "failed") | |
| step( | |
| env, | |
| trace, | |
| "phase1_diagnose_oom", | |
| {"action_type": "inspect_flight_recorder", "parameters": {"rank_id": failing_rank}}, | |
| ) | |
| for index in range(2, 13): | |
| step(env, trace, f"monitor_stale_surface_{index}", {"action_type": "noop", "parameters": {}}) | |
| step( | |
| env, | |
| trace, | |
| "refresh_diagnostics_after_stale_warning", | |
| {"action_type": "query_nccl_logs", "parameters": {"time_window": 8}}, | |
| ) | |
| for index in range(14, 22): | |
| step(env, trace, f"wait_for_phase2_{index}", {"action_type": "noop", "parameters": {}}) | |
| step( | |
| env, | |
| trace, | |
| "phase2_rack_local_topology_fix", | |
| {"action_type": "topo_reorder", "parameters": {"affinity": "rack"}}, | |
| ) | |
| for index in range(23, 51): | |
| action = ( | |
| {"action_type": "query_nccl_logs", "parameters": {"time_window": 5}} | |
| if index in {30, 40, 50} | |
| else {"action_type": "noop", "parameters": {}} | |
| ) | |
| step(env, trace, f"monitor_recovery_and_delayed_desync_{index}", action) | |
| step( | |
| env, | |
| trace, | |
| "phase3_investigate_desync", | |
| {"action_type": "query_nccl_logs", "parameters": {"time_window": 10}}, | |
| ) | |
| selected_file = PATCH_FILES[0] | |
| for file_name in PATCH_FILES: | |
| action = { | |
| "action_type": "patch_divergent_code", | |
| "parameters": {"file": file_name, "fix_type": "identify_file"}, | |
| } | |
| result = env.step(SREAction(**action)) | |
| snapshot(trace, f"identify_candidate_{file_name}", action, result, env) | |
| if "stage 1" in result.reward.info.lower(): | |
| selected_file = file_name | |
| break | |
| step( | |
| env, | |
| trace, | |
| "propose_patch_diff", | |
| { | |
| "action_type": "patch_divergent_code", | |
| "parameters": {"file": selected_file, "fix_type": "propose_diff"}, | |
| }, | |
| ) | |
| step( | |
| env, | |
| trace, | |
| "apply_synchronize_conditional_patch", | |
| { | |
| "action_type": "patch_divergent_code", | |
| "parameters": {"file": selected_file, "fix_type": "synchronize_conditional"}, | |
| }, | |
| ) | |
| snapshot(trace, "final_grade", None, None, env) | |
| result = { | |
| "timestamp": datetime.now().isoformat(), | |
| "task_id": "cascade", | |
| "seed": 42, | |
| "trace_length": len(trace), | |
| "trace": trace, | |
| } | |
| path = os.path.join(RESULTS_DIR, "cascade_long_horizon_trace.json") | |
| with open(path, "w", encoding="utf-8") as file: | |
| json.dump(result, file, indent=2) | |
| print(f"Saved {path} with {len(trace)} trace events") | |
| if __name__ == "__main__": | |
| main() | |