VERIDEX.V1 / backend /train_efficientnet.py
shadow55gh
fix: remove node_modules and cache from tracking
81f9dfe
"""
VERIDEX β€” EfficientNet-B4 Deepfake Training Script
====================================================
Dataset: FaceForensics++ + DFDC + Custom (80k images)
Expected folder structure (TWO options supported):
Option A β€” class folders:
data/
real/ ← real face images
fake/ ← deepfake/AI-generated images
Option B β€” train/val split folders:
data/
train/
real/
fake/
val/
real/
fake/
Usage:
python train_efficientnet.py --data_dir ./data --epochs 20 --batch_size 32
After training, weights saved to:
weights/efficientnet_deepfake.pth
weights/efficientnet_b4_meta.json
"""
import os, json, time, argparse, warnings
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, WeightedRandomSampler
from torchvision import datasets, transforms
from torch.cuda.amp import GradScaler, autocast
import numpy as np
warnings.filterwarnings("ignore")
# ─────────────────────────────────────────────────────────────
# Config
# ─────────────────────────────────────────────────────────────
IMG_SIZE = 380
NUM_CLASSES = 2
FAKE_LABEL = 0 # index 0 = fake, index 1 = real (alphabetical: fake < real)
def get_transforms(is_train: bool):
"""
Deepfake-specific augmentation:
- Compression artifacts (JPEG quality) β€” FaceForensics++ uses compressed videos
- Gaussian noise β€” simulates video encoding
- Horizontal flip β€” faces are symmetric
- Color jitter β€” lighting variation
- Random erasing β€” occlusion robustness
"""
if is_train:
return transforms.Compose([
transforms.Resize((IMG_SIZE + 20, IMG_SIZE + 20)),
transforms.RandomCrop(IMG_SIZE),
transforms.RandomHorizontalFlip(p=0.5),
transforms.ColorJitter(
brightness=0.3, contrast=0.3,
saturation=0.2, hue=0.05
),
transforms.RandomGrayscale(p=0.05),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225]),
transforms.RandomErasing(p=0.1, scale=(0.02, 0.1)),
])
else:
return transforms.Compose([
transforms.Resize((IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225]),
])
def build_model(device: str) -> nn.Module:
import timm
model = timm.create_model("efficientnet_b4", pretrained=True)
# Replace classifier head for binary deepfake detection
model.classifier = nn.Sequential(
nn.Dropout(0.4),
nn.Linear(model.num_features, 512),
nn.GELU(),
nn.BatchNorm1d(512),
nn.Dropout(0.3),
nn.Linear(512, NUM_CLASSES),
)
return model.to(device)
def make_weighted_sampler(dataset) -> WeightedRandomSampler:
"""Balance fake/real classes during training."""
counts = np.bincount([label for _, label in dataset.samples])
weights_per_class = 1.0 / (counts + 1e-6)
sample_weights = [weights_per_class[label] for _, label in dataset.samples]
return WeightedRandomSampler(
weights=sample_weights,
num_samples=len(sample_weights),
replacement=True
)
def load_datasets(data_dir: str, val_split: float = 0.15):
"""
Auto-detect Option A (flat) or Option B (train/val split).
"""
train_dir = os.path.join(data_dir, "train")
val_dir = os.path.join(data_dir, "val")
if os.path.isdir(train_dir) and os.path.isdir(val_dir):
# Option B: pre-split
print(f"[TRAIN] Using pre-split: {train_dir} / {val_dir}")
train_ds = datasets.ImageFolder(train_dir, transform=get_transforms(True))
val_ds = datasets.ImageFolder(val_dir, transform=get_transforms(False))
else:
# Option A: flat folders β€” split automatically
print(f"[TRAIN] Auto-splitting from: {data_dir}")
full_ds = datasets.ImageFolder(data_dir, transform=get_transforms(True))
n_val = int(len(full_ds) * val_split)
n_train = len(full_ds) - n_val
from torch.utils.data import random_split, Subset
indices = torch.randperm(len(full_ds)).tolist()
train_idx, val_idx = indices[n_val:], indices[:n_val]
train_ds = Subset(full_ds, train_idx)
val_ds = Subset(full_ds, val_idx)
val_ds.dataset = datasets.ImageFolder(
data_dir, transform=get_transforms(False)
)
# Fix: val uses val transforms
class _ValSubset(torch.utils.data.Dataset):
def __init__(self, base_dir, indices):
self.ds = datasets.ImageFolder(base_dir, transform=get_transforms(False))
self.indices = indices
self.classes = self.ds.classes
self.class_to_idx = self.ds.class_to_idx
self.samples = [self.ds.samples[i] for i in indices]
def __len__(self): return len(self.indices)
def __getitem__(self, i): return self.ds[self.indices[i]]
train_ds = Subset(full_ds, train_idx)
val_ds = _ValSubset(data_dir, val_idx)
return train_ds, val_ds
def train(args):
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"[TRAIN] Device: {device.upper()}")
print(f"[TRAIN] Data: {args.data_dir}")
print(f"[TRAIN] Epochs: {args.epochs} | Batch: {args.batch_size}")
os.makedirs("weights", exist_ok=True)
# ── Datasets ───────────────────────────────────────────────
train_ds, val_ds = load_datasets(args.data_dir, args.val_split)
# Weighted sampler to handle class imbalance
try:
sampler = make_weighted_sampler(train_ds)
train_loader = DataLoader(
train_ds, batch_size=args.batch_size,
sampler=sampler, num_workers=args.workers,
pin_memory=(device == "cuda")
)
except Exception:
train_loader = DataLoader(
train_ds, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers,
pin_memory=(device == "cuda")
)
val_loader = DataLoader(
val_ds, batch_size=args.batch_size,
shuffle=False, num_workers=args.workers,
pin_memory=(device == "cuda")
)
n_train = len(train_ds)
n_val = len(val_ds)
print(f"[TRAIN] Train: {n_train} | Val: {n_val}")
# ── Model ──────────────────────────────────────────────────
model = build_model(device)
scaler = GradScaler(enabled=(device == "cuda"))
# ── Loss: label smoothing helps generalization on deepfakes ─
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
# ── Optimizer: 3-phase LR ──────────────────────────────────
# Phase 1 (ep 1-5): warm-up, train only classifier head
# Phase 2 (ep 6-15): fine-tune full network, lower LR
# Phase 3 (ep 16+): cosine decay to near-zero
optimizer = optim.AdamW(
model.classifier.parameters(),
lr=1e-3, weight_decay=1e-4
)
scheduler = optim.lr_scheduler.OneCycleLR(
optimizer,
max_lr=1e-3,
steps_per_epoch=len(train_loader),
epochs=args.epochs,
pct_start=0.1,
anneal_strategy="cos",
)
best_val_acc = 0.0
best_val_auc = 0.0
patience_counter = 0
for epoch in range(1, args.epochs + 1):
# ── Unfreeze backbone after epoch 3 ───────────────────
if epoch == 4:
print("[TRAIN] πŸ”“ Unfreezing backbone for full fine-tuning")
optimizer = optim.AdamW(
model.parameters(), lr=5e-5, weight_decay=1e-4
)
scheduler = optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=args.epochs - 3, eta_min=1e-7
)
# ── Training loop ─────────────────────────────────────
model.train()
train_loss = 0.0
train_correct = 0
t0 = time.time()
for batch_idx, (imgs, labels) in enumerate(train_loader):
imgs, labels = imgs.to(device), labels.to(device)
optimizer.zero_grad()
with autocast(enabled=(device == "cuda")):
outputs = model(imgs)
loss = criterion(outputs, labels)
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
scaler.step(optimizer)
scaler.update()
if epoch <= 3:
scheduler.step()
train_loss += loss.item() * imgs.size(0)
train_correct += (outputs.argmax(1) == labels).sum().item()
if (batch_idx + 1) % 100 == 0:
pct = 100.0 * (batch_idx + 1) / len(train_loader)
print(f" [{epoch}/{args.epochs}] {pct:.0f}% | "
f"loss: {loss.item():.4f}", end="\r")
if epoch > 3:
scheduler.step()
train_loss /= n_train
train_acc = train_correct / n_train
# ── Validation loop ───────────────────────────────────
model.eval()
val_loss = 0.0
val_correct = 0
all_probs = []
all_labels = []
with torch.no_grad():
for imgs, labels in val_loader:
imgs, labels = imgs.to(device), labels.to(device)
with autocast(enabled=(device == "cuda")):
outputs = model(imgs)
loss = criterion(outputs, labels)
probs = torch.softmax(outputs, dim=1)[:, FAKE_LABEL].cpu().numpy()
all_probs.extend(probs.tolist())
all_labels.extend(labels.cpu().numpy().tolist())
val_loss += loss.item() * imgs.size(0)
val_correct += (outputs.argmax(1) == labels).sum().item()
val_loss /= n_val
val_acc = val_correct / n_val
# AUC
try:
from sklearn.metrics import roc_auc_score
val_auc = roc_auc_score(all_labels, all_probs)
except Exception:
val_auc = 0.0
elapsed = time.time() - t0
print(f"\n[TRAIN] Epoch {epoch:02d}/{args.epochs} "
f"| Train loss={train_loss:.4f} acc={train_acc:.4f} "
f"| Val loss={val_loss:.4f} acc={val_acc:.4f} "
f"| AUC={val_auc:.4f} | {elapsed:.1f}s")
# ── Save best model ────────────────────────────────────
improved = val_acc > best_val_acc or (
val_acc == best_val_acc and val_auc > best_val_auc
)
if improved:
best_val_acc = val_acc
best_val_auc = val_auc
patience_counter = 0
# Detect fake_label from class_to_idx
fake_idx = FAKE_LABEL
try:
c2i = train_ds.dataset.class_to_idx if hasattr(train_ds, "dataset") \
else train_ds.class_to_idx
fake_idx = c2i.get("fake", c2i.get("Fake", FAKE_LABEL))
except Exception:
pass
torch.save(model.state_dict(), "weights/efficientnet_deepfake.pth")
meta = {
"fake_label": int(fake_idx),
"img_size": IMG_SIZE,
"best_val_acc": round(best_val_acc, 4),
"best_val_auc": round(best_val_auc, 4),
"epoch": epoch,
"datasets": ["FaceForensics++", "DFDC", "Custom-80k"],
}
with open("weights/efficientnet_b4_meta.json", "w") as f:
json.dump(meta, f, indent=2)
print(f"[TRAIN] βœ… Best model saved! acc={best_val_acc:.4f} AUC={best_val_auc:.4f}")
else:
patience_counter += 1
# ── Early stopping ─────────────────────────────────────
if args.patience > 0 and patience_counter >= args.patience:
print(f"[TRAIN] ⏹ Early stopping (no improvement for {args.patience} epochs)")
break
print(f"\n[TRAIN] πŸŽ‰ Training complete!")
print(f"[TRAIN] Best val accuracy : {best_val_acc:.4f} ({best_val_acc*100:.1f}%)")
print(f"[TRAIN] Best val AUC : {best_val_auc:.4f}")
print(f"[TRAIN] Weights saved to : weights/efficientnet_deepfake.pth")
print(f"[TRAIN] Meta saved to : weights/efficientnet_b4_meta.json")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Train EfficientNet-B4 Deepfake Detector")
parser.add_argument("--data_dir", type=str, default="data",
help="Root data directory (containing real/ and fake/ folders)")
parser.add_argument("--epochs", type=int, default=20,
help="Total training epochs (default: 20)")
parser.add_argument("--batch_size", type=int, default=32,
help="Batch size (default: 32; use 16 for 4GB GPU)")
parser.add_argument("--val_split", type=float, default=0.15,
help="Validation fraction if no val/ folder (default: 0.15)")
parser.add_argument("--workers", type=int, default=4,
help="DataLoader workers (default: 4)")
parser.add_argument("--patience", type=int, default=5,
help="Early stopping patience (default: 5, 0=disabled)")
args = parser.parse_args()
if not os.path.isdir(args.data_dir):
print(f"[ERROR] Data directory not found: {args.data_dir}")
print("[ERROR] Expected structure:")
print(" data/real/ ← real face images")
print(" data/fake/ ← deepfake images")
exit(1)
train(args)