File size: 2,081 Bytes
65366a8 62785f9 65366a8 6c85a1c 62785f9 65366a8 62785f9 65366a8 62785f9 65366a8 62785f9 65366a8 62785f9 65366a8 62785f9 65366a8 62785f9 65366a8 62785f9 65366a8 62785f9 65366a8 62785f9 65366a8 62785f9 65366a8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 | """Utilities for preparing datasets and inference inputs for the binary coronary bundle."""
import logging
import multiprocessing
import os
import resource
from typing import Dict, List, Optional, Union
import monai # type: ignore
import psutil
import torch # type: ignore
logger = logging.getLogger(__name__)
USE_AMP = monai.utils.get_torch_version_tuple() >= (1, 6) # type: ignore
def num_workers() -> int:
"""Return a conservative worker count based on CPU, RAM and file descriptor limits."""
n_workers = max(multiprocessing.cpu_count() - 1, 1)
available_ram_in_gb = psutil.virtual_memory()[0] / 1024**3
max_workers = int(available_ram_in_gb // 4)
if max_workers < n_workers:
n_workers = max_workers
soft_limit, hard_limit = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_CORE, (resource.RLIM_INFINITY, resource.RLIM_INFINITY))
max_workers = soft_limit // 256
if max_workers < n_workers:
logger.info(
"Open file handle limit (%s) constrains multiprocessing; switching to file_system sharing.",
soft_limit,
)
n_workers = min(16, n_workers)
torch.multiprocessing.set_sharing_strategy("file_system")
# Cap at 16 workers max regardless of system resources
n_workers = min(16, n_workers)
logger.info("using number of workers: %s", n_workers)
return n_workers
IMAGE_FILE_EXT = [".nii", ".nii.gz", ".nrrd", ".dcm"]
def parse_data_for_inference(fn_or_dir: Optional[str] = None) -> Union[None, List[Dict[str, str]]]:
"""Convert filepath to data_dict for inference runs."""
if not fn_or_dir:
return None
if os.path.isfile(fn_or_dir):
data_dict = [{"image": fn_or_dir}]
elif os.path.isdir(fn_or_dir):
files = sorted(
[fn for fn in os.listdir(fn_or_dir) if any(fn.endswith(ext) for ext in IMAGE_FILE_EXT)]
)
data_dict = [{"image": os.path.join(fn_or_dir, fn)} for fn in files]
else:
raise FileNotFoundError(fn_or_dir)
return data_dict
|