| """ |
| DeepfakeDetector: Two-branch fusion architecture for deepfake detection. |
| |
| Branch 1 β CLIP ViT vision encoder: |
| Extracts high-level semantic features from RGB pixels. CLIP's pretraining |
| on 400M image-text pairs gives strong cross-generator generalization |
| (GANs, diffusion models, face-swap). |
| |
| Branch 2 β FFT magnitude β lightweight CNN: |
| Computes the 2D FFT magnitude spectrum from the input, capturing |
| frequency-domain artifacts invisible in pixel space (GAN spectral peaks, |
| diffusion noise patterns, blending boundary artifacts). |
| |
| The two branch embeddings are concatenated and classified by a fusion MLP head. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| from dataclasses import dataclass |
| from pathlib import Path |
| from transformers import CLIPVisionModel |
|
|
| |
| |
| CLIP_MODEL_ID = "openai/clip-vit-large-patch14" |
|
|
|
|
| @dataclass |
| class DetectorOutput: |
| logits: torch.Tensor |
| fused: torch.Tensor = None |
|
|
|
|
| class DeepfakeDetector(nn.Module): |
| """Two-branch deepfake detector: CLIP (spatial) + FFT (frequency).""" |
|
|
| def __init__(self, clip_vision, fft_embed_dim=128, num_classes=2): |
| super().__init__() |
| self.clip_vision = clip_vision |
|
|
| |
| try: |
| |
| clip_hidden = clip_vision.base_model.model.config.hidden_size |
| except AttributeError: |
| clip_hidden = clip_vision.config.hidden_size |
|
|
| |
| self.fft_branch = nn.Sequential( |
| nn.Conv2d(1, 32, kernel_size=3, padding=1), |
| nn.BatchNorm2d(32), |
| nn.ReLU(inplace=True), |
| nn.MaxPool2d(2), |
| nn.Conv2d(32, 64, kernel_size=3, padding=1), |
| nn.BatchNorm2d(64), |
| nn.ReLU(inplace=True), |
| nn.AdaptiveAvgPool2d(1), |
| nn.Flatten(), |
| nn.Linear(64, fft_embed_dim), |
| nn.ReLU(inplace=True), |
| ) |
|
|
| |
| self.classifier = nn.Sequential( |
| nn.Linear(clip_hidden + fft_embed_dim, 256), |
| nn.ReLU(inplace=True), |
| nn.Dropout(0.1), |
| nn.Linear(256, num_classes), |
| ) |
|
|
| self.id2label = {0: "Realism", 1: "Deepfake"} |
|
|
| @property |
| def fused_dim(self): |
| """Dimension of the fused embedding (clip_hidden + fft_embed_dim).""" |
| return self.classifier[0].in_features |
|
|
| def forward(self, pixel_values, return_fused=False): |
| |
| clip_out = self.clip_vision(pixel_values=pixel_values) |
| clip_embed = clip_out.pooler_output |
|
|
| |
| gray = pixel_values.mean(dim=1, keepdim=True) |
| fft = torch.fft.fft2(gray) |
| fft_mag = torch.log1p(torch.abs(torch.fft.fftshift(fft))) |
| fft_embed = self.fft_branch(fft_mag) |
|
|
| |
| fused = torch.cat([clip_embed, fft_embed], dim=1) |
| logits = self.classifier(fused) |
| return DetectorOutput(logits=logits, fused=fused if return_fused else None) |
|
|
| |
|
|
| def save_model(self, save_dir): |
| """Save the full model for inference (call after DoRA merge).""" |
| save_dir = Path(save_dir) |
| save_dir.mkdir(parents=True, exist_ok=True) |
| self.clip_vision.save_pretrained(save_dir / "clip_vision") |
| torch.save({ |
| "fft_branch": self.fft_branch.state_dict(), |
| "classifier": self.classifier.state_dict(), |
| "id2label": self.id2label, |
| }, save_dir / "head_weights.pt") |
|
|
| @classmethod |
| def from_pretrained(cls, model_id_or_path, device="cpu"): |
| """Load a trained model for inference. |
| |
| Accepts either a local directory path or a Hugging Face repo ID. |
| When loading from HF, the CLIP backbone is pulled from openai/clip-vit-large-patch14. |
| """ |
| local = Path(model_id_or_path) |
| if local.exists(): |
| clip_dir = local / "clip_vision" |
| clip_vision = CLIPVisionModel.from_pretrained( |
| str(clip_dir) if clip_dir.exists() else CLIP_MODEL_ID |
| ) |
| head_path = local / "head_weights.pt" |
| else: |
| from huggingface_hub import hf_hub_download |
| clip_vision = CLIPVisionModel.from_pretrained( |
| model_id_or_path, subfolder="model/clip_vision" |
| ) |
| head_path = hf_hub_download( |
| repo_id=model_id_or_path, filename="model/head_weights.pt" |
| ) |
|
|
| head_data = torch.load(head_path, map_location=device, weights_only=True) |
| model = cls(clip_vision) |
| model.fft_branch.load_state_dict(head_data["fft_branch"]) |
| model.classifier.load_state_dict(head_data["classifier"]) |
| model.id2label = head_data["id2label"] |
| return model |
|
|