LuisZermeno commited on
Commit
b8febd7
·
verified ·
1 Parent(s): 0599958

Create answer_extractor.py

Browse files
Files changed (1) hide show
  1. answer_extractor.py +187 -0
answer_extractor.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from typing import Optional, List, Dict, Any
3
+ import logging
4
+
5
+ logger = logging.getLogger(__name__)
6
+
7
+ def extract_final_answer(text: str) -> Optional[str]:
8
+ """Extract final answer from text using multiple strategies"""
9
+ if not text:
10
+ return None
11
+
12
+ # Strategy 1: Look for explicit FINAL ANSWER format
13
+ final_answer_match = re.search(r'FINAL ANSWER:\s*(.+?)(?:\n|$)', text, re.IGNORECASE)
14
+ if final_answer_match:
15
+ answer = final_answer_match.group(1).strip()
16
+ return clean_answer(answer)
17
+
18
+ # Strategy 2: Look for answer patterns based on question type
19
+ answer = extract_by_pattern(text)
20
+ if answer:
21
+ return clean_answer(answer)
22
+
23
+ # Strategy 3: Look for the last definitive statement
24
+ answer = extract_last_statement(text)
25
+ if answer:
26
+ return clean_answer(answer)
27
+
28
+ return None
29
+
30
+ def clean_answer(answer: str) -> str:
31
+ """Clean and format answer according to GAIA requirements"""
32
+ if not answer:
33
+ return ""
34
+
35
+ # Remove quotes if they wrap the entire answer
36
+ if answer.startswith('"') and answer.endswith('"'):
37
+ answer = answer[1:-1]
38
+ if answer.startswith("'") and answer.endswith("'"):
39
+ answer = answer[1:-1]
40
+
41
+ # Remove common prefixes
42
+ prefixes_to_remove = [
43
+ "The answer is ",
44
+ "The result is ",
45
+ "It is ",
46
+ "This is ",
47
+ "Therefore, ",
48
+ "So, ",
49
+ "Thus, ",
50
+ ]
51
+
52
+ for prefix in prefixes_to_remove:
53
+ if answer.lower().startswith(prefix.lower()):
54
+ answer = answer[len(prefix):]
55
+
56
+ # Clean up whitespace
57
+ answer = answer.strip()
58
+
59
+ # Handle special formats
60
+ answer = format_special_answers(answer)
61
+
62
+ return answer
63
+
64
+ def format_special_answers(answer: str) -> str:
65
+ """Format answers according to common GAIA patterns"""
66
+ # If it's a pure number, return just the number
67
+ if re.match(r'^-?\d+\.?\d*$', answer):
68
+ return answer
69
+
70
+ # If it's yes/no, normalize
71
+ if answer.lower() in ['yes', 'no']:
72
+ return answer.lower()
73
+
74
+ # If it's a date, try to standardize
75
+ date_match = re.search(r'(\d{1,2})[/-](\d{1,2})[/-](\d{2,4})', answer)
76
+ if date_match:
77
+ month, day, year = date_match.groups()
78
+ if len(year) == 2:
79
+ year = '20' + year
80
+ return f"{month}/{day}/{year}"
81
+
82
+ return answer
83
+
84
+ def extract_by_pattern(text: str) -> Optional[str]:
85
+ """Extract answer based on common patterns"""
86
+ patterns = [
87
+ # Numbers
88
+ (r'(?:total|sum|count|number|result)(?:\s+is)?:?\s*(\d+\.?\d*)', lambda m: m.group(1)),
89
+ # Yes/No
90
+ (r'\b(yes|no)\b(?:\s*[,.\n]|$)', lambda m: m.group(1).lower()),
91
+ # Names
92
+ (r'(?:name is|called|known as)\s+([A-Z][a-zA-Z\s]+?)(?:[,.\n]|$)', lambda m: m.group(1).strip()),
93
+ # Years
94
+ (r'(?:year|in)\s+(19\d{2}|20\d{2})\b', lambda m: m.group(1)),
95
+ # Countries
96
+ (r'(?:country|nation|located in)\s+([A-Z][a-zA-Z\s]+?)(?:[,.\n]|$)', lambda m: m.group(1).strip()),
97
+ ]
98
+
99
+ for pattern, extractor in patterns:
100
+ matches = re.findall(pattern, text, re.IGNORECASE)
101
+ if matches:
102
+ # Return the last match (usually most relevant)
103
+ return extractor(re.search(pattern, text, re.IGNORECASE))
104
+
105
+ return None
106
+
107
+ def extract_last_statement(text: str) -> Optional[str]:
108
+ """Extract the last meaningful statement from text"""
109
+ # Split into sentences
110
+ sentences = re.split(r'[.!?]\s+', text)
111
+
112
+ # Work backwards to find a meaningful statement
113
+ for sentence in reversed(sentences):
114
+ sentence = sentence.strip()
115
+
116
+ # Skip empty or very short sentences
117
+ if len(sentence) < 3:
118
+ continue
119
+
120
+ # Skip meta-statements
121
+ if any(skip in sentence.lower() for skip in ['based on', 'according to', 'therefore', 'thus']):
122
+ continue
123
+
124
+ # Check if it contains an answer-like pattern
125
+ if re.search(r'\b(?:is|are|was|were|equals?|contains?)\b', sentence, re.IGNORECASE):
126
+ # Extract the part after the verb
127
+ match = re.search(r'\b(?:is|are|was|were|equals?|contains?)\s+(.+?)(?:[,.\n]|$)', sentence, re.IGNORECASE)
128
+ if match:
129
+ return match.group(1).strip()
130
+
131
+ # If it's a short definitive statement, return it
132
+ if len(sentence.split()) <= 5:
133
+ return sentence
134
+
135
+ return None
136
+
137
+ def extract_from_calculation(text: str) -> Optional[str]:
138
+ """Extract numerical answer from calculation text"""
139
+ # Look for equation results
140
+ patterns = [
141
+ r'=\s*(-?\d+\.?\d*)',
142
+ r'(?:equals?|is)\s+(-?\d+\.?\d*)',
143
+ r'(?:result|answer):\s*(-?\d+\.?\d*)',
144
+ r'^(-?\d+\.?\d*)$', # Just a number on its own line
145
+ ]
146
+
147
+ for pattern in patterns:
148
+ match = re.search(pattern, text, re.MULTILINE | re.IGNORECASE)
149
+ if match:
150
+ return match.group(1)
151
+
152
+ return None
153
+
154
+ def extract_from_data_analysis(text: str) -> Optional[str]:
155
+ """Extract answer from data analysis results"""
156
+ # Look for summary statistics
157
+ patterns = [
158
+ r'(?:total|sum)(?:\s+is)?:?\s*(-?\d+\.?\d*)',
159
+ r'(?:mean|average)(?:\s+is)?:?\s*(-?\d+\.?\d*)',
160
+ r'(?:count|number)(?:\s+is)?:?\s*(\d+)',
161
+ r'(?:maximum|max)(?:\s+is)?:?\s*(-?\d+\.?\d*)',
162
+ r'(?:minimum|min)(?:\s+is)?:?\s*(-?\d+\.?\d*)',
163
+ ]
164
+
165
+ for pattern in patterns:
166
+ match = re.search(pattern, text, re.IGNORECASE)
167
+ if match:
168
+ return match.group(1)
169
+
170
+ return None
171
+
172
+ def validate_answer_format(answer: str, question: str) -> bool:
173
+ """Validate that answer format matches question requirements"""
174
+ question_lower = question.lower()
175
+
176
+ # Check for specific format requirements
177
+ if 'how many' in question_lower and not re.match(r'^\d+$', answer):
178
+ return False
179
+
180
+ if 'what year' in question_lower and not re.match(r'^\d{4}$', answer):
181
+ return False
182
+
183
+ if any(phrase in question_lower for phrase in ['yes or no', 'yes/no']):
184
+ if answer.lower() not in ['yes', 'no']:
185
+ return False
186
+
187
+ return True