ImpactSeg / body /Model.py
VBoussot's picture
Update body model
982b8c4
Raw
History Blame Contribute Delete
2.17 kB
from torch import nn
from dynamic_network_architectures.architectures.unet import ResidualEncoderUNet
from konfai.network import network
def _replace_unpicklable_identities(module: nn.Module) -> None:
"""Replace library lambdas such as ``lambda x: x`` with ``nn.Identity``."""
for child in module.modules():
if hasattr(child, "skip") and callable(child.skip) and not isinstance(child.skip, nn.Module):
child.skip = nn.Identity()
if hasattr(child, "nonlin2") and callable(child.nonlin2) and not isinstance(child.nonlin2, nn.Module):
child.nonlin2 = nn.Identity()
class ResEnc(network.Network):
def __init__(
self,
optimizer: network.OptimizerLoader = network.OptimizerLoader(),
schedulers: dict[str, network.LRSchedulersLoader] = {
"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)
},
outputs_criterions: dict[str, network.TargetCriterionsLoader] = {"default": network.TargetCriterionsLoader()},
in_channels: int = 5,
nb_class: int = 132,
) -> None:
super().__init__(
in_channels=in_channels,
optimizer=optimizer,
schedulers=schedulers,
outputs_criterions=outputs_criterions,
dim=2,
)
self.add_module("DecoderOutputs", ResidualEncoderUNet(
input_channels=in_channels,
n_stages=6,
features_per_stage=(24, 48, 96, 192, 256, 256),
conv_op=nn.Conv2d,
kernel_sizes=3,
strides=(1, 2, 2, 2, 2, 2),
n_blocks_per_stage=(1, 2, 2, 3, 3, 3),
num_classes=nb_class,
n_conv_per_stage_decoder=(1, 1, 1, 1, 1),
conv_bias=True,
norm_op=nn.InstanceNorm2d,
norm_op_kwargs={"eps": 1e-5, "affine": True},
dropout_op=None,
dropout_op_kwargs=None,
nonlin=nn.LeakyReLU,
nonlin_kwargs={"inplace": True},
deep_supervision=False,
))
_replace_unpicklable_identities(self.DecoderOutputs)