xiaobin123 commited on
Commit
a1766c7
·
verified ·
1 Parent(s): 07f1898

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +889 -16
app.py CHANGED
@@ -1,22 +1,895 @@
1
- import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def display_info(task_type):
4
- info = {
5
- "Level 1": "Tasks that are easy for humans and can be solved by most LLMs.",
6
- "Level 2": "Tasks requiring multi-modality or multi-step reasoning.",
7
- "Level 3": "Complex tasks requiring tool use and long-term planning."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  }
9
- return info.get(task_type, "Select a level")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
- with gr.Blocks() as demo:
12
- gr.Markdown("# 🚀 My GAIA Research Dashboard")
13
- gr.Markdown("This space is used for analyzing Agent performance on GAIA benchmark.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
- with gr.Row():
16
- input_text = gr.Dropdown(["Level 1", "Level 2", "Level 3"], label="Select GAIA Task Level")
17
- output_text = gr.Textbox(label="Description")
18
 
19
- btn = gr.Button("Analyze Task Structure")
20
- btn.click(fn=display_info, inputs=input_text, outputs=output_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
- demo.launch()
 
1
+ """
2
+ 测试脚本:支持 validation 和 test 数据集,支持全量、增量、混合三种模式
3
+ - 全量模式(full):不管文件是否存在,都删除重新开始
4
+ - 增量模式(incremental):如果文件存在则增量,不存在则全量执行
5
+ - 混合模式(hybrid):第一次时全量(文件不存在),后面就增量(文件存在)
6
+ """
7
+ import argparse
8
+ import json
9
+ import os
10
+ import re
11
+ import time
12
+ import traceback
13
+ from collections import OrderedDict
14
+ from concurrent.futures import ThreadPoolExecutor, as_completed
15
+ from threading import Lock
16
 
17
+ import pandas as pd
18
+ import requests
19
+ from datasets import load_dataset
20
+ from huggingface_hub import snapshot_download, hf_hub_download
21
+ from pathlib import Path
22
+ import shutil
23
+
24
+ # --- 1. 配置区 ---
25
+ BASE_URL = "http://localhost:5173/api/v1"
26
+ CHAT_URL = f"{BASE_URL}/sessions/chat"
27
+ UPLOAD_URL = f"{BASE_URL}/files"
28
+ HEADERS = {"Authorization": "Bearer hawk_YhCZLQYqtPOwOiEyEgeCNdfAFAbrHtTUxQvRiaOInyekgVgE"}
29
+ DATA_PATH = "./gaia_data"
30
+ REQUEST_TIMEOUT = 1800
31
+ MAX_CONCURRENT = 2
32
+
33
+ file_lock = Lock()
34
+
35
+ # 全局变量,根据 split 类型动态设置
36
+ dataset = None
37
+ OUTPUT_FILE = None
38
+ SUBMISSION_FILE = None
39
+ SPLIT_TYPE = None # "validation" 或 "test"
40
+
41
+
42
+ def check_and_download_dataset_files():
43
+ """
44
+ 检查并下载完整的 GAIA 数据集文件到本地目录(包含 validation 和 test 的所有文件)
45
+
46
+ Returns:
47
+ bool: 如果文件已存在或下载成功返回 True,否则返回 False
48
+ """
49
+ base_target_dir = Path(DATA_PATH) / "2023"
50
+ validation_dir = base_target_dir / "validation"
51
+ test_dir = base_target_dir / "test"
52
+
53
+ # 检查两个目录是否都存在且有文件
54
+ validation_files = list(validation_dir.glob("*")) if validation_dir.exists() else []
55
+ test_files = list(test_dir.glob("*")) if test_dir.exists() else []
56
+
57
+ if validation_files and test_files:
58
+ print(f"✅ 检测到数据集文件已存在")
59
+ print(f" validation 文件数: {len(validation_files)}")
60
+ print(f" test 文件数: {len(test_files)}")
61
+ return True
62
+
63
+ # 需要下载数据集文件
64
+ print(f"📥 开始下载完整的 GAIA 数据集文件...")
65
+ print(f" 目标目录: {base_target_dir}")
66
+
67
+ try:
68
+ # 创建基础目录
69
+ base_target_dir.mkdir(parents=True, exist_ok=True)
70
+
71
+ # 下载完整数据集到临时目录,然后复制到目标目录
72
+ print(" 步骤 1/4: 正在从 Hugging Face 下载完整数据集...")
73
+ print(" 提示: 下载进度会显示在下方,请耐心等待...")
74
+ download_start = time.time()
75
+
76
+ # 使用 snapshot_download 下载完整数据集
77
+ cache_dir = snapshot_download(
78
+ repo_id="gaia-benchmark/GAIA",
79
+ repo_type="dataset",
80
+ local_dir=None, # 使用默认缓存目录
81
+ resume_download=True
82
+ )
83
+
84
+ download_duration = time.time() - download_start
85
+ print(f" ✅ 数据集下载完成,耗时 {download_duration:.2f} 秒")
86
+
87
+ cache_path = Path(cache_dir)
88
+ source_2023_dir = cache_path / "2023"
89
+
90
+ if not source_2023_dir.exists():
91
+ print(f" ❌ 错误: 缓存目录中未找到 2023 目录")
92
+ return False
93
+
94
+ # 复制 validation 和 test 目录
95
+ print(" 步骤 2/4: 正在复制 validation 文件...")
96
+ validation_source = source_2023_dir / "validation"
97
+ if validation_source.exists():
98
+ if validation_dir.exists():
99
+ shutil.rmtree(validation_dir)
100
+ shutil.copytree(validation_source, validation_dir)
101
+ validation_count = len(list(validation_dir.glob("*")))
102
+ print(f" ✅ validation 文件复制完成,共 {validation_count} 个文件")
103
+ else:
104
+ print(f" ⚠️ 警告: 未找到 validation 目录")
105
+
106
+ print(" 步骤 3/4: 正在复制 test 文件...")
107
+ test_source = source_2023_dir / "test"
108
+ if test_source.exists():
109
+ if test_dir.exists():
110
+ shutil.rmtree(test_dir)
111
+ shutil.copytree(test_source, test_dir)
112
+ test_count = len(list(test_dir.glob("*")))
113
+ print(f" ✅ test 文件复制完成,共 {test_count} 个文件")
114
+ else:
115
+ print(f" ⚠️ 警告: 未找到 test 目录")
116
+
117
+ print(" 步骤 4/4: 数据集文件准备完成!")
118
+ print(f" 目标目录: {base_target_dir}")
119
+ return True
120
+
121
+ except Exception as e:
122
+ print(f" ❌ 下载数据集文件时出错: {e}")
123
+ import traceback
124
+ traceback.print_exc()
125
+ return False
126
+
127
+
128
+ def build_ordered_record(task_id, question, level, agent_answer, duration, has_file,
129
+ session_id=None, attachment_name=None, ground_truth=None, is_correct=None):
130
+ """
131
+ 按照固定顺序构建记录字典,确保字段顺序一致
132
+
133
+ Args:
134
+ task_id: 任务ID
135
+ question: 问题
136
+ level: 难度级别
137
+ agent_answer: Agent答案
138
+ duration: 执行时长
139
+ has_file: 是否有文件
140
+ session_id: 会话 ID
141
+ attachment_name: 附件名称(如果有附件)
142
+ ground_truth: 标准答案(仅validation数据集)
143
+ is_correct: 是否正确(仅validation数据集)
144
+
145
+ Returns:
146
+ OrderedDict: 按固定顺序排列的记录
147
+ """
148
+ record = OrderedDict()
149
+ record["task_id"] = task_id
150
+ record["question"] = question
151
+ record["level"] = level
152
+ record["duration"] = duration
153
+ record["has_file"] = has_file
154
+ # attachment_name: 如果有值就写入(即使 agent 出错也应该写入)
155
+ # 只要 attachment_name 不是 None 且不是空字符串,就写入
156
+ if attachment_name and attachment_name.strip():
157
+ record["attachment_name"] = attachment_name
158
+ # session_id: 如果有值就写入(如果 agent 出错可能为 None,不写入是合理的)
159
+ if session_id:
160
+ record["session_id"] = session_id
161
+ record["agent_answer"] = agent_answer
162
+ # validation 数据集特有字段
163
+ if ground_truth is not None:
164
+ record["ground_truth"] = ground_truth
165
+ if is_correct is not None:
166
+ record["is_correct"] = is_correct
167
+
168
+ return record
169
+
170
+
171
+ def load_existing_results():
172
+ """
173
+ 加载已有的测试结果文件
174
+ 返回: dict, task_id -> 完整记录字典
175
+ """
176
+ if not os.path.exists(OUTPUT_FILE):
177
+ return {}
178
+
179
+ results = {}
180
+ try:
181
+ with open(OUTPUT_FILE, "r", encoding="utf-8") as f:
182
+ for line in f:
183
+ if not line.strip():
184
+ continue
185
+ try:
186
+ data = json.loads(line)
187
+ task_id = data.get("task_id")
188
+ if task_id:
189
+ results[task_id] = data
190
+ except json.JSONDecodeError:
191
+ continue
192
+ print(f"✅ 已加载 {len(results)} 条历史记录")
193
+ except Exception as e:
194
+ print(f"⚠️ 加载历史记录时出错: {e}")
195
+ return {}
196
+
197
+ return results
198
+
199
+
200
+ def update_result_in_file(task_id, new_record):
201
+ """
202
+ 更新 jsonl 文件中指定 task_id 的记录
203
+ 使用临时文件方式,确保线程安全
204
+ """
205
+ if not os.path.exists(OUTPUT_FILE):
206
+ # 如果文件不存在,直接写入
207
+ with file_lock:
208
+ with open(OUTPUT_FILE, "w", encoding="utf-8") as f:
209
+ f.write(json.dumps(new_record, ensure_ascii=False) + "\n")
210
+ return
211
+
212
+ # 读取所有记录,更新指定记录,写回文件
213
+ with file_lock:
214
+ temp_file = OUTPUT_FILE + ".tmp"
215
+ updated = False
216
+
217
+ try:
218
+ with open(OUTPUT_FILE, "r", encoding="utf-8") as f_in, \
219
+ open(temp_file, "w", encoding="utf-8") as f_out:
220
+ for line in f_in:
221
+ if not line.strip():
222
+ continue
223
+ try:
224
+ data = json.loads(line)
225
+ if data.get("task_id") == task_id:
226
+ # 更新这条记录
227
+ f_out.write(json.dumps(new_record, ensure_ascii=False) + "\n")
228
+ updated = True
229
+ else:
230
+ # 保持原记录
231
+ f_out.write(line)
232
+ except json.JSONDecodeError:
233
+ continue
234
+
235
+ # 如果没找到要更新的记录,追加新记录
236
+ if not updated:
237
+ f_out.write(json.dumps(new_record, ensure_ascii=False) + "\n")
238
+
239
+ # 替换原文件
240
+ os.replace(temp_file, OUTPUT_FILE)
241
+ except Exception as e:
242
+ # 如果出错,删除临时文件
243
+ if os.path.exists(temp_file):
244
+ os.remove(temp_file)
245
+ raise e
246
+
247
+
248
+ def upload_file(local_path):
249
+ """上传文件并返回符合接口要求的 file_id 和 filename"""
250
+ try:
251
+ if not os.path.exists(local_path):
252
+ print(f"❌ 本地文件不存在: {local_path}")
253
+ return None
254
+
255
+ with open(local_path, 'rb') as f:
256
+ files = {'file': f}
257
+ response = requests.post(UPLOAD_URL, headers=HEADERS, files=files, timeout=60)
258
+ response.raise_for_status()
259
+ res_data = response.json()
260
+
261
+ if res_data.get("code") == 0:
262
+ file_info = res_data.get("data", {})
263
+ return {
264
+ "file_id": file_info.get("file_id"),
265
+ "filename": file_info.get("filename")
266
+ }
267
+ else:
268
+ print(f"❌ 上传接口返回错误: {res_data.get('msg')}")
269
+ except Exception as e:
270
+ print(f"❌ 文件上传异常 ({os.path.basename(local_path)}): {e}")
271
+ return None
272
+
273
+
274
+ def extract_answer(text):
275
+ """从文本中提取答案"""
276
+ if not text:
277
+ return ""
278
+ pattern = r"(?si)<\s*answer\s*>\s*(.*?)\s*</\s*answer\s*>"
279
+ match = re.search(pattern, text)
280
+ if match:
281
+ ans = match.group(1).strip()
282
+ return re.sub(r'^["\']|["\']$', '', ans)
283
+ backup_pattern = r"(?i)answer\s*is[::]\s*(.*)"
284
+ backup_match = re.search(backup_pattern, text)
285
+ if backup_match:
286
+ return backup_match.group(1).strip().rstrip('.')
287
+ lines = [l.strip() for l in text.strip().split('\n') if l.strip()]
288
+ return lines[-1] if lines else text.strip()
289
+
290
+
291
+ def call_my_agent_safe(question, attachments=None, task_id=None):
292
+ """
293
+ 发送对话请求,包含附件数组
294
+
295
+ Args:
296
+ question: 问题内容
297
+ attachments: 附件列表
298
+ task_id: 任务ID,用于确保会话隔离
299
+
300
+ Returns:
301
+ tuple: (parsed_answer, session_id, raw_content)
302
+ """
303
+ guided_prompt = (
304
+ f"{question}\n\n Important Requirement: \nprovide the final answer (the answer only, without explanation) inside the tags in the following format: <answer>your answer</answer>"
305
+ )
306
+
307
+ payload = {
308
+ "message": guided_prompt,
309
+ "streaming": False,
310
+ "attachments": attachments if attachments else [],
311
+ "recycle_sandbox": True,
312
+ # 明确指定创建新会话,避免会话内容混乱
313
+ # 如果 API 支持 session_id 参数,设置为 null 表示创建新会话
314
+ # 如果不支持,则不传递 session_id 参数(当前做法)
315
  }
316
+
317
+ # 如果 API 支持,可以尝试以下方式之一来确保创建新会话:
318
+ # 1. payload["session_id"] = None # 明确创建新会话
319
+ # 2. payload["new_session"] = True # 如果 API 支持此参数
320
+ # 3. 在请求头中添加唯一标识
321
+
322
+ if task_id:
323
+ # 添加 task_id 作为请求标识,帮助后端区分不同请求,确保会话隔离
324
+ payload["task_id"] = task_id
325
+
326
+ # 在请求头中添加唯一标识,进一步确保请求隔离
327
+ # 如果后端支持,可以通过 X-Request-ID 或类似头部来区分请求
328
+ request_headers = HEADERS.copy()
329
+ if task_id:
330
+ # 添加 task_id 到请求头,帮助后端识别和隔离不同请求
331
+ request_headers["X-Task-ID"] = task_id
332
+
333
+ try:
334
+ response = requests.post(CHAT_URL, headers=request_headers, json=payload, timeout=(30, REQUEST_TIMEOUT))
335
+ response.raise_for_status()
336
+ res_data = response.json()
337
+ raw_content = (res_data.get("answer") or res_data.get("content") or res_data.get("response") or "").strip()
338
+ session_id = res_data.get("session_id")
339
+ parsed_answer = extract_answer(raw_content)
340
+ return parsed_answer, session_id, raw_content
341
+ except Exception as e:
342
+ error_traceback = traceback.format_exc()
343
+ return f"ERROR: {str(e)}", session_id, error_traceback
344
+
345
+
346
+ def process_item(item, existing_results, mode):
347
+ """
348
+ 处理单条数据:上传文件 -> 发起对话 -> 记录结果
349
+ hybrid + validation 模式下:如果记录已存在且 is_correct 为 true,则跳过 agent 调用,只刷新字段顺序
350
+ 其他情况:所有记录都重新执行并刷新,确保字段顺序一致
351
+ (test 数据集没有 is_correct 字段,无法判断是否正确,所以总是重新执行)
352
+
353
+ Args:
354
+ item: 数据集项
355
+ existing_results: 已有结果字典
356
+ mode: 执行模式 ("full"、"incremental" 或 "hybrid")
357
+ """
358
+ task_id = item['task_id']
359
+ level = item.get('Level', 'Unknown')
360
+ question = item['Question']
361
+ file_name = item.get('file_name', "")
362
+
363
+ # hybrid + validation 模式下:如果记录已存在且成功,只刷新字段顺序,不调用 agent
364
+ # 只有 validation 数据集有 is_correct 字段,可以判断是否正确
365
+ if mode == "hybrid" and SPLIT_TYPE == "validation" and task_id in existing_results:
366
+ existing_record = existing_results[task_id]
367
+ if existing_record.get("is_correct", False):
368
+ # 已成功,只刷新字段顺序,不调用 agent
369
+ # 使用当前的 file_name 更新 attachment_name,确保数据一致性
370
+ current_has_file = bool(file_name)
371
+ current_attachment_name = file_name if file_name else None
372
+ record = build_ordered_record(
373
+ task_id=task_id,
374
+ question=existing_record.get("question", question),
375
+ level=existing_record.get("level", level),
376
+ agent_answer=existing_record.get("agent_answer", ""),
377
+ duration=existing_record.get("duration", 0),
378
+ has_file=current_has_file,
379
+ session_id=existing_record.get("session_id"),
380
+ attachment_name=current_attachment_name,
381
+ ground_truth=existing_record.get("ground_truth", ""),
382
+ is_correct=True
383
+ )
384
+ # 更新已有记录(刷新字段顺序)
385
+ update_result_in_file(task_id, record)
386
+ return task_id, True, "refreshed"
387
+
388
+ # 需要调用 agent 的情况(新记录、错误记录、或非 hybrid 模式)
389
+ attachments = []
390
 
391
+ # 1. 如果有文件,先执行上传
392
+ if file_name:
393
+ # 根据 split 类型选择不同的文件夹
394
+ folder = "validation" if SPLIT_TYPE == "validation" else "test"
395
+ local_file_path = os.path.abspath(os.path.join(DATA_PATH, "2023", folder, file_name))
396
+ upload_data = upload_file(local_file_path)
397
+ if upload_data:
398
+ attachments.append(upload_data)
399
+
400
+ # 2. 调用 Agent(传递 task_id 确保会话隔离)
401
+ start_time = time.time()
402
+ agent_answer, session_id, _ = call_my_agent_safe(question, attachments, task_id=task_id)
403
+ duration = time.time() - start_time
404
+
405
+ # 3. 构建记录(使用固定顺序)
406
+ if SPLIT_TYPE == "validation":
407
+ # validation 数据集:添加标准答案和正确性判断
408
+ ground_truth = str(item['Final answer']).strip()
409
+ clean_agent = str(agent_answer).lower().rstrip('.')
410
+ clean_gt = ground_truth.lower().rstrip('.')
411
+ is_correct = (clean_agent == clean_gt)
412
+ record = build_ordered_record(
413
+ task_id=task_id,
414
+ question=question,
415
+ level=level,
416
+ duration=round(duration, 2),
417
+ has_file=bool(file_name),
418
+ session_id=session_id,
419
+ attachment_name=file_name if file_name else None,
420
+ agent_answer=agent_answer,
421
+ ground_truth=ground_truth,
422
+ is_correct=is_correct
423
+ )
424
+ result_correct = is_correct
425
+ else: # test 数据集:没有标准答案
426
+ record = build_ordered_record(
427
+ task_id=task_id,
428
+ question=question,
429
+ level=level,
430
+ duration=round(duration, 2),
431
+ has_file=bool(file_name),
432
+ session_id=session_id,
433
+ attachment_name=file_name if file_name else None,
434
+ agent_answer=agent_answer
435
+ )
436
+ result_correct = None
437
+
438
+ # 4. 更新或追加记录
439
+ if task_id in existing_results:
440
+ # 更新已有记录
441
+ update_result_in_file(task_id, record)
442
+ return task_id, result_correct, "updated"
443
+ else:
444
+ # 追加新记录
445
+ with file_lock:
446
+ with open(OUTPUT_FILE, "a", encoding="utf-8") as f:
447
+ f.write(json.dumps(record, ensure_ascii=False) + "\n")
448
+ return task_id, result_correct, "new"
449
+
450
+
451
+ def generate_submission():
452
+ """
453
+ 生成官网提交格式文件
454
+ GAIA 提交格式要求:
455
+ - 文件格式:JSONL(每行一个 JSON 对象)
456
+ - 必需字段:task_id, model_answer
457
+ - 编码:UTF-8
458
+ - test 数据集需要包含所有 285 个用例的答案
459
+ """
460
+ if not os.path.exists(OUTPUT_FILE):
461
+ print(f"⚠️ 警告:结果文件 {OUTPUT_FILE} 不存在,无法生成提交文件")
462
+ return
463
+
464
+ # 读取所有结果并按 task_id 排序(确保顺序一致)
465
+ results = []
466
+ with open(OUTPUT_FILE, "r", encoding="utf-8") as f:
467
+ for line in f:
468
+ if not line.strip():
469
+ continue
470
+ try:
471
+ data = json.loads(line)
472
+ if "task_id" in data and "agent_answer" in data:
473
+ results.append(data)
474
+ except json.JSONDecodeError:
475
+ continue
476
+
477
+ if not results:
478
+ print(f"⚠️ 警告:结果文件 {OUTPUT_FILE} 中没有有效数据")
479
+ return
480
+
481
+ # 按 task_id 排序,确保顺序一致
482
+ results.sort(key=lambda x: x.get("task_id", ""))
483
+
484
+ # 生成提交文件
485
+ with open(SUBMISSION_FILE, "w", encoding="utf-8") as f_out:
486
+ for data in results:
487
+ submission_data = {
488
+ "task_id": data["task_id"],
489
+ "model_answer": str(data["agent_answer"])
490
+ }
491
+ f_out.write(json.dumps(submission_data, ensure_ascii=False) + "\n")
492
+
493
+ print(f"✅ 提交文件已生成: {SUBMISSION_FILE} (共 {len(results)} 条记录)")
494
+
495
+ # test 数据集验证:检查是否包含所有用例
496
+ if SPLIT_TYPE == "test":
497
+ expected_count = 285
498
+ if len(results) < expected_count:
499
+ print(f"⚠️ 警告:test 数据集应该有 {expected_count} 个用例,当前只有 {len(results)} 个")
500
+ else:
501
+ print(f"✅ test 数据集已包含 {len(results)} 个用例,符合提交要求")
502
+
503
+
504
+ def get_current_accuracy():
505
+ """
506
+ 获取当前的整体正确率(仅 validation 数据集)
507
+
508
+ Returns:
509
+ float or None: 正确率(百分比),如果不是 validation 数据集或文件不存在则返回 None
510
+ """
511
+ # test 数据集没有标准答案,无法计算正确率
512
+ if SPLIT_TYPE != "validation":
513
+ return None
514
 
515
+ if not os.path.exists(OUTPUT_FILE):
516
+ return None
 
517
 
518
+ try:
519
+ results = []
520
+ with open(OUTPUT_FILE, "r", encoding="utf-8") as f:
521
+ for line in f:
522
+ if not line.strip():
523
+ continue
524
+ try:
525
+ data = json.loads(line)
526
+ results.append(data)
527
+ except json.JSONDecodeError:
528
+ continue
529
+
530
+ if not results:
531
+ return None
532
+
533
+ total = len(results)
534
+ correct = sum(1 for r in results if r.get("is_correct", False))
535
+ accuracy = (correct / total * 100) if total > 0 else 0.0
536
+ return accuracy
537
+ except Exception:
538
+ return None
539
+
540
+
541
+ def generate_report():
542
+ """生成统计成绩单(仅 validation 数据集有标准答案,才生成成绩单)"""
543
+ # test 数据集没有标准答案,不生成成绩单
544
+ if SPLIT_TYPE != "validation":
545
+ return
546
+
547
+ if not os.path.exists(OUTPUT_FILE):
548
+ return
549
+
550
+ results = [json.loads(line) for line in open(OUTPUT_FILE, "r", encoding="utf-8")]
551
+ df = pd.DataFrame(results)
552
+ total = len(df)
553
+ acc = (df['is_correct'].sum() / total) * 100
554
+
555
+ print("\n" + "=" * 50)
556
+ print(f"测试完成! 总数: {total} | 总准确率: {acc:.2f}%")
557
+ print("=" * 50)
558
+
559
+
560
+ def run_test_concurrent(num_questions=200, mode="hybrid", split="validation", threads=MAX_CONCURRENT, target_task_id=None):
561
+ """
562
+ 测试主函数
563
+
564
+ Args:
565
+ num_questions: 要执行的用例数量
566
+ mode: 执行模式
567
+ - "full": 全量模式,不管文件是否存在,都删除重新开始
568
+ - "incremental": 增量模式,如果文件存在则增量,不存在则全量执行
569
+ - "hybrid": 混合模式,第一次时全量(文件不存在),后面就增量(文件存在)
570
+ - "error": 错误模式,只重新执行 agent_answer 包含 ERROR 的记录
571
+ split: 数据集类型,"validation" 或 "test"
572
+ threads: 并发线程数
573
+ target_task_id: 可选,指定要运行的 task_id,如果指定则只运行该用例
574
+ """
575
+ global dataset, OUTPUT_FILE, SUBMISSION_FILE, SPLIT_TYPE
576
+
577
+ # 设置全局变量
578
+ SPLIT_TYPE = split
579
+
580
+ # 根据 split 类型设置输出文件名
581
+ if split == "validation":
582
+ OUTPUT_FILE = "validation_results.jsonl"
583
+ SUBMISSION_FILE = "validation_submission.jsonl"
584
+ print("📥 正在检查 GAIA 验证集数据...")
585
+ else: # test
586
+ OUTPUT_FILE = "test_results.jsonl"
587
+ SUBMISSION_FILE = "test_submission.jsonl"
588
+ print("📥 正在检查 GAIA 测试集数据...")
589
+
590
+ # 1. 检查并下载完整数据集文件(如果需要,包含 validation 和 test 的所有文件)
591
+ print("\n【步骤 1/2】检查数据集文件...")
592
+ check_and_download_dataset_files()
593
+
594
+ # 2. 加载数据集元数据(如果首次下载会显示下载进度)
595
+ print(f"\n【步骤 2/2】加载数据集元数据...")
596
+ print(f" 数据集: gaia-benchmark/GAIA (2023_all, split={split})")
597
+ print(" 提示: 如果是首次下载,请耐心等待,下载进度会显示在下方...")
598
+ print(" 如果已下载过,会直接从缓存加载,速度较快")
599
+ start_time = time.time()
600
+ dataset = load_dataset("gaia-benchmark/GAIA", "2023_all", split=split)
601
+ load_duration = time.time() - start_time
602
+ print(f"✅ 数据集元数据加载完成!共 {len(dataset)} 条记录,耗时 {load_duration:.2f} 秒\n")
603
+
604
+ # 1. 根据模式处理已有结果
605
+ file_exists = os.path.exists(OUTPUT_FILE)
606
+
607
+ if mode == "full":
608
+ # 全量模式:删除旧文件,从头开始
609
+ if file_exists:
610
+ os.remove(OUTPUT_FILE)
611
+ print("🔄 全量模式:已删除旧结果文件,从头开始执行")
612
+ existing_results = {}
613
+ elif mode == "incremental":
614
+ # 增量模式:如果文件存在则增量,不存在则全量执行
615
+ if file_exists:
616
+ existing_results = load_existing_results()
617
+ print(f"📋 增量模式:已加载 {len(existing_results)} 条历史记录")
618
+ else:
619
+ existing_results = {}
620
+ print("📋 增量模式:未找到历史记录,将全量执行")
621
+ elif mode == "error":
622
+ # 错误模式:只重新执行 agent_answer 包含 ERROR 的记录
623
+ if file_exists:
624
+ existing_results = load_existing_results()
625
+ print(f"📋 错误模式:已加载 {len(existing_results)} 条历史记录,将重新执行包含 ERROR 的记录")
626
+ else:
627
+ existing_results = {}
628
+ print("📋 错误模式:未找到历史记录,无法执行错误重试")
629
+ else: # hybrid
630
+ # 混合模式:第一次时全量(文件不存在),后面就增量(文件存在)
631
+ if file_exists:
632
+ existing_results = load_existing_results()
633
+ print(f"📋 混合模式:检测到已有文件,进入增量模式(已加载 {len(existing_results)} 条历史记录)")
634
+ else:
635
+ existing_results = {}
636
+ print("📋 混合模式:首次执行,进入全量模式")
637
+
638
+ # 2. 筛选需要执行的用例
639
+ if target_task_id:
640
+ # 如果指定了 task_id,只运行该用例
641
+ print(f"🎯 指定运行 task_id: {target_task_id}")
642
+ tasks_to_run = []
643
+ found = False
644
+ for item in dataset:
645
+ if item['task_id'] == target_task_id:
646
+ tasks_to_run = [item]
647
+ found = True
648
+ break
649
+ if not found:
650
+ print(f"❌ 错误: 在 {split} 数据集中未找到 task_id: {target_task_id}")
651
+ return
652
+ num_to_run = 1
653
+ else:
654
+ # 正常模式,根据 num_questions 筛选
655
+ num_to_run = min(num_questions, len(dataset))
656
+ tasks_to_run = dataset.select(range(num_to_run))
657
+
658
+ # 统计需要执行的用例
659
+ tasks_to_execute = []
660
+ refresh_count = 0 # hybrid 模式下只刷新字段顺序的记录数
661
+ update_count = 0 # 需要重新调用 agent 的记录数
662
+ new_count = 0 # 新记录数
663
+ error_count = 0 # error 模式下包含 ERROR 的记录数
664
+
665
+ for item in tasks_to_run:
666
+ task_id = item['task_id']
667
+
668
+ if mode == "error":
669
+ # error 模式:只重新执行 agent_answer 包含 ERROR 的记录
670
+ if task_id in existing_results:
671
+ agent_answer = existing_results[task_id].get("agent_answer", "")
672
+ if agent_answer and "ERROR" in str(agent_answer):
673
+ error_count += 1
674
+ tasks_to_execute.append(item)
675
+ # 如果记录不存在或 agent_answer 不包含 ERROR,则跳过
676
+ else:
677
+ # 其他模式:正常处理
678
+ if task_id in existing_results:
679
+ # hybrid + validation 模式下:如果已成功,只刷新字段顺序
680
+ # test 数据集没有 is_correct 字段,无法判断是否正确,所以总是重新执行
681
+ if mode == "hybrid" and split == "validation":
682
+ if existing_results[task_id].get("is_correct", False):
683
+ refresh_count += 1
684
+ else:
685
+ update_count += 1
686
+ else:
687
+ # 非 hybrid 模式,或 test 数据集:所有已有记录都需要重新执行
688
+ update_count += 1
689
+ else:
690
+ new_count += 1
691
+ tasks_to_execute.append(item)
692
+
693
+ total_to_execute = len(tasks_to_execute)
694
+
695
+ print(f"\n📊 统计信息:")
696
+ print(f" 数据集: {split}")
697
+ print(f" 执行模式: {mode}")
698
+ if target_task_id:
699
+ print(f" 指定 task_id: {target_task_id}")
700
+ print(f" 总用例数: {num_to_run}")
701
+ if existing_results:
702
+ if mode == "error":
703
+ print(f" 需要执行: {total_to_execute} (包含 ERROR 的记录: {error_count})")
704
+ elif mode == "hybrid":
705
+ print(
706
+ f" 需要执行: {total_to_execute} (新用例: {new_count}, 刷新字段顺序: {refresh_count}, 重新测试: {update_count})")
707
+ else:
708
+ print(
709
+ f" 需要执行: {total_to_execute} (新用例: {new_count}, 刷新已有记录: {refresh_count + update_count})")
710
+ else:
711
+ print(f" 需要执行: {total_to_execute} (全量执行)")
712
+ print(f"🚀 开始测试 | 并发数: {threads} | 待执行: {total_to_execute}")
713
+
714
+ if total_to_execute == 0:
715
+ if mode == "error":
716
+ print("✅ 没有包含 ERROR 的记录,无需执行")
717
+ elif split == "validation":
718
+ print("✅ 所有用例已完成且正确,无需执行")
719
+ else:
720
+ print("✅ 所有用例已完成,无需执行")
721
+ generate_report()
722
+ generate_submission()
723
+ return
724
+
725
+ # 3. 并发执行
726
+ with ThreadPoolExecutor(max_workers=threads) as executor:
727
+ future_to_item = {executor.submit(process_item, item, existing_results, mode): item for item in
728
+ tasks_to_execute}
729
+
730
+ done = 0
731
+ for future in as_completed(future_to_item):
732
+ done += 1
733
+ item = future_to_item[future]
734
+ tid = item['task_id']
735
+ try:
736
+ _, is_ok, status = future.result()
737
+ if status == "refreshed":
738
+ status_icon = "🔄"
739
+ elif split == "validation":
740
+ status_icon = "✅" if is_ok else "❌"
741
+ else: # test
742
+ status_icon = "✅"
743
+
744
+ # 计算并显示当前整体正确率
745
+ accuracy_info = ""
746
+ if split == "validation":
747
+ current_accuracy = get_current_accuracy()
748
+ if current_accuracy is not None:
749
+ accuracy_info = f" | 当前正确率: {current_accuracy:.2f}%"
750
+
751
+ print(f"[{done}/{total_to_execute}] ID: {tid} | 状态: {status_icon} ({status}){accuracy_info}")
752
+ except Exception as e:
753
+ error_traceback = traceback.format_exc()
754
+ print(f"[{done}/{total_to_execute}] ID: {tid} 运行异常: {e}")
755
+ print(f"异常堆栈:\n{error_traceback}")
756
+
757
+ # 4. 生成报表
758
+ generate_report()
759
+ generate_submission()
760
+
761
+
762
+ def print_help():
763
+ """打印详细的帮助信息"""
764
+ print("=" * 70)
765
+ print("GAIA 测试脚本 - 参数说明")
766
+ print("=" * 70)
767
+ print()
768
+ print("用法:")
769
+ print(" python gaia_test.py [参数]")
770
+ print()
771
+ print("参数说明:")
772
+ print()
773
+ print(" --split <类型>")
774
+ print(" 数据集类型")
775
+ print(" 可选值: validation, test")
776
+ print(" 默认值: validation")
777
+ print(" 说明:")
778
+ print(" - validation: 验证集,有标准答案,可以计算正确率")
779
+ print(" - test: 测试集,无标准答案,用于最终提交")
780
+ print()
781
+ print(" --mode <模式>")
782
+ print(" 执行模式")
783
+ print(" 可选值: full, incremental, hybrid, error")
784
+ print(" 默认值: hybrid")
785
+ print(" 说明:")
786
+ print(" - full: 全量模式,删除旧结果文件,从头开始执行")
787
+ print(" - incremental: 增量模式,如果文件存在则增量,不存在则全量执行")
788
+ print(" - hybrid: 混合模式(推荐),首次全量,后续增量")
789
+ print(" 在 hybrid 模式下,validation 数据集中已正确的记录")
790
+ print(" 只刷新字段顺序,不重新调用 agent")
791
+ print(" - error: 错误模式,只重新执行 agent_answer 包含 ERROR 的记录")
792
+ print()
793
+ print(" --num <数量>")
794
+ print(" 要执行的用例数量")
795
+ print(" 类型: 整数")
796
+ print(" 默认值: 200")
797
+ print(" 说明:")
798
+ print(" - test 数据集共 285 题,可以设置 --num 285 执行全部")
799
+ print(" - validation 数据集可以根据需要设置数量")
800
+ print()
801
+ print(" --threads <数量>")
802
+ print(" 并发执行的线程数")
803
+ print(" 类型: 整数")
804
+ print(" 默认值: 2")
805
+ print(" 说明:")
806
+ print(" - 根据服务器性能调整,过高可能导致服务器压力过大")
807
+ print(" - 建议范围: 1-4")
808
+ print()
809
+ print(" --task-id <task_id>")
810
+ print(" 指定要运行的 task_id")
811
+ print(" 类型: 字符串")
812
+ print(" 默认值: 无(运行多个用例)")
813
+ print(" 说明:")
814
+ print(" - 如果指定此参数,则只运行该 task_id 对应的用例")
815
+ print(" - 指定此参数时,--num 参数会被忽略")
816
+ print(" - 如果指定的 task_id 不存在,脚本会报错并退出")
817
+ print()
818
+ print(" -h, --help")
819
+ print(" 显示此帮助信息并退出")
820
+ print()
821
+ print("示例:")
822
+ print(" # 使用默认参数(validation 数据集,hybrid 模式,200 题)")
823
+ print(" python gaia_test.py")
824
+ print()
825
+ print(" # 测试 test 数据集,执行全部 285 题")
826
+ print(" python gaia_test.py --split test --num 285")
827
+ print()
828
+ print(" # 使用 error 模式重新执行错误记录")
829
+ print(" python gaia_test.py --mode error")
830
+ print()
831
+ print(" # 使用全量模式,4 个并发线程")
832
+ print(" python gaia_test.py --mode full --threads 4")
833
+ print()
834
+ print(" # 运行指定的 task_id")
835
+ print(" python gaia_test.py --task-id c61d22de-5f6c-4958-a7f6-5e9707bd3466")
836
+ print()
837
+ print("=" * 70)
838
+ print("配置文件:")
839
+ print(" 运行前请确保已正确配置 gaia_test.py 中的以下参数:")
840
+ print(" - BASE_URL: API 服务地址")
841
+ print(" - HEADERS: 认证 Token(必须修改)")
842
+ print(" - DATA_PATH: 数据文件路径")
843
+ print(" - REQUEST_TIMEOUT: 请求超时时间")
844
+ print(" - MAX_CONCURRENT: 最大并发数")
845
+ print("=" * 70)
846
+
847
+
848
+ if __name__ == "__main__":
849
+ import sys
850
+
851
+ # 检查是否有 -h 或 --help 参数
852
+ if "-h" in sys.argv or "--help" in sys.argv:
853
+ print_help()
854
+ sys.exit(0)
855
+
856
+ parser = argparse.ArgumentParser(
857
+ description="GAIA 测试脚本(支持 validation 和 test 数据集,支持全量、���量、混合三种模式)",
858
+ formatter_class=argparse.RawDescriptionHelpFormatter
859
+ )
860
+ parser.add_argument(
861
+ "--split",
862
+ type=str,
863
+ choices=["validation", "test"],
864
+ default="validation",
865
+ help="数据集类型: 'validation' 验证集(有标准答案)、'test' 测试集(无标准答案,默认: validation)"
866
+ )
867
+ parser.add_argument(
868
+ "--mode",
869
+ type=str,
870
+ choices=["full", "incremental", "hybrid", "error"],
871
+ default="hybrid",
872
+ help="执行模式: 'full' 全量模式(删除旧文件重新执行)、'incremental' 增量模式(文件存在则增量,不存在则全量)、'hybrid' 混合模式(首次全量,后续增量,默认)、'error' 错误模式(只重新执行 agent_answer 包含 ERROR 的记录)"
873
+ )
874
+ parser.add_argument(
875
+ "--num",
876
+ type=int,
877
+ default=200,
878
+ help="要执行的用例数量(默认: 200,test 集共 285 题)"
879
+ )
880
+ parser.add_argument(
881
+ "--threads",
882
+ type=int,
883
+ default=MAX_CONCURRENT,
884
+ help="执行的并发数(默认: 2)"
885
+ )
886
+ parser.add_argument(
887
+ "--task-id",
888
+ type=str,
889
+ default=None,
890
+ help="指定要运行的 task_id,如果指定则只运行该用例(忽略 --num 参数)"
891
+ )
892
+
893
+ args = parser.parse_args()
894
 
895
+ run_test_concurrent(num_questions=args.num, mode=args.mode, split=args.split, threads=args.threads, target_task_id=args.task_id)