| """ |
| https://github.com/pytorch/vision/blob/main/torchvision/models/_utils.py |
| |
| Copyright(c) 2023 lyuwenyu. All Rights Reserved. |
| """ |
|
|
| from collections import OrderedDict |
| from typing import Dict, List |
|
|
|
|
| import torch.nn as nn |
|
|
|
|
| class IntermediateLayerGetter(nn.ModuleDict): |
| """ |
| Module wrapper that returns intermediate layers from a model |
| |
| It has a strong assumption that the modules have been registered |
| into the model in the same order as they are used. |
| This means that one should **not** reuse the same nn.Module |
| twice in the forward if you want this to work. |
| |
| Additionally, it is only able to query submodules that are directly |
| assigned to the model. So if `model` is passed, `model.feature1` can |
| be returned, but not `model.feature1.layer2`. |
| """ |
|
|
| _version = 3 |
|
|
| def __init__(self, model: nn.Module, return_layers: List[str]) -> None: |
| if not set(return_layers).issubset([name for name, _ in model.named_children()]): |
| raise ValueError("return_layers are not present in model. {}"\ |
| .format([name for name, _ in model.named_children()])) |
| orig_return_layers = return_layers |
| return_layers = {str(k): str(k) for k in return_layers} |
| layers = OrderedDict() |
| for name, module in model.named_children(): |
| layers[name] = module |
| if name in return_layers: |
| del return_layers[name] |
| if not return_layers: |
| break |
|
|
| super().__init__(layers) |
| self.return_layers = orig_return_layers |
|
|
| def forward(self, x): |
| outputs = [] |
| for name, module in self.items(): |
| x = module(x) |
| if name in self.return_layers: |
| outputs.append(x) |
|
|
| return outputs |
|
|