File size: 1,772 Bytes
136ea72
 
 
 
b3b9bbd
136ea72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b3b9bbd
 
 
 
 
136ea72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from __future__ import annotations

import unittest

from graders import clamp, confidence_alignment, grade_task3_step
from scenarios import get_scenario
from task_graph import TaskGraph


class GraderAndGraphTests(unittest.TestCase):
    def test_clamp_is_boundary_exclusive(self) -> None:
        self.assertEqual(clamp(-10), 0.01)
        self.assertEqual(clamp(10), 0.99)
        self.assertEqual(clamp(0.42), 0.42)

    def test_adversarial_verify_gets_detection_reward(self) -> None:
        reward, reason, breakdown = grade_task3_step(
            specialist_outcome=1.0,
            stakes=0.85,
            was_adversarial=True,
            action_type="verify",
            step_count=10,
            max_steps=45,
        )

        self.assertGreater(reward, 0.8)
        self.assertIn("Adversarial detected", reason)
        self.assertEqual(breakdown["stakes_awareness"], 0.99)
        self.assertIn("verification_quality", breakdown)

    def test_overconfident_wrong_answer_is_penalized(self) -> None:
        self.assertLess(confidence_alignment(0.95, 0.0), 0.1)
        self.assertGreater(confidence_alignment(0.85, 1.0), 0.8)

    def test_failed_nodes_are_retriable_then_resolved(self) -> None:
        graph = TaskGraph(get_scenario("SCN-TASK1-001"))
        node = graph.current_node()
        self.assertIsNotNone(node)

        graph.record_outcome(node.subtask["id"], 0.0, "S1")
        retry = graph.current_node()
        self.assertEqual(retry.subtask["id"], node.subtask["id"])

        graph.record_outcome(node.subtask["id"], 0.0, "S1")
        next_node = graph.current_node()
        self.assertIsNotNone(next_node)
        self.assertNotEqual(next_node.subtask["id"], node.subtask["id"])


if __name__ == "__main__":
    unittest.main()