| from dataclasses import dataclass |
| from pathlib import Path |
|
|
| import lightning as L |
| import torch |
| from torch.utils.data import DataLoader, Dataset |
|
|
| from ..util import get_logger |
| from .dataset import AudioItem, ChunkedAudioDataset, pad_audio |
|
|
| logger = get_logger() |
|
|
|
|
| @dataclass |
| class AudioBatch: |
| waveform: torch.Tensor |
| audio_ids: list[str] |
| paths: list[Path] |
| sample_rates: list[int] |
| frame_offsets: list[int] | None |
|
|
|
|
| @dataclass |
| class AudioDataConfig: |
| csv_path: str |
| audio_root: str |
|
|
| |
| sample_rate: int | None = 16000 |
| mono: bool = True |
| normalize: bool = True |
|
|
| |
| chunk_size: int | None = None |
| chunk_hop_size: int | None = None |
|
|
| |
| batch_size: int = 32 |
| num_workers: int = 4 |
| pin_memory: bool = False |
| persistent_workers: bool = False |
| shuffle: bool = False |
| drop_last: bool = False |
|
|
|
|
| def audio_collate_fn(batch: list[AudioItem]) -> AudioBatch: |
| waveforms = [item.waveform for item in batch] |
|
|
| |
| max_length = max(wave.shape[1] for wave in waveforms) |
| if any(wave.shape[1] != max_length for wave in waveforms): |
| waveforms = [pad_audio(wave, max_length) for wave in waveforms] |
|
|
| return AudioBatch( |
| waveform=torch.stack(waveforms), |
| audio_ids=[item.audio_id for item in batch], |
| paths=[item.path for item in batch], |
| sample_rates=[item.sample_rate for item in batch], |
| frame_offsets=[item.frame_offset for item in batch], |
| ) |
|
|
|
|
| class AudioDataModule(L.LightningDataModule): |
| def __init__( |
| self, |
| train_config: AudioDataConfig, |
| val_config: AudioDataConfig | None = None, |
| test_config: AudioDataConfig | None = None, |
| ): |
| super().__init__() |
| self.train_config = train_config |
| self.val_config = val_config or train_config |
| self.test_config = test_config or self.val_config |
|
|
| |
| self.train_dataset: Dataset | None = None |
| self.val_dataset: Dataset | None = None |
| self.test_dataset: Dataset | None = None |
|
|
| def _create_dataset(self, config: AudioDataConfig) -> Dataset: |
| return ChunkedAudioDataset( |
| csv_path=config.csv_path, |
| audio_root=config.audio_root, |
| chunk_size=config.chunk_size, |
| hop_size=config.chunk_hop_size, |
| mono=config.mono, |
| normalize=config.normalize, |
| target_sample_rate=config.sample_rate, |
| ) |
|
|
| def setup(self, stage: str | None = None): |
| if stage == "fit" or stage is None: |
| self.train_dataset = self._create_dataset(self.train_config) |
| self.val_dataset = self._create_dataset(self.val_config) |
| elif stage == "validate": |
| self.val_dataset = self._create_dataset(self.val_config) |
| elif stage == "test" or stage == "predict": |
| self.test_dataset = self._create_dataset(self.test_config) |
|
|
| def train_dataloader(self) -> DataLoader: |
| return DataLoader( |
| self.train_dataset, |
| batch_size=self.train_config.batch_size, |
| num_workers=self.train_config.num_workers, |
| pin_memory=self.train_config.pin_memory, |
| persistent_workers=self.train_config.persistent_workers if self.train_config.num_workers > 0 else False, |
| shuffle=self.train_config.shuffle, |
| drop_last=self.train_config.drop_last, |
| collate_fn=audio_collate_fn, |
| ) |
|
|
| def val_dataloader(self) -> DataLoader: |
| return DataLoader( |
| self.val_dataset, |
| batch_size=self.val_config.batch_size, |
| num_workers=self.val_config.num_workers, |
| pin_memory=self.val_config.pin_memory, |
| persistent_workers=self.val_config.persistent_workers if self.val_config.num_workers > 0 else False, |
| shuffle=False, |
| drop_last=False, |
| collate_fn=audio_collate_fn, |
| ) |
|
|
| def test_dataloader(self) -> DataLoader: |
| return DataLoader( |
| self.test_dataset, |
| batch_size=self.test_config.batch_size, |
| num_workers=self.test_config.num_workers, |
| pin_memory=self.test_config.pin_memory, |
| persistent_workers=self.test_config.persistent_workers if self.test_config.num_workers > 0 else False, |
| shuffle=False, |
| drop_last=False, |
| collate_fn=audio_collate_fn, |
| ) |
|
|
| def predict_dataloader(self) -> DataLoader: |
| return DataLoader( |
| self.test_dataset, |
| batch_size=self.test_config.batch_size, |
| num_workers=self.test_config.num_workers, |
| pin_memory=self.test_config.pin_memory, |
| persistent_workers=self.test_config.persistent_workers if self.test_config.num_workers > 0 else False, |
| shuffle=False, |
| drop_last=False, |
| collate_fn=audio_collate_fn, |
| ) |
|
|