Spaces:
Running
Running
File size: 6,873 Bytes
325aa05 136ea72 aad7819 325aa05 aad7819 325aa05 aad7819 325aa05 136ea72 325aa05 136ea72 325aa05 aad7819 325aa05 136ea72 325aa05 136ea72 325aa05 136ea72 325aa05 136ea72 325aa05 136ea72 325aa05 136ea72 325aa05 136ea72 325aa05 136ea72 325aa05 136ea72 325aa05 aad7819 325aa05 136ea72 325aa05 aad7819 325aa05 136ea72 325aa05 aad7819 136ea72 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 | from __future__ import annotations
from dataclasses import dataclass
from typing import Literal
from sentinel_config import ADVERSARIAL_AWARENESS_STAKES
from scenarios import Scenario, SubTask
TaskStatus = Literal["pending", "ready", "in_progress", "completed", "failed", "skipped"]
SummaryValue = str | int | float | bool
# ---------------------------------------------------------------------------
# Node state
# ---------------------------------------------------------------------------
@dataclass
class TaskNode:
subtask: SubTask
status: TaskStatus = "pending"
outcome: float = 0.0 # 1.0 correct | 0.5 partial | 0.0 wrong
specialist_used: str = ""
attempts: int = 0
was_adversarial: bool = False
adversarial_detection_count: int = 0
adversarial_poisoning_count: int = 0
# ---------------------------------------------------------------------------
# TaskGraph
# Manages the DAG of subtasks for one episode.
# Tracks dependencies, determines which nodes are "ready" to execute,
# and records outcomes.
# ---------------------------------------------------------------------------
class TaskGraph:
MAX_ATTEMPTS_PER_NODE = 2
def __init__(self, scenario: Scenario) -> None:
self._scenario = scenario
self._nodes: dict[str, TaskNode] = {}
self._order: list[str] = [] # insertion order (for iteration)
self._build(scenario["subtasks"])
def _build(self, subtasks: list[SubTask]) -> None:
for st in subtasks:
self._nodes[st["id"]] = TaskNode(subtask=st)
self._order.append(st["id"])
# ------------------------------------------------------------------
# State queries
# ------------------------------------------------------------------
def current_node(self) -> TaskNode | None:
"""
Returns the first 'ready' node (all dependencies completed).
Returns None if all nodes are done or none are unblocked yet.
"""
for sid in self._order:
node = self._nodes[sid]
if node.status == "failed" and node.attempts < self.MAX_ATTEMPTS_PER_NODE:
node.status = "ready"
if node.status == "pending" and self._deps_met(sid):
node.status = "ready"
if node.status == "ready":
return node
return None
def _deps_met(self, subtask_id: str) -> bool:
"""All dependencies must be resolved before downstream work starts."""
deps = self._nodes[subtask_id].subtask["depends_on"]
return all(
self._is_dependency_resolved(dep)
for dep in deps
if dep in self._nodes
)
def _is_dependency_resolved(self, subtask_id: str) -> bool:
node = self._nodes[subtask_id]
if node.status in ("completed", "skipped"):
return True
return node.status == "failed" and node.attempts >= self.MAX_ATTEMPTS_PER_NODE
def is_done(self) -> bool:
return all(
n.status in ("completed", "skipped")
or (n.status == "failed" and n.attempts >= self.MAX_ATTEMPTS_PER_NODE)
for n in self._nodes.values()
)
def completion_rate(self) -> float:
completed = sum(1 for n in self._nodes.values() if n.status == "completed")
return completed / len(self._nodes) if self._nodes else 0.0
def adversarial_detections(self) -> int:
"""
Count of high-stakes adversarial attempts that were avoided.
Avoided = node was adversarial AND orchestrator chose VERIFY or SOLVE_INDEPENDENTLY.
"""
return sum(
n.adversarial_detection_count for n in self._nodes.values()
)
def adversarial_poisonings(self) -> int:
"""
Count of adversarial results that slipped through unchecked.
"""
return sum(
n.adversarial_poisoning_count for n in self._nodes.values()
)
def subtasks_remaining(self) -> int:
return sum(
1 for n in self._nodes.values()
if n.status in ("pending", "ready", "in_progress")
or (n.status == "failed" and n.attempts < self.MAX_ATTEMPTS_PER_NODE)
)
def subtasks_completed(self) -> int:
return sum(1 for n in self._nodes.values() if n.status == "completed")
def subtasks_total(self) -> int:
return len(self._nodes)
def subtasks_failed(self) -> int:
return sum(1 for n in self._nodes.values() if n.status == "failed")
def node_index(self, subtask_id: str) -> int:
return self._order.index(subtask_id)
def high_stakes_nodes(self) -> list[TaskNode]:
return [n for n in self._nodes.values() if n.subtask["stakes"] >= ADVERSARIAL_AWARENESS_STAKES]
# ------------------------------------------------------------------
# Mutations
# ------------------------------------------------------------------
def record_outcome(
self,
subtask_id: str,
outcome: float,
specialist_id: str,
was_adversarial: bool = False,
) -> None:
if subtask_id not in self._nodes:
raise KeyError(f"Unknown subtask_id: {subtask_id}")
node = self._nodes[subtask_id]
node.outcome = outcome
node.specialist_used = specialist_id
node.attempts += 1
node.was_adversarial = node.was_adversarial or was_adversarial
if was_adversarial and outcome > 0.0:
node.adversarial_detection_count += 1
elif was_adversarial:
node.adversarial_poisoning_count += 1
node.status = "completed" if outcome > 0.0 else "failed"
def skip_node(self, subtask_id: str) -> None:
if subtask_id in self._nodes:
self._nodes[subtask_id].status = "skipped"
# ------------------------------------------------------------------
# Summary (for info dict in StepResult)
# ------------------------------------------------------------------
def summary(self) -> dict[str, SummaryValue]:
return {
"scenario_id": self._scenario["scenario_id"],
"task_type": self._scenario["task_type"],
"subtasks_total": self.subtasks_total(),
"subtasks_completed": self.subtasks_completed(),
"subtasks_failed": self.subtasks_failed(),
"subtasks_remaining": self.subtasks_remaining(),
"completion_rate": round(self.completion_rate(), 3),
"adversarial_detections": self.adversarial_detections(),
"adversarial_poisonings": self.adversarial_poisonings(),
"is_done": self.is_done(),
}
def node_statuses(self) -> dict[str, TaskStatus]:
return {sid: n.status for sid, n in self._nodes.items()}
|