File size: 4,365 Bytes
c5377b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import cv2
import numpy as np
from PIL import Image
import torch.nn as nn

def get_resnet_gradcam(image_path, predictor, output_path):
    model = predictor.model
    device = predictor.device
    model.eval()

    features, gradients = [], []

    def forward_hook(module, input, output): features.append(output)
    def backward_hook(module, grad_in, grad_out): gradients.append(grad_out[0])

    target_layer = model.model.layer4[-1]
    handle_fw = target_layer.register_forward_hook(forward_hook)
    handle_bw = target_layer.register_full_backward_hook(backward_hook)

    original_img = Image.open(image_path).convert("RGB")
    input_tensor = predictor.test_transforms(original_img).unsqueeze(0).to(device)

    model.zero_grad()
    output = model(input_tensor)
    pred_class_idx = output.argmax(dim=1).item()
    
    score = output[0, pred_class_idx]
    score.backward()

    handle_fw.remove()
    handle_bw.remove()

    acts = features[0].cpu().data.numpy()[0]
    grads = gradients[0].cpu().data.numpy()[0]
    weights = np.mean(grads, axis=(1, 2))
    
    cam = np.zeros(acts.shape[1:], dtype=np.float32)
    for i, w in enumerate(weights):
        cam += w * acts[i]

    cam = np.maximum(cam, 0)
    cam = cv2.resize(cam, (original_img.width, original_img.height))
    cam = (cam - np.min(cam)) / (np.max(cam) - np.min(cam) + 1e-8)

    heatmap = cv2.applyColorMap(np.uint8(255 * cam), cv2.COLORMAP_JET)
    original_np = np.array(original_img)
    
    # Overlay logic (OpenCV style)
    overlay = cv2.addWeighted(cv2.cvtColor(original_np, cv2.COLOR_RGB2BGR), 0.6, heatmap, 0.4, 0)
    cv2.imwrite(output_path, overlay)
    return True

def get_fusion_gradcam(image_path, predictor, output_path):
    model = predictor.model
    device = predictor.device
    model.eval()

    # ----- Find the last Conv2d layer inside the EfficientNet backbone -----
    def find_last_conv_layer(module):
        last_conv = None
        for child in module.children():
            if isinstance(child, nn.Conv2d):
                last_conv = child
            else:
                candidate = find_last_conv_layer(child)
                if candidate is not None:
                    last_conv = candidate
        return last_conv

    target_layer = find_last_conv_layer(model.eff_features)
    if target_layer is None:
        raise RuntimeError("No Conv2d layer found in EfficientNet backbone")

    # ----- Forward hook that stores the activation with gradient tracking -----
    activation = None

    def forward_hook(module, inp, out):
        nonlocal activation
        activation = out
        activation.retain_grad()   # crucial for getting gradients later

    handle = target_layer.register_forward_hook(forward_hook)

    # ----- Preprocess image for both branches -----
    original_img = Image.open(image_path).convert("RGB")

    # EfficientNet branch
    pixel_eff = predictor.eff_normalize(original_img).unsqueeze(0).to(device)

    # ConvNeXt branch
    inputs_cnx = predictor.convnext_processor(images=original_img, return_tensors="pt")
    pixel_cnx = inputs_cnx["pixel_values"].to(device)

    # ----- Forward + backward -----
    model.zero_grad()
    output = model(pixel_eff, pixel_cnx)          # logits
    pred_class_idx = output.argmax(dim=1).item()
    score = output[0, pred_class_idx]
    score.backward()

    # ----- Extract activation and gradients -----
    acts = activation[0].cpu().data.numpy()       # (C, H, W)
    grads = activation.grad[0].cpu().data.numpy() # (C, H, W)

    handle.remove()  # clean up

    # ----- Grad-CAM computation -----
    weights = np.mean(grads, axis=(1, 2))         # (C,)
    cam = np.zeros(acts.shape[1:], dtype=np.float32)
    for i, w in enumerate(weights):
        cam += w * acts[i]

    cam = np.maximum(cam, 0)
    cam = cv2.resize(cam, (original_img.width, original_img.height))

    # Normalise
    if np.max(cam) - np.min(cam) > 1e-8:
        cam = (cam - np.min(cam)) / (np.max(cam) - np.min(cam))
    else:
        cam = np.zeros_like(cam)

    # ----- Overlay and save -----
    heatmap = cv2.applyColorMap(np.uint8(255 * cam), cv2.COLORMAP_JET)
    original_np = np.array(original_img)
    original_bgr = cv2.cvtColor(original_np, cv2.COLOR_RGB2BGR)
    overlay = cv2.addWeighted(original_bgr, 0.6, heatmap, 0.4, 0)
    cv2.imwrite(output_path, overlay)

    return True