File size: 6,189 Bytes
b8febd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
import re
from typing import Optional, List, Dict, Any
import logging

logger = logging.getLogger(__name__)

def extract_final_answer(text: str) -> Optional[str]:
    """Extract final answer from text using multiple strategies"""
    if not text:
        return None
    
    # Strategy 1: Look for explicit FINAL ANSWER format
    final_answer_match = re.search(r'FINAL ANSWER:\s*(.+?)(?:\n|$)', text, re.IGNORECASE)
    if final_answer_match:
        answer = final_answer_match.group(1).strip()
        return clean_answer(answer)
    
    # Strategy 2: Look for answer patterns based on question type
    answer = extract_by_pattern(text)
    if answer:
        return clean_answer(answer)
    
    # Strategy 3: Look for the last definitive statement
    answer = extract_last_statement(text)
    if answer:
        return clean_answer(answer)
    
    return None

def clean_answer(answer: str) -> str:
    """Clean and format answer according to GAIA requirements"""
    if not answer:
        return ""
    
    # Remove quotes if they wrap the entire answer
    if answer.startswith('"') and answer.endswith('"'):
        answer = answer[1:-1]
    if answer.startswith("'") and answer.endswith("'"):
        answer = answer[1:-1]
    
    # Remove common prefixes
    prefixes_to_remove = [
        "The answer is ",
        "The result is ",
        "It is ",
        "This is ",
        "Therefore, ",
        "So, ",
        "Thus, ",
    ]
    
    for prefix in prefixes_to_remove:
        if answer.lower().startswith(prefix.lower()):
            answer = answer[len(prefix):]
    
    # Clean up whitespace
    answer = answer.strip()
    
    # Handle special formats
    answer = format_special_answers(answer)
    
    return answer

def format_special_answers(answer: str) -> str:
    """Format answers according to common GAIA patterns"""
    # If it's a pure number, return just the number
    if re.match(r'^-?\d+\.?\d*$', answer):
        return answer
    
    # If it's yes/no, normalize
    if answer.lower() in ['yes', 'no']:
        return answer.lower()
    
    # If it's a date, try to standardize
    date_match = re.search(r'(\d{1,2})[/-](\d{1,2})[/-](\d{2,4})', answer)
    if date_match:
        month, day, year = date_match.groups()
        if len(year) == 2:
            year = '20' + year
        return f"{month}/{day}/{year}"
    
    return answer

def extract_by_pattern(text: str) -> Optional[str]:
    """Extract answer based on common patterns"""
    patterns = [
        # Numbers
        (r'(?:total|sum|count|number|result)(?:\s+is)?:?\s*(\d+\.?\d*)', lambda m: m.group(1)),
        # Yes/No
        (r'\b(yes|no)\b(?:\s*[,.\n]|$)', lambda m: m.group(1).lower()),
        # Names
        (r'(?:name is|called|known as)\s+([A-Z][a-zA-Z\s]+?)(?:[,.\n]|$)', lambda m: m.group(1).strip()),
        # Years
        (r'(?:year|in)\s+(19\d{2}|20\d{2})\b', lambda m: m.group(1)),
        # Countries
        (r'(?:country|nation|located in)\s+([A-Z][a-zA-Z\s]+?)(?:[,.\n]|$)', lambda m: m.group(1).strip()),
    ]
    
    for pattern, extractor in patterns:
        matches = re.findall(pattern, text, re.IGNORECASE)
        if matches:
            # Return the last match (usually most relevant)
            return extractor(re.search(pattern, text, re.IGNORECASE))
    
    return None

def extract_last_statement(text: str) -> Optional[str]:
    """Extract the last meaningful statement from text"""
    # Split into sentences
    sentences = re.split(r'[.!?]\s+', text)
    
    # Work backwards to find a meaningful statement
    for sentence in reversed(sentences):
        sentence = sentence.strip()
        
        # Skip empty or very short sentences
        if len(sentence) < 3:
            continue
        
        # Skip meta-statements
        if any(skip in sentence.lower() for skip in ['based on', 'according to', 'therefore', 'thus']):
            continue
        
        # Check if it contains an answer-like pattern
        if re.search(r'\b(?:is|are|was|were|equals?|contains?)\b', sentence, re.IGNORECASE):
            # Extract the part after the verb
            match = re.search(r'\b(?:is|are|was|were|equals?|contains?)\s+(.+?)(?:[,.\n]|$)', sentence, re.IGNORECASE)
            if match:
                return match.group(1).strip()
        
        # If it's a short definitive statement, return it
        if len(sentence.split()) <= 5:
            return sentence
    
    return None

def extract_from_calculation(text: str) -> Optional[str]:
    """Extract numerical answer from calculation text"""
    # Look for equation results
    patterns = [
        r'=\s*(-?\d+\.?\d*)',
        r'(?:equals?|is)\s+(-?\d+\.?\d*)',
        r'(?:result|answer):\s*(-?\d+\.?\d*)',
        r'^(-?\d+\.?\d*)$',  # Just a number on its own line
    ]
    
    for pattern in patterns:
        match = re.search(pattern, text, re.MULTILINE | re.IGNORECASE)
        if match:
            return match.group(1)
    
    return None

def extract_from_data_analysis(text: str) -> Optional[str]:
    """Extract answer from data analysis results"""
    # Look for summary statistics
    patterns = [
        r'(?:total|sum)(?:\s+is)?:?\s*(-?\d+\.?\d*)',
        r'(?:mean|average)(?:\s+is)?:?\s*(-?\d+\.?\d*)',
        r'(?:count|number)(?:\s+is)?:?\s*(\d+)',
        r'(?:maximum|max)(?:\s+is)?:?\s*(-?\d+\.?\d*)',
        r'(?:minimum|min)(?:\s+is)?:?\s*(-?\d+\.?\d*)',
    ]
    
    for pattern in patterns:
        match = re.search(pattern, text, re.IGNORECASE)
        if match:
            return match.group(1)
    
    return None

def validate_answer_format(answer: str, question: str) -> bool:
    """Validate that answer format matches question requirements"""
    question_lower = question.lower()
    
    # Check for specific format requirements
    if 'how many' in question_lower and not re.match(r'^\d+$', answer):
        return False
    
    if 'what year' in question_lower and not re.match(r'^\d{4}$', answer):
        return False
    
    if any(phrase in question_lower for phrase in ['yes or no', 'yes/no']):
        if answer.lower() not in ['yes', 'no']:
            return False
    
    return True