Spaces:
Running
Running
Initial Update
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- README.md +5 -5
- SaRa/__pycache__/pySaliencyMap.cpython-39.pyc +0 -0
- SaRa/__pycache__/pySaliencyMapDefs.cpython-39.pyc +0 -0
- SaRa/__pycache__/saraRC1.cpython-39.pyc +0 -0
- SaRa/pySaliencyMap.py +288 -0
- SaRa/pySaliencyMapDefs.py +74 -0
- SaRa/saraRC1.py +1082 -0
- app.py +184 -0
- deepgaze_pytorch/__init__.py +3 -0
- deepgaze_pytorch/__pycache__/__init__.cpython-39.pyc +0 -0
- deepgaze_pytorch/__pycache__/deepgaze1.cpython-39.pyc +0 -0
- deepgaze_pytorch/__pycache__/deepgaze2e.cpython-39.pyc +0 -0
- deepgaze_pytorch/__pycache__/deepgaze3.cpython-39.pyc +0 -0
- deepgaze_pytorch/__pycache__/layers.cpython-39.pyc +0 -0
- deepgaze_pytorch/__pycache__/modules.cpython-39.pyc +0 -0
- deepgaze_pytorch/data.py +403 -0
- deepgaze_pytorch/deepgaze1.py +42 -0
- deepgaze_pytorch/deepgaze2e.py +151 -0
- deepgaze_pytorch/deepgaze3.py +110 -0
- deepgaze_pytorch/features/__init__.py +0 -0
- deepgaze_pytorch/features/__pycache__/__init__.cpython-39.pyc +0 -0
- deepgaze_pytorch/features/__pycache__/alexnet.cpython-39.pyc +0 -0
- deepgaze_pytorch/features/__pycache__/densenet.cpython-39.pyc +0 -0
- deepgaze_pytorch/features/__pycache__/efficientnet.cpython-39.pyc +0 -0
- deepgaze_pytorch/features/__pycache__/normalizer.cpython-39.pyc +0 -0
- deepgaze_pytorch/features/__pycache__/resnext.cpython-39.pyc +0 -0
- deepgaze_pytorch/features/__pycache__/shapenet.cpython-39.pyc +0 -0
- deepgaze_pytorch/features/alexnet.py +18 -0
- deepgaze_pytorch/features/bagnet.py +192 -0
- deepgaze_pytorch/features/densenet.py +19 -0
- deepgaze_pytorch/features/efficientnet.py +31 -0
- deepgaze_pytorch/features/efficientnet_pytorch/__init__.py +10 -0
- deepgaze_pytorch/features/efficientnet_pytorch/__pycache__/__init__.cpython-39.pyc +0 -0
- deepgaze_pytorch/features/efficientnet_pytorch/__pycache__/model.cpython-39.pyc +0 -0
- deepgaze_pytorch/features/efficientnet_pytorch/__pycache__/utils.cpython-39.pyc +0 -0
- deepgaze_pytorch/features/efficientnet_pytorch/model.py +229 -0
- deepgaze_pytorch/features/efficientnet_pytorch/utils.py +335 -0
- deepgaze_pytorch/features/inception.py +20 -0
- deepgaze_pytorch/features/mobilenet.py +17 -0
- deepgaze_pytorch/features/normalizer.py +28 -0
- deepgaze_pytorch/features/resnet.py +44 -0
- deepgaze_pytorch/features/resnext.py +27 -0
- deepgaze_pytorch/features/shapenet.py +89 -0
- deepgaze_pytorch/features/squeezenet.py +17 -0
- deepgaze_pytorch/features/swav.py +20 -0
- deepgaze_pytorch/features/uninformative.py +26 -0
- deepgaze_pytorch/features/vgg.py +86 -0
- deepgaze_pytorch/features/vggnet.py +24 -0
- deepgaze_pytorch/features/wsl.py +27 -0
- deepgaze_pytorch/layers.py +427 -0
README.md
CHANGED
|
@@ -1,13 +1,13 @@
|
|
| 1 |
---
|
| 2 |
title: Saliency Ranking
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version:
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
license: mit
|
| 11 |
---
|
| 12 |
|
| 13 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
| 1 |
---
|
| 2 |
title: Saliency Ranking
|
| 3 |
+
emoji: 📚
|
| 4 |
+
colorFrom: red
|
| 5 |
+
colorTo: indigo
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 4.39.0
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
license: mit
|
| 11 |
---
|
| 12 |
|
| 13 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
SaRa/__pycache__/pySaliencyMap.cpython-39.pyc
ADDED
|
Binary file (7.79 kB). View file
|
|
|
SaRa/__pycache__/pySaliencyMapDefs.cpython-39.pyc
ADDED
|
Binary file (2.01 kB). View file
|
|
|
SaRa/__pycache__/saraRC1.cpython-39.pyc
ADDED
|
Binary file (18.5 kB). View file
|
|
|
SaRa/pySaliencyMap.py
ADDED
|
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#-------------------------------------------------------------------------------
|
| 2 |
+
# Name: pySaliencyMap
|
| 3 |
+
# Purpose: Extracting a saliency map from a single still image
|
| 4 |
+
#
|
| 5 |
+
# Author: Akisato Kimura <akisato@ieee.org>
|
| 6 |
+
#
|
| 7 |
+
# Created: April 24, 2014
|
| 8 |
+
# Copyright: (c) Akisato Kimura 2014-
|
| 9 |
+
# Licence: All rights reserved
|
| 10 |
+
#-------------------------------------------------------------------------------
|
| 11 |
+
|
| 12 |
+
import cv2
|
| 13 |
+
import numpy as np
|
| 14 |
+
import SaRa.pySaliencyMapDefs as pySaliencyMapDefs
|
| 15 |
+
import time
|
| 16 |
+
|
| 17 |
+
class pySaliencyMap:
|
| 18 |
+
# initialization
|
| 19 |
+
def __init__(self, width, height):
|
| 20 |
+
self.width = width
|
| 21 |
+
self.height = height
|
| 22 |
+
self.prev_frame = None
|
| 23 |
+
self.SM = None
|
| 24 |
+
self.GaborKernel0 = np.array(pySaliencyMapDefs.GaborKernel_0)
|
| 25 |
+
self.GaborKernel45 = np.array(pySaliencyMapDefs.GaborKernel_45)
|
| 26 |
+
self.GaborKernel90 = np.array(pySaliencyMapDefs.GaborKernel_90)
|
| 27 |
+
self.GaborKernel135 = np.array(pySaliencyMapDefs.GaborKernel_135)
|
| 28 |
+
|
| 29 |
+
# extracting color channels
|
| 30 |
+
def SMExtractRGBI(self, inputImage):
|
| 31 |
+
# convert scale of array elements
|
| 32 |
+
src = np.float32(inputImage) * 1./255
|
| 33 |
+
# split
|
| 34 |
+
(B, G, R) = cv2.split(src)
|
| 35 |
+
# extract an intensity image
|
| 36 |
+
I = cv2.cvtColor(src, cv2.COLOR_BGR2GRAY)
|
| 37 |
+
# return
|
| 38 |
+
return R, G, B, I
|
| 39 |
+
|
| 40 |
+
# feature maps
|
| 41 |
+
## constructing a Gaussian pyramid
|
| 42 |
+
def FMCreateGaussianPyr(self, src):
|
| 43 |
+
dst = list()
|
| 44 |
+
dst.append(src)
|
| 45 |
+
for i in range(1,9):
|
| 46 |
+
nowdst = cv2.pyrDown(dst[i-1])
|
| 47 |
+
dst.append(nowdst)
|
| 48 |
+
return dst
|
| 49 |
+
## taking center-surround differences
|
| 50 |
+
def FMCenterSurroundDiff(self, GaussianMaps):
|
| 51 |
+
dst = list()
|
| 52 |
+
for s in range(2,5):
|
| 53 |
+
now_size = GaussianMaps[s].shape
|
| 54 |
+
now_size = (now_size[1], now_size[0]) ## (width, height)
|
| 55 |
+
tmp = cv2.resize(GaussianMaps[s+3], now_size, interpolation=cv2.INTER_LINEAR)
|
| 56 |
+
nowdst = cv2.absdiff(GaussianMaps[s], tmp)
|
| 57 |
+
dst.append(nowdst)
|
| 58 |
+
tmp = cv2.resize(GaussianMaps[s+4], now_size, interpolation=cv2.INTER_LINEAR)
|
| 59 |
+
nowdst = cv2.absdiff(GaussianMaps[s], tmp)
|
| 60 |
+
dst.append(nowdst)
|
| 61 |
+
return dst
|
| 62 |
+
## constructing a Gaussian pyramid + taking center-surround differences
|
| 63 |
+
def FMGaussianPyrCSD(self, src):
|
| 64 |
+
GaussianMaps = self.FMCreateGaussianPyr(src)
|
| 65 |
+
dst = self.FMCenterSurroundDiff(GaussianMaps)
|
| 66 |
+
return dst
|
| 67 |
+
## intensity feature maps
|
| 68 |
+
def IFMGetFM(self, I):
|
| 69 |
+
return self.FMGaussianPyrCSD(I)
|
| 70 |
+
## color feature maps
|
| 71 |
+
def CFMGetFM(self, R, G, B):
|
| 72 |
+
# max(R,G,B)
|
| 73 |
+
tmp1 = cv2.max(R, G)
|
| 74 |
+
RGBMax = cv2.max(B, tmp1)
|
| 75 |
+
RGBMax[RGBMax <= 0] = 0.0001 # prevent dividing by 0
|
| 76 |
+
# min(R,G)
|
| 77 |
+
RGMin = cv2.min(R, G)
|
| 78 |
+
# RG = (R-G)/max(R,G,B)
|
| 79 |
+
RG = (R - G) / RGBMax
|
| 80 |
+
# BY = (B-min(R,G)/max(R,G,B)
|
| 81 |
+
BY = (B - RGMin) / RGBMax
|
| 82 |
+
# clamp nagative values to 0
|
| 83 |
+
RG[RG < 0] = 0
|
| 84 |
+
BY[BY < 0] = 0
|
| 85 |
+
# obtain feature maps in the same way as intensity
|
| 86 |
+
RGFM = self.FMGaussianPyrCSD(RG)
|
| 87 |
+
BYFM = self.FMGaussianPyrCSD(BY)
|
| 88 |
+
# return
|
| 89 |
+
return RGFM, BYFM
|
| 90 |
+
## orientation feature maps
|
| 91 |
+
def OFMGetFM(self, src):
|
| 92 |
+
# creating a Gaussian pyramid
|
| 93 |
+
GaussianI = self.FMCreateGaussianPyr(src)
|
| 94 |
+
# convoluting a Gabor filter with an intensity image to extract oriemtation features
|
| 95 |
+
GaborOutput0 = [ np.empty((1,1)), np.empty((1,1)) ] # dummy data: any kinds of np.array()s are OK
|
| 96 |
+
GaborOutput45 = [ np.empty((1,1)), np.empty((1,1)) ]
|
| 97 |
+
GaborOutput90 = [ np.empty((1,1)), np.empty((1,1)) ]
|
| 98 |
+
GaborOutput135 = [ np.empty((1,1)), np.empty((1,1)) ]
|
| 99 |
+
for j in range(2,9):
|
| 100 |
+
GaborOutput0.append( cv2.filter2D(GaussianI[j], cv2.CV_32F, self.GaborKernel0) )
|
| 101 |
+
GaborOutput45.append( cv2.filter2D(GaussianI[j], cv2.CV_32F, self.GaborKernel45) )
|
| 102 |
+
GaborOutput90.append( cv2.filter2D(GaussianI[j], cv2.CV_32F, self.GaborKernel90) )
|
| 103 |
+
GaborOutput135.append( cv2.filter2D(GaussianI[j], cv2.CV_32F, self.GaborKernel135) )
|
| 104 |
+
# calculating center-surround differences for every oriantation
|
| 105 |
+
CSD0 = self.FMCenterSurroundDiff(GaborOutput0)
|
| 106 |
+
CSD45 = self.FMCenterSurroundDiff(GaborOutput45)
|
| 107 |
+
CSD90 = self.FMCenterSurroundDiff(GaborOutput90)
|
| 108 |
+
CSD135 = self.FMCenterSurroundDiff(GaborOutput135)
|
| 109 |
+
# concatenate
|
| 110 |
+
dst = list(CSD0)
|
| 111 |
+
dst.extend(CSD45)
|
| 112 |
+
dst.extend(CSD90)
|
| 113 |
+
dst.extend(CSD135)
|
| 114 |
+
# return
|
| 115 |
+
return dst
|
| 116 |
+
## motion feature maps
|
| 117 |
+
def MFMGetFM(self, src):
|
| 118 |
+
# convert scale
|
| 119 |
+
I8U = np.uint8(255 * src)
|
| 120 |
+
# cv2.waitKey(10)
|
| 121 |
+
# calculating optical flows
|
| 122 |
+
if self.prev_frame is not None:
|
| 123 |
+
farne_pyr_scale= pySaliencyMapDefs.farne_pyr_scale
|
| 124 |
+
farne_levels = pySaliencyMapDefs.farne_levels
|
| 125 |
+
farne_winsize = pySaliencyMapDefs.farne_winsize
|
| 126 |
+
farne_iterations = pySaliencyMapDefs.farne_iterations
|
| 127 |
+
farne_poly_n = pySaliencyMapDefs.farne_poly_n
|
| 128 |
+
farne_poly_sigma = pySaliencyMapDefs.farne_poly_sigma
|
| 129 |
+
farne_flags = pySaliencyMapDefs.farne_flags
|
| 130 |
+
flow = cv2.calcOpticalFlowFarneback(\
|
| 131 |
+
prev = self.prev_frame, \
|
| 132 |
+
next = I8U, \
|
| 133 |
+
pyr_scale = farne_pyr_scale, \
|
| 134 |
+
levels = farne_levels, \
|
| 135 |
+
winsize = farne_winsize, \
|
| 136 |
+
iterations = farne_iterations, \
|
| 137 |
+
poly_n = farne_poly_n, \
|
| 138 |
+
poly_sigma = farne_poly_sigma, \
|
| 139 |
+
flags = farne_flags, \
|
| 140 |
+
flow = None \
|
| 141 |
+
)
|
| 142 |
+
flowx = flow[...,0]
|
| 143 |
+
flowy = flow[...,1]
|
| 144 |
+
else:
|
| 145 |
+
flowx = np.zeros(I8U.shape)
|
| 146 |
+
flowy = np.zeros(I8U.shape)
|
| 147 |
+
# create Gaussian pyramids
|
| 148 |
+
dst_x = self.FMGaussianPyrCSD(flowx)
|
| 149 |
+
dst_y = self.FMGaussianPyrCSD(flowy)
|
| 150 |
+
# update the current frame
|
| 151 |
+
self.prev_frame = np.uint8(I8U)
|
| 152 |
+
# return
|
| 153 |
+
return dst_x, dst_y
|
| 154 |
+
|
| 155 |
+
# conspicuity maps
|
| 156 |
+
## standard range normalization
|
| 157 |
+
def SMRangeNormalize(self, src):
|
| 158 |
+
minn, maxx, dummy1, dummy2 = cv2.minMaxLoc(src)
|
| 159 |
+
if maxx!=minn:
|
| 160 |
+
dst = src/(maxx-minn) + minn/(minn-maxx)
|
| 161 |
+
else:
|
| 162 |
+
dst = src - minn
|
| 163 |
+
return dst
|
| 164 |
+
## computing an average of local maxima
|
| 165 |
+
def SMAvgLocalMax(self, src):
|
| 166 |
+
# size
|
| 167 |
+
stepsize = pySaliencyMapDefs.default_step_local
|
| 168 |
+
width = src.shape[1]
|
| 169 |
+
height = src.shape[0]
|
| 170 |
+
# find local maxima
|
| 171 |
+
numlocal = 0
|
| 172 |
+
lmaxmean = 0
|
| 173 |
+
for y in range(0, height-stepsize, stepsize):
|
| 174 |
+
for x in range(0, width-stepsize, stepsize):
|
| 175 |
+
localimg = src[y:y+stepsize, x:x+stepsize]
|
| 176 |
+
lmin, lmax, dummy1, dummy2 = cv2.minMaxLoc(localimg)
|
| 177 |
+
lmaxmean += lmax
|
| 178 |
+
numlocal += 1
|
| 179 |
+
# averaging over all the local regions (error checking for numlocal)
|
| 180 |
+
if numlocal==0:
|
| 181 |
+
return 0
|
| 182 |
+
else:
|
| 183 |
+
return lmaxmean / numlocal
|
| 184 |
+
## normalization specific for the saliency map model
|
| 185 |
+
def SMNormalization(self, src):
|
| 186 |
+
dst = self.SMRangeNormalize(src)
|
| 187 |
+
lmaxmean = self.SMAvgLocalMax(dst)
|
| 188 |
+
normcoeff = (1-lmaxmean)*(1-lmaxmean)
|
| 189 |
+
return dst * normcoeff
|
| 190 |
+
## normalizing feature maps
|
| 191 |
+
def normalizeFeatureMaps(self, FM):
|
| 192 |
+
NFM = list()
|
| 193 |
+
for i in range(0,6):
|
| 194 |
+
normalizedImage = self.SMNormalization(FM[i])
|
| 195 |
+
nownfm = cv2.resize(normalizedImage, (self.width, self.height), interpolation=cv2.INTER_LINEAR)
|
| 196 |
+
NFM.append(nownfm)
|
| 197 |
+
return NFM
|
| 198 |
+
## intensity conspicuity map
|
| 199 |
+
def ICMGetCM(self, IFM):
|
| 200 |
+
NIFM = self.normalizeFeatureMaps(IFM)
|
| 201 |
+
ICM = sum(NIFM)
|
| 202 |
+
return ICM
|
| 203 |
+
## color conspicuity map
|
| 204 |
+
def CCMGetCM(self, CFM_RG, CFM_BY):
|
| 205 |
+
# extracting a conspicuity map for every color opponent pair
|
| 206 |
+
CCM_RG = self.ICMGetCM(CFM_RG)
|
| 207 |
+
CCM_BY = self.ICMGetCM(CFM_BY)
|
| 208 |
+
# merge
|
| 209 |
+
CCM = CCM_RG + CCM_BY
|
| 210 |
+
# return
|
| 211 |
+
return CCM
|
| 212 |
+
## orientation conspicuity map
|
| 213 |
+
def OCMGetCM(self, OFM):
|
| 214 |
+
OCM = np.zeros((self.height, self.width))
|
| 215 |
+
for i in range (0,4):
|
| 216 |
+
# slicing
|
| 217 |
+
nowofm = OFM[i*6:(i+1)*6] # angle = i*45
|
| 218 |
+
# extracting a conspicuity map for every angle
|
| 219 |
+
NOFM = self.ICMGetCM(nowofm)
|
| 220 |
+
# normalize
|
| 221 |
+
NOFM2 = self.SMNormalization(NOFM)
|
| 222 |
+
# accumulate
|
| 223 |
+
OCM += NOFM2
|
| 224 |
+
return OCM
|
| 225 |
+
## motion conspicuity map
|
| 226 |
+
def MCMGetCM(self, MFM_X, MFM_Y):
|
| 227 |
+
return self.CCMGetCM(MFM_X, MFM_Y)
|
| 228 |
+
|
| 229 |
+
# core
|
| 230 |
+
def SMGetSM(self, src):
|
| 231 |
+
# definitions
|
| 232 |
+
size = src.shape
|
| 233 |
+
width = size[1]
|
| 234 |
+
height = size[0]
|
| 235 |
+
# check
|
| 236 |
+
# if(width != self.width or height != self.height):
|
| 237 |
+
# sys.exit("size mismatch")
|
| 238 |
+
# extracting individual color channels
|
| 239 |
+
R, G, B, I = self.SMExtractRGBI(src)
|
| 240 |
+
# extracting feature maps
|
| 241 |
+
IFM = self.IFMGetFM(I)
|
| 242 |
+
CFM_RG, CFM_BY = self.CFMGetFM(R, G, B)
|
| 243 |
+
OFM = self.OFMGetFM(I)
|
| 244 |
+
MFM_X, MFM_Y = self.MFMGetFM(I)
|
| 245 |
+
# extracting conspicuity maps
|
| 246 |
+
ICM = self.ICMGetCM(IFM)
|
| 247 |
+
CCM = self.CCMGetCM(CFM_RG, CFM_BY)
|
| 248 |
+
OCM = self.OCMGetCM(OFM)
|
| 249 |
+
MCM = self.MCMGetCM(MFM_X, MFM_Y)
|
| 250 |
+
# adding all the conspicuity maps to form a saliency map
|
| 251 |
+
wi = pySaliencyMapDefs.weight_intensity
|
| 252 |
+
wc = pySaliencyMapDefs.weight_color
|
| 253 |
+
wo = pySaliencyMapDefs.weight_orientation
|
| 254 |
+
wm = pySaliencyMapDefs.weight_motion
|
| 255 |
+
SMMat = wi*ICM + wc*CCM + wo*OCM + wm*MCM
|
| 256 |
+
# normalize
|
| 257 |
+
normalizedSM = self.SMRangeNormalize(SMMat)
|
| 258 |
+
normalizedSM2 = normalizedSM.astype(np.float32)
|
| 259 |
+
smoothedSM = cv2.bilateralFilter(normalizedSM2, 7, 3, 1.55)
|
| 260 |
+
self.SM = cv2.resize(smoothedSM, (width,height), interpolation=cv2.INTER_NEAREST)
|
| 261 |
+
# return
|
| 262 |
+
return self.SM
|
| 263 |
+
|
| 264 |
+
def SMGetBinarizedSM(self, src):
|
| 265 |
+
# get a saliency map
|
| 266 |
+
if self.SM is None:
|
| 267 |
+
self.SM = self.SMGetSM(src)
|
| 268 |
+
# convert scale
|
| 269 |
+
SM_I8U = np.uint8(255 * self.SM)
|
| 270 |
+
# binarize
|
| 271 |
+
thresh, binarized_SM = cv2.threshold(SM_I8U, thresh=0, maxval=255, type=cv2.THRESH_BINARY+cv2.THRESH_OTSU)
|
| 272 |
+
return binarized_SM
|
| 273 |
+
|
| 274 |
+
def SMGetSalientRegion(self, src):
|
| 275 |
+
# get a binarized saliency map
|
| 276 |
+
binarized_SM = self.SMGetBinarizedSM(src)
|
| 277 |
+
# GrabCut
|
| 278 |
+
img = src.copy()
|
| 279 |
+
mask = np.where((binarized_SM!=0), cv2.GC_PR_FGD, cv2.GC_PR_BGD).astype('uint8')
|
| 280 |
+
bgdmodel = np.zeros((1,65),np.float64)
|
| 281 |
+
fgdmodel = np.zeros((1,65),np.float64)
|
| 282 |
+
rect = (0,0,1,1) # dummy
|
| 283 |
+
iterCount = 1
|
| 284 |
+
cv2.grabCut(img, mask=mask, rect=rect, bgdModel=bgdmodel, fgdModel=fgdmodel, iterCount=iterCount, mode=cv2.GC_INIT_WITH_MASK)
|
| 285 |
+
# post-processing
|
| 286 |
+
mask_out = np.where((mask==cv2.GC_FGD) + (mask==cv2.GC_PR_FGD), 255, 0).astype('uint8')
|
| 287 |
+
output = cv2.bitwise_and(img,img,mask=mask_out)
|
| 288 |
+
return output
|
SaRa/pySaliencyMapDefs.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#-------------------------------------------------------------------------------
|
| 2 |
+
# Name: pySaliencyMapDefs
|
| 3 |
+
# Purpose: Definitions for class pySaliencyMap
|
| 4 |
+
#
|
| 5 |
+
# Author: Akisato Kimura <akisato@ieee.org>
|
| 6 |
+
#
|
| 7 |
+
# Created: April 24, 2014
|
| 8 |
+
# Copyright: (c) Akisato Kimura 2014-
|
| 9 |
+
# Licence: All rights reserved
|
| 10 |
+
#-------------------------------------------------------------------------------
|
| 11 |
+
|
| 12 |
+
# parameters for computing optical flows using the Gunner Farneback's algorithm
|
| 13 |
+
farne_pyr_scale = 0.5
|
| 14 |
+
farne_levels = 3
|
| 15 |
+
farne_winsize = 15
|
| 16 |
+
farne_iterations = 3
|
| 17 |
+
farne_poly_n = 5
|
| 18 |
+
farne_poly_sigma = 1.2
|
| 19 |
+
farne_flags = 0
|
| 20 |
+
|
| 21 |
+
# parameters for detecting local maxima
|
| 22 |
+
default_step_local = 16
|
| 23 |
+
|
| 24 |
+
# feature weights
|
| 25 |
+
weight_intensity = 0.30
|
| 26 |
+
weight_color = 0.30
|
| 27 |
+
weight_orientation = 0.20
|
| 28 |
+
weight_motion = 0.20
|
| 29 |
+
|
| 30 |
+
# coefficients of Gabor filters
|
| 31 |
+
GaborKernel_0 = [\
|
| 32 |
+
[ 1.85212E-06, 1.28181E-05, -0.000350433, -0.000136537, 0.002010422, -0.000136537, -0.000350433, 1.28181E-05, 1.85212E-06 ],\
|
| 33 |
+
[ 2.80209E-05, 0.000193926, -0.005301717, -0.002065674, 0.030415784, -0.002065674, -0.005301717, 0.000193926, 2.80209E-05 ],\
|
| 34 |
+
[ 0.000195076, 0.001350077, -0.036909595, -0.014380852, 0.211749204, -0.014380852, -0.036909595, 0.001350077, 0.000195076 ],\
|
| 35 |
+
[ 0.000624940, 0.004325061, -0.118242318, -0.046070008, 0.678352526, -0.046070008, -0.118242318, 0.004325061, 0.000624940 ],\
|
| 36 |
+
[ 0.000921261, 0.006375831, -0.174308068, -0.067914552, 1.000000000, -0.067914552, -0.174308068, 0.006375831, 0.000921261 ],\
|
| 37 |
+
[ 0.000624940, 0.004325061, -0.118242318, -0.046070008, 0.678352526, -0.046070008, -0.118242318, 0.004325061, 0.000624940 ],\
|
| 38 |
+
[ 0.000195076, 0.001350077, -0.036909595, -0.014380852, 0.211749204, -0.014380852, -0.036909595, 0.001350077, 0.000195076 ],\
|
| 39 |
+
[ 2.80209E-05, 0.000193926, -0.005301717, -0.002065674, 0.030415784, -0.002065674, -0.005301717, 0.000193926, 2.80209E-05 ],\
|
| 40 |
+
[ 1.85212E-06, 1.28181E-05, -0.000350433, -0.000136537, 0.002010422, -0.000136537, -0.000350433, 1.28181E-05, 1.85212E-06 ]\
|
| 41 |
+
]
|
| 42 |
+
GaborKernel_45 = [\
|
| 43 |
+
[ 4.04180E-06, 2.25320E-05, -0.000279806, -0.001028923, 3.79931E-05, 0.000744712, 0.000132863, -9.04408E-06, -1.01551E-06 ],\
|
| 44 |
+
[ 2.25320E-05, 0.000925120, 0.002373205, -0.013561362, -0.022947700, 0.000389916, 0.003516954, 0.000288732, -9.04408E-06 ],\
|
| 45 |
+
[ -0.000279806, 0.002373205, 0.044837725, 0.052928748, -0.139178011, -0.108372072, 0.000847346, 0.003516954, 0.000132863 ],\
|
| 46 |
+
[ -0.001028923, -0.013561362, 0.052928748, 0.460162150, 0.249959607, -0.302454279, -0.108372072, 0.000389916, 0.000744712 ],\
|
| 47 |
+
[ 3.79931E-05, -0.022947700, -0.139178011, 0.249959607, 1.000000000, 0.249959607, -0.139178011, -0.022947700, 3.79931E-05 ],\
|
| 48 |
+
[ 0.000744712, 0.003899160, -0.108372072, -0.302454279, 0.249959607, 0.460162150, 0.052928748, -0.013561362, -0.001028923 ],\
|
| 49 |
+
[ 0.000132863, 0.003516954, 0.000847346, -0.108372072, -0.139178011, 0.052928748, 0.044837725, 0.002373205, -0.000279806 ],\
|
| 50 |
+
[ -9.04408E-06, 0.000288732, 0.003516954, 0.000389916, -0.022947700, -0.013561362, 0.002373205, 0.000925120, 2.25320E-05 ],\
|
| 51 |
+
[ -1.01551E-06, -9.04408E-06, 0.000132863, 0.000744712, 3.79931E-05, -0.001028923, -0.000279806, 2.25320E-05, 4.04180E-06 ]\
|
| 52 |
+
]
|
| 53 |
+
GaborKernel_90 = [\
|
| 54 |
+
[ 1.85212E-06, 2.80209E-05, 0.000195076, 0.000624940, 0.000921261, 0.000624940, 0.000195076, 2.80209E-05, 1.85212E-06 ],\
|
| 55 |
+
[ 1.28181E-05, 0.000193926, 0.001350077, 0.004325061, 0.006375831, 0.004325061, 0.001350077, 0.000193926, 1.28181E-05 ],\
|
| 56 |
+
[ -0.000350433, -0.005301717, -0.036909595, -0.118242318, -0.174308068, -0.118242318, -0.036909595, -0.005301717, -0.000350433 ],\
|
| 57 |
+
[ -0.000136537, -0.002065674, -0.014380852, -0.046070008, -0.067914552, -0.046070008, -0.014380852, -0.002065674, -0.000136537 ],\
|
| 58 |
+
[ 0.002010422, 0.030415784, 0.211749204, 0.678352526, 1.000000000, 0.678352526, 0.211749204, 0.030415784, 0.002010422 ],\
|
| 59 |
+
[ -0.000136537, -0.002065674, -0.014380852, -0.046070008, -0.067914552, -0.046070008, -0.014380852, -0.002065674, -0.000136537 ],\
|
| 60 |
+
[ -0.000350433, -0.005301717, -0.036909595, -0.118242318, -0.174308068, -0.118242318, -0.036909595, -0.005301717, -0.000350433 ],\
|
| 61 |
+
[ 1.28181E-05, 0.000193926, 0.001350077, 0.004325061, 0.006375831, 0.004325061, 0.001350077, 0.000193926, 1.28181E-05 ],\
|
| 62 |
+
[ 1.85212E-06, 2.80209E-05, 0.000195076, 0.000624940, 0.000921261, 0.000624940, 0.000195076, 2.80209E-05, 1.85212E-06 ]
|
| 63 |
+
]
|
| 64 |
+
GaborKernel_135 = [\
|
| 65 |
+
[ -1.01551E-06, -9.04408E-06, 0.000132863, 0.000744712, 3.79931E-05, -0.001028923, -0.000279806, 2.2532E-05, 4.0418E-06 ],\
|
| 66 |
+
[ -9.04408E-06, 0.000288732, 0.003516954, 0.000389916, -0.022947700, -0.013561362, 0.002373205, 0.00092512, 2.2532E-05 ],\
|
| 67 |
+
[ 0.000132863, 0.003516954, 0.000847346, -0.108372072, -0.139178011, 0.052928748, 0.044837725, 0.002373205, -0.000279806 ],\
|
| 68 |
+
[ 0.000744712, 0.000389916, -0.108372072, -0.302454279, 0.249959607, 0.46016215, 0.052928748, -0.013561362, -0.001028923 ],\
|
| 69 |
+
[ 3.79931E-05, -0.022947700, -0.139178011, 0.249959607, 1.000000000, 0.249959607, -0.139178011, -0.0229477, 3.79931E-05 ],\
|
| 70 |
+
[ -0.001028923, -0.013561362, 0.052928748, 0.460162150, 0.249959607, -0.302454279, -0.108372072, 0.000389916, 0.000744712 ],\
|
| 71 |
+
[ -0.000279806, 0.002373205, 0.044837725, 0.052928748, -0.139178011, -0.108372072, 0.000847346, 0.003516954, 0.000132863 ],\
|
| 72 |
+
[ 2.25320E-05, 0.000925120, 0.002373205, -0.013561362, -0.022947700, 0.000389916, 0.003516954, 0.000288732, -9.04408E-06 ],\
|
| 73 |
+
[ 4.04180E-06, 2.25320E-05, -0.000279806, -0.001028923, 3.79931E-05 , 0.000744712, 0.000132863, -9.04408E-06, -1.01551E-06 ]\
|
| 74 |
+
]
|
SaRa/saraRC1.py
ADDED
|
@@ -0,0 +1,1082 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
+
import math
|
| 4 |
+
import scipy.stats as st
|
| 5 |
+
from mpl_toolkits.mplot3d import Axes3D
|
| 6 |
+
from matplotlib.lines import Line2D
|
| 7 |
+
import matplotlib.pyplot as plt
|
| 8 |
+
import operator
|
| 9 |
+
import time
|
| 10 |
+
import os
|
| 11 |
+
from enum import Enum
|
| 12 |
+
import pandas as pd
|
| 13 |
+
|
| 14 |
+
# Akisato Kimura <akisato@ieee.org> implementation of Itti's Saliency Map Generator -- https://github.com/akisatok/pySaliencyMap
|
| 15 |
+
from SaRa.pySaliencyMap import pySaliencyMap
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# Global Variables
|
| 19 |
+
|
| 20 |
+
# Entropy, sum, depth, centre-bias
|
| 21 |
+
WEIGHTS = (1, 1, 1, 1)
|
| 22 |
+
|
| 23 |
+
# segments_entropies = []
|
| 24 |
+
segments_scores = []
|
| 25 |
+
segments_coords = []
|
| 26 |
+
|
| 27 |
+
seg_dim = 0
|
| 28 |
+
segments = []
|
| 29 |
+
gt_segments = []
|
| 30 |
+
dws = []
|
| 31 |
+
sara_list = []
|
| 32 |
+
|
| 33 |
+
eval_list = []
|
| 34 |
+
labels_eval_list = ['Image', 'Index', 'Rank', 'Quartile', 'isGT', 'Outcome']
|
| 35 |
+
|
| 36 |
+
outcome_list = []
|
| 37 |
+
labels_outcome_list = ['Image', 'FN', 'FP', 'TN', 'TP']
|
| 38 |
+
|
| 39 |
+
dataframe_collection = {}
|
| 40 |
+
error_count = 0
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# SaRa Initial Functions
|
| 44 |
+
def generate_segments(img, seg_count) -> list:
|
| 45 |
+
'''
|
| 46 |
+
Given an image img and the desired number of segments seg_count, this
|
| 47 |
+
function divides the image into segments and returns a list of segments.
|
| 48 |
+
'''
|
| 49 |
+
|
| 50 |
+
segments = []
|
| 51 |
+
segment_count = seg_count
|
| 52 |
+
index = 0
|
| 53 |
+
|
| 54 |
+
w_interval = int(img.shape[1] / segment_count)
|
| 55 |
+
h_interval = int(img.shape[0] / segment_count)
|
| 56 |
+
|
| 57 |
+
for i in range(segment_count):
|
| 58 |
+
for j in range(segment_count):
|
| 59 |
+
temp_segment = img[int(h_interval * i):int(h_interval * (i + 1)),
|
| 60 |
+
int(w_interval * j):int(w_interval * (j + 1))]
|
| 61 |
+
segments.append(temp_segment)
|
| 62 |
+
|
| 63 |
+
coord_tup = (index, int(w_interval * j), int(h_interval * i),
|
| 64 |
+
int(w_interval * (j + 1)), int(h_interval * (i + 1)))
|
| 65 |
+
segments_coords.append(coord_tup)
|
| 66 |
+
|
| 67 |
+
index += 1
|
| 68 |
+
|
| 69 |
+
return segments
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def return_saliency(img, generator='itti', deepgaze_model=None, emlnet_models=None, DEVICE='cpu'):
|
| 73 |
+
'''
|
| 74 |
+
Takes an image img as input and calculates the saliency map using the
|
| 75 |
+
Itti's Saliency Map Generator. It returns the saliency map.
|
| 76 |
+
'''
|
| 77 |
+
|
| 78 |
+
img_width, img_height = img.shape[1], img.shape[0]
|
| 79 |
+
|
| 80 |
+
if generator == 'itti':
|
| 81 |
+
|
| 82 |
+
sm = pySaliencyMap(img_width, img_height)
|
| 83 |
+
saliency_map = sm.SMGetSM(img)
|
| 84 |
+
|
| 85 |
+
# Scale pixel values to 0-255 instead of float (approx 0, hence black image)
|
| 86 |
+
# https://stackoverflow.com/questions/48331211/how-to-use-cv2-imshow-correctly-for-the-float-image-returned-by-cv2-distancet/48333272
|
| 87 |
+
saliency_map = cv2.normalize(saliency_map, None, 255, 0, cv2.NORM_MINMAX, cv2.CV_8UC1)
|
| 88 |
+
elif generator == 'deepgaze':
|
| 89 |
+
import numpy as np
|
| 90 |
+
from scipy.misc import face
|
| 91 |
+
from scipy.ndimage import zoom
|
| 92 |
+
from scipy.special import logsumexp
|
| 93 |
+
import torch
|
| 94 |
+
|
| 95 |
+
import deepgaze_pytorch
|
| 96 |
+
|
| 97 |
+
# you can use DeepGazeI or DeepGazeIIE
|
| 98 |
+
# model = deepgaze_pytorch.DeepGazeIIE(pretrained=True).to(DEVICE)
|
| 99 |
+
|
| 100 |
+
if deepgaze_model is None:
|
| 101 |
+
model = deepgaze_pytorch.DeepGazeIIE(pretrained=True).to(DEVICE)
|
| 102 |
+
else:
|
| 103 |
+
model = deepgaze_model
|
| 104 |
+
|
| 105 |
+
# image = face()
|
| 106 |
+
image = img
|
| 107 |
+
|
| 108 |
+
# load precomputed centerbias log density (from MIT1003) over a 1024x1024 image
|
| 109 |
+
# you can download the centerbias from https://github.com/matthias-k/DeepGaze/releases/download/v1.0.0/centerbias_mit1003.npy
|
| 110 |
+
# alternatively, you can use a uniform centerbias via `centerbias_template = np.zeros((1024, 1024))`.
|
| 111 |
+
# centerbias_template = np.load('centerbias_mit1003.npy')
|
| 112 |
+
centerbias_template = np.zeros((1024, 1024))
|
| 113 |
+
# rescale to match image size
|
| 114 |
+
centerbias = zoom(centerbias_template, (image.shape[0]/centerbias_template.shape[0], image.shape[1]/centerbias_template.shape[1]), order=0, mode='nearest')
|
| 115 |
+
# renormalize log density
|
| 116 |
+
centerbias -= logsumexp(centerbias)
|
| 117 |
+
|
| 118 |
+
image_tensor = torch.tensor([image.transpose(2, 0, 1)]).to(DEVICE)
|
| 119 |
+
centerbias_tensor = torch.tensor([centerbias]).to(DEVICE)
|
| 120 |
+
|
| 121 |
+
log_density_prediction = model(image_tensor, centerbias_tensor)
|
| 122 |
+
|
| 123 |
+
saliency_map = cv2.resize(log_density_prediction.detach().cpu().numpy()[0, 0], (img_width, img_height))
|
| 124 |
+
|
| 125 |
+
elif generator == 'fpn':
|
| 126 |
+
# Add ./fpn to the system path
|
| 127 |
+
import sys
|
| 128 |
+
sys.path.append('./fpn')
|
| 129 |
+
import inference as inf
|
| 130 |
+
|
| 131 |
+
results_dict = {}
|
| 132 |
+
rt_args = inf.parse_arguments(img)
|
| 133 |
+
|
| 134 |
+
# Call the run_inference function and capture the results
|
| 135 |
+
pred_masks_raw_list, pred_masks_round_list = inf.run_inference(rt_args)
|
| 136 |
+
|
| 137 |
+
# Store the results in the dictionary
|
| 138 |
+
results_dict['pred_masks_raw'] = pred_masks_raw_list
|
| 139 |
+
results_dict['pred_masks_round'] = pred_masks_round_list
|
| 140 |
+
|
| 141 |
+
saliency_map = results_dict['pred_masks_raw']
|
| 142 |
+
|
| 143 |
+
if img_width > img_height:
|
| 144 |
+
saliency_map = cv2.resize(saliency_map, (img_width, img_width))
|
| 145 |
+
|
| 146 |
+
diff = (img_width - img_height) // 2
|
| 147 |
+
|
| 148 |
+
saliency_map = saliency_map[diff:img_width - diff, 0:img_width]
|
| 149 |
+
else:
|
| 150 |
+
saliency_map = cv2.resize(saliency_map, (img_height, img_height))
|
| 151 |
+
|
| 152 |
+
diff = (img_height - img_width) // 2
|
| 153 |
+
|
| 154 |
+
saliency_map = saliency_map[0:img_height, diff:img_height - diff]
|
| 155 |
+
|
| 156 |
+
elif generator == 'emlnet':
|
| 157 |
+
from emlnet.eval_combined import main as eval_combined
|
| 158 |
+
saliency_map = eval_combined(img, emlnet_models)
|
| 159 |
+
|
| 160 |
+
# Resize to image size
|
| 161 |
+
saliency_map = cv2.resize(saliency_map, (img_width, img_height))
|
| 162 |
+
|
| 163 |
+
# Normalize saliency map
|
| 164 |
+
saliency_map = cv2.normalize(saliency_map, None, 255, 0, cv2.NORM_MINMAX, cv2.CV_8UC1)
|
| 165 |
+
|
| 166 |
+
saliency_map = cv2.GaussianBlur(saliency_map, (31, 31), 10)
|
| 167 |
+
return saliency_map
|
| 168 |
+
saliency_map = saliency_map // 16
|
| 169 |
+
|
| 170 |
+
return saliency_map
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def return_saliency_batch(images, generator='deepgaze', deepgaze_model=None, emlnet_models=None, DEVICE='cuda', BATCH_SIZE=1):
|
| 174 |
+
img_widths, img_heights = [], []
|
| 175 |
+
if generator == 'deepgaze':
|
| 176 |
+
import numpy as np
|
| 177 |
+
from scipy.misc import face
|
| 178 |
+
from scipy.ndimage import zoom
|
| 179 |
+
from scipy.special import logsumexp
|
| 180 |
+
import torch
|
| 181 |
+
|
| 182 |
+
import deepgaze_pytorch
|
| 183 |
+
|
| 184 |
+
# you can use DeepGazeI or DeepGazeIIE
|
| 185 |
+
# model = deepgaze_pytorch.DeepGazeIIE(pretrained=True).to(DEVICE)
|
| 186 |
+
|
| 187 |
+
if deepgaze_model is None:
|
| 188 |
+
model = deepgaze_pytorch.DeepGazeIIE(pretrained=True).to(DEVICE)
|
| 189 |
+
else:
|
| 190 |
+
model = deepgaze_model
|
| 191 |
+
|
| 192 |
+
# image = face()
|
| 193 |
+
# image = img
|
| 194 |
+
image_batch = torch.tensor([img.transpose(2, 0, 1) for img in images]).to(DEVICE)
|
| 195 |
+
centerbias_template = np.zeros((1024, 1024))
|
| 196 |
+
centerbias_tensors = []
|
| 197 |
+
|
| 198 |
+
for img in images:
|
| 199 |
+
centerbias = zoom(centerbias_template, (img.shape[0] / centerbias_template.shape[0], img.shape[1] / centerbias_template.shape[1]), order=0, mode='nearest')
|
| 200 |
+
centerbias -= logsumexp(centerbias)
|
| 201 |
+
centerbias_tensors.append(torch.tensor(centerbias).to(DEVICE))
|
| 202 |
+
|
| 203 |
+
# Set img_width and img_height
|
| 204 |
+
img_widths.append(img.shape[1])
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
# rescale to match image size
|
| 208 |
+
# centerbias = zoom(centerbias_template, (image.shape[0]/centerbias_template.shape[0], image.shape[1]/centerbias_template.shape[1]), order=0, mode='nearest')
|
| 209 |
+
# # renormalize log density
|
| 210 |
+
# centerbias -= logsumexp(centerbias)
|
| 211 |
+
|
| 212 |
+
# image_tensor = torch.tensor([image.transpose(2, 0, 1)]).to(DEVICE)
|
| 213 |
+
# centerbias_tensor = torch.tensor([centerbias]).to(DEVICE)
|
| 214 |
+
with torch.no_grad():
|
| 215 |
+
# Process the batch of images in one forward pass
|
| 216 |
+
log_density_predictions = model(image_batch, torch.stack(centerbias_tensors))
|
| 217 |
+
|
| 218 |
+
# log_density_prediction = model(image_tensor, centerbias_tensor)
|
| 219 |
+
|
| 220 |
+
# saliency_map = cv2.resize(log_density_prediction.detach().cpu().numpy()[0, 0], (img_width, img_height))
|
| 221 |
+
|
| 222 |
+
saliency_maps = []
|
| 223 |
+
|
| 224 |
+
for i in range(len(images)):
|
| 225 |
+
saliency_map = cv2.resize(log_density_predictions[i, 0].cpu().numpy(), (img_widths[i], img_widths[i]))
|
| 226 |
+
|
| 227 |
+
saliency_map = cv2.normalize(saliency_map, None, 255, 0, cv2.NORM_MINMAX, cv2.CV_8UC1)
|
| 228 |
+
|
| 229 |
+
saliency_map = cv2.GaussianBlur(saliency_map, (31, 31), 10)
|
| 230 |
+
saliency_map = saliency_map // 16
|
| 231 |
+
|
| 232 |
+
saliency_maps.append(saliency_map)
|
| 233 |
+
|
| 234 |
+
return saliency_maps
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
# def return_itti_saliency(img):
|
| 238 |
+
# '''
|
| 239 |
+
# Takes an image img as input and calculates the saliency map using the
|
| 240 |
+
# Itti's Saliency Map Generator. It returns the saliency map.
|
| 241 |
+
# '''
|
| 242 |
+
|
| 243 |
+
# img_width, img_height = img.shape[1], img.shape[0]
|
| 244 |
+
|
| 245 |
+
# sm = pySaliencyMap.pySaliencyMap(img_width, img_height)
|
| 246 |
+
# saliency_map = sm.SMGetSM(img)
|
| 247 |
+
|
| 248 |
+
# # Scale pixel values to 0-255 instead of float (approx 0, hence black image)
|
| 249 |
+
# # https://stackoverflow.com/questions/48331211/how-to-use-cv2-imshow-correctly-for-the-float-image-returned-by-cv2-distancet/48333272
|
| 250 |
+
# saliency_map = cv2.normalize(saliency_map, None, 255, 0, cv2.NORM_MINMAX, cv2.CV_8UC1)
|
| 251 |
+
|
| 252 |
+
# return saliency_map
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
# Saliency Ranking
|
| 256 |
+
def calculate_pixel_frequency(img) -> dict:
|
| 257 |
+
'''
|
| 258 |
+
Calculates the frequency of each pixel value in the image img and
|
| 259 |
+
returns a dictionary containing the pixel frequencies.
|
| 260 |
+
'''
|
| 261 |
+
|
| 262 |
+
flt = img.flatten()
|
| 263 |
+
unique, counts = np.unique(flt, return_counts=True)
|
| 264 |
+
pixels_frequency = dict(zip(unique, counts))
|
| 265 |
+
|
| 266 |
+
return pixels_frequency
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
def calculate_score(H, sum, ds, cb, w):
|
| 270 |
+
'''
|
| 271 |
+
Calculates the saliency score of an image img using the entropy H, depth score ds, centre-bias cb and weights w. It returns the saliency score.
|
| 272 |
+
'''
|
| 273 |
+
|
| 274 |
+
# Normalise H
|
| 275 |
+
# H = (H - 0) / (math.log(2, 256) - 0)
|
| 276 |
+
|
| 277 |
+
# H = wth root of H
|
| 278 |
+
H = H ** w[0]
|
| 279 |
+
|
| 280 |
+
if sum > 0:
|
| 281 |
+
sum = np.log(sum)
|
| 282 |
+
sum = sum ** w[1]
|
| 283 |
+
|
| 284 |
+
ds = ds ** w[2]
|
| 285 |
+
|
| 286 |
+
cb = (cb + 1) ** w[3]
|
| 287 |
+
|
| 288 |
+
return H + sum + ds + cb
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
def calculate_entropy(img, w, dw) -> float:
|
| 292 |
+
'''
|
| 293 |
+
Calculates the entropy of an image img using the given weights w and
|
| 294 |
+
depth weights dw. It returns the entropy value.
|
| 295 |
+
'''
|
| 296 |
+
|
| 297 |
+
flt = img.flatten()
|
| 298 |
+
|
| 299 |
+
# c = flt.shape[0]
|
| 300 |
+
total_pixels = 0
|
| 301 |
+
t_prob = 0
|
| 302 |
+
# sum_of_probs = 0
|
| 303 |
+
entropy = 0
|
| 304 |
+
wt = w * 10
|
| 305 |
+
|
| 306 |
+
# if imgD=None then proceed normally
|
| 307 |
+
# else calculate its frequency and find max
|
| 308 |
+
# use this max value as a weight in entropy
|
| 309 |
+
|
| 310 |
+
pixels_frequency = calculate_pixel_frequency(flt)
|
| 311 |
+
|
| 312 |
+
total_pixels = sum(pixels_frequency.values())
|
| 313 |
+
|
| 314 |
+
for px in pixels_frequency:
|
| 315 |
+
t_prob = pixels_frequency[px] / total_pixels
|
| 316 |
+
|
| 317 |
+
if t_prob != 0:
|
| 318 |
+
entropy += (t_prob * math.log((1 / t_prob), 2))
|
| 319 |
+
|
| 320 |
+
# entropy = entropy * wt * dw
|
| 321 |
+
|
| 322 |
+
return entropy
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
def find_most_salient_segment(segments, kernel, dws):
|
| 326 |
+
'''
|
| 327 |
+
Finds the most salient segment among the provided segments using a
|
| 328 |
+
given kernel and depth weights. It returns the maximum entropy value
|
| 329 |
+
and the index of the most salient segment.
|
| 330 |
+
'''
|
| 331 |
+
|
| 332 |
+
# max_entropy = 0
|
| 333 |
+
max_score = 0
|
| 334 |
+
index = 0
|
| 335 |
+
i = 0
|
| 336 |
+
|
| 337 |
+
for segment in segments:
|
| 338 |
+
temp_entropy = calculate_entropy(segment, kernel[i], dws[i])
|
| 339 |
+
# Normalise semgnet bweetn 0 and 255
|
| 340 |
+
segment = cv2.normalize(segment, None, 255, 0, cv2.NORM_MINMAX, cv2.CV_8UC1)
|
| 341 |
+
temp_sum = np.sum(segment)
|
| 342 |
+
# temp_tup = (i, temp_entropy)
|
| 343 |
+
# segments_entropies.append(temp_tup)
|
| 344 |
+
|
| 345 |
+
w = WEIGHTS
|
| 346 |
+
|
| 347 |
+
temp_score = calculate_score(temp_entropy, temp_sum, dws[i], kernel[i], w)
|
| 348 |
+
|
| 349 |
+
temp_tup = (i, temp_score, temp_entropy ** w[0], temp_sum ** w[1], (kernel[i] + 1) ** w[2], dws[i] ** w[3])
|
| 350 |
+
|
| 351 |
+
# segments_scores.append((i, temp_score))
|
| 352 |
+
segments_scores.append(temp_tup)
|
| 353 |
+
|
| 354 |
+
# if temp_entropy > max_entropy:
|
| 355 |
+
# max_entropy = temp_entropy
|
| 356 |
+
# index = i
|
| 357 |
+
|
| 358 |
+
if temp_score > max_score:
|
| 359 |
+
max_score = temp_score
|
| 360 |
+
index = i
|
| 361 |
+
|
| 362 |
+
i += 1
|
| 363 |
+
|
| 364 |
+
# return max_entropy, index
|
| 365 |
+
return max_score, index
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
def make_gaussian(size, fwhm=10, center=None):
|
| 369 |
+
'''
|
| 370 |
+
Generates a 2D Gaussian kernel with the specified size and full-width-half-maximum (fwhm). It returns the Gaussian kernel.
|
| 371 |
+
|
| 372 |
+
size: length of a side of the square
|
| 373 |
+
fwhm: full-width-half-maximum, which can be thought of as an effective
|
| 374 |
+
radius.
|
| 375 |
+
|
| 376 |
+
https://gist.github.com/andrewgiessel/4635563
|
| 377 |
+
'''
|
| 378 |
+
|
| 379 |
+
x = np.arange(0, size, 1, float)
|
| 380 |
+
y = x[:, np.newaxis]
|
| 381 |
+
|
| 382 |
+
if center is None:
|
| 383 |
+
x0 = y0 = size // 2
|
| 384 |
+
else:
|
| 385 |
+
x0 = center[0]
|
| 386 |
+
y0 = center[1]
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
return np.exp(-4 * np.log(2) * ((x - x0) ** 2 + (y - y0) ** 2) / fwhm ** 2)
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
def gen_depth_weights(d_segments, depth_map) -> list:
|
| 393 |
+
'''
|
| 394 |
+
Generates depth weights for the segments based on the depth map. It
|
| 395 |
+
returns a list of depth weights.
|
| 396 |
+
'''
|
| 397 |
+
|
| 398 |
+
hist_d, _ = np.histogram(depth_map, 256, [0, 256])
|
| 399 |
+
|
| 400 |
+
# Get first non-zero index
|
| 401 |
+
first_nz = next((i for i, x in enumerate(hist_d) if x), None)
|
| 402 |
+
|
| 403 |
+
# Get last non-zero index
|
| 404 |
+
rev = (len(hist_d) - idx for idx, item in enumerate(reversed(hist_d), 1) if item)
|
| 405 |
+
last_nz = next(rev, default=None)
|
| 406 |
+
|
| 407 |
+
mid = (first_nz + last_nz) / 2
|
| 408 |
+
|
| 409 |
+
for seg in d_segments:
|
| 410 |
+
hist, _ = np.histogram(seg, 256, [0, 256])
|
| 411 |
+
dw = 0
|
| 412 |
+
ind = 0
|
| 413 |
+
for s in hist:
|
| 414 |
+
if ind > mid:
|
| 415 |
+
dw = dw + (s * 1)
|
| 416 |
+
ind = ind + 1
|
| 417 |
+
dws.append(dw)
|
| 418 |
+
|
| 419 |
+
return dws
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
def gen_blank_depth_weight(d_segments):
|
| 423 |
+
'''
|
| 424 |
+
Generates blank depth weights for the segments. It returns a list of
|
| 425 |
+
depth weights.
|
| 426 |
+
'''
|
| 427 |
+
|
| 428 |
+
for _ in d_segments:
|
| 429 |
+
dw = 1
|
| 430 |
+
dws.append(dw)
|
| 431 |
+
return dws
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
# def generate_heatmap(img, mode, sorted_seg_scores, segments_coords) -> tuple:
|
| 435 |
+
# '''
|
| 436 |
+
# Generates a heatmap overlay on the input image img based on the
|
| 437 |
+
# provided sorted segment scores. The mode parameter determines the color
|
| 438 |
+
# scheme of the heatmap. It returns the image with the heatmap overlay
|
| 439 |
+
# and a list of segment scores.
|
| 440 |
+
|
| 441 |
+
# mode: 0 for white grid, 1 for color-coded grid
|
| 442 |
+
# '''
|
| 443 |
+
|
| 444 |
+
# font = cv2.FONT_HERSHEY_SIMPLEX
|
| 445 |
+
# # print_index = 0
|
| 446 |
+
# print_index = len(sorted_seg_scores) - 1
|
| 447 |
+
# set_value = int(0.25 * len(sorted_seg_scores))
|
| 448 |
+
# color = (0, 0, 0)
|
| 449 |
+
|
| 450 |
+
# max_x = 0
|
| 451 |
+
# max_y = 0
|
| 452 |
+
|
| 453 |
+
# overlay = np.zeros_like(img, dtype=np.uint8)
|
| 454 |
+
# text_overlay = np.zeros_like(img, dtype=np.uint8)
|
| 455 |
+
|
| 456 |
+
# sara_list_out = []
|
| 457 |
+
|
| 458 |
+
# for ent in reversed(sorted_seg_scores):
|
| 459 |
+
# quartile = 0
|
| 460 |
+
# if mode == 0:
|
| 461 |
+
# color = (255, 255, 255)
|
| 462 |
+
# t = 4
|
| 463 |
+
# elif mode == 1:
|
| 464 |
+
# if print_index + 1 <= set_value:
|
| 465 |
+
# color = (0, 0, 255, 255)
|
| 466 |
+
# t = 2
|
| 467 |
+
# quartile = 1
|
| 468 |
+
# elif print_index + 1 <= set_value * 2:
|
| 469 |
+
# color = (0, 128, 255, 192)
|
| 470 |
+
# t = 4
|
| 471 |
+
# quartile = 2
|
| 472 |
+
# elif print_index + 1 <= set_value * 3:
|
| 473 |
+
# color = (0, 255, 255, 128)
|
| 474 |
+
# t = 4
|
| 475 |
+
# t = 6
|
| 476 |
+
# quartile = 3
|
| 477 |
+
# # elif print_index + 1 <= set_value * 4:
|
| 478 |
+
# # color = (0, 250, 0, 64)
|
| 479 |
+
# # t = 8
|
| 480 |
+
# # quartile = 4
|
| 481 |
+
# else:
|
| 482 |
+
# color = (0, 250, 0, 64)
|
| 483 |
+
# t = 8
|
| 484 |
+
# quartile = 4
|
| 485 |
+
|
| 486 |
+
|
| 487 |
+
# x1 = segments_coords[ent[0]][1]
|
| 488 |
+
# y1 = segments_coords[ent[0]][2]
|
| 489 |
+
# x2 = segments_coords[ent[0]][3]
|
| 490 |
+
# y2 = segments_coords[ent[0]][4]
|
| 491 |
+
|
| 492 |
+
# if x2 > max_x:
|
| 493 |
+
# max_x = x2
|
| 494 |
+
# if y2 > max_y:
|
| 495 |
+
# max_y = y2
|
| 496 |
+
|
| 497 |
+
# x = int((x1 + x2) / 2)
|
| 498 |
+
# y = int((y1 + y2) / 2)
|
| 499 |
+
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
# # fill rectangle
|
| 503 |
+
# cv2.rectangle(overlay, (x1, y1), (x2, y2), color, -1)
|
| 504 |
+
|
| 505 |
+
# cv2.rectangle(overlay, (x1, y1), (x2, y2), (0, 0, 0), 1)
|
| 506 |
+
# # put text in the middle of the rectangle
|
| 507 |
+
|
| 508 |
+
# # white text
|
| 509 |
+
# cv2.putText(text_overlay, str(print_index), (x - 5, y),
|
| 510 |
+
# font, .4, (255, 255, 255), 1, cv2.LINE_AA)
|
| 511 |
+
|
| 512 |
+
# # Index, rank, score, entropy, entropy_sum, centre_bias, depth, quartile
|
| 513 |
+
# sara_tuple = (ent[0], print_index, ent[1], ent[2], ent[3], ent[4], ent[5], quartile)
|
| 514 |
+
# sara_list_out.append(sara_tuple)
|
| 515 |
+
# print_index -= 1
|
| 516 |
+
|
| 517 |
+
# # crop the overlay to up to x2 and y2
|
| 518 |
+
# overlay = overlay[0:max_y, 0:max_x]
|
| 519 |
+
# text_overlay = text_overlay[0:max_y, 0:max_x]
|
| 520 |
+
# img = img[0:max_y, 0:max_x]
|
| 521 |
+
|
| 522 |
+
|
| 523 |
+
# img = cv2.addWeighted(overlay, 0.3, img, 0.7, 0, img)
|
| 524 |
+
|
| 525 |
+
# img[text_overlay > 128] = text_overlay[text_overlay > 128]
|
| 526 |
+
|
| 527 |
+
|
| 528 |
+
# return img, sara_list_out
|
| 529 |
+
def generate_heatmap(img, sorted_seg_scores, segments_coords, mode=1) -> tuple:
|
| 530 |
+
'''
|
| 531 |
+
Generates a more vibrant heatmap overlay on the input image img based on the
|
| 532 |
+
provided sorted segment scores. It returns the image with the heatmap overlay
|
| 533 |
+
and a list of segment scores with quartile information.
|
| 534 |
+
|
| 535 |
+
mode: 0 for white grid, 1 for color-coded grid, 2 for heatmap to be used as a feature
|
| 536 |
+
'''
|
| 537 |
+
alpha =0.3
|
| 538 |
+
if mode == 2:
|
| 539 |
+
|
| 540 |
+
font = cv2.FONT_HERSHEY_SIMPLEX
|
| 541 |
+
print_index = len(sorted_seg_scores) - 1
|
| 542 |
+
set_value = int(0.25 * len(sorted_seg_scores))
|
| 543 |
+
|
| 544 |
+
max_x = 0
|
| 545 |
+
max_y = 0
|
| 546 |
+
|
| 547 |
+
overlay = np.zeros_like(img, dtype=np.uint8)
|
| 548 |
+
text_overlay = np.zeros_like(img, dtype=np.uint8)
|
| 549 |
+
|
| 550 |
+
sara_list_out = []
|
| 551 |
+
|
| 552 |
+
scores = [score[1] for score in sorted_seg_scores]
|
| 553 |
+
min_score = min(scores)
|
| 554 |
+
max_score = max(scores)
|
| 555 |
+
|
| 556 |
+
# Choose a colormap from matplotlib
|
| 557 |
+
colormap = plt.get_cmap('jet') # 'jet', 'viridis', 'plasma', 'magma', 'cividis, jet_r, viridis_r, plasma_r, magma_r, cividis_r
|
| 558 |
+
|
| 559 |
+
for ent in reversed(sorted_seg_scores):
|
| 560 |
+
score = ent[1]
|
| 561 |
+
normalized_score = (score - min_score) / (max_score - min_score)
|
| 562 |
+
color_weight = normalized_score * score # Weighted color based on the score
|
| 563 |
+
color = np.array(colormap(normalized_score)[:3]) * 255 #* color_weight
|
| 564 |
+
|
| 565 |
+
x1 = segments_coords[ent[0]][1]
|
| 566 |
+
y1 = segments_coords[ent[0]][2]
|
| 567 |
+
x2 = segments_coords[ent[0]][3]
|
| 568 |
+
y2 = segments_coords[ent[0]][4]
|
| 569 |
+
|
| 570 |
+
if x2 > max_x:
|
| 571 |
+
max_x = x2
|
| 572 |
+
if y2 > max_y:
|
| 573 |
+
max_y = y2
|
| 574 |
+
|
| 575 |
+
x = int((x1 + x2) / 2)
|
| 576 |
+
y = int((y1 + y2) / 2)
|
| 577 |
+
|
| 578 |
+
# fill rectangle
|
| 579 |
+
cv2.rectangle(overlay, (x1, y1), (x2, y2), color, -1)
|
| 580 |
+
# black border
|
| 581 |
+
# cv2.rectangle(overlay, (x1, y1), (x2, y2), (0, 0, 0), 1)
|
| 582 |
+
|
| 583 |
+
# white text
|
| 584 |
+
# cv2.putText(text_overlay, str(print_index), (x - 5, y),
|
| 585 |
+
# font, .4, (255, 255, 255), 1, cv2.LINE_AA)
|
| 586 |
+
|
| 587 |
+
# Determine quartile based on print_index
|
| 588 |
+
if print_index + 1 <= set_value:
|
| 589 |
+
quartile = 1
|
| 590 |
+
elif print_index + 1 <= set_value * 2:
|
| 591 |
+
quartile = 2
|
| 592 |
+
elif print_index + 1 <= set_value * 3:
|
| 593 |
+
quartile = 3
|
| 594 |
+
else:
|
| 595 |
+
quartile = 4
|
| 596 |
+
|
| 597 |
+
sara_tuple = (ent[0], print_index, ent[1], ent[2], ent[3], ent[4], ent[5], quartile)
|
| 598 |
+
sara_list_out.append(sara_tuple)
|
| 599 |
+
print_index -= 1
|
| 600 |
+
|
| 601 |
+
overlay = overlay[0:max_y, 0:max_x]
|
| 602 |
+
text_overlay = text_overlay[0:max_y, 0:max_x]
|
| 603 |
+
img = img[0:max_y, 0:max_x]
|
| 604 |
+
|
| 605 |
+
# Create a blank grayscale image with the same dimensions as the original image
|
| 606 |
+
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
| 607 |
+
|
| 608 |
+
gray = cv2.merge([gray, gray, gray])
|
| 609 |
+
|
| 610 |
+
gray = cv2.addWeighted(overlay, alpha, gray, 1-alpha, 0, gray)
|
| 611 |
+
gray[text_overlay > 128] = text_overlay[text_overlay > 128]
|
| 612 |
+
|
| 613 |
+
return gray, sara_list_out
|
| 614 |
+
else:
|
| 615 |
+
font = cv2.FONT_HERSHEY_SIMPLEX
|
| 616 |
+
# print_index = 0
|
| 617 |
+
print_index = len(sorted_seg_scores) - 1
|
| 618 |
+
set_value = int(0.25 * len(sorted_seg_scores))
|
| 619 |
+
color = (0, 0, 0)
|
| 620 |
+
|
| 621 |
+
max_x = 0
|
| 622 |
+
max_y = 0
|
| 623 |
+
|
| 624 |
+
overlay = np.zeros_like(img, dtype=np.uint8)
|
| 625 |
+
text_overlay = np.zeros_like(img, dtype=np.uint8)
|
| 626 |
+
|
| 627 |
+
sara_list_out = []
|
| 628 |
+
|
| 629 |
+
for ent in reversed(sorted_seg_scores):
|
| 630 |
+
quartile = 0
|
| 631 |
+
if mode == 0:
|
| 632 |
+
color = (255, 255, 255)
|
| 633 |
+
t = 4
|
| 634 |
+
elif mode == 1:
|
| 635 |
+
if print_index + 1 <= set_value:
|
| 636 |
+
color = (0, 0, 255, 255)
|
| 637 |
+
t = 2
|
| 638 |
+
quartile = 1
|
| 639 |
+
elif print_index + 1 <= set_value * 2:
|
| 640 |
+
color = (0, 128, 255, 192)
|
| 641 |
+
t = 4
|
| 642 |
+
quartile = 2
|
| 643 |
+
elif print_index + 1 <= set_value * 3:
|
| 644 |
+
color = (0, 255, 255, 128)
|
| 645 |
+
t = 4
|
| 646 |
+
t = 6
|
| 647 |
+
quartile = 3
|
| 648 |
+
# elif print_index + 1 <= set_value * 4:
|
| 649 |
+
# color = (0, 250, 0, 64)
|
| 650 |
+
# t = 8
|
| 651 |
+
# quartile = 4
|
| 652 |
+
else:
|
| 653 |
+
color = (0, 250, 0, 64)
|
| 654 |
+
t = 8
|
| 655 |
+
quartile = 4
|
| 656 |
+
|
| 657 |
+
|
| 658 |
+
x1 = segments_coords[ent[0]][1]
|
| 659 |
+
y1 = segments_coords[ent[0]][2]
|
| 660 |
+
x2 = segments_coords[ent[0]][3]
|
| 661 |
+
y2 = segments_coords[ent[0]][4]
|
| 662 |
+
|
| 663 |
+
if x2 > max_x:
|
| 664 |
+
max_x = x2
|
| 665 |
+
if y2 > max_y:
|
| 666 |
+
max_y = y2
|
| 667 |
+
|
| 668 |
+
x = int((x1 + x2) / 2)
|
| 669 |
+
y = int((y1 + y2) / 2)
|
| 670 |
+
|
| 671 |
+
|
| 672 |
+
|
| 673 |
+
# fill rectangle
|
| 674 |
+
cv2.rectangle(overlay, (x1, y1), (x2, y2), color, -1)
|
| 675 |
+
|
| 676 |
+
cv2.rectangle(overlay, (x1, y1), (x2, y2), (0, 0, 0), 1)
|
| 677 |
+
# put text in the middle of the rectangle
|
| 678 |
+
|
| 679 |
+
# white text
|
| 680 |
+
cv2.putText(text_overlay, str(print_index), (x - 5, y),
|
| 681 |
+
font, .4, (255, 255, 255), 1, cv2.LINE_AA)
|
| 682 |
+
|
| 683 |
+
# Index, rank, score, entropy, entropy_sum, centre_bias, depth, quartile
|
| 684 |
+
sara_tuple = (ent[0], print_index, ent[1], ent[2], ent[3], ent[4], ent[5], quartile)
|
| 685 |
+
sara_list_out.append(sara_tuple)
|
| 686 |
+
print_index -= 1
|
| 687 |
+
|
| 688 |
+
# crop the overlay to up to x2 and y2
|
| 689 |
+
overlay = overlay[0:max_y, 0:max_x]
|
| 690 |
+
text_overlay = text_overlay[0:max_y, 0:max_x]
|
| 691 |
+
img = img[0:max_y, 0:max_x]
|
| 692 |
+
|
| 693 |
+
|
| 694 |
+
img = cv2.addWeighted(overlay, 0.3, img, 0.7, 0, img)
|
| 695 |
+
|
| 696 |
+
img[text_overlay > 128] = text_overlay[text_overlay > 128]
|
| 697 |
+
|
| 698 |
+
|
| 699 |
+
return img, sara_list_out
|
| 700 |
+
|
| 701 |
+
def generate_sara(tex, tex_segments, mode=2):
|
| 702 |
+
'''
|
| 703 |
+
Generates the SaRa (Salient Region Annotation) output by calculating
|
| 704 |
+
saliency scores for the segments of the given texture image tex. It
|
| 705 |
+
returns the texture image with the heatmap overlay and a list of
|
| 706 |
+
segment scores.
|
| 707 |
+
'''
|
| 708 |
+
|
| 709 |
+
gaussian_kernel_array = make_gaussian(seg_dim)
|
| 710 |
+
gaussian1d = gaussian_kernel_array.ravel()
|
| 711 |
+
|
| 712 |
+
dws = gen_blank_depth_weight(tex_segments)
|
| 713 |
+
|
| 714 |
+
max_h, index = find_most_salient_segment(tex_segments, gaussian1d, dws)
|
| 715 |
+
# dict_entropies = dict(segments_entropies)
|
| 716 |
+
# segments_scores list with 5 elements, use index as key for dict and store rest as list of index
|
| 717 |
+
dict_scores = {}
|
| 718 |
+
|
| 719 |
+
for segment in segments_scores:
|
| 720 |
+
# Index: score, entropy, sum, depth, centre-bias
|
| 721 |
+
dict_scores[segment[0]] = [segment[1], segment[2], segment[3], segment[4], segment[5]]
|
| 722 |
+
|
| 723 |
+
# sorted_entropies = sorted(dict_entropies.items(),
|
| 724 |
+
# key=operator.itemgetter(1), reverse=True)
|
| 725 |
+
|
| 726 |
+
|
| 727 |
+
# sorted_scores = sorted(dict_scores.items(),
|
| 728 |
+
# key=operator.itemgetter(1), reverse=True)
|
| 729 |
+
|
| 730 |
+
# Sort by first value in value list
|
| 731 |
+
sorted_scores = sorted(dict_scores.items(), key=lambda x: x[1][0], reverse=True)
|
| 732 |
+
|
| 733 |
+
# flatten
|
| 734 |
+
sorted_scores = [[i[0], i[1][0], i[1][1], i[1][2], i[1][3], i[1][4]] for i in sorted_scores]
|
| 735 |
+
|
| 736 |
+
# tex_out, sara_list_out = generate_heatmap(
|
| 737 |
+
# tex, 1, sorted_entropies, segments_coords)
|
| 738 |
+
|
| 739 |
+
tex_out, sara_list_out = generate_heatmap(
|
| 740 |
+
tex, sorted_scores, segments_coords, mode = mode)
|
| 741 |
+
|
| 742 |
+
sara_list_out = list(reversed(sara_list_out))
|
| 743 |
+
|
| 744 |
+
return tex_out, sara_list_out
|
| 745 |
+
|
| 746 |
+
|
| 747 |
+
def return_sara(input_img, grid, generator='itti', saliency_map=None, mode = 2):
|
| 748 |
+
'''
|
| 749 |
+
Computes the SaRa output for the given input image. It uses the
|
| 750 |
+
generate_sara function internally. It returns the SaRa output image and
|
| 751 |
+
a list of segment scores.
|
| 752 |
+
'''
|
| 753 |
+
|
| 754 |
+
global seg_dim
|
| 755 |
+
seg_dim = grid
|
| 756 |
+
|
| 757 |
+
if saliency_map is None:
|
| 758 |
+
saliency_map = return_saliency(input_img, generator)
|
| 759 |
+
|
| 760 |
+
tex_segments = generate_segments(saliency_map, seg_dim)
|
| 761 |
+
|
| 762 |
+
# tex_segments = generate_segments(input_img, seg_dim)
|
| 763 |
+
sara_output, sara_list_output = generate_sara(input_img, tex_segments, mode=mode)
|
| 764 |
+
|
| 765 |
+
return sara_output, sara_list_output
|
| 766 |
+
|
| 767 |
+
|
| 768 |
+
def mean_squared_error(image_a, image_b) -> float:
|
| 769 |
+
'''
|
| 770 |
+
Calculates the Mean Squared Error (MSE), i.e. sum of squared
|
| 771 |
+
differences between two images image_a and image_b. It returns the MSE
|
| 772 |
+
value.
|
| 773 |
+
|
| 774 |
+
NOTE: The two images must have the same dimension
|
| 775 |
+
'''
|
| 776 |
+
|
| 777 |
+
err = np.sum((image_a.astype('float') - image_b.astype('float')) ** 2)
|
| 778 |
+
err /= float(image_a.shape[0] * image_a.shape[1])
|
| 779 |
+
|
| 780 |
+
return err
|
| 781 |
+
|
| 782 |
+
|
| 783 |
+
def reset():
|
| 784 |
+
'''
|
| 785 |
+
Resets all global variables to their default values.
|
| 786 |
+
'''
|
| 787 |
+
|
| 788 |
+
# global segments_entropies, segments_scores, segments_coords, seg_dim, segments, gt_segments, dws, sara_list
|
| 789 |
+
|
| 790 |
+
global segments_scores, segments_coords, seg_dim, segments, gt_segments, dws, sara_list
|
| 791 |
+
|
| 792 |
+
# segments_entropies = []
|
| 793 |
+
segments_scores = []
|
| 794 |
+
segments_coords = []
|
| 795 |
+
|
| 796 |
+
seg_dim = 0
|
| 797 |
+
segments = []
|
| 798 |
+
gt_segments = []
|
| 799 |
+
dws = []
|
| 800 |
+
sara_list = []
|
| 801 |
+
|
| 802 |
+
|
| 803 |
+
|
| 804 |
+
def resize_based_on_important_ranks(img, sara_info, grid_size, rate=0.3):
|
| 805 |
+
def generate_segments(image, seg_count) -> dict:
|
| 806 |
+
"""
|
| 807 |
+
Function to generate segments of an image
|
| 808 |
+
|
| 809 |
+
Args:
|
| 810 |
+
image: input image
|
| 811 |
+
seg_count: number of segments to generate
|
| 812 |
+
|
| 813 |
+
Returns:
|
| 814 |
+
segments: dictionary of segments
|
| 815 |
+
|
| 816 |
+
"""
|
| 817 |
+
# Initializing segments dictionary
|
| 818 |
+
segments = {}
|
| 819 |
+
# Initializing segment index and segment count
|
| 820 |
+
segment_count = seg_count
|
| 821 |
+
index = 0
|
| 822 |
+
|
| 823 |
+
# Retrieving image width and height
|
| 824 |
+
h, w = image.shape[:2]
|
| 825 |
+
|
| 826 |
+
# Calculating width and height intervals for segments from the segment count
|
| 827 |
+
w_interval = w // segment_count
|
| 828 |
+
h_interval = h // segment_count
|
| 829 |
+
|
| 830 |
+
# Iterating through the image and generating segments
|
| 831 |
+
for i in range(segment_count):
|
| 832 |
+
for j in range(segment_count):
|
| 833 |
+
# Calculating segment coordinates
|
| 834 |
+
x1, y1 = j * w_interval, i * h_interval
|
| 835 |
+
x2, y2 = x1 + w_interval, y1 + h_interval
|
| 836 |
+
|
| 837 |
+
# Adding segment coordinates to segments dictionary
|
| 838 |
+
segments[index] = (x1, y1, x2, y2)
|
| 839 |
+
|
| 840 |
+
# Incrementing segment index
|
| 841 |
+
index += 1
|
| 842 |
+
|
| 843 |
+
# Returning segments dictionary
|
| 844 |
+
return segments
|
| 845 |
+
|
| 846 |
+
# Retrieving important ranks from SaRa
|
| 847 |
+
sara_dict = {
|
| 848 |
+
info[0]: {
|
| 849 |
+
'score': info[2],
|
| 850 |
+
'index': info[1]
|
| 851 |
+
}
|
| 852 |
+
for info in sara_info[1]
|
| 853 |
+
}
|
| 854 |
+
|
| 855 |
+
# Sorting important ranks by score
|
| 856 |
+
sorted_sara_dict = sorted(sara_dict.items(), key=lambda item: item[1]['score'], reverse=True)
|
| 857 |
+
|
| 858 |
+
# Generating segments
|
| 859 |
+
index_info = generate_segments(img, grid_size)
|
| 860 |
+
|
| 861 |
+
# Initializing most important ranks image
|
| 862 |
+
most_imp_ranks = np.zeros_like(img)
|
| 863 |
+
|
| 864 |
+
# Calculating maximum rank
|
| 865 |
+
max_rank = int(grid_size * grid_size * rate)
|
| 866 |
+
count = 0
|
| 867 |
+
|
| 868 |
+
# Iterating through important ranks and adding them to most important ranks image
|
| 869 |
+
for rank, info in sorted_sara_dict:
|
| 870 |
+
# Checking if rank is within maximum rank
|
| 871 |
+
if count <= max_rank:
|
| 872 |
+
# Retrieving segment coordinates
|
| 873 |
+
coords = index_info[rank]
|
| 874 |
+
|
| 875 |
+
# Adding segment to most important ranks image by making it white
|
| 876 |
+
most_imp_ranks[coords[1]:coords[3], coords[0]:coords[2]] = 255
|
| 877 |
+
|
| 878 |
+
# Incrementing count
|
| 879 |
+
count += 1
|
| 880 |
+
else:
|
| 881 |
+
break
|
| 882 |
+
|
| 883 |
+
# Retrieving coordinates of most important ranks
|
| 884 |
+
coords = np.argwhere(most_imp_ranks == 255)
|
| 885 |
+
|
| 886 |
+
# Checking if no important ranks were found and returning original image
|
| 887 |
+
if coords.size == 0:
|
| 888 |
+
return img , most_imp_ranks, [0, 0, img.shape[0], img.shape[1]]
|
| 889 |
+
|
| 890 |
+
# Cropping image based on most important ranks
|
| 891 |
+
x0, y0 = coords.min(axis=0)[:2]
|
| 892 |
+
x1, y1 = coords.max(axis=0)[:2] + 1
|
| 893 |
+
cropped_img = img[x0:x1, y0:y1]
|
| 894 |
+
return cropped_img , most_imp_ranks, [x0, y0, x1, y1]
|
| 895 |
+
|
| 896 |
+
def sara_resize(img, sara_info, grid_size, rate=0.3, iterations=2):
|
| 897 |
+
"""
|
| 898 |
+
Function to resize an image based on SaRa
|
| 899 |
+
|
| 900 |
+
Args:
|
| 901 |
+
img: input image
|
| 902 |
+
sara_info: SaRa information
|
| 903 |
+
grid_size: size of the grid
|
| 904 |
+
rate: rate of important ranks
|
| 905 |
+
iterations: number of iterations to resize
|
| 906 |
+
|
| 907 |
+
Returns:
|
| 908 |
+
img: resized image
|
| 909 |
+
"""
|
| 910 |
+
# Iterating through iterations
|
| 911 |
+
for _ in range(iterations):
|
| 912 |
+
# Resizing image based on important ranks
|
| 913 |
+
img, most_imp_ranks, coords = resize_based_on_important_ranks(img, sara_info, grid_size, rate=rate)
|
| 914 |
+
|
| 915 |
+
# Returning resized image
|
| 916 |
+
return img, most_imp_ranks, coords
|
| 917 |
+
|
| 918 |
+
def plot_3D(img, sara_info, grid_size, rate=0.3):
|
| 919 |
+
def generate_segments(image, seg_count) -> dict:
|
| 920 |
+
"""
|
| 921 |
+
Function to generate segments of an image
|
| 922 |
+
|
| 923 |
+
Args:
|
| 924 |
+
image: input image
|
| 925 |
+
seg_count: number of segments to generate
|
| 926 |
+
|
| 927 |
+
Returns:
|
| 928 |
+
segments: dictionary of segments
|
| 929 |
+
|
| 930 |
+
"""
|
| 931 |
+
# Initializing segments dictionary
|
| 932 |
+
segments = {}
|
| 933 |
+
# Initializing segment index and segment count
|
| 934 |
+
segment_count = seg_count
|
| 935 |
+
index = 0
|
| 936 |
+
|
| 937 |
+
# Retrieving image width and height
|
| 938 |
+
h, w = image.shape[:2]
|
| 939 |
+
|
| 940 |
+
# Calculating width and height intervals for segments from the segment count
|
| 941 |
+
w_interval = w // segment_count
|
| 942 |
+
h_interval = h // segment_count
|
| 943 |
+
|
| 944 |
+
# Iterating through the image and generating segments
|
| 945 |
+
for i in range(segment_count):
|
| 946 |
+
for j in range(segment_count):
|
| 947 |
+
# Calculating segment coordinates
|
| 948 |
+
x1, y1 = j * w_interval, i * h_interval
|
| 949 |
+
x2, y2 = x1 + w_interval, y1 + h_interval
|
| 950 |
+
|
| 951 |
+
# Adding segment coordinates to segments dictionary
|
| 952 |
+
segments[index] = (x1, y1, x2, y2)
|
| 953 |
+
|
| 954 |
+
# Incrementing segment index
|
| 955 |
+
index += 1
|
| 956 |
+
|
| 957 |
+
# Returning segments dictionary
|
| 958 |
+
return segments
|
| 959 |
+
|
| 960 |
+
# Extracting heatmap from SaRa information
|
| 961 |
+
heatmap = sara_info[0]
|
| 962 |
+
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
|
| 963 |
+
|
| 964 |
+
# Retrieving important ranks from SaRa
|
| 965 |
+
sara_dict = {
|
| 966 |
+
info[0]: {
|
| 967 |
+
'score': info[2],
|
| 968 |
+
'index': info[1]
|
| 969 |
+
}
|
| 970 |
+
for info in sara_info[1]
|
| 971 |
+
}
|
| 972 |
+
|
| 973 |
+
# Sorting important ranks by score
|
| 974 |
+
sorted_sara_dict = sorted(sara_dict.items(), key=lambda item: item[1]['score'], reverse=True)
|
| 975 |
+
|
| 976 |
+
# Generating segments
|
| 977 |
+
index_info = generate_segments(img, grid_size)
|
| 978 |
+
|
| 979 |
+
# Calculating maximum rank
|
| 980 |
+
max_rank = int(grid_size * grid_size * rate)
|
| 981 |
+
count = 0
|
| 982 |
+
|
| 983 |
+
# Normalizing heatmap
|
| 984 |
+
heatmap = heatmap.astype(float) / 255.0
|
| 985 |
+
|
| 986 |
+
# Creating a figure
|
| 987 |
+
fig = plt.figure(figsize=(20, 10))
|
| 988 |
+
|
| 989 |
+
# Creating a 3D plot
|
| 990 |
+
ax = fig.add_subplot(111, projection='3d')
|
| 991 |
+
|
| 992 |
+
# Defining the x and y coordinates for the heatmap
|
| 993 |
+
x_coords = np.linspace(0, 1, heatmap.shape[1])
|
| 994 |
+
y_coords = np.linspace(0, 1, heatmap.shape[0])
|
| 995 |
+
x, y = np.meshgrid(x_coords, y_coords)
|
| 996 |
+
|
| 997 |
+
# Defining the z-coordinate for the heatmap (a constant, such as -5)
|
| 998 |
+
z = np.asarray([[-10] * heatmap.shape[1]] * heatmap.shape[0])
|
| 999 |
+
|
| 1000 |
+
# Plotting the heatmap as a texture on the xy-plane
|
| 1001 |
+
ax.plot_surface(x, y, z, facecolors=heatmap, rstride=1, cstride=1, shade=False)
|
| 1002 |
+
|
| 1003 |
+
# Initializing the single distribution array
|
| 1004 |
+
single_distribution = np.asarray([[1e-6] * heatmap.shape[1]] * heatmap.shape[0], dtype=float)
|
| 1005 |
+
|
| 1006 |
+
importance = 0
|
| 1007 |
+
# Creating the single distribution by summing up Gaussian distributions for each segment
|
| 1008 |
+
for rank, info in sorted_sara_dict:
|
| 1009 |
+
# Retrieving segment coordinates
|
| 1010 |
+
coords = index_info[rank]
|
| 1011 |
+
|
| 1012 |
+
# Creating a Gaussian distribution for the whole segment, i.e., arrange all the pixels in the segment in a 3D Gaussian distribution
|
| 1013 |
+
x_temp = np.linspace(0, 1, coords[2] - coords[0])
|
| 1014 |
+
y_temp = np.linspace(0, 1, coords[3] - coords[1])
|
| 1015 |
+
|
| 1016 |
+
# Creating a meshgrid
|
| 1017 |
+
x_temp, y_temp = np.meshgrid(x_temp, y_temp)
|
| 1018 |
+
|
| 1019 |
+
# Calculating the Gaussian distribution
|
| 1020 |
+
distribution = np.exp(-((x_temp - 0.5) ** 2 + (y_temp - 0.5) ** 2) / 0.1) * ((grid_size ** 2 - importance) / grid_size ** 2) # (constant)
|
| 1021 |
+
|
| 1022 |
+
# Adding the Gaussian distribution to the single distribution
|
| 1023 |
+
single_distribution[coords[1]:coords[3], coords[0]:coords[2]] += distribution
|
| 1024 |
+
|
| 1025 |
+
# Incrementing importance
|
| 1026 |
+
importance +=1
|
| 1027 |
+
|
| 1028 |
+
# Based on the rate, calculating the minimum number for the most important ranks
|
| 1029 |
+
min_rank = int(grid_size * grid_size * rate)
|
| 1030 |
+
|
| 1031 |
+
# Calculating the scale factor for the single distribution
|
| 1032 |
+
scale_factor = ((grid_size ** 2 - min_rank) / grid_size ** 2) * 5
|
| 1033 |
+
|
| 1034 |
+
# Scaling the distribution
|
| 1035 |
+
single_distribution *= scale_factor
|
| 1036 |
+
|
| 1037 |
+
# Retrieving the max and min values of the single distribution
|
| 1038 |
+
max_value = np.max(single_distribution)
|
| 1039 |
+
min_value = np.min(single_distribution)
|
| 1040 |
+
|
| 1041 |
+
# Calculating the hyperplane
|
| 1042 |
+
hyperplane = np.asarray([[(max_value - min_value)* (1 - rate) + min_value] * heatmap.shape[1]] * heatmap.shape[0])
|
| 1043 |
+
|
| 1044 |
+
# Plotting a horizontal plane at the minimum rank level (hyperplane)
|
| 1045 |
+
ax.plot_surface(x, y, hyperplane, rstride=1, cstride=1, color='red', alpha=0.3, shade=False)
|
| 1046 |
+
|
| 1047 |
+
# Plotting the single distribution as a wireframe on the xy-plane
|
| 1048 |
+
ax.plot_surface(x, y, single_distribution, rstride=1, cstride=1, color='blue', shade=False)
|
| 1049 |
+
|
| 1050 |
+
# Setting the title
|
| 1051 |
+
ax.set_title('SaRa 3D Heatmap Plot', fontsize=20)
|
| 1052 |
+
|
| 1053 |
+
# Setting the labels
|
| 1054 |
+
ax.set_xlabel('X', fontsize=16)
|
| 1055 |
+
ax.set_ylabel('Y', fontsize=16)
|
| 1056 |
+
ax.set_zlabel('Z', fontsize=16)
|
| 1057 |
+
|
| 1058 |
+
# Setting the viewing angle to look from the y, x diagonal position
|
| 1059 |
+
ax.view_init(elev=30, azim=45) # Adjust the elevation (elev) and azimuth (azim) angles as needed
|
| 1060 |
+
# ax.view_init(elev=0, azim=0) # View from the top
|
| 1061 |
+
|
| 1062 |
+
# Adding legend to the plot
|
| 1063 |
+
# Creating Line2D objects for the legend
|
| 1064 |
+
legend_elements = [Line2D([0], [0], color='blue', lw=4, label='Rank Distribution'),
|
| 1065 |
+
Line2D([0], [0], color='red', lw=4, label='Threshold Hyperplane ({}%)'.format(rate*100)),
|
| 1066 |
+
Line2D([0], [0], color='green', lw=4, label='SaRa Heatmap')]
|
| 1067 |
+
|
| 1068 |
+
# Creating the legend
|
| 1069 |
+
plt.subplots_adjust(right=0.5)
|
| 1070 |
+
ax.legend(handles=legend_elements, fontsize=16, loc='center left', bbox_to_anchor=(1, 0.5))
|
| 1071 |
+
|
| 1072 |
+
# Inverting the x axis
|
| 1073 |
+
ax.invert_xaxis()
|
| 1074 |
+
|
| 1075 |
+
# Removing labels
|
| 1076 |
+
ax.set_xticks([])
|
| 1077 |
+
ax.set_yticks([])
|
| 1078 |
+
ax.set_zticks([])
|
| 1079 |
+
|
| 1080 |
+
# Showing the plot
|
| 1081 |
+
plt.show()
|
| 1082 |
+
|
app.py
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Tuple
|
| 2 |
+
import gradio as gr
|
| 3 |
+
import numpy as np
|
| 4 |
+
import cv2
|
| 5 |
+
import SaRa.saraRC1 as sara
|
| 6 |
+
import warnings
|
| 7 |
+
warnings.filterwarnings("ignore")
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
ALPHA = 0.4
|
| 11 |
+
GENERATORS = ['itti', 'deepgaze']
|
| 12 |
+
|
| 13 |
+
MARKDOWN = """
|
| 14 |
+
<h1 style='text-align: center'>Saliency Ranking 📚</h1>
|
| 15 |
+
|
| 16 |
+
Saliency Ranking is a fundamental 🌟 **Computer Vision** 🌟 process aimed at discerning the most visually significant features within an image 🖼️.
|
| 17 |
+
|
| 18 |
+
🌟 This demo showcases the **SaRa (Saliency-Driven Object Ranking)** model for Saliency Ranking 🎯, which can efficiently rank the visual saliency of an image without requiring any training. 🖼️
|
| 19 |
+
|
| 20 |
+
This technique is configured on the Saliency Map generator model by Itti, which works based on the primate visual cortex 🧠, and can work with or without depth information 🔄.
|
| 21 |
+
|
| 22 |
+
<div style="display: flex; align-items: center;">
|
| 23 |
+
<a href="https://github.com/dylanseychell/SaliencyRanking" style="margin-right: 10px;">
|
| 24 |
+
<img src="https://badges.aleen42.com/src/github.svg">
|
| 25 |
+
</a>
|
| 26 |
+
<a href="https://github.com/mbar0075/SaRa" style="margin-right: 10px;">
|
| 27 |
+
<img src="https://badges.aleen42.com/src/github.svg">
|
| 28 |
+
</a>
|
| 29 |
+
<a href="https://github.com/matthewkenely/ICT3909" style="margin-right: 10px;">
|
| 30 |
+
<img src="https://badges.aleen42.com/src/github.svg">
|
| 31 |
+
</a>
|
| 32 |
+
</div>
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
IMAGE_EXAMPLES = [
|
| 36 |
+
['https://media.roboflow.com/supervision/image-examples/people-walking.png', 32],
|
| 37 |
+
['https://media.roboflow.com/supervision/image-examples/vehicles.png', 32],
|
| 38 |
+
['https://media.roboflow.com/supervision/image-examples/basketball-1.png', 32],
|
| 39 |
+
]
|
| 40 |
+
|
| 41 |
+
def detect_and_annotate(image,
|
| 42 |
+
GRID_SIZE,
|
| 43 |
+
generator,
|
| 44 |
+
ALPHA=ALPHA,
|
| 45 |
+
mode=1)-> np.ndarray:
|
| 46 |
+
# Converting from PIL to OpenCV
|
| 47 |
+
image = np.array(image)
|
| 48 |
+
# Convert image from BGR to RGB
|
| 49 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 50 |
+
|
| 51 |
+
# Copy and convert the image for sara processing
|
| 52 |
+
sara_image = image.copy()
|
| 53 |
+
# sara_image = cv2.cvtColor(sara_image, cv2.COLOR_RGB2BGR)
|
| 54 |
+
|
| 55 |
+
# Resetting sara
|
| 56 |
+
sara.reset()
|
| 57 |
+
|
| 58 |
+
# Running sara (Original implementation on itti)
|
| 59 |
+
sara_info = sara.return_sara(sara_image, GRID_SIZE, generator, mode=mode)
|
| 60 |
+
|
| 61 |
+
# Generate saliency map
|
| 62 |
+
saliency_map = sara.return_saliency(image, generator=generator)
|
| 63 |
+
# Resize saliency map to match the image size
|
| 64 |
+
saliency_map = cv2.resize(saliency_map, (image.shape[1], image.shape[0]))
|
| 65 |
+
|
| 66 |
+
# Apply color map and convert to RGB
|
| 67 |
+
saliency_map = cv2.applyColorMap(saliency_map, cv2.COLORMAP_JET)
|
| 68 |
+
saliency_map = cv2.cvtColor(saliency_map, cv2.COLOR_BGR2RGB)
|
| 69 |
+
|
| 70 |
+
# Overlay the saliency map on the original image
|
| 71 |
+
saliency_map = cv2.addWeighted(saliency_map, ALPHA, image, 1-ALPHA, 0)
|
| 72 |
+
|
| 73 |
+
# Extract and convert heatmap to RGB
|
| 74 |
+
heatmap = sara_info[0]
|
| 75 |
+
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
|
| 76 |
+
|
| 77 |
+
return saliency_map, heatmap
|
| 78 |
+
|
| 79 |
+
def process_image(
|
| 80 |
+
input_image: np.ndarray,
|
| 81 |
+
GRIDSIZE: int,
|
| 82 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
| 83 |
+
# Validate GRID_SIZE
|
| 84 |
+
if GRIDSIZE is None or GRIDSIZE < 3:
|
| 85 |
+
GRIDSIZE = 9
|
| 86 |
+
|
| 87 |
+
itti_saliency_map, itti_heatmap = detect_and_annotate(
|
| 88 |
+
input_image, GRIDSIZE, 'itti')
|
| 89 |
+
_, itti_heatmap2 = detect_and_annotate(
|
| 90 |
+
input_image, GRIDSIZE, 'itti', mode=2)
|
| 91 |
+
# deepgaze_saliency_map, deepgaze_heatmap = detect_and_annotate(
|
| 92 |
+
# input_image, GRIDSIZE, 'deepgaze')
|
| 93 |
+
|
| 94 |
+
return (
|
| 95 |
+
itti_saliency_map,
|
| 96 |
+
itti_heatmap,
|
| 97 |
+
itti_heatmap2,
|
| 98 |
+
# deepgaze_saliency_map,
|
| 99 |
+
# deepgaze_heatmap,
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
grid_size_Component = gr.Slider(
|
| 103 |
+
minimum=3,
|
| 104 |
+
maximum=100,
|
| 105 |
+
value=32,
|
| 106 |
+
step=1,
|
| 107 |
+
label="Grid Size",
|
| 108 |
+
info=(
|
| 109 |
+
"The grid size for the Saliency Ranking (SaRa) model. The grid size determines "
|
| 110 |
+
"the number of regions the image is divided into. A higher grid size results in "
|
| 111 |
+
"more regions and a lower grid size results in fewer regions. The default grid "
|
| 112 |
+
"size is 9."
|
| 113 |
+
))
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
with gr.Blocks() as demo:
|
| 117 |
+
gr.Markdown(MARKDOWN)
|
| 118 |
+
with gr.Accordion("Configuration", open=False):
|
| 119 |
+
with gr.Row():
|
| 120 |
+
grid_size_Component.render()
|
| 121 |
+
with gr.Row():
|
| 122 |
+
input_image_component = gr.Image(
|
| 123 |
+
type='pil',
|
| 124 |
+
label='Input'
|
| 125 |
+
)
|
| 126 |
+
itti_saliency_map = gr.Image(
|
| 127 |
+
type='pil',
|
| 128 |
+
label='Itti Saliency Map'
|
| 129 |
+
)
|
| 130 |
+
with gr.Row():
|
| 131 |
+
itti_heatmap = gr.Image(
|
| 132 |
+
type='pil',
|
| 133 |
+
label='Saliency Ranking Heatmap 1'
|
| 134 |
+
)
|
| 135 |
+
itti_heatmap2 = gr.Image(
|
| 136 |
+
type='pil',
|
| 137 |
+
label='Saliency Ranking Heatmap 2'
|
| 138 |
+
)
|
| 139 |
+
# with gr.Row():
|
| 140 |
+
# deepgaze_saliency_map = gr.Image(
|
| 141 |
+
# type='pil',
|
| 142 |
+
# label='DeepGaze Saliency Map'
|
| 143 |
+
# )
|
| 144 |
+
# deepgaze_heatmap = gr.Image(
|
| 145 |
+
# type='pil',
|
| 146 |
+
# label='DeepGaze Saliency Ranking Heatmap'
|
| 147 |
+
# )
|
| 148 |
+
submit_button_component = gr.Button(
|
| 149 |
+
value='Submit',
|
| 150 |
+
scale=1,
|
| 151 |
+
variant='primary'
|
| 152 |
+
)
|
| 153 |
+
gr.Examples(
|
| 154 |
+
fn=process_image,
|
| 155 |
+
examples=IMAGE_EXAMPLES,
|
| 156 |
+
inputs=[
|
| 157 |
+
input_image_component,
|
| 158 |
+
grid_size_Component,
|
| 159 |
+
],
|
| 160 |
+
outputs=[
|
| 161 |
+
itti_saliency_map,
|
| 162 |
+
itti_heatmap,
|
| 163 |
+
itti_heatmap2,
|
| 164 |
+
# deepgaze_saliency_map,
|
| 165 |
+
# deepgaze_heatmap,
|
| 166 |
+
]
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
submit_button_component.click(
|
| 170 |
+
fn=process_image,
|
| 171 |
+
inputs=[
|
| 172 |
+
input_image_component,
|
| 173 |
+
grid_size_Component,
|
| 174 |
+
],
|
| 175 |
+
outputs=[
|
| 176 |
+
itti_saliency_map,
|
| 177 |
+
itti_heatmap,
|
| 178 |
+
itti_heatmap2,
|
| 179 |
+
# deepgaze_saliency_map,
|
| 180 |
+
# deepgaze_heatmap,
|
| 181 |
+
]
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
demo.launch(debug=False, show_error=True, max_threads=1)
|
deepgaze_pytorch/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .deepgaze1 import DeepGazeI
|
| 2 |
+
from .deepgaze2e import DeepGazeIIE
|
| 3 |
+
from .deepgaze3 import DeepGazeIII
|
deepgaze_pytorch/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (289 Bytes). View file
|
|
|
deepgaze_pytorch/__pycache__/deepgaze1.cpython-39.pyc
ADDED
|
Binary file (2.12 kB). View file
|
|
|
deepgaze_pytorch/__pycache__/deepgaze2e.cpython-39.pyc
ADDED
|
Binary file (4.29 kB). View file
|
|
|
deepgaze_pytorch/__pycache__/deepgaze3.cpython-39.pyc
ADDED
|
Binary file (3.48 kB). View file
|
|
|
deepgaze_pytorch/__pycache__/layers.cpython-39.pyc
ADDED
|
Binary file (13.7 kB). View file
|
|
|
deepgaze_pytorch/__pycache__/modules.cpython-39.pyc
ADDED
|
Binary file (10.4 kB). View file
|
|
|
deepgaze_pytorch/data.py
ADDED
|
@@ -0,0 +1,403 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import Counter
|
| 2 |
+
import io
|
| 3 |
+
import os
|
| 4 |
+
import pickle
|
| 5 |
+
import random
|
| 6 |
+
|
| 7 |
+
from boltons.iterutils import chunked
|
| 8 |
+
import lmdb
|
| 9 |
+
import numpy as np
|
| 10 |
+
from PIL import Image
|
| 11 |
+
import pysaliency
|
| 12 |
+
from pysaliency.datasets import create_subset
|
| 13 |
+
from pysaliency.utils import remove_trailing_nans
|
| 14 |
+
import torch
|
| 15 |
+
from tqdm import tqdm
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def ensure_color_image(image):
|
| 19 |
+
if len(image.shape) == 2:
|
| 20 |
+
return np.dstack([image, image, image])
|
| 21 |
+
return image
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def x_y_to_sparse_indices(xs, ys):
|
| 25 |
+
# Converts list of x and y coordinates into indices and values for sparse mask
|
| 26 |
+
x_inds = []
|
| 27 |
+
y_inds = []
|
| 28 |
+
values = []
|
| 29 |
+
pair_inds = {}
|
| 30 |
+
|
| 31 |
+
for x, y in zip(xs, ys):
|
| 32 |
+
key = (x, y)
|
| 33 |
+
if key not in pair_inds:
|
| 34 |
+
x_inds.append(x)
|
| 35 |
+
y_inds.append(y)
|
| 36 |
+
pair_inds[key] = len(x_inds) - 1
|
| 37 |
+
values.append(1)
|
| 38 |
+
else:
|
| 39 |
+
values[pair_inds[key]] += 1
|
| 40 |
+
|
| 41 |
+
return np.array([y_inds, x_inds]), values
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class ImageDataset(torch.utils.data.Dataset):
|
| 45 |
+
def __init__(
|
| 46 |
+
self,
|
| 47 |
+
stimuli,
|
| 48 |
+
fixations,
|
| 49 |
+
centerbias_model=None,
|
| 50 |
+
lmdb_path=None,
|
| 51 |
+
transform=None,
|
| 52 |
+
cached=None,
|
| 53 |
+
average='fixation'
|
| 54 |
+
):
|
| 55 |
+
self.stimuli = stimuli
|
| 56 |
+
self.fixations = fixations
|
| 57 |
+
self.centerbias_model = centerbias_model
|
| 58 |
+
self.lmdb_path = lmdb_path
|
| 59 |
+
self.transform = transform
|
| 60 |
+
self.average = average
|
| 61 |
+
|
| 62 |
+
# cache only short dataset
|
| 63 |
+
if cached is None:
|
| 64 |
+
cached = len(self.stimuli) < 100
|
| 65 |
+
|
| 66 |
+
cache_fixation_data = cached
|
| 67 |
+
|
| 68 |
+
if lmdb_path is not None:
|
| 69 |
+
_export_dataset_to_lmdb(stimuli, centerbias_model, lmdb_path)
|
| 70 |
+
self.lmdb_env = lmdb.open(lmdb_path, subdir=os.path.isdir(lmdb_path),
|
| 71 |
+
readonly=True, lock=False,
|
| 72 |
+
readahead=False, meminit=False
|
| 73 |
+
)
|
| 74 |
+
cached = False
|
| 75 |
+
cache_fixation_data = True
|
| 76 |
+
else:
|
| 77 |
+
self.lmdb_env = None
|
| 78 |
+
|
| 79 |
+
self.cached = cached
|
| 80 |
+
if cached:
|
| 81 |
+
self._cache = {}
|
| 82 |
+
self.cache_fixation_data = cache_fixation_data
|
| 83 |
+
if cache_fixation_data:
|
| 84 |
+
print("Populating fixations cache")
|
| 85 |
+
self._xs_cache = {}
|
| 86 |
+
self._ys_cache = {}
|
| 87 |
+
|
| 88 |
+
for x, y, n in zip(self.fixations.x_int, self.fixations.y_int, tqdm(self.fixations.n)):
|
| 89 |
+
self._xs_cache.setdefault(n, []).append(x)
|
| 90 |
+
self._ys_cache.setdefault(n, []).append(y)
|
| 91 |
+
|
| 92 |
+
for key in list(self._xs_cache):
|
| 93 |
+
self._xs_cache[key] = np.array(self._xs_cache[key], dtype=int)
|
| 94 |
+
for key in list(self._ys_cache):
|
| 95 |
+
self._ys_cache[key] = np.array(self._ys_cache[key], dtype=int)
|
| 96 |
+
|
| 97 |
+
def get_shapes(self):
|
| 98 |
+
return list(self.stimuli.sizes)
|
| 99 |
+
|
| 100 |
+
def _get_image_data(self, n):
|
| 101 |
+
if self.lmdb_env:
|
| 102 |
+
image, centerbias_prediction = _get_image_data_from_lmdb(self.lmdb_env, n)
|
| 103 |
+
else:
|
| 104 |
+
image = np.array(self.stimuli.stimuli[n])
|
| 105 |
+
centerbias_prediction = self.centerbias_model.log_density(image)
|
| 106 |
+
|
| 107 |
+
image = ensure_color_image(image).astype(np.float32)
|
| 108 |
+
image = image.transpose(2, 0, 1)
|
| 109 |
+
|
| 110 |
+
return image, centerbias_prediction
|
| 111 |
+
|
| 112 |
+
def __getitem__(self, key):
|
| 113 |
+
if not self.cached or key not in self._cache:
|
| 114 |
+
|
| 115 |
+
image, centerbias_prediction = self._get_image_data(key)
|
| 116 |
+
centerbias_prediction = centerbias_prediction.astype(np.float32)
|
| 117 |
+
|
| 118 |
+
if self.cache_fixation_data and self.cached:
|
| 119 |
+
xs = self._xs_cache.pop(key)
|
| 120 |
+
ys = self._ys_cache.pop(key)
|
| 121 |
+
elif self.cache_fixation_data and not self.cached:
|
| 122 |
+
xs = self._xs_cache[key]
|
| 123 |
+
ys = self._ys_cache[key]
|
| 124 |
+
else:
|
| 125 |
+
inds = self.fixations.n == key
|
| 126 |
+
xs = np.array(self.fixations.x_int[inds], dtype=int)
|
| 127 |
+
ys = np.array(self.fixations.y_int[inds], dtype=int)
|
| 128 |
+
|
| 129 |
+
data = {
|
| 130 |
+
"image": image,
|
| 131 |
+
"x": xs,
|
| 132 |
+
"y": ys,
|
| 133 |
+
"centerbias": centerbias_prediction,
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
if self.average == 'image':
|
| 137 |
+
data['weight'] = 1.0
|
| 138 |
+
else:
|
| 139 |
+
data['weight'] = float(len(xs))
|
| 140 |
+
|
| 141 |
+
if self.cached:
|
| 142 |
+
self._cache[key] = data
|
| 143 |
+
else:
|
| 144 |
+
data = self._cache[key]
|
| 145 |
+
|
| 146 |
+
if self.transform is not None:
|
| 147 |
+
return self.transform(dict(data))
|
| 148 |
+
|
| 149 |
+
return data
|
| 150 |
+
|
| 151 |
+
def __len__(self):
|
| 152 |
+
return len(self.stimuli)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
class FixationDataset(torch.utils.data.Dataset):
|
| 156 |
+
def __init__(
|
| 157 |
+
self,
|
| 158 |
+
stimuli, fixations,
|
| 159 |
+
centerbias_model=None,
|
| 160 |
+
lmdb_path=None,
|
| 161 |
+
transform=None,
|
| 162 |
+
included_fixations=-2,
|
| 163 |
+
allow_missing_fixations=False,
|
| 164 |
+
average='fixation',
|
| 165 |
+
cache_image_data=False,
|
| 166 |
+
):
|
| 167 |
+
self.stimuli = stimuli
|
| 168 |
+
self.fixations = fixations
|
| 169 |
+
self.centerbias_model = centerbias_model
|
| 170 |
+
self.lmdb_path = lmdb_path
|
| 171 |
+
|
| 172 |
+
if lmdb_path is not None:
|
| 173 |
+
_export_dataset_to_lmdb(stimuli, centerbias_model, lmdb_path)
|
| 174 |
+
self.lmdb_env = lmdb.open(lmdb_path, subdir=os.path.isdir(lmdb_path),
|
| 175 |
+
readonly=True, lock=False,
|
| 176 |
+
readahead=False, meminit=False
|
| 177 |
+
)
|
| 178 |
+
cache_image_data=False
|
| 179 |
+
else:
|
| 180 |
+
self.lmdb_env = None
|
| 181 |
+
|
| 182 |
+
self.transform = transform
|
| 183 |
+
self.average = average
|
| 184 |
+
|
| 185 |
+
self._shapes = None
|
| 186 |
+
|
| 187 |
+
if isinstance(included_fixations, int):
|
| 188 |
+
if included_fixations < 0:
|
| 189 |
+
included_fixations = [-1 - i for i in range(-included_fixations)]
|
| 190 |
+
else:
|
| 191 |
+
raise NotImplementedError()
|
| 192 |
+
|
| 193 |
+
self.included_fixations = included_fixations
|
| 194 |
+
self.allow_missing_fixations = allow_missing_fixations
|
| 195 |
+
self.fixation_counts = Counter(fixations.n)
|
| 196 |
+
|
| 197 |
+
self.cache_image_data = cache_image_data
|
| 198 |
+
|
| 199 |
+
if self.cache_image_data:
|
| 200 |
+
self.image_data_cache = {}
|
| 201 |
+
|
| 202 |
+
print("Populating image cache")
|
| 203 |
+
for n in tqdm(range(len(self.stimuli))):
|
| 204 |
+
self.image_data_cache[n] = self._get_image_data(n)
|
| 205 |
+
|
| 206 |
+
def get_shapes(self):
|
| 207 |
+
if self._shapes is None:
|
| 208 |
+
shapes = list(self.stimuli.sizes)
|
| 209 |
+
self._shapes = [shapes[n] for n in self.fixations.n]
|
| 210 |
+
|
| 211 |
+
return self._shapes
|
| 212 |
+
|
| 213 |
+
def _get_image_data(self, n):
|
| 214 |
+
if self.lmdb_path:
|
| 215 |
+
return _get_image_data_from_lmdb(self.lmdb_env, n)
|
| 216 |
+
image = np.array(self.stimuli.stimuli[n])
|
| 217 |
+
centerbias_prediction = self.centerbias_model.log_density(image)
|
| 218 |
+
|
| 219 |
+
image = ensure_color_image(image).astype(np.float32)
|
| 220 |
+
image = image.transpose(2, 0, 1)
|
| 221 |
+
|
| 222 |
+
return image, centerbias_prediction
|
| 223 |
+
|
| 224 |
+
def __getitem__(self, key):
|
| 225 |
+
n = self.fixations.n[key]
|
| 226 |
+
|
| 227 |
+
if self.cache_image_data:
|
| 228 |
+
image, centerbias_prediction = self.image_data_cache[n]
|
| 229 |
+
else:
|
| 230 |
+
image, centerbias_prediction = self._get_image_data(n)
|
| 231 |
+
|
| 232 |
+
centerbias_prediction = centerbias_prediction.astype(np.float32)
|
| 233 |
+
|
| 234 |
+
x_hist = remove_trailing_nans(self.fixations.x_hist[key])
|
| 235 |
+
y_hist = remove_trailing_nans(self.fixations.y_hist[key])
|
| 236 |
+
|
| 237 |
+
if self.allow_missing_fixations:
|
| 238 |
+
_x_hist = []
|
| 239 |
+
_y_hist = []
|
| 240 |
+
for fixation_index in self.included_fixations:
|
| 241 |
+
if fixation_index < -len(x_hist):
|
| 242 |
+
_x_hist.append(np.nan)
|
| 243 |
+
_y_hist.append(np.nan)
|
| 244 |
+
else:
|
| 245 |
+
_x_hist.append(x_hist[fixation_index])
|
| 246 |
+
_y_hist.append(y_hist[fixation_index])
|
| 247 |
+
x_hist = np.array(_x_hist)
|
| 248 |
+
y_hist = np.array(_y_hist)
|
| 249 |
+
else:
|
| 250 |
+
print("Not missing")
|
| 251 |
+
x_hist = x_hist[self.included_fixations]
|
| 252 |
+
y_hist = y_hist[self.included_fixations]
|
| 253 |
+
|
| 254 |
+
data = {
|
| 255 |
+
"image": image,
|
| 256 |
+
"x": np.array([self.fixations.x_int[key]], dtype=int),
|
| 257 |
+
"y": np.array([self.fixations.y_int[key]], dtype=int),
|
| 258 |
+
"x_hist": x_hist,
|
| 259 |
+
"y_hist": y_hist,
|
| 260 |
+
"centerbias": centerbias_prediction,
|
| 261 |
+
}
|
| 262 |
+
|
| 263 |
+
if self.average == 'image':
|
| 264 |
+
data['weight'] = 1.0 / self.fixation_counts[n]
|
| 265 |
+
else:
|
| 266 |
+
data['weight'] = 1.0
|
| 267 |
+
|
| 268 |
+
if self.transform is not None:
|
| 269 |
+
return self.transform(data)
|
| 270 |
+
|
| 271 |
+
return data
|
| 272 |
+
|
| 273 |
+
def __len__(self):
|
| 274 |
+
return len(self.fixations)
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
class FixationMaskTransform(object):
|
| 278 |
+
def __init__(self, sparse=True):
|
| 279 |
+
super().__init__()
|
| 280 |
+
self.sparse = sparse
|
| 281 |
+
|
| 282 |
+
def __call__(self, item):
|
| 283 |
+
shape = torch.Size([item['image'].shape[1], item['image'].shape[2]])
|
| 284 |
+
x = item.pop('x')
|
| 285 |
+
y = item.pop('y')
|
| 286 |
+
|
| 287 |
+
# inds, values = x_y_to_sparse_indices(x, y)
|
| 288 |
+
inds = np.array([y, x])
|
| 289 |
+
values = np.ones(len(y), dtype=int)
|
| 290 |
+
|
| 291 |
+
mask = torch.sparse.IntTensor(torch.tensor(inds), torch.tensor(values), shape)
|
| 292 |
+
mask = mask.coalesce()
|
| 293 |
+
# sparse tensors don't work with workers...
|
| 294 |
+
if not self.sparse:
|
| 295 |
+
mask = mask.to_dense()
|
| 296 |
+
|
| 297 |
+
item['fixation_mask'] = mask
|
| 298 |
+
|
| 299 |
+
return item
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
class ImageDatasetSampler(torch.utils.data.Sampler):
|
| 303 |
+
def __init__(self, data_source, batch_size=1, ratio_used=1.0, shuffle=True):
|
| 304 |
+
self.ratio_used = ratio_used
|
| 305 |
+
self.shuffle = shuffle
|
| 306 |
+
|
| 307 |
+
shapes = data_source.get_shapes()
|
| 308 |
+
unique_shapes = sorted(set(shapes))
|
| 309 |
+
|
| 310 |
+
shape_indices = [[] for shape in unique_shapes]
|
| 311 |
+
|
| 312 |
+
for k, shape in enumerate(shapes):
|
| 313 |
+
shape_indices[unique_shapes.index(shape)].append(k)
|
| 314 |
+
|
| 315 |
+
if self.shuffle:
|
| 316 |
+
for indices in shape_indices:
|
| 317 |
+
random.shuffle(indices)
|
| 318 |
+
|
| 319 |
+
self.batches = sum([chunked(indices, size=batch_size) for indices in shape_indices], [])
|
| 320 |
+
|
| 321 |
+
def __iter__(self):
|
| 322 |
+
if self.shuffle:
|
| 323 |
+
indices = torch.randperm(len(self.batches))
|
| 324 |
+
else:
|
| 325 |
+
indices = range(len(self.batches))
|
| 326 |
+
|
| 327 |
+
if self.ratio_used < 1.0:
|
| 328 |
+
indices = indices[:int(self.ratio_used * len(indices))]
|
| 329 |
+
|
| 330 |
+
return iter(self.batches[i] for i in indices)
|
| 331 |
+
|
| 332 |
+
def __len__(self):
|
| 333 |
+
return int(self.ratio_used * len(self.batches))
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
def _export_dataset_to_lmdb(stimuli: pysaliency.FileStimuli, centerbias_model: pysaliency.Model, lmdb_path, write_frequency=100):
|
| 337 |
+
lmdb_path = os.path.expanduser(lmdb_path)
|
| 338 |
+
isdir = os.path.isdir(lmdb_path)
|
| 339 |
+
|
| 340 |
+
print("Generate LMDB to %s" % lmdb_path)
|
| 341 |
+
db = lmdb.open(lmdb_path, subdir=isdir,
|
| 342 |
+
map_size=1099511627776 * 2, readonly=False,
|
| 343 |
+
meminit=False, map_async=True)
|
| 344 |
+
|
| 345 |
+
txn = db.begin(write=True)
|
| 346 |
+
for idx, stimulus in enumerate(tqdm(stimuli)):
|
| 347 |
+
key = u'{}'.format(idx).encode('ascii')
|
| 348 |
+
|
| 349 |
+
previous_data = txn.get(key)
|
| 350 |
+
if previous_data:
|
| 351 |
+
continue
|
| 352 |
+
|
| 353 |
+
#timulus_data = stimulus.stimulus_data
|
| 354 |
+
stimulus_filename = stimuli.filenames[idx]
|
| 355 |
+
centerbias = centerbias_model.log_density(stimulus)
|
| 356 |
+
|
| 357 |
+
txn.put(
|
| 358 |
+
key,
|
| 359 |
+
_encode_filestimulus_item(stimulus_filename, centerbias)
|
| 360 |
+
)
|
| 361 |
+
if idx % write_frequency == 0:
|
| 362 |
+
#print("[%d/%d]" % (idx, len(stimuli)))
|
| 363 |
+
#print("stimulus ids", len(stimuli.stimulus_ids._cache))
|
| 364 |
+
#print("stimuli.cached", stimuli.cached)
|
| 365 |
+
#print("stimuli", len(stimuli.stimuli._cache))
|
| 366 |
+
#print("centerbias", len(centerbias_model._cache._cache))
|
| 367 |
+
txn.commit()
|
| 368 |
+
txn = db.begin(write=True)
|
| 369 |
+
|
| 370 |
+
# finish iterating through dataset
|
| 371 |
+
txn.commit()
|
| 372 |
+
#keys = [u'{}'.format(k).encode('ascii') for k in range(idx + 1)]
|
| 373 |
+
#with db.begin(write=True) as txn:
|
| 374 |
+
# txn.put(b'__keys__', dumps_pyarrow(keys))
|
| 375 |
+
# txn.put(b'__len__', dumps_pyarrow(len(keys)))
|
| 376 |
+
|
| 377 |
+
print("Flushing database ...")
|
| 378 |
+
db.sync()
|
| 379 |
+
db.close()
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
def _encode_filestimulus_item(filename, centerbias):
|
| 383 |
+
with open(filename, 'rb') as f:
|
| 384 |
+
image_bytes = f.read()
|
| 385 |
+
|
| 386 |
+
buffer = io.BytesIO()
|
| 387 |
+
pickle.dump({'image': image_bytes, 'centerbias': centerbias}, buffer)
|
| 388 |
+
buffer.seek(0)
|
| 389 |
+
return buffer.read()
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
def _get_image_data_from_lmdb(lmdb_env, n):
|
| 393 |
+
key = '{}'.format(n).encode('ascii')
|
| 394 |
+
with lmdb_env.begin(write=False) as txn:
|
| 395 |
+
byteflow = txn.get(key)
|
| 396 |
+
data = pickle.loads(byteflow)
|
| 397 |
+
buffer = io.BytesIO(data['image'])
|
| 398 |
+
buffer.seek(0)
|
| 399 |
+
image = np.array(Image.open(buffer).convert('RGB'))
|
| 400 |
+
centerbias_prediction = data['centerbias']
|
| 401 |
+
image = image.transpose(2, 0, 1)
|
| 402 |
+
|
| 403 |
+
return image, centerbias_prediction
|
deepgaze_pytorch/deepgaze1.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import OrderedDict
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
from torch.utils import model_zoo
|
| 7 |
+
|
| 8 |
+
from .features.alexnet import RGBalexnet
|
| 9 |
+
from .modules import FeatureExtractor, Finalizer, DeepGazeII as TorchDeepGazeII
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class DeepGazeI(TorchDeepGazeII):
|
| 13 |
+
"""DeepGaze I model
|
| 14 |
+
|
| 15 |
+
Please note that this version of DeepGaze I is not exactly the one from the original paper.
|
| 16 |
+
The original model used caffe for AlexNet and theano for the linear readout and was trained using the SFO optimizer.
|
| 17 |
+
Here, we use the torch implementation of AlexNet (without any adaptations), which doesn't use the two-steam architecture,
|
| 18 |
+
and the DeepGaze II torch implementation with a simple linear readout network.
|
| 19 |
+
The model has been retrained with Adam, but still on the same dataset (all images of MIT1003 which are of size 1024x768).
|
| 20 |
+
Also, we don't use the sparsity penalty anymore.
|
| 21 |
+
|
| 22 |
+
Reference:
|
| 23 |
+
Kümmerer, M., Theis, L., & Bethge, M. (2015). Deep Gaze I: Boosting Saliency Prediction with Feature Maps Trained on ImageNet. ICLR Workshop Track. http://arxiv.org/abs/1411.1045
|
| 24 |
+
"""
|
| 25 |
+
def __init__(self, pretrained=True):
|
| 26 |
+
features = RGBalexnet()
|
| 27 |
+
feature_extractor = FeatureExtractor(features, ['1.features.10'])
|
| 28 |
+
|
| 29 |
+
readout_network = nn.Sequential(OrderedDict([
|
| 30 |
+
('conv0', nn.Conv2d(256, 1, (1, 1), bias=False)),
|
| 31 |
+
]))
|
| 32 |
+
|
| 33 |
+
super().__init__(
|
| 34 |
+
features=feature_extractor,
|
| 35 |
+
readout_network=readout_network,
|
| 36 |
+
downsample=2,
|
| 37 |
+
readout_factor=4,
|
| 38 |
+
saliency_map_factor=4,
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
if pretrained:
|
| 42 |
+
self.load_state_dict(model_zoo.load_url('https://github.com/matthias-k/DeepGaze/releases/download/v1.01/deepgaze1.pth', map_location=torch.device('cpu')))
|
deepgaze_pytorch/deepgaze2e.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import OrderedDict
|
| 2 |
+
import importlib
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
|
| 10 |
+
from torch.utils import model_zoo
|
| 11 |
+
|
| 12 |
+
from .modules import FeatureExtractor, Finalizer, DeepGazeIIIMixture, MixtureModel
|
| 13 |
+
|
| 14 |
+
from .layers import (
|
| 15 |
+
Conv2dMultiInput,
|
| 16 |
+
LayerNorm,
|
| 17 |
+
LayerNormMultiInput,
|
| 18 |
+
Bias,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
BACKBONES = [
|
| 23 |
+
{
|
| 24 |
+
'type': 'deepgaze_pytorch.features.shapenet.RGBShapeNetC',
|
| 25 |
+
'used_features': [
|
| 26 |
+
'1.module.layer3.0.conv2',
|
| 27 |
+
'1.module.layer3.3.conv2',
|
| 28 |
+
'1.module.layer3.5.conv1',
|
| 29 |
+
'1.module.layer3.5.conv2',
|
| 30 |
+
'1.module.layer4.1.conv2',
|
| 31 |
+
'1.module.layer4.2.conv2',
|
| 32 |
+
],
|
| 33 |
+
'channels': 2048,
|
| 34 |
+
},
|
| 35 |
+
{
|
| 36 |
+
'type': 'deepgaze_pytorch.features.efficientnet.RGBEfficientNetB5',
|
| 37 |
+
'used_features': [
|
| 38 |
+
'1._blocks.24._depthwise_conv',
|
| 39 |
+
'1._blocks.26._depthwise_conv',
|
| 40 |
+
'1._blocks.35._project_conv',
|
| 41 |
+
],
|
| 42 |
+
'channels': 2416,
|
| 43 |
+
},
|
| 44 |
+
{
|
| 45 |
+
'type': 'deepgaze_pytorch.features.densenet.RGBDenseNet201',
|
| 46 |
+
'used_features': [
|
| 47 |
+
'1.features.denseblock4.denselayer32.norm1',
|
| 48 |
+
'1.features.denseblock4.denselayer32.conv1',
|
| 49 |
+
'1.features.denseblock4.denselayer31.conv2',
|
| 50 |
+
],
|
| 51 |
+
'channels': 2048,
|
| 52 |
+
},
|
| 53 |
+
{
|
| 54 |
+
'type': 'deepgaze_pytorch.features.resnext.RGBResNext50',
|
| 55 |
+
'used_features': [
|
| 56 |
+
'1.layer3.5.conv1',
|
| 57 |
+
'1.layer3.5.conv2',
|
| 58 |
+
'1.layer3.4.conv2',
|
| 59 |
+
'1.layer4.2.conv2',
|
| 60 |
+
],
|
| 61 |
+
'channels': 2560,
|
| 62 |
+
},
|
| 63 |
+
]
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def build_saliency_network(input_channels):
|
| 67 |
+
return nn.Sequential(OrderedDict([
|
| 68 |
+
('layernorm0', LayerNorm(input_channels)),
|
| 69 |
+
('conv0', nn.Conv2d(input_channels, 8, (1, 1), bias=False)),
|
| 70 |
+
('bias0', Bias(8)),
|
| 71 |
+
('softplus0', nn.Softplus()),
|
| 72 |
+
|
| 73 |
+
('layernorm1', LayerNorm(8)),
|
| 74 |
+
('conv1', nn.Conv2d(8, 16, (1, 1), bias=False)),
|
| 75 |
+
('bias1', Bias(16)),
|
| 76 |
+
('softplus1', nn.Softplus()),
|
| 77 |
+
|
| 78 |
+
('layernorm2', LayerNorm(16)),
|
| 79 |
+
('conv2', nn.Conv2d(16, 1, (1, 1), bias=False)),
|
| 80 |
+
('bias2', Bias(1)),
|
| 81 |
+
('softplus3', nn.Softplus()),
|
| 82 |
+
]))
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def build_fixation_selection_network():
|
| 86 |
+
return nn.Sequential(OrderedDict([
|
| 87 |
+
('layernorm0', LayerNormMultiInput([1, 0])),
|
| 88 |
+
('conv0', Conv2dMultiInput([1, 0], 128, (1, 1), bias=False)),
|
| 89 |
+
('bias0', Bias(128)),
|
| 90 |
+
('softplus0', nn.Softplus()),
|
| 91 |
+
|
| 92 |
+
('layernorm1', LayerNorm(128)),
|
| 93 |
+
('conv1', nn.Conv2d(128, 16, (1, 1), bias=False)),
|
| 94 |
+
('bias1', Bias(16)),
|
| 95 |
+
('softplus1', nn.Softplus()),
|
| 96 |
+
|
| 97 |
+
('conv2', nn.Conv2d(16, 1, (1, 1), bias=False)),
|
| 98 |
+
]))
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def build_deepgaze_mixture(backbone_config, components=10):
|
| 102 |
+
feature_class = import_class(backbone_config['type'])
|
| 103 |
+
features = feature_class()
|
| 104 |
+
|
| 105 |
+
feature_extractor = FeatureExtractor(features, backbone_config['used_features'])
|
| 106 |
+
|
| 107 |
+
saliency_networks = []
|
| 108 |
+
scanpath_networks = []
|
| 109 |
+
fixation_selection_networks = []
|
| 110 |
+
finalizers = []
|
| 111 |
+
for component in range(components):
|
| 112 |
+
saliency_network = build_saliency_network(backbone_config['channels'])
|
| 113 |
+
fixation_selection_network = build_fixation_selection_network()
|
| 114 |
+
|
| 115 |
+
saliency_networks.append(saliency_network)
|
| 116 |
+
scanpath_networks.append(None)
|
| 117 |
+
fixation_selection_networks.append(fixation_selection_network)
|
| 118 |
+
finalizers.append(Finalizer(sigma=8.0, learn_sigma=True, saliency_map_factor=2))
|
| 119 |
+
|
| 120 |
+
return DeepGazeIIIMixture(
|
| 121 |
+
features=feature_extractor,
|
| 122 |
+
saliency_networks=saliency_networks,
|
| 123 |
+
scanpath_networks=scanpath_networks,
|
| 124 |
+
fixation_selection_networks=fixation_selection_networks,
|
| 125 |
+
finalizers=finalizers,
|
| 126 |
+
downsample=2,
|
| 127 |
+
readout_factor=16,
|
| 128 |
+
saliency_map_factor=2,
|
| 129 |
+
included_fixations=[],
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class DeepGazeIIE(MixtureModel):
|
| 134 |
+
"""DeepGazeIIE model
|
| 135 |
+
|
| 136 |
+
:note
|
| 137 |
+
See Linardos, A., Kümmerer, M., Press, O., & Bethge, M. (2021). Calibrated prediction in and out-of-domain for state-of-the-art saliency modeling. ArXiv:2105.12441 [Cs], http://arxiv.org/abs/2105.12441
|
| 138 |
+
"""
|
| 139 |
+
def __init__(self, pretrained=True):
|
| 140 |
+
# we average over 3 instances per backbone, each instance has 10 crossvalidation folds
|
| 141 |
+
backbone_models = [build_deepgaze_mixture(backbone_config, components=3 * 10) for backbone_config in BACKBONES]
|
| 142 |
+
super().__init__(backbone_models)
|
| 143 |
+
|
| 144 |
+
if pretrained:
|
| 145 |
+
self.load_state_dict(model_zoo.load_url('https://github.com/matthias-k/DeepGaze/releases/download/v1.0.0/deepgaze2e.pth', map_location=torch.device('cpu')))
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def import_class(name):
|
| 149 |
+
module_name, class_name = name.rsplit('.', 1)
|
| 150 |
+
module = importlib.import_module(module_name)
|
| 151 |
+
return getattr(module, class_name)
|
deepgaze_pytorch/deepgaze3.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import OrderedDict
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
from torch.utils import model_zoo
|
| 8 |
+
|
| 9 |
+
from .features.densenet import RGBDenseNet201
|
| 10 |
+
from .modules import FeatureExtractor, Finalizer, DeepGazeIIIMixture
|
| 11 |
+
from .layers import FlexibleScanpathHistoryEncoding
|
| 12 |
+
|
| 13 |
+
from .layers import (
|
| 14 |
+
Conv2dMultiInput,
|
| 15 |
+
LayerNorm,
|
| 16 |
+
LayerNormMultiInput,
|
| 17 |
+
Bias,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def build_saliency_network(input_channels):
|
| 22 |
+
return nn.Sequential(OrderedDict([
|
| 23 |
+
('layernorm0', LayerNorm(input_channels)),
|
| 24 |
+
('conv0', nn.Conv2d(input_channels, 8, (1, 1), bias=False)),
|
| 25 |
+
('bias0', Bias(8)),
|
| 26 |
+
('softplus0', nn.Softplus()),
|
| 27 |
+
|
| 28 |
+
('layernorm1', LayerNorm(8)),
|
| 29 |
+
('conv1', nn.Conv2d(8, 16, (1, 1), bias=False)),
|
| 30 |
+
('bias1', Bias(16)),
|
| 31 |
+
('softplus1', nn.Softplus()),
|
| 32 |
+
|
| 33 |
+
('layernorm2', LayerNorm(16)),
|
| 34 |
+
('conv2', nn.Conv2d(16, 1, (1, 1), bias=False)),
|
| 35 |
+
('bias2', Bias(1)),
|
| 36 |
+
('softplus2', nn.Softplus()),
|
| 37 |
+
]))
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def build_scanpath_network():
|
| 41 |
+
return nn.Sequential(OrderedDict([
|
| 42 |
+
('encoding0', FlexibleScanpathHistoryEncoding(in_fixations=4, channels_per_fixation=3, out_channels=128, kernel_size=[1, 1], bias=True)),
|
| 43 |
+
('softplus0', nn.Softplus()),
|
| 44 |
+
|
| 45 |
+
('layernorm1', LayerNorm(128)),
|
| 46 |
+
('conv1', nn.Conv2d(128, 16, (1, 1), bias=False)),
|
| 47 |
+
('bias1', Bias(16)),
|
| 48 |
+
('softplus1', nn.Softplus()),
|
| 49 |
+
]))
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def build_fixation_selection_network():
|
| 53 |
+
return nn.Sequential(OrderedDict([
|
| 54 |
+
('layernorm0', LayerNormMultiInput([1, 16])),
|
| 55 |
+
('conv0', Conv2dMultiInput([1, 16], 128, (1, 1), bias=False)),
|
| 56 |
+
('bias0', Bias(128)),
|
| 57 |
+
('softplus0', nn.Softplus()),
|
| 58 |
+
|
| 59 |
+
('layernorm1', LayerNorm(128)),
|
| 60 |
+
('conv1', nn.Conv2d(128, 16, (1, 1), bias=False)),
|
| 61 |
+
('bias1', Bias(16)),
|
| 62 |
+
('softplus1', nn.Softplus()),
|
| 63 |
+
|
| 64 |
+
('conv2', nn.Conv2d(16, 1, (1, 1), bias=False)),
|
| 65 |
+
]))
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class DeepGazeIII(DeepGazeIIIMixture):
|
| 69 |
+
"""DeepGazeIII model
|
| 70 |
+
|
| 71 |
+
:note
|
| 72 |
+
See Kümmerer, M., Bethge, M., & Wallis, T.S.A. (2022). DeepGaze III: Modeling free-viewing human scanpaths with deep learning. Journal of Vision 2022, https://doi.org/10.1167/jov.22.5.7
|
| 73 |
+
"""
|
| 74 |
+
def __init__(self, pretrained=True):
|
| 75 |
+
features = RGBDenseNet201()
|
| 76 |
+
|
| 77 |
+
feature_extractor = FeatureExtractor(features, [
|
| 78 |
+
'1.features.denseblock4.denselayer32.norm1',
|
| 79 |
+
'1.features.denseblock4.denselayer32.conv1',
|
| 80 |
+
'1.features.denseblock4.denselayer31.conv2',
|
| 81 |
+
])
|
| 82 |
+
|
| 83 |
+
saliency_networks = []
|
| 84 |
+
scanpath_networks = []
|
| 85 |
+
fixation_selection_networks = []
|
| 86 |
+
finalizers = []
|
| 87 |
+
for component in range(10):
|
| 88 |
+
saliency_network = build_saliency_network(2048)
|
| 89 |
+
scanpath_network = build_scanpath_network()
|
| 90 |
+
fixation_selection_network = build_fixation_selection_network()
|
| 91 |
+
|
| 92 |
+
saliency_networks.append(saliency_network)
|
| 93 |
+
scanpath_networks.append(scanpath_network)
|
| 94 |
+
fixation_selection_networks.append(fixation_selection_network)
|
| 95 |
+
finalizers.append(Finalizer(sigma=8.0, learn_sigma=True, saliency_map_factor=4))
|
| 96 |
+
|
| 97 |
+
super().__init__(
|
| 98 |
+
features=feature_extractor,
|
| 99 |
+
saliency_networks=saliency_networks,
|
| 100 |
+
scanpath_networks=scanpath_networks,
|
| 101 |
+
fixation_selection_networks=fixation_selection_networks,
|
| 102 |
+
finalizers=finalizers,
|
| 103 |
+
downsample=2,
|
| 104 |
+
readout_factor=4,
|
| 105 |
+
saliency_map_factor=4,
|
| 106 |
+
included_fixations=[-1, -2, -3, -4]
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
if pretrained:
|
| 110 |
+
self.load_state_dict(model_zoo.load_url('https://github.com/matthias-k/DeepGaze/releases/download/v1.1.0/deepgaze3.pth', map_location=torch.device('cpu')))
|
deepgaze_pytorch/features/__init__.py
ADDED
|
File without changes
|
deepgaze_pytorch/features/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (164 Bytes). View file
|
|
|
deepgaze_pytorch/features/__pycache__/alexnet.cpython-39.pyc
ADDED
|
Binary file (836 Bytes). View file
|
|
|
deepgaze_pytorch/features/__pycache__/densenet.cpython-39.pyc
ADDED
|
Binary file (852 Bytes). View file
|
|
|
deepgaze_pytorch/features/__pycache__/efficientnet.cpython-39.pyc
ADDED
|
Binary file (1.25 kB). View file
|
|
|
deepgaze_pytorch/features/__pycache__/normalizer.cpython-39.pyc
ADDED
|
Binary file (1.1 kB). View file
|
|
|
deepgaze_pytorch/features/__pycache__/resnext.cpython-39.pyc
ADDED
|
Binary file (1.23 kB). View file
|
|
|
deepgaze_pytorch/features/__pycache__/shapenet.cpython-39.pyc
ADDED
|
Binary file (3.7 kB). View file
|
|
|
deepgaze_pytorch/features/alexnet.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import OrderedDict
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torchvision
|
| 7 |
+
|
| 8 |
+
from .normalizer import Normalizer
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class RGBalexnet(nn.Sequential):
|
| 13 |
+
def __init__(self):
|
| 14 |
+
super(RGBalexnet, self).__init__()
|
| 15 |
+
self.model = torch.hub.load('pytorch/vision:v0.6.0', 'alexnet', pretrained=True)
|
| 16 |
+
self.normalizer = Normalizer()
|
| 17 |
+
super(RGBalexnet, self).__init__(self.normalizer, self.model)
|
| 18 |
+
|
deepgaze_pytorch/features/bagnet.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This code is adapted from: https://github.com/wielandbrendel/bag-of-local-features-models
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import math
|
| 7 |
+
import torch
|
| 8 |
+
from collections import OrderedDict
|
| 9 |
+
from torch.utils import model_zoo
|
| 10 |
+
|
| 11 |
+
from .normalizer import Normalizer
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
import os
|
| 15 |
+
dir_path = os.path.dirname(os.path.realpath(__file__))
|
| 16 |
+
|
| 17 |
+
__all__ = ['bagnet9', 'bagnet17', 'bagnet33']
|
| 18 |
+
|
| 19 |
+
model_urls = {
|
| 20 |
+
'bagnet9': 'https://bitbucket.org/wielandbrendel/bag-of-feature-pretrained-models/raw/249e8fa82c0913623a807d9d35eeab9da7dcc2a8/bagnet8-34f4ccd2.pth.tar',
|
| 21 |
+
'bagnet17': 'https://bitbucket.org/wielandbrendel/bag-of-feature-pretrained-models/raw/249e8fa82c0913623a807d9d35eeab9da7dcc2a8/bagnet16-105524de.pth.tar',
|
| 22 |
+
'bagnet33': 'https://bitbucket.org/wielandbrendel/bag-of-feature-pretrained-models/raw/249e8fa82c0913623a807d9d35eeab9da7dcc2a8/bagnet32-2ddd53ed.pth.tar',
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class Bottleneck(nn.Module):
|
| 27 |
+
expansion = 4
|
| 28 |
+
|
| 29 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None, kernel_size=1):
|
| 30 |
+
super(Bottleneck, self).__init__()
|
| 31 |
+
# print('Creating bottleneck with kernel size {} and stride {} with padding {}'.format(kernel_size, stride, (kernel_size - 1) // 2))
|
| 32 |
+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
| 33 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
| 34 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=kernel_size, stride=stride,
|
| 35 |
+
padding=0, bias=False) # changed padding from (kernel_size - 1) // 2
|
| 36 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
| 37 |
+
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
|
| 38 |
+
self.bn3 = nn.BatchNorm2d(planes * 4)
|
| 39 |
+
self.relu = nn.ReLU(inplace=True)
|
| 40 |
+
self.downsample = downsample
|
| 41 |
+
self.stride = stride
|
| 42 |
+
|
| 43 |
+
def forward(self, x, **kwargs):
|
| 44 |
+
residual = x
|
| 45 |
+
|
| 46 |
+
out = self.conv1(x)
|
| 47 |
+
out = self.bn1(out)
|
| 48 |
+
out = self.relu(out)
|
| 49 |
+
|
| 50 |
+
out = self.conv2(out)
|
| 51 |
+
out = self.bn2(out)
|
| 52 |
+
out = self.relu(out)
|
| 53 |
+
|
| 54 |
+
out = self.conv3(out)
|
| 55 |
+
out = self.bn3(out)
|
| 56 |
+
|
| 57 |
+
if self.downsample is not None:
|
| 58 |
+
residual = self.downsample(x)
|
| 59 |
+
|
| 60 |
+
if residual.size(-1) != out.size(-1):
|
| 61 |
+
diff = residual.size(-1) - out.size(-1)
|
| 62 |
+
residual = residual[:,:,:-diff,:-diff]
|
| 63 |
+
|
| 64 |
+
out += residual
|
| 65 |
+
out = self.relu(out)
|
| 66 |
+
|
| 67 |
+
return out
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class BagNet(nn.Module):
|
| 71 |
+
|
| 72 |
+
def __init__(self, block, layers, strides=[1, 2, 2, 2], kernel3=[0, 0, 0, 0], num_classes=1000, avg_pool=True):
|
| 73 |
+
self.inplanes = 64
|
| 74 |
+
super(BagNet, self).__init__()
|
| 75 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=1, stride=1, padding=0,
|
| 76 |
+
bias=False)
|
| 77 |
+
self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0,
|
| 78 |
+
bias=False)
|
| 79 |
+
self.bn1 = nn.BatchNorm2d(64, momentum=0.001)
|
| 80 |
+
self.relu = nn.ReLU(inplace=True)
|
| 81 |
+
self.layer1 = self._make_layer(block, 64, layers[0], stride=strides[0], kernel3=kernel3[0], prefix='layer1')
|
| 82 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[1], kernel3=kernel3[1], prefix='layer2')
|
| 83 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[2], kernel3=kernel3[2], prefix='layer3')
|
| 84 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=strides[3], kernel3=kernel3[3], prefix='layer4')
|
| 85 |
+
self.avgpool = nn.AvgPool2d(1, stride=1)
|
| 86 |
+
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
| 87 |
+
self.avg_pool = avg_pool
|
| 88 |
+
self.block = block
|
| 89 |
+
|
| 90 |
+
for m in self.modules():
|
| 91 |
+
if isinstance(m, nn.Conv2d):
|
| 92 |
+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
| 93 |
+
m.weight.data.normal_(0, math.sqrt(2. / n))
|
| 94 |
+
elif isinstance(m, nn.BatchNorm2d):
|
| 95 |
+
m.weight.data.fill_(1)
|
| 96 |
+
m.bias.data.zero_()
|
| 97 |
+
|
| 98 |
+
def _make_layer(self, block, planes, blocks, stride=1, kernel3=0, prefix=''):
|
| 99 |
+
downsample = None
|
| 100 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
| 101 |
+
downsample = nn.Sequential(
|
| 102 |
+
nn.Conv2d(self.inplanes, planes * block.expansion,
|
| 103 |
+
kernel_size=1, stride=stride, bias=False),
|
| 104 |
+
nn.BatchNorm2d(planes * block.expansion),
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
layers = []
|
| 108 |
+
kernel = 1 if kernel3 == 0 else 3
|
| 109 |
+
layers.append(block(self.inplanes, planes, stride, downsample, kernel_size=kernel))
|
| 110 |
+
self.inplanes = planes * block.expansion
|
| 111 |
+
for i in range(1, blocks):
|
| 112 |
+
kernel = 1 if kernel3 <= i else 3
|
| 113 |
+
layers.append(block(self.inplanes, planes, kernel_size=kernel))
|
| 114 |
+
|
| 115 |
+
return nn.Sequential(*layers)
|
| 116 |
+
|
| 117 |
+
def forward(self, x):
|
| 118 |
+
x = self.conv1(x)
|
| 119 |
+
x = self.conv2(x)
|
| 120 |
+
x = self.bn1(x)
|
| 121 |
+
x = self.relu(x)
|
| 122 |
+
|
| 123 |
+
x = self.layer1(x)
|
| 124 |
+
x = self.layer2(x)
|
| 125 |
+
x = self.layer3(x)
|
| 126 |
+
x = self.layer4(x)
|
| 127 |
+
|
| 128 |
+
if self.avg_pool:
|
| 129 |
+
x = nn.AvgPool2d(x.size()[2], stride=1)(x)
|
| 130 |
+
x = x.view(x.size(0), -1)
|
| 131 |
+
x = self.fc(x)
|
| 132 |
+
else:
|
| 133 |
+
x = x.permute(0,2,3,1)
|
| 134 |
+
x = self.fc(x)
|
| 135 |
+
|
| 136 |
+
return x
|
| 137 |
+
|
| 138 |
+
def bagnet33(pretrained=False, strides=[2, 2, 2, 1], **kwargs):
|
| 139 |
+
"""Constructs a Bagnet-33 model.
|
| 140 |
+
|
| 141 |
+
Args:
|
| 142 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
| 143 |
+
"""
|
| 144 |
+
model = BagNet(Bottleneck, [3, 4, 6, 3], strides=strides, kernel3=[1,1,1,1], **kwargs)
|
| 145 |
+
if pretrained:
|
| 146 |
+
model.load_state_dict(model_zoo.load_url(model_urls['bagnet33']))
|
| 147 |
+
return model
|
| 148 |
+
|
| 149 |
+
def bagnet17(pretrained=False, strides=[2, 2, 2, 1], **kwargs):
|
| 150 |
+
"""Constructs a Bagnet-17 model.
|
| 151 |
+
|
| 152 |
+
Args:
|
| 153 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
| 154 |
+
"""
|
| 155 |
+
model = BagNet(Bottleneck, [3, 4, 6, 3], strides=strides, kernel3=[1,1,1,0], **kwargs)
|
| 156 |
+
if pretrained:
|
| 157 |
+
model.load_state_dict(model_zoo.load_url(model_urls['bagnet17']))
|
| 158 |
+
return model
|
| 159 |
+
|
| 160 |
+
def bagnet9(pretrained=False, strides=[2, 2, 2, 1], **kwargs):
|
| 161 |
+
"""Constructs a Bagnet-9 model.
|
| 162 |
+
|
| 163 |
+
Args:
|
| 164 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
| 165 |
+
"""
|
| 166 |
+
model = BagNet(Bottleneck, [3, 4, 6, 3], strides=strides, kernel3=[1,1,0,0], **kwargs)
|
| 167 |
+
if pretrained:
|
| 168 |
+
model.load_state_dict(model_zoo.load_url(model_urls['bagnet9']))
|
| 169 |
+
return model
|
| 170 |
+
|
| 171 |
+
# --- DeepGaze Adaptation ----
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
class RGBBagNet17(nn.Sequential):
|
| 177 |
+
def __init__(self):
|
| 178 |
+
super(RGBBagNet17, self).__init__()
|
| 179 |
+
self.bagnet = bagnet17(pretrained=True, avg_pool=False)
|
| 180 |
+
self.normalizer = Normalizer()
|
| 181 |
+
super(RGBBagNet17, self).__init__(self.normalizer, self.bagnet)
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
class RGBBagNet33(nn.Sequential):
|
| 185 |
+
def __init__(self):
|
| 186 |
+
super(RGBBagNet33, self).__init__()
|
| 187 |
+
self.bagnet = bagnet33(pretrained=True, avg_pool=False)
|
| 188 |
+
self.normalizer = Normalizer()
|
| 189 |
+
super(RGBBagNet33, self).__init__(self.normalizer, self.bagnet)
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
|
deepgaze_pytorch/features/densenet.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import OrderedDict
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torchvision
|
| 7 |
+
|
| 8 |
+
from .normalizer import Normalizer
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class RGBDenseNet201(nn.Sequential):
|
| 13 |
+
def __init__(self):
|
| 14 |
+
super(RGBDenseNet201, self).__init__()
|
| 15 |
+
self.densenet = torch.hub.load('pytorch/vision:v0.6.0', 'densenet201', pretrained=True)
|
| 16 |
+
self.normalizer = Normalizer()
|
| 17 |
+
super(RGBDenseNet201, self).__init__(self.normalizer, self.densenet)
|
| 18 |
+
|
| 19 |
+
|
deepgaze_pytorch/features/efficientnet.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import OrderedDict
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torchvision
|
| 7 |
+
|
| 8 |
+
from .efficientnet_pytorch import EfficientNet
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
from .normalizer import Normalizer
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class RGBEfficientNetB5(nn.Sequential):
|
| 16 |
+
def __init__(self):
|
| 17 |
+
super(RGBEfficientNetB5, self).__init__()
|
| 18 |
+
self.efficientnet = EfficientNet.from_pretrained('efficientnet-b5')
|
| 19 |
+
self.normalizer = Normalizer()
|
| 20 |
+
super(RGBEfficientNetB5, self).__init__(self.normalizer, self.efficientnet)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class RGBEfficientNetB7(nn.Sequential):
|
| 25 |
+
def __init__(self):
|
| 26 |
+
super(RGBEfficientNetB7, self).__init__()
|
| 27 |
+
self.efficientnet = EfficientNet.from_pretrained('efficientnet-b7')
|
| 28 |
+
self.normalizer = Normalizer()
|
| 29 |
+
super(RGBEfficientNetB7, self).__init__(self.normalizer, self.efficientnet)
|
| 30 |
+
|
| 31 |
+
|
deepgaze_pytorch/features/efficientnet_pytorch/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__version__ = "0.6.3"
|
| 2 |
+
from .model import EfficientNet
|
| 3 |
+
from .utils import (
|
| 4 |
+
GlobalParams,
|
| 5 |
+
BlockArgs,
|
| 6 |
+
BlockDecoder,
|
| 7 |
+
efficientnet,
|
| 8 |
+
get_model_params,
|
| 9 |
+
)
|
| 10 |
+
|
deepgaze_pytorch/features/efficientnet_pytorch/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (383 Bytes). View file
|
|
|
deepgaze_pytorch/features/efficientnet_pytorch/__pycache__/model.cpython-39.pyc
ADDED
|
Binary file (6.98 kB). View file
|
|
|
deepgaze_pytorch/features/efficientnet_pytorch/__pycache__/utils.cpython-39.pyc
ADDED
|
Binary file (12.7 kB). View file
|
|
|
deepgaze_pytorch/features/efficientnet_pytorch/model.py
ADDED
|
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
from torch.nn import functional as F
|
| 4 |
+
|
| 5 |
+
from .utils import (
|
| 6 |
+
round_filters,
|
| 7 |
+
round_repeats,
|
| 8 |
+
drop_connect,
|
| 9 |
+
get_same_padding_conv2d,
|
| 10 |
+
get_model_params,
|
| 11 |
+
efficientnet_params,
|
| 12 |
+
load_pretrained_weights,
|
| 13 |
+
Swish,
|
| 14 |
+
MemoryEfficientSwish,
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
class MBConvBlock(nn.Module):
|
| 18 |
+
"""
|
| 19 |
+
Mobile Inverted Residual Bottleneck Block
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
block_args (namedtuple): BlockArgs, see above
|
| 23 |
+
global_params (namedtuple): GlobalParam, see above
|
| 24 |
+
|
| 25 |
+
Attributes:
|
| 26 |
+
has_se (bool): Whether the block contains a Squeeze and Excitation layer.
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
def __init__(self, block_args, global_params):
|
| 30 |
+
super().__init__()
|
| 31 |
+
self._block_args = block_args
|
| 32 |
+
self._bn_mom = 1 - global_params.batch_norm_momentum
|
| 33 |
+
self._bn_eps = global_params.batch_norm_epsilon
|
| 34 |
+
self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ratio <= 1)
|
| 35 |
+
self.id_skip = block_args.id_skip # skip connection and drop connect
|
| 36 |
+
|
| 37 |
+
# Get static or dynamic convolution depending on image size
|
| 38 |
+
Conv2d = get_same_padding_conv2d(image_size=global_params.image_size)
|
| 39 |
+
|
| 40 |
+
# Expansion phase
|
| 41 |
+
inp = self._block_args.input_filters # number of input channels
|
| 42 |
+
oup = self._block_args.input_filters * self._block_args.expand_ratio # number of output channels
|
| 43 |
+
if self._block_args.expand_ratio != 1:
|
| 44 |
+
self._expand_conv = Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False)
|
| 45 |
+
self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
|
| 46 |
+
|
| 47 |
+
# Depthwise convolution phase
|
| 48 |
+
k = self._block_args.kernel_size
|
| 49 |
+
s = self._block_args.stride
|
| 50 |
+
self._depthwise_conv = Conv2d(
|
| 51 |
+
in_channels=oup, out_channels=oup, groups=oup, # groups makes it depthwise
|
| 52 |
+
kernel_size=k, stride=s, bias=False)
|
| 53 |
+
self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
|
| 54 |
+
|
| 55 |
+
# Squeeze and Excitation layer, if desired
|
| 56 |
+
if self.has_se:
|
| 57 |
+
num_squeezed_channels = max(1, int(self._block_args.input_filters * self._block_args.se_ratio))
|
| 58 |
+
self._se_reduce = Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1)
|
| 59 |
+
self._se_expand = Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1)
|
| 60 |
+
|
| 61 |
+
# Output phase
|
| 62 |
+
final_oup = self._block_args.output_filters
|
| 63 |
+
self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False)
|
| 64 |
+
self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps)
|
| 65 |
+
self._swish = MemoryEfficientSwish()
|
| 66 |
+
|
| 67 |
+
def forward(self, inputs, drop_connect_rate=None):
|
| 68 |
+
"""
|
| 69 |
+
:param inputs: input tensor
|
| 70 |
+
:param drop_connect_rate: drop connect rate (float, between 0 and 1)
|
| 71 |
+
:return: output of block
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
# Expansion and Depthwise Convolution
|
| 75 |
+
x = inputs
|
| 76 |
+
if self._block_args.expand_ratio != 1:
|
| 77 |
+
x = self._swish(self._bn0(self._expand_conv(inputs)))
|
| 78 |
+
x = self._swish(self._bn1(self._depthwise_conv(x)))
|
| 79 |
+
|
| 80 |
+
# Squeeze and Excitation
|
| 81 |
+
if self.has_se:
|
| 82 |
+
x_squeezed = F.adaptive_avg_pool2d(x, 1)
|
| 83 |
+
x_squeezed = self._se_expand(self._swish(self._se_reduce(x_squeezed)))
|
| 84 |
+
x = torch.sigmoid(x_squeezed) * x
|
| 85 |
+
|
| 86 |
+
x = self._bn2(self._project_conv(x))
|
| 87 |
+
|
| 88 |
+
# Skip connection and drop connect
|
| 89 |
+
input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters
|
| 90 |
+
if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters:
|
| 91 |
+
if drop_connect_rate:
|
| 92 |
+
x = drop_connect(x, p=drop_connect_rate, training=self.training)
|
| 93 |
+
x = x + inputs # skip connection
|
| 94 |
+
return x
|
| 95 |
+
|
| 96 |
+
def set_swish(self, memory_efficient=True):
|
| 97 |
+
"""Sets swish function as memory efficient (for training) or standard (for export)"""
|
| 98 |
+
self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class EfficientNet(nn.Module):
|
| 102 |
+
"""
|
| 103 |
+
An EfficientNet model. Most easily loaded with the .from_name or .from_pretrained methods
|
| 104 |
+
|
| 105 |
+
Args:
|
| 106 |
+
blocks_args (list): A list of BlockArgs to construct blocks
|
| 107 |
+
global_params (namedtuple): A set of GlobalParams shared between blocks
|
| 108 |
+
|
| 109 |
+
Example:
|
| 110 |
+
model = EfficientNet.from_pretrained('efficientnet-b0')
|
| 111 |
+
|
| 112 |
+
"""
|
| 113 |
+
|
| 114 |
+
def __init__(self, blocks_args=None, global_params=None):
|
| 115 |
+
super().__init__()
|
| 116 |
+
assert isinstance(blocks_args, list), 'blocks_args should be a list'
|
| 117 |
+
assert len(blocks_args) > 0, 'block args must be greater than 0'
|
| 118 |
+
self._global_params = global_params
|
| 119 |
+
self._blocks_args = blocks_args
|
| 120 |
+
|
| 121 |
+
# Get static or dynamic convolution depending on image size
|
| 122 |
+
Conv2d = get_same_padding_conv2d(image_size=global_params.image_size)
|
| 123 |
+
|
| 124 |
+
# Batch norm parameters
|
| 125 |
+
bn_mom = 1 - self._global_params.batch_norm_momentum
|
| 126 |
+
bn_eps = self._global_params.batch_norm_epsilon
|
| 127 |
+
|
| 128 |
+
# Stem
|
| 129 |
+
in_channels = 3 # rgb
|
| 130 |
+
out_channels = round_filters(32, self._global_params) # number of output channels
|
| 131 |
+
self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
|
| 132 |
+
self._bn0 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
|
| 133 |
+
|
| 134 |
+
# Build blocks
|
| 135 |
+
self._blocks = nn.ModuleList([])
|
| 136 |
+
for block_args in self._blocks_args:
|
| 137 |
+
|
| 138 |
+
# Update block input and output filters based on depth multiplier.
|
| 139 |
+
block_args = block_args._replace(
|
| 140 |
+
input_filters=round_filters(block_args.input_filters, self._global_params),
|
| 141 |
+
output_filters=round_filters(block_args.output_filters, self._global_params),
|
| 142 |
+
num_repeat=round_repeats(block_args.num_repeat, self._global_params)
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
# The first block needs to take care of stride and filter size increase.
|
| 146 |
+
self._blocks.append(MBConvBlock(block_args, self._global_params))
|
| 147 |
+
if block_args.num_repeat > 1:
|
| 148 |
+
block_args = block_args._replace(input_filters=block_args.output_filters, stride=1)
|
| 149 |
+
for _ in range(block_args.num_repeat - 1):
|
| 150 |
+
self._blocks.append(MBConvBlock(block_args, self._global_params))
|
| 151 |
+
|
| 152 |
+
# Head
|
| 153 |
+
in_channels = block_args.output_filters # output of final block
|
| 154 |
+
out_channels = round_filters(1280, self._global_params)
|
| 155 |
+
self._conv_head = Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
|
| 156 |
+
self._bn1 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
|
| 157 |
+
|
| 158 |
+
# Final linear layer
|
| 159 |
+
self._avg_pooling = nn.AdaptiveAvgPool2d(1)
|
| 160 |
+
self._dropout = nn.Dropout(self._global_params.dropout_rate)
|
| 161 |
+
self._fc = nn.Linear(out_channels, self._global_params.num_classes)
|
| 162 |
+
self._swish = MemoryEfficientSwish()
|
| 163 |
+
|
| 164 |
+
def set_swish(self, memory_efficient=True):
|
| 165 |
+
"""Sets swish function as memory efficient (for training) or standard (for export)"""
|
| 166 |
+
self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
|
| 167 |
+
for block in self._blocks:
|
| 168 |
+
block.set_swish(memory_efficient)
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def extract_features(self, inputs):
|
| 172 |
+
""" Returns output of the final convolution layer """
|
| 173 |
+
|
| 174 |
+
# Stem
|
| 175 |
+
x = self._swish(self._bn0(self._conv_stem(inputs)))
|
| 176 |
+
|
| 177 |
+
# Blocks
|
| 178 |
+
for idx, block in enumerate(self._blocks):
|
| 179 |
+
drop_connect_rate = self._global_params.drop_connect_rate
|
| 180 |
+
if drop_connect_rate:
|
| 181 |
+
drop_connect_rate *= float(idx) / len(self._blocks)
|
| 182 |
+
x = block(x, drop_connect_rate=drop_connect_rate)
|
| 183 |
+
|
| 184 |
+
# Head
|
| 185 |
+
x = self._swish(self._bn1(self._conv_head(x)))
|
| 186 |
+
|
| 187 |
+
return x
|
| 188 |
+
|
| 189 |
+
def forward(self, inputs):
|
| 190 |
+
""" Calls extract_features to extract features, applies final linear layer, and returns logits. """
|
| 191 |
+
bs = inputs.size(0)
|
| 192 |
+
# Convolution layers
|
| 193 |
+
x = self.extract_features(inputs)
|
| 194 |
+
|
| 195 |
+
# Pooling and final linear layer
|
| 196 |
+
x = self._avg_pooling(x)
|
| 197 |
+
x = x.view(bs, -1)
|
| 198 |
+
x = self._dropout(x)
|
| 199 |
+
x = self._fc(x)
|
| 200 |
+
return x
|
| 201 |
+
|
| 202 |
+
@classmethod
|
| 203 |
+
def from_name(cls, model_name, override_params=None):
|
| 204 |
+
cls._check_model_name_is_valid(model_name)
|
| 205 |
+
blocks_args, global_params = get_model_params(model_name, override_params)
|
| 206 |
+
return cls(blocks_args, global_params)
|
| 207 |
+
|
| 208 |
+
@classmethod
|
| 209 |
+
def from_pretrained(cls, model_name, advprop=False, num_classes=1000, in_channels=3):
|
| 210 |
+
model = cls.from_name(model_name, override_params={'num_classes': num_classes})
|
| 211 |
+
load_pretrained_weights(model, model_name, load_fc=(num_classes == 1000), advprop=advprop)
|
| 212 |
+
if in_channels != 3:
|
| 213 |
+
Conv2d = get_same_padding_conv2d(image_size = model._global_params.image_size)
|
| 214 |
+
out_channels = round_filters(32, model._global_params)
|
| 215 |
+
model._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
|
| 216 |
+
return model
|
| 217 |
+
|
| 218 |
+
@classmethod
|
| 219 |
+
def get_image_size(cls, model_name):
|
| 220 |
+
cls._check_model_name_is_valid(model_name)
|
| 221 |
+
_, _, res, _ = efficientnet_params(model_name)
|
| 222 |
+
return res
|
| 223 |
+
|
| 224 |
+
@classmethod
|
| 225 |
+
def _check_model_name_is_valid(cls, model_name):
|
| 226 |
+
""" Validates model name. """
|
| 227 |
+
valid_models = ['efficientnet-b'+str(i) for i in range(9)]
|
| 228 |
+
if model_name not in valid_models:
|
| 229 |
+
raise ValueError('model_name should be one of: ' + ', '.join(valid_models))
|
deepgaze_pytorch/features/efficientnet_pytorch/utils.py
ADDED
|
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This file contains helper functions for building the model and for loading model parameters.
|
| 3 |
+
These helper functions are built to mirror those in the official TensorFlow implementation.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import re
|
| 7 |
+
import math
|
| 8 |
+
import collections
|
| 9 |
+
from functools import partial
|
| 10 |
+
import torch
|
| 11 |
+
from torch import nn
|
| 12 |
+
from torch.nn import functional as F
|
| 13 |
+
from torch.utils import model_zoo
|
| 14 |
+
|
| 15 |
+
########################################################################
|
| 16 |
+
############### HELPERS FUNCTIONS FOR MODEL ARCHITECTURE ###############
|
| 17 |
+
########################################################################
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# Parameters for the entire model (stem, all blocks, and head)
|
| 21 |
+
GlobalParams = collections.namedtuple('GlobalParams', [
|
| 22 |
+
'batch_norm_momentum', 'batch_norm_epsilon', 'dropout_rate',
|
| 23 |
+
'num_classes', 'width_coefficient', 'depth_coefficient',
|
| 24 |
+
'depth_divisor', 'min_depth', 'drop_connect_rate', 'image_size'])
|
| 25 |
+
|
| 26 |
+
# Parameters for an individual model block
|
| 27 |
+
BlockArgs = collections.namedtuple('BlockArgs', [
|
| 28 |
+
'kernel_size', 'num_repeat', 'input_filters', 'output_filters',
|
| 29 |
+
'expand_ratio', 'id_skip', 'stride', 'se_ratio'])
|
| 30 |
+
|
| 31 |
+
# Change namedtuple defaults
|
| 32 |
+
GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields)
|
| 33 |
+
BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class SwishImplementation(torch.autograd.Function):
|
| 37 |
+
@staticmethod
|
| 38 |
+
def forward(ctx, i):
|
| 39 |
+
result = i * torch.sigmoid(i)
|
| 40 |
+
ctx.save_for_backward(i)
|
| 41 |
+
return result
|
| 42 |
+
|
| 43 |
+
@staticmethod
|
| 44 |
+
def backward(ctx, grad_output):
|
| 45 |
+
i = ctx.saved_variables[0]
|
| 46 |
+
sigmoid_i = torch.sigmoid(i)
|
| 47 |
+
return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class MemoryEfficientSwish(nn.Module):
|
| 51 |
+
def forward(self, x):
|
| 52 |
+
return SwishImplementation.apply(x)
|
| 53 |
+
|
| 54 |
+
class Swish(nn.Module):
|
| 55 |
+
def forward(self, x):
|
| 56 |
+
return x * torch.sigmoid(x)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def round_filters(filters, global_params):
|
| 60 |
+
""" Calculate and round number of filters based on depth multiplier. """
|
| 61 |
+
multiplier = global_params.width_coefficient
|
| 62 |
+
if not multiplier:
|
| 63 |
+
return filters
|
| 64 |
+
divisor = global_params.depth_divisor
|
| 65 |
+
min_depth = global_params.min_depth
|
| 66 |
+
filters *= multiplier
|
| 67 |
+
min_depth = min_depth or divisor
|
| 68 |
+
new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor)
|
| 69 |
+
if new_filters < 0.9 * filters: # prevent rounding by more than 10%
|
| 70 |
+
new_filters += divisor
|
| 71 |
+
return int(new_filters)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def round_repeats(repeats, global_params):
|
| 75 |
+
""" Round number of filters based on depth multiplier. """
|
| 76 |
+
multiplier = global_params.depth_coefficient
|
| 77 |
+
if not multiplier:
|
| 78 |
+
return repeats
|
| 79 |
+
return int(math.ceil(multiplier * repeats))
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def drop_connect(inputs, p, training):
|
| 83 |
+
""" Drop connect. """
|
| 84 |
+
if not training: return inputs
|
| 85 |
+
batch_size = inputs.shape[0]
|
| 86 |
+
keep_prob = 1 - p
|
| 87 |
+
random_tensor = keep_prob
|
| 88 |
+
random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device)
|
| 89 |
+
binary_tensor = torch.floor(random_tensor)
|
| 90 |
+
output = inputs / keep_prob * binary_tensor
|
| 91 |
+
return output
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def get_same_padding_conv2d(image_size=None):
|
| 95 |
+
""" Chooses static padding if you have specified an image size, and dynamic padding otherwise.
|
| 96 |
+
Static padding is necessary for ONNX exporting of models. """
|
| 97 |
+
if image_size is None:
|
| 98 |
+
return Conv2dDynamicSamePadding
|
| 99 |
+
else:
|
| 100 |
+
return partial(Conv2dStaticSamePadding, image_size=image_size)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class Conv2dDynamicSamePadding(nn.Conv2d):
|
| 104 |
+
""" 2D Convolutions like TensorFlow, for a dynamic image size """
|
| 105 |
+
|
| 106 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=True):
|
| 107 |
+
super().__init__(in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)
|
| 108 |
+
self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2
|
| 109 |
+
|
| 110 |
+
def forward(self, x):
|
| 111 |
+
ih, iw = x.size()[-2:]
|
| 112 |
+
kh, kw = self.weight.size()[-2:]
|
| 113 |
+
sh, sw = self.stride
|
| 114 |
+
oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
|
| 115 |
+
pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
|
| 116 |
+
pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
|
| 117 |
+
if pad_h > 0 or pad_w > 0:
|
| 118 |
+
x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
|
| 119 |
+
return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class Conv2dStaticSamePadding(nn.Conv2d):
|
| 123 |
+
""" 2D Convolutions like TensorFlow, for a fixed image size"""
|
| 124 |
+
|
| 125 |
+
def __init__(self, in_channels, out_channels, kernel_size, image_size=None, **kwargs):
|
| 126 |
+
super().__init__(in_channels, out_channels, kernel_size, **kwargs)
|
| 127 |
+
self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2
|
| 128 |
+
|
| 129 |
+
# Calculate padding based on image size and save it
|
| 130 |
+
assert image_size is not None
|
| 131 |
+
ih, iw = image_size if type(image_size) == list else [image_size, image_size]
|
| 132 |
+
kh, kw = self.weight.size()[-2:]
|
| 133 |
+
sh, sw = self.stride
|
| 134 |
+
oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
|
| 135 |
+
pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
|
| 136 |
+
pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
|
| 137 |
+
if pad_h > 0 or pad_w > 0:
|
| 138 |
+
self.static_padding = nn.ZeroPad2d((pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2))
|
| 139 |
+
else:
|
| 140 |
+
self.static_padding = Identity()
|
| 141 |
+
|
| 142 |
+
def forward(self, x):
|
| 143 |
+
x = self.static_padding(x)
|
| 144 |
+
x = F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
| 145 |
+
return x
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
class Identity(nn.Module):
|
| 149 |
+
def __init__(self, ):
|
| 150 |
+
super(Identity, self).__init__()
|
| 151 |
+
|
| 152 |
+
def forward(self, input):
|
| 153 |
+
return input
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
########################################################################
|
| 157 |
+
############## HELPERS FUNCTIONS FOR LOADING MODEL PARAMS ##############
|
| 158 |
+
########################################################################
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def efficientnet_params(model_name):
|
| 162 |
+
""" Map EfficientNet model name to parameter coefficients. """
|
| 163 |
+
params_dict = {
|
| 164 |
+
# Coefficients: width,depth,res,dropout
|
| 165 |
+
'efficientnet-b0': (1.0, 1.0, 224, 0.2),
|
| 166 |
+
'efficientnet-b1': (1.0, 1.1, 240, 0.2),
|
| 167 |
+
'efficientnet-b2': (1.1, 1.2, 260, 0.3),
|
| 168 |
+
'efficientnet-b3': (1.2, 1.4, 300, 0.3),
|
| 169 |
+
'efficientnet-b4': (1.4, 1.8, 380, 0.4),
|
| 170 |
+
'efficientnet-b5': (1.6, 2.2, 456, 0.4),
|
| 171 |
+
'efficientnet-b6': (1.8, 2.6, 528, 0.5),
|
| 172 |
+
'efficientnet-b7': (2.0, 3.1, 600, 0.5),
|
| 173 |
+
'efficientnet-b8': (2.2, 3.6, 672, 0.5),
|
| 174 |
+
'efficientnet-l2': (4.3, 5.3, 800, 0.5),
|
| 175 |
+
}
|
| 176 |
+
return params_dict[model_name]
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
class BlockDecoder(object):
|
| 180 |
+
""" Block Decoder for readability, straight from the official TensorFlow repository """
|
| 181 |
+
|
| 182 |
+
@staticmethod
|
| 183 |
+
def _decode_block_string(block_string):
|
| 184 |
+
""" Gets a block through a string notation of arguments. """
|
| 185 |
+
assert isinstance(block_string, str)
|
| 186 |
+
|
| 187 |
+
ops = block_string.split('_')
|
| 188 |
+
options = {}
|
| 189 |
+
for op in ops:
|
| 190 |
+
splits = re.split(r'(\d.*)', op)
|
| 191 |
+
if len(splits) >= 2:
|
| 192 |
+
key, value = splits[:2]
|
| 193 |
+
options[key] = value
|
| 194 |
+
|
| 195 |
+
# Check stride
|
| 196 |
+
assert (('s' in options and len(options['s']) == 1) or
|
| 197 |
+
(len(options['s']) == 2 and options['s'][0] == options['s'][1]))
|
| 198 |
+
|
| 199 |
+
return BlockArgs(
|
| 200 |
+
kernel_size=int(options['k']),
|
| 201 |
+
num_repeat=int(options['r']),
|
| 202 |
+
input_filters=int(options['i']),
|
| 203 |
+
output_filters=int(options['o']),
|
| 204 |
+
expand_ratio=int(options['e']),
|
| 205 |
+
id_skip=('noskip' not in block_string),
|
| 206 |
+
se_ratio=float(options['se']) if 'se' in options else None,
|
| 207 |
+
stride=[int(options['s'][0])])
|
| 208 |
+
|
| 209 |
+
@staticmethod
|
| 210 |
+
def _encode_block_string(block):
|
| 211 |
+
"""Encodes a block to a string."""
|
| 212 |
+
args = [
|
| 213 |
+
'r%d' % block.num_repeat,
|
| 214 |
+
'k%d' % block.kernel_size,
|
| 215 |
+
's%d%d' % (block.strides[0], block.strides[1]),
|
| 216 |
+
'e%s' % block.expand_ratio,
|
| 217 |
+
'i%d' % block.input_filters,
|
| 218 |
+
'o%d' % block.output_filters
|
| 219 |
+
]
|
| 220 |
+
if 0 < block.se_ratio <= 1:
|
| 221 |
+
args.append('se%s' % block.se_ratio)
|
| 222 |
+
if block.id_skip is False:
|
| 223 |
+
args.append('noskip')
|
| 224 |
+
return '_'.join(args)
|
| 225 |
+
|
| 226 |
+
@staticmethod
|
| 227 |
+
def decode(string_list):
|
| 228 |
+
"""
|
| 229 |
+
Decodes a list of string notations to specify blocks inside the network.
|
| 230 |
+
|
| 231 |
+
:param string_list: a list of strings, each string is a notation of block
|
| 232 |
+
:return: a list of BlockArgs namedtuples of block args
|
| 233 |
+
"""
|
| 234 |
+
assert isinstance(string_list, list)
|
| 235 |
+
blocks_args = []
|
| 236 |
+
for block_string in string_list:
|
| 237 |
+
blocks_args.append(BlockDecoder._decode_block_string(block_string))
|
| 238 |
+
return blocks_args
|
| 239 |
+
|
| 240 |
+
@staticmethod
|
| 241 |
+
def encode(blocks_args):
|
| 242 |
+
"""
|
| 243 |
+
Encodes a list of BlockArgs to a list of strings.
|
| 244 |
+
|
| 245 |
+
:param blocks_args: a list of BlockArgs namedtuples of block args
|
| 246 |
+
:return: a list of strings, each string is a notation of block
|
| 247 |
+
"""
|
| 248 |
+
block_strings = []
|
| 249 |
+
for block in blocks_args:
|
| 250 |
+
block_strings.append(BlockDecoder._encode_block_string(block))
|
| 251 |
+
return block_strings
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
def efficientnet(width_coefficient=None, depth_coefficient=None, dropout_rate=0.2,
|
| 255 |
+
drop_connect_rate=0.2, image_size=None, num_classes=1000):
|
| 256 |
+
""" Creates a efficientnet model. """
|
| 257 |
+
|
| 258 |
+
blocks_args = [
|
| 259 |
+
'r1_k3_s11_e1_i32_o16_se0.25', 'r2_k3_s22_e6_i16_o24_se0.25',
|
| 260 |
+
'r2_k5_s22_e6_i24_o40_se0.25', 'r3_k3_s22_e6_i40_o80_se0.25',
|
| 261 |
+
'r3_k5_s11_e6_i80_o112_se0.25', 'r4_k5_s22_e6_i112_o192_se0.25',
|
| 262 |
+
'r1_k3_s11_e6_i192_o320_se0.25',
|
| 263 |
+
]
|
| 264 |
+
blocks_args = BlockDecoder.decode(blocks_args)
|
| 265 |
+
|
| 266 |
+
global_params = GlobalParams(
|
| 267 |
+
batch_norm_momentum=0.99,
|
| 268 |
+
batch_norm_epsilon=1e-3,
|
| 269 |
+
dropout_rate=dropout_rate,
|
| 270 |
+
drop_connect_rate=drop_connect_rate,
|
| 271 |
+
# data_format='channels_last', # removed, this is always true in PyTorch
|
| 272 |
+
num_classes=num_classes,
|
| 273 |
+
width_coefficient=width_coefficient,
|
| 274 |
+
depth_coefficient=depth_coefficient,
|
| 275 |
+
depth_divisor=8,
|
| 276 |
+
min_depth=None,
|
| 277 |
+
image_size=image_size,
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
return blocks_args, global_params
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
def get_model_params(model_name, override_params):
|
| 284 |
+
""" Get the block args and global params for a given model """
|
| 285 |
+
if model_name.startswith('efficientnet'):
|
| 286 |
+
w, d, s, p = efficientnet_params(model_name)
|
| 287 |
+
# note: all models have drop connect rate = 0.2
|
| 288 |
+
blocks_args, global_params = efficientnet(
|
| 289 |
+
width_coefficient=w, depth_coefficient=d, dropout_rate=p, image_size=s)
|
| 290 |
+
else:
|
| 291 |
+
raise NotImplementedError('model name is not pre-defined: %s' % model_name)
|
| 292 |
+
if override_params:
|
| 293 |
+
# ValueError will be raised here if override_params has fields not included in global_params.
|
| 294 |
+
global_params = global_params._replace(**override_params)
|
| 295 |
+
return blocks_args, global_params
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
url_map = {
|
| 299 |
+
'efficientnet-b0': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth',
|
| 300 |
+
'efficientnet-b1': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b1-f1951068.pth',
|
| 301 |
+
'efficientnet-b2': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b2-8bb594d6.pth',
|
| 302 |
+
'efficientnet-b3': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b3-5fb5a3c3.pth',
|
| 303 |
+
'efficientnet-b4': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth',
|
| 304 |
+
'efficientnet-b5': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b5-b6417697.pth',
|
| 305 |
+
'efficientnet-b6': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b6-c76e70fd.pth',
|
| 306 |
+
'efficientnet-b7': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth',
|
| 307 |
+
}
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
url_map_advprop = {
|
| 311 |
+
'efficientnet-b0': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b0-b64d5a18.pth',
|
| 312 |
+
'efficientnet-b1': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b1-0f3ce85a.pth',
|
| 313 |
+
'efficientnet-b2': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b2-6e9d97e5.pth',
|
| 314 |
+
'efficientnet-b3': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b3-cdd7c0f4.pth',
|
| 315 |
+
'efficientnet-b4': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b4-44fb3a87.pth',
|
| 316 |
+
'efficientnet-b5': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b5-86493f6b.pth',
|
| 317 |
+
'efficientnet-b6': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b6-ac80338e.pth',
|
| 318 |
+
'efficientnet-b7': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b7-4652b6dd.pth',
|
| 319 |
+
'efficientnet-b8': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b8-22a8fe65.pth',
|
| 320 |
+
}
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
def load_pretrained_weights(model, model_name, load_fc=True, advprop=False):
|
| 324 |
+
""" Loads pretrained weights, and downloads if loading for the first time. """
|
| 325 |
+
# AutoAugment or Advprop (different preprocessing)
|
| 326 |
+
url_map_ = url_map_advprop if advprop else url_map
|
| 327 |
+
state_dict = model_zoo.load_url(url_map_[model_name])
|
| 328 |
+
if load_fc:
|
| 329 |
+
model.load_state_dict(state_dict)
|
| 330 |
+
else:
|
| 331 |
+
state_dict.pop('_fc.weight')
|
| 332 |
+
state_dict.pop('_fc.bias')
|
| 333 |
+
res = model.load_state_dict(state_dict, strict=False)
|
| 334 |
+
assert set(res.missing_keys) == set(['_fc.weight', '_fc.bias']), 'issue loading pretrained weights'
|
| 335 |
+
print('Loaded pretrained weights for {}'.format(model_name))
|
deepgaze_pytorch/features/inception.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import OrderedDict
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torchvision
|
| 7 |
+
|
| 8 |
+
from .normalizer import Normalizer
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class RGBInceptionV3(nn.Sequential):
|
| 14 |
+
def __init__(self):
|
| 15 |
+
super(RGBInceptionV3, self).__init__()
|
| 16 |
+
self.resnext = torch.hub.load('pytorch/vision:v0.6.0', 'inception_v3', pretrained=True)
|
| 17 |
+
self.normalizer = Normalizer()
|
| 18 |
+
super(RGBInceptionV3, self).__init__(self.normalizer, self.resnext)
|
| 19 |
+
|
| 20 |
+
|
deepgaze_pytorch/features/mobilenet.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import OrderedDict
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torchvision
|
| 7 |
+
|
| 8 |
+
from .normalizer import Normalizer
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class RGBMobileNetV2(nn.Sequential):
|
| 13 |
+
def __init__(self):
|
| 14 |
+
super(RGBMobileNetV2, self).__init__()
|
| 15 |
+
self.mobilenet_v2 = torchvision.models.mobilenet_v2(pretrained=True)
|
| 16 |
+
self.normalizer = Normalizer()
|
| 17 |
+
super(RGBMobileNetV2, self).__init__(self.normalizer, self.mobilenet_v2)
|
deepgaze_pytorch/features/normalizer.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import OrderedDict
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torchvision
|
| 7 |
+
|
| 8 |
+
class Normalizer(nn.Module):
|
| 9 |
+
def __init__(self):
|
| 10 |
+
super(Normalizer, self).__init__()
|
| 11 |
+
mean = np.array([0.485, 0.456, 0.406])
|
| 12 |
+
mean = mean[:, np.newaxis, np.newaxis]
|
| 13 |
+
|
| 14 |
+
std = np.array([0.229, 0.224, 0.225])
|
| 15 |
+
std = std[:, np.newaxis, np.newaxis]
|
| 16 |
+
|
| 17 |
+
# don't persist to keep old checkpoints working
|
| 18 |
+
self.register_buffer('mean', torch.tensor(mean), persistent=False)
|
| 19 |
+
self.register_buffer('std', torch.tensor(std), persistent=False)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def forward(self, tensor):
|
| 23 |
+
tensor = tensor / 255.0
|
| 24 |
+
|
| 25 |
+
tensor -= self.mean
|
| 26 |
+
tensor /= self.std
|
| 27 |
+
|
| 28 |
+
return tensor
|
deepgaze_pytorch/features/resnet.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import OrderedDict
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torchvision
|
| 7 |
+
|
| 8 |
+
from .normalizer import Normalizer
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class RGBResNet34(nn.Sequential):
|
| 13 |
+
def __init__(self):
|
| 14 |
+
super(RGBResNet34, self).__init__()
|
| 15 |
+
self.resnet = torchvision.models.resnet34(pretrained=True)
|
| 16 |
+
self.normalizer = Normalizer()
|
| 17 |
+
super(RGBResNet34, self).__init__(self.normalizer, self.resnet)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class RGBResNet50(nn.Sequential):
|
| 21 |
+
def __init__(self):
|
| 22 |
+
super(RGBResNet50, self).__init__()
|
| 23 |
+
self.resnet = torchvision.models.resnet50(pretrained=True)
|
| 24 |
+
self.normalizer = Normalizer()
|
| 25 |
+
super(RGBResNet50, self).__init__(self.normalizer, self.resnet)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class RGBResNet50_alt(nn.Sequential):
|
| 29 |
+
def __init__(self):
|
| 30 |
+
super(RGBResNet50, self).__init__()
|
| 31 |
+
self.resnet = torchvision.models.resnet50(pretrained=True)
|
| 32 |
+
self.normalizer = Normalizer()
|
| 33 |
+
state_dict = torch.load("Resnet-AlternativePreTrain.pth")
|
| 34 |
+
model.load_state_dict(state_dict)
|
| 35 |
+
super(RGBResNet50, self).__init__(self.normalizer, self.resnet)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class RGBResNet101(nn.Sequential):
|
| 40 |
+
def __init__(self):
|
| 41 |
+
super(RGBResNet101, self).__init__()
|
| 42 |
+
self.resnet = torchvision.models.resnet101(pretrained=True)
|
| 43 |
+
self.normalizer = Normalizer()
|
| 44 |
+
super(RGBResNet101, self).__init__(self.normalizer, self.resnet)
|
deepgaze_pytorch/features/resnext.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import OrderedDict
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torchvision
|
| 7 |
+
|
| 8 |
+
from .normalizer import Normalizer
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class RGBResNext50(nn.Sequential):
|
| 13 |
+
def __init__(self):
|
| 14 |
+
super(RGBResNext50, self).__init__()
|
| 15 |
+
self.resnext = torch.hub.load('pytorch/vision:v0.6.0', 'resnext50_32x4d', pretrained=True)
|
| 16 |
+
self.normalizer = Normalizer()
|
| 17 |
+
super(RGBResNext50, self).__init__(self.normalizer, self.resnext)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class RGBResNext101(nn.Sequential):
|
| 21 |
+
def __init__(self):
|
| 22 |
+
super(RGBResNext101, self).__init__()
|
| 23 |
+
self.resnext = torch.hub.load('pytorch/vision:v0.6.0', 'resnext101_32x8d', pretrained=True)
|
| 24 |
+
self.normalizer = Normalizer()
|
| 25 |
+
super(RGBResNext101, self).__init__(self.normalizer, self.resnext)
|
| 26 |
+
|
| 27 |
+
|
deepgaze_pytorch/features/shapenet.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This code was adapted from: https://github.com/rgeirhos/texture-vs-shape
|
| 3 |
+
"""
|
| 4 |
+
import os
|
| 5 |
+
import sys
|
| 6 |
+
from collections import OrderedDict
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torchvision
|
| 10 |
+
import torchvision.models
|
| 11 |
+
from torch.utils import model_zoo
|
| 12 |
+
|
| 13 |
+
from .normalizer import Normalizer
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def load_model(model_name):
|
| 17 |
+
|
| 18 |
+
model_urls = {
|
| 19 |
+
'resnet50_trained_on_SIN': 'https://bitbucket.org/robert_geirhos/texture-vs-shape-pretrained-models/raw/6f41d2e86fc60566f78de64ecff35cc61eb6436f/resnet50_train_60_epochs-c8e5653e.pth.tar',
|
| 20 |
+
'resnet50_trained_on_SIN_and_IN': 'https://bitbucket.org/robert_geirhos/texture-vs-shape-pretrained-models/raw/60b770e128fffcbd8562a3ab3546c1a735432d03/resnet50_train_45_epochs_combined_IN_SF-2a0d100e.pth.tar',
|
| 21 |
+
'resnet50_trained_on_SIN_and_IN_then_finetuned_on_IN': 'https://bitbucket.org/robert_geirhos/texture-vs-shape-pretrained-models/raw/60b770e128fffcbd8562a3ab3546c1a735432d03/resnet50_finetune_60_epochs_lr_decay_after_30_start_resnet50_train_45_epochs_combined_IN_SF-ca06340c.pth.tar',
|
| 22 |
+
'vgg16_trained_on_SIN': 'https://bitbucket.org/robert_geirhos/texture-vs-shape-pretrained-models/raw/0008049cd10f74a944c6d5e90d4639927f8620ae/vgg16_train_60_epochs_lr0.01-6c6fcc9f.pth.tar',
|
| 23 |
+
'alexnet_trained_on_SIN': 'https://bitbucket.org/robert_geirhos/texture-vs-shape-pretrained-models/raw/0008049cd10f74a944c6d5e90d4639927f8620ae/alexnet_train_60_epochs_lr0.001-b4aa5238.pth.tar',
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
if "resnet50" in model_name:
|
| 27 |
+
#print("Using the ResNet50 architecture.")
|
| 28 |
+
model = torchvision.models.resnet50(pretrained=False)
|
| 29 |
+
#model = torch.nn.DataParallel(model) # .cuda()
|
| 30 |
+
# fake DataParallel structrue
|
| 31 |
+
model = torch.nn.Sequential(OrderedDict([('module', model)]))
|
| 32 |
+
checkpoint = model_zoo.load_url(model_urls[model_name], map_location=torch.device('cpu'))
|
| 33 |
+
elif "vgg16" in model_name:
|
| 34 |
+
#print("Using the VGG-16 architecture.")
|
| 35 |
+
|
| 36 |
+
# download model from URL manually and save to desired location
|
| 37 |
+
filepath = "./vgg16_train_60_epochs_lr0.01-6c6fcc9f.pth.tar"
|
| 38 |
+
|
| 39 |
+
assert os.path.exists(filepath), "Please download the VGG model yourself from the following link and save it locally: https://drive.google.com/drive/folders/1A0vUWyU6fTuc-xWgwQQeBvzbwi6geYQK (too large to be downloaded automatically like the other models)"
|
| 40 |
+
|
| 41 |
+
model = torchvision.models.vgg16(pretrained=False)
|
| 42 |
+
model.features = torch.nn.DataParallel(model.features)
|
| 43 |
+
model.cuda()
|
| 44 |
+
checkpoint = torch.load(filepath, map_location=torch.device('cpu'))
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
elif "alexnet" in model_name:
|
| 48 |
+
#print("Using the AlexNet architecture.")
|
| 49 |
+
model = torchvision.models.alexnet(pretrained=False)
|
| 50 |
+
model.features = torch.nn.DataParallel(model.features)
|
| 51 |
+
model.cuda()
|
| 52 |
+
checkpoint = model_zoo.load_url(model_urls[model_name], map_location=torch.device('cpu'))
|
| 53 |
+
else:
|
| 54 |
+
raise ValueError("unknown model architecture.")
|
| 55 |
+
|
| 56 |
+
model.load_state_dict(checkpoint["state_dict"])
|
| 57 |
+
return model
|
| 58 |
+
|
| 59 |
+
# --- DeepGaze Adaptation ----
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class RGBShapeNetA(nn.Sequential):
|
| 65 |
+
def __init__(self):
|
| 66 |
+
super(RGBShapeNetA, self).__init__()
|
| 67 |
+
self.shapenet = load_model("resnet50_trained_on_SIN")
|
| 68 |
+
self.normalizer = Normalizer()
|
| 69 |
+
super(RGBShapeNetA, self).__init__(self.normalizer, self.shapenet)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class RGBShapeNetB(nn.Sequential):
|
| 74 |
+
def __init__(self):
|
| 75 |
+
super(RGBShapeNetB, self).__init__()
|
| 76 |
+
self.shapenet = load_model("resnet50_trained_on_SIN_and_IN")
|
| 77 |
+
self.normalizer = Normalizer()
|
| 78 |
+
super(RGBShapeNetB, self).__init__(self.normalizer, self.shapenet)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class RGBShapeNetC(nn.Sequential):
|
| 82 |
+
def __init__(self):
|
| 83 |
+
super(RGBShapeNetC, self).__init__()
|
| 84 |
+
self.shapenet = load_model("resnet50_trained_on_SIN_and_IN_then_finetuned_on_IN")
|
| 85 |
+
self.normalizer = Normalizer()
|
| 86 |
+
super(RGBShapeNetC, self).__init__(self.normalizer, self.shapenet)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
|
deepgaze_pytorch/features/squeezenet.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import OrderedDict
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torchvision
|
| 7 |
+
|
| 8 |
+
from .normalizer import Normalizer
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class RGBSqueezeNet(nn.Sequential):
|
| 12 |
+
def __init__(self):
|
| 13 |
+
super(RGBSqueezeNet, self).__init__()
|
| 14 |
+
self.squeezenet = torch.hub.load('pytorch/vision:v0.6.0', 'squeezenet1_0', pretrained=True)
|
| 15 |
+
self.normalizer = Normalizer()
|
| 16 |
+
super(RGBSqueezeNet, self).__init__(self.normalizer, self.squeezenet)
|
| 17 |
+
|
deepgaze_pytorch/features/swav.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import OrderedDict
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torchvision
|
| 7 |
+
|
| 8 |
+
from .normalizer import Normalizer
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class RGBSwav(nn.Sequential):
|
| 14 |
+
def __init__(self):
|
| 15 |
+
super(RGBSwav, self).__init__()
|
| 16 |
+
self.swav = torch.hub.load('facebookresearch/swav', 'resnet50', pretrained=True)
|
| 17 |
+
self.normalizer = Normalizer()
|
| 18 |
+
super(RGBSwav, self).__init__(self.normalizer, self.swav)
|
| 19 |
+
|
| 20 |
+
|
deepgaze_pytorch/features/uninformative.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import OrderedDict
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class OnesLayer(nn.Module):
|
| 8 |
+
def __init__(self, size=None):
|
| 9 |
+
super().__init__()
|
| 10 |
+
self.size = size
|
| 11 |
+
|
| 12 |
+
def forward(self, tensor):
|
| 13 |
+
shape = list(tensor.shape)
|
| 14 |
+
shape[1] = 1 # return only one channel
|
| 15 |
+
|
| 16 |
+
if self.size is not None:
|
| 17 |
+
shape[2], shape[3] = self.size
|
| 18 |
+
|
| 19 |
+
return torch.ones(shape, dtype=torch.float32, device=tensor.device)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class UninformativeFeatures(torch.nn.Sequential):
|
| 23 |
+
def __init__(self):
|
| 24 |
+
super().__init__(OrderedDict([
|
| 25 |
+
('ones', OnesLayer(size=(1, 1))),
|
| 26 |
+
]))
|
deepgaze_pytorch/features/vgg.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import OrderedDict
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torchvision
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class VGGInputNormalization(torch.nn.Module):
|
| 10 |
+
def __init__(self, inplace=True):
|
| 11 |
+
super().__init__()
|
| 12 |
+
|
| 13 |
+
self.inplace = inplace
|
| 14 |
+
|
| 15 |
+
mean = np.array([0.485, 0.456, 0.406])
|
| 16 |
+
mean = mean[:, np.newaxis, np.newaxis]
|
| 17 |
+
|
| 18 |
+
std = np.array([0.229, 0.224, 0.225])
|
| 19 |
+
std = std[:, np.newaxis, np.newaxis]
|
| 20 |
+
self.register_buffer('mean', torch.tensor(mean))
|
| 21 |
+
self.register_buffer('std', torch.tensor(std))
|
| 22 |
+
|
| 23 |
+
def forward(self, tensor):
|
| 24 |
+
if self.inplace:
|
| 25 |
+
tensor /= 255.0
|
| 26 |
+
else:
|
| 27 |
+
tensor = tensor / 255.0
|
| 28 |
+
|
| 29 |
+
tensor -= self.mean
|
| 30 |
+
tensor /= self.std
|
| 31 |
+
|
| 32 |
+
return tensor
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class VGG19BNNamedFeatures(torch.nn.Sequential):
|
| 36 |
+
def __init__(self):
|
| 37 |
+
names = []
|
| 38 |
+
for block in range(5):
|
| 39 |
+
block_size = 2 if block < 2 else 4
|
| 40 |
+
for layer in range(block_size):
|
| 41 |
+
names.append(f'conv{block+1}_{layer+1}')
|
| 42 |
+
names.append(f'bn{block+1}_{layer+1}')
|
| 43 |
+
names.append(f'relu{block+1}_{layer+1}')
|
| 44 |
+
names.append(f'pool{block+1}')
|
| 45 |
+
|
| 46 |
+
vgg = torchvision.models.vgg19_bn(pretrained=True)
|
| 47 |
+
vgg_features = vgg.features
|
| 48 |
+
vgg.classifier = torch.nn.Sequential()
|
| 49 |
+
|
| 50 |
+
assert len(names) == len(vgg_features)
|
| 51 |
+
|
| 52 |
+
named_features = OrderedDict({'normalize': VGGInputNormalization()})
|
| 53 |
+
|
| 54 |
+
for name, feature in zip(names, vgg_features):
|
| 55 |
+
if isinstance(feature, nn.MaxPool2d):
|
| 56 |
+
feature.ceil_mode = True
|
| 57 |
+
named_features[name] = feature
|
| 58 |
+
|
| 59 |
+
super().__init__(named_features)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class VGG19NamedFeatures(torch.nn.Sequential):
|
| 63 |
+
def __init__(self):
|
| 64 |
+
names = []
|
| 65 |
+
for block in range(5):
|
| 66 |
+
block_size = 2 if block < 2 else 4
|
| 67 |
+
for layer in range(block_size):
|
| 68 |
+
names.append(f'conv{block+1}_{layer+1}')
|
| 69 |
+
names.append(f'relu{block+1}_{layer+1}')
|
| 70 |
+
names.append(f'pool{block+1}')
|
| 71 |
+
|
| 72 |
+
vgg = torchvision.models.vgg19(pretrained=True)
|
| 73 |
+
vgg_features = vgg.features
|
| 74 |
+
vgg.classifier = torch.nn.Sequential()
|
| 75 |
+
|
| 76 |
+
assert len(names) == len(vgg_features)
|
| 77 |
+
|
| 78 |
+
named_features = OrderedDict({'normalize': VGGInputNormalization()})
|
| 79 |
+
|
| 80 |
+
for name, feature in zip(names, vgg_features):
|
| 81 |
+
if isinstance(feature, nn.MaxPool2d):
|
| 82 |
+
feature.ceil_mode = True
|
| 83 |
+
|
| 84 |
+
named_features[name] = feature
|
| 85 |
+
|
| 86 |
+
super().__init__(named_features)
|
deepgaze_pytorch/features/vggnet.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import OrderedDict
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torchvision
|
| 7 |
+
|
| 8 |
+
from .normalizer import Normalizer
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class RGBvgg19(nn.Sequential):
|
| 12 |
+
def __init__(self):
|
| 13 |
+
super(RGBvgg19, self).__init__()
|
| 14 |
+
self.model = torch.hub.load('pytorch/vision:v0.6.0', 'vgg19', pretrained=True)
|
| 15 |
+
self.normalizer = Normalizer()
|
| 16 |
+
super(RGBvgg19, self).__init__(self.normalizer, self.model)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class RGBvgg11(nn.Sequential):
|
| 20 |
+
def __init__(self):
|
| 21 |
+
super(RGBvgg11, self).__init__()
|
| 22 |
+
self.model = torch.hub.load('pytorch/vision:v0.6.0', 'vgg11', pretrained=True)
|
| 23 |
+
self.normalizer = Normalizer()
|
| 24 |
+
super(RGBvgg11, self).__init__(self.normalizer, self.model)
|
deepgaze_pytorch/features/wsl.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import OrderedDict
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torchvision
|
| 7 |
+
|
| 8 |
+
from .normalizer import Normalizer
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class RGBResNext50(nn.Sequential):
|
| 13 |
+
def __init__(self):
|
| 14 |
+
super(RGBResNext50, self).__init__()
|
| 15 |
+
self.resnext = torch.hub.load('facebookresearch/WSL-Images', 'resnext50_32x16d_wsl')
|
| 16 |
+
self.normalizer = Normalizer()
|
| 17 |
+
super(RGBResNext50, self).__init__(self.normalizer, self.resnext)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class RGBResNext101(nn.Sequential):
|
| 21 |
+
def __init__(self):
|
| 22 |
+
super(RGBResNext101, self).__init__()
|
| 23 |
+
self.resnext = torch.hub.load('facebookresearch/WSL-Images', 'resnext101_32x16d_wsl')
|
| 24 |
+
self.normalizer = Normalizer()
|
| 25 |
+
super(RGBResNext101, self).__init__(self.normalizer, self.resnext)
|
| 26 |
+
|
| 27 |
+
|
deepgaze_pytorch/layers.py
ADDED
|
@@ -0,0 +1,427 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# pylint: disable=missing-module-docstring,invalid-name
|
| 2 |
+
# pylint: disable=missing-docstring
|
| 3 |
+
# pylint: disable=line-too-long
|
| 4 |
+
|
| 5 |
+
import math
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class LayerNorm(nn.Module):
|
| 14 |
+
r"""Applies Layer Normalization over a mini-batch of inputs as described in
|
| 15 |
+
the paper `Layer Normalization`_ .
|
| 16 |
+
|
| 17 |
+
.. math::
|
| 18 |
+
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
|
| 19 |
+
|
| 20 |
+
The mean and standard-deviation are calculated separately over the last
|
| 21 |
+
certain number dimensions which have to be of the shape specified by
|
| 22 |
+
:attr:`normalized_shape`.
|
| 23 |
+
:math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
|
| 24 |
+
:attr:`normalized_shape` if :attr:`elementwise_affine` is ``True``.
|
| 25 |
+
|
| 26 |
+
.. note::
|
| 27 |
+
Unlike Batch Normalization and Instance Normalization, which applies
|
| 28 |
+
scalar scale and bias for each entire channel/plane with the
|
| 29 |
+
:attr:`affine` option, Layer Normalization applies per-element scale and
|
| 30 |
+
bias with :attr:`elementwise_affine`.
|
| 31 |
+
|
| 32 |
+
This layer uses statistics computed from input data in both training and
|
| 33 |
+
evaluation modes.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
normalized_shape (int or list or torch.Size): input shape from an expected input
|
| 37 |
+
of size
|
| 38 |
+
|
| 39 |
+
.. math::
|
| 40 |
+
[* \times \text{normalized\_shape}[0] \times \text{normalized\_shape}[1]
|
| 41 |
+
\times \ldots \times \text{normalized\_shape}[-1]]
|
| 42 |
+
|
| 43 |
+
If a single integer is used, it is treated as a singleton list, and this module will
|
| 44 |
+
normalize over the last dimension which is expected to be of that specific size.
|
| 45 |
+
eps: a value added to the denominator for numerical stability. Default: 1e-5
|
| 46 |
+
elementwise_affine: a boolean value that when set to ``True``, this module
|
| 47 |
+
has learnable per-element affine parameters initialized to ones (for weights)
|
| 48 |
+
and zeros (for biases). Default: ``True``.
|
| 49 |
+
|
| 50 |
+
Shape:
|
| 51 |
+
- Input: :math:`(N, *)`
|
| 52 |
+
- Output: :math:`(N, *)` (same shape as input)
|
| 53 |
+
|
| 54 |
+
Examples::
|
| 55 |
+
|
| 56 |
+
>>> input = torch.randn(20, 5, 10, 10)
|
| 57 |
+
>>> # With Learnable Parameters
|
| 58 |
+
>>> m = nn.LayerNorm(input.size()[1:])
|
| 59 |
+
>>> # Without Learnable Parameters
|
| 60 |
+
>>> m = nn.LayerNorm(input.size()[1:], elementwise_affine=False)
|
| 61 |
+
>>> # Normalize over last two dimensions
|
| 62 |
+
>>> m = nn.LayerNorm([10, 10])
|
| 63 |
+
>>> # Normalize over last dimension of size 10
|
| 64 |
+
>>> m = nn.LayerNorm(10)
|
| 65 |
+
>>> # Activating the module
|
| 66 |
+
>>> output = m(input)
|
| 67 |
+
|
| 68 |
+
.. _`Layer Normalization`: https://arxiv.org/abs/1607.06450
|
| 69 |
+
"""
|
| 70 |
+
__constants__ = ['features', 'weight', 'bias', 'eps', 'center', 'scale']
|
| 71 |
+
|
| 72 |
+
def __init__(self, features, eps=1e-12, center=True, scale=True):
|
| 73 |
+
super(LayerNorm, self).__init__()
|
| 74 |
+
self.features = features
|
| 75 |
+
self.eps = eps
|
| 76 |
+
self.center = center
|
| 77 |
+
self.scale = scale
|
| 78 |
+
|
| 79 |
+
if self.scale:
|
| 80 |
+
self.weight = nn.Parameter(torch.Tensor(self.features))
|
| 81 |
+
else:
|
| 82 |
+
self.register_parameter('weight', None)
|
| 83 |
+
|
| 84 |
+
if self.center:
|
| 85 |
+
self.bias = nn.Parameter(torch.Tensor(self.features))
|
| 86 |
+
else:
|
| 87 |
+
self.register_parameter('bias', None)
|
| 88 |
+
|
| 89 |
+
self.reset_parameters()
|
| 90 |
+
|
| 91 |
+
def reset_parameters(self):
|
| 92 |
+
if self.scale:
|
| 93 |
+
nn.init.ones_(self.weight)
|
| 94 |
+
|
| 95 |
+
if self.center:
|
| 96 |
+
nn.init.zeros_(self.bias)
|
| 97 |
+
|
| 98 |
+
def adjust_parameter(self, tensor, parameter):
|
| 99 |
+
return torch.repeat_interleave(
|
| 100 |
+
torch.repeat_interleave(
|
| 101 |
+
parameter.view(-1, 1, 1),
|
| 102 |
+
repeats=tensor.shape[2],
|
| 103 |
+
dim=1),
|
| 104 |
+
repeats=tensor.shape[3],
|
| 105 |
+
dim=2
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
def forward(self, input):
|
| 109 |
+
normalized_shape = (self.features, input.shape[2], input.shape[3])
|
| 110 |
+
weight = self.adjust_parameter(input, self.weight)
|
| 111 |
+
bias = self.adjust_parameter(input, self.bias)
|
| 112 |
+
return F.layer_norm(
|
| 113 |
+
input, normalized_shape, weight, bias, self.eps)
|
| 114 |
+
|
| 115 |
+
def extra_repr(self):
|
| 116 |
+
return '{features}, eps={eps}, ' \
|
| 117 |
+
'center={center}, scale={scale}'.format(**self.__dict__)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def gaussian_filter_1d(tensor, dim, sigma, truncate=4, kernel_size=None, padding_mode='replicate', padding_value=0.0):
|
| 121 |
+
sigma = torch.as_tensor(sigma, device=tensor.device, dtype=tensor.dtype)
|
| 122 |
+
|
| 123 |
+
if kernel_size is not None:
|
| 124 |
+
kernel_size = torch.as_tensor(kernel_size, device=tensor.device, dtype=torch.int64)
|
| 125 |
+
else:
|
| 126 |
+
kernel_size = torch.as_tensor(2 * torch.ceil(truncate * sigma) + 1, device=tensor.device, dtype=torch.int64)
|
| 127 |
+
|
| 128 |
+
kernel_size = kernel_size.detach()
|
| 129 |
+
|
| 130 |
+
kernel_size_int = kernel_size.detach().cpu().numpy()
|
| 131 |
+
|
| 132 |
+
mean = (torch.as_tensor(kernel_size, dtype=tensor.dtype) - 1) / 2
|
| 133 |
+
|
| 134 |
+
grid = torch.arange(kernel_size, device=tensor.device) - mean
|
| 135 |
+
|
| 136 |
+
kernel_shape = (1, 1, kernel_size)
|
| 137 |
+
grid = grid.view(kernel_shape)
|
| 138 |
+
|
| 139 |
+
grid = grid.detach()
|
| 140 |
+
|
| 141 |
+
source_shape = tensor.shape
|
| 142 |
+
|
| 143 |
+
tensor = torch.movedim(tensor, dim, len(source_shape)-1)
|
| 144 |
+
dim_last_shape = tensor.shape
|
| 145 |
+
assert tensor.shape[-1] == source_shape[dim]
|
| 146 |
+
|
| 147 |
+
# we need reshape instead of view for batches like B x C x H x W
|
| 148 |
+
tensor = tensor.reshape(-1, 1, source_shape[dim])
|
| 149 |
+
|
| 150 |
+
padding = (math.ceil((kernel_size_int - 1) / 2), math.ceil((kernel_size_int - 1) / 2))
|
| 151 |
+
tensor_ = F.pad(tensor, padding, padding_mode, padding_value)
|
| 152 |
+
|
| 153 |
+
# create gaussian kernel from grid using current sigma
|
| 154 |
+
kernel = torch.exp(-0.5 * (grid / sigma) ** 2)
|
| 155 |
+
kernel = kernel / kernel.sum()
|
| 156 |
+
|
| 157 |
+
# convolve input with gaussian kernel
|
| 158 |
+
tensor_ = F.conv1d(tensor_, kernel)
|
| 159 |
+
tensor_ = tensor_.view(dim_last_shape)
|
| 160 |
+
tensor_ = torch.movedim(tensor_, len(source_shape)-1, dim)
|
| 161 |
+
|
| 162 |
+
assert tensor_.shape == source_shape
|
| 163 |
+
|
| 164 |
+
return tensor_
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
class GaussianFilterNd(nn.Module):
|
| 168 |
+
"""A differentiable gaussian filter"""
|
| 169 |
+
|
| 170 |
+
def __init__(self, dims, sigma, truncate=4, kernel_size=None, padding_mode='replicate', padding_value=0.0,
|
| 171 |
+
trainable=False):
|
| 172 |
+
"""Creates a 1d gaussian filter
|
| 173 |
+
|
| 174 |
+
Args:
|
| 175 |
+
dims ([int]): the dimensions to which the gaussian filter is applied. Negative values won't work
|
| 176 |
+
sigma (float): standard deviation of the gaussian filter (blur size)
|
| 177 |
+
input_dims (int, optional): number of input dimensions ignoring batch and channel dimension,
|
| 178 |
+
i.e. use input_dims=2 for images (default: 2).
|
| 179 |
+
truncate (float, optional): truncate the filter at this many standard deviations (default: 4.0).
|
| 180 |
+
This has no effect if the `kernel_size` is explicitely set
|
| 181 |
+
kernel_size (int): size of the gaussian kernel convolved with the input
|
| 182 |
+
padding_mode (string, optional): Padding mode implemented by `torch.nn.functional.pad`.
|
| 183 |
+
padding_value (string, optional): Value used for constant padding.
|
| 184 |
+
"""
|
| 185 |
+
# IDEA determine input_dims dynamically for every input
|
| 186 |
+
super(GaussianFilterNd, self).__init__()
|
| 187 |
+
|
| 188 |
+
self.dims = dims
|
| 189 |
+
self.sigma = nn.Parameter(torch.tensor(sigma, dtype=torch.float32), requires_grad=trainable) # default: no optimization
|
| 190 |
+
self.truncate = truncate
|
| 191 |
+
self.kernel_size = kernel_size
|
| 192 |
+
|
| 193 |
+
# setup padding
|
| 194 |
+
self.padding_mode = padding_mode
|
| 195 |
+
self.padding_value = padding_value
|
| 196 |
+
|
| 197 |
+
def forward(self, tensor):
|
| 198 |
+
"""Applies the gaussian filter to the given tensor"""
|
| 199 |
+
for dim in self.dims:
|
| 200 |
+
tensor = gaussian_filter_1d(
|
| 201 |
+
tensor,
|
| 202 |
+
dim=dim,
|
| 203 |
+
sigma=self.sigma,
|
| 204 |
+
truncate=self.truncate,
|
| 205 |
+
kernel_size=self.kernel_size,
|
| 206 |
+
padding_mode=self.padding_mode,
|
| 207 |
+
padding_value=self.padding_value,
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
return tensor
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
class Conv2dMultiInput(nn.Module):
|
| 214 |
+
def __init__(self, in_channels, out_channels, kernel_size, bias=True):
|
| 215 |
+
super().__init__()
|
| 216 |
+
self.in_channels = in_channels
|
| 217 |
+
self.out_channels = out_channels
|
| 218 |
+
|
| 219 |
+
for k, _in_channels in enumerate(in_channels):
|
| 220 |
+
if _in_channels:
|
| 221 |
+
setattr(self, f'conv_part{k}', nn.Conv2d(_in_channels, out_channels, kernel_size, bias=bias))
|
| 222 |
+
|
| 223 |
+
def forward(self, tensors):
|
| 224 |
+
assert len(tensors) == len(self.in_channels)
|
| 225 |
+
|
| 226 |
+
out = None
|
| 227 |
+
for k, (count, tensor) in enumerate(zip(self.in_channels, tensors)):
|
| 228 |
+
if not count:
|
| 229 |
+
continue
|
| 230 |
+
_out = getattr(self, f'conv_part{k}')(tensor)
|
| 231 |
+
|
| 232 |
+
if out is None:
|
| 233 |
+
out = _out
|
| 234 |
+
else:
|
| 235 |
+
out += _out
|
| 236 |
+
|
| 237 |
+
return out
|
| 238 |
+
|
| 239 |
+
# def extra_repr(self):
|
| 240 |
+
# return f'{self.in_channels}'
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
class LayerNormMultiInput(nn.Module):
|
| 244 |
+
__constants__ = ['features', 'weight', 'bias', 'eps', 'center', 'scale']
|
| 245 |
+
|
| 246 |
+
def __init__(self, features, eps=1e-12, center=True, scale=True):
|
| 247 |
+
super().__init__()
|
| 248 |
+
self.features = features
|
| 249 |
+
self.eps = eps
|
| 250 |
+
self.center = center
|
| 251 |
+
self.scale = scale
|
| 252 |
+
|
| 253 |
+
for k, _features in enumerate(features):
|
| 254 |
+
if _features:
|
| 255 |
+
setattr(self, f'layernorm_part{k}', LayerNorm(_features, eps=eps, center=center, scale=scale))
|
| 256 |
+
|
| 257 |
+
def forward(self, tensors):
|
| 258 |
+
assert len(tensors) == len(self.features)
|
| 259 |
+
|
| 260 |
+
out = []
|
| 261 |
+
for k, (count, tensor) in enumerate(zip(self.features, tensors)):
|
| 262 |
+
if not count:
|
| 263 |
+
assert tensor is None
|
| 264 |
+
out.append(None)
|
| 265 |
+
continue
|
| 266 |
+
out.append(getattr(self, f'layernorm_part{k}')(tensor))
|
| 267 |
+
|
| 268 |
+
return out
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
class Bias(nn.Module):
|
| 272 |
+
def __init__(self, channels):
|
| 273 |
+
super().__init__()
|
| 274 |
+
self.channels = channels
|
| 275 |
+
self.bias = nn.Parameter(torch.zeros(channels))
|
| 276 |
+
|
| 277 |
+
def forward(self, tensor):
|
| 278 |
+
return tensor + self.bias[np.newaxis, :, np.newaxis, np.newaxis]
|
| 279 |
+
|
| 280 |
+
def extra_repr(self):
|
| 281 |
+
return f'channels={self.channels}'
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
class SelfAttention(nn.Module):
|
| 285 |
+
""" Self attention Layer
|
| 286 |
+
|
| 287 |
+
adapted from https://discuss.pytorch.org/t/attention-in-image-classification/80147/3
|
| 288 |
+
"""
|
| 289 |
+
|
| 290 |
+
def __init__(self, in_channels, out_channels=None, key_channels=None, activation=None, skip_connection_with_convolution=False, return_attention=True):
|
| 291 |
+
super().__init__()
|
| 292 |
+
self.in_channels = in_channels
|
| 293 |
+
if out_channels is None:
|
| 294 |
+
out_channels = in_channels
|
| 295 |
+
self.out_channels = out_channels
|
| 296 |
+
if key_channels is None:
|
| 297 |
+
key_channels = in_channels // 8
|
| 298 |
+
self.key_channels = key_channels
|
| 299 |
+
self.activation = activation
|
| 300 |
+
self.skip_connection_with_convolution = skip_connection_with_convolution
|
| 301 |
+
if not self.skip_connection_with_convolution:
|
| 302 |
+
if self.out_channels != self.in_channels:
|
| 303 |
+
raise ValueError("out_channels has to be equal to in_channels with true skip connection!")
|
| 304 |
+
self.return_attention = return_attention
|
| 305 |
+
|
| 306 |
+
self.query_conv = nn.Conv2d(in_channels=in_channels, out_channels=key_channels, kernel_size=1)
|
| 307 |
+
self.key_conv = nn.Conv2d(in_channels=in_channels, out_channels=key_channels, kernel_size=1)
|
| 308 |
+
self.value_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1)
|
| 309 |
+
self.gamma = nn.Parameter(torch.zeros(1))
|
| 310 |
+
if self.skip_connection_with_convolution:
|
| 311 |
+
self.skip_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1)
|
| 312 |
+
|
| 313 |
+
self.softmax = nn.Softmax(dim=-1)
|
| 314 |
+
|
| 315 |
+
def forward(self, x):
|
| 316 |
+
"""
|
| 317 |
+
inputs :
|
| 318 |
+
x : input feature maps( B X C X W X H)
|
| 319 |
+
returns :
|
| 320 |
+
out : self attention value + input feature
|
| 321 |
+
attention: B X N X N (N is Width*Height)
|
| 322 |
+
"""
|
| 323 |
+
m_batchsize, C, width, height = x.size()
|
| 324 |
+
proj_query = self.query_conv(x).view(m_batchsize, -1, width * height).permute(0, 2, 1) # B X CX(N)
|
| 325 |
+
proj_key = self.key_conv(x).view(m_batchsize, -1, width * height) # B X C x (*W*H)
|
| 326 |
+
energy = torch.bmm(proj_query, proj_key) # transpose check
|
| 327 |
+
attention = self.softmax(energy) # BX (N) X (N)
|
| 328 |
+
proj_value = self.value_conv(x).view(m_batchsize, -1, width * height) # B X C X N
|
| 329 |
+
|
| 330 |
+
out = torch.bmm(proj_value, attention.permute(0, 2, 1))
|
| 331 |
+
out = out.view(m_batchsize, self.out_channels, width, height)
|
| 332 |
+
|
| 333 |
+
if self.skip_connection_with_convolution:
|
| 334 |
+
skip_connection = self.skip_conv(x)
|
| 335 |
+
else:
|
| 336 |
+
skip_connection = x
|
| 337 |
+
out = self.gamma * out + skip_connection
|
| 338 |
+
|
| 339 |
+
if self.activation is not None:
|
| 340 |
+
out = self.activation(out)
|
| 341 |
+
|
| 342 |
+
if self.return_attention:
|
| 343 |
+
return out, attention
|
| 344 |
+
|
| 345 |
+
return out
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
class MultiHeadSelfAttention(nn.Module):
|
| 349 |
+
""" Self attention Layer
|
| 350 |
+
|
| 351 |
+
adapted from https://discuss.pytorch.org/t/attention-in-image-classification/80147/3
|
| 352 |
+
"""
|
| 353 |
+
|
| 354 |
+
def __init__(self, in_channels, heads, out_channels=None, key_channels=None, activation=None, skip_connection_with_convolution=False):
|
| 355 |
+
super().__init__()
|
| 356 |
+
self.heads = heads
|
| 357 |
+
self.heads = nn.ModuleList([SelfAttention(
|
| 358 |
+
in_channels=in_channels,
|
| 359 |
+
out_channels=out_channels,
|
| 360 |
+
key_channels=key_channels,
|
| 361 |
+
activation=activation,
|
| 362 |
+
skip_connection_with_convolution=skip_connection_with_convolution,
|
| 363 |
+
return_attention=False,
|
| 364 |
+
) for _ in range(heads)])
|
| 365 |
+
|
| 366 |
+
def forward(self, tensor):
|
| 367 |
+
outs = [head(tensor) for head in self.heads]
|
| 368 |
+
out = torch.cat(outs, dim=1)
|
| 369 |
+
return out
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
class FlexibleScanpathHistoryEncoding(nn.Module):
|
| 373 |
+
"""
|
| 374 |
+
a convolutional layer which works for different numbers of previous fixations.
|
| 375 |
+
|
| 376 |
+
Nonexistent fixations will deactivate the respective convolutions
|
| 377 |
+
the bias will be added per fixation (if the given fixation is present)
|
| 378 |
+
"""
|
| 379 |
+
def __init__(self, in_fixations, channels_per_fixation, out_channels, kernel_size, bias=True,):
|
| 380 |
+
super().__init__()
|
| 381 |
+
self.in_fixations = in_fixations
|
| 382 |
+
self.channels_per_fixation = channels_per_fixation
|
| 383 |
+
self.out_channels = out_channels
|
| 384 |
+
self.kernel_size = kernel_size
|
| 385 |
+
self.bias = bias
|
| 386 |
+
self.convolutions = nn.ModuleList([
|
| 387 |
+
nn.Conv2d(
|
| 388 |
+
in_channels=self.channels_per_fixation,
|
| 389 |
+
out_channels=self.out_channels,
|
| 390 |
+
kernel_size=self.kernel_size,
|
| 391 |
+
bias=self.bias
|
| 392 |
+
) for i in range(in_fixations)
|
| 393 |
+
])
|
| 394 |
+
|
| 395 |
+
def forward(self, tensor):
|
| 396 |
+
results = None
|
| 397 |
+
valid_fixations = ~torch.isnan(
|
| 398 |
+
tensor[:, :self.in_fixations, 0, 0]
|
| 399 |
+
)
|
| 400 |
+
# print("valid fix", valid_fixations)
|
| 401 |
+
|
| 402 |
+
for fixation_index in range(self.in_fixations):
|
| 403 |
+
valid_indices = valid_fixations[:, fixation_index]
|
| 404 |
+
if not torch.any(valid_indices):
|
| 405 |
+
continue
|
| 406 |
+
this_input = tensor[
|
| 407 |
+
valid_indices,
|
| 408 |
+
fixation_index::self.in_fixations
|
| 409 |
+
]
|
| 410 |
+
this_result = self.convolutions[fixation_index](
|
| 411 |
+
this_input
|
| 412 |
+
)
|
| 413 |
+
# TODO: This will break if all data points
|
| 414 |
+
# in the batch don't have a single fixation
|
| 415 |
+
# but that's not a case I intend to train
|
| 416 |
+
# anyway.
|
| 417 |
+
if results is None:
|
| 418 |
+
b, _, _, _ = tensor.shape
|
| 419 |
+
_, _, h, w = this_result.shape
|
| 420 |
+
results = torch.zeros(
|
| 421 |
+
(b, self.out_channels, h, w),
|
| 422 |
+
dtype=tensor.dtype,
|
| 423 |
+
device=tensor.device
|
| 424 |
+
)
|
| 425 |
+
results[valid_indices] += this_result
|
| 426 |
+
|
| 427 |
+
return results
|