|
|
""" |
|
|
AST-Based Function Call Evaluator |
|
|
================================= |
|
|
|
|
|
Evaluates model predictions against ground truth using AST-based matching. |
|
|
""" |
|
|
|
|
|
import json |
|
|
import re |
|
|
from typing import Any, Dict, List, Optional, Tuple, Union |
|
|
from dataclasses import dataclass |
|
|
from .arabic_utils import ArabicNormalizer |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class EvaluationResult: |
|
|
"""Result of evaluating a single sample.""" |
|
|
sample_id: str |
|
|
category: str |
|
|
is_correct: bool |
|
|
score: float |
|
|
details: Dict[str, Any] |
|
|
|
|
|
|
|
|
class ArabicASTEvaluator: |
|
|
""" |
|
|
AST-based evaluator for Arabic function calling. |
|
|
|
|
|
Supports multiple evaluation modes: |
|
|
- exact: Exact match of function name and all arguments |
|
|
- relaxed: Allows minor variations in argument values |
|
|
- function_only: Only checks if correct function was called |
|
|
""" |
|
|
|
|
|
def __init__(self, mode: str = "exact"): |
|
|
self.mode = mode |
|
|
self.normalizer = ArabicNormalizer() |
|
|
|
|
|
def parse_function_call(self, response: str) -> Optional[Dict]: |
|
|
""" |
|
|
Parse a function call from model response. |
|
|
Handles multiple formats: |
|
|
- JSON: {"name": "func", "arguments": {...}} |
|
|
- OpenAI style: {"function_call": {"name": "func", "arguments": "..."}} |
|
|
- Plain text: func(arg1, arg2) |
|
|
""" |
|
|
if not response: |
|
|
return None |
|
|
|
|
|
response = response.strip() |
|
|
|
|
|
|
|
|
try: |
|
|
data = json.loads(response) |
|
|
if isinstance(data, dict): |
|
|
|
|
|
if 'name' in data and 'arguments' in data: |
|
|
args = data['arguments'] |
|
|
if isinstance(args, str): |
|
|
args = json.loads(args) |
|
|
return {'name': data['name'], 'arguments': args} |
|
|
|
|
|
if 'function_call' in data: |
|
|
fc = data['function_call'] |
|
|
args = fc.get('arguments', {}) |
|
|
if isinstance(args, str): |
|
|
args = json.loads(args) |
|
|
return {'name': fc['name'], 'arguments': args} |
|
|
|
|
|
if 'tool_calls' in data and data['tool_calls']: |
|
|
tc = data['tool_calls'][0] |
|
|
func = tc.get('function', tc) |
|
|
args = func.get('arguments', {}) |
|
|
if isinstance(args, str): |
|
|
args = json.loads(args) |
|
|
return {'name': func['name'], 'arguments': args} |
|
|
except (json.JSONDecodeError, KeyError, TypeError): |
|
|
pass |
|
|
|
|
|
|
|
|
json_match = re.search(r'\{[^{}]*"name"[^{}]*\}', response, re.DOTALL) |
|
|
if json_match: |
|
|
try: |
|
|
data = json.loads(json_match.group()) |
|
|
if 'name' in data: |
|
|
args = data.get('arguments', data.get('parameters', {})) |
|
|
if isinstance(args, str): |
|
|
args = json.loads(args) |
|
|
return {'name': data['name'], 'arguments': args} |
|
|
except (json.JSONDecodeError, KeyError): |
|
|
pass |
|
|
|
|
|
|
|
|
func_match = re.match(r'(\w+)\s*\((.*)\)', response, re.DOTALL) |
|
|
if func_match: |
|
|
name = func_match.group(1) |
|
|
args_str = func_match.group(2).strip() |
|
|
try: |
|
|
|
|
|
if args_str.startswith('{'): |
|
|
args = json.loads(args_str) |
|
|
else: |
|
|
|
|
|
args = {} |
|
|
for pair in args_str.split(','): |
|
|
if '=' in pair: |
|
|
k, v = pair.split('=', 1) |
|
|
args[k.strip()] = self._parse_value(v.strip()) |
|
|
return {'name': name, 'arguments': args} |
|
|
except: |
|
|
pass |
|
|
|
|
|
return None |
|
|
|
|
|
def parse_multiple_calls(self, response: str) -> List[Dict]: |
|
|
"""Parse multiple function calls from response.""" |
|
|
calls = [] |
|
|
|
|
|
if not response: |
|
|
return calls |
|
|
|
|
|
|
|
|
try: |
|
|
data = json.loads(response) |
|
|
if isinstance(data, list): |
|
|
for item in data: |
|
|
parsed = self.parse_function_call(json.dumps(item)) |
|
|
if parsed: |
|
|
calls.append(parsed) |
|
|
return calls |
|
|
elif isinstance(data, dict) and 'tool_calls' in data: |
|
|
for tc in data['tool_calls']: |
|
|
func = tc.get('function', tc) |
|
|
args = func.get('arguments', {}) |
|
|
if isinstance(args, str): |
|
|
args = json.loads(args) |
|
|
calls.append({'name': func['name'], 'arguments': args}) |
|
|
return calls |
|
|
except (json.JSONDecodeError, KeyError, TypeError): |
|
|
pass |
|
|
|
|
|
|
|
|
json_pattern = r'\{[^{}]*"name"[^{}]*\}' |
|
|
matches = re.findall(json_pattern, response, re.DOTALL) |
|
|
for match in matches: |
|
|
parsed = self.parse_function_call(match) |
|
|
if parsed: |
|
|
calls.append(parsed) |
|
|
|
|
|
|
|
|
if not calls: |
|
|
single = self.parse_function_call(response) |
|
|
if single: |
|
|
calls.append(single) |
|
|
|
|
|
return calls |
|
|
|
|
|
def _parse_value(self, value: str) -> Any: |
|
|
"""Parse a string value to appropriate type.""" |
|
|
value = value.strip().strip('"\'') |
|
|
|
|
|
try: |
|
|
if '.' in value: |
|
|
return float(value) |
|
|
return int(value) |
|
|
except ValueError: |
|
|
pass |
|
|
|
|
|
if value.lower() in ('true', 'false'): |
|
|
return value.lower() == 'true' |
|
|
|
|
|
if value.lower() in ('none', 'null'): |
|
|
return None |
|
|
return value |
|
|
|
|
|
def normalize_value(self, value: Any) -> Any: |
|
|
"""Normalize a value for comparison.""" |
|
|
if isinstance(value, str): |
|
|
return self.normalizer.normalize(value) |
|
|
if isinstance(value, (list, tuple)): |
|
|
return [self.normalize_value(v) for v in value] |
|
|
if isinstance(value, dict): |
|
|
return {k: self.normalize_value(v) for k, v in value.items()} |
|
|
return value |
|
|
|
|
|
def compare_arguments( |
|
|
self, |
|
|
predicted: Dict[str, Any], |
|
|
expected: Dict[str, Any], |
|
|
strict: bool = True |
|
|
) -> Tuple[bool, float, Dict]: |
|
|
""" |
|
|
Compare predicted arguments against expected. |
|
|
|
|
|
Returns: (is_match, score, details) |
|
|
""" |
|
|
if not expected: |
|
|
return len(predicted) == 0, 1.0 if len(predicted) == 0 else 0.0, {} |
|
|
|
|
|
details = {'matched': [], 'mismatched': [], 'missing': [], 'extra': []} |
|
|
|
|
|
expected_keys = set(expected.keys()) |
|
|
predicted_keys = set(predicted.keys()) |
|
|
|
|
|
|
|
|
missing = expected_keys - predicted_keys |
|
|
extra = predicted_keys - expected_keys |
|
|
|
|
|
details['missing'] = list(missing) |
|
|
details['extra'] = list(extra) |
|
|
|
|
|
|
|
|
common_keys = expected_keys & predicted_keys |
|
|
matched_count = 0 |
|
|
|
|
|
for key in common_keys: |
|
|
exp_val = self.normalize_value(expected[key]) |
|
|
pred_val = self.normalize_value(predicted[key]) |
|
|
|
|
|
if exp_val == pred_val: |
|
|
details['matched'].append(key) |
|
|
matched_count += 1 |
|
|
else: |
|
|
|
|
|
if isinstance(exp_val, (int, float)) and isinstance(pred_val, (int, float)): |
|
|
if abs(exp_val - pred_val) < 0.001: |
|
|
details['matched'].append(key) |
|
|
matched_count += 1 |
|
|
continue |
|
|
details['mismatched'].append({ |
|
|
'key': key, |
|
|
'expected': expected[key], |
|
|
'predicted': predicted[key] |
|
|
}) |
|
|
|
|
|
|
|
|
total_expected = len(expected_keys) |
|
|
if strict: |
|
|
|
|
|
is_match = (matched_count == total_expected and len(extra) == 0) |
|
|
score = matched_count / max(total_expected, len(predicted_keys)) if predicted_keys else 0.0 |
|
|
else: |
|
|
|
|
|
is_match = matched_count == total_expected |
|
|
score = matched_count / total_expected if total_expected > 0 else 1.0 |
|
|
|
|
|
return is_match, score, details |
|
|
|
|
|
def evaluate_single_call( |
|
|
self, |
|
|
predicted: Optional[Dict], |
|
|
expected: Dict |
|
|
) -> EvaluationResult: |
|
|
"""Evaluate a single function call prediction.""" |
|
|
if predicted is None: |
|
|
return EvaluationResult( |
|
|
sample_id="", |
|
|
category="", |
|
|
is_correct=False, |
|
|
score=0.0, |
|
|
details={'error': 'Failed to parse prediction'} |
|
|
) |
|
|
|
|
|
|
|
|
pred_name = self.normalizer.normalize(predicted.get('name', '')) |
|
|
exp_name = self.normalizer.normalize(expected.get('name', '')) |
|
|
|
|
|
if pred_name != exp_name: |
|
|
return EvaluationResult( |
|
|
sample_id="", |
|
|
category="", |
|
|
is_correct=False, |
|
|
score=0.0, |
|
|
details={ |
|
|
'error': 'Function name mismatch', |
|
|
'expected_name': expected.get('name'), |
|
|
'predicted_name': predicted.get('name') |
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
pred_args = predicted.get('arguments', {}) |
|
|
exp_args = expected.get('arguments', {}) |
|
|
|
|
|
is_match, score, details = self.compare_arguments( |
|
|
pred_args, exp_args, strict=(self.mode == 'exact') |
|
|
) |
|
|
|
|
|
return EvaluationResult( |
|
|
sample_id="", |
|
|
category="", |
|
|
is_correct=is_match, |
|
|
score=score, |
|
|
details=details |
|
|
) |
|
|
|
|
|
def evaluate_parallel_calls( |
|
|
self, |
|
|
predicted: List[Dict], |
|
|
expected: List[Dict] |
|
|
) -> EvaluationResult: |
|
|
""" |
|
|
Evaluate parallel function calls (order-agnostic). |
|
|
Uses bipartite matching for optimal pairing. |
|
|
""" |
|
|
if len(predicted) == 0 and len(expected) == 0: |
|
|
return EvaluationResult( |
|
|
sample_id="", |
|
|
category="", |
|
|
is_correct=True, |
|
|
score=1.0, |
|
|
details={'matched_calls': 0} |
|
|
) |
|
|
|
|
|
if len(predicted) == 0: |
|
|
return EvaluationResult( |
|
|
sample_id="", |
|
|
category="", |
|
|
is_correct=False, |
|
|
score=0.0, |
|
|
details={'error': 'No predictions', 'expected_count': len(expected)} |
|
|
) |
|
|
|
|
|
|
|
|
scores = [] |
|
|
for pred in predicted: |
|
|
row = [] |
|
|
for exp in expected: |
|
|
result = self.evaluate_single_call(pred, exp) |
|
|
row.append(result.score) |
|
|
scores.append(row) |
|
|
|
|
|
|
|
|
matched = 0 |
|
|
total_score = 0.0 |
|
|
used_expected = set() |
|
|
match_details = [] |
|
|
|
|
|
for i, pred in enumerate(predicted): |
|
|
best_j = -1 |
|
|
best_score = -1 |
|
|
|
|
|
for j, exp in enumerate(expected): |
|
|
if j not in used_expected and scores[i][j] > best_score: |
|
|
best_score = scores[i][j] |
|
|
best_j = j |
|
|
|
|
|
if best_j >= 0 and best_score > 0: |
|
|
used_expected.add(best_j) |
|
|
total_score += best_score |
|
|
if best_score == 1.0: |
|
|
matched += 1 |
|
|
match_details.append({ |
|
|
'predicted': pred, |
|
|
'matched_to': expected[best_j], |
|
|
'score': best_score |
|
|
}) |
|
|
|
|
|
|
|
|
max_possible = max(len(predicted), len(expected)) |
|
|
avg_score = total_score / max_possible if max_possible > 0 else 0.0 |
|
|
|
|
|
is_correct = (matched == len(expected) and len(predicted) == len(expected)) |
|
|
|
|
|
return EvaluationResult( |
|
|
sample_id="", |
|
|
category="", |
|
|
is_correct=is_correct, |
|
|
score=avg_score, |
|
|
details={ |
|
|
'matched_calls': matched, |
|
|
'expected_count': len(expected), |
|
|
'predicted_count': len(predicted), |
|
|
'matches': match_details |
|
|
} |
|
|
) |
|
|
|
|
|
def evaluate_irrelevance( |
|
|
self, |
|
|
predicted: Union[str, Dict, List], |
|
|
expected_no_call: bool = True |
|
|
) -> EvaluationResult: |
|
|
""" |
|
|
Evaluate irrelevance detection (should not call any function). |
|
|
""" |
|
|
|
|
|
if isinstance(predicted, str): |
|
|
calls = self.parse_multiple_calls(predicted) |
|
|
elif isinstance(predicted, list): |
|
|
calls = predicted |
|
|
elif isinstance(predicted, dict): |
|
|
calls = [predicted] if 'name' in predicted else [] |
|
|
else: |
|
|
calls = [] |
|
|
|
|
|
made_call = len(calls) > 0 |
|
|
|
|
|
if expected_no_call: |
|
|
is_correct = not made_call |
|
|
score = 1.0 if is_correct else 0.0 |
|
|
details = { |
|
|
'expected': 'no_call', |
|
|
'actual': 'call_made' if made_call else 'no_call', |
|
|
'calls_made': calls |
|
|
} |
|
|
else: |
|
|
is_correct = made_call |
|
|
score = 1.0 if is_correct else 0.0 |
|
|
details = { |
|
|
'expected': 'call_required', |
|
|
'actual': 'call_made' if made_call else 'no_call' |
|
|
} |
|
|
|
|
|
return EvaluationResult( |
|
|
sample_id="", |
|
|
category="irrelevance", |
|
|
is_correct=is_correct, |
|
|
score=score, |
|
|
details=details |
|
|
) |
|
|
|
|
|
def evaluate( |
|
|
self, |
|
|
sample: Dict, |
|
|
prediction: str |
|
|
) -> EvaluationResult: |
|
|
""" |
|
|
Main evaluation entry point. |
|
|
Dispatches to appropriate evaluator based on category. |
|
|
""" |
|
|
category = sample.get('category', 'simple') |
|
|
sample_id = sample.get('id', '') |
|
|
|
|
|
|
|
|
ground_truth = sample.get('ground_truth') |
|
|
if isinstance(ground_truth, str) and ground_truth: |
|
|
try: |
|
|
ground_truth = json.loads(ground_truth) |
|
|
except json.JSONDecodeError: |
|
|
ground_truth = None |
|
|
|
|
|
|
|
|
if category == 'irrelevance': |
|
|
result = self.evaluate_irrelevance(prediction, expected_no_call=True) |
|
|
result.sample_id = sample_id |
|
|
return result |
|
|
|
|
|
|
|
|
if category in ('parallel', 'parallel_multiple'): |
|
|
pred_calls = self.parse_multiple_calls(prediction) |
|
|
if ground_truth and 'calls' in ground_truth: |
|
|
exp_calls = ground_truth['calls'] |
|
|
else: |
|
|
exp_calls = [] |
|
|
result = self.evaluate_parallel_calls(pred_calls, exp_calls) |
|
|
else: |
|
|
pred_call = self.parse_function_call(prediction) |
|
|
if ground_truth: |
|
|
if 'calls' in ground_truth and ground_truth['calls']: |
|
|
exp_call = ground_truth['calls'][0] |
|
|
else: |
|
|
exp_call = ground_truth |
|
|
else: |
|
|
|
|
|
result = EvaluationResult( |
|
|
sample_id=sample_id, |
|
|
category=category, |
|
|
is_correct=False, |
|
|
score=0.0, |
|
|
details={'error': 'No ground truth available'} |
|
|
) |
|
|
return result |
|
|
result = self.evaluate_single_call(pred_call, exp_call) |
|
|
|
|
|
result.sample_id = sample_id |
|
|
result.category = category |
|
|
return result |
|
|
|