thefynnbe's picture
Upload 1.3 with bioimageio.spec 0.5.7.1
9af5e69 verified
# type: ignore
import inspect
from enum import Enum
from functools import partial
from typing import Optional, Sequence, Tuple, Union
import numpy as np
import torch.nn as nn
### Inferno parts (adapted from inferno 0.4.2)
def assert_(condition, message="", exception_type=AssertionError):
"""Like assert, but with arbitrary exception types."""
if not condition:
raise exception_type(message)
# proxy for generated classes in inferno
generated_inferno_classes = {}
def partial_cls(base_cls, name, fix=None, default=None):
# helper function
def insert_if_not_present(dict_a, dict_b):
for kw, val in dict_b.items():
if kw not in dict_a:
dict_a[kw] = val
return dict_a
# helper function
def insert_call_if_present(dict_a, dict_b, callback):
for kw, val in dict_b.items():
if kw not in dict_a:
dict_a[kw] = val
else:
callback(kw)
return dict_a
# helper class
class PartialCls(object):
def __init__(self, base_cls, name, fix=None, default=None):
self.base_cls = base_cls
self.name = name
self.fix = [fix, {}][fix is None]
self.default = [default, {}][default is None]
if self.fix.keys() & self.default.keys():
raise TypeError("fix and default share keys")
# remove binded kw
self._allowed_kw = self._get_allowed_kw()
def _get_allowed_kw(self):
argspec = inspect.getfullargspec(base_cls.__init__)
args, varargs, varkw, defaults, kwonlyargs, kwonlydefaults, annotations = (
argspec
)
if varargs is not None:
raise TypeError(
"partial_cls can only be used if __init__ has no varargs"
)
if varkw is not None:
raise TypeError("partial_cls can only be used if __init__ has no varkw")
if kwonlyargs is not None and kwonlyargs != []:
raise TypeError("partial_cls can only be used without kwonlyargs")
if args is None or len(args) < 1:
raise TypeError("seems like self is missing")
return [kw for kw in args[1:] if kw not in self.fix]
def _build_kw(self, args, kwargs):
# handle *args
if len(args) > len(self._allowed_kw):
raise TypeError("to many arguments")
all_args = {}
for arg, akw in zip(args, self._allowed_kw):
all_args[akw] = arg
# handle **kwargs
intersection = self.fix.keys() & kwargs.keys()
if len(intersection) >= 1:
kw = intersection.pop()
raise TypeError(
"`{}.__init__` got unexpected keyword argument '{}'".format(
name, kw
)
)
def raise_cb(kw):
raise TypeError(
"{}.__init__ got multiple values for argument '{}'".format(name, kw)
)
all_args = insert_call_if_present(all_args, kwargs, raise_cb)
# handle fixed arguments
def raise_cb(kw):
raise TypeError()
all_args = insert_call_if_present(all_args, self.fix, raise_cb)
# handle defaults
all_args = insert_if_not_present(all_args, self.default)
# handle fixed
all_args.update(self.fix)
return all_args
def build_cls(self):
def new_init(self_of_new_cls, *args, **kwargs):
combined_args = self._build_kw(args=args, kwargs=kwargs)
# call base cls init
super(self_of_new_cls.__class__, self_of_new_cls).__init__(
**combined_args
)
return type(name, (self.base_cls,), {"__init__": new_init})
return PartialCls(
base_cls=base_cls, name=name, fix=fix, default=default
).build_cls()
def register_partial_cls(base_cls, name, fix=None, default=None):
generatedClass = partial_cls(base_cls=base_cls, name=name, fix=fix, default=default)
generated_inferno_classes[generatedClass.__name__] = generatedClass
class Initializer(object):
"""
Base class for all initializers.
"""
# TODO Support LSTMs and GRUs
VALID_LAYERS = {
"Conv1d",
"Conv2d",
"Conv3d",
"ConvTranspose1d",
"ConvTranspose2d",
"ConvTranspose3d",
"Linear",
"Bilinear",
"Embedding",
}
def __call__(self, module):
module_class_name = module.__class__.__name__
if module_class_name in self.VALID_LAYERS:
# Apply to weight and bias
try:
if hasattr(module, "weight"):
self.call_on_weight(module.weight.data)
except NotImplementedError:
# Don't cry if it's not implemented
pass
try:
if hasattr(module, "bias"):
self.call_on_bias(module.bias.data)
except NotImplementedError:
pass
return module
def call_on_bias(self, tensor):
return self.call_on_tensor(tensor)
def call_on_weight(self, tensor):
return self.call_on_tensor(tensor)
def call_on_tensor(self, tensor):
raise NotImplementedError
@classmethod
def initializes_weight(cls):
return "call_on_tensor" in cls.__dict__ or "call_on_weight" in cls.__dict__
@classmethod
def initializes_bias(cls):
return "call_on_tensor" in cls.__dict__ or "call_on_bias" in cls.__dict__
class Initialization(Initializer):
def __init__(self, weight_initializer=None, bias_initializer=None):
if weight_initializer is None:
self.weight_initializer = Initializer()
else:
if isinstance(weight_initializer, Initializer):
assert weight_initializer.initializes_weight()
self.weight_initializer = weight_initializer
elif isinstance(weight_initializer, str):
init_function = getattr(nn.init, weight_initializer, None)
assert init_function is not None
self.weight_initializer = WeightInitFunction(
init_function=init_function
)
else:
# Provison for weight_initializer to be a function
assert callable(weight_initializer)
self.weight_initializer = WeightInitFunction(
init_function=weight_initializer
)
if bias_initializer is None:
self.bias_initializer = Initializer()
else:
if isinstance(bias_initializer, Initializer):
assert bias_initializer.initializes_bias
self.bias_initializer = bias_initializer
elif isinstance(bias_initializer, str):
init_function = getattr(nn.init, bias_initializer, None)
assert init_function is not None
self.bias_initializer = BiasInitFunction(init_function=init_function)
else:
assert callable(bias_initializer)
self.bias_initializer = BiasInitFunction(init_function=bias_initializer)
def call_on_weight(self, tensor):
return self.weight_initializer.call_on_weight(tensor)
def call_on_bias(self, tensor):
return self.bias_initializer.call_on_bias(tensor)
class WeightInitFunction(Initializer):
def __init__(self, init_function, *init_function_args, **init_function_kwargs):
super(WeightInitFunction, self).__init__()
assert callable(init_function)
self.init_function = init_function
self.init_function_args = init_function_args
self.init_function_kwargs = init_function_kwargs
def call_on_weight(self, tensor):
return self.init_function(
tensor, *self.init_function_args, **self.init_function_kwargs
)
class BiasInitFunction(Initializer):
def __init__(self, init_function, *init_function_args, **init_function_kwargs):
super(BiasInitFunction, self).__init__()
assert callable(init_function)
self.init_function = init_function
self.init_function_args = init_function_args
self.init_function_kwargs = init_function_kwargs
def call_on_bias(self, tensor):
return self.init_function(
tensor, *self.init_function_args, **self.init_function_kwargs
)
class TensorInitFunction(Initializer):
def __init__(self, init_function, *init_function_args, **init_function_kwargs):
super(TensorInitFunction, self).__init__()
assert callable(init_function)
self.init_function = init_function
self.init_function_args = init_function_args
self.init_function_kwargs = init_function_kwargs
def call_on_tensor(self, tensor):
return self.init_function(
tensor, *self.init_function_args, **self.init_function_kwargs
)
class Constant(Initializer):
"""Initialize with a constant."""
def __init__(self, constant):
self.constant = constant
def call_on_tensor(self, tensor):
tensor.fill_(self.constant)
return tensor
class NormalWeights(Initializer):
"""
Initialize weights with random numbers drawn from the normal distribution at
`mean` and `stddev`.
"""
def __init__(self, mean=0.0, stddev=1.0, sqrt_gain_over_fan_in=None):
self.mean = mean
self.stddev = stddev
self.sqrt_gain_over_fan_in = sqrt_gain_over_fan_in
def compute_fan_in(self, tensor):
if tensor.dim() == 2:
return tensor.size(1)
else:
return np.prod(list(tensor.size())[1:])
def call_on_weight(self, tensor):
# Compute stddev if required
if self.sqrt_gain_over_fan_in is not None:
stddev = self.stddev * np.sqrt(
self.sqrt_gain_over_fan_in / self.compute_fan_in(tensor)
)
else:
stddev = self.stddev
# Init
tensor.normal_(self.mean, stddev)
class OrthogonalWeightsZeroBias(Initialization):
def __init__(self, orthogonal_gain=1.0):
# This prevents a deprecated warning in Pytorch 0.4+
orthogonal = getattr(nn.init, "orthogonal_", nn.init.orthogonal)
super(OrthogonalWeightsZeroBias, self).__init__(
weight_initializer=partial(orthogonal, gain=orthogonal_gain),
bias_initializer=Constant(0.0),
)
class KaimingNormalWeightsZeroBias(Initialization):
def __init__(self, relu_leakage=0):
# This prevents a deprecated warning in Pytorch 0.4+
kaiming_normal = getattr(nn.init, "kaiming_normal_", nn.init.kaiming_normal)
super(KaimingNormalWeightsZeroBias, self).__init__(
weight_initializer=partial(kaiming_normal, a=relu_leakage),
bias_initializer=Constant(0.0),
)
class SELUWeightsZeroBias(Initialization):
def __init__(self):
super(SELUWeightsZeroBias, self).__init__(
weight_initializer=NormalWeights(sqrt_gain_over_fan_in=1.0),
bias_initializer=Constant(0.0),
)
class ELUWeightsZeroBias(Initialization):
def __init__(self):
super(ELUWeightsZeroBias, self).__init__(
weight_initializer=NormalWeights(sqrt_gain_over_fan_in=1.5505188080679277),
bias_initializer=Constant(0.0),
)
class BatchNormND(nn.Module):
def __init__(
self,
dim,
num_features,
eps=1e-5,
momentum=0.1,
affine=True,
track_running_stats=True,
):
super(BatchNormND, self).__init__()
assert dim in [1, 2, 3]
self.bn = getattr(nn, "BatchNorm{}d".format(dim))(
num_features=num_features,
eps=eps,
momentum=momentum,
affine=affine,
track_running_stats=track_running_stats,
)
def forward(self, x):
return self.bn(x)
class ConvActivation(nn.Module):
"""Convolutional layer with 'SAME' padding by default followed by an activation."""
def __init__(
self,
in_channels,
out_channels,
kernel_size,
dim,
activation,
stride=1,
dilation=1,
groups=None,
depthwise=False,
bias=True,
deconv=False,
initialization=None,
valid_conv=False,
):
super(ConvActivation, self).__init__()
# Validate dim
assert_(
dim in [1, 2, 3],
"`dim` must be one of [1, 2, 3], got {}.".format(dim),
)
self.dim = dim
# Check if depthwise
if depthwise:
# We know that in_channels == out_channels, but we also want a consistent API.
# As a compromise, we allow that out_channels be None or 'auto'.
out_channels = (
in_channels if out_channels in [None, "auto"] else out_channels
)
assert_(
in_channels == out_channels,
"For depthwise convolutions, number of input channels (given: {}) "
"must equal the number of output channels (given {}).".format(
in_channels, out_channels
),
ValueError,
)
assert_(
groups is None or groups == in_channels,
"For depthwise convolutions, groups (given: {}) must "
"equal the number of channels (given: {}).".format(groups, in_channels),
)
groups = in_channels
else:
groups = 1 if groups is None else groups
self.depthwise = depthwise
if valid_conv:
self.conv = getattr(nn, "Conv{}d".format(self.dim))(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
groups=groups,
bias=bias,
)
elif not deconv:
# Get padding
padding = self.get_padding(kernel_size, dilation)
self.conv = getattr(nn, "Conv{}d".format(self.dim))(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
padding=padding,
stride=stride,
dilation=dilation,
groups=groups,
bias=bias,
)
else:
self.conv = getattr(nn, "ConvTranspose{}d".format(self.dim))(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
groups=groups,
bias=bias,
)
if initialization is None:
pass
elif isinstance(initialization, Initializer):
self.conv.apply(initialization)
else:
raise NotImplementedError
if isinstance(activation, str):
self.activation = getattr(nn, activation)()
elif isinstance(activation, nn.Module):
self.activation = activation
elif activation is None:
self.activation = None
else:
raise NotImplementedError
def forward(self, input):
conved = self.conv(input)
if self.activation is not None:
activated = self.activation(conved)
else:
# No activation
activated = conved
return activated
def _pair_or_triplet(self, object_):
if isinstance(object_, (list, tuple)):
assert len(object_) == self.dim
return object_
else:
object_ = [object_] * self.dim
return object_
def _get_padding(self, _kernel_size, _dilation):
assert isinstance(_kernel_size, int)
assert isinstance(_dilation, int)
assert _kernel_size % 2 == 1
return ((_kernel_size - 1) // 2) * _dilation
def get_padding(self, kernel_size, dilation):
kernel_size = self._pair_or_triplet(kernel_size)
dilation = self._pair_or_triplet(dilation)
padding = [
self._get_padding(_kernel_size, _dilation)
for _kernel_size, _dilation in zip(kernel_size, dilation)
]
return tuple(padding)
# for consistency
ConvActivationND = ConvActivation
class _BNReLUSomeConv(object):
def forward(self, input):
normed = self.batchnorm(input)
activated = self.activation(normed)
conved = self.conv(activated)
return conved
class BNReLUConvBaseND(_BNReLUSomeConv, ConvActivation):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
dim,
stride=1,
dilation=1,
deconv=False,
):
super(BNReLUConvBaseND, self).__init__(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
dim=dim,
stride=stride,
activation=nn.ReLU(inplace=True),
dilation=dilation,
deconv=deconv,
initialization=KaimingNormalWeightsZeroBias(0),
)
self.batchnorm = BatchNormND(dim, in_channels)
def _register_bnr_conv_cls(conv_name, fix=None, default=None):
if fix is None:
fix = {}
if default is None:
default = {}
for dim in [1, 2, 3]:
cls_name = "BNReLU{}ND".format(conv_name)
register_partial_cls(BNReLUConvBaseND, cls_name, fix=fix, default=default)
for dim in [1, 2, 3]:
cls_name = "BNReLU{}{}D".format(conv_name, dim)
register_partial_cls(
BNReLUConvBaseND, cls_name, fix={**fix, "dim": dim}, default=default
)
def _register_conv_cls(conv_name, fix=None, default=None):
if fix is None:
fix = {}
if default is None:
default = {}
# simple conv activation
activations = ["ReLU", "ELU", "Sigmoid", "SELU", ""]
init_map = {"ReLU": KaimingNormalWeightsZeroBias, "SELU": SELUWeightsZeroBias}
for activation_str in activations:
cls_name = cls_name = "{}{}ND".format(conv_name, activation_str)
initialization_cls = init_map.get(activation_str, OrthogonalWeightsZeroBias)
if activation_str == "":
activation = None
_fix = {**fix}
_default = {"activation": None}
elif activation_str == "SELU":
activation = nn.SELU(inplace=True)
_fix = {**fix, "activation": activation}
_default = {**default}
else:
activation = activation_str
_fix = {**fix, "activation": activation}
_default = {**default}
register_partial_cls(
ConvActivation,
cls_name,
fix=_fix,
default={**_default, "initialization": initialization_cls()},
)
for dim in [1, 2, 3]:
cls_name = "{}{}{}D".format(conv_name, activation_str, dim)
register_partial_cls(
ConvActivation,
cls_name,
fix={**_fix, "dim": dim},
default={**_default, "initialization": initialization_cls()},
)
_register_conv_cls("Conv")
_register_conv_cls("ValidConv", fix=dict(valid_conv=True))
Conv2D = generated_inferno_classes["Conv2D"]
ValidConv3D = generated_inferno_classes["ValidConv3D"]
### HyLFM architecture
class Crop(nn.Module):
def __init__(self, *slices: slice):
super().__init__()
self.slices = slices
def extra_repr(self):
return str(self.slices)
def forward(self, input):
return input[self.slices]
class ChannelFromLightField(nn.Module):
def __init__(self, nnum: int):
super().__init__()
self.nnum = nnum
def forward(self, tensor):
assert len(tensor.shape) == 4, tensor.shape
b, c, x, y = tensor.shape
assert c == 1
assert x % self.nnum == 0, (x, self.nnum)
assert y % self.nnum == 0, (y, self.nnum)
return (
tensor.reshape(b, x // self.nnum, self.nnum, y // self.nnum, self.nnum)
.transpose(1, 2)
.transpose(2, 4)
.transpose(3, 4)
.reshape(b, self.nnum**2, x // self.nnum, y // self.nnum)
)
class ResnetBlock(nn.Module):
def __init__(
self,
in_n_filters,
n_filters,
kernel_size=(3, 3),
batch_norm=False,
conv_per_block=2,
valid: bool = False,
activation: str = "ReLU",
):
super().__init__()
if batch_norm and activation != "ReLU":
raise NotImplementedError("batch_norm with non ReLU activation")
assert isinstance(kernel_size, tuple), kernel_size
assert conv_per_block >= 2
self.debug = False # sys.gettrace() is not None
Conv = generated_inferno_classes[
f"{'BNReLU' if batch_norm else ''}{'Valid' if valid else ''}Conv{'' if batch_norm else activation}{len(kernel_size)}D"
]
FinalConv = generated_inferno_classes[
f"{'BNReLU' if batch_norm else ''}{'Valid' if valid else ''}Conv{len(kernel_size)}D"
]
layers = []
layers.append(
Conv(
in_channels=in_n_filters,
out_channels=n_filters,
kernel_size=kernel_size,
)
)
for _ in range(conv_per_block - 2):
layers.append(Conv(n_filters, n_filters, kernel_size))
layers.append(FinalConv(n_filters, n_filters, kernel_size))
self.block = nn.Sequential(*layers)
if n_filters != in_n_filters:
ProjConv = generated_inferno_classes[f"Conv{len(kernel_size)}D"]
self.projection_layer = ProjConv(in_n_filters, n_filters, kernel_size=1)
else:
self.projection_layer = None
if valid:
crop_each_side = [conv_per_block * (ks // 2) for ks in kernel_size]
self.crop = Crop(..., *[slice(c, -c) for c in crop_each_side])
else:
self.crop = None
self.relu = nn.ReLU()
# determine shrinkage
# self.shrinkage = (1, 1) + tuple([conv_per_block * (ks - 1) for ks in kernel_size])
def forward(self, input):
x = self.block(input)
if self.crop is not None:
input = self.crop(input)
if self.projection_layer is None:
x = x + input
else:
projected = self.projection_layer(input)
x = x + projected
x = self.relu(x)
return x
class HyLFM_Net(nn.Module):
class InitName(str, Enum):
uniform_ = "uniform"
normal_ = "normal"
constant_ = "constant"
eye_ = "eye"
dirac_ = "dirac"
xavier_uniform_ = "xavier_uniform"
xavier_normal_ = "xavier_normal"
kaiming_uniform_ = "kaiming_uniform"
kaiming_normal_ = "kaiming_normal"
orthogonal_ = "orthogonal"
sparse_ = "sparse"
def __init__(
self,
*,
z_out: int,
nnum: int,
kernel2d: int = 3,
conv_per_block2d: int = 2,
c_res2d: Sequence[Union[int, str]] = (488, 488, "u244", 244),
last_kernel2d: int = 1,
c_in_3d: int = 7,
kernel3d: int = 3,
conv_per_block3d: int = 2,
c_res3d: Sequence[str] = (7, "u7", 7, 7),
init_fn: Union[InitName, str] = InitName.xavier_uniform_.value,
final_activation: Optional[str] = None,
):
super().__init__()
self.channel_from_lf = ChannelFromLightField(nnum=nnum)
init_fn = self.InitName(init_fn)
if hasattr(nn.init, f"{init_fn.value}_"):
# prevents deprecation warning
init_fn = getattr(nn.init, f"{init_fn.value}_")
else:
init_fn = getattr(nn.init, init_fn.value)
self.c_res2d = list(c_res2d)
self.c_res3d = list(c_res3d)
c_res3d = c_res3d
self.nnum = nnum
self.z_out = z_out
if kernel3d != 3:
raise NotImplementedError("z_out expansion for other res3d kernel")
dz = 2 * conv_per_block3d * (kernel3d // 2)
for c in c_res3d:
if isinstance(c, int) or not c.startswith("u"):
z_out += dz
# z_out += 4 * (len(c_res3d) - 2 * sum([layer == "u" for layer in c_res3d])) # add z_out for valid 3d convs
assert (
c_res2d[-1] != "u"
), "missing # output channels for upsampling in 'c_res2d'"
assert (
c_res3d[-1] != "u"
), "missing # output channels for upsampling in 'c_res3d'"
res2d = []
c_in = nnum**2
c_out = c_in
for i in range(len(c_res2d)):
if not isinstance(c_res2d[i], int) and c_res2d[i].startswith("u"):
c_out = int(c_res2d[i][1:])
res2d.append(
nn.ConvTranspose2d(
in_channels=c_in,
out_channels=c_out,
kernel_size=2,
stride=2,
padding=0,
output_padding=0,
)
)
else:
c_out = int(c_res2d[i])
res2d.append(
ResnetBlock(
in_n_filters=c_in,
n_filters=c_out,
kernel_size=(kernel2d, kernel2d),
valid=False,
conv_per_block=conv_per_block2d,
)
)
c_in = c_out
self.res2d = nn.Sequential(*res2d)
if "gain" in inspect.signature(init_fn).parameters:
init_fn_conv2d = partial(init_fn, gain=nn.init.calculate_gain("relu"))
else:
init_fn_conv2d = init_fn
init = Initialization(
weight_initializer=init_fn_conv2d, bias_initializer=Constant(0.0)
)
self.conv2d = Conv2D(
c_out,
z_out * c_in_3d,
last_kernel2d,
activation="ReLU",
initialization=init,
)
self.c2z = lambda ipt, ip3=c_in_3d: ipt.view(
ipt.shape[0], ip3, z_out, *ipt.shape[2:]
)
res3d = []
c_in = c_in_3d
c_out = c_in
for i in range(len(c_res3d)):
if not isinstance(c_res3d[i], int) and c_res3d[i].startswith("u"):
c_out = int(c_res3d[i][1:])
res3d.append(
nn.ConvTranspose3d(
in_channels=c_in,
out_channels=c_out,
kernel_size=(3, 2, 2),
stride=(1, 2, 2),
padding=(1, 0, 0),
output_padding=0,
)
)
else:
c_out = int(c_res3d[i])
res3d.append(
ResnetBlock(
in_n_filters=c_in,
n_filters=c_out,
kernel_size=(kernel3d, kernel3d, kernel3d),
valid=True,
conv_per_block=conv_per_block3d,
)
)
c_in = c_out
self.res3d = nn.Sequential(*res3d)
if "gain" in inspect.signature(init_fn).parameters:
init_fn_conv3d = partial(init_fn, gain=nn.init.calculate_gain("linear"))
else:
init_fn_conv3d = init_fn
init = Initialization(
weight_initializer=init_fn_conv3d, bias_initializer=Constant(0.0)
)
self.conv3d = ValidConv3D(c_out, 1, (1, 1, 1), initialization=init)
if final_activation is None:
self.final_activation = None
elif final_activation == "sigmoid":
self.final_activation = nn.Sigmoid()
else:
raise NotImplementedError(final_activation)
def forward(self, x):
x = self.channel_from_lf(x)
x = self.res2d(x)
x = self.conv2d(x)
x = self.c2z(x)
x = self.res3d(x)
x = self.conv3d(x)
if self.final_activation is not None:
x = self.final_activation(x)
return x
def get_scale(self, ipt_shape: Optional[Tuple[int, int]] = None) -> int:
s = max(
1,
2
* sum(
isinstance(res2d, str) and res2d.startswith("u")
for res2d in self.c_res2d
),
) * max(
1,
2
* sum(
isinstance(res3d, str) and res3d.startswith("u")
for res3d in self.c_res3d
),
)
return s
def get_shrink(self, ipt_shape: Optional[Tuple[int, int]] = None) -> int:
s = 0
for res in self.c_res3d:
if isinstance(res, str) and res.startswith("u"):
s *= 2
else:
s += 2
return s
def get_output_shape(self, ipt_shape: Tuple[int, int]) -> Tuple[int, int, int]:
scale = self.get_scale(ipt_shape)
shrink = self.get_shrink(ipt_shape)
return (self.z_out,) + tuple(i * scale - 2 * shrink for i in ipt_shape)
if __name__ == "__main__":
# Example usage
model = HyLFM_Net(
z_out=9,
nnum=5,
kernel2d=3,
conv_per_block2d=2,
c_res2d=(12, 14, "u14", 8),
last_kernel2d=1,
c_in_3d=7,
kernel3d=3,
conv_per_block3d=2,
c_res3d=(7, "u7", 7, 7),
init_fn="xavier_uniform",
final_activation="sigmoid",
)
print(model)
print(model.get_output_shape((64, 64)))