# common.py import copy import numpy as np import scipy.ndimage as ndimage import torch import torch.nn.functional as F from torch import nn class Preprocessing(torch.nn.Module): def __init__(self, input_dims, output_dim): super(Preprocessing, self).__init__() self.input_dims = input_dims self.output_dim = output_dim self.preprocessing_modules = torch.nn.ModuleList() for _ in input_dims: module = MeanMapper(output_dim) self.preprocessing_modules.append(module) def forward(self, features): _features = [] for module, feature in zip(self.preprocessing_modules, features): _features.append(module(feature)) return torch.stack(_features, dim=1) class MeanMapper(torch.nn.Module): def __init__(self, preprocessing_dim): super(MeanMapper, self).__init__() self.preprocessing_dim = preprocessing_dim def forward(self, features): features = features.reshape(len(features), 1, -1) return F.adaptive_avg_pool1d(features, self.preprocessing_dim).squeeze(1) class Aggregator(torch.nn.Module): def __init__(self, target_dim): super(Aggregator, self).__init__() self.target_dim = target_dim def forward(self, features): """Returns reshaped and average pooled features.""" features = features.reshape(len(features), 1, -1) features = F.adaptive_avg_pool1d(features, self.target_dim) return features.reshape(len(features), -1) class RescaleSegmentor: def __init__(self, device, target_size=288): self.device = device self.target_size = target_size self.smoothing = 4 def convert_to_segmentation(self, patch_scores): with torch.no_grad(): if isinstance(patch_scores, np.ndarray): patch_scores = torch.from_numpy(patch_scores) _scores = patch_scores.to(self.device) _scores = _scores.unsqueeze(1) _scores = F.interpolate( _scores, size=self.target_size, mode="bilinear", align_corners=False ) _scores = _scores.squeeze(1) patch_scores = _scores.cpu().numpy() return [ndimage.gaussian_filter(patch_score, sigma=self.smoothing) for patch_score in patch_scores] class NetworkFeatureAggregator(torch.nn.Module): """Efficient extraction of network features.""" def __init__(self, backbone, layers_to_extract_from, device, train_backbone=False): super(NetworkFeatureAggregator, self).__init__() """Extraction of network features. Runs a network only to the last layer of the list of layers where network features should be extracted from. Args: backbone: torchvision.model layers_to_extract_from: [list of str] """ self.layers_to_extract_from = layers_to_extract_from self.backbone = backbone self.device = device self.train_backbone = train_backbone if not hasattr(backbone, "hook_handles"): self.backbone.hook_handles = [] for handle in self.backbone.hook_handles: handle.remove() self.outputs = {} for extract_layer in layers_to_extract_from: self.register_hook(extract_layer) self.to(self.device) def forward(self, images, eval=True): self.outputs.clear() if self.train_backbone and not eval: self.backbone.train() self.backbone(images) else: self.backbone.eval() with torch.no_grad(): self.backbone(images) return self.outputs def feature_dimensions(self, input_shape): """Computes the feature dimensions for all layers given input_shape.""" _input = torch.ones([1] + list(input_shape)).to(self.device) _output = self(_input) return [_output[layer].shape[1] for layer in self.layers_to_extract_from] def register_hook(self, layer_name): module = self.find_module(self.backbone, layer_name) if module is not None: forward_hook = ForwardHook(self.outputs, layer_name, self.layers_to_extract_from[-1]) if isinstance(module, torch.nn.Sequential): hook = module[-1].register_forward_hook(forward_hook) else: hook = module.register_forward_hook(forward_hook) self.backbone.hook_handles.append(hook) else: raise ValueError(f"Module {layer_name} not found in the model") def find_module(self, model, module_name): for name, module in model.named_modules(): if name == module_name: return module elif '.' in module_name: father, child = module_name.split('.', 1) if name == father: return self.find_module(module, child) return None class ForwardHook: def __init__(self, hook_dict, layer_name: str, last_layer_to_extract: str): self.hook_dict = hook_dict self.layer_name = layer_name self.raise_exception_to_break = copy.deepcopy( layer_name == last_layer_to_extract ) def __call__(self, module, input, output): self.hook_dict[self.layer_name] = output return None class LastLayerToExtractReachedException(Exception): pass