File size: 6,647 Bytes
0584798
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
from __future__ import annotations

import os
from collections import Counter

from combined_inference import classify_query
from iab_classifier import predict_iab_content_classifier_batch
from iab_retrieval import predict_iab_content_retrieval_batch
from iab_taxonomy import parse_path_label


def _include_shadow_retrieval_in_iab_views() -> bool:
    """Shadow retrieval loads Alibaba-NLP/gte-Qwen2-1.5B (~7GB) when the taxonomy index exists."""
    value = os.environ.get("IAB_EVAL_INCLUDE_SHADOW_RETRIEVAL", "0").strip().lower()
    return value in ("1", "true", "yes")


def path_from_content(content: dict) -> tuple[str, ...]:
    path = []
    for tier in ("tier1", "tier2", "tier3", "tier4"):
        if tier in content:
            path.append(content[tier]["label"])
    return tuple(path)


def path_from_label(label: str) -> tuple[str, ...]:
    return parse_path_label(label)


def is_parent_safe(true_path: tuple[str, ...], pred_path: tuple[str, ...]) -> bool:
    if not pred_path:
        return False
    if len(pred_path) > len(true_path):
        return False
    return true_path[: len(pred_path)] == pred_path


def error_bucket(true_path: tuple[str, ...], pred_path: tuple[str, ...]) -> str:
    if pred_path == true_path:
        return "exact_match"
    if not pred_path:
        return "no_prediction"
    if true_path[:1] != pred_path[:1]:
        return "wrong_tier1"
    if len(true_path) >= 2 and (len(pred_path) < 2 or true_path[:2] != pred_path[:2]):
        return "right_tier1_wrong_tier2"
    if is_parent_safe(true_path, pred_path):
        return "parent_safe_stop"
    return "wrong_deep_leaf"


def compute_path_metrics(true_paths: list[tuple[str, ...]], pred_paths: list[tuple[str, ...]]) -> dict:
    total = len(true_paths)
    if total == 0:
        return {
            "tier1_accuracy": 0.0,
            "tier2_accuracy": 0.0,
            "tier3_accuracy": 0.0,
            "tier4_accuracy": 0.0,
            "exact_path_accuracy": 0.0,
            "parent_safe_accuracy": 0.0,
            "average_prediction_depth": 0.0,
            "error_buckets": {},
        }

    tier_hits = {1: 0, 2: 0, 3: 0, 4: 0}
    tier_totals = {1: 0, 2: 0, 3: 0, 4: 0}
    exact_hits = 0
    parent_safe_hits = 0
    buckets = Counter()
    for true_path, pred_path in zip(true_paths, pred_paths):
        if pred_path == true_path:
            exact_hits += 1
        if is_parent_safe(true_path, pred_path):
            parent_safe_hits += 1
        buckets[error_bucket(true_path, pred_path)] += 1
        for level in range(1, 5):
            if len(true_path) < level:
                continue
            tier_totals[level] += 1
            if len(pred_path) >= level and true_path[:level] == pred_path[:level]:
                tier_hits[level] += 1

    return {
        "tier1_accuracy": round(tier_hits[1] / max(tier_totals[1], 1), 4),
        "tier2_accuracy": round(tier_hits[2] / max(tier_totals[2], 1), 4),
        "tier3_accuracy": round(tier_hits[3] / max(tier_totals[3], 1), 4),
        "tier4_accuracy": round(tier_hits[4] / max(tier_totals[4], 1), 4),
        "exact_path_accuracy": round(exact_hits / total, 4),
        "parent_safe_accuracy": round(parent_safe_hits / total, 4),
        "average_prediction_depth": round(sum(len(path) for path in pred_paths) / total, 4),
        "error_buckets": dict(sorted(buckets.items())),
    }


def evaluate_iab_views(rows: list[dict], max_combined_rows: int = 500) -> dict:
    texts = [row["text"] for row in rows]
    true_paths = [path_from_label(row["iab_path"]) for row in rows]
    classifier_outputs = predict_iab_content_classifier_batch(texts)
    if not any(output is not None for output in classifier_outputs):
        raise RuntimeError(
            "IAB classifier artifacts are unavailable. Run `python3 training/train_iab.py` "
            "and `python3 training/calibrate_confidence.py --head iab_content` "
            "from the `agentic-intent-classifier` directory first."
        )

    classifier_paths = [path_from_content(output["content"]) if output is not None else tuple() for output in classifier_outputs]
    views = {"classifier": compute_path_metrics(true_paths, classifier_paths)}

    if _include_shadow_retrieval_in_iab_views():
        retrieval_outputs = predict_iab_content_retrieval_batch(texts)
    else:
        retrieval_outputs = [None for _ in texts]
        views["shadow_embedding_retrieval"] = {
            "skipped": True,
            "reason": "disabled_by_default",
            "hint": "Set IAB_EVAL_INCLUDE_SHADOW_RETRIEVAL=1 to run shadow embedding retrieval (downloads/loads gte-Qwen2 when index is present).",
        }

    if any(output is not None for output in retrieval_outputs):
        retrieval_paths = [path_from_content(output["content"]) if output is not None else tuple() for output in retrieval_outputs]
        views["shadow_embedding_retrieval"] = compute_path_metrics(true_paths, retrieval_paths)

    if len(rows) > max_combined_rows:
        views["combined_path"] = {
            "skipped": True,
            "reason": "dataset_too_large_for_combined_view",
            "count": len(rows),
            "max_combined_rows": max_combined_rows,
        }
        views["disagreements"] = {
            "skipped": True,
            "reason": "dataset_too_large_for_combined_view",
            "count": len(rows),
            "max_combined_rows": max_combined_rows,
        }
        return views

    combined_payloads = [classify_query(text) for text in texts]
    combined_contents = [payload["model_output"]["classification"]["iab_content"] for payload in combined_payloads]
    combined_fallbacks = [bool(payload["model_output"].get("fallback")) for payload in combined_payloads]
    combined_paths = [path_from_content(content) for content in combined_contents]
    views["combined_path"] = {
        **compute_path_metrics(true_paths, combined_paths),
        "fallback_rate": round(sum(combined_fallbacks) / max(len(combined_fallbacks), 1), 4),
        "fallback_overuse_count": sum(combined_fallbacks),
    }
    disagreements = {
        "classifier_vs_combined": sum(1 for left, right in zip(classifier_paths, combined_paths) if left != right),
    }
    if any(output is not None for output in retrieval_outputs):
        disagreements["retrieval_vs_classifier"] = sum(
            1 for left, right in zip(retrieval_paths, classifier_paths) if left != right
        )
        disagreements["retrieval_vs_combined"] = sum(
            1 for left, right in zip(retrieval_paths, combined_paths) if left != right
        )
    views["disagreements"] = disagreements
    return views