File size: 8,184 Bytes
9f2b6db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
import os
import hashlib
import torch
import torchaudio
import numpy as np
from torch.utils.data import Dataset
import librosa
from scipy.fftpack import dct

def compute_cqcc(wav_np, n_bins, sample_rate=16000, hop_length=160, num_coeffs=20):
    """Compute CQCC features from a mono waveform numpy array."""
    try:
        cqt = np.abs(
            librosa.cqt(
                wav_np,
                sr=sample_rate,
                n_bins=n_bins,
                hop_length=hop_length,
                fmin=librosa.note_to_hz('C1')
            )
        )
        log_power = librosa.amplitude_to_db(cqt, ref=np.max)
        cqcc = dct(log_power, type=2, axis=0, norm='ortho')[:num_coeffs]
        return torch.from_numpy(cqcc).unsqueeze(0).float()
    except Exception:
        # Fallback for very short or invalid audio.
        return torch.zeros((1, num_coeffs, 10), dtype=torch.float32)

class AudioDataset(Dataset):
    def __init__(self, data_dir=None, n_bins=60, augment=False, cqcc_cache_dir=None, target_lang=None):
        if data_dir is None:
            # Check if MLAAD-tiny exists, else fallback to 'data'
            mlaad_dir = os.path.join(os.path.dirname(__file__), "..", "MLAAD-tiny")
            if os.path.exists(mlaad_dir):
                data_dir = mlaad_dir
            else:
                data_dir = os.path.join(os.path.dirname(__file__), "..", "data")
                
        self.data_dir = data_dir
        self.files = []
        self.labels = []
        self.n_bins = n_bins
        self.augment = augment
        self.cqcc_cache_dir = cqcc_cache_dir
        self.target_lang = target_lang

        real_path = os.path.join(data_dir, "original")
        if not os.path.exists(real_path):
            real_path = os.path.join(data_dir, "real")
            
        fake_path = os.path.join(data_dir, "fake")
            
        for root, dirs, files in os.walk(real_path):
            dirs.sort()
            files.sort()
            for f in files:
                if f.endswith('.wav') or f.endswith('.flac'):
                    if self.target_lang:
                        rel_root = os.path.relpath(root, real_path).replace('\\', '/')
                        if not rel_root.startswith(self.target_lang):
                            continue
                    self.files.append(os.path.join(root, f))
                    self.labels.append(0) # 0 = Real

        for root, dirs, files in os.walk(fake_path):
            dirs.sort()
            files.sort()
            for f in files:
                if f.endswith('.wav') or f.endswith('.flac'):
                    if self.target_lang:
                        rel_root = os.path.relpath(root, fake_path).replace('\\', '/')
                        if not rel_root.startswith(self.target_lang):
                            continue
                    self.files.append(os.path.join(root, f))
                    self.labels.append(1) # 1 = Fake

        if self.cqcc_cache_dir is not None:
            os.makedirs(self.cqcc_cache_dir, exist_ok=True)

    def __len__(self):
        return len(self.files)

    def _cqcc_cache_path(self, audio_path):
        rel_path = os.path.relpath(audio_path, start=self.data_dir)
        cache_key = hashlib.md5(audio_path.encode("utf-8")).hexdigest()
        rel_stem = os.path.splitext(rel_path)[0]
        safe_name = rel_stem.replace(os.sep, "__")
        return os.path.join(self.cqcc_cache_dir, f"{safe_name}_{cache_key}.pt")

    def _load_or_compute_cqcc(self, audio_path, wav_np, is_augmented=False):
        if self.cqcc_cache_dir is None or is_augmented:
            return compute_cqcc(wav_np, n_bins=self.n_bins)

        cache_path = self._cqcc_cache_path(audio_path)
        if os.path.exists(cache_path):
            return torch.load(cache_path, map_location="cpu")

        cqcc = compute_cqcc(wav_np, n_bins=self.n_bins)
        torch.save(cqcc, cache_path)
        return cqcc

    def precompute_cqcc_cache(self, force=False):
        """Materialize CQCC features to disk so training can reuse them."""
        import tqdm
        if self.cqcc_cache_dir is None:
            raise ValueError("cqcc_cache_dir must be set to precompute CQCC features.")

        try:
            from tqdm.notebook import tqdm
            iterable_files = tqdm(self.files, desc="Precomputing CQCC Cache")
        except ImportError:
            iterable_files = self.files

        total = len(self.files)
        for idx, audio_path in enumerate(iterable_files):
            cache_path = self._cqcc_cache_path(audio_path)
            if not force and os.path.exists(cache_path):
                continue

            try:
                wav_np, _ = librosa.load(audio_path, sr=16000, mono=True)
                cqcc = compute_cqcc(wav_np, n_bins=self.n_bins)
                torch.save(cqcc, cache_path)
            except Exception as e:
                print(f"Error precomputing CQCC for {audio_path}: {e}")


            if (idx + 1) % 100 == 0 or idx + 1 == total:
                print(f"Precomputed CQCC {idx + 1}/{total}")

    def __getitem__(self, idx):
        audio_path = self.files[idx]
        wav_np, sr = librosa.load(audio_path, sr=16000, mono=True)

        is_augmented = False
        # Augmentation on raw audio (Data Augmentation for generalizability)
        if self.augment and np.random.rand() < 0.3:
            # Apply only ONE augmentation type per sample to avoid over-modification
            aug_type = np.random.choice(['noise', 'speed', 'pitch'], p=[0.33, 0.33, 0.34])

            if aug_type == 'noise':
                # SNR-based noise addition (reverted to original robust method)
                signal_power = np.mean(wav_np**2)
                if signal_power > 1e-10:
                    snr_db = np.random.uniform(10, 30)
                    snr_linear = 10**(snr_db / 10)
                    noise_power = signal_power / snr_linear
                    noise = np.random.randn(len(wav_np)) * np.sqrt(noise_power)
                    wav_np = wav_np + noise
                is_augmented = True
            elif aug_type == 'speed':
                # Mild speed perturbation
                speed_factor = np.random.uniform(0.95, 1.05)
                wav_np = librosa.effects.time_stretch(wav_np, rate=speed_factor)
                is_augmented = True
            elif aug_type == 'pitch':
                # Subtle pitch shift
                n_steps = np.random.uniform(-1, 1)
                wav_np = librosa.effects.pitch_shift(wav_np, sr=sr, n_steps=n_steps)
                is_augmented = True

        # Crop or pad to exactly 64600 samples (AASIST standard)
        target_len = 64600
        if len(wav_np) > target_len:
            # Center crop or random crop for augment instead of taking just the start.
            if self.augment:
                start = np.random.randint(0, len(wav_np) - target_len)
            else:
                start = (len(wav_np) - target_len) // 2
            wav_np = wav_np[start:start+target_len]
        elif len(wav_np) < target_len:
            pad = target_len - len(wav_np)  
            wav_np = np.pad(wav_np, (0, pad), 'constant')

        wav = torch.from_numpy(wav_np).unsqueeze(0).float()

        cqcc = self._load_or_compute_cqcc(audio_path, wav_np, is_augmented=is_augmented)

        return wav, cqcc, self.labels[idx]


def collate_variable_length(batch):

    wavs, cqccs, labels = zip(*batch)
    labels = torch.tensor(labels)

    # ---------- WAVE ----------
    max_wav_len = max(w.shape[-1] for w in wavs)

    wavs_padded = []
    for w in wavs:
        if w.shape[-1] < max_wav_len:
            pad = max_wav_len - w.shape[-1]
            w = torch.nn.functional.pad(w, (0, pad))
        wavs_padded.append(w)

    wavs = torch.stack(wavs_padded, dim=0)
    
    # ---------- CQCC ----------
    max_cqcc_len = max(c.shape[-1] for c in cqccs)

    cqccs_padded = []
    for c in cqccs:
        if c.shape[-1] < max_cqcc_len:
            pad = max_cqcc_len - c.shape[-1]
            c = torch.nn.functional.pad(c, (0, pad))
        cqccs_padded.append(c)

    cqccs = torch.stack(cqccs_padded, dim=0)

    return wavs, cqccs, labels