cjayic's picture
init
f4b9544
from pathlib import Path
import math
import random
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
import torchaudio
import torchaudio.transforms as transforms
class LogMelSpectrogram(torch.nn.Module):
def __init__(self):
super().__init__()
self.melspctrogram = transforms.MelSpectrogram(
sample_rate=16000,
n_fft=1024,
win_length=1024,
hop_length=160,
center=False,
power=1.0,
norm="slaney",
onesided=True,
n_mels=128,
mel_scale="slaney",
)
def forward(self, wav):
wav = F.pad(wav, ((1024 - 160) // 2, (1024 - 160) // 2), "reflect")
mel = self.melspctrogram(wav)
logmel = torch.log(torch.clamp(mel, min=1e-5))
return logmel
class MelDataset(Dataset):
def __init__(
self,
root: Path,
segment_length: int,
sample_rate: int,
hop_length: int,
train: bool = True,
finetune: bool = False,
):
self.wavs_dir = root / "wavs"
self.mels_dir = root / "mels"
self.data_dir = self.wavs_dir if not finetune else self.mels_dir
self.segment_length = segment_length
self.sample_rate = sample_rate
self.hop_length = hop_length
self.train = train
self.finetune = finetune
suffix = ".wav" if not finetune else ".npy"
pattern = f"train/**/*{suffix}" if train else "dev/**/*{suffix}"
self.metadata = [
path.relative_to(self.data_dir).with_suffix("")
for path in self.data_dir.rglob(pattern)
]
self.logmel = LogMelSpectrogram()
def __len__(self):
return len(self.metadata)
def __getitem__(self, index):
path = self.metadata[index]
wav_path = self.wavs_dir / path
info = torchaudio.info(wav_path.with_suffix(".wav"))
if info.sample_rate != self.sample_rate:
raise ValueError(
f"Sample rate {info.sample_rate} doesn't match target of {self.sample_rate}"
)
if self.finetune:
mel_path = self.mels_dir / path
src_logmel = torch.from_numpy(np.load(mel_path.with_suffix(".npy")))
src_logmel = src_logmel.unsqueeze(0)
mel_frames_per_segment = math.ceil(self.segment_length / self.hop_length)
mel_diff = src_logmel.size(-1) - mel_frames_per_segment if self.train else 0
mel_offset = random.randint(0, max(mel_diff, 0))
frame_offset = self.hop_length * mel_offset
else:
frame_diff = info.num_frames - self.segment_length
frame_offset = random.randint(0, max(frame_diff, 0))
wav, _ = torchaudio.load(
filepath=wav_path.with_suffix(".wav"),
frame_offset=frame_offset if self.train else 0,
num_frames=self.segment_length if self.train else -1,
)
if wav.size(-1) < self.segment_length:
wav = F.pad(wav, (0, self.segment_length - wav.size(-1)))
if not self.finetune and self.train:
gain = random.random() * (0.99 - 0.4) + 0.4
flip = -1 if random.random() > 0.5 else 1
wav = flip * gain * wav / wav.abs().max()
tgt_logmel = self.logmel(wav.unsqueeze(0)).squeeze(0)
if self.finetune:
if self.train:
src_logmel = src_logmel[
:, :, mel_offset : mel_offset + mel_frames_per_segment
]
if src_logmel.size(-1) < mel_frames_per_segment:
src_logmel = F.pad(
src_logmel,
(0, mel_frames_per_segment - src_logmel.size(-1)),
"constant",
src_logmel.min(),
)
else:
src_logmel = tgt_logmel.clone()
return wav, src_logmel, tgt_logmel