xiaobin123 commited on
Commit
07f1898
·
verified ·
1 Parent(s): b77683b

Delete gaia_test.py

Browse files
Files changed (1) hide show
  1. gaia_test.py +0 -895
gaia_test.py DELETED
@@ -1,895 +0,0 @@
1
- """
2
- 测试脚本:支持 validation 和 test 数据集,支持全量、增量、混合三种模式
3
- - 全量模式(full):不管文件是否存在,都删除重新开始
4
- - 增量模式(incremental):如果文件存在则增量,不存在则全量执行
5
- - 混合模式(hybrid):第一次时全量(文件不存在),后面就增量(文件存在)
6
- """
7
- import argparse
8
- import json
9
- import os
10
- import re
11
- import time
12
- import traceback
13
- from collections import OrderedDict
14
- from concurrent.futures import ThreadPoolExecutor, as_completed
15
- from threading import Lock
16
-
17
- import pandas as pd
18
- import requests
19
- from datasets import load_dataset
20
- from huggingface_hub import snapshot_download, hf_hub_download
21
- from pathlib import Path
22
- import shutil
23
-
24
- # --- 1. 配置区 ---
25
- BASE_URL = "http://localhost:5173/api/v1"
26
- CHAT_URL = f"{BASE_URL}/sessions/chat"
27
- UPLOAD_URL = f"{BASE_URL}/files"
28
- HEADERS = {"Authorization": "Bearer hawk_YhCZLQYqtPOwOiEyEgeCNdfAFAbrHtTUxQvRiaOInyekgVgE"}
29
- DATA_PATH = "./gaia_data"
30
- REQUEST_TIMEOUT = 1800
31
- MAX_CONCURRENT = 2
32
-
33
- file_lock = Lock()
34
-
35
- # 全局变量,根据 split 类型动态设置
36
- dataset = None
37
- OUTPUT_FILE = None
38
- SUBMISSION_FILE = None
39
- SPLIT_TYPE = None # "validation" 或 "test"
40
-
41
-
42
- def check_and_download_dataset_files():
43
- """
44
- 检查并下载完整的 GAIA 数据集文件到本地目录(包含 validation 和 test 的所有文件)
45
-
46
- Returns:
47
- bool: 如果文件已存在或下载成功返回 True,否则返回 False
48
- """
49
- base_target_dir = Path(DATA_PATH) / "2023"
50
- validation_dir = base_target_dir / "validation"
51
- test_dir = base_target_dir / "test"
52
-
53
- # 检查两个目录是否都存在且有文件
54
- validation_files = list(validation_dir.glob("*")) if validation_dir.exists() else []
55
- test_files = list(test_dir.glob("*")) if test_dir.exists() else []
56
-
57
- if validation_files and test_files:
58
- print(f"✅ 检测到数据集文件已存在")
59
- print(f" validation 文件数: {len(validation_files)}")
60
- print(f" test 文件数: {len(test_files)}")
61
- return True
62
-
63
- # 需要下载数据集文件
64
- print(f"📥 开始下载完整的 GAIA 数据集文件...")
65
- print(f" 目标目录: {base_target_dir}")
66
-
67
- try:
68
- # 创建基础目录
69
- base_target_dir.mkdir(parents=True, exist_ok=True)
70
-
71
- # 下载完整数据集到临时目录,然后复制到目标目录
72
- print(" 步骤 1/4: 正在从 Hugging Face 下载完整数据集...")
73
- print(" 提示: 下载进度会显示在下方,请耐心等待...")
74
- download_start = time.time()
75
-
76
- # 使用 snapshot_download 下载完整数据集
77
- cache_dir = snapshot_download(
78
- repo_id="gaia-benchmark/GAIA",
79
- repo_type="dataset",
80
- local_dir=None, # 使用默认缓存目录
81
- resume_download=True
82
- )
83
-
84
- download_duration = time.time() - download_start
85
- print(f" ✅ 数据集下载完成,耗时 {download_duration:.2f} 秒")
86
-
87
- cache_path = Path(cache_dir)
88
- source_2023_dir = cache_path / "2023"
89
-
90
- if not source_2023_dir.exists():
91
- print(f" ❌ 错误: 缓存目录中未找到 2023 目录")
92
- return False
93
-
94
- # 复制 validation 和 test 目录
95
- print(" 步骤 2/4: 正在复制 validation 文件...")
96
- validation_source = source_2023_dir / "validation"
97
- if validation_source.exists():
98
- if validation_dir.exists():
99
- shutil.rmtree(validation_dir)
100
- shutil.copytree(validation_source, validation_dir)
101
- validation_count = len(list(validation_dir.glob("*")))
102
- print(f" ✅ validation 文件复制完成,共 {validation_count} 个文件")
103
- else:
104
- print(f" ⚠️ 警告: 未找到 validation 目录")
105
-
106
- print(" 步骤 3/4: 正在复制 test 文件...")
107
- test_source = source_2023_dir / "test"
108
- if test_source.exists():
109
- if test_dir.exists():
110
- shutil.rmtree(test_dir)
111
- shutil.copytree(test_source, test_dir)
112
- test_count = len(list(test_dir.glob("*")))
113
- print(f" ✅ test 文件复制完成,共 {test_count} 个文件")
114
- else:
115
- print(f" ⚠️ 警告: 未找到 test 目录")
116
-
117
- print(" 步骤 4/4: 数据集文件准备完成!")
118
- print(f" 目标目录: {base_target_dir}")
119
- return True
120
-
121
- except Exception as e:
122
- print(f" ❌ 下载数据集文件时出错: {e}")
123
- import traceback
124
- traceback.print_exc()
125
- return False
126
-
127
-
128
- def build_ordered_record(task_id, question, level, agent_answer, duration, has_file,
129
- session_id=None, attachment_name=None, ground_truth=None, is_correct=None):
130
- """
131
- 按照固定顺序构建记录字典,确保字段��序一致
132
-
133
- Args:
134
- task_id: 任务ID
135
- question: 问题
136
- level: 难度级别
137
- agent_answer: Agent答案
138
- duration: 执行时长
139
- has_file: 是否有文件
140
- session_id: 会话 ID
141
- attachment_name: 附件名称(如果有附件)
142
- ground_truth: 标准答案(仅validation数据集)
143
- is_correct: 是否正确(仅validation数据集)
144
-
145
- Returns:
146
- OrderedDict: 按固定顺序排列的记录
147
- """
148
- record = OrderedDict()
149
- record["task_id"] = task_id
150
- record["question"] = question
151
- record["level"] = level
152
- record["duration"] = duration
153
- record["has_file"] = has_file
154
- # attachment_name: 如果有值就写入(即使 agent 出错也应该写入)
155
- # 只要 attachment_name 不是 None 且不是空字符串,就写入
156
- if attachment_name and attachment_name.strip():
157
- record["attachment_name"] = attachment_name
158
- # session_id: 如果有值就写入(如果 agent 出错可能为 None,不写入是合理的)
159
- if session_id:
160
- record["session_id"] = session_id
161
- record["agent_answer"] = agent_answer
162
- # validation 数据集特有字段
163
- if ground_truth is not None:
164
- record["ground_truth"] = ground_truth
165
- if is_correct is not None:
166
- record["is_correct"] = is_correct
167
-
168
- return record
169
-
170
-
171
- def load_existing_results():
172
- """
173
- 加载已有的测试结果文件
174
- 返回: dict, task_id -> 完整记录字典
175
- """
176
- if not os.path.exists(OUTPUT_FILE):
177
- return {}
178
-
179
- results = {}
180
- try:
181
- with open(OUTPUT_FILE, "r", encoding="utf-8") as f:
182
- for line in f:
183
- if not line.strip():
184
- continue
185
- try:
186
- data = json.loads(line)
187
- task_id = data.get("task_id")
188
- if task_id:
189
- results[task_id] = data
190
- except json.JSONDecodeError:
191
- continue
192
- print(f"✅ 已加载 {len(results)} 条历史记录")
193
- except Exception as e:
194
- print(f"⚠️ 加载历史记录时出错: {e}")
195
- return {}
196
-
197
- return results
198
-
199
-
200
- def update_result_in_file(task_id, new_record):
201
- """
202
- 更新 jsonl 文件中指定 task_id 的记录
203
- 使用临时文件方式,确保线程安全
204
- """
205
- if not os.path.exists(OUTPUT_FILE):
206
- # 如果文件不存在,直接写入
207
- with file_lock:
208
- with open(OUTPUT_FILE, "w", encoding="utf-8") as f:
209
- f.write(json.dumps(new_record, ensure_ascii=False) + "\n")
210
- return
211
-
212
- # 读取所有记录,更新指定记录,写回文件
213
- with file_lock:
214
- temp_file = OUTPUT_FILE + ".tmp"
215
- updated = False
216
-
217
- try:
218
- with open(OUTPUT_FILE, "r", encoding="utf-8") as f_in, \
219
- open(temp_file, "w", encoding="utf-8") as f_out:
220
- for line in f_in:
221
- if not line.strip():
222
- continue
223
- try:
224
- data = json.loads(line)
225
- if data.get("task_id") == task_id:
226
- # 更新这条记录
227
- f_out.write(json.dumps(new_record, ensure_ascii=False) + "\n")
228
- updated = True
229
- else:
230
- # 保持原记录
231
- f_out.write(line)
232
- except json.JSONDecodeError:
233
- continue
234
-
235
- # 如果没找到要更新的记录,追加新记录
236
- if not updated:
237
- f_out.write(json.dumps(new_record, ensure_ascii=False) + "\n")
238
-
239
- # 替换原文件
240
- os.replace(temp_file, OUTPUT_FILE)
241
- except Exception as e:
242
- # 如果出错,删除临时文件
243
- if os.path.exists(temp_file):
244
- os.remove(temp_file)
245
- raise e
246
-
247
-
248
- def upload_file(local_path):
249
- """上传文件并返回符合接口要求的 file_id 和 filename"""
250
- try:
251
- if not os.path.exists(local_path):
252
- print(f"❌ 本地文件不存在: {local_path}")
253
- return None
254
-
255
- with open(local_path, 'rb') as f:
256
- files = {'file': f}
257
- response = requests.post(UPLOAD_URL, headers=HEADERS, files=files, timeout=60)
258
- response.raise_for_status()
259
- res_data = response.json()
260
-
261
- if res_data.get("code") == 0:
262
- file_info = res_data.get("data", {})
263
- return {
264
- "file_id": file_info.get("file_id"),
265
- "filename": file_info.get("filename")
266
- }
267
- else:
268
- print(f"❌ 上传接口返回错误: {res_data.get('msg')}")
269
- except Exception as e:
270
- print(f"❌ 文件上传异常 ({os.path.basename(local_path)}): {e}")
271
- return None
272
-
273
-
274
- def extract_answer(text):
275
- """从文本中提取答案"""
276
- if not text:
277
- return ""
278
- pattern = r"(?si)<\s*answer\s*>\s*(.*?)\s*</\s*answer\s*>"
279
- match = re.search(pattern, text)
280
- if match:
281
- ans = match.group(1).strip()
282
- return re.sub(r'^["\']|["\']$', '', ans)
283
- backup_pattern = r"(?i)answer\s*is[::]\s*(.*)"
284
- backup_match = re.search(backup_pattern, text)
285
- if backup_match:
286
- return backup_match.group(1).strip().rstrip('.')
287
- lines = [l.strip() for l in text.strip().split('\n') if l.strip()]
288
- return lines[-1] if lines else text.strip()
289
-
290
-
291
- def call_my_agent_safe(question, attachments=None, task_id=None):
292
- """
293
- 发送对话请求,包含附件数组
294
-
295
- Args:
296
- question: 问题内容
297
- attachments: 附件列表
298
- task_id: 任务ID,用于确保会话隔离
299
-
300
- Returns:
301
- tuple: (parsed_answer, session_id, raw_content)
302
- """
303
- guided_prompt = (
304
- f"{question}\n\n Important Requirement: \nprovide the final answer (the answer only, without explanation) inside the tags in the following format: <answer>your answer</answer>"
305
- )
306
-
307
- payload = {
308
- "message": guided_prompt,
309
- "streaming": False,
310
- "attachments": attachments if attachments else [],
311
- "recycle_sandbox": True,
312
- # 明确指定创建新会话,避免会话内容混乱
313
- # 如果 API 支持 session_id 参数,设置为 null 表示创建新会话
314
- # 如果不支持,则不传递 session_id 参数(当前做法)
315
- }
316
-
317
- # 如果 API 支持,可以尝试以下方式之一来确保创建新会话:
318
- # 1. payload["session_id"] = None # 明确创建新会话
319
- # 2. payload["new_session"] = True # 如果 API 支持此参数
320
- # 3. 在请求头中添加唯一标识
321
-
322
- if task_id:
323
- # 添加 task_id 作为请求标识,帮助后端区分不同请求,确保会话隔离
324
- payload["task_id"] = task_id
325
-
326
- # 在请求头中添加唯一标识,进一步确保请求隔离
327
- # 如果后端支持,可以通过 X-Request-ID 或类似头部来区分请求
328
- request_headers = HEADERS.copy()
329
- if task_id:
330
- # 添加 task_id 到请求头,帮助后端识别和隔离不同请求
331
- request_headers["X-Task-ID"] = task_id
332
-
333
- try:
334
- response = requests.post(CHAT_URL, headers=request_headers, json=payload, timeout=(30, REQUEST_TIMEOUT))
335
- response.raise_for_status()
336
- res_data = response.json()
337
- raw_content = (res_data.get("answer") or res_data.get("content") or res_data.get("response") or "").strip()
338
- session_id = res_data.get("session_id")
339
- parsed_answer = extract_answer(raw_content)
340
- return parsed_answer, session_id, raw_content
341
- except Exception as e:
342
- error_traceback = traceback.format_exc()
343
- return f"ERROR: {str(e)}", session_id, error_traceback
344
-
345
-
346
- def process_item(item, existing_results, mode):
347
- """
348
- 处理单条数据:上传文件 -> 发起对话 -> 记录结果
349
- hybrid + validation 模式下:如果记录已存在且 is_correct 为 true,则跳过 agent 调用,只刷新字段顺序
350
- 其他情况:所有记录都重新执行并刷新,确保字段顺序一致
351
- (test 数据集没有 is_correct 字段,无法判断是否正确,所以总是重新执行)
352
-
353
- Args:
354
- item: 数据集项
355
- existing_results: 已有结果字典
356
- mode: 执行模式 ("full"、"incremental" 或 "hybrid")
357
- """
358
- task_id = item['task_id']
359
- level = item.get('Level', 'Unknown')
360
- question = item['Question']
361
- file_name = item.get('file_name', "")
362
-
363
- # hybrid + validation 模式下:如果记录已存在且成功,只刷新字段顺序,不调用 agent
364
- # 只有 validation 数据集有 is_correct 字段,可以判断是否正确
365
- if mode == "hybrid" and SPLIT_TYPE == "validation" and task_id in existing_results:
366
- existing_record = existing_results[task_id]
367
- if existing_record.get("is_correct", False):
368
- # 已成功,只刷新字段顺序,不调用 agent
369
- # 使用当前的 file_name 更新 attachment_name,确保数据一致性
370
- current_has_file = bool(file_name)
371
- current_attachment_name = file_name if file_name else None
372
- record = build_ordered_record(
373
- task_id=task_id,
374
- question=existing_record.get("question", question),
375
- level=existing_record.get("level", level),
376
- agent_answer=existing_record.get("agent_answer", ""),
377
- duration=existing_record.get("duration", 0),
378
- has_file=current_has_file,
379
- session_id=existing_record.get("session_id"),
380
- attachment_name=current_attachment_name,
381
- ground_truth=existing_record.get("ground_truth", ""),
382
- is_correct=True
383
- )
384
- # 更新已有记录(刷新字段顺序)
385
- update_result_in_file(task_id, record)
386
- return task_id, True, "refreshed"
387
-
388
- # 需要调用 agent 的情况(新记录、错误记录、或非 hybrid 模式)
389
- attachments = []
390
-
391
- # 1. 如果有文件,先执行上传
392
- if file_name:
393
- # 根据 split 类型选择不同的文件夹
394
- folder = "validation" if SPLIT_TYPE == "validation" else "test"
395
- local_file_path = os.path.abspath(os.path.join(DATA_PATH, "2023", folder, file_name))
396
- upload_data = upload_file(local_file_path)
397
- if upload_data:
398
- attachments.append(upload_data)
399
-
400
- # 2. 调用 Agent(传递 task_id 确保会话隔离)
401
- start_time = time.time()
402
- agent_answer, session_id, _ = call_my_agent_safe(question, attachments, task_id=task_id)
403
- duration = time.time() - start_time
404
-
405
- # 3. 构建记录(使用固定顺序)
406
- if SPLIT_TYPE == "validation":
407
- # validation 数据集:添加标准答案和正确性判断
408
- ground_truth = str(item['Final answer']).strip()
409
- clean_agent = str(agent_answer).lower().rstrip('.')
410
- clean_gt = ground_truth.lower().rstrip('.')
411
- is_correct = (clean_agent == clean_gt)
412
- record = build_ordered_record(
413
- task_id=task_id,
414
- question=question,
415
- level=level,
416
- duration=round(duration, 2),
417
- has_file=bool(file_name),
418
- session_id=session_id,
419
- attachment_name=file_name if file_name else None,
420
- agent_answer=agent_answer,
421
- ground_truth=ground_truth,
422
- is_correct=is_correct
423
- )
424
- result_correct = is_correct
425
- else: # test 数据集:没有标准答案
426
- record = build_ordered_record(
427
- task_id=task_id,
428
- question=question,
429
- level=level,
430
- duration=round(duration, 2),
431
- has_file=bool(file_name),
432
- session_id=session_id,
433
- attachment_name=file_name if file_name else None,
434
- agent_answer=agent_answer
435
- )
436
- result_correct = None
437
-
438
- # 4. 更新或追加记录
439
- if task_id in existing_results:
440
- # 更新已有记录
441
- update_result_in_file(task_id, record)
442
- return task_id, result_correct, "updated"
443
- else:
444
- # 追加新记录
445
- with file_lock:
446
- with open(OUTPUT_FILE, "a", encoding="utf-8") as f:
447
- f.write(json.dumps(record, ensure_ascii=False) + "\n")
448
- return task_id, result_correct, "new"
449
-
450
-
451
- def generate_submission():
452
- """
453
- 生成官网提交格式文件
454
- GAIA 提交格式要求:
455
- - 文件格式:JSONL(每行一个 JSON 对象)
456
- - 必需字段:task_id, model_answer
457
- - 编码:UTF-8
458
- - test 数据集需要包含所有 285 个用例的答案
459
- """
460
- if not os.path.exists(OUTPUT_FILE):
461
- print(f"⚠️ 警告:结果文件 {OUTPUT_FILE} 不存在,无法生成提交文件")
462
- return
463
-
464
- # 读取所有结果并按 task_id 排序(确保顺序一致)
465
- results = []
466
- with open(OUTPUT_FILE, "r", encoding="utf-8") as f:
467
- for line in f:
468
- if not line.strip():
469
- continue
470
- try:
471
- data = json.loads(line)
472
- if "task_id" in data and "agent_answer" in data:
473
- results.append(data)
474
- except json.JSONDecodeError:
475
- continue
476
-
477
- if not results:
478
- print(f"⚠️ 警告:结果文件 {OUTPUT_FILE} 中没有有效数据")
479
- return
480
-
481
- # 按 task_id 排序,确保顺序一致
482
- results.sort(key=lambda x: x.get("task_id", ""))
483
-
484
- # 生成提交文件
485
- with open(SUBMISSION_FILE, "w", encoding="utf-8") as f_out:
486
- for data in results:
487
- submission_data = {
488
- "task_id": data["task_id"],
489
- "model_answer": str(data["agent_answer"])
490
- }
491
- f_out.write(json.dumps(submission_data, ensure_ascii=False) + "\n")
492
-
493
- print(f"✅ 提交文件已生成: {SUBMISSION_FILE} (共 {len(results)} 条记录)")
494
-
495
- # test 数据集验证:检查是否包含所有用例
496
- if SPLIT_TYPE == "test":
497
- expected_count = 285
498
- if len(results) < expected_count:
499
- print(f"⚠️ 警告:test 数据集应该有 {expected_count} 个用例,当前只有 {len(results)} 个")
500
- else:
501
- print(f"✅ test 数据集已包含 {len(results)} 个用例,符合提交要求")
502
-
503
-
504
- def get_current_accuracy():
505
- """
506
- 获取当前的整体正确率(仅 validation 数据集)
507
-
508
- Returns:
509
- float or None: 正确率(百分比),如果不是 validation 数据集或文件不存在则返回 None
510
- """
511
- # test 数据集没有标准答案,无法计算正确率
512
- if SPLIT_TYPE != "validation":
513
- return None
514
-
515
- if not os.path.exists(OUTPUT_FILE):
516
- return None
517
-
518
- try:
519
- results = []
520
- with open(OUTPUT_FILE, "r", encoding="utf-8") as f:
521
- for line in f:
522
- if not line.strip():
523
- continue
524
- try:
525
- data = json.loads(line)
526
- results.append(data)
527
- except json.JSONDecodeError:
528
- continue
529
-
530
- if not results:
531
- return None
532
-
533
- total = len(results)
534
- correct = sum(1 for r in results if r.get("is_correct", False))
535
- accuracy = (correct / total * 100) if total > 0 else 0.0
536
- return accuracy
537
- except Exception:
538
- return None
539
-
540
-
541
- def generate_report():
542
- """生成统计成绩单(仅 validation 数据集有标准答案,才生成成绩单)"""
543
- # test 数据集没有标准答案,不生成成绩单
544
- if SPLIT_TYPE != "validation":
545
- return
546
-
547
- if not os.path.exists(OUTPUT_FILE):
548
- return
549
-
550
- results = [json.loads(line) for line in open(OUTPUT_FILE, "r", encoding="utf-8")]
551
- df = pd.DataFrame(results)
552
- total = len(df)
553
- acc = (df['is_correct'].sum() / total) * 100
554
-
555
- print("\n" + "=" * 50)
556
- print(f"测试完成! 总数: {total} | 总准确率: {acc:.2f}%")
557
- print("=" * 50)
558
-
559
-
560
- def run_test_concurrent(num_questions=200, mode="hybrid", split="validation", threads=MAX_CONCURRENT, target_task_id=None):
561
- """
562
- 测试主函数
563
-
564
- Args:
565
- num_questions: 要执行的用例数量
566
- mode: 执行模式
567
- - "full": 全量模式,不管文件是否存在,都删除重新开始
568
- - "incremental": 增量模式,如果文件存在则增量,不存在则全量执行
569
- - "hybrid": 混合模式,第一次时全量(文件不存在),后面就增量(文件存在)
570
- - "error": 错误模式,只重新执行 agent_answer 包含 ERROR 的记录
571
- split: 数据集类型,"validation" 或 "test"
572
- threads: 并发线程数
573
- target_task_id: 可选,指定要运行的 task_id,如果指定则只运行该用例
574
- """
575
- global dataset, OUTPUT_FILE, SUBMISSION_FILE, SPLIT_TYPE
576
-
577
- # 设置全局变量
578
- SPLIT_TYPE = split
579
-
580
- # 根据 split 类型设置输出文件名
581
- if split == "validation":
582
- OUTPUT_FILE = "validation_results.jsonl"
583
- SUBMISSION_FILE = "validation_submission.jsonl"
584
- print("📥 正在检查 GAIA 验证集数据...")
585
- else: # test
586
- OUTPUT_FILE = "test_results.jsonl"
587
- SUBMISSION_FILE = "test_submission.jsonl"
588
- print("📥 正在检查 GAIA 测试集数据...")
589
-
590
- # 1. 检查并下载完整数据集文件(如果需要,包含 validation 和 test 的所有文件)
591
- print("\n【步骤 1/2】检查数据集文件...")
592
- check_and_download_dataset_files()
593
-
594
- # 2. 加载数据集元数据(如果首次下载会显示下载进度)
595
- print(f"\n【步骤 2/2】加载数据集元数据...")
596
- print(f" 数据集: gaia-benchmark/GAIA (2023_all, split={split})")
597
- print(" 提示: 如果是首次下载,请耐心等待,下载进度会显示在下方...")
598
- print(" 如果已下载过,会直接从缓存加载,速度较快")
599
- start_time = time.time()
600
- dataset = load_dataset("gaia-benchmark/GAIA", "2023_all", split=split)
601
- load_duration = time.time() - start_time
602
- print(f"✅ 数据集元数据加载完成!共 {len(dataset)} 条记录,耗时 {load_duration:.2f} 秒\n")
603
-
604
- # 1. 根据模式处理已有结果
605
- file_exists = os.path.exists(OUTPUT_FILE)
606
-
607
- if mode == "full":
608
- # 全量模式:删除旧文件,从头开始
609
- if file_exists:
610
- os.remove(OUTPUT_FILE)
611
- print("🔄 全量模式:已删除旧结果文件,从头开始执行")
612
- existing_results = {}
613
- elif mode == "incremental":
614
- # 增量模式:如果文件存在则增量,不存在则全量执行
615
- if file_exists:
616
- existing_results = load_existing_results()
617
- print(f"📋 增量模式:已加载 {len(existing_results)} 条历史记录")
618
- else:
619
- existing_results = {}
620
- print("📋 增量模式:未找到历史记录,将全量执行")
621
- elif mode == "error":
622
- # 错误模式:只重新执行 agent_answer 包含 ERROR 的记录
623
- if file_exists:
624
- existing_results = load_existing_results()
625
- print(f"📋 错误模式:已加载 {len(existing_results)} 条历史记录,将重新执行包含 ERROR 的记录")
626
- else:
627
- existing_results = {}
628
- print("📋 错误模式:未找到历史记录,无法执行错误重试")
629
- else: # hybrid
630
- # 混合模式:第一次时全量(文件不存在),后面就增量(文件存在)
631
- if file_exists:
632
- existing_results = load_existing_results()
633
- print(f"📋 混合模式:检测到已有文件,进入增量模式(已加载 {len(existing_results)} 条历史记录)")
634
- else:
635
- existing_results = {}
636
- print("📋 混合模式:首次执行,进入全量模式")
637
-
638
- # 2. 筛选需要执行的用例
639
- if target_task_id:
640
- # 如果指定了 task_id,只运行该用例
641
- print(f"🎯 指定运行 task_id: {target_task_id}")
642
- tasks_to_run = []
643
- found = False
644
- for item in dataset:
645
- if item['task_id'] == target_task_id:
646
- tasks_to_run = [item]
647
- found = True
648
- break
649
- if not found:
650
- print(f"❌ 错误: 在 {split} 数据集中未找到 task_id: {target_task_id}")
651
- return
652
- num_to_run = 1
653
- else:
654
- # 正常模式,根据 num_questions 筛选
655
- num_to_run = min(num_questions, len(dataset))
656
- tasks_to_run = dataset.select(range(num_to_run))
657
-
658
- # 统计需要执行的用例
659
- tasks_to_execute = []
660
- refresh_count = 0 # hybrid 模式下只刷新字段顺序的记录数
661
- update_count = 0 # 需要重新调用 agent 的记录数
662
- new_count = 0 # 新记录数
663
- error_count = 0 # error 模式下包含 ERROR 的记录数
664
-
665
- for item in tasks_to_run:
666
- task_id = item['task_id']
667
-
668
- if mode == "error":
669
- # error 模式:只重新执行 agent_answer 包含 ERROR 的记录
670
- if task_id in existing_results:
671
- agent_answer = existing_results[task_id].get("agent_answer", "")
672
- if agent_answer and "ERROR" in str(agent_answer):
673
- error_count += 1
674
- tasks_to_execute.append(item)
675
- # 如果记录不存在或 agent_answer 不包含 ERROR,则跳过
676
- else:
677
- # 其他模式:正常处理
678
- if task_id in existing_results:
679
- # hybrid + validation 模式下:如果已成功,只刷新字段顺序
680
- # test 数据集没有 is_correct 字段,无法判断是否正确,所以总是重新执行
681
- if mode == "hybrid" and split == "validation":
682
- if existing_results[task_id].get("is_correct", False):
683
- refresh_count += 1
684
- else:
685
- update_count += 1
686
- else:
687
- # 非 hybrid 模式,或 test 数据集:所有已有记录都需要重新执行
688
- update_count += 1
689
- else:
690
- new_count += 1
691
- tasks_to_execute.append(item)
692
-
693
- total_to_execute = len(tasks_to_execute)
694
-
695
- print(f"\n📊 统计信息:")
696
- print(f" 数据集: {split}")
697
- print(f" 执行模式: {mode}")
698
- if target_task_id:
699
- print(f" 指定 task_id: {target_task_id}")
700
- print(f" 总用例数: {num_to_run}")
701
- if existing_results:
702
- if mode == "error":
703
- print(f" 需要执行: {total_to_execute} (包含 ERROR 的记录: {error_count})")
704
- elif mode == "hybrid":
705
- print(
706
- f" 需要执行: {total_to_execute} (新用例: {new_count}, 刷新字段顺序: {refresh_count}, 重新测试: {update_count})")
707
- else:
708
- print(
709
- f" 需要执行: {total_to_execute} (新用例: {new_count}, 刷新已有记录: {refresh_count + update_count})")
710
- else:
711
- print(f" 需要执行: {total_to_execute} (全量执行)")
712
- print(f"🚀 开始测试 | 并发数: {threads} | 待执行: {total_to_execute}")
713
-
714
- if total_to_execute == 0:
715
- if mode == "error":
716
- print("✅ 没有包含 ERROR 的记录,无需执行")
717
- elif split == "validation":
718
- print("✅ 所有用例已完成且正确,无需执行")
719
- else:
720
- print("✅ 所有用例已完成,无需执行")
721
- generate_report()
722
- generate_submission()
723
- return
724
-
725
- # 3. 并发执行
726
- with ThreadPoolExecutor(max_workers=threads) as executor:
727
- future_to_item = {executor.submit(process_item, item, existing_results, mode): item for item in
728
- tasks_to_execute}
729
-
730
- done = 0
731
- for future in as_completed(future_to_item):
732
- done += 1
733
- item = future_to_item[future]
734
- tid = item['task_id']
735
- try:
736
- _, is_ok, status = future.result()
737
- if status == "refreshed":
738
- status_icon = "🔄"
739
- elif split == "validation":
740
- status_icon = "✅" if is_ok else "❌"
741
- else: # test
742
- status_icon = "✅"
743
-
744
- # 计算并显示当前整体正确率
745
- accuracy_info = ""
746
- if split == "validation":
747
- current_accuracy = get_current_accuracy()
748
- if current_accuracy is not None:
749
- accuracy_info = f" | 当前正确率: {current_accuracy:.2f}%"
750
-
751
- print(f"[{done}/{total_to_execute}] ID: {tid} | 状态: {status_icon} ({status}){accuracy_info}")
752
- except Exception as e:
753
- error_traceback = traceback.format_exc()
754
- print(f"[{done}/{total_to_execute}] ID: {tid} 运行异常: {e}")
755
- print(f"异常堆栈:\n{error_traceback}")
756
-
757
- # 4. 生成报表
758
- generate_report()
759
- generate_submission()
760
-
761
-
762
- def print_help():
763
- """打印详细的帮助信息"""
764
- print("=" * 70)
765
- print("GAIA 测试脚本 - 参数说明")
766
- print("=" * 70)
767
- print()
768
- print("用法:")
769
- print(" python gaia_test.py [参数]")
770
- print()
771
- print("参数说明:")
772
- print()
773
- print(" --split <类型>")
774
- print(" 数据集类型")
775
- print(" 可选值: validation, test")
776
- print(" 默认值: validation")
777
- print(" 说明:")
778
- print(" - validation: 验证集,有标准答案,可以计算正确率")
779
- print(" - test: 测试集,无标准答案,用于最终提交")
780
- print()
781
- print(" --mode <模式>")
782
- print(" 执行模式")
783
- print(" 可选值: full, incremental, hybrid, error")
784
- print(" 默认值: hybrid")
785
- print(" 说明:")
786
- print(" - full: 全量模式,删除旧结果文件,从头开始执行")
787
- print(" - incremental: 增量模式,如果文件存在则增量,不存在则全量执行")
788
- print(" - hybrid: 混合模式(推荐),首次全量,后续增量")
789
- print(" 在 hybrid 模式下,validation 数据集中已正确的记录")
790
- print(" 只刷新字段顺序,不重新调用 agent")
791
- print(" - error: 错误模式,只重新执行 agent_answer 包含 ERROR 的记录")
792
- print()
793
- print(" --num <数量>")
794
- print(" 要执行的用例数量")
795
- print(" 类型: 整数")
796
- print(" 默认值: 200")
797
- print(" 说明:")
798
- print(" - test 数据集共 285 题,可以设置 --num 285 执行全部")
799
- print(" - validation 数据集可以根据需要设置数量")
800
- print()
801
- print(" --threads <数量>")
802
- print(" 并发执行的线程数")
803
- print(" 类型: 整数")
804
- print(" 默认值: 2")
805
- print(" 说明:")
806
- print(" - 根据服务器性能调整,过高可能导致服务器压力过大")
807
- print(" - 建议范围: 1-4")
808
- print()
809
- print(" --task-id <task_id>")
810
- print(" 指定要运行的 task_id")
811
- print(" 类型: 字符串")
812
- print(" 默认值: 无(运行多个用例)")
813
- print(" 说明:")
814
- print(" - 如果指定此参数,则只运行该 task_id 对应的用例")
815
- print(" - 指定此参数时,--num 参数会被忽略")
816
- print(" - 如果指定的 task_id 不存在,脚本会报错并退出")
817
- print()
818
- print(" -h, --help")
819
- print(" 显示此帮助信息并退出")
820
- print()
821
- print("示例:")
822
- print(" # 使用默认参数(validation 数据集,hybrid 模式,200 题)")
823
- print(" python gaia_test.py")
824
- print()
825
- print(" # 测试 test 数据集,执行全部 285 题")
826
- print(" python gaia_test.py --split test --num 285")
827
- print()
828
- print(" # 使用 error 模式重新执行错误记录")
829
- print(" python gaia_test.py --mode error")
830
- print()
831
- print(" # 使用全量模式,4 个并发线程")
832
- print(" python gaia_test.py --mode full --threads 4")
833
- print()
834
- print(" # 运行指定的 task_id")
835
- print(" python gaia_test.py --task-id c61d22de-5f6c-4958-a7f6-5e9707bd3466")
836
- print()
837
- print("=" * 70)
838
- print("配置文件:")
839
- print(" 运行前请确保已正确配置 gaia_test.py 中的以下参数:")
840
- print(" - BASE_URL: API 服务地址")
841
- print(" - HEADERS: 认证 Token(必须修改)")
842
- print(" - DATA_PATH: 数据文件路径")
843
- print(" - REQUEST_TIMEOUT: 请求超时时间")
844
- print(" - MAX_CONCURRENT: 最大并发数")
845
- print("=" * 70)
846
-
847
-
848
- if __name__ == "__main__":
849
- import sys
850
-
851
- # 检查是否有 -h 或 --help 参数
852
- if "-h" in sys.argv or "--help" in sys.argv:
853
- print_help()
854
- sys.exit(0)
855
-
856
- parser = argparse.ArgumentParser(
857
- description="GAIA 测试脚本(支持 validation 和 test 数据集,支持全量、增量、混合三种模式)",
858
- formatter_class=argparse.RawDescriptionHelpFormatter
859
- )
860
- parser.add_argument(
861
- "--split",
862
- type=str,
863
- choices=["validation", "test"],
864
- default="validation",
865
- help="数据集类型: 'validation' 验证集(有标准答案)、'test' 测试集(无标准答案,默认: validation)"
866
- )
867
- parser.add_argument(
868
- "--mode",
869
- type=str,
870
- choices=["full", "incremental", "hybrid", "error"],
871
- default="hybrid",
872
- help="执行模式: 'full' 全量模式(删除旧文件重新执行)、'incremental' 增量模式(文件存在则增量,不存在则全量)、'hybrid' 混合模式(首次全量,后续增量,默认)、'error' 错误模式(只重新执行 agent_answer 包含 ERROR 的记录)"
873
- )
874
- parser.add_argument(
875
- "--num",
876
- type=int,
877
- default=200,
878
- help="要执行的用例数量(默认: 200,test 集共 285 题)"
879
- )
880
- parser.add_argument(
881
- "--threads",
882
- type=int,
883
- default=MAX_CONCURRENT,
884
- help="执行的并发数(默认: 2)"
885
- )
886
- parser.add_argument(
887
- "--task-id",
888
- type=str,
889
- default=None,
890
- help="指定要运行的 task_id,如果指定则只运行该用例(忽略 --num 参数)"
891
- )
892
-
893
- args = parser.parse_args()
894
-
895
- run_test_concurrent(num_questions=args.num, mode=args.mode, split=args.split, threads=args.threads, target_task_id=args.task_id)