FelixzeroSun's picture
Upload folder using huggingface_hub
19c1f58 verified
import torch
from torch import nn
import numpy as np
from typing import Union, Type, List, Tuple
from torch.nn.modules.conv import _ConvNd
from torch.nn.modules.dropout import _DropoutNd
from dynamic_network_architectures.building_blocks.simple_conv_blocks import StackedConvBlocks
from dynamic_network_architectures.building_blocks.helper import maybe_convert_scalar_to_list, get_matching_pool_op
class PlainConvEncoder(nn.Module):
def __init__(self,
input_channels: int,
n_stages: int,
features_per_stage: Union[int, List[int], Tuple[int, ...]],
conv_op: Type[_ConvNd],
kernel_sizes: Union[int, List[int], Tuple[int, ...]],
strides: Union[int, List[int], Tuple[int, ...]],
n_conv_per_stage: Union[int, List[int], Tuple[int, ...]],
conv_bias: bool = False,
norm_op: Union[None, Type[nn.Module]] = None,
norm_op_kwargs: dict = None,
dropout_op: Union[None, Type[_DropoutNd]] = None,
dropout_op_kwargs: dict = None,
nonlin: Union[None, Type[torch.nn.Module]] = None,
nonlin_kwargs: dict = None,
return_skips: bool = False,
nonlin_first: bool = False,
pool: str = 'conv'
):
super().__init__()
if isinstance(kernel_sizes, int):
kernel_sizes = [kernel_sizes] * n_stages
if isinstance(features_per_stage, int):
features_per_stage = [features_per_stage] * n_stages
if isinstance(n_conv_per_stage, int):
n_conv_per_stage = [n_conv_per_stage] * n_stages
if isinstance(strides, int):
strides = [strides] * n_stages
assert len(kernel_sizes) == n_stages, "kernel_sizes must have as many entries as we have resolution stages (n_stages)"
assert len(n_conv_per_stage) == n_stages, "n_conv_per_stage must have as many entries as we have resolution stages (n_stages)"
assert len(features_per_stage) == n_stages, "features_per_stage must have as many entries as we have resolution stages (n_stages)"
assert len(strides) == n_stages, "strides must have as many entries as we have resolution stages (n_stages). " \
"Important: first entry is recommended to be 1, else we run strided conv drectly on the input"
stages = []
for s in range(n_stages):
stage_modules = []
if pool == 'max' or pool == 'avg':
if (isinstance(strides[s], int) and strides[s] != 1) or \
isinstance(strides[s], (tuple, list)) and any([i != 1 for i in strides[s]]):
stage_modules.append(get_matching_pool_op(conv_op, pool_type=pool)(kernel_size=strides[s], stride=strides[s]))
conv_stride = 1
elif pool == 'conv':
conv_stride = strides[s]
else:
raise RuntimeError()
stage_modules.append(StackedConvBlocks(
n_conv_per_stage[s], conv_op, input_channels, features_per_stage[s], kernel_sizes[s], conv_stride,
conv_bias, norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs, nonlin, nonlin_kwargs, nonlin_first
))
stages.append(nn.Sequential(*stage_modules))
input_channels = features_per_stage[s]
self.stages = nn.Sequential(*stages)
self.output_channels = features_per_stage
self.strides = [maybe_convert_scalar_to_list(conv_op, i) for i in strides]
self.return_skips = return_skips
# we store some things that a potential decoder needs
self.conv_op = conv_op
self.norm_op = norm_op
self.norm_op_kwargs = norm_op_kwargs
self.nonlin = nonlin
self.nonlin_kwargs = nonlin_kwargs
self.dropout_op = dropout_op
self.dropout_op_kwargs = dropout_op_kwargs
self.conv_bias = conv_bias
self.kernel_sizes = kernel_sizes
def forward(self, x):
ret = []
for s in self.stages:
x = s(x)
ret.append(x)
return ret
# def forward(self, x):
# ret = []
# feature_maps = [] # To store the required feature maps
# for stage in self.stages:
# for block in stage:
# x = block(x)
# if isinstance(block, StackedConvBlocks):
# # Assuming StackedConvBlocks has a structure that ends with a ReLU or similar activation
# # And assuming the last module in StackedConvBlocks is the ReLU activation we're interested in
# feature_maps.append(x) # Append the feature map right after the last ReLU activation
# ret.append(x)
# else:
# return ret[-1], feature_maps
def compute_conv_feature_map_size(self, input_size):
output = np.int64(0)
for s in range(len(self.stages)):
if isinstance(self.stages[s], nn.Sequential):
for sq in self.stages[s]:
if hasattr(sq, 'compute_conv_feature_map_size'):
output += self.stages[s][-1].compute_conv_feature_map_size(input_size)
else:
output += self.stages[s].compute_conv_feature_map_size(input_size)
input_size = [i // j for i, j in zip(input_size, self.strides[s])]
return output