File size: 11,318 Bytes
f7f0207
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
import random
import sys
import unittest
import warnings
from os import environ

from datasets import Dataset, DatasetDict
from mmengine.config import read_base
from tqdm import tqdm

from concurrent.futures import ThreadPoolExecutor, as_completed

warnings.filterwarnings('ignore', category=DeprecationWarning)


def reload_datasets():
    modules_to_remove = [
        module_name for module_name in sys.modules
        if module_name.startswith('configs.datasets')
    ]

    for module_name in modules_to_remove:
        del sys.modules[module_name]

    with read_base():
        from configs.datasets.ceval.ceval_gen import ceval_datasets
        from configs.datasets.gsm8k.gsm8k_gen import gsm8k_datasets
        from configs.datasets.cmmlu.cmmlu_gen import cmmlu_datasets
        from configs.datasets.ARC_c.ARC_c_gen import ARC_c_datasets
        from configs.datasets.ARC_e.ARC_e_gen import ARC_e_datasets
        from configs.datasets.humaneval.humaneval_gen import humaneval_datasets
        from configs.datasets.humaneval.humaneval_repeat10_gen_8e312c import humaneval_datasets as humaneval_repeat10_datasets
        from configs.datasets.race.race_ppl import race_datasets
        from configs.datasets.commonsenseqa.commonsenseqa_gen import commonsenseqa_datasets
        
        from configs.datasets.mmlu.mmlu_gen import mmlu_datasets
        from configs.datasets.strategyqa.strategyqa_gen import strategyqa_datasets
        from configs.datasets.bbh.bbh_gen import bbh_datasets
        from configs.datasets.Xsum.Xsum_gen import Xsum_datasets
        from configs.datasets.winogrande.winogrande_gen import winogrande_datasets
        from configs.datasets.winogrande.winogrande_ll import winogrande_datasets as winogrande_ll_datasets
        from configs.datasets.winogrande.winogrande_5shot_ll_252f01 import winogrande_datasets as winogrande_5shot_ll_datasets
        from configs.datasets.obqa.obqa_gen import obqa_datasets
        from configs.datasets.obqa.obqa_ppl_6aac9e import obqa_datasets as obqa_ppl_datasets
        from configs.datasets.agieval.agieval_gen import agieval_datasets as agieval_v2_datasets
        # from configs.datasets.agieval.agieval_gen_a0c741 import agieval_datasets as agieval_v1_datasets
        from configs.datasets.siqa.siqa_gen import siqa_datasets as siqa_v2_datasets
        from configs.datasets.siqa.siqa_gen_18632c import siqa_datasets as siqa_v3_datasets
        from configs.datasets.siqa.siqa_ppl_42bc6e import siqa_datasets as siqa_ppl_datasets
        from configs.datasets.storycloze.storycloze_gen import storycloze_datasets
        from configs.datasets.storycloze.storycloze_ppl import storycloze_datasets as storycloze_ppl_datasets
        from configs.datasets.summedits.summedits_gen import summedits_datasets as summedits_v2_datasets
        
        from configs.datasets.hellaswag.hellaswag_gen import hellaswag_datasets as hellaswag_v2_datasets
        from configs.datasets.hellaswag.hellaswag_10shot_gen_e42710 import hellaswag_datasets as hellaswag_ice_datasets
        from configs.datasets.hellaswag.hellaswag_ppl_9dbb12 import hellaswag_datasets as hellaswag_v1_datasets
        from configs.datasets.hellaswag.hellaswag_ppl_a6e128 import hellaswag_datasets as hellaswag_v3_datasets
        from configs.datasets.mbpp.mbpp_gen import mbpp_datasets as mbpp_v1_datasets
        from configs.datasets.mbpp.mbpp_passk_gen_830460 import mbpp_datasets as mbpp_v2_datasets
        from configs.datasets.mbpp.sanitized_mbpp_gen_830460 import sanitized_mbpp_datasets
        from configs.datasets.nq.nq_gen import nq_datasets
        from configs.datasets.lcsts.lcsts_gen import lcsts_datasets
        from configs.datasets.math.math_gen import math_datasets
        from configs.datasets.piqa.piqa_gen import piqa_datasets as piqa_v2_datasets
        from configs.datasets.piqa.piqa_ppl import piqa_datasets as piqa_v1_datasets
        from configs.datasets.piqa.piqa_ppl_0cfff2 import piqa_datasets as piqa_v3_datasets
        from configs.datasets.lambada.lambada_gen import lambada_datasets
        from configs.datasets.tydiqa.tydiqa_gen import tydiqa_datasets
        from configs.datasets.GaokaoBench.GaokaoBench_gen import GaokaoBench_datasets
        from configs.datasets.GaokaoBench.GaokaoBench_mixed import GaokaoBench_datasets as GaokaoBench_mixed_datasets
        from configs.datasets.GaokaoBench.GaokaoBench_no_subjective_gen_4c31db import GaokaoBench_datasets as GaokaoBench_no_subjective_datasets
        from configs.datasets.triviaqa.triviaqa_gen import triviaqa_datasets
        from configs.datasets.triviaqa.triviaqa_wiki_1shot_gen_20a989 import triviaqa_datasets as triviaqa_wiki_1shot_datasets
        
        from configs.datasets.CLUE_cmnli.CLUE_cmnli_gen import cmnli_datasets
        from configs.datasets.CLUE_cmnli.CLUE_cmnli_ppl import cmnli_datasets as cmnli_ppl_datasets
        from configs.datasets.CLUE_ocnli.CLUE_ocnli_gen import ocnli_datasets
        
        from configs.datasets.ceval.ceval_clean_ppl import ceval_datasets as ceval_clean_datasets
        from configs.datasets.ARC_c.ARC_c_clean_ppl import ARC_c_datasets as ARC_c_clean_datasets
        from configs.datasets.mmlu.mmlu_clean_ppl import mmlu_datasets as mmlu_clean_datasets
        from configs.datasets.hellaswag.hellaswag_clean_ppl import hellaswag_datasets as hellaswag_clean_datasets
        from configs.datasets.FewCLUE_ocnli_fc.FewCLUE_ocnli_fc_gen import ocnli_fc_datasets

    return sum((v for k, v in locals().items() if k.endswith('_datasets')), [])


def load_datasets_conf(source):
    environ['DATASET_SOURCE'] = source
    datasets_conf = reload_datasets()
    return datasets_conf


def load_datasets(source, conf):
    environ['DATASET_SOURCE'] = source
    if 'lang' in conf:
        dataset = conf['type'].load(path=conf['path'], lang=conf['lang'])
        return dataset
    if 'setting_name' in conf:
        dataset = conf['type'].load(path=conf['path'],
                                    name=conf['name'],
                                    setting_name=conf['setting_name'])
        return dataset
    if 'name' in conf:
        dataset = conf['type'].load(path=conf['path'], name=conf['name'])
        return dataset

    if 'local_mode' in conf:
        dataset = conf['type'].load(path=conf['path'], local_mode=conf['local_mode'])
        return dataset
    try:
        dataset = conf['type'].load(path=conf['path'])
    except Exception:
        dataset = conf['type'].load(**conf)
    return dataset


def clean_string(value):
    """Helper function to clean and normalize string data.

    It strips leading and trailing whitespace and replaces multiple whitespace
    characters with a single space.
    """
    if isinstance(value, str):
        return ' '.join(value.split())
    return value


class TestingLocalDatasets(unittest.TestCase):

    def test_datasets(self):
        # 加载 ModelScope 和 Local 数据集配置
        # ms_datasets_conf = load_datasets_conf('ModelScope')
        local_datasets_conf = load_datasets_conf('Local')
        
        # 初始化成功和失败的数据集列表
        successful_comparisons = []
        failed_comparisons = []
        
        def compare_datasets(local_conf):
            # local_dataset = load_datasets(local_conf)
            local_dataset = load_datasets('Local', local_conf)
            # modelscope_path_name = f"{ms_conf.get('path')}/{ms_conf.get('name', '')}\t{ms_conf.get('lang', '')}"
            local_path_name = f"{local_conf.get('path')}/{local_conf.get('name', '')}\t{local_conf.get('lang', '')}"
            # # 断言类型一致
            # assert ms_conf['type'] == local_conf['type'], "Data types do not match"
            # print(modelscope_path_name, local_path_name)
            try:
            #     ms_dataset = load_datasets('ModelScope', ms_conf)
                local_dataset = load_datasets('Local', local_conf)
            #     _check_data(ms_dataset, local_dataset, sample_size=sample_size)
                return 'success', f'{local_path_name}'
            except Exception as exception:
            #     print(exception)
                return 'failure', f'can\'t load {local_path_name}'
        
        with ThreadPoolExecutor(16) as executor:
            futures = {
                executor.submit(compare_datasets, local_conf):  local_conf
                for local_conf in  local_datasets_conf
            }
            
            for future in tqdm(as_completed(futures), total=len(futures)):
                result, message = future.result()
                if result == 'success':
                    successful_comparisons.append(message)
                else:
                    failed_comparisons.append(message)
        
        # 输出测试总结
        total_datasets = len(local_datasets_conf)
        print(f"All {total_datasets} datasets")
        print(f"OK {len(successful_comparisons)} datasets")
        for success in successful_comparisons:
            print(f"  {success}")
        print(f"Fail {len(failed_comparisons)} datasets")
        for failure in failed_comparisons:
            print(f"  {failure}")


def _check_data(ms_dataset: Dataset | DatasetDict,
                oc_dataset: Dataset | DatasetDict,
                sample_size):
    assert type(ms_dataset) == type(
        oc_dataset
    ), f'Dataset type not match: {type(ms_dataset)} != {type(oc_dataset)}'

    # match DatasetDict
    if isinstance(oc_dataset, DatasetDict):
        assert ms_dataset.keys() == oc_dataset.keys(
        ), f'DatasetDict not match: {ms_dataset.keys()} != {oc_dataset.keys()}'

        for key in ms_dataset.keys():
            _check_data(ms_dataset[key], oc_dataset[key], sample_size=sample_size)

    elif isinstance(oc_dataset, Dataset):
        # match by cols
        assert set(ms_dataset.column_names) == set(
            oc_dataset.column_names
        ), f'Column names do not match: {ms_dataset.column_names} != {oc_dataset.column_names}'

        # Check that the number of rows is the same
        assert len(ms_dataset) == len(
            oc_dataset
        ), f'Number of rows do not match: {len(ms_dataset)} != {len(oc_dataset)}'

        # Randomly sample indices
        sample_indices = random.sample(range(len(ms_dataset)),
                                       min(sample_size, len(ms_dataset)))

        for i, idx in enumerate(sample_indices):
            for col in ms_dataset.column_names:
                ms_value = clean_string(str(ms_dataset[col][idx]))
                oc_value = clean_string(str(oc_dataset[col][idx]))
                try:
                    assert ms_value == oc_value, f"Value mismatch in column '{col}', index {idx}: {ms_value} != {oc_value}"
                except AssertionError as e:
                    print(f"Assertion failed for column '{col}', index {idx}")
                    print(f"ms_data: {ms_dataset[idx]}")
                    print(f'oc_data: {oc_dataset[idx]}')
                    print(f'ms_value: {ms_value} ({type(ms_value)})')
                    print(f'oc_value: {oc_value} ({type(oc_value)})')
                    raise e
    else:
        raise ValueError(f'Datasets type not supported {type(ms_dataset)}')


if __name__ == '__main__':
    sample_size = 100
    unittest.main()