from __future__ import annotations from dataclasses import dataclass from pathlib import Path from typing import Dict, Set import yaml from auditenv.models import AuditAction @dataclass(frozen=True) class RewardConfig: control_tested: float = 0.1 true_positive: float = 0.2 full_evidence_chain: float = 0.3 false_positive: float = -0.5 false_negative: float = -0.2 noop_penalty: float = -0.05 human_review_bonus: float = 0.1 @classmethod def from_yaml(cls, path: str = "configs/reward.yaml") -> "RewardConfig": """Load reward weights from YAML config, falling back to defaults.""" config_file = Path(path) if not config_file.exists(): return cls() try: with config_file.open("r", encoding="utf-8") as fh: cfg = yaml.safe_load(fh) or {} weights = cfg.get("weights", {}) return cls(**{k: v for k, v in weights.items() if k in cls.__dataclass_fields__}) except Exception: return cls() def truth_key(document_id: str, violation_type: str) -> str: return f"{document_id}:{violation_type}" def normalize_reward(raw_reward: float, floor: float = -1.0, ceiling: float = 1.0) -> float: clamped = max(floor, min(ceiling, raw_reward)) return (clamped - floor) / (ceiling - floor) def grade_step( action: AuditAction, ground_truth: Dict[str, str], evidence_map: Dict[str, list[str]], found: Set[str], cfg: RewardConfig, ) -> tuple[float, str]: if action.action_type == "noop": return cfg.noop_penalty, "noop_penalty" if action.action_type == "flag_human_review": return cfg.human_review_bonus, "human_review_bonus" if action.finding is None: return cfg.false_positive, "missing_finding_payload" finding = action.finding key = truth_key(finding.document_id, finding.violation_type) if key in found: return -0.1, "duplicate_finding" expected_violation = ground_truth.get(finding.document_id) if expected_violation != finding.violation_type: return cfg.false_positive, "false_positive" reward = cfg.true_positive + cfg.control_tested required_evidence = set(evidence_map.get(key, [])) supplied = set(finding.evidence) if required_evidence and required_evidence.issubset(supplied): reward += cfg.full_evidence_chain found.add(key) return reward, "true_positive" def terminal_missed_penalty(ground_truth: Dict[str, str], found: Set[str], cfg: RewardConfig) -> float: missed = 0 for doc_id, violation in ground_truth.items(): if truth_key(doc_id, violation) not in found: missed += 1 return missed * cfg.false_negative