|
|
import torch |
|
|
from torch import nn |
|
|
import numpy as np |
|
|
from typing import Union, Type, List, Tuple |
|
|
|
|
|
from torch.nn.modules.conv import _ConvNd |
|
|
from torch.nn.modules.dropout import _DropoutNd |
|
|
from dynamic_network_architectures.building_blocks.simple_conv_blocks import StackedConvBlocks |
|
|
from dynamic_network_architectures.building_blocks.helper import maybe_convert_scalar_to_list, get_matching_pool_op |
|
|
|
|
|
|
|
|
class PlainConvEncoder(nn.Module): |
|
|
def __init__(self, |
|
|
input_channels: int, |
|
|
n_stages: int, |
|
|
features_per_stage: Union[int, List[int], Tuple[int, ...]], |
|
|
conv_op: Type[_ConvNd], |
|
|
kernel_sizes: Union[int, List[int], Tuple[int, ...]], |
|
|
strides: Union[int, List[int], Tuple[int, ...]], |
|
|
n_conv_per_stage: Union[int, List[int], Tuple[int, ...]], |
|
|
conv_bias: bool = False, |
|
|
norm_op: Union[None, Type[nn.Module]] = None, |
|
|
norm_op_kwargs: dict = None, |
|
|
dropout_op: Union[None, Type[_DropoutNd]] = None, |
|
|
dropout_op_kwargs: dict = None, |
|
|
nonlin: Union[None, Type[torch.nn.Module]] = None, |
|
|
nonlin_kwargs: dict = None, |
|
|
return_skips: bool = False, |
|
|
nonlin_first: bool = False, |
|
|
pool: str = 'conv' |
|
|
): |
|
|
|
|
|
super().__init__() |
|
|
if isinstance(kernel_sizes, int): |
|
|
kernel_sizes = [kernel_sizes] * n_stages |
|
|
if isinstance(features_per_stage, int): |
|
|
features_per_stage = [features_per_stage] * n_stages |
|
|
if isinstance(n_conv_per_stage, int): |
|
|
n_conv_per_stage = [n_conv_per_stage] * n_stages |
|
|
if isinstance(strides, int): |
|
|
strides = [strides] * n_stages |
|
|
assert len(kernel_sizes) == n_stages, "kernel_sizes must have as many entries as we have resolution stages (n_stages)" |
|
|
assert len(n_conv_per_stage) == n_stages, "n_conv_per_stage must have as many entries as we have resolution stages (n_stages)" |
|
|
assert len(features_per_stage) == n_stages, "features_per_stage must have as many entries as we have resolution stages (n_stages)" |
|
|
assert len(strides) == n_stages, "strides must have as many entries as we have resolution stages (n_stages). " \ |
|
|
"Important: first entry is recommended to be 1, else we run strided conv drectly on the input" |
|
|
|
|
|
stages = [] |
|
|
for s in range(n_stages): |
|
|
stage_modules = [] |
|
|
if pool == 'max' or pool == 'avg': |
|
|
if (isinstance(strides[s], int) and strides[s] != 1) or \ |
|
|
isinstance(strides[s], (tuple, list)) and any([i != 1 for i in strides[s]]): |
|
|
stage_modules.append(get_matching_pool_op(conv_op, pool_type=pool)(kernel_size=strides[s], stride=strides[s])) |
|
|
conv_stride = 1 |
|
|
elif pool == 'conv': |
|
|
conv_stride = strides[s] |
|
|
else: |
|
|
raise RuntimeError() |
|
|
stage_modules.append(StackedConvBlocks( |
|
|
n_conv_per_stage[s], conv_op, input_channels, features_per_stage[s], kernel_sizes[s], conv_stride, |
|
|
conv_bias, norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs, nonlin, nonlin_kwargs, nonlin_first |
|
|
)) |
|
|
stages.append(nn.Sequential(*stage_modules)) |
|
|
input_channels = features_per_stage[s] |
|
|
|
|
|
self.stages = nn.Sequential(*stages) |
|
|
self.output_channels = features_per_stage |
|
|
self.strides = [maybe_convert_scalar_to_list(conv_op, i) for i in strides] |
|
|
self.return_skips = return_skips |
|
|
|
|
|
|
|
|
self.conv_op = conv_op |
|
|
self.norm_op = norm_op |
|
|
self.norm_op_kwargs = norm_op_kwargs |
|
|
self.nonlin = nonlin |
|
|
self.nonlin_kwargs = nonlin_kwargs |
|
|
self.dropout_op = dropout_op |
|
|
self.dropout_op_kwargs = dropout_op_kwargs |
|
|
self.conv_bias = conv_bias |
|
|
self.kernel_sizes = kernel_sizes |
|
|
|
|
|
def forward(self, x): |
|
|
ret = [] |
|
|
for s in self.stages: |
|
|
x = s(x) |
|
|
ret.append(x) |
|
|
return ret |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def compute_conv_feature_map_size(self, input_size): |
|
|
output = np.int64(0) |
|
|
for s in range(len(self.stages)): |
|
|
if isinstance(self.stages[s], nn.Sequential): |
|
|
for sq in self.stages[s]: |
|
|
if hasattr(sq, 'compute_conv_feature_map_size'): |
|
|
output += self.stages[s][-1].compute_conv_feature_map_size(input_size) |
|
|
else: |
|
|
output += self.stages[s].compute_conv_feature_map_size(input_size) |
|
|
input_size = [i // j for i, j in zip(input_size, self.strides[s])] |
|
|
return output |
|
|
|
|
|
|