File size: 5,643 Bytes
54e81df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
from __future__ import annotations

from dataclasses import dataclass, field
from typing import Optional

from scenarios import Scenario, SubTask


# ---------------------------------------------------------------------------
# Node state
# ---------------------------------------------------------------------------

@dataclass
class TaskNode:
    subtask: SubTask
    status: str = "pending"      # pending | ready | in_progress | completed | failed | skipped
    outcome: float = 0.0         # 1.0 correct | 0.5 partial | 0.0 wrong
    specialist_used: str = ""
    attempts: int = 0
    was_adversarial: bool = False


# ---------------------------------------------------------------------------
# TaskGraph
# Manages the DAG of subtasks for one episode.
# Tracks dependencies, determines which nodes are "ready" to execute,
# and records outcomes.
# ---------------------------------------------------------------------------

class TaskGraph:
    def __init__(self, scenario: Scenario) -> None:
        self._scenario   = scenario
        self._nodes: dict[str, TaskNode] = {}
        self._order: list[str] = []   # insertion order (for iteration)
        self._build(scenario["subtasks"])

    def _build(self, subtasks: list[SubTask]) -> None:
        for st in subtasks:
            self._nodes[st["id"]] = TaskNode(subtask=st)
            self._order.append(st["id"])

    # ------------------------------------------------------------------
    # State queries
    # ------------------------------------------------------------------

    def current_node(self) -> Optional[TaskNode]:
        """
        Returns the first 'ready' node (all dependencies completed).
        Returns None if all nodes are done or none are unblocked yet.
        """
        for sid in self._order:
            node = self._nodes[sid]
            if node.status == "pending" and self._deps_met(sid):
                node.status = "ready"
            if node.status == "ready":
                return node
        return None

    def _deps_met(self, subtask_id: str) -> bool:
        """All dependencies of this node must be 'completed'."""
        deps = self._nodes[subtask_id].subtask["depends_on"]
        return all(
            self._nodes[dep].status == "completed"
            for dep in deps
            if dep in self._nodes
        )

    def is_done(self) -> bool:
        return all(
            n.status in ("completed", "failed", "skipped")
            for n in self._nodes.values()
        )

    def completion_rate(self) -> float:
        completed = sum(1 for n in self._nodes.values() if n.status == "completed")
        return completed / len(self._nodes) if self._nodes else 0.0

    def adversarial_detections(self) -> int:
        """
        Count of high-stakes adversarial attempts that were avoided.
        Avoided = node was adversarial AND orchestrator chose VERIFY or SOLVE_INDEPENDENTLY.
        """
        return sum(
            1 for n in self._nodes.values()
            if n.was_adversarial and n.status == "completed" and n.outcome > 0.0
        )

    def adversarial_poisonings(self) -> int:
        """
        Count of adversarial results that slipped through unchecked.
        """
        return sum(
            1 for n in self._nodes.values()
            if n.was_adversarial and n.outcome == 0.0
        )

    def subtasks_remaining(self) -> int:
        return sum(
            1 for n in self._nodes.values()
            if n.status in ("pending", "ready", "in_progress")
        )

    def subtasks_completed(self) -> int:
        return sum(1 for n in self._nodes.values() if n.status == "completed")

    def subtasks_total(self) -> int:
        return len(self._nodes)

    def high_stakes_nodes(self) -> list[TaskNode]:
        return [n for n in self._nodes.values() if n.subtask["stakes"] >= 0.70]

    # ------------------------------------------------------------------
    # Mutations
    # ------------------------------------------------------------------

    def record_outcome(
        self,
        subtask_id: str,
        outcome: float,
        specialist_id: str,
        was_adversarial: bool = False,
    ) -> None:
        if subtask_id not in self._nodes:
            raise KeyError(f"Unknown subtask_id: {subtask_id}")
        node = self._nodes[subtask_id]
        node.outcome         = outcome
        node.specialist_used = specialist_id
        node.attempts        += 1
        node.was_adversarial  = was_adversarial
        node.status = "completed" if outcome > 0.0 else "failed"

    def skip_node(self, subtask_id: str) -> None:
        if subtask_id in self._nodes:
            self._nodes[subtask_id].status = "skipped"

    # ------------------------------------------------------------------
    # Summary (for info dict in StepResult)
    # ------------------------------------------------------------------

    def summary(self) -> dict:
        return {
            "scenario_id":          self._scenario["scenario_id"],
            "task_type":            self._scenario["task_type"],
            "subtasks_total":       self.subtasks_total(),
            "subtasks_completed":   self.subtasks_completed(),
            "subtasks_remaining":   self.subtasks_remaining(),
            "completion_rate":      round(self.completion_rate(), 3),
            "adversarial_detections": self.adversarial_detections(),
            "adversarial_poisonings": self.adversarial_poisonings(),
            "is_done":              self.is_done(),
        }

    def node_statuses(self) -> dict[str, str]:
        return {sid: n.status for sid, n in self._nodes.items()}