Spaces:
Sleeping
Sleeping
| import functools | |
| import math | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from .layers import GaussianFilterNd | |
| def encode_scanpath_features(x_hist, y_hist, size, device=None, include_x=True, include_y=True, include_duration=False): | |
| assert include_x | |
| assert include_y | |
| assert not include_duration | |
| height = size[0] | |
| width = size[1] | |
| xs = torch.arange(width, dtype=torch.float32).to(device) | |
| ys = torch.arange(height, dtype=torch.float32).to(device) | |
| YS, XS = torch.meshgrid(ys, xs, indexing='ij') | |
| XS = torch.repeat_interleave( | |
| torch.repeat_interleave( | |
| XS[np.newaxis, np.newaxis, :, :], | |
| repeats=x_hist.shape[0], | |
| dim=0, | |
| ), | |
| repeats=x_hist.shape[1], | |
| dim=1, | |
| ) | |
| YS = torch.repeat_interleave( | |
| torch.repeat_interleave( | |
| YS[np.newaxis, np.newaxis, :, :], | |
| repeats=y_hist.shape[0], | |
| dim=0, | |
| ), | |
| repeats=y_hist.shape[1], | |
| dim=1, | |
| ) | |
| XS -= x_hist.unsqueeze(2).unsqueeze(3) | |
| YS -= y_hist.unsqueeze(2).unsqueeze(3) | |
| distances = torch.sqrt(XS**2 + YS**2) | |
| return torch.cat((XS, YS, distances), axis=1) | |
| class FeatureExtractor(torch.nn.Module): | |
| def __init__(self, features, targets): | |
| super().__init__() | |
| self.features = features | |
| self.targets = targets | |
| #print("Targets are {}".format(targets)) | |
| self.outputs = {} | |
| for target in targets: | |
| layer = dict([*self.features.named_modules()])[target] | |
| layer.register_forward_hook(self.save_outputs_hook(target)) | |
| def save_outputs_hook(self, layer_id: str): | |
| def fn(_, __, output): | |
| self.outputs[layer_id] = output.clone() | |
| return fn | |
| def forward(self, x): | |
| self.outputs.clear() | |
| self.features(x) | |
| return [self.outputs[target] for target in self.targets] | |
| def upscale(tensor, size): | |
| tensor_size = torch.tensor(tensor.shape[2:]).type(torch.float32) | |
| target_size = torch.tensor(size).type(torch.float32) | |
| factors = torch.ceil(target_size / tensor_size) | |
| factor = torch.max(factors).type(torch.int64).to(tensor.device) | |
| assert factor >= 1 | |
| tensor = torch.repeat_interleave(tensor, factor, dim=2) | |
| tensor = torch.repeat_interleave(tensor, factor, dim=3) | |
| tensor = tensor[:, :, :size[0], :size[1]] | |
| return tensor | |
| class Finalizer(nn.Module): | |
| """Transforms a readout into a gaze prediction | |
| A readout network returns a single, spatial map of probable gaze locations. | |
| This module bundles the common processing steps necessary to transform this into | |
| the predicted gaze distribution: | |
| - resizing to the stimulus size | |
| - smoothing of the prediction using a gaussian filter | |
| - removing of channel and time dimension | |
| - weighted addition of the center bias | |
| - normalization | |
| """ | |
| def __init__( | |
| self, | |
| sigma, | |
| kernel_size=None, | |
| learn_sigma=False, | |
| center_bias_weight=1.0, | |
| learn_center_bias_weight=True, | |
| saliency_map_factor=4, | |
| ): | |
| """Creates a new finalizer | |
| Args: | |
| size (tuple): target size for the predictions | |
| sigma (float): standard deviation of the gaussian kernel used for smoothing | |
| kernel_size (int, optional): size of the gaussian kernel | |
| learn_sigma (bool, optional): If True, the standard deviation of the gaussian kernel will | |
| be learned (default: False) | |
| center_bias (string or tensor): the center bias | |
| center_bias_weight (float, optional): initial weight of the center bias | |
| learn_center_bias_weight (bool, optional): If True, the center bias weight will be | |
| learned (default: True) | |
| """ | |
| super(Finalizer, self).__init__() | |
| self.saliency_map_factor = saliency_map_factor | |
| self.gauss = GaussianFilterNd([2, 3], sigma, truncate=3, trainable=learn_sigma) | |
| self.center_bias_weight = nn.Parameter(torch.Tensor([center_bias_weight]), requires_grad=learn_center_bias_weight) | |
| def forward(self, readout, centerbias): | |
| """Applies the finalization steps to the given readout""" | |
| downscaled_centerbias = F.interpolate( | |
| centerbias.view(centerbias.shape[0], 1, centerbias.shape[1], centerbias.shape[2]), | |
| scale_factor=1 / self.saliency_map_factor, | |
| recompute_scale_factor=False, | |
| )[:, 0, :, :] | |
| out = F.interpolate( | |
| readout, | |
| size=[downscaled_centerbias.shape[1], downscaled_centerbias.shape[2]] | |
| ) | |
| # apply gaussian filter | |
| out = self.gauss(out) | |
| # remove channel dimension | |
| out = out[:, 0, :, :] | |
| # add to center bias | |
| out = out + self.center_bias_weight * downscaled_centerbias | |
| out = F.interpolate(out[:, np.newaxis, :, :], size=[centerbias.shape[1], centerbias.shape[2]])[:, 0, :, :] | |
| # normalize | |
| out = out - out.logsumexp(dim=(1, 2), keepdim=True) | |
| return out | |
| class DeepGazeII(torch.nn.Module): | |
| def __init__(self, features, readout_network, downsample=2, readout_factor=16, saliency_map_factor=2, initial_sigma=8.0): | |
| super().__init__() | |
| self.readout_factor = readout_factor | |
| self.saliency_map_factor = saliency_map_factor | |
| self.features = features | |
| for param in self.features.parameters(): | |
| param.requires_grad = False | |
| self.features.eval() | |
| self.readout_network = readout_network | |
| self.finalizer = Finalizer( | |
| sigma=initial_sigma, | |
| learn_sigma=True, | |
| saliency_map_factor=self.saliency_map_factor, | |
| ) | |
| self.downsample = downsample | |
| def forward(self, x, centerbias): | |
| orig_shape = x.shape | |
| x = F.interpolate( | |
| x, | |
| scale_factor=1 / self.downsample, | |
| recompute_scale_factor=False, | |
| ) | |
| x = self.features(x) | |
| readout_shape = [math.ceil(orig_shape[2] / self.downsample / self.readout_factor), math.ceil(orig_shape[3] / self.downsample / self.readout_factor)] | |
| x = [F.interpolate(item, readout_shape) for item in x] | |
| x = torch.cat(x, dim=1) | |
| x = self.readout_network(x) | |
| x = self.finalizer(x, centerbias) | |
| return x | |
| def train(self, mode=True): | |
| self.features.eval() | |
| self.readout_network.train(mode=mode) | |
| self.finalizer.train(mode=mode) | |
| class DeepGazeIII(torch.nn.Module): | |
| def __init__(self, features, saliency_network, scanpath_network, fixation_selection_network, downsample=2, readout_factor=2, saliency_map_factor=2, included_fixations=-2, initial_sigma=8.0): | |
| super().__init__() | |
| self.downsample = downsample | |
| self.readout_factor = readout_factor | |
| self.saliency_map_factor = saliency_map_factor | |
| self.included_fixations = included_fixations | |
| self.features = features | |
| for param in self.features.parameters(): | |
| param.requires_grad = False | |
| self.features.eval() | |
| self.saliency_network = saliency_network | |
| self.scanpath_network = scanpath_network | |
| self.fixation_selection_network = fixation_selection_network | |
| self.finalizer = Finalizer( | |
| sigma=initial_sigma, | |
| learn_sigma=True, | |
| saliency_map_factor=self.saliency_map_factor, | |
| ) | |
| def forward(self, x, centerbias, x_hist=None, y_hist=None, durations=None): | |
| orig_shape = x.shape | |
| x = F.interpolate(x, scale_factor=1 / self.downsample) | |
| x = self.features(x) | |
| readout_shape = [math.ceil(orig_shape[2] / self.downsample / self.readout_factor), math.ceil(orig_shape[3] / self.downsample / self.readout_factor)] | |
| x = [F.interpolate(item, readout_shape) for item in x] | |
| x = torch.cat(x, dim=1) | |
| x = self.saliency_network(x) | |
| if self.scanpath_network is not None: | |
| scanpath_features = encode_scanpath_features(x_hist, y_hist, size=(orig_shape[2], orig_shape[3]), device=x.device) | |
| #scanpath_features = F.interpolate(scanpath_features, scale_factor=1 / self.downsample / self.readout_factor) | |
| scanpath_features = F.interpolate(scanpath_features, readout_shape) | |
| y = self.scanpath_network(scanpath_features) | |
| else: | |
| y = None | |
| x = self.fixation_selection_network((x, y)) | |
| x = self.finalizer(x, centerbias) | |
| return x | |
| def train(self, mode=True): | |
| self.features.eval() | |
| self.saliency_network.train(mode=mode) | |
| if self.scanpath_network is not None: | |
| self.scanpath_network.train(mode=mode) | |
| self.fixation_selection_network.train(mode=mode) | |
| self.finalizer.train(mode=mode) | |
| class DeepGazeIIIMixture(torch.nn.Module): | |
| def __init__(self, features, saliency_networks, scanpath_networks, fixation_selection_networks, finalizers, downsample=2, readout_factor=2, saliency_map_factor=2, included_fixations=-2, initial_sigma=8.0): | |
| super().__init__() | |
| self.downsample = downsample | |
| self.readout_factor = readout_factor | |
| self.saliency_map_factor = saliency_map_factor | |
| self.included_fixations = included_fixations | |
| self.features = features | |
| for param in self.features.parameters(): | |
| param.requires_grad = False | |
| self.features.eval() | |
| self.saliency_networks = torch.nn.ModuleList(saliency_networks) | |
| self.scanpath_networks = torch.nn.ModuleList(scanpath_networks) | |
| self.fixation_selection_networks = torch.nn.ModuleList(fixation_selection_networks) | |
| self.finalizers = torch.nn.ModuleList(finalizers) | |
| def forward(self, x, centerbias, x_hist=None, y_hist=None, durations=None): | |
| orig_shape = x.shape | |
| x = F.interpolate( | |
| x, | |
| scale_factor=1 / self.downsample, | |
| recompute_scale_factor=False, | |
| ) | |
| x = self.features(x) | |
| readout_shape = [math.ceil(orig_shape[2] / self.downsample / self.readout_factor), math.ceil(orig_shape[3] / self.downsample / self.readout_factor)] | |
| x = [F.interpolate(item, readout_shape) for item in x] | |
| x = torch.cat(x, dim=1) | |
| predictions = [] | |
| readout_input = x | |
| for saliency_network, scanpath_network, fixation_selection_network, finalizer in zip( | |
| self.saliency_networks, self.scanpath_networks, self.fixation_selection_networks, self.finalizers | |
| ): | |
| x = saliency_network(readout_input) | |
| if scanpath_network is not None: | |
| scanpath_features = encode_scanpath_features(x_hist, y_hist, size=(orig_shape[2], orig_shape[3]), device=x.device) | |
| scanpath_features = F.interpolate(scanpath_features, readout_shape) | |
| y = scanpath_network(scanpath_features) | |
| else: | |
| y = None | |
| x = fixation_selection_network((x, y)) | |
| x = finalizer(x, centerbias) | |
| predictions.append(x[:, np.newaxis, :, :]) | |
| predictions = torch.cat(predictions, dim=1) - np.log(len(self.saliency_networks)) | |
| prediction = predictions.logsumexp(dim=(1), keepdim=True) | |
| return prediction | |
| class MixtureModel(torch.nn.Module): | |
| def __init__(self, models): | |
| super().__init__() | |
| self.models = torch.nn.ModuleList(models) | |
| def forward(self, *args, **kwargs): | |
| predictions = [model.forward(*args, **kwargs) for model in self.models] | |
| predictions = torch.cat(predictions, dim=1) | |
| predictions -= np.log(len(self.models)) | |
| prediction = predictions.logsumexp(dim=(1), keepdim=True) | |
| return prediction | |