LuisZermeno's picture
Create answer_extractor.py
b8febd7 verified
import re
from typing import Optional, List, Dict, Any
import logging
logger = logging.getLogger(__name__)
def extract_final_answer(text: str) -> Optional[str]:
"""Extract final answer from text using multiple strategies"""
if not text:
return None
# Strategy 1: Look for explicit FINAL ANSWER format
final_answer_match = re.search(r'FINAL ANSWER:\s*(.+?)(?:\n|$)', text, re.IGNORECASE)
if final_answer_match:
answer = final_answer_match.group(1).strip()
return clean_answer(answer)
# Strategy 2: Look for answer patterns based on question type
answer = extract_by_pattern(text)
if answer:
return clean_answer(answer)
# Strategy 3: Look for the last definitive statement
answer = extract_last_statement(text)
if answer:
return clean_answer(answer)
return None
def clean_answer(answer: str) -> str:
"""Clean and format answer according to GAIA requirements"""
if not answer:
return ""
# Remove quotes if they wrap the entire answer
if answer.startswith('"') and answer.endswith('"'):
answer = answer[1:-1]
if answer.startswith("'") and answer.endswith("'"):
answer = answer[1:-1]
# Remove common prefixes
prefixes_to_remove = [
"The answer is ",
"The result is ",
"It is ",
"This is ",
"Therefore, ",
"So, ",
"Thus, ",
]
for prefix in prefixes_to_remove:
if answer.lower().startswith(prefix.lower()):
answer = answer[len(prefix):]
# Clean up whitespace
answer = answer.strip()
# Handle special formats
answer = format_special_answers(answer)
return answer
def format_special_answers(answer: str) -> str:
"""Format answers according to common GAIA patterns"""
# If it's a pure number, return just the number
if re.match(r'^-?\d+\.?\d*$', answer):
return answer
# If it's yes/no, normalize
if answer.lower() in ['yes', 'no']:
return answer.lower()
# If it's a date, try to standardize
date_match = re.search(r'(\d{1,2})[/-](\d{1,2})[/-](\d{2,4})', answer)
if date_match:
month, day, year = date_match.groups()
if len(year) == 2:
year = '20' + year
return f"{month}/{day}/{year}"
return answer
def extract_by_pattern(text: str) -> Optional[str]:
"""Extract answer based on common patterns"""
patterns = [
# Numbers
(r'(?:total|sum|count|number|result)(?:\s+is)?:?\s*(\d+\.?\d*)', lambda m: m.group(1)),
# Yes/No
(r'\b(yes|no)\b(?:\s*[,.\n]|$)', lambda m: m.group(1).lower()),
# Names
(r'(?:name is|called|known as)\s+([A-Z][a-zA-Z\s]+?)(?:[,.\n]|$)', lambda m: m.group(1).strip()),
# Years
(r'(?:year|in)\s+(19\d{2}|20\d{2})\b', lambda m: m.group(1)),
# Countries
(r'(?:country|nation|located in)\s+([A-Z][a-zA-Z\s]+?)(?:[,.\n]|$)', lambda m: m.group(1).strip()),
]
for pattern, extractor in patterns:
matches = re.findall(pattern, text, re.IGNORECASE)
if matches:
# Return the last match (usually most relevant)
return extractor(re.search(pattern, text, re.IGNORECASE))
return None
def extract_last_statement(text: str) -> Optional[str]:
"""Extract the last meaningful statement from text"""
# Split into sentences
sentences = re.split(r'[.!?]\s+', text)
# Work backwards to find a meaningful statement
for sentence in reversed(sentences):
sentence = sentence.strip()
# Skip empty or very short sentences
if len(sentence) < 3:
continue
# Skip meta-statements
if any(skip in sentence.lower() for skip in ['based on', 'according to', 'therefore', 'thus']):
continue
# Check if it contains an answer-like pattern
if re.search(r'\b(?:is|are|was|were|equals?|contains?)\b', sentence, re.IGNORECASE):
# Extract the part after the verb
match = re.search(r'\b(?:is|are|was|were|equals?|contains?)\s+(.+?)(?:[,.\n]|$)', sentence, re.IGNORECASE)
if match:
return match.group(1).strip()
# If it's a short definitive statement, return it
if len(sentence.split()) <= 5:
return sentence
return None
def extract_from_calculation(text: str) -> Optional[str]:
"""Extract numerical answer from calculation text"""
# Look for equation results
patterns = [
r'=\s*(-?\d+\.?\d*)',
r'(?:equals?|is)\s+(-?\d+\.?\d*)',
r'(?:result|answer):\s*(-?\d+\.?\d*)',
r'^(-?\d+\.?\d*)$', # Just a number on its own line
]
for pattern in patterns:
match = re.search(pattern, text, re.MULTILINE | re.IGNORECASE)
if match:
return match.group(1)
return None
def extract_from_data_analysis(text: str) -> Optional[str]:
"""Extract answer from data analysis results"""
# Look for summary statistics
patterns = [
r'(?:total|sum)(?:\s+is)?:?\s*(-?\d+\.?\d*)',
r'(?:mean|average)(?:\s+is)?:?\s*(-?\d+\.?\d*)',
r'(?:count|number)(?:\s+is)?:?\s*(\d+)',
r'(?:maximum|max)(?:\s+is)?:?\s*(-?\d+\.?\d*)',
r'(?:minimum|min)(?:\s+is)?:?\s*(-?\d+\.?\d*)',
]
for pattern in patterns:
match = re.search(pattern, text, re.IGNORECASE)
if match:
return match.group(1)
return None
def validate_answer_format(answer: str, question: str) -> bool:
"""Validate that answer format matches question requirements"""
question_lower = question.lower()
# Check for specific format requirements
if 'how many' in question_lower and not re.match(r'^\d+$', answer):
return False
if 'what year' in question_lower and not re.match(r'^\d{4}$', answer):
return False
if any(phrase in question_lower for phrase in ['yes or no', 'yes/no']):
if answer.lower() not in ['yes', 'no']:
return False
return True