baharay commited on
Commit
b728ad5
·
verified ·
1 Parent(s): a3fc39c

Upload 30 files

Browse files
PNAS/PNASnet.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from operations import *
4
+ from torch.autograd import Variable
5
+ # from utils import drop_path
6
+
7
+
8
+ class Cell(nn.Module):
9
+
10
+ def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev):
11
+ super(Cell, self).__init__()
12
+ print(C_prev_prev, C_prev, C)
13
+ self.reduction = reduction
14
+
15
+ if reduction_prev is None:
16
+ self.preprocess0 = Identity()
17
+ elif reduction_prev is True:
18
+ self.preprocess0 = FactorizedReduce(C_prev_prev, C)
19
+ else:
20
+ self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0)
21
+ self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0)
22
+
23
+ if reduction:
24
+ op_names, indices = zip(*genotype.reduce)
25
+ concat = genotype.reduce_concat
26
+ else:
27
+ op_names, indices = zip(*genotype.normal)
28
+ concat = genotype.normal_concat
29
+
30
+ assert len(op_names) == len(indices)
31
+ self._steps = len(op_names) // 2
32
+ self._concat = concat
33
+ self.multiplier = len(concat)
34
+
35
+ self._ops = nn.ModuleList()
36
+ for name, index in zip(op_names, indices):
37
+ stride = 2 if reduction and index < 2 else 1
38
+ if reduction_prev is None and index == 0:
39
+ op = OPS[name](C_prev_prev, C, stride, True)
40
+ else:
41
+ op = OPS[name](C, C, stride, True)
42
+ self._ops += [op]
43
+ self._indices = indices
44
+
45
+ def forward(self, s0, s1, drop_prob):
46
+ s0 = self.preprocess0(s0)
47
+ s1 = self.preprocess1(s1)
48
+
49
+ states = [s0, s1]
50
+ for i in range(self._steps):
51
+ h1 = states[self._indices[2*i]]
52
+ h2 = states[self._indices[2*i+1]]
53
+ op1 = self._ops[2*i]
54
+ op2 = self._ops[2*i+1]
55
+ h1 = op1(h1)
56
+ h2 = op2(h2)
57
+ # if self.training and drop_prob > 0.:
58
+ # if not isinstance(op1, Identity):
59
+ # h1 = drop_path(h1, drop_prob)
60
+ # if not isinstance(op2, Identity):
61
+ # h2 = drop_path(h2, drop_prob)
62
+ s = h1 + h2
63
+ states += [s]
64
+ return torch.cat([states[i] for i in self._concat], dim=1)
65
+
66
+
67
+ class AuxiliaryHeadImageNet(nn.Module):
68
+
69
+ def __init__(self, C, num_classes):
70
+ """assuming input size 14x14"""
71
+ super(AuxiliaryHeadImageNet, self).__init__()
72
+ self.features = nn.Sequential(
73
+ nn.ReLU(inplace=True),
74
+ nn.AvgPool2d(5, stride=2, padding=0, count_include_pad=False),
75
+ nn.Conv2d(C, 128, 1, bias=False),
76
+ nn.BatchNorm2d(128),
77
+ nn.ReLU(inplace=True),
78
+ nn.Conv2d(128, 768, 2, bias=False),
79
+ nn.BatchNorm2d(768),
80
+ nn.ReLU(inplace=True)
81
+ )
82
+ self.classifier = nn.Linear(768, num_classes)
83
+
84
+ def forward(self, x):
85
+ x = self.features(x)
86
+ x = self.classifier(x.view(x.size(0),-1))
87
+ return x
88
+
89
+
90
+ class NetworkImageNet(nn.Module):
91
+
92
+ def __init__(self, C, num_classes, layers, auxiliary, genotype):
93
+ super(NetworkImageNet, self).__init__()
94
+ self._layers = layers
95
+ self._auxiliary = auxiliary
96
+
97
+ self.conv0 = nn.Conv2d(3, 96, kernel_size=3, stride=2, padding=0, bias=False)
98
+ self.conv0_bn = nn.BatchNorm2d(96, eps=1e-3)
99
+ self.stem1 = Cell(genotype, 96, 96, C // 4, True, None)
100
+ self.stem2 = Cell(genotype, 96, C * self.stem1.multiplier // 4, C // 2, True, True)
101
+
102
+ C_prev_prev, C_prev, C_curr = C * self.stem1.multiplier // 4, C * self.stem2.multiplier // 2, C
103
+
104
+ self.cells = nn.ModuleList()
105
+ reduction_prev = True
106
+ for i in range(layers):
107
+ if i in [layers // 3, 2 * layers // 3]:
108
+ C_curr *= 2
109
+ reduction = True
110
+ else:
111
+ reduction = False
112
+ cell = Cell(genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev)
113
+ reduction_prev = reduction
114
+ self.cells += [cell]
115
+ C_prev_prev, C_prev = C_prev, cell.multiplier * C_curr
116
+ if i == 2 * layers // 3:
117
+ C_to_auxiliary = C_prev
118
+
119
+ if auxiliary:
120
+ self.auxiliary_head = AuxiliaryHeadImageNet(C_to_auxiliary, num_classes)
121
+ self.relu = nn.ReLU(inplace=False)
122
+ self.global_pooling = nn.AdaptiveAvgPool2d(1)
123
+ self.classifier = nn.Linear(C_prev, num_classes)
124
+
125
+ def forward(self, input):
126
+ logits_aux = None
127
+ s0 = self.conv0(input)
128
+ s0 = self.conv0_bn(s0)
129
+ s1 = self.stem1(s0, s0, self.drop_path_prob)
130
+ s0, s1 = s1, self.stem2(s0, s1, self.drop_path_prob)
131
+ for i, cell in enumerate(self.cells):
132
+ s0, s1 = s1, cell(s0, s1, self.drop_path_prob)
133
+ if i == 2 * self._layers // 3:
134
+ if self._auxiliary and self.training:
135
+ logits_aux = self.auxiliary_head(s1)
136
+ s1 = self.relu(s1)
137
+ out = self.global_pooling(s1)
138
+ logits = self.classifier(out.view(out.size(0), -1))
139
+ return logits, logits_aux
140
+
PNAS/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
PNAS/genotypes.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import namedtuple
2
+
3
+ Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat')
4
+
5
+ PNASNet = Genotype(
6
+ normal = [
7
+ ('sep_conv_5x5', 0),
8
+ ('max_pool_3x3', 0),
9
+ ('sep_conv_7x7', 1),
10
+ ('max_pool_3x3', 1),
11
+ ('sep_conv_5x5', 1),
12
+ ('sep_conv_3x3', 1),
13
+ ('sep_conv_3x3', 4),
14
+ ('max_pool_3x3', 1),
15
+ ('sep_conv_3x3', 0),
16
+ ('skip_connect', 1),
17
+ ],
18
+ normal_concat = [2, 3, 4, 5, 6],
19
+ reduce = [
20
+ ('sep_conv_5x5', 0),
21
+ ('max_pool_3x3', 0),
22
+ ('sep_conv_7x7', 1),
23
+ ('max_pool_3x3', 1),
24
+ ('sep_conv_5x5', 1),
25
+ ('sep_conv_3x3', 1),
26
+ ('sep_conv_3x3', 4),
27
+ ('max_pool_3x3', 1),
28
+ ('sep_conv_3x3', 0),
29
+ ('skip_connect', 1),
30
+ ],
31
+ reduce_concat = [2, 3, 4, 5, 6],
32
+ )
33
+
PNAS/operations.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ OPS = {
5
+ 'none' : lambda C_in, C_out, stride, affine: Zero(stride),
6
+ 'avg_pool_3x3' : lambda C_in, C_out, stride, affine: nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False) if C_in == C_out else nn.Sequential(
7
+ nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False),
8
+ nn.Conv2d(C_in, C_out, 1, stride=1, padding=0, bias=False),
9
+ nn.BatchNorm2d(C_out, eps=1e-3, affine=affine)
10
+ ),
11
+ 'max_pool_3x3' : lambda C_in, C_out, stride, affine: nn.MaxPool2d(3, stride=stride, padding=1) if C_in == C_out else nn.Sequential(
12
+ nn.MaxPool2d(3, stride=stride, padding=1),
13
+ nn.Conv2d(C_in, C_out, 1, stride=1, padding=0, bias=False),
14
+ nn.BatchNorm2d(C_out, eps=1e-3, affine=affine)
15
+ ),
16
+ 'skip_connect' : lambda C_in, C_out, stride, affine: Identity() if stride == 1 else ReLUConvBN(C_in, C_out, 1, stride, 0, affine=affine),
17
+ 'sep_conv_3x3' : lambda C_in, C_out, stride, affine: SepConv(C_in, C_out, 3, stride, 1, affine=affine),
18
+ 'sep_conv_5x5' : lambda C_in, C_out, stride, affine: SepConv(C_in, C_out, 5, stride, 2, affine=affine),
19
+ 'sep_conv_7x7' : lambda C_in, C_out, stride, affine: SepConv(C_in, C_out, 7, stride, 3, affine=affine),
20
+ 'dil_conv_3x3' : lambda C_in, C_out, stride, affine: DilConv(C_in, C_out, 3, stride, 2, 2, affine=affine),
21
+ 'dil_conv_5x5' : lambda C_in, C_out, stride, affine: DilConv(C_in, C_out, 5, stride, 4, 2, affine=affine),
22
+ 'conv_7x1_1x7' : lambda C_in, C_out, stride, affine: nn.Sequential(
23
+ nn.ReLU(inplace=False),
24
+ nn.Conv2d(C_in, C_in, (1,7), stride=(1, stride), padding=(0, 3), bias=False),
25
+ nn.Conv2d(C_in, C_out, (7,1), stride=(stride, 1), padding=(3, 0), bias=False),
26
+ nn.BatchNorm2d(C_out, eps=1e-3, affine=affine)
27
+ ),
28
+ }
29
+
30
+ class ReLUConvBN(nn.Module):
31
+
32
+ def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
33
+ super(ReLUConvBN, self).__init__()
34
+ self.op = nn.Sequential(
35
+ nn.ReLU(inplace=False),
36
+ nn.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, bias=False),
37
+ nn.BatchNorm2d(C_out, eps=1e-3, affine=affine)
38
+ )
39
+
40
+ def forward(self, x):
41
+ return self.op(x)
42
+
43
+ class DilConv(nn.Module):
44
+
45
+ def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True):
46
+ super(DilConv, self).__init__()
47
+ self.op = nn.Sequential(
48
+ nn.ReLU(inplace=False),
49
+ nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=C_in, bias=False),
50
+ nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
51
+ nn.BatchNorm2d(C_out, eps=1e-3, affine=affine),
52
+ )
53
+
54
+ def forward(self, x):
55
+ return self.op(x)
56
+
57
+
58
+ class SepConv(nn.Module):
59
+
60
+ def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
61
+ super(SepConv, self).__init__()
62
+ self.op = nn.Sequential(
63
+ nn.ReLU(inplace=False),
64
+ nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, groups=C_in, bias=False),
65
+ nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
66
+ nn.BatchNorm2d(C_out, eps=1e-3, affine=affine),
67
+ nn.ReLU(inplace=False),
68
+ nn.Conv2d(C_out, C_out, kernel_size=kernel_size, stride=1, padding=padding, groups=C_out, bias=False),
69
+ nn.Conv2d(C_out, C_out, kernel_size=1, padding=0, bias=False),
70
+ nn.BatchNorm2d(C_out, eps=1e-3, affine=affine),
71
+ )
72
+
73
+ def forward(self, x):
74
+ return self.op(x)
75
+
76
+
77
+ class Identity(nn.Module):
78
+
79
+ def __init__(self):
80
+ super(Identity, self).__init__()
81
+
82
+ def forward(self, x):
83
+ return x
84
+
85
+
86
+ class Zero(nn.Module):
87
+
88
+ def __init__(self, stride):
89
+ super(Zero, self).__init__()
90
+ self.stride = stride
91
+
92
+ def forward(self, x):
93
+ if self.stride == 1:
94
+ return x.mul(0.)
95
+ return x[:,:,::self.stride,::self.stride].mul(0.)
96
+
97
+
98
+ class FactorizedReduce(nn.Module):
99
+
100
+ def __init__(self, C_in, C_out, affine=True):
101
+ super(FactorizedReduce, self).__init__()
102
+ assert C_out % 2 == 0
103
+ self.relu = nn.ReLU(inplace=False)
104
+ self.conv_1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
105
+ self.conv_2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
106
+ self.bn = nn.BatchNorm2d(C_out, eps=1e-3, affine=affine)
107
+ self.pad = nn.ConstantPad2d((0, 1, 0, 1), 0)
108
+
109
+ def forward(self, x):
110
+ x = self.relu(x)
111
+ y = self.pad(x)
112
+ out = torch.cat([self.conv_1(x), self.conv_2(y[:,:,1:,1:])], dim=1)
113
+ out = self.bn(out)
114
+ return out
115
+
README.md CHANGED
@@ -1,12 +1 @@
1
- ---
2
- title: Tempsal
3
- emoji: ⚡
4
- colorFrom: gray
5
- colorTo: red
6
- sdk: gradio
7
- sdk_version: 4.44.1
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ Download the model checkpoint from: https://drive.google.com/drive/folders/1W92oXYra_OPYkR1W56D80iDexWIR7f7Z?usp=sharing Follow the instructions on inference.ipynb. This notebook provides predictions on temporal and image saliency together.
 
 
 
 
 
 
 
 
 
 
 
checkpoints/Readme.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ Please download the checkpoint from the following link :
2
+ https://drive.google.com/drive/folders/1W92oXYra_OPYkR1W56D80iDexWIR7f7Z?usp=sharing
dataloader_clean.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchvision import transforms
2
+ import torchvision.transforms.functional as TF
3
+ from PIL import Image
4
+ from torch.utils.data import DataLoader
5
+ import numpy as np
6
+ import torch
7
+ import os, cv2
8
+ from utils import *
9
+ import json
10
+ import random
11
+ from pycocotools.coco import COCO
12
+
13
+
14
+ class SaliconDataset(DataLoader):
15
+ def __init__(self, img_dir, gt_dir, fix_dir, img_ids, exten='.png'):
16
+ self.img_dir = img_dir
17
+ self.gt_dir = gt_dir
18
+ self.fix_dir = fix_dir
19
+ self.img_ids = img_ids
20
+ self.exten = exten
21
+ self.img_transform = transforms.Compose([
22
+ transforms.Resize((256, 256)),
23
+ transforms.ToTensor(),
24
+ transforms.Normalize([0.5, 0.5, 0.5],
25
+ [0.5, 0.5, 0.5])
26
+ ])
27
+
28
+ def __getitem__(self, idx):
29
+ img_id = self.img_ids[idx]
30
+ img_path = os.path.join(self.img_dir, img_id + '.jpg')
31
+ gt_path = os.path.join(self.gt_dir, img_id + self.exten)
32
+ fix_path = os.path.join(self.fix_dir, img_id + self.exten)
33
+
34
+ img = Image.open(img_path).convert('RGB')
35
+ img = self.img_transform(img)
36
+
37
+ gt = np.array(Image.open(gt_path).convert('L'))
38
+ gt = gt.astype('float')
39
+ gt = cv2.resize(gt, (256,256))
40
+ if np.max(gt) > 1.0:
41
+ gt = gt / 255.0
42
+
43
+ fixations = np.array(Image.open(fix_path).convert('L'))
44
+ fixations = fixations.astype('float')
45
+ fixations = (fixations > 0.5).astype('float')
46
+
47
+ assert np.min(gt)>=0.0 and np.max(gt)<=1.0
48
+ assert np.min(fixations)==0.0 and np.max(fixations)==1.0
49
+ return img, torch.FloatTensor(gt), torch.FloatTensor(fixations)
50
+
51
+ def __len__(self):
52
+ return len(self.img_ids)
generate_volumes.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils import *
2
+ from operator import itemgetter
3
+ from itertools import groupby
4
+ import cv2
5
+ import argparse
6
+
7
+
8
+
9
+ parser = argparse.ArgumentParser()
10
+ parser.add_argument('--time_slices', default=5, type=int)
11
+
12
+ def generate_fixation_files(path, time_slices):
13
+ print('Parsing fixations of ' + path + '...')
14
+ filenames = [nm.split(".")[0] for nm in os.listdir(FIXATION_PATH + path)]
15
+
16
+ def create_dirs(dir_path):
17
+ if not os.path.exists(dir_path):
18
+ os.makedirs(dir_path)
19
+
20
+ dir_path = dir_path + '/' + path
21
+ if not os.path.exists(dir_path):
22
+ os.makedirs(dir_path)
23
+
24
+ return dir_path
25
+
26
+ sal_vol_path = create_dirs(SAL_VOL_PATH + str(time_slices))
27
+ fix_vol_path = create_dirs(FIX_VOL_PATH + str(time_slices))
28
+
29
+ conv2D = GaussianBlur2D().cuda()
30
+
31
+ print('Generating saliency volumes of ' + path + '...')
32
+ for filename in tqdm(filenames):
33
+ fixation_volume = parse_fixations([filename], FIXATION_PATH + path, progress_bar=False)[0]
34
+ fix_timestamps = sorted([fixation for fix_timestamps in fixation_volume
35
+ for fixation in fix_timestamps], key=lambda x: x[0])
36
+ fix_timestamps = np.array([(min(int(ts * time_slices / TIMESPAN), time_slices-1), (x, y)) for (ts, (x, y)) in fix_timestamps])
37
+
38
+ # Saving fixation map
39
+ fix_vol = np.zeros(shape=(time_slices,H,W))
40
+ for i, coords in fix_timestamps:
41
+ fix_vol[i, coords[1] - 1, coords[0] - 1] = 1
42
+
43
+ # Saving fixation list with timestamps
44
+ compressed = np.array([(key, list(v[1] for v in valuesiter))
45
+ for key,valuesiter in groupby(fix_timestamps, key=itemgetter(0))])
46
+
47
+ saliency_volume = get_saliency_volume(compressed, conv2D, time_slices)
48
+ saliency_volume = saliency_volume.squeeze(0).squeeze(0).detach().cpu().numpy()
49
+
50
+ for i, saliency_slice in enumerate(saliency_volume):
51
+ cv2.imwrite(sal_vol_path + filename + '_' + str(i) + '.png', 255 * saliency_slice)
52
+ cv2.imwrite(fix_vol_path + filename + '_' + str(i) + '.png', 255 * fix_vol[i])
53
+
54
+ args = parser.parse_args()
55
+ time_slices = args.time_slices
56
+ generate_fixation_files('train/', time_slices)
57
+ generate_fixation_files('val/', time_slices)
inference.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
loss.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import cv2
4
+
5
+ def kldiv(s_map, gt):
6
+ batch_size = s_map.size(0)
7
+ w = s_map.size(1)
8
+ h = s_map.size(2)
9
+
10
+ sum_s_map = torch.sum(s_map.view(batch_size, -1), 1)
11
+ expand_s_map = sum_s_map.view(batch_size, 1, 1).expand(batch_size, w, h)
12
+
13
+ assert expand_s_map.size() == s_map.size()
14
+
15
+ sum_gt = torch.sum(gt.view(batch_size, -1), 1)
16
+ expand_gt = sum_gt.view(batch_size, 1, 1).expand(batch_size, w, h)
17
+
18
+ assert expand_gt.size() == gt.size()
19
+
20
+ s_map = s_map/(expand_s_map*1.0)
21
+ gt = gt / (expand_gt*1.0)
22
+
23
+ s_map = s_map.view(batch_size, -1)
24
+ gt = gt.view(batch_size, -1)
25
+
26
+ eps = 2.2204e-16
27
+ result = gt * torch.log(eps + gt/(s_map + eps))
28
+ # print(torch.log(eps + gt/(s_map + eps)) )
29
+ return torch.mean(torch.sum(result, 1))
30
+
31
+
32
+ def normalize_map(s_map):
33
+ # normalize the salience map (as done in MIT code)
34
+ batch_size = s_map.size(0)
35
+ w = s_map.size(1)
36
+ h = s_map.size(2)
37
+
38
+ min_s_map = torch.min(s_map.view(batch_size, -1), 1)[0].view(batch_size, 1, 1).expand(batch_size, w, h)
39
+ max_s_map = torch.max(s_map.view(batch_size, -1), 1)[0].view(batch_size, 1, 1).expand(batch_size, w, h)
40
+
41
+ norm_s_map = (s_map - min_s_map)/(max_s_map-min_s_map*1.0)
42
+ return norm_s_map
43
+
44
+ def similarity(s_map, gt):
45
+ ''' For single image metric
46
+ Size of Image - WxH or 1xWxH
47
+ gt is ground truth saliency map
48
+ '''
49
+ batch_size = s_map.size(0)
50
+ w = s_map.size(1)
51
+ h = s_map.size(2)
52
+
53
+ s_map = normalize_map(s_map)
54
+ gt = normalize_map(gt)
55
+
56
+ sum_s_map = torch.sum(s_map.view(batch_size, -1), 1)
57
+ expand_s_map = sum_s_map.view(batch_size, 1, 1).expand(batch_size, w, h)
58
+
59
+ assert expand_s_map.size() == s_map.size()
60
+
61
+ sum_gt = torch.sum(gt.view(batch_size, -1), 1)
62
+ expand_gt = sum_gt.view(batch_size, 1, 1).expand(batch_size, w, h)
63
+
64
+ s_map = s_map/(expand_s_map*1.0)
65
+ gt = gt / (expand_gt*1.0)
66
+
67
+ s_map = s_map.view(batch_size, -1)
68
+ gt = gt.view(batch_size, -1)
69
+ return torch.mean(torch.sum(torch.min(s_map, gt), 1))
70
+
71
+ def cc(s_map, gt):
72
+ batch_size = s_map.size(0)
73
+ w = s_map.size(1)
74
+ h = s_map.size(2)
75
+
76
+ mean_s_map = torch.mean(s_map.view(batch_size, -1), 1).view(batch_size, 1, 1).expand(batch_size, w, h)
77
+ std_s_map = torch.std(s_map.view(batch_size, -1), 1).view(batch_size, 1, 1).expand(batch_size, w, h)
78
+
79
+ mean_gt = torch.mean(gt.view(batch_size, -1), 1).view(batch_size, 1, 1).expand(batch_size, w, h)
80
+ std_gt = torch.std(gt.view(batch_size, -1), 1).view(batch_size, 1, 1).expand(batch_size, w, h)
81
+
82
+ s_map = (s_map - mean_s_map) / std_s_map
83
+ gt = (gt - mean_gt) / std_gt
84
+
85
+ ab = torch.sum((s_map * gt).view(batch_size, -1), 1)
86
+ aa = torch.sum((s_map * s_map).view(batch_size, -1), 1)
87
+ bb = torch.sum((gt * gt).view(batch_size, -1), 1)
88
+
89
+ return torch.mean(ab / (torch.sqrt(aa*bb)))
90
+
91
+ def nss(s_map, gt):
92
+ if s_map.size() != gt.size():
93
+ s_map = s_map.cpu().detach().numpy()
94
+ s_map = torch.FloatTensor([cv2.resize(map, (gt.size(2), gt.size(1))) for map in s_map])
95
+ s_map = s_map.cuda()
96
+ gt = gt.cuda()
97
+
98
+ assert s_map.size()==gt.size()
99
+ batch_size = s_map.size(0)
100
+ w = s_map.size(1)
101
+ h = s_map.size(2)
102
+ mean_s_map = torch.mean(s_map.view(batch_size, -1), 1).view(batch_size, 1, 1).expand(batch_size, w, h)
103
+ std_s_map = torch.std(s_map.view(batch_size, -1), 1).view(batch_size, 1, 1).expand(batch_size, w, h)
104
+
105
+ eps = 2.2204e-16
106
+ s_map = (s_map - mean_s_map) / (std_s_map + eps)
107
+
108
+ s_map = torch.sum((s_map * gt).view(batch_size, -1), 1)
109
+ count = torch.sum(gt.view(batch_size, -1), 1)
110
+ return torch.mean(s_map / count)
111
+
112
+ def auc_judd(saliencyMap, fixationMap, jitter=True, normalize=False):
113
+ # saliencyMap is the saliency map
114
+ # fixationMap is the human fixation map (binary matrix)
115
+ # jitter=True will add tiny non-zero random constant to all map locations to ensure
116
+ # ROC can be calculated robustly (to avoid uniform region)
117
+
118
+ # If there are no fixations to predict, return NaN
119
+ if saliencyMap.size() != fixationMap.size():
120
+ saliencyMap = saliencyMap.cpu().squeeze(0).numpy()
121
+ saliencyMap = torch.FloatTensor(cv2.resize(saliencyMap, (fixationMap.size(2), fixationMap.size(1)))).unsqueeze(0)
122
+ # saliencyMap = saliencyMap.cuda()
123
+ # fixationMap = fixationMap.cuda()
124
+ if len(saliencyMap.size())==3:
125
+ saliencyMap = saliencyMap[0,:,:]
126
+ fixationMap = fixationMap[0,:,:]
127
+ saliencyMap = saliencyMap.numpy()
128
+ fixationMap = fixationMap.numpy()
129
+ if normalize:
130
+ saliencyMap = normalize_map(saliencyMap)
131
+
132
+ if not fixationMap.any():
133
+ print('Error: no fixationMap')
134
+ score = float('nan')
135
+ return score
136
+
137
+ # make the saliencyMap the size of the image of fixationMap
138
+
139
+ if not np.shape(saliencyMap) == np.shape(fixationMap):
140
+ from scipy.misc import imresize
141
+ saliencyMap = imresize(saliencyMap, np.shape(fixationMap))
142
+
143
+ # jitter saliency maps that come from saliency models that have a lot of zero values.
144
+ # If the saliency map is made with a Gaussian then it does not need to be jittered as
145
+ # the values are varied and there is not a large patch of the same value. In fact
146
+ # jittering breaks the ordering in the small values!
147
+ if jitter:
148
+ # jitter the saliency map slightly to distrupt ties of the same numbers
149
+ saliencyMap = saliencyMap + np.random.random(np.shape(saliencyMap)) / 10 ** 7
150
+
151
+ # normalize saliency map
152
+ saliencyMap = (saliencyMap - saliencyMap.min()) \
153
+ / (saliencyMap.max() - saliencyMap.min())
154
+
155
+ if np.isnan(saliencyMap).all():
156
+ print('NaN saliencyMap')
157
+ score = float('nan')
158
+ return score
159
+
160
+ S = saliencyMap.flatten()
161
+ F = fixationMap.flatten()
162
+
163
+ Sth = S[F > 0] # sal map values at fixation locations
164
+ Nfixations = len(Sth)
165
+ Npixels = len(S)
166
+
167
+ allthreshes = sorted(Sth, reverse=True) # sort sal map values, to sweep through values
168
+ tp = np.zeros((Nfixations + 2))
169
+ fp = np.zeros((Nfixations + 2))
170
+ tp[0], tp[-1] = 0, 1
171
+ fp[0], fp[-1] = 0, 1
172
+
173
+ for i in range(Nfixations):
174
+ thresh = allthreshes[i]
175
+ aboveth = (S >= thresh).sum() # total number of sal map values above threshold
176
+ tp[i + 1] = float(i + 1) / Nfixations # ratio sal map values at fixation locations
177
+ # above threshold
178
+ fp[i + 1] = float(aboveth - i) / (Npixels - Nfixations) # ratio other sal map values
179
+ # above threshold
180
+
181
+ score = np.trapz(tp, x=fp)
182
+ allthreshes = np.insert(allthreshes, 0, 0)
183
+ allthreshes = np.append(allthreshes, 1)
184
+
185
+ return score
186
+
187
+ def auc_shuff(s_map,gt,other_map,splits=100,stepsize=0.1):
188
+
189
+ if len(s_map.size())==3:
190
+ s_map = s_map[0,:,:]
191
+ gt = gt[0,:,:]
192
+ other_map = other_map[0,:,:]
193
+
194
+
195
+ s_map = s_map.numpy()
196
+ s_map = normalize_map(s_map)
197
+ gt = gt.numpy()
198
+ other_map = other_map.numpy()
199
+
200
+ num_fixations = np.sum(gt)
201
+
202
+ x,y = np.where(other_map==1)
203
+ other_map_fixs = []
204
+ for j in zip(x,y):
205
+ other_map_fixs.append(j[0]*other_map.shape[0] + j[1])
206
+ ind = len(other_map_fixs)
207
+ assert ind==np.sum(other_map), 'something is wrong in auc shuffle'
208
+
209
+
210
+ num_fixations_other = min(ind,num_fixations)
211
+
212
+ num_pixels = s_map.shape[0]*s_map.shape[1]
213
+ random_numbers = []
214
+ for i in range(0,splits):
215
+ temp_list = []
216
+ t1 = np.random.permutation(ind)
217
+ for k in t1:
218
+ temp_list.append(other_map_fixs[k])
219
+ random_numbers.append(temp_list)
220
+
221
+ aucs = []
222
+ # for each split, calculate auc
223
+ for i in random_numbers:
224
+ r_sal_map = []
225
+ for k in i:
226
+ r_sal_map.append(s_map[k%s_map.shape[0]-1, int(k/s_map.shape[0])])
227
+ # in these values, we need to find thresholds and calculate auc
228
+ thresholds = [0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9]
229
+
230
+ r_sal_map = np.array(r_sal_map)
231
+
232
+ # once threshs are got
233
+ thresholds = sorted(set(thresholds))
234
+ area = []
235
+ area.append((0.0,0.0))
236
+ for thresh in thresholds:
237
+ # in the salience map, keep only those pixels with values above threshold
238
+ temp = np.zeros(s_map.shape)
239
+ temp[s_map>=thresh] = 1.0
240
+ num_overlap = np.where(np.add(temp,gt)==2)[0].shape[0]
241
+ tp = num_overlap/(num_fixations*1.0)
242
+
243
+ #fp = (np.sum(temp) - num_overlap)/((np.shape(gt)[0] * np.shape(gt)[1]) - num_fixations)
244
+ # number of values in r_sal_map, above the threshold, divided by num of random locations = num of fixations
245
+ fp = len(np.where(r_sal_map>thresh)[0])/(num_fixations*1.0)
246
+
247
+ area.append((round(tp,4),round(fp,4)))
248
+
249
+ area.append((1.0,1.0))
250
+ area.sort(key = lambda x:x[0])
251
+ tp_list = [x[0] for x in area]
252
+ fp_list = [x[1] for x in area]
253
+
254
+ aucs.append(np.trapz(np.array(tp_list),np.array(fp_list)))
255
+
256
+ return np.mean(aucs)
model.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchvision.models as models
2
+ import torch
3
+ import torch.nn as nn
4
+ from collections import OrderedDict
5
+ import sys
6
+ from einops import rearrange, repeat
7
+ from einops.layers.torch import Rearrange
8
+ from scipy import ndimage
9
+
10
+ sys.path.append('./PNAS/')
11
+ from PNASnet import *
12
+ from genotypes import PNASNet
13
+ import torch.nn.functional as nnf
14
+ import numpy as np
15
+
16
+
17
+
18
+ class PNASModel(nn.Module):
19
+
20
+ def __init__(self, num_channels=3, train_enc=False, load_weight=1):
21
+ super(PNASModel, self).__init__()
22
+ self.pnas = NetworkImageNet(216, 1001, 12, False, PNASNet)
23
+ if load_weight:
24
+ self.pnas.load_state_dict(torch.load(self.path))
25
+
26
+ for param in self.pnas.parameters():
27
+ param.requires_grad = train_enc
28
+
29
+ self.padding = nn.ConstantPad2d((0,1,0,1),0)
30
+ self.drop_path_prob = 0
31
+
32
+ self.linear_upsampling = nn.UpsamplingBilinear2d(scale_factor=2)
33
+
34
+ self.deconv_layer0 = nn.Sequential(
35
+ nn.Conv2d(in_channels = 4320, out_channels = 512, kernel_size=3, padding=1, bias = True),
36
+ nn.ReLU(inplace=True),
37
+ self.linear_upsampling
38
+ )
39
+
40
+ self.deconv_layer1 = nn.Sequential(
41
+ nn.Conv2d(in_channels = 512+2160, out_channels = 256, kernel_size = 3, padding = 1, bias = True),
42
+ nn.ReLU(inplace=True),
43
+ self.linear_upsampling
44
+ )
45
+ self.deconv_layer2 = nn.Sequential(
46
+ nn.Conv2d(in_channels = 1080+256, out_channels = 270, kernel_size = 3, padding = 1, bias = True),
47
+ nn.ReLU(inplace=True),
48
+ self.linear_upsampling
49
+ )
50
+ self.deconv_layer3 = nn.Sequential(
51
+ nn.Conv2d(in_channels = 540, out_channels = 96, kernel_size = 3, padding = 1, bias = True),
52
+ nn.ReLU(inplace=True),
53
+ self.linear_upsampling
54
+ )
55
+ self.deconv_layer4 = nn.Sequential(
56
+ nn.Conv2d(in_channels = 192, out_channels = 128, kernel_size = 3, padding = 1, bias = True),
57
+ nn.ReLU(inplace=True),
58
+ self.linear_upsampling
59
+ )
60
+ self.deconv_layer5 = nn.Sequential(
61
+ nn.Conv2d(in_channels = 128, out_channels = 128, kernel_size = 3, padding = 1, bias = True),
62
+ nn.ReLU(inplace=True),
63
+ nn.Conv2d(in_channels = 128, out_channels = 1, kernel_size = 3, padding = 1, bias = True),
64
+ nn.Sigmoid()
65
+ )
66
+
67
+ def forward(self, images):
68
+ batch_size = images.size(0)
69
+
70
+ s0 = self.pnas.conv0(images)
71
+ s0 = self.pnas.conv0_bn(s0)
72
+ out1 = self.padding(s0)
73
+
74
+ s1 = self.pnas.stem1(s0, s0, self.drop_path_prob)
75
+ out2 = s1
76
+ s0, s1 = s1, self.pnas.stem2(s0, s1, 0)
77
+
78
+ for i, cell in enumerate(self.pnas.cells):
79
+ s0, s1 = s1, cell(s0, s1, 0)
80
+ if i==3:
81
+ out3 = s1
82
+ if i==7:
83
+ out4 = s1
84
+ if i==11:
85
+ out5 = s1
86
+
87
+ out5 = self.deconv_layer0(out5)
88
+
89
+ x = torch.cat((out5,out4), 1)
90
+ x = self.deconv_layer1(x)
91
+
92
+ x = torch.cat((x,out3), 1)
93
+ x = self.deconv_layer2(x)
94
+
95
+ x = torch.cat((x,out2), 1)
96
+ x = self.deconv_layer3(x)
97
+ x = torch.cat((x,out1), 1)
98
+
99
+ x = self.deconv_layer4(x)
100
+
101
+ x = self.deconv_layer5(x)
102
+ x = x.squeeze(1)
103
+ # print("PNAS pred actual pnas:", x.mean(),x.min(), x.max(), x.sum())
104
+
105
+ return x
106
+
107
+ class PNASVolModellast(nn.Module):
108
+
109
+ def __init__(self, time_slices, num_channels=3, train_enc=False, load_weight=1):
110
+ super(PNASVolModellast, self).__init__()
111
+
112
+ self.pnas = NetworkImageNet(216, 1001, 12, False, PNASNet)
113
+ if load_weight:
114
+ state_dict = torch.load(self.path)
115
+ new_state_dict = OrderedDict()
116
+ for k, v in state_dict.items():
117
+ if 'module' in k:
118
+ k = 'module.pnas.' + k
119
+ else:
120
+ k = k.replace('pnas.', '')
121
+ new_state_dict[k] = v
122
+ self.pnas.load_state_dict(new_state_dict, strict=False)
123
+
124
+
125
+ for param in self.pnas.parameters():
126
+ param.requires_grad = train_enc
127
+
128
+ self.padding = nn.ConstantPad2d((0,1,0,1),0)
129
+ self.drop_path_prob = 0
130
+
131
+ self.linear_upsampling = nn.UpsamplingBilinear2d(scale_factor=2)
132
+
133
+ self.deconv_layer0 = nn.Sequential(
134
+ nn.Conv2d(in_channels = 4320, out_channels = 512, kernel_size=3, padding=1, bias = True),
135
+ nn.ReLU(inplace=True),
136
+ self.linear_upsampling
137
+ )
138
+
139
+ self.deconv_layer1 = nn.Sequential(
140
+ nn.Conv2d(in_channels = 512+2160, out_channels = 256, kernel_size = 3, padding = 1, bias = True),
141
+ nn.ReLU(inplace=True),
142
+ self.linear_upsampling
143
+ )
144
+ self.deconv_layer2 = nn.Sequential(
145
+ nn.Conv2d(in_channels = 1080+256, out_channels = 270, kernel_size = 3, padding = 1, bias = True),
146
+ nn.ReLU(inplace=True),
147
+ self.linear_upsampling
148
+ )
149
+ self.deconv_layer3 = nn.Sequential(
150
+ nn.Conv2d(in_channels = 540, out_channels = 96, kernel_size = 3, padding = 1, bias = True),
151
+ nn.ReLU(inplace=True),
152
+ self.linear_upsampling
153
+ )
154
+ self.deconv_layer4 = nn.Sequential(
155
+ nn.Conv2d(in_channels = 192, out_channels = 128, kernel_size = 3, padding = 1, bias = True),
156
+ nn.ReLU(inplace=True),
157
+ self.linear_upsampling
158
+ )
159
+
160
+ self.deconv_layer5 = nn.Sequential(
161
+ nn.Conv2d(in_channels = 128, out_channels = 64, kernel_size = 3, padding = 1, bias = True),
162
+ nn.ReLU(inplace=True),
163
+ nn.Conv2d(in_channels = 64, out_channels = 32, kernel_size = 3, padding = 1, bias = True),
164
+ nn.ReLU(inplace=True),
165
+ nn.Conv2d(in_channels = 32, out_channels = time_slices, kernel_size = 3, padding = 1, bias = True),
166
+ nn.Sigmoid()
167
+ )
168
+
169
+ def forward(self, images):
170
+ s0 = self.pnas.conv0(images)
171
+ s0 = self.pnas.conv0_bn(s0)
172
+ out1 = self.padding(s0)
173
+
174
+ s1 = self.pnas.stem1(s0, s0, self.drop_path_prob)
175
+ out2 = s1
176
+ s0, s1 = s1, self.pnas.stem2(s0, s1, 0)
177
+
178
+ for i, cell in enumerate(self.pnas.cells):
179
+ s0, s1 = s1, cell(s0, s1, 0)
180
+ if i==3:
181
+ out3 = s1
182
+ if i==7:
183
+ out4 = s1
184
+ if i==11:
185
+ out5 = s1
186
+
187
+ out5 = self.deconv_layer0(out5)
188
+
189
+ x = torch.cat((out5,out4), 1)
190
+ x = self.deconv_layer1(x)
191
+
192
+ x = torch.cat((x,out3), 1)
193
+ x = self.deconv_layer2(x)
194
+
195
+ x = torch.cat((x,out2), 1)
196
+ x = self.deconv_layer3(x)
197
+ x = torch.cat((x,out1), 1)
198
+
199
+ x = self.deconv_layer4(x)
200
+
201
+ x = self.deconv_layer5(x)
202
+ x = x / x.max()
203
+
204
+ return x , [out1,out2,out3,out4,out5]
205
+
206
+
207
+ class PNASBoostedModelMultiLevel(nn.Module):
208
+
209
+ def __init__(self, device, model_path, model_vol_path, time_slices, train_model=False, selected_slices=""):
210
+ super(PNASBoostedModelMultiLevel, self).__init__()
211
+
212
+
213
+ self.selected_slices = selected_slices
214
+
215
+
216
+ self.linear_upsampling = nn.UpsamplingBilinear2d(scale_factor=2)
217
+
218
+ self.deconv_layer1 = nn.Sequential(
219
+ nn.Conv2d(in_channels = 512+2160+6, out_channels = 256, kernel_size = 3, padding = 1, bias = True),
220
+ nn.ReLU(inplace=True),
221
+ self.linear_upsampling
222
+ )
223
+ self.deconv_layer2 = nn.Sequential(
224
+ nn.Conv2d(in_channels = 1080+256+6, out_channels = 270, kernel_size = 3, padding = 1, bias = True),
225
+ nn.ReLU(inplace=True),
226
+ self.linear_upsampling
227
+ )
228
+ self.deconv_layer3 = nn.Sequential(
229
+ nn.Conv2d(in_channels = 540+6, out_channels = 96, kernel_size = 3, padding = 1, bias = True),
230
+ nn.ReLU(inplace=True),
231
+ self.linear_upsampling
232
+ )
233
+ self.deconv_layer4 = nn.Sequential(
234
+ nn.Conv2d(in_channels = 192+6, out_channels = 128, kernel_size = 3, padding = 1, bias = True),
235
+ nn.ReLU(inplace=True),
236
+ self.linear_upsampling
237
+ )
238
+
239
+
240
+ self.deconv_mix = nn.Sequential(
241
+ nn.Conv2d(in_channels = 128+6 , out_channels = 16, kernel_size = 3, padding = 1, bias = True),
242
+ nn.ReLU(inplace=True),
243
+ nn.Conv2d(in_channels = 16, out_channels = 32, kernel_size = 3, padding = 1, bias = True),
244
+ nn.ReLU(inplace=True),
245
+ nn.Conv2d(in_channels = 32, out_channels = 1, kernel_size = 3, padding = 1, bias = True),
246
+ nn.Sigmoid()
247
+ )
248
+ model_vol = PNASVolModellast(time_slices=5, load_weight=0) #change this to time slices
249
+ model_vol = nn.DataParallel(model_vol).cuda()
250
+ state_dict = torch.load(model_path)
251
+ vol_state_dict = OrderedDict()
252
+ sal_state_dict = OrderedDict()
253
+ smm_state_dict = OrderedDict()
254
+
255
+ for k, v in state_dict.items():
256
+ if 'pnas_vol' in k:
257
+
258
+ k = k.replace('pnas_vol.module.', '')
259
+ vol_state_dict[k] = v
260
+ elif 'pnas_sal' in k:
261
+ k = k.replace('pnas_sal.module.', '')
262
+ sal_state_dict[k] = v
263
+ else:
264
+ smm_state_dict[k] = v
265
+
266
+ self.load_state_dict(smm_state_dict)
267
+ model_vol.load_state_dict(vol_state_dict)
268
+ self.pnas_vol = nn.DataParallel(model_vol).cuda()
269
+
270
+ for param in self.pnas_vol.parameters():
271
+ param.requires_grad = False
272
+
273
+
274
+ model = PNASModel(load_weight=0)
275
+ model = nn.DataParallel(model).cuda()
276
+
277
+ model.load_state_dict(sal_state_dict, strict=True)
278
+ self.pnas_sal = nn.DataParallel(model).to(device)
279
+
280
+ for param in self.pnas_sal.parameters():
281
+ param.requires_grad = False #train_model
282
+
283
+ def forward(self, images):
284
+ # print("IMAGES", images.shape)
285
+
286
+ pnas_pred = self.pnas_sal(images).unsqueeze(1)
287
+ pnas_vol_pred , outs = self.pnas_vol(images)
288
+
289
+ out1 , out2, out3, out4, out5 = outs
290
+ #print(pnas_vol_pred.shape)
291
+ x_maps = torch.cat((pnas_pred, pnas_vol_pred), 1)
292
+
293
+ x = torch.cat((out5,out4), 1)
294
+ x_maps16 = nnf.interpolate(x_maps, size=(16, 16), mode='bicubic', align_corners=False)
295
+
296
+ x = torch.cat((x,x_maps16), 1)
297
+
298
+ x = self.deconv_layer1(x)
299
+ x = torch.cat((x,out3), 1)
300
+ x_maps32 = nnf.interpolate(x_maps, size=(32, 32), mode='bicubic', align_corners=False)
301
+ x = torch.cat((x,x_maps32), 1)
302
+
303
+ x = self.deconv_layer2(x)
304
+ x = torch.cat((x,out2), 1)
305
+ x_maps64 = nnf.interpolate(x_maps, size=(64, 64), mode='bicubic', align_corners=False)
306
+ x = torch.cat((x,x_maps64), 1)
307
+
308
+ x = self.deconv_layer3(x)
309
+ x = torch.cat((x,out1), 1)
310
+ x_maps128 = nnf.interpolate(x_maps, size=(128, 128), mode='bicubic', align_corners=False)
311
+
312
+ x = torch.cat((x,x_maps128), 1)
313
+
314
+ x = self.deconv_layer4(x)
315
+ x = torch.cat((x,x_maps), 1)
316
+
317
+ x = self.deconv_mix(x)
318
+
319
+ x = x.squeeze(1)
320
+
321
+ return x, pnas_vol_pred
322
+
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ wandb
2
+ pycocotools
3
+ torch==1.8.0+cu111
4
+ torchvision==0.9.0+cu111
5
+ torchaudio==0.8.0
6
+ libgl1-mesa-glx
7
+ ftfy
8
+ regex
9
+ tqdm
10
+ ipywidgets
11
+ seaborn
12
+ einops
13
+ clip-anytorch
14
+ pycocotools
15
+ kornia==0.5.10
testing/.DS_Store ADDED
Binary file (6.15 kB). View file
 
testing/gt/COCO_val2014_000000000192.png ADDED
testing/gt/COCO_val2014_000000000192_0.png ADDED
testing/gt/COCO_val2014_000000000192_1.png ADDED
testing/gt/COCO_val2014_000000000192_2.png ADDED
testing/gt/COCO_val2014_000000000192_3.png ADDED
testing/gt/COCO_val2014_000000000192_4.png ADDED
testing/gt/COCO_val2014_000000000208.png ADDED
testing/gt/COCO_val2014_000000000208_0.png ADDED
testing/gt/COCO_val2014_000000000208_1.png ADDED
testing/gt/COCO_val2014_000000000208_2.png ADDED
testing/gt/COCO_val2014_000000000208_3.png ADDED
testing/gt/COCO_val2014_000000000208_4.png ADDED
testing/images/COCO_val2014_000000000192.jpg ADDED
testing/images/COCO_val2014_000000000208.jpg ADDED
testing/predictions/Readme.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ Your predictions will appear in this folder after running the notebook.
train.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import torch
4
+ import sys
5
+ import time
6
+ import wandb
7
+ import torch.nn as nn
8
+ from tqdm import tqdm
9
+ from dataloader import SaliconDataset
10
+ from loss import *
11
+ from utils import AverageMeter
12
+ from utils import img_save
13
+ from torchvision import utils
14
+ import torch.nn.functional as nnf
15
+ from os.path import join
16
+ from PIL import Image
17
+
18
+
19
+ parser = argparse.ArgumentParser()
20
+ parser.add_argument('--no_epochs',default=30, type=int)
21
+ parser.add_argument('--lr',default=1e-5, type=float)
22
+ parser.add_argument('--kldiv',default=True, type=bool)
23
+ parser.add_argument('--cc',default=True, type=bool)
24
+ parser.add_argument('--nss',default=False, type=bool)
25
+ parser.add_argument('--sim',default=False, type=bool)
26
+ parser.add_argument('--nss_emlnet',default=False, type=bool)
27
+ parser.add_argument('--nss_norm',default=False, type=bool)
28
+ parser.add_argument('--l1',default=False, type=bool)
29
+ parser.add_argument('--lr_sched',default=False, type=bool)
30
+ parser.add_argument('--dilation',default=False, type=bool)
31
+ parser.add_argument('--enc_model',default="pnas", type=str)
32
+ parser.add_argument('--optim',default="Adam", type=str)
33
+
34
+ parser.add_argument('--load_weight',default=1, type=int)
35
+ parser.add_argument('--kldiv_coeff',default=1.0, type=float)
36
+ parser.add_argument('--step_size',default=5, type=int)
37
+ parser.add_argument('--cc_coeff',default=-1.0, type=float)
38
+ parser.add_argument('--sim_coeff',default=-1.0, type=float)
39
+ parser.add_argument('--nss_coeff',default=-1.0, type=float)
40
+ parser.add_argument('--nss_emlnet_coeff',default=1.0, type=float)
41
+ parser.add_argument('--nss_norm_coeff',default=1.0, type=float)
42
+ parser.add_argument('--l1_coeff',default=1.0, type=float)
43
+ parser.add_argument('--train_enc',default=1, type=int)
44
+
45
+ parser.add_argument('--dataset_dir',default="../data/", type=str)
46
+ parser.add_argument('--batch_size',default=32, type=int)
47
+ parser.add_argument('--log_interval',default=60, type=int)
48
+ parser.add_argument('--no_workers',default=4, type=int)
49
+ parser.add_argument('--train_model',default=False, type=bool)
50
+ parser.add_argument('--time_slices',default=5, type=int)
51
+ parser.add_argument('--selected_slices',default="", type=str)
52
+ parser.add_argument('--results_dir',default="", type=str )
53
+
54
+ # Path to save the model weights
55
+ parser.add_argument('--model_val_path',default="model.pt", type=str)
56
+ # If the model type is pnas_boosted, specify the path of the pre-trained pnas model here
57
+ parser.add_argument('--model_path',default="", type=str)
58
+ # If the model type is pnas_boosted, specify the path of the pre-trained pnasvol model here
59
+ parser.add_argument('--model_vol_path',default="", type=str)
60
+
61
+
62
+ args = parser.parse_args()
63
+
64
+ train_img_dir = args.dataset_dir + "images/train/"
65
+ train_gt_dir = args.dataset_dir + "maps/train/"
66
+ train_fix_dir = args.dataset_dir + "fixation_maps/train/"
67
+
68
+ val_img_dir = args.dataset_dir + "images/val/"
69
+ val_gt_dir = args.dataset_dir + "maps/val/"
70
+ val_fix_dir = args.dataset_dir + "fixation_maps/val/"
71
+
72
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
73
+
74
+ if args.enc_model == "pnas":
75
+ print("PNAS Model")
76
+ from model import PNASModel
77
+ model = PNASModel(train_enc=bool(args.train_enc), load_weight=args.load_weight)
78
+
79
+ elif args.enc_model == "pnas_boosted_multi":
80
+ print("PNAS Boosted Model PNASBoostedModelMultilevel")
81
+ from model import PNASBoostedModelMultilevel
82
+ model = PNASBoostedModelMultilevel(device, args.model_path, args.model_vol_path, args.time_slices, train_model=args.train_model,selected_slices = args.selected_slices )
83
+
84
+
85
+ if torch.cuda.device_count() > 1:
86
+ print("Let's use", torch.cuda.device_count(), "GPUs!")
87
+ model = nn.DataParallel(model)
88
+ model.to(device)
89
+
90
+
91
+ train_img_ids = [nm.split(".")[0] for nm in os.listdir(train_img_dir)]
92
+ val_img_ids = [nm.split(".")[0] for nm in os.listdir(val_img_dir)]
93
+
94
+ train_dataset = SaliconDataset(train_img_dir, train_gt_dir, train_fix_dir, train_img_ids)
95
+ val_dataset = SaliconDataset(val_img_dir, val_gt_dir, val_fix_dir, val_img_ids)
96
+
97
+ train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.no_workers)
98
+ val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.no_workers)
99
+
100
+
101
+
102
+ def loss_func(pred_map, gt, fixations, args):
103
+ loss = torch.FloatTensor([0.0]).cuda()
104
+ criterion = nn.L1Loss()
105
+ if args.kldiv:
106
+ loss += args.kldiv_coeff * kldiv(pred_map, gt)
107
+ if args.cc:
108
+ loss += args.cc_coeff * cc(pred_map, gt)
109
+ if args.nss:
110
+ loss += args.nss_coeff * nss(pred_map, fixations)
111
+ if args.l1:
112
+ loss += args.l1_coeff * criterion(pred_map, gt)
113
+ if args.sim:
114
+ loss += args.sim_coeff * similarity(pred_map, gt)
115
+ #print("Loss: ", loss)
116
+ return loss
117
+
118
+ def train(model, optimizer, loader, epoch, device, args):
119
+ model.train()
120
+
121
+ tic = time.time()
122
+
123
+ total_loss = 0.0
124
+ cur_loss = 0.0
125
+
126
+ for idx, (img, gt, fixations) in enumerate(loader):
127
+ img = img.to(device)
128
+ gt = gt.to(device)
129
+ fixations = fixations.to(device)
130
+
131
+ optimizer.zero_grad()
132
+ pred_map, vol_pred = model(img)
133
+
134
+ assert pred_map.size() == gt.size()
135
+
136
+ loss = loss_func(pred_map, gt, fixations, args)
137
+ loss.backward()
138
+
139
+ total_loss += loss.item()
140
+ cur_loss += loss.item()
141
+
142
+ optimizer.step()
143
+ if idx%args.log_interval==(args.log_interval-1):
144
+ print('[{:2d}, {:5d}] avg_loss : {:.5f}, time:{:3f} minutes'.format(epoch, idx, cur_loss/args.log_interval, (time.time()-tic)/60))
145
+ wandb.log({"loss": cur_loss/args.log_interval})
146
+ cur_loss = 0.0
147
+ sys.stdout.flush()
148
+
149
+ print('[{:2d}, train] avg_loss : {:.5f}'.format(epoch, total_loss/len(loader)))
150
+ sys.stdout.flush()
151
+
152
+ return total_loss/len(loader)
153
+
154
+ def validate(model, loader, epoch, device, args):
155
+ model.eval()
156
+ tic = time.time()
157
+ cc_loss = AverageMeter()
158
+ kldiv_loss = AverageMeter()
159
+ nss_loss = AverageMeter()
160
+ sim_loss = AverageMeter()
161
+
162
+ for (img, gt, fixations) in tqdm(loader):
163
+ img = img.to(device)
164
+ gt = gt.to(device)
165
+ fixations = fixations.to(device)
166
+
167
+ pred_map , vol_pred = model(img)
168
+
169
+ cc_loss.update(cc(pred_map, gt))
170
+ kldiv_loss.update(kldiv(pred_map, gt))
171
+ nss_loss.update(nss(pred_map, fixations))
172
+ sim_loss.update(similarity(pred_map, gt))
173
+
174
+ print('[{:2d}, val] CC : {:.5f}, KLDIV : {:.5f}, NSS : {:.5f}, SIM : {:.5f} time:{:3f} minutes'.format(epoch, cc_loss.avg, kldiv_loss.avg, nss_loss.avg, sim_loss.avg, (time.time()-tic)/60))
175
+ wandb.log({"CC": cc_loss.avg, 'KLDIV': kldiv_loss.avg, 'NSS': nss_loss.avg, 'SIM': sim_loss.avg})
176
+ sys.stdout.flush()
177
+
178
+ return cc_loss.avg,cc_loss,kldiv_loss,nss_loss,sim_loss
179
+
180
+ params = list(filter(lambda p: p.requires_grad, model.parameters()))
181
+
182
+ if args.optim=="Adam":
183
+ optimizer = torch.optim.Adam(params, lr=args.lr)
184
+ if args.optim=="Adagrad":
185
+ optimizer = torch.optim.Adagrad(params, lr=args.lr)
186
+ if args.optim=="SGD":
187
+ optimizer = torch.optim.SGD(params, lr=args.lr, momentum=0.9)
188
+ if args.lr_sched:
189
+ scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=0.1)
190
+
191
+ print(device)
192
+ best_loss = 0
193
+ for epoch in range(0, args.no_epochs):
194
+ loss = train(model, optimizer, train_loader, epoch, device, args)
195
+
196
+ with torch.no_grad():
197
+ cc_loss,cc_loss_obj,kldiv_loss,nss_loss,sim_loss = validate(model, val_loader, epoch, device, args)
198
+ cc_loss -=kldiv_loss.avg
199
+ if epoch == 0 :
200
+ best_loss = cc_loss
201
+ if best_loss <= cc_loss:
202
+ best_loss = cc_loss
203
+ print('[{:2d}, save, {}]'.format(epoch, args.model_val_path))
204
+ wandb.log({"Best/CC mean": cc_loss,"Best/CC median": cc_loss_obj.get_median(), "Best/CC std": cc_loss_obj.get_std(),
205
+ "Best/KLD mean": kldiv_loss.avg,"Best/KLD median": kldiv_loss.get_median(), "Best/KLD std": kldiv_loss.get_std(),
206
+ "Best/NSS mean": nss_loss.avg,"Best/NSS median": nss_loss.get_median(), "Best/NSS std": nss_loss.get_std(),
207
+ "Best/SIM mean": sim_loss.avg,"Best/SIM median": sim_loss.get_median(), "Best/SIM std": sim_loss.get_std()})
208
+ if torch.cuda.device_count() > 1:
209
+ torch.save(model.module.state_dict(), args.model_val_path)
210
+ else:
211
+ torch.save(model.state_dict(), args.model_val_path)
212
+ print()
213
+
214
+ if args.lr_sched:
215
+ scheduler.step()
utils.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import fnmatch
2
+ import os
3
+ import torch
4
+ import cv2
5
+
6
+ from tqdm import tqdm
7
+ from scipy.spatial import distance
8
+ from math import pi, sqrt, exp
9
+ from os.path import join
10
+ from torchvision import utils
11
+ from PIL import Image
12
+
13
+ import numpy as np
14
+ import scipy.io as sio
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ import matplotlib.pyplot as plt
18
+ import matplotlib.animation as animation
19
+
20
+ W = 640
21
+ H = 480
22
+ TIMESPAN = 5000
23
+ MAX_PIXEL_DISTANCE = 800
24
+ ESTIMATED_TIMESTAMP_WEIGHT = 0.006
25
+ RATIO = 0.9
26
+
27
+ FIXATION_PATH = '../data/fixations/'
28
+ FIX_VOL_PATH = '../data/fixation_volumes_'
29
+ SAL_VOL_PATH = '../data/saliency_volumes_'
30
+
31
+ class bcolors:
32
+ HEADER = '\033[95m'
33
+ OKBLUE = '\033[94m'
34
+ OKCYAN = '\033[96m'
35
+ OKGREEN = '\033[92m'
36
+ WARNING = '\033[93m'
37
+ FAIL = '\033[91m'
38
+ ENDC = '\033[0m'
39
+ BOLD = '\033[1m'
40
+ UNDERLINE = '\033[4m'
41
+
42
+ def get_colored_value(value, ref_value, increasing=True):
43
+ coef = 1 if increasing else -1
44
+ return (bcolors.FAIL if (coef * ref_value > coef * value) else bcolors.OKGREEN) + '{:.5f}'.format(value) + bcolors.ENDC
45
+
46
+ def get_filenames(path):
47
+ return [file for file in sorted(os.listdir(path)) if fnmatch.fnmatch(file, 'COCO_*')]
48
+
49
+ def parse_fixations(filenames,
50
+ path_prefix,
51
+ etw=ESTIMATED_TIMESTAMP_WEIGHT, progress_bar=True):
52
+ fixation_volumes = []
53
+ filenames = tqdm(filenames) if progress_bar else filenames
54
+
55
+ for filename in filenames:
56
+ # 1. Extracting data from .mat files
57
+ mat = sio.loadmat(path_prefix + filename + '.mat')
58
+ gaze = mat["gaze"]
59
+
60
+ locations = []
61
+ timestamps = []
62
+ fixations = []
63
+
64
+ for i in range(len(gaze)):
65
+ locations.append(mat["gaze"][i][0][0])
66
+ timestamps.append(mat["gaze"][i][0][1])
67
+ fixations.append(mat["gaze"][i][0][2])
68
+
69
+ # 2. Matching fixations with timestamps
70
+ fixation_volume = []
71
+ for i, observer in enumerate(fixations):
72
+ fix_timestamps = []
73
+ fix_time = TIMESPAN / (len(observer) + 1)
74
+ est_timestamp = fix_time
75
+
76
+ for fixation in observer:
77
+ distances = distance.cdist([fixation], locations[i], 'euclidean')[0][..., np.newaxis]
78
+ time_diffs = abs(timestamps[i] - est_timestamp)
79
+ min_idx = (etw * time_diffs + distances).argmin()
80
+
81
+ fix_timestamps.append([min(timestamps[i][min_idx][0], TIMESPAN), fixation.tolist()])
82
+ est_timestamp += fix_time
83
+
84
+ if (len(observer) > 0):
85
+ fixation_volume.append(fix_timestamps)
86
+
87
+ fixation_volumes.append(fixation_volume)
88
+
89
+ return fixation_volumes
90
+
91
+ def get_saliency_volume(fixation_volume, conv2D, time_slices):
92
+ fixation_map = torch.cuda.FloatTensor(time_slices,H,W).fill_(0)
93
+
94
+ for ts, coords in fixation_volume:
95
+ for (x, y) in coords:
96
+ fixation_map[ts,y-1,x-1] = 1
97
+
98
+ saliency_volume = conv2D.forward(fixation_map)
99
+ return saliency_volume / saliency_volume.max()
100
+
101
+ def blur(img):
102
+ k_size = 11
103
+ bl = cv2.GaussianBlur(img,(k_size,k_size),0)
104
+ return torch.FloatTensor(bl)
105
+
106
+ def visualize_model(model, loader, device, args):
107
+ with torch.no_grad():
108
+ model.eval()
109
+ os.makedirs(args.results_dir, exist_ok=True)
110
+
111
+ for (img, img_id, sz) in tqdm(loader):
112
+ img = img.to(device)
113
+
114
+ pred_map = model(img)
115
+ if type(pred_map) is tuple:
116
+ pred_map = pred_map[1]
117
+ pred_map = pred_map.cpu().squeeze(0).numpy()
118
+ pred_map = cv2.resize(pred_map, (sz[0], sz[1]))
119
+
120
+ pred_map = torch.FloatTensor(blur(pred_map))
121
+ img_save(pred_map, join(args.results_dir, img_id[0]), normalize=True)
122
+
123
+ def img_save(tensor, fp, nrow=8, padding=2,
124
+ normalize=False, range=None, scale_each=False, pad_value=0, format=None):
125
+ grid = utils.make_grid(tensor, nrow=nrow, padding=padding, pad_value=pad_value,
126
+ normalize=normalize, range=range, scale_each=scale_each)
127
+
128
+ ''' Add 0.5 after unnormalizing to [0, 255] to round to nearest integer '''
129
+
130
+ ndarr = torch.round(grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0)).to('cpu', torch.uint8).numpy()
131
+ im = Image.fromarray(ndarr[:,:,0])
132
+ #fp = fp[:-4] + '.png'
133
+ fp = '.png'
134
+ print(fp)
135
+ im.save("1.png", format=format, compress_level=0)
136
+
137
+ class AverageMeter(object):
138
+
139
+ '''Computers and stores the average and current value'''
140
+
141
+ def __init__(self):
142
+ self.reset()
143
+
144
+ def reset(self):
145
+ self.past = np.array([])
146
+ self.val = 0
147
+ self.avg = 0
148
+ self.sum = 0
149
+ self.count = 0
150
+
151
+ def update(self, val, n = 1):
152
+ self.val = val
153
+ self.sum += val*n
154
+ self.count += n
155
+ self.avg = self.sum / self.count
156
+ self.past = np.append(self.past,val.cpu())
157
+
158
+ def get_std (self):
159
+ return np.std(self.past)
160
+
161
+ def get_median (self):
162
+ return np.median(self.past)
163
+
164
+
165
+ def im2heat(pred_dir, a, gt, exten='.png'):
166
+ pred_nm = pred_dir + a + exten
167
+ pred = cv2.imread(pred_nm, 0)
168
+ heatmap_img = cv2.applyColorMap(pred, cv2.COLORMAP_JET)
169
+ heatmap_img = convert(heatmap_img)
170
+ pred = np.stack((pred, pred, pred),2).astype('float32')
171
+ pred = pred / 255.0
172
+
173
+ return np.uint8(pred * heatmap_img + (1.0-pred) * gt)
174
+
175
+ def get_heat_image(image):
176
+ return cv2.cvtColor(cv2.applyColorMap(np.uint8(255 * image), cv2.COLORMAP_HOT), cv2.COLOR_BGR2RGB)
177
+
178
+ def format_image(heatmap, image, max_value):
179
+ extended_map = heatmap / max_value
180
+ hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
181
+ factors = np.clip(2 * extended_map, 0, 1)
182
+ hsv[:,:,1] = np.uint8(factors * hsv[:,:,1])
183
+ hsv[:,:,2] = np.uint8((RATIO * factors + (1 - RATIO)) * hsv[:,:,2])
184
+
185
+ return get_heat_image(extended_map[:,:,np.newaxis]), cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB)
186
+
187
+ def animate(gt_vol, pred_vol, image):
188
+ fig = plt.figure(figsize=(16, 16))
189
+
190
+ gt_max = np.max(gt_vol);
191
+ pred_max = np.max(pred_vol);
192
+ formatted_images = []
193
+
194
+ for (gt_map, pred_map) in zip(gt_vol, pred_vol):
195
+ gt_heatmap, gt_image_heatmap = format_image(gt_map, image, gt_max)
196
+ gt_im = np.concatenate((gt_heatmap, gt_image_heatmap), 1)
197
+
198
+ pred_heatmap, pred_image_heatmap = format_image(pred_map, image, pred_max)
199
+ pred_im = np.concatenate((pred_heatmap, pred_image_heatmap), 1)
200
+
201
+ diff = 0.5 + ((gt_map / gt_max) - (pred_map / pred_max)) / 2
202
+ diff_im = cv2.cvtColor(cv2.applyColorMap(np.uint8(255 * diff[:,:,np.newaxis]), cv2.COLORMAP_TWILIGHT), cv2.COLOR_BGR2RGB)
203
+ diff_im = np.concatenate((diff_im, image), 1)
204
+ formatted_images.append([plt.imshow(np.concatenate((gt_im, pred_im, diff_im), 0), animated=True)])
205
+
206
+ return animation.ArtistAnimation(fig, formatted_images, interval=500, blit=True, repeat_delay=1000)
207
+
208
+ def animate_single_heatmap(gt_vol, image):
209
+ fig = plt.figure(figsize=(6.4, 4.8),frameon=False)
210
+ ax = plt.Axes(fig, [0., 0., 1., 1.])
211
+ ax.set_axis_off()
212
+ fig.add_axes(ax)
213
+ gt_max = np.max(gt_vol);
214
+ formatted_images = []
215
+ plt.axis('off')
216
+ for gt_map in gt_vol:
217
+ gt_heatmap, gt_image_heatmap = format_image(gt_map, image, gt_max)
218
+ gt_im = gt_heatmap
219
+ formatted_images.append([ax.imshow(gt_im, animated=True)])
220
+ return animation.ArtistAnimation(fig, formatted_images, interval=1000, blit=True, repeat_delay=1000)
221
+
222
+
223
+ def gauss(n, sigma):
224
+ r = range(-int(n/2),int(n/2)+1)
225
+ return [1 / (sigma * sqrt(2*pi)) * exp(-float(x)**2/(2*sigma**2)) for x in r]
226
+
227
+ class GaussianBlur1D(nn.Module):
228
+ def __init__(self, time_slices):
229
+ super(GaussianBlur1D, self).__init__()
230
+ sigma = 2 * time_slices / 25
231
+ self.size = 2 * int(4 * sigma + 0.5) + 1
232
+ kernel = gauss(self.size, sigma)
233
+ kernel = torch.cuda.FloatTensor(kernel)
234
+ self.weight = nn.Parameter(data=kernel, requires_grad=False)
235
+
236
+ def forward(self, x):
237
+ pad = int(self.size/2)
238
+ temp = F.conv1d(x, self.weight.view(1, 1, -1, 1, 1), padding=pad)
239
+ return temp[:,:,:,pad:-pad,pad:-pad]
240
+
241
+ class GaussianBlur2D(nn.Module):
242
+ def __init__(self):
243
+ super(GaussianBlur2D, self).__init__()
244
+ self.size = 201
245
+ kernel = gauss(self.size, 25)
246
+ kernel = torch.cuda.FloatTensor(kernel)
247
+ self.weight = nn.Parameter(data=kernel, requires_grad=False)
248
+
249
+ def forward(self, x):
250
+ pad = int(self.size/2)
251
+ temp = F.conv1d(x.unsqueeze(0).unsqueeze(0), self.weight.view(1, 1, 1, -1, 1), padding=pad)
252
+ temp = temp[:,:,pad:-pad,:,pad:-pad]
253
+ temp = F.conv1d(temp, self.weight.view(1, 1, 1, 1, -1), padding=pad)
254
+ return temp[:,:,pad:-pad,pad:-pad]