cprm / scripts /validate_compliance_prm.py
Zhuohan's picture
Initial commit
7aceaa5
#!/usr/bin/env python3
from __future__ import annotations
import argparse
import csv
import json
from pathlib import Path
from typing import Iterable
ROOT = Path(__file__).resolve().parents[1]
RULE_CARDS_PATH = ROOT / "data" / "rules" / "rule_cards_starter.csv"
A_TRACES_PATH = ROOT / "data" / "canonical" / "A_canonical_traces_v1.jsonl"
B_TRACES_PATH = ROOT / "data" / "counterfactuals" / "B_counterfactual_traces_v1.jsonl"
BUNDLES_PATH = ROOT / "data" / "bundles" / "pilot_bundles_v1.jsonl"
ALLOWED_ACTION_TYPES = {
"ask_profile",
"confirm_preference",
"assess",
"recommend",
"redirect",
"disclose_risk",
"disclose_conflict",
"marketing_claim",
}
REQUIRED_BUNDLE_LABELS = {"compliant", "deadline_missed", "hard_violation"}
DEADLINE_VIOLATION_ACTION_TYPES = {"recommend", "marketing_claim"}
REMEDIAL_ACTION_TYPES = {
"ask_profile",
"confirm_preference",
"disclose_risk",
"disclose_conflict",
}
def load_jsonl(path: Path) -> list[dict]:
with path.open("r", encoding="utf-8") as handle:
return [json.loads(line) for line in handle if line.strip()]
def load_json(path: Path) -> dict:
with path.open("r", encoding="utf-8") as handle:
return json.load(handle)
def load_rule_ids(path: Path) -> set[str]:
with path.open("r", encoding="utf-8", newline="") as handle:
return {row["rule_id"] for row in csv.DictReader(handle)}
def earliest_hard_violation(trace: dict) -> int | None:
for step in trace["steps"]:
if step["hard_violation"]:
return step["step_id"]
return None
def has_hard_violation(trace: dict) -> bool:
return earliest_hard_violation(trace) is not None
def validate_trace(trace: dict, valid_rule_ids: set[str], errors: list[str], origin: str) -> None:
expected_first = earliest_hard_violation(trace)
if trace["first_violation_step"] != expected_first:
errors.append(
f"{origin}: first_violation_step={trace['first_violation_step']} but earliest hard violation is {expected_first}"
)
expected_compliant = expected_first is None and trace["label"] == "compliant"
if trace["overall_compliant"] != expected_compliant:
errors.append(
f"{origin}: overall_compliant={trace['overall_compliant']} but expected {expected_compliant}"
)
cumulative: dict[str, float] = {}
for step in trace["steps"]:
if step["action_type"] not in ALLOWED_ACTION_TYPES:
errors.append(f"{origin}: unsupported action_type {step['action_type']}")
for rule_id in step["active_rule_ids"]:
if rule_id not in valid_rule_ids:
errors.append(f"{origin}: unknown active_rule_id {rule_id}")
violated_rule_id = step["violated_rule_id"]
if violated_rule_id is not None and violated_rule_id not in valid_rule_ids:
errors.append(f"{origin}: unknown violated_rule_id {violated_rule_id}")
for rule_id, delta in step["soft_coverage_delta"].items():
if rule_id not in valid_rule_ids:
errors.append(f"{origin}: unknown soft_coverage_delta rule_id {rule_id}")
cumulative[rule_id] = cumulative.get(rule_id, 0.0) + float(delta)
if cumulative[rule_id] > 1.000001:
errors.append(
f"{origin}: cumulative soft coverage for {rule_id} exceeds 1.0 ({cumulative[rule_id]:.3f})"
)
def validate_canonical_trace_length(trace: dict, errors: list[str], origin: str) -> None:
step_count = len(trace["steps"])
if not 3 <= step_count <= 5:
errors.append(f"{origin}: canonical trace has {step_count} steps; expected 3 to 5")
def validate_deadline_missed_temporal(trace: dict, warnings: list[str], origin: str) -> None:
if trace["label"] != "deadline_missed":
return
first_violation_step = trace["first_violation_step"]
if first_violation_step is None:
warnings.append(f"{origin}: deadline_missed trace has null first_violation_step")
return
if first_violation_step > 2:
warnings.append(
f"{origin}: deadline_missed first_violation_step={first_violation_step}; expected early violation at step 1 or 2"
)
step_lookup = {step["step_id"]: step for step in trace["steps"]}
violating_step = step_lookup.get(first_violation_step)
if violating_step is None:
warnings.append(
f"{origin}: first_violation_step={first_violation_step} does not exist in steps"
)
return
violating_action = violating_step["action_type"]
if violating_action not in DEADLINE_VIOLATION_ACTION_TYPES:
warnings.append(
f"{origin}: deadline_missed first violation occurs on action_type={violating_action}; expected recommend or marketing_claim"
)
for step in trace["steps"]:
if step["step_id"] >= first_violation_step:
continue
if step["action_type"] in REMEDIAL_ACTION_TYPES:
warnings.append(
f"{origin}: remedial action_type={step['action_type']} appears before first_violation_step={first_violation_step}"
)
break
def validate_bundle(bundle: dict, errors: list[str], warnings: list[str]) -> None:
bundle_id = bundle["bundle_id"]
candidates = bundle["candidates"]
label_set = {candidate["label"] for candidate in candidates}
missing_labels = sorted(REQUIRED_BUNDLE_LABELS - label_set)
if missing_labels:
errors.append(f"{bundle_id}: missing required labels {missing_labels}")
for candidate in candidates:
label = candidate["label"]
candidate_origin = f"{bundle_id}/{candidate['trace_id']}"
candidate_has_hard_violation = has_hard_violation(candidate)
first_violation_step = candidate["first_violation_step"]
if label == "compliant" and (
candidate_has_hard_violation or first_violation_step is not None
):
warnings.append(
f"{candidate_origin}: compliant candidate has hard_violation={candidate_has_hard_violation} and first_violation_step={first_violation_step}"
)
elif label == "hard_violation" and (
not candidate_has_hard_violation or first_violation_step is None
):
warnings.append(
f"{candidate_origin}: hard_violation candidate has hard_violation={candidate_has_hard_violation} and first_violation_step={first_violation_step}"
)
elif label == "deadline_missed" and not candidate_has_hard_violation:
warnings.append(f"{candidate_origin}: deadline_missed candidate has no hard violation")
first_positions: dict[str, int] = {}
for index, candidate in enumerate(candidates):
first_positions.setdefault(candidate["label"], index)
if all(label in first_positions for label in REQUIRED_BUNDLE_LABELS):
ordered_positions = [
first_positions["compliant"],
first_positions["deadline_missed"],
first_positions["hard_violation"],
]
if ordered_positions != sorted(ordered_positions):
warnings.append(
f"{bundle_id}: candidate order violates severity ordering compliant < deadline_missed < hard_violation; positions={ordered_positions}"
)
def build_result(errors: list[str], warnings: list[str], traces_count: int, bundles_count: int) -> dict:
return {
"ok": len(errors) == 0,
"errors": errors,
"warnings": warnings,
"validated_traces": traces_count,
"validated_bundles": bundles_count,
}
def validate_dataset(a_traces: list[dict], b_traces: list[dict], bundles: list[dict], valid_rule_ids: set[str]) -> dict:
traces = a_traces + b_traces
errors: list[str] = []
warnings: list[str] = []
for trace in traces:
validate_trace(trace, valid_rule_ids, errors, trace["trace_id"])
validate_deadline_missed_temporal(trace, warnings, trace["trace_id"])
for trace in a_traces:
validate_canonical_trace_length(trace, errors, trace["trace_id"])
for bundle in bundles:
for candidate in bundle["candidates"]:
validate_trace(candidate, valid_rule_ids, errors, candidate["trace_id"])
validate_bundle(bundle, errors, warnings)
return build_result(errors, warnings, len(traces), len(bundles))
def validate_single_bundle(bundle: dict, valid_rule_ids: set[str]) -> dict:
errors: list[str] = []
warnings: list[str] = []
for candidate in bundle["candidates"]:
validate_trace(candidate, valid_rule_ids, errors, candidate["trace_id"])
validate_deadline_missed_temporal(candidate, warnings, candidate["trace_id"])
if candidate["trace_id"].endswith("_A"):
validate_canonical_trace_length(candidate, errors, candidate["trace_id"])
validate_bundle(bundle, errors, warnings)
return build_result(errors, warnings, len(bundle["candidates"]), 1)
def load_annotation_bundles(annotations_dir: Path) -> list[dict]:
bundles: list[dict] = []
for path in sorted(annotations_dir.rglob("*.json")):
bundles.append(load_json(path))
return bundles
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Validate CPRM pilot traces, bundles, or saved annotations.")
parser.add_argument(
"--bundle-json",
type=Path,
help="Validate a single saved bundle JSON file instead of the default dataset.",
)
parser.add_argument(
"--annotations-dir",
type=Path,
help="Validate all JSON bundle files under an annotations directory instead of the default dataset.",
)
return parser.parse_args()
def main() -> None:
args = parse_args()
valid_rule_ids = load_rule_ids(RULE_CARDS_PATH)
if args.bundle_json:
result = validate_single_bundle(load_json(args.bundle_json), valid_rule_ids)
elif args.annotations_dir:
bundles = load_annotation_bundles(args.annotations_dir)
result = validate_dataset([], [], bundles, valid_rule_ids)
else:
result = validate_dataset(
load_jsonl(A_TRACES_PATH),
load_jsonl(B_TRACES_PATH),
load_jsonl(BUNDLES_PATH),
valid_rule_ids,
)
print(json.dumps(result, indent=2))
if not result["ok"]:
raise SystemExit(1)
if __name__ == "__main__":
main()