Spaces:
Sleeping
Sleeping
| """GPU cluster state machine implementation.""" | |
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| from random import Random | |
| from typing import Any, Literal | |
| from app.config import DESTRUCTIVE_ACTIONS, BASELINE_THROUGHPUT, NUM_NODES, ScenarioConfig | |
| from app.models import SREAction | |
| class TrainingState: | |
| """Mutable training metrics for the cluster.""" | |
| throughput_tokens_per_sec: float | |
| target_throughput: float | |
| stalled_steps: int | |
| current_step: int | |
| job_status: Literal["running", "stalled", "failed", "recovered"] | |
| class NodeSnapshot: | |
| """Internal node state for simulation.""" | |
| node_id: int | |
| gpu_memory_used_mb: float | |
| gpu_utilization_pct: float | |
| health_status: Literal["healthy", "degraded", "failed"] | |
| xid_errors: list[Literal[79, 48, 31, 74]] | |
| class FailureSignal: | |
| """Represents an injected failure condition.""" | |
| node_id: int | None | |
| severity: Literal["minor", "major", "critical"] | |
| cause: str | |
| class ActionResult: | |
| """Outcome metadata for an SRE action.""" | |
| success: bool | |
| is_destructive: bool | |
| action_output: dict[str, Any] | |
| class ClusterStateMachine: | |
| """State machine that simulates a distributed GPU cluster.""" | |
| def __init__(self, scenario: ScenarioConfig, seed: int) -> None: | |
| """Initialize the cluster with a scenario configuration.""" | |
| self._scenario = scenario | |
| self._seed = seed | |
| self._random = Random(seed) | |
| self.patch_stage: int = 0 | |
| self.divergent_rank_id: int | None = None | |
| self.nodes: list[NodeSnapshot] = self._initialize_nodes() | |
| self.training = self._initialize_training() | |
| def reset(self, scenario: ScenarioConfig, seed: int) -> None: | |
| """Reset cluster state with a new scenario and seed.""" | |
| self._scenario = scenario | |
| self._seed = seed | |
| self._random = Random(seed) | |
| self.patch_stage = 0 | |
| self.divergent_rank_id = None | |
| self.nodes = self._initialize_nodes() | |
| self.training = self._initialize_training() | |
| def inject_scenario_failure(self) -> None: | |
| """Inject the scenario failure once at episode start.""" | |
| if self._scenario.failure_type == "oom": | |
| node = self.nodes[self._scenario.failing_node_id] | |
| node.health_status = "failed" | |
| node.xid_errors.append(79) | |
| self.training.job_status = "stalled" | |
| elif self._scenario.failure_type == "congestion": | |
| self.training.throughput_tokens_per_sec = ( | |
| self._scenario.target_throughput * self._scenario.congestion_bandwidth_pct | |
| ) | |
| self.training.job_status = "running" | |
| elif self._scenario.failure_type == "desync": | |
| self.training.job_status = "stalled" | |
| self.training.stalled_steps = 5 | |
| self.divergent_rank_id = self._scenario.failing_rank_id | |
| elif self._scenario.failure_type == "cascade": | |
| node = self.nodes[self._scenario.failing_node_id] | |
| node.health_status = "failed" | |
| node.xid_errors.append(79) | |
| for correlated_node_id in self._scenario.correlated_fault_nodes: | |
| if 0 <= correlated_node_id < len(self.nodes): | |
| correlated_node = self.nodes[correlated_node_id] | |
| if correlated_node.health_status == "healthy": | |
| correlated_node.health_status = "degraded" | |
| if 48 not in correlated_node.xid_errors: | |
| correlated_node.xid_errors.append(48) | |
| self.training.job_status = "stalled" | |
| elif self._scenario.failure_type == "black_swan": | |
| node = self.nodes[self._scenario.failing_node_id] | |
| node.health_status = "failed" | |
| node.xid_errors.append(79) | |
| for false_positive_node_id in self._scenario.false_positive_nodes: | |
| if 0 <= false_positive_node_id < len(self.nodes): | |
| false_node = self.nodes[false_positive_node_id] | |
| if false_node.health_status == "healthy": | |
| false_node.health_status = "degraded" | |
| if 48 not in false_node.xid_errors: | |
| false_node.xid_errors.append(48) | |
| self.training.throughput_tokens_per_sec = ( | |
| self._scenario.target_throughput * self._scenario.congestion_bandwidth_pct | |
| ) | |
| self.training.job_status = "stalled" | |
| def apply_failure(self, failure: FailureSignal | None) -> None: | |
| """Apply a secondary failure signal to cluster nodes.""" | |
| if failure is None or failure.node_id is None: | |
| return | |
| node = self.nodes[failure.node_id] | |
| if failure.severity == "minor": | |
| node.health_status = "degraded" | |
| node.xid_errors.append(79) | |
| elif failure.severity == "major": | |
| node.health_status = "degraded" | |
| node.xid_errors.append(48) | |
| self.training.job_status = "stalled" | |
| else: | |
| node.health_status = "failed" | |
| node.xid_errors.append(74) | |
| self.training.job_status = "failed" | |
| def apply_action(self, action: SREAction) -> ActionResult: | |
| """Apply the given SRE action to the cluster.""" | |
| is_destructive = action.action_type in DESTRUCTIVE_ACTIONS | |
| params = action.parameters | |
| if action.action_type == "inspect_flight_recorder": | |
| rank_id = params.get("rank_id") | |
| if rank_id is None: | |
| return ActionResult( | |
| success=False, | |
| is_destructive=is_destructive, | |
| action_output={"error": "rank_id parameter required"}, | |
| ) | |
| return ActionResult( | |
| success=True, | |
| is_destructive=is_destructive, | |
| action_output={ | |
| "flight_recorder": self._generate_flight_recorder_data(int(rank_id)) | |
| }, | |
| ) | |
| if action.action_type == "query_nccl_logs": | |
| time_window = int(params.get("time_window", 10)) | |
| return ActionResult( | |
| success=True, | |
| is_destructive=is_destructive, | |
| action_output={"nccl_logs": self._generate_nccl_logs(time_window)}, | |
| ) | |
| if action.action_type == "topo_reorder": | |
| affinity = params.get("affinity") | |
| if affinity is None: | |
| return ActionResult( | |
| success=False, | |
| is_destructive=is_destructive, | |
| action_output={"error": "affinity parameter required"}, | |
| ) | |
| if affinity == "rack": | |
| rack_map = { | |
| node_id: rack_index | |
| for rack_index, rack in enumerate(self._scenario.rack_layout) | |
| for node_id in rack | |
| } | |
| ring = list(range(len(self.nodes))) | |
| crosses_racks = any( | |
| rack_map.get(ring[index]) != rack_map.get(ring[(index + 1) % len(ring)]) | |
| for index in range(len(ring)) | |
| ) | |
| # With randomized rack layout, ring usually crosses racks; rack-local reorder helps most. | |
| boost = 1.35 if crosses_racks else 1.05 | |
| self.training.throughput_tokens_per_sec *= boost | |
| self.training.job_status = "recovered" | |
| else: | |
| self.training.throughput_tokens_per_sec *= 1.05 | |
| return ActionResult( | |
| success=True, | |
| is_destructive=is_destructive, | |
| action_output={"affinity": affinity}, | |
| ) | |
| if action.action_type == "patch_divergent_code": | |
| file = params.get("file") | |
| fix_type = params.get("fix_type") | |
| if file is None or fix_type is None: | |
| return ActionResult( | |
| success=False, | |
| is_destructive=is_destructive, | |
| action_output={"error": "file and fix_type required"}, | |
| ) | |
| if fix_type == "identify_file": | |
| if file == self._scenario.divergent_file: | |
| self.patch_stage = max(self.patch_stage, 1) | |
| return ActionResult( | |
| success=True, | |
| is_destructive=is_destructive, | |
| action_output={ | |
| "stage": 1, | |
| "file": file, | |
| "hint": "File confirmed divergent. Propose a diff next.", | |
| }, | |
| ) | |
| return ActionResult( | |
| success=False, | |
| is_destructive=is_destructive, | |
| action_output={"stage": 0, "error": "wrong file"}, | |
| ) | |
| if fix_type == "propose_diff": | |
| if self.patch_stage >= 1: | |
| self.patch_stage = max(self.patch_stage, 2) | |
| return ActionResult( | |
| success=True, | |
| is_destructive=is_destructive, | |
| action_output={ | |
| "stage": 2, | |
| "hint": "Diff accepted. Apply synchronize_conditional to fix.", | |
| }, | |
| ) | |
| return ActionResult( | |
| success=False, | |
| is_destructive=is_destructive, | |
| action_output={"error": "Must identify file first"}, | |
| ) | |
| if fix_type == "synchronize_conditional": | |
| if self.patch_stage >= 2: | |
| self.training.job_status = "recovered" | |
| self.training.stalled_steps = 0 | |
| self.patch_stage = 3 | |
| return ActionResult( | |
| success=True, | |
| is_destructive=is_destructive, | |
| action_output={"stage": 3, "file": file, "fix_type": fix_type}, | |
| ) | |
| return ActionResult( | |
| success=False, | |
| is_destructive=is_destructive, | |
| action_output={ | |
| "error": "Must propose diff before applying patch", | |
| "current_stage": self.patch_stage, | |
| }, | |
| ) | |
| if file == self._scenario.divergent_file: | |
| self.training.job_status = "recovered" | |
| self.training.stalled_steps = 0 | |
| self.patch_stage = 3 | |
| return ActionResult( | |
| success=True, | |
| is_destructive=is_destructive, | |
| action_output={"file": file, "fix_type": fix_type}, | |
| ) | |
| return ActionResult( | |
| success=False, | |
| is_destructive=is_destructive, | |
| action_output={"error": "divergent file mismatch"}, | |
| ) | |
| if action.action_type == "restart_rank": | |
| rank_id = params.get("rank_id") | |
| if rank_id is None: | |
| return ActionResult( | |
| success=False, | |
| is_destructive=is_destructive, | |
| action_output={"error": "rank_id parameter required"}, | |
| ) | |
| node_id = int(rank_id) | |
| if node_id < 0 or node_id >= len(self.nodes): | |
| return ActionResult( | |
| success=False, | |
| is_destructive=is_destructive, | |
| action_output={"error": "rank_id out of range"}, | |
| ) | |
| node = self.nodes[node_id] | |
| node.health_status = "healthy" | |
| node.xid_errors.clear() | |
| return ActionResult( | |
| success=True, | |
| is_destructive=is_destructive, | |
| action_output={"rank_id": node_id}, | |
| ) | |
| if action.action_type == "reset_ib_interface": | |
| node_id = params.get("node_id") | |
| if node_id is None: | |
| return ActionResult( | |
| success=False, | |
| is_destructive=is_destructive, | |
| action_output={"error": "node_id parameter required"}, | |
| ) | |
| self.training.throughput_tokens_per_sec *= 1.05 | |
| return ActionResult( | |
| success=True, | |
| is_destructive=is_destructive, | |
| action_output={"node_id": int(node_id)}, | |
| ) | |
| if action.action_type == "adjust_sharding_strategy": | |
| strategy = params.get("strategy") | |
| if strategy is None: | |
| return ActionResult( | |
| success=False, | |
| is_destructive=is_destructive, | |
| action_output={"error": "strategy parameter required"}, | |
| ) | |
| self.training.throughput_tokens_per_sec *= 1.02 | |
| return ActionResult( | |
| success=True, | |
| is_destructive=is_destructive, | |
| action_output={"strategy": strategy}, | |
| ) | |
| if action.action_type == "noop": | |
| return ActionResult(success=True, is_destructive=False, action_output={}) | |
| return ActionResult( | |
| success=False, | |
| is_destructive=is_destructive, | |
| action_output={"error": "unsupported action"}, | |
| ) | |
| def advance_tick(self) -> None: | |
| """Advance training metrics by one tick.""" | |
| if self.training.job_status == "stalled": | |
| self.training.stalled_steps += 1 | |
| if self.training.stalled_steps <= 5: | |
| degrade_factor = 0.97 | |
| elif self.training.stalled_steps <= 15: | |
| degrade_factor = 0.92 | |
| else: | |
| degrade_factor = 0.85 | |
| self.training.throughput_tokens_per_sec = max( | |
| 100.0, self.training.throughput_tokens_per_sec * degrade_factor | |
| ) | |
| if self.training.stalled_steps >= 16 and self._random.random() < 0.15: | |
| cascading_candidates = [ | |
| node | |
| for node in self.nodes | |
| if node.health_status == "healthy" | |
| and node.node_id != self._scenario.failing_node_id | |
| ] | |
| if cascading_candidates: | |
| node = self._random.choice(cascading_candidates) | |
| node.health_status = "degraded" | |
| if 48 not in node.xid_errors: | |
| node.xid_errors.append(48) | |
| return | |
| if self.training.job_status == "failed": | |
| self.training.throughput_tokens_per_sec = 0.0 | |
| return | |
| if self.training.job_status == "recovered": | |
| ramp_step = min(500.0, self.training.target_throughput * 0.08) | |
| self.training.throughput_tokens_per_sec = min( | |
| self.training.target_throughput, | |
| self.training.throughput_tokens_per_sec + ramp_step, | |
| ) | |
| self.training.current_step += 1 | |
| return | |
| if ( | |
| self.training.job_status == "running" | |
| and self._scenario.failure_type in {"congestion", "cascade", "black_swan"} | |
| ): | |
| jitter = self._random.uniform(-0.05, 0.05) | |
| self.training.throughput_tokens_per_sec = max( | |
| 100.0, | |
| self.training.throughput_tokens_per_sec * (1.0 + jitter), | |
| ) | |
| if self.training.throughput_tokens_per_sec < self.training.target_throughput: | |
| self.training.throughput_tokens_per_sec *= 1.005 | |
| self.training.current_step += 1 | |
| def _initialize_nodes(self) -> list[NodeSnapshot]: | |
| """Create a fresh set of node snapshots.""" | |
| return [ | |
| NodeSnapshot( | |
| node_id=index, | |
| gpu_memory_used_mb=BASELINE_THROUGHPUT + (index * 250.0), | |
| gpu_utilization_pct=65.0, | |
| health_status="healthy", | |
| xid_errors=[], | |
| ) | |
| for index in range(NUM_NODES) | |
| ] | |
| def _initialize_training(self) -> TrainingState: | |
| """Create a fresh training state for the scenario.""" | |
| if self._scenario.failure_type == "congestion": | |
| throughput = self._scenario.target_throughput * self._scenario.congestion_bandwidth_pct | |
| else: | |
| throughput = self._scenario.target_throughput * 0.99 | |
| return TrainingState( | |
| throughput_tokens_per_sec=throughput, | |
| target_throughput=self._scenario.target_throughput, | |
| stalled_steps=0, | |
| current_step=0, | |
| job_status="running", | |
| ) | |
| def _generate_flight_recorder_data(self, rank_id: int) -> dict[str, Any]: | |
| """Generate a deterministic PyTorch 2.5 Flight Recorder payload.""" | |
| failing_rank = self._scenario.failing_rank_id | |
| base_seq_id = 1230 | |
| base_time_ns = (self._seed * 1_000_000) + (rank_id * 10_000) | |
| is_failing_rank = rank_id == failing_rank | |
| entries: list[dict[str, Any]] = [] | |
| for entry_index in range(8): | |
| seq_id = base_seq_id + entry_index | |
| time_created_ns = base_time_ns + (entry_index * 1000) | |
| if is_failing_rank and entry_index == 7: | |
| state = "scheduled" | |
| time_started_ns: int | None = None | |
| time_finished_ns: int | None = None | |
| elif is_failing_rank and entry_index == 6: | |
| state = "started" | |
| time_started_ns = time_created_ns + 100 | |
| time_finished_ns = None | |
| else: | |
| state = "completed" | |
| time_started_ns = time_created_ns + 100 | |
| time_finished_ns = time_started_ns + 500 | |
| entries.append( | |
| { | |
| "profiling_name": "nccl:all_reduce", | |
| "rank": rank_id, | |
| "collective_seq_id": seq_id, | |
| "p2p_seq_id": 0, | |
| "op_id": seq_id, | |
| "state": state, | |
| "input_sizes": [[2048, 4096]], | |
| "output_sizes": [[2048, 4096]], | |
| "input_dtypes": ["Float"], | |
| "output_dtypes": ["Float"], | |
| "timeout_ms": 1800000, | |
| "time_created_ns": time_created_ns, | |
| "time_started_ns": time_started_ns, | |
| "time_finished_ns": time_finished_ns, | |
| "frames": [ | |
| { | |
| "name": "all_reduce", | |
| "filename": "torch/distributed/distributed_c10d.py", | |
| "line": 2891, | |
| }, | |
| { | |
| "name": "nccl:ncclAllReduce", | |
| "filename": "torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp", | |
| "line": 1456, | |
| }, | |
| ], | |
| } | |
| ) | |
| if is_failing_rank: | |
| last_completed_collective = base_seq_id + 5 | |
| last_started_collective = base_seq_id + 6 | |
| last_enqueued_collective = base_seq_id + 7 | |
| else: | |
| last_completed_collective = base_seq_id + 7 | |
| last_started_collective = base_seq_id + 7 | |
| last_enqueued_collective = base_seq_id + 7 | |
| payload: dict[str, Any] = { | |
| "version": "2.5", | |
| "pg_config": { | |
| "0": { | |
| "name": "default_pg", | |
| "desc": "default_pg", | |
| "ranks": list(range(NUM_NODES)), | |
| } | |
| }, | |
| "pg_status": { | |
| "0": { | |
| "last_enqueued_collective": last_enqueued_collective, | |
| "last_started_collective": last_started_collective, | |
| "last_completed_collective": last_completed_collective, | |
| } | |
| }, | |
| "entries": entries, | |
| "has_recording": True, | |
| "record_id": (self._seed * 100) + rank_id, | |
| "capture_time_ns": base_time_ns + 8000, | |
| "global_rank": rank_id, | |
| "world_size": NUM_NODES, | |
| } | |
| seq_gap = last_enqueued_collective - last_completed_collective | |
| if is_failing_rank and seq_gap >= 2: | |
| payload["circular_buffer_warning"] = ( | |
| "buffer may be overwritten; retrieve immediately" | |
| ) | |
| return payload | |
| def _generate_nccl_logs(self, time_window: int) -> list[str]: | |
| """Generate deterministic NCCL-style log lines.""" | |
| failing_rank = self._scenario.failing_rank_id | |
| logs: list[str] = [] | |
| for i in range(time_window): | |
| step = self.training.current_step - time_window + i | |
| for rank in range(NUM_NODES): | |
| if rank == failing_rank and i > time_window // 2: | |
| logs.append( | |
| f"[{step}][rank{rank}] NCCL INFO: ncclAllReduce() timeout " | |
| f"waiting for rank {failing_rank} to join collective " | |
| f"(seq_id=1198, op=AllReduce, timeout=1800000ms)" | |
| ) | |
| else: | |
| duration_ms = 12 + ((rank + i) % 34) | |
| logs.append( | |
| f"[{step}][rank{rank}] NCCL INFO: ncclAllReduce() " | |
| f"seq_id={(1190 + i)} completed in {duration_ms}ms" | |
| ) | |
| return logs | |