HeshamHaroon's picture
Initial release: Arabic Function Calling Leaderboard
566d03e verified
"""
AST-Based Function Call Evaluator
=================================
Evaluates model predictions against ground truth using AST-based matching.
"""
import json
import re
from typing import Any, Dict, List, Optional, Tuple, Union
from dataclasses import dataclass
from .arabic_utils import ArabicNormalizer
@dataclass
class EvaluationResult:
"""Result of evaluating a single sample."""
sample_id: str
category: str
is_correct: bool
score: float
details: Dict[str, Any]
class ArabicASTEvaluator:
"""
AST-based evaluator for Arabic function calling.
Supports multiple evaluation modes:
- exact: Exact match of function name and all arguments
- relaxed: Allows minor variations in argument values
- function_only: Only checks if correct function was called
"""
def __init__(self, mode: str = "exact"):
self.mode = mode
self.normalizer = ArabicNormalizer()
def parse_function_call(self, response: str) -> Optional[Dict]:
"""
Parse a function call from model response.
Handles multiple formats:
- JSON: {"name": "func", "arguments": {...}}
- OpenAI style: {"function_call": {"name": "func", "arguments": "..."}}
- Plain text: func(arg1, arg2)
"""
if not response:
return None
response = response.strip()
# Try JSON format first
try:
data = json.loads(response)
if isinstance(data, dict):
# Direct format
if 'name' in data and 'arguments' in data:
args = data['arguments']
if isinstance(args, str):
args = json.loads(args)
return {'name': data['name'], 'arguments': args}
# OpenAI format
if 'function_call' in data:
fc = data['function_call']
args = fc.get('arguments', {})
if isinstance(args, str):
args = json.loads(args)
return {'name': fc['name'], 'arguments': args}
# Tool calls format
if 'tool_calls' in data and data['tool_calls']:
tc = data['tool_calls'][0]
func = tc.get('function', tc)
args = func.get('arguments', {})
if isinstance(args, str):
args = json.loads(args)
return {'name': func['name'], 'arguments': args}
except (json.JSONDecodeError, KeyError, TypeError):
pass
# Try extracting JSON from text
json_match = re.search(r'\{[^{}]*"name"[^{}]*\}', response, re.DOTALL)
if json_match:
try:
data = json.loads(json_match.group())
if 'name' in data:
args = data.get('arguments', data.get('parameters', {}))
if isinstance(args, str):
args = json.loads(args)
return {'name': data['name'], 'arguments': args}
except (json.JSONDecodeError, KeyError):
pass
# Try plain text function call format: func(args)
func_match = re.match(r'(\w+)\s*\((.*)\)', response, re.DOTALL)
if func_match:
name = func_match.group(1)
args_str = func_match.group(2).strip()
try:
# Try parsing as JSON
if args_str.startswith('{'):
args = json.loads(args_str)
else:
# Parse as key=value pairs
args = {}
for pair in args_str.split(','):
if '=' in pair:
k, v = pair.split('=', 1)
args[k.strip()] = self._parse_value(v.strip())
return {'name': name, 'arguments': args}
except:
pass
return None
def parse_multiple_calls(self, response: str) -> List[Dict]:
"""Parse multiple function calls from response."""
calls = []
if not response:
return calls
# Try JSON array
try:
data = json.loads(response)
if isinstance(data, list):
for item in data:
parsed = self.parse_function_call(json.dumps(item))
if parsed:
calls.append(parsed)
return calls
elif isinstance(data, dict) and 'tool_calls' in data:
for tc in data['tool_calls']:
func = tc.get('function', tc)
args = func.get('arguments', {})
if isinstance(args, str):
args = json.loads(args)
calls.append({'name': func['name'], 'arguments': args})
return calls
except (json.JSONDecodeError, KeyError, TypeError):
pass
# Try finding multiple JSON objects
json_pattern = r'\{[^{}]*"name"[^{}]*\}'
matches = re.findall(json_pattern, response, re.DOTALL)
for match in matches:
parsed = self.parse_function_call(match)
if parsed:
calls.append(parsed)
# If no calls found, try single call
if not calls:
single = self.parse_function_call(response)
if single:
calls.append(single)
return calls
def _parse_value(self, value: str) -> Any:
"""Parse a string value to appropriate type."""
value = value.strip().strip('"\'')
# Try numeric
try:
if '.' in value:
return float(value)
return int(value)
except ValueError:
pass
# Boolean
if value.lower() in ('true', 'false'):
return value.lower() == 'true'
# None
if value.lower() in ('none', 'null'):
return None
return value
def normalize_value(self, value: Any) -> Any:
"""Normalize a value for comparison."""
if isinstance(value, str):
return self.normalizer.normalize(value)
if isinstance(value, (list, tuple)):
return [self.normalize_value(v) for v in value]
if isinstance(value, dict):
return {k: self.normalize_value(v) for k, v in value.items()}
return value
def compare_arguments(
self,
predicted: Dict[str, Any],
expected: Dict[str, Any],
strict: bool = True
) -> Tuple[bool, float, Dict]:
"""
Compare predicted arguments against expected.
Returns: (is_match, score, details)
"""
if not expected:
return len(predicted) == 0, 1.0 if len(predicted) == 0 else 0.0, {}
details = {'matched': [], 'mismatched': [], 'missing': [], 'extra': []}
expected_keys = set(expected.keys())
predicted_keys = set(predicted.keys())
# Check for missing and extra keys
missing = expected_keys - predicted_keys
extra = predicted_keys - expected_keys
details['missing'] = list(missing)
details['extra'] = list(extra)
# Compare common keys
common_keys = expected_keys & predicted_keys
matched_count = 0
for key in common_keys:
exp_val = self.normalize_value(expected[key])
pred_val = self.normalize_value(predicted[key])
if exp_val == pred_val:
details['matched'].append(key)
matched_count += 1
else:
# Try numeric comparison with tolerance
if isinstance(exp_val, (int, float)) and isinstance(pred_val, (int, float)):
if abs(exp_val - pred_val) < 0.001:
details['matched'].append(key)
matched_count += 1
continue
details['mismatched'].append({
'key': key,
'expected': expected[key],
'predicted': predicted[key]
})
# Calculate score
total_expected = len(expected_keys)
if strict:
# All must match, no extras
is_match = (matched_count == total_expected and len(extra) == 0)
score = matched_count / max(total_expected, len(predicted_keys)) if predicted_keys else 0.0
else:
# Partial credit
is_match = matched_count == total_expected
score = matched_count / total_expected if total_expected > 0 else 1.0
return is_match, score, details
def evaluate_single_call(
self,
predicted: Optional[Dict],
expected: Dict
) -> EvaluationResult:
"""Evaluate a single function call prediction."""
if predicted is None:
return EvaluationResult(
sample_id="",
category="",
is_correct=False,
score=0.0,
details={'error': 'Failed to parse prediction'}
)
# Check function name
pred_name = self.normalizer.normalize(predicted.get('name', ''))
exp_name = self.normalizer.normalize(expected.get('name', ''))
if pred_name != exp_name:
return EvaluationResult(
sample_id="",
category="",
is_correct=False,
score=0.0,
details={
'error': 'Function name mismatch',
'expected_name': expected.get('name'),
'predicted_name': predicted.get('name')
}
)
# Compare arguments
pred_args = predicted.get('arguments', {})
exp_args = expected.get('arguments', {})
is_match, score, details = self.compare_arguments(
pred_args, exp_args, strict=(self.mode == 'exact')
)
return EvaluationResult(
sample_id="",
category="",
is_correct=is_match,
score=score,
details=details
)
def evaluate_parallel_calls(
self,
predicted: List[Dict],
expected: List[Dict]
) -> EvaluationResult:
"""
Evaluate parallel function calls (order-agnostic).
Uses bipartite matching for optimal pairing.
"""
if len(predicted) == 0 and len(expected) == 0:
return EvaluationResult(
sample_id="",
category="",
is_correct=True,
score=1.0,
details={'matched_calls': 0}
)
if len(predicted) == 0:
return EvaluationResult(
sample_id="",
category="",
is_correct=False,
score=0.0,
details={'error': 'No predictions', 'expected_count': len(expected)}
)
# Build score matrix
scores = []
for pred in predicted:
row = []
for exp in expected:
result = self.evaluate_single_call(pred, exp)
row.append(result.score)
scores.append(row)
# Greedy matching (could use Hungarian algorithm for optimal)
matched = 0
total_score = 0.0
used_expected = set()
match_details = []
for i, pred in enumerate(predicted):
best_j = -1
best_score = -1
for j, exp in enumerate(expected):
if j not in used_expected and scores[i][j] > best_score:
best_score = scores[i][j]
best_j = j
if best_j >= 0 and best_score > 0:
used_expected.add(best_j)
total_score += best_score
if best_score == 1.0:
matched += 1
match_details.append({
'predicted': pred,
'matched_to': expected[best_j],
'score': best_score
})
# Calculate overall score
max_possible = max(len(predicted), len(expected))
avg_score = total_score / max_possible if max_possible > 0 else 0.0
is_correct = (matched == len(expected) and len(predicted) == len(expected))
return EvaluationResult(
sample_id="",
category="",
is_correct=is_correct,
score=avg_score,
details={
'matched_calls': matched,
'expected_count': len(expected),
'predicted_count': len(predicted),
'matches': match_details
}
)
def evaluate_irrelevance(
self,
predicted: Union[str, Dict, List],
expected_no_call: bool = True
) -> EvaluationResult:
"""
Evaluate irrelevance detection (should not call any function).
"""
# Check if model made any function calls
if isinstance(predicted, str):
calls = self.parse_multiple_calls(predicted)
elif isinstance(predicted, list):
calls = predicted
elif isinstance(predicted, dict):
calls = [predicted] if 'name' in predicted else []
else:
calls = []
made_call = len(calls) > 0
if expected_no_call:
is_correct = not made_call
score = 1.0 if is_correct else 0.0
details = {
'expected': 'no_call',
'actual': 'call_made' if made_call else 'no_call',
'calls_made': calls
}
else:
is_correct = made_call
score = 1.0 if is_correct else 0.0
details = {
'expected': 'call_required',
'actual': 'call_made' if made_call else 'no_call'
}
return EvaluationResult(
sample_id="",
category="irrelevance",
is_correct=is_correct,
score=score,
details=details
)
def evaluate(
self,
sample: Dict,
prediction: str
) -> EvaluationResult:
"""
Main evaluation entry point.
Dispatches to appropriate evaluator based on category.
"""
category = sample.get('category', 'simple')
sample_id = sample.get('id', '')
# Parse ground truth
ground_truth = sample.get('ground_truth')
if isinstance(ground_truth, str) and ground_truth:
try:
ground_truth = json.loads(ground_truth)
except json.JSONDecodeError:
ground_truth = None
# Handle irrelevance
if category == 'irrelevance':
result = self.evaluate_irrelevance(prediction, expected_no_call=True)
result.sample_id = sample_id
return result
# Parse prediction
if category in ('parallel', 'parallel_multiple'):
pred_calls = self.parse_multiple_calls(prediction)
if ground_truth and 'calls' in ground_truth:
exp_calls = ground_truth['calls']
else:
exp_calls = []
result = self.evaluate_parallel_calls(pred_calls, exp_calls)
else:
pred_call = self.parse_function_call(prediction)
if ground_truth:
if 'calls' in ground_truth and ground_truth['calls']:
exp_call = ground_truth['calls'][0]
else:
exp_call = ground_truth
else:
# No ground truth available
result = EvaluationResult(
sample_id=sample_id,
category=category,
is_correct=False,
score=0.0,
details={'error': 'No ground truth available'}
)
return result
result = self.evaluate_single_call(pred_call, exp_call)
result.sample_id = sample_id
result.category = category
return result