File size: 2,742 Bytes
a617acd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
from __future__ import annotations

from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Set

import yaml

from auditenv.models import AuditAction


@dataclass(frozen=True)
class RewardConfig:
    control_tested: float = 0.1
    true_positive: float = 0.2
    full_evidence_chain: float = 0.3
    false_positive: float = -0.5
    false_negative: float = -0.2
    noop_penalty: float = -0.05
    human_review_bonus: float = 0.1

    @classmethod
    def from_yaml(cls, path: str = "configs/reward.yaml") -> "RewardConfig":
        """Load reward weights from YAML config, falling back to defaults."""
        config_file = Path(path)
        if not config_file.exists():
            return cls()
        try:
            with config_file.open("r", encoding="utf-8") as fh:
                cfg = yaml.safe_load(fh) or {}
            weights = cfg.get("weights", {})
            return cls(**{k: v for k, v in weights.items() if k in cls.__dataclass_fields__})
        except Exception:
            return cls()


def truth_key(document_id: str, violation_type: str) -> str:
    return f"{document_id}:{violation_type}"


def normalize_reward(raw_reward: float, floor: float = -1.0, ceiling: float = 1.0) -> float:
    clamped = max(floor, min(ceiling, raw_reward))
    return (clamped - floor) / (ceiling - floor)


def grade_step(
    action: AuditAction,
    ground_truth: Dict[str, str],
    evidence_map: Dict[str, list[str]],
    found: Set[str],
    cfg: RewardConfig,
) -> tuple[float, str]:
    if action.action_type == "noop":
        return cfg.noop_penalty, "noop_penalty"

    if action.action_type == "flag_human_review":
        return cfg.human_review_bonus, "human_review_bonus"

    if action.finding is None:
        return cfg.false_positive, "missing_finding_payload"

    finding = action.finding
    key = truth_key(finding.document_id, finding.violation_type)

    if key in found:
        return -0.1, "duplicate_finding"

    expected_violation = ground_truth.get(finding.document_id)
    if expected_violation != finding.violation_type:
        return cfg.false_positive, "false_positive"

    reward = cfg.true_positive + cfg.control_tested
    required_evidence = set(evidence_map.get(key, []))
    supplied = set(finding.evidence)
    if required_evidence and required_evidence.issubset(supplied):
        reward += cfg.full_evidence_chain

    found.add(key)
    return reward, "true_positive"


def terminal_missed_penalty(ground_truth: Dict[str, str], found: Set[str], cfg: RewardConfig) -> float:
    missed = 0
    for doc_id, violation in ground_truth.items():
        if truth_key(doc_id, violation) not in found:
            missed += 1
    return missed * cfg.false_negative