""" 队列工具函数 用于数据转换和队列辅助功能 """ from typing import List, Dict, Any from question import Question, SamplingRecord from question_queue import QuestionQueue, QueueStrategy # math-verify 校验 from math_verify import parse, verify def simple_verify(gold_answer: str, output_text: str) -> Dict[str, Any]: """ 使用math_verify进行简单验证 与原有infer.py保持一致 """ try: gold_answer = "$"+str(gold_answer)+"$" gold = parse(gold_answer) answer = parse(output_text) verify_result = verify(gold, answer) # 提取答案 - 与math-verify保持一致 extracted_answer = str(answer) if answer is not None else None except Exception as e: verify_result = False # 如果解析失败,认为答案不正确 extracted_answer = None # 解析失败时无法提取答案 print(f"解析失败: {e}") return { 'extracted_predicted': extracted_answer, 'thinking_part': output_text, # 整个输出作为思考部分 'answer_part': extracted_answer or output_text, # 提取的答案或整个输出 'is_correct': verify_result, 'score': 1.0 if verify_result else 0.0, 'error': None } def create_questions_from_raw_data(raw_questions: List[Dict[str, Any]]) -> List[Question]: """ 从原始数据创建问题对象列表 Args: raw_questions: 原始问题数据列表 Returns: 问题对象列表 """ questions = [] for idx, raw_q in enumerate(raw_questions): question = Question( question_id=idx, question_text=raw_q['question'], gold_answer=raw_q['answer'], gold_solution=raw_q['solution'], file_path=raw_q.get('file', ''), raw_id=raw_q.get('raw_id', None) ) questions.append(question) return questions def create_question_queue_from_raw_data(raw_questions: List[Dict[str, Any]], strategy: QueueStrategy = QueueStrategy.FCFS) -> QuestionQueue: """ 从原始数据创建问题队列 Args: raw_questions: 原始问题数据列表 strategy: 队列策略 Returns: 问题队列对象 """ questions = create_questions_from_raw_data(raw_questions) queue = QuestionQueue(questions) queue.set_strategy(strategy) return queue def add_sampling_result_to_queue(queue: QuestionQueue, question_id: int, round_num: int, output_text: str, token_count: int, total_tokens_used: int, budget: int, run_count: int = 1) -> bool: """ 将采样结果添加到队列中的指定问题 Args: queue: 问题队列 question_id: 问题ID round_num: 轮次 output_text: 模型输出文本 token_count: token数量 total_tokens_used: 累计token数 budget: 总预算 run_count: 运行次数 Returns: 是否成功添加 """ question = queue.get_question(question_id) if not question: return False # 使用math_verify进行简单验证 verification_result = simple_verify(question.gold_answer, output_text) # 创建采样记录 record = SamplingRecord( round_num=round_num, token_count=token_count, extracted_answer=verification_result['extracted_predicted'], verify_result=verification_result['is_correct'], thinking_part=verification_result['thinking_part'], answer_part=verification_result['answer_part'], verification_score=verification_result['score'], verification_error=verification_result.get('error'), cumulative_tokens=total_tokens_used, budget_remaining=budget - total_tokens_used, run_count=run_count ) # 添加到问题中 question.add_sampling_record(record) return True def convert_queue_to_legacy_format(queue: QuestionQueue, scheduler_type: str) -> Dict[str, Any]: """ 将队列转换为兼容原有answer_evaluator的格式 Args: queue: 问题队列 scheduler_type: 调度器类型 ('fcfs' 或 'sjf') Returns: 兼容原有格式的字典 """ results = [] for question in queue.questions: for record in question.sampling_records: result = { 'round': record.round_num, 'question': question.question_text, 'gt_answer': question.gold_answer, 'gt_solution': question.gold_solution, 'pred_answer': record.answer_part, # 使用答案部分 'extracted_answer': record.extracted_answer, 'token_count': record.token_count, 'verify': record.verify_result, 'file': question.file_path, 'question_idx': question.question_id, 'cumulative_tokens': record.cumulative_tokens, 'budget_remaining': record.budget_remaining, 'run_count': record.run_count, # 新增高级验证信息 'verification_score': record.verification_score, 'thinking_part': record.thinking_part, 'answer_part': record.answer_part, 'verification_error': record.verification_error } results.append(result) return { 'scheduler_type': scheduler_type, 'results': results, 'queue_stats': queue.get_queue_stats() } def create_summary_from_queue(queue: QuestionQueue, scheduler_type: str, budget: int, total_tokens_used: int) -> Dict[str, Any]: """ 从队列创建摘要信息 Args: queue: 问题队列 scheduler_type: 调度器类型 budget: 预算 total_tokens_used: 使用的token数 Returns: 摘要信息字典 """ stats = queue.get_queue_stats() # 计算处理的问题数(有采样记录的问题) processed_questions = sum(1 for q in queue.questions if q.total_runs > 0) skipped_questions = len(queue.questions) - processed_questions return { 'scheduler_type': f'{scheduler_type.upper()}_Circular', 'budget': budget, 'total_tokens_used': total_tokens_used, 'budget_utilization': total_tokens_used / budget if budget > 0 else 0, 'total_processed': stats['total_processed'], 'total_skipped': skipped_questions, 'correct_count': stats['total_correct'], 'accuracy': stats['overall_accuracy'], 'total_rounds': max((record.round_num for q in queue.questions for record in q.sampling_records), default=0), 'question_run_count': {q.question_id: q.total_runs for q in queue.questions}, 'queue_stats': stats } def save_queue_results(queue: QuestionQueue, scheduler_type: str, out_dir: str, budget: int, total_tokens_used: int) -> None: """ 保存队列结果到文件 Args: queue: 问题队列 scheduler_type: 调度器类型 out_dir: 输出目录 budget: 预算 total_tokens_used: 使用的token数 """ import os import json # 创建结果目录 results_dir = os.path.join(out_dir, f'{scheduler_type}_results') os.makedirs(results_dir, exist_ok=True) # 保存每个问题的详细结果 for question in queue.questions: if question.sampling_records: filename = f'{scheduler_type}_problem_{question.question_id:04d}.jsonl' filepath = os.path.join(results_dir, filename) # gold_answer = "$"+str(question.gold_answer)+"$" # gold = parse(gold_answer) with open(filepath, 'w', encoding='utf-8') as f: for record in question.sampling_records: result = { 'round': record.round_num, 'question': question.question_text, 'gt_answer': question.gold_answer, 'gt_solution': question.gold_solution, 'pred_answer': record.answer_part, 'extracted_answer': record.extracted_answer, 'token_count': record.token_count, 'verify': record.verify_result, 'file': question.file_path, 'question_idx': question.question_id, 'cumulative_tokens': record.cumulative_tokens, 'budget_remaining': record.budget_remaining, 'run_count': record.run_count, 'verification_score': record.verification_score, 'thinking_part': record.thinking_part, 'answer_part': record.answer_part, 'verification_error': record.verification_error } f.write(json.dumps(result, ensure_ascii=False) + '\n') # 保存摘要信息 summary = create_summary_from_queue(queue, scheduler_type, budget, total_tokens_used) summary_path = os.path.join(out_dir, f'{scheduler_type}_summary.json') with open(summary_path, 'w', encoding='utf-8') as f: json.dump(summary, f, ensure_ascii=False, indent=2) # 保存完整队列信息 queue_path = os.path.join(out_dir, f'{scheduler_type}_queue.json') queue.save_to_file(queue_path) def create_clean_question_queue(questions: List[Question], strategy: QueueStrategy = QueueStrategy.FCFS, preserve_token_info: bool = False, preserve_stage1_answers: bool = False) -> QuestionQueue: """ 创建干净的问题队列,不包含采样记录 Args: questions: 原始问题列表 strategy: 队列策略 preserve_token_info: 是否保留token信息(用于预算计算) preserve_stage1_answers: 是否保留第一阶段的完整回答 Returns: 干净的问题队列对象 """ clean_questions = [] for q in questions: clean_question = Question( question_id=q.question_id, question_text=q.question_text, gold_answer=q.gold_answer, gold_solution=q.gold_solution, file_path=q.file_path ) # 如果需要保留第一阶段的完整回答 if preserve_stage1_answers and q.sampling_records: # 复制第一阶段的完整记录 for record in q.sampling_records: if record.round_num == 0: # 第一阶段的记录 clean_question.add_sampling_record(record) break # 如果只需要保留token信息 elif preserve_token_info and q.sampling_records: # 复制第一阶段的token信息(如果有的话) first_record = q.sampling_records[0] token_record = SamplingRecord( round_num=0, # 标记为第0轮,表示这是token信息 token_count=first_record.token_count, extracted_answer="", # 不复制答案信息 verify_result=False, thinking_part="", answer_part="", verification_score=0.0, verification_error=None, cumulative_tokens=first_record.token_count, budget_remaining=0, run_count=0 ) clean_question.add_sampling_record(token_record) clean_questions.append(clean_question) queue = QuestionQueue(clean_questions) queue.set_strategy(strategy) return queue