| from torch import nn |
| from dynamic_network_architectures.architectures.unet import ResidualEncoderUNet |
| from konfai.network import network |
|
|
| def _replace_unpicklable_identities(module: nn.Module) -> None: |
| """Replace library lambdas such as ``lambda x: x`` with ``nn.Identity``.""" |
| for child in module.modules(): |
| if hasattr(child, "skip") and callable(child.skip) and not isinstance(child.skip, nn.Module): |
| child.skip = nn.Identity() |
| if hasattr(child, "nonlin2") and callable(child.nonlin2) and not isinstance(child.nonlin2, nn.Module): |
| child.nonlin2 = nn.Identity() |
| |
| class ResEnc(network.Network): |
| def __init__( |
| self, |
| optimizer: network.OptimizerLoader = network.OptimizerLoader(), |
| schedulers: dict[str, network.LRSchedulersLoader] = { |
| "default:ReduceLROnPlateau": network.LRSchedulersLoader(0) |
| }, |
| outputs_criterions: dict[str, network.TargetCriterionsLoader] = {"default": network.TargetCriterionsLoader()}, |
| in_channels: int = 5, |
| nb_class: int = 132, |
| ) -> None: |
| super().__init__( |
| in_channels=in_channels, |
| optimizer=optimizer, |
| schedulers=schedulers, |
| outputs_criterions=outputs_criterions, |
| dim=2, |
| ) |
| self.add_module("DecoderOutputs", ResidualEncoderUNet( |
| input_channels=in_channels, |
| n_stages=6, |
| features_per_stage=(24, 48, 96, 192, 256, 256), |
| conv_op=nn.Conv2d, |
| kernel_sizes=3, |
| strides=(1, 2, 2, 2, 2, 2), |
| n_blocks_per_stage=(1, 2, 2, 3, 3, 3), |
| num_classes=nb_class, |
| n_conv_per_stage_decoder=(1, 1, 1, 1, 1), |
| conv_bias=True, |
| norm_op=nn.InstanceNorm2d, |
| norm_op_kwargs={"eps": 1e-5, "affine": True}, |
| dropout_op=None, |
| dropout_op_kwargs=None, |
| nonlin=nn.LeakyReLU, |
| nonlin_kwargs={"inplace": True}, |
| deep_supervision=False, |
| )) |
| _replace_unpicklable_identities(self.DecoderOutputs) |