| from konfai.network import network |
| import segmentation_models_pytorch as smp |
| import torch |
|
|
| class Head(network.ModuleArgsDict): |
|
|
| def __init__(self): |
| super().__init__() |
| self.add_module("Tanh", torch.nn.Tanh()) |
|
|
| class UNetpp(network.Network): |
| |
| def __init__(self, |
| optimizer : network.OptimizerLoader = network.OptimizerLoader(), |
| schedulers: dict[str, network.LRSchedulersLoader] = { |
| "default:ReduceLROnPlateau": network.LRSchedulersLoader(0) |
| }, |
| outputs_criterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()}, |
| nb_channel : int = 5): |
| super().__init__(in_channels = nb_channel, optimizer = optimizer, schedulers = schedulers, outputs_criterions = outputs_criterions, dim = 2) |
| self.add_module("model", smp.UnetPlusPlus( |
| encoder_name="resnet34", |
| encoder_weights=None, |
| in_channels=nb_channel, |
| classes=1, |
| activation=None |
| )) |
| self.add_module("Head", Head()) |