| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from __future__ import annotations |
|
|
| import copy |
| import csv |
| import gc |
| import logging |
| import multiprocessing as mp |
| import os |
| import shutil |
| import sys |
| import time |
| import warnings |
| from datetime import datetime, timedelta |
| from pathlib import Path |
| from typing import Any, Callable, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union |
|
|
| import numpy as np |
| import psutil |
| import torch |
| import torch.distributed as dist |
| import torch.multiprocessing as mp |
| import yaml |
| from torch.amp import GradScaler, autocast |
| from torch.nn.parallel import DistributedDataParallel |
| from torch.utils.data.distributed import DistributedSampler |
| from torch.utils.tensorboard import SummaryWriter |
|
|
| from monai.apps.auto3dseg.auto_runner import logger |
| from monai.apps.auto3dseg.transforms import EnsureSameShaped |
| from monai.auto3dseg.utils import datafold_read |
| from monai.bundle.config_parser import ConfigParser |
| from monai.config import KeysCollection |
| from monai.data import CacheDataset, DataLoader, Dataset, DistributedSampler, decollate_batch, list_data_collate |
| from monai.inferers import SlidingWindowInfererAdapt |
| from monai.losses import DeepSupervisionLoss |
| from monai.metrics import CumulativeAverage, DiceHelper |
| from monai.networks.layers.factories import split_args |
| from monai.optimizers.lr_scheduler import WarmupCosineSchedule |
| from monai.transforms import ( |
| AsDiscreted, |
| CastToTyped, |
| ClassesToIndicesd, |
| Compose, |
| ConcatItemsd, |
| CopyItemsd, |
| CropForegroundd, |
| DataStatsd, |
| DeleteItemsd, |
| EnsureTyped, |
| Identityd, |
| Invertd, |
| Lambdad, |
| LoadImaged, |
| NormalizeIntensityd, |
| Orientationd, |
| RandAdjustContrastd, |
| RandAffined, |
| RandCropByLabelClassesd, |
| RandFlipd, |
| RandGaussianNoised, |
| RandGaussianSmoothd, |
| RandHistogramShiftd, |
| RandIdentity, |
| RandRotate90d, |
| RandScaleIntensityd, |
| RandScaleIntensityFixedMeand, |
| RandShiftIntensityd, |
| RandSpatialCropd, |
| ResampleToMatchd, |
| SaveImaged, |
| ScaleIntensityRanged, |
| Spacingd, |
| SpatialPadd, |
| ToDeviced, |
| ) |
| from monai.transforms.transform import MapTransform |
| from monai.utils import ImageMetaKey, convert_to_dst_type, optional_import, set_determinism |
|
|
| mlflow, mlflow_is_imported = optional_import("mlflow") |
|
|
|
|
| os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:2048" |
| print = logger.debug |
| tqdm, has_tqdm = optional_import("tqdm", name="tqdm") |
|
|
| if __package__ in (None, ""): |
| from utils import auto_adjust_network_settings, logger_configure |
| else: |
| from .utils import auto_adjust_network_settings, logger_configure |
|
|
|
|
| class LabelEmbedClassIndex(MapTransform): |
| """ |
| Label embedding according to class_index |
| """ |
|
|
| def __init__( |
| self, keys: KeysCollection = "label", allow_missing_keys: bool = False, class_index: Optional[List] = None |
| ) -> None: |
| """ |
| Args: |
| keys: keys of the corresponding items to be compared to the source_key item shape. |
| allow_missing_keys: do not raise exception if key is missing. |
| class_index: a list of class indices |
| """ |
| super().__init__(keys=keys, allow_missing_keys=allow_missing_keys) |
| self.class_index = class_index |
|
|
| def label_mapping(self, x: torch.Tensor) -> torch.Tensor: |
| dtype = x.dtype |
| return torch.cat([sum([x == i for i in c]) for c in self.class_index], dim=0).to(dtype=dtype) |
|
|
| def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: |
| d = dict(data) |
| if self.class_index is not None: |
| for key in self.key_iterator(d): |
| d[key] = self.label_mapping(d[key]) |
| return d |
|
|
|
|
| def schedule_validation_epochs(num_epochs, num_epochs_per_validation=None, fraction=0.16) -> list: |
| """ |
| Schedule of epochs to validate (progressively more frequently) |
| num_epochs - total number of epochs |
| num_epochs_per_validation - if provided use a linear schedule with this step |
| init_step |
| """ |
|
|
| if num_epochs_per_validation is None: |
| x = (np.sin(np.linspace(0, np.pi / 2, max(10, int(fraction * num_epochs)))) * num_epochs).astype(int) |
| x = np.cumsum(np.sort(np.diff(np.unique(x)))[::-1]) |
| x[-1] = num_epochs |
| x = x.tolist() |
| else: |
| if num_epochs_per_validation >= num_epochs: |
| x = [num_epochs_per_validation] |
| else: |
| x = list(range(num_epochs_per_validation, num_epochs, num_epochs_per_validation)) |
|
|
| if len(x) == 0: |
| x = [0] |
|
|
| return x |
|
|
|
|
| class DataTransformBuilder: |
| def __init__( |
| self, |
| roi_size: list, |
| image_key: str = "image", |
| label_key: str = "label", |
| resample: bool = False, |
| resample_resolution: Optional[list] = None, |
| normalize_mode: str = "meanstd", |
| normalize_params: Optional[dict] = None, |
| crop_mode: str = "ratio", |
| crop_params: Optional[dict] = None, |
| extra_modalities: Optional[dict] = None, |
| custom_transforms=None, |
| augment_params: Optional[dict] = None, |
| debug: bool = False, |
| rank: int = 0, |
| class_index=None, |
| **kwargs, |
| ) -> None: |
| self.roi_size, self.image_key, self.label_key = roi_size, image_key, label_key |
|
|
| self.resample, self.resample_resolution = resample, resample_resolution |
| self.normalize_mode = normalize_mode |
| self.normalize_params = normalize_params if normalize_params is not None else {} |
| self.crop_mode = crop_mode |
| self.crop_params = crop_params if crop_params is not None else {} |
| self.augment_params = augment_params if augment_params is not None else {} |
|
|
| self.extra_modalities = extra_modalities if extra_modalities is not None else {} |
| self.custom_transforms = custom_transforms if custom_transforms is not None else {} |
|
|
| self.extra_options = kwargs |
| self.debug = debug |
| self.rank = rank |
| self.class_index = class_index |
|
|
| def get_custom(self, name, **kwargs): |
| tr = [] |
| for t in self.custom_transforms.get(name, []): |
| if isinstance(t, dict): |
| t.update(kwargs) |
| t = ConfigParser(t).get_parsed_content(instantiate=True) |
| tr.append(t) |
|
|
| return tr |
|
|
| def get_load_transforms(self): |
| ts = self.get_custom("load_transforms") |
| if len(ts) > 0: |
| return ts |
|
|
| keys = [self.image_key, self.label_key] + list(self.extra_modalities) |
| ts.append( |
| LoadImaged(keys=keys, ensure_channel_first=True, dtype=None, allow_missing_keys=True, image_only=True) |
| ) |
| ts.append(EnsureTyped(keys=keys, data_type="tensor", dtype=torch.float, allow_missing_keys=True)) |
| ts.append( |
| EnsureSameShaped(keys=self.label_key, source_key=self.image_key, allow_missing_keys=True, warn=self.debug) |
| ) |
|
|
| ts.extend(self.get_custom("after_load_transforms")) |
|
|
| return ts |
|
|
| def get_resample_transforms(self, resample_label=True): |
| ts = self.get_custom("resample_transforms", resample_label=resample_label) |
| if len(ts) > 0: |
| return ts |
|
|
| keys = [self.image_key, self.label_key] if resample_label else [self.image_key] |
| mode = ["bilinear", "nearest"] if resample_label else ["bilinear"] |
| extra_keys = list(self.extra_modalities) |
|
|
| if self.extra_options.get("orientation_ras", False): |
| ts.append(Orientationd(keys=keys, axcodes="RAS", labels=(("L", "R"), ("P", "A"), ("I", "S")))) |
|
|
| if self.extra_options.get("crop_foreground", False) and len(extra_keys) == 0: |
| ts.append( |
| CropForegroundd( |
| keys=keys, source_key=self.image_key, allow_missing_keys=True, margin=10, allow_smaller=True |
| ) |
| ) |
| if self.resample: |
| if self.resample_resolution is None: |
| raise ValueError("resample_resolution is not provided") |
|
|
| pixdim = self.resample_resolution |
| ts.append( |
| Spacingd( |
| keys=keys, |
| pixdim=pixdim, |
| mode=mode, |
| dtype=torch.float, |
| min_pixdim=np.array(pixdim) * 0.75, |
| max_pixdim=np.array(pixdim) * 1.25, |
| allow_missing_keys=True, |
| ) |
| ) |
|
|
| if resample_label: |
| ts.append( |
| EnsureSameShaped( |
| keys=self.label_key, source_key=self.image_key, allow_missing_keys=True, warn=self.debug |
| ) |
| ) |
|
|
| for extra_key in extra_keys: |
| ts.append(ResampleToMatchd(keys=extra_key, key_dst=self.image_key, dtype=np.float32)) |
|
|
| ts.extend(self.get_custom("after_resample_transforms", resample_label=resample_label)) |
|
|
| return ts |
|
|
| def get_normalize_transforms(self): |
|
|
| ts = self.get_custom("normalize_transforms") |
| if len(ts) > 0: |
| return ts |
|
|
| label_dtype = self.normalize_params.get("label_dtype", None) |
| if label_dtype is not None: |
| ts.append(CastToTyped(keys=self.label_key, dtype=label_dtype, allow_missing_keys=True)) |
| image_dtype = self.normalize_params.get("image_dtype", None) |
| if image_dtype is not None: |
| ts.append(CastToTyped(keys=self.image_key, dtype=image_dtype, allow_missing_keys=True)) |
| ts.append(RandIdentity()) |
| ts.append(CastToTyped(keys=self.image_key, dtype=torch.float, allow_missing_keys=True)) |
|
|
| modalities = {self.image_key: self.normalize_mode} |
| modalities.update(self.extra_modalities) |
|
|
| for key, normalize_mode in modalities.items(): |
| if normalize_mode == "none": |
| pass |
| elif normalize_mode in ["range", "ct"]: |
| intensity_bounds = self.normalize_params.get("intensity_bounds", None) |
| if intensity_bounds is None: |
| intensity_bounds = [-250, 250] |
| warnings.warn(f"intensity_bounds is not specified, assuming {intensity_bounds}") |
|
|
| ts.append( |
| ScaleIntensityRanged( |
| keys=key, a_min=intensity_bounds[0], a_max=intensity_bounds[1], b_min=-1, b_max=1, clip=False |
| ) |
| ) |
| ts.append(Lambdad(keys=key, func=lambda x: torch.sigmoid(x))) |
| elif normalize_mode in ["meanstd", "mri"]: |
| ts.append(NormalizeIntensityd(keys=key, nonzero=True, channel_wise=True)) |
| elif normalize_mode in ["meanstdtanh"]: |
| ts.append(NormalizeIntensityd(keys=key, nonzero=True, channel_wise=True)) |
| ts.append(Lambdad(keys=key, func=lambda x: 3 * torch.tanh(x / 3))) |
| elif normalize_mode in ["pet"]: |
| ts.append(Lambdad(keys=key, func=lambda x: torch.sigmoid((x - x.min()) / x.std()))) |
| else: |
| raise ValueError("Unsupported normalize_mode" + str(normalize_mode)) |
|
|
| if len(self.extra_modalities) > 0: |
| ts.append(ConcatItemsd(keys=list(modalities), name=self.image_key)) |
| ts.append(DeleteItemsd(keys=list(self.extra_modalities))) |
|
|
| ts.extend(self.get_custom("after_normalize_transforms")) |
| return ts |
|
|
| def get_crop_transforms(self): |
| ts = self.get_custom("crop_transforms") |
| if len(ts) > 0: |
| return ts |
|
|
| if self.roi_size is None: |
| raise ValueError("roi_size is not specified") |
|
|
| keys = [self.image_key, self.label_key] |
| ts = [] |
| ts.append(SpatialPadd(keys=keys, spatial_size=self.roi_size)) |
|
|
| if self.crop_mode == "ratio": |
| output_classes = self.crop_params.get("output_classes", None) |
| if output_classes is None: |
| raise ValueError("crop_params option output_classes must be specified") |
|
|
| crop_ratios = self.crop_params.get("crop_ratios", None) |
| cache_class_indices = self.crop_params.get("cache_class_indices", False) |
| max_samples_per_class = self.crop_params.get("max_samples_per_class", None) |
| if max_samples_per_class <= 0: |
| max_samples_per_class = None |
| indices_key = None |
|
|
| sigmoid = self.extra_options.get("sigmoid", False) |
| crop_add_background = self.crop_params.get("crop_add_background", False) |
|
|
| if crop_ratios is None: |
| crop_classes = output_classes |
| if sigmoid and crop_add_background and self.class_index is not None and len(self.class_index) > 1: |
| crop_classes = crop_classes + 1 |
| else: |
| crop_classes = len(crop_ratios) |
|
|
| if self.debug: |
| print( |
| f"Cropping with classes {crop_classes} and crop_add_background {crop_add_background} ratios {crop_ratios}" |
| ) |
|
|
| if cache_class_indices: |
| ts.append( |
| ClassesToIndicesd( |
| keys=self.label_key, |
| num_classes=crop_classes, |
| indices_postfix="_cls_indices", |
| max_samples_per_class=max_samples_per_class, |
| ) |
| ) |
| indices_key = self.label_key + "_cls_indices" |
|
|
| num_crops_per_image = self.crop_params.get("num_crops_per_image", 1) |
| |
| |
|
|
| ts.append( |
| RandCropByLabelClassesd( |
| keys=keys, |
| label_key=self.label_key, |
| num_classes=crop_classes, |
| spatial_size=self.roi_size, |
| num_samples=num_crops_per_image, |
| ratios=crop_ratios, |
| indices_key=indices_key, |
| warn=False, |
| ) |
| ) |
| elif self.crop_mode == "rand": |
| ts.append(RandSpatialCropd(keys=keys, roi_size=self.roi_size, random_size=False)) |
| else: |
| raise ValueError("Unsupported crop mode" + str(self.crop_mode)) |
|
|
| ts.extend(self.get_custom("after_crop_transforms")) |
|
|
| return ts |
|
|
| def get_augment_transforms(self): |
| ts = self.get_custom("augment_transforms") |
| if len(ts) > 0: |
| return ts |
|
|
| if self.roi_size is None: |
| raise ValueError("roi_size is not specified") |
|
|
| augment_mode = self.augment_params.get("augment_mode", None) |
| augment_flips = self.augment_params.get("augment_flips", None) |
| augment_rots = self.augment_params.get("augment_rots", None) |
|
|
| if self.debug: |
| print(f"Using augment_mode {augment_mode}, augment_flips {augment_flips} augment_rots {augment_rots}") |
|
|
| ts = [] |
|
|
| if augment_mode is None or augment_mode == "default": |
|
|
| ts.append( |
| RandAffined( |
| keys=[self.image_key, self.label_key], |
| prob=0.2, |
| rotate_range=[0.26, 0.26, 0.26], |
| scale_range=[0.2, 0.2, 0.2], |
| mode=["bilinear", "nearest"], |
| spatial_size=self.roi_size, |
| cache_grid=True, |
| padding_mode="border", |
| ) |
| ) |
| ts.append( |
| RandGaussianSmoothd( |
| keys=self.image_key, prob=0.2, sigma_x=[0.5, 1.0], sigma_y=[0.5, 1.0], sigma_z=[0.5, 1.0] |
| ) |
| ) |
| ts.append(RandScaleIntensityd(keys=self.image_key, prob=0.5, factors=0.3)) |
| ts.append(RandShiftIntensityd(keys=self.image_key, prob=0.5, offsets=0.1)) |
| ts.append(RandGaussianNoised(keys=self.image_key, prob=0.2, mean=0.0, std=0.1)) |
|
|
| elif augment_mode == "none": |
|
|
| augment_flips = [] |
| augment_rots = [] |
|
|
| elif augment_mode == "ct_ax_1": |
|
|
| ts.append(RandHistogramShiftd(keys="image", prob=0.5, num_control_points=16)) |
| ts.append(RandAdjustContrastd(keys="image", prob=0.2, gamma=[0.5, 3.0])) |
|
|
| ts.append( |
| RandAffined( |
| keys=[self.image_key, self.label_key], |
| prob=0.5, |
| rotate_range=[0, 0, 0.26], |
| scale_range=[0.35, 0.35, 0], |
| mode=["bilinear", "nearest"], |
| spatial_size=self.roi_size, |
| cache_grid=True, |
| padding_mode="border", |
| ) |
| ) |
|
|
| elif augment_mode == "mri_1": |
|
|
| ts.append( |
| RandAffined( |
| keys=[self.image_key, self.label_key], |
| prob=0.2, |
| rotate_range=[0.26, 0.26, 0.26], |
| scale_range=[0.2, 0.2, 0.2], |
| mode=["bilinear", "nearest"], |
| spatial_size=self.roi_size, |
| cache_grid=True, |
| padding_mode="border", |
| ) |
| ) |
|
|
| ts.append(RandGaussianNoised(keys=self.image_key, prob=0.2, mean=0.0, std=0.1)) |
|
|
| ts.append( |
| RandGaussianSmoothd( |
| keys=self.image_key, prob=0.2, sigma_x=[0.5, 1.0], sigma_y=[0.5, 1.0], sigma_z=[0.5, 1.0] |
| ) |
| ) |
|
|
| ts.append(RandScaleIntensityFixedMeand(keys="image", prob=0.2, fixed_mean=True, factors=0.3)) |
| ts.append( |
| RandAdjustContrastd(keys="image", prob=0.2, gamma=[0.7, 1.5], retain_stats=True, invert_image=False) |
| ) |
| ts.append( |
| RandAdjustContrastd(keys="image", prob=0.2, gamma=[0.7, 1.5], retain_stats=True, invert_image=True) |
| ) |
|
|
| else: |
| raise ValueError("Unsupported augment_mode: " + str(augment_mode)) |
|
|
| |
| if augment_flips is None: |
| augment_flips = [0, 1, 2] |
| for sa in augment_flips: |
| ts.append(RandFlipd(keys=[self.image_key, self.label_key], prob=0.5, spatial_axis=sa)) |
|
|
| |
| if augment_rots is not None: |
| for sa in augment_rots: |
| ts.append(RandRotate90d(keys=[self.image_key, self.label_key], prob=0.5, spatial_axes=sa)) |
|
|
| ts.extend(self.get_custom("after_augment_transforms")) |
|
|
| return ts |
|
|
| def get_final_transforms(self): |
| return self.get_custom("final_transforms") |
|
|
| @classmethod |
| def get_postprocess_transform( |
| cls, |
| save_mask=False, |
| invert=False, |
| transform=None, |
| sigmoid=False, |
| output_path=None, |
| resample=False, |
| data_root_dir="", |
| output_dtype=np.uint8, |
| save_mask_mode=None, |
| ) -> Compose: |
| ts = [] |
| if invert and transform is not None: |
| |
| |
| ts.append(Invertd(keys="pred", orig_keys="image", transform=transform, nearest_interp=False)) |
|
|
| if save_mask and output_path is not None: |
| ts.append(CopyItemsd(keys="pred", times=1, names="seg")) |
| if save_mask_mode == "prob": |
| output_dtype = np.float32 |
| else: |
| ts.append( |
| AsDiscreted(keys="seg", argmax=True) if not sigmoid else AsDiscreted(keys="seg", threshold=0.5) |
| ) |
| ts.append( |
| SaveImaged( |
| keys=["seg"], |
| output_dir=output_path, |
| output_postfix="", |
| data_root_dir=data_root_dir, |
| output_dtype=output_dtype, |
| separate_folder=False, |
| squeeze_end_dims=True, |
| resample=False, |
| print_log=False, |
| ) |
| ) |
|
|
| return Compose(ts) |
|
|
| def __call__(self, augment=False, resample_label=False) -> Compose: |
| ts = [] |
| ts.extend(self.get_load_transforms()) |
| ts.extend(self.get_resample_transforms(resample_label=resample_label)) |
| ts.extend(self.get_normalize_transforms()) |
|
|
| if augment: |
| ts.extend(self.get_crop_transforms()) |
| ts.extend(self.get_augment_transforms()) |
|
|
| ts.extend(self.get_final_transforms()) |
|
|
| compose_ts = Compose(ts) |
|
|
| return compose_ts |
|
|
| def __repr__(self) -> str: |
| out: str = f"DataTransformBuilder: with image_key: {self.image_key}, label_key: {self.label_key} \n" |
| out += f"roi_size {self.roi_size} resample {self.resample} resample_resolution {self.resample_resolution} \n" |
| out += f"normalize_mode {self.normalize_mode} normalize_params {self.normalize_params} \n" |
| out += f"crop_mode {self.crop_mode} crop_params {self.crop_params} \n" |
| out += f"extra_modalities {self.extra_modalities} \n" |
| for k, trs in self.custom_transforms.items(): |
| out += f"Custom {k} : {str(trs)} \n" |
| return out |
|
|
|
|
| class Segmenter: |
| def __init__( |
| self, |
| config_file: Optional[Union[str, Sequence[str]]] = None, |
| config_dict: Dict = {}, |
| rank: int = 0, |
| global_rank: int = 0, |
| ) -> None: |
| self.rank = rank |
| self.global_rank = global_rank |
| self.distributed = dist.is_initialized() |
|
|
| if self.global_rank == 0: |
| print(f"Segmenter started config_file: {config_file}, config_dict: {config_dict}") |
|
|
| np.set_printoptions(formatter={"float": "{: 0.3f}".format}, suppress=True) |
| logging.getLogger("torch.nn.parallel.distributed").setLevel(logging.WARNING) |
|
|
| config = self.parse_input_config(config_file=config_file, override=config_dict) |
| self.config = config |
| self.config_file = config_file if not isinstance(config_file, (list, tuple)) else config_file[0] |
| self.override = config_dict |
|
|
| if config["ckpt_path"] is not None and not os.path.exists(config["ckpt_path"]): |
| os.makedirs(config["ckpt_path"], exist_ok=True) |
|
|
| if config["log_output_file"] is None: |
| config["log_output_file"] = os.path.join(self.config["ckpt_path"], "training.log") |
| logger_configure(log_output_file=config["log_output_file"], debug=config["debug"], global_rank=self.global_rank) |
|
|
| if config["fork"] and "fork" in mp.get_all_start_methods(): |
| mp.set_start_method("fork", force=True) |
| else: |
| warnings.warn( |
| "Multiprocessing method fork is not available, some non-picklable objects (e.g. lambda ) may fail" |
| ) |
|
|
| if config["cuda"] and torch.cuda.is_available(): |
| self.device = torch.device(self.rank) |
| if self.distributed and dist.get_backend() == dist.Backend.NCCL: |
| torch.cuda.set_device(rank) |
| else: |
| self.device = torch.device("cpu") |
|
|
| if self.global_rank == 0: |
| print(yaml.dump(config)) |
|
|
| if config["determ"]: |
| set_determinism(seed=0) |
| elif torch.cuda.is_available(): |
| torch.backends.cudnn.benchmark = True |
|
|
| if config["notf32"]: |
| torch.backends.cuda.matmul.allow_tf32 = False |
| torch.backends.cudnn.allow_tf32 = False |
| print(f"!!!disabling tf32") |
| if config.get("float32_precision", None) is not None: |
| torch.set_float32_matmul_precision(config["float32_precision"]) |
| print(f"!!!setting matmul precession {config['float32_precision']}") |
|
|
| |
| if config["auto_scale_allowed"]: |
| if config["auto_scale_batch"] or config["auto_scale_roi"] or config["auto_scale_filters"]: |
| roi_size, _, init_filters, batch_size = auto_adjust_network_settings( |
| auto_scale_batch=config["auto_scale_batch"], |
| auto_scale_roi=config["auto_scale_roi"], |
| auto_scale_filters=config["auto_scale_filters"], |
| image_size_mm=config["image_size_mm_median"], |
| spacing=config["resample_resolution"], |
| anisotropic_scales=config["anisotropic_scales"], |
| levels=len(config["network"]["blocks_down"]), |
| output_classes=config["output_classes"], |
| ) |
|
|
| config["roi_size"] = roi_size |
| if config["auto_scale_batch"]: |
| config["batch_size"] = batch_size |
| if config["auto_scale_filters"] and config["pretrained_ckpt_name"] is None: |
| config["network"]["init_filters"] = init_filters |
|
|
| self.model = self.setup_model(pretrained_ckpt_name=config["pretrained_ckpt_name"]) |
|
|
| loss_function = ConfigParser(config["loss"]).get_parsed_content(instantiate=True) |
| self.loss_function = DeepSupervisionLoss(loss_function) |
|
|
| dice_ignore_empty = config.get("dice_ignore_empty", True) |
| self.acc_function = DiceHelper(threshold=config["sigmoid"], ignore_empty=dice_ignore_empty) |
| self.amp_device_type = "cuda" if torch.cuda.is_available() else "cpu" |
| self.grad_scaler = GradScaler(self.amp_device_type, enabled=config["amp"]) |
|
|
| if config.get("sliding_inferrer") is not None: |
| self.sliding_inferrer = ConfigParser(config["sliding_inferrer"]).get_parsed_content() |
| else: |
| self.sliding_inferrer = SlidingWindowInfererAdapt( |
| roi_size=config["roi_size"], |
| sw_batch_size=1, |
| overlap=0.625, |
| mode="gaussian", |
| cache_roi_weight_map=True, |
| progress=False, |
| ) |
|
|
| self._data_transform_builder: DataTransformBuilder = None |
| self.lr_scheduler = None |
| self.optimizer = None |
|
|
| def get_custom_transforms(self): |
| config = self.config |
|
|
| |
| custom_transforms = {} |
| for tr in config.get("custom_data_transforms", []): |
| must_include_keys = ("key", "path", "transform") |
| if not all(k in tr for k in must_include_keys): |
| raise ValueError("custom transform must include " + str(must_include_keys)) |
|
|
| if os.path.abspath(tr["path"]) not in sys.path: |
| sys.path.append(os.path.abspath(tr["path"])) |
|
|
| custom_transforms.setdefault(tr["key"], []) |
| custom_transforms[tr["key"]].append(tr["transform"]) |
|
|
| if len(custom_transforms) > 0 and self.global_rank == 0: |
| print(f"Using custom transforms {custom_transforms}") |
|
|
| if isinstance(config["class_index"], list) and len(config["class_index"]) > 0: |
| |
| custom_transforms.setdefault("final_transforms", []) |
| custom_transforms["final_transforms"].append( |
| LabelEmbedClassIndex(keys="label", class_index=config["class_index"], allow_missing_keys=True) |
| ) |
|
|
| return custom_transforms |
|
|
| def get_data_transform_builder(self): |
| if self._data_transform_builder is None: |
| config = self.config |
| custom_transforms = self.get_custom_transforms() |
|
|
| self._data_transform_builder = DataTransformBuilder( |
| roi_size=config["roi_size"], |
| resample=config["resample"], |
| resample_resolution=config["resample_resolution"], |
| normalize_mode=config["normalize_mode"], |
| normalize_params={ |
| "intensity_bounds": config["intensity_bounds"], |
| "label_dtype": torch.uint8 if config["input_channels"] < 255 else torch.int16, |
| "image_dtype": torch.int16 if config.get("cache_image_int16", False) else None, |
| }, |
| crop_mode=config["crop_mode"], |
| crop_params={ |
| "output_classes": config["output_classes"], |
| "input_channels": config["input_channels"], |
| "crop_ratios": config["crop_ratios"], |
| "cache_class_indices": config["cache_class_indices"], |
| "num_crops_per_image": config["num_crops_per_image"], |
| "max_samples_per_class": config["max_samples_per_class"], |
| "crop_add_background": config["crop_add_background"], |
| }, |
| augment_params={ |
| "augment_mode": config.get("augment_mode", None), |
| "augment_flips": config.get("augment_flips", None), |
| "augment_rots": config.get("augment_rots", None), |
| }, |
| extra_modalities=config["extra_modalities"], |
| custom_transforms=custom_transforms, |
| crop_foreground=config.get("crop_foreground", True), |
| sigmoid=config["sigmoid"], |
| orientation_ras=config.get("orientation_ras", False), |
| class_index=config["class_index"], |
| debug=config["debug"], |
| ) |
|
|
| return self._data_transform_builder |
|
|
| def setup_model(self, pretrained_ckpt_name=None): |
| config = self.config |
| spatial_dims = config["network"].get("spatial_dims", 3) |
| norm_name, norm_args = split_args(config["network"].get("norm", "")) |
| norm_name = norm_name.upper() |
|
|
| if norm_name == "INSTANCE_NVFUSER": |
| _, has_nvfuser = optional_import("apex.normalization", name="InstanceNorm3dNVFuser") |
| if has_nvfuser and spatial_dims == 3: |
| act = config["network"].get("act", "relu") |
| if isinstance(act, str): |
| config["network"]["act"] = [act, {"inplace": False}] |
| else: |
| norm_name = "INSTANCE" |
|
|
| if len(norm_name) > 0: |
| config["network"]["norm"] = norm_name if len(norm_args) == 0 else [norm_name, norm_args] |
|
|
| if spatial_dims == 3: |
| if config.get("anisotropic_scales", False) and "SegResNetDS" in config["network"]["_target_"]: |
| config["network"]["resolution"] = copy.deepcopy(config["resample_resolution"]) |
| if self.global_rank == 0: |
| print(f"Using anisotropic scales {config['network']}") |
|
|
| model = ConfigParser(config["network"]).get_parsed_content() |
|
|
| if self.global_rank == 0: |
| print(str(model)) |
|
|
| if pretrained_ckpt_name is not None: |
| self.checkpoint_load(ckpt=pretrained_ckpt_name, model=model) |
|
|
| model = model.to(self.device) |
|
|
| if spatial_dims == 3: |
| memory_format = torch.channels_last_3d if config["channels_last"] else torch.preserve_format |
| model = model.to(memory_format=memory_format) |
|
|
| if self.distributed and not config["infer"]["enabled"]: |
| model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) |
| model = DistributedDataParallel( |
| module=model, device_ids=[self.rank], output_device=self.rank, find_unused_parameters=False |
| ) |
|
|
| if self.global_rank == 0: |
| pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| print(f"Total parameters count: {pytorch_total_params} distributed: {self.distributed}") |
|
|
| return model |
|
|
| def parse_input_config( |
| self, config_file: Optional[Union[str, Sequence[str]]] = None, override: Dict = {} |
| ) -> Tuple[ConfigParser, Dict]: |
| config = {} |
| if config_file is None or override.get("use_ckpt_config", False): |
| |
| for ckpt_key in ["pretrained_ckpt_name", "validate#ckpt_name", "infer#ckpt_name", "finetune#ckpt_name"]: |
| ckpt = override.get(ckpt_key, None) |
| if ckpt and os.path.exists(ckpt): |
| checkpoint = torch.load(ckpt, map_location="cpu") |
| config = checkpoint.get("config", {}) |
| if self.global_rank == 0: |
| print(f"Initializing config from the checkpoint {ckpt}: {yaml.dump(config)}") |
|
|
| if len(config) == 0 and config_file is None: |
| warnings.warn("No input config_file provided, and no valid checkpoints found") |
|
|
| if config_file is not None and len(config) == 0: |
| config = ConfigParser.load_config_files(config_file) |
| config.setdefault("finetune", {"enabled": False, "ckpt_name": None}) |
| config.setdefault( |
| "validate", {"enabled": False, "ckpt_name": None, "save_mask": False, "output_path": None} |
| ) |
| config.setdefault("infer", {"enabled": False, "ckpt_name": None}) |
|
|
| parser = ConfigParser(config=config) |
| parser.update(pairs=override) |
| config = parser.config |
|
|
| if config.get("data_file_base_dir", None) is None or config.get("data_list_file_path", None) is None: |
| raise ValueError("CONFIG: data_file_base_dir and data_list_file_path must be provided") |
|
|
| if config.get("bundle_root", None) is None: |
| config["bundle_root"] = str(Path(__file__).parent.parent) |
|
|
| if "modality" not in config: |
| if self.global_rank == 0: |
| warnings.warn("CONFIG: modality is not provided, assuming MRI") |
| config["modality"] = "mri" |
|
|
| if "normalize_mode" not in config: |
| config["normalize_mode"] = "range" if config["modality"].lower() == "ct" else "meanstd" |
| if self.global_rank == 0: |
| print(f"CONFIG: normalize_mode is not provided, assuming: {config['normalize_mode']}") |
|
|
| |
| config.setdefault("debug", False) |
|
|
| config.setdefault("loss", None) |
| config.setdefault("acc", None) |
| config.setdefault("amp", True) |
| config.setdefault("cuda", True) |
| config.setdefault("fold", 0) |
| config.setdefault("batch_size", 1) |
| config.setdefault("determ", False) |
| config.setdefault("quick", False) |
| config.setdefault("sigmoid", False) |
| config.setdefault("cache_rate", None) |
| config.setdefault("cache_class_indices", None) |
| config.setdefault("crop_add_background", True) |
| config.setdefault("orientation_ras", False) |
|
|
| config.setdefault("channels_last", True) |
| config.setdefault("fork", True) |
|
|
| config.setdefault("num_epochs", 300) |
| config.setdefault("num_warmup_epochs", 3) |
| config.setdefault("num_epochs_per_validation", None) |
| config.setdefault("num_epochs_per_saving", 10) |
| config.setdefault("num_steps_per_image", None) |
| config.setdefault("num_crops_per_image", 1) |
| config.setdefault("max_samples_per_class", None) |
|
|
| config.setdefault("calc_val_loss", False) |
| config.setdefault("validate_final_original_res", False) |
| config.setdefault("early_stopping_fraction", 0) |
| config.setdefault("start_epoch", 0) |
|
|
| config.setdefault("ckpt_path", None) |
| config.setdefault("ckpt_save", True) |
| config.setdefault("log_output_file", None) |
|
|
| config.setdefault("crop_mode", "ratio") |
| config.setdefault("crop_ratios", None) |
| config.setdefault("resample_resolution", [1.0, 1.0, 1.0]) |
| config.setdefault("resample", False) |
| config.setdefault("roi_size", [128, 128, 128]) |
| config.setdefault("num_workers", 4) |
| config.setdefault("extra_modalities", {}) |
| config.setdefault("intensity_bounds", [-250, 250]) |
| config.setdefault("stop_on_lowacc", True) |
|
|
| config.setdefault("float32_precision", None) |
| config.setdefault("notf32", False) |
|
|
| config.setdefault("class_index", None) |
| config.setdefault("class_names", []) |
| if not isinstance(config["class_names"], (list, tuple)): |
| config["class_names"] = [] |
|
|
| if len(config["class_names"]) == 0: |
| n_foreground_classes = int(config["output_classes"]) |
| if not config["sigmoid"]: |
| n_foreground_classes -= 1 |
| config["class_names"] = ["acc_" + str(i) for i in range(n_foreground_classes)] |
|
|
| pretrained_ckpt_name = config.get("pretrained_ckpt_name", None) |
| if pretrained_ckpt_name is None: |
| if config["validate"]["enabled"]: |
| pretrained_ckpt_name = config["validate"]["ckpt_name"] |
| elif config["infer"]["enabled"]: |
| pretrained_ckpt_name = config["infer"]["ckpt_name"] |
| elif config["finetune"]["enabled"]: |
| pretrained_ckpt_name = config["finetune"]["ckpt_name"] |
| config["pretrained_ckpt_name"] = pretrained_ckpt_name |
|
|
| config.setdefault("auto_scale_allowed", False) |
| config.setdefault("auto_scale_batch", False) |
| config.setdefault("auto_scale_roi", False) |
| config.setdefault("auto_scale_filters", False) |
|
|
| if pretrained_ckpt_name is not None: |
| config["auto_scale_roi"] = False |
| config["auto_scale_filters"] = False |
|
|
| if config["max_samples_per_class"] is None: |
| config["max_samples_per_class"] = 10 * config["num_epochs"] |
|
|
| if not torch.cuda.is_available() and config["cuda"]: |
| print("No cuda is available.! Running on CPU!!!") |
| config["cuda"] = False |
|
|
| config["amp"] = config["amp"] and config["cuda"] |
| config["rank"] = self.rank |
| config["global_rank"] = self.global_rank |
|
|
| |
| for k, v in config.items(): |
| if isinstance(v, dict) and "_target_" in v: |
| config[k] = parser.get_parsed_content(k, instantiate=False).config |
| elif "_target_" in str(v): |
| config[k] = copy.deepcopy(v) |
| else: |
| config[k] = parser.get_parsed_content(k) |
|
|
| return config |
|
|
| def config_save_updated(self, save_path=None): |
| if self.global_rank == 0 and self.config["auto_scale_allowed"]: |
| |
| config = ConfigParser.load_config_files(self.config_file) |
| parser = ConfigParser(config=config) |
| parser.update(pairs=self.override) |
| config = parser.config |
|
|
| config["batch_size"] = self.config["batch_size"] |
| config["roi_size"] = self.config["roi_size"] |
| config["num_crops_per_image"] = self.config["num_crops_per_image"] |
|
|
| if "init_filters" in self.config["network"]: |
| config["network"]["init_filters"] = self.config["network"]["init_filters"] |
|
|
| if save_path is None: |
| save_path = self.config_file |
|
|
| print(f"Re-saving main config to {save_path}.") |
| ConfigParser.export_config_file(config, save_path, fmt="yaml", default_flow_style=None, sort_keys=False) |
|
|
| def config_with_relpath(self, config=None): |
| if config is None: |
| config = self.config |
| config = copy.deepcopy(config) |
| bundle_root = config["bundle_root"] |
|
|
| def convert_rel_path(conf): |
| for k, v in conf.items(): |
| if isinstance(v, str) and v.startswith(bundle_root): |
| conf[k] = f"$@bundle_root + '/{os.path.relpath(v, bundle_root)}'" |
|
|
| convert_rel_path(config) |
| convert_rel_path(config["finetune"]) |
| convert_rel_path(config["validate"]) |
| convert_rel_path(config["infer"]) |
| config["bundle_root"] = bundle_root |
|
|
| return config |
|
|
| def checkpoint_save(self, ckpt: str, model: torch.nn.Module, **kwargs): |
| save_time = time.time() |
| if isinstance(model, torch.nn.parallel.DistributedDataParallel): |
| state_dict = model.module.state_dict() |
| else: |
| state_dict = model.state_dict() |
|
|
| config = self.config_with_relpath() |
|
|
| torch.save({"state_dict": state_dict, "config": config, **kwargs}, ckpt) |
|
|
| save_time = time.time() - save_time |
| print(f"Saving checkpoint process: {ckpt}, {kwargs}, save_time {save_time:.2f}s") |
|
|
| return save_time |
|
|
| def checkpoint_load(self, ckpt: str, model: torch.nn.Module, **kwargs): |
| if not os.path.isfile(ckpt): |
| if self.global_rank == 0: |
| warnings.warn("Invalid checkpoint file: " + str(ckpt)) |
| else: |
| checkpoint = torch.load(ckpt, map_location="cpu") |
| model.load_state_dict(checkpoint["state_dict"], strict=True) |
| epoch = checkpoint.get("epoch", 0) |
| best_metric = checkpoint.get("best_metric", 0) |
|
|
| if self.config.pop("continue", False): |
| if "epoch" in checkpoint: |
| self.config["start_epoch"] = checkpoint["epoch"] |
| if "best_metric" in checkpoint: |
| self.config["best_metric"] = checkpoint["best_metric"] |
|
|
| print( |
| f"=> loaded checkpoint {ckpt} (epoch {epoch}) (best_metric {best_metric}) setting start_epoch {self.config['start_epoch']}" |
| ) |
| self.config["start_epoch"] = int(self.config["start_epoch"]) + 1 |
|
|
| def get_shared_memory_list(self, length=0): |
| mp.current_process().authkey = np.arange(32, dtype=np.uint8).tobytes() |
| shl0 = mp.Manager().list([None] * length) |
|
|
| if self.distributed: |
| |
| is_multinode = False |
|
|
| if dist_launched(): |
| local_world_size = int(os.getenv("LOCAL_WORLD_SIZE")) |
| world_size = int(os.getenv("WORLD_SIZE")) |
| group_rank = int(os.getenv("GROUP_RANK")) |
| if world_size > local_world_size: |
| is_multinode = True |
| |
| lw = torch.tensor(local_world_size, dtype=torch.int, device=self.device) |
| lw_sizes = [torch.zeros_like(lw) for _ in range(world_size)] |
| dist.all_gather(tensor_list=lw_sizes, tensor=lw) |
|
|
| src = g_rank = 0 |
| while src < world_size: |
| |
| |
| group = dist.new_group(ranks=list(range(src, src + local_world_size))) |
| if group_rank == g_rank: |
| shl_list = [shl0] |
| dist.broadcast_object_list(shl_list, src=src, group=group, device=self.device) |
| shl = shl_list[0] |
| dist.destroy_process_group(group) |
| src = src + lw_sizes[src].item() |
| g_rank += 1 |
|
|
| if not is_multinode: |
| shl_list = [shl0] |
| dist.broadcast_object_list(shl_list, src=0, device=self.device) |
| shl = shl_list[0] |
|
|
| else: |
| shl = shl0 |
|
|
| return shl |
|
|
| def get_train_loader(self, data, cache_rate=0, persistent_workers=False): |
| distributed = self.distributed |
| num_workers = self.config["num_workers"] |
| batch_size = self.config["batch_size"] |
|
|
| train_transform = self.get_data_transform_builder()(augment=True, resample_label=True) |
|
|
| if cache_rate > 0: |
| runtime_cache = self.get_shared_memory_list(length=len(data)) |
| train_ds = CacheDataset( |
| data=data, |
| transform=train_transform, |
| copy_cache=False, |
| cache_rate=cache_rate, |
| runtime_cache=runtime_cache, |
| ) |
| else: |
| train_ds = Dataset(data=data, transform=train_transform) |
|
|
| train_sampler = DistributedSampler(train_ds, shuffle=True) if distributed else None |
| train_loader = DataLoader( |
| train_ds, |
| batch_size=batch_size, |
| shuffle=(train_sampler is None), |
| num_workers=num_workers, |
| sampler=train_sampler, |
| persistent_workers=persistent_workers and num_workers > 0, |
| pin_memory=True, |
| ) |
|
|
| return train_loader |
|
|
| def get_val_loader(self, data, cache_rate=0, resample_label=False, persistent_workers=False): |
| distributed = self.distributed |
| num_workers = self.config["num_workers"] |
|
|
| val_transform = self.get_data_transform_builder()(augment=False, resample_label=resample_label) |
|
|
| if cache_rate > 0: |
| runtime_cache = self.get_shared_memory_list(length=len(data)) |
| val_ds = CacheDataset( |
| data=data, transform=val_transform, copy_cache=False, cache_rate=cache_rate, runtime_cache=runtime_cache |
| ) |
| else: |
| val_ds = Dataset(data=data, transform=val_transform) |
|
|
| val_sampler = DistributedSampler(val_ds, shuffle=False) if distributed else None |
| val_loader = DataLoader( |
| val_ds, |
| batch_size=1, |
| shuffle=False, |
| num_workers=num_workers, |
| sampler=val_sampler, |
| persistent_workers=persistent_workers and num_workers > 0, |
| pin_memory=True, |
| ) |
|
|
| return val_loader |
|
|
| def train(self): |
| if self.global_rank == 0: |
| print("Segmenter train called") |
|
|
| if self.loss_function is None: |
| raise ValueError("CONFIG loss function is not provided") |
| if self.acc_function is None: |
| raise ValueError("CONFIG accuracy function is not provided") |
|
|
| config = self.config |
| distributed = self.distributed |
| sliding_inferrer = self.sliding_inferrer |
|
|
| loss_function = self.loss_function |
| acc_function = self.acc_function |
| grad_scaler = self.grad_scaler |
|
|
| use_amp = config["amp"] |
| use_cuda = config["cuda"] |
| ckpt_path = config["ckpt_path"] |
| sigmoid = config["sigmoid"] |
| channels_last = config["channels_last"] |
| calc_val_loss = config["calc_val_loss"] |
|
|
| data_list_file_path = config["data_list_file_path"] |
| if not os.path.isabs(data_list_file_path): |
| data_list_file_path = os.path.abspath(os.path.join(config["bundle_root"], data_list_file_path)) |
|
|
| if config.get("validation_key", None) is not None: |
| train_files, _ = datafold_read(datalist=data_list_file_path, basedir=config["data_file_base_dir"], fold=-1) |
| validation_files, _ = datafold_read( |
| datalist=data_list_file_path, |
| basedir=config["data_file_base_dir"], |
| fold=-1, |
| key=config["validation_key"], |
| ) |
| else: |
| train_files, validation_files = datafold_read( |
| datalist=data_list_file_path, basedir=config["data_file_base_dir"], fold=config["fold"] |
| ) |
|
|
| if config["quick"]: |
| train_files, validation_files = train_files[:8], validation_files[:8] |
| if self.global_rank == 0: |
| print(f"train_files files {len(train_files)}, validation files {len(validation_files)}") |
|
|
| if len(validation_files) == 0: |
| warnings.warn("No validation files found!") |
|
|
| cache_rate_train, cache_rate_val = self.get_cache_rate( |
| train_cases=len(train_files), validation_cases=len(validation_files) |
| ) |
|
|
| if config["cache_class_indices"] is None: |
| config["cache_class_indices"] = cache_rate_train > 0 |
|
|
| if self.global_rank == 0: |
| print( |
| f"Auto setting max_samples_per_class: {config['max_samples_per_class']} cache_class_indices: {config['cache_class_indices']}" |
| ) |
|
|
| num_steps_per_image = config["num_steps_per_image"] |
| if config["auto_scale_allowed"] and num_steps_per_image is None: |
| be = config["batch_size"] |
|
|
| if config["crop_mode"] == "ratio": |
| config["num_crops_per_image"] = config["batch_size"] |
| config["batch_size"] = 1 |
| else: |
| config["num_crops_per_image"] = 1 |
|
|
| if cache_rate_train < 0.75: |
| num_steps_per_image = max(1, 4 // be) |
| else: |
| num_steps_per_image = 1 |
|
|
| elif num_steps_per_image is None: |
| num_steps_per_image = 1 |
|
|
| num_crops_per_image = int(config["num_crops_per_image"]) |
| num_epochs_per_saving = max(1, config["num_epochs_per_saving"] // num_crops_per_image) |
| num_warmup_epochs = max(3, config["num_warmup_epochs"] // num_crops_per_image) |
| num_epochs_per_validation = config["num_epochs_per_validation"] |
| num_epochs = max(1, config["num_epochs"] // min(3, num_crops_per_image)) |
| if self.global_rank == 0: |
| print( |
| f"Given num_crops_per_image {num_crops_per_image}, num_epochs was adjusted {config['num_epochs']} => {num_epochs}" |
| ) |
|
|
| if num_epochs_per_validation is not None: |
| num_epochs_per_validation = max(1, num_epochs_per_validation // num_crops_per_image) |
|
|
| val_schedule_list = schedule_validation_epochs( |
| num_epochs=num_epochs, |
| num_epochs_per_validation=num_epochs_per_validation, |
| fraction=min(0.3, 0.16 * num_crops_per_image), |
| ) |
| if self.global_rank == 0: |
| print(f"Scheduling validation loops at epochs: {val_schedule_list}") |
|
|
| train_loader = self.get_train_loader(data=train_files, cache_rate=cache_rate_train, persistent_workers=True) |
|
|
| val_loader = self.get_val_loader( |
| data=validation_files, cache_rate=cache_rate_val, resample_label=True, persistent_workers=True |
| ) |
|
|
| optim_name = config.get("optim_name", None) |
| if optim_name is not None: |
| if self.global_rank == 0: |
| print(f"Using optimizer: {optim_name}") |
| if optim_name == "fusednovograd": |
| import apex |
|
|
| optimizer = apex.optimizers.FusedNovoGrad( |
| params=self.model.parameters(), lr=config["learning_rate"], weight_decay=1.0e-5 |
| ) |
| elif optim_name == "sgd": |
| momentum = config.get("sgd_momentum", 0.9) |
| optimizer = torch.optim.SGD( |
| params=self.model.parameters(), lr=config["learning_rate"], weight_decay=1.0e-5, momentum=momentum |
| ) |
| if self.global_rank == 0: |
| print(f"Using momentum: {momentum}") |
| else: |
| raise ValueError("Unsupported optim_name" + str(optim_name)) |
|
|
| elif self.optimizer is None: |
| optimizer_part = ConfigParser(config["optimizer"]).get_parsed_content(instantiate=False) |
| optimizer = optimizer_part.instantiate(params=self.model.parameters()) |
| else: |
| optimizer = self.optimizer |
|
|
| tb_writer = None |
| csv_path = progress_path = None |
|
|
| if self.global_rank == 0 and ckpt_path is not None: |
| |
| progress_path = os.path.join(ckpt_path, "progress.yaml") |
|
|
| tb_writer = SummaryWriter(log_dir=ckpt_path) |
| print(f"Writing Tensorboard logs to {tb_writer.log_dir}") |
|
|
| if mlflow_is_imported: |
| mlflow.set_tracking_uri(config["mlflow_tracking_uri"]) |
| mlflow.set_experiment(config["mlflow_experiment_name"]) |
| mlflow.start_run(run_name=f'segresnet - fold{config["fold"]} - train') |
|
|
| csv_path = os.path.join(ckpt_path, "accuracy_history.csv") |
| self.save_history_csv( |
| csv_path=csv_path, |
| header=["epoch", "metric", "loss", "iter", "time", "train_time", "validation_time", "epoch_time"], |
| ) |
|
|
| do_torch_save = (self.global_rank == 0) and ckpt_path is not None and config["ckpt_save"] |
| best_ckpt_path = os.path.join(ckpt_path, "model.pt") |
| intermediate_ckpt_path = os.path.join(ckpt_path, "model_final.pt") |
|
|
| best_metric = -1 |
| best_metric_epoch = -1 |
| pre_loop_time = time.time() |
| report_num_epochs = num_epochs * num_crops_per_image |
| train_time = validation_time = 0 |
| val_acc_history = [] |
|
|
| start_epoch = config["start_epoch"] |
| if "best_metric" in config: |
| best_metric = float(config["best_metric"]) |
|
|
| start_epoch = start_epoch // num_crops_per_image |
| if start_epoch > 0: |
| val_schedule_list = [v for v in val_schedule_list if v >= start_epoch] |
| if len(val_schedule_list) == 0: |
| val_schedule_list = [start_epoch] |
| print(f"adjusted schedule_list {val_schedule_list}") |
|
|
| if self.global_rank == 0: |
| print( |
| f"Using num_epochs => {num_epochs}\n " |
| f"Using start_epoch => {start_epoch}\n " |
| f"batch_size => {config['batch_size']} \n " |
| f"num_crops_per_image => {config['num_crops_per_image']} \n " |
| f"num_steps_per_image => {num_steps_per_image} \n " |
| f"num_warmup_epochs => {num_warmup_epochs} \n " |
| ) |
|
|
| if self.lr_scheduler is None: |
| lr_scheduler = WarmupCosineSchedule( |
| optimizer=optimizer, warmup_steps=num_warmup_epochs, warmup_multiplier=0.1, t_total=num_epochs |
| ) |
| else: |
| lr_scheduler = self.lr_scheduler |
| if lr_scheduler is not None and start_epoch > 0: |
| lr_scheduler.last_epoch = start_epoch |
|
|
| range_num_epochs = range(start_epoch, num_epochs) |
| if self.global_rank == 0 and has_tqdm and not config["debug"]: |
| range_num_epochs = tqdm( |
| range(start_epoch, num_epochs), |
| desc=str(os.path.basename(config["bundle_root"])) + " - training", |
| unit="epoch", |
| ) |
|
|
| if distributed: |
| dist.barrier() |
| self.config_save_updated(save_path=self.config_file) |
|
|
| for epoch in range_num_epochs: |
| report_epoch = epoch * num_crops_per_image |
|
|
| if distributed: |
| if isinstance(train_loader.sampler, DistributedSampler): |
| train_loader.sampler.set_epoch(epoch) |
| dist.barrier() |
|
|
| epoch_time = start_time = time.time() |
|
|
| train_loss, train_acc = 0, 0 |
| if not config.get("skip_train", False): |
| train_loss, train_acc = self.train_epoch( |
| model=self.model, |
| train_loader=train_loader, |
| optimizer=optimizer, |
| loss_function=loss_function, |
| acc_function=acc_function, |
| grad_scaler=grad_scaler, |
| epoch=report_epoch, |
| rank=self.rank, |
| global_rank=self.global_rank, |
| num_epochs=report_num_epochs, |
| sigmoid=sigmoid, |
| use_amp=use_amp, |
| use_cuda=use_cuda, |
| channels_last=channels_last, |
| num_steps_per_image=num_steps_per_image, |
| ) |
|
|
| train_time = time.time() - start_time |
|
|
| if self.global_rank == 0: |
| print( |
| f"Final training {report_epoch}/{report_num_epochs - 1} " |
| f"loss: {train_loss:.4f} acc_avg: {np.mean(train_acc):.4f} " |
| f"acc {train_acc} time {train_time:.2f}s " |
| f"lr: {optimizer.param_groups[0]['lr']:.4e}" |
| ) |
|
|
| if tb_writer is not None: |
| tb_writer.add_scalar("train/loss", train_loss, report_epoch) |
| tb_writer.add_scalar("train/acc", np.mean(train_acc), report_epoch) |
| if mlflow_is_imported: |
| mlflow.log_metric("train/loss", train_loss, step=report_epoch) |
|
|
| |
| val_acc_mean = -1 |
| if ( |
| len(val_schedule_list) > 0 |
| and epoch + 1 >= val_schedule_list[0] |
| and val_loader is not None |
| and len(val_loader) > 0 |
| ): |
| val_schedule_list.pop(0) |
|
|
| start_time = time.time() |
| torch.cuda.empty_cache() |
|
|
| val_loss, val_acc = self.val_epoch( |
| model=self.model, |
| val_loader=val_loader, |
| sliding_inferrer=sliding_inferrer, |
| loss_function=loss_function, |
| acc_function=acc_function, |
| epoch=report_epoch, |
| rank=self.rank, |
| global_rank=self.global_rank, |
| num_epochs=report_num_epochs, |
| sigmoid=sigmoid, |
| use_amp=use_amp, |
| use_cuda=use_cuda, |
| channels_last=channels_last, |
| calc_val_loss=calc_val_loss, |
| ) |
|
|
| torch.cuda.empty_cache() |
| validation_time = time.time() - start_time |
|
|
| val_acc_mean = float(np.mean(val_acc)) |
| val_acc_history.append((report_epoch, val_acc_mean)) |
|
|
| if self.global_rank == 0: |
| print( |
| f"Final validation {report_epoch}/{report_num_epochs - 1} " |
| f"loss: {val_loss:.4f} acc_avg: {val_acc_mean:.4f} acc: {val_acc} time: {validation_time:.2f}s" |
| ) |
|
|
| if tb_writer is not None: |
| tb_writer.add_scalar("val/acc", val_acc_mean, report_epoch) |
| if mlflow_is_imported: |
| mlflow.log_metric("val/acc", val_acc_mean, step=report_epoch) |
|
|
| for i in range(min(len(config["class_names"]), len(val_acc))): |
| tb_writer.add_scalar("val_class/" + config["class_names"][i], val_acc[i], report_epoch) |
| if mlflow_is_imported: |
| mlflow.log_metric( |
| "val_class/" + config["class_names"][i], val_acc[i], step=report_epoch |
| ) |
|
|
| if calc_val_loss: |
| tb_writer.add_scalar("val/loss", val_loss, report_epoch) |
|
|
| timing_dict = dict( |
| time="{:.2f} hr".format((time.time() - pre_loop_time) / 3600), |
| train_time="{:.2f}s".format(train_time), |
| validation_time="{:.2f}s".format(validation_time), |
| epoch_time="{:.2f}s".format(time.time() - epoch_time), |
| ) |
|
|
| if val_acc_mean > best_metric: |
| print(f"New best metric ({best_metric:.6f} --> {val_acc_mean:.6f}). ") |
| best_metric, best_metric_epoch = val_acc_mean, report_epoch |
| save_time = 0 |
| if do_torch_save: |
| save_time = self.checkpoint_save( |
| ckpt=best_ckpt_path, model=self.model, epoch=best_metric_epoch, best_metric=best_metric |
| ) |
|
|
| if progress_path is not None: |
| self.save_progress_yaml( |
| progress_path=progress_path, |
| ckpt=best_ckpt_path if do_torch_save else None, |
| best_avg_dice_score_epoch=best_metric_epoch, |
| best_avg_dice_score=best_metric, |
| save_time=save_time, |
| **timing_dict, |
| ) |
| if csv_path is not None: |
| self.save_history_csv( |
| csv_path=csv_path, |
| epoch=report_epoch, |
| metric="{:.4f}".format(val_acc_mean), |
| loss="{:.4f}".format(train_loss), |
| iter=report_epoch * len(train_loader.dataset), |
| **timing_dict, |
| ) |
|
|
| |
| if epoch > max(20, num_epochs / 4) and 0 <= val_acc_mean < 0.01 and config["stop_on_lowacc"]: |
| raise ValueError( |
| f"Accuracy seems very low at epoch {report_epoch}, acc {val_acc_mean}. " |
| f"Most likely optimization diverged, try setting a smaller learning_rate than {config['learning_rate']}" |
| ) |
|
|
| |
| if config["early_stopping_fraction"] > 0 and epoch > num_epochs / 2 and len(val_acc_history) > 10: |
| check_interval = int(0.1 * num_epochs * num_crops_per_image) |
| check_stats = [ |
| va[1] for va in val_acc_history if report_epoch - va[0] < check_interval |
| ] |
| if len(check_stats) < 10: |
| check_stats = [va[1] for va in val_acc_history[-10:]] |
| mac, mic = max(check_stats), min(check_stats) |
|
|
| early_stopping_fraction = (mac - mic) / (abs(mac) + 1e-8) |
| if mac > 0 and mic > 0 and early_stopping_fraction < config["early_stopping_fraction"]: |
| if self.global_rank == 0: |
| print( |
| f"Early stopping at epoch {report_epoch} fraction {early_stopping_fraction} !!! max {mac} min {mic} samples count {len(check_stats)} {check_stats[-50:]}" |
| ) |
| break |
| else: |
| if self.global_rank == 0: |
| print( |
| f"No stopping at epoch {report_epoch} fraction {early_stopping_fraction} !!! max {mac} min {mic} samples count {len(check_stats)} {check_stats[-50:]}" |
| ) |
|
|
| |
| if do_torch_save and ((epoch + 1) % num_epochs_per_saving == 0 or (epoch + 1) >= num_epochs): |
| if report_epoch != best_metric_epoch: |
| self.checkpoint_save( |
| ckpt=intermediate_ckpt_path, model=self.model, epoch=report_epoch, best_metric=val_acc_mean |
| ) |
| else: |
| shutil.copyfile(best_ckpt_path, intermediate_ckpt_path) |
|
|
| if lr_scheduler is not None: |
| lr_scheduler.step() |
|
|
| if self.global_rank == 0: |
| |
| time_remaining_estimate = train_time * (num_epochs - epoch) |
| if val_loader is not None and len(val_loader) > 0: |
| if validation_time == 0: |
| validation_time = train_time |
| time_remaining_estimate += validation_time * len(val_schedule_list) |
|
|
| print( |
| f"Estimated remaining training time for the current model fold {config['fold']} is " |
| f"{time_remaining_estimate/3600:.2f} hr, " |
| f"running time {(time.time() - pre_loop_time)/3600:.2f} hr, " |
| f"est total time {(time.time() - pre_loop_time + time_remaining_estimate)/3600:.2f} hr \n" |
| ) |
|
|
| |
|
|
| train_loader = val_loader = optimizer = None |
|
|
| |
| orig_res = config["resample"] == False |
| if config["validate_final_original_res"] and config["resample"]: |
| pretrained_ckpt_name = best_ckpt_path if os.path.exists(best_ckpt_path) else intermediate_ckpt_path |
| if os.path.exists(pretrained_ckpt_name): |
| self.model = None |
| gc.collect() |
| torch.cuda.empty_cache() |
|
|
| best_metric = self.original_resolution_validate( |
| pretrained_ckpt_name=pretrained_ckpt_name, |
| progress_path=progress_path, |
| best_metric_epoch=best_metric_epoch, |
| pre_loop_time=pre_loop_time, |
| ) |
| orig_res = True |
| else: |
| if self.global_rank == 0: |
| print( |
| f"Unable to validate at the original res since no model checkpoints found {best_ckpt_path}, {intermediate_ckpt_path}" |
| ) |
|
|
| if tb_writer is not None: |
| tb_writer.flush() |
| tb_writer.close() |
|
|
| if mlflow_is_imported: |
| mlflow.end_run() |
|
|
| if self.global_rank == 0: |
| print( |
| f"=== DONE: best_metric: {best_metric:.4f} at epoch: {best_metric_epoch} of {report_num_epochs} orig_res {orig_res}. Training time {(time.time() - pre_loop_time)/3600:.2f} hr." |
| ) |
|
|
| return best_metric |
|
|
| def original_resolution_validate(self, pretrained_ckpt_name, progress_path, best_metric_epoch, pre_loop_time): |
| if self.global_rank == 0: |
| print("Running final best model validation on the original image resolution!") |
|
|
| self.model = self.setup_model(pretrained_ckpt_name=pretrained_ckpt_name) |
|
|
| |
| start_time = time.time() |
| val_acc_mean, val_loss, val_acc = self.validate() |
| validation_time = "{:.2f}s".format(time.time() - start_time) |
| val_acc_mean = float(np.mean(val_acc)) |
| if self.global_rank == 0: |
| print( |
| f"Original resolution validation: " |
| f"loss: {val_loss:.4f} acc_avg: {val_acc_mean:.4f} " |
| f"acc {val_acc} time {validation_time}" |
| ) |
|
|
| if progress_path is not None: |
| self.save_progress_yaml( |
| progress_path=progress_path, |
| ckpt=pretrained_ckpt_name, |
| best_avg_dice_score_epoch=best_metric_epoch, |
| best_avg_dice_score=val_acc_mean, |
| validation_time=validation_time, |
| inverted_best_validation=True, |
| time="{:.2f} hr".format((time.time() - pre_loop_time) / 3600), |
| ) |
|
|
| return val_acc_mean |
|
|
| def validate(self, validation_files=None): |
| config = self.config |
| resample = config["resample"] |
|
|
| val_config = self.config["validate"] |
| output_path = val_config.get("output_path", None) |
| save_mask = val_config.get("save_mask", False) and output_path is not None |
| invert = val_config.get("invert", True) |
|
|
| data_list_file_path = config["data_list_file_path"] |
| if not os.path.isabs(data_list_file_path): |
| data_list_file_path = os.path.abspath(os.path.join(config["bundle_root"], data_list_file_path)) |
|
|
| if validation_files is None: |
| if config.get("validation_key", None) is not None: |
| validation_files, _ = datafold_read( |
| datalist=data_list_file_path, |
| basedir=config["data_file_base_dir"], |
| fold=-1, |
| key=config["validation_key"], |
| ) |
| else: |
| _, validation_files = datafold_read( |
| datalist=data_list_file_path, basedir=config["data_file_base_dir"], fold=config["fold"] |
| ) |
|
|
| if self.global_rank == 0: |
| print(f"validation files {len(validation_files)}") |
|
|
| if len(validation_files) == 0: |
| warnings.warn("No validation files found!") |
| return |
|
|
| val_loader = self.get_val_loader(data=validation_files, resample_label=not invert) |
| val_transform = val_loader.dataset.transform |
|
|
| post_transforms = None |
| if save_mask or invert: |
| post_transforms = DataTransformBuilder.get_postprocess_transform( |
| save_mask=save_mask, |
| invert=invert, |
| transform=val_transform, |
| sigmoid=self.config["sigmoid"], |
| output_path=output_path, |
| resample=resample, |
| data_root_dir=self.config["data_file_base_dir"], |
| output_dtype=np.uint8 if self.config["output_classes"] < 255 else np.uint16, |
| save_mask_mode=self.config.get("save_mask_mode", None), |
| ) |
|
|
| start_time = time.time() |
| val_loss, val_acc = self.val_epoch( |
| model=self.model, |
| val_loader=val_loader, |
| sliding_inferrer=self.sliding_inferrer, |
| loss_function=self.loss_function, |
| acc_function=self.acc_function, |
| rank=self.rank, |
| global_rank=self.global_rank, |
| sigmoid=self.config["sigmoid"], |
| use_amp=self.config["amp"], |
| use_cuda=self.config["cuda"], |
| post_transforms=post_transforms, |
| channels_last=self.config["channels_last"], |
| calc_val_loss=self.config["calc_val_loss"], |
| ) |
| val_acc_mean = float(np.mean(val_acc)) |
|
|
| if self.global_rank == 0: |
| print( |
| f"Validation complete, loss_avg: {val_loss:.4f} " |
| f"acc_avg: {val_acc_mean:.4f} acc {val_acc} time {time.time() - start_time:.2f}s" |
| ) |
|
|
| return val_acc_mean, val_loss, val_acc |
|
|
| def infer(self, testing_files=None): |
| output_path = self.config["infer"].get("output_path", None) |
| testing_key = self.config["infer"].get("data_list_key", "testing") |
|
|
| if output_path is None: |
| if self.global_rank == 0: |
| print("Inference output_path is not specified") |
| return |
|
|
| if testing_files is None: |
| data_list_file_path = self.config["data_list_file_path"] |
| if not os.path.isabs(data_list_file_path): |
| data_list_file_path = os.path.abspath(os.path.join(self.config["bundle_root"], data_list_file_path)) |
|
|
| testing_files, _ = datafold_read( |
| datalist=data_list_file_path, basedir=self.config["data_file_base_dir"], fold=-1, key=testing_key |
| ) |
|
|
| if self.global_rank == 0: |
| print(f"testing_files files {len(testing_files)}") |
|
|
| if len(testing_files) == 0: |
| warnings.warn("No testing_files files found!") |
| return |
|
|
| inf_loader = self.get_val_loader(data=testing_files, resample_label=False) |
| inf_transform = inf_loader.dataset.transform |
|
|
| post_transforms = DataTransformBuilder.get_postprocess_transform( |
| save_mask=True, |
| invert=True, |
| transform=inf_transform, |
| sigmoid=self.config["sigmoid"], |
| output_path=output_path, |
| resample=self.config["resample"], |
| data_root_dir=self.config["data_file_base_dir"], |
| output_dtype=np.uint8 if self.config["output_classes"] < 255 else np.uint16, |
| save_mask_mode=self.config.get("save_mask_mode", None), |
| ) |
|
|
| start_time = time.time() |
| self.val_epoch( |
| model=self.model, |
| val_loader=inf_loader, |
| sliding_inferrer=self.sliding_inferrer, |
| rank=self.rank, |
| global_rank=self.global_rank, |
| sigmoid=self.config["sigmoid"], |
| use_amp=self.config["amp"], |
| use_cuda=self.config["cuda"], |
| post_transforms=post_transforms, |
| channels_last=self.config["channels_last"], |
| calc_val_loss=self.config["calc_val_loss"], |
| ) |
|
|
| if self.global_rank == 0: |
| print(f"Inference complete, time {time.time() - start_time:.2f}s") |
|
|
| @torch.no_grad() |
| def infer_image(self, image_file): |
| self.model.eval() |
|
|
| infer_config = self.config["infer"] |
| output_path = infer_config.get("output_path", None) |
| save_mask = infer_config.get("save_mask", False) and output_path is not None |
| invert_on_gpu = infer_config.get("invert_on_gpu", False) |
|
|
| start_time = time.time() |
| sigmoid = self.config["sigmoid"] |
| resample = self.config["resample"] |
| channels_last = self.config["channels_last"] |
|
|
| inf_transform = self.get_data_transform_builder()(augment=False, resample_label=False) |
|
|
| batch_data = inf_transform([image_file]) |
| batch_data = list_data_collate([batch_data]) |
|
|
| memory_format = torch.channels_last_3d if channels_last else torch.preserve_format |
| data = batch_data["image"].as_subclass(torch.Tensor).to(memory_format=memory_format, device=self.device) |
|
|
| with autocast(self.amp_device_type, enabled=self.config["amp"]): |
| logits = self.sliding_inferrer(inputs=data, network=self.model) |
|
|
| data = None |
|
|
| try: |
| pred = self.logits2pred(logits, sigmoid=sigmoid) |
| except RuntimeError as e: |
| if not logits.is_cuda: |
| raise e |
| print(f"logits2pred failed on GPU pred retrying on CPU {logits.shape}") |
| logits = logits.cpu() |
| pred = self.logits2pred(logits, sigmoid=sigmoid) |
|
|
| logits = None |
|
|
| if not invert_on_gpu: |
| pred = pred.cpu() |
|
|
| post_transforms = DataTransformBuilder.get_postprocess_transform( |
| save_mask=save_mask, |
| invert=True, |
| transform=inf_transform, |
| sigmoid=sigmoid, |
| output_path=output_path, |
| resample=resample, |
| data_root_dir=self.config["data_file_base_dir"], |
| output_dtype=np.uint8 if self.config["output_classes"] < 255 else np.uint16, |
| save_mask_mode=self.config.get("save_mask_mode", None), |
| ) |
|
|
| batch_data["pred"] = convert_to_dst_type(pred, batch_data["image"], dtype=pred.dtype, device=pred.device)[ |
| 0 |
| ] |
| pred = [post_transforms(x)["pred"] for x in decollate_batch(batch_data)] |
|
|
| pred = pred[0] |
|
|
| print(f"Inference complete, time {time.time() - start_time:.2f}s shape {pred.shape} {image_file}") |
|
|
| return pred |
|
|
| def train_epoch( |
| self, |
| model, |
| train_loader, |
| optimizer, |
| loss_function, |
| acc_function, |
| grad_scaler, |
| epoch, |
| rank, |
| global_rank=0, |
| num_epochs=0, |
| sigmoid=False, |
| use_amp=True, |
| use_cuda=True, |
| channels_last=False, |
| num_steps_per_image=1, |
| ): |
| model.train() |
| device = torch.device(rank) if use_cuda else torch.device("cpu") |
| memory_format = torch.channels_last_3d if channels_last else torch.preserve_format |
|
|
| run_loss = CumulativeAverage() |
| run_acc = CumulativeAverage() |
|
|
| start_time = time.time() |
| avg_loss = avg_acc = 0 |
| for idx, batch_data in enumerate(train_loader): |
| data = batch_data["image"].as_subclass(torch.Tensor).to(memory_format=memory_format, device=device) |
| target = batch_data["label"].as_subclass(torch.Tensor).to(memory_format=memory_format, device=device) |
|
|
| data_list = data.chunk(num_steps_per_image) if num_steps_per_image > 1 else [data] |
| target_list = target.chunk(num_steps_per_image) if num_steps_per_image > 1 else [target] |
|
|
| for ich in range(min(num_steps_per_image, len(data_list))): |
| data = data_list[ich] |
| target = target_list[ich] |
|
|
| |
| for param in model.parameters(): |
| param.grad = None |
|
|
| with autocast(self.amp_device_type, enabled=use_amp): |
| logits = model(data) |
|
|
| loss = loss_function(logits, target) |
| grad_scaler.scale(loss).backward() |
| grad_scaler.step(optimizer) |
| grad_scaler.update() |
|
|
| with torch.no_grad(): |
| pred = self.logits2pred(logits, sigmoid=sigmoid, skip_softmax=True) |
| acc = acc_function(pred, target) |
|
|
| batch_size_adjusted = batch_size = data.shape[0] |
| if isinstance(acc, (list, tuple)): |
| acc, batch_size_adjusted = acc |
|
|
| run_loss.append(loss, count=batch_size) |
| run_acc.append(acc, count=batch_size_adjusted) |
|
|
| avg_loss = run_loss.aggregate() |
| avg_acc = run_acc.aggregate() |
|
|
| if global_rank == 0: |
| print( |
| f"Epoch {epoch}/{num_epochs} {idx}/{len(train_loader)} " |
| f"loss: {avg_loss:.4f} acc {avg_acc} time {time.time() - start_time:.2f}s " |
| ) |
| start_time = time.time() |
|
|
| |
| for param in model.parameters(): |
| param.grad = None |
|
|
| data = None |
| target = None |
| data_list = None |
| target_list = None |
| batch_data = None |
|
|
| return avg_loss, avg_acc |
|
|
| @torch.no_grad() |
| def val_epoch( |
| self, |
| model, |
| val_loader, |
| sliding_inferrer, |
| loss_function=None, |
| acc_function=None, |
| epoch=0, |
| rank=0, |
| global_rank=0, |
| num_epochs=0, |
| sigmoid=False, |
| use_amp=True, |
| use_cuda=True, |
| post_transforms=None, |
| channels_last=False, |
| calc_val_loss=False, |
| ): |
| model.eval() |
| device = torch.device(rank) if use_cuda else torch.device("cpu") |
| memory_format = torch.channels_last_3d if channels_last else torch.preserve_format |
| distributed = dist.is_initialized() |
|
|
| run_loss = CumulativeAverage() |
| run_acc = CumulativeAverage() |
| run_loss.append(torch.tensor(0, device=device), count=0) |
|
|
| avg_loss = avg_acc = 0 |
| start_time = time.time() |
|
|
| |
| |
| nonrepeated_data_length = len(val_loader.dataset) |
| sampler = val_loader.sampler |
| if dist.is_initialized and isinstance(sampler, DistributedSampler) and not sampler.drop_last: |
| nonrepeated_data_length = len(range(sampler.rank, len(sampler.dataset), sampler.num_replicas)) |
|
|
| for idx, batch_data in enumerate(val_loader): |
| data = batch_data["image"].as_subclass(torch.Tensor).to(memory_format=memory_format, device=device) |
| filename = batch_data["image"].meta[ImageMetaKey.FILENAME_OR_OBJ] |
| batch_size = data.shape[0] |
|
|
| with autocast(self.amp_device_type, enabled=use_amp): |
| logits = sliding_inferrer(inputs=data, network=model) |
|
|
| data = None |
|
|
| if post_transforms: |
|
|
| try: |
| pred = self.logits2pred(logits, sigmoid=sigmoid) |
| except RuntimeError as e: |
| if not logits.is_cuda: |
| raise e |
| print(f"logits2pred failed on GPU pred retrying on CPU {logits.shape} {filename}") |
| logits = logits.cpu() |
| pred = self.logits2pred(logits, sigmoid=sigmoid) |
|
|
| if not calc_val_loss: |
| logits = None |
|
|
| batch_data["pred"] = convert_to_dst_type( |
| pred, batch_data["image"], dtype=pred.dtype, device=pred.device |
| )[0] |
| pred = None |
|
|
| try: |
| |
| pred = torch.stack([post_transforms(x)["pred"] for x in decollate_batch(batch_data)]) |
| except RuntimeError as e: |
| if not batch_data["pred"].is_cuda: |
| raise e |
| print(f"post_transforms failed on GPU pred retrying on CPU {batch_data['pred'].shape}") |
| batch_data["pred"] = batch_data["pred"].cpu() |
| pred = torch.stack([post_transforms(x)["pred"] for x in decollate_batch(batch_data)]) |
|
|
| batch_data["pred"] = None |
| if logits is not None and pred.shape != logits.shape: |
| logits = None |
| else: |
| pred = self.logits2pred(logits, sigmoid=sigmoid, skip_softmax=True) |
|
|
| if "label" in batch_data and loss_function is not None and acc_function is not None: |
| loss = acc = None |
| target = batch_data["label"].as_subclass(torch.Tensor) |
|
|
| if calc_val_loss: |
| if logits is not None: |
| loss = loss_function(logits, target.to(device=logits.device)) |
| run_loss.append(loss.to(device=device), count=batch_size) |
| logits = None |
|
|
| with torch.no_grad(): |
| try: |
| acc = acc_function(pred.to(device=device), target.to(device=device)) |
| except RuntimeError as e: |
| if "OutOfMemoryError" not in str(type(e).__name__): |
| raise e |
| print( |
| f"acc_function val failed on GPU pred: {pred.shape} on {pred.device}, target: {target.shape} on {target.device}. retrying on CPU" |
| ) |
| acc = acc_function(pred.cpu(), target.cpu()) |
|
|
| batch_size_adjusted = batch_size |
| if isinstance(acc, (list, tuple)): |
| acc, batch_size_adjusted = acc |
| acc = acc.detach().clone() |
|
|
| if idx < nonrepeated_data_length: |
| run_acc.append(acc.to(device=device), count=batch_size_adjusted) |
| else: |
| run_acc.append(torch.zeros_like(acc, device=device), count=torch.zeros_like(batch_size_adjusted)) |
|
|
| avg_loss = loss.cpu() if loss is not None else 0 |
| avg_acc = acc.cpu().numpy() if acc is not None else 0 |
| pred, target = None, None |
|
|
| if global_rank == 0: |
| print( |
| f"Val {epoch}/{num_epochs} {idx}/{len(val_loader)} loss: {avg_loss:.4f} " |
| f"acc {avg_acc} time {time.time() - start_time:.2f}s {filename}" |
| ) |
|
|
| else: |
| if global_rank == 0: |
| print(f"Val {epoch}/{num_epochs} {idx}/{len(val_loader)} time {time.time() - start_time:.2f}s") |
|
|
| start_time = time.time() |
|
|
| pred = target = data = batch_data = None |
|
|
| if distributed: |
| dist.barrier() |
|
|
| avg_loss = run_loss.aggregate() |
| avg_acc = run_acc.aggregate() |
|
|
| if np.any(avg_acc < 0): |
| dist.barrier() |
| warnings.warn("Avg dice accuracy is negative, something went wrong!!!!!") |
|
|
| return avg_loss, avg_acc |
|
|
| def logits2pred(self, logits, sigmoid=False, dim=1, skip_softmax=False): |
| if isinstance(logits, (list, tuple)): |
| logits = logits[0] |
|
|
| if sigmoid: |
| pred = torch.sigmoid(logits) |
| else: |
| pred = logits if skip_softmax else torch.softmax(logits, dim=dim, dtype=torch.double).float() |
|
|
| return pred |
|
|
| def get_avail_cpu_memory(self): |
| avail_memory = psutil.virtual_memory().available |
|
|
| |
| memory_limit_filename = "/sys/fs/cgroup/memory/memory.limit_in_bytes" |
| if os.path.exists(memory_limit_filename): |
| with open(memory_limit_filename, "r") as f: |
| docker_limit = int(f.read()) |
| avail_memory = min(docker_limit, avail_memory) |
|
|
| return avail_memory |
|
|
| def get_cache_rate(self, train_cases=0, validation_cases=0, prioritise_train=True): |
| config = self.config |
| cache_rate = config["cache_rate"] |
| avail_memory = None |
|
|
| total_cases = train_cases + validation_cases |
|
|
| image_size_mm_90 = config.get("image_size_mm_90", None) |
| if config["resample"] and image_size_mm_90 is not None: |
| image_size = ( |
| (np.array(image_size_mm_90) / np.array(config["resample_resolution"])).astype(np.int32).tolist() |
| ) |
| else: |
| image_size = config["image_size"] |
|
|
| approx_data_cache_required = (4 * config["input_channels"] + 1) * np.prod(image_size) * total_cases |
| approx_os_cache_required = 50 * 1024**3 |
|
|
| if cache_rate is None: |
| cache_rate = 0 |
|
|
| if image_size is not None: |
| avail_memory = self.get_avail_cpu_memory() |
| cache_rate = min(avail_memory / float(approx_data_cache_required + approx_os_cache_required), 1.0) |
| if cache_rate < 0.1: |
| cache_rate = 0.0 |
|
|
| if self.global_rank == 0: |
| print( |
| f"Calculating cache required {approx_data_cache_required >> 30}GB, available RAM {avail_memory >> 30}GB given avg image size {image_size}." |
| ) |
| if cache_rate < 1: |
| print( |
| f"Available RAM is not enought to cache full dataset, caching a fraction {cache_rate:.2f}" |
| ) |
| else: |
| print("Caching full dataset in RAM") |
| else: |
| print("Cant calculate cache_rate since image_size is not provided!!!!") |
|
|
| else: |
| if self.global_rank == 0: |
| print(f"Using user specified cache_rate={cache_rate} to cache data in RAM") |
|
|
| |
| cache_rate_train = cache_rate_val = cache_rate |
|
|
| if prioritise_train: |
| if cache_rate > 0 and cache_rate < 1: |
| cache_num = cache_rate * total_cases |
| cache_rate_train = min(1.0, cache_num / train_cases) if train_cases > 0 else 0 |
| if (cache_rate_train < 1 and train_cases > 0) or validation_cases == 0: |
| cache_rate_val = 0 |
| else: |
| cache_rate_val = (cache_num - cache_rate_train * train_cases) / validation_cases |
|
|
| if self.global_rank == 0: |
| print(f"Prioritizing cache_rate training {cache_rate_train} validation {cache_rate_val}") |
|
|
| return cache_rate_train, cache_rate_val |
|
|
| def save_history_csv(self, csv_path=None, header=None, **kwargs): |
| if csv_path is not None: |
| if header is not None: |
| with open(csv_path, "a") as myfile: |
| wrtr = csv.writer(myfile, delimiter="\t") |
| wrtr.writerow(header) |
| if len(kwargs): |
| with open(csv_path, "a") as myfile: |
| wrtr = csv.writer(myfile, delimiter="\t") |
| wrtr.writerow(list(kwargs.values())) |
|
|
| def save_progress_yaml(self, progress_path=None, ckpt=None, **report): |
| if ckpt is not None: |
| report["model"] = ckpt |
|
|
| report["date"] = str(datetime.now())[:19] |
|
|
| if progress_path is not None: |
| yaml.add_representer( |
| float, lambda dumper, value: dumper.represent_scalar("tag:yaml.org,2002:float", "{0:.4f}".format(value)) |
| ) |
| with open(progress_path, "a") as progress_file: |
| yaml.dump([report], stream=progress_file, allow_unicode=True, default_flow_style=None, sort_keys=False) |
|
|
| print("Progress:" + ",".join(f" {k}: {v}" for k, v in report.items())) |
|
|
| def run(self): |
| if self.config["validate"]["enabled"]: |
| self.validate() |
| elif self.config["infer"]["enabled"]: |
| self.infer() |
| else: |
| self.train() |
|
|
|
|
| def run_segmenter_worker(rank=0, config_file: Optional[Union[str, Sequence[str]]] = None, override: Dict = {}): |
| logging.basicConfig(stream=sys.stdout, level=logging.INFO) |
| dist_available = dist.is_available() |
| global_rank = rank |
|
|
| if type(config_file) == str and "," in config_file: |
| config_file = config_file.split(",") |
|
|
| if dist_available: |
| mgpu = override.get("mgpu", None) |
| if mgpu is not None: |
| logging.getLogger("torch.distributed.distributed_c10d").setLevel(logging.WARNING) |
| dist.init_process_group(backend="nccl", rank=rank, timeout=timedelta(seconds=5400), **mgpu) |
| mgpu.update({"rank": rank, "global_rank": rank}) |
| if rank == 0: |
| print(f"Distributed: initializing multi-gpu tcp:// process group {mgpu}") |
|
|
| elif dist_launched() and torch.cuda.device_count() > 1: |
| rank = int(os.getenv("LOCAL_RANK")) |
| global_rank = int(os.getenv("RANK")) |
| world_size = int(os.getenv("LOCAL_WORLD_SIZE")) |
| logging.getLogger("torch.distributed.distributed_c10d").setLevel(logging.WARNING) |
| dist.init_process_group(backend="nccl", init_method="env://") |
| override["mgpu"] = {"world_size": world_size, "rank": rank, "global_rank": global_rank} |
|
|
| print(f"Distributed launched: initializing multi-gpu env:// process group {override['mgpu']}") |
|
|
| segmenter = Segmenter(config_file=config_file, config_dict=override, rank=rank, global_rank=global_rank) |
| best_metric = segmenter.run() |
| segmenter = None |
|
|
| if dist_available and dist.is_initialized(): |
| dist.destroy_process_group() |
|
|
| return best_metric |
|
|
|
|
| def dist_launched() -> bool: |
| return dist.is_torchelastic_launched() or ( |
| os.getenv("NGC_ARRAY_SIZE") is not None and int(os.getenv("NGC_ARRAY_SIZE")) > 1 |
| ) |
|
|
|
|
| def run_segmenter(config_file: Optional[Union[str, Sequence[str]]] = None, **kwargs): |
| """ |
| if multiple gpu available, start multiprocessing for all gpus |
| """ |
|
|
| nprocs = torch.cuda.device_count() |
|
|
| if nprocs > 1 and not dist_launched(): |
| print("Manually spawning processes {nprocs}") |
| kwargs["mgpu"] = {"world_size": nprocs, "init_method": kwargs.get("init_method", "tcp://127.0.0.1:23456")} |
| torch.multiprocessing.spawn(run_segmenter_worker, nprocs=nprocs, args=(config_file, kwargs)) |
| else: |
| print("Not spawning processes, dist is already launched {nprocs}") |
| run_segmenter_worker(0, config_file, kwargs) |
|
|
|
|
| if __name__ == "__main__": |
| fire, fire_is_imported = optional_import("fire") |
| if fire_is_imported: |
| fire.Fire(run_segmenter) |
| else: |
| warnings.warn("Fire commandline parser cannot be imported, using options from config/hyper_parameters.yaml") |
| run_segmenter(config_file="config/hyper_parameters.yaml") |
|
|