vx7sh's picture
feat(env): add curriculum, challenge generation, coalition, and black-swan mechanics
edc6488
"""GPU cluster state machine implementation."""
from __future__ import annotations
from dataclasses import dataclass
from random import Random
from typing import Any, Literal
from app.config import DESTRUCTIVE_ACTIONS, BASELINE_THROUGHPUT, NUM_NODES, ScenarioConfig
from app.models import SREAction
@dataclass
class TrainingState:
"""Mutable training metrics for the cluster."""
throughput_tokens_per_sec: float
target_throughput: float
stalled_steps: int
current_step: int
job_status: Literal["running", "stalled", "failed", "recovered"]
@dataclass
class NodeSnapshot:
"""Internal node state for simulation."""
node_id: int
gpu_memory_used_mb: float
gpu_utilization_pct: float
health_status: Literal["healthy", "degraded", "failed"]
xid_errors: list[Literal[79, 48, 31, 74]]
@dataclass
class FailureSignal:
"""Represents an injected failure condition."""
node_id: int | None
severity: Literal["minor", "major", "critical"]
cause: str
@dataclass
class ActionResult:
"""Outcome metadata for an SRE action."""
success: bool
is_destructive: bool
action_output: dict[str, Any]
class ClusterStateMachine:
"""State machine that simulates a distributed GPU cluster."""
def __init__(self, scenario: ScenarioConfig, seed: int) -> None:
"""Initialize the cluster with a scenario configuration."""
self._scenario = scenario
self._seed = seed
self._random = Random(seed)
self.patch_stage: int = 0
self.divergent_rank_id: int | None = None
self.nodes: list[NodeSnapshot] = self._initialize_nodes()
self.training = self._initialize_training()
def reset(self, scenario: ScenarioConfig, seed: int) -> None:
"""Reset cluster state with a new scenario and seed."""
self._scenario = scenario
self._seed = seed
self._random = Random(seed)
self.patch_stage = 0
self.divergent_rank_id = None
self.nodes = self._initialize_nodes()
self.training = self._initialize_training()
def inject_scenario_failure(self) -> None:
"""Inject the scenario failure once at episode start."""
if self._scenario.failure_type == "oom":
node = self.nodes[self._scenario.failing_node_id]
node.health_status = "failed"
node.xid_errors.append(79)
self.training.job_status = "stalled"
elif self._scenario.failure_type == "congestion":
self.training.throughput_tokens_per_sec = (
self._scenario.target_throughput * self._scenario.congestion_bandwidth_pct
)
self.training.job_status = "running"
elif self._scenario.failure_type == "desync":
self.training.job_status = "stalled"
self.training.stalled_steps = 5
self.divergent_rank_id = self._scenario.failing_rank_id
elif self._scenario.failure_type == "cascade":
node = self.nodes[self._scenario.failing_node_id]
node.health_status = "failed"
node.xid_errors.append(79)
for correlated_node_id in self._scenario.correlated_fault_nodes:
if 0 <= correlated_node_id < len(self.nodes):
correlated_node = self.nodes[correlated_node_id]
if correlated_node.health_status == "healthy":
correlated_node.health_status = "degraded"
if 48 not in correlated_node.xid_errors:
correlated_node.xid_errors.append(48)
self.training.job_status = "stalled"
elif self._scenario.failure_type == "black_swan":
node = self.nodes[self._scenario.failing_node_id]
node.health_status = "failed"
node.xid_errors.append(79)
for false_positive_node_id in self._scenario.false_positive_nodes:
if 0 <= false_positive_node_id < len(self.nodes):
false_node = self.nodes[false_positive_node_id]
if false_node.health_status == "healthy":
false_node.health_status = "degraded"
if 48 not in false_node.xid_errors:
false_node.xid_errors.append(48)
self.training.throughput_tokens_per_sec = (
self._scenario.target_throughput * self._scenario.congestion_bandwidth_pct
)
self.training.job_status = "stalled"
def apply_failure(self, failure: FailureSignal | None) -> None:
"""Apply a secondary failure signal to cluster nodes."""
if failure is None or failure.node_id is None:
return
node = self.nodes[failure.node_id]
if failure.severity == "minor":
node.health_status = "degraded"
node.xid_errors.append(79)
elif failure.severity == "major":
node.health_status = "degraded"
node.xid_errors.append(48)
self.training.job_status = "stalled"
else:
node.health_status = "failed"
node.xid_errors.append(74)
self.training.job_status = "failed"
def apply_action(self, action: SREAction) -> ActionResult:
"""Apply the given SRE action to the cluster."""
is_destructive = action.action_type in DESTRUCTIVE_ACTIONS
params = action.parameters
if action.action_type == "inspect_flight_recorder":
rank_id = params.get("rank_id")
if rank_id is None:
return ActionResult(
success=False,
is_destructive=is_destructive,
action_output={"error": "rank_id parameter required"},
)
return ActionResult(
success=True,
is_destructive=is_destructive,
action_output={
"flight_recorder": self._generate_flight_recorder_data(int(rank_id))
},
)
if action.action_type == "query_nccl_logs":
time_window = int(params.get("time_window", 10))
return ActionResult(
success=True,
is_destructive=is_destructive,
action_output={"nccl_logs": self._generate_nccl_logs(time_window)},
)
if action.action_type == "topo_reorder":
affinity = params.get("affinity")
if affinity is None:
return ActionResult(
success=False,
is_destructive=is_destructive,
action_output={"error": "affinity parameter required"},
)
if affinity == "rack":
rack_map = {
node_id: rack_index
for rack_index, rack in enumerate(self._scenario.rack_layout)
for node_id in rack
}
ring = list(range(len(self.nodes)))
crosses_racks = any(
rack_map.get(ring[index]) != rack_map.get(ring[(index + 1) % len(ring)])
for index in range(len(ring))
)
# With randomized rack layout, ring usually crosses racks; rack-local reorder helps most.
boost = 1.35 if crosses_racks else 1.05
self.training.throughput_tokens_per_sec *= boost
self.training.job_status = "recovered"
else:
self.training.throughput_tokens_per_sec *= 1.05
return ActionResult(
success=True,
is_destructive=is_destructive,
action_output={"affinity": affinity},
)
if action.action_type == "patch_divergent_code":
file = params.get("file")
fix_type = params.get("fix_type")
if file is None or fix_type is None:
return ActionResult(
success=False,
is_destructive=is_destructive,
action_output={"error": "file and fix_type required"},
)
if fix_type == "identify_file":
if file == self._scenario.divergent_file:
self.patch_stage = max(self.patch_stage, 1)
return ActionResult(
success=True,
is_destructive=is_destructive,
action_output={
"stage": 1,
"file": file,
"hint": "File confirmed divergent. Propose a diff next.",
},
)
return ActionResult(
success=False,
is_destructive=is_destructive,
action_output={"stage": 0, "error": "wrong file"},
)
if fix_type == "propose_diff":
if self.patch_stage >= 1:
self.patch_stage = max(self.patch_stage, 2)
return ActionResult(
success=True,
is_destructive=is_destructive,
action_output={
"stage": 2,
"hint": "Diff accepted. Apply synchronize_conditional to fix.",
},
)
return ActionResult(
success=False,
is_destructive=is_destructive,
action_output={"error": "Must identify file first"},
)
if fix_type == "synchronize_conditional":
if self.patch_stage >= 2:
self.training.job_status = "recovered"
self.training.stalled_steps = 0
self.patch_stage = 3
return ActionResult(
success=True,
is_destructive=is_destructive,
action_output={"stage": 3, "file": file, "fix_type": fix_type},
)
return ActionResult(
success=False,
is_destructive=is_destructive,
action_output={
"error": "Must propose diff before applying patch",
"current_stage": self.patch_stage,
},
)
if file == self._scenario.divergent_file:
self.training.job_status = "recovered"
self.training.stalled_steps = 0
self.patch_stage = 3
return ActionResult(
success=True,
is_destructive=is_destructive,
action_output={"file": file, "fix_type": fix_type},
)
return ActionResult(
success=False,
is_destructive=is_destructive,
action_output={"error": "divergent file mismatch"},
)
if action.action_type == "restart_rank":
rank_id = params.get("rank_id")
if rank_id is None:
return ActionResult(
success=False,
is_destructive=is_destructive,
action_output={"error": "rank_id parameter required"},
)
node_id = int(rank_id)
if node_id < 0 or node_id >= len(self.nodes):
return ActionResult(
success=False,
is_destructive=is_destructive,
action_output={"error": "rank_id out of range"},
)
node = self.nodes[node_id]
node.health_status = "healthy"
node.xid_errors.clear()
return ActionResult(
success=True,
is_destructive=is_destructive,
action_output={"rank_id": node_id},
)
if action.action_type == "reset_ib_interface":
node_id = params.get("node_id")
if node_id is None:
return ActionResult(
success=False,
is_destructive=is_destructive,
action_output={"error": "node_id parameter required"},
)
self.training.throughput_tokens_per_sec *= 1.05
return ActionResult(
success=True,
is_destructive=is_destructive,
action_output={"node_id": int(node_id)},
)
if action.action_type == "adjust_sharding_strategy":
strategy = params.get("strategy")
if strategy is None:
return ActionResult(
success=False,
is_destructive=is_destructive,
action_output={"error": "strategy parameter required"},
)
self.training.throughput_tokens_per_sec *= 1.02
return ActionResult(
success=True,
is_destructive=is_destructive,
action_output={"strategy": strategy},
)
if action.action_type == "noop":
return ActionResult(success=True, is_destructive=False, action_output={})
return ActionResult(
success=False,
is_destructive=is_destructive,
action_output={"error": "unsupported action"},
)
def advance_tick(self) -> None:
"""Advance training metrics by one tick."""
if self.training.job_status == "stalled":
self.training.stalled_steps += 1
if self.training.stalled_steps <= 5:
degrade_factor = 0.97
elif self.training.stalled_steps <= 15:
degrade_factor = 0.92
else:
degrade_factor = 0.85
self.training.throughput_tokens_per_sec = max(
100.0, self.training.throughput_tokens_per_sec * degrade_factor
)
if self.training.stalled_steps >= 16 and self._random.random() < 0.15:
cascading_candidates = [
node
for node in self.nodes
if node.health_status == "healthy"
and node.node_id != self._scenario.failing_node_id
]
if cascading_candidates:
node = self._random.choice(cascading_candidates)
node.health_status = "degraded"
if 48 not in node.xid_errors:
node.xid_errors.append(48)
return
if self.training.job_status == "failed":
self.training.throughput_tokens_per_sec = 0.0
return
if self.training.job_status == "recovered":
ramp_step = min(500.0, self.training.target_throughput * 0.08)
self.training.throughput_tokens_per_sec = min(
self.training.target_throughput,
self.training.throughput_tokens_per_sec + ramp_step,
)
self.training.current_step += 1
return
if (
self.training.job_status == "running"
and self._scenario.failure_type in {"congestion", "cascade", "black_swan"}
):
jitter = self._random.uniform(-0.05, 0.05)
self.training.throughput_tokens_per_sec = max(
100.0,
self.training.throughput_tokens_per_sec * (1.0 + jitter),
)
if self.training.throughput_tokens_per_sec < self.training.target_throughput:
self.training.throughput_tokens_per_sec *= 1.005
self.training.current_step += 1
def _initialize_nodes(self) -> list[NodeSnapshot]:
"""Create a fresh set of node snapshots."""
return [
NodeSnapshot(
node_id=index,
gpu_memory_used_mb=BASELINE_THROUGHPUT + (index * 250.0),
gpu_utilization_pct=65.0,
health_status="healthy",
xid_errors=[],
)
for index in range(NUM_NODES)
]
def _initialize_training(self) -> TrainingState:
"""Create a fresh training state for the scenario."""
if self._scenario.failure_type == "congestion":
throughput = self._scenario.target_throughput * self._scenario.congestion_bandwidth_pct
else:
throughput = self._scenario.target_throughput * 0.99
return TrainingState(
throughput_tokens_per_sec=throughput,
target_throughput=self._scenario.target_throughput,
stalled_steps=0,
current_step=0,
job_status="running",
)
def _generate_flight_recorder_data(self, rank_id: int) -> dict[str, Any]:
"""Generate a deterministic PyTorch 2.5 Flight Recorder payload."""
failing_rank = self._scenario.failing_rank_id
base_seq_id = 1230
base_time_ns = (self._seed * 1_000_000) + (rank_id * 10_000)
is_failing_rank = rank_id == failing_rank
entries: list[dict[str, Any]] = []
for entry_index in range(8):
seq_id = base_seq_id + entry_index
time_created_ns = base_time_ns + (entry_index * 1000)
if is_failing_rank and entry_index == 7:
state = "scheduled"
time_started_ns: int | None = None
time_finished_ns: int | None = None
elif is_failing_rank and entry_index == 6:
state = "started"
time_started_ns = time_created_ns + 100
time_finished_ns = None
else:
state = "completed"
time_started_ns = time_created_ns + 100
time_finished_ns = time_started_ns + 500
entries.append(
{
"profiling_name": "nccl:all_reduce",
"rank": rank_id,
"collective_seq_id": seq_id,
"p2p_seq_id": 0,
"op_id": seq_id,
"state": state,
"input_sizes": [[2048, 4096]],
"output_sizes": [[2048, 4096]],
"input_dtypes": ["Float"],
"output_dtypes": ["Float"],
"timeout_ms": 1800000,
"time_created_ns": time_created_ns,
"time_started_ns": time_started_ns,
"time_finished_ns": time_finished_ns,
"frames": [
{
"name": "all_reduce",
"filename": "torch/distributed/distributed_c10d.py",
"line": 2891,
},
{
"name": "nccl:ncclAllReduce",
"filename": "torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp",
"line": 1456,
},
],
}
)
if is_failing_rank:
last_completed_collective = base_seq_id + 5
last_started_collective = base_seq_id + 6
last_enqueued_collective = base_seq_id + 7
else:
last_completed_collective = base_seq_id + 7
last_started_collective = base_seq_id + 7
last_enqueued_collective = base_seq_id + 7
payload: dict[str, Any] = {
"version": "2.5",
"pg_config": {
"0": {
"name": "default_pg",
"desc": "default_pg",
"ranks": list(range(NUM_NODES)),
}
},
"pg_status": {
"0": {
"last_enqueued_collective": last_enqueued_collective,
"last_started_collective": last_started_collective,
"last_completed_collective": last_completed_collective,
}
},
"entries": entries,
"has_recording": True,
"record_id": (self._seed * 100) + rank_id,
"capture_time_ns": base_time_ns + 8000,
"global_rank": rank_id,
"world_size": NUM_NODES,
}
seq_gap = last_enqueued_collective - last_completed_collective
if is_failing_rank and seq_gap >= 2:
payload["circular_buffer_warning"] = (
"buffer may be overwritten; retrieve immediately"
)
return payload
def _generate_nccl_logs(self, time_window: int) -> list[str]:
"""Generate deterministic NCCL-style log lines."""
failing_rank = self._scenario.failing_rank_id
logs: list[str] = []
for i in range(time_window):
step = self.training.current_step - time_window + i
for rank in range(NUM_NODES):
if rank == failing_rank and i > time_window // 2:
logs.append(
f"[{step}][rank{rank}] NCCL INFO: ncclAllReduce() timeout "
f"waiting for rank {failing_rank} to join collective "
f"(seq_id=1198, op=AllReduce, timeout=1800000ms)"
)
else:
duration_ms = 12 + ((rank + i) % 34)
logs.append(
f"[{step}][rank{rank}] NCCL INFO: ncclAllReduce() "
f"seq_id={(1190 + i)} completed in {duration_ms}ms"
)
return logs