warisqr007's picture
Add application file
eb9c81a
from typing import Any
import random
from pathlib import Path
import librosa
import numpy as np
import torch
from sklearn.model_selection import train_test_split
import pytorch_lightning as pl
from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split
def load_audio(full_path, sampling_rate=16000):
data, sampling_rate = librosa.load(full_path, sr = sampling_rate)
return data, sampling_rate
class ResynthesisDataset(Dataset):
def __init__(
self,
training_files,
segment_size,
code_hop_size,
sampling_rate
):
self.audio_files = training_files
self.segment_size = segment_size
self.code_hop_size = code_hop_size
self.sampling_rate = sampling_rate
random.seed(1234)
def _sample_interval(self, seqs, seq_len=None):
N = max([v.shape[-1] for v in seqs])
if seq_len is None:
seq_len = self.segment_size if self.segment_size > 0 else N
hops = [N // v.shape[-1] for v in seqs]
lcm = np.lcm.reduce(hops)
# Randomly pickup with the batch_max_steps length of the part
interval_start = 0
interval_end = N // lcm - seq_len // lcm
start_step = random.randint(interval_start, interval_end)
new_seqs = []
for i, v in enumerate(seqs):
start = start_step * (lcm // hops[i])
end = (start_step + seq_len // lcm) * (lcm // hops[i])
new_seqs += [v[..., start:end]]
return new_seqs
def __getitem__(self, index):
wav_fpath = self.audio_files[index]
audio, sampling_rate = load_audio(wav_fpath, self.sampling_rate)
if sampling_rate != self.sampling_rate:
import resampy
audio = resampy.resample(audio, sampling_rate, self.sampling_rate)
# audio = audio / MAX_WAV_VALUE
# audio = normalize(audio) * 0.95
audio = audio / (max(abs(audio)) + 0.00001) * 0.9
# Trim audio ending
code_length = min(audio.shape[0] // self.code_hop_size, tokens.shape[-1])
audio = audio[:code_length * self.code_hop_size]
while audio.shape[0] < self.segment_size:
audio = np.hstack([audio, audio])
audio = torch.FloatTensor(audio)
audio = audio.unsqueeze(0)
assert audio.size(1) >= self.segment_size, "Padding not supported!!"
audio = self._sample_interval([audio])
return audio.squeeze(0), str(wav_fpath)
def __len__(self):
return len(self.audio_files)
class PasrMultilingualDataModule(pl.LightningDataModule):
"""
A DataModule implements 5 key methods:
def prepare_data(self):
# things to do on 1 GPU/TPU (not on every GPU/TPU in DDP)
# download data, pre-process, split, save to disk, etc...
def setup(self, stage):
# things to do on every process in DDP
# load data, set variables, etc...
def train_dataloader(self):
# return train dataloader
def val_dataloader(self):
# return validation dataloader
def test_dataloader(self):
# return test dataloader
def teardown(self):
# called on every process in DDP
# clean up after fit or test
This allows you to share a full dataset without explaining how to download,
split, transform and process the data.
Read the docs:
https://lightning.ai/docs/pytorch/latest/data/datamodule.html
"""
def __init__(
self,
data_dir: str = "data",
batch_size: int = 16,
num_workers: int = 4,
pin_memory: bool = True,
segment_size: int = 20480,
code_hop_size: int = 320,
sampling_rate: int = 16000,
):
super().__init__()
# this line allows to access init params with 'self.hparams' attribute
# also ensures init params will be stored in ckpt
self.save_hyperparameters()
# data transformations
# self.transforms = T.Compose([T.ToTensor(), T.Normalize((0.1307,), (0.3081,))])
self.data_train: Dataset = None
self.data_val: Dataset = None
self.data_test: Dataset = None
@property
def num_classes(self):
return self.hparams.num_codes
def prepare_data(self):
"""Download data if needed.
Do not use it to assign state (self.x = y).
"""
pass
def setup(self, stage: str = None):
"""Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`.
This method is called by lightning with both `trainer.fit()` and `trainer.test()`, so be
careful not to execute things like random split twice!
"""
training_files = list(Path(self.hparams.data_dir).rglob("*.wav"))
training_files, self.validation_files, _, _ = train_test_split(training_files, training_files, test_size=0.001, random_state=42)
self.training_files, self.test_files, _, _ = train_test_split(training_files, training_files, test_size=0.0001, random_state=42)
# load and split datasets only if not loaded already
if not self.data_train and not self.data_val and not self.data_test:
self.data_train = PasrMultilingualDataset(
training_files=self.training_files,
segment_size=self.hparams.segment_size,
code_hop_size=self.hparams.code_hop_size,
sampling_rate=self.hparams.sampling_rate,
)
self.data_val = PasrMultilingualDataset(
training_files=self.validation_files,
segment_size=self.hparams.segment_size,
code_hop_size=self.hparams.code_hop_size,
sampling_rate=self.hparams.sampling_rate,
)
self.data_test = PasrMultilingualDataset(
training_files=self.test_files,
segment_size=self.hparams.segment_size,
code_hop_size=self.hparams.code_hop_size,
sampling_rate=self.hparams.sampling_rate,
)
def train_dataloader(self):
return DataLoader(
dataset=self.data_train,
batch_size=self.hparams.batch_size,
num_workers=self.hparams.num_workers,
pin_memory=self.hparams.pin_memory,
shuffle=True,
)
def val_dataloader(self):
return DataLoader(
dataset=self.data_val,
batch_size=self.hparams.batch_size,
num_workers=self.hparams.num_workers,
pin_memory=self.hparams.pin_memory,
shuffle=False,
)
def test_dataloader(self):
return DataLoader(
dataset=self.data_test,
batch_size=self.hparams.batch_size,
num_workers=self.hparams.num_workers,
pin_memory=self.hparams.pin_memory,
shuffle=False,
)
def teardown(self, stage: str = None):
"""Clean up after fit or test."""
pass
def state_dict(self):
"""Extra things to save to checkpoint."""
return {}
def load_state_dict(self, state_dict: dict[str, Any]):
"""Things to do when loading checkpoint."""
pass
if __name__ == "__main__":
dm = ResynthesisDataset()
dm.prepare_data()
dm.setup()
for batch in dm.train_dataloader():
print(batch[0].shape)
print(batch[1].shape)
break