Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Dict, Set | |
| import yaml | |
| from auditenv.models import AuditAction | |
| class RewardConfig: | |
| control_tested: float = 0.1 | |
| true_positive: float = 0.2 | |
| full_evidence_chain: float = 0.3 | |
| false_positive: float = -0.5 | |
| false_negative: float = -0.2 | |
| noop_penalty: float = -0.05 | |
| human_review_bonus: float = 0.1 | |
| def from_yaml(cls, path: str = "configs/reward.yaml") -> "RewardConfig": | |
| """Load reward weights from YAML config, falling back to defaults.""" | |
| config_file = Path(path) | |
| if not config_file.exists(): | |
| return cls() | |
| try: | |
| with config_file.open("r", encoding="utf-8") as fh: | |
| cfg = yaml.safe_load(fh) or {} | |
| weights = cfg.get("weights", {}) | |
| return cls(**{k: v for k, v in weights.items() if k in cls.__dataclass_fields__}) | |
| except Exception: | |
| return cls() | |
| def truth_key(document_id: str, violation_type: str) -> str: | |
| return f"{document_id}:{violation_type}" | |
| def normalize_reward(raw_reward: float, floor: float = -1.0, ceiling: float = 1.0) -> float: | |
| clamped = max(floor, min(ceiling, raw_reward)) | |
| return (clamped - floor) / (ceiling - floor) | |
| def grade_step( | |
| action: AuditAction, | |
| ground_truth: Dict[str, str], | |
| evidence_map: Dict[str, list[str]], | |
| found: Set[str], | |
| cfg: RewardConfig, | |
| ) -> tuple[float, str]: | |
| if action.action_type == "noop": | |
| return cfg.noop_penalty, "noop_penalty" | |
| if action.action_type == "flag_human_review": | |
| return cfg.human_review_bonus, "human_review_bonus" | |
| if action.finding is None: | |
| return cfg.false_positive, "missing_finding_payload" | |
| finding = action.finding | |
| key = truth_key(finding.document_id, finding.violation_type) | |
| if key in found: | |
| return -0.1, "duplicate_finding" | |
| expected_violation = ground_truth.get(finding.document_id) | |
| if expected_violation != finding.violation_type: | |
| return cfg.false_positive, "false_positive" | |
| reward = cfg.true_positive + cfg.control_tested | |
| required_evidence = set(evidence_map.get(key, [])) | |
| supplied = set(finding.evidence) | |
| if required_evidence and required_evidence.issubset(supplied): | |
| reward += cfg.full_evidence_chain | |
| found.add(key) | |
| return reward, "true_positive" | |
| def terminal_missed_penalty(ground_truth: Dict[str, str], found: Set[str], cfg: RewardConfig) -> float: | |
| missed = 0 | |
| for doc_id, violation in ground_truth.items(): | |
| if truth_key(doc_id, violation) not in found: | |
| missed += 1 | |
| return missed * cfg.false_negative | |