File size: 4,800 Bytes
c76423f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 | """
End-to-end evaluation of the tool agent.
Checks whether the agent:
- routes tool calls at all (rather than answering from memory)
- calls query_documents before web_search (local-RAG-first routing)
- finishes within the iteration limit
- refuses to answer unanswerable questions rather than fabricating a response
No LLM judge is used. All checks are deterministic heuristics so results
are reproducible without API credits.
"""
import warnings
from dotenv import load_dotenv
load_dotenv()
warnings.filterwarnings("ignore", category=DeprecationWarning)
from agents.tool_agent import run_tool_agent
from .eval_common import (
build_embeddings,
build_llm,
build_reranker,
load_dataset,
)
ITERATION_LIMIT_MESSAGE = "Agent reached the iteration limit without a final answer."
# Phrases that indicate the agent acknowledged it could not find the answer.
_REFUSAL_PHRASES = [
"don't have", "do not have", "cannot find", "not available",
"no information", "not mentioned", "not provided", "not contain",
"does not contain", "doesn't contain", "doesn't provide",
"does not provide", "unable to find", "i don't know",
]
def _fallback_handled(answer: str) -> bool:
"""
Return True when the agent appears to have refused rather than fabricated.
Heuristic: the answer contains at least one refusal/can't-find signal
phrase. Domain-agnostic — works for any question type, not just salary
or compensation queries.
"""
answer_lower = answer.lower()
return any(phrase in answer_lower for phrase in _REFUSAL_PHRASES)
def main():
print("Initializing models...")
llm = build_llm()
embeddings = build_embeddings()
reranker = build_reranker()
dataset = load_dataset()
print(f"\nRunning tool agent on {len(dataset)} cases...\n")
col_q = 60
print(f"{'Question':<{col_q}} {'Tools':<32} Fin Fallback?")
print("-" * 110)
total_called_a_tool = 0
total_used_local_first = 0
local_first_applicable = 0 # cases where query_documents appeared at all
total_finished = 0
total_errors = 0
fallback_cases_total = 0
fallback_cases_handled = 0
for case in dataset:
question = case["question"]
expected_behavior = case.get("expected_behavior", "answer")
# Run the agent. A tool may raise (e.g. a failed web fetch); we record
# the case as an error and keep going rather than aborting the whole run.
trace: list = []
try:
answer = run_tool_agent(
question,
llm=llm,
embeddings=embeddings,
reranker=reranker,
trace=trace,
)
except Exception as exc:
total_errors += 1
tool_sequence = [step["tool"] for step in trace if "tool" in step]
print(
f"{question[:col_q]:<{col_q}} {str(tool_sequence)[:32]:<32} "
f"ERR {type(exc).__name__}: {str(exc)[:60]}"
)
continue
tool_sequence = [step["tool"] for step in trace if "tool" in step]
called_a_tool = len(tool_sequence) > 0
finished = answer != ITERATION_LIMIT_MESSAGE
# Local-first: only meaningful when query_documents appears in the sequence.
used_local_first = False
if "query_documents" in tool_sequence:
local_first_applicable += 1
used_local_first = tool_sequence[0] == "query_documents"
if used_local_first:
total_used_local_first += 1
if called_a_tool:
total_called_a_tool += 1
if finished:
total_finished += 1
fallback_label = ""
if expected_behavior == "fallback":
fallback_cases_total += 1
handled = _fallback_handled(answer)
if handled:
fallback_cases_handled += 1
fallback_label = "OK" if handled else "FAIL"
tool_str = str(tool_sequence)[:32]
fin_label = "yes" if finished else "NO"
print(
f"{question[:col_q]:<{col_q}} {tool_str:<32} {fin_label:<3} {fallback_label}"
)
n = len(dataset)
print()
print("=" * 60)
print(f"Agent evaluation summary ({n} cases)")
print("=" * 60)
print(f" Called a tool: {total_called_a_tool}/{n}")
print(f" Used local RAG first: {total_used_local_first}/{local_first_applicable}"
f" (of cases that called query_documents)")
print(f" Finished within limit: {total_finished}/{n}")
if fallback_cases_total:
print(f" Fallback handled: {fallback_cases_handled}/{fallback_cases_total}"
f" (refused with a can't-find signal phrase)")
if __name__ == "__main__":
main()
|