from __future__ import annotations from dataclasses import dataclass from typing import Literal from sentinel_config import ADVERSARIAL_AWARENESS_STAKES from scenarios import Scenario, SubTask TaskStatus = Literal["pending", "ready", "in_progress", "completed", "failed", "skipped"] SummaryValue = str | int | float | bool # --------------------------------------------------------------------------- # Node state # --------------------------------------------------------------------------- @dataclass class TaskNode: subtask: SubTask status: TaskStatus = "pending" outcome: float = 0.0 # 1.0 correct | 0.5 partial | 0.0 wrong specialist_used: str = "" attempts: int = 0 was_adversarial: bool = False adversarial_detection_count: int = 0 adversarial_poisoning_count: int = 0 # --------------------------------------------------------------------------- # TaskGraph # Manages the DAG of subtasks for one episode. # Tracks dependencies, determines which nodes are "ready" to execute, # and records outcomes. # --------------------------------------------------------------------------- class TaskGraph: MAX_ATTEMPTS_PER_NODE = 2 def __init__(self, scenario: Scenario) -> None: self._scenario = scenario self._nodes: dict[str, TaskNode] = {} self._order: list[str] = [] # insertion order (for iteration) self._build(scenario["subtasks"]) def _build(self, subtasks: list[SubTask]) -> None: for st in subtasks: self._nodes[st["id"]] = TaskNode(subtask=st) self._order.append(st["id"]) # ------------------------------------------------------------------ # State queries # ------------------------------------------------------------------ def current_node(self) -> TaskNode | None: """ Returns the first 'ready' node (all dependencies completed). Returns None if all nodes are done or none are unblocked yet. """ for sid in self._order: node = self._nodes[sid] if node.status == "failed" and node.attempts < self.MAX_ATTEMPTS_PER_NODE: node.status = "ready" if node.status == "pending" and self._deps_met(sid): node.status = "ready" if node.status == "ready": return node return None def _deps_met(self, subtask_id: str) -> bool: """All dependencies must be resolved before downstream work starts.""" deps = self._nodes[subtask_id].subtask["depends_on"] return all( self._is_dependency_resolved(dep) for dep in deps if dep in self._nodes ) def _is_dependency_resolved(self, subtask_id: str) -> bool: node = self._nodes[subtask_id] if node.status in ("completed", "skipped"): return True return node.status == "failed" and node.attempts >= self.MAX_ATTEMPTS_PER_NODE def is_done(self) -> bool: return all( n.status in ("completed", "skipped") or (n.status == "failed" and n.attempts >= self.MAX_ATTEMPTS_PER_NODE) for n in self._nodes.values() ) def completion_rate(self) -> float: completed = sum(1 for n in self._nodes.values() if n.status == "completed") return completed / len(self._nodes) if self._nodes else 0.0 def adversarial_detections(self) -> int: """ Count of high-stakes adversarial attempts that were avoided. Avoided = node was adversarial AND orchestrator chose VERIFY or SOLVE_INDEPENDENTLY. """ return sum( n.adversarial_detection_count for n in self._nodes.values() ) def adversarial_poisonings(self) -> int: """ Count of adversarial results that slipped through unchecked. """ return sum( n.adversarial_poisoning_count for n in self._nodes.values() ) def subtasks_remaining(self) -> int: return sum( 1 for n in self._nodes.values() if n.status in ("pending", "ready", "in_progress") or (n.status == "failed" and n.attempts < self.MAX_ATTEMPTS_PER_NODE) ) def subtasks_completed(self) -> int: return sum(1 for n in self._nodes.values() if n.status == "completed") def subtasks_total(self) -> int: return len(self._nodes) def subtasks_failed(self) -> int: return sum(1 for n in self._nodes.values() if n.status == "failed") def node_index(self, subtask_id: str) -> int: return self._order.index(subtask_id) def high_stakes_nodes(self) -> list[TaskNode]: return [n for n in self._nodes.values() if n.subtask["stakes"] >= ADVERSARIAL_AWARENESS_STAKES] # ------------------------------------------------------------------ # Mutations # ------------------------------------------------------------------ def record_outcome( self, subtask_id: str, outcome: float, specialist_id: str, was_adversarial: bool = False, ) -> None: if subtask_id not in self._nodes: raise KeyError(f"Unknown subtask_id: {subtask_id}") node = self._nodes[subtask_id] node.outcome = outcome node.specialist_used = specialist_id node.attempts += 1 node.was_adversarial = node.was_adversarial or was_adversarial if was_adversarial and outcome > 0.0: node.adversarial_detection_count += 1 elif was_adversarial: node.adversarial_poisoning_count += 1 node.status = "completed" if outcome > 0.0 else "failed" def skip_node(self, subtask_id: str) -> None: if subtask_id in self._nodes: self._nodes[subtask_id].status = "skipped" # ------------------------------------------------------------------ # Summary (for info dict in StepResult) # ------------------------------------------------------------------ def summary(self) -> dict[str, SummaryValue]: return { "scenario_id": self._scenario["scenario_id"], "task_type": self._scenario["task_type"], "subtasks_total": self.subtasks_total(), "subtasks_completed": self.subtasks_completed(), "subtasks_failed": self.subtasks_failed(), "subtasks_remaining": self.subtasks_remaining(), "completion_rate": round(self.completion_rate(), 3), "adversarial_detections": self.adversarial_detections(), "adversarial_poisonings": self.adversarial_poisonings(), "is_done": self.is_done(), } def node_statuses(self) -> dict[str, TaskStatus]: return {sid: n.status for sid, n in self._nodes.items()}