"""Fleet AI Supervisor-Worker coordination layer.""" from __future__ import annotations import math from enum import Enum from typing import TYPE_CHECKING, Any, Literal from pydantic import BaseModel, ConfigDict, Field from app.models import SREAction if TYPE_CHECKING: from simulation.cluster import ClusterStateMachine from simulation.telemetry import TelemetryStream class WorkerType(str, Enum): """Available specialized worker agent types.""" LOG_INSPECTOR = "log_inspector" PATCH_AGENT = "patch_agent" TOPO_AGENT = "topo_agent" VERSION_CHECKER = "version_checker" class DelegationRequest(BaseModel): """Supervisor delegation payload to a worker.""" model_config = ConfigDict(extra="forbid") worker: WorkerType action: str parameters: dict[str, Any] = Field(default_factory=dict) supervisor_reasoning: str = "" class WorkerResult(BaseModel): """Result returned by a worker to the supervisor.""" model_config = ConfigDict(extra="forbid") worker: WorkerType action: str success: bool output: dict[str, Any] confidence: float = 0.0 uncertainty: float = 0.0 tokens_used: int = 0 class DelegationResult(BaseModel): """Full result of a supervisor delegation cycle.""" model_config = ConfigDict(extra="forbid") delegation: DelegationRequest worker_result: WorkerResult coordination_reward: float explanation: str class CoalitionProposal(BaseModel): """A proposal requiring 2-worker agreement before execution.""" model_config = ConfigDict(extra="forbid") proposing_worker: str supporting_worker: str action: str parameters: dict[str, Any] = Field(default_factory=dict) rationale: str = "" class CoalitionResult(BaseModel): """Result of a coalition action attempt.""" model_config = ConfigDict(extra="forbid") proposal: CoalitionProposal agreement_reached: bool dissent_reason: str = "" joint_output: dict[str, Any] = Field(default_factory=dict) coalition_reward: float = 0.0 execution_success: bool = False class LogInspectorWorker: """Handles flight recorder and NCCL log inspection.""" worker_type = WorkerType.LOG_INSPECTOR def execute( self, action: str, parameters: dict[str, Any], cluster: "ClusterStateMachine", telemetry: "TelemetryStream", ) -> WorkerResult: """Execute log inspection action.""" if action == "inspect_flight_recorder": rank_id = parameters.get("rank_id") if rank_id is None: return WorkerResult( worker=self.worker_type, action=action, success=False, output={"error": "rank_id parameter required"}, confidence=0.0, uncertainty=1.0, tokens_used=int(parameters.get("token_count", 0)), ) payload = cluster._generate_flight_recorder_data(int(rank_id)) return WorkerResult( worker=self.worker_type, action=action, success=True, output={"flight_recorder": payload}, confidence=0.9, uncertainty=0.3, tokens_used=int(parameters.get("token_count", 0)), ) if action == "query_nccl_subsystem": subsystem = str(parameters.get("subsystem", "watchdog")) time_window = int(parameters.get("time_window", 10)) logs = telemetry.generate_nccl_subsystem_logs(cluster, subsystem, time_window) return WorkerResult( worker=self.worker_type, action=action, success=True, output={"nccl_logs": logs, "subsystem": subsystem}, confidence=0.7, uncertainty=0.4, tokens_used=int(parameters.get("token_count", 0)), ) if action == "grep_errors": patterns = ("error", "timeout", "failed", "xid") matches = [ line for line in telemetry.visible_logs() if any(pattern in line.lower() for pattern in patterns) ] return WorkerResult( worker=self.worker_type, action=action, success=True, output={"matches": matches}, confidence=0.5, uncertainty=0.4, tokens_used=int(parameters.get("token_count", 0)), ) return WorkerResult( worker=self.worker_type, action=action, success=False, output={"error": "unknown action"}, confidence=0.0, uncertainty=1.0, tokens_used=int(parameters.get("token_count", 0)), ) class PatchAgentWorker: """Handles code patching and fix application.""" worker_type = WorkerType.PATCH_AGENT def execute( self, action: str, parameters: dict[str, Any], cluster: "ClusterStateMachine", telemetry: "TelemetryStream", ) -> WorkerResult: """Execute patch action.""" if action == "apply_patch": file = parameters.get("file") fix_type = parameters.get("fix_type") if file is None or fix_type is None: return WorkerResult( worker=self.worker_type, action=action, success=False, output={"error": "file and fix_type required"}, confidence=0.0, uncertainty=1.0, tokens_used=int(parameters.get("token_count", 0)), ) action_result = cluster.apply_action( SREAction( action_type="patch_divergent_code", parameters={"file": str(file), "fix_type": str(fix_type)}, ) ) confidence = 0.95 if str(file) == cluster._scenario.divergent_file else 0.3 uncertainty = 0.1 if str(file) == cluster._scenario.divergent_file else 0.9 return WorkerResult( worker=self.worker_type, action=action, success=action_result.success, output=action_result.action_output, confidence=confidence, uncertainty=uncertainty, tokens_used=int(parameters.get("token_count", 0)), ) if action == "verify_patch": verified = cluster.training.job_status == "recovered" return WorkerResult( worker=self.worker_type, action=action, success=True, output={"verified": verified, "job_status": cluster.training.job_status}, confidence=1.0, uncertainty=0.4, tokens_used=int(parameters.get("token_count", 0)), ) return WorkerResult( worker=self.worker_type, action=action, success=False, output={"error": "unknown action"}, confidence=0.0, uncertainty=1.0, tokens_used=int(parameters.get("token_count", 0)), ) class TopoAgentWorker: """Handles topology and network configuration.""" worker_type = WorkerType.TOPO_AGENT def execute( self, action: str, parameters: dict[str, Any], cluster: "ClusterStateMachine", telemetry: "TelemetryStream", ) -> WorkerResult: """Execute topology action.""" if action == "reorder_topology": affinity = str(parameters.get("affinity", "rack")) action_result = cluster.apply_action( SREAction( action_type="topo_reorder", parameters={"affinity": affinity}, ) ) confidence = 0.85 if affinity == "rack" else 0.4 return WorkerResult( worker=self.worker_type, action=action, success=action_result.success, output=action_result.action_output, confidence=confidence, uncertainty=0.15 if affinity == "rack" else 0.4, tokens_used=int(parameters.get("token_count", 0)), ) if action == "check_bandwidth": throughput = cluster.training.throughput_tokens_per_sec target = cluster.training.target_throughput ratio = throughput / max(1.0, target) return WorkerResult( worker=self.worker_type, action=action, success=True, output={ "ratio": ratio, "degraded": ratio < 0.8, "throughput": throughput, }, confidence=1.0, uncertainty=0.1, tokens_used=int(parameters.get("token_count", 0)), ) return WorkerResult( worker=self.worker_type, action=action, success=False, output={"error": "unknown action"}, confidence=0.0, uncertainty=1.0, tokens_used=int(parameters.get("token_count", 0)), ) class VersionCheckerWorker: """Checks NCCL version compatibility and LD_LIBRARY_PATH.""" worker_type = WorkerType.VERSION_CHECKER def execute( self, action: str, parameters: dict[str, Any], cluster: "ClusterStateMachine", telemetry: "TelemetryStream", ) -> WorkerResult: """Execute version check action.""" if action == "check_nccl_version": loaded = cluster._scenario.nccl_version_loaded expected = cluster._scenario.nccl_version_expected mismatch = loaded != expected ld_corrupted = cluster._scenario.ld_library_path_corrupted return WorkerResult( worker=self.worker_type, action=action, success=True, output={ "loaded": loaded, "expected": expected, "mismatch": mismatch, "ld_corrupted": ld_corrupted, }, confidence=1.0, uncertainty=0.05, tokens_used=int(parameters.get("token_count", 0)), ) if action == "check_init_logs": logs = telemetry.generate_nccl_subsystem_logs(cluster, "init", 5) detected = any("version mismatch" in line.lower() for line in logs) return WorkerResult( worker=self.worker_type, action=action, success=True, output={"version_mismatch_detected": detected, "logs": logs}, confidence=0.8, uncertainty=0.4, tokens_used=int(parameters.get("token_count", 0)), ) return WorkerResult( worker=self.worker_type, action=action, success=False, output={"error": "unknown action"}, confidence=0.0, uncertainty=1.0, tokens_used=int(parameters.get("token_count", 0)), ) class FleetCoordinator: """Routes supervisor delegations to correct workers.""" def __init__( self, cluster: "ClusterStateMachine", telemetry: "TelemetryStream", ) -> None: self._cluster = cluster self._telemetry = telemetry self._workers = { WorkerType.LOG_INSPECTOR: LogInspectorWorker(), WorkerType.PATCH_AGENT: PatchAgentWorker(), WorkerType.TOPO_AGENT: TopoAgentWorker(), WorkerType.VERSION_CHECKER: VersionCheckerWorker(), } self._delegation_log: list[Any] = [] self._delegation_budget: int = 10 self._delegation_count: int = 0 self._budget_penalty_per_overrun: float = 0.05 self._coalition_reward_total: float = 0.0 _COALITION_SPECS: dict[str, dict[str, Any]] = { "topology_version_fix": { "proposing": WorkerType.TOPO_AGENT, "supporting": WorkerType.VERSION_CHECKER, "valid_failures": {"cascade", "black_swan"}, "coalition_reward": 0.25, "description": "Fix topology congestion and version mismatch jointly", "support_signals": {"check_nccl_version"}, }, "deep_patch_with_verification": { "proposing": WorkerType.PATCH_AGENT, "supporting": WorkerType.LOG_INSPECTOR, "valid_failures": {"desync", "cascade", "black_swan"}, "coalition_reward": 0.20, "description": "Apply patch and verify via deep logs in one joint action", "support_signals": {"inspect_flight_recorder", "query_nccl_logs"}, }, "rack_aware_oom_restart": { "proposing": WorkerType.LOG_INSPECTOR, "supporting": WorkerType.TOPO_AGENT, "valid_failures": {"oom", "cascade", "black_swan"}, "coalition_reward": 0.15, "description": "Non-destructive restart coordinated with topology awareness", "support_signals": {"check_bandwidth", "topo_reorder"}, }, } def delegate(self, request: DelegationRequest) -> DelegationResult: """Route delegation request to correct worker and score coordination.""" self._delegation_count += 1 worker = self._workers[request.worker] worker_result = worker.execute( action=request.action, parameters=request.parameters, cluster=self._cluster, telemetry=self._telemetry, ) worker_result.uncertainty = self._adjust_uncertainty( request=request, worker_result=worker_result, ) coordination_reward = self._score_coordination(request, worker_result) explanation = self._explain(request, worker_result, coordination_reward) if self._delegation_count > self._delegation_budget: overrun = self._delegation_count - self._delegation_budget coordination_reward = max( 0.0, coordination_reward - (overrun * self._budget_penalty_per_overrun), ) explanation += ( f" [OVER BUDGET: {overrun} excess delegations, penalty applied]" ) result = DelegationResult( delegation=request, worker_result=worker_result, coordination_reward=coordination_reward, explanation=explanation, ) self._delegation_log.append(result) return result def _expected_worker_for_failure(self, failure_type: str) -> WorkerType: if failure_type == "oom": return WorkerType.LOG_INSPECTOR if failure_type == "congestion": return WorkerType.TOPO_AGENT if failure_type == "desync": return WorkerType.PATCH_AGENT return WorkerType.VERSION_CHECKER def _validate_coalition_workers( self, proposal: CoalitionProposal, spec: dict[str, Any], ) -> str: try: proposing = WorkerType(proposal.proposing_worker) supporting = WorkerType(proposal.supporting_worker) except ValueError: return "unknown worker in coalition proposal" if proposing not in self._workers or supporting not in self._workers: return "worker not available in fleet" expected_proposing = spec["proposing"] expected_supporting = spec["supporting"] if supporting != expected_supporting or proposing != expected_proposing: return "wrong supporting worker for this coalition" return "" def _execute_coalition_effect( self, action: str, cluster: "ClusterStateMachine", ) -> tuple[bool, dict[str, Any]]: if action == "topology_version_fix": cluster.training.throughput_tokens_per_sec *= 1.35 cluster.training.job_status = "recovered" return True, {"fixed": ["congestion", "version_mismatch"]} if action == "deep_patch_with_verification": cluster.patch_stage = 3 cluster.training.job_status = "recovered" cluster.training.stalled_steps = 0 return True, {"patch_stage": 3, "verified": True} if action == "rack_aware_oom_restart": failing_node = cluster.nodes[cluster._scenario.failing_node_id] failing_node.health_status = "healthy" failing_node.xid_errors.clear() return True, { "restarted": cluster._scenario.failing_node_id, "destructive_penalty_waived": True, } return False, {"error": "unsupported coalition action"} def propose_coalition( self, proposal: CoalitionProposal, cluster: "ClusterStateMachine", telemetry: "TelemetryStream", ) -> CoalitionResult: """Attempt a coalition action requiring 2-worker agreement.""" _ = telemetry spec = self._COALITION_SPECS.get(proposal.action) if spec is None: return CoalitionResult( proposal=proposal, agreement_reached=False, dissent_reason="unknown coalition action", ) worker_validation_error = self._validate_coalition_workers(proposal, spec) if worker_validation_error: result = CoalitionResult( proposal=proposal, agreement_reached=False, dissent_reason=worker_validation_error, ) self._delegation_log.append( { "type": "coalition", "action": proposal.action, "agreement_reached": False, "dissent_reason": worker_validation_error, "coalition_reward": 0.0, } ) return result failure_type = cluster._scenario.failure_type valid_failures: set[str] = spec["valid_failures"] if failure_type not in valid_failures: dissent_reason = "coalition not applicable to current failure" result = CoalitionResult( proposal=proposal, agreement_reached=False, dissent_reason=dissent_reason, ) self._delegation_log.append( { "type": "coalition", "action": proposal.action, "agreement_reached": False, "dissent_reason": dissent_reason, "coalition_reward": 0.0, } ) return result consensus = self.get_worker_consensus(cluster) supporting_worker = spec["supporting"] support_signals: set[str] = spec["support_signals"] supporting_recommendation = str( consensus["recommendations"].get(supporting_worker.value, "") ) if supporting_recommendation not in support_signals: dissent_reason = "supporting worker did not agree with coalition action" result = CoalitionResult( proposal=proposal, agreement_reached=False, dissent_reason=dissent_reason, ) self._delegation_log.append( { "type": "coalition", "action": proposal.action, "agreement_reached": False, "dissent_reason": dissent_reason, "coalition_reward": 0.0, } ) return result execution_success, joint_output = self._execute_coalition_effect(proposal.action, cluster) coalition_reward = float(spec["coalition_reward"]) if execution_success else 0.0 self._coalition_reward_total += coalition_reward result = CoalitionResult( proposal=proposal, agreement_reached=True, joint_output=joint_output, coalition_reward=coalition_reward, execution_success=execution_success, ) self._delegation_log.append( { "type": "coalition", "action": proposal.action, "agreement_reached": True, "dissent_reason": "", "coalition_reward": coalition_reward, "joint_output": joint_output, "execution_success": execution_success, "proposing_worker": proposal.proposing_worker, "supporting_worker": proposal.supporting_worker, } ) return result def get_coalition_options( self, cluster: "ClusterStateMachine", ) -> list[dict[str, Any]]: """Return available coalition actions for current failure type.""" failure_type = cluster._scenario.failure_type options: list[dict[str, Any]] = [] for action, spec in self._COALITION_SPECS.items(): options.append( { "action": action, "proposing_worker": spec["proposing"].value, "supporting_worker": spec["supporting"].value, "valid_for_current_failure": failure_type in spec["valid_failures"], "coalition_reward": float(spec["coalition_reward"]), "description": str(spec["description"]), } ) return options def _direct_actions_for_worker(self, worker: WorkerType) -> set[str]: if worker == WorkerType.LOG_INSPECTOR: return {"inspect_flight_recorder"} if worker == WorkerType.TOPO_AGENT: return {"reorder_topology", "check_bandwidth"} if worker == WorkerType.PATCH_AGENT: return {"apply_patch"} if worker == WorkerType.VERSION_CHECKER: return {"check_nccl_version"} return set() def _adjust_uncertainty( self, request: DelegationRequest, worker_result: WorkerResult, ) -> float: """Apply global uncertainty calibration while preserving explicit worker signals.""" failure_type = self._cluster._scenario.failure_type expected_worker = self._expected_worker_for_failure(failure_type) if worker_result.uncertainty >= 0.9: return worker_result.uncertainty if request.worker != expected_worker: return max(worker_result.uncertainty, 0.8) direct_actions = self._direct_actions_for_worker(request.worker) if request.action in direct_actions: return min(worker_result.uncertainty, 0.1) return max(worker_result.uncertainty, 0.4) def get_worker_consensus( self, cluster: "ClusterStateMachine", ) -> dict[str, Any]: """Poll all workers and return consensus or disagreement signal.""" recommendations: dict[str, str] = {} failure_type = cluster._scenario.failure_type for worker_type in self._workers: if worker_type == WorkerType.LOG_INSPECTOR: rec = ( "inspect_flight_recorder" if failure_type in {"oom", "cascade", "black_swan"} else "query_nccl_logs" ) elif worker_type == WorkerType.TOPO_AGENT: rec = ( "topo_reorder" if failure_type in {"congestion", "cascade", "black_swan"} else "check_bandwidth" ) elif worker_type == WorkerType.PATCH_AGENT: rec = ( "apply_patch" if failure_type in {"desync", "cascade", "black_swan"} else "verify_patch" ) else: rec = ( "check_nccl_version" if failure_type in {"cascade", "black_swan"} else "check_init_logs" ) recommendations[worker_type.value] = rec unique_recs = set(recommendations.values()) disagreement_score = len(unique_recs) / 4.0 suggested_next = max( set(recommendations.values()), key=list(recommendations.values()).count, ) return { "recommendations": recommendations, "disagreement_score": round(disagreement_score, 2), "consensus": len(unique_recs) == 1, "suggested_next": suggested_next, } def _score_coordination( self, request: DelegationRequest, result: WorkerResult, ) -> float: """Score whether supervisor routed to the correct worker.""" failure_type: Literal["oom", "congestion", "desync", "cascade", "black_swan"] = ( self._cluster._scenario.failure_type ) expected_worker: WorkerType | None = None if failure_type == "oom": expected_worker = WorkerType.LOG_INSPECTOR elif failure_type == "congestion": expected_worker = WorkerType.TOPO_AGENT elif failure_type == "desync": expected_worker = WorkerType.PATCH_AGENT elif failure_type in {"cascade", "black_swan"} and self._cluster.training.stalled_steps >= 10: expected_worker = WorkerType.VERSION_CHECKER reward = 0.0 if expected_worker is not None and request.worker == expected_worker: reward += 0.3 if result.success: reward += 0.1 reward = max(0.0, min(0.4, reward)) return float(round(reward, 4)) def _explain( self, request: DelegationRequest, result: WorkerResult, coordination_reward: float, ) -> str: """Build human-readable explanation of delegation outcome.""" _ = math.log(max(1.0, result.confidence + 1.0)) return ( f"Supervisor delegated '{request.action}' to {request.worker.value}. " f"Worker {'succeeded' if result.success else 'failed'} " f"(confidence={result.confidence:.2f}). " f"Coordination reward: {coordination_reward:.2f}. " f"Reasoning: '{request.supervisor_reasoning}'" ) def get_delegation_log(self) -> list[dict[str, Any]]: """Return serializable delegation history.""" serializable: list[dict[str, Any]] = [] for index, record in enumerate(self._delegation_log): if isinstance(record, DelegationResult): serializable.append( { "type": "delegation", "worker": record.delegation.worker.value, "action": record.delegation.action, "success": record.worker_result.success, "coordination_reward": record.coordination_reward, "confidence": record.worker_result.confidence, "uncertainty": record.worker_result.uncertainty, "budget_remaining": max(0, self._delegation_budget - index - 1), } ) elif isinstance(record, dict): serializable.append({**record}) return serializable def budget_status(self) -> dict[str, Any]: """Return current delegation budget usage and penalties.""" overrun = self._delegation_count - self._delegation_budget return { "budget": self._delegation_budget, "used": self._delegation_count, "remaining": max(0, self._delegation_budget - self._delegation_count), "over_budget": self._delegation_count > self._delegation_budget, "total_penalty": max(0.0, overrun * self._budget_penalty_per_overrun), } def total_coordination_reward(self) -> float: """Sum of coordination rewards across all delegations.""" delegation_total = sum( r.coordination_reward for r in self._delegation_log if isinstance(r, DelegationResult) ) return delegation_total + self._coalition_reward_total