File size: 5,342 Bytes
563f896
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
"""
DeepfakeDetector: Two-branch fusion architecture for deepfake detection.

Branch 1 β€” CLIP ViT vision encoder:
    Extracts high-level semantic features from RGB pixels. CLIP's pretraining
    on 400M image-text pairs gives strong cross-generator generalization
    (GANs, diffusion models, face-swap).

Branch 2 β€” FFT magnitude β†’ lightweight CNN:
    Computes the 2D FFT magnitude spectrum from the input, capturing
    frequency-domain artifacts invisible in pixel space (GAN spectral peaks,
    diffusion noise patterns, blending boundary artifacts).

The two branch embeddings are concatenated and classified by a fusion MLP head.
"""

import torch
import torch.nn as nn
from dataclasses import dataclass
from pathlib import Path
from transformers import CLIPVisionModel

# Use ViT-L/14 β€” same backbone as UnivFD (Ojha et al., CVPR 2023).
# Switch to "openai/clip-vit-base-patch16" if MPS memory is tight.
CLIP_MODEL_ID = "openai/clip-vit-large-patch14"


@dataclass
class DetectorOutput:
    logits: torch.Tensor
    fused: torch.Tensor = None  # pre-classifier fused embedding (for SupCon)


class DeepfakeDetector(nn.Module):
    """Two-branch deepfake detector: CLIP (spatial) + FFT (frequency)."""

    def __init__(self, clip_vision, fft_embed_dim=128, num_classes=2):
        super().__init__()
        self.clip_vision = clip_vision

        # Resolve hidden size for both regular and peft-wrapped models
        try:
            # PeftModel stores base config at base_model.model.config
            clip_hidden = clip_vision.base_model.model.config.hidden_size
        except AttributeError:
            clip_hidden = clip_vision.config.hidden_size

        # Branch 2: lightweight CNN on FFT magnitude spectrum
        self.fft_branch = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(64, fft_embed_dim),
            nn.ReLU(inplace=True),
        )

        # Fusion classification head
        self.classifier = nn.Sequential(
            nn.Linear(clip_hidden + fft_embed_dim, 256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1),
            nn.Linear(256, num_classes),
        )

        self.id2label = {0: "Realism", 1: "Deepfake"}

    @property
    def fused_dim(self):
        """Dimension of the fused embedding (clip_hidden + fft_embed_dim)."""
        return self.classifier[0].in_features

    def forward(self, pixel_values, return_fused=False):
        # Branch 1: CLIP spatial features
        clip_out = self.clip_vision(pixel_values=pixel_values)
        clip_embed = clip_out.pooler_output  # [B, clip_hidden]

        # Branch 2: FFT frequency features
        gray = pixel_values.mean(dim=1, keepdim=True)  # [B, 1, H, W]
        fft = torch.fft.fft2(gray)
        fft_mag = torch.log1p(torch.abs(torch.fft.fftshift(fft)))  # [B, 1, H, W]
        fft_embed = self.fft_branch(fft_mag)  # [B, fft_embed_dim]

        # Fusion
        fused = torch.cat([clip_embed, fft_embed], dim=1)
        logits = self.classifier(fused)
        return DetectorOutput(logits=logits, fused=fused if return_fused else None)

    # ── Save / Load ───────────────────────────────────────────────────────────

    def save_model(self, save_dir):
        """Save the full model for inference (call after DoRA merge)."""
        save_dir = Path(save_dir)
        save_dir.mkdir(parents=True, exist_ok=True)
        self.clip_vision.save_pretrained(save_dir / "clip_vision")
        torch.save({
            "fft_branch": self.fft_branch.state_dict(),
            "classifier": self.classifier.state_dict(),
            "id2label": self.id2label,
        }, save_dir / "head_weights.pt")

    @classmethod
    def from_pretrained(cls, model_id_or_path, device="cpu"):
        """Load a trained model for inference.

        Accepts either a local directory path or a Hugging Face repo ID.
        When loading from HF, the CLIP backbone is pulled from openai/clip-vit-large-patch14.
        """
        local = Path(model_id_or_path)
        if local.exists():
            clip_dir = local / "clip_vision"
            clip_vision = CLIPVisionModel.from_pretrained(
                str(clip_dir) if clip_dir.exists() else CLIP_MODEL_ID
            )
            head_path = local / "head_weights.pt"
        else:
            from huggingface_hub import hf_hub_download
            clip_vision = CLIPVisionModel.from_pretrained(
                model_id_or_path, subfolder="model/clip_vision"
            )
            head_path = hf_hub_download(
                repo_id=model_id_or_path, filename="model/head_weights.pt"
            )

        head_data = torch.load(head_path, map_location=device, weights_only=True)
        model = cls(clip_vision)
        model.fft_branch.load_state_dict(head_data["fft_branch"])
        model.classifier.load_state_dict(head_data["classifier"])
        model.id2label = head_data["id2label"]
        return model