#!/usr/bin/env python3 from __future__ import annotations import argparse import csv import json from pathlib import Path from typing import Iterable ROOT = Path(__file__).resolve().parents[1] RULE_CARDS_PATH = ROOT / "data" / "rules" / "rule_cards_starter.csv" A_TRACES_PATH = ROOT / "data" / "canonical" / "A_canonical_traces_v1.jsonl" B_TRACES_PATH = ROOT / "data" / "counterfactuals" / "B_counterfactual_traces_v1.jsonl" BUNDLES_PATH = ROOT / "data" / "bundles" / "pilot_bundles_v1.jsonl" ALLOWED_ACTION_TYPES = { "ask_profile", "confirm_preference", "assess", "recommend", "redirect", "disclose_risk", "disclose_conflict", "marketing_claim", } REQUIRED_BUNDLE_LABELS = {"compliant", "deadline_missed", "hard_violation"} DEADLINE_VIOLATION_ACTION_TYPES = {"recommend", "marketing_claim"} REMEDIAL_ACTION_TYPES = { "ask_profile", "confirm_preference", "disclose_risk", "disclose_conflict", } def load_jsonl(path: Path) -> list[dict]: with path.open("r", encoding="utf-8") as handle: return [json.loads(line) for line in handle if line.strip()] def load_json(path: Path) -> dict: with path.open("r", encoding="utf-8") as handle: return json.load(handle) def load_rule_ids(path: Path) -> set[str]: with path.open("r", encoding="utf-8", newline="") as handle: return {row["rule_id"] for row in csv.DictReader(handle)} def earliest_hard_violation(trace: dict) -> int | None: for step in trace["steps"]: if step["hard_violation"]: return step["step_id"] return None def has_hard_violation(trace: dict) -> bool: return earliest_hard_violation(trace) is not None def validate_trace(trace: dict, valid_rule_ids: set[str], errors: list[str], origin: str) -> None: expected_first = earliest_hard_violation(trace) if trace["first_violation_step"] != expected_first: errors.append( f"{origin}: first_violation_step={trace['first_violation_step']} but earliest hard violation is {expected_first}" ) expected_compliant = expected_first is None and trace["label"] == "compliant" if trace["overall_compliant"] != expected_compliant: errors.append( f"{origin}: overall_compliant={trace['overall_compliant']} but expected {expected_compliant}" ) cumulative: dict[str, float] = {} for step in trace["steps"]: if step["action_type"] not in ALLOWED_ACTION_TYPES: errors.append(f"{origin}: unsupported action_type {step['action_type']}") for rule_id in step["active_rule_ids"]: if rule_id not in valid_rule_ids: errors.append(f"{origin}: unknown active_rule_id {rule_id}") violated_rule_id = step["violated_rule_id"] if violated_rule_id is not None and violated_rule_id not in valid_rule_ids: errors.append(f"{origin}: unknown violated_rule_id {violated_rule_id}") for rule_id, delta in step["soft_coverage_delta"].items(): if rule_id not in valid_rule_ids: errors.append(f"{origin}: unknown soft_coverage_delta rule_id {rule_id}") cumulative[rule_id] = cumulative.get(rule_id, 0.0) + float(delta) if cumulative[rule_id] > 1.000001: errors.append( f"{origin}: cumulative soft coverage for {rule_id} exceeds 1.0 ({cumulative[rule_id]:.3f})" ) def validate_canonical_trace_length(trace: dict, errors: list[str], origin: str) -> None: step_count = len(trace["steps"]) if not 3 <= step_count <= 5: errors.append(f"{origin}: canonical trace has {step_count} steps; expected 3 to 5") def validate_deadline_missed_temporal(trace: dict, warnings: list[str], origin: str) -> None: if trace["label"] != "deadline_missed": return first_violation_step = trace["first_violation_step"] if first_violation_step is None: warnings.append(f"{origin}: deadline_missed trace has null first_violation_step") return if first_violation_step > 2: warnings.append( f"{origin}: deadline_missed first_violation_step={first_violation_step}; expected early violation at step 1 or 2" ) step_lookup = {step["step_id"]: step for step in trace["steps"]} violating_step = step_lookup.get(first_violation_step) if violating_step is None: warnings.append( f"{origin}: first_violation_step={first_violation_step} does not exist in steps" ) return violating_action = violating_step["action_type"] if violating_action not in DEADLINE_VIOLATION_ACTION_TYPES: warnings.append( f"{origin}: deadline_missed first violation occurs on action_type={violating_action}; expected recommend or marketing_claim" ) for step in trace["steps"]: if step["step_id"] >= first_violation_step: continue if step["action_type"] in REMEDIAL_ACTION_TYPES: warnings.append( f"{origin}: remedial action_type={step['action_type']} appears before first_violation_step={first_violation_step}" ) break def validate_bundle(bundle: dict, errors: list[str], warnings: list[str]) -> None: bundle_id = bundle["bundle_id"] candidates = bundle["candidates"] label_set = {candidate["label"] for candidate in candidates} missing_labels = sorted(REQUIRED_BUNDLE_LABELS - label_set) if missing_labels: errors.append(f"{bundle_id}: missing required labels {missing_labels}") for candidate in candidates: label = candidate["label"] candidate_origin = f"{bundle_id}/{candidate['trace_id']}" candidate_has_hard_violation = has_hard_violation(candidate) first_violation_step = candidate["first_violation_step"] if label == "compliant" and ( candidate_has_hard_violation or first_violation_step is not None ): warnings.append( f"{candidate_origin}: compliant candidate has hard_violation={candidate_has_hard_violation} and first_violation_step={first_violation_step}" ) elif label == "hard_violation" and ( not candidate_has_hard_violation or first_violation_step is None ): warnings.append( f"{candidate_origin}: hard_violation candidate has hard_violation={candidate_has_hard_violation} and first_violation_step={first_violation_step}" ) elif label == "deadline_missed" and not candidate_has_hard_violation: warnings.append(f"{candidate_origin}: deadline_missed candidate has no hard violation") first_positions: dict[str, int] = {} for index, candidate in enumerate(candidates): first_positions.setdefault(candidate["label"], index) if all(label in first_positions for label in REQUIRED_BUNDLE_LABELS): ordered_positions = [ first_positions["compliant"], first_positions["deadline_missed"], first_positions["hard_violation"], ] if ordered_positions != sorted(ordered_positions): warnings.append( f"{bundle_id}: candidate order violates severity ordering compliant < deadline_missed < hard_violation; positions={ordered_positions}" ) def build_result(errors: list[str], warnings: list[str], traces_count: int, bundles_count: int) -> dict: return { "ok": len(errors) == 0, "errors": errors, "warnings": warnings, "validated_traces": traces_count, "validated_bundles": bundles_count, } def validate_dataset(a_traces: list[dict], b_traces: list[dict], bundles: list[dict], valid_rule_ids: set[str]) -> dict: traces = a_traces + b_traces errors: list[str] = [] warnings: list[str] = [] for trace in traces: validate_trace(trace, valid_rule_ids, errors, trace["trace_id"]) validate_deadline_missed_temporal(trace, warnings, trace["trace_id"]) for trace in a_traces: validate_canonical_trace_length(trace, errors, trace["trace_id"]) for bundle in bundles: for candidate in bundle["candidates"]: validate_trace(candidate, valid_rule_ids, errors, candidate["trace_id"]) validate_bundle(bundle, errors, warnings) return build_result(errors, warnings, len(traces), len(bundles)) def validate_single_bundle(bundle: dict, valid_rule_ids: set[str]) -> dict: errors: list[str] = [] warnings: list[str] = [] for candidate in bundle["candidates"]: validate_trace(candidate, valid_rule_ids, errors, candidate["trace_id"]) validate_deadline_missed_temporal(candidate, warnings, candidate["trace_id"]) if candidate["trace_id"].endswith("_A"): validate_canonical_trace_length(candidate, errors, candidate["trace_id"]) validate_bundle(bundle, errors, warnings) return build_result(errors, warnings, len(bundle["candidates"]), 1) def load_annotation_bundles(annotations_dir: Path) -> list[dict]: bundles: list[dict] = [] for path in sorted(annotations_dir.rglob("*.json")): bundles.append(load_json(path)) return bundles def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Validate CPRM pilot traces, bundles, or saved annotations.") parser.add_argument( "--bundle-json", type=Path, help="Validate a single saved bundle JSON file instead of the default dataset.", ) parser.add_argument( "--annotations-dir", type=Path, help="Validate all JSON bundle files under an annotations directory instead of the default dataset.", ) return parser.parse_args() def main() -> None: args = parse_args() valid_rule_ids = load_rule_ids(RULE_CARDS_PATH) if args.bundle_json: result = validate_single_bundle(load_json(args.bundle_json), valid_rule_ids) elif args.annotations_dir: bundles = load_annotation_bundles(args.annotations_dir) result = validate_dataset([], [], bundles, valid_rule_ids) else: result = validate_dataset( load_jsonl(A_TRACES_PATH), load_jsonl(B_TRACES_PATH), load_jsonl(BUNDLES_PATH), valid_rule_ids, ) print(json.dumps(result, indent=2)) if not result["ok"]: raise SystemExit(1) if __name__ == "__main__": main()