Spaces:
Running
Running
File size: 5,643 Bytes
54e81df | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 | from __future__ import annotations
from dataclasses import dataclass, field
from typing import Optional
from scenarios import Scenario, SubTask
# ---------------------------------------------------------------------------
# Node state
# ---------------------------------------------------------------------------
@dataclass
class TaskNode:
subtask: SubTask
status: str = "pending" # pending | ready | in_progress | completed | failed | skipped
outcome: float = 0.0 # 1.0 correct | 0.5 partial | 0.0 wrong
specialist_used: str = ""
attempts: int = 0
was_adversarial: bool = False
# ---------------------------------------------------------------------------
# TaskGraph
# Manages the DAG of subtasks for one episode.
# Tracks dependencies, determines which nodes are "ready" to execute,
# and records outcomes.
# ---------------------------------------------------------------------------
class TaskGraph:
def __init__(self, scenario: Scenario) -> None:
self._scenario = scenario
self._nodes: dict[str, TaskNode] = {}
self._order: list[str] = [] # insertion order (for iteration)
self._build(scenario["subtasks"])
def _build(self, subtasks: list[SubTask]) -> None:
for st in subtasks:
self._nodes[st["id"]] = TaskNode(subtask=st)
self._order.append(st["id"])
# ------------------------------------------------------------------
# State queries
# ------------------------------------------------------------------
def current_node(self) -> Optional[TaskNode]:
"""
Returns the first 'ready' node (all dependencies completed).
Returns None if all nodes are done or none are unblocked yet.
"""
for sid in self._order:
node = self._nodes[sid]
if node.status == "pending" and self._deps_met(sid):
node.status = "ready"
if node.status == "ready":
return node
return None
def _deps_met(self, subtask_id: str) -> bool:
"""All dependencies of this node must be 'completed'."""
deps = self._nodes[subtask_id].subtask["depends_on"]
return all(
self._nodes[dep].status == "completed"
for dep in deps
if dep in self._nodes
)
def is_done(self) -> bool:
return all(
n.status in ("completed", "failed", "skipped")
for n in self._nodes.values()
)
def completion_rate(self) -> float:
completed = sum(1 for n in self._nodes.values() if n.status == "completed")
return completed / len(self._nodes) if self._nodes else 0.0
def adversarial_detections(self) -> int:
"""
Count of high-stakes adversarial attempts that were avoided.
Avoided = node was adversarial AND orchestrator chose VERIFY or SOLVE_INDEPENDENTLY.
"""
return sum(
1 for n in self._nodes.values()
if n.was_adversarial and n.status == "completed" and n.outcome > 0.0
)
def adversarial_poisonings(self) -> int:
"""
Count of adversarial results that slipped through unchecked.
"""
return sum(
1 for n in self._nodes.values()
if n.was_adversarial and n.outcome == 0.0
)
def subtasks_remaining(self) -> int:
return sum(
1 for n in self._nodes.values()
if n.status in ("pending", "ready", "in_progress")
)
def subtasks_completed(self) -> int:
return sum(1 for n in self._nodes.values() if n.status == "completed")
def subtasks_total(self) -> int:
return len(self._nodes)
def high_stakes_nodes(self) -> list[TaskNode]:
return [n for n in self._nodes.values() if n.subtask["stakes"] >= 0.70]
# ------------------------------------------------------------------
# Mutations
# ------------------------------------------------------------------
def record_outcome(
self,
subtask_id: str,
outcome: float,
specialist_id: str,
was_adversarial: bool = False,
) -> None:
if subtask_id not in self._nodes:
raise KeyError(f"Unknown subtask_id: {subtask_id}")
node = self._nodes[subtask_id]
node.outcome = outcome
node.specialist_used = specialist_id
node.attempts += 1
node.was_adversarial = was_adversarial
node.status = "completed" if outcome > 0.0 else "failed"
def skip_node(self, subtask_id: str) -> None:
if subtask_id in self._nodes:
self._nodes[subtask_id].status = "skipped"
# ------------------------------------------------------------------
# Summary (for info dict in StepResult)
# ------------------------------------------------------------------
def summary(self) -> dict:
return {
"scenario_id": self._scenario["scenario_id"],
"task_type": self._scenario["task_type"],
"subtasks_total": self.subtasks_total(),
"subtasks_completed": self.subtasks_completed(),
"subtasks_remaining": self.subtasks_remaining(),
"completion_rate": round(self.completion_rate(), 3),
"adversarial_detections": self.adversarial_detections(),
"adversarial_poisonings": self.adversarial_poisonings(),
"is_done": self.is_done(),
}
def node_statuses(self) -> dict[str, str]:
return {sid: n.status for sid, n in self._nodes.items()} |