File size: 3,816 Bytes
6fcadef
 
68b4742
 
6fcadef
 
68b4742
6fcadef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68b4742
 
 
 
 
 
 
 
 
 
 
 
 
6fcadef
68b4742
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import cv2
import numpy as np
import torch
from pipeline import preprocess_disc_crop, preprocess_for_efficientnet, TRAIN_MEAN, TRAIN_STD


def make_visualization_images(image_bgr: np.ndarray, result: dict, eff_model, device) -> dict:
    image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)

    img_yolo = image_rgb.copy()
    x1, y1, x2, y2 = [int(v) for v in result["yolo_bbox"]]
    cv2.rectangle(img_yolo, (x1, y1), (x2, y2), (255, 215, 0), 3)

    crop_bgr = result["crop"]
    crop_rgb = cv2.cvtColor(crop_bgr, cv2.COLOR_BGR2RGB)

    tensor_proc = preprocess_disc_crop(crop_bgr, img_size=512)
    img_proc = tensor_proc.permute(1, 2, 0).numpy()
    img_proc = img_proc * TRAIN_STD + TRAIN_MEAN
    img_proc = np.clip(img_proc * 255, 0, 255).astype(np.uint8)

    pred_map = result["pred_map"]
    seg_rgb  = np.zeros((*pred_map.shape, 3), dtype=np.uint8)
    seg_rgb[pred_map == 1] = [0, 200, 0]
    seg_rgb[pred_map == 2] = [200, 0, 0]

    crop_512 = cv2.resize(crop_rgb, (512, 512), interpolation=cv2.INTER_LINEAR)
    overlay  = crop_512.astype(np.float32).copy()
    overlay[result["od_mask"]] = (
        overlay[result["od_mask"]] * 0.5 + np.array([0, 200, 0]) * 0.5
    )
    overlay[result["oc_mask"]] = (
        overlay[result["oc_mask"]] * 0.5 + np.array([200, 0, 0]) * 0.5
    )
    overlay_rgb = overlay.astype(np.uint8)

    mean_eff = np.array([0.485, 0.456, 0.406])
    std_eff  = np.array([0.229, 0.224, 0.225])
    tensor_eff = preprocess_for_efficientnet(image_bgr, img_size=300)
    img_eff_display = tensor_eff.permute(1, 2, 0).numpy()
    img_eff_display = img_eff_display * std_eff + mean_eff
    img_eff_display = np.clip(img_eff_display, 0, 1)

    cam, _ = compute_gradcam(eff_model, tensor_eff, device)

    overlay_proc = overlay_gradcam(img_eff_display, cam, alpha=0.4)
    original_300 = cv2.resize(image_rgb, (300, 300)) / 255.0
    overlay_orig = overlay_gradcam(original_300, cam, alpha=0.4)

    return {
        "img_yolo":       img_yolo,
        "crop_rgb":       crop_rgb,
        "img_proc":       img_proc,
        "seg_rgb":        seg_rgb,
        "overlay_rgb":    overlay_rgb,
        "img_eff_input":  image_rgb,
        "img_eff_proc":   (img_eff_display * 255).astype(np.uint8),
        "gradcam_proc":   (overlay_proc * 255).astype(np.uint8),
        "gradcam_orig":   (overlay_orig * 255).astype(np.uint8),
    }


def compute_gradcam(eff_model, tensor_eff, device):
    gradients = {}
    activations = {}

    def forward_hook(module, input, output):
        activations["feat"] = output.detach()

    def backward_hook(module, grad_input, grad_output):
        gradients["feat"] = grad_output[0].detach()

    target_layer = eff_model.backbone.features[-1]
    fh = target_layer.register_forward_hook(forward_hook)
    bh = target_layer.register_full_backward_hook(backward_hook)

    input_tensor = tensor_eff.unsqueeze(0).to(device).requires_grad_(True)
    eff_model.eval()
    output = eff_model(input_tensor)
    prob = torch.sigmoid(output).item()
    eff_model.zero_grad()
    output.backward()

    fh.remove()
    bh.remove()

    weights = gradients["feat"].mean(dim=[2, 3], keepdim=True)
    cam = (weights * activations["feat"]).sum(dim=1, keepdim=True)
    cam = torch.relu(cam).squeeze().cpu().numpy()
    cam = cam - cam.min()
    if cam.max() > 0:
        cam = cam / cam.max()

    return cam, prob


def overlay_gradcam(original_img_np: np.ndarray, cam: np.ndarray, alpha: float = 0.4) -> np.ndarray:
    cam_resized = cv2.resize(cam, (original_img_np.shape[1], original_img_np.shape[0]))
    heatmap = cv2.applyColorMap((cam_resized * 255).astype(np.uint8), cv2.COLORMAP_JET)
    heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) / 255.0
    overlay = (1 - alpha) * original_img_np + alpha * heatmap
    return np.clip(overlay, 0, 1)