Mustafa Tag Eldeen
HF Space: Reasoning Trajectory Demo
8675765
"""Utility functions"""
from __future__ import annotations
from typing import Any, Dict, List, Optional
from pathlib import Path
import json
import random
import numpy as np
def set_seed(seed: int) -> None:
"""Set random seed for reproducibility
Args:
seed: Random seed
"""
random.seed(seed)
np.random.seed(seed)
try:
import torch
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
except ImportError:
pass
def save_json(data: Any, filepath: str, indent: int = 2) -> None:
"""Save data to JSON file
Args:
data: Data to save
filepath: Output file path
indent: JSON indentation
"""
Path(filepath).parent.mkdir(parents=True, exist_ok=True)
with open(filepath, "w", encoding="utf-8") as f:
json.dump(data, f, indent=indent, ensure_ascii=False)
def load_json(filepath: str) -> Any:
"""Load data from JSON file
Args:
filepath: Input file path
Returns:
Loaded data
"""
with open(filepath, "r", encoding="utf-8") as f:
return json.load(f)
def save_jsonl(data: List[Dict[str, Any]], filepath: str) -> None:
"""Save data to JSONL file
Args:
data: List of dictionaries to save
filepath: Output file path
"""
Path(filepath).parent.mkdir(parents=True, exist_ok=True)
with open(filepath, "w", encoding="utf-8") as f:
for item in data:
f.write(json.dumps(item, ensure_ascii=False) + "\n")
def load_jsonl(filepath: str) -> List[Dict[str, Any]]:
"""Load data from JSONL file
Args:
filepath: Input file path
Returns:
List of dictionaries
"""
data = []
with open(filepath, "r", encoding="utf-8") as f:
for line in f:
if line.strip():
data.append(json.loads(line))
return data
def last_boxed_only_string(text: str) -> Optional[str]:
"""Extract the last \\boxed{} expression from text (Hendrycks MATH standard)
This is the canonical extractor used in:
- Original MATH dataset (Hendrycks et al.)
- lm-evaluation-harness
- Minerva
- Math-Verify
Args:
text: Text containing LaTeX \\boxed{} expressions
Returns:
Content of the last \\boxed{} expression, or None if not found
"""
import re
# Pattern for \boxed{} or \fbox{}
# Use a simple pattern first for non-nested cases
simple_pattern = r'\\(?:boxed|fbox)\{([^{}]+)\}'
matches = list(re.finditer(simple_pattern, text))
if matches:
# Return content of last match
return matches[-1].group(1).strip()
# If no simple match, try to handle nested braces with brace counting
# This handles cases like \boxed{\frac{a}{b}}
boxed_starts = list(re.finditer(r'\\(?:boxed|fbox)\{', text))
if not boxed_starts:
return None
# Process from the last \boxed{ occurrence
last_start = boxed_starts[-1]
start_pos = last_start.end()
# Count braces to find matching closing brace
depth = 1
end_pos = start_pos
for i in range(start_pos, len(text)):
if text[i] == '{':
depth += 1
elif text[i] == '}':
depth -= 1
if depth == 0:
end_pos = i
break
if depth == 0:
content = text[start_pos:end_pos].strip()
return content if content else None
return None
def get_gold_answer_math(
gold_answer_field: Optional[str] = None,
gold_solution_text: Optional[str] = None,
) -> Optional[str]:
"""Get gold answer for MATH dataset in a dataset-agnostic way
Args:
gold_answer_field: Direct answer field (MATH-500 HuggingFace format)
gold_solution_text: Solution text with \\boxed{} answer (original MATH format)
Returns:
Extracted gold answer string
"""
# MATH-500 HuggingFace format: use answer field directly
if gold_answer_field is not None:
return gold_answer_field.strip()
# Original Hendrycks MATH format: extract from solution using \\boxed{}
if gold_solution_text is not None:
ans = last_boxed_only_string(gold_solution_text)
if ans:
return ans
return None
def extract_answer_before_hash_r1(text: str, task: str = "gsm8k") -> Optional[str]:
"""Extract answer that appears BEFORE #### marker (R1-specific pattern)
R1 models often put the final answer right before ####, then regenerate after.
Examples:
- "Total time = 16 hours.\n####" → "16"
- "The answer is \\boxed{72}.\n####" → "72"
Args:
text: Generated text containing #### marker
task: Task type (gsm8k or math/math-500)
Returns:
Extracted answer from before ####, or None
"""
import re
if '####' not in text:
return None
# Get text before the last ####
last_hash_idx = text.rfind('####')
before_hash = text[:last_hash_idx].strip()
if not before_hash:
return None
if task in ("math", "math-500"):
# MATH: Look for \boxed{} before ####
boxed_answer = last_boxed_only_string(before_hash)
if boxed_answer:
return boxed_answer
# Fallback: extract from last line before ####
lines = before_hash.split('\n')
for line in reversed(lines[-3:]): # Check last 3 lines
line = line.strip()
if line and not line.startswith('Step'):
# Try to extract LaTeX expression or number
# Remove common prefixes
for prefix in ['So,', 'Therefore,', 'Thus,', 'Hence,', 'The answer is', 'Final answer:']:
if line.startswith(prefix):
line = line[len(prefix):].strip()
break
# Clean and return if reasonable length
line = line.rstrip('.')
if line and len(line) < 100:
return line
return None
else: # GSM8K
# Look for the first number before ####, then extend backwards
# Get last 200 chars before #### to avoid scanning entire text
snippet = before_hash[-200:] if len(before_hash) > 200 else before_hash
# Find all numbers in the snippet
# Pattern: optional $, then number with optional commas/spaces
number_pattern = r'\$?\s*(-?\d+(?:[,\s]\d+)*(?:\.\d+)?)'
matches = list(re.finditer(number_pattern, snippet))
if not matches:
return None
# Get the last number before ####
last_match = matches[-1]
number_str = last_match.group(1).replace(',', '').replace(' ', '')
# Now look backwards from this number to find the start of the answer phrase
# This helps extract "16 hours" instead of just "16"
match_start = last_match.start()
phrase_start = match_start
# Look back up to 50 chars or until we hit punctuation/newline
lookback_start = max(0, match_start - 50)
for i in range(match_start - 1, lookback_start - 1, -1):
char = snippet[i]
if char in '.!?\n':
phrase_start = i + 1
break
elif char in '=':
# Include the part after = (e.g., "= 16")
phrase_start = i
break
# Extract the phrase containing the number
phrase = snippet[phrase_start:last_match.end()].strip()
# Clean up the phrase
phrase = phrase.lstrip('=').strip()
phrase = phrase.lstrip('$').strip()
# For GSM8K, we typically want just the number
# But first try to extract it cleanly from the phrase
clean_number = re.search(r'(-?\d+(?:\.\d+)?)', phrase)
if clean_number:
return clean_number.group(1)
return number_str
def extract_answer(
text: str,
task: str = "gsm8k",
use_r1_fallback: bool = False,
gold_answer: Optional[str] = None
) -> Optional[str]:
"""Extract final answer from generated text (ROBUST version with MATH-specific improvements)
Handles:
- GSM8K format: "#### 36 + 3 = 39" → "39"
- MATH format: "\\boxed{42}" or "\\boxed{x \\in [-2,7]}" → "42" or "x \\in [-2,7]"
- MATH with ####: "#### $answer$" or "#### answer"
- R1-specific: Answer before #### when use_r1_fallback=True
- Comma-separated numbers: "#### 1,234" → "1234"
- Currency symbols: "#### $70" → "70"
- Negative numbers: "#### -42" → "-42"
- Spaces between digits: "#### 118 000" → "118000"
- Text after number: "#### 70 dollars" → "70"
- Recursive operators: "#### 36 + 3 = 39 + 1 = 40" → "40"
- Repetition detection: Stops before repetitive patterns
- Fallback: Extracts from last coherent step
Args:
text: Generated text
task: Task/dataset name (gsm8k, math-500, math, etc.)
use_r1_fallback: If True, try extracting answer before #### when after-#### extraction
fails or is incorrect (for R1 models)
gold_answer: Optional gold answer for R1 fallback validation
Returns:
Extracted answer or None
"""
if not text:
return None
import re
# Validate and normalize task name
task = task.lower().strip()
if task not in ("gsm8k", "math", "math-500"):
# Treat unknown tasks as generic numeric extraction
task = "gsm8k"
# Helper function: Detect repetitive pattern and truncate
def remove_repetition(txt: str) -> str:
"""Remove repetitive patterns from end of text"""
if len(txt) < 100:
return txt
# Check for exact repetition of substrings at the end
for pattern_len in [20, 30, 40, 50]:
if len(txt) < pattern_len * 3:
continue
pattern = txt[-pattern_len:]
# Count how many times this pattern repeats at the end
count = 1
pos = len(txt) - pattern_len * 2
while pos >= 0 and txt[pos:pos+pattern_len] == pattern:
count += 1
pos -= pattern_len
if count >= 3: # If pattern repeats 3+ times, truncate
return txt[:pos + pattern_len]
# Check for "Step N: ####" repetition (common failure mode)
if txt.count('Step ') > 10:
# Find where "Step N: ####" pattern starts repeating
matches = list(re.finditer(r'Step \d+: ####', txt))
if len(matches) >= 5: # If we see this pattern 5+ times
# Truncate at the first occurrence of this pattern
return txt[:matches[0].start()]
return txt
# Helper function: Extract last number from text as fallback
def extract_last_number(txt: str) -> Optional[str]:
all_numbers = re.findall(r'(-?\d+(?:[\s,]\d+)*(?:\.\d+)?)', txt)
if all_numbers:
return all_numbers[-1].replace(' ', '').replace(',', '')
return None
# Helper function: Extract answer from #### marker (MATH-aware)
def extract_from_hash_marker(txt: str) -> Optional[str]:
"""Extract answer after #### marker, handling MATH-specific formats"""
if '####' not in txt:
return None
# Split by #### and get the LAST occurrence (most recent answer)
parts = txt.split('####')
answer_part = parts[-1].strip()
before_hash = parts[-2] if len(parts) >= 2 else ""
# Check if answer_part is actually useful
# It's not useful if: empty, too long, starts with "Step", or is a full sentence without clear answer
answer_looks_like_sentence = (
answer_part and
len(answer_part.split()) > 5 and # More than 5 words
not answer_part[0].isdigit() and # Doesn't start with a number
not answer_part.startswith('$') and # Doesn't start with LaTeX
('the' in answer_part.lower() or 'is' in answer_part.lower()) # Contains common sentence words
)
if not answer_part or len(answer_part) > 200 or answer_part.startswith('Step') or answer_looks_like_sentence:
# If nothing useful after ####, or it's too long (likely paragraph), or starts with Step, or looks like a sentence
# Try to extract from the text BEFORE #### (look for "Final Answer:" or similar)
if before_hash:
# Look for common answer indicators in the last 300 chars before ####
before_snippet = before_hash[-300:] if len(before_hash) > 300 else before_hash
# Try to find "Final Answer:", "Answer:", etc.
answer_patterns = [
r'Final Answer:\s*([^\n]+)',
r'Answer:\s*([^\n]+)',
r'Therefore,?\s+(?:the answer is|we get|we have)\s*:?\s*([^\n]+)',
r'(?:Thus|Hence|So),?\s+(?:the answer is|we get|we have)\s*:?\s*([^\n]+)',
]
for pattern in answer_patterns:
match = re.search(pattern, before_snippet, re.IGNORECASE)
if match:
potential_answer = match.group(1).strip()
# Clean up the extracted answer
potential_answer = potential_answer.rstrip('.,;:')
# Extract just the number/value if it's a sentence
words = potential_answer.split()
if len(words) > 3: # If it's a phrase, extract the first number/value
# Look for a number at the start
number_match = re.match(r'^(-?\d+(?:\.\d+)?)', potential_answer)
if number_match:
return number_match.group(1)
return potential_answer
# Continue with normal after-#### processing
if not answer_part:
# If last #### has nothing after it, try second-to-last
if len(parts) >= 2:
answer_part = parts[-2].strip()
if not answer_part:
return None
# Skip if it starts with "Step" (incomplete generation)
if answer_part.startswith('Step'):
# Try to find a #### that's NOT followed by Step
for i in range(len(parts) - 1, 0, -1):
candidate = parts[i].strip()
if candidate and not candidate.startswith('Step'):
answer_part = candidate
break
else:
return None
# Format 1: #### $latex_expression$
# Extract content between first $ and last $ on same line
first_line = answer_part.split('\n')[0].strip()
if first_line.startswith('$') and first_line.count('$') >= 2:
# Extract content between first and last $
dollar_content = first_line[1:] # Remove first $
if '$' in dollar_content:
dollar_content = dollar_content[:dollar_content.rfind('$')] # Remove last $
dollar_content = dollar_content.strip()
# If it's an equation like "f(x)=5", extract the value after =
if '=' in dollar_content:
# Split by = and get the rightmost part
eq_parts = dollar_content.split('=')
# Get last non-empty part
for part in reversed(eq_parts):
part = part.strip()
if part:
return part
return dollar_content
# Format 2: #### plain answer (number or expression)
# Take first line only, clean it
first_line = answer_part.split('\n')[0].strip()
# Remove trailing/leading $ if present
first_line = first_line.strip('$')
# Remove markdown formatting (**, *, etc.)
first_line = re.sub(r'\*\*([^*]+)\*\*', r'\1', first_line) # **text** → text
first_line = re.sub(r'\*([^*]+)\*', r'\1', first_line) # *text* → text
first_line = first_line.strip()
# Remove common prefixes that indicate an answer
prefixes_to_remove = [
'The final answer is',
'The answer is',
'Final Answer:',
'Final Answer',
'Answer:',
'Answer',
'Therefore,',
'Thus,',
'So,',
'Hence,',
]
for prefix in prefixes_to_remove:
if first_line.startswith(prefix):
first_line = first_line[len(prefix):].strip()
# Remove trailing punctuation and colons after removing prefix
first_line = first_line.lstrip(':').strip().rstrip('.,;:')
break
# Remove common suffixes like explanations
first_line = re.split(r'\n|Explanation:|Note:|Solution:', first_line)[0].strip()
# If what remains is a sentence (contains many words) OR has text after a number, extract just the number
# This handles cases like "3 treeks", "70 dollars", "The combined weight of three treeks equals the weight of one squig."
words = first_line.split()
if len(words) >= 2: # Has multiple words (number + text), try to extract just the value
# Look for numbers in the text
numbers = re.findall(r'-?\d+(?:\.\d+)?', first_line)
if numbers:
# If there's only one number, return it
if len(numbers) == 1:
return numbers[0]
# If multiple numbers, try to find the most likely answer (usually the last meaningful one)
return numbers[-1]
# If no numbers, try to extract key mathematical value
# Look for pattern like "X equals Y" or "X = Y" or "has X units"
if '=' in first_line or 'equals' in first_line.lower() or 'has' in first_line.lower():
split_parts = re.split(r'=|equals|has', first_line, flags=re.IGNORECASE)
if len(split_parts) >= 2:
last_part = split_parts[-1].strip().rstrip('.,;:')
# Try to extract number or short expression
cleaned = re.sub(r'\b(the|a|an|one|of|to|is|are)\b', '', last_part, flags=re.IGNORECASE).strip()
# Extract first number from cleaned part
num_match = re.search(r'-?\d+(?:\.\d+)?', cleaned)
if num_match:
return num_match.group(0)
if cleaned and len(cleaned) < 20:
return cleaned
return first_line if first_line else None
# Helper function: Extract from last coherent step
def extract_from_last_step(txt: str) -> Optional[str]:
"""Extract answer from the last coherent step before repetition"""
# Find all Step N: patterns
step_matches = list(re.finditer(r'Step \d+:', txt))
if not step_matches:
return None
# Get the last few steps
for match in reversed(step_matches[-5:]): # Check last 5 steps
start = match.end()
# Find end of this step (next Step or end of text)
next_match_idx = step_matches.index(match) + 1
if next_match_idx < len(step_matches):
end = step_matches[next_match_idx].start()
else:
end = len(txt)
step_content = txt[start:end].strip()
# Look for equations or expressions ending with $ on a line by itself
lines = step_content.split('\n')
for line in reversed(lines):
line = line.strip()
# Check if line contains LaTeX expression
if '$' in line and not line.startswith('Step'):
# Extract from $...$
if line.count('$') >= 2:
dollar_parts = line.split('$')
for part in reversed(dollar_parts):
part = part.strip()
if part and not part.startswith('Step'):
# Check if this looks like an answer (not a full sentence)
if len(part) < 50 and ('=' in part or '\\' in part or part.replace('.', '').replace('-', '').replace(',', '').replace(' ', '').replace('/', '').isalnum()):
# Extract the value after = if present
if '=' in part:
after_eq = part.split('=')[-1].strip()
if after_eq:
return after_eq
return part
return None
if task in ("math-500", "math"):
# MATH format: For gold answers, the text IS the answer (may contain LaTeX)
# For generated text, use multi-strategy extraction
# Remove repetitive patterns first
text = remove_repetition(text)
# Strategy 1: Look for \boxed{} notation (official MATH format, highest priority)
# Use the canonical last_boxed_only_string extractor
boxed_answer = last_boxed_only_string(text)
if boxed_answer:
result = boxed_answer
# Note: If boxed_answer is None, this is a formatting failure for MATH.
# Upstream code may want to track n_with_boxed vs n_without_boxed
# to distinguish formatting quality from reasoning quality.
# Strategy 2: Look for #### marker (GSM8K-style, but model sometimes uses it)
elif '####' in text:
hash_answer = extract_from_hash_marker(text)
if hash_answer:
result = hash_answer
else:
result = None
# Strategy 3: If text looks like a direct answer (gold answer case)
# This handles gold answers like "\frac{14}{3}", "x \\in [-2,7]", or "p - q"
elif not ("Step" in text or len(text) > 200):
# Likely a gold answer - return as-is after light cleanup
result = text.strip()
else:
# Strategy 4: Extract from last coherent step
step_answer = extract_from_last_step(text)
if step_answer:
result = step_answer
else:
# Strategy 5: For MATH, do NOT default to last numeric token
result = None
# R1 FALLBACK for MATH: Check if we should try before-hash extraction
if use_r1_fallback and '####' in text:
should_try_r1 = False
if result is None:
should_try_r1 = True
elif gold_answer is not None and not evaluate_math_answer(result, gold_answer):
should_try_r1 = True
if should_try_r1:
r1_answer = extract_answer_before_hash_r1(text, task)
if r1_answer and (gold_answer is None or evaluate_math_answer(r1_answer, gold_answer)):
return r1_answer
return result
if task == "gsm8k":
# GSM8K format: answer is after ####
# Use the shared extract_from_hash_marker helper for consistency
from_hash = extract_from_hash_marker(text)
if from_hash is None:
# No #### found or extraction failed, try last number
result = extract_last_number(text)
else:
# For GSM8K, enforce "single number" semantics
# Normalize to extract just the numeric value
num = normalize_answer(from_hash)
result = num if num != "" else extract_last_number(text)
# R1 FALLBACK for GSM8K: Check if we should try before-hash extraction
if use_r1_fallback and '####' in text:
should_try_r1 = False
if result is None:
should_try_r1 = True
elif gold_answer is not None and not evaluate_answer(result, gold_answer):
should_try_r1 = True
if should_try_r1:
r1_answer = extract_answer_before_hash_r1(text, task)
if r1_answer and (gold_answer is None or evaluate_answer(r1_answer, gold_answer)):
return r1_answer
return result
# Generic extraction: try to find last number in entire text
return extract_last_number(text)
def normalize_answer(answer: str) -> str:
"""Normalize answer for comparison
Removes spaces, commas, currency symbols and converts to standard number format.
Extracts first number if answer contains non-numeric text.
Args:
answer: Answer string
Returns:
Normalized answer string
"""
if not answer:
return ""
import re
# First try to extract just the number (handles "70 dollars", etc.)
# Remove leading currency symbols
answer = re.sub(r'^[\$£€¥₹]+\s*', '', answer.strip())
# Extract first number (with optional decimal point and commas)
number_match = re.match(r'(-?[\d,]+(?:\.\d+)?)', answer)
if number_match:
answer = number_match.group(1)
# Remove remaining spaces and commas
answer = re.sub(r'[\s,]', '', answer)
# Try to convert to number and normalize
try:
num = float(answer)
if num.is_integer():
return str(int(num))
else:
return str(num)
except (ValueError, TypeError):
return answer.strip()
def answers_match(answer1: str, answer2: str) -> bool:
"""Check if two answers match after normalization
Args:
answer1: First answer
answer2: Second answer
Returns:
True if answers match after normalization
"""
return normalize_answer(answer1) == normalize_answer(answer2)
def evaluate_answer(
predicted: str,
ground_truth: str,
tolerance: float = 1e-5,
) -> bool:
"""Evaluate if predicted answer matches ground truth (robust version)
Uses normalization to handle:
- Different spacing: "1 000" vs "1000"
- Commas: "1,000" vs "1000"
- Currency: "$50" vs "50"
- Trailing text: "72 clips" vs "72"
- Decimal vs integer: "72.0" vs "72"
Args:
predicted: Predicted answer
ground_truth: Ground truth answer
tolerance: Numerical tolerance (not used with normalization, kept for API compatibility)
Returns:
True if answers match
"""
if not predicted or not ground_truth:
return False
# Use robust normalization and matching
return answers_match(predicted, ground_truth)
def math_normalize_answer(answer: str) -> str:
"""Normalize MATH dataset answer for comparison
Handles:
- LaTeX expressions: \\frac{a}{b}, \\sqrt{x}, etc.
- Intervals: [a,b], (a,b), x \\in [a,b]
- Sets: \\{a, b, c\\}
- Matrices and vectors
- Numerical values with various formats
- Whitespace normalization
Args:
answer: Answer string (may contain LaTeX)
Returns:
Normalized answer string
"""
if not answer:
return ""
import re
# Strip outer whitespace
answer = answer.strip()
# Remove LaTeX display mode markers
answer = answer.replace('$', '')
answer = answer.replace('\\[', '').replace('\\]', '')
answer = answer.replace('\\(', '').replace('\\)', '')
# Normalize whitespace around common delimiters
answer = re.sub(r'\s*,\s*', ',', answer) # "a, b" → "a,b"
answer = re.sub(r'\s*=\s*', '=', answer) # "x = 5" → "x=5"
answer = re.sub(r'\s+', ' ', answer) # Multiple spaces → single space
# Normalize LaTeX fractions: \frac{a}{b} → frac(a,b) for comparison
# This allows \frac{1}{2} to match \frac{2}{4} after simplification
def normalize_frac(match):
num = match.group(1).strip()
den = match.group(2).strip()
return f"frac({num},{den})"
answer = re.sub(r'\\frac\{([^}]+)\}\{([^}]+)\}', normalize_frac, answer)
# Normalize common LaTeX commands (preserve but standardize)
latex_commands = {
'\\sqrt': 'sqrt',
'\\pi': 'pi',
'\\infty': 'inf',
'\\in': ' in ',
'\\cup': ' cup ',
'\\cap': ' cap ',
'\\subset': ' subset ',
'\\subseteq': ' subseteq ',
'\\emptyset': 'emptyset',
'\\varnothing': 'emptyset',
'\\left': '',
'\\right': '',
# Note: \\{ and \\} handled separately below for clarity
'\\ldots': '...',
'\\cdots': '...',
'\\le': '<=',
'\\ge': '>=',
'\\leq': '<=',
'\\geq': '>=',
'\\ne': '!=',
'\\neq': '!=',
'\\times': '*',
'\\cdot': '*',
'\\div': '/',
}
for latex_cmd, replacement in latex_commands.items():
answer = answer.replace(latex_cmd, replacement)
# Handle escaped braces separately (avoid duplicate with dict)
answer = answer.replace('\\{', '{')
answer = answer.replace('\\}', '}')
# Normalize interval notation: remove extra spaces inside brackets
answer = re.sub(r'\[\s*', '[', answer)
answer = re.sub(r'\s*\]', ']', answer)
answer = re.sub(r'\(\s*', '(', answer)
answer = re.sub(r'\s*\)', ')', answer)
# Try to simplify fractions if both numerator and denominator are integers
def simplify_frac(match):
import math
try:
num = int(match.group(1))
den = int(match.group(2))
gcd = math.gcd(abs(num), abs(den))
return f"frac({num//gcd},{den//gcd})"
except:
return match.group(0)
answer = re.sub(r'frac\((-?\d+),(-?\d+)\)', simplify_frac, answer)
# Normalize numbers: remove commas, standardize decimals
def normalize_number(match):
num_str = match.group(0).replace(',', '')
try:
num = float(num_str)
if num.is_integer():
return str(int(num))
else:
# Round to 6 decimal places to avoid floating point issues
return f"{num:.6f}".rstrip('0').rstrip('.')
except:
return match.group(0)
answer = re.sub(r'-?\d+(?:,\d{3})*(?:\.\d+)?', normalize_number, answer)
# Final cleanup
answer = answer.strip()
return answer
def canonical_non_numeric(ans: str) -> Optional[str]:
"""Map various verbal answers to a canonical label
Handles non-numeric MATH answers like:
- "no solution"
- "does not exist"
- "infinitely many solutions"
- "all real numbers"
Args:
ans: Answer string
Returns:
Canonical label if recognized, None otherwise
"""
import re
a = ans.strip().lower()
# Strip punctuation
a = re.sub(r'[.,;:!?\s]+', ' ', a).strip()
# Common patterns
if any(kw in a for kw in ["no solution", "no real solution", "no real solutions", "no solutions"]):
return "NO_SOLUTION"
if any(kw in a for kw in ["does not exist", "dne", "undefined"]):
return "DNE"
if any(kw in a for kw in ["infinitely many", "infinite number of solutions", "infinite solutions"]):
return "INFINITELY_MANY"
if any(kw in a for kw in ["all real numbers", "any real number", "for all real x", "all reals"]):
return "ALL_REALS"
return None
def math_answers_match(answer1: str, answer2: str) -> bool:
"""Check if two MATH dataset answers match after normalization
Uses multiple strategies:
1. Exact string match after normalization
2. Numeric equality for pure numbers
3. SymPy-based symbolic equivalence (algebraic, fraction, etc.)
4. Whitespace-insensitive string comparison
Args:
answer1: First answer
answer2: Second answer
Returns:
True if answers match after normalization
"""
if not answer1 or not answer2:
return False
import re
norm1 = math_normalize_answer(answer1)
norm2 = math_normalize_answer(answer2)
# 0. Non-numeric canonical forms (check before other strategies)
# Handles "no solution", "DNE", "infinitely many", etc.
canon1 = canonical_non_numeric(norm1)
canon2 = canonical_non_numeric(norm2)
if canon1 is not None or canon2 is not None:
return canon1 == canon2
# 1. Direct string match after normalization
if norm1 == norm2:
return True
# 2. Try to evaluate as numbers if possible
try:
# Extract just numbers for comparison (handles cases like "4" vs "4.0")
num1_match = re.search(r'^-?\d+(?:\.\d+)?$', norm1)
num2_match = re.search(r'^-?\d+(?:\.\d+)?$', norm2)
if num1_match and num2_match:
num1 = float(norm1)
num2 = float(norm2)
return abs(num1 - num2) < 1e-6
except:
pass
# 3. Try SymPy-based symbolic equivalence
# This handles algebraic equivalence, fraction simplification, etc.
try:
import sympy as sp
def try_sympy_parse(expr: str) -> Optional[sp.Expr]:
"""Attempt to parse expression as SymPy object"""
try:
expr_clean = expr
# Convert frac(a,b) back to a/b
expr_clean = re.sub(r'frac\(([^,]+),([^)]+)\)', r'(\1)/(\2)', expr_clean)
# Basic cleanup: remove stray spaces around operators
expr_clean = re.sub(r'\s+', ' ', expr_clean)
# Locals mapping: let sympify know about sqrt, pi, etc.
# Don't inject sp. into the string; pass locals dict instead
locals_dict = {
"sqrt": sp.sqrt,
"pi": sp.pi,
"inf": sp.oo,
"e": sp.E,
}
return sp.sympify(expr_clean, locals=locals_dict, rational=True)
except Exception:
return None
expr1 = try_sympy_parse(norm1)
expr2 = try_sympy_parse(norm2)
if expr1 is not None and expr2 is not None:
try:
# Check symbolic equality
diff = sp.simplify(expr1 - expr2)
if diff == 0:
return True
# Also try expanding and simplifying both sides
if sp.simplify(expr1) == sp.simplify(expr2):
return True
# For fractions, check if they're equivalent
if isinstance(expr1, sp.Rational) and isinstance(expr2, sp.Rational):
return expr1 == expr2
except Exception:
pass
except ImportError:
# SymPy not available, skip symbolic checking
pass
# 4. Check if they differ only in equivalent representations
# E.g., "x in [-2,7]" vs "x in [ -2 , 7 ]"
compact1 = re.sub(r'\s+', '', norm1).lower()
compact2 = re.sub(r'\s+', '', norm2).lower()
return compact1 == compact2
def evaluate_math_answer(
predicted: str,
ground_truth: str,
) -> bool:
"""Evaluate if predicted MATH answer matches ground truth
Uses MATH-specific normalization that handles:
- LaTeX expressions
- Fractions and simplification
- Intervals and sets
- Multiple equivalent representations
Args:
predicted: Predicted answer
ground_truth: Ground truth answer
Returns:
True if answers match
"""
if not predicted or not ground_truth:
return False
return math_answers_match(predicted, ground_truth)
def format_prompt(
question: str,
template: str = "default",
few_shot_examples: Optional[List[Dict[str, str]]] = None,
) -> str:
"""Format prompt for model input
Args:
question: Input question
template: Prompt template to use
few_shot_examples: Optional few-shot examples
Returns:
Formatted prompt
"""
if template == "default":
prompt = f"Question: {question}\n\nAnswer:"
elif template == "cot":
# Structured prompt with clear instructions for GSM8K format
prompt = f"""You are a helpful assistant that solves problems step by step with each step signified by "Step [step_number]: ".
Always provide your final answer after #### at the end.
Question: {question}
Please solve this step by step, putting each step after "Step [step_number]: " and always provide your final answer after ####.
Solution:
"""
elif template == "cot_final":
# Chain of thought with "Final Answer:" marker
prompt = f"""You are a helpful assistant that solves problems step by step with each step signified by "Step [step_number]: ".
Always provide your final answer after "Final Answer:" at the end.
Question: {question}
Please solve this step by step, putting each step after "Step [step_number]: " and always provide your final answer after "Final Answer:".
Solution:
"""
elif template == "math_cot":
# Chain of thought for MATH dataset with \boxed{} formatting
prompt = f"""You are a helpful assistant that solves olympiad-style math problems step by step.
At the end, ALWAYS put your final answer in LaTeX form inside \\boxed{{...}} on its own line.
Question: {question}
Please solve this step by step, and put ONLY the final result (no explanation) inside \\boxed{{...}} on the last line.
Solution:
"""
elif template == "direct":
prompt = question
else:
prompt = question
# Add few-shot examples if provided
if few_shot_examples:
examples_text = ""
for ex in few_shot_examples:
examples_text += f"Question: {ex['question']}\nAnswer: {ex['answer']}\n\n"
prompt = examples_text + prompt
return prompt
def calculate_metrics(
predictions: List[Optional[str]],
ground_truths: List[Optional[str]],
task: str = "gsm8k",
) -> Dict[str, float]:
"""Calculate evaluation metrics with extraction tracking
Distinguishes between extraction failures and reasoning errors:
- accuracy: standard accuracy over all prompts (correct / total)
- conditional_accuracy: accuracy given extraction succeeded (correct / extractable)
- extraction_rate: fraction of prompts with extractable answers
This allows distinguishing formatting quality from reasoning quality.
Args:
predictions: List of predicted answers (None if extraction failed)
ground_truths: List of ground truth answers
task: Task format ("gsm8k" or "math" / "math-500")
Returns:
Dictionary of metrics including accuracy, conditional_accuracy, extraction_rate
"""
if len(predictions) != len(ground_truths):
raise ValueError("Number of predictions and ground truths must match")
total_prompts = len(predictions)
valid_total = 0
n_extractable = 0
n_correct = 0
for pred, gt in zip(predictions, ground_truths):
# Skip if ground truth is None (bad gold answer)
if gt is None:
continue
valid_total += 1
# Track extraction success
if pred is None:
continue
n_extractable += 1
# Dispatch to appropriate evaluation function based on task
if task in ("math", "math-500"):
is_correct = evaluate_math_answer(pred, gt)
else:
is_correct = evaluate_answer(pred, gt)
if is_correct:
n_correct += 1
overall_acc = n_correct / valid_total if valid_total > 0 else 0.0
cond_acc = n_correct / n_extractable if n_extractable > 0 else 0.0
extract_rate = n_extractable / valid_total if valid_total > 0 else 0.0
return {
"accuracy": overall_acc, # correct / valid_total
"conditional_accuracy": cond_acc, # correct / extractable
"extraction_rate": extract_rate, # extractable / valid_total
"correct": n_correct,
"extractable": n_extractable,
"total": valid_total, # valid examples (excludes None golds)
"total_prompts": total_prompts, # all prompts including invalid
}
def print_metrics(metrics: Dict[str, Any], title: str = "Results") -> None:
"""Pretty print metrics
Args:
metrics: Dictionary of metrics
title: Title for output
"""
print(f"\n{'='*50}")
print(f"{title:^50}")
print(f"{'='*50}")
for key, value in metrics.items():
if isinstance(value, float):
print(f"{key:.<30} {value:.4f}")
else:
print(f"{key:.<30} {value}")
print(f"{'='*50}\n")
def create_directory_structure(base_dir: Optional[str] = None) -> None:
"""Create standard directory structure
Args:
base_dir: Base directory (defaults to current directory)
"""
if base_dir is None:
base_dir = Path.cwd()
else:
base_dir = Path(base_dir)
directories = [
"data",
"data/gsm8k",
"models",
"models/cache",
"output",
"output/results",
"output/trajectories",
"output/checkpoints",
"output/logs",
"output/cache",
"config",
"src",
"examples",
]
for dir_name in directories:
dir_path = base_dir / dir_name
dir_path.mkdir(parents=True, exist_ok=True)
print(f"Created directory structure at {base_dir}")