| import torch |
| import torch.nn as nn |
|
|
| from .ProbUNet_utils import make_onehot as make_onehot_segmentation, make_slices, match_to |
|
|
|
|
| def is_conv(op): |
| conv_types = (nn.Conv1d, |
| nn.Conv2d, |
| nn.Conv3d, |
| nn.ConvTranspose1d, |
| nn.ConvTranspose2d, |
| nn.ConvTranspose3d) |
| if type(op) == type and issubclass(op, conv_types): |
| return True |
| elif type(op) in conv_types: |
| return True |
| else: |
| return False |
|
|
|
|
|
|
| class ConvModule(nn.Module): |
|
|
| def __init__(self, *args, **kwargs): |
|
|
| super(ConvModule, self).__init__() |
|
|
| def init_weights(self, init_fn, *args, **kwargs): |
|
|
| class init_(object): |
|
|
| def __init__(self): |
| self.fn = init_fn |
| self.args = args |
| self.kwargs = kwargs |
|
|
| def __call__(self, module): |
| if is_conv(type(module)): |
| module.weight = self.fn(module.weight, *self.args, **self.kwargs) |
|
|
| _init_ = init_() |
| self.apply(_init_) |
|
|
| def init_bias(self, init_fn, *args, **kwargs): |
|
|
| class init_(object): |
|
|
| def __init__(self): |
| self.fn = init_fn |
| self.args = args |
| self.kwargs = kwargs |
|
|
| def __call__(self, module): |
| if is_conv(type(module)) and module.bias is not None: |
| module.bias = self.fn(module.bias, *self.args, **self.kwargs) |
|
|
| _init_ = init_() |
| self.apply(_init_) |
|
|
|
|
|
|
| class ConcatCoords(nn.Module): |
|
|
| def forward(self, input_): |
|
|
| dim = input_.dim() - 2 |
| coord_channels = [] |
| for i in range(dim): |
| view = [1, ] * dim |
| view[i] = -1 |
| repeat = list(input_.shape[2:]) |
| repeat[i] = 1 |
| coord_channels.append( |
| torch.linspace(-0.5, 0.5, input_.shape[i+2]) |
| .view(*view) |
| .repeat(*repeat) |
| .to(device=input_.device, dtype=input_.dtype)) |
| coord_channels = torch.stack(coord_channels).unsqueeze(0) |
| repeat = [1, ] * input_.dim() |
| repeat[0] = input_.shape[0] |
| coord_channels = coord_channels.repeat(*repeat).contiguous() |
|
|
| return torch.cat([input_, coord_channels], 1) |
|
|
|
|
|
|
| class InjectionConvEncoder(ConvModule): |
|
|
| _default_activation_kwargs = dict(inplace=True) |
| _default_norm_kwargs = dict() |
| _default_conv_kwargs = dict(kernel_size=3, padding=1) |
| _default_pool_kwargs = dict(kernel_size=2) |
| _default_dropout_kwargs = dict() |
| _default_global_pool_kwargs = dict() |
|
|
| def __init__(self, |
| in_channels=1, |
| out_channels=6, |
| depth=4, |
| injection_depth="last", |
| injection_channels=0, |
| block_depth=2, |
| num_feature_maps=24, |
| feature_map_multiplier=2, |
| activation_op=nn.LeakyReLU, |
| activation_kwargs=None, |
| norm_op=nn.InstanceNorm2d, |
| norm_kwargs=None, |
| norm_depth=0, |
| conv_op=nn.Conv2d, |
| conv_kwargs=None, |
| pool_op=nn.AvgPool2d, |
| pool_kwargs=None, |
| dropout_op=None, |
| dropout_kwargs=None, |
| global_pool_op=nn.AdaptiveAvgPool2d, |
| global_pool_kwargs=None, |
| **kwargs): |
|
|
| super(InjectionConvEncoder, self).__init__(**kwargs) |
|
|
| self.in_channels = in_channels |
| self.out_channels = out_channels |
| self.depth = depth |
| self.injection_depth = depth - 1 if injection_depth == "last" else injection_depth |
| self.injection_channels = injection_channels |
| self.block_depth = block_depth |
| self.num_feature_maps = num_feature_maps |
| self.feature_map_multiplier = feature_map_multiplier |
|
|
| self.activation_op = activation_op |
| self.activation_kwargs = self._default_activation_kwargs |
| if activation_kwargs is not None: |
| self.activation_kwargs.update(activation_kwargs) |
|
|
| self.norm_op = norm_op |
| self.norm_kwargs = self._default_norm_kwargs |
| if norm_kwargs is not None: |
| self.norm_kwargs.update(norm_kwargs) |
| self.norm_depth = depth if norm_depth == "full" else norm_depth |
|
|
| self.conv_op = conv_op |
| self.conv_kwargs = self._default_conv_kwargs |
| if conv_kwargs is not None: |
| self.conv_kwargs.update(conv_kwargs) |
|
|
| self.pool_op = pool_op |
| self.pool_kwargs = self._default_pool_kwargs |
| if pool_kwargs is not None: |
| self.pool_kwargs.update(pool_kwargs) |
|
|
| self.dropout_op = dropout_op |
| self.dropout_kwargs = self._default_dropout_kwargs |
| if dropout_kwargs is not None: |
| self.dropout_kwargs.update(dropout_kwargs) |
|
|
| self.global_pool_op = global_pool_op |
| self.global_pool_kwargs = self._default_global_pool_kwargs |
| if global_pool_kwargs is not None: |
| self.global_pool_kwargs.update(global_pool_kwargs) |
|
|
| for d in range(self.depth): |
|
|
| in_ = self.in_channels if d == 0 else self.num_feature_maps * (self.feature_map_multiplier**(d-1)) |
| out_ = self.num_feature_maps * (self.feature_map_multiplier**d) |
|
|
| if d == self.injection_depth + 1: |
| in_ += self.injection_channels |
|
|
| layers = [] |
| if d > 0: |
| layers.append(self.pool_op(**self.pool_kwargs)) |
| for b in range(self.block_depth): |
| current_in = in_ if b == 0 else out_ |
| layers.append(self.conv_op(current_in, out_, **self.conv_kwargs)) |
| if self.norm_op is not None and d < self.norm_depth: |
| layers.append(self.norm_op(out_, **self.norm_kwargs)) |
| if self.activation_op is not None: |
| layers.append(self.activation_op(**self.activation_kwargs)) |
| if self.dropout_op is not None: |
| layers.append(self.dropout_op(**self.dropout_kwargs)) |
| if d == self.depth - 1: |
| current_conv_kwargs = self.conv_kwargs.copy() |
| current_conv_kwargs["kernel_size"] = 1 |
| current_conv_kwargs["padding"] = 0 |
| current_conv_kwargs["bias"] = False |
| layers.append(self.conv_op(out_, out_channels, **current_conv_kwargs)) |
|
|
| self.add_module("encode_{}".format(d), nn.Sequential(*layers)) |
|
|
| if self.global_pool_op is not None: |
| self.add_module("global_pool", self.global_pool_op(1, **self.global_pool_kwargs)) |
|
|
| def forward(self, x, injection=None): |
|
|
| for d in range(self.depth): |
| x = self._modules["encode_{}".format(d)](x) |
| if d == self.injection_depth and self.injection_channels > 0: |
| injection = match_to(injection, x, self.injection_channels) |
| x = torch.cat([x, injection], 1) |
| if hasattr(self, "global_pool"): |
| x = self.global_pool(x) |
|
|
| return x |
|
|
|
|
| class InjectionConvEncoder3D(InjectionConvEncoder): |
|
|
| def __init__(self, *args, **kwargs): |
|
|
| update_kwargs = dict( |
| norm_op=nn.InstanceNorm3d, |
| conv_op=nn.Conv3d, |
| pool_op=nn.AvgPool3d, |
| global_pool_op=nn.AdaptiveAvgPool3d |
| ) |
|
|
| for (arg, val) in update_kwargs.items(): |
| if arg not in kwargs: kwargs[arg] = val |
|
|
| super(InjectionConvEncoder3D, self).__init__(*args, **kwargs) |
|
|
| class InjectionConvEncoder2D(InjectionConvEncoder): |
| |
| def __init__(self, *args, **kwargs): |
|
|
| update_kwargs = dict( |
| norm_op=nn.InstanceNorm2d, |
| conv_op=nn.Conv2d, |
| pool_op=nn.AvgPool2d, |
| global_pool_op=nn.AdaptiveAvgPool2d |
| ) |
|
|
| for (arg, val) in update_kwargs.items(): |
| if arg not in kwargs: kwargs[arg] = val |
|
|
| super(InjectionConvEncoder2D, self).__init__(*args, **kwargs) |
|
|
| class InjectionUNet(ConvModule): |
|
|
| def __init__( |
| self, |
| depth=5, |
| in_channels=4, |
| out_channels=4, |
| kernel_size=3, |
| dilation=1, |
| num_feature_maps=24, |
| block_depth=2, |
| num_1x1_at_end=3, |
| injection_channels=3, |
| injection_at="end", |
| activation_op=nn.LeakyReLU, |
| activation_kwargs=None, |
| pool_op=nn.AvgPool2d, |
| pool_kwargs=dict(kernel_size=2), |
| dropout_op=None, |
| dropout_kwargs=None, |
| norm_op=nn.InstanceNorm2d, |
| norm_kwargs=None, |
| conv_op=nn.Conv2d, |
| conv_kwargs=None, |
| upconv_op=nn.ConvTranspose2d, |
| upconv_kwargs=None, |
| output_activation_op=None, |
| output_activation_kwargs=None, |
| return_bottom=False, |
| coords=False, |
| coords_dim=2, |
| **kwargs |
| ): |
|
|
| super(InjectionUNet, self).__init__(**kwargs) |
|
|
| self.depth = depth |
| self.in_channels = in_channels |
| self.out_channels = out_channels |
| self.kernel_size = kernel_size |
| self.dilation = dilation |
| self.padding = (self.kernel_size + (self.kernel_size-1) * (self.dilation-1)) // 2 |
| self.num_feature_maps = num_feature_maps |
| self.block_depth = block_depth |
| self.num_1x1_at_end = num_1x1_at_end |
| self.injection_channels = injection_channels |
| self.injection_at = injection_at |
| self.activation_op = activation_op |
| self.activation_kwargs = {} if activation_kwargs is None else activation_kwargs |
| self.pool_op = pool_op |
| self.pool_kwargs = {} if pool_kwargs is None else pool_kwargs |
| self.dropout_op = dropout_op |
| self.dropout_kwargs = {} if dropout_kwargs is None else dropout_kwargs |
| self.norm_op = norm_op |
| self.norm_kwargs = {} if norm_kwargs is None else norm_kwargs |
| self.conv_op = conv_op |
| self.conv_kwargs = {} if conv_kwargs is None else conv_kwargs |
| self.upconv_op = upconv_op |
| self.upconv_kwargs = {} if upconv_kwargs is None else upconv_kwargs |
| self.output_activation_op = output_activation_op |
| self.output_activation_kwargs = {} if output_activation_kwargs is None else output_activation_kwargs |
| self.return_bottom = return_bottom |
| if not coords: |
| self.coords = [[], []] |
| elif coords is True: |
| self.coords = [list(range(depth)), []] |
| else: |
| self.coords = coords |
| self.coords_dim = coords_dim |
|
|
| self.last_activations = None |
|
|
| |
| for d in range(self.depth): |
|
|
| block = [] |
| if d > 0: |
| block.append(self.pool_op(**self.pool_kwargs)) |
|
|
| for i in range(self.block_depth): |
|
|
| |
| if d == self.depth - 1 and i > 0: |
| continue |
|
|
| out_size = self.num_feature_maps * 2**d |
| if d == 0 and i == 0: |
| in_size = self.in_channels |
| elif i == 0: |
| in_size = self.num_feature_maps * 2**(d - 1) |
| else: |
| in_size = out_size |
|
|
| |
| if d in self.coords[0] and i == 0: |
| block.append(ConcatCoords()) |
| in_size += self.coords_dim |
|
|
| block.append(self.conv_op(in_size, |
| out_size, |
| self.kernel_size, |
| padding=self.padding, |
| dilation=self.dilation, |
| **self.conv_kwargs)) |
| if self.dropout_op is not None: |
| block.append(self.dropout_op(**self.dropout_kwargs)) |
| if self.norm_op is not None: |
| block.append(self.norm_op(out_size, **self.norm_kwargs)) |
| block.append(self.activation_op(**self.activation_kwargs)) |
|
|
| self.add_module("encode-{}".format(d), nn.Sequential(*block)) |
|
|
| |
| for d in reversed(range(self.depth)): |
|
|
| block = [] |
|
|
| for i in range(self.block_depth): |
|
|
| |
| if d == self.depth - 1 and i > 0: |
| continue |
|
|
| out_size = self.num_feature_maps * 2**(d) |
| if i == 0 and d < self.depth - 1: |
| in_size = self.num_feature_maps * 2**(d+1) |
| elif i == 0 and self.injection_at == "bottom": |
| in_size = out_size + self.injection_channels |
| else: |
| in_size = out_size |
|
|
| |
| if d in self.coords[0] and i == 0 and d < self.depth - 1: |
| block.append(ConcatCoords()) |
| in_size += self.coords_dim |
|
|
| block.append(self.conv_op(in_size, |
| out_size, |
| self.kernel_size, |
| padding=self.padding, |
| dilation=self.dilation, |
| **self.conv_kwargs)) |
| if self.dropout_op is not None: |
| block.append(self.dropout_op(**self.dropout_kwargs)) |
| if self.norm_op is not None: |
| block.append(self.norm_op(out_size, **self.norm_kwargs)) |
| block.append(self.activation_op(**self.activation_kwargs)) |
|
|
| if d > 0: |
| block.append(self.upconv_op(out_size, |
| out_size // 2, |
| self.kernel_size, |
| 2, |
| padding=self.padding, |
| dilation=self.dilation, |
| output_padding=1, |
| **self.upconv_kwargs)) |
|
|
| self.add_module("decode-{}".format(d), nn.Sequential(*block)) |
|
|
| if self.injection_at == "end": |
| out_size += self.injection_channels |
| in_size = out_size |
| for i in range(self.num_1x1_at_end): |
| if i == self.num_1x1_at_end - 1: |
| out_size = self.out_channels |
| current_conv_kwargs = self.conv_kwargs.copy() |
| current_conv_kwargs["bias"] = True |
| self.add_module("reduce-{}".format(i), self.conv_op(in_size, out_size, 1, **current_conv_kwargs)) |
| if i != self.num_1x1_at_end - 1: |
| self.add_module("reduce-{}-nonlin".format(i), self.activation_op(**self.activation_kwargs)) |
| if self.output_activation_op is not None: |
| self.add_module("output-activation", self.output_activation_op(**self.output_activation_kwargs)) |
|
|
| def reset(self): |
|
|
| self.last_activations = None |
|
|
| def forward(self, x, injection=None, reuse_last_activations=False, store_activations=False): |
|
|
| if self.injection_at == "bottom": |
| reuse_last_activations = False |
| store_activations = False |
|
|
| if self.last_activations is None or reuse_last_activations is False: |
|
|
| enc = [x] |
|
|
| for i in range(self.depth - 1): |
| enc.append(self._modules["encode-{}".format(i)](enc[-1])) |
|
|
| bottom_rep = self._modules["encode-{}".format(self.depth - 1)](enc[-1]) |
|
|
| if self.injection_at == "bottom" and self.injection_channels > 0: |
| injection = match_to(injection, bottom_rep, (0, 1)) |
| bottom_rep = torch.cat((bottom_rep, injection), 1) |
|
|
| x = self._modules["decode-{}".format(self.depth - 1)](bottom_rep) |
|
|
| for i in reversed(range(self.depth - 1)): |
| x = self._modules["decode-{}".format(i)](torch.cat((enc[-(self.depth - 1 - i)], x), 1)) |
|
|
| if store_activations: |
| self.last_activations = x.detach() |
|
|
| else: |
|
|
| x = self.last_activations |
|
|
| if self.injection_at == "end" and self.injection_channels > 0: |
| injection = match_to(injection, x, (0, 1)) |
| x = torch.cat((x, injection), 1) |
|
|
| for i in range(self.num_1x1_at_end): |
| x = self._modules["reduce-{}".format(i)](x) |
| if self.output_activation_op is not None: |
| x = self._modules["output-activation"](x) |
|
|
| if self.return_bottom and not reuse_last_activations: |
| return x, bottom_rep |
| else: |
| return x |
|
|
|
|
|
|
| class InjectionUNet3D(InjectionUNet): |
|
|
| def __init__(self, *args, **kwargs): |
|
|
| update_kwargs = dict( |
| pool_op=nn.AvgPool3d, |
| norm_op=nn.InstanceNorm3d, |
| conv_op=nn.Conv3d, |
| upconv_op=nn.ConvTranspose3d, |
| coords_dim=3 |
| ) |
|
|
| for (arg, val) in update_kwargs.items(): |
| if arg not in kwargs: kwargs[arg] = val |
|
|
| super(InjectionUNet3D, self).__init__(*args, **kwargs) |
|
|
| class InjectionUNet2D(InjectionUNet): |
| |
| def __init__(self, *args, **kwargs): |
|
|
| update_kwargs = dict( |
| pool_op=nn.AvgPool2d, |
| norm_op=nn.InstanceNorm2d, |
| conv_op=nn.Conv2d, |
| upconv_op=nn.ConvTranspose2d, |
| coords_dim=2 |
| ) |
|
|
| for (arg, val) in update_kwargs.items(): |
| if arg not in kwargs: kwargs[arg] = val |
|
|
| super(InjectionUNet2D, self).__init__(*args, **kwargs) |
|
|
| class ProbabilisticSegmentationNet(ConvModule): |
|
|
| def __init__(self, |
| in_channels=4, |
| out_channels=4, |
| num_feature_maps=24, |
| latent_size=3, |
| depth=5, |
| latent_distribution=torch.distributions.Normal, |
| task_op=InjectionUNet3D, |
| task_kwargs=None, |
| prior_op=InjectionConvEncoder3D, |
| prior_kwargs=None, |
| posterior_op=InjectionConvEncoder3D, |
| posterior_kwargs=None, |
| **kwargs): |
|
|
| super(ProbabilisticSegmentationNet, self).__init__(**kwargs) |
|
|
| self.task_op = task_op |
| self.task_kwargs = {} if task_kwargs is None else task_kwargs |
| self.prior_op = prior_op |
| self.prior_kwargs = {} if prior_kwargs is None else prior_kwargs |
| self.posterior_op = posterior_op |
| self.posterior_kwargs = {} if posterior_kwargs is None else posterior_kwargs |
|
|
| default_task_kwargs = dict( |
| in_channels=in_channels, |
| out_channels=out_channels, |
| num_feature_maps=num_feature_maps, |
| injection_size=latent_size, |
| depth=depth |
| ) |
|
|
| default_prior_kwargs = dict( |
| in_channels=in_channels, |
| out_channels=latent_size*2, |
| num_feature_maps=num_feature_maps, |
| z_dim=latent_size, |
| depth=depth |
| ) |
|
|
| default_posterior_kwargs = dict( |
| in_channels=in_channels+out_channels, |
| out_channels=latent_size*2, |
| num_feature_maps=num_feature_maps, |
| z_dim=latent_size, |
| depth=depth |
| ) |
|
|
| default_task_kwargs.update(self.task_kwargs) |
| self.task_kwargs = default_task_kwargs |
| default_prior_kwargs.update(self.prior_kwargs) |
| self.prior_kwargs = default_prior_kwargs |
| default_posterior_kwargs.update(self.posterior_kwargs) |
| self.posterior_kwargs = default_posterior_kwargs |
|
|
| self.latent_distribution = latent_distribution |
| self._prior = None |
| self._posterior = None |
|
|
| self.make_modules() |
|
|
| def make_modules(self): |
|
|
| if type(self.task_op) == type: |
| self.add_module("task_net", self.task_op(**self.task_kwargs)) |
| else: |
| self.add_module("task_net", self.task_op) |
| if type(self.prior_op) == type: |
| self.add_module("prior_net", self.prior_op(**self.prior_kwargs)) |
| else: |
| self.add_module("prior_net", self.prior_op) |
| if type(self.posterior_op) == type: |
| self.add_module("posterior_net", self.posterior_op(**self.posterior_kwargs)) |
| else: |
| self.add_module("posterior_net", self.posterior_op) |
|
|
| @property |
| def prior(self): |
| return self._prior |
|
|
| @property |
| def posterior(self): |
| return self._posterior |
|
|
| @property |
| def last_activations(self): |
| return self.task_net.last_activations |
|
|
| def train(self, mode=True): |
|
|
| super(ProbabilisticSegmentationNet, self).train(mode) |
| self.reset() |
|
|
| def reset(self): |
|
|
| self.task_net.reset() |
| self._prior = None |
| self._posterior = None |
|
|
| def forward(self, input_, seg=None, make_onehot=True, make_onehot_classes=None, newaxis=False, distlossN=0): |
| """Forward pass includes reparametrization sampling during training, otherwise it'll just take the prior mean.""" |
|
|
| self.encode_prior(input_) |
|
|
| if distlossN == 0: |
| if self.training: |
| self.encode_posterior(input_, seg, make_onehot, make_onehot_classes, newaxis) |
| sample = self.posterior.rsample() |
| else: |
| sample = self.prior.loc |
| return self.task_net(input_, sample, store_activations=not self.training) |
| else: |
| if self.training: |
| self.encode_posterior(input_, seg, make_onehot, make_onehot_classes, newaxis) |
| segs = [] |
| for i in range(distlossN): |
| sample = self.posterior.rsample() |
| segs.append(self.task_net(input_, sample, store_activations=not self.training)) |
| return segs |
| else: |
| sample = self.prior.loc |
| return self.task_net(input_, sample, store_activations=not self.training) |
|
|
|
|
| def encode_prior(self, input_): |
|
|
| rep = self.prior_net(input_) |
| if isinstance(rep, tuple): |
| mean, logvar = rep |
| elif torch.is_tensor(rep): |
| mean, logvar = torch.split(rep, rep.shape[1] // 2, dim=1) |
| self._prior = self.latent_distribution(mean, logvar.mul(0.5).exp()) |
| return self._prior |
|
|
| def encode_posterior(self, input_, seg, make_onehot=True, make_onehot_classes=None, newaxis=False): |
|
|
| if make_onehot: |
| if make_onehot_classes is None: |
| make_onehot_classes = tuple(range(self.posterior_net.in_channels - input_.shape[1])) |
| seg = make_onehot_segmentation(seg, make_onehot_classes, newaxis=newaxis) |
| rep = self.posterior_net(torch.cat((input_, seg.float()), 1)) |
| if isinstance(rep, tuple): |
| mean, logvar = rep |
| elif torch.is_tensor(rep): |
| mean, logvar = torch.split(rep, rep.shape[1] // 2, dim=1) |
| self._posterior = self.latent_distribution(mean, logvar.mul(0.5).exp()) |
| return self._posterior |
|
|
| def sample_prior(self, N=1, out_device=None, input_=None, pred_with_mean=False): |
| """Draw multiple samples from the current prior. |
| |
| * input_ is required if no activations are stored in task_net. |
| * If input_ is given, prior will automatically be encoded again. |
| * Returns either a single sample or a list of samples. |
| |
| """ |
|
|
| if out_device is None: |
| if self.last_activations is not None: |
| out_device = self.last_activations.device |
| elif input_ is not None: |
| out_device = input_.device |
| else: |
| out_device = next(self.task_net.parameters()).device |
| with torch.no_grad(): |
| if self.prior is None or input_ is not None: |
| self.encode_prior(input_) |
| result = [] |
| |
| if input_ is not None: |
| result.append(self.task_net(input_, self.prior.sample(), reuse_last_activations=False, store_activations=True).to(device=out_device)) |
| while len(result) < N: |
| result.append(self.task_net(input_, |
| self.prior.sample(), |
| reuse_last_activations=self.last_activations is not None, |
| store_activations=False).to(device=out_device)) |
| if pred_with_mean: |
| result.append(self.task_net(input_, self.prior.mean, reuse_last_activations=False, store_activations=True).to(device=out_device)) |
| |
| if len(result) == 1: |
| return result[0] |
| else: |
| return result |
|
|
| def reconstruct(self, sample=None, use_posterior_mean=True, out_device=None, input_=None): |
| """Reconstruct a sample or the current posterior mean. Will not compute gradients!""" |
|
|
| if self.posterior is None and sample is None: |
| raise ValueError("'posterior' is currently None. Please pass an input and a segmentation first.") |
| if out_device is None: |
| out_device = next(self.task_net.parameters()).device |
| if sample is None: |
| if use_posterior_mean: |
| sample = self.posterior.loc |
| else: |
| sample = self.posterior.sample() |
| else: |
| sample = sample.to(next(self.task_net.parameters()).device) |
| with torch.no_grad(): |
| return self.task_net(input_, sample, reuse_last_activations=True).to(device=out_device) |
|
|
| def kl_divergence(self): |
| """Compute current KL, requires existing prior and posterior.""" |
|
|
| if self.posterior is None or self.prior is None: |
| raise ValueError("'prior' and 'posterior' must not be None, but prior={} and posterior={}".format(self.prior, self.posterior)) |
| return torch.distributions.kl_divergence(self.posterior, self.prior).sum() |
|
|
| def elbo(self, seg, input_=None, nll_reduction="sum", beta=1.0, make_onehot=True, make_onehot_classes=None, newaxis=False): |
| """Compute the ELBO with seg as ground truth. |
| |
| * Prior is expected and will not be encoded. |
| * If input_ is given, posterior will automatically be encoded. |
| * Either input_ or stored activations must be available. |
| |
| """ |
|
|
| if self.last_activations is None: |
| raise ValueError("'last_activations' is currently None. Please pass an input first.") |
| if input_ is not None: |
| with torch.no_grad(): |
| self.encode_posterior(input_, seg, make_onehot=make_onehot, make_onehot_classes=make_onehot_classes, newaxis=newaxis) |
| if make_onehot and newaxis: |
| pass |
| elif make_onehot and not newaxis: |
| seg = seg[:, 0] |
| else: |
| seg = torch.argmax(seg, 1, keepdim=False) |
| kl = self.kl_divergence() |
| nll = nn.NLLLoss(reduction=nll_reduction)(self.reconstruct(sample=None, use_posterior_mean=True, out_device=None), seg.long()) |
| return - (beta * nll + kl) |