cond_gen / MuCodec /muq_dev /muq_fairseq /data /ark_dataset.py
Leon299's picture
Add files using upload-large-folder tool
ec0bc9b verified
import logging
import torch
import torch.nn.functional as F
from fairseq.data.audio.raw_audio_dataset import RawAudioDataset
from typing import Tuple
try:
import kaldiio
except:
kaldiio = None
import warnings
logger = logging.getLogger(__name__)
class ArkDataset(RawAudioDataset):
def __init__(
self,
wav_scp,
dur_scp,
sr = 24000,
max_dur = 20,
num_buckets=0,
normalize=False,
):
super().__init__(
sample_rate=sr,
max_sample_size=max_dur*sr,
min_sample_size=1200,
shuffle=True,
pad=True,
normalize=normalize,
compute_mask=False,
)
self.sr = sr
self.max_dur = max_dur
self.normalize = normalize
logger.info("Loading Kaldi scp files from {}".format(wav_scp))
self.wav_data = kaldiio.load_scp(wav_scp)
self.keys = list(self.wav_data.keys())
dur_data = {}
keys_set = set(self.keys)
with open(dur_scp, 'r') as f:
for line in f:
line = line.strip().split()
if line[0] in keys_set:
dur_data[line[0]] = float(line[-1])
self.sizes = [int(dur_data[k]*self.sr/100) for k in self.keys]
logger.info("Loading Kaldi scp files done")
self.dataset_len = len(self.keys)
self.set_bucket_info(num_buckets)
def __len__(self):
return self.dataset_len
def __getitem__(self, idx):
pass
def size(self, idx):
pass
def postprocess(self, wav):
pass
def collater(self, samples):
pass