File size: 8,229 Bytes
b386992
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os
from dataclasses import dataclass, field
from typing import Optional

import hydra
from convert_to_tarred_audio_dataset import ASRTarredDatasetBuilder, ASRTarredDatasetMetadata
from hydra.core.config_store import ConfigStore
from joblib import Parallel, delayed
from omegaconf import MISSING
from tqdm import tqdm

"""
# Partial Tarred Audio Dataset Creator

## Overview

This script facilitates the creation of tarred and sharded audio datasets from existing tarred manifests. It allows you to select specific shards from a manifest file and then tar them separately. 

This is useful in several scenarios:
- When you only need to process a specific subset of shards (e.g., for debugging or incremental dataset preparation).
- When you want to parallelize shard creation across multiple SLURM jobs to accelerate the dataset generation process and overcome per-job time limits.

## Prerequisites

- Ensure that the `convert_to_tarred_audio_dataset` script is correctly configured and run with the `--only_manifests` flag to generate the necessary manifest files.
- Make sure the paths to the manifest and metadata files are correct and accessible.

## Usage

### Script Execution

To run the script, use the following command:

python partial_convertion_to_tarred_audio_dataset.py \
    # the path to the tarred manifest file that contains the entries for the shards you want to process. This option is mandatory.
    --tarred_manifest_filepath=<path to the tarred manifest file > \
    # any other optional argument
    --output_dir=<output directory for tarred shards> \
    --shards_to_tar=<shard IDs to be tarred> \
    --num_workers=-1 \
    --dataset_metadata_filepath=<dataset metadata YAML filepath>

Example:
python partial_convertion_to_tarred_audio_dataset.py \
    tarred_manifest_filepath="path/to/manifest.json" \
    shards_to_tar="0:3"
"""


def select_shards(manifest_filepath: str, shards_to_tar: str, slice_with_offset: bool = False):
    """
    Selects and returns a subset of shards from the tarred manifest file.

    Args:
        manifest_filepath (str): The path to the tarred manifest file.
        shards_to_tar (str): A range or list of shard IDs to select, e.g., "0:5" or "0,1,2".
        slice_with_offset (bool, optional): If True, slices entries based on audio offsets. Defaults to False.

    Raises:
        FileNotFoundError: If the manifest file does not exist.
        KeyError: If `slice_with_offset` is enabled but required fields are missing in the manifest entries.

    Returns:
        Dict[int, List[Dict[str, any]]]: A dictionary where the keys are shard IDs and the values are lists of entries for those shards.
    """
    shard_ids = []
    if shards_to_tar != "all":
        if ":" not in shards_to_tar:
            shard_ids = [int(shards_to_tar)]
        else:
            start_shard_idx, end_shard_idx = map(
                lambda x: int(x.strip()) if x.strip() else None, shards_to_tar.split(":")
            )
            shard_ids = list(range(start_shard_idx, end_shard_idx))

    entries_to_shard = {}
    with open(manifest_filepath, 'r') as manifest:
        for line in tqdm(manifest, desc="Selecting shards"):
            entry = json.loads(line)
            if shards_to_tar == "all" or entry['shard_id'] in shard_ids:
                if entry['shard_id'] not in entries_to_shard:
                    entries_to_shard[entry['shard_id']] = []

                if slice_with_offset:
                    if 'abs_audio_filepath' not in entry or 'source_audio_offset' not in entry:
                        raise KeyError(
                            f"`slice_with_offset` is enabled, but `abs_audio_filepath` and/or `source_audio_offset` are not found in the entry:\n{entry}."
                        )
                    entry['audio_filepath'] = entry.pop('abs_audio_filepath')
                    entry['offset'] = entry.pop('source_audio_offset')

                entries_to_shard[entry['shard_id']].append(entry)

    return entries_to_shard


@dataclass
class PartialASRTarredDatasetConfig:
    """
    Configuration class for creating partial tarred audio dataset shards.

    Attributes:
        tarred_manifest_filepath (str): The path to the tarred manifest file.
        output_dir (Optional[str]): Directory where the output tarred shards will be saved.
        shards_to_tar (Optional[str]): A range or list of shard IDs to tar.
        num_workers (int): Number of parallel workers to use for tar file creation.
        dataset_metadata_filepath (Optional[str]): Path to the dataset metadata YAML file.
        dataset_metadata (ASRTarredDatasetMetadata): Dataset metadata configuration.
    """

    tarred_manifest_filepath: str = MISSING
    output_dir: Optional[str] = None
    shards_to_tar: Optional[str] = "all"
    num_workers: int = 1
    dataset_metadata_filepath: Optional[str] = None
    dataset_metadata: ASRTarredDatasetMetadata = field(default=ASRTarredDatasetMetadata)
    slice_with_offset: bool = False


def create_shards(cfg: PartialASRTarredDatasetConfig):
    """
    Creates tarred shards based on the provided configuration.

    Args:
        cfg (PartialASRTarredDatasetConfig): The configuration object containing paths, shard IDs, and metadata.

    Raises:
        ValueError: If the `tarred_manifest_filepath` is None.
        FileNotFoundError: If the tarred manifest file or dataset metadata file does not exist.

    Notes:
        - Reads the tarred manifest file and selects the specified shards.
        - Creates tarred shards in parallel using the `ASRTarredDatasetBuilder`.
        - The `dataset_metadata_filepath` is inferred if not provided.
    """
    if cfg.tarred_manifest_filepath is None:
        raise ValueError("The `tarred_manifest_filepath` cannot be `None`. Please check your configuration.")

    if not os.path.exists(cfg.tarred_manifest_filepath):
        raise FileNotFoundError(
            f"The `tarred_manifest_filepath` was not found: {cfg.tarred_manifest_filepath}. Please verify that the filepath is correct."
        )

    if cfg.dataset_metadata_filepath is None:
        cfg.dataset_metadata_filepath = os.path.join(os.path.dirname(cfg.tarred_manifest_filepath), "metadata.yaml")

    if cfg.output_dir is None:
        cfg.output_dir = os.path.dirname(cfg.tarred_manifest_filepath)

    if not os.path.exists(cfg.dataset_metadata_filepath):
        raise FileNotFoundError(
            f"The `dataset_metadata_filepath` was not found: {cfg.dataset_metadata_filepath}. Please verify that the filepath is correct."
        )
    else:
        cfg.dataset_metadata = ASRTarredDatasetMetadata.from_file(cfg.dataset_metadata_filepath)

    entries_to_shard = select_shards(
        cfg.tarred_manifest_filepath, cfg.shards_to_tar, cfg.dataset_metadata.dataset_config.slice_with_offset
    )

    builder = ASRTarredDatasetBuilder()
    builder.configure(cfg.dataset_metadata.dataset_config)

    with Parallel(n_jobs=cfg.num_workers, verbose=len(entries_to_shard)) as parallel:
        # Call parallel tarfile construction
        _ = parallel(
            delayed(builder._create_shard)(
                entries=entries_to_shard[shard_id],
                target_dir=cfg.output_dir,
                shard_id=shard_id,
            )
            for shard_id in entries_to_shard
        )


@hydra.main(config_path=None, config_name='partial_tar_config')
def main(cfg: PartialASRTarredDatasetConfig):
    create_shards(cfg)


ConfigStore.instance().store(name='partial_tar_config', node=PartialASRTarredDatasetConfig)

if __name__ == '__main__':
    main()