| """ |
| 队列工具函数 |
| 用于数据转换和队列辅助功能 |
| """ |
|
|
| from typing import List, Dict, Any |
| from question import Question, SamplingRecord |
| from question_queue import QuestionQueue, QueueStrategy |
|
|
| |
| from math_verify import parse, verify |
|
|
| def simple_verify(gold_answer: str, output_text: str) -> Dict[str, Any]: |
| """ |
| 使用math_verify进行简单验证 |
| 与原有infer.py保持一致 |
| """ |
| try: |
| gold_answer = "$"+str(gold_answer)+"$" |
| gold = parse(gold_answer) |
| answer = parse(output_text) |
| verify_result = verify(gold, answer) |
| |
| |
| extracted_answer = str(answer) if answer is not None else None |
| except Exception as e: |
| verify_result = False |
| extracted_answer = None |
| print(f"解析失败: {e}") |
| |
| return { |
| 'extracted_predicted': extracted_answer, |
| 'thinking_part': output_text, |
| 'answer_part': extracted_answer or output_text, |
| 'is_correct': verify_result, |
| 'score': 1.0 if verify_result else 0.0, |
| 'error': None |
| } |
|
|
| def create_questions_from_raw_data(raw_questions: List[Dict[str, Any]]) -> List[Question]: |
| """ |
| 从原始数据创建问题对象列表 |
| |
| Args: |
| raw_questions: 原始问题数据列表 |
| |
| Returns: |
| 问题对象列表 |
| """ |
| questions = [] |
| |
| for idx, raw_q in enumerate(raw_questions): |
| question = Question( |
| question_id=idx, |
| question_text=raw_q['question'], |
| gold_answer=raw_q['answer'], |
| gold_solution=raw_q['solution'], |
| file_path=raw_q.get('file', ''), |
| raw_id=raw_q.get('raw_id', None) |
| ) |
| questions.append(question) |
| |
| return questions |
|
|
| def create_question_queue_from_raw_data(raw_questions: List[Dict[str, Any]], |
| strategy: QueueStrategy = QueueStrategy.FCFS) -> QuestionQueue: |
| """ |
| 从原始数据创建问题队列 |
| |
| Args: |
| raw_questions: 原始问题数据列表 |
| strategy: 队列策略 |
| |
| Returns: |
| 问题队列对象 |
| """ |
| questions = create_questions_from_raw_data(raw_questions) |
| queue = QuestionQueue(questions) |
| queue.set_strategy(strategy) |
| return queue |
|
|
| def add_sampling_result_to_queue(queue: QuestionQueue, |
| question_id: int, |
| round_num: int, |
| output_text: str, |
| token_count: int, |
| total_tokens_used: int, |
| budget: int, |
| run_count: int = 1) -> bool: |
| """ |
| 将采样结果添加到队列中的指定问题 |
| |
| Args: |
| queue: 问题队列 |
| question_id: 问题ID |
| round_num: 轮次 |
| output_text: 模型输出文本 |
| token_count: token数量 |
| total_tokens_used: 累计token数 |
| budget: 总预算 |
| run_count: 运行次数 |
| |
| Returns: |
| 是否成功添加 |
| """ |
| question = queue.get_question(question_id) |
| if not question: |
| return False |
| |
| |
| verification_result = simple_verify(question.gold_answer, output_text) |
| |
| |
| record = SamplingRecord( |
| round_num=round_num, |
| token_count=token_count, |
| extracted_answer=verification_result['extracted_predicted'], |
| verify_result=verification_result['is_correct'], |
| thinking_part=verification_result['thinking_part'], |
| answer_part=verification_result['answer_part'], |
| verification_score=verification_result['score'], |
| verification_error=verification_result.get('error'), |
| cumulative_tokens=total_tokens_used, |
| budget_remaining=budget - total_tokens_used, |
| run_count=run_count |
| ) |
| |
| |
| question.add_sampling_record(record) |
| return True |
|
|
| def convert_queue_to_legacy_format(queue: QuestionQueue, scheduler_type: str) -> Dict[str, Any]: |
| """ |
| 将队列转换为兼容原有answer_evaluator的格式 |
| |
| Args: |
| queue: 问题队列 |
| scheduler_type: 调度器类型 ('fcfs' 或 'sjf') |
| |
| Returns: |
| 兼容原有格式的字典 |
| """ |
| results = [] |
| for question in queue.questions: |
| for record in question.sampling_records: |
| result = { |
| 'round': record.round_num, |
| 'question': question.question_text, |
| 'gt_answer': question.gold_answer, |
| 'gt_solution': question.gold_solution, |
| 'pred_answer': record.answer_part, |
| 'extracted_answer': record.extracted_answer, |
| 'token_count': record.token_count, |
| 'verify': record.verify_result, |
| 'file': question.file_path, |
| 'question_idx': question.question_id, |
| 'cumulative_tokens': record.cumulative_tokens, |
| 'budget_remaining': record.budget_remaining, |
| 'run_count': record.run_count, |
| |
| 'verification_score': record.verification_score, |
| 'thinking_part': record.thinking_part, |
| 'answer_part': record.answer_part, |
| 'verification_error': record.verification_error |
| } |
| results.append(result) |
| |
| return { |
| 'scheduler_type': scheduler_type, |
| 'results': results, |
| 'queue_stats': queue.get_queue_stats() |
| } |
|
|
| def create_summary_from_queue(queue: QuestionQueue, scheduler_type: str, |
| budget: int, total_tokens_used: int) -> Dict[str, Any]: |
| """ |
| 从队列创建摘要信息 |
| |
| Args: |
| queue: 问题队列 |
| scheduler_type: 调度器类型 |
| budget: 预算 |
| total_tokens_used: 使用的token数 |
| |
| Returns: |
| 摘要信息字典 |
| """ |
| stats = queue.get_queue_stats() |
| |
| |
| processed_questions = sum(1 for q in queue.questions if q.total_runs > 0) |
| skipped_questions = len(queue.questions) - processed_questions |
| |
| return { |
| 'scheduler_type': f'{scheduler_type.upper()}_Circular', |
| 'budget': budget, |
| 'total_tokens_used': total_tokens_used, |
| 'budget_utilization': total_tokens_used / budget if budget > 0 else 0, |
| 'total_processed': stats['total_processed'], |
| 'total_skipped': skipped_questions, |
| 'correct_count': stats['total_correct'], |
| 'accuracy': stats['overall_accuracy'], |
| 'total_rounds': max((record.round_num for q in queue.questions |
| for record in q.sampling_records), default=0), |
| 'question_run_count': {q.question_id: q.total_runs for q in queue.questions}, |
| 'queue_stats': stats |
| } |
|
|
| def save_queue_results(queue: QuestionQueue, scheduler_type: str, |
| out_dir: str, budget: int, total_tokens_used: int) -> None: |
| """ |
| 保存队列结果到文件 |
| |
| Args: |
| queue: 问题队列 |
| scheduler_type: 调度器类型 |
| out_dir: 输出目录 |
| budget: 预算 |
| total_tokens_used: 使用的token数 |
| """ |
| import os |
| import json |
| |
| |
| results_dir = os.path.join(out_dir, f'{scheduler_type}_results') |
| os.makedirs(results_dir, exist_ok=True) |
| |
| |
| for question in queue.questions: |
| if question.sampling_records: |
| filename = f'{scheduler_type}_problem_{question.question_id:04d}.jsonl' |
| filepath = os.path.join(results_dir, filename) |
| |
| |
| with open(filepath, 'w', encoding='utf-8') as f: |
| for record in question.sampling_records: |
| result = { |
| 'round': record.round_num, |
| 'question': question.question_text, |
| 'gt_answer': question.gold_answer, |
| 'gt_solution': question.gold_solution, |
| 'pred_answer': record.answer_part, |
| 'extracted_answer': record.extracted_answer, |
| 'token_count': record.token_count, |
| 'verify': record.verify_result, |
| 'file': question.file_path, |
| 'question_idx': question.question_id, |
| 'cumulative_tokens': record.cumulative_tokens, |
| 'budget_remaining': record.budget_remaining, |
| 'run_count': record.run_count, |
| 'verification_score': record.verification_score, |
| 'thinking_part': record.thinking_part, |
| 'answer_part': record.answer_part, |
| 'verification_error': record.verification_error |
| } |
| f.write(json.dumps(result, ensure_ascii=False) + '\n') |
| |
| |
| summary = create_summary_from_queue(queue, scheduler_type, budget, total_tokens_used) |
| summary_path = os.path.join(out_dir, f'{scheduler_type}_summary.json') |
| with open(summary_path, 'w', encoding='utf-8') as f: |
| json.dump(summary, f, ensure_ascii=False, indent=2) |
| |
| |
| queue_path = os.path.join(out_dir, f'{scheduler_type}_queue.json') |
| queue.save_to_file(queue_path) |
|
|
| def create_clean_question_queue(questions: List[Question], |
| strategy: QueueStrategy = QueueStrategy.FCFS, |
| preserve_token_info: bool = False, |
| preserve_stage1_answers: bool = False) -> QuestionQueue: |
| """ |
| 创建干净的问题队列,不包含采样记录 |
| |
| Args: |
| questions: 原始问题列表 |
| strategy: 队列策略 |
| preserve_token_info: 是否保留token信息(用于预算计算) |
| preserve_stage1_answers: 是否保留第一阶段的完整回答 |
| |
| Returns: |
| 干净的问题队列对象 |
| """ |
| clean_questions = [] |
| for q in questions: |
| clean_question = Question( |
| question_id=q.question_id, |
| question_text=q.question_text, |
| gold_answer=q.gold_answer, |
| gold_solution=q.gold_solution, |
| file_path=q.file_path |
| ) |
| |
| |
| if preserve_stage1_answers and q.sampling_records: |
| |
| for record in q.sampling_records: |
| if record.round_num == 0: |
| clean_question.add_sampling_record(record) |
| break |
| |
| elif preserve_token_info and q.sampling_records: |
| |
| first_record = q.sampling_records[0] |
| token_record = SamplingRecord( |
| round_num=0, |
| token_count=first_record.token_count, |
| extracted_answer="", |
| verify_result=False, |
| thinking_part="", |
| answer_part="", |
| verification_score=0.0, |
| verification_error=None, |
| cumulative_tokens=first_record.token_count, |
| budget_remaining=0, |
| run_count=0 |
| ) |
| clean_question.add_sampling_record(token_record) |
| |
| clean_questions.append(clean_question) |
| |
| queue = QuestionQueue(clean_questions) |
| queue.set_strategy(strategy) |
| return queue |