| from typing import Callable, Optional, Union |
| from tqdm import tqdm |
| import os |
| import torch |
| import torchaudio |
| import torchaudio.functional as F |
| from torch.utils.data import Dataset, DataLoader, IterableDataset, random_split |
| from pytorch_lightning import LightningDataModule |
| import webdataset |
|
|
|
|
| class VLSP2020Dataset(Dataset): |
| def __init__(self, root: str, sample_rate: int = 16000): |
| super().__init__() |
|
|
| self.sample_rate = sample_rate |
| self.memory = self._prepare_data(root) |
| self._memory = tuple( |
| (v["transcript"], v["audio"]) for v in self.memory.values() |
| ) |
|
|
| @staticmethod |
| def _prepare_data(root: str): |
| memory = {} |
|
|
| for f in os.scandir(root): |
| file_name, file_ext = os.path.splitext(f.name) |
|
|
| if file_ext == ".txt": |
| if file_name not in memory: |
| memory[file_name] = {"transcript": f.path} |
| elif "transcript" not in memory[file_name]: |
| memory[file_name]["transcript"] = f.path |
| else: |
| raise ValueError(f"Duplicate transcript for {f.path}") |
| else: |
| if file_name not in memory: |
| memory[file_name] = {"audio": f.path} |
| elif "audio" not in memory[file_name]: |
| memory[file_name]["audio"] = f.path |
| else: |
| raise ValueError(f"Duplicate audio for {f.path}") |
|
|
| for key, value in memory.items(): |
| if "audio" not in value: |
| raise ValueError(f"Missing audio for {key}") |
| elif "transcript" not in value: |
| raise ValueError(f"Missing transcript for {key}") |
|
|
| return memory |
|
|
| def __len__(self): |
| return len(self.memory) |
|
|
| def __getitem__(self, index: int): |
| transcript, audio = self._memory[index] |
|
|
| with open(transcript, "r") as f: |
| transcript = f.read() |
|
|
| audio, sample_rate = torchaudio.load(audio) |
| audio = F.resample(audio, sample_rate, self.sample_rate) |
|
|
| return transcript, audio |
|
|
|
|
| class VLSP2020TarDataset: |
| def __init__(self, outpath: str): |
| self.outpath = outpath |
|
|
| def convert(self, dataset: VLSP2020Dataset): |
| writer = webdataset.TarWriter(self.outpath) |
|
|
| for idx, (transcript, audio) in enumerate(tqdm(dataset, colour="green")): |
| writer.write( |
| { |
| "__key__": f"{idx:08d}", |
| "txt": transcript, |
| "pth": audio, |
| } |
| ) |
|
|
| writer.close() |
|
|
| def load(self) -> webdataset.WebDataset: |
| self.data = ( |
| webdataset.WebDataset(self.outpath) |
| .decode( |
| webdataset.handle_extension("txt", lambda x: x.decode("utf-8")), |
| webdataset.torch_audio, |
| ) |
| .to_tuple("txt", "pth") |
| ) |
|
|
| return self.data |
|
|
|
|
| def get_dataloader( |
| dataset: Union[VLSP2020Dataset, webdataset.WebDataset], |
| return_transcript: bool = False, |
| target_transform: Optional[Callable] = None, |
| batch_size: int = 32, |
| num_workers: int = 2, |
| ): |
| def collate_fn(batch): |
| def get_audio(item): |
| audio = item[1] |
|
|
| assert ( |
| isinstance(audio, torch.Tensor) |
| and audio.ndim == 2 |
| and audio.size(0) == 1 |
| ) |
|
|
| return audio.squeeze(0) |
|
|
| audio = tuple(get_audio(item) for item in batch) |
|
|
| if return_transcript: |
| if target_transform is not None: |
| transcript = tuple(target_transform(item[0]) for item in batch) |
| else: |
| transcript = tuple(item[0] for item in batch) |
|
|
| return transcript, audio |
| else: |
| return audio |
|
|
| return DataLoader( |
| dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=collate_fn |
| ) |
|
|