| """ |
| 数据管理模块 - 负责加载数据集和管理审核记录 |
| """ |
| import os |
| import json |
| import uuid |
| from datetime import datetime |
| from typing import Dict, List, Optional, Any |
| from dataclasses import dataclass, asdict |
| from pathlib import Path |
|
|
|
|
| |
|
|
| @dataclass |
| class LabelInfo: |
| """图表标签信息""" |
| Number: str |
| Type: str |
| Source: str |
| Weblink: str |
| Topic: str |
| Describe: str |
| Other: str = "" |
|
|
|
|
| @dataclass |
| class QAItem: |
| """问答对""" |
| id: str |
| chart: str |
| question: str |
| answer: str |
|
|
|
|
| @dataclass |
| class ReviewRecord: |
| """审核记录""" |
| review_id: str |
| chart_id: str |
| qa_id: str |
| source: str |
| chart_type: str |
| model: str |
| original_question: str |
| original_answer: str |
| status: str |
| modified_question: str = "" |
| modified_answer: str = "" |
| issue_type: str = "" |
| comment: str = "" |
| reviewer: str = "default" |
| review_time: str = "" |
|
|
| def to_dict(self) -> Dict: |
| return asdict(self) |
|
|
|
|
| |
|
|
| class DataManager: |
| """数据管理器""" |
| |
| def __init__(self, dataset_path: str = "./dataset", reviews_path: str = "./reviews"): |
| """ |
| 初始化数据管理器 |
| |
| Args: |
| dataset_path: 数据集根目录路径 |
| reviews_path: 审核记录保存目录路径 |
| """ |
| self.dataset_path = Path(dataset_path) |
| self.reviews_path = Path(reviews_path) |
| self.web_path = self.dataset_path / "web" |
| self.label_path = self.dataset_path / "label" |
| self.qa_path = self.dataset_path / "question_answer" |
| self.reviews_file = self.reviews_path / "reviews.json" |
| |
| |
| self.reviews_path.mkdir(parents=True, exist_ok=True) |
| if not self.reviews_file.exists(): |
| self._save_reviews([]) |
| |
| |
| |
| def get_dataset_structure(self) -> Dict[str, Any]: |
| """ |
| 获取数据集的目录结构 |
| |
| Returns: |
| 包含 source -> chart_type -> models 的树形结构 |
| """ |
| structure = {"sources": {}} |
| |
| if not self.web_path.exists(): |
| return structure |
| |
| |
| for source_name in self._list_dirs(self.web_path): |
| source_web_path = self.web_path / source_name |
| source_qa_path = self.qa_path / source_name |
| |
| structure["sources"][source_name] = {"chart_types": {}} |
| |
| |
| for chart_type_name in self._list_dirs(source_web_path): |
| chart_type_web_path = source_web_path / chart_type_name |
| chart_type_qa_path = source_qa_path / chart_type_name |
| |
| |
| chart_files = [f for f in self._list_files(chart_type_web_path) if f.endswith('.html')] |
| chart_count = len(chart_files) |
| |
| |
| models = self._list_dirs(chart_type_qa_path) if chart_type_qa_path.exists() else [] |
| |
| |
| reviews = self.get_all_reviews() |
| reviewed_charts = set() |
| for r in reviews: |
| if r.get('source') == source_name and r.get('chart_type') == chart_type_name: |
| reviewed_charts.add(r.get('chart_id')) |
| reviewed_count = len(reviewed_charts) |
| |
| structure["sources"][source_name]["chart_types"][chart_type_name] = { |
| "chart_count": chart_count, |
| "reviewed_count": reviewed_count, |
| "models": models |
| } |
| |
| return structure |
| |
| def get_chart_list(self, source: str, chart_type: str) -> List[str]: |
| """获取指定 source 和 chart_type 下的所有图表 ID""" |
| chart_type_path = self.web_path / source / chart_type |
| if not chart_type_path.exists(): |
| return [] |
| |
| return [f.replace('.html', '') for f in self._list_files(chart_type_path) if f.endswith('.html')] |
| |
| def get_all_chart_paths(self) -> List[Dict[str, str]]: |
| """ |
| 获取所有图表的完整路径(用于导航) |
| |
| Returns: |
| 包含 {source, chart_type, chart_id, model} 的列表 |
| """ |
| paths = [] |
| structure = self.get_dataset_structure() |
| |
| for source_name, source_data in structure["sources"].items(): |
| for chart_type_name, chart_type_data in source_data["chart_types"].items(): |
| chart_ids = self.get_chart_list(source_name, chart_type_name) |
| for chart_id in chart_ids: |
| for model in chart_type_data.get("models", []): |
| paths.append({ |
| "source": source_name, |
| "chart_type": chart_type_name, |
| "chart_id": chart_id, |
| "model": model |
| }) |
| |
| return paths |
| |
| |
| |
| def get_chart_data(self, source: str, chart_type: str, chart_id: str) -> Dict[str, Any]: |
| """ |
| 获取图表数据(HTML内容和标签信息) |
| |
| Args: |
| source: 数据来源 |
| chart_type: 图表类型 |
| chart_id: 图表ID |
| |
| Returns: |
| 包含 html_content, html_path, label_info 的字典 |
| """ |
| html_path = self.web_path / source / chart_type / f"{chart_id}.html" |
| label_path = self.label_path / source / chart_type / f"{chart_id}.json" |
| |
| result = { |
| "html_content": "", |
| "html_path": str(html_path) if html_path.exists() else "", |
| "label_info": None |
| } |
| |
| |
| if html_path.exists(): |
| with open(html_path, 'r', encoding='utf-8') as f: |
| result["html_content"] = f.read() |
| |
| |
| if label_path.exists(): |
| try: |
| with open(label_path, 'r', encoding='utf-8') as f: |
| label_data = json.load(f) |
| result["label_info"] = label_data |
| except Exception as e: |
| print(f"Error reading label file: {e}") |
| |
| return result |
| |
| |
| |
| def get_qa_list(self, source: str, chart_type: str, model: str, chart_id: str) -> List[QAItem]: |
| """ |
| 获取指定图表的 QA 列表 |
| |
| Args: |
| source: 数据来源 |
| chart_type: 图表类型 |
| model: 模型名称 |
| chart_id: 图表ID |
| |
| Returns: |
| QAItem 列表 |
| """ |
| qa_model_path = self.qa_path / source / chart_type / model |
| |
| if not qa_model_path.exists(): |
| return [] |
| |
| qa_list = [] |
| for qa_file in self._list_files(qa_model_path): |
| if not qa_file.endswith('.json'): |
| continue |
| |
| try: |
| with open(qa_model_path / qa_file, 'r', encoding='utf-8') as f: |
| qa_data = json.load(f) |
| |
| if qa_data.get('chart') == chart_id or chart_id in qa_data.get('id', ''): |
| qa_list.append(QAItem( |
| id=qa_data.get('id', ''), |
| chart=qa_data.get('chart', ''), |
| question=qa_data.get('question', ''), |
| answer=qa_data.get('answer', '') |
| )) |
| except Exception as e: |
| print(f"Error reading QA file {qa_file}: {e}") |
| |
| return qa_list |
| |
| def get_all_qa_for_chart(self, source: str, chart_type: str, chart_id: str) -> Dict[str, List[QAItem]]: |
| """ |
| 获取指定图表所有模型的 QA 数据 |
| |
| Returns: |
| {model_name: [QAItem, ...], ...} 的字典 |
| """ |
| chart_qa_path = self.qa_path / source / chart_type |
| |
| if not chart_qa_path.exists(): |
| return {} |
| |
| result = {} |
| for model in self._list_dirs(chart_qa_path): |
| qa_list = self.get_qa_list(source, chart_type, model, chart_id) |
| if qa_list: |
| result[model] = qa_list |
| |
| return result |
| |
| |
| |
| def _save_reviews(self, reviews: List[Dict]): |
| """保存审核记录到文件""" |
| with open(self.reviews_file, 'w', encoding='utf-8') as f: |
| json.dump(reviews, f, ensure_ascii=False, indent=2) |
| |
| def get_all_reviews(self) -> List[Dict]: |
| """获取所有审核记录""" |
| if not self.reviews_file.exists(): |
| return [] |
| |
| try: |
| with open(self.reviews_file, 'r', encoding='utf-8') as f: |
| return json.load(f) |
| except: |
| return [] |
| |
| def get_review_by_qa_id(self, qa_id: str) -> Optional[Dict]: |
| """根据 QA ID 获取审核记录""" |
| reviews = self.get_all_reviews() |
| for r in reviews: |
| if r.get('qa_id') == qa_id: |
| return r |
| return None |
| |
| def get_reviews_by_chart(self, chart_id: str, model: Optional[str] = None) -> List[Dict]: |
| """获取指定图表的审核记录""" |
| reviews = self.get_all_reviews() |
| result = [] |
| for r in reviews: |
| if r.get('chart_id') == chart_id: |
| if model is None or r.get('model') == model: |
| result.append(r) |
| return result |
| |
| def save_review(self, review_data: Dict) -> Dict: |
| """ |
| 保存审核记录 |
| |
| Args: |
| review_data: 审核数据字典 |
| |
| Returns: |
| 保存后的审核记录 |
| """ |
| reviews = self.get_all_reviews() |
| |
| |
| review = { |
| "review_id": review_data.get('review_id') or str(uuid.uuid4()), |
| "chart_id": review_data.get('chart_id', ''), |
| "qa_id": review_data.get('qa_id', ''), |
| "source": review_data.get('source', ''), |
| "chart_type": review_data.get('chart_type', ''), |
| "model": review_data.get('model', ''), |
| "original_question": review_data.get('original_question', ''), |
| "original_answer": review_data.get('original_answer', ''), |
| "status": review_data.get('status', 'pending'), |
| "modified_question": review_data.get('modified_question', ''), |
| "modified_answer": review_data.get('modified_answer', ''), |
| "issue_type": review_data.get('issue_type', ''), |
| "comment": review_data.get('comment', ''), |
| "reviewer": review_data.get('reviewer', 'default'), |
| "review_time": datetime.now().isoformat() |
| } |
| |
| |
| existing_index = None |
| for i, r in enumerate(reviews): |
| if r.get('qa_id') == review['qa_id']: |
| existing_index = i |
| break |
| |
| if existing_index is not None: |
| reviews[existing_index] = review |
| else: |
| reviews.append(review) |
| |
| self._save_reviews(reviews) |
| return review |
| |
| def get_review_stats(self) -> Dict[str, int]: |
| """获取审核统计信息""" |
| reviews = self.get_all_reviews() |
| return { |
| "total": len(reviews), |
| "correct": len([r for r in reviews if r.get('status') == 'correct']), |
| "incorrect": len([r for r in reviews if r.get('status') == 'incorrect']), |
| "needs_modification": len([r for r in reviews if r.get('status') == 'needs_modification']), |
| "pending": len([r for r in reviews if r.get('status') == 'pending']) |
| } |
| |
| def export_reviews(self, output_path: str = None) -> str: |
| """ |
| 导出审核记录 |
| |
| Args: |
| output_path: 输出文件路径,如果为 None 则自动生成 |
| |
| Returns: |
| 导出文件的路径 |
| """ |
| reviews = self.get_all_reviews() |
| |
| if output_path is None: |
| output_path = f"./reviews_export_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json" |
| |
| with open(output_path, 'w', encoding='utf-8') as f: |
| json.dump(reviews, f, ensure_ascii=False, indent=2) |
| |
| return output_path |
| |
| |
| |
| def _list_dirs(self, path: Path) -> List[str]: |
| """列出目录下的所有子目录""" |
| if not path.exists(): |
| return [] |
| return [d.name for d in path.iterdir() if d.is_dir() and not d.name.startswith('.')] |
| |
| def _list_files(self, path: Path) -> List[str]: |
| """列出目录下的所有文件""" |
| if not path.exists(): |
| return [] |
| return [f.name for f in path.iterdir() if f.is_file() and not f.name.startswith('.')] |
|
|
|
|
| |
|
|
| |
| data_manager = DataManager() |
|
|