| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import sys |
| from collections import Counter, defaultdict |
| from pathlib import Path |
| from time import perf_counter |
| from typing import Any |
|
|
| sys.path.insert(0, str(Path(__file__).resolve().parents[1])) |
|
|
| from jawbreaker.analyzers import ( |
| analysis_to_prediction, |
| build_llama_cpp_analyzer, |
| build_transformers_analyzer, |
| has_unsafe_action, |
| heuristic_analyzer, |
| load_prediction_jsonl, |
| prediction_file_analyzer, |
| prediction_to_analysis, |
| repair_prediction, |
| should_apply_heuristic_guard, |
| validate_prediction, |
| write_predictions, |
| ) |
| from jawbreaker.schema import ScamAnalysis |
| from jawbreaker.schema import RISK_LEVELS |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser(description="Run Jawbreaker scam-risk evals.") |
| parser.add_argument("--dataset", type=Path, default=Path(__file__).with_name("scam_eval.jsonl")) |
| parser.add_argument( |
| "--backend", |
| choices=["heuristic", "predictions", "llama-cpp", "transformers"], |
| default="heuristic", |
| ) |
| parser.add_argument("--predictions", type=Path, help="JSONL predictions for --backend predictions.") |
| parser.add_argument("--predictions-out", type=Path, help="Write predictions as JSONL.") |
| parser.add_argument("--json-out", type=Path, help="Write metrics as JSON.") |
| parser.add_argument("--limit", type=int, help="Limit number of eval cases for smoke tests.") |
| parser.add_argument("--show-failures", type=int, default=5, help="Failures to print per category.") |
| parser.add_argument( |
| "--apply-safety-guard", |
| action="store_true", |
| help="Apply the same deterministic undercall guard used by the app before scoring.", |
| ) |
|
|
| parser.add_argument("--model-path", type=Path, help="GGUF path for --backend llama-cpp.") |
| parser.add_argument("--chat-format", help="Optional llama-cpp-python chat_format.") |
| parser.add_argument("--n-ctx", type=int, default=4096) |
| parser.add_argument("--n-threads", type=int) |
| parser.add_argument("--n-gpu-layers", type=int, default=0) |
| parser.add_argument("--n-batch", type=int, default=512) |
| parser.add_argument("--n-ubatch", type=int, default=512) |
| parser.add_argument("--offload-kqv", action=argparse.BooleanOptionalAction, default=True) |
| parser.add_argument("--op-offload", action=argparse.BooleanOptionalAction) |
| parser.add_argument("--max-tokens", type=int, default=512) |
| parser.add_argument("--temperature", type=float, default=0.0) |
| parser.add_argument("--model-id", default="openbmb/MiniCPM4.1-8B", help="HF model id for --backend transformers.") |
| parser.add_argument("--adapter-id", help="Optional PEFT adapter id for --backend transformers.") |
| parser.add_argument("--device-map", default="auto", help="Transformers device_map.") |
| parser.add_argument("--dtype", default="auto", help="Transformers dtype.") |
| parser.add_argument("--trust-remote-code", action=argparse.BooleanOptionalAction, default=True) |
| parser.add_argument("--attn-implementation", default="eager", help="Transformers attention implementation.") |
| return parser.parse_args() |
|
|
|
|
| def load_rows(path: Path, limit: int | None = None) -> list[dict[str, Any]]: |
| rows = [] |
| ids = set() |
| errors = [] |
|
|
| for line_number, line in enumerate(path.read_text(encoding="utf-8").splitlines(), start=1): |
| if not line.strip(): |
| continue |
| try: |
| row = json.loads(line) |
| except json.JSONDecodeError as exc: |
| errors.append(f"line {line_number}: invalid JSON: {exc}") |
| continue |
|
|
| missing = {"id", "category", "input", "expected_risk_level", "expected_scam_type", "expected_tactics"} - set(row) |
| if missing: |
| errors.append(f"line {line_number}: missing fields: {sorted(missing)}") |
|
|
| if row.get("id") in ids: |
| errors.append(f"line {line_number}: duplicate id: {row.get('id')}") |
| ids.add(row.get("id")) |
|
|
| if row.get("expected_risk_level") not in RISK_LEVELS: |
| errors.append(f"line {line_number}: invalid expected_risk_level: {row.get('expected_risk_level')}") |
|
|
| if not isinstance(row.get("expected_tactics"), list): |
| errors.append(f"line {line_number}: expected_tactics must be a list") |
|
|
| rows.append(row) |
| if limit is not None and len(rows) >= limit: |
| break |
|
|
| if errors: |
| raise SystemExit("Eval dataset validation failed:\n" + "\n".join(errors)) |
| return rows |
|
|
|
|
| def tactic_recall(expected: list[str], actual: list[str]) -> float: |
| if not expected: |
| return 1.0 |
| expected_set = set(expected) |
| actual_set = set(actual) |
| return len(expected_set & actual_set) / len(expected_set) |
|
|
|
|
| def build_analyzer(args: argparse.Namespace): |
| if args.backend == "heuristic": |
| return lambda row: heuristic_analyzer(row["input"]) |
|
|
| if args.backend == "predictions": |
| if not args.predictions: |
| raise SystemExit("--predictions is required with --backend predictions") |
| predictions = load_prediction_jsonl(args.predictions) |
| return lambda row: prediction_file_analyzer(predictions, row["id"]) |
|
|
| if args.backend == "llama-cpp": |
| if not args.model_path: |
| raise SystemExit("--model-path is required with --backend llama-cpp") |
| analyzer = build_llama_cpp_analyzer( |
| args.model_path, |
| chat_format=args.chat_format, |
| n_ctx=args.n_ctx, |
| n_threads=args.n_threads, |
| n_gpu_layers=args.n_gpu_layers, |
| n_batch=args.n_batch, |
| n_ubatch=args.n_ubatch, |
| offload_kqv=args.offload_kqv, |
| op_offload=args.op_offload, |
| max_tokens=args.max_tokens, |
| temperature=args.temperature, |
| ) |
| return lambda row: analyzer(row["input"]) |
|
|
| if args.backend == "transformers": |
| analyzer = build_transformers_analyzer( |
| args.model_id, |
| adapter_id=args.adapter_id, |
| max_new_tokens=args.max_tokens, |
| temperature=args.temperature, |
| device_map=args.device_map, |
| dtype=args.dtype, |
| trust_remote_code=args.trust_remote_code, |
| attn_implementation=args.attn_implementation, |
| ) |
| return lambda row: analyzer(row["input"]) |
|
|
| raise SystemExit(f"Unsupported backend: {args.backend}") |
|
|
|
|
| def apply_safety_guard(message: str, prediction: dict[str, Any]) -> tuple[dict[str, Any], bool]: |
| model_analysis = prediction_to_analysis(prediction) |
| heuristic = ScamAnalysis.from_heuristics(message) |
| if should_apply_heuristic_guard(message, model_analysis, heuristic, validation_errors=[]): |
| return analysis_to_prediction(heuristic), True |
|
|
| return prediction, False |
|
|
|
|
| def score_rows(rows: list[dict[str, Any]], predictions: dict[str, dict[str, Any]], elapsed: float) -> dict[str, Any]: |
| risk_correct = 0 |
| scam_type_correct = 0 |
| dangerous_as_safe = [] |
| dangerous_as_needs_check = [] |
| suspicious_as_safe = [] |
| safe_as_dangerous = [] |
| unsafe_actions = [] |
| invalid_predictions = [] |
| model_errors = [] |
| safety_guard_promotions = [] |
| recalls = [] |
| risk_confusion: Counter[tuple[str, str]] = Counter() |
| category_counts: Counter[str] = Counter() |
| category_correct: Counter[str] = Counter() |
| failures_by_category: dict[str, list[str]] = defaultdict(list) |
|
|
| for row in rows: |
| case_id = row["id"] |
| prediction = predictions[case_id] |
| if "_jawbreaker_model_error" in prediction: |
| model_errors.append({"id": case_id, "error": str(prediction["_jawbreaker_model_error"])}) |
| if "_jawbreaker_safety_guard" in prediction: |
| safety_guard_promotions.append({"id": case_id, "from": str(prediction["_jawbreaker_safety_guard"])}) |
| validation_errors = validate_prediction(prediction) |
| if validation_errors: |
| invalid_predictions.append({"id": case_id, "errors": validation_errors}) |
|
|
| actual_risk = prediction.get("risk_level", "invalid") |
| actual_scam_type = prediction.get("scam_type", "invalid") |
| actual_tactics = prediction.get("tactics", []) |
| if not isinstance(actual_tactics, list): |
| actual_tactics = [] |
|
|
| expected_risk = row["expected_risk_level"] |
| category = row["category"] |
| risk_ok = actual_risk == expected_risk |
| type_ok = actual_scam_type == row["expected_scam_type"] |
| recall = tactic_recall(row["expected_tactics"], [str(tactic) for tactic in actual_tactics]) |
|
|
| risk_correct += int(risk_ok) |
| scam_type_correct += int(type_ok) |
| recalls.append(recall) |
| risk_confusion[(expected_risk, str(actual_risk))] += 1 |
| category_counts[category] += 1 |
| category_correct[category] += int(risk_ok) |
|
|
| if expected_risk == "dangerous" and actual_risk == "safe": |
| dangerous_as_safe.append(case_id) |
| if expected_risk == "dangerous" and actual_risk == "needs_check": |
| dangerous_as_needs_check.append(case_id) |
| if expected_risk == "suspicious" and actual_risk == "safe": |
| suspicious_as_safe.append(case_id) |
| if expected_risk == "safe" and actual_risk in {"dangerous", "suspicious"}: |
| safe_as_dangerous.append(case_id) |
| if has_unsafe_action(str(prediction.get("safest_action", ""))): |
| unsafe_actions.append(case_id) |
| if not risk_ok: |
| failures_by_category[category].append(f"{case_id} expected={expected_risk} actual={actual_risk}") |
|
|
| total = len(rows) |
| return { |
| "cases": total, |
| "risk_level_correct": risk_correct, |
| "risk_level_accuracy": risk_correct / total, |
| "scam_type_correct": scam_type_correct, |
| "scam_type_accuracy": scam_type_correct / total, |
| "mean_tactic_recall": sum(recalls) / len(recalls), |
| "dangerous_as_safe": dangerous_as_safe, |
| "dangerous_as_needs_check": dangerous_as_needs_check, |
| "suspicious_as_safe": suspicious_as_safe, |
| "safe_as_dangerous_or_suspicious": safe_as_dangerous, |
| "unsafe_action_violations": unsafe_actions, |
| "invalid_predictions": invalid_predictions, |
| "model_errors": model_errors, |
| "safety_guard_promotions": safety_guard_promotions, |
| "elapsed_seconds": elapsed, |
| "risk_confusion": {f"{expected}->{actual}": count for (expected, actual), count in sorted(risk_confusion.items())}, |
| "category_risk_accuracy": { |
| category: { |
| "correct": category_correct[category], |
| "total": count, |
| "accuracy": category_correct[category] / count, |
| } |
| for category, count in sorted(category_counts.items()) |
| }, |
| "failures_by_category": {category: failures for category, failures in sorted(failures_by_category.items())}, |
| } |
|
|
|
|
| def print_report(metrics: dict[str, Any], show_failures: int) -> None: |
| total = metrics["cases"] |
| print(f"cases={total}") |
| print( |
| "risk_level_accuracy=" |
| f"{metrics['risk_level_correct']}/{total} ({metrics['risk_level_accuracy']:.1%})" |
| ) |
| print( |
| "scam_type_accuracy=" |
| f"{metrics['scam_type_correct']}/{total} ({metrics['scam_type_accuracy']:.1%})" |
| ) |
| print(f"mean_tactic_recall={metrics['mean_tactic_recall']:.1%}") |
| print(f"dangerous_as_safe={len(metrics['dangerous_as_safe'])} {metrics['dangerous_as_safe']}") |
| print( |
| "dangerous_as_needs_check=" |
| f"{len(metrics['dangerous_as_needs_check'])} {metrics['dangerous_as_needs_check']}" |
| ) |
| print(f"suspicious_as_safe={len(metrics['suspicious_as_safe'])} {metrics['suspicious_as_safe']}") |
| print( |
| "safe_as_dangerous_or_suspicious=" |
| f"{len(metrics['safe_as_dangerous_or_suspicious'])} {metrics['safe_as_dangerous_or_suspicious']}" |
| ) |
| print(f"unsafe_action_violations={len(metrics['unsafe_action_violations'])} {metrics['unsafe_action_violations']}") |
| print(f"invalid_predictions={len(metrics['invalid_predictions'])} {metrics['invalid_predictions'][:show_failures]}") |
| print(f"model_errors={len(metrics['model_errors'])} {metrics['model_errors'][:show_failures]}") |
| print( |
| "safety_guard_promotions=" |
| f"{len(metrics['safety_guard_promotions'])} {metrics['safety_guard_promotions'][:show_failures]}" |
| ) |
| print(f"elapsed_seconds={metrics['elapsed_seconds']:.3f}") |
|
|
| print("\nrisk_confusion expected->actual:") |
| for pair, count in metrics["risk_confusion"].items(): |
| expected, actual = pair.split("->", 1) |
| print(f" {expected:12s} -> {actual:12s} {count}") |
|
|
| print("\ncategory_risk_accuracy:") |
| for category, result in metrics["category_risk_accuracy"].items(): |
| print( |
| f" {category:24s} {result['correct']:2d}/{result['total']:2d} " |
| f"({result['accuracy']:.1%})" |
| ) |
|
|
| if metrics["failures_by_category"]: |
| print("\nfirst_failures_by_category:") |
| for category, failures in metrics["failures_by_category"].items(): |
| print(f" {category}:") |
| for failure in failures[:show_failures]: |
| print(f" {failure}") |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| rows = load_rows(args.dataset, args.limit) |
| analyzer = build_analyzer(args) |
| predictions = {} |
|
|
| started = perf_counter() |
| for index, row in enumerate(rows, start=1): |
| print(f"eval case {index}/{len(rows)} id={row['id']}", flush=True) |
| try: |
| prediction = repair_prediction(analyzer(row)) |
| except Exception as exc: |
| prediction = heuristic_analyzer(row["input"]) |
| prediction["_jawbreaker_model_error"] = repr(exc) |
| prediction = repair_prediction(prediction) |
| if args.apply_safety_guard: |
| guarded_prediction, promoted = apply_safety_guard(row["input"], prediction) |
| if promoted: |
| guarded_prediction["_jawbreaker_safety_guard"] = prediction.get("risk_level", "unknown") |
| prediction = repair_prediction(guarded_prediction) |
| predictions[row["id"]] = prediction |
| elapsed = perf_counter() - started |
|
|
| metrics = score_rows(rows, predictions, elapsed) |
| print_report(metrics, args.show_failures) |
|
|
| if args.predictions_out: |
| write_predictions(args.predictions_out, rows, predictions) |
| if args.json_out: |
| args.json_out.write_text(json.dumps(metrics, indent=2, sort_keys=True) + "\n", encoding="utf-8") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|