sentinel-env / difficulty_controller.py
XcodeAddy's picture
Add adaptive trust curriculum wow features
74b74f1
from __future__ import annotations
from dataclasses import asdict, dataclass, field
from statistics import mean
from sentinel_config import ADVERSARIAL_TRIGGER_STAKES
@dataclass
class DifficultyProfile:
"""Snapshot of the adaptive curriculum knobs for a new episode."""
adaptive: bool = False
episodes_seen: int = 0
rolling_detection_rate: float = 0.0
adversarial_threshold: float = ADVERSARIAL_TRIGGER_STAKES
high_stakes_ratio: float = 0.35
verify_budget_penalty: int = 0
adversary_benign_confidence: float = 0.88
adversary_poison_confidence: float = 0.92
def to_dict(self) -> dict[str, float | int | bool]:
payload = asdict(self)
payload["rolling_detection_rate"] = round(self.rolling_detection_rate, 3)
payload["adversarial_threshold"] = round(self.adversarial_threshold, 3)
payload["high_stakes_ratio"] = round(self.high_stakes_ratio, 3)
payload["adversary_benign_confidence"] = round(self.adversary_benign_confidence, 3)
payload["adversary_poison_confidence"] = round(self.adversary_poison_confidence, 3)
return payload
@dataclass
class DifficultyController:
"""
Tiny self-improving curriculum controller.
Every window of episodes, it watches adversarial detection rate. Strong
policies get harder episodes; struggling policies get easier recovery.
"""
window_size: int = 20
threshold_step: float = 0.05
high_stakes_step: float = 0.10
min_threshold: float = 0.40
max_threshold: float = 0.85
min_high_stakes_ratio: float = 0.25
max_high_stakes_ratio: float = 0.80
max_verify_budget_penalty: int = 8
_profile: DifficultyProfile = field(default_factory=DifficultyProfile)
_episode_detection_rates: list[float] = field(default_factory=list)
def profile(self, adaptive: bool) -> DifficultyProfile:
if not adaptive:
return DifficultyProfile(adaptive=False)
profile = DifficultyProfile(**asdict(self._profile))
profile.adaptive = True
return profile
def update(self, episode_metrics: dict[str, float | int]) -> DifficultyProfile:
detections = int(episode_metrics.get("adversarial_detections", 0))
poisonings = int(episode_metrics.get("adversarial_poisonings", 0))
encounters = int(episode_metrics.get("adversarial_encounters", detections + poisonings))
detection_rate = detections / max(1, encounters)
self._episode_detection_rates.append(detection_rate)
self._profile.episodes_seen += 1
window = self._episode_detection_rates[-self.window_size :]
self._profile.rolling_detection_rate = mean(window) if window else 0.0
if len(self._episode_detection_rates) % self.window_size == 0:
self._adapt_from_window(self._profile.rolling_detection_rate)
return self.profile(adaptive=True)
def reset(self) -> None:
self._profile = DifficultyProfile()
self._episode_detection_rates = []
def state(self) -> dict[str, float | int | bool]:
return self.profile(adaptive=True).to_dict()
def _adapt_from_window(self, detection_rate: float) -> None:
if detection_rate > 0.70:
self._profile.adversarial_threshold -= self.threshold_step
self._profile.high_stakes_ratio += self.high_stakes_step
self._profile.verify_budget_penalty += 1
elif detection_rate < 0.30:
self._profile.adversarial_threshold += self.threshold_step
self._profile.high_stakes_ratio -= self.high_stakes_step
self._profile.verify_budget_penalty -= 1
# Adversarial arms race: if the defender catches the adversary often,
# the attacker starts earlier and lowers confidence to blend in.
if detection_rate > 0.60:
self._profile.adversary_benign_confidence -= 0.03
self._profile.adversary_poison_confidence -= 0.03
self._profile.adversarial_threshold = max(
self.min_threshold,
min(self.max_threshold, self._profile.adversarial_threshold),
)
self._profile.high_stakes_ratio = max(
self.min_high_stakes_ratio,
min(self.max_high_stakes_ratio, self._profile.high_stakes_ratio),
)
self._profile.verify_budget_penalty = max(
0,
min(self.max_verify_budget_penalty, self._profile.verify_budget_penalty),
)
self._profile.adversary_benign_confidence = max(
0.60,
min(0.88, self._profile.adversary_benign_confidence),
)
self._profile.adversary_poison_confidence = max(
0.70,
min(0.92, self._profile.adversary_poison_confidence),
)
GLOBAL_DIFFICULTY_CONTROLLER = DifficultyController()