| import json |
| import logging |
| import os |
| import shutil |
| from pathlib import Path |
|
|
| import torchaudio |
| from speechbrain.utils.data_utils import download_file |
| from tqdm import tqdm |
|
|
| from hw_asr.base.base_dataset import BaseDataset |
| from hw_asr.utils import ROOT_PATH |
|
|
| logger = logging.getLogger(__name__) |
|
|
| URL_LINKS = { |
| "dev-clean": "https://www.openslr.org/resources/12/dev-clean.tar.gz", |
| "dev-other": "https://www.openslr.org/resources/12/dev-other.tar.gz", |
| "test-clean": "https://www.openslr.org/resources/12/test-clean.tar.gz", |
| "test-other": "https://www.openslr.org/resources/12/test-other.tar.gz", |
| "train-clean-100": "https://www.openslr.org/resources/12/train-clean-100.tar.gz", |
| "train-clean-360": "https://www.openslr.org/resources/12/train-clean-360.tar.gz", |
| "train-other-500": "https://www.openslr.org/resources/12/train-other-500.tar.gz", |
| } |
|
|
|
|
| class LibrispeechDataset(BaseDataset): |
| def __init__(self, part, data_dir=None, *args, **kwargs): |
| assert part in URL_LINKS or part == 'train_all' |
|
|
| if data_dir is None: |
| data_dir = ROOT_PATH / "data" / "datasets" / "librispeech" |
| data_dir.mkdir(exist_ok=True, parents=True) |
| self._data_dir = data_dir |
| if part == 'train_all': |
| index = sum([self._get_or_load_index(part) |
| for part in URL_LINKS if 'train' in part], []) |
| else: |
| index = self._get_or_load_index(part) |
|
|
| super().__init__(index, *args, **kwargs) |
|
|
| def _load_part(self, part): |
| arch_path = self._data_dir / f"{part}.tar.gz" |
| print(f"Loading part {part}") |
| download_file(URL_LINKS[part], arch_path) |
| shutil.unpack_archive(arch_path, self._data_dir) |
| for fpath in (self._data_dir / "LibriSpeech").iterdir(): |
| shutil.move(str(fpath), str(self._data_dir / fpath.name)) |
| os.remove(str(arch_path)) |
| shutil.rmtree(str(self._data_dir / "LibriSpeech")) |
|
|
| def _get_or_load_index(self, part): |
| index_path = self._data_dir / f"{part}_index.json" |
| if index_path.exists(): |
| with index_path.open() as f: |
| index = json.load(f) |
| else: |
| index = self._create_index(part) |
| with index_path.open("w") as f: |
| json.dump(index, f, indent=2) |
| return index |
|
|
| def _create_index(self, part): |
| index = [] |
| split_dir = self._data_dir / part |
| if not split_dir.exists(): |
| self._load_part(part) |
|
|
| flac_dirs = set() |
| for dirpath, dirnames, filenames in os.walk(str(split_dir)): |
| if any([f.endswith(".flac") for f in filenames]): |
| flac_dirs.add(dirpath) |
| for flac_dir in tqdm( |
| list(flac_dirs), desc=f"Preparing librispeech folders: {part}" |
| ): |
| flac_dir = Path(flac_dir) |
| trans_path = list(flac_dir.glob("*.trans.txt"))[0] |
| with trans_path.open() as f: |
| for line in f: |
| f_id = line.split()[0] |
| f_text = " ".join(line.split()[1:]).strip() |
| flac_path = flac_dir / f"{f_id}.flac" |
| t_info = torchaudio.info(str(flac_path)) |
| length = t_info.num_frames / t_info.sample_rate |
| index.append( |
| { |
| "path": str(flac_path.absolute().resolve()), |
| "text": f_text.lower(), |
| "audio_len": length, |
| } |
| ) |
| return index |
|
|