| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import glob |
| | import logging |
| | import os |
| | from dataclasses import dataclass |
| |
|
| | import hydra |
| | from hydra.core.config_store import ConfigStore |
| | from joblib import Parallel, delayed |
| | from omegaconf import MISSING |
| |
|
| | try: |
| | from wds2idx import IndexCreator |
| |
|
| | INDEX_CREATOR_AVAILABLE = True |
| | except (ImportError, ModuleNotFoundError): |
| | INDEX_CREATOR_AVAILABLE = False |
| |
|
| | """ |
| | python create_dali_tarred_dataset_index.py \ |
| | tar_dir=<path to the directory which contains tarred dataset> \ |
| | workers=-1 |
| | |
| | """ |
| |
|
| | logging.basicConfig(level=logging.INFO) |
| |
|
| |
|
| | @dataclass |
| | class DALITarredIndexConfig: |
| | tar_dir: str = MISSING |
| | workers: int = -1 |
| |
|
| |
|
| | def process_index_path(tar_paths, index_dir): |
| | """ |
| | Appends the folder `{index_dir}` to the filepath of all tarfiles. |
| | Example: |
| | /X/Y/Z/audio_0.tar -> /X/Y/Z/{index_dir}/audio_0.index |
| | """ |
| | index_paths = [] |
| | for path in tar_paths: |
| | basepath, filename = os.path.split(path) |
| | path = filename.replace('.tar', '.index') |
| | path = os.path.join(basepath, path) |
| | base, name = os.path.split(path) |
| | index_path = os.path.join(index_dir, name) |
| | index_paths.append(index_path) |
| |
|
| | return index_paths |
| |
|
| |
|
| | def build_index(tarpath, indexfile): |
| | with IndexCreator(tarpath, indexfile) as index: |
| | index.create_index() |
| |
|
| |
|
| | @hydra.main(config_path=None, config_name='index_config', version_base="1.1") |
| | def main(cfg: DALITarredIndexConfig): |
| | if not INDEX_CREATOR_AVAILABLE: |
| | logging.error("`wds2idx` is not installed. Please install NVIDIA DALI >= 1.11") |
| | exit(1) |
| |
|
| | tar_files = list(glob.glob(os.path.join(cfg.tar_dir, "*.tar"))) |
| |
|
| | index_dir = os.path.join(cfg.tar_dir, "dali_index") |
| | if not os.path.exists(index_dir): |
| | os.makedirs(index_dir, exist_ok=True) |
| |
|
| | index_paths = process_index_path(tar_files, index_dir) |
| |
|
| | with Parallel(n_jobs=cfg.workers, verbose=len(tar_files)) as parallel: |
| | _ = parallel(delayed(build_index)(tarpath, indexfile) for tarpath, indexfile in zip(tar_files, index_paths)) |
| |
|
| | logging.info("Finished constructing index files !") |
| |
|
| |
|
| | ConfigStore.instance().store(name='index_config', node=DALITarredIndexConfig) |
| |
|
| |
|
| | if __name__ == '__main__': |
| | main() |
| |
|