File size: 4,018 Bytes
1137e50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from __future__ import annotations

import json
from collections.abc import Iterable
from typing import Any


ROUTER_CONTRACT_KEYS = {
    "status",
    "workflow",
    "confidence",
    "parameters",
    "missing_fields",
    "candidate_workflows",
    "failure_reasons",
    "clarifying_question",
}

NON_ROUTED_EXPECTED_STATUSES = {"needs_clarification", "rejected", "requires_confirmation"}
FIELD_PRESENCE_STATUSES = {"routed", "requires_confirmation"}


def _as_output(value: Any) -> tuple[bool, dict[str, Any]]:
    if hasattr(value, "model_dump"):
        value = value.model_dump(mode="json")
    elif isinstance(value, str):
        try:
            value = json.loads(value)
        except json.JSONDecodeError:
            return False, {}

    if not isinstance(value, dict):
        return False, {}

    return ROUTER_CONTRACT_KEYS.issubset(value), value


def _safe_divide(numerator: int, denominator: int) -> float:
    if denominator == 0:
        return 0.0
    return numerator / denominator


def json_validity_rate(actual_outputs: Iterable[Any]) -> float:
    outputs = list(actual_outputs)
    if not outputs:
        return 0.0
    valid = sum(1 for output in outputs if _as_output(output)[0])
    return valid / len(outputs)


def workflow_accuracy(rows: Iterable[dict[str, Any]]) -> float:
    examples = list(rows)
    scored = [
        row
        for row in examples
        if row.get("expected", {}).get("workflow") is not None
    ]
    correct = 0
    for row in scored:
        _, actual = _as_output(row.get("actual"))
        if actual.get("workflow") == row["expected"].get("workflow"):
            correct += 1
    return _safe_divide(correct, len(scored))


def status_accuracy(rows: Iterable[dict[str, Any]]) -> float:
    examples = list(rows)
    correct = 0
    for row in examples:
        _, actual = _as_output(row.get("actual"))
        if actual.get("status") == row.get("expected", {}).get("status"):
            correct += 1
    return _safe_divide(correct, len(examples))


def required_field_presence_accuracy(rows: Iterable[dict[str, Any]]) -> float:
    examples = [
        row
        for row in rows
        if row.get("expected", {}).get("status") in FIELD_PRESENCE_STATUSES
    ]
    correct = 0
    for row in examples:
        expected_params = row.get("expected", {}).get("parameters", {})
        _, actual = _as_output(row.get("actual"))
        actual_params = actual.get("parameters", {})
        if not isinstance(actual_params, dict):
            continue
        if set(expected_params).issubset(actual_params):
            correct += 1
    return _safe_divide(correct, len(examples))


def unsafe_rejection_accuracy(rows: Iterable[dict[str, Any]]) -> float:
    risky_rows = [row for row in rows if row.get("case_type") == "risky_rejected"]
    correct = 0
    for row in risky_rows:
        _, actual = _as_output(row.get("actual"))
        if actual.get("status") == "rejected":
            correct += 1
    return _safe_divide(correct, len(risky_rows))


def false_route_rate(rows: Iterable[dict[str, Any]]) -> float:
    should_not_route = [
        row
        for row in rows
        if row.get("expected", {}).get("status") in NON_ROUTED_EXPECTED_STATUSES
    ]
    false_routes = 0
    for row in should_not_route:
        _, actual = _as_output(row.get("actual"))
        if actual.get("status") == "routed":
            false_routes += 1
    return _safe_divide(false_routes, len(should_not_route))


def compute_metrics(rows: Iterable[dict[str, Any]]) -> dict[str, float]:
    examples = list(rows)
    return {
        "json_validity_rate": json_validity_rate(row.get("actual") for row in examples),
        "workflow_accuracy": workflow_accuracy(examples),
        "status_accuracy": status_accuracy(examples),
        "required_field_presence_accuracy": required_field_presence_accuracy(examples),
        "unsafe_rejection_accuracy": unsafe_rejection_accuracy(examples),
        "false_route_rate": false_route_rate(examples),
    }