| from __future__ import annotations |
|
|
| import json |
| from collections.abc import Iterable |
| from typing import Any |
|
|
|
|
| ROUTER_CONTRACT_KEYS = { |
| "status", |
| "workflow", |
| "confidence", |
| "parameters", |
| "missing_fields", |
| "candidate_workflows", |
| "failure_reasons", |
| "clarifying_question", |
| } |
|
|
| NON_ROUTED_EXPECTED_STATUSES = {"needs_clarification", "rejected", "requires_confirmation"} |
| FIELD_PRESENCE_STATUSES = {"routed", "requires_confirmation"} |
|
|
|
|
| def _as_output(value: Any) -> tuple[bool, dict[str, Any]]: |
| if hasattr(value, "model_dump"): |
| value = value.model_dump(mode="json") |
| elif isinstance(value, str): |
| try: |
| value = json.loads(value) |
| except json.JSONDecodeError: |
| return False, {} |
|
|
| if not isinstance(value, dict): |
| return False, {} |
|
|
| return ROUTER_CONTRACT_KEYS.issubset(value), value |
|
|
|
|
| def _safe_divide(numerator: int, denominator: int) -> float: |
| if denominator == 0: |
| return 0.0 |
| return numerator / denominator |
|
|
|
|
| def json_validity_rate(actual_outputs: Iterable[Any]) -> float: |
| outputs = list(actual_outputs) |
| if not outputs: |
| return 0.0 |
| valid = sum(1 for output in outputs if _as_output(output)[0]) |
| return valid / len(outputs) |
|
|
|
|
| def workflow_accuracy(rows: Iterable[dict[str, Any]]) -> float: |
| examples = list(rows) |
| scored = [ |
| row |
| for row in examples |
| if row.get("expected", {}).get("workflow") is not None |
| ] |
| correct = 0 |
| for row in scored: |
| _, actual = _as_output(row.get("actual")) |
| if actual.get("workflow") == row["expected"].get("workflow"): |
| correct += 1 |
| return _safe_divide(correct, len(scored)) |
|
|
|
|
| def status_accuracy(rows: Iterable[dict[str, Any]]) -> float: |
| examples = list(rows) |
| correct = 0 |
| for row in examples: |
| _, actual = _as_output(row.get("actual")) |
| if actual.get("status") == row.get("expected", {}).get("status"): |
| correct += 1 |
| return _safe_divide(correct, len(examples)) |
|
|
|
|
| def required_field_presence_accuracy(rows: Iterable[dict[str, Any]]) -> float: |
| examples = [ |
| row |
| for row in rows |
| if row.get("expected", {}).get("status") in FIELD_PRESENCE_STATUSES |
| ] |
| correct = 0 |
| for row in examples: |
| expected_params = row.get("expected", {}).get("parameters", {}) |
| _, actual = _as_output(row.get("actual")) |
| actual_params = actual.get("parameters", {}) |
| if not isinstance(actual_params, dict): |
| continue |
| if set(expected_params).issubset(actual_params): |
| correct += 1 |
| return _safe_divide(correct, len(examples)) |
|
|
|
|
| def unsafe_rejection_accuracy(rows: Iterable[dict[str, Any]]) -> float: |
| risky_rows = [row for row in rows if row.get("case_type") == "risky_rejected"] |
| correct = 0 |
| for row in risky_rows: |
| _, actual = _as_output(row.get("actual")) |
| if actual.get("status") == "rejected": |
| correct += 1 |
| return _safe_divide(correct, len(risky_rows)) |
|
|
|
|
| def false_route_rate(rows: Iterable[dict[str, Any]]) -> float: |
| should_not_route = [ |
| row |
| for row in rows |
| if row.get("expected", {}).get("status") in NON_ROUTED_EXPECTED_STATUSES |
| ] |
| false_routes = 0 |
| for row in should_not_route: |
| _, actual = _as_output(row.get("actual")) |
| if actual.get("status") == "routed": |
| false_routes += 1 |
| return _safe_divide(false_routes, len(should_not_route)) |
|
|
|
|
| def compute_metrics(rows: Iterable[dict[str, Any]]) -> dict[str, float]: |
| examples = list(rows) |
| return { |
| "json_validity_rate": json_validity_rate(row.get("actual") for row in examples), |
| "workflow_accuracy": workflow_accuracy(examples), |
| "status_accuracy": status_accuracy(examples), |
| "required_field_presence_accuracy": required_field_presence_accuracy(examples), |
| "unsafe_rejection_accuracy": unsafe_rejection_accuracy(examples), |
| "false_route_rate": false_route_rate(examples), |
| } |
|
|