|
|
import matplotlib.pyplot as plt |
|
|
from torchcam.utils import overlay_mask |
|
|
import numpy as np |
|
|
from torch import tensor |
|
|
import operator |
|
|
|
|
|
|
|
|
from torchvision.models import resnet18 |
|
|
model = resnet18(pretrained=True).eval() |
|
|
|
|
|
|
|
|
from torchcam.methods import SmoothGradCAMpp |
|
|
cam_extractor = SmoothGradCAMpp(model) |
|
|
|
|
|
from torchvision.io.image import read_image |
|
|
from torchvision.transforms.functional import normalize, resize, to_pil_image |
|
|
from torchvision.models import resnet18 |
|
|
from torchcam.methods import SmoothGradCAMpp |
|
|
import pickle |
|
|
|
|
|
model = resnet18(pretrained=True).eval() |
|
|
|
|
|
CAM_data = [] |
|
|
|
|
|
def dump_CAM_data(): |
|
|
|
|
|
with open('CAM_data.pkl', 'wb') as f: |
|
|
pickle.dump(CAM_data, f) |
|
|
|
|
|
def get_coordinates(img_path): |
|
|
|
|
|
img = read_image(img_path) |
|
|
|
|
|
input_tensor = normalize(resize(img, (224, 224)) / 255., [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
|
|
|
|
|
with SmoothGradCAMpp(model) as cam_extractor: |
|
|
|
|
|
out = model(input_tensor.unsqueeze(0)) |
|
|
|
|
|
activation_map = cam_extractor(out.squeeze(0).argmax().item(), out) |
|
|
|
|
|
cam_map = activation_map[0][0] |
|
|
arr = np.array(cam_map.cpu()) |
|
|
ten_map = tensor(arr) |
|
|
cms = cam_map.shape[0] |
|
|
|
|
|
x_ = img.shape[2] // cms |
|
|
y_ = img.shape[1] // cms |
|
|
|
|
|
CAM_data.append({'x_': x_, 'y_': y_, 'ten_map': ten_map}) |
|
|
|
|
|
top,bottom,left,right = -1,-1,-1,-1 |
|
|
threshold = 0.2 |
|
|
|
|
|
|
|
|
found = False |
|
|
for i in range(0, ten_map.shape[0]): |
|
|
for j in range(0,ten_map.shape[1]): |
|
|
if ten_map[i][j] >= threshold: |
|
|
top = i |
|
|
found = True |
|
|
break |
|
|
if found: |
|
|
break |
|
|
|
|
|
|
|
|
found = False |
|
|
for i in range(ten_map.shape[0]-1, -1, -1): |
|
|
for j in range(0,ten_map.shape[1]): |
|
|
if ten_map[i][j] >= threshold: |
|
|
bottom = i |
|
|
found = True |
|
|
break |
|
|
if found: |
|
|
break |
|
|
|
|
|
|
|
|
found = False |
|
|
for j in range(0, ten_map.shape[1]): |
|
|
for i in range(0,ten_map.shape[0]): |
|
|
if ten_map[i][j] >= threshold: |
|
|
left = j |
|
|
found = True |
|
|
break |
|
|
if found: |
|
|
break |
|
|
|
|
|
|
|
|
found = False |
|
|
for j in range(ten_map.shape[1]-1, -1, -1): |
|
|
for i in range(0,ten_map.shape[0]): |
|
|
if ten_map[i][j] >= threshold: |
|
|
right = j |
|
|
found = True |
|
|
break |
|
|
if found: |
|
|
break |
|
|
|
|
|
top = top * y_ |
|
|
bottom = bottom * y_ |
|
|
left = left * x_ |
|
|
right = right * x_ |
|
|
left,right,top,bottom |
|
|
|
|
|
return left, right, top, bottom |