Spaces:
Running
Running
| import os | |
| import hashlib | |
| import torch | |
| import torchaudio | |
| import numpy as np | |
| from torch.utils.data import Dataset | |
| import librosa | |
| from scipy.fftpack import dct | |
| def compute_cqcc(wav_np, n_bins, sample_rate=16000, hop_length=160, num_coeffs=20): | |
| """Compute CQCC features from a mono waveform numpy array.""" | |
| try: | |
| cqt = np.abs( | |
| librosa.cqt( | |
| wav_np, | |
| sr=sample_rate, | |
| n_bins=n_bins, | |
| hop_length=hop_length, | |
| fmin=librosa.note_to_hz('C1') | |
| ) | |
| ) | |
| log_power = librosa.amplitude_to_db(cqt, ref=np.max) | |
| cqcc = dct(log_power, type=2, axis=0, norm='ortho')[:num_coeffs] | |
| return torch.from_numpy(cqcc).unsqueeze(0).float() | |
| except Exception: | |
| # Fallback for very short or invalid audio. | |
| return torch.zeros((1, num_coeffs, 10), dtype=torch.float32) | |
| class AudioDataset(Dataset): | |
| def __init__(self, data_dir=None, n_bins=60, augment=False, cqcc_cache_dir=None, target_lang=None): | |
| if data_dir is None: | |
| # Check if MLAAD-tiny exists, else fallback to 'data' | |
| mlaad_dir = os.path.join(os.path.dirname(__file__), "..", "MLAAD-tiny") | |
| if os.path.exists(mlaad_dir): | |
| data_dir = mlaad_dir | |
| else: | |
| data_dir = os.path.join(os.path.dirname(__file__), "..", "data") | |
| self.data_dir = data_dir | |
| self.files = [] | |
| self.labels = [] | |
| self.n_bins = n_bins | |
| self.augment = augment | |
| self.cqcc_cache_dir = cqcc_cache_dir | |
| self.target_lang = target_lang | |
| real_path = os.path.join(data_dir, "original") | |
| if not os.path.exists(real_path): | |
| real_path = os.path.join(data_dir, "real") | |
| fake_path = os.path.join(data_dir, "fake") | |
| for root, dirs, files in os.walk(real_path): | |
| dirs.sort() | |
| files.sort() | |
| for f in files: | |
| if f.endswith('.wav') or f.endswith('.flac'): | |
| if self.target_lang: | |
| rel_root = os.path.relpath(root, real_path).replace('\\', '/') | |
| if not rel_root.startswith(self.target_lang): | |
| continue | |
| self.files.append(os.path.join(root, f)) | |
| self.labels.append(0) # 0 = Real | |
| for root, dirs, files in os.walk(fake_path): | |
| dirs.sort() | |
| files.sort() | |
| for f in files: | |
| if f.endswith('.wav') or f.endswith('.flac'): | |
| if self.target_lang: | |
| rel_root = os.path.relpath(root, fake_path).replace('\\', '/') | |
| if not rel_root.startswith(self.target_lang): | |
| continue | |
| self.files.append(os.path.join(root, f)) | |
| self.labels.append(1) # 1 = Fake | |
| if self.cqcc_cache_dir is not None: | |
| os.makedirs(self.cqcc_cache_dir, exist_ok=True) | |
| def __len__(self): | |
| return len(self.files) | |
| def _cqcc_cache_path(self, audio_path): | |
| rel_path = os.path.relpath(audio_path, start=self.data_dir) | |
| cache_key = hashlib.md5(audio_path.encode("utf-8")).hexdigest() | |
| rel_stem = os.path.splitext(rel_path)[0] | |
| safe_name = rel_stem.replace(os.sep, "__") | |
| return os.path.join(self.cqcc_cache_dir, f"{safe_name}_{cache_key}.pt") | |
| def _load_or_compute_cqcc(self, audio_path, wav_np, is_augmented=False): | |
| if self.cqcc_cache_dir is None or is_augmented: | |
| return compute_cqcc(wav_np, n_bins=self.n_bins) | |
| cache_path = self._cqcc_cache_path(audio_path) | |
| if os.path.exists(cache_path): | |
| return torch.load(cache_path, map_location="cpu") | |
| cqcc = compute_cqcc(wav_np, n_bins=self.n_bins) | |
| torch.save(cqcc, cache_path) | |
| return cqcc | |
| def precompute_cqcc_cache(self, force=False): | |
| """Materialize CQCC features to disk so training can reuse them.""" | |
| import tqdm | |
| if self.cqcc_cache_dir is None: | |
| raise ValueError("cqcc_cache_dir must be set to precompute CQCC features.") | |
| try: | |
| from tqdm.notebook import tqdm | |
| iterable_files = tqdm(self.files, desc="Precomputing CQCC Cache") | |
| except ImportError: | |
| iterable_files = self.files | |
| total = len(self.files) | |
| for idx, audio_path in enumerate(iterable_files): | |
| cache_path = self._cqcc_cache_path(audio_path) | |
| if not force and os.path.exists(cache_path): | |
| continue | |
| try: | |
| wav_np, _ = librosa.load(audio_path, sr=16000, mono=True) | |
| cqcc = compute_cqcc(wav_np, n_bins=self.n_bins) | |
| torch.save(cqcc, cache_path) | |
| except Exception as e: | |
| print(f"Error precomputing CQCC for {audio_path}: {e}") | |
| if (idx + 1) % 100 == 0 or idx + 1 == total: | |
| print(f"Precomputed CQCC {idx + 1}/{total}") | |
| def __getitem__(self, idx): | |
| audio_path = self.files[idx] | |
| wav_np, sr = librosa.load(audio_path, sr=16000, mono=True) | |
| is_augmented = False | |
| # Augmentation on raw audio (Data Augmentation for generalizability) | |
| if self.augment and np.random.rand() < 0.3: | |
| # Apply only ONE augmentation type per sample to avoid over-modification | |
| aug_type = np.random.choice(['noise', 'speed', 'pitch'], p=[0.33, 0.33, 0.34]) | |
| if aug_type == 'noise': | |
| # SNR-based noise addition (reverted to original robust method) | |
| signal_power = np.mean(wav_np**2) | |
| if signal_power > 1e-10: | |
| snr_db = np.random.uniform(10, 30) | |
| snr_linear = 10**(snr_db / 10) | |
| noise_power = signal_power / snr_linear | |
| noise = np.random.randn(len(wav_np)) * np.sqrt(noise_power) | |
| wav_np = wav_np + noise | |
| is_augmented = True | |
| elif aug_type == 'speed': | |
| # Mild speed perturbation | |
| speed_factor = np.random.uniform(0.95, 1.05) | |
| wav_np = librosa.effects.time_stretch(wav_np, rate=speed_factor) | |
| is_augmented = True | |
| elif aug_type == 'pitch': | |
| # Subtle pitch shift | |
| n_steps = np.random.uniform(-1, 1) | |
| wav_np = librosa.effects.pitch_shift(wav_np, sr=sr, n_steps=n_steps) | |
| is_augmented = True | |
| # Crop or pad to exactly 64600 samples (AASIST standard) | |
| target_len = 64600 | |
| if len(wav_np) > target_len: | |
| # Center crop or random crop for augment instead of taking just the start. | |
| if self.augment: | |
| start = np.random.randint(0, len(wav_np) - target_len) | |
| else: | |
| start = (len(wav_np) - target_len) // 2 | |
| wav_np = wav_np[start:start+target_len] | |
| elif len(wav_np) < target_len: | |
| pad = target_len - len(wav_np) | |
| wav_np = np.pad(wav_np, (0, pad), 'constant') | |
| wav = torch.from_numpy(wav_np).unsqueeze(0).float() | |
| cqcc = self._load_or_compute_cqcc(audio_path, wav_np, is_augmented=is_augmented) | |
| return wav, cqcc, self.labels[idx] | |
| def collate_variable_length(batch): | |
| wavs, cqccs, labels = zip(*batch) | |
| labels = torch.tensor(labels) | |
| # ---------- WAVE ---------- | |
| max_wav_len = max(w.shape[-1] for w in wavs) | |
| wavs_padded = [] | |
| for w in wavs: | |
| if w.shape[-1] < max_wav_len: | |
| pad = max_wav_len - w.shape[-1] | |
| w = torch.nn.functional.pad(w, (0, pad)) | |
| wavs_padded.append(w) | |
| wavs = torch.stack(wavs_padded, dim=0) | |
| # ---------- CQCC ---------- | |
| max_cqcc_len = max(c.shape[-1] for c in cqccs) | |
| cqccs_padded = [] | |
| for c in cqccs: | |
| if c.shape[-1] < max_cqcc_len: | |
| pad = max_cqcc_len - c.shape[-1] | |
| c = torch.nn.functional.pad(c, (0, pad)) | |
| cqccs_padded.append(c) | |
| cqccs = torch.stack(cqccs_padded, dim=0) | |
| return wavs, cqccs, labels | |