| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import torchvision.models as models |
| | import numpy as np |
| | from src.utils import get_fft_feature |
| |
|
| | class RGBBranch(nn.Module): |
| | def __init__(self, pretrained=True): |
| | super().__init__() |
| | |
| | weights = models.EfficientNet_V2_S_Weights.DEFAULT if pretrained else None |
| | self.net = models.efficientnet_v2_s(weights=weights) |
| | |
| | self.features = self.net.features |
| | self.avgpool = self.net.avgpool |
| | self.out_dim = 1280 |
| |
|
| | def forward(self, x): |
| | x = self.features(x) |
| | x = self.avgpool(x) |
| | x = torch.flatten(x, 1) |
| | return x |
| |
|
| | class FreqBranch(nn.Module): |
| | def __init__(self): |
| | super().__init__() |
| | |
| | self.net = nn.Sequential( |
| | nn.Conv2d(3, 32, kernel_size=3, padding=1), |
| | nn.BatchNorm2d(32), |
| | nn.ReLU(), |
| | nn.MaxPool2d(2), |
| | |
| | nn.Conv2d(32, 64, kernel_size=3, padding=1), |
| | nn.BatchNorm2d(64), |
| | nn.ReLU(), |
| | nn.MaxPool2d(2), |
| | |
| | nn.Conv2d(64, 128, kernel_size=3, padding=1), |
| | nn.BatchNorm2d(128), |
| | nn.ReLU(), |
| | nn.AdaptiveAvgPool2d((1,1)) |
| | ) |
| | self.out_dim = 128 |
| | |
| | def forward(self, x): |
| | return torch.flatten(self.net(x), 1) |
| |
|
| | class PatchBranch(nn.Module): |
| | def __init__(self): |
| | super().__init__() |
| | |
| | |
| | self.patch_encoder = nn.Sequential( |
| | nn.Conv2d(3, 16, kernel_size=3, padding=1), |
| | nn.ReLU(), |
| | nn.MaxPool2d(2), |
| | nn.Conv2d(16, 32, kernel_size=3, padding=1), |
| | nn.ReLU(), |
| | nn.MaxPool2d(2), |
| | nn.Conv2d(32, 64, kernel_size=3, padding=1), |
| | nn.ReLU(), |
| | nn.AdaptiveAvgPool2d((1,1)) |
| | ) |
| | self.out_dim = 64 |
| |
|
| | def forward(self, x): |
| | |
| | |
| | |
| | patches = x.unfold(2, 64, 64).unfold(3, 64, 64) |
| | |
| | B, C, H_grid, W_grid, H_patch, W_patch = patches.shape |
| | |
| | |
| | patches = patches.permute(0, 2, 3, 1, 4, 5).contiguous() |
| | patches = patches.view(B * H_grid * W_grid, C, H_patch, W_patch) |
| | |
| | |
| | feats = self.patch_encoder(patches) |
| | feats = torch.flatten(feats, 1) |
| | |
| | |
| | feats = feats.view(B, H_grid * W_grid, -1) |
| | |
| | |
| | feats_max, _ = torch.max(feats, dim=1) |
| | |
| | return feats_max |
| |
|
| | class ViTBranch(nn.Module): |
| | def __init__(self, pretrained=True): |
| | super().__init__() |
| | |
| | weights = models.Swin_V2_T_Weights.DEFAULT if pretrained else None |
| | self.net = models.swin_v2_t(weights=weights) |
| | |
| | |
| | self.out_dim = self.net.head.in_features |
| | self.net.head = nn.Identity() |
| | |
| | def forward(self, x): |
| | return self.net(x) |
| |
|
| | class DeepfakeDetector(nn.Module): |
| | def __init__(self, pretrained=True): |
| | super().__init__() |
| | self.rgb_branch = RGBBranch(pretrained) |
| | self.freq_branch = FreqBranch() |
| | self.patch_branch = PatchBranch() |
| | self.vit_branch = ViTBranch(pretrained) |
| | |
| | input_dim = (self.rgb_branch.out_dim + |
| | self.freq_branch.out_dim + |
| | self.patch_branch.out_dim + |
| | self.vit_branch.out_dim) |
| | |
| | |
| | self.classifier = nn.Sequential( |
| | nn.Linear(input_dim, 512), |
| | nn.BatchNorm1d(512), |
| | nn.ReLU(), |
| | nn.Dropout(0.5), |
| | nn.Linear(512, 1) |
| | ) |
| | |
| | def forward(self, x): |
| | |
| | rgb_feat = self.rgb_branch(x) |
| | |
| | |
| | freq_img = get_fft_feature(x) |
| | freq_feat = self.freq_branch(freq_img) |
| | |
| | |
| | patch_feat = self.patch_branch(x) |
| | |
| | |
| | vit_feat = self.vit_branch(x) |
| | |
| | |
| | combined = torch.cat([rgb_feat, freq_feat, patch_feat, vit_feat], dim=1) |
| | |
| | return self.classifier(combined) |
| |
|
| | def get_heatmap(self, x): |
| | """Generate Grad-CAM heatmap for the input image""" |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | gradients = [] |
| | activations = [] |
| | |
| | def backward_hook(module, grad_input, grad_output): |
| | gradients.append(grad_output[0]) |
| | |
| | def forward_hook(module, input, output): |
| | activations.append(output) |
| | |
| | |
| | target_layer = self.rgb_branch.features[-1] |
| | hook_b = target_layer.register_full_backward_hook(backward_hook) |
| | hook_f = target_layer.register_forward_hook(forward_hook) |
| | |
| | |
| | logits = self(x) |
| | pred_idx = 0 |
| | |
| | |
| | self.zero_grad() |
| | logits.backward(retain_graph=True) |
| | |
| | |
| | pooled_gradients = torch.mean(gradients[0], dim=[0, 2, 3]) |
| | activation = activations[0][0] |
| | |
| | |
| | for i in range(activation.shape[0]): |
| | activation[i, :, :] *= pooled_gradients[i] |
| | |
| | heatmap = torch.mean(activation, dim=0).cpu().detach().numpy() |
| | heatmap = np.maximum(heatmap, 0) |
| | |
| | |
| | if np.max(heatmap) != 0: |
| | heatmap /= np.max(heatmap) |
| | |
| | |
| | hook_b.remove() |
| | hook_f.remove() |
| | |
| | return heatmap |
| |
|