"""GPU cluster state machine implementation.""" from __future__ import annotations from dataclasses import dataclass from random import Random from typing import Any, Literal from app.config import DESTRUCTIVE_ACTIONS, BASELINE_THROUGHPUT, NUM_NODES, ScenarioConfig from app.models import SREAction @dataclass class TrainingState: """Mutable training metrics for the cluster.""" throughput_tokens_per_sec: float target_throughput: float stalled_steps: int current_step: int job_status: Literal["running", "stalled", "failed", "recovered"] @dataclass class NodeSnapshot: """Internal node state for simulation.""" node_id: int gpu_memory_used_mb: float gpu_utilization_pct: float health_status: Literal["healthy", "degraded", "failed"] xid_errors: list[Literal[79, 48, 31, 74]] @dataclass class FailureSignal: """Represents an injected failure condition.""" node_id: int | None severity: Literal["minor", "major", "critical"] cause: str @dataclass class ActionResult: """Outcome metadata for an SRE action.""" success: bool is_destructive: bool action_output: dict[str, Any] class ClusterStateMachine: """State machine that simulates a distributed GPU cluster.""" def __init__(self, scenario: ScenarioConfig, seed: int) -> None: """Initialize the cluster with a scenario configuration.""" self._scenario = scenario self._seed = seed self._random = Random(seed) self.patch_stage: int = 0 self.divergent_rank_id: int | None = None self.nodes: list[NodeSnapshot] = self._initialize_nodes() self.training = self._initialize_training() def reset(self, scenario: ScenarioConfig, seed: int) -> None: """Reset cluster state with a new scenario and seed.""" self._scenario = scenario self._seed = seed self._random = Random(seed) self.patch_stage = 0 self.divergent_rank_id = None self.nodes = self._initialize_nodes() self.training = self._initialize_training() def inject_scenario_failure(self) -> None: """Inject the scenario failure once at episode start.""" if self._scenario.failure_type == "oom": node = self.nodes[self._scenario.failing_node_id] node.health_status = "failed" node.xid_errors.append(79) self.training.job_status = "stalled" elif self._scenario.failure_type == "congestion": self.training.throughput_tokens_per_sec = ( self._scenario.target_throughput * self._scenario.congestion_bandwidth_pct ) self.training.job_status = "running" elif self._scenario.failure_type == "desync": self.training.job_status = "stalled" self.training.stalled_steps = 5 self.divergent_rank_id = self._scenario.failing_rank_id elif self._scenario.failure_type == "cascade": node = self.nodes[self._scenario.failing_node_id] node.health_status = "failed" node.xid_errors.append(79) for correlated_node_id in self._scenario.correlated_fault_nodes: if 0 <= correlated_node_id < len(self.nodes): correlated_node = self.nodes[correlated_node_id] if correlated_node.health_status == "healthy": correlated_node.health_status = "degraded" if 48 not in correlated_node.xid_errors: correlated_node.xid_errors.append(48) self.training.job_status = "stalled" elif self._scenario.failure_type == "black_swan": node = self.nodes[self._scenario.failing_node_id] node.health_status = "failed" node.xid_errors.append(79) for false_positive_node_id in self._scenario.false_positive_nodes: if 0 <= false_positive_node_id < len(self.nodes): false_node = self.nodes[false_positive_node_id] if false_node.health_status == "healthy": false_node.health_status = "degraded" if 48 not in false_node.xid_errors: false_node.xid_errors.append(48) self.training.throughput_tokens_per_sec = ( self._scenario.target_throughput * self._scenario.congestion_bandwidth_pct ) self.training.job_status = "stalled" def apply_failure(self, failure: FailureSignal | None) -> None: """Apply a secondary failure signal to cluster nodes.""" if failure is None or failure.node_id is None: return node = self.nodes[failure.node_id] if failure.severity == "minor": node.health_status = "degraded" node.xid_errors.append(79) elif failure.severity == "major": node.health_status = "degraded" node.xid_errors.append(48) self.training.job_status = "stalled" else: node.health_status = "failed" node.xid_errors.append(74) self.training.job_status = "failed" def apply_action(self, action: SREAction) -> ActionResult: """Apply the given SRE action to the cluster.""" is_destructive = action.action_type in DESTRUCTIVE_ACTIONS params = action.parameters if action.action_type == "inspect_flight_recorder": rank_id = params.get("rank_id") if rank_id is None: return ActionResult( success=False, is_destructive=is_destructive, action_output={"error": "rank_id parameter required"}, ) return ActionResult( success=True, is_destructive=is_destructive, action_output={ "flight_recorder": self._generate_flight_recorder_data(int(rank_id)) }, ) if action.action_type == "query_nccl_logs": time_window = int(params.get("time_window", 10)) return ActionResult( success=True, is_destructive=is_destructive, action_output={"nccl_logs": self._generate_nccl_logs(time_window)}, ) if action.action_type == "topo_reorder": affinity = params.get("affinity") if affinity is None: return ActionResult( success=False, is_destructive=is_destructive, action_output={"error": "affinity parameter required"}, ) if affinity == "rack": rack_map = { node_id: rack_index for rack_index, rack in enumerate(self._scenario.rack_layout) for node_id in rack } ring = list(range(len(self.nodes))) crosses_racks = any( rack_map.get(ring[index]) != rack_map.get(ring[(index + 1) % len(ring)]) for index in range(len(ring)) ) # With randomized rack layout, ring usually crosses racks; rack-local reorder helps most. boost = 1.35 if crosses_racks else 1.05 self.training.throughput_tokens_per_sec *= boost self.training.job_status = "recovered" else: self.training.throughput_tokens_per_sec *= 1.05 return ActionResult( success=True, is_destructive=is_destructive, action_output={"affinity": affinity}, ) if action.action_type == "patch_divergent_code": file = params.get("file") fix_type = params.get("fix_type") if file is None or fix_type is None: return ActionResult( success=False, is_destructive=is_destructive, action_output={"error": "file and fix_type required"}, ) if fix_type == "identify_file": if file == self._scenario.divergent_file: self.patch_stage = max(self.patch_stage, 1) return ActionResult( success=True, is_destructive=is_destructive, action_output={ "stage": 1, "file": file, "hint": "File confirmed divergent. Propose a diff next.", }, ) return ActionResult( success=False, is_destructive=is_destructive, action_output={"stage": 0, "error": "wrong file"}, ) if fix_type == "propose_diff": if self.patch_stage >= 1: self.patch_stage = max(self.patch_stage, 2) return ActionResult( success=True, is_destructive=is_destructive, action_output={ "stage": 2, "hint": "Diff accepted. Apply synchronize_conditional to fix.", }, ) return ActionResult( success=False, is_destructive=is_destructive, action_output={"error": "Must identify file first"}, ) if fix_type == "synchronize_conditional": if self.patch_stage >= 2: self.training.job_status = "recovered" self.training.stalled_steps = 0 self.patch_stage = 3 return ActionResult( success=True, is_destructive=is_destructive, action_output={"stage": 3, "file": file, "fix_type": fix_type}, ) return ActionResult( success=False, is_destructive=is_destructive, action_output={ "error": "Must propose diff before applying patch", "current_stage": self.patch_stage, }, ) if file == self._scenario.divergent_file: self.training.job_status = "recovered" self.training.stalled_steps = 0 self.patch_stage = 3 return ActionResult( success=True, is_destructive=is_destructive, action_output={"file": file, "fix_type": fix_type}, ) return ActionResult( success=False, is_destructive=is_destructive, action_output={"error": "divergent file mismatch"}, ) if action.action_type == "restart_rank": rank_id = params.get("rank_id") if rank_id is None: return ActionResult( success=False, is_destructive=is_destructive, action_output={"error": "rank_id parameter required"}, ) node_id = int(rank_id) if node_id < 0 or node_id >= len(self.nodes): return ActionResult( success=False, is_destructive=is_destructive, action_output={"error": "rank_id out of range"}, ) node = self.nodes[node_id] node.health_status = "healthy" node.xid_errors.clear() return ActionResult( success=True, is_destructive=is_destructive, action_output={"rank_id": node_id}, ) if action.action_type == "reset_ib_interface": node_id = params.get("node_id") if node_id is None: return ActionResult( success=False, is_destructive=is_destructive, action_output={"error": "node_id parameter required"}, ) self.training.throughput_tokens_per_sec *= 1.05 return ActionResult( success=True, is_destructive=is_destructive, action_output={"node_id": int(node_id)}, ) if action.action_type == "adjust_sharding_strategy": strategy = params.get("strategy") if strategy is None: return ActionResult( success=False, is_destructive=is_destructive, action_output={"error": "strategy parameter required"}, ) self.training.throughput_tokens_per_sec *= 1.02 return ActionResult( success=True, is_destructive=is_destructive, action_output={"strategy": strategy}, ) if action.action_type == "noop": return ActionResult(success=True, is_destructive=False, action_output={}) return ActionResult( success=False, is_destructive=is_destructive, action_output={"error": "unsupported action"}, ) def advance_tick(self) -> None: """Advance training metrics by one tick.""" if self.training.job_status == "stalled": self.training.stalled_steps += 1 if self.training.stalled_steps <= 5: degrade_factor = 0.97 elif self.training.stalled_steps <= 15: degrade_factor = 0.92 else: degrade_factor = 0.85 self.training.throughput_tokens_per_sec = max( 100.0, self.training.throughput_tokens_per_sec * degrade_factor ) if self.training.stalled_steps >= 16 and self._random.random() < 0.15: cascading_candidates = [ node for node in self.nodes if node.health_status == "healthy" and node.node_id != self._scenario.failing_node_id ] if cascading_candidates: node = self._random.choice(cascading_candidates) node.health_status = "degraded" if 48 not in node.xid_errors: node.xid_errors.append(48) return if self.training.job_status == "failed": self.training.throughput_tokens_per_sec = 0.0 return if self.training.job_status == "recovered": ramp_step = min(500.0, self.training.target_throughput * 0.08) self.training.throughput_tokens_per_sec = min( self.training.target_throughput, self.training.throughput_tokens_per_sec + ramp_step, ) self.training.current_step += 1 return if ( self.training.job_status == "running" and self._scenario.failure_type in {"congestion", "cascade", "black_swan"} ): jitter = self._random.uniform(-0.05, 0.05) self.training.throughput_tokens_per_sec = max( 100.0, self.training.throughput_tokens_per_sec * (1.0 + jitter), ) if self.training.throughput_tokens_per_sec < self.training.target_throughput: self.training.throughput_tokens_per_sec *= 1.005 self.training.current_step += 1 def _initialize_nodes(self) -> list[NodeSnapshot]: """Create a fresh set of node snapshots.""" return [ NodeSnapshot( node_id=index, gpu_memory_used_mb=BASELINE_THROUGHPUT + (index * 250.0), gpu_utilization_pct=65.0, health_status="healthy", xid_errors=[], ) for index in range(NUM_NODES) ] def _initialize_training(self) -> TrainingState: """Create a fresh training state for the scenario.""" if self._scenario.failure_type == "congestion": throughput = self._scenario.target_throughput * self._scenario.congestion_bandwidth_pct else: throughput = self._scenario.target_throughput * 0.99 return TrainingState( throughput_tokens_per_sec=throughput, target_throughput=self._scenario.target_throughput, stalled_steps=0, current_step=0, job_status="running", ) def _generate_flight_recorder_data(self, rank_id: int) -> dict[str, Any]: """Generate a deterministic PyTorch 2.5 Flight Recorder payload.""" failing_rank = self._scenario.failing_rank_id base_seq_id = 1230 base_time_ns = (self._seed * 1_000_000) + (rank_id * 10_000) is_failing_rank = rank_id == failing_rank entries: list[dict[str, Any]] = [] for entry_index in range(8): seq_id = base_seq_id + entry_index time_created_ns = base_time_ns + (entry_index * 1000) if is_failing_rank and entry_index == 7: state = "scheduled" time_started_ns: int | None = None time_finished_ns: int | None = None elif is_failing_rank and entry_index == 6: state = "started" time_started_ns = time_created_ns + 100 time_finished_ns = None else: state = "completed" time_started_ns = time_created_ns + 100 time_finished_ns = time_started_ns + 500 entries.append( { "profiling_name": "nccl:all_reduce", "rank": rank_id, "collective_seq_id": seq_id, "p2p_seq_id": 0, "op_id": seq_id, "state": state, "input_sizes": [[2048, 4096]], "output_sizes": [[2048, 4096]], "input_dtypes": ["Float"], "output_dtypes": ["Float"], "timeout_ms": 1800000, "time_created_ns": time_created_ns, "time_started_ns": time_started_ns, "time_finished_ns": time_finished_ns, "frames": [ { "name": "all_reduce", "filename": "torch/distributed/distributed_c10d.py", "line": 2891, }, { "name": "nccl:ncclAllReduce", "filename": "torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp", "line": 1456, }, ], } ) if is_failing_rank: last_completed_collective = base_seq_id + 5 last_started_collective = base_seq_id + 6 last_enqueued_collective = base_seq_id + 7 else: last_completed_collective = base_seq_id + 7 last_started_collective = base_seq_id + 7 last_enqueued_collective = base_seq_id + 7 payload: dict[str, Any] = { "version": "2.5", "pg_config": { "0": { "name": "default_pg", "desc": "default_pg", "ranks": list(range(NUM_NODES)), } }, "pg_status": { "0": { "last_enqueued_collective": last_enqueued_collective, "last_started_collective": last_started_collective, "last_completed_collective": last_completed_collective, } }, "entries": entries, "has_recording": True, "record_id": (self._seed * 100) + rank_id, "capture_time_ns": base_time_ns + 8000, "global_rank": rank_id, "world_size": NUM_NODES, } seq_gap = last_enqueued_collective - last_completed_collective if is_failing_rank and seq_gap >= 2: payload["circular_buffer_warning"] = ( "buffer may be overwritten; retrieve immediately" ) return payload def _generate_nccl_logs(self, time_window: int) -> list[str]: """Generate deterministic NCCL-style log lines.""" failing_rank = self._scenario.failing_rank_id logs: list[str] = [] for i in range(time_window): step = self.training.current_step - time_window + i for rank in range(NUM_NODES): if rank == failing_rank and i > time_window // 2: logs.append( f"[{step}][rank{rank}] NCCL INFO: ncclAllReduce() timeout " f"waiting for rank {failing_rank} to join collective " f"(seq_id=1198, op=AllReduce, timeout=1800000ms)" ) else: duration_ms = 12 + ((rank + i) % 34) logs.append( f"[{step}][rank{rank}] NCCL INFO: ncclAllReduce() " f"seq_id={(1190 + i)} completed in {duration_ms}ms" ) return logs