Spaces:
Running
Running
| from __future__ import annotations | |
| import base64 | |
| import io | |
| from typing import Literal, Optional | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| from loguru import logger | |
| from PIL import Image | |
| from pytorch_grad_cam import GradCAMPlusPlus | |
| from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget | |
| from config import settings | |
| from models.model_loader import get_model_loader | |
| class _HFLogitsWrapper(torch.nn.Module): | |
| """Wrap a HuggingFace image classification model so forward() returns logits | |
| as a plain tensor (pytorch_grad_cam expects tensor outputs, not dicts/dataclasses). | |
| """ | |
| def __init__(self, model: torch.nn.Module) -> None: | |
| super().__init__() | |
| self.model = model | |
| def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: # type: ignore[override] | |
| return self.model(pixel_values=pixel_values).logits | |
| def _vit_reshape_transform(tensor: torch.Tensor, height: int = 14, width: int = 14) -> torch.Tensor: | |
| """Grad-CAM expects (B, C, H, W); ViT hidden states are (B, 1+H*W, C). | |
| Drop the CLS token and reshape tokens into a spatial grid. | |
| """ | |
| result = tensor[:, 1:, :] | |
| b, n, c = result.shape | |
| result = result.reshape(b, height, width, c) | |
| result = result.permute(0, 3, 1, 2) # (B, C, H, W) | |
| return result | |
| def _find_class_index(model: torch.nn.Module, label_tokens: tuple[str, ...]) -> Optional[int]: | |
| """Find the first class index whose label contains one of `label_tokens`.""" | |
| id2label: dict[int, str] = getattr(getattr(model, "config", None), "id2label", {}) or {} | |
| for idx, label in id2label.items(): | |
| lowered = str(label).lower() | |
| if any(token in lowered for token in label_tokens): | |
| return int(idx) | |
| return None | |
| def _preprocess_for_cam(pil_img: Image.Image, processor) -> tuple[torch.Tensor, np.ndarray]: | |
| """Return (input_tensor, rgb_float_224) where rgb_float_224 is a (H,W,3) float | |
| array in [0,1] matching the model input geometry — needed for overlaying. | |
| """ | |
| inputs = processor(images=pil_img, return_tensors="pt") | |
| input_tensor = inputs["pixel_values"].to(settings.DEVICE) | |
| size = getattr(processor, "size", {"height": 224, "width": 224}) | |
| h = size.get("height", 224) if isinstance(size, dict) else 224 | |
| w = size.get("width", 224) if isinstance(size, dict) else 224 | |
| resized = pil_img.resize((w, h), Image.BILINEAR) | |
| rgb = np.array(resized).astype(np.float32) / 255.0 # (H,W,3) in [0,1] | |
| return input_tensor, rgb | |
| def _encode_overlay_to_base64(overlay: np.ndarray) -> str: | |
| """Encode a uint8 RGB/RGBA overlay to a base64 data-URL PNG.""" | |
| buf = io.BytesIO() | |
| Image.fromarray(overlay).save(buf, format="PNG") | |
| b64 = base64.b64encode(buf.getvalue()).decode("ascii") | |
| return f"data:image/png;base64,{b64}" | |
| def _compute_gradcam_pp( | |
| pil_img: Image.Image, | |
| target_class_idx: Optional[int] = None, | |
| ) -> tuple[np.ndarray, np.ndarray]: | |
| """Compute Grad-CAM++ averaged across the last 3 ViT encoder layers. | |
| Returns (grayscale_cam, rgb_float) where grayscale_cam is (H,W) in [0,1]. | |
| """ | |
| loader = get_model_loader() | |
| model, processor = loader.load_image_model() | |
| model.eval() | |
| for p in model.parameters(): | |
| p.requires_grad_(True) | |
| input_tensor, rgb_float = _preprocess_for_cam(pil_img, processor) | |
| grid = int(model.config.image_size / model.config.patch_size) | |
| # Average across last 3 ViT encoder layers for smoother heatmaps | |
| num_layers = len(model.vit.encoder.layer) | |
| last_n = min(3, num_layers) | |
| target_layers = [ | |
| model.vit.encoder.layer[-(i + 1)].layernorm_before | |
| for i in range(last_n) | |
| ] | |
| wrapped = _HFLogitsWrapper(model) | |
| if target_class_idx is None: | |
| target_class_idx = _find_class_index( | |
| model, | |
| ("fake", "deepfake", "manipulated", "ai", "generated", "synthetic"), | |
| ) | |
| targets = [ClassifierOutputTarget(int(target_class_idx))] if target_class_idx is not None else None | |
| with GradCAMPlusPlus( | |
| model=wrapped, | |
| target_layers=target_layers, | |
| reshape_transform=lambda t: _vit_reshape_transform(t, grid, grid), | |
| ) as cam: | |
| grayscale_cam = cam(input_tensor=input_tensor, targets=targets)[0] # (H,W) in [0,1] | |
| return grayscale_cam, rgb_float | |
| def _face_bbox_from_detections(frame_data: dict, orig_h: int, orig_w: int) -> Optional[tuple[int,int,int,int]]: | |
| """Extract (ymin, xmin, ymax, xmax) in pixel coords from BlazeFace frame_data.""" | |
| detections = frame_data.get("detections", []) | |
| if len(detections) == 0: | |
| return None | |
| d = detections[0] # first (highest-confidence) face | |
| ymin = int(max(0, d[0])) | |
| xmin = int(max(0, d[1])) | |
| ymax = int(min(orig_h, d[2])) | |
| xmax = int(min(orig_w, d[3])) | |
| if ymax <= ymin or xmax <= xmin: | |
| return None | |
| return ymin, xmin, ymax, xmax | |
| def _compute_gradcam_pp_efficientnet( | |
| pil_img: Image.Image, | |
| ) -> tuple[np.ndarray, Optional[tuple[int,int,int,int]], Literal["attention", "gradcam++"]]: | |
| """Grad-CAM++ for EfficientNetAutoAttB4. | |
| Returns (grayscale_cam_224, face_bbox_pixels_or_None, heatmap_source). | |
| grayscale_cam_224 is in the 224x224 coordinate space of the face crop. | |
| face_bbox_pixels is (ymin, xmin, ymax, xmax) in original image pixels. | |
| """ | |
| loader = get_model_loader() | |
| eff = loader.load_efficientnet() | |
| if eff is None: | |
| raise RuntimeError("EfficientNet not loaded") | |
| if pil_img.mode != "RGB": | |
| pil_img = pil_img.convert("RGB") | |
| img_np = np.array(pil_img) | |
| orig_h, orig_w = img_np.shape[:2] | |
| frame_data = eff.face_extractor.process_image(img=img_np) | |
| faces: list = frame_data.get("faces", []) | |
| if not faces: | |
| raise ValueError("no_face") | |
| face_bbox = _face_bbox_from_detections(frame_data, orig_h, orig_w) | |
| face_t = eff._face_tensor(faces[0]).unsqueeze(0).to(eff.device) | |
| try: | |
| net = eff.net | |
| target_layers = [net.efficientnet._blocks[-1]] | |
| face_t.requires_grad_(True) | |
| for p in net.parameters(): | |
| p.requires_grad_(True) | |
| with GradCAMPlusPlus(model=net, target_layers=target_layers) as cam: | |
| grayscale_cam = cam(input_tensor=face_t, targets=None)[0] | |
| return grayscale_cam, face_bbox, "gradcam++" | |
| except Exception as e: | |
| logger.warning(f"EfficientNet Grad-CAM++ failed ({e}), using uniform fallback") | |
| grayscale_cam = np.ones((224, 224), dtype=np.float32) * 0.5 | |
| return grayscale_cam, face_bbox, "gradcam++" | |
| def _cam_to_full_image( | |
| grayscale_cam: np.ndarray, | |
| pil_img: Image.Image, | |
| face_bbox: Optional[tuple[int,int,int,int]] = None, | |
| ) -> tuple[np.ndarray, np.ndarray]: | |
| """Resize grayscale_cam to the original image dimensions. | |
| For EfficientNet (face-crop cam + known bbox): places the cam activation | |
| at the face location; background activation is 0. | |
| For ViT (full-image cam): bilinear resize to original dims. | |
| Returns (cam_full [H,W] float32), orig_np [H,W,3] float32 in [0,1]). | |
| """ | |
| orig_w, orig_h = pil_img.size | |
| orig_np = np.array(pil_img.convert("RGB")).astype(np.float32) / 255.0 | |
| if face_bbox is not None: | |
| ymin, xmin, ymax, xmax = face_bbox | |
| face_h, face_w = ymax - ymin, xmax - xmin | |
| cam_full = np.zeros((orig_h, orig_w), dtype=np.float32) | |
| cam_resized = cv2.resize(grayscale_cam, (face_w, face_h), interpolation=cv2.INTER_LINEAR) | |
| cam_full[ymin:ymax, xmin:xmax] = cam_resized | |
| else: | |
| cam_full = cv2.resize(grayscale_cam, (orig_w, orig_h), interpolation=cv2.INTER_LINEAR) | |
| return cam_full, orig_np | |
| def _compute_gradcam_pp_densenet( | |
| pil_img: Image.Image, | |
| ) -> tuple[np.ndarray, str]: | |
| """Grad-CAM++ on the DenseNet121 face-GAN model. | |
| Target signal = fake probability = sigmoid(-logit), so we maximise the | |
| negated logit. Target layer = features.norm5 (final BN after last DenseBlock, | |
| 7×7×1024 activation map). Returns (grayscale_cam, source_tag). | |
| """ | |
| loader = get_model_loader() | |
| result = loader.load_densenet() | |
| if result is None: | |
| raise RuntimeError("DenseNet model unavailable") | |
| model, meta = result | |
| from services.densenet_service import _preprocess | |
| image_size = int(meta.get("image_size", 224)) | |
| input_tensor = _preprocess(pil_img, image_size, settings.DEVICE) | |
| model.eval() | |
| for p in model.parameters(): | |
| p.requires_grad_(True) | |
| # Target = last BN after all DenseBlocks (equivalent to conv5_block16_concat in Keras) | |
| target_layers = [model.features.norm5] | |
| # Negate logit so Grad-CAM gradients flow toward the FAKE class | |
| # (model output = real_probability logit; higher = more real) | |
| class _NegatedLogitWrapper(torch.nn.Module): | |
| def __init__(self, m: torch.nn.Module) -> None: | |
| super().__init__() | |
| self.m = m | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore[override] | |
| return -self.m(x) # negative logit → gradient points at fake evidence | |
| wrapped = _NegatedLogitWrapper(model) | |
| with GradCAMPlusPlus(model=wrapped, target_layers=target_layers) as cam: | |
| grayscale_cam = cam(input_tensor=input_tensor, targets=None)[0] # (H,W) in [0,1] | |
| return grayscale_cam, "gradcam++_densenet" | |
| def generate_heatmap_base64( | |
| pil_img: Image.Image, | |
| target_class_idx: Optional[int] = None, | |
| model_family: Literal["vit", "efficientnet", "densenet"] = "vit", | |
| ) -> tuple[str, str]: | |
| """Produce a base64 data-URL PNG of the Grad-CAM++ overlay at original image resolution. | |
| Returns (base64_png, heatmap_source). | |
| """ | |
| if model_family == "densenet": | |
| try: | |
| grayscale_cam, source = _compute_gradcam_pp_densenet(pil_img) | |
| cam_full, orig_np = _cam_to_full_image(grayscale_cam, pil_img, None) | |
| except Exception as e: | |
| logger.warning(f"DenseNet heatmap failed ({e}) — falling back to ViT Grad-CAM++") | |
| try: | |
| grayscale_cam, _ = _compute_gradcam_pp(pil_img, target_class_idx) | |
| cam_full, orig_np = _cam_to_full_image(grayscale_cam, pil_img, None) | |
| source = "vit_fallback" | |
| except Exception as fe: | |
| logger.warning(f"ViT fallback heatmap also failed: {fe}") | |
| return "", "none" | |
| elif model_family == "efficientnet": | |
| try: | |
| grayscale_cam, face_bbox, source = _compute_gradcam_pp_efficientnet(pil_img) | |
| cam_full, orig_np = _cam_to_full_image(grayscale_cam, pil_img, face_bbox) | |
| except ValueError: | |
| # BlazeFace found no face — fall back to ViT Grad-CAM on the full image. | |
| logger.info("EfficientNet heatmap: no face detected — falling back to ViT Grad-CAM++") | |
| try: | |
| grayscale_cam, _ = _compute_gradcam_pp(pil_img, target_class_idx) | |
| cam_full, orig_np = _cam_to_full_image(grayscale_cam, pil_img, None) | |
| source = "vit_fallback" | |
| except Exception as fe: | |
| logger.warning(f"ViT fallback heatmap also failed: {fe}") | |
| return "", "none" | |
| except Exception as e: | |
| logger.warning(f"EfficientNet heatmap failed: {e}") | |
| return "", "fallback" | |
| else: | |
| grayscale_cam, _ = _compute_gradcam_pp(pil_img, target_class_idx) | |
| source = "gradcam++" | |
| cam_full, orig_np = _cam_to_full_image(grayscale_cam, pil_img, None) | |
| # Generate transparent RGBA overlay so CSS can blend it without darkening the base image | |
| heatmap_colored = cv2.applyColorMap(np.uint8(255 * cam_full), cv2.COLORMAP_JET) | |
| heatmap_colored = cv2.cvtColor(heatmap_colored, cv2.COLOR_BGR2RGB) | |
| alpha = np.clip(cam_full * 1.8 * 255, 0, 255).astype(np.uint8) | |
| overlay_rgba = np.dstack((heatmap_colored, alpha)) | |
| logger.info(f"Heatmap generated ({overlay_rgba.shape[1]}x{overlay_rgba.shape[0]}) source={source}") | |
| return _encode_overlay_to_base64(overlay_rgba), source | |
| def generate_boxes_base64( | |
| pil_img: Image.Image, | |
| target_class_idx: Optional[int] = None, | |
| top_k: int = 5, | |
| threshold: float = 0.4, | |
| ) -> str: | |
| """Draw Grad-CAM++ activation bounding boxes on the full original image. | |
| Uses the ViT cam (full-image coverage), resizes it to original dimensions, | |
| finds contours, and draws boxes at the correct pixel locations. | |
| """ | |
| grayscale_cam, _ = _compute_gradcam_pp(pil_img, target_class_idx) | |
| # Use original image as the canvas — resize cam to match | |
| orig_w, orig_h = pil_img.size | |
| base_img = np.array(pil_img.convert("RGB")).copy() | |
| cam_full = cv2.resize(grayscale_cam, (orig_w, orig_h), interpolation=cv2.INTER_LINEAR) | |
| binary = (cam_full >= threshold).astype(np.uint8) * 255 | |
| contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| if not contours: | |
| logger.info("No significant activation regions found for bounding boxes") | |
| return _encode_overlay_to_base64(base_img) | |
| contours = sorted(contours, key=cv2.contourArea, reverse=True)[:top_k] | |
| # Scale line width to image size | |
| line_w = max(2, orig_w // 300) | |
| font_scale = max(0.5, orig_w / 1200) | |
| for cnt in contours: | |
| x, y, bw, bh = cv2.boundingRect(cnt) | |
| region_activation = cam_full[y:y + bh, x:x + bw].mean() | |
| if region_activation >= 0.7: | |
| color = (220, 40, 40) | |
| elif region_activation >= 0.5: | |
| color = (240, 140, 20) | |
| else: | |
| color = (230, 200, 40) | |
| cv2.rectangle(base_img, (x, y), (x + bw, y + bh), color, line_w) | |
| label = f"{region_activation * 100:.0f}%" | |
| cv2.putText(base_img, label, (x, max(y - 6, 14)), | |
| cv2.FONT_HERSHEY_SIMPLEX, font_scale, color, line_w, cv2.LINE_AA) | |
| logger.info(f"Bounding boxes generated: {len(contours)} regions on {orig_w}x{orig_h} image") | |
| return _encode_overlay_to_base64(base_img) | |