File size: 13,948 Bytes
d23039a
 
 
 
3909c31
d23039a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1b18758
 
 
 
 
 
 
 
 
 
d23039a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1b18758
d23039a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1b18758
 
 
 
 
 
 
d23039a
 
 
 
 
 
 
 
 
 
 
fba30db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3909c31
 
fba30db
3909c31
 
fba30db
 
 
3909c31
 
 
 
 
 
 
 
 
fba30db
3909c31
 
 
 
 
 
fba30db
3909c31
fba30db
3909c31
 
 
 
 
 
 
 
 
fba30db
3909c31
 
 
fba30db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3909c31
 
07ff735
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d23039a
 
 
07ff735
3909c31
fba30db
3909c31
fba30db
3909c31
07ff735
 
 
 
 
 
 
 
 
 
 
 
 
 
3909c31
fba30db
 
3909c31
fba30db
 
 
 
 
 
 
 
 
3909c31
 
 
 
fba30db
3909c31
fba30db
3909c31
e126c62
 
 
 
 
 
 
 
 
d23039a
 
 
 
 
 
 
 
fba30db
 
 
 
d23039a
fba30db
d23039a
fba30db
 
 
 
d23039a
fba30db
d23039a
 
 
 
 
 
 
 
fba30db
 
 
 
d23039a
 
fba30db
d23039a
 
fba30db
d23039a
fba30db
d23039a
fba30db
d23039a
fba30db
d23039a
fba30db
 
d23039a
fba30db
d23039a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
from __future__ import annotations

import base64
import io
from typing import Literal, Optional

import cv2
import numpy as np
import torch
from loguru import logger
from PIL import Image
from pytorch_grad_cam import GradCAMPlusPlus
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget

from config import settings
from models.model_loader import get_model_loader


class _HFLogitsWrapper(torch.nn.Module):
    """Wrap a HuggingFace image classification model so forward() returns logits
    as a plain tensor (pytorch_grad_cam expects tensor outputs, not dicts/dataclasses).
    """

    def __init__(self, model: torch.nn.Module) -> None:
        super().__init__()
        self.model = model

    def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:  # type: ignore[override]
        return self.model(pixel_values=pixel_values).logits


def _vit_reshape_transform(tensor: torch.Tensor, height: int = 14, width: int = 14) -> torch.Tensor:
    """Grad-CAM expects (B, C, H, W); ViT hidden states are (B, 1+H*W, C).
    Drop the CLS token and reshape tokens into a spatial grid.
    """
    result = tensor[:, 1:, :]
    b, n, c = result.shape
    result = result.reshape(b, height, width, c)
    result = result.permute(0, 3, 1, 2)  # (B, C, H, W)
    return result


def _find_class_index(model: torch.nn.Module, label_tokens: tuple[str, ...]) -> Optional[int]:
    """Find the first class index whose label contains one of `label_tokens`."""
    id2label: dict[int, str] = getattr(getattr(model, "config", None), "id2label", {}) or {}
    for idx, label in id2label.items():
        lowered = str(label).lower()
        if any(token in lowered for token in label_tokens):
            return int(idx)
    return None


def _preprocess_for_cam(pil_img: Image.Image, processor) -> tuple[torch.Tensor, np.ndarray]:
    """Return (input_tensor, rgb_float_224) where rgb_float_224 is a (H,W,3) float
    array in [0,1] matching the model input geometry — needed for overlaying.
    """
    inputs = processor(images=pil_img, return_tensors="pt")
    input_tensor = inputs["pixel_values"].to(settings.DEVICE)

    size = getattr(processor, "size", {"height": 224, "width": 224})
    h = size.get("height", 224) if isinstance(size, dict) else 224
    w = size.get("width", 224) if isinstance(size, dict) else 224

    resized = pil_img.resize((w, h), Image.BILINEAR)
    rgb = np.array(resized).astype(np.float32) / 255.0  # (H,W,3) in [0,1]
    return input_tensor, rgb


def _encode_overlay_to_base64(overlay: np.ndarray) -> str:
    """Encode a uint8 RGB/RGBA overlay to a base64 data-URL PNG."""
    buf = io.BytesIO()
    Image.fromarray(overlay).save(buf, format="PNG")
    b64 = base64.b64encode(buf.getvalue()).decode("ascii")
    return f"data:image/png;base64,{b64}"


def _compute_gradcam_pp(
    pil_img: Image.Image,
    target_class_idx: Optional[int] = None,
) -> tuple[np.ndarray, np.ndarray]:
    """Compute Grad-CAM++ averaged across the last 3 ViT encoder layers.
    Returns (grayscale_cam, rgb_float) where grayscale_cam is (H,W) in [0,1].
    """
    loader = get_model_loader()
    model, processor = loader.load_image_model()

    model.eval()
    for p in model.parameters():
        p.requires_grad_(True)

    input_tensor, rgb_float = _preprocess_for_cam(pil_img, processor)

    grid = int(model.config.image_size / model.config.patch_size)

    # Average across last 3 ViT encoder layers for smoother heatmaps
    num_layers = len(model.vit.encoder.layer)
    last_n = min(3, num_layers)
    target_layers = [
        model.vit.encoder.layer[-(i + 1)].layernorm_before
        for i in range(last_n)
    ]

    wrapped = _HFLogitsWrapper(model)

    if target_class_idx is None:
        target_class_idx = _find_class_index(
            model,
            ("fake", "deepfake", "manipulated", "ai", "generated", "synthetic"),
        )

    targets = [ClassifierOutputTarget(int(target_class_idx))] if target_class_idx is not None else None

    with GradCAMPlusPlus(
        model=wrapped,
        target_layers=target_layers,
        reshape_transform=lambda t: _vit_reshape_transform(t, grid, grid),
    ) as cam:
        grayscale_cam = cam(input_tensor=input_tensor, targets=targets)[0]  # (H,W) in [0,1]

    return grayscale_cam, rgb_float


def _face_bbox_from_detections(frame_data: dict, orig_h: int, orig_w: int) -> Optional[tuple[int,int,int,int]]:
    """Extract (ymin, xmin, ymax, xmax) in pixel coords from BlazeFace frame_data."""
    detections = frame_data.get("detections", [])
    if len(detections) == 0:
        return None
    d = detections[0]  # first (highest-confidence) face
    ymin = int(max(0, d[0]))
    xmin = int(max(0, d[1]))
    ymax = int(min(orig_h, d[2]))
    xmax = int(min(orig_w, d[3]))
    if ymax <= ymin or xmax <= xmin:
        return None
    return ymin, xmin, ymax, xmax


def _compute_gradcam_pp_efficientnet(
    pil_img: Image.Image,
) -> tuple[np.ndarray, Optional[tuple[int,int,int,int]], Literal["attention", "gradcam++"]]:
    """Grad-CAM++ for EfficientNetAutoAttB4.

    Returns (grayscale_cam_224, face_bbox_pixels_or_None, heatmap_source).
    grayscale_cam_224 is in the 224x224 coordinate space of the face crop.
    face_bbox_pixels is (ymin, xmin, ymax, xmax) in original image pixels.
    """
    loader = get_model_loader()
    eff = loader.load_efficientnet()
    if eff is None:
        raise RuntimeError("EfficientNet not loaded")

    if pil_img.mode != "RGB":
        pil_img = pil_img.convert("RGB")
    img_np = np.array(pil_img)
    orig_h, orig_w = img_np.shape[:2]

    frame_data = eff.face_extractor.process_image(img=img_np)
    faces: list = frame_data.get("faces", [])
    if not faces:
        raise ValueError("no_face")

    face_bbox = _face_bbox_from_detections(frame_data, orig_h, orig_w)

    face_t = eff._face_tensor(faces[0]).unsqueeze(0).to(eff.device)

    try:
        net = eff.net
        target_layers = [net.efficientnet._blocks[-1]]
        face_t.requires_grad_(True)
        for p in net.parameters():
            p.requires_grad_(True)
        with GradCAMPlusPlus(model=net, target_layers=target_layers) as cam:
            grayscale_cam = cam(input_tensor=face_t, targets=None)[0]
        return grayscale_cam, face_bbox, "gradcam++"
    except Exception as e:
        logger.warning(f"EfficientNet Grad-CAM++ failed ({e}), using uniform fallback")
        grayscale_cam = np.ones((224, 224), dtype=np.float32) * 0.5
        return grayscale_cam, face_bbox, "gradcam++"


def _cam_to_full_image(
    grayscale_cam: np.ndarray,
    pil_img: Image.Image,
    face_bbox: Optional[tuple[int,int,int,int]] = None,
) -> tuple[np.ndarray, np.ndarray]:
    """Resize grayscale_cam to the original image dimensions.

    For EfficientNet (face-crop cam + known bbox): places the cam activation
    at the face location; background activation is 0.
    For ViT (full-image cam): bilinear resize to original dims.

    Returns (cam_full [H,W] float32), orig_np [H,W,3] float32 in [0,1]).
    """
    orig_w, orig_h = pil_img.size
    orig_np = np.array(pil_img.convert("RGB")).astype(np.float32) / 255.0

    if face_bbox is not None:
        ymin, xmin, ymax, xmax = face_bbox
        face_h, face_w = ymax - ymin, xmax - xmin
        cam_full = np.zeros((orig_h, orig_w), dtype=np.float32)
        cam_resized = cv2.resize(grayscale_cam, (face_w, face_h), interpolation=cv2.INTER_LINEAR)
        cam_full[ymin:ymax, xmin:xmax] = cam_resized
    else:
        cam_full = cv2.resize(grayscale_cam, (orig_w, orig_h), interpolation=cv2.INTER_LINEAR)

    return cam_full, orig_np


def _compute_gradcam_pp_densenet(
    pil_img: Image.Image,
) -> tuple[np.ndarray, str]:
    """Grad-CAM++ on the DenseNet121 face-GAN model.

    Target signal = fake probability = sigmoid(-logit), so we maximise the
    negated logit. Target layer = features.norm5 (final BN after last DenseBlock,
    7×7×1024 activation map). Returns (grayscale_cam, source_tag).
    """
    loader = get_model_loader()
    result = loader.load_densenet()
    if result is None:
        raise RuntimeError("DenseNet model unavailable")
    model, meta = result

    from services.densenet_service import _preprocess
    image_size = int(meta.get("image_size", 224))
    input_tensor = _preprocess(pil_img, image_size, settings.DEVICE)

    model.eval()
    for p in model.parameters():
        p.requires_grad_(True)

    # Target = last BN after all DenseBlocks (equivalent to conv5_block16_concat in Keras)
    target_layers = [model.features.norm5]

    # Negate logit so Grad-CAM gradients flow toward the FAKE class
    # (model output = real_probability logit; higher = more real)
    class _NegatedLogitWrapper(torch.nn.Module):
        def __init__(self, m: torch.nn.Module) -> None:
            super().__init__()
            self.m = m

        def forward(self, x: torch.Tensor) -> torch.Tensor:  # type: ignore[override]
            return -self.m(x)   # negative logit → gradient points at fake evidence

    wrapped = _NegatedLogitWrapper(model)

    with GradCAMPlusPlus(model=wrapped, target_layers=target_layers) as cam:
        grayscale_cam = cam(input_tensor=input_tensor, targets=None)[0]  # (H,W) in [0,1]

    return grayscale_cam, "gradcam++_densenet"


def generate_heatmap_base64(
    pil_img: Image.Image,
    target_class_idx: Optional[int] = None,
    model_family: Literal["vit", "efficientnet", "densenet"] = "vit",
) -> tuple[str, str]:
    """Produce a base64 data-URL PNG of the Grad-CAM++ overlay at original image resolution.

    Returns (base64_png, heatmap_source).
    """
    if model_family == "densenet":
        try:
            grayscale_cam, source = _compute_gradcam_pp_densenet(pil_img)
            cam_full, orig_np = _cam_to_full_image(grayscale_cam, pil_img, None)
        except Exception as e:
            logger.warning(f"DenseNet heatmap failed ({e}) — falling back to ViT Grad-CAM++")
            try:
                grayscale_cam, _ = _compute_gradcam_pp(pil_img, target_class_idx)
                cam_full, orig_np = _cam_to_full_image(grayscale_cam, pil_img, None)
                source = "vit_fallback"
            except Exception as fe:
                logger.warning(f"ViT fallback heatmap also failed: {fe}")
                return "", "none"
    elif model_family == "efficientnet":
        try:
            grayscale_cam, face_bbox, source = _compute_gradcam_pp_efficientnet(pil_img)
            cam_full, orig_np = _cam_to_full_image(grayscale_cam, pil_img, face_bbox)
        except ValueError:
            # BlazeFace found no face — fall back to ViT Grad-CAM on the full image.
            logger.info("EfficientNet heatmap: no face detected — falling back to ViT Grad-CAM++")
            try:
                grayscale_cam, _ = _compute_gradcam_pp(pil_img, target_class_idx)
                cam_full, orig_np = _cam_to_full_image(grayscale_cam, pil_img, None)
                source = "vit_fallback"
            except Exception as fe:
                logger.warning(f"ViT fallback heatmap also failed: {fe}")
                return "", "none"
        except Exception as e:
            logger.warning(f"EfficientNet heatmap failed: {e}")
            return "", "fallback"
    else:
        grayscale_cam, _ = _compute_gradcam_pp(pil_img, target_class_idx)
        source = "gradcam++"
        cam_full, orig_np = _cam_to_full_image(grayscale_cam, pil_img, None)

    # Generate transparent RGBA overlay so CSS can blend it without darkening the base image
    heatmap_colored = cv2.applyColorMap(np.uint8(255 * cam_full), cv2.COLORMAP_JET)
    heatmap_colored = cv2.cvtColor(heatmap_colored, cv2.COLOR_BGR2RGB)
    
    alpha = np.clip(cam_full * 1.8 * 255, 0, 255).astype(np.uint8)
    overlay_rgba = np.dstack((heatmap_colored, alpha))
    
    logger.info(f"Heatmap generated ({overlay_rgba.shape[1]}x{overlay_rgba.shape[0]}) source={source}")
    return _encode_overlay_to_base64(overlay_rgba), source


def generate_boxes_base64(
    pil_img: Image.Image,
    target_class_idx: Optional[int] = None,
    top_k: int = 5,
    threshold: float = 0.4,
) -> str:
    """Draw Grad-CAM++ activation bounding boxes on the full original image.

    Uses the ViT cam (full-image coverage), resizes it to original dimensions,
    finds contours, and draws boxes at the correct pixel locations.
    """
    grayscale_cam, _ = _compute_gradcam_pp(pil_img, target_class_idx)

    # Use original image as the canvas — resize cam to match
    orig_w, orig_h = pil_img.size
    base_img = np.array(pil_img.convert("RGB")).copy()
    cam_full = cv2.resize(grayscale_cam, (orig_w, orig_h), interpolation=cv2.INTER_LINEAR)

    binary = (cam_full >= threshold).astype(np.uint8) * 255
    contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    if not contours:
        logger.info("No significant activation regions found for bounding boxes")
        return _encode_overlay_to_base64(base_img)

    contours = sorted(contours, key=cv2.contourArea, reverse=True)[:top_k]

    # Scale line width to image size
    line_w = max(2, orig_w // 300)
    font_scale = max(0.5, orig_w / 1200)

    for cnt in contours:
        x, y, bw, bh = cv2.boundingRect(cnt)
        region_activation = cam_full[y:y + bh, x:x + bw].mean()

        if region_activation >= 0.7:
            color = (220, 40, 40)
        elif region_activation >= 0.5:
            color = (240, 140, 20)
        else:
            color = (230, 200, 40)

        cv2.rectangle(base_img, (x, y), (x + bw, y + bh), color, line_w)
        label = f"{region_activation * 100:.0f}%"
        cv2.putText(base_img, label, (x, max(y - 6, 14)),
                    cv2.FONT_HERSHEY_SIMPLEX, font_scale, color, line_w, cv2.LINE_AA)

    logger.info(f"Bounding boxes generated: {len(contours)} regions on {orig_w}x{orig_h} image")
    return _encode_overlay_to_base64(base_img)