tts_schedule_730 / queue_utils.py
unfair11212's picture
Upload folder using huggingface_hub
815d1e1 verified
"""
队列工具函数
用于数据转换和队列辅助功能
"""
from typing import List, Dict, Any
from question import Question, SamplingRecord
from question_queue import QuestionQueue, QueueStrategy
# math-verify 校验
from math_verify import parse, verify
def simple_verify(gold_answer: str, output_text: str) -> Dict[str, Any]:
"""
使用math_verify进行简单验证
与原有infer.py保持一致
"""
try:
gold_answer = "$"+str(gold_answer)+"$"
gold = parse(gold_answer)
answer = parse(output_text)
verify_result = verify(gold, answer)
# 提取答案 - 与math-verify保持一致
extracted_answer = str(answer) if answer is not None else None
except Exception as e:
verify_result = False # 如果解析失败,认为答案不正确
extracted_answer = None # 解析失败时无法提取答案
print(f"解析失败: {e}")
return {
'extracted_predicted': extracted_answer,
'thinking_part': output_text, # 整个输出作为思考部分
'answer_part': extracted_answer or output_text, # 提取的答案或整个输出
'is_correct': verify_result,
'score': 1.0 if verify_result else 0.0,
'error': None
}
def create_questions_from_raw_data(raw_questions: List[Dict[str, Any]]) -> List[Question]:
"""
从原始数据创建问题对象列表
Args:
raw_questions: 原始问题数据列表
Returns:
问题对象列表
"""
questions = []
for idx, raw_q in enumerate(raw_questions):
question = Question(
question_id=idx,
question_text=raw_q['question'],
gold_answer=raw_q['answer'],
gold_solution=raw_q['solution'],
file_path=raw_q.get('file', ''),
raw_id=raw_q.get('raw_id', None)
)
questions.append(question)
return questions
def create_question_queue_from_raw_data(raw_questions: List[Dict[str, Any]],
strategy: QueueStrategy = QueueStrategy.FCFS) -> QuestionQueue:
"""
从原始数据创建问题队列
Args:
raw_questions: 原始问题数据列表
strategy: 队列策略
Returns:
问题队列对象
"""
questions = create_questions_from_raw_data(raw_questions)
queue = QuestionQueue(questions)
queue.set_strategy(strategy)
return queue
def add_sampling_result_to_queue(queue: QuestionQueue,
question_id: int,
round_num: int,
output_text: str,
token_count: int,
total_tokens_used: int,
budget: int,
run_count: int = 1) -> bool:
"""
将采样结果添加到队列中的指定问题
Args:
queue: 问题队列
question_id: 问题ID
round_num: 轮次
output_text: 模型输出文本
token_count: token数量
total_tokens_used: 累计token数
budget: 总预算
run_count: 运行次数
Returns:
是否成功添加
"""
question = queue.get_question(question_id)
if not question:
return False
# 使用math_verify进行简单验证
verification_result = simple_verify(question.gold_answer, output_text)
# 创建采样记录
record = SamplingRecord(
round_num=round_num,
token_count=token_count,
extracted_answer=verification_result['extracted_predicted'],
verify_result=verification_result['is_correct'],
thinking_part=verification_result['thinking_part'],
answer_part=verification_result['answer_part'],
verification_score=verification_result['score'],
verification_error=verification_result.get('error'),
cumulative_tokens=total_tokens_used,
budget_remaining=budget - total_tokens_used,
run_count=run_count
)
# 添加到问题中
question.add_sampling_record(record)
return True
def convert_queue_to_legacy_format(queue: QuestionQueue, scheduler_type: str) -> Dict[str, Any]:
"""
将队列转换为兼容原有answer_evaluator的格式
Args:
queue: 问题队列
scheduler_type: 调度器类型 ('fcfs' 或 'sjf')
Returns:
兼容原有格式的字典
"""
results = []
for question in queue.questions:
for record in question.sampling_records:
result = {
'round': record.round_num,
'question': question.question_text,
'gt_answer': question.gold_answer,
'gt_solution': question.gold_solution,
'pred_answer': record.answer_part, # 使用答案部分
'extracted_answer': record.extracted_answer,
'token_count': record.token_count,
'verify': record.verify_result,
'file': question.file_path,
'question_idx': question.question_id,
'cumulative_tokens': record.cumulative_tokens,
'budget_remaining': record.budget_remaining,
'run_count': record.run_count,
# 新增高级验证信息
'verification_score': record.verification_score,
'thinking_part': record.thinking_part,
'answer_part': record.answer_part,
'verification_error': record.verification_error
}
results.append(result)
return {
'scheduler_type': scheduler_type,
'results': results,
'queue_stats': queue.get_queue_stats()
}
def create_summary_from_queue(queue: QuestionQueue, scheduler_type: str,
budget: int, total_tokens_used: int) -> Dict[str, Any]:
"""
从队列创建摘要信息
Args:
queue: 问题队列
scheduler_type: 调度器类型
budget: 预算
total_tokens_used: 使用的token数
Returns:
摘要信息字典
"""
stats = queue.get_queue_stats()
# 计算处理的问题数(有采样记录的问题)
processed_questions = sum(1 for q in queue.questions if q.total_runs > 0)
skipped_questions = len(queue.questions) - processed_questions
return {
'scheduler_type': f'{scheduler_type.upper()}_Circular',
'budget': budget,
'total_tokens_used': total_tokens_used,
'budget_utilization': total_tokens_used / budget if budget > 0 else 0,
'total_processed': stats['total_processed'],
'total_skipped': skipped_questions,
'correct_count': stats['total_correct'],
'accuracy': stats['overall_accuracy'],
'total_rounds': max((record.round_num for q in queue.questions
for record in q.sampling_records), default=0),
'question_run_count': {q.question_id: q.total_runs for q in queue.questions},
'queue_stats': stats
}
def save_queue_results(queue: QuestionQueue, scheduler_type: str,
out_dir: str, budget: int, total_tokens_used: int) -> None:
"""
保存队列结果到文件
Args:
queue: 问题队列
scheduler_type: 调度器类型
out_dir: 输出目录
budget: 预算
total_tokens_used: 使用的token数
"""
import os
import json
# 创建结果目录
results_dir = os.path.join(out_dir, f'{scheduler_type}_results')
os.makedirs(results_dir, exist_ok=True)
# 保存每个问题的详细结果
for question in queue.questions:
if question.sampling_records:
filename = f'{scheduler_type}_problem_{question.question_id:04d}.jsonl'
filepath = os.path.join(results_dir, filename)
# gold_answer = "$"+str(question.gold_answer)+"$"
# gold = parse(gold_answer)
with open(filepath, 'w', encoding='utf-8') as f:
for record in question.sampling_records:
result = {
'round': record.round_num,
'question': question.question_text,
'gt_answer': question.gold_answer,
'gt_solution': question.gold_solution,
'pred_answer': record.answer_part,
'extracted_answer': record.extracted_answer,
'token_count': record.token_count,
'verify': record.verify_result,
'file': question.file_path,
'question_idx': question.question_id,
'cumulative_tokens': record.cumulative_tokens,
'budget_remaining': record.budget_remaining,
'run_count': record.run_count,
'verification_score': record.verification_score,
'thinking_part': record.thinking_part,
'answer_part': record.answer_part,
'verification_error': record.verification_error
}
f.write(json.dumps(result, ensure_ascii=False) + '\n')
# 保存摘要信息
summary = create_summary_from_queue(queue, scheduler_type, budget, total_tokens_used)
summary_path = os.path.join(out_dir, f'{scheduler_type}_summary.json')
with open(summary_path, 'w', encoding='utf-8') as f:
json.dump(summary, f, ensure_ascii=False, indent=2)
# 保存完整队列信息
queue_path = os.path.join(out_dir, f'{scheduler_type}_queue.json')
queue.save_to_file(queue_path)
def create_clean_question_queue(questions: List[Question],
strategy: QueueStrategy = QueueStrategy.FCFS,
preserve_token_info: bool = False,
preserve_stage1_answers: bool = False) -> QuestionQueue:
"""
创建干净的问题队列,不包含采样记录
Args:
questions: 原始问题列表
strategy: 队列策略
preserve_token_info: 是否保留token信息(用于预算计算)
preserve_stage1_answers: 是否保留第一阶段的完整回答
Returns:
干净的问题队列对象
"""
clean_questions = []
for q in questions:
clean_question = Question(
question_id=q.question_id,
question_text=q.question_text,
gold_answer=q.gold_answer,
gold_solution=q.gold_solution,
file_path=q.file_path
)
# 如果需要保留第一阶段的完整回答
if preserve_stage1_answers and q.sampling_records:
# 复制第一阶段的完整记录
for record in q.sampling_records:
if record.round_num == 0: # 第一阶段的记录
clean_question.add_sampling_record(record)
break
# 如果只需要保留token信息
elif preserve_token_info and q.sampling_records:
# 复制第一阶段的token信息(如果有的话)
first_record = q.sampling_records[0]
token_record = SamplingRecord(
round_num=0, # 标记为第0轮,表示这是token信息
token_count=first_record.token_count,
extracted_answer="", # 不复制答案信息
verify_result=False,
thinking_part="",
answer_part="",
verification_score=0.0,
verification_error=None,
cumulative_tokens=first_record.token_count,
budget_remaining=0,
run_count=0
)
clean_question.add_sampling_record(token_record)
clean_questions.append(clean_question)
queue = QuestionQueue(clean_questions)
queue.set_strategy(strategy)
return queue