File size: 5,030 Bytes
2cba492
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
from dataclasses import dataclass
from pathlib import Path

import lightning as L
import torch
from torch.utils.data import DataLoader, Dataset

from ..util import get_logger
from .dataset import AudioItem, ChunkedAudioDataset, pad_audio

logger = get_logger()


@dataclass
class AudioBatch:
    waveform: torch.Tensor  # [batch, channels, samples]
    audio_ids: list[str]
    paths: list[Path]
    sample_rates: list[int]
    frame_offsets: list[int] | None  # For chunked audio


@dataclass
class AudioDataConfig:
    csv_path: str
    audio_root: str

    # Audio processing
    sample_rate: int | None = 16000
    mono: bool = True
    normalize: bool = True

    # Chunking options
    chunk_size: int | None = None
    chunk_hop_size: int | None = None

    # DataLoader options
    batch_size: int = 32
    num_workers: int = 4
    pin_memory: bool = False
    persistent_workers: bool = False
    shuffle: bool = False
    drop_last: bool = False


def audio_collate_fn(batch: list[AudioItem]) -> AudioBatch:
    waveforms = [item.waveform for item in batch]

    # Pad all waveforms to max length
    max_length = max(wave.shape[1] for wave in waveforms)
    if any(wave.shape[1] != max_length for wave in waveforms):
        waveforms = [pad_audio(wave, max_length) for wave in waveforms]

    return AudioBatch(
        waveform=torch.stack(waveforms),
        audio_ids=[item.audio_id for item in batch],
        paths=[item.path for item in batch],
        sample_rates=[item.sample_rate for item in batch],
        frame_offsets=[item.frame_offset for item in batch],
    )


class AudioDataModule(L.LightningDataModule):
    def __init__(
        self,
        train_config: AudioDataConfig,
        val_config: AudioDataConfig | None = None,
        test_config: AudioDataConfig | None = None,
    ):
        super().__init__()
        self.train_config = train_config
        self.val_config = val_config or train_config
        self.test_config = test_config or self.val_config

        # Set to be initialized in setup()
        self.train_dataset: Dataset | None = None
        self.val_dataset: Dataset | None = None
        self.test_dataset: Dataset | None = None

    def _create_dataset(self, config: AudioDataConfig) -> Dataset:
        return ChunkedAudioDataset(
            csv_path=config.csv_path,
            audio_root=config.audio_root,
            chunk_size=config.chunk_size,
            hop_size=config.chunk_hop_size,
            mono=config.mono,
            normalize=config.normalize,
            target_sample_rate=config.sample_rate,
        )

    def setup(self, stage: str | None = None):
        if stage == "fit" or stage is None:
            self.train_dataset = self._create_dataset(self.train_config)
            self.val_dataset = self._create_dataset(self.val_config)
        elif stage == "validate":
            self.val_dataset = self._create_dataset(self.val_config)
        elif stage == "test" or stage == "predict":
            self.test_dataset = self._create_dataset(self.test_config)

    def train_dataloader(self) -> DataLoader:
        return DataLoader(
            self.train_dataset,
            batch_size=self.train_config.batch_size,
            num_workers=self.train_config.num_workers,
            pin_memory=self.train_config.pin_memory,
            persistent_workers=self.train_config.persistent_workers if self.train_config.num_workers > 0 else False,
            shuffle=self.train_config.shuffle,
            drop_last=self.train_config.drop_last,
            collate_fn=audio_collate_fn,
        )

    def val_dataloader(self) -> DataLoader:
        return DataLoader(
            self.val_dataset,
            batch_size=self.val_config.batch_size,
            num_workers=self.val_config.num_workers,
            pin_memory=self.val_config.pin_memory,
            persistent_workers=self.val_config.persistent_workers if self.val_config.num_workers > 0 else False,
            shuffle=False,
            drop_last=False,
            collate_fn=audio_collate_fn,
        )

    def test_dataloader(self) -> DataLoader:
        return DataLoader(
            self.test_dataset,
            batch_size=self.test_config.batch_size,
            num_workers=self.test_config.num_workers,
            pin_memory=self.test_config.pin_memory,
            persistent_workers=self.test_config.persistent_workers if self.test_config.num_workers > 0 else False,
            shuffle=False,
            drop_last=False,
            collate_fn=audio_collate_fn,
        )

    def predict_dataloader(self) -> DataLoader:
        return DataLoader(
            self.test_dataset,
            batch_size=self.test_config.batch_size,
            num_workers=self.test_config.num_workers,
            pin_memory=self.test_config.pin_memory,
            persistent_workers=self.test_config.persistent_workers if self.test_config.num_workers > 0 else False,
            shuffle=False,
            drop_last=False,
            collate_fn=audio_collate_fn,
        )