"""Failure injection utilities for cluster simulation.""" from __future__ import annotations from typing import TYPE_CHECKING from random import Random from simulation.cluster import FailureSignal if TYPE_CHECKING: from simulation.cluster import ClusterStateMachine class FailureInjector: """Injects rare secondary failures mid-episode.""" def __init__( self, seed: int, secondary_failure_probability: float = 0.05, ) -> None: """Initialize failure injector with a seed.""" self._random = Random(seed) self._secondary_failure_probability = secondary_failure_probability def reset(self, seed: int) -> None: """Reset the injector to a new seed.""" self._random = Random(seed) def configure_difficulty(self, secondary_failure_probability: float) -> None: """Apply adaptive curriculum secondary-failure probability.""" self._secondary_failure_probability = secondary_failure_probability def maybe_inject_secondary(self, step: int) -> FailureSignal | None: """Return a secondary failure signal with small probability.""" if step < 10: return None if self._random.random() < self._secondary_failure_probability: return FailureSignal( node_id=self._random.randint(0, 7), severity=self._random.choice(["minor", "minor", "major"]), cause="secondary_hardware_fault", ) return None def inject_cascade_phase( self, cluster: "ClusterStateMachine", step: int, ) -> int: """Drive deterministic cascade phase transitions by step boundaries.""" failing_node_id = cluster._scenario.failing_node_id failing_rank_id = cluster._scenario.failing_rank_id failing_node = cluster.nodes[failing_node_id] if step <= 20: failing_node.health_status = "failed" if 79 not in failing_node.xid_errors: failing_node.xid_errors.append(79) cluster.training.job_status = "stalled" return 1 if step == 21: cluster.training.throughput_tokens_per_sec = cluster.training.target_throughput * 0.55 cluster.training.job_status = "running" failing_node.health_status = "degraded" if 21 <= step <= 50: return 2 if step == 51: cluster.training.job_status = "stalled" cluster.training.stalled_steps += 10 cluster.divergent_rank_id = failing_rank_id return 3