import numpy as np import torch from skimage import filters from torchvision.transforms.functional import resize from utils.saliency import decoder, resnet def get_smap(image, path, filter_size=15): """ Compute the saliency map of the target image using EMLNet. Reference: https://arxiv.org/abs/1805.01047 Reference: https://github.com/SenJia/EML-NET-Saliency """ if image.shape[0] != 3: raise ValueError("Saliency prediction only supports RGB images") sod_res = (480, 640) imagenet_model = resnet.resnet50(f"{path}/emlnet/res_imagenet.pth").cuda().eval() places_model = resnet.resnet50(f"{path}/emlnet/res_places.pth").cuda().eval() decoder_model = ( decoder.build_decoder(f"{path}/emlnet/res_decoder.pth", sod_res, 5, 5) .cuda() .eval() ) image_sod = resize(image, sod_res).unsqueeze(0) with torch.no_grad(): imagenet_feat = imagenet_model(image_sod, decode=True) places_feat = places_model(image_sod, decode=True) smap = decoder_model([imagenet_feat, places_feat]) smap = resize(smap.squeeze(0).detach().cpu(), image.shape[1:]).squeeze(0) def post_process(smap): smap = filters.gaussian(smap, filter_size) smap -= smap.min() smap /= smap.max() return smap return post_process(smap.numpy()).astype(np.float32)