| |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import csv |
| import json |
| from pathlib import Path |
| from typing import Iterable |
|
|
| ROOT = Path(__file__).resolve().parents[1] |
| RULE_CARDS_PATH = ROOT / "data" / "rules" / "rule_cards_starter.csv" |
| A_TRACES_PATH = ROOT / "data" / "canonical" / "A_canonical_traces_v1.jsonl" |
| B_TRACES_PATH = ROOT / "data" / "counterfactuals" / "B_counterfactual_traces_v1.jsonl" |
| BUNDLES_PATH = ROOT / "data" / "bundles" / "pilot_bundles_v1.jsonl" |
|
|
| ALLOWED_ACTION_TYPES = { |
| "ask_profile", |
| "confirm_preference", |
| "assess", |
| "recommend", |
| "redirect", |
| "disclose_risk", |
| "disclose_conflict", |
| "marketing_claim", |
| } |
| REQUIRED_BUNDLE_LABELS = {"compliant", "deadline_missed", "hard_violation"} |
| DEADLINE_VIOLATION_ACTION_TYPES = {"recommend", "marketing_claim"} |
| REMEDIAL_ACTION_TYPES = { |
| "ask_profile", |
| "confirm_preference", |
| "disclose_risk", |
| "disclose_conflict", |
| } |
|
|
|
|
| def load_jsonl(path: Path) -> list[dict]: |
| with path.open("r", encoding="utf-8") as handle: |
| return [json.loads(line) for line in handle if line.strip()] |
|
|
|
|
| def load_json(path: Path) -> dict: |
| with path.open("r", encoding="utf-8") as handle: |
| return json.load(handle) |
|
|
|
|
| def load_rule_ids(path: Path) -> set[str]: |
| with path.open("r", encoding="utf-8", newline="") as handle: |
| return {row["rule_id"] for row in csv.DictReader(handle)} |
|
|
|
|
| def earliest_hard_violation(trace: dict) -> int | None: |
| for step in trace["steps"]: |
| if step["hard_violation"]: |
| return step["step_id"] |
| return None |
|
|
|
|
| def has_hard_violation(trace: dict) -> bool: |
| return earliest_hard_violation(trace) is not None |
|
|
|
|
| def validate_trace(trace: dict, valid_rule_ids: set[str], errors: list[str], origin: str) -> None: |
| expected_first = earliest_hard_violation(trace) |
| if trace["first_violation_step"] != expected_first: |
| errors.append( |
| f"{origin}: first_violation_step={trace['first_violation_step']} but earliest hard violation is {expected_first}" |
| ) |
|
|
| expected_compliant = expected_first is None and trace["label"] == "compliant" |
| if trace["overall_compliant"] != expected_compliant: |
| errors.append( |
| f"{origin}: overall_compliant={trace['overall_compliant']} but expected {expected_compliant}" |
| ) |
|
|
| cumulative: dict[str, float] = {} |
| for step in trace["steps"]: |
| if step["action_type"] not in ALLOWED_ACTION_TYPES: |
| errors.append(f"{origin}: unsupported action_type {step['action_type']}") |
|
|
| for rule_id in step["active_rule_ids"]: |
| if rule_id not in valid_rule_ids: |
| errors.append(f"{origin}: unknown active_rule_id {rule_id}") |
|
|
| violated_rule_id = step["violated_rule_id"] |
| if violated_rule_id is not None and violated_rule_id not in valid_rule_ids: |
| errors.append(f"{origin}: unknown violated_rule_id {violated_rule_id}") |
|
|
| for rule_id, delta in step["soft_coverage_delta"].items(): |
| if rule_id not in valid_rule_ids: |
| errors.append(f"{origin}: unknown soft_coverage_delta rule_id {rule_id}") |
| cumulative[rule_id] = cumulative.get(rule_id, 0.0) + float(delta) |
| if cumulative[rule_id] > 1.000001: |
| errors.append( |
| f"{origin}: cumulative soft coverage for {rule_id} exceeds 1.0 ({cumulative[rule_id]:.3f})" |
| ) |
|
|
|
|
| def validate_canonical_trace_length(trace: dict, errors: list[str], origin: str) -> None: |
| step_count = len(trace["steps"]) |
| if not 3 <= step_count <= 5: |
| errors.append(f"{origin}: canonical trace has {step_count} steps; expected 3 to 5") |
|
|
|
|
| def validate_deadline_missed_temporal(trace: dict, warnings: list[str], origin: str) -> None: |
| if trace["label"] != "deadline_missed": |
| return |
|
|
| first_violation_step = trace["first_violation_step"] |
| if first_violation_step is None: |
| warnings.append(f"{origin}: deadline_missed trace has null first_violation_step") |
| return |
|
|
| if first_violation_step > 2: |
| warnings.append( |
| f"{origin}: deadline_missed first_violation_step={first_violation_step}; expected early violation at step 1 or 2" |
| ) |
|
|
| step_lookup = {step["step_id"]: step for step in trace["steps"]} |
| violating_step = step_lookup.get(first_violation_step) |
| if violating_step is None: |
| warnings.append( |
| f"{origin}: first_violation_step={first_violation_step} does not exist in steps" |
| ) |
| return |
|
|
| violating_action = violating_step["action_type"] |
| if violating_action not in DEADLINE_VIOLATION_ACTION_TYPES: |
| warnings.append( |
| f"{origin}: deadline_missed first violation occurs on action_type={violating_action}; expected recommend or marketing_claim" |
| ) |
|
|
| for step in trace["steps"]: |
| if step["step_id"] >= first_violation_step: |
| continue |
| if step["action_type"] in REMEDIAL_ACTION_TYPES: |
| warnings.append( |
| f"{origin}: remedial action_type={step['action_type']} appears before first_violation_step={first_violation_step}" |
| ) |
| break |
|
|
|
|
| def validate_bundle(bundle: dict, errors: list[str], warnings: list[str]) -> None: |
| bundle_id = bundle["bundle_id"] |
| candidates = bundle["candidates"] |
| label_set = {candidate["label"] for candidate in candidates} |
| missing_labels = sorted(REQUIRED_BUNDLE_LABELS - label_set) |
| if missing_labels: |
| errors.append(f"{bundle_id}: missing required labels {missing_labels}") |
|
|
| for candidate in candidates: |
| label = candidate["label"] |
| candidate_origin = f"{bundle_id}/{candidate['trace_id']}" |
| candidate_has_hard_violation = has_hard_violation(candidate) |
| first_violation_step = candidate["first_violation_step"] |
|
|
| if label == "compliant" and ( |
| candidate_has_hard_violation or first_violation_step is not None |
| ): |
| warnings.append( |
| f"{candidate_origin}: compliant candidate has hard_violation={candidate_has_hard_violation} and first_violation_step={first_violation_step}" |
| ) |
| elif label == "hard_violation" and ( |
| not candidate_has_hard_violation or first_violation_step is None |
| ): |
| warnings.append( |
| f"{candidate_origin}: hard_violation candidate has hard_violation={candidate_has_hard_violation} and first_violation_step={first_violation_step}" |
| ) |
| elif label == "deadline_missed" and not candidate_has_hard_violation: |
| warnings.append(f"{candidate_origin}: deadline_missed candidate has no hard violation") |
|
|
| first_positions: dict[str, int] = {} |
| for index, candidate in enumerate(candidates): |
| first_positions.setdefault(candidate["label"], index) |
|
|
| if all(label in first_positions for label in REQUIRED_BUNDLE_LABELS): |
| ordered_positions = [ |
| first_positions["compliant"], |
| first_positions["deadline_missed"], |
| first_positions["hard_violation"], |
| ] |
| if ordered_positions != sorted(ordered_positions): |
| warnings.append( |
| f"{bundle_id}: candidate order violates severity ordering compliant < deadline_missed < hard_violation; positions={ordered_positions}" |
| ) |
|
|
|
|
| def build_result(errors: list[str], warnings: list[str], traces_count: int, bundles_count: int) -> dict: |
| return { |
| "ok": len(errors) == 0, |
| "errors": errors, |
| "warnings": warnings, |
| "validated_traces": traces_count, |
| "validated_bundles": bundles_count, |
| } |
|
|
|
|
| def validate_dataset(a_traces: list[dict], b_traces: list[dict], bundles: list[dict], valid_rule_ids: set[str]) -> dict: |
| traces = a_traces + b_traces |
| errors: list[str] = [] |
| warnings: list[str] = [] |
|
|
| for trace in traces: |
| validate_trace(trace, valid_rule_ids, errors, trace["trace_id"]) |
| validate_deadline_missed_temporal(trace, warnings, trace["trace_id"]) |
|
|
| for trace in a_traces: |
| validate_canonical_trace_length(trace, errors, trace["trace_id"]) |
|
|
| for bundle in bundles: |
| for candidate in bundle["candidates"]: |
| validate_trace(candidate, valid_rule_ids, errors, candidate["trace_id"]) |
| validate_bundle(bundle, errors, warnings) |
|
|
| return build_result(errors, warnings, len(traces), len(bundles)) |
|
|
|
|
| def validate_single_bundle(bundle: dict, valid_rule_ids: set[str]) -> dict: |
| errors: list[str] = [] |
| warnings: list[str] = [] |
|
|
| for candidate in bundle["candidates"]: |
| validate_trace(candidate, valid_rule_ids, errors, candidate["trace_id"]) |
| validate_deadline_missed_temporal(candidate, warnings, candidate["trace_id"]) |
| if candidate["trace_id"].endswith("_A"): |
| validate_canonical_trace_length(candidate, errors, candidate["trace_id"]) |
|
|
| validate_bundle(bundle, errors, warnings) |
| return build_result(errors, warnings, len(bundle["candidates"]), 1) |
|
|
|
|
| def load_annotation_bundles(annotations_dir: Path) -> list[dict]: |
| bundles: list[dict] = [] |
| for path in sorted(annotations_dir.rglob("*.json")): |
| bundles.append(load_json(path)) |
| return bundles |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser(description="Validate CPRM pilot traces, bundles, or saved annotations.") |
| parser.add_argument( |
| "--bundle-json", |
| type=Path, |
| help="Validate a single saved bundle JSON file instead of the default dataset.", |
| ) |
| parser.add_argument( |
| "--annotations-dir", |
| type=Path, |
| help="Validate all JSON bundle files under an annotations directory instead of the default dataset.", |
| ) |
| return parser.parse_args() |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| valid_rule_ids = load_rule_ids(RULE_CARDS_PATH) |
|
|
| if args.bundle_json: |
| result = validate_single_bundle(load_json(args.bundle_json), valid_rule_ids) |
| elif args.annotations_dir: |
| bundles = load_annotation_bundles(args.annotations_dir) |
| result = validate_dataset([], [], bundles, valid_rule_ids) |
| else: |
| result = validate_dataset( |
| load_jsonl(A_TRACES_PATH), |
| load_jsonl(B_TRACES_PATH), |
| load_jsonl(BUNDLES_PATH), |
| valid_rule_ids, |
| ) |
|
|
| print(json.dumps(result, indent=2)) |
| if not result["ok"]: |
| raise SystemExit(1) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|