BPO-Bench / evaluator.py
haroldshipibm's picture
Upload folder using huggingface_hub
d075a5b verified
"""Keyword-based evaluation for BPO benchmark."""
from typing import List, Dict, Any
def check_keywords(response: str, expected_keywords: List[str]) -> Dict[str, Any]:
"""
Check if response contains expected keywords (supports OR with |).
Args:
response: The agent's response text
expected_keywords: List of keywords to check. Each keyword can contain
alternatives separated by | (e.g., "67%|67 %|67")
Returns:
Dictionary with found/missing keywords, match rate, and pass status
"""
found = []
missing = []
for keyword in expected_keywords:
alternatives = keyword.split("|")
if any(alt.lower() in response.lower() for alt in alternatives):
found.append(keyword)
else:
missing.append(keyword)
match_rate = len(found) / len(expected_keywords) if expected_keywords else 1.0
return {
"found": found,
"missing": missing,
"match_rate": match_rate,
"passed": len(missing) == 0
}
def evaluate_task(task: Dict[str, Any], response: str, tool_calls: List[Dict[str, Any]]) -> Dict[str, Any]:
"""
Evaluate a single task.
Args:
task: Task definition from tasks.json
response: The agent's response text
tool_calls: List of tool calls made by the agent
Returns:
Evaluation result dictionary
"""
expected_output = task.get("expected_output", {})
keywords = expected_output.get("keywords", [])
result = check_keywords(response, keywords)
# Extract tool names from tool calls
tool_names = []
for tc in tool_calls:
if isinstance(tc, dict):
name = tc.get("name") or tc.get("function", {}).get("name", "")
if name:
tool_names.append(name)
elif isinstance(tc, str):
tool_names.append(tc)
# Check expected tool calls
expected_tools = expected_output.get("tool_calls", [])
expected_tool_names = [t.get("name", "") for t in expected_tools if isinstance(t, dict)]
# Calculate tool call accuracy
if expected_tool_names:
matched_tools = sum(1 for t in expected_tool_names if any(t in tn for tn in tool_names))
tool_accuracy = matched_tools / len(expected_tool_names)
else:
# No tools expected - check that none were called or that's acceptable
tool_accuracy = 1.0 if not tool_names else 0.5
# Calculate API count accuracy (lenient: correct if actual >= expected)
api_call_count = len(tool_names)
expected_api_count = len(expected_tool_names)
api_count_correct = 1 if api_call_count >= expected_api_count else 0
return {
"task_id": task.get("name", "unknown"),
"difficulty": task.get("difficulty", "unknown"),
"intent": task.get("intent", ""),
"response": response,
"expected_keywords": keywords,
"found_keywords": result["found"],
"missing_keywords": result["missing"],
"match_rate": result["match_rate"],
"passed": result["passed"],
"tool_calls": tool_names,
"expected_tool_calls": expected_tool_names,
"tool_accuracy": tool_accuracy,
"api_call_count": api_call_count,
"expected_api_count": expected_api_count,
"api_count_correct": api_count_correct,
}
def calculate_summary(results: List[Dict[str, Any]]) -> Dict[str, Any]:
"""
Calculate summary statistics from evaluation results.
Args:
results: List of evaluation results from evaluate_task
Returns:
Summary dictionary with pass rates and averages
"""
if not results:
return {
"total_tasks": 0,
"passed": 0,
"pass_rate": 0.0,
"avg_match_rate": 0.0,
"avg_tool_accuracy": 0.0,
"api_count_accuracy": 0.0,
"by_difficulty": {},
}
total = len(results)
passed = sum(1 for r in results if r.get("passed", False))
avg_match = sum(r.get("match_rate", 0) for r in results) / total
avg_tool = sum(r.get("tool_accuracy", 0) for r in results) / total
api_count_correct = sum(r.get("api_count_correct", 0) for r in results)
# Group by difficulty
by_difficulty = {}
for r in results:
diff = r.get("difficulty", "unknown")
if diff not in by_difficulty:
by_difficulty[diff] = {"total": 0, "passed": 0}
by_difficulty[diff]["total"] += 1
if r.get("passed", False):
by_difficulty[diff]["passed"] += 1
for diff in by_difficulty:
by_difficulty[diff]["pass_rate"] = (
by_difficulty[diff]["passed"] / by_difficulty[diff]["total"]
)
return {
"total_tasks": total,
"passed": passed,
"pass_rate": passed / total,
"avg_match_rate": avg_match,
"avg_tool_accuracy": avg_tool,
"api_count_accuracy": api_count_correct / total,
"by_difficulty": by_difficulty,
}