Spaces:
Running
Running
| from dataclasses import dataclass | |
| from pathlib import Path | |
| import lightning as L | |
| import torch | |
| from torch.utils.data import DataLoader, Dataset | |
| from ..util import get_logger | |
| from .dataset import AudioItem, ChunkedAudioDataset, pad_audio | |
| logger = get_logger() | |
| class AudioBatch: | |
| waveform: torch.Tensor # [batch, channels, samples] | |
| audio_ids: list[str] | |
| paths: list[Path] | |
| sample_rates: list[int] | |
| frame_offsets: list[int] | None # For chunked audio | |
| class AudioDataConfig: | |
| csv_path: str | |
| audio_root: str | |
| # Audio processing | |
| sample_rate: int | None = 16000 | |
| mono: bool = True | |
| normalize: bool = True | |
| # Chunking options | |
| chunk_size: int | None = None | |
| chunk_hop_size: int | None = None | |
| # DataLoader options | |
| batch_size: int = 32 | |
| num_workers: int = 4 | |
| pin_memory: bool = False | |
| persistent_workers: bool = False | |
| shuffle: bool = False | |
| drop_last: bool = False | |
| def audio_collate_fn(batch: list[AudioItem]) -> AudioBatch: | |
| waveforms = [item.waveform for item in batch] | |
| # Pad all waveforms to max length | |
| max_length = max(wave.shape[1] for wave in waveforms) | |
| if any(wave.shape[1] != max_length for wave in waveforms): | |
| waveforms = [pad_audio(wave, max_length) for wave in waveforms] | |
| return AudioBatch( | |
| waveform=torch.stack(waveforms), | |
| audio_ids=[item.audio_id for item in batch], | |
| paths=[item.path for item in batch], | |
| sample_rates=[item.sample_rate for item in batch], | |
| frame_offsets=[item.frame_offset for item in batch], | |
| ) | |
| class AudioDataModule(L.LightningDataModule): | |
| def __init__( | |
| self, | |
| train_config: AudioDataConfig, | |
| val_config: AudioDataConfig | None = None, | |
| test_config: AudioDataConfig | None = None, | |
| ): | |
| super().__init__() | |
| self.train_config = train_config | |
| self.val_config = val_config or train_config | |
| self.test_config = test_config or self.val_config | |
| # Set to be initialized in setup() | |
| self.train_dataset: Dataset | None = None | |
| self.val_dataset: Dataset | None = None | |
| self.test_dataset: Dataset | None = None | |
| def _create_dataset(self, config: AudioDataConfig) -> Dataset: | |
| return ChunkedAudioDataset( | |
| csv_path=config.csv_path, | |
| audio_root=config.audio_root, | |
| chunk_size=config.chunk_size, | |
| hop_size=config.chunk_hop_size, | |
| mono=config.mono, | |
| normalize=config.normalize, | |
| target_sample_rate=config.sample_rate, | |
| ) | |
| def setup(self, stage: str | None = None): | |
| if stage == "fit" or stage is None: | |
| self.train_dataset = self._create_dataset(self.train_config) | |
| self.val_dataset = self._create_dataset(self.val_config) | |
| elif stage == "validate": | |
| self.val_dataset = self._create_dataset(self.val_config) | |
| elif stage == "test" or stage == "predict": | |
| self.test_dataset = self._create_dataset(self.test_config) | |
| def train_dataloader(self) -> DataLoader: | |
| return DataLoader( | |
| self.train_dataset, | |
| batch_size=self.train_config.batch_size, | |
| num_workers=self.train_config.num_workers, | |
| pin_memory=self.train_config.pin_memory, | |
| persistent_workers=self.train_config.persistent_workers if self.train_config.num_workers > 0 else False, | |
| shuffle=self.train_config.shuffle, | |
| drop_last=self.train_config.drop_last, | |
| collate_fn=audio_collate_fn, | |
| ) | |
| def val_dataloader(self) -> DataLoader: | |
| return DataLoader( | |
| self.val_dataset, | |
| batch_size=self.val_config.batch_size, | |
| num_workers=self.val_config.num_workers, | |
| pin_memory=self.val_config.pin_memory, | |
| persistent_workers=self.val_config.persistent_workers if self.val_config.num_workers > 0 else False, | |
| shuffle=False, | |
| drop_last=False, | |
| collate_fn=audio_collate_fn, | |
| ) | |
| def test_dataloader(self) -> DataLoader: | |
| return DataLoader( | |
| self.test_dataset, | |
| batch_size=self.test_config.batch_size, | |
| num_workers=self.test_config.num_workers, | |
| pin_memory=self.test_config.pin_memory, | |
| persistent_workers=self.test_config.persistent_workers if self.test_config.num_workers > 0 else False, | |
| shuffle=False, | |
| drop_last=False, | |
| collate_fn=audio_collate_fn, | |
| ) | |
| def predict_dataloader(self) -> DataLoader: | |
| return DataLoader( | |
| self.test_dataset, | |
| batch_size=self.test_config.batch_size, | |
| num_workers=self.test_config.num_workers, | |
| pin_memory=self.test_config.pin_memory, | |
| persistent_workers=self.test_config.persistent_workers if self.test_config.num_workers > 0 else False, | |
| shuffle=False, | |
| drop_last=False, | |
| collate_fn=audio_collate_fn, | |
| ) | |