| | |
| | |
| | |
| | |
| |
|
| | import torch |
| | import torchaudio |
| | import json |
| | import os |
| | import numpy as np |
| | import librosa |
| | from torch.nn.utils.rnn import pad_sequence |
| | from modules import whisper_extractor as whisper |
| |
|
| |
|
| | class TorchaudioDataset(torch.utils.data.Dataset): |
| | def __init__(self, cfg, dataset, sr, accelerator=None, metadata=None): |
| | """ |
| | Args: |
| | cfg: config |
| | dataset: dataset name |
| | |
| | """ |
| | assert isinstance(dataset, str) |
| |
|
| | self.sr = sr |
| | self.cfg = cfg |
| |
|
| | if metadata is None: |
| | self.train_metadata_path = os.path.join( |
| | cfg.preprocess.processed_dir, dataset, cfg.preprocess.train_file |
| | ) |
| | self.valid_metadata_path = os.path.join( |
| | cfg.preprocess.processed_dir, dataset, cfg.preprocess.valid_file |
| | ) |
| | self.metadata = self.get_metadata() |
| | else: |
| | self.metadata = metadata |
| |
|
| | if accelerator is not None: |
| | self.device = accelerator.device |
| | elif torch.cuda.is_available(): |
| | self.device = torch.device("cuda") |
| | else: |
| | self.device = torch.device("cpu") |
| |
|
| | def get_metadata(self): |
| | metadata = [] |
| | with open(self.train_metadata_path, "r", encoding="utf-8") as t: |
| | metadata.extend(json.load(t)) |
| | with open(self.valid_metadata_path, "r", encoding="utf-8") as v: |
| | metadata.extend(json.load(v)) |
| | return metadata |
| |
|
| | def __len__(self): |
| | return len(self.metadata) |
| |
|
| | def __getitem__(self, index): |
| | utt_info = self.metadata[index] |
| | wav_path = utt_info["Path"] |
| |
|
| | wav, sr = torchaudio.load(wav_path) |
| |
|
| | |
| | if sr != self.sr: |
| | wav = torchaudio.functional.resample(wav, sr, self.sr) |
| | |
| | if wav.shape[0] > 1: |
| | wav = torch.mean(wav, dim=0, keepdim=True) |
| | assert wav.shape[0] == 1 |
| | wav = wav.squeeze(0) |
| | |
| | length = wav.shape[0] |
| | |
| | return utt_info, wav, length |
| |
|
| |
|
| | class LibrosaDataset(TorchaudioDataset): |
| | def __init__(self, cfg, dataset, sr, accelerator=None, metadata=None): |
| | super().__init__(cfg, dataset, sr, accelerator, metadata) |
| |
|
| | def __getitem__(self, index): |
| | utt_info = self.metadata[index] |
| | wav_path = utt_info["Path"] |
| |
|
| | wav, _ = librosa.load(wav_path, sr=self.sr) |
| | |
| | wav = torch.from_numpy(wav) |
| |
|
| | |
| | length = wav.shape[0] |
| | return utt_info, wav, length |
| |
|
| |
|
| | class FFmpegDataset(TorchaudioDataset): |
| | def __init__(self, cfg, dataset, sr, accelerator=None, metadata=None): |
| | super().__init__(cfg, dataset, sr, accelerator, metadata) |
| |
|
| | def __getitem__(self, index): |
| | utt_info = self.metadata[index] |
| | wav_path = utt_info["Path"] |
| |
|
| | |
| | wav = whisper.load_audio(wav_path) |
| | |
| | wav = torch.from_numpy(wav) |
| | |
| | length = wav.shape[0] |
| |
|
| | return utt_info, wav, length |
| |
|
| |
|
| | def collate_batch(batch_list): |
| | """ |
| | Args: |
| | batch_list: list of (metadata, wav, length) |
| | """ |
| | metadata = [item[0] for item in batch_list] |
| | |
| | wavs = pad_sequence([item[1] for item in batch_list], batch_first=True) |
| | lens = [item[2] for item in batch_list] |
| |
|
| | return metadata, wavs, lens |
| |
|