File size: 2,343 Bytes
aa27d2d
 
 
 
cd801e4
 
aa27d2d
cd801e4
 
 
 
 
aa27d2d
 
 
cd801e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa27d2d
cd801e4
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import torch
import numpy as np
import cv2
from PIL import Image
from transformers import AutoImageProcessor, AutoModelForImageClassification
import logging

logger = logging.getLogger(__name__)

model_id = "prithivMLmods/Deep-Fake-Detector-Model"
processor = AutoImageProcessor.from_pretrained(model_id)
model = AutoModelForImageClassification.from_pretrained(model_id)
model.eval()

def generate_heatmap(image_path: str, save_path: str):
    try:
        img = Image.open(image_path).convert("RGB")
        original_cv = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)

        inputs = processor(images=img, return_tensors="pt")
        x = inputs["pixel_values"]

        target_layer = model.vision_model.encoder.layers[-1].layer_norm1

        activations = []
        gradients = []

        def forward_hook(module, args, output):
            activations.append(output.detach())

        def backward_hook(module, grad_input, grad_output):
            gradients.append(grad_output[0].detach())
        
        handle_forward = target_layer.register_forward_hook(forward_hook)
        handle_backward = target_layer.register_full_backward_hook(backward_hook)

        outputs = model(x)
        logits = outputs.logits
        target_class = logits.argmax(dim=-1).item()
        score = logits[0, target_class]

        model.zero_grad()
        score.backward()

        handle_forward.remove()
        handle_backward.remove()

        act = activations[0][0]
        grad = gradients[0][0]
        weights = torch.mean(grad, dim=0)

        cam = torch.zeros(act.shape[0])
        for i in range(act.shape[1]):
            cam += weights[i] * act[:, i]

        cam = cam.numpy()
        cam = cam.reshape(14, 14)

        cam = np.maximum(cam, 0)
        cam = cam - np.min(cam)
        if np.max(cam) != 0:
            cam = cam / np.max(cam)

        heatmap_resized = cv2.resize(cam, (img.width, img.height), interpolation=cv2.INTER_CUBIC)

        heatmap_uint8 = np.uint8(255 * heatmap_resized)
        heatmap_colored = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_INFERNO)

        overlay = cv2.addWeighted(original_cv, 0.6, heatmap_colored, 0.4, 0)
        cv2.imwrite(save_path, overlay)

        return save_path
    
    except Exception as e:
        logger.error(f"Failed to generate heatmap: {str(e)}")
        raise