"""Utilities for preparing datasets and inference inputs for the binary coronary bundle.""" import logging import multiprocessing import os import resource from typing import Dict, List, Optional, Union import monai # type: ignore import psutil import torch # type: ignore logger = logging.getLogger(__name__) USE_AMP = monai.utils.get_torch_version_tuple() >= (1, 6) # type: ignore def num_workers() -> int: """Return a conservative worker count based on CPU, RAM and file descriptor limits.""" n_workers = max(multiprocessing.cpu_count() - 1, 1) available_ram_in_gb = psutil.virtual_memory()[0] / 1024**3 max_workers = int(available_ram_in_gb // 4) if max_workers < n_workers: n_workers = max_workers soft_limit, hard_limit = resource.getrlimit(resource.RLIMIT_NOFILE) resource.setrlimit(resource.RLIMIT_CORE, (resource.RLIM_INFINITY, resource.RLIM_INFINITY)) max_workers = soft_limit // 256 if max_workers < n_workers: logger.info( "Open file handle limit (%s) constrains multiprocessing; switching to file_system sharing.", soft_limit, ) n_workers = min(16, n_workers) torch.multiprocessing.set_sharing_strategy("file_system") # Cap at 16 workers max regardless of system resources n_workers = min(16, n_workers) logger.info("using number of workers: %s", n_workers) return n_workers IMAGE_FILE_EXT = [".nii", ".nii.gz", ".nrrd", ".dcm"] def parse_data_for_inference(fn_or_dir: Optional[str] = None) -> Union[None, List[Dict[str, str]]]: """Convert filepath to data_dict for inference runs.""" if not fn_or_dir: return None if os.path.isfile(fn_or_dir): data_dict = [{"image": fn_or_dir}] elif os.path.isdir(fn_or_dir): files = sorted( [fn for fn in os.listdir(fn_or_dir) if any(fn.endswith(ext) for ext in IMAGE_FILE_EXT)] ) data_dict = [{"image": os.path.join(fn_or_dir, fn)} for fn in files] else: raise FileNotFoundError(fn_or_dir) return data_dict