|
|
|
|
| import copy
|
| import numpy as np
|
| import scipy.ndimage as ndimage
|
| import torch
|
| import torch.nn.functional as F
|
| from torch import nn
|
|
|
|
|
| class Preprocessing(torch.nn.Module):
|
| def __init__(self, input_dims, output_dim):
|
| super(Preprocessing, self).__init__()
|
| self.input_dims = input_dims
|
| self.output_dim = output_dim
|
|
|
| self.preprocessing_modules = torch.nn.ModuleList()
|
| for _ in input_dims:
|
| module = MeanMapper(output_dim)
|
| self.preprocessing_modules.append(module)
|
|
|
| def forward(self, features):
|
| _features = []
|
| for module, feature in zip(self.preprocessing_modules, features):
|
| _features.append(module(feature))
|
| return torch.stack(_features, dim=1)
|
|
|
|
|
| class MeanMapper(torch.nn.Module):
|
| def __init__(self, preprocessing_dim):
|
| super(MeanMapper, self).__init__()
|
| self.preprocessing_dim = preprocessing_dim
|
|
|
| def forward(self, features):
|
| features = features.reshape(len(features), 1, -1)
|
| return F.adaptive_avg_pool1d(features, self.preprocessing_dim).squeeze(1)
|
|
|
|
|
| class Aggregator(torch.nn.Module):
|
| def __init__(self, target_dim):
|
| super(Aggregator, self).__init__()
|
| self.target_dim = target_dim
|
|
|
| def forward(self, features):
|
| """Returns reshaped and average pooled features."""
|
| features = features.reshape(len(features), 1, -1)
|
| features = F.adaptive_avg_pool1d(features, self.target_dim)
|
| return features.reshape(len(features), -1)
|
|
|
|
|
| class RescaleSegmentor:
|
| def __init__(self, device, target_size=288):
|
| self.device = device
|
| self.target_size = target_size
|
| self.smoothing = 4
|
|
|
| def convert_to_segmentation(self, patch_scores):
|
| with torch.no_grad():
|
| if isinstance(patch_scores, np.ndarray):
|
| patch_scores = torch.from_numpy(patch_scores)
|
| _scores = patch_scores.to(self.device)
|
| _scores = _scores.unsqueeze(1)
|
| _scores = F.interpolate(
|
| _scores, size=self.target_size, mode="bilinear", align_corners=False
|
| )
|
| _scores = _scores.squeeze(1)
|
| patch_scores = _scores.cpu().numpy()
|
| return [ndimage.gaussian_filter(patch_score, sigma=self.smoothing) for patch_score in patch_scores]
|
|
|
|
|
| class NetworkFeatureAggregator(torch.nn.Module):
|
| """Efficient extraction of network features."""
|
|
|
| def __init__(self, backbone, layers_to_extract_from, device, train_backbone=False):
|
| super(NetworkFeatureAggregator, self).__init__()
|
| """Extraction of network features.
|
|
|
| Runs a network only to the last layer of the list of layers where
|
| network features should be extracted from.
|
|
|
| Args:
|
| backbone: torchvision.model
|
| layers_to_extract_from: [list of str]
|
| """
|
| self.layers_to_extract_from = layers_to_extract_from
|
| self.backbone = backbone
|
| self.device = device
|
| self.train_backbone = train_backbone
|
| if not hasattr(backbone, "hook_handles"):
|
| self.backbone.hook_handles = []
|
| for handle in self.backbone.hook_handles:
|
| handle.remove()
|
| self.outputs = {}
|
|
|
| for extract_layer in layers_to_extract_from:
|
| self.register_hook(extract_layer)
|
|
|
| self.to(self.device)
|
|
|
| def forward(self, images, eval=True):
|
| self.outputs.clear()
|
| if self.train_backbone and not eval:
|
| self.backbone.train()
|
| self.backbone(images)
|
| else:
|
| self.backbone.eval()
|
| with torch.no_grad():
|
| self.backbone(images)
|
| return self.outputs
|
|
|
| def feature_dimensions(self, input_shape):
|
| """Computes the feature dimensions for all layers given input_shape."""
|
| _input = torch.ones([1] + list(input_shape)).to(self.device)
|
| _output = self(_input)
|
| return [_output[layer].shape[1] for layer in self.layers_to_extract_from]
|
|
|
| def register_hook(self, layer_name):
|
| module = self.find_module(self.backbone, layer_name)
|
| if module is not None:
|
| forward_hook = ForwardHook(self.outputs, layer_name, self.layers_to_extract_from[-1])
|
| if isinstance(module, torch.nn.Sequential):
|
| hook = module[-1].register_forward_hook(forward_hook)
|
| else:
|
| hook = module.register_forward_hook(forward_hook)
|
| self.backbone.hook_handles.append(hook)
|
| else:
|
| raise ValueError(f"Module {layer_name} not found in the model")
|
|
|
| def find_module(self, model, module_name):
|
| for name, module in model.named_modules():
|
| if name == module_name:
|
| return module
|
| elif '.' in module_name:
|
| father, child = module_name.split('.', 1)
|
| if name == father:
|
| return self.find_module(module, child)
|
| return None
|
|
|
|
|
| class ForwardHook:
|
| def __init__(self, hook_dict, layer_name: str, last_layer_to_extract: str):
|
| self.hook_dict = hook_dict
|
| self.layer_name = layer_name
|
| self.raise_exception_to_break = copy.deepcopy(
|
| layer_name == last_layer_to_extract
|
| )
|
|
|
| def __call__(self, module, input, output):
|
| self.hook_dict[self.layer_name] = output
|
| return None
|
|
|
|
|
| class LastLayerToExtractReachedException(Exception):
|
| pass
|
|
|