""" VERIDEX — EfficientNet-B4 Deepfake Training Script ==================================================== Dataset: FaceForensics++ + DFDC + Custom (80k images) Expected folder structure (TWO options supported): Option A — class folders: data/ real/ ← real face images fake/ ← deepfake/AI-generated images Option B — train/val split folders: data/ train/ real/ fake/ val/ real/ fake/ Usage: python train_efficientnet.py --data_dir ./data --epochs 20 --batch_size 32 After training, weights saved to: weights/efficientnet_deepfake.pth weights/efficientnet_b4_meta.json """ import os, json, time, argparse, warnings import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader, WeightedRandomSampler from torchvision import datasets, transforms from torch.cuda.amp import GradScaler, autocast import numpy as np warnings.filterwarnings("ignore") # ───────────────────────────────────────────────────────────── # Config # ───────────────────────────────────────────────────────────── IMG_SIZE = 380 NUM_CLASSES = 2 FAKE_LABEL = 0 # index 0 = fake, index 1 = real (alphabetical: fake < real) def get_transforms(is_train: bool): """ Deepfake-specific augmentation: - Compression artifacts (JPEG quality) — FaceForensics++ uses compressed videos - Gaussian noise — simulates video encoding - Horizontal flip — faces are symmetric - Color jitter — lighting variation - Random erasing — occlusion robustness """ if is_train: return transforms.Compose([ transforms.Resize((IMG_SIZE + 20, IMG_SIZE + 20)), transforms.RandomCrop(IMG_SIZE), transforms.RandomHorizontalFlip(p=0.5), transforms.ColorJitter( brightness=0.3, contrast=0.3, saturation=0.2, hue=0.05 ), transforms.RandomGrayscale(p=0.05), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), transforms.RandomErasing(p=0.1, scale=(0.02, 0.1)), ]) else: return transforms.Compose([ transforms.Resize((IMG_SIZE, IMG_SIZE)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]) def build_model(device: str) -> nn.Module: import timm model = timm.create_model("efficientnet_b4", pretrained=True) # Replace classifier head for binary deepfake detection model.classifier = nn.Sequential( nn.Dropout(0.4), nn.Linear(model.num_features, 512), nn.GELU(), nn.BatchNorm1d(512), nn.Dropout(0.3), nn.Linear(512, NUM_CLASSES), ) return model.to(device) def make_weighted_sampler(dataset) -> WeightedRandomSampler: """Balance fake/real classes during training.""" counts = np.bincount([label for _, label in dataset.samples]) weights_per_class = 1.0 / (counts + 1e-6) sample_weights = [weights_per_class[label] for _, label in dataset.samples] return WeightedRandomSampler( weights=sample_weights, num_samples=len(sample_weights), replacement=True ) def load_datasets(data_dir: str, val_split: float = 0.15): """ Auto-detect Option A (flat) or Option B (train/val split). """ train_dir = os.path.join(data_dir, "train") val_dir = os.path.join(data_dir, "val") if os.path.isdir(train_dir) and os.path.isdir(val_dir): # Option B: pre-split print(f"[TRAIN] Using pre-split: {train_dir} / {val_dir}") train_ds = datasets.ImageFolder(train_dir, transform=get_transforms(True)) val_ds = datasets.ImageFolder(val_dir, transform=get_transforms(False)) else: # Option A: flat folders — split automatically print(f"[TRAIN] Auto-splitting from: {data_dir}") full_ds = datasets.ImageFolder(data_dir, transform=get_transforms(True)) n_val = int(len(full_ds) * val_split) n_train = len(full_ds) - n_val from torch.utils.data import random_split, Subset indices = torch.randperm(len(full_ds)).tolist() train_idx, val_idx = indices[n_val:], indices[:n_val] train_ds = Subset(full_ds, train_idx) val_ds = Subset(full_ds, val_idx) val_ds.dataset = datasets.ImageFolder( data_dir, transform=get_transforms(False) ) # Fix: val uses val transforms class _ValSubset(torch.utils.data.Dataset): def __init__(self, base_dir, indices): self.ds = datasets.ImageFolder(base_dir, transform=get_transforms(False)) self.indices = indices self.classes = self.ds.classes self.class_to_idx = self.ds.class_to_idx self.samples = [self.ds.samples[i] for i in indices] def __len__(self): return len(self.indices) def __getitem__(self, i): return self.ds[self.indices[i]] train_ds = Subset(full_ds, train_idx) val_ds = _ValSubset(data_dir, val_idx) return train_ds, val_ds def train(args): device = "cuda" if torch.cuda.is_available() else "cpu" print(f"[TRAIN] Device: {device.upper()}") print(f"[TRAIN] Data: {args.data_dir}") print(f"[TRAIN] Epochs: {args.epochs} | Batch: {args.batch_size}") os.makedirs("weights", exist_ok=True) # ── Datasets ─────────────────────────────────────────────── train_ds, val_ds = load_datasets(args.data_dir, args.val_split) # Weighted sampler to handle class imbalance try: sampler = make_weighted_sampler(train_ds) train_loader = DataLoader( train_ds, batch_size=args.batch_size, sampler=sampler, num_workers=args.workers, pin_memory=(device == "cuda") ) except Exception: train_loader = DataLoader( train_ds, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=(device == "cuda") ) val_loader = DataLoader( val_ds, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=(device == "cuda") ) n_train = len(train_ds) n_val = len(val_ds) print(f"[TRAIN] Train: {n_train} | Val: {n_val}") # ── Model ────────────────────────────────────────────────── model = build_model(device) scaler = GradScaler(enabled=(device == "cuda")) # ── Loss: label smoothing helps generalization on deepfakes ─ criterion = nn.CrossEntropyLoss(label_smoothing=0.1) # ── Optimizer: 3-phase LR ────────────────────────────────── # Phase 1 (ep 1-5): warm-up, train only classifier head # Phase 2 (ep 6-15): fine-tune full network, lower LR # Phase 3 (ep 16+): cosine decay to near-zero optimizer = optim.AdamW( model.classifier.parameters(), lr=1e-3, weight_decay=1e-4 ) scheduler = optim.lr_scheduler.OneCycleLR( optimizer, max_lr=1e-3, steps_per_epoch=len(train_loader), epochs=args.epochs, pct_start=0.1, anneal_strategy="cos", ) best_val_acc = 0.0 best_val_auc = 0.0 patience_counter = 0 for epoch in range(1, args.epochs + 1): # ── Unfreeze backbone after epoch 3 ─────────────────── if epoch == 4: print("[TRAIN] 🔓 Unfreezing backbone for full fine-tuning") optimizer = optim.AdamW( model.parameters(), lr=5e-5, weight_decay=1e-4 ) scheduler = optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=args.epochs - 3, eta_min=1e-7 ) # ── Training loop ───────────────────────────────────── model.train() train_loss = 0.0 train_correct = 0 t0 = time.time() for batch_idx, (imgs, labels) in enumerate(train_loader): imgs, labels = imgs.to(device), labels.to(device) optimizer.zero_grad() with autocast(enabled=(device == "cuda")): outputs = model(imgs) loss = criterion(outputs, labels) scaler.scale(loss).backward() scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) scaler.step(optimizer) scaler.update() if epoch <= 3: scheduler.step() train_loss += loss.item() * imgs.size(0) train_correct += (outputs.argmax(1) == labels).sum().item() if (batch_idx + 1) % 100 == 0: pct = 100.0 * (batch_idx + 1) / len(train_loader) print(f" [{epoch}/{args.epochs}] {pct:.0f}% | " f"loss: {loss.item():.4f}", end="\r") if epoch > 3: scheduler.step() train_loss /= n_train train_acc = train_correct / n_train # ── Validation loop ─────────────────────────────────── model.eval() val_loss = 0.0 val_correct = 0 all_probs = [] all_labels = [] with torch.no_grad(): for imgs, labels in val_loader: imgs, labels = imgs.to(device), labels.to(device) with autocast(enabled=(device == "cuda")): outputs = model(imgs) loss = criterion(outputs, labels) probs = torch.softmax(outputs, dim=1)[:, FAKE_LABEL].cpu().numpy() all_probs.extend(probs.tolist()) all_labels.extend(labels.cpu().numpy().tolist()) val_loss += loss.item() * imgs.size(0) val_correct += (outputs.argmax(1) == labels).sum().item() val_loss /= n_val val_acc = val_correct / n_val # AUC try: from sklearn.metrics import roc_auc_score val_auc = roc_auc_score(all_labels, all_probs) except Exception: val_auc = 0.0 elapsed = time.time() - t0 print(f"\n[TRAIN] Epoch {epoch:02d}/{args.epochs} " f"| Train loss={train_loss:.4f} acc={train_acc:.4f} " f"| Val loss={val_loss:.4f} acc={val_acc:.4f} " f"| AUC={val_auc:.4f} | {elapsed:.1f}s") # ── Save best model ──────────────────────────────────── improved = val_acc > best_val_acc or ( val_acc == best_val_acc and val_auc > best_val_auc ) if improved: best_val_acc = val_acc best_val_auc = val_auc patience_counter = 0 # Detect fake_label from class_to_idx fake_idx = FAKE_LABEL try: c2i = train_ds.dataset.class_to_idx if hasattr(train_ds, "dataset") \ else train_ds.class_to_idx fake_idx = c2i.get("fake", c2i.get("Fake", FAKE_LABEL)) except Exception: pass torch.save(model.state_dict(), "weights/efficientnet_deepfake.pth") meta = { "fake_label": int(fake_idx), "img_size": IMG_SIZE, "best_val_acc": round(best_val_acc, 4), "best_val_auc": round(best_val_auc, 4), "epoch": epoch, "datasets": ["FaceForensics++", "DFDC", "Custom-80k"], } with open("weights/efficientnet_b4_meta.json", "w") as f: json.dump(meta, f, indent=2) print(f"[TRAIN] ✅ Best model saved! acc={best_val_acc:.4f} AUC={best_val_auc:.4f}") else: patience_counter += 1 # ── Early stopping ───────────────────────────────────── if args.patience > 0 and patience_counter >= args.patience: print(f"[TRAIN] ⏹ Early stopping (no improvement for {args.patience} epochs)") break print(f"\n[TRAIN] 🎉 Training complete!") print(f"[TRAIN] Best val accuracy : {best_val_acc:.4f} ({best_val_acc*100:.1f}%)") print(f"[TRAIN] Best val AUC : {best_val_auc:.4f}") print(f"[TRAIN] Weights saved to : weights/efficientnet_deepfake.pth") print(f"[TRAIN] Meta saved to : weights/efficientnet_b4_meta.json") if __name__ == "__main__": parser = argparse.ArgumentParser(description="Train EfficientNet-B4 Deepfake Detector") parser.add_argument("--data_dir", type=str, default="data", help="Root data directory (containing real/ and fake/ folders)") parser.add_argument("--epochs", type=int, default=20, help="Total training epochs (default: 20)") parser.add_argument("--batch_size", type=int, default=32, help="Batch size (default: 32; use 16 for 4GB GPU)") parser.add_argument("--val_split", type=float, default=0.15, help="Validation fraction if no val/ folder (default: 0.15)") parser.add_argument("--workers", type=int, default=4, help="DataLoader workers (default: 4)") parser.add_argument("--patience", type=int, default=5, help="Early stopping patience (default: 5, 0=disabled)") args = parser.parse_args() if not os.path.isdir(args.data_dir): print(f"[ERROR] Data directory not found: {args.data_dir}") print("[ERROR] Expected structure:") print(" data/real/ ← real face images") print(" data/fake/ ← deepfake images") exit(1) train(args)