kbressem's picture
Upload ct_binary_coronary_segmentation/scripts/utils.py with huggingface_hub
6c85a1c verified
"""Utilities for preparing datasets and inference inputs for the binary coronary bundle."""
import logging
import multiprocessing
import os
import resource
from typing import Dict, List, Optional, Union
import monai # type: ignore
import psutil
import torch # type: ignore
logger = logging.getLogger(__name__)
USE_AMP = monai.utils.get_torch_version_tuple() >= (1, 6) # type: ignore
def num_workers() -> int:
"""Return a conservative worker count based on CPU, RAM and file descriptor limits."""
n_workers = max(multiprocessing.cpu_count() - 1, 1)
available_ram_in_gb = psutil.virtual_memory()[0] / 1024**3
max_workers = int(available_ram_in_gb // 4)
if max_workers < n_workers:
n_workers = max_workers
soft_limit, hard_limit = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_CORE, (resource.RLIM_INFINITY, resource.RLIM_INFINITY))
max_workers = soft_limit // 256
if max_workers < n_workers:
logger.info(
"Open file handle limit (%s) constrains multiprocessing; switching to file_system sharing.",
soft_limit,
)
n_workers = min(16, n_workers)
torch.multiprocessing.set_sharing_strategy("file_system")
# Cap at 16 workers max regardless of system resources
n_workers = min(16, n_workers)
logger.info("using number of workers: %s", n_workers)
return n_workers
IMAGE_FILE_EXT = [".nii", ".nii.gz", ".nrrd", ".dcm"]
def parse_data_for_inference(fn_or_dir: Optional[str] = None) -> Union[None, List[Dict[str, str]]]:
"""Convert filepath to data_dict for inference runs."""
if not fn_or_dir:
return None
if os.path.isfile(fn_or_dir):
data_dict = [{"image": fn_or_dir}]
elif os.path.isdir(fn_or_dir):
files = sorted(
[fn for fn in os.listdir(fn_or_dir) if any(fn.endswith(ext) for ext in IMAGE_FILE_EXT)]
)
data_dict = [{"image": os.path.join(fn_or_dir, fn)} for fn in files]
else:
raise FileNotFoundError(fn_or_dir)
return data_dict