import logging import torch import torch.nn.functional as F from fairseq.data.audio.raw_audio_dataset import RawAudioDataset from typing import Tuple try: import kaldiio except: kaldiio = None import warnings logger = logging.getLogger(__name__) class ArkDataset(RawAudioDataset): def __init__( self, wav_scp, dur_scp, sr = 24000, max_dur = 20, num_buckets=0, normalize=False, ): super().__init__( sample_rate=sr, max_sample_size=max_dur*sr, min_sample_size=1200, shuffle=True, pad=True, normalize=normalize, compute_mask=False, ) self.sr = sr self.max_dur = max_dur self.normalize = normalize logger.info("Loading Kaldi scp files from {}".format(wav_scp)) self.wav_data = kaldiio.load_scp(wav_scp) self.keys = list(self.wav_data.keys()) dur_data = {} keys_set = set(self.keys) with open(dur_scp, 'r') as f: for line in f: line = line.strip().split() if line[0] in keys_set: dur_data[line[0]] = float(line[-1]) self.sizes = [int(dur_data[k]*self.sr/100) for k in self.keys] logger.info("Loading Kaldi scp files done") self.dataset_len = len(self.keys) self.set_bucket_info(num_buckets) def __len__(self): return self.dataset_len def __getitem__(self, idx): pass def size(self, idx): pass def postprocess(self, wav): pass def collater(self, samples): pass