Spaces:
Running
Running
File size: 4,638 Bytes
a36db1b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 | from __future__ import annotations
import statistics
from dataclasses import dataclass
from typing import Any
@dataclass
class RewardClaim:
agent_id: str
reward: float
justification: str
timestamp: int
approved: bool | None = None
@dataclass
class ActionRecord:
agent_id: str
action: dict[str, Any]
timestamp: int
class AuditLedger:
"""
Records reward claims and actions so an Auditor policy can detect hacking.
"""
def __init__(self) -> None:
self._claims: list[RewardClaim] = []
self._actions: list[ActionRecord] = []
self._flags: dict[str, list[str]] = {}
def record_claim(
self,
agent_id: str,
reward: float,
justification: str,
timestamp: int,
approved: bool | None = None,
) -> None:
self._claims.append(
RewardClaim(
agent_id=agent_id,
reward=max(0.0, min(1.0, reward)),
justification=justification,
timestamp=timestamp,
approved=approved,
)
)
def record_action(self, agent_id: str, action: dict[str, Any], timestamp: int) -> None:
self._actions.append(ActionRecord(agent_id=agent_id, action=dict(action), timestamp=timestamp))
def flag(self, agent_id: str, reason: str) -> None:
self._flags.setdefault(agent_id, []).append(reason)
def clear(self, agent_id: str) -> None:
self._flags.pop(agent_id, None)
def anomaly_scores(self) -> dict[str, float]:
"""
Simple deterministic anomaly score in [0, 1].
High scores mean the agent's reward claims are unusually high, volatile,
denied by the auditor, or already flagged.
"""
scores: dict[str, float] = {}
for agent_id in self._agent_ids():
claims = [claim for claim in self._claims if claim.agent_id == agent_id]
if not claims:
scores[agent_id] = 0.0
continue
rewards = [claim.reward for claim in claims]
avg_reward = statistics.fmean(rewards)
volatility = statistics.pstdev(rewards) if len(rewards) > 1 else 0.0
denied_rate = sum(1 for claim in claims if claim.approved is False) / len(claims)
high_claim_rate = sum(1 for reward in rewards if reward >= 0.90) / len(rewards)
flag_bonus = min(0.25, 0.08 * len(self._flags.get(agent_id, [])))
score = (
0.35 * high_claim_rate
+ 0.25 * max(0.0, avg_reward - 0.70) / 0.30
+ 0.20 * min(1.0, volatility / 0.40)
+ 0.20 * denied_rate
+ flag_bonus
)
scores[agent_id] = round(max(0.0, min(1.0, score)), 4)
return scores
def investigate(self, agent_id: str, window: int = 10) -> dict[str, Any]:
latest_time = self._latest_timestamp()
since = max(0, latest_time - window)
claims = [
claim for claim in self._claims
if claim.agent_id == agent_id and claim.timestamp >= since
]
actions = [
record for record in self._actions
if record.agent_id == agent_id and record.timestamp >= since
]
rewards = [claim.reward for claim in claims]
return {
"agent_id": agent_id,
"window": window,
"claims": [claim.__dict__ for claim in claims],
"actions": [record.__dict__ for record in actions],
"avg_claimed_reward": round(statistics.fmean(rewards), 4) if rewards else 0.0,
"denied_claims": sum(1 for claim in claims if claim.approved is False),
"flags": list(self._flags.get(agent_id, [])),
"anomaly_score": self.anomaly_scores().get(agent_id, 0.0),
}
def snapshot(self) -> dict[str, Any]:
return {
"claims": [claim.__dict__ for claim in self._claims],
"actions": [record.__dict__ for record in self._actions],
"anomaly_scores": self.anomaly_scores(),
"flags": {agent: list(reasons) for agent, reasons in self._flags.items()},
}
def _agent_ids(self) -> set[str]:
return (
{claim.agent_id for claim in self._claims}
| {record.agent_id for record in self._actions}
| set(self._flags)
)
def _latest_timestamp(self) -> int:
timestamps = [claim.timestamp for claim in self._claims] + [
record.timestamp for record in self._actions
]
return max(timestamps) if timestamps else 0
|