File size: 4,800 Bytes
c76423f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
"""
End-to-end evaluation of the tool agent.

Checks whether the agent:
  - routes tool calls at all (rather than answering from memory)
  - calls query_documents before web_search (local-RAG-first routing)
  - finishes within the iteration limit
  - refuses to answer unanswerable questions rather than fabricating a response

No LLM judge is used. All checks are deterministic heuristics so results
are reproducible without API credits.
"""
import warnings

from dotenv import load_dotenv

load_dotenv()

warnings.filterwarnings("ignore", category=DeprecationWarning)

from agents.tool_agent import run_tool_agent
from .eval_common import (
    build_embeddings,
    build_llm,
    build_reranker,
    load_dataset,
)

ITERATION_LIMIT_MESSAGE = "Agent reached the iteration limit without a final answer."

# Phrases that indicate the agent acknowledged it could not find the answer.
_REFUSAL_PHRASES = [
    "don't have", "do not have", "cannot find", "not available",
    "no information", "not mentioned", "not provided", "not contain",
    "does not contain", "doesn't contain", "doesn't provide",
    "does not provide", "unable to find", "i don't know",
]


def _fallback_handled(answer: str) -> bool:
    """
    Return True when the agent appears to have refused rather than fabricated.

    Heuristic: the answer contains at least one refusal/can't-find signal
    phrase. Domain-agnostic — works for any question type, not just salary
    or compensation queries.
    """
    answer_lower = answer.lower()
    return any(phrase in answer_lower for phrase in _REFUSAL_PHRASES)


def main():
    print("Initializing models...")
    llm = build_llm()
    embeddings = build_embeddings()
    reranker = build_reranker()

    dataset = load_dataset()

    print(f"\nRunning tool agent on {len(dataset)} cases...\n")

    col_q = 60
    print(f"{'Question':<{col_q}}  {'Tools':<32}  Fin  Fallback?")
    print("-" * 110)

    total_called_a_tool = 0
    total_used_local_first = 0
    local_first_applicable = 0  # cases where query_documents appeared at all
    total_finished = 0
    total_errors = 0
    fallback_cases_total = 0
    fallback_cases_handled = 0

    for case in dataset:
        question = case["question"]
        expected_behavior = case.get("expected_behavior", "answer")

        # Run the agent. A tool may raise (e.g. a failed web fetch); we record
        # the case as an error and keep going rather than aborting the whole run.
        trace: list = []
        try:
            answer = run_tool_agent(
                question,
                llm=llm,
                embeddings=embeddings,
                reranker=reranker,
                trace=trace,
            )
        except Exception as exc:
            total_errors += 1
            tool_sequence = [step["tool"] for step in trace if "tool" in step]
            print(
                f"{question[:col_q]:<{col_q}}  {str(tool_sequence)[:32]:<32}  "
                f"ERR  {type(exc).__name__}: {str(exc)[:60]}"
            )
            continue

        tool_sequence = [step["tool"] for step in trace if "tool" in step]
        called_a_tool = len(tool_sequence) > 0
        finished = answer != ITERATION_LIMIT_MESSAGE

        # Local-first: only meaningful when query_documents appears in the sequence.
        used_local_first = False
        if "query_documents" in tool_sequence:
            local_first_applicable += 1
            used_local_first = tool_sequence[0] == "query_documents"
            if used_local_first:
                total_used_local_first += 1

        if called_a_tool:
            total_called_a_tool += 1
        if finished:
            total_finished += 1

        fallback_label = ""
        if expected_behavior == "fallback":
            fallback_cases_total += 1
            handled = _fallback_handled(answer)
            if handled:
                fallback_cases_handled += 1
            fallback_label = "OK" if handled else "FAIL"

        tool_str = str(tool_sequence)[:32]
        fin_label = "yes" if finished else "NO"
        print(
            f"{question[:col_q]:<{col_q}}  {tool_str:<32}  {fin_label:<3}  {fallback_label}"
        )

    n = len(dataset)
    print()
    print("=" * 60)
    print(f"Agent evaluation summary ({n} cases)")
    print("=" * 60)
    print(f"  Called a tool:          {total_called_a_tool}/{n}")
    print(f"  Used local RAG first:   {total_used_local_first}/{local_first_applicable}"
          f"  (of cases that called query_documents)")
    print(f"  Finished within limit:  {total_finished}/{n}")
    if fallback_cases_total:
        print(f"  Fallback handled:       {fallback_cases_handled}/{fallback_cases_total}"
              f"  (refused with a can't-find signal phrase)")


if __name__ == "__main__":
    main()