""" 测试脚本:支持 validation 和 test 数据集,支持全量、增量、混合三种模式 - 全量模式(full):不管文件是否存在,都删除重新开始 - 增量模式(incremental):如果文件存在则增量,不存在则全量执行 - 混合模式(hybrid):第一次时全量(文件不存在),后面就增量(文件存在) """ import argparse import json import os import re import time import traceback from collections import OrderedDict from concurrent.futures import ThreadPoolExecutor, as_completed from threading import Lock import pandas as pd import requests from datasets import load_dataset from huggingface_hub import snapshot_download, hf_hub_download from pathlib import Path import shutil # --- 1. 配置区 --- BASE_URL = "http://localhost:5173/api/v1" CHAT_URL = f"{BASE_URL}/sessions/chat" UPLOAD_URL = f"{BASE_URL}/files" HEADERS = {"Authorization": "Bearer hawk_YhCZLQYqtPOwOiEyEgeCNdfAFAbrHtTUxQvRiaOInyekgVgE"} DATA_PATH = "./gaia_data" REQUEST_TIMEOUT = 1800 MAX_CONCURRENT = 2 file_lock = Lock() # 全局变量,根据 split 类型动态设置 dataset = None OUTPUT_FILE = None SUBMISSION_FILE = None SPLIT_TYPE = None # "validation" 或 "test" def check_and_download_dataset_files(): """ 检查并下载完整的 GAIA 数据集文件到本地目录(包含 validation 和 test 的所有文件) Returns: bool: 如果文件已存在或下载成功返回 True,否则返回 False """ base_target_dir = Path(DATA_PATH) / "2023" validation_dir = base_target_dir / "validation" test_dir = base_target_dir / "test" # 检查两个目录是否都存在且有文件 validation_files = list(validation_dir.glob("*")) if validation_dir.exists() else [] test_files = list(test_dir.glob("*")) if test_dir.exists() else [] if validation_files and test_files: print(f"✅ 检测到数据集文件已存在") print(f" validation 文件数: {len(validation_files)}") print(f" test 文件数: {len(test_files)}") return True # 需要下载数据集文件 print(f"📥 开始下载完整的 GAIA 数据集文件...") print(f" 目标目录: {base_target_dir}") try: # 创建基础目录 base_target_dir.mkdir(parents=True, exist_ok=True) # 下载完整数据集到临时目录,然后复制到目标目录 print(" 步骤 1/4: 正在从 Hugging Face 下载完整数据集...") print(" 提示: 下载进度会显示在下方,请耐心等待...") download_start = time.time() # 使用 snapshot_download 下载完整数据集 cache_dir = snapshot_download( repo_id="gaia-benchmark/GAIA", repo_type="dataset", local_dir=None, # 使用默认缓存目录 resume_download=True ) download_duration = time.time() - download_start print(f" ✅ 数据集下载完成,耗时 {download_duration:.2f} 秒") cache_path = Path(cache_dir) source_2023_dir = cache_path / "2023" if not source_2023_dir.exists(): print(f" ❌ 错误: 缓存目录中未找到 2023 目录") return False # 复制 validation 和 test 目录 print(" 步骤 2/4: 正在复制 validation 文件...") validation_source = source_2023_dir / "validation" if validation_source.exists(): if validation_dir.exists(): shutil.rmtree(validation_dir) shutil.copytree(validation_source, validation_dir) validation_count = len(list(validation_dir.glob("*"))) print(f" ✅ validation 文件复制完成,共 {validation_count} 个文件") else: print(f" ⚠️ 警告: 未找到 validation 目录") print(" 步骤 3/4: 正在复制 test 文件...") test_source = source_2023_dir / "test" if test_source.exists(): if test_dir.exists(): shutil.rmtree(test_dir) shutil.copytree(test_source, test_dir) test_count = len(list(test_dir.glob("*"))) print(f" ✅ test 文件复制完成,共 {test_count} 个文件") else: print(f" ⚠️ 警告: 未找到 test 目录") print(" 步骤 4/4: 数据集文件准备完成!") print(f" 目标目录: {base_target_dir}") return True except Exception as e: print(f" ❌ 下载数据集文件时出错: {e}") import traceback traceback.print_exc() return False def build_ordered_record(task_id, question, level, agent_answer, duration, has_file, session_id=None, attachment_name=None, ground_truth=None, is_correct=None): """ 按照固定顺序构建记录字典,确保字段顺序一致 Args: task_id: 任务ID question: 问题 level: 难度级别 agent_answer: Agent答案 duration: 执行时长 has_file: 是否有文件 session_id: 会话 ID attachment_name: 附件名称(如果有附件) ground_truth: 标准答案(仅validation数据集) is_correct: 是否正确(仅validation数据集) Returns: OrderedDict: 按固定顺序排列的记录 """ record = OrderedDict() record["task_id"] = task_id record["question"] = question record["level"] = level record["duration"] = duration record["has_file"] = has_file # attachment_name: 如果有值就写入(即使 agent 出错也应该写入) # 只要 attachment_name 不是 None 且不是空字符串,就写入 if attachment_name and attachment_name.strip(): record["attachment_name"] = attachment_name # session_id: 如果有值就写入(如果 agent 出错可能为 None,不写入是合理的) if session_id: record["session_id"] = session_id record["agent_answer"] = agent_answer # validation 数据集特有字段 if ground_truth is not None: record["ground_truth"] = ground_truth if is_correct is not None: record["is_correct"] = is_correct return record def load_existing_results(): """ 加载已有的测试结果文件 返回: dict, task_id -> 完整记录字典 """ if not os.path.exists(OUTPUT_FILE): return {} results = {} try: with open(OUTPUT_FILE, "r", encoding="utf-8") as f: for line in f: if not line.strip(): continue try: data = json.loads(line) task_id = data.get("task_id") if task_id: results[task_id] = data except json.JSONDecodeError: continue print(f"✅ 已加载 {len(results)} 条历史记录") except Exception as e: print(f"⚠️ 加载历史记录时出错: {e}") return {} return results def update_result_in_file(task_id, new_record): """ 更新 jsonl 文件中指定 task_id 的记录 使用临时文件方式,确保线程安全 """ if not os.path.exists(OUTPUT_FILE): # 如果文件不存在,直接写入 with file_lock: with open(OUTPUT_FILE, "w", encoding="utf-8") as f: f.write(json.dumps(new_record, ensure_ascii=False) + "\n") return # 读取所有记录,更新指定记录,写回文件 with file_lock: temp_file = OUTPUT_FILE + ".tmp" updated = False try: with open(OUTPUT_FILE, "r", encoding="utf-8") as f_in, \ open(temp_file, "w", encoding="utf-8") as f_out: for line in f_in: if not line.strip(): continue try: data = json.loads(line) if data.get("task_id") == task_id: # 更新这条记录 f_out.write(json.dumps(new_record, ensure_ascii=False) + "\n") updated = True else: # 保持原记录 f_out.write(line) except json.JSONDecodeError: continue # 如果没找到要更新的记录,追加新记录 if not updated: f_out.write(json.dumps(new_record, ensure_ascii=False) + "\n") # 替换原文件 os.replace(temp_file, OUTPUT_FILE) except Exception as e: # 如果出错,删除临时文件 if os.path.exists(temp_file): os.remove(temp_file) raise e def upload_file(local_path): """上传文件并返回符合接口要求的 file_id 和 filename""" try: if not os.path.exists(local_path): print(f"❌ 本地文件不存在: {local_path}") return None with open(local_path, 'rb') as f: files = {'file': f} response = requests.post(UPLOAD_URL, headers=HEADERS, files=files, timeout=60) response.raise_for_status() res_data = response.json() if res_data.get("code") == 0: file_info = res_data.get("data", {}) return { "file_id": file_info.get("file_id"), "filename": file_info.get("filename") } else: print(f"❌ 上传接口返回错误: {res_data.get('msg')}") except Exception as e: print(f"❌ 文件上传异常 ({os.path.basename(local_path)}): {e}") return None def extract_answer(text): """从文本中提取答案""" if not text: return "" pattern = r"(?si)<\s*answer\s*>\s*(.*?)\s*" match = re.search(pattern, text) if match: ans = match.group(1).strip() return re.sub(r'^["\']|["\']$', '', ans) backup_pattern = r"(?i)answer\s*is[::]\s*(.*)" backup_match = re.search(backup_pattern, text) if backup_match: return backup_match.group(1).strip().rstrip('.') lines = [l.strip() for l in text.strip().split('\n') if l.strip()] return lines[-1] if lines else text.strip() def call_my_agent_safe(question, attachments=None, task_id=None): """ 发送对话请求,包含附件数组 Args: question: 问题内容 attachments: 附件列表 task_id: 任务ID,用于确保会话隔离 Returns: tuple: (parsed_answer, session_id, raw_content) """ guided_prompt = ( f"{question}\n\n Important Requirement: \nprovide the final answer (the answer only, without explanation) inside the tags in the following format: your answer" ) payload = { "message": guided_prompt, "streaming": False, "attachments": attachments if attachments else [], "recycle_sandbox": True, # 明确指定创建新会话,避免会话内容混乱 # 如果 API 支持 session_id 参数,设置为 null 表示创建新会话 # 如果不支持,则不传递 session_id 参数(当前做法) } # 如果 API 支持,可以尝试以下方式之一来确保创建新会话: # 1. payload["session_id"] = None # 明确创建新会话 # 2. payload["new_session"] = True # 如果 API 支持此参数 # 3. 在请求头中添加唯一标识 if task_id: # 添加 task_id 作为请求标识,帮助后端区分不同请求,确保会话隔离 payload["task_id"] = task_id # 在请求头中添加唯一标识,进一步确保请求隔离 # 如果后端支持,可以通过 X-Request-ID 或类似头部来区分请求 request_headers = HEADERS.copy() if task_id: # 添加 task_id 到请求头,帮助后端识别和隔离不同请求 request_headers["X-Task-ID"] = task_id try: response = requests.post(CHAT_URL, headers=request_headers, json=payload, timeout=(30, REQUEST_TIMEOUT)) response.raise_for_status() res_data = response.json() raw_content = (res_data.get("answer") or res_data.get("content") or res_data.get("response") or "").strip() session_id = res_data.get("session_id") parsed_answer = extract_answer(raw_content) return parsed_answer, session_id, raw_content except Exception as e: error_traceback = traceback.format_exc() return f"ERROR: {str(e)}", session_id, error_traceback def process_item(item, existing_results, mode): """ 处理单条数据:上传文件 -> 发起对话 -> 记录结果 hybrid + validation 模式下:如果记录已存在且 is_correct 为 true,则跳过 agent 调用,只刷新字段顺序 其他情况:所有记录都重新执行并刷新,确保字段顺序一致 (test 数据集没有 is_correct 字段,无法判断是否正确,所以总是重新执行) Args: item: 数据集项 existing_results: 已有结果字典 mode: 执行模式 ("full"、"incremental" 或 "hybrid") """ task_id = item['task_id'] level = item.get('Level', 'Unknown') question = item['Question'] file_name = item.get('file_name', "") # hybrid + validation 模式下:如果记录已存在且成功,只刷新字段顺序,不调用 agent # 只有 validation 数据集有 is_correct 字段,可以判断是否正确 if mode == "hybrid" and SPLIT_TYPE == "validation" and task_id in existing_results: existing_record = existing_results[task_id] if existing_record.get("is_correct", False): # 已成功,只刷新字段顺序,不调用 agent # 使用当前的 file_name 更新 attachment_name,确保数据一致性 current_has_file = bool(file_name) current_attachment_name = file_name if file_name else None record = build_ordered_record( task_id=task_id, question=existing_record.get("question", question), level=existing_record.get("level", level), agent_answer=existing_record.get("agent_answer", ""), duration=existing_record.get("duration", 0), has_file=current_has_file, session_id=existing_record.get("session_id"), attachment_name=current_attachment_name, ground_truth=existing_record.get("ground_truth", ""), is_correct=True ) # 更新已有记录(刷新字段顺序) update_result_in_file(task_id, record) return task_id, True, "refreshed" # 需要调用 agent 的情况(新记录、错误记录、或非 hybrid 模式) attachments = [] # 1. 如果有文件,先执行上传 if file_name: # 根据 split 类型选择不同的文件夹 folder = "validation" if SPLIT_TYPE == "validation" else "test" local_file_path = os.path.abspath(os.path.join(DATA_PATH, "2023", folder, file_name)) upload_data = upload_file(local_file_path) if upload_data: attachments.append(upload_data) # 2. 调用 Agent(传递 task_id 确保会话隔离) start_time = time.time() agent_answer, session_id, _ = call_my_agent_safe(question, attachments, task_id=task_id) duration = time.time() - start_time # 3. 构建记录(使用固定顺序) if SPLIT_TYPE == "validation": # validation 数据集:添加标准答案和正确性判断 ground_truth = str(item['Final answer']).strip() clean_agent = str(agent_answer).lower().rstrip('.') clean_gt = ground_truth.lower().rstrip('.') is_correct = (clean_agent == clean_gt) record = build_ordered_record( task_id=task_id, question=question, level=level, duration=round(duration, 2), has_file=bool(file_name), session_id=session_id, attachment_name=file_name if file_name else None, agent_answer=agent_answer, ground_truth=ground_truth, is_correct=is_correct ) result_correct = is_correct else: # test 数据集:没有标准答案 record = build_ordered_record( task_id=task_id, question=question, level=level, duration=round(duration, 2), has_file=bool(file_name), session_id=session_id, attachment_name=file_name if file_name else None, agent_answer=agent_answer ) result_correct = None # 4. 更新或追加记录 if task_id in existing_results: # 更新已有记录 update_result_in_file(task_id, record) return task_id, result_correct, "updated" else: # 追加新记录 with file_lock: with open(OUTPUT_FILE, "a", encoding="utf-8") as f: f.write(json.dumps(record, ensure_ascii=False) + "\n") return task_id, result_correct, "new" def generate_submission(): """ 生成官网提交格式文件 GAIA 提交格式要求: - 文件格式:JSONL(每行一个 JSON 对象) - 必需字段:task_id, model_answer - 编码:UTF-8 - test 数据集需要包含所有 285 个用例的答案 """ if not os.path.exists(OUTPUT_FILE): print(f"⚠️ 警告:结果文件 {OUTPUT_FILE} 不存在,无法生成提交文件") return # 读取所有结果并按 task_id 排序(确保顺序一致) results = [] with open(OUTPUT_FILE, "r", encoding="utf-8") as f: for line in f: if not line.strip(): continue try: data = json.loads(line) if "task_id" in data and "agent_answer" in data: results.append(data) except json.JSONDecodeError: continue if not results: print(f"⚠️ 警告:结果文件 {OUTPUT_FILE} 中没有有效数据") return # 按 task_id 排序,确保顺序一致 results.sort(key=lambda x: x.get("task_id", "")) # 生成提交文件 with open(SUBMISSION_FILE, "w", encoding="utf-8") as f_out: for data in results: submission_data = { "task_id": data["task_id"], "model_answer": str(data["agent_answer"]) } f_out.write(json.dumps(submission_data, ensure_ascii=False) + "\n") print(f"✅ 提交文件已生成: {SUBMISSION_FILE} (共 {len(results)} 条记录)") # test 数据集验证:检查是否包含所有用例 if SPLIT_TYPE == "test": expected_count = 285 if len(results) < expected_count: print(f"⚠️ 警告:test 数据集应该有 {expected_count} 个用例,当前只有 {len(results)} 个") else: print(f"✅ test 数据集已包含 {len(results)} 个用例,符合提交要求") def get_current_accuracy(): """ 获取当前的整体正确率(仅 validation 数据集) Returns: float or None: 正确率(百分比),如果不是 validation 数据集或文件不存在则返回 None """ # test 数据集没有标准答案,无法计算正确率 if SPLIT_TYPE != "validation": return None if not os.path.exists(OUTPUT_FILE): return None try: results = [] with open(OUTPUT_FILE, "r", encoding="utf-8") as f: for line in f: if not line.strip(): continue try: data = json.loads(line) results.append(data) except json.JSONDecodeError: continue if not results: return None total = len(results) correct = sum(1 for r in results if r.get("is_correct", False)) accuracy = (correct / total * 100) if total > 0 else 0.0 return accuracy except Exception: return None def generate_report(): """生成统计成绩单(仅 validation 数据集有标准答案,才生成成绩单)""" # test 数据集没有标准答案,不生成成绩单 if SPLIT_TYPE != "validation": return if not os.path.exists(OUTPUT_FILE): return results = [json.loads(line) for line in open(OUTPUT_FILE, "r", encoding="utf-8")] df = pd.DataFrame(results) total = len(df) acc = (df['is_correct'].sum() / total) * 100 print("\n" + "=" * 50) print(f"测试完成! 总数: {total} | 总准确率: {acc:.2f}%") print("=" * 50) def run_test_concurrent(num_questions=200, mode="hybrid", split="validation", threads=MAX_CONCURRENT, target_task_id=None): """ 测试主函数 Args: num_questions: 要执行的用例数量 mode: 执行模式 - "full": 全量模式,不管文件是否存在,都删除重新开始 - "incremental": 增量模式,如果文件存在则增量,不存在则全量执行 - "hybrid": 混合模式,第一次时全量(文件不存在),后面就增量(文件存在) - "error": 错误模式,只重新执行 agent_answer 包含 ERROR 的记录 split: 数据集类型,"validation" 或 "test" threads: 并发线程数 target_task_id: 可选,指定要运行的 task_id,如果指定则只运行该用例 """ global dataset, OUTPUT_FILE, SUBMISSION_FILE, SPLIT_TYPE # 设置全局变量 SPLIT_TYPE = split # 根据 split 类型设置输出文件名 if split == "validation": OUTPUT_FILE = "validation_results.jsonl" SUBMISSION_FILE = "validation_submission.jsonl" print("📥 正在检查 GAIA 验证集数据...") else: # test OUTPUT_FILE = "test_results.jsonl" SUBMISSION_FILE = "test_submission.jsonl" print("📥 正在检查 GAIA 测试集数据...") # 1. 检查并下载完整数据集文件(如果需要,包含 validation 和 test 的所有文件) print("\n【步骤 1/2】检查数据集文件...") check_and_download_dataset_files() # 2. 加载数据集元数据(如果首次下载会显示下载进度) print(f"\n【步骤 2/2】加载数据集元数据...") print(f" 数据集: gaia-benchmark/GAIA (2023_all, split={split})") print(" 提示: 如果是首次下载,请耐心等待,下载进度会显示在下方...") print(" 如果已下载过,会直接从缓存加载,速度较快") start_time = time.time() dataset = load_dataset("gaia-benchmark/GAIA", "2023_all", split=split) load_duration = time.time() - start_time print(f"✅ 数据集元数据加载完成!共 {len(dataset)} 条记录,耗时 {load_duration:.2f} 秒\n") # 1. 根据模式处理已有结果 file_exists = os.path.exists(OUTPUT_FILE) if mode == "full": # 全量模式:删除旧文件,从头开始 if file_exists: os.remove(OUTPUT_FILE) print("🔄 全量模式:已删除旧结果文件,从头开始执行") existing_results = {} elif mode == "incremental": # 增量模式:如果文件存在则增量,不存在则全量执行 if file_exists: existing_results = load_existing_results() print(f"📋 增量模式:已加载 {len(existing_results)} 条历史记录") else: existing_results = {} print("📋 增量模式:未找到历史记录,将全量执行") elif mode == "error": # 错误模式:只重新执行 agent_answer 包含 ERROR 的记录 if file_exists: existing_results = load_existing_results() print(f"📋 错误模式:已加载 {len(existing_results)} 条历史记录,将重新执行包含 ERROR 的记录") else: existing_results = {} print("📋 错误模式:未找到历史记录,无法执行错误重试") else: # hybrid # 混合模式:第一次时全量(文件不存在),后面就增量(文件存在) if file_exists: existing_results = load_existing_results() print(f"📋 混合模式:检测到已有文件,进入增量模式(已加载 {len(existing_results)} 条历史记录)") else: existing_results = {} print("📋 混合模式:首次执行,进入全量模式") # 2. 筛选需要执行的用例 if target_task_id: # 如果指定了 task_id,只运行该用例 print(f"🎯 指定运行 task_id: {target_task_id}") tasks_to_run = [] found = False for item in dataset: if item['task_id'] == target_task_id: tasks_to_run = [item] found = True break if not found: print(f"❌ 错误: 在 {split} 数据集中未找到 task_id: {target_task_id}") return num_to_run = 1 else: # 正常模式,根据 num_questions 筛选 num_to_run = min(num_questions, len(dataset)) tasks_to_run = dataset.select(range(num_to_run)) # 统计需要执行的用例 tasks_to_execute = [] refresh_count = 0 # hybrid 模式下只刷新字段顺序的记录数 update_count = 0 # 需要重新调用 agent 的记录数 new_count = 0 # 新记录数 error_count = 0 # error 模式下包含 ERROR 的记录数 for item in tasks_to_run: task_id = item['task_id'] if mode == "error": # error 模式:只重新执行 agent_answer 包含 ERROR 的记录 if task_id in existing_results: agent_answer = existing_results[task_id].get("agent_answer", "") if agent_answer and "ERROR" in str(agent_answer): error_count += 1 tasks_to_execute.append(item) # 如果记录不存在或 agent_answer 不包含 ERROR,则跳过 else: # 其他模式:正常处理 if task_id in existing_results: # hybrid + validation 模式下:如果已成功,只刷新字段顺序 # test 数据集没有 is_correct 字段,无法判断是否正确,所以总是重新执行 if mode == "hybrid" and split == "validation": if existing_results[task_id].get("is_correct", False): refresh_count += 1 else: update_count += 1 else: # 非 hybrid 模式,或 test 数据集:所有已有记录都需要重新执行 update_count += 1 else: new_count += 1 tasks_to_execute.append(item) total_to_execute = len(tasks_to_execute) print(f"\n📊 统计信息:") print(f" 数据集: {split}") print(f" 执行模式: {mode}") if target_task_id: print(f" 指定 task_id: {target_task_id}") print(f" 总用例数: {num_to_run}") if existing_results: if mode == "error": print(f" 需要执行: {total_to_execute} (包含 ERROR 的记录: {error_count})") elif mode == "hybrid": print( f" 需要执行: {total_to_execute} (新用例: {new_count}, 刷新字段顺序: {refresh_count}, 重新测试: {update_count})") else: print( f" 需要执行: {total_to_execute} (新用例: {new_count}, 刷新已有记录: {refresh_count + update_count})") else: print(f" 需要执行: {total_to_execute} (全量执行)") print(f"🚀 开始测试 | 并发数: {threads} | 待执行: {total_to_execute}") if total_to_execute == 0: if mode == "error": print("✅ 没有包含 ERROR 的记录,无需执行") elif split == "validation": print("✅ 所有用例已完成且正确,无需执行") else: print("✅ 所有用例已完成,无需执行") generate_report() generate_submission() return # 3. 并发执行 with ThreadPoolExecutor(max_workers=threads) as executor: future_to_item = {executor.submit(process_item, item, existing_results, mode): item for item in tasks_to_execute} done = 0 for future in as_completed(future_to_item): done += 1 item = future_to_item[future] tid = item['task_id'] try: _, is_ok, status = future.result() if status == "refreshed": status_icon = "🔄" elif split == "validation": status_icon = "✅" if is_ok else "❌" else: # test status_icon = "✅" # 计算并显示当前整体正确率 accuracy_info = "" if split == "validation": current_accuracy = get_current_accuracy() if current_accuracy is not None: accuracy_info = f" | 当前正确率: {current_accuracy:.2f}%" print(f"[{done}/{total_to_execute}] ID: {tid} | 状态: {status_icon} ({status}){accuracy_info}") except Exception as e: error_traceback = traceback.format_exc() print(f"[{done}/{total_to_execute}] ID: {tid} 运行异常: {e}") print(f"异常堆栈:\n{error_traceback}") # 4. 生成报表 generate_report() generate_submission() def print_help(): """打印详细的帮助信息""" print("=" * 70) print("GAIA 测试脚本 - 参数说明") print("=" * 70) print() print("用法:") print(" python gaia_test.py [参数]") print() print("参数说明:") print() print(" --split <类型>") print(" 数据集类型") print(" 可选值: validation, test") print(" 默认值: validation") print(" 说明:") print(" - validation: 验证集,有标准答案,可以计算正确率") print(" - test: 测试集,无标准答案,用于最终提交") print() print(" --mode <模式>") print(" 执行模式") print(" 可选值: full, incremental, hybrid, error") print(" 默认值: hybrid") print(" 说明:") print(" - full: 全量模式,删除旧结果文件,从头开始执行") print(" - incremental: 增量模式,如果文件存在则增量,不存在则全量执行") print(" - hybrid: 混合模式(推荐),首次全量,后续增量") print(" 在 hybrid 模式下,validation 数据集中已正确的记录") print(" 只刷新字段顺序,不重新调用 agent") print(" - error: 错误模式,只重新执行 agent_answer 包含 ERROR 的记录") print() print(" --num <数量>") print(" 要执行的用例数量") print(" 类型: 整数") print(" 默认值: 200") print(" 说明:") print(" - test 数据集共 285 题,可以设置 --num 285 执行全部") print(" - validation 数据集可以根据需要设置数量") print() print(" --threads <数量>") print(" 并发执行的线程数") print(" 类型: 整数") print(" 默认值: 2") print(" 说明:") print(" - 根据服务器性能调整,过高可能导致服务器压力过大") print(" - 建议范围: 1-4") print() print(" --task-id ") print(" 指定要运行的 task_id") print(" 类型: 字符串") print(" 默认值: 无(运行多个用例)") print(" 说明:") print(" - 如果指定此参数,则只运行该 task_id 对应的用例") print(" - 指定此参数时,--num 参数会被忽略") print(" - 如果指定的 task_id 不存在,脚本会报错并退出") print() print(" -h, --help") print(" 显示此帮助信息并退出") print() print("示例:") print(" # 使用默认参数(validation 数据集,hybrid 模式,200 题)") print(" python gaia_test.py") print() print(" # 测试 test 数据集,执行全部 285 题") print(" python gaia_test.py --split test --num 285") print() print(" # 使用 error 模式重新执行错误记录") print(" python gaia_test.py --mode error") print() print(" # 使用全量模式,4 个并发线程") print(" python gaia_test.py --mode full --threads 4") print() print(" # 运行指定的 task_id") print(" python gaia_test.py --task-id c61d22de-5f6c-4958-a7f6-5e9707bd3466") print() print("=" * 70) print("配置文件:") print(" 运行前请确保已正确配置 gaia_test.py 中的以下参数:") print(" - BASE_URL: API 服务地址") print(" - HEADERS: 认证 Token(必须修改)") print(" - DATA_PATH: 数据文件路径") print(" - REQUEST_TIMEOUT: 请求超时时间") print(" - MAX_CONCURRENT: 最大并发数") print("=" * 70) if __name__ == "__main__": import sys # 检查是否有 -h 或 --help 参数 if "-h" in sys.argv or "--help" in sys.argv: print_help() sys.exit(0) parser = argparse.ArgumentParser( description="GAIA 测试脚本(支持 validation 和 test 数据集,支持全量、增量、混合三种模式)", formatter_class=argparse.RawDescriptionHelpFormatter ) parser.add_argument( "--split", type=str, choices=["validation", "test"], default="validation", help="数据集类型: 'validation' 验证集(有标准答案)、'test' 测试集(无标准答案,默认: validation)" ) parser.add_argument( "--mode", type=str, choices=["full", "incremental", "hybrid", "error"], default="hybrid", help="执行模式: 'full' 全量模式(删除旧文件重新执行)、'incremental' 增量模式(文件存在则增量,不存在则全量)、'hybrid' 混合模式(首次全量,后续增量,默认)、'error' 错误模式(只重新执行 agent_answer 包含 ERROR 的记录)" ) parser.add_argument( "--num", type=int, default=200, help="要执行的用例数量(默认: 200,test 集共 285 题)" ) parser.add_argument( "--threads", type=int, default=MAX_CONCURRENT, help="执行的并发数(默认: 2)" ) parser.add_argument( "--task-id", type=str, default=None, help="指定要运行的 task_id,如果指定则只运行该用例(忽略 --num 参数)" ) args = parser.parse_args() run_test_concurrent(num_questions=args.num, mode=args.mode, split=args.split, threads=args.threads, target_task_id=args.task_id)