vineetshukla.work@gmail.com
final commit
c5c9261
"""
Training Engine — Research-grade training loop for deepfake audio detection.
Features: Focal Loss, Cosine Annealing, Mixed Precision, EMA, full logging.
"""
import os
import sys
import copy
import yaml
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
import librosa
import logging
import argparse
import random
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import GradScaler, autocast
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, CosineAnnealingLR, LinearLR, SequentialLR
from augment import AugmentationPipeline, AugConfig
from models.backbone import BackboneLoader
from models.classifier import DeepfakeClassifier
from models.ensemble import EnsembleDetector
logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(message)s")
logger = logging.getLogger(__name__)
# ===============================================================
# Focal Loss (handles class imbalance better than cross-entropy)
# ===============================================================
class FocalLoss(nn.Module):
"""Focal Loss for class-imbalanced classification."""
def __init__(self, gamma: float = 2.0, alpha: float = 0.25):
super().__init__()
self.gamma = gamma
self.alpha = alpha
def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
probs = F.softmax(logits, dim=-1)
targets_one_hot = F.one_hot(targets, num_classes=logits.size(-1)).float()
pt = (probs * targets_one_hot).sum(dim=-1)
focal_weight = (1 - pt) ** self.gamma
ce_loss = F.cross_entropy(logits, targets, reduction="none")
loss = self.alpha * focal_weight * ce_loss
return loss.mean()
# ===============================================================
# EMA (Exponential Moving Average) for smoother final weights
# ===============================================================
class EMAModel:
"""Maintains an exponential moving average of model parameters."""
def __init__(self, model: nn.Module, decay: float = 0.999):
self.decay = decay
self.shadow = {}
self.backup = {}
for name, param in model.named_parameters():
if param.requires_grad:
self.shadow[name] = param.data.clone()
def update(self, model: nn.Module):
for name, param in model.named_parameters():
if param.requires_grad and name in self.shadow:
self.shadow[name] = (
self.decay * self.shadow[name] + (1 - self.decay) * param.data
)
def apply_shadow(self, model: nn.Module):
"""Replace model params with EMA params."""
for name, param in model.named_parameters():
if param.requires_grad and name in self.shadow:
self.backup[name] = param.data.clone()
param.data = self.shadow[name]
def restore(self, model: nn.Module):
"""Restore original params."""
for name, param in model.named_parameters():
if name in self.backup:
param.data = self.backup[name]
self.backup = {}
# ===============================================================
# Audio Dataset
# ===============================================================
class AudioDataset(Dataset):
"""Loads audio files and applies augmentations on the fly."""
def __init__(self, metadata_csv: str, cfg: dict,
augment: bool = False, max_samples: int = None):
self.df = pd.read_csv(metadata_csv)
if max_samples:
self.df = self.df.head(max_samples)
self.sr = cfg["data"]["sample_rate"]
self.max_len = int(cfg["data"]["max_duration_sec"] * self.sr)
self.augmenter = None
if augment and cfg["augmentation"]["enabled"]:
self.augmenter = AugmentationPipeline(
AugConfig.from_dict(cfg), sr=self.sr
)
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
row = self.df.iloc[idx]
filepath = row["file"]
label = int(row["label"])
# Load audio
try:
y, _ = librosa.load(filepath, sr=self.sr, mono=True)
except Exception as e:
logger.warning(f"Failed to load {filepath}: {e}. Returning silence.")
y = np.zeros(self.sr, dtype=np.float32)
# Apply augmentations (training only)
if self.augmenter is not None:
y = self.augmenter(y)
# Pad or truncate to fixed length
if len(y) > self.max_len:
y = y[:self.max_len]
elif len(y) < self.max_len:
y = np.pad(y, (0, self.max_len - len(y)), mode="constant")
# Normalize
peak = np.max(np.abs(y))
if peak > 0:
y = y / peak
return {
"input_values": torch.tensor(y, dtype=torch.float32),
"labels": torch.tensor(label, dtype=torch.long),
}
# ===============================================================
# Training Loop
# ===============================================================
def set_seed(seed: int):
"""Set all seeds for reproducibility."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if hasattr(torch, "use_deterministic_algorithms"):
try:
torch.use_deterministic_algorithms(True)
except Exception:
pass
def build_model(cfg: dict, device: str):
"""Build model from config."""
model_cfg = cfg["model"]
mode = model_cfg["mode"]
if mode == "ensemble":
# Build all three backbones
models = []
for backbone_type in ["wav2vec2", "hubert", "wavlm"]:
bb_cfg = model_cfg["backbones"][backbone_type]
backbone, feat_ext, hidden = BackboneLoader.load(
backbone_type, bb_cfg["name"], bb_cfg["freeze_layers"], device
)
clf = DeepfakeClassifier(
backbone, hidden,
num_labels=2,
classifier_hidden=model_cfg["classifier"]["hidden_dim"],
dropout=model_cfg["classifier"]["dropout"],
pooling_type=model_cfg["classifier"]["pooling"],
).to(device)
models.append(clf)
ensemble = EnsembleDetector.create(
models,
strategy=model_cfg["ensemble"]["strategy"],
weights=model_cfg["ensemble"].get("weights"),
).to(device)
return ensemble, feat_ext # Use last feature extractor
else:
# Single backbone
bb_cfg = model_cfg["backbones"][mode]
backbone, feat_ext, hidden = BackboneLoader.load(
mode, bb_cfg["name"], bb_cfg["freeze_layers"], device
)
model = DeepfakeClassifier(
backbone, hidden,
num_labels=2,
classifier_hidden=model_cfg["classifier"]["hidden_dim"],
dropout=model_cfg["classifier"]["dropout"],
pooling_type=model_cfg["classifier"]["pooling"],
).to(device)
return model, feat_ext
def evaluate(model, dataloader, criterion, device):
"""Run evaluation and return loss + accuracy."""
model.eval()
total_loss, correct, total = 0, 0, 0
all_preds, all_labels, all_probs = [], [], []
with torch.no_grad():
for batch in dataloader:
inputs = batch["input_values"].to(device)
labels = batch["labels"].to(device)
logits = model(inputs)
loss = criterion(logits, labels)
total_loss += loss.item() * labels.size(0)
probs = F.softmax(logits, dim=-1)
preds = logits.argmax(dim=-1)
correct += (preds == labels).sum().item()
total += labels.size(0)
all_preds.extend(preds.cpu().numpy())
all_labels.extend(labels.cpu().numpy())
all_probs.extend(probs[:, 1].cpu().numpy()) # P(AI_GENERATED)
avg_loss = total_loss / max(total, 1)
accuracy = correct / max(total, 1)
# Compute EER
from sklearn.metrics import roc_curve
fpr, tpr, thresholds = roc_curve(all_labels, all_probs)
fnr = 1 - tpr
eer_idx = np.nanargmin(np.abs(fpr - fnr))
eer = (fpr[eer_idx] + fnr[eer_idx]) / 2
return {
"loss": avg_loss,
"accuracy": accuracy,
"eer": eer,
"predictions": all_preds,
"labels": all_labels,
"probabilities": all_probs,
}
def train(cfg: dict):
"""Main training function."""
train_cfg = cfg["training"]
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"🖥️ Device: {device}")
set_seed(train_cfg["seed"])
# Paths
metadata_dir = os.path.join(cfg["paths"]["output_dir"], "metadata")
train_csv = os.path.join(metadata_dir, "train.csv")
val_csv = os.path.join(metadata_dir, "val.csv")
if not os.path.exists(train_csv):
logger.error("❌ Training metadata not found. Run prepare_data.py first!")
return
# Datasets
logger.info("📂 Loading datasets...")
train_dataset = AudioDataset(train_csv, cfg, augment=True)
val_dataset = AudioDataset(val_csv, cfg, augment=False)
train_loader = DataLoader(
train_dataset, batch_size=train_cfg["batch_size"],
shuffle=True, num_workers=0, pin_memory=True, drop_last=True,
)
val_loader = DataLoader(
val_dataset, batch_size=train_cfg["batch_size"],
shuffle=False, num_workers=0, pin_memory=True,
)
logger.info(f" Train samples: {len(train_dataset)}")
logger.info(f" Val samples: {len(val_dataset)}")
# Model
logger.info("🏗️ Building model...")
model, feature_extractor = build_model(cfg, device)
# Count parameters
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
logger.info(f" Parameters: {trainable:,} trainable / {total:,} total")
# Loss
if train_cfg["loss"] == "focal":
criterion = FocalLoss(
gamma=train_cfg["focal_gamma"],
alpha=train_cfg["focal_alpha"],
)
logger.info(" Loss: Focal Loss")
else:
criterion = nn.CrossEntropyLoss()
logger.info(" Loss: Cross Entropy")
# Optimizer
optimizer = torch.optim.AdamW(
filter(lambda p: p.requires_grad, model.parameters()),
lr=train_cfg["learning_rate"],
weight_decay=train_cfg["weight_decay"],
)
# Scheduler
total_steps = len(train_loader) * train_cfg["num_epochs"]
warmup_steps = int(total_steps * train_cfg["warmup_ratio"])
warmup_scheduler = LinearLR(optimizer, start_factor=0.1, total_iters=warmup_steps)
if train_cfg["lr_scheduler"] == "cosine_with_restarts":
main_scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=len(train_loader), T_mult=2)
else:
main_scheduler = CosineAnnealingLR(optimizer, T_max=total_steps - warmup_steps)
scheduler = SequentialLR(optimizer, [warmup_scheduler, main_scheduler], milestones=[warmup_steps])
# Mixed precision
use_fp16 = train_cfg["fp16"] and device == "cuda"
scaler = GradScaler() if use_fp16 else None
logger.info(f" Mixed precision: {'ON' if use_fp16 else 'OFF'}")
# EMA
ema = None
if train_cfg["ema"]["enabled"]:
ema = EMAModel(model, decay=train_cfg["ema"]["decay"])
logger.info(f" EMA: ON (decay={train_cfg['ema']['decay']})")
# Training state
output_dir = cfg["paths"]["output_dir"]
os.makedirs(output_dir, exist_ok=True)
best_metric = float("inf") if train_cfg["metric_for_best_model"] == "eer" else 0
patience_counter = 0
history = []
grad_accum = train_cfg["gradient_accumulation_steps"]
# ============ Training Loop ============
logger.info("=" * 60)
logger.info(" 🔥 TRAINING STARTED")
logger.info("=" * 60)
for epoch in range(train_cfg["num_epochs"]):
model.train()
epoch_loss = 0
epoch_correct = 0
epoch_total = 0
optimizer.zero_grad()
for step, batch in enumerate(train_loader):
inputs = batch["input_values"].to(device)
labels = batch["labels"].to(device)
if use_fp16:
with autocast():
logits = model(inputs)
loss = criterion(logits, labels) / grad_accum
scaler.scale(loss).backward()
if (step + 1) % grad_accum == 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
scheduler.step()
else:
logits = model(inputs)
loss = criterion(logits, labels) / grad_accum
loss.backward()
if (step + 1) % grad_accum == 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
optimizer.zero_grad()
scheduler.step()
# EMA update
if ema is not None:
ema.update(model)
# Track metrics
epoch_loss += loss.item() * grad_accum * labels.size(0)
preds = logits.argmax(dim=-1)
epoch_correct += (preds == labels).sum().item()
epoch_total += labels.size(0)
# Epoch summary
train_loss = epoch_loss / max(epoch_total, 1)
train_acc = epoch_correct / max(epoch_total, 1)
# Validation
if ema is not None:
ema.apply_shadow(model)
val_metrics = evaluate(model, val_loader, criterion, device)
if ema is not None:
ema.restore(model)
current_lr = optimizer.param_groups[0]["lr"]
logger.info(
f"Epoch {epoch+1}/{train_cfg['num_epochs']} | "
f"Train Loss: {train_loss:.4f} Acc: {train_acc:.4f} | "
f"Val Loss: {val_metrics['loss']:.4f} Acc: {val_metrics['accuracy']:.4f} "
f"EER: {val_metrics['eer']:.4f} | "
f"LR: {current_lr:.2e}"
)
history.append({
"epoch": epoch + 1,
"train_loss": train_loss,
"train_acc": train_acc,
"val_loss": val_metrics["loss"],
"val_acc": val_metrics["accuracy"],
"val_eer": val_metrics["eer"],
"lr": current_lr,
})
# Best model check
metric_key = train_cfg["metric_for_best_model"]
current_val = val_metrics.get(metric_key, val_metrics["accuracy"])
is_better = (current_val < best_metric) if metric_key == "eer" else (current_val > best_metric)
if is_better:
best_metric = current_val
patience_counter = 0
# Save best model
best_path = os.path.join(output_dir, "best_model")
os.makedirs(best_path, exist_ok=True)
if ema is not None:
ema.apply_shadow(model)
torch.save(model.state_dict(), os.path.join(best_path, "model.pt"))
if ema is not None:
ema.restore(model)
logger.info(f" ✅ New best! {metric_key}={current_val:.4f} → saved to {best_path}")
else:
patience_counter += 1
if patience_counter >= train_cfg["early_stopping_patience"]:
logger.info(f" ⏹️ Early stopping after {patience_counter} epochs without improvement")
break
# Save training history
with open(os.path.join(output_dir, "training_history.json"), "w") as f:
json.dump(history, f, indent=2)
logger.info("=" * 60)
logger.info(" 🎉 TRAINING COMPLETE!")
logger.info(f" Best {train_cfg['metric_for_best_model']}: {best_metric:.4f}")
logger.info(f" Model saved to: {os.path.join(output_dir, 'best_model')}")
logger.info("=" * 60)
# ===============================================================
# Main Entry Point
# ===============================================================
def main():
parser = argparse.ArgumentParser(description="World-Class Deepfake Audio Detection Trainer")
parser.add_argument("--config", type=str, default="config.yaml")
parser.add_argument("--max_steps", type=int, default=None,
help="Limit training for testing (overrides epochs)")
args = parser.parse_args()
with open(args.config, "r") as f:
cfg = yaml.safe_load(f)
if args.max_steps:
cfg["training"]["num_epochs"] = 1
logger.info(f"⚠️ Debug mode: limited to {args.max_steps} steps")
train(cfg)
if __name__ == "__main__":
main()