from konfai.network import network import segmentation_models_pytorch as smp import torch class Head(network.ModuleArgsDict): def __init__(self): super().__init__() self.add_module("Tanh", torch.nn.Tanh()) class UNetpp(network.Network): def __init__(self, optimizer : network.OptimizerLoader = network.OptimizerLoader(), schedulers: dict[str, network.LRSchedulersLoader] = { "default:ReduceLROnPlateau": network.LRSchedulersLoader(0) }, outputs_criterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()}, pretrained: bool = False): super().__init__(in_channels = 3, optimizer = optimizer, schedulers = schedulers, outputs_criterions = outputs_criterions, dim = 2) self.add_module("model", smp.UnetPlusPlus( encoder_name="resnet34", encoder_weights=None if not pretrained else "imagenet", in_channels=3, classes=1, activation=None )) self.add_module("Head", Head())