hemantvirmani commited on
Commit
9168c5e
·
1 Parent(s): 395b552

Improve agent performance and testing infrastructure

Browse files

- Increase step limit from 25 to 40 for complex multi-step questions
- Add number formatting validation (remove commas, trailing periods)
- Enhance system prompt with location name expansion rules (St. → Saint)
- Optimize tool usage guidance with priority ordering
- Refactor run_test_code() to accept filter parameter for flexible testing
- Integrate official GAIA scorer for accurate answer verification
- Update ground truth lookup to use task_id instead of question text
- Add summary statistics to test results
- Clean up requirements.txt and remove duplicates

Files changed (3) hide show
  1. app.py +85 -28
  2. requirements.txt +2 -2
  3. scorer.py +107 -0
app.py CHANGED
@@ -9,6 +9,8 @@ import json
9
  from agents import MyLangGraphAgent
10
  # Import Gradio UI creation function
11
  from gradioapp import create_ui
 
 
12
 
13
  # --- Constants ---
14
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
@@ -150,54 +152,102 @@ def run_and_submit_all(username: str):
150
  return status_message, results_df
151
 
152
  def load_ground_truth(file_path="files/metadata.jsonl"):
 
 
 
 
 
153
  truth_mapping = {}
154
  try:
155
  with open(file_path, 'r', encoding='utf-8') as f:
156
  for line in f:
157
  data = json.loads(line)
 
158
  question = data.get("Question")
159
  answer = data.get("Final answer")
160
- if question and answer:
161
- truth_mapping[question] = answer
 
 
 
162
  except Exception as e:
163
  print(f"Error loading ground truth: {e}")
164
  return truth_mapping
165
 
166
  def verify_answers(results, log_output):
 
 
 
 
 
 
167
  ground_truth = load_ground_truth()
168
  log_output.append("\n=== Verification Results ===")
169
- for question, answer in results:
170
- if question in ground_truth:
171
- correct_answer = ground_truth[question]
172
- is_correct = (str(answer).strip() == str(correct_answer).strip())
173
- log_output.append(f"Question: {question}...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
  log_output.append(f"Expected: {correct_answer}")
175
  log_output.append(f"Got: {answer}")
176
- log_output.append(f"Match: {'Correct' if is_correct else 'Incorrect'}\n")
177
  else:
178
- log_output.append(f"Question: {question[:50]}... - No ground truth found.\n")
179
-
180
- def run_test_code():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
  log_output = []
182
  results_to_verify = []
183
  log_output.append("=== Processing Example Questions One by One ===")
184
 
185
- my_questions = []
186
  my_questions_data = FetchQuestions(DEFAULT_API_URL)
187
- if isinstance(my_questions_data, list):
188
- # Extract both question and file_name for selected indices
189
- my_questions = [
190
- {
191
- "question": my_questions_data[i]["question"],
192
- "file_name": my_questions_data[i].get("file_name")
193
- }
194
- #for i in (0, 1, 3, 4, 5, 9, 11, 13, 14, 17, 18) if i < len(my_questions_data) - All 11 incorrect questions
195
- #for i in (0, 1, 4, 5, 14, 17) if i < len(my_questions_data) - All 6 incorrect except ones with files
196
- for i in (0, 5, 17) if i < len(my_questions_data)
197
  ]
198
- #print(f"Running these Questions:\n{chr(10).join(f'{i}. {q[\"question\"]}' for i, q in enumerate(my_questions, 1))}\n")
 
 
 
 
199
 
200
- # 1. Instantiate Agent ( modify this part to create your agent)
201
  try:
202
  my_agent = MyLangGraphAgent()
203
  except Exception as e:
@@ -207,8 +257,9 @@ def run_test_code():
207
 
208
  # Process each question separately
209
  try:
210
- for i, question_item in enumerate(my_questions, 1):
211
  # Use .get() for safe access (returns None if key doesn't exist)
 
212
  question_text = question_item.get("question")
213
  file_name = question_item.get("file_name")
214
 
@@ -216,7 +267,7 @@ def run_test_code():
216
  log_output.append(f"\nQuestion {i}: [ERROR] Missing question text")
217
  continue
218
 
219
- log_output.append(f"\nQuestion {i}: {question_text}")
220
  if file_name:
221
  log_output.append(f"File: {file_name}")
222
 
@@ -225,7 +276,7 @@ def run_test_code():
225
 
226
  print(f"Question: {question_text} Answer: {my_answer}")
227
  log_output.append(f"Answer: {my_answer}")
228
- results_to_verify.append((question_text, my_answer))
229
  except Exception as e:
230
  error_msg = f"Error running agent on task: {e}"
231
  print(error_msg)
@@ -263,7 +314,13 @@ if __name__ == "__main__":
263
 
264
  if args.test and not space_id_startup:
265
  print("Running test code (CLI mode)...")
266
- result = run_test_code()
 
 
 
 
 
 
267
  if isinstance(result, pd.DataFrame):
268
  # Print DataFrame content without truncation
269
  pd.set_option('display.max_colwidth', None)
 
9
  from agents import MyLangGraphAgent
10
  # Import Gradio UI creation function
11
  from gradioapp import create_ui
12
+ # Import scoring function for answer verification
13
+ from scorer import question_scorer
14
 
15
  # --- Constants ---
16
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
 
152
  return status_message, results_df
153
 
154
  def load_ground_truth(file_path="files/metadata.jsonl"):
155
+ """Load ground truth data indexed by task_id.
156
+
157
+ Returns:
158
+ dict: Mapping of task_id -> {"question": str, "answer": str}
159
+ """
160
  truth_mapping = {}
161
  try:
162
  with open(file_path, 'r', encoding='utf-8') as f:
163
  for line in f:
164
  data = json.loads(line)
165
+ task_id = data.get("task_id")
166
  question = data.get("Question")
167
  answer = data.get("Final answer")
168
+ if task_id and answer:
169
+ truth_mapping[task_id] = {
170
+ "question": question,
171
+ "answer": answer
172
+ }
173
  except Exception as e:
174
  print(f"Error loading ground truth: {e}")
175
  return truth_mapping
176
 
177
  def verify_answers(results, log_output):
178
+ """Verify answers against ground truth using the official GAIA scorer.
179
+
180
+ Args:
181
+ results: List of tuples (task_id, question_text, answer)
182
+ log_output: List to append verification results to
183
+ """
184
  ground_truth = load_ground_truth()
185
  log_output.append("\n=== Verification Results ===")
186
+
187
+ correct_count = 0
188
+ total_count = 0
189
+
190
+ for task_id, question_text, answer in results:
191
+ if task_id in ground_truth:
192
+ truth_data = ground_truth[task_id]
193
+ correct_answer = truth_data["answer"]
194
+
195
+ # Use the official GAIA question_scorer for comparison
196
+ # This handles numbers, lists, and strings with proper normalization
197
+ is_correct = question_scorer(str(answer), str(correct_answer))
198
+
199
+ if is_correct:
200
+ correct_count += 1
201
+ total_count += 1
202
+
203
+ log_output.append(f"Task ID: {task_id}")
204
+ log_output.append(f"Question: {question_text[:100]}...")
205
  log_output.append(f"Expected: {correct_answer}")
206
  log_output.append(f"Got: {answer}")
207
+ log_output.append(f"Match: {'Correct' if is_correct else 'Incorrect'}\n")
208
  else:
209
+ log_output.append(f"Task ID: {task_id}")
210
+ log_output.append(f"Question: {question_text[:50]}...")
211
+ log_output.append(f"No ground truth found.\n")
212
+
213
+ # Add summary statistics
214
+ if total_count > 0:
215
+ accuracy = (correct_count / total_count) * 100
216
+ log_output.append("=" * 60)
217
+ log_output.append(f"SUMMARY: {correct_count}/{total_count} correct ({accuracy:.1f}%)")
218
+ log_output.append("=" * 60)
219
+
220
+ def run_test_code(filter=None):
221
+ """Run test code on selected questions.
222
+
223
+ Args:
224
+ filter: Optional tuple/list of question indices to test (e.g., (4, 7, 15)).
225
+ If None, processes all questions.
226
+ """
227
  log_output = []
228
  results_to_verify = []
229
  log_output.append("=== Processing Example Questions One by One ===")
230
 
231
+ # Fetch all questions
232
  my_questions_data = FetchQuestions(DEFAULT_API_URL)
233
+ if not isinstance(my_questions_data, list):
234
+ error_msg = f"Failed to fetch questions: {my_questions_data}"
235
+ print(error_msg)
236
+ return error_msg
237
+
238
+ # Apply filter or use all questions
239
+ if filter is not None:
240
+ # Filter to specific indices
241
+ questions_to_process = [
242
+ my_questions_data[i] for i in filter if i < len(my_questions_data)
243
  ]
244
+ log_output.append(f"Testing {len(questions_to_process)} selected questions (indices: {filter})")
245
+ else:
246
+ # Process all questions
247
+ questions_to_process = my_questions_data
248
+ log_output.append(f"Testing all {len(questions_to_process)} questions")
249
 
250
+ # Instantiate Agent
251
  try:
252
  my_agent = MyLangGraphAgent()
253
  except Exception as e:
 
257
 
258
  # Process each question separately
259
  try:
260
+ for i, question_item in enumerate(questions_to_process, 1):
261
  # Use .get() for safe access (returns None if key doesn't exist)
262
+ task_id = question_item.get("task_id")
263
  question_text = question_item.get("question")
264
  file_name = question_item.get("file_name")
265
 
 
267
  log_output.append(f"\nQuestion {i}: [ERROR] Missing question text")
268
  continue
269
 
270
+ log_output.append(f"\nQuestion {i} (Task ID: {task_id}): {question_text}")
271
  if file_name:
272
  log_output.append(f"File: {file_name}")
273
 
 
276
 
277
  print(f"Question: {question_text} Answer: {my_answer}")
278
  log_output.append(f"Answer: {my_answer}")
279
+ results_to_verify.append((task_id, question_text, my_answer))
280
  except Exception as e:
281
  error_msg = f"Error running agent on task: {e}"
282
  print(error_msg)
 
314
 
315
  if args.test and not space_id_startup:
316
  print("Running test code (CLI mode)...")
317
+ # Specify question indices to test, or None for all questions
318
+ # Examples:
319
+ # - (0, 1, 3, 4, 5, 9, 11, 13, 14, 17, 18) - All 11 incorrect questions
320
+ # - (0, 1, 4, 5, 14, 17) - All 6 incorrect except ones with files
321
+ # - None - Test all 20 questions
322
+ test_filter = (4,) # Testing Q5, Q8, Q16
323
+ result = run_test_code(filter=test_filter)
324
  if isinstance(result, pd.DataFrame):
325
  # Print DataFrame content without truncation
326
  pd.set_option('display.max_colwidth', None)
requirements.txt CHANGED
@@ -11,12 +11,12 @@ langchain-core
11
  langchain-google-genai
12
  langchain-huggingface
13
  langchain-community
14
- ddgs
15
  pypdf
16
  youtube-transcript-api
17
  pytube
18
  pymupdf
19
- wikipedia
20
  nest_asyncio
21
  speechrecognition
22
  markdownify
 
 
 
11
  langchain-google-genai
12
  langchain-huggingface
13
  langchain-community
 
14
  pypdf
15
  youtube-transcript-api
16
  pytube
17
  pymupdf
 
18
  nest_asyncio
19
  speechrecognition
20
  markdownify
21
+ numpy
22
+ pandas
scorer.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #Official GAIA Scorer Module from HF. Copied from https://huggingface.co/spaces/gaia-benchmark/leaderboard/blob/main/scorer.py for offline Use. Hoping there are no licensing issues as it is intended for learning purposes only.
2
+ #Thanks, Hemant Virmani
3
+
4
+ import json
5
+ import re
6
+ import string
7
+ import warnings
8
+
9
+ import numpy as np
10
+
11
+
12
+ def normalize_number_str(number_str: str) -> float:
13
+ # we replace these common units and commas to allow
14
+ # conversion to float
15
+ for char in ["$", "%", ","]:
16
+ number_str = number_str.replace(char, "")
17
+ try:
18
+ return float(number_str)
19
+ except ValueError:
20
+ print(f"String {number_str} cannot be normalized to number str.")
21
+ return float("inf")
22
+
23
+
24
+ def split_string(
25
+ s: str,
26
+ char_list: list[str] = [",", ";"],
27
+ ) -> list[str]:
28
+ pattern = f"[{''.join(char_list)}]"
29
+ return re.split(pattern, s)
30
+
31
+
32
+ def question_scorer(
33
+ model_answer: str,
34
+ ground_truth: str,
35
+ ) -> bool:
36
+ def is_float(element: any) -> bool:
37
+ try:
38
+ float(element)
39
+ return True
40
+ except ValueError:
41
+ return False
42
+
43
+ if model_answer is None:
44
+ model_answer = "None"
45
+
46
+ # if gt is a number
47
+ if is_float(ground_truth):
48
+ print(f"Evaluating {model_answer} as a number.")
49
+ normalized_answer = normalize_number_str(model_answer)
50
+ return normalized_answer == float(ground_truth)
51
+
52
+ # if gt is a list
53
+ elif any(char in ground_truth for char in [",", ";"]):
54
+ print(f"Evaluating {model_answer} as a comma separated list.")
55
+ # question with the fish: normalization removes punct
56
+
57
+ gt_elems = split_string(ground_truth)
58
+ ma_elems = split_string(model_answer)
59
+
60
+ # check length is the same
61
+ if len(gt_elems) != len(ma_elems):
62
+ warnings.warn(
63
+ "Answer lists have different lengths, returning False.", UserWarning
64
+ )
65
+ return False
66
+
67
+ # compare each element as float or str
68
+ comparisons = []
69
+ for ma_elem, gt_elem in zip(ma_elems, gt_elems):
70
+ if is_float(gt_elem):
71
+ normalized_ma_elem = normalize_number_str(ma_elem)
72
+ comparisons.append(normalized_ma_elem == float(gt_elem))
73
+ else:
74
+ # we do not remove punct since comparisons can include punct
75
+ comparisons.append(
76
+ normalize_str(ma_elem, remove_punct=False)
77
+ == normalize_str(gt_elem, remove_punct=False)
78
+ )
79
+ return all(comparisons)
80
+
81
+ # if gt is a str
82
+ else:
83
+ print(f"Evaluating {model_answer} as a string.")
84
+ return normalize_str(model_answer) == normalize_str(ground_truth)
85
+
86
+
87
+ def normalize_str(input_str, remove_punct=True) -> str:
88
+ """
89
+ Normalize a string by:
90
+ - Removing all white spaces
91
+ - Optionally removing punctuation (if remove_punct is True)
92
+ - Converting to lowercase
93
+ Parameters:
94
+ - input_str: str, the string to normalize
95
+ - remove_punct: bool, whether to remove punctuation (default: True)
96
+ Returns:
97
+ - str, the normalized string
98
+ """
99
+ # Remove all white spaces. Required e.g for seagull vs. sea gull
100
+ no_spaces = re.sub(r"\s", "", input_str)
101
+
102
+ # Remove punctuation, if specified.
103
+ if remove_punct:
104
+ translator = str.maketrans("", "", string.punctuation)
105
+ return no_spaces.lower().translate(translator)
106
+ else:
107
+ return no_spaces.lower()