File size: 1,849 Bytes
136ea72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aad7819
 
 
 
 
 
 
 
 
136ea72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from __future__ import annotations

import random
import unittest

from specialists import DomainBoundSpecialist, SpecialistPool


class SpecialistTests(unittest.TestCase):
    def test_domain_bound_matches_abstract_analysis_and_verify_tasks(self) -> None:
        specialist = DomainBoundSpecialist()

        in_domain = specialist.execute("Analyze the inputs and identify the key pattern.", 0.2, random.Random(1))
        out_domain = specialist.execute("Execute the planned action and report the outcome.", 0.2, random.Random(1))

        self.assertTrue(in_domain.metadata["in_domain"])
        self.assertFalse(out_domain.metadata["in_domain"])
        self.assertEqual(in_domain.outcome, 1.0)
        self.assertEqual(out_domain.outcome, 0.0)

    def test_domain_bound_prefers_structured_domain_over_keywords(self) -> None:
        specialist = DomainBoundSpecialist()

        structured = specialist.execute("Examine the payload carefully.", 0.2, random.Random(1), domain="ANALYZE")
        mismatched = specialist.execute("Analyze this deployment step.", 0.2, random.Random(1), domain="EXECUTE")

        self.assertTrue(structured.metadata["in_domain"])
        self.assertFalse(mismatched.metadata["in_domain"])

    def test_profile_shuffle_keeps_public_reliability_aligned(self) -> None:
        pool = SpecialistPool()
        pool.reset(seed=7)

        profile = pool.internal_profile()
        reliability = pool.public_ground_truth_reliability({"S0": 0.9, "S1": 0.6, "S2": 0.7, "S3": 0.15, "S4": 0.65})

        self.assertEqual(set(profile), {"S0", "S1", "S2", "S3", "S4"})
        self.assertEqual(set(reliability), {"S0", "S1", "S2", "S3", "S4"})
        self.assertEqual(profile[pool.adversarial_slot], "S3")
        self.assertEqual(reliability[pool.adversarial_slot], 0.15)


if __name__ == "__main__":
    unittest.main()