ImpactSynth / MR /Model.py
VBoussot's picture
Update MR/Model.py
86404cd verified
from konfai.network import network
import segmentation_models_pytorch as smp
import torch
class Head(network.ModuleArgsDict):
def __init__(self):
super().__init__()
self.add_module("Tanh", torch.nn.Tanh())
class UNetpp(network.Network):
def __init__(self,
optimizer : network.OptimizerLoader = network.OptimizerLoader(),
schedulers: dict[str, network.LRSchedulersLoader] = {
"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)
},
outputs_criterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
pretrained: bool = False):
super().__init__(in_channels = 3, optimizer = optimizer, schedulers = schedulers, outputs_criterions = outputs_criterions, dim = 2)
self.add_module("model", smp.UnetPlusPlus(
encoder_name="resnet34",
encoder_weights=None if not pretrained else "imagenet",
in_channels=3,
classes=1,
activation=None
))
self.add_module("Head", Head())