Dalzymodderever
Intial Commit
2cba492
from dataclasses import dataclass
from pathlib import Path
import lightning as L
import torch
from torch.utils.data import DataLoader, Dataset
from ..util import get_logger
from .dataset import AudioItem, ChunkedAudioDataset, pad_audio
logger = get_logger()
@dataclass
class AudioBatch:
waveform: torch.Tensor # [batch, channels, samples]
audio_ids: list[str]
paths: list[Path]
sample_rates: list[int]
frame_offsets: list[int] | None # For chunked audio
@dataclass
class AudioDataConfig:
csv_path: str
audio_root: str
# Audio processing
sample_rate: int | None = 16000
mono: bool = True
normalize: bool = True
# Chunking options
chunk_size: int | None = None
chunk_hop_size: int | None = None
# DataLoader options
batch_size: int = 32
num_workers: int = 4
pin_memory: bool = False
persistent_workers: bool = False
shuffle: bool = False
drop_last: bool = False
def audio_collate_fn(batch: list[AudioItem]) -> AudioBatch:
waveforms = [item.waveform for item in batch]
# Pad all waveforms to max length
max_length = max(wave.shape[1] for wave in waveforms)
if any(wave.shape[1] != max_length for wave in waveforms):
waveforms = [pad_audio(wave, max_length) for wave in waveforms]
return AudioBatch(
waveform=torch.stack(waveforms),
audio_ids=[item.audio_id for item in batch],
paths=[item.path for item in batch],
sample_rates=[item.sample_rate for item in batch],
frame_offsets=[item.frame_offset for item in batch],
)
class AudioDataModule(L.LightningDataModule):
def __init__(
self,
train_config: AudioDataConfig,
val_config: AudioDataConfig | None = None,
test_config: AudioDataConfig | None = None,
):
super().__init__()
self.train_config = train_config
self.val_config = val_config or train_config
self.test_config = test_config or self.val_config
# Set to be initialized in setup()
self.train_dataset: Dataset | None = None
self.val_dataset: Dataset | None = None
self.test_dataset: Dataset | None = None
def _create_dataset(self, config: AudioDataConfig) -> Dataset:
return ChunkedAudioDataset(
csv_path=config.csv_path,
audio_root=config.audio_root,
chunk_size=config.chunk_size,
hop_size=config.chunk_hop_size,
mono=config.mono,
normalize=config.normalize,
target_sample_rate=config.sample_rate,
)
def setup(self, stage: str | None = None):
if stage == "fit" or stage is None:
self.train_dataset = self._create_dataset(self.train_config)
self.val_dataset = self._create_dataset(self.val_config)
elif stage == "validate":
self.val_dataset = self._create_dataset(self.val_config)
elif stage == "test" or stage == "predict":
self.test_dataset = self._create_dataset(self.test_config)
def train_dataloader(self) -> DataLoader:
return DataLoader(
self.train_dataset,
batch_size=self.train_config.batch_size,
num_workers=self.train_config.num_workers,
pin_memory=self.train_config.pin_memory,
persistent_workers=self.train_config.persistent_workers if self.train_config.num_workers > 0 else False,
shuffle=self.train_config.shuffle,
drop_last=self.train_config.drop_last,
collate_fn=audio_collate_fn,
)
def val_dataloader(self) -> DataLoader:
return DataLoader(
self.val_dataset,
batch_size=self.val_config.batch_size,
num_workers=self.val_config.num_workers,
pin_memory=self.val_config.pin_memory,
persistent_workers=self.val_config.persistent_workers if self.val_config.num_workers > 0 else False,
shuffle=False,
drop_last=False,
collate_fn=audio_collate_fn,
)
def test_dataloader(self) -> DataLoader:
return DataLoader(
self.test_dataset,
batch_size=self.test_config.batch_size,
num_workers=self.test_config.num_workers,
pin_memory=self.test_config.pin_memory,
persistent_workers=self.test_config.persistent_workers if self.test_config.num_workers > 0 else False,
shuffle=False,
drop_last=False,
collate_fn=audio_collate_fn,
)
def predict_dataloader(self) -> DataLoader:
return DataLoader(
self.test_dataset,
batch_size=self.test_config.batch_size,
num_workers=self.test_config.num_workers,
pin_memory=self.test_config.pin_memory,
persistent_workers=self.test_config.persistent_workers if self.test_config.num_workers > 0 else False,
shuffle=False,
drop_last=False,
collate_fn=audio_collate_fn,
)