Fill-Mask
Transformers
code
File size: 7,820 Bytes
8193465
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import logging
import os
from argparse import ArgumentParser
from collections.abc import Generator
from shutil import rmtree

import datasets.config
from datasets.builder import DatasetBuilder
from datasets.commands import BaseDatasetsCLICommand
from datasets.download.download_manager import DownloadMode
from datasets.info import DatasetInfosDict
from datasets.load import dataset_module_factory, get_dataset_builder_class
from datasets.utils.info_utils import VerificationMode
from datasets.utils.logging import ERROR, get_logger


logger = get_logger(__name__)


def _test_command_factory(args):
    return TestCommand(
        args.dataset,
        args.name,
        args.cache_dir,
        args.data_dir,
        args.all_configs,
        args.save_info or args.save_infos,
        args.ignore_verifications,
        args.force_redownload,
        args.clear_cache,
        args.num_proc,
    )


class TestCommand(BaseDatasetsCLICommand):
    __test__ = False  # to tell pytest it's not a test class

    @staticmethod
    def register_subcommand(parser: ArgumentParser):
        test_parser = parser.add_parser("test", help="Test dataset loading.")
        test_parser.add_argument("--name", type=str, default=None, help="Dataset processing name")
        test_parser.add_argument(
            "--cache_dir",
            type=str,
            default=None,
            help="Cache directory where the datasets are stored.",
        )
        test_parser.add_argument(
            "--data_dir",
            type=str,
            default=None,
            help="Can be used to specify a manual directory to get the files from.",
        )
        test_parser.add_argument("--all_configs", action="store_true", help="Test all dataset configurations")
        test_parser.add_argument(
            "--save_info", action="store_true", help="Save the dataset infos in the dataset card (README.md)"
        )
        test_parser.add_argument(
            "--ignore_verifications",
            action="store_true",
            help="Run the test without checksums and splits checks.",
        )
        test_parser.add_argument("--force_redownload", action="store_true", help="Force dataset redownload")
        test_parser.add_argument(
            "--clear_cache",
            action="store_true",
            help="Remove downloaded files and cached datasets after each config test",
        )
        test_parser.add_argument("--num_proc", type=int, default=None, help="Number of processes")
        # aliases
        test_parser.add_argument("--save_infos", action="store_true", help="alias to save_info")
        test_parser.add_argument("dataset", type=str, help="Name of the dataset to download")
        test_parser.set_defaults(func=_test_command_factory)

    def __init__(
        self,
        dataset: str,
        name: str,
        cache_dir: str,
        data_dir: str,
        all_configs: bool,
        save_infos: bool,
        ignore_verifications: bool,
        force_redownload: bool,
        clear_cache: bool,
        num_proc: int,
    ):
        self._dataset = dataset
        self._name = name
        self._cache_dir = cache_dir
        self._data_dir = data_dir
        self._all_configs = all_configs
        self._save_infos = save_infos
        self._ignore_verifications = ignore_verifications
        self._force_redownload = force_redownload
        self._clear_cache = clear_cache
        self._num_proc = num_proc
        if clear_cache and not cache_dir:
            print(
                "When --clear_cache is used, specifying a cache directory is mandatory.\n"
                "The 'download' folder of the cache directory and the dataset builder cache will be deleted after each configuration test.\n"
                "Please provide a --cache_dir that will be used to test the dataset."
            )
            exit(1)
        if save_infos:
            self._ignore_verifications = True

    def run(self):
        logging.getLogger("filelock").setLevel(ERROR)
        if self._name is not None and self._all_configs:
            print("Both parameters `config` and `all_configs` can't be used at once.")
            exit(1)
        path, config_name = self._dataset, self._name
        module = dataset_module_factory(path)
        builder_cls = get_dataset_builder_class(module)
        n_builders = len(builder_cls.BUILDER_CONFIGS) if self._all_configs and builder_cls.BUILDER_CONFIGS else 1

        def get_builders() -> Generator[DatasetBuilder, None, None]:
            if self._all_configs and builder_cls.BUILDER_CONFIGS:
                for i, config in enumerate(builder_cls.BUILDER_CONFIGS):
                    if "config_name" in module.builder_kwargs:
                        yield builder_cls(
                            cache_dir=self._cache_dir,
                            data_dir=self._data_dir,
                            **module.builder_kwargs,
                        )
                    else:
                        yield builder_cls(
                            config_name=config.name,
                            cache_dir=self._cache_dir,
                            data_dir=self._data_dir,
                            **module.builder_kwargs,
                        )
            else:
                if "config_name" in module.builder_kwargs:
                    yield builder_cls(cache_dir=self._cache_dir, data_dir=self._data_dir, **module.builder_kwargs)
                else:
                    yield builder_cls(
                        config_name=config_name,
                        cache_dir=self._cache_dir,
                        data_dir=self._data_dir,
                        **module.builder_kwargs,
                    )

        for j, builder in enumerate(get_builders()):
            print(f"Testing builder '{builder.config.name}' ({j + 1}/{n_builders})")
            builder._record_infos = os.path.exists(
                os.path.join(builder.get_imported_module_dir(), datasets.config.DATASETDICT_INFOS_FILENAME)
            )  # record checksums only if we need to update a (deprecated) dataset_infos.json
            builder.download_and_prepare(
                download_mode=DownloadMode.REUSE_CACHE_IF_EXISTS
                if not self._force_redownload
                else DownloadMode.FORCE_REDOWNLOAD,
                verification_mode=VerificationMode.NO_CHECKS
                if self._ignore_verifications
                else VerificationMode.ALL_CHECKS,
                num_proc=self._num_proc,
            )
            builder.as_dataset()

            # If save_infos=True, we create the dataset card (README.md)
            # The dataset_infos are saved in the YAML part of the README.md
            # This is to allow the user to upload them on HF afterwards.
            if self._save_infos:
                save_infos_dir = os.path.basename(path) if not os.path.isdir(path) else path
                os.makedirs(save_infos_dir, exist_ok=True)
                DatasetInfosDict(**{builder.config.name: builder.info}).write_to_directory(save_infos_dir)
                print(f"Dataset card saved at {os.path.join(save_infos_dir, datasets.config.REPOCARD_FILENAME)}")

            # If clear_cache=True, the download folder and the dataset builder cache directory are deleted
            if self._clear_cache:
                if os.path.isdir(builder._cache_dir):
                    logger.warning(f"Clearing cache at {builder._cache_dir}")
                    rmtree(builder._cache_dir)
                download_dir = os.path.join(self._cache_dir, datasets.config.DOWNLOADED_DATASETS_DIR)
                if os.path.isdir(download_dir):
                    logger.warning(f"Clearing cache at {download_dir}")
                    rmtree(download_dir)

        print("Test successful.")