|
|
| import sys |
|
|
| import torch |
| import torch.nn as nn |
|
|
| from transformers import PreTrainedModel |
|
|
| from .ProbUNet_model import InjectionConvEncoder2D, InjectionUNet2D, InjectionConvEncoder3D, InjectionUNet3D, ProbabilisticSegmentationNet |
| from .PULASkiConfigs import ProbUNetConfig |
|
|
| class ProbUNet(PreTrainedModel): |
| config_class = ProbUNetConfig |
| def __init__(self, config): |
| super().__init__(config) |
| |
| if config.dim == 2: |
| task_op = InjectionUNet2D |
| prior_op = InjectionConvEncoder2D |
| posterior_op = InjectionConvEncoder2D |
| elif config.dim == 3: |
| task_op = InjectionUNet3D |
| prior_op = InjectionConvEncoder3D |
| posterior_op = InjectionConvEncoder3D |
| else: |
| sys.exit("Invalid dim! Only configured for dim 2 and 3.") |
| |
| if config.latent_distribution == "normal": |
| latent_distribution = torch.distributions.Normal |
| else: |
| sys.exit("Invalid latent_distribution. Only normal has been implemented.") |
| |
| self.model = ProbabilisticSegmentationNet(in_channels=config.in_channels, |
| out_channels=config.out_channels, |
| num_feature_maps=config.num_feature_maps, |
| latent_size=config.latent_size, |
| depth=config.depth, |
| latent_distribution=latent_distribution, |
| task_op=task_op, |
| task_kwargs={"output_activation_op": nn.Identity if config.no_outact_op else nn.Sigmoid, |
| "activation_kwargs": {"inplace": True}, "injection_at": config.prob_injection_at}, |
| prior_op=prior_op, |
| prior_kwargs={"activation_kwargs": {"inplace": True}, "norm_depth": 2}, |
| posterior_op=posterior_op, |
| posterior_kwargs={"activation_kwargs": {"inplace": True}, "norm_depth": 2}, |
| ) |
| def forward(self, x): |
| return self.model(x) |