from typing import Union, Tuple, List from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer from torch import nn class nnInteractiveTrainer_stub(): def __init__(self, *args, **kwargs): pass @staticmethod def build_network_architecture(architecture_class_name: str, arch_init_kwargs: dict, arch_init_kwargs_req_import: Union[List[str], Tuple[str, ...]], num_input_channels: int, num_output_channels: int, enable_deep_supervision: bool = True) -> nn.Module: return nnUNetTrainer.build_network_architecture( architecture_class_name, arch_init_kwargs, arch_init_kwargs_req_import, num_input_channels + 7, 2, # nnunet handles one class segmentation still as CE so we need 2 outputs. enable_deep_supervision )