File size: 6,825 Bytes
d33766a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 |
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import numpy as np
from src.utils import get_fft_feature
class RGBBranch(nn.Module):
def __init__(self, pretrained=True):
super().__init__()
# EfficientNet V2 Small: Robust and efficient spatial features
weights = models.EfficientNet_V2_S_Weights.DEFAULT if pretrained else None
self.net = models.efficientnet_v2_s(weights=weights)
# Extract features before classification head
self.features = self.net.features
self.avgpool = self.net.avgpool
self.out_dim = 1280
def forward(self, x):
x = self.features(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
return x
class FreqBranch(nn.Module):
def __init__(self):
super().__init__()
# Simple CNN to analyze frequency domain patterns
self.net = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.AdaptiveAvgPool2d((1,1))
)
self.out_dim = 128
def forward(self, x):
return torch.flatten(self.net(x), 1)
class PatchBranch(nn.Module):
def __init__(self):
super().__init__()
# Analyzes local patches for inconsistencies
# Shared lightweight CNN for each patch
self.patch_encoder = nn.Sequential(
nn.Conv2d(3, 16, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2), # 64 -> 32
nn.Conv2d(16, 32, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2), # 32 -> 16
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.ReLU(),
nn.AdaptiveAvgPool2d((1,1))
)
self.out_dim = 64
def forward(self, x):
# x: (B, 3, 256, 256)
# Create 4x4=16 patches of size 64x64
# Unfold logic: kernel_size=64, stride=64
patches = x.unfold(2, 64, 64).unfold(3, 64, 64)
# patches shape: (B, 3, 4, 4, 64, 64)
B, C, H_grid, W_grid, H_patch, W_patch = patches.shape
# Merge batch and grid dimensions for parallel processing
patches = patches.permute(0, 2, 3, 1, 4, 5).contiguous()
patches = patches.view(B * H_grid * W_grid, C, H_patch, W_patch)
# Encode
feats = self.patch_encoder(patches) # (B*16, 64, 1, 1)
feats = torch.flatten(feats, 1) # (B*16, 64)
# Aggregate back to B
feats = feats.view(B, H_grid * W_grid, -1) # (B, 16, 64)
# Max pool over patches to capture the "most fake" patch signal
feats_max, _ = torch.max(feats, dim=1) # (B, 64)
return feats_max
class ViTBranch(nn.Module):
def __init__(self, pretrained=True):
super().__init__()
# Swin Transformer Tiny: Capture long-range dependencies
weights = models.Swin_V2_T_Weights.DEFAULT if pretrained else None
self.net = models.swin_v2_t(weights=weights)
# Replace head with Identity to get features
self.out_dim = self.net.head.in_features
self.net.head = nn.Identity()
def forward(self, x):
return self.net(x)
class DeepfakeDetector(nn.Module):
def __init__(self, pretrained=True):
super().__init__()
self.rgb_branch = RGBBranch(pretrained)
self.freq_branch = FreqBranch()
self.patch_branch = PatchBranch()
self.vit_branch = ViTBranch(pretrained)
input_dim = (self.rgb_branch.out_dim +
self.freq_branch.out_dim +
self.patch_branch.out_dim +
self.vit_branch.out_dim)
# Confidence-based fusion head
self.classifier = nn.Sequential(
nn.Linear(input_dim, 512),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(512, 1)
)
def forward(self, x):
# 1. Spatial Analysis
rgb_feat = self.rgb_branch(x)
# 2. Frequency Analysis
freq_img = get_fft_feature(x)
freq_feat = self.freq_branch(freq_img)
# 3. Patch Analysis (Local Inconsistencies)
patch_feat = self.patch_branch(x)
# 4. Global Consistency (ViT)
vit_feat = self.vit_branch(x)
# 5. Feature Fusion
combined = torch.cat([rgb_feat, freq_feat, patch_feat, vit_feat], dim=1)
return self.classifier(combined)
def get_heatmap(self, x):
"""Generate Grad-CAM heatmap for the input image"""
# We'll use the RGB branch for visualization as it contains spatial features
# Enable gradients for the input if needed, though typically we hook into layers
# 1. Forward pass through RGB branch
# We need to register a hook on the last conv layer of the efficientnet features
# Target layer: self.rgb_branch.features[-1] (the last block)
gradients = []
activations = []
def backward_hook(module, grad_input, grad_output):
gradients.append(grad_output[0])
def forward_hook(module, input, output):
activations.append(output)
# Register hooks on the last convolutional layer of RGB branch
target_layer = self.rgb_branch.features[-1]
hook_b = target_layer.register_full_backward_hook(backward_hook)
hook_f = target_layer.register_forward_hook(forward_hook)
# Forward pass
logits = self(x)
pred_idx = 0 # Binary classification, output is scalar logic
# Backward pass
self.zero_grad()
logits.backward(retain_graph=True)
# Get gradients and activations
pooled_gradients = torch.mean(gradients[0], dim=[0, 2, 3])
activation = activations[0][0]
# Weight activations by gradients (Grad-CAM)
for i in range(activation.shape[0]):
activation[i, :, :] *= pooled_gradients[i]
heatmap = torch.mean(activation, dim=0).cpu().detach().numpy()
heatmap = np.maximum(heatmap, 0) # ReLU
# Normalize
if np.max(heatmap) != 0:
heatmap /= np.max(heatmap)
# Remove hooks
hook_b.remove()
hook_f.remove()
return heatmap
|