kailinjiang commited on
Commit
1353b71
·
verified ·
1 Parent(s): b7e509b

Upload benchmark_load.py

Browse files
Files changed (1) hide show
  1. benchmark_load.py +354 -0
benchmark_load.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path as osp
3
+ import pandas as pd
4
+ import warnings
5
+ from typing import Dict, List, Optional, Union
6
+ import hashlib
7
+ from tqdm import tqdm
8
+ import urllib.request
9
+ import json
10
+ import base64
11
+ from PIL import Image
12
+ import io
13
+
14
+ # 数据集配置
15
+ DATASET_CONFIG = {
16
+ 'MME': {
17
+ 'url': 'https://opencompass.openxlab.space/utils/VLMEval/MME.tsv',
18
+ 'md5': 'b36b43c3f09801f5d368627fb92187c3',
19
+ 'type': 'Y/N'
20
+ },
21
+ 'HallusionBench': {
22
+ 'url': 'https://opencompass.openxlab.space/utils/VLMEval/HallusionBench.tsv',
23
+ 'md5': '0c23ac0dc9ef46832d7a24504f2a0c7c',
24
+ 'type': 'Y/N'
25
+ },
26
+ 'POPE': {
27
+ 'url': 'https://opencompass.openxlab.space/utils/VLMEval/POPE.tsv',
28
+ 'md5': 'c12f5acb142f2ef1f85a26ba2fbe41d5',
29
+ 'type': 'Y/N'
30
+ },
31
+ 'AMBER': {
32
+ 'url': 'https://huggingface.co/datasets/yifanzhang114/AMBER_base64/resolve/main/AMBER.tsv',
33
+ 'md5': '970d94c0410916166e0a76ba75da7934',
34
+ 'type': 'Y/N'
35
+ },
36
+ 'MMBench_DEV_EN': {
37
+ 'url': 'https://opencompass.openxlab.space/utils/benchmarks/MMBench/MMBench_DEV_EN.tsv',
38
+ 'md5': 'b6caf1133a01c6bb705cf753bb527ed8',
39
+ 'type': 'MCQ'
40
+ },
41
+ 'SEEDBench2_Plus': {
42
+ 'url': 'https://opencompass.openxlab.space/utils/benchmarks/SEEDBench/SEEDBench2_Plus.tsv',
43
+ 'md5': 'e32d3216dc4f452b0fe497a52015d1fd',
44
+ 'type': 'MCQ'
45
+ },
46
+ 'ScienceQA_VAL': {
47
+ 'url': 'https://opencompass.openxlab.space/utils/benchmarks/ScienceQA/ScienceQA_VAL.tsv',
48
+ 'md5': '96320d05e142e585e7204e72affd29f3',
49
+ 'type': 'MCQ'
50
+ },
51
+ 'MMMU_TEST': {
52
+ 'url': 'https://opencompass.openxlab.space/utils/VLMEval/MMMU_TEST.tsv',
53
+ 'md5': 'c19875d11a2d348d07e5eb4bdf33166d',
54
+ 'type': 'MCQ'
55
+ },
56
+ 'OCRBench': {
57
+ 'url': 'https://opencompass.openxlab.space/utils/VLMEval/OCRBench.tsv',
58
+ 'md5': 'e953d98a987cc6e26ef717b61260b778',
59
+ 'type': 'VQA'
60
+ },
61
+ 'MathVista_MINI': {
62
+ 'url': 'https://opencompass.openxlab.space/utils/VLMEval/MathVista_MINI.tsv',
63
+ 'md5': 'f199b98e178e5a2a20e7048f5dcb0464',
64
+ 'type': 'VQA'
65
+ },
66
+ 'MathVision': {
67
+ 'url': 'https://opencompass.openxlab.space/utils/VLMEval/MathVision.tsv',
68
+ 'md5': '93f6de14f7916e598aa1b7165589831e',
69
+ 'type': 'VQA'
70
+ },
71
+ 'MMDU': {
72
+ 'url': 'https://opencompass.openxlab.space/utils/VLMEval/MMDU.tsv',
73
+ 'md5': '848b635a88a078f49aebcc6e39792061',
74
+ 'type': 'MT'
75
+ },
76
+ 'MIA-Bench': {
77
+ 'url': 'https://opencompass.openxlab.space/utils/VLMEval/Mia-Bench.tsv',
78
+ 'md5': '0b9de595f4dd40af18a69b94d89aba82',
79
+ 'type': 'VQA'
80
+ }
81
+ }
82
+
83
+ class DownloadProgressBar(tqdm):
84
+ """下载进度条"""
85
+ def update_to(self, b=1, bsize=1, tsize=None):
86
+ if tsize is not None:
87
+ self.total = tsize
88
+ self.update(b * bsize - self.n)
89
+
90
+ def download_file(url: str, filename: str) -> str:
91
+ """下载文件"""
92
+ try:
93
+ with DownloadProgressBar(unit='B', unit_scale=True, miniters=1, desc=filename) as t:
94
+ urllib.request.urlretrieve(url, filename=filename, reporthook=t.update_to)
95
+ except Exception as e:
96
+ warnings.warn(f'下载失败: {e}')
97
+ # 处理huggingface.co下载失败的情况
98
+ if 'huggingface.co' in url:
99
+ url_new = url.replace('huggingface.co', 'hf-mirror.com')
100
+ try:
101
+ download_file(url_new, filename)
102
+ return filename
103
+ except Exception as e:
104
+ warnings.warn(f'镜像下载也失败: {e}')
105
+ raise Exception(f'无法下载 {url}')
106
+ else:
107
+ raise Exception(f'无法下载 {url}')
108
+
109
+ return filename
110
+
111
+ def md5(file_path: str) -> str:
112
+ """计算文件MD5"""
113
+ hash_md5 = hashlib.md5()
114
+ with open(file_path, "rb") as f:
115
+ for chunk in iter(lambda: f.read(4096), b""):
116
+ hash_md5.update(chunk)
117
+ return hash_md5.hexdigest()
118
+
119
+ def decode_base64_to_image_file(base64_str: str, output_path: str):
120
+ """将base64字符串解码为图片文件"""
121
+ # 移除data:image/jpeg;base64,前缀
122
+ if ',' in base64_str:
123
+ base64_str = base64_str.split(',')[1]
124
+
125
+ # 解码base64
126
+ image_data = base64.b64decode(base64_str)
127
+
128
+ # 保存图片
129
+ with open(output_path, 'wb') as f:
130
+ f.write(image_data)
131
+
132
+ def read_ok(file_path: str) -> bool:
133
+ """检查文件是否可读"""
134
+ return os.path.exists(file_path) and os.path.getsize(file_path) > 0
135
+
136
+ def toliststr(x):
137
+ """将输入转换为字符串列表"""
138
+ if isinstance(x, str):
139
+ return [x]
140
+ elif isinstance(x, list):
141
+ return [str(item) for item in x]
142
+ else:
143
+ return [str(x)]
144
+
145
+ class DatasetProcessor:
146
+ """数据集处理器"""
147
+
148
+ def __init__(self, data_root: str):
149
+ """初始化处理器"""
150
+ self.data_root = data_root
151
+
152
+ os.makedirs(self.data_root, exist_ok=True)
153
+ self.img_root = os.path.join(self.data_root, 'images')
154
+ os.makedirs(self.img_root, exist_ok=True)
155
+
156
+ def download_dataset(self, dataset_name: str) -> str:
157
+ """下载单个数据集"""
158
+ if dataset_name not in DATASET_CONFIG:
159
+ raise ValueError(f"不支持的数据集: {dataset_name}")
160
+
161
+ config = DATASET_CONFIG[dataset_name]
162
+ url = config['url']
163
+ file_md5 = config['md5']
164
+
165
+ # 构建文件路径
166
+ file_name = url.split('/')[-1]
167
+ data_path = os.path.join(self.data_root, file_name)
168
+
169
+ # 检查文件是否已存在且MD5正确
170
+ if os.path.exists(data_path):
171
+ if md5(data_path) == file_md5:
172
+ print(f"✓ 数据集 {dataset_name} 已存在且MD5正确")
173
+ return data_path
174
+ else:
175
+ print(f"⚠ 数据集 {dataset_name} 存在但MD5不匹配,重新下载")
176
+
177
+ # 下载文件
178
+ print(f"正在下载数据集 {dataset_name}...")
179
+ download_file(url, data_path)
180
+
181
+ # 验证MD5
182
+ if md5(data_path) != file_md5:
183
+ raise ValueError(f"数据集 {dataset_name} MD5验证失败")
184
+
185
+ print(f"✓ 数据集 {dataset_name} 下载成功")
186
+ return data_path
187
+
188
+ def extract_images(self, dataset_name: str, data: pd.DataFrame) -> Dict[str, str]:
189
+ """提取数据集中的图像"""
190
+ dataset_img_root = os.path.join(self.img_root, dataset_name)
191
+ os.makedirs(dataset_img_root, exist_ok=True)
192
+
193
+ image_paths = {}
194
+
195
+ if 'image' in data.columns:
196
+ print(f"正在提取 {dataset_name} 的图像...")
197
+ for idx, row in tqdm(data.iterrows(), total=len(data), desc=f"提取 {dataset_name} 图像"):
198
+ index = row['index']
199
+ image_data = row['image']
200
+
201
+ if pd.isna(image_data):
202
+ continue
203
+
204
+ # 处理图像数据
205
+ if isinstance(image_data, str) and len(image_data) > 64:
206
+ # 假设是base64编码的图像
207
+ image_path = os.path.join(dataset_img_root, f"{index}.jpg")
208
+ if not read_ok(image_path):
209
+ try:
210
+ decode_base64_to_image_file(image_data, image_path)
211
+ except Exception as e:
212
+ print(f"⚠ 解码图像失败 (索引 {index}): {e}")
213
+ continue
214
+ image_paths[str(index)] = image_path
215
+ elif isinstance(image_data, list):
216
+ # 处理多图像情况
217
+ for i, img in enumerate(image_data):
218
+ if isinstance(img, str) and len(img) > 64:
219
+ image_path = os.path.join(dataset_img_root, f"{index}_{i+1}.jpg")
220
+ if not read_ok(image_path):
221
+ try:
222
+ decode_base64_to_image_file(img, image_path)
223
+ except Exception as e:
224
+ print(f"⚠ 解码图像失败 (索引 {index}_{i+1}): {e}")
225
+ continue
226
+ image_paths[f"{index}_{i+1}"] = image_path
227
+
228
+ print(f"✓ 提取了 {len(image_paths)} 张图像")
229
+ return image_paths
230
+
231
+ def process_dataset(self, dataset_name: str) -> Dict:
232
+ """处理单个数据集"""
233
+ print(f"\n=== 处理数据集: {dataset_name} ===")
234
+
235
+ # 下载数据集
236
+ data_path = self.download_dataset(dataset_name)
237
+
238
+ # 加载数据
239
+ data = pd.read_csv(data_path, sep='\t')
240
+ print(f"✓ 加载了 {len(data)} 个样本")
241
+
242
+ # 提取图像
243
+ image_paths = self.extract_images(dataset_name, data)
244
+
245
+ # 准备结果
246
+ config = DATASET_CONFIG[dataset_name]
247
+ result = {
248
+ 'dataset_name': dataset_name,
249
+ 'dataset_type': config['type'],
250
+ 'total_samples': len(data),
251
+ 'image_count': len(image_paths),
252
+ 'data': data,
253
+ 'image_paths': image_paths,
254
+ 'columns': list(data.columns)
255
+ }
256
+
257
+ # 添加样本数据
258
+ sample_data = []
259
+ for idx, row in data.head(3).iterrows():
260
+ sample = {
261
+ 'index': row['index'],
262
+ 'question': row.get('question', 'N/A'),
263
+ 'answer': row.get('answer', 'N/A')
264
+ }
265
+
266
+ # 添加选项(如果是MCQ类型)
267
+ if config['type'] == 'MCQ':
268
+ options = {}
269
+ for col in ['A', 'B', 'C', 'D', 'E']:
270
+ if col in row and not pd.isna(row[col]):
271
+ options[col] = row[col]
272
+ if options:
273
+ sample['options'] = options
274
+
275
+ # 添加图像路径
276
+ if str(row['index']) in image_paths:
277
+ sample['image_path'] = image_paths[str(row['index'])]
278
+
279
+ sample_data.append(sample)
280
+
281
+ result['sample_data'] = sample_data
282
+
283
+ return result
284
+
285
+ def process_datasets(self, dataset_names: List[str]) -> Dict[str, Dict]:
286
+ """处理多个数据集"""
287
+ if not dataset_names:
288
+ raise ValueError("数据集名称列表不能为空")
289
+
290
+ # 验证数据集名称
291
+ invalid_datasets = [name for name in dataset_names if name not in DATASET_CONFIG]
292
+ if invalid_datasets:
293
+ raise ValueError(f"不支持的数据集: {invalid_datasets}")
294
+
295
+ print(f"开始处理 {len(dataset_names)} 个数据集: {dataset_names}")
296
+
297
+ results = {}
298
+
299
+ for dataset_name in dataset_names:
300
+ try:
301
+ result = self.process_dataset(dataset_name)
302
+ results[dataset_name] = result
303
+ print(f"✓ 数据集 {dataset_name} 处理完成")
304
+ except Exception as e:
305
+ print(f"✗ 处理数据集 {dataset_name} 失败: {e}")
306
+ results[dataset_name] = None
307
+
308
+ return results
309
+
310
+ def process_vlmeval_datasets(dataset_names: List[str], data_root: str) -> Dict[str, Dict]:
311
+ """
312
+ 主函数:处理VLMEval数据集
313
+
314
+ Args:
315
+ dataset_names: 数据集名称列表,支持的数据集包括:
316
+ - Y/N类型: MME, HallusionBench, POPE, AMBER
317
+ - MCQ类型: MMBench_DEV_EN, SEEDBench2_Plus, ScienceQA_VAL, MMMU_TEST
318
+ - VQA类型: OCRBench, MathVista_MINI, MathVision, MIA-Bench
319
+ - MT类型: MMDU
320
+ data_root: 数据存储根目录
321
+
322
+ Returns:
323
+ 包含所有数据集处理结果的字典
324
+ """
325
+ processor = DatasetProcessor(data_root)
326
+ return processor.process_datasets(dataset_names)
327
+
328
+ # 使用示例
329
+ if __name__ == "__main__":
330
+ # 示例1:处理单个数据集
331
+ print("=== 示例1:处理单个数据集 ===")
332
+ result = process_vlmeval_datasets(['MMBench_DEV_EN'], data_root='/media/raid/workspace/jiangkailin/data_and_ckpt/dataset/cache/vlmeval')
333
+
334
+ for dataset_name, dataset_result in result.items():
335
+ if dataset_result:
336
+ print(f"\n数据集: {dataset_name}")
337
+ print(f"类型: {dataset_result['dataset_type']}")
338
+ print(f"样本数: {dataset_result['total_samples']}")
339
+ print(f"图像数: {dataset_result['image_count']}")
340
+ print("前3个样本:")
341
+ for i, sample in enumerate(dataset_result['sample_data'], 1):
342
+ print(f" 样本{i}: 索引={sample['index']}, 问题={sample['question'][:50]}...")
343
+
344
+ # # 示例2:处理多个数据集
345
+ # print("\n=== 示例2:处理多个数据集 ===")
346
+ # datasets = ['MME', 'POPE', 'OCRBench']
347
+ # results = process_vlmeval_datasets(datasets, data_root='/home/jiangkailin/mydisk/iclr26_evoke_dynamic_null_space/cache/vlmeval')
348
+
349
+ # print(f"\n处理完成,共处理 {len(results)} 个数据集:")
350
+ # for dataset_name, dataset_result in results.items():
351
+ # if dataset_result:
352
+ # print(f" ✓ {dataset_name}: {dataset_result['total_samples']} 样本, {dataset_result['image_count']} 图像")
353
+ # else:
354
+ # print(f" ✗ {dataset_name}: 处理失败")