|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import torchvision.models as models |
|
|
import numpy as np |
|
|
from src.utils import get_fft_feature |
|
|
|
|
|
class RGBBranch(nn.Module): |
|
|
def __init__(self, pretrained=True): |
|
|
super().__init__() |
|
|
|
|
|
weights = models.EfficientNet_V2_S_Weights.DEFAULT if pretrained else None |
|
|
self.net = models.efficientnet_v2_s(weights=weights) |
|
|
|
|
|
self.features = self.net.features |
|
|
self.avgpool = self.net.avgpool |
|
|
self.out_dim = 1280 |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.features(x) |
|
|
x = self.avgpool(x) |
|
|
x = torch.flatten(x, 1) |
|
|
return x |
|
|
|
|
|
class FreqBranch(nn.Module): |
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
|
|
|
self.net = nn.Sequential( |
|
|
nn.Conv2d(3, 32, kernel_size=3, padding=1), |
|
|
nn.BatchNorm2d(32), |
|
|
nn.ReLU(), |
|
|
nn.MaxPool2d(2), |
|
|
|
|
|
nn.Conv2d(32, 64, kernel_size=3, padding=1), |
|
|
nn.BatchNorm2d(64), |
|
|
nn.ReLU(), |
|
|
nn.MaxPool2d(2), |
|
|
|
|
|
nn.Conv2d(64, 128, kernel_size=3, padding=1), |
|
|
nn.BatchNorm2d(128), |
|
|
nn.ReLU(), |
|
|
nn.AdaptiveAvgPool2d((1,1)) |
|
|
) |
|
|
self.out_dim = 128 |
|
|
|
|
|
def forward(self, x): |
|
|
return torch.flatten(self.net(x), 1) |
|
|
|
|
|
class PatchBranch(nn.Module): |
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self.patch_encoder = nn.Sequential( |
|
|
nn.Conv2d(3, 16, kernel_size=3, padding=1), |
|
|
nn.ReLU(), |
|
|
nn.MaxPool2d(2), |
|
|
nn.Conv2d(16, 32, kernel_size=3, padding=1), |
|
|
nn.ReLU(), |
|
|
nn.MaxPool2d(2), |
|
|
nn.Conv2d(32, 64, kernel_size=3, padding=1), |
|
|
nn.ReLU(), |
|
|
nn.AdaptiveAvgPool2d((1,1)) |
|
|
) |
|
|
self.out_dim = 64 |
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
|
|
|
|
|
|
patches = x.unfold(2, 64, 64).unfold(3, 64, 64) |
|
|
|
|
|
B, C, H_grid, W_grid, H_patch, W_patch = patches.shape |
|
|
|
|
|
|
|
|
patches = patches.permute(0, 2, 3, 1, 4, 5).contiguous() |
|
|
patches = patches.view(B * H_grid * W_grid, C, H_patch, W_patch) |
|
|
|
|
|
|
|
|
feats = self.patch_encoder(patches) |
|
|
feats = torch.flatten(feats, 1) |
|
|
|
|
|
|
|
|
feats = feats.view(B, H_grid * W_grid, -1) |
|
|
|
|
|
|
|
|
feats_max, _ = torch.max(feats, dim=1) |
|
|
|
|
|
return feats_max |
|
|
|
|
|
class ViTBranch(nn.Module): |
|
|
def __init__(self, pretrained=True): |
|
|
super().__init__() |
|
|
|
|
|
weights = models.Swin_V2_T_Weights.DEFAULT if pretrained else None |
|
|
self.net = models.swin_v2_t(weights=weights) |
|
|
|
|
|
|
|
|
self.out_dim = self.net.head.in_features |
|
|
self.net.head = nn.Identity() |
|
|
|
|
|
def forward(self, x): |
|
|
return self.net(x) |
|
|
|
|
|
class DeepfakeDetector(nn.Module): |
|
|
def __init__(self, pretrained=True): |
|
|
super().__init__() |
|
|
self.rgb_branch = RGBBranch(pretrained) |
|
|
self.freq_branch = FreqBranch() |
|
|
self.patch_branch = PatchBranch() |
|
|
self.vit_branch = ViTBranch(pretrained) |
|
|
|
|
|
input_dim = (self.rgb_branch.out_dim + |
|
|
self.freq_branch.out_dim + |
|
|
self.patch_branch.out_dim + |
|
|
self.vit_branch.out_dim) |
|
|
|
|
|
|
|
|
self.classifier = nn.Sequential( |
|
|
nn.Linear(input_dim, 512), |
|
|
nn.BatchNorm1d(512), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(0.5), |
|
|
nn.Linear(512, 1) |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
rgb_feat = self.rgb_branch(x) |
|
|
|
|
|
|
|
|
freq_img = get_fft_feature(x) |
|
|
freq_feat = self.freq_branch(freq_img) |
|
|
|
|
|
|
|
|
patch_feat = self.patch_branch(x) |
|
|
|
|
|
|
|
|
vit_feat = self.vit_branch(x) |
|
|
|
|
|
|
|
|
combined = torch.cat([rgb_feat, freq_feat, patch_feat, vit_feat], dim=1) |
|
|
|
|
|
return self.classifier(combined) |
|
|
|
|
|
def get_heatmap(self, x): |
|
|
"""Generate Grad-CAM heatmap for the input image""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gradients = [] |
|
|
activations = [] |
|
|
|
|
|
def backward_hook(module, grad_input, grad_output): |
|
|
gradients.append(grad_output[0]) |
|
|
|
|
|
def forward_hook(module, input, output): |
|
|
activations.append(output) |
|
|
|
|
|
|
|
|
target_layer = self.rgb_branch.features[-1] |
|
|
hook_b = target_layer.register_full_backward_hook(backward_hook) |
|
|
hook_f = target_layer.register_forward_hook(forward_hook) |
|
|
|
|
|
|
|
|
logits = self(x) |
|
|
pred_idx = 0 |
|
|
|
|
|
|
|
|
self.zero_grad() |
|
|
logits.backward(retain_graph=True) |
|
|
|
|
|
|
|
|
pooled_gradients = torch.mean(gradients[0], dim=[0, 2, 3]) |
|
|
activation = activations[0][0] |
|
|
|
|
|
|
|
|
for i in range(activation.shape[0]): |
|
|
activation[i, :, :] *= pooled_gradients[i] |
|
|
|
|
|
heatmap = torch.mean(activation, dim=0).cpu().detach().numpy() |
|
|
heatmap = np.maximum(heatmap, 0) |
|
|
|
|
|
|
|
|
if np.max(heatmap) != 0: |
|
|
heatmap /= np.max(heatmap) |
|
|
|
|
|
|
|
|
hook_b.remove() |
|
|
hook_f.remove() |
|
|
|
|
|
return heatmap |
|
|
|