| | import torch |
| | import torch.nn as nn |
| | import numpy as np |
| | import torch.nn.functional as F |
| | from collections import OrderedDict |
| |
|
| | class _SimpleSegmentationModel(nn.Module): |
| | def __init__(self, backbone, classifier): |
| | super(_SimpleSegmentationModel, self).__init__() |
| | self.backbone = backbone |
| | self.classifier = classifier |
| | |
| | def forward(self, x): |
| | input_shape = x.shape[-2:] |
| | features = self.backbone(x) |
| | x = self.classifier(features) |
| | x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False) |
| | return x |
| |
|
| |
|
| | class IntermediateLayerGetter(nn.ModuleDict): |
| | """ |
| | Module wrapper that returns intermediate layers from a model |
| | |
| | It has a strong assumption that the modules have been registered |
| | into the model in the same order as they are used. |
| | This means that one should **not** reuse the same nn.Module |
| | twice in the forward if you want this to work. |
| | |
| | Additionally, it is only able to query submodules that are directly |
| | assigned to the model. So if `model` is passed, `model.feature1` can |
| | be returned, but not `model.feature1.layer2`. |
| | |
| | Arguments: |
| | model (nn.Module): model on which we will extract the features |
| | return_layers (Dict[name, new_name]): a dict containing the names |
| | of the modules for which the activations will be returned as |
| | the key of the dict, and the value of the dict is the name |
| | of the returned activation (which the user can specify). |
| | |
| | Examples:: |
| | |
| | >>> m = torchvision.models.resnet18(pretrained=True) |
| | >>> # extract layer1 and layer3, giving as names `feat1` and feat2` |
| | >>> new_m = torchvision.models._utils.IntermediateLayerGetter(m, |
| | >>> {'layer1': 'feat1', 'layer3': 'feat2'}) |
| | >>> out = new_m(torch.rand(1, 3, 224, 224)) |
| | >>> print([(k, v.shape) for k, v in out.items()]) |
| | >>> [('feat1', torch.Size([1, 64, 56, 56])), |
| | >>> ('feat2', torch.Size([1, 256, 14, 14]))] |
| | """ |
| | def __init__(self, model, return_layers, hrnet_flag=False): |
| | if not set(return_layers).issubset([name for name, _ in model.named_children()]): |
| | raise ValueError("return_layers are not present in model") |
| |
|
| | self.hrnet_flag = hrnet_flag |
| |
|
| | orig_return_layers = return_layers |
| | return_layers = {k: v for k, v in return_layers.items()} |
| | layers = OrderedDict() |
| | for name, module in model.named_children(): |
| | layers[name] = module |
| | if name in return_layers: |
| | del return_layers[name] |
| | if not return_layers: |
| | break |
| |
|
| | super(IntermediateLayerGetter, self).__init__(layers) |
| | self.return_layers = orig_return_layers |
| |
|
| | def forward(self, x): |
| | out = OrderedDict() |
| | for name, module in self.named_children(): |
| | if self.hrnet_flag and name.startswith('transition'): |
| | if name == 'transition1': |
| | x = [trans(x) for trans in module] |
| | else: |
| | x.append(module(x[-1])) |
| | else: |
| | x = module(x) |
| |
|
| | if name in self.return_layers: |
| | out_name = self.return_layers[name] |
| | if name == 'stage4' and self.hrnet_flag: |
| | output_h, output_w = x[0].size(2), x[0].size(3) |
| | x1 = F.interpolate(x[1], size=(output_h, output_w), mode='bilinear', align_corners=False) |
| | x2 = F.interpolate(x[2], size=(output_h, output_w), mode='bilinear', align_corners=False) |
| | x3 = F.interpolate(x[3], size=(output_h, output_w), mode='bilinear', align_corners=False) |
| | x = torch.cat([x[0], x1, x2, x3], dim=1) |
| | out[out_name] = x |
| | else: |
| | out[out_name] = x |
| | return out |
| |
|