| """ |
| 问题队列类定义 |
| 管理问题列表,提供各种排序和调度方法 |
| """ |
|
|
| from typing import List, Dict, Any, Optional, Callable, Iterator |
| from enum import Enum |
| from question import Question, SamplingRecord |
| import json |
| import os |
|
|
| class QueueStrategy(Enum): |
| """队列策略枚举""" |
| FCFS = "fcfs" |
| SJF = "sjf" |
| ROUND_ROBIN = "rr" |
| PRIORITY = "priority" |
| CUSTOM = "custom" |
|
|
| class QuestionQueue: |
| """问题队列类""" |
| |
| def __init__(self, questions: Optional[List[Question]] = None): |
| """ |
| 初始化问题队列 |
| |
| Args: |
| questions: 问题列表 |
| """ |
| self.questions: List[Question] = questions or [] |
| self.current_index: int = 0 |
| self.strategy: QueueStrategy = QueueStrategy.FCFS |
| self.custom_sort_key: Optional[Callable[[Question], Any]] = None |
| |
| |
| self.total_processed: int = 0 |
| self.total_tokens_used: int = 0 |
| self.total_correct: int = 0 |
| |
| def add_question(self, question: Question) -> None: |
| """添加问题到队列""" |
| self.questions.append(question) |
| |
| def add_questions(self, questions: List[Question]) -> None: |
| """批量添加问题""" |
| self.questions.extend(questions) |
| |
| def remove_question(self, question_id: int) -> bool: |
| """根据ID移除问题""" |
| for i, question in enumerate(self.questions): |
| if question.question_id == question_id: |
| self.questions.pop(i) |
| return True |
| return False |
| |
| def get_question(self, question_id: int) -> Optional[Question]: |
| """根据ID获取问题""" |
| for question in self.questions: |
| if question.question_id == question_id: |
| return question |
| return None |
| |
| def set_strategy(self, strategy: QueueStrategy, |
| custom_sort_key: Optional[Callable[[Question], Any]] = None) -> None: |
| """ |
| 设置队列策略 |
| |
| Args: |
| strategy: 队列策略 |
| custom_sort_key: 自定义排序键函数 |
| """ |
| self.strategy = strategy |
| self.custom_sort_key = custom_sort_key |
| self.current_index = 0 |
| |
| def sort_questions(self) -> None: |
| """根据当前策略排序问题""" |
| if self.strategy == QueueStrategy.FCFS: |
| |
| pass |
| elif self.strategy == QueueStrategy.SJF: |
| |
| self.questions.sort(key=lambda q: q.avg_tokens) |
| elif self.strategy == QueueStrategy.ROUND_ROBIN: |
| |
| pass |
| elif self.strategy == QueueStrategy.PRIORITY: |
| |
| self.questions.sort(key=lambda q: q.accuracy) |
| elif self.strategy == QueueStrategy.CUSTOM and self.custom_sort_key: |
| |
| self.questions.sort(key=self.custom_sort_key) |
| |
| self.current_index = 0 |
| |
| def get_next_question(self) -> Optional[Question]: |
| """获取下一个要处理的问题""" |
| if not self.questions: |
| return None |
| |
| if self.strategy == QueueStrategy.ROUND_ROBIN: |
| |
| question = self.questions[self.current_index] |
| self.current_index = (self.current_index + 1) % len(self.questions) |
| return question |
| else: |
| |
| if self.current_index < len(self.questions): |
| question = self.questions[self.current_index] |
| self.current_index += 1 |
| return question |
| return None |
| |
| def reset_iterator(self) -> None: |
| """重置迭代器位置""" |
| self.current_index = 0 |
| |
| def get_questions_iterator(self) -> Iterator[Question]: |
| """获取问题迭代器""" |
| if self.strategy == QueueStrategy.ROUND_ROBIN: |
| |
| while True: |
| for question in self.questions: |
| yield question |
| else: |
| |
| for question in self.questions: |
| yield question |
| |
| def add_sampling_record(self, question_id: int, record: SamplingRecord) -> bool: |
| """ |
| 为指定问题添加采样记录 |
| |
| Args: |
| question_id: 问题ID |
| record: 采样记录 |
| |
| Returns: |
| 是否成功添加 |
| """ |
| question = self.get_question(question_id) |
| if question: |
| question.add_sampling_record(record) |
| return True |
| return False |
| |
| def get_queue_stats(self) -> Dict[str, Any]: |
| """获取队列统计信息""" |
| if not self.questions: |
| return { |
| 'total_questions': 0, |
| 'total_processed': 0, |
| 'total_tokens_used': 0, |
| 'total_correct': 0, |
| 'overall_accuracy': 0.0, |
| 'avg_tokens_per_question': 0.0 |
| } |
| |
| total_questions = len(self.questions) |
| total_processed = sum(q.total_runs for q in self.questions) |
| total_tokens_used = sum(q.total_tokens for q in self.questions) |
| total_correct = sum(q.correct_count for q in self.questions) |
| |
| overall_accuracy = total_correct / total_questions if total_questions > 0 else 0.0 |
| avg_tokens_per_question = total_tokens_used / total_questions if total_questions > 0 else 0.0 |
| |
| return { |
| 'total_questions': total_questions, |
| 'total_processed': total_processed, |
| 'total_tokens_used': total_tokens_used, |
| 'total_correct': total_correct, |
| 'overall_accuracy': overall_accuracy, |
| 'avg_tokens_per_question': avg_tokens_per_question, |
| 'strategy': self.strategy.value, |
| 'current_index': self.current_index |
| } |
| |
| def get_questions_by_criteria(self, |
| min_accuracy: Optional[float] = None, |
| max_accuracy: Optional[float] = None, |
| min_tokens: Optional[float] = None, |
| max_tokens: Optional[float] = None, |
| min_runs: Optional[int] = None, |
| max_runs: Optional[int] = None) -> List[Question]: |
| """ |
| 根据条件筛选问题 |
| |
| Args: |
| min_accuracy: 最小准确率 |
| max_accuracy: 最大准确率 |
| min_tokens: 最小平均token数 |
| max_tokens: 最大平均token数 |
| min_runs: 最小运行次数 |
| max_runs: 最大运行次数 |
| |
| Returns: |
| 符合条件的问题列表 |
| """ |
| filtered_questions = [] |
| |
| for question in self.questions: |
| |
| if min_accuracy is not None and question.accuracy < min_accuracy: |
| continue |
| if max_accuracy is not None and question.accuracy > max_accuracy: |
| continue |
| |
| |
| if min_tokens is not None and question.avg_tokens < min_tokens: |
| continue |
| if max_tokens is not None and question.avg_tokens > max_tokens: |
| continue |
| |
| |
| if min_runs is not None and question.total_runs < min_runs: |
| continue |
| if max_runs is not None and question.total_runs > max_runs: |
| continue |
| |
| filtered_questions.append(question) |
| |
| return filtered_questions |
| |
| def save_to_file(self, filepath: str) -> None: |
| """保存队列到文件""" |
| data = { |
| 'questions': [q.to_dict() for q in self.questions], |
| 'strategy': self.strategy.value, |
| 'current_index': self.current_index, |
| 'stats': self.get_queue_stats() |
| } |
| |
| os.makedirs(os.path.dirname(filepath), exist_ok=True) |
| with open(filepath, 'w', encoding='utf-8') as f: |
| json.dump(data, f, ensure_ascii=False, indent=2) |
| |
| @classmethod |
| def load_from_file(cls, filepath: str) -> 'QuestionQueue': |
| """从文件加载队列""" |
| with open(filepath, 'r', encoding='utf-8') as f: |
| data = json.load(f) |
| |
| questions = [Question.from_dict(q_data) for q_data in data['questions']] |
| queue = cls(questions) |
| queue.strategy = QueueStrategy(data['strategy']) |
| queue.current_index = data.get('current_index', 0) |
| |
| return queue |
| |
| def to_dict(self) -> Dict[str, Any]: |
| """转换为字典格式""" |
| return { |
| 'questions': [q.to_dict() for q in self.questions], |
| 'strategy': self.strategy.value, |
| 'current_index': self.current_index, |
| 'stats': self.get_queue_stats() |
| } |
| |
| def __len__(self) -> int: |
| """返回队列长度""" |
| return len(self.questions) |
| |
| def __getitem__(self, index: int) -> Question: |
| """索引访问""" |
| return self.questions[index] |
| |
| def __iter__(self) -> Iterator[Question]: |
| """迭代器""" |
| return iter(self.questions) |
| |
| def copy(self) -> 'QuestionQueue': |
| """创建队列的深拷贝""" |
| copied_questions = [q.copy() for q in self.questions] |
| copied_queue = QuestionQueue(copied_questions) |
| copied_queue.strategy = self.strategy |
| copied_queue.current_index = self.current_index |
| copied_queue.custom_sort_key = self.custom_sort_key |
| return copied_queue |