File size: 11,318 Bytes
f7f0207 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 |
import random
import sys
import unittest
import warnings
from os import environ
from datasets import Dataset, DatasetDict
from mmengine.config import read_base
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed
warnings.filterwarnings('ignore', category=DeprecationWarning)
def reload_datasets():
modules_to_remove = [
module_name for module_name in sys.modules
if module_name.startswith('configs.datasets')
]
for module_name in modules_to_remove:
del sys.modules[module_name]
with read_base():
from configs.datasets.ceval.ceval_gen import ceval_datasets
from configs.datasets.gsm8k.gsm8k_gen import gsm8k_datasets
from configs.datasets.cmmlu.cmmlu_gen import cmmlu_datasets
from configs.datasets.ARC_c.ARC_c_gen import ARC_c_datasets
from configs.datasets.ARC_e.ARC_e_gen import ARC_e_datasets
from configs.datasets.humaneval.humaneval_gen import humaneval_datasets
from configs.datasets.humaneval.humaneval_repeat10_gen_8e312c import humaneval_datasets as humaneval_repeat10_datasets
from configs.datasets.race.race_ppl import race_datasets
from configs.datasets.commonsenseqa.commonsenseqa_gen import commonsenseqa_datasets
from configs.datasets.mmlu.mmlu_gen import mmlu_datasets
from configs.datasets.strategyqa.strategyqa_gen import strategyqa_datasets
from configs.datasets.bbh.bbh_gen import bbh_datasets
from configs.datasets.Xsum.Xsum_gen import Xsum_datasets
from configs.datasets.winogrande.winogrande_gen import winogrande_datasets
from configs.datasets.winogrande.winogrande_ll import winogrande_datasets as winogrande_ll_datasets
from configs.datasets.winogrande.winogrande_5shot_ll_252f01 import winogrande_datasets as winogrande_5shot_ll_datasets
from configs.datasets.obqa.obqa_gen import obqa_datasets
from configs.datasets.obqa.obqa_ppl_6aac9e import obqa_datasets as obqa_ppl_datasets
from configs.datasets.agieval.agieval_gen import agieval_datasets as agieval_v2_datasets
# from configs.datasets.agieval.agieval_gen_a0c741 import agieval_datasets as agieval_v1_datasets
from configs.datasets.siqa.siqa_gen import siqa_datasets as siqa_v2_datasets
from configs.datasets.siqa.siqa_gen_18632c import siqa_datasets as siqa_v3_datasets
from configs.datasets.siqa.siqa_ppl_42bc6e import siqa_datasets as siqa_ppl_datasets
from configs.datasets.storycloze.storycloze_gen import storycloze_datasets
from configs.datasets.storycloze.storycloze_ppl import storycloze_datasets as storycloze_ppl_datasets
from configs.datasets.summedits.summedits_gen import summedits_datasets as summedits_v2_datasets
from configs.datasets.hellaswag.hellaswag_gen import hellaswag_datasets as hellaswag_v2_datasets
from configs.datasets.hellaswag.hellaswag_10shot_gen_e42710 import hellaswag_datasets as hellaswag_ice_datasets
from configs.datasets.hellaswag.hellaswag_ppl_9dbb12 import hellaswag_datasets as hellaswag_v1_datasets
from configs.datasets.hellaswag.hellaswag_ppl_a6e128 import hellaswag_datasets as hellaswag_v3_datasets
from configs.datasets.mbpp.mbpp_gen import mbpp_datasets as mbpp_v1_datasets
from configs.datasets.mbpp.mbpp_passk_gen_830460 import mbpp_datasets as mbpp_v2_datasets
from configs.datasets.mbpp.sanitized_mbpp_gen_830460 import sanitized_mbpp_datasets
from configs.datasets.nq.nq_gen import nq_datasets
from configs.datasets.lcsts.lcsts_gen import lcsts_datasets
from configs.datasets.math.math_gen import math_datasets
from configs.datasets.piqa.piqa_gen import piqa_datasets as piqa_v2_datasets
from configs.datasets.piqa.piqa_ppl import piqa_datasets as piqa_v1_datasets
from configs.datasets.piqa.piqa_ppl_0cfff2 import piqa_datasets as piqa_v3_datasets
from configs.datasets.lambada.lambada_gen import lambada_datasets
from configs.datasets.tydiqa.tydiqa_gen import tydiqa_datasets
from configs.datasets.GaokaoBench.GaokaoBench_gen import GaokaoBench_datasets
from configs.datasets.GaokaoBench.GaokaoBench_mixed import GaokaoBench_datasets as GaokaoBench_mixed_datasets
from configs.datasets.GaokaoBench.GaokaoBench_no_subjective_gen_4c31db import GaokaoBench_datasets as GaokaoBench_no_subjective_datasets
from configs.datasets.triviaqa.triviaqa_gen import triviaqa_datasets
from configs.datasets.triviaqa.triviaqa_wiki_1shot_gen_20a989 import triviaqa_datasets as triviaqa_wiki_1shot_datasets
from configs.datasets.CLUE_cmnli.CLUE_cmnli_gen import cmnli_datasets
from configs.datasets.CLUE_cmnli.CLUE_cmnli_ppl import cmnli_datasets as cmnli_ppl_datasets
from configs.datasets.CLUE_ocnli.CLUE_ocnli_gen import ocnli_datasets
from configs.datasets.ceval.ceval_clean_ppl import ceval_datasets as ceval_clean_datasets
from configs.datasets.ARC_c.ARC_c_clean_ppl import ARC_c_datasets as ARC_c_clean_datasets
from configs.datasets.mmlu.mmlu_clean_ppl import mmlu_datasets as mmlu_clean_datasets
from configs.datasets.hellaswag.hellaswag_clean_ppl import hellaswag_datasets as hellaswag_clean_datasets
from configs.datasets.FewCLUE_ocnli_fc.FewCLUE_ocnli_fc_gen import ocnli_fc_datasets
return sum((v for k, v in locals().items() if k.endswith('_datasets')), [])
def load_datasets_conf(source):
environ['DATASET_SOURCE'] = source
datasets_conf = reload_datasets()
return datasets_conf
def load_datasets(source, conf):
environ['DATASET_SOURCE'] = source
if 'lang' in conf:
dataset = conf['type'].load(path=conf['path'], lang=conf['lang'])
return dataset
if 'setting_name' in conf:
dataset = conf['type'].load(path=conf['path'],
name=conf['name'],
setting_name=conf['setting_name'])
return dataset
if 'name' in conf:
dataset = conf['type'].load(path=conf['path'], name=conf['name'])
return dataset
if 'local_mode' in conf:
dataset = conf['type'].load(path=conf['path'], local_mode=conf['local_mode'])
return dataset
try:
dataset = conf['type'].load(path=conf['path'])
except Exception:
dataset = conf['type'].load(**conf)
return dataset
def clean_string(value):
"""Helper function to clean and normalize string data.
It strips leading and trailing whitespace and replaces multiple whitespace
characters with a single space.
"""
if isinstance(value, str):
return ' '.join(value.split())
return value
class TestingLocalDatasets(unittest.TestCase):
def test_datasets(self):
# 加载 ModelScope 和 Local 数据集配置
# ms_datasets_conf = load_datasets_conf('ModelScope')
local_datasets_conf = load_datasets_conf('Local')
# 初始化成功和失败的数据集列表
successful_comparisons = []
failed_comparisons = []
def compare_datasets(local_conf):
# local_dataset = load_datasets(local_conf)
local_dataset = load_datasets('Local', local_conf)
# modelscope_path_name = f"{ms_conf.get('path')}/{ms_conf.get('name', '')}\t{ms_conf.get('lang', '')}"
local_path_name = f"{local_conf.get('path')}/{local_conf.get('name', '')}\t{local_conf.get('lang', '')}"
# # 断言类型一致
# assert ms_conf['type'] == local_conf['type'], "Data types do not match"
# print(modelscope_path_name, local_path_name)
try:
# ms_dataset = load_datasets('ModelScope', ms_conf)
local_dataset = load_datasets('Local', local_conf)
# _check_data(ms_dataset, local_dataset, sample_size=sample_size)
return 'success', f'{local_path_name}'
except Exception as exception:
# print(exception)
return 'failure', f'can\'t load {local_path_name}'
with ThreadPoolExecutor(16) as executor:
futures = {
executor.submit(compare_datasets, local_conf): local_conf
for local_conf in local_datasets_conf
}
for future in tqdm(as_completed(futures), total=len(futures)):
result, message = future.result()
if result == 'success':
successful_comparisons.append(message)
else:
failed_comparisons.append(message)
# 输出测试总结
total_datasets = len(local_datasets_conf)
print(f"All {total_datasets} datasets")
print(f"OK {len(successful_comparisons)} datasets")
for success in successful_comparisons:
print(f" {success}")
print(f"Fail {len(failed_comparisons)} datasets")
for failure in failed_comparisons:
print(f" {failure}")
def _check_data(ms_dataset: Dataset | DatasetDict,
oc_dataset: Dataset | DatasetDict,
sample_size):
assert type(ms_dataset) == type(
oc_dataset
), f'Dataset type not match: {type(ms_dataset)} != {type(oc_dataset)}'
# match DatasetDict
if isinstance(oc_dataset, DatasetDict):
assert ms_dataset.keys() == oc_dataset.keys(
), f'DatasetDict not match: {ms_dataset.keys()} != {oc_dataset.keys()}'
for key in ms_dataset.keys():
_check_data(ms_dataset[key], oc_dataset[key], sample_size=sample_size)
elif isinstance(oc_dataset, Dataset):
# match by cols
assert set(ms_dataset.column_names) == set(
oc_dataset.column_names
), f'Column names do not match: {ms_dataset.column_names} != {oc_dataset.column_names}'
# Check that the number of rows is the same
assert len(ms_dataset) == len(
oc_dataset
), f'Number of rows do not match: {len(ms_dataset)} != {len(oc_dataset)}'
# Randomly sample indices
sample_indices = random.sample(range(len(ms_dataset)),
min(sample_size, len(ms_dataset)))
for i, idx in enumerate(sample_indices):
for col in ms_dataset.column_names:
ms_value = clean_string(str(ms_dataset[col][idx]))
oc_value = clean_string(str(oc_dataset[col][idx]))
try:
assert ms_value == oc_value, f"Value mismatch in column '{col}', index {idx}: {ms_value} != {oc_value}"
except AssertionError as e:
print(f"Assertion failed for column '{col}', index {idx}")
print(f"ms_data: {ms_dataset[idx]}")
print(f'oc_data: {oc_dataset[idx]}')
print(f'ms_value: {ms_value} ({type(ms_value)})')
print(f'oc_value: {oc_value} ({type(oc_value)})')
raise e
else:
raise ValueError(f'Datasets type not supported {type(ms_dataset)}')
if __name__ == '__main__':
sample_size = 100
unittest.main()
|