File size: 2,774 Bytes
83e35a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import matplotlib.pyplot as plt
from torchcam.utils import overlay_mask
import numpy as np
from torch import tensor
import operator


from torchvision.models import resnet18
model = resnet18(pretrained=True).eval()

# Set your CAM extractor
from torchcam.methods import SmoothGradCAMpp
cam_extractor = SmoothGradCAMpp(model)

from torchvision.io.image import read_image
from torchvision.transforms.functional import normalize, resize, to_pil_image
from torchvision.models import resnet18
from torchcam.methods import SmoothGradCAMpp
import pickle

model = resnet18(pretrained=True).eval()

CAM_data = []

def dump_CAM_data():
    # Dumping CAM_data
    with open('CAM_data.pkl', 'wb') as f:
        pickle.dump(CAM_data, f)

def get_coordinates(img_path):
    # Get your input
    img = read_image(img_path)
    # Preprocess it for your chosen model
    input_tensor = normalize(resize(img, (224, 224)) / 255., [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

    with SmoothGradCAMpp(model) as cam_extractor:
      # Preprocess your data and feed it to the model
      out = model(input_tensor.unsqueeze(0))
      # Retrieve the CAM by passing the class index and the model output
      activation_map = cam_extractor(out.squeeze(0).argmax().item(), out)

    cam_map = activation_map[0][0]
    arr = np.array(cam_map.cpu())
    ten_map = tensor(arr)
    cms = cam_map.shape[0]

    x_ = img.shape[2] // cms
    y_ = img.shape[1] // cms

    CAM_data.append({'x_': x_, 'y_': y_, 'ten_map': ten_map})

    top,bottom,left,right = -1,-1,-1,-1
    threshold = 0.2

    # Top
    found = False
    for i in range(0, ten_map.shape[0]):
        for j in range(0,ten_map.shape[1]):
          if ten_map[i][j] >= threshold:
            top = i
            found = True
            break
        if found:
          break
        
    #Bottom
    found = False
    for i in range(ten_map.shape[0]-1, -1, -1):
        for j in range(0,ten_map.shape[1]):
          if ten_map[i][j] >= threshold:
            bottom = i
            found = True
            break
        if found:
          break
        
    #Left
    found = False
    for j in range(0, ten_map.shape[1]):
        for i in range(0,ten_map.shape[0]):
          if ten_map[i][j] >= threshold:
            left = j
            found = True
            break
        if found:
          break
        
    #Right
    found = False
    for j in range(ten_map.shape[1]-1, -1, -1):
        for i in range(0,ten_map.shape[0]):
          if ten_map[i][j] >= threshold:
            right = j
            found = True
            break
        if found:
          break
        
    top = top * y_
    bottom = bottom * y_
    left = left * x_
    right = right * x_
    left,right,top,bottom

    return left, right, top, bottom