""" AST-Based Function Call Evaluator ================================= Evaluates model predictions against ground truth using AST-based matching. """ import json import re from typing import Any, Dict, List, Optional, Tuple, Union from dataclasses import dataclass from .arabic_utils import ArabicNormalizer @dataclass class EvaluationResult: """Result of evaluating a single sample.""" sample_id: str category: str is_correct: bool score: float details: Dict[str, Any] class ArabicASTEvaluator: """ AST-based evaluator for Arabic function calling. Supports multiple evaluation modes: - exact: Exact match of function name and all arguments - relaxed: Allows minor variations in argument values - function_only: Only checks if correct function was called """ def __init__(self, mode: str = "exact"): self.mode = mode self.normalizer = ArabicNormalizer() def parse_function_call(self, response: str) -> Optional[Dict]: """ Parse a function call from model response. Handles multiple formats: - JSON: {"name": "func", "arguments": {...}} - OpenAI style: {"function_call": {"name": "func", "arguments": "..."}} - Plain text: func(arg1, arg2) """ if not response: return None response = response.strip() # Try JSON format first try: data = json.loads(response) if isinstance(data, dict): # Direct format if 'name' in data and 'arguments' in data: args = data['arguments'] if isinstance(args, str): args = json.loads(args) return {'name': data['name'], 'arguments': args} # OpenAI format if 'function_call' in data: fc = data['function_call'] args = fc.get('arguments', {}) if isinstance(args, str): args = json.loads(args) return {'name': fc['name'], 'arguments': args} # Tool calls format if 'tool_calls' in data and data['tool_calls']: tc = data['tool_calls'][0] func = tc.get('function', tc) args = func.get('arguments', {}) if isinstance(args, str): args = json.loads(args) return {'name': func['name'], 'arguments': args} except (json.JSONDecodeError, KeyError, TypeError): pass # Try extracting JSON from text json_match = re.search(r'\{[^{}]*"name"[^{}]*\}', response, re.DOTALL) if json_match: try: data = json.loads(json_match.group()) if 'name' in data: args = data.get('arguments', data.get('parameters', {})) if isinstance(args, str): args = json.loads(args) return {'name': data['name'], 'arguments': args} except (json.JSONDecodeError, KeyError): pass # Try plain text function call format: func(args) func_match = re.match(r'(\w+)\s*\((.*)\)', response, re.DOTALL) if func_match: name = func_match.group(1) args_str = func_match.group(2).strip() try: # Try parsing as JSON if args_str.startswith('{'): args = json.loads(args_str) else: # Parse as key=value pairs args = {} for pair in args_str.split(','): if '=' in pair: k, v = pair.split('=', 1) args[k.strip()] = self._parse_value(v.strip()) return {'name': name, 'arguments': args} except: pass return None def parse_multiple_calls(self, response: str) -> List[Dict]: """Parse multiple function calls from response.""" calls = [] if not response: return calls # Try JSON array try: data = json.loads(response) if isinstance(data, list): for item in data: parsed = self.parse_function_call(json.dumps(item)) if parsed: calls.append(parsed) return calls elif isinstance(data, dict) and 'tool_calls' in data: for tc in data['tool_calls']: func = tc.get('function', tc) args = func.get('arguments', {}) if isinstance(args, str): args = json.loads(args) calls.append({'name': func['name'], 'arguments': args}) return calls except (json.JSONDecodeError, KeyError, TypeError): pass # Try finding multiple JSON objects json_pattern = r'\{[^{}]*"name"[^{}]*\}' matches = re.findall(json_pattern, response, re.DOTALL) for match in matches: parsed = self.parse_function_call(match) if parsed: calls.append(parsed) # If no calls found, try single call if not calls: single = self.parse_function_call(response) if single: calls.append(single) return calls def _parse_value(self, value: str) -> Any: """Parse a string value to appropriate type.""" value = value.strip().strip('"\'') # Try numeric try: if '.' in value: return float(value) return int(value) except ValueError: pass # Boolean if value.lower() in ('true', 'false'): return value.lower() == 'true' # None if value.lower() in ('none', 'null'): return None return value def normalize_value(self, value: Any) -> Any: """Normalize a value for comparison.""" if isinstance(value, str): return self.normalizer.normalize(value) if isinstance(value, (list, tuple)): return [self.normalize_value(v) for v in value] if isinstance(value, dict): return {k: self.normalize_value(v) for k, v in value.items()} return value def compare_arguments( self, predicted: Dict[str, Any], expected: Dict[str, Any], strict: bool = True ) -> Tuple[bool, float, Dict]: """ Compare predicted arguments against expected. Returns: (is_match, score, details) """ if not expected: return len(predicted) == 0, 1.0 if len(predicted) == 0 else 0.0, {} details = {'matched': [], 'mismatched': [], 'missing': [], 'extra': []} expected_keys = set(expected.keys()) predicted_keys = set(predicted.keys()) # Check for missing and extra keys missing = expected_keys - predicted_keys extra = predicted_keys - expected_keys details['missing'] = list(missing) details['extra'] = list(extra) # Compare common keys common_keys = expected_keys & predicted_keys matched_count = 0 for key in common_keys: exp_val = self.normalize_value(expected[key]) pred_val = self.normalize_value(predicted[key]) if exp_val == pred_val: details['matched'].append(key) matched_count += 1 else: # Try numeric comparison with tolerance if isinstance(exp_val, (int, float)) and isinstance(pred_val, (int, float)): if abs(exp_val - pred_val) < 0.001: details['matched'].append(key) matched_count += 1 continue details['mismatched'].append({ 'key': key, 'expected': expected[key], 'predicted': predicted[key] }) # Calculate score total_expected = len(expected_keys) if strict: # All must match, no extras is_match = (matched_count == total_expected and len(extra) == 0) score = matched_count / max(total_expected, len(predicted_keys)) if predicted_keys else 0.0 else: # Partial credit is_match = matched_count == total_expected score = matched_count / total_expected if total_expected > 0 else 1.0 return is_match, score, details def evaluate_single_call( self, predicted: Optional[Dict], expected: Dict ) -> EvaluationResult: """Evaluate a single function call prediction.""" if predicted is None: return EvaluationResult( sample_id="", category="", is_correct=False, score=0.0, details={'error': 'Failed to parse prediction'} ) # Check function name pred_name = self.normalizer.normalize(predicted.get('name', '')) exp_name = self.normalizer.normalize(expected.get('name', '')) if pred_name != exp_name: return EvaluationResult( sample_id="", category="", is_correct=False, score=0.0, details={ 'error': 'Function name mismatch', 'expected_name': expected.get('name'), 'predicted_name': predicted.get('name') } ) # Compare arguments pred_args = predicted.get('arguments', {}) exp_args = expected.get('arguments', {}) is_match, score, details = self.compare_arguments( pred_args, exp_args, strict=(self.mode == 'exact') ) return EvaluationResult( sample_id="", category="", is_correct=is_match, score=score, details=details ) def evaluate_parallel_calls( self, predicted: List[Dict], expected: List[Dict] ) -> EvaluationResult: """ Evaluate parallel function calls (order-agnostic). Uses bipartite matching for optimal pairing. """ if len(predicted) == 0 and len(expected) == 0: return EvaluationResult( sample_id="", category="", is_correct=True, score=1.0, details={'matched_calls': 0} ) if len(predicted) == 0: return EvaluationResult( sample_id="", category="", is_correct=False, score=0.0, details={'error': 'No predictions', 'expected_count': len(expected)} ) # Build score matrix scores = [] for pred in predicted: row = [] for exp in expected: result = self.evaluate_single_call(pred, exp) row.append(result.score) scores.append(row) # Greedy matching (could use Hungarian algorithm for optimal) matched = 0 total_score = 0.0 used_expected = set() match_details = [] for i, pred in enumerate(predicted): best_j = -1 best_score = -1 for j, exp in enumerate(expected): if j not in used_expected and scores[i][j] > best_score: best_score = scores[i][j] best_j = j if best_j >= 0 and best_score > 0: used_expected.add(best_j) total_score += best_score if best_score == 1.0: matched += 1 match_details.append({ 'predicted': pred, 'matched_to': expected[best_j], 'score': best_score }) # Calculate overall score max_possible = max(len(predicted), len(expected)) avg_score = total_score / max_possible if max_possible > 0 else 0.0 is_correct = (matched == len(expected) and len(predicted) == len(expected)) return EvaluationResult( sample_id="", category="", is_correct=is_correct, score=avg_score, details={ 'matched_calls': matched, 'expected_count': len(expected), 'predicted_count': len(predicted), 'matches': match_details } ) def evaluate_irrelevance( self, predicted: Union[str, Dict, List], expected_no_call: bool = True ) -> EvaluationResult: """ Evaluate irrelevance detection (should not call any function). """ # Check if model made any function calls if isinstance(predicted, str): calls = self.parse_multiple_calls(predicted) elif isinstance(predicted, list): calls = predicted elif isinstance(predicted, dict): calls = [predicted] if 'name' in predicted else [] else: calls = [] made_call = len(calls) > 0 if expected_no_call: is_correct = not made_call score = 1.0 if is_correct else 0.0 details = { 'expected': 'no_call', 'actual': 'call_made' if made_call else 'no_call', 'calls_made': calls } else: is_correct = made_call score = 1.0 if is_correct else 0.0 details = { 'expected': 'call_required', 'actual': 'call_made' if made_call else 'no_call' } return EvaluationResult( sample_id="", category="irrelevance", is_correct=is_correct, score=score, details=details ) def evaluate( self, sample: Dict, prediction: str ) -> EvaluationResult: """ Main evaluation entry point. Dispatches to appropriate evaluator based on category. """ category = sample.get('category', 'simple') sample_id = sample.get('id', '') # Parse ground truth ground_truth = sample.get('ground_truth') if isinstance(ground_truth, str) and ground_truth: try: ground_truth = json.loads(ground_truth) except json.JSONDecodeError: ground_truth = None # Handle irrelevance if category == 'irrelevance': result = self.evaluate_irrelevance(prediction, expected_no_call=True) result.sample_id = sample_id return result # Parse prediction if category in ('parallel', 'parallel_multiple'): pred_calls = self.parse_multiple_calls(prediction) if ground_truth and 'calls' in ground_truth: exp_calls = ground_truth['calls'] else: exp_calls = [] result = self.evaluate_parallel_calls(pred_calls, exp_calls) else: pred_call = self.parse_function_call(prediction) if ground_truth: if 'calls' in ground_truth and ground_truth['calls']: exp_call = ground_truth['calls'][0] else: exp_call = ground_truth else: # No ground truth available result = EvaluationResult( sample_id=sample_id, category=category, is_correct=False, score=0.0, details={'error': 'No ground truth available'} ) return result result = self.evaluate_single_call(pred_call, exp_call) result.sample_id = sample_id result.category = category return result