Spaces:
Sleeping
Sleeping
| """Evaluation utilities matching standard implementations""" | |
| import re | |
| from typing import Optional, Union | |
| import numpy as np | |
| try: | |
| import sympy | |
| from sympy.parsing.latex import parse_latex | |
| SYMPY_AVAILABLE = True | |
| except ImportError: | |
| SYMPY_AVAILABLE = False | |
| def normalize_math_answer(answer: str) -> str: | |
| """Normalize mathematical answers following lm-eval's approach""" | |
| if not answer: | |
| return "" | |
| # Extract content after equals sign | |
| if '=' in answer: | |
| answer = answer.split('=')[-1] | |
| # Remove dollar signs and spaces | |
| answer = answer.strip() | |
| answer = answer.strip('$') | |
| # Remove text{} and textbf{} | |
| answer = re.sub(r'\\text\{([^}]*)\}', r'\1', answer) | |
| answer = re.sub(r'\\textbf\{([^}]*)\}', r'\1', answer) | |
| # Fix \fracab -> \frac{a}{b} | |
| answer = re.sub(r'\\frac([0-9a-zA-Z])([0-9a-zA-Z])', r'\\frac{\1}{\2}', answer) | |
| # Remove commas from numbers | |
| answer = re.sub(r'(\d),', r'\1', answer) | |
| # Remove specific words | |
| for word in ['square', 'units', 'integers', 'dollars', 'mph', 'inches', 'feet', 'minutes', 'cm', 'gm', 'pounds', 'meters', 'meals', 'edges', 'students', 'childrentickets', 'multiples', 'hours', 'degrees', 'ounces', 'bits', 'factorization', 'greenmarbles', 'redmarbles', 'bluemarbles']: | |
| answer = answer.replace(word, '') | |
| # Remove extra spaces | |
| answer = ' '.join(answer.split()) | |
| return answer.strip() | |
| def extract_answer_gsm8k(response: str) -> Optional[float]: | |
| """Extract answer from GSM8K response following official format""" | |
| # Look for the last number in the response | |
| numbers = re.findall(r'[-+]?\d*\.?\d+', response) | |
| if numbers: | |
| try: | |
| return float(numbers[-1]) | |
| except: | |
| pass | |
| return None | |
| def extract_answer_mmlu(response: str) -> Optional[str]: | |
| """Extract MMLU answer following official format""" | |
| # Clean response | |
| response = response.strip() | |
| # Look for single letter answer | |
| if len(response) == 1 and response in 'ABCD': | |
| return response | |
| # Look for letter followed by parenthesis or period | |
| match = re.search(r'^([ABCD])[).\s]', response) | |
| if match: | |
| return match.group(1) | |
| # Look for "answer is X" pattern | |
| match = re.search(r'answer is ([ABCD])', response, re.IGNORECASE) | |
| if match: | |
| return match.group(1).upper() | |
| # Look for first occurrence of A, B, C, or D | |
| match = re.search(r'[ABCD]', response) | |
| if match: | |
| return match.group(0) | |
| return None | |
| def calculate_accuracy_with_confidence(results: list) -> dict: | |
| """Calculate accuracy with confidence intervals""" | |
| correct = sum(1 for r in results if r.get('is_correct', False)) | |
| total = len(results) | |
| if total == 0: | |
| return { | |
| 'accuracy': 0.0, | |
| 'correct': 0, | |
| 'total': 0, | |
| 'confidence_interval': (0.0, 0.0) | |
| } | |
| accuracy = correct / total | |
| # Wilson score interval for binomial proportion | |
| z = 1.96 # 95% confidence | |
| n = total | |
| p = accuracy | |
| denominator = 1 + z**2 / n | |
| center = (p + z**2 / (2*n)) / denominator | |
| margin = z * np.sqrt(p * (1-p) / n + z**2 / (4*n**2)) / denominator | |
| lower = max(0, center - margin) | |
| upper = min(1, center + margin) | |
| return { | |
| 'accuracy': accuracy, | |
| 'correct': correct, | |
| 'total': total, | |
| 'confidence_interval': (lower, upper) | |
| } | |
| def is_math_equiv(pred: str, gold: str) -> bool: | |
| """Check mathematical equivalence using SymPy (matching lm-eval)""" | |
| # First normalize both answers | |
| pred_norm = normalize_math_answer(pred) | |
| gold_norm = normalize_math_answer(gold) | |
| # Quick string comparison | |
| if pred_norm == gold_norm: | |
| return True | |
| if not SYMPY_AVAILABLE: | |
| # Fallback to string comparison | |
| return pred_norm == gold_norm | |
| try: | |
| # Try to parse as LaTeX | |
| try: | |
| pred_expr = parse_latex(pred_norm) | |
| gold_expr = parse_latex(gold_norm) | |
| except: | |
| # Try parsing as regular SymPy expression | |
| pred_expr = sympy.sympify(pred_norm) | |
| gold_expr = sympy.sympify(gold_norm) | |
| # Check if expressions are equivalent | |
| diff = sympy.simplify(pred_expr - gold_expr) | |
| return diff == 0 or diff.is_zero | |
| except Exception: | |
| # If parsing fails, fall back to string comparison | |
| return pred_norm == gold_norm | |
| def is_gsm8k_correct(pred: str, gold: str) -> bool: | |
| """Check GSM8K answer correctness""" | |
| if pred == gold: | |
| return True | |
| try: | |
| # Try numeric comparison | |
| pred_num = float(pred) | |
| gold_num = float(gold) | |
| # GSM8K uses exact match, but we allow tiny floating point errors | |
| return abs(pred_num - gold_num) < 1e-9 | |
| except: | |
| return False |