| | |
| | import inspect |
| | from enum import Enum |
| | from functools import partial |
| | from typing import Optional, Sequence, Tuple, Union |
| |
|
| | import numpy as np |
| | import torch.nn as nn |
| |
|
| | |
| |
|
| |
|
| | def assert_(condition, message="", exception_type=AssertionError): |
| | """Like assert, but with arbitrary exception types.""" |
| | if not condition: |
| | raise exception_type(message) |
| |
|
| |
|
| | |
| | generated_inferno_classes = {} |
| |
|
| |
|
| | def partial_cls(base_cls, name, fix=None, default=None): |
| |
|
| | |
| | def insert_if_not_present(dict_a, dict_b): |
| | for kw, val in dict_b.items(): |
| | if kw not in dict_a: |
| | dict_a[kw] = val |
| | return dict_a |
| |
|
| | |
| | def insert_call_if_present(dict_a, dict_b, callback): |
| | for kw, val in dict_b.items(): |
| | if kw not in dict_a: |
| | dict_a[kw] = val |
| | else: |
| | callback(kw) |
| | return dict_a |
| |
|
| | |
| | class PartialCls(object): |
| | def __init__(self, base_cls, name, fix=None, default=None): |
| |
|
| | self.base_cls = base_cls |
| | self.name = name |
| | self.fix = [fix, {}][fix is None] |
| | self.default = [default, {}][default is None] |
| |
|
| | if self.fix.keys() & self.default.keys(): |
| | raise TypeError("fix and default share keys") |
| |
|
| | |
| | self._allowed_kw = self._get_allowed_kw() |
| |
|
| | def _get_allowed_kw(self): |
| |
|
| | argspec = inspect.getfullargspec(base_cls.__init__) |
| | args, varargs, varkw, defaults, kwonlyargs, kwonlydefaults, annotations = ( |
| | argspec |
| | ) |
| |
|
| | if varargs is not None: |
| | raise TypeError( |
| | "partial_cls can only be used if __init__ has no varargs" |
| | ) |
| |
|
| | if varkw is not None: |
| | raise TypeError("partial_cls can only be used if __init__ has no varkw") |
| |
|
| | if kwonlyargs is not None and kwonlyargs != []: |
| | raise TypeError("partial_cls can only be used without kwonlyargs") |
| |
|
| | if args is None or len(args) < 1: |
| | raise TypeError("seems like self is missing") |
| |
|
| | return [kw for kw in args[1:] if kw not in self.fix] |
| |
|
| | def _build_kw(self, args, kwargs): |
| | |
| | if len(args) > len(self._allowed_kw): |
| | raise TypeError("to many arguments") |
| |
|
| | all_args = {} |
| | for arg, akw in zip(args, self._allowed_kw): |
| | all_args[akw] = arg |
| |
|
| | |
| | intersection = self.fix.keys() & kwargs.keys() |
| | if len(intersection) >= 1: |
| | kw = intersection.pop() |
| | raise TypeError( |
| | "`{}.__init__` got unexpected keyword argument '{}'".format( |
| | name, kw |
| | ) |
| | ) |
| |
|
| | def raise_cb(kw): |
| | raise TypeError( |
| | "{}.__init__ got multiple values for argument '{}'".format(name, kw) |
| | ) |
| |
|
| | all_args = insert_call_if_present(all_args, kwargs, raise_cb) |
| |
|
| | |
| | def raise_cb(kw): |
| | raise TypeError() |
| |
|
| | all_args = insert_call_if_present(all_args, self.fix, raise_cb) |
| |
|
| | |
| | all_args = insert_if_not_present(all_args, self.default) |
| |
|
| | |
| | all_args.update(self.fix) |
| |
|
| | return all_args |
| |
|
| | def build_cls(self): |
| |
|
| | def new_init(self_of_new_cls, *args, **kwargs): |
| | combined_args = self._build_kw(args=args, kwargs=kwargs) |
| |
|
| | |
| | super(self_of_new_cls.__class__, self_of_new_cls).__init__( |
| | **combined_args |
| | ) |
| |
|
| | return type(name, (self.base_cls,), {"__init__": new_init}) |
| |
|
| | return PartialCls( |
| | base_cls=base_cls, name=name, fix=fix, default=default |
| | ).build_cls() |
| |
|
| |
|
| | def register_partial_cls(base_cls, name, fix=None, default=None): |
| | generatedClass = partial_cls(base_cls=base_cls, name=name, fix=fix, default=default) |
| | generated_inferno_classes[generatedClass.__name__] = generatedClass |
| |
|
| |
|
| | class Initializer(object): |
| | """ |
| | Base class for all initializers. |
| | """ |
| |
|
| | |
| | VALID_LAYERS = { |
| | "Conv1d", |
| | "Conv2d", |
| | "Conv3d", |
| | "ConvTranspose1d", |
| | "ConvTranspose2d", |
| | "ConvTranspose3d", |
| | "Linear", |
| | "Bilinear", |
| | "Embedding", |
| | } |
| |
|
| | def __call__(self, module): |
| | module_class_name = module.__class__.__name__ |
| | if module_class_name in self.VALID_LAYERS: |
| | |
| | try: |
| | if hasattr(module, "weight"): |
| | self.call_on_weight(module.weight.data) |
| | except NotImplementedError: |
| | |
| | pass |
| |
|
| | try: |
| | if hasattr(module, "bias"): |
| | self.call_on_bias(module.bias.data) |
| | except NotImplementedError: |
| | pass |
| |
|
| | return module |
| |
|
| | def call_on_bias(self, tensor): |
| | return self.call_on_tensor(tensor) |
| |
|
| | def call_on_weight(self, tensor): |
| | return self.call_on_tensor(tensor) |
| |
|
| | def call_on_tensor(self, tensor): |
| | raise NotImplementedError |
| |
|
| | @classmethod |
| | def initializes_weight(cls): |
| | return "call_on_tensor" in cls.__dict__ or "call_on_weight" in cls.__dict__ |
| |
|
| | @classmethod |
| | def initializes_bias(cls): |
| | return "call_on_tensor" in cls.__dict__ or "call_on_bias" in cls.__dict__ |
| |
|
| |
|
| | class Initialization(Initializer): |
| | def __init__(self, weight_initializer=None, bias_initializer=None): |
| | if weight_initializer is None: |
| | self.weight_initializer = Initializer() |
| | else: |
| | if isinstance(weight_initializer, Initializer): |
| | assert weight_initializer.initializes_weight() |
| | self.weight_initializer = weight_initializer |
| | elif isinstance(weight_initializer, str): |
| | init_function = getattr(nn.init, weight_initializer, None) |
| | assert init_function is not None |
| | self.weight_initializer = WeightInitFunction( |
| | init_function=init_function |
| | ) |
| | else: |
| | |
| | assert callable(weight_initializer) |
| | self.weight_initializer = WeightInitFunction( |
| | init_function=weight_initializer |
| | ) |
| |
|
| | if bias_initializer is None: |
| | self.bias_initializer = Initializer() |
| | else: |
| | if isinstance(bias_initializer, Initializer): |
| | assert bias_initializer.initializes_bias |
| | self.bias_initializer = bias_initializer |
| | elif isinstance(bias_initializer, str): |
| | init_function = getattr(nn.init, bias_initializer, None) |
| | assert init_function is not None |
| | self.bias_initializer = BiasInitFunction(init_function=init_function) |
| | else: |
| | assert callable(bias_initializer) |
| | self.bias_initializer = BiasInitFunction(init_function=bias_initializer) |
| |
|
| | def call_on_weight(self, tensor): |
| | return self.weight_initializer.call_on_weight(tensor) |
| |
|
| | def call_on_bias(self, tensor): |
| | return self.bias_initializer.call_on_bias(tensor) |
| |
|
| |
|
| | class WeightInitFunction(Initializer): |
| | def __init__(self, init_function, *init_function_args, **init_function_kwargs): |
| | super(WeightInitFunction, self).__init__() |
| | assert callable(init_function) |
| | self.init_function = init_function |
| | self.init_function_args = init_function_args |
| | self.init_function_kwargs = init_function_kwargs |
| |
|
| | def call_on_weight(self, tensor): |
| | return self.init_function( |
| | tensor, *self.init_function_args, **self.init_function_kwargs |
| | ) |
| |
|
| |
|
| | class BiasInitFunction(Initializer): |
| | def __init__(self, init_function, *init_function_args, **init_function_kwargs): |
| | super(BiasInitFunction, self).__init__() |
| | assert callable(init_function) |
| | self.init_function = init_function |
| | self.init_function_args = init_function_args |
| | self.init_function_kwargs = init_function_kwargs |
| |
|
| | def call_on_bias(self, tensor): |
| | return self.init_function( |
| | tensor, *self.init_function_args, **self.init_function_kwargs |
| | ) |
| |
|
| |
|
| | class TensorInitFunction(Initializer): |
| | def __init__(self, init_function, *init_function_args, **init_function_kwargs): |
| | super(TensorInitFunction, self).__init__() |
| | assert callable(init_function) |
| | self.init_function = init_function |
| | self.init_function_args = init_function_args |
| | self.init_function_kwargs = init_function_kwargs |
| |
|
| | def call_on_tensor(self, tensor): |
| | return self.init_function( |
| | tensor, *self.init_function_args, **self.init_function_kwargs |
| | ) |
| |
|
| |
|
| | class Constant(Initializer): |
| | """Initialize with a constant.""" |
| |
|
| | def __init__(self, constant): |
| | self.constant = constant |
| |
|
| | def call_on_tensor(self, tensor): |
| | tensor.fill_(self.constant) |
| | return tensor |
| |
|
| |
|
| | class NormalWeights(Initializer): |
| | """ |
| | Initialize weights with random numbers drawn from the normal distribution at |
| | `mean` and `stddev`. |
| | """ |
| |
|
| | def __init__(self, mean=0.0, stddev=1.0, sqrt_gain_over_fan_in=None): |
| | self.mean = mean |
| | self.stddev = stddev |
| | self.sqrt_gain_over_fan_in = sqrt_gain_over_fan_in |
| |
|
| | def compute_fan_in(self, tensor): |
| | if tensor.dim() == 2: |
| | return tensor.size(1) |
| | else: |
| | return np.prod(list(tensor.size())[1:]) |
| |
|
| | def call_on_weight(self, tensor): |
| | |
| | if self.sqrt_gain_over_fan_in is not None: |
| | stddev = self.stddev * np.sqrt( |
| | self.sqrt_gain_over_fan_in / self.compute_fan_in(tensor) |
| | ) |
| | else: |
| | stddev = self.stddev |
| | |
| | tensor.normal_(self.mean, stddev) |
| |
|
| |
|
| | class OrthogonalWeightsZeroBias(Initialization): |
| | def __init__(self, orthogonal_gain=1.0): |
| | |
| | orthogonal = getattr(nn.init, "orthogonal_", nn.init.orthogonal) |
| | super(OrthogonalWeightsZeroBias, self).__init__( |
| | weight_initializer=partial(orthogonal, gain=orthogonal_gain), |
| | bias_initializer=Constant(0.0), |
| | ) |
| |
|
| |
|
| | class KaimingNormalWeightsZeroBias(Initialization): |
| | def __init__(self, relu_leakage=0): |
| | |
| | kaiming_normal = getattr(nn.init, "kaiming_normal_", nn.init.kaiming_normal) |
| | super(KaimingNormalWeightsZeroBias, self).__init__( |
| | weight_initializer=partial(kaiming_normal, a=relu_leakage), |
| | bias_initializer=Constant(0.0), |
| | ) |
| |
|
| |
|
| | class SELUWeightsZeroBias(Initialization): |
| | def __init__(self): |
| | super(SELUWeightsZeroBias, self).__init__( |
| | weight_initializer=NormalWeights(sqrt_gain_over_fan_in=1.0), |
| | bias_initializer=Constant(0.0), |
| | ) |
| |
|
| |
|
| | class ELUWeightsZeroBias(Initialization): |
| | def __init__(self): |
| | super(ELUWeightsZeroBias, self).__init__( |
| | weight_initializer=NormalWeights(sqrt_gain_over_fan_in=1.5505188080679277), |
| | bias_initializer=Constant(0.0), |
| | ) |
| |
|
| |
|
| | class BatchNormND(nn.Module): |
| | def __init__( |
| | self, |
| | dim, |
| | num_features, |
| | eps=1e-5, |
| | momentum=0.1, |
| | affine=True, |
| | track_running_stats=True, |
| | ): |
| | super(BatchNormND, self).__init__() |
| | assert dim in [1, 2, 3] |
| | self.bn = getattr(nn, "BatchNorm{}d".format(dim))( |
| | num_features=num_features, |
| | eps=eps, |
| | momentum=momentum, |
| | affine=affine, |
| | track_running_stats=track_running_stats, |
| | ) |
| |
|
| | def forward(self, x): |
| | return self.bn(x) |
| |
|
| |
|
| | class ConvActivation(nn.Module): |
| | """Convolutional layer with 'SAME' padding by default followed by an activation.""" |
| |
|
| | def __init__( |
| | self, |
| | in_channels, |
| | out_channels, |
| | kernel_size, |
| | dim, |
| | activation, |
| | stride=1, |
| | dilation=1, |
| | groups=None, |
| | depthwise=False, |
| | bias=True, |
| | deconv=False, |
| | initialization=None, |
| | valid_conv=False, |
| | ): |
| | super(ConvActivation, self).__init__() |
| | |
| | assert_( |
| | dim in [1, 2, 3], |
| | "`dim` must be one of [1, 2, 3], got {}.".format(dim), |
| | ) |
| | self.dim = dim |
| | |
| | if depthwise: |
| |
|
| | |
| | |
| | out_channels = ( |
| | in_channels if out_channels in [None, "auto"] else out_channels |
| | ) |
| | assert_( |
| | in_channels == out_channels, |
| | "For depthwise convolutions, number of input channels (given: {}) " |
| | "must equal the number of output channels (given {}).".format( |
| | in_channels, out_channels |
| | ), |
| | ValueError, |
| | ) |
| | assert_( |
| | groups is None or groups == in_channels, |
| | "For depthwise convolutions, groups (given: {}) must " |
| | "equal the number of channels (given: {}).".format(groups, in_channels), |
| | ) |
| | groups = in_channels |
| | else: |
| | groups = 1 if groups is None else groups |
| | self.depthwise = depthwise |
| | if valid_conv: |
| | self.conv = getattr(nn, "Conv{}d".format(self.dim))( |
| | in_channels=in_channels, |
| | out_channels=out_channels, |
| | kernel_size=kernel_size, |
| | stride=stride, |
| | dilation=dilation, |
| | groups=groups, |
| | bias=bias, |
| | ) |
| | elif not deconv: |
| | |
| | padding = self.get_padding(kernel_size, dilation) |
| | self.conv = getattr(nn, "Conv{}d".format(self.dim))( |
| | in_channels=in_channels, |
| | out_channels=out_channels, |
| | kernel_size=kernel_size, |
| | padding=padding, |
| | stride=stride, |
| | dilation=dilation, |
| | groups=groups, |
| | bias=bias, |
| | ) |
| | else: |
| | self.conv = getattr(nn, "ConvTranspose{}d".format(self.dim))( |
| | in_channels=in_channels, |
| | out_channels=out_channels, |
| | kernel_size=kernel_size, |
| | stride=stride, |
| | dilation=dilation, |
| | groups=groups, |
| | bias=bias, |
| | ) |
| | if initialization is None: |
| | pass |
| | elif isinstance(initialization, Initializer): |
| | self.conv.apply(initialization) |
| | else: |
| | raise NotImplementedError |
| |
|
| | if isinstance(activation, str): |
| | self.activation = getattr(nn, activation)() |
| | elif isinstance(activation, nn.Module): |
| | self.activation = activation |
| | elif activation is None: |
| | self.activation = None |
| | else: |
| | raise NotImplementedError |
| |
|
| | def forward(self, input): |
| | conved = self.conv(input) |
| | if self.activation is not None: |
| | activated = self.activation(conved) |
| | else: |
| | |
| | activated = conved |
| | return activated |
| |
|
| | def _pair_or_triplet(self, object_): |
| | if isinstance(object_, (list, tuple)): |
| | assert len(object_) == self.dim |
| | return object_ |
| | else: |
| | object_ = [object_] * self.dim |
| | return object_ |
| |
|
| | def _get_padding(self, _kernel_size, _dilation): |
| | assert isinstance(_kernel_size, int) |
| | assert isinstance(_dilation, int) |
| | assert _kernel_size % 2 == 1 |
| | return ((_kernel_size - 1) // 2) * _dilation |
| |
|
| | def get_padding(self, kernel_size, dilation): |
| | kernel_size = self._pair_or_triplet(kernel_size) |
| | dilation = self._pair_or_triplet(dilation) |
| | padding = [ |
| | self._get_padding(_kernel_size, _dilation) |
| | for _kernel_size, _dilation in zip(kernel_size, dilation) |
| | ] |
| | return tuple(padding) |
| |
|
| |
|
| | |
| | ConvActivationND = ConvActivation |
| |
|
| |
|
| | class _BNReLUSomeConv(object): |
| | def forward(self, input): |
| | normed = self.batchnorm(input) |
| | activated = self.activation(normed) |
| | conved = self.conv(activated) |
| | return conved |
| |
|
| |
|
| | class BNReLUConvBaseND(_BNReLUSomeConv, ConvActivation): |
| | def __init__( |
| | self, |
| | in_channels, |
| | out_channels, |
| | kernel_size, |
| | dim, |
| | stride=1, |
| | dilation=1, |
| | deconv=False, |
| | ): |
| |
|
| | super(BNReLUConvBaseND, self).__init__( |
| | in_channels=in_channels, |
| | out_channels=out_channels, |
| | kernel_size=kernel_size, |
| | dim=dim, |
| | stride=stride, |
| | activation=nn.ReLU(inplace=True), |
| | dilation=dilation, |
| | deconv=deconv, |
| | initialization=KaimingNormalWeightsZeroBias(0), |
| | ) |
| | self.batchnorm = BatchNormND(dim, in_channels) |
| |
|
| |
|
| | def _register_bnr_conv_cls(conv_name, fix=None, default=None): |
| | if fix is None: |
| | fix = {} |
| | if default is None: |
| | default = {} |
| | for dim in [1, 2, 3]: |
| |
|
| | cls_name = "BNReLU{}ND".format(conv_name) |
| | register_partial_cls(BNReLUConvBaseND, cls_name, fix=fix, default=default) |
| |
|
| | for dim in [1, 2, 3]: |
| | cls_name = "BNReLU{}{}D".format(conv_name, dim) |
| |
|
| | register_partial_cls( |
| | BNReLUConvBaseND, cls_name, fix={**fix, "dim": dim}, default=default |
| | ) |
| |
|
| |
|
| | def _register_conv_cls(conv_name, fix=None, default=None): |
| | if fix is None: |
| | fix = {} |
| | if default is None: |
| | default = {} |
| |
|
| | |
| | activations = ["ReLU", "ELU", "Sigmoid", "SELU", ""] |
| | init_map = {"ReLU": KaimingNormalWeightsZeroBias, "SELU": SELUWeightsZeroBias} |
| | for activation_str in activations: |
| | cls_name = cls_name = "{}{}ND".format(conv_name, activation_str) |
| | initialization_cls = init_map.get(activation_str, OrthogonalWeightsZeroBias) |
| | if activation_str == "": |
| | activation = None |
| | _fix = {**fix} |
| | _default = {"activation": None} |
| | elif activation_str == "SELU": |
| | activation = nn.SELU(inplace=True) |
| | _fix = {**fix, "activation": activation} |
| | _default = {**default} |
| | else: |
| | activation = activation_str |
| | _fix = {**fix, "activation": activation} |
| | _default = {**default} |
| |
|
| | register_partial_cls( |
| | ConvActivation, |
| | cls_name, |
| | fix=_fix, |
| | default={**_default, "initialization": initialization_cls()}, |
| | ) |
| | for dim in [1, 2, 3]: |
| | cls_name = "{}{}{}D".format(conv_name, activation_str, dim) |
| | register_partial_cls( |
| | ConvActivation, |
| | cls_name, |
| | fix={**_fix, "dim": dim}, |
| | default={**_default, "initialization": initialization_cls()}, |
| | ) |
| |
|
| |
|
| | _register_conv_cls("Conv") |
| | _register_conv_cls("ValidConv", fix=dict(valid_conv=True)) |
| |
|
| | Conv2D = generated_inferno_classes["Conv2D"] |
| | ValidConv3D = generated_inferno_classes["ValidConv3D"] |
| |
|
| |
|
| | |
| | class Crop(nn.Module): |
| | def __init__(self, *slices: slice): |
| | super().__init__() |
| | self.slices = slices |
| |
|
| | def extra_repr(self): |
| | return str(self.slices) |
| |
|
| | def forward(self, input): |
| | return input[self.slices] |
| |
|
| |
|
| | class ChannelFromLightField(nn.Module): |
| | def __init__(self, nnum: int): |
| | super().__init__() |
| | self.nnum = nnum |
| |
|
| | def forward(self, tensor): |
| | assert len(tensor.shape) == 4, tensor.shape |
| | b, c, x, y = tensor.shape |
| | assert c == 1 |
| | assert x % self.nnum == 0, (x, self.nnum) |
| | assert y % self.nnum == 0, (y, self.nnum) |
| | return ( |
| | tensor.reshape(b, x // self.nnum, self.nnum, y // self.nnum, self.nnum) |
| | .transpose(1, 2) |
| | .transpose(2, 4) |
| | .transpose(3, 4) |
| | .reshape(b, self.nnum**2, x // self.nnum, y // self.nnum) |
| | ) |
| |
|
| |
|
| | class ResnetBlock(nn.Module): |
| | def __init__( |
| | self, |
| | in_n_filters, |
| | n_filters, |
| | kernel_size=(3, 3), |
| | batch_norm=False, |
| | conv_per_block=2, |
| | valid: bool = False, |
| | activation: str = "ReLU", |
| | ): |
| | super().__init__() |
| | if batch_norm and activation != "ReLU": |
| | raise NotImplementedError("batch_norm with non ReLU activation") |
| |
|
| | assert isinstance(kernel_size, tuple), kernel_size |
| | assert conv_per_block >= 2 |
| | self.debug = False |
| |
|
| | Conv = generated_inferno_classes[ |
| | f"{'BNReLU' if batch_norm else ''}{'Valid' if valid else ''}Conv{'' if batch_norm else activation}{len(kernel_size)}D" |
| | ] |
| | FinalConv = generated_inferno_classes[ |
| | f"{'BNReLU' if batch_norm else ''}{'Valid' if valid else ''}Conv{len(kernel_size)}D" |
| | ] |
| |
|
| | layers = [] |
| | layers.append( |
| | Conv( |
| | in_channels=in_n_filters, |
| | out_channels=n_filters, |
| | kernel_size=kernel_size, |
| | ) |
| | ) |
| |
|
| | for _ in range(conv_per_block - 2): |
| | layers.append(Conv(n_filters, n_filters, kernel_size)) |
| |
|
| | layers.append(FinalConv(n_filters, n_filters, kernel_size)) |
| |
|
| | self.block = nn.Sequential(*layers) |
| |
|
| | if n_filters != in_n_filters: |
| | ProjConv = generated_inferno_classes[f"Conv{len(kernel_size)}D"] |
| | self.projection_layer = ProjConv(in_n_filters, n_filters, kernel_size=1) |
| | else: |
| | self.projection_layer = None |
| |
|
| | if valid: |
| | crop_each_side = [conv_per_block * (ks // 2) for ks in kernel_size] |
| | self.crop = Crop(..., *[slice(c, -c) for c in crop_each_side]) |
| | else: |
| | self.crop = None |
| |
|
| | self.relu = nn.ReLU() |
| |
|
| | |
| | |
| |
|
| | def forward(self, input): |
| | x = self.block(input) |
| | if self.crop is not None: |
| | input = self.crop(input) |
| |
|
| | if self.projection_layer is None: |
| | x = x + input |
| | else: |
| | projected = self.projection_layer(input) |
| | x = x + projected |
| |
|
| | x = self.relu(x) |
| | return x |
| |
|
| |
|
| | class HyLFM_Net(nn.Module): |
| | class InitName(str, Enum): |
| | uniform_ = "uniform" |
| | normal_ = "normal" |
| | constant_ = "constant" |
| | eye_ = "eye" |
| | dirac_ = "dirac" |
| | xavier_uniform_ = "xavier_uniform" |
| | xavier_normal_ = "xavier_normal" |
| | kaiming_uniform_ = "kaiming_uniform" |
| | kaiming_normal_ = "kaiming_normal" |
| | orthogonal_ = "orthogonal" |
| | sparse_ = "sparse" |
| |
|
| | def __init__( |
| | self, |
| | *, |
| | z_out: int, |
| | nnum: int, |
| | kernel2d: int = 3, |
| | conv_per_block2d: int = 2, |
| | c_res2d: Sequence[Union[int, str]] = (488, 488, "u244", 244), |
| | last_kernel2d: int = 1, |
| | c_in_3d: int = 7, |
| | kernel3d: int = 3, |
| | conv_per_block3d: int = 2, |
| | c_res3d: Sequence[str] = (7, "u7", 7, 7), |
| | init_fn: Union[InitName, str] = InitName.xavier_uniform_.value, |
| | final_activation: Optional[str] = None, |
| | ): |
| | super().__init__() |
| | self.channel_from_lf = ChannelFromLightField(nnum=nnum) |
| | init_fn = self.InitName(init_fn) |
| |
|
| | if hasattr(nn.init, f"{init_fn.value}_"): |
| | |
| | init_fn = getattr(nn.init, f"{init_fn.value}_") |
| | else: |
| | init_fn = getattr(nn.init, init_fn.value) |
| |
|
| | self.c_res2d = list(c_res2d) |
| | self.c_res3d = list(c_res3d) |
| | c_res3d = c_res3d |
| | self.nnum = nnum |
| | self.z_out = z_out |
| | if kernel3d != 3: |
| | raise NotImplementedError("z_out expansion for other res3d kernel") |
| |
|
| | dz = 2 * conv_per_block3d * (kernel3d // 2) |
| | for c in c_res3d: |
| | if isinstance(c, int) or not c.startswith("u"): |
| | z_out += dz |
| |
|
| | |
| |
|
| | assert ( |
| | c_res2d[-1] != "u" |
| | ), "missing # output channels for upsampling in 'c_res2d'" |
| | assert ( |
| | c_res3d[-1] != "u" |
| | ), "missing # output channels for upsampling in 'c_res3d'" |
| |
|
| | res2d = [] |
| | c_in = nnum**2 |
| | c_out = c_in |
| | for i in range(len(c_res2d)): |
| | if not isinstance(c_res2d[i], int) and c_res2d[i].startswith("u"): |
| | c_out = int(c_res2d[i][1:]) |
| | res2d.append( |
| | nn.ConvTranspose2d( |
| | in_channels=c_in, |
| | out_channels=c_out, |
| | kernel_size=2, |
| | stride=2, |
| | padding=0, |
| | output_padding=0, |
| | ) |
| | ) |
| | else: |
| | c_out = int(c_res2d[i]) |
| | res2d.append( |
| | ResnetBlock( |
| | in_n_filters=c_in, |
| | n_filters=c_out, |
| | kernel_size=(kernel2d, kernel2d), |
| | valid=False, |
| | conv_per_block=conv_per_block2d, |
| | ) |
| | ) |
| |
|
| | c_in = c_out |
| |
|
| | self.res2d = nn.Sequential(*res2d) |
| |
|
| | if "gain" in inspect.signature(init_fn).parameters: |
| | init_fn_conv2d = partial(init_fn, gain=nn.init.calculate_gain("relu")) |
| | else: |
| | init_fn_conv2d = init_fn |
| |
|
| | init = Initialization( |
| | weight_initializer=init_fn_conv2d, bias_initializer=Constant(0.0) |
| | ) |
| | self.conv2d = Conv2D( |
| | c_out, |
| | z_out * c_in_3d, |
| | last_kernel2d, |
| | activation="ReLU", |
| | initialization=init, |
| | ) |
| |
|
| | self.c2z = lambda ipt, ip3=c_in_3d: ipt.view( |
| | ipt.shape[0], ip3, z_out, *ipt.shape[2:] |
| | ) |
| |
|
| | res3d = [] |
| | c_in = c_in_3d |
| | c_out = c_in |
| | for i in range(len(c_res3d)): |
| | if not isinstance(c_res3d[i], int) and c_res3d[i].startswith("u"): |
| | c_out = int(c_res3d[i][1:]) |
| | res3d.append( |
| | nn.ConvTranspose3d( |
| | in_channels=c_in, |
| | out_channels=c_out, |
| | kernel_size=(3, 2, 2), |
| | stride=(1, 2, 2), |
| | padding=(1, 0, 0), |
| | output_padding=0, |
| | ) |
| | ) |
| | else: |
| | c_out = int(c_res3d[i]) |
| | res3d.append( |
| | ResnetBlock( |
| | in_n_filters=c_in, |
| | n_filters=c_out, |
| | kernel_size=(kernel3d, kernel3d, kernel3d), |
| | valid=True, |
| | conv_per_block=conv_per_block3d, |
| | ) |
| | ) |
| |
|
| | c_in = c_out |
| |
|
| | self.res3d = nn.Sequential(*res3d) |
| |
|
| | if "gain" in inspect.signature(init_fn).parameters: |
| | init_fn_conv3d = partial(init_fn, gain=nn.init.calculate_gain("linear")) |
| | else: |
| | init_fn_conv3d = init_fn |
| |
|
| | init = Initialization( |
| | weight_initializer=init_fn_conv3d, bias_initializer=Constant(0.0) |
| | ) |
| | self.conv3d = ValidConv3D(c_out, 1, (1, 1, 1), initialization=init) |
| |
|
| | if final_activation is None: |
| | self.final_activation = None |
| | elif final_activation == "sigmoid": |
| | self.final_activation = nn.Sigmoid() |
| | else: |
| | raise NotImplementedError(final_activation) |
| |
|
| | def forward(self, x): |
| | x = self.channel_from_lf(x) |
| | x = self.res2d(x) |
| | x = self.conv2d(x) |
| | x = self.c2z(x) |
| | x = self.res3d(x) |
| | x = self.conv3d(x) |
| |
|
| | if self.final_activation is not None: |
| | x = self.final_activation(x) |
| |
|
| | return x |
| |
|
| | def get_scale(self, ipt_shape: Optional[Tuple[int, int]] = None) -> int: |
| | s = max( |
| | 1, |
| | 2 |
| | * sum( |
| | isinstance(res2d, str) and res2d.startswith("u") |
| | for res2d in self.c_res2d |
| | ), |
| | ) * max( |
| | 1, |
| | 2 |
| | * sum( |
| | isinstance(res3d, str) and res3d.startswith("u") |
| | for res3d in self.c_res3d |
| | ), |
| | ) |
| | return s |
| |
|
| | def get_shrink(self, ipt_shape: Optional[Tuple[int, int]] = None) -> int: |
| | s = 0 |
| | for res in self.c_res3d: |
| | if isinstance(res, str) and res.startswith("u"): |
| | s *= 2 |
| | else: |
| | s += 2 |
| |
|
| | return s |
| |
|
| | def get_output_shape(self, ipt_shape: Tuple[int, int]) -> Tuple[int, int, int]: |
| | scale = self.get_scale(ipt_shape) |
| | shrink = self.get_shrink(ipt_shape) |
| | return (self.z_out,) + tuple(i * scale - 2 * shrink for i in ipt_shape) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | |
| | model = HyLFM_Net( |
| | z_out=9, |
| | nnum=5, |
| | kernel2d=3, |
| | conv_per_block2d=2, |
| | c_res2d=(12, 14, "u14", 8), |
| | last_kernel2d=1, |
| | c_in_3d=7, |
| | kernel3d=3, |
| | conv_per_block3d=2, |
| | c_res3d=(7, "u7", 7, 7), |
| | init_fn="xavier_uniform", |
| | final_activation="sigmoid", |
| | ) |
| | print(model) |
| | print(model.get_output_shape((64, 64))) |
| |
|