|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import json |
|
|
import os |
|
|
from dataclasses import dataclass, field |
|
|
from typing import Optional |
|
|
|
|
|
import hydra |
|
|
from convert_to_tarred_audio_dataset import ASRTarredDatasetBuilder, ASRTarredDatasetMetadata |
|
|
from hydra.core.config_store import ConfigStore |
|
|
from joblib import Parallel, delayed |
|
|
from omegaconf import MISSING |
|
|
from tqdm import tqdm |
|
|
|
|
|
""" |
|
|
# Partial Tarred Audio Dataset Creator |
|
|
|
|
|
## Overview |
|
|
|
|
|
This script facilitates the creation of tarred and sharded audio datasets from existing tarred manifests. It allows you to select specific shards from a manifest file and then tar them separately. |
|
|
|
|
|
This is useful in several scenarios: |
|
|
- When you only need to process a specific subset of shards (e.g., for debugging or incremental dataset preparation). |
|
|
- When you want to parallelize shard creation across multiple SLURM jobs to accelerate the dataset generation process and overcome per-job time limits. |
|
|
|
|
|
## Prerequisites |
|
|
|
|
|
- Ensure that the `convert_to_tarred_audio_dataset` script is correctly configured and run with the `--only_manifests` flag to generate the necessary manifest files. |
|
|
- Make sure the paths to the manifest and metadata files are correct and accessible. |
|
|
|
|
|
## Usage |
|
|
|
|
|
### Script Execution |
|
|
|
|
|
To run the script, use the following command: |
|
|
|
|
|
python partial_convertion_to_tarred_audio_dataset.py \ |
|
|
# the path to the tarred manifest file that contains the entries for the shards you want to process. This option is mandatory. |
|
|
--tarred_manifest_filepath=<path to the tarred manifest file > \ |
|
|
# any other optional argument |
|
|
--output_dir=<output directory for tarred shards> \ |
|
|
--shards_to_tar=<shard IDs to be tarred> \ |
|
|
--num_workers=-1 \ |
|
|
--dataset_metadata_filepath=<dataset metadata YAML filepath> |
|
|
|
|
|
Example: |
|
|
python partial_convertion_to_tarred_audio_dataset.py \ |
|
|
tarred_manifest_filepath="path/to/manifest.json" \ |
|
|
shards_to_tar="0:3" |
|
|
""" |
|
|
|
|
|
|
|
|
def select_shards(manifest_filepath: str, shards_to_tar: str, slice_with_offset: bool = False): |
|
|
""" |
|
|
Selects and returns a subset of shards from the tarred manifest file. |
|
|
|
|
|
Args: |
|
|
manifest_filepath (str): The path to the tarred manifest file. |
|
|
shards_to_tar (str): A range or list of shard IDs to select, e.g., "0:5" or "0,1,2". |
|
|
slice_with_offset (bool, optional): If True, slices entries based on audio offsets. Defaults to False. |
|
|
|
|
|
Raises: |
|
|
FileNotFoundError: If the manifest file does not exist. |
|
|
KeyError: If `slice_with_offset` is enabled but required fields are missing in the manifest entries. |
|
|
|
|
|
Returns: |
|
|
Dict[int, List[Dict[str, any]]]: A dictionary where the keys are shard IDs and the values are lists of entries for those shards. |
|
|
""" |
|
|
shard_ids = [] |
|
|
if shards_to_tar != "all": |
|
|
if ":" not in shards_to_tar: |
|
|
shard_ids = [int(shards_to_tar)] |
|
|
else: |
|
|
start_shard_idx, end_shard_idx = map( |
|
|
lambda x: int(x.strip()) if x.strip() else None, shards_to_tar.split(":") |
|
|
) |
|
|
shard_ids = list(range(start_shard_idx, end_shard_idx)) |
|
|
|
|
|
entries_to_shard = {} |
|
|
with open(manifest_filepath, 'r') as manifest: |
|
|
for line in tqdm(manifest, desc="Selecting shards"): |
|
|
entry = json.loads(line) |
|
|
if shards_to_tar == "all" or entry['shard_id'] in shard_ids: |
|
|
if entry['shard_id'] not in entries_to_shard: |
|
|
entries_to_shard[entry['shard_id']] = [] |
|
|
|
|
|
if slice_with_offset: |
|
|
if 'abs_audio_filepath' not in entry or 'source_audio_offset' not in entry: |
|
|
raise KeyError( |
|
|
f"`slice_with_offset` is enabled, but `abs_audio_filepath` and/or `source_audio_offset` are not found in the entry:\n{entry}." |
|
|
) |
|
|
entry['audio_filepath'] = entry.pop('abs_audio_filepath') |
|
|
entry['offset'] = entry.pop('source_audio_offset') |
|
|
|
|
|
entries_to_shard[entry['shard_id']].append(entry) |
|
|
|
|
|
return entries_to_shard |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class PartialASRTarredDatasetConfig: |
|
|
""" |
|
|
Configuration class for creating partial tarred audio dataset shards. |
|
|
|
|
|
Attributes: |
|
|
tarred_manifest_filepath (str): The path to the tarred manifest file. |
|
|
output_dir (Optional[str]): Directory where the output tarred shards will be saved. |
|
|
shards_to_tar (Optional[str]): A range or list of shard IDs to tar. |
|
|
num_workers (int): Number of parallel workers to use for tar file creation. |
|
|
dataset_metadata_filepath (Optional[str]): Path to the dataset metadata YAML file. |
|
|
dataset_metadata (ASRTarredDatasetMetadata): Dataset metadata configuration. |
|
|
""" |
|
|
|
|
|
tarred_manifest_filepath: str = MISSING |
|
|
output_dir: Optional[str] = None |
|
|
shards_to_tar: Optional[str] = "all" |
|
|
num_workers: int = 1 |
|
|
dataset_metadata_filepath: Optional[str] = None |
|
|
dataset_metadata: ASRTarredDatasetMetadata = field(default=ASRTarredDatasetMetadata) |
|
|
slice_with_offset: bool = False |
|
|
|
|
|
|
|
|
def create_shards(cfg: PartialASRTarredDatasetConfig): |
|
|
""" |
|
|
Creates tarred shards based on the provided configuration. |
|
|
|
|
|
Args: |
|
|
cfg (PartialASRTarredDatasetConfig): The configuration object containing paths, shard IDs, and metadata. |
|
|
|
|
|
Raises: |
|
|
ValueError: If the `tarred_manifest_filepath` is None. |
|
|
FileNotFoundError: If the tarred manifest file or dataset metadata file does not exist. |
|
|
|
|
|
Notes: |
|
|
- Reads the tarred manifest file and selects the specified shards. |
|
|
- Creates tarred shards in parallel using the `ASRTarredDatasetBuilder`. |
|
|
- The `dataset_metadata_filepath` is inferred if not provided. |
|
|
""" |
|
|
if cfg.tarred_manifest_filepath is None: |
|
|
raise ValueError("The `tarred_manifest_filepath` cannot be `None`. Please check your configuration.") |
|
|
|
|
|
if not os.path.exists(cfg.tarred_manifest_filepath): |
|
|
raise FileNotFoundError( |
|
|
f"The `tarred_manifest_filepath` was not found: {cfg.tarred_manifest_filepath}. Please verify that the filepath is correct." |
|
|
) |
|
|
|
|
|
if cfg.dataset_metadata_filepath is None: |
|
|
cfg.dataset_metadata_filepath = os.path.join(os.path.dirname(cfg.tarred_manifest_filepath), "metadata.yaml") |
|
|
|
|
|
if cfg.output_dir is None: |
|
|
cfg.output_dir = os.path.dirname(cfg.tarred_manifest_filepath) |
|
|
|
|
|
if not os.path.exists(cfg.dataset_metadata_filepath): |
|
|
raise FileNotFoundError( |
|
|
f"The `dataset_metadata_filepath` was not found: {cfg.dataset_metadata_filepath}. Please verify that the filepath is correct." |
|
|
) |
|
|
else: |
|
|
cfg.dataset_metadata = ASRTarredDatasetMetadata.from_file(cfg.dataset_metadata_filepath) |
|
|
|
|
|
entries_to_shard = select_shards( |
|
|
cfg.tarred_manifest_filepath, cfg.shards_to_tar, cfg.dataset_metadata.dataset_config.slice_with_offset |
|
|
) |
|
|
|
|
|
builder = ASRTarredDatasetBuilder() |
|
|
builder.configure(cfg.dataset_metadata.dataset_config) |
|
|
|
|
|
with Parallel(n_jobs=cfg.num_workers, verbose=len(entries_to_shard)) as parallel: |
|
|
|
|
|
_ = parallel( |
|
|
delayed(builder._create_shard)( |
|
|
entries=entries_to_shard[shard_id], |
|
|
target_dir=cfg.output_dir, |
|
|
shard_id=shard_id, |
|
|
) |
|
|
for shard_id in entries_to_shard |
|
|
) |
|
|
|
|
|
|
|
|
@hydra.main(config_path=None, config_name='partial_tar_config') |
|
|
def main(cfg: PartialASRTarredDatasetConfig): |
|
|
create_shards(cfg) |
|
|
|
|
|
|
|
|
ConfigStore.instance().store(name='partial_tar_config', node=PartialASRTarredDatasetConfig) |
|
|
|
|
|
if __name__ == '__main__': |
|
|
main() |
|
|
|