OdioCheck-Backend / backend /dataset.py
JunSiang26's picture
Pure production deploy
9f2b6db
import os
import hashlib
import torch
import torchaudio
import numpy as np
from torch.utils.data import Dataset
import librosa
from scipy.fftpack import dct
def compute_cqcc(wav_np, n_bins, sample_rate=16000, hop_length=160, num_coeffs=20):
"""Compute CQCC features from a mono waveform numpy array."""
try:
cqt = np.abs(
librosa.cqt(
wav_np,
sr=sample_rate,
n_bins=n_bins,
hop_length=hop_length,
fmin=librosa.note_to_hz('C1')
)
)
log_power = librosa.amplitude_to_db(cqt, ref=np.max)
cqcc = dct(log_power, type=2, axis=0, norm='ortho')[:num_coeffs]
return torch.from_numpy(cqcc).unsqueeze(0).float()
except Exception:
# Fallback for very short or invalid audio.
return torch.zeros((1, num_coeffs, 10), dtype=torch.float32)
class AudioDataset(Dataset):
def __init__(self, data_dir=None, n_bins=60, augment=False, cqcc_cache_dir=None, target_lang=None):
if data_dir is None:
# Check if MLAAD-tiny exists, else fallback to 'data'
mlaad_dir = os.path.join(os.path.dirname(__file__), "..", "MLAAD-tiny")
if os.path.exists(mlaad_dir):
data_dir = mlaad_dir
else:
data_dir = os.path.join(os.path.dirname(__file__), "..", "data")
self.data_dir = data_dir
self.files = []
self.labels = []
self.n_bins = n_bins
self.augment = augment
self.cqcc_cache_dir = cqcc_cache_dir
self.target_lang = target_lang
real_path = os.path.join(data_dir, "original")
if not os.path.exists(real_path):
real_path = os.path.join(data_dir, "real")
fake_path = os.path.join(data_dir, "fake")
for root, dirs, files in os.walk(real_path):
dirs.sort()
files.sort()
for f in files:
if f.endswith('.wav') or f.endswith('.flac'):
if self.target_lang:
rel_root = os.path.relpath(root, real_path).replace('\\', '/')
if not rel_root.startswith(self.target_lang):
continue
self.files.append(os.path.join(root, f))
self.labels.append(0) # 0 = Real
for root, dirs, files in os.walk(fake_path):
dirs.sort()
files.sort()
for f in files:
if f.endswith('.wav') or f.endswith('.flac'):
if self.target_lang:
rel_root = os.path.relpath(root, fake_path).replace('\\', '/')
if not rel_root.startswith(self.target_lang):
continue
self.files.append(os.path.join(root, f))
self.labels.append(1) # 1 = Fake
if self.cqcc_cache_dir is not None:
os.makedirs(self.cqcc_cache_dir, exist_ok=True)
def __len__(self):
return len(self.files)
def _cqcc_cache_path(self, audio_path):
rel_path = os.path.relpath(audio_path, start=self.data_dir)
cache_key = hashlib.md5(audio_path.encode("utf-8")).hexdigest()
rel_stem = os.path.splitext(rel_path)[0]
safe_name = rel_stem.replace(os.sep, "__")
return os.path.join(self.cqcc_cache_dir, f"{safe_name}_{cache_key}.pt")
def _load_or_compute_cqcc(self, audio_path, wav_np, is_augmented=False):
if self.cqcc_cache_dir is None or is_augmented:
return compute_cqcc(wav_np, n_bins=self.n_bins)
cache_path = self._cqcc_cache_path(audio_path)
if os.path.exists(cache_path):
return torch.load(cache_path, map_location="cpu")
cqcc = compute_cqcc(wav_np, n_bins=self.n_bins)
torch.save(cqcc, cache_path)
return cqcc
def precompute_cqcc_cache(self, force=False):
"""Materialize CQCC features to disk so training can reuse them."""
import tqdm
if self.cqcc_cache_dir is None:
raise ValueError("cqcc_cache_dir must be set to precompute CQCC features.")
try:
from tqdm.notebook import tqdm
iterable_files = tqdm(self.files, desc="Precomputing CQCC Cache")
except ImportError:
iterable_files = self.files
total = len(self.files)
for idx, audio_path in enumerate(iterable_files):
cache_path = self._cqcc_cache_path(audio_path)
if not force and os.path.exists(cache_path):
continue
try:
wav_np, _ = librosa.load(audio_path, sr=16000, mono=True)
cqcc = compute_cqcc(wav_np, n_bins=self.n_bins)
torch.save(cqcc, cache_path)
except Exception as e:
print(f"Error precomputing CQCC for {audio_path}: {e}")
if (idx + 1) % 100 == 0 or idx + 1 == total:
print(f"Precomputed CQCC {idx + 1}/{total}")
def __getitem__(self, idx):
audio_path = self.files[idx]
wav_np, sr = librosa.load(audio_path, sr=16000, mono=True)
is_augmented = False
# Augmentation on raw audio (Data Augmentation for generalizability)
if self.augment and np.random.rand() < 0.3:
# Apply only ONE augmentation type per sample to avoid over-modification
aug_type = np.random.choice(['noise', 'speed', 'pitch'], p=[0.33, 0.33, 0.34])
if aug_type == 'noise':
# SNR-based noise addition (reverted to original robust method)
signal_power = np.mean(wav_np**2)
if signal_power > 1e-10:
snr_db = np.random.uniform(10, 30)
snr_linear = 10**(snr_db / 10)
noise_power = signal_power / snr_linear
noise = np.random.randn(len(wav_np)) * np.sqrt(noise_power)
wav_np = wav_np + noise
is_augmented = True
elif aug_type == 'speed':
# Mild speed perturbation
speed_factor = np.random.uniform(0.95, 1.05)
wav_np = librosa.effects.time_stretch(wav_np, rate=speed_factor)
is_augmented = True
elif aug_type == 'pitch':
# Subtle pitch shift
n_steps = np.random.uniform(-1, 1)
wav_np = librosa.effects.pitch_shift(wav_np, sr=sr, n_steps=n_steps)
is_augmented = True
# Crop or pad to exactly 64600 samples (AASIST standard)
target_len = 64600
if len(wav_np) > target_len:
# Center crop or random crop for augment instead of taking just the start.
if self.augment:
start = np.random.randint(0, len(wav_np) - target_len)
else:
start = (len(wav_np) - target_len) // 2
wav_np = wav_np[start:start+target_len]
elif len(wav_np) < target_len:
pad = target_len - len(wav_np)
wav_np = np.pad(wav_np, (0, pad), 'constant')
wav = torch.from_numpy(wav_np).unsqueeze(0).float()
cqcc = self._load_or_compute_cqcc(audio_path, wav_np, is_augmented=is_augmented)
return wav, cqcc, self.labels[idx]
def collate_variable_length(batch):
wavs, cqccs, labels = zip(*batch)
labels = torch.tensor(labels)
# ---------- WAVE ----------
max_wav_len = max(w.shape[-1] for w in wavs)
wavs_padded = []
for w in wavs:
if w.shape[-1] < max_wav_len:
pad = max_wav_len - w.shape[-1]
w = torch.nn.functional.pad(w, (0, pad))
wavs_padded.append(w)
wavs = torch.stack(wavs_padded, dim=0)
# ---------- CQCC ----------
max_cqcc_len = max(c.shape[-1] for c in cqccs)
cqccs_padded = []
for c in cqccs:
if c.shape[-1] < max_cqcc_len:
pad = max_cqcc_len - c.shape[-1]
c = torch.nn.functional.pad(c, (0, pad))
cqccs_padded.append(c)
cqccs = torch.stack(cqccs_padded, dim=0)
return wavs, cqccs, labels