from __future__ import annotations import base64 import io from typing import Literal, Optional import cv2 import numpy as np import torch from loguru import logger from PIL import Image from pytorch_grad_cam import GradCAMPlusPlus from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget from config import settings from models.model_loader import get_model_loader class _HFLogitsWrapper(torch.nn.Module): """Wrap a HuggingFace image classification model so forward() returns logits as a plain tensor (pytorch_grad_cam expects tensor outputs, not dicts/dataclasses). """ def __init__(self, model: torch.nn.Module) -> None: super().__init__() self.model = model def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: # type: ignore[override] return self.model(pixel_values=pixel_values).logits def _vit_reshape_transform(tensor: torch.Tensor, height: int = 14, width: int = 14) -> torch.Tensor: """Grad-CAM expects (B, C, H, W); ViT hidden states are (B, 1+H*W, C). Drop the CLS token and reshape tokens into a spatial grid. """ result = tensor[:, 1:, :] b, n, c = result.shape result = result.reshape(b, height, width, c) result = result.permute(0, 3, 1, 2) # (B, C, H, W) return result def _find_class_index(model: torch.nn.Module, label_tokens: tuple[str, ...]) -> Optional[int]: """Find the first class index whose label contains one of `label_tokens`.""" id2label: dict[int, str] = getattr(getattr(model, "config", None), "id2label", {}) or {} for idx, label in id2label.items(): lowered = str(label).lower() if any(token in lowered for token in label_tokens): return int(idx) return None def _preprocess_for_cam(pil_img: Image.Image, processor) -> tuple[torch.Tensor, np.ndarray]: """Return (input_tensor, rgb_float_224) where rgb_float_224 is a (H,W,3) float array in [0,1] matching the model input geometry — needed for overlaying. """ inputs = processor(images=pil_img, return_tensors="pt") input_tensor = inputs["pixel_values"].to(settings.DEVICE) size = getattr(processor, "size", {"height": 224, "width": 224}) h = size.get("height", 224) if isinstance(size, dict) else 224 w = size.get("width", 224) if isinstance(size, dict) else 224 resized = pil_img.resize((w, h), Image.BILINEAR) rgb = np.array(resized).astype(np.float32) / 255.0 # (H,W,3) in [0,1] return input_tensor, rgb def _encode_overlay_to_base64(overlay: np.ndarray) -> str: """Encode a uint8 RGB/RGBA overlay to a base64 data-URL PNG.""" buf = io.BytesIO() Image.fromarray(overlay).save(buf, format="PNG") b64 = base64.b64encode(buf.getvalue()).decode("ascii") return f"data:image/png;base64,{b64}" def _compute_gradcam_pp( pil_img: Image.Image, target_class_idx: Optional[int] = None, ) -> tuple[np.ndarray, np.ndarray]: """Compute Grad-CAM++ averaged across the last 3 ViT encoder layers. Returns (grayscale_cam, rgb_float) where grayscale_cam is (H,W) in [0,1]. """ loader = get_model_loader() model, processor = loader.load_image_model() model.eval() for p in model.parameters(): p.requires_grad_(True) input_tensor, rgb_float = _preprocess_for_cam(pil_img, processor) grid = int(model.config.image_size / model.config.patch_size) # Average across last 3 ViT encoder layers for smoother heatmaps num_layers = len(model.vit.encoder.layer) last_n = min(3, num_layers) target_layers = [ model.vit.encoder.layer[-(i + 1)].layernorm_before for i in range(last_n) ] wrapped = _HFLogitsWrapper(model) if target_class_idx is None: target_class_idx = _find_class_index( model, ("fake", "deepfake", "manipulated", "ai", "generated", "synthetic"), ) targets = [ClassifierOutputTarget(int(target_class_idx))] if target_class_idx is not None else None with GradCAMPlusPlus( model=wrapped, target_layers=target_layers, reshape_transform=lambda t: _vit_reshape_transform(t, grid, grid), ) as cam: grayscale_cam = cam(input_tensor=input_tensor, targets=targets)[0] # (H,W) in [0,1] return grayscale_cam, rgb_float def _face_bbox_from_detections(frame_data: dict, orig_h: int, orig_w: int) -> Optional[tuple[int,int,int,int]]: """Extract (ymin, xmin, ymax, xmax) in pixel coords from BlazeFace frame_data.""" detections = frame_data.get("detections", []) if len(detections) == 0: return None d = detections[0] # first (highest-confidence) face ymin = int(max(0, d[0])) xmin = int(max(0, d[1])) ymax = int(min(orig_h, d[2])) xmax = int(min(orig_w, d[3])) if ymax <= ymin or xmax <= xmin: return None return ymin, xmin, ymax, xmax def _compute_gradcam_pp_efficientnet( pil_img: Image.Image, ) -> tuple[np.ndarray, Optional[tuple[int,int,int,int]], Literal["attention", "gradcam++"]]: """Grad-CAM++ for EfficientNetAutoAttB4. Returns (grayscale_cam_224, face_bbox_pixels_or_None, heatmap_source). grayscale_cam_224 is in the 224x224 coordinate space of the face crop. face_bbox_pixels is (ymin, xmin, ymax, xmax) in original image pixels. """ loader = get_model_loader() eff = loader.load_efficientnet() if eff is None: raise RuntimeError("EfficientNet not loaded") if pil_img.mode != "RGB": pil_img = pil_img.convert("RGB") img_np = np.array(pil_img) orig_h, orig_w = img_np.shape[:2] frame_data = eff.face_extractor.process_image(img=img_np) faces: list = frame_data.get("faces", []) if not faces: raise ValueError("no_face") face_bbox = _face_bbox_from_detections(frame_data, orig_h, orig_w) face_t = eff._face_tensor(faces[0]).unsqueeze(0).to(eff.device) try: net = eff.net target_layers = [net.efficientnet._blocks[-1]] face_t.requires_grad_(True) for p in net.parameters(): p.requires_grad_(True) with GradCAMPlusPlus(model=net, target_layers=target_layers) as cam: grayscale_cam = cam(input_tensor=face_t, targets=None)[0] return grayscale_cam, face_bbox, "gradcam++" except Exception as e: logger.warning(f"EfficientNet Grad-CAM++ failed ({e}), using uniform fallback") grayscale_cam = np.ones((224, 224), dtype=np.float32) * 0.5 return grayscale_cam, face_bbox, "gradcam++" def _cam_to_full_image( grayscale_cam: np.ndarray, pil_img: Image.Image, face_bbox: Optional[tuple[int,int,int,int]] = None, ) -> tuple[np.ndarray, np.ndarray]: """Resize grayscale_cam to the original image dimensions. For EfficientNet (face-crop cam + known bbox): places the cam activation at the face location; background activation is 0. For ViT (full-image cam): bilinear resize to original dims. Returns (cam_full [H,W] float32), orig_np [H,W,3] float32 in [0,1]). """ orig_w, orig_h = pil_img.size orig_np = np.array(pil_img.convert("RGB")).astype(np.float32) / 255.0 if face_bbox is not None: ymin, xmin, ymax, xmax = face_bbox face_h, face_w = ymax - ymin, xmax - xmin cam_full = np.zeros((orig_h, orig_w), dtype=np.float32) cam_resized = cv2.resize(grayscale_cam, (face_w, face_h), interpolation=cv2.INTER_LINEAR) cam_full[ymin:ymax, xmin:xmax] = cam_resized else: cam_full = cv2.resize(grayscale_cam, (orig_w, orig_h), interpolation=cv2.INTER_LINEAR) return cam_full, orig_np def _compute_gradcam_pp_densenet( pil_img: Image.Image, ) -> tuple[np.ndarray, str]: """Grad-CAM++ on the DenseNet121 face-GAN model. Target signal = fake probability = sigmoid(-logit), so we maximise the negated logit. Target layer = features.norm5 (final BN after last DenseBlock, 7×7×1024 activation map). Returns (grayscale_cam, source_tag). """ loader = get_model_loader() result = loader.load_densenet() if result is None: raise RuntimeError("DenseNet model unavailable") model, meta = result from services.densenet_service import _preprocess image_size = int(meta.get("image_size", 224)) input_tensor = _preprocess(pil_img, image_size, settings.DEVICE) model.eval() for p in model.parameters(): p.requires_grad_(True) # Target = last BN after all DenseBlocks (equivalent to conv5_block16_concat in Keras) target_layers = [model.features.norm5] # Negate logit so Grad-CAM gradients flow toward the FAKE class # (model output = real_probability logit; higher = more real) class _NegatedLogitWrapper(torch.nn.Module): def __init__(self, m: torch.nn.Module) -> None: super().__init__() self.m = m def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore[override] return -self.m(x) # negative logit → gradient points at fake evidence wrapped = _NegatedLogitWrapper(model) with GradCAMPlusPlus(model=wrapped, target_layers=target_layers) as cam: grayscale_cam = cam(input_tensor=input_tensor, targets=None)[0] # (H,W) in [0,1] return grayscale_cam, "gradcam++_densenet" def generate_heatmap_base64( pil_img: Image.Image, target_class_idx: Optional[int] = None, model_family: Literal["vit", "efficientnet", "densenet"] = "vit", ) -> tuple[str, str]: """Produce a base64 data-URL PNG of the Grad-CAM++ overlay at original image resolution. Returns (base64_png, heatmap_source). """ if model_family == "densenet": try: grayscale_cam, source = _compute_gradcam_pp_densenet(pil_img) cam_full, orig_np = _cam_to_full_image(grayscale_cam, pil_img, None) except Exception as e: logger.warning(f"DenseNet heatmap failed ({e}) — falling back to ViT Grad-CAM++") try: grayscale_cam, _ = _compute_gradcam_pp(pil_img, target_class_idx) cam_full, orig_np = _cam_to_full_image(grayscale_cam, pil_img, None) source = "vit_fallback" except Exception as fe: logger.warning(f"ViT fallback heatmap also failed: {fe}") return "", "none" elif model_family == "efficientnet": try: grayscale_cam, face_bbox, source = _compute_gradcam_pp_efficientnet(pil_img) cam_full, orig_np = _cam_to_full_image(grayscale_cam, pil_img, face_bbox) except ValueError: # BlazeFace found no face — fall back to ViT Grad-CAM on the full image. logger.info("EfficientNet heatmap: no face detected — falling back to ViT Grad-CAM++") try: grayscale_cam, _ = _compute_gradcam_pp(pil_img, target_class_idx) cam_full, orig_np = _cam_to_full_image(grayscale_cam, pil_img, None) source = "vit_fallback" except Exception as fe: logger.warning(f"ViT fallback heatmap also failed: {fe}") return "", "none" except Exception as e: logger.warning(f"EfficientNet heatmap failed: {e}") return "", "fallback" else: grayscale_cam, _ = _compute_gradcam_pp(pil_img, target_class_idx) source = "gradcam++" cam_full, orig_np = _cam_to_full_image(grayscale_cam, pil_img, None) # Generate transparent RGBA overlay so CSS can blend it without darkening the base image heatmap_colored = cv2.applyColorMap(np.uint8(255 * cam_full), cv2.COLORMAP_JET) heatmap_colored = cv2.cvtColor(heatmap_colored, cv2.COLOR_BGR2RGB) alpha = np.clip(cam_full * 1.8 * 255, 0, 255).astype(np.uint8) overlay_rgba = np.dstack((heatmap_colored, alpha)) logger.info(f"Heatmap generated ({overlay_rgba.shape[1]}x{overlay_rgba.shape[0]}) source={source}") return _encode_overlay_to_base64(overlay_rgba), source def generate_boxes_base64( pil_img: Image.Image, target_class_idx: Optional[int] = None, top_k: int = 5, threshold: float = 0.4, ) -> str: """Draw Grad-CAM++ activation bounding boxes on the full original image. Uses the ViT cam (full-image coverage), resizes it to original dimensions, finds contours, and draws boxes at the correct pixel locations. """ grayscale_cam, _ = _compute_gradcam_pp(pil_img, target_class_idx) # Use original image as the canvas — resize cam to match orig_w, orig_h = pil_img.size base_img = np.array(pil_img.convert("RGB")).copy() cam_full = cv2.resize(grayscale_cam, (orig_w, orig_h), interpolation=cv2.INTER_LINEAR) binary = (cam_full >= threshold).astype(np.uint8) * 255 contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) if not contours: logger.info("No significant activation regions found for bounding boxes") return _encode_overlay_to_base64(base_img) contours = sorted(contours, key=cv2.contourArea, reverse=True)[:top_k] # Scale line width to image size line_w = max(2, orig_w // 300) font_scale = max(0.5, orig_w / 1200) for cnt in contours: x, y, bw, bh = cv2.boundingRect(cnt) region_activation = cam_full[y:y + bh, x:x + bw].mean() if region_activation >= 0.7: color = (220, 40, 40) elif region_activation >= 0.5: color = (240, 140, 20) else: color = (230, 200, 40) cv2.rectangle(base_img, (x, y), (x + bw, y + bh), color, line_w) label = f"{region_activation * 100:.0f}%" cv2.putText(base_img, label, (x, max(y - 6, 14)), cv2.FONT_HERSHEY_SIMPLEX, font_scale, color, line_w, cv2.LINE_AA) logger.info(f"Bounding boxes generated: {len(contours)} regions on {orig_w}x{orig_h} image") return _encode_overlay_to_base64(base_img)