Corp_AI / src /auditenv /grader.py
Arpit Deep
feat: initial AuditEnv submission
a617acd
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Set
import yaml
from auditenv.models import AuditAction
@dataclass(frozen=True)
class RewardConfig:
control_tested: float = 0.1
true_positive: float = 0.2
full_evidence_chain: float = 0.3
false_positive: float = -0.5
false_negative: float = -0.2
noop_penalty: float = -0.05
human_review_bonus: float = 0.1
@classmethod
def from_yaml(cls, path: str = "configs/reward.yaml") -> "RewardConfig":
"""Load reward weights from YAML config, falling back to defaults."""
config_file = Path(path)
if not config_file.exists():
return cls()
try:
with config_file.open("r", encoding="utf-8") as fh:
cfg = yaml.safe_load(fh) or {}
weights = cfg.get("weights", {})
return cls(**{k: v for k, v in weights.items() if k in cls.__dataclass_fields__})
except Exception:
return cls()
def truth_key(document_id: str, violation_type: str) -> str:
return f"{document_id}:{violation_type}"
def normalize_reward(raw_reward: float, floor: float = -1.0, ceiling: float = 1.0) -> float:
clamped = max(floor, min(ceiling, raw_reward))
return (clamped - floor) / (ceiling - floor)
def grade_step(
action: AuditAction,
ground_truth: Dict[str, str],
evidence_map: Dict[str, list[str]],
found: Set[str],
cfg: RewardConfig,
) -> tuple[float, str]:
if action.action_type == "noop":
return cfg.noop_penalty, "noop_penalty"
if action.action_type == "flag_human_review":
return cfg.human_review_bonus, "human_review_bonus"
if action.finding is None:
return cfg.false_positive, "missing_finding_payload"
finding = action.finding
key = truth_key(finding.document_id, finding.violation_type)
if key in found:
return -0.1, "duplicate_finding"
expected_violation = ground_truth.get(finding.document_id)
if expected_violation != finding.violation_type:
return cfg.false_positive, "false_positive"
reward = cfg.true_positive + cfg.control_tested
required_evidence = set(evidence_map.get(key, []))
supplied = set(finding.evidence)
if required_evidence and required_evidence.issubset(supplied):
reward += cfg.full_evidence_chain
found.add(key)
return reward, "true_positive"
def terminal_missed_penalty(ground_truth: Dict[str, str], found: Set[str], cfg: RewardConfig) -> float:
missed = 0
for doc_id, violation in ground_truth.items():
if truth_key(doc_id, violation) not in found:
missed += 1
return missed * cfg.false_negative