import torch import torch.nn.functional as F from torchvision import transforms from PIL import Image import numpy as np import cv2 import os # Import GradCAM tools from pytorch_grad_cam import GradCAM from pytorch_grad_cam.utils.image import show_cam_on_image from src.models.model import VisionGuardModel class VisionGuardPredictor: def __init__(self, model_path, config_path="configs/config.yaml"): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"🚀 Loading Inference Engine on: {self.device}") # 1. Load Model self.model = VisionGuardModel(num_classes=2, pretrained=False) checkpoint = torch.load(model_path, map_location=self.device) self.model.load_state_dict(checkpoint) self.model.to(self.device) self.model.eval() # 2. Setup GradCAM (The Explainability Tool) # We target the last normalization layer of the backbone target_layers = [self.model.backbone.blocks[-1].norm1] # DINOv2 requires a special reshape transform because it outputs 1D sequences def reshape_transform(tensor): # DINOv2 small outputs: [Batch, 257, 384] (1 CLS token + 256 Patches) # We discard the CLS token (index 0) and keep the 256 patches result = tensor[:, 1:, :] # Reshape 256 -> 16x16 grid (since 224/14 = 16) height = 14 width = 14 # Note: If image size is 224x224, grid is 16x16. # DINOv2-S/14 means patch size is 14. 224/14 = 16. grid_size = 16 result = result.reshape(tensor.size(0), grid_size, grid_size, tensor.size(2)) # Bring channels first: [Batch, Channels, Height, Width] result = result.transpose(2, 3).transpose(1, 2) return result self.cam = GradCAM(model=self.model, target_layers=target_layers, reshape_transform=reshape_transform) # 3. Preprocessing self.transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) self.labels = ['FAKE', 'REAL'] def predict(self, image_path): # 1. Load Image image = Image.open(image_path).convert('RGB') # Keep a clean copy for visualization (resized to 224x224) vis_image = image.resize((224, 224)) vis_image = np.float32(vis_image) / 255.0 # Normalize 0-1 for OpenCV # 2. Transform for Model input_tensor = self.transform(image).unsqueeze(0).to(self.device) # 3. Inference with torch.no_grad(): outputs = self.model(input_tensor) probs = F.softmax(outputs, dim=1) confidence, predicted_class = torch.max(probs, 1) # 4. Generate Heatmap # We tell GradCAM to look for the predicted class grayscale_cam = self.cam(input_tensor=input_tensor, targets=None) grayscale_cam = grayscale_cam[0, :] # Overlay heatmap on image visualization = show_cam_on_image(vis_image, grayscale_cam, use_rgb=True) # Convert back to PIL for Gradio heatmap_pil = Image.fromarray(visualization) # 5. Format Output idx = predicted_class.item() return { "verdict": self.labels[idx], "confidence": round(float(confidence.item()) * 100, 2), "probabilities": { "FAKE": round(float(probs[0][0].item()), 4), "REAL": round(float(probs[0][1].item()), 4) }, "heatmap": heatmap_pil }