vx7sh's picture
feat(env): add curriculum, challenge generation, coalition, and black-swan mechanics
edc6488
"""Fleet AI Supervisor-Worker coordination layer."""
from __future__ import annotations
import math
from enum import Enum
from typing import TYPE_CHECKING, Any, Literal
from pydantic import BaseModel, ConfigDict, Field
from app.models import SREAction
if TYPE_CHECKING:
from simulation.cluster import ClusterStateMachine
from simulation.telemetry import TelemetryStream
class WorkerType(str, Enum):
"""Available specialized worker agent types."""
LOG_INSPECTOR = "log_inspector"
PATCH_AGENT = "patch_agent"
TOPO_AGENT = "topo_agent"
VERSION_CHECKER = "version_checker"
class DelegationRequest(BaseModel):
"""Supervisor delegation payload to a worker."""
model_config = ConfigDict(extra="forbid")
worker: WorkerType
action: str
parameters: dict[str, Any] = Field(default_factory=dict)
supervisor_reasoning: str = ""
class WorkerResult(BaseModel):
"""Result returned by a worker to the supervisor."""
model_config = ConfigDict(extra="forbid")
worker: WorkerType
action: str
success: bool
output: dict[str, Any]
confidence: float = 0.0
uncertainty: float = 0.0
tokens_used: int = 0
class DelegationResult(BaseModel):
"""Full result of a supervisor delegation cycle."""
model_config = ConfigDict(extra="forbid")
delegation: DelegationRequest
worker_result: WorkerResult
coordination_reward: float
explanation: str
class CoalitionProposal(BaseModel):
"""A proposal requiring 2-worker agreement before execution."""
model_config = ConfigDict(extra="forbid")
proposing_worker: str
supporting_worker: str
action: str
parameters: dict[str, Any] = Field(default_factory=dict)
rationale: str = ""
class CoalitionResult(BaseModel):
"""Result of a coalition action attempt."""
model_config = ConfigDict(extra="forbid")
proposal: CoalitionProposal
agreement_reached: bool
dissent_reason: str = ""
joint_output: dict[str, Any] = Field(default_factory=dict)
coalition_reward: float = 0.0
execution_success: bool = False
class LogInspectorWorker:
"""Handles flight recorder and NCCL log inspection."""
worker_type = WorkerType.LOG_INSPECTOR
def execute(
self,
action: str,
parameters: dict[str, Any],
cluster: "ClusterStateMachine",
telemetry: "TelemetryStream",
) -> WorkerResult:
"""Execute log inspection action."""
if action == "inspect_flight_recorder":
rank_id = parameters.get("rank_id")
if rank_id is None:
return WorkerResult(
worker=self.worker_type,
action=action,
success=False,
output={"error": "rank_id parameter required"},
confidence=0.0,
uncertainty=1.0,
tokens_used=int(parameters.get("token_count", 0)),
)
payload = cluster._generate_flight_recorder_data(int(rank_id))
return WorkerResult(
worker=self.worker_type,
action=action,
success=True,
output={"flight_recorder": payload},
confidence=0.9,
uncertainty=0.3,
tokens_used=int(parameters.get("token_count", 0)),
)
if action == "query_nccl_subsystem":
subsystem = str(parameters.get("subsystem", "watchdog"))
time_window = int(parameters.get("time_window", 10))
logs = telemetry.generate_nccl_subsystem_logs(cluster, subsystem, time_window)
return WorkerResult(
worker=self.worker_type,
action=action,
success=True,
output={"nccl_logs": logs, "subsystem": subsystem},
confidence=0.7,
uncertainty=0.4,
tokens_used=int(parameters.get("token_count", 0)),
)
if action == "grep_errors":
patterns = ("error", "timeout", "failed", "xid")
matches = [
line
for line in telemetry.visible_logs()
if any(pattern in line.lower() for pattern in patterns)
]
return WorkerResult(
worker=self.worker_type,
action=action,
success=True,
output={"matches": matches},
confidence=0.5,
uncertainty=0.4,
tokens_used=int(parameters.get("token_count", 0)),
)
return WorkerResult(
worker=self.worker_type,
action=action,
success=False,
output={"error": "unknown action"},
confidence=0.0,
uncertainty=1.0,
tokens_used=int(parameters.get("token_count", 0)),
)
class PatchAgentWorker:
"""Handles code patching and fix application."""
worker_type = WorkerType.PATCH_AGENT
def execute(
self,
action: str,
parameters: dict[str, Any],
cluster: "ClusterStateMachine",
telemetry: "TelemetryStream",
) -> WorkerResult:
"""Execute patch action."""
if action == "apply_patch":
file = parameters.get("file")
fix_type = parameters.get("fix_type")
if file is None or fix_type is None:
return WorkerResult(
worker=self.worker_type,
action=action,
success=False,
output={"error": "file and fix_type required"},
confidence=0.0,
uncertainty=1.0,
tokens_used=int(parameters.get("token_count", 0)),
)
action_result = cluster.apply_action(
SREAction(
action_type="patch_divergent_code",
parameters={"file": str(file), "fix_type": str(fix_type)},
)
)
confidence = 0.95 if str(file) == cluster._scenario.divergent_file else 0.3
uncertainty = 0.1 if str(file) == cluster._scenario.divergent_file else 0.9
return WorkerResult(
worker=self.worker_type,
action=action,
success=action_result.success,
output=action_result.action_output,
confidence=confidence,
uncertainty=uncertainty,
tokens_used=int(parameters.get("token_count", 0)),
)
if action == "verify_patch":
verified = cluster.training.job_status == "recovered"
return WorkerResult(
worker=self.worker_type,
action=action,
success=True,
output={"verified": verified, "job_status": cluster.training.job_status},
confidence=1.0,
uncertainty=0.4,
tokens_used=int(parameters.get("token_count", 0)),
)
return WorkerResult(
worker=self.worker_type,
action=action,
success=False,
output={"error": "unknown action"},
confidence=0.0,
uncertainty=1.0,
tokens_used=int(parameters.get("token_count", 0)),
)
class TopoAgentWorker:
"""Handles topology and network configuration."""
worker_type = WorkerType.TOPO_AGENT
def execute(
self,
action: str,
parameters: dict[str, Any],
cluster: "ClusterStateMachine",
telemetry: "TelemetryStream",
) -> WorkerResult:
"""Execute topology action."""
if action == "reorder_topology":
affinity = str(parameters.get("affinity", "rack"))
action_result = cluster.apply_action(
SREAction(
action_type="topo_reorder",
parameters={"affinity": affinity},
)
)
confidence = 0.85 if affinity == "rack" else 0.4
return WorkerResult(
worker=self.worker_type,
action=action,
success=action_result.success,
output=action_result.action_output,
confidence=confidence,
uncertainty=0.15 if affinity == "rack" else 0.4,
tokens_used=int(parameters.get("token_count", 0)),
)
if action == "check_bandwidth":
throughput = cluster.training.throughput_tokens_per_sec
target = cluster.training.target_throughput
ratio = throughput / max(1.0, target)
return WorkerResult(
worker=self.worker_type,
action=action,
success=True,
output={
"ratio": ratio,
"degraded": ratio < 0.8,
"throughput": throughput,
},
confidence=1.0,
uncertainty=0.1,
tokens_used=int(parameters.get("token_count", 0)),
)
return WorkerResult(
worker=self.worker_type,
action=action,
success=False,
output={"error": "unknown action"},
confidence=0.0,
uncertainty=1.0,
tokens_used=int(parameters.get("token_count", 0)),
)
class VersionCheckerWorker:
"""Checks NCCL version compatibility and LD_LIBRARY_PATH."""
worker_type = WorkerType.VERSION_CHECKER
def execute(
self,
action: str,
parameters: dict[str, Any],
cluster: "ClusterStateMachine",
telemetry: "TelemetryStream",
) -> WorkerResult:
"""Execute version check action."""
if action == "check_nccl_version":
loaded = cluster._scenario.nccl_version_loaded
expected = cluster._scenario.nccl_version_expected
mismatch = loaded != expected
ld_corrupted = cluster._scenario.ld_library_path_corrupted
return WorkerResult(
worker=self.worker_type,
action=action,
success=True,
output={
"loaded": loaded,
"expected": expected,
"mismatch": mismatch,
"ld_corrupted": ld_corrupted,
},
confidence=1.0,
uncertainty=0.05,
tokens_used=int(parameters.get("token_count", 0)),
)
if action == "check_init_logs":
logs = telemetry.generate_nccl_subsystem_logs(cluster, "init", 5)
detected = any("version mismatch" in line.lower() for line in logs)
return WorkerResult(
worker=self.worker_type,
action=action,
success=True,
output={"version_mismatch_detected": detected, "logs": logs},
confidence=0.8,
uncertainty=0.4,
tokens_used=int(parameters.get("token_count", 0)),
)
return WorkerResult(
worker=self.worker_type,
action=action,
success=False,
output={"error": "unknown action"},
confidence=0.0,
uncertainty=1.0,
tokens_used=int(parameters.get("token_count", 0)),
)
class FleetCoordinator:
"""Routes supervisor delegations to correct workers."""
def __init__(
self,
cluster: "ClusterStateMachine",
telemetry: "TelemetryStream",
) -> None:
self._cluster = cluster
self._telemetry = telemetry
self._workers = {
WorkerType.LOG_INSPECTOR: LogInspectorWorker(),
WorkerType.PATCH_AGENT: PatchAgentWorker(),
WorkerType.TOPO_AGENT: TopoAgentWorker(),
WorkerType.VERSION_CHECKER: VersionCheckerWorker(),
}
self._delegation_log: list[Any] = []
self._delegation_budget: int = 10
self._delegation_count: int = 0
self._budget_penalty_per_overrun: float = 0.05
self._coalition_reward_total: float = 0.0
_COALITION_SPECS: dict[str, dict[str, Any]] = {
"topology_version_fix": {
"proposing": WorkerType.TOPO_AGENT,
"supporting": WorkerType.VERSION_CHECKER,
"valid_failures": {"cascade", "black_swan"},
"coalition_reward": 0.25,
"description": "Fix topology congestion and version mismatch jointly",
"support_signals": {"check_nccl_version"},
},
"deep_patch_with_verification": {
"proposing": WorkerType.PATCH_AGENT,
"supporting": WorkerType.LOG_INSPECTOR,
"valid_failures": {"desync", "cascade", "black_swan"},
"coalition_reward": 0.20,
"description": "Apply patch and verify via deep logs in one joint action",
"support_signals": {"inspect_flight_recorder", "query_nccl_logs"},
},
"rack_aware_oom_restart": {
"proposing": WorkerType.LOG_INSPECTOR,
"supporting": WorkerType.TOPO_AGENT,
"valid_failures": {"oom", "cascade", "black_swan"},
"coalition_reward": 0.15,
"description": "Non-destructive restart coordinated with topology awareness",
"support_signals": {"check_bandwidth", "topo_reorder"},
},
}
def delegate(self, request: DelegationRequest) -> DelegationResult:
"""Route delegation request to correct worker and score coordination."""
self._delegation_count += 1
worker = self._workers[request.worker]
worker_result = worker.execute(
action=request.action,
parameters=request.parameters,
cluster=self._cluster,
telemetry=self._telemetry,
)
worker_result.uncertainty = self._adjust_uncertainty(
request=request,
worker_result=worker_result,
)
coordination_reward = self._score_coordination(request, worker_result)
explanation = self._explain(request, worker_result, coordination_reward)
if self._delegation_count > self._delegation_budget:
overrun = self._delegation_count - self._delegation_budget
coordination_reward = max(
0.0,
coordination_reward - (overrun * self._budget_penalty_per_overrun),
)
explanation += (
f" [OVER BUDGET: {overrun} excess delegations, penalty applied]"
)
result = DelegationResult(
delegation=request,
worker_result=worker_result,
coordination_reward=coordination_reward,
explanation=explanation,
)
self._delegation_log.append(result)
return result
def _expected_worker_for_failure(self, failure_type: str) -> WorkerType:
if failure_type == "oom":
return WorkerType.LOG_INSPECTOR
if failure_type == "congestion":
return WorkerType.TOPO_AGENT
if failure_type == "desync":
return WorkerType.PATCH_AGENT
return WorkerType.VERSION_CHECKER
def _validate_coalition_workers(
self,
proposal: CoalitionProposal,
spec: dict[str, Any],
) -> str:
try:
proposing = WorkerType(proposal.proposing_worker)
supporting = WorkerType(proposal.supporting_worker)
except ValueError:
return "unknown worker in coalition proposal"
if proposing not in self._workers or supporting not in self._workers:
return "worker not available in fleet"
expected_proposing = spec["proposing"]
expected_supporting = spec["supporting"]
if supporting != expected_supporting or proposing != expected_proposing:
return "wrong supporting worker for this coalition"
return ""
def _execute_coalition_effect(
self,
action: str,
cluster: "ClusterStateMachine",
) -> tuple[bool, dict[str, Any]]:
if action == "topology_version_fix":
cluster.training.throughput_tokens_per_sec *= 1.35
cluster.training.job_status = "recovered"
return True, {"fixed": ["congestion", "version_mismatch"]}
if action == "deep_patch_with_verification":
cluster.patch_stage = 3
cluster.training.job_status = "recovered"
cluster.training.stalled_steps = 0
return True, {"patch_stage": 3, "verified": True}
if action == "rack_aware_oom_restart":
failing_node = cluster.nodes[cluster._scenario.failing_node_id]
failing_node.health_status = "healthy"
failing_node.xid_errors.clear()
return True, {
"restarted": cluster._scenario.failing_node_id,
"destructive_penalty_waived": True,
}
return False, {"error": "unsupported coalition action"}
def propose_coalition(
self,
proposal: CoalitionProposal,
cluster: "ClusterStateMachine",
telemetry: "TelemetryStream",
) -> CoalitionResult:
"""Attempt a coalition action requiring 2-worker agreement."""
_ = telemetry
spec = self._COALITION_SPECS.get(proposal.action)
if spec is None:
return CoalitionResult(
proposal=proposal,
agreement_reached=False,
dissent_reason="unknown coalition action",
)
worker_validation_error = self._validate_coalition_workers(proposal, spec)
if worker_validation_error:
result = CoalitionResult(
proposal=proposal,
agreement_reached=False,
dissent_reason=worker_validation_error,
)
self._delegation_log.append(
{
"type": "coalition",
"action": proposal.action,
"agreement_reached": False,
"dissent_reason": worker_validation_error,
"coalition_reward": 0.0,
}
)
return result
failure_type = cluster._scenario.failure_type
valid_failures: set[str] = spec["valid_failures"]
if failure_type not in valid_failures:
dissent_reason = "coalition not applicable to current failure"
result = CoalitionResult(
proposal=proposal,
agreement_reached=False,
dissent_reason=dissent_reason,
)
self._delegation_log.append(
{
"type": "coalition",
"action": proposal.action,
"agreement_reached": False,
"dissent_reason": dissent_reason,
"coalition_reward": 0.0,
}
)
return result
consensus = self.get_worker_consensus(cluster)
supporting_worker = spec["supporting"]
support_signals: set[str] = spec["support_signals"]
supporting_recommendation = str(
consensus["recommendations"].get(supporting_worker.value, "")
)
if supporting_recommendation not in support_signals:
dissent_reason = "supporting worker did not agree with coalition action"
result = CoalitionResult(
proposal=proposal,
agreement_reached=False,
dissent_reason=dissent_reason,
)
self._delegation_log.append(
{
"type": "coalition",
"action": proposal.action,
"agreement_reached": False,
"dissent_reason": dissent_reason,
"coalition_reward": 0.0,
}
)
return result
execution_success, joint_output = self._execute_coalition_effect(proposal.action, cluster)
coalition_reward = float(spec["coalition_reward"]) if execution_success else 0.0
self._coalition_reward_total += coalition_reward
result = CoalitionResult(
proposal=proposal,
agreement_reached=True,
joint_output=joint_output,
coalition_reward=coalition_reward,
execution_success=execution_success,
)
self._delegation_log.append(
{
"type": "coalition",
"action": proposal.action,
"agreement_reached": True,
"dissent_reason": "",
"coalition_reward": coalition_reward,
"joint_output": joint_output,
"execution_success": execution_success,
"proposing_worker": proposal.proposing_worker,
"supporting_worker": proposal.supporting_worker,
}
)
return result
def get_coalition_options(
self,
cluster: "ClusterStateMachine",
) -> list[dict[str, Any]]:
"""Return available coalition actions for current failure type."""
failure_type = cluster._scenario.failure_type
options: list[dict[str, Any]] = []
for action, spec in self._COALITION_SPECS.items():
options.append(
{
"action": action,
"proposing_worker": spec["proposing"].value,
"supporting_worker": spec["supporting"].value,
"valid_for_current_failure": failure_type in spec["valid_failures"],
"coalition_reward": float(spec["coalition_reward"]),
"description": str(spec["description"]),
}
)
return options
def _direct_actions_for_worker(self, worker: WorkerType) -> set[str]:
if worker == WorkerType.LOG_INSPECTOR:
return {"inspect_flight_recorder"}
if worker == WorkerType.TOPO_AGENT:
return {"reorder_topology", "check_bandwidth"}
if worker == WorkerType.PATCH_AGENT:
return {"apply_patch"}
if worker == WorkerType.VERSION_CHECKER:
return {"check_nccl_version"}
return set()
def _adjust_uncertainty(
self,
request: DelegationRequest,
worker_result: WorkerResult,
) -> float:
"""Apply global uncertainty calibration while preserving explicit worker signals."""
failure_type = self._cluster._scenario.failure_type
expected_worker = self._expected_worker_for_failure(failure_type)
if worker_result.uncertainty >= 0.9:
return worker_result.uncertainty
if request.worker != expected_worker:
return max(worker_result.uncertainty, 0.8)
direct_actions = self._direct_actions_for_worker(request.worker)
if request.action in direct_actions:
return min(worker_result.uncertainty, 0.1)
return max(worker_result.uncertainty, 0.4)
def get_worker_consensus(
self,
cluster: "ClusterStateMachine",
) -> dict[str, Any]:
"""Poll all workers and return consensus or disagreement signal."""
recommendations: dict[str, str] = {}
failure_type = cluster._scenario.failure_type
for worker_type in self._workers:
if worker_type == WorkerType.LOG_INSPECTOR:
rec = (
"inspect_flight_recorder"
if failure_type in {"oom", "cascade", "black_swan"}
else "query_nccl_logs"
)
elif worker_type == WorkerType.TOPO_AGENT:
rec = (
"topo_reorder"
if failure_type in {"congestion", "cascade", "black_swan"}
else "check_bandwidth"
)
elif worker_type == WorkerType.PATCH_AGENT:
rec = (
"apply_patch"
if failure_type in {"desync", "cascade", "black_swan"}
else "verify_patch"
)
else:
rec = (
"check_nccl_version"
if failure_type in {"cascade", "black_swan"}
else "check_init_logs"
)
recommendations[worker_type.value] = rec
unique_recs = set(recommendations.values())
disagreement_score = len(unique_recs) / 4.0
suggested_next = max(
set(recommendations.values()),
key=list(recommendations.values()).count,
)
return {
"recommendations": recommendations,
"disagreement_score": round(disagreement_score, 2),
"consensus": len(unique_recs) == 1,
"suggested_next": suggested_next,
}
def _score_coordination(
self,
request: DelegationRequest,
result: WorkerResult,
) -> float:
"""Score whether supervisor routed to the correct worker."""
failure_type: Literal["oom", "congestion", "desync", "cascade", "black_swan"] = (
self._cluster._scenario.failure_type
)
expected_worker: WorkerType | None = None
if failure_type == "oom":
expected_worker = WorkerType.LOG_INSPECTOR
elif failure_type == "congestion":
expected_worker = WorkerType.TOPO_AGENT
elif failure_type == "desync":
expected_worker = WorkerType.PATCH_AGENT
elif failure_type in {"cascade", "black_swan"} and self._cluster.training.stalled_steps >= 10:
expected_worker = WorkerType.VERSION_CHECKER
reward = 0.0
if expected_worker is not None and request.worker == expected_worker:
reward += 0.3
if result.success:
reward += 0.1
reward = max(0.0, min(0.4, reward))
return float(round(reward, 4))
def _explain(
self,
request: DelegationRequest,
result: WorkerResult,
coordination_reward: float,
) -> str:
"""Build human-readable explanation of delegation outcome."""
_ = math.log(max(1.0, result.confidence + 1.0))
return (
f"Supervisor delegated '{request.action}' to {request.worker.value}. "
f"Worker {'succeeded' if result.success else 'failed'} "
f"(confidence={result.confidence:.2f}). "
f"Coordination reward: {coordination_reward:.2f}. "
f"Reasoning: '{request.supervisor_reasoning}'"
)
def get_delegation_log(self) -> list[dict[str, Any]]:
"""Return serializable delegation history."""
serializable: list[dict[str, Any]] = []
for index, record in enumerate(self._delegation_log):
if isinstance(record, DelegationResult):
serializable.append(
{
"type": "delegation",
"worker": record.delegation.worker.value,
"action": record.delegation.action,
"success": record.worker_result.success,
"coordination_reward": record.coordination_reward,
"confidence": record.worker_result.confidence,
"uncertainty": record.worker_result.uncertainty,
"budget_remaining": max(0, self._delegation_budget - index - 1),
}
)
elif isinstance(record, dict):
serializable.append({**record})
return serializable
def budget_status(self) -> dict[str, Any]:
"""Return current delegation budget usage and penalties."""
overrun = self._delegation_count - self._delegation_budget
return {
"budget": self._delegation_budget,
"used": self._delegation_count,
"remaining": max(0, self._delegation_budget - self._delegation_count),
"over_budget": self._delegation_count > self._delegation_budget,
"total_penalty": max(0.0, overrun * self._budget_penalty_per_overrun),
}
def total_coordination_reward(self) -> float:
"""Sum of coordination rewards across all delegations."""
delegation_total = sum(
r.coordination_reward
for r in self._delegation_log
if isinstance(r, DelegationResult)
)
return delegation_total + self._coalition_reward_total