JunSiang26's picture
Pure production deploy
9f2b6db
import os
import argparse
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from dataset import AudioDataset, collate_variable_length
from models import (
AASISTDetector,
Wav2Vec2SpoofDetector,
CQCCBaselineDetector,
ImprovedWav2Vec2CQCCDetector,
AblationWav2Vec2GraphDetector,
AblationCQCCGraphDetector,
AblationConcatGraphDetector,
AblationCrossAttnLinearDetector
)
from sklearn.metrics import roc_curve, auc
import numpy as np
import random
from tqdm import tqdm
def train_model(model, train_dataloader, criterion, optimizer, epochs=5, input_type='wav', device=None, val_dataloader=None, eval_interval=1, patience=2, model_save_path=None):
if device is None:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
loss_history = []
best_val_metric = float('inf') # For min_dcf, lower is better
patience_counter = 0
best_epoch = 0
for epoch in range(epochs):
model.train()
epoch_loss = 0
correct = 0
total = 0
# Wrap the dataloader with tqdm for a progress bar
for batch_idx, batch in enumerate(tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{epochs} - Training")):
wavs, cqccs, labels = batch
wavs = wavs.to(device)
cqccs = cqccs.to(device)
labels = labels.to(device)
optimizer.zero_grad()
if input_type == 'wav':
outputs = model(wavs)
elif input_type == 'cqcc':
outputs = model(cqccs)
elif input_type == 'wav_and_cqcc':
outputs = model(wavs, cqccs)
else:
raise ValueError("invalid input_type")
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
# Print intermediate progress within the epoch
if batch_idx % 500 == 0 and batch_idx > 0: # Report every 500 batches
current_acc = 100 * correct / total
current_loss = epoch_loss / (batch_idx + 1)
print(f" Batch {batch_idx}/{len(train_dataloader)} | Loss: {current_loss:.4f} | Acc: {current_acc:.2f}%")
acc = 100 * correct / total if total > 0 else 0
avg_loss = epoch_loss / len(train_dataloader)
loss_history.append(avg_loss)
print(f"Epoch {epoch+1}/{epochs} | Training Loss: {avg_loss:.4f} | Training Acc: {acc:.2f}%")
# Validation and Early Stopping
if val_dataloader is not None and (epoch + 1) % eval_interval == 0:
print(f"Epoch {epoch+1}/{epochs} - Evaluating on Validation Set...")
_, _, _, val_eer, val_min_dcf, val_accuracy = evaluate_model(
model, val_dataloader, input_type=input_type, device=device
)
print(f" Validation | EER={val_eer*100:.2f}% | minDCF={val_min_dcf:.4f} | Accuracy={val_accuracy:.2f}")
if val_min_dcf < best_val_metric:
best_val_metric = val_min_dcf
patience_counter = 0
best_epoch = epoch + 1
if model_save_path:
torch.save(model.state_dict(), model_save_path)
print(f" Saved best model to {model_save_path} (minDCF: {best_val_metric:.4f})")
else:
patience_counter += 1
print(f" Validation minDCF did not improve. Patience: {patience_counter}/{patience}")
if patience_counter >= patience:
print(f"Early stopping triggered after {epoch+1} epochs. Best minDCF: {best_val_metric:.4f} at epoch {best_epoch}")
if model_save_path:
print(f"Loading best model from {model_save_path}")
model.load_state_dict(torch.load(model_save_path))
return loss_history # Stop training
# ensure save path logic is intact even when loop ends naturally
if val_dataloader is None and model_save_path is not None:
torch.save(model.state_dict(), model_save_path)
print(f" Saved final model to {model_save_path}")
return loss_history
def evaluate_model(model, dataloader, input_type='wav', device=None):
if device is None:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.eval()
all_labels = []
all_probs = []
with torch.no_grad():
for batch in tqdm(dataloader, desc="Evaluating"):
wavs, cqccs, labels = batch
wavs = wavs.to(device)
cqccs = cqccs.to(device)
labels = labels.to(device)
if input_type == 'wav':
outputs = model(wavs)
elif input_type == 'cqcc':
outputs = model(cqccs)
elif input_type == 'wav_and_cqcc':
outputs = model(wavs, cqccs)
else:
raise ValueError("invalid input_type")
probs = torch.softmax(outputs, dim=1)[:, 1]
all_labels.extend(labels.tolist())
all_probs.extend(probs.tolist())
fpr, tpr, thresholds = roc_curve(all_labels, all_probs)
roc_auc = auc(fpr, tpr)
# ------------------
# EER (Equal Error Rate)
# ------------------
fnr = 1 - tpr
eer_index = np.nanargmin(np.absolute(fnr - fpr))
eer = fpr[eer_index]
# ------------------
# minDCF (Minimum Detection Cost Function)
# Parameters according to ASVspoof 5 Evaluation Plan (Track 1)
# ------------------
P_spoof = 0.05 # Prior probability of a spoofing attack (\pi_{spf})
P_bonafide = 0.95 # Prior probability of a real/bonafide utterance (1 - \pi_{spf})
C_miss = 1 # Cost of falsely rejecting a real voice (Miss)
C_fa = 10 # Cost of falsely accepting a spoof (False Alarm)
# In the dataset, 0 = real (bonafide), 1 = fake (spoof)
# fpr (False Positive Rate) = predicted fake (1) when true is real (0). This is a "miss" in ASVspoof.
# fnr (False Negative Rate) = predicted real (0) when true is fake (1). This is a "false alarm" in ASVspoof.
P_miss = fpr
P_fa = fnr
# Raw DCF = C_miss * P_bonafide * P_miss + C_fa * P_spoof * P_fa
# Normalized by the default DCF (min cost of predicting all bonafide vs all spoof)
dcf_default = min(C_miss * P_bonafide, C_fa * P_spoof)
dcf_array = (C_miss * P_bonafide * P_miss + C_fa * P_spoof * P_fa) / dcf_default
min_dcf = np.min(dcf_array)
# Overall Accuracy (using 0.5 threshold)
preds = [1 if p > 0.5 else 0 for p in all_probs]
correct = sum(1 for p, l in zip(preds, all_labels) if p == l)
accuracy = correct / len(all_labels) if len(all_labels) > 0 else 0
return fpr, tpr, roc_auc, eer, min_dcf, accuracy
def parse_args():
parser = argparse.ArgumentParser(description="Train spoof-detection models with optional CQCC caching.")
parser.add_argument(
"--data-dir",
default=None,
help="Path to dataset root containing original/ and fake/ folders."
)
parser.add_argument(
"--cqcc-cache-dir", # this is where cqcc is stored
default=os.path.join(os.path.dirname(__file__), "precomputed_features", "cqcc"),
help="Directory used to store and reuse precomputed CQCC tensors."
)
parser.add_argument(
"--precompute-cqcc-only",
action="store_true",
help="Only build the CQCC cache and exit without training."
)
parser.add_argument(
"--val-split",
type=float,
default=0.2,
help="Fraction of English training data to reserve for validation."
)
parser.add_argument(
"--force-rebuild-cqcc",
action="store_true",
help="Recompute cached CQCC files even if they already exist."
)
parser.add_argument(
"--smoke-test",
action="store_true",
help="Load one batch, run a forward pass through each model, and exit without training."
)
return parser.parse_args()
def run_smoke_test(dataloader, device):
print("\n--- Running Smoke Test ---")
batch = next(iter(dataloader))
wavs, cqccs, labels = batch
models_to_test = [
("Wav2Vec2 Baseline", Wav2Vec2SpoofDetector(num_classes=2).to(device), "wav"),
("AASIST Baseline", AASISTDetector(num_classes=2).to(device), "wav"),
("CQCC Baseline", CQCCBaselineDetector(num_classes=2).to(device), "cqcc"),
("Custom Fusion Model", ImprovedWav2Vec2CQCCDetector(num_classes=2).to(device), "wav_and_cqcc"),
("Ablation W2V2+Graph", AblationWav2Vec2GraphDetector(num_classes=2).to(device), "wav"),
("Ablation CQCC+Graph", AblationCQCCGraphDetector(num_classes=2).to(device), "cqcc"),
("Ablation Concat+Graph", AblationConcatGraphDetector(num_classes=2).to(device), "wav_and_cqcc"),
("Ablation CrossAttn+Linear", AblationCrossAttnLinearDetector(num_classes=2).to(device), "wav_and_cqcc"),
]
with torch.no_grad():
for name, model, input_type in models_to_test:
model.eval()
if input_type == "wav":
outputs = model(wavs.to(device))
elif input_type == "cqcc":
outputs = model(cqccs.to(device))
elif input_type == "wav_and_cqcc":
outputs = model(wavs.to(device), cqccs.to(device))
else:
raise ValueError("invalid input_type")
print(f"{name}: input OK, output shape = {tuple(outputs.shape)}")
print(f"Labels shape = {tuple(labels.shape)}")
print("Smoke test complete. Cached CQCC loading and model forward passes succeeded.")
def main():
args = parse_args()
print(args)
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(SEED)
g = torch.Generator()
g.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
print("Loading English Dataset for training/validation...")
full_en_dataset = AudioDataset(data_dir=args.data_dir, augment=False, cqcc_cache_dir=args.cqcc_cache_dir, target_lang="en")
total_en = len(full_en_dataset)
if total_en == 0:
raise ValueError("No English data found for target_lang='en'. Check data_dir and directory layout.")
val_split = min(max(args.val_split, 0.0), 0.5)
train_size = int((1.0 - val_split) * total_en)
val_size = total_en - train_size
indices = torch.randperm(total_en, generator=g).tolist()
train_indices = indices[:train_size]
val_indices = indices[train_size:]
train_dataset = torch.utils.data.Subset(
AudioDataset(data_dir=args.data_dir, augment=True, cqcc_cache_dir=args.cqcc_cache_dir, target_lang="en"),
train_indices
)
val_dataset = torch.utils.data.Subset(
AudioDataset(data_dir=args.data_dir, augment=False, cqcc_cache_dir=args.cqcc_cache_dir, target_lang="en"),
val_indices
)
print("Loading German Dataset for Testing...")
test_dataset = AudioDataset(data_dir=args.data_dir, augment=False, cqcc_cache_dir=args.cqcc_cache_dir, target_lang="de")
if args.precompute_cqcc_only:
print("\n--- Starting CQCC Precomputation ---")
print(f"Dataset: {full_en_dataset.data_dir}")
print("Precomputing CQCC cache for English data...")
full_en_dataset.precompute_cqcc_cache(force=args.force_rebuild_cqcc)
test_dataset.precompute_cqcc_cache(force=args.force_rebuild_cqcc)
print("CQCC preprocessing complete. Exiting.")
return
train_loader = DataLoader(
train_dataset,
batch_size=8,
shuffle=True,
collate_fn=collate_variable_length,
num_workers=2,
pin_memory=True,
generator=g, # ensure reproducible shuffling
)
val_loader = DataLoader(
val_dataset,
batch_size=8,
shuffle=False,
collate_fn=collate_variable_length,
num_workers=2,
pin_memory=True
)
test_loader = DataLoader(
test_dataset,
batch_size=8,
shuffle=False,
collate_fn=collate_variable_length,
num_workers=2,
pin_memory=True
)
if args.smoke_test:
run_smoke_test(train_loader, device)
return
models_dir = os.path.join(os.path.dirname(__file__), "models")
os.makedirs(models_dir, exist_ok=True)
criterion = nn.CrossEntropyLoss()
# ============================================================
# 1 Wav2Vec2 Baseline
# ============================================================
print("\n--- Training Wav2Vec2 Baseline ---")
wav2vec_model = Wav2Vec2SpoofDetector(num_classes=2).to(device)
optimizer_wav2vec = torch.optim.Adam(wav2vec_model.parameters(), lr=1e-4)
wav2vec_loss = train_model(
wav2vec_model,
train_loader,
criterion,
optimizer_wav2vec,
input_type='wav',
device=device,
val_dataloader=val_loader,
model_save_path=os.path.join(models_dir, "wav2vec2.pth")
)
del wav2vec_model, optimizer_wav2vec
torch.cuda.empty_cache()
# ============================================================
# 2 AASIST Baseline
# ============================================================
print("\n--- Training AASIST Baseline ---")
aasist_model = AASISTDetector(num_classes=2).to(device)
optimizer_aasist = torch.optim.Adam(aasist_model.parameters(), lr=5e-4)
aasist_loss = train_model(
aasist_model,
train_loader,
criterion,
optimizer_aasist,
input_type='wav',
device=device,
val_dataloader=val_loader,
model_save_path=os.path.join(models_dir, "aasist.pth")
)
del aasist_model, optimizer_aasist
torch.cuda.empty_cache()
# ============================================================
# 3 CQCC Baseline
# ============================================================
print("\n--- Training CQCC Baseline ---")
cqcc_baseline = CQCCBaselineDetector(num_classes=2).to(device)
optimizer_cqcc = torch.optim.Adam(cqcc_baseline.parameters(), lr=1e-4)
cqcc_loss = train_model(
cqcc_baseline,
train_loader,
criterion,
optimizer_cqcc,
input_type='cqcc',
device=device,
val_dataloader=val_loader,
model_save_path=os.path.join(models_dir, "cqcc_baseline.pth")
)
del cqcc_baseline, optimizer_cqcc
torch.cuda.empty_cache()
# ============================================================
# 4 Custom Fusional Wav2Vec2 + CQCC with Cross-Attention + Graph
# ============================================================
print("\n--- Training Custom Fusion Detector ---")
custom_model = ImprovedWav2Vec2CQCCDetector(num_classes=2).to(device)
optimizer_custom = torch.optim.Adam(custom_model.parameters(), lr=1e-4)
custom_loss = train_model(
custom_model,
train_loader,
criterion,
optimizer_custom,
input_type='wav_and_cqcc',
device=device,
val_dataloader=val_loader,
model_save_path=os.path.join(models_dir, "custom_hybrid.pth")
)
del custom_model, optimizer_custom
torch.cuda.empty_cache()
# ============================================================
# 5 Ablation Models
# ============================================================
print("\n--- Training Ablation 1 (Wav2Vec2 + Graph) ---")
ab1_model = AblationWav2Vec2GraphDetector(num_classes=2).to(device)
optimizer_ab1 = torch.optim.Adam(ab1_model.parameters(), lr=1e-4) # learning rate for wav2vec2-based
ab1_loss = train_model(ab1_model, train_loader, criterion, optimizer_ab1, input_type='wav', device=device, val_dataloader=val_loader, model_save_path=os.path.join(models_dir, "ablation_w2v2_graph.pth"))
del ab1_model, optimizer_ab1
torch.cuda.empty_cache()
print("\n--- Training Ablation 2 (CQCC + Graph) ---")
ab2_model = AblationCQCCGraphDetector(num_classes=2).to(device)
optimizer_ab2 = torch.optim.Adam(ab2_model.parameters(), lr=1e-4) # learning rate for CQCC-based
ab2_loss = train_model(ab2_model, train_loader, criterion, optimizer_ab2, input_type='cqcc', device=device, val_dataloader=val_loader, model_save_path=os.path.join(models_dir, "ablation_cqcc_graph.pth"))
del ab2_model, optimizer_ab2
torch.cuda.empty_cache()
print("\n--- Training Ablation 3 (Wav2Vec2 + CQCC + Simple Concat) ---")
ab3_model = AblationConcatGraphDetector(num_classes=2).to(device)
optimizer_ab3 = torch.optim.Adam(ab3_model.parameters(), lr=1e-4)
ab3_loss = train_model(ab3_model, train_loader, criterion, optimizer_ab3, input_type='wav_and_cqcc', device=device, val_dataloader=val_loader, model_save_path=os.path.join(models_dir, "ablation_concat_graph.pth"))
del ab3_model, optimizer_ab3
torch.cuda.empty_cache()
print("\n--- Training Ablation 4 (Wav2Vec2 + CQCC + Cross-Attn + Linear) ---")
ab4_model = AblationCrossAttnLinearDetector(num_classes=2).to(device)
optimizer_ab4 = torch.optim.Adam(ab4_model.parameters(), lr=1e-4)
ab4_loss = train_model(ab4_model, train_loader, criterion, optimizer_ab4, input_type='wav_and_cqcc', device=device, val_dataloader=val_loader, model_save_path=os.path.join(models_dir, "ablation_crossattn_linear.pth"))
del ab4_model, optimizer_ab4
torch.cuda.empty_cache()
# ============================================================
# Evaluation — reload one at a time
# ============================================================
print("\n--- Evaluating Models ---")
evals = []
models_to_eval = [
("Wav2Vec2 Baseline", Wav2Vec2SpoofDetector, "wav2vec2.pth", 'wav'),
("AASIST Baseline", AASISTDetector, "aasist.pth", 'wav'),
("CQCC Baseline", CQCCBaselineDetector, "cqcc_baseline.pth", 'cqcc'),
("Custom Fusion Model", ImprovedWav2Vec2CQCCDetector, "custom_hybrid.pth", 'wav_and_cqcc'),
("Ablation 1 (W2V2+Graph)", AblationWav2Vec2GraphDetector, "ablation_w2v2_graph.pth", 'wav'),
("Ablation 2 (CQCC+Graph)", AblationCQCCGraphDetector, "ablation_cqcc_graph.pth", 'cqcc'),
("Ablation 3 (Concat+Graph)", AblationConcatGraphDetector, "ablation_concat_graph.pth", 'wav_and_cqcc'),
("Ablation 4 (CrossAttn+Linear)", AblationCrossAttnLinearDetector, "ablation_crossattn_linear.pth", 'wav_and_cqcc'),
]
for name, model_class, filename, inp in models_to_eval:
model_path = os.path.join(models_dir, filename)
if not os.path.exists(model_path):
print(f"Skipping evaluation for {name} (Model weights not found at {model_path})")
continue
model_obj = model_class(num_classes=2).to(device)
model_obj.load_state_dict(torch.load(model_path, map_location=device))
model_obj.eval()
print(f"\n--- Metrics for {name} ---")
# 1. EVAL ON TRAIN SET
train_fpr, train_tpr, train_auc, train_eer, train_min_dcf, train_acc = evaluate_model(
model_obj, train_loader, input_type=inp, device=device
)
print(f"[Train] Acc={train_acc*100:.2f}% | EER={train_eer*100:.2f}% | minDCF={train_min_dcf:.4f}")
# 2. EVAL ON TEST SET
test_fpr, test_tpr, test_auc, test_eer, test_min_dcf, test_acc = evaluate_model(
model_obj, test_loader, input_type=inp, device=device
)
print(f"[Test ] Acc={test_acc*100:.2f}% | EER={test_eer*100:.2f}% | minDCF={test_min_dcf:.4f}")
del model_obj
torch.cuda.empty_cache()
if __name__ == "__main__":
main()