Harshasnade commited on
Commit
d33766a
·
verified ·
1 Parent(s): 28d680c

Upload folder using huggingface_hub

Browse files
Files changed (4) hide show
  1. handler.py +81 -0
  2. models.py +197 -0
  3. requirements.txt +5 -0
  4. utils.py +37 -0
handler.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from io import BytesIO
3
+ from PIL import Image
4
+ import torch
5
+ import base64
6
+ import numpy as np
7
+ import cv2
8
+ import albumentations as A
9
+ from albumentations.pytorch import ToTensorV2
10
+ from safetensors.torch import load_file
11
+
12
+ # Import your model definition
13
+ from models import DeepfakeDetector
14
+
15
+ class EndpointHandler:
16
+ def __init__(self, path="."):
17
+ # Load model definition
18
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+ self.device = device
20
+ self.model = DeepfakeDetector(pretrained=False) # Architecture only
21
+
22
+ # Load weights
23
+ try:
24
+ # Try loading safetensors
25
+ state_dict = load_file(f"{path}/best_model.safetensors")
26
+ self.model.load_state_dict(state_dict, strict=False)
27
+ except Exception as e:
28
+ print(f"Error loading weights: {e}")
29
+ # Fallback path if necessary
30
+ state_dict = load_file("best_model.safetensors")
31
+ self.model.load_state_dict(state_dict, strict=False)
32
+
33
+ self.model.to(device)
34
+ self.model.eval()
35
+
36
+ # Define transform
37
+ self.transform = A.Compose([
38
+ A.Resize(224, 224),
39
+ A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
40
+ ToTensorV2(),
41
+ ])
42
+
43
+ def __call__(self, data: Any) -> List[Dict[str, Any]]:
44
+ inputs = data.pop("inputs", data)
45
+
46
+ # Decode image
47
+ image = None
48
+ if isinstance(inputs, Image.Image):
49
+ image = inputs
50
+ elif isinstance(inputs, str):
51
+ # Try base64
52
+ try:
53
+ if "base64," in inputs:
54
+ inputs = inputs.split("base64,")[1]
55
+ image_bytes = base64.b64decode(inputs)
56
+ image = Image.open(BytesIO(image_bytes))
57
+ except:
58
+ # Url?
59
+ pass
60
+ elif isinstance(inputs, bytes):
61
+ image = Image.open(BytesIO(inputs))
62
+
63
+ if image is None:
64
+ return [{"error": "Invalid input format"}]
65
+
66
+ image = image.convert("RGB")
67
+ image_np = np.array(image)
68
+
69
+ # Augmentations expect numpy array
70
+ augmented = self.transform(image=image_np)
71
+ image_tensor = augmented['image'].unsqueeze(0).to(self.device)
72
+
73
+ # Inference
74
+ with torch.no_grad():
75
+ output = self.model(image_tensor)
76
+ prob = torch.sigmoid(output).item()
77
+
78
+ label = "FAKE" if prob > 0.5 else "REAL"
79
+ score = prob if prob > 0.5 else 1 - prob
80
+
81
+ return [{"label": label, "score": score}]
models.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torchvision.models as models
5
+ import numpy as np
6
+ from src.utils import get_fft_feature
7
+
8
+ class RGBBranch(nn.Module):
9
+ def __init__(self, pretrained=True):
10
+ super().__init__()
11
+ # EfficientNet V2 Small: Robust and efficient spatial features
12
+ weights = models.EfficientNet_V2_S_Weights.DEFAULT if pretrained else None
13
+ self.net = models.efficientnet_v2_s(weights=weights)
14
+ # Extract features before classification head
15
+ self.features = self.net.features
16
+ self.avgpool = self.net.avgpool
17
+ self.out_dim = 1280
18
+
19
+ def forward(self, x):
20
+ x = self.features(x)
21
+ x = self.avgpool(x)
22
+ x = torch.flatten(x, 1)
23
+ return x
24
+
25
+ class FreqBranch(nn.Module):
26
+ def __init__(self):
27
+ super().__init__()
28
+ # Simple CNN to analyze frequency domain patterns
29
+ self.net = nn.Sequential(
30
+ nn.Conv2d(3, 32, kernel_size=3, padding=1),
31
+ nn.BatchNorm2d(32),
32
+ nn.ReLU(),
33
+ nn.MaxPool2d(2),
34
+
35
+ nn.Conv2d(32, 64, kernel_size=3, padding=1),
36
+ nn.BatchNorm2d(64),
37
+ nn.ReLU(),
38
+ nn.MaxPool2d(2),
39
+
40
+ nn.Conv2d(64, 128, kernel_size=3, padding=1),
41
+ nn.BatchNorm2d(128),
42
+ nn.ReLU(),
43
+ nn.AdaptiveAvgPool2d((1,1))
44
+ )
45
+ self.out_dim = 128
46
+
47
+ def forward(self, x):
48
+ return torch.flatten(self.net(x), 1)
49
+
50
+ class PatchBranch(nn.Module):
51
+ def __init__(self):
52
+ super().__init__()
53
+ # Analyzes local patches for inconsistencies
54
+ # Shared lightweight CNN for each patch
55
+ self.patch_encoder = nn.Sequential(
56
+ nn.Conv2d(3, 16, kernel_size=3, padding=1),
57
+ nn.ReLU(),
58
+ nn.MaxPool2d(2), # 64 -> 32
59
+ nn.Conv2d(16, 32, kernel_size=3, padding=1),
60
+ nn.ReLU(),
61
+ nn.MaxPool2d(2), # 32 -> 16
62
+ nn.Conv2d(32, 64, kernel_size=3, padding=1),
63
+ nn.ReLU(),
64
+ nn.AdaptiveAvgPool2d((1,1))
65
+ )
66
+ self.out_dim = 64
67
+
68
+ def forward(self, x):
69
+ # x: (B, 3, 256, 256)
70
+ # Create 4x4=16 patches of size 64x64
71
+ # Unfold logic: kernel_size=64, stride=64
72
+ patches = x.unfold(2, 64, 64).unfold(3, 64, 64)
73
+ # patches shape: (B, 3, 4, 4, 64, 64)
74
+ B, C, H_grid, W_grid, H_patch, W_patch = patches.shape
75
+
76
+ # Merge batch and grid dimensions for parallel processing
77
+ patches = patches.permute(0, 2, 3, 1, 4, 5).contiguous()
78
+ patches = patches.view(B * H_grid * W_grid, C, H_patch, W_patch)
79
+
80
+ # Encode
81
+ feats = self.patch_encoder(patches) # (B*16, 64, 1, 1)
82
+ feats = torch.flatten(feats, 1) # (B*16, 64)
83
+
84
+ # Aggregate back to B
85
+ feats = feats.view(B, H_grid * W_grid, -1) # (B, 16, 64)
86
+
87
+ # Max pool over patches to capture the "most fake" patch signal
88
+ feats_max, _ = torch.max(feats, dim=1) # (B, 64)
89
+
90
+ return feats_max
91
+
92
+ class ViTBranch(nn.Module):
93
+ def __init__(self, pretrained=True):
94
+ super().__init__()
95
+ # Swin Transformer Tiny: Capture long-range dependencies
96
+ weights = models.Swin_V2_T_Weights.DEFAULT if pretrained else None
97
+ self.net = models.swin_v2_t(weights=weights)
98
+
99
+ # Replace head with Identity to get features
100
+ self.out_dim = self.net.head.in_features
101
+ self.net.head = nn.Identity()
102
+
103
+ def forward(self, x):
104
+ return self.net(x)
105
+
106
+ class DeepfakeDetector(nn.Module):
107
+ def __init__(self, pretrained=True):
108
+ super().__init__()
109
+ self.rgb_branch = RGBBranch(pretrained)
110
+ self.freq_branch = FreqBranch()
111
+ self.patch_branch = PatchBranch()
112
+ self.vit_branch = ViTBranch(pretrained)
113
+
114
+ input_dim = (self.rgb_branch.out_dim +
115
+ self.freq_branch.out_dim +
116
+ self.patch_branch.out_dim +
117
+ self.vit_branch.out_dim)
118
+
119
+ # Confidence-based fusion head
120
+ self.classifier = nn.Sequential(
121
+ nn.Linear(input_dim, 512),
122
+ nn.BatchNorm1d(512),
123
+ nn.ReLU(),
124
+ nn.Dropout(0.5),
125
+ nn.Linear(512, 1)
126
+ )
127
+
128
+ def forward(self, x):
129
+ # 1. Spatial Analysis
130
+ rgb_feat = self.rgb_branch(x)
131
+
132
+ # 2. Frequency Analysis
133
+ freq_img = get_fft_feature(x)
134
+ freq_feat = self.freq_branch(freq_img)
135
+
136
+ # 3. Patch Analysis (Local Inconsistencies)
137
+ patch_feat = self.patch_branch(x)
138
+
139
+ # 4. Global Consistency (ViT)
140
+ vit_feat = self.vit_branch(x)
141
+
142
+ # 5. Feature Fusion
143
+ combined = torch.cat([rgb_feat, freq_feat, patch_feat, vit_feat], dim=1)
144
+
145
+ return self.classifier(combined)
146
+
147
+ def get_heatmap(self, x):
148
+ """Generate Grad-CAM heatmap for the input image"""
149
+ # We'll use the RGB branch for visualization as it contains spatial features
150
+ # Enable gradients for the input if needed, though typically we hook into layers
151
+
152
+ # 1. Forward pass through RGB branch
153
+ # We need to register a hook on the last conv layer of the efficientnet features
154
+ # Target layer: self.rgb_branch.features[-1] (the last block)
155
+
156
+ gradients = []
157
+ activations = []
158
+
159
+ def backward_hook(module, grad_input, grad_output):
160
+ gradients.append(grad_output[0])
161
+
162
+ def forward_hook(module, input, output):
163
+ activations.append(output)
164
+
165
+ # Register hooks on the last convolutional layer of RGB branch
166
+ target_layer = self.rgb_branch.features[-1]
167
+ hook_b = target_layer.register_full_backward_hook(backward_hook)
168
+ hook_f = target_layer.register_forward_hook(forward_hook)
169
+
170
+ # Forward pass
171
+ logits = self(x)
172
+ pred_idx = 0 # Binary classification, output is scalar logic
173
+
174
+ # Backward pass
175
+ self.zero_grad()
176
+ logits.backward(retain_graph=True)
177
+
178
+ # Get gradients and activations
179
+ pooled_gradients = torch.mean(gradients[0], dim=[0, 2, 3])
180
+ activation = activations[0][0]
181
+
182
+ # Weight activations by gradients (Grad-CAM)
183
+ for i in range(activation.shape[0]):
184
+ activation[i, :, :] *= pooled_gradients[i]
185
+
186
+ heatmap = torch.mean(activation, dim=0).cpu().detach().numpy()
187
+ heatmap = np.maximum(heatmap, 0) # ReLU
188
+
189
+ # Normalize
190
+ if np.max(heatmap) != 0:
191
+ heatmap /= np.max(heatmap)
192
+
193
+ # Remove hooks
194
+ hook_b.remove()
195
+ hook_f.remove()
196
+
197
+ return heatmap
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ albumentations
4
+ safetensors
5
+ opencv-python-headless
utils.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import cv2
4
+
5
+ def get_fft_feature(x):
6
+ """
7
+ Computes the Log-Magnitude Spectrum of the input images.
8
+ Args:
9
+ x (torch.Tensor): Input images of shape (B, C, H, W)
10
+ Returns:
11
+ torch.Tensor: Log-magnitude spectrum of shape (B, C, H, W)
12
+ """
13
+ if x.dim() == 3:
14
+ x = x.unsqueeze(0)
15
+
16
+ # Compute 2D FFT
17
+ fft = torch.fft.fft2(x, norm='ortho')
18
+
19
+ # Compute magnitude
20
+ mag = torch.abs(fft)
21
+
22
+ # Apply log scale (add epsilon for stability)
23
+ mag = torch.log(mag + 1e-6)
24
+
25
+ # Shift zero-frequency component to the center of the spectrum
26
+ mag = torch.fft.fftshift(mag, dim=(-2, -1))
27
+
28
+ return mag
29
+
30
+ def min_max_normalize(tensor):
31
+ """
32
+ Min-max normalization for visualization or stable training provided tensor.
33
+ """
34
+ min_val = tensor.min()
35
+ max_val = tensor.max()
36
+ return (tensor - min_val) / (max_val - min_val + 1e-8)
37
+