"""
测试脚本:支持 validation 和 test 数据集,支持全量、增量、混合三种模式
- 全量模式(full):不管文件是否存在,都删除重新开始
- 增量模式(incremental):如果文件存在则增量,不存在则全量执行
- 混合模式(hybrid):第一次时全量(文件不存在),后面就增量(文件存在)
"""
import argparse
import json
import os
import re
import time
import traceback
from collections import OrderedDict
from concurrent.futures import ThreadPoolExecutor, as_completed
from threading import Lock
import pandas as pd
import requests
from datasets import load_dataset
from huggingface_hub import snapshot_download, hf_hub_download
from pathlib import Path
import shutil
# --- 1. 配置区 ---
BASE_URL = "http://localhost:5173/api/v1"
CHAT_URL = f"{BASE_URL}/sessions/chat"
UPLOAD_URL = f"{BASE_URL}/files"
HEADERS = {"Authorization": "Bearer hawk_YhCZLQYqtPOwOiEyEgeCNdfAFAbrHtTUxQvRiaOInyekgVgE"}
DATA_PATH = "./gaia_data"
REQUEST_TIMEOUT = 1800
MAX_CONCURRENT = 2
file_lock = Lock()
# 全局变量,根据 split 类型动态设置
dataset = None
OUTPUT_FILE = None
SUBMISSION_FILE = None
SPLIT_TYPE = None # "validation" 或 "test"
def check_and_download_dataset_files():
"""
检查并下载完整的 GAIA 数据集文件到本地目录(包含 validation 和 test 的所有文件)
Returns:
bool: 如果文件已存在或下载成功返回 True,否则返回 False
"""
base_target_dir = Path(DATA_PATH) / "2023"
validation_dir = base_target_dir / "validation"
test_dir = base_target_dir / "test"
# 检查两个目录是否都存在且有文件
validation_files = list(validation_dir.glob("*")) if validation_dir.exists() else []
test_files = list(test_dir.glob("*")) if test_dir.exists() else []
if validation_files and test_files:
print(f"✅ 检测到数据集文件已存在")
print(f" validation 文件数: {len(validation_files)}")
print(f" test 文件数: {len(test_files)}")
return True
# 需要下载数据集文件
print(f"📥 开始下载完整的 GAIA 数据集文件...")
print(f" 目标目录: {base_target_dir}")
try:
# 创建基础目录
base_target_dir.mkdir(parents=True, exist_ok=True)
# 下载完整数据集到临时目录,然后复制到目标目录
print(" 步骤 1/4: 正在从 Hugging Face 下载完整数据集...")
print(" 提示: 下载进度会显示在下方,请耐心等待...")
download_start = time.time()
# 使用 snapshot_download 下载完整数据集
cache_dir = snapshot_download(
repo_id="gaia-benchmark/GAIA",
repo_type="dataset",
local_dir=None, # 使用默认缓存目录
resume_download=True
)
download_duration = time.time() - download_start
print(f" ✅ 数据集下载完成,耗时 {download_duration:.2f} 秒")
cache_path = Path(cache_dir)
source_2023_dir = cache_path / "2023"
if not source_2023_dir.exists():
print(f" ❌ 错误: 缓存目录中未找到 2023 目录")
return False
# 复制 validation 和 test 目录
print(" 步骤 2/4: 正在复制 validation 文件...")
validation_source = source_2023_dir / "validation"
if validation_source.exists():
if validation_dir.exists():
shutil.rmtree(validation_dir)
shutil.copytree(validation_source, validation_dir)
validation_count = len(list(validation_dir.glob("*")))
print(f" ✅ validation 文件复制完成,共 {validation_count} 个文件")
else:
print(f" ⚠️ 警告: 未找到 validation 目录")
print(" 步骤 3/4: 正在复制 test 文件...")
test_source = source_2023_dir / "test"
if test_source.exists():
if test_dir.exists():
shutil.rmtree(test_dir)
shutil.copytree(test_source, test_dir)
test_count = len(list(test_dir.glob("*")))
print(f" ✅ test 文件复制完成,共 {test_count} 个文件")
else:
print(f" ⚠️ 警告: 未找到 test 目录")
print(" 步骤 4/4: 数据集文件准备完成!")
print(f" 目标目录: {base_target_dir}")
return True
except Exception as e:
print(f" ❌ 下载数据集文件时出错: {e}")
import traceback
traceback.print_exc()
return False
def build_ordered_record(task_id, question, level, agent_answer, duration, has_file,
session_id=None, attachment_name=None, ground_truth=None, is_correct=None):
"""
按照固定顺序构建记录字典,确保字段顺序一致
Args:
task_id: 任务ID
question: 问题
level: 难度级别
agent_answer: Agent答案
duration: 执行时长
has_file: 是否有文件
session_id: 会话 ID
attachment_name: 附件名称(如果有附件)
ground_truth: 标准答案(仅validation数据集)
is_correct: 是否正确(仅validation数据集)
Returns:
OrderedDict: 按固定顺序排列的记录
"""
record = OrderedDict()
record["task_id"] = task_id
record["question"] = question
record["level"] = level
record["duration"] = duration
record["has_file"] = has_file
# attachment_name: 如果有值就写入(即使 agent 出错也应该写入)
# 只要 attachment_name 不是 None 且不是空字符串,就写入
if attachment_name and attachment_name.strip():
record["attachment_name"] = attachment_name
# session_id: 如果有值就写入(如果 agent 出错可能为 None,不写入是合理的)
if session_id:
record["session_id"] = session_id
record["agent_answer"] = agent_answer
# validation 数据集特有字段
if ground_truth is not None:
record["ground_truth"] = ground_truth
if is_correct is not None:
record["is_correct"] = is_correct
return record
def load_existing_results():
"""
加载已有的测试结果文件
返回: dict, task_id -> 完整记录字典
"""
if not os.path.exists(OUTPUT_FILE):
return {}
results = {}
try:
with open(OUTPUT_FILE, "r", encoding="utf-8") as f:
for line in f:
if not line.strip():
continue
try:
data = json.loads(line)
task_id = data.get("task_id")
if task_id:
results[task_id] = data
except json.JSONDecodeError:
continue
print(f"✅ 已加载 {len(results)} 条历史记录")
except Exception as e:
print(f"⚠️ 加载历史记录时出错: {e}")
return {}
return results
def update_result_in_file(task_id, new_record):
"""
更新 jsonl 文件中指定 task_id 的记录
使用临时文件方式,确保线程安全
"""
if not os.path.exists(OUTPUT_FILE):
# 如果文件不存在,直接写入
with file_lock:
with open(OUTPUT_FILE, "w", encoding="utf-8") as f:
f.write(json.dumps(new_record, ensure_ascii=False) + "\n")
return
# 读取所有记录,更新指定记录,写回文件
with file_lock:
temp_file = OUTPUT_FILE + ".tmp"
updated = False
try:
with open(OUTPUT_FILE, "r", encoding="utf-8") as f_in, \
open(temp_file, "w", encoding="utf-8") as f_out:
for line in f_in:
if not line.strip():
continue
try:
data = json.loads(line)
if data.get("task_id") == task_id:
# 更新这条记录
f_out.write(json.dumps(new_record, ensure_ascii=False) + "\n")
updated = True
else:
# 保持原记录
f_out.write(line)
except json.JSONDecodeError:
continue
# 如果没找到要更新的记录,追加新记录
if not updated:
f_out.write(json.dumps(new_record, ensure_ascii=False) + "\n")
# 替换原文件
os.replace(temp_file, OUTPUT_FILE)
except Exception as e:
# 如果出错,删除临时文件
if os.path.exists(temp_file):
os.remove(temp_file)
raise e
def upload_file(local_path):
"""上传文件并返回符合接口要求的 file_id 和 filename"""
try:
if not os.path.exists(local_path):
print(f"❌ 本地文件不存在: {local_path}")
return None
with open(local_path, 'rb') as f:
files = {'file': f}
response = requests.post(UPLOAD_URL, headers=HEADERS, files=files, timeout=60)
response.raise_for_status()
res_data = response.json()
if res_data.get("code") == 0:
file_info = res_data.get("data", {})
return {
"file_id": file_info.get("file_id"),
"filename": file_info.get("filename")
}
else:
print(f"❌ 上传接口返回错误: {res_data.get('msg')}")
except Exception as e:
print(f"❌ 文件上传异常 ({os.path.basename(local_path)}): {e}")
return None
def extract_answer(text):
"""从文本中提取答案"""
if not text:
return ""
pattern = r"(?si)<\s*answer\s*>\s*(.*?)\s*\s*answer\s*>"
match = re.search(pattern, text)
if match:
ans = match.group(1).strip()
return re.sub(r'^["\']|["\']$', '', ans)
backup_pattern = r"(?i)answer\s*is[::]\s*(.*)"
backup_match = re.search(backup_pattern, text)
if backup_match:
return backup_match.group(1).strip().rstrip('.')
lines = [l.strip() for l in text.strip().split('\n') if l.strip()]
return lines[-1] if lines else text.strip()
def call_my_agent_safe(question, attachments=None, task_id=None):
"""
发送对话请求,包含附件数组
Args:
question: 问题内容
attachments: 附件列表
task_id: 任务ID,用于确保会话隔离
Returns:
tuple: (parsed_answer, session_id, raw_content)
"""
guided_prompt = (
f"{question}\n\n Important Requirement: \nprovide the final answer (the answer only, without explanation) inside the tags in the following format: your answer"
)
payload = {
"message": guided_prompt,
"streaming": False,
"attachments": attachments if attachments else [],
"recycle_sandbox": True,
# 明确指定创建新会话,避免会话内容混乱
# 如果 API 支持 session_id 参数,设置为 null 表示创建新会话
# 如果不支持,则不传递 session_id 参数(当前做法)
}
# 如果 API 支持,可以尝试以下方式之一来确保创建新会话:
# 1. payload["session_id"] = None # 明确创建新会话
# 2. payload["new_session"] = True # 如果 API 支持此参数
# 3. 在请求头中添加唯一标识
if task_id:
# 添加 task_id 作为请求标识,帮助后端区分不同请求,确保会话隔离
payload["task_id"] = task_id
# 在请求头中添加唯一标识,进一步确保请求隔离
# 如果后端支持,可以通过 X-Request-ID 或类似头部来区分请求
request_headers = HEADERS.copy()
if task_id:
# 添加 task_id 到请求头,帮助后端识别和隔离不同请求
request_headers["X-Task-ID"] = task_id
try:
response = requests.post(CHAT_URL, headers=request_headers, json=payload, timeout=(30, REQUEST_TIMEOUT))
response.raise_for_status()
res_data = response.json()
raw_content = (res_data.get("answer") or res_data.get("content") or res_data.get("response") or "").strip()
session_id = res_data.get("session_id")
parsed_answer = extract_answer(raw_content)
return parsed_answer, session_id, raw_content
except Exception as e:
error_traceback = traceback.format_exc()
return f"ERROR: {str(e)}", session_id, error_traceback
def process_item(item, existing_results, mode):
"""
处理单条数据:上传文件 -> 发起对话 -> 记录结果
hybrid + validation 模式下:如果记录已存在且 is_correct 为 true,则跳过 agent 调用,只刷新字段顺序
其他情况:所有记录都重新执行并刷新,确保字段顺序一致
(test 数据集没有 is_correct 字段,无法判断是否正确,所以总是重新执行)
Args:
item: 数据集项
existing_results: 已有结果字典
mode: 执行模式 ("full"、"incremental" 或 "hybrid")
"""
task_id = item['task_id']
level = item.get('Level', 'Unknown')
question = item['Question']
file_name = item.get('file_name', "")
# hybrid + validation 模式下:如果记录已存在且成功,只刷新字段顺序,不调用 agent
# 只有 validation 数据集有 is_correct 字段,可以判断是否正确
if mode == "hybrid" and SPLIT_TYPE == "validation" and task_id in existing_results:
existing_record = existing_results[task_id]
if existing_record.get("is_correct", False):
# 已成功,只刷新字段顺序,不调用 agent
# 使用当前的 file_name 更新 attachment_name,确保数据一致性
current_has_file = bool(file_name)
current_attachment_name = file_name if file_name else None
record = build_ordered_record(
task_id=task_id,
question=existing_record.get("question", question),
level=existing_record.get("level", level),
agent_answer=existing_record.get("agent_answer", ""),
duration=existing_record.get("duration", 0),
has_file=current_has_file,
session_id=existing_record.get("session_id"),
attachment_name=current_attachment_name,
ground_truth=existing_record.get("ground_truth", ""),
is_correct=True
)
# 更新已有记录(刷新字段顺序)
update_result_in_file(task_id, record)
return task_id, True, "refreshed"
# 需要调用 agent 的情况(新记录、错误记录、或非 hybrid 模式)
attachments = []
# 1. 如果有文件,先执行上传
if file_name:
# 根据 split 类型选择不同的文件夹
folder = "validation" if SPLIT_TYPE == "validation" else "test"
local_file_path = os.path.abspath(os.path.join(DATA_PATH, "2023", folder, file_name))
upload_data = upload_file(local_file_path)
if upload_data:
attachments.append(upload_data)
# 2. 调用 Agent(传递 task_id 确保会话隔离)
start_time = time.time()
agent_answer, session_id, _ = call_my_agent_safe(question, attachments, task_id=task_id)
duration = time.time() - start_time
# 3. 构建记录(使用固定顺序)
if SPLIT_TYPE == "validation":
# validation 数据集:添加标准答案和正确性判断
ground_truth = str(item['Final answer']).strip()
clean_agent = str(agent_answer).lower().rstrip('.')
clean_gt = ground_truth.lower().rstrip('.')
is_correct = (clean_agent == clean_gt)
record = build_ordered_record(
task_id=task_id,
question=question,
level=level,
duration=round(duration, 2),
has_file=bool(file_name),
session_id=session_id,
attachment_name=file_name if file_name else None,
agent_answer=agent_answer,
ground_truth=ground_truth,
is_correct=is_correct
)
result_correct = is_correct
else: # test 数据集:没有标准答案
record = build_ordered_record(
task_id=task_id,
question=question,
level=level,
duration=round(duration, 2),
has_file=bool(file_name),
session_id=session_id,
attachment_name=file_name if file_name else None,
agent_answer=agent_answer
)
result_correct = None
# 4. 更新或追加记录
if task_id in existing_results:
# 更新已有记录
update_result_in_file(task_id, record)
return task_id, result_correct, "updated"
else:
# 追加新记录
with file_lock:
with open(OUTPUT_FILE, "a", encoding="utf-8") as f:
f.write(json.dumps(record, ensure_ascii=False) + "\n")
return task_id, result_correct, "new"
def generate_submission():
"""
生成官网提交格式文件
GAIA 提交格式要求:
- 文件格式:JSONL(每行一个 JSON 对象)
- 必需字段:task_id, model_answer
- 编码:UTF-8
- test 数据集需要包含所有 285 个用例的答案
"""
if not os.path.exists(OUTPUT_FILE):
print(f"⚠️ 警告:结果文件 {OUTPUT_FILE} 不存在,无法生成提交文件")
return
# 读取所有结果并按 task_id 排序(确保顺序一致)
results = []
with open(OUTPUT_FILE, "r", encoding="utf-8") as f:
for line in f:
if not line.strip():
continue
try:
data = json.loads(line)
if "task_id" in data and "agent_answer" in data:
results.append(data)
except json.JSONDecodeError:
continue
if not results:
print(f"⚠️ 警告:结果文件 {OUTPUT_FILE} 中没有有效数据")
return
# 按 task_id 排序,确保顺序一致
results.sort(key=lambda x: x.get("task_id", ""))
# 生成提交文件
with open(SUBMISSION_FILE, "w", encoding="utf-8") as f_out:
for data in results:
submission_data = {
"task_id": data["task_id"],
"model_answer": str(data["agent_answer"])
}
f_out.write(json.dumps(submission_data, ensure_ascii=False) + "\n")
print(f"✅ 提交文件已生成: {SUBMISSION_FILE} (共 {len(results)} 条记录)")
# test 数据集验证:检查是否包含所有用例
if SPLIT_TYPE == "test":
expected_count = 285
if len(results) < expected_count:
print(f"⚠️ 警告:test 数据集应该有 {expected_count} 个用例,当前只有 {len(results)} 个")
else:
print(f"✅ test 数据集已包含 {len(results)} 个用例,符合提交要求")
def get_current_accuracy():
"""
获取当前的整体正确率(仅 validation 数据集)
Returns:
float or None: 正确率(百分比),如果不是 validation 数据集或文件不存在则返回 None
"""
# test 数据集没有标准答案,无法计算正确率
if SPLIT_TYPE != "validation":
return None
if not os.path.exists(OUTPUT_FILE):
return None
try:
results = []
with open(OUTPUT_FILE, "r", encoding="utf-8") as f:
for line in f:
if not line.strip():
continue
try:
data = json.loads(line)
results.append(data)
except json.JSONDecodeError:
continue
if not results:
return None
total = len(results)
correct = sum(1 for r in results if r.get("is_correct", False))
accuracy = (correct / total * 100) if total > 0 else 0.0
return accuracy
except Exception:
return None
def generate_report():
"""生成统计成绩单(仅 validation 数据集有标准答案,才生成成绩单)"""
# test 数据集没有标准答案,不生成成绩单
if SPLIT_TYPE != "validation":
return
if not os.path.exists(OUTPUT_FILE):
return
results = [json.loads(line) for line in open(OUTPUT_FILE, "r", encoding="utf-8")]
df = pd.DataFrame(results)
total = len(df)
acc = (df['is_correct'].sum() / total) * 100
print("\n" + "=" * 50)
print(f"测试完成! 总数: {total} | 总准确率: {acc:.2f}%")
print("=" * 50)
def run_test_concurrent(num_questions=200, mode="hybrid", split="validation", threads=MAX_CONCURRENT, target_task_id=None):
"""
测试主函数
Args:
num_questions: 要执行的用例数量
mode: 执行模式
- "full": 全量模式,不管文件是否存在,都删除重新开始
- "incremental": 增量模式,如果文件存在则增量,不存在则全量执行
- "hybrid": 混合模式,第一次时全量(文件不存在),后面就增量(文件存在)
- "error": 错误模式,只重新执行 agent_answer 包含 ERROR 的记录
split: 数据集类型,"validation" 或 "test"
threads: 并发线程数
target_task_id: 可选,指定要运行的 task_id,如果指定则只运行该用例
"""
global dataset, OUTPUT_FILE, SUBMISSION_FILE, SPLIT_TYPE
# 设置全局变量
SPLIT_TYPE = split
# 根据 split 类型设置输出文件名
if split == "validation":
OUTPUT_FILE = "validation_results.jsonl"
SUBMISSION_FILE = "validation_submission.jsonl"
print("📥 正在检查 GAIA 验证集数据...")
else: # test
OUTPUT_FILE = "test_results.jsonl"
SUBMISSION_FILE = "test_submission.jsonl"
print("📥 正在检查 GAIA 测试集数据...")
# 1. 检查并下载完整数据集文件(如果需要,包含 validation 和 test 的所有文件)
print("\n【步骤 1/2】检查数据集文件...")
check_and_download_dataset_files()
# 2. 加载数据集元数据(如果首次下载会显示下载进度)
print(f"\n【步骤 2/2】加载数据集元数据...")
print(f" 数据集: gaia-benchmark/GAIA (2023_all, split={split})")
print(" 提示: 如果是首次下载,请耐心等待,下载进度会显示在下方...")
print(" 如果已下载过,会直接从缓存加载,速度较快")
start_time = time.time()
dataset = load_dataset("gaia-benchmark/GAIA", "2023_all", split=split)
load_duration = time.time() - start_time
print(f"✅ 数据集元数据加载完成!共 {len(dataset)} 条记录,耗时 {load_duration:.2f} 秒\n")
# 1. 根据模式处理已有结果
file_exists = os.path.exists(OUTPUT_FILE)
if mode == "full":
# 全量模式:删除旧文件,从头开始
if file_exists:
os.remove(OUTPUT_FILE)
print("🔄 全量模式:已删除旧结果文件,从头开始执行")
existing_results = {}
elif mode == "incremental":
# 增量模式:如果文件存在则增量,不存在则全量执行
if file_exists:
existing_results = load_existing_results()
print(f"📋 增量模式:已加载 {len(existing_results)} 条历史记录")
else:
existing_results = {}
print("📋 增量模式:未找到历史记录,将全量执行")
elif mode == "error":
# 错误模式:只重新执行 agent_answer 包含 ERROR 的记录
if file_exists:
existing_results = load_existing_results()
print(f"📋 错误模式:已加载 {len(existing_results)} 条历史记录,将重新执行包含 ERROR 的记录")
else:
existing_results = {}
print("📋 错误模式:未找到历史记录,无法执行错误重试")
else: # hybrid
# 混合模式:第一次时全量(文件不存在),后面就增量(文件存在)
if file_exists:
existing_results = load_existing_results()
print(f"📋 混合模式:检测到已有文件,进入增量模式(已加载 {len(existing_results)} 条历史记录)")
else:
existing_results = {}
print("📋 混合模式:首次执行,进入全量模式")
# 2. 筛选需要执行的用例
if target_task_id:
# 如果指定了 task_id,只运行该用例
print(f"🎯 指定运行 task_id: {target_task_id}")
tasks_to_run = []
found = False
for item in dataset:
if item['task_id'] == target_task_id:
tasks_to_run = [item]
found = True
break
if not found:
print(f"❌ 错误: 在 {split} 数据集中未找到 task_id: {target_task_id}")
return
num_to_run = 1
else:
# 正常模式,根据 num_questions 筛选
num_to_run = min(num_questions, len(dataset))
tasks_to_run = dataset.select(range(num_to_run))
# 统计需要执行的用例
tasks_to_execute = []
refresh_count = 0 # hybrid 模式下只刷新字段顺序的记录数
update_count = 0 # 需要重新调用 agent 的记录数
new_count = 0 # 新记录数
error_count = 0 # error 模式下包含 ERROR 的记录数
for item in tasks_to_run:
task_id = item['task_id']
if mode == "error":
# error 模式:只重新执行 agent_answer 包含 ERROR 的记录
if task_id in existing_results:
agent_answer = existing_results[task_id].get("agent_answer", "")
if agent_answer and "ERROR" in str(agent_answer):
error_count += 1
tasks_to_execute.append(item)
# 如果记录不存在或 agent_answer 不包含 ERROR,则跳过
else:
# 其他模式:正常处理
if task_id in existing_results:
# hybrid + validation 模式下:如果已成功,只刷新字段顺序
# test 数据集没有 is_correct 字段,无法判断是否正确,所以总是重新执行
if mode == "hybrid" and split == "validation":
if existing_results[task_id].get("is_correct", False):
refresh_count += 1
else:
update_count += 1
else:
# 非 hybrid 模式,或 test 数据集:所有已有记录都需要重新执行
update_count += 1
else:
new_count += 1
tasks_to_execute.append(item)
total_to_execute = len(tasks_to_execute)
print(f"\n📊 统计信息:")
print(f" 数据集: {split}")
print(f" 执行模式: {mode}")
if target_task_id:
print(f" 指定 task_id: {target_task_id}")
print(f" 总用例数: {num_to_run}")
if existing_results:
if mode == "error":
print(f" 需要执行: {total_to_execute} (包含 ERROR 的记录: {error_count})")
elif mode == "hybrid":
print(
f" 需要执行: {total_to_execute} (新用例: {new_count}, 刷新字段顺序: {refresh_count}, 重新测试: {update_count})")
else:
print(
f" 需要执行: {total_to_execute} (新用例: {new_count}, 刷新已有记录: {refresh_count + update_count})")
else:
print(f" 需要执行: {total_to_execute} (全量执行)")
print(f"🚀 开始测试 | 并发数: {threads} | 待执行: {total_to_execute}")
if total_to_execute == 0:
if mode == "error":
print("✅ 没有包含 ERROR 的记录,无需执行")
elif split == "validation":
print("✅ 所有用例已完成且正确,无需执行")
else:
print("✅ 所有用例已完成,无需执行")
generate_report()
generate_submission()
return
# 3. 并发执行
with ThreadPoolExecutor(max_workers=threads) as executor:
future_to_item = {executor.submit(process_item, item, existing_results, mode): item for item in
tasks_to_execute}
done = 0
for future in as_completed(future_to_item):
done += 1
item = future_to_item[future]
tid = item['task_id']
try:
_, is_ok, status = future.result()
if status == "refreshed":
status_icon = "🔄"
elif split == "validation":
status_icon = "✅" if is_ok else "❌"
else: # test
status_icon = "✅"
# 计算并显示当前整体正确率
accuracy_info = ""
if split == "validation":
current_accuracy = get_current_accuracy()
if current_accuracy is not None:
accuracy_info = f" | 当前正确率: {current_accuracy:.2f}%"
print(f"[{done}/{total_to_execute}] ID: {tid} | 状态: {status_icon} ({status}){accuracy_info}")
except Exception as e:
error_traceback = traceback.format_exc()
print(f"[{done}/{total_to_execute}] ID: {tid} 运行异常: {e}")
print(f"异常堆栈:\n{error_traceback}")
# 4. 生成报表
generate_report()
generate_submission()
def print_help():
"""打印详细的帮助信息"""
print("=" * 70)
print("GAIA 测试脚本 - 参数说明")
print("=" * 70)
print()
print("用法:")
print(" python gaia_test.py [参数]")
print()
print("参数说明:")
print()
print(" --split <类型>")
print(" 数据集类型")
print(" 可选值: validation, test")
print(" 默认值: validation")
print(" 说明:")
print(" - validation: 验证集,有标准答案,可以计算正确率")
print(" - test: 测试集,无标准答案,用于最终提交")
print()
print(" --mode <模式>")
print(" 执行模式")
print(" 可选值: full, incremental, hybrid, error")
print(" 默认值: hybrid")
print(" 说明:")
print(" - full: 全量模式,删除旧结果文件,从头开始执行")
print(" - incremental: 增量模式,如果文件存在则增量,不存在则全量执行")
print(" - hybrid: 混合模式(推荐),首次全量,后续增量")
print(" 在 hybrid 模式下,validation 数据集中已正确的记录")
print(" 只刷新字段顺序,不重新调用 agent")
print(" - error: 错误模式,只重新执行 agent_answer 包含 ERROR 的记录")
print()
print(" --num <数量>")
print(" 要执行的用例数量")
print(" 类型: 整数")
print(" 默认值: 200")
print(" 说明:")
print(" - test 数据集共 285 题,可以设置 --num 285 执行全部")
print(" - validation 数据集可以根据需要设置数量")
print()
print(" --threads <数量>")
print(" 并发执行的线程数")
print(" 类型: 整数")
print(" 默认值: 2")
print(" 说明:")
print(" - 根据服务器性能调整,过高可能导致服务器压力过大")
print(" - 建议范围: 1-4")
print()
print(" --task-id ")
print(" 指定要运行的 task_id")
print(" 类型: 字符串")
print(" 默认值: 无(运行多个用例)")
print(" 说明:")
print(" - 如果指定此参数,则只运行该 task_id 对应的用例")
print(" - 指定此参数时,--num 参数会被忽略")
print(" - 如果指定的 task_id 不存在,脚本会报错并退出")
print()
print(" -h, --help")
print(" 显示此帮助信息并退出")
print()
print("示例:")
print(" # 使用默认参数(validation 数据集,hybrid 模式,200 题)")
print(" python gaia_test.py")
print()
print(" # 测试 test 数据集,执行全部 285 题")
print(" python gaia_test.py --split test --num 285")
print()
print(" # 使用 error 模式重新执行错误记录")
print(" python gaia_test.py --mode error")
print()
print(" # 使用全量模式,4 个并发线程")
print(" python gaia_test.py --mode full --threads 4")
print()
print(" # 运行指定的 task_id")
print(" python gaia_test.py --task-id c61d22de-5f6c-4958-a7f6-5e9707bd3466")
print()
print("=" * 70)
print("配置文件:")
print(" 运行前请确保已正确配置 gaia_test.py 中的以下参数:")
print(" - BASE_URL: API 服务地址")
print(" - HEADERS: 认证 Token(必须修改)")
print(" - DATA_PATH: 数据文件路径")
print(" - REQUEST_TIMEOUT: 请求超时时间")
print(" - MAX_CONCURRENT: 最大并发数")
print("=" * 70)
if __name__ == "__main__":
import sys
# 检查是否有 -h 或 --help 参数
if "-h" in sys.argv or "--help" in sys.argv:
print_help()
sys.exit(0)
parser = argparse.ArgumentParser(
description="GAIA 测试脚本(支持 validation 和 test 数据集,支持全量、增量、混合三种模式)",
formatter_class=argparse.RawDescriptionHelpFormatter
)
parser.add_argument(
"--split",
type=str,
choices=["validation", "test"],
default="validation",
help="数据集类型: 'validation' 验证集(有标准答案)、'test' 测试集(无标准答案,默认: validation)"
)
parser.add_argument(
"--mode",
type=str,
choices=["full", "incremental", "hybrid", "error"],
default="hybrid",
help="执行模式: 'full' 全量模式(删除旧文件重新执行)、'incremental' 增量模式(文件存在则增量,不存在则全量)、'hybrid' 混合模式(首次全量,后续增量,默认)、'error' 错误模式(只重新执行 agent_answer 包含 ERROR 的记录)"
)
parser.add_argument(
"--num",
type=int,
default=200,
help="要执行的用例数量(默认: 200,test 集共 285 题)"
)
parser.add_argument(
"--threads",
type=int,
default=MAX_CONCURRENT,
help="执行的并发数(默认: 2)"
)
parser.add_argument(
"--task-id",
type=str,
default=None,
help="指定要运行的 task_id,如果指定则只运行该用例(忽略 --num 参数)"
)
args = parser.parse_args()
run_test_concurrent(num_questions=args.num, mode=args.mode, split=args.split, threads=args.threads, target_task_id=args.task_id)