nervousystem-env / simulation /challenge_generator.py
vx7sh's picture
feat(env): add curriculum, challenge generation, coalition, and black-swan mechanics
edc6488
from __future__ import annotations
from dataclasses import dataclass
from random import Random
from typing import Any, Literal
from app.config import NUM_NODES, TARGET_THROUGHPUT, ScenarioConfig
@dataclass
class ChallengeSpec:
"""Specification for a generated challenge scenario."""
challenge_id: str
base_task_id: str
seed: int
difficulty_multiplier: float
has_red_herring_node: bool = False
has_intermittent_failure: bool = False
has_version_drift: bool = False
has_correlated_rack_failure: bool = False
has_delayed_symptom: bool = False
description: str = ""
solvability_score: float = 1.0
class ChallengeGenerator:
"""
Generates novel failure compositions for targeted agent training.
"""
CHALLENGE_TYPES = [
"oom_with_red_herring",
"congestion_with_version_drift",
"delayed_desync",
"intermittent_oom",
"multi_rack_congestion",
]
def __init__(self, seed: int = 42) -> None:
self._rng = Random(seed)
self._generated: list[ChallengeSpec] = []
def generate(
self,
challenge_type: str | None = None,
difficulty_multiplier: float = 1.0,
seed: int | None = None,
) -> ChallengeSpec:
"""
Generate a challenge spec.
If challenge_type is None, pick randomly.
"""
effective_seed = seed if seed is not None else self._rng.randint(0, 9999)
if challenge_type is None:
challenge_type = self._rng.choice(self.CHALLENGE_TYPES)
rng = Random(effective_seed)
_ = rng.random()
spec = ChallengeSpec(
challenge_id=f"{challenge_type}_{effective_seed}_{difficulty_multiplier:.1f}",
base_task_id=self._base_task_for(challenge_type),
seed=effective_seed,
difficulty_multiplier=difficulty_multiplier,
)
if challenge_type == "oom_with_red_herring":
spec.has_red_herring_node = True
spec.has_correlated_rack_failure = True
spec.description = (
"OOM on one rank with a correlated rack neighbor "
"showing degraded symptoms. Agent must distinguish "
"primary failure from rack-blast-radius noise."
)
elif challenge_type == "congestion_with_version_drift":
spec.has_version_drift = True
spec.description = (
"Spine congestion AND wrong NCCL version loaded. "
"Fixing topology alone does not fully restore throughput. "
"Agent must identify and fix both issues."
)
elif challenge_type == "delayed_desync":
spec.has_delayed_symptom = True
spec.description = (
"Desync failure injected at step 5. Agent sees healthy "
"cluster initially. Must recognize delayed emergence and "
"not waste investigation actions early."
)
elif challenge_type == "intermittent_oom":
spec.has_intermittent_failure = True
spec.description = (
"OOM clears and re-triggers every 8 steps. Agent must "
"catch the failure in the active window. "
"Diagnosis during clear window yields no signal."
)
elif challenge_type == "multi_rack_congestion":
spec.has_correlated_rack_failure = True
spec.description = (
"Two spine switches congested simultaneously. "
"Single topo_reorder insufficient — agent must "
"call topo_reorder twice with different affinities."
)
if difficulty_multiplier > 2.5:
spec.solvability_score = 0.3
elif difficulty_multiplier >= 1.5:
spec.solvability_score = 0.6
else:
spec.solvability_score = 0.9
self._generated.append(spec)
return spec
def _base_task_for(self, challenge_type: str) -> str:
"""Map challenge type to base task."""
mapping = {
"oom_with_red_herring": "easy",
"congestion_with_version_drift": "medium",
"delayed_desync": "hard",
"intermittent_oom": "easy",
"multi_rack_congestion": "medium",
}
return mapping.get(challenge_type, "easy")
def to_scenario_config(self, spec: ChallengeSpec) -> ScenarioConfig:
"""
Convert a ChallengeSpec into a ScenarioConfig that the
environment can use directly.
"""
from dataclasses import replace
from app.config import build_scenario
base = build_scenario(spec.base_task_id, spec.seed)
overrides: dict[str, Any] = {}
if spec.has_version_drift:
overrides["nccl_version_loaded"] = "2.21.5"
overrides["nccl_version_expected"] = "2.27.0"
overrides["ld_library_path_corrupted"] = True
if spec.has_correlated_rack_failure:
rng = Random(spec.seed)
_ = rng.random()
other = (base.failing_node_id + 1) % NUM_NODES
overrides["correlated_fault_nodes"] = [other]
if spec.has_delayed_symptom:
overrides["cascade_phase"] = 0
if spec.has_intermittent_failure:
overrides["congestion_bandwidth_pct"] = 0.45
if overrides:
return replace(base, **overrides)
return base
def validate_solvability(self, spec: ChallengeSpec) -> dict[str, Any]:
"""
Validate that a generated challenge is actually solvable.
"""
issues = []
if spec.solvability_score <= 0.2:
issues.append("solvability_score too low — may be unsolvable")
if spec.base_task_id not in ["easy", "medium", "hard", "cascade"]:
issues.append(f"invalid base_task_id: {spec.base_task_id}")
ids = [s.challenge_id for s in self._generated]
if ids.count(spec.challenge_id) > 1:
issues.append("duplicate challenge_id")
return {
"valid": len(issues) == 0,
"solvability_score": spec.solvability_score,
"issues": issues,
}
def get_generated(self) -> list[dict[str, Any]]:
"""Return list of generated challenge specs as dicts."""
return [
{
"challenge_id": s.challenge_id,
"base_task_id": s.base_task_id,
"seed": s.seed,
"difficulty_multiplier": s.difficulty_multiplier,
"description": s.description,
"solvability_score": s.solvability_score,
"flags": {
"red_herring": s.has_red_herring_node,
"intermittent": s.has_intermittent_failure,
"version_drift": s.has_version_drift,
"correlated_rack": s.has_correlated_rack_failure,
"delayed_symptom": s.has_delayed_symptom,
},
}
for s in self._generated
]