Spaces:
Sleeping
Sleeping
File size: 2,602 Bytes
29733b9 9f406b5 29733b9 9f406b5 29733b9 3928ed0 29733b9 3928ed0 29733b9 3928ed0 29733b9 9f406b5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 | """Failure injection utilities for cluster simulation."""
from __future__ import annotations
from typing import TYPE_CHECKING
from random import Random
from simulation.cluster import FailureSignal
if TYPE_CHECKING:
from simulation.cluster import ClusterStateMachine
class FailureInjector:
"""Injects rare secondary failures mid-episode."""
def __init__(
self,
seed: int,
secondary_failure_probability: float = 0.05,
) -> None:
"""Initialize failure injector with a seed."""
self._random = Random(seed)
self._secondary_failure_probability = secondary_failure_probability
def reset(self, seed: int) -> None:
"""Reset the injector to a new seed."""
self._random = Random(seed)
def configure_difficulty(self, secondary_failure_probability: float) -> None:
"""Apply adaptive curriculum secondary-failure probability."""
self._secondary_failure_probability = secondary_failure_probability
def maybe_inject_secondary(self, step: int) -> FailureSignal | None:
"""Return a secondary failure signal with small probability."""
if step < 10:
return None
if self._random.random() < self._secondary_failure_probability:
return FailureSignal(
node_id=self._random.randint(0, 7),
severity=self._random.choice(["minor", "minor", "major"]),
cause="secondary_hardware_fault",
)
return None
def inject_cascade_phase(
self,
cluster: "ClusterStateMachine",
step: int,
) -> int:
"""Drive deterministic cascade phase transitions by step boundaries."""
failing_node_id = cluster._scenario.failing_node_id
failing_rank_id = cluster._scenario.failing_rank_id
failing_node = cluster.nodes[failing_node_id]
if step <= 20:
failing_node.health_status = "failed"
if 79 not in failing_node.xid_errors:
failing_node.xid_errors.append(79)
cluster.training.job_status = "stalled"
return 1
if step == 21:
cluster.training.throughput_tokens_per_sec = cluster.training.target_throughput * 0.55
cluster.training.job_status = "running"
failing_node.health_status = "degraded"
if 21 <= step <= 50:
return 2
if step == 51:
cluster.training.job_status = "stalled"
cluster.training.stalled_steps += 10
cluster.divergent_rank_id = failing_rank_id
return 3
|