| import numpy as np | |
| import torch | |
| from skimage import filters | |
| from torchvision.transforms.functional import resize | |
| from utils.saliency import decoder, resnet | |
| def get_smap(image, path, filter_size=15): | |
| """ | |
| Compute the saliency map of the target image using EMLNet. | |
| Reference: https://arxiv.org/abs/1805.01047 | |
| Reference: https://github.com/SenJia/EML-NET-Saliency | |
| """ | |
| if image.shape[0] != 3: | |
| raise ValueError("Saliency prediction only supports RGB images") | |
| sod_res = (480, 640) | |
| imagenet_model = resnet.resnet50(f"{path}/emlnet/res_imagenet.pth").cuda().eval() | |
| places_model = resnet.resnet50(f"{path}/emlnet/res_places.pth").cuda().eval() | |
| decoder_model = ( | |
| decoder.build_decoder(f"{path}/emlnet/res_decoder.pth", sod_res, 5, 5) | |
| .cuda() | |
| .eval() | |
| ) | |
| image_sod = resize(image, sod_res).unsqueeze(0) | |
| with torch.no_grad(): | |
| imagenet_feat = imagenet_model(image_sod, decode=True) | |
| places_feat = places_model(image_sod, decode=True) | |
| smap = decoder_model([imagenet_feat, places_feat]) | |
| smap = resize(smap.squeeze(0).detach().cpu(), image.shape[1:]).squeeze(0) | |
| def post_process(smap): | |
| smap = filters.gaussian(smap, filter_size) | |
| smap -= smap.min() | |
| smap /= smap.max() | |
| return smap | |
| return post_process(smap.numpy()).astype(np.float32) | |