| from dataclasses import dataclass
|
| from pathlib import Path
|
| from typing import Optional
|
|
|
| import librosa
|
| import numpy as np
|
| import torch
|
| from lightning import LightningDataModule
|
| from torch.utils.data import DataLoader, Dataset
|
|
|
| from fish_speech.utils import RankedLogger
|
|
|
| logger = RankedLogger(__name__, rank_zero_only=False)
|
|
|
|
|
| class VQGANDataset(Dataset):
|
| def __init__(
|
| self,
|
| filelist: str,
|
| sample_rate: int = 32000,
|
| hop_length: int = 640,
|
| slice_frames: Optional[int] = None,
|
| ):
|
| super().__init__()
|
|
|
| filelist = Path(filelist)
|
| root = filelist.parent
|
|
|
| self.files = [
|
| root / line.strip()
|
| for line in filelist.read_text(encoding="utf-8").splitlines()
|
| if line.strip()
|
| ]
|
| self.sample_rate = sample_rate
|
| self.hop_length = hop_length
|
| self.slice_frames = slice_frames
|
|
|
| def __len__(self):
|
| return len(self.files)
|
|
|
| def get_item(self, idx):
|
| file = self.files[idx]
|
|
|
| audio, _ = librosa.load(file, sr=self.sample_rate, mono=True)
|
|
|
|
|
| if (
|
| self.slice_frames is not None
|
| and audio.shape[0] > self.slice_frames * self.hop_length
|
| ):
|
| start = np.random.randint(
|
| 0, audio.shape[0] - self.slice_frames * self.hop_length
|
| )
|
| audio = audio[start : start + self.slice_frames * self.hop_length]
|
|
|
| if len(audio) == 0:
|
| return None
|
|
|
| max_value = np.abs(audio).max()
|
| if max_value > 1.0:
|
| audio = audio / max_value
|
|
|
| return {
|
| "audio": torch.from_numpy(audio),
|
| }
|
|
|
| def __getitem__(self, idx):
|
| try:
|
| return self.get_item(idx)
|
| except Exception as e:
|
| import traceback
|
|
|
| traceback.print_exc()
|
| logger.error(f"Error loading {self.files[idx]}: {e}")
|
| return None
|
|
|
|
|
| @dataclass
|
| class VQGANCollator:
|
| def __call__(self, batch):
|
| batch = [x for x in batch if x is not None]
|
|
|
| audio_lengths = torch.tensor([len(x["audio"]) for x in batch])
|
| audio_maxlen = audio_lengths.max()
|
|
|
|
|
| audios = []
|
| for x in batch:
|
| audios.append(
|
| torch.nn.functional.pad(x["audio"], (0, audio_maxlen - len(x["audio"])))
|
| )
|
|
|
| return {
|
| "audios": torch.stack(audios),
|
| "audio_lengths": audio_lengths,
|
| }
|
|
|
|
|
| class VQGANDataModule(LightningDataModule):
|
| def __init__(
|
| self,
|
| train_dataset: VQGANDataset,
|
| val_dataset: VQGANDataset,
|
| batch_size: int = 32,
|
| num_workers: int = 4,
|
| val_batch_size: Optional[int] = None,
|
| ):
|
| super().__init__()
|
|
|
| self.train_dataset = train_dataset
|
| self.val_dataset = val_dataset
|
| self.batch_size = batch_size
|
| self.val_batch_size = val_batch_size or batch_size
|
| self.num_workers = num_workers
|
|
|
| def train_dataloader(self):
|
| return DataLoader(
|
| self.train_dataset,
|
| batch_size=self.batch_size,
|
| collate_fn=VQGANCollator(),
|
| num_workers=self.num_workers,
|
| shuffle=True,
|
| persistent_workers=True,
|
| )
|
|
|
| def val_dataloader(self):
|
| return DataLoader(
|
| self.val_dataset,
|
| batch_size=self.val_batch_size,
|
| collate_fn=VQGANCollator(),
|
| num_workers=self.num_workers,
|
| persistent_workers=True,
|
| )
|
|
|
|
|
| if __name__ == "__main__":
|
| dataset = VQGANDataset("data/LibriTTS_R/vq_train_filelist.txt")
|
| dataloader = DataLoader(
|
| dataset, batch_size=4, shuffle=False, collate_fn=VQGANCollator()
|
| )
|
|
|
| for batch in dataloader:
|
| print(batch["audios"].shape)
|
| print(batch["features"].shape)
|
| print(batch["audio_lengths"])
|
| print(batch["feature_lengths"])
|
| break
|
|
|