Spaces:
Sleeping
Sleeping
| from typing import Any | |
| import random | |
| from pathlib import Path | |
| import librosa | |
| import numpy as np | |
| import torch | |
| from sklearn.model_selection import train_test_split | |
| import pytorch_lightning as pl | |
| from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split | |
| def load_audio(full_path, sampling_rate=16000): | |
| data, sampling_rate = librosa.load(full_path, sr = sampling_rate) | |
| return data, sampling_rate | |
| class ResynthesisDataset(Dataset): | |
| def __init__( | |
| self, | |
| training_files, | |
| segment_size, | |
| code_hop_size, | |
| sampling_rate | |
| ): | |
| self.audio_files = training_files | |
| self.segment_size = segment_size | |
| self.code_hop_size = code_hop_size | |
| self.sampling_rate = sampling_rate | |
| random.seed(1234) | |
| def _sample_interval(self, seqs, seq_len=None): | |
| N = max([v.shape[-1] for v in seqs]) | |
| if seq_len is None: | |
| seq_len = self.segment_size if self.segment_size > 0 else N | |
| hops = [N // v.shape[-1] for v in seqs] | |
| lcm = np.lcm.reduce(hops) | |
| # Randomly pickup with the batch_max_steps length of the part | |
| interval_start = 0 | |
| interval_end = N // lcm - seq_len // lcm | |
| start_step = random.randint(interval_start, interval_end) | |
| new_seqs = [] | |
| for i, v in enumerate(seqs): | |
| start = start_step * (lcm // hops[i]) | |
| end = (start_step + seq_len // lcm) * (lcm // hops[i]) | |
| new_seqs += [v[..., start:end]] | |
| return new_seqs | |
| def __getitem__(self, index): | |
| wav_fpath = self.audio_files[index] | |
| audio, sampling_rate = load_audio(wav_fpath, self.sampling_rate) | |
| if sampling_rate != self.sampling_rate: | |
| import resampy | |
| audio = resampy.resample(audio, sampling_rate, self.sampling_rate) | |
| # audio = audio / MAX_WAV_VALUE | |
| # audio = normalize(audio) * 0.95 | |
| audio = audio / (max(abs(audio)) + 0.00001) * 0.9 | |
| # Trim audio ending | |
| code_length = min(audio.shape[0] // self.code_hop_size, tokens.shape[-1]) | |
| audio = audio[:code_length * self.code_hop_size] | |
| while audio.shape[0] < self.segment_size: | |
| audio = np.hstack([audio, audio]) | |
| audio = torch.FloatTensor(audio) | |
| audio = audio.unsqueeze(0) | |
| assert audio.size(1) >= self.segment_size, "Padding not supported!!" | |
| audio = self._sample_interval([audio]) | |
| return audio.squeeze(0), str(wav_fpath) | |
| def __len__(self): | |
| return len(self.audio_files) | |
| class PasrMultilingualDataModule(pl.LightningDataModule): | |
| """ | |
| A DataModule implements 5 key methods: | |
| def prepare_data(self): | |
| # things to do on 1 GPU/TPU (not on every GPU/TPU in DDP) | |
| # download data, pre-process, split, save to disk, etc... | |
| def setup(self, stage): | |
| # things to do on every process in DDP | |
| # load data, set variables, etc... | |
| def train_dataloader(self): | |
| # return train dataloader | |
| def val_dataloader(self): | |
| # return validation dataloader | |
| def test_dataloader(self): | |
| # return test dataloader | |
| def teardown(self): | |
| # called on every process in DDP | |
| # clean up after fit or test | |
| This allows you to share a full dataset without explaining how to download, | |
| split, transform and process the data. | |
| Read the docs: | |
| https://lightning.ai/docs/pytorch/latest/data/datamodule.html | |
| """ | |
| def __init__( | |
| self, | |
| data_dir: str = "data", | |
| batch_size: int = 16, | |
| num_workers: int = 4, | |
| pin_memory: bool = True, | |
| segment_size: int = 20480, | |
| code_hop_size: int = 320, | |
| sampling_rate: int = 16000, | |
| ): | |
| super().__init__() | |
| # this line allows to access init params with 'self.hparams' attribute | |
| # also ensures init params will be stored in ckpt | |
| self.save_hyperparameters() | |
| # data transformations | |
| # self.transforms = T.Compose([T.ToTensor(), T.Normalize((0.1307,), (0.3081,))]) | |
| self.data_train: Dataset = None | |
| self.data_val: Dataset = None | |
| self.data_test: Dataset = None | |
| def num_classes(self): | |
| return self.hparams.num_codes | |
| def prepare_data(self): | |
| """Download data if needed. | |
| Do not use it to assign state (self.x = y). | |
| """ | |
| pass | |
| def setup(self, stage: str = None): | |
| """Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`. | |
| This method is called by lightning with both `trainer.fit()` and `trainer.test()`, so be | |
| careful not to execute things like random split twice! | |
| """ | |
| training_files = list(Path(self.hparams.data_dir).rglob("*.wav")) | |
| training_files, self.validation_files, _, _ = train_test_split(training_files, training_files, test_size=0.001, random_state=42) | |
| self.training_files, self.test_files, _, _ = train_test_split(training_files, training_files, test_size=0.0001, random_state=42) | |
| # load and split datasets only if not loaded already | |
| if not self.data_train and not self.data_val and not self.data_test: | |
| self.data_train = PasrMultilingualDataset( | |
| training_files=self.training_files, | |
| segment_size=self.hparams.segment_size, | |
| code_hop_size=self.hparams.code_hop_size, | |
| sampling_rate=self.hparams.sampling_rate, | |
| ) | |
| self.data_val = PasrMultilingualDataset( | |
| training_files=self.validation_files, | |
| segment_size=self.hparams.segment_size, | |
| code_hop_size=self.hparams.code_hop_size, | |
| sampling_rate=self.hparams.sampling_rate, | |
| ) | |
| self.data_test = PasrMultilingualDataset( | |
| training_files=self.test_files, | |
| segment_size=self.hparams.segment_size, | |
| code_hop_size=self.hparams.code_hop_size, | |
| sampling_rate=self.hparams.sampling_rate, | |
| ) | |
| def train_dataloader(self): | |
| return DataLoader( | |
| dataset=self.data_train, | |
| batch_size=self.hparams.batch_size, | |
| num_workers=self.hparams.num_workers, | |
| pin_memory=self.hparams.pin_memory, | |
| shuffle=True, | |
| ) | |
| def val_dataloader(self): | |
| return DataLoader( | |
| dataset=self.data_val, | |
| batch_size=self.hparams.batch_size, | |
| num_workers=self.hparams.num_workers, | |
| pin_memory=self.hparams.pin_memory, | |
| shuffle=False, | |
| ) | |
| def test_dataloader(self): | |
| return DataLoader( | |
| dataset=self.data_test, | |
| batch_size=self.hparams.batch_size, | |
| num_workers=self.hparams.num_workers, | |
| pin_memory=self.hparams.pin_memory, | |
| shuffle=False, | |
| ) | |
| def teardown(self, stage: str = None): | |
| """Clean up after fit or test.""" | |
| pass | |
| def state_dict(self): | |
| """Extra things to save to checkpoint.""" | |
| return {} | |
| def load_state_dict(self, state_dict: dict[str, Any]): | |
| """Things to do when loading checkpoint.""" | |
| pass | |
| if __name__ == "__main__": | |
| dm = ResynthesisDataset() | |
| dm.prepare_data() | |
| dm.setup() | |
| for batch in dm.train_dataloader(): | |
| print(batch[0].shape) | |
| print(batch[1].shape) | |
| break | |