|
|
import random |
|
|
import sys |
|
|
import unittest |
|
|
import warnings |
|
|
from os import environ |
|
|
|
|
|
from datasets import Dataset, DatasetDict |
|
|
from mmengine.config import read_base |
|
|
from tqdm import tqdm |
|
|
|
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed |
|
|
|
|
|
warnings.filterwarnings('ignore', category=DeprecationWarning) |
|
|
|
|
|
|
|
|
def reload_datasets(): |
|
|
modules_to_remove = [ |
|
|
module_name for module_name in sys.modules |
|
|
if module_name.startswith('configs.datasets') |
|
|
] |
|
|
|
|
|
for module_name in modules_to_remove: |
|
|
del sys.modules[module_name] |
|
|
|
|
|
with read_base(): |
|
|
from configs.datasets.ceval.ceval_gen import ceval_datasets |
|
|
from configs.datasets.gsm8k.gsm8k_gen import gsm8k_datasets |
|
|
from configs.datasets.cmmlu.cmmlu_gen import cmmlu_datasets |
|
|
from configs.datasets.ARC_c.ARC_c_gen import ARC_c_datasets |
|
|
from configs.datasets.ARC_e.ARC_e_gen import ARC_e_datasets |
|
|
from configs.datasets.humaneval.humaneval_gen import humaneval_datasets |
|
|
from configs.datasets.humaneval.humaneval_repeat10_gen_8e312c import humaneval_datasets as humaneval_repeat10_datasets |
|
|
from configs.datasets.race.race_ppl import race_datasets |
|
|
from configs.datasets.commonsenseqa.commonsenseqa_gen import commonsenseqa_datasets |
|
|
|
|
|
from configs.datasets.mmlu.mmlu_gen import mmlu_datasets |
|
|
from configs.datasets.strategyqa.strategyqa_gen import strategyqa_datasets |
|
|
from configs.datasets.bbh.bbh_gen import bbh_datasets |
|
|
from configs.datasets.Xsum.Xsum_gen import Xsum_datasets |
|
|
from configs.datasets.winogrande.winogrande_gen import winogrande_datasets |
|
|
from configs.datasets.winogrande.winogrande_ll import winogrande_datasets as winogrande_ll_datasets |
|
|
from configs.datasets.winogrande.winogrande_5shot_ll_252f01 import winogrande_datasets as winogrande_5shot_ll_datasets |
|
|
from configs.datasets.obqa.obqa_gen import obqa_datasets |
|
|
from configs.datasets.obqa.obqa_ppl_6aac9e import obqa_datasets as obqa_ppl_datasets |
|
|
from configs.datasets.agieval.agieval_gen import agieval_datasets as agieval_v2_datasets |
|
|
|
|
|
from configs.datasets.siqa.siqa_gen import siqa_datasets as siqa_v2_datasets |
|
|
from configs.datasets.siqa.siqa_gen_18632c import siqa_datasets as siqa_v3_datasets |
|
|
from configs.datasets.siqa.siqa_ppl_42bc6e import siqa_datasets as siqa_ppl_datasets |
|
|
from configs.datasets.storycloze.storycloze_gen import storycloze_datasets |
|
|
from configs.datasets.storycloze.storycloze_ppl import storycloze_datasets as storycloze_ppl_datasets |
|
|
from configs.datasets.summedits.summedits_gen import summedits_datasets as summedits_v2_datasets |
|
|
|
|
|
from configs.datasets.hellaswag.hellaswag_gen import hellaswag_datasets as hellaswag_v2_datasets |
|
|
from configs.datasets.hellaswag.hellaswag_10shot_gen_e42710 import hellaswag_datasets as hellaswag_ice_datasets |
|
|
from configs.datasets.hellaswag.hellaswag_ppl_9dbb12 import hellaswag_datasets as hellaswag_v1_datasets |
|
|
from configs.datasets.hellaswag.hellaswag_ppl_a6e128 import hellaswag_datasets as hellaswag_v3_datasets |
|
|
from configs.datasets.mbpp.mbpp_gen import mbpp_datasets as mbpp_v1_datasets |
|
|
from configs.datasets.mbpp.mbpp_passk_gen_830460 import mbpp_datasets as mbpp_v2_datasets |
|
|
from configs.datasets.mbpp.sanitized_mbpp_gen_830460 import sanitized_mbpp_datasets |
|
|
from configs.datasets.nq.nq_gen import nq_datasets |
|
|
from configs.datasets.lcsts.lcsts_gen import lcsts_datasets |
|
|
from configs.datasets.math.math_gen import math_datasets |
|
|
from configs.datasets.piqa.piqa_gen import piqa_datasets as piqa_v2_datasets |
|
|
from configs.datasets.piqa.piqa_ppl import piqa_datasets as piqa_v1_datasets |
|
|
from configs.datasets.piqa.piqa_ppl_0cfff2 import piqa_datasets as piqa_v3_datasets |
|
|
from configs.datasets.lambada.lambada_gen import lambada_datasets |
|
|
from configs.datasets.tydiqa.tydiqa_gen import tydiqa_datasets |
|
|
from configs.datasets.GaokaoBench.GaokaoBench_gen import GaokaoBench_datasets |
|
|
from configs.datasets.GaokaoBench.GaokaoBench_mixed import GaokaoBench_datasets as GaokaoBench_mixed_datasets |
|
|
from configs.datasets.GaokaoBench.GaokaoBench_no_subjective_gen_4c31db import GaokaoBench_datasets as GaokaoBench_no_subjective_datasets |
|
|
from configs.datasets.triviaqa.triviaqa_gen import triviaqa_datasets |
|
|
from configs.datasets.triviaqa.triviaqa_wiki_1shot_gen_20a989 import triviaqa_datasets as triviaqa_wiki_1shot_datasets |
|
|
|
|
|
from configs.datasets.CLUE_cmnli.CLUE_cmnli_gen import cmnli_datasets |
|
|
from configs.datasets.CLUE_cmnli.CLUE_cmnli_ppl import cmnli_datasets as cmnli_ppl_datasets |
|
|
from configs.datasets.CLUE_ocnli.CLUE_ocnli_gen import ocnli_datasets |
|
|
|
|
|
from configs.datasets.ceval.ceval_clean_ppl import ceval_datasets as ceval_clean_datasets |
|
|
from configs.datasets.ARC_c.ARC_c_clean_ppl import ARC_c_datasets as ARC_c_clean_datasets |
|
|
from configs.datasets.mmlu.mmlu_clean_ppl import mmlu_datasets as mmlu_clean_datasets |
|
|
from configs.datasets.hellaswag.hellaswag_clean_ppl import hellaswag_datasets as hellaswag_clean_datasets |
|
|
from configs.datasets.FewCLUE_ocnli_fc.FewCLUE_ocnli_fc_gen import ocnli_fc_datasets |
|
|
|
|
|
return sum((v for k, v in locals().items() if k.endswith('_datasets')), []) |
|
|
|
|
|
|
|
|
def load_datasets_conf(source): |
|
|
environ['DATASET_SOURCE'] = source |
|
|
datasets_conf = reload_datasets() |
|
|
return datasets_conf |
|
|
|
|
|
|
|
|
def load_datasets(source, conf): |
|
|
environ['DATASET_SOURCE'] = source |
|
|
if 'lang' in conf: |
|
|
dataset = conf['type'].load(path=conf['path'], lang=conf['lang']) |
|
|
return dataset |
|
|
if 'setting_name' in conf: |
|
|
dataset = conf['type'].load(path=conf['path'], |
|
|
name=conf['name'], |
|
|
setting_name=conf['setting_name']) |
|
|
return dataset |
|
|
if 'name' in conf: |
|
|
dataset = conf['type'].load(path=conf['path'], name=conf['name']) |
|
|
return dataset |
|
|
|
|
|
if 'local_mode' in conf: |
|
|
dataset = conf['type'].load(path=conf['path'], local_mode=conf['local_mode']) |
|
|
return dataset |
|
|
try: |
|
|
dataset = conf['type'].load(path=conf['path']) |
|
|
except Exception: |
|
|
dataset = conf['type'].load(**conf) |
|
|
return dataset |
|
|
|
|
|
|
|
|
def clean_string(value): |
|
|
"""Helper function to clean and normalize string data. |
|
|
|
|
|
It strips leading and trailing whitespace and replaces multiple whitespace |
|
|
characters with a single space. |
|
|
""" |
|
|
if isinstance(value, str): |
|
|
return ' '.join(value.split()) |
|
|
return value |
|
|
|
|
|
|
|
|
class TestingLocalDatasets(unittest.TestCase): |
|
|
|
|
|
def test_datasets(self): |
|
|
|
|
|
|
|
|
local_datasets_conf = load_datasets_conf('Local') |
|
|
|
|
|
|
|
|
successful_comparisons = [] |
|
|
failed_comparisons = [] |
|
|
|
|
|
def compare_datasets(local_conf): |
|
|
|
|
|
local_dataset = load_datasets('Local', local_conf) |
|
|
|
|
|
local_path_name = f"{local_conf.get('path')}/{local_conf.get('name', '')}\t{local_conf.get('lang', '')}" |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
local_dataset = load_datasets('Local', local_conf) |
|
|
|
|
|
return 'success', f'{local_path_name}' |
|
|
except Exception as exception: |
|
|
|
|
|
return 'failure', f'can\'t load {local_path_name}' |
|
|
|
|
|
with ThreadPoolExecutor(16) as executor: |
|
|
futures = { |
|
|
executor.submit(compare_datasets, local_conf): local_conf |
|
|
for local_conf in local_datasets_conf |
|
|
} |
|
|
|
|
|
for future in tqdm(as_completed(futures), total=len(futures)): |
|
|
result, message = future.result() |
|
|
if result == 'success': |
|
|
successful_comparisons.append(message) |
|
|
else: |
|
|
failed_comparisons.append(message) |
|
|
|
|
|
|
|
|
total_datasets = len(local_datasets_conf) |
|
|
print(f"All {total_datasets} datasets") |
|
|
print(f"OK {len(successful_comparisons)} datasets") |
|
|
for success in successful_comparisons: |
|
|
print(f" {success}") |
|
|
print(f"Fail {len(failed_comparisons)} datasets") |
|
|
for failure in failed_comparisons: |
|
|
print(f" {failure}") |
|
|
|
|
|
|
|
|
def _check_data(ms_dataset: Dataset | DatasetDict, |
|
|
oc_dataset: Dataset | DatasetDict, |
|
|
sample_size): |
|
|
assert type(ms_dataset) == type( |
|
|
oc_dataset |
|
|
), f'Dataset type not match: {type(ms_dataset)} != {type(oc_dataset)}' |
|
|
|
|
|
|
|
|
if isinstance(oc_dataset, DatasetDict): |
|
|
assert ms_dataset.keys() == oc_dataset.keys( |
|
|
), f'DatasetDict not match: {ms_dataset.keys()} != {oc_dataset.keys()}' |
|
|
|
|
|
for key in ms_dataset.keys(): |
|
|
_check_data(ms_dataset[key], oc_dataset[key], sample_size=sample_size) |
|
|
|
|
|
elif isinstance(oc_dataset, Dataset): |
|
|
|
|
|
assert set(ms_dataset.column_names) == set( |
|
|
oc_dataset.column_names |
|
|
), f'Column names do not match: {ms_dataset.column_names} != {oc_dataset.column_names}' |
|
|
|
|
|
|
|
|
assert len(ms_dataset) == len( |
|
|
oc_dataset |
|
|
), f'Number of rows do not match: {len(ms_dataset)} != {len(oc_dataset)}' |
|
|
|
|
|
|
|
|
sample_indices = random.sample(range(len(ms_dataset)), |
|
|
min(sample_size, len(ms_dataset))) |
|
|
|
|
|
for i, idx in enumerate(sample_indices): |
|
|
for col in ms_dataset.column_names: |
|
|
ms_value = clean_string(str(ms_dataset[col][idx])) |
|
|
oc_value = clean_string(str(oc_dataset[col][idx])) |
|
|
try: |
|
|
assert ms_value == oc_value, f"Value mismatch in column '{col}', index {idx}: {ms_value} != {oc_value}" |
|
|
except AssertionError as e: |
|
|
print(f"Assertion failed for column '{col}', index {idx}") |
|
|
print(f"ms_data: {ms_dataset[idx]}") |
|
|
print(f'oc_data: {oc_dataset[idx]}') |
|
|
print(f'ms_value: {ms_value} ({type(ms_value)})') |
|
|
print(f'oc_value: {oc_value} ({type(oc_value)})') |
|
|
raise e |
|
|
else: |
|
|
raise ValueError(f'Datasets type not supported {type(ms_dataset)}') |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
sample_size = 100 |
|
|
unittest.main() |
|
|
|