vx7sh's picture
feat(curriculum): adaptive difficulty for telemetry, masking, and secondary failures
3928ed0
"""Failure injection utilities for cluster simulation."""
from __future__ import annotations
from typing import TYPE_CHECKING
from random import Random
from simulation.cluster import FailureSignal
if TYPE_CHECKING:
from simulation.cluster import ClusterStateMachine
class FailureInjector:
"""Injects rare secondary failures mid-episode."""
def __init__(
self,
seed: int,
secondary_failure_probability: float = 0.05,
) -> None:
"""Initialize failure injector with a seed."""
self._random = Random(seed)
self._secondary_failure_probability = secondary_failure_probability
def reset(self, seed: int) -> None:
"""Reset the injector to a new seed."""
self._random = Random(seed)
def configure_difficulty(self, secondary_failure_probability: float) -> None:
"""Apply adaptive curriculum secondary-failure probability."""
self._secondary_failure_probability = secondary_failure_probability
def maybe_inject_secondary(self, step: int) -> FailureSignal | None:
"""Return a secondary failure signal with small probability."""
if step < 10:
return None
if self._random.random() < self._secondary_failure_probability:
return FailureSignal(
node_id=self._random.randint(0, 7),
severity=self._random.choice(["minor", "minor", "major"]),
cause="secondary_hardware_fault",
)
return None
def inject_cascade_phase(
self,
cluster: "ClusterStateMachine",
step: int,
) -> int:
"""Drive deterministic cascade phase transitions by step boundaries."""
failing_node_id = cluster._scenario.failing_node_id
failing_rank_id = cluster._scenario.failing_rank_id
failing_node = cluster.nodes[failing_node_id]
if step <= 20:
failing_node.health_status = "failed"
if 79 not in failing_node.xid_errors:
failing_node.xid_errors.append(79)
cluster.training.job_status = "stalled"
return 1
if step == 21:
cluster.training.throughput_tokens_per_sec = cluster.training.target_throughput * 0.55
cluster.training.job_status = "running"
failing_node.health_status = "degraded"
if 21 <= step <= 50:
return 2
if step == 51:
cluster.training.job_status = "stalled"
cluster.training.stalled_steps += 10
cluster.divergent_rank_id = failing_rank_id
return 3