""" 数据管理模块 - 负责加载数据集和管理审核记录 """ import os import json import uuid from datetime import datetime from typing import Dict, List, Optional, Any from dataclasses import dataclass, asdict from pathlib import Path # ============== 数据类定义 ============== @dataclass class LabelInfo: """图表标签信息""" Number: str Type: str Source: str Weblink: str Topic: str Describe: str Other: str = "" @dataclass class QAItem: """问答对""" id: str chart: str question: str answer: str @dataclass class ReviewRecord: """审核记录""" review_id: str chart_id: str qa_id: str source: str chart_type: str model: str original_question: str original_answer: str status: str # correct, incorrect, needs_modification, pending modified_question: str = "" modified_answer: str = "" issue_type: str = "" # question_ambiguous, answer_wrong, chart_unclear, other comment: str = "" reviewer: str = "default" review_time: str = "" def to_dict(self) -> Dict: return asdict(self) # ============== 数据管理器 ============== class DataManager: """数据管理器""" def __init__(self, dataset_path: str = "./dataset", reviews_path: str = "./reviews"): """ 初始化数据管理器 Args: dataset_path: 数据集根目录路径 reviews_path: 审核记录保存目录路径 """ self.dataset_path = Path(dataset_path) self.reviews_path = Path(reviews_path) self.web_path = self.dataset_path / "web" self.label_path = self.dataset_path / "label" self.qa_path = self.dataset_path / "question_answer" self.reviews_file = self.reviews_path / "reviews.json" # 确保目录存在 self.reviews_path.mkdir(parents=True, exist_ok=True) if not self.reviews_file.exists(): self._save_reviews([]) # ============== 数据集结构获取 ============== def get_dataset_structure(self) -> Dict[str, Any]: """ 获取数据集的目录结构 Returns: 包含 source -> chart_type -> models 的树形结构 """ structure = {"sources": {}} if not self.web_path.exists(): return structure # 遍历所有 source for source_name in self._list_dirs(self.web_path): source_web_path = self.web_path / source_name source_qa_path = self.qa_path / source_name structure["sources"][source_name] = {"chart_types": {}} # 遍历所有 chart_type for chart_type_name in self._list_dirs(source_web_path): chart_type_web_path = source_web_path / chart_type_name chart_type_qa_path = source_qa_path / chart_type_name # 获取图表数量 chart_files = [f for f in self._list_files(chart_type_web_path) if f.endswith('.html')] chart_count = len(chart_files) # 获取模型列表 models = self._list_dirs(chart_type_qa_path) if chart_type_qa_path.exists() else [] # 获取已审核数量 reviews = self.get_all_reviews() reviewed_charts = set() for r in reviews: if r.get('source') == source_name and r.get('chart_type') == chart_type_name: reviewed_charts.add(r.get('chart_id')) reviewed_count = len(reviewed_charts) structure["sources"][source_name]["chart_types"][chart_type_name] = { "chart_count": chart_count, "reviewed_count": reviewed_count, "models": models } return structure def get_chart_list(self, source: str, chart_type: str) -> List[str]: """获取指定 source 和 chart_type 下的所有图表 ID""" chart_type_path = self.web_path / source / chart_type if not chart_type_path.exists(): return [] return [f.replace('.html', '') for f in self._list_files(chart_type_path) if f.endswith('.html')] def get_all_chart_paths(self) -> List[Dict[str, str]]: """ 获取所有图表的完整路径(用于导航) Returns: 包含 {source, chart_type, chart_id, model} 的列表 """ paths = [] structure = self.get_dataset_structure() for source_name, source_data in structure["sources"].items(): for chart_type_name, chart_type_data in source_data["chart_types"].items(): chart_ids = self.get_chart_list(source_name, chart_type_name) for chart_id in chart_ids: for model in chart_type_data.get("models", []): paths.append({ "source": source_name, "chart_type": chart_type_name, "chart_id": chart_id, "model": model }) return paths # ============== 图表数据获取 ============== def get_chart_data(self, source: str, chart_type: str, chart_id: str) -> Dict[str, Any]: """ 获取图表数据(HTML内容和标签信息) Args: source: 数据来源 chart_type: 图表类型 chart_id: 图表ID Returns: 包含 html_content, html_path, label_info 的字典 """ html_path = self.web_path / source / chart_type / f"{chart_id}.html" label_path = self.label_path / source / chart_type / f"{chart_id}.json" result = { "html_content": "", "html_path": str(html_path) if html_path.exists() else "", "label_info": None } # 读取 HTML 内容 if html_path.exists(): with open(html_path, 'r', encoding='utf-8') as f: result["html_content"] = f.read() # 读取标签信息 if label_path.exists(): try: with open(label_path, 'r', encoding='utf-8') as f: label_data = json.load(f) result["label_info"] = label_data except Exception as e: print(f"Error reading label file: {e}") return result # ============== QA 数据获取 ============== def get_qa_list(self, source: str, chart_type: str, model: str, chart_id: str) -> List[QAItem]: """ 获取指定图表的 QA 列表 Args: source: 数据来源 chart_type: 图表类型 model: 模型名称 chart_id: 图表ID Returns: QAItem 列表 """ qa_model_path = self.qa_path / source / chart_type / model if not qa_model_path.exists(): return [] qa_list = [] for qa_file in self._list_files(qa_model_path): if not qa_file.endswith('.json'): continue try: with open(qa_model_path / qa_file, 'r', encoding='utf-8') as f: qa_data = json.load(f) # 筛选属于当前图表的 QA if qa_data.get('chart') == chart_id or chart_id in qa_data.get('id', ''): qa_list.append(QAItem( id=qa_data.get('id', ''), chart=qa_data.get('chart', ''), question=qa_data.get('question', ''), answer=qa_data.get('answer', '') )) except Exception as e: print(f"Error reading QA file {qa_file}: {e}") return qa_list def get_all_qa_for_chart(self, source: str, chart_type: str, chart_id: str) -> Dict[str, List[QAItem]]: """ 获取指定图表所有模型的 QA 数据 Returns: {model_name: [QAItem, ...], ...} 的字典 """ chart_qa_path = self.qa_path / source / chart_type if not chart_qa_path.exists(): return {} result = {} for model in self._list_dirs(chart_qa_path): qa_list = self.get_qa_list(source, chart_type, model, chart_id) if qa_list: result[model] = qa_list return result # ============== 审核记录管理 ============== def _save_reviews(self, reviews: List[Dict]): """保存审核记录到文件""" with open(self.reviews_file, 'w', encoding='utf-8') as f: json.dump(reviews, f, ensure_ascii=False, indent=2) def get_all_reviews(self) -> List[Dict]: """获取所有审核记录""" if not self.reviews_file.exists(): return [] try: with open(self.reviews_file, 'r', encoding='utf-8') as f: return json.load(f) except: return [] def get_review_by_qa_id(self, qa_id: str) -> Optional[Dict]: """根据 QA ID 获取审核记录""" reviews = self.get_all_reviews() for r in reviews: if r.get('qa_id') == qa_id: return r return None def get_reviews_by_chart(self, chart_id: str, model: Optional[str] = None) -> List[Dict]: """获取指定图表的审核记录""" reviews = self.get_all_reviews() result = [] for r in reviews: if r.get('chart_id') == chart_id: if model is None or r.get('model') == model: result.append(r) return result def save_review(self, review_data: Dict) -> Dict: """ 保存审核记录 Args: review_data: 审核数据字典 Returns: 保存后的审核记录 """ reviews = self.get_all_reviews() # 生成审核记录 review = { "review_id": review_data.get('review_id') or str(uuid.uuid4()), "chart_id": review_data.get('chart_id', ''), "qa_id": review_data.get('qa_id', ''), "source": review_data.get('source', ''), "chart_type": review_data.get('chart_type', ''), "model": review_data.get('model', ''), "original_question": review_data.get('original_question', ''), "original_answer": review_data.get('original_answer', ''), "status": review_data.get('status', 'pending'), "modified_question": review_data.get('modified_question', ''), "modified_answer": review_data.get('modified_answer', ''), "issue_type": review_data.get('issue_type', ''), "comment": review_data.get('comment', ''), "reviewer": review_data.get('reviewer', 'default'), "review_time": datetime.now().isoformat() } # 更新或添加 existing_index = None for i, r in enumerate(reviews): if r.get('qa_id') == review['qa_id']: existing_index = i break if existing_index is not None: reviews[existing_index] = review else: reviews.append(review) self._save_reviews(reviews) return review def get_review_stats(self) -> Dict[str, int]: """获取审核统计信息""" reviews = self.get_all_reviews() return { "total": len(reviews), "correct": len([r for r in reviews if r.get('status') == 'correct']), "incorrect": len([r for r in reviews if r.get('status') == 'incorrect']), "needs_modification": len([r for r in reviews if r.get('status') == 'needs_modification']), "pending": len([r for r in reviews if r.get('status') == 'pending']) } def export_reviews(self, output_path: str = None) -> str: """ 导出审核记录 Args: output_path: 输出文件路径,如果为 None 则自动生成 Returns: 导出文件的路径 """ reviews = self.get_all_reviews() if output_path is None: output_path = f"./reviews_export_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json" with open(output_path, 'w', encoding='utf-8') as f: json.dump(reviews, f, ensure_ascii=False, indent=2) return output_path # ============== 辅助方法 ============== def _list_dirs(self, path: Path) -> List[str]: """列出目录下的所有子目录""" if not path.exists(): return [] return [d.name for d in path.iterdir() if d.is_dir() and not d.name.startswith('.')] def _list_files(self, path: Path) -> List[str]: """列出目录下的所有文件""" if not path.exists(): return [] return [f.name for f in path.iterdir() if f.is_file() and not f.name.startswith('.')] # ============== 全局实例 ============== # 默认数据管理器实例 data_manager = DataManager()