Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| from random import Random | |
| from typing import Any, Literal | |
| from app.config import NUM_NODES, TARGET_THROUGHPUT, ScenarioConfig | |
| class ChallengeSpec: | |
| """Specification for a generated challenge scenario.""" | |
| challenge_id: str | |
| base_task_id: str | |
| seed: int | |
| difficulty_multiplier: float | |
| has_red_herring_node: bool = False | |
| has_intermittent_failure: bool = False | |
| has_version_drift: bool = False | |
| has_correlated_rack_failure: bool = False | |
| has_delayed_symptom: bool = False | |
| description: str = "" | |
| solvability_score: float = 1.0 | |
| class ChallengeGenerator: | |
| """ | |
| Generates novel failure compositions for targeted agent training. | |
| """ | |
| CHALLENGE_TYPES = [ | |
| "oom_with_red_herring", | |
| "congestion_with_version_drift", | |
| "delayed_desync", | |
| "intermittent_oom", | |
| "multi_rack_congestion", | |
| ] | |
| def __init__(self, seed: int = 42) -> None: | |
| self._rng = Random(seed) | |
| self._generated: list[ChallengeSpec] = [] | |
| def generate( | |
| self, | |
| challenge_type: str | None = None, | |
| difficulty_multiplier: float = 1.0, | |
| seed: int | None = None, | |
| ) -> ChallengeSpec: | |
| """ | |
| Generate a challenge spec. | |
| If challenge_type is None, pick randomly. | |
| """ | |
| effective_seed = seed if seed is not None else self._rng.randint(0, 9999) | |
| if challenge_type is None: | |
| challenge_type = self._rng.choice(self.CHALLENGE_TYPES) | |
| rng = Random(effective_seed) | |
| _ = rng.random() | |
| spec = ChallengeSpec( | |
| challenge_id=f"{challenge_type}_{effective_seed}_{difficulty_multiplier:.1f}", | |
| base_task_id=self._base_task_for(challenge_type), | |
| seed=effective_seed, | |
| difficulty_multiplier=difficulty_multiplier, | |
| ) | |
| if challenge_type == "oom_with_red_herring": | |
| spec.has_red_herring_node = True | |
| spec.has_correlated_rack_failure = True | |
| spec.description = ( | |
| "OOM on one rank with a correlated rack neighbor " | |
| "showing degraded symptoms. Agent must distinguish " | |
| "primary failure from rack-blast-radius noise." | |
| ) | |
| elif challenge_type == "congestion_with_version_drift": | |
| spec.has_version_drift = True | |
| spec.description = ( | |
| "Spine congestion AND wrong NCCL version loaded. " | |
| "Fixing topology alone does not fully restore throughput. " | |
| "Agent must identify and fix both issues." | |
| ) | |
| elif challenge_type == "delayed_desync": | |
| spec.has_delayed_symptom = True | |
| spec.description = ( | |
| "Desync failure injected at step 5. Agent sees healthy " | |
| "cluster initially. Must recognize delayed emergence and " | |
| "not waste investigation actions early." | |
| ) | |
| elif challenge_type == "intermittent_oom": | |
| spec.has_intermittent_failure = True | |
| spec.description = ( | |
| "OOM clears and re-triggers every 8 steps. Agent must " | |
| "catch the failure in the active window. " | |
| "Diagnosis during clear window yields no signal." | |
| ) | |
| elif challenge_type == "multi_rack_congestion": | |
| spec.has_correlated_rack_failure = True | |
| spec.description = ( | |
| "Two spine switches congested simultaneously. " | |
| "Single topo_reorder insufficient — agent must " | |
| "call topo_reorder twice with different affinities." | |
| ) | |
| if difficulty_multiplier > 2.5: | |
| spec.solvability_score = 0.3 | |
| elif difficulty_multiplier >= 1.5: | |
| spec.solvability_score = 0.6 | |
| else: | |
| spec.solvability_score = 0.9 | |
| self._generated.append(spec) | |
| return spec | |
| def _base_task_for(self, challenge_type: str) -> str: | |
| """Map challenge type to base task.""" | |
| mapping = { | |
| "oom_with_red_herring": "easy", | |
| "congestion_with_version_drift": "medium", | |
| "delayed_desync": "hard", | |
| "intermittent_oom": "easy", | |
| "multi_rack_congestion": "medium", | |
| } | |
| return mapping.get(challenge_type, "easy") | |
| def to_scenario_config(self, spec: ChallengeSpec) -> ScenarioConfig: | |
| """ | |
| Convert a ChallengeSpec into a ScenarioConfig that the | |
| environment can use directly. | |
| """ | |
| from dataclasses import replace | |
| from app.config import build_scenario | |
| base = build_scenario(spec.base_task_id, spec.seed) | |
| overrides: dict[str, Any] = {} | |
| if spec.has_version_drift: | |
| overrides["nccl_version_loaded"] = "2.21.5" | |
| overrides["nccl_version_expected"] = "2.27.0" | |
| overrides["ld_library_path_corrupted"] = True | |
| if spec.has_correlated_rack_failure: | |
| rng = Random(spec.seed) | |
| _ = rng.random() | |
| other = (base.failing_node_id + 1) % NUM_NODES | |
| overrides["correlated_fault_nodes"] = [other] | |
| if spec.has_delayed_symptom: | |
| overrides["cascade_phase"] = 0 | |
| if spec.has_intermittent_failure: | |
| overrides["congestion_bandwidth_pct"] = 0.45 | |
| if overrides: | |
| return replace(base, **overrides) | |
| return base | |
| def validate_solvability(self, spec: ChallengeSpec) -> dict[str, Any]: | |
| """ | |
| Validate that a generated challenge is actually solvable. | |
| """ | |
| issues = [] | |
| if spec.solvability_score <= 0.2: | |
| issues.append("solvability_score too low — may be unsolvable") | |
| if spec.base_task_id not in ["easy", "medium", "hard", "cascade"]: | |
| issues.append(f"invalid base_task_id: {spec.base_task_id}") | |
| ids = [s.challenge_id for s in self._generated] | |
| if ids.count(spec.challenge_id) > 1: | |
| issues.append("duplicate challenge_id") | |
| return { | |
| "valid": len(issues) == 0, | |
| "solvability_score": spec.solvability_score, | |
| "issues": issues, | |
| } | |
| def get_generated(self) -> list[dict[str, Any]]: | |
| """Return list of generated challenge specs as dicts.""" | |
| return [ | |
| { | |
| "challenge_id": s.challenge_id, | |
| "base_task_id": s.base_task_id, | |
| "seed": s.seed, | |
| "difficulty_multiplier": s.difficulty_multiplier, | |
| "description": s.description, | |
| "solvability_score": s.solvability_score, | |
| "flags": { | |
| "red_herring": s.has_red_herring_node, | |
| "intermittent": s.has_intermittent_failure, | |
| "version_drift": s.has_version_drift, | |
| "correlated_rack": s.has_correlated_rack_failure, | |
| "delayed_symptom": s.has_delayed_symptom, | |
| }, | |
| } | |
| for s in self._generated | |
| ] | |