File size: 5,762 Bytes
eca55dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import torch
from torch.utils.data import Dataset, DataLoader
import lightning as L
from typing import Optional, List, Dict, Any, Union
from functools import partial


class MockAudioSetDataset(Dataset):
    """
    Mock Dataset for AudioSet data that generates random noise.
    """

    def __init__(
        self,
        length: int = 100,
        max_length: int = 160000,
        target_sample_rate: int = 16000,
    ):
        self.length = length
        self.max_length = max_length
        self.target_sample_rate = target_sample_rate

    def __len__(self) -> int:
        return self.length

    def __getitem__(self, idx: int) -> Dict[str, Union[torch.Tensor, str, int]]:
        # Generate random waveform [1, T]
        # Random length between max_length // 2 and max_length for realism, or just fixed max_length
        # Let's do fixed max_length for simplicity in mock
        waveform = torch.randn(1, self.max_length)

        # Fake target (multi-hot) - assuming AudioSet has 527 classes
        target = torch.zeros(527)
        # Set a few random classes to 1
        indices = torch.randint(0, 527, (3,))
        target[indices] = 1.0

        audio_name = f"mock_audio_{idx}"

        return {
            "waveform": waveform,
            "target": target,
            "audio_name": audio_name,
            "index": idx,
        }


class MockAudioSetDataModule(L.LightningDataModule):
    """
    LightningDataModule for Mock AudioSet.
    """

    def __init__(
        self,
        batch_size: int = 8,
        num_workers: int = 0,
        pin_memory: bool = True,
        max_audio_length_sec: float = 10.0,
        target_sample_rate: int = 16000,
        collate_mode: str = "pad",
    ):
        super().__init__()
        self.save_hyperparameters()

        self.batch_size = batch_size
        self.num_workers = num_workers
        self.pin_memory = pin_memory
        self.collate_mode = collate_mode

        self.max_audio_length = int(max_audio_length_sec * target_sample_rate)

        self.train_dataset: Optional[MockAudioSetDataset] = None
        self.val_dataset: Optional[MockAudioSetDataset] = None
        self.test_dataset: Optional[MockAudioSetDataset] = None

    def setup(self, stage: Optional[str] = None) -> None:
        if stage == "fit" or stage is None:
            self.train_dataset = MockAudioSetDataset(
                length=1000,  # Fake dataset size
                max_length=self.max_audio_length,
                target_sample_rate=self.hparams.target_sample_rate,
            )
            self.val_dataset = MockAudioSetDataset(
                length=100,
                max_length=self.max_audio_length,
                target_sample_rate=self.hparams.target_sample_rate,
            )

        if stage == "test":
            self.test_dataset = MockAudioSetDataset(
                length=50,
                max_length=self.max_audio_length,
                target_sample_rate=self.hparams.target_sample_rate,
            )

    def train_dataloader(self) -> DataLoader:
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=self.pin_memory,
            persistent_workers=self.num_workers > 0,
            collate_fn=partial(self.collate_fn, mode=self.collate_mode),
        )

    def val_dataloader(self) -> DataLoader:
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=self.pin_memory,
            persistent_workers=self.num_workers > 0,
            collate_fn=partial(self.collate_fn, mode=self.collate_mode),
        )

    def test_dataloader(self) -> DataLoader:
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=self.pin_memory,
            collate_fn=partial(self.collate_fn, mode=self.collate_mode),
        )

    @staticmethod
    def collate_fn(batch: List[Dict[str, Any]], mode: str = "pad") -> Dict[str, Any]:
        """
        Collate function to pad or truncate waveforms.
        """
        waveforms = [item["waveform"] for item in batch]  # List of [1, T]
        targets = torch.stack([item["target"] for item in batch])
        audio_names = [item["audio_name"] for item in batch]
        indices = [item["index"] for item in batch]

        # Find max or min length in the batch
        lengths = [w.shape[-1] for w in waveforms]

        if mode == "pad":
            target_wave_len = max(lengths)
        elif mode == "truncate":
            target_wave_len = min(lengths)
        else:
            raise ValueError(f"Unknown collate mode: {mode}")

        # Pad or Truncate waveforms
        processed_waveforms = []
        for w in waveforms:
            current_len = w.shape[-1]
            if current_len < target_wave_len:
                pad_amount = target_wave_len - current_len
                # Pad at the end
                w_padded = torch.nn.functional.pad(w, (0, pad_amount))
                processed_waveforms.append(w_padded)
            elif current_len > target_wave_len:
                # Truncate
                w_truncated = w[..., :target_wave_len]
                processed_waveforms.append(w_truncated)
            else:
                processed_waveforms.append(w)

        processed_waveforms = torch.stack(processed_waveforms)

        return {
            "waveform": processed_waveforms,
            "target": targets,
            "audio_name": audio_names,
            "index": indices,
        }