File size: 2,171 Bytes
982b8c4 2558e55 9d47104 2558e55 9d47104 2558e55 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 | from torch import nn
from dynamic_network_architectures.architectures.unet import ResidualEncoderUNet
from konfai.network import network
def _replace_unpicklable_identities(module: nn.Module) -> None:
"""Replace library lambdas such as ``lambda x: x`` with ``nn.Identity``."""
for child in module.modules():
if hasattr(child, "skip") and callable(child.skip) and not isinstance(child.skip, nn.Module):
child.skip = nn.Identity()
if hasattr(child, "nonlin2") and callable(child.nonlin2) and not isinstance(child.nonlin2, nn.Module):
child.nonlin2 = nn.Identity()
class ResEnc(network.Network):
def __init__(
self,
optimizer: network.OptimizerLoader = network.OptimizerLoader(),
schedulers: dict[str, network.LRSchedulersLoader] = {
"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)
},
outputs_criterions: dict[str, network.TargetCriterionsLoader] = {"default": network.TargetCriterionsLoader()},
in_channels: int = 5,
nb_class: int = 132,
) -> None:
super().__init__(
in_channels=in_channels,
optimizer=optimizer,
schedulers=schedulers,
outputs_criterions=outputs_criterions,
dim=2,
)
self.add_module("DecoderOutputs", ResidualEncoderUNet(
input_channels=in_channels,
n_stages=6,
features_per_stage=(24, 48, 96, 192, 256, 256),
conv_op=nn.Conv2d,
kernel_sizes=3,
strides=(1, 2, 2, 2, 2, 2),
n_blocks_per_stage=(1, 2, 2, 3, 3, 3),
num_classes=nb_class,
n_conv_per_stage_decoder=(1, 1, 1, 1, 1),
conv_bias=True,
norm_op=nn.InstanceNorm2d,
norm_op_kwargs={"eps": 1e-5, "affine": True},
dropout_op=None,
dropout_op_kwargs=None,
nonlin=nn.LeakyReLU,
nonlin_kwargs={"inplace": True},
deep_supervision=False,
))
_replace_unpicklable_identities(self.DecoderOutputs) |