File size: 2,171 Bytes
982b8c4
2558e55
 
 
 
 
 
 
 
 
 
 
9d47104
2558e55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9d47104
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2558e55
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from torch import nn
from dynamic_network_architectures.architectures.unet import ResidualEncoderUNet
from konfai.network import network

def _replace_unpicklable_identities(module: nn.Module) -> None:
    """Replace library lambdas such as ``lambda x: x`` with ``nn.Identity``."""
    for child in module.modules():
        if hasattr(child, "skip") and callable(child.skip) and not isinstance(child.skip, nn.Module):
            child.skip = nn.Identity()
        if hasattr(child, "nonlin2") and callable(child.nonlin2) and not isinstance(child.nonlin2, nn.Module):
            child.nonlin2 = nn.Identity()
    
class ResEnc(network.Network):
    def __init__(
        self,
        optimizer: network.OptimizerLoader = network.OptimizerLoader(),
        schedulers: dict[str, network.LRSchedulersLoader] = {
            "default:ReduceLROnPlateau": network.LRSchedulersLoader(0)
        },
        outputs_criterions: dict[str, network.TargetCriterionsLoader] = {"default": network.TargetCriterionsLoader()},
        in_channels: int = 5,
        nb_class: int = 132,
    ) -> None:
        super().__init__(
            in_channels=in_channels,
            optimizer=optimizer,
            schedulers=schedulers,
            outputs_criterions=outputs_criterions,
            dim=2,
        )
        self.add_module("DecoderOutputs", ResidualEncoderUNet(
                input_channels=in_channels,
                n_stages=6,
                features_per_stage=(24, 48, 96, 192, 256, 256),
                conv_op=nn.Conv2d,
                kernel_sizes=3,
                strides=(1, 2, 2, 2, 2, 2),
                n_blocks_per_stage=(1, 2, 2, 3, 3, 3),
                num_classes=nb_class,
                n_conv_per_stage_decoder=(1, 1, 1, 1, 1),
                conv_bias=True,
                norm_op=nn.InstanceNorm2d,
                norm_op_kwargs={"eps": 1e-5, "affine": True},
                dropout_op=None,
                dropout_op_kwargs=None,
                nonlin=nn.LeakyReLU,
                nonlin_kwargs={"inplace": True},
                deep_supervision=False,
            ))
        _replace_unpicklable_identities(self.DecoderOutputs)