Spaces:
Sleeping
Sleeping
| """ | |
| VERIDEX β EfficientNet-B4 Deepfake Training Script | |
| ==================================================== | |
| Dataset: FaceForensics++ + DFDC + Custom (80k images) | |
| Expected folder structure (TWO options supported): | |
| Option A β class folders: | |
| data/ | |
| real/ β real face images | |
| fake/ β deepfake/AI-generated images | |
| Option B β train/val split folders: | |
| data/ | |
| train/ | |
| real/ | |
| fake/ | |
| val/ | |
| real/ | |
| fake/ | |
| Usage: | |
| python train_efficientnet.py --data_dir ./data --epochs 20 --batch_size 32 | |
| After training, weights saved to: | |
| weights/efficientnet_deepfake.pth | |
| weights/efficientnet_b4_meta.json | |
| """ | |
| import os, json, time, argparse, warnings | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from torch.utils.data import DataLoader, WeightedRandomSampler | |
| from torchvision import datasets, transforms | |
| from torch.cuda.amp import GradScaler, autocast | |
| import numpy as np | |
| warnings.filterwarnings("ignore") | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Config | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| IMG_SIZE = 380 | |
| NUM_CLASSES = 2 | |
| FAKE_LABEL = 0 # index 0 = fake, index 1 = real (alphabetical: fake < real) | |
| def get_transforms(is_train: bool): | |
| """ | |
| Deepfake-specific augmentation: | |
| - Compression artifacts (JPEG quality) β FaceForensics++ uses compressed videos | |
| - Gaussian noise β simulates video encoding | |
| - Horizontal flip β faces are symmetric | |
| - Color jitter β lighting variation | |
| - Random erasing β occlusion robustness | |
| """ | |
| if is_train: | |
| return transforms.Compose([ | |
| transforms.Resize((IMG_SIZE + 20, IMG_SIZE + 20)), | |
| transforms.RandomCrop(IMG_SIZE), | |
| transforms.RandomHorizontalFlip(p=0.5), | |
| transforms.ColorJitter( | |
| brightness=0.3, contrast=0.3, | |
| saturation=0.2, hue=0.05 | |
| ), | |
| transforms.RandomGrayscale(p=0.05), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], | |
| [0.229, 0.224, 0.225]), | |
| transforms.RandomErasing(p=0.1, scale=(0.02, 0.1)), | |
| ]) | |
| else: | |
| return transforms.Compose([ | |
| transforms.Resize((IMG_SIZE, IMG_SIZE)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], | |
| [0.229, 0.224, 0.225]), | |
| ]) | |
| def build_model(device: str) -> nn.Module: | |
| import timm | |
| model = timm.create_model("efficientnet_b4", pretrained=True) | |
| # Replace classifier head for binary deepfake detection | |
| model.classifier = nn.Sequential( | |
| nn.Dropout(0.4), | |
| nn.Linear(model.num_features, 512), | |
| nn.GELU(), | |
| nn.BatchNorm1d(512), | |
| nn.Dropout(0.3), | |
| nn.Linear(512, NUM_CLASSES), | |
| ) | |
| return model.to(device) | |
| def make_weighted_sampler(dataset) -> WeightedRandomSampler: | |
| """Balance fake/real classes during training.""" | |
| counts = np.bincount([label for _, label in dataset.samples]) | |
| weights_per_class = 1.0 / (counts + 1e-6) | |
| sample_weights = [weights_per_class[label] for _, label in dataset.samples] | |
| return WeightedRandomSampler( | |
| weights=sample_weights, | |
| num_samples=len(sample_weights), | |
| replacement=True | |
| ) | |
| def load_datasets(data_dir: str, val_split: float = 0.15): | |
| """ | |
| Auto-detect Option A (flat) or Option B (train/val split). | |
| """ | |
| train_dir = os.path.join(data_dir, "train") | |
| val_dir = os.path.join(data_dir, "val") | |
| if os.path.isdir(train_dir) and os.path.isdir(val_dir): | |
| # Option B: pre-split | |
| print(f"[TRAIN] Using pre-split: {train_dir} / {val_dir}") | |
| train_ds = datasets.ImageFolder(train_dir, transform=get_transforms(True)) | |
| val_ds = datasets.ImageFolder(val_dir, transform=get_transforms(False)) | |
| else: | |
| # Option A: flat folders β split automatically | |
| print(f"[TRAIN] Auto-splitting from: {data_dir}") | |
| full_ds = datasets.ImageFolder(data_dir, transform=get_transforms(True)) | |
| n_val = int(len(full_ds) * val_split) | |
| n_train = len(full_ds) - n_val | |
| from torch.utils.data import random_split, Subset | |
| indices = torch.randperm(len(full_ds)).tolist() | |
| train_idx, val_idx = indices[n_val:], indices[:n_val] | |
| train_ds = Subset(full_ds, train_idx) | |
| val_ds = Subset(full_ds, val_idx) | |
| val_ds.dataset = datasets.ImageFolder( | |
| data_dir, transform=get_transforms(False) | |
| ) | |
| # Fix: val uses val transforms | |
| class _ValSubset(torch.utils.data.Dataset): | |
| def __init__(self, base_dir, indices): | |
| self.ds = datasets.ImageFolder(base_dir, transform=get_transforms(False)) | |
| self.indices = indices | |
| self.classes = self.ds.classes | |
| self.class_to_idx = self.ds.class_to_idx | |
| self.samples = [self.ds.samples[i] for i in indices] | |
| def __len__(self): return len(self.indices) | |
| def __getitem__(self, i): return self.ds[self.indices[i]] | |
| train_ds = Subset(full_ds, train_idx) | |
| val_ds = _ValSubset(data_dir, val_idx) | |
| return train_ds, val_ds | |
| def train(args): | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"[TRAIN] Device: {device.upper()}") | |
| print(f"[TRAIN] Data: {args.data_dir}") | |
| print(f"[TRAIN] Epochs: {args.epochs} | Batch: {args.batch_size}") | |
| os.makedirs("weights", exist_ok=True) | |
| # ββ Datasets βββββββββββββββββββββββββββββββββββββββββββββββ | |
| train_ds, val_ds = load_datasets(args.data_dir, args.val_split) | |
| # Weighted sampler to handle class imbalance | |
| try: | |
| sampler = make_weighted_sampler(train_ds) | |
| train_loader = DataLoader( | |
| train_ds, batch_size=args.batch_size, | |
| sampler=sampler, num_workers=args.workers, | |
| pin_memory=(device == "cuda") | |
| ) | |
| except Exception: | |
| train_loader = DataLoader( | |
| train_ds, batch_size=args.batch_size, | |
| shuffle=True, num_workers=args.workers, | |
| pin_memory=(device == "cuda") | |
| ) | |
| val_loader = DataLoader( | |
| val_ds, batch_size=args.batch_size, | |
| shuffle=False, num_workers=args.workers, | |
| pin_memory=(device == "cuda") | |
| ) | |
| n_train = len(train_ds) | |
| n_val = len(val_ds) | |
| print(f"[TRAIN] Train: {n_train} | Val: {n_val}") | |
| # ββ Model ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| model = build_model(device) | |
| scaler = GradScaler(enabled=(device == "cuda")) | |
| # ββ Loss: label smoothing helps generalization on deepfakes β | |
| criterion = nn.CrossEntropyLoss(label_smoothing=0.1) | |
| # ββ Optimizer: 3-phase LR ββββββββββββββββββββββββββββββββββ | |
| # Phase 1 (ep 1-5): warm-up, train only classifier head | |
| # Phase 2 (ep 6-15): fine-tune full network, lower LR | |
| # Phase 3 (ep 16+): cosine decay to near-zero | |
| optimizer = optim.AdamW( | |
| model.classifier.parameters(), | |
| lr=1e-3, weight_decay=1e-4 | |
| ) | |
| scheduler = optim.lr_scheduler.OneCycleLR( | |
| optimizer, | |
| max_lr=1e-3, | |
| steps_per_epoch=len(train_loader), | |
| epochs=args.epochs, | |
| pct_start=0.1, | |
| anneal_strategy="cos", | |
| ) | |
| best_val_acc = 0.0 | |
| best_val_auc = 0.0 | |
| patience_counter = 0 | |
| for epoch in range(1, args.epochs + 1): | |
| # ββ Unfreeze backbone after epoch 3 βββββββββββββββββββ | |
| if epoch == 4: | |
| print("[TRAIN] π Unfreezing backbone for full fine-tuning") | |
| optimizer = optim.AdamW( | |
| model.parameters(), lr=5e-5, weight_decay=1e-4 | |
| ) | |
| scheduler = optim.lr_scheduler.CosineAnnealingLR( | |
| optimizer, T_max=args.epochs - 3, eta_min=1e-7 | |
| ) | |
| # ββ Training loop βββββββββββββββββββββββββββββββββββββ | |
| model.train() | |
| train_loss = 0.0 | |
| train_correct = 0 | |
| t0 = time.time() | |
| for batch_idx, (imgs, labels) in enumerate(train_loader): | |
| imgs, labels = imgs.to(device), labels.to(device) | |
| optimizer.zero_grad() | |
| with autocast(enabled=(device == "cuda")): | |
| outputs = model(imgs) | |
| loss = criterion(outputs, labels) | |
| scaler.scale(loss).backward() | |
| scaler.unscale_(optimizer) | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) | |
| scaler.step(optimizer) | |
| scaler.update() | |
| if epoch <= 3: | |
| scheduler.step() | |
| train_loss += loss.item() * imgs.size(0) | |
| train_correct += (outputs.argmax(1) == labels).sum().item() | |
| if (batch_idx + 1) % 100 == 0: | |
| pct = 100.0 * (batch_idx + 1) / len(train_loader) | |
| print(f" [{epoch}/{args.epochs}] {pct:.0f}% | " | |
| f"loss: {loss.item():.4f}", end="\r") | |
| if epoch > 3: | |
| scheduler.step() | |
| train_loss /= n_train | |
| train_acc = train_correct / n_train | |
| # ββ Validation loop βββββββββββββββββββββββββββββββββββ | |
| model.eval() | |
| val_loss = 0.0 | |
| val_correct = 0 | |
| all_probs = [] | |
| all_labels = [] | |
| with torch.no_grad(): | |
| for imgs, labels in val_loader: | |
| imgs, labels = imgs.to(device), labels.to(device) | |
| with autocast(enabled=(device == "cuda")): | |
| outputs = model(imgs) | |
| loss = criterion(outputs, labels) | |
| probs = torch.softmax(outputs, dim=1)[:, FAKE_LABEL].cpu().numpy() | |
| all_probs.extend(probs.tolist()) | |
| all_labels.extend(labels.cpu().numpy().tolist()) | |
| val_loss += loss.item() * imgs.size(0) | |
| val_correct += (outputs.argmax(1) == labels).sum().item() | |
| val_loss /= n_val | |
| val_acc = val_correct / n_val | |
| # AUC | |
| try: | |
| from sklearn.metrics import roc_auc_score | |
| val_auc = roc_auc_score(all_labels, all_probs) | |
| except Exception: | |
| val_auc = 0.0 | |
| elapsed = time.time() - t0 | |
| print(f"\n[TRAIN] Epoch {epoch:02d}/{args.epochs} " | |
| f"| Train loss={train_loss:.4f} acc={train_acc:.4f} " | |
| f"| Val loss={val_loss:.4f} acc={val_acc:.4f} " | |
| f"| AUC={val_auc:.4f} | {elapsed:.1f}s") | |
| # ββ Save best model ββββββββββββββββββββββββββββββββββββ | |
| improved = val_acc > best_val_acc or ( | |
| val_acc == best_val_acc and val_auc > best_val_auc | |
| ) | |
| if improved: | |
| best_val_acc = val_acc | |
| best_val_auc = val_auc | |
| patience_counter = 0 | |
| # Detect fake_label from class_to_idx | |
| fake_idx = FAKE_LABEL | |
| try: | |
| c2i = train_ds.dataset.class_to_idx if hasattr(train_ds, "dataset") \ | |
| else train_ds.class_to_idx | |
| fake_idx = c2i.get("fake", c2i.get("Fake", FAKE_LABEL)) | |
| except Exception: | |
| pass | |
| torch.save(model.state_dict(), "weights/efficientnet_deepfake.pth") | |
| meta = { | |
| "fake_label": int(fake_idx), | |
| "img_size": IMG_SIZE, | |
| "best_val_acc": round(best_val_acc, 4), | |
| "best_val_auc": round(best_val_auc, 4), | |
| "epoch": epoch, | |
| "datasets": ["FaceForensics++", "DFDC", "Custom-80k"], | |
| } | |
| with open("weights/efficientnet_b4_meta.json", "w") as f: | |
| json.dump(meta, f, indent=2) | |
| print(f"[TRAIN] β Best model saved! acc={best_val_acc:.4f} AUC={best_val_auc:.4f}") | |
| else: | |
| patience_counter += 1 | |
| # ββ Early stopping βββββββββββββββββββββββββββββββββββββ | |
| if args.patience > 0 and patience_counter >= args.patience: | |
| print(f"[TRAIN] βΉ Early stopping (no improvement for {args.patience} epochs)") | |
| break | |
| print(f"\n[TRAIN] π Training complete!") | |
| print(f"[TRAIN] Best val accuracy : {best_val_acc:.4f} ({best_val_acc*100:.1f}%)") | |
| print(f"[TRAIN] Best val AUC : {best_val_auc:.4f}") | |
| print(f"[TRAIN] Weights saved to : weights/efficientnet_deepfake.pth") | |
| print(f"[TRAIN] Meta saved to : weights/efficientnet_b4_meta.json") | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="Train EfficientNet-B4 Deepfake Detector") | |
| parser.add_argument("--data_dir", type=str, default="data", | |
| help="Root data directory (containing real/ and fake/ folders)") | |
| parser.add_argument("--epochs", type=int, default=20, | |
| help="Total training epochs (default: 20)") | |
| parser.add_argument("--batch_size", type=int, default=32, | |
| help="Batch size (default: 32; use 16 for 4GB GPU)") | |
| parser.add_argument("--val_split", type=float, default=0.15, | |
| help="Validation fraction if no val/ folder (default: 0.15)") | |
| parser.add_argument("--workers", type=int, default=4, | |
| help="DataLoader workers (default: 4)") | |
| parser.add_argument("--patience", type=int, default=5, | |
| help="Early stopping patience (default: 5, 0=disabled)") | |
| args = parser.parse_args() | |
| if not os.path.isdir(args.data_dir): | |
| print(f"[ERROR] Data directory not found: {args.data_dir}") | |
| print("[ERROR] Expected structure:") | |
| print(" data/real/ β real face images") | |
| print(" data/fake/ β deepfake images") | |
| exit(1) | |
| train(args) | |