Spaces:
Runtime error
Runtime error
| """ | |
| 测试脚本:支持 validation 和 test 数据集,支持全量、增量、混合三种模式 | |
| - 全量模式(full):不管文件是否存在,都删除重新开始 | |
| - 增量模式(incremental):如果文件存在则增量,不存在则全量执行 | |
| - 混合模式(hybrid):第一次时全量(文件不存在),后面就增量(文件存在) | |
| """ | |
| import argparse | |
| import json | |
| import os | |
| import re | |
| import time | |
| import traceback | |
| from collections import OrderedDict | |
| from concurrent.futures import ThreadPoolExecutor, as_completed | |
| from threading import Lock | |
| import pandas as pd | |
| import requests | |
| from datasets import load_dataset | |
| from huggingface_hub import snapshot_download, hf_hub_download | |
| from pathlib import Path | |
| import shutil | |
| # --- 1. 配置区 --- | |
| BASE_URL = "http://localhost:5173/api/v1" | |
| CHAT_URL = f"{BASE_URL}/sessions/chat" | |
| UPLOAD_URL = f"{BASE_URL}/files" | |
| HEADERS = {"Authorization": "Bearer hawk_YhCZLQYqtPOwOiEyEgeCNdfAFAbrHtTUxQvRiaOInyekgVgE"} | |
| DATA_PATH = "./gaia_data" | |
| REQUEST_TIMEOUT = 1800 | |
| MAX_CONCURRENT = 2 | |
| file_lock = Lock() | |
| # 全局变量,根据 split 类型动态设置 | |
| dataset = None | |
| OUTPUT_FILE = None | |
| SUBMISSION_FILE = None | |
| SPLIT_TYPE = None # "validation" 或 "test" | |
| def check_and_download_dataset_files(): | |
| """ | |
| 检查并下载完整的 GAIA 数据集文件到本地目录(包含 validation 和 test 的所有文件) | |
| Returns: | |
| bool: 如果文件已存在或下载成功返回 True,否则返回 False | |
| """ | |
| base_target_dir = Path(DATA_PATH) / "2023" | |
| validation_dir = base_target_dir / "validation" | |
| test_dir = base_target_dir / "test" | |
| # 检查两个目录是否都存在且有文件 | |
| validation_files = list(validation_dir.glob("*")) if validation_dir.exists() else [] | |
| test_files = list(test_dir.glob("*")) if test_dir.exists() else [] | |
| if validation_files and test_files: | |
| print(f"✅ 检测到数据集文件已存在") | |
| print(f" validation 文件数: {len(validation_files)}") | |
| print(f" test 文件数: {len(test_files)}") | |
| return True | |
| # 需要下载数据集文件 | |
| print(f"📥 开始下载完整的 GAIA 数据集文件...") | |
| print(f" 目标目录: {base_target_dir}") | |
| try: | |
| # 创建基础目录 | |
| base_target_dir.mkdir(parents=True, exist_ok=True) | |
| # 下载完整数据集到临时目录,然后复制到目标目录 | |
| print(" 步骤 1/4: 正在从 Hugging Face 下载完整数据集...") | |
| print(" 提示: 下载进度会显示在下方,请耐心等待...") | |
| download_start = time.time() | |
| # 使用 snapshot_download 下载完整数据集 | |
| cache_dir = snapshot_download( | |
| repo_id="gaia-benchmark/GAIA", | |
| repo_type="dataset", | |
| local_dir=None, # 使用默认缓存目录 | |
| resume_download=True | |
| ) | |
| download_duration = time.time() - download_start | |
| print(f" ✅ 数据集下载完成,耗时 {download_duration:.2f} 秒") | |
| cache_path = Path(cache_dir) | |
| source_2023_dir = cache_path / "2023" | |
| if not source_2023_dir.exists(): | |
| print(f" ❌ 错误: 缓存目录中未找到 2023 目录") | |
| return False | |
| # 复制 validation 和 test 目录 | |
| print(" 步骤 2/4: 正在复制 validation 文件...") | |
| validation_source = source_2023_dir / "validation" | |
| if validation_source.exists(): | |
| if validation_dir.exists(): | |
| shutil.rmtree(validation_dir) | |
| shutil.copytree(validation_source, validation_dir) | |
| validation_count = len(list(validation_dir.glob("*"))) | |
| print(f" ✅ validation 文件复制完成,共 {validation_count} 个文件") | |
| else: | |
| print(f" ⚠️ 警告: 未找到 validation 目录") | |
| print(" 步骤 3/4: 正在复制 test 文件...") | |
| test_source = source_2023_dir / "test" | |
| if test_source.exists(): | |
| if test_dir.exists(): | |
| shutil.rmtree(test_dir) | |
| shutil.copytree(test_source, test_dir) | |
| test_count = len(list(test_dir.glob("*"))) | |
| print(f" ✅ test 文件复制完成,共 {test_count} 个文件") | |
| else: | |
| print(f" ⚠️ 警告: 未找到 test 目录") | |
| print(" 步骤 4/4: 数据集文件准备完成!") | |
| print(f" 目标目录: {base_target_dir}") | |
| return True | |
| except Exception as e: | |
| print(f" ❌ 下载数据集文件时出错: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return False | |
| def build_ordered_record(task_id, question, level, agent_answer, duration, has_file, | |
| session_id=None, attachment_name=None, ground_truth=None, is_correct=None): | |
| """ | |
| 按照固定顺序构建记录字典,确保字段顺序一致 | |
| Args: | |
| task_id: 任务ID | |
| question: 问题 | |
| level: 难度级别 | |
| agent_answer: Agent答案 | |
| duration: 执行时长 | |
| has_file: 是否有文件 | |
| session_id: 会话 ID | |
| attachment_name: 附件名称(如果有附件) | |
| ground_truth: 标准答案(仅validation数据集) | |
| is_correct: 是否正确(仅validation数据集) | |
| Returns: | |
| OrderedDict: 按固定顺序排列的记录 | |
| """ | |
| record = OrderedDict() | |
| record["task_id"] = task_id | |
| record["question"] = question | |
| record["level"] = level | |
| record["duration"] = duration | |
| record["has_file"] = has_file | |
| # attachment_name: 如果有值就写入(即使 agent 出错也应该写入) | |
| # 只要 attachment_name 不是 None 且不是空字符串,就写入 | |
| if attachment_name and attachment_name.strip(): | |
| record["attachment_name"] = attachment_name | |
| # session_id: 如果有值就写入(如果 agent 出错可能为 None,不写入是合理的) | |
| if session_id: | |
| record["session_id"] = session_id | |
| record["agent_answer"] = agent_answer | |
| # validation 数据集特有字段 | |
| if ground_truth is not None: | |
| record["ground_truth"] = ground_truth | |
| if is_correct is not None: | |
| record["is_correct"] = is_correct | |
| return record | |
| def load_existing_results(): | |
| """ | |
| 加载已有的测试结果文件 | |
| 返回: dict, task_id -> 完整记录字典 | |
| """ | |
| if not os.path.exists(OUTPUT_FILE): | |
| return {} | |
| results = {} | |
| try: | |
| with open(OUTPUT_FILE, "r", encoding="utf-8") as f: | |
| for line in f: | |
| if not line.strip(): | |
| continue | |
| try: | |
| data = json.loads(line) | |
| task_id = data.get("task_id") | |
| if task_id: | |
| results[task_id] = data | |
| except json.JSONDecodeError: | |
| continue | |
| print(f"✅ 已加载 {len(results)} 条历史记录") | |
| except Exception as e: | |
| print(f"⚠️ 加载历史记录时出错: {e}") | |
| return {} | |
| return results | |
| def update_result_in_file(task_id, new_record): | |
| """ | |
| 更新 jsonl 文件中指定 task_id 的记录 | |
| 使用临时文件方式,确保线程安全 | |
| """ | |
| if not os.path.exists(OUTPUT_FILE): | |
| # 如果文件不存在,直接写入 | |
| with file_lock: | |
| with open(OUTPUT_FILE, "w", encoding="utf-8") as f: | |
| f.write(json.dumps(new_record, ensure_ascii=False) + "\n") | |
| return | |
| # 读取所有记录,更新指定记录,写回文件 | |
| with file_lock: | |
| temp_file = OUTPUT_FILE + ".tmp" | |
| updated = False | |
| try: | |
| with open(OUTPUT_FILE, "r", encoding="utf-8") as f_in, \ | |
| open(temp_file, "w", encoding="utf-8") as f_out: | |
| for line in f_in: | |
| if not line.strip(): | |
| continue | |
| try: | |
| data = json.loads(line) | |
| if data.get("task_id") == task_id: | |
| # 更新这条记录 | |
| f_out.write(json.dumps(new_record, ensure_ascii=False) + "\n") | |
| updated = True | |
| else: | |
| # 保持原记录 | |
| f_out.write(line) | |
| except json.JSONDecodeError: | |
| continue | |
| # 如果没找到要更新的记录,追加新记录 | |
| if not updated: | |
| f_out.write(json.dumps(new_record, ensure_ascii=False) + "\n") | |
| # 替换原文件 | |
| os.replace(temp_file, OUTPUT_FILE) | |
| except Exception as e: | |
| # 如果出错,删除临时文件 | |
| if os.path.exists(temp_file): | |
| os.remove(temp_file) | |
| raise e | |
| def upload_file(local_path): | |
| """上传文件并返回符合接口要求的 file_id 和 filename""" | |
| try: | |
| if not os.path.exists(local_path): | |
| print(f"❌ 本地文件不存在: {local_path}") | |
| return None | |
| with open(local_path, 'rb') as f: | |
| files = {'file': f} | |
| response = requests.post(UPLOAD_URL, headers=HEADERS, files=files, timeout=60) | |
| response.raise_for_status() | |
| res_data = response.json() | |
| if res_data.get("code") == 0: | |
| file_info = res_data.get("data", {}) | |
| return { | |
| "file_id": file_info.get("file_id"), | |
| "filename": file_info.get("filename") | |
| } | |
| else: | |
| print(f"❌ 上传接口返回错误: {res_data.get('msg')}") | |
| except Exception as e: | |
| print(f"❌ 文件上传异常 ({os.path.basename(local_path)}): {e}") | |
| return None | |
| def extract_answer(text): | |
| """从文本中提取答案""" | |
| if not text: | |
| return "" | |
| pattern = r"(?si)<\s*answer\s*>\s*(.*?)\s*</\s*answer\s*>" | |
| match = re.search(pattern, text) | |
| if match: | |
| ans = match.group(1).strip() | |
| return re.sub(r'^["\']|["\']$', '', ans) | |
| backup_pattern = r"(?i)answer\s*is[::]\s*(.*)" | |
| backup_match = re.search(backup_pattern, text) | |
| if backup_match: | |
| return backup_match.group(1).strip().rstrip('.') | |
| lines = [l.strip() for l in text.strip().split('\n') if l.strip()] | |
| return lines[-1] if lines else text.strip() | |
| def call_my_agent_safe(question, attachments=None, task_id=None): | |
| """ | |
| 发送对话请求,包含附件数组 | |
| Args: | |
| question: 问题内容 | |
| attachments: 附件列表 | |
| task_id: 任务ID,用于确保会话隔离 | |
| Returns: | |
| tuple: (parsed_answer, session_id, raw_content) | |
| """ | |
| guided_prompt = ( | |
| f"{question}\n\n Important Requirement: \nprovide the final answer (the answer only, without explanation) inside the tags in the following format: <answer>your answer</answer>" | |
| ) | |
| payload = { | |
| "message": guided_prompt, | |
| "streaming": False, | |
| "attachments": attachments if attachments else [], | |
| "recycle_sandbox": True, | |
| # 明确指定创建新会话,避免会话内容混乱 | |
| # 如果 API 支持 session_id 参数,设置为 null 表示创建新会话 | |
| # 如果不支持,则不传递 session_id 参数(当前做法) | |
| } | |
| # 如果 API 支持,可以尝试以下方式之一来确保创建新会话: | |
| # 1. payload["session_id"] = None # 明确创建新会话 | |
| # 2. payload["new_session"] = True # 如果 API 支持此参数 | |
| # 3. 在请求头中添加唯一标识 | |
| if task_id: | |
| # 添加 task_id 作为请求标识,帮助后端区分不同请求,确保会话隔离 | |
| payload["task_id"] = task_id | |
| # 在请求头中添加唯一标识,进一步确保请求隔离 | |
| # 如果后端支持,可以通过 X-Request-ID 或类似头部来区分请求 | |
| request_headers = HEADERS.copy() | |
| if task_id: | |
| # 添加 task_id 到请求头,帮助后端识别和隔离不同请求 | |
| request_headers["X-Task-ID"] = task_id | |
| try: | |
| response = requests.post(CHAT_URL, headers=request_headers, json=payload, timeout=(30, REQUEST_TIMEOUT)) | |
| response.raise_for_status() | |
| res_data = response.json() | |
| raw_content = (res_data.get("answer") or res_data.get("content") or res_data.get("response") or "").strip() | |
| session_id = res_data.get("session_id") | |
| parsed_answer = extract_answer(raw_content) | |
| return parsed_answer, session_id, raw_content | |
| except Exception as e: | |
| error_traceback = traceback.format_exc() | |
| return f"ERROR: {str(e)}", session_id, error_traceback | |
| def process_item(item, existing_results, mode): | |
| """ | |
| 处理单条数据:上传文件 -> 发起对话 -> 记录结果 | |
| hybrid + validation 模式下:如果记录已存在且 is_correct 为 true,则跳过 agent 调用,只刷新字段顺序 | |
| 其他情况:所有记录都重新执行并刷新,确保字段顺序一致 | |
| (test 数据集没有 is_correct 字段,无法判断是否正确,所以总是重新执行) | |
| Args: | |
| item: 数据集项 | |
| existing_results: 已有结果字典 | |
| mode: 执行模式 ("full"、"incremental" 或 "hybrid") | |
| """ | |
| task_id = item['task_id'] | |
| level = item.get('Level', 'Unknown') | |
| question = item['Question'] | |
| file_name = item.get('file_name', "") | |
| # hybrid + validation 模式下:如果记录已存在且成功,只刷新字段顺序,不调用 agent | |
| # 只有 validation 数据集有 is_correct 字段,可以判断是否正确 | |
| if mode == "hybrid" and SPLIT_TYPE == "validation" and task_id in existing_results: | |
| existing_record = existing_results[task_id] | |
| if existing_record.get("is_correct", False): | |
| # 已成功,只刷新字段顺序,不调用 agent | |
| # 使用当前的 file_name 更新 attachment_name,确保数据一致性 | |
| current_has_file = bool(file_name) | |
| current_attachment_name = file_name if file_name else None | |
| record = build_ordered_record( | |
| task_id=task_id, | |
| question=existing_record.get("question", question), | |
| level=existing_record.get("level", level), | |
| agent_answer=existing_record.get("agent_answer", ""), | |
| duration=existing_record.get("duration", 0), | |
| has_file=current_has_file, | |
| session_id=existing_record.get("session_id"), | |
| attachment_name=current_attachment_name, | |
| ground_truth=existing_record.get("ground_truth", ""), | |
| is_correct=True | |
| ) | |
| # 更新已有记录(刷新字段顺序) | |
| update_result_in_file(task_id, record) | |
| return task_id, True, "refreshed" | |
| # 需要调用 agent 的情况(新记录、错误记录、或非 hybrid 模式) | |
| attachments = [] | |
| # 1. 如果有文件,先执行上传 | |
| if file_name: | |
| # 根据 split 类型选择不同的文件夹 | |
| folder = "validation" if SPLIT_TYPE == "validation" else "test" | |
| local_file_path = os.path.abspath(os.path.join(DATA_PATH, "2023", folder, file_name)) | |
| upload_data = upload_file(local_file_path) | |
| if upload_data: | |
| attachments.append(upload_data) | |
| # 2. 调用 Agent(传递 task_id 确保会话隔离) | |
| start_time = time.time() | |
| agent_answer, session_id, _ = call_my_agent_safe(question, attachments, task_id=task_id) | |
| duration = time.time() - start_time | |
| # 3. 构建记录(使用固定顺序) | |
| if SPLIT_TYPE == "validation": | |
| # validation 数据集:添加标准答案和正确性判断 | |
| ground_truth = str(item['Final answer']).strip() | |
| clean_agent = str(agent_answer).lower().rstrip('.') | |
| clean_gt = ground_truth.lower().rstrip('.') | |
| is_correct = (clean_agent == clean_gt) | |
| record = build_ordered_record( | |
| task_id=task_id, | |
| question=question, | |
| level=level, | |
| duration=round(duration, 2), | |
| has_file=bool(file_name), | |
| session_id=session_id, | |
| attachment_name=file_name if file_name else None, | |
| agent_answer=agent_answer, | |
| ground_truth=ground_truth, | |
| is_correct=is_correct | |
| ) | |
| result_correct = is_correct | |
| else: # test 数据集:没有标准答案 | |
| record = build_ordered_record( | |
| task_id=task_id, | |
| question=question, | |
| level=level, | |
| duration=round(duration, 2), | |
| has_file=bool(file_name), | |
| session_id=session_id, | |
| attachment_name=file_name if file_name else None, | |
| agent_answer=agent_answer | |
| ) | |
| result_correct = None | |
| # 4. 更新或追加记录 | |
| if task_id in existing_results: | |
| # 更新已有记录 | |
| update_result_in_file(task_id, record) | |
| return task_id, result_correct, "updated" | |
| else: | |
| # 追加新记录 | |
| with file_lock: | |
| with open(OUTPUT_FILE, "a", encoding="utf-8") as f: | |
| f.write(json.dumps(record, ensure_ascii=False) + "\n") | |
| return task_id, result_correct, "new" | |
| def generate_submission(): | |
| """ | |
| 生成官网提交格式文件 | |
| GAIA 提交格式要求: | |
| - 文件格式:JSONL(每行一个 JSON 对象) | |
| - 必需字段:task_id, model_answer | |
| - 编码:UTF-8 | |
| - test 数据集需要包含所有 285 个用例的答案 | |
| """ | |
| if not os.path.exists(OUTPUT_FILE): | |
| print(f"⚠️ 警告:结果文件 {OUTPUT_FILE} 不存在,无法生成提交文件") | |
| return | |
| # 读取所有结果并按 task_id 排序(确保顺序一致) | |
| results = [] | |
| with open(OUTPUT_FILE, "r", encoding="utf-8") as f: | |
| for line in f: | |
| if not line.strip(): | |
| continue | |
| try: | |
| data = json.loads(line) | |
| if "task_id" in data and "agent_answer" in data: | |
| results.append(data) | |
| except json.JSONDecodeError: | |
| continue | |
| if not results: | |
| print(f"⚠️ 警告:结果文件 {OUTPUT_FILE} 中没有有效数据") | |
| return | |
| # 按 task_id 排序,确保顺序一致 | |
| results.sort(key=lambda x: x.get("task_id", "")) | |
| # 生成提交文件 | |
| with open(SUBMISSION_FILE, "w", encoding="utf-8") as f_out: | |
| for data in results: | |
| submission_data = { | |
| "task_id": data["task_id"], | |
| "model_answer": str(data["agent_answer"]) | |
| } | |
| f_out.write(json.dumps(submission_data, ensure_ascii=False) + "\n") | |
| print(f"✅ 提交文件已生成: {SUBMISSION_FILE} (共 {len(results)} 条记录)") | |
| # test 数据集验证:检查是否包含所有用例 | |
| if SPLIT_TYPE == "test": | |
| expected_count = 285 | |
| if len(results) < expected_count: | |
| print(f"⚠️ 警告:test 数据集应该有 {expected_count} 个用例,当前只有 {len(results)} 个") | |
| else: | |
| print(f"✅ test 数据集已包含 {len(results)} 个用例,符合提交要求") | |
| def get_current_accuracy(): | |
| """ | |
| 获取当前的整体正确率(仅 validation 数据集) | |
| Returns: | |
| float or None: 正确率(百分比),如果不是 validation 数据集或文件不存在则返回 None | |
| """ | |
| # test 数据集没有标准答案,无法计算正确率 | |
| if SPLIT_TYPE != "validation": | |
| return None | |
| if not os.path.exists(OUTPUT_FILE): | |
| return None | |
| try: | |
| results = [] | |
| with open(OUTPUT_FILE, "r", encoding="utf-8") as f: | |
| for line in f: | |
| if not line.strip(): | |
| continue | |
| try: | |
| data = json.loads(line) | |
| results.append(data) | |
| except json.JSONDecodeError: | |
| continue | |
| if not results: | |
| return None | |
| total = len(results) | |
| correct = sum(1 for r in results if r.get("is_correct", False)) | |
| accuracy = (correct / total * 100) if total > 0 else 0.0 | |
| return accuracy | |
| except Exception: | |
| return None | |
| def generate_report(): | |
| """生成统计成绩单(仅 validation 数据集有标准答案,才生成成绩单)""" | |
| # test 数据集没有标准答案,不生成成绩单 | |
| if SPLIT_TYPE != "validation": | |
| return | |
| if not os.path.exists(OUTPUT_FILE): | |
| return | |
| results = [json.loads(line) for line in open(OUTPUT_FILE, "r", encoding="utf-8")] | |
| df = pd.DataFrame(results) | |
| total = len(df) | |
| acc = (df['is_correct'].sum() / total) * 100 | |
| print("\n" + "=" * 50) | |
| print(f"测试完成! 总数: {total} | 总准确率: {acc:.2f}%") | |
| print("=" * 50) | |
| def run_test_concurrent(num_questions=200, mode="hybrid", split="validation", threads=MAX_CONCURRENT, target_task_id=None): | |
| """ | |
| 测试主函数 | |
| Args: | |
| num_questions: 要执行的用例数量 | |
| mode: 执行模式 | |
| - "full": 全量模式,不管文件是否存在,都删除重新开始 | |
| - "incremental": 增量模式,如果文件存在则增量,不存在则全量执行 | |
| - "hybrid": 混合模式,第一次时全量(文件不存在),后面就增量(文件存在) | |
| - "error": 错误模式,只重新执行 agent_answer 包含 ERROR 的记录 | |
| split: 数据集类型,"validation" 或 "test" | |
| threads: 并发线程数 | |
| target_task_id: 可选,指定要运行的 task_id,如果指定则只运行该用例 | |
| """ | |
| global dataset, OUTPUT_FILE, SUBMISSION_FILE, SPLIT_TYPE | |
| # 设置全局变量 | |
| SPLIT_TYPE = split | |
| # 根据 split 类型设置输出文件名 | |
| if split == "validation": | |
| OUTPUT_FILE = "validation_results.jsonl" | |
| SUBMISSION_FILE = "validation_submission.jsonl" | |
| print("📥 正在检查 GAIA 验证集数据...") | |
| else: # test | |
| OUTPUT_FILE = "test_results.jsonl" | |
| SUBMISSION_FILE = "test_submission.jsonl" | |
| print("📥 正在检查 GAIA 测试集数据...") | |
| # 1. 检查并下载完整数据集文件(如果需要,包含 validation 和 test 的所有文件) | |
| print("\n【步骤 1/2】检查数据集文件...") | |
| check_and_download_dataset_files() | |
| # 2. 加载数据集元数据(如果首次下载会显示下载进度) | |
| print(f"\n【步骤 2/2】加载数据集元数据...") | |
| print(f" 数据集: gaia-benchmark/GAIA (2023_all, split={split})") | |
| print(" 提示: 如果是首次下载,请耐心等待,下载进度会显示在下方...") | |
| print(" 如果已下载过,会直接从缓存加载,速度较快") | |
| start_time = time.time() | |
| dataset = load_dataset("gaia-benchmark/GAIA", "2023_all", split=split) | |
| load_duration = time.time() - start_time | |
| print(f"✅ 数据集元数据加载完成!共 {len(dataset)} 条记录,耗时 {load_duration:.2f} 秒\n") | |
| # 1. 根据模式处理已有结果 | |
| file_exists = os.path.exists(OUTPUT_FILE) | |
| if mode == "full": | |
| # 全量模式:删除旧文件,从头开始 | |
| if file_exists: | |
| os.remove(OUTPUT_FILE) | |
| print("🔄 全量模式:已删除旧结果文件,从头开始执行") | |
| existing_results = {} | |
| elif mode == "incremental": | |
| # 增量模式:如果文件存在则增量,不存在则全量执行 | |
| if file_exists: | |
| existing_results = load_existing_results() | |
| print(f"📋 增量模式:已加载 {len(existing_results)} 条历史记录") | |
| else: | |
| existing_results = {} | |
| print("📋 增量模式:未找到历史记录,将全量执行") | |
| elif mode == "error": | |
| # 错误模式:只重新执行 agent_answer 包含 ERROR 的记录 | |
| if file_exists: | |
| existing_results = load_existing_results() | |
| print(f"📋 错误模式:已加载 {len(existing_results)} 条历史记录,将重新执行包含 ERROR 的记录") | |
| else: | |
| existing_results = {} | |
| print("📋 错误模式:未找到历史记录,无法执行错误重试") | |
| else: # hybrid | |
| # 混合模式:第一次时全量(文件不存在),后面就增量(文件存在) | |
| if file_exists: | |
| existing_results = load_existing_results() | |
| print(f"📋 混合模式:检测到已有文件,进入增量模式(已加载 {len(existing_results)} 条历史记录)") | |
| else: | |
| existing_results = {} | |
| print("📋 混合模式:首次执行,进入全量模式") | |
| # 2. 筛选需要执行的用例 | |
| if target_task_id: | |
| # 如果指定了 task_id,只运行该用例 | |
| print(f"🎯 指定运行 task_id: {target_task_id}") | |
| tasks_to_run = [] | |
| found = False | |
| for item in dataset: | |
| if item['task_id'] == target_task_id: | |
| tasks_to_run = [item] | |
| found = True | |
| break | |
| if not found: | |
| print(f"❌ 错误: 在 {split} 数据集中未找到 task_id: {target_task_id}") | |
| return | |
| num_to_run = 1 | |
| else: | |
| # 正常模式,根据 num_questions 筛选 | |
| num_to_run = min(num_questions, len(dataset)) | |
| tasks_to_run = dataset.select(range(num_to_run)) | |
| # 统计需要执行的用例 | |
| tasks_to_execute = [] | |
| refresh_count = 0 # hybrid 模式下只刷新字段顺序的记录数 | |
| update_count = 0 # 需要重新调用 agent 的记录数 | |
| new_count = 0 # 新记录数 | |
| error_count = 0 # error 模式下包含 ERROR 的记录数 | |
| for item in tasks_to_run: | |
| task_id = item['task_id'] | |
| if mode == "error": | |
| # error 模式:只重新执行 agent_answer 包含 ERROR 的记录 | |
| if task_id in existing_results: | |
| agent_answer = existing_results[task_id].get("agent_answer", "") | |
| if agent_answer and "ERROR" in str(agent_answer): | |
| error_count += 1 | |
| tasks_to_execute.append(item) | |
| # 如果记录不存在或 agent_answer 不包含 ERROR,则跳过 | |
| else: | |
| # 其他模式:正常处理 | |
| if task_id in existing_results: | |
| # hybrid + validation 模式下:如果已成功,只刷新字段顺序 | |
| # test 数据集没有 is_correct 字段,无法判断是否正确,所以总是重新执行 | |
| if mode == "hybrid" and split == "validation": | |
| if existing_results[task_id].get("is_correct", False): | |
| refresh_count += 1 | |
| else: | |
| update_count += 1 | |
| else: | |
| # 非 hybrid 模式,或 test 数据集:所有已有记录都需要重新执行 | |
| update_count += 1 | |
| else: | |
| new_count += 1 | |
| tasks_to_execute.append(item) | |
| total_to_execute = len(tasks_to_execute) | |
| print(f"\n📊 统计信息:") | |
| print(f" 数据集: {split}") | |
| print(f" 执行模式: {mode}") | |
| if target_task_id: | |
| print(f" 指定 task_id: {target_task_id}") | |
| print(f" 总用例数: {num_to_run}") | |
| if existing_results: | |
| if mode == "error": | |
| print(f" 需要执行: {total_to_execute} (包含 ERROR 的记录: {error_count})") | |
| elif mode == "hybrid": | |
| print( | |
| f" 需要执行: {total_to_execute} (新用例: {new_count}, 刷新字段顺序: {refresh_count}, 重新测试: {update_count})") | |
| else: | |
| print( | |
| f" 需要执行: {total_to_execute} (新用例: {new_count}, 刷新已有记录: {refresh_count + update_count})") | |
| else: | |
| print(f" 需要执行: {total_to_execute} (全量执行)") | |
| print(f"🚀 开始测试 | 并发数: {threads} | 待执行: {total_to_execute}") | |
| if total_to_execute == 0: | |
| if mode == "error": | |
| print("✅ 没有包含 ERROR 的记录,无需执行") | |
| elif split == "validation": | |
| print("✅ 所有用例已完成且正确,无需执行") | |
| else: | |
| print("✅ 所有用例已完成,无需执行") | |
| generate_report() | |
| generate_submission() | |
| return | |
| # 3. 并发执行 | |
| with ThreadPoolExecutor(max_workers=threads) as executor: | |
| future_to_item = {executor.submit(process_item, item, existing_results, mode): item for item in | |
| tasks_to_execute} | |
| done = 0 | |
| for future in as_completed(future_to_item): | |
| done += 1 | |
| item = future_to_item[future] | |
| tid = item['task_id'] | |
| try: | |
| _, is_ok, status = future.result() | |
| if status == "refreshed": | |
| status_icon = "🔄" | |
| elif split == "validation": | |
| status_icon = "✅" if is_ok else "❌" | |
| else: # test | |
| status_icon = "✅" | |
| # 计算并显示当前整体正确率 | |
| accuracy_info = "" | |
| if split == "validation": | |
| current_accuracy = get_current_accuracy() | |
| if current_accuracy is not None: | |
| accuracy_info = f" | 当前正确率: {current_accuracy:.2f}%" | |
| print(f"[{done}/{total_to_execute}] ID: {tid} | 状态: {status_icon} ({status}){accuracy_info}") | |
| except Exception as e: | |
| error_traceback = traceback.format_exc() | |
| print(f"[{done}/{total_to_execute}] ID: {tid} 运行异常: {e}") | |
| print(f"异常堆栈:\n{error_traceback}") | |
| # 4. 生成报表 | |
| generate_report() | |
| generate_submission() | |
| def print_help(): | |
| """打印详细的帮助信息""" | |
| print("=" * 70) | |
| print("GAIA 测试脚本 - 参数说明") | |
| print("=" * 70) | |
| print() | |
| print("用法:") | |
| print(" python gaia_test.py [参数]") | |
| print() | |
| print("参数说明:") | |
| print() | |
| print(" --split <类型>") | |
| print(" 数据集类型") | |
| print(" 可选值: validation, test") | |
| print(" 默认值: validation") | |
| print(" 说明:") | |
| print(" - validation: 验证集,有标准答案,可以计算正确率") | |
| print(" - test: 测试集,无标准答案,用于最终提交") | |
| print() | |
| print(" --mode <模式>") | |
| print(" 执行模式") | |
| print(" 可选值: full, incremental, hybrid, error") | |
| print(" 默认值: hybrid") | |
| print(" 说明:") | |
| print(" - full: 全量模式,删除旧结果文件,从头开始执行") | |
| print(" - incremental: 增量模式,如果文件存在则增量,不存在则全量执行") | |
| print(" - hybrid: 混合模式(推荐),首次全量,后续增量") | |
| print(" 在 hybrid 模式下,validation 数据集中已正确的记录") | |
| print(" 只刷新字段顺序,不重新调用 agent") | |
| print(" - error: 错误模式,只重新执行 agent_answer 包含 ERROR 的记录") | |
| print() | |
| print(" --num <数量>") | |
| print(" 要执行的用例数量") | |
| print(" 类型: 整数") | |
| print(" 默认值: 200") | |
| print(" 说明:") | |
| print(" - test 数据集共 285 题,可以设置 --num 285 执行全部") | |
| print(" - validation 数据集可以根据需要设置数量") | |
| print() | |
| print(" --threads <数量>") | |
| print(" 并发执行的线程数") | |
| print(" 类型: 整数") | |
| print(" 默认值: 2") | |
| print(" 说明:") | |
| print(" - 根据服务器性能调整,过高可能导致服务器压力过大") | |
| print(" - 建议范围: 1-4") | |
| print() | |
| print(" --task-id <task_id>") | |
| print(" 指定要运行的 task_id") | |
| print(" 类型: 字符串") | |
| print(" 默认值: 无(运行多个用例)") | |
| print(" 说明:") | |
| print(" - 如果指定此参数,则只运行该 task_id 对应的用例") | |
| print(" - 指定此参数时,--num 参数会被忽略") | |
| print(" - 如果指定的 task_id 不存在,脚本会报错并退出") | |
| print() | |
| print(" -h, --help") | |
| print(" 显示此帮助信息并退出") | |
| print() | |
| print("示例:") | |
| print(" # 使用默认参数(validation 数据集,hybrid 模式,200 题)") | |
| print(" python gaia_test.py") | |
| print() | |
| print(" # 测试 test 数据集,执行全部 285 题") | |
| print(" python gaia_test.py --split test --num 285") | |
| print() | |
| print(" # 使用 error 模式重新执行错误记录") | |
| print(" python gaia_test.py --mode error") | |
| print() | |
| print(" # 使用全量模式,4 个并发线程") | |
| print(" python gaia_test.py --mode full --threads 4") | |
| print() | |
| print(" # 运行指定的 task_id") | |
| print(" python gaia_test.py --task-id c61d22de-5f6c-4958-a7f6-5e9707bd3466") | |
| print() | |
| print("=" * 70) | |
| print("配置文件:") | |
| print(" 运行前请确保已正确配置 gaia_test.py 中的以下参数:") | |
| print(" - BASE_URL: API 服务地址") | |
| print(" - HEADERS: 认证 Token(必须修改)") | |
| print(" - DATA_PATH: 数据文件路径") | |
| print(" - REQUEST_TIMEOUT: 请求超时时间") | |
| print(" - MAX_CONCURRENT: 最大并发数") | |
| print("=" * 70) | |
| if __name__ == "__main__": | |
| import sys | |
| # 检查是否有 -h 或 --help 参数 | |
| if "-h" in sys.argv or "--help" in sys.argv: | |
| print_help() | |
| sys.exit(0) | |
| parser = argparse.ArgumentParser( | |
| description="GAIA 测试脚本(支持 validation 和 test 数据集,支持全量、增量、混合三种模式)", | |
| formatter_class=argparse.RawDescriptionHelpFormatter | |
| ) | |
| parser.add_argument( | |
| "--split", | |
| type=str, | |
| choices=["validation", "test"], | |
| default="validation", | |
| help="数据集类型: 'validation' 验证集(有标准答案)、'test' 测试集(无标准答案,默认: validation)" | |
| ) | |
| parser.add_argument( | |
| "--mode", | |
| type=str, | |
| choices=["full", "incremental", "hybrid", "error"], | |
| default="hybrid", | |
| help="执行模式: 'full' 全量模式(删除旧文件重新执行)、'incremental' 增量模式(文件存在则增量,不存在则全量)、'hybrid' 混合模式(首次全量,后续增量,默认)、'error' 错误模式(只重新执行 agent_answer 包含 ERROR 的记录)" | |
| ) | |
| parser.add_argument( | |
| "--num", | |
| type=int, | |
| default=200, | |
| help="要执行的用例数量(默认: 200,test 集共 285 题)" | |
| ) | |
| parser.add_argument( | |
| "--threads", | |
| type=int, | |
| default=MAX_CONCURRENT, | |
| help="执行的并发数(默认: 2)" | |
| ) | |
| parser.add_argument( | |
| "--task-id", | |
| type=str, | |
| default=None, | |
| help="指定要运行的 task_id,如果指定则只运行该用例(忽略 --num 参数)" | |
| ) | |
| args = parser.parse_args() | |
| run_test_concurrent(num_questions=args.num, mode=args.mode, split=args.split, threads=args.threads, target_task_id=args.task_id) | |