File size: 5,342 Bytes
563f896 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 | """
DeepfakeDetector: Two-branch fusion architecture for deepfake detection.
Branch 1 β CLIP ViT vision encoder:
Extracts high-level semantic features from RGB pixels. CLIP's pretraining
on 400M image-text pairs gives strong cross-generator generalization
(GANs, diffusion models, face-swap).
Branch 2 β FFT magnitude β lightweight CNN:
Computes the 2D FFT magnitude spectrum from the input, capturing
frequency-domain artifacts invisible in pixel space (GAN spectral peaks,
diffusion noise patterns, blending boundary artifacts).
The two branch embeddings are concatenated and classified by a fusion MLP head.
"""
import torch
import torch.nn as nn
from dataclasses import dataclass
from pathlib import Path
from transformers import CLIPVisionModel
# Use ViT-L/14 β same backbone as UnivFD (Ojha et al., CVPR 2023).
# Switch to "openai/clip-vit-base-patch16" if MPS memory is tight.
CLIP_MODEL_ID = "openai/clip-vit-large-patch14"
@dataclass
class DetectorOutput:
logits: torch.Tensor
fused: torch.Tensor = None # pre-classifier fused embedding (for SupCon)
class DeepfakeDetector(nn.Module):
"""Two-branch deepfake detector: CLIP (spatial) + FFT (frequency)."""
def __init__(self, clip_vision, fft_embed_dim=128, num_classes=2):
super().__init__()
self.clip_vision = clip_vision
# Resolve hidden size for both regular and peft-wrapped models
try:
# PeftModel stores base config at base_model.model.config
clip_hidden = clip_vision.base_model.model.config.hidden_size
except AttributeError:
clip_hidden = clip_vision.config.hidden_size
# Branch 2: lightweight CNN on FFT magnitude spectrum
self.fft_branch = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Linear(64, fft_embed_dim),
nn.ReLU(inplace=True),
)
# Fusion classification head
self.classifier = nn.Sequential(
nn.Linear(clip_hidden + fft_embed_dim, 256),
nn.ReLU(inplace=True),
nn.Dropout(0.1),
nn.Linear(256, num_classes),
)
self.id2label = {0: "Realism", 1: "Deepfake"}
@property
def fused_dim(self):
"""Dimension of the fused embedding (clip_hidden + fft_embed_dim)."""
return self.classifier[0].in_features
def forward(self, pixel_values, return_fused=False):
# Branch 1: CLIP spatial features
clip_out = self.clip_vision(pixel_values=pixel_values)
clip_embed = clip_out.pooler_output # [B, clip_hidden]
# Branch 2: FFT frequency features
gray = pixel_values.mean(dim=1, keepdim=True) # [B, 1, H, W]
fft = torch.fft.fft2(gray)
fft_mag = torch.log1p(torch.abs(torch.fft.fftshift(fft))) # [B, 1, H, W]
fft_embed = self.fft_branch(fft_mag) # [B, fft_embed_dim]
# Fusion
fused = torch.cat([clip_embed, fft_embed], dim=1)
logits = self.classifier(fused)
return DetectorOutput(logits=logits, fused=fused if return_fused else None)
# ββ Save / Load βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def save_model(self, save_dir):
"""Save the full model for inference (call after DoRA merge)."""
save_dir = Path(save_dir)
save_dir.mkdir(parents=True, exist_ok=True)
self.clip_vision.save_pretrained(save_dir / "clip_vision")
torch.save({
"fft_branch": self.fft_branch.state_dict(),
"classifier": self.classifier.state_dict(),
"id2label": self.id2label,
}, save_dir / "head_weights.pt")
@classmethod
def from_pretrained(cls, model_id_or_path, device="cpu"):
"""Load a trained model for inference.
Accepts either a local directory path or a Hugging Face repo ID.
When loading from HF, the CLIP backbone is pulled from openai/clip-vit-large-patch14.
"""
local = Path(model_id_or_path)
if local.exists():
clip_dir = local / "clip_vision"
clip_vision = CLIPVisionModel.from_pretrained(
str(clip_dir) if clip_dir.exists() else CLIP_MODEL_ID
)
head_path = local / "head_weights.pt"
else:
from huggingface_hub import hf_hub_download
clip_vision = CLIPVisionModel.from_pretrained(
model_id_or_path, subfolder="model/clip_vision"
)
head_path = hf_hub_download(
repo_id=model_id_or_path, filename="model/head_weights.pt"
)
head_data = torch.load(head_path, map_location=device, weights_only=True)
model = cls(clip_vision)
model.fft_branch.load_state_dict(head_data["fft_branch"])
model.classifier.load_state_dict(head_data["classifier"])
model.id2label = head_data["id2label"]
return model
|