File size: 4,812 Bytes
74b74f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
from __future__ import annotations

from dataclasses import asdict, dataclass, field
from statistics import mean

from sentinel_config import ADVERSARIAL_TRIGGER_STAKES


@dataclass
class DifficultyProfile:
    """Snapshot of the adaptive curriculum knobs for a new episode."""

    adaptive: bool = False
    episodes_seen: int = 0
    rolling_detection_rate: float = 0.0
    adversarial_threshold: float = ADVERSARIAL_TRIGGER_STAKES
    high_stakes_ratio: float = 0.35
    verify_budget_penalty: int = 0
    adversary_benign_confidence: float = 0.88
    adversary_poison_confidence: float = 0.92

    def to_dict(self) -> dict[str, float | int | bool]:
        payload = asdict(self)
        payload["rolling_detection_rate"] = round(self.rolling_detection_rate, 3)
        payload["adversarial_threshold"] = round(self.adversarial_threshold, 3)
        payload["high_stakes_ratio"] = round(self.high_stakes_ratio, 3)
        payload["adversary_benign_confidence"] = round(self.adversary_benign_confidence, 3)
        payload["adversary_poison_confidence"] = round(self.adversary_poison_confidence, 3)
        return payload


@dataclass
class DifficultyController:
    """
    Tiny self-improving curriculum controller.

    Every window of episodes, it watches adversarial detection rate. Strong
    policies get harder episodes; struggling policies get easier recovery.
    """

    window_size: int = 20
    threshold_step: float = 0.05
    high_stakes_step: float = 0.10
    min_threshold: float = 0.40
    max_threshold: float = 0.85
    min_high_stakes_ratio: float = 0.25
    max_high_stakes_ratio: float = 0.80
    max_verify_budget_penalty: int = 8
    _profile: DifficultyProfile = field(default_factory=DifficultyProfile)
    _episode_detection_rates: list[float] = field(default_factory=list)

    def profile(self, adaptive: bool) -> DifficultyProfile:
        if not adaptive:
            return DifficultyProfile(adaptive=False)
        profile = DifficultyProfile(**asdict(self._profile))
        profile.adaptive = True
        return profile

    def update(self, episode_metrics: dict[str, float | int]) -> DifficultyProfile:
        detections = int(episode_metrics.get("adversarial_detections", 0))
        poisonings = int(episode_metrics.get("adversarial_poisonings", 0))
        encounters = int(episode_metrics.get("adversarial_encounters", detections + poisonings))
        detection_rate = detections / max(1, encounters)

        self._episode_detection_rates.append(detection_rate)
        self._profile.episodes_seen += 1
        window = self._episode_detection_rates[-self.window_size :]
        self._profile.rolling_detection_rate = mean(window) if window else 0.0

        if len(self._episode_detection_rates) % self.window_size == 0:
            self._adapt_from_window(self._profile.rolling_detection_rate)

        return self.profile(adaptive=True)

    def reset(self) -> None:
        self._profile = DifficultyProfile()
        self._episode_detection_rates = []

    def state(self) -> dict[str, float | int | bool]:
        return self.profile(adaptive=True).to_dict()

    def _adapt_from_window(self, detection_rate: float) -> None:
        if detection_rate > 0.70:
            self._profile.adversarial_threshold -= self.threshold_step
            self._profile.high_stakes_ratio += self.high_stakes_step
            self._profile.verify_budget_penalty += 1
        elif detection_rate < 0.30:
            self._profile.adversarial_threshold += self.threshold_step
            self._profile.high_stakes_ratio -= self.high_stakes_step
            self._profile.verify_budget_penalty -= 1

        # Adversarial arms race: if the defender catches the adversary often,
        # the attacker starts earlier and lowers confidence to blend in.
        if detection_rate > 0.60:
            self._profile.adversary_benign_confidence -= 0.03
            self._profile.adversary_poison_confidence -= 0.03

        self._profile.adversarial_threshold = max(
            self.min_threshold,
            min(self.max_threshold, self._profile.adversarial_threshold),
        )
        self._profile.high_stakes_ratio = max(
            self.min_high_stakes_ratio,
            min(self.max_high_stakes_ratio, self._profile.high_stakes_ratio),
        )
        self._profile.verify_budget_penalty = max(
            0,
            min(self.max_verify_budget_penalty, self._profile.verify_budget_penalty),
        )
        self._profile.adversary_benign_confidence = max(
            0.60,
            min(0.88, self._profile.adversary_benign_confidence),
        )
        self._profile.adversary_poison_confidence = max(
            0.70,
            min(0.92, self._profile.adversary_poison_confidence),
        )


GLOBAL_DIFFICULTY_CONTROLLER = DifficultyController()