""" 问题队列类定义 管理问题列表,提供各种排序和调度方法 """ from typing import List, Dict, Any, Optional, Callable, Iterator from enum import Enum from question import Question, SamplingRecord import json import os class QueueStrategy(Enum): """队列策略枚举""" FCFS = "fcfs" # First Come First Serve SJF = "sjf" # Shortest Job First ROUND_ROBIN = "rr" # Round Robin PRIORITY = "priority" # Priority based CUSTOM = "custom" # Custom sorting class QuestionQueue: """问题队列类""" def __init__(self, questions: Optional[List[Question]] = None): """ 初始化问题队列 Args: questions: 问题列表 """ self.questions: List[Question] = questions or [] self.current_index: int = 0 self.strategy: QueueStrategy = QueueStrategy.FCFS self.custom_sort_key: Optional[Callable[[Question], Any]] = None # 队列统计信息 self.total_processed: int = 0 self.total_tokens_used: int = 0 self.total_correct: int = 0 def add_question(self, question: Question) -> None: """添加问题到队列""" self.questions.append(question) def add_questions(self, questions: List[Question]) -> None: """批量添加问题""" self.questions.extend(questions) def remove_question(self, question_id: int) -> bool: """根据ID移除问题""" for i, question in enumerate(self.questions): if question.question_id == question_id: self.questions.pop(i) return True return False def get_question(self, question_id: int) -> Optional[Question]: """根据ID获取问题""" for question in self.questions: if question.question_id == question_id: return question return None def set_strategy(self, strategy: QueueStrategy, custom_sort_key: Optional[Callable[[Question], Any]] = None) -> None: """ 设置队列策略 Args: strategy: 队列策略 custom_sort_key: 自定义排序键函数 """ self.strategy = strategy self.custom_sort_key = custom_sort_key self.current_index = 0 def sort_questions(self) -> None: """根据当前策略排序问题""" if self.strategy == QueueStrategy.FCFS: # FCFS: 保持原始顺序 pass elif self.strategy == QueueStrategy.SJF: # SJF: 按平均token数量排序 self.questions.sort(key=lambda q: q.avg_tokens) elif self.strategy == QueueStrategy.ROUND_ROBIN: # Round Robin: 保持原始顺序 pass elif self.strategy == QueueStrategy.PRIORITY: # Priority: 按准确率排序(准确率低的优先) self.questions.sort(key=lambda q: q.accuracy) elif self.strategy == QueueStrategy.CUSTOM and self.custom_sort_key: # Custom: 使用自定义排序键 self.questions.sort(key=self.custom_sort_key) self.current_index = 0 def get_next_question(self) -> Optional[Question]: """获取下一个要处理的问题""" if not self.questions: return None if self.strategy == QueueStrategy.ROUND_ROBIN: # Round Robin: 循环访问 question = self.questions[self.current_index] self.current_index = (self.current_index + 1) % len(self.questions) return question else: # 其他策略: 按排序顺序访问 if self.current_index < len(self.questions): question = self.questions[self.current_index] self.current_index += 1 return question return None def reset_iterator(self) -> None: """重置迭代器位置""" self.current_index = 0 def get_questions_iterator(self) -> Iterator[Question]: """获取问题迭代器""" if self.strategy == QueueStrategy.ROUND_ROBIN: # Round Robin: 无限循环 while True: for question in self.questions: yield question else: # 其他策略: 单次遍历 for question in self.questions: yield question def add_sampling_record(self, question_id: int, record: SamplingRecord) -> bool: """ 为指定问题添加采样记录 Args: question_id: 问题ID record: 采样记录 Returns: 是否成功添加 """ question = self.get_question(question_id) if question: question.add_sampling_record(record) return True return False def get_queue_stats(self) -> Dict[str, Any]: """获取队列统计信息""" if not self.questions: return { 'total_questions': 0, 'total_processed': 0, 'total_tokens_used': 0, 'total_correct': 0, 'overall_accuracy': 0.0, 'avg_tokens_per_question': 0.0 } total_questions = len(self.questions) total_processed = sum(q.total_runs for q in self.questions) total_tokens_used = sum(q.total_tokens for q in self.questions) total_correct = sum(q.correct_count for q in self.questions) overall_accuracy = total_correct / total_questions if total_questions > 0 else 0.0 avg_tokens_per_question = total_tokens_used / total_questions if total_questions > 0 else 0.0 return { 'total_questions': total_questions, 'total_processed': total_processed, 'total_tokens_used': total_tokens_used, 'total_correct': total_correct, 'overall_accuracy': overall_accuracy, 'avg_tokens_per_question': avg_tokens_per_question, 'strategy': self.strategy.value, 'current_index': self.current_index } def get_questions_by_criteria(self, min_accuracy: Optional[float] = None, max_accuracy: Optional[float] = None, min_tokens: Optional[float] = None, max_tokens: Optional[float] = None, min_runs: Optional[int] = None, max_runs: Optional[int] = None) -> List[Question]: """ 根据条件筛选问题 Args: min_accuracy: 最小准确率 max_accuracy: 最大准确率 min_tokens: 最小平均token数 max_tokens: 最大平均token数 min_runs: 最小运行次数 max_runs: 最大运行次数 Returns: 符合条件的问题列表 """ filtered_questions = [] for question in self.questions: # 检查准确率 if min_accuracy is not None and question.accuracy < min_accuracy: continue if max_accuracy is not None and question.accuracy > max_accuracy: continue # 检查平均token数 if min_tokens is not None and question.avg_tokens < min_tokens: continue if max_tokens is not None and question.avg_tokens > max_tokens: continue # 检查运行次数 if min_runs is not None and question.total_runs < min_runs: continue if max_runs is not None and question.total_runs > max_runs: continue filtered_questions.append(question) return filtered_questions def save_to_file(self, filepath: str) -> None: """保存队列到文件""" data = { 'questions': [q.to_dict() for q in self.questions], 'strategy': self.strategy.value, 'current_index': self.current_index, 'stats': self.get_queue_stats() } os.makedirs(os.path.dirname(filepath), exist_ok=True) with open(filepath, 'w', encoding='utf-8') as f: json.dump(data, f, ensure_ascii=False, indent=2) @classmethod def load_from_file(cls, filepath: str) -> 'QuestionQueue': """从文件加载队列""" with open(filepath, 'r', encoding='utf-8') as f: data = json.load(f) questions = [Question.from_dict(q_data) for q_data in data['questions']] queue = cls(questions) queue.strategy = QueueStrategy(data['strategy']) queue.current_index = data.get('current_index', 0) return queue def to_dict(self) -> Dict[str, Any]: """转换为字典格式""" return { 'questions': [q.to_dict() for q in self.questions], 'strategy': self.strategy.value, 'current_index': self.current_index, 'stats': self.get_queue_stats() } def __len__(self) -> int: """返回队列长度""" return len(self.questions) def __getitem__(self, index: int) -> Question: """索引访问""" return self.questions[index] def __iter__(self) -> Iterator[Question]: """迭代器""" return iter(self.questions) def copy(self) -> 'QuestionQueue': """创建队列的深拷贝""" copied_questions = [q.copy() for q in self.questions] copied_queue = QuestionQueue(copied_questions) copied_queue.strategy = self.strategy copied_queue.current_index = self.current_index copied_queue.custom_sort_key = self.custom_sort_key return copied_queue