| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| """ |
| This file computes fbank features of the LibriSpeech dataset. |
| It looks for manifests in the directory data/manifests. |
| |
| The generated fbank features are saved in data/fbank. |
| """ |
|
|
| import logging |
| import os |
| from pathlib import Path |
|
|
| import torch |
| from lhotse import S3PRLSSL, CutSet, NumpyFilesWriter, S3PRLSSLConfig |
| from lhotse.recipes.utils import read_manifests_if_cached |
|
|
| from icefall.utils import get_executor |
|
|
| |
| |
| |
| |
| torch.set_num_threads(1) |
| torch.set_num_interop_threads(1) |
|
|
|
|
| def compute_ssl_librispeech(): |
| src_dir = Path("data/manifests") |
| output_dir = Path("data/ssl") |
| num_jobs = 1 |
|
|
| dataset_parts = ( |
| "dev-clean", |
| "dev-other", |
| "test-clean", |
| "test-other", |
| "train-clean-100", |
| ) |
| prefix = "librispeech" |
| suffix = "jsonl.gz" |
| manifests = read_manifests_if_cached( |
| dataset_parts=dataset_parts, |
| output_dir=src_dir, |
| prefix=prefix, |
| suffix=suffix, |
| ) |
| assert manifests is not None |
|
|
| assert len(manifests) == len(dataset_parts), ( |
| len(manifests), |
| len(dataset_parts), |
| list(manifests.keys()), |
| dataset_parts, |
| ) |
|
|
| extractor = S3PRLSSL(S3PRLSSLConfig(ssl_model="wav2vec2", device="cuda")) |
|
|
| with get_executor() as ex: |
| for partition, m in manifests.items(): |
| cuts_filename = f"{prefix}_cuts_{partition}.{suffix}" |
| if (output_dir / cuts_filename).is_file(): |
| logging.info(f"{partition} already exists - skipping.") |
| continue |
| logging.info(f"Processing {partition}") |
| cut_set = CutSet.from_manifests( |
| recordings=m["recordings"], |
| supervisions=m["supervisions"], |
| ) |
| cut_set = cut_set.compute_and_store_features( |
| extractor=extractor, |
| storage_path=f"{output_dir}/{prefix}_feats_{partition}", |
| storage_type=NumpyFilesWriter, |
| ) |
| cut_set.to_file(output_dir / cuts_filename) |
|
|
|
|
| if __name__ == "__main__": |
| formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" |
|
|
| logging.basicConfig(format=formatter, level=logging.INFO) |
|
|
| compute_ssl_librispeech() |
|
|