Upload 30 files
Browse files- PNAS/PNASnet.py +140 -0
- PNAS/__init__.py +1 -0
- PNAS/genotypes.py +33 -0
- PNAS/operations.py +115 -0
- README.md +1 -12
- checkpoints/Readme.txt +2 -0
- dataloader_clean.py +52 -0
- generate_volumes.py +57 -0
- inference.ipynb +0 -0
- loss.py +256 -0
- model.py +322 -0
- requirements.txt +15 -0
- testing/.DS_Store +0 -0
- testing/gt/COCO_val2014_000000000192.png +0 -0
- testing/gt/COCO_val2014_000000000192_0.png +0 -0
- testing/gt/COCO_val2014_000000000192_1.png +0 -0
- testing/gt/COCO_val2014_000000000192_2.png +0 -0
- testing/gt/COCO_val2014_000000000192_3.png +0 -0
- testing/gt/COCO_val2014_000000000192_4.png +0 -0
- testing/gt/COCO_val2014_000000000208.png +0 -0
- testing/gt/COCO_val2014_000000000208_0.png +0 -0
- testing/gt/COCO_val2014_000000000208_1.png +0 -0
- testing/gt/COCO_val2014_000000000208_2.png +0 -0
- testing/gt/COCO_val2014_000000000208_3.png +0 -0
- testing/gt/COCO_val2014_000000000208_4.png +0 -0
- testing/images/COCO_val2014_000000000192.jpg +0 -0
- testing/images/COCO_val2014_000000000208.jpg +0 -0
- testing/predictions/Readme.txt +1 -0
- train.py +215 -0
- utils.py +254 -0
PNAS/PNASnet.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from operations import *
|
| 4 |
+
from torch.autograd import Variable
|
| 5 |
+
# from utils import drop_path
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class Cell(nn.Module):
|
| 9 |
+
|
| 10 |
+
def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev):
|
| 11 |
+
super(Cell, self).__init__()
|
| 12 |
+
print(C_prev_prev, C_prev, C)
|
| 13 |
+
self.reduction = reduction
|
| 14 |
+
|
| 15 |
+
if reduction_prev is None:
|
| 16 |
+
self.preprocess0 = Identity()
|
| 17 |
+
elif reduction_prev is True:
|
| 18 |
+
self.preprocess0 = FactorizedReduce(C_prev_prev, C)
|
| 19 |
+
else:
|
| 20 |
+
self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0)
|
| 21 |
+
self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0)
|
| 22 |
+
|
| 23 |
+
if reduction:
|
| 24 |
+
op_names, indices = zip(*genotype.reduce)
|
| 25 |
+
concat = genotype.reduce_concat
|
| 26 |
+
else:
|
| 27 |
+
op_names, indices = zip(*genotype.normal)
|
| 28 |
+
concat = genotype.normal_concat
|
| 29 |
+
|
| 30 |
+
assert len(op_names) == len(indices)
|
| 31 |
+
self._steps = len(op_names) // 2
|
| 32 |
+
self._concat = concat
|
| 33 |
+
self.multiplier = len(concat)
|
| 34 |
+
|
| 35 |
+
self._ops = nn.ModuleList()
|
| 36 |
+
for name, index in zip(op_names, indices):
|
| 37 |
+
stride = 2 if reduction and index < 2 else 1
|
| 38 |
+
if reduction_prev is None and index == 0:
|
| 39 |
+
op = OPS[name](C_prev_prev, C, stride, True)
|
| 40 |
+
else:
|
| 41 |
+
op = OPS[name](C, C, stride, True)
|
| 42 |
+
self._ops += [op]
|
| 43 |
+
self._indices = indices
|
| 44 |
+
|
| 45 |
+
def forward(self, s0, s1, drop_prob):
|
| 46 |
+
s0 = self.preprocess0(s0)
|
| 47 |
+
s1 = self.preprocess1(s1)
|
| 48 |
+
|
| 49 |
+
states = [s0, s1]
|
| 50 |
+
for i in range(self._steps):
|
| 51 |
+
h1 = states[self._indices[2*i]]
|
| 52 |
+
h2 = states[self._indices[2*i+1]]
|
| 53 |
+
op1 = self._ops[2*i]
|
| 54 |
+
op2 = self._ops[2*i+1]
|
| 55 |
+
h1 = op1(h1)
|
| 56 |
+
h2 = op2(h2)
|
| 57 |
+
# if self.training and drop_prob > 0.:
|
| 58 |
+
# if not isinstance(op1, Identity):
|
| 59 |
+
# h1 = drop_path(h1, drop_prob)
|
| 60 |
+
# if not isinstance(op2, Identity):
|
| 61 |
+
# h2 = drop_path(h2, drop_prob)
|
| 62 |
+
s = h1 + h2
|
| 63 |
+
states += [s]
|
| 64 |
+
return torch.cat([states[i] for i in self._concat], dim=1)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class AuxiliaryHeadImageNet(nn.Module):
|
| 68 |
+
|
| 69 |
+
def __init__(self, C, num_classes):
|
| 70 |
+
"""assuming input size 14x14"""
|
| 71 |
+
super(AuxiliaryHeadImageNet, self).__init__()
|
| 72 |
+
self.features = nn.Sequential(
|
| 73 |
+
nn.ReLU(inplace=True),
|
| 74 |
+
nn.AvgPool2d(5, stride=2, padding=0, count_include_pad=False),
|
| 75 |
+
nn.Conv2d(C, 128, 1, bias=False),
|
| 76 |
+
nn.BatchNorm2d(128),
|
| 77 |
+
nn.ReLU(inplace=True),
|
| 78 |
+
nn.Conv2d(128, 768, 2, bias=False),
|
| 79 |
+
nn.BatchNorm2d(768),
|
| 80 |
+
nn.ReLU(inplace=True)
|
| 81 |
+
)
|
| 82 |
+
self.classifier = nn.Linear(768, num_classes)
|
| 83 |
+
|
| 84 |
+
def forward(self, x):
|
| 85 |
+
x = self.features(x)
|
| 86 |
+
x = self.classifier(x.view(x.size(0),-1))
|
| 87 |
+
return x
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class NetworkImageNet(nn.Module):
|
| 91 |
+
|
| 92 |
+
def __init__(self, C, num_classes, layers, auxiliary, genotype):
|
| 93 |
+
super(NetworkImageNet, self).__init__()
|
| 94 |
+
self._layers = layers
|
| 95 |
+
self._auxiliary = auxiliary
|
| 96 |
+
|
| 97 |
+
self.conv0 = nn.Conv2d(3, 96, kernel_size=3, stride=2, padding=0, bias=False)
|
| 98 |
+
self.conv0_bn = nn.BatchNorm2d(96, eps=1e-3)
|
| 99 |
+
self.stem1 = Cell(genotype, 96, 96, C // 4, True, None)
|
| 100 |
+
self.stem2 = Cell(genotype, 96, C * self.stem1.multiplier // 4, C // 2, True, True)
|
| 101 |
+
|
| 102 |
+
C_prev_prev, C_prev, C_curr = C * self.stem1.multiplier // 4, C * self.stem2.multiplier // 2, C
|
| 103 |
+
|
| 104 |
+
self.cells = nn.ModuleList()
|
| 105 |
+
reduction_prev = True
|
| 106 |
+
for i in range(layers):
|
| 107 |
+
if i in [layers // 3, 2 * layers // 3]:
|
| 108 |
+
C_curr *= 2
|
| 109 |
+
reduction = True
|
| 110 |
+
else:
|
| 111 |
+
reduction = False
|
| 112 |
+
cell = Cell(genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev)
|
| 113 |
+
reduction_prev = reduction
|
| 114 |
+
self.cells += [cell]
|
| 115 |
+
C_prev_prev, C_prev = C_prev, cell.multiplier * C_curr
|
| 116 |
+
if i == 2 * layers // 3:
|
| 117 |
+
C_to_auxiliary = C_prev
|
| 118 |
+
|
| 119 |
+
if auxiliary:
|
| 120 |
+
self.auxiliary_head = AuxiliaryHeadImageNet(C_to_auxiliary, num_classes)
|
| 121 |
+
self.relu = nn.ReLU(inplace=False)
|
| 122 |
+
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
| 123 |
+
self.classifier = nn.Linear(C_prev, num_classes)
|
| 124 |
+
|
| 125 |
+
def forward(self, input):
|
| 126 |
+
logits_aux = None
|
| 127 |
+
s0 = self.conv0(input)
|
| 128 |
+
s0 = self.conv0_bn(s0)
|
| 129 |
+
s1 = self.stem1(s0, s0, self.drop_path_prob)
|
| 130 |
+
s0, s1 = s1, self.stem2(s0, s1, self.drop_path_prob)
|
| 131 |
+
for i, cell in enumerate(self.cells):
|
| 132 |
+
s0, s1 = s1, cell(s0, s1, self.drop_path_prob)
|
| 133 |
+
if i == 2 * self._layers // 3:
|
| 134 |
+
if self._auxiliary and self.training:
|
| 135 |
+
logits_aux = self.auxiliary_head(s1)
|
| 136 |
+
s1 = self.relu(s1)
|
| 137 |
+
out = self.global_pooling(s1)
|
| 138 |
+
logits = self.classifier(out.view(out.size(0), -1))
|
| 139 |
+
return logits, logits_aux
|
| 140 |
+
|
PNAS/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
PNAS/genotypes.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import namedtuple
|
| 2 |
+
|
| 3 |
+
Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat')
|
| 4 |
+
|
| 5 |
+
PNASNet = Genotype(
|
| 6 |
+
normal = [
|
| 7 |
+
('sep_conv_5x5', 0),
|
| 8 |
+
('max_pool_3x3', 0),
|
| 9 |
+
('sep_conv_7x7', 1),
|
| 10 |
+
('max_pool_3x3', 1),
|
| 11 |
+
('sep_conv_5x5', 1),
|
| 12 |
+
('sep_conv_3x3', 1),
|
| 13 |
+
('sep_conv_3x3', 4),
|
| 14 |
+
('max_pool_3x3', 1),
|
| 15 |
+
('sep_conv_3x3', 0),
|
| 16 |
+
('skip_connect', 1),
|
| 17 |
+
],
|
| 18 |
+
normal_concat = [2, 3, 4, 5, 6],
|
| 19 |
+
reduce = [
|
| 20 |
+
('sep_conv_5x5', 0),
|
| 21 |
+
('max_pool_3x3', 0),
|
| 22 |
+
('sep_conv_7x7', 1),
|
| 23 |
+
('max_pool_3x3', 1),
|
| 24 |
+
('sep_conv_5x5', 1),
|
| 25 |
+
('sep_conv_3x3', 1),
|
| 26 |
+
('sep_conv_3x3', 4),
|
| 27 |
+
('max_pool_3x3', 1),
|
| 28 |
+
('sep_conv_3x3', 0),
|
| 29 |
+
('skip_connect', 1),
|
| 30 |
+
],
|
| 31 |
+
reduce_concat = [2, 3, 4, 5, 6],
|
| 32 |
+
)
|
| 33 |
+
|
PNAS/operations.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
OPS = {
|
| 5 |
+
'none' : lambda C_in, C_out, stride, affine: Zero(stride),
|
| 6 |
+
'avg_pool_3x3' : lambda C_in, C_out, stride, affine: nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False) if C_in == C_out else nn.Sequential(
|
| 7 |
+
nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False),
|
| 8 |
+
nn.Conv2d(C_in, C_out, 1, stride=1, padding=0, bias=False),
|
| 9 |
+
nn.BatchNorm2d(C_out, eps=1e-3, affine=affine)
|
| 10 |
+
),
|
| 11 |
+
'max_pool_3x3' : lambda C_in, C_out, stride, affine: nn.MaxPool2d(3, stride=stride, padding=1) if C_in == C_out else nn.Sequential(
|
| 12 |
+
nn.MaxPool2d(3, stride=stride, padding=1),
|
| 13 |
+
nn.Conv2d(C_in, C_out, 1, stride=1, padding=0, bias=False),
|
| 14 |
+
nn.BatchNorm2d(C_out, eps=1e-3, affine=affine)
|
| 15 |
+
),
|
| 16 |
+
'skip_connect' : lambda C_in, C_out, stride, affine: Identity() if stride == 1 else ReLUConvBN(C_in, C_out, 1, stride, 0, affine=affine),
|
| 17 |
+
'sep_conv_3x3' : lambda C_in, C_out, stride, affine: SepConv(C_in, C_out, 3, stride, 1, affine=affine),
|
| 18 |
+
'sep_conv_5x5' : lambda C_in, C_out, stride, affine: SepConv(C_in, C_out, 5, stride, 2, affine=affine),
|
| 19 |
+
'sep_conv_7x7' : lambda C_in, C_out, stride, affine: SepConv(C_in, C_out, 7, stride, 3, affine=affine),
|
| 20 |
+
'dil_conv_3x3' : lambda C_in, C_out, stride, affine: DilConv(C_in, C_out, 3, stride, 2, 2, affine=affine),
|
| 21 |
+
'dil_conv_5x5' : lambda C_in, C_out, stride, affine: DilConv(C_in, C_out, 5, stride, 4, 2, affine=affine),
|
| 22 |
+
'conv_7x1_1x7' : lambda C_in, C_out, stride, affine: nn.Sequential(
|
| 23 |
+
nn.ReLU(inplace=False),
|
| 24 |
+
nn.Conv2d(C_in, C_in, (1,7), stride=(1, stride), padding=(0, 3), bias=False),
|
| 25 |
+
nn.Conv2d(C_in, C_out, (7,1), stride=(stride, 1), padding=(3, 0), bias=False),
|
| 26 |
+
nn.BatchNorm2d(C_out, eps=1e-3, affine=affine)
|
| 27 |
+
),
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
class ReLUConvBN(nn.Module):
|
| 31 |
+
|
| 32 |
+
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
|
| 33 |
+
super(ReLUConvBN, self).__init__()
|
| 34 |
+
self.op = nn.Sequential(
|
| 35 |
+
nn.ReLU(inplace=False),
|
| 36 |
+
nn.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, bias=False),
|
| 37 |
+
nn.BatchNorm2d(C_out, eps=1e-3, affine=affine)
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
def forward(self, x):
|
| 41 |
+
return self.op(x)
|
| 42 |
+
|
| 43 |
+
class DilConv(nn.Module):
|
| 44 |
+
|
| 45 |
+
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True):
|
| 46 |
+
super(DilConv, self).__init__()
|
| 47 |
+
self.op = nn.Sequential(
|
| 48 |
+
nn.ReLU(inplace=False),
|
| 49 |
+
nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=C_in, bias=False),
|
| 50 |
+
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
|
| 51 |
+
nn.BatchNorm2d(C_out, eps=1e-3, affine=affine),
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
def forward(self, x):
|
| 55 |
+
return self.op(x)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class SepConv(nn.Module):
|
| 59 |
+
|
| 60 |
+
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
|
| 61 |
+
super(SepConv, self).__init__()
|
| 62 |
+
self.op = nn.Sequential(
|
| 63 |
+
nn.ReLU(inplace=False),
|
| 64 |
+
nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, groups=C_in, bias=False),
|
| 65 |
+
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
|
| 66 |
+
nn.BatchNorm2d(C_out, eps=1e-3, affine=affine),
|
| 67 |
+
nn.ReLU(inplace=False),
|
| 68 |
+
nn.Conv2d(C_out, C_out, kernel_size=kernel_size, stride=1, padding=padding, groups=C_out, bias=False),
|
| 69 |
+
nn.Conv2d(C_out, C_out, kernel_size=1, padding=0, bias=False),
|
| 70 |
+
nn.BatchNorm2d(C_out, eps=1e-3, affine=affine),
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
def forward(self, x):
|
| 74 |
+
return self.op(x)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class Identity(nn.Module):
|
| 78 |
+
|
| 79 |
+
def __init__(self):
|
| 80 |
+
super(Identity, self).__init__()
|
| 81 |
+
|
| 82 |
+
def forward(self, x):
|
| 83 |
+
return x
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class Zero(nn.Module):
|
| 87 |
+
|
| 88 |
+
def __init__(self, stride):
|
| 89 |
+
super(Zero, self).__init__()
|
| 90 |
+
self.stride = stride
|
| 91 |
+
|
| 92 |
+
def forward(self, x):
|
| 93 |
+
if self.stride == 1:
|
| 94 |
+
return x.mul(0.)
|
| 95 |
+
return x[:,:,::self.stride,::self.stride].mul(0.)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class FactorizedReduce(nn.Module):
|
| 99 |
+
|
| 100 |
+
def __init__(self, C_in, C_out, affine=True):
|
| 101 |
+
super(FactorizedReduce, self).__init__()
|
| 102 |
+
assert C_out % 2 == 0
|
| 103 |
+
self.relu = nn.ReLU(inplace=False)
|
| 104 |
+
self.conv_1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
|
| 105 |
+
self.conv_2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
|
| 106 |
+
self.bn = nn.BatchNorm2d(C_out, eps=1e-3, affine=affine)
|
| 107 |
+
self.pad = nn.ConstantPad2d((0, 1, 0, 1), 0)
|
| 108 |
+
|
| 109 |
+
def forward(self, x):
|
| 110 |
+
x = self.relu(x)
|
| 111 |
+
y = self.pad(x)
|
| 112 |
+
out = torch.cat([self.conv_1(x), self.conv_2(y[:,:,1:,1:])], dim=1)
|
| 113 |
+
out = self.bn(out)
|
| 114 |
+
return out
|
| 115 |
+
|
README.md
CHANGED
|
@@ -1,12 +1 @@
|
|
| 1 |
-
|
| 2 |
-
title: Tempsal
|
| 3 |
-
emoji: ⚡
|
| 4 |
-
colorFrom: gray
|
| 5 |
-
colorTo: red
|
| 6 |
-
sdk: gradio
|
| 7 |
-
sdk_version: 4.44.1
|
| 8 |
-
app_file: app.py
|
| 9 |
-
pinned: false
|
| 10 |
-
---
|
| 11 |
-
|
| 12 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
| 1 |
+
Download the model checkpoint from: https://drive.google.com/drive/folders/1W92oXYra_OPYkR1W56D80iDexWIR7f7Z?usp=sharing Follow the instructions on inference.ipynb. This notebook provides predictions on temporal and image saliency together.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
checkpoints/Readme.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Please download the checkpoint from the following link :
|
| 2 |
+
https://drive.google.com/drive/folders/1W92oXYra_OPYkR1W56D80iDexWIR7f7Z?usp=sharing
|
dataloader_clean.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torchvision import transforms
|
| 2 |
+
import torchvision.transforms.functional as TF
|
| 3 |
+
from PIL import Image
|
| 4 |
+
from torch.utils.data import DataLoader
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
import os, cv2
|
| 8 |
+
from utils import *
|
| 9 |
+
import json
|
| 10 |
+
import random
|
| 11 |
+
from pycocotools.coco import COCO
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class SaliconDataset(DataLoader):
|
| 15 |
+
def __init__(self, img_dir, gt_dir, fix_dir, img_ids, exten='.png'):
|
| 16 |
+
self.img_dir = img_dir
|
| 17 |
+
self.gt_dir = gt_dir
|
| 18 |
+
self.fix_dir = fix_dir
|
| 19 |
+
self.img_ids = img_ids
|
| 20 |
+
self.exten = exten
|
| 21 |
+
self.img_transform = transforms.Compose([
|
| 22 |
+
transforms.Resize((256, 256)),
|
| 23 |
+
transforms.ToTensor(),
|
| 24 |
+
transforms.Normalize([0.5, 0.5, 0.5],
|
| 25 |
+
[0.5, 0.5, 0.5])
|
| 26 |
+
])
|
| 27 |
+
|
| 28 |
+
def __getitem__(self, idx):
|
| 29 |
+
img_id = self.img_ids[idx]
|
| 30 |
+
img_path = os.path.join(self.img_dir, img_id + '.jpg')
|
| 31 |
+
gt_path = os.path.join(self.gt_dir, img_id + self.exten)
|
| 32 |
+
fix_path = os.path.join(self.fix_dir, img_id + self.exten)
|
| 33 |
+
|
| 34 |
+
img = Image.open(img_path).convert('RGB')
|
| 35 |
+
img = self.img_transform(img)
|
| 36 |
+
|
| 37 |
+
gt = np.array(Image.open(gt_path).convert('L'))
|
| 38 |
+
gt = gt.astype('float')
|
| 39 |
+
gt = cv2.resize(gt, (256,256))
|
| 40 |
+
if np.max(gt) > 1.0:
|
| 41 |
+
gt = gt / 255.0
|
| 42 |
+
|
| 43 |
+
fixations = np.array(Image.open(fix_path).convert('L'))
|
| 44 |
+
fixations = fixations.astype('float')
|
| 45 |
+
fixations = (fixations > 0.5).astype('float')
|
| 46 |
+
|
| 47 |
+
assert np.min(gt)>=0.0 and np.max(gt)<=1.0
|
| 48 |
+
assert np.min(fixations)==0.0 and np.max(fixations)==1.0
|
| 49 |
+
return img, torch.FloatTensor(gt), torch.FloatTensor(fixations)
|
| 50 |
+
|
| 51 |
+
def __len__(self):
|
| 52 |
+
return len(self.img_ids)
|
generate_volumes.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from utils import *
|
| 2 |
+
from operator import itemgetter
|
| 3 |
+
from itertools import groupby
|
| 4 |
+
import cv2
|
| 5 |
+
import argparse
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
parser = argparse.ArgumentParser()
|
| 10 |
+
parser.add_argument('--time_slices', default=5, type=int)
|
| 11 |
+
|
| 12 |
+
def generate_fixation_files(path, time_slices):
|
| 13 |
+
print('Parsing fixations of ' + path + '...')
|
| 14 |
+
filenames = [nm.split(".")[0] for nm in os.listdir(FIXATION_PATH + path)]
|
| 15 |
+
|
| 16 |
+
def create_dirs(dir_path):
|
| 17 |
+
if not os.path.exists(dir_path):
|
| 18 |
+
os.makedirs(dir_path)
|
| 19 |
+
|
| 20 |
+
dir_path = dir_path + '/' + path
|
| 21 |
+
if not os.path.exists(dir_path):
|
| 22 |
+
os.makedirs(dir_path)
|
| 23 |
+
|
| 24 |
+
return dir_path
|
| 25 |
+
|
| 26 |
+
sal_vol_path = create_dirs(SAL_VOL_PATH + str(time_slices))
|
| 27 |
+
fix_vol_path = create_dirs(FIX_VOL_PATH + str(time_slices))
|
| 28 |
+
|
| 29 |
+
conv2D = GaussianBlur2D().cuda()
|
| 30 |
+
|
| 31 |
+
print('Generating saliency volumes of ' + path + '...')
|
| 32 |
+
for filename in tqdm(filenames):
|
| 33 |
+
fixation_volume = parse_fixations([filename], FIXATION_PATH + path, progress_bar=False)[0]
|
| 34 |
+
fix_timestamps = sorted([fixation for fix_timestamps in fixation_volume
|
| 35 |
+
for fixation in fix_timestamps], key=lambda x: x[0])
|
| 36 |
+
fix_timestamps = np.array([(min(int(ts * time_slices / TIMESPAN), time_slices-1), (x, y)) for (ts, (x, y)) in fix_timestamps])
|
| 37 |
+
|
| 38 |
+
# Saving fixation map
|
| 39 |
+
fix_vol = np.zeros(shape=(time_slices,H,W))
|
| 40 |
+
for i, coords in fix_timestamps:
|
| 41 |
+
fix_vol[i, coords[1] - 1, coords[0] - 1] = 1
|
| 42 |
+
|
| 43 |
+
# Saving fixation list with timestamps
|
| 44 |
+
compressed = np.array([(key, list(v[1] for v in valuesiter))
|
| 45 |
+
for key,valuesiter in groupby(fix_timestamps, key=itemgetter(0))])
|
| 46 |
+
|
| 47 |
+
saliency_volume = get_saliency_volume(compressed, conv2D, time_slices)
|
| 48 |
+
saliency_volume = saliency_volume.squeeze(0).squeeze(0).detach().cpu().numpy()
|
| 49 |
+
|
| 50 |
+
for i, saliency_slice in enumerate(saliency_volume):
|
| 51 |
+
cv2.imwrite(sal_vol_path + filename + '_' + str(i) + '.png', 255 * saliency_slice)
|
| 52 |
+
cv2.imwrite(fix_vol_path + filename + '_' + str(i) + '.png', 255 * fix_vol[i])
|
| 53 |
+
|
| 54 |
+
args = parser.parse_args()
|
| 55 |
+
time_slices = args.time_slices
|
| 56 |
+
generate_fixation_files('train/', time_slices)
|
| 57 |
+
generate_fixation_files('val/', time_slices)
|
inference.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
loss.py
ADDED
|
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import cv2
|
| 4 |
+
|
| 5 |
+
def kldiv(s_map, gt):
|
| 6 |
+
batch_size = s_map.size(0)
|
| 7 |
+
w = s_map.size(1)
|
| 8 |
+
h = s_map.size(2)
|
| 9 |
+
|
| 10 |
+
sum_s_map = torch.sum(s_map.view(batch_size, -1), 1)
|
| 11 |
+
expand_s_map = sum_s_map.view(batch_size, 1, 1).expand(batch_size, w, h)
|
| 12 |
+
|
| 13 |
+
assert expand_s_map.size() == s_map.size()
|
| 14 |
+
|
| 15 |
+
sum_gt = torch.sum(gt.view(batch_size, -1), 1)
|
| 16 |
+
expand_gt = sum_gt.view(batch_size, 1, 1).expand(batch_size, w, h)
|
| 17 |
+
|
| 18 |
+
assert expand_gt.size() == gt.size()
|
| 19 |
+
|
| 20 |
+
s_map = s_map/(expand_s_map*1.0)
|
| 21 |
+
gt = gt / (expand_gt*1.0)
|
| 22 |
+
|
| 23 |
+
s_map = s_map.view(batch_size, -1)
|
| 24 |
+
gt = gt.view(batch_size, -1)
|
| 25 |
+
|
| 26 |
+
eps = 2.2204e-16
|
| 27 |
+
result = gt * torch.log(eps + gt/(s_map + eps))
|
| 28 |
+
# print(torch.log(eps + gt/(s_map + eps)) )
|
| 29 |
+
return torch.mean(torch.sum(result, 1))
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def normalize_map(s_map):
|
| 33 |
+
# normalize the salience map (as done in MIT code)
|
| 34 |
+
batch_size = s_map.size(0)
|
| 35 |
+
w = s_map.size(1)
|
| 36 |
+
h = s_map.size(2)
|
| 37 |
+
|
| 38 |
+
min_s_map = torch.min(s_map.view(batch_size, -1), 1)[0].view(batch_size, 1, 1).expand(batch_size, w, h)
|
| 39 |
+
max_s_map = torch.max(s_map.view(batch_size, -1), 1)[0].view(batch_size, 1, 1).expand(batch_size, w, h)
|
| 40 |
+
|
| 41 |
+
norm_s_map = (s_map - min_s_map)/(max_s_map-min_s_map*1.0)
|
| 42 |
+
return norm_s_map
|
| 43 |
+
|
| 44 |
+
def similarity(s_map, gt):
|
| 45 |
+
''' For single image metric
|
| 46 |
+
Size of Image - WxH or 1xWxH
|
| 47 |
+
gt is ground truth saliency map
|
| 48 |
+
'''
|
| 49 |
+
batch_size = s_map.size(0)
|
| 50 |
+
w = s_map.size(1)
|
| 51 |
+
h = s_map.size(2)
|
| 52 |
+
|
| 53 |
+
s_map = normalize_map(s_map)
|
| 54 |
+
gt = normalize_map(gt)
|
| 55 |
+
|
| 56 |
+
sum_s_map = torch.sum(s_map.view(batch_size, -1), 1)
|
| 57 |
+
expand_s_map = sum_s_map.view(batch_size, 1, 1).expand(batch_size, w, h)
|
| 58 |
+
|
| 59 |
+
assert expand_s_map.size() == s_map.size()
|
| 60 |
+
|
| 61 |
+
sum_gt = torch.sum(gt.view(batch_size, -1), 1)
|
| 62 |
+
expand_gt = sum_gt.view(batch_size, 1, 1).expand(batch_size, w, h)
|
| 63 |
+
|
| 64 |
+
s_map = s_map/(expand_s_map*1.0)
|
| 65 |
+
gt = gt / (expand_gt*1.0)
|
| 66 |
+
|
| 67 |
+
s_map = s_map.view(batch_size, -1)
|
| 68 |
+
gt = gt.view(batch_size, -1)
|
| 69 |
+
return torch.mean(torch.sum(torch.min(s_map, gt), 1))
|
| 70 |
+
|
| 71 |
+
def cc(s_map, gt):
|
| 72 |
+
batch_size = s_map.size(0)
|
| 73 |
+
w = s_map.size(1)
|
| 74 |
+
h = s_map.size(2)
|
| 75 |
+
|
| 76 |
+
mean_s_map = torch.mean(s_map.view(batch_size, -1), 1).view(batch_size, 1, 1).expand(batch_size, w, h)
|
| 77 |
+
std_s_map = torch.std(s_map.view(batch_size, -1), 1).view(batch_size, 1, 1).expand(batch_size, w, h)
|
| 78 |
+
|
| 79 |
+
mean_gt = torch.mean(gt.view(batch_size, -1), 1).view(batch_size, 1, 1).expand(batch_size, w, h)
|
| 80 |
+
std_gt = torch.std(gt.view(batch_size, -1), 1).view(batch_size, 1, 1).expand(batch_size, w, h)
|
| 81 |
+
|
| 82 |
+
s_map = (s_map - mean_s_map) / std_s_map
|
| 83 |
+
gt = (gt - mean_gt) / std_gt
|
| 84 |
+
|
| 85 |
+
ab = torch.sum((s_map * gt).view(batch_size, -1), 1)
|
| 86 |
+
aa = torch.sum((s_map * s_map).view(batch_size, -1), 1)
|
| 87 |
+
bb = torch.sum((gt * gt).view(batch_size, -1), 1)
|
| 88 |
+
|
| 89 |
+
return torch.mean(ab / (torch.sqrt(aa*bb)))
|
| 90 |
+
|
| 91 |
+
def nss(s_map, gt):
|
| 92 |
+
if s_map.size() != gt.size():
|
| 93 |
+
s_map = s_map.cpu().detach().numpy()
|
| 94 |
+
s_map = torch.FloatTensor([cv2.resize(map, (gt.size(2), gt.size(1))) for map in s_map])
|
| 95 |
+
s_map = s_map.cuda()
|
| 96 |
+
gt = gt.cuda()
|
| 97 |
+
|
| 98 |
+
assert s_map.size()==gt.size()
|
| 99 |
+
batch_size = s_map.size(0)
|
| 100 |
+
w = s_map.size(1)
|
| 101 |
+
h = s_map.size(2)
|
| 102 |
+
mean_s_map = torch.mean(s_map.view(batch_size, -1), 1).view(batch_size, 1, 1).expand(batch_size, w, h)
|
| 103 |
+
std_s_map = torch.std(s_map.view(batch_size, -1), 1).view(batch_size, 1, 1).expand(batch_size, w, h)
|
| 104 |
+
|
| 105 |
+
eps = 2.2204e-16
|
| 106 |
+
s_map = (s_map - mean_s_map) / (std_s_map + eps)
|
| 107 |
+
|
| 108 |
+
s_map = torch.sum((s_map * gt).view(batch_size, -1), 1)
|
| 109 |
+
count = torch.sum(gt.view(batch_size, -1), 1)
|
| 110 |
+
return torch.mean(s_map / count)
|
| 111 |
+
|
| 112 |
+
def auc_judd(saliencyMap, fixationMap, jitter=True, normalize=False):
|
| 113 |
+
# saliencyMap is the saliency map
|
| 114 |
+
# fixationMap is the human fixation map (binary matrix)
|
| 115 |
+
# jitter=True will add tiny non-zero random constant to all map locations to ensure
|
| 116 |
+
# ROC can be calculated robustly (to avoid uniform region)
|
| 117 |
+
|
| 118 |
+
# If there are no fixations to predict, return NaN
|
| 119 |
+
if saliencyMap.size() != fixationMap.size():
|
| 120 |
+
saliencyMap = saliencyMap.cpu().squeeze(0).numpy()
|
| 121 |
+
saliencyMap = torch.FloatTensor(cv2.resize(saliencyMap, (fixationMap.size(2), fixationMap.size(1)))).unsqueeze(0)
|
| 122 |
+
# saliencyMap = saliencyMap.cuda()
|
| 123 |
+
# fixationMap = fixationMap.cuda()
|
| 124 |
+
if len(saliencyMap.size())==3:
|
| 125 |
+
saliencyMap = saliencyMap[0,:,:]
|
| 126 |
+
fixationMap = fixationMap[0,:,:]
|
| 127 |
+
saliencyMap = saliencyMap.numpy()
|
| 128 |
+
fixationMap = fixationMap.numpy()
|
| 129 |
+
if normalize:
|
| 130 |
+
saliencyMap = normalize_map(saliencyMap)
|
| 131 |
+
|
| 132 |
+
if not fixationMap.any():
|
| 133 |
+
print('Error: no fixationMap')
|
| 134 |
+
score = float('nan')
|
| 135 |
+
return score
|
| 136 |
+
|
| 137 |
+
# make the saliencyMap the size of the image of fixationMap
|
| 138 |
+
|
| 139 |
+
if not np.shape(saliencyMap) == np.shape(fixationMap):
|
| 140 |
+
from scipy.misc import imresize
|
| 141 |
+
saliencyMap = imresize(saliencyMap, np.shape(fixationMap))
|
| 142 |
+
|
| 143 |
+
# jitter saliency maps that come from saliency models that have a lot of zero values.
|
| 144 |
+
# If the saliency map is made with a Gaussian then it does not need to be jittered as
|
| 145 |
+
# the values are varied and there is not a large patch of the same value. In fact
|
| 146 |
+
# jittering breaks the ordering in the small values!
|
| 147 |
+
if jitter:
|
| 148 |
+
# jitter the saliency map slightly to distrupt ties of the same numbers
|
| 149 |
+
saliencyMap = saliencyMap + np.random.random(np.shape(saliencyMap)) / 10 ** 7
|
| 150 |
+
|
| 151 |
+
# normalize saliency map
|
| 152 |
+
saliencyMap = (saliencyMap - saliencyMap.min()) \
|
| 153 |
+
/ (saliencyMap.max() - saliencyMap.min())
|
| 154 |
+
|
| 155 |
+
if np.isnan(saliencyMap).all():
|
| 156 |
+
print('NaN saliencyMap')
|
| 157 |
+
score = float('nan')
|
| 158 |
+
return score
|
| 159 |
+
|
| 160 |
+
S = saliencyMap.flatten()
|
| 161 |
+
F = fixationMap.flatten()
|
| 162 |
+
|
| 163 |
+
Sth = S[F > 0] # sal map values at fixation locations
|
| 164 |
+
Nfixations = len(Sth)
|
| 165 |
+
Npixels = len(S)
|
| 166 |
+
|
| 167 |
+
allthreshes = sorted(Sth, reverse=True) # sort sal map values, to sweep through values
|
| 168 |
+
tp = np.zeros((Nfixations + 2))
|
| 169 |
+
fp = np.zeros((Nfixations + 2))
|
| 170 |
+
tp[0], tp[-1] = 0, 1
|
| 171 |
+
fp[0], fp[-1] = 0, 1
|
| 172 |
+
|
| 173 |
+
for i in range(Nfixations):
|
| 174 |
+
thresh = allthreshes[i]
|
| 175 |
+
aboveth = (S >= thresh).sum() # total number of sal map values above threshold
|
| 176 |
+
tp[i + 1] = float(i + 1) / Nfixations # ratio sal map values at fixation locations
|
| 177 |
+
# above threshold
|
| 178 |
+
fp[i + 1] = float(aboveth - i) / (Npixels - Nfixations) # ratio other sal map values
|
| 179 |
+
# above threshold
|
| 180 |
+
|
| 181 |
+
score = np.trapz(tp, x=fp)
|
| 182 |
+
allthreshes = np.insert(allthreshes, 0, 0)
|
| 183 |
+
allthreshes = np.append(allthreshes, 1)
|
| 184 |
+
|
| 185 |
+
return score
|
| 186 |
+
|
| 187 |
+
def auc_shuff(s_map,gt,other_map,splits=100,stepsize=0.1):
|
| 188 |
+
|
| 189 |
+
if len(s_map.size())==3:
|
| 190 |
+
s_map = s_map[0,:,:]
|
| 191 |
+
gt = gt[0,:,:]
|
| 192 |
+
other_map = other_map[0,:,:]
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
s_map = s_map.numpy()
|
| 196 |
+
s_map = normalize_map(s_map)
|
| 197 |
+
gt = gt.numpy()
|
| 198 |
+
other_map = other_map.numpy()
|
| 199 |
+
|
| 200 |
+
num_fixations = np.sum(gt)
|
| 201 |
+
|
| 202 |
+
x,y = np.where(other_map==1)
|
| 203 |
+
other_map_fixs = []
|
| 204 |
+
for j in zip(x,y):
|
| 205 |
+
other_map_fixs.append(j[0]*other_map.shape[0] + j[1])
|
| 206 |
+
ind = len(other_map_fixs)
|
| 207 |
+
assert ind==np.sum(other_map), 'something is wrong in auc shuffle'
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
num_fixations_other = min(ind,num_fixations)
|
| 211 |
+
|
| 212 |
+
num_pixels = s_map.shape[0]*s_map.shape[1]
|
| 213 |
+
random_numbers = []
|
| 214 |
+
for i in range(0,splits):
|
| 215 |
+
temp_list = []
|
| 216 |
+
t1 = np.random.permutation(ind)
|
| 217 |
+
for k in t1:
|
| 218 |
+
temp_list.append(other_map_fixs[k])
|
| 219 |
+
random_numbers.append(temp_list)
|
| 220 |
+
|
| 221 |
+
aucs = []
|
| 222 |
+
# for each split, calculate auc
|
| 223 |
+
for i in random_numbers:
|
| 224 |
+
r_sal_map = []
|
| 225 |
+
for k in i:
|
| 226 |
+
r_sal_map.append(s_map[k%s_map.shape[0]-1, int(k/s_map.shape[0])])
|
| 227 |
+
# in these values, we need to find thresholds and calculate auc
|
| 228 |
+
thresholds = [0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9]
|
| 229 |
+
|
| 230 |
+
r_sal_map = np.array(r_sal_map)
|
| 231 |
+
|
| 232 |
+
# once threshs are got
|
| 233 |
+
thresholds = sorted(set(thresholds))
|
| 234 |
+
area = []
|
| 235 |
+
area.append((0.0,0.0))
|
| 236 |
+
for thresh in thresholds:
|
| 237 |
+
# in the salience map, keep only those pixels with values above threshold
|
| 238 |
+
temp = np.zeros(s_map.shape)
|
| 239 |
+
temp[s_map>=thresh] = 1.0
|
| 240 |
+
num_overlap = np.where(np.add(temp,gt)==2)[0].shape[0]
|
| 241 |
+
tp = num_overlap/(num_fixations*1.0)
|
| 242 |
+
|
| 243 |
+
#fp = (np.sum(temp) - num_overlap)/((np.shape(gt)[0] * np.shape(gt)[1]) - num_fixations)
|
| 244 |
+
# number of values in r_sal_map, above the threshold, divided by num of random locations = num of fixations
|
| 245 |
+
fp = len(np.where(r_sal_map>thresh)[0])/(num_fixations*1.0)
|
| 246 |
+
|
| 247 |
+
area.append((round(tp,4),round(fp,4)))
|
| 248 |
+
|
| 249 |
+
area.append((1.0,1.0))
|
| 250 |
+
area.sort(key = lambda x:x[0])
|
| 251 |
+
tp_list = [x[0] for x in area]
|
| 252 |
+
fp_list = [x[1] for x in area]
|
| 253 |
+
|
| 254 |
+
aucs.append(np.trapz(np.array(tp_list),np.array(fp_list)))
|
| 255 |
+
|
| 256 |
+
return np.mean(aucs)
|
model.py
ADDED
|
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torchvision.models as models
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from collections import OrderedDict
|
| 5 |
+
import sys
|
| 6 |
+
from einops import rearrange, repeat
|
| 7 |
+
from einops.layers.torch import Rearrange
|
| 8 |
+
from scipy import ndimage
|
| 9 |
+
|
| 10 |
+
sys.path.append('./PNAS/')
|
| 11 |
+
from PNASnet import *
|
| 12 |
+
from genotypes import PNASNet
|
| 13 |
+
import torch.nn.functional as nnf
|
| 14 |
+
import numpy as np
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class PNASModel(nn.Module):
|
| 19 |
+
|
| 20 |
+
def __init__(self, num_channels=3, train_enc=False, load_weight=1):
|
| 21 |
+
super(PNASModel, self).__init__()
|
| 22 |
+
self.pnas = NetworkImageNet(216, 1001, 12, False, PNASNet)
|
| 23 |
+
if load_weight:
|
| 24 |
+
self.pnas.load_state_dict(torch.load(self.path))
|
| 25 |
+
|
| 26 |
+
for param in self.pnas.parameters():
|
| 27 |
+
param.requires_grad = train_enc
|
| 28 |
+
|
| 29 |
+
self.padding = nn.ConstantPad2d((0,1,0,1),0)
|
| 30 |
+
self.drop_path_prob = 0
|
| 31 |
+
|
| 32 |
+
self.linear_upsampling = nn.UpsamplingBilinear2d(scale_factor=2)
|
| 33 |
+
|
| 34 |
+
self.deconv_layer0 = nn.Sequential(
|
| 35 |
+
nn.Conv2d(in_channels = 4320, out_channels = 512, kernel_size=3, padding=1, bias = True),
|
| 36 |
+
nn.ReLU(inplace=True),
|
| 37 |
+
self.linear_upsampling
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
self.deconv_layer1 = nn.Sequential(
|
| 41 |
+
nn.Conv2d(in_channels = 512+2160, out_channels = 256, kernel_size = 3, padding = 1, bias = True),
|
| 42 |
+
nn.ReLU(inplace=True),
|
| 43 |
+
self.linear_upsampling
|
| 44 |
+
)
|
| 45 |
+
self.deconv_layer2 = nn.Sequential(
|
| 46 |
+
nn.Conv2d(in_channels = 1080+256, out_channels = 270, kernel_size = 3, padding = 1, bias = True),
|
| 47 |
+
nn.ReLU(inplace=True),
|
| 48 |
+
self.linear_upsampling
|
| 49 |
+
)
|
| 50 |
+
self.deconv_layer3 = nn.Sequential(
|
| 51 |
+
nn.Conv2d(in_channels = 540, out_channels = 96, kernel_size = 3, padding = 1, bias = True),
|
| 52 |
+
nn.ReLU(inplace=True),
|
| 53 |
+
self.linear_upsampling
|
| 54 |
+
)
|
| 55 |
+
self.deconv_layer4 = nn.Sequential(
|
| 56 |
+
nn.Conv2d(in_channels = 192, out_channels = 128, kernel_size = 3, padding = 1, bias = True),
|
| 57 |
+
nn.ReLU(inplace=True),
|
| 58 |
+
self.linear_upsampling
|
| 59 |
+
)
|
| 60 |
+
self.deconv_layer5 = nn.Sequential(
|
| 61 |
+
nn.Conv2d(in_channels = 128, out_channels = 128, kernel_size = 3, padding = 1, bias = True),
|
| 62 |
+
nn.ReLU(inplace=True),
|
| 63 |
+
nn.Conv2d(in_channels = 128, out_channels = 1, kernel_size = 3, padding = 1, bias = True),
|
| 64 |
+
nn.Sigmoid()
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
def forward(self, images):
|
| 68 |
+
batch_size = images.size(0)
|
| 69 |
+
|
| 70 |
+
s0 = self.pnas.conv0(images)
|
| 71 |
+
s0 = self.pnas.conv0_bn(s0)
|
| 72 |
+
out1 = self.padding(s0)
|
| 73 |
+
|
| 74 |
+
s1 = self.pnas.stem1(s0, s0, self.drop_path_prob)
|
| 75 |
+
out2 = s1
|
| 76 |
+
s0, s1 = s1, self.pnas.stem2(s0, s1, 0)
|
| 77 |
+
|
| 78 |
+
for i, cell in enumerate(self.pnas.cells):
|
| 79 |
+
s0, s1 = s1, cell(s0, s1, 0)
|
| 80 |
+
if i==3:
|
| 81 |
+
out3 = s1
|
| 82 |
+
if i==7:
|
| 83 |
+
out4 = s1
|
| 84 |
+
if i==11:
|
| 85 |
+
out5 = s1
|
| 86 |
+
|
| 87 |
+
out5 = self.deconv_layer0(out5)
|
| 88 |
+
|
| 89 |
+
x = torch.cat((out5,out4), 1)
|
| 90 |
+
x = self.deconv_layer1(x)
|
| 91 |
+
|
| 92 |
+
x = torch.cat((x,out3), 1)
|
| 93 |
+
x = self.deconv_layer2(x)
|
| 94 |
+
|
| 95 |
+
x = torch.cat((x,out2), 1)
|
| 96 |
+
x = self.deconv_layer3(x)
|
| 97 |
+
x = torch.cat((x,out1), 1)
|
| 98 |
+
|
| 99 |
+
x = self.deconv_layer4(x)
|
| 100 |
+
|
| 101 |
+
x = self.deconv_layer5(x)
|
| 102 |
+
x = x.squeeze(1)
|
| 103 |
+
# print("PNAS pred actual pnas:", x.mean(),x.min(), x.max(), x.sum())
|
| 104 |
+
|
| 105 |
+
return x
|
| 106 |
+
|
| 107 |
+
class PNASVolModellast(nn.Module):
|
| 108 |
+
|
| 109 |
+
def __init__(self, time_slices, num_channels=3, train_enc=False, load_weight=1):
|
| 110 |
+
super(PNASVolModellast, self).__init__()
|
| 111 |
+
|
| 112 |
+
self.pnas = NetworkImageNet(216, 1001, 12, False, PNASNet)
|
| 113 |
+
if load_weight:
|
| 114 |
+
state_dict = torch.load(self.path)
|
| 115 |
+
new_state_dict = OrderedDict()
|
| 116 |
+
for k, v in state_dict.items():
|
| 117 |
+
if 'module' in k:
|
| 118 |
+
k = 'module.pnas.' + k
|
| 119 |
+
else:
|
| 120 |
+
k = k.replace('pnas.', '')
|
| 121 |
+
new_state_dict[k] = v
|
| 122 |
+
self.pnas.load_state_dict(new_state_dict, strict=False)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
for param in self.pnas.parameters():
|
| 126 |
+
param.requires_grad = train_enc
|
| 127 |
+
|
| 128 |
+
self.padding = nn.ConstantPad2d((0,1,0,1),0)
|
| 129 |
+
self.drop_path_prob = 0
|
| 130 |
+
|
| 131 |
+
self.linear_upsampling = nn.UpsamplingBilinear2d(scale_factor=2)
|
| 132 |
+
|
| 133 |
+
self.deconv_layer0 = nn.Sequential(
|
| 134 |
+
nn.Conv2d(in_channels = 4320, out_channels = 512, kernel_size=3, padding=1, bias = True),
|
| 135 |
+
nn.ReLU(inplace=True),
|
| 136 |
+
self.linear_upsampling
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
self.deconv_layer1 = nn.Sequential(
|
| 140 |
+
nn.Conv2d(in_channels = 512+2160, out_channels = 256, kernel_size = 3, padding = 1, bias = True),
|
| 141 |
+
nn.ReLU(inplace=True),
|
| 142 |
+
self.linear_upsampling
|
| 143 |
+
)
|
| 144 |
+
self.deconv_layer2 = nn.Sequential(
|
| 145 |
+
nn.Conv2d(in_channels = 1080+256, out_channels = 270, kernel_size = 3, padding = 1, bias = True),
|
| 146 |
+
nn.ReLU(inplace=True),
|
| 147 |
+
self.linear_upsampling
|
| 148 |
+
)
|
| 149 |
+
self.deconv_layer3 = nn.Sequential(
|
| 150 |
+
nn.Conv2d(in_channels = 540, out_channels = 96, kernel_size = 3, padding = 1, bias = True),
|
| 151 |
+
nn.ReLU(inplace=True),
|
| 152 |
+
self.linear_upsampling
|
| 153 |
+
)
|
| 154 |
+
self.deconv_layer4 = nn.Sequential(
|
| 155 |
+
nn.Conv2d(in_channels = 192, out_channels = 128, kernel_size = 3, padding = 1, bias = True),
|
| 156 |
+
nn.ReLU(inplace=True),
|
| 157 |
+
self.linear_upsampling
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
self.deconv_layer5 = nn.Sequential(
|
| 161 |
+
nn.Conv2d(in_channels = 128, out_channels = 64, kernel_size = 3, padding = 1, bias = True),
|
| 162 |
+
nn.ReLU(inplace=True),
|
| 163 |
+
nn.Conv2d(in_channels = 64, out_channels = 32, kernel_size = 3, padding = 1, bias = True),
|
| 164 |
+
nn.ReLU(inplace=True),
|
| 165 |
+
nn.Conv2d(in_channels = 32, out_channels = time_slices, kernel_size = 3, padding = 1, bias = True),
|
| 166 |
+
nn.Sigmoid()
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
def forward(self, images):
|
| 170 |
+
s0 = self.pnas.conv0(images)
|
| 171 |
+
s0 = self.pnas.conv0_bn(s0)
|
| 172 |
+
out1 = self.padding(s0)
|
| 173 |
+
|
| 174 |
+
s1 = self.pnas.stem1(s0, s0, self.drop_path_prob)
|
| 175 |
+
out2 = s1
|
| 176 |
+
s0, s1 = s1, self.pnas.stem2(s0, s1, 0)
|
| 177 |
+
|
| 178 |
+
for i, cell in enumerate(self.pnas.cells):
|
| 179 |
+
s0, s1 = s1, cell(s0, s1, 0)
|
| 180 |
+
if i==3:
|
| 181 |
+
out3 = s1
|
| 182 |
+
if i==7:
|
| 183 |
+
out4 = s1
|
| 184 |
+
if i==11:
|
| 185 |
+
out5 = s1
|
| 186 |
+
|
| 187 |
+
out5 = self.deconv_layer0(out5)
|
| 188 |
+
|
| 189 |
+
x = torch.cat((out5,out4), 1)
|
| 190 |
+
x = self.deconv_layer1(x)
|
| 191 |
+
|
| 192 |
+
x = torch.cat((x,out3), 1)
|
| 193 |
+
x = self.deconv_layer2(x)
|
| 194 |
+
|
| 195 |
+
x = torch.cat((x,out2), 1)
|
| 196 |
+
x = self.deconv_layer3(x)
|
| 197 |
+
x = torch.cat((x,out1), 1)
|
| 198 |
+
|
| 199 |
+
x = self.deconv_layer4(x)
|
| 200 |
+
|
| 201 |
+
x = self.deconv_layer5(x)
|
| 202 |
+
x = x / x.max()
|
| 203 |
+
|
| 204 |
+
return x , [out1,out2,out3,out4,out5]
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
class PNASBoostedModelMultiLevel(nn.Module):
|
| 208 |
+
|
| 209 |
+
def __init__(self, device, model_path, model_vol_path, time_slices, train_model=False, selected_slices=""):
|
| 210 |
+
super(PNASBoostedModelMultiLevel, self).__init__()
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
self.selected_slices = selected_slices
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
self.linear_upsampling = nn.UpsamplingBilinear2d(scale_factor=2)
|
| 217 |
+
|
| 218 |
+
self.deconv_layer1 = nn.Sequential(
|
| 219 |
+
nn.Conv2d(in_channels = 512+2160+6, out_channels = 256, kernel_size = 3, padding = 1, bias = True),
|
| 220 |
+
nn.ReLU(inplace=True),
|
| 221 |
+
self.linear_upsampling
|
| 222 |
+
)
|
| 223 |
+
self.deconv_layer2 = nn.Sequential(
|
| 224 |
+
nn.Conv2d(in_channels = 1080+256+6, out_channels = 270, kernel_size = 3, padding = 1, bias = True),
|
| 225 |
+
nn.ReLU(inplace=True),
|
| 226 |
+
self.linear_upsampling
|
| 227 |
+
)
|
| 228 |
+
self.deconv_layer3 = nn.Sequential(
|
| 229 |
+
nn.Conv2d(in_channels = 540+6, out_channels = 96, kernel_size = 3, padding = 1, bias = True),
|
| 230 |
+
nn.ReLU(inplace=True),
|
| 231 |
+
self.linear_upsampling
|
| 232 |
+
)
|
| 233 |
+
self.deconv_layer4 = nn.Sequential(
|
| 234 |
+
nn.Conv2d(in_channels = 192+6, out_channels = 128, kernel_size = 3, padding = 1, bias = True),
|
| 235 |
+
nn.ReLU(inplace=True),
|
| 236 |
+
self.linear_upsampling
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
self.deconv_mix = nn.Sequential(
|
| 241 |
+
nn.Conv2d(in_channels = 128+6 , out_channels = 16, kernel_size = 3, padding = 1, bias = True),
|
| 242 |
+
nn.ReLU(inplace=True),
|
| 243 |
+
nn.Conv2d(in_channels = 16, out_channels = 32, kernel_size = 3, padding = 1, bias = True),
|
| 244 |
+
nn.ReLU(inplace=True),
|
| 245 |
+
nn.Conv2d(in_channels = 32, out_channels = 1, kernel_size = 3, padding = 1, bias = True),
|
| 246 |
+
nn.Sigmoid()
|
| 247 |
+
)
|
| 248 |
+
model_vol = PNASVolModellast(time_slices=5, load_weight=0) #change this to time slices
|
| 249 |
+
model_vol = nn.DataParallel(model_vol).cuda()
|
| 250 |
+
state_dict = torch.load(model_path)
|
| 251 |
+
vol_state_dict = OrderedDict()
|
| 252 |
+
sal_state_dict = OrderedDict()
|
| 253 |
+
smm_state_dict = OrderedDict()
|
| 254 |
+
|
| 255 |
+
for k, v in state_dict.items():
|
| 256 |
+
if 'pnas_vol' in k:
|
| 257 |
+
|
| 258 |
+
k = k.replace('pnas_vol.module.', '')
|
| 259 |
+
vol_state_dict[k] = v
|
| 260 |
+
elif 'pnas_sal' in k:
|
| 261 |
+
k = k.replace('pnas_sal.module.', '')
|
| 262 |
+
sal_state_dict[k] = v
|
| 263 |
+
else:
|
| 264 |
+
smm_state_dict[k] = v
|
| 265 |
+
|
| 266 |
+
self.load_state_dict(smm_state_dict)
|
| 267 |
+
model_vol.load_state_dict(vol_state_dict)
|
| 268 |
+
self.pnas_vol = nn.DataParallel(model_vol).cuda()
|
| 269 |
+
|
| 270 |
+
for param in self.pnas_vol.parameters():
|
| 271 |
+
param.requires_grad = False
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
model = PNASModel(load_weight=0)
|
| 275 |
+
model = nn.DataParallel(model).cuda()
|
| 276 |
+
|
| 277 |
+
model.load_state_dict(sal_state_dict, strict=True)
|
| 278 |
+
self.pnas_sal = nn.DataParallel(model).to(device)
|
| 279 |
+
|
| 280 |
+
for param in self.pnas_sal.parameters():
|
| 281 |
+
param.requires_grad = False #train_model
|
| 282 |
+
|
| 283 |
+
def forward(self, images):
|
| 284 |
+
# print("IMAGES", images.shape)
|
| 285 |
+
|
| 286 |
+
pnas_pred = self.pnas_sal(images).unsqueeze(1)
|
| 287 |
+
pnas_vol_pred , outs = self.pnas_vol(images)
|
| 288 |
+
|
| 289 |
+
out1 , out2, out3, out4, out5 = outs
|
| 290 |
+
#print(pnas_vol_pred.shape)
|
| 291 |
+
x_maps = torch.cat((pnas_pred, pnas_vol_pred), 1)
|
| 292 |
+
|
| 293 |
+
x = torch.cat((out5,out4), 1)
|
| 294 |
+
x_maps16 = nnf.interpolate(x_maps, size=(16, 16), mode='bicubic', align_corners=False)
|
| 295 |
+
|
| 296 |
+
x = torch.cat((x,x_maps16), 1)
|
| 297 |
+
|
| 298 |
+
x = self.deconv_layer1(x)
|
| 299 |
+
x = torch.cat((x,out3), 1)
|
| 300 |
+
x_maps32 = nnf.interpolate(x_maps, size=(32, 32), mode='bicubic', align_corners=False)
|
| 301 |
+
x = torch.cat((x,x_maps32), 1)
|
| 302 |
+
|
| 303 |
+
x = self.deconv_layer2(x)
|
| 304 |
+
x = torch.cat((x,out2), 1)
|
| 305 |
+
x_maps64 = nnf.interpolate(x_maps, size=(64, 64), mode='bicubic', align_corners=False)
|
| 306 |
+
x = torch.cat((x,x_maps64), 1)
|
| 307 |
+
|
| 308 |
+
x = self.deconv_layer3(x)
|
| 309 |
+
x = torch.cat((x,out1), 1)
|
| 310 |
+
x_maps128 = nnf.interpolate(x_maps, size=(128, 128), mode='bicubic', align_corners=False)
|
| 311 |
+
|
| 312 |
+
x = torch.cat((x,x_maps128), 1)
|
| 313 |
+
|
| 314 |
+
x = self.deconv_layer4(x)
|
| 315 |
+
x = torch.cat((x,x_maps), 1)
|
| 316 |
+
|
| 317 |
+
x = self.deconv_mix(x)
|
| 318 |
+
|
| 319 |
+
x = x.squeeze(1)
|
| 320 |
+
|
| 321 |
+
return x, pnas_vol_pred
|
| 322 |
+
|
requirements.txt
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
wandb
|
| 2 |
+
pycocotools
|
| 3 |
+
torch==1.8.0+cu111
|
| 4 |
+
torchvision==0.9.0+cu111
|
| 5 |
+
torchaudio==0.8.0
|
| 6 |
+
libgl1-mesa-glx
|
| 7 |
+
ftfy
|
| 8 |
+
regex
|
| 9 |
+
tqdm
|
| 10 |
+
ipywidgets
|
| 11 |
+
seaborn
|
| 12 |
+
einops
|
| 13 |
+
clip-anytorch
|
| 14 |
+
pycocotools
|
| 15 |
+
kornia==0.5.10
|
testing/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
testing/gt/COCO_val2014_000000000192.png
ADDED
|
testing/gt/COCO_val2014_000000000192_0.png
ADDED
|
testing/gt/COCO_val2014_000000000192_1.png
ADDED
|
testing/gt/COCO_val2014_000000000192_2.png
ADDED
|
testing/gt/COCO_val2014_000000000192_3.png
ADDED
|
testing/gt/COCO_val2014_000000000192_4.png
ADDED
|
testing/gt/COCO_val2014_000000000208.png
ADDED
|
testing/gt/COCO_val2014_000000000208_0.png
ADDED
|
testing/gt/COCO_val2014_000000000208_1.png
ADDED
|
testing/gt/COCO_val2014_000000000208_2.png
ADDED
|
testing/gt/COCO_val2014_000000000208_3.png
ADDED
|
testing/gt/COCO_val2014_000000000208_4.png
ADDED
|
testing/images/COCO_val2014_000000000192.jpg
ADDED
|
testing/images/COCO_val2014_000000000208.jpg
ADDED
|
testing/predictions/Readme.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Your predictions will appear in this folder after running the notebook.
|
train.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import torch
|
| 4 |
+
import sys
|
| 5 |
+
import time
|
| 6 |
+
import wandb
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
from dataloader import SaliconDataset
|
| 10 |
+
from loss import *
|
| 11 |
+
from utils import AverageMeter
|
| 12 |
+
from utils import img_save
|
| 13 |
+
from torchvision import utils
|
| 14 |
+
import torch.nn.functional as nnf
|
| 15 |
+
from os.path import join
|
| 16 |
+
from PIL import Image
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
parser = argparse.ArgumentParser()
|
| 20 |
+
parser.add_argument('--no_epochs',default=30, type=int)
|
| 21 |
+
parser.add_argument('--lr',default=1e-5, type=float)
|
| 22 |
+
parser.add_argument('--kldiv',default=True, type=bool)
|
| 23 |
+
parser.add_argument('--cc',default=True, type=bool)
|
| 24 |
+
parser.add_argument('--nss',default=False, type=bool)
|
| 25 |
+
parser.add_argument('--sim',default=False, type=bool)
|
| 26 |
+
parser.add_argument('--nss_emlnet',default=False, type=bool)
|
| 27 |
+
parser.add_argument('--nss_norm',default=False, type=bool)
|
| 28 |
+
parser.add_argument('--l1',default=False, type=bool)
|
| 29 |
+
parser.add_argument('--lr_sched',default=False, type=bool)
|
| 30 |
+
parser.add_argument('--dilation',default=False, type=bool)
|
| 31 |
+
parser.add_argument('--enc_model',default="pnas", type=str)
|
| 32 |
+
parser.add_argument('--optim',default="Adam", type=str)
|
| 33 |
+
|
| 34 |
+
parser.add_argument('--load_weight',default=1, type=int)
|
| 35 |
+
parser.add_argument('--kldiv_coeff',default=1.0, type=float)
|
| 36 |
+
parser.add_argument('--step_size',default=5, type=int)
|
| 37 |
+
parser.add_argument('--cc_coeff',default=-1.0, type=float)
|
| 38 |
+
parser.add_argument('--sim_coeff',default=-1.0, type=float)
|
| 39 |
+
parser.add_argument('--nss_coeff',default=-1.0, type=float)
|
| 40 |
+
parser.add_argument('--nss_emlnet_coeff',default=1.0, type=float)
|
| 41 |
+
parser.add_argument('--nss_norm_coeff',default=1.0, type=float)
|
| 42 |
+
parser.add_argument('--l1_coeff',default=1.0, type=float)
|
| 43 |
+
parser.add_argument('--train_enc',default=1, type=int)
|
| 44 |
+
|
| 45 |
+
parser.add_argument('--dataset_dir',default="../data/", type=str)
|
| 46 |
+
parser.add_argument('--batch_size',default=32, type=int)
|
| 47 |
+
parser.add_argument('--log_interval',default=60, type=int)
|
| 48 |
+
parser.add_argument('--no_workers',default=4, type=int)
|
| 49 |
+
parser.add_argument('--train_model',default=False, type=bool)
|
| 50 |
+
parser.add_argument('--time_slices',default=5, type=int)
|
| 51 |
+
parser.add_argument('--selected_slices',default="", type=str)
|
| 52 |
+
parser.add_argument('--results_dir',default="", type=str )
|
| 53 |
+
|
| 54 |
+
# Path to save the model weights
|
| 55 |
+
parser.add_argument('--model_val_path',default="model.pt", type=str)
|
| 56 |
+
# If the model type is pnas_boosted, specify the path of the pre-trained pnas model here
|
| 57 |
+
parser.add_argument('--model_path',default="", type=str)
|
| 58 |
+
# If the model type is pnas_boosted, specify the path of the pre-trained pnasvol model here
|
| 59 |
+
parser.add_argument('--model_vol_path',default="", type=str)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
args = parser.parse_args()
|
| 63 |
+
|
| 64 |
+
train_img_dir = args.dataset_dir + "images/train/"
|
| 65 |
+
train_gt_dir = args.dataset_dir + "maps/train/"
|
| 66 |
+
train_fix_dir = args.dataset_dir + "fixation_maps/train/"
|
| 67 |
+
|
| 68 |
+
val_img_dir = args.dataset_dir + "images/val/"
|
| 69 |
+
val_gt_dir = args.dataset_dir + "maps/val/"
|
| 70 |
+
val_fix_dir = args.dataset_dir + "fixation_maps/val/"
|
| 71 |
+
|
| 72 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 73 |
+
|
| 74 |
+
if args.enc_model == "pnas":
|
| 75 |
+
print("PNAS Model")
|
| 76 |
+
from model import PNASModel
|
| 77 |
+
model = PNASModel(train_enc=bool(args.train_enc), load_weight=args.load_weight)
|
| 78 |
+
|
| 79 |
+
elif args.enc_model == "pnas_boosted_multi":
|
| 80 |
+
print("PNAS Boosted Model PNASBoostedModelMultilevel")
|
| 81 |
+
from model import PNASBoostedModelMultilevel
|
| 82 |
+
model = PNASBoostedModelMultilevel(device, args.model_path, args.model_vol_path, args.time_slices, train_model=args.train_model,selected_slices = args.selected_slices )
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
if torch.cuda.device_count() > 1:
|
| 86 |
+
print("Let's use", torch.cuda.device_count(), "GPUs!")
|
| 87 |
+
model = nn.DataParallel(model)
|
| 88 |
+
model.to(device)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
train_img_ids = [nm.split(".")[0] for nm in os.listdir(train_img_dir)]
|
| 92 |
+
val_img_ids = [nm.split(".")[0] for nm in os.listdir(val_img_dir)]
|
| 93 |
+
|
| 94 |
+
train_dataset = SaliconDataset(train_img_dir, train_gt_dir, train_fix_dir, train_img_ids)
|
| 95 |
+
val_dataset = SaliconDataset(val_img_dir, val_gt_dir, val_fix_dir, val_img_ids)
|
| 96 |
+
|
| 97 |
+
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.no_workers)
|
| 98 |
+
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.no_workers)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def loss_func(pred_map, gt, fixations, args):
|
| 103 |
+
loss = torch.FloatTensor([0.0]).cuda()
|
| 104 |
+
criterion = nn.L1Loss()
|
| 105 |
+
if args.kldiv:
|
| 106 |
+
loss += args.kldiv_coeff * kldiv(pred_map, gt)
|
| 107 |
+
if args.cc:
|
| 108 |
+
loss += args.cc_coeff * cc(pred_map, gt)
|
| 109 |
+
if args.nss:
|
| 110 |
+
loss += args.nss_coeff * nss(pred_map, fixations)
|
| 111 |
+
if args.l1:
|
| 112 |
+
loss += args.l1_coeff * criterion(pred_map, gt)
|
| 113 |
+
if args.sim:
|
| 114 |
+
loss += args.sim_coeff * similarity(pred_map, gt)
|
| 115 |
+
#print("Loss: ", loss)
|
| 116 |
+
return loss
|
| 117 |
+
|
| 118 |
+
def train(model, optimizer, loader, epoch, device, args):
|
| 119 |
+
model.train()
|
| 120 |
+
|
| 121 |
+
tic = time.time()
|
| 122 |
+
|
| 123 |
+
total_loss = 0.0
|
| 124 |
+
cur_loss = 0.0
|
| 125 |
+
|
| 126 |
+
for idx, (img, gt, fixations) in enumerate(loader):
|
| 127 |
+
img = img.to(device)
|
| 128 |
+
gt = gt.to(device)
|
| 129 |
+
fixations = fixations.to(device)
|
| 130 |
+
|
| 131 |
+
optimizer.zero_grad()
|
| 132 |
+
pred_map, vol_pred = model(img)
|
| 133 |
+
|
| 134 |
+
assert pred_map.size() == gt.size()
|
| 135 |
+
|
| 136 |
+
loss = loss_func(pred_map, gt, fixations, args)
|
| 137 |
+
loss.backward()
|
| 138 |
+
|
| 139 |
+
total_loss += loss.item()
|
| 140 |
+
cur_loss += loss.item()
|
| 141 |
+
|
| 142 |
+
optimizer.step()
|
| 143 |
+
if idx%args.log_interval==(args.log_interval-1):
|
| 144 |
+
print('[{:2d}, {:5d}] avg_loss : {:.5f}, time:{:3f} minutes'.format(epoch, idx, cur_loss/args.log_interval, (time.time()-tic)/60))
|
| 145 |
+
wandb.log({"loss": cur_loss/args.log_interval})
|
| 146 |
+
cur_loss = 0.0
|
| 147 |
+
sys.stdout.flush()
|
| 148 |
+
|
| 149 |
+
print('[{:2d}, train] avg_loss : {:.5f}'.format(epoch, total_loss/len(loader)))
|
| 150 |
+
sys.stdout.flush()
|
| 151 |
+
|
| 152 |
+
return total_loss/len(loader)
|
| 153 |
+
|
| 154 |
+
def validate(model, loader, epoch, device, args):
|
| 155 |
+
model.eval()
|
| 156 |
+
tic = time.time()
|
| 157 |
+
cc_loss = AverageMeter()
|
| 158 |
+
kldiv_loss = AverageMeter()
|
| 159 |
+
nss_loss = AverageMeter()
|
| 160 |
+
sim_loss = AverageMeter()
|
| 161 |
+
|
| 162 |
+
for (img, gt, fixations) in tqdm(loader):
|
| 163 |
+
img = img.to(device)
|
| 164 |
+
gt = gt.to(device)
|
| 165 |
+
fixations = fixations.to(device)
|
| 166 |
+
|
| 167 |
+
pred_map , vol_pred = model(img)
|
| 168 |
+
|
| 169 |
+
cc_loss.update(cc(pred_map, gt))
|
| 170 |
+
kldiv_loss.update(kldiv(pred_map, gt))
|
| 171 |
+
nss_loss.update(nss(pred_map, fixations))
|
| 172 |
+
sim_loss.update(similarity(pred_map, gt))
|
| 173 |
+
|
| 174 |
+
print('[{:2d}, val] CC : {:.5f}, KLDIV : {:.5f}, NSS : {:.5f}, SIM : {:.5f} time:{:3f} minutes'.format(epoch, cc_loss.avg, kldiv_loss.avg, nss_loss.avg, sim_loss.avg, (time.time()-tic)/60))
|
| 175 |
+
wandb.log({"CC": cc_loss.avg, 'KLDIV': kldiv_loss.avg, 'NSS': nss_loss.avg, 'SIM': sim_loss.avg})
|
| 176 |
+
sys.stdout.flush()
|
| 177 |
+
|
| 178 |
+
return cc_loss.avg,cc_loss,kldiv_loss,nss_loss,sim_loss
|
| 179 |
+
|
| 180 |
+
params = list(filter(lambda p: p.requires_grad, model.parameters()))
|
| 181 |
+
|
| 182 |
+
if args.optim=="Adam":
|
| 183 |
+
optimizer = torch.optim.Adam(params, lr=args.lr)
|
| 184 |
+
if args.optim=="Adagrad":
|
| 185 |
+
optimizer = torch.optim.Adagrad(params, lr=args.lr)
|
| 186 |
+
if args.optim=="SGD":
|
| 187 |
+
optimizer = torch.optim.SGD(params, lr=args.lr, momentum=0.9)
|
| 188 |
+
if args.lr_sched:
|
| 189 |
+
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=0.1)
|
| 190 |
+
|
| 191 |
+
print(device)
|
| 192 |
+
best_loss = 0
|
| 193 |
+
for epoch in range(0, args.no_epochs):
|
| 194 |
+
loss = train(model, optimizer, train_loader, epoch, device, args)
|
| 195 |
+
|
| 196 |
+
with torch.no_grad():
|
| 197 |
+
cc_loss,cc_loss_obj,kldiv_loss,nss_loss,sim_loss = validate(model, val_loader, epoch, device, args)
|
| 198 |
+
cc_loss -=kldiv_loss.avg
|
| 199 |
+
if epoch == 0 :
|
| 200 |
+
best_loss = cc_loss
|
| 201 |
+
if best_loss <= cc_loss:
|
| 202 |
+
best_loss = cc_loss
|
| 203 |
+
print('[{:2d}, save, {}]'.format(epoch, args.model_val_path))
|
| 204 |
+
wandb.log({"Best/CC mean": cc_loss,"Best/CC median": cc_loss_obj.get_median(), "Best/CC std": cc_loss_obj.get_std(),
|
| 205 |
+
"Best/KLD mean": kldiv_loss.avg,"Best/KLD median": kldiv_loss.get_median(), "Best/KLD std": kldiv_loss.get_std(),
|
| 206 |
+
"Best/NSS mean": nss_loss.avg,"Best/NSS median": nss_loss.get_median(), "Best/NSS std": nss_loss.get_std(),
|
| 207 |
+
"Best/SIM mean": sim_loss.avg,"Best/SIM median": sim_loss.get_median(), "Best/SIM std": sim_loss.get_std()})
|
| 208 |
+
if torch.cuda.device_count() > 1:
|
| 209 |
+
torch.save(model.module.state_dict(), args.model_val_path)
|
| 210 |
+
else:
|
| 211 |
+
torch.save(model.state_dict(), args.model_val_path)
|
| 212 |
+
print()
|
| 213 |
+
|
| 214 |
+
if args.lr_sched:
|
| 215 |
+
scheduler.step()
|
utils.py
ADDED
|
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import fnmatch
|
| 2 |
+
import os
|
| 3 |
+
import torch
|
| 4 |
+
import cv2
|
| 5 |
+
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
from scipy.spatial import distance
|
| 8 |
+
from math import pi, sqrt, exp
|
| 9 |
+
from os.path import join
|
| 10 |
+
from torchvision import utils
|
| 11 |
+
from PIL import Image
|
| 12 |
+
|
| 13 |
+
import numpy as np
|
| 14 |
+
import scipy.io as sio
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
import matplotlib.pyplot as plt
|
| 18 |
+
import matplotlib.animation as animation
|
| 19 |
+
|
| 20 |
+
W = 640
|
| 21 |
+
H = 480
|
| 22 |
+
TIMESPAN = 5000
|
| 23 |
+
MAX_PIXEL_DISTANCE = 800
|
| 24 |
+
ESTIMATED_TIMESTAMP_WEIGHT = 0.006
|
| 25 |
+
RATIO = 0.9
|
| 26 |
+
|
| 27 |
+
FIXATION_PATH = '../data/fixations/'
|
| 28 |
+
FIX_VOL_PATH = '../data/fixation_volumes_'
|
| 29 |
+
SAL_VOL_PATH = '../data/saliency_volumes_'
|
| 30 |
+
|
| 31 |
+
class bcolors:
|
| 32 |
+
HEADER = '\033[95m'
|
| 33 |
+
OKBLUE = '\033[94m'
|
| 34 |
+
OKCYAN = '\033[96m'
|
| 35 |
+
OKGREEN = '\033[92m'
|
| 36 |
+
WARNING = '\033[93m'
|
| 37 |
+
FAIL = '\033[91m'
|
| 38 |
+
ENDC = '\033[0m'
|
| 39 |
+
BOLD = '\033[1m'
|
| 40 |
+
UNDERLINE = '\033[4m'
|
| 41 |
+
|
| 42 |
+
def get_colored_value(value, ref_value, increasing=True):
|
| 43 |
+
coef = 1 if increasing else -1
|
| 44 |
+
return (bcolors.FAIL if (coef * ref_value > coef * value) else bcolors.OKGREEN) + '{:.5f}'.format(value) + bcolors.ENDC
|
| 45 |
+
|
| 46 |
+
def get_filenames(path):
|
| 47 |
+
return [file for file in sorted(os.listdir(path)) if fnmatch.fnmatch(file, 'COCO_*')]
|
| 48 |
+
|
| 49 |
+
def parse_fixations(filenames,
|
| 50 |
+
path_prefix,
|
| 51 |
+
etw=ESTIMATED_TIMESTAMP_WEIGHT, progress_bar=True):
|
| 52 |
+
fixation_volumes = []
|
| 53 |
+
filenames = tqdm(filenames) if progress_bar else filenames
|
| 54 |
+
|
| 55 |
+
for filename in filenames:
|
| 56 |
+
# 1. Extracting data from .mat files
|
| 57 |
+
mat = sio.loadmat(path_prefix + filename + '.mat')
|
| 58 |
+
gaze = mat["gaze"]
|
| 59 |
+
|
| 60 |
+
locations = []
|
| 61 |
+
timestamps = []
|
| 62 |
+
fixations = []
|
| 63 |
+
|
| 64 |
+
for i in range(len(gaze)):
|
| 65 |
+
locations.append(mat["gaze"][i][0][0])
|
| 66 |
+
timestamps.append(mat["gaze"][i][0][1])
|
| 67 |
+
fixations.append(mat["gaze"][i][0][2])
|
| 68 |
+
|
| 69 |
+
# 2. Matching fixations with timestamps
|
| 70 |
+
fixation_volume = []
|
| 71 |
+
for i, observer in enumerate(fixations):
|
| 72 |
+
fix_timestamps = []
|
| 73 |
+
fix_time = TIMESPAN / (len(observer) + 1)
|
| 74 |
+
est_timestamp = fix_time
|
| 75 |
+
|
| 76 |
+
for fixation in observer:
|
| 77 |
+
distances = distance.cdist([fixation], locations[i], 'euclidean')[0][..., np.newaxis]
|
| 78 |
+
time_diffs = abs(timestamps[i] - est_timestamp)
|
| 79 |
+
min_idx = (etw * time_diffs + distances).argmin()
|
| 80 |
+
|
| 81 |
+
fix_timestamps.append([min(timestamps[i][min_idx][0], TIMESPAN), fixation.tolist()])
|
| 82 |
+
est_timestamp += fix_time
|
| 83 |
+
|
| 84 |
+
if (len(observer) > 0):
|
| 85 |
+
fixation_volume.append(fix_timestamps)
|
| 86 |
+
|
| 87 |
+
fixation_volumes.append(fixation_volume)
|
| 88 |
+
|
| 89 |
+
return fixation_volumes
|
| 90 |
+
|
| 91 |
+
def get_saliency_volume(fixation_volume, conv2D, time_slices):
|
| 92 |
+
fixation_map = torch.cuda.FloatTensor(time_slices,H,W).fill_(0)
|
| 93 |
+
|
| 94 |
+
for ts, coords in fixation_volume:
|
| 95 |
+
for (x, y) in coords:
|
| 96 |
+
fixation_map[ts,y-1,x-1] = 1
|
| 97 |
+
|
| 98 |
+
saliency_volume = conv2D.forward(fixation_map)
|
| 99 |
+
return saliency_volume / saliency_volume.max()
|
| 100 |
+
|
| 101 |
+
def blur(img):
|
| 102 |
+
k_size = 11
|
| 103 |
+
bl = cv2.GaussianBlur(img,(k_size,k_size),0)
|
| 104 |
+
return torch.FloatTensor(bl)
|
| 105 |
+
|
| 106 |
+
def visualize_model(model, loader, device, args):
|
| 107 |
+
with torch.no_grad():
|
| 108 |
+
model.eval()
|
| 109 |
+
os.makedirs(args.results_dir, exist_ok=True)
|
| 110 |
+
|
| 111 |
+
for (img, img_id, sz) in tqdm(loader):
|
| 112 |
+
img = img.to(device)
|
| 113 |
+
|
| 114 |
+
pred_map = model(img)
|
| 115 |
+
if type(pred_map) is tuple:
|
| 116 |
+
pred_map = pred_map[1]
|
| 117 |
+
pred_map = pred_map.cpu().squeeze(0).numpy()
|
| 118 |
+
pred_map = cv2.resize(pred_map, (sz[0], sz[1]))
|
| 119 |
+
|
| 120 |
+
pred_map = torch.FloatTensor(blur(pred_map))
|
| 121 |
+
img_save(pred_map, join(args.results_dir, img_id[0]), normalize=True)
|
| 122 |
+
|
| 123 |
+
def img_save(tensor, fp, nrow=8, padding=2,
|
| 124 |
+
normalize=False, range=None, scale_each=False, pad_value=0, format=None):
|
| 125 |
+
grid = utils.make_grid(tensor, nrow=nrow, padding=padding, pad_value=pad_value,
|
| 126 |
+
normalize=normalize, range=range, scale_each=scale_each)
|
| 127 |
+
|
| 128 |
+
''' Add 0.5 after unnormalizing to [0, 255] to round to nearest integer '''
|
| 129 |
+
|
| 130 |
+
ndarr = torch.round(grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0)).to('cpu', torch.uint8).numpy()
|
| 131 |
+
im = Image.fromarray(ndarr[:,:,0])
|
| 132 |
+
#fp = fp[:-4] + '.png'
|
| 133 |
+
fp = '.png'
|
| 134 |
+
print(fp)
|
| 135 |
+
im.save("1.png", format=format, compress_level=0)
|
| 136 |
+
|
| 137 |
+
class AverageMeter(object):
|
| 138 |
+
|
| 139 |
+
'''Computers and stores the average and current value'''
|
| 140 |
+
|
| 141 |
+
def __init__(self):
|
| 142 |
+
self.reset()
|
| 143 |
+
|
| 144 |
+
def reset(self):
|
| 145 |
+
self.past = np.array([])
|
| 146 |
+
self.val = 0
|
| 147 |
+
self.avg = 0
|
| 148 |
+
self.sum = 0
|
| 149 |
+
self.count = 0
|
| 150 |
+
|
| 151 |
+
def update(self, val, n = 1):
|
| 152 |
+
self.val = val
|
| 153 |
+
self.sum += val*n
|
| 154 |
+
self.count += n
|
| 155 |
+
self.avg = self.sum / self.count
|
| 156 |
+
self.past = np.append(self.past,val.cpu())
|
| 157 |
+
|
| 158 |
+
def get_std (self):
|
| 159 |
+
return np.std(self.past)
|
| 160 |
+
|
| 161 |
+
def get_median (self):
|
| 162 |
+
return np.median(self.past)
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def im2heat(pred_dir, a, gt, exten='.png'):
|
| 166 |
+
pred_nm = pred_dir + a + exten
|
| 167 |
+
pred = cv2.imread(pred_nm, 0)
|
| 168 |
+
heatmap_img = cv2.applyColorMap(pred, cv2.COLORMAP_JET)
|
| 169 |
+
heatmap_img = convert(heatmap_img)
|
| 170 |
+
pred = np.stack((pred, pred, pred),2).astype('float32')
|
| 171 |
+
pred = pred / 255.0
|
| 172 |
+
|
| 173 |
+
return np.uint8(pred * heatmap_img + (1.0-pred) * gt)
|
| 174 |
+
|
| 175 |
+
def get_heat_image(image):
|
| 176 |
+
return cv2.cvtColor(cv2.applyColorMap(np.uint8(255 * image), cv2.COLORMAP_HOT), cv2.COLOR_BGR2RGB)
|
| 177 |
+
|
| 178 |
+
def format_image(heatmap, image, max_value):
|
| 179 |
+
extended_map = heatmap / max_value
|
| 180 |
+
hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
|
| 181 |
+
factors = np.clip(2 * extended_map, 0, 1)
|
| 182 |
+
hsv[:,:,1] = np.uint8(factors * hsv[:,:,1])
|
| 183 |
+
hsv[:,:,2] = np.uint8((RATIO * factors + (1 - RATIO)) * hsv[:,:,2])
|
| 184 |
+
|
| 185 |
+
return get_heat_image(extended_map[:,:,np.newaxis]), cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB)
|
| 186 |
+
|
| 187 |
+
def animate(gt_vol, pred_vol, image):
|
| 188 |
+
fig = plt.figure(figsize=(16, 16))
|
| 189 |
+
|
| 190 |
+
gt_max = np.max(gt_vol);
|
| 191 |
+
pred_max = np.max(pred_vol);
|
| 192 |
+
formatted_images = []
|
| 193 |
+
|
| 194 |
+
for (gt_map, pred_map) in zip(gt_vol, pred_vol):
|
| 195 |
+
gt_heatmap, gt_image_heatmap = format_image(gt_map, image, gt_max)
|
| 196 |
+
gt_im = np.concatenate((gt_heatmap, gt_image_heatmap), 1)
|
| 197 |
+
|
| 198 |
+
pred_heatmap, pred_image_heatmap = format_image(pred_map, image, pred_max)
|
| 199 |
+
pred_im = np.concatenate((pred_heatmap, pred_image_heatmap), 1)
|
| 200 |
+
|
| 201 |
+
diff = 0.5 + ((gt_map / gt_max) - (pred_map / pred_max)) / 2
|
| 202 |
+
diff_im = cv2.cvtColor(cv2.applyColorMap(np.uint8(255 * diff[:,:,np.newaxis]), cv2.COLORMAP_TWILIGHT), cv2.COLOR_BGR2RGB)
|
| 203 |
+
diff_im = np.concatenate((diff_im, image), 1)
|
| 204 |
+
formatted_images.append([plt.imshow(np.concatenate((gt_im, pred_im, diff_im), 0), animated=True)])
|
| 205 |
+
|
| 206 |
+
return animation.ArtistAnimation(fig, formatted_images, interval=500, blit=True, repeat_delay=1000)
|
| 207 |
+
|
| 208 |
+
def animate_single_heatmap(gt_vol, image):
|
| 209 |
+
fig = plt.figure(figsize=(6.4, 4.8),frameon=False)
|
| 210 |
+
ax = plt.Axes(fig, [0., 0., 1., 1.])
|
| 211 |
+
ax.set_axis_off()
|
| 212 |
+
fig.add_axes(ax)
|
| 213 |
+
gt_max = np.max(gt_vol);
|
| 214 |
+
formatted_images = []
|
| 215 |
+
plt.axis('off')
|
| 216 |
+
for gt_map in gt_vol:
|
| 217 |
+
gt_heatmap, gt_image_heatmap = format_image(gt_map, image, gt_max)
|
| 218 |
+
gt_im = gt_heatmap
|
| 219 |
+
formatted_images.append([ax.imshow(gt_im, animated=True)])
|
| 220 |
+
return animation.ArtistAnimation(fig, formatted_images, interval=1000, blit=True, repeat_delay=1000)
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def gauss(n, sigma):
|
| 224 |
+
r = range(-int(n/2),int(n/2)+1)
|
| 225 |
+
return [1 / (sigma * sqrt(2*pi)) * exp(-float(x)**2/(2*sigma**2)) for x in r]
|
| 226 |
+
|
| 227 |
+
class GaussianBlur1D(nn.Module):
|
| 228 |
+
def __init__(self, time_slices):
|
| 229 |
+
super(GaussianBlur1D, self).__init__()
|
| 230 |
+
sigma = 2 * time_slices / 25
|
| 231 |
+
self.size = 2 * int(4 * sigma + 0.5) + 1
|
| 232 |
+
kernel = gauss(self.size, sigma)
|
| 233 |
+
kernel = torch.cuda.FloatTensor(kernel)
|
| 234 |
+
self.weight = nn.Parameter(data=kernel, requires_grad=False)
|
| 235 |
+
|
| 236 |
+
def forward(self, x):
|
| 237 |
+
pad = int(self.size/2)
|
| 238 |
+
temp = F.conv1d(x, self.weight.view(1, 1, -1, 1, 1), padding=pad)
|
| 239 |
+
return temp[:,:,:,pad:-pad,pad:-pad]
|
| 240 |
+
|
| 241 |
+
class GaussianBlur2D(nn.Module):
|
| 242 |
+
def __init__(self):
|
| 243 |
+
super(GaussianBlur2D, self).__init__()
|
| 244 |
+
self.size = 201
|
| 245 |
+
kernel = gauss(self.size, 25)
|
| 246 |
+
kernel = torch.cuda.FloatTensor(kernel)
|
| 247 |
+
self.weight = nn.Parameter(data=kernel, requires_grad=False)
|
| 248 |
+
|
| 249 |
+
def forward(self, x):
|
| 250 |
+
pad = int(self.size/2)
|
| 251 |
+
temp = F.conv1d(x.unsqueeze(0).unsqueeze(0), self.weight.view(1, 1, 1, -1, 1), padding=pad)
|
| 252 |
+
temp = temp[:,:,pad:-pad,:,pad:-pad]
|
| 253 |
+
temp = F.conv1d(temp, self.weight.view(1, 1, 1, 1, -1), padding=pad)
|
| 254 |
+
return temp[:,:,pad:-pad,pad:-pad]
|