Spaces:
Sleeping
Sleeping
| """Utility functions""" | |
| from __future__ import annotations | |
| from typing import Any, Dict, List, Optional | |
| from pathlib import Path | |
| import json | |
| import random | |
| import numpy as np | |
| def set_seed(seed: int) -> None: | |
| """Set random seed for reproducibility | |
| Args: | |
| seed: Random seed | |
| """ | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| try: | |
| import torch | |
| torch.manual_seed(seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed_all(seed) | |
| except ImportError: | |
| pass | |
| def save_json(data: Any, filepath: str, indent: int = 2) -> None: | |
| """Save data to JSON file | |
| Args: | |
| data: Data to save | |
| filepath: Output file path | |
| indent: JSON indentation | |
| """ | |
| Path(filepath).parent.mkdir(parents=True, exist_ok=True) | |
| with open(filepath, "w", encoding="utf-8") as f: | |
| json.dump(data, f, indent=indent, ensure_ascii=False) | |
| def load_json(filepath: str) -> Any: | |
| """Load data from JSON file | |
| Args: | |
| filepath: Input file path | |
| Returns: | |
| Loaded data | |
| """ | |
| with open(filepath, "r", encoding="utf-8") as f: | |
| return json.load(f) | |
| def save_jsonl(data: List[Dict[str, Any]], filepath: str) -> None: | |
| """Save data to JSONL file | |
| Args: | |
| data: List of dictionaries to save | |
| filepath: Output file path | |
| """ | |
| Path(filepath).parent.mkdir(parents=True, exist_ok=True) | |
| with open(filepath, "w", encoding="utf-8") as f: | |
| for item in data: | |
| f.write(json.dumps(item, ensure_ascii=False) + "\n") | |
| def load_jsonl(filepath: str) -> List[Dict[str, Any]]: | |
| """Load data from JSONL file | |
| Args: | |
| filepath: Input file path | |
| Returns: | |
| List of dictionaries | |
| """ | |
| data = [] | |
| with open(filepath, "r", encoding="utf-8") as f: | |
| for line in f: | |
| if line.strip(): | |
| data.append(json.loads(line)) | |
| return data | |
| def last_boxed_only_string(text: str) -> Optional[str]: | |
| """Extract the last \\boxed{} expression from text (Hendrycks MATH standard) | |
| This is the canonical extractor used in: | |
| - Original MATH dataset (Hendrycks et al.) | |
| - lm-evaluation-harness | |
| - Minerva | |
| - Math-Verify | |
| Args: | |
| text: Text containing LaTeX \\boxed{} expressions | |
| Returns: | |
| Content of the last \\boxed{} expression, or None if not found | |
| """ | |
| import re | |
| # Pattern for \boxed{} or \fbox{} | |
| # Use a simple pattern first for non-nested cases | |
| simple_pattern = r'\\(?:boxed|fbox)\{([^{}]+)\}' | |
| matches = list(re.finditer(simple_pattern, text)) | |
| if matches: | |
| # Return content of last match | |
| return matches[-1].group(1).strip() | |
| # If no simple match, try to handle nested braces with brace counting | |
| # This handles cases like \boxed{\frac{a}{b}} | |
| boxed_starts = list(re.finditer(r'\\(?:boxed|fbox)\{', text)) | |
| if not boxed_starts: | |
| return None | |
| # Process from the last \boxed{ occurrence | |
| last_start = boxed_starts[-1] | |
| start_pos = last_start.end() | |
| # Count braces to find matching closing brace | |
| depth = 1 | |
| end_pos = start_pos | |
| for i in range(start_pos, len(text)): | |
| if text[i] == '{': | |
| depth += 1 | |
| elif text[i] == '}': | |
| depth -= 1 | |
| if depth == 0: | |
| end_pos = i | |
| break | |
| if depth == 0: | |
| content = text[start_pos:end_pos].strip() | |
| return content if content else None | |
| return None | |
| def get_gold_answer_math( | |
| gold_answer_field: Optional[str] = None, | |
| gold_solution_text: Optional[str] = None, | |
| ) -> Optional[str]: | |
| """Get gold answer for MATH dataset in a dataset-agnostic way | |
| Args: | |
| gold_answer_field: Direct answer field (MATH-500 HuggingFace format) | |
| gold_solution_text: Solution text with \\boxed{} answer (original MATH format) | |
| Returns: | |
| Extracted gold answer string | |
| """ | |
| # MATH-500 HuggingFace format: use answer field directly | |
| if gold_answer_field is not None: | |
| return gold_answer_field.strip() | |
| # Original Hendrycks MATH format: extract from solution using \\boxed{} | |
| if gold_solution_text is not None: | |
| ans = last_boxed_only_string(gold_solution_text) | |
| if ans: | |
| return ans | |
| return None | |
| def extract_answer_before_hash_r1(text: str, task: str = "gsm8k") -> Optional[str]: | |
| """Extract answer that appears BEFORE #### marker (R1-specific pattern) | |
| R1 models often put the final answer right before ####, then regenerate after. | |
| Examples: | |
| - "Total time = 16 hours.\n####" → "16" | |
| - "The answer is \\boxed{72}.\n####" → "72" | |
| Args: | |
| text: Generated text containing #### marker | |
| task: Task type (gsm8k or math/math-500) | |
| Returns: | |
| Extracted answer from before ####, or None | |
| """ | |
| import re | |
| if '####' not in text: | |
| return None | |
| # Get text before the last #### | |
| last_hash_idx = text.rfind('####') | |
| before_hash = text[:last_hash_idx].strip() | |
| if not before_hash: | |
| return None | |
| if task in ("math", "math-500"): | |
| # MATH: Look for \boxed{} before #### | |
| boxed_answer = last_boxed_only_string(before_hash) | |
| if boxed_answer: | |
| return boxed_answer | |
| # Fallback: extract from last line before #### | |
| lines = before_hash.split('\n') | |
| for line in reversed(lines[-3:]): # Check last 3 lines | |
| line = line.strip() | |
| if line and not line.startswith('Step'): | |
| # Try to extract LaTeX expression or number | |
| # Remove common prefixes | |
| for prefix in ['So,', 'Therefore,', 'Thus,', 'Hence,', 'The answer is', 'Final answer:']: | |
| if line.startswith(prefix): | |
| line = line[len(prefix):].strip() | |
| break | |
| # Clean and return if reasonable length | |
| line = line.rstrip('.') | |
| if line and len(line) < 100: | |
| return line | |
| return None | |
| else: # GSM8K | |
| # Look for the first number before ####, then extend backwards | |
| # Get last 200 chars before #### to avoid scanning entire text | |
| snippet = before_hash[-200:] if len(before_hash) > 200 else before_hash | |
| # Find all numbers in the snippet | |
| # Pattern: optional $, then number with optional commas/spaces | |
| number_pattern = r'\$?\s*(-?\d+(?:[,\s]\d+)*(?:\.\d+)?)' | |
| matches = list(re.finditer(number_pattern, snippet)) | |
| if not matches: | |
| return None | |
| # Get the last number before #### | |
| last_match = matches[-1] | |
| number_str = last_match.group(1).replace(',', '').replace(' ', '') | |
| # Now look backwards from this number to find the start of the answer phrase | |
| # This helps extract "16 hours" instead of just "16" | |
| match_start = last_match.start() | |
| phrase_start = match_start | |
| # Look back up to 50 chars or until we hit punctuation/newline | |
| lookback_start = max(0, match_start - 50) | |
| for i in range(match_start - 1, lookback_start - 1, -1): | |
| char = snippet[i] | |
| if char in '.!?\n': | |
| phrase_start = i + 1 | |
| break | |
| elif char in '=': | |
| # Include the part after = (e.g., "= 16") | |
| phrase_start = i | |
| break | |
| # Extract the phrase containing the number | |
| phrase = snippet[phrase_start:last_match.end()].strip() | |
| # Clean up the phrase | |
| phrase = phrase.lstrip('=').strip() | |
| phrase = phrase.lstrip('$').strip() | |
| # For GSM8K, we typically want just the number | |
| # But first try to extract it cleanly from the phrase | |
| clean_number = re.search(r'(-?\d+(?:\.\d+)?)', phrase) | |
| if clean_number: | |
| return clean_number.group(1) | |
| return number_str | |
| def extract_answer( | |
| text: str, | |
| task: str = "gsm8k", | |
| use_r1_fallback: bool = False, | |
| gold_answer: Optional[str] = None | |
| ) -> Optional[str]: | |
| """Extract final answer from generated text (ROBUST version with MATH-specific improvements) | |
| Handles: | |
| - GSM8K format: "#### 36 + 3 = 39" → "39" | |
| - MATH format: "\\boxed{42}" or "\\boxed{x \\in [-2,7]}" → "42" or "x \\in [-2,7]" | |
| - MATH with ####: "#### $answer$" or "#### answer" | |
| - R1-specific: Answer before #### when use_r1_fallback=True | |
| - Comma-separated numbers: "#### 1,234" → "1234" | |
| - Currency symbols: "#### $70" → "70" | |
| - Negative numbers: "#### -42" → "-42" | |
| - Spaces between digits: "#### 118 000" → "118000" | |
| - Text after number: "#### 70 dollars" → "70" | |
| - Recursive operators: "#### 36 + 3 = 39 + 1 = 40" → "40" | |
| - Repetition detection: Stops before repetitive patterns | |
| - Fallback: Extracts from last coherent step | |
| Args: | |
| text: Generated text | |
| task: Task/dataset name (gsm8k, math-500, math, etc.) | |
| use_r1_fallback: If True, try extracting answer before #### when after-#### extraction | |
| fails or is incorrect (for R1 models) | |
| gold_answer: Optional gold answer for R1 fallback validation | |
| Returns: | |
| Extracted answer or None | |
| """ | |
| if not text: | |
| return None | |
| import re | |
| # Validate and normalize task name | |
| task = task.lower().strip() | |
| if task not in ("gsm8k", "math", "math-500"): | |
| # Treat unknown tasks as generic numeric extraction | |
| task = "gsm8k" | |
| # Helper function: Detect repetitive pattern and truncate | |
| def remove_repetition(txt: str) -> str: | |
| """Remove repetitive patterns from end of text""" | |
| if len(txt) < 100: | |
| return txt | |
| # Check for exact repetition of substrings at the end | |
| for pattern_len in [20, 30, 40, 50]: | |
| if len(txt) < pattern_len * 3: | |
| continue | |
| pattern = txt[-pattern_len:] | |
| # Count how many times this pattern repeats at the end | |
| count = 1 | |
| pos = len(txt) - pattern_len * 2 | |
| while pos >= 0 and txt[pos:pos+pattern_len] == pattern: | |
| count += 1 | |
| pos -= pattern_len | |
| if count >= 3: # If pattern repeats 3+ times, truncate | |
| return txt[:pos + pattern_len] | |
| # Check for "Step N: ####" repetition (common failure mode) | |
| if txt.count('Step ') > 10: | |
| # Find where "Step N: ####" pattern starts repeating | |
| matches = list(re.finditer(r'Step \d+: ####', txt)) | |
| if len(matches) >= 5: # If we see this pattern 5+ times | |
| # Truncate at the first occurrence of this pattern | |
| return txt[:matches[0].start()] | |
| return txt | |
| # Helper function: Extract last number from text as fallback | |
| def extract_last_number(txt: str) -> Optional[str]: | |
| all_numbers = re.findall(r'(-?\d+(?:[\s,]\d+)*(?:\.\d+)?)', txt) | |
| if all_numbers: | |
| return all_numbers[-1].replace(' ', '').replace(',', '') | |
| return None | |
| # Helper function: Extract answer from #### marker (MATH-aware) | |
| def extract_from_hash_marker(txt: str) -> Optional[str]: | |
| """Extract answer after #### marker, handling MATH-specific formats""" | |
| if '####' not in txt: | |
| return None | |
| # Split by #### and get the LAST occurrence (most recent answer) | |
| parts = txt.split('####') | |
| answer_part = parts[-1].strip() | |
| before_hash = parts[-2] if len(parts) >= 2 else "" | |
| # Check if answer_part is actually useful | |
| # It's not useful if: empty, too long, starts with "Step", or is a full sentence without clear answer | |
| answer_looks_like_sentence = ( | |
| answer_part and | |
| len(answer_part.split()) > 5 and # More than 5 words | |
| not answer_part[0].isdigit() and # Doesn't start with a number | |
| not answer_part.startswith('$') and # Doesn't start with LaTeX | |
| ('the' in answer_part.lower() or 'is' in answer_part.lower()) # Contains common sentence words | |
| ) | |
| if not answer_part or len(answer_part) > 200 or answer_part.startswith('Step') or answer_looks_like_sentence: | |
| # If nothing useful after ####, or it's too long (likely paragraph), or starts with Step, or looks like a sentence | |
| # Try to extract from the text BEFORE #### (look for "Final Answer:" or similar) | |
| if before_hash: | |
| # Look for common answer indicators in the last 300 chars before #### | |
| before_snippet = before_hash[-300:] if len(before_hash) > 300 else before_hash | |
| # Try to find "Final Answer:", "Answer:", etc. | |
| answer_patterns = [ | |
| r'Final Answer:\s*([^\n]+)', | |
| r'Answer:\s*([^\n]+)', | |
| r'Therefore,?\s+(?:the answer is|we get|we have)\s*:?\s*([^\n]+)', | |
| r'(?:Thus|Hence|So),?\s+(?:the answer is|we get|we have)\s*:?\s*([^\n]+)', | |
| ] | |
| for pattern in answer_patterns: | |
| match = re.search(pattern, before_snippet, re.IGNORECASE) | |
| if match: | |
| potential_answer = match.group(1).strip() | |
| # Clean up the extracted answer | |
| potential_answer = potential_answer.rstrip('.,;:') | |
| # Extract just the number/value if it's a sentence | |
| words = potential_answer.split() | |
| if len(words) > 3: # If it's a phrase, extract the first number/value | |
| # Look for a number at the start | |
| number_match = re.match(r'^(-?\d+(?:\.\d+)?)', potential_answer) | |
| if number_match: | |
| return number_match.group(1) | |
| return potential_answer | |
| # Continue with normal after-#### processing | |
| if not answer_part: | |
| # If last #### has nothing after it, try second-to-last | |
| if len(parts) >= 2: | |
| answer_part = parts[-2].strip() | |
| if not answer_part: | |
| return None | |
| # Skip if it starts with "Step" (incomplete generation) | |
| if answer_part.startswith('Step'): | |
| # Try to find a #### that's NOT followed by Step | |
| for i in range(len(parts) - 1, 0, -1): | |
| candidate = parts[i].strip() | |
| if candidate and not candidate.startswith('Step'): | |
| answer_part = candidate | |
| break | |
| else: | |
| return None | |
| # Format 1: #### $latex_expression$ | |
| # Extract content between first $ and last $ on same line | |
| first_line = answer_part.split('\n')[0].strip() | |
| if first_line.startswith('$') and first_line.count('$') >= 2: | |
| # Extract content between first and last $ | |
| dollar_content = first_line[1:] # Remove first $ | |
| if '$' in dollar_content: | |
| dollar_content = dollar_content[:dollar_content.rfind('$')] # Remove last $ | |
| dollar_content = dollar_content.strip() | |
| # If it's an equation like "f(x)=5", extract the value after = | |
| if '=' in dollar_content: | |
| # Split by = and get the rightmost part | |
| eq_parts = dollar_content.split('=') | |
| # Get last non-empty part | |
| for part in reversed(eq_parts): | |
| part = part.strip() | |
| if part: | |
| return part | |
| return dollar_content | |
| # Format 2: #### plain answer (number or expression) | |
| # Take first line only, clean it | |
| first_line = answer_part.split('\n')[0].strip() | |
| # Remove trailing/leading $ if present | |
| first_line = first_line.strip('$') | |
| # Remove markdown formatting (**, *, etc.) | |
| first_line = re.sub(r'\*\*([^*]+)\*\*', r'\1', first_line) # **text** → text | |
| first_line = re.sub(r'\*([^*]+)\*', r'\1', first_line) # *text* → text | |
| first_line = first_line.strip() | |
| # Remove common prefixes that indicate an answer | |
| prefixes_to_remove = [ | |
| 'The final answer is', | |
| 'The answer is', | |
| 'Final Answer:', | |
| 'Final Answer', | |
| 'Answer:', | |
| 'Answer', | |
| 'Therefore,', | |
| 'Thus,', | |
| 'So,', | |
| 'Hence,', | |
| ] | |
| for prefix in prefixes_to_remove: | |
| if first_line.startswith(prefix): | |
| first_line = first_line[len(prefix):].strip() | |
| # Remove trailing punctuation and colons after removing prefix | |
| first_line = first_line.lstrip(':').strip().rstrip('.,;:') | |
| break | |
| # Remove common suffixes like explanations | |
| first_line = re.split(r'\n|Explanation:|Note:|Solution:', first_line)[0].strip() | |
| # If what remains is a sentence (contains many words) OR has text after a number, extract just the number | |
| # This handles cases like "3 treeks", "70 dollars", "The combined weight of three treeks equals the weight of one squig." | |
| words = first_line.split() | |
| if len(words) >= 2: # Has multiple words (number + text), try to extract just the value | |
| # Look for numbers in the text | |
| numbers = re.findall(r'-?\d+(?:\.\d+)?', first_line) | |
| if numbers: | |
| # If there's only one number, return it | |
| if len(numbers) == 1: | |
| return numbers[0] | |
| # If multiple numbers, try to find the most likely answer (usually the last meaningful one) | |
| return numbers[-1] | |
| # If no numbers, try to extract key mathematical value | |
| # Look for pattern like "X equals Y" or "X = Y" or "has X units" | |
| if '=' in first_line or 'equals' in first_line.lower() or 'has' in first_line.lower(): | |
| split_parts = re.split(r'=|equals|has', first_line, flags=re.IGNORECASE) | |
| if len(split_parts) >= 2: | |
| last_part = split_parts[-1].strip().rstrip('.,;:') | |
| # Try to extract number or short expression | |
| cleaned = re.sub(r'\b(the|a|an|one|of|to|is|are)\b', '', last_part, flags=re.IGNORECASE).strip() | |
| # Extract first number from cleaned part | |
| num_match = re.search(r'-?\d+(?:\.\d+)?', cleaned) | |
| if num_match: | |
| return num_match.group(0) | |
| if cleaned and len(cleaned) < 20: | |
| return cleaned | |
| return first_line if first_line else None | |
| # Helper function: Extract from last coherent step | |
| def extract_from_last_step(txt: str) -> Optional[str]: | |
| """Extract answer from the last coherent step before repetition""" | |
| # Find all Step N: patterns | |
| step_matches = list(re.finditer(r'Step \d+:', txt)) | |
| if not step_matches: | |
| return None | |
| # Get the last few steps | |
| for match in reversed(step_matches[-5:]): # Check last 5 steps | |
| start = match.end() | |
| # Find end of this step (next Step or end of text) | |
| next_match_idx = step_matches.index(match) + 1 | |
| if next_match_idx < len(step_matches): | |
| end = step_matches[next_match_idx].start() | |
| else: | |
| end = len(txt) | |
| step_content = txt[start:end].strip() | |
| # Look for equations or expressions ending with $ on a line by itself | |
| lines = step_content.split('\n') | |
| for line in reversed(lines): | |
| line = line.strip() | |
| # Check if line contains LaTeX expression | |
| if '$' in line and not line.startswith('Step'): | |
| # Extract from $...$ | |
| if line.count('$') >= 2: | |
| dollar_parts = line.split('$') | |
| for part in reversed(dollar_parts): | |
| part = part.strip() | |
| if part and not part.startswith('Step'): | |
| # Check if this looks like an answer (not a full sentence) | |
| if len(part) < 50 and ('=' in part or '\\' in part or part.replace('.', '').replace('-', '').replace(',', '').replace(' ', '').replace('/', '').isalnum()): | |
| # Extract the value after = if present | |
| if '=' in part: | |
| after_eq = part.split('=')[-1].strip() | |
| if after_eq: | |
| return after_eq | |
| return part | |
| return None | |
| if task in ("math-500", "math"): | |
| # MATH format: For gold answers, the text IS the answer (may contain LaTeX) | |
| # For generated text, use multi-strategy extraction | |
| # Remove repetitive patterns first | |
| text = remove_repetition(text) | |
| # Strategy 1: Look for \boxed{} notation (official MATH format, highest priority) | |
| # Use the canonical last_boxed_only_string extractor | |
| boxed_answer = last_boxed_only_string(text) | |
| if boxed_answer: | |
| result = boxed_answer | |
| # Note: If boxed_answer is None, this is a formatting failure for MATH. | |
| # Upstream code may want to track n_with_boxed vs n_without_boxed | |
| # to distinguish formatting quality from reasoning quality. | |
| # Strategy 2: Look for #### marker (GSM8K-style, but model sometimes uses it) | |
| elif '####' in text: | |
| hash_answer = extract_from_hash_marker(text) | |
| if hash_answer: | |
| result = hash_answer | |
| else: | |
| result = None | |
| # Strategy 3: If text looks like a direct answer (gold answer case) | |
| # This handles gold answers like "\frac{14}{3}", "x \\in [-2,7]", or "p - q" | |
| elif not ("Step" in text or len(text) > 200): | |
| # Likely a gold answer - return as-is after light cleanup | |
| result = text.strip() | |
| else: | |
| # Strategy 4: Extract from last coherent step | |
| step_answer = extract_from_last_step(text) | |
| if step_answer: | |
| result = step_answer | |
| else: | |
| # Strategy 5: For MATH, do NOT default to last numeric token | |
| result = None | |
| # R1 FALLBACK for MATH: Check if we should try before-hash extraction | |
| if use_r1_fallback and '####' in text: | |
| should_try_r1 = False | |
| if result is None: | |
| should_try_r1 = True | |
| elif gold_answer is not None and not evaluate_math_answer(result, gold_answer): | |
| should_try_r1 = True | |
| if should_try_r1: | |
| r1_answer = extract_answer_before_hash_r1(text, task) | |
| if r1_answer and (gold_answer is None or evaluate_math_answer(r1_answer, gold_answer)): | |
| return r1_answer | |
| return result | |
| if task == "gsm8k": | |
| # GSM8K format: answer is after #### | |
| # Use the shared extract_from_hash_marker helper for consistency | |
| from_hash = extract_from_hash_marker(text) | |
| if from_hash is None: | |
| # No #### found or extraction failed, try last number | |
| result = extract_last_number(text) | |
| else: | |
| # For GSM8K, enforce "single number" semantics | |
| # Normalize to extract just the numeric value | |
| num = normalize_answer(from_hash) | |
| result = num if num != "" else extract_last_number(text) | |
| # R1 FALLBACK for GSM8K: Check if we should try before-hash extraction | |
| if use_r1_fallback and '####' in text: | |
| should_try_r1 = False | |
| if result is None: | |
| should_try_r1 = True | |
| elif gold_answer is not None and not evaluate_answer(result, gold_answer): | |
| should_try_r1 = True | |
| if should_try_r1: | |
| r1_answer = extract_answer_before_hash_r1(text, task) | |
| if r1_answer and (gold_answer is None or evaluate_answer(r1_answer, gold_answer)): | |
| return r1_answer | |
| return result | |
| # Generic extraction: try to find last number in entire text | |
| return extract_last_number(text) | |
| def normalize_answer(answer: str) -> str: | |
| """Normalize answer for comparison | |
| Removes spaces, commas, currency symbols and converts to standard number format. | |
| Extracts first number if answer contains non-numeric text. | |
| Args: | |
| answer: Answer string | |
| Returns: | |
| Normalized answer string | |
| """ | |
| if not answer: | |
| return "" | |
| import re | |
| # First try to extract just the number (handles "70 dollars", etc.) | |
| # Remove leading currency symbols | |
| answer = re.sub(r'^[\$£€¥₹]+\s*', '', answer.strip()) | |
| # Extract first number (with optional decimal point and commas) | |
| number_match = re.match(r'(-?[\d,]+(?:\.\d+)?)', answer) | |
| if number_match: | |
| answer = number_match.group(1) | |
| # Remove remaining spaces and commas | |
| answer = re.sub(r'[\s,]', '', answer) | |
| # Try to convert to number and normalize | |
| try: | |
| num = float(answer) | |
| if num.is_integer(): | |
| return str(int(num)) | |
| else: | |
| return str(num) | |
| except (ValueError, TypeError): | |
| return answer.strip() | |
| def answers_match(answer1: str, answer2: str) -> bool: | |
| """Check if two answers match after normalization | |
| Args: | |
| answer1: First answer | |
| answer2: Second answer | |
| Returns: | |
| True if answers match after normalization | |
| """ | |
| return normalize_answer(answer1) == normalize_answer(answer2) | |
| def evaluate_answer( | |
| predicted: str, | |
| ground_truth: str, | |
| tolerance: float = 1e-5, | |
| ) -> bool: | |
| """Evaluate if predicted answer matches ground truth (robust version) | |
| Uses normalization to handle: | |
| - Different spacing: "1 000" vs "1000" | |
| - Commas: "1,000" vs "1000" | |
| - Currency: "$50" vs "50" | |
| - Trailing text: "72 clips" vs "72" | |
| - Decimal vs integer: "72.0" vs "72" | |
| Args: | |
| predicted: Predicted answer | |
| ground_truth: Ground truth answer | |
| tolerance: Numerical tolerance (not used with normalization, kept for API compatibility) | |
| Returns: | |
| True if answers match | |
| """ | |
| if not predicted or not ground_truth: | |
| return False | |
| # Use robust normalization and matching | |
| return answers_match(predicted, ground_truth) | |
| def math_normalize_answer(answer: str) -> str: | |
| """Normalize MATH dataset answer for comparison | |
| Handles: | |
| - LaTeX expressions: \\frac{a}{b}, \\sqrt{x}, etc. | |
| - Intervals: [a,b], (a,b), x \\in [a,b] | |
| - Sets: \\{a, b, c\\} | |
| - Matrices and vectors | |
| - Numerical values with various formats | |
| - Whitespace normalization | |
| Args: | |
| answer: Answer string (may contain LaTeX) | |
| Returns: | |
| Normalized answer string | |
| """ | |
| if not answer: | |
| return "" | |
| import re | |
| # Strip outer whitespace | |
| answer = answer.strip() | |
| # Remove LaTeX display mode markers | |
| answer = answer.replace('$', '') | |
| answer = answer.replace('\\[', '').replace('\\]', '') | |
| answer = answer.replace('\\(', '').replace('\\)', '') | |
| # Normalize whitespace around common delimiters | |
| answer = re.sub(r'\s*,\s*', ',', answer) # "a, b" → "a,b" | |
| answer = re.sub(r'\s*=\s*', '=', answer) # "x = 5" → "x=5" | |
| answer = re.sub(r'\s+', ' ', answer) # Multiple spaces → single space | |
| # Normalize LaTeX fractions: \frac{a}{b} → frac(a,b) for comparison | |
| # This allows \frac{1}{2} to match \frac{2}{4} after simplification | |
| def normalize_frac(match): | |
| num = match.group(1).strip() | |
| den = match.group(2).strip() | |
| return f"frac({num},{den})" | |
| answer = re.sub(r'\\frac\{([^}]+)\}\{([^}]+)\}', normalize_frac, answer) | |
| # Normalize common LaTeX commands (preserve but standardize) | |
| latex_commands = { | |
| '\\sqrt': 'sqrt', | |
| '\\pi': 'pi', | |
| '\\infty': 'inf', | |
| '\\in': ' in ', | |
| '\\cup': ' cup ', | |
| '\\cap': ' cap ', | |
| '\\subset': ' subset ', | |
| '\\subseteq': ' subseteq ', | |
| '\\emptyset': 'emptyset', | |
| '\\varnothing': 'emptyset', | |
| '\\left': '', | |
| '\\right': '', | |
| # Note: \\{ and \\} handled separately below for clarity | |
| '\\ldots': '...', | |
| '\\cdots': '...', | |
| '\\le': '<=', | |
| '\\ge': '>=', | |
| '\\leq': '<=', | |
| '\\geq': '>=', | |
| '\\ne': '!=', | |
| '\\neq': '!=', | |
| '\\times': '*', | |
| '\\cdot': '*', | |
| '\\div': '/', | |
| } | |
| for latex_cmd, replacement in latex_commands.items(): | |
| answer = answer.replace(latex_cmd, replacement) | |
| # Handle escaped braces separately (avoid duplicate with dict) | |
| answer = answer.replace('\\{', '{') | |
| answer = answer.replace('\\}', '}') | |
| # Normalize interval notation: remove extra spaces inside brackets | |
| answer = re.sub(r'\[\s*', '[', answer) | |
| answer = re.sub(r'\s*\]', ']', answer) | |
| answer = re.sub(r'\(\s*', '(', answer) | |
| answer = re.sub(r'\s*\)', ')', answer) | |
| # Try to simplify fractions if both numerator and denominator are integers | |
| def simplify_frac(match): | |
| import math | |
| try: | |
| num = int(match.group(1)) | |
| den = int(match.group(2)) | |
| gcd = math.gcd(abs(num), abs(den)) | |
| return f"frac({num//gcd},{den//gcd})" | |
| except: | |
| return match.group(0) | |
| answer = re.sub(r'frac\((-?\d+),(-?\d+)\)', simplify_frac, answer) | |
| # Normalize numbers: remove commas, standardize decimals | |
| def normalize_number(match): | |
| num_str = match.group(0).replace(',', '') | |
| try: | |
| num = float(num_str) | |
| if num.is_integer(): | |
| return str(int(num)) | |
| else: | |
| # Round to 6 decimal places to avoid floating point issues | |
| return f"{num:.6f}".rstrip('0').rstrip('.') | |
| except: | |
| return match.group(0) | |
| answer = re.sub(r'-?\d+(?:,\d{3})*(?:\.\d+)?', normalize_number, answer) | |
| # Final cleanup | |
| answer = answer.strip() | |
| return answer | |
| def canonical_non_numeric(ans: str) -> Optional[str]: | |
| """Map various verbal answers to a canonical label | |
| Handles non-numeric MATH answers like: | |
| - "no solution" | |
| - "does not exist" | |
| - "infinitely many solutions" | |
| - "all real numbers" | |
| Args: | |
| ans: Answer string | |
| Returns: | |
| Canonical label if recognized, None otherwise | |
| """ | |
| import re | |
| a = ans.strip().lower() | |
| # Strip punctuation | |
| a = re.sub(r'[.,;:!?\s]+', ' ', a).strip() | |
| # Common patterns | |
| if any(kw in a for kw in ["no solution", "no real solution", "no real solutions", "no solutions"]): | |
| return "NO_SOLUTION" | |
| if any(kw in a for kw in ["does not exist", "dne", "undefined"]): | |
| return "DNE" | |
| if any(kw in a for kw in ["infinitely many", "infinite number of solutions", "infinite solutions"]): | |
| return "INFINITELY_MANY" | |
| if any(kw in a for kw in ["all real numbers", "any real number", "for all real x", "all reals"]): | |
| return "ALL_REALS" | |
| return None | |
| def math_answers_match(answer1: str, answer2: str) -> bool: | |
| """Check if two MATH dataset answers match after normalization | |
| Uses multiple strategies: | |
| 1. Exact string match after normalization | |
| 2. Numeric equality for pure numbers | |
| 3. SymPy-based symbolic equivalence (algebraic, fraction, etc.) | |
| 4. Whitespace-insensitive string comparison | |
| Args: | |
| answer1: First answer | |
| answer2: Second answer | |
| Returns: | |
| True if answers match after normalization | |
| """ | |
| if not answer1 or not answer2: | |
| return False | |
| import re | |
| norm1 = math_normalize_answer(answer1) | |
| norm2 = math_normalize_answer(answer2) | |
| # 0. Non-numeric canonical forms (check before other strategies) | |
| # Handles "no solution", "DNE", "infinitely many", etc. | |
| canon1 = canonical_non_numeric(norm1) | |
| canon2 = canonical_non_numeric(norm2) | |
| if canon1 is not None or canon2 is not None: | |
| return canon1 == canon2 | |
| # 1. Direct string match after normalization | |
| if norm1 == norm2: | |
| return True | |
| # 2. Try to evaluate as numbers if possible | |
| try: | |
| # Extract just numbers for comparison (handles cases like "4" vs "4.0") | |
| num1_match = re.search(r'^-?\d+(?:\.\d+)?$', norm1) | |
| num2_match = re.search(r'^-?\d+(?:\.\d+)?$', norm2) | |
| if num1_match and num2_match: | |
| num1 = float(norm1) | |
| num2 = float(norm2) | |
| return abs(num1 - num2) < 1e-6 | |
| except: | |
| pass | |
| # 3. Try SymPy-based symbolic equivalence | |
| # This handles algebraic equivalence, fraction simplification, etc. | |
| try: | |
| import sympy as sp | |
| def try_sympy_parse(expr: str) -> Optional[sp.Expr]: | |
| """Attempt to parse expression as SymPy object""" | |
| try: | |
| expr_clean = expr | |
| # Convert frac(a,b) back to a/b | |
| expr_clean = re.sub(r'frac\(([^,]+),([^)]+)\)', r'(\1)/(\2)', expr_clean) | |
| # Basic cleanup: remove stray spaces around operators | |
| expr_clean = re.sub(r'\s+', ' ', expr_clean) | |
| # Locals mapping: let sympify know about sqrt, pi, etc. | |
| # Don't inject sp. into the string; pass locals dict instead | |
| locals_dict = { | |
| "sqrt": sp.sqrt, | |
| "pi": sp.pi, | |
| "inf": sp.oo, | |
| "e": sp.E, | |
| } | |
| return sp.sympify(expr_clean, locals=locals_dict, rational=True) | |
| except Exception: | |
| return None | |
| expr1 = try_sympy_parse(norm1) | |
| expr2 = try_sympy_parse(norm2) | |
| if expr1 is not None and expr2 is not None: | |
| try: | |
| # Check symbolic equality | |
| diff = sp.simplify(expr1 - expr2) | |
| if diff == 0: | |
| return True | |
| # Also try expanding and simplifying both sides | |
| if sp.simplify(expr1) == sp.simplify(expr2): | |
| return True | |
| # For fractions, check if they're equivalent | |
| if isinstance(expr1, sp.Rational) and isinstance(expr2, sp.Rational): | |
| return expr1 == expr2 | |
| except Exception: | |
| pass | |
| except ImportError: | |
| # SymPy not available, skip symbolic checking | |
| pass | |
| # 4. Check if they differ only in equivalent representations | |
| # E.g., "x in [-2,7]" vs "x in [ -2 , 7 ]" | |
| compact1 = re.sub(r'\s+', '', norm1).lower() | |
| compact2 = re.sub(r'\s+', '', norm2).lower() | |
| return compact1 == compact2 | |
| def evaluate_math_answer( | |
| predicted: str, | |
| ground_truth: str, | |
| ) -> bool: | |
| """Evaluate if predicted MATH answer matches ground truth | |
| Uses MATH-specific normalization that handles: | |
| - LaTeX expressions | |
| - Fractions and simplification | |
| - Intervals and sets | |
| - Multiple equivalent representations | |
| Args: | |
| predicted: Predicted answer | |
| ground_truth: Ground truth answer | |
| Returns: | |
| True if answers match | |
| """ | |
| if not predicted or not ground_truth: | |
| return False | |
| return math_answers_match(predicted, ground_truth) | |
| def format_prompt( | |
| question: str, | |
| template: str = "default", | |
| few_shot_examples: Optional[List[Dict[str, str]]] = None, | |
| ) -> str: | |
| """Format prompt for model input | |
| Args: | |
| question: Input question | |
| template: Prompt template to use | |
| few_shot_examples: Optional few-shot examples | |
| Returns: | |
| Formatted prompt | |
| """ | |
| if template == "default": | |
| prompt = f"Question: {question}\n\nAnswer:" | |
| elif template == "cot": | |
| # Structured prompt with clear instructions for GSM8K format | |
| prompt = f"""You are a helpful assistant that solves problems step by step with each step signified by "Step [step_number]: ". | |
| Always provide your final answer after #### at the end. | |
| Question: {question} | |
| Please solve this step by step, putting each step after "Step [step_number]: " and always provide your final answer after ####. | |
| Solution: | |
| """ | |
| elif template == "cot_final": | |
| # Chain of thought with "Final Answer:" marker | |
| prompt = f"""You are a helpful assistant that solves problems step by step with each step signified by "Step [step_number]: ". | |
| Always provide your final answer after "Final Answer:" at the end. | |
| Question: {question} | |
| Please solve this step by step, putting each step after "Step [step_number]: " and always provide your final answer after "Final Answer:". | |
| Solution: | |
| """ | |
| elif template == "math_cot": | |
| # Chain of thought for MATH dataset with \boxed{} formatting | |
| prompt = f"""You are a helpful assistant that solves olympiad-style math problems step by step. | |
| At the end, ALWAYS put your final answer in LaTeX form inside \\boxed{{...}} on its own line. | |
| Question: {question} | |
| Please solve this step by step, and put ONLY the final result (no explanation) inside \\boxed{{...}} on the last line. | |
| Solution: | |
| """ | |
| elif template == "direct": | |
| prompt = question | |
| else: | |
| prompt = question | |
| # Add few-shot examples if provided | |
| if few_shot_examples: | |
| examples_text = "" | |
| for ex in few_shot_examples: | |
| examples_text += f"Question: {ex['question']}\nAnswer: {ex['answer']}\n\n" | |
| prompt = examples_text + prompt | |
| return prompt | |
| def calculate_metrics( | |
| predictions: List[Optional[str]], | |
| ground_truths: List[Optional[str]], | |
| task: str = "gsm8k", | |
| ) -> Dict[str, float]: | |
| """Calculate evaluation metrics with extraction tracking | |
| Distinguishes between extraction failures and reasoning errors: | |
| - accuracy: standard accuracy over all prompts (correct / total) | |
| - conditional_accuracy: accuracy given extraction succeeded (correct / extractable) | |
| - extraction_rate: fraction of prompts with extractable answers | |
| This allows distinguishing formatting quality from reasoning quality. | |
| Args: | |
| predictions: List of predicted answers (None if extraction failed) | |
| ground_truths: List of ground truth answers | |
| task: Task format ("gsm8k" or "math" / "math-500") | |
| Returns: | |
| Dictionary of metrics including accuracy, conditional_accuracy, extraction_rate | |
| """ | |
| if len(predictions) != len(ground_truths): | |
| raise ValueError("Number of predictions and ground truths must match") | |
| total_prompts = len(predictions) | |
| valid_total = 0 | |
| n_extractable = 0 | |
| n_correct = 0 | |
| for pred, gt in zip(predictions, ground_truths): | |
| # Skip if ground truth is None (bad gold answer) | |
| if gt is None: | |
| continue | |
| valid_total += 1 | |
| # Track extraction success | |
| if pred is None: | |
| continue | |
| n_extractable += 1 | |
| # Dispatch to appropriate evaluation function based on task | |
| if task in ("math", "math-500"): | |
| is_correct = evaluate_math_answer(pred, gt) | |
| else: | |
| is_correct = evaluate_answer(pred, gt) | |
| if is_correct: | |
| n_correct += 1 | |
| overall_acc = n_correct / valid_total if valid_total > 0 else 0.0 | |
| cond_acc = n_correct / n_extractable if n_extractable > 0 else 0.0 | |
| extract_rate = n_extractable / valid_total if valid_total > 0 else 0.0 | |
| return { | |
| "accuracy": overall_acc, # correct / valid_total | |
| "conditional_accuracy": cond_acc, # correct / extractable | |
| "extraction_rate": extract_rate, # extractable / valid_total | |
| "correct": n_correct, | |
| "extractable": n_extractable, | |
| "total": valid_total, # valid examples (excludes None golds) | |
| "total_prompts": total_prompts, # all prompts including invalid | |
| } | |
| def print_metrics(metrics: Dict[str, Any], title: str = "Results") -> None: | |
| """Pretty print metrics | |
| Args: | |
| metrics: Dictionary of metrics | |
| title: Title for output | |
| """ | |
| print(f"\n{'='*50}") | |
| print(f"{title:^50}") | |
| print(f"{'='*50}") | |
| for key, value in metrics.items(): | |
| if isinstance(value, float): | |
| print(f"{key:.<30} {value:.4f}") | |
| else: | |
| print(f"{key:.<30} {value}") | |
| print(f"{'='*50}\n") | |
| def create_directory_structure(base_dir: Optional[str] = None) -> None: | |
| """Create standard directory structure | |
| Args: | |
| base_dir: Base directory (defaults to current directory) | |
| """ | |
| if base_dir is None: | |
| base_dir = Path.cwd() | |
| else: | |
| base_dir = Path(base_dir) | |
| directories = [ | |
| "data", | |
| "data/gsm8k", | |
| "models", | |
| "models/cache", | |
| "output", | |
| "output/results", | |
| "output/trajectories", | |
| "output/checkpoints", | |
| "output/logs", | |
| "output/cache", | |
| "config", | |
| "src", | |
| "examples", | |
| ] | |
| for dir_name in directories: | |
| dir_path = base_dir / dir_name | |
| dir_path.mkdir(parents=True, exist_ok=True) | |
| print(f"Created directory structure at {base_dir}") | |