| """Utilities for preparing datasets and inference inputs for the binary coronary bundle.""" |
|
|
| import logging |
| import multiprocessing |
| import os |
| import resource |
|
|
| from typing import Dict, List, Optional, Union |
|
|
| import monai |
| import psutil |
| import torch |
|
|
|
|
| logger = logging.getLogger(__name__) |
|
|
| USE_AMP = monai.utils.get_torch_version_tuple() >= (1, 6) |
|
|
|
|
| def num_workers() -> int: |
| """Return a conservative worker count based on CPU, RAM and file descriptor limits.""" |
|
|
| n_workers = max(multiprocessing.cpu_count() - 1, 1) |
|
|
| available_ram_in_gb = psutil.virtual_memory()[0] / 1024**3 |
| max_workers = int(available_ram_in_gb // 4) |
| if max_workers < n_workers: |
| n_workers = max_workers |
|
|
| soft_limit, hard_limit = resource.getrlimit(resource.RLIMIT_NOFILE) |
| resource.setrlimit(resource.RLIMIT_CORE, (resource.RLIM_INFINITY, resource.RLIM_INFINITY)) |
| max_workers = soft_limit // 256 |
|
|
| if max_workers < n_workers: |
| logger.info( |
| "Open file handle limit (%s) constrains multiprocessing; switching to file_system sharing.", |
| soft_limit, |
| ) |
| n_workers = min(16, n_workers) |
| torch.multiprocessing.set_sharing_strategy("file_system") |
|
|
| |
| n_workers = min(16, n_workers) |
| logger.info("using number of workers: %s", n_workers) |
|
|
| return n_workers |
|
|
|
|
| IMAGE_FILE_EXT = [".nii", ".nii.gz", ".nrrd", ".dcm"] |
|
|
|
|
| def parse_data_for_inference(fn_or_dir: Optional[str] = None) -> Union[None, List[Dict[str, str]]]: |
| """Convert filepath to data_dict for inference runs.""" |
|
|
| if not fn_or_dir: |
| return None |
|
|
| if os.path.isfile(fn_or_dir): |
| data_dict = [{"image": fn_or_dir}] |
| elif os.path.isdir(fn_or_dir): |
| files = sorted( |
| [fn for fn in os.listdir(fn_or_dir) if any(fn.endswith(ext) for ext in IMAGE_FILE_EXT)] |
| ) |
| data_dict = [{"image": os.path.join(fn_or_dir, fn)} for fn in files] |
| else: |
| raise FileNotFoundError(fn_or_dir) |
|
|
| return data_dict |
|
|