File size: 13,658 Bytes
9fce90e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 | """
数据管理模块 - 负责加载数据集和管理审核记录
"""
import os
import json
import uuid
from datetime import datetime
from typing import Dict, List, Optional, Any
from dataclasses import dataclass, asdict
from pathlib import Path
# ============== 数据类定义 ==============
@dataclass
class LabelInfo:
"""图表标签信息"""
Number: str
Type: str
Source: str
Weblink: str
Topic: str
Describe: str
Other: str = ""
@dataclass
class QAItem:
"""问答对"""
id: str
chart: str
question: str
answer: str
@dataclass
class ReviewRecord:
"""审核记录"""
review_id: str
chart_id: str
qa_id: str
source: str
chart_type: str
model: str
original_question: str
original_answer: str
status: str # correct, incorrect, needs_modification, pending
modified_question: str = ""
modified_answer: str = ""
issue_type: str = "" # question_ambiguous, answer_wrong, chart_unclear, other
comment: str = ""
reviewer: str = "default"
review_time: str = ""
def to_dict(self) -> Dict:
return asdict(self)
# ============== 数据管理器 ==============
class DataManager:
"""数据管理器"""
def __init__(self, dataset_path: str = "./dataset", reviews_path: str = "./reviews"):
"""
初始化数据管理器
Args:
dataset_path: 数据集根目录路径
reviews_path: 审核记录保存目录路径
"""
self.dataset_path = Path(dataset_path)
self.reviews_path = Path(reviews_path)
self.web_path = self.dataset_path / "web"
self.label_path = self.dataset_path / "label"
self.qa_path = self.dataset_path / "question_answer"
self.reviews_file = self.reviews_path / "reviews.json"
# 确保目录存在
self.reviews_path.mkdir(parents=True, exist_ok=True)
if not self.reviews_file.exists():
self._save_reviews([])
# ============== 数据集结构获取 ==============
def get_dataset_structure(self) -> Dict[str, Any]:
"""
获取数据集的目录结构
Returns:
包含 source -> chart_type -> models 的树形结构
"""
structure = {"sources": {}}
if not self.web_path.exists():
return structure
# 遍历所有 source
for source_name in self._list_dirs(self.web_path):
source_web_path = self.web_path / source_name
source_qa_path = self.qa_path / source_name
structure["sources"][source_name] = {"chart_types": {}}
# 遍历所有 chart_type
for chart_type_name in self._list_dirs(source_web_path):
chart_type_web_path = source_web_path / chart_type_name
chart_type_qa_path = source_qa_path / chart_type_name
# 获取图表数量
chart_files = [f for f in self._list_files(chart_type_web_path) if f.endswith('.html')]
chart_count = len(chart_files)
# 获取模型列表
models = self._list_dirs(chart_type_qa_path) if chart_type_qa_path.exists() else []
# 获取已审核数量
reviews = self.get_all_reviews()
reviewed_charts = set()
for r in reviews:
if r.get('source') == source_name and r.get('chart_type') == chart_type_name:
reviewed_charts.add(r.get('chart_id'))
reviewed_count = len(reviewed_charts)
structure["sources"][source_name]["chart_types"][chart_type_name] = {
"chart_count": chart_count,
"reviewed_count": reviewed_count,
"models": models
}
return structure
def get_chart_list(self, source: str, chart_type: str) -> List[str]:
"""获取指定 source 和 chart_type 下的所有图表 ID"""
chart_type_path = self.web_path / source / chart_type
if not chart_type_path.exists():
return []
return [f.replace('.html', '') for f in self._list_files(chart_type_path) if f.endswith('.html')]
def get_all_chart_paths(self) -> List[Dict[str, str]]:
"""
获取所有图表的完整路径(用于导航)
Returns:
包含 {source, chart_type, chart_id, model} 的列表
"""
paths = []
structure = self.get_dataset_structure()
for source_name, source_data in structure["sources"].items():
for chart_type_name, chart_type_data in source_data["chart_types"].items():
chart_ids = self.get_chart_list(source_name, chart_type_name)
for chart_id in chart_ids:
for model in chart_type_data.get("models", []):
paths.append({
"source": source_name,
"chart_type": chart_type_name,
"chart_id": chart_id,
"model": model
})
return paths
# ============== 图表数据获取 ==============
def get_chart_data(self, source: str, chart_type: str, chart_id: str) -> Dict[str, Any]:
"""
获取图表数据(HTML内容和标签信息)
Args:
source: 数据来源
chart_type: 图表类型
chart_id: 图表ID
Returns:
包含 html_content, html_path, label_info 的字典
"""
html_path = self.web_path / source / chart_type / f"{chart_id}.html"
label_path = self.label_path / source / chart_type / f"{chart_id}.json"
result = {
"html_content": "",
"html_path": str(html_path) if html_path.exists() else "",
"label_info": None
}
# 读取 HTML 内容
if html_path.exists():
with open(html_path, 'r', encoding='utf-8') as f:
result["html_content"] = f.read()
# 读取标签信息
if label_path.exists():
try:
with open(label_path, 'r', encoding='utf-8') as f:
label_data = json.load(f)
result["label_info"] = label_data
except Exception as e:
print(f"Error reading label file: {e}")
return result
# ============== QA 数据获取 ==============
def get_qa_list(self, source: str, chart_type: str, model: str, chart_id: str) -> List[QAItem]:
"""
获取指定图表的 QA 列表
Args:
source: 数据来源
chart_type: 图表类型
model: 模型名称
chart_id: 图表ID
Returns:
QAItem 列表
"""
qa_model_path = self.qa_path / source / chart_type / model
if not qa_model_path.exists():
return []
qa_list = []
for qa_file in self._list_files(qa_model_path):
if not qa_file.endswith('.json'):
continue
try:
with open(qa_model_path / qa_file, 'r', encoding='utf-8') as f:
qa_data = json.load(f)
# 筛选属于当前图表的 QA
if qa_data.get('chart') == chart_id or chart_id in qa_data.get('id', ''):
qa_list.append(QAItem(
id=qa_data.get('id', ''),
chart=qa_data.get('chart', ''),
question=qa_data.get('question', ''),
answer=qa_data.get('answer', '')
))
except Exception as e:
print(f"Error reading QA file {qa_file}: {e}")
return qa_list
def get_all_qa_for_chart(self, source: str, chart_type: str, chart_id: str) -> Dict[str, List[QAItem]]:
"""
获取指定图表所有模型的 QA 数据
Returns:
{model_name: [QAItem, ...], ...} 的字典
"""
chart_qa_path = self.qa_path / source / chart_type
if not chart_qa_path.exists():
return {}
result = {}
for model in self._list_dirs(chart_qa_path):
qa_list = self.get_qa_list(source, chart_type, model, chart_id)
if qa_list:
result[model] = qa_list
return result
# ============== 审核记录管理 ==============
def _save_reviews(self, reviews: List[Dict]):
"""保存审核记录到文件"""
with open(self.reviews_file, 'w', encoding='utf-8') as f:
json.dump(reviews, f, ensure_ascii=False, indent=2)
def get_all_reviews(self) -> List[Dict]:
"""获取所有审核记录"""
if not self.reviews_file.exists():
return []
try:
with open(self.reviews_file, 'r', encoding='utf-8') as f:
return json.load(f)
except:
return []
def get_review_by_qa_id(self, qa_id: str) -> Optional[Dict]:
"""根据 QA ID 获取审核记录"""
reviews = self.get_all_reviews()
for r in reviews:
if r.get('qa_id') == qa_id:
return r
return None
def get_reviews_by_chart(self, chart_id: str, model: Optional[str] = None) -> List[Dict]:
"""获取指定图表的审核记录"""
reviews = self.get_all_reviews()
result = []
for r in reviews:
if r.get('chart_id') == chart_id:
if model is None or r.get('model') == model:
result.append(r)
return result
def save_review(self, review_data: Dict) -> Dict:
"""
保存审核记录
Args:
review_data: 审核数据字典
Returns:
保存后的审核记录
"""
reviews = self.get_all_reviews()
# 生成审核记录
review = {
"review_id": review_data.get('review_id') or str(uuid.uuid4()),
"chart_id": review_data.get('chart_id', ''),
"qa_id": review_data.get('qa_id', ''),
"source": review_data.get('source', ''),
"chart_type": review_data.get('chart_type', ''),
"model": review_data.get('model', ''),
"original_question": review_data.get('original_question', ''),
"original_answer": review_data.get('original_answer', ''),
"status": review_data.get('status', 'pending'),
"modified_question": review_data.get('modified_question', ''),
"modified_answer": review_data.get('modified_answer', ''),
"issue_type": review_data.get('issue_type', ''),
"comment": review_data.get('comment', ''),
"reviewer": review_data.get('reviewer', 'default'),
"review_time": datetime.now().isoformat()
}
# 更新或添加
existing_index = None
for i, r in enumerate(reviews):
if r.get('qa_id') == review['qa_id']:
existing_index = i
break
if existing_index is not None:
reviews[existing_index] = review
else:
reviews.append(review)
self._save_reviews(reviews)
return review
def get_review_stats(self) -> Dict[str, int]:
"""获取审核统计信息"""
reviews = self.get_all_reviews()
return {
"total": len(reviews),
"correct": len([r for r in reviews if r.get('status') == 'correct']),
"incorrect": len([r for r in reviews if r.get('status') == 'incorrect']),
"needs_modification": len([r for r in reviews if r.get('status') == 'needs_modification']),
"pending": len([r for r in reviews if r.get('status') == 'pending'])
}
def export_reviews(self, output_path: str = None) -> str:
"""
导出审核记录
Args:
output_path: 输出文件路径,如果为 None 则自动生成
Returns:
导出文件的路径
"""
reviews = self.get_all_reviews()
if output_path is None:
output_path = f"./reviews_export_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(reviews, f, ensure_ascii=False, indent=2)
return output_path
# ============== 辅助方法 ==============
def _list_dirs(self, path: Path) -> List[str]:
"""列出目录下的所有子目录"""
if not path.exists():
return []
return [d.name for d in path.iterdir() if d.is_dir() and not d.name.startswith('.')]
def _list_files(self, path: Path) -> List[str]:
"""列出目录下的所有文件"""
if not path.exists():
return []
return [f.name for f in path.iterdir() if f.is_file() and not f.name.startswith('.')]
# ============== 全局实例 ==============
# 默认数据管理器实例
data_manager = DataManager()
|