Chart / data_manager.py
adddrett's picture
clean init
9fce90e
"""
数据管理模块 - 负责加载数据集和管理审核记录
"""
import os
import json
import uuid
from datetime import datetime
from typing import Dict, List, Optional, Any
from dataclasses import dataclass, asdict
from pathlib import Path
# ============== 数据类定义 ==============
@dataclass
class LabelInfo:
"""图表标签信息"""
Number: str
Type: str
Source: str
Weblink: str
Topic: str
Describe: str
Other: str = ""
@dataclass
class QAItem:
"""问答对"""
id: str
chart: str
question: str
answer: str
@dataclass
class ReviewRecord:
"""审核记录"""
review_id: str
chart_id: str
qa_id: str
source: str
chart_type: str
model: str
original_question: str
original_answer: str
status: str # correct, incorrect, needs_modification, pending
modified_question: str = ""
modified_answer: str = ""
issue_type: str = "" # question_ambiguous, answer_wrong, chart_unclear, other
comment: str = ""
reviewer: str = "default"
review_time: str = ""
def to_dict(self) -> Dict:
return asdict(self)
# ============== 数据管理器 ==============
class DataManager:
"""数据管理器"""
def __init__(self, dataset_path: str = "./dataset", reviews_path: str = "./reviews"):
"""
初始化数据管理器
Args:
dataset_path: 数据集根目录路径
reviews_path: 审核记录保存目录路径
"""
self.dataset_path = Path(dataset_path)
self.reviews_path = Path(reviews_path)
self.web_path = self.dataset_path / "web"
self.label_path = self.dataset_path / "label"
self.qa_path = self.dataset_path / "question_answer"
self.reviews_file = self.reviews_path / "reviews.json"
# 确保目录存在
self.reviews_path.mkdir(parents=True, exist_ok=True)
if not self.reviews_file.exists():
self._save_reviews([])
# ============== 数据集结构获取 ==============
def get_dataset_structure(self) -> Dict[str, Any]:
"""
获取数据集的目录结构
Returns:
包含 source -> chart_type -> models 的树形结构
"""
structure = {"sources": {}}
if not self.web_path.exists():
return structure
# 遍历所有 source
for source_name in self._list_dirs(self.web_path):
source_web_path = self.web_path / source_name
source_qa_path = self.qa_path / source_name
structure["sources"][source_name] = {"chart_types": {}}
# 遍历所有 chart_type
for chart_type_name in self._list_dirs(source_web_path):
chart_type_web_path = source_web_path / chart_type_name
chart_type_qa_path = source_qa_path / chart_type_name
# 获取图表数量
chart_files = [f for f in self._list_files(chart_type_web_path) if f.endswith('.html')]
chart_count = len(chart_files)
# 获取模型列表
models = self._list_dirs(chart_type_qa_path) if chart_type_qa_path.exists() else []
# 获取已审核数量
reviews = self.get_all_reviews()
reviewed_charts = set()
for r in reviews:
if r.get('source') == source_name and r.get('chart_type') == chart_type_name:
reviewed_charts.add(r.get('chart_id'))
reviewed_count = len(reviewed_charts)
structure["sources"][source_name]["chart_types"][chart_type_name] = {
"chart_count": chart_count,
"reviewed_count": reviewed_count,
"models": models
}
return structure
def get_chart_list(self, source: str, chart_type: str) -> List[str]:
"""获取指定 source 和 chart_type 下的所有图表 ID"""
chart_type_path = self.web_path / source / chart_type
if not chart_type_path.exists():
return []
return [f.replace('.html', '') for f in self._list_files(chart_type_path) if f.endswith('.html')]
def get_all_chart_paths(self) -> List[Dict[str, str]]:
"""
获取所有图表的完整路径(用于导航)
Returns:
包含 {source, chart_type, chart_id, model} 的列表
"""
paths = []
structure = self.get_dataset_structure()
for source_name, source_data in structure["sources"].items():
for chart_type_name, chart_type_data in source_data["chart_types"].items():
chart_ids = self.get_chart_list(source_name, chart_type_name)
for chart_id in chart_ids:
for model in chart_type_data.get("models", []):
paths.append({
"source": source_name,
"chart_type": chart_type_name,
"chart_id": chart_id,
"model": model
})
return paths
# ============== 图表数据获取 ==============
def get_chart_data(self, source: str, chart_type: str, chart_id: str) -> Dict[str, Any]:
"""
获取图表数据(HTML内容和标签信息)
Args:
source: 数据来源
chart_type: 图表类型
chart_id: 图表ID
Returns:
包含 html_content, html_path, label_info 的字典
"""
html_path = self.web_path / source / chart_type / f"{chart_id}.html"
label_path = self.label_path / source / chart_type / f"{chart_id}.json"
result = {
"html_content": "",
"html_path": str(html_path) if html_path.exists() else "",
"label_info": None
}
# 读取 HTML 内容
if html_path.exists():
with open(html_path, 'r', encoding='utf-8') as f:
result["html_content"] = f.read()
# 读取标签信息
if label_path.exists():
try:
with open(label_path, 'r', encoding='utf-8') as f:
label_data = json.load(f)
result["label_info"] = label_data
except Exception as e:
print(f"Error reading label file: {e}")
return result
# ============== QA 数据获取 ==============
def get_qa_list(self, source: str, chart_type: str, model: str, chart_id: str) -> List[QAItem]:
"""
获取指定图表的 QA 列表
Args:
source: 数据来源
chart_type: 图表类型
model: 模型名称
chart_id: 图表ID
Returns:
QAItem 列表
"""
qa_model_path = self.qa_path / source / chart_type / model
if not qa_model_path.exists():
return []
qa_list = []
for qa_file in self._list_files(qa_model_path):
if not qa_file.endswith('.json'):
continue
try:
with open(qa_model_path / qa_file, 'r', encoding='utf-8') as f:
qa_data = json.load(f)
# 筛选属于当前图表的 QA
if qa_data.get('chart') == chart_id or chart_id in qa_data.get('id', ''):
qa_list.append(QAItem(
id=qa_data.get('id', ''),
chart=qa_data.get('chart', ''),
question=qa_data.get('question', ''),
answer=qa_data.get('answer', '')
))
except Exception as e:
print(f"Error reading QA file {qa_file}: {e}")
return qa_list
def get_all_qa_for_chart(self, source: str, chart_type: str, chart_id: str) -> Dict[str, List[QAItem]]:
"""
获取指定图表所有模型的 QA 数据
Returns:
{model_name: [QAItem, ...], ...} 的字典
"""
chart_qa_path = self.qa_path / source / chart_type
if not chart_qa_path.exists():
return {}
result = {}
for model in self._list_dirs(chart_qa_path):
qa_list = self.get_qa_list(source, chart_type, model, chart_id)
if qa_list:
result[model] = qa_list
return result
# ============== 审核记录管理 ==============
def _save_reviews(self, reviews: List[Dict]):
"""保存审核记录到文件"""
with open(self.reviews_file, 'w', encoding='utf-8') as f:
json.dump(reviews, f, ensure_ascii=False, indent=2)
def get_all_reviews(self) -> List[Dict]:
"""获取所有审核记录"""
if not self.reviews_file.exists():
return []
try:
with open(self.reviews_file, 'r', encoding='utf-8') as f:
return json.load(f)
except:
return []
def get_review_by_qa_id(self, qa_id: str) -> Optional[Dict]:
"""根据 QA ID 获取审核记录"""
reviews = self.get_all_reviews()
for r in reviews:
if r.get('qa_id') == qa_id:
return r
return None
def get_reviews_by_chart(self, chart_id: str, model: Optional[str] = None) -> List[Dict]:
"""获取指定图表的审核记录"""
reviews = self.get_all_reviews()
result = []
for r in reviews:
if r.get('chart_id') == chart_id:
if model is None or r.get('model') == model:
result.append(r)
return result
def save_review(self, review_data: Dict) -> Dict:
"""
保存审核记录
Args:
review_data: 审核数据字典
Returns:
保存后的审核记录
"""
reviews = self.get_all_reviews()
# 生成审核记录
review = {
"review_id": review_data.get('review_id') or str(uuid.uuid4()),
"chart_id": review_data.get('chart_id', ''),
"qa_id": review_data.get('qa_id', ''),
"source": review_data.get('source', ''),
"chart_type": review_data.get('chart_type', ''),
"model": review_data.get('model', ''),
"original_question": review_data.get('original_question', ''),
"original_answer": review_data.get('original_answer', ''),
"status": review_data.get('status', 'pending'),
"modified_question": review_data.get('modified_question', ''),
"modified_answer": review_data.get('modified_answer', ''),
"issue_type": review_data.get('issue_type', ''),
"comment": review_data.get('comment', ''),
"reviewer": review_data.get('reviewer', 'default'),
"review_time": datetime.now().isoformat()
}
# 更新或添加
existing_index = None
for i, r in enumerate(reviews):
if r.get('qa_id') == review['qa_id']:
existing_index = i
break
if existing_index is not None:
reviews[existing_index] = review
else:
reviews.append(review)
self._save_reviews(reviews)
return review
def get_review_stats(self) -> Dict[str, int]:
"""获取审核统计信息"""
reviews = self.get_all_reviews()
return {
"total": len(reviews),
"correct": len([r for r in reviews if r.get('status') == 'correct']),
"incorrect": len([r for r in reviews if r.get('status') == 'incorrect']),
"needs_modification": len([r for r in reviews if r.get('status') == 'needs_modification']),
"pending": len([r for r in reviews if r.get('status') == 'pending'])
}
def export_reviews(self, output_path: str = None) -> str:
"""
导出审核记录
Args:
output_path: 输出文件路径,如果为 None 则自动生成
Returns:
导出文件的路径
"""
reviews = self.get_all_reviews()
if output_path is None:
output_path = f"./reviews_export_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(reviews, f, ensure_ascii=False, indent=2)
return output_path
# ============== 辅助方法 ==============
def _list_dirs(self, path: Path) -> List[str]:
"""列出目录下的所有子目录"""
if not path.exists():
return []
return [d.name for d in path.iterdir() if d.is_dir() and not d.name.startswith('.')]
def _list_files(self, path: Path) -> List[str]:
"""列出目录下的所有文件"""
if not path.exists():
return []
return [f.name for f in path.iterdir() if f.is_file() and not f.name.startswith('.')]
# ============== 全局实例 ==============
# 默认数据管理器实例
data_manager = DataManager()