from torch import nn from dynamic_network_architectures.architectures.unet import ResidualEncoderUNet from konfai.network import network def _replace_unpicklable_identities(module: nn.Module) -> None: """Replace library lambdas such as ``lambda x: x`` with ``nn.Identity``.""" for child in module.modules(): if hasattr(child, "skip") and callable(child.skip) and not isinstance(child.skip, nn.Module): child.skip = nn.Identity() if hasattr(child, "nonlin2") and callable(child.nonlin2) and not isinstance(child.nonlin2, nn.Module): child.nonlin2 = nn.Identity() class ResEnc(network.Network): def __init__( self, optimizer: network.OptimizerLoader = network.OptimizerLoader(), schedulers: dict[str, network.LRSchedulersLoader] = { "default:ReduceLROnPlateau": network.LRSchedulersLoader(0) }, outputs_criterions: dict[str, network.TargetCriterionsLoader] = {"default": network.TargetCriterionsLoader()}, in_channels: int = 5, nb_class: int = 132, ) -> None: super().__init__( in_channels=in_channels, optimizer=optimizer, schedulers=schedulers, outputs_criterions=outputs_criterions, dim=2, ) self.add_module("DecoderOutputs", ResidualEncoderUNet( input_channels=in_channels, n_stages=6, features_per_stage=(24, 48, 96, 192, 256, 256), conv_op=nn.Conv2d, kernel_sizes=3, strides=(1, 2, 2, 2, 2, 2), n_blocks_per_stage=(1, 2, 2, 3, 3, 3), num_classes=nb_class, n_conv_per_stage_decoder=(1, 1, 1, 1, 1), conv_bias=True, norm_op=nn.InstanceNorm2d, norm_op_kwargs={"eps": 1e-5, "affine": True}, dropout_op=None, dropout_op_kwargs=None, nonlin=nn.LeakyReLU, nonlin_kwargs={"inplace": True}, deep_supervision=False, )) _replace_unpicklable_identities(self.DecoderOutputs)