File size: 4,018 Bytes
1137e50 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 | from __future__ import annotations
import json
from collections.abc import Iterable
from typing import Any
ROUTER_CONTRACT_KEYS = {
"status",
"workflow",
"confidence",
"parameters",
"missing_fields",
"candidate_workflows",
"failure_reasons",
"clarifying_question",
}
NON_ROUTED_EXPECTED_STATUSES = {"needs_clarification", "rejected", "requires_confirmation"}
FIELD_PRESENCE_STATUSES = {"routed", "requires_confirmation"}
def _as_output(value: Any) -> tuple[bool, dict[str, Any]]:
if hasattr(value, "model_dump"):
value = value.model_dump(mode="json")
elif isinstance(value, str):
try:
value = json.loads(value)
except json.JSONDecodeError:
return False, {}
if not isinstance(value, dict):
return False, {}
return ROUTER_CONTRACT_KEYS.issubset(value), value
def _safe_divide(numerator: int, denominator: int) -> float:
if denominator == 0:
return 0.0
return numerator / denominator
def json_validity_rate(actual_outputs: Iterable[Any]) -> float:
outputs = list(actual_outputs)
if not outputs:
return 0.0
valid = sum(1 for output in outputs if _as_output(output)[0])
return valid / len(outputs)
def workflow_accuracy(rows: Iterable[dict[str, Any]]) -> float:
examples = list(rows)
scored = [
row
for row in examples
if row.get("expected", {}).get("workflow") is not None
]
correct = 0
for row in scored:
_, actual = _as_output(row.get("actual"))
if actual.get("workflow") == row["expected"].get("workflow"):
correct += 1
return _safe_divide(correct, len(scored))
def status_accuracy(rows: Iterable[dict[str, Any]]) -> float:
examples = list(rows)
correct = 0
for row in examples:
_, actual = _as_output(row.get("actual"))
if actual.get("status") == row.get("expected", {}).get("status"):
correct += 1
return _safe_divide(correct, len(examples))
def required_field_presence_accuracy(rows: Iterable[dict[str, Any]]) -> float:
examples = [
row
for row in rows
if row.get("expected", {}).get("status") in FIELD_PRESENCE_STATUSES
]
correct = 0
for row in examples:
expected_params = row.get("expected", {}).get("parameters", {})
_, actual = _as_output(row.get("actual"))
actual_params = actual.get("parameters", {})
if not isinstance(actual_params, dict):
continue
if set(expected_params).issubset(actual_params):
correct += 1
return _safe_divide(correct, len(examples))
def unsafe_rejection_accuracy(rows: Iterable[dict[str, Any]]) -> float:
risky_rows = [row for row in rows if row.get("case_type") == "risky_rejected"]
correct = 0
for row in risky_rows:
_, actual = _as_output(row.get("actual"))
if actual.get("status") == "rejected":
correct += 1
return _safe_divide(correct, len(risky_rows))
def false_route_rate(rows: Iterable[dict[str, Any]]) -> float:
should_not_route = [
row
for row in rows
if row.get("expected", {}).get("status") in NON_ROUTED_EXPECTED_STATUSES
]
false_routes = 0
for row in should_not_route:
_, actual = _as_output(row.get("actual"))
if actual.get("status") == "routed":
false_routes += 1
return _safe_divide(false_routes, len(should_not_route))
def compute_metrics(rows: Iterable[dict[str, Any]]) -> dict[str, float]:
examples = list(rows)
return {
"json_validity_rate": json_validity_rate(row.get("actual") for row in examples),
"workflow_accuracy": workflow_accuracy(examples),
"status_accuracy": status_accuracy(examples),
"required_field_presence_accuracy": required_field_presence_accuracy(examples),
"unsafe_rejection_accuracy": unsafe_rejection_accuracy(examples),
"false_route_rate": false_route_rate(examples),
}
|