from __future__ import annotations from dataclasses import dataclass from random import Random from typing import Any, Literal from app.config import NUM_NODES, TARGET_THROUGHPUT, ScenarioConfig @dataclass class ChallengeSpec: """Specification for a generated challenge scenario.""" challenge_id: str base_task_id: str seed: int difficulty_multiplier: float has_red_herring_node: bool = False has_intermittent_failure: bool = False has_version_drift: bool = False has_correlated_rack_failure: bool = False has_delayed_symptom: bool = False description: str = "" solvability_score: float = 1.0 class ChallengeGenerator: """ Generates novel failure compositions for targeted agent training. """ CHALLENGE_TYPES = [ "oom_with_red_herring", "congestion_with_version_drift", "delayed_desync", "intermittent_oom", "multi_rack_congestion", ] def __init__(self, seed: int = 42) -> None: self._rng = Random(seed) self._generated: list[ChallengeSpec] = [] def generate( self, challenge_type: str | None = None, difficulty_multiplier: float = 1.0, seed: int | None = None, ) -> ChallengeSpec: """ Generate a challenge spec. If challenge_type is None, pick randomly. """ effective_seed = seed if seed is not None else self._rng.randint(0, 9999) if challenge_type is None: challenge_type = self._rng.choice(self.CHALLENGE_TYPES) rng = Random(effective_seed) _ = rng.random() spec = ChallengeSpec( challenge_id=f"{challenge_type}_{effective_seed}_{difficulty_multiplier:.1f}", base_task_id=self._base_task_for(challenge_type), seed=effective_seed, difficulty_multiplier=difficulty_multiplier, ) if challenge_type == "oom_with_red_herring": spec.has_red_herring_node = True spec.has_correlated_rack_failure = True spec.description = ( "OOM on one rank with a correlated rack neighbor " "showing degraded symptoms. Agent must distinguish " "primary failure from rack-blast-radius noise." ) elif challenge_type == "congestion_with_version_drift": spec.has_version_drift = True spec.description = ( "Spine congestion AND wrong NCCL version loaded. " "Fixing topology alone does not fully restore throughput. " "Agent must identify and fix both issues." ) elif challenge_type == "delayed_desync": spec.has_delayed_symptom = True spec.description = ( "Desync failure injected at step 5. Agent sees healthy " "cluster initially. Must recognize delayed emergence and " "not waste investigation actions early." ) elif challenge_type == "intermittent_oom": spec.has_intermittent_failure = True spec.description = ( "OOM clears and re-triggers every 8 steps. Agent must " "catch the failure in the active window. " "Diagnosis during clear window yields no signal." ) elif challenge_type == "multi_rack_congestion": spec.has_correlated_rack_failure = True spec.description = ( "Two spine switches congested simultaneously. " "Single topo_reorder insufficient — agent must " "call topo_reorder twice with different affinities." ) if difficulty_multiplier > 2.5: spec.solvability_score = 0.3 elif difficulty_multiplier >= 1.5: spec.solvability_score = 0.6 else: spec.solvability_score = 0.9 self._generated.append(spec) return spec def _base_task_for(self, challenge_type: str) -> str: """Map challenge type to base task.""" mapping = { "oom_with_red_herring": "easy", "congestion_with_version_drift": "medium", "delayed_desync": "hard", "intermittent_oom": "easy", "multi_rack_congestion": "medium", } return mapping.get(challenge_type, "easy") def to_scenario_config(self, spec: ChallengeSpec) -> ScenarioConfig: """ Convert a ChallengeSpec into a ScenarioConfig that the environment can use directly. """ from dataclasses import replace from app.config import build_scenario base = build_scenario(spec.base_task_id, spec.seed) overrides: dict[str, Any] = {} if spec.has_version_drift: overrides["nccl_version_loaded"] = "2.21.5" overrides["nccl_version_expected"] = "2.27.0" overrides["ld_library_path_corrupted"] = True if spec.has_correlated_rack_failure: rng = Random(spec.seed) _ = rng.random() other = (base.failing_node_id + 1) % NUM_NODES overrides["correlated_fault_nodes"] = [other] if spec.has_delayed_symptom: overrides["cascade_phase"] = 0 if spec.has_intermittent_failure: overrides["congestion_bandwidth_pct"] = 0.45 if overrides: return replace(base, **overrides) return base def validate_solvability(self, spec: ChallengeSpec) -> dict[str, Any]: """ Validate that a generated challenge is actually solvable. """ issues = [] if spec.solvability_score <= 0.2: issues.append("solvability_score too low — may be unsolvable") if spec.base_task_id not in ["easy", "medium", "hard", "cascade"]: issues.append(f"invalid base_task_id: {spec.base_task_id}") ids = [s.challenge_id for s in self._generated] if ids.count(spec.challenge_id) > 1: issues.append("duplicate challenge_id") return { "valid": len(issues) == 0, "solvability_score": spec.solvability_score, "issues": issues, } def get_generated(self) -> list[dict[str, Any]]: """Return list of generated challenge specs as dicts.""" return [ { "challenge_id": s.challenge_id, "base_task_id": s.base_task_id, "seed": s.seed, "difficulty_multiplier": s.difficulty_multiplier, "description": s.description, "solvability_score": s.solvability_score, "flags": { "red_herring": s.has_red_herring_node, "intermittent": s.has_intermittent_failure, "version_drift": s.has_version_drift, "correlated_rack": s.has_correlated_rack_failure, "delayed_symptom": s.has_delayed_symptom, }, } for s in self._generated ]