File size: 3,877 Bytes
26c2a4a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import torch
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import numpy as np
import cv2
import os

# Import GradCAM tools
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image

from src.models.model import VisionGuardModel

class VisionGuardPredictor:
    def __init__(self, model_path, config_path="configs/config.yaml"):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"🚀 Loading Inference Engine on: {self.device}")
        
        # 1. Load Model
        self.model = VisionGuardModel(num_classes=2, pretrained=False)
        checkpoint = torch.load(model_path, map_location=self.device)
        self.model.load_state_dict(checkpoint)
        self.model.to(self.device)
        self.model.eval()
        
        # 2. Setup GradCAM (The Explainability Tool)
        # We target the last normalization layer of the backbone
        target_layers = [self.model.backbone.blocks[-1].norm1]
        
        # DINOv2 requires a special reshape transform because it outputs 1D sequences
        def reshape_transform(tensor):
            # DINOv2 small outputs: [Batch, 257, 384] (1 CLS token + 256 Patches)
            # We discard the CLS token (index 0) and keep the 256 patches
            result = tensor[:, 1:, :]
            
            # Reshape 256 -> 16x16 grid (since 224/14 = 16)
            height = 14
            width = 14
            # Note: If image size is 224x224, grid is 16x16. 
            # DINOv2-S/14 means patch size is 14. 224/14 = 16.
            grid_size = 16
            
            result = result.reshape(tensor.size(0), grid_size, grid_size, tensor.size(2))
            
            # Bring channels first: [Batch, Channels, Height, Width]
            result = result.transpose(2, 3).transpose(1, 2)
            return result

        self.cam = GradCAM(model=self.model, target_layers=target_layers, reshape_transform=reshape_transform)

        # 3. Preprocessing
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

        self.labels = ['FAKE', 'REAL']

    def predict(self, image_path):
        # 1. Load Image
        image = Image.open(image_path).convert('RGB')
        
        # Keep a clean copy for visualization (resized to 224x224)
        vis_image = image.resize((224, 224))
        vis_image = np.float32(vis_image) / 255.0 # Normalize 0-1 for OpenCV
        
        # 2. Transform for Model
        input_tensor = self.transform(image).unsqueeze(0).to(self.device)

        # 3. Inference
        with torch.no_grad():
            outputs = self.model(input_tensor)
            probs = F.softmax(outputs, dim=1)
            confidence, predicted_class = torch.max(probs, 1)

        # 4. Generate Heatmap
        # We tell GradCAM to look for the predicted class
        grayscale_cam = self.cam(input_tensor=input_tensor, targets=None)
        grayscale_cam = grayscale_cam[0, :]
        
        # Overlay heatmap on image
        visualization = show_cam_on_image(vis_image, grayscale_cam, use_rgb=True)
        
        # Convert back to PIL for Gradio
        heatmap_pil = Image.fromarray(visualization)

        # 5. Format Output
        idx = predicted_class.item()
        return {
            "verdict": self.labels[idx],
            "confidence": round(float(confidence.item()) * 100, 2),
            "probabilities": {
                "FAKE": round(float(probs[0][0].item()), 4),
                "REAL": round(float(probs[0][1].item()), 4)
            },
            "heatmap": heatmap_pil
        }