# type: ignore import inspect from enum import Enum from functools import partial from typing import Optional, Sequence, Tuple, Union import numpy as np import torch.nn as nn ### Inferno parts (adapted from inferno 0.4.2) def assert_(condition, message="", exception_type=AssertionError): """Like assert, but with arbitrary exception types.""" if not condition: raise exception_type(message) # proxy for generated classes in inferno generated_inferno_classes = {} def partial_cls(base_cls, name, fix=None, default=None): # helper function def insert_if_not_present(dict_a, dict_b): for kw, val in dict_b.items(): if kw not in dict_a: dict_a[kw] = val return dict_a # helper function def insert_call_if_present(dict_a, dict_b, callback): for kw, val in dict_b.items(): if kw not in dict_a: dict_a[kw] = val else: callback(kw) return dict_a # helper class class PartialCls(object): def __init__(self, base_cls, name, fix=None, default=None): self.base_cls = base_cls self.name = name self.fix = [fix, {}][fix is None] self.default = [default, {}][default is None] if self.fix.keys() & self.default.keys(): raise TypeError("fix and default share keys") # remove binded kw self._allowed_kw = self._get_allowed_kw() def _get_allowed_kw(self): argspec = inspect.getfullargspec(base_cls.__init__) args, varargs, varkw, defaults, kwonlyargs, kwonlydefaults, annotations = ( argspec ) if varargs is not None: raise TypeError( "partial_cls can only be used if __init__ has no varargs" ) if varkw is not None: raise TypeError("partial_cls can only be used if __init__ has no varkw") if kwonlyargs is not None and kwonlyargs != []: raise TypeError("partial_cls can only be used without kwonlyargs") if args is None or len(args) < 1: raise TypeError("seems like self is missing") return [kw for kw in args[1:] if kw not in self.fix] def _build_kw(self, args, kwargs): # handle *args if len(args) > len(self._allowed_kw): raise TypeError("to many arguments") all_args = {} for arg, akw in zip(args, self._allowed_kw): all_args[akw] = arg # handle **kwargs intersection = self.fix.keys() & kwargs.keys() if len(intersection) >= 1: kw = intersection.pop() raise TypeError( "`{}.__init__` got unexpected keyword argument '{}'".format( name, kw ) ) def raise_cb(kw): raise TypeError( "{}.__init__ got multiple values for argument '{}'".format(name, kw) ) all_args = insert_call_if_present(all_args, kwargs, raise_cb) # handle fixed arguments def raise_cb(kw): raise TypeError() all_args = insert_call_if_present(all_args, self.fix, raise_cb) # handle defaults all_args = insert_if_not_present(all_args, self.default) # handle fixed all_args.update(self.fix) return all_args def build_cls(self): def new_init(self_of_new_cls, *args, **kwargs): combined_args = self._build_kw(args=args, kwargs=kwargs) # call base cls init super(self_of_new_cls.__class__, self_of_new_cls).__init__( **combined_args ) return type(name, (self.base_cls,), {"__init__": new_init}) return PartialCls( base_cls=base_cls, name=name, fix=fix, default=default ).build_cls() def register_partial_cls(base_cls, name, fix=None, default=None): generatedClass = partial_cls(base_cls=base_cls, name=name, fix=fix, default=default) generated_inferno_classes[generatedClass.__name__] = generatedClass class Initializer(object): """ Base class for all initializers. """ # TODO Support LSTMs and GRUs VALID_LAYERS = { "Conv1d", "Conv2d", "Conv3d", "ConvTranspose1d", "ConvTranspose2d", "ConvTranspose3d", "Linear", "Bilinear", "Embedding", } def __call__(self, module): module_class_name = module.__class__.__name__ if module_class_name in self.VALID_LAYERS: # Apply to weight and bias try: if hasattr(module, "weight"): self.call_on_weight(module.weight.data) except NotImplementedError: # Don't cry if it's not implemented pass try: if hasattr(module, "bias"): self.call_on_bias(module.bias.data) except NotImplementedError: pass return module def call_on_bias(self, tensor): return self.call_on_tensor(tensor) def call_on_weight(self, tensor): return self.call_on_tensor(tensor) def call_on_tensor(self, tensor): raise NotImplementedError @classmethod def initializes_weight(cls): return "call_on_tensor" in cls.__dict__ or "call_on_weight" in cls.__dict__ @classmethod def initializes_bias(cls): return "call_on_tensor" in cls.__dict__ or "call_on_bias" in cls.__dict__ class Initialization(Initializer): def __init__(self, weight_initializer=None, bias_initializer=None): if weight_initializer is None: self.weight_initializer = Initializer() else: if isinstance(weight_initializer, Initializer): assert weight_initializer.initializes_weight() self.weight_initializer = weight_initializer elif isinstance(weight_initializer, str): init_function = getattr(nn.init, weight_initializer, None) assert init_function is not None self.weight_initializer = WeightInitFunction( init_function=init_function ) else: # Provison for weight_initializer to be a function assert callable(weight_initializer) self.weight_initializer = WeightInitFunction( init_function=weight_initializer ) if bias_initializer is None: self.bias_initializer = Initializer() else: if isinstance(bias_initializer, Initializer): assert bias_initializer.initializes_bias self.bias_initializer = bias_initializer elif isinstance(bias_initializer, str): init_function = getattr(nn.init, bias_initializer, None) assert init_function is not None self.bias_initializer = BiasInitFunction(init_function=init_function) else: assert callable(bias_initializer) self.bias_initializer = BiasInitFunction(init_function=bias_initializer) def call_on_weight(self, tensor): return self.weight_initializer.call_on_weight(tensor) def call_on_bias(self, tensor): return self.bias_initializer.call_on_bias(tensor) class WeightInitFunction(Initializer): def __init__(self, init_function, *init_function_args, **init_function_kwargs): super(WeightInitFunction, self).__init__() assert callable(init_function) self.init_function = init_function self.init_function_args = init_function_args self.init_function_kwargs = init_function_kwargs def call_on_weight(self, tensor): return self.init_function( tensor, *self.init_function_args, **self.init_function_kwargs ) class BiasInitFunction(Initializer): def __init__(self, init_function, *init_function_args, **init_function_kwargs): super(BiasInitFunction, self).__init__() assert callable(init_function) self.init_function = init_function self.init_function_args = init_function_args self.init_function_kwargs = init_function_kwargs def call_on_bias(self, tensor): return self.init_function( tensor, *self.init_function_args, **self.init_function_kwargs ) class TensorInitFunction(Initializer): def __init__(self, init_function, *init_function_args, **init_function_kwargs): super(TensorInitFunction, self).__init__() assert callable(init_function) self.init_function = init_function self.init_function_args = init_function_args self.init_function_kwargs = init_function_kwargs def call_on_tensor(self, tensor): return self.init_function( tensor, *self.init_function_args, **self.init_function_kwargs ) class Constant(Initializer): """Initialize with a constant.""" def __init__(self, constant): self.constant = constant def call_on_tensor(self, tensor): tensor.fill_(self.constant) return tensor class NormalWeights(Initializer): """ Initialize weights with random numbers drawn from the normal distribution at `mean` and `stddev`. """ def __init__(self, mean=0.0, stddev=1.0, sqrt_gain_over_fan_in=None): self.mean = mean self.stddev = stddev self.sqrt_gain_over_fan_in = sqrt_gain_over_fan_in def compute_fan_in(self, tensor): if tensor.dim() == 2: return tensor.size(1) else: return np.prod(list(tensor.size())[1:]) def call_on_weight(self, tensor): # Compute stddev if required if self.sqrt_gain_over_fan_in is not None: stddev = self.stddev * np.sqrt( self.sqrt_gain_over_fan_in / self.compute_fan_in(tensor) ) else: stddev = self.stddev # Init tensor.normal_(self.mean, stddev) class OrthogonalWeightsZeroBias(Initialization): def __init__(self, orthogonal_gain=1.0): # This prevents a deprecated warning in Pytorch 0.4+ orthogonal = getattr(nn.init, "orthogonal_", nn.init.orthogonal) super(OrthogonalWeightsZeroBias, self).__init__( weight_initializer=partial(orthogonal, gain=orthogonal_gain), bias_initializer=Constant(0.0), ) class KaimingNormalWeightsZeroBias(Initialization): def __init__(self, relu_leakage=0): # This prevents a deprecated warning in Pytorch 0.4+ kaiming_normal = getattr(nn.init, "kaiming_normal_", nn.init.kaiming_normal) super(KaimingNormalWeightsZeroBias, self).__init__( weight_initializer=partial(kaiming_normal, a=relu_leakage), bias_initializer=Constant(0.0), ) class SELUWeightsZeroBias(Initialization): def __init__(self): super(SELUWeightsZeroBias, self).__init__( weight_initializer=NormalWeights(sqrt_gain_over_fan_in=1.0), bias_initializer=Constant(0.0), ) class ELUWeightsZeroBias(Initialization): def __init__(self): super(ELUWeightsZeroBias, self).__init__( weight_initializer=NormalWeights(sqrt_gain_over_fan_in=1.5505188080679277), bias_initializer=Constant(0.0), ) class BatchNormND(nn.Module): def __init__( self, dim, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, ): super(BatchNormND, self).__init__() assert dim in [1, 2, 3] self.bn = getattr(nn, "BatchNorm{}d".format(dim))( num_features=num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats, ) def forward(self, x): return self.bn(x) class ConvActivation(nn.Module): """Convolutional layer with 'SAME' padding by default followed by an activation.""" def __init__( self, in_channels, out_channels, kernel_size, dim, activation, stride=1, dilation=1, groups=None, depthwise=False, bias=True, deconv=False, initialization=None, valid_conv=False, ): super(ConvActivation, self).__init__() # Validate dim assert_( dim in [1, 2, 3], "`dim` must be one of [1, 2, 3], got {}.".format(dim), ) self.dim = dim # Check if depthwise if depthwise: # We know that in_channels == out_channels, but we also want a consistent API. # As a compromise, we allow that out_channels be None or 'auto'. out_channels = ( in_channels if out_channels in [None, "auto"] else out_channels ) assert_( in_channels == out_channels, "For depthwise convolutions, number of input channels (given: {}) " "must equal the number of output channels (given {}).".format( in_channels, out_channels ), ValueError, ) assert_( groups is None or groups == in_channels, "For depthwise convolutions, groups (given: {}) must " "equal the number of channels (given: {}).".format(groups, in_channels), ) groups = in_channels else: groups = 1 if groups is None else groups self.depthwise = depthwise if valid_conv: self.conv = getattr(nn, "Conv{}d".format(self.dim))( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, groups=groups, bias=bias, ) elif not deconv: # Get padding padding = self.get_padding(kernel_size, dilation) self.conv = getattr(nn, "Conv{}d".format(self.dim))( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding=padding, stride=stride, dilation=dilation, groups=groups, bias=bias, ) else: self.conv = getattr(nn, "ConvTranspose{}d".format(self.dim))( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, groups=groups, bias=bias, ) if initialization is None: pass elif isinstance(initialization, Initializer): self.conv.apply(initialization) else: raise NotImplementedError if isinstance(activation, str): self.activation = getattr(nn, activation)() elif isinstance(activation, nn.Module): self.activation = activation elif activation is None: self.activation = None else: raise NotImplementedError def forward(self, input): conved = self.conv(input) if self.activation is not None: activated = self.activation(conved) else: # No activation activated = conved return activated def _pair_or_triplet(self, object_): if isinstance(object_, (list, tuple)): assert len(object_) == self.dim return object_ else: object_ = [object_] * self.dim return object_ def _get_padding(self, _kernel_size, _dilation): assert isinstance(_kernel_size, int) assert isinstance(_dilation, int) assert _kernel_size % 2 == 1 return ((_kernel_size - 1) // 2) * _dilation def get_padding(self, kernel_size, dilation): kernel_size = self._pair_or_triplet(kernel_size) dilation = self._pair_or_triplet(dilation) padding = [ self._get_padding(_kernel_size, _dilation) for _kernel_size, _dilation in zip(kernel_size, dilation) ] return tuple(padding) # for consistency ConvActivationND = ConvActivation class _BNReLUSomeConv(object): def forward(self, input): normed = self.batchnorm(input) activated = self.activation(normed) conved = self.conv(activated) return conved class BNReLUConvBaseND(_BNReLUSomeConv, ConvActivation): def __init__( self, in_channels, out_channels, kernel_size, dim, stride=1, dilation=1, deconv=False, ): super(BNReLUConvBaseND, self).__init__( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, dim=dim, stride=stride, activation=nn.ReLU(inplace=True), dilation=dilation, deconv=deconv, initialization=KaimingNormalWeightsZeroBias(0), ) self.batchnorm = BatchNormND(dim, in_channels) def _register_bnr_conv_cls(conv_name, fix=None, default=None): if fix is None: fix = {} if default is None: default = {} for dim in [1, 2, 3]: cls_name = "BNReLU{}ND".format(conv_name) register_partial_cls(BNReLUConvBaseND, cls_name, fix=fix, default=default) for dim in [1, 2, 3]: cls_name = "BNReLU{}{}D".format(conv_name, dim) register_partial_cls( BNReLUConvBaseND, cls_name, fix={**fix, "dim": dim}, default=default ) def _register_conv_cls(conv_name, fix=None, default=None): if fix is None: fix = {} if default is None: default = {} # simple conv activation activations = ["ReLU", "ELU", "Sigmoid", "SELU", ""] init_map = {"ReLU": KaimingNormalWeightsZeroBias, "SELU": SELUWeightsZeroBias} for activation_str in activations: cls_name = cls_name = "{}{}ND".format(conv_name, activation_str) initialization_cls = init_map.get(activation_str, OrthogonalWeightsZeroBias) if activation_str == "": activation = None _fix = {**fix} _default = {"activation": None} elif activation_str == "SELU": activation = nn.SELU(inplace=True) _fix = {**fix, "activation": activation} _default = {**default} else: activation = activation_str _fix = {**fix, "activation": activation} _default = {**default} register_partial_cls( ConvActivation, cls_name, fix=_fix, default={**_default, "initialization": initialization_cls()}, ) for dim in [1, 2, 3]: cls_name = "{}{}{}D".format(conv_name, activation_str, dim) register_partial_cls( ConvActivation, cls_name, fix={**_fix, "dim": dim}, default={**_default, "initialization": initialization_cls()}, ) _register_conv_cls("Conv") _register_conv_cls("ValidConv", fix=dict(valid_conv=True)) Conv2D = generated_inferno_classes["Conv2D"] ValidConv3D = generated_inferno_classes["ValidConv3D"] ### HyLFM architecture class Crop(nn.Module): def __init__(self, *slices: slice): super().__init__() self.slices = slices def extra_repr(self): return str(self.slices) def forward(self, input): return input[self.slices] class ChannelFromLightField(nn.Module): def __init__(self, nnum: int): super().__init__() self.nnum = nnum def forward(self, tensor): assert len(tensor.shape) == 4, tensor.shape b, c, x, y = tensor.shape assert c == 1 assert x % self.nnum == 0, (x, self.nnum) assert y % self.nnum == 0, (y, self.nnum) return ( tensor.reshape(b, x // self.nnum, self.nnum, y // self.nnum, self.nnum) .transpose(1, 2) .transpose(2, 4) .transpose(3, 4) .reshape(b, self.nnum**2, x // self.nnum, y // self.nnum) ) class ResnetBlock(nn.Module): def __init__( self, in_n_filters, n_filters, kernel_size=(3, 3), batch_norm=False, conv_per_block=2, valid: bool = False, activation: str = "ReLU", ): super().__init__() if batch_norm and activation != "ReLU": raise NotImplementedError("batch_norm with non ReLU activation") assert isinstance(kernel_size, tuple), kernel_size assert conv_per_block >= 2 self.debug = False # sys.gettrace() is not None Conv = generated_inferno_classes[ f"{'BNReLU' if batch_norm else ''}{'Valid' if valid else ''}Conv{'' if batch_norm else activation}{len(kernel_size)}D" ] FinalConv = generated_inferno_classes[ f"{'BNReLU' if batch_norm else ''}{'Valid' if valid else ''}Conv{len(kernel_size)}D" ] layers = [] layers.append( Conv( in_channels=in_n_filters, out_channels=n_filters, kernel_size=kernel_size, ) ) for _ in range(conv_per_block - 2): layers.append(Conv(n_filters, n_filters, kernel_size)) layers.append(FinalConv(n_filters, n_filters, kernel_size)) self.block = nn.Sequential(*layers) if n_filters != in_n_filters: ProjConv = generated_inferno_classes[f"Conv{len(kernel_size)}D"] self.projection_layer = ProjConv(in_n_filters, n_filters, kernel_size=1) else: self.projection_layer = None if valid: crop_each_side = [conv_per_block * (ks // 2) for ks in kernel_size] self.crop = Crop(..., *[slice(c, -c) for c in crop_each_side]) else: self.crop = None self.relu = nn.ReLU() # determine shrinkage # self.shrinkage = (1, 1) + tuple([conv_per_block * (ks - 1) for ks in kernel_size]) def forward(self, input): x = self.block(input) if self.crop is not None: input = self.crop(input) if self.projection_layer is None: x = x + input else: projected = self.projection_layer(input) x = x + projected x = self.relu(x) return x class HyLFM_Net(nn.Module): class InitName(str, Enum): uniform_ = "uniform" normal_ = "normal" constant_ = "constant" eye_ = "eye" dirac_ = "dirac" xavier_uniform_ = "xavier_uniform" xavier_normal_ = "xavier_normal" kaiming_uniform_ = "kaiming_uniform" kaiming_normal_ = "kaiming_normal" orthogonal_ = "orthogonal" sparse_ = "sparse" def __init__( self, *, z_out: int, nnum: int, kernel2d: int = 3, conv_per_block2d: int = 2, c_res2d: Sequence[Union[int, str]] = (488, 488, "u244", 244), last_kernel2d: int = 1, c_in_3d: int = 7, kernel3d: int = 3, conv_per_block3d: int = 2, c_res3d: Sequence[str] = (7, "u7", 7, 7), init_fn: Union[InitName, str] = InitName.xavier_uniform_.value, final_activation: Optional[str] = None, ): super().__init__() self.channel_from_lf = ChannelFromLightField(nnum=nnum) init_fn = self.InitName(init_fn) if hasattr(nn.init, f"{init_fn.value}_"): # prevents deprecation warning init_fn = getattr(nn.init, f"{init_fn.value}_") else: init_fn = getattr(nn.init, init_fn.value) self.c_res2d = list(c_res2d) self.c_res3d = list(c_res3d) c_res3d = c_res3d self.nnum = nnum self.z_out = z_out if kernel3d != 3: raise NotImplementedError("z_out expansion for other res3d kernel") dz = 2 * conv_per_block3d * (kernel3d // 2) for c in c_res3d: if isinstance(c, int) or not c.startswith("u"): z_out += dz # z_out += 4 * (len(c_res3d) - 2 * sum([layer == "u" for layer in c_res3d])) # add z_out for valid 3d convs assert ( c_res2d[-1] != "u" ), "missing # output channels for upsampling in 'c_res2d'" assert ( c_res3d[-1] != "u" ), "missing # output channels for upsampling in 'c_res3d'" res2d = [] c_in = nnum**2 c_out = c_in for i in range(len(c_res2d)): if not isinstance(c_res2d[i], int) and c_res2d[i].startswith("u"): c_out = int(c_res2d[i][1:]) res2d.append( nn.ConvTranspose2d( in_channels=c_in, out_channels=c_out, kernel_size=2, stride=2, padding=0, output_padding=0, ) ) else: c_out = int(c_res2d[i]) res2d.append( ResnetBlock( in_n_filters=c_in, n_filters=c_out, kernel_size=(kernel2d, kernel2d), valid=False, conv_per_block=conv_per_block2d, ) ) c_in = c_out self.res2d = nn.Sequential(*res2d) if "gain" in inspect.signature(init_fn).parameters: init_fn_conv2d = partial(init_fn, gain=nn.init.calculate_gain("relu")) else: init_fn_conv2d = init_fn init = Initialization( weight_initializer=init_fn_conv2d, bias_initializer=Constant(0.0) ) self.conv2d = Conv2D( c_out, z_out * c_in_3d, last_kernel2d, activation="ReLU", initialization=init, ) self.c2z = lambda ipt, ip3=c_in_3d: ipt.view( ipt.shape[0], ip3, z_out, *ipt.shape[2:] ) res3d = [] c_in = c_in_3d c_out = c_in for i in range(len(c_res3d)): if not isinstance(c_res3d[i], int) and c_res3d[i].startswith("u"): c_out = int(c_res3d[i][1:]) res3d.append( nn.ConvTranspose3d( in_channels=c_in, out_channels=c_out, kernel_size=(3, 2, 2), stride=(1, 2, 2), padding=(1, 0, 0), output_padding=0, ) ) else: c_out = int(c_res3d[i]) res3d.append( ResnetBlock( in_n_filters=c_in, n_filters=c_out, kernel_size=(kernel3d, kernel3d, kernel3d), valid=True, conv_per_block=conv_per_block3d, ) ) c_in = c_out self.res3d = nn.Sequential(*res3d) if "gain" in inspect.signature(init_fn).parameters: init_fn_conv3d = partial(init_fn, gain=nn.init.calculate_gain("linear")) else: init_fn_conv3d = init_fn init = Initialization( weight_initializer=init_fn_conv3d, bias_initializer=Constant(0.0) ) self.conv3d = ValidConv3D(c_out, 1, (1, 1, 1), initialization=init) if final_activation is None: self.final_activation = None elif final_activation == "sigmoid": self.final_activation = nn.Sigmoid() else: raise NotImplementedError(final_activation) def forward(self, x): x = self.channel_from_lf(x) x = self.res2d(x) x = self.conv2d(x) x = self.c2z(x) x = self.res3d(x) x = self.conv3d(x) if self.final_activation is not None: x = self.final_activation(x) return x def get_scale(self, ipt_shape: Optional[Tuple[int, int]] = None) -> int: s = max( 1, 2 * sum( isinstance(res2d, str) and res2d.startswith("u") for res2d in self.c_res2d ), ) * max( 1, 2 * sum( isinstance(res3d, str) and res3d.startswith("u") for res3d in self.c_res3d ), ) return s def get_shrink(self, ipt_shape: Optional[Tuple[int, int]] = None) -> int: s = 0 for res in self.c_res3d: if isinstance(res, str) and res.startswith("u"): s *= 2 else: s += 2 return s def get_output_shape(self, ipt_shape: Tuple[int, int]) -> Tuple[int, int, int]: scale = self.get_scale(ipt_shape) shrink = self.get_shrink(ipt_shape) return (self.z_out,) + tuple(i * scale - 2 * shrink for i in ipt_shape) if __name__ == "__main__": # Example usage model = HyLFM_Net( z_out=9, nnum=5, kernel2d=3, conv_per_block2d=2, c_res2d=(12, 14, "u14", 8), last_kernel2d=1, c_in_3d=7, kernel3d=3, conv_per_block3d=2, c_res3d=(7, "u7", 7, 7), init_fn="xavier_uniform", final_activation="sigmoid", ) print(model) print(model.get_output_shape((64, 64)))