diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..d3ed83ee7cfea6fc87800c63d9b4ada18601db8c 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,17 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +assets/overview.png filter=lfs diff=lfs merge=lfs -text +ShapeID/out/2d/image_with_v.png filter=lfs diff=lfs merge=lfs -text +ShapeID/out/2d/progression/New[[:space:]]Folder[[:space:]]With[[:space:]]Items/0.png filter=lfs diff=lfs merge=lfs -text +ShapeID/out/2d/progression/New[[:space:]]Folder[[:space:]]With[[:space:]]Items/1.png filter=lfs diff=lfs merge=lfs -text +ShapeID/out/2d/progression/New[[:space:]]Folder[[:space:]]With[[:space:]]Items/10.png filter=lfs diff=lfs merge=lfs -text +ShapeID/out/2d/progression/New[[:space:]]Folder[[:space:]]With[[:space:]]Items/2.png filter=lfs diff=lfs merge=lfs -text +ShapeID/out/2d/progression/New[[:space:]]Folder[[:space:]]With[[:space:]]Items/3.png filter=lfs diff=lfs merge=lfs -text +ShapeID/out/2d/progression/New[[:space:]]Folder[[:space:]]With[[:space:]]Items/4.png filter=lfs diff=lfs merge=lfs -text +ShapeID/out/2d/progression/New[[:space:]]Folder[[:space:]]With[[:space:]]Items/5.png filter=lfs diff=lfs merge=lfs -text +ShapeID/out/2d/progression/New[[:space:]]Folder[[:space:]]With[[:space:]]Items/6.png filter=lfs diff=lfs merge=lfs -text +ShapeID/out/2d/progression/New[[:space:]]Folder[[:space:]]With[[:space:]]Items/7.png filter=lfs diff=lfs merge=lfs -text +ShapeID/out/2d/progression/New[[:space:]]Folder[[:space:]]With[[:space:]]Items/8.png filter=lfs diff=lfs merge=lfs -text +ShapeID/out/2d/progression/New[[:space:]]Folder[[:space:]]With[[:space:]]Items/9.png filter=lfs diff=lfs merge=lfs -text +ShapeID/out/2d/V.png filter=lfs diff=lfs merge=lfs -text diff --git a/Generator/__init__.py b/Generator/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..656da50fb51f13cba60d7b11e47335f58917d397 --- /dev/null +++ b/Generator/__init__.py @@ -0,0 +1,22 @@ + +""" +Datasets interface. +""" +from .constants import dataset_setups +from .datasets import BaseGen, BrainIDGen + + + +dataset_options = { + 'default': BaseGen, + 'brain_id': BrainIDGen, +} + + + + +def build_datasets(gen_args, device): + """Helper function to build dataset for different splits ('train' or 'test').""" + datasets = {'all': dataset_options[gen_args.dataset_option](gen_args, device)} + return datasets + diff --git a/Generator/config.py b/Generator/config.py new file mode 100644 index 0000000000000000000000000000000000000000..9c0dcfe23ded0dc5a036f7d84f86f40e5bc3e38c --- /dev/null +++ b/Generator/config.py @@ -0,0 +1,181 @@ + +"""Config utilities for yml file.""" +import os +from argparse import Namespace +import collections +import functools +import os +import re + +import yaml +# from imaginaire.utils.distributed import master_only_print as print + + +class AttrDict(dict): + """Dict as attribute trick.""" + + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self + for key, value in self.__dict__.items(): + if isinstance(value, dict): + self.__dict__[key] = AttrDict(value) + elif isinstance(value, (list, tuple)): + if isinstance(value[0], dict): + self.__dict__[key] = [AttrDict(item) for item in value] + else: + self.__dict__[key] = value + + def yaml(self): + """Convert object to yaml dict and return.""" + yaml_dict = {} + for key, value in self.__dict__.items(): + if isinstance(value, AttrDict): + yaml_dict[key] = value.yaml() + elif isinstance(value, list): + if isinstance(value[0], AttrDict): + new_l = [] + for item in value: + new_l.append(item.yaml()) + yaml_dict[key] = new_l + else: + yaml_dict[key] = value + else: + yaml_dict[key] = value + return yaml_dict + + def __repr__(self): + """Print all variables.""" + ret_str = [] + for key, value in self.__dict__.items(): + if isinstance(value, AttrDict): + ret_str.append('{}:'.format(key)) + child_ret_str = value.__repr__().split('\n') + for item in child_ret_str: + ret_str.append(' ' + item) + elif isinstance(value, list): + if isinstance(value[0], AttrDict): + ret_str.append('{}:'.format(key)) + for item in value: + # Treat as AttrDict above. + child_ret_str = item.__repr__().split('\n') + for item in child_ret_str: + ret_str.append(' ' + item) + else: + ret_str.append('{}: {}'.format(key, value)) + else: + ret_str.append('{}: {}'.format(key, value)) + return '\n'.join(ret_str) + + +class Config(AttrDict): + r"""Configuration class. This should include every human specifiable + hyperparameter values for your training.""" + + def __init__(self, filename=None, verbose=False): + super(Config, self).__init__() + + # Update with given configurations. + if os.path.exists(filename): + + loader = yaml.SafeLoader + loader.add_implicit_resolver( + u'tag:yaml.org,2002:float', + re.compile(u'''^(?: + [-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)? + |[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+) + |\\.[0-9_]+(?:[eE][-+][0-9]+)? + |[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]* + |[-+]?\\.(?:inf|Inf|INF) + |\\.(?:nan|NaN|NAN))$''', re.X), + list(u'-+0123456789.')) + try: + with open(filename, 'r') as f: + cfg_dict = yaml.load(f, Loader=loader) + except EnvironmentError: + print('Please check the file with name of "%s"', filename) + recursive_update(self, cfg_dict) + else: + raise ValueError('Provided config path not existed: %s' % filename) + + if verbose: + print(' imaginaire config '.center(80, '-')) + print(self.__repr__()) + print(''.center(80, '-')) + + +def rsetattr(obj, attr, val): + """Recursively find object and set value""" + pre, _, post = attr.rpartition('.') + return setattr(rgetattr(obj, pre) if pre else obj, post, val) + + +def rgetattr(obj, attr, *args): + """Recursively find object and return value""" + + def _getattr(obj, attr): + r"""Get attribute.""" + return getattr(obj, attr, *args) + + return functools.reduce(_getattr, [obj] + attr.split('.')) + + +def recursive_update(d, u): + """Recursively update AttrDict d with AttrDict u""" + if u is not None: + for key, value in u.items(): + if isinstance(value, collections.abc.Mapping): + d.__dict__[key] = recursive_update(d.get(key, AttrDict({})), value) + elif isinstance(value, (list, tuple)): + if len(value) > 0 and isinstance(value[0], dict): + d.__dict__[key] = [AttrDict(item) for item in value] + else: + d.__dict__[key] = value + else: + d.__dict__[key] = value + return d + + +def merge_and_update_from_dict(cfg, dct): + """ + (Compatible for submitit's Dict as attribute trick) + Merge dict as dict() to config as CfgNode(). + Args: + cfg: dict + dct: dict + """ + if dct is not None: + for key, value in dct.items(): + if isinstance(value, dict): + if key in cfg.keys(): + sub_cfgnode = cfg[key] + else: + sub_cfgnode = dict() + cfg.__setattr__(key, sub_cfgnode) + sub_cfgnode = merge_and_update_from_dict(sub_cfgnode, value) + else: + cfg[key] = value + return cfg + + +def load_config(cfg_files = [], cfg_dir = ''): + cfg = Config(cfg_files[0]) + for cfg_file in cfg_files[1:]: + add_cfg = Config(cfg_file) + cfg = merge_and_update_from_dict(cfg, add_cfg) + return cfg + + +def nested_dict_to_namespace(dictionary): + namespace = dictionary + if isinstance(dictionary, dict): + namespace = Namespace(**dictionary) + for key, value in dictionary.items(): + setattr(namespace, key, nested_dict_to_namespace(value)) + return namespace + + +def preprocess_cfg(cfg_files, cfg_dir = ''): + config = load_config(cfg_files, cfg_dir) + args = nested_dict_to_namespace(config) + return args \ No newline at end of file diff --git a/Generator/constants.py b/Generator/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..b43eb1fc73482d892b64534fa7e4519ac0167fa6 --- /dev/null +++ b/Generator/constants.py @@ -0,0 +1,290 @@ +import os, glob + +from .utils import * + +augmentation_funcs = { + 'gamma': add_gamma_transform, + 'bias_field': add_bias_field, + 'resample': resample_resolution, + 'noise': add_noise, +} + +processing_funcs = { + 'T1': read_and_deform_image, + 'T2': read_and_deform_image, + 'FLAIR': read_and_deform_image, + 'CT': read_and_deform_CT, + 'segmentation': read_and_deform_segmentation, + 'surface': read_and_deform_surface, + 'distance': read_and_deform_distance, + 'bias_field': read_and_deform_bias_field, + 'registration': read_and_deform_registration, + 'pathology': read_and_deform_pathology, +} + + +dataset_setups = { + + 'ADHD': { + 'root': '/autofs/space/yogurt_001/users/pl629/data/adhd200_crop', + 'pathology_type': None, + 'train': 'train.txt', + 'test': 'test.txt', + 'modalities': ['T1'], + + 'paths':{ + # for synth + 'Gen': 'label_maps_generation', + 'Dmaps': None, + 'DmapsBag': None, + + # real images + 'T1': 'T1', + 'T2': None, + 'FLAIR': None, + 'CT': None, + + # processed ground truths + 'surface': None, #'surfaces', TODO + 'distance': None, + 'segmentation': 'label_maps_segmentation', + 'bias_field': None, + 'pathology': None, + 'pathology_prob': None, + } + }, + + 'HCP': { + 'root': '/autofs/space/yogurt_001/users/pl629/data/hcp_crop', + 'pathology_type': None, + 'train': 'train.txt', + 'test': 'test.txt', + 'modalities': ['T1', 'T2'], + + 'paths':{ + # for synth + 'Gen': 'label_maps_generation', + 'Dmaps': None, + 'DmapsBag': None, + + # real images + 'T1': 'T1', + 'T2': 'T2', + 'FLAIR': None, + 'CT': None, + + # processed ground truths + 'surface': None, #'surfaces', + 'distance': None, + 'segmentation': 'label_maps_segmentation', + 'bias_field': None, + 'pathology': None, + 'pathology_prob': None, + } + }, + + 'AIBL': { + 'root': '/autofs/space/yogurt_001/users/pl629/data/aibl_crop', + 'pathology_type': None, + 'train': 'train.txt', + 'test': 'test.txt', + 'modalities': ['T1', 'T2', 'FLAIR'], + + 'paths':{ + # for synth + 'Gen': 'label_maps_generation', + 'Dmaps': None, + 'DmapsBag': None, + + # real images + 'T1': 'T1', + 'T2': 'T2', + 'FLAIR': 'FLAIR', + 'CT': None, + + # processed ground truths + 'surface': None, #'surfaces', + 'distance': None, + 'segmentation': 'label_maps_segmentation', + 'bias_field': None, + 'pathology': None, + 'pathology_prob': None, + } + }, + + 'OASIS': { + 'root': '/autofs/space/yogurt_001/users/pl629/data/oasis3', + 'pathology_type': None, + 'train': 'train.txt', + 'test': 'test.txt', + 'modalities': ['T1', 'CT'], + + 'paths':{ + # for synth + 'Gen': 'label_maps_generation', + 'Dmaps': None, + 'DmapsBag': None, + + # real images + 'T1': 'T1', + 'T2': None, + 'FLAIR': None, + 'CT': 'CT', + + # processed ground truths + 'surface': None, #'surfaces', + 'distance': None, + 'segmentation': 'label_maps_segmentation', + 'bias_field': None, + 'pathology': None, + 'pathology_prob': None, + } + }, + + 'ADNI': { + 'root': '/autofs/space/yogurt_001/users/pl629/data/adni_crop', + 'pathology_type': None, #'wmh', + 'train': 'train.txt', + 'test': 'test.txt', + 'modalities': ['T1'], + + 'paths':{ + # for synth + 'Gen': 'label_maps_generation', + 'Dmaps': 'Dmaps', + 'DmapsBag': 'DmapsBag', + + # real images + 'T1': 'T1', + 'T2': None, + 'FLAIR': None, + 'CT': None, + + # processed ground truths + 'surface': 'surfaces', + 'distance': 'Dmaps', + 'segmentation': 'label_maps_segmentation', + 'bias_field': None, + 'pathology': 'pathology_maps_segmentation', + 'pathology_prob': 'pathology_probability', + } + }, + + 'ADNI3': { + 'root': '/autofs/space/yogurt_001/users/pl629/data/adni3_crop', + 'pathology_type': None, # 'wmh', + 'train': 'train.txt', + 'test': 'test.txt', + 'modalities': ['T1', 'FLAIR'], + + 'paths':{ + # for synth + 'Gen': 'label_maps_generation', + 'Dmaps': None, + 'DmapsBag': None, + + # real images + 'T1': 'T1', + 'T2': None, + 'FLAIR': 'FLAIR', + 'CT': None, + + # processed ground truths + 'surface': None, #'surfaces', TODO + 'distance': None, + 'segmentation': 'label_maps_segmentation', + 'bias_field': None, + 'pathology': 'pathology_maps_segmentation', + 'pathology_prob': 'pathology_probability', + } + }, + + 'ATLAS': { + 'root': '/autofs/space/yogurt_001/users/pl629/data/atlas_crop', + 'pathology_type': 'stroke', + 'train': 'train.txt', + 'test': 'test.txt', + 'modalities': ['T1'], + + 'paths':{ + # for synth + 'Gen': 'label_maps_generation', + 'Dmaps': None, + 'DmapsBag': None, + + # real images + 'T1': 'T1', + 'T2': None, + 'FLAIR': None, + 'CT': None, + + # processed ground truths + 'surface': None, #'surfaces', TODO + 'distance': None, + 'segmentation': 'label_maps_segmentation', + 'bias_field': None, + 'pathology': 'pathology_maps_segmentation', + 'pathology_prob': 'pathology_probability', + } + }, + + 'ISLES': { + 'root': '/autofs/space/yogurt_001/users/pl629/data/isles2022_crop', + 'pathology_type': 'stroke', + 'train': 'train.txt', + 'test': 'test.txt', + 'modalities': ['FLAIR'], + + 'paths':{ + # for synth + 'Gen': 'label_maps_generation', + 'Dmaps': None, + 'DmapsBag': None, + + # real images + 'T1': None, + 'T2': None, + 'FLAIR': 'FLAIR', + 'CT': None, + + # processed ground truths + 'surface': None, #'surfaces', TODO + 'distance': None, + 'segmentation': 'label_maps_segmentation', + 'bias_field': None, + 'pathology': 'pathology_maps_segmentation', + 'pathology_prob': 'pathology_probability', + } + }, +} + + +all_dataset_names = dataset_setups.keys() + + +# get all pathologies +pathology_paths = [] +pathology_prob_paths = [] +for name, dict in dataset_setups.items(): + # TODO: select what kind of shapes? + if dict['paths']['pathology'] is not None and dict['pathology_type'] is not None and dict['pathology_type'] == 'stroke': + pathology_paths += glob.glob(os.path.join(dict['root'], dict['paths']['pathology'], '*.nii.gz')) \ + + glob.glob(os.path.join(dict['root'], dict['paths']['pathology'], '*.nii')) + pathology_prob_paths += glob.glob(os.path.join(dict['root'], dict['paths']['pathology_prob'], '*.nii.gz')) \ + + glob.glob(os.path.join(dict['root'], dict['paths']['pathology_prob'], '*.nii')) +n_pathology = len(pathology_paths) + + +# with csf # NOTE old version (FreeSurfer standard), non-vast +label_list_segmentation = [0,14,15,16,24,77,85, 2, 3, 4, 7, 8, 10,11,12,13,17,18,26,28, 41,42,43,46,47,49,50,51,52,53,54,58,60] # 33 +n_neutral_labels = 7 + + +## NEW VAST synth +label_list_segmentation_brainseg_with_extracerebral = [0, 11, 12, 13, 16, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 14, 15, 17, 47, 49, 51, 53, 55, + 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 48, 50, 52, 54, 56] +n_neutral_labels_brainseg_with_extracerebral = 20 + +label_list_segmentation_brainseg_left = [0, 1, 2, 3, 4, 7, 8, 9, 10, 14, 15, 17, 31, 34, 36, 38, 40, 42] + diff --git a/Generator/datasets.py b/Generator/datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..7bd67774e95cfe9c5f697a53659db7208f3d0505 --- /dev/null +++ b/Generator/datasets.py @@ -0,0 +1,757 @@ +import os, sys, glob +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from collections import defaultdict +import random + +import torch +import numpy as np +import nibabel as nib +from torch.utils.data import Dataset + + +from .utils import * +from .constants import n_pathology, pathology_paths, pathology_prob_paths, \ + n_neutral_labels_brainseg_with_extracerebral, label_list_segmentation_brainseg_with_extracerebral, \ + label_list_segmentation_brainseg_left, augmentation_funcs, processing_funcs +import utils.interpol as interpol + +from utils.misc import viewVolume + + +from ShapeID.DiffEqs.pde import AdvDiffPDE + + + +class BaseGen(Dataset): + """ + BaseGen dataset + """ + def __init__(self, gen_args, device='cpu'): + + self.gen_args = gen_args + self.split = gen_args.split + + self.synth_args = self.gen_args.generator + self.shape_gen_args = gen_args.pathology_shape_generator + self.real_image_args = gen_args.real_image_generator + self.synth_image_args = gen_args.synth_image_generator + self.augmentation_steps = vars(gen_args.augmentation_steps) + self.input_prob = vars(gen_args.modality_probs) + self.device = device + + self.prepare_tasks() + self.prepare_paths() + self.prepare_grid() + self.prepare_one_hot() + + + def __len__(self): + return sum([len(self.names[i]) for i in range(len(self.names))]) + + + def idx_to_path(self, idx): + cnt = 0 + for i, l in enumerate(self.datasets_len): + if idx >= cnt and idx < cnt + l: + dataset_name = self.datasets[i] + age = self.ages[i][os.path.basename(self.names[i][idx - cnt]).split('.T1w')[0]] if len(self.ages) > 0 else None + return dataset_name, vars(self.input_prob[dataset_name]), self.names[i][idx - cnt], age + else: + cnt += l + + + def prepare_paths(self): + + # Collect list of available images, per dataset + if len(self.gen_args.dataset_names) < 1: + datasets = [] + g = glob.glob(os.path.join(self.gen_args.data_root, '*' + 'T1w.nii')) + for i in range(len(g)): + filename = os.path.basename(g[i]) + dataset = filename[:filename.find('.')] + found = False + for d in datasets: + if dataset == d: + found = True + if found is False: + datasets.append(dataset) + print('Found ' + str(len(datasets)) + ' datasets with ' + str(len(g)) + ' scans in total') + else: + datasets = self.gen_args.dataset_names + print('Dataset list', datasets) + + + names = [] + if 'age' in self.tasks: + self.split = self.split + '_age' + if self.gen_args.split_root is not None: + split_file = open(os.path.join(self.gen_args.split_root, self.split + '.txt'), 'r') + split_names = [] + for subj in split_file.readlines(): + split_names.append(subj.strip()) + + for i in range(len(datasets)): + names.append([name for name in split_names if os.path.basename(name).startswith(datasets[i])]) + #else: + # for i in range(len(datasets)): + # names.append(glob.glob(os.path.join(self.gen_args.data_root, datasets[i] + '.*' + 'T1w.nii'))) + + # read brain age + ages = [] + if 'age' in self.tasks: + age_file = open(os.path.join(self.gen_args.split_root, 'participants_age.txt'), 'r') + subj_name_age = [] + for line in age_file.readlines(): # 'subj age\n' + subj_name_age.append(line.strip().split(' ')) + for i in range(len(datasets)): + ages.append({}) + for [name, age] in subj_name_age: + if name.startswith(datasets[i]): + ages[-1][name] = float(age) + print('Age info', self.split, len(ages[0].items()), min(ages[0].values()), max(ages[0].values())) + + self.ages = ages + self.names = names + self.datasets = datasets + self.datasets_num = len(datasets) + self.datasets_len = [len(self.names[i]) for i in range(len(self.names))] + print('Num of data', sum([len(self.names[i]) for i in range(len(self.names))])) + + self.pathology_type = None #setup_dict['pathology_type'] + + + def prepare_tasks(self): + self.tasks = [key for (key, value) in vars(self.gen_args.task).items() if value] + if 'bias_field' in self.tasks and 'segmentation' not in self.tasks: + # add segmentation mask for computing bias_field_soft_mask + self.tasks += ['segmentation'] + if 'pathology' in self.tasks and self.synth_args.augment_pathology and self.synth_args.random_shape_prob < 1.: + self.t = torch.from_numpy(np.arange(self.shape_gen_args.max_nt) * self.shape_gen_args.dt).to(self.device) + with torch.no_grad(): + self.adv_pde = AdvDiffPDE(data_spacing=[1., 1., 1.], + perf_pattern='adv', + V_type='vector_div_free', + V_dict={}, + BC=self.shape_gen_args.bc, + dt=self.shape_gen_args.dt, + device=self.device + ) + else: + self.t, self.adv_pde = None, None + for task_name in self.tasks: + if task_name not in processing_funcs.keys(): + print('Warning: Function for task "%s" not found' % task_name) + + + def prepare_grid(self): + self.size = self.synth_args.size + + # Get resolution of training data + #aff = nib.load(os.path.join(self.modalities['Gen'], self.names[0])).affine + #self.res_training_data = np.sqrt(np.sum(abs(aff[:-1, :-1]), axis=0)) + + self.res_training_data = np.array([1.0, 1.0, 1.0]) + + xx, yy, zz = np.meshgrid(range(self.size[0]), range(self.size[1]), range(self.size[2]), sparse=False, indexing='ij') + self.xx = torch.tensor(xx, dtype=torch.float, device=self.device) + self.yy = torch.tensor(yy, dtype=torch.float, device=self.device) + self.zz = torch.tensor(zz, dtype=torch.float, device=self.device) + self.c = torch.tensor((np.array(self.size) - 1) / 2, dtype=torch.float, device=self.device) + self.xc = self.xx - self.c[0] + self.yc = self.yy - self.c[1] + self.zc = self.zz - self.c[2] + return + + def prepare_one_hot(self): + if self.synth_args.left_hemis_only: + n_labels = len(label_list_segmentation_brainseg_left) + label_list_segmentation = label_list_segmentation_brainseg_left + else: + # Matrix for one-hot encoding (includes a lookup-table) + n_labels = len(label_list_segmentation_brainseg_with_extracerebral) + label_list_segmentation = label_list_segmentation_brainseg_with_extracerebral + + self.lut = torch.zeros(10000, dtype=torch.long, device=self.device) + for l in range(n_labels): + self.lut[label_list_segmentation[l]] = l + self.onehotmatrix = torch.eye(n_labels, dtype=torch.float, device=self.device) + + # useless for left_hemis_only + nlat = int((n_labels - n_neutral_labels_brainseg_with_extracerebral) / 2.0) + self.vflip = np.concatenate([np.array(range(n_neutral_labels_brainseg_with_extracerebral)), + np.array(range(n_neutral_labels_brainseg_with_extracerebral + nlat, n_labels)), + np.array(range(n_neutral_labels_brainseg_with_extracerebral, n_neutral_labels_brainseg_with_extracerebral + nlat))]) + return + + + def random_affine_transform(self, shp): + rotations = (2 * self.synth_args.max_rotation * np.random.rand(3) - self.synth_args.max_rotation) / 180.0 * np.pi + shears = (2 * self.synth_args.max_shear * np.random.rand(3) - self.synth_args.max_shear) + scalings = 1 + (2 * self.synth_args.max_scaling * np.random.rand(3) - self.synth_args.max_scaling) + scaling_factor_distances = np.prod(scalings) ** .33333333333 + A = torch.tensor(make_affine_matrix(rotations, shears, scalings), dtype=torch.float, device=self.device) + + # sample center + if self.synth_args.random_shift: + max_shift = (torch.tensor(np.array(shp[0:3]) - self.size, dtype=torch.float, device=self.device)) / 2 + max_shift[max_shift < 0] = 0 + c2 = torch.tensor((np.array(shp[0:3]) - 1)/2, dtype=torch.float, device=self.device) + (2 * (max_shift * torch.rand(3, dtype=float, device=self.device)) - max_shift) + else: + c2 = torch.tensor((np.array(shp[0:3]) - 1)/2, dtype=torch.float, device=self.device) + return scaling_factor_distances, A, c2 + + def random_nonlinear_transform(self, photo_mode, spac): + nonlin_scale = self.synth_args.nonlin_scale_min + np.random.rand(1) * (self.synth_args.nonlin_scale_max - self.synth_args.nonlin_scale_min) + size_F_small = np.round(nonlin_scale * np.array(self.size)).astype(int).tolist() + if photo_mode: + size_F_small[1] = np.round(self.size[1]/spac).astype(int) + nonlin_std = self.synth_args.nonlin_std_max * np.random.rand() + Fsmall = nonlin_std * torch.randn([*size_F_small, 3], dtype=torch.float, device=self.device) + F = myzoom_torch(Fsmall, np.array(self.size) / size_F_small) + if photo_mode: + F[:, :, :, 1] = 0 + + if 'surface' in self.tasks: # TODO need to integrate the non-linear deformation fields for inverse + steplength = 1.0 / (2.0 ** self.synth_args.n_steps_svf_integration) + Fsvf = F * steplength + for _ in range(self.synth_args.n_steps_svf_integration): + Fsvf += fast_3D_interp_torch(Fsvf, self.xx + Fsvf[:, :, :, 0], self.yy + Fsvf[:, :, :, 1], self.zz + Fsvf[:, :, :, 2], 'linear') + Fsvf_neg = -F * steplength + for _ in range(self.synth_args.n_steps_svf_integration): + Fsvf_neg += fast_3D_interp_torch(Fsvf_neg, self.xx + Fsvf_neg[:, :, :, 0], self.yy + Fsvf_neg[:, :, :, 1], self.zz + Fsvf_neg[:, :, :, 2], 'linear') + F = Fsvf + Fneg = Fsvf_neg + else: + Fneg = None + return F, Fneg + + def generate_deformation(self, setups, shp): + + # generate affine deformation + scaling_factor_distances, A, c2 = self.random_affine_transform(shp) + + # generate nonlinear deformation + if self.synth_args.nonlinear_transform: + F, Fneg = self.random_nonlinear_transform(setups['photo_mode'], setups['spac']) + else: + F, Fneg = None, None + + # deform the image grid + xx2, yy2, zz2, x1, y1, z1, x2, y2, z2 = self.deform_grid(shp, A, c2, F) + + return {'scaling_factor_distances': scaling_factor_distances, + 'A': A, + 'c2': c2, + 'F': F, + 'Fneg': Fneg, + 'grid': [xx2, yy2, zz2, x1, y1, z1, x2, y2, z2], + } + + + def get_left_hemis_mask(self, grid): + [xx2, yy2, zz2, x1, y1, z1, x2, y2, z2] = grid + + if self.synth_args.left_hemis_only: + S, aff, res = read_image(self.modalities['segmentation']) # read seg map + S = torch.squeeze(torch.from_numpy(S.get_fdata()[x1:x2, y1:y2, z1:z2].astype(int))).to(self.device) + S = self.lut[S.int()] # mask out non-left labels + X, aff, res = read_image(self.modalities['registration'][0]) # read_mni_coord_X + X = torch.squeeze(torch.from_numpy(X.get_fdata()[x1:x2, y1:y2, z1:z2])).to(self.device) + self.hemis_mask = ((S > 0) & (X < 0)).int() + else: + self.hemis_mask = None + + def deform_grid(self, shp, A, c2, F): + if F is not None: + # deform the images (we do nonlinear "first" ie after so we can do heavy coronal deformations in photo mode) + xx1 = self.xc + F[:, :, :, 0] + yy1 = self.yc + F[:, :, :, 1] + zz1 = self.zc + F[:, :, :, 2] + else: + xx1 = self.xc + yy1 = self.yc + zz1 = self.zc + + xx2 = A[0, 0] * xx1 + A[0, 1] * yy1 + A[0, 2] * zz1 + c2[0] + yy2 = A[1, 0] * xx1 + A[1, 1] * yy1 + A[1, 2] * zz1 + c2[1] + zz2 = A[2, 0] * xx1 + A[2, 1] * yy1 + A[2, 2] * zz1 + c2[2] + xx2[xx2 < 0] = 0 + yy2[yy2 < 0] = 0 + zz2[zz2 < 0] = 0 + xx2[xx2 > (shp[0] - 1)] = shp[0] - 1 + yy2[yy2 > (shp[1] - 1)] = shp[1] - 1 + zz2[zz2 > (shp[2] - 1)] = shp[2] - 1 + + # Get the margins for reading images + x1 = torch.floor(torch.min(xx2)) + y1 = torch.floor(torch.min(yy2)) + z1 = torch.floor(torch.min(zz2)) + x2 = 1+torch.ceil(torch.max(xx2)) + y2 = 1 + torch.ceil(torch.max(yy2)) + z2 = 1 + torch.ceil(torch.max(zz2)) + xx2 -= x1 + yy2 -= y1 + zz2 -= z1 + + x1 = x1.cpu().numpy().astype(int) + y1 = y1.cpu().numpy().astype(int) + z1 = z1.cpu().numpy().astype(int) + x2 = x2.cpu().numpy().astype(int) + y2 = y2.cpu().numpy().astype(int) + z2 = z2.cpu().numpy().astype(int) + + return xx2, yy2, zz2, x1, y1, z1, x2, y2, z2 + + + def augment_sample(self, name, I_def, setups, deform_dict, res, target, pathol_direction = None, input_mode = 'synth'): + + sample = {} + [xx2, yy2, zz2, x1, y1, z1, x2, y2, z2] = deform_dict['grid'] + + if not isinstance(I_def, torch.Tensor): + I_def = torch.squeeze(torch.tensor(I_def.get_fdata()[x1:x2, y1:y2, z1:z2].astype(float), dtype=torch.float, device=self.device)) + if self.hemis_mask is not None: + I_def[self.hemis_mask == 0] = 0 + # Deform grid + I_def = fast_3D_interp_torch(I_def, xx2, yy2, zz2, 'linear') + + if input_mode == 'CT': + I_def = torch.clamp(I_def, min = 0., max = 80.) + + if 'pathology' in target and isinstance(target['pathology'], torch.Tensor) and target['pathology'].sum() > 0: + I_def = self.encode_pathology(I_def, target['pathology'], target['pathology_prob'], pathol_direction) + I_def[I_def < 0.] = 0. + else: + target['pathology'] = 0. + target['pathology_prob'] = 0. + + # Augment sample + aux_dict = {} + augmentation_steps = self.augmentation_steps['synth'] if input_mode == 'synth' else self.augmentation_steps['real'] + for func_name in augmentation_steps: + I_def, aux_dict = augmentation_funcs[func_name](I = I_def, aux_dict = aux_dict, cfg = self.gen_args.generator, + input_mode = input_mode, setups = setups, size = self.size, res = res, device = self.device) + + + # Back to original resolution + if self.synth_args.bspline_zooming: + I_def = interpol.resize(I_def, shape=self.size, anchor='edge', interpolation=3, bound='dct2', prefilter=True) + else: + I_def = myzoom_torch(I_def, 1 / aux_dict['factors']) + + maxi = torch.max(I_def) + I_final = I_def / maxi + + if 'super_resolution' in self.tasks: + SRresidual = aux_dict['high_res'] / maxi - I_final + sample.update({'high_res_residual': torch.flip(SRresidual, [0])[None] if setups['flip'] else SRresidual[None]}) + + + sample.update({'input': torch.flip(I_final, [0])[None] if setups['flip'] else I_final[None]}) + if 'bias_field' in self.tasks and input_mode != 'CT': + sample.update({'bias_field_log': torch.flip(aux_dict['BFlog'], [0])[None] if setups['flip'] else aux_dict['BFlog'][None]}) + + return sample + + + def generate_sample(self, name, G, setups, deform_dict, res, target): + + [xx2, yy2, zz2, x1, y1, z1, x2, y2, z2] = deform_dict['grid'] + + # Generate contrasts + mus, sigmas = self.get_contrast(setups['photo_mode']) + + G = torch.squeeze(torch.tensor(G.get_fdata()[x1:x2, y1:y2, z1:z2].astype(float), dtype=torch.float, device=self.device)) + #G[G > 255] = 0 # kill extracerebral regions + G[G == 77] = 2 # merge WM lesion to white matter region + if self.hemis_mask is not None: + G[self.hemis_mask == 0] = 0 + Gr = torch.round(G).long() + + SYN = mus[Gr] + sigmas[Gr] * torch.randn(Gr.shape, dtype=torch.float, device=self.device) + SYN[SYN < 0] = 0 + #SYN /= mus[2] # normalize by WM + #SYN = gaussian_blur_3d(SYN, 0.5*np.ones(3), self.device) # cosmetic + + SYN = fast_3D_interp_torch(SYN, xx2, yy2, zz2) + + # Make random linear combinations + if np.random.rand() < self.gen_args.mix_synth_prob: + v = torch.rand(4) + v[2] = 0 if 'T2' not in self.modalities else v[2] + v[3] = 0 if 'FLAIR' not in self.modalities else v[3] + v /= torch.sum(v) + SYN = v[0] * SYN + v[1] * target['T1'][0] + if 'T2' in self.modalities: + SYN += v[2] * target['T2'][0] + if 'FLAIR' in self.modalities: + SYN += v[3] * target['FLAIR'][0] + + if 'pathology' in target and isinstance(target['pathology'], torch.Tensor) and target['pathology'].sum() > 0: + SYN_cerebral = SYN.clone() + SYN_cerebral[Gr == 0] = 0 + SYN_cerebral = fast_3D_interp_torch(SYN_cerebral, xx2, yy2, zz2)[None] + + wm_mask = (Gr==2) | (Gr==41) + wm_mean = (SYN * wm_mask).sum() / wm_mask.sum() + gm_mask = (Gr!=0) & (Gr!=2) & (Gr!=41) + gm_mean = (SYN * gm_mask).sum() / gm_mask.sum() + + target['pathology'][SYN_cerebral == 0] = 0 + target['pathology_prob'][SYN_cerebral == 0] = 0 + # determine to be T1-resembled or T2-resembled + #if pathol_direction: lesion should be brigher than WM.mean() + # pathol_direction: +1: T2-like; -1: T1-like + pathol_direction = self.get_pathology_direction('synth', gm_mean > wm_mean) + else: + pathol_direction = None + target['pathology'] = 0. + target['pathology_prob'] = 0. + + SYN[SYN < 0.] = 0. + return target['pathology'], target['pathology_prob'], self.augment_sample(name, SYN, setups, deform_dict, res, target, pathol_direction = pathol_direction) + + def get_pathology_direction(self, input_mode, pathol_direction = None): + #if np.random.rand() < 0.1: # in some (rare) cases, randomly pick the direction + # return random.choice([True, False]) + + if pathol_direction is not None: # for synth image + return pathol_direction + + if input_mode in ['T1', 'CT']: + return False + + if input_mode in ['T2', 'FLAIR']: + return True + + return random.choice([True, False]) + + + def get_contrast(self, photo_mode): + # Sample Gaussian image + mus = 25 + 200 * torch.rand(256, dtype=torch.float, device=self.device) + sigmas = 5 + 20 * torch.rand(256, dtype=torch.float, device=self.device) + + if np.random.rand() < self.synth_args.ct_prob: + darker = 25 + 10 * torch.rand(1, dtype=torch.float, device=self.device)[0] + for l in ct_brightness_group['darker']: + mus[l] = darker + dark = 90 + 20 * torch.rand(1, dtype=torch.float, device=self.device)[0] + for l in ct_brightness_group['dark']: + mus[l] = dark + bright = 110 + 20 * torch.rand(1, dtype=torch.float, device=self.device)[0] + for l in ct_brightness_group['bright']: + mus[l] = bright + brighter = 150 + 50 * torch.rand(1, dtype=torch.float, device=self.device)[0] + for l in ct_brightness_group['brighter']: + mus[l] = brighter + + if photo_mode or np.random.rand(1)<0.5: # set the background to zero every once in a while (or always in photo mode) + mus[0] = 0 + + # partial volume + # 1 = lesion, 2 = WM, 3 = GM, 4 = CSF + v = 0.02 * torch.arange(50).to(self.device) + mus[100:150] = mus[1] * (1 - v) + mus[2] * v + mus[150:200] = mus[2] * (1 - v) + mus[3] * v + mus[200:250] = mus[3] * (1 - v) + mus[4] * v + mus[250] = mus[4] + sigmas[100:150] = torch.sqrt(sigmas[1]**2 * (1 - v) + sigmas[2]**2 * v) + sigmas[150:200] = torch.sqrt(sigmas[2]**2 * (1 - v) + sigmas[3]**2 * v) + sigmas[200:250] = torch.sqrt(sigmas[3]**2 * (1 - v) + sigmas[4]**2 * v) + sigmas[250] = sigmas[4] + + return mus, sigmas + + def get_setup_params(self): + + if self.synth_args.left_hemis_only: + hemis = 'left' + else: + hemis = 'both' + + if self.synth_args.low_res_only: + photo_mode = False + elif self.synth_args.left_hemis_only: + photo_mode = True + else: + photo_mode = np.random.rand() < self.synth_args.photo_prob + + pathol_mode = np.random.rand() < self.synth_args.pathology_prob + pathol_random_shape = np.random.rand() < self.synth_args.random_shape_prob + spac = 2.5 + 10 * np.random.rand() if photo_mode else None + flip = np.random.randn() < self.synth_args.flip_prob if not self.synth_args.left_hemis_only else False + + if photo_mode: + resolution = np.array([self.res_training_data[0], spac, self.res_training_data[2]]) + thickness = np.array([self.res_training_data[0], 0.1, self.res_training_data[2]]) + else: + resolution, thickness = resolution_sampler(self.synth_args.low_res_only) + return {'resolution': resolution, 'thickness': thickness, + 'photo_mode': photo_mode, 'pathol_mode': pathol_mode, + 'pathol_random_shape': pathol_random_shape, + 'spac': spac, 'flip': flip, 'hemis': hemis} + + + def encode_pathology(self, I, P, Pprob, pathol_direction = None): + + + if pathol_direction is None: # True: T2/FLAIR-resembled, False: T1-resembled + pathol_direction = random.choice([True, False]) + + P, Pprob = torch.squeeze(P), torch.squeeze(Pprob) + I_mu = (I * P).sum() / P.sum() + + p_mask = torch.round(P).long() + #pth_mus = I_mu/4 + I_mu/2 * torch.rand(10000, dtype=torch.float, device=self.device) + pth_mus = 3*I_mu/4 + I_mu/4 * torch.rand(10000, dtype=torch.float, device=self.device) # enforce the pathology pattern harder! + pth_mus = pth_mus if pathol_direction else -pth_mus + pth_sigmas = I_mu/4 * torch.rand(10000, dtype=torch.float, device=self.device) + I += Pprob * (pth_mus[p_mask] + pth_sigmas[p_mask] * torch.randn(p_mask.shape, dtype=torch.float, device=self.device)) + I[I < 0] = 0 + + #print('encode', P.shape, P.mean()) + #print('pre', I_mu) + #I_mu = (I * P).sum() / P.sum() + #print('post', I_mu) + + return I + + def get_info(self, t1): + + t1dm = t1[:-7] + 'T1w.defacingmask.nii' + t2 = t1[:-7] + 'T2w.nii' + t2dm = t1[:-7] + 'T2w.defacingmask.nii' + flair = t1[:-7] + 'FLAIR.nii' + flairdm = t1[:-7] + 'FLAIR.defacingmask.nii' + ct = t1[:-7] + 'CT.nii' + ctdm = t1[:-7] + 'CT.defacingmask.nii' + generation_labels = t1[:-7] + 'generation_labels.nii' + segmentation_labels = t1[:-7] + self.gen_args.segment_prefix + '.nii' + #brain_dist_map = t1[:-7] + 'brain_dist_map.nii' + lp_dist_map = t1[:-7] + 'lp_dist_map.nii' + rp_dist_map = t1[:-7] + 'rp_dist_map.nii' + lw_dist_map = t1[:-7] + 'lw_dist_map.nii' + rw_dist_map = t1[:-7] + 'rw_dist_map.nii' + mni_reg_x = t1[:-7] + 'mni_reg.x.nii' + mni_reg_y = t1[:-7] + 'mni_reg.y.nii' + mni_reg_z = t1[:-7] + 'mni_reg.z.nii' + + + self.modalities = {'T1': t1, 'Gen': generation_labels, 'segmentation': segmentation_labels, + 'distance': [lp_dist_map, lw_dist_map, rp_dist_map, rw_dist_map], + 'registration': [mni_reg_x, mni_reg_y, mni_reg_z]} + + if os.path.isfile(t1dm): + self.modalities.update({'T1_DM': t1dm}) + if os.path.isfile(t2): + self.modalities.update({'T2': t2}) + if os.path.isfile(t2dm): + self.modalities.update({'T2_DM': t2dm}) + if os.path.isfile(flair): + self.modalities.update({'FLAIR': flair}) + if os.path.isfile(flairdm): + self.modalities.update({'FLAIR_DM': flairdm}) + if os.path.isfile(ct): + self.modalities.update({'CT': ct}) + if os.path.isfile(ctdm): + self.modalities.update({'CT_DM': ctdm}) + + return self.modalities + + + def read_input(self, idx): + """ + determine input type according to prob (in generator/constants.py) + Logic: if np.random.rand() < real_image_prob and is real_image_exist --> input real images; otherwise, synthesize images. + """ + dataset_name, input_prob, t1_path, age = self.idx_to_path(idx) + case_name = os.path.basename(t1_path).split('.T1w.nii')[0] + self.modalities = self.get_info(t1_path) + + prob = np.random.rand() + if prob < input_prob['T1'] and 'T1' in self.modalities: + input_mode = 'T1' + img, aff, res = read_image(self.modalities['T1']) + elif prob < input_prob['T2'] and 'T2' in self.modalities: + input_mode = 'T2' + img, aff, res = read_image(self.modalities['T2']) + elif prob < input_prob['FLAIR'] and 'FLAIR' in self.modalities: + input_mode = 'FLAIR' + img, aff, res = read_image(self.modalities['FLAIR']) + elif prob < input_prob['CT'] and 'CT' in self.modalities: + input_mode = 'CT' + img, aff, res = read_image(self.modalities['CT']) + else: + input_mode = 'synth' + img, aff, res = read_image(self.modalities['Gen']) + + return dataset_name, case_name, input_mode, img, aff, res, age + + + def read_and_deform_target(self, idx, exist_keys, task_name, input_mode, setups, deform_dict, linear_weights = None): + current_target = {} + p_prob_path, augment, thres = None, False, 0.1 + + if task_name == 'pathology': + # NOTE: for now - encode pathology only for healthy cases + # TODO: what to do if the case has pathology itself? -- inconsistency between encoded pathol and the output + if self.pathology_type is None: # healthy + if setups['pathol_mode']: # and input_mode == 'synth': + if setups['pathol_random_shape']: + p_prob_path = 'random_shape' + augment, thres = False, self.shape_gen_args.pathol_thres + else: + p_prob_path = random.choice(pathology_prob_paths) + augment, thres = self.synth_args.augment_pathology, self.shape_gen_args.pathol_thres + else: + pass + #p_prob_path = self.modalities['pathology_prob'] + + current_target = processing_funcs[task_name](exist_keys, task_name, p_prob_path, setups, deform_dict, self.device, + mask = self.hemis_mask, + augment = augment, + pde_func = self.adv_pde, + t = self.t, + shape_gen_args = self.shape_gen_args, + thres = thres + ) + + else: + if task_name in self.modalities: + current_target = processing_funcs[task_name](exist_keys, task_name, self.modalities[task_name], + setups, deform_dict, self.device, + mask = self.hemis_mask, + cfg = self.gen_args, + onehotmatrix = self.onehotmatrix, + lut = self.lut, vflip = self.vflip + ) + else: + current_target = {task_name: 0.} + return current_target + + + def update_gen_args(self, new_args): + for key, value in vars(new_args).items(): + vars(self.gen_args.generator)[key] = value + + def __getitem__(self, idx): + if torch.is_tensor(idx): + idx = idx.tolist() + + # read input: real or synthesized image, according to customized prob + dataset_name, case_name, input_mode, img, aff, res, age = self.read_input(idx) + + # generate random values + setups = self.get_setup_params() + + # sample random deformation + deform_dict = self.generate_deformation(setups, img.shape) + + # get left_hemis_mask if needed + self.get_left_hemis_mask(deform_dict['grid']) + + # read and deform target according to the assigned tasks + target = defaultdict(lambda: None) + target['name'] = case_name + target.update(self.read_and_deform_target(idx, target.keys(), 'T1', input_mode, setups, deform_dict)) + target.update(self.read_and_deform_target(idx, target.keys(), 'T2', input_mode, setups, deform_dict)) + target.update(self.read_and_deform_target(idx, target.keys(), 'FLAIR', input_mode, setups, deform_dict)) + for task_name in self.tasks: + if task_name in processing_funcs.keys() and task_name not in ['T1', 'T2', 'FLAIR']: + target.update(self.read_and_deform_target(idx, target.keys(), task_name, input_mode, setups, deform_dict)) + + + # process or generate input sample + if input_mode == 'synth': + self.update_gen_args(self.synth_image_args) # severe noise injection for real images + target['pathology'], target['pathology_prob'], sample = \ + self.generate_sample(case_name, img, setups, deform_dict, res, target) + else: + self.update_gen_args(self.real_image_args) # milder noise injection for real images + sample = self.augment_sample(case_name, img, setups, deform_dict, res, target, + pathol_direction = self.get_pathology_direction(input_mode),input_mode = input_mode) + + if setups['flip'] and isinstance(target['pathology'], torch.Tensor): # flipping should happen after P has been encoded + target['pathology'], target['pathology_prob'] = torch.flip(target['pathology'], [1]), torch.flip(target['pathology_prob'], [1]) + + if age is not None: + target['age'] = age + + return self.datasets_num, dataset_name, input_mode, target, sample + + + + +# An example of customized dataset from BaseSynth +class BrainIDGen(BaseGen): + """ + BrainIDGen dataset + BrainIDGen enables intra-subject augmentation, i.e., each subject will have multiple augmentations + """ + def __init__(self, gen_args, device='cpu'): + super(BrainIDGen, self).__init__(gen_args, device) + + self.all_samples = gen_args.generator.all_samples + self.mild_samples = gen_args.generator.mild_samples + self.mild_generator_args = gen_args.mild_generator + self.severe_generator_args = gen_args.severe_generator + + def __getitem__(self, idx): + if torch.is_tensor(idx): + idx = idx.tolist() + + # read input: real or synthesized image, according to customized prob + dataset_name, case_name, input_mode, img, aff, res, age = self.read_input(idx) + + # generate random values + setups = self.get_setup_params() + + # sample random deformation + deform_dict = self.generate_deformation(setups, img.shape) + + # get left_hemis_mask if needed + self.get_left_hemis_mask(deform_dict['grid']) + + # read and deform target according to the assigned tasks + target = defaultdict(lambda: 1.) + target['name'] = case_name + target.update(self.read_and_deform_target(idx, target.keys(), 'T1', input_mode, setups, deform_dict)) + target.update(self.read_and_deform_target(idx, target.keys(), 'T2', input_mode, setups, deform_dict)) + target.update(self.read_and_deform_target(idx, target.keys(), 'FLAIR', input_mode, setups, deform_dict)) + for task_name in self.tasks: + if task_name in processing_funcs.keys() and task_name not in ['T1', 'T2', 'FLAIR']: + target.update(self.read_and_deform_target(idx, target.keys(), task_name, input_mode, setups, deform_dict)) + + # process or generate intra-subject input samples + samples = [] + for i_sample in range(self.all_samples): + if i_sample < self.mild_samples: + self.update_gen_args(self.mild_generator_args) + if input_mode == 'synth': + self.update_gen_args(self.synth_image_args) + target['pathology'], target['pathology_prob'], sample = \ + self.generate_sample(case_name, img, setups, deform_dict, res, target) + else: + self.update_gen_args(self.real_image_args) + sample = self.augment_sample(case_name, img, setups, deform_dict, res, target, + pathol_direction = self.get_pathology_direction(input_mode),input_mode = input_mode) + else: + self.update_gen_args(self.severe_generator_args) + if input_mode == 'synth': + self.update_gen_args(self.synth_image_args) + target['pathology'], target['pathology_prob'], sample = \ + self.generate_sample(case_name, img, setups, deform_dict, res, target) + else: + self.update_gen_args(self.real_image_args) + sample = self.augment_sample(case_name, img, setups, deform_dict, res, target, + pathol_direction = self.get_pathology_direction(input_mode),input_mode = input_mode) + + samples.append(sample) + + if setups['flip'] and isinstance(target['pathology'], torch.Tensor): # flipping should happen after P has been encoded + target['pathology'], target['pathology_prob'] = torch.flip(target['pathology'], [1]), torch.flip(target['pathology_prob'], [1]) + + if age is not None: + target['age'] = age + return self.datasets_num, dataset_name, input_mode, target, samples \ No newline at end of file diff --git a/Generator/interpol/__init__.py b/Generator/interpol/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ecb4adda01be9ca4a18facfbc2c09a9c9a0d1b1c --- /dev/null +++ b/Generator/interpol/__init__.py @@ -0,0 +1,7 @@ +from .api import * +from .resize import * +from .restrict import * +from . import backend + +from . import _version +__version__ = _version.get_versions()['version'] diff --git a/Generator/interpol/_version.py b/Generator/interpol/_version.py new file mode 100644 index 0000000000000000000000000000000000000000..bf96c7aa9aba4320b760ff7443fa2f4199e98abe --- /dev/null +++ b/Generator/interpol/_version.py @@ -0,0 +1,623 @@ + +# This file helps to compute a version number in source trees obtained from +# git-archive tarball (such as those provided by githubs download-from-tag +# feature). Distribution tarballs (built by setup.py sdist) and build +# directories (produced by setup.py build) will contain a much shorter file +# that just contains the computed version number. + +# This file is released into the public domain. Generated by +# versioneer-0.20 (https://github.com/python-versioneer/python-versioneer) + +"""Git implementation of _version.py.""" + +import errno +import os +import re +import subprocess +import sys + + +def get_keywords(): + """Get the keywords needed to look up the version information.""" + # these strings will be replaced by git during git-archive. + # setup.py/versioneer.py will grep for the variable names, so they must + # each be defined on a line of their own. _version.py will just call + # get_keywords(). + git_refnames = " (HEAD -> main, tag: 0.2.3)" + git_full = "414ed52c973b9d32e3e6a5a75c91cd5aab064f23" + git_date = "2023-04-17 20:36:50 -0400" + keywords = {"refnames": git_refnames, "full": git_full, "date": git_date} + return keywords + + +class VersioneerConfig: # pylint: disable=too-few-public-methods + """Container for Versioneer configuration parameters.""" + + +def get_config(): + """Create, populate and return the VersioneerConfig() object.""" + # these strings are filled in when 'setup.py versioneer' creates + # _version.py + cfg = VersioneerConfig() + cfg.VCS = "git" + cfg.style = "pep440" + cfg.tag_prefix = "" + cfg.parentdir_prefix = "" + cfg.versionfile_source = "interpol/_version.py" + cfg.verbose = False + return cfg + + +class NotThisMethod(Exception): + """Exception raised if a method is not valid for the current scenario.""" + + +LONG_VERSION_PY = {} +HANDLERS = {} + + +def register_vcs_handler(vcs, method): # decorator + """Create decorator to mark a method as the handler of a VCS.""" + def decorate(f): + """Store f in HANDLERS[vcs][method].""" + if vcs not in HANDLERS: + HANDLERS[vcs] = {} + HANDLERS[vcs][method] = f + return f + return decorate + + +# pylint:disable=too-many-arguments,consider-using-with # noqa +def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, + env=None): + """Call the given command(s).""" + assert isinstance(commands, list) + process = None + for command in commands: + try: + dispcmd = str([command] + args) + # remember shell=False, so use git.cmd on windows, not just git + process = subprocess.Popen([command] + args, cwd=cwd, env=env, + stdout=subprocess.PIPE, + stderr=(subprocess.PIPE if hide_stderr + else None)) + break + except EnvironmentError: + e = sys.exc_info()[1] + if e.errno == errno.ENOENT: + continue + if verbose: + print("unable to run %s" % dispcmd) + print(e) + return None, None + else: + if verbose: + print("unable to find command, tried %s" % (commands,)) + return None, None + stdout = process.communicate()[0].strip().decode() + if process.returncode != 0: + if verbose: + print("unable to run %s (error)" % dispcmd) + print("stdout was %s" % stdout) + return None, process.returncode + return stdout, process.returncode + + +def versions_from_parentdir(parentdir_prefix, root, verbose): + """Try to determine the version from the parent directory name. + + Source tarballs conventionally unpack into a directory that includes both + the project name and a version string. We will also support searching up + two directory levels for an appropriately named parent directory + """ + rootdirs = [] + + for _ in range(3): + dirname = os.path.basename(root) + if dirname.startswith(parentdir_prefix): + return {"version": dirname[len(parentdir_prefix):], + "full-revisionid": None, + "dirty": False, "error": None, "date": None} + rootdirs.append(root) + root = os.path.dirname(root) # up a level + + if verbose: + print("Tried directories %s but none started with prefix %s" % + (str(rootdirs), parentdir_prefix)) + raise NotThisMethod("rootdir doesn't start with parentdir_prefix") + + +@register_vcs_handler("git", "get_keywords") +def git_get_keywords(versionfile_abs): + """Extract version information from the given file.""" + # the code embedded in _version.py can just fetch the value of these + # keywords. When used from setup.py, we don't want to import _version.py, + # so we do it with a regexp instead. This function is not used from + # _version.py. + keywords = {} + try: + with open(versionfile_abs, "r") as fobj: + for line in fobj: + if line.strip().startswith("git_refnames ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["refnames"] = mo.group(1) + if line.strip().startswith("git_full ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["full"] = mo.group(1) + if line.strip().startswith("git_date ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["date"] = mo.group(1) + except EnvironmentError: + pass + return keywords + + +@register_vcs_handler("git", "keywords") +def git_versions_from_keywords(keywords, tag_prefix, verbose): + """Get version information from git keywords.""" + if "refnames" not in keywords: + raise NotThisMethod("Short version file found") + date = keywords.get("date") + if date is not None: + # Use only the last line. Previous lines may contain GPG signature + # information. + date = date.splitlines()[-1] + + # git-2.2.0 added "%cI", which expands to an ISO-8601 -compliant + # datestamp. However we prefer "%ci" (which expands to an "ISO-8601 + # -like" string, which we must then edit to make compliant), because + # it's been around since git-1.5.3, and it's too difficult to + # discover which version we're using, or to work around using an + # older one. + date = date.strip().replace(" ", "T", 1).replace(" ", "", 1) + refnames = keywords["refnames"].strip() + if refnames.startswith("$Format"): + if verbose: + print("keywords are unexpanded, not using") + raise NotThisMethod("unexpanded keywords, not a git-archive tarball") + refs = {r.strip() for r in refnames.strip("()").split(",")} + # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of + # just "foo-1.0". If we see a "tag: " prefix, prefer those. + TAG = "tag: " + tags = {r[len(TAG):] for r in refs if r.startswith(TAG)} + if not tags: + # Either we're using git < 1.8.3, or there really are no tags. We use + # a heuristic: assume all version tags have a digit. The old git %d + # expansion behaves like git log --decorate=short and strips out the + # refs/heads/ and refs/tags/ prefixes that would let us distinguish + # between branches and tags. By ignoring refnames without digits, we + # filter out many common branch names like "release" and + # "stabilization", as well as "HEAD" and "master". + tags = {r for r in refs if re.search(r'\d', r)} + if verbose: + print("discarding '%s', no digits" % ",".join(refs - tags)) + if verbose: + print("likely tags: %s" % ",".join(sorted(tags))) + for ref in sorted(tags): + # sorting will prefer e.g. "2.0" over "2.0rc1" + if ref.startswith(tag_prefix): + r = ref[len(tag_prefix):] + # Filter out refs that exactly match prefix or that don't start + # with a number once the prefix is stripped (mostly a concern + # when prefix is '') + if not re.match(r'\d', r): + continue + if verbose: + print("picking %s" % r) + return {"version": r, + "full-revisionid": keywords["full"].strip(), + "dirty": False, "error": None, + "date": date} + # no suitable tags, so version is "0+unknown", but full hex is still there + if verbose: + print("no suitable tags, using unknown + full revision id") + return {"version": "0+unknown", + "full-revisionid": keywords["full"].strip(), + "dirty": False, "error": "no suitable tags", "date": None} + + +@register_vcs_handler("git", "pieces_from_vcs") +def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command): + """Get version from 'git describe' in the root of the source tree. + + This only gets called if the git-archive 'subst' keywords were *not* + expanded, and _version.py hasn't already been rewritten with a short + version string, meaning we're inside a checked out source tree. + """ + GITS = ["git"] + if sys.platform == "win32": + GITS = ["git.cmd", "git.exe"] + + _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root, + hide_stderr=True) + if rc != 0: + if verbose: + print("Directory %s not under git control" % root) + raise NotThisMethod("'git rev-parse --git-dir' returned error") + + # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] + # if there isn't one, this yields HEX[-dirty] (no NUM) + describe_out, rc = runner(GITS, ["describe", "--tags", "--dirty", + "--always", "--long", + "--match", "%s*" % tag_prefix], + cwd=root) + # --long was added in git-1.5.5 + if describe_out is None: + raise NotThisMethod("'git describe' failed") + describe_out = describe_out.strip() + full_out, rc = runner(GITS, ["rev-parse", "HEAD"], cwd=root) + if full_out is None: + raise NotThisMethod("'git rev-parse' failed") + full_out = full_out.strip() + + pieces = {} + pieces["long"] = full_out + pieces["short"] = full_out[:7] # maybe improved later + pieces["error"] = None + + branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"], + cwd=root) + # --abbrev-ref was added in git-1.6.3 + if rc != 0 or branch_name is None: + raise NotThisMethod("'git rev-parse --abbrev-ref' returned error") + branch_name = branch_name.strip() + + if branch_name == "HEAD": + # If we aren't exactly on a branch, pick a branch which represents + # the current commit. If all else fails, we are on a branchless + # commit. + branches, rc = runner(GITS, ["branch", "--contains"], cwd=root) + # --contains was added in git-1.5.4 + if rc != 0 or branches is None: + raise NotThisMethod("'git branch --contains' returned error") + branches = branches.split("\n") + + # Remove the first line if we're running detached + if "(" in branches[0]: + branches.pop(0) + + # Strip off the leading "* " from the list of branches. + branches = [branch[2:] for branch in branches] + if "master" in branches: + branch_name = "master" + elif not branches: + branch_name = None + else: + # Pick the first branch that is returned. Good or bad. + branch_name = branches[0] + + pieces["branch"] = branch_name + + # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] + # TAG might have hyphens. + git_describe = describe_out + + # look for -dirty suffix + dirty = git_describe.endswith("-dirty") + pieces["dirty"] = dirty + if dirty: + git_describe = git_describe[:git_describe.rindex("-dirty")] + + # now we have TAG-NUM-gHEX or HEX + + if "-" in git_describe: + # TAG-NUM-gHEX + mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) + if not mo: + # unparseable. Maybe git-describe is misbehaving? + pieces["error"] = ("unable to parse git-describe output: '%s'" + % describe_out) + return pieces + + # tag + full_tag = mo.group(1) + if not full_tag.startswith(tag_prefix): + if verbose: + fmt = "tag '%s' doesn't start with prefix '%s'" + print(fmt % (full_tag, tag_prefix)) + pieces["error"] = ("tag '%s' doesn't start with prefix '%s'" + % (full_tag, tag_prefix)) + return pieces + pieces["closest-tag"] = full_tag[len(tag_prefix):] + + # distance: number of commits since tag + pieces["distance"] = int(mo.group(2)) + + # commit: short hex revision ID + pieces["short"] = mo.group(3) + + else: + # HEX: no tags + pieces["closest-tag"] = None + count_out, rc = runner(GITS, ["rev-list", "HEAD", "--count"], cwd=root) + pieces["distance"] = int(count_out) # total number of commits + + # commit date: see ISO-8601 comment in git_versions_from_keywords() + date = runner(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[0].strip() + # Use only the last line. Previous lines may contain GPG signature + # information. + date = date.splitlines()[-1] + pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) + + return pieces + + +def plus_or_dot(pieces): + """Return a + if we don't already have one, else return a .""" + if "+" in pieces.get("closest-tag", ""): + return "." + return "+" + + +def render_pep440(pieces): + """Build up version string, with post-release "local version identifier". + + Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you + get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty + + Exceptions: + 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += plus_or_dot(pieces) + rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + else: + # exception #1 + rendered = "0+untagged.%d.g%s" % (pieces["distance"], + pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + return rendered + + +def render_pep440_branch(pieces): + """TAG[[.dev0]+DISTANCE.gHEX[.dirty]] . + + The ".dev0" means not master branch. Note that .dev0 sorts backwards + (a feature branch will appear "older" than the master branch). + + Exceptions: + 1: no tags. 0[.dev0]+untagged.DISTANCE.gHEX[.dirty] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += plus_or_dot(pieces) + rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + else: + # exception #1 + rendered = "0" + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += "+untagged.%d.g%s" % (pieces["distance"], + pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + return rendered + + +def render_pep440_pre(pieces): + """TAG[.post0.devDISTANCE] -- No -dirty. + + Exceptions: + 1: no tags. 0.post0.devDISTANCE + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"]: + rendered += ".post0.dev%d" % pieces["distance"] + else: + # exception #1 + rendered = "0.post0.dev%d" % pieces["distance"] + return rendered + + +def render_pep440_post(pieces): + """TAG[.postDISTANCE[.dev0]+gHEX] . + + The ".dev0" means dirty. Note that .dev0 sorts backwards + (a dirty tree will appear "older" than the corresponding clean one), + but you shouldn't be releasing software with -dirty anyways. + + Exceptions: + 1: no tags. 0.postDISTANCE[.dev0] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += ".post%d" % pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + rendered += plus_or_dot(pieces) + rendered += "g%s" % pieces["short"] + else: + # exception #1 + rendered = "0.post%d" % pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + rendered += "+g%s" % pieces["short"] + return rendered + + +def render_pep440_post_branch(pieces): + """TAG[.postDISTANCE[.dev0]+gHEX[.dirty]] . + + The ".dev0" means not master branch. + + Exceptions: + 1: no tags. 0.postDISTANCE[.dev0]+gHEX[.dirty] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += ".post%d" % pieces["distance"] + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += plus_or_dot(pieces) + rendered += "g%s" % pieces["short"] + if pieces["dirty"]: + rendered += ".dirty" + else: + # exception #1 + rendered = "0.post%d" % pieces["distance"] + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += "+g%s" % pieces["short"] + if pieces["dirty"]: + rendered += ".dirty" + return rendered + + +def render_pep440_old(pieces): + """TAG[.postDISTANCE[.dev0]] . + + The ".dev0" means dirty. + + Exceptions: + 1: no tags. 0.postDISTANCE[.dev0] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += ".post%d" % pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + else: + # exception #1 + rendered = "0.post%d" % pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + return rendered + + +def render_git_describe(pieces): + """TAG[-DISTANCE-gHEX][-dirty]. + + Like 'git describe --tags --dirty --always'. + + Exceptions: + 1: no tags. HEX[-dirty] (note: no 'g' prefix) + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"]: + rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) + else: + # exception #1 + rendered = pieces["short"] + if pieces["dirty"]: + rendered += "-dirty" + return rendered + + +def render_git_describe_long(pieces): + """TAG-DISTANCE-gHEX[-dirty]. + + Like 'git describe --tags --dirty --always -long'. + The distance/hash is unconditional. + + Exceptions: + 1: no tags. HEX[-dirty] (note: no 'g' prefix) + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) + else: + # exception #1 + rendered = pieces["short"] + if pieces["dirty"]: + rendered += "-dirty" + return rendered + + +def render(pieces, style): + """Render the given version pieces into the requested style.""" + if pieces["error"]: + return {"version": "unknown", + "full-revisionid": pieces.get("long"), + "dirty": None, + "error": pieces["error"], + "date": None} + + if not style or style == "default": + style = "pep440" # the default + + if style == "pep440": + rendered = render_pep440(pieces) + elif style == "pep440-branch": + rendered = render_pep440_branch(pieces) + elif style == "pep440-pre": + rendered = render_pep440_pre(pieces) + elif style == "pep440-post": + rendered = render_pep440_post(pieces) + elif style == "pep440-post-branch": + rendered = render_pep440_post_branch(pieces) + elif style == "pep440-old": + rendered = render_pep440_old(pieces) + elif style == "git-describe": + rendered = render_git_describe(pieces) + elif style == "git-describe-long": + rendered = render_git_describe_long(pieces) + else: + raise ValueError("unknown style '%s'" % style) + + return {"version": rendered, "full-revisionid": pieces["long"], + "dirty": pieces["dirty"], "error": None, + "date": pieces.get("date")} + + +def get_versions(): + """Get version information or return default if unable to do so.""" + # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have + # __file__, we can work backwards from there to the root. Some + # py2exe/bbfreeze/non-CPython implementations don't do __file__, in which + # case we can only use expanded keywords. + + cfg = get_config() + verbose = cfg.verbose + + try: + return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, + verbose) + except NotThisMethod: + pass + + try: + root = os.path.realpath(__file__) + # versionfile_source is the relative path from the top of the source + # tree (where the .git directory might live) to this file. Invert + # this to find the root from __file__. + for _ in cfg.versionfile_source.split('/'): + root = os.path.dirname(root) + except NameError: + return {"version": "0+unknown", "full-revisionid": None, + "dirty": None, + "error": "unable to find root of source tree", + "date": None} + + try: + pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose) + return render(pieces, cfg.style) + except NotThisMethod: + pass + + try: + if cfg.parentdir_prefix: + return versions_from_parentdir(cfg.parentdir_prefix, root, verbose) + except NotThisMethod: + pass + + return {"version": "0+unknown", "full-revisionid": None, + "dirty": None, + "error": "unable to compute version", "date": None} diff --git a/Generator/interpol/api.py b/Generator/interpol/api.py new file mode 100644 index 0000000000000000000000000000000000000000..b7c0066ea08c74ae8ff4beaf8b694051a0dded78 --- /dev/null +++ b/Generator/interpol/api.py @@ -0,0 +1,560 @@ +"""High level interpolation API""" + +__all__ = ['grid_pull', 'grid_push', 'grid_count', 'grid_grad', + 'spline_coeff', 'spline_coeff_nd', + 'identity_grid', 'add_identity_grid', 'add_identity_grid_'] + +import torch +from .utils import expanded_shape, matvec +from .jit_utils import movedim1, meshgrid +from .autograd import (GridPull, GridPush, GridCount, GridGrad, + SplineCoeff, SplineCoeffND) +from . import backend, jitfields + +_doc_interpolation = \ +"""`interpolation` can be an int, a string or an InterpolationType. + Possible values are: + - 0 or 'nearest' + - 1 or 'linear' + - 2 or 'quadratic' + - 3 or 'cubic' + - 4 or 'fourth' + - 5 or 'fifth' + - etc. + A list of values can be provided, in the order [W, H, D], + to specify dimension-specific interpolation orders.""" + +_doc_bound = \ +"""`bound` can be an int, a string or a BoundType. + Possible values are: + - 'replicate' or 'nearest' : a a a | a b c d | d d d + - 'dct1' or 'mirror' : d c b | a b c d | c b a + - 'dct2' or 'reflect' : c b a | a b c d | d c b + - 'dst1' or 'antimirror' : -b -a 0 | a b c d | 0 -d -c + - 'dst2' or 'antireflect' : -c -b -a | a b c d | -d -c -b + - 'dft' or 'wrap' : b c d | a b c d | a b c + - 'zero' or 'zeros' : 0 0 0 | a b c d | 0 0 0 + A list of values can be provided, in the order [W, H, D], + to specify dimension-specific boundary conditions. + Note that + - `dft` corresponds to circular padding + - `dct2` corresponds to Neumann boundary conditions (symmetric) + - `dst2` corresponds to Dirichlet boundary conditions (antisymmetric) + See https://en.wikipedia.org/wiki/Discrete_cosine_transform + https://en.wikipedia.org/wiki/Discrete_sine_transform""" + +_doc_bound_coeff = \ +"""`bound` can be an int, a string or a BoundType. + Possible values are: + - 'replicate' or 'nearest' : a a a | a b c d | d d d + - 'dct1' or 'mirror' : d c b | a b c d | c b a + - 'dct2' or 'reflect' : c b a | a b c d | d c b + - 'dst1' or 'antimirror' : -b -a 0 | a b c d | 0 -d -c + - 'dst2' or 'antireflect' : -c -b -a | a b c d | -d -c -b + - 'dft' or 'wrap' : b c d | a b c d | a b c + - 'zero' or 'zeros' : 0 0 0 | a b c d | 0 0 0 + A list of values can be provided, in the order [W, H, D], + to specify dimension-specific boundary conditions. + Note that + - `dft` corresponds to circular padding + - `dct1` corresponds to mirroring about the center of the first/last voxel + - `dct2` corresponds to mirroring about the edge of the first/last voxel + See https://en.wikipedia.org/wiki/Discrete_cosine_transform + https://en.wikipedia.org/wiki/Discrete_sine_transform + + /!\ Only 'dct1', 'dct2' and 'dft' are implemented for interpolation + orders >= 6.""" + +_ref_coeff = \ +"""..[1] M. Unser, A. Aldroubi and M. Eden. + "B-Spline Signal Processing: Part I-Theory," + IEEE Transactions on Signal Processing 41(2):821-832 (1993). +..[2] M. Unser, A. Aldroubi and M. Eden. + "B-Spline Signal Processing: Part II-Efficient Design and Applications," + IEEE Transactions on Signal Processing 41(2):834-848 (1993). +..[3] M. Unser. + "Splines: A Perfect Fit for Signal and Image Processing," + IEEE Signal Processing Magazine 16(6):22-38 (1999). +""" + + +def _preproc(grid, input=None, mode=None): + """Preprocess tensors for pull/push/count/grad + + Low level bindings expect inputs of shape + [batch, channel, *spatial] and [batch, *spatial, dim], whereas + the high level python API accepts inputs of shape + [..., [channel], *spatial] and [..., *spatial, dim]. + + This function broadcasts and reshapes the input tensors accordingly. + /!\\ This *can* trigger large allocations /!\\ + """ + dim = grid.shape[-1] + if input is None: + spatial = grid.shape[-dim-1:-1] + batch = grid.shape[:-dim-1] + grid = grid.reshape([-1, *spatial, dim]) + info = dict(batch=batch, channel=[1] if batch else [], dim=dim) + return grid, info + + grid_spatial = grid.shape[-dim-1:-1] + grid_batch = grid.shape[:-dim-1] + input_spatial = input.shape[-dim:] + channel = 0 if input.dim() == dim else input.shape[-dim-1] + input_batch = input.shape[:-dim-1] + + if mode == 'push': + grid_spatial = input_spatial = expanded_shape(grid_spatial, input_spatial) + + # broadcast and reshape + batch = expanded_shape(grid_batch, input_batch) + grid = grid.expand([*batch, *grid_spatial, dim]) + grid = grid.reshape([-1, *grid_spatial, dim]) + input = input.expand([*batch, channel or 1, *input_spatial]) + input = input.reshape([-1, channel or 1, *input_spatial]) + + out_channel = [channel] if channel else ([1] if batch else []) + info = dict(batch=batch, channel=out_channel, dim=dim) + return grid, input, info + + +def _postproc(out, shape_info, mode): + """Postprocess tensors for pull/push/count/grad""" + dim = shape_info['dim'] + if mode != 'grad': + spatial = out.shape[-dim:] + feat = [] + else: + spatial = out.shape[-dim-1:-1] + feat = [out.shape[-1]] + batch = shape_info['batch'] + channel = shape_info['channel'] + + out = out.reshape([*batch, *channel, *spatial, *feat]) + return out + + +def grid_pull(input, grid, interpolation='linear', bound='zero', + extrapolate=False, prefilter=False): + """Sample an image with respect to a deformation field. + + Notes + ----- + {interpolation} + + {bound} + + If the input dtype is not a floating point type, the input image is + assumed to contain labels. Then, unique labels are extracted + and resampled individually, making them soft labels. Finally, + the label map is reconstructed from the individual soft labels by + assigning the label with maximum soft value. + + Parameters + ---------- + input : (..., [channel], *inshape) tensor + Input image. + grid : (..., *outshape, dim) tensor + Transformation field. + interpolation : int or sequence[int], default=1 + Interpolation order. + bound : BoundType or sequence[BoundType], default='zero' + Boundary conditions. + extrapolate : bool or int, default=True + Extrapolate out-of-bound data. + prefilter : bool, default=False + Apply spline pre-filter (= interpolates the input) + + Returns + ------- + output : (..., [channel], *outshape) tensor + Deformed image. + + """ + if backend.jitfields and jitfields.available: + return jitfields.grid_pull(input, grid, interpolation, bound, + extrapolate, prefilter) + + grid, input, shape_info = _preproc(grid, input) + batch, channel = input.shape[:2] + dim = grid.shape[-1] + + if not input.dtype.is_floating_point: + # label map -> specific processing + out = input.new_zeros([batch, channel, *grid.shape[1:-1]]) + pmax = grid.new_zeros([batch, channel, *grid.shape[1:-1]]) + for label in input.unique(): + soft = (input == label).to(grid.dtype) + if prefilter: + input = spline_coeff_nd(soft, interpolation=interpolation, + bound=bound, dim=dim, inplace=True) + soft = GridPull.apply(soft, grid, interpolation, bound, extrapolate) + out[soft > pmax] = label + pmax = torch.max(pmax, soft) + else: + if prefilter: + input = spline_coeff_nd(input, interpolation=interpolation, + bound=bound, dim=dim) + out = GridPull.apply(input, grid, interpolation, bound, extrapolate) + + return _postproc(out, shape_info, mode='pull') + + +def grid_push(input, grid, shape=None, interpolation='linear', bound='zero', + extrapolate=False, prefilter=False): + """Splat an image with respect to a deformation field (pull adjoint). + + Notes + ----- + {interpolation} + + {bound} + + Parameters + ---------- + input : (..., [channel], *inshape) tensor + Input image. + grid : (..., *inshape, dim) tensor + Transformation field. + shape : sequence[int], default=inshape + Output shape + interpolation : int or sequence[int], default=1 + Interpolation order. + bound : BoundType, or sequence[BoundType], default='zero' + Boundary conditions. + extrapolate : bool or int, default=True + Extrapolate out-of-bound data. + prefilter : bool, default=False + Apply spline pre-filter. + + Returns + ------- + output : (..., [channel], *shape) tensor + Spatted image. + + """ + if backend.jitfields and jitfields.available: + return jitfields.grid_push(input, grid, shape, interpolation, bound, + extrapolate, prefilter) + + grid, input, shape_info = _preproc(grid, input, mode='push') + dim = grid.shape[-1] + + if shape is None: + shape = tuple(input.shape[2:]) + + out = GridPush.apply(input, grid, shape, interpolation, bound, extrapolate) + if prefilter: + out = spline_coeff_nd(out, interpolation=interpolation, bound=bound, + dim=dim, inplace=True) + return _postproc(out, shape_info, mode='push') + + +def grid_count(grid, shape=None, interpolation='linear', bound='zero', + extrapolate=False): + """Splatting weights with respect to a deformation field (pull adjoint). + + Notes + ----- + {interpolation} + + {bound} + + Parameters + ---------- + grid : (..., *inshape, dim) tensor + Transformation field. + shape : sequence[int], default=inshape + Output shape + interpolation : int or sequence[int], default=1 + Interpolation order. + bound : BoundType, or sequence[BoundType], default='zero' + Boundary conditions. + extrapolate : bool or int, default=True + Extrapolate out-of-bound data. + + Returns + ------- + output : (..., [1], *shape) tensor + Splatted weights. + + """ + if backend.jitfields and jitfields.available: + return jitfields.grid_count(grid, shape, interpolation, bound, extrapolate) + + grid, shape_info = _preproc(grid) + out = GridCount.apply(grid, shape, interpolation, bound, extrapolate) + return _postproc(out, shape_info, mode='count') + + +def grid_grad(input, grid, interpolation='linear', bound='zero', + extrapolate=False, prefilter=False): + """Sample spatial gradients of an image with respect to a deformation field. + + Notes + ----- + {interpolation} + + {bound} + + Parameters + ---------- + input : (..., [channel], *inshape) tensor + Input image. + grid : (..., *inshape, dim) tensor + Transformation field. + shape : sequence[int], default=inshape + Output shape + interpolation : int or sequence[int], default=1 + Interpolation order. + bound : BoundType, or sequence[BoundType], default='zero' + Boundary conditions. + extrapolate : bool or int, default=True + Extrapolate out-of-bound data. + prefilter : bool, default=False + Apply spline pre-filter (= interpolates the input) + + Returns + ------- + output : (..., [channel], *shape, dim) tensor + Sampled gradients. + + """ + if backend.jitfields and jitfields.available: + return jitfields.grid_grad(input, grid, interpolation, bound, + extrapolate, prefilter) + + grid, input, shape_info = _preproc(grid, input) + dim = grid.shape[-1] + if prefilter: + input = spline_coeff_nd(input, interpolation, bound, dim) + out = GridGrad.apply(input, grid, interpolation, bound, extrapolate) + return _postproc(out, shape_info, mode='grad') + + +def spline_coeff(input, interpolation='linear', bound='dct2', dim=-1, + inplace=False): + """Compute the interpolating spline coefficients, for a given spline order + and boundary conditions, along a single dimension. + + Notes + ----- + {interpolation} + + {bound} + + References + ---------- + {ref} + + + Parameters + ---------- + input : tensor + Input image. + interpolation : int or sequence[int], default=1 + Interpolation order. + bound : BoundType or sequence[BoundType], default='dct1' + Boundary conditions. + dim : int, default=-1 + Dimension along which to process + inplace : bool, default=False + Process the volume in place. + + Returns + ------- + output : tensor + Coefficient image. + + """ + # This implementation is based on the file bsplines.c in SPM12, written + # by John Ashburner, which is itself based on the file coeff.c, + # written by Philippe Thevenaz: http://bigwww.epfl.ch/thevenaz/interpolation + # . DCT1 boundary conditions were derived by Thevenaz and Unser. + # . DFT boundary conditions were derived by John Ashburner. + # SPM12 is released under the GNU-GPL v2 license. + # Philippe Thevenaz's code does not have an explicit license as far + # as we know. + if backend.jitfields and jitfields.available: + return jitfields.spline_coeff(input, interpolation, bound, + dim, inplace) + + out = SplineCoeff.apply(input, bound, interpolation, dim, inplace) + return out + + +def spline_coeff_nd(input, interpolation='linear', bound='dct2', dim=None, + inplace=False): + """Compute the interpolating spline coefficients, for a given spline order + and boundary conditions, along the last `dim` dimensions. + + Notes + ----- + {interpolation} + + {bound} + + References + ---------- + {ref} + + Parameters + ---------- + input : (..., *spatial) tensor + Input image. + interpolation : int or sequence[int], default=1 + Interpolation order. + bound : BoundType or sequence[BoundType], default='dct1' + Boundary conditions. + dim : int, default=-1 + Number of spatial dimensions + inplace : bool, default=False + Process the volume in place. + + Returns + ------- + output : (..., *spatial) tensor + Coefficient image. + + """ + # This implementation is based on the file bsplines.c in SPM12, written + # by John Ashburner, which is itself based on the file coeff.c, + # written by Philippe Thevenaz: http://bigwww.epfl.ch/thevenaz/interpolation + # . DCT1 boundary conditions were derived by Thevenaz and Unser. + # . DFT boundary conditions were derived by John Ashburner. + # SPM12 is released under the GNU-GPL v2 license. + # Philippe Thevenaz's code does not have an explicit license as far + # as we know. + if backend.jitfields and jitfields.available: + return jitfields.spline_coeff_nd(input, interpolation, bound, + dim, inplace) + + out = SplineCoeffND.apply(input, bound, interpolation, dim, inplace) + return out + + +grid_pull.__doc__ = grid_pull.__doc__.format( + interpolation=_doc_interpolation, bound=_doc_bound) +grid_push.__doc__ = grid_push.__doc__.format( + interpolation=_doc_interpolation, bound=_doc_bound) +grid_count.__doc__ = grid_count.__doc__.format( + interpolation=_doc_interpolation, bound=_doc_bound) +grid_grad.__doc__ = grid_grad.__doc__.format( + interpolation=_doc_interpolation, bound=_doc_bound) +spline_coeff.__doc__ = spline_coeff.__doc__.format( + interpolation=_doc_interpolation, bound=_doc_bound_coeff, ref=_ref_coeff) +spline_coeff_nd.__doc__ = spline_coeff_nd.__doc__.format( + interpolation=_doc_interpolation, bound=_doc_bound_coeff, ref=_ref_coeff) + +# aliases +pull = grid_pull +push = grid_push +count = grid_count + + +def identity_grid(shape, dtype=None, device=None): + """Returns an identity deformation field. + + Parameters + ---------- + shape : (dim,) sequence of int + Spatial dimension of the field. + dtype : torch.dtype, default=`get_default_dtype()` + Data type. + device torch.device, optional + Device. + + Returns + ------- + grid : (*shape, dim) tensor + Transformation field + + """ + mesh1d = [torch.arange(float(s), dtype=dtype, device=device) + for s in shape] + grid = torch.stack(meshgrid(mesh1d), dim=-1) + return grid + + +@torch.jit.script +def add_identity_grid_(disp): + """Adds the identity grid to a displacement field, inplace. + + Parameters + ---------- + disp : (..., *spatial, dim) tensor + Displacement field + + Returns + ------- + grid : (..., *spatial, dim) tensor + Transformation field + + """ + dim = disp.shape[-1] + spatial = disp.shape[-dim-1:-1] + mesh1d = [torch.arange(s, dtype=disp.dtype, device=disp.device) + for s in spatial] + grid = meshgrid(mesh1d) + disp = movedim1(disp, -1, 0) + for i, grid1 in enumerate(grid): + disp[i].add_(grid1) + disp = movedim1(disp, 0, -1) + return disp + + +@torch.jit.script +def add_identity_grid(disp): + """Adds the identity grid to a displacement field. + + Parameters + ---------- + disp : (..., *spatial, dim) tensor + Displacement field + + Returns + ------- + grid : (..., *spatial, dim) tensor + Transformation field + + """ + return add_identity_grid_(disp.clone()) + + +def affine_grid(mat, shape): + """Create a dense transformation grid from an affine matrix. + + Parameters + ---------- + mat : (..., D[+1], D+1) tensor + Affine matrix (or matrices). + shape : (D,) sequence[int] + Shape of the grid, with length D. + + Returns + ------- + grid : (..., *shape, D) tensor + Dense transformation grid + + """ + mat = torch.as_tensor(mat) + shape = list(shape) + nb_dim = mat.shape[-1] - 1 + if nb_dim != len(shape): + raise ValueError('Dimension of the affine matrix ({}) and shape ({}) ' + 'are not the same.'.format(nb_dim, len(shape))) + if mat.shape[-2] not in (nb_dim, nb_dim+1): + raise ValueError('First argument should be matrces of shape ' + '(..., {0}, {1}) or (..., {1], {1}) but got {2}.' + .format(nb_dim, nb_dim+1, mat.shape)) + batch_shape = mat.shape[:-2] + grid = identity_grid(shape, mat.dtype, mat.device) + if batch_shape: + for _ in range(len(batch_shape)): + grid = grid.unsqueeze(0) + for _ in range(nb_dim): + mat = mat.unsqueeze(-1) + lin = mat[..., :nb_dim, :nb_dim] + off = mat[..., :nb_dim, -1] + grid = matvec(lin, grid) + off + return grid diff --git a/Generator/interpol/autograd.py b/Generator/interpol/autograd.py new file mode 100644 index 0000000000000000000000000000000000000000..40cace911615a22a3d33cc79c4697bfabd868c2b --- /dev/null +++ b/Generator/interpol/autograd.py @@ -0,0 +1,301 @@ +"""AutoGrad version of pull/push/count/grad""" +import torch +from .coeff import spline_coeff_nd, spline_coeff +from .bounds import BoundType +from .splines import InterpolationType +from .pushpull import ( + grid_pull, grid_pull_backward, + grid_push, grid_push_backward, + grid_count, grid_count_backward, + grid_grad, grid_grad_backward) +from .utils import fake_decorator +try: + from torch.cuda.amp import custom_fwd, custom_bwd +except (ModuleNotFoundError, ImportError): + custom_fwd = custom_bwd = fake_decorator + + +def make_list(x): + if not isinstance(x, (list, tuple)): + x = [x] + return list(x) + + +def bound_to_nitorch(bound, as_type='str'): + """Convert boundary type to niTorch's convention. + + Parameters + ---------- + bound : [list of] str or bound_like + Boundary condition in any convention + as_type : {'str', 'enum', 'int'}, default='str' + Return BoundType or int rather than str + + Returns + ------- + bound : [list of] str or BoundType + Boundary condition in NITorch's convention + + """ + intype = type(bound) + if not isinstance(bound, (list, tuple)): + bound = [bound] + obound = [] + for b in bound: + b = b.lower() if isinstance(b, str) else b + if b in ('replicate', 'repeat', 'border', 'nearest', BoundType.replicate): + obound.append('replicate') + elif b in ('zero', 'zeros', 'constant', BoundType.zero): + obound.append('zero') + elif b in ('dct2', 'reflect', 'reflection', 'neumann', BoundType.dct2): + obound.append('dct2') + elif b in ('dct1', 'mirror', BoundType.dct1): + obound.append('dct1') + elif b in ('dft', 'wrap', 'circular', BoundType.dft): + obound.append('dft') + elif b in ('dst2', 'antireflect', 'dirichlet', BoundType.dst2): + obound.append('dst2') + elif b in ('dst1', 'antimirror', BoundType.dst1): + obound.append('dst1') + elif isinstance(b, int): + obound.append(b) + else: + raise ValueError(f'Unknown boundary condition {b}') + obound = list(map(lambda b: getattr(BoundType, b) if isinstance(b, str) + else BoundType(b), obound)) + if as_type in ('int', int): + obound = [b.value for b in obound] + if as_type in ('str', str): + obound = [b.name for b in obound] + if issubclass(intype, (list, tuple)): + obound = intype(obound) + else: + obound = obound[0] + return obound + + +def inter_to_nitorch(inter, as_type='str'): + """Convert interpolation order to NITorch's convention. + + Parameters + ---------- + inter : [sequence of] int or str or InterpolationType + as_type : {'str', 'enum', 'int'}, default='int' + + Returns + ------- + inter : [sequence of] int or InterpolationType + + """ + intype = type(inter) + if not isinstance(inter, (list, tuple)): + inter = [inter] + ointer = [] + for o in inter: + o = o.lower() if isinstance(o, str) else o + if o in (0, 'nearest', InterpolationType.nearest): + ointer.append(0) + elif o in (1, 'linear', InterpolationType.linear): + ointer.append(1) + elif o in (2, 'quadratic', InterpolationType.quadratic): + ointer.append(2) + elif o in (3, 'cubic', InterpolationType.cubic): + ointer.append(3) + elif o in (4, 'fourth', InterpolationType.fourth): + ointer.append(4) + elif o in (5, 'fifth', InterpolationType.fifth): + ointer.append(5) + elif o in (6, 'sixth', InterpolationType.sixth): + ointer.append(6) + elif o in (7, 'seventh', InterpolationType.seventh): + ointer.append(7) + else: + raise ValueError(f'Unknown interpolation order {o}') + if as_type in ('enum', 'str', str): + ointer = list(map(InterpolationType, ointer)) + if as_type in ('str', str): + ointer = [o.name for o in ointer] + if issubclass(intype, (list, tuple)): + ointer = intype(ointer) + else: + ointer = ointer[0] + return ointer + + +class GridPull(torch.autograd.Function): + + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, input, grid, interpolation, bound, extrapolate): + + bound = bound_to_nitorch(make_list(bound), as_type='int') + interpolation = inter_to_nitorch(make_list(interpolation), as_type='int') + extrapolate = int(extrapolate) + opt = (bound, interpolation, extrapolate) + + # Pull + output = grid_pull(input, grid, *opt) + + # Context + ctx.opt = opt + ctx.save_for_backward(input, grid) + + return output + + @staticmethod + @custom_bwd + def backward(ctx, grad): + var = ctx.saved_tensors + opt = ctx.opt + grads = grid_pull_backward(grad, *var, *opt) + grad_input, grad_grid = grads + return grad_input, grad_grid, None, None, None + + +class GridPush(torch.autograd.Function): + + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, input, grid, shape, interpolation, bound, extrapolate): + + bound = bound_to_nitorch(make_list(bound), as_type='int') + interpolation = inter_to_nitorch(make_list(interpolation), as_type='int') + extrapolate = int(extrapolate) + opt = (bound, interpolation, extrapolate) + + # Push + output = grid_push(input, grid, shape, *opt) + + # Context + ctx.opt = opt + ctx.save_for_backward(input, grid) + + return output + + @staticmethod + @custom_bwd + def backward(ctx, grad): + var = ctx.saved_tensors + opt = ctx.opt + grads = grid_push_backward(grad, *var, *opt) + grad_input, grad_grid = grads + return grad_input, grad_grid, None, None, None, None + + +class GridCount(torch.autograd.Function): + + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, grid, shape, interpolation, bound, extrapolate): + + bound = bound_to_nitorch(make_list(bound), as_type='int') + interpolation = inter_to_nitorch(make_list(interpolation), as_type='int') + extrapolate = int(extrapolate) + opt = (bound, interpolation, extrapolate) + + # Push + output = grid_count(grid, shape, *opt) + + # Context + ctx.opt = opt + ctx.save_for_backward(grid) + + return output + + @staticmethod + @custom_bwd + def backward(ctx, grad): + var = ctx.saved_tensors + opt = ctx.opt + grad_grid = None + if ctx.needs_input_grad[0]: + grad_grid = grid_count_backward(grad, *var, *opt) + return grad_grid, None, None, None, None + + +class GridGrad(torch.autograd.Function): + + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, input, grid, interpolation, bound, extrapolate): + + bound = bound_to_nitorch(make_list(bound), as_type='int') + interpolation = inter_to_nitorch(make_list(interpolation), as_type='int') + extrapolate = int(extrapolate) + opt = (bound, interpolation, extrapolate) + + # Pull + output = grid_grad(input, grid, *opt) + + # Context + ctx.opt = opt + ctx.save_for_backward(input, grid) + + return output + + @staticmethod + @custom_bwd + def backward(ctx, grad): + var = ctx.saved_tensors + opt = ctx.opt + grad_input = grad_grid = None + if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: + grads = grid_grad_backward(grad, *var, *opt) + grad_input, grad_grid = grads + return grad_input, grad_grid, None, None, None + + +class SplineCoeff(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward(ctx, input, bound, interpolation, dim, inplace): + + bound = bound_to_nitorch(make_list(bound)[0], as_type='int') + interpolation = inter_to_nitorch(make_list(interpolation)[0], as_type='int') + opt = (bound, interpolation, dim, inplace) + + # Pull + output = spline_coeff(input, *opt) + + # Context + if input.requires_grad: + ctx.opt = opt + + return output + + @staticmethod + @custom_bwd + def backward(ctx, grad): + # symmetric filter -> backward == forward + # (I don't know if I can write into grad, so inplace=False to be safe) + grad = spline_coeff(grad, *ctx.opt[:-1], inplace=False) + return [grad] + [None] * 4 + + +class SplineCoeffND(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward(ctx, input, bound, interpolation, dim, inplace): + + bound = bound_to_nitorch(make_list(bound), as_type='int') + interpolation = inter_to_nitorch(make_list(interpolation), as_type='int') + opt = (bound, interpolation, dim, inplace) + + # Pull + output = spline_coeff_nd(input, *opt) + + # Context + if input.requires_grad: + ctx.opt = opt + + return output + + @staticmethod + @custom_bwd + def backward(ctx, grad): + # symmetric filter -> backward == forward + # (I don't know if I can write into grad, so inplace=False to be safe) + grad = spline_coeff_nd(grad, *ctx.opt[:-1], inplace=False) + return grad, None, None, None, None diff --git a/Generator/interpol/backend.py b/Generator/interpol/backend.py new file mode 100644 index 0000000000000000000000000000000000000000..a5e3a8386a8f6f4c23c3932039d1e8540f6b7135 --- /dev/null +++ b/Generator/interpol/backend.py @@ -0,0 +1 @@ +jitfields = False # Whether to use jitfields if available diff --git a/Generator/interpol/bounds.py b/Generator/interpol/bounds.py new file mode 100644 index 0000000000000000000000000000000000000000..67ece9415195b3547eee53a7e21e632ffe7c5a79 --- /dev/null +++ b/Generator/interpol/bounds.py @@ -0,0 +1,89 @@ +import torch +from enum import Enum +from typing import Optional +from .jit_utils import floor_div +Tensor = torch.Tensor + + +class BoundType(Enum): + zero = zeros = 0 + replicate = nearest = 1 + dct1 = mirror = 2 + dct2 = reflect = 3 + dst1 = antimirror = 4 + dst2 = antireflect = 5 + dft = wrap = 6 + + +class ExtrapolateType(Enum): + no = 0 # threshold: (0, n-1) + yes = 1 + hist = 2 # threshold: (-0.5, n-0.5) + + +@torch.jit.script +class Bound: + + def __init__(self, bound_type: int = 3): + self.type = bound_type + + def index(self, i, n: int): + if self.type in (0, 1): # zero / replicate + return i.clamp(min=0, max=n-1) + elif self.type in (3, 5): # dct2 / dst2 + n2 = n * 2 + i = torch.where(i < 0, (-i-1).remainder(n2).neg().add(n2 - 1), + i.remainder(n2)) + i = torch.where(i >= n, -i + (n2 - 1), i) + return i + elif self.type == 2: # dct1 + if n == 1: + return torch.zeros(i.shape, dtype=i.dtype, device=i.device) + else: + n2 = (n - 1) * 2 + i = i.abs().remainder(n2) + i = torch.where(i >= n, -i + n2, i) + return i + elif self.type == 4: # dst1 + n2 = 2 * (n + 1) + first = torch.zeros([1], dtype=i.dtype, device=i.device) + last = torch.full([1], n - 1, dtype=i.dtype, device=i.device) + i = torch.where(i < 0, -i - 2, i) + i = i.remainder(n2) + i = torch.where(i > n, -i + (n2 - 2), i) + i = torch.where(i == -1, first, i) + i = torch.where(i == n, last, i) + return i + elif self.type == 6: # dft + return i.remainder(n) + else: + return i + + def transform(self, i, n: int) -> Optional[Tensor]: + if self.type == 4: # dst1 + if n == 1: + return None + one = torch.ones([1], dtype=torch.int8, device=i.device) + zero = torch.zeros([1], dtype=torch.int8, device=i.device) + n2 = 2 * (n + 1) + i = torch.where(i < 0, -i + (n-1), i) + i = i.remainder(n2) + x = torch.where(i == 0, zero, one) + x = torch.where(i.remainder(n + 1) == n, zero, x) + i = floor_div(i, n+1) + x = torch.where(torch.remainder(i, 2) > 0, -x, x) + return x + elif self.type == 5: # dst2 + i = torch.where(i < 0, n - 1 - i, i) + x = torch.ones([1], dtype=torch.int8, device=i.device) + i = floor_div(i, n) + x = torch.where(torch.remainder(i, 2) > 0, -x, x) + return x + elif self.type == 0: # zero + one = torch.ones([1], dtype=torch.int8, device=i.device) + zero = torch.zeros([1], dtype=torch.int8, device=i.device) + outbounds = ((i < 0) | (i >= n)) + x = torch.where(outbounds, zero, one) + return x + else: + return None diff --git a/Generator/interpol/coeff.py b/Generator/interpol/coeff.py new file mode 100644 index 0000000000000000000000000000000000000000..d1d6c047d090f1e60e18cf858f6ffe8454488f71 --- /dev/null +++ b/Generator/interpol/coeff.py @@ -0,0 +1,344 @@ +"""Compute spline interpolating coefficients + +These functions are ported from the C routines in SPM's bsplines.c +by John Ashburner, which are themselves ports from Philippe Thevenaz's +code. JA furthermore derived the initial conditions for the DFT ("wrap around") +boundary conditions. + +Note that similar routines are available in scipy with boundary conditions +DCT1 ("mirror"), DCT2 ("reflect") and DFT ("wrap"); all derived by P. Thevenaz, +according to the comments. Our DCT2 boundary conditions are ported from +scipy. + +Only boundary conditions DCT1, DCT2 and DFT are implemented. + +References +---------- +..[1] M. Unser, A. Aldroubi and M. Eden. + "B-Spline Signal Processing: Part I-Theory," + IEEE Transactions on Signal Processing 41(2):821-832 (1993). +..[2] M. Unser, A. Aldroubi and M. Eden. + "B-Spline Signal Processing: Part II-Efficient Design and Applications," + IEEE Transactions on Signal Processing 41(2):834-848 (1993). +..[3] M. Unser. + "Splines: A Perfect Fit for Signal and Image Processing," + IEEE Signal Processing Magazine 16(6):22-38 (1999). +""" +import torch +import math +from typing import List, Optional +from .jit_utils import movedim1 +from .pushpull import pad_list_int + + +@torch.jit.script +def get_poles(order: int) -> List[float]: + empty: List[float] = [] + if order in (0, 1): + return empty + if order == 2: + return [math.sqrt(8.) - 3.] + if order == 3: + return [math.sqrt(3.) - 2.] + if order == 4: + return [math.sqrt(664. - math.sqrt(438976.)) + math.sqrt(304.) - 19., + math.sqrt(664. + math.sqrt(438976.)) - math.sqrt(304.) - 19.] + if order == 5: + return [math.sqrt(67.5 - math.sqrt(4436.25)) + math.sqrt(26.25) - 6.5, + math.sqrt(67.5 + math.sqrt(4436.25)) - math.sqrt(26.25) - 6.5] + if order == 6: + return [-0.488294589303044755130118038883789062112279161239377608394, + -0.081679271076237512597937765737059080653379610398148178525368, + -0.00141415180832581775108724397655859252786416905534669851652709] + if order == 7: + return [-0.5352804307964381655424037816816460718339231523426924148812, + -0.122554615192326690515272264359357343605486549427295558490763, + -0.0091486948096082769285930216516478534156925639545994482648003] + raise NotImplementedError + + +@torch.jit.script +def get_gain(poles: List[float]) -> float: + lam: float = 1. + for pole in poles: + lam *= (1. - pole) * (1. - 1./pole) + return lam + + +@torch.jit.script +def dft_initial(inp, pole: float, dim: int = -1, keepdim: bool = False): + + assert inp.shape[dim] > 1 + max_iter: int = int(math.ceil(-30./math.log(abs(pole)))) + max_iter = min(max_iter, inp.shape[dim]) + + poles = torch.as_tensor(pole, dtype=inp.dtype, device=inp.device) + poles = poles.pow(torch.arange(1, max_iter, dtype=inp.dtype, device=inp.device)) + poles = poles.flip(0) + + inp = movedim1(inp, dim, 0) + inp0 = inp[0] + inp = inp[1-max_iter:] + inp = movedim1(inp, 0, -1) + out = torch.matmul(inp.unsqueeze(-2), poles.unsqueeze(-1)).squeeze(-1) + out = out + inp0.unsqueeze(-1) + if keepdim: + out = movedim1(out, -1, dim) + else: + out = out.squeeze(-1) + + pole = pole ** max_iter + out = out / (1 - pole) + return out + + +@torch.jit.script +def dct1_initial(inp, pole: float, dim: int = -1, keepdim: bool = False): + + n = inp.shape[dim] + max_iter: int = int(math.ceil(-30./math.log(abs(pole)))) + + if max_iter < n: + + poles = torch.as_tensor(pole, dtype=inp.dtype, device=inp.device) + poles = poles.pow(torch.arange(1, max_iter, dtype=inp.dtype, device=inp.device)) + + inp = movedim1(inp, dim, 0) + inp0 = inp[0] + inp = inp[1:max_iter] + inp = movedim1(inp, 0, -1) + out = torch.matmul(inp.unsqueeze(-2), poles.unsqueeze(-1)).squeeze(-1) + out = out + inp0.unsqueeze(-1) + if keepdim: + out = movedim1(out, -1, dim) + else: + out = out.squeeze(-1) + + else: + max_iter = n + + polen = pole ** (n - 1) + inp0 = inp[0] + polen * inp[-1] + inp = inp[1:-1] + inp = movedim1(inp, 0, -1) + + poles = torch.as_tensor(pole, dtype=inp.dtype, device=inp.device) + poles = poles.pow(torch.arange(1, n-1, dtype=inp.dtype, device=inp.device)) + poles = poles + (polen * polen) / poles + + out = torch.matmul(inp.unsqueeze(-2), poles.unsqueeze(-1)).squeeze(-1) + out = out + inp0.unsqueeze(-1) + if keepdim: + out = movedim1(out, -1, dim) + else: + out = out.squeeze(-1) + + pole = pole ** (max_iter - 1) + out = out / (1 - pole * pole) + + return out + + +@torch.jit.script +def dct2_initial(inp, pole: float, dim: int = -1, keepdim: bool = False): + # Ported from scipy: + # https://github.com/scipy/scipy/blob/master/scipy/ndimage/src/ni_splines.c + # + # I (YB) unwarped and simplied the terms so that I could use a dot + # product instead of a loop. + # It should certainly be possible to derive a version for max_iter < n, + # as JA did for DCT1, to avoid long recursions when `n` is large. But + # I think it would require a more complicated anticausal/final condition. + + n = inp.shape[dim] + + polen = pole ** n + pole_last = polen * (1 + 1/(pole + polen * polen)) + inp00 = inp[0] + inp0 = inp[0] + pole_last * inp[-1] + inp = inp[1:-1] + inp = movedim1(inp, 0, -1) + + poles = torch.as_tensor(pole, dtype=inp.dtype, device=inp.device) + poles = (poles.pow(torch.arange(1, n-1, dtype=inp.dtype, device=inp.device)) + + poles.pow(torch.arange(2*n-2, n, -1, dtype=inp.dtype, device=inp.device))) + + out = torch.matmul(inp.unsqueeze(-2), poles.unsqueeze(-1)).squeeze(-1) + + out = out + inp0.unsqueeze(-1) + out = out * (pole / (1 - polen * polen)) + out = out + inp00.unsqueeze(-1) + + if keepdim: + out = movedim1(out, -1, dim) + else: + out = out.squeeze(-1) + + return out + + +@torch.jit.script +def dft_final(inp, pole: float, dim: int = -1, keepdim: bool = False): + + assert inp.shape[dim] > 1 + max_iter: int = int(math.ceil(-30./math.log(abs(pole)))) + max_iter = min(max_iter, inp.shape[dim]) + + poles = torch.as_tensor(pole, dtype=inp.dtype, device=inp.device) + poles = poles.pow(torch.arange(2, max_iter+1, dtype=inp.dtype, device=inp.device)) + + inp = movedim1(inp, dim, 0) + inp0 = inp[-1] + inp = inp[:max_iter-1] + inp = movedim1(inp, 0, -1) + out = torch.matmul(inp.unsqueeze(-2), poles.unsqueeze(-1)).squeeze(-1) + out = out.add(inp0.unsqueeze(-1), alpha=pole) + if keepdim: + out = movedim1(out, -1, dim) + else: + out = out.squeeze(-1) + + pole = pole ** max_iter + out = out / (pole - 1) + return out + + +@torch.jit.script +def dct1_final(inp, pole: float, dim: int = -1, keepdim: bool = False): + inp = movedim1(inp, dim, 0) + out = pole * inp[-2] + inp[-1] + out = out * (pole / (pole*pole - 1)) + if keepdim: + out = movedim1(out.unsqueeze(0), 0, dim) + return out + + +@torch.jit.script +def dct2_final(inp, pole: float, dim: int = -1, keepdim: bool = False): + # Ported from scipy: + # https://github.com/scipy/scipy/blob/master/scipy/ndimage/src/ni_splines.c + inp = movedim1(inp, dim, 0) + out = inp[-1] * (pole / (pole - 1)) + if keepdim: + out = movedim1(out.unsqueeze(0), 0, dim) + return out + + +@torch.jit.script +class CoeffBound: + + def __init__(self, bound: int): + self.bound = bound + + def initial(self, inp, pole: float, dim: int = -1, keepdim: bool = False): + if self.bound in (0, 2): # zero, dct1 + return dct1_initial(inp, pole, dim, keepdim) + elif self.bound in (1, 3): # nearest, dct2 + return dct2_initial(inp, pole, dim, keepdim) + elif self.bound == 6: # dft + return dft_initial(inp, pole, dim, keepdim) + else: + raise NotImplementedError + + def final(self, inp, pole: float, dim: int = -1, keepdim: bool = False): + if self.bound in (0, 2): # zero, dct1 + return dct1_final(inp, pole, dim, keepdim) + elif self.bound in (1, 3): # nearest, dct2 + return dct2_final(inp, pole, dim, keepdim) + elif self.bound == 6: # dft + return dft_final(inp, pole, dim, keepdim) + else: + raise NotImplementedError + + +@torch.jit.script +def filter(inp, bound: CoeffBound, poles: List[float], + dim: int = -1, inplace: bool = False): + + if not inplace: + inp = inp.clone() + + if inp.shape[dim] == 1: + return inp + + gain = get_gain(poles) + inp *= gain + inp = movedim1(inp, dim, 0) + n = inp.shape[0] + + for pole in poles: + inp[0] = bound.initial(inp, pole, dim=0, keepdim=False) + + for i in range(1, n): + inp[i].add_(inp[i-1], alpha=pole) + + inp[-1] = bound.final(inp, pole, dim=0, keepdim=False) + + for i in range(n-2, -1, -1): + inp[i].neg_().add_(inp[i+1]).mul_(pole) + + inp = movedim1(inp, 0, dim) + return inp + + +@torch.jit.script +def spline_coeff(inp, bound: int, order: int, dim: int = -1, + inplace: bool = False): + """Compute the interpolating spline coefficients, for a given spline order + and boundary conditions, along a single dimension. + + Parameters + ---------- + inp : tensor + bound : {2: dct1, 6: dft} + order : {0..7} + dim : int, default=-1 + inplace : bool, default=False + + Returns + ------- + out : tensor + + """ + if not inplace: + inp = inp.clone() + + if order in (0, 1): + return inp + + poles = get_poles(order) + return filter(inp, CoeffBound(bound), poles, dim=dim, inplace=True) + + +@torch.jit.script +def spline_coeff_nd(inp, bound: List[int], order: List[int], + dim: Optional[int] = None, inplace: bool = False): + """Compute the interpolating spline coefficients, for a given spline order + and boundary condition, along the last `dim` dimensions. + + Parameters + ---------- + inp : (..., *spatial) tensor + bound : List[{2: dct1, 6: dft}] + order : List[{0..7}] + dim : int, default=`inp.dim()` + inplace : bool, default=False + + Returns + ------- + out : (..., *spatial) tensor + + """ + if not inplace: + inp = inp.clone() + + if dim is None: + dim = inp.dim() + + bound = pad_list_int(bound, dim) + order = pad_list_int(order, dim) + + for d, b, o in zip(range(dim), bound, order): + inp = spline_coeff(inp, b, o, dim=-dim + d, inplace=True) + + return inp diff --git a/Generator/interpol/iso0.py b/Generator/interpol/iso0.py new file mode 100644 index 0000000000000000000000000000000000000000..7f43a81be437ac96cb9dd39d9f735ca6890c6f95 --- /dev/null +++ b/Generator/interpol/iso0.py @@ -0,0 +1,368 @@ +"""Isotropic 0-th order splines ("nearest neighbor")""" +import torch +from .bounds import Bound +from .jit_utils import (sub2ind_list, make_sign, + inbounds_mask_3d, inbounds_mask_2d, inbounds_mask_1d) +from typing import List, Optional +Tensor = torch.Tensor + + +@torch.jit.script +def get_indices(g, n: int, bound: Bound): + g0 = g.round().long() + sign0 = bound.transform(g0, n) + g0 = bound.index(g0, n) + return g0, sign0 + + +# ====================================================================== +# 3D +# ====================================================================== + + +@torch.jit.script +def pull3d(inp, g, bound: List[Bound], extrapolate: int = 1): + """ + inp: (B, C, iX, iY, iZ) tensor + g: (B, oX, oY, oZ, 3) tensor + bound: List{3}[Bound] tensor + extrapolate: ExtrapolateType + returns: (B, C, oX, oY, oZ) tensor + """ + dim = 3 + boundx, boundy, boundz = bound + oshape = g.shape[-dim-1:-1] + g = g.reshape([g.shape[0], 1, -1, dim]) + gx, gy, gz = g.unbind(-1) + batch = max(inp.shape[0], gx.shape[0]) + channel = inp.shape[1] + shape = inp.shape[-dim:] + nx, ny, nz = shape + + # mask of inbounds voxels + mask = inbounds_mask_3d(extrapolate, gx, gy, gz, nx, ny, nz) + + # nearest integer coordinates + gx, signx = get_indices(gx, nx, boundx) + gy, signy = get_indices(gy, ny, boundy) + gz, signz = get_indices(gz, nz, boundz) + + # gather + inp = inp.reshape(inp.shape[:2] + [-1]) + idx = sub2ind_list([gx, gy, gz], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out = inp.gather(-1, idx) + sign = make_sign([signx, signy, signz]) + if sign is not None: + out *= sign + if mask is not None: + out *= mask + out = out.reshape(out.shape[:2] + oshape) + return out + + +@torch.jit.script +def push3d(inp, g, shape: Optional[List[int]], bound: List[Bound], + extrapolate: int = 1): + """ + inp: (B, C, iX, iY, iZ) tensor + g: (B, iX, iY, iZ, 3) tensor + shape: List{3}[int], optional + bound: List{3}[Bound] tensor + extrapolate: ExtrapolateType + returns: (B, C, *shape) tensor + """ + dim = 3 + boundx, boundy, boundz = bound + if inp.shape[-dim:] != g.shape[-dim-1:-1]: + raise ValueError('Input and grid should have the same spatial shape') + ishape = inp.shape[-dim:] + g = g.reshape([g.shape[0], 1, -1, dim]) + gx, gy, gz = torch.unbind(g, -1) + inp = inp.reshape(inp.shape[:2] + [-1]) + batch = max(inp.shape[0], gx.shape[0]) + channel = inp.shape[1] + + if shape is None: + shape = ishape + nx, ny, nz = shape + + # mask of inbounds voxels + mask = inbounds_mask_3d(extrapolate, gx, gy, gz, nx, ny, nz) + + # nearest integer coordinates + gx, signx = get_indices(gx, nx, boundx) + gy, signy = get_indices(gy, ny, boundy) + gz, signz = get_indices(gz, nz, boundz) + + # scatter + out = torch.zeros([batch, channel, nx*ny*nz], dtype=inp.dtype, device=inp.device) + idx = sub2ind_list([gx, gy, gz], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + sign = make_sign([signx, signy, signz]) + if sign is not None or mask is not None: + inp = inp.clone() + if sign is not None: + inp *= sign + if mask is not None: + inp *= mask + out.scatter_add_(-1, idx, inp) + + out = out.reshape(out.shape[:2] + shape) + return out + + +# ====================================================================== +# 2D +# ====================================================================== + + +@torch.jit.script +def pull2d(inp, g, bound: List[Bound], extrapolate: int = 1): + """ + inp: (B, C, iX, iY) tensor + g: (B, oX, oY, 2) tensor + bound: List{2}[Bound] tensor + extrapolate: ExtrapolateType + returns: (B, C, oX, oY) tensor + """ + dim = 2 + boundx, boundy = bound + oshape = g.shape[-dim-1:-1] + g = g.reshape([g.shape[0], 1, -1, dim]) + gx, gy = g.unbind(-1) + batch = max(inp.shape[0], gx.shape[0]) + channel = inp.shape[1] + shape = inp.shape[-dim:] + nx, ny = shape + + # mask of inbounds voxels + mask = inbounds_mask_2d(extrapolate, gx, gy, nx, ny) + + # nearest integer coordinates + gx, signx = get_indices(gx, nx, boundx) + gy, signy = get_indices(gy, ny, boundy) + + # gather + inp = inp.reshape(inp.shape[:2] + [-1]) + idx = sub2ind_list([gx, gy], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out = inp.gather(-1, idx) + sign = make_sign([signx, signy]) + if sign is not None: + out = out * sign + if mask is not None: + out = mask * mask + out = out.reshape(out.shape[:2] + oshape) + return out + + +@torch.jit.script +def push2d(inp, g, shape: Optional[List[int]], bound: List[Bound], + extrapolate: int = 1): + """ + inp: (B, C, iX, iY) tensor + g: (B, iX, iY, 2) tensor + shape: List{2}[int], optional + bound: List{2}[Bound] tensor + extrapolate: ExtrapolateType + returns: (B, C, *shape) tensor + """ + dim = 2 + boundx, boundy = bound + if inp.shape[-dim:] != g.shape[-dim-1:-1]: + raise ValueError('Input and grid should have the same spatial shape') + ishape = inp.shape[-dim:] + g = g.reshape([g.shape[0], 1, -1, dim]) + gx, gy = torch.unbind(g, -1) + inp = inp.reshape(inp.shape[:2] + [-1]) + batch = max(inp.shape[0], gx.shape[0]) + channel = inp.shape[1] + + if shape is None: + shape = ishape + nx, ny = shape + + # mask of inbounds voxels + mask = inbounds_mask_2d(extrapolate, gx, gy, nx, ny) + + # nearest integer coordinates + gx, signx = get_indices(gx, nx, boundx) + gy, signy = get_indices(gy, ny, boundy) + + # scatter + out = torch.zeros([batch, channel, nx*ny], dtype=inp.dtype, device=inp.device) + idx = sub2ind_list([gx, gy], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + sign = make_sign([signx, signy]) + if sign is not None or mask is not None: + inp = inp.clone() + if sign is not None: + inp = inp * sign + if mask is not None: + inp = inp * mask + out.scatter_add_(-1, idx, inp) + + out = out.reshape(out.shape[:2] + shape) + return out + + +# ====================================================================== +# 1D +# ====================================================================== + + +@torch.jit.script +def pull1d(inp, g, bound: List[Bound], extrapolate: int = 1): + """ + inp: (B, C, iX) tensor + g: (B, oX, 1) tensor + bound: List{1}[Bound] tensor + extrapolate: ExtrapolateType + returns: (B, C, oX) tensor + """ + dim = 1 + boundx = bound[0] + oshape = g.shape[-dim-1:-1] + g = g.reshape([g.shape[0], 1, -1, dim]) + gx = g.squeeze(-1) + batch = max(inp.shape[0], gx.shape[0]) + channel = inp.shape[1] + shape = inp.shape[-dim:] + nx = shape[0] + + # mask of inbounds voxels + mask = inbounds_mask_1d(extrapolate, gx, nx) + + # nearest integer coordinates + gx, signx = get_indices(gx, nx, boundx) + + # gather + inp = inp.reshape(inp.shape[:2] + [-1]) + idx = gx + idx = idx.expand([batch, channel, idx.shape[-1]]) + out = inp.gather(-1, idx) + sign = signx + if sign is not None: + out = out * sign + if mask is not None: + out = out * mask + out = out.reshape(out.shape[:2] + oshape) + return out + + +@torch.jit.script +def push1d(inp, g, shape: Optional[List[int]], bound: List[Bound], + extrapolate: int = 1): + """ + inp: (B, C, iX) tensor + g: (B, iX, 1) tensor + shape: List{1}[int], optional + bound: List{1}[Bound] tensor + extrapolate: ExtrapolateType + returns: (B, C, *shape) tensor + """ + dim = 1 + boundx = bound[0] + if inp.shape[-dim:] != g.shape[-dim-1:-1]: + raise ValueError('Input and grid should have the same spatial shape') + ishape = inp.shape[-dim:] + g = g.reshape([g.shape[0], 1, -1, dim]) + gx = g.squeeze(-1) + inp = inp.reshape(inp.shape[:2] + [-1]) + batch = max(inp.shape[0], gx.shape[0]) + channel = inp.shape[1] + + if shape is None: + shape = ishape + nx = shape[0] + + # mask of inbounds voxels + mask = inbounds_mask_1d(extrapolate, gx, nx) + + # nearest integer coordinates + gx, signx = get_indices(gx, nx, boundx) + + # scatter + out = torch.zeros([batch, channel, nx], dtype=inp.dtype, device=inp.device) + idx = gx + idx = idx.expand([batch, channel, idx.shape[-1]]) + sign = signx + if sign is not None or mask is not None: + inp = inp.clone() + if sign is not None: + inp = inp * sign + if mask is not None: + inp = inp * mask + out.scatter_add_(-1, idx, inp) + + out = out.reshape(out.shape[:2] + shape) + return out + + +# ====================================================================== +# ND +# ====================================================================== + + +@torch.jit.script +def grad(inp, g, bound: List[Bound], extrapolate: int = 1): + """ + inp: (B, C, *ishape) tensor + g: (B, *oshape, D) tensor + bound: List{D}[Bound] tensor + extrapolate: ExtrapolateType + returns: (B, C, *oshape, D) tensor + """ + dim = g.shape[-1] + oshape = list(g.shape[-dim-1:-1]) + batch = max(inp.shape[0], g.shape[0]) + channel = inp.shape[1] + + return torch.zeros([batch, channel] + oshape + [dim], + dtype=inp.dtype, device=inp.device) + + +@torch.jit.script +def pushgrad(inp, g, shape: Optional[List[int]], bound: List[Bound], + extrapolate: int = 1): + """ + inp: (B, C, *ishape, D) tensor + g: (B, *ishape, D) tensor + shape: List{D}[int], optional, optional + bound: List{D}[Bound] tensor + extrapolate: ExtrapolateType + returns: (B, C, *shape) tensor + """ + dim = g.shape[-1] + if inp.shape[-dim-1:-1] != g.shape[-dim-1:-1]: + raise ValueError('Input and grid should have the same spatial shape') + ishape = inp.shape[-dim-1:-1] + batch = max(inp.shape[0], g.shape[0]) + channel = inp.shape[1] + + if shape is None: + shape = ishape + shape = list(shape) + + return torch.zeros([batch, channel] + shape, + dtype=inp.dtype, device=inp.device) + + +@torch.jit.script +def hess(inp, g, bound: List[Bound], extrapolate: int = 1): + """ + inp: (B, C, *ishape) tensor + g: (B, *oshape, D) tensor + bound: List{D}[Bound] tensor + extrapolate: ExtrapolateType + returns: (B, C, *oshape, D, D) tensor + """ + dim = g.shape[-1] + oshape = list(g.shape[-dim-1:-1]) + g = g.reshape([g.shape[0], 1, -1, dim]) + batch = max(inp.shape[0], g.shape[0]) + channel = inp.shape[1] + + return torch.zeros([batch, channel] + oshape + [dim, dim], + dtype=inp.dtype, device=inp.device) diff --git a/Generator/interpol/iso1.py b/Generator/interpol/iso1.py new file mode 100644 index 0000000000000000000000000000000000000000..fa21f12d9a68532cec2ebdcb4ce6ef7c75d6d6a6 --- /dev/null +++ b/Generator/interpol/iso1.py @@ -0,0 +1,1339 @@ +"""Isotropic 1-st order splines ("linear/bilinear/trilinear")""" +import torch +from .bounds import Bound +from .jit_utils import (sub2ind_list, make_sign, + inbounds_mask_3d, inbounds_mask_2d, inbounds_mask_1d) +from typing import List, Tuple, Optional +Tensor = torch.Tensor + + +@torch.jit.script +def get_weights_and_indices(g, n: int, bound: Bound) \ + -> Tuple[Tensor, Tensor, Tensor, Optional[Tensor], Optional[Tensor]]: + g0 = g.floor().long() + g1 = g0 + 1 + sign1 = bound.transform(g1, n) + sign0 = bound.transform(g0, n) + g1 = bound.index(g1, n) + g0 = bound.index(g0, n) + g = g - g.floor() + return g, g0, g1, sign0, sign1 + + +# ====================================================================== +# 3D +# ====================================================================== + + +@torch.jit.script +def pull3d(inp, g, bound: List[Bound], extrapolate: int = 1): + """ + inp: (B, C, iX, iY, iZ) tensor + g: (B, oX, oY, oZ, 3) tensor + bound: List{3}[Bound] tensor + extrapolate: ExtrapolateType + returns: (B, C, oX, oY, oZ) tensor + """ + dim = 3 + boundx, boundy, boundz = bound + oshape = list(g.shape[-dim-1:-1]) + g = g.reshape([g.shape[0], 1, -1, dim]) + gx, gy, gz = g.unbind(-1) + batch = max(inp.shape[0], gx.shape[0]) + channel = inp.shape[1] + shape = list(inp.shape[-dim:]) + nx, ny, nz = shape + + # mask of inbounds voxels + mask = inbounds_mask_3d(extrapolate, gx, gy, gz, nx, ny, nz) + + # corners + # (upper weight, lower corner, upper corner, lower sign, upper sign) + gx, gx0, gx1, signx0, signx1 = get_weights_and_indices(gx, nx, boundx) + gy, gy0, gy1, signy0, signy1 = get_weights_and_indices(gy, ny, boundy) + gz, gz0, gz1, signz0, signz1 = get_weights_and_indices(gz, nz, boundz) + + # gather + inp = inp.reshape(list(inp.shape[:2]) + [-1]) + # - corner 000 + idx = sub2ind_list([gx0, gy0, gz0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out = inp.gather(-1, idx) + sign = make_sign([signx0, signy0, signz0]) + if sign is not None: + out = out * sign + out = out * ((1 - gx) * (1 - gy) * (1 - gz)) + # - corner 001 + idx = sub2ind_list([gx0, gy0, gz1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx0, signy0, signz1]) + if sign is not None: + out1 = out1 * sign + out1 = out1 * ((1 - gx) * (1 - gy) * gz) + out = out + out1 + # - corner 010 + idx = sub2ind_list([gx0, gy1, gz0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx0, signy1, signz0]) + if sign is not None: + out1 = out1 * sign + out1 = out1 * ((1 - gx) * gy * (1 - gz)) + out = out + out1 + # - corner 011 + idx = sub2ind_list([gx0, gy1, gz1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx0, signy1, signz1]) + if sign is not None: + out1 = out1 * sign + out1 = out1 * ((1 - gx) * gy * gz) + out = out + out1 + # - corner 100 + idx = sub2ind_list([gx1, gy0, gz0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx1, signy0, signz0]) + if sign is not None: + out1 = out1 * sign + out1 = out1 * (gx * (1 - gy) * (1 - gz)) + out = out + out1 + # - corner 101 + idx = sub2ind_list([gx1, gy0, gz1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx1, signy0, signz1]) + if sign is not None: + out1 = out1 * sign + out1 = out1 * (gx * (1 - gy) * gz) + out = out + out1 + # - corner 110 + idx = sub2ind_list([gx1, gy1, gz0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx1, signy1, signz0]) + if sign is not None: + out1 = out1 * sign + out1 = out1 * (gx * gy * (1 - gz)) + out = out + out1 + # - corner 111 + idx = sub2ind_list([gx1, gy1, gz1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx1, signy1, signz1]) + if sign is not None: + out1 = out1 * sign + out1 = out1 * (gx * gy * gz) + out = out + out1 + + if mask is not None: + out *= mask + out = out.reshape(list(out.shape[:2]) + oshape) + return out + + +@torch.jit.script +def push3d(inp, g, shape: Optional[List[int]], bound: List[Bound], + extrapolate: int = 1): + """ + inp: (B, C, iX, iY, iZ) tensor + g: (B, iX, iY, iZ, 3) tensor + shape: List{3}[int], optional + bound: List{3}[Bound] tensor + extrapolate: ExtrapolateType + returns: (B, C, *shape) tensor + """ + dim = 3 + boundx, boundy, boundz = bound + if inp.shape[-dim:] != g.shape[-dim-1:-1]: + raise ValueError('Input and grid should have the same spatial shape') + ishape = list(inp.shape[-dim:]) + g = g.reshape([g.shape[0], 1, -1, dim]) + gx, gy, gz = torch.unbind(g, -1) + inp = inp.reshape(list(inp.shape[:2]) + [-1]) + batch = max(inp.shape[0], gx.shape[0]) + channel = inp.shape[1] + + if shape is None: + shape = ishape + shape = list(shape) + nx, ny, nz = shape + + # mask of inbounds voxels + mask = inbounds_mask_3d(extrapolate, gx, gy, gz, nx, ny, nz) + + # corners + # (upper weight, lower corner, upper corner, lower sign, upper sign) + gx, gx0, gx1, signx0, signx1 = get_weights_and_indices(gx, nx, boundx) + gy, gy0, gy1, signy0, signy1 = get_weights_and_indices(gy, ny, boundy) + gz, gz0, gz1, signz0, signz1 = get_weights_and_indices(gz, nz, boundz) + + # scatter + out = torch.zeros([batch, channel, nx*ny*nz], + dtype=inp.dtype, device=inp.device) + # - corner 000 + idx = sub2ind_list([gx0, gy0, gz0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = make_sign([signx0, signy0, signz0]) + if sign is not None: + out1 = out1 * sign + if mask is not None: + out1 = out1 * mask + out1 = out1 * ((1 - gx) * (1 - gy) * (1 - gz)) + out.scatter_add_(-1, idx, out1) + # - corner 001 + idx = sub2ind_list([gx0, gy0, gz1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = make_sign([signx0, signy0, signz1]) + if sign is not None: + out1 = out1 * sign + if mask is not None: + out1 = out1 * mask + out1 = out1 * ((1 - gx) * (1 - gy) * gz) + out.scatter_add_(-1, idx, out1) + # - corner 010 + idx = sub2ind_list([gx0, gy1, gz0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = make_sign([signx0, signy1, signz0]) + if sign is not None: + out1 = out1 * sign + if mask is not None: + out1 = out1 * mask + out1 = out1 * ((1 - gx) * gy * (1 - gz)) + out.scatter_add_(-1, idx, out1) + # - corner 011 + idx = sub2ind_list([gx0, gy1, gz1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = make_sign([signx0, signy1, signz1]) + if sign is not None: + out1 = out1 * sign + if mask is not None: + out1 = out1 * mask + out1 = out1 * ((1 - gx) * gy * gz) + out.scatter_add_(-1, idx, out1) + # - corner 100 + idx = sub2ind_list([gx1, gy0, gz0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = make_sign([signx1, signy0, signz0]) + if sign is not None: + out1 = out1 * sign + if mask is not None: + out1 = out1 * mask + out1 = out1 * (gx * (1 - gy) * (1 - gz)) + out.scatter_add_(-1, idx, out1) + # - corner 101 + idx = sub2ind_list([gx1, gy0, gz1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = make_sign([signx1, signy0, signz1]) + if sign is not None: + out1 = out1 * sign + if mask is not None: + out1 = out1 * mask + out1 = out1 * (gx * (1 - gy) * gz) + out.scatter_add_(-1, idx, out1) + # - corner 110 + idx = sub2ind_list([gx1, gy1, gz0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = make_sign([signx1, signy1, signz0]) + if sign is not None: + out1 = out1 * sign + if mask is not None: + out1 = out1 * mask + out1 = out1 * (gx * gy * (1 - gz)) + out.scatter_add_(-1, idx, out1) + # - corner 111 + idx = sub2ind_list([gx1, gy1, gz1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = make_sign([signx1, signy1, signz1]) + if sign is not None: + out1 = out1 * sign + if mask is not None: + out1 = out1 * mask + out1 = out1 * (gx * gy * gz) + out.scatter_add_(-1, idx, out1) + + out = out.reshape(list(out.shape[:2]) + shape) + return out + + +@torch.jit.script +def grad3d(inp, g, bound: List[Bound], extrapolate: int = 1): + """ + inp: (B, C, iX, iY, iZ) tensor + g: (B, oX, oY, oZ, 3) tensor + bound: List{3}[Bound] tensor + extrapolate: ExtrapolateType + returns: (B, C, oX, oY, oZ, 3) tensor + """ + dim = 3 + boundx, boundy, boundz = bound + oshape = list(g.shape[-dim-1:-1]) + g = g.reshape([g.shape[0], 1, -1, dim]) + gx, gy, gz = torch.unbind(g, -1) + batch = max(inp.shape[0], gx.shape[0]) + channel = inp.shape[1] + shape = list(inp.shape[-dim:]) + nx, ny, nz = shape + + # mask of inbounds voxels + mask = inbounds_mask_3d(extrapolate, gx, gy, gz, nx, ny, nz) + + # corners + # (upper weight, lower corner, upper corner, lower sign, upper sign) + gx, gx0, gx1, signx0, signx1 = get_weights_and_indices(gx, nx, boundx) + gy, gy0, gy1, signy0, signy1 = get_weights_and_indices(gy, ny, boundy) + gz, gz0, gz1, signz0, signz1 = get_weights_and_indices(gz, nz, boundz) + + # gather + inp = inp.reshape(list(inp.shape[:2]) + [-1]) + out = torch.empty([batch, channel] + list(g.shape[-2:]), + dtype=inp.dtype, device=inp.device) + outx, outy, outz = out.unbind(-1) + # - corner 000 + idx = sub2ind_list([gx0, gy0, gz0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + torch.gather(inp, -1, idx, out=outx) + outy.copy_(outx) + outz.copy_(outx) + sign = make_sign([signx0, signy0, signz0]) + if sign is not None: + out *= sign.unsqueeze(-1) + outx *= - (1 - gy) * (1 - gz) + outy *= - (1 - gx) * (1 - gz) + outz *= - (1 - gx) * (1 - gy) + # - corner 001 + idx = sub2ind_list([gx0, gy0, gz1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx0, signy0, signz1]) + if sign is not None: + out1 *= sign + outx.addcmul_(out1, - (1 - gy) * gz) + outy.addcmul_(out1, - (1 - gx) * gz) + outz.addcmul_(out1, (1 - gx) * (1 - gy)) + # - corner 010 + idx = sub2ind_list([gx0, gy1, gz0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx0, signy1, signz0]) + if sign is not None: + out1 *= sign + outx.addcmul_(out1, - gy * (1 - gz)) + outy.addcmul_(out1, (1 - gx) * (1 - gz)) + outz.addcmul_(out1, - (1 - gx) * gy) + # - corner 011 + idx = sub2ind_list([gx0, gy1, gz1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx0, signy1, signz1]) + if sign is not None: + out1 *= sign + outx.addcmul_(out1, - gy * gz) + outy.addcmul_(out1, (1 - gx) * gz) + outz.addcmul_(out1, (1 - gx) * gy) + # - corner 100 + idx = sub2ind_list([gx1, gy0, gz0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx1, signy0, signz0]) + if sign is not None: + out1 *= sign + outx.addcmul_(out1, (1 - gy) * (1 - gz)) + outy.addcmul_(out1, - gx * (1 - gz)) + outz.addcmul_(out1, - gx * (1 - gy)) + # - corner 101 + idx = sub2ind_list([gx1, gy0, gz1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx1, signy0, signz1]) + if sign is not None: + out1 *= sign + outx.addcmul_(out1, (1 - gy) * gz) + outy.addcmul_(out1, - gx * gz) + outz.addcmul_(out1, gx * (1 - gy)) + # - corner 110 + idx = sub2ind_list([gx1, gy1, gz0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx1, signy1, signz0]) + if sign is not None: + out1 *= sign + outx.addcmul_(out1, gy * (1 - gz)) + outy.addcmul_(out1, gx * (1 - gz)) + outz.addcmul_(out1, - gx * gy) + # - corner 111 + idx = sub2ind_list([gx1, gy1, gz1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx1, signy1, signz1]) + if sign is not None: + out1 *= sign + outx.addcmul_(out1, gy * gz) + outy.addcmul_(out1, gx * gz) + outz.addcmul_(out1, gx * gy) + + if mask is not None: + out *= mask.unsqueeze(-1) + out = out.reshape(list(out.shape[:2]) + oshape + [3]) + return out + + +@torch.jit.script +def pushgrad3d(inp, g, shape: Optional[List[int]], bound: List[Bound], + extrapolate: int = 1): + """ + inp: (B, C, iX, iY, iZ, 3) tensor + g: (B, iX, iY, iZ, 3) tensor + shape: List{3}[int], optional + bound: List{3}[Bound] tensor + extrapolate: ExtrapolateType + returns: (B, C, *shape) tensor + """ + dim = 3 + boundx, boundy, boundz = bound + if inp.shape[-dim-1:-1] != g.shape[-dim-1:-1]: + raise ValueError('Input and grid should have the same spatial shape') + ishape = list(inp.shape[-dim-1:-1]) + g = g.reshape([g.shape[0], 1, -1, dim]) + gx, gy, gz = g.unbind(-1) + inp = inp.reshape(list(inp.shape[:2]) + [-1, dim]) + batch = max(inp.shape[0], g.shape[0]) + channel = inp.shape[1] + + if shape is None: + shape = ishape + shape = list(shape) + nx, ny, nz = shape + + # mask of inbounds voxels + mask = inbounds_mask_3d(extrapolate, gx, gy, gz, nx, ny, nz) + + # corners + # (upper weight, lower corner, upper corner, lower sign, upper sign) + gx, gx0, gx1, signx0, signx1 = get_weights_and_indices(gx, nx, boundx) + gy, gy0, gy1, signy0, signy1 = get_weights_and_indices(gy, ny, boundy) + gz, gz0, gz1, signz0, signz1 = get_weights_and_indices(gz, nz, boundz) + + # scatter + out = torch.zeros([batch, channel, nx*ny*nz], + dtype=inp.dtype, device=inp.device) + # - corner 000 + idx = sub2ind_list([gx0, gy0, gz0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = make_sign([signx0, signy0, signz0]) + if sign is not None: + out1 *= sign.unsqueeze(-1) + if mask is not None: + out1 *= mask.unsqueeze(-1) + out1x, out1y, out1z = out1.unbind(-1) + out1x *= - (1 - gy) * (1 - gz) + out1y *= - (1 - gx) * (1 - gz) + out1z *= - (1 - gx) * (1 - gy) + out.scatter_add_(-1, idx, out1x + out1y + out1z) + # - corner 001 + idx = sub2ind_list([gx0, gy0, gz1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = make_sign([signx0, signy0, signz1]) + if sign is not None: + out1 *= sign.unsqueeze(-1) + if mask is not None: + out1 *= mask.unsqueeze(-1) + out1x, out1y, out1z = out1.unbind(-1) + out1x *= - (1 - gy) * gz + out1y *= - (1 - gx) * gz + out1z *= (1 - gx) * (1 - gy) + out.scatter_add_(-1, idx, out1x + out1y + out1z) + # - corner 010 + idx = sub2ind_list([gx0, gy1, gz0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = make_sign([signx0, signy1, signz0]) + if sign is not None: + out1 *= sign.unsqueeze(-1) + if mask is not None: + out1 *= mask.unsqueeze(-1) + out1x, out1y, out1z = out1.unbind(-1) + out1x *= - gy * (1 - gz) + out1y *= (1 - gx) * (1 - gz) + out1z *= - (1 - gx) * gy + out.scatter_add_(-1, idx, out1x + out1y + out1z) + # - corner 011 + idx = sub2ind_list([gx0, gy1, gz1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = make_sign([signx0, signy1, signz1]) + if sign is not None: + out1 *= sign.unsqueeze(-1) + if mask is not None: + out1 *= mask.unsqueeze(-1) + out1x, out1y, out1z = out1.unbind(-1) + out1x *= - gy * gz + out1y *= (1 - gx) * gz + out1z *= (1 - gx) * gy + out.scatter_add_(-1, idx, out1x + out1y + out1z) + # - corner 100 + idx = sub2ind_list([gx1, gy0, gz0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = make_sign([signx1, signy0, signz0]) + if sign is not None: + out1 *= sign.unsqueeze(-1) + if mask is not None: + out1 *= mask.unsqueeze(-1) + out1x, out1y, out1z = out1.unbind(-1) + out1x *= (1 - gy) * (1 - gz) + out1y *= - gx * (1 - gz) + out1z *= - gx * (1 - gy) + out.scatter_add_(-1, idx, out1x + out1y + out1z) + # - corner 101 + idx = sub2ind_list([gx1, gy0, gz1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = make_sign([signx1, signy0, signz1]) + if sign is not None: + out1 *= sign.unsqueeze(-1) + if mask is not None: + out1 *= mask.unsqueeze(-1) + out1x, out1y, out1z = out1.unbind(-1) + out1x *= (1 - gy) * gz + out1y *= - gx * gz + out1z *= gx * (1 - gy) + out.scatter_add_(-1, idx, out1x + out1y + out1z) + # - corner 110 + idx = sub2ind_list([gx1, gy1, gz0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = make_sign([signx1, signy1, signz0]) + if sign is not None: + out1 *= sign.unsqueeze(-1) + if mask is not None: + out1 *= mask.unsqueeze(-1) + out1x, out1y, out1z = out1.unbind(-1) + out1x *= gy * (1 - gz) + out1y *= gx * (1 - gz) + out1z *= - gx * gy + out.scatter_add_(-1, idx, out1x + out1y + out1z) + # - corner 111 + idx = sub2ind_list([gx1, gy1, gz1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = make_sign([signx1, signy1, signz1]) + if sign is not None: + out1 *= sign.unsqueeze(-1) + if mask is not None: + out1 *= mask.unsqueeze(-1) + out1x, out1y, out1z = out1.unbind(-1) + out1x *= gy * gz + out1y *= gx * gz + out1z *= gx * gy + out.scatter_add_(-1, idx, out1x + out1y + out1z) + + out = out.reshape(list(out.shape[:2]) + shape) + return out + + +@torch.jit.script +def hess3d(inp, g, bound: List[Bound], extrapolate: int = 1): + """ + inp: (B, C, iX, iY, iZ) tensor + g: (B, oX, oY, oZ, 3) tensor + bound: List{3}[Bound] tensor + extrapolate: ExtrapolateType + returns: (B, C, oX, oY, oZ, 3, 3) tensor + """ + dim = 3 + boundx, boundy, boundz = bound + oshape = list(g.shape[-dim-1:-1]) + g = g.reshape([g.shape[0], 1, -1, dim]) + gx, gy, gz = torch.unbind(g, -1) + batch = max(inp.shape[0], gx.shape[0]) + channel = inp.shape[1] + shape = list(inp.shape[-dim:]) + nx, ny, nz = shape + + # mask of inbounds voxels + mask = inbounds_mask_3d(extrapolate, gx, gy, gz, nx, ny, nz) + + # corners + # (upper weight, lower corner, upper corner, lower sign, upper sign) + gx, gx0, gx1, signx0, signx1 = get_weights_and_indices(gx, nx, boundx) + gy, gy0, gy1, signy0, signy1 = get_weights_and_indices(gy, ny, boundy) + gz, gz0, gz1, signz0, signz1 = get_weights_and_indices(gz, nz, boundz) + + # gather + inp = inp.reshape(list(inp.shape[:2]) + [-1]) + out = torch.empty([batch, channel, g.shape[-2], dim, dim], + dtype=inp.dtype, device=inp.device) + outx, outy, outz = out.unbind(-1) + outxx, outyx, outzx = outx.unbind(-1) + outxy, outyy, outzy = outy.unbind(-1) + outxz, outyz, outzz = outz.unbind(-1) + # - corner 000 + idx = sub2ind_list([gx0, gy0, gz0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + torch.gather(inp, -1, idx, out=outxy) + outxz.copy_(outxy) + outyz.copy_(outxy) + outxx.zero_() + outyy.zero_() + outzz.zero_() + sign = make_sign([signx0, signy0, signz0]) + if sign is not None: + out *= sign.unsqueeze(-1).unsqueeze(-1) + outxy *= (1 - gz) + outxz *= (1 - gy) + outyz *= (1 - gx) + # - corner 001 + idx = sub2ind_list([gx0, gy0, gz1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx0, signy0, signz1]) + if sign is not None: + out1 *= sign + outxy.addcmul_(out1, gz) + outxz.addcmul_(out1, - (1 - gy)) + outyz.addcmul_(out1, - (1 - gx)) + # - corner 010 + idx = sub2ind_list([gx0, gy1, gz0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx0, signy1, signz0]) + if sign is not None: + out1 *= sign + outxy.addcmul_(out1, - (1 - gz)) + outxz.addcmul_(out1, gy) + outyz.addcmul_(out1, - (1 - gx)) + # - corner 011 + idx = sub2ind_list([gx0, gy1, gz1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx0, signy1, signz1]) + if sign is not None: + out1 *= sign + outxy.addcmul_(out1, - gz) + outxz.addcmul_(out1, - gy) + outyz.addcmul_(out1, (1 - gx)) + # - corner 100 + idx = sub2ind_list([gx1, gy0, gz0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx1, signy0, signz0]) + if sign is not None: + out1 *= sign + outxy.addcmul_(out1, - (1 - gz)) + outxz.addcmul_(out1, - (1 - gy)) + outyz.addcmul_(out1, gx) + # - corner 101 + idx = sub2ind_list([gx1, gy0, gz1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx1, signy0, signz1]) + if sign is not None: + out1 *= sign + outxy.addcmul_(out1, - gz) + outxz.addcmul_(out1, (1 - gy)) + outyz.addcmul_(out1, - gx) + # - corner 110 + idx = sub2ind_list([gx1, gy1, gz0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx1, signy1, signz0]) + if sign is not None: + out1 *= sign + outxy.addcmul_(out1, (1 - gz)) + outxz.addcmul_(out1, - gy) + outyz.addcmul_(out1, - gx) + # - corner 111 + idx = sub2ind_list([gx1, gy1, gz1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx1, signy1, signz1]) + if sign is not None: + out1 *= sign + outxy.addcmul_(out1, gz) + outxz.addcmul_(out1, gy) + outyz.addcmul_(out1, gx) + + outyx.copy_(outxy) + outzx.copy_(outxz) + outzy.copy_(outyz) + + if mask is not None: + out *= mask.unsqueeze(-1).unsqueeze(-1) + out = out.reshape(list(out.shape[:2]) + oshape + [dim, dim]) + return out + + +# ====================================================================== +# 2D +# ====================================================================== + + +@torch.jit.script +def pull2d(inp, g, bound: List[Bound], extrapolate: int = 1): + """ + inp: (B, C, iX, iY) tensor + g: (B, oX, oY, 2) tensor + bound: List{2}[Bound] tensor + extrapolate: ExtrapolateType + returns: (B, C, oX, oY) tensor + """ + dim = 2 + boundx, boundy = bound + oshape = list(g.shape[-dim-1:-1]) + g = g.reshape([g.shape[0], 1, -1, dim]) + gx, gy = g.unbind(-1) + batch = max(inp.shape[0], gx.shape[0]) + channel = inp.shape[1] + shape = list(inp.shape[-dim:]) + nx, ny = shape + + # mask of inbounds voxels + mask = inbounds_mask_2d(extrapolate, gx, gy, nx, ny) + + # corners + # (upper weight, lower corner, upper corner, lower sign, upper sign) + gx, gx0, gx1, signx0, signx1 = get_weights_and_indices(gx, nx, boundx) + gy, gy0, gy1, signy0, signy1 = get_weights_and_indices(gy, ny, boundy) + + # gather + inp = inp.reshape(list(inp.shape[:2]) + [-1]) + # - corner 00 + idx = sub2ind_list([gx0, gy0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out = inp.gather(-1, idx) + sign = make_sign([signx0, signy0]) + if sign is not None: + out = out * sign + out = out * ((1 - gx) * (1 - gy)) + # - corner 01 + idx = sub2ind_list([gx0, gy1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx0, signy1]) + if sign is not None: + out1 = out1 * sign + out1 = out1 * ((1 - gx) * gy) + out = out + out1 + # - corner 10 + idx = sub2ind_list([gx1, gy0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx1, signy0]) + if sign is not None: + out1 = out1 * sign + out1 = out1 * (gx * (1 - gy)) + out = out + out1 + # - corner 11 + idx = sub2ind_list([gx1, gy1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx1, signy1]) + if sign is not None: + out1 = out1 * sign + out1 = out1 * (gx * gy) + out = out + out1 + + if mask is not None: + out *= mask + out = out.reshape(list(out.shape[:2]) + oshape) + return out + + +@torch.jit.script +def push2d(inp, g, shape: Optional[List[int]], bound: List[Bound], + extrapolate: int = 1): + """ + inp: (B, C, iX, iY) tensor + g: (B, iX, iY, 2) tensor + shape: List{2}[int], optional + bound: List{2}[Bound] tensor + extrapolate: ExtrapolateType + returns: (B, C, *shape) tensor + """ + dim = 2 + boundx, boundy = bound + if inp.shape[-dim:] != g.shape[-dim-1:-1]: + raise ValueError('Input and grid should have the same spatial shape') + ishape = list(inp.shape[-dim:]) + g = g.reshape([g.shape[0], 1, -1, dim]) + gx, gy = torch.unbind(g, -1) + inp = inp.reshape(list(inp.shape[:2]) + [-1]) + batch = max(inp.shape[0], gx.shape[0]) + channel = inp.shape[1] + + if shape is None: + shape = ishape + shape = list(shape) + nx, ny = shape + + # mask of inbounds voxels + mask = inbounds_mask_2d(extrapolate, gx, gy, nx, ny) + + # corners + # (upper weight, lower corner, upper corner, lower sign, upper sign) + gx, gx0, gx1, signx0, signx1 = get_weights_and_indices(gx, nx, boundx) + gy, gy0, gy1, signy0, signy1 = get_weights_and_indices(gy, ny, boundy) + + # scatter + out = torch.zeros([batch, channel, nx*ny], + dtype=inp.dtype, device=inp.device) + # - corner 00 + idx = sub2ind_list([gx0, gy0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = make_sign([signx0, signy0]) + if sign is not None: + out1 *= sign + if mask is not None: + out1 *= mask + out1 *= (1 - gx) * (1 - gy) + out.scatter_add_(-1, idx, out1) + # - corner 01 + idx = sub2ind_list([gx0, gy1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = make_sign([signx0, signy1]) + if sign is not None: + out1 *= sign + if mask is not None: + out1 *= mask + out1 *= (1 - gx) * gy + out.scatter_add_(-1, idx, out1) + # - corner 10 + idx = sub2ind_list([gx1, gy0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = make_sign([signx1, signy0]) + if sign is not None: + out1 *= sign + if mask is not None: + out1 *= mask + out1 *= gx * (1 - gy) + out.scatter_add_(-1, idx, out1) + # - corner 11 + idx = sub2ind_list([gx1, gy1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = make_sign([signx1, signy1]) + if sign is not None: + out1 *= sign + if mask is not None: + out1 *= mask + out1 *= gx * gy + out.scatter_add_(-1, idx, out1) + + out = out.reshape(list(out.shape[:2]) + shape) + return out + + +@torch.jit.script +def grad2d(inp, g, bound: List[Bound], extrapolate: int = 1): + """ + inp: (B, C, iX, iY) tensor + g: (B, oX, oY, 2) tensor + bound: List{2}[Bound] tensor + extrapolate: ExtrapolateType + returns: (B, C, oX, oY, 2) tensor + """ + dim = 2 + boundx, boundy = bound + oshape = list(g.shape[-dim-1:-1]) + g = g.reshape([g.shape[0], 1, -1, dim]) + gx, gy = torch.unbind(g, -1) + batch = max(inp.shape[0], gx.shape[0]) + channel = inp.shape[1] + shape = list(inp.shape[-dim:]) + nx, ny = shape + + # mask of inbounds voxels + mask = inbounds_mask_2d(extrapolate, gx, gy, nx, ny) + + # corners + # (upper weight, lower corner, upper corner, lower sign, upper sign) + gx, gx0, gx1, signx0, signx1 = get_weights_and_indices(gx, nx, boundx) + gy, gy0, gy1, signy0, signy1 = get_weights_and_indices(gy, ny, boundy) + + # gather + inp = inp.reshape(list(inp.shape[:2]) + [-1]) + out = torch.empty([batch, channel] + list(g.shape[-2:]), + dtype=inp.dtype, device=inp.device) + outx, outy = out.unbind(-1) + # - corner 00 + idx = sub2ind_list([gx0, gy0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + torch.gather(inp, -1, idx, out=outx) + outy.copy_(outx) + sign = make_sign([signx0, signy0]) + if sign is not None: + out *= sign.unsqueeze(-1) + outx *= - (1 - gy) + outy *= - (1 - gx) + # - corner 01 + idx = sub2ind_list([gx0, gy1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx0, signy1]) + if sign is not None: + out1 *= sign + outx.addcmul_(out1, - gy) + outy.addcmul_(out1, (1 - gx)) + # - corner 10 + idx = sub2ind_list([gx1, gy0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx1, signy0]) + if sign is not None: + out1 *= sign + outx.addcmul_(out1, (1 - gy)) + outy.addcmul_(out1, - gx) + # - corner 11 + idx = sub2ind_list([gx1, gy1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx1, signy1]) + if sign is not None: + out1 *= sign + outx.addcmul_(out1, gy) + outy.addcmul_(out1, gx) + + if mask is not None: + out *= mask.unsqueeze(-1) + out = out.reshape(list(out.shape[:2]) + oshape + [dim]) + return out + + +@torch.jit.script +def pushgrad2d(inp, g, shape: Optional[List[int]], bound: List[Bound], + extrapolate: int = 1): + """ + inp: (B, C, iX, iY, 2) tensor + g: (B, iX, iY, 2) tensor + shape: List{2}[int], optional + bound: List{2}[Bound] tensor + extrapolate: ExtrapolateType + returns: (B, C, *shape) tensor + """ + dim = 2 + boundx, boundy = bound + if inp.shape[-dim-1:-1] != g.shape[-dim-1:-1]: + raise ValueError('Input and grid should have the same spatial shape') + ishape = list(inp.shape[-dim-1:-1]) + g = g.reshape([g.shape[0], 1, -1, dim]) + gx, gy = g.unbind(-1) + inp = inp.reshape(list(inp.shape[:2]) + [-1, dim]) + batch = max(inp.shape[0], g.shape[0]) + channel = inp.shape[1] + + if shape is None: + shape = ishape + shape = list(shape) + nx, ny = shape + + # mask of inbounds voxels + mask = inbounds_mask_2d(extrapolate, gx, gy, nx, ny) + + # corners + # (upper weight, lower corner, upper corner, lower sign, upper sign) + gx, gx0, gx1, signx0, signx1 = get_weights_and_indices(gx, nx, boundx) + gy, gy0, gy1, signy0, signy1 = get_weights_and_indices(gy, ny, boundy) + + # scatter + out = torch.zeros([batch, channel, nx*ny], + dtype=inp.dtype, device=inp.device) + # - corner 00 + idx = sub2ind_list([gx0, gy0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = make_sign([signx0, signy0]) + if sign is not None: + out1 *= sign.unsqueeze(-1) + if mask is not None: + out1 *= mask.unsqueeze(-1) + out1x, out1y = out1.unbind(-1) + out1x *= - (1 - gy) + out1y *= - (1 - gx) + out.scatter_add_(-1, idx, out1x + out1y) + # - corner 01 + idx = sub2ind_list([gx0, gy1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = make_sign([signx0, signy1]) + if sign is not None: + out1 *= sign.unsqueeze(-1) + if mask is not None: + out1 *= mask.unsqueeze(-1) + out1x, out1y = out1.unbind(-1) + out1x *= - gy + out1y *= (1 - gx) + out.scatter_add_(-1, idx, out1x + out1y) + # - corner 10 + idx = sub2ind_list([gx1, gy0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = make_sign([signx1, signy0]) + if sign is not None: + out1 *= sign.unsqueeze(-1) + if mask is not None: + out1 *= mask.unsqueeze(-1) + out1x, out1y = out1.unbind(-1) + out1x *= (1 - gy) + out1y *= - gx + out.scatter_add_(-1, idx, out1x + out1y) + # - corner 11 + idx = sub2ind_list([gx1, gy1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = make_sign([signx1, signy1]) + if sign is not None: + out1 *= sign.unsqueeze(-1) + if mask is not None: + out1 *= mask.unsqueeze(-1) + out1x, out1y = out1.unbind(-1) + out1x *= gy + out1y *= gx + out.scatter_add_(-1, idx, out1x + out1y) + + out = out.reshape(list(out.shape[:2]) + shape) + return out + + +@torch.jit.script +def hess2d(inp, g, bound: List[Bound], extrapolate: int = 1): + """ + inp: (B, C, iX, iY) tensor + g: (B, oX, oY, 2) tensor + bound: List{2}[Bound] tensor + extrapolate: ExtrapolateType + returns: (B, C, oX, oY, 2, 2) tensor + """ + dim = 2 + boundx, boundy = bound + oshape = list(g.shape[-dim-1:-1]) + g = g.reshape([g.shape[0], 1, -1, dim]) + gx, gy = torch.unbind(g, -1) + batch = max(inp.shape[0], gx.shape[0]) + channel = inp.shape[1] + shape = list(inp.shape[-dim:]) + nx, ny = shape + + # mask of inbounds voxels + mask = inbounds_mask_2d(extrapolate, gx, gy, nx, ny) + + # corners + # (upper weight, lower corner, upper corner, lower sign, upper sign) + gx, gx0, gx1, signx0, signx1 = get_weights_and_indices(gx, nx, boundx) + gy, gy0, gy1, signy0, signy1 = get_weights_and_indices(gy, ny, boundy) + + # gather + inp = inp.reshape(list(inp.shape[:2]) + [-1]) + out = torch.empty([batch, channel, g.shape[-2], dim, dim], + dtype=inp.dtype, device=inp.device) + outx, outy = out.unbind(-1) + outxx, outyx = outx.unbind(-1) + outxy, outyy = outy.unbind(-1) + # - corner 00 + idx = sub2ind_list([gx0, gy0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + torch.gather(inp, -1, idx, out=outxy) + outxx.zero_() + outyy.zero_() + sign = make_sign([signx0, signy0]) + if sign is not None: + out *= sign.unsqueeze(-1).unsqueeze(-1) + outxy *= 1 + # - corner 01 + idx = sub2ind_list([gx0, gy1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx0, signy1]) + if sign is not None: + out1 *= sign + outxy.add_(out1, alpha=-1) + # - corner 10 + idx = sub2ind_list([gx1, gy0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx1, signy0]) + if sign is not None: + out1 *= sign + outxy.add_(out1, alpha=-1) + # - corner 11 + idx = sub2ind_list([gx1, gy1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx1, signy1]) + if sign is not None: + out1 *= sign + outxy.add_(out1) + + outyx.copy_(outxy) + + if mask is not None: + out *= mask.unsqueeze(-1).unsqueeze(-1) + out = out.reshape(list(out.shape[:2]) + oshape + [dim, dim]) + return out + + +# ====================================================================== +# 1D +# ====================================================================== + + +@torch.jit.script +def pull1d(inp, g, bound: List[Bound], extrapolate: int = 1): + """ + inp: (B, C, iX) tensor + g: (B, oX, 1) tensor + bound: List{1}[Bound] tensor + extrapolate: ExtrapolateType + returns: (B, C, oX) tensor + """ + dim = 1 + boundx = bound[0] + oshape = list(g.shape[-dim-1:-1]) + g = g.reshape([g.shape[0], 1, -1, dim]) + gx = g.squeeze(-1) + batch = max(inp.shape[0], gx.shape[0]) + channel = inp.shape[1] + shape = list(inp.shape[-dim:]) + nx = shape[0] + + # mask of inbounds voxels + mask = inbounds_mask_1d(extrapolate, gx, nx) + + # corners + # (upper weight, lower corner, upper corner, lower sign, upper sign) + gx, gx0, gx1, signx0, signx1 = get_weights_and_indices(gx, nx, boundx) + + # gather + inp = inp.reshape(list(inp.shape[:2]) + [-1]) + # - corner 0 + idx = gx0 + idx = idx.expand([batch, channel, idx.shape[-1]]) + out = inp.gather(-1, idx) + sign = signx0 + if sign is not None: + out = out * sign + out = out * (1 - gx) + # - corner 1 + idx = gx1 + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = signx1 + if sign is not None: + out1 = out1 * sign + out1 = out1 * gx + out = out + out1 + + if mask is not None: + out *= mask + out = out.reshape(list(out.shape[:2]) + oshape) + return out + + +@torch.jit.script +def push1d(inp, g, shape: Optional[List[int]], bound: List[Bound], + extrapolate: int = 1): + """ + inp: (B, C, iX, iY) tensor + g: (B, iX, iY, 2) tensor + shape: List{2}[int], optional + bound: List{2}[Bound] tensor + extrapolate: ExtrapolateType + returns: (B, C, *shape) tensor + """ + dim = 1 + boundx = bound[0] + if inp.shape[-dim:] != g.shape[-dim-1:-1]: + raise ValueError('Input and grid should have the same spatial shape') + ishape = list(inp.shape[-dim:]) + g = g.reshape([g.shape[0], 1, -1, dim]) + gx = g.squeeze(-1) + inp = inp.reshape(list(inp.shape[:2]) + [-1]) + batch = max(inp.shape[0], gx.shape[0]) + channel = inp.shape[1] + + if shape is None: + shape = ishape + shape = list(shape) + nx = shape[0] + + # mask of inbounds voxels + mask = inbounds_mask_1d(extrapolate, gx, nx) + + # corners + # (upper weight, lower corner, upper corner, lower sign, upper sign) + gx, gx0, gx1, signx0, signx1 = get_weights_and_indices(gx, nx, boundx) + + # scatter + out = torch.zeros([batch, channel, nx], + dtype=inp.dtype, device=inp.device) + # - corner 0 + idx = gx0 + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = signx0 + if sign is not None: + out1 = out1 * sign + if mask is not None: + out1 = out1 * mask + out1 = out1 * (1 - gx) + out.scatter_add_(-1, idx, out1) + # - corner 1 + idx = gx1 + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = signx1 + if sign is not None: + out1 = out1 * sign + if mask is not None: + out1 = out1 * mask + out1 = out1 * gx + out.scatter_add_(-1, idx, out1) + + out = out.reshape(list(out.shape[:2]) + shape) + return out + + +@torch.jit.script +def grad1d(inp, g, bound: List[Bound], extrapolate: int = 1): + """ + inp: (B, C, iX) tensor + g: (B, oX, 1) tensor + bound: List{1}[Bound] tensor + extrapolate: ExtrapolateType + returns: (B, C, oX, 1) tensor + """ + dim = 1 + boundx = bound[0] + oshape = list(g.shape[-dim-1:-1]) + g = g.reshape([g.shape[0], 1, -1, dim]) + gx = g.squeeze(-1) + batch = max(inp.shape[0], gx.shape[0]) + channel = inp.shape[1] + shape = list(inp.shape[-dim:]) + nx = shape[0] + + # mask of inbounds voxels + mask = inbounds_mask_1d(extrapolate, gx, nx) + + # corners + # (upper weight, lower corner, upper corner, lower sign, upper sign) + gx, gx0, gx1, signx0, signx1 = get_weights_and_indices(gx, nx, boundx) + + # gather + inp = inp.reshape(list(inp.shape[:2]) + [-1]) + out = torch.empty([batch, channel] + list(g.shape[-2:]), + dtype=inp.dtype, device=inp.device) + outx = out.squeeze(-1) + # - corner 0 + idx = gx0 + idx = idx.expand([batch, channel, idx.shape[-1]]) + torch.gather(inp, -1, idx, out=outx) + sign = signx0 + if sign is not None: + out *= sign.unsqueeze(-1) + outx.neg_() + # - corner 1 + idx = gx1 + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = signx1 + if sign is not None: + out1 *= sign + outx.add_(out1) + + if mask is not None: + out *= mask.unsqueeze(-1) + out = out.reshape(list(out.shape[:2]) + oshape + [dim]) + return out + + +@torch.jit.script +def pushgrad1d(inp, g, shape: Optional[List[int]], bound: List[Bound], + extrapolate: int = 1): + """ + inp: (B, C, iX, 1) tensor + g: (B, iX, 1) tensor + shape: List{1}[int], optional + bound: List{1}[Bound] tensor + extrapolate: ExtrapolateType + returns: (B, C, *shape) tensor + """ + dim = 1 + boundx = bound[0] + if inp.shape[-2] != g.shape[-2]: + raise ValueError('Input and grid should have the same spatial shape') + ishape = list(inp.shape[-dim-1:-1]) + g = g.reshape([g.shape[0], 1, -1, dim]) + gx = g.squeeze(-1) + inp = inp.reshape(list(inp.shape[:2]) + [-1, dim]) + batch = max(inp.shape[0], g.shape[0]) + channel = inp.shape[1] + + if shape is None: + shape = ishape + shape = list(shape) + nx = shape[0] + + # mask of inbounds voxels + mask = inbounds_mask_1d(extrapolate, gx, nx) + + # corners + # (upper weight, lower corner, upper corner, lower sign, upper sign) + gx, gx0, gx1, signx0, signx1 = get_weights_and_indices(gx, nx, boundx) + + # scatter + out = torch.zeros([batch, channel, nx], dtype=inp.dtype, device=inp.device) + # - corner 000 + idx = gx0 + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = signx0 + if sign is not None: + out1 *= sign.unsqueeze(-1) + if mask is not None: + out1 *= mask.unsqueeze(-1) + out1x = out1.squeeze(-1) + out1x.neg_() + out.scatter_add_(-1, idx, out1x) + # - corner 100 + idx = gx1 + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = signx1 + if sign is not None: + out1 *= sign.unsqueeze(-1) + if mask is not None: + out1 *= mask.unsqueeze(-1) + out1x = out1.squeeze(-1) + out.scatter_add_(-1, idx, out1x) + + out = out.reshape(list(out.shape[:2]) + shape) + return out + + +@torch.jit.script +def hess1d(inp, g, bound: List[Bound], extrapolate: int = 1): + """ + inp: (B, C, iX) tensor + g: (B, oX, 1) tensor + bound: List{1}[Bound] tensor + extrapolate: ExtrapolateType + returns: (B, C, oX, 1, 1) tensor + """ + batch = max(inp.shape[0], g.shape[0]) + return torch.zeros([batch, inp.shape[1], g.shape[1], 1, 1], + dtype=inp.dtype, device=inp.device) \ No newline at end of file diff --git a/Generator/interpol/jit_utils.py b/Generator/interpol/jit_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..cc2765af659eb553ef060513c2fad71fe48daadd --- /dev/null +++ b/Generator/interpol/jit_utils.py @@ -0,0 +1,443 @@ +"""A lot of utility functions for TorchScript""" +import torch +import os +from typing import List, Tuple, Optional +from .utils import torch_version +from torch import Tensor + + +@torch.jit.script +def pad_list_int(x: List[int], dim: int) -> List[int]: + if len(x) < dim: + x = x + x[-1:] * (dim - len(x)) + if len(x) > dim: + x = x[:dim] + return x + + +@torch.jit.script +def pad_list_float(x: List[float], dim: int) -> List[float]: + if len(x) < dim: + x = x + x[-1:] * (dim - len(x)) + if len(x) > dim: + x = x[:dim] + return x + + +@torch.jit.script +def pad_list_str(x: List[str], dim: int) -> List[str]: + if len(x) < dim: + x = x + x[-1:] * (dim - len(x)) + if len(x) > dim: + x = x[:dim] + return x + + +@torch.jit.script +def list_any(x: List[bool]) -> bool: + for elem in x: + if elem: + return True + return False + + +@torch.jit.script +def list_all(x: List[bool]) -> bool: + for elem in x: + if not elem: + return False + return True + + +@torch.jit.script +def list_prod_int(x: List[int]) -> int: + if len(x) == 0: + return 1 + x0 = x[0] + for x1 in x[1:]: + x0 = x0 * x1 + return x0 + + +@torch.jit.script +def list_sum_int(x: List[int]) -> int: + if len(x) == 0: + return 1 + x0 = x[0] + for x1 in x[1:]: + x0 = x0 + x1 + return x0 + + +@torch.jit.script +def list_prod_tensor(x: List[Tensor]) -> Tensor: + if len(x) == 0: + empty: List[int] = [] + return torch.ones(empty) + x0 = x[0] + for x1 in x[1:]: + x0 = x0 * x1 + return x0 + + +@torch.jit.script +def list_sum_tensor(x: List[Tensor]) -> Tensor: + if len(x) == 0: + empty: List[int] = [] + return torch.ones(empty) + x0 = x[0] + for x1 in x[1:]: + x0 = x0 + x1 + return x0 + + +@torch.jit.script +def list_reverse_int(x: List[int]) -> List[int]: + if len(x) == 0: + return x + return [x[i] for i in range(-1, -len(x)-1, -1)] + + +@torch.jit.script +def list_cumprod_int(x: List[int], reverse: bool = False, + exclusive: bool = False) -> List[int]: + if len(x) == 0: + lx: List[int] = [] + return lx + if reverse: + x = list_reverse_int(x) + + x0 = 1 if exclusive else x[0] + lx = [x0] + all_x = x[:-1] if exclusive else x[1:] + for x1 in all_x: + x0 = x0 * x1 + lx.append(x0) + if reverse: + lx = list_reverse_int(lx) + return lx + + +@torch.jit.script +def movedim1(x, source: int, destination: int): + dim = x.dim() + source = dim + source if source < 0 else source + destination = dim + destination if destination < 0 else destination + permutation = [d for d in range(dim)] + permutation = permutation[:source] + permutation[source+1:] + permutation = permutation[:destination] + [source] + permutation[destination:] + return x.permute(permutation) + + +@torch.jit.script +def sub2ind(subs, shape: List[int]): + """Convert sub indices (i, j, k) into linear indices. + + The rightmost dimension is the most rapidly changing one + -> if shape == [D, H, W], the strides are therefore [H*W, W, 1] + + Parameters + ---------- + subs : (D, ...) tensor + List of sub-indices. The first dimension is the number of dimension. + Each element should have the same number of elements and shape. + shape : (D,) list[int] + Size of each dimension. Its length should be the same as the + first dimension of ``subs``. + + Returns + ------- + ind : (...) tensor + Linear indices + """ + subs = subs.unbind(0) + ind = subs[-1] + subs = subs[:-1] + ind = ind.clone() + stride = list_cumprod_int(shape[1:], reverse=True, exclusive=False) + for i, s in zip(subs, stride): + ind += i * s + return ind + + +@torch.jit.script +def sub2ind_list(subs: List[Tensor], shape: List[int]): + """Convert sub indices (i, j, k) into linear indices. + + The rightmost dimension is the most rapidly changing one + -> if shape == [D, H, W], the strides are therefore [H*W, W, 1] + + Parameters + ---------- + subs : (D,) list[tensor] + List of sub-indices. The first dimension is the number of dimension. + Each element should have the same number of elements and shape. + shape : (D,) list[int] + Size of each dimension. Its length should be the same as the + first dimension of ``subs``. + + Returns + ------- + ind : (...) tensor + Linear indices + """ + ind = subs[-1] + subs = subs[:-1] + ind = ind.clone() + stride = list_cumprod_int(shape[1:], reverse=True, exclusive=False) + for i, s in zip(subs, stride): + ind += i * s + return ind + +# floor_divide returns wrong results for negative values, because it truncates +# instead of performing a proper floor. In recent version of pytorch, it is +# advised to use div(..., rounding_mode='trunc'|'floor') instead. +# Here, we only use floor_divide on positive values so we do not care. +if torch_version('>=', [1, 8]): + @torch.jit.script + def floor_div(x, y) -> torch.Tensor: + return torch.div(x, y, rounding_mode='floor') + @torch.jit.script + def floor_div_int(x, y: int) -> torch.Tensor: + return torch.div(x, y, rounding_mode='floor') +else: + @torch.jit.script + def floor_div(x, y) -> torch.Tensor: + return (x / y).floor_() + @torch.jit.script + def floor_div_int(x, y: int) -> torch.Tensor: + return (x / y).floor_() + + +@torch.jit.script +def ind2sub(ind, shape: List[int]): + """Convert linear indices into sub indices (i, j, k). + + The rightmost dimension is the most rapidly changing one + -> if shape == [D, H, W], the strides are therefore [H*W, W, 1] + + Parameters + ---------- + ind : tensor_like + Linear indices + shape : (D,) vector_like + Size of each dimension. + + Returns + ------- + subs : (D, ...) tensor + Sub-indices. + """ + stride = list_cumprod_int(shape, reverse=True, exclusive=True) + sub = ind.new_empty([len(shape)] + ind.shape) + sub.copy_(ind) + for d in range(len(shape)): + if d > 0: + sub[d] = torch.remainder(sub[d], stride[d-1]) + sub[d] = floor_div_int(sub[d], stride[d]) + return sub + + +@torch.jit.script +def inbounds_mask_3d(extrapolate: int, gx, gy, gz, nx: int, ny: int, nz: int) \ + -> Optional[Tensor]: + # mask of inbounds voxels + mask: Optional[Tensor] = None + if extrapolate in (0, 2): # no / hist + tiny = 5e-2 + threshold = tiny + if extrapolate == 2: + threshold = 0.5 + tiny + mask = ((gx > -threshold) & (gx < nx - 1 + threshold) & + (gy > -threshold) & (gy < ny - 1 + threshold) & + (gz > -threshold) & (gz < nz - 1 + threshold)) + return mask + return mask + + +@torch.jit.script +def inbounds_mask_2d(extrapolate: int, gx, gy, nx: int, ny: int) \ + -> Optional[Tensor]: + # mask of inbounds voxels + mask: Optional[Tensor] = None + if extrapolate in (0, 2): # no / hist + tiny = 5e-2 + threshold = tiny + if extrapolate == 2: + threshold = 0.5 + tiny + mask = ((gx > -threshold) & (gx < nx - 1 + threshold) & + (gy > -threshold) & (gy < ny - 1 + threshold)) + return mask + return mask + + +@torch.jit.script +def inbounds_mask_1d(extrapolate: int, gx, nx: int) -> Optional[Tensor]: + # mask of inbounds voxels + mask: Optional[Tensor] = None + if extrapolate in (0, 2): # no / hist + tiny = 5e-2 + threshold = tiny + if extrapolate == 2: + threshold = 0.5 + tiny + mask = (gx > -threshold) & (gx < nx - 1 + threshold) + return mask + return mask + + +@torch.jit.script +def make_sign(sign: List[Optional[Tensor]]) -> Optional[Tensor]: + is_none : List[bool] = [s is None for s in sign] + if list_all(is_none): + return None + filt_sign: List[Tensor] = [] + for s in sign: + if s is not None: + filt_sign.append(s) + return list_prod_tensor(filt_sign) + + +@torch.jit.script +def square(x): + return x * x + + +@torch.jit.script +def square_(x): + return x.mul_(x) + + +@torch.jit.script +def cube(x): + return x * x * x + + +@torch.jit.script +def cube_(x): + return square_(x).mul_(x) + + +@torch.jit.script +def pow4(x): + return square(square(x)) + + +@torch.jit.script +def pow4_(x): + return square_(square_(x)) + + +@torch.jit.script +def pow5(x): + return x * pow4(x) + + +@torch.jit.script +def pow5_(x): + return pow4_(x).mul_(x) + + +@torch.jit.script +def pow6(x): + return square(cube(x)) + + +@torch.jit.script +def pow6_(x): + return square_(cube_(x)) + + +@torch.jit.script +def pow7(x): + return pow6(x) * x + + +@torch.jit.script +def pow7_(x): + return pow6_(x).mul_(x) + + +@torch.jit.script +def dot(x, y, dim: int = -1, keepdim: bool = False): + """(Batched) dot product along a dimension""" + x = movedim1(x, dim, -1).unsqueeze(-2) + y = movedim1(y, dim, -1).unsqueeze(-1) + d = torch.matmul(x, y).squeeze(-1).squeeze(-1) + if keepdim: + d.unsqueeze(dim) + return d + + +@torch.jit.script +def dot_multi(x, y, dim: List[int], keepdim: bool = False): + """(Batched) dot product along a dimension""" + for d in dim: + x = movedim1(x, d, -1) + y = movedim1(y, d, -1) + x = x.reshape(x.shape[:-len(dim)] + [1, -1]) + y = y.reshape(x.shape[:-len(dim)] + [-1, 1]) + dt = torch.matmul(x, y).squeeze(-1).squeeze(-1) + if keepdim: + for d in dim: + dt.unsqueeze(d) + return dt + + + +# cartesian_prod takes multiple inout tensors as input in eager mode +# but takes a list of tensor in jit mode. This is a helper that works +# in both cases. +if not int(os.environ.get('PYTORCH_JIT', '1')): + cartesian_prod = lambda x: torch.cartesian_prod(*x) + if torch_version('>=', (1, 10)): + def meshgrid_ij(x: List[torch.Tensor]) -> List[torch.Tensor]: + return torch.meshgrid(*x, indexing='ij') + def meshgrid_xy(x: List[torch.Tensor]) -> List[torch.Tensor]: + return torch.meshgrid(*x, indexing='xy') + else: + def meshgrid_ij(x: List[torch.Tensor]) -> List[torch.Tensor]: + return torch.meshgrid(*x) + def meshgrid_xy(x: List[torch.Tensor]) -> List[torch.Tensor]: + grid = torch.meshgrid(*x) + if len(grid) > 1: + grid[0] = grid[0].transpose(0, 1) + grid[1] = grid[1].transpose(0, 1) + return grid + +else: + cartesian_prod = torch.cartesian_prod + if torch_version('>=', (1, 10)): + @torch.jit.script + def meshgrid_ij(x: List[torch.Tensor]) -> List[torch.Tensor]: + return torch.meshgrid(x, indexing='ij') + @torch.jit.script + def meshgrid_xy(x: List[torch.Tensor]) -> List[torch.Tensor]: + return torch.meshgrid(x, indexing='xy') + else: + @torch.jit.script + def meshgrid_ij(x: List[torch.Tensor]) -> List[torch.Tensor]: + return torch.meshgrid(x) + @torch.jit.script + def meshgrid_xyt(x: List[torch.Tensor]) -> List[torch.Tensor]: + grid = torch.meshgrid(x) + if len(grid) > 1: + grid[0] = grid[0].transpose(0, 1) + grid[1] = grid[1].transpose(0, 1) + return grid + + +meshgrid = meshgrid_ij + + +# In torch < 1.6, div applied to integer tensor performed a floor_divide +# In torch > 1.6, it performs a true divide. +# Floor division must be done using `floor_divide`, but it was buggy +# until torch 1.13 (it was doing a trunc divide instead of a floor divide). +# There was at some point a deprecation warning for floor_divide, but it +# seems to have been lifted afterwards. In torch >= 1.13, floor_divide +# performs a correct floor division. +# Since we only apply floor_divide ot positive values, we are fine. +if torch_version('<', (1, 6)): + floor_div = torch.div +else: + floor_div = torch.floor_divide \ No newline at end of file diff --git a/Generator/interpol/jitfields.py b/Generator/interpol/jitfields.py new file mode 100644 index 0000000000000000000000000000000000000000..b758a8085de51847d2aef8b0b1795a1720a92136 --- /dev/null +++ b/Generator/interpol/jitfields.py @@ -0,0 +1,95 @@ +try: + import jitfields + available = True +except (ImportError, ModuleNotFoundError): + jitfields = None + available = False +from .utils import make_list +import torch + + +def first2last(input, ndim): + insert = input.dim() <= ndim + if insert: + input = input.unsqueeze(-1) + else: + input = torch.movedim(input, -ndim-1, -1) + return input, insert + + +def last2first(input, ndim, inserted, grad=False): + if inserted: + input = input.squeeze(-1 - grad) + else: + input = torch.movedim(input, -1 - grad, -ndim-1 - grad) + return input + + +def grid_pull(input, grid, interpolation='linear', bound='zero', + extrapolate=False, prefilter=False): + ndim = grid.shape[-1] + input, inserted = first2last(input, ndim) + input = jitfields.pull(input, grid, order=interpolation, bound=bound, + extrapolate=extrapolate, prefilter=prefilter) + input = last2first(input, ndim, inserted) + return input + + +def grid_push(input, grid, shape=None, interpolation='linear', bound='zero', + extrapolate=False, prefilter=False): + ndim = grid.shape[-1] + input, inserted = first2last(input, ndim) + input = jitfields.push(input, grid, shape, order=interpolation, bound=bound, + extrapolate=extrapolate, prefilter=prefilter) + input = last2first(input, ndim, inserted) + return input + + +def grid_count(grid, shape=None, interpolation='linear', bound='zero', + extrapolate=False): + return jitfields.count(grid, shape, order=interpolation, bound=bound, + extrapolate=extrapolate) + + +def grid_grad(input, grid, interpolation='linear', bound='zero', + extrapolate=False, prefilter=False): + ndim = grid.shape[-1] + input, inserted = first2last(input, ndim) + input = jitfields.grad(input, grid, order=interpolation, bound=bound, + extrapolate=extrapolate, prefilter=prefilter) + input = last2first(input, ndim, inserted, True) + return input + + +def spline_coeff(input, interpolation='linear', bound='dct2', dim=-1, + inplace=False): + func = jitfields.spline_coeff_ if inplace else jitfields.spline_coeff + return func(input, interpolation, bound=bound, dim=dim) + + +def spline_coeff_nd(input, interpolation='linear', bound='dct2', dim=None, + inplace=False): + func = jitfields.spline_coeff_nd_ if inplace else jitfields.spline_coeff_nd + return func(input, interpolation, bound=bound, ndim=dim) + + +def resize(image, factor=None, shape=None, anchor='c', + interpolation=1, prefilter=True, **kwargs): + kwargs.setdefault('bound', 'nearest') + ndim = max(len(make_list(factor or [])), + len(make_list(shape or [])), + len(make_list(anchor or []))) or (image.dim() - 2) + return jitfields.resize(image, factor=factor, shape=shape, ndim=ndim, + anchor=anchor, order=interpolation, + bound=kwargs['bound'], prefilter=prefilter) + + +def restrict(image, factor=None, shape=None, anchor='c', + interpolation=1, reduce_sum=False, **kwargs): + kwargs.setdefault('bound', 'nearest') + ndim = max(len(make_list(factor or [])), + len(make_list(shape or [])), + len(make_list(anchor or []))) or (image.dim() - 2) + return jitfields.restrict(image, factor=factor, shape=shape, ndim=ndim, + anchor=anchor, order=interpolation, + bound=kwargs['bound'], reduce_sum=reduce_sum) diff --git a/Generator/interpol/nd.py b/Generator/interpol/nd.py new file mode 100644 index 0000000000000000000000000000000000000000..1a366ff2e8ca3c07f15defb01ff6df9fa3990ed6 --- /dev/null +++ b/Generator/interpol/nd.py @@ -0,0 +1,464 @@ +"""Generic N-dimensional version: any combination of spline orders""" +import torch +from typing import List, Optional, Tuple +from .bounds import Bound +from .splines import Spline +from .jit_utils import sub2ind_list, make_sign, list_prod_int, cartesian_prod +Tensor = torch.Tensor + + +@torch.jit.script +def inbounds_mask(extrapolate: int, grid, shape: List[int])\ + -> Optional[Tensor]: + # mask of inbounds voxels + mask: Optional[Tensor] = None + if extrapolate in (0, 2): # no / hist + grid = grid.unsqueeze(1) + tiny = 5e-2 + threshold = tiny + if extrapolate == 2: + threshold = 0.5 + tiny + mask = torch.ones(grid.shape[:-1], + dtype=torch.bool, device=grid.device) + for grid1, shape1 in zip(grid.unbind(-1), shape): + mask = mask & (grid1 > -threshold) + mask = mask & (grid1 < shape1 - 1 + threshold) + return mask + return mask + + +@torch.jit.script +def get_weights(grid, bound: List[Bound], spline: List[Spline], + shape: List[int], grad: bool = False, hess: bool = False) \ + -> Tuple[List[List[Tensor]], + List[List[Optional[Tensor]]], + List[List[Optional[Tensor]]], + List[List[Tensor]], + List[List[Optional[Tensor]]]]: + + weights: List[List[Tensor]] = [] + grads: List[List[Optional[Tensor]]] = [] + hesss: List[List[Optional[Tensor]]] = [] + coords: List[List[Tensor]] = [] + signs: List[List[Optional[Tensor]]] = [] + for g, b, s, n in zip(grid.unbind(-1), bound, spline, shape): + grid0 = (g - (s.order-1)/2).floor() + dist0 = g - grid0 + grid0 = grid0.long() + nb_nodes = s.order + 1 + subweights: List[Tensor] = [] + subcoords: List[Tensor] = [] + subgrads: List[Optional[Tensor]] = [] + subhesss: List[Optional[Tensor]] = [] + subsigns: List[Optional[Tensor]] = [] + for node in range(nb_nodes): + grid1 = grid0 + node + sign1: Optional[Tensor] = b.transform(grid1, n) + subsigns.append(sign1) + grid1 = b.index(grid1, n) + subcoords.append(grid1) + dist1 = dist0 - node + weight1 = s.fastweight(dist1) + subweights.append(weight1) + grad1: Optional[Tensor] = None + if grad: + grad1 = s.fastgrad(dist1) + subgrads.append(grad1) + hess1: Optional[Tensor] = None + if hess: + hess1 = s.fasthess(dist1) + subhesss.append(hess1) + weights.append(subweights) + coords.append(subcoords) + signs.append(subsigns) + grads.append(subgrads) + hesss.append(subhesss) + + return weights, grads, hesss, coords, signs + + +@torch.jit.script +def pull(inp, grid, bound: List[Bound], spline: List[Spline], + extrapolate: int = 1): + """ + inp: (B, C, *ishape) tensor + g: (B, *oshape, D) tensor + bound: List{D}[Bound] tensor + spline: List{D}[Spline] tensor + extrapolate: int + returns: (B, C, *oshape) tensor + """ + + dim = grid.shape[-1] + shape = list(inp.shape[-dim:]) + oshape = list(grid.shape[-dim-1:-1]) + batch = max(inp.shape[0], grid.shape[0]) + channel = inp.shape[1] + + grid = grid.reshape([grid.shape[0], -1, grid.shape[-1]]) + inp = inp.reshape([inp.shape[0], inp.shape[1], -1]) + mask = inbounds_mask(extrapolate, grid, shape) + + # precompute weights along each dimension + weights, _, _, coords, signs = get_weights(grid, bound, spline, shape, False, False) + + # initialize + out = torch.zeros([batch, channel, grid.shape[1]], + dtype=inp.dtype, device=inp.device) + + # iterate across nodes/corners + range_nodes = [torch.as_tensor([d for d in range(n)]) + for n in [s.order + 1 for s in spline]] + if dim == 1: + # cartesian_prod does not work as expected when only one + # element is provided + all_nodes = range_nodes[0].unsqueeze(-1) + else: + all_nodes = cartesian_prod(range_nodes) + for nodes in all_nodes: + # gather + idx = [c[n] for c, n in zip(coords, nodes)] + idx = sub2ind_list(idx, shape).unsqueeze(1) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + + # apply sign + sign0: List[Optional[Tensor]] = [sgn[n] for sgn, n in zip(signs, nodes)] + sign1: Optional[Tensor] = make_sign(sign0) + if sign1 is not None: + out1 = out1 * sign1.unsqueeze(1) + + # apply weights + for weight, n in zip(weights, nodes): + out1 = out1 * weight[n].unsqueeze(1) + + # accumulate + out = out + out1 + + # out-of-bounds mask + if mask is not None: + out = out * mask + + out = out.reshape(list(out.shape[:2]) + oshape) + return out + + +@torch.jit.script +def push(inp, grid, shape: Optional[List[int]], bound: List[Bound], + spline: List[Spline], extrapolate: int = 1): + """ + inp: (B, C, *ishape) tensor + g: (B, *ishape, D) tensor + shape: List{D}[int], optional + bound: List{D}[Bound] tensor + spline: List{D}[Spline] tensor + extrapolate: int + returns: (B, C, *oshape) tensor + """ + + dim = grid.shape[-1] + ishape = list(grid.shape[-dim - 1:-1]) + if shape is None: + shape = ishape + shape = list(shape) + batch = max(inp.shape[0], grid.shape[0]) + channel = inp.shape[1] + + grid = grid.reshape([grid.shape[0], -1, grid.shape[-1]]) + inp = inp.reshape([inp.shape[0], inp.shape[1], -1]) + mask = inbounds_mask(extrapolate, grid, shape) + + # precompute weights along each dimension + weights, _, _, coords, signs = get_weights(grid, bound, spline, shape) + + # initialize + out = torch.zeros([batch, channel, list_prod_int(shape)], + dtype=inp.dtype, device=inp.device) + + # iterate across nodes/corners + range_nodes = [torch.as_tensor([d for d in range(n)]) + for n in [s.order + 1 for s in spline]] + if dim == 1: + # cartesian_prod does not work as expected when only one + # element is provided + all_nodes = range_nodes[0].unsqueeze(-1) + else: + all_nodes = cartesian_prod(range_nodes) + for nodes in all_nodes: + + # gather + idx = [c[n] for c, n in zip(coords, nodes)] + idx = sub2ind_list(idx, shape).unsqueeze(1) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + + # apply sign + sign0: List[Optional[Tensor]] = [sgn[n] for sgn, n in zip(signs, nodes)] + sign1: Optional[Tensor] = make_sign(sign0) + if sign1 is not None: + out1 = out1 * sign1.unsqueeze(1) + + # out-of-bounds mask + if mask is not None: + out1 = out1 * mask + + # apply weights + for weight, n in zip(weights, nodes): + out1 = out1 * weight[n].unsqueeze(1) + + # accumulate + out.scatter_add_(-1, idx, out1) + + out = out.reshape(list(out.shape[:2]) + shape) + return out + + +@torch.jit.script +def grad(inp, grid, bound: List[Bound], spline: List[Spline], + extrapolate: int = 1): + """ + inp: (B, C, *ishape) tensor + grid: (B, *oshape, D) tensor + bound: List{D}[Bound] tensor + spline: List{D}[Spline] tensor + extrapolate: int + returns: (B, C, *oshape, D) tensor + """ + + dim = grid.shape[-1] + shape = list(inp.shape[-dim:]) + oshape = list(grid.shape[-dim-1:-1]) + batch = max(inp.shape[0], grid.shape[0]) + channel = inp.shape[1] + + grid = grid.reshape([grid.shape[0], -1, grid.shape[-1]]) + inp = inp.reshape([inp.shape[0], inp.shape[1], -1]) + mask = inbounds_mask(extrapolate, grid, shape) + + # precompute weights along each dimension + weights, grads, _, coords, signs = get_weights(grid, bound, spline, shape, + grad=True) + + # initialize + out = torch.zeros([batch, channel, grid.shape[1], dim], + dtype=inp.dtype, device=inp.device) + + # iterate across nodes/corners + range_nodes = [torch.as_tensor([d for d in range(n)]) + for n in [s.order + 1 for s in spline]] + if dim == 1: + # cartesian_prod does not work as expected when only one + # element is provided + all_nodes = range_nodes[0].unsqueeze(-1) + else: + all_nodes = cartesian_prod(range_nodes) + for nodes in all_nodes: + + # gather + idx = [c[n] for c, n in zip(coords, nodes)] + idx = sub2ind_list(idx, shape).unsqueeze(1) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out0 = inp.gather(-1, idx) + + # apply sign + sign0: List[Optional[Tensor]] = [sgn[n] for sgn, n in zip(signs, nodes)] + sign1: Optional[Tensor] = make_sign(sign0) + if sign1 is not None: + out0 = out0 * sign1.unsqueeze(1) + + for d in range(dim): + out1 = out0.clone() + # apply weights + for dd, (weight, grad1, n) in enumerate(zip(weights, grads, nodes)): + if d == dd: + grad11 = grad1[n] + if grad11 is not None: + out1 = out1 * grad11.unsqueeze(1) + else: + out1 = out1 * weight[n].unsqueeze(1) + + # accumulate + out.unbind(-1)[d].add_(out1) + + # out-of-bounds mask + if mask is not None: + out = out * mask.unsqueeze(-1) + + out = out.reshape(list(out.shape[:2]) + oshape + list(out.shape[-1:])) + return out + + +@torch.jit.script +def pushgrad(inp, grid, shape: Optional[List[int]], bound: List[Bound], + spline: List[Spline], extrapolate: int = 1): + """ + inp: (B, C, *ishape, D) tensor + g: (B, *ishape, D) tensor + shape: List{D}[int], optional + bound: List{D}[Bound] tensor + spline: List{D}[Spline] tensor + extrapolate: int + returns: (B, C, *shape) tensor + """ + dim = grid.shape[-1] + oshape = list(grid.shape[-dim-1:-1]) + if shape is None: + shape = oshape + shape = list(shape) + batch = max(inp.shape[0], grid.shape[0]) + channel = inp.shape[1] + + grid = grid.reshape([grid.shape[0], -1, grid.shape[-1]]) + inp = inp.reshape([inp.shape[0], inp.shape[1], -1, dim]) + mask = inbounds_mask(extrapolate, grid, shape) + + # precompute weights along each dimension + weights, grads, _, coords, signs = get_weights(grid, bound, spline, shape, grad=True) + + # initialize + out = torch.zeros([batch, channel, list_prod_int(shape)], + dtype=inp.dtype, device=inp.device) + + # iterate across nodes/corners + range_nodes = [torch.as_tensor([d for d in range(n)]) + for n in [s.order + 1 for s in spline]] + if dim == 1: + # cartesian_prod does not work as expected when only one + # element is provided + all_nodes = range_nodes[0].unsqueeze(-1) + else: + all_nodes = cartesian_prod(range_nodes) + for nodes in all_nodes: + + # gather + idx = [c[n] for c, n in zip(coords, nodes)] + idx = sub2ind_list(idx, shape).unsqueeze(1) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out0 = inp.clone() + + # apply sign + sign0: List[Optional[Tensor]] = [sgn[n] for sgn, n in zip(signs, nodes)] + sign1: Optional[Tensor] = make_sign(sign0) + if sign1 is not None: + out0 = out0 * sign1.unsqueeze(1).unsqueeze(-1) + + # out-of-bounds mask + if mask is not None: + out0 = out0 * mask.unsqueeze(-1) + + for d in range(dim): + out1 = out0.unbind(-1)[d].clone() + # apply weights + for dd, (weight, grad1, n) in enumerate(zip(weights, grads, nodes)): + if d == dd: + grad11 = grad1[n] + if grad11 is not None: + out1 = out1 * grad11.unsqueeze(1) + else: + out1 = out1 * weight[n].unsqueeze(1) + + # accumulate + out.scatter_add_(-1, idx, out1) + + out = out.reshape(list(out.shape[:2]) + shape) + return out + + +@torch.jit.script +def hess(inp, grid, bound: List[Bound], spline: List[Spline], + extrapolate: int = 1): + """ + inp: (B, C, *ishape) tensor + grid: (B, *oshape, D) tensor + bound: List{D}[Bound] tensor + spline: List{D}[Spline] tensor + extrapolate: int + returns: (B, C, *oshape, D, D) tensor + """ + + dim = grid.shape[-1] + shape = list(inp.shape[-dim:]) + oshape = list(grid.shape[-dim-1:-1]) + batch = max(inp.shape[0], grid.shape[0]) + channel = inp.shape[1] + + grid = grid.reshape([grid.shape[0], -1, grid.shape[-1]]) + inp = inp.reshape([inp.shape[0], inp.shape[1], -1]) + mask = inbounds_mask(extrapolate, grid, shape) + + # precompute weights along each dimension + weights, grads, hesss, coords, signs \ + = get_weights(grid, bound, spline, shape, grad=True, hess=True) + + # initialize + out = torch.zeros([batch, channel, grid.shape[1], dim, dim], + dtype=inp.dtype, device=inp.device) + + # iterate across nodes/corners + range_nodes = [torch.as_tensor([d for d in range(n)]) + for n in [s.order + 1 for s in spline]] + if dim == 1: + # cartesian_prod does not work as expected when only one + # element is provided + all_nodes = range_nodes[0].unsqueeze(-1) + else: + all_nodes = cartesian_prod(range_nodes) + for nodes in all_nodes: + + # gather + idx = [c[n] for c, n in zip(coords, nodes)] + idx = sub2ind_list(idx, shape).unsqueeze(1) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out0 = inp.gather(-1, idx) + + # apply sign + sign0: List[Optional[Tensor]] = [sgn[n] for sgn, n in zip(signs, nodes)] + sign1: Optional[Tensor] = make_sign(sign0) + if sign1 is not None: + out0 = out0 * sign1.unsqueeze(1) + + for d in range(dim): + # -- diagonal -- + out1 = out0.clone() + + # apply weights + for dd, (weight, hess1, n) \ + in enumerate(zip(weights, hesss, nodes)): + if d == dd: + hess11 = hess1[n] + if hess11 is not None: + out1 = out1 * hess11.unsqueeze(1) + else: + out1 = out1 * weight[n].unsqueeze(1) + + # accumulate + out.unbind(-1)[d].unbind(-1)[d].add_(out1) + + # -- off diagonal -- + for d2 in range(d+1, dim): + out1 = out0.clone() + + # apply weights + for dd, (weight, grad1, n) \ + in enumerate(zip(weights, grads, nodes)): + if dd in (d, d2): + grad11 = grad1[n] + if grad11 is not None: + out1 = out1 * grad11.unsqueeze(1) + else: + out1 = out1 * weight[n].unsqueeze(1) + + # accumulate + out.unbind(-1)[d].unbind(-1)[d2].add_(out1) + + # out-of-bounds mask + if mask is not None: + out = out * mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) + + # fill lower triangle + for d in range(dim): + for d2 in range(d+1, dim): + out.unbind(-1)[d2].unbind(-1)[d].copy_(out.unbind(-1)[d].unbind(-1)[d2]) + + out = out.reshape(list(out.shape[:2]) + oshape + list(out.shape[-2:])) + return out diff --git a/Generator/interpol/pushpull.py b/Generator/interpol/pushpull.py new file mode 100644 index 0000000000000000000000000000000000000000..d37b2d3e4815f6544c0b5f2c37b40a9087196bbc --- /dev/null +++ b/Generator/interpol/pushpull.py @@ -0,0 +1,325 @@ +""" +Non-differentiable forward/backward components. +These components are put together in `interpol.autograd` to generate +differentiable functions. + +Note +---- +.. I removed @torch.jit.script from these entry-points because compiling + all possible combinations of bound+interpolation made the first call + extremely slow. +.. I am not using the dot/multi_dot helpers even though they should be + more efficient that "multiply and sum" because I haven't had the time + to test them. It would be worth doing it. +""" +import torch +from typing import List, Optional, Tuple +from .jit_utils import list_all, dot, dot_multi, pad_list_int +from .bounds import Bound +from .splines import Spline +from . import iso0, iso1, nd +Tensor = torch.Tensor + + +@torch.jit.script +def make_bound(bound: List[int]) -> List[Bound]: + return [Bound(b) for b in bound] + + +@torch.jit.script +def make_spline(spline: List[int]) -> List[Spline]: + return [Spline(s) for s in spline] + + +# @torch.jit.script +def grid_pull(inp, grid, bound: List[int], interpolation: List[int], + extrapolate: int): + """ + inp: (B, C, *spatial_in) tensor + grid: (B, *spatial_out, D) tensor + bound: List{D}[int] tensor + interpolation: List{D}[int] + extrapolate: int + returns: (B, C, *spatial_out) tensor + """ + dim = grid.shape[-1] + bound = pad_list_int(bound, dim) + interpolation = pad_list_int(interpolation, dim) + bound_fn = make_bound(bound) + is_iso1 = list_all([order == 1 for order in interpolation]) + if is_iso1: + if dim == 3: + return iso1.pull3d(inp, grid, bound_fn, extrapolate) + elif dim == 2: + return iso1.pull2d(inp, grid, bound_fn, extrapolate) + elif dim == 1: + return iso1.pull1d(inp, grid, bound_fn, extrapolate) + is_iso0 = list_all([order == 0 for order in interpolation]) + if is_iso0: + if dim == 3: + return iso0.pull3d(inp, grid, bound_fn, extrapolate) + elif dim == 2: + return iso0.pull2d(inp, grid, bound_fn, extrapolate) + elif dim == 1: + return iso0.pull1d(inp, grid, bound_fn, extrapolate) + spline_fn = make_spline(interpolation) + return nd.pull(inp, grid, bound_fn, spline_fn, extrapolate) + + +# @torch.jit.script +def grid_push(inp, grid, shape: Optional[List[int]], bound: List[int], + interpolation: List[int], extrapolate: int): + """ + inp: (B, C, *spatial_in) tensor + grid: (B, *spatial_in, D) tensor + shape: List{D}[int] tensor, optional, default=spatial_in + bound: List{D}[int] tensor + interpolation: List{D}[int] + extrapolate: int + returns: (B, C, *shape) tensor + """ + dim = grid.shape[-1] + bound = pad_list_int(bound, dim) + interpolation = pad_list_int(interpolation, dim) + bound_fn = make_bound(bound) + is_iso1 = list_all([order == 1 for order in interpolation]) + if is_iso1: + if dim == 3: + return iso1.push3d(inp, grid, shape, bound_fn, extrapolate) + elif dim == 2: + return iso1.push2d(inp, grid, shape, bound_fn, extrapolate) + elif dim == 1: + return iso1.push1d(inp, grid, shape, bound_fn, extrapolate) + is_iso0 = list_all([order == 0 for order in interpolation]) + if is_iso0: + if dim == 3: + return iso0.push3d(inp, grid, shape, bound_fn, extrapolate) + elif dim == 2: + return iso0.push2d(inp, grid, shape, bound_fn, extrapolate) + elif dim == 1: + return iso0.push1d(inp, grid, shape, bound_fn, extrapolate) + spline_fn = make_spline(interpolation) + return nd.push(inp, grid, shape, bound_fn, spline_fn, extrapolate) + + +# @torch.jit.script +def grid_count(grid, shape: Optional[List[int]], bound: List[int], + interpolation: List[int], extrapolate: int): + """ + grid: (B, *spatial_in, D) tensor + shape: List{D}[int] tensor, optional, default=spatial_in + bound: List{D}[int] tensor + interpolation: List{D}[int] + extrapolate: int + returns: (B, 1, *shape) tensor + """ + dim = grid.shape[-1] + bound = pad_list_int(bound, dim) + interpolation = pad_list_int(interpolation, dim) + bound_fn = make_bound(bound) + gshape = list(grid.shape[-dim-1:-1]) + if shape is None: + shape = gshape + inp = torch.ones([], dtype=grid.dtype, device=grid.device) + inp = inp.expand([len(grid), 1] + gshape) + is_iso1 = list_all([order == 1 for order in interpolation]) + if is_iso1: + if dim == 3: + return iso1.push3d(inp, grid, shape, bound_fn, extrapolate) + elif dim == 2: + return iso1.push2d(inp, grid, shape, bound_fn, extrapolate) + elif dim == 1: + return iso1.push1d(inp, grid, shape, bound_fn, extrapolate) + is_iso0 = list_all([order == 0 for order in interpolation]) + if is_iso0: + if dim == 3: + return iso0.push3d(inp, grid, shape, bound_fn, extrapolate) + elif dim == 2: + return iso0.push2d(inp, grid, shape, bound_fn, extrapolate) + elif dim == 1: + return iso0.push1d(inp, grid, shape, bound_fn, extrapolate) + spline_fn = make_spline(interpolation) + return nd.push(inp, grid, shape, bound_fn, spline_fn, extrapolate) + + +# @torch.jit.script +def grid_grad(inp, grid, bound: List[int], interpolation: List[int], + extrapolate: int): + """ + inp: (B, C, *spatial_in) tensor + grid: (B, *spatial_out, D) tensor + bound: List{D}[int] tensor + interpolation: List{D}[int] + extrapolate: int + returns: (B, C, *spatial_out, D) tensor + """ + dim = grid.shape[-1] + bound = pad_list_int(bound, dim) + interpolation = pad_list_int(interpolation, dim) + bound_fn = make_bound(bound) + is_iso1 = list_all([order == 1 for order in interpolation]) + if is_iso1: + if dim == 3: + return iso1.grad3d(inp, grid, bound_fn, extrapolate) + elif dim == 2: + return iso1.grad2d(inp, grid, bound_fn, extrapolate) + elif dim == 1: + return iso1.grad1d(inp, grid, bound_fn, extrapolate) + is_iso0 = list_all([order == 0 for order in interpolation]) + if is_iso0: + return iso0.grad(inp, grid, bound_fn, extrapolate) + spline_fn = make_spline(interpolation) + return nd.grad(inp, grid, bound_fn, spline_fn, extrapolate) + + +# @torch.jit.script +def grid_pushgrad(inp, grid, shape: List[int], bound: List[int], + interpolation: List[int], extrapolate: int): + """ /!\ Used only in backward pass of grid_grad + inp: (B, C, *spatial_in, D) tensor + grid: (B, *spatial_in, D) tensor + shape: List{D}[int], optional + bound: List{D}[int] tensor + interpolation: List{D}[int] + extrapolate: int + returns: (B, C, *shape) tensor + """ + dim = grid.shape[-1] + bound = pad_list_int(bound, dim) + interpolation = pad_list_int(interpolation, dim) + bound_fn = make_bound(bound) + is_iso1 = list_all([order == 1 for order in interpolation]) + if is_iso1: + if dim == 3: + return iso1.pushgrad3d(inp, grid, shape, bound_fn, extrapolate) + elif dim == 2: + return iso1.pushgrad2d(inp, grid, shape, bound_fn, extrapolate) + elif dim == 1: + return iso1.pushgrad1d(inp, grid, shape, bound_fn, extrapolate) + is_iso0 = list_all([order == 0 for order in interpolation]) + if is_iso0: + return iso0.pushgrad(inp, grid, shape, bound_fn, extrapolate) + spline_fn = make_spline(interpolation) + return nd.pushgrad(inp, grid, shape, bound_fn, spline_fn, extrapolate) + + +# @torch.jit.script +def grid_hess(inp, grid, bound: List[int], interpolation: List[int], + extrapolate: int): + """ /!\ Used only in backward pass of grid_grad + inp: (B, C, *spatial_in) tensor + grid: (B, *spatial_out, D) tensor + bound: List{D}[int] tensor + interpolation: List{D}[int] + extrapolate: int + returns: (B, C, *spatial_out, D, D) tensor + """ + dim = grid.shape[-1] + bound = pad_list_int(bound, dim) + interpolation = pad_list_int(interpolation, dim) + bound_fn = make_bound(bound) + is_iso1 = list_all([order == 1 for order in interpolation]) + if is_iso1: + if dim == 3: + return iso1.hess3d(inp, grid, bound_fn, extrapolate) + if dim == 2: + return iso1.hess2d(inp, grid, bound_fn, extrapolate) + if dim == 1: + return iso1.hess1d(inp, grid, bound_fn, extrapolate) + is_iso0 = list_all([order == 0 for order in interpolation]) + if is_iso0: + return iso0.hess(inp, grid, bound_fn, extrapolate) + spline_fn = make_spline(interpolation) + return nd.hess(inp, grid, bound_fn, spline_fn, extrapolate) + + +# @torch.jit.script +def grid_pull_backward(grad, inp, grid, bound: List[int], + interpolation: List[int], extrapolate: int) \ + -> Tuple[Optional[Tensor], Optional[Tensor], ]: + """ + grad: (B, C, *spatial_out) tensor + inp: (B, C, *spatial_in) tensor + grid: (B, *spatial_out, D) tensor + bound: List{D}[int] tensor + interpolation: List{D}[int] + extrapolate: int + returns: (B, C, *spatial_in) tensor, (B, *spatial_out, D) + """ + dim = grid.shape[-1] + grad_inp: Optional[Tensor] = None + grad_grid: Optional[Tensor] = None + if inp.requires_grad: + grad_inp = grid_push(grad, grid, inp.shape[-dim:], bound, interpolation, extrapolate) + if grid.requires_grad: + grad_grid = grid_grad(inp, grid, bound, interpolation, extrapolate) + # grad_grid = dot(grad_grid, grad.unsqueeze(-1), dim=1) + grad_grid = (grad_grid * grad.unsqueeze(-1)).sum(dim=1) + return grad_inp, grad_grid + + +# @torch.jit.script +def grid_push_backward(grad, inp, grid, bound: List[int], + interpolation: List[int], extrapolate: int) \ + -> Tuple[Optional[Tensor], Optional[Tensor], ]: + """ + grad: (B, C, *spatial_out) tensor + inp: (B, C, *spatial_in) tensor + grid: (B, *spatial_in, D) tensor + bound: List{D}[int] tensor + interpolation: List{D}[int] + extrapolate: int + returns: (B, C, *spatial_in) tensor, (B, *spatial_in, D) + """ + grad_inp: Optional[Tensor] = None + grad_grid: Optional[Tensor] = None + if inp.requires_grad: + grad_inp = grid_pull(grad, grid, bound, interpolation, extrapolate) + if grid.requires_grad: + grad_grid = grid_grad(grad, grid, bound, interpolation, extrapolate) + # grad_grid = dot(grad_grid, inp.unsqueeze(-1), dim=1) + grad_grid = (grad_grid * inp.unsqueeze(-1)).sum(dim=1) + return grad_inp, grad_grid + + +# @torch.jit.script +def grid_count_backward(grad, grid, bound: List[int], + interpolation: List[int], extrapolate: int) \ + -> Optional[Tensor]: + """ + grad: (B, C, *spatial_out) tensor + grid: (B, *spatial_in, D) tensor + bound: List{D}[int] tensor + interpolation: List{D}[int] + extrapolate: int + returns: (B, C, *spatial_in) tensor, (B, *spatial_in, D) + """ + if grid.requires_grad: + return grid_grad(grad, grid, bound, interpolation, extrapolate).sum(1) + return None + + +# @torch.jit.script +def grid_grad_backward(grad, inp, grid, bound: List[int], + interpolation: List[int], extrapolate: int) \ + -> Tuple[Optional[Tensor], Optional[Tensor]]: + """ + grad: (B, C, *spatial_out, D) tensor + inp: (B, C, *spatial_in) tensor + grid: (B, *spatial_out, D) tensor + bound: List{D}[int] tensor + interpolation: List{D}[int] + extrapolate: int + returns: (B, C, *spatial_in, D) tensor, (B, *spatial_out, D) + """ + dim = grid.shape[-1] + shape = inp.shape[-dim:] + grad_inp: Optional[Tensor] = None + grad_grid: Optional[Tensor] = None + if inp.requires_grad: + grad_inp = grid_pushgrad(grad, grid, shape, bound, interpolation, extrapolate) + if grid.requires_grad: + grad_grid = grid_hess(inp, grid, bound, interpolation, extrapolate) + # grad_grid = dot_multi(grad_grid, grad.unsqueeze(-1), dim=[1, -2]) + grad_grid = (grad_grid * grad.unsqueeze(-1)).sum(dim=[1, -2]) + return grad_inp, grad_grid diff --git a/Generator/interpol/resize.py b/Generator/interpol/resize.py new file mode 100644 index 0000000000000000000000000000000000000000..9b505624ef795437f516d578465f702b07a4d7ae --- /dev/null +++ b/Generator/interpol/resize.py @@ -0,0 +1,120 @@ +""" +Resize functions (equivalent to scipy's zoom, pytorch's interpolate) +based on grid_pull. +""" +__all__ = ['resize'] + +from .api import grid_pull +from .utils import make_list, meshgrid_ij +from . import backend, jitfields +import torch + + +def resize(image, factor=None, shape=None, anchor='c', + interpolation=1, prefilter=True, **kwargs): + """Resize an image by a factor or to a specific shape. + + Notes + ----- + .. A least one of `factor` and `shape` must be specified + .. If `anchor in ('centers', 'edges')`, exactly one of `factor` or + `shape must be specified. + .. If `anchor in ('first', 'last')`, `factor` must be provided even + if `shape` is specified. + .. Because of rounding, it is in general not assured that + `resize(resize(x, f), 1/f)` returns a tensor with the same shape as x. + + edges centers first last + e - + - + - e + - + - + - + + - + - + - + + - + - + - + + | . | . | . | | c | . | c | | f | . | . | | . | . | . | + + _ + _ + _ + + _ + _ + _ + + _ + _ + _ + + _ + _ + _ + + | . | . | . | | . | . | . | | . | . | . | | . | . | . | + + _ + _ + _ + + _ + _ + _ + + _ + _ + _ + + _ + _ + _ + + | . | . | . | | c | . | c | | . | . | . | | . | . | l | + e _ + _ + _ e + _ + _ + _ + + _ + _ + _ + + _ + _ + _ + + + Parameters + ---------- + image : (batch, channel, *inshape) tensor + Image to resize + factor : float or list[float], optional + Resizing factor + * > 1 : larger image <-> smaller voxels + * < 1 : smaller image <-> larger voxels + shape : (ndim,) list[int], optional + Output shape + anchor : {'centers', 'edges', 'first', 'last'} or list, default='centers' + * In cases 'c' and 'e', the volume shape is multiplied by the + zoom factor (and eventually truncated), and two anchor points + are used to determine the voxel size. + * In cases 'f' and 'l', a single anchor point is used so that + the voxel size is exactly divided by the zoom factor. + This case with an integer factor corresponds to subslicing + the volume (e.g., `vol[::f, ::f, ::f]`). + * A list of anchors (one per dimension) can also be provided. + interpolation : int or sequence[int], default=1 + Interpolation order. + prefilter : bool, default=True + Apply spline pre-filter (= interpolates the input) + + Returns + ------- + resized : (batch, channel, *shape) tensor + Resized image + + """ + if backend.jitfields and jitfields.available: + return jitfields.resize(image, factor, shape, anchor, + interpolation, prefilter, **kwargs) + + factor = make_list(factor) if factor else [] + shape = make_list(shape) if shape else [] + anchor = make_list(anchor) + nb_dim = max(len(factor), len(shape), len(anchor)) or (image.dim() - 2) + anchor = [a[0].lower() for a in make_list(anchor, nb_dim)] + bck = dict(dtype=image.dtype, device=image.device) + + # compute output shape + inshape = image.shape[-nb_dim:] + if factor: + factor = make_list(factor, nb_dim) + elif not shape: + raise ValueError('One of `factor` or `shape` must be provided') + if shape: + shape = make_list(shape, nb_dim) + else: + shape = [int(i*f) for i, f in zip(inshape, factor)] + + if not factor: + factor = [o/i for o, i in zip(shape, inshape)] + + # compute transformation grid + lin = [] + for anch, f, inshp, outshp in zip(anchor, factor, inshape, shape): + if anch == 'c': # centers + lin.append(torch.linspace(0, inshp - 1, outshp, **bck)) + elif anch == 'e': # edges + scale = inshp / outshp + shift = 0.5 * (scale - 1) + lin.append(torch.arange(0., outshp, **bck) * scale + shift) + elif anch == 'f': # first voxel + # scale = 1/f + # shift = 0 + lin.append(torch.arange(0., outshp, **bck) / f) + elif anch == 'l': # last voxel + # scale = 1/f + shift = (inshp - 1) - (outshp - 1) / f + lin.append(torch.arange(0., outshp, **bck) / f + shift) + else: + raise ValueError('Unknown anchor {}'.format(anch)) + + # interpolate + kwargs.setdefault('bound', 'nearest') + kwargs.setdefault('extrapolate', True) + kwargs.setdefault('interpolation', interpolation) + kwargs.setdefault('prefilter', prefilter) + grid = torch.stack(meshgrid_ij(*lin), dim=-1) + resized = grid_pull(image, grid, **kwargs) + + return resized + diff --git a/Generator/interpol/restrict.py b/Generator/interpol/restrict.py new file mode 100644 index 0000000000000000000000000000000000000000..771acdfa9ce3a9e63e0ab8315362519ecceae587 --- /dev/null +++ b/Generator/interpol/restrict.py @@ -0,0 +1,122 @@ +__all__ = ['restrict'] + +from .api import grid_push +from .utils import make_list, meshgrid_ij +from . import backend, jitfields +import torch + + +def restrict(image, factor=None, shape=None, anchor='c', + interpolation=1, reduce_sum=False, **kwargs): + """Restrict an image by a factor or to a specific shape. + + Notes + ----- + .. A least one of `factor` and `shape` must be specified + .. If `anchor in ('centers', 'edges')`, exactly one of `factor` or + `shape must be specified. + .. If `anchor in ('first', 'last')`, `factor` must be provided even + if `shape` is specified. + .. Because of rounding, it is in general not assured that + `resize(resize(x, f), 1/f)` returns a tensor with the same shape as x. + + edges centers first last + e - + - + - e + - + - + - + + - + - + - + + - + - + - + + | . | . | . | | c | . | c | | f | . | . | | . | . | . | + + _ + _ + _ + + _ + _ + _ + + _ + _ + _ + + _ + _ + _ + + | . | . | . | | . | . | . | | . | . | . | | . | . | . | + + _ + _ + _ + + _ + _ + _ + + _ + _ + _ + + _ + _ + _ + + | . | . | . | | c | . | c | | . | . | . | | . | . | l | + e _ + _ + _ e + _ + _ + _ + + _ + _ + _ + + _ + _ + _ + + + Parameters + ---------- + image : (batch, channel, *inshape) tensor + Image to resize + factor : float or list[float], optional + Resizing factor + * > 1 : larger image <-> smaller voxels + * < 1 : smaller image <-> larger voxels + shape : (ndim,) list[int], optional + Output shape + anchor : {'centers', 'edges', 'first', 'last'} or list, default='centers' + * In cases 'c' and 'e', the volume shape is multiplied by the + zoom factor (and eventually truncated), and two anchor points + are used to determine the voxel size. + * In cases 'f' and 'l', a single anchor point is used so that + the voxel size is exactly divided by the zoom factor. + This case with an integer factor corresponds to subslicing + the volume (e.g., `vol[::f, ::f, ::f]`). + * A list of anchors (one per dimension) can also be provided. + interpolation : int or sequence[int], default=1 + Interpolation order. + reduce_sum : bool, default=False + Do not normalize by the number of accumulated values per voxel + + Returns + ------- + restricted : (batch, channel, *shape) tensor + Restricted image + + """ + if backend.jitfields and jitfields.available: + return jitfields.restrict(image, factor, shape, anchor, + interpolation, reduce_sum, **kwargs) + + factor = make_list(factor) if factor else [] + shape = make_list(shape) if shape else [] + anchor = make_list(anchor) + nb_dim = max(len(factor), len(shape), len(anchor)) or (image.dim() - 2) + anchor = [a[0].lower() for a in make_list(anchor, nb_dim)] + bck = dict(dtype=image.dtype, device=image.device) + + # compute output shape + inshape = image.shape[-nb_dim:] + if factor: + factor = make_list(factor, nb_dim) + elif not shape: + raise ValueError('One of `factor` or `shape` must be provided') + if shape: + shape = make_list(shape, nb_dim) + else: + shape = [int(i/f) for i, f in zip(inshape, factor)] + + if not factor: + factor = [i/o for o, i in zip(shape, inshape)] + + # compute transformation grid + lin = [] + fullscale = 1 + for anch, f, inshp, outshp in zip(anchor, factor, inshape, shape): + if anch == 'c': # centers + lin.append(torch.linspace(0, outshp - 1, inshp, **bck)) + fullscale *= (inshp - 1) / (outshp - 1) + elif anch == 'e': # edges + scale = outshp / inshp + shift = 0.5 * (scale - 1) + fullscale *= scale + lin.append(torch.arange(0., inshp, **bck) * scale + shift) + elif anch == 'f': # first voxel + # scale = 1/f + # shift = 0 + fullscale *= 1/f + lin.append(torch.arange(0., inshp, **bck) / f) + elif anch == 'l': # last voxel + # scale = 1/f + shift = (outshp - 1) - (inshp - 1) / f + fullscale *= 1/f + lin.append(torch.arange(0., inshp, **bck) / f + shift) + else: + raise ValueError('Unknown anchor {}'.format(anch)) + + # scatter + kwargs.setdefault('bound', 'nearest') + kwargs.setdefault('extrapolate', True) + kwargs.setdefault('interpolation', interpolation) + kwargs.setdefault('prefilter', False) + grid = torch.stack(meshgrid_ij(*lin), dim=-1) + resized = grid_push(image, grid, shape, **kwargs) + if not reduce_sum: + resized /= fullscale + + return resized diff --git a/Generator/interpol/splines.py b/Generator/interpol/splines.py new file mode 100644 index 0000000000000000000000000000000000000000..a456d87ff24ccc93b727af2cbdad6b0fbf5f6356 --- /dev/null +++ b/Generator/interpol/splines.py @@ -0,0 +1,196 @@ +"""Weights and derivatives of spline orders 0 to 7.""" +import torch +from enum import Enum +from .jit_utils import square, cube, pow4, pow5, pow6, pow7 + + +class InterpolationType(Enum): + nearest = zeroth = 0 + linear = first = 1 + quadratic = second = 2 + cubic = third = 3 + fourth = 4 + fifth = 5 + sixth = 6 + seventh = 7 + + +@torch.jit.script +class Spline: + + def __init__(self, order: int = 1): + self.order = order + + def weight(self, x): + w = self.fastweight(x) + zero = torch.zeros([1], dtype=x.dtype, device=x.device) + w = torch.where(x.abs() >= (self.order + 1)/2, zero, w) + return w + + def fastweight(self, x): + if self.order == 0: + return torch.ones(x.shape, dtype=x.dtype, device=x.device) + x = x.abs() + if self.order == 1: + return 1 - x + if self.order == 2: + x_low = 0.75 - square(x) + x_up = 0.5 * square(1.5 - x) + return torch.where(x < 0.5, x_low, x_up) + if self.order == 3: + x_low = (x * x * (x - 2.) * 3. + 4.) / 6. + x_up = cube(2. - x) / 6. + return torch.where(x < 1., x_low, x_up) + if self.order == 4: + x_low = square(x) + x_low = x_low * (x_low * 0.25 - 0.625) + 115. / 192. + x_mid = x * (x * (x * (5. - x) / 6. - 1.25) + 5./24.) + 55./96. + x_up = pow4(x - 2.5) / 24. + return torch.where(x < 0.5, x_low, torch.where(x < 1.5, x_mid, x_up)) + if self.order == 5: + x_low = square(x) + x_low = x_low * (x_low * (0.25 - x / 12.) - 0.5) + 0.55 + x_mid = x * (x * (x * (x * (x / 24. - 0.375) + 1.25) - 1.75) + 0.625) + 0.425 + x_up = pow5(3 - x) / 120. + return torch.where(x < 1., x_low, torch.where(x < 2., x_mid, x_up)) + if self.order == 6: + x_low = square(x) + x_low = x_low * (x_low * (7./48. - x_low/36.) - 77./192.) + 5887./11520. + x_mid_low = (x * (x * (x * (x * (x * (x / 48. - 7./48.) + 0.328125) + - 35./288.) - 91./256.) - 7./768.) + 7861./15360.) + x_mid_up = (x * (x * (x * (x * (x * (7./60. - x / 120.) - 0.65625) + + 133./72.) - 2.5703125) + 1267./960.) + 1379./7680.) + x_up = pow6(x - 3.5) / 720. + return torch.where(x < .5, x_low, + torch.where(x < 1.5, x_mid_low, + torch.where(x < 2.5, x_mid_up, x_up))) + if self.order == 7: + x_low = square(x) + x_low = (x_low * (x_low * (x_low * (x / 144. - 1./36.) + + 1./9.) - 1./3.) + 151./315.) + x_mid_low = (x * (x * (x * (x * (x * (x * (0.05 - x/240.) - 7./30.) + + 0.5) - 7./18.) - 0.1) - 7./90.) + 103./210.) + x_mid_up = (x * (x * (x * (x * (x * (x * (x / 720. - 1./36.) + + 7./30.) - 19./18.) + 49./18.) - 23./6.) + 217./90.) + - 139./630.) + x_up = pow7(4 - x) / 5040. + return torch.where(x < 1., x_low, + torch.where(x < 2., x_mid_low, + torch.where(x < 3., x_mid_up, x_up))) + raise NotImplementedError + + def grad(self, x): + if self.order == 0: + return torch.zeros(x.shape, dtype=x.dtype, device=x.device) + g = self.fastgrad(x) + zero = torch.zeros([1], dtype=x.dtype, device=x.device) + g = torch.where(x.abs() >= (self.order + 1)/2, zero, g) + return g + + def fastgrad(self, x): + if self.order == 0: + return torch.zeros(x.shape, dtype=x.dtype, device=x.device) + return self._fastgrad(x.abs()).mul(x.sign()) + + def _fastgrad(self, x): + if self.order == 1: + return torch.ones(x.shape, dtype=x.dtype, device=x.device) + if self.order == 2: + return torch.where(x < 0.5, -2*x, x - 1.5) + if self.order == 3: + g_low = x * (x * 1.5 - 2) + g_up = -0.5 * square(2 - x) + return torch.where(x < 1, g_low, g_up) + if self.order == 4: + g_low = x * (square(x) - 1.25) + g_mid = x * (x * (x * (-2./3.) + 2.5) - 2.5) + 5./24. + g_up = cube(2. * x - 5.) / 48. + return torch.where(x < 0.5, g_low, + torch.where(x < 1.5, g_mid, g_up)) + if self.order == 5: + g_low = x * (x * (x * (x * (-5./12.) + 1.)) - 1.) + g_mid = x * (x * (x * (x * (5./24.) - 1.5) + 3.75) - 3.5) + 0.625 + g_up = pow4(x - 3.) / (-24.) + return torch.where(x < 1, g_low, + torch.where(x < 2, g_mid, g_up)) + if self.order == 6: + g_low = square(x) + g_low = x * (g_low * (7./12.) - square(g_low) / 6. - 77./96.) + g_mid_low = (x * (x * (x * (x * (x * 0.125 - 35./48.) + 1.3125) + - 35./96.) - 0.7109375) - 7./768.) + g_mid_up = (x * (x * (x * (x * (x / (-20.) + 7./12.) - 2.625) + + 133./24.) - 5.140625) + 1267./960.) + g_up = pow5(2*x - 7) / 3840. + return torch.where(x < 0.5, g_low, + torch.where(x < 1.5, g_mid_low, + torch.where(x < 2.5, g_mid_up, + g_up))) + if self.order == 7: + g_low = square(x) + g_low = x * (g_low * (g_low * (x * (7./144.) - 1./6.) + 4./9.) - 2./3.) + g_mid_low = (x * (x * (x * (x * (x * (x * (-7./240.) + 3./10.) + - 7./6.) + 2.) - 7./6.) - 1./5.) - 7./90.) + g_mid_up = (x * (x * (x * (x * (x * (x * (7./720.) - 1./6.) + + 7./6.) - 38./9.) + 49./6.) - 23./3.) + 217./90.) + g_up = pow6(x - 4) / (-720.) + return torch.where(x < 1, g_low, + torch.where(x < 2, g_mid_low, + torch.where(x < 3, g_mid_up, g_up))) + raise NotImplementedError + + def hess(self, x): + if self.order == 0: + return torch.zeros(x.shape, dtype=x.dtype, device=x.device) + h = self.fasthess(x) + zero = torch.zeros([1], dtype=x.dtype, device=x.device) + h = torch.where(x.abs() >= (self.order + 1)/2, zero, h) + return h + + def fasthess(self, x): + if self.order in (0, 1): + return torch.zeros(x.shape, dtype=x.dtype, device=x.device) + x = x.abs() + if self.order == 2: + one = torch.ones([1], dtype=x.dtype, device=x.device) + return torch.where(x < 0.5, -2 * one, one) + if self.order == 3: + return torch.where(x < 1, 3. * x - 2., 2. - x) + if self.order == 4: + return torch.where(x < 0.5, 3. * square(x) - 1.25, + torch.where(x < 1.5, x * (-2. * x + 5.) - 2.5, + square(2. * x - 5.) / 8.)) + if self.order == 5: + h_low = square(x) + h_low = - h_low * (x * (5./3.) - 3.) - 1. + h_mid = x * (x * (x * (5./6.) - 9./2.) + 15./2.) - 7./2. + h_up = 9./2. - x * (x * (x/6. - 3./2.) + 9./2.) + return torch.where(x < 1, h_low, + torch.where(x < 2, h_mid, h_up)) + if self.order == 6: + h_low = square(x) + h_low = - h_low * (h_low * (5./6) - 7./4.) - 77./96. + h_mid_low = (x * (x * (x * (x * (5./8.) - 35./12.) + 63./16.) + - 35./48.) - 91./128.) + h_mid_up = -(x * (x * (x * (x/4. - 7./3.) + 63./8.) - 133./12.) + + 329./64.) + h_up = (x * (x * (x * (x/24. - 7./12.) + 49./16.) - 343./48.) + + 2401./384.) + return torch.where(x < 0.5, h_low, + torch.where(x < 1.5, h_mid_low, + torch.where(x < 2.5, h_mid_up, + h_up))) + if self.order == 7: + h_low = square(x) + h_low = h_low * (h_low*(x * (7./24.) - 5./6.) + 4./3.) - 2./3. + h_mid_low = - (x * (x * (x * (x * (x * (7./40.) - 3./2.) + 14./3.) + - 6.) + 7./3.) + 1./5.) + h_mid_up = (x * (x * (x * (x * (x * (7./120.) - 5./6.) + 14./3.) + - 38./3.) + 49./3.) - 23./3.) + h_up = - (x * (x * (x * (x * (x/120. - 1./6.) + 4./3.) - 16./3.) + + 32./3.) - 128./15.) + return torch.where(x < 1, h_low, + torch.where(x < 2, h_mid_low, + torch.where(x < 3, h_mid_up, + h_up))) + raise NotImplementedError + diff --git a/Generator/interpol/tests/__init__.py b/Generator/interpol/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Generator/interpol/tests/test_gradcheck_pushpull.py b/Generator/interpol/tests/test_gradcheck_pushpull.py new file mode 100644 index 0000000000000000000000000000000000000000..7a4344391b37b39284ed1b2b1d1e62f7346de533 --- /dev/null +++ b/Generator/interpol/tests/test_gradcheck_pushpull.py @@ -0,0 +1,125 @@ +import torch +from torch.autograd import gradcheck +from interpol import grid_pull, grid_push, grid_count, grid_grad, add_identity_grid_ +import pytest +import inspect + +# global parameters +dtype = torch.double # data type (double advised to check gradients) +shape1 = 3 # size along each dimension +extrapolate = True + +if hasattr(torch, 'use_deterministic_algorithms'): + torch.use_deterministic_algorithms(True) +kwargs = dict(rtol=1., raise_exception=True) +if 'check_undefined_grad' in inspect.signature(gradcheck).parameters: + kwargs['check_undefined_grad'] = False +if 'nondet_tol' in inspect.signature(gradcheck).parameters: + kwargs['nondet_tol'] = 1e-3 + +# parameters +devices = [('cpu', 1)] +if torch.backends.openmp.is_available() or torch.backends.mkl.is_available(): + print('parallel backend available') + devices.append(('cpu', 10)) +if torch.cuda.is_available(): + print('cuda backend available') + devices.append('cuda') + +dims = [1, 2, 3] +bounds = list(range(7)) +order_bounds = [] +for o in range(3): + for b in bounds: + order_bounds += [(o, b)] +for o in range(3, 8): + order_bounds += [(o, 3)] # only test dc2 for order > 2 + + +def make_data(shape, device, dtype): + grid = torch.randn([2, *shape, len(shape)], device=device, dtype=dtype) + grid = add_identity_grid_(grid) + vol = torch.randn((2, 1,) + shape, device=device, dtype=dtype) + return vol, grid + + +def init_device(device): + if isinstance(device, (list, tuple)): + device, param = device + else: + param = 1 if device == 'cpu' else 0 + if device == 'cuda': + torch.cuda.set_device(param) + torch.cuda.init() + try: + torch.cuda.empty_cache() + except RuntimeError: + pass + device = '{}:{}'.format(device, param) + else: + assert device == 'cpu' + torch.set_num_threads(param) + return torch.device(device) + + +@pytest.mark.parametrize("device", devices) +@pytest.mark.parametrize("dim", dims) +# @pytest.mark.parametrize("bound", bounds) +# @pytest.mark.parametrize("interpolation", orders) +@pytest.mark.parametrize("interpolation,bound", order_bounds) +def test_gradcheck_grad(device, dim, bound, interpolation): + print(f'grad_{dim}d({interpolation}, {bound}) on {device}') + device = init_device(device) + shape = (shape1,) * dim + vol, grid = make_data(shape, device, dtype) + vol.requires_grad = True + grid.requires_grad = True + assert gradcheck(grid_grad, (vol, grid, interpolation, bound, extrapolate), + **kwargs) + + +@pytest.mark.parametrize("device", devices) +@pytest.mark.parametrize("dim", dims) +# @pytest.mark.parametrize("bound", bounds) +# @pytest.mark.parametrize("interpolation", orders) +@pytest.mark.parametrize("interpolation,bound", order_bounds) +def test_gradcheck_pull(device, dim, bound, interpolation): + print(f'pull_{dim}d({interpolation}, {bound}) on {device}') + device = init_device(device) + shape = (shape1,) * dim + vol, grid = make_data(shape, device, dtype) + vol.requires_grad = True + grid.requires_grad = True + assert gradcheck(grid_pull, (vol, grid, interpolation, bound, extrapolate), + **kwargs) + + +@pytest.mark.parametrize("device", devices) +@pytest.mark.parametrize("dim", dims) +# @pytest.mark.parametrize("bound", bounds) +# @pytest.mark.parametrize("interpolation", orders) +@pytest.mark.parametrize("interpolation,bound", order_bounds) +def test_gradcheck_push(device, dim, bound, interpolation): + print(f'push_{dim}d({interpolation}, {bound}) on {device}') + device = init_device(device) + shape = (shape1,) * dim + vol, grid = make_data(shape, device, dtype) + vol.requires_grad = True + grid.requires_grad = True + assert gradcheck(grid_push, (vol, grid, shape, interpolation, bound, extrapolate), + **kwargs) + + +@pytest.mark.parametrize("device", devices) +@pytest.mark.parametrize("dim", dims) +# @pytest.mark.parametrize("bound", bounds) +# @pytest.mark.parametrize("interpolation", orders) +@pytest.mark.parametrize("interpolation,bound", order_bounds) +def test_gradcheck_count(device, dim, bound, interpolation): + print(f'count_{dim}d({interpolation}, {bound}) on {device}') + device = init_device(device) + shape = (shape1,) * dim + _, grid = make_data(shape, device, dtype) + grid.requires_grad = True + assert gradcheck(grid_count, (grid, shape, interpolation, bound, extrapolate), + **kwargs) \ No newline at end of file diff --git a/Generator/interpol/utils.py b/Generator/interpol/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2f1109fef6e3d93b7cb009a99ec31cccd7f45752 --- /dev/null +++ b/Generator/interpol/utils.py @@ -0,0 +1,176 @@ +import torch + + +def fake_decorator(*a, **k): + if len(a) == 1 and not k: + return a[0] + else: + return fake_decorator + + +def make_list(x, n=None, **kwargs): + """Ensure that the input is a list (of a given size) + + Parameters + ---------- + x : list or tuple or scalar + Input object + n : int, optional + Required length + default : scalar, optional + Value to right-pad with. Use last value of the input by default. + + Returns + ------- + x : list + """ + if not isinstance(x, (list, tuple)): + x = [x] + x = list(x) + if n and len(x) < n: + default = kwargs.get('default', x[-1]) + x = x + [default] * max(0, n - len(x)) + return x + + +def expanded_shape(*shapes, side='left'): + """Expand input shapes according to broadcasting rules + + Parameters + ---------- + *shapes : sequence[int] + Input shapes + side : {'left', 'right'}, default='left' + Side to add singleton dimensions. + + Returns + ------- + shape : tuple[int] + Output shape + + Raises + ------ + ValueError + If shapes are not compatible for broadcast. + + """ + def error(s0, s1): + raise ValueError('Incompatible shapes for broadcasting: {} and {}.' + .format(s0, s1)) + + # 1. nb dimensions + nb_dim = 0 + for shape in shapes: + nb_dim = max(nb_dim, len(shape)) + + # 2. enumerate + shape = [1] * nb_dim + for i, shape1 in enumerate(shapes): + pad_size = nb_dim - len(shape1) + ones = [1] * pad_size + if side == 'left': + shape1 = [*ones, *shape1] + else: + shape1 = [*shape1, *ones] + shape = [max(s0, s1) if s0 == 1 or s1 == 1 or s0 == s1 + else error(s0, s1) for s0, s1 in zip(shape, shape1)] + + return tuple(shape) + + +def matvec(mat, vec, out=None): + """Matrix-vector product (supports broadcasting) + + Parameters + ---------- + mat : (..., M, N) tensor + Input matrix. + vec : (..., N) tensor + Input vector. + out : (..., M) tensor, optional + Placeholder for the output tensor. + + Returns + ------- + mv : (..., M) tensor + Matrix vector product of the inputs + + """ + vec = vec[..., None] + if out is not None: + out = out[..., None] + + mv = torch.matmul(mat, vec, out=out) + mv = mv[..., 0] + if out is not None: + out = out[..., 0] + + return mv + + +def _compare_versions(version1, mode, version2): + for v1, v2 in zip(version1, version2): + if mode in ('gt', '>'): + if v1 > v2: + return True + elif v1 < v2: + return False + elif mode in ('ge', '>='): + if v1 > v2: + return True + elif v1 < v2: + return False + elif mode in ('lt', '<'): + if v1 < v2: + return True + elif v1 > v2: + return False + elif mode in ('le', '<='): + if v1 < v2: + return True + elif v1 > v2: + return False + if mode in ('gt', 'lt', '>', '<'): + return False + else: + return True + + +def torch_version(mode, version): + """Check torch version + + Parameters + ---------- + mode : {'<', '<=', '>', '>='} + version : tuple[int] + + Returns + ------- + True if "torch.version version" + + """ + current_version, *cuda_variant = torch.__version__.split('+') + major, minor, patch, *_ = current_version.split('.') + # strip alpha tags + for x in 'abcdefghijklmnopqrstuvwxy': + if x in patch: + patch = patch[:patch.index(x)] + current_version = (int(major), int(minor), int(patch)) + version = make_list(version) + return _compare_versions(current_version, mode, version) + + +if torch_version('>=', (1, 10)): + meshgrid_ij = lambda *x: torch.meshgrid(*x, indexing='ij') + meshgrid_xy = lambda *x: torch.meshgrid(*x, indexing='xy') +else: + meshgrid_ij = lambda *x: torch.meshgrid(*x) + def meshgrid_xy(*x): + grid = list(torch.meshgrid(*x)) + if len(grid) > 1: + grid[0] = grid[0].transpose(0, 1) + grid[1] = grid[1].transpose(0, 1) + return grid + + +meshgrid = meshgrid_ij \ No newline at end of file diff --git a/Generator/utils.py b/Generator/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ddd365a3bce5c7878fd5fabd56ba97f15bc9fa36 --- /dev/null +++ b/Generator/utils.py @@ -0,0 +1,669 @@ +import os +import numpy as np +import nibabel as nib + +import torch +from torch.nn.functional import conv3d +from torch.utils.data import Dataset + +from scipy.io.matlab import loadmat + + +import time, datetime + +from ShapeID.DiffEqs.adjoint import odeint_adjoint as odeint +from ShapeID.perlin3d import generate_velocity_3d , generate_shape_3d + + +class ConcatDataset(Dataset): + def __init__(self,dataset_list, probs=None): + self.datasets = dataset_list + self.probs = probs if probs else [1/len(self.datasets)] * len(self.datasets) + + def __getitem__(self, i): + chosen_dataset = np.random.choice(self.datasets, 1, p=self.probs)[0] + i = i % len(chosen_dataset) + return chosen_dataset[i] + + def __len__(self): + return max(len(d) for d in self.datasets) + + + +# Prepare generator +def resolution_sampler(low_res_only = False): + + if low_res_only: + r = (np.random.rand() * 0.5) + 0.5 # in [0.5, 1] + else: + r = np.random.rand() # in [0, 1] + + if r < 0.25: # 1mm isotropic + resolution = np.array([1.0, 1.0, 1.0]) + thickness = np.array([1.0, 1.0, 1.0]) + elif r < 0.5: # clinical (low-res in one dimension) + resolution = np.array([1.0, 1.0, 1.0]) + thickness = np.array([1.0, 1.0, 1.0]) + idx = np.random.randint(3) + resolution[idx] = 2.5 + 6 * np.random.rand() + thickness[idx] = np.min([resolution[idx], 4.0 + 2.0 * np.random.rand()]) + elif r < 0.75: # low-field: stock sequences (always axial) + resolution = np.array([1.3, 1.3, 4.8]) + 0.4 * np.random.rand(3) + thickness = resolution.copy() + else: # low-field: isotropic-ish (also good for scouts) + resolution = 2.0 + 3.0 * np.random.rand(3) + thickness = resolution.copy() + + return resolution, thickness + + +##################################### +############ Utility Func ########### +##################################### + + +def binarize(p, thres): + # TODO: what is the optimal thresholding strategy? + thres = thres * p.max() + + bin = p.clone() + bin[p < thres] = 0. + bin[p >= thres] = 1. + return bin + +def make_gaussian_kernel(sigma, device): + + sl = int(np.ceil(3 * sigma)) + ts = torch.linspace(-sl, sl, 2*sl+1, dtype=torch.float, device=device) + gauss = torch.exp((-(ts / sigma)**2 / 2)) + kernel = gauss / gauss.sum() + + return kernel + +def gaussian_blur_3d(input, stds, device): + blurred = input[None, None, :, :, :] + if stds[0]>0: + kx = make_gaussian_kernel(stds[0], device=device) + blurred = conv3d(blurred, kx[None, None, :, None, None], stride=1, padding=(len(kx) // 2, 0, 0)) + if stds[1]>0: + ky = make_gaussian_kernel(stds[1], device=device) + blurred = conv3d(blurred, ky[None, None, None, :, None], stride=1, padding=(0, len(ky) // 2, 0)) + if stds[2]>0: + kz = make_gaussian_kernel(stds[2], device=device) + blurred = conv3d(blurred, kz[None, None, None, None, :], stride=1, padding=(0, 0, len(kz) // 2)) + return torch.squeeze(blurred) + + + +##################################### +######### Deformation Func ########## +##################################### + +def make_affine_matrix(rot, sh, s): + Rx = np.array([[1, 0, 0], [0, np.cos(rot[0]), -np.sin(rot[0])], [0, np.sin(rot[0]), np.cos(rot[0])]]) + Ry = np.array([[np.cos(rot[1]), 0, np.sin(rot[1])], [0, 1, 0], [-np.sin(rot[1]), 0, np.cos(rot[1])]]) + Rz = np.array([[np.cos(rot[2]), -np.sin(rot[2]), 0], [np.sin(rot[2]), np.cos(rot[2]), 0], [0, 0, 1]]) + + SHx = np.array([[1, 0, 0], [sh[1], 1, 0], [sh[2], 0, 1]]) + SHy = np.array([[1, sh[0], 0], [0, 1, 0], [0, sh[2], 1]]) + SHz = np.array([[1, 0, sh[0]], [0, 1, sh[1]], [0, 0, 1]]) + + A = SHx @ SHy @ SHz @ Rx @ Ry @ Rz + A[0, :] = A[0, :] * s[0] + A[1, :] = A[1, :] * s[1] + A[2, :] = A[2, :] * s[2] + + return A + + +def fast_3D_interp_torch(X, II, JJ, KK, mode='linear', default_value_linear=0.0): + + if II is None: + return X + + if mode=='nearest': + IIr = torch.round(II).long() + JJr = torch.round(JJ).long() + KKr = torch.round(KK).long() + IIr[IIr < 0] = 0 + JJr[JJr < 0] = 0 + KKr[KKr < 0] = 0 + IIr[IIr > (X.shape[0] - 1)] = (X.shape[0] - 1) + JJr[JJr > (X.shape[1] - 1)] = (X.shape[1] - 1) + KKr[KKr > (X.shape[2] - 1)] = (X.shape[2] - 1) + if len(X.shape)==3: + X = X[..., None] + Y = X[IIr, JJr, KKr] + if Y.shape[3] == 1: + Y = Y[:, :, :, 0] + + elif mode=='linear': + ok = (II>0) & (JJ>0) & (KK>0) & (II<=X.shape[0]-1) & (JJ<=X.shape[1]-1) & (KK<=X.shape[2]-1) + + IIv = II[ok] + JJv = JJ[ok] + KKv = KK[ok] + + fx = torch.floor(IIv).long() + cx = fx + 1 + cx[cx > (X.shape[0] - 1)] = (X.shape[0] - 1) + wcx = (IIv - fx)[..., None] + wfx = 1 - wcx + + fy = torch.floor(JJv).long() + cy = fy + 1 + cy[cy > (X.shape[1] - 1)] = (X.shape[1] - 1) + wcy = (JJv - fy)[..., None] + wfy = 1 - wcy + + fz = torch.floor(KKv).long() + cz = fz + 1 + cz[cz > (X.shape[2] - 1)] = (X.shape[2] - 1) + wcz = (KKv - fz)[..., None] + wfz = 1 - wcz + + if len(X.shape)==3: + X = X[..., None] + + c000 = X[fx, fy, fz] + c100 = X[cx, fy, fz] + c010 = X[fx, cy, fz] + c110 = X[cx, cy, fz] + c001 = X[fx, fy, cz] + c101 = X[cx, fy, cz] + c011 = X[fx, cy, cz] + c111 = X[cx, cy, cz] + + c00 = c000 * wfx + c100 * wcx + c01 = c001 * wfx + c101 * wcx + c10 = c010 * wfx + c110 * wcx + c11 = c011 * wfx + c111 * wcx + + c0 = c00 * wfy + c10 * wcy + c1 = c01 * wfy + c11 * wcy + + c = c0 * wfz + c1 * wcz + + Y = torch.zeros([*II.shape, X.shape[3]], device=X.device) + Y[ok] = c.float() + Y[~ok] = default_value_linear + + if Y.shape[-1]==1: + Y = Y[...,0] + else: + raise Exception('mode must be linear or nearest') + + return Y + + + +def myzoom_torch(X, factor, aff=None): + + if len(X.shape)==3: + X = X[..., None] + + delta = (1.0 - factor) / (2.0 * factor) + newsize = np.round(X.shape[:-1] * factor).astype(int) + + vx = torch.arange(delta[0], delta[0] + newsize[0] / factor[0], 1 / factor[0], dtype=torch.float, device=X.device)[:newsize[0]] + vy = torch.arange(delta[1], delta[1] + newsize[1] / factor[1], 1 / factor[1], dtype=torch.float, device=X.device)[:newsize[1]] + vz = torch.arange(delta[2], delta[2] + newsize[2] / factor[2], 1 / factor[2], dtype=torch.float, device=X.device)[:newsize[2]] + + vx[vx < 0] = 0 + vy[vy < 0] = 0 + vz[vz < 0] = 0 + vx[vx > (X.shape[0]-1)] = (X.shape[0]-1) + vy[vy > (X.shape[1] - 1)] = (X.shape[1] - 1) + vz[vz > (X.shape[2] - 1)] = (X.shape[2] - 1) + + fx = torch.floor(vx).int() + cx = fx + 1 + cx[cx > (X.shape[0]-1)] = (X.shape[0]-1) + wcx = (vx - fx) + wfx = 1 - wcx + + fy = torch.floor(vy).int() + cy = fy + 1 + cy[cy > (X.shape[1]-1)] = (X.shape[1]-1) + wcy = (vy - fy) + wfy = 1 - wcy + + fz = torch.floor(vz).int() + cz = fz + 1 + cz[cz > (X.shape[2]-1)] = (X.shape[2]-1) + wcz = (vz - fz) + wfz = 1 - wcz + + Y = torch.zeros([newsize[0], newsize[1], newsize[2], X.shape[3]], dtype=torch.float, device=X.device) + + tmp1 = torch.zeros([newsize[0], X.shape[1], X.shape[2], X.shape[3]], dtype=torch.float, device=X.device) + for i in range(newsize[0]): + tmp1[i, :, :] = wfx[i] * X[fx[i], :, :] + wcx[i] * X[cx[i], :, :] + tmp2 = torch.zeros([newsize[0], newsize[1], X.shape[2], X.shape[3]], dtype=torch.float, device=X.device) + for j in range(newsize[1]): + tmp2[:, j, :] = wfy[j] * tmp1[:, fy[j], :] + wcy[j] * tmp1[:, cy[j], :] + for k in range(newsize[2]): + Y[:, :, k] = wfz[k] * tmp2[:, :, fz[k]] + wcz[k] * tmp2[:, :, cz[k]] + + if Y.shape[3] == 1: + Y = Y[:,:,:, 0] + + if aff is not None: + aff_new = aff.copy() + aff_new[:-1] = aff_new[:-1] / factor + aff_new[:-1, -1] = aff_new[:-1, -1] - aff[:-1, :-1] @ (0.5 - 0.5 / (factor * np.ones(3))) + return Y, aff_new + else: + return Y + + + + +##################################### +############ Reading Func ########### +##################################### + +def read_image(file_name): + img = nib.load(file_name) + aff = img.affine + res = np.sqrt(np.sum(abs(aff[:-1, :-1]), axis=0)) + return img, aff, res + +def deform_image(I, deform_dict, device, default_value_linear_mode=None, deform_mode = 'linear'): + if I is None: + return I + + [xx2, yy2, zz2, x1, y1, z1, x2, y2, z2] = deform_dict['grid'] + + if not isinstance(I, torch.Tensor): + I = torch.squeeze(torch.tensor(I.get_fdata()[x1:x2, y1:y2, z1:z2].astype(float), dtype=torch.float, device=device)) + else: + I = torch.squeeze(I[x1:x2, y1:y2, z1:z2].astype(float), dtype=torch.float, device=device) + I = torch.nan_to_num(I) + + if default_value_linear_mode is not None: + if default_value_linear_mode == 'max': + default_value_linear = torch.max(I) + else: + raise ValueError('Not support default_value_linear_mode:', default_value_linear_mode) + else: + default_value_linear = 0. + Idef = fast_3D_interp_torch(I, xx2, yy2, zz2, deform_mode, default_value_linear) + + return Idef + + +def read_and_deform(file_name, dtype, deform_dict, device, mask, default_value_linear_mode=None, deform_mode = 'linear', mean = 0., scale = 1.): + [xx2, yy2, zz2, x1, y1, z1, x2, y2, z2] = deform_dict['grid'] + + try: + Iimg = nib.load(file_name) + except: + Iimg = nib.load(file_name + '.gz') + res = np.sqrt(np.sum(abs(Iimg.affine[:-1, :-1]), axis=0)) + I = torch.squeeze(torch.tensor(Iimg.get_fdata()[x1:x2, y1:y2, z1:z2].astype(float), dtype=dtype, device=device)) + I = torch.nan_to_num(I) + + I -= mean + I /= scale + + if mask is not None: + I[mask == 0] = 0 + + if default_value_linear_mode is not None: + if default_value_linear_mode == 'max': + default_value_linear = torch.max(I) + else: + raise ValueError('Not support default_value_linear_mode:', default_value_linear_mode) + else: + default_value_linear = 0. + Idef = fast_3D_interp_torch(I, xx2, yy2, zz2, deform_mode, default_value_linear) + return Idef, res + + +def read_and_deform_image(exist_keys, task_name, file_name, setups, deform_dict, device, mask, **kwargs): + Idef, _ = read_and_deform(file_name, torch.float, deform_dict, device, mask) + Idef -= torch.min(Idef) + Idef /= torch.max(Idef) + if setups['flip']: + Idef = torch.flip(Idef, [0]) + update_dict = {task_name: Idef[None]} + + if os.path.isfile(file_name[:-4] + '.defacingmask.nii'): + Idef_DM, _ = read_and_deform(file_name[:-4] + '.defacingmask.nii', torch.float, deform_dict, device, mask) + Idef_DM = torch.clamp(Idef_DM, min = 0.) + Idef_DM /= torch.max(Idef_DM) + if setups['flip']: + Idef = torch.flip(Idef_DM, [0]) + update_dict.update({task_name + '_DM': Idef_DM[None]}) + #if not 'brain_mask' in exist_keys: + # mask = torch.ones_like(Idef) + # mask[Idef <= 0.] = 0. + # update_dict.update({'brain_mask': mask[None]}) + return update_dict + +def read_and_deform_CT(exist_keys, task_name, file_name, setups, deform_dict, device, mask, **kwargs): + Idef, _ = read_and_deform(file_name, torch.float, deform_dict, device, mask, scale = 1000) + #Idef = torch.clamp(Idef, min = 0., max = 80.) # No clamping for inference/GT + #Idef /= torch.max(Idef) + if setups['flip']: + Idef = torch.flip(Idef, [0]) + update_dict = {'CT': Idef[None]} + + if os.path.isfile(file_name[:-4] + '.defacingmask.nii'): + Idef_DM, _ = read_and_deform(file_name[:-4] + '.defacingmask.nii', torch.float, deform_dict, device, mask) + Idef_DM = torch.clamp(Idef_DM, min = 0.) + Idef_DM /= torch.max(Idef_DM) + if setups['flip']: + Idef = torch.flip(Idef_DM, [0]) + update_dict.update({task_name + '_DM': Idef_DM[None]}) + #if not 'brain_mask' in exist_keys: + # mask = torch.ones_like(Idef) + # mask[Idef <= 0.] = 0. + # update_dict.update({'brain_mask': mask[None]}) + return update_dict + +def read_and_deform_distance(exist_keys, task_name, file_names, setups, deform_dict, device, mask, cfg, **kwargs): + [lp_dist_map, lw_dist_map, rp_dist_map, rw_dist_map] = file_names + + + lp, _ = read_and_deform(lp_dist_map, torch.float, deform_dict, device, mask, default_value_linear_mode = 'max', mean = 128., scale = 20) + lw, _ = read_and_deform(lw_dist_map, torch.float, deform_dict, device, mask, default_value_linear_mode = 'max', mean = 128., scale = 20) + + if mask is not None: # left_hemis_only + Idef = torch.stack([lp, lw], dim = 0) + else: + rp, _ = read_and_deform(rp_dist_map, torch.float, deform_dict, device, mask, default_value_linear_mode = 'max', mean = 128., scale = 20) + rw, _ = read_and_deform(rw_dist_map, torch.float, deform_dict, device, mask, default_value_linear_mode = 'max', mean = 128., scale = 20) + + if setups['flip']: + aux = torch.flip(lp, [0]) + lp = torch.flip(rp, [0]) + rp = aux + aux = torch.flip(lw, [0]) + lw = torch.flip(rw, [0]) + rw = aux + + Idef = torch.stack([lp, lw, rp, rw], dim = 0) + + Idef /= deform_dict['scaling_factor_distances'] + Idef = torch.clamp(Idef, min=-cfg.max_surf_distance, max=cfg.max_surf_distance) + + return {'distance': Idef} + +def read_and_deform_segmentation(exist_keys, task_name, file_name, setups, deform_dict, device, mask, cfg, onehotmatrix, lut, vflip, **kwargs): + [xx2, yy2, zz2, x1, y1, z1, x2, y2, z2] = deform_dict['grid'] + + Simg = nib.load(file_name) + S = torch.squeeze(torch.tensor(Simg.get_fdata()[x1:x2, y1:y2, z1:z2].astype(int), dtype=torch.int, device=device)) + + if mask is not None: + S[mask == 0] = 0 + + Sdef = fast_3D_interp_torch(S, xx2, yy2, zz2, 'nearest') + if cfg.generator.deform_one_hots: + Sonehot = onehotmatrix[lut[S.long()]] + Sdef_OneHot = fast_3D_interp_torch(Sonehot, xx2, yy2, zz2) + else: + Sdef_OneHot = onehotmatrix[lut[Sdef.long()]] + + if setups['flip']: + #Sdef = torch.flip(Sdef, [0]) + Sdef_OneHot = torch.flip(Sdef_OneHot, [0])[:, :, :, vflip] + + # prepare for input + Sdef_OneHot = Sdef_OneHot.permute([3, 0, 1, 2]) + + #update_dict = {'label': Sdef[None], 'segmentation': Sdef_OneHot} + update_dict = {'segmentation': Sdef_OneHot} + + #if not 'brain_mask' in exist_keys: + # mask = torch.ones_like(Sdef) + # mask[Sdef <= 0.] = 0. + # update_dict.update({'brain_mask': mask[None]}) + return update_dict + + + +def read_and_deform_pathology(exist_keys, task_name, file_name, setups, deform_dict, device, mask = None, + augment = False, pde_func = None, t = None, + shape_gen_args = None, thres = 0., **kwargs): + # NOTE does not support left_hemis for now + + [xx2, yy2, zz2, x1, y1, z1, x2, y2, z2] = deform_dict['grid'] + + if file_name is None: + return {'pathology': torch.zeros(xx2.shape)[None].to(device), 'pathology_prob': torch.zeros(xx2.shape)[None].to(device)} + + if file_name == 'random_shape': # generate random shape + percentile = np.random.uniform(shape_gen_args.mask_percentile_min, shape_gen_args.mask_percentile_max) + _, Pdef = generate_shape_3d(xx2.shape, shape_gen_args.perlin_res, percentile, device) + else: # read from existing shape + Pdef, _ = read_and_deform(file_name, torch.float, deform_dict, device) + + if augment: + Pdef = augment_pathology(Pdef, pde_func, t, shape_gen_args, device) + + #if setups['flip']: # flipping should happen after P has been encoded + # Pdef = torch.flip(Pdef, [0]) + + P = binarize(Pdef, thres) + if P.mean() <= shape_gen_args.pathol_tol: + return {'pathology': torch.zeros(xx2.shape)[None].to(device), 'pathology_prob': torch.zeros(xx2.shape)[None].to(device)} + #print('process', P.mean(), shape_gen_args.pathol_tol) + + return {'pathology': P[None], 'pathology_prob': Pdef[None]} + + +def read_and_deform_registration(exist_keys, task_name, file_names, setups, deform_dict, device, mask, **kwargs): + [mni_reg_x, mni_reg_y, mni_reg_z] = file_names + regx, _ = read_and_deform(mni_reg_x, torch.float, deform_dict, device, mask, scale = 10000) + regy, _ = read_and_deform(mni_reg_y, torch.float, deform_dict, device, mask, scale = 10000) + regz, _ = read_and_deform(mni_reg_z, torch.float, deform_dict, device, mask, scale = 10000) + + if setups['flip']: + regx = -torch.flip(regx, [0]) # NOTE: careful with switching sign + regy = torch.flip(regy, [0]) + regz = torch.flip(regz, [0]) + + Idef = torch.stack([regx, regy, regz], dim = 0) + + return {'registration': Idef} + +def read_and_deform_bias_field(exist_keys, task_name, file_name, setups, deform_dict, device, mask, **kwargs): + Idef, _ = read_and_deform(file_name, torch.float, deform_dict, mask, device) + if setups['flip']: + Idef = torch.flip(Idef, [0]) + return {'bias_field': Idef[None]} + +def read_and_deform_surface(exist_keys, task_name, file_name, setups, deform_dict, device, mask, size): + Fneg, A, c2 = deform_dict['Fneg'], deform_dict['A'], deform_dict['c2'] + # NOTE does not support left_hemis for now + + mat = loadmat(file_name.split('.nii')[0] + '.mat') + + Vlw = torch.tensor(mat['Vlw'], dtype=torch.float, device=device) + Flw = torch.tensor(mat['Flw'], dtype=torch.int, device=device) + Vrw = torch.tensor(mat['Vrw'], dtype=torch.float, device=device) + Frw = torch.tensor(mat['Frw'], dtype=torch.int, device=device) + Vlp = torch.tensor(mat['Vlp'], dtype=torch.float, device=device) + Flp = torch.tensor(mat['Flp'], dtype=torch.int, device=device) + Vrp = torch.tensor(mat['Vrp'], dtype=torch.float, device=device) + Frp = torch.tensor(mat['Frp'], dtype=torch.int, device=device) + + Ainv = torch.inverse(A) + Vlw -= c2[None, :] + Vlw = Vlw @ torch.transpose(Ainv, 0, 1) + Vlw += fast_3D_interp_torch(Fneg, Vlw[:, 0] + c2[0], Vlw[:, 1]+c2[1], Vlw[:, 2] + c2[2]) + Vlw += c2[None, :] + Vrw -= c2[None, :] + Vrw = Vrw @ torch.transpose(Ainv, 0, 1) + Vrw += fast_3D_interp_torch(Fneg, Vrw[:, 0] + c2[0], Vrw[:, 1]+c2[1], Vrw[:, 2] + c2[2]) + Vrw += c2[None, :] + Vlp -= c2[None, :] + Vlp = Vlp @ torch.transpose(Ainv, 0, 1) + Vlp += fast_3D_interp_torch(Fneg, Vlp[:, 0] + c2[0], Vlp[:, 1] + c2[1], Vlp[:, 2] + c2[2]) + Vlp += c2[None, :] + Vrp -= c2[None, :] + Vrp = Vrp @ torch.transpose(Ainv, 0, 1) + Vrp += fast_3D_interp_torch(Fneg, Vrp[:, 0] + c2[0], Vrp[:, 1] + c2[1], Vrp[:, 2] + c2[2]) + Vrp += c2[None, :] + + if setups['flip']: + Vlw[:, 0] = size[0] - 1 - Vlw[:, 0] + Vrw[:, 0] = size[0] - 1 - Vrw[:, 0] + Vlp[:, 0] = size[0] - 1 - Vlp[:, 0] + Vrp[:, 0] = size[0] - 1 - Vrp[:, 0] + Vlw, Vrw = Vrw, Vlw + Vlp, Vrp = Vrp, Vlp + Flw, Frw = Frw, Flw + Flp, Frp = Frp, Flp + + print(Vlw.shape) # 131148 + print(Vlp.shape) # 131148 + + print(Vrw.shape) # 131720 + print(Vrp.shape) # 131720 + + print(Flw.shape) # 262292 + print(Flp.shape) # 262292 + + print(Frw.shape) # 263436 + print(Frp.shape) # 263436 + #return torch.stack([Vlw, Flw, Vrw, Frw, Vlp, Flp, Vrp, Frp]) + return {'Vlw': Vlw, 'Flw': Flw, 'Vrw': Vrw, 'Frw': Frw, 'Vlp': Vlp, 'Flp': Flp, 'Vrp': Vrp, 'Frp': Frp} + + +##################################### +######### Pathology Shape ######### +##################################### + + +def augment_pathology(Pprob, pde_func, t, shape_gen_args, device): + Pprob = torch.squeeze(Pprob) + + nt = np.random.randint(1, shape_gen_args.max_nt+1) + if nt <= 1: + return Pprob + + pde_func.V_dict = generate_velocity_3d(Pprob.shape, shape_gen_args.perlin_res, shape_gen_args.V_multiplier, device) + + #start_time = time.time() + Pprob = odeint(pde_func, Pprob[None], t[:nt], + shape_gen_args.dt, + method = shape_gen_args.integ_method)[-1, 0] # (last_t, n_batch=1, s, r, c) + # total_time = time.time() - start_time + #total_time_str = str(datetime.timedelta(seconds=int(total_time))) + #print('Time {} for {} time points'.format(total_time_str, nt)) + + + return Pprob + + +##################################### +######### Augmentation Func ######### +##################################### + + +def add_gamma_transform(I, aux_dict, cfg, device, **kwargs): + gamma = torch.tensor(np.exp(cfg.gamma_std * np.random.randn(1)[0]), dtype=float, device=device) + I_gamma = 300.0 * (I / 300.0) ** gamma + #aux_dict.update({'gamma': gamma}) # uncomment if you want to save gamma for later use + return I_gamma, aux_dict + +def add_bias_field(I, aux_dict, cfg, input_mode, setups, size, device, **kwargs): + if input_mode == 'CT': + aux_dict.update({'high_res': I}) + return I, aux_dict + + bf_scale = cfg.bf_scale_min + np.random.rand(1) * (cfg.bf_scale_max - cfg.bf_scale_min) + size_BF_small = np.round(bf_scale * np.array(size)).astype(int).tolist() + if setups['photo_mode']: + size_BF_small[1] = np.round(size[1]/setups['spac']).astype(int) + BFsmall = torch.tensor(cfg.bf_std_min + (cfg.bf_std_max - cfg.bf_std_min) * np.random.rand(1), dtype=torch.float, device=device) * \ + torch.randn(size_BF_small, dtype=torch.float, device=device) + BFlog = myzoom_torch(BFsmall, np.array(size) / size_BF_small) + BF = torch.exp(BFlog) + I_bf = I * BF + aux_dict.update({'BFlog': BFlog, 'high_res': I_bf}) + return I_bf, aux_dict + +def resample_resolution(I, aux_dict, setups, res, size, device, **kwargs): + stds = (0.85 + 0.3 * np.random.rand()) * np.log(5) /np.pi * setups['thickness'] / res + stds[setups['thickness']<=res] = 0.0 # no blur if thickness is equal to the resolution of the training data + I_blur = gaussian_blur_3d(I, stds, device) + new_size = (np.array(size) * res / setups['resolution']).astype(int) + + factors = np.array(new_size) / np.array(size) + delta = (1.0 - factors) / (2.0 * factors) + vx = np.arange(delta[0], delta[0] + new_size[0] / factors[0], 1 / factors[0])[:new_size[0]] + vy = np.arange(delta[1], delta[1] + new_size[1] / factors[1], 1 / factors[1])[:new_size[1]] + vz = np.arange(delta[2], delta[2] + new_size[2] / factors[2], 1 / factors[2])[:new_size[2]] + II, JJ, KK = np.meshgrid(vx, vy, vz, sparse=False, indexing='ij') + II = torch.tensor(II, dtype=torch.float, device=device) + JJ = torch.tensor(JJ, dtype=torch.float, device=device) + KK = torch.tensor(KK, dtype=torch.float, device=device) + + I_small = fast_3D_interp_torch(I_blur, II, JJ, KK) + aux_dict.update({'factors': factors}) + return I_small, aux_dict + + +def resample_resolution_photo(I, aux_dict, setups, res, size, device, **kwargs): + stds = (0.85 + 0.3 * np.random.rand()) * np.log(5) /np.pi * setups['thickness'] / res + stds[setups['thickness']<=res] = 0.0 # no blur if thickness is equal to the resolution of the training data + I_blur = gaussian_blur_3d(I, stds, device) + new_size = (np.array(size) * res / setups['resolution']).astype(int) + + factors = np.array(new_size) / np.array(size) + delta = (1.0 - factors) / (2.0 * factors) + vx = np.arange(delta[0], delta[0] + new_size[0] / factors[0], 1 / factors[0])[:new_size[0]] + vy = np.arange(delta[1], delta[1] + new_size[1] / factors[1], 1 / factors[1])[:new_size[1]] + vz = np.arange(delta[2], delta[2] + new_size[2] / factors[2], 1 / factors[2])[:new_size[2]] + II, JJ, KK = np.meshgrid(vx, vy, vz, sparse=False, indexing='ij') + II = torch.tensor(II, dtype=torch.float, device=device) + JJ = torch.tensor(JJ, dtype=torch.float, device=device) + KK = torch.tensor(KK, dtype=torch.float, device=device) + + I_small = fast_3D_interp_torch(I_blur, II, JJ, KK) + aux_dict.update({'factors': factors}) + return I_small, aux_dict + + +def add_noise(I, aux_dict, cfg, device, **kwargs): + noise_std = torch.tensor(cfg.noise_std_min + (cfg.noise_std_max - cfg.noise_std_min) * np.random.rand(1), dtype=torch.float, device=device) + I_noisy = I + noise_std * torch.randn(I.shape, dtype=torch.float, device=device) + I_noisy[I_noisy < 0] = 0 + #aux_dict.update({'noise_std': noise_std}) # uncomment if you want to save noise_std for later use + return I_noisy, aux_dict + + +##################################### +##################################### + + +# map SynthSeg right to left labels for contrast synthesis +right_to_left_dict = { + 41: 2, + 42: 3, + 43: 4, + 44: 5, + 46: 7, + 47: 8, + 49: 10, + 50: 11, + 51: 12, + 52: 13, + 53: 17, + 54: 18, + 58: 26, + 60: 28 +} + +# based on merged left & right SynthSeg labels +ct_brightness_group = { + 'darker': [4, 5, 14, 15, 24, 31, 72], # ventricles, CSF + 'dark': [2, 7, 16, 77, 30], # white matter + 'bright': [3, 8, 17, 18, 28, 10, 11, 12, 13, 26], # grey matter (cortex, hippocampus, amggdala, ventral DC), thalamus, ganglia (nucleus (putamen, pallidus, accumbens), caudate) + 'brighter': [], # skull, pineal gland, choroid plexus +} diff --git a/README.md b/README.md index 7b95401dc46245ac339fc25059d4a56d90b4cde5..0cacaf9e2c1cb749056cd0e1f62042677938a819 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,91 @@ ---- -license: apache-2.0 ---- + +##

[A Modality-agnostic Multi-task Foundation Model for Human Brain Imaging](https://arxiv.org/abs/2509.00549)

+ +**

Peirong Liu1,2, Oula Puonti2, Xiaoling Hu2, Karthik Gopinath2, Annabel Sorby-Adams2, Daniel C. Alexander3, Juan Eugenio Iglesias2,3,4

** + + +

+1Johns Hopkins University
+2Harvard Medical School and Massachusetts General Hospital
+3University College London
+4Massachusetts Institute of Technology +

+ +

+ drawing +

+ + +This is the official repository for our preprint: A Modality-agnostic Multi-task Foundation Model for Human Brain Imaging [[arXiv]](https://arxiv.org/abs/2509.00549)
+More detailed and organized instructions are coming soon... + +## Environment +Training and evaluation environment: Python 3.11.4, PyTorch 2.0.1, CUDA 12.2. Run the following command to install required packages. +``` +conda create -n pre python=3.11 +conda activate pre + +git clone https://github.com/jhuldr/BrainFM +cd /path/to/brainfm +pip install -r requirements.txt +``` + + +## Generator +``` +cd scripts +python demo_generator.py +``` + +### Generator setups +Setups are in cfgs/generator, default setups are in default.yaml. A customized setup example can be found in train/brain_id.yaml, where several Brain-ID-specific setups are added. During Config reading/implementation, customized yaml will overwrite default.yaml if they have the same keys. + +dataset_setups: information for all datasets, in Generator/constants.py
+augmentation_funcs: augmentation functions and steps, in Generator/constants.py
+processing_funcs: image processing functions for each modality/task, in Generator/constants.py
+ +dataset_names: dataset name list, paths setups in Generator/constants.py
+mix_synth_prob: if the input mode is synthesizing, then probability for blending synth with real images
+dataset_option: generator types, could be BaseGen or customized generator
+task: switch on/off individual training tasks + +### Base generator module +``` +cd Generator +python datasets.py +``` +The dataset paths setups are in constants.py. In datasets.py, different datasets been used are fomulated as a list of dataset names. + +A customized data generator module example can be found in datasets.py -- BrainIDGen. + + +Refer to "__getitem__" function. Specifically, it includes:
+(1) read original input: could be either generation labels or real images;
+(2) generate augmentation setups and deformation fields;
+(3) read target(s) according to the assigned tasks -- here I seperate the processing functions for each item/modality, in case we want different processing steps for them;
+(4) augment input sample: either synthesized or real image input. + + + +(Some of the functions are leaved blank for now.) + + + +## Trainer +``` +cd scripts +python train.py +``` + +## Downloads +The pre-trained model weight is available on [OneDrive](https://livejohnshopkins-my.sharepoint.com/:u:/g/personal/pliu53_jh_edu/EZ_BJ7K6pMJEj9hZ8SA51GYBxH_Nan4fA3a-s4udwvVRog?e=nwZ7JC). + + +## Citation +```bibtex +@article{Liu_2025_BrainFM, + author = {Liu, Peirong and Puonti, Oula and Hu, Xiaoling and Gopinath, Karthik and Sorby-Adams, Annabel and Alexander, Daniel C. and Iglesias, Juan E.}, + title = {A Modality-agnostic Multi-task Foundation Model for Human Brain Imaging}, + booktitle = {arXiv preprint arXiv:2509.00549}, + year = {2025}, +} diff --git a/ShapeID/DiffEqs/FD.py b/ShapeID/DiffEqs/FD.py new file mode 100644 index 0000000000000000000000000000000000000000..d7c11968a1b23e3854c8afd5b6193ae11da176cf --- /dev/null +++ b/ShapeID/DiffEqs/FD.py @@ -0,0 +1,525 @@ +""" +*finite_difference.py* is the main package to compute finite differences in +1D, 2D, and 3D on numpy arrays (class FD_np) and pytorch tensors (class FD_torch). +The package supports first and second order derivatives and Neumann and linear extrapolation +boundary conditions (though the latter have not been tested extensively yet). +""" +from __future__ import absolute_import + +# from builtins import object +from abc import ABCMeta, abstractmethod + +import torch +from torch.autograd import Variable +import numpy as np +from future.utils import with_metaclass + +class FD(with_metaclass(ABCMeta, object)): + """ + *FD* is the abstract class for finite differences. It includes most of the actual finite difference code, + but requires the definition (in a derived class) of the methods *get_dimension*, *create_zero_array*, and *get_size_of_array*. + In this way the numpy and pytorch versions can easily be derived. All the method expect BxXxYxZ format (i.e., they process a batch at a time) + """ + + def __init__(self, spacing, bcNeumannZero=True): + """ + Constructor + :param spacing: 1D numpy array defining the spatial spacing, e.g., [0.1,0.1,0.1] for a 3D image + :param bcNeumannZero: Defines the boundary condition. If set to *True* (default) zero Neumann boundary conditions + are imposed. If set to *False* linear extrapolation is used (this is still experimental, but may be beneficial + for better boundary behavior) + """ + + self.dim = len(spacing) # In my code, data_spacing is a list # spacing.size + """spatial dimension""" + self.spacing = np.ones(self.dim) + """spacing""" + self.bcNeumannZero = bcNeumannZero # if false then linear interpolation + """should Neumann boundary conditions be used? (otherwise linear extrapolation)""" + if len(spacing) == 1: #spacing.size==1: + self.spacing[0] = spacing[0] + elif len(spacing) == 2: # spacing.size==2: + self.spacing[0] = spacing[0] + self.spacing[1] = spacing[1] + elif len(spacing) == 3: # spacing.size==3: + self.spacing[0] = spacing[0] + self.spacing[1] = spacing[1] + self.spacing[2] = spacing[2] + else: + print('Current dimension:', len(spacing)) + raise ValueError('Finite differences are only supported in dimensions 1 to 3') + + def dXb(self,I): + """ + Backward difference in x direction: + :math:`\\frac{dI(i)}{dx}\\approx\\frac{I_i-I_{i-1}}{h_x}` + :param I: Input image + :return: Returns the first derivative in x direction using backward differences + """ + return (I-self.xm(I))/self.spacing[0] + + def dXf(self,I): + """ + Forward difference in x direction: + :math:`\\frac{dI(i)}{dx}\\approx\\frac{I_{i+1}-I_{i}}{h_x}` + + :param I: Input image + :return: Returns the first derivative in x direction using forward differences + """ + return (self.xp(I)-I)/self.spacing[0] + + def dXc(self,I): + """ + Central difference in x direction: + :math:`\\frac{dI(i)}{dx}\\approx\\frac{I_{i+1}-I_{i-1}}{2h_x}` + + :param I: Input image + :return: Returns the first derivative in x direction using central differences + """ + return (self.xp(I)-self.xm(I))/(2*self.spacing[0]) + + def ddXc(self,I): + """ + Second deriative in x direction + + :param I: Input image + :return: Returns the second derivative in x direction + """ + return (self.xp(I)-2*I+self.xm(I))/(self.spacing[0]**2) + + def dYb(self,I): + """ + Same as dXb, but for the y direction + + :param I: Input image + :return: Returns the first derivative in y direction using backward differences + """ + return (I-self.ym(I))/self.spacing[1] + + def dYf(self,I): + """ + Same as dXf, but for the y direction + + :param I: Input image + :return: Returns the first derivative in y direction using forward differences + """ + return (self.yp(I)-I)/self.spacing[1] + + def dYc(self,I): + """ + Same as dXc, but for the y direction + + :param I: Input image + :return: Returns the first derivative in y direction using central differences + """ + return (self.yp(I)-self.ym(I))/(2*self.spacing[1]) + + def ddYc(self,I): + """ + Same as ddXc, but for the y direction + + :param I: Input image + :return: Returns the second derivative in the y direction + """ + return (self.yp(I)-2*I+self.ym(I))/(self.spacing[1]**2) + + def dZb(self,I): + """ + Same as dXb, but for the z direction + + :param I: Input image + :return: Returns the first derivative in the z direction using backward differences + """ + return (I - self.zm(I))/self.spacing[2] + + def dZf(self, I): + """ + Same as dXf, but for the z direction + + :param I: Input image + :return: Returns the first derivative in the z direction using forward differences + """ + return (self.zp(I)-I)/self.spacing[2] + + def dZc(self, I): + """ + Same as dXc, but for the z direction + + :param I: Input image + :return: Returns the first derivative in the z direction using central differences + """ + return (self.zp(I)-self.zm(I))/(2*self.spacing[2]) + + def ddZc(self,I): + """ + Same as ddXc, but for the z direction + + :param I: Input iamge + :return: Returns the second derivative in the z direction + """ + return (self.zp(I)-2*I+self.zm(I))/(self.spacing[2]**2) + + def lap(self, I): + """ + Compute the Lapacian of an image + !!!!!!!!!!! + IMPORTANT: + ALL THE FOLLOWING IMPLEMENTED CODE ADD 1 ON DIMENSION, WHICH REPRESENT BATCH DIMENSION. + THIS IS FOR COMPUTATIONAL EFFICIENCY. + + :param I: Input image [batch, X,Y,Z] + :return: Returns the Laplacian + """ + ndim = self.getdimension(I) + if ndim == 1+1: + return self.ddXc(I) + elif ndim == 2+1: + return (self.ddXc(I) + self.ddYc(I)) + elif ndim == 3+1: + return (self.ddXc(I) + self.ddYc(I) + self.ddZc(I)) + else: + raise ValueError('Finite differences are only supported in dimensions 1 to 3') + + def grad_norm_sqr_c(self, I): + """ + Computes the gradient norm of an image + !!!!!!!!!!! + IMPORTANT: + ALL THE FOLLOWING IMPLEMENTED CODE ADD 1 ON DIMENSION, WHICH REPRESENT BATCH DIMENSION. + THIS IS FOR COMPUTATIONAL EFFICIENCY. + :param I: Input image [batch, X,Y,Z] + :return: returns ||grad I||^2 + """ + ndim = self.getdimension(I) + if ndim == 1 + 1: + return self.dXc(I)**2 + elif ndim == 2 + 1: + return (self.dXc(I)**2 + self.dYc(I)**2) + elif ndim == 3 + 1: + return (self.dXc(I)**2 + self.dYc(I)**2 + self.dZc(I)**2) + else: + raise ValueError('Finite differences are only supported in dimensions 1 to 3') + + def grad_norm_sqr_f(self, I): + """ + Computes the gradient norm of an image + !!!!!!!!!!! + IMPORTANT: + ALL THE FOLLOWING IMPLEMENTED CODE ADD 1 ON DIMENSION, WHICH REPRESENT BATCH DIMENSION. + THIS IS FOR COMPUTATIONAL EFFICIENCY. + :param I: Input image [batch, X,Y,Z] + :return: returns ||grad I||^2 + """ + ndim = self.getdimension(I) + if ndim == 1 + 1: + return self.dXf(I)**2 + elif ndim == 2 + 1: + return (self.dXf(I)**2 + self.dYf(I)**2) + elif ndim == 3 + 1: + return (self.dXf(I)**2 + self.dYf(I)**2 + self.dZf(I)**2) + else: + raise ValueError('Finite differences are only supported in dimensions 1 to 3') + + def grad_norm_sqr_b(self, I): + """ + Computes the gradient norm of an image + !!!!!!!!!!! + IMPORTANT: + ALL THE FOLLOWING IMPLEMENTED CODE ADD 1 ON DIMENSION, WHICH REPRESENT BATCH DIMENSION. + THIS IS FOR COMPUTATIONAL EFFICIENCY. + :param I: Input image [batch, X,Y,Z] + :return: returns ||grad I||^2 + """ + ndim = self.getdimension(I) + if ndim == 1 + 1: + return self.dXb(I)**2 + elif ndim == 2 + 1: + return (self.dXb(I)**2 + self.dYb(I)**2) + elif ndim == 3 + 1: + return (self.dXb(I)**2 + self.dYb(I)**2 + self.dZb(I)**2) + else: + raise ValueError('Finite differences are only supported in dimensions 1 to 3') + + @abstractmethod + def getdimension(self,I): + """ + Abstract method to return the dimension of an input image I + + :param I: Input image + :return: Returns the dimension of the image I + """ + pass + + @abstractmethod + def create_zero_array(self, sz): + """ + Abstract method to create a zero array of a given size, sz. E.g., sz=[10,2,5] + + :param sz: Size array + :return: Returns a zero array of the specified size + """ + pass + + @abstractmethod + def get_size_of_array(self, A): + """ + Abstract method to return the size of an array (as a vector) + + :param A: Input array + :return: Returns its size (e.g., [5,10] or [3,4,6] + """ + pass + + def xp(self,I): + """ + !!!!!!!!!!! + IMPORTANT: + ALL THE FOLLOWING IMPLEMENTED CODE ADD 1 ON DIMENSION, WHICH REPRESENT BATCH DIMENSION. + THIS IS FOR COMPUTATIONAL EFFICIENCY. + Returns the values for x-index incremented by one (to the right in 1D) + + :param I: Input image [batch, X, Y,Z] + :return: Image with values at an x-index one larger + """ + rxp = self.create_zero_array( self.get_size_of_array( I ) ) + ndim = self.getdimension(I) + if ndim == 1+1: + rxp[:,0:-1] = I[:,1:] + if self.bcNeumannZero: + rxp[:,-1] = I[:,-1] + else: + rxp[:,-1] = 2*I[:,-1]-I[:,-2] + elif ndim == 2+1: + rxp[:,0:-1,:] = I[:,1:,:] + if self.bcNeumannZero: + rxp[:,-1,:] = I[:,-1,:] + else: + rxp[:,-1,:] = 2*I[:,-1,:]-I[:,-2,:] + elif ndim == 3+1: + rxp[:,0:-1,:,:] = I[:,1:,:,:] + if self.bcNeumannZero: + rxp[:,-1,:,:] = I[:,-1,:,:] + else: + rxp[:,-1,:,:] = 2*I[:,-1,:,:]-I[:,-2,:,:] + else: + raise ValueError('Finite differences are only supported in dimensions 1 to 3') + return rxp + + def xm(self,I): + """ + !!!!!!!!!!! + IMPORTANT: + ALL THE FOLLOWING IMPLEMENTED CODE ADD 1 ON DIMENSION, WHICH REPRESENT BATCH DIMENSION. + THIS IS FOR COMPUTATIONAL EFFICIENCY. + Returns the values for x-index decremented by one (to the left in 1D) + + :param I: Input image [batch, X, Y, Z] + :return: Image with values at an x-index one smaller + """ + rxm = self.create_zero_array( self.get_size_of_array( I ) ) + ndim = self.getdimension(I) + if ndim == 1+1: + rxm[:,1:] = I[:,0:-1] + if self.bcNeumannZero: + rxm[:,0] = I[:,0] + else: + rxm[:,0] = 2*I[:,0]-I[:,1] + elif ndim == 2+1: + rxm[:,1:,:] = I[:,0:-1,:] + if self.bcNeumannZero: + rxm[:,0,:] = I[:,0,:] + else: + rxm[:,0,:] = 2*I[:,0,:]-I[:,1,:] + elif ndim == 3+1: + rxm[:,1:,:,:] = I[:,0:-1,:,:] + if self.bcNeumannZero: + rxm[:,0,:,:] = I[:,0,:,:] + else: + rxm[:,0,:,:] = 2*I[:,0,:,:]-I[:,1,:,:] + else: + raise ValueError('Finite differences are only supported in dimensions 1 to 3') + return rxm + + def yp(self, I): + """ + !!!!!!!!!!! + IMPORTANT: + ALL THE FOLLOWING IMPLEMENTED CODE ADD 1 ON DIMENSION, WHICH REPRESENT BATCH DIMENSION. + THIS IS FOR COMPUTATIONAL EFFICIENCY. + Same as xp, but for the y direction + + :param I: Input image + :return: Image with values at y-index one larger + """ + ryp = self.create_zero_array( self.get_size_of_array( I ) ) + ndim = self.getdimension(I) + if ndim == 2+1: + ryp[:,:,0:-1] = I[:,:,1:] + if self.bcNeumannZero: + ryp[:,:,-1] = I[:,:,-1] + else: + ryp[:,:,-1] = 2*I[:,:,-1]-I[:,:,-2] + elif ndim == 3+1: + ryp[:,:,0:-1,:] = I[:,:,1:,:] + if self.bcNeumannZero: + ryp[:,:,-1,:] = I[:,:,-1,:] + else: + ryp[:,:,-1,:] = 2*I[:,:,-1,:]-I[:,:,-2,:] + else: + print('Current dimension:', ndim-1) + raise ValueError('Finite differences are only supported in dimensions 1 to 3') + return ryp + + def ym(self, I): + """ + Same as xm, but for the y direction + !!!!!!!!!!! + IMPORTANT: + ALL THE FOLLOWING IMPLEMENTED CODE ADD 1 ON DIMENSION, WHICH REPRESENT BATCH DIMENSION. + THIS IS FOR COMPUTATIONAL EFFICIENCY. + Returns the values for x-index decremented by one (to the left in 1D) + :param I: Input image [batch, X, Y, Z] + :return: Image with values at y-index one smaller + """ + rym = self.create_zero_array( self.get_size_of_array( I ) ) + ndim = self.getdimension(I) + if ndim == 2+1: + rym[:,:,1:] = I[:,:,0:-1] + if self.bcNeumannZero: + rym[:,:,0] = I[:,:,0] + else: + rym[:,:,0] = 2*I[:,:,0]-I[:,:,1] + elif ndim == 3+1: + rym[:,:,1:,:] = I[:,:,0:-1,:] + if self.bcNeumannZero: + rym[:,:,0,:] = I[:,:,0,:] + else: + rym[:,:,0,:] = 2*I[:,:,0,:]-I[:,:,1,:] + else: + raise ValueError('Finite differences are only supported in dimensions 1 to 3') + return rym + + def zp(self, I): + """ + Same as xp, but for the z direction + + !!!!!!!!!!! + IMPORTANT: + ALL THE FOLLOWING IMPLEMENTED CODE ADD 1 ON DIMENSION, WHICH REPRESENT BATCH DIMENSION. + THIS IS FOR COMPUTATIONAL EFFICIENCY. + Returns the values for x-index decremented by one (to the left in 1D) + :param I: Input image [batch, X, Y, Z] + :return: Image with values at z-index one larger + """ + rzp = self.create_zero_array( self.get_size_of_array( I ) ) + ndim = self.getdimension(I) + if ndim == 3+1: + rzp[:,:,:,0:-1] = I[:,:,:,1:] + if self.bcNeumannZero: + rzp[:,:,:,-1] = I[:,:,:,-1] + else: + rzp[:,:,:,-1] = 2*I[:,:,:,-1]-I[:,:,:,-2] + else: + raise ValueError('Finite differences are only supported in dimensions 1 to 3') + return rzp + + def zm(self, I): + """ + Same as xm, but for the z direction + + !!!!!!!!!!! + IMPORTANT: + ALL THE FOLLOWING IMPLEMENTED CODE ADD 1 ON DIMENSION, WHICH REPRESENT BATCH DIMENSION. + THIS IS FOR COMPUTATIONAL EFFICIENCY. + Returns the values for x-index decremented by one (to the left in 1D) + :param I: Input image [batch, X, Y, Z] + :return: Image with values at z-index one smaller + """ + rzm = self.create_zero_array( self.get_size_of_array( I ) ) + ndim = self.getdimension(I) + if ndim == 3+1: + rzm[:,:,:,1:] = I[:,:,:,0:-1] + if self.bcNeumannZero: + rzm[:,:,:,0] = I[:,:,:,0] + else: + rzm[:,:,:,0] = 2*I[:,:,:,0]-I[:,:,:,1] + else: + raise ValueError('Finite differences are only supported in dimensions 1 to 3') + return rzm + + +class FD_np(FD): + """ + Defnitions of the abstract methods for numpy + """ + + def __init__(self,spacing,bcNeumannZero=True): + """ + Constructor for numpy finite differences + :param spacing: spatial spacing (array with as many entries as there are spatial dimensions) + :param bcNeumannZero: Specifies if zero Neumann conditions should be used (if not, uses linear extrapolation) + """ + super(FD_np, self).__init__(spacing,bcNeumannZero) + + def getdimension(self,I): + """ + Returns the dimension of an image + :param I: input image + :return: dimension of the input image + """ + return I.ndim + + def create_zero_array(self, sz): + """ + Creates a zero array + :param sz: size of the zero array, e.g., [3,4,2] + :return: the zero array + """ + return np.zeros( sz ) + + def get_size_of_array(self, A): + """ + Returns the size (shape in numpy) of an array + :param A: input array + :return: shape/size + """ + return A.shape + + +class FD_torch(FD): + """ + Defnitions of the abstract methods for torch + """ + + def __init__(self,spacing,device,bcNeumannZero=True): + """ + Constructor for torch finite differences + :param spacing: spatial spacing (array with as many entries as there are spatial dimensions) + :param bcNeumannZero: Specifies if zero Neumann conditions should be used (if not, uses linear extrapolation) + """ + super(FD_torch, self).__init__(spacing,bcNeumannZero) + self.device = device + + def getdimension(self,I): + """ + Returns the dimension of an image + :param I: input image + :return: dimension of the input image + """ + return I.dim() + + def create_zero_array(self, sz): + """ + Creats a zero array + :param sz: size of the array, e.g., [3,4,2] + :return: the zero array + """ + return torch.zeros(sz).float().to(self.device) + + def get_size_of_array(self, A): + """ + Returns the size (size()) of an array + :param A: input array + :return: shape/size + """ + return A.size() \ No newline at end of file diff --git a/ShapeID/DiffEqs/adams.py b/ShapeID/DiffEqs/adams.py new file mode 100644 index 0000000000000000000000000000000000000000..19f1acd17d2808a6c46719848c4bd8d5c9fcfca0 --- /dev/null +++ b/ShapeID/DiffEqs/adams.py @@ -0,0 +1,170 @@ +import collections +import torch +from ShapeID.DiffEqs.solvers import AdaptiveStepsizeODESolver +from ShapeID.DiffEqs.misc import ( + _handle_unused_kwargs, _select_initial_step, _convert_to_tensor, _scaled_dot_product, _is_iterable, + _optimal_step_size, _compute_error_ratio +) + +_MIN_ORDER = 1 +_MAX_ORDER = 12 + +gamma_star = [ + 1, -1 / 2, -1 / 12, -1 / 24, -19 / 720, -3 / 160, -863 / 60480, -275 / 24192, -33953 / 3628800, -0.00789255, + -0.00678585, -0.00592406, -0.00523669, -0.0046775, -0.00421495, -0.0038269 +] + + +class _VCABMState(collections.namedtuple('_VCABMState', 'y_n, prev_f, prev_t, next_t, phi, order')): + """Saved state of the variable step size Adams-Bashforth-Moulton solver as described in + + Solving Ordinary Differential Equations I - Nonstiff Problems III.5 + by Ernst Hairer, Gerhard Wanner, and Syvert P Norsett. + """ + + +def g_and_explicit_phi(prev_t, next_t, implicit_phi, k): + curr_t = prev_t[0] + dt = next_t - prev_t[0] + + g = torch.empty(k + 1).to(prev_t[0]) + explicit_phi = collections.deque(maxlen=k) + beta = torch.tensor(1).to(prev_t[0]) + + g[0] = 1 + c = 1 / torch.arange(1, k + 2).to(prev_t[0]) + explicit_phi.append(implicit_phi[0]) + + for j in range(1, k): + beta = (next_t - prev_t[j - 1]) / (curr_t - prev_t[j]) * beta + beat_cast = beta.to(implicit_phi[j][0]) + explicit_phi.append(tuple(iphi_ * beat_cast for iphi_ in implicit_phi[j])) + + c = c[:-1] - c[1:] if j == 1 else c[:-1] - c[1:] * dt / (next_t - prev_t[j - 1]) + g[j] = c[0] + + c = c[:-1] - c[1:] * dt / (next_t - prev_t[k - 1]) + g[k] = c[0] + + return g, explicit_phi + + +def compute_implicit_phi(explicit_phi, f_n, k): + k = min(len(explicit_phi) + 1, k) + implicit_phi = collections.deque(maxlen=k) + implicit_phi.append(f_n) + for j in range(1, k): + implicit_phi.append(tuple(iphi_ - ephi_ for iphi_, ephi_ in zip(implicit_phi[j - 1], explicit_phi[j - 1]))) + return implicit_phi + + +class VariableCoefficientAdamsBashforth(AdaptiveStepsizeODESolver): + + def __init__( + self, func, y0, rtol, atol, implicit=True, max_order=_MAX_ORDER, safety=0.9, ifactor=10.0, dfactor=0.2, + **unused_kwargs + ): + _handle_unused_kwargs(self, unused_kwargs) + del unused_kwargs + + self.func = func + self.y0 = y0 + self.rtol = rtol if _is_iterable(rtol) else [rtol] * len(y0) + self.atol = atol if _is_iterable(atol) else [atol] * len(y0) + self.implicit = implicit + self.max_order = int(max(_MIN_ORDER, min(max_order, _MAX_ORDER))) + self.safety = _convert_to_tensor(safety, dtype=torch.float64, device=y0[0].device) + self.ifactor = _convert_to_tensor(ifactor, dtype=torch.float64, device=y0[0].device) + self.dfactor = _convert_to_tensor(dfactor, dtype=torch.float64, device=y0[0].device) + + def before_integrate(self, t): + prev_f = collections.deque(maxlen=self.max_order + 1) + prev_t = collections.deque(maxlen=self.max_order + 1) + phi = collections.deque(maxlen=self.max_order) + + t0 = t[0] + f0 = self.func(t0.type_as(self.y0[0]), self.y0) + prev_t.appendleft(t0) + prev_f.appendleft(f0) + phi.appendleft(f0) + first_step = _select_initial_step(self.func, t[0], self.y0, 2, self.rtol[0], self.atol[0], f0=f0).to(t) + + self.vcabm_state = _VCABMState(self.y0, prev_f, prev_t, next_t=t[0] + first_step, phi=phi, order=1) + + def advance(self, final_t): + final_t = _convert_to_tensor(final_t).to(self.vcabm_state.prev_t[0]) + while final_t > self.vcabm_state.prev_t[0]: + self.vcabm_state = self._adaptive_adams_step(self.vcabm_state, final_t) + assert final_t == self.vcabm_state.prev_t[0] + return self.vcabm_state.y_n + + def _adaptive_adams_step(self, vcabm_state, final_t): + y0, prev_f, prev_t, next_t, prev_phi, order = vcabm_state + if next_t > final_t: + next_t = final_t + dt = (next_t - prev_t[0]) + dt_cast = dt.to(y0[0]) + + # Explicit predictor step. + g, phi = g_and_explicit_phi(prev_t, next_t, prev_phi, order) + g = g.to(y0[0]) + p_next = tuple( + y0_ + _scaled_dot_product(dt_cast, g[:max(1, order - 1)], phi_[:max(1, order - 1)]) + for y0_, phi_ in zip(y0, tuple(zip(*phi))) + ) + + # Update phi to implicit. + next_f0 = self.func(next_t.to(p_next[0]), p_next) + implicit_phi_p = compute_implicit_phi(phi, next_f0, order + 1) + + # Implicit corrector step. + y_next = tuple( + p_next_ + dt_cast * g[order - 1] * iphi_ for p_next_, iphi_ in zip(p_next, implicit_phi_p[order - 1]) + ) + + # Error estimation. + tolerance = tuple( + atol_ + rtol_ * torch.max(torch.abs(y0_), torch.abs(y1_)) + for atol_, rtol_, y0_, y1_ in zip(self.atol, self.rtol, y0, y_next) + ) + local_error = tuple(dt_cast * (g[order] - g[order - 1]) * iphi_ for iphi_ in implicit_phi_p[order]) + error_k = _compute_error_ratio(local_error, tolerance) + accept_step = (torch.tensor(error_k) <= 1).all() + + if not accept_step: + # Retry with adjusted step size if step is rejected. + dt_next = _optimal_step_size(dt, error_k, self.safety, self.ifactor, self.dfactor, order=order) + return _VCABMState(y0, prev_f, prev_t, prev_t[0] + dt_next, prev_phi, order=order) + + # We accept the step. Evaluate f and update phi. + next_f0 = self.func(next_t.to(p_next[0]), y_next) + implicit_phi = compute_implicit_phi(phi, next_f0, order + 2) + + next_order = order + + if len(prev_t) <= 4 or order < 3: + next_order = min(order + 1, 3, self.max_order) + else: + error_km1 = _compute_error_ratio( + tuple(dt_cast * (g[order - 1] - g[order - 2]) * iphi_ for iphi_ in implicit_phi_p[order - 1]), tolerance + ) + error_km2 = _compute_error_ratio( + tuple(dt_cast * (g[order - 2] - g[order - 3]) * iphi_ for iphi_ in implicit_phi_p[order - 2]), tolerance + ) + if min(error_km1 + error_km2) < max(error_k): + next_order = order - 1 + elif order < self.max_order: + error_kp1 = _compute_error_ratio( + tuple(dt_cast * gamma_star[order] * iphi_ for iphi_ in implicit_phi_p[order]), tolerance + ) + if max(error_kp1) < max(error_k): + next_order = order + 1 + + # Keep step size constant if increasing order. Else use adaptive step size. + dt_next = dt if next_order > order else _optimal_step_size( + dt, error_k, self.safety, self.ifactor, self.dfactor, order=order + 1 + ) + + prev_f.appendleft(next_f0) + prev_t.appendleft(next_t) + return _VCABMState(p_next, prev_f, prev_t, next_t + dt_next, implicit_phi, order=next_order) diff --git a/ShapeID/DiffEqs/adjoint.py b/ShapeID/DiffEqs/adjoint.py new file mode 100644 index 0000000000000000000000000000000000000000..d2730fbbcbcf15b93f0d4e87a14a145681b12800 --- /dev/null +++ b/ShapeID/DiffEqs/adjoint.py @@ -0,0 +1,133 @@ +import torch +import torch.nn as nn +from ShapeID.DiffEqs.odeint import odeint +from ShapeID.DiffEqs.misc import _flatten, _flatten_convert_none_to_zeros + + +class OdeintAdjointMethod(torch.autograd.Function): + + @staticmethod + def forward(ctx, *args): + assert len(args) >= 8, 'Internal error: all arguments required.' + y0, func, t, dt, flat_params, rtol, atol, method, options = \ + args[:-8], args[-8], args[-7], args[-6], args[-5], args[-4], args[-3], args[-2], args[-1] + + ctx.func, ctx.rtol, ctx.atol, ctx.method, ctx.options = func, rtol, atol, method, options + + with torch.no_grad(): + ans = odeint(func, y0, t, dt, rtol=rtol, atol=atol, method=method, options=options) + ctx.save_for_backward(t, flat_params, *ans) + return ans + + @staticmethod + def backward(ctx, *grad_output): + + t, flat_params, *ans = ctx.saved_tensors + ans = tuple(ans) + func, rtol, atol, method, options = ctx.func, ctx.rtol, ctx.atol, ctx.method, ctx.options + n_tensors = len(ans) + f_params = tuple(func.parameters()) + + # TODO: use a nn.Module and call odeint_adjoint to implement higher order derivatives. + def augmented_dynamics(t, y_aug): + # Dynamics of the original system augmented with + # the adjoint wrt y, and an integrator wrt t and args. + y, adj_y = y_aug[:n_tensors], y_aug[n_tensors:2 * n_tensors] # Ignore adj_time and adj_params. + + with torch.set_grad_enabled(True): + t = t.to(y[0].device).detach().requires_grad_(True) + y = tuple(y_.detach().requires_grad_(True) for y_ in y) + func_eval = func(t, y) + vjp_t, *vjp_y_and_params = torch.autograd.grad( + func_eval, (t,) + y + f_params, + tuple(-adj_y_ for adj_y_ in adj_y), allow_unused=True, retain_graph=True + ) + vjp_y = vjp_y_and_params[:n_tensors] + vjp_params = vjp_y_and_params[n_tensors:] + + # autograd.grad returns None if no gradient, set to zero. + vjp_t = torch.zeros_like(t) if vjp_t is None else vjp_t + vjp_y = tuple(torch.zeros_like(y_) if vjp_y_ is None else vjp_y_ for vjp_y_, y_ in zip(vjp_y, y)) + vjp_params = _flatten_convert_none_to_zeros(vjp_params, f_params) + + if len(f_params) == 0: + vjp_params = torch.tensor(0.).to(vjp_y[0]) + return (*func_eval, *vjp_y, vjp_t, vjp_params) + + T = ans[0].shape[0] + with torch.no_grad(): + adj_y = tuple(grad_output_[-1] for grad_output_ in grad_output) + adj_params = torch.zeros_like(flat_params) + adj_time = torch.tensor(0.).to(t) + time_vjps = [] + for i in range(T - 1, 0, -1): + + ans_i = tuple(ans_[i] for ans_ in ans) + grad_output_i = tuple(grad_output_[i] for grad_output_ in grad_output) + func_i = func(t[i], ans_i) + + # Compute the effect of moving the current time measurement point. + dLd_cur_t = sum( + torch.dot(func_i_.reshape(-1), grad_output_i_.reshape(-1)).reshape(1) + for func_i_, grad_output_i_ in zip(func_i, grad_output_i) + ) + adj_time = adj_time - dLd_cur_t + time_vjps.append(dLd_cur_t) + + # Run the augmented system backwards in time. + if adj_params.numel() == 0: + adj_params = torch.tensor(0.).to(adj_y[0]) + aug_y0 = (*ans_i, *adj_y, adj_time, adj_params) + aug_ans = odeint( + augmented_dynamics, aug_y0, + torch.tensor([t[i], t[i - 1]]), rtol=rtol, atol=atol, method=method, options=options + ) + + # Unpack aug_ans. + adj_y = aug_ans[n_tensors:2 * n_tensors] + adj_time = aug_ans[2 * n_tensors] + adj_params = aug_ans[2 * n_tensors + 1] + + adj_y = tuple(adj_y_[1] if len(adj_y_) > 0 else adj_y_ for adj_y_ in adj_y) + if len(adj_time) > 0: adj_time = adj_time[1] + if len(adj_params) > 0: adj_params = adj_params[1] + + adj_y = tuple(adj_y_ + grad_output_[i - 1] for adj_y_, grad_output_ in zip(adj_y, grad_output)) + + del aug_y0, aug_ans + + time_vjps.append(adj_time) + time_vjps = torch.cat(time_vjps[::-1]) + + return (*adj_y, None, time_vjps, adj_params, None, None, None, None, None, None) # Add a None (TODO, futher check) + + +def odeint_adjoint(func, y0, t, dt, rtol=1e-6, atol=1e-12, method=None, options=None): + + # We need this in order to access the variables inside this module, + # since we have no other way of getting variables along the execution path. + if not isinstance(func, nn.Module): + raise ValueError('func is required to be an instance of nn.Module.') + + tensor_input = False + if torch.is_tensor(y0): + + class TupleFunc(nn.Module): + + def __init__(self, base_func): + super(TupleFunc, self).__init__() + self.base_func = base_func + + def forward(self, t, y): + return (self.base_func(t, y[0]),) + + tensor_input = True + y0 = (y0,) + func = TupleFunc(func) + + flat_params = _flatten(func.parameters()) + ys = OdeintAdjointMethod.apply(*y0, func, t, dt, flat_params, rtol, atol, method, options) + + if tensor_input: + ys = ys[0] + return ys \ No newline at end of file diff --git a/ShapeID/DiffEqs/dopri5.py b/ShapeID/DiffEqs/dopri5.py new file mode 100644 index 0000000000000000000000000000000000000000..f3a4f7fd77b2e441a864a37f001e5c6d2c7f607a --- /dev/null +++ b/ShapeID/DiffEqs/dopri5.py @@ -0,0 +1,172 @@ +import torch +from .misc import ( + _scaled_dot_product, _convert_to_tensor, _is_finite, _select_initial_step, _handle_unused_kwargs, _is_iterable, + _optimal_step_size, _compute_error_ratio +) +from .solvers import AdaptiveStepsizeODESolver, set_BC_2D, set_BC_3D, add_dBC_2D, add_dBC_3D +from .interp import _interp_fit, _interp_evaluate +from .rk_common import _RungeKuttaState, _ButcherTableau, _runge_kutta_step + + +_DORMAND_PRINCE_SHAMPINE_TABLEAU = _ButcherTableau( + alpha=[1 / 5, 3 / 10, 4 / 5, 8 / 9, 1., 1.], + beta=[ + [1 / 5], + [3 / 40, 9 / 40], + [44 / 45, -56 / 15, 32 / 9], + [19372 / 6561, -25360 / 2187, 64448 / 6561, -212 / 729], + [9017 / 3168, -355 / 33, 46732 / 5247, 49 / 176, -5103 / 18656], + [35 / 384, 0, 500 / 1113, 125 / 192, -2187 / 6784, 11 / 84], + ], + c_sol=[35 / 384, 0, 500 / 1113, 125 / 192, -2187 / 6784, 11 / 84, 0], + c_error=[ + 35 / 384 - 1951 / 21600, + 0, + 500 / 1113 - 22642 / 50085, + 125 / 192 - 451 / 720, + -2187 / 6784 - -12231 / 42400, + 11 / 84 - 649 / 6300, + -1. / 60., + ], +) + +DPS_C_MID = [ + 6025192743 / 30085553152 / 2, 0, 51252292925 / 65400821598 / 2, -2691868925 / 45128329728 / 2, + 187940372067 / 1594534317056 / 2, -1776094331 / 19743644256 / 2, 11237099 / 235043384 / 2 +] + + +def _interp_fit_dopri5(y0, y1, k, dt, tableau=_DORMAND_PRINCE_SHAMPINE_TABLEAU): + """Fit an interpolating polynomial to the results of a Runge-Kutta step.""" + dt = dt.type_as(y0[0]) + y_mid = tuple(y0_ + _scaled_dot_product(dt, DPS_C_MID, k_) for y0_, k_ in zip(y0, k)) + f0 = tuple(k_[0] for k_ in k) + f1 = tuple(k_[-1] for k_ in k) + return _interp_fit(y0, y1, y_mid, f0, f1, dt) + + +def _abs_square(x): + return torch.mul(x, x) + + +def _ta_append(list_of_tensors, value): + """Append a value to the end of a list of PyTorch tensors.""" + list_of_tensors.append(value) + return list_of_tensors + + +class Dopri5Solver(AdaptiveStepsizeODESolver): + + def __init__( + self, func, y0, rtol, atol, dt, first_step=None, safety=0.9, ifactor=10.0, dfactor=0.2, max_num_steps=2**31 - 1, + options = None + #**unused_kwargs + ): + #_handle_unused_kwargs(self, unused_kwargs) + #del unused_kwargs + + self.func = func + self.y0 = y0 + + self.dt = dt #options.dt + '''if 'dirichlet' in options.BC or 'cauchy' in options.BC and options.contours is not None: + self.contours = options.contours # (n_batch, nT, 4 / 6, BC_size, sub_spatial_shape) + self.BC_size = self.contours.size(3) + self.set_BC = set_BC_2D if self.contours.size(2) == 4 else set_BC_3D + else: + self.contours = None + if 'source' in options.BC and options.dcontours is not None: + self.dcontours = options.dcontours # (n_batch, nT, 4 / 6, BC_size, sub_spatial_shape) + self.BC_size = self.dcontours.size(3) + self.add_dBC = add_dBC_2D if self.dcontours.size(2) == 4 else add_dBC_3D + else: + self.dcontours = None''' + + #self.adjoint = options.adjoint + + self.rtol = rtol if _is_iterable(rtol) else [rtol] * len(y0) + self.atol = atol if _is_iterable(atol) else [atol] * len(y0) + self.first_step = first_step + self.safety = _convert_to_tensor(safety, dtype=torch.float64, device=y0[0].device) + self.ifactor = _convert_to_tensor(ifactor, dtype=torch.float64, device=y0[0].device) + self.dfactor = _convert_to_tensor(dfactor, dtype=torch.float64, device=y0[0].device) + self.max_num_steps = _convert_to_tensor(max_num_steps, dtype=torch.int32, device=y0[0].device) + #self.n_step_record=[] + + def before_integrate(self, t): + f0 = self.func(t[0].type_as(self.y0[0]), self.y0) + #print("first_step is {}".format(self.first_step)) + if self.first_step is None: + first_step = _select_initial_step(self.func, t[0], self.y0, 4, self.rtol[0], self.atol[0], f0=f0).to(t) + else: + first_step = _convert_to_tensor(0.01, dtype=t.dtype, device=t.device) + # if first_step>0.2: + # print("warning the first step of dopri5 {} is too big, set to 0.2".format(first_step)) + # first_step = _convert_to_tensor(0.2, dtype=torch.float64, device=self.y0[0].device) + + self.rk_state = _RungeKuttaState(self.y0, f0, t[0], t[0], first_step, interp_coeff=[self.y0] * 5) + + def advance(self, next_t): + """Interpolate through the next time point, integrating as necessary.""" + n_steps = 0 + while next_t > self.rk_state.t1: + assert n_steps < self.max_num_steps, 'max_num_steps exceeded ({}>={})'.format(n_steps, self.max_num_steps) + self.rk_state = self._adaptive_dopri5_step(self.rk_state) + n_steps += 1 + # if len(self.n_step_record)==100: + # print("this dopri5 step info will print every 100 calls, the current average step is {}".format(sum(self.n_step_record)/100)) + # self.n_step_record=[] + # else: + # self.n_step_record.append(n_steps) + + return _interp_evaluate(self.rk_state.interp_coeff, self.rk_state.t0, self.rk_state.t1, next_t) + + def _adaptive_dopri5_step(self, rk_state): + """Take an adaptive Runge-Kutta step to integrate the DiffEqs.""" + y0, f0, _, t0, dt, interp_coeff = rk_state + ######################################################## + # Assertions # + ######################################################## + assert t0 + dt > t0, 'underflow in dt {}'.format(dt.item()) + # for y0_ in y0: + # #assert _is_finite(torch.abs(y0_)), 'non-finite values in state `y`: {}'.format(y0_) + # is_finite= _is_finite(torch.abs(y0_)) + # if not is_finite: + # print(" non-finite elements exist, try to fix") + # y0_[y0_ != y0_] = 0. + # y0_[y0_ == float("Inf")] = 0. + + y1, f1, y1_error, k = _runge_kutta_step(self.func, y0, f0, t0, dt, tableau=_DORMAND_PRINCE_SHAMPINE_TABLEAU) + + ######################################################## + # Error Ratio # + ######################################################## + mean_sq_error_ratio = _compute_error_ratio(y1_error, atol=self.atol, rtol=self.rtol, y0=y0, y1=y1) + accept_step = (torch.tensor(mean_sq_error_ratio) <= 1).all() + + ######################################################## + # Update RK State # + ######################################################## + dt_next = _optimal_step_size( + dt, mean_sq_error_ratio, safety=self.safety, ifactor=self.ifactor, dfactor=self.dfactor, order=5) + tol_min_dt = 0.2 * self.dt if 0.1 * self.dt >= 0.01 else 0.01 + #print('tol min', tol_min_dt) + if not (dt_next< tol_min_dt or dt_next>0.1): #(dt_next<0.01 or dt_next>0.1): #(dt_next<0.02): #not (dt_next<0.02 or dt_next>0.1): + y_next = y1 if accept_step else y0 + f_next = f1 if accept_step else f0 + t_next = t0 + dt if accept_step else t0 + interp_coeff = _interp_fit_dopri5(y0, y_next, k, dt) if accept_step else interp_coeff + else: + if dt_next< tol_min_dt: #dt_next<0.01: # 0.01 + #print("Dopri5 step %.3f too small, set to %.3f" % (dt_next, 0.2 * self.dt)) + dt_next = _convert_to_tensor(tol_min_dt, dtype=torch.float64, device=y0[0].device) + if dt_next>0.1: + #print("Dopri5 step %.8f is too big, set to 0.1" % (dt_next)) + dt_next = _convert_to_tensor(0.1, dtype=torch.float64, device=y0[0].device) + y_next = y1 + f_next = f1 + t_next = t0 + dt + interp_coeff = _interp_fit_dopri5(y0, y1, k, dt) + rk_state = _RungeKuttaState(y_next, f_next, t0, t_next, dt_next, interp_coeff) + #print('dt_next', dt_next) + return rk_state \ No newline at end of file diff --git a/ShapeID/DiffEqs/fixed_adams.py b/ShapeID/DiffEqs/fixed_adams.py new file mode 100644 index 0000000000000000000000000000000000000000..b71d91360c4fdd9f6835a3a5d96cb7004688b280 --- /dev/null +++ b/ShapeID/DiffEqs/fixed_adams.py @@ -0,0 +1,211 @@ +import sys +import collections +from ShapeID.DiffEqs.solvers import FixedGridODESolver +from ShapeID.DiffEqs.misc import _scaled_dot_product, _has_converged +import ShapeID.DiffEqs.rk_common + +_BASHFORTH_COEFFICIENTS = [ + [], # order 0 + [11], + [3, -1], + [23, -16, 5], + [55, -59, 37, -9], + [1901, -2774, 2616, -1274, 251], + [4277, -7923, 9982, -7298, 2877, -475], + [198721, -447288, 705549, -688256, 407139, -134472, 19087], + [434241, -1152169, 2183877, -2664477, 2102243, -1041723, 295767, -36799], + [14097247, -43125206, 95476786, -139855262, 137968480, -91172642, 38833486, -9664106, 1070017], + [30277247, -104995189, 265932680, -454661776, 538363838, -444772162, 252618224, -94307320, 20884811, -2082753], + [ + 2132509567, -8271795124, 23591063805, -46113029016, 63716378958, -63176201472, 44857168434, -22329634920, + 7417904451, -1479574348, 134211265 + ], + [ + 4527766399, -19433810163, 61633227185, -135579356757, 214139355366, -247741639374, 211103573298, -131365867290, + 58189107627, -17410248271, 3158642445, -262747265 + ], + [ + 13064406523627, -61497552797274, 214696591002612, -524924579905150, 932884546055895, -1233589244941764, + 1226443086129408, -915883387152444, 507140369728425, -202322913738370, 55060974662412, -9160551085734, + 703604254357 + ], + [ + 27511554976875, -140970750679621, 537247052515662, -1445313351681906, 2854429571790805, -4246767353305755, + 4825671323488452, -4204551925534524, 2793869602879077, -1393306307155755, 505586141196430, -126174972681906, + 19382853593787, -1382741929621 + ], + [ + 173233498598849, -960122866404112, 3966421670215481, -11643637530577472, 25298910337081429, -41825269932507728, + 53471026659940509, -53246738660646912, 41280216336284259, -24704503655607728, 11205849753515179, + -3728807256577472, 859236476684231, -122594813904112, 8164168737599 + ], + [ + 362555126427073, -2161567671248849, 9622096909515337, -30607373860520569, 72558117072259733, + -131963191940828581, 187463140112902893, -210020588912321949, 186087544263596643, -129930094104237331, + 70724351582843483, -29417910911251819, 9038571752734087, -1934443196892599, 257650275915823, -16088129229375 + ], + [ + 192996103681340479, -1231887339593444974, 5878428128276811750, -20141834622844109630, 51733880057282977010, + -102651404730855807942, 160414858999474733422, -199694296833704562550, 199061418623907202560, + -158848144481581407370, 100878076849144434322, -50353311405771659322, 19338911944324897550, + -5518639984393844930, 1102560345141059610, -137692773163513234, 8092989203533249 + ], + [ + 401972381695456831, -2735437642844079789, 13930159965811142228, -51150187791975812900, 141500575026572531760, + -304188128232928718008, 518600355541383671092, -710171024091234303204, 786600875277595877750, + -706174326992944287370, 512538584122114046748, -298477260353977522892, 137563142659866897224, + -49070094880794267600, 13071639236569712860, -2448689255584545196, 287848942064256339, -15980174332775873 + ], + [ + 333374427829017307697, -2409687649238345289684, 13044139139831833251471, -51099831122607588046344, + 151474888613495715415020, -350702929608291455167896, 647758157491921902292692, -967713746544629658690408, + 1179078743786280451953222, -1176161829956768365219840, 960377035444205950813626, -639182123082298748001432, + 343690461612471516746028, -147118738993288163742312, 48988597853073465932820, -12236035290567356418552, + 2157574942881818312049, -239560589366324764716, 12600467236042756559 + ], + [ + 691668239157222107697, -5292843584961252933125, 30349492858024727686755, -126346544855927856134295, + 399537307669842150996468, -991168450545135070835076, 1971629028083798845750380, -3191065388846318679544380, + 4241614331208149947151790, -4654326468801478894406214, 4222756879776354065593786, -3161821089800186539248210, + 1943018818982002395655620, -970350191086531368649620, 387739787034699092364924, -121059601023985433003532, + 28462032496476316665705, -4740335757093710713245, 498669220956647866875, -24919383499187492303 + ], +] + +_MOULTON_COEFFICIENTS = [ + [], # order 0 + [1], + [1, 1], + [5, 8, -1], + [9, 19, -5, 1], + [251, 646, -264, 106, -19], + [475, 1427, -798, 482, -173, 27], + [19087, 65112, -46461, 37504, -20211, 6312, -863], + [36799, 139849, -121797, 123133, -88547, 41499, -11351, 1375], + [1070017, 4467094, -4604594, 5595358, -5033120, 3146338, -1291214, 312874, -33953], + [2082753, 9449717, -11271304, 16002320, -17283646, 13510082, -7394032, 2687864, -583435, 57281], + [ + 134211265, 656185652, -890175549, 1446205080, -1823311566, 1710774528, -1170597042, 567450984, -184776195, + 36284876, -3250433 + ], + [ + 262747265, 1374799219, -2092490673, 3828828885, -5519460582, 6043521486, -4963166514, 3007739418, -1305971115, + 384709327, -68928781, 5675265 + ], + [ + 703604254357, 3917551216986, -6616420957428, 13465774256510, -21847538039895, 27345870698436, -26204344465152, + 19058185652796, -10344711794985, 4063327863170, -1092096992268, 179842822566, -13695779093 + ], + [ + 1382741929621, 8153167962181, -15141235084110, 33928990133618, -61188680131285, 86180228689563, -94393338653892, + 80101021029180, -52177910882661, 25620259777835, -9181635605134, 2268078814386, -345457086395, 24466579093 + ], + [ + 8164168737599, 50770967534864, -102885148956217, 251724894607936, -499547203754837, 781911618071632, + -963605400824733, 934600833490944, -710312834197347, 418551804601264, -187504936597931, 61759426692544, + -14110480969927, 1998759236336, -132282840127 + ], + [ + 16088129229375, 105145058757073, -230992163723849, 612744541065337, -1326978663058069, 2285168598349733, + -3129453071993581, 3414941728852893, -2966365730265699, 2039345879546643, -1096355235402331, 451403108933483, + -137515713789319, 29219384284087, -3867689367599, 240208245823 + ], + [ + 8092989203533249, 55415287221275246, -131240807912923110, 375195469874202430, -880520318434977010, + 1654462865819232198, -2492570347928318318, 3022404969160106870, -2953729295811279360, 2320851086013919370, + -1455690451266780818, 719242466216944698, -273894214307914510, 77597639915764930, -15407325991235610, + 1913813460537746, -111956703448001 + ], + [ + 15980174332775873, 114329243705491117, -290470969929371220, 890337710266029860, -2250854333681641520, + 4582441343348851896, -7532171919277411636, 10047287575124288740, -10910555637627652470, 9644799218032932490, + -6913858539337636636, 3985516155854664396, -1821304040326216520, 645008976643217360, -170761422500096220, + 31816981024600492, -3722582669836627, 205804074290625 + ], + [ + 12600467236042756559, 93965550344204933076, -255007751875033918095, 834286388106402145800, + -2260420115705863623660, 4956655592790542146968, -8827052559979384209108, 12845814402199484797800, + -15345231910046032448070, 15072781455122686545920, -12155867625610599812538, 8008520809622324571288, + -4269779992576330506540, 1814584564159445787240, -600505972582990474260, 149186846171741510136, + -26182538841925312881, 2895045518506940460, -151711881512390095 + ], + [ + 24919383499187492303, 193280569173472261637, -558160720115629395555, 1941395668950986461335, + -5612131802364455926260, 13187185898439270330756, -25293146116627869170796, 39878419226784442421820, + -51970649453670274135470, 56154678684618739939910, -50320851025594566473146, 37297227252822858381906, + -22726350407538133839300, 11268210124987992327060, -4474886658024166985340, 1389665263296211699212, + -325187970422032795497, 53935307402575440285, -5652892248087175675, 281550972898020815 + ], +] + +_DIVISOR = [ + None, 11, 2, 12, 24, 720, 1440, 60480, 120960, 3628800, 7257600, 479001600, 958003200, 2615348736000, 5230697472000, + 31384184832000, 62768369664000, 32011868528640000, 64023737057280000, 51090942171709440000, 102181884343418880000 +] + +_MIN_ORDER = 4 +_MAX_ORDER = 12 +_MAX_ITERS = 4 + + +class AdamsBashforthMoulton(FixedGridODESolver): + + def __init__( + self, func, y0, rtol=1e-3, atol=1e-4, implicit=True, max_iters=_MAX_ITERS, max_order=_MAX_ORDER, **kwargs + ): + super(AdamsBashforthMoulton, self).__init__(func, y0, **kwargs) + + self.rtol = rtol + self.atol = atol + self.implicit = implicit + self.max_iters = max_iters + self.max_order = int(min(max_order, _MAX_ORDER)) + self.prev_f = collections.deque(maxlen=self.max_order - 1) + self.prev_t = None + + def _update_history(self, t, f): + if self.prev_t is None or self.prev_t != t: + self.prev_f.appendleft(f) + self.prev_t = t + + def step_func(self, func, t, dt, y): + self._update_history(t, func(t, y)) + order = min(len(self.prev_f), self.max_order - 1) + if order < _MIN_ORDER - 1: + # Compute using RK4. + dy = rk_common.rk4_alt_step_func(func, t, dt, y, k1=self.prev_f[0]) + return dy + else: + # Adams-Bashforth predictor. + bashforth_coeffs = _BASHFORTH_COEFFICIENTS[order] + ab_div = _DIVISOR[order] + dy = tuple(dt * _scaled_dot_product(1 / ab_div, bashforth_coeffs, f_) for f_ in zip(*self.prev_f)) + + # Adams-Moulton corrector. + if self.implicit: + moulton_coeffs = _MOULTON_COEFFICIENTS[order + 1] + am_div = _DIVISOR[order + 1] + delta = tuple(dt * _scaled_dot_product(1 / am_div, moulton_coeffs[1:], f_) for f_ in zip(*self.prev_f)) + converged = False + for _ in range(self.max_iters): + dy_old = dy + f = func(t + dt, tuple(y_ + dy_ for y_, dy_ in zip(y, dy))) + dy = tuple(dt * (moulton_coeffs[0] / am_div) * f_ + delta_ for f_, delta_ in zip(f, delta)) + converged = _has_converged(dy_old, dy, self.rtol, self.atol) + if converged: + break + if not converged: + print('Warning: Functional iteration did not converge. Solution may be incorrect.', file=sys.stderr) + self.prev_f.pop() + self._update_history(t, f) + return dy + + @property + def order(self): + return 4 + + +class AdamsBashforth(AdamsBashforthMoulton): + + def __init__(self, func, y0, **kwargs): + super(AdamsBashforth, self).__init__(func, y0, implicit=False, **kwargs) diff --git a/ShapeID/DiffEqs/fixed_grid.py b/ShapeID/DiffEqs/fixed_grid.py new file mode 100644 index 0000000000000000000000000000000000000000..22cdb88ef3bba99492ba4226999b7f770b98ea8f --- /dev/null +++ b/ShapeID/DiffEqs/fixed_grid.py @@ -0,0 +1,33 @@ +from ShapeID.DiffEqs.solvers import FixedGridODESolver +import ShapeID.DiffEqs.rk_common as rk_common + + +class Euler(FixedGridODESolver): + + def step_func(self, func, t, dt, y): + return tuple(dt * f_ for f_ in func(t, y)) + + @property + def order(self): + return 1 + + +class Midpoint(FixedGridODESolver): + + def step_func(self, func, t, dt, y): + y_mid = tuple(y_ + f_ * dt / 2 for y_, f_ in zip(y, func(t, y))) + return tuple(dt * f_ for f_ in func(t + dt / 2, y_mid)) + + @property + def order(self): + return 2 + + +class RK4(FixedGridODESolver): + + def step_func(self, func, t, dt, y): + return rk_common.rk4_alt_step_func(func, t, dt, y) + + @property + def order(self): + return 4 diff --git a/ShapeID/DiffEqs/interp.py b/ShapeID/DiffEqs/interp.py new file mode 100644 index 0000000000000000000000000000000000000000..a0acd862b2a2db2d365e18e53a0dddb66a1c90a2 --- /dev/null +++ b/ShapeID/DiffEqs/interp.py @@ -0,0 +1,65 @@ +import torch +from ShapeID.DiffEqs.misc import _convert_to_tensor, _dot_product + + +def _interp_fit(y0, y1, y_mid, f0, f1, dt): + """Fit coefficients for 4th order polynomial interpolation. + + Args: + y0: function value at the start of the interval. + y1: function value at the end of the interval. + y_mid: function value at the mid-point of the interval. + f0: derivative value at the start of the interval. + f1: derivative value at the end of the interval. + dt: width of the interval. + + Returns: + List of coefficients `[a, b, c, d, e]` for interpolating with the polynomial + `p = a * x ** 4 + b * x ** 3 + c * x ** 2 + d * x + e` for values of `x` + between 0 (start of interval) and 1 (end of interval). + """ + a = tuple( + _dot_product([-2 * dt, 2 * dt, -8, -8, 16], [f0_, f1_, y0_, y1_, y_mid_]) + for f0_, f1_, y0_, y1_, y_mid_ in zip(f0, f1, y0, y1, y_mid) + ) + b = tuple( + _dot_product([5 * dt, -3 * dt, 18, 14, -32], [f0_, f1_, y0_, y1_, y_mid_]) + for f0_, f1_, y0_, y1_, y_mid_ in zip(f0, f1, y0, y1, y_mid) + ) + c = tuple( + _dot_product([-4 * dt, dt, -11, -5, 16], [f0_, f1_, y0_, y1_, y_mid_]) + for f0_, f1_, y0_, y1_, y_mid_ in zip(f0, f1, y0, y1, y_mid) + ) + d = tuple(dt * f0_ for f0_ in f0) + e = y0 + return [a, b, c, d, e] + + +def _interp_evaluate(coefficients, t0, t1, t): + """Evaluate polynomial interpolation at the given time point. + + Args: + coefficients: list of Tensor coefficients as created by `interp_fit`. + t0: scalar float64 Tensor giving the start of the interval. + t1: scalar float64 Tensor giving the end of the interval. + t: scalar float64 Tensor giving the desired interpolation point. + + Returns: + Polynomial interpolation of the coefficients at time `t`. + """ + + dtype = coefficients[0][0].dtype + device = coefficients[0][0].device + + t0 = _convert_to_tensor(t0, dtype=dtype, device=device) + t1 = _convert_to_tensor(t1, dtype=dtype, device=device) + t = _convert_to_tensor(t, dtype=dtype, device=device) + + assert (t0 <= t) & (t <= t1), 'invalid interpolation, fails `t0 <= t <= t1`: {}, {}, {}'.format(t0, t, t1) + x = ((t - t0) / (t1 - t0)).type(dtype).to(device) + + xs = [torch.tensor(1).type(dtype).to(device), x] + for _ in range(2, len(coefficients)): + xs.append(xs[-1] * x) + + return tuple(_dot_product(coefficients_, reversed(xs)) for coefficients_ in zip(*coefficients)) diff --git a/ShapeID/DiffEqs/misc.py b/ShapeID/DiffEqs/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..d1df333d7ba2675c1b4d2c300c9a473f8a79f1fe --- /dev/null +++ b/ShapeID/DiffEqs/misc.py @@ -0,0 +1,195 @@ +import warnings +import torch + + +def _flatten(sequence): + flat = [p.contiguous().view(-1) for p in sequence] + return torch.cat(flat) if len(flat) > 0 else torch.tensor([]) + + +def _flatten_convert_none_to_zeros(sequence, like_sequence): + flat = [ + p.contiguous().view(-1) if p is not None else torch.zeros_like(q).view(-1) + for p, q in zip(sequence, like_sequence) + ] + return torch.cat(flat) if len(flat) > 0 else torch.tensor([]) + + +def _possibly_nonzero(x): + return isinstance(x, torch.Tensor) or x != 0 + + +def _scaled_dot_product(scale, xs, ys): + """Calculate a scaled, vector inner product between lists of Tensors.""" + # Using _possibly_nonzero lets us avoid wasted computation. + return sum([(scale * x) * y for x, y in zip(xs, ys) if _possibly_nonzero(x) or _possibly_nonzero(y)]) + + +def _dot_product(xs, ys): + """Calculate the vector inner product between two lists of Tensors.""" + return sum([x * y for x, y in zip(xs, ys)]) + + +def _has_converged(y0, y1, rtol, atol): + """Checks that each element is within the error tolerance.""" + error_tol = tuple(atol + rtol * torch.max(torch.abs(y0_), torch.abs(y1_)) for y0_, y1_ in zip(y0, y1)) + error = tuple(torch.abs(y0_ - y1_) for y0_, y1_ in zip(y0, y1)) + return all((error_ < error_tol_).all() for error_, error_tol_ in zip(error, error_tol)) + + +def _convert_to_tensor(a, dtype=None, device=None): + if not isinstance(a, torch.Tensor): + a = torch.tensor(a) + if dtype is not None: + a = a.type(dtype) + if device is not None: + a = a.to(device) + return a + + +def _is_finite(tensor): + _check = (tensor == float('inf')) + (tensor == float('-inf')) + torch.isnan(tensor) + return not _check.any() + + +def _decreasing(t): + return (t[1:] < t[:-1]).all() + + +def _assert_increasing(t): + assert (t[1:] > t[:-1]).all(), 't must be strictly increasing or decrasing' + + +def _is_iterable(inputs): + try: + iter(inputs) + return True + except TypeError: + return False + + +def _norm(x): + """Compute RMS norm.""" + if torch.is_tensor(x): + return x.norm() / (x.numel()**0.5) + else: + return torch.sqrt(sum(x_.norm()**2 for x_ in x) / sum(x_.numel() for x_ in x)) + + +def _handle_unused_kwargs(solver, unused_kwargs): + if len(unused_kwargs) > 0: + warnings.warn('{}: Unexpected arguments {}'.format(solver.__class__.__name__, unused_kwargs)) + + +def _select_initial_step(fun, t0, y0, order, rtol, atol, f0=None): + """Empirically select a good initial step. + + The algorithm is described in [1]_. + + Parameters + ---------- + fun : callable + Right-hand side of the system. + t0 : float + Initial value of the independent variable. + y0 : ndarray, shape (n,) + Initial value of the dependent variable. + direction : float + Integration direction. + order : float + Method order. + rtol : float + Desired relative tolerance. + atol : float + Desired absolute tolerance. + + Returns + ------- + h_abs : float + Absolute value of the suggested initial step. + + References + ---------- + .. [1] E. Hairer, S. P. Norsett G. Wanner, "Solving Ordinary Differential + Equations I: Nonstiff Problems", Sec. II.4. + """ + t0 = t0.to(y0[0]) + if f0 is None: + f0 = fun(t0, y0) + + rtol = rtol if _is_iterable(rtol) else [rtol] * len(y0) + atol = atol if _is_iterable(atol) else [atol] * len(y0) + + scale = tuple(atol_ + torch.abs(y0_) * rtol_ for y0_, atol_, rtol_ in zip(y0, atol, rtol)) + + d0 = tuple(_norm(y0_ / scale_) for y0_, scale_ in zip(y0, scale)) + d1 = tuple(_norm(f0_ / scale_) for f0_, scale_ in zip(f0, scale)) + + if max(d0).item() < 1e-5 or max(d1).item() < 1e-5: + h0 = torch.tensor(1e-6).to(t0) + else: + h0 = 0.01 * max(d0_ / d1_ for d0_, d1_ in zip(d0, d1)) + + y1 = tuple(y0_ + h0 * f0_ for y0_, f0_ in zip(y0, f0)) + f1 = fun(t0 + h0, y1) + + d2 = tuple(_norm((f1_ - f0_) / scale_) / h0 for f1_, f0_, scale_ in zip(f1, f0, scale)) + + if max(d1).item() <= 1e-15 and max(d2).item() <= 1e-15: + h1 = torch.max(torch.tensor(1e-6).to(h0), h0 * 1e-3) + else: + h1 = (0.01 / max(d1 + d2))**(1. / float(order + 1)) + + return torch.min(100 * h0, h1) + + +def _compute_error_ratio(error_estimate, error_tol=None, rtol=None, atol=None, y0=None, y1=None): + if error_tol is None: + assert rtol is not None and atol is not None and y0 is not None and y1 is not None + rtol if _is_iterable(rtol) else [rtol] * len(y0) + atol if _is_iterable(atol) else [atol] * len(y0) + error_tol = tuple( + atol_ + rtol_ * torch.max(torch.abs(y0_), torch.abs(y1_)) + for atol_, rtol_, y0_, y1_ in zip(atol, rtol, y0, y1) + ) + error_ratio = tuple(error_estimate_ / error_tol_ for error_estimate_, error_tol_ in zip(error_estimate, error_tol)) + mean_sq_error_ratio = tuple(torch.mean(error_ratio_ * error_ratio_) for error_ratio_ in error_ratio) + return mean_sq_error_ratio + + +def _optimal_step_size(last_step, mean_error_ratio, safety=0.9, ifactor=10.0, dfactor=0.2, order=5): + """Calculate the optimal size for the next step.""" + mean_error_ratio = max(mean_error_ratio) # Compute step size based on highest ratio. + if mean_error_ratio == 0: + return last_step * ifactor + if mean_error_ratio < 1: + dfactor = _convert_to_tensor(1, dtype=torch.float64, device=mean_error_ratio.device) + error_ratio = torch.sqrt(mean_error_ratio).to(last_step) + exponent = torch.tensor(1 / order).to(last_step) + factor = torch.max(1 / ifactor, torch.min(error_ratio**exponent / safety, 1 / dfactor)) + return last_step / factor + + +def _check_inputs(func, y0, t): + tensor_input = False + if torch.is_tensor(y0): + tensor_input = True + y0 = (y0,) + _base_nontuple_func_ = func + func = lambda t, y: (_base_nontuple_func_(t, y[0]),) + assert isinstance(y0, tuple), 'y0 must be either a torch.Tensor or a tuple' + for y0_ in y0: + assert torch.is_tensor(y0_), 'each element must be a torch.Tensor but received {}'.format(type(y0_)) + + if _decreasing(t): + t = -t + _base_reverse_func = func + func = lambda t, y: tuple(-f_ for f_ in _base_reverse_func(-t, y)) + + for y0_ in y0: + if not torch.is_floating_point(y0_): + raise TypeError('`y0` must be a floating point Tensor but is a {}'.format(y0_.type())) + if not torch.is_floating_point(t): + raise TypeError('`t` must be a floating point Tensor but is a {}'.format(t.type())) + + return tensor_input, func, y0, t diff --git a/ShapeID/DiffEqs/odeint.py b/ShapeID/DiffEqs/odeint.py new file mode 100644 index 0000000000000000000000000000000000000000..d5be84bb0e4de1f90f32ba17a70dc3fce48914d1 --- /dev/null +++ b/ShapeID/DiffEqs/odeint.py @@ -0,0 +1,75 @@ +from ShapeID.DiffEqs.tsit5 import Tsit5Solver +from ShapeID.DiffEqs.dopri5 import Dopri5Solver +from ShapeID.DiffEqs.fixed_grid import Euler, Midpoint, RK4 +from ShapeID.DiffEqs.fixed_adams import AdamsBashforth, AdamsBashforthMoulton +from ShapeID.DiffEqs.adams import VariableCoefficientAdamsBashforth +from ShapeID.DiffEqs.misc import _check_inputs + +SOLVERS = { + 'explicit_adams': AdamsBashforth, + 'fixed_adams': AdamsBashforthMoulton, + 'adams': VariableCoefficientAdamsBashforth, + 'tsit5': Tsit5Solver, + 'dopri5': Dopri5Solver, + 'euler': Euler, + 'midpoint': Midpoint, + 'rk4': RK4, +} + + +def odeint(func, y0, t, dt, step_size = None, rtol = 1e-7, atol = 1e-9, method = None, options = None): + """Integrate a system of ordinary differential equations. + + Solves the initial value problem for a non-stiff system of first order ODEs: + ``` + dy/dt = func(t, y), y(t[0]) = y0 + ``` + where y is a Tensor of any shape. + + Output dtypes and numerical precision are based on the dtypes of the inputs `y0`. + + Args: + func: Function that maps a Tensor holding the state `y` and a scalar Tensor + `t` into a Tensor of state derivatives with respect to time. + y0: N-D Tensor giving starting value of `y` at time point `t[0]`. May + have any floating point or complex dtype. + t: 1-D Tensor holding a sequence of time points for which to solve for + `y`. The initial time point should be the first element of this sequence, + and each time must be larger than the previous time. May have any floating + point dtype. Converted to a Tensor with float64 dtype. + rtol: optional float64 Tensor specifying an upper bound on relative error, + per element of `y`. + atol: optional float64 Tensor specifying an upper bound on absolute error, + per element of `y`. + method: optional string indicating the integration method to use. + options: optional dict of configuring options for the indicated integration + method. Can only be provided if a `method` is explicitly set. + name: Optional name for this operation. + + Returns: + y: Tensor, where the first dimension corresponds to different + time points. Contains the solved value of y for each desired time point in + `t`, with the initial value `y0` being the first element along the first + dimension. + + Raises: + ValueError: if an invalid `method` is provided. + TypeError: if `options` is supplied without `method`, or if `t` or `y0` has + an invalid dtype. + """ + + tensor_input, func, y0, t = _check_inputs(func, y0, t) + + if options and method is None: + raise ValueError('cannot supply `options` without specifying `method`') + + if method is None: + method = 'dopri5' + + #solver = SOLVERS[method](func, y0, rtol = rtol, atol = atol, **options) + solver = SOLVERS[method](func, y0, rtol = rtol, atol = atol, dt = dt, options = options) + solution = solver.integrate(t) + + if tensor_input: + solution = solution[0] + return solution diff --git a/ShapeID/DiffEqs/pde.py b/ShapeID/DiffEqs/pde.py new file mode 100644 index 0000000000000000000000000000000000000000..c4835f34c138de422d8b18aa9ccce3805c9634e1 --- /dev/null +++ b/ShapeID/DiffEqs/pde.py @@ -0,0 +1,643 @@ +# ported from https://github.com/pvigier/perlin-numpy + +import math + +import numpy as np + +import torch +import torch.nn as nn + + + + +def gradient_f(X, batched = False, delta_lst = [1., 1., 1.]): + ''' + Compute gradient of a torch tensor "X" in each direction + Upper-boundaries: Backward Difference + Non-boundaries & Upper-boundaries: Forward Difference + if X is batched: (n_batch, ...); + else: (...) + ''' + device = X.device + dim = len(X.size()) - 1 if batched else len(X.size()) + #print(batched) + #print(dim) + if dim == 1: + #print('dim = 1') + dX = torch.zeros(X.size(), dtype = torch.float, device = device) + X = X.permute(1, 0) if batched else X + dX = dX.permute(1, 0) if batched else dX + dX[-1] = X[-1] - X[-2] # Backward Difference + dX[:-1] = X[1:] - X[:-1] # Forward Difference + + dX = dX.permute(1, 0) if batched else dX + dX /= delta_lst[0] + elif dim == 2: + #print('dim = 2') + dX = torch.zeros(X.size() + tuple([2]), dtype = torch.float, device = device) + X = X.permute(1, 2, 0) if batched else X + dX = dX.permute(1, 2, 3, 0) if batched else dX # put batch to last dim + dX[-1, :, 0] = X[-1, :] - X[-2, :] # Backward Difference + dX[:-1, :, 0] = X[1:] - X[:-1] # Forward Difference + + dX[:, -1, 1] = X[:, -1] - X[:, -2] # Backward Difference + dX[:, :-1, 1] = X[:, 1:] - X[:, :-1] # Forward Difference + + dX = dX.permute(3, 0, 1, 2) if batched else dX + dX[..., 0] /= delta_lst[0] + dX[..., 1] /= delta_lst[1] + elif dim == 3: + #print('dim = 3') + dX = torch.zeros(X.size() + tuple([3]), dtype = torch.float, device = device) + X = X.permute(1, 2, 3, 0) if batched else X + dX = dX.permute(1, 2, 3, 4, 0) if batched else dX + dX[-1, :, :, 0] = X[-1, :, :] - X[-2, :, :] # Backward Difference + dX[:-1, :, :, 0] = X[1:] - X[:-1] # Forward Difference + + dX[:, -1, :, 1] = X[:, -1] - X[:, -2] # Backward Difference + dX[:, :-1, :, 1] = X[:, 1:] - X[:, :-1] # Forward Difference + + dX[:, :, -1, 2] = X[:, :, -1] - X[:, :, -2] # Backward Difference + dX[:, :, :-1, 2] = X[:, :, 1:] - X[:, :, :-1] # Forward Difference + + dX = dX.permute(4, 0, 1, 2, 3) if batched else dX + dX[..., 0] /= delta_lst[0] + dX[..., 1] /= delta_lst[1] + dX[..., 2] /= delta_lst[2] + return dX + + +def gradient_b(X, batched = False, delta_lst = [1., 1., 1.]): + ''' + Compute gradient of a torch tensor "X" in each direction + Non-boundaries & Upper-boundaries: Backward Difference + Lower-boundaries: Forward Difference + if X is batched: (n_batch, ...); + else: (...) + ''' + device = X.device + dim = len(X.size()) - 1 if batched else len(X.size()) + #print(batched) + #print(dim) + if dim == 1: + #print('dim = 1') + dX = torch.zeros(X.size(), dtype = torch.float, device = device) + X = X.permute(1, 0) if batched else X + dX = dX.permute(1, 0) if batched else dX + dX[1:] = X[1:] - X[:-1] # Backward Difference + dX[0] = X[1] - X[0] # Forward Difference + + dX = dX.permute(1, 0) if batched else dX + dX /= delta_lst[0] + elif dim == 2: + #print('dim = 2') + dX = torch.zeros(X.size() + tuple([2]), dtype = torch.float, device = device) + X = X.permute(1, 2, 0) if batched else X + dX = dX.permute(1, 2, 3, 0) if batched else dX # put batch to last dim + dX[1:, :, 0] = X[1:, :] - X[:-1, :] # Backward Difference + dX[0, :, 0] = X[1] - X[0] # Forward Difference + + dX[:, 1:, 1] = X[:, 1:] - X[:, :-1] # Backward Difference + dX[:, 0, 1] = X[:, 1] - X[:, 0] # Forward Difference + + dX = dX.permute(3, 0, 1, 2) if batched else dX + dX[..., 0] /= delta_lst[0] + dX[..., 1] /= delta_lst[1] + elif dim == 3: + #print('dim = 3') + dX = torch.zeros(X.size() + tuple([3]), dtype = torch.float, device = device) + X = X.permute(1, 2, 3, 0) if batched else X + dX = dX.permute(1, 2, 3, 4, 0) if batched else dX + dX[1:, :, :, 0] = X[1:, :, :] - X[:-1, :, :] # Backward Difference + dX[0, :, :, 0] = X[1] - X[0] # Forward Difference + + dX[:, 1:, :, 1] = X[:, 1:] - X[:, :-1] # Backward Difference + dX[:, 0, :, 1] = X[:, 1] - X[:, 0] # Forward Difference + + dX[:, :, 1:, 2] = X[:, :, 1:] - X[:, :, :-1] # Backward Difference + dX[:, :, 0, 2] = X[:, :, 1] - X[:, :, 0] # Forward Difference + + dX = dX.permute(4, 0, 1, 2, 3) if batched else dX + dX[..., 0] /= delta_lst[0] + dX[..., 1] /= delta_lst[1] + dX[..., 2] /= delta_lst[2] + return dX + + +def gradient_c(X, batched = False, delta_lst = [1., 1., 1.]): + ''' + Compute gradient of a torch tensor "X" in each direction + Non-boundaries: Central Difference + Upper-boundaries: Backward Difference + Lower-boundaries: Forward Difference + if X is batched: (n_batch, ...); + else: (...) + ''' + device = X.device + dim = len(X.size()) - 1 if batched else len(X.size()) + #print(X.size()) + #print(batched) + #print(dim) + if dim == 1: + #print('dim = 1') + dX = torch.zeros(X.size(), dtype = torch.float, device = device) + X = X.permute(1, 0) if batched else X + dX = dX.permute(1, 0) if batched else dX + dX[1:-1] = (X[2:] - X[:-2]) / 2 # Central Difference + dX[0] = X[1] - X[0] # Forward Difference + dX[-1] = X[-1] - X[-2] # Backward Difference + + dX = dX.permute(1, 0) if batched else dX + dX /= delta_lst[0] + elif dim == 2: + #print('dim = 2') + dX = torch.zeros(X.size() + tuple([2]), dtype = torch.float, device = device) + X = X.permute(1, 2, 0) if batched else X + dX = dX.permute(1, 2, 3, 0) if batched else dX # put batch to last dim + dX[1:-1, :, 0] = (X[2:, :] - X[:-2, :]) / 2 + dX[0, :, 0] = X[1] - X[0] + dX[-1, :, 0] = X[-1] - X[-2] + dX[:, 1:-1, 1] = (X[:, 2:] - X[:, :-2]) / 2 + dX[:, 0, 1] = X[:, 1] - X[:, 0] + dX[:, -1, 1] = X[:, -1] - X[:, -2] + + dX = dX.permute(3, 0, 1, 2) if batched else dX + dX[..., 0] /= delta_lst[0] + dX[..., 1] /= delta_lst[1] + elif dim == 3: + #print('dim = 3') + dX = torch.zeros(X.size() + tuple([3]), dtype = torch.float, device = device) + X = X.permute(1, 2, 3, 0) if batched else X + dX = dX.permute(1, 2, 3, 4, 0) if batched else dX + dX[1:-1, :, :, 0] = (X[2:, :, :] - X[:-2, :, :]) / 2 + dX[0, :, :, 0] = X[1] - X[0] + dX[-1, :, :, 0] = X[-1] - X[-2] + dX[:, 1:-1, :, 1] = (X[:, 2:, :] - X[:, :-2, :]) / 2 + dX[:, 0, :, 1] = X[:, 1] - X[:, 0] + dX[:, -1, :, 1] = X[:, -1] - X[:, -2] + dX[:, :, 1:-1, 2] = (X[:, :, 2:] - X[:, :, :-2]) / 2 + dX[:, :, 0, 2] = X[:, :, 1] - X[:, :, 0] + dX[:, :, -1, 2] = X[:, :, -1] - X[:, :, -2] + + dX = dX.permute(4, 0, 1, 2, 3) if batched else dX + dX[..., 0] /= delta_lst[0] + dX[..., 1] /= delta_lst[1] + dX[..., 2] /= delta_lst[2] + return dX + + +def gradient_c_numpy(X, batched = False, delta_lst = [1., 1., 1.]): + ''' + Compute gradient of a Numpy array "X" in each direction + Non-boundaries: Central Difference + Upper-boundaries: Backward Difference + Lower-boundaries: Forward Difference + if X is batched: (n_batch, ...); + else: (...) + ''' + dim = len(X.shape) - 1 if batched else len(X.shape) + #print(dim) + if dim == 1: + #print('dim = 1') + X = np.transpose(X, (1, 0)) if batched else X + dX = np.zeros(X.shapee).astype(float) + dX[1:-1] = (X[2:] - X[:-2]) / 2 # Central Difference + dX[0] = X[1] - X[0] # Forward Difference + dX[-1] = X[-1] - X[-2] # Backward Difference + + dX = np.transpose(X, (1, 0)) if batched else dX + dX /= delta_lst[0] + elif dim == 2: + #print('dim = 2') + dX = np.zeros(X.shape + tuple([2])).astype(float) + X = np.transpose(X, (1, 2, 0)) if batched else X + dX = np.transpose(dX, (1, 2, 3, 0)) if batched else dX # put batch to last dim + dX[1:-1, :, 0] = (X[2:, :] - X[:-2, :]) / 2 + dX[0, :, 0] = X[1] - X[0] + dX[-1, :, 0] = X[-1] - X[-2] + dX[:, 1:-1, 1] = (X[:, 2:] - X[:, :-2]) / 2 + dX[:, 0, 1] = X[:, 1] - X[:, 0] + dX[:, -1, 1] = X[:, -1] - X[:, -2] + + dX = np.transpose(dX, (3, 0, 1, 2)) if batched else dX + dX[..., 0] /= delta_lst[0] + dX[..., 1] /= delta_lst[1] + elif dim == 3: + #print('dim = 3') + dX = np.zeros(X.shape + tuple([3])).astype(float) + X = np.transpose(X, (1, 2, 3, 0)) if batched else X + dX = np.transpose(dX, (1, 2, 3, 4, 0)) if batched else dX # put batch to last dim + dX[1:-1, :, :, 0] = (X[2:, :, :] - X[:-2, :, :]) / 2 + dX[0, :, :, 0] = X[1] - X[0] + dX[-1, :, :, 0] = X[-1] - X[-2] + dX[:, 1:-1, :, 1] = (X[:, 2:, :] - X[:, :-2, :]) / 2 + dX[:, 0, :, 1] = X[:, 1] - X[:, 0] + dX[:, -1, :, 1] = X[:, -1] - X[:, -2] + dX[:, :, 1:-1, 2] = (X[:, :, 2:] - X[:, :, :-2]) / 2 + dX[:, :, 0, 2] = X[:, :, 1] - X[:, :, 0] + dX[:, :, -1, 2] = X[:, :, -1] - X[:, :, -2] + + dX = np.transpose(dX, (4, 0, 1, 2, 3)) if batched else dX + dX[..., 0] /= delta_lst[0] + dX[..., 1] /= delta_lst[1] + dX[..., 2] /= delta_lst[2] + return dX + + +def gradient_f_numpy(X, batched = False, delta_lst = [1., 1., 1.]): + ''' + Compute gradient of a torch tensor "X" in each direction + Upper-boundaries: Backward Difference + Non-boundaries & Upper-boundaries: Forward Difference + if X is batched: (n_batch, ...); + else: (...) + ''' + dim = len(X.shape) - 1 if batched else len(X.shape) + #print(dim) + if dim == 1: + #print('dim = 1') + X = np.transpose(X, (1, 0)) if batched else X + dX = np.zeros(X.shapee).astype(float) + dX[-1] = X[-1] - X[-2] # Backward Difference + dX[:-1] = X[1:] - X[:-1] # Forward Difference + + dX = np.transpose(X, (1, 0)) if batched else dX + dX /= delta_lst[0] + elif dim == 2: + #print('dim = 2') + dX = np.zeros(X.shape + tuple([2])).astype(float) + X = np.transpose(X, (1, 2, 0)) if batched else X + dX = np.transpose(dX, (1, 2, 3, 0)) if batched else dX # put batch to last dim + dX[-1, :, 0] = X[-1, :] - X[-2, :] # Backward Difference + dX[:-1, :, 0] = X[1:] - X[:-1] # Forward Difference + + dX[:, -1, 1] = X[:, -1] - X[:, -2] # Backward Difference + dX[:, :-1, 1] = X[:, 1:] - X[:, :-1] # Forward Difference + + dX = np.transpose(dX, (3, 0, 1, 2)) if batched else dX + dX[..., 0] /= delta_lst[0] + dX[..., 1] /= delta_lst[1] + elif dim == 3: + #print('dim = 3') + dX = np.zeros(X.shape + tuple([3])).astype(float) + X = np.transpose(X, (1, 2, 3, 0)) if batched else X + dX = np.transpose(dX, (1, 2, 3, 4, 0)) if batched else dX # put batch to last dim + dX[-1, :, :, 0] = X[-1, :, :] - X[-2, :, :] # Backward Difference + dX[:-1, :, :, 0] = X[1:] - X[:-1] # Forward Difference + + dX[:, -1, :, 1] = X[:, -1] - X[:, -2] # Backward Difference + dX[:, :-1, :, 1] = X[:, 1:] - X[:, :-1] # Forward Difference + + dX[:, :, -1, 2] = X[:, :, -1] - X[:, :, -2] # Backward Difference + dX[:, :, :-1, 2] = X[:, :, 1:] - X[:, :, :-1] # Forward Difference + + dX = np.transpose(dX, (4, 0, 1, 2, 3)) if batched else dX + dX[..., 0] /= delta_lst[0] + dX[..., 1] /= delta_lst[1] + dX[..., 2] /= delta_lst[2] + return dX + + +class Upwind(object): + ''' + Backward if > 0, forward if <= 0 + ''' + def __init__(self, U, data_spacing = [1., 1, 1.], batched = True): + self.U = U # (s, r, c) + self.batched = batched + self.data_spacing = data_spacing + self.dim = len(self.U.size()) - 1 if batched else len(self.U.size()) + self.I = torch.ones(self.U.size(), dtype = torch.float, device = U.device) + + def dX(self, FGx): + dXf = gradient_f(self.U, batched = self.batched, delta_lst = self.data_spacing)[..., 0] + dXb = gradient_b(self.U, batched = self.batched, delta_lst = self.data_spacing)[..., 0] + Xflag = (FGx > 0).float() + return dXf * (self.I - Xflag) + dXb * Xflag + + def dY(self, FGy): + dYf = gradient_f(self.U, batched = self.batched, delta_lst = self.data_spacing)[..., 1] + dYb = gradient_b(self.U, batched = self.batched, delta_lst = self.data_spacing)[..., 1] + Yflag = (FGy > 0).float() + return dYf * (self.I - Yflag) + dYb * Yflag + + def dZ(self, FGz): + dZf = gradient_f(self.U, batched = self.batched, delta_lst = self.data_spacing)[..., 2] + dZb = gradient_b(self.U, batched = self.batched, delta_lst = self.data_spacing)[..., 2] + Zflag = (FGz > 0).float() + return dZf * (self.I - Zflag) + dZb * Zflag + + +class AdvDiffPartial(nn.Module): + def __init__(self, data_spacing, device): + super(AdvDiffPartial, self).__init__() + self.dimension = len(data_spacing) # (slc, row, col) + self.device = device + self.data_spacing = data_spacing + + @property + def Grad_Ds(self): + return { + 'constant': self.Grad_constantD, + 'scalar': self.Grad_scalarD, + 'diag': self.Grad_diagD, + 'full': self.Grad_fullD, + 'full_dual': self.Grad_fullD, + 'full_spectral':self.Grad_fullD, + 'full_cholesky': self.Grad_fullD, + 'full_symmetric': self.Grad_fullD + } + @property + def Grad_Vs(self): + return { + 'constant': self.Grad_constantV, + 'scalar': self.Grad_scalarV, + 'vector': self.Grad_vectorV, # For general V w/o div-free TODO self.Grad_vectorV + 'vector_div_free': self.Grad_div_free_vectorV, + 'vector_div_free_clebsch': self.Grad_div_free_vectorV, + 'vector_div_free_stream': self.Grad_div_free_vectorV, + 'vector_div_free_stream_gauge': self.Grad_div_free_vectorV, + } + + def Grad_constantD(self, C, Dlst): + if self.dimension == 1: + return Dlst['D'] * (self.ddXc(C)) + elif self.dimension == 2: + return Dlst['D'] * (self.ddXc(C) + self.ddYc(C)) + elif self.dimension == 3: + return Dlst['D'] * (self.ddXc(C) + self.ddYc(C) + self.ddZc(C)) + + def Grad_constant_tensorD(self, C, Dlst): + if self.dimension == 1: + raise NotImplementedError + elif self.dimension == 2: + dC_c = self.dc(C) + dC_f = self.df(C) + return Dlst['Dxx'] * self.dXb(dC_f[..., 0]) +\ + Dlst['Dxy'] * self.dXb(dC_f[..., 1]) + Dlst['Dxy'] * self.dYb(dC_f[..., 0]) +\ + Dlst['Dyy'] * self.dYb(dC_f[..., 1]) + elif self.dimension == 3: + dC_c = self.dc(C) + dC_f = self.df(C) + return Dlst['Dxx'] * self.dXb(dC_f[..., 0]) + Dlst['Dyy'] * self.dYb(dC_f[..., 1]) + Dlst['Dzz'] * self.dZb(dC_f[..., 2]) + \ + Dlst['Dxy'] * (self.dXb(dC_f[..., 1]) + self.dYb(dC_f[..., 0])) + \ + Dlst['Dyz'] * (self.dYb(dC_f[..., 2]) + self.dZb(dC_f[..., 1])) + \ + Dlst['Dxz'] * (self.dZb(dC_f[..., 0]) + self.dXb(dC_f[..., 2])) + + def Grad_scalarD(self, C, Dlst): # batch_C: (batch_size, (slc), row, col) + # Expanded version: \nabla (D \nabla C) => \nabla D \cdot \nabla C (part (a)) + D \Delta C (part (b)) # + # NOTE: Work better than Central Differences !!! # + # Nested Forward-Backward Difference Scheme in part (b)# + if self.dimension == 1: + dC = gradient_c(C, batched = True, delta_lst = self.data_spacing) + return gradient_c(Dlst['D'], batched = True, delta_lst = self.data_spacing) * dC + \ + Dlst['D'] * gradient_c(dC, batched = True, delta_lst = self.data_spacing) + else: # (dimension = 2 or 3) + dC_c = gradient_c(C, batched = True, delta_lst = self.data_spacing) + dC_f = gradient_f(C, batched = True, delta_lst = self.data_spacing) + dD_c = gradient_c(Dlst['D'], batched = True, delta_lst = self.data_spacing) + out = (dD_c * dC_c).sum(-1) + for dim in range(dC_f.size(-1)): + out += Dlst['D'] * gradient_b(dC_f[..., dim], batched = True, delta_lst = self.data_spacing)[..., dim] + return out + + def Grad_diagD(self, C, Dlst): + # Expanded version # + if self.dimension == 1: + raise NotImplementedError('diag_D is not supported for 1D version of diffusivity') + elif self.dimension == 2: + dC_c = self.dc(C) + dC_f = self.df(C) + return self.dXc(Dlst['Dxx']) * dC_c[..., 0] + Dlst['Dxx'] * self.dXb(dC_f[..., 0]) +\ + self.dYc(Dlst['Dyy']) * dC_c[..., 1] + Dlst['Dyy'] * self.dYb(dC_f[..., 1]) + elif self.dimension == 3: + dC_c = self.dc(C) + dC_f = self.df(C) + return self.dXc(Dlst['Dxx']) * dC_c[..., 0] + Dlst['Dxx'] * self.dXb(dC_f[..., 0]) +\ + self.dYc(Dlst['Dyy']) * dC_c[..., 1] + Dlst['Dyy'] * self.dYb(dC_f[..., 1]) +\ + self.dZc(Dlst['Dzz']) * dC_c[..., 2] + Dlst['Dzz'] * self.dZb(dC_f[..., 2]) + + def Grad_fullD(self, C, Dlst): + # Expanded version # + '''https://github.com/uncbiag/PIANOinD/blob/master/Doc/PIANOinD.pdf''' + if self.dimension == 1: + raise NotImplementedError('full_D is not supported for 1D version of diffusivity') + elif self.dimension == 2: + dC_c = self.dc(C) + dC_f = self.df(C) + return self.dXc(Dlst['Dxx']) * dC_c[..., 0] + Dlst['Dxx'] * self.dXb(dC_f[..., 0]) +\ + self.dXc(Dlst['Dxy']) * dC_c[..., 1] + Dlst['Dxy'] * self.dXb(dC_f[..., 1]) +\ + self.dYc(Dlst['Dxy']) * dC_c[..., 0] + Dlst['Dxy'] * self.dYb(dC_f[..., 0]) +\ + self.dYc(Dlst['Dyy']) * dC_c[..., 1] + Dlst['Dyy'] * self.dYb(dC_f[..., 1]) + elif self.dimension == 3: + dC_c = self.dc(C) + dC_f = self.df(C) + return (self.dXc(Dlst['Dxx']) + self.dYc(Dlst['Dxy']) + self.dZc(Dlst['Dxz'])) * dC_c[..., 0] + \ + (self.dXc(Dlst['Dxy']) + self.dYc(Dlst['Dyy']) + self.dZc(Dlst['Dyz'])) * dC_c[..., 1] + \ + (self.dXc(Dlst['Dxz']) + self.dYc(Dlst['Dyz']) + self.dZc(Dlst['Dzz'])) * dC_c[..., 2] + \ + Dlst['Dxx'] * self.dXb(dC_f[..., 0]) + Dlst['Dyy'] * self.dYb(dC_f[..., 1]) + Dlst['Dzz'] * self.dZb(dC_f[..., 2]) + \ + Dlst['Dxy'] * (self.dXb(dC_f[..., 1]) + self.dYb(dC_f[..., 0])) + \ + Dlst['Dyz'] * (self.dYb(dC_f[..., 2]) + self.dZb(dC_f[..., 1])) + \ + Dlst['Dxz'] * (self.dZb(dC_f[..., 0]) + self.dXb(dC_f[..., 2])) + + def Grad_constantV(self, C, Vlst): + if len(Vlst['V'].size()) == 1: + if self.dimension == 1: + return - Vlst['V'] * self.dXb(C) if Vlst['V'] > 0 else - Vlst['V'] * self.dXf(C) + elif self.dimension == 2: + return - Vlst['V'] * (self.dXb(C) + self.dYb(C)) if Vlst['V'] > 0 else - Vlst['V'] * (self.dXf(C) + self.dYf(C)) + elif self.dimension == 3: + return - Vlst['V'] * (self.dXb(C) + self.dYb(C) + self.dZb(C)) if Vlst['V'] > 0 else - Vlst['V'] * (self.dXf(C) + self.dYf(C) + self.dZf(C)) + else: + if self.dimension == 1: + return - Vlst['V'] * self.dXb(C) if Vlst['V'][0, 0] > 0 else - Vlst['V'] * self.dXf(C) + elif self.dimension == 2: + return - Vlst['V'] * (self.dXb(C) + self.dYb(C)) if Vlst['V'][0, 0, 0] > 0 else - Vlst['V'] * (self.dXf(C) + self.dYf(C)) + elif self.dimension == 3: + return - Vlst['V'] * (self.dXb(C) + self.dYb(C) + self.dZb(C)) if Vlst['V'][0, 0, 0, 0] > 0 else - Vlst['V'] * (self.dXf(C) + self.dYf(C) + self.dZf(C)) + + def Grad_constant_vectorV(self, C, Vlst): + if self.dimension == 1: + raise NotImplementedError + elif self.dimension == 2: + out_x = - Vlst['Vx'] * (self.dXb(C) + self.dYb(C)) if Vlst['Vx'][0, 0, 0] > 0 else - Vlst['Vx'] * (self.dXf(C) + self.dYf(C)) + out_y = - Vlst['Vy'] * (self.dXb(C) + self.dYb(C)) if Vlst['Vy'][0, 0, 0] > 0 else - Vlst['Vy'] * (self.dXf(C) + self.dYf(C)) + return out_x + out_y + elif self.dimension == 3: + out_x = - Vlst['Vx'] * (self.dXb(C) + self.dYb(C)) if Vlst['Vx'][0, 0, 0] > 0 else - Vlst['Vx'] * (self.dXf(C) + self.dYf(C)) + out_y = - Vlst['Vy'] * (self.dXb(C) + self.dYb(C)) if Vlst['Vy'][0, 0, 0] > 0 else - Vlst['Vy'] * (self.dXf(C) + self.dYf(C)) + out_z = - Vlst['Vz'] * (self.dXb(C) + self.dYb(C)) if Vlst['Vz'][0, 0, 0] > 0 else - Vlst['Vz'] * (self.dXf(C) + self.dYf(C)) + return out_x + out_y + out_z + + def Grad_SimscalarV(self, C, Vlst): + V = Vlst['V'] + Upwind_C = Upwind(C, self.data_spacing) + if self.dimension == 1: + C_x = Upwind_C.dX(V) + return - V * C_x + if self.dimension == 2: + C_x, C_y = Upwind_C.dX(V), Upwind_C.dY(V) + return - V * (C_x + C_y) + if self.dimension == 3: + C_x, C_y, C_z = Upwind_C.dX(V), Upwind_C.dY(V), Upwind_C.dZ(V) + return - V * (C_x + C_y + C_z) + + def Grad_scalarV(self, C, Vlst): + V = Vlst['V'] + Upwind_C = Upwind(C, self.data_spacing) + dV = gradient_c(V, batched = True, delta_lst = self.data_spacing) + if self.dimension == 1: + C_x = Upwind_C.dX(V) + return - V * C_x - C * dV + elif self.dimension == 2: + C_x, C_y = Upwind_C.dX(V), Upwind_C.dY(V) + return - V * (C_x + C_y) - C * dV.sum(-1) + elif self.dimension == 3: + C_x, C_y, C_z = Upwind_C.dX(V), Upwind_C.dY(V), Upwind_C.dZ(V) + return - V * (C_x + C_y + C_z) - C * dV.sum(-1) + + def Grad_div_free_vectorV(self, C, Vlst): + ''' For divergence-free-by-definition velocity''' + if self.dimension == 1: + raise NotImplementedError('clebschVector is not supported for 1D version of velocity') + Upwind_C = Upwind(C, self.data_spacing) + C_x, C_y = Upwind_C.dX(Vlst['Vx']), Upwind_C.dY(Vlst['Vy']) + if self.dimension == 2: + return - (Vlst['Vx'] * C_x + Vlst['Vy'] * C_y) + elif self.dimension == 3: + C_z = Upwind_C.dZ(Vlst['Vz']) + return - (Vlst['Vx'] * C_x + Vlst['Vy'] * C_y + Vlst['Vz'] * C_z) + + def Grad_vectorV(self, C, Vlst): + ''' For general velocity''' + if self.dimension == 1: + raise NotImplementedError('vector is not supported for 1D version of velocity') + Upwind_C = Upwind(C, self.data_spacing) + C_x, C_y = Upwind_C.dX(Vlst['Vx']), Upwind_C.dY(Vlst['Vy']) + Vx_x = self.dXc(Vlst['Vx']) + Vy_y = self.dYc(Vlst['Vy']) + if self.dimension == 2: + return - (Vlst['Vx'] * C_x + Vlst['Vy'] * C_y) - C * (Vx_x + Vy_y) + if self.dimension == 3: + C_z = Upwind_C.dZ(Vlst['Vz']) + Vz_z = self.dZc(Vlst['Vz']) + return - (Vlst['Vx'] * C_x + Vlst['Vy'] * C_y + Vlst['Vz'] * C_z) - C * (Vx_x + Vy_y + Vz_z) + + ################# Utilities ################# + def db(self, X): + return gradient_b(X, batched = True, delta_lst = self.data_spacing) + def df(self, X): + return gradient_f(X, batched = True, delta_lst = self.data_spacing) + def dc(self, X): + return gradient_c(X, batched = True, delta_lst = self.data_spacing) + def dXb(self, X): + return gradient_b(X, batched = True, delta_lst = self.data_spacing)[..., 0] + def dXf(self, X): + return gradient_f(X, batched = True, delta_lst = self.data_spacing)[..., 0] + def dXc(self, X): + return gradient_c(X, batched = True, delta_lst = self.data_spacing)[..., 0] + def dYb(self, X): + return gradient_b(X, batched = True, delta_lst = self.data_spacing)[..., 1] + def dYf(self, X): + return gradient_f(X, batched = True, delta_lst = self.data_spacing)[..., 1] + def dYc(self, X): + return gradient_c(X, batched = True, delta_lst = self.data_spacing)[..., 1] + def dZb(self, X): + return gradient_b(X, batched = True, delta_lst = self.data_spacing)[..., 2] + def dZf(self, X): + return gradient_f(X, batched = True, delta_lst = self.data_spacing)[..., 2] + def dZc(self, X): + return gradient_c(X, batched = True, delta_lst = self.data_spacing)[..., 2] + def ddXc(self, X): + return gradient_b(gradient_f(X, batched = True, delta_lst = self.data_spacing)[..., 0], + batched = True, delta_lst = self.data_spacing)[..., 0] + def ddYc(self, X): + return gradient_b(gradient_f(X, batched = True, delta_lst = self.data_spacing)[..., 1], + batched = True, delta_lst = self.data_spacing)[..., 1] + def ddZc(self, X): + return gradient_b(gradient_f(X, batched = True, delta_lst = self.data_spacing)[..., 2], + batched = True, delta_lst = self.data_spacing)[..., 2] + + + +class AdvDiffPDE(nn.Module): + ''' + Plain advection-diffusion PDE solver for pre-set V_lst and D_lst (1D, 2D, 3D) for forward time series simulation + ''' + def __init__(self, data_spacing, perf_pattern, D_type='scalar', V_type='vector', BC=None, dt=0.1, V_dict={}, D_dict={}, stochastic=False, device='cpu'): + super(AdvDiffPDE, self).__init__() + self.BC = BC + self.dt = dt + self.dimension = len(data_spacing) + self.perf_pattern = perf_pattern + self.partials = AdvDiffPartial(data_spacing, device) + self.D_type, self.V_type = D_type, V_type + self.stochastic = stochastic + self.V_dict, self.D_dict = V_dict, D_dict + self.Sigma, self.Sigma_V, self.Sigma_D = 0., 0., 0. # Only for initialization # + if self.dimension == 1: + self.neumann_BC = torch.nn.ReplicationPad1d(1) + elif self.dimension == 2: + self.neumann_BC = torch.nn.ReplicationPad2d(1) + elif self.dimension == 3: + self.neumann_BC = torch.nn.ReplicationPad3d(1) + else: + raise ValueError('Unsupported dimension: %d' % self.dimension) + + @property + def set_BC(self): + # NOTE For bondary condition of mass concentration # + '''X: (n_batch, spatial_shape)''' + if self.BC == 'neumann' or self.BC == 'cauchy': + if self.dimension == 1: + return lambda X: self.neumann_BC(X[:, 1:-1].unsqueeze(dim=1))[:,0] + elif self.dimension == 2: + return lambda X: self.neumann_BC(X[:, 1:-1, 1:-1].unsqueeze(dim=1))[:,0] + elif self.dimension == 3: + return lambda X: self.neumann_BC(X[:, 1:-1, 1:-1, 1:-1].unsqueeze(dim=1))[:,0] + else: + raise NotImplementedError('Unsupported B.C.!') + elif self.BC == 'dirichlet_neumann' or self.BC == 'source_neumann': + ctrl_wdth = 1 + if self.dimension == 1: + self.dirichlet_BC = torch.nn.ReplicationPad1d(ctrl_wdth) + return lambda X: self.dirichlet_BC(X[:, ctrl_wdth : -ctrl_wdth].unsqueeze(dim=1))[:,0] + elif self.dimension == 2: + self.dirichlet_BC = torch.nn.ReplicationPad2d(ctrl_wdth) + return lambda X: self.dirichlet_BC(X[:, ctrl_wdth : -ctrl_wdth, ctrl_wdth : -ctrl_wdth].unsqueeze(dim=1))[:,0] + elif self.dimension == 3: + self.dirichlet_BC = torch.nn.ReplicationPad3d(ctrl_wdth) + return lambda X: self.neumann_dirichlet_BCBC(X[:, ctrl_wdth : -ctrl_wdth, ctrl_wdth : -ctrl_wdth, ctrl_wdth : -ctrl_wdth].unsqueeze(dim=1))[:,0] + else: + raise NotImplementedError('Unsupported B.C.!') + else: + return lambda X: X + + def forward(self, t, batch_C): + ''' + t: (batch_size,) + batch_C: (batch_size, (slc,) row, col) + ''' + batch_size = batch_C.size(0) + batch_C = self.set_BC(batch_C) + if 'diff' not in self.perf_pattern: + out = self.partials.Grad_Vs[self.V_type](batch_C, self.V_dict) + if self.stochastic: + out = out + self.Sigma * math.sqrt(self.dt) * torch.randn_like(batch_C).to(batch_C) + elif 'adv' not in self.perf_pattern: + out = self.partials.Grad_Ds[self.D_type](batch_C, self.D_dict) + if self.stochastic: + out = out + self.Sigma * math.sqrt(self.dt) * torch.randn_like(batch_C).to(batch_C) + else: + if self.stochastic: + out_D = self.partials.Grad_Ds[self.D_type](batch_C, self.D_dict) + out_V = self.partials.Grad_Vs[self.V_type](batch_C, self.V_dict) + out = out_D + out_V + self.Sigma * math.sqrt(self.dt) * torch.randn_like(batch_C).to(batch_C) + else: + out_V = self.partials.Grad_Vs[self.V_type](batch_C, self.V_dict) + out_D = self.partials.Grad_Ds[self.D_type](batch_C, self.D_dict) + out = out_V + out_D + return out + + + diff --git a/ShapeID/DiffEqs/rk_common.py b/ShapeID/DiffEqs/rk_common.py new file mode 100644 index 0000000000000000000000000000000000000000..f77e9547c51224408b3a6e7078755ece61f0fe76 --- /dev/null +++ b/ShapeID/DiffEqs/rk_common.py @@ -0,0 +1,78 @@ +# Based on https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/integrate +import collections +from ShapeID.DiffEqs.misc import _scaled_dot_product, _convert_to_tensor + +_ButcherTableau = collections.namedtuple('_ButcherTableau', 'alpha beta c_sol c_error') + + +class _RungeKuttaState(collections.namedtuple('_RungeKuttaState', 'y1, f1, t0, t1, dt, interp_coeff')): + """Saved state of the Runge Kutta solver. + + Attributes: + y1: Tensor giving the function value at the end of the last time step. + f1: Tensor giving derivative at the end of the last time step. + t0: scalar float64 Tensor giving start of the last time step. + t1: scalar float64 Tensor giving end of the last time step. + dt: scalar float64 Tensor giving the size for the next time step. + interp_coef: list of Tensors giving coefficients for polynomial + interpolation between `t0` and `t1`. + """ + + +def _runge_kutta_step(func, y0, f0, t0, dt, tableau): + """Take an arbitrary Runge-Kutta step and estimate error. + + Args: + func: Function to evaluate like `func(t, y)` to compute the time derivative + of `y`. + y0: Tensor initial value for the state. + f0: Tensor initial value for the derivative, computed from `func(t0, y0)`. + t0: float64 scalar Tensor giving the initial time. + dt: float64 scalar Tensor giving the size of the desired time step. + tableau: optional _ButcherTableau describing how to take the Runge-Kutta + step. + name: optional name for the operation. + + Returns: + Tuple `(y1, f1, y1_error, k)` giving the estimated function value after + the Runge-Kutta step at `t1 = t0 + dt`, the derivative of the state at `t1`, + estimated error at `t1`, and a list of Runge-Kutta coefficients `k` used for + calculating these terms. + """ + dtype = y0[0].dtype + device = y0[0].device + + t0 = _convert_to_tensor(t0, dtype=dtype, device=device) + dt = _convert_to_tensor(dt, dtype=dtype, device=device) + + k = tuple(map(lambda x: [x], f0)) + for alpha_i, beta_i in zip(tableau.alpha, tableau.beta): + ti = t0 + alpha_i * dt + yi = tuple(y0_ + _scaled_dot_product(dt, beta_i, k_) for y0_, k_ in zip(y0, k)) + tuple(k_.append(f_) for k_, f_ in zip(k, func(ti, yi))) + + if not (tableau.c_sol[-1] == 0 and tableau.c_sol[:-1] == tableau.beta[-1]): + # This property (true for Dormand-Prince) lets us save a few FLOPs. + yi = tuple(y0_ + _scaled_dot_product(dt, tableau.c_sol, k_) for y0_, k_ in zip(y0, k)) + + y1 = yi + f1 = tuple(k_[-1] for k_ in k) + y1_error = tuple(_scaled_dot_product(dt, tableau.c_error, k_) for k_ in k) + return (y1, f1, y1_error, k) + + +def rk4_step_func(func, t, dt, y, k1=None): + if k1 is None: k1 = func(t, y) + k2 = func(t + dt / 2, tuple(y_ + dt * k1_ / 2 for y_, k1_ in zip(y, k1))) + k3 = func(t + dt / 2, tuple(y_ + dt * k2_ / 2 for y_, k2_ in zip(y, k2))) + k4 = func(t + dt, tuple(y_ + dt * k3_ for y_, k3_ in zip(y, k3))) + return tuple((k1_ + 2 * k2_ + 2 * k3_ + k4_) * (dt / 6) for k1_, k2_, k3_, k4_ in zip(k1, k2, k3, k4)) + + +def rk4_alt_step_func(func, t, dt, y, k1=None): + """Smaller error with slightly more compute.""" + if k1 is None: k1 = func(t, y) + k2 = func(t + dt / 3, tuple(y_ + dt * k1_ / 3 for y_, k1_ in zip(y, k1))) + k3 = func(t + dt * 2 / 3, tuple(y_ + dt * (k1_ / -3 + k2_) for y_, k1_, k2_ in zip(y, k1, k2))) + k4 = func(t + dt, tuple(y_ + dt * (k1_ - k2_ + k3_) for y_, k1_, k2_, k3_ in zip(y, k1, k2, k3))) + return tuple((k1_ + 3 * k2_ + 3 * k3_ + k4_) * (dt / 8) for k1_, k2_, k3_, k4_ in zip(k1, k2, k3, k4)) diff --git a/ShapeID/DiffEqs/solvers.py b/ShapeID/DiffEqs/solvers.py new file mode 100644 index 0000000000000000000000000000000000000000..15cb7263db6b575028647486d1f43e80074c2ad6 --- /dev/null +++ b/ShapeID/DiffEqs/solvers.py @@ -0,0 +1,216 @@ +import abc +import torch +from ShapeID.DiffEqs.misc import _assert_increasing, _handle_unused_kwargs + +def set_BC_2D(X, BCs): # X: (n_batch, spatial_size); BCs: (batch, 4, BC_shape, data_dim) + BC_size = BCs.size(2) + X[:, : BC_size] = BCs[:, 0] + X[:, - BC_size :] = BCs[:, 1] + X[:, :, : BC_size] = BCs[:, 2].permute(0, 2, 1) # (batch, BC_shape, r) -> (batch, r, BC_shape) + X[:, :, - BC_size :] = BCs[:, 3].permute(0, 2, 1) # (batch, BC_shape, r) -> (batch, r, BC_shape) + del BCs + return X +def set_BC_3D(X, BCs): # X: (n_batch, spatial_size); BCs: (batch, 6, BC_shape, data_dim, dta_dim) + BC_size = BCs.size(2) + X[:, : BC_size] = BCs[:, 0] + X[:, - BC_size :] = BCs[:, 1] + X[:, :, : BC_size] = BCs[:, 2].permute(0, 2, 1, 3) # (batch, BC_shape, s, c) -> (batch, s, BC_shape, c) + X[:, :, - BC_size :] = BCs[:, 3].permute(0, 2, 1, 3) # (batch, BC_shape, s, c) -> (batch, s, BC_shape, c) + X[:, :, :, : BC_size] = BCs[:, 4].permute(0, 2, 3, 1) # (batch, BC_shape, s, r) -> (batch, s, r, BC_shape) + X[:, :, :, - BC_size :] = BCs[:, 5].permute(0, 2, 3, 1) # (batch, BC_shape, s, r) -> (batch, s, r, BC_shape) + del BCs + return X + +''' X[t] = X[t] + dBC[t] (dBC[t] = BC[t+1] - BC[t]) ''' +def add_dBC_2D(X, dBCs): # X: (n_batch, spatial_size); BCs: (batch, 4, BC_shape, data_dim) + BC_size = dBCs.size(2) + X[:, : BC_size] += dBCs[:, 0] + X[:, - BC_size :] += dBCs[:, 1] + X[:, :, : BC_size] += dBCs[:, 2].permute(0, 2, 1) # (batch, BC_shape, r) -> (batch, r, BC_shape) + X[:, :, - BC_size :] += dBCs[:, 3].permute(0, 2, 1) # (batch, BC_shape, r) -> (batch, r, BC_shape) + del dBCs + return X +def add_dBC_3D(X, dBCs): # X: (n_batch, spatial_size); BCs: (batch, 6, BC_shape, data_dim, dta_dim) + BC_size = dBCs.size(2) + X[:, : BC_size] += dBCs[:, 0] + X[:, - BC_size :] += dBCs[:, 1] + X[:, :, : BC_size] += dBCs[:, 2].permute(0, 2, 1, 3) # (batch, BC_shape, s, c) -> (batch, s, BC_shape, c) + X[:, :, - BC_size :] += dBCs[:, 3].permute(0, 2, 1, 3) # (batch, BC_shape, s, c) -> (batch, s, BC_shape, c) + X[:, :, :, : BC_size] += dBCs[:, 4].permute(0, 2, 3, 1) # (batch, BC_shape, s, r) -> (batch, s, r, BC_shape) + X[:, :, :, - BC_size :] += dBCs[:, 5].permute(0, 2, 3, 1) # (batch, BC_shape, s, r) -> (batch, s, r, BC_shape) + del dBCs + return X + +class AdaptiveStepsizeODESolver(object): + __metaclass__ = abc.ABCMeta + + def __init__(self, func, y0, atol, rtol, options= None): + + # _handle_unused_kwargs(self, options) + #del options + self.func = func + self.y0 = y0 + self.atol = atol + self.rtol = rtol + + def before_integrate(self, t): + pass + + @abc.abstractmethod + def advance(self, next_t): + raise NotImplementedError + + def integrate(self, t): + _assert_increasing(t) + solution = [self.y0] + t = t.to(self.y0[0].device, torch.float64) + self.before_integrate(t) + for i in range(1, len(t)): + y = self.advance(t[i]) + solution.append(y) + '''if self.contours is not None: # contours: (n_batch, nT, 4 / 6, BC_size, c) + if self.adjoint: + for i in range(1, len(t)): + ys = list(self.advance(t[i])) # tuple: (y0, **back_grad) -> y0: (n_batch, spatial_shape) + #print(len(t)) + #print(ys[0].size()) + #print(self.contours.size()) + ys[0] = self.set_BC(ys[0], self.contours[:, i]) # (n_batch, 4 / 6, BC_size, c) + solution.append(tuple(ys)) + else: + for i in range(1, len(t)): + y = torch.stack(self.advance(t[i])) # y: (n_batch, 1, spatial_shape) + y = self.set_BC(y[:, 0], self.contours[:, i]).unsqueeze(1) + solution.append(tuple(y)) + elif self.dcontours is not None: # dcontours: (n_batch, nT, 4 / 6, BC_size, c) + if self.adjoint: + for i in range(1, len(t)): + ys = list(self.advance(t[i])) # ys - tuple: (y0, **back_grad) -> y0: (n_batch, spatial_shape) + ys[0] = self.add_dBC(ys[0], self.dcontours[:, i]) # (n_batch, 4 / 6, BC_size, c) + solution.append(tuple(ys)) + else: + for i in range(1, len(t)): + y = torch.stack(self.advance(t[i])) # (n_batch, 1, spatial_shape) + y = self.add_dBC(y[:, 0], self.dcontours[:, i]).unsqueeze(1) + solution.append(tuple(y)) + else: + for i in range(1, len(t)): + y = self.advance(t[i]) + solution.append(y)''' + return tuple(map(torch.stack, tuple(zip(*solution)))) + + +class FixedGridODESolver(object): + __metaclass__ = abc.ABCMeta + + def __init__(self, func, y0, step_size=None, grid_constructor=None, atol=None, rtol=None, dt=None, options = None): + '''if 'dirichlet' in options.BC or 'cauchy' in options.BC and options.contours is not None: + self.contours = options.contours # (n_batch, nT, 4 / 6, BC_size, sub_spatial_shape) + self.BC_size = self.contours.size(3) + self.set_BC = set_BC_2D if self.contours.size(2) == 4 else set_BC_3D + else: + self.contours = None + if 'source' in options.BC and options.dcontours is not None: + self.dcontours = options.dcontours # (n_batch, nT, 4 / 6, BC_size, sub_spatial_shape) + self.BC_size = self.dcontours.size(3) + self.add_dBC = add_dBC_2D if self.dcontours.size(2) == 4 else add_dBC_3D + else: + self.dcontours = None''' + #self.adjoint = options.adjoint + #options.pop('rtol', None) + #options.pop('atol', None) + #_handle_unused_kwargs(self, options) + #del options + + self.func = func + self.y0 = y0 + + if step_size is not None and grid_constructor is None: + self.grid_constructor = self._grid_constructor_from_step_size(step_size) + elif grid_constructor is None: + self.grid_constructor = lambda f, y0, t: t # Same time step as time interval + else: + raise ValueError("step_size and grid_constructor are exclusive arguments.") + + def _grid_constructor_from_step_size(self, step_size): + + def _grid_constructor(func, y0, t): + start_time = t[0] + end_time = t[-1] + + niters = torch.ceil((end_time - start_time) / step_size + 1).item() + t_infer = torch.arange(0, niters).to(t) * step_size + start_time + if t_infer[-1] > t[-1]: + t_infer[-1] = t[-1] + return t_infer + + return _grid_constructor + + @property + @abc.abstractmethod + def order(self): + pass + + @abc.abstractmethod + def step_func(self, func, t, dt, y): + pass + + def integrate(self, t): + _assert_increasing(t) + t = t.type_as(self.y0[0]) # (n_time, ) + time_grid = self.grid_constructor(self.func, self.y0, t) + #print('time_grid:', time_grid.size()) + #print('t:', t.size()) + assert time_grid[0] == t[0] and time_grid[-1] == t[-1] + time_grid = time_grid.to(self.y0[0]) + + solution = [self.y0] + + j = 1 + y0 = self.y0 + for t0, t1 in zip(time_grid[:-1], time_grid[1:]): + dy = self.step_func(self.func, t0, t1 - t0, y0) + y1 = tuple(y0_ + dy_ for y0_, dy_ in zip(y0, dy)) + y0 = y1 + while j < len(t) and t1 >= t[j]: + solution.append(self._linear_interp(t0, t1, y0, y1, t[j])) + j += 1 + '''if self.contours is not None: + if self.adjoint: + for i in range(1, len(t)): + ys = list(self._linear_interp(t0, t1, y0, y1, t[j])) # tuple: (y0, **back_grad) -> y0: (n_batch, spatial_shape) + ys[0] = self.set_BC(ys[0], self.contours[:, i]) # (n_batch, 4 / 6, BC_size, c) + solution.append(tuple(ys)) + j += 1 + else: + while j < len(t) and t1 >= t[j]: + y = torch.stack(self._linear_interp(t0, t1, y0, y1, t[j])) # (n_batch, 1, spatial_shape) + y = self.set_BC(y[:, 0], self.contours[:, j]).unsqueeze(1) + solution.append(tuple(y)) + j += 1 + elif self.dcontours is not None: + if self.adjoint: + for i in range(1, len(t)): + ys = list(self._linear_interp(t0, t1, y0, y1, t[j])) # tuple: (y0, **back_grad) -> y0: (n_batch, spatial_shape) + ys[0] = self.add_dBC(ys[0], self.dcontours[:, j]) # (n_batch, 4 / 6, BC_size, c) + solution.append(tuple(ys)) + else: + while j < len(t) and t1 >= t[j]: + y = torch.stack(self._linear_interp(t0, t1, y0, y1, t[j])) # (n_batch, 1, spatial_shape) + y = self.add_dBC(y[:, 0], self.dcontours[:, j]).unsqueeze(1) + solution.append(tuple(y)) + j += 1 + else: + while j < len(t) and t1 >= t[j]: + solution.append(self._linear_interp(t0, t1, y0, y1, t[j])) + j += 1''' + return tuple(map(torch.stack, tuple(zip(*solution)))) # (batch, time) + + def _linear_interp(self, t0, t1, y0, y1, t): + if t == t0: + return y0 + if t == t1: + return y1 + t0, t1, t = t0.to(y0[0]), t1.to(y0[0]), t.to(y0[0]) + slope = tuple((y1_ - y0_) / (t1 - t0) for y0_, y1_, in zip(y0, y1)) + return tuple(y0_ + slope_ * (t - t0) for y0_, slope_ in zip(y0, slope)) diff --git a/ShapeID/DiffEqs/tsit5.py b/ShapeID/DiffEqs/tsit5.py new file mode 100644 index 0000000000000000000000000000000000000000..40ae8c9b7c0038498078215a2c90fe1baa726566 --- /dev/null +++ b/ShapeID/DiffEqs/tsit5.py @@ -0,0 +1,139 @@ +import torch +from ShapeID.DiffEqs.misc import _scaled_dot_product, _convert_to_tensor, _is_finite, _select_initial_step, _handle_unused_kwargs +from ShapeID.DiffEqs.solvers import AdaptiveStepsizeODESolver +from ShapeID.DiffEqs.rk_common import _RungeKuttaState, _ButcherTableau, _runge_kutta_step + +# Parameters from Tsitouras (2011). +_TSITOURAS_TABLEAU = _ButcherTableau( + alpha=[0.161, 0.327, 0.9, 0.9800255409045097, 1., 1.], + beta=[ + [0.161], + [-0.008480655492357, 0.3354806554923570], + [2.897153057105494, -6.359448489975075, 4.362295432869581], + [5.32586482843925895, -11.74888356406283, 7.495539342889836, -0.09249506636175525], + [5.86145544294642038, -12.92096931784711, 8.159367898576159, -0.071584973281401006, -0.02826905039406838], + [0.09646076681806523, 0.01, 0.4798896504144996, 1.379008574103742, -3.290069515436081, 2.324710524099774], + ], + c_sol=[0.09646076681806523, 0.01, 0.4798896504144996, 1.379008574103742, -3.290069515436081, 2.324710524099774, 0], + c_error=[ + 0.09646076681806523 - 0.001780011052226, + 0.01 - 0.000816434459657, + 0.4798896504144996 - -0.007880878010262, + 1.379008574103742 - 0.144711007173263, + -3.290069515436081 - -0.582357165452555, + 2.324710524099774 - 0.458082105929187, + -1 / 66, + ], +) + + +def _interp_coeff_tsit5(t0, dt, eval_t): + t = float((eval_t - t0) / dt) + b1 = -1.0530884977290216 * t * (t - 1.3299890189751412) * (t**2 - 1.4364028541716351 * t + 0.7139816917074209) + b2 = 0.1017 * t**2 * (t**2 - 2.1966568338249754 * t + 1.2949852507374631) + b3 = 2.490627285651252793 * t**2 * (t**2 - 2.38535645472061657 * t + 1.57803468208092486) + b4 = -16.54810288924490272 * (t - 1.21712927295533244) * (t - 0.61620406037800089) * t**2 + b5 = 47.37952196281928122 * (t - 1.203071208372362603) * (t - 0.658047292653547382) * t**2 + b6 = -34.87065786149660974 * (t - 1.2) * (t - 0.666666666666666667) * t**2 + b7 = 2.5 * (t - 1) * (t - 0.6) * t**2 + return [b1, b2, b3, b4, b5, b6, b7] + + +def _interp_eval_tsit5(t0, t1, k, eval_t): + dt = t1 - t0 + y0 = tuple(k_[0] for k_ in k) + interp_coeff = _interp_coeff_tsit5(t0, dt, eval_t) + y_t = tuple(y0_ + _scaled_dot_product(dt, interp_coeff, k_) for y0_, k_ in zip(y0, k)) + return y_t + + +def _optimal_step_size(last_step, mean_error_ratio, safety=0.9, ifactor=10.0, dfactor=0.2, order=5): + """Calculate the optimal size for the next Runge-Kutta step.""" + if mean_error_ratio == 0: + return last_step * ifactor + if mean_error_ratio < 1: + dfactor = _convert_to_tensor(1, dtype=torch.float64, device=mean_error_ratio.device) + error_ratio = torch.sqrt(mean_error_ratio).type_as(last_step) + exponent = torch.tensor(1 / order).type_as(last_step) + factor = torch.max(1 / ifactor, torch.min(error_ratio**exponent / safety, 1 / dfactor)) + return last_step / factor + + +def _abs_square(x): + return torch.mul(x, x) + + +class Tsit5Solver(AdaptiveStepsizeODESolver): + + def __init__( + self, func, y0, rtol, atol, first_step=None, safety=0.9, ifactor=10.0, dfactor=0.2, max_num_steps=2**31 - 1, + **unused_kwargs + ): + _handle_unused_kwargs(self, unused_kwargs) + del unused_kwargs + + self.func = func + self.y0 = y0 + self.rtol = rtol + self.atol = atol + self.first_step = first_step + self.safety = _convert_to_tensor(safety, dtype=torch.float64, device=y0[0].device) + self.ifactor = _convert_to_tensor(ifactor, dtype=torch.float64, device=y0[0].device) + self.dfactor = _convert_to_tensor(dfactor, dtype=torch.float64, device=y0[0].device) + self.max_num_steps = _convert_to_tensor(max_num_steps, dtype=torch.int32, device=y0[0].device) + + def before_integrate(self, t): + if self.first_step is None: + first_step = _select_initial_step(self.func, t[0], self.y0, 4, self.rtol, self.atol).to(t) + else: + first_step = _convert_to_tensor(0.01, dtype=t.dtype, device=t.device) + self.rk_state = _RungeKuttaState( + self.y0, + self.func(t[0].type_as(self.y0[0]), self.y0), t[0], t[0], first_step, + tuple(map(lambda x: [x] * 7, self.y0)) + ) + + def advance(self, next_t): + """Interpolate through the next time point, integrating as necessary.""" + n_steps = 0 + while next_t > self.rk_state.t1: + assert n_steps < self.max_num_steps, 'max_num_steps exceeded ({}>={})'.format(n_steps, self.max_num_steps) + self.rk_state = self._adaptive_tsit5_step(self.rk_state) + n_steps += 1 + return _interp_eval_tsit5(self.rk_state.t0, self.rk_state.t1, self.rk_state.interp_coeff, next_t) + + def _adaptive_tsit5_step(self, rk_state): + """Take an adaptive Runge-Kutta step to integrate the DiffEqs.""" + y0, f0, _, t0, dt, _ = rk_state + ######################################################## + # Assertions # + ######################################################## + assert t0 + dt > t0, 'underflow in dt {}'.format(dt.item()) + for y0_ in y0: + assert _is_finite(torch.abs(y0_)), 'non-finite values in state `y`: {}'.format(y0_) + y1, f1, y1_error, k = _runge_kutta_step(self.func, y0, f0, t0, dt, tableau=_TSITOURAS_TABLEAU) + + ######################################################## + # Error Ratio # + ######################################################## + error_tol = tuple(self.atol + self.rtol * torch.max(torch.abs(y0_), torch.abs(y1_)) for y0_, y1_ in zip(y0, y1)) + tensor_error_ratio = tuple(y1_error_ / error_tol_ for y1_error_, error_tol_ in zip(y1_error, error_tol)) + sq_error_ratio = tuple( + torch.mul(tensor_error_ratio_, tensor_error_ratio_) for tensor_error_ratio_ in tensor_error_ratio + ) + mean_error_ratio = ( + sum(torch.sum(sq_error_ratio_) for sq_error_ratio_ in sq_error_ratio) / + sum(sq_error_ratio_.numel() for sq_error_ratio_ in sq_error_ratio) + ) + accept_step = mean_error_ratio <= 1 + + ######################################################## + # Update RK State # + ######################################################## + y_next = y1 if accept_step else y0 + f_next = f1 if accept_step else f0 + t_next = t0 + dt if accept_step else t0 + dt_next = _optimal_step_size(dt, mean_error_ratio, self.safety, self.ifactor, self.dfactor) + k_next = k if accept_step else self.rk_state.interp_coeff + rk_state = _RungeKuttaState(y_next, f_next, t0, t_next, dt_next, k_next) + return rk_state diff --git a/ShapeID/__init__.py b/ShapeID/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..713c8268a970043fa772ace9622b091b199b4f34 --- /dev/null +++ b/ShapeID/__init__.py @@ -0,0 +1 @@ +from utils import * \ No newline at end of file diff --git a/ShapeID/demo2d.py b/ShapeID/demo2d.py new file mode 100644 index 0000000000000000000000000000000000000000..25454a3c4369d570e2a43477bceef4eac6893407 --- /dev/null +++ b/ShapeID/demo2d.py @@ -0,0 +1,102 @@ +# ported from https://github.com/pvigier/perlin-numpy + +import os, sys +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +import time, datetime + +import torch +import numpy as np +import matplotlib.pyplot as plt + +from misc import stream_2D, V_plot +from utils.misc import viewVolume, make_dir + +from perlin2d import * + + +#from ShapeID.DiffEqs.odeint import odeint +from ShapeID.DiffEqs.adjoint import odeint_adjoint as odeint +from ShapeID.DiffEqs.pde import AdvDiffPDE + + + + +if __name__ == '__main__': + + device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') + + + image, mask_image = generate_perlin_noise_2d([256, 256], [2, 2], percentile = 80) + plt.imshow(image, cmap='gray') #, interpolation='lanczos') + plt.axis('off') + plt.savefig('out/2d/image.png') + plt.imshow(mask_image, cmap='gray') #, interpolation='lanczos') + plt.axis('off') + plt.savefig('out/2d/mask_image.png') + + + + curl, mask_curl = generate_perlin_noise_2d([256, 256], [2, 2], percentile = 80) + plt.imshow(curl, cmap='gray') #, interpolation='lanczos') + plt.axis('off') + plt.savefig('out/2d/curl.png') + plt.imshow(mask_curl, cmap='gray') #, interpolation='lanczos') + plt.axis('off') + plt.savefig('out/2d/mask_curl.png') + + + dx, dy = stream_2D(torch.from_numpy(curl)) + V_plot(dx.numpy(), dy.numpy(), 'out/2d/V.png') + + plt.imshow(mask_image, cmap='gray') #, interpolation='lanczos') + plt.axis('off') + plt.savefig('out/2d/image_with_v.png') + #plt.close() + + + dt = 0.15 + nt = 21 + integ_method = 'dopri5' # choices=['dopri5', 'adams', 'rk4', 'euler'] + t = torch.from_numpy(np.arange(nt) * dt).to(device) + thres = 0.9 + + initial = torch.from_numpy(mask_image) + Vx, Vy = dx * 1000, dy * 1000 + + forward_pde = AdvDiffPDE(data_spacing=[1., 1.], + perf_pattern='adv', + V_type='vector_div_free', + V_dict={'Vx': Vx, 'Vy': Vy}, + BC='neumann', + dt=dt, + device=device + ) + + + start_time = time.time() + noise_progression = odeint(forward_pde, + initial.unsqueeze(0), + t, dt, method = integ_method + )[:, 0] + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('Time {}'.format(total_time_str)) + noise_progression = noise_progression[::2] + + + noise_progression = noise_progression.numpy() + make_dir('out/2d/progression') + + for i, noise_t in enumerate(noise_progression): + print(i, noise_t.mean()) + + noise_t[noise_t > thres] = 1 + noise_t[noise_t <= thres] = 0 + + #fig = plt.figure() + plt.imshow(noise_t, cmap='gray') #, interpolation='lanczos') + plt.savefig('out/2d/progression/%d.png' % i) + #plt.close() + + viewVolume(noise_progression, names = ['noise_progression'], save_dir = 'out/2d/progression') \ No newline at end of file diff --git a/ShapeID/demo3d.py b/ShapeID/demo3d.py new file mode 100644 index 0000000000000000000000000000000000000000..54483e8d66db3ce1b01f53dd5daf274b4fb78763 --- /dev/null +++ b/ShapeID/demo3d.py @@ -0,0 +1,91 @@ +# ported from https://github.com/pvigier/perlin-numpy + +import os, sys +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +import time, datetime + +import torch +import numpy as np +import matplotlib.pyplot as plt + +from misc import stream_3D, V_plot, center_crop +from utils.misc import viewVolume, make_dir, read_image + + +#from ShapeID.DiffEqs.odeint import odeint +from ShapeID.DiffEqs.adjoint import odeint_adjoint as odeint +from ShapeID.DiffEqs.pde import AdvDiffPDE + +from perlin3d import * + + + + +if __name__ == '__main__': + + device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') + + percentile = 80 + + + #image, mask_image = generate_perlin_noise_3d([128, 128, 128], [2, 2, 2], tileable=(True, False, False), percentile = percentile) + #viewVolume(image, names = ['image'], save_dir = '/autofs/space/yogurt_003/users/pl629/code/MTBrainID/ShapeID/out/3d') + #viewVolume(mask_image, names = ['mask_image'], save_dir = '/autofs/space/yogurt_003/users/pl629/code/MTBrainID/ShapeID/out/3d') + + + #mask_image, aff = read_image('/autofs/space/yogurt_001/users/pl629/data/adni/pathology_probability/subject_193441.nii.gz') + mask_image, aff = read_image('/autofs/space/yogurt_001/users/pl629/data/isles2022/pathology_probability/sub-strokecase0127.nii.gz') + mask_image, _, _ = center_crop(torch.from_numpy(mask_image), win_size = [128, 128, 128]) + mask_image = mask_image[0, 0].numpy() + + shape = mask_image.shape + + curl_a, _ = generate_perlin_noise_3d(shape, [2, 2, 2], tileable=(True, False, False), percentile = percentile) + curl_b, _ = generate_perlin_noise_3d(shape, [2, 2, 2], tileable=(True, False, False), percentile = percentile) + curl_c, _ = generate_perlin_noise_3d(shape, [2, 2, 2], tileable=(True, False, False), percentile = percentile) + dx, dy, dz = stream_3D(torch.from_numpy(curl_a), torch.from_numpy(curl_b), torch.from_numpy(curl_c)) + + + dt = 0.1 + nt = 10 + integ_method = 'dopri5' # choices=['dopri5', 'adams', 'rk4', 'euler'] + t = torch.from_numpy(np.arange(nt) * dt).to(device) + thres = 0.5 + + initial = torch.from_numpy(mask_image)[None] # (batch=1, h, w) + Vx, Vy, Vz = dx * 500, dy * 500, dz * 500 + print(abs(Vx).mean(), abs(Vy).mean(), abs(Vz).mean()) + + forward_pde = AdvDiffPDE(data_spacing=[1., 1., 1.], + perf_pattern='adv', + V_type='vector_div_free', + V_dict={'Vx': Vx, 'Vy': Vy, 'Vz': Vz}, + BC='neumann', + dt=dt, + device=device + ) + + + start_time = time.time() + noise_progression = odeint(forward_pde, + initial, + t, dt, method = integ_method + )[:, 0] # (nt, n_batch, h, w) + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('Time {}'.format(total_time_str)) + + noise_progression = noise_progression[::2] + noise_progression = noise_progression.numpy() + make_dir('out/3d/progression') + + + for i, noise_t in enumerate(noise_progression): + noise_t[noise_t > 1] = 1 + noise_t[noise_t <= thres] = 0 + print(i, noise_t.mean()) + viewVolume(noise_t, names = ['noise_%s' % i], save_dir = '/autofs/space/yogurt_003/users/pl629/code/MTBrainID/ShapeID/out/3d/progression') + + noise_t[noise_t > 0.] = 1 + viewVolume(noise_t, names = ['noise_%s_mask' % i], save_dir = '/autofs/space/yogurt_003/users/pl629/code/MTBrainID/ShapeID/out/3d/progression') diff --git a/ShapeID/misc.py b/ShapeID/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..65ee0a7121e393e39dbf26650fdd62d5480ed9b5 --- /dev/null +++ b/ShapeID/misc.py @@ -0,0 +1,261 @@ +# ported from https://github.com/pvigier/perlin-numpy + + +import torch +import numpy as np +import matplotlib.pyplot as plt + + + +def center_crop(img, win_size = [220, 220, 220]): + # center crop + if len(img.shape) == 4: + img = torch.permute(img, (3, 0, 1, 2)) # (move last dim to first) + img = img[None] + permuted = True + else: + assert len(img.shape) == 3 + img = img[None, None] + permuted = False + + orig_shp = img.shape[2:] # (1, d, s, r, c) + if win_size is None: + if permuted: + return torch.permute(img, (0, 2, 3, 4, 1)), [0, 0, 0], orig_shp + return img, [0, 0, 0], orig_shp + elif orig_shp[0] > win_size[0] or orig_shp[1] > win_size[1] or orig_shp[2] > win_size[2]: + crop_start = [ max((orig_shp[i] - win_size[i]), 0) // 2 for i in range(3) ] + crop_img = img[ :, :, crop_start[0] : crop_start[0] + win_size[0], + crop_start[1] : crop_start[1] + win_size[1], + crop_start[2] : crop_start[2] + win_size[2]] + if permuted: + return torch.permute(crop_img, (0, 2, 3, 4, 1)), [0, 0, 0], orig_shp + return crop_img, crop_start, orig_shp + else: + if permuted: + return torch.permute(img, (0, 2, 3, 4, 1)), [0, 0, 0], orig_shp + return img, [0, 0, 0], orig_shp + + + +def V_plot(Vx, Vy, save_path): + # Meshgrid + X,Y = np.meshgrid(np.arange(0, Vx.shape[0], 1), np.arange(0, Vx.shape[1], 1)) + # Assign vector directions + Ex = Vx + Ey = Vy + + # Depict illustration + plt.figure() + plt.streamplot(X,Y,Ex,Ey, density=1.4, linewidth=None, color='orange') + plt.axis('off') + plt.savefig(save_path) + #plt.show() + +def stream_2D(Phi, batched = False, delta_lst = [1., 1.]): + ''' + input: Phi as a scalar field in 2D grid: (r, c) or (n_batch, r, c) + output: curl of Phi (divergence-free by definition) + ''' + dD = gradient_c(Phi, batched = batched, delta_lst = delta_lst) + Vx = - dD[..., 1] + Vy = dD[..., 0] + return Vx, Vy + + +def stream_3D(Phi_a, Phi_b, Phi_c, batched = False, delta_lst = [1., 1., 1.]): + ''' + input: (batch, s, r, c) + ''' + device = Phi_a.device + dDa = gradient_c(Phi_a, batched = batched, delta_lst = delta_lst) + dDb = gradient_c(Phi_b, batched = batched, delta_lst = delta_lst) + dDc = gradient_c(Phi_c, batched = batched, delta_lst = delta_lst) + Va_x, Va_y, Va_z = dDa[..., 0], dDa[..., 1], dDa[..., 2] + Vb_x, Vb_y, Vb_z = dDb[..., 0], dDb[..., 1], dDb[..., 2] + Vc_x, Vc_y, Vc_z = dDc[..., 0], dDc[..., 1], dDc[..., 2] + Vx = Vc_y - Vb_z + Vy = Va_z - Vc_x + Vz = Vb_x - Va_y + return Vx, Vy, Vz + + + +def gradient_f(X, batched = False, delta_lst = [1., 1., 1.]): + ''' + Compute gradient of a torch tensor "X" in each direction + Upper-boundaries: Backward Difference + Non-boundaries & Upper-boundaries: Forward Difference + if X is batched: (n_batch, ...); + else: (...) + ''' + device = X.device + dim = len(X.size()) - 1 if batched else len(X.size()) + #print(batched) + #print(dim) + if dim == 1: + #print('dim = 1') + dX = torch.zeros(X.size(), dtype = torch.float, device = device) + X = X.permute(1, 0) if batched else X + dX = dX.permute(1, 0) if batched else dX + dX[-1] = X[-1] - X[-2] # Backward Difference + dX[:-1] = X[1:] - X[:-1] # Forward Difference + + dX = dX.permute(1, 0) if batched else dX + dX /= delta_lst[0] + elif dim == 2: + #print('dim = 2') + dX = torch.zeros(X.size() + tuple([2]), dtype = torch.float, device = device) + X = X.permute(1, 2, 0) if batched else X + dX = dX.permute(1, 2, 3, 0) if batched else dX # put batch to last dim + dX[-1, :, 0] = X[-1, :] - X[-2, :] # Backward Difference + dX[:-1, :, 0] = X[1:] - X[:-1] # Forward Difference + + dX[:, -1, 1] = X[:, -1] - X[:, -2] # Backward Difference + dX[:, :-1, 1] = X[:, 1:] - X[:, :-1] # Forward Difference + + dX = dX.permute(3, 0, 1, 2) if batched else dX + dX[..., 0] /= delta_lst[0] + dX[..., 1] /= delta_lst[1] + elif dim == 3: + #print('dim = 3') + dX = torch.zeros(X.size() + tuple([3]), dtype = torch.float, device = device) + X = X.permute(1, 2, 3, 0) if batched else X + dX = dX.permute(1, 2, 3, 4, 0) if batched else dX + dX[-1, :, :, 0] = X[-1, :, :] - X[-2, :, :] # Backward Difference + dX[:-1, :, :, 0] = X[1:] - X[:-1] # Forward Difference + + dX[:, -1, :, 1] = X[:, -1] - X[:, -2] # Backward Difference + dX[:, :-1, :, 1] = X[:, 1:] - X[:, :-1] # Forward Difference + + dX[:, :, -1, 2] = X[:, :, -1] - X[:, :, -2] # Backward Difference + dX[:, :, :-1, 2] = X[:, :, 1:] - X[:, :, :-1] # Forward Difference + + dX = dX.permute(4, 0, 1, 2, 3) if batched else dX + dX[..., 0] /= delta_lst[0] + dX[..., 1] /= delta_lst[1] + dX[..., 2] /= delta_lst[2] + return dX + + +def gradient_b(X, batched = False, delta_lst = [1., 1., 1.]): + ''' + Compute gradient of a torch tensor "X" in each direction + Non-boundaries & Upper-boundaries: Backward Difference + Lower-boundaries: Forward Difference + if X is batched: (n_batch, ...); + else: (...) + ''' + device = X.device + dim = len(X.size()) - 1 if batched else len(X.size()) + #print(batched) + #print(dim) + if dim == 1: + #print('dim = 1') + dX = torch.zeros(X.size(), dtype = torch.float, device = device) + X = X.permute(1, 0) if batched else X + dX = dX.permute(1, 0) if batched else dX + dX[1:] = X[1:] - X[:-1] # Backward Difference + dX[0] = X[1] - X[0] # Forward Difference + + dX = dX.permute(1, 0) if batched else dX + dX /= delta_lst[0] + elif dim == 2: + #print('dim = 2') + dX = torch.zeros(X.size() + tuple([2]), dtype = torch.float, device = device) + X = X.permute(1, 2, 0) if batched else X + dX = dX.permute(1, 2, 3, 0) if batched else dX # put batch to last dim + dX[1:, :, 0] = X[1:, :] - X[:-1, :] # Backward Difference + dX[0, :, 0] = X[1] - X[0] # Forward Difference + + dX[:, 1:, 1] = X[:, 1:] - X[:, :-1] # Backward Difference + dX[:, 0, 1] = X[:, 1] - X[:, 0] # Forward Difference + + dX = dX.permute(3, 0, 1, 2) if batched else dX + dX[..., 0] /= delta_lst[0] + dX[..., 1] /= delta_lst[1] + elif dim == 3: + #print('dim = 3') + dX = torch.zeros(X.size() + tuple([3]), dtype = torch.float, device = device) + X = X.permute(1, 2, 3, 0) if batched else X + dX = dX.permute(1, 2, 3, 4, 0) if batched else dX + dX[1:, :, :, 0] = X[1:, :, :] - X[:-1, :, :] # Backward Difference + dX[0, :, :, 0] = X[1] - X[0] # Forward Difference + + dX[:, 1:, :, 1] = X[:, 1:] - X[:, :-1] # Backward Difference + dX[:, 0, :, 1] = X[:, 1] - X[:, 0] # Forward Difference + + dX[:, :, 1:, 2] = X[:, :, 1:] - X[:, :, :-1] # Backward Difference + dX[:, :, 0, 2] = X[:, :, 1] - X[:, :, 0] # Forward Difference + + dX = dX.permute(4, 0, 1, 2, 3) if batched else dX + dX[..., 0] /= delta_lst[0] + dX[..., 1] /= delta_lst[1] + dX[..., 2] /= delta_lst[2] + return dX + + +def gradient_c(X, batched = False, delta_lst = [1., 1., 1.]): + ''' + Compute gradient of a torch tensor "X" in each direction + Non-boundaries: Central Difference + Upper-boundaries: Backward Difference + Lower-boundaries: Forward Difference + if X is batched: (n_batch, ...); + else: (...) + ''' + + device = X.device + dim = len(X.size()) - 1 if batched else len(X.size()) + #print(X.size()) + #print(batched) + #print(dim) + if dim == 1: + #print('dim = 1') + dX = torch.zeros(X.size(), dtype = torch.float, device = device) + X = X.permute(1, 0) if batched else X + dX = dX.permute(1, 0) if batched else dX + dX[1:-1] = (X[2:] - X[:-2]) / 2 # Central Difference + dX[0] = X[1] - X[0] # Forward Difference + dX[-1] = X[-1] - X[-2] # Backward Difference + + dX = dX.permute(1, 0) if batched else dX + dX /= delta_lst[0] + elif dim == 2: + #print('dim = 2') + dX = torch.zeros(X.size() + tuple([2]), dtype = torch.float, device = device) + X = X.permute(1, 2, 0) if batched else X + dX = dX.permute(1, 2, 3, 0) if batched else dX # put batch to last dim + dX[1:-1, :, 0] = (X[2:, :] - X[:-2, :]) / 2 + dX[0, :, 0] = X[1] - X[0] + dX[-1, :, 0] = X[-1] - X[-2] + dX[:, 1:-1, 1] = (X[:, 2:] - X[:, :-2]) / 2 + dX[:, 0, 1] = X[:, 1] - X[:, 0] + dX[:, -1, 1] = X[:, -1] - X[:, -2] + + dX = dX.permute(3, 0, 1, 2) if batched else dX + dX[..., 0] /= delta_lst[0] + dX[..., 1] /= delta_lst[1] + elif dim == 3: + #print('dim = 3') + dX = torch.zeros(X.size() + tuple([3]), dtype = torch.float, device = device) + X = X.permute(1, 2, 3, 0) if batched else X + dX = dX.permute(1, 2, 3, 4, 0) if batched else dX + dX[1:-1, :, :, 0] = (X[2:, :, :] - X[:-2, :, :]) / 2 + dX[0, :, :, 0] = X[1] - X[0] + dX[-1, :, :, 0] = X[-1] - X[-2] + dX[:, 1:-1, :, 1] = (X[:, 2:, :] - X[:, :-2, :]) / 2 + dX[:, 0, :, 1] = X[:, 1] - X[:, 0] + dX[:, -1, :, 1] = X[:, -1] - X[:, -2] + dX[:, :, 1:-1, 2] = (X[:, :, 2:] - X[:, :, :-2]) / 2 + dX[:, :, 0, 2] = X[:, :, 1] - X[:, :, 0] + dX[:, :, -1, 2] = X[:, :, -1] - X[:, :, -2] + + dX = dX.permute(4, 0, 1, 2, 3) if batched else dX + dX[..., 0] /= delta_lst[0] + dX[..., 1] /= delta_lst[1] + dX[..., 2] /= delta_lst[2] + + return dX + + diff --git a/ShapeID/out/2d/V.png b/ShapeID/out/2d/V.png new file mode 100644 index 0000000000000000000000000000000000000000..ba10af802dc95c2e510ee4c78c89bdaee1259caf --- /dev/null +++ b/ShapeID/out/2d/V.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ddfa181038abcb687f69cb20d9aea1caf3e9d474895de4dca01ee39a6afd8b82 +size 135770 diff --git a/ShapeID/out/2d/curl.png b/ShapeID/out/2d/curl.png new file mode 100644 index 0000000000000000000000000000000000000000..9aab1f522590264a8c42cc066ea444827a4405b1 Binary files /dev/null and b/ShapeID/out/2d/curl.png differ diff --git a/ShapeID/out/2d/image.png b/ShapeID/out/2d/image.png new file mode 100644 index 0000000000000000000000000000000000000000..a07cf92ef6846c3b02c05af31fbdde0828ff5243 Binary files /dev/null and b/ShapeID/out/2d/image.png differ diff --git a/ShapeID/out/2d/image_with_v.png b/ShapeID/out/2d/image_with_v.png new file mode 100644 index 0000000000000000000000000000000000000000..0378967736eaa50634f46e141b9eeb02318339ff --- /dev/null +++ b/ShapeID/out/2d/image_with_v.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6f9309254daadba376969a414084601ee466c9d8e24e9e975d6ca3b93243dbf0 +size 129820 diff --git a/ShapeID/out/2d/mask_curl.png b/ShapeID/out/2d/mask_curl.png new file mode 100644 index 0000000000000000000000000000000000000000..8596c18e9a05b43cae29da9650274e5e303ba337 Binary files /dev/null and b/ShapeID/out/2d/mask_curl.png differ diff --git a/ShapeID/out/2d/mask_image.png b/ShapeID/out/2d/mask_image.png new file mode 100644 index 0000000000000000000000000000000000000000..6905e505782e84a2567ad91c2d944b2da1a1f2f6 Binary files /dev/null and b/ShapeID/out/2d/mask_image.png differ diff --git a/ShapeID/out/2d/progression/New Folder With Items/0.png b/ShapeID/out/2d/progression/New Folder With Items/0.png new file mode 100644 index 0000000000000000000000000000000000000000..0378967736eaa50634f46e141b9eeb02318339ff --- /dev/null +++ b/ShapeID/out/2d/progression/New Folder With Items/0.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6f9309254daadba376969a414084601ee466c9d8e24e9e975d6ca3b93243dbf0 +size 129820 diff --git a/ShapeID/out/2d/progression/New Folder With Items/1.png b/ShapeID/out/2d/progression/New Folder With Items/1.png new file mode 100644 index 0000000000000000000000000000000000000000..74d2a9818453f72bad8893b6beb86b029d8ac57f --- /dev/null +++ b/ShapeID/out/2d/progression/New Folder With Items/1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:62b3b2cab8bffb63debd309c0152825b0c63d808ba80d2032fe6f14a88f83776 +size 129382 diff --git a/ShapeID/out/2d/progression/New Folder With Items/10.png b/ShapeID/out/2d/progression/New Folder With Items/10.png new file mode 100644 index 0000000000000000000000000000000000000000..ed4cbeba59268849dcbbd38b00d954ac8cce5384 --- /dev/null +++ b/ShapeID/out/2d/progression/New Folder With Items/10.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:988889ec6081f53d2dad6711a01e6378bc211b263ba51be74ab6c7ec17f9c166 +size 127392 diff --git a/ShapeID/out/2d/progression/New Folder With Items/2.png b/ShapeID/out/2d/progression/New Folder With Items/2.png new file mode 100644 index 0000000000000000000000000000000000000000..c5df2364d787cf73cde8ab9f0585c7cccca35e6e --- /dev/null +++ b/ShapeID/out/2d/progression/New Folder With Items/2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bb03a91920fbdcb7d44024b8e5b71123de9e95499c5fa41a1f84b768ef30a513 +size 129175 diff --git a/ShapeID/out/2d/progression/New Folder With Items/3.png b/ShapeID/out/2d/progression/New Folder With Items/3.png new file mode 100644 index 0000000000000000000000000000000000000000..2949a4fa9afbde2e766dc07d992cf5798812cde8 --- /dev/null +++ b/ShapeID/out/2d/progression/New Folder With Items/3.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fa8342fa0061c77e36f8f4a067723f8908688a0a681fe321f094b15ba3350bde +size 128632 diff --git a/ShapeID/out/2d/progression/New Folder With Items/4.png b/ShapeID/out/2d/progression/New Folder With Items/4.png new file mode 100644 index 0000000000000000000000000000000000000000..40b58583f9db4d8784186718c8035d84217705b9 --- /dev/null +++ b/ShapeID/out/2d/progression/New Folder With Items/4.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:38defe398c16079b092d57f10547a66a4b43bbfa3dffa9251b69ee3033ed641e +size 128770 diff --git a/ShapeID/out/2d/progression/New Folder With Items/5.png b/ShapeID/out/2d/progression/New Folder With Items/5.png new file mode 100644 index 0000000000000000000000000000000000000000..abebaee3ee2f7219e0bb229227739cba851bc262 --- /dev/null +++ b/ShapeID/out/2d/progression/New Folder With Items/5.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7d17d81aa68e4b82dd4ad55dfcde57b71a9dfd903822740a7497aace51f7f701 +size 128631 diff --git a/ShapeID/out/2d/progression/New Folder With Items/6.png b/ShapeID/out/2d/progression/New Folder With Items/6.png new file mode 100644 index 0000000000000000000000000000000000000000..40230310739f45e9ca7b8ed19678627678b4a602 --- /dev/null +++ b/ShapeID/out/2d/progression/New Folder With Items/6.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b234b7316d8e3996d5230dd0228e0e64694d21888c4818bb8ddd5f5ecc1e9152 +size 128204 diff --git a/ShapeID/out/2d/progression/New Folder With Items/7.png b/ShapeID/out/2d/progression/New Folder With Items/7.png new file mode 100644 index 0000000000000000000000000000000000000000..affe6c3d80230f7b147df663d3b6fa1304ca7c4c --- /dev/null +++ b/ShapeID/out/2d/progression/New Folder With Items/7.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c0599491ec0eb5c1342419ff6d3233e9b36f75fb1577885ff5ccf933d95d91af +size 127900 diff --git a/ShapeID/out/2d/progression/New Folder With Items/8.png b/ShapeID/out/2d/progression/New Folder With Items/8.png new file mode 100644 index 0000000000000000000000000000000000000000..689cee1824f5230800eb5064307e3b6819839b0c --- /dev/null +++ b/ShapeID/out/2d/progression/New Folder With Items/8.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e0731e0d4d0f76b96ec920a1529ffabeba48d7296f2afcaffe1e9586c64a2ca4 +size 127770 diff --git a/ShapeID/out/2d/progression/New Folder With Items/9.png b/ShapeID/out/2d/progression/New Folder With Items/9.png new file mode 100644 index 0000000000000000000000000000000000000000..b282721369dcaf1910b0da2e30d95e7309d4c891 --- /dev/null +++ b/ShapeID/out/2d/progression/New Folder With Items/9.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:45c29c66908a041571b4cbeee0f75df74aaa58ccdd37d6b4a4f39e5019ffbd97 +size 127713 diff --git a/ShapeID/out/2d/progression/noise_progression.nii.gz b/ShapeID/out/2d/progression/noise_progression.nii.gz new file mode 100644 index 0000000000000000000000000000000000000000..dde45040937e41aa228d4fded81103cc74951458 --- /dev/null +++ b/ShapeID/out/2d/progression/noise_progression.nii.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4d4ba4426d352b688d9f3cde650600585f41ea01a83981f564898efbba949e82 +size 35245 diff --git a/ShapeID/out/3d/image.nii.gz b/ShapeID/out/3d/image.nii.gz new file mode 100644 index 0000000000000000000000000000000000000000..7348937fad009f52a5ef33efb702c9d93949853c --- /dev/null +++ b/ShapeID/out/3d/image.nii.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1cca2d6569746883aa83fe7dd183af3dee2030a706c0f2bc69c84caf0babb98c +size 7656439 diff --git a/ShapeID/out/3d/mask_image.nii.gz b/ShapeID/out/3d/mask_image.nii.gz new file mode 100644 index 0000000000000000000000000000000000000000..af0de3369ec7e583559e9363fea7aaa451f080b4 --- /dev/null +++ b/ShapeID/out/3d/mask_image.nii.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b5bf62a2186c0924987ff89e8e891feeffb44fddbe4fc69e0e6ad4d99470b686 +size 104722 diff --git a/ShapeID/out/3d/progression/noise_0.nii.gz b/ShapeID/out/3d/progression/noise_0.nii.gz new file mode 100644 index 0000000000000000000000000000000000000000..93d1624ab81b229a3a15ff22c2a0b10d041e0a55 --- /dev/null +++ b/ShapeID/out/3d/progression/noise_0.nii.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c1e0cefe610081bef5b22881751b8e3bae1d452b6622f384112d31857403b103 +size 469584 diff --git a/ShapeID/out/3d/progression/noise_0_mask.nii.gz b/ShapeID/out/3d/progression/noise_0_mask.nii.gz new file mode 100644 index 0000000000000000000000000000000000000000..39c8ca22007b4d537a2ba6f764075e4aa4cb45b3 --- /dev/null +++ b/ShapeID/out/3d/progression/noise_0_mask.nii.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:58906767955d0dc6588992f01e85800d0da44dbc40b8757fbe48fb01937c3c39 +size 72723 diff --git a/ShapeID/out/3d/progression/noise_1.nii.gz b/ShapeID/out/3d/progression/noise_1.nii.gz new file mode 100644 index 0000000000000000000000000000000000000000..fe282228b486af27430adb5b7c14081378936fd2 --- /dev/null +++ b/ShapeID/out/3d/progression/noise_1.nii.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ec4fe9176f932eab177599644be6f77e9bfd9e74d4b6517378000c461641e7c0 +size 436758 diff --git a/ShapeID/out/3d/progression/noise_1_mask.nii.gz b/ShapeID/out/3d/progression/noise_1_mask.nii.gz new file mode 100644 index 0000000000000000000000000000000000000000..f73c6b438270f7cc75fc258dde801947614b708a --- /dev/null +++ b/ShapeID/out/3d/progression/noise_1_mask.nii.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9467b45f870cdc60c0d9ac96cbfade609760e5b88dd41eaa5a02c406a18adbb3 +size 69852 diff --git a/ShapeID/out/3d/progression/noise_2.nii.gz b/ShapeID/out/3d/progression/noise_2.nii.gz new file mode 100644 index 0000000000000000000000000000000000000000..a34b9dd7f6db410f940e014bb4c4ff2fec125781 --- /dev/null +++ b/ShapeID/out/3d/progression/noise_2.nii.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b2f36e85ece3c0b213e6a809d56d9b81512134e2756e25673958ee17a151cd17 +size 428978 diff --git a/ShapeID/out/3d/progression/noise_2_mask.nii.gz b/ShapeID/out/3d/progression/noise_2_mask.nii.gz new file mode 100644 index 0000000000000000000000000000000000000000..f9664a7c8b171c578eadd19a8794976aba7ed6e8 --- /dev/null +++ b/ShapeID/out/3d/progression/noise_2_mask.nii.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1f7aed1808ca6d8e0dd46ff99831a9fac61200a26041d768cf3f4b49505e69c4 +size 68371 diff --git a/ShapeID/out/3d/progression/noise_3.nii.gz b/ShapeID/out/3d/progression/noise_3.nii.gz new file mode 100644 index 0000000000000000000000000000000000000000..728b278d33b504e6310c14c0641783e684dba15a --- /dev/null +++ b/ShapeID/out/3d/progression/noise_3.nii.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5cc021d369bc9756c2e809599d64ce9e506ddd494a70c69b7145be10f0477cb8 +size 424842 diff --git a/ShapeID/out/3d/progression/noise_3_mask.nii.gz b/ShapeID/out/3d/progression/noise_3_mask.nii.gz new file mode 100644 index 0000000000000000000000000000000000000000..23da22182547f4ec96c0ac5200ac9fbe23c8dda8 --- /dev/null +++ b/ShapeID/out/3d/progression/noise_3_mask.nii.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c4d2905da7a1989c708f0d8f8e95828a235671905fda0e50f5afb639b4c64121 +size 67611 diff --git a/ShapeID/out/3d/progression/noise_4.nii.gz b/ShapeID/out/3d/progression/noise_4.nii.gz new file mode 100644 index 0000000000000000000000000000000000000000..4adf58eadfb0b3cbb2e4bf8bb2453d6b052c73b1 --- /dev/null +++ b/ShapeID/out/3d/progression/noise_4.nii.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c33961ae13e9945ca1ef87522d86d45bb8d0276c0bd615e9659cc8ee655634bc +size 423957 diff --git a/ShapeID/out/3d/progression/noise_4_mask.nii.gz b/ShapeID/out/3d/progression/noise_4_mask.nii.gz new file mode 100644 index 0000000000000000000000000000000000000000..37cbaadadcb19aafc67a20b32ae42a35bb34de9b --- /dev/null +++ b/ShapeID/out/3d/progression/noise_4_mask.nii.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:233b739d50df622d9196b67de224c3bbaf47ed5b7aeb579b7b3115e6e324eb85 +size 67489 diff --git a/ShapeID/perlin2d.py b/ShapeID/perlin2d.py new file mode 100644 index 0000000000000000000000000000000000000000..ecbccd551fd4583ade49638b14be2d6308824ad2 --- /dev/null +++ b/ShapeID/perlin2d.py @@ -0,0 +1,137 @@ +# ported from https://github.com/pvigier/perlin-numpy + +import os, sys +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +import numpy as np + + +def interpolant(t): + return t*t*t*(t*(t*6 - 15) + 10) + + +def generate_perlin_noise_2d( + shape, res, tileable=(False, False), interpolant=interpolant, percentile=None, +): + """Generate a 2D numpy array of perlin noise. + + Args: + shape: The shape of the generated array (tuple of two ints). + This must be a multple of res. + res: The number of periods of noise to generate along each + axis (tuple of two ints). Note shape must be a multiple of + res. + tileable: If the noise should be tileable along each axis + (tuple of two bools). Defaults to (False, False). + interpolant: The interpolation function, defaults to + t*t*t*(t*(t*6 - 15) + 10). + + Returns: + A numpy array of shape shape with the generated noise. + + Raises: + ValueError: If shape is not a multiple of res. + """ + delta = (res[0] / shape[0], res[1] / shape[1]) + d = (shape[0] // res[0], shape[1] // res[1]) + grid = np.mgrid[0:res[0]:delta[0], 0:res[1]:delta[1]]\ + .transpose(1, 2, 0) % 1 + # Gradients + angles = 2*np.pi*np.random.rand(res[0]+1, res[1]+1) + gradients = np.dstack((np.cos(angles), np.sin(angles))) + if tileable[0]: + gradients[-1,:] = gradients[0,:] + if tileable[1]: + gradients[:,-1] = gradients[:,0] + gradients = gradients.repeat(d[0], 0).repeat(d[1], 1) + g00 = gradients[ :-d[0], :-d[1]] + g10 = gradients[d[0]: , :-d[1]] + g01 = gradients[ :-d[0],d[1]: ] + g11 = gradients[d[0]: ,d[1]: ] + # Ramps + n00 = np.sum(np.dstack((grid[:,:,0] , grid[:,:,1] )) * g00, 2) + n10 = np.sum(np.dstack((grid[:,:,0]-1, grid[:,:,1] )) * g10, 2) + n01 = np.sum(np.dstack((grid[:,:,0] , grid[:,:,1]-1)) * g01, 2) + n11 = np.sum(np.dstack((grid[:,:,0]-1, grid[:,:,1]-1)) * g11, 2) + # Interpolation + t = interpolant(grid) + n0 = n00*(1-t[:,:,0]) + t[:,:,0]*n10 + n1 = n01*(1-t[:,:,0]) + t[:,:,0]*n11 + + noise = np.sqrt(2)*((1-t[:,:,1])*n0 + t[:,:,1]*n1) + if percentile is None: + return noise + shres = np.percentile(noise, percentile) + mask = np.zeros_like(noise) + mask[noise >= shres] = 1. + noise *= mask + return noise, mask + + +def generate_fractal_noise_2d( + shape, res, octaves=1, persistence=0.5, + lacunarity=2, tileable=(False, False), + interpolant=interpolant, percentile=None +): + """Generate a 2D numpy array of fractal noise. + + Args: + shape: The shape of the generated array (tuple of two ints). + This must be a multiple of lacunarity**(octaves-1)*res. + res: The number of periods of noise to generate along each + axis (tuple of two ints). Note shape must be a multiple of + (lacunarity**(octaves-1)*res). + octaves: The number of octaves in the noise. Defaults to 1. + persistence: The scaling factor between two octaves. + lacunarity: The frequency factor between two octaves. + tileable: If the noise should be tileable along each axis + (tuple of two bools). Defaults to (False, False). + interpolant: The, interpolation function, defaults to + t*t*t*(t*(t*6 - 15) + 10). + + Returns: + A numpy array of fractal noise and of shape shape generated by + combining several octaves of perlin noise. + + Raises: + ValueError: If shape is not a multiple of + (lacunarity**(octaves-1)*res). + """ + noise = np.zeros(shape) + frequency = 1 + amplitude = 1 + for _ in range(octaves): + noise += amplitude * generate_perlin_noise_2d( + shape, (frequency*res[0], frequency*res[1]), tileable, interpolant + ) + frequency *= lacunarity + amplitude *= persistence + if percentile is None: + return noise + shres = np.percentile(noise, percentile) + mask = np.zeros_like(noise) + mask[noise >= shres] = 1. + noise *= mask + return noise, mask + + + +# generate multiple noises, assign argmax as label + +def generate_fractal2d_batch(num, shape, res, octaves=5): + noises = [] + for i in range(num): + noise, _ = generate_fractal_noise_2d(shape, res, octaves) + noises.append(noise) + idx = np.argmax(np.array(noises), axis = 0) + return idx + + +def generate_perlin2d_batch(num, shape, res): + noises = [] + for i in range(num): + noise, _ = generate_perlin_noise_2d(shape, res) + noises.append(noise) + idx = np.argmax(np.array(noises), axis = 0) + return idx + \ No newline at end of file diff --git a/ShapeID/perlin3d.py b/ShapeID/perlin3d.py new file mode 100644 index 0000000000000000000000000000000000000000..201d7ada7147860020bec5f7e73d53e8ebbf540d --- /dev/null +++ b/ShapeID/perlin3d.py @@ -0,0 +1,158 @@ +# ported from https://github.com/pvigier/perlin-numpy + +import os, sys +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +import torch +import numpy as np +from ShapeID.misc import stream_3D + + +def interpolant(t): + return t*t*t*(t*(t*6 - 15) + 10) + + +def generate_perlin_noise_3d( + shape, res, tileable=(False, False, False), + interpolant=interpolant, percentile=None, +): + """Generate a 3D numpy array of perlin noise. + + Args: + shape: The shape of the generated array (tuple of three ints). + This must be a multiple of res. + res: The number of periods of noise to generate along each + axis (tuple of three ints). Note shape must be a multiple + of res. + tileable: If the noise should be tileable along each axis + (tuple of three bools). Defaults to (False, False, False). + interpolant: The interpolation function, defaults to + t*t*t*(t*(t*6 - 15) + 10). + + Returns: + A numpy array of shape with the generated noise. + + Raises: + ValueError: If shape is not a multiple of res. + """ + delta = (res[0] / shape[0], res[1] / shape[1], res[2] / shape[2]) + d = (shape[0] // res[0], shape[1] // res[1], shape[2] // res[2]) + grid = np.mgrid[0:res[0]:delta[0],0:res[1]:delta[1],0:res[2]:delta[2]] + grid = np.mgrid[0:res[0]:delta[0],0:res[1]:delta[1],0:res[2]:delta[2]] + grid = grid.transpose(1, 2, 3, 0) % 1 + # Gradients + theta = 2*np.pi*np.random.rand(res[0] + 1, res[1] + 1, res[2] + 1) + phi = 2*np.pi*np.random.rand(res[0] + 1, res[1] + 1, res[2] + 1) + gradients = np.stack( + (np.sin(phi)*np.cos(theta), np.sin(phi)*np.sin(theta), np.cos(phi)), + axis=3 + ) + if tileable[0]: + gradients[-1,:,:] = gradients[0,:,:] + if tileable[1]: + gradients[:,-1,:] = gradients[:,0,:] + if tileable[2]: + gradients[:,:,-1] = gradients[:,:,0] + gradients = gradients.repeat(d[0], 0).repeat(d[1], 1).repeat(d[2], 2) + g000 = gradients[ :-d[0], :-d[1], :-d[2]] + g100 = gradients[d[0]: , :-d[1], :-d[2]] + g010 = gradients[ :-d[0],d[1]: , :-d[2]] + g110 = gradients[d[0]: ,d[1]: , :-d[2]] + g001 = gradients[ :-d[0], :-d[1],d[2]: ] + g101 = gradients[d[0]: , :-d[1],d[2]: ] + g011 = gradients[ :-d[0],d[1]: ,d[2]: ] + g111 = gradients[d[0]: ,d[1]: ,d[2]: ] + # Ramps + n000 = np.sum(np.stack((grid[:,:,:,0] , grid[:,:,:,1] , grid[:,:,:,2] ), axis=3) * g000, 3) + n100 = np.sum(np.stack((grid[:,:,:,0]-1, grid[:,:,:,1] , grid[:,:,:,2] ), axis=3) * g100, 3) + n010 = np.sum(np.stack((grid[:,:,:,0] , grid[:,:,:,1]-1, grid[:,:,:,2] ), axis=3) * g010, 3) + n110 = np.sum(np.stack((grid[:,:,:,0]-1, grid[:,:,:,1]-1, grid[:,:,:,2] ), axis=3) * g110, 3) + n001 = np.sum(np.stack((grid[:,:,:,0] , grid[:,:,:,1] , grid[:,:,:,2]-1), axis=3) * g001, 3) + n101 = np.sum(np.stack((grid[:,:,:,0]-1, grid[:,:,:,1] , grid[:,:,:,2]-1), axis=3) * g101, 3) + n011 = np.sum(np.stack((grid[:,:,:,0] , grid[:,:,:,1]-1, grid[:,:,:,2]-1), axis=3) * g011, 3) + n111 = np.sum(np.stack((grid[:,:,:,0]-1, grid[:,:,:,1]-1, grid[:,:,:,2]-1), axis=3) * g111, 3) + # Interpolation + t = interpolant(grid) + n00 = n000*(1-t[:,:,:,0]) + t[:,:,:,0]*n100 + n10 = n010*(1-t[:,:,:,0]) + t[:,:,:,0]*n110 + n01 = n001*(1-t[:,:,:,0]) + t[:,:,:,0]*n101 + n11 = n011*(1-t[:,:,:,0]) + t[:,:,:,0]*n111 + n0 = (1-t[:,:,:,1])*n00 + t[:,:,:,1]*n10 + n1 = (1-t[:,:,:,1])*n01 + t[:,:,:,1]*n11 + + noise = ((1-t[:,:,:,2])*n0 + t[:,:,:,2]*n1) + if percentile is None: + return noise + shres = np.percentile(noise, percentile) + mask = np.zeros_like(noise) + mask[noise >= shres] = 1. + noise *= mask + return noise, mask + + + +def generate_fractal_noise_3d( + shape, res, octaves=1, persistence=0.5, lacunarity=2, + tileable=(False, False, False), interpolant=interpolant, percentile=None, +): + """Generate a 3D numpy array of fractal noise. + + Args: + shape: The shape of the generated array (tuple of three ints). + This must be a multiple of lacunarity**(octaves-1)*res. + res: The number of periods of noise to generate along each + axis (tuple of three ints). Note shape must be a multiple of + (lacunarity**(octaves-1)*res). + octaves: The number of octaves in the noise. Defaults to 1. + persistence: The scaling factor between two octaves. + lacunarity: The frequency factor between two octaves. + tileable: If the noise should be tileable along each axis + (tuple of three bools). Defaults to (False, False, False). + interpolant: The, interpolation function, defaults to + t*t*t*(t*(t*6 - 15) + 10). + + Returns: + A numpy array of fractal noise and of shape generated by + combining several octaves of perlin noise. + + Raises: + ValueError: If shape is not a multiple of + (lacunarity**(octaves-1)*res). + """ + noise = np.zeros(shape) + frequency = 1 + amplitude = 1 + for _ in range(octaves): + noise += amplitude * generate_perlin_noise_3d( + shape, + (frequency*res[0], frequency*res[1], frequency*res[2]), + tileable, + interpolant + ) + frequency *= lacunarity + amplitude *= persistence + + if percentile is None: + return noise + shres = np.percentile(noise, percentile) + mask = np.zeros_like(noise) + mask[noise >= shres] = 1. + noise *= mask + return noise, mask + + +def generate_shape_3d(shape, perlin_res, percentile, device): + pprob, p = generate_perlin_noise_3d(shape, perlin_res, tileable=(True, False, False), percentile=percentile) + return torch.from_numpy(p).to(device), torch.from_numpy(pprob).to(device) + + +def generate_velocity_3d(shape, perlin_res, V_multiplier, device): + curl_a = generate_perlin_noise_3d(shape, perlin_res, tileable=(True, False, False)) + curl_b = generate_perlin_noise_3d(shape, perlin_res, tileable=(True, False, False)) + curl_c = generate_perlin_noise_3d(shape, perlin_res, tileable=(True, False, False)) + Vx, Vy, Vz = stream_3D(torch.from_numpy(curl_a).to(device), + torch.from_numpy(curl_b).to(device), + torch.from_numpy(curl_c).to(device)) + return {'Vx': (Vx * V_multiplier), 'Vy': (Vy * V_multiplier).to(device), 'Vz': (Vz * V_multiplier)} + + diff --git a/Trainer/__init__.py b/Trainer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Trainer/engine.py b/Trainer/engine.py new file mode 100644 index 0000000000000000000000000000000000000000..f4e2ebefd9f9b9145911a37bcdc6c69f324b359d --- /dev/null +++ b/Trainer/engine.py @@ -0,0 +1,319 @@ + +""" +Train and eval functions +""" +import os, random +import math +import time + +import torch +import numpy as np + +import utils.misc as utils +import utils.logging as logging + + +logger = logging.get_logger(__name__) + + + + +def make_results(target, samples, outputs, out_dir): + case_names = target['name'] + results = outputs + case_out_dir = utils.make_dir(os.path.join(out_dir, case_names[0], 'results')) + + if 'aff' in target: + aff = target['aff'][0] + else: + aff = None + + if 'label' in target: + utils.viewVolume(target['label'], aff = aff, names = ['label'], prefix = 'gt_', save_dir = case_out_dir) + if 'image' in target: + utils.viewVolume(target['image'], aff = aff, names = ['image'], prefix = 'gt_', save_dir = case_out_dir) + if 'image_orig' in target: + utils.viewVolume(target['image_orig'], aff = aff, names = ['image_orig'], prefix = 'gt_', save_dir = case_out_dir) + + for i_sample, sample in enumerate(samples): + + if 'bias_field_log' in sample: + utils.viewVolume(torch.exp(sample['bias_field_log']), aff = aff, names = ['bflog'], prefix = 'gt_', postfix = '_#%d' % i_sample, save_dir = case_out_dir) + utils.viewVolume(torch.exp(outputs[i_sample]['bias_field_log']), aff = aff, names = ['bflog'], prefix = 'pd_', postfix = '_#%d' % i_sample, save_dir = case_out_dir) + + if 'input' in sample: + utils.viewVolume(sample['input'], aff = aff, names = ['input'], prefix = '', postfix = '_#%d' % i_sample, save_dir = case_out_dir) + + if 'orig' in sample: + utils.viewVolume(sample['orig'], aff = aff, names = ['orig'], prefix = 'gt_', postfix = '_#%d' % i_sample, save_dir = case_out_dir) + + if 'source' in sample: + utils.viewVolume(sample['source'], aff = aff, names = ['source'], prefix = 'gt_', postfix = '_#%d' % i_sample, save_dir = case_out_dir) + utils.viewVolume(sample['target'], aff = aff, names = ['target'], prefix = 'gt_', postfix = '_#%d' % i_sample, save_dir = case_out_dir) + utils.viewVolume(outputs[i_sample]['tgt_def'], aff = aff, names = ['source'], prefix = 'pd_', postfix = '_#%d' % i_sample, save_dir = case_out_dir) + utils.viewVolume(outputs[i_sample]['src_def'], aff = aff, names = ['target'], prefix = 'pd_', postfix = '_#%d' % i_sample, save_dir = case_out_dir) + + if 'label' in outputs[i_sample]: + utils.viewVolume(outputs[i_sample]['label'], aff = aff, names = ['label'], prefix = 'pd_', postfix = '_#%d' % i_sample, save_dir = case_out_dir) + + if 'image' in outputs[i_sample]: + utils.viewVolume(outputs[i_sample]['image'], aff = aff, names = ['image'], prefix = 'pd_', postfix = '_#%d' % i_sample, save_dir = case_out_dir) + + return results + + + +def train_one_epoch(epoch, gen_args, train_args, model, processors, criterion, data_loader_dict, + scaler, optimizer, lr_scheduler, wd_scheduler, + postprocessor, visualizers, output_dir, device = 'cpu'): + + model.train() + criterion.train() + + seed = int(time.time()) + os.environ['PYTHONHASHSEED'] = str(seed) + np.random.seed(seed) + random.seed(seed) + + metric_logger = utils.MetricLogger( + train_args.log_itr, + delimiter=" ", + debug=train_args.debug) + metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.8f}')) + + header = 'Epoch: [{}/{}]'.format(epoch, train_args.n_epochs) + + max_len = max([len(v) for v in data_loader_dict.values()]) + probs = probs if gen_args.dataset_probs else [1/len(data_loader_dict)] * len(data_loader_dict) + + for itr, (dataset_num, curr_dataset, input_mode, target, samples) in enumerate(metric_logger.log_every(data_loader_dict, max_len, probs, epoch, header=header, train_limit=train_args.train_itr_limit)): + + optimizer.zero_grad() + with torch.cuda.amp.autocast(): + # update weight decay and learning rate according to their schedule + curr_itr = max_len * epoch + itr # global training iteration + for i, param_group in enumerate(optimizer.param_groups): + param_group["lr"] = lr_scheduler[curr_itr] + param_group["weight_decay"] = wd_scheduler[curr_itr] + + samples = utils.nested_dict_to_device(samples, device) + target = utils.nested_dict_to_device(target, device) + + cond = [] + if train_args.condition is not None: + for i in range(len(samples)): + curr_cond = None + if 'mask' in train_args.condition: + samples[i]['input'] *= 1 - target['pathology'] # mask out anomaly # (b, 1, s, r, c) + curr_cond = target['pathology'].to(samples[0]['input'].dtype) + if 'flip' in train_args.condition: + samples[i]['input_flip'] = torch.flip(samples[i]['input'], dims = [2]) + curr_cond = torch.concat([samples[i]['input_flip'], curr_cond], dim = 1) if curr_cond is not None else samples[i]['input_flip'] + cond.append(curr_cond) + + outputs, _ = model(samples, cond = cond) + for processor in processors: + outputs = processor(outputs, target, curr_dataset) + + loss_dict = criterion(outputs, target, samples) + + weight_dict = criterion.weight_dict + losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict) + + # reduce losses over all GPUs for logging purposes + loss_dict_reduced = utils.reduce_dict(loss_dict) + loss_dict_reduced_unscaled = { + f'{k}_unscaled': v for k, v in loss_dict_reduced.items()} + loss_dict_reduced_scaled = { + k: v * weight_dict[k] for k, v in loss_dict_reduced.items() if k in weight_dict} + losses_reduced_scaled = sum(loss_dict_reduced_scaled.values()) + + try: + loss_value = losses_reduced_scaled.item() + except: + logger.info('This iteration does not have any loss applicable, skipping') + torch.cuda.empty_cache() + continue + + if not math.isfinite(loss_value): + #logger.info(f"Loss is {loss_value}, stopping training") + logger.info(f"Loss is {loss_value}, skipping this iteration") + logger.info(loss_dict_reduced) + logger.info(f"Case is {curr_dataset} - {target['name']}, skipping this iteration") + #sys.exit(1) + torch.cuda.empty_cache() + continue + + #losses.backward() # old + scaler.scale(losses).backward() + scaler.unscale_(optimizer) + if train_args.clip_max_norm > 0: + utils.clip_gradients(model, train_args.clip_max_norm) + utils.cancel_gradients_last_layer(epoch, model, train_args.freeze_last_layer) + #optimizer.step() # old + scaler.step(optimizer) + scaler.update() + + # logging + if utils.get_world_size() > 1: + torch.cuda.synchronize() + metric_logger.update(loss = loss_value, + **loss_dict_reduced_scaled, + **loss_dict_reduced_unscaled + ) + metric_logger.update(lr = optimizer.param_groups[0]["lr"]) + metric_logger.update(wd = optimizer.param_groups[0]["weight_decay"]) + + if train_args.debug or (itr % train_args.vis_itr < dataset_num) and visualizers is not None and utils.is_main_process(): + vis_itr = itr - itr % train_args.vis_itr + epoch_vis_dir = utils.make_dir(os.path.join(output_dir, str(epoch), str(vis_itr), curr_dataset + '-' + input_mode)) if epoch is not None else output_dir + + if postprocessor is not None: + outputs, samples, target = postprocessor(gen_args, train_args, outputs, samples, target = target, feats = None, tasks = gen_args.tasks) + if train_args.visualizer.make_results: + make_results(target, samples, outputs, out_dir = epoch_vis_dir) + + visualizers['result'].visualize_all(target, samples, outputs, epoch_vis_dir, + output_names = train_args.output_names + train_args.aux_output_names, target_names = train_args.target_names) + #if 'feature' in visualizers: + # visualizers['feature'].visualize_all_multi(target, samples, outputs, epoch_vis_dir) + + # gather the stats from all processes + metric_logger.synchronize_between_processes() + logger.info("Averaged stats: {}".format(metric_logger)) + + if train_args.debug: + exit() + + return {k: meter.global_avg for k, meter in metric_logger.meters.items()} + + + + +def train_one_epoch_twostage(epoch, gen_args, train_args, pathol_model, task_model, pathol_processors, task_processors, + criterion, data_loader_dict, scaler, optimizer, lr_scheduler, wd_scheduler, + postprocessor, visualizers, output_dir, device = 'cpu'): + + pathol_model.train() + task_model.train() + criterion.train() + + seed = int(time.time()) + os.environ['PYTHONHASHSEED'] = str(seed) + np.random.seed(seed) + random.seed(seed) + + metric_logger = utils.MetricLogger( + train_args.log_itr, + delimiter=" ", + debug=train_args.debug) + metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.8f}')) + + header = 'Epoch: [{}/{}]'.format(epoch, train_args.n_epochs) + + max_len = max([len(v) for v in data_loader_dict.values()]) + probs = probs if gen_args.dataset_probs else [1/len(data_loader_dict)] * len(data_loader_dict) + + for itr, (dataset_num, curr_dataset, input_mode, target, samples) in enumerate(metric_logger.log_every(data_loader_dict, max_len, probs, epoch, header=header, train_limit=train_args.train_itr_limit)): + + optimizer.zero_grad() + with torch.cuda.amp.autocast(): + # update weight decay and learning rate according to their schedule + curr_itr = max_len * epoch + itr # global training iteration + for i, param_group in enumerate(optimizer.param_groups): + param_group["lr"] = lr_scheduler[curr_itr] + param_group["weight_decay"] = wd_scheduler[curr_itr] + + samples = utils.nested_dict_to_device(samples, device) + target = utils.nested_dict_to_device(target, device) + + # stage-0: pathology segmentation prediction + outputs_pathol, _ = pathol_model(samples) + for processor in pathol_processors: + outputs_pathol = processor(outputs_pathol, target, curr_dataset) + + # stage-1: pathology-mask-conditioned inpainting tasks prediction + cond = [] + for i in range(len(samples)): + samples[i]['input_masked'] = samples[i]['input'] * (1 - outputs_pathol[i]['pathology']) # mask out anomaly # (b, 1, s, r, c) + curr_cond = target['pathology'].to(samples[0]['input'].dtype) + if 'flip' in train_args.condition: + samples[i]['input_flip'] = torch.flip(samples[i]['input'], dims = [2]) + curr_cond = torch.concat([samples[i]['input_flip'], curr_cond], dim = 1) + cond.append(curr_cond) + + outputs_task, _ = task_model(samples, input_name = 'input_masked', cond = cond) + for processor in task_processors: + outputs_task = processor(outputs_task, target, curr_dataset) + + outputs = utils.merge_list_of_dict(outputs_task, outputs_pathol) + loss_dict = criterion(outputs, target, samples) + + weight_dict = criterion.weight_dict + losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict) + + # reduce losses over all GPUs for logging purposes + loss_dict_reduced = utils.reduce_dict(loss_dict) + loss_dict_reduced_unscaled = { + f'{k}_unscaled': v for k, v in loss_dict_reduced.items()} + loss_dict_reduced_scaled = { + k: v * weight_dict[k] for k, v in loss_dict_reduced.items() if k in weight_dict} + losses_reduced_scaled = sum(loss_dict_reduced_scaled.values()) + + try: + loss_value = losses_reduced_scaled.item() + except: + logger.info('This iteration does not have any loss applicable, skipping') + torch.cuda.empty_cache() + continue + if not math.isfinite(loss_value): + #logger.info(f"Loss is {loss_value}, stopping training") + logger.info(f"Loss is {loss_value}, skipping this iteration") + logger.info(loss_dict_reduced) + logger.info(f"Case is {curr_dataset} - {target['name']}, skipping this iteration") + #sys.exit(1) + torch.cuda.empty_cache() + continue + + + #losses.backward() # old + scaler.scale(losses).backward() + scaler.unscale_(optimizer) + if train_args.clip_max_norm > 0: + utils.clip_gradients(pathol_model, train_args.clip_max_norm) + utils.clip_gradients(task_model, train_args.clip_max_norm) + #optimizer.step() # old + scaler.step(optimizer) + scaler.update() + + # logging + if utils.get_world_size() > 1: + torch.cuda.synchronize() + metric_logger.update(loss = loss_value, + **loss_dict_reduced_scaled, + **loss_dict_reduced_unscaled + ) + metric_logger.update(lr = optimizer.param_groups[0]["lr"]) + metric_logger.update(wd = optimizer.param_groups[0]["weight_decay"]) + + if train_args.debug or (itr % train_args.vis_itr < dataset_num) and visualizers is not None and utils.is_main_process(): + vis_itr = itr - itr % train_args.vis_itr + epoch_vis_dir = utils.make_dir(os.path.join(output_dir, str(epoch), str(vis_itr), curr_dataset + '-' + input_mode)) if epoch is not None else output_dir + + if postprocessor is not None: + outputs, samples, target = postprocessor(gen_args, train_args, outputs, samples, target = target, feats = None, tasks = gen_args.tasks) + + visualizers['result'].visualize_all(target, samples, outputs, epoch_vis_dir, + output_names = train_args.output_names + train_args.aux_output_names, target_names = train_args.target_names) + #if 'feature' in visualizers: + # visualizers['feature'].visualize_all_multi(target, samples, outputs, epoch_vis_dir) + + # gather the stats from all processes + metric_logger.synchronize_between_processes() + logger.info("Averaged stats: {}".format(metric_logger)) + + if train_args.debug: + exit() + + return {k: meter.global_avg for k, meter in metric_logger.meters.items()} + diff --git a/Trainer/models/__init__.py b/Trainer/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..396b176679a968f13e67b12d2773df77492dcb37 --- /dev/null +++ b/Trainer/models/__init__.py @@ -0,0 +1,465 @@ + + +""" +Submodule interface. +""" +import torch + +from .backbone import build_backbone +from .criterion import * +from .evaluator import Evaluator +from .head import get_head +from .joiner import get_processors, get_joiner +import utils.misc as utils + + +######################################### + +# some constants +label_list_segmentation_brainseg_left = [0, 1, 2, 3, 4, 7, 8, 9, 10, 14, 15, 17, 31, 34, 36, 38, 40, 42] +n_labels_brainseg_left = len(label_list_segmentation_brainseg_left) + +label_list_segmentation_brainseg_with_extracerebral = [0, 11, 12, 13, 16, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 14, 15, 17, 47, 49, 51, 53, 55, + 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 48, 50, 52, 54, 56] +n_neutral_labels_brainseg_with_extracerebral = 20 +n_labels_brainseg_with_extracerebral = len(label_list_segmentation_brainseg_with_extracerebral) +nlat = int((n_labels_brainseg_with_extracerebral - n_neutral_labels_brainseg_with_extracerebral) / 2.0) +vflip = np.concatenate([np.array(range(n_neutral_labels_brainseg_with_extracerebral)), + np.array(range(n_neutral_labels_brainseg_with_extracerebral + nlat, n_labels_brainseg_with_extracerebral)), + np.array(range(n_neutral_labels_brainseg_with_extracerebral, n_neutral_labels_brainseg_with_extracerebral + nlat))]) + + +############################################ +############# helper functions ############# +############################################ + +def process_args(gen_args, train_args, task): + """ + task options: feat-anat, feat-seg, feat-anat-seg, anat, seg, reg, sr, bf + """ + gen_args.tasks = [key for (key, value) in vars(task).items() if value] + + gen_args.generator.size = gen_args.generator.size # update real sample size (if sample_size is downsampled) + train_args.size = gen_args.generator.size + + if gen_args.generator.left_hemis_only: + gen_args.label_list_segmentation = label_list_segmentation_brainseg_left + gen_args.n_labels = n_labels_brainseg_left + else: + gen_args.label_list_segmentation = label_list_segmentation_brainseg_with_extracerebral + gen_args.n_labels = n_labels_brainseg_with_extracerebral + + train_args.out_channels = {} + train_args.output_names = [] + train_args.aux_output_names = [] + train_args.target_names = [] + if not 'contrastive' in gen_args.tasks: + if 'T1' in gen_args.tasks: + train_args.out_channels['T1'] = 2 if train_args.losses.uncertainty is not None else 1 + train_args.output_names += ['T1'] + train_args.target_names += ['T1'] + if train_args.losses.uncertainty is not None: + train_args.aux_output_names += ['T1_sigma'] + if 'T2' in gen_args.tasks: + train_args.out_channels['T2'] = 2 if train_args.losses.uncertainty is not None else 1 + train_args.output_names += ['T2'] + train_args.target_names += ['T2'] + if train_args.losses.uncertainty is not None: + train_args.aux_output_names += ['T2_sigma'] + if 'FLAIR' in gen_args.tasks: + train_args.out_channels['FLAIR'] = 2 if train_args.losses.uncertainty is not None else 1 + train_args.output_names += ['FLAIR'] + train_args.target_names += ['FLAIR'] + if train_args.losses.uncertainty is not None: + train_args.aux_output_names += ['FLAIR_sigma'] + if 'CT' in gen_args.tasks: + train_args.out_channels['CT'] = 2 if train_args.losses.uncertainty is not None else 1 + train_args.output_names += ['CT'] + train_args.target_names += ['CT'] + if train_args.losses.uncertainty is not None: # TODO + train_args.aux_output_names += ['CT_sigma'] + if 'bias_field' in gen_args.tasks: + train_args.out_channels['bias_field_log'] = 2 if train_args.losses.uncertainty is not None else 1 + train_args.output_names += ['bias_field'] + train_args.target_names += ['bias_field'] + if 'segmentation' in gen_args.tasks: + train_args.out_channels['segmentation'] = gen_args.n_labels + train_args.output_names += ['label'] + train_args.target_names += ['label'] + if 'distance' in gen_args.tasks: + if gen_args.generator.left_hemis_only: + train_args.out_channels['distance'] = 2 + train_args.output_names += ['distance', 'lp', 'lw'] + train_args.target_names += ['distance', 'lp', 'lw'] + else: + train_args.out_channels['distance'] = 4 + train_args.output_names += ['distance', 'lp', 'lw', 'rp', 'rw'] + train_args.target_names += ['distance', 'lp', 'lw', 'rp', 'rw'] + if 'registration' in gen_args.tasks: + train_args.out_channels['registration'] = 3 + train_args.output_names += ['registration', 'regx', 'regy', 'regz'] + train_args.target_names += ['registration', 'regx', 'regy', 'regz'] + if 'surface' in gen_args.tasks: + train_args.out_channels['surface'] = 8 + train_args.output_names += ['surface'] + train_args.target_names += ['surface'] + if 'super_resolution' in gen_args.tasks: + train_args.out_channels['high_res_residual'] = 2 if train_args.losses.uncertainty is not None else 1 + train_args.output_names += ['high_res', 'high_res_residual'] + train_args.target_names += ['high_res', 'high_res_residual'] + if 'pathology' in gen_args.tasks: + train_args.out_channels['pathology'] = 1 + train_args.output_names += ['pathology'] + train_args.target_names += ['pathology'] + + if 'age' in gen_args.tasks: + train_args.out_channels['age'] = -1 + + if train_args.losses.implicit_pathol: # TODO + train_args.output_names += ['implicit_pathol_orig'] + train_args.output_names += ['implicit_pathol_pred'] + + #assert len(train_args.output_names) > 0 + + return gen_args, train_args + +############################################ +################ CRITERIONS ################ +############################################ + +def get_evaluator(args, task, device): + """ + task options: sr, seg, anat, reg + """ + metric_names = [] + if 'T1' in task or 'T2' in task or 'FLAIR' in task or 'CT' in task: + metric_names += ['feat_ssim', 'feat_ms_ssim', 'feat_l1'] + else: + if 'T1' in task: # TODO + metric_names += ['recon_l1', 'recon_psnr', 'recon_ssim', 'recon_ms_ssim'] + if 'super_resolution' in task: + metric_names += ['sr_l1', 'sr_psnr', 'sr_ssim', 'sr_ms_ssim'] + if 'bias_field' in task: + metric_names += ['bf_normalized_l2', 'bf_corrected_l1'] + if 'segmentation' in task: + metric_names += ['seg_dice'] + if 'pathology' in task: + metric_names += ['pathol_dice'] + + assert len(metric_names) > 0 + + evaluator = Evaluator( + args = args, + metric_names = metric_names, + device = device, + ) + + return evaluator + + + +def get_criterion(gen_args, train_args, tasks, device, exclude_keys = []): + """ + task options: sr, seg, anat, reg + """ + loss_names = [] + weight_dict = {} + + if 'contrastive' in tasks: + loss_names += ['contrastive'] + weight_dict['loss_contrastive'] = train_args.weights.contrastive + return SetCriterion( + gen_args = gen_args, + train_args = train_args, + weight_dict = weight_dict, + loss_names = loss_names, + device = device, + ) + + + for task in tasks: + + if 'T1' in task or 'T2' in task or 'FLAIR' in task or 'CT' in task: + name = task + + loss_names += [name] + weight_dict.update({'loss_%s' % name: train_args.weights.image}) + if train_args.losses.image_grad: + loss_names += ['%s_grad' % name] + weight_dict['loss_%s_grad' % name] = train_args.weights.image_grad + + if 'segmentation' in task: + loss_names += ['seg_ce', 'seg_dice'] + weight_dict.update( { + 'loss_seg_ce': train_args.weights.seg_ce, + 'loss_seg_dice': train_args.weights.seg_dice, + } ) + + if 'bias_field' in task: + loss_names += ['bias_field_log'] + weight_dict.update( { + 'loss_bias_field_log': train_args.weights.bias_field_log, + } ) + + if 'super_resolution' in task: + loss_names += ['SR'] + weight_dict.update( { + 'loss_SR': train_args.weights.image, + } ) + if train_args.losses.image_grad: + loss_names += ['SR_grad'] + weight_dict['loss_SR_grad'] = train_args.weights.image_grad + + if 'distance' in task: + loss_names += ['distance'] + weight_dict.update( { + 'loss_distance': train_args.weights.distance, + } ) + + if 'registration' in task: + loss_names += ['registration'] + weight_dict.update( { + 'loss_registration': train_args.weights.registration, + } ) + if train_args.losses.registration_grad: + loss_names += ['registration_grad'] + weight_dict['loss_registration_grad'] = train_args.weights.registration_grad + if train_args.losses.registration_smooth: + loss_names += ['registration_smooth'] + weight_dict['loss_registration_smooth'] = train_args.weights.registration_smooth + if train_args.losses.registration_hessian: + loss_names += ['registration_hessian'] + weight_dict['loss_registration_hessian'] = train_args.weights.registration_hessian + + if 'surface' in task: + loss_names += ['surface'] + weight_dict['loss_surface'] = train_args.weights.surface + + if 'age' in task: + loss_names += ['age'] + weight_dict['loss_age'] = train_args.weights.age + + if 'pathology' in task and 'pathology' not in exclude_keys: + loss_names += ['pathol_ce', 'pathol_dice'] + weight_dict.update( { + 'loss_pathol_ce': train_args.weights.pathol_ce, + 'loss_pathol_dice': train_args.weights.pathol_dice, + } ) + + if train_args.losses.implicit_pathol: + loss_names += ['implicit_pathol_ce', 'implicit_pathol_dice'] + weight_dict.update( { + 'loss_implicit_pathol_ce': train_args.weights.implicit_pathol_ce, + 'loss_implicit_pathol_dice': train_args.weights.implicit_pathol_dice, + } ) + + assert len(loss_names) > 0 + + criterion = SetMultiCriterion( + gen_args = gen_args, + train_args = train_args, + weight_dict = weight_dict, + loss_names = loss_names, + device = device, + ) + + return criterion + + + + +def get_postprocessor(gen_args, train_args, outputs, samples, target, feats, tasks): + """ + output: list of output dict + feat: list of output dict from pre-trained feat extractor + """ + + if 'distance' in tasks and target is not None: + if gen_args.generator.left_hemis_only: + target.update({'lp': target['distance'][:, 0][:, None], + 'lw': target['distance'][:, 1][:, None]}) + else: + target.update({'lp': target['distance'][:, 0][:, None], + 'lw': target['distance'][:, 1][:, None], + 'rp': target['distance'][:, 2][:, None], + 'rw': target['distance'][:, 3][:, None]}) + del target['distance'] + + if 'registration' in tasks and target is not None: + target.update({'regx': target['registration'][:, 0][:, None], + 'regy': target['registration'][:, 1][:, None], + 'regz': target['registration'][:, 2][:, None]}) + del target['registration'] + + if 'CT' in tasks and target is not None: + target['CT'] = target['CT'] * 1000 + + if 'segmentation' in tasks and target is not None: + target['label'] = torch.tensor(gen_args.label_list_segmentation, + device = target['segmentation'].device)[torch.argmax(target['segmentation'], 1, keepdim = True)] # (b, n_labels, s, r, c) -> (b, s, r, c) + + for i, output in enumerate(outputs): + + if feats is not None: + output.update({'feat': feats[i]['feat']}) + + if 'super_resolution' in tasks: + output.update({'high_res': output['high_res_residual'] + samples[i]['input']}) + if 'high_res_residual' in samples[i]: + samples[i].update({'high_res': samples[i]['high_res_residual'] + samples[i]['input']}) + + if 'bias_field' in tasks: + output.update({'bias_field': torch.exp(output['bias_field_log'])}) + del output['bias_field_log'] + + if 'bias_field_log' in samples[i]: + samples[i].update({'bias_field': torch.exp(samples[i]['bias_field_log'])}) + del samples[i]['bias_field_log'] + + if 'distance' in tasks: + + a = 2 + + if gen_args.generator.left_hemis_only: + output.update({'lp': output['distance'][:, 0][:, None], + 'lw': output['distance'][:, 1][:, None]}) + fake = 70 * (1 - (torch.tanh(a * (output['lw'] + 0.3)) + 1) / 2) + 40 * (1 - (torch.tanh(a * output['lp']) + 1) / 2) + else: + output.update({'lp': output['distance'][:, 0][:, None], + 'lw': output['distance'][:, 1][:, None], + 'rp': output['distance'][:, 2][:, None], + 'rw': output['distance'][:, 3][:, None]}) + + fakeL = 70 * (1 - (torch.tanh(a * (output['lw'] + 0.3)) + 1) / 2) + 40 * (1 - (torch.tanh(a * output['lp']) + 1) / 2) + fakeR = 70 * (1 - (torch.tanh(a * (output['rw'] + 0.3)) + 1) / 2) + 40 * (1 - (torch.tanh(a * output['rp']) + 1) / 2) + fake = fakeL + fakeR + + output.update({'fake_cortical': fake}) + del output['distance'] + + if 'registration' in tasks: + output.update({'regx': output['registration'][:, 0][:, None], + 'regy': output['registration'][:, 1][:, None], + 'regz': output['registration'][:, 2][:, None]}) + del output['registration'] + + if 'segmentation' in tasks: + output['label'] = torch.tensor(gen_args.label_list_segmentation, + device = output['segmentation'].device)[torch.argmax(output['segmentation'], 1, keepdim = True)] # (b, n_labels, s, r, c) -> (b, s, r, c) + + if 'CT' in tasks: + output['CT'] = output['CT'] * 1000 + + return outputs, samples, target + + +############################################# +################ OPTIMIZERS ################# +############################################# + + +def build_optimizer(train_args, params_groups): + if train_args.optimizer == "adam": + return torch.optim.Adam(params_groups) + elif train_args.optimizer == "adamw": + return torch.optim.AdamW(params_groups) # to use with ViTs + elif train_args.optimizer == "sgd": + return torch.optim.SGD(params_groups, lr=0, momentum=0.9) # lr is set by scheduler + elif train_args.optimizer == "lars": + return utils.LARS(params_groups) # to use with convnet and large batches + else: + ValueError('optim type {args.optimizer.type} supported!') + + +def build_schedulers(train_args, itr_per_epoch, lr, min_lr): + if train_args.lr_scheduler == "cosine": + lr_scheduler = utils.cosine_scheduler( + lr, # * (args.batch_size * utils.get_world_size()) / 256., # linear scaling rule + min_lr, + train_args.n_epochs, itr_per_epoch, + warmup_epochs=train_args.warmup_epochs + ) + elif train_args.lr_scheduler == "multistep": + lr_scheduler = utils.multistep_scheduler( + lr, + train_args.lr_drops, + train_args.n_epochs, itr_per_epoch, + warmup_epochs=train_args.warmup_epochs, + gamma=train_args.lr_drop_multi + ) + wd_scheduler = utils.cosine_scheduler( + train_args.weight_decay, # set as 0 to disable it + train_args.weight_decay_end, + train_args.n_epochs, itr_per_epoch + ) + return lr_scheduler, wd_scheduler + + +############################################ +################## MODELS ################## +############################################ + + +def build_model(gen_args, train_args, device = 'cpu'): + gen_args, train_args = process_args(gen_args, train_args, task = gen_args.task) + + backbone = build_backbone(train_args, train_args.backbone) + head = get_head(train_args, train_args.task_f_maps, train_args.out_channels, True, -1) + model = get_joiner(gen_args.tasks, backbone, head, device) + + processors = get_processors(gen_args, train_args, gen_args.tasks, device) + + criterion = get_criterion(gen_args, train_args, gen_args.tasks, device) + + criterion.to(device) + + model.to(device) + postprocessor = get_postprocessor + + return gen_args, train_args, model, processors, criterion, postprocessor + + +def build_conditioned_model(gen_args, train_args, device = 'cpu'): # mask-conditioned inpaiting + gen_args, train_args = process_args(gen_args, train_args, task = gen_args.task) + + backbone = build_backbone(train_args, train_args.backbone, num_cond = len(train_args.condition.split('+'))) + head = get_head(train_args, train_args.task_f_maps, train_args.out_channels, True, -1, stage = 1, exclude_keys = ['pathology']) + model = get_joiner(gen_args.tasks, backbone, head, device) + processors = get_processors(gen_args, train_args, gen_args.tasks, device, exclude_keys = ['pathology']) + + criterion = get_criterion(gen_args, train_args, gen_args.tasks, device, exclude_keys = ['pathology']) + criterion.to(device) + + model.to(device) + postprocessor = get_postprocessor + + return gen_args, train_args, model, processors, criterion, postprocessor + + + +def build_inpaint_model(gen_args, train_args, device = 'cpu'): # two-stage inpainting + gen_args, train_args = process_args(gen_args, train_args, task = gen_args.task) + + # stage-0: pathology mask prediction + pathol_backbone = build_backbone(train_args, train_args.backbone.split('+')[0], num_cond = 0) + pathol_head = get_head(train_args, train_args.task_f_maps, train_args.out_channels, True, -1, stage = 0) + pathol_model = get_joiner(gen_args.tasks, pathol_backbone, pathol_head, device, postfix = '_pathol') + pathol_processors = get_processors(train_args, ['pathology'], device) + + # stage-1: pathology-mask-conditioned task prediction (inpainting) + task_backbone = build_backbone(train_args, train_args.backbone.split('+')[1], num_cond = 1) + task_head = get_head(train_args, train_args.task_f_maps, train_args.out_channels, True, -1, stage = 1) + task_model = get_joiner(gen_args.tasks, task_backbone, task_head, device, postfix = '_task') + task_processors = get_processors(gen_args, train_args, gen_args.tasks, device, exclude_keys = ['pathology']) + + criterion = get_criterion(gen_args, train_args, gen_args.tasks, device) + criterion.to(device) + + pathol_model.to(device) + task_model.to(device) + postprocessor = get_postprocessor + + return gen_args, train_args, pathol_model, task_model, pathol_processors, task_processors, criterion, postprocessor + + \ No newline at end of file diff --git a/Trainer/models/backbone.py b/Trainer/models/backbone.py new file mode 100644 index 0000000000000000000000000000000000000000..383174c82d7620fd67170d1e51c868e50b47d75d --- /dev/null +++ b/Trainer/models/backbone.py @@ -0,0 +1,27 @@ +""" +Backbone modules. +""" + +from Trainer.models.unet3d.model import UNet3D, UNet2D, UNet3DSep +#from Trainer.models.guided_diffusion.script_util import create_model_and_diffusion, model_and_diffusion_defaults + + +backbone_options = { + 'unet2d': UNet2D, + 'unet3d': UNet3D, + 'unet3d_2stage': UNet3D, + 'unet3d_sep': UNet3DSep, +} + + + +#################################### + + +def build_backbone(args, backbone, num_cond=0): + backbone = backbone_options[backbone](args.in_channels + num_cond, args.f_maps, + args.layer_order, args.num_groups, args.num_levels, + args.unit_feat, + ) + return backbone + \ No newline at end of file diff --git a/Trainer/models/criterion.py b/Trainer/models/criterion.py new file mode 100644 index 0000000000000000000000000000000000000000..f3642bcefc2110722d0b318350c8b1494ba867f3 --- /dev/null +++ b/Trainer/models/criterion.py @@ -0,0 +1,365 @@ +""" +Criterion modules. +""" + +import numpy as np +import torch +import torch.nn as nn + +from Trainer.models.losses import GradientLoss, SmoothnessLoss, HessianLoss, gaussian_loss, laplace_loss, l1_loss +from utils.misc import viewVolume + +uncertainty_loss = {'gaussian': gaussian_loss, 'laplace': laplace_loss} + + +class SetCriterion(nn.Module): + """ + This class computes the loss for BrainID. + """ + def __init__(self, gen_args, train_args, weight_dict, loss_names, device): + """ Create the criterion. + Parameters: + args: general exp cfg + weight_dict: dict containing as key the names of the losses and as values their + relative weight. + loss_names: list of all the losses to be applied. See get_loss for list of + available loss_names. + """ + super(SetCriterion, self).__init__() + self.gen_args = gen_args + self.train_args = train_args + self.weight_dict = weight_dict + self.loss_names = loss_names + + self.mse = nn.MSELoss() + + self.loss_regression_type = train_args.losses.uncertainty if train_args.losses.uncertainty is not None else 'l1' + self.loss_regression = uncertainty_loss[train_args.losses.uncertainty] if train_args.losses.uncertainty is not None else l1_loss + + self.grad = GradientLoss('l1') + self.smoothness = SmoothnessLoss('l2') + self.hessian = HessianLoss('l2') + + self.bflog_loss = nn.L1Loss() if train_args.losses.bias_field_log_type == 'l1' else self.mse + + if 'contrastive' in self.loss_names: + self.temp_alpha = train_args.contrastive_temperatures.alpha + self.temp_beta = train_args.contrastive_temperatures.beta + self.temp_gamma = train_args.contrastive_temperatures.gamma + + # initialize weights # NOTE all = 1 for now + weights_brainseg = torch.ones(gen_args.n_labels).to(device) + weights_brainseg[gen_args.label_list_segmentation_with_csf==77] = train_args.relative_weight_lesions # give (more) importance to lesions + weights_brainseg = weights_brainseg / torch.sum(weights_brainseg) + + self.weights_ce = weights_brainseg[None, :, None, None, None] + self.weights_dice = weights_brainseg[None, :] + + # archived + #self.csf_ind = torch.tensor(np.where(np.array(gen_args.label_list_segmentation)==24)[0][0]) + #self.csf_v = torch.tensor(np.concatenate([np.arange(0, self.csf_ind), np.arange(self.csf_ind+1, gen_args.n_labels)])) + + self.loss_map = { + 'seg_ce': self.loss_seg_ce, + 'seg_dice': self.loss_seg_dice, + 'pathol_ce': self.loss_pathol_ce, + 'pathol_dice': self.loss_pathol_dice, + 'implicit_pathol_ce': self.loss_implicit_pathol_ce, + 'implicit_pathol_dice': self.loss_implicit_pathol_dice, + 'implicit_aux_pathol_ce': self.loss_implicit_aux_pathol_ce, + 'implicit_aux_pathol_dice': self.loss_implicit_aux_pathol_dice, + + 'T1': self.loss_T1, + 'T1_grad': self.loss_T1_grad, + 'T2': self.loss_T2, + 'T2_grad': self.loss_T2_grad, + 'FLAIR': self.loss_FLAIR, + 'FLAIR_grad': self.loss_FLAIR_grad, + 'CT': self.loss_CT, + 'CT_grad': self.loss_CT_grad, + 'SR': self.loss_SR, + 'SR_grad': self.loss_SR_grad, + + "age": self.loss_age, + "distance": self.loss_distance, + "registration": self.loss_registration, + "registration_grad": self.loss_registration_grad, + "registration_hessian": self.loss_registration_hessian, + "registration_smooth": self.loss_registration_smooth, + "bias_field_log": self.loss_bias_field_log, + 'contrastive': self.loss_feat_contrastive, + + "surface": self.loss_surface, # TODO + #'supervised_seg': self.loss_supervised_seg, # archived + } + + def loss_feat_contrastive(self, outputs, *kwargs): + """ + outputs: [feat1, feat2] + feat shape: (b, feat_dim, s, r, c) + """ + feat1, feat2 = outputs[0]['feat'][-1], outputs[1]['feat'][-1] + num = torch.sum(torch.exp(feat1 * feat2 / self.temp_alpha), dim = 1) + den = torch.zeros_like(feat1[:, 0]) + for i in range(feat1.shape[1]): + den1 = torch.exp(feat1[:, i] ** 2 / self.temp_beta) + den2 = torch.exp((torch.sum(feat1[:, i][:, None] * feat1, dim = 1) - feat1[:, i] ** 2) / self.temp_gamma) + den += den1 + den2 + loss_contrastive = torch.mean(- torch.log(num / den)) + return {'loss_contrastive': loss_contrastive} + + def loss_seg_ce(self, outputs, targets, *kwargs): + """ + Cross entropy of segmentation + """ + loss_seg_ce = torch.mean(-torch.sum(torch.log(torch.clamp(outputs['segmentation'], min=1e-5)) * self.weights_ce * targets['segmentation'], dim=1)) + return {'loss_seg_ce': loss_seg_ce} + + def loss_seg_dice(self, outputs, targets, *kwargs): + """ + Dice of segmentation + """ + loss_seg_dice = torch.sum(self.weights_dice * (1.0 - 2.0 * ((outputs['segmentation'] * targets['segmentation']).sum(dim=[2, 3, 4])) + / torch.clamp((outputs['segmentation'] + targets['segmentation']).sum(dim=[2, 3, 4]), min=1e-5))) + return {'loss_seg_dice': loss_seg_dice} + + def loss_implicit_pathol_ce(self, outputs, targets, samples, *kwargs): + """ + Cross entropy of pathology segmentation + """ + if 'implicit_pathol_pred' in outputs: + #loss_implicit_pathol_ce = torch.mean(-torch.sum(torch.log(torch.clamp(outputs['implicit_pathol_pred'], min=1e-5)) * self.weights_ce * samples['pathol'], dim=1)) + loss_implicit_pathol_ce = torch.mean(-torch.sum(torch.log(torch.clamp(outputs['implicit_pathol_pred'], min=1e-5)) * outputs['implicit_pathol_orig'], dim=1)) + else: # no GT image exists + loss_implicit_pathol_ce = 0. + return {'loss_implicit_pathol_ce': loss_implicit_pathol_ce} + + def loss_implicit_pathol_dice(self, outputs, targets, samples, *kwargs): + """ + Dice of pathology segmentation + """ + if 'implicit_pathol_pred' in outputs: + #loss_implicit_pathol_dice = torch.sum(self.weights_dice * (1.0 - 2.0 * ((outputs['implicit_pathol_pred'] * samples['pathol']).sum(dim=[2, 3, 4])) + # / torch.clamp((outputs['implicit_pathol_pred'] + samples['pathol']).sum(dim=[2, 3, 4]), min=1e-5))) + loss_implicit_pathol_dice = torch.sum((1.0 - 2.0 * ((outputs['implicit_pathol_pred'] * outputs['implicit_pathol_orig']).sum(dim=[2, 3, 4])) + / torch.clamp((outputs['implicit_pathol_pred'] + outputs['implicit_pathol_orig']).sum(dim=[2, 3, 4]), min=1e-5))) + else: + loss_implicit_pathol_dice = 0. + return {'loss_implicit_pathol_dice': loss_implicit_pathol_dice} + + + def loss_implicit_aux_pathol_ce(self, outputs, targets, samples): + """ + Cross entropy of pathology segmentation + """ + if 'implicit_aux_pathol_pred' in outputs: + #loss_implicit_aux_pathol_ce = torch.mean(-torch.sum(torch.log(torch.clamp(outputs['implicit_aux_pathol_pred'], min=1e-5)) * self.weights_ce * samples['pathol'], dim=1)) + loss_implicit_aux_pathol_ce = torch.mean(-torch.sum(torch.log(torch.clamp(outputs['implicit_aux_pathol_pred'], min=1e-5)) * self.weights_ce * outputs['implicit_aux_pathol_orig'], dim=1)) + else: + loss_implicit_aux_pathol_ce = 0. + return {'loss_implicit_aux_pathol_ce': loss_implicit_aux_pathol_ce} + + def loss_implicit_aux_pathol_dice(self, outputs, targets, samples): + """ + Dice of pathology segmentation + """ + if 'implicit_aux_pathol_pred' in outputs: + #loss_implicit_aux_pathol_dice = torch.sum(self.weights_dice * (1.0 - 2.0 * ((outputs['implicit_aux_pathol_pred'] * samples['pathol']).sum(dim=[2, 3, 4])) + # / torch.clamp((outputs['implicit_aux_pathol_pred'] + samples['pathol']).sum(dim=[2, 3, 4]), min=1e-5))) + loss_implicit_aux_pathol_dice = torch.sum(self.weights_dice * (1.0 - 2.0 * ((outputs['implicit_aux_pathol_pred'] * outputs['implicit_aux_pathol_orig']).sum(dim=[2, 3, 4])) + / torch.clamp((outputs['implicit_aux_pathol_pred'] + outputs['implicit_aux_pathol_orig']).sum(dim=[2, 3, 4]), min=1e-5))) + else: + loss_implicit_aux_pathol_dice = 0. + return {'loss_implicit_aux_pathol_dice': loss_implicit_aux_pathol_dice} + + def loss_surface(self, outputs, targets, *kwargs): + return {'loss_surface': self.loss_image(outputs['surface'], targets['surface'])} + + def loss_distance(self, outputs, targets, *kwargs): + return {'loss_distance': self.loss_image(outputs['distance'], targets['distance'])} + + def loss_registration(self, outputs, targets, *kwargs): + return {'loss_registration': self.loss_image(outputs['registration'], targets['registration'])} + + def loss_registration_grad(self, outputs, targets, *kwargs): + return {'loss_registration_grad': self.loss_image_grad(outputs['registration'], targets['registration'])} + + def loss_registration_smooth(self, outputs, *kwargs): + return {'loss_registration_smooth': self.smoothness(outputs['registration'])} + + def loss_registration_hessian(self, outputs, *kwargs): + return {'loss_registration_hessian': self.hessian(outputs['registration'])} + + def loss_pathol_ce(self, outputs, targets, samples): + """ + Cross entropy of pathology segmentation + """ + if 'pathology' in outputs and outputs['pathology'].shape == targets['pathology'].shape: + loss_pathol_ce = torch.mean(-torch.sum(torch.log(torch.clamp(outputs['pathology'], min=1e-5)) * targets['pathology'], dim=1)) + else: + loss_pathol_ce = 0. + return {'loss_pathol_ce': loss_pathol_ce} + + def loss_pathol_dice(self, outputs, targets, samples): + """ + Dice of pathology segmentation + """ + if 'pathology' in outputs and outputs['pathology'].shape == targets['pathology'].shape: + loss_pathol_dice = torch.sum((1.0 - 2.0 * ((outputs['pathology'] * targets['pathology']).sum(dim=[2, 3, 4])) + / torch.clamp((outputs['pathology'] + targets['pathology']).sum(dim=[2, 3, 4]), min=1e-5))) + else: + loss_pathol_dice = 0. + return {'loss_pathol_dice': loss_pathol_dice} + + + def loss_T1(self, outputs, targets, *kwargs): + #weights = 1. - targets['pathology'] if targets['pathology'].shape == targets['T1'].shape else 1. + weights = 1. - targets['T1_DM'] if 'T1_DM' in targets else 1. + #weights = 1. + return {'loss_T1': self.loss_image(outputs['T1'], targets['T1'], outputs['T1_sigma'] if 'T1_sigma' in outputs else None, weights = weights)} + def loss_T1_grad(self, outputs, targets, *kwargs): + #weights = 1. - targets['pathology'] if targets['pathology'].shape == targets['T1'].shape else 1. + weights = 1. - targets['T1_DM'] if 'T1_DM' in targets else 1. + #weights = 1. + return {'loss_T1_grad': self.loss_image_grad(outputs['T1'], targets['T1'], weights)} + + def loss_T2(self, outputs, targets, *kwargs): + #weights = 1. - targets['pathology'] if targets['pathology'].shape == targets['T2'].shape else 1. + weights = 1. - targets['T2_DM'] if 'T2_DM' in targets else 1. + #weights = 1. + return {'loss_T2': self.loss_image(outputs['T2'], targets['T2'], outputs['T2_sigma'] if 'T2_sigma' in outputs else None, weights)} + def loss_T2_grad(self, outputs, targets, *kwargs): + #weights = 1. - targets['pathology'] if targets['pathology'].shape == targets['T2'].shape else 1. + weights = 1. - targets['T2_DM'] if 'T2_DM' in targets else 1. + #weights = 1. + return {'loss_T2_grad': self.loss_image_grad(outputs['T2'], targets['T2'], weights)} + + def loss_FLAIR(self, outputs, targets, *kwargs): + #weights = 1. - targets['pathology'] if targets['pathology'].shape == targets['FLAIR'].shape else 1. + weights = 1. - targets['FLAIR_DM'] if 'FLAIR_DM' in targets else 1. + #weights = 1. + return {'loss_FLAIR': self.loss_image(outputs['FLAIR'], targets['FLAIR'], outputs['FLAIR_sigma'] if 'FLAIR_sigma' in outputs else None, weights)} + def loss_FLAIR_grad(self, outputs, targets, *kwargs): + #weights = 1. - targets['pathology'] if targets['pathology'].shape == targets['FLAIR'].shape else 1. + weights = 1. - targets['FLAIR_DM'] if 'FLAIR_DM' in targets else 1. + #weights = 1. + return {'loss_FLAIR_grad': self.loss_image_grad(outputs['FLAIR'], targets['FLAIR'], weights)} + + def loss_CT(self, outputs, targets, *kwargs): + #weights = 1. - targets['pathology'] if targets['pathology'].shape == targets['CT'].shape else 1. + weights = 1. - targets['CT_DM'] if 'CT_DM' in targets else 1. + #weights = 1. + return {'loss_CT': self.loss_image(outputs['CT'], targets['CT'], outputs['CT_sigma'] if 'CT_sigma' in outputs else None, weights)} + def loss_CT_grad(self, outputs, targets, *kwargs): + #weights = 1. - targets['pathology'] if targets['pathology'].shape == targets['CT'].shape else 1. + weights = 1. - targets['CT_DM'] if 'CT_DM' in targets else 1. + #weights = 1. + return {'loss_CT_grad': self.loss_image_grad(outputs['CT'], targets['CT'], weights)} + + def loss_SR(self, outputs, targets, samples): + loss_SR = self.loss_image(outputs['high_res_residual'], samples['high_res_residual']) + return {'loss_SR': loss_SR} + + def loss_SR_grad(self, outputs, targets, samples): + loss_SR_grad = self.loss_image_grad(outputs['high_res_residual'], samples['high_res_residual']) + return {'loss_SR_grad': loss_SR_grad} + + def loss_bias_field_log(self, outputs, targets, samples): + if 'bias_field_log' in samples: + bf_soft_mask = 1. - targets['segmentation'][:, 0] + loss_bias_field_log = self.bflog_loss(outputs['bias_field_log'] * bf_soft_mask, samples['bias_field_log'] * bf_soft_mask) + else: + loss_bias_field_log = 0. + return {'loss_bias_field_log': loss_bias_field_log} + + + def loss_age(self, outputs, targets, *kwargs): + loss_age = abs(outputs['age'] - targets['age']) + #print(outputs['age'].item(), outputs['age'].shape, targets['age'].item(), targets['age'].shape) + return {'loss_age': loss_age} + + + def loss_image(self, output, target, output_sigma = None, weights = 1., *kwargs): + if output.shape == target.shape: + if output_sigma: + loss_image = self.loss_regression(output, output_sigma, target) + else: + loss_image = self.loss_regression(output, target, weights) + else: + loss_image = 0. + return loss_image + + def loss_image_grad(self, output, target, weights = 1., *kwargs): + return self.grad(output, target, weights) if output.shape == target.shape else 0. + + + def loss_supervised_seg(self, outputs, targets, *kwargs): + """ + Supervised segmentation differences (for dataset_name == synth) + """ + onehot_withoutcsf = targets['segmentation'].clone() + onehot_withoutcsf = onehot_withoutcsf[:, self.csf_v, ...] + onehot_withoutcsf[:, 0, :, :, :] = onehot_withoutcsf[:, 0, :, :, :] + targets['segmentation'][:, self.csf_ind, :, :, :] + + loss_supervised_seg = torch.sum(self.weights_dice_sup * (1.0 - 2.0 * ((outputs['supervised_seg'] * onehot_withoutcsf).sum(dim=[2, 3, 4])) + / torch.clamp((outputs['supervised_seg'] + onehot_withoutcsf).sum(dim=[2, 3, 4]), min=1e-5))) + + return {'loss_supervised_seg': loss_supervised_seg} + + def get_loss(self, loss_name, outputs, targets, *kwargs): + assert loss_name in self.loss_map, f'do you really want to compute {loss_name} loss?' + return self.loss_map[loss_name](outputs, targets, *kwargs) + + def forward(self, outputs, targets, *kwargs): + """ This performs the loss computation. + Parameters: + outputs: dict of tensors, see the output specification of the model for the format + targets: list of dicts, such that len(targets) == batch_size. + The expected keys in each dict depends on the losses applied, + see each loss' doc + """ + # Compute all the requested losses + losses = {} + for loss_name in self.loss_names: + losses.update(self.get_loss(loss_name, outputs, targets, *kwargs)) + return losses + + + +class SetMultiCriterion(SetCriterion): + """ + This class computes the loss for BrainID with a list of results as inputs. + """ + def __init__(self, gen_args, train_args, weight_dict, loss_names, device): + """ Create the criterion. + Parameters: + args: general exp cfg + weight_dict: dict containing as key the names of the losses and as values their + relative weight. + loss_names: list of all the losses to be applied. See get_loss for list of + available loss_names. + """ + super(SetMultiCriterion, self).__init__(gen_args, train_args, weight_dict, loss_names, device) + self.all_samples = gen_args.generator.all_samples + + def get_loss(self, loss_name, outputs_list, targets, samples_list): + assert loss_name in self.loss_map, f'do you really want to compute {loss_name} loss?' + total_loss = 0. + for i_sample, outputs in enumerate(outputs_list): + total_loss += self.loss_map[loss_name](outputs, targets, samples_list[i_sample])['loss_' + loss_name] + return {'loss_' + loss_name: total_loss / self.all_samples} + + def forward(self, outputs_list, targets, samples_list): + """ This performs the loss computation. + Parameters: + outputs: dict of tensors, see the output specification of the model for the format + targets: list of dicts, such that len(targets) == batch_size. + The expected keys in each dict depends on the losses applied, + see each loss' doc + """ + # Compute all the requested losses + losses = {} + for loss_name in self.loss_names: + losses.update(self.get_loss(loss_name, outputs_list, targets, samples_list)) + return losses + diff --git a/Trainer/models/evaluator.py b/Trainer/models/evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..163cd83723e7a2ad73dc53a39e7d5134c56742f4 --- /dev/null +++ b/Trainer/models/evaluator.py @@ -0,0 +1,195 @@ + +""" +Evaluator modules +""" + +import os + +import math +import numpy as np +import torch +import torch.nn as nn +from pytorch_msssim import ssim, ms_ssim + + +from utils.misc import MRIread, MRIwrite + + +######################################### + +# some constants +label_list_segmentation = [0, 14, 15, 16, 24, 77, 85, 2, 3, 4, 7, 8, 10, 11, 12, 13, 17, 18, 26, 28, 41, + 42, 43, 46, 47, 49, 50, 51, 52, 53, 54, 58, 60] # 33 +n_neutral_labels = 7 +n_labels = len(label_list_segmentation) +nlat = int((n_labels - n_neutral_labels) / 2.0) +vflip = np.concatenate([np.array(range(n_neutral_labels)), + np.array(range(n_neutral_labels + nlat, n_labels)), + np.array(range(n_neutral_labels, n_neutral_labels + nlat))]) + +def get_onehot(label, device): + # Matrix for one-hot encoding (includes a lookup-table) + lut = torch.zeros(10000, dtype=torch.long, device=device) + for l in range(n_labels): + lut[label_list_segmentation[l]] = l + onehotmatrix = torch.eye(n_labels, dtype=torch.float, device=device) + + label = torch.from_numpy(np.squeeze(label)) + onehot = onehotmatrix[lut[label.long()]] + + return onehot.permute([3, 0, 1, 2]) + +def align_shape(nda1, nda2): + if nda1.shape != nda2.shape: + print('pre-align', nda1.shape, nda2.shape) + s = min(nda1.shape[0], nda2.shape[0]) + r = min(nda1.shape[1], nda2.shape[1]) + c = min(nda1.shape[2], nda2.shape[2]) + nda1 = nda1[:s, :r, :c] + nda2 = nda2[:s, :r, :c] + print('post-align', nda1.shape, nda2.shape) + return nda1, nda2 + + + +class Evaluator: + """ + This class computes the evaluation scores for BrainID. + """ + def __init__(self, args, metric_names, device): + + self.args = args + self.metric_names = metric_names + self.device = device + + self.mse = nn.MSELoss() + self.l1 = nn.L1Loss() + self.win_sigma = args.ssim_win_sigma + + self.metric_map = { + 'seg_dice': self.get_dice, + 'pathol_dice': self.get_dice, + + 'feat_l1': self.get_l1, + 'recon_l1': self.get_l1, + 'sr_l1': self.get_l1, + + 'bf_normalized_l2': self.get_normalized_l2, + 'bf_corrected_l1': self.get_l1, + + 'recon_psnr': self.get_psnr, + 'sr_psnr': self.get_psnr, + + 'feat_ssim': self.get_ssim, + 'recon_ssim': self.get_ssim, + 'sr_ssim': self.get_ssim, + + 'feat_ms_ssim': self.get_ms_ssim, + 'recon_ms_ssim': self.get_ms_ssim, + 'sr_ms_ssim': self.get_ms_ssim, + } + + def get_dice(self, metric_name, output, target, *kwargs): + """ + Dice of segmentation + """ + dice = torch.mean((2.0 * ((output * target).sum(dim=[2, 3, 4])) + / torch.clamp((output + target).sum(dim=[2, 3, 4]), min=1e-5))) + return {metric_name: dice.cpu().numpy()} + + def get_normalized_l2(self, metric_name, output, target, *kwargs): + w = torch.sum(output * target) / (torch.sum(output ** 2) + 1e-7) + l2 = 0. + torch.sqrt( torch.sum( (w * output - target) ** 2 ) / (torch.sum(target ** 2) + 1e-7) ) + return {metric_name: l2.cpu().numpy()} + + def get_l1(self, metric_name, output, target, nonzero_only=False, *kwargs): + if nonzero_only: # compute only within face_aware_region # + nonzero_mask = target!=0 + l1 = (abs(target - output) * nonzero_mask).sum(dim=0) / nonzero_mask.sum(dim=0) + else: + l1 = self.l1(output, target) + return {metric_name: l1.cpu().numpy()} + + def get_psnr(self, metric_name, output, target, *kwargs): + mse = self.mse(output, target).cpu().numpy() + if mse == 0: + psnr = float('inf') + else: + psnr = 20 * math.log10(np.max(target.cpu().numpy()) / math.sqrt(mse)) + return {metric_name: psnr} + + def get_ssim(self, metric_name, output, target, *kwargs): + ''' + Ref: https://github.com/jorge-pessoa/pytorch-msssim + ''' + output = (output - output.min()) / (output.max() - output.min()) + target = (target - target.min()) / (target.max() - target.min()) + ss = ssim(output, target, data_range = 1.0, size_average = False, win_sigma = self.win_sigma) + return {metric_name: ss.mean().cpu().numpy()} + + def get_ms_ssim(self, metric_name, output, target, *kwargs): + ''' + Ref: https://github.com/jorge-pessoa/pytorch-msssim + ''' + output = (output - output.min()) / (output.max() - output.min()) + target = (target - target.min()) / (target.max() - target.min()) + try: + ms_ss = ms_ssim(output, target, data_range = 1.0, size_average = False, win_sigma = self.win_sigma) + return {metric_name: ms_ss.mean().cpu().numpy()} + except: + print('Error in MS-SSIM: Image too small for Multi-scale SSIM computation. Skipping...') + return {metric_name: float('nan')} + + def get_score(self, metric_name, output, target, **kwargs): + assert metric_name in self.metric_map, f'do you really want to compute {metric_name} metric?' + return self.metric_map[metric_name](metric_name, output, target, **kwargs) + + def eval(self, pred_path, target_path, clamp = False, is_seg = False, normalize = False, add_mask = False, flip = False, kill_target_labels = [], **kwargs): + + pred = MRIread(pred_path, im_only=True, dtype='int' if 'label' in os.path.basename(pred_path) else 'float') + target, aff = MRIread(target_path, im_only=False, dtype='int' if 'label' in os.path.basename(target_path) else 'float') + + #print(pred.shape, target.shape) + pred, target = align_shape(pred, target) + + if flip: + pred = np.flip(pred, 0) + + for label in kill_target_labels: + target[target == label] = 0 + pred[pred == label] = 0 + + if add_mask and '_masked' not in pred_path: + pred[target == 0] = 0 + pred[pred < 0] = 0 + MRIwrite(pred, aff, pred_path.split('.')[0] + '_masked.nii.gz') + + if normalize: + pred = (pred - np.min(pred)) / (np.max(pred) - np.min(pred)) + + if is_seg: + pred = get_onehot(pred.copy(), self.device) + target = get_onehot(target, self.device) + else: + pred = torch.tensor(np.squeeze(pred), dtype=torch.float32, device=self.device) + target = torch.tensor(np.squeeze(target), dtype=torch.float32, device=self.device) + + if clamp: + pred = torch.clamp(pred, min = 0., max = 1.) + target = torch.clamp(target, min = 0., max = 1.) + + if len(pred.shape) == 3: + pred = pred[None, None] + target = target[None, None] + elif len(pred.shape) == 4: # seg + pred = pred[None] + target = target[None] + assert len(pred.shape) == len(target.shape) == 5 + + score = {} + for metric_name in self.metric_names: + score.update(self.get_score(metric_name, pred, target, **kwargs)) + + return score + + diff --git a/Trainer/models/guided_diffusion/__init__.py b/Trainer/models/guided_diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4fb697cd41565a18092f884e0f6e5a64be46b5bd --- /dev/null +++ b/Trainer/models/guided_diffusion/__init__.py @@ -0,0 +1,3 @@ +""" +Codebase for " Diffusion Models for Implicit Image Segmentation Ensembles". +""" diff --git a/Trainer/models/guided_diffusion/attention.py b/Trainer/models/guided_diffusion/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..8497e83a8355f010ba435d8028f0b840381897c8 --- /dev/null +++ b/Trainer/models/guided_diffusion/attention.py @@ -0,0 +1,279 @@ +# From Zhu, L. et al. (2023). Make-A-Volume: Leveraging Latent Diffusion Models for Cross-Modality 3D Brain MRI Synthesis. In: Greenspan, H., et al. Medical Image Computing and Computer Assisted Intervention – MICCAI 2023. MICCAI 2023. Lecture Notes in Computer Science, vol 14229. Springer, Cham. https://doi.org/10.1007/978-3-031-43999-5_56 + +from inspect import isfunction +import math +import torch +import torch.nn.functional as F +from torch import nn, einsum +from einops import rearrange, repeat + +from util import checkpoint + + +def exists(val): + return val is not None + + +def uniq(arr): + return {el: True for el in arr}.keys() + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def max_neg_value(t): + return -torch.finfo(t.dtype).max + + +def init_(tensor): + dim = tensor.shape[-1] + std = 1 / math.sqrt(dim) + tensor.uniform_(-std, std) + return tensor + + +# feedforward +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = ( + nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) + if not glu + else GEGLU(dim, inner_dim) + ) + + self.net = nn.Sequential( + project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def Normalize(in_channels): + return torch.nn.GroupNorm( + num_groups=32, num_channels=in_channels, eps=1e-6, affine=True + ) + + +class LinearAttention(nn.Module): + def __init__(self, dim, heads=4, dim_head=32): + super().__init__() + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x) + q, k, v = rearrange( + qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3 + ) + k = k.softmax(dim=-1) + context = torch.einsum("bhdn,bhen->bhde", k, v) + out = torch.einsum("bhde,bhdn->bhen", context, q) + out = rearrange( + out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w + ) + return self.to_out(out) + + +class SpatialSelfAttention(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.k = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.v = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.proj_out = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = rearrange(q, "b c h w -> b (h w) c") + k = rearrange(k, "b c h w -> b c (h w)") + w_ = torch.einsum("bij,bjk->bik", q, k) + + w_ = w_ * (int(c) ** (-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = rearrange(v, "b c h w -> b c (h w)") + w_ = rearrange(w_, "b i j -> b j i") + h_ = torch.einsum("bij,bjk->bik", v, w_) + h_ = rearrange(h_, "b c (h w) -> b c h w", h=h) + h_ = self.proj_out(h_) + + return x + h_ + + +class CrossAttention(nn.Module): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.scale = dim_head**-0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) + ) + + def forward(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) + + sim = einsum("b i d, b j d -> b i j", q, k) * self.scale + + if exists(mask): + mask = rearrange(mask, "b ... -> b (...)") + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, "b j -> (b h) () j", h=h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + attn = sim.softmax(dim=-1) + + out = einsum("b i j, b j d -> b i d", attn, v) + out = rearrange(out, "(b h) n d -> b n (h d)", h=h) + return self.to_out(out) + + +class BasicTransformerBlock(nn.Module): + def __init__( + self, + dim, + n_heads, + d_head, + dropout=0.0, + context_dim=None, + gated_ff=True, + checkpoint=True, + ): + super().__init__() + self.attn1 = CrossAttention( + query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout + ) # is a self-attention + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = CrossAttention( + query_dim=dim, + context_dim=context_dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + ) # is self-attn if context is none + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + def forward(self, x, context=None): + return checkpoint( + self._forward, (x, context), self.parameters(), self.checkpoint + ) + + def _forward(self, x, context=None): + x = self.attn1(self.norm1(x)) + x + x = self.attn2(self.norm2(x), context=context) + x + x = self.ff(self.norm3(x)) + x + return x + + +class SpatialTransformer(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + """ + + def __init__( + self, in_channels, n_heads, d_head, depth=1, dropout=0.0, context_dim=None + ): + super().__init__() + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = Normalize(in_channels) + + self.proj_in = nn.Conv2d( + in_channels, inner_dim, kernel_size=1, stride=1, padding=0 + ) + + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim + ) + for d in range(depth) + ] + ) + + self.proj_out = zero_module( + nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + ) + + def forward(self, x, context=None): + # note: if no context is given, cross-attention defaults to self-attention + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + x = self.proj_in(x) + x = rearrange(x, "b c h w -> b (h w) c") + for block in self.transformer_blocks: + x = block(x, context=context) + x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) + x = self.proj_out(x) + return x + x_in diff --git a/Trainer/models/guided_diffusion/bratsloader.py b/Trainer/models/guided_diffusion/bratsloader.py new file mode 100644 index 0000000000000000000000000000000000000000..4a440fb4199d2cdcb47f42897b8d94627aa1bf4d --- /dev/null +++ b/Trainer/models/guided_diffusion/bratsloader.py @@ -0,0 +1,128 @@ +import math +import os +import os.path + +import nibabel +import nibabel as nib +import numpy as np +import torch +import torch.nn + + +class BRATSDataset(torch.utils.data.Dataset): + def __init__(self, directory, test_flag=True): + super().__init__() + self.directory = os.path.expanduser(directory) + + self.test_flag = test_flag + if test_flag: + self.seqtypes = ["voided", "mask"] + else: + #self.seqtypes = ["diseased", "mask", "healthy"] + self.seqtypes = ["healthyvoided", "healthy", "t1n"] + + self.seqtypes_set = set(self.seqtypes) + self.database = [] + self.mask_vis = [] + for root, dirs, files in os.walk(self.directory): + dirs_sorted = sorted(dirs) + for dir_id in dirs_sorted: + datapoint = dict() + sli_dict = dict() + for ro, di, fi in os.walk(root + "/" + str(dir_id)): + fi_sorted = sorted(fi) + for f in fi_sorted: + seqtype = f.split("-")[-1].split(".")[0] + #print('seqtype:', seqtype) + datapoint[seqtype] = os.path.join(root, dir_id, f) + if seqtype == "mask": + slice_range = [] + mask_to_define_rand = np.array( + nibabel.load(datapoint["mask"]).dataobj + ) + if test_flag: + mask_to_define_rand = np.pad( + mask_to_define_rand, ((0, 0), (0, 0), (34, 35)) + ) + mask_to_define_rand = mask_to_define_rand[8:-8, 8:-8, :] + for i in range(0, 224): + mask_slice = mask_to_define_rand[:, :, i] + if np.sum(mask_slice) != 0: + slice_range.append(i) + + # assert ( + # set(datapoint.keys()) == self.seqtypes_set + # ), f"datapoint is incomplete, keys are {datapoint.keys()}" + self.database.append(datapoint) + self.mask_vis.append(slice_range) + + break + + def __getitem__(self, x): + filedict = self.database[x] + slicedict = self.mask_vis[x] + + #print("input files: ", filedict) + #print("slice dict:", slicedict) + + out_single = [] + + if self.test_flag: + for seqtype in self.seqtypes: + if seqtype == "voided": + nib_img = np.array(nibabel.load(filedict[seqtype]).dataobj).astype( + np.float32 + ) + path = filedict[seqtype] + t1_numpy_pad = np.pad(nib_img, ((0, 0), (0, 0), (34, 35))) + t1_numpy_crop = t1_numpy_pad[8:-8, 8:-8, :] # crop-pad to (224, 224, 224) + t1_clipped = np.clip( + t1_numpy_crop, + np.quantile(t1_numpy_crop, 0.001), + np.quantile(t1_numpy_crop, 0.999), + ) + t1_normalized = (t1_clipped - np.min(t1_clipped)) / ( + np.max(t1_clipped) - np.min(t1_clipped) + ) + img_preprocessed = torch.tensor(t1_normalized) + elif seqtype == "mask": + nib_img = np.array(nibabel.load(filedict[seqtype]).dataobj).astype( + np.float32 + ) + path = filedict[seqtype] + mask_numpy_pad = np.pad(nib_img, ((0, 0), (0, 0), (34, 35))) + mask_numpy_crop = mask_numpy_pad[8:-8, 8:-8, :] + img_preprocessed = torch.tensor(mask_numpy_crop) + else: + print("unknown seqtype") + + out_single.append(img_preprocessed) + + out_single = torch.stack(out_single) + + image = out_single[0:2, ...] + path = filedict[seqtype] + + return (image, path, slicedict) + + else: + for seqtype in self.seqtypes: + nib_img = np.array(nibabel.load(filedict[seqtype]).dataobj).astype( + np.float32 + ) + path = filedict[seqtype] + img_preprocessed = torch.tensor(nib_img) + + out_single.append(img_preprocessed) + + out_single = torch.stack(out_single) # "diseased", "mask", "healthy" + + image = out_single[0:2, ...] # "diseased", "mask" + label = out_single[2, ...] # "healthy" + label = label.unsqueeze(0) + path = filedict[seqtype] + + return (image, label, path, slicedict) + + def __len__(self): + return len(self.database) diff --git a/Trainer/models/guided_diffusion/dist_util.py b/Trainer/models/guided_diffusion/dist_util.py new file mode 100644 index 0000000000000000000000000000000000000000..167e03fa1c58a8ce426f388c9fe788370e3196a1 --- /dev/null +++ b/Trainer/models/guided_diffusion/dist_util.py @@ -0,0 +1,88 @@ +""" +Helpers for distributed training. +""" + +import io +import os +import socket + +import blobfile as bf + +# from mpi4py import MPI +import torch +import torch.distributed as dist + +# Change this to reflect your cluster layout. +# The GPU for a given rank is (rank % GPUS_PER_NODE). +GPUS_PER_NODE = 8 + +SETUP_RETRY_COUNT = 3 + + +def setup_dist(): + """ + Setup a distributed process group. + """ + if dist.is_initialized(): + return + #os.environ["CUDA_VISIBLE_DEVICES"] = "1" + os.environ["CUDA_VISIBLE_DEVICES"] = "0" + + backend = "gloo" if not torch.cuda.is_available() else "nccl" + + if backend == "gloo": + hostname = "localhost" + else: + hostname = socket.gethostbyname(socket.getfqdn()) + os.environ["MASTER_ADDR"] = "127.0.1.1" # comm.bcast(hostname, root=0) + os.environ["RANK"] = "0" # str(comm.rank) + os.environ["WORLD_SIZE"] = "1" # str(comm.size) + + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.bind(("", 0)) + s.listen(1) + port = s.getsockname()[1] + s.close() + os.environ["MASTER_PORT"] = str(port) + dist.init_process_group(backend=backend, init_method="env://") + + +def dev(): + """ + Get the device to use for torch.distributed. + """ + if torch.cuda.is_available(): + return torch.device(f"cuda") + return torch.device("cpu") + + +def load_state_dict(path, **kwargs): + """ + Load a PyTorch file without redundant fetches across MPI ranks. + """ + mpigetrank = 0 + if mpigetrank == 0: + with bf.BlobFile(path, "rb") as f: + data = f.read() + else: + data = None + return torch.load(io.BytesIO(data), **kwargs) + + +def sync_params(params): + """ + Synchronize a sequence of Tensors across ranks from rank 0. + """ + for p in params: + with torch.no_grad(): + dist.broadcast(p, 0) + + +def _find_free_port(): + try: + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.bind(("", 0)) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + return s.getsockname()[1] + finally: + s.close() diff --git a/Trainer/models/guided_diffusion/fp16_util.py b/Trainer/models/guided_diffusion/fp16_util.py new file mode 100644 index 0000000000000000000000000000000000000000..4ca07decddacd4fe8f89c2ea98c868528728656d --- /dev/null +++ b/Trainer/models/guided_diffusion/fp16_util.py @@ -0,0 +1,236 @@ +""" +Helpers to train with 16-bit precision. +""" + +import numpy as np +import torch +import torch.nn as nn +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors + +import logger + +INITIAL_LOG_LOSS_SCALE = 20.0 + + +def convert_module_to_f16(l): + """ + Convert primitive modules to float16. + """ + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): + l.weight.data = l.weight.data.half() + if l.bias is not None: + l.bias.data = l.bias.data.half() + + +def convert_module_to_f32(l): + """ + Convert primitive modules to float32, undoing convert_module_to_f16(). + """ + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): + l.weight.data = l.weight.data.float() + if l.bias is not None: + l.bias.data = l.bias.data.float() + + +def make_master_params(param_groups_and_shapes): + """ + Copy model parameters into a (differently-shaped) list of full-precision + parameters. + """ + master_params = [] + for param_group, shape in param_groups_and_shapes: + master_param = nn.Parameter( + _flatten_dense_tensors( + [param.detach().float() for (_, param) in param_group] + ).view(shape) + ) + master_param.requires_grad = True + master_params.append(master_param) + return master_params + + +def model_grads_to_master_grads(param_groups_and_shapes, master_params): + """ + Copy the gradients from the model parameters into the master parameters + from make_master_params(). + """ + for master_param, (param_group, shape) in zip( + master_params, param_groups_and_shapes + ): + master_param.grad = _flatten_dense_tensors( + [param_grad_or_zeros(param) for (_, param) in param_group] + ).view(shape) + + +def master_params_to_model_params(param_groups_and_shapes, master_params): + """ + Copy the master parameter data back into the model parameters. + """ + # Without copying to a list, if a generator is passed, this will + # silently not copy any parameters. + for master_param, (param_group, _) in zip(master_params, param_groups_and_shapes): + for (_, param), unflat_master_param in zip( + param_group, unflatten_master_params(param_group, master_param.view(-1)) + ): + param.detach().copy_(unflat_master_param) + + +def unflatten_master_params(param_group, master_param): + return _unflatten_dense_tensors(master_param, [param for (_, param) in param_group]) + + +def get_param_groups_and_shapes(named_model_params): + named_model_params = list(named_model_params) + scalar_vector_named_params = ( + [(n, p) for (n, p) in named_model_params if p.ndim <= 1], + (-1), + ) + matrix_named_params = ( + [(n, p) for (n, p) in named_model_params if p.ndim > 1], + (1, -1), + ) + return [scalar_vector_named_params, matrix_named_params] + + +def master_params_to_state_dict( + model, param_groups_and_shapes, master_params, use_fp16 +): + if use_fp16: + state_dict = model.state_dict() + for master_param, (param_group, _) in zip( + master_params, param_groups_and_shapes + ): + for (name, _), unflat_master_param in zip( + param_group, unflatten_master_params(param_group, master_param.view(-1)) + ): + assert name in state_dict + state_dict[name] = unflat_master_param + else: + state_dict = model.state_dict() + for i, (name, _value) in enumerate(model.named_parameters()): + assert name in state_dict + state_dict[name] = master_params[i] + return state_dict + + +def state_dict_to_master_params(model, state_dict, use_fp16): + if use_fp16: + named_model_params = [ + (name, state_dict[name]) for name, _ in model.named_parameters() + ] + param_groups_and_shapes = get_param_groups_and_shapes(named_model_params) + master_params = make_master_params(param_groups_and_shapes) + else: + master_params = [state_dict[name] for name, _ in model.named_parameters()] + return master_params + + +def zero_master_grads(master_params): + for param in master_params: + param.grad = None + + +def zero_grad(model_params): + for param in model_params: + # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group + if param.grad is not None: + param.grad.detach_() + param.grad.zero_() + + +def param_grad_or_zeros(param): + if param.grad is not None: + return param.grad.data.detach() + else: + return torch.zeros_like(param) + + +class MixedPrecisionTrainer: + def __init__( + self, + *, + model, + use_fp16=False, + fp16_scale_growth=1e-3, + initial_lg_loss_scale=INITIAL_LOG_LOSS_SCALE, + ): + self.model = model + self.use_fp16 = use_fp16 + self.fp16_scale_growth = fp16_scale_growth + + self.model_params = list(self.model.parameters()) + self.master_params = self.model_params + self.param_groups_and_shapes = None + self.lg_loss_scale = initial_lg_loss_scale + + if self.use_fp16: + self.param_groups_and_shapes = get_param_groups_and_shapes( + self.model.named_parameters() + ) + self.master_params = make_master_params(self.param_groups_and_shapes) + self.model.convert_to_fp16() + + def zero_grad(self): + zero_grad(self.model_params) + + def backward(self, loss: torch.Tensor): + if self.use_fp16: + loss_scale = 2**self.lg_loss_scale + (loss * loss_scale).backward() + else: + loss.backward() + + def optimize(self, opt: torch.optim.Optimizer): + if self.use_fp16: + return self._optimize_fp16(opt) + else: + return self._optimize_normal(opt) + + def _optimize_fp16(self, opt: torch.optim.Optimizer): + logger.logkv_mean("lg_loss_scale", self.lg_loss_scale) + model_grads_to_master_grads(self.param_groups_and_shapes, self.master_params) + grad_norm, param_norm = self._compute_norms(grad_scale=2**self.lg_loss_scale) + if check_overflow(grad_norm): + self.lg_loss_scale -= 1 + logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}") + zero_master_grads(self.master_params) + return False + + logger.logkv_mean("grad_norm", grad_norm) + logger.logkv_mean("param_norm", param_norm) + + self.master_params[0].grad.mul_(1.0 / (2**self.lg_loss_scale)) + opt.step() + zero_master_grads(self.master_params) + master_params_to_model_params(self.param_groups_and_shapes, self.master_params) + self.lg_loss_scale += self.fp16_scale_growth + return True + + def _optimize_normal(self, opt: torch.optim.Optimizer): + grad_norm, param_norm = self._compute_norms() + logger.logkv_mean("grad_norm", grad_norm) + logger.logkv_mean("param_norm", param_norm) + opt.step() + return True + + def _compute_norms(self, grad_scale=1.0): + grad_norm = 0.0 + param_norm = 0.0 + for p in self.master_params: + with torch.no_grad(): + param_norm += torch.norm(p, p=2, dtype=torch.float32).item() ** 2 + if p.grad is not None: + grad_norm += torch.norm(p.grad, p=2, dtype=torch.float32).item() ** 2 + return np.sqrt(grad_norm) / grad_scale, np.sqrt(param_norm) + + def master_params_to_state_dict(self, master_params): + return master_params_to_state_dict( + self.model, self.param_groups_and_shapes, master_params, self.use_fp16 + ) + + def state_dict_to_master_params(self, state_dict): + return state_dict_to_master_params(self.model, state_dict, self.use_fp16) + + +def check_overflow(value): + return (value == float("inf")) or (value == -float("inf")) or (value != value) diff --git a/Trainer/models/guided_diffusion/gaussian_diffusion.py b/Trainer/models/guided_diffusion/gaussian_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..4edd7071058d9b7bce76672c49022829cfdbb316 --- /dev/null +++ b/Trainer/models/guided_diffusion/gaussian_diffusion.py @@ -0,0 +1,1067 @@ +""" +This code started out as a PyTorch port of Ho et al's diffusion models: +https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py +Docstrings have been added, as well as DDIM sampling and a new collection of beta schedules. +""" +from torch.autograd import Variable +import enum +import torch.nn.functional as F +from torchvision.utils import save_image +import torch +import math +import numpy as np +import torch +from train_util import visualize +from nn import mean_flat +from losses import normal_kl, discretized_gaussian_log_likelihood +from scipy import ndimage +from torchvision import transforms + + +def standardize(img): + mean = torch.mean(img) + std = torch.std(img) + img = (img - mean) / std + return img + + +def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): + """ + Get a pre-defined beta schedule for the given name. + The beta schedule library consists of beta schedules which remain similar + in the limit of num_diffusion_timesteps. + Beta schedules may be added, but should not be removed or changed once + they are committed to maintain backwards compatibility. + """ + if schedule_name == "linear": + # Linear schedule from Ho et al, extended to work for any number of + # diffusion steps. + scale = 1000 / num_diffusion_timesteps + beta_start = scale * 0.0001 + beta_end = scale * 0.02 + return np.linspace( + beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64 + ) + elif schedule_name == "cosine": + return betas_for_alpha_bar( + num_diffusion_timesteps, + lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, + ) + else: + raise NotImplementedError(f"unknown beta schedule: {schedule_name}") + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +class ModelMeanType(enum.Enum): + """ + Which type of output the model predicts. + """ + + PREVIOUS_X = enum.auto() # the model predicts x_{t-1} + START_X = enum.auto() # the model predicts x_0 + EPSILON = enum.auto() # the model predicts epsilon + + +class ModelVarType(enum.Enum): + """ + What is used as the model's output variance. + The LEARNED_RANGE option has been added to allow the model to predict + values between FIXED_SMALL and FIXED_LARGE, making its job easier. + """ + + LEARNED = enum.auto() + FIXED_SMALL = enum.auto() + FIXED_LARGE = enum.auto() + LEARNED_RANGE = enum.auto() + + +class LossType(enum.Enum): + MSE = enum.auto() # use raw MSE loss (and KL when learning variances) + RESCALED_MSE = ( + enum.auto() + ) # use raw MSE loss (with RESCALED_KL when learning variances) + KL = enum.auto() # use the variational lower-bound + RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB + + def is_vb(self): + return self == LossType.KL or self == LossType.RESCALED_KL + + +class GaussianDiffusion: + """ + Utilities for training and sampling diffusion models. + Ported directly from here, and then adapted over time to further experimentation. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42 + :param betas: a 1-D numpy array of betas for each diffusion timestep, + starting at T and going to 1. + :param model_mean_type: a ModelMeanType determining what the model outputs. + :param model_var_type: a ModelVarType determining how variance is output. + :param loss_type: a LossType determining the loss function to use. + :param rescale_timesteps: if True, pass floating point timesteps into the + model so that they are always scaled like in the + original paper (0 to 1000). + """ + + def __init__( + self, + *, + betas, + model_mean_type, + model_var_type, + loss_type, + rescale_timesteps=False, + ): + self.model_mean_type = model_mean_type + self.model_var_type = model_var_type + self.loss_type = loss_type + self.rescale_timesteps = rescale_timesteps + + # Use float64 for accuracy. + betas = np.array(betas, dtype=np.float64) + self.betas = betas + assert len(betas.shape) == 1, "betas must be 1-D" + assert (betas > 0).all() and (betas <= 1).all() + + self.num_timesteps = int(betas.shape[0]) + + alphas = 1.0 - betas + self.alphas_cumprod = np.cumprod(alphas, axis=0) + self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) + self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) + assert self.alphas_cumprod_prev.shape == (self.num_timesteps,) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) + self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) + self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) + self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) + self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + self.posterior_variance = ( + betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + ) + # log calculation clipped because the posterior variance is 0 at the + # beginning of the diffusion chain. + self.posterior_log_variance_clipped = np.log( + np.append(self.posterior_variance[1], self.posterior_variance[1:]) + ) + self.posterior_mean_coef1 = ( + betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + ) + self.posterior_mean_coef2 = ( + (1.0 - self.alphas_cumprod_prev) + * np.sqrt(alphas) + / (1.0 - self.alphas_cumprod) + ) + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = ( + _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + ) + variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = _extract_into_tensor( + self.log_one_minus_alphas_cumprod, t, x_start.shape + ) + return mean, variance, log_variance + + def q_sample(self, x_start, t, noise=None): + """ + Diffuse the data for a given number of diffusion steps. + In other words, sample from q(x_t | x_0). + :param x_start: the initial data batch. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :param noise: if specified, the split-out normal noise. + :return: A noisy version of x_start. + """ + if noise is None: + noise = torch.randn_like(x_start) + assert noise.shape == x_start.shape + return ( + _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) + * noise + ) + + def q_posterior_mean_variance(self, x_start, x_t, t): + """ + Compute the mean and variance of the diffusion posterior: + q(x_{t-1} | x_t, x_0) + """ + assert x_start.shape == x_t.shape + posterior_mean = ( + _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = _extract_into_tensor( + self.posterior_log_variance_clipped, t, x_t.shape + ) + assert ( + posterior_mean.shape[0] + == posterior_variance.shape[0] + == posterior_log_variance_clipped.shape[0] + == x_start.shape[0] + ) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance( + self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None + ): + """ + Apply the model to get p(x_{t-1} | x_t), as well as a prediction of + the initial x, x_0. + :param model: the model, which takes a signal and a batch of timesteps + as input. + :param x: the [N x C x ...] tensor at time t. + :param t: a 1-D Tensor of timesteps. + :param clip_denoised: if True, clip the denoised signal into [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. Applies before + clip_denoised. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict with the following keys: + - 'mean': the model mean output. + - 'variance': the model variance output. + - 'log_variance': the log of 'variance'. + - 'pred_xstart': the prediction for x_0. + """ + if model_kwargs is None: + model_kwargs = {} + B, C = x.shape[:2] + C = 1 + assert t.shape == (B,) + model_output = model(x, self._scale_timesteps(t), **model_kwargs) + x = x[ + :, -1:, ... + ] # loss is only calculated on the last channel, not on the input brain MR image + if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]: # THIS + assert model_output.shape == (B, C * 2, *x.shape[2:]) + model_output, model_var_values = torch.split(model_output, C, dim=1) + if self.model_var_type == ModelVarType.LEARNED: + model_log_variance = model_var_values + model_variance = torch.exp(model_log_variance) + else: # THIS + min_log = _extract_into_tensor( + self.posterior_log_variance_clipped, t, x.shape + ) + max_log = _extract_into_tensor(np.log(self.betas), t, x.shape) + # The model_var_values is [-1, 1] for [min_var, max_var]. + frac = (model_var_values + 1) / 2 + model_log_variance = frac * max_log + (1 - frac) * min_log + model_variance = torch.exp(model_log_variance) + else: + model_variance, model_log_variance = { + # for fixedlarge, we set the initial (log-)variance like so + # to get a better decoder log likelihood. + ModelVarType.FIXED_LARGE: ( + np.append(self.posterior_variance[1], self.betas[1:]), + np.log(np.append(self.posterior_variance[1], self.betas[1:])), + ), + ModelVarType.FIXED_SMALL: ( + self.posterior_variance, + self.posterior_log_variance_clipped, + ), + }[self.model_var_type] + model_variance = _extract_into_tensor(model_variance, t, x.shape) + model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape) + + def process_xstart(x): + if denoised_fn is not None: + x = denoised_fn(x) + if clip_denoised: + return x.clamp(-1, 1) + return x + + if self.model_mean_type == ModelMeanType.PREVIOUS_X: + pred_xstart = process_xstart( + self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output) + ) + model_mean = model_output + elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]: # THIS + if self.model_mean_type == ModelMeanType.START_X: + pred_xstart = process_xstart(model_output) + else: # THIS + pred_xstart = process_xstart( + self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output) + ) + model_mean, _, _ = self.q_posterior_mean_variance( + x_start=pred_xstart, x_t=x, t=t + ) + else: + raise NotImplementedError(self.model_mean_type) + + assert ( + model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape + ) + return { + "mean": model_mean, + "variance": model_variance, + "log_variance": model_log_variance, + "pred_xstart": pred_xstart, + } + + def _predict_xstart_from_eps(self, x_t, t, eps): + assert x_t.shape == eps.shape + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps + ) + + def _predict_xstart_from_xprev(self, x_t, t, xprev): + assert x_t.shape == xprev.shape + return ( # (xprev - coef2*x_t) / coef1 + _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev + - _extract_into_tensor( + self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape + ) + * x_t + ) + + def _predict_eps_from_xstart(self, x_t, t, pred_xstart): + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - pred_xstart + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + + def _scale_timesteps(self, t): + if self.rescale_timesteps: + return t.float() * (1000.0 / self.num_timesteps) + return t + + def condition_mean(self, cond_fn, p_mean_var, x, t, org, model_kwargs=None): + """ + Compute the mean for the previous step, given a function cond_fn that + computes the gradient of a conditional log probability with respect to + x. In particular, cond_fn computes grad(log(p(y|x))), and we want to + condition on y. + This uses the conditioning strategy from Sohl-Dickstein et al. (2015). + """ + a, gradient = cond_fn(x, self._scale_timesteps(t), org, **model_kwargs) + + new_mean = ( + p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() + ) + return a, new_mean + + def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute what the p_mean_variance output would have been, should the + model's score function be conditioned by cond_fn. + See condition_mean() for details on cond_fn. + Unlike condition_mean(), this instead uses the conditioning strategy + from Song et al (2020). + """ + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + + eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) + + eps = eps.detach() - (1 - alpha_bar).sqrt() * p_mean_var["update"] * 0 + + out = p_mean_var.copy() + out["pred_xstart"] = self._predict_xstart_from_eps(x.detach(), t.detach(), eps) + out["mean"], _, _ = self.q_posterior_mean_variance( + x_start=out["pred_xstart"], x_t=x, t=t + ) + return out, eps + + def sample_known(self, img, batch_size=1): + image_size = self.image_size + channels = self.channels + return self.p_sample_loop_known( + model, (batch_size, channels, image_size, image_size), img + ) + + def p_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + model_kwargs=None, + ): + """ + Sample x_{t-1} from the model at the given timestep. + :param model: the model to sample from. + :param x: the current tensor at x_{t-1}. + :param t: the value of t, starting at 0 for the first diffusion step. + :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict containing the following keys: + - 'sample': a random sample from the model. + - 'pred_xstart': a prediction of x_0. + """ + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + noise = torch.randn_like(x[:, -1:, ...]) + nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + sample = out["mean"] + nonzero_mask * torch.exp(0.5 * out["log_variance"]) * noise + + return {"sample": sample, "pred_xstart": out["pred_xstart"]} + + def p_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + ): + """ + Generate samples from the model. + :param model: the model module. + :param shape: the shape of the samples, (N, C, H, W). + :param noise: if specified, the noise from the encoder to sample. + Should be of the same shape as `shape`. + :param clip_denoised: if True, clip x_start predictions to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param device: if specified, the device to create the samples on. + If not specified, use a model parameter's device. + :param progress: if True, show a tqdm progress bar. + :return: a non-differentiable batch of samples. + """ + final = None + print("apparently we use this function") + for sample in self.p_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + ): + final = sample + return final["sample"] + + def p_sample_loop_known( + self, + model, + shape, + img, + org=None, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + conditioner=None, + classifier=None, + ): + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + img = img.to(device) + noise = torch.randn_like(img[:, :1, ...]).to(device) + x_noisy = torch.cat( + (img[:, :-1, ...], noise), dim=1 + ) # add noise as the last channel + + img = img.to(device) + + for sample in self.p_sample_loop_progressive( + model, + shape, + noise=x_noisy, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + org=org, + model_kwargs=model_kwargs, + device=device, + progress=progress, + ): + final = sample + + return final["sample"], x_noisy, img + + def p_sample_loop_progressive( + self, + model, + shape, + time=1000, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + org=None, + model_kwargs=None, + device=None, + progress=False, + ): + """ + Generate samples from the model and yield intermediate samples from + each timestep of diffusion. + Arguments are the same as p_sample_loop(). + Returns a generator over dicts, where each dict is the return value of + p_sample(). + """ + + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = torch.randn(*shape, device=device) + indices = list(range(time))[::-1] + + org_MRI = img[:, :-1, ...] # original brain MR image + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices) + + else: + # i_list = [0, 1, 2, 3, 4, 100, 250, 500, 999] + i_exceptions = [0, 1, 2] + for i in indices: + t = torch.tensor([i] * shape[0], device=device) + + imarr = np.asarray(img.cpu().detach()) + + with torch.no_grad(): + if img.shape != (imarr.shape[0], 3, 224, 224): + img = torch.cat( + (org_MRI, img), dim=1 + ) # in every step, make sure to concatenate the original image to the sample + + out = self.p_sample( + model, + img.float(), + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + yield out + + if i in i_exceptions: + img = out["pred_xstart"] + + else: + img = out["sample"] + + # if i in i_list: + # print('sampling step out ', i, np.min(np.asarray(img.detach().cpu())), np.max(np.asarray(img.detach().cpu()))) + + def ddim_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + eta=0.0, + ): + """ + Sample x_{t-1} from the model using DDIM. + Same usage as p_sample(). + """ + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + + if cond_fn is not None: + out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) + + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) + + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) + sigma = ( + eta + * torch.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) + * torch.sqrt(1 - alpha_bar / alpha_bar_prev) + ) + # Equation 12. + noise = torch.randn_like(x[:, -1:, ...]) + + mean_pred = ( + out["pred_xstart"] * torch.sqrt(alpha_bar_prev) + + torch.sqrt(1 - alpha_bar_prev - sigma**2) * eps + ) + nonzero_mask = ( + (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + ) # no noise when t == 0 + sample = mean_pred + nonzero_mask * sigma * noise + return {"sample": sample, "pred_xstart": out["pred_xstart"]} + + def ddim_reverse_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + model_kwargs=None, + eta=0.0, + ): + """ + Sample x_{t+1} from the model using DDIM reverse ODE. + """ + assert eta == 0.0, "Reverse ODE only for deterministic path" + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x + - out["pred_xstart"] + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) + alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape) + + # Equation 12. reversed + mean_pred = ( + out["pred_xstart"] * torch.sqrt(alpha_bar_next) + + torch.sqrt(1 - alpha_bar_next) * eps + ) + + return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]} + + def ddim_sample_loop_interpolation( + self, + model, + shape, + img1, + img2, + lambdaint, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + ): + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + b = shape[0] + t = torch.randint(499, 500, (b,), device=device).long().to(device) + + img1 = torch.tensor(img1).to(device) + img2 = torch.tensor(img2).to(device) + + noise = torch.randn_like(img1).to(device) + x_noisy1 = self.q_sample(x_start=img1, t=t, noise=noise).to(device) + x_noisy2 = self.q_sample(x_start=img2, t=t, noise=noise).to(device) + interpol = lambdaint * x_noisy1 + (1 - lambdaint) * x_noisy2 + + for sample in self.ddim_sample_loop_progressive( + model, + shape, + time=t, + noise=interpol, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + ): + final = sample + return final["sample"], interpol, img1, img2 + + def ddim_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + ): + """ + Generate samples from the model using DDIM. + Same usage as p_sample_loop(). + """ + final = None + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + b = shape[0] + t = torch.randint(99, 100, (b,), device=device).long().to(device) + + for sample in self.ddim_sample_loop_progressive( + model, + shape, + time=t, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + eta=eta, + ): + final = sample + + return final["sample"] + + def ddim_sample_loop_known( + self, + model, + shape, + img, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + ): + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + b = shape[0] + + img = img.to(device) + + t = torch.randint(499, 500, (b,), device=device).long().to(device) + noise = torch.randn_like(img[:, :1, ...]).to(device) + + x_noisy = torch.cat((img[:, :-1, ...], noise), dim=1).float() + img = img.to(device) + + final = None + for sample in self.ddim_sample_loop_progressive( + model, + shape, + time=t, + noise=x_noisy, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + eta=eta, + ): + final = sample + + return final["sample"], x_noisy, img + + def ddim_sample_loop_progressive( + self, + model, + shape, + time=1000, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + ): + """ + Use DDIM to sample from the model and yield intermediate samples from + each timestep of DDIM. + Same usage as p_sample_loop_progressive(). + """ + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = torch.randn(*shape, device=device) + indices = list(range(time - 1))[::-1] + orghigh = img[:, :-1, ...] + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices) + + for i in indices: + t = torch.tensor([i] * shape[0], device=device) + with torch.no_grad(): + if img.shape != (1, 5, 224, 224): + img = torch.cat((orghigh, img), dim=1).float() + + out = self.ddim_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + eta=eta, + ) + yield out + img = out["sample"] + + def _vb_terms_bpd( + self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None + ): + """ + Get a term for the variational lower-bound. + The resulting units are bits (rather than nats, as one might expect). + This allows for comparison to other papers. + :return: a dict with the following keys: + - 'output': a shape [N] tensor of NLLs or KLs. + - 'pred_xstart': the x_0 predictions. + """ + true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance( + x_start=x_start, x_t=x_t, t=t + ) + out = self.p_mean_variance( + model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs + ) + kl = normal_kl( + true_mean, true_log_variance_clipped, out["mean"], out["log_variance"] + ) + kl = mean_flat(kl) / np.log(2.0) + + decoder_nll = -discretized_gaussian_log_likelihood( + x_start, means=out["mean"], log_scales=0.5 * out["log_variance"] + ) + assert decoder_nll.shape == x_start.shape + decoder_nll = mean_flat(decoder_nll) / np.log(2.0) + + # At the first timestep return the decoder NLL, + # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) + output = torch.where((t == 0), decoder_nll, kl) + return {"output": output, "pred_xstart": out["pred_xstart"]} + + def training_losses_segmentation( + self, model, classifier, x_start, t, model_kwargs=None, noise=None + ): + """ + Compute training losses for a single timestep. + :param model: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. + :param t: a batch of timestep indices. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param noise: if specified, the specific Gaussian noise to try to remove. + :return: a dict with the key "loss" containing a tensor of shape [N]. + Some mean or variance settings may also have other keys. + """ + + if model_kwargs is None: + model_kwargs = {} + if noise is None: + noise = torch.randn_like(x_start[:, -1:, ...]) + + goal = x_start[:, -1:, ...] # whole image + + res_t = self.q_sample( + goal, t, noise=noise + ) # during q, noise is only added to the ground truth! + + x_t = x_start.float() + + x_t[:, -1:, ...] = res_t.float() # replace last channel by noisy GT # ( void + mask + noisy_GT ) + + terms = {} + + if self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: + model_output = model(x_t, self._scale_timesteps(t), **model_kwargs) # x_t - (16, 3, 224, 224); out - (16, 2, 224, 224) + print('model_output', model_output.shape) #(mu, var) + + if self.model_var_type in [ + ModelVarType.LEARNED, + ModelVarType.LEARNED_RANGE, + ]: + B, C = x_t.shape[:2] + C = 1 + assert model_output.shape == (B, C * 2, *x_t.shape[2:]) # (16, 2, 224, 224): (mu, var) + model_output, model_var_values = torch.split(model_output, C, dim=1) + # Learn the variance using the variational bound, but don't let + # it affect our mean prediction. + frozen_out = torch.cat([model_output.detach(), model_var_values], dim=1) + terms["vb"] = self._vb_terms_bpd( + model=lambda *args, r=frozen_out: r, + x_start=goal, + x_t=res_t, + t=t, + clip_denoised=False, + )["output"] + if self.loss_type == LossType.RESCALED_MSE: + # Divide by 1000 for equivalence with initial implementation. + # Without a factor of 1/1000, the VB term hurts the MSE term. + terms["vb"] *= self.num_timesteps / 1000.0 + + target = { + ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance( + x_start=goal, x_t=res_t, t=t + )[0], + ModelMeanType.START_X: goal, + ModelMeanType.EPSILON: noise, # THIS # the GT noise + }[self.model_mean_type] + terms["mse"] = mean_flat((target - model_output) ** 2) # noise MSE + if "vb" in terms: + terms["loss"] = terms["mse"] + terms["vb"] + else: + terms["loss"] = terms["mse"] + + else: + raise NotImplementedError(self.loss_type) + + return (terms, model_output) + + def _prior_bpd(self, x_start): + """ + Get the prior KL term for the variational lower-bound, measured in + bits-per-dim. + This term can't be optimized, as it only depends on the encoder. + :param x_start: the [N x C x ...] tensor of inputs. + :return: a batch of [N] KL values (in bits), one per batch element. + """ + batch_size = x_start.shape[0] + t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) + qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) + kl_prior = normal_kl( + mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0 + ) + return mean_flat(kl_prior) / np.log(2.0) + + def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None): + """ + Compute the entire variational lower-bound, measured in bits-per-dim, + as well as other related quantities. + :param model: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. + :param clip_denoised: if True, clip denoised samples. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict containing the following keys: + - total_bpd: the total variational lower-bound, per batch element. + - prior_bpd: the prior term in the lower-bound. + - vb: an [N x T] tensor of terms in the lower-bound. + - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep. + - mse: an [N x T] tensor of epsilon MSEs for each timestep. + """ + device = x_start.device + batch_size = x_start.shape[0] + + vb = [] + xstart_mse = [] + mse = [] + for t in list(range(self.num_timesteps))[::-1]: + t_batch = torch.tensor([t] * batch_size, device=device) + noise = torch.randn_like(x_start) + x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise) + + # Calculate VLB term at the current timestep + with torch.no_grad(): + out = self._vb_terms_bptimestepsd( + model, + x_start=x_start, + x_t=x_t, + t=t_batch, + clip_denoised=clip_denoised, + model_kwargs=model_kwargs, + ) + vb.append(out["output"]) + xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2)) + eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"]) + mse.append(mean_flat((eps - noise) ** 2)) + + vb = torch.stack(vb, dim=1) + xstart_mse = torch.stack(xstart_mse, dim=1) + mse = torch.stack(mse, dim=1) + + prior_bpd = self._prior_bpd(x_start) + total_bpd = vb.sum(dim=1) + prior_bpd + return { + "total_bpd": total_bpd, + "prior_bpd": prior_bpd, + "vb": vb, + "xstart_mse": xstart_mse, + "mse": mse, + } + + +def _extract_into_tensor(arr, timesteps, broadcast_shape): + """ + Extract values from a 1-D numpy array for a batch of indices. + :param arr: the 1-D numpy array. + :param timesteps: a tensor of indices into the array to extract. + :param broadcast_shape: a larger shape of K dimensions with the batch + dimension equal to the length of timesteps. + :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. + """ + res = torch.from_numpy(arr).to(device=timesteps.device)[timesteps].float() + while len(res.shape) < len(broadcast_shape): + res = res[..., None] + return res.expand(broadcast_shape) diff --git a/Trainer/models/guided_diffusion/logger.py b/Trainer/models/guided_diffusion/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..aefb21897515bfad6d0c6755818282503cde8f07 --- /dev/null +++ b/Trainer/models/guided_diffusion/logger.py @@ -0,0 +1,494 @@ +""" +Logger copied from OpenAI baselines to avoid extra RL-based dependencies: +https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/logger.py +""" + +import os +import sys +import shutil +import os.path as osp +import json +import time +import datetime +import tempfile +import warnings +from collections import defaultdict +from contextlib import contextmanager + +DEBUG = 10 +INFO = 20 +WARN = 30 +ERROR = 40 + +DISABLED = 50 + + +class KVWriter(object): + def writekvs(self, kvs): + raise NotImplementedError + + +class SeqWriter(object): + def writeseq(self, seq): + raise NotImplementedError + + +class HumanOutputFormat(KVWriter, SeqWriter): + def __init__(self, filename_or_file): + if isinstance(filename_or_file, str): + self.file = open(filename_or_file, "wt") + self.own_file = True + else: + assert hasattr(filename_or_file, "read"), ( + "expected file or str, got %s" % filename_or_file + ) + self.file = filename_or_file + self.own_file = False + + def writekvs(self, kvs): + # Create strings for printing + key2str = {} + for key, val in sorted(kvs.items()): + if hasattr(val, "__float__"): + valstr = "%-8.3g" % val + else: + valstr = str(val) + key2str[self._truncate(key)] = self._truncate(valstr) + + # Find max widths + if len(key2str) == 0: + print("WARNING: tried to write empty key-value dict") + return + else: + keywidth = max(map(len, key2str.keys())) + valwidth = max(map(len, key2str.values())) + + # Write out the data + dashes = "-" * (keywidth + valwidth + 7) + lines = [dashes] + for key, val in sorted(key2str.items(), key=lambda kv: kv[0].lower()): + lines.append( + "| %s%s | %s%s |" + % (key, " " * (keywidth - len(key)), val, " " * (valwidth - len(val))) + ) + lines.append(dashes) + self.file.write("\n".join(lines) + "\n") + + # Flush the output to the file + self.file.flush() + + def _truncate(self, s): + maxlen = 30 + return s[: maxlen - 3] + "..." if len(s) > maxlen else s + + def writeseq(self, seq): + seq = list(seq) + for i, elem in enumerate(seq): + self.file.write(elem) + if i < len(seq) - 1: # add space unless this is the last one + self.file.write(" ") + self.file.write("\n") + self.file.flush() + + def close(self): + if self.own_file: + self.file.close() + + +class JSONOutputFormat(KVWriter): + def __init__(self, filename): + self.file = open(filename, "wt") + + def writekvs(self, kvs): + for k, v in sorted(kvs.items()): + if hasattr(v, "dtype"): + kvs[k] = float(v) + self.file.write(json.dumps(kvs) + "\n") + self.file.flush() + + def close(self): + self.file.close() + + +class CSVOutputFormat(KVWriter): + def __init__(self, filename): + self.file = open(filename, "w+t") + self.keys = [] + self.sep = "," + + def writekvs(self, kvs): + # Add our current row to the history + extra_keys = list(kvs.keys() - self.keys) + extra_keys.sort() + if extra_keys: + self.keys.extend(extra_keys) + self.file.seek(0) + lines = self.file.readlines() + self.file.seek(0) + for i, k in enumerate(self.keys): + if i > 0: + self.file.write(",") + self.file.write(k) + self.file.write("\n") + for line in lines[1:]: + self.file.write(line[:-1]) + self.file.write(self.sep * len(extra_keys)) + self.file.write("\n") + for i, k in enumerate(self.keys): + if i > 0: + self.file.write(",") + v = kvs.get(k) + if v is not None: + self.file.write(str(v)) + self.file.write("\n") + self.file.flush() + + def close(self): + self.file.close() + + +class TensorBoardOutputFormat(KVWriter): + """ + Dumps key/value pairs into TensorBoard's numeric format. + """ + + def __init__(self, dir): + os.makedirs(dir, exist_ok=True) + self.dir = dir + self.step = 1 + prefix = "events" + path = osp.join(osp.abspath(dir), prefix) + import tensorflow as tf + from tensorflow.python import pywrap_tensorflow + from tensorflow.core.util import event_pb2 + from tensorflow.python.util import compat + + self.tf = tf + self.event_pb2 = event_pb2 + self.pywrap_tensorflow = pywrap_tensorflow + self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path)) + + def writekvs(self, kvs): + def summary_val(k, v): + kwargs = {"tag": k, "simple_value": float(v)} + return self.tf.Summary.Value(**kwargs) + + summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()]) + event = self.event_pb2.Event(wall_time=time.time(), summary=summary) + event.step = ( + self.step + ) # is there any reason why you'd want to specify the step? + self.writer.WriteEvent(event) + self.writer.Flush() + self.step += 1 + + def close(self): + if self.writer: + self.writer.Close() + self.writer = None + + +def make_output_format(format, ev_dir, log_suffix=""): + os.makedirs(ev_dir, exist_ok=True) + if format == "stdout": + return HumanOutputFormat(sys.stdout) + elif format == "log": + return HumanOutputFormat(osp.join(ev_dir, "log%s.txt" % log_suffix)) + elif format == "json": + return JSONOutputFormat(osp.join(ev_dir, "progress%s.json" % log_suffix)) + elif format == "csv": + return CSVOutputFormat(osp.join(ev_dir, "progress%s.csv" % log_suffix)) + elif format == "tensorboard": + return TensorBoardOutputFormat(osp.join(ev_dir, "tb%s" % log_suffix)) + else: + raise ValueError("Unknown format specified: %s" % (format,)) + + +# ================================================================ +# API +# ================================================================ + + +def logkv(key, val): + """ + Log a value of some diagnostic + Call this once for each diagnostic quantity, each iteration + If called many times, last value will be used. + """ + get_current().logkv(key, val) + + +def logkv_mean(key, val): + """ + The same as logkv(), but if called many times, values averaged. + """ + get_current().logkv_mean(key, val) + + +def logkvs(d): + """ + Log a dictionary of key-value pairs + """ + for k, v in d.items(): + logkv(k, v) + + +def dumpkvs(): + """ + Write all of the diagnostics from the current iteration + """ + return get_current().dumpkvs() + + +def getkvs(): + return get_current().name2val + + +def log(*args, level=INFO): + """ + Write the sequence of args, with no separators, to the console and output files (if you've configured an output file). + """ + get_current().log(*args, level=level) + + +def debug(*args): + log(*args, level=DEBUG) + + +def info(*args): + log(*args, level=INFO) + + +def warn(*args): + log(*args, level=WARN) + + +def error(*args): + log(*args, level=ERROR) + + +def set_level(level): + """ + Set logging threshold on current logger. + """ + get_current().set_level(level) + + +def set_comm(comm): + get_current().set_comm(comm) + + +def get_dir(): + """ + Get directory that log files are being written to. + will be None if there is no output directory (i.e., if you didn't call start) + """ + return get_current().get_dir() + + +record_tabular = logkv +dump_tabular = dumpkvs + + +@contextmanager +def profile_kv(scopename): + logkey = "wait_" + scopename + tstart = time.time() + try: + yield + finally: + get_current().name2val[logkey] += time.time() - tstart + + +def profile(n): + """ + Usage: + @profile("my_func") + def my_func(): code + """ + + def decorator_with_name(func): + def func_wrapper(*args, **kwargs): + with profile_kv(n): + return func(*args, **kwargs) + + return func_wrapper + + return decorator_with_name + + +# ================================================================ +# Backend +# ================================================================ + + +def get_current(): + if Logger.CURRENT is None: + _configure_default_logger() + + return Logger.CURRENT + + +class Logger(object): + DEFAULT = None # A logger with no output files. (See right below class definition) + # So that you can still log to the terminal without setting up any output files + CURRENT = None # Current logger being used by the free functions above + + def __init__(self, dir, output_formats, comm=None): + self.name2val = defaultdict(float) # values this iteration + self.name2cnt = defaultdict(int) + self.level = INFO + self.dir = dir + self.output_formats = output_formats + self.comm = comm + + # Logging API, forwarded + # ---------------------------------------- + def logkv(self, key, val): + self.name2val[key] = val + + def logkv_mean(self, key, val): + oldval, cnt = self.name2val[key], self.name2cnt[key] + self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1) + self.name2cnt[key] = cnt + 1 + + def dumpkvs(self): + if self.comm is None: + d = self.name2val + else: + d = mpi_weighted_mean( + self.comm, + { + name: (val, self.name2cnt.get(name, 1)) + for (name, val) in self.name2val.items() + }, + ) + if self.comm.rank != 0: + d["dummy"] = 1 # so we don't get a warning about empty dict + out = d.copy() # Return the dict for unit testing purposes + for fmt in self.output_formats: + if isinstance(fmt, KVWriter): + fmt.writekvs(d) + self.name2val.clear() + self.name2cnt.clear() + return out + + def log(self, *args, level=INFO): + if self.level <= level: + self._do_log(args) + + # Configuration + # ---------------------------------------- + def set_level(self, level): + self.level = level + + def set_comm(self, comm): + self.comm = comm + + def get_dir(self): + return self.dir + + def close(self): + for fmt in self.output_formats: + fmt.close() + + # Misc + # ---------------------------------------- + def _do_log(self, args): + for fmt in self.output_formats: + if isinstance(fmt, SeqWriter): + fmt.writeseq(map(str, args)) + + +def get_rank_without_mpi_import(): + # check environment variables here instead of importing mpi4py + # to avoid calling MPI_Init() when this module is imported + for varname in ["PMI_RANK", "OMPI_COMM_WORLD_RANK"]: + if varname in os.environ: + return int(os.environ[varname]) + return 0 + + +def mpi_weighted_mean(comm, local_name2valcount): + """ + Copied from: https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/common/mpi_util.py#L110 + Perform a weighted average over dicts that are each on a different node + Input: local_name2valcount: dict mapping key -> (value, count) + Returns: key -> mean + """ + all_name2valcount = comm.gather(local_name2valcount) + if comm.rank == 0: + name2sum = defaultdict(float) + name2count = defaultdict(float) + for n2vc in all_name2valcount: + for name, (val, count) in n2vc.items(): + try: + val = float(val) + except ValueError: + if comm.rank == 0: + warnings.warn( + "WARNING: tried to compute mean on non-float {}={}".format( + name, val + ) + ) + else: + name2sum[name] += val * count + name2count[name] += count + return {name: name2sum[name] / name2count[name] for name in name2sum} + else: + return {} + + +def configure(dir="results", format_strs=None, comm=None, log_suffix=""): + """ + If comm is provided, average all numerical stats across that comm + """ + if dir is None: + dir = os.getenv("OPENAI_LOGDIR") + if dir is None: + dir = osp.join( + tempfile.gettempdir(), + datetime.datetime.now().strftime("openai-%Y-%m-%d-%H-%M-%S-%f"), + ) + assert isinstance(dir, str) + dir = os.path.expanduser(dir) + os.makedirs(os.path.expanduser(dir), exist_ok=True) + + rank = get_rank_without_mpi_import() + if rank > 0: + log_suffix = log_suffix + "-rank%03i" % rank + + if format_strs is None: + if rank == 0: + format_strs = os.getenv("OPENAI_LOG_FORMAT", "stdout,log,csv").split(",") + else: + format_strs = os.getenv("OPENAI_LOG_FORMAT_MPI", "log").split(",") + format_strs = filter(None, format_strs) + output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs] + + Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, comm=comm) + if output_formats: + log("Logging to %s" % dir) + + +def _configure_default_logger(): + configure() + Logger.DEFAULT = Logger.CURRENT + + +def reset(): + if Logger.CURRENT is not Logger.DEFAULT: + Logger.CURRENT.close() + Logger.CURRENT = Logger.DEFAULT + log("Reset logger") + + +@contextmanager +def scoped_configure(dir=None, format_strs=None, comm=None): + prevlogger = Logger.CURRENT + configure(dir=dir, format_strs=format_strs, comm=comm) + try: + yield + finally: + Logger.CURRENT.close() + Logger.CURRENT = prevlogger diff --git a/Trainer/models/guided_diffusion/losses.py b/Trainer/models/guided_diffusion/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..9281e1daa495ab4aa0cadb406705bb70b95652dd --- /dev/null +++ b/Trainer/models/guided_diffusion/losses.py @@ -0,0 +1,77 @@ +""" +Helpers for various likelihood-based losses. These are ported from the original +Ho et al. diffusion models codebase: +https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py +""" + +import numpy as np + +import torch + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + Compute the KL divergence between two gaussians. + + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, torch.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for torch.exp(). + logvar1, logvar2 = [ + x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) + for x in (logvar1, logvar2) + ] + + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + torch.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) + ) + + +def approx_standard_normal_cdf(x): + """ + A fast approximation of the cumulative distribution function of the + standard normal. + """ + return 0.5 * (1.0 + torch.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * torch.pow(x, 3)))) + + +def discretized_gaussian_log_likelihood(x, *, means, log_scales): + """ + Compute the log-likelihood of a Gaussian distribution discretizing to a + given image. + + :param x: the target images. It is assumed that this was uint8 values, + rescaled to the range [-1, 1]. + :param means: the Gaussian mean Tensor. + :param log_scales: the Gaussian log stddev Tensor. + :return: a tensor like x of log probabilities (in nats). + """ + assert x.shape == means.shape == log_scales.shape + centered_x = x - means + inv_stdv = torch.exp(-log_scales) + plus_in = inv_stdv * (centered_x + 1.0 / 255.0) + cdf_plus = approx_standard_normal_cdf(plus_in) + min_in = inv_stdv * (centered_x - 1.0 / 255.0) + cdf_min = approx_standard_normal_cdf(min_in) + log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12)) + log_one_minus_cdf_min = torch.log((1.0 - cdf_min).clamp(min=1e-12)) + cdf_delta = cdf_plus - cdf_min + log_probs = torch.where( + x < -0.999, + log_cdf_plus, + torch.where(x > 0.999, log_one_minus_cdf_min, torch.log(cdf_delta.clamp(min=1e-12))), + ) + assert log_probs.shape == x.shape + return log_probs diff --git a/Trainer/models/guided_diffusion/misc.py b/Trainer/models/guided_diffusion/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..2b64febc415a45d094eda566e7913d3d5465936b --- /dev/null +++ b/Trainer/models/guided_diffusion/misc.py @@ -0,0 +1,248 @@ +""" +Misc functions, including distributed helpers. + +Mostly copy-paste from torchvision references. +""" +import os +import shutil +import numpy as np +import nibabel as nib +from pathlib import Path +import SimpleITK as sitk + +import torch + + + +'''if float(torchvision.__version__[:3]) < 0.7: + from torchvision.ops import _new_empty_tensor + from torchvision.ops.misc import _output_size''' + + +def make_dir(dir_name, parents = True, exist_ok = True, reset = False): + if reset and os.path.isdir(dir_name): + shutil.rmtree(dir_name) + dir_name = Path(dir_name) + dir_name.mkdir(parents=parents, exist_ok=exist_ok) + return dir_name + + +def read_image(img_path, save_path = None): + img = nib.load(img_path) + nda = img.get_fdata() + affine = img.affine + if save_path: + ni_img = nib.Nifti1Image(nda, affine) + nib.save(ni_img, save_path) + return np.squeeze(nda), affine + +def save_image(nda, affine, save_path): + ni_img = nib.Nifti1Image(nda, affine) + nib.save(ni_img, save_path) + return save_path + +def img2nda(img_path, save_path = None): + img = sitk.ReadImage(img_path) + nda = sitk.GetArrayFromImage(img) + if save_path: + np.save(save_path, nda) + return nda, img.GetOrigin(), img.GetSpacing(), img.GetDirection() + +def to3d(img_path, save_path = None): + nda, o, s, d = img2nda(img_path) + save_path = img_path if save_path is None else save_path + if len(o) > 3: + nda2img(nda, o[:3], s[:3], d[:3] + d[4:7] + d[8:11], save_path) + return save_path + +def nda2img(nda, origin = None, spacing = None, direction = None, save_path = None, isVector = None): + if type(nda) == torch.Tensor: + nda = nda.cpu().detach().numpy() + nda = np.squeeze(np.array(nda)) + isVector = isVector if isVector else len(nda.shape) > 3 + img = sitk.GetImageFromArray(nda, isVector = isVector) + if origin: + img.SetOrigin(origin) + if spacing: + img.SetSpacing(spacing) + if direction: + img.SetDirection(direction) + if save_path: + sitk.WriteImage(img, save_path) + return img + + + +def cropping(img_path, tol = 0, crop_range_lst = None, spare = 0, save_path = None): + + img = sitk.ReadImage(img_path) + orig_nda = sitk.GetArrayFromImage(img) + if len(orig_nda.shape) > 3: # 4D data: last axis (t=0) as time dimension + nda = orig_nda[..., 0] + else: + nda = np.copy(orig_nda) + + if crop_range_lst is None: + # Mask of non-black pixels (assuming image has a single channel). + mask = nda > tol + # Coordinates of non-black pixels. + coords = np.argwhere(mask) + # Bounding box of non-black pixels. + x0, y0, z0 = coords.min(axis=0) + x1, y1, z1 = coords.max(axis=0) + 1 # slices are exclusive at the top + # add sparing gap if needed + x0 = x0 - spare if x0 > spare else x0 + y0 = y0 - spare if y0 > spare else y0 + z0 = z0 - spare if z0 > spare else z0 + x1 = x1 + spare if x1 < orig_nda.shape[0] - spare else x1 + y1 = y1 + spare if y1 < orig_nda.shape[1] - spare else y1 + z1 = z1 + spare if z1 < orig_nda.shape[2] - spare else z1 + + # Check the the bounding box # + #print(' Cropping Slice [%d, %d)' % (x0, x1)) + #print(' Cropping Row [%d, %d)' % (y0, y1)) + #print(' Cropping Column [%d, %d)' % (z0, z1)) + + else: + [[x0, y0, z0], [x1, y1, z1]] = crop_range_lst + + + cropped_nda = orig_nda[x0 : x1, y0 : y1, z0 : z1] + new_origin = [img.GetOrigin()[0] + img.GetSpacing()[0] * z0,\ + img.GetOrigin()[1] + img.GetSpacing()[1] * y0,\ + img.GetOrigin()[2] + img.GetSpacing()[2] * x0] # numpy reverse to sitk''' + cropped_img = sitk.GetImageFromArray(cropped_nda, isVector = len(orig_nda.shape) > 3) + cropped_img.SetOrigin(new_origin) + #cropped_img.SetOrigin(img.GetOrigin()) + cropped_img.SetSpacing(img.GetSpacing()) + cropped_img.SetDirection(img.GetDirection()) + if save_path: + sitk.WriteImage(cropped_img, save_path) + + return cropped_img, [[x0, y0, z0], [x1, y1, z1]], new_origin + + + + +def crop_and_pad(orig_nda, crop_idx = [], tol = 1e-7, pad_size = [224, 224, 224], to_print = True): + if len(crop_idx) < 2: + [[x0, y0, z0], [x1, y1, z1]] = crop(orig_nda, to_print = to_print) + else: + [[x0, y0, z0], [x1, y1, z1]] = crop_idx + nda = orig_nda[x0:x1, y0:y1, z0:z1] + nda = pad(nda, pad_size, to_print = to_print) + return nda, [[x0, y0, z0], [x1, y1, z1]] + + +def crop(orig_nda, tol = 1e-7, to_print = True): + + if len(orig_nda.shape) > 3: # 4D data: last axis (t=0) as time dimension + nda = orig_nda[..., 0] + else: + nda = np.copy(orig_nda) + + # Mask of non-black pixels (assuming image has a single channel). + mask = nda > tol + + # Coordinates of non-black pixels. + coords = np.argwhere(mask) + + # Bounding box of non-black pixels. + x0, y0, z0 = coords.min(axis=0) + x1, y1, z1 = coords.max(axis=0) + 1 # slices are exclusive at the top + + if to_print: + # Check the the bounding box # + print(' Cropping Slice [%d, %d)' % (x0, x1)) + print(' Cropping Row [%d, %d)' % (y0, y1)) + print(' Cropping Column [%d, %d)' % (z0, z1)) + + return [[x0, y0, z0], [x1, y1, z1]] + +def pad(orig_nda, pad_size = [224, 224, 224], to_print = True): + orig_shape = orig_nda.shape + to_pad_start = [int((pad_size[i] - orig_shape[i])/2) for i in range(3)] + + if to_print: + print(' orig shape:', orig_shape) + print(' pad start:', to_pad_start) + + new_nda = np.zeros(pad_size) + new_nda[to_pad_start[0]:to_pad_start[0]+orig_shape[0], + to_pad_start[1]:to_pad_start[1]+orig_shape[1], + to_pad_start[2]:to_pad_start[2]+orig_shape[2]] = orig_nda + + return new_nda + + +######################################### +######################################### + + +def viewVolume(x, aff=None, prefix='', postfix='', names=[], ext='.nii.gz', save_dir='/tmp'): + + if aff is None: + aff = np.eye(4) + else: + if type(aff) == torch.Tensor: + aff = aff.cpu().detach().numpy() + + if type(x) is dict: + names = list(x.keys()) + x = [x[k] for k in x] + + if type(x) is not list: + x = [x] + + #cmd = 'source /usr/local/freesurfer/nmr-dev-env-bash && freeview ' + + for n in range(len(x)): + vol = x[n] + if vol is not None: + if type(vol) == torch.Tensor: + vol = vol.cpu().detach().numpy() + vol = np.squeeze(np.array(vol)) + try: + save_path = os.path.join(save_dir, prefix + names[n] + postfix + ext) + except: + save_path = os.path.join(save_dir, prefix + str(n) + postfix + ext) + MRIwrite(vol, aff, save_path) + #cmd = cmd + ' ' + save_path + + #os.system(cmd + ' &') + return save_path + +###############################3 + +def MRIwrite(volume, aff, filename, dtype=None): + + if dtype is not None: + volume = volume.astype(dtype=dtype) + + if aff is None: + aff = np.eye(4) + header = nib.Nifti1Header() + nifty = nib.Nifti1Image(volume, aff, header) + + nib.save(nifty, filename) + +############################### + +def MRIread(filename, dtype=None, im_only=False): + # dtype example: 'int', 'float' + assert filename.endswith(('.nii', '.nii.gz', '.mgz')), 'Unknown data file: %s' % filename + + x = nib.load(filename) + volume = x.get_fdata() + aff = x.affine + + if dtype is not None: + volume = volume.astype(dtype=dtype) + + if im_only: + return volume + else: + return volume, aff + +############## + \ No newline at end of file diff --git a/Trainer/models/guided_diffusion/nn.py b/Trainer/models/guided_diffusion/nn.py new file mode 100644 index 0000000000000000000000000000000000000000..3386c2e327e1313164a66a152c86d5b4cac78674 --- /dev/null +++ b/Trainer/models/guided_diffusion/nn.py @@ -0,0 +1,171 @@ +""" +Various utilities for neural networks. +""" + +import math + +import torch +import torch.nn as nn + + +# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +class SiLU(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def update_ema(target_params, source_params, rate=0.99): + """ + Update target parameters to be closer to those of source parameters using + an exponential moving average. + + :param target_params: the target parameter sequence. + :param source_params: the source parameter sequence. + :param rate: the EMA rate (closer to 1 means slower). + """ + for targ, src in zip(target_params, source_params): + targ.detach().mul_(rate).add_(src, alpha=1 - rate) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + + +def normalization(channels): + """ + Make a standard normalization layer. + + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(32, channels) + + +def timestep_embedding(timesteps, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([th.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads diff --git a/Trainer/models/guided_diffusion/openaimodel_pseudo3D.py b/Trainer/models/guided_diffusion/openaimodel_pseudo3D.py new file mode 100644 index 0000000000000000000000000000000000000000..a6ed135c5ed467d38ed77a8f5118546ea6bea5cb --- /dev/null +++ b/Trainer/models/guided_diffusion/openaimodel_pseudo3D.py @@ -0,0 +1,1088 @@ +# From Zhu, L. et al. (2023). Make-A-Volume: Leveraging Latent Diffusion Models for Cross-Modality 3D Brain MRI Synthesis. In: Greenspan, H., et al. Medical Image Computing and Computer Assisted Intervention – MICCAI 2023. MICCAI 2023. Lecture Notes in Computer Science, vol 14229. Springer, Cham. https://doi.org/10.1007/978-3-031-43999-5_56 + + +from abc import abstractmethod +from functools import partial +import math +from typing import Iterable + +from einops import rearrange +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from util import ( + checkpoint, + conv_nd, + linear, + avg_pool_nd, + zero_module, + normalization, + timestep_embedding, +) + +from attention import SpatialTransformer + + +def exists(val): + return val is not None + + +def default(val, d): + return val if exists(val) else d + + +# dummy replace +def convert_module_to_f16(x): + pass + + +def convert_module_to_f32(x): + pass + + +class PseudoConv3d(nn.Module): + def __init__( + self, dim, dim_out=None, kernel_size=3, *, temporal_kernel_size=None, **kwargs + ): + super().__init__() + dim_out = default(dim_out, dim) + temporal_kernel_size = default(temporal_kernel_size, kernel_size) + + # self.spatial_conv = nn.Conv2d(dim, dim_out, kernel_size = kernel_size, padding = kernel_size // 2) + self.temporal_conv = ( + nn.Conv1d( + dim_out, + dim_out, + kernel_size=temporal_kernel_size, + padding=temporal_kernel_size // 2, + ) + if kernel_size > 1 + else None + ) + + if exists(self.temporal_conv): + nn.init.dirac_(self.temporal_conv.weight.data) # initialized to be identity + nn.init.zeros_(self.temporal_conv.bias.data) + + def forward(self, x, enable_time=True): + _, _, h, w = x.shape + + # x = self.spatial_conv(x) + + if not enable_time or not exists(self.temporal_conv): + return x + + # Here f is hard-coded. + # Could be set as the slice num in one volume + # or smaller window size (requires autoregressively sampling) + if x.shape[0] >= 16: + x = rearrange(x, "(b f) c h w -> (b h w) c f", f=16) + else: + x = rearrange(x, "(b f) c h w -> (b h w) c f", f=x.shape[0]) + x = self.temporal_conv(x) + + x = rearrange(x, "(b h w) c f -> (b f) c h w", h=h, w=w) + + return x + + +## go +class AttentionPool2d(nn.Module): + """ + Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py + """ + + def __init__( + self, + spacial_dim: int, + embed_dim: int, + num_heads_channels: int, + output_dim: int = None, + ): + super().__init__() + self.positional_embedding = nn.Parameter( + torch.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5 + ) + self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) + self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) + self.num_heads = embed_dim // num_heads_channels + self.attention = QKVAttention(self.num_heads) + + def forward(self, x): + b, c, *_spatial = x.shape + x = x.reshape(b, c, -1) # NC(HW) + x = torch.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) + x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) + x = self.qkv_proj(x) + x = self.attention(x) + x = self.c_proj(x) + return x[:, :, 0] + + +class TimestepBlock(nn.Module): + """ + Any module where forward() takes timestep embeddings as a second argument. + """ + + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + + def forward(self, x, emb, context=None): + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb) + elif isinstance(layer, SpatialTransformer): + x = layer(x, context) + else: + x = layer(x) + return x + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd( + dims, self.channels, self.out_channels, 3, padding=padding + ) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate( + x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" + ) + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + + +class TransposedUpsample(nn.Module): + "Learned 2x upsampling without padding" + + def __init__(self, channels, out_channels=None, ks=5): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + + self.up = nn.ConvTranspose2d( + self.channels, self.out_channels, kernel_size=ks, stride=2 + ) + + def forward(self, x): + return self.up(x) + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd( + dims, + self.channels, + self.out_channels, + 3, + stride=stride, + padding=padding, + ) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class ResBlock(TimestepBlock): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: if True, use gradient checkpointing on this module. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + use_checkpoint=False, + up=False, + down=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + conv_nd(dims, channels, self.out_channels, 3, padding=1), + ) + self.in_layers_tem = PseudoConv3d(channels, self.out_channels, 3) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.SiLU(), + linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module( + conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) + ), + # conv_nd(dims, channels, self.out_channels, 3, padding=1), + ) + self.out_layers_tem = PseudoConv3d(self.out_channels, self.out_channels, 3) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 3, padding=1 + ) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + def forward(self, x, emb): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + return checkpoint( + self._forward, (x, emb), self.parameters(), self.use_checkpoint + ) + + def _forward(self, x, emb): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + h = self.in_layers_tem(h) + else: + h = self.in_layers(x) + h = self.in_layers_tem(h) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = torch.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + h = self.out_layers_tem(h) + else: + h = h + emb_out + h = self.out_layers(h) + h = self.out_layers_tem(h) + return self.skip_connection(x) + h + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. + Originally ported from here, but adapted to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + """ + + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + use_checkpoint=False, + use_new_attention_order=False, + ): + super().__init__() + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.use_checkpoint = use_checkpoint + self.norm = normalization(channels) + self.qkv = conv_nd(1, channels, channels * 3, 1) + if use_new_attention_order: + # split qkv before split heads + self.attention = QKVAttention(self.num_heads) + else: + # split heads before split qkv + self.attention = QKVAttentionLegacy(self.num_heads) + + self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) + + def forward(self, x): + return checkpoint( + self._forward, (x,), self.parameters(), True + ) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! + # return pt_checkpoint(self._forward, x) # pytorch + + def _forward(self, x): + b, c, *spatial = x.shape + x = x.reshape(b, c, -1) + qkv = self.qkv(self.norm(x)) + h = self.attention(qkv) + h = self.proj_out(h) + return (x + h).reshape(b, c, *spatial) + + +def count_flops_attn(model, _x, y): + """ + A counter for the `thop` package to count the operations in an + attention operation. + Meant to be used like: + macs, params = thop.profile( + model, + inputs=(inputs, timestamps), + custom_ops={QKVAttention: QKVAttention.count_flops}, + ) + """ + b, c, *spatial = y[0].shape + num_spatial = int(np.prod(spatial)) + # We perform two matmuls with the same number of ops. + # The first computes the weight matrix, the second computes + # the combination of the value vectors. + matmul_ops = 2 * b * (num_spatial**2) * c + model.total_ops += torch.DoubleTensor([matmul_ops]) + + +class QKVAttentionLegacy(nn.Module): + """ + A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = torch.einsum( + "bct,bcs->bts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + a = torch.einsum("bts,bcs->bct", weight, v) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class QKVAttention(nn.Module): + """ + A module which performs QKV attention and splits in a different order. + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.chunk(3, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = torch.einsum( + "bct,bcs->bts", + (q * scale).view(bs * self.n_heads, ch, length), + (k * scale).view(bs * self.n_heads, ch, length), + ) # More stable with f16 than dividing afterwards + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + a = torch.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class UNetModel(nn.Module): + """ + The full UNet model with attention and timestep embedding. + :param in_channels: channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param num_res_blocks: number of residual blocks per downsample. + :param attention_resolutions: a collection of downsample rates at which + attention will take place. May be a set, list, or tuple. + For example, if this contains 4, then at 4x downsampling, attention + will be used. + :param dropout: the dropout probability. + :param channel_mult: channel multiplier for each level of the UNet. + :param conv_resample: if True, use learned convolutions for upsampling and + downsampling. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param num_classes: if specified (as an int), then this model will be + class-conditional with `num_classes` classes. + :param use_checkpoint: use gradient checkpointing to reduce memory usage. + :param num_heads: the number of attention heads in each attention layer. + :param num_heads_channels: if specified, ignore num_heads and instead use + a fixed channel width per attention head. + :param num_heads_upsample: works with num_heads to set a different number + of heads for upsampling. Deprecated. + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. + :param resblock_updown: use residual blocks for up/downsampling. + :param use_new_attention_order: use a different attention pattern for potentially + increased efficiency. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + num_classes=None, + use_checkpoint=False, + use_fp16=False, + num_heads=-1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + use_spatial_transformer=False, # custom transformer support + transformer_depth=1, # custom transformer support + context_dim=None, # custom transformer support + n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model + legacy=True, + ): + super().__init__() + if use_spatial_transformer: + assert ( + context_dim is not None + ), "Fool!! You forgot to include the dimension of your cross-attention conditioning..." + + if context_dim is not None: + assert ( + use_spatial_transformer + ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..." + from omegaconf.listconfig import ListConfig + + if type(context_dim) == ListConfig: + context_dim = list(context_dim) + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + if num_heads == -1: + assert ( + num_head_channels != -1 + ), "Either num_heads or num_head_channels has to be set" + + if num_head_channels == -1: + assert ( + num_heads != -1 + ), "Either num_heads or num_head_channels has to be set" + + self.image_size = image_size + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.use_checkpoint = use_checkpoint + self.dtype = torch.float16 if use_fp16 else torch.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.predict_codebook_ids = n_embed is not None + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + if self.num_classes is not None: + self.label_emb = nn.Embedding(num_classes, time_embed_dim) + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + # num_heads = 1 + dim_head = ( + ch // num_heads + if use_spatial_transformer + else num_head_channels + ) + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) + if not use_spatial_transformer + else SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim, + use_checkpoint=use_checkpoint, + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + # num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) + if not use_spatial_transformer + else SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim, + use_checkpoint=use_checkpoint, + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(num_res_blocks + 1): + ich = input_block_chans.pop() + layers = [ + ResBlock( + ch + ich, + time_embed_dim, + dropout, + out_channels=model_channels * mult, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = model_channels * mult + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + # num_heads = 1 + dim_head = ( + ch // num_heads + if use_spatial_transformer + else num_head_channels + ) + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads_upsample, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) + if not use_spatial_transformer + else SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim, + use_checkpoint=use_checkpoint, + ) + ) + if level and i == num_res_blocks: + out_ch = ch + layers.append( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), + ) + if self.predict_codebook_ids: + self.id_predictor = nn.Sequential( + normalization(ch), + conv_nd(dims, model_channels, n_embed, 1), + # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits + ) + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + self.output_blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + self.output_blocks.apply(convert_module_to_f32) + + def forward(self, x, timesteps=None, context=None, y=None, **kwargs): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :param context: conditioning plugged in via crossattn + :param y: an [N] Tensor of labels, if class-conditional. + :return: an [N x C x ...] Tensor of outputs. + """ + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + emb = self.time_embed(t_emb) + + if self.num_classes is not None: + assert y.shape == (x.shape[0],) + emb = emb + self.label_emb(y) + + print('x UNet', x.shape, emb.shape, context) # (16, 3, 224, 224) + + h = x.type(self.dtype) + for module in self.input_blocks: + print('-- input_blocks in', h.shape) + h = module(h, emb, context) + print('-- input_blocks out', h.shape) + hs.append(h) + h = self.middle_block(h, emb, context) + for module in self.output_blocks: + print('-- output_blocks in 1', h.shape) + h = torch.cat([h, hs.pop()], dim=1) + print('-- output_blocks in 2', h.shape) + h = module(h, emb, context) + print('-- output_blocks out', h.shape) + h = h.type(x.dtype) + if self.predict_codebook_ids: + return self.id_predictor(h) + else: # THIS + print('-- UNet out', h.shape, self.out(h).shape) + return self.out(h) # (16, 2, 224, 224) + + + +class EncoderUNetModel(nn.Module): + """ + The half UNet model with attention and timestep embedding. + For usage, see UNet. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + use_checkpoint=False, + use_fp16=False, + num_heads=1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + pool="adaptive", + *args, + **kwargs, + ): + super().__init__() + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.use_checkpoint = use_checkpoint + self.dtype = torch.float16 if use_fp16 else torch.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + self.pool = pool + if pool == "adaptive": + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + nn.AdaptiveAvgPool2d((1, 1)), + zero_module(conv_nd(dims, ch, out_channels, 1)), + nn.Flatten(), + ) + elif pool == "attention": + assert num_head_channels != -1 + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + AttentionPool2d( + (image_size // ds), ch, num_head_channels, out_channels + ), + ) + elif pool == "spatial": + self.out = nn.Sequential( + nn.Linear(self._feature_size, 2048), + nn.ReLU(), + nn.Linear(2048, self.out_channels), + ) + elif pool == "spatial_v2": + self.out = nn.Sequential( + nn.Linear(self._feature_size, 2048), + normalization(2048), + nn.SiLU(), + nn.Linear(2048, self.out_channels), + ) + else: + raise NotImplementedError(f"Unexpected {pool} pooling") + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + + def forward(self, x, timesteps): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :return: an [N x K] Tensor of outputs. + """ + emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + + results = [] + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb) + if self.pool.startswith("spatial"): + results.append(h.type(x.dtype).mean(dim=(2, 3))) + h = self.middle_block(h, emb) + if self.pool.startswith("spatial"): + results.append(h.type(x.dtype).mean(dim=(2, 3))) + h = torch.cat(results, axis=-1) + return self.out(h) + else: + h = h.type(x.dtype) + return self.out(h) diff --git a/Trainer/models/guided_diffusion/preprocess.py b/Trainer/models/guided_diffusion/preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..b170ad3c5c975e4d293cdb648519431193461ee9 --- /dev/null +++ b/Trainer/models/guided_diffusion/preprocess.py @@ -0,0 +1,59 @@ +import os +import nibabel + +import numpy as np +from misc import crop_and_pad, viewVolume, MRIread, make_dir + + +test_flag = False + +directory = '/autofs/space/yogurt_004/users/pl629/ASNR-MICCAI-BraTS2023-Local-Synthesis-Challenge-Training' +new_directory = make_dir('/autofs/space/yogurt_004/users/pl629/ASNR-MICCAI-BraTS2023-Local-Synthesis-Challenge-Training_CropPad') + +database = [] +mask_vis = [] +pad_size = [224, 224, 224] + +for root, dirs, files in os.walk(directory): + dirs_sorted = sorted(dirs) + for dir_id in dirs_sorted: + datapoint = dict() + sli_dict = dict() + for ro, di, fi in os.walk(root + "/" + str(dir_id)): + fi_sorted = sorted(fi) + + assert os.path.isfile(os.path.join(root, dir_id, dir_id + '-t1n.nii.gz')) + + new_dir = make_dir(os.path.join(new_directory, dir_id)) + print('Create new case dir:', new_dir) + '''try: + to_crop, aff = MRIread(os.path.join(root, dir_id, dir_id + '-t1n.nii.gz'), im_only=False, dtype='float') + _, crop_idx = crop_and_pad(to_crop, pad_size = pad_size, to_print = True) + print('-- crop_idx:', crop_idx) + except: + raise NotImplementedError + + + for f in fi_sorted: + seqtype = f.split("-")[-1].split(".")[0] + datapoint[seqtype] = os.path.join(root, dir_id, f) + print('-- current filename:', f) + print('-- current seqtype:', seqtype) + print('-- to save in new_dir:', os.path.join(new_dir, f.split('.')[0] + '.nii.gz')) + + curr_nda, _ = MRIread(os.path.join(root, dir_id, f), im_only=False, dtype='float') + new_curr_nda, _ = crop_and_pad(curr_nda, crop_idx, pad_size = pad_size, to_print = False) + viewVolume(new_curr_nda, aff, names = [f.split('.')[0]], save_dir = new_dir)''' + + + nda, aff = MRIread(os.path.join(new_dir, dir_id + '-t1n.nii.gz'), im_only=False, dtype='float') + mask, _ = MRIread(os.path.join(new_dir, dir_id + '-mask-healthy.nii.gz'), im_only=False, dtype='float') + viewVolume(nda * (1 - mask), aff, names = [dir_id + '-t1n-healthyvoided'], save_dir = new_dir) + + database.append(datapoint) + + #exit() + + break + +print('Total num of cases:', len(database)) \ No newline at end of file diff --git a/Trainer/models/guided_diffusion/resample.py b/Trainer/models/guided_diffusion/resample.py new file mode 100644 index 0000000000000000000000000000000000000000..e2633a633de9e3aec6a4b2ea938d017d3acb5ac6 --- /dev/null +++ b/Trainer/models/guided_diffusion/resample.py @@ -0,0 +1,154 @@ +from abc import ABC, abstractmethod + +import numpy as np +import torch +import torch.distributed as dist + + +def create_named_schedule_sampler(name, diffusion, maxt): + """ + Create a ScheduleSampler from a library of pre-defined samplers. + + :param name: the name of the sampler. + :param diffusion: the diffusion object to sample for. + """ + if name == "uniform": + return UniformSampler(diffusion, maxt) + elif name == "loss-second-moment": + return LossSecondMomentResampler(diffusion) + else: + raise NotImplementedError(f"unknown schedule sampler: {name}") + + +class ScheduleSampler(ABC): + """ + A distribution over timesteps in the diffusion process, intended to reduce + variance of the objective. + + By default, samplers perform unbiased importance sampling, in which the + objective's mean is unchanged. + However, subclasses may override sample() to change how the resampled + terms are reweighted, allowing for actual changes in the objective. + """ + + @abstractmethod + def weights(self): + """ + Get a numpy array of weights, one per diffusion step. + + The weights needn't be normalized, but must be positive. + """ + + def sample(self, batch_size, device): + """ + Importance-sample timesteps for a batch. + + :param batch_size: the number of timesteps. + :param device: the torch device to save to. + :return: a tuple (timesteps, weights): + - timesteps: a tensor of timestep indices. + - weights: a tensor of weights to scale the resulting losses. + """ + w = self.weights() + p = w / np.sum(w) + indices_np = np.random.choice(len(p), size=(batch_size,), p=p) + indices = torch.from_numpy(indices_np).long().to(device) + weights_np = 1 / (len(p) * p[indices_np]) + weights = torch.from_numpy(weights_np).float().to(device) + return indices, weights + + +class UniformSampler(ScheduleSampler): + def __init__(self, diffusion, maxt): + self.diffusion = diffusion + self._weights = np.ones([maxt]) + + def weights(self): + return self._weights + + +class LossAwareSampler(ScheduleSampler): + def update_with_local_losses(self, local_ts, local_losses): + """ + Update the reweighting using losses from a model. + + Call this method from each rank with a batch of timesteps and the + corresponding losses for each of those timesteps. + This method will perform synchronization to make sure all of the ranks + maintain the exact same reweighting. + + :param local_ts: an integer Tensor of timesteps. + :param local_losses: a 1D Tensor of losses. + """ + batch_sizes = [ + torch.tensor([0], dtype=torch.int32, device=local_ts.device) + for _ in range(dist.get_world_size()) + ] + dist.all_gather( + batch_sizes, + torch.tensor([len(local_ts)], dtype=torch.int32, device=local_ts.device), + ) + + # Pad all_gather batches to be the maximum batch size. + batch_sizes = [x.item() for x in batch_sizes] + max_bs = max(batch_sizes) + + timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] + loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] + dist.all_gather(timestep_batches, local_ts) + dist.all_gather(loss_batches, local_losses) + timesteps = [ + x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] + ] + losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] + self.update_with_all_losses(timesteps, losses) + + @abstractmethod + def update_with_all_losses(self, ts, losses): + """ + Update the reweighting using losses from a model. + + Sub-classes should override this method to update the reweighting + using losses from the model. + + This method directly updates the reweighting without synchronizing + between workers. It is called by update_with_local_losses from all + ranks with identical arguments. Thus, it should have deterministic + behavior to maintain state across workers. + + :param ts: a list of int timesteps. + :param losses: a list of float losses, one per timestep. + """ + + +class LossSecondMomentResampler(LossAwareSampler): + def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): + self.diffusion = diffusion + self.history_per_term = history_per_term + self.uniform_prob = uniform_prob + self._loss_history = np.zeros( + [diffusion.num_timesteps, history_per_term], dtype=np.float64 + ) + self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) + + def weights(self): + if not self._warmed_up(): + return np.ones([self.diffusion.num_timesteps], dtype=np.float64) + weights = np.sqrt(np.mean(self._loss_history**2, axis=-1)) + weights /= np.sum(weights) + weights *= 1 - self.uniform_prob + weights += self.uniform_prob / len(weights) + return weights + + def update_with_all_losses(self, ts, losses): + for t, loss in zip(ts, losses): + if self._loss_counts[t] == self.history_per_term: + # Shift out the oldest loss term. + self._loss_history[t, :-1] = self._loss_history[t, 1:] + self._loss_history[t, -1] = loss + else: + self._loss_history[t, self._loss_counts[t]] = loss + self._loss_counts[t] += 1 + + def _warmed_up(self): + return (self._loss_counts == self.history_per_term).all() diff --git a/Trainer/models/guided_diffusion/respace.py b/Trainer/models/guided_diffusion/respace.py new file mode 100644 index 0000000000000000000000000000000000000000..bae7444da9d439f3bf4cd0c0ae0962287ada0825 --- /dev/null +++ b/Trainer/models/guided_diffusion/respace.py @@ -0,0 +1,150 @@ +import numpy as np +import torch + +from gaussian_diffusion import GaussianDiffusion + + +def space_timesteps(num_timesteps, section_counts): + """ + Create a list of timesteps to use from an original diffusion process, + given the number of timesteps we want to take from equally-sized portions + of the original process. + + For example, if there's 300 timesteps and the section counts are [10,15,20] + then the first 100 timesteps are strided to be 10 timesteps, the second 100 + are strided to be 15 timesteps, and the final 100 are strided to be 20. + + If the stride is a string starting with "ddim", then the fixed striding + from the DDIM paper is used, and only one section is allowed. + + :param num_timesteps: the number of diffusion steps in the original + process to divide up. + :param section_counts: either a list of numbers, or a string containing + comma-separated numbers, indicating the step count + per section. As a special case, use "ddimN" where N + is a number of steps to use the striding from the + DDIM paper. + :return: a set of diffusion steps from the original process to use. + """ + if isinstance(section_counts, str): + if section_counts.startswith("ddim"): + desired_count = int(section_counts[len("ddim") :]) + for i in range(1, num_timesteps): + if len(range(0, num_timesteps, i)) == desired_count: + return set(range(0, num_timesteps, i)) + raise ValueError( + f"cannot create exactly {num_timesteps} steps with an integer stride" + ) + section_counts = [int(x) for x in section_counts.split(",")] + size_per = num_timesteps // len(section_counts) + extra = num_timesteps % len(section_counts) + start_idx = 0 + all_steps = [] + for i, section_count in enumerate(section_counts): + size = size_per + (1 if i < extra else 0) + if size < section_count: + raise ValueError( + f"cannot divide section of {size} steps into {section_count}" + ) + if section_count <= 1: + frac_stride = 1 + else: + frac_stride = (size - 1) / (section_count - 1) + cur_idx = 0.0 + taken_steps = [] + for _ in range(section_count): + taken_steps.append(start_idx + round(cur_idx)) + cur_idx += frac_stride + all_steps += taken_steps + start_idx += size + return set(all_steps) + + +class SpacedDiffusion(GaussianDiffusion): + """ + A diffusion process which can skip steps in a base diffusion process. + + :param use_timesteps: a collection (sequence or set) of timesteps from the + original diffusion process to retain. + :param kwargs: the kwargs to create the base diffusion process. + """ + + def __init__(self, use_timesteps, **kwargs): + self.use_timesteps = set(use_timesteps) + self.timestep_map = [] + self.original_num_steps = len(kwargs["betas"]) + base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa + last_alpha_cumprod = 1.0 + new_betas = [] + for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): + if i in self.use_timesteps: + new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) + last_alpha_cumprod = alpha_cumprod + self.timestep_map.append(i) + kwargs["betas"] = np.array(new_betas) + super().__init__(**kwargs) + + def p_mean_variance( + self, model, *args, **kwargs + ): # pylint: disable=signature-differs + return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) + + def training_losses( + self, model, *args, **kwargs + ): # pylint: disable=signature-differs + return super().training_losses(self._wrap_model(model), *args, **kwargs) + + def condition_mean(self, cond_fn, *args, **kwargs): + return super().condition_mean(self._wrap_model2(cond_fn), *args, **kwargs) + + def condition_score(self, cond_fn, *args, **kwargs): + return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) + + def _wrap_model(self, model): + if isinstance(model, _WrappedModel): + return model + return _WrappedModel( + model, self.timestep_map, self.rescale_timesteps, self.original_num_steps + ) + + def _wrap_model2(self, model): + if isinstance(model, _WrappedModel2): + return model + return _WrappedModel2( + model, self.timestep_map, self.rescale_timesteps, self.original_num_steps + ) + + def _scale_timesteps(self, t): + # Scaling is done by the wrapped model. + return t + + +class _WrappedModel: + def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps): + self.model = model + self.timestep_map = timestep_map + self.rescale_timesteps = rescale_timesteps + self.original_num_steps = original_num_steps + + def __call__(self, x, ts, **kwargs): + map_tensor = torch.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) + new_ts = map_tensor[ts] + + if self.rescale_timesteps: + new_ts = new_ts.float() * (1000.0 / self.original_num_steps) + return self.model(x, new_ts, **kwargs) + + +class _WrappedModel2: + def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps): + self.model = model + self.timestep_map = timestep_map + self.rescale_timesteps = rescale_timesteps + self.original_num_steps = original_num_steps + + def __call__(self, x, ts, org, **kwargs): + map_tensor = torch.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) + new_ts = map_tensor[ts] + if self.rescale_timesteps: + new_ts = new_ts.float() * (1000.0 / self.original_num_steps) + return self.model(x, new_ts, org, **kwargs) diff --git a/Trainer/models/guided_diffusion/script_util.py b/Trainer/models/guided_diffusion/script_util.py new file mode 100644 index 0000000000000000000000000000000000000000..2c26d9e3277e492e824faeeafa4706f98052b338 --- /dev/null +++ b/Trainer/models/guided_diffusion/script_util.py @@ -0,0 +1,462 @@ +import argparse +import inspect + +import gaussian_diffusion as gd +from respace import SpacedDiffusion, space_timesteps + +# from .unet import SuperResModel, UNetModel, EncoderUNetModel +from openaimodel_pseudo3D import UNetModel, EncoderUNetModel + +NUM_CLASSES = 2 + + +def diffusion_defaults(): + """ + Defaults for image and classifier training. + """ + return dict( + learn_sigma=True, # False + diffusion_steps=1000, + noise_schedule="linear", + timestep_respacing="", + use_kl=False, + predict_xstart=False, + rescale_timesteps=False, + rescale_learned_sigmas=False, + ) + + +def classifier_defaults(): + """ + Defaults for classifier models. + """ + return dict( + image_size=64, + classifier_use_fp16=False, + classifier_width=128, + classifier_depth=2, + classifier_attention_resolutions="32,16,8", # 16 + classifier_use_scale_shift_norm=True, # False + classifier_resblock_updown=True, # False + classifier_pool="spatial", + ) + + +def model_and_diffusion_defaults(): + """ + Defaults for image training. + """ + res = dict( + image_size=128, # 256 + num_channels=128, + num_res_blocks=2, + num_heads=1, # 4 + num_heads_upsample=-1, + num_head_channels=-1, + attention_resolutions="16", # "16,8" + channel_mult="", + dropout=0.0, + class_cond=False, + use_checkpoint=False, + use_scale_shift_norm=False, # True + resblock_updown=False, + use_fp16=False, + use_new_attention_order=False, + ) + res.update(diffusion_defaults()) + return res + + +def classifier_and_diffusion_defaults(): + res = classifier_defaults() + res.update(diffusion_defaults()) + return res + + +def create_model_and_diffusion( + image_size, + class_cond, + learn_sigma, + num_channels, + num_res_blocks, + channel_mult, + num_heads, + num_head_channels, + num_heads_upsample, + attention_resolutions, + dropout, + diffusion_steps, + noise_schedule, + timestep_respacing, + use_kl, + predict_xstart, + rescale_timesteps, + rescale_learned_sigmas, + use_checkpoint, + use_scale_shift_norm, + resblock_updown, + use_fp16, + use_new_attention_order, +): + model = create_model( + image_size, + num_channels, + num_res_blocks, + channel_mult=channel_mult, + learn_sigma=learn_sigma, + class_cond=class_cond, + use_checkpoint=use_checkpoint, + attention_resolutions=attention_resolutions, + num_heads=num_heads, + num_head_channels=num_head_channels, + num_heads_upsample=num_heads_upsample, + use_scale_shift_norm=use_scale_shift_norm, + dropout=dropout, + resblock_updown=resblock_updown, + use_fp16=use_fp16, + use_new_attention_order=use_new_attention_order, + ) + diffusion = create_gaussian_diffusion( + steps=diffusion_steps, + learn_sigma=learn_sigma, + noise_schedule=noise_schedule, + use_kl=use_kl, + predict_xstart=predict_xstart, + rescale_timesteps=rescale_timesteps, + rescale_learned_sigmas=rescale_learned_sigmas, + timestep_respacing=timestep_respacing, + ) + return model, diffusion + + +def create_model( + image_size, + num_channels, + num_res_blocks, + channel_mult="", + learn_sigma=False, + class_cond=False, + use_checkpoint=False, + attention_resolutions="16", + num_heads=1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + dropout=0, + resblock_updown=False, + use_fp16=False, + use_new_attention_order=False, +): + if channel_mult == "": + if image_size == 512: + channel_mult = (1, 1, 2, 2, 4, 4) + elif image_size == 256: + channel_mult = (1, 1, 2, 2, 4, 4) + elif image_size == 128: + channel_mult = (1, 1, 2, 3, 4) + elif image_size == 64: + channel_mult = (1, 2, 3, 4) + else: + raise ValueError(f"unsupported image size: {image_size}") + else: + channel_mult = tuple(int(ch_mult) for ch_mult in channel_mult.split(",")) + + attention_ds = [] + for res in attention_resolutions.split(","): + attention_ds.append(image_size // int(res)) + + return UNetModel( + image_size=image_size, + in_channels=3, + model_channels=num_channels, + out_channels=2, # (3 if not learn_sigma else 6), + num_res_blocks=num_res_blocks, + attention_resolutions=tuple(attention_ds), + dropout=dropout, + channel_mult=channel_mult, + num_classes=(NUM_CLASSES if class_cond else None), + use_checkpoint=use_checkpoint, + use_fp16=use_fp16, + num_heads=num_heads, + num_head_channels=num_head_channels, + num_heads_upsample=num_heads_upsample, + use_scale_shift_norm=use_scale_shift_norm, + resblock_updown=resblock_updown, + use_new_attention_order=use_new_attention_order, + ) + + +def create_classifier_and_diffusion( + image_size, + classifier_use_fp16, + classifier_width, + classifier_depth, + classifier_attention_resolutions, + classifier_use_scale_shift_norm, + classifier_resblock_updown, + classifier_pool, + learn_sigma, + diffusion_steps, + noise_schedule, + timestep_respacing, + use_kl, + predict_xstart, + rescale_timesteps, + rescale_learned_sigmas, +): + classifier = create_classifier( + image_size, + classifier_use_fp16, + classifier_width, + classifier_depth, + classifier_attention_resolutions, + classifier_use_scale_shift_norm, + classifier_resblock_updown, + classifier_pool, + ) + diffusion = create_gaussian_diffusion( + steps=diffusion_steps, + learn_sigma=learn_sigma, + noise_schedule=noise_schedule, + use_kl=use_kl, + predict_xstart=predict_xstart, + rescale_timesteps=rescale_timesteps, + rescale_learned_sigmas=rescale_learned_sigmas, + timestep_respacing=timestep_respacing, + ) + return classifier, diffusion + + +def create_classifier( + image_size, + classifier_use_fp16, + classifier_width, + classifier_depth, + classifier_attention_resolutions, + classifier_use_scale_shift_norm, + classifier_resblock_updown, + classifier_pool, +): + if image_size == 256: + channel_mult = (1, 1, 2, 2, 4, 4) + elif image_size == 128: + channel_mult = (1, 1, 2, 3, 4) + elif image_size == 64: + channel_mult = (1, 2, 3, 4) + else: + raise ValueError(f"unsupported image size: {image_size}") + + attention_ds = [] + for res in classifier_attention_resolutions.split(","): + attention_ds.append(image_size // int(res)) + + return EncoderUNetModel( + image_size=image_size, + in_channels=3, + model_channels=classifier_width, + out_channels=2, # 1000, + num_res_blocks=classifier_depth, + attention_resolutions=tuple(attention_ds), + channel_mult=channel_mult, + use_fp16=classifier_use_fp16, + num_head_channels=64, + use_scale_shift_norm=classifier_use_scale_shift_norm, + resblock_updown=classifier_resblock_updown, + pool=classifier_pool, + ) + + +def sr_model_and_diffusion_defaults(): + res = model_and_diffusion_defaults() + res["large_size"] = 256 + res["small_size"] = 64 + arg_names = inspect.getfullargspec(sr_create_model_and_diffusion)[0] + for k in res.copy().keys(): + if k not in arg_names: + del res[k] + return res + + +def sr_create_model_and_diffusion( + large_size, + small_size, + class_cond, + learn_sigma, + num_channels, + num_res_blocks, + num_heads, + num_head_channels, + num_heads_upsample, + attention_resolutions, + dropout, + diffusion_steps, + noise_schedule, + timestep_respacing, + use_kl, + predict_xstart, + rescale_timesteps, + rescale_learned_sigmas, + use_checkpoint, + use_scale_shift_norm, + resblock_updown, + use_fp16, +): + model = sr_create_model( + large_size, + small_size, + num_channels, + num_res_blocks, + learn_sigma=learn_sigma, + class_cond=class_cond, + use_checkpoint=use_checkpoint, + attention_resolutions=attention_resolutions, + num_heads=num_heads, + num_head_channels=num_head_channels, + num_heads_upsample=num_heads_upsample, + use_scale_shift_norm=use_scale_shift_norm, + dropout=dropout, + resblock_updown=resblock_updown, + use_fp16=use_fp16, + ) + diffusion = create_gaussian_diffusion( + steps=diffusion_steps, + learn_sigma=learn_sigma, + noise_schedule=noise_schedule, + use_kl=use_kl, + predict_xstart=predict_xstart, + rescale_timesteps=rescale_timesteps, + rescale_learned_sigmas=rescale_learned_sigmas, + timestep_respacing=timestep_respacing, + ) + return model, diffusion + + +def sr_create_model( + large_size, + small_size, + num_channels, + num_res_blocks, + learn_sigma, + class_cond, + use_checkpoint, + attention_resolutions, + num_heads, + num_head_channels, + num_heads_upsample, + use_scale_shift_norm, + dropout, + resblock_updown, + use_fp16, +): + _ = small_size # hack to prevent unused variable + + if large_size == 512: + channel_mult = (1, 1, 2, 2, 4, 4) + elif large_size == 256: + channel_mult = (1, 1, 2, 2, 4, 4) + elif large_size == 64: + channel_mult = (1, 2, 3, 4) + else: + raise ValueError(f"unsupported large size: {large_size}") + + attention_ds = [] + for res in attention_resolutions.split(","): + attention_ds.append(large_size // int(res)) + + return SuperResModel( + image_size=large_size, + in_channels=3, + model_channels=num_channels, + out_channels=(3 if not learn_sigma else 6), + num_res_blocks=num_res_blocks, + attention_resolutions=tuple(attention_ds), + dropout=dropout, + channel_mult=channel_mult, + num_classes=(NUM_CLASSES if class_cond else None), + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + num_heads_upsample=num_heads_upsample, + use_scale_shift_norm=use_scale_shift_norm, + resblock_updown=resblock_updown, + use_fp16=use_fp16, + ) + + +def create_gaussian_diffusion( + *, + steps=1000, + learn_sigma=False, + sigma_small=False, + noise_schedule="linear", + use_kl=False, + predict_xstart=False, + rescale_timesteps=False, + rescale_learned_sigmas=False, + timestep_respacing="", +): + betas = gd.get_named_beta_schedule(noise_schedule, steps) + if use_kl: + loss_type = gd.LossType.RESCALED_KL + elif rescale_learned_sigmas: + loss_type = gd.LossType.RESCALED_MSE + else: + loss_type = gd.LossType.MSE # THIS + if not timestep_respacing: + timestep_respacing = [steps] + return SpacedDiffusion( + use_timesteps=space_timesteps(steps, timestep_respacing), + betas=betas, + model_mean_type=( + gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X + ), + model_var_type=( + ( + gd.ModelVarType.FIXED_LARGE + if not sigma_small + else gd.ModelVarType.FIXED_SMALL + ) + if not learn_sigma + else gd.ModelVarType.LEARNED_RANGE # THIS + ), + loss_type=loss_type, + rescale_timesteps=rescale_timesteps, + ) + + +def add_dict_to_argparser(parser, default_dict): + for k, v in default_dict.items(): + v_type = type(v) + if v is None: + v_type = str + elif isinstance(v, bool): + v_type = str2bool + parser.add_argument(f"--{k}", default=v, type=v_type) + + +def args_to_dict(args, keys): + return {k: getattr(args, k) for k in keys} + + +def str2bool(v): + """ + https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse + """ + if isinstance(v, bool): + return v + if v.lower() in ("yes", "true", "t", "y", "1"): + return True + elif v.lower() in ("no", "false", "f", "n", "0"): + return False + else: + raise argparse.ArgumentTypeError("boolean value expected") + + + +############################ + + +def get_model_and_diffusion(): + res = model_and_diffusion_defaults() + model, diffusion = create_model_and_diffusion(**res) + return model, diffusion diff --git a/Trainer/models/guided_diffusion/train_util.py b/Trainer/models/guided_diffusion/train_util.py new file mode 100644 index 0000000000000000000000000000000000000000..a742dd1d43ffc52e778623566cad2d2e10e676d9 --- /dev/null +++ b/Trainer/models/guided_diffusion/train_util.py @@ -0,0 +1,456 @@ +import copy +import functools +import os + +import math +import numpy as np +import nibabel as nib +import blobfile as bf +import torch +import torch.distributed as dist +from torch.nn.parallel.distributed import DistributedDataParallel as DDP +from torch.optim import AdamW + +import dist_util, logger +from fp16_util import MixedPrecisionTrainer +from nn import update_ema +from resample import LossAwareSampler, UniformSampler + + +# For ImageNet experiments, this was a good default value. +# We found that the lg_loss_scale quickly climbed to +# 20-21 within the first ~1K steps of training. +INITIAL_LOG_LOSS_SCALE = 20.0 + + +def visualize(img): + _min = img.min() + _max = img.max() + normalized_img = (img - _min) / (_max - _min) + return normalized_img + + +class TrainLoop: + def __init__( + self, + *, + model, + classifier, + diffusion, + data, + dataloader, + batch_size, + microbatch, + lr, + ema_rate, + log_interval, + save_interval, + resume_checkpoint, + use_fp16=False, + fp16_scale_growth=1e-3, + schedule_sampler=None, + weight_decay=0.0, + lr_anneal_steps=0, + ): + self.model = model + self.dataloader = dataloader + self.classifier = classifier + self.diffusion = diffusion + self.data = data + self.batch_size = batch_size + self.microbatch = microbatch if microbatch > 0 else batch_size + self.lr = lr + self.ema_rate = ( + [ema_rate] + if isinstance(ema_rate, float) + else [float(x) for x in ema_rate.split(",")] + ) + self.log_interval = log_interval + self.save_interval = save_interval + self.resume_checkpoint = resume_checkpoint + self.use_fp16 = use_fp16 + self.fp16_scale_growth = fp16_scale_growth + self.schedule_sampler = schedule_sampler or UniformSampler(diffusion) + self.weight_decay = weight_decay + self.lr_anneal_steps = lr_anneal_steps + + self.step = 0 + self.resume_step = 0 + self.global_batch = self.batch_size * dist.get_world_size() + + self.sync_cuda = torch.cuda.is_available() + + self._load_and_sync_parameters() + self.mp_trainer = MixedPrecisionTrainer( + model=self.model, + use_fp16=self.use_fp16, + fp16_scale_growth=fp16_scale_growth, + ) + + self.opt = AdamW( + self.mp_trainer.master_params, lr=self.lr, weight_decay=self.weight_decay + ) + if self.resume_step: + self._load_optimizer_state() + # Model was resumed, either due to a restart or a checkpoint + # being specified at the command line. + self.ema_params = [ + self._load_ema_parameters(rate) for rate in self.ema_rate + ] + else: + self.ema_params = [ + copy.deepcopy(self.mp_trainer.master_params) + for _ in range(len(self.ema_rate)) + ] + + if torch.cuda.is_available(): + self.use_ddp = True + self.ddp_model = DDP( + self.model, + device_ids=[dist_util.dev()], + output_device=dist_util.dev(), + broadcast_buffers=False, + bucket_cap_mb=128, + find_unused_parameters=False, + ) + else: + if dist.get_world_size() > 1: + logger.warn( + "Distributed training requires CUDA. " + "Gradients will not be synchronized properly!" + ) + self.use_ddp = False + self.ddp_model = self.model + + + + def _load_and_sync_parameters(self): + resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint + + if resume_checkpoint: + print("resume model") + self.resume_step = parse_resume_step_from_filename(resume_checkpoint) + if dist.get_rank() == 0: + logger.log(f"loading model from checkpoint: {resume_checkpoint}...") + self.model.load_state_dict( + dist_util.load_state_dict( + resume_checkpoint, map_location=dist_util.dev() + ) + ) + + dist_util.sync_params(self.model.parameters()) + + def _load_ema_parameters(self, rate): + ema_params = copy.deepcopy(self.mp_trainer.master_params) + + main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint + ema_checkpoint = find_ema_checkpoint(main_checkpoint, self.resume_step, rate) + if ema_checkpoint: + if dist.get_rank() == 0: + logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...") + state_dict = dist_util.load_state_dict( + ema_checkpoint, map_location=dist_util.dev() + ) + ema_params = self.mp_trainer.state_dict_to_master_params(state_dict) + + dist_util.sync_params(ema_params) + return ema_params + + def _load_optimizer_state(self): + main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint + opt_checkpoint = bf.join( + bf.dirname(main_checkpoint), f"opt{self.resume_step:06}.pt" + ) + if bf.exists(opt_checkpoint): + logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}") + state_dict = dist_util.load_state_dict( + opt_checkpoint, map_location=dist_util.dev() + ) + self.opt.load_state_dict(state_dict) + + def run_loop(self): + i = 0 + data_iter = iter(self.dataloader) + while ( + not self.lr_anneal_steps + or self.step + self.resume_step < self.lr_anneal_steps + ): + try: + batch, cond, path, slicedict = next(data_iter) # (b=1, 2, 224, 224, 224), (b=1, 1, 224, 224, 224) + print('batch, cond', batch.shape, cond.shape) + + batch_size_vol = 16 + nr_batches = len(slicedict) / batch_size_vol # only input images within to-inpaint masks + + nr_batches = math.ceil(nr_batches) + + for b in range(0, nr_batches): + out_batch = [] + out_cond = [] + + print('slicedict', slicedict) + print('slicedict', len(slicedict), b, nr_batches) + if len(slicedict) > b * batch_size_vol + batch_size_vol: + print('in', len(slicedict), b * batch_size_vol + batch_size_vol) + for s in slicedict[ + b * batch_size_vol : (b * batch_size_vol + batch_size_vol) + ]: + print('s', s) + out_batch.append(torch.tensor(batch[..., s])) # (b=1, 2, w, h) + out_cond.append(torch.tensor(cond[..., s])) # (b=1, 1, w, h) + + out_batch = torch.stack(out_batch) # (batch_size_vol, b=1, 2, w, h) + out_cond = torch.stack(out_cond) # (batch_size_vol, b=1, 1, w, h) + + print('1 out_batch, out_cond', out_batch.shape, out_cond.shape) + + out_batch = out_batch.squeeze(1) + out_cond = out_cond.squeeze(1) + + out_batch = out_batch.squeeze(4) # (batch_size_vol, 2, w, h) + out_cond = out_cond.squeeze(4) # (batch_size_vol, 2, w, h) + print('2 out_batch, out_cond', out_batch.shape, out_cond.shape) + + + p_s = path[0].split("/")[3] + + self.run_step(out_batch, out_cond) + + i += 1 + + else: + print('not in', len(slicedict), b * batch_size_vol + batch_size_vol) + for s in slicedict[b * batch_size_vol :]: + print('s', s) + out_batch.append(torch.tensor(batch[..., s])) + out_cond.append(torch.tensor(cond[..., s])) + + out_batch = torch.stack(out_batch) + out_cond = torch.stack(out_cond) + + out_batch = out_batch.squeeze(1) + out_cond = out_cond.squeeze(1) + out_batch = out_batch.squeeze(4) # (< batch_size_vol, 2, w, h) + out_cond = out_cond.squeeze(4) # (< batch_size_vol, 2, w, h) + print('NOT out_batch, out_cond', out_batch.shape, out_cond.shape) + + p_s = path[0].split("/")[3] + + self.run_step(out_batch, out_cond) + + i += 1 + + except StopIteration: + # StopIteration is thrown if dataset ends + # reinitialize data loader + data_iter = iter(self.dataloader) + + batch, cond, path, slicedict = next(data_iter) + + batch_size_vol = 16 + nr_batches = len(slicedict) / batch_size_vol + + nr_batches = math.ceil(nr_batches) + + for b in range(0, nr_batches): + out_batch = [] + out_cond = [] + + if len(slicedict) > b * batch_size_vol + batch_size_vol: + for s in slicedict[ + b * batch_size_vol : (b * batch_size_vol + batch_size_vol) + ]: + out_batch.append(torch.tensor(batch[..., s])) + out_cond.append(torch.tensor(cond[..., s])) + + out_batch = torch.stack(out_batch) + out_cond = torch.stack(out_cond) + + out_batch = out_batch.squeeze(1) + out_cond = out_cond.squeeze(1) + out_batch = out_batch.squeeze(4) + out_cond = out_cond.squeeze(4) + + p_s = path[0].split("/")[3] + + self.run_step(out_batch, out_cond) + + i += 1 + + else: + for s in slicedict[b * batch_size_vol :]: + out_batch.append(torch.tensor(batch[..., s])) + out_cond.append(torch.tensor(cond[..., s])) + + out_batch = torch.stack(out_batch) + out_cond = torch.stack(out_cond) + + out_batch = out_batch.squeeze(1) + out_cond = out_cond.squeeze(1) + out_batch = out_batch.squeeze(4) + out_cond = out_cond.squeeze(4) + + p_s = path[0].split("/")[3] + + self.run_step(out_batch, out_cond) + + i += 1 + + if self.step % self.log_interval == 0: + logger.dumpkvs() + if self.step % self.save_interval == 0: + self.save() + # Run for a finite amount of time in integration tests. + if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.step > 0: + return + self.step += 1 + # Save the last checkpoint if it wasn't already saved. + if (self.step - 1) % self.save_interval != 0: + self.save() + + def run_step(self, batch, cond): + print("batch pre:", batch.shape) # (16, 2, 224, 224) # void + mask + print("cond pre:", cond.shape) # (16, 1, 224, 224) # all (unmasked) + batch = torch.cat((batch, cond), dim=1) # (16, 3, 224, 224) + print("batch:", batch.shape) # (16, 3, 224, 224) + cond = {} + sample = self.forward_backward(batch, cond) + print("out sample:", sample.shape) # (16, 1, 224, 224) + took_step = self.mp_trainer.optimize(self.opt) + if took_step: + self._update_ema() + self._anneal_lr() + self.log_step() + return sample + + def forward_backward(self, batch, cond): + self.mp_trainer.zero_grad() + + micro = batch.to(dist_util.dev()) + t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev()) # (16, ) + + print('micro, t, weights:', micro.shape, t.shape, weights.shape) + + compute_losses = functools.partial( + self.diffusion.training_losses_segmentation, + self.ddp_model, # UNet + self.classifier, # None + micro, # (16, 3, 224, 224) + t, + ) + + # if last_batch or not self.use_ddp: + losses1 = compute_losses() + + # else: + # with self.ddp_model.no_sync(): + # losses1 = compute_losses() + + if isinstance(self.schedule_sampler, LossAwareSampler): + self.schedule_sampler.update_with_local_losses(t, losses["loss"].detach()) + losses = losses1[0] + sample = losses1[1] + + print('--- losses', losses) + + loss = (losses["loss"] * weights).mean() + print('--- avg loss', loss) + + log_loss_dict(self.diffusion, t, {k: v * weights for k, v in losses.items()}) + self.mp_trainer.backward(loss) + return sample + + def _update_ema(self): + for rate, params in zip(self.ema_rate, self.ema_params): + update_ema(params, self.mp_trainer.master_params, rate=rate) + + def _anneal_lr(self): + if not self.lr_anneal_steps: + return + frac_done = (self.step + self.resume_step) / self.lr_anneal_steps + lr = self.lr * (1 - frac_done) + for param_group in self.opt.param_groups: + param_group["lr"] = lr + + def log_step(self): + logger.logkv("step", self.step + self.resume_step) + logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch) + + def save(self): + def save_checkpoint(rate, params): + state_dict = self.mp_trainer.master_params_to_state_dict(params) + if dist.get_rank() == 0: + logger.log(f"saving model {rate}...") + if not rate: + filename = f"savedmodel{(self.step+self.resume_step):06d}.pt" + else: + filename = ( + f"emasavedmodel_{rate}_{(self.step+self.resume_step):06d}.pt" + ) + with bf.BlobFile(bf.join(get_blob_logdir(), filename), "wb") as f: + torch.save(state_dict, f) + + save_checkpoint(0, self.mp_trainer.master_params) + for rate, params in zip(self.ema_rate, self.ema_params): + save_checkpoint(rate, params) + + if dist.get_rank() == 0: + with bf.BlobFile( + bf.join( + get_blob_logdir(), + f"optsavedmodel{(self.step+self.resume_step):06d}.pt", + ), + "wb", + ) as f: + torch.save(self.opt.state_dict(), f) + + dist.barrier() + + +def parse_resume_step_from_filename(filename): + """ + Parse filenames of the form path/to/modelNNNNNN.pt, where NNNNNN is the + checkpoint's number of steps. + """ + split = filename.split("model") + if len(split) < 2: + return 0 + split1 = split[-1].split(".")[0] + try: + return int(split1) + except ValueError: + return 0 + + +def get_blob_logdir(): + # You can change this to be a separate path to save checkpoints to + # a blobstore or some external drive. + return logger.get_dir() + + +def find_resume_checkpoint(): + # On your infrastructure, you may want to override this to automatically + # discover the latest checkpoint on your blob storage, etc. + return None + + +def find_ema_checkpoint(main_checkpoint, step, rate): + if main_checkpoint is None: + return None + filename = f"ema_{rate}_{(step):06d}.pt" + path = bf.join(bf.dirname(main_checkpoint), filename) + if bf.exists(path): + return path + return None + + +def log_loss_dict(diffusion, ts, losses): + for key, values in losses.items(): + logger.logkv_mean(key, values.mean().item()) + # Log the quantiles (four quartiles, in particular). + for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()): + quartile = int(4 * sub_t / diffusion.num_timesteps) + logger.logkv_mean(f"{key}_q{quartile}", sub_loss) + + + \ No newline at end of file diff --git a/Trainer/models/guided_diffusion/unet.py b/Trainer/models/guided_diffusion/unet.py new file mode 100644 index 0000000000000000000000000000000000000000..b3c66765ab2f7265b50fe929efaa2088946919dc --- /dev/null +++ b/Trainer/models/guided_diffusion/unet.py @@ -0,0 +1,904 @@ +from abc import abstractmethod +import math +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from fp16_util import convert_module_to_f16, convert_module_to_f32 +from nn import ( + checkpoint, + conv_nd, + linear, + avg_pool_nd, + zero_module, + normalization, + timestep_embedding, +) + + +class AttentionPool2d(nn.Module): + """ + Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py + """ + + def __init__( + self, + spacial_dim: int, + embed_dim: int, + num_heads_channels: int, + output_dim: int = None, + ): + super().__init__() + self.positional_embedding = nn.Parameter( + torch.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5 + ) + self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) + self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) + self.num_heads = embed_dim // num_heads_channels + self.attention = QKVAttention(self.num_heads) + + def forward(self, x): + b, c, *_spatial = x.shape + x = x.reshape(b, c, -1) # NC(HW) + x = torch.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) + x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) + x = self.qkv_proj(x) + x = self.attention(x) + x = self.c_proj(x) + return x[:, :, 0] + + +class TimestepBlock(nn.Module): + """ + Any module where forward() takes timestep embeddings as a second argument. + """ + + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + + def forward(self, x, emb): + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb) + else: + x = layer(x) + return x + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate( + x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" + ) + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd( + dims, self.channels, self.out_channels, 3, stride=stride, padding=1 + ) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class ResBlock(TimestepBlock): + """ + A residual block that can optionally change the number of channels. + + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: if True, use gradient checkpointing on this module. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + use_checkpoint=False, + up=False, + down=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + conv_nd(dims, channels, self.out_channels, 3, padding=1), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.SiLU(), + linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module( + conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) + ), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 3, padding=1 + ) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + def forward(self, x, emb): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + return checkpoint( + self._forward, (x, emb), self.parameters(), self.use_checkpoint + ) + + def _forward(self, x, emb): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = torch.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. + + Originally ported from here, but adapted to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + """ + + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + use_checkpoint=False, + use_new_attention_order=False, + ): + super().__init__() + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.use_checkpoint = use_checkpoint + self.norm = normalization(channels) + self.qkv = conv_nd(1, channels, channels * 3, 1) + if use_new_attention_order: + # split qkv before split heads + self.attention = QKVAttention(self.num_heads) + else: + # split heads before split qkv + self.attention = QKVAttentionLegacy(self.num_heads) + + self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) + + def forward(self, x): + return checkpoint(self._forward, (x,), self.parameters(), True) + + def _forward(self, x): + b, c, *spatial = x.shape + x = x.reshape(b, c, -1) + qkv = self.qkv(self.norm(x)) + h = self.attention(qkv) + h = self.proj_out(h) + return (x + h).reshape(b, c, *spatial) + + +def count_flops_attn(model, _x, y): + """ + A counter for the `thop` package to count the operations in an + attention operation. + Meant to be used like: + macs, params = thop.profile( + model, + inputs=(inputs, timestamps), + custom_ops={QKVAttention: QKVAttention.count_flops}, + ) + """ + b, c, *spatial = y[0].shape + num_spatial = int(np.prod(spatial)) + # We perform two matmuls with the same number of ops. + # The first computes the weight matrix, the second computes + # the combination of the value vectors. + matmul_ops = 2 * b * (num_spatial**2) * c + model.total_ops += torch.DoubleTensor([matmul_ops]) + + +class QKVAttentionLegacy(nn.Module): + """ + A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + + :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = torch.einsum( + "bct,bcs->bts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + a = torch.einsum("bts,bcs->bct", weight, v) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class QKVAttention(nn.Module): + """ + A module which performs QKV attention and splits in a different order. + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + + :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.chunk(3, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = torch.einsum( + "bct,bcs->bts", + (q * scale).view(bs * self.n_heads, ch, length), + (k * scale).view(bs * self.n_heads, ch, length), + ) # More stable with f16 than dividing afterwards + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + a = torch.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class UNetModel(nn.Module): + """ + The full UNet model with attention and timestep embedding. + + :param in_channels: channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param num_res_blocks: number of residual blocks per downsample. + :param attention_resolutions: a collection of downsample rates at which + attention will take place. May be a set, list, or tuple. + For example, if this contains 4, then at 4x downsampling, attention + will be used. + :param dropout: the dropout probability. + :param channel_mult: channel multiplier for each level of the UNet. + :param conv_resample: if True, use learned convolutions for upsampling and + downsampling. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param num_classes: if specified (as an int), then this model will be + class-conditional with `num_classes` classes. + :param use_checkpoint: use gradient checkpointing to reduce memory usage. + :param num_heads: the number of attention heads in each attention layer. + :param num_heads_channels: if specified, ignore num_heads and instead use + a fixed channel width per attention head. + :param num_heads_upsample: works with num_heads to set a different number + of heads for upsampling. Deprecated. + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. + :param resblock_updown: use residual blocks for up/downsampling. + :param use_new_attention_order: use a different attention pattern for potentially + increased efficiency. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + num_classes=None, + use_checkpoint=False, + use_fp16=False, + num_heads=1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + ): + super().__init__() + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + self.image_size = image_size + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.use_checkpoint = use_checkpoint + self.dtype = torch.float16 if use_fp16 else torch.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + if self.num_classes is not None: + self.label_emb = nn.Embedding(num_classes, time_embed_dim) + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(num_res_blocks + 1): + ich = input_block_chans.pop() + layers = [ + ResBlock( + ch + ich, + time_embed_dim, + dropout, + out_channels=model_channels * mult, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = model_channels * mult + if ds in attention_resolutions: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads_upsample, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ) + ) + if level and i == num_res_blocks: + out_ch = ch + layers.append( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), + ) + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + self.output_blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + self.output_blocks.apply(convert_module_to_f32) + + def forward(self, x, timesteps, y=None): + """ + Apply the model to an input batch. + + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :param y: an [N] Tensor of labels, if class-conditional. + :return: an [N x C x ...] Tensor of outputs. + """ + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + + hs = [] + emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + + if self.num_classes is not None: + assert y.shape == (x.shape[0],) + emb = emb + self.label_emb(y) + + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb) + hs.append(h) + h = self.middle_block(h, emb) + for module in self.output_blocks: + h = torch.cat([h, hs.pop()], dim=1) + h = module(h, emb) + h = h.type(x.dtype) + return self.out(h) + + +class SuperResModel(UNetModel): + """ + A UNetModel that performs super-resolution. + + Expects an extra kwarg `low_res` to condition on a low-resolution image. + """ + + def __init__(self, image_size, in_channels, *args, **kwargs): + super().__init__(image_size, in_channels * 2, *args, **kwargs) + + def forward(self, x, timesteps, low_res=None, **kwargs): + _, _, new_height, new_width = x.shape + upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear") + x = torch.cat([x, upsampled], dim=1) + return super().forward(x, timesteps, **kwargs) + + +class EncoderUNetModel(nn.Module): + """ + The half UNet model with attention and timestep embedding. + + For usage, see UNet. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + use_checkpoint=False, + use_fp16=False, + num_heads=1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + pool="adaptive", + ): + super().__init__() + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.use_checkpoint = use_checkpoint + self.dtype = torch.float16 if use_fp16 else torch.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + self.pool = pool + self.gap = nn.AvgPool2d((8, 8)) # global average pooling + self.cam_feature_maps = None + print("pool", pool) + if pool == "adaptive": + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + nn.AdaptiveAvgPool2d((1, 1)), + zero_module(conv_nd(dims, ch, out_channels, 1)), + nn.Flatten(), + ) + elif pool == "attention": + assert num_head_channels != -1 + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + AttentionPool2d( + (image_size // ds), ch, num_head_channels, out_channels + ), + ) + elif pool == "spatial": + self.out = nn.Linear(256, self.out_channels) + + elif pool == "spatial_v2": + self.out = nn.Sequential( + nn.Linear(self._feature_size, 2048), + normalization(2048), + nn.SiLU(), + nn.Linear(2048, self.out_channels), + ) + else: + raise NotImplementedError(f"Unexpected {pool} pooling") + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + + def forward(self, x, timesteps): + """ + Apply the model to an input batch. + + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :return: an [N x K] Tensor of outputs. + """ + emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + + results = [] + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb) + if self.pool.startswith("spatial"): + results.append(h.type(x.dtype).mean(dim=(2, 3))) + h = self.middle_block(h, emb) + + if self.pool.startswith("spatial"): + self.cam_feature_maps = h + h = self.gap(h) + N = h.shape[0] + h = h.reshape(N, -1) + print("h1", h.shape) + return self.out(h) + else: + h = h.type(x.dtype) + self.cam_feature_maps = h + return self.out(h) diff --git a/Trainer/models/guided_diffusion/util.py b/Trainer/models/guided_diffusion/util.py new file mode 100644 index 0000000000000000000000000000000000000000..d378d9deabfacc79890902134da2019a333fbf90 --- /dev/null +++ b/Trainer/models/guided_diffusion/util.py @@ -0,0 +1,303 @@ +# adopted from +# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +# and +# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +# and +# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py +# +# thanks! + + +import os +import math +import torch +import torch.nn as nn +import numpy as np +from einops import repeat + + +def instantiate_from_config(config): + if not "target" in config: + if config == "__is_first_stage__": + return None + elif config == "__is_unconditional__": + return None + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) + + +def make_beta_schedule( + schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3 +): + if schedule == "linear": + betas = ( + torch.linspace( + linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64 + ) + ** 2 + ) + + elif schedule == "cosine": + timesteps = ( + torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s + ) + alphas = timesteps / (1 + cosine_s) * np.pi / 2 + alphas = torch.cos(alphas).pow(2) + alphas = alphas / alphas[0] + betas = 1 - alphas[1:] / alphas[:-1] + betas = np.clip(betas, a_min=0, a_max=0.999) + + elif schedule == "sqrt_linear": + betas = torch.linspace( + linear_start, linear_end, n_timestep, dtype=torch.float64 + ) + elif schedule == "sqrt": + betas = ( + torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) + ** 0.5 + ) + else: + raise ValueError(f"schedule '{schedule}' unknown.") + return betas.numpy() + + +def make_ddim_timesteps( + ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True +): + if ddim_discr_method == "uniform": + c = num_ddpm_timesteps // num_ddim_timesteps + ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) + elif ddim_discr_method == "quad": + ddim_timesteps = ( + (np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2 + ).astype(int) + else: + raise NotImplementedError( + f'There is no ddim discretization method called "{ddim_discr_method}"' + ) + + # assert ddim_timesteps.shape[0] == num_ddim_timesteps + # add one to get the final alpha values right (the ones from first scale to data during sampling) + steps_out = ddim_timesteps + 1 + if verbose: + print(f"Selected timesteps for ddim sampler: {steps_out}") + return steps_out + + +def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): + # select alphas for computing the variance schedule + alphas = alphacums[ddim_timesteps] + alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) + + # according the the formula provided in https://arxiv.org/abs/2010.02502 + sigmas = eta * np.sqrt( + (1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev) + ) + if verbose: + print( + f"Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}" + ) + print( + f"For the chosen value of eta, which is {eta}, " + f"this results in the following sigma_t schedule for ddim sampler {sigmas}" + ) + return sigmas, alphas, alphas_prev + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32) + / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat( + [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 + ) + else: + embedding = repeat(timesteps, "b -> b d", d=dim) + return embedding + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def normalization(channels): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(32, channels) + + +# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +class SiLU(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +class HybridConditioner(nn.Module): + def __init__(self, c_concat_config, c_crossattn_config): + super().__init__() + self.concat_conditioner = instantiate_from_config(c_concat_config) + self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) + + def forward(self, c_concat, c_crossattn): + c_concat = self.concat_conditioner(c_concat) + c_crossattn = self.crossattn_conditioner(c_crossattn) + return {"c_concat": [c_concat], "c_crossattn": [c_crossattn]} + + +def noise_like(shape, device, repeat=False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat( + shape[0], *((1,) * (len(shape) - 1)) + ) + noise = lambda: torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() diff --git a/Trainer/models/head.py b/Trainer/models/head.py new file mode 100644 index 0000000000000000000000000000000000000000..20385038367a92b9eeae44870f11401934866450 --- /dev/null +++ b/Trainer/models/head.py @@ -0,0 +1,189 @@ +""" +Model heads +""" + + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def removekey(d, keys): + r = dict(d) + for k in keys: + del r[k] + return r + + +class TaskHead(nn.Module): + """ + Task-specific head that takes a list of sample features as inputs + """ + + def __init__(self, args, f_maps_list, out_channels, is_3d, out_feat_level = -1, exclude_keys=[], *kwargs): + super(TaskHead, self).__init__() + self.out_feat_level = out_feat_level + + layers = [] # additional layers (same-size-output 3x3 conv) before final_conv, if len( f_maps_list ) > 1 + for i, in_feature_num in enumerate(f_maps_list[:-1]): + layer = ConvBlock(in_feature_num, f_maps_list[i+1], stride = 1, is_3d = is_3d) + layers.append(layer) + self.layers = nn.ModuleList(layers) + + conv = nn.Conv3d if is_3d else nn.Conv2d + fc = nn.Linear + + self.out_channels = removekey(out_channels, exclude_keys) + self.out_names = self.out_channels.keys() + for out_name, out_channels_num in self.out_channels.items(): + if out_channels_num > 0: + self.add_module("final_conv_%s" % out_name, conv(f_maps_list[-1], out_channels_num, 1)) + else: # single value output (age prediction) + pool_layers = [nn.MaxPool3d(4, 4), # (160 -> 40) + ConvBlock(f_maps_list[-1], 16, stride = 1, is_3d = is_3d), + nn.MaxPool3d(4, 4), # (40 -> 10) + ConvBlock(16, 4, stride = 1, is_3d = is_3d) + ] + self.pool_layers = nn.ModuleList(pool_layers) + self.add_module("final_linear1_%s" % out_name, fc(4 * args.size[0] // 16 * args.size[1] // 16 * args.size[2] // 16, 160, 1)) + self.add_module("final_linear2_%s" % out_name, fc(160, 10, 1)) + self.add_module("final_linear3_%s" % out_name, fc(10, - out_channels_num, 1)) + + def forward(self, x, *kwargs): + x = x[self.out_feat_level] + for layer in self.layers: + x = layer(x) + out = {} + for (name, n_channels) in self.out_channels.items(): + if n_channels > 0: + out[name] = getattr(self, f"final_conv_{name}")(x) + else: + for layer in self.pool_layers: + x = layer(x) + x = torch.flatten(x, 1) # flatten all dimensions except batch + x = F.relu(getattr(self, f"final_linear1_{name}")(x)) + x = F.relu(getattr(self, f"final_linear2_{name}")(x)) + out[name] = torch.squeeze(getattr(self, f"final_linear3_{name}")(x), dim = 1) + return out + + +class DepHead(nn.Module): + """ + Task-specific head that takes a list of sample features as inputs + For contrast-dependent tasks + """ + + def __init__(self, args, f_maps_list, out_channels, is_3d, out_feat_level = -1, *kwargs): + super(DepHead, self).__init__() + self.out_feat_level = out_feat_level + + f_maps_list[0] += 1 # add one input image/contrast channel + + layers = [] # additional layers (same-size-output 3x3 conv) before final_conv, if len( f_maps_list ) > 1 + for i, in_feature_num in enumerate(f_maps_list[:-1]): + layer = ConvBlock(in_feature_num, f_maps_list[i+1], stride = 1, is_3d = is_3d) + layers.append(layer) + self.layers = nn.ModuleList(layers) + + conv = nn.Conv3d if is_3d else nn.Conv2d + self.out_names = out_channels.keys() + for out_name, out_channels_num in out_channels.items(): + self.add_module("final_conv_%s" % out_name, conv(f_maps_list[-1], out_channels_num, 1)) + + def forward(self, x, image): + x = x[self.out_feat_level] + x = torch.cat([x, image], dim = 1) + for layer in self.layers: + x = layer(x) + out = {} + for name in self.out_names: + out[name] = getattr(self, f"final_conv_{name}")(x) + return out + + + +class MultiInputDepHead(DepHead): + """ + Task-specific head that takes a list of sample features as inputs + For contrast-dependent tasks + """ + + def __init__(self, args, f_maps_list, out_channels, is_3d, out_feat_level = -1, *kwargs): + super(MultiInputDepHead, self).__init__(args, f_maps_list, out_channels, is_3d, out_feat_level) + + def forward(self, feat_list, image_list): + outs = [] + for i, x in enumerate(feat_list): + x = x[self.out_feat_level] + x = torch.cat([x, image_list[i]], dim = 1) + for layer in self.layers: + x = layer(x) + out = {} + for name in self.out_names: + out[name] = getattr(self, f"final_conv_{name}")(x) + outs.append(out) + return outs + + + +class MultiInputTaskHead(TaskHead): + """ + Task-specific head that takes a list of sample features as inputs + For contrast-independent tasks + """ + + def __init__(self, args, f_maps_list, out_channels, is_3d, out_feat_level = -1, *kwargs): + super(MultiInputTaskHead, self).__init__(args, f_maps_list, out_channels, is_3d, out_feat_level) + + def forward(self, feat_list, *kwargs): + outs = [] + for x in feat_list: + x = x[self.out_feat_level] + for layer in self.layers: + x = layer(x) + out = {} + for name in self.out_names: + out[name] = getattr(self, f"final_conv_{name}")(x) + outs.append(out) + return outs + + + +class ConvBlock(nn.Module): + """ + Specific same-size-output 3x3 convolutional block followed by leakyrelu for unet. + """ + + def __init__(self, in_channels, out_channels, stride=1, is_3d=True): + super().__init__() + + conv = nn.Conv3d if is_3d else nn.Conv2d + self.main = conv(in_channels, out_channels, 3, stride, 1) + self.activation = nn.LeakyReLU(0.2) + + def forward(self, x): + out = self.main(x) + out = self.activation(out) + return out + + + +################################ + + + +def get_head(train_args, f_maps_list, out_channels, is_3d, out_feat_level, stage=0, exclude_keys=[]): + if 'sep' in train_args.backbone: # separate decoder and head for anomaly/pathology segmentation + return get_sep_head(train_args, f_maps_list, out_channels, is_3d, out_feat_level) + elif '+' in train_args.backbone: # two-stage network for inpainting + if stage == 0: + return TaskHead(train_args, f_maps_list, {'pathology': 1}, is_3d, out_feat_level) + else: + return TaskHead(train_args, f_maps_list, out_channels, is_3d, out_feat_level, exclude_keys = ['pathology']) + return TaskHead(train_args, f_maps_list, out_channels, is_3d, out_feat_level, exclude_keys) + + +def get_sep_head(train_args, f_maps_list, out_channels, is_3d, out_feat_level): + head_normal = TaskHead(train_args, f_maps_list, out_channels, is_3d, out_feat_level, ['pathology']) + head_pathol = TaskHead(train_args, f_maps_list, {'pathology': 1}, is_3d, out_feat_level) + return {'normal': head_normal, 'pathology': head_pathol} \ No newline at end of file diff --git a/Trainer/models/joiner.py b/Trainer/models/joiner.py new file mode 100644 index 0000000000000000000000000000000000000000..5793f50625e4f30ee1ee83c88809f6322340d05a --- /dev/null +++ b/Trainer/models/joiner.py @@ -0,0 +1,272 @@ + + +""" +Wrapper interface. +""" +import torch +import torch.nn.functional as F +import torch.nn as nn + +from Trainer.models.unet3d.model import UNet3D +from .head import TaskHead, MultiInputTaskHead +from utils.checkpoint import load_checkpoint + + +#supersynth_ckp_path = '/autofs/space/yogurt_002/users/pl629/ckp/wmh-synthseg/PAPER_checkpoint_0101.pth' +supersynth_ckp_path = '/autofs/space/yogurt_002/users/pl629/ckp/wmh-synthseg/AllDataIn_checkpoint_0101.pth' + +flair2pathol_feat_ckp_path = '/autofs/space/yogurt_002/users/pl629/ckp/Supv/supv_adni3_flair2pathol_feat_epoch_35.pth' +flair2pathol_task_ckp_path = '/autofs/space/yogurt_002/users/pl629/ckp/Supv/supv_adni3_flair2pathol_epoch_35.pth' + + + +def build_supersynth_model(device = 'cpu'): + # 33 + 4 + 1 + 1 = 39 (SuperSynth) + backbone = UNet3D(1, f_maps=64, layer_order='gcl', num_groups=8, num_levels=5, is3d=True) + head = TaskHead(None, f_maps_list = [64], out_channels ={'segmentation': 39}, is_3d = True, out_feat_level = -1) + model = get_joiner('segmentation', backbone, head) + processor = SegProcessor().to(device) + model.to(device) + return model, processor + + +def build_pathol_model(device = 'cpu'): + backbone = UNet3D(1, f_maps=64, layer_order='gcl', num_groups=8, num_levels=5, is3d=True) + feat_model = get_joiner('segmentation', backbone, None) + task_model = MultiInputTaskHead(None, [64], {'pathology': 1}, True, -1) + processor = PatholProcessor().to(device) + feat_model.to(device) + task_model.to(device) + return feat_model, task_model, processor + + + + +class UncertaintyProcessor(nn.Module): + def __init__(self, output_names): + super(UncertaintyProcessor, self).__init__() + self.output_names = output_names + + def forward(self, outputs, *kwargs): + for output_name in self.output_names: + if 'image' in output_name: + for output in outputs: + output[output_name + '_sigma'] = output[output_name][:, 1][:, None] + output[output_name] = output[output_name][:, 0][:, None] + return outputs + +class AgeProcessor(nn.Module): + def __init__(self): + super(AgeProcessor, self).__init__() + + def forward(self, outputs, *kwargs): + for output in outputs: + #output['age'] = output['age'] ** 2 + output['age'] = abs(output['age']) + return outputs + + +class SegProcessor(nn.Module): + def __init__(self): + super(SegProcessor, self).__init__() + self.softmax = nn.Softmax(dim = 1) + + def forward(self, outputs, *kwargs): + for output in outputs: + output['segmentation'] = self.softmax(output['segmentation']) + return outputs + +class PatholProcessor(nn.Module): + def __init__(self): + super(PatholProcessor, self).__init__() + self.sigmoid = nn.Sigmoid() + + def forward(self, outputs, *kwargs): + for output in outputs: + output['pathology'] = self.sigmoid(output['pathology']) + return outputs + + +class PatholSeg(nn.Module): + def __init__(self, args): + super(PatholSeg, self).__init__() + self.sigmoid = nn.Sigmoid() + + paths = args.supervised_pathol_seg_ckp_path + self.feat_model, self.task_model, self.processor = build_pathol_model() + load_checkpoint(paths.feat, [self.feat_model], model_keys = ['model'], to_print = False) + load_checkpoint(paths.task, [self.task_model], model_keys = ['model'], to_print = False) + for param in self.feat_model.parameters(): # Crucial!!!! We backprop through it, but weights should not change + param.requires_grad = False + for param in self.task_model.parameters(): # Crucial!!!! We backprop through it, but weights should not change + param.requires_grad = False + + aux_paths = args.supervised_aux_pathol_seg_ckp_path + if args.aux_modality is not None: + self.aux_feat_model, self.aux_task_model, self.aux_processor = build_pathol_model() + load_checkpoint(aux_paths.feat, [self.aux_feat_model], model_keys = ['model'], to_print = False) + load_checkpoint(aux_paths.task, [self.aux_task_model], model_keys = ['model'], to_print = False) + for param in self.aux_feat_model.parameters(): # Crucial!!!! We backprop through it, but weights should not change + param.requires_grad = False + for param in self.aux_task_model.parameters(): # Crucial!!!! We backprop through it, but weights should not change + param.requires_grad = False + else: + self.aux_feat_model, self.aux_task_model, self.aux_processor = None, None, None + + def forward(self, outputs, target, curr_dataset, *kwargs): + for output in outputs: + if output['image'].shape == target['image'].shape: + samples = [ { 'input': output['image'] }, { 'input': target['image'] } ] + feats, inputs = self.feat_model(samples) + preds = self.task_model([feat['feat'] for feat in feats], inputs) + preds = self.processor(preds, samples) + output['implicit_pathol_pred'] = preds[0]['pathology'] + output['implicit_pathol_orig'] = preds[1]['pathology'] + if self.aux_feat_model is not None: + if output['aux_image'].shape == target['aux_image'].shape: + samples = [ { 'input': output['aux_image'] }, { 'input': target['aux_image'] } ] + feats, inputs = self.aux_feat_model(samples) + preds = self.aux_task_model([feat['feat'] for feat in feats], inputs) + preds = self.processor(preds, samples) + output['implicit_aux_pathol_pred'] = preds[0]['pathology'] + output['implicit_aux_pathol_orig'] = preds[1]['pathology'] + return outputs + + +class ContrastiveProcessor(nn.Module): + def __init__(self): + ''' + Ref: https://openreview.net/forum?id=2oCb0q5TA4Y + ''' + super(ContrastiveProcessor, self).__init__() + self.softmax = nn.Softmax(dim = 1) + + def forward(self, outputs, *kwargs): + for output in outputs: + output['feat'][-1] = F.normalize(output['feat'][-1], dim = 1) + return outputs + +class DistProcessor(nn.Module): + def __init__(self, gen_args): + super(DistProcessor, self).__init__() + self.gen_args = gen_args + + def forward(self, outputs, *kwargs): + for output in outputs: + output['distance'] = torch.clamp(output['distance'], min = - self.gen_args.max_surf_distance, max = self.gen_args.max_surf_distance) + return outputs + + +############################################################################## + + +class MultiInputIndepJoiner(nn.Module): + """ + Perform forward pass separately on each augmented input. + """ + def __init__(self, backbone, head, device, postfix = ''): + super(MultiInputIndepJoiner, self).__init__() + + self.backbone = backbone.to(device) + self.head = head.to(device) + self.postfix = postfix + + def forward(self, input_list, input_name = 'input', cond = []): + outs = [] + for i, x in enumerate(input_list): + if len(cond) > 0: + feat = self.backbone.get_feature(torch.concat([x[input_name], cond[i]], dim = 1)) + else: + feat = self.backbone.get_feature(x[input_name]) + out = {'feat' + self.postfix: feat} + if self.head is not None: + out.update( self.head(feat) ) + outs.append(out) + return outs, [input[input_name] for input in input_list] + + +class MultiInputSepDecIndepJoiner(nn.Module): + """ + Perform forward pass separately on each augmented input. + NOTE: keys in head_dict must equal feat_dict + """ + def __init__(self, backbone, head_dict, device): + super(MultiInputSepDecIndepJoiner, self).__init__() + + self.backbone = backbone.to(device) + self.head_dict = head_dict + for k in self.head_dict.keys(): + self.head_dict[k].to(device) + + def forward(self, input_list): + outs = [] + for x in input_list: + feat_dict = self.backbone.get_feature(x['input']) + out = {'feat_%s' % k: feat_dict[k] for k in feat_dict.keys()} + for k in feat_dict.keys(): + if self.head_dict is not None: + out.update( self.head_dict[k](feat_dict[k]) ) + outs.append(out) + return outs, [input['input'] for input in input_list] + + +class MultiInputDepJoiner(nn.Module): + """ + Perform forward pass separately on each augmented input. + """ + def __init__(self, backbone, head, device): + super(MultiInputDepJoiner, self).__init__() + + self.backbone = backbone.to(device) + self.head = head.to(device) + + def forward(self, input_list): + outs = [] + for x in input_list: + feat = self.backbone.get_feature(x['input']) + out = {'feat': feat} + if self.head is not None: + out.update( self.head( feat, x) ) + outs.append(out) + return outs, [input['input'] for input in input_list] + + + +################################ + + +def get_processors(gen_args, train_args, tasks, device, exclude_keys = []): + processors = [] + if train_args.losses.uncertainty is not None: + processors.append(UncertaintyProcessor(train_args.output_names).to(device)) + if train_args.losses.implicit_pathol: + processors.append(PatholSeg(train_args).to(device)) + + if 'contrastive' in tasks: + processors.append(ContrastiveProcessor().to(device)) + if 'age' in tasks: + processors.append(AgeProcessor().to(device)) + if 'segmentation' in tasks and 'segmentation' not in exclude_keys: + processors.append(SegProcessor().to(device)) + if 'distance' in tasks: + processors.append(DistProcessor(gen_args).to(device)) + if 'pathology' in tasks and 'pathology' not in exclude_keys: + processors.append(PatholProcessor().to(device)) + + return processors + + + + + +def get_joiner(task, backbone, head, device, postfix = ''): + if isinstance(head, dict): + return get_sep_joiner(task, backbone, head, device) + + return MultiInputIndepJoiner(backbone, head, device, postfix = postfix) + + + + +def get_sep_joiner(task, backbone, head_dict, device): + return MultiInputSepDecIndepJoiner(backbone, head_dict, device) \ No newline at end of file diff --git a/Trainer/models/losses.py b/Trainer/models/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..e04b8a256f3084edb185616c6e6d6fd500c2d9b2 --- /dev/null +++ b/Trainer/models/losses.py @@ -0,0 +1,142 @@ + +""" +Losses +""" + +import torch +import torch.nn.functional as F +from torch import nn as nn + + + +def l1_loss(outputs, targets, weights = 1.): + return torch.mean(abs(outputs - targets) * weights) + +def l2_loss(outputs, targets, weights = 1.): + return torch.mean((outputs - targets)**2 * weights) + +def gaussian_loss(outputs_mu, outputs_sigma, targets, weights = 1.): + variance = torch.exp(outputs_sigma) + minusloglhood = 0.5 * torch.log(2 * torch.pi * variance) + 0.5 * ((targets - outputs_mu) ** 2) / variance + return torch.mean(minusloglhood * weights) + +def laplace_loss(outputs_mu, outputs_sigma, targets, weights = 1.): + b = torch.exp(outputs_sigma) + minusloglhood = torch.log(2 * b) + torch.abs(targets - outputs_mu) / b + return torch.mean(minusloglhood, weights) + + +class GradientLoss(nn.Module): + def __init__(self, mode = 'l1', mask = False): + super(GradientLoss, self).__init__() + self.mask = mask + if mode == 'l1': + self.loss_func = l1_loss + elif mode == 'l2': + self.loss_func = l2_loss + else: + raise ValueError('Not supported loss_func for GradientLoss:', mode) + + def gradient(self, x): + # x: (b, c, s, r, c) --> dx, dy, dz: (b, c, s, r, c) + back = F.pad(x, [0, 1, 0, 0, 0, 0])[:, :, :, :, 1:] + right = F.pad(x, [0, 0, 0, 1, 0, 0])[:, :, :, 1:, :] + bottom = F.pad(x, [0, 0, 0, 0, 0, 1])[:, :, 1:, :, :] + + dx, dy, dz = back - x, right - x, bottom - x + dx[:, :, :, :, -1] = 0 + dy[:, :, :, -1] = 0 + dz[:, :, -1] = 0 + return dx, dy, dz + + def forward_archive(self, input, target): + dx_i, dy_i, dz_i = self.gradient(input) + dx_t, dy_t, dz_t = self.gradient(target) + if self.mask: + dx_i[dx_t == 0.] = 0. + dy_i[dy_t == 0.] = 0. + dz_i[dz_t == 0.] = 0. + return (self.loss_func(dx_i, dx_t) + self.loss_func(dy_i, dy_t) + self.loss_func(dz_i, dz_t)).mean() + + def forward(self, input, target, weights = 1.): + dx_i, dy_i, dz_i = self.gradient(input) + dx_t, dy_t, dz_t = self.gradient(target) + if self.mask: + diff_dx = abs(dx_i - dx_t) + diff_dy = abs(dy_i - dy_t) + diff_dz = abs(dz_i - dz_t) + diff_dx[target == 0.] = 0. + diff_dy[target == 0.] = 0. + diff_dz[target == 0.] = 0. + return (diff_dx + diff_dy + diff_dz).mean() + return (self.loss_func(dx_i, dx_t, weights) + self.loss_func(dy_i, dy_t, weights) + self.loss_func(dz_i, dz_t, weights)).mean() + + +class SmoothnessLoss(nn.Module): + def __init__(self, mode = 'l2'): + super(SmoothnessLoss, self).__init__() + self.mode = mode + if mode == 'l1': + self.loss_func = l1_loss + elif mode == 'l2': + self.loss_func = l2_loss + else: + raise ValueError('Not supported loss_func for SmoothnessLoss:', mode) + + def gradient(self, x): + # x: (b, c, s, r, c) --> dx, dy, dz: (b, c, s, r, c) + back = F.pad(x, [0, 1, 0, 0, 0, 0])[:, :, :, :, 1:] + right = F.pad(x, [0, 0, 0, 1, 0, 0])[:, :, :, 1:, :] + bottom = F.pad(x, [0, 0, 0, 0, 0, 1])[:, :, 1:, :, :] + + dx, dy, dz = back - x, right - x, bottom - x + dx[:, :, :, :, -1] = 0 + dy[:, :, :, -1] = 0 + dz[:, :, -1] = 0 + return dx, dy, dz + + def forward(self, input): + dx, dy, dz = self.gradient(input) + if self.mode == 'l1': + return (abs(dx) + abs(dy) + abs(dz)).mean() + elif self.mode == 'l2': + return (dx ** 2 + dy ** 2 + dz ** 2).mean() + else: + raise NotImplementedError('Not supported loss mode for SmoothnessLoss:', self.mode) + + +class HessianLoss(nn.Module): + def __init__(self, mode = 'l2'): + super(HessianLoss, self).__init__() + self.mode = mode + if mode == 'l1': + self.loss_func = l1_loss + elif mode == 'l2': + self.loss_func = l2_loss + else: + raise ValueError('Not supported loss_func for SmoothnessLoss:', mode) + + def gradient(self, x): # gradient_c + # x: (b, c, s, r, c) --> dx, dy, dz: (b, c, s, r, c) + back = F.pad(x, [0, 1, 0, 0, 0, 0])[:, :, :, :, 1:] + right = F.pad(x, [0, 0, 0, 1, 0, 0])[:, :, :, 1:, :] + bottom = F.pad(x, [0, 0, 0, 0, 0, 1])[:, :, 1:, :, :] + + dx, dy, dz = back - x, right - x, bottom - x + dx[:, :, :, :, -1] = 0 + dy[:, :, :, -1] = 0 + dz[:, :, -1] = 0 + return dx, dy, dz + + def forward(self, input): # det of the Hessian + dx, dy, dz = self.gradient(input) + ddxx, ddxy, ddxz = self.gradient(dx) + ddxy, ddyy, ddyz = self.gradient(dy) + ddxz, ddyz, ddzz = self.gradient(dz) + det_hessian = ddxx * (ddyy * ddzz - ddyz ** 2) - ddxy * (ddxy * ddzz - ddxz * ddyz) + ddxz * (ddxy * ddyz - ddxz * ddyy) + if self.mode == 'l1': + return abs(det_hessian).sum() + elif self.mode == 'l2': + return (det_hessian ** 2).sum() + else: + raise NotImplementedError('Not supported loss mode for SmoothnessLoss:', self.mode) diff --git a/Trainer/models/unet3d/__init__.py b/Trainer/models/unet3d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Trainer/models/unet3d/buildingblocks.py b/Trainer/models/unet3d/buildingblocks.py new file mode 100644 index 0000000000000000000000000000000000000000..f8cc416bdc265fee927848e96cfe7bb3e04ae7ed --- /dev/null +++ b/Trainer/models/unet3d/buildingblocks.py @@ -0,0 +1,393 @@ +from functools import partial + +import torch +from torch import nn as nn +from torch.nn import functional as F + +def create_conv(in_channels, out_channels, kernel_size, order, num_groups, padding, is3d): + """ + Create a list of modules with together constitute a single conv layer with non-linearity + and optional batchnorm/groupnorm. + + Args: + in_channels (int): number of input channels + out_channels (int): number of output channels + kernel_size(int or tuple): size of the convolving kernel + order (string): order of things, e.g. + 'cr' -> conv + ReLU + 'gcr' -> groupnorm + conv + ReLU + 'cl' -> conv + LeakyReLU + 'ce' -> conv + ELU + 'bcr' -> batchnorm + conv + ReLU + num_groups (int): number of groups for the GroupNorm + padding (int or tuple): add zero-padding added to all three sides of the input + is3d (bool): is3d (bool): if True use Conv3d, otherwise use Conv2d + Return: + list of tuple (name, module) + """ + assert 'c' in order, "Conv layer MUST be present" + assert order[0] not in 'rle', 'Non-linearity cannot be the first operation in the layer' + + modules = [] + for i, char in enumerate(order): + if char == 'r': + modules.append(('ReLU', nn.ReLU(inplace=True))) + elif char == 'l': + modules.append(('LeakyReLU', nn.LeakyReLU(inplace=True))) + elif char == 'e': + modules.append(('ELU', nn.ELU(inplace=True))) + elif char == 'c': + # add learnable bias only in the absence of batchnorm/groupnorm + bias = not ('g' in order or 'b' in order) + if is3d: + conv = nn.Conv3d(in_channels, out_channels, kernel_size, padding=padding, bias=bias) + else: + conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding, bias=bias) + + modules.append(('conv', conv)) + elif char == 'g': + is_before_conv = i < order.index('c') + if is_before_conv: + num_channels = in_channels + else: + num_channels = out_channels + + # use only one group if the given number of groups is greater than the number of channels + if num_channels < num_groups: + num_groups = 1 + + assert num_channels % num_groups == 0, f'Expected number of channels in input to be divisible by num_groups. num_channels={num_channels}, num_groups={num_groups}' + modules.append(('groupnorm', nn.GroupNorm(num_groups=num_groups, num_channels=num_channels))) + elif char == 'b': + is_before_conv = i < order.index('c') + if is3d: + bn = nn.BatchNorm3d + else: + bn = nn.BatchNorm2d + + if is_before_conv: + modules.append(('batchnorm', bn(in_channels))) + else: + modules.append(('batchnorm', bn(out_channels))) + else: + raise ValueError(f"Unsupported layer type '{char}'. MUST be one of ['b', 'g', 'r', 'l', 'e', 'c']") + + return modules + + +class SingleConv(nn.Sequential): + """ + Basic convolutional module consisting of a Conv3d, non-linearity and optional batchnorm/groupnorm. The order + of operations can be specified via the `order` parameter + + Args: + in_channels (int): number of input channels + out_channels (int): number of output channels + kernel_size (int or tuple): size of the convolving kernel + order (string): determines the order of layers, e.g. + 'cr' -> conv + ReLU + 'crg' -> conv + ReLU + groupnorm + 'cl' -> conv + LeakyReLU + 'ce' -> conv + ELU + num_groups (int): number of groups for the GroupNorm + padding (int or tuple): add zero-padding + is3d (bool): if True use Conv3d, otherwise use Conv2d + """ + + def __init__(self, in_channels, out_channels, kernel_size=3, order='gcr', num_groups=8, padding=1, is3d=True): + super(SingleConv, self).__init__() + + for name, module in create_conv(in_channels, out_channels, kernel_size, order, num_groups, padding, is3d): + self.add_module(name, module) + + +class DoubleConv(nn.Sequential): + """ + A module consisting of two consecutive convolution layers (e.g. BatchNorm3d+ReLU+Conv3d). + We use (Conv3d+ReLU+GroupNorm3d) by default. + This can be changed however by providing the 'order' argument, e.g. in order + to change to Conv3d+BatchNorm3d+ELU use order='cbe'. + Use padded convolutions to make sure that the output (H_out, W_out) is the same + as (H_in, W_in), so that you don't have to crop in the decoder path. + + Args: + in_channels (int): number of input channels + out_channels (int): number of output channels + encoder (bool): if True we're in the encoder path, otherwise we're in the decoder + kernel_size (int or tuple): size of the convolving kernel + order (string): determines the order of layers, e.g. + 'cr' -> conv + ReLU + 'crg' -> conv + ReLU + groupnorm + 'cl' -> conv + LeakyReLU + 'ce' -> conv + ELU + num_groups (int): number of groups for the GroupNorm + padding (int or tuple): add zero-padding added to all three sides of the input + is3d (bool): if True use Conv3d instead of Conv2d layers + """ + + def __init__(self, in_channels, out_channels, encoder, kernel_size=3, order='gcr', num_groups=8, padding=1, + is3d=True): + super(DoubleConv, self).__init__() + if encoder: + # we're in the encoder path + conv1_in_channels = in_channels + conv1_out_channels = out_channels // 2 + if conv1_out_channels < in_channels: + conv1_out_channels = in_channels + conv2_in_channels, conv2_out_channels = conv1_out_channels, out_channels + else: + # we're in the decoder path, decrease the number of channels in the 1st convolution + conv1_in_channels, conv1_out_channels = in_channels, out_channels + conv2_in_channels, conv2_out_channels = out_channels, out_channels + + # conv1 + self.add_module('SingleConv1', + SingleConv(conv1_in_channels, conv1_out_channels, kernel_size, order, num_groups, + padding=padding, is3d=is3d)) + # conv2 + self.add_module('SingleConv2', + SingleConv(conv2_in_channels, conv2_out_channels, kernel_size, order, num_groups, + padding=padding, is3d=is3d)) + + + + +class Encoder(nn.Module): + """ + A single module from the encoder path consisting of the optional max + pooling layer (one may specify the MaxPool kernel_size to be different + from the standard (2,2,2), e.g. if the volumetric data is anisotropic + (make sure to use complementary scale_factor in the decoder path) followed by + a basic module (DoubleConv or ResNetBlock). + + Args: + in_channels (int): number of input channels + out_channels (int): number of output channels + conv_kernel_size (int or tuple): size of the convolving kernel + apply_pooling (bool): if True use MaxPool3d before DoubleConv + pool_kernel_size (int or tuple): the size of the window + pool_type (str): pooling layer: 'max' or 'avg' + basic_module(nn.Module): either ResNetBlock or DoubleConv + conv_layer_order (string): determines the order of layers + in `DoubleConv` module. See `DoubleConv` for more info. + num_groups (int): number of groups for the GroupNorm + padding (int or tuple): add zero-padding added to all three sides of the input + is3d (bool): use 3d or 2d convolutions/pooling operation + """ + + def __init__(self, in_channels, out_channels, conv_kernel_size=3, apply_pooling=True, + pool_kernel_size=2, pool_type='max', basic_module=DoubleConv, conv_layer_order='gcr', + num_groups=8, padding=1, is3d=True): + super(Encoder, self).__init__() + assert pool_type in ['max', 'avg'] + if apply_pooling: + if pool_type == 'max': + if is3d: + self.pooling = nn.MaxPool3d(kernel_size=pool_kernel_size) + else: + self.pooling = nn.MaxPool2d(kernel_size=pool_kernel_size) + else: + if is3d: + self.pooling = nn.AvgPool3d(kernel_size=pool_kernel_size) + else: + self.pooling = nn.AvgPool2d(kernel_size=pool_kernel_size) + else: + self.pooling = None + + self.basic_module = basic_module(in_channels, out_channels, + encoder=True, + kernel_size=conv_kernel_size, + order=conv_layer_order, + num_groups=num_groups, + padding=padding, + is3d=is3d) + + def forward(self, x): + if self.pooling is not None: + x = self.pooling(x) + x = self.basic_module(x) + return x + + +class Decoder(nn.Module): + """ + A single module for decoder path consisting of the upsampling layer + (either learned ConvTranspose3d or nearest neighbor interpolation) + followed by a basic module (DoubleConv or ResNetBlock). + + Args: + in_channels (int): number of input channels + out_channels (int): number of output channels + conv_kernel_size (int or tuple): size of the convolving kernel + scale_factor (tuple): used as the multiplier for the image H/W/D in + case of nn.Upsample or as stride in case of ConvTranspose3d, must reverse the MaxPool3d operation + from the corresponding encoder + basic_module(nn.Module): either ResNetBlock or DoubleConv + conv_layer_order (string): determines the order of layers + in `DoubleConv` module. See `DoubleConv` for more info. + num_groups (int): number of groups for the GroupNorm + padding (int or tuple): add zero-padding added to all three sides of the input + upsample (bool): should the input be upsampled + """ + + def __init__(self, in_channels, out_channels, conv_kernel_size=3, scale_factor=(2, 2, 2), basic_module=DoubleConv, + conv_layer_order='gcr', num_groups=8, mode='nearest', padding=1, upsample=True, is3d=True): + super(Decoder, self).__init__() + + if upsample: + if basic_module == DoubleConv: + # if DoubleConv is the basic_module use interpolation for upsampling and concatenation joining + self.upsampling = InterpolateUpsampling(mode=mode) + # concat joining + self.joining = partial(self._joining, concat=True) + else: + # if basic_module=ResNetBlock use transposed convolution upsampling and summation joining + self.upsampling = TransposeConvUpsampling(in_channels=in_channels, out_channels=out_channels, + kernel_size=conv_kernel_size, scale_factor=scale_factor) + # sum joining + self.joining = partial(self._joining, concat=False) + # adapt the number of in_channels for the ResNetBlock + in_channels = out_channels + else: + # no upsampling + self.upsampling = NoUpsampling() + # concat joining + self.joining = partial(self._joining, concat=True) + + self.basic_module = basic_module(in_channels, out_channels, + encoder=False, + kernel_size=conv_kernel_size, + order=conv_layer_order, + num_groups=num_groups, + padding=padding, + is3d=is3d) + + def forward(self, encoder_features, x): + x = self.upsampling(encoder_features=encoder_features, x=x) + x = self.joining(encoder_features, x) + x = self.basic_module(x) + return x + + @staticmethod + def _joining(encoder_features, x, concat): + if concat: + return torch.cat((encoder_features, x), dim=1) + else: + return encoder_features + x + + +def create_encoders(in_channels, f_maps, basic_module, conv_kernel_size, conv_padding, layer_order, num_groups, + pool_kernel_size, is3d): + # create encoder path consisting of Encoder modules. Depth of the encoder is equal to `len(f_maps)` + encoders = [] + for i, out_feature_num in enumerate(f_maps): + if i == 0: + # apply conv_coord only in the first encoder if any + encoder = Encoder(in_channels, out_feature_num, + apply_pooling=False, # skip pooling in the firs encoder + basic_module=basic_module, + conv_layer_order=layer_order, + conv_kernel_size=conv_kernel_size, + num_groups=num_groups, + padding=conv_padding, + is3d=is3d) + else: + encoder = Encoder(f_maps[i - 1], out_feature_num, + basic_module=basic_module, + conv_layer_order=layer_order, + conv_kernel_size=conv_kernel_size, + num_groups=num_groups, + pool_kernel_size=pool_kernel_size, + padding=conv_padding, + is3d=is3d) + + encoders.append(encoder) + + return nn.ModuleList(encoders) + + +def create_decoders(f_maps, basic_module, conv_kernel_size, conv_padding, layer_order, num_groups, is3d): + # create decoder path consisting of the Decoder modules. The length of the decoder list is equal to `len(f_maps) - 1` + decoders = [] + reversed_f_maps = list(reversed(f_maps)) + for i in range(len(reversed_f_maps) - 1): + if basic_module == DoubleConv: + in_feature_num = reversed_f_maps[i] + reversed_f_maps[i + 1] + else: + in_feature_num = reversed_f_maps[i] + + out_feature_num = reversed_f_maps[i + 1] + + decoder = Decoder(in_feature_num, out_feature_num, + basic_module=basic_module, + conv_layer_order=layer_order, + conv_kernel_size=conv_kernel_size, + num_groups=num_groups, + padding=conv_padding, + is3d=is3d) + decoders.append(decoder) + return nn.ModuleList(decoders) + + +class AbstractUpsampling(nn.Module): + """ + Abstract class for upsampling. A given implementation should upsample a given 5D input tensor using either + interpolation or learned transposed convolution. + """ + + def __init__(self, upsample): + super(AbstractUpsampling, self).__init__() + self.upsample = upsample + + def forward(self, encoder_features, x): + # get the spatial dimensions of the output given the encoder_features + output_size = encoder_features.size()[2:] + # upsample the input and return + return self.upsample(x, output_size) + + +class InterpolateUpsampling(AbstractUpsampling): + """ + Args: + mode (str): algorithm used for upsampling: + 'nearest' | 'linear' | 'bilinear' | 'trilinear' | 'area'. Default: 'nearest' + used only if transposed_conv is False + """ + + def __init__(self, mode='nearest'): + upsample = partial(self._interpolate, mode=mode) + super().__init__(upsample) + + @staticmethod + def _interpolate(x, size, mode): + return F.interpolate(x, size=size, mode=mode) + + +class TransposeConvUpsampling(AbstractUpsampling): + """ + Args: + in_channels (int): number of input channels for transposed conv + used only if transposed_conv is True + out_channels (int): number of output channels for transpose conv + used only if transposed_conv is True + kernel_size (int or tuple): size of the convolving kernel + used only if transposed_conv is True + scale_factor (int or tuple): stride of the convolution + used only if transposed_conv is True + + """ + + def __init__(self, in_channels=None, out_channels=None, kernel_size=3, scale_factor=(2, 2, 2)): + # make sure that the output size reverses the MaxPool3d from the corresponding encoder + upsample = nn.ConvTranspose3d(in_channels, out_channels, kernel_size=kernel_size, stride=scale_factor, + padding=1) + super().__init__(upsample) + + +class NoUpsampling(AbstractUpsampling): + def __init__(self): + super().__init__(self._no_upsampling) + + @staticmethod + def _no_upsampling(x, size): + return x diff --git a/Trainer/models/unet3d/losses.py b/Trainer/models/unet3d/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..99a094fd807c482b25f60f1332b1d0ffd8bf37c6 --- /dev/null +++ b/Trainer/models/unet3d/losses.py @@ -0,0 +1,347 @@ +import torch +import torch.nn.functional as F +from torch import nn as nn +from torch.autograd import Variable +from torch.nn import MSELoss, SmoothL1Loss, L1Loss + +from .utils import expand_as_one_hot + + +def compute_per_channel_dice(input, target, epsilon=1e-6, weight=None): + """ + Computes DiceCoefficient as defined in https://arxiv.org/abs/1606.04797 given a multi channel input and target. + Assumes the input is a normalized probability, e.g. a result of Sigmoid or Softmax function. + + Args: + input (torch.Tensor): NxCxSpatial input tensor + target (torch.Tensor): NxCxSpatial target tensor + epsilon (float): prevents division by zero + weight (torch.Tensor): Cx1 tensor of weight per channel/class + """ + + # input and target shapes must match + assert input.size() == target.size(), "'input' and 'target' must have the same shape" + + input = flatten(input) + target = flatten(target) + target = target.float() + + # compute per channel Dice Coefficient + intersect = (input * target).sum(-1) + if weight is not None: + intersect = weight * intersect + + # here we can use standard dice (input + target).sum(-1) or extension (see V-Net) (input^2 + target^2).sum(-1) + denominator = (input * input).sum(-1) + (target * target).sum(-1) + return 2 * (intersect / denominator.clamp(min=epsilon)) + + +class _MaskingLossWrapper(nn.Module): + """ + Loss wrapper which prevents the gradient of the loss to be computed where target is equal to `ignore_index`. + """ + + def __init__(self, loss, ignore_index): + super(_MaskingLossWrapper, self).__init__() + assert ignore_index is not None, 'ignore_index cannot be None' + self.loss = loss + self.ignore_index = ignore_index + + def forward(self, input, target): + mask = target.clone().ne_(self.ignore_index) + mask.requires_grad = False + + # mask out input/target so that the gradient is zero where on the mask + input = input * mask + target = target * mask + + # forward masked input and target to the loss + return self.loss(input, target) + + +class SkipLastTargetChannelWrapper(nn.Module): + """ + Loss wrapper which removes additional target channel + """ + + def __init__(self, loss, squeeze_channel=False): + super(SkipLastTargetChannelWrapper, self).__init__() + self.loss = loss + self.squeeze_channel = squeeze_channel + + def forward(self, input, target): + assert target.size(1) > 1, 'Target tensor has a singleton channel dimension, cannot remove channel' + + # skips last target channel if needed + target = target[:, :-1, ...] + + if self.squeeze_channel: + # squeeze channel dimension if singleton + target = torch.squeeze(target, dim=1) + return self.loss(input, target) + + +class _AbstractDiceLoss(nn.Module): + """ + Base class for different implementations of Dice loss. + """ + + def __init__(self, weight=None, normalization='sigmoid'): + super(_AbstractDiceLoss, self).__init__() + self.register_buffer('weight', weight) + # The output from the network during training is assumed to be un-normalized probabilities and we would + # like to normalize the logits. Since Dice (or soft Dice in this case) is usually used for binary data, + # normalizing the channels with Sigmoid is the default choice even for multi-class segmentation problems. + # However if one would like to apply Softmax in order to get the proper probability distribution from the + # output, just specify `normalization=Softmax` + assert normalization in ['sigmoid', 'softmax', 'none'] + if normalization == 'sigmoid': + self.normalization = nn.Sigmoid() + elif normalization == 'softmax': + self.normalization = nn.Softmax(dim=1) + else: + self.normalization = lambda x: x + + def dice(self, input, target, weight): + # actual Dice score computation; to be implemented by the subclass + raise NotImplementedError + + def forward(self, input, target): + # get probabilities from logits + input = self.normalization(input) + + # compute per channel Dice coefficient + per_channel_dice = self.dice(input, target, weight=self.weight) + + # average Dice score across all channels/classes + return 1. - torch.mean(per_channel_dice) + + +class DiceLoss(_AbstractDiceLoss): + """Computes Dice Loss according to https://arxiv.org/abs/1606.04797. + For multi-class segmentation `weight` parameter can be used to assign different weights per class. + The input to the loss function is assumed to be a logit and will be normalized by the Sigmoid function. + """ + + def __init__(self, weight=None, normalization='sigmoid'): + super().__init__(weight, normalization) + + def dice(self, input, target, weight): + return compute_per_channel_dice(input, target, weight=self.weight) + + +class GeneralizedDiceLoss(_AbstractDiceLoss): + """Computes Generalized Dice Loss (GDL) as described in https://arxiv.org/pdf/1707.03237.pdf. + """ + + def __init__(self, normalization='sigmoid', epsilon=1e-6): + super().__init__(weight=None, normalization=normalization) + self.epsilon = epsilon + + def dice(self, input, target, weight): + assert input.size() == target.size(), "'input' and 'target' must have the same shape" + + input = flatten(input) + target = flatten(target) + target = target.float() + + if input.size(0) == 1: + # for GDL to make sense we need at least 2 channels (see https://arxiv.org/pdf/1707.03237.pdf) + # put foreground and background voxels in separate channels + input = torch.cat((input, 1 - input), dim=0) + target = torch.cat((target, 1 - target), dim=0) + + # GDL weighting: the contribution of each label is corrected by the inverse of its volume + w_l = target.sum(-1) + w_l = 1 / (w_l * w_l).clamp(min=self.epsilon) + w_l.requires_grad = False + + intersect = (input * target).sum(-1) + intersect = intersect * w_l + + denominator = (input + target).sum(-1) + denominator = (denominator * w_l).clamp(min=self.epsilon) + + return 2 * (intersect.sum() / denominator.sum()) + + +class BCEDiceLoss(nn.Module): + """Linear combination of BCE and Dice losses""" + + def __init__(self, alpha, beta): + super(BCEDiceLoss, self).__init__() + self.alpha = alpha + self.bce = nn.BCEWithLogitsLoss() + self.beta = beta + self.dice = DiceLoss() + + def forward(self, input, target): + return self.alpha * self.bce(input, target) + self.beta * self.dice(input, target) + + +class WeightedCrossEntropyLoss(nn.Module): + """WeightedCrossEntropyLoss (WCE) as described in https://arxiv.org/pdf/1707.03237.pdf + """ + + def __init__(self, ignore_index=-1): + super(WeightedCrossEntropyLoss, self).__init__() + self.ignore_index = ignore_index + + def forward(self, input, target): + weight = self._class_weights(input) + return F.cross_entropy(input, target, weight=weight, ignore_index=self.ignore_index) + + @staticmethod + def _class_weights(input): + # normalize the input first + input = F.softmax(input, dim=1) + flattened = flatten(input) + nominator = (1. - flattened).sum(-1) + denominator = flattened.sum(-1) + class_weights = Variable(nominator / denominator, requires_grad=False) + return class_weights + + +class PixelWiseCrossEntropyLoss(nn.Module): + def __init__(self, class_weights=None, ignore_index=None): + super(PixelWiseCrossEntropyLoss, self).__init__() + self.register_buffer('class_weights', class_weights) + self.ignore_index = ignore_index + self.log_softmax = nn.LogSoftmax(dim=1) + + def forward(self, input, target, weights): + assert target.size() == weights.size() + # normalize the input + log_probabilities = self.log_softmax(input) + # standard CrossEntropyLoss requires the target to be (NxDxHxW), so we need to expand it to (NxCxDxHxW) + target = expand_as_one_hot(target, C=input.size()[1], ignore_index=self.ignore_index) + # expand weights + weights = weights.unsqueeze(1) + weights = weights.expand_as(input) + + # create default class_weights if None + if self.class_weights is None: + class_weights = torch.ones(input.size()[1]).float().cuda() + else: + class_weights = self.class_weights + + # resize class_weights to be broadcastable into the weights + class_weights = class_weights.view(1, -1, 1, 1, 1) + + # multiply weights tensor by class weights + weights = class_weights * weights + + # compute the losses + result = -weights * target * log_probabilities + # average the losses + return result.mean() + + +class WeightedSmoothL1Loss(nn.SmoothL1Loss): + def __init__(self, threshold, initial_weight, apply_below_threshold=True): + super().__init__(reduction="none") + self.threshold = threshold + self.apply_below_threshold = apply_below_threshold + self.weight = initial_weight + + def forward(self, input, target): + l1 = super().forward(input, target) + + if self.apply_below_threshold: + mask = target < self.threshold + else: + mask = target >= self.threshold + + l1[mask] = l1[mask] * self.weight + + return l1.mean() + + +def flatten(tensor): + """Flattens a given tensor such that the channel axis is first. + The shapes are transformed as follows: + (N, C, D, H, W) -> (C, N * D * H * W) + """ + # number of channels + C = tensor.size(1) + # new axis order + axis_order = (1, 0) + tuple(range(2, tensor.dim())) + # Transpose: (N, C, D, H, W) -> (C, N, D, H, W) + transposed = tensor.permute(axis_order) + # Flatten: (C, N, D, H, W) -> (C, N * D * H * W) + return transposed.contiguous().view(C, -1) + + +def get_loss_criterion(config): + """ + Returns the loss function based on provided configuration + :param config: (dict) a top level configuration object containing the 'loss' key + :return: an instance of the loss function + """ + assert 'loss' in config, 'Could not find loss function configuration' + loss_config = config['loss'] + name = loss_config.pop('name') + + ignore_index = loss_config.pop('ignore_index', None) + skip_last_target = loss_config.pop('skip_last_target', False) + weight = loss_config.pop('weight', None) + + if weight is not None: + weight = torch.tensor(weight) + + pos_weight = loss_config.pop('pos_weight', None) + if pos_weight is not None: + pos_weight = torch.tensor(pos_weight) + + loss = _create_loss(name, loss_config, weight, ignore_index, pos_weight) + + if not (ignore_index is None or name in ['CrossEntropyLoss', 'WeightedCrossEntropyLoss']): + # use MaskingLossWrapper only for non-cross-entropy losses, since CE losses allow specifying 'ignore_index' directly + loss = _MaskingLossWrapper(loss, ignore_index) + + if skip_last_target: + loss = SkipLastTargetChannelWrapper(loss, loss_config.get('squeeze_channel', False)) + + if torch.cuda.is_available(): + loss = loss.cuda() + + return loss + + +####################################################################################################################### + +def _create_loss(name, loss_config, weight, ignore_index, pos_weight): + if name == 'BCEWithLogitsLoss': + return nn.BCEWithLogitsLoss(pos_weight=pos_weight) + elif name == 'BCEDiceLoss': + alpha = loss_config.get('alphs', 1.) + beta = loss_config.get('beta', 1.) + return BCEDiceLoss(alpha, beta) + elif name == 'CrossEntropyLoss': + if ignore_index is None: + ignore_index = -100 # use the default 'ignore_index' as defined in the CrossEntropyLoss + return nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_index) + elif name == 'WeightedCrossEntropyLoss': + if ignore_index is None: + ignore_index = -100 # use the default 'ignore_index' as defined in the CrossEntropyLoss + return WeightedCrossEntropyLoss(ignore_index=ignore_index) + elif name == 'PixelWiseCrossEntropyLoss': + return PixelWiseCrossEntropyLoss(class_weights=weight, ignore_index=ignore_index) + elif name == 'GeneralizedDiceLoss': + normalization = loss_config.get('normalization', 'sigmoid') + return GeneralizedDiceLoss(normalization=normalization) + elif name == 'DiceLoss': + normalization = loss_config.get('normalization', 'sigmoid') + return DiceLoss(weight=weight, normalization=normalization) + elif name == 'MSELoss': + return MSELoss() + elif name == 'SmoothL1Loss': + return SmoothL1Loss() + elif name == 'L1Loss': + return L1Loss() + elif name == 'WeightedSmoothL1Loss': + return WeightedSmoothL1Loss(threshold=loss_config['threshold'], + initial_weight=loss_config['initial_weight'], + apply_below_threshold=loss_config.get('apply_below_threshold', True)) + else: + raise RuntimeError(f"Unsupported loss function: '{name}'") diff --git a/Trainer/models/unet3d/model.py b/Trainer/models/unet3d/model.py new file mode 100644 index 0000000000000000000000000000000000000000..3c1b9745c9d66e5b2e8896f4e67c2f83879e3b10 --- /dev/null +++ b/Trainer/models/unet3d/model.py @@ -0,0 +1,281 @@ +import torch.nn as nn +import torch.nn.functional as F + +from .buildingblocks import DoubleConv, create_decoders, create_encoders +from .utils import get_class, number_of_features_per_level + + +class AbstractUNetSep(nn.Module): + """ + Base class for standard and residual UNet. + + Args: + in_channels (int): number of input channels + out_channels (int): number of output segmentation masks; + Note that the of out_channels might correspond to either + different semantic classes or to different binary segmentation mask. + It's up to the user of the class to interpret the out_channels and + use the proper loss criterion during training (i.e. CrossEntropyLoss (multi-class) + or BCEWithLogitsLoss (two-class) respectively) + f_maps (int, tuple): number of feature maps at each level of the encoder; if it's an integer the number + of feature maps is given by the geometric progression: f_maps ^ k, k=1,2,3,4 + final_sigmoid (bool): if True apply element-wise nn.Sigmoid after the final 1x1 convolution, + otherwise apply nn.Softmax. In effect only if `self.training == False`, i.e. during validation/testing + basic_module: basic model for the encoder/decoder (DoubleConv, ResNetBlock, ....) + layer_order (string): determines the order of layers in `SingleConv` module. + E.g. 'crg' stands for GroupNorm3d+Conv3d+ReLU. See `SingleConv` for more info + num_groups (int): number of groups for the GroupNorm + num_levels (int): number of levels in the encoder/decoder path (applied only if f_maps is an int) + default: 4 + is_segmentation (bool): if True and the model is in eval mode, Sigmoid/Softmax normalization is applied + after the final convolution; if False (regression problem) the normalization layer is skipped + conv_kernel_size (int or tuple): size of the convolving kernel in the basic_module + pool_kernel_size (int or tuple): the size of the window + conv_padding (int or tuple): add zero-padding added to all three sides of the input + is_3d (bool): if True the model is 3D, otherwise 2D, default: True + """ + + def __init__(self, in_channels, basic_module, f_maps=64, layer_order='gcr', + num_groups=8, num_levels=4, conv_kernel_size=3, pool_kernel_size=2, + conv_padding=1, is_unit_vector = False, is_3d=True): + super(AbstractUNetSep, self).__init__() + + if isinstance(f_maps, int): + self.f_maps = number_of_features_per_level(f_maps, num_levels=num_levels) + else: + assert isinstance(self.f_maps, list) or isinstance(self.f_maps, tuple) + self.f_maps = f_maps + + assert len(self.f_maps) > 1, "Required at least 2 levels in the U-Net" + if 'g' in layer_order: + assert num_groups is not None, "num_groups must be specified if GroupNorm is used" + + # create encoder path + self.encoders = create_encoders(in_channels, self.f_maps, basic_module, conv_kernel_size, conv_padding, layer_order, + num_groups, pool_kernel_size, is_3d) + + # create decoder path #1 + self.decoders_normal = create_decoders(self.f_maps, basic_module, conv_kernel_size, conv_padding, layer_order, num_groups, + is_3d) + # create decoder path #2 + self.decoders_pathol = create_decoders(self.f_maps, basic_module, conv_kernel_size, conv_padding, layer_order, num_groups, + is_3d) + + self.is_unit_vector = is_unit_vector + + def forward(self, x): + # encoder part + encoders_features = [] + for i, encoder_a in enumerate(self.encoders): + x = encoder_a(x) + # reverse the encoder outputs to be aligned with the decoder + encoders_features.insert(0, x) + + # remove the last encoder's output from the list + # !!remember: it's the 1st in the list + encoders_features = encoders_features[1:] + x_normal, x_pathol = x.clone(), x.clone() + + # decoder part + for decoder_normal, decoder_pathol, encoder_features in zip(self.decoders_normal, self.decoders_pathol, encoders_features): + # pass the output from the corresponding encoder and the output + # of the previous decoder + x_normal = decoder_normal(encoder_features, x_normal) + x_pathol = decoder_pathol(encoder_features, x_pathol) + + if self.is_unit_vector: + x_normal = F.normalize(x_normal, dim=1) + x_pathol = F.normalize(x_pathol, dim=1) + + return {'normal': x_normal, 'pathology': x_pathol} + + + def get_feature(self, x): + encoders_features = [] + for i, encoder_a in enumerate(self.encoders): + x = encoder_a(x) + encoders_features.insert(0, x) + encoders_features = encoders_features[1:] + + x_normal, x_pathol = x.clone(), x.clone() + decoders_normal_features, decoders_pathol_features = [x_normal], [x_pathol] + for decoder_normal, decoder_pathol, encoder_features in zip(self.decoders_normal, self.decoders_pathol, encoders_features): + # pass the output from the corresponding encoder and the output + # of the previous decoder + x_normal = decoder_normal(encoder_features, x_normal) + x_pathol = decoder_normal(encoder_features, x_pathol) + decoders_normal_features.append(x_normal) + decoders_pathol_features.append(x_pathol) + if self.is_unit_vector: + decoders_normal_features[-1] = F.normalize(decoders_normal_features[-1], dim=1) + decoders_pathol_features[-1] = F.normalize(decoders_pathol_features[-1], dim=1) + return {'normal': decoders_normal_features, 'pathology': decoders_pathol_features} + + + +class AbstractUNet(nn.Module): + """ + Base class for standard and residual UNet. + + Args: + in_channels (int): number of input channels + out_channels (int): number of output segmentation masks; + Note that the of out_channels might correspond to either + different semantic classes or to different binary segmentation mask. + It's up to the user of the class to interpret the out_channels and + use the proper loss criterion during training (i.e. CrossEntropyLoss (multi-class) + or BCEWithLogitsLoss (two-class) respectively) + f_maps (int, tuple): number of feature maps at each level of the encoder; if it's an integer the number + of feature maps is given by the geometric progression: f_maps ^ k, k=1,2,3,4 + final_sigmoid (bool): if True apply element-wise nn.Sigmoid after the final 1x1 convolution, + otherwise apply nn.Softmax. In effect only if `self.training == False`, i.e. during validation/testing + basic_module: basic model for the encoder/decoder (DoubleConv, ResNetBlock, ....) + layer_order (string): determines the order of layers in `SingleConv` module. + E.g. 'crg' stands for GroupNorm3d+Conv3d+ReLU. See `SingleConv` for more info + num_groups (int): number of groups for the GroupNorm + num_levels (int): number of levels in the encoder/decoder path (applied only if f_maps is an int) + default: 4 + is_segmentation (bool): if True and the model is in eval mode, Sigmoid/Softmax normalization is applied + after the final convolution; if False (regression problem) the normalization layer is skipped + conv_kernel_size (int or tuple): size of the convolving kernel in the basic_module + pool_kernel_size (int or tuple): the size of the window + conv_padding (int or tuple): add zero-padding added to all three sides of the input + is_3d (bool): if True the model is 3D, otherwise 2D, default: True + """ + + def __init__(self, in_channels, basic_module, f_maps=64, layer_order='gcr', + num_groups=8, num_levels=4, conv_kernel_size=3, pool_kernel_size=2, + conv_padding=1, is_unit_vector = False, is_3d=True): + super(AbstractUNet, self).__init__() + + if isinstance(f_maps, int): + self.f_maps = number_of_features_per_level(f_maps, num_levels=num_levels) + else: + assert isinstance(self.f_maps, list) or isinstance(self.f_maps, tuple) + self.f_maps = f_maps + + assert len(self.f_maps) > 1, "Required at least 2 levels in the U-Net" + if 'g' in layer_order: + assert num_groups is not None, "num_groups must be specified if GroupNorm is used" + + # create encoder path + self.encoders = create_encoders(in_channels, self.f_maps, basic_module, conv_kernel_size, conv_padding, layer_order, + num_groups, pool_kernel_size, is_3d) + + # create decoder path + self.decoders = create_decoders(self.f_maps, basic_module, conv_kernel_size, conv_padding, layer_order, num_groups, + is_3d) + + self.is_unit_vector = is_unit_vector + + def forward(self, x): + # encoder part + encoders_features = [] + for encoder in self.encoders: + x = encoder(x) + # reverse the encoder outputs to be aligned with the decoder + encoders_features.insert(0, x) + + # remove the last encoder's output from the list + # !!remember: it's the 1st in the list + encoders_features = encoders_features[1:] + + # decoder part + for decoder, encoder_features in zip(self.decoders, encoders_features): + # pass the output from the corresponding encoder and the output + # of the previous decoder + x = decoder(encoder_features, x) + + if self.is_unit_vector: + x = F.normalize(x, dim=1) + + return x + + + def get_feature(self, x): + + encoders_features = [] + for encoder in self.encoders: + x = encoder(x) + encoders_features.insert(0, x) + encoders_features = encoders_features[1:] + + decoders_features = [x] + for decoder, encoder_features in zip(self.decoders, encoders_features): + x = decoder(encoder_features, x) + decoders_features.append(x) + if self.is_unit_vector: + decoders_features[-1] = F.normalize(decoders_features[-1], dim=1) + return decoders_features + + + +class UNet3D(AbstractUNet): + """ + 3DUnet model from + `"3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation" + `. + + Uses `DoubleConv` as a basic_module and nearest neighbor upsampling in the decoder + """ + + def __init__(self, in_channels, f_maps, layer_order='gcl', num_groups=8, num_levels=5, is_unit_vector=False, conv_padding=1, **kwargs): + + super(UNet3D, self).__init__(in_channels=in_channels, + basic_module=DoubleConv, + f_maps=f_maps, + layer_order=layer_order, + num_groups=num_groups, + num_levels=num_levels, + is_unit_vector=is_unit_vector, + conv_padding=conv_padding, + is_3d=True) + + +class UNet3DSep(AbstractUNetSep): + """ + 3DUnet model from + `"3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation" + `. + + Uses `DoubleConv` as a basic_module and nearest neighbor upsampling in the decoder + """ + + def __init__(self, in_channels, f_maps, layer_order='gcl', num_groups=8, num_levels=5, is_unit_vector=False, conv_padding=1, **kwargs): + + super(UNet3DSep, self).__init__(in_channels=in_channels, + basic_module=DoubleConv, + f_maps=f_maps, + layer_order=layer_order, + num_groups=num_groups, + num_levels=num_levels, + is_unit_vector=is_unit_vector, + conv_padding=conv_padding, + is_3d=True) + + +class UNet2D(AbstractUNet): + """ + 2DUnet model from + `"U-Net: Convolutional Networks for Biomedical Image Segmentation" ` + """ + + def __init__(self, args, in_channels, f_maps, conv_padding=1, **kwargs): + + super(UNet2D, self).__init__(in_channels=in_channels, + basic_module=DoubleConv, + f_maps=f_maps, + layer_order=args.layer_order, + num_groups=args.num_groups, + num_levels=args.num_levels, + conv_padding=conv_padding, + is_3d=True) + + + +def get_model(model_config): + model_class = get_class(model_config['name'], modules=[ + 'pytorch3dunet.unet3d.model' + ]) + return model_class(**model_config) + diff --git a/Trainer/models/unet3d/utils.py b/Trainer/models/unet3d/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4c173f757472883ff6f9e9b839dbb930a976d290 --- /dev/null +++ b/Trainer/models/unet3d/utils.py @@ -0,0 +1,300 @@ +import importlib +import logging +import os +import shutil +import sys + +import h5py +import numpy as np +import torch +from torch import optim + + +def save_checkpoint(state, is_best, checkpoint_dir): + """Saves model and training parameters at '{checkpoint_dir}/last_checkpoint.pytorch'. + If is_best==True saves '{checkpoint_dir}/best_checkpoint.pytorch' as well. + + Args: + state (dict): contains model's state_dict, optimizer's state_dict, epoch + and best evaluation metric value so far + is_best (bool): if True state contains the best model seen so far + checkpoint_dir (string): directory where the checkpoint are to be saved + """ + + if not os.path.exists(checkpoint_dir): + os.mkdir(checkpoint_dir) + + last_file_path = os.path.join(checkpoint_dir, 'last_checkpoint.pytorch') + torch.save(state, last_file_path) + if is_best: + best_file_path = os.path.join(checkpoint_dir, 'best_checkpoint.pytorch') + shutil.copyfile(last_file_path, best_file_path) + + +def load_checkpoint(checkpoint_path, model, optimizer=None, + model_key='model_state_dict', optimizer_key='optimizer_state_dict'): + """Loads model and training parameters from a given checkpoint_path + If optimizer is provided, loads optimizer's state_dict of as well. + + Args: + checkpoint_path (string): path to the checkpoint to be loaded + model (torch.nn.Module): model into which the parameters are to be copied + optimizer (torch.optim.Optimizer) optional: optimizer instance into + which the parameters are to be copied + + Returns: + state + """ + if not os.path.exists(checkpoint_path): + raise IOError(f"Checkpoint '{checkpoint_path}' does not exist") + + state = torch.load(checkpoint_path, map_location='cpu') + model.load_state_dict(state[model_key]) + + if optimizer is not None: + optimizer.load_state_dict(state[optimizer_key]) + + return state + + +def save_network_output(output_path, output, logger=None): + if logger is not None: + print(f'Saving network output to: {output_path}...') + output = output.detach().cpu()[0] + with h5py.File(output_path, 'w') as f: + f.create_dataset('predictions', data=output, compression='gzip') + + +loggers = {} + + +def get_logger(name, level=logging.INFO): + global loggers + if loggers.get(name) is not None: + return loggers[name] + else: + logger = logging.getLogger(name) + logger.setLevel(level) + # Logging to console + stream_handler = logging.StreamHandler(sys.stdout) + formatter = logging.Formatter( + '%(asctime)s [%(threadName)s] %(levelname)s %(name)s - %(message)s') + stream_handler.setFormatter(formatter) + logger.addHandler(stream_handler) + + loggers[name] = logger + + return logger + + +def get_number_of_learnable_parameters(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + +class RunningAverage: + """Computes and stores the average + """ + + def __init__(self): + self.count = 0 + self.sum = 0 + self.avg = 0 + + def update(self, value, n=1): + self.count += n + self.sum += value * n + self.avg = self.sum / self.count + + +def number_of_features_per_level(init_channel_number, num_levels): + return [init_channel_number * 2 ** k for k in range(num_levels)] + + +class _TensorboardFormatter: + """ + Tensorboard formatters converts a given batch of images (be it input/output to the network or the target segmentation + image) to a series of images that can be displayed in tensorboard. This is the parent class for all tensorboard + formatters which ensures that returned images are in the 'CHW' format. + """ + + def __init__(self, **kwargs): + pass + + def __call__(self, name, batch): + """ + Transform a batch to a series of tuples of the form (tag, img), where `tag` corresponds to the image tag + and `img` is the image itself. + + Args: + name (str): one of 'inputs'/'targets'/'predictions' + batch (torch.tensor): 4D or 5D torch tensor + """ + + def _check_img(tag_img): + tag, img = tag_img + + assert img.ndim == 2 or img.ndim == 3, 'Only 2D (HW) and 3D (CHW) images are accepted for display' + + if img.ndim == 2: + img = np.expand_dims(img, axis=0) + else: + C = img.shape[0] + assert C == 1 or C == 3, 'Only (1, H, W) or (3, H, W) images are supported' + + return tag, img + + tagged_images = self.process_batch(name, batch) + + return list(map(_check_img, tagged_images)) + + def process_batch(self, name, batch): + raise NotImplementedError + + +class DefaultTensorboardFormatter(_TensorboardFormatter): + def __init__(self, skip_last_target=False, **kwargs): + super().__init__(**kwargs) + self.skip_last_target = skip_last_target + + def process_batch(self, name, batch): + if name == 'targets' and self.skip_last_target: + batch = batch[:, :-1, ...] + + tag_template = '{}/batch_{}/channel_{}/slice_{}' + + tagged_images = [] + + if batch.ndim == 5: + # NCDHW + slice_idx = batch.shape[2] // 2 # get the middle slice + for batch_idx in range(batch.shape[0]): + for channel_idx in range(batch.shape[1]): + tag = tag_template.format(name, batch_idx, channel_idx, slice_idx) + img = batch[batch_idx, channel_idx, slice_idx, ...] + tagged_images.append((tag, self._normalize_img(img))) + else: + # batch has no channel dim: NDHW + slice_idx = batch.shape[1] // 2 # get the middle slice + for batch_idx in range(batch.shape[0]): + tag = tag_template.format(name, batch_idx, 0, slice_idx) + img = batch[batch_idx, slice_idx, ...] + tagged_images.append((tag, self._normalize_img(img))) + + return tagged_images + + @staticmethod + def _normalize_img(img): + return np.nan_to_num((img - np.min(img)) / np.ptp(img)) + + +def _find_masks(batch, min_size=10): + """Center the z-slice in the 'middle' of a given instance, given a batch of instances + + Args: + batch (ndarray): 5d numpy tensor (NCDHW) + """ + result = [] + for b in batch: + assert b.shape[0] == 1 + patch = b[0] + z_sum = patch.sum(axis=(1, 2)) + coords = np.where(z_sum > min_size)[0] + if len(coords) > 0: + ind = coords[len(coords) // 2] + result.append(b[:, ind:ind + 1, ...]) + else: + ind = b.shape[1] // 2 + result.append(b[:, ind:ind + 1, ...]) + + return np.stack(result, axis=0) + + +def get_tensorboard_formatter(formatter_config): + if formatter_config is None: + return DefaultTensorboardFormatter() + + class_name = formatter_config['name'] + m = importlib.import_module('pytorch3dunet.unet3d.utils') + clazz = getattr(m, class_name) + return clazz(**formatter_config) + + +def expand_as_one_hot(input, C, ignore_index=None): + """ + Converts NxSPATIAL label image to NxCxSPATIAL, where each label gets converted to its corresponding one-hot vector. + It is assumed that the batch dimension is present. + Args: + input (torch.Tensor): 3D/4D input image + C (int): number of channels/labels + ignore_index (int): ignore index to be kept during the expansion + Returns: + 4D/5D output torch.Tensor (NxCxSPATIAL) + """ + assert input.dim() == 4 + + # expand the input tensor to Nx1xSPATIAL before scattering + input = input.unsqueeze(1) + # create output tensor shape (NxCxSPATIAL) + shape = list(input.size()) + shape[1] = C + + if ignore_index is not None: + # create ignore_index mask for the result + mask = input.expand(shape) == ignore_index + # clone the src tensor and zero out ignore_index in the input + input = input.clone() + input[input == ignore_index] = 0 + # scatter to get the one-hot tensor + result = torch.zeros(shape).to(input.device).scatter_(1, input, 1) + # bring back the ignore_index in the result + result[mask] = ignore_index + return result + else: + # scatter to get the one-hot tensor + return torch.zeros(shape).to(input.device).scatter_(1, input, 1) + + +def convert_to_numpy(*inputs): + """ + Coverts input tensors to numpy ndarrays + + Args: + inputs (iteable of torch.Tensor): torch tensor + + Returns: + tuple of ndarrays + """ + + def _to_numpy(i): + assert isinstance(i, torch.Tensor), "Expected input to be torch.Tensor" + return i.detach().cpu().numpy() + + return (_to_numpy(i) for i in inputs) + + +def create_optimizer(optimizer_config, model): + learning_rate = optimizer_config['learning_rate'] + weight_decay = optimizer_config.get('weight_decay', 0) + betas = tuple(optimizer_config.get('betas', (0.9, 0.999))) + optimizer = optim.Adam(model.parameters(), lr=learning_rate, betas=betas, weight_decay=weight_decay) + return optimizer + + +def create_lr_scheduler(lr_config, optimizer): + if lr_config is None: + return None + class_name = lr_config.pop('name') + m = importlib.import_module('torch.optim.lr_scheduler') + clazz = getattr(m, class_name) + # add optimizer to the config + lr_config['optimizer'] = optimizer + return clazz(**lr_config) + + +def get_class(class_name, modules): + for module in modules: + m = importlib.import_module(module) + clazz = getattr(m, class_name, None) + if clazz is not None: + return clazz + raise RuntimeError(f'Unsupported dataset class: {class_name}') diff --git a/Trainer/visualizer.py b/Trainer/visualizer.py new file mode 100644 index 0000000000000000000000000000000000000000..aacf7a1d7abfc4844d29e81fde9db82bbcf49e09 --- /dev/null +++ b/Trainer/visualizer.py @@ -0,0 +1,333 @@ + +""" +Visualization modules +""" +import os +import numpy as np +from math import ceil +import torch +import torch.nn.functional as F +from PIL import Image +from collections import defaultdict + +from utils.misc import make_dir + + +def match_shape(array, shape): + # array: (channel_dim, *orig_shape) + array = array[None] + if list(array.shape[2:]) != list(shape): + array = F.interpolate(array, size=shape) + return array[0] + +def pad_shape(array_list): + max_shape = [0] * len(array_list[0].shape) + + for array in array_list: + max_shape = [max(max_shape[dim], array.shape[dim]) for dim in range(len(max_shape))] + pad_array_list = [] + for array in array_list: + start = [(max_shape[dim] - array.shape[dim]) // 2 for dim in range(len(max_shape))] + if len(start) == 2: + pad_array = np.zeros((max_shape[0], max_shape[1])) + pad_array[start[0] : start[0] + array.shape[0], start[1] : start[1] + array.shape[1]] = array + elif len(start) == 3: + pad_array = np.zeros((max_shape[0], max_shape[1], max_shape[2])) + pad_array[start[0] : start[0] + array.shape[0], start[1] : start[1] + array.shape[1], start[2] : start[2] + array.shape[2]] = array + elif len(start) == 4: + pad_array = np.zeros((max_shape[0], max_shape[1], max_shape[2], max_shape[3])) + pad_array[start[0] : start[0] + array.shape[0], start[1] : start[1] + array.shape[1], start[2] : start[2] + array.shape[2], start[3] : start[3] + array.shape[3]] = array + + pad_array_list.append(pad_array) + return pad_array_list + + +def even_sample(orig_len, num): + idx = [] + length = float(orig_len) + for i in range(num): + idx.append(int(ceil(i * length / num))) + return idx + + +def normalize(nda, channel = None): + if channel is not None: + nda_max = np.max(nda, axis = channel, keepdims = True) + nda_min = np.min(nda, axis = channel, keepdims = True) + else: + nda_max = np.max(nda) + nda_min = np.min(nda) + return (nda - nda_min) / (nda_max - nda_min + 1e-7) + + +############################################## + + +class BaseVisualizer(object): + + def __init__(self, gen_args, train_args, draw_border=False): + + self.tasks = [key for (key, value) in vars(gen_args.task).items() if value] + + self.args = train_args + self.draw_border = draw_border + self.vis_spacing = self.args.visualizer.spacing + + + def create_image_row(self, images): + if self.draw_border: + images = np.copy(images) + images[:, :, [0, -1]] = (1, 1, 1) + images[:, :, [0, -1]] = (1, 1, 1) + return np.concatenate(list(images), axis=1) + + def create_image_grid(self, *args): + out = [] + for arg in args: + out.append(normalize(self.create_image_row(arg))) + return np.concatenate(out, axis=0) + + def prepare_for_itk(self, array): # (s, r, c, *) + return array[:, ::-1, :] + + def prepare_for_png(self, array, normalize = False): # (s, r, c, *) + slc = array[::self.vis_spacing[0]] # (s', r, c *) + row = array[:, ::self.vis_spacing[1]].transpose((1, 0, 2, 3))[:, ::-1] # (s, r', c, *) -> (r', s, c, *) + col = array[:, :, ::self.vis_spacing[2]].transpose((2, 0, 1, 3))[:, ::-1] # (s, r, c', *) -> (c', s, r, *) + + if normalize: + slc = (slc - np.min(slc)) / (np.max(slc) - np.min(slc)) + row = (slc - np.min(slc)) / (np.max(slc) - np.min(row)) + col = (slc - np.min(slc)) / (np.max(slc) - np.min(col)) + return slc, row, col + + + +class FeatVisualizer(BaseVisualizer): + + def __init__(self, gen_args, train_args, draw_border=False): + BaseVisualizer.__init__(self, gen_args, train_args, draw_border) + self.feat_vis_num = train_args.visualizer.feat_vis_num + + def visualize_all_multi(self, subjects, multi_inputs, multi_outputs, out_dir): + """ + For med-id student input samples: n_samples * [ (batch_size, channel_dim, *img_shp) ] + For med-id student output features: n_samples * [ n_levels * (batch_size, channel_dim, *img_shp) ] + """ + + names = [name.split('.nii')[0] for name in subjects['name']] + multi_inputs = [x['input'] for x in multi_inputs] # n_samples * (b, d, s, r, c) + for k in multi_outputs[0].keys(): + if 'feat' in k: + multi_features = [x[k] for x in multi_outputs] + self.visualize_all_multi_features(names , multi_features, multi_inputs, out_dir, prefix = k) + + def visualize_all_multi_features(self, names, multi_features, multi_inputs, out_dir, prefix = 'feat'): + + n_samples = len(multi_inputs) + n_levels = len(multi_features[0]) + + multi_inputs_reorg = [] # batch_size * [ n_samples * (channel_dim, *img_shp) ] + multi_features_reorg = [] # batch_size * [ n_samples * [ n_levels * (channel_dim, *img_shp) ] ] + for i_name, _ in enumerate(names): + multi_features_reorg.append([[multi_features[i_sample][i_level][i_name] for i_level in range(n_levels)] for i_sample in range(n_samples)]) + multi_inputs_reorg.append([multi_inputs[i_sample][i_name] for i_sample in range(n_samples)]) + + for i_name, name in enumerate(names): + + inputs = multi_inputs_reorg[i_name] + features = multi_features_reorg[i_name] + + all_sample_results = defaultdict(list) + for i_sample in range(n_samples): + + curr_input = inputs[i_sample].data.cpu().numpy() # ( d=1, s, r, c) + curr_input = self.prepare_for_itk(curr_input.transpose(3, 2, 1, 0)) # (d, x, y, z) -> (z, y, x, d) + + curr_feat = features[i_sample] # n_levels * (channel_dim, s, r, c) + curr_level_feats = [] + + for l in range(n_levels): + curr_level_feat = curr_feat[l] # (channel_dim, s, r, c) + + sub_idx = even_sample(curr_level_feat.shape[0], self.feat_vis_num) + curr_level_feat = torch.stack([curr_level_feat[idx] for idx in sub_idx], dim = 0) # (sub_channel_dim, s, r, c) + + curr_level_feat = match_shape(curr_level_feat, list(curr_input.shape[:-1])) + curr_level_feats.append(self.prepare_for_itk((curr_level_feat.data.cpu().numpy().transpose((3, 2, 1, 0))))) + + all_results = self.gather(curr_input, curr_level_feats) + + for l, result in enumerate(all_results): # n_level * (r, c) + gap = np.zeros_like(result[:, :int( result.shape[1] / (curr_input.shape[0] / self.vis_spacing[0]) )]) + all_sample_results[l] += [result] + [gap] + + for l in all_sample_results.keys(): + curr_level_all_sample_feats = np.concatenate(list(all_sample_results[l][:-1]), axis=1) # (s, n_samples * c) + Image.fromarray(curr_level_all_sample_feats).save(os.path.join(make_dir(os.path.join(out_dir, name)), name + '_%s_l%s.png' % (prefix, str(l)))) + + + def visualize_all(self, names, inputs, features): + """ + For general (single-sample) inputs: (batch_size, channel_dim, *img_shp) + For general (single-sample) output features: n_levels * (batch_size, channel_dim, *img_shp) + """ + + inputs = inputs.data.cpu().numpy() # (b, d=1, s, r, c) + n_levels = len(features) # n_levels * (b, channel_dim, s, r, c) + + for i_name, name in enumerate(names): + curr_input = self.prepare_for_itk(inputs[i_name].transpose((3, 2, 1, 0))) # (d, x, y, z) -> (z, y, x, d) + curr_level_feats = [] + for l in range(n_levels): + curr_feat = features[l][i_name] # (channel_dim, s, r, c) + + sub_idx = even_sample(curr_feat.shape[0], self.feat_vis_num) + curr_feat = torch.stack([curr_feat[idx] for idx in sub_idx], dim = 0) # (sub_channel_dim, s, r, c) + + curr_feat = match_shape(curr_feat, list(curr_input.shape[:-1])) + curr_level_feats.append(self.prepare_for_itk((curr_feat.data.cpu().numpy().transpose((3, 2, 1, 0))))) + + self.gather(curr_input, curr_level_feats) + + + def gather(self, input, feats): + + input_slc = self.prepare_for_png(input, normalize = False)[0][..., 0] # (sub_s, r, c) + all_images = [] + for l, feat in enumerate(feats): + slc_images = [input_slc] # only plot along axial + slc_feat = normalize(feat[::self.vis_spacing[0]].transpose(3, 0, 1, 2), channel = 1) # (sub_s, r, c, sub_channel_dim) -> (sub_channel_dim, sub_s, r, c) + slc_images = [input_slc, np.zeros_like(input_slc)] + list(slc_feat) # (1 + 1 + s', r, c *) + slc_images = pad_shape(slc_images) + + slc_image = self.create_image_grid(*slc_images) + slc_image = (255 * slc_image).astype(np.uint8) + all_images.append(slc_image) + + return all_images + + + +class TaskVisualizer(BaseVisualizer): + + def __init__(self, gen_args, train_args, draw_border=False): + BaseVisualizer.__init__(self, gen_args, train_args, draw_border) + + def visualize_all(self, subjects, samples, outputs, out_dir, output_names = ['image'], target_names = ['image']): + + if len(output_names) == 0: + return + + n_samples = len(samples) + + names = [name.split('.nii')[0] for name in subjects['name']] + + inputs = [x['input'].data.cpu().numpy() for x in samples] # n_samples * (b, d, s, r, c) + if 'input_flip' in samples[0].keys(): + inputs_flip = [x['input_flip'].data.cpu().numpy() for x in samples] # n_samples * (b, d, s, r, c) + + out_images = {} + for output_name in output_names: + if output_name in outputs[0].keys(): + out_images[output_name] = [x[output_name].data.cpu().numpy() for x in outputs] # n_samples * (b, d, s, r, c) + + for i, name in enumerate(names): + #case_out_dir = make_dir(os.path.join(out_dir, name)) + curr_inputs = [self.prepare_for_itk(inputs[i_sample][i].transpose((3, 2, 1, 0))) for i_sample in range(n_samples)] # n_samples * (d, x, y, z) -> n_samples (z, y, x, d) + if 'input_flip' in samples[0].keys(): + curr_inputs_flip = [self.prepare_for_itk(inputs_flip[i_sample][i].transpose((3, 2, 1, 0))) for i_sample in range(n_samples)] # n_samples * (d, x, y, z) -> n_samples (z, y, x, d) + + # Plot all inputs + #self.visualize_sample(name, curr_inputs, out_dir, postfix = '_input') + + if len(out_images) > 0: + curr_target = {} + if 'bias_field' in samples[0]: + curr_target['bias_field'] = [self.prepare_for_itk(samples[i_sample]['bias_field'][i].data.cpu().numpy().transpose((3, 2, 1, 0))) for i_sample in range(n_samples)] + if 'high_res' in samples[0]: + curr_target['high_res'] = [self.prepare_for_itk(samples[i_sample]['high_res'][i].data.cpu().numpy().transpose((3, 2, 1, 0))) for i_sample in range(n_samples)] + + for target_name in target_names: + if target_name in subjects and target_name not in curr_target.keys(): + try: + curr_target[target_name] = self.prepare_for_itk(subjects[target_name][i].data.cpu().numpy().transpose((3, 2, 1, 0))) # (d=1, s, r, c) -> (z, y, x, d) + except: + pass + #print(target_name, 'failed in visualization') + + curr_outputs = {} + for output_name in output_names: + if output_name in outputs[0].keys(): + #print('output name', output_name) + curr_outputs[output_name] = [self.prepare_for_itk(out_images[output_name][i_sample][i].transpose((3, 2, 1, 0))) for i_sample in range(n_samples)] # n_samples * (d, x, y, z) -> n_samples (z, y, x, d) + + all_images = [] + + for i_sample, curr_input in enumerate(curr_inputs): + target_list = [curr_input] + if 'input_flip' in samples[0].keys(): + target_list.append(curr_inputs_flip[i_sample]) + for target_name in target_names: + if target_name in curr_target: + #print('target name', target_name) + if 'bias_field' in target_name or 'high_res' in target_name: + target_list.append(curr_target[target_name][i_sample]) + else: + target_list.append(curr_target[target_name]) + + output_list = [] + for ouput_name in output_names: + if ouput_name in curr_outputs.keys(): + output_list.append(curr_outputs[ouput_name][i_sample]) + + all_image = self.gather(target_list, output_list) # (row, col) + all_images.append(all_image) # n_sample * (row, col) + all_images = np.concatenate(all_images, axis=1).astype(np.uint8) # (row, n_sample * col) + Image.fromarray(all_images).save(os.path.join(out_dir, name + '_all_outputs.png')) + + def visualize_sample(self, name, input, out_dir, postfix = '_input'): + + n_samples = len(input) + + slc_images, row_images, col_images = [], [], [] + for i_sample in range(n_samples): + input_slc, input_row, input_col = self.prepare_for_png(input[i_sample], normalize = False) + + slc_images.append(input_slc) + row_images.append(input_row) + col_images.append(input_col) + + # add row gap + gap = [np.zeros_like(slc_images[0])] + all_images = slc_images + gap + row_images + gap + col_images + all_images = pad_shape(all_images) + all_image = self.create_image_grid(*all_images) + all_image = (255 * all_image).astype(np.uint8) + Image.fromarray(all_image[:, :, 0]).save(os.path.join(out_dir, name + '_all' + postfix + '.png')) # grey scale image last channel == 1 + return + + def gather(self, target_list = [], output_list = []): + + slc_images, row_images, col_images = [], [], [] + + for add_target in target_list: + add_target_slc, add_target_row, add_target_col = self.prepare_for_png(add_target, normalize = False) + slc_images += [add_target_slc] + row_images += [add_target_row] + col_images += [add_target_col] + + for add_output in output_list: + add_output_slc, add_output_row, add_output_col = self.prepare_for_png(add_output, normalize = False) + slc_images += [add_output_slc] + row_images += [add_output_row] + col_images += [add_output_col] + + # add row gap + gap = [np.zeros_like(add_target_slc)] + all_images = slc_images + gap + row_images + gap + col_images + all_images = pad_shape(all_images) + all_image = self.create_image_grid(*all_images) + + all_image = (255 * all_image).astype(np.uint8) + return all_image[:, :, 0] # shrink last channel dimension (d=1) diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/assets/.DS_Store b/assets/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..f7321d09e855180cc11c506d47c9e6e726f76559 Binary files /dev/null and b/assets/.DS_Store differ diff --git a/assets/overview.png b/assets/overview.png new file mode 100644 index 0000000000000000000000000000000000000000..77e863fa912053f0a13a50f62e6dbee169fce386 --- /dev/null +++ b/assets/overview.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3ea34001c23d83e604ce6d7e24fee314eca662cfe9489464a24045042b50b5b0 +size 306611 diff --git a/cfgs/generator/default.yaml b/cfgs/generator/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..01a1f4c2f90cb7391567cbd96f7451601f6ca57c --- /dev/null +++ b/cfgs/generator/default.yaml @@ -0,0 +1,149 @@ +device_generator: + +data_root: /autofs/vast/lemon/data_curated/brain_mris_QCed + + +split: train # train or test +split_root: /autofs/vast/lemon/temp_stuff/peirong/train_test_split +train_txt: /autofs/vast/lemon/temp_stuff/peirong/train_test_split/train.txt +test_txt: /autofs/vast/lemon/temp_stuff/peirong/train_test_split/test.txt + + +dataset_names: ['ADNI'] # list of datasets +dataset_probs: # [1.] +modality_probs: { # default + 'ADHD200': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'HCP': {'T1': 0.3333, 'T2': 0.6667, 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'AIBL': {'T1': 0.25, 'T2': 0.5, 'FLAIR': 0.75, 'CT': 0., 'synth': 1.}, + 'OASIS3': {'T1': 0.25, 'T2': 0.5, 'FLAIR': 0., 'CT': 0.75, 'synth': 1.}, + 'ADNI': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'ADNI3': {'T1': 0.3333, 'T2': 0., 'FLAIR': 0.6667, 'CT': 0., 'synth': 1.}, + 'ATLAS': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'ISLES': {'T1': 0., 'T2': 0., 'FLAIR': 0.5, 'CT': 0., 'synth': 1.}, + + # TODO + 'ABIDE': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'Buckner40': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'COBRE': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'Chinese-HCP': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'ISBI2015': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'MCIC': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, +} +mix_synth_prob: 0. # mix synth with real images +dataset_option: default # or brain_id +segment_prefix: brainseg_with_extracerebral + + +# setups for training/testing tasks +task: + T1: True + T2: True + FLAIR: True + CT: True + + segmentation: True + distance: True + bias_field: True + registration: True + + super_resolution: False + surface: False + pathology: False + + contrastive: False + + +# setups for augmentation functions to apply +augmentation_steps: {'synth': ['gamma', 'bias_field', 'resample', 'noise'], 'real': ['gamma', 'bias_field', 'resample', 'noise']} + + +# setups for generator +generator: + + size: [128, 128, 128] + + photo_prob: 0.2 + max_rotation: 15 + max_shear: 0.2 + max_scaling: 0.2 + nonlin_scale_min: 0.03 + nonlin_scale_max: 0.06 + nonlin_std_max: 4 + bag_prob: 0.5 + bag_scale_min: 0.02 + bag_scale_max: 0.08 + bf_scale_min: 0.02 + bf_scale_max: 0.04 + bf_std_min: 0.1 + bf_std_max: 0.6 + gamma_std: 0.1 + noise_std_min: 5 + noise_std_max: 15 + exvixo_prob: 0.25 + exvixo_prob_vs_photo: 0.66666666666666 + + + pv: True + random_shift: False + deform_one_hots: False + + produce_surfaces: False + bspline_zooming: False + n_steps_svf_integration: 8 + nonlinear_transform: True + + left_hemis_only: False + low_res_only: False + ct_prob: 0 + flip_prob: 0.5 + + pathology_prob: 0. # pathology_prob when synth + random_shape_prob: 0. # initialize pathol shape from random noise (instead of existing shapes) + augment_pathology: False + + +synth_image_generator: + noise_std_min: 5. + noise_std_max: 15. +real_image_generator: + noise_std_min: 0. + noise_std_max: 0.02 + + +pathology_shape_generator: + perlin_res: [2, 2, 2] # shape must be a multiple of res + mask_percentile_min: 85 + mask_percentile_max: 99.9 + integ_method: dopri5 # choices=['dopri5', 'adams', 'rk4', 'euler'] + bc: neumann # choices=['neumann', 'cauchy', 'dirichlet', 'source_neumann', 'dirichlet_neumann'] + V_multiplier: 500 + dt: 0.1 + max_nt: 10 # >= 2 + pathol_thres: 0.5 + pathol_tol: 0.0000001 # if pathol mean < tol, skip + + +### some constants + +max_surf_distance: 3. # clamp at plus / minus this number (both the ground truth and the prediction) + +## NEW VAST synth +label_list_segmentation_brainseg_with_extracerebral: [0, 11, 12, 13, 16, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 14, 15, 17, 47, 49, 51, 53, 55, + 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 48, 50, 52, 54, 56] +n_neutral_labels_brainseg_with_extracerebral: 20 + +## synth +label_list_segmentation_with_csf: [0,14,15,16,24,77,85, 2, 3, 4, 7, 8, 10,11,12,13,17,18,26,28, 41,42,43,46,47,49,50,51,52,53,54,58,60] # 33 +n_neutral_labels_with_csf: 7 +label_list_segmentation_without_csf: [0,14,15,16,77,85, 2, 3, 4, 7, 8, 10,11,12,13,17,18,26,28, 41,42,43,46,47,49,50,51,52,53,54,58,60] +n_neutral_labels_without_csf: 6 + + +## synth_hemi +# without cerebellum and brainstem +label_list_segmentation: [0, 2, 3, 4, 10, 11, 12, 13, 17, 18, 26, 28, 77] +n_neutral_labels: 6 + +# with cerebellum and brainstem +label_list_segmentation_with_cb: [0, 2, 3, 4, 7, 8, 10, 11, 12, 13, 16, 17, 18, 26, 28, 77] diff --git a/cfgs/generator/test/demo_synth.yaml b/cfgs/generator/test/demo_synth.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c9d857c4e2b9803683c477b24be1a1ea99d175bb --- /dev/null +++ b/cfgs/generator/test/demo_synth.yaml @@ -0,0 +1,146 @@ +device_generator: +out_dir: /autofs/space/yogurt_002/users/pl629/results/BrainID/demo_synth + + +test_itr_limit: 10 # n_subjects + +num_deformations: 1 # n_deformations for each subj +all_contrasts: 10 # n_deformations for each deformation: >= 1, <= all_samples + +mild_samples: 1 +all_samples: 1 # n_samples within each subject +test_mild_samples: 2 +test_all_samples: 4 # n_samples within each subject + + + +split: test # train or test +dataset_names: ['HCP'] # list of datasets +dataset_probs: +modality_probs: { + 'ADHD': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'HCP': {'T1': 0.3333, 'T2': 0.6667, 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'AIBL': {'T1': 0.25, 'T2': 0.5, 'FLAIR': 0.75, 'CT': 0., 'synth': 1.}, + 'OASIS': {'T1': 0.3333, 'T2': 0., 'FLAIR': 0., 'CT': 0.6667, 'synth': 1.}, + 'ADNI': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'ADNI3': {'T1': 0.3333, 'T2': 0., 'FLAIR': 0.6667, 'CT': 0., 'synth': 1.}, + 'ATLAS': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'ISLES': {'T1': 0., 'T2': 0., 'FLAIR': 0.5, 'CT': 0., 'synth': 1.}, +} +mix_synth_prob: 0. # mix synth with real images +dataset_option: brain_id + +# setups for training/testing tasks +task: + T1: True + T2: False + FLAIR: False + CT: False + + segmentation: False + registration: False + surface: False + distance: False + bias_field: False + + pathology: True + super_resolution: False + + contrastive: False + +# setups for augmentation functions to apply +augmentation_steps: ['gamma', 'bias_field', 'resample', 'noise'] + +# setups for generator +generator: + + size: [128, 128, 128] + + photo_prob: 0.2 + max_rotation: 15 + max_shear: 0.2 + max_scaling: 0.2 + nonlin_scale_min: 0.03 + nonlin_scale_max: 0.06 + nonlin_std_max: 4 + bag_prob: 0.5 + bag_scale_min: 0.02 + bag_scale_max: 0.08 + bf_scale_min: 0.02 + bf_scale_max: 0.04 + bf_std_min: 0.1 + bf_std_max: 0.6 + gamma_std: 0.1 + noise_std_min: 5 + noise_std_max: 15 + exvixo_prob: 0.25 + exvixo_prob_vs_photo: 0.66666666666666 + + + pv: True + random_shift: False + deform_one_hots: False + integrate_deformation_fields: False + produce_surfaces: False + bspline_zooming: False + n_steps_svf_integration: 8 + nonlinear_transform: True + + + ct_prob: 0 + flip_prob: 0. + + pathology_prob: 0.5 + augment_pathology: True + + + # brain-id customized setups + + # mild-to-severe intra-subject aug params + mild_samples: 2 + all_samples: 4 + all_contrasts: 4 # >= 1, <= all_samples + num_deformations: 1 + + + + + +pathology_shape_generator: + perlin_res: [2, 2, 2] # shape must be a multiple of res + integ_method: dopri5 # choices=['dopri5', 'adams', 'rk4', 'euler'] + bc: neumann # choices=['neumann', 'cauchy', 'dirichlet', 'source_neumann', 'dirichlet_neumann'] + V_multiplier: 500 + dt: 0.1 + max_nt: 10 # >= 2 + pathol_thres: 0.2 + + + + +# brain-id customized setups + +mild_generator: + bag_prob: 0.1 + bag_scale_min: 0.01 + bag_scale_max: 0.02 + bf_scale_min: 0.01 + bf_scale_max: 0.02 + bf_std_min: 0. + bf_std_max: 0.02 + gamma_std: 0.01 + noise_std_min: 0. + noise_std_max: 0.02 + + +severe_generator: + bag_prob: 0.5 + bag_scale_min: 0.02 + bag_scale_max: 0.08 + bf_scale_min: 0.02 + bf_scale_max: 0.04 + bf_std_min: 0.1 + bf_std_max: 0.6 + gamma_std: 0.1 + noise_std_min: 5 + noise_std_max: 15 \ No newline at end of file diff --git a/cfgs/generator/test/demo_test.yaml b/cfgs/generator/test/demo_test.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a3c06a865540b36fc140d551a43df6cb7460040d --- /dev/null +++ b/cfgs/generator/test/demo_test.yaml @@ -0,0 +1,137 @@ +device_generator: + +split: train # train or test + +dataset_names: ['HCP'] # list of datasets +dataset_probs: +modality_probs: { + 'ADHD': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'HCP': {'T1': 0.3333, 'T2': 0.6667, 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'AIBL': {'T1': 0.25, 'T2': 0.5, 'FLAIR': 0.75, 'CT': 0., 'synth': 1.}, + 'OASIS': {'T1': 0.3333, 'T2': 0., 'FLAIR': 0., 'CT': 0.6667, 'synth': 1.}, + 'ADNI': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'ADNI3': {'T1': 0.3333, 'T2': 0., 'FLAIR': 0.6667, 'CT': 0., 'synth': 1.}, + 'ATLAS': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'ISLES': {'T1': 0., 'T2': 0., 'FLAIR': 0.5, 'CT': 0., 'synth': 1.}, +} +mix_synth_prob: 0. # mix synth with real images +dataset_option: brain_id +segment_prefix: brainseg_with_extracerebral + +# setups for training/testing tasks +task: + T1: True + T2: True + FLAIR: True + CT: True + + segmentation: True + distance: True + bias_field: True + registration: True + + super_resolution: True + + surface: False + pathology: False + contrastive: False + + +# setups for augmentation functions to apply +augmentation_steps: ['gamma', 'bias_field', 'resample', 'noise'] + +# setups for generator +generator: + + size: [160, 160, 160] + + photo_prob: 0.2 + max_rotation: 15 + max_shear: 0.2 + max_scaling: 0.2 + nonlin_scale_min: 0.03 + nonlin_scale_max: 0.06 + nonlin_std_max: 4 + bag_prob: 0.5 + bag_scale_min: 0.02 + bag_scale_max: 0.08 + bf_scale_min: 0.02 + bf_scale_max: 0.04 + bf_std_min: 0.1 + bf_std_max: 0.6 + gamma_std: 0.1 + noise_std_min: 0.05 # 0.05 # 5 + noise_std_max: 1. # 0.15 # 15 + exvixo_prob: 0.25 + exvixo_prob_vs_photo: 0.66666666666666 + + + pv: True + random_shift: False + deform_one_hots: False + integrate_deformation_fields: False + produce_surfaces: False + bspline_zooming: False + n_steps_svf_integration: 8 + nonlinear_transform: True + + + left_hemis_only: False # NOTE + low_res_only: False + ct_prob: 0 + flip_prob: 0. + + pathology_prob: 0.5 + augment_pathology: True + + + # brain-id customized setups + + # mild-to-severe intra-subject aug params + mild_samples: 2 + all_samples: 4 + all_contrasts: 4 # >= 1, <= all_samples + num_deformations: 1 + + + + + +pathology_shape_generator: + perlin_res: [2, 2, 2] # shape must be a multiple of res + integ_method: dopri5 # choices=['dopri5', 'adams', 'rk4', 'euler'] + bc: neumann # choices=['neumann', 'cauchy', 'dirichlet', 'source_neumann', 'dirichlet_neumann'] + V_multiplier: 500 + dt: 0.1 + max_nt: 10 # >= 2 + pathol_thres: 0.2 + + + + +# brain-id customized setups + +mild_generator: + bag_prob: 0.1 + bag_scale_min: 0.01 + bag_scale_max: 0.02 + bf_scale_min: 0.01 + bf_scale_max: 0.02 + bf_std_min: 0. + bf_std_max: 0.02 + gamma_std: 0.01 + noise_std_min: 0. + noise_std_max: 0.02 + + +severe_generator: + bag_prob: 0.5 + bag_scale_min: 0.02 + bag_scale_max: 0.08 + bf_scale_min: 0.02 + bf_scale_max: 0.04 + bf_std_min: 0.1 + bf_std_max: 0.6 + gamma_std: 0.1 + noise_std_min: 0.05 # 0.05 # 5 + noise_std_max: 1. # 0.15 # 15 \ No newline at end of file diff --git a/cfgs/generator/test/demo_test_hemis.yaml b/cfgs/generator/test/demo_test_hemis.yaml new file mode 100644 index 0000000000000000000000000000000000000000..af8c0bbafc498bd4554cbf7c57d6bb72d5b5083d --- /dev/null +++ b/cfgs/generator/test/demo_test_hemis.yaml @@ -0,0 +1,137 @@ +device_generator: + +split: train # train or test + +dataset_names: ['HCP'] # list of datasets +dataset_probs: +modality_probs: { + 'ADHD': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'HCP': {'T1': 0.3333, 'T2': 0.6667, 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'AIBL': {'T1': 0.25, 'T2': 0.5, 'FLAIR': 0.75, 'CT': 0., 'synth': 1.}, + 'OASIS': {'T1': 0.3333, 'T2': 0., 'FLAIR': 0., 'CT': 0.6667, 'synth': 1.}, + 'ADNI': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'ADNI3': {'T1': 0.3333, 'T2': 0., 'FLAIR': 0.6667, 'CT': 0., 'synth': 1.}, + 'ATLAS': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'ISLES': {'T1': 0., 'T2': 0., 'FLAIR': 0.5, 'CT': 0., 'synth': 1.}, +} +mix_synth_prob: 0. # mix synth with real images +dataset_option: brain_id +segment_prefix: brainseg # NOTE: set to 'brainseg' for photo_hemis setup + +# setups for training/testing tasks +task: + T1: True + T2: True + FLAIR: True + CT: False # NOTE: turn off for photo_hemis setup + + segmentation: True + distance: True + bias_field: True + registration: True + + super_resolution: False + + surface: False + pathology: False + contrastive: False + + +# setups for augmentation functions to apply +augmentation_steps: ['gamma', 'bias_field', 'resample', 'noise'] + +# setups for generator +generator: + + size: [160, 160, 160] + + photo_prob: 0.2 + max_rotation: 15 + max_shear: 0.2 + max_scaling: 0.2 + nonlin_scale_min: 0.03 + nonlin_scale_max: 0.06 + nonlin_std_max: 4 + bag_prob: 0.5 + bag_scale_min: 0.02 + bag_scale_max: 0.08 + bf_scale_min: 0.02 + bf_scale_max: 0.04 + bf_std_min: 0.1 + bf_std_max: 0.6 + gamma_std: 0.1 + noise_std_min: 0.05 # 0.05 # 5 + noise_std_max: 1. # 0.15 # 15 + exvixo_prob: 0.25 + exvixo_prob_vs_photo: 0.66666666666666 + + + pv: True + random_shift: False + deform_one_hots: False + integrate_deformation_fields: False + produce_surfaces: False + bspline_zooming: False + n_steps_svf_integration: 8 + nonlinear_transform: True + + + left_hemis_only: True # NOTE + low_res_only: False # NOTE + ct_prob: 0 + flip_prob: 0. + + pathology_prob: 0.5 + augment_pathology: True + + + # brain-id customized setups + + # mild-to-severe intra-subject aug params + mild_samples: 2 + all_samples: 4 + all_contrasts: 4 # >= 1, <= all_samples + num_deformations: 1 + + + + + +pathology_shape_generator: + perlin_res: [2, 2, 2] # shape must be a multiple of res + integ_method: dopri5 # choices=['dopri5', 'adams', 'rk4', 'euler'] + bc: neumann # choices=['neumann', 'cauchy', 'dirichlet', 'source_neumann', 'dirichlet_neumann'] + V_multiplier: 500 + dt: 0.1 + max_nt: 10 # >= 2 + pathol_thres: 0.2 + + + + +# brain-id customized setups + +mild_generator: + bag_prob: 0.1 + bag_scale_min: 0.01 + bag_scale_max: 0.02 + bf_scale_min: 0.01 + bf_scale_max: 0.02 + bf_std_min: 0. + bf_std_max: 0.02 + gamma_std: 0.01 + noise_std_min: 0. + noise_std_max: 0.02 + + +severe_generator: + bag_prob: 0.5 + bag_scale_min: 0.02 + bag_scale_max: 0.08 + bf_scale_min: 0.02 + bf_scale_max: 0.04 + bf_std_min: 0.1 + bf_std_max: 0.6 + gamma_std: 0.1 + noise_std_min: 0.05 # 0.05 # 5 + noise_std_max: 1. # 0.15 # 15 \ No newline at end of file diff --git a/cfgs/generator/train/brain_id.yaml b/cfgs/generator/train/brain_id.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e4c56e913e87074195530f7c0a988f015ebee611 --- /dev/null +++ b/cfgs/generator/train/brain_id.yaml @@ -0,0 +1,126 @@ +device_generator: + +split: train # NOTE: train or test + +dataset_names: [] # NOTE: None for all, 'ADHD200' for age estimation +dataset_probs: +modality_probs: { + 'ADHD200': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'HCP': {'T1': 0.3333, 'T2': 0.6667, 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'AIBL': {'T1': 0.25, 'T2': 0.5, 'FLAIR': 0.75, 'CT': 0., 'synth': 1.}, + 'OASIS3': {'T1': 0.25, 'T2': 0.5, 'FLAIR': 0., 'CT': 0.75, 'synth': 1.}, + 'ADNI': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'ADNI3': {'T1': 0.3333, 'T2': 0., 'FLAIR': 0.6667, 'CT': 0., 'synth': 1.}, + 'ATLAS': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'ISLES': {'T1': 0., 'T2': 0., 'FLAIR': 0.5, 'CT': 0., 'synth': 1.}, + + # TODO + 'ABIDE': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'Buckner40': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'COBRE': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'Chinese-HCP': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'ISBI2015': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'MCIC': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, +} +mix_synth_prob: 0.2 # blend synth with real images +dataset_option: brain_id +segment_prefix: brainseg_with_extracerebral # NOTE: set to 'brainseg' for photo_hemis setup + +# setups for training/testing tasks +task: + T1: True + T2: True + FLAIR: True + CT: True # NOTE: turn off for photo_hemis setup + + segmentation: True + distance: True + bias_field: True + registration: True + + # downstream + super_resolution: False + age: False + + surface: False + pathology: False + contrastive: False + + +# setups for generator +generator: + + size: [160, 160, 160] # [128, 128, 128], [160, 160, 160] + + left_hemis_only: False + low_res_only: False # UNDER TESTING + + photo_prob: 0.2 + max_rotation: 15 + max_shear: 0.2 + max_scaling: 0.2 + nonlin_scale_min: 0.03 + nonlin_scale_max: 0.06 + nonlin_std_max: 4 + bag_prob: 0.5 + bag_scale_min: 0.02 + bag_scale_max: 0.08 + bf_scale_min: 0.02 + bf_scale_max: 0.04 + bf_std_min: 0.1 + bf_std_max: 0.6 + gamma_std: 0.1 + noise_std_min: 0.05 + noise_std_max: 1. + exvixo_prob: 0.25 + exvixo_prob_vs_photo: 0.66666666666666 + + pv: True + random_shift: False + deform_one_hots: False + integrate_deformation_fields: False + produce_surfaces: False + bspline_zooming: False + n_steps_svf_integration: 8 + nonlinear_transform: True + + ct_prob: 0 + flip_prob: 0.5 + + + # brain-id customized setups + + # mild-to-severe intra-subject aug params + mild_samples: 0 # 2 + all_samples: 1 # 4 + all_contrasts: 1 # 4 # >= 1, <= all_samples + num_deformations: 1 + + + +# brain-id customized setups + +mild_generator: + bag_prob: 0.1 + bag_scale_min: 0.01 + bag_scale_max: 0.02 + bf_scale_min: 0.01 + bf_scale_max: 0.02 + bf_std_min: 0. + bf_std_max: 0.02 + gamma_std: 0.01 + noise_std_min: 0. + noise_std_max: 0.02 + + +severe_generator: + bag_prob: 0.5 + bag_scale_min: 0.02 + bag_scale_max: 0.08 + bf_scale_min: 0.02 + bf_scale_max: 0.04 + bf_std_min: 0.1 + bf_std_max: 0.6 + gamma_std: 0.1 + noise_std_min: 0.05 # 0.05 # 5 + noise_std_max: 1. # 0.15 # 15 \ No newline at end of file diff --git a/cfgs/generator/train/brain_id_age.yaml b/cfgs/generator/train/brain_id_age.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1497fa8b9865ea1d4e8402b8bf8f12a44b286043 --- /dev/null +++ b/cfgs/generator/train/brain_id_age.yaml @@ -0,0 +1,126 @@ +device_generator: + +split: train # NOTE: train or test, train_age + +dataset_names: ['ADHD200'] # NOTE: None for all, 'ADHD200' for age estimation +dataset_probs: +modality_probs: { + 'ADHD200': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'HCP': {'T1': 0.3333, 'T2': 0.6667, 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'AIBL': {'T1': 0.25, 'T2': 0.5, 'FLAIR': 0.75, 'CT': 0., 'synth': 1.}, + 'OASIS3': {'T1': 0.25, 'T2': 0.5, 'FLAIR': 0., 'CT': 0.75, 'synth': 1.}, + 'ADNI': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'ADNI3': {'T1': 0.3333, 'T2': 0., 'FLAIR': 0.6667, 'CT': 0., 'synth': 1.}, + 'ATLAS': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'ISLES': {'T1': 0., 'T2': 0., 'FLAIR': 0.5, 'CT': 0., 'synth': 1.}, + + # TODO + 'ABIDE': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'Buckner40': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'COBRE': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'Chinese-HCP': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'ISBI2015': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'MCIC': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, +} +mix_synth_prob: 0.2 # blend synth with real images +dataset_option: brain_id +segment_prefix: brainseg_with_extracerebral + +# setups for training/testing tasks +task: + T1: False + T2: False + FLAIR: False + CT: False # NOTE: turn off for photo_hemis setup + + segmentation: False + distance: False + bias_field: False + registration: False + + # downstream + super_resolution: False + age: True + + surface: False + pathology: False + contrastive: False + + +# setups for generator +generator: + + size: [160, 160, 160] # [128, 128, 128], [160, 160, 160] + + left_hemis_only: False # UNDER TESTING + low_res_only: False # UNDER TESTING + + photo_prob: 0.2 + max_rotation: 15 + max_shear: 0. # disable shearing to preserve the original brain shape + max_scaling: 0. # disable scaling to preserve the original brain size + nonlin_scale_min: 0.03 + nonlin_scale_max: 0.06 + nonlin_std_max: 4 + bag_prob: 0.5 + bag_scale_min: 0.02 + bag_scale_max: 0.08 + bf_scale_min: 0.02 + bf_scale_max: 0.04 + bf_std_min: 0.1 + bf_std_max: 0.6 + gamma_std: 0.1 + noise_std_min: 0.05 + noise_std_max: 1. + exvixo_prob: 0.25 + exvixo_prob_vs_photo: 0.66666666666666 + + pv: True + random_shift: False + deform_one_hots: False + integrate_deformation_fields: False + produce_surfaces: False + bspline_zooming: False + n_steps_svf_integration: 8 + nonlinear_transform: False # age prediction: disable non-linear and affine (only rigid) + + ct_prob: 0 + flip_prob: 0.5 + + + # brain-id customized setups + + # mild-to-severe intra-subject aug params + mild_samples: 0 # 2 + all_samples: 1 # 4 + all_contrasts: 1 # 4 # >= 1, <= all_samples + num_deformations: 1 + + + +# brain-id customized setups + +mild_generator: + bag_prob: 0.1 + bag_scale_min: 0.01 + bag_scale_max: 0.02 + bf_scale_min: 0.01 + bf_scale_max: 0.02 + bf_std_min: 0. + bf_std_max: 0.02 + gamma_std: 0.01 + noise_std_min: 0. + noise_std_max: 0.02 + + +severe_generator: + bag_prob: 0.5 + bag_scale_min: 0.02 + bag_scale_max: 0.08 + bf_scale_min: 0.02 + bf_scale_max: 0.04 + bf_std_min: 0.1 + bf_std_max: 0.6 + gamma_std: 0.1 + noise_std_min: 0.05 # 0.05 # 5 + noise_std_max: 1. # 0.15 # 15 \ No newline at end of file diff --git a/cfgs/generator/train/brain_id_bf.yaml b/cfgs/generator/train/brain_id_bf.yaml new file mode 100644 index 0000000000000000000000000000000000000000..add8d8c18b09f05828706feb9cd80aad7fb9a39c --- /dev/null +++ b/cfgs/generator/train/brain_id_bf.yaml @@ -0,0 +1,126 @@ +device_generator: + +split: train # NOTE: train or test + +dataset_names: [] # NOTE: None for all, 'ADHD200' for age estimation +dataset_probs: +modality_probs: { + 'ADHD200': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'HCP': {'T1': 0.3333, 'T2': 0.6667, 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'AIBL': {'T1': 0.25, 'T2': 0.5, 'FLAIR': 0.75, 'CT': 0., 'synth': 1.}, + 'OASIS3': {'T1': 0.25, 'T2': 0.5, 'FLAIR': 0., 'CT': 0.75, 'synth': 1.}, + 'ADNI': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'ADNI3': {'T1': 0.3333, 'T2': 0., 'FLAIR': 0.6667, 'CT': 0., 'synth': 1.}, + 'ATLAS': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'ISLES': {'T1': 0., 'T2': 0., 'FLAIR': 0.5, 'CT': 0., 'synth': 1.}, + + # TODO + 'ABIDE': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'Buckner40': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'COBRE': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'Chinese-HCP': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'ISBI2015': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'MCIC': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, +} +mix_synth_prob: 0.2 # blend synth with real images +dataset_option: brain_id +segment_prefix: brainseg_with_extracerebral + +# setups for training/testing tasks +task: + T1: False + T2: False + FLAIR: False + CT: False # NOTE: turn off for photo_hemis setup + + segmentation: False + distance: False + bias_field: True + registration: False + + # downstream + super_resolution: False + age: False + + surface: False + pathology: False + contrastive: False + + +# setups for generator +generator: + + size: [160, 160, 160] # [128, 128, 128], [160, 160, 160] + + left_hemis_only: False # UNDER TESTING + low_res_only: False # UNDER TESTING + + photo_prob: 0.2 + max_rotation: 15 + max_shear: 0.2 + max_scaling: 0.2 + nonlin_scale_min: 0.03 + nonlin_scale_max: 0.06 + nonlin_std_max: 4 + bag_prob: 0.5 + bag_scale_min: 0.02 + bag_scale_max: 0.08 + bf_scale_min: 0.02 + bf_scale_max: 0.04 + bf_std_min: 0.1 + bf_std_max: 0.6 + gamma_std: 0.1 + noise_std_min: 0.05 + noise_std_max: 1. + exvixo_prob: 0.25 + exvixo_prob_vs_photo: 0.66666666666666 + + pv: True + random_shift: False + deform_one_hots: False + integrate_deformation_fields: False + produce_surfaces: False + bspline_zooming: False + n_steps_svf_integration: 8 + nonlinear_transform: True + + ct_prob: 0 + flip_prob: 0.5 + + + # brain-id customized setups + + # mild-to-severe intra-subject aug params + mild_samples: 0 # 2 + all_samples: 1 # 4 + all_contrasts: 1 # 4 # >= 1, <= all_samples + num_deformations: 1 + + + +# brain-id customized setups + +mild_generator: + bag_prob: 0.1 + bag_scale_min: 0.01 + bag_scale_max: 0.02 + bf_scale_min: 0.01 + bf_scale_max: 0.02 + bf_std_min: 0. + bf_std_max: 0.02 + gamma_std: 0.01 + noise_std_min: 0. + noise_std_max: 0.02 + + +severe_generator: + bag_prob: 0.5 + bag_scale_min: 0.02 + bag_scale_max: 0.08 + bf_scale_min: 0.02 + bf_scale_max: 0.04 + bf_std_min: 0.1 + bf_std_max: 0.6 + gamma_std: 0.1 + noise_std_min: 0.05 # 0.05 # 5 + noise_std_max: 1. # 0.15 # 15 \ No newline at end of file diff --git a/cfgs/generator/train/brain_id_dist.yaml b/cfgs/generator/train/brain_id_dist.yaml new file mode 100644 index 0000000000000000000000000000000000000000..00e80676533b118fc08b852064c19c8511fbb3ca --- /dev/null +++ b/cfgs/generator/train/brain_id_dist.yaml @@ -0,0 +1,126 @@ +device_generator: + +split: train # NOTE: train or test + +dataset_names: [] # NOTE: None for all, 'ADHD200' for age estimation +dataset_probs: +modality_probs: { + 'ADHD200': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'HCP': {'T1': 0.3333, 'T2': 0.6667, 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'AIBL': {'T1': 0.25, 'T2': 0.5, 'FLAIR': 0.75, 'CT': 0., 'synth': 1.}, + 'OASIS3': {'T1': 0.25, 'T2': 0.5, 'FLAIR': 0., 'CT': 0.75, 'synth': 1.}, + 'ADNI': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'ADNI3': {'T1': 0.3333, 'T2': 0., 'FLAIR': 0.6667, 'CT': 0., 'synth': 1.}, + 'ATLAS': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'ISLES': {'T1': 0., 'T2': 0., 'FLAIR': 0.5, 'CT': 0., 'synth': 1.}, + + # TODO + 'ABIDE': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'Buckner40': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'COBRE': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'Chinese-HCP': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'ISBI2015': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'MCIC': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, +} +mix_synth_prob: 0.2 # blend synth with real images +dataset_option: brain_id +segment_prefix: brainseg_with_extracerebral + +# setups for training/testing tasks +task: + T1: False + T2: False + FLAIR: False + CT: False # NOTE: turn off for photo_hemis setup + + segmentation: False + distance: True + bias_field: False + registration: False + + # downstream + super_resolution: False + age: False + + surface: False + pathology: False + contrastive: False + + +# setups for generator +generator: + + size: [160, 160, 160] # [128, 128, 128], [160, 160, 160] + + left_hemis_only: False # UNDER TESTING + low_res_only: False # UNDER TESTING + + photo_prob: 0.2 + max_rotation: 15 + max_shear: 0.2 + max_scaling: 0.2 + nonlin_scale_min: 0.03 + nonlin_scale_max: 0.06 + nonlin_std_max: 4 + bag_prob: 0.5 + bag_scale_min: 0.02 + bag_scale_max: 0.08 + bf_scale_min: 0.02 + bf_scale_max: 0.04 + bf_std_min: 0.1 + bf_std_max: 0.6 + gamma_std: 0.1 + noise_std_min: 0.05 + noise_std_max: 1. + exvixo_prob: 0.25 + exvixo_prob_vs_photo: 0.66666666666666 + + pv: True + random_shift: False + deform_one_hots: False + integrate_deformation_fields: False + produce_surfaces: False + bspline_zooming: False + n_steps_svf_integration: 8 + nonlinear_transform: True + + ct_prob: 0 + flip_prob: 0.5 + + + # brain-id customized setups + + # mild-to-severe intra-subject aug params + mild_samples: 0 # 2 + all_samples: 1 # 4 + all_contrasts: 1 # 4 # >= 1, <= all_samples + num_deformations: 1 + + + +# brain-id customized setups + +mild_generator: + bag_prob: 0.1 + bag_scale_min: 0.01 + bag_scale_max: 0.02 + bf_scale_min: 0.01 + bf_scale_max: 0.02 + bf_std_min: 0. + bf_std_max: 0.02 + gamma_std: 0.01 + noise_std_min: 0. + noise_std_max: 0.02 + + +severe_generator: + bag_prob: 0.5 + bag_scale_min: 0.02 + bag_scale_max: 0.08 + bf_scale_min: 0.02 + bf_scale_max: 0.04 + bf_std_min: 0.1 + bf_std_max: 0.6 + gamma_std: 0.1 + noise_std_min: 0.05 # 0.05 # 5 + noise_std_max: 1. # 0.15 # 15 \ No newline at end of file diff --git a/cfgs/generator/train/brain_id_hemis.yaml b/cfgs/generator/train/brain_id_hemis.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6974da940743dd7676211096a4ad716384b0de00 --- /dev/null +++ b/cfgs/generator/train/brain_id_hemis.yaml @@ -0,0 +1,126 @@ +device_generator: + +split: train # NOTE: train or test + +dataset_names: [] # NOTE: None for all, 'ADHD200' for age estimation +dataset_probs: +modality_probs: { + 'ADHD200': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'HCP': {'T1': 0.3333, 'T2': 0.6667, 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'AIBL': {'T1': 0.25, 'T2': 0.5, 'FLAIR': 0.75, 'CT': 0., 'synth': 1.}, + 'OASIS3': {'T1': 0.25, 'T2': 0.5, 'FLAIR': 0., 'CT': 0.75, 'synth': 1.}, + 'ADNI': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'ADNI3': {'T1': 0.3333, 'T2': 0., 'FLAIR': 0.6667, 'CT': 0., 'synth': 1.}, + 'ATLAS': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'ISLES': {'T1': 0., 'T2': 0., 'FLAIR': 0.5, 'CT': 0., 'synth': 1.}, + + # TODO + 'ABIDE': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'Buckner40': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'COBRE': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'Chinese-HCP': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'ISBI2015': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'MCIC': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, +} +mix_synth_prob: 0.2 # blend synth with real images +dataset_option: brain_id +segment_prefix: brainseg # NOTE: set to 'brainseg' for photo_hemis setup + +# setups for training/testing tasks +task: + T1: True + T2: True + FLAIR: True + CT: False # NOTE: turn off for photo_hemis setup + + segmentation: True + distance: True + bias_field: True + registration: True + + # downstream + super_resolution: False + age: False + + surface: False + pathology: False + contrastive: False + + +# setups for generator +generator: + + size: [160, 160, 160] # [128, 128, 128], [160, 160, 160] + + left_hemis_only: True # NOTE: True for photo_hemis setup + low_res_only: False + + photo_prob: 0.2 + max_rotation: 15 + max_shear: 0.2 + max_scaling: 0.2 + nonlin_scale_min: 0.03 + nonlin_scale_max: 0.06 + nonlin_std_max: 4 + bag_prob: 0.5 + bag_scale_min: 0.02 + bag_scale_max: 0.08 + bf_scale_min: 0.02 + bf_scale_max: 0.04 + bf_std_min: 0.1 + bf_std_max: 0.6 + gamma_std: 0.1 + noise_std_min: 0.05 + noise_std_max: 1. + exvixo_prob: 0.25 + exvixo_prob_vs_photo: 0.66666666666666 + + pv: True + random_shift: False + deform_one_hots: False + integrate_deformation_fields: False + produce_surfaces: False + bspline_zooming: False + n_steps_svf_integration: 8 + nonlinear_transform: True + + ct_prob: 0 + flip_prob: 0.5 + + + # brain-id customized setups + + # mild-to-severe intra-subject aug params + mild_samples: 0 # 2 + all_samples: 1 # 4 + all_contrasts: 1 # 4 # >= 1, <= all_samples + num_deformations: 1 + + + +# brain-id customized setups + +mild_generator: + bag_prob: 0.1 + bag_scale_min: 0.01 + bag_scale_max: 0.02 + bf_scale_min: 0.01 + bf_scale_max: 0.02 + bf_std_min: 0. + bf_std_max: 0.02 + gamma_std: 0.01 + noise_std_min: 0. + noise_std_max: 0.02 + + +severe_generator: + bag_prob: 0.5 + bag_scale_min: 0.02 + bag_scale_max: 0.08 + bf_scale_min: 0.02 + bf_scale_max: 0.04 + bf_std_min: 0.1 + bf_std_max: 0.6 + gamma_std: 0.1 + noise_std_min: 0.05 # 0.05 # 5 + noise_std_max: 1. # 0.15 # 15 \ No newline at end of file diff --git a/cfgs/generator/train/brain_id_lowres.yaml b/cfgs/generator/train/brain_id_lowres.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4ff8b5b2650e9237c7e5f3e3d5af2836e578e7f9 --- /dev/null +++ b/cfgs/generator/train/brain_id_lowres.yaml @@ -0,0 +1,126 @@ +device_generator: + +split: train # NOTE: train or test + +dataset_names: [] # NOTE: None for all, 'ADHD200' for age estimation +dataset_probs: +modality_probs: { + 'ADHD200': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'HCP': {'T1': 0.3333, 'T2': 0.6667, 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'AIBL': {'T1': 0.25, 'T2': 0.5, 'FLAIR': 0.75, 'CT': 0., 'synth': 1.}, + 'OASIS3': {'T1': 0.25, 'T2': 0.5, 'FLAIR': 0., 'CT': 0.75, 'synth': 1.}, + 'ADNI': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'ADNI3': {'T1': 0.3333, 'T2': 0., 'FLAIR': 0.6667, 'CT': 0., 'synth': 1.}, + 'ATLAS': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'ISLES': {'T1': 0., 'T2': 0., 'FLAIR': 0.5, 'CT': 0., 'synth': 1.}, + + # TODO + 'ABIDE': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'Buckner40': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'COBRE': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'Chinese-HCP': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'ISBI2015': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'MCIC': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, +} +mix_synth_prob: 0.2 # blend synth with real images +dataset_option: brain_id +segment_prefix: brainseg_with_extracerebral # NOTE: set to 'brainseg' for photo_hemis setup + +# setups for training/testing tasks +task: + T1: True + T2: True + FLAIR: True + CT: True # NOTE: turn off for photo_hemis setup + + segmentation: True + distance: True + bias_field: True + registration: True + + # downstream + super_resolution: False + age: False + + surface: False + pathology: False + contrastive: False + + +# setups for generator +generator: + + size: [160, 160, 160] # [128, 128, 128], [160, 160, 160] + + left_hemis_only: False + low_res_only: True # UNDER TESTING + + photo_prob: 0.2 + max_rotation: 15 + max_shear: 0.2 + max_scaling: 0.2 + nonlin_scale_min: 0.03 + nonlin_scale_max: 0.06 + nonlin_std_max: 4 + bag_prob: 0.5 + bag_scale_min: 0.02 + bag_scale_max: 0.08 + bf_scale_min: 0.02 + bf_scale_max: 0.04 + bf_std_min: 0.1 + bf_std_max: 0.6 + gamma_std: 0.1 + noise_std_min: 0.05 + noise_std_max: 1. + exvixo_prob: 0.25 + exvixo_prob_vs_photo: 0.66666666666666 + + pv: True + random_shift: False + deform_one_hots: False + integrate_deformation_fields: False + produce_surfaces: False + bspline_zooming: False + n_steps_svf_integration: 8 + nonlinear_transform: True + + ct_prob: 0 + flip_prob: 0.5 + + + # brain-id customized setups + + # mild-to-severe intra-subject aug params + mild_samples: 0 # 2 + all_samples: 1 # 4 + all_contrasts: 1 # 4 # >= 1, <= all_samples + num_deformations: 1 + + + +# brain-id customized setups + +mild_generator: + bag_prob: 0.1 + bag_scale_min: 0.01 + bag_scale_max: 0.02 + bf_scale_min: 0.01 + bf_scale_max: 0.02 + bf_std_min: 0. + bf_std_max: 0.02 + gamma_std: 0.01 + noise_std_min: 0. + noise_std_max: 0.02 + + +severe_generator: + bag_prob: 0.5 + bag_scale_min: 0.02 + bag_scale_max: 0.08 + bf_scale_min: 0.02 + bf_scale_max: 0.04 + bf_std_min: 0.1 + bf_std_max: 0.6 + gamma_std: 0.1 + noise_std_min: 0.05 # 0.05 # 5 + noise_std_max: 1. # 0.15 # 15 \ No newline at end of file diff --git a/cfgs/generator/train/brain_id_lowres_shift.yaml b/cfgs/generator/train/brain_id_lowres_shift.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4ff8b5b2650e9237c7e5f3e3d5af2836e578e7f9 --- /dev/null +++ b/cfgs/generator/train/brain_id_lowres_shift.yaml @@ -0,0 +1,126 @@ +device_generator: + +split: train # NOTE: train or test + +dataset_names: [] # NOTE: None for all, 'ADHD200' for age estimation +dataset_probs: +modality_probs: { + 'ADHD200': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'HCP': {'T1': 0.3333, 'T2': 0.6667, 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'AIBL': {'T1': 0.25, 'T2': 0.5, 'FLAIR': 0.75, 'CT': 0., 'synth': 1.}, + 'OASIS3': {'T1': 0.25, 'T2': 0.5, 'FLAIR': 0., 'CT': 0.75, 'synth': 1.}, + 'ADNI': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'ADNI3': {'T1': 0.3333, 'T2': 0., 'FLAIR': 0.6667, 'CT': 0., 'synth': 1.}, + 'ATLAS': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'ISLES': {'T1': 0., 'T2': 0., 'FLAIR': 0.5, 'CT': 0., 'synth': 1.}, + + # TODO + 'ABIDE': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'Buckner40': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'COBRE': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'Chinese-HCP': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'ISBI2015': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'MCIC': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, +} +mix_synth_prob: 0.2 # blend synth with real images +dataset_option: brain_id +segment_prefix: brainseg_with_extracerebral # NOTE: set to 'brainseg' for photo_hemis setup + +# setups for training/testing tasks +task: + T1: True + T2: True + FLAIR: True + CT: True # NOTE: turn off for photo_hemis setup + + segmentation: True + distance: True + bias_field: True + registration: True + + # downstream + super_resolution: False + age: False + + surface: False + pathology: False + contrastive: False + + +# setups for generator +generator: + + size: [160, 160, 160] # [128, 128, 128], [160, 160, 160] + + left_hemis_only: False + low_res_only: True # UNDER TESTING + + photo_prob: 0.2 + max_rotation: 15 + max_shear: 0.2 + max_scaling: 0.2 + nonlin_scale_min: 0.03 + nonlin_scale_max: 0.06 + nonlin_std_max: 4 + bag_prob: 0.5 + bag_scale_min: 0.02 + bag_scale_max: 0.08 + bf_scale_min: 0.02 + bf_scale_max: 0.04 + bf_std_min: 0.1 + bf_std_max: 0.6 + gamma_std: 0.1 + noise_std_min: 0.05 + noise_std_max: 1. + exvixo_prob: 0.25 + exvixo_prob_vs_photo: 0.66666666666666 + + pv: True + random_shift: False + deform_one_hots: False + integrate_deformation_fields: False + produce_surfaces: False + bspline_zooming: False + n_steps_svf_integration: 8 + nonlinear_transform: True + + ct_prob: 0 + flip_prob: 0.5 + + + # brain-id customized setups + + # mild-to-severe intra-subject aug params + mild_samples: 0 # 2 + all_samples: 1 # 4 + all_contrasts: 1 # 4 # >= 1, <= all_samples + num_deformations: 1 + + + +# brain-id customized setups + +mild_generator: + bag_prob: 0.1 + bag_scale_min: 0.01 + bag_scale_max: 0.02 + bf_scale_min: 0.01 + bf_scale_max: 0.02 + bf_std_min: 0. + bf_std_max: 0.02 + gamma_std: 0.01 + noise_std_min: 0. + noise_std_max: 0.02 + + +severe_generator: + bag_prob: 0.5 + bag_scale_min: 0.02 + bag_scale_max: 0.08 + bf_scale_min: 0.02 + bf_scale_max: 0.04 + bf_std_min: 0.1 + bf_std_max: 0.6 + gamma_std: 0.1 + noise_std_min: 0.05 # 0.05 # 5 + noise_std_max: 1. # 0.15 # 15 \ No newline at end of file diff --git a/cfgs/generator/train/brain_id_reg.yaml b/cfgs/generator/train/brain_id_reg.yaml new file mode 100644 index 0000000000000000000000000000000000000000..98cad5ab55babdba5c79ea2ec13c56acf65417ca --- /dev/null +++ b/cfgs/generator/train/brain_id_reg.yaml @@ -0,0 +1,126 @@ +device_generator: + +split: train # NOTE: train or test + +dataset_names: [] # NOTE: None for all, 'ADHD200' for age estimation +dataset_probs: +modality_probs: { + 'ADHD200': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'HCP': {'T1': 0.3333, 'T2': 0.6667, 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'AIBL': {'T1': 0.25, 'T2': 0.5, 'FLAIR': 0.75, 'CT': 0., 'synth': 1.}, + 'OASIS3': {'T1': 0.25, 'T2': 0.5, 'FLAIR': 0., 'CT': 0.75, 'synth': 1.}, + 'ADNI': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'ADNI3': {'T1': 0.3333, 'T2': 0., 'FLAIR': 0.6667, 'CT': 0., 'synth': 1.}, + 'ATLAS': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'ISLES': {'T1': 0., 'T2': 0., 'FLAIR': 0.5, 'CT': 0., 'synth': 1.}, + + # TODO + 'ABIDE': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'Buckner40': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'COBRE': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'Chinese-HCP': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'ISBI2015': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'MCIC': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, +} +mix_synth_prob: 0.2 # blend synth with real images +dataset_option: brain_id +segment_prefix: brainseg_with_extracerebral + +# setups for training/testing tasks +task: + T1: False + T2: False + FLAIR: False + CT: False # NOTE: turn off for photo_hemis setup + + segmentation: False + distance: False + bias_field: False + registration: True + + # downstream + super_resolution: False + age: False + + surface: False + pathology: False + contrastive: False + + +# setups for generator +generator: + + size: [160, 160, 160] # [128, 128, 128], [160, 160, 160] + + left_hemis_only: False # UNDER TESTING + low_res_only: False # UNDER TESTING + + photo_prob: 0.2 + max_rotation: 15 + max_shear: 0.2 + max_scaling: 0.2 + nonlin_scale_min: 0.03 + nonlin_scale_max: 0.06 + nonlin_std_max: 4 + bag_prob: 0.5 + bag_scale_min: 0.02 + bag_scale_max: 0.08 + bf_scale_min: 0.02 + bf_scale_max: 0.04 + bf_std_min: 0.1 + bf_std_max: 0.6 + gamma_std: 0.1 + noise_std_min: 0.05 + noise_std_max: 1. + exvixo_prob: 0.25 + exvixo_prob_vs_photo: 0.66666666666666 + + pv: True + random_shift: False + deform_one_hots: False + integrate_deformation_fields: False + produce_surfaces: False + bspline_zooming: False + n_steps_svf_integration: 8 + nonlinear_transform: True + + ct_prob: 0 + flip_prob: 0.5 + + + # brain-id customized setups + + # mild-to-severe intra-subject aug params + mild_samples: 0 # 2 + all_samples: 1 # 4 + all_contrasts: 1 # 4 # >= 1, <= all_samples + num_deformations: 1 + + + +# brain-id customized setups + +mild_generator: + bag_prob: 0.1 + bag_scale_min: 0.01 + bag_scale_max: 0.02 + bf_scale_min: 0.01 + bf_scale_max: 0.02 + bf_std_min: 0. + bf_std_max: 0.02 + gamma_std: 0.01 + noise_std_min: 0. + noise_std_max: 0.02 + + +severe_generator: + bag_prob: 0.5 + bag_scale_min: 0.02 + bag_scale_max: 0.08 + bf_scale_min: 0.02 + bf_scale_max: 0.04 + bf_std_min: 0.1 + bf_std_max: 0.6 + gamma_std: 0.1 + noise_std_min: 0.05 # 0.05 # 5 + noise_std_max: 1. # 0.15 # 15 \ No newline at end of file diff --git a/cfgs/generator/train/brain_id_seg.yaml b/cfgs/generator/train/brain_id_seg.yaml new file mode 100644 index 0000000000000000000000000000000000000000..17a4cf8bb054c59551cb28c40281c254f7137917 --- /dev/null +++ b/cfgs/generator/train/brain_id_seg.yaml @@ -0,0 +1,126 @@ +device_generator: + +split: train # NOTE: train or test + +dataset_names: [] # NOTE: None for all, 'ADHD200' for age estimation +dataset_probs: +modality_probs: { + 'ADHD200': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'HCP': {'T1': 0.3333, 'T2': 0.6667, 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'AIBL': {'T1': 0.25, 'T2': 0.5, 'FLAIR': 0.75, 'CT': 0., 'synth': 1.}, + 'OASIS3': {'T1': 0.25, 'T2': 0.5, 'FLAIR': 0., 'CT': 0.75, 'synth': 1.}, + 'ADNI': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'ADNI3': {'T1': 0.3333, 'T2': 0., 'FLAIR': 0.6667, 'CT': 0., 'synth': 1.}, + 'ATLAS': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'ISLES': {'T1': 0., 'T2': 0., 'FLAIR': 0.5, 'CT': 0., 'synth': 1.}, + + # TODO + 'ABIDE': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'Buckner40': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'COBRE': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'Chinese-HCP': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'ISBI2015': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'MCIC': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, +} +mix_synth_prob: 0.2 # blend synth with real images +dataset_option: brain_id +segment_prefix: brainseg_with_extracerebral + +# setups for training/testing tasks +task: + T1: False + T2: False + FLAIR: False + CT: False # NOTE: turn off for photo_hemis setup + + segmentation: True + distance: False + bias_field: False + registration: False + + # downstream + super_resolution: False + age: False + + surface: False + pathology: False + contrastive: False + + +# setups for generator +generator: + + size: [160, 160, 160] # [128, 128, 128], [160, 160, 160] + + left_hemis_only: False # UNDER TESTING + low_res_only: False # UNDER TESTING + + photo_prob: 0.2 + max_rotation: 15 + max_shear: 0.2 + max_scaling: 0.2 + nonlin_scale_min: 0.03 + nonlin_scale_max: 0.06 + nonlin_std_max: 4 + bag_prob: 0.5 + bag_scale_min: 0.02 + bag_scale_max: 0.08 + bf_scale_min: 0.02 + bf_scale_max: 0.04 + bf_std_min: 0.1 + bf_std_max: 0.6 + gamma_std: 0.1 + noise_std_min: 0.05 + noise_std_max: 1. + exvixo_prob: 0.25 + exvixo_prob_vs_photo: 0.66666666666666 + + pv: True + random_shift: False + deform_one_hots: False + integrate_deformation_fields: False + produce_surfaces: False + bspline_zooming: False + n_steps_svf_integration: 8 + nonlinear_transform: True + + ct_prob: 0 + flip_prob: 0.5 + + + # brain-id customized setups + + # mild-to-severe intra-subject aug params + mild_samples: 0 # 2 + all_samples: 1 # 4 + all_contrasts: 1 # 4 # >= 1, <= all_samples + num_deformations: 1 + + + +# brain-id customized setups + +mild_generator: + bag_prob: 0.1 + bag_scale_min: 0.01 + bag_scale_max: 0.02 + bf_scale_min: 0.01 + bf_scale_max: 0.02 + bf_std_min: 0. + bf_std_max: 0.02 + gamma_std: 0.01 + noise_std_min: 0. + noise_std_max: 0.02 + + +severe_generator: + bag_prob: 0.5 + bag_scale_min: 0.02 + bag_scale_max: 0.08 + bf_scale_min: 0.02 + bf_scale_max: 0.04 + bf_std_min: 0.1 + bf_std_max: 0.6 + gamma_std: 0.1 + noise_std_min: 0.05 # 0.05 # 5 + noise_std_max: 1. # 0.15 # 15 \ No newline at end of file diff --git a/cfgs/generator/train/brain_id_shift.yaml b/cfgs/generator/train/brain_id_shift.yaml new file mode 100644 index 0000000000000000000000000000000000000000..64f8ae657a96de6fbca0a8fc62ef8a35152323f7 --- /dev/null +++ b/cfgs/generator/train/brain_id_shift.yaml @@ -0,0 +1,126 @@ +device_generator: + +split: train # NOTE: train or test + +dataset_names: [] # NOTE: None for all, 'ADHD200' for age estimation +dataset_probs: +modality_probs: { + 'ADHD200': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'HCP': {'T1': 0.3333, 'T2': 0.6667, 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'AIBL': {'T1': 0.25, 'T2': 0.5, 'FLAIR': 0.75, 'CT': 0., 'synth': 1.}, + 'OASIS3': {'T1': 0.25, 'T2': 0.5, 'FLAIR': 0., 'CT': 0.75, 'synth': 1.}, + 'ADNI': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'ADNI3': {'T1': 0.3333, 'T2': 0., 'FLAIR': 0.6667, 'CT': 0., 'synth': 1.}, + 'ATLAS': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'ISLES': {'T1': 0., 'T2': 0., 'FLAIR': 0.5, 'CT': 0., 'synth': 1.}, + + # TODO + 'ABIDE': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'Buckner40': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'COBRE': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'Chinese-HCP': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'ISBI2015': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'MCIC': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, +} +mix_synth_prob: 0.2 # blend synth with real images +dataset_option: brain_id +segment_prefix: brainseg_with_extracerebral # NOTE: set to 'brainseg' for photo_hemis setup + +# setups for training/testing tasks +task: + T1: True + T2: True + FLAIR: True + CT: True # NOTE: turn off for photo_hemis setup + + segmentation: True + distance: True + bias_field: True + registration: True + + # downstream + super_resolution: False + age: False + + surface: False + pathology: False + contrastive: False + + +# setups for generator +generator: + + size: [160, 160, 160] # [128, 128, 128], [160, 160, 160] + + left_hemis_only: False + low_res_only: False # UNDER TESTING + + photo_prob: 0.2 + max_rotation: 15 + max_shear: 0.2 + max_scaling: 0.2 + nonlin_scale_min: 0.03 + nonlin_scale_max: 0.06 + nonlin_std_max: 4 + bag_prob: 0.5 + bag_scale_min: 0.02 + bag_scale_max: 0.08 + bf_scale_min: 0.02 + bf_scale_max: 0.04 + bf_std_min: 0.1 + bf_std_max: 0.6 + gamma_std: 0.1 + noise_std_min: 0.05 + noise_std_max: 1. + exvixo_prob: 0.25 + exvixo_prob_vs_photo: 0.66666666666666 + + pv: True + random_shift: True + deform_one_hots: False + integrate_deformation_fields: False + produce_surfaces: False + bspline_zooming: False + n_steps_svf_integration: 8 + nonlinear_transform: True + + ct_prob: 0 + flip_prob: 0.5 + + + # brain-id customized setups + + # mild-to-severe intra-subject aug params + mild_samples: 0 # 2 + all_samples: 1 # 4 + all_contrasts: 1 # 4 # >= 1, <= all_samples + num_deformations: 1 + + + +# brain-id customized setups + +mild_generator: + bag_prob: 0.1 + bag_scale_min: 0.01 + bag_scale_max: 0.02 + bf_scale_min: 0.01 + bf_scale_max: 0.02 + bf_std_min: 0. + bf_std_max: 0.02 + gamma_std: 0.01 + noise_std_min: 0. + noise_std_max: 0.02 + + +severe_generator: + bag_prob: 0.5 + bag_scale_min: 0.02 + bag_scale_max: 0.08 + bf_scale_min: 0.02 + bf_scale_max: 0.04 + bf_std_min: 0.1 + bf_std_max: 0.6 + gamma_std: 0.1 + noise_std_min: 0.05 # 0.05 # 5 + noise_std_max: 1. # 0.15 # 15 \ No newline at end of file diff --git a/cfgs/generator/train/brain_id_sr.yaml b/cfgs/generator/train/brain_id_sr.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7f4cfce3bf34bb2d5fd1dc6d5fd52b8ef96beb17 --- /dev/null +++ b/cfgs/generator/train/brain_id_sr.yaml @@ -0,0 +1,126 @@ +device_generator: + +split: train # NOTE: train or test, train_age + +dataset_names: [] # NOTE: None for all, 'ADHD200' for age estimation +dataset_probs: +modality_probs: { + 'ADHD200': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'HCP': {'T1': 0.3333, 'T2': 0.6667, 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'AIBL': {'T1': 0.25, 'T2': 0.5, 'FLAIR': 0.75, 'CT': 0., 'synth': 1.}, + 'OASIS3': {'T1': 0.25, 'T2': 0.5, 'FLAIR': 0., 'CT': 0.75, 'synth': 1.}, + 'ADNI': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'ADNI3': {'T1': 0.3333, 'T2': 0., 'FLAIR': 0.6667, 'CT': 0., 'synth': 1.}, + 'ATLAS': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'ISLES': {'T1': 0., 'T2': 0., 'FLAIR': 0.5, 'CT': 0., 'synth': 1.}, + + # TODO + 'ABIDE': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'Buckner40': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'COBRE': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'Chinese-HCP': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'ISBI2015': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'MCIC': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, +} +mix_synth_prob: 0.2 # blend synth with real images +dataset_option: brain_id +segment_prefix: brainseg_with_extracerebral + +# setups for training/testing tasks +task: + T1: False + T2: False + FLAIR: False + CT: False # NOTE: turn off for photo_hemis setup + + segmentation: False + distance: False + bias_field: False + registration: False + + # downstream + super_resolution: True + age: False + + surface: False + pathology: False + contrastive: False + + +# setups for generator +generator: + + size: [160, 160, 160] # [128, 128, 128], [160, 160, 160] + + left_hemis_only: False # UNDER TESTING + low_res_only: False # UNDER TESTING + + photo_prob: 0.2 + max_rotation: 15 + max_shear: 0.2 + max_scaling: 0.2 + nonlin_scale_min: 0.03 + nonlin_scale_max: 0.06 + nonlin_std_max: 4 + bag_prob: 0.5 + bag_scale_min: 0.02 + bag_scale_max: 0.08 + bf_scale_min: 0.02 + bf_scale_max: 0.04 + bf_std_min: 0.1 + bf_std_max: 0.6 + gamma_std: 0.1 + noise_std_min: 0.05 + noise_std_max: 1. + exvixo_prob: 0.25 + exvixo_prob_vs_photo: 0.66666666666666 + + pv: True + random_shift: False + deform_one_hots: False + integrate_deformation_fields: False + produce_surfaces: False + bspline_zooming: False + n_steps_svf_integration: 8 + nonlinear_transform: True + + ct_prob: 0 + flip_prob: 0.5 + + + # brain-id customized setups + + # mild-to-severe intra-subject aug params + mild_samples: 0 # 2 + all_samples: 1 # 4 + all_contrasts: 1 # 4 # >= 1, <= all_samples + num_deformations: 1 + + + +# brain-id customized setups + +mild_generator: + bag_prob: 0.1 + bag_scale_min: 0.01 + bag_scale_max: 0.02 + bf_scale_min: 0.01 + bf_scale_max: 0.02 + bf_std_min: 0. + bf_std_max: 0.02 + gamma_std: 0.01 + noise_std_min: 0. + noise_std_max: 0.02 + + +severe_generator: + bag_prob: 0.5 + bag_scale_min: 0.02 + bag_scale_max: 0.08 + bf_scale_min: 0.02 + bf_scale_max: 0.04 + bf_std_min: 0.1 + bf_std_max: 0.6 + gamma_std: 0.1 + noise_std_min: 0.05 # 0.05 # 5 + noise_std_max: 1. # 0.15 # 15 \ No newline at end of file diff --git a/cfgs/generator/train/brain_id_synth.yaml b/cfgs/generator/train/brain_id_synth.yaml new file mode 100644 index 0000000000000000000000000000000000000000..674d67a42df4dda2a392e7d04df8fdc25ff57ecb --- /dev/null +++ b/cfgs/generator/train/brain_id_synth.yaml @@ -0,0 +1,126 @@ +device_generator: + +split: train # NOTE: train or test + +dataset_names: [] # NOTE: None for all, 'ADHD200' for age estimation +dataset_probs: +modality_probs: { + 'ADHD200': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'HCP': {'T1': 0.3333, 'T2': 0.6667, 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'AIBL': {'T1': 0.25, 'T2': 0.5, 'FLAIR': 0.75, 'CT': 0., 'synth': 1.}, + 'OASIS3': {'T1': 0.25, 'T2': 0.5, 'FLAIR': 0., 'CT': 0.75, 'synth': 1.}, + 'ADNI': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'ADNI3': {'T1': 0.3333, 'T2': 0., 'FLAIR': 0.6667, 'CT': 0., 'synth': 1.}, + 'ATLAS': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'ISLES': {'T1': 0., 'T2': 0., 'FLAIR': 0.5, 'CT': 0., 'synth': 1.}, + + # TODO + 'ABIDE': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'Buckner40': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'COBRE': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'Chinese-HCP': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'ISBI2015': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, + 'MCIC': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, +} +mix_synth_prob: 0.2 # blend synth with real images +dataset_option: brain_id +segment_prefix: brainseg_with_extracerebral + +# setups for training/testing tasks +task: + T1: True + T2: True + FLAIR: True + CT: True # NOTE: turn off for photo_hemis setup + + segmentation: False + distance: False + bias_field: False + registration: False + + # downstream + super_resolution: False + age: False + + surface: False + pathology: False + contrastive: False + + +# setups for generator +generator: + + size: [160, 160, 160] # [128, 128, 128], [160, 160, 160] + + left_hemis_only: False # UNDER TESTING + low_res_only: False # UNDER TESTING + + photo_prob: 0.2 + max_rotation: 15 + max_shear: 0.2 + max_scaling: 0.2 + nonlin_scale_min: 0.03 + nonlin_scale_max: 0.06 + nonlin_std_max: 4 + bag_prob: 0.5 + bag_scale_min: 0.02 + bag_scale_max: 0.08 + bf_scale_min: 0.02 + bf_scale_max: 0.04 + bf_std_min: 0.1 + bf_std_max: 0.6 + gamma_std: 0.1 + noise_std_min: 0.05 + noise_std_max: 1. + exvixo_prob: 0.25 + exvixo_prob_vs_photo: 0.66666666666666 + + pv: True + random_shift: False + deform_one_hots: False + integrate_deformation_fields: False + produce_surfaces: False + bspline_zooming: False + n_steps_svf_integration: 8 + nonlinear_transform: True + + ct_prob: 0 + flip_prob: 0.5 + + + # brain-id customized setups + + # mild-to-severe intra-subject aug params + mild_samples: 0 # 2 + all_samples: 1 # 4 + all_contrasts: 1 # 4 # >= 1, <= all_samples + num_deformations: 1 + + + +# brain-id customized setups + +mild_generator: + bag_prob: 0.1 + bag_scale_min: 0.01 + bag_scale_max: 0.02 + bf_scale_min: 0.01 + bf_scale_max: 0.02 + bf_std_min: 0. + bf_std_max: 0.02 + gamma_std: 0.01 + noise_std_min: 0. + noise_std_max: 0.02 + + +severe_generator: + bag_prob: 0.5 + bag_scale_min: 0.02 + bag_scale_max: 0.08 + bf_scale_min: 0.02 + bf_scale_max: 0.04 + bf_std_min: 0.1 + bf_std_max: 0.6 + gamma_std: 0.1 + noise_std_min: 0.05 # 0.05 # 5 + noise_std_max: 1. # 0.15 # 15 \ No newline at end of file diff --git a/cfgs/generator/train/inpaint.yaml b/cfgs/generator/train/inpaint.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8a6e259eae1e7ac51c6dd55b6ce742b56fef1972 --- /dev/null +++ b/cfgs/generator/train/inpaint.yaml @@ -0,0 +1,134 @@ +device_generator: #cuda:1 + +split: train # train or test + +dataset_names: ['ADHD'] # list of datasets +dataset_probs: +modality_probs: { + 'ADHD': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, # healthy + 'HCP': {'T1': 0., 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, # healthy + 'AIBL': {'T1': 0.25, 'T2': 0.5, 'FLAIR': 0.75, 'CT': 0., 'synth': 1.}, # healthy + 'OASIS': {'T1': 0.3333, 'T2': 0., 'FLAIR': 0., 'CT': 0.6667, 'synth': 1.}, # healthy + 'ADNI': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, # healthy / wmh + 'ADNI3': {'T1': 0., 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, # wmh + 'ATLAS': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, # stroke + 'ISLES': {'T1': 0., 'T2': 0., 'FLAIR': 0.5, 'CT': 0., 'synth': 1.}, # isles +} +mix_synth_prob: 0. # blend synth with real images +dataset_option: brain_id + +# setups for training/testing tasks +task: + T1: True + T2: False + FLAIR: False + CT: False + pathology: True + + super_resolution: False + segmentation: False + registration: False + surface: False + distance: False + bias_field: False + contrastive: False + + +# setups for augmentation functions to apply +augmentation_steps: ['gamma', 'bias_field', 'resample', 'noise'] + +# setups for generator +generator: + + #size: [128, 128, 128] + size: [100, 100, 100] + + photo_prob: 0.2 + max_rotation: 15 + max_shear: 0.2 + max_scaling: 0.2 + nonlin_scale_min: 0.03 + nonlin_scale_max: 0.06 + nonlin_std_max: 4 + bag_prob: 0.5 + bag_scale_min: 0.02 + bag_scale_max: 0.08 + bf_scale_min: 0.02 + bf_scale_max: 0.04 + bf_std_min: 0.1 + bf_std_max: 0.6 + gamma_std: 0.1 + noise_std_min: 5 + noise_std_max: 15 + exvixo_prob: 0.25 + exvixo_prob_vs_photo: 0.66666666666666 + + pv: True + random_shift: False + deform_one_hots: False + integrate_deformation_fields: False + produce_surfaces: False + bspline_zooming: False + n_steps_svf_integration: 8 + nonlinear_transform: True + + ct_prob: 0 + flip_prob: 0. + + pathology_prob: 1. # pathology_prob when synth + random_shape_prob: 1. # initialize pathol shape from random noise (v.s. existing shapes) + augment_pathology: True + + + # brain-id customized setups + + # mild-to-severe intra-subject aug params + mild_samples: 2 + all_samples: 4 + all_contrasts: 4 # >= 1, <= all_samples + num_deformations: 1 + + + + +pathology_shape_generator: + perlin_res: [2, 2, 2] # shape must be a multiple of res + mask_percentile_min: 85 + mask_percentile_max: 99.6 + integ_method: dopri5 # choices=['dopri5', 'adams', 'rk4', 'euler'] + bc: neumann # choices=['neumann', 'cauchy', 'dirichlet', 'source_neumann', 'dirichlet_neumann'] + V_multiplier: 500 + dt: 0.1 + max_nt: 10 # >= 2 + pathol_thres: 0.2 + pathol_tol: 0.000001 # if pathol mean < tol, skip + + + + +# brain-id customized setups + +mild_generator: + bag_prob: 0.1 + bag_scale_min: 0.01 + bag_scale_max: 0.02 + bf_scale_min: 0.01 + bf_scale_max: 0.02 + bf_std_min: 0. + bf_std_max: 0.02 + gamma_std: 0.01 + noise_std_min: 0. + noise_std_max: 0.02 + + +severe_generator: + bag_prob: 0.5 + bag_scale_min: 0.02 + bag_scale_max: 0.08 + bf_scale_min: 0.02 + bf_scale_max: 0.04 + bf_std_min: 0.1 + bf_std_max: 0.6 + gamma_std: 0.1 + noise_std_min: 5 + noise_std_max: 15 \ No newline at end of file diff --git a/cfgs/generator/train/shape_id.yaml b/cfgs/generator/train/shape_id.yaml new file mode 100644 index 0000000000000000000000000000000000000000..379cc05a6b3cd68a273eeb5a1b17fc251ade0bb8 --- /dev/null +++ b/cfgs/generator/train/shape_id.yaml @@ -0,0 +1,134 @@ +device_generator: cuda:1 + +split: train # train or test + +dataset_names: ['ADHD'] # list of datasets +dataset_probs: +modality_probs: { + 'ADHD': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, # healthy + 'HCP': {'T1': 0., 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, # healthy + 'AIBL': {'T1': 0.25, 'T2': 0.5, 'FLAIR': 0.75, 'CT': 0., 'synth': 1.}, # healthy + 'OASIS': {'T1': 0.3333, 'T2': 0., 'FLAIR': 0., 'CT': 0.6667, 'synth': 1.}, # healthy + 'ADNI': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, # healthy / wmh + 'ADNI3': {'T1': 0., 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, # wmh + 'ATLAS': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, # stroke + 'ISLES': {'T1': 0., 'T2': 0., 'FLAIR': 0.5, 'CT': 0., 'synth': 1.}, # isles +} +mix_synth_prob: 0. # blend synth with real images +dataset_option: brain_id + +# setups for training/testing tasks +task: + T1: False + T2: False + FLAIR: False + CT: False + pathology: True + + super_resolution: False + segmentation: False + registration: False + surface: False + distance: False + bias_field: False + contrastive: False + + +# setups for augmentation functions to apply +augmentation_steps: ['gamma', 'bias_field', 'resample', 'noise'] + +# setups for generator +generator: + + #size: [128, 128, 128] + size: [100, 100, 100] + + photo_prob: 0.2 + max_rotation: 15 + max_shear: 0.2 + max_scaling: 0.2 + nonlin_scale_min: 0.03 + nonlin_scale_max: 0.06 + nonlin_std_max: 4 + bag_prob: 0.5 + bag_scale_min: 0.02 + bag_scale_max: 0.08 + bf_scale_min: 0.02 + bf_scale_max: 0.04 + bf_std_min: 0.1 + bf_std_max: 0.6 + gamma_std: 0.1 + noise_std_min: 5 + noise_std_max: 15 + exvixo_prob: 0.25 + exvixo_prob_vs_photo: 0.66666666666666 + + pv: True + random_shift: False + deform_one_hots: False + integrate_deformation_fields: False + produce_surfaces: False + bspline_zooming: False + n_steps_svf_integration: 8 + nonlinear_transform: True + + ct_prob: 0 + flip_prob: 0. + + pathology_prob: 1. # pathology_prob when synth + random_shape_prob: 1. # initialize pathol shape from random noise (v.s. existing shapes) + augment_pathology: True + + + # brain-id customized setups + + # mild-to-severe intra-subject aug params + mild_samples: 2 + all_samples: 4 + all_contrasts: 4 # >= 1, <= all_samples + num_deformations: 1 + + + + +pathology_shape_generator: + perlin_res: [2, 2, 2] # shape must be a multiple of res + mask_percentile_min: 85 + mask_percentile_max: 99.6 + integ_method: dopri5 # choices=['dopri5', 'adams', 'rk4', 'euler'] + bc: neumann # choices=['neumann', 'cauchy', 'dirichlet', 'source_neumann', 'dirichlet_neumann'] + V_multiplier: 500 + dt: 0.1 + max_nt: 10 # >= 2 + pathol_thres: 0.2 + pathol_tol: 0.000001 # if pathol mean < tol, skip + + + + +# brain-id customized setups + +mild_generator: + bag_prob: 0.1 + bag_scale_min: 0.01 + bag_scale_max: 0.02 + bf_scale_min: 0.01 + bf_scale_max: 0.02 + bf_std_min: 0. + bf_std_max: 0.02 + gamma_std: 0.01 + noise_std_min: 0. + noise_std_max: 0.02 + + +severe_generator: + bag_prob: 0.5 + bag_scale_min: 0.02 + bag_scale_max: 0.08 + bf_scale_min: 0.02 + bf_scale_max: 0.04 + bf_std_min: 0.1 + bf_std_max: 0.6 + gamma_std: 0.1 + noise_std_min: 5 + noise_std_max: 15 \ No newline at end of file diff --git a/cfgs/generator/train/twostage.yaml b/cfgs/generator/train/twostage.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8a6e259eae1e7ac51c6dd55b6ce742b56fef1972 --- /dev/null +++ b/cfgs/generator/train/twostage.yaml @@ -0,0 +1,134 @@ +device_generator: #cuda:1 + +split: train # train or test + +dataset_names: ['ADHD'] # list of datasets +dataset_probs: +modality_probs: { + 'ADHD': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, # healthy + 'HCP': {'T1': 0., 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, # healthy + 'AIBL': {'T1': 0.25, 'T2': 0.5, 'FLAIR': 0.75, 'CT': 0., 'synth': 1.}, # healthy + 'OASIS': {'T1': 0.3333, 'T2': 0., 'FLAIR': 0., 'CT': 0.6667, 'synth': 1.}, # healthy + 'ADNI': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, # healthy / wmh + 'ADNI3': {'T1': 0., 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, # wmh + 'ATLAS': {'T1': 0.5, 'T2': 0., 'FLAIR': 0., 'CT': 0., 'synth': 1.}, # stroke + 'ISLES': {'T1': 0., 'T2': 0., 'FLAIR': 0.5, 'CT': 0., 'synth': 1.}, # isles +} +mix_synth_prob: 0. # blend synth with real images +dataset_option: brain_id + +# setups for training/testing tasks +task: + T1: True + T2: False + FLAIR: False + CT: False + pathology: True + + super_resolution: False + segmentation: False + registration: False + surface: False + distance: False + bias_field: False + contrastive: False + + +# setups for augmentation functions to apply +augmentation_steps: ['gamma', 'bias_field', 'resample', 'noise'] + +# setups for generator +generator: + + #size: [128, 128, 128] + size: [100, 100, 100] + + photo_prob: 0.2 + max_rotation: 15 + max_shear: 0.2 + max_scaling: 0.2 + nonlin_scale_min: 0.03 + nonlin_scale_max: 0.06 + nonlin_std_max: 4 + bag_prob: 0.5 + bag_scale_min: 0.02 + bag_scale_max: 0.08 + bf_scale_min: 0.02 + bf_scale_max: 0.04 + bf_std_min: 0.1 + bf_std_max: 0.6 + gamma_std: 0.1 + noise_std_min: 5 + noise_std_max: 15 + exvixo_prob: 0.25 + exvixo_prob_vs_photo: 0.66666666666666 + + pv: True + random_shift: False + deform_one_hots: False + integrate_deformation_fields: False + produce_surfaces: False + bspline_zooming: False + n_steps_svf_integration: 8 + nonlinear_transform: True + + ct_prob: 0 + flip_prob: 0. + + pathology_prob: 1. # pathology_prob when synth + random_shape_prob: 1. # initialize pathol shape from random noise (v.s. existing shapes) + augment_pathology: True + + + # brain-id customized setups + + # mild-to-severe intra-subject aug params + mild_samples: 2 + all_samples: 4 + all_contrasts: 4 # >= 1, <= all_samples + num_deformations: 1 + + + + +pathology_shape_generator: + perlin_res: [2, 2, 2] # shape must be a multiple of res + mask_percentile_min: 85 + mask_percentile_max: 99.6 + integ_method: dopri5 # choices=['dopri5', 'adams', 'rk4', 'euler'] + bc: neumann # choices=['neumann', 'cauchy', 'dirichlet', 'source_neumann', 'dirichlet_neumann'] + V_multiplier: 500 + dt: 0.1 + max_nt: 10 # >= 2 + pathol_thres: 0.2 + pathol_tol: 0.000001 # if pathol mean < tol, skip + + + + +# brain-id customized setups + +mild_generator: + bag_prob: 0.1 + bag_scale_min: 0.01 + bag_scale_max: 0.02 + bf_scale_min: 0.01 + bf_scale_max: 0.02 + bf_std_min: 0. + bf_std_max: 0.02 + gamma_std: 0.01 + noise_std_min: 0. + noise_std_max: 0.02 + + +severe_generator: + bag_prob: 0.5 + bag_scale_min: 0.02 + bag_scale_max: 0.08 + bf_scale_min: 0.02 + bf_scale_max: 0.04 + bf_std_min: 0.1 + bf_std_max: 0.6 + gamma_std: 0.1 + noise_std_min: 5 + noise_std_max: 15 \ No newline at end of file diff --git a/cfgs/submit.yaml b/cfgs/submit.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2012eeec7529ec3485bc416e60302f60ac23eceb --- /dev/null +++ b/cfgs/submit.yaml @@ -0,0 +1,46 @@ +# Number of gpus to request on each node +num_gpus: 1 +num_workers: 0 +vram: 12GB +# memory allocated per GPU in GB +mem_per_gpu: 256 +# Number of nodes to request +nodes: 1 +# Duration of the job +timeout: 4320 +# Job dir. Leave empty for automatic. +job_dir: '' +# Use to run jobs locally. ('debug', 'local', 'slurm') +cluster: debug +# Partition. Leave empty for automatic. +slurm_partition: '' +# Constraint. Leave empty for automatic. +slurm_constraint: '' +slurm_comment: '' +slurm_gres: '' +slurm_exclude: '' +cpus_per_task: 4 + +# devices +partition: "rtx8000" # options: {"rtx8000", "dgx-a100"} +shard_id: 0 # int: shard id for the current machine. Starts from 0 to num_shards - 1. If single machine is used, then set shard id to 0. +num_shards: 1 # int: number of shards using by the job. +init_method: "tcp://localhost:9999" # "tcp://localhost:9999" # str: initialization method to launch the job with multiple + # devices. Options includes TCP or shared file-system for initialization. + # details can be find in https://pytorch.org/docs/stable/distributed.html#tcp-initialization +opts: # list: provide addtional options from the command line, it overwrites the config loaded from file. +dist_backend: "nccl" +# url used to set up distributed training +dist_url: env:// + +rank: 0 +gpu: 0 # current gpu id +seed: 42 +world_size: 1 # number of distributed processes + + + + +device: + + diff --git a/cfgs/trainer/default_train.yaml b/cfgs/trainer/default_train.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c4f87ff8c5a3d76631e43cb24ac5c0409e9810f4 --- /dev/null +++ b/cfgs/trainer/default_train.yaml @@ -0,0 +1,161 @@ + +job_name: default +exp_name: default + +## paths +out_root: lemon # yogurt or lemon +root_dir_lemon: /autofs/vast/lemon/temp_stuff/peirong # mlsc run +root_dir_yogurt_out: /autofs/space/yogurt_002/users/pl629 # desktop run +out_dir: results/MTBrainID + + +supervised_seg_cnn_weights: /autofs/space/yogurt_002/users/pl629/ckp/supervised_segmentation_model.pth +feat_ext_ckp_path: +ckp_path: + +supervised_pathol_seg_ckp_path: {'feat': '/autofs/space/yogurt_002/users/pl629/ckp/Supv/supv_adni3_flair2pathol_feat_epoch_35.pth', + 'task': '/autofs/space/yogurt_002/users/pl629/ckp/Supv/supv_adni3_flair2pathol_epoch_35.pth'} +supervised_aux_pathol_seg_ckp_path: {'feat': '', + 'task': ''} + + + +device_segmenter: + + +task_f_maps: [64] +unit_feat: True + + +## losses and weights +losses: + image_grad: True + registration_grad: True + registration_smooth: False + registration_hessian: False + bias_field_log_type: l2 # l1 or l2 + # archived # + image_grad_mask: False + uncertainty: # only for recon or sr (regression tasks); options: { gaussian, laplace } + implicit_pathol: False + +weights: + seg_ce: 1. + seg_dice: 1. + pathol_ce: 1. + pathol_dice: 1. + implicit_pathol_ce: 1. + implicit_pathol_dice: 1. + dist: 1. + image: 1. + image_grad: 1. + seg_supervised: 1. + bias_field_log: 1. + reg: 1. + reg_grad: 1. + contrastive: 1. + + age: 1. + surface: 1. + distance: 1. + registration: 1. + registration_grad: 1. + registration_smooth: 1. + registration_hessian: 1. + + +## training params +start_epoch: 0 +train_itr_limit: # if not None, it sets the max itr per epoch +train_subset: #0.2 +train_txt: + +resume: False +reset_epoch: False +resume_optim: True +resume_lr_scheduler: True +freeze_feat: False + + +batch_size: 1 + +n_epochs: 400 +lr_scheduler: multistep # cosine, multistep +lr_drops: [250,300] +lr_drop_multi: 0.1 +lr: 0.0001 # Learning rate at the end of linear warmup (highest LR used during training) +min_lr: 0.000001 # Target LR at the end of optimization. We use a cosine LR schedule with linear warmup +warmup_epochs: 1 # Number of epochs for the linear learning-rate warm up + +feat_opt: + lr: 0.0001 # Learning rate at the end of linear warmup (highest LR used during training) + min_lr: 0.000001 # Target LR at the end of optimization. We use a cosine LR schedule with linear warmup + + +optimizer: adamw # adam, adamw +weight_decay: 0 # 0.04 # Final value of the weight decay. A cosine schedule for WD and using a larger decay by the end of training improves performance for ViTs. +weight_decay_end: 0 # 0.4 # Final value of the weight decay. A cosine schedule for WD and using a larger decay by the end of training improves performance for ViTs. +momentum: 1 # 1 as disabling momentum + +# gradient clipping max norm +clip_max_norm: 0. +freeze_last_layer: 0 #Number of epochs during which we keep the output layer fixed. Typically doing so during the first epoch helps training. Try increasing this value if the loss does not decrease + + + +## testing params +eval_only: False +debug: False +test_itr_limit: +test_subset: +test_txt: + + + +################################ +########### Backbone ########### +################################ + +condition: # mask, flip, mask+flip + +backbone: unet3d # options: unet2d, unet3d, unet3d_sep, unet3d+unet3d + + +### UNet setting ### +in_channels: 1 +f_maps: 64 +f_maps_supervised_seg_cnn: 32 +num_groups: 8 +num_levels: 5 +layer_order: 'gcl' +final_sigmoid: False + + +### DDPM_Pseudo3D setting ### +ema_rate: 0.9999 +use_fp16: False +fp16_scale_growth: 0.001 + + + + + + +## UNDER TESTING ## +relative_weight_lesions: 1.0 # for now... + + +## visualizer params +visualizer: + make_results: False + save_image: False + spacing: [8, 8, 8] + + feat_vis: True + feat_vis_num: 10 # fixed number of feature channels to plot + +## logging intervals +val_epoch: 100000 # must be multiplicable by save_model_epoch + +log_itr: 100 +vis_itr: 1000 \ No newline at end of file diff --git a/cfgs/trainer/default_val.yaml b/cfgs/trainer/default_val.yaml new file mode 100644 index 0000000000000000000000000000000000000000..11038c3477a8f92b6a9935f32527605165470dd1 --- /dev/null +++ b/cfgs/trainer/default_val.yaml @@ -0,0 +1,126 @@ +# testing set up # + + +test_pass: 1 # 1 or 2 + +test_itr_limit: 1 # n_subjects +test_mild_samples: 1 +test_all_samples: 1 # n_samples within each subject + +max_test_win_size: [220, 220, 220] +test_win_partition: False + + +#### IF we want to augment the testing dataset: + +base_test_generator: + + data_augmentation: True + + apply_deformation: True + nonlinear_transform: False + integrate_deformation_fields: True + + # below setups are effective ONLY IF data_augmentation is True: + + apply_gamma_transform: True + apply_bias_field: True + apply_resampling: True + hyperfine_prob: 0. + apply_noises: True + + ######### + ct_prob: 0. + + pathology_prob: 0. + pathology_thres_max: 1. + pathology_mu_multi: 500. + pathology_sig_multi: 50. + + + noise_std_min: 0.01 # 5 # should be small if inputs are real images + noise_std_max: 0.1 # 15 + + ############################ + + ## synth + label_list_segmentation_with_csf: [0,14,15,16,24,77,85, 2, 3, 4, 7, 8, 10,11,12,13,17,18,26,28, 41,42,43,46,47,49,50,51,52,53,54,58,60] + n_neutral_labels_with_csf: 7 + label_list_segmentation_without_csf: [0,14,15,16,77,85, 2, 3, 4, 7, 8, 10,11,12,13,17,18,26,28, 41,42,43,46,47,49,50,51,52,53,54,58,60] + n_neutral_labels_without_csf: 6 + + + ## synth_hemi + # without cerebellum and brainstem + label_list_segmentation: [0, 2, 3, 4, 10, 11, 12, 13, 17, 18, 26, 28, 77] + n_neutral_labels: 6 + + # with cerebellum and brainstem + label_list_segmentation_with_cb: [0, 2, 3, 4, 7, 8, 10, 11, 12, 13, 16, 17, 18, 26, 28, 77] + + max_surf_distance: 2.0 # clamp at plus / minus this number (both the ground truth and the prediction) + + size: [128, 128, 128] + photo_prob: 0.2 + max_rotation: 10 + max_shear: 0.2 + max_scaling: 0.2 + nonlin_scale_min: 0.03 + nonlin_scale_max: 0.06 + nonlin_std_max: 2 + + bag_prob: 0.5 + bag_scale_min: 0.02 + bag_scale_max: 0.08 + bf_scale_min: 0.01 + bf_scale_max: 0.02 + bf_std_min: 0. + bf_std_max: 0.1 + gamma_std: 0.05 + exvixo_prob: 0.25 + exvixo_prob_vs_photo: 0.66666666666666 + + pv: True + deform_one_hots: False + produce_surfaces: False + bspline_zooming: False + n_steps_svf_integration: 8 + nonlinear_transform: True + + +#### For ID-Synth #### + +## mild generator set up +mild_test_generator: + bag_prob: 0.1 + bag_scale_min: 0.01 + bag_scale_max: 0.02 + bf_scale_min: 0.01 + bf_scale_max: 0.02 + bf_std_min: 0. + bf_std_max: 0.02 + gamma_std: 0.01 + noise_std_min: 0. + noise_std_max: 0.02 + exvixo_prob: 0.25 + exvixo_prob_vs_photo: 0.66666666666666 + +## severe generator set up +# real data should not have too severe corruptions as synthetic data +severe_test_generator: + bag_prob: 0.5 + bag_scale_min: 0.02 + bag_scale_max: 0.08 + bf_scale_min: 0.02 + bf_scale_max: 0.04 + bf_std_min: 0.1 + bf_std_max: 0.6 + gamma_std: 0.1 + noise_std_min: 0.1 + noise_std_max: 0.5 + exvixo_prob: 0.25 + exvixo_prob_vs_photo: 0.66666666666666 + + + + diff --git a/cfgs/trainer/test/demo_test.yaml b/cfgs/trainer/test/demo_test.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d8193437098c2f638446e8df0a7beac81d95d466 --- /dev/null +++ b/cfgs/trainer/test/demo_test.yaml @@ -0,0 +1,10 @@ +## job specific set ups ## +init_method: "tcp://localhost:9999" + +eval_only: True +debug: False + + +backbone: unet3d #+unet3d # options: unet2d, unet3d, unet3d_sep, unet3d+unet3d +num_levels: 6 # 5 (1024), 6 (2054), 7 (4096) +condition: #mask #+flip # mask, flip, mask+flip diff --git a/cfgs/trainer/train/joint.yaml b/cfgs/trainer/train/joint.yaml new file mode 100644 index 0000000000000000000000000000000000000000..368cb98e9c87fc5f0d421c50522a4584fa633ea2 --- /dev/null +++ b/cfgs/trainer/train/joint.yaml @@ -0,0 +1,26 @@ +## job specific set ups ## +exp_name: wosr_reggrad #wosr_reggrad #hemis #wosr_reggrad #wosr_reggrad_lowres # age_pool +job_name: l6_16 +init_method: "tcp://localhost:9998" + +eval_only: False +debug: False + +resume: True +reset_epoch: False +resume_optim: True +ckp_path: /autofs/vast/lemon/temp_stuff/peirong/results/MTBrainID/wosr_reggrad/l6_16/1128-0936/ckp/checkpoint_latest.pth + + +n_epochs: 5000 +#lr_drops: [1600] # [70, 90] # [120] # [70, 90] +lr_drops: [2500] # [70, 90] # [120] # [70, 90] +#lr_drops: [660] # [70, 90] # [120] # [70, 90] + +condition: #mask+flip # mask, flip, mask+flip + +log_itr: 200 +vis_itr: 6000 + + +num_levels: 6 # 5 (1024), 6 (2048), 7 (4096) \ No newline at end of file diff --git a/cfgs/trainer/train/joint_age.yaml b/cfgs/trainer/train/joint_age.yaml new file mode 100644 index 0000000000000000000000000000000000000000..13a168bd02d335795967b0103f592592f668ee6b --- /dev/null +++ b/cfgs/trainer/train/joint_age.yaml @@ -0,0 +1,24 @@ +## job specific set ups ## +exp_name: age_rigid #wosr_reggrad #hemis # wosr_reggrad_lowres # age_pool +job_name: l6_16 +init_method: "tcp://localhost:9981" + +eval_only: False +debug: False + +resume: True +reset_epoch: True # NOTE +resume_optim: False +ckp_path: /autofs/vast/lemon/temp_stuff/peirong/results/MTBrainID/age_pool/l6_16/0729-1611/ckp/checkpoint_latest.pth + + +n_epochs: 5000 +lr_drops: [2000] # [70, 90] # [120] # [70, 90] + +condition: #mask+flip # mask, flip, mask+flip + +log_itr: 200 +vis_itr: 6000 + + +num_levels: 6 # 5 (1024), 6 (2054), 7 (4096) \ No newline at end of file diff --git a/cfgs/trainer/train/joint_bf.yaml b/cfgs/trainer/train/joint_bf.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c540bdcd293859842099aba9849ef38c0991b5e8 --- /dev/null +++ b/cfgs/trainer/train/joint_bf.yaml @@ -0,0 +1,25 @@ +## job specific set ups ## +exp_name: bf +job_name: l6_16 +init_method: "tcp://localhost:9992" + +eval_only: False +debug: False + +resume: True +reset_epoch: False +resume_optim: True +ckp_path: /autofs/vast/lemon/temp_stuff/peirong/results/MTBrainID/bf/l6_16/0806-1026/ckp/checkpoint_latest.pth + + + +n_epochs: 1600 +lr_drops: [1200] # [70, 90] # [120] # [70, 90] + +condition: #mask+flip # mask, flip, mask+flip + +log_itr: 200 +vis_itr: 6000 + + +num_levels: 6 # 5 (1024), 6 (2054), 7 (4096) \ No newline at end of file diff --git a/cfgs/trainer/train/joint_dist.yaml b/cfgs/trainer/train/joint_dist.yaml new file mode 100644 index 0000000000000000000000000000000000000000..142cf34df649c8220787e4ffc1b561834c51f10a --- /dev/null +++ b/cfgs/trainer/train/joint_dist.yaml @@ -0,0 +1,25 @@ +## job specific set ups ## +exp_name: dist #wosr_reggrad #hemis # wosr_reggrad_lowres # age_pool +job_name: l6_16 +init_method: "tcp://localhost:9993" + +eval_only: False +debug: False + +resume: True +reset_epoch: False +resume_optim: True +ckp_path: /autofs/vast/lemon/temp_stuff/peirong/results/MTBrainID/dist/l6_16/0806-1024/ckp/checkpoint_latest.pth + + + +n_epochs: 2000 +lr_drops: [1200] # [70, 90] # [120] # [70, 90] + +condition: #mask+flip # mask, flip, mask+flip + +log_itr: 200 +vis_itr: 6000 + + +num_levels: 6 # 5 (1024), 6 (2054), 7 (4096) \ No newline at end of file diff --git a/cfgs/trainer/train/joint_hemis.yaml b/cfgs/trainer/train/joint_hemis.yaml new file mode 100644 index 0000000000000000000000000000000000000000..90ab69c944fcc17c034f892679d17ae1ba92abd6 --- /dev/null +++ b/cfgs/trainer/train/joint_hemis.yaml @@ -0,0 +1,27 @@ +## job specific set ups ## +exp_name: hemis #wosr_reggrad #hemis # wosr_reggrad_lowres # age_pool +job_name: l6_16 +init_method: "tcp://localhost:9991" + +eval_only: False +debug: False + +resume: True +reset_epoch: True +resume_optim: False +ckp_path: /autofs/vast/lemon/temp_stuff/peirong/results/MTBrainID/wosr_reggrad/l6_16/0904-1645/ckp/checkpoint_latest.pth + + + +n_epochs: 5000 +#lr_drops: [1600] # [70, 90] # [120] # [70, 90] +lr_drops: [2000] # [70, 90] # [120] # [70, 90] +#lr_drops: [660] # [70, 90] # [120] # [70, 90] + +condition: #mask+flip # mask, flip, mask+flip + +log_itr: 200 +vis_itr: 6000 + + +num_levels: 6 # 5 (1024), 6 (2048), 7 (4096) \ No newline at end of file diff --git a/cfgs/trainer/train/joint_lowres.yaml b/cfgs/trainer/train/joint_lowres.yaml new file mode 100644 index 0000000000000000000000000000000000000000..27ab18c58a50198c365de16f3ec9bb4921bafc41 --- /dev/null +++ b/cfgs/trainer/train/joint_lowres.yaml @@ -0,0 +1,26 @@ +## job specific set ups ## +exp_name: wosr_reggrad_lowres #wosr_reggrad #hemis #wosr_reggrad #wosr_reggrad_lowres # age_pool +job_name: l6_16 +init_method: "tcp://localhost:9999" + +eval_only: False +debug: False + +resume: True +reset_epoch: False +resume_optim: True +ckp_path: /autofs/vast/lemon/temp_stuff/peirong/results/MTBrainID/wosr_reggrad_lowres/l6_16/1101-1744/ckp/checkpoint_latest.pth + + +n_epochs: 5000 +#lr_drops: [1600] # [70, 90] # [120] # [70, 90] +lr_drops: [2500] # [70, 90] # [120] # [70, 90] +#lr_drops: [660] # [70, 90] # [120] # [70, 90] + +condition: #mask+flip # mask, flip, mask+flip + +log_itr: 200 +vis_itr: 6000 + + +num_levels: 6 # 5 (1024), 6 (2048), 7 (4096) \ No newline at end of file diff --git a/cfgs/trainer/train/joint_lowres_shift.yaml b/cfgs/trainer/train/joint_lowres_shift.yaml new file mode 100644 index 0000000000000000000000000000000000000000..efd470c1588d853486f6f97e8fcbdc38916f86b3 --- /dev/null +++ b/cfgs/trainer/train/joint_lowres_shift.yaml @@ -0,0 +1,26 @@ +## job specific set ups ## +exp_name: wosr_reggrad_lowres_shift #wosr_reggrad #hemis #wosr_reggrad #wosr_reggrad_lowres # age_pool +job_name: l6_16 +init_method: "tcp://localhost:9989" + +eval_only: False +debug: False + +resume: True +reset_epoch: False +resume_optim: True +ckp_path: /autofs/vast/lemon/temp_stuff/peirong/results/MTBrainID/wosr_reggrad_lowres_shift/l6_16/1102-1007/ckp/checkpoint_latest.pth + + +n_epochs: 5000 +#lr_drops: [1600] # [70, 90] # [120] # [70, 90] +lr_drops: [3000] # [70, 90] # [120] # [70, 90] +#lr_drops: [660] # [70, 90] # [120] # [70, 90] + +condition: #mask+flip # mask, flip, mask+flip + +log_itr: 200 +vis_itr: 6000 + + +num_levels: 6 # 5 (1024), 6 (2048), 7 (4096) \ No newline at end of file diff --git a/cfgs/trainer/train/joint_reg.yaml b/cfgs/trainer/train/joint_reg.yaml new file mode 100644 index 0000000000000000000000000000000000000000..294f12fb0ff7fcd7ab2a4a2efbac2be24af3f6ea --- /dev/null +++ b/cfgs/trainer/train/joint_reg.yaml @@ -0,0 +1,25 @@ +## job specific set ups ## +exp_name: reg +job_name: l6_16 +init_method: "tcp://localhost:9991" + +eval_only: False +debug: False + +resume: True +reset_epoch: False +resume_optim: True +ckp_path: /autofs/vast/lemon/temp_stuff/peirong/results/MTBrainID/reg/l6_16/0806-1029/ckp/checkpoint_latest.pth + + + +n_epochs: 1600 +lr_drops: [1200] # [70, 90] # [120] # [70, 90] + +condition: #mask+flip # mask, flip, mask+flip + +log_itr: 200 +vis_itr: 6000 + + +num_levels: 6 # 5 (1024), 6 (2054), 7 (4096) \ No newline at end of file diff --git a/cfgs/trainer/train/joint_seg.yaml b/cfgs/trainer/train/joint_seg.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c1dc8558812736e04d6d33054d58c98fcb4e85c8 --- /dev/null +++ b/cfgs/trainer/train/joint_seg.yaml @@ -0,0 +1,25 @@ +## job specific set ups ## +exp_name: seg +job_name: l6_16 +init_method: "tcp://localhost:9990" + +eval_only: False +debug: False + +resume: True +reset_epoch: False +resume_optim: True +ckp_path: /autofs/vast/lemon/temp_stuff/peirong/results/MTBrainID/seg/l6_16/0806-1021/ckp/checkpoint_latest.pth + + + +n_epochs: 2000 +lr_drops: [1200] # [70, 90] # [120] # [70, 90] + +condition: #mask+flip # mask, flip, mask+flip + +log_itr: 200 +vis_itr: 6000 + + +num_levels: 6 # 5 (1024), 6 (2054), 7 (4096) \ No newline at end of file diff --git a/cfgs/trainer/train/joint_shift.yaml b/cfgs/trainer/train/joint_shift.yaml new file mode 100644 index 0000000000000000000000000000000000000000..56b003de83e894031f591afd37010e9518a0c886 --- /dev/null +++ b/cfgs/trainer/train/joint_shift.yaml @@ -0,0 +1,26 @@ +## job specific set ups ## +exp_name: wosr_reggrad_shift #wosr_reggrad #hemis #wosr_reggrad #wosr_reggrad_lowres # age_pool +job_name: l6_16 +init_method: "tcp://localhost:9988" + +eval_only: False +debug: False + +resume: True +reset_epoch: False +resume_optim: True +ckp_path: /autofs/vast/lemon/temp_stuff/peirong/results/MTBrainID/wosr_reggrad_shift/l6_16/1102-1006/ckp/checkpoint_latest.pth + + +n_epochs: 5000 +#lr_drops: [1600] # [70, 90] # [120] # [70, 90] +lr_drops: [3000] # [70, 90] # [120] # [70, 90] +#lr_drops: [660] # [70, 90] # [120] # [70, 90] + +condition: #mask+flip # mask, flip, mask+flip + +log_itr: 200 +vis_itr: 6000 + + +num_levels: 6 # 5 (1024), 6 (2048), 7 (4096) \ No newline at end of file diff --git a/cfgs/trainer/train/joint_sr.yaml b/cfgs/trainer/train/joint_sr.yaml new file mode 100644 index 0000000000000000000000000000000000000000..415d1905f098a761391cd4e3a1bd79cb71402065 --- /dev/null +++ b/cfgs/trainer/train/joint_sr.yaml @@ -0,0 +1,25 @@ +## job specific set ups ## +exp_name: sr #wosr_reggrad #hemis # wosr_reggrad_lowres # age_pool +job_name: l6_16 +init_method: "tcp://localhost:9997" + +eval_only: False +debug: False + +resume: True +reset_epoch: False +resume_optim: True +ckp_path: /autofs/vast/lemon/temp_stuff/peirong/results/MTBrainID/sr/l6_16/0926-2035/ckp/checkpoint_latest.pth + + +n_epochs: 1600 +#lr_drops: [1200] # [70, 90] # [120] # [70, 90] +lr_drops: [1600] # [70, 90] # [120] # [70, 90] # lowres + +condition: #mask+flip # mask, flip, mask+flip + +log_itr: 200 +vis_itr: 6000 + + +num_levels: 6 # 5 (1024), 6 (2054), 7 (4096) \ No newline at end of file diff --git a/cfgs/trainer/train/joint_synth.yaml b/cfgs/trainer/train/joint_synth.yaml new file mode 100644 index 0000000000000000000000000000000000000000..beed0959ed39bb4b39f1e4e0c81afee5ef41099f --- /dev/null +++ b/cfgs/trainer/train/joint_synth.yaml @@ -0,0 +1,25 @@ +## job specific set ups ## +exp_name: synth +job_name: l6_16 +init_method: "tcp://localhost:9993" + +eval_only: False +debug: False + +resume: True +reset_epoch: False +resume_optim: True +ckp_path: /autofs/vast/lemon/temp_stuff/peirong/results/MTBrainID/synth/l6_16/0826-1637/ckp/checkpoint_latest.pth + + + +n_epochs: 2000 +lr_drops: [800] # [70, 90] # [120] # [70, 90] + +condition: #mask+flip # mask, flip, mask+flip + +log_itr: 200 +vis_itr: 6000 + + +num_levels: 6 # 5 (1024), 6 (2054), 7 (4096) \ No newline at end of file diff --git a/cfgs/trainer/train/sep.yaml b/cfgs/trainer/train/sep.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c289124429700e6f3f8746ee5442b12e053d4355 --- /dev/null +++ b/cfgs/trainer/train/sep.yaml @@ -0,0 +1,22 @@ +## job specific set ups ## +exp_name: shape-random-disentangle +job_name: adhd-synth +init_method: "tcp://localhost:9999" + +eval_only: False +debug: False + +resume: True +reset_epoch: False +resume_optim: True +ckp_path: /autofs/space/yogurt_002/users/pl629/results/BrainID/shape-random-disentangle/adhd-synth/0426-1110/ckp/checkpoint_latest.pth + +n_epochs: 1500 +lr_drops: [] # [70, 90] # [120] # [70, 90] + +# Backbone +backbone: unet3d_sep + + +#log_itr: 1 +#vis_itr: 1 \ No newline at end of file diff --git a/cfgs/trainer/train/twostage.yaml b/cfgs/trainer/train/twostage.yaml new file mode 100644 index 0000000000000000000000000000000000000000..790060c2d1277245e87e979f2607eb94a997e790 --- /dev/null +++ b/cfgs/trainer/train/twostage.yaml @@ -0,0 +1,24 @@ +## job specific set ups ## +exp_name: two-stage-inpaint-masked +job_name: adhd +init_method: "tcp://localhost:9997" + +eval_only: False +debug: False + +resume: True +reset_epoch: False +resume_optim: True +pathol_ckp_path: /autofs/space/yogurt_002/users/pl629/results/BrainID/two-stage-inpaint-masked/adhd/0430-1027/ckp/checkpoint_latest_pathol.pth +task_ckp_path: /autofs/space/yogurt_002/users/pl629/results/BrainID/two-stage-inpaint-masked/adhd/0430-1027/ckp/checkpoint_latest_task.pth + +n_epochs: 1500 +lr_drops: [] # [70, 90] # [120] # [70, 90] + +backbone: unet3d+unet3d # options: unet2d, unet3d, unet3d_sep, unet3d+unet3d + +condition: mask+flip # mask, flip, mask+flip + +#log_itr: 1 +#vis_itr: 1 + diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..1257183544f2651dbff1ac1ac6a002d5ac64a137 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,22 @@ +versioned-hdf5==1.6.0 +iopath==0.1.10 +matplotlib==3.7.1 +nibabel==5.2.0 +numpy==1.24.3 +pandas==1.5.3 +Pillow==9.4.0 +pytest==7.4.0 +pytorch_msssim==1.0.0 +pytz==2022.7 +PyYAML==6.0 +scipy==1.10.1 +seaborn==0.13.0 +setuptools==68.0.0 +SimpleITK==2.3.0 +simplejson==3.19.1 +scikit-image==0.20.0 +tabulate==0.8.10 +torch==2.0.1 +torchvision==0.15.2 +tqdm==4.65.0 +visdom==0.2.4 diff --git a/scripts/.DS_Store b/scripts/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..5008ddfcf53c02e82d7eee2e57c38e5672ef89f6 Binary files /dev/null and b/scripts/.DS_Store differ diff --git a/scripts/__init__.py b/scripts/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scripts/demo_generator.py b/scripts/demo_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..2b105b5102e60ec7b1483527e67ff56bc70c4538 --- /dev/null +++ b/scripts/demo_generator.py @@ -0,0 +1,124 @@ +############################### +#### Synthetic Data Demo #### +############################### + + +import datetime +import os, sys +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +import time + +import torch + +import utils.misc as utils + + +from Generator import build_datasets + + + +# default & gpu cfg # +default_gen_cfg_file = '/autofs/space/yogurt_003/users/pl629/code/MTBrainID/cfgs/generator/default.yaml' +demo_gen_cfg_file = '/autofs/space/yogurt_003/users/pl629/code/MTBrainID/cfgs/generator/test/demo_synth.yaml' + + +def map_back_orig(img, idx, shp): + if idx is None or shp is None: + return img + if len(img.shape) == 3: + img = img[None, None] + elif len(img.shape) == 4: + img = img[None] + return img[:, :, idx[0]:idx[0] + shp[0], idx[1]:idx[1] + shp[1], idx[2]:idx[2] + shp[2]] + + +def generate(args): + + _, gen_args, _ = args + + if gen_args.device_generator: + device = gen_args.device_generator + elif torch.cuda.is_available(): + device = torch.cuda.current_device() + else: + device = 'cpu' + print('device: %s' % device) + + print('out_dir:', gen_args.out_dir) + + # ============ preparing data ... ============ + dataset_dict = build_datasets(gen_args, device = gen_args.device_generator if gen_args.device_generator is not None else device) + dataset = dataset_dict[gen_args.dataset_names[0]] + + tasks = [key for (key, value) in vars(gen_args.task).items() if value] + + print("Start generating") + start_time = time.time() + + + dataset.mild_samples = gen_args.mild_samples + dataset.all_samples = gen_args.all_samples + for itr in range(min(gen_args.test_itr_limit, len(dataset.names))): + + subj_name = os.path.basename(dataset.names[itr]).split('.nii')[0] + + save_dir = utils.make_dir(os.path.join(gen_args.out_dir, subj_name)) + + print('Processing image (%d/%d): %s' % (itr, len(dataset), dataset.names[itr])) + + for i_deform in range(gen_args.num_deformations): + def_save_dir = utils.make_dir(os.path.join(save_dir, 'deform-%s' % i_deform)) + + (_, subjects, samples) = dataset.__getitem__(itr) + + if 'aff' in subjects: + aff = subjects['aff'] + shp = subjects['shp'] + loc_idx = subjects['loc_idx'] + else: + aff = torch.eye((4)) + shp = loc_idx = None + + print('num samples:', len(samples)) + print(' deform:', i_deform) + + #print(subjects.keys()) + + if 'T1' in subjects: + utils.viewVolume(subjects['T1'], aff, names = ['T1'], save_dir = def_save_dir) + if 'T2' in subjects: + utils.viewVolume(subjects['T2'], aff, names = ['T2'], save_dir = def_save_dir) + if 'FLAIR' in subjects: + utils.viewVolume(subjects['FLAIR'], aff, names = ['FLAIR'], save_dir = def_save_dir) + if 'CT' in subjects: + utils.viewVolume(subjects['CT'], aff, names = ['CT'], save_dir = def_save_dir) + if 'pathology' in tasks: + utils.viewVolume(subjects['pathology'], aff, names = ['pathology'], save_dir = def_save_dir) + if 'segmentation' in tasks: + utils.viewVolume(subjects['segmentation']['label'], aff, names = ['label'], save_dir = def_save_dir) + + for i_sample, sample in enumerate(samples): + print(' sample:', i_sample) + sample_save_dir = utils.make_dir(os.path.join(def_save_dir, 'sample-%s' % i_sample)) + + #print(sample.keys()) + + if 'input' in sample: + utils.viewVolume(map_back_orig(sample['input'], loc_idx, shp), aff, names = ['input'], save_dir = sample_save_dir) + if 'super_resolution' in tasks: + utils.viewVolume(map_back_orig(sample['orig'], loc_idx, shp), aff, names = ['high_reso'], save_dir = sample_save_dir) + if 'bias_field' in tasks: + utils.viewVolume(map_back_orig(torch.exp(sample['bias_field_log']), loc_idx, shp), aff, names = ['bias_field'], save_dir = sample_save_dir) + + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('Generation time {}'.format(total_time_str)) + + +##################################################################################### + + +if __name__ == '__main__': + gen_args = utils.preprocess_cfg([default_gen_cfg_file, demo_gen_cfg_file]) + utils.launch_job(submit_cfg = None, gen_cfg = gen_args, train_cfg = None, func = generate) \ No newline at end of file diff --git a/scripts/test.py b/scripts/test.py new file mode 100644 index 0000000000000000000000000000000000000000..079fafc1eba5bfe0e346f9c8fda5eef43aef9a02 --- /dev/null +++ b/scripts/test.py @@ -0,0 +1,233 @@ +############################### +#### Brani-ID Inference ##### +############################### + +import os, sys, warnings, shutil, glob, time, datetime +warnings.filterwarnings("ignore") +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from collections import defaultdict + +import torch +import numpy as np + +from utils.misc import make_dir, viewVolume, MRIread +import utils.test_utils as utils +from Generator.utils import fast_3D_interp_torch + +device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu' + + +################################# +### For hemisphere prediction ### +################################# +label_list_left_segmentation = [0, 1, 2, 3, 4, 7, 8, 9, 10, 14, 15, 17, 31, 34, 36, 38, 40, 42] +lut = torch.zeros(10000, dtype=torch.long, device=device) +for l in range(len(label_list_left_segmentation)): + lut[label_list_left_segmentation[l]] = l + +# get left hemis mask +#S = read_brainseg.nii +#S = lut[X.astype(np.int)] +#X = read_mni_coord_X +#M = (S > 0) & (X < 0) + +# apply hemis mask for all image I +#I[M==0] = 0 +################################# +################################# + + +def prepare_paths(data_root, split_txt): + + # Collect list of available images, per dataset + datasets = [] + g = glob.glob(os.path.join(data_root, '*' + 'T1w.nii')) + for i in range(len(g)): + filename = os.path.basename(g[i]) + dataset = filename[:filename.find('.')] + found = False + for d in datasets: + if dataset == d: + found = True + if found is False: + datasets.append(dataset) + print('Found ' + str(len(datasets)) + ' datasets with ' + str(len(g)) + ' scans in total') + print('Dataset list', datasets) + names = [] + + split_file = open(split_txt, 'r') + split_names = [] + for subj in split_file.readlines(): + split_names.append(subj.strip()) + + for i in range(len(datasets)): + names.append([name for name in split_names if os.path.basename(name).startswith(datasets[i])]) + + datasets_num = len(datasets) + datasets_len = [len(names[i]) for i in range(len(names))] + print('Num of testing data', sum([len(names[i]) for i in range(len(names))])) + + return names, datasets + + +def get_info(t1): + + t2 = t1[:-7] + 'T2w.nii' + flair = t1[:-7] + 'FLAIR.nii' + ct = t1[:-7] + 'CT.nii' + cerebral_labels = t1[:-7] + 'brainseg.nii' + segmentation_labels = t1[:-7] + 'brainseg_with_extracerebral.nii' + brain_dist_map = t1[:-7] + 'brain_dist_map.nii' + lp_dist_map = t1[:-7] + 'lp_dist_map.nii' + rp_dist_map = t1[:-7] + 'rp_dist_map.nii' + lw_dist_map = t1[:-7] + 'lw_dist_map.nii' + rw_dist_map = t1[:-7] + 'rw_dist_map.nii' + mni_reg_x = t1[:-7] + 'mni_reg.x.nii' + mni_reg_y = t1[:-7] + 'mni_reg.y.nii' + mni_reg_z = t1[:-7] + 'mni_reg.z.nii' + + modalities = {'T1': t1} + if os.path.isfile(t2): + modalities.update({'T2': t2}) + if os.path.isfile(flair): + modalities.update({'FLAIR': flair}) + if os.path.isfile(ct): + modalities.update({'CT': ct}) + + aux = {'label': segmentation_labels, 'cerebral_label': cerebral_labels, 'distance': brain_dist_map, + 'regx': mni_reg_x, 'regy': mni_reg_y, 'regz': mni_reg_z, + 'lp': lp_dist_map, 'lw': lw_dist_map, 'rp': rp_dist_map, 'rw': rw_dist_map} + + return modalities, aux + + +################################# + + +gen_cfg = '/autofs/space/yogurt_003/users/pl629/code/MTBrainID/cfgs/generator/test/demo_test.yaml' +gen_hemis_cfg = '/autofs/space/yogurt_003/users/pl629/code/MTBrainID/cfgs/generator/test/demo_test_hemis.yaml' +model_cfg = '/autofs/space/yogurt_003/users/pl629/code/MTBrainID/cfgs/trainer/test/demo_test.yaml' + +#win_size = [192, 192, 192] +win_size = [160, 160, 160] +mask_output = False + + +exclude_keys = ['segmentation'] +data_root = '/autofs/vast/lemon/data_curated/brain_mris_QCed' +split_txt = '/autofs/vast/lemon/temp_stuff/peirong/train_test_split/test.txt' +names, datasets = prepare_paths(data_root, split_txt) + + +max_num_test_dataset = None #1 +max_num_per_dataset = None #5 + +zero_crop = False + +main_save_dir = make_dir('/autofs/space/yogurt_002/users/pl629/results/MTBrainID/test/', reset = False) + +models = [ + #('test_reggr', '/autofs/vast/lemon/temp_stuff/peirong/results/MTBrainID/wosr_reggrad/l6_16/1025-1744/ckp/checkpoint_latest.pth'), + #('test_lowres', '/autofs/vast/lemon/temp_stuff/peirong/results/MTBrainID/wosr_reggrad_lowres/l6_16/1025-1746/ckp/checkpoint_latest.pth'), + ('test_sr', '/autofs/vast/lemon/temp_stuff/peirong/results/MTBrainID/sr/l6_16/0926-2035/ckp/checkpoint_latest.pth'), + #('test_sr_lowres', '/autofs/vast/lemon/temp_stuff/peirong/results/MTBrainID/sr_lowres/l6_16/0926-2025/ckp/checkpoint_latest.pth'), + #('test_hemis', '/autofs/vast/lemon/temp_stuff/peirong/results/MTBrainID/hemis/l6_16/0806-1008/ckp/checkpoint_latest.pth'), + + #('comp_synth', '/autofs/vast/lemon/temp_stuff/peirong/results/MTBrainID/synth/l6_16/0924-0929/ckp/checkpoint_latest.pth'), + #('comp_dist', '/autofs/vast/lemon/temp_stuff/peirong/results/MTBrainID/dist/l6_16/0806-1024/ckp/checkpoint_latest.pth'), + #('comp_reg', '/autofs/vast/lemon/temp_stuff/peirong/results/MTBrainID/reg/l6_16/0806-1029/ckp/checkpoint_latest.pth'), + #('comp_bf', '/autofs/vast/lemon/temp_stuff/peirong/results/MTBrainID/bf/l6_16/0806-1026/ckp/checkpoint_latest.pth'), +] + +#spacing = [1.5, 1.5, 5] # [1, 1, 1], [1.5, 1.5, 5], [3, 3, 3], None +#add_bf = False +setups = [ + #([1, 1, 1], False), + #([1, 1, 1], True), + ([1.5, 1.5, 5], False), + #([1.5, 1.5, 5], True), +] + + + +all_start_time = time.time() +for postfix, ckp_path in models: + + for spacing, add_bf in setups: + curr_postfix = postfix + '_BF' if add_bf else postfix + curr_postfix += '_%s-%s-%s' % (str(spacing[0]), str(spacing[1]), str(spacing[2])) if spacing is not None else '_1-1-1' + save_dir = make_dir(os.path.join(main_save_dir, curr_postfix), reset = True) + print('\nSave at: %s\n' % save_dir) + + curr_gen_cfg = gen_hemis_cfg if 'hemis' in postfix else gen_cfg + + + for i, curr_dataset in enumerate(names): + curr_dataset.sort() + print('Dataset: %s (%d/%d) -- %d total cases' % (datasets[i], i+1, len(datasets), len(curr_dataset))) + + #''' + if max_num_test_dataset is not None and i >= max_num_test_dataset: + break + + start_time = time.time() + for j, t1_name in enumerate(curr_dataset): + + if max_num_per_dataset is not None and j >= max_num_per_dataset: + break + + subj_name = os.path.basename(t1_name).split('.T1w')[0] + subj_dir = make_dir(os.path.join(save_dir, subj_name)) + print('Now testing: %s (%d/%d)' % (t1_name, j+1, len(curr_dataset))) + + modalities, aux = get_info(t1_name) + + S_cerebral = torch.squeeze(utils.prepare_image(aux['cerebral_label'], win_size = win_size, zero_crop = zero_crop, spacing = spacing, rescale = False, im_only = True, device = device)) # read seg map + + if 'hemis' in postfix: + S = utils.prepare_image(aux['cerebral_label'], win_size = win_size, zero_crop = zero_crop, spacing = spacing, rescale = False, im_only = True, device = device) # read seg map + S = lut[S.int()] # mask out non-left labels + X = utils.prepare_image(aux['regx'], win_size = win_size, zero_crop = zero_crop, spacing = spacing, rescale = False, im_only = True, device = device) # read_mni_coord_X + hemis_mask = (S > 0) & (X < 0).int() # (1, 1, s, r, c) + viewVolume(hemis_mask, names = ['.'.join(os.path.basename(aux['label']).split('.')[:2]) + '.hemis_mask'], save_dir = subj_dir) + else: + hemis_mask = None + + # save all GT + for mod in modalities.keys(): + final, orig, high_res, bf, _, _, _ = utils.prepare_image(modalities[mod], win_size = win_size, zero_crop = zero_crop, spacing = spacing, add_bf = add_bf, is_CT = 'CT' in mod, rescale = False, hemis_mask = hemis_mask, im_only = False, device = device) + viewVolume(orig, names = [os.path.basename(modalities[mod])[:-4]], save_dir = subj_dir) + viewVolume(final, names = [os.path.basename(modalities[mod])[:-4] + '.input'], save_dir = subj_dir) + viewVolume(high_res, names = [os.path.basename(modalities[mod])[:-4] + '.high_res'], save_dir = subj_dir) + if bf is not None: + viewVolume(bf, names = [os.path.basename(modalities[mod])[:-4] + '.bias_field'], save_dir = subj_dir) + for mod in aux.keys(): + im = utils.prepare_image(aux[mod], win_size = win_size, zero_crop = zero_crop, is_label = 'label' in mod, rescale = False, hemis_mask = hemis_mask, im_only = True, device = device) + viewVolume(im, names = [os.path.basename(aux[mod])[:-4]], save_dir = subj_dir) + + # testing + for mod in modalities.keys(): + test_dir = make_dir(os.path.join(subj_dir, 'input_' + mod)) + im = utils.prepare_image(os.path.join(subj_dir, os.path.basename(modalities[mod])[:-4] + '.input.nii.gz'), win_size = win_size, zero_crop = zero_crop, is_CT = 'CT' in mod, hemis_mask = hemis_mask, im_only = True, device = device) + outs = utils.evaluate_image(im, ckp_path = ckp_path, feature_only = False, device = device, gen_cfg = curr_gen_cfg, model_cfg = model_cfg) + + if mask_output: + mask = im.clone() + mask[im != 0.] = 1. + + for k, v in outs.items(): + if 'feat' not in k and k not in exclude_keys: + viewVolume(v * mask if mask_output else v, names = [ 'out_' + k], save_dir = test_dir) + + print(S_cerebral.shape, outs['regx'].shape) + deformed_atlas = utils.get_deformed_atlas(S_cerebral, torch.squeeze(outs['regx']), torch.squeeze(outs['regy']), torch.squeeze(outs['regz'])) + viewVolume(deformed_atlas * mask if mask_output else deformed_atlas, names = [ 'out_deformed_atlas'], save_dir = test_dir) + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('Testing time for {}: {}'.format(total_time_str, datasets[i])) + + all_total_time = time.time() - all_start_time + all_total_time_str = str(datetime.timedelta(seconds=int(all_total_time))) +print('Total testing time: {}'.format(total_time_str)) +#''' \ No newline at end of file diff --git a/scripts/test.sh b/scripts/test.sh new file mode 100644 index 0000000000000000000000000000000000000000..005db7576e6b3639e09a7b62dab29c8e4ec906b6 --- /dev/null +++ b/scripts/test.sh @@ -0,0 +1,23 @@ +#!/bin/bash + +#SBATCH --job-name=test_sr # test_lr +#SBATCH --gpus=1 +#SBATCH --partition=rtx8000 # rtx8000, lcnrtx, dgx-a100, lcnv100, rtx6000 + +#SBATCH --mail-type=FAIL +#SBATCH --account=mlsclemon +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=24 +#SBATCH --mem=256G +#SBATCH --time=6-23:59:59 +#SBATCH --output=/autofs/vast/lemon/temp_stuff/peirong/logs/%j.log # Standard output and error log + + +# exp-specific cfg # +#exp_cfg_file='/autofs/space/yogurt_003/users/pl629/code/MTBrainID/cfgs/test/feat.yaml' + + +date;hostname;pwd +python /autofs/space/yogurt_003/users/pl629/code/MTBrainID/scripts/test.py #$exp_cfg_file +date diff --git a/scripts/train.py b/scripts/train.py new file mode 100644 index 0000000000000000000000000000000000000000..459c2ac2fb5186f363e972f4c53560719e65b1ca --- /dev/null +++ b/scripts/train.py @@ -0,0 +1,249 @@ + +import datetime +import os, sys +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +import glob +import yaml +import json +import random +import time +from argparse import Namespace +from pathlib import Path + + +import numpy as np +import torch +import torch.nn as nn + +from torch.utils.data import DataLoader + +from utils.checkpoint import load_checkpoint +import utils.logging as logging +import utils.misc as utils + +from Generator import build_datasets +from Trainer.visualizer import TaskVisualizer, FeatVisualizer +from Trainer.models import build_model, build_optimizer, build_schedulers +from Trainer.engine import train_one_epoch + + + +logger = logging.get_logger(__name__) + + +# default & gpu cfg # +submit_cfg_file = '/autofs/space/yogurt_003/users/pl629/code/MTBrainID/cfgs/submit.yaml' + +default_gen_cfg_file = '/autofs/space/yogurt_003/users/pl629/code/MTBrainID/cfgs/generator/default.yaml' + +default_train_cfg_file = '/autofs/space/yogurt_003/users/pl629/code/MTBrainID/cfgs/trainer/default_train.yaml' +default_val_file = '/autofs/space/yogurt_003/users/pl629/code/MTBrainID/cfgs/trainer/default_val.yaml' + +gen_cfg_dir = '/autofs/space/yogurt_003/users/pl629/code/MTBrainID/cfgs/generator/train' +train_cfg_dir = '/autofs/space/yogurt_003/users/pl629/code/MTBrainID/cfgs/trainer/train' + + +def get_params_groups(model): + all = [] + for name, param in model.named_parameters(): + if not param.requires_grad: + continue + # we do not regularize biases nor Norm parameters + all.append(param) + return [{'params': all}] + + +def train(args): + + """ + args: list of configs + """ + + submit_args, gen_args, train_args = args + + utils.init_distributed_mode(submit_args) + if torch.cuda.is_available(): + if submit_args.num_gpus > torch.cuda.device_count(): + submit_args.num_gpus = torch.cuda.device_count() + assert ( + submit_args.num_gpus <= torch.cuda.device_count() + ), "Cannot use more GPU devices than available" + else: + submit_args.num_gpus = 0 + + if train_args.debug: + submit_args.num_workers = 0 + + output_dir = utils.make_dir(train_args.out_dir) + cfg_dir = utils.make_dir(os.path.join(output_dir, "cfg")) + plt_dir = utils.make_dir(os.path.join(output_dir, "plt")) + vis_train_dir = utils.make_dir(os.path.join(output_dir, "vis-train")) + ckp_output_dir = utils.make_dir(os.path.join(output_dir, "ckp")) + #ckp_epoch_dir = utils.make_dir(os.path.join(ckp_output_dir, "epoch")) + + yaml.dump( + vars(submit_args), + open(cfg_dir / 'config_submit.yaml', 'w'), allow_unicode=True) + yaml.dump( + vars(gen_args), + open(cfg_dir / 'config_generator.yaml', 'w'), allow_unicode=True) + yaml.dump( + vars(train_args), + open(cfg_dir / 'config_trainer.yaml', 'w'), allow_unicode=True) + + # ============ setup logging ... ============ + logging.setup_logging(output_dir) + logger.info("git:\n {}\n".format(utils.get_sha())) + logger.info("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(submit_args)).items()))) + logger.info("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(gen_args)).items()))) + logger.info("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(train_args)).items()))) + log_path = os.path.join(output_dir, 'log.txt') + + if submit_args.device is not None: # assign to specified device + device = submit_args.device + elif torch.cuda.is_available(): + device = torch.cuda.current_device() + else: + device = 'cpu' + logger.info('device: %s' % device) + + # fix the seed for reproducibility + #seed = submit_args.seed + utils.get_rank() + seed = int(time.time()) + + os.environ['PYTHONHASHSEED'] = str(seed) + + np.random.seed(seed) + random.seed(seed) + + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.backends.cudnn.deterministic = True + + # ============ preparing data ... ============ + dataset_dict = build_datasets(gen_args, device = gen_args.device_generator if gen_args.device_generator is not None else device) + data_loader_dict = {} + data_total = 0 + for name in dataset_dict.keys(): + if submit_args.num_gpus>1: + sampler_train = utils.DistributedWeightedSampler(dataset_dict[name]) + else: + sampler_train = torch.utils.data.RandomSampler(dataset_dict[name]) + + data_loader_dict[name] = DataLoader( + dataset_dict[name], + batch_sampler=torch.utils.data.BatchSampler(sampler_train, train_args.batch_size, drop_last=True), + #collate_fn=utils.collate_fn, # apply custom data cooker if needed + num_workers=submit_args.num_workers) + data_total += len(data_loader_dict[name]) + logger.info('Dataset: {}'.format(name)) + logger.info('Num of total training data: {}'.format(data_total)) + + visualizers = {'result': TaskVisualizer(gen_args, train_args)} + if train_args.visualizer.feat_vis: + visualizers['feature'] = FeatVisualizer(gen_args, train_args) + + # ============ building model ... ============ + gen_args, train_args, model, processors, criterion, postprocessor = build_model(gen_args, train_args, device = device) # train: True; test: False + + model_without_ddp = model + # Use multi-process data parallel model in the multi-gpu setting + if submit_args.num_gpus > 1: + logger.info('currect device: %s' % str(torch.cuda.current_device())) + # Make model replica operate on the current device + model = torch.nn.parallel.DistributedDataParallel( + module=model, device_ids=[device], output_device=device, + find_unused_parameters=True + ) + model_without_ddp = model.module # unwarp the model + n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) + logger.info('Num of trainable model params: {}'.format(n_parameters)) + + + # ============ preparing optimizer ... ============ + scaler = torch.cuda.amp.GradScaler() + param_dicts = get_params_groups(model_without_ddp) + optimizer = build_optimizer(train_args, param_dicts) + + # ============ init schedulers ... ============ + lr_scheduler, wd_scheduler = build_schedulers(train_args, data_total, train_args.lr, train_args.min_lr) + logger.info(f"Optimizer and schedulers ready.") + + + best_val_stats = None + train_args.start_epoch = 0 + # Load weights if provided + if train_args.resume or train_args.eval_only: + if train_args.ckp_path: + ckp_path = train_args.ckp_path + else: + ckp_path = sorted(glob.glob(ckp_output_dir + '/*.pth')) + + train_args.start_epoch, best_val_stats = load_checkpoint(ckp_path, [model_without_ddp], optimizer, ['model'], exclude_key = 'supervised_seg') + logger.info(f"Resume epoch: {train_args.start_epoch}") + else: + logger.info('Starting from scratch') + if train_args.reset_epoch: + train_args.start_epoch = 0 + logger.info(f"Start epoch: {train_args.start_epoch}") + + # ============ start training ... ============ + + logger.info("Start training") + start_time = time.time() + + for epoch in range(train_args.start_epoch, train_args.n_epochs): + + if os.path.isfile(os.path.join(ckp_output_dir,'checkpoint_latest.pth')): + os.rename(os.path.join(ckp_output_dir,'checkpoint_latest.pth'), os.path.join(ckp_output_dir,'checkpoint_latest_bk.pth')) + + checkpoint_paths = [ckp_output_dir / 'checkpoint_latest.pth'] + + # ============ save model ... ============ + #checkpoint_paths.append(ckp_epoch_dir / f"checkpoint_epoch_{epoch}.pth") + + for checkpoint_path in checkpoint_paths: + utils.save_on_master({ + 'model': model_without_ddp.state_dict(), + 'optimizer': optimizer.state_dict(), + 'epoch': epoch, + 'submit_args': submit_args, + 'gen_args': gen_args, + 'train_args': train_args, + 'best_val_stats': best_val_stats + }, checkpoint_path) + + # ============ training one epoch ... ============ + if submit_args.num_gpus > 1: + sampler_train.set_epoch(epoch) + log_stats = train_one_epoch(epoch, gen_args, train_args, model_without_ddp, processors, criterion, data_loader_dict, + scaler, optimizer, lr_scheduler, wd_scheduler, postprocessor, visualizers, vis_train_dir, device) + + # ============ writing logs ... ============ + if utils.is_main_process(): + with (Path(output_dir) / "log.txt").open("a") as f: + f.write('epoch %s - ' % str(epoch).zfill(5)) + f.write(json.dumps(log_stats) + "\n") + + # ============ plot training losses ... ============ + if os.path.isfile(log_path): + sum_losses = [0.] * (epoch + 1) + for loss_name in criterion.loss_names: + curr_epoches, curr_losses = utils.read_log(log_path, 'loss_' + loss_name) + sum_losses = [sum_losses[i] + curr_losses[i] for i in range(len(curr_losses))] + utils.plot_loss(curr_losses, os.path.join(utils.make_dir(plt_dir), 'loss_%s.png' % loss_name)) + utils.plot_loss(sum_losses, os.path.join(utils.make_dir(plt_dir), 'loss_all.png')) + + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + logger.info('Training time {}'.format(total_time_str)) + + +##################################################################################### + +if __name__ == '__main__': + submit_args = utils.preprocess_cfg([submit_cfg_file]) + gen_args = utils.preprocess_cfg([default_gen_cfg_file, sys.argv[1]], cfg_dir = gen_cfg_dir) + train_args = utils.preprocess_cfg([default_train_cfg_file, default_val_file, sys.argv[2]], cfg_dir = train_cfg_dir) + utils.launch_job(submit_args, gen_args, train_args, train) \ No newline at end of file diff --git a/scripts/train.sh b/scripts/train.sh new file mode 100644 index 0000000000000000000000000000000000000000..797bd163277a397064760d003632c2a4ca948e03 --- /dev/null +++ b/scripts/train.sh @@ -0,0 +1,24 @@ +#!/bin/bash + +#SBATCH --job-name=lr # reggr hemis lowres age_pool sr_lowres synth +#SBATCH --gpus=1 +#SBATCH --partition=lcnrtx # lcna100, lcnrtx, lcna40, rtx8000, dgx-a100, lcnv100, rtx6000 + +#SBATCH --mail-type=FAIL +#SBATCH --account=lcnlemon #lcnrtx, lcnlemon, mlsclemon +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=4 +#SBATCH --mem=64G # 128G +#SBATCH --time=29-23:59:59 +#SBATCH --output=/autofs/vast/lemon/temp_stuff/peirong/logs/%j.log # Standard output and error log + + +# exp-specific cfg # +gen_cfg_file=/autofs/space/yogurt_003/users/pl629/code/MTBrainID/cfgs/generator/train/brain_id_lowres.yaml +train_cfg_file=/autofs/space/yogurt_003/users/pl629/code/MTBrainID/cfgs/trainer/train/joint_lowres.yaml + + +date;hostname;pwd +python /autofs/space/yogurt_003/users/pl629/code/MTBrainID/scripts/train.py $gen_cfg_file $train_cfg_file +date \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..b161a63bb3b07bf20589787d4d08c0da50aae88f --- /dev/null +++ b/setup.py @@ -0,0 +1,4 @@ +from setuptools import setup, find_packages + +# run it: python setup.py install +setup(name='Brain-ID', version='1.0', packages=find_packages()) \ No newline at end of file diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/utils/checkpoint.py b/utils/checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..ed7f09aee43f81a9d1a93b3e4fabd23cf1287fab --- /dev/null +++ b/utils/checkpoint.py @@ -0,0 +1,686 @@ +#!/usr/bin/env python3 + +"""Functions that handle saving and loading of checkpoints.""" + +import os + +import torch +import torch.nn as nn + +import utils.distributed as du +import utils.logging as logging +from utils.env import checkpoint_pathmgr as pathmgr + +from tabulate import tabulate + +logger = logging.get_logger(__name__) + + +import copy +import logging +import re +from typing import Dict, List +import torch +from tabulate import tabulate + + +def convert_basic_c2_names(original_keys): + """ + Apply some basic name conversion to names in C2 weights. + It only deals with typical backbone models. + + Args: + original_keys (list[str]): + Returns: + list[str]: The same number of strings matching those in original_keys. + """ + layer_keys = copy.deepcopy(original_keys) + layer_keys = [ + {"pred_b": "linear_b", "pred_w": "linear_w"}.get(k, k) for k in layer_keys + ] # some hard-coded mappings + + layer_keys = [k.replace("_", ".") for k in layer_keys] + layer_keys = [re.sub("\\.b$", ".bias", k) for k in layer_keys] + layer_keys = [re.sub("\\.w$", ".weight", k) for k in layer_keys] + # Uniform both bn and gn names to "norm" + layer_keys = [re.sub("bn\\.s$", "norm.weight", k) for k in layer_keys] + layer_keys = [re.sub("bn\\.bias$", "norm.bias", k) for k in layer_keys] + layer_keys = [re.sub("bn\\.rm", "norm.running_mean", k) for k in layer_keys] + layer_keys = [re.sub("bn\\.running.mean$", "norm.running_mean", k) for k in layer_keys] + layer_keys = [re.sub("bn\\.riv$", "norm.running_var", k) for k in layer_keys] + layer_keys = [re.sub("bn\\.running.var$", "norm.running_var", k) for k in layer_keys] + layer_keys = [re.sub("bn\\.gamma$", "norm.weight", k) for k in layer_keys] + layer_keys = [re.sub("bn\\.beta$", "norm.bias", k) for k in layer_keys] + layer_keys = [re.sub("gn\\.s$", "norm.weight", k) for k in layer_keys] + layer_keys = [re.sub("gn\\.bias$", "norm.bias", k) for k in layer_keys] + + # stem + layer_keys = [re.sub("^res\\.conv1\\.norm\\.", "conv1.norm.", k) for k in layer_keys] + # to avoid mis-matching with "conv1" in other components (e.g. detection head) + layer_keys = [re.sub("^conv1\\.", "stem.conv1.", k) for k in layer_keys] + + # layer1-4 is used by torchvision, however we follow the C2 naming strategy (res2-5) + # layer_keys = [re.sub("^res2.", "layer1.", k) for k in layer_keys] + # layer_keys = [re.sub("^res3.", "layer2.", k) for k in layer_keys] + # layer_keys = [re.sub("^res4.", "layer3.", k) for k in layer_keys] + # layer_keys = [re.sub("^res5.", "layer4.", k) for k in layer_keys] + + # blocks + layer_keys = [k.replace(".branch1.", ".shortcut.") for k in layer_keys] + layer_keys = [k.replace(".branch2a.", ".conv1.") for k in layer_keys] + layer_keys = [k.replace(".branch2b.", ".conv2.") for k in layer_keys] + layer_keys = [k.replace(".branch2c.", ".conv3.") for k in layer_keys] + + # DensePose substitutions + layer_keys = [re.sub("^body.conv.fcn", "body_conv_fcn", k) for k in layer_keys] + layer_keys = [k.replace("AnnIndex.lowres", "ann_index_lowres") for k in layer_keys] + layer_keys = [k.replace("Index.UV.lowres", "index_uv_lowres") for k in layer_keys] + layer_keys = [k.replace("U.lowres", "u_lowres") for k in layer_keys] + layer_keys = [k.replace("V.lowres", "v_lowres") for k in layer_keys] + return layer_keys + + +def convert_c2_detectron_names(weights): + """ + Map Caffe2 Detectron weight names to Detectron2 names. + + Args: + weights (dict): name -> tensor + + Returns: + dict: detectron2 names -> tensor + dict: detectron2 names -> C2 names + """ + logger = logging.getLogger(__name__) + logger.info("Renaming Caffe2 weights ......") + original_keys = sorted(weights.keys()) + layer_keys = copy.deepcopy(original_keys) + + layer_keys = convert_basic_c2_names(layer_keys) + + # -------------------------------------------------------------------------- + # RPN hidden representation conv + # -------------------------------------------------------------------------- + # FPN case + # In the C2 model, the RPN hidden layer conv is defined for FPN level 2 and then + # shared for all other levels, hence the appearance of "fpn2" + layer_keys = [ + k.replace("conv.rpn.fpn2", "proposal_generator.rpn_head.conv") for k in layer_keys + ] + # Non-FPN case + layer_keys = [k.replace("conv.rpn", "proposal_generator.rpn_head.conv") for k in layer_keys] + + # -------------------------------------------------------------------------- + # RPN box transformation conv + # -------------------------------------------------------------------------- + # FPN case (see note above about "fpn2") + layer_keys = [ + k.replace("rpn.bbox.pred.fpn2", "proposal_generator.rpn_head.anchor_deltas") + for k in layer_keys + ] + layer_keys = [ + k.replace("rpn.cls.logits.fpn2", "proposal_generator.rpn_head.objectness_logits") + for k in layer_keys + ] + # Non-FPN case + layer_keys = [ + k.replace("rpn.bbox.pred", "proposal_generator.rpn_head.anchor_deltas") for k in layer_keys + ] + layer_keys = [ + k.replace("rpn.cls.logits", "proposal_generator.rpn_head.objectness_logits") + for k in layer_keys + ] + + # -------------------------------------------------------------------------- + # Fast R-CNN box head + # -------------------------------------------------------------------------- + layer_keys = [re.sub("^bbox\\.pred", "bbox_pred", k) for k in layer_keys] + layer_keys = [re.sub("^cls\\.score", "cls_score", k) for k in layer_keys] + layer_keys = [re.sub("^fc6\\.", "box_head.fc1.", k) for k in layer_keys] + layer_keys = [re.sub("^fc7\\.", "box_head.fc2.", k) for k in layer_keys] + # 4conv1fc head tensor names: head_conv1_w, head_conv1_gn_s + layer_keys = [re.sub("^head\\.conv", "box_head.conv", k) for k in layer_keys] + + # -------------------------------------------------------------------------- + # FPN lateral and output convolutions + # -------------------------------------------------------------------------- + def fpn_map(name): + """ + Look for keys with the following patterns: + 1) Starts with "fpn.inner." + Example: "fpn.inner.res2.2.sum.lateral.weight" + Meaning: These are lateral pathway convolutions + 2) Starts with "fpn.res" + Example: "fpn.res2.2.sum.weight" + Meaning: These are FPN output convolutions + """ + splits = name.split(".") + norm = ".norm" if "norm" in splits else "" + if name.startswith("fpn.inner."): + # splits example: ['fpn', 'inner', 'res2', '2', 'sum', 'lateral', 'weight'] + stage = int(splits[2][len("res") :]) + return "fpn_lateral{}{}.{}".format(stage, norm, splits[-1]) + elif name.startswith("fpn.res"): + # splits example: ['fpn', 'res2', '2', 'sum', 'weight'] + stage = int(splits[1][len("res") :]) + return "fpn_output{}{}.{}".format(stage, norm, splits[-1]) + return name + + layer_keys = [fpn_map(k) for k in layer_keys] + + # -------------------------------------------------------------------------- + # Mask R-CNN mask head + # -------------------------------------------------------------------------- + # roi_heads.StandardROIHeads case + layer_keys = [k.replace(".[mask].fcn", "mask_head.mask_fcn") for k in layer_keys] + layer_keys = [re.sub("^\\.mask\\.fcn", "mask_head.mask_fcn", k) for k in layer_keys] + layer_keys = [k.replace("mask.fcn.logits", "mask_head.predictor") for k in layer_keys] + # roi_heads.Res5ROIHeads case + layer_keys = [k.replace("conv5.mask", "mask_head.deconv") for k in layer_keys] + + # -------------------------------------------------------------------------- + # Keypoint R-CNN head + # -------------------------------------------------------------------------- + # interestingly, the keypoint head convs have blob names that are simply "conv_fcnX" + layer_keys = [k.replace("conv.fcn", "roi_heads.keypoint_head.conv_fcn") for k in layer_keys] + layer_keys = [ + k.replace("kps.score.lowres", "roi_heads.keypoint_head.score_lowres") for k in layer_keys + ] + layer_keys = [k.replace("kps.score.", "roi_heads.keypoint_head.score.") for k in layer_keys] + + # -------------------------------------------------------------------------- + # Done with replacements + # -------------------------------------------------------------------------- + assert len(set(layer_keys)) == len(layer_keys) + assert len(original_keys) == len(layer_keys) + + new_weights = {} + new_keys_to_original_keys = {} + for orig, renamed in zip(original_keys, layer_keys): + new_keys_to_original_keys[renamed] = orig + if renamed.startswith("bbox_pred.") or renamed.startswith("mask_head.predictor."): + # remove the meaningless prediction weight for background class + new_start_idx = 4 if renamed.startswith("bbox_pred.") else 1 + new_weights[renamed] = weights[orig][new_start_idx:] + logger.info( + "Remove prediction weight for background class in {}. The shape changes from " + "{} to {}.".format( + renamed, tuple(weights[orig].shape), tuple(new_weights[renamed].shape) + ) + ) + elif renamed.startswith("cls_score."): + # move weights of bg class from original index 0 to last index + logger.info( + "Move classification weights for background class in {} from index 0 to " + "index {}.".format(renamed, weights[orig].shape[0] - 1) + ) + new_weights[renamed] = torch.cat([weights[orig][1:], weights[orig][:1]]) + else: + new_weights[renamed] = weights[orig] + + return new_weights, new_keys_to_original_keys + + + + +def _group_keys_by_module(keys: List[str], original_names: Dict[str, str]): + """ + Params in the same submodule are grouped together. + + Args: + keys: names of all parameters + original_names: mapping from parameter name to their name in the checkpoint + + Returns: + dict[name -> all other names in the same group] + """ + + def _submodule_name(key): + pos = key.rfind(".") + if pos < 0: + return None + prefix = key[: pos + 1] + return prefix + + all_submodules = [_submodule_name(k) for k in keys] + all_submodules = [x for x in all_submodules if x] + all_submodules = sorted(all_submodules, key=len) + + ret = {} + for prefix in all_submodules: + group = [k for k in keys if k.startswith(prefix)] + if len(group) <= 1: + continue + original_name_lcp = _longest_common_prefix_str([original_names[k] for k in group]) + if len(original_name_lcp) == 0: + # don't group weights if original names don't share prefix + continue + + for k in group: + if k in ret: + continue + ret[k] = group + return ret + + +def _longest_common_prefix(names): + """ + ["abc.zfg", "abc.zef"] -> "abc." + """ + names = [n.split(".") for n in names] + m1, m2 = min(names), max(names) + ret = [a for a, b in zip(m1, m2) if a == b] + ret = ".".join(ret) + "." if len(ret) else "" + return ret + + +def _longest_common_prefix_str(names): + m1, m2 = min(names), max(names) + lcp = [] + for a, b in zip(m1, m2): + if a == b: + lcp.append(a) + else: + break + lcp = "".join(lcp) + return lcp + +def _group_str(names): + """ + Turn "common1", "common2", "common3" into "common{1,2,3}" + """ + lcp = _longest_common_prefix_str(names) + rest = [x[len(lcp) :] for x in names] + rest = "{" + ",".join(rest) + "}" + ret = lcp + rest + + # add some simplification for BN specifically + ret = ret.replace("bn_{beta,running_mean,running_var,gamma}", "bn_*") + ret = ret.replace("bn_beta,bn_running_mean,bn_running_var,bn_gamma", "bn_*") + return ret + +def make_checkpoint_dir(path_to_job): + """ + Creates the checkpoint directory (if not present already). + Args: + path_to_job (string): the path to the folder of the current job. + """ + checkpoint_dir = os.path.join(path_to_job, "checkpoints") + # Create the checkpoint dir from the master process + if du.is_master_proc() and not pathmgr.exists(checkpoint_dir): + try: + pathmgr.mkdirs(checkpoint_dir) + except Exception: + pass + return checkpoint_dir + + +def get_checkpoint_dir(path_to_job): + """ + Get path for storing checkpoints. + Args: + path_to_job (string): the path to the folder of the current job. + """ + return os.path.join(path_to_job, "checkpoints") + + +def get_path_to_checkpoint(path_to_job, epoch): + """ + Get the full path to a checkpoint file. + Args: + path_to_job (string): the path to the folder of the current job. + epoch (int): the number of epoch for the checkpoint. + """ + name = "checkpoint_epoch_{:05d}.pyth".format(epoch) + return os.path.join(get_checkpoint_dir(path_to_job), name) + + +def get_last_checkpoint(path_to_job): + """ + Get the last checkpoint from the checkpointing folder. + Args: + path_to_job (string): the path to the folder of the current job. + """ + name = "checkpoint_latest.pyth" + return os.path.join(get_checkpoint_dir(path_to_job), name) + + +def has_checkpoint(path_to_job): + """ + Determines if the given directory contains a checkpoint. + Args: + path_to_job (string): the path to the folder of the current job. + """ + d = get_checkpoint_dir(path_to_job) + files = pathmgr.ls(d) if pathmgr.exists(d) else [] + return any("checkpoint" in f for f in files) + + +def is_checkpoint_epoch(cfg, cur_iter): + """ + Determine if a checkpoint should be saved on current epoch. + Args: + cfg (CfgNode): configs to save. + cur_epoch (int): current number of epoch of the model. + """ + if cur_iter + 1 == cfg.SOLVER.MAX_EPOCH: + return True + + return (cur_iter + 1) % cfg.TRAIN.CHECKPOINT_PERIOD == 0 + + +def save_checkpoint(path_to_job, model, optimizer, iter, cfg, scaler=None): + """ + Save a checkpoint. + Args: + model (model): model to save the weight to the checkpoint. + optimizer (optim): optimizer to save the historical state. + epoch (int): current number of epoch of the model. + cfg (CfgNode): configs to save. + scaler (GradScaler): the mixed precision scale. + """ + # Save checkpoints only from the master process. + if not du.is_master_proc(cfg.NUM_GPUS * cfg.NUM_SHARDS): + return + # Ensure that the checkpoint dir exists. + pathmgr.mkdirs(get_checkpoint_dir(path_to_job)) + # Omit the DDP wrapper in the multi-gpu setting. + sd = model.module.state_dict() if cfg.NUM_GPUS > 1 else model.state_dict() + + # Record the state. + checkpoint = { + "epoch": iter, + "model_state": sd, + "optimizer_state": optimizer.state_dict(), + "cfg": cfg.dump(), + } + if scaler is not None: + checkpoint["scaler_state"] = scaler.state_dict() + # Write the current epoch checkpoint & update the latest epoch checkpoint + path_to_checkpoint = get_path_to_checkpoint(path_to_job, iter + 1) + with pathmgr.open(path_to_checkpoint, "wb") as f: + torch.save(checkpoint, f) + path_to_latest_checkpoint = get_last_checkpoint(path_to_job) + with pathmgr.open(path_to_latest_checkpoint, "wb") as f: + torch.save(checkpoint, f) + return path_to_checkpoint + + +def load_checkpoint( + path_to_checkpoint, + models, + optimizer = None, + model_keys = ['model'], + exclude_key = None, + to_match = {}, + to_print = True, +): + """ + Load the checkpoint from the given file. + """ + assert pathmgr.exists(path_to_checkpoint), "Checkpoint '{}' not found".format( + path_to_checkpoint + ) + if to_print: + logger.info("Loading network weights from {}.".format(path_to_checkpoint)) + + + # Load the checkpoint on CPU to avoid GPU mem spike. + + def find_model_key(keys, model_key): + for k in keys: + if model_key in k: + return k + for k in keys: + if 'model' in k: + if to_print: + logger.info('Have not found model state_dict according to the given key, but using the "model" as key instead!') + return k + + + with pathmgr.open(path_to_checkpoint, "rb") as f: + checkpoint = torch.load(f, map_location="cpu") + + for i, model in enumerate(models): + ms = model + #ms = model.module if data_parallel else model # Account for the DDP wrapper in the multi-gpu setting. + model_dict = ms.state_dict() + + k = find_model_key(checkpoint.keys(), model_keys[i]) + pre_train_dict = checkpoint[k] + + ms.load_state_dict(align_and_update_state_dicts(model_dict, pre_train_dict, exclude_key = exclude_key, to_print = to_print, to_match = to_match), strict=False) + + if optimizer and 'optimizaer' in checkpoint: + optimizer.load_state_dict(checkpoint['optimizer']) + best_val_stats = checkpoint['best_val_stats'] if 'best_val_stats' in checkpoint else None + return checkpoint['epoch'], best_val_stats + + + +def load_test_checkpoint(cfg, model): + """ + Loading checkpoint logic for testing. + """ + # Load a checkpoint to test if applicable. + if cfg.TEST.CHECKPOINT_FILE_PATH != "": + load_checkpoint( + cfg.TEST.CHECKPOINT_FILE_PATH, + model, + cfg.NUM_GPUS > 1, + None, + squeeze_temporal=cfg.TEST.CHECKPOINT_SQUEEZE_TEMPORAL, + ) + elif has_checkpoint(cfg.OUTPUT_DIR): + last_checkpoint = get_last_checkpoint(cfg.OUTPUT_DIR) + load_checkpoint(last_checkpoint, model, cfg.NUM_GPUS > 1) + elif cfg.TRAIN.CHECKPOINT_FILE_PATH != "": + # If no checkpoint found in TEST.CHECKPOINT_FILE_PATH or in the current + # checkpoint folder, try to load checkpoint from + # TRAIN.CHECKPOINT_FILE_PATH and test it. + load_checkpoint( + cfg.TRAIN.CHECKPOINT_FILE_PATH, + model, + cfg.NUM_GPUS > 1, + None, + ) + else: + logger.info( + "Unknown way of loading checkpoint. Using random initialization, only for debugging." + ) + + +def load_train_checkpoint(cfg, model, optimizer, scaler=None): + """ + Loading checkpoint logic for training. + """ + if cfg.TRAIN.AUTO_RESUME and has_checkpoint(cfg.OUTPUT_DIR): + last_checkpoint = get_last_checkpoint(cfg.OUTPUT_DIR) + logger.info("Load from last checkpoint, {}.".format(last_checkpoint)) + checkpoint_epoch = load_checkpoint( + last_checkpoint, model, cfg.NUM_GPUS > 1, optimizer, scaler=scaler + ) + start_epoch = checkpoint_epoch + 1 + elif cfg.TRAIN.CHECKPOINT_FILE_PATH != "" and cfg.TRAIN.FINETUNE: + logger.info("Finetune from given checkpoint file.") + checkpoint_epoch = load_checkpoint( + cfg.TRAIN.CHECKPOINT_FILE_PATH, + model, + cfg.NUM_GPUS > 1, + optimizer, + scaler=scaler, + epoch_reset=cfg.TRAIN.CHECKPOINT_EPOCH_RESET, + freeze_pretrain=cfg.TRAIN.FREEZE_PRETRAIN, + ) + start_epoch = checkpoint_epoch + 1 if cfg.TRAIN.FINETUNE_START_EPOCH == 0 else cfg.TRAIN.FINETUNE_START_EPOCH + elif cfg.TRAIN.CHECKPOINT_FILE_PATH != "": + logger.info("Load from given checkpoint file.") + checkpoint_epoch = load_checkpoint( + cfg.TRAIN.CHECKPOINT_FILE_PATH, + model, + cfg.NUM_GPUS > 1, + optimizer, + scaler=scaler, + epoch_reset=cfg.TRAIN.CHECKPOINT_EPOCH_RESET, + ) + start_epoch = checkpoint_epoch + 1 + else: + start_epoch = 0 + + return start_epoch + + + + + +# Note the current matching is not symmetric. +# it assumes model_state_dict will have longer names. +def align_and_update_state_dicts(model_state_dict, ckpt_state_dict, exclude_key = None, to_print = True, to_match = {}): + """ + Match names between the two state-dict, and returns a new chkpt_state_dict with names + converted to match model_state_dict with heuristics. The returned dict can be later + loaded with fvcore checkpointer. + """ + if exclude_key is not None: + model_keys = sorted([k for k in model_state_dict.keys() if exclude_key not in k]) + else: + model_keys = sorted(model_state_dict.keys()) + original_keys = {x: x for x in ckpt_state_dict.keys()} + ckpt_keys = sorted(ckpt_state_dict.keys()) + + def in_to_match(a, b): + for k in to_match.keys(): + c = b.replace(k, to_match[k]) + if a == c or a.endswith("." + c): + return True + return False + + def match(a, b): + if (a == b or a.endswith("." + b) or in_to_match(a, b)) and to_print: + print('matched') + print(a, '--', b) + return a == b or a.endswith("." + b) or in_to_match(a, b) + + # get a matrix of string matches, where each (i, j) entry correspond to the size of the + # ckpt_key string, if it matches + match_matrix = [len(j) if match(i, j) else 0 for i in model_keys for j in ckpt_keys] + match_matrix = torch.as_tensor(match_matrix).view(len(model_keys), len(ckpt_keys)) + # use the matched one with longest size in case of multiple matches + max_match_size, idxs = match_matrix.max(1) + # remove indices that correspond to no-match + idxs[max_match_size == 0] = -1 + + #logger = logging.getLogger(__name__) + # matched_pairs (matched checkpoint key --> matched model key) + matched_keys = {} + result_state_dict = {} + for idx_model, idx_ckpt in enumerate(idxs.tolist()): + if idx_ckpt == -1: + continue + key_model = model_keys[idx_model] + key_ckpt = ckpt_keys[idx_ckpt] + value_ckpt = ckpt_state_dict[key_ckpt] + shape_in_model = model_state_dict[key_model].shape + + if shape_in_model != value_ckpt.shape: + logger.warning( + "Shape of {} in checkpoint is {}, while shape of {} in model is {}.".format( + key_ckpt, value_ckpt.shape, key_model, shape_in_model + ) + ) + if shape_in_model[0] != value_ckpt.shape[0] and len(shape_in_model) == len(value_ckpt.shape): # different embed_dim setup + logger.warning( + "{} will not be loaded. Please double check and see if this is desired.".format( + key_ckpt + ) + ) + logger.warning('--- shape_in_model: {}'.format(shape_in_model)) + logger.warning('--- ckpt shape: {}'.format(value_ckpt.shape)) + else: + logger.warning( + "{} will be loaded for the center frame with the weights from the 2D conv layers in pre-trained models and\ + initialize other weights as zero. Please double check and see if this is desired.".format( + key_ckpt + ) + ) + assert key_model not in result_state_dict + logger.warning('--- shape_in_model: {}'.format(shape_in_model)) + logger.warning('--- ckpt shape: {}'.format(value_ckpt.shape)) + # load pre-trained 2D weights on the parameters' center termporal frame while others as 0. (B, C, (T,) H, W) + nn.init.constant_(model_state_dict[key_model], 0.0) + model_state_dict[key_model][:, :, int(shape_in_model[2] / 2)] = value_ckpt + result_state_dict[key_model] = model_state_dict[key_model] + logger.warning('--- loaded to T: {}'.format(int(shape_in_model[2] / 2))) + logger.warning('--- reshaped ckpt: {}'.format(result_state_dict[key_model].shape)) + matched_keys[key_ckpt] = key_model + else: + assert key_model not in result_state_dict + result_state_dict[key_model] = value_ckpt + if key_ckpt in matched_keys: # already added to matched_keys + logger.error( + "Ambiguity found for {} in checkpoint!" + "It matches at least two keys in the model ({} and {}).".format( + key_ckpt, key_model, matched_keys[key_ckpt] + ) + ) + raise ValueError("Cannot match one checkpoint key to multiple keys in the model.") + if to_print: + logger.info('Matching {} to {}'.format(key_ckpt, key_model)) + matched_keys[key_ckpt] = key_model + + # logging: + matched_model_keys = sorted(matched_keys.values()) + + if len(matched_model_keys) == 0: + logger.warning("No weights in checkpoint matched with model.") + return ckpt_state_dict + common_prefix = _longest_common_prefix(matched_model_keys) + rev_matched_keys = {v: k for k, v in matched_keys.items()} + original_keys = {k: original_keys[rev_matched_keys[k]] for k in matched_model_keys} + + model_key_groups = _group_keys_by_module(matched_model_keys, original_keys) + + table = [] + memo = set() + for key_model in matched_model_keys: + if to_print: + print(' matched:', key_model) + if key_model in memo: + continue + if key_model in model_key_groups: + group = model_key_groups[key_model] + memo |= set(group) + shapes = [tuple(model_state_dict[k].shape) for k in group] + table.append( + ( + _longest_common_prefix([k[len(common_prefix) :] for k in group]) + "*", + _group_str([original_keys[k] for k in group]), + " ".join([str(x).replace(" ", "") for x in shapes]), + ) + ) + else: + key_checkpoint = original_keys[key_model] + shape = str(tuple(model_state_dict[key_model].shape)) + table.append((key_model[len(common_prefix) :], key_checkpoint, shape)) + table_str = tabulate( + table, tablefmt="pipe", headers=["Names in Model", "Names in Checkpoint", "Shapes"] + ) + if to_print: + logger.info( + "Following weights matched with " + + (f"submodule {common_prefix[:-1]}" if common_prefix else "model") + + ":\n" + + table_str + ) + + unmatched_ckpt_keys = [k for k in ckpt_keys if k not in set(matched_keys.keys())] + unmatched_model_keys = [k for k in model_keys if k not in set(matched_keys.values())] + #for k in unmatched_ckpt_keys: + #result_state_dict[k] = ckpt_state_dict[k] + #result_state_dict[k] = model_state_dict[k] + #logger.info('unmatched:', k) + for k in unmatched_model_keys: + #logger.info('unmatched:', k) + result_state_dict[k] = model_state_dict[k] + + return result_state_dict \ No newline at end of file diff --git a/utils/config.py b/utils/config.py new file mode 100644 index 0000000000000000000000000000000000000000..8da7b66f150cc3f3b8fca6a53e20ae2daeba8ea4 --- /dev/null +++ b/utils/config.py @@ -0,0 +1,135 @@ + +"""Config utilities for yml file.""" + +import collections +import functools +import os +import re + +import yaml +# from imaginaire.utils.distributed import master_only_print as print + + +class AttrDict(dict): + """Dict as attribute trick.""" + + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self + for key, value in self.__dict__.items(): + if isinstance(value, dict): + self.__dict__[key] = AttrDict(value) + elif isinstance(value, (list, tuple)): + if isinstance(value[0], dict): + self.__dict__[key] = [AttrDict(item) for item in value] + else: + self.__dict__[key] = value + + def yaml(self): + """Convert object to yaml dict and return.""" + yaml_dict = {} + for key, value in self.__dict__.items(): + if isinstance(value, AttrDict): + yaml_dict[key] = value.yaml() + elif isinstance(value, list): + if isinstance(value[0], AttrDict): + new_l = [] + for item in value: + new_l.append(item.yaml()) + yaml_dict[key] = new_l + else: + yaml_dict[key] = value + else: + yaml_dict[key] = value + return yaml_dict + + def __repr__(self): + """Print all variables.""" + ret_str = [] + for key, value in self.__dict__.items(): + if isinstance(value, AttrDict): + ret_str.append('{}:'.format(key)) + child_ret_str = value.__repr__().split('\n') + for item in child_ret_str: + ret_str.append(' ' + item) + elif isinstance(value, list): + if isinstance(value[0], AttrDict): + ret_str.append('{}:'.format(key)) + for item in value: + # Treat as AttrDict above. + child_ret_str = item.__repr__().split('\n') + for item in child_ret_str: + ret_str.append(' ' + item) + else: + ret_str.append('{}: {}'.format(key, value)) + else: + ret_str.append('{}: {}'.format(key, value)) + return '\n'.join(ret_str) + + +class Config(AttrDict): + r"""Configuration class. This should include every human specifiable + hyperparameter values for your training.""" + + def __init__(self, filename=None, verbose=False): + super(Config, self).__init__() + + # Update with given configurations. + if os.path.exists(filename): + + loader = yaml.SafeLoader + loader.add_implicit_resolver( + u'tag:yaml.org,2002:float', + re.compile(u'''^(?: + [-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)? + |[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+) + |\\.[0-9_]+(?:[eE][-+][0-9]+)? + |[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]* + |[-+]?\\.(?:inf|Inf|INF) + |\\.(?:nan|NaN|NAN))$''', re.X), + list(u'-+0123456789.')) + try: + with open(filename, 'r') as f: + cfg_dict = yaml.load(f, Loader=loader) + except EnvironmentError: + print('Please check the file with name of "%s"', filename) + recursive_update(self, cfg_dict) + else: + raise ValueError('Provided config path not existed: %s' % filename) + + if verbose: + print(' imaginaire config '.center(80, '-')) + print(self.__repr__()) + print(''.center(80, '-')) + + +def rsetattr(obj, attr, val): + """Recursively find object and set value""" + pre, _, post = attr.rpartition('.') + return setattr(rgetattr(obj, pre) if pre else obj, post, val) + + +def rgetattr(obj, attr, *args): + """Recursively find object and return value""" + + def _getattr(obj, attr): + r"""Get attribute.""" + return getattr(obj, attr, *args) + + return functools.reduce(_getattr, [obj] + attr.split('.')) + + +def recursive_update(d, u): + """Recursively update AttrDict d with AttrDict u""" + if u is not None: + for key, value in u.items(): + if isinstance(value, collections.abc.Mapping): + d.__dict__[key] = recursive_update(d.get(key, AttrDict({})), value) + elif isinstance(value, (list, tuple)): + if len(value) > 0 and isinstance(value[0], dict): + d.__dict__[key] = [AttrDict(item) for item in value] + else: + d.__dict__[key] = value + else: + d.__dict__[key] = value + return d diff --git a/utils/distributed.py b/utils/distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..a5808fdf42815e8369628fa43f0efcd0292fa3ff --- /dev/null +++ b/utils/distributed.py @@ -0,0 +1,66 @@ +#!/usr/bin/env python3 + +"""Distributed helpers.""" + +import torch +import torch.distributed as dist + +_LOCAL_PROCESS_GROUP = None + + +def all_reduce(tensors, average=True): + """ + All reduce the provided tensors from all processes across machines. + Args: + tensors (list): tensors to perform all reduce across all processes in + all machines. + average (bool): scales the reduced tensor by the number of overall + processes across all machines. + """ + + for tensor in tensors: + dist.all_reduce(tensor, async_op=False) + if average: + world_size = dist.get_world_size() + for tensor in tensors: + tensor.mul_(1.0 / world_size) + return tensors + + +def is_master_proc(num_gpus=8): + """ + Determines if the current process is the master process. + """ + if torch.distributed.is_initialized(): + return dist.get_rank() % num_gpus == 0 + else: + return True + + +def get_world_size(): + """ + Get the size of the world. + """ + if not dist.is_available(): + return 1 + if not dist.is_initialized(): + return 1 + return dist.get_world_size() + + +def init_distributed_training(cfg): + """ + Initialize variables needed for distributed training. + """ + if cfg.NUM_GPUS <= 1: + return + num_gpus_per_machine = cfg.NUM_GPUS + num_machines = dist.get_world_size() // num_gpus_per_machine + for i in range(num_machines): + ranks_on_i = list( + range(i * num_gpus_per_machine, (i + 1) * num_gpus_per_machine) + ) + pg = dist.new_group(ranks_on_i) + if i == cfg.SHARD_ID: + global _LOCAL_PROCESS_GROUP + _LOCAL_PROCESS_GROUP = pg diff --git a/utils/env.py b/utils/env.py new file mode 100644 index 0000000000000000000000000000000000000000..78d1764c924291a6d3eba4031eaabbbffe985b06 --- /dev/null +++ b/utils/env.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python3 + +"""Set up Environment.""" + +from iopath.common.file_io import PathManagerFactory + +_ENV_SETUP_DONE = False +pathmgr = PathManagerFactory.get(key="pyslowfast") +checkpoint_pathmgr = PathManagerFactory.get(key="pyslowfast_checkpoint") + + +def setup_environment(): + global _ENV_SETUP_DONE + if _ENV_SETUP_DONE: + return + _ENV_SETUP_DONE = True diff --git a/utils/interpol/__init__.py b/utils/interpol/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ecb4adda01be9ca4a18facfbc2c09a9c9a0d1b1c --- /dev/null +++ b/utils/interpol/__init__.py @@ -0,0 +1,7 @@ +from .api import * +from .resize import * +from .restrict import * +from . import backend + +from . import _version +__version__ = _version.get_versions()['version'] diff --git a/utils/interpol/_version.py b/utils/interpol/_version.py new file mode 100644 index 0000000000000000000000000000000000000000..bf96c7aa9aba4320b760ff7443fa2f4199e98abe --- /dev/null +++ b/utils/interpol/_version.py @@ -0,0 +1,623 @@ + +# This file helps to compute a version number in source trees obtained from +# git-archive tarball (such as those provided by githubs download-from-tag +# feature). Distribution tarballs (built by setup.py sdist) and build +# directories (produced by setup.py build) will contain a much shorter file +# that just contains the computed version number. + +# This file is released into the public domain. Generated by +# versioneer-0.20 (https://github.com/python-versioneer/python-versioneer) + +"""Git implementation of _version.py.""" + +import errno +import os +import re +import subprocess +import sys + + +def get_keywords(): + """Get the keywords needed to look up the version information.""" + # these strings will be replaced by git during git-archive. + # setup.py/versioneer.py will grep for the variable names, so they must + # each be defined on a line of their own. _version.py will just call + # get_keywords(). + git_refnames = " (HEAD -> main, tag: 0.2.3)" + git_full = "414ed52c973b9d32e3e6a5a75c91cd5aab064f23" + git_date = "2023-04-17 20:36:50 -0400" + keywords = {"refnames": git_refnames, "full": git_full, "date": git_date} + return keywords + + +class VersioneerConfig: # pylint: disable=too-few-public-methods + """Container for Versioneer configuration parameters.""" + + +def get_config(): + """Create, populate and return the VersioneerConfig() object.""" + # these strings are filled in when 'setup.py versioneer' creates + # _version.py + cfg = VersioneerConfig() + cfg.VCS = "git" + cfg.style = "pep440" + cfg.tag_prefix = "" + cfg.parentdir_prefix = "" + cfg.versionfile_source = "interpol/_version.py" + cfg.verbose = False + return cfg + + +class NotThisMethod(Exception): + """Exception raised if a method is not valid for the current scenario.""" + + +LONG_VERSION_PY = {} +HANDLERS = {} + + +def register_vcs_handler(vcs, method): # decorator + """Create decorator to mark a method as the handler of a VCS.""" + def decorate(f): + """Store f in HANDLERS[vcs][method].""" + if vcs not in HANDLERS: + HANDLERS[vcs] = {} + HANDLERS[vcs][method] = f + return f + return decorate + + +# pylint:disable=too-many-arguments,consider-using-with # noqa +def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, + env=None): + """Call the given command(s).""" + assert isinstance(commands, list) + process = None + for command in commands: + try: + dispcmd = str([command] + args) + # remember shell=False, so use git.cmd on windows, not just git + process = subprocess.Popen([command] + args, cwd=cwd, env=env, + stdout=subprocess.PIPE, + stderr=(subprocess.PIPE if hide_stderr + else None)) + break + except EnvironmentError: + e = sys.exc_info()[1] + if e.errno == errno.ENOENT: + continue + if verbose: + print("unable to run %s" % dispcmd) + print(e) + return None, None + else: + if verbose: + print("unable to find command, tried %s" % (commands,)) + return None, None + stdout = process.communicate()[0].strip().decode() + if process.returncode != 0: + if verbose: + print("unable to run %s (error)" % dispcmd) + print("stdout was %s" % stdout) + return None, process.returncode + return stdout, process.returncode + + +def versions_from_parentdir(parentdir_prefix, root, verbose): + """Try to determine the version from the parent directory name. + + Source tarballs conventionally unpack into a directory that includes both + the project name and a version string. We will also support searching up + two directory levels for an appropriately named parent directory + """ + rootdirs = [] + + for _ in range(3): + dirname = os.path.basename(root) + if dirname.startswith(parentdir_prefix): + return {"version": dirname[len(parentdir_prefix):], + "full-revisionid": None, + "dirty": False, "error": None, "date": None} + rootdirs.append(root) + root = os.path.dirname(root) # up a level + + if verbose: + print("Tried directories %s but none started with prefix %s" % + (str(rootdirs), parentdir_prefix)) + raise NotThisMethod("rootdir doesn't start with parentdir_prefix") + + +@register_vcs_handler("git", "get_keywords") +def git_get_keywords(versionfile_abs): + """Extract version information from the given file.""" + # the code embedded in _version.py can just fetch the value of these + # keywords. When used from setup.py, we don't want to import _version.py, + # so we do it with a regexp instead. This function is not used from + # _version.py. + keywords = {} + try: + with open(versionfile_abs, "r") as fobj: + for line in fobj: + if line.strip().startswith("git_refnames ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["refnames"] = mo.group(1) + if line.strip().startswith("git_full ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["full"] = mo.group(1) + if line.strip().startswith("git_date ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["date"] = mo.group(1) + except EnvironmentError: + pass + return keywords + + +@register_vcs_handler("git", "keywords") +def git_versions_from_keywords(keywords, tag_prefix, verbose): + """Get version information from git keywords.""" + if "refnames" not in keywords: + raise NotThisMethod("Short version file found") + date = keywords.get("date") + if date is not None: + # Use only the last line. Previous lines may contain GPG signature + # information. + date = date.splitlines()[-1] + + # git-2.2.0 added "%cI", which expands to an ISO-8601 -compliant + # datestamp. However we prefer "%ci" (which expands to an "ISO-8601 + # -like" string, which we must then edit to make compliant), because + # it's been around since git-1.5.3, and it's too difficult to + # discover which version we're using, or to work around using an + # older one. + date = date.strip().replace(" ", "T", 1).replace(" ", "", 1) + refnames = keywords["refnames"].strip() + if refnames.startswith("$Format"): + if verbose: + print("keywords are unexpanded, not using") + raise NotThisMethod("unexpanded keywords, not a git-archive tarball") + refs = {r.strip() for r in refnames.strip("()").split(",")} + # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of + # just "foo-1.0". If we see a "tag: " prefix, prefer those. + TAG = "tag: " + tags = {r[len(TAG):] for r in refs if r.startswith(TAG)} + if not tags: + # Either we're using git < 1.8.3, or there really are no tags. We use + # a heuristic: assume all version tags have a digit. The old git %d + # expansion behaves like git log --decorate=short and strips out the + # refs/heads/ and refs/tags/ prefixes that would let us distinguish + # between branches and tags. By ignoring refnames without digits, we + # filter out many common branch names like "release" and + # "stabilization", as well as "HEAD" and "master". + tags = {r for r in refs if re.search(r'\d', r)} + if verbose: + print("discarding '%s', no digits" % ",".join(refs - tags)) + if verbose: + print("likely tags: %s" % ",".join(sorted(tags))) + for ref in sorted(tags): + # sorting will prefer e.g. "2.0" over "2.0rc1" + if ref.startswith(tag_prefix): + r = ref[len(tag_prefix):] + # Filter out refs that exactly match prefix or that don't start + # with a number once the prefix is stripped (mostly a concern + # when prefix is '') + if not re.match(r'\d', r): + continue + if verbose: + print("picking %s" % r) + return {"version": r, + "full-revisionid": keywords["full"].strip(), + "dirty": False, "error": None, + "date": date} + # no suitable tags, so version is "0+unknown", but full hex is still there + if verbose: + print("no suitable tags, using unknown + full revision id") + return {"version": "0+unknown", + "full-revisionid": keywords["full"].strip(), + "dirty": False, "error": "no suitable tags", "date": None} + + +@register_vcs_handler("git", "pieces_from_vcs") +def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command): + """Get version from 'git describe' in the root of the source tree. + + This only gets called if the git-archive 'subst' keywords were *not* + expanded, and _version.py hasn't already been rewritten with a short + version string, meaning we're inside a checked out source tree. + """ + GITS = ["git"] + if sys.platform == "win32": + GITS = ["git.cmd", "git.exe"] + + _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root, + hide_stderr=True) + if rc != 0: + if verbose: + print("Directory %s not under git control" % root) + raise NotThisMethod("'git rev-parse --git-dir' returned error") + + # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] + # if there isn't one, this yields HEX[-dirty] (no NUM) + describe_out, rc = runner(GITS, ["describe", "--tags", "--dirty", + "--always", "--long", + "--match", "%s*" % tag_prefix], + cwd=root) + # --long was added in git-1.5.5 + if describe_out is None: + raise NotThisMethod("'git describe' failed") + describe_out = describe_out.strip() + full_out, rc = runner(GITS, ["rev-parse", "HEAD"], cwd=root) + if full_out is None: + raise NotThisMethod("'git rev-parse' failed") + full_out = full_out.strip() + + pieces = {} + pieces["long"] = full_out + pieces["short"] = full_out[:7] # maybe improved later + pieces["error"] = None + + branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"], + cwd=root) + # --abbrev-ref was added in git-1.6.3 + if rc != 0 or branch_name is None: + raise NotThisMethod("'git rev-parse --abbrev-ref' returned error") + branch_name = branch_name.strip() + + if branch_name == "HEAD": + # If we aren't exactly on a branch, pick a branch which represents + # the current commit. If all else fails, we are on a branchless + # commit. + branches, rc = runner(GITS, ["branch", "--contains"], cwd=root) + # --contains was added in git-1.5.4 + if rc != 0 or branches is None: + raise NotThisMethod("'git branch --contains' returned error") + branches = branches.split("\n") + + # Remove the first line if we're running detached + if "(" in branches[0]: + branches.pop(0) + + # Strip off the leading "* " from the list of branches. + branches = [branch[2:] for branch in branches] + if "master" in branches: + branch_name = "master" + elif not branches: + branch_name = None + else: + # Pick the first branch that is returned. Good or bad. + branch_name = branches[0] + + pieces["branch"] = branch_name + + # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] + # TAG might have hyphens. + git_describe = describe_out + + # look for -dirty suffix + dirty = git_describe.endswith("-dirty") + pieces["dirty"] = dirty + if dirty: + git_describe = git_describe[:git_describe.rindex("-dirty")] + + # now we have TAG-NUM-gHEX or HEX + + if "-" in git_describe: + # TAG-NUM-gHEX + mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) + if not mo: + # unparseable. Maybe git-describe is misbehaving? + pieces["error"] = ("unable to parse git-describe output: '%s'" + % describe_out) + return pieces + + # tag + full_tag = mo.group(1) + if not full_tag.startswith(tag_prefix): + if verbose: + fmt = "tag '%s' doesn't start with prefix '%s'" + print(fmt % (full_tag, tag_prefix)) + pieces["error"] = ("tag '%s' doesn't start with prefix '%s'" + % (full_tag, tag_prefix)) + return pieces + pieces["closest-tag"] = full_tag[len(tag_prefix):] + + # distance: number of commits since tag + pieces["distance"] = int(mo.group(2)) + + # commit: short hex revision ID + pieces["short"] = mo.group(3) + + else: + # HEX: no tags + pieces["closest-tag"] = None + count_out, rc = runner(GITS, ["rev-list", "HEAD", "--count"], cwd=root) + pieces["distance"] = int(count_out) # total number of commits + + # commit date: see ISO-8601 comment in git_versions_from_keywords() + date = runner(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[0].strip() + # Use only the last line. Previous lines may contain GPG signature + # information. + date = date.splitlines()[-1] + pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) + + return pieces + + +def plus_or_dot(pieces): + """Return a + if we don't already have one, else return a .""" + if "+" in pieces.get("closest-tag", ""): + return "." + return "+" + + +def render_pep440(pieces): + """Build up version string, with post-release "local version identifier". + + Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you + get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty + + Exceptions: + 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += plus_or_dot(pieces) + rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + else: + # exception #1 + rendered = "0+untagged.%d.g%s" % (pieces["distance"], + pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + return rendered + + +def render_pep440_branch(pieces): + """TAG[[.dev0]+DISTANCE.gHEX[.dirty]] . + + The ".dev0" means not master branch. Note that .dev0 sorts backwards + (a feature branch will appear "older" than the master branch). + + Exceptions: + 1: no tags. 0[.dev0]+untagged.DISTANCE.gHEX[.dirty] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += plus_or_dot(pieces) + rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + else: + # exception #1 + rendered = "0" + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += "+untagged.%d.g%s" % (pieces["distance"], + pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + return rendered + + +def render_pep440_pre(pieces): + """TAG[.post0.devDISTANCE] -- No -dirty. + + Exceptions: + 1: no tags. 0.post0.devDISTANCE + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"]: + rendered += ".post0.dev%d" % pieces["distance"] + else: + # exception #1 + rendered = "0.post0.dev%d" % pieces["distance"] + return rendered + + +def render_pep440_post(pieces): + """TAG[.postDISTANCE[.dev0]+gHEX] . + + The ".dev0" means dirty. Note that .dev0 sorts backwards + (a dirty tree will appear "older" than the corresponding clean one), + but you shouldn't be releasing software with -dirty anyways. + + Exceptions: + 1: no tags. 0.postDISTANCE[.dev0] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += ".post%d" % pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + rendered += plus_or_dot(pieces) + rendered += "g%s" % pieces["short"] + else: + # exception #1 + rendered = "0.post%d" % pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + rendered += "+g%s" % pieces["short"] + return rendered + + +def render_pep440_post_branch(pieces): + """TAG[.postDISTANCE[.dev0]+gHEX[.dirty]] . + + The ".dev0" means not master branch. + + Exceptions: + 1: no tags. 0.postDISTANCE[.dev0]+gHEX[.dirty] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += ".post%d" % pieces["distance"] + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += plus_or_dot(pieces) + rendered += "g%s" % pieces["short"] + if pieces["dirty"]: + rendered += ".dirty" + else: + # exception #1 + rendered = "0.post%d" % pieces["distance"] + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += "+g%s" % pieces["short"] + if pieces["dirty"]: + rendered += ".dirty" + return rendered + + +def render_pep440_old(pieces): + """TAG[.postDISTANCE[.dev0]] . + + The ".dev0" means dirty. + + Exceptions: + 1: no tags. 0.postDISTANCE[.dev0] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += ".post%d" % pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + else: + # exception #1 + rendered = "0.post%d" % pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + return rendered + + +def render_git_describe(pieces): + """TAG[-DISTANCE-gHEX][-dirty]. + + Like 'git describe --tags --dirty --always'. + + Exceptions: + 1: no tags. HEX[-dirty] (note: no 'g' prefix) + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"]: + rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) + else: + # exception #1 + rendered = pieces["short"] + if pieces["dirty"]: + rendered += "-dirty" + return rendered + + +def render_git_describe_long(pieces): + """TAG-DISTANCE-gHEX[-dirty]. + + Like 'git describe --tags --dirty --always -long'. + The distance/hash is unconditional. + + Exceptions: + 1: no tags. HEX[-dirty] (note: no 'g' prefix) + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) + else: + # exception #1 + rendered = pieces["short"] + if pieces["dirty"]: + rendered += "-dirty" + return rendered + + +def render(pieces, style): + """Render the given version pieces into the requested style.""" + if pieces["error"]: + return {"version": "unknown", + "full-revisionid": pieces.get("long"), + "dirty": None, + "error": pieces["error"], + "date": None} + + if not style or style == "default": + style = "pep440" # the default + + if style == "pep440": + rendered = render_pep440(pieces) + elif style == "pep440-branch": + rendered = render_pep440_branch(pieces) + elif style == "pep440-pre": + rendered = render_pep440_pre(pieces) + elif style == "pep440-post": + rendered = render_pep440_post(pieces) + elif style == "pep440-post-branch": + rendered = render_pep440_post_branch(pieces) + elif style == "pep440-old": + rendered = render_pep440_old(pieces) + elif style == "git-describe": + rendered = render_git_describe(pieces) + elif style == "git-describe-long": + rendered = render_git_describe_long(pieces) + else: + raise ValueError("unknown style '%s'" % style) + + return {"version": rendered, "full-revisionid": pieces["long"], + "dirty": pieces["dirty"], "error": None, + "date": pieces.get("date")} + + +def get_versions(): + """Get version information or return default if unable to do so.""" + # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have + # __file__, we can work backwards from there to the root. Some + # py2exe/bbfreeze/non-CPython implementations don't do __file__, in which + # case we can only use expanded keywords. + + cfg = get_config() + verbose = cfg.verbose + + try: + return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, + verbose) + except NotThisMethod: + pass + + try: + root = os.path.realpath(__file__) + # versionfile_source is the relative path from the top of the source + # tree (where the .git directory might live) to this file. Invert + # this to find the root from __file__. + for _ in cfg.versionfile_source.split('/'): + root = os.path.dirname(root) + except NameError: + return {"version": "0+unknown", "full-revisionid": None, + "dirty": None, + "error": "unable to find root of source tree", + "date": None} + + try: + pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose) + return render(pieces, cfg.style) + except NotThisMethod: + pass + + try: + if cfg.parentdir_prefix: + return versions_from_parentdir(cfg.parentdir_prefix, root, verbose) + except NotThisMethod: + pass + + return {"version": "0+unknown", "full-revisionid": None, + "dirty": None, + "error": "unable to compute version", "date": None} diff --git a/utils/interpol/api.py b/utils/interpol/api.py new file mode 100644 index 0000000000000000000000000000000000000000..b7c0066ea08c74ae8ff4beaf8b694051a0dded78 --- /dev/null +++ b/utils/interpol/api.py @@ -0,0 +1,560 @@ +"""High level interpolation API""" + +__all__ = ['grid_pull', 'grid_push', 'grid_count', 'grid_grad', + 'spline_coeff', 'spline_coeff_nd', + 'identity_grid', 'add_identity_grid', 'add_identity_grid_'] + +import torch +from .utils import expanded_shape, matvec +from .jit_utils import movedim1, meshgrid +from .autograd import (GridPull, GridPush, GridCount, GridGrad, + SplineCoeff, SplineCoeffND) +from . import backend, jitfields + +_doc_interpolation = \ +"""`interpolation` can be an int, a string or an InterpolationType. + Possible values are: + - 0 or 'nearest' + - 1 or 'linear' + - 2 or 'quadratic' + - 3 or 'cubic' + - 4 or 'fourth' + - 5 or 'fifth' + - etc. + A list of values can be provided, in the order [W, H, D], + to specify dimension-specific interpolation orders.""" + +_doc_bound = \ +"""`bound` can be an int, a string or a BoundType. + Possible values are: + - 'replicate' or 'nearest' : a a a | a b c d | d d d + - 'dct1' or 'mirror' : d c b | a b c d | c b a + - 'dct2' or 'reflect' : c b a | a b c d | d c b + - 'dst1' or 'antimirror' : -b -a 0 | a b c d | 0 -d -c + - 'dst2' or 'antireflect' : -c -b -a | a b c d | -d -c -b + - 'dft' or 'wrap' : b c d | a b c d | a b c + - 'zero' or 'zeros' : 0 0 0 | a b c d | 0 0 0 + A list of values can be provided, in the order [W, H, D], + to specify dimension-specific boundary conditions. + Note that + - `dft` corresponds to circular padding + - `dct2` corresponds to Neumann boundary conditions (symmetric) + - `dst2` corresponds to Dirichlet boundary conditions (antisymmetric) + See https://en.wikipedia.org/wiki/Discrete_cosine_transform + https://en.wikipedia.org/wiki/Discrete_sine_transform""" + +_doc_bound_coeff = \ +"""`bound` can be an int, a string or a BoundType. + Possible values are: + - 'replicate' or 'nearest' : a a a | a b c d | d d d + - 'dct1' or 'mirror' : d c b | a b c d | c b a + - 'dct2' or 'reflect' : c b a | a b c d | d c b + - 'dst1' or 'antimirror' : -b -a 0 | a b c d | 0 -d -c + - 'dst2' or 'antireflect' : -c -b -a | a b c d | -d -c -b + - 'dft' or 'wrap' : b c d | a b c d | a b c + - 'zero' or 'zeros' : 0 0 0 | a b c d | 0 0 0 + A list of values can be provided, in the order [W, H, D], + to specify dimension-specific boundary conditions. + Note that + - `dft` corresponds to circular padding + - `dct1` corresponds to mirroring about the center of the first/last voxel + - `dct2` corresponds to mirroring about the edge of the first/last voxel + See https://en.wikipedia.org/wiki/Discrete_cosine_transform + https://en.wikipedia.org/wiki/Discrete_sine_transform + + /!\ Only 'dct1', 'dct2' and 'dft' are implemented for interpolation + orders >= 6.""" + +_ref_coeff = \ +"""..[1] M. Unser, A. Aldroubi and M. Eden. + "B-Spline Signal Processing: Part I-Theory," + IEEE Transactions on Signal Processing 41(2):821-832 (1993). +..[2] M. Unser, A. Aldroubi and M. Eden. + "B-Spline Signal Processing: Part II-Efficient Design and Applications," + IEEE Transactions on Signal Processing 41(2):834-848 (1993). +..[3] M. Unser. + "Splines: A Perfect Fit for Signal and Image Processing," + IEEE Signal Processing Magazine 16(6):22-38 (1999). +""" + + +def _preproc(grid, input=None, mode=None): + """Preprocess tensors for pull/push/count/grad + + Low level bindings expect inputs of shape + [batch, channel, *spatial] and [batch, *spatial, dim], whereas + the high level python API accepts inputs of shape + [..., [channel], *spatial] and [..., *spatial, dim]. + + This function broadcasts and reshapes the input tensors accordingly. + /!\\ This *can* trigger large allocations /!\\ + """ + dim = grid.shape[-1] + if input is None: + spatial = grid.shape[-dim-1:-1] + batch = grid.shape[:-dim-1] + grid = grid.reshape([-1, *spatial, dim]) + info = dict(batch=batch, channel=[1] if batch else [], dim=dim) + return grid, info + + grid_spatial = grid.shape[-dim-1:-1] + grid_batch = grid.shape[:-dim-1] + input_spatial = input.shape[-dim:] + channel = 0 if input.dim() == dim else input.shape[-dim-1] + input_batch = input.shape[:-dim-1] + + if mode == 'push': + grid_spatial = input_spatial = expanded_shape(grid_spatial, input_spatial) + + # broadcast and reshape + batch = expanded_shape(grid_batch, input_batch) + grid = grid.expand([*batch, *grid_spatial, dim]) + grid = grid.reshape([-1, *grid_spatial, dim]) + input = input.expand([*batch, channel or 1, *input_spatial]) + input = input.reshape([-1, channel or 1, *input_spatial]) + + out_channel = [channel] if channel else ([1] if batch else []) + info = dict(batch=batch, channel=out_channel, dim=dim) + return grid, input, info + + +def _postproc(out, shape_info, mode): + """Postprocess tensors for pull/push/count/grad""" + dim = shape_info['dim'] + if mode != 'grad': + spatial = out.shape[-dim:] + feat = [] + else: + spatial = out.shape[-dim-1:-1] + feat = [out.shape[-1]] + batch = shape_info['batch'] + channel = shape_info['channel'] + + out = out.reshape([*batch, *channel, *spatial, *feat]) + return out + + +def grid_pull(input, grid, interpolation='linear', bound='zero', + extrapolate=False, prefilter=False): + """Sample an image with respect to a deformation field. + + Notes + ----- + {interpolation} + + {bound} + + If the input dtype is not a floating point type, the input image is + assumed to contain labels. Then, unique labels are extracted + and resampled individually, making them soft labels. Finally, + the label map is reconstructed from the individual soft labels by + assigning the label with maximum soft value. + + Parameters + ---------- + input : (..., [channel], *inshape) tensor + Input image. + grid : (..., *outshape, dim) tensor + Transformation field. + interpolation : int or sequence[int], default=1 + Interpolation order. + bound : BoundType or sequence[BoundType], default='zero' + Boundary conditions. + extrapolate : bool or int, default=True + Extrapolate out-of-bound data. + prefilter : bool, default=False + Apply spline pre-filter (= interpolates the input) + + Returns + ------- + output : (..., [channel], *outshape) tensor + Deformed image. + + """ + if backend.jitfields and jitfields.available: + return jitfields.grid_pull(input, grid, interpolation, bound, + extrapolate, prefilter) + + grid, input, shape_info = _preproc(grid, input) + batch, channel = input.shape[:2] + dim = grid.shape[-1] + + if not input.dtype.is_floating_point: + # label map -> specific processing + out = input.new_zeros([batch, channel, *grid.shape[1:-1]]) + pmax = grid.new_zeros([batch, channel, *grid.shape[1:-1]]) + for label in input.unique(): + soft = (input == label).to(grid.dtype) + if prefilter: + input = spline_coeff_nd(soft, interpolation=interpolation, + bound=bound, dim=dim, inplace=True) + soft = GridPull.apply(soft, grid, interpolation, bound, extrapolate) + out[soft > pmax] = label + pmax = torch.max(pmax, soft) + else: + if prefilter: + input = spline_coeff_nd(input, interpolation=interpolation, + bound=bound, dim=dim) + out = GridPull.apply(input, grid, interpolation, bound, extrapolate) + + return _postproc(out, shape_info, mode='pull') + + +def grid_push(input, grid, shape=None, interpolation='linear', bound='zero', + extrapolate=False, prefilter=False): + """Splat an image with respect to a deformation field (pull adjoint). + + Notes + ----- + {interpolation} + + {bound} + + Parameters + ---------- + input : (..., [channel], *inshape) tensor + Input image. + grid : (..., *inshape, dim) tensor + Transformation field. + shape : sequence[int], default=inshape + Output shape + interpolation : int or sequence[int], default=1 + Interpolation order. + bound : BoundType, or sequence[BoundType], default='zero' + Boundary conditions. + extrapolate : bool or int, default=True + Extrapolate out-of-bound data. + prefilter : bool, default=False + Apply spline pre-filter. + + Returns + ------- + output : (..., [channel], *shape) tensor + Spatted image. + + """ + if backend.jitfields and jitfields.available: + return jitfields.grid_push(input, grid, shape, interpolation, bound, + extrapolate, prefilter) + + grid, input, shape_info = _preproc(grid, input, mode='push') + dim = grid.shape[-1] + + if shape is None: + shape = tuple(input.shape[2:]) + + out = GridPush.apply(input, grid, shape, interpolation, bound, extrapolate) + if prefilter: + out = spline_coeff_nd(out, interpolation=interpolation, bound=bound, + dim=dim, inplace=True) + return _postproc(out, shape_info, mode='push') + + +def grid_count(grid, shape=None, interpolation='linear', bound='zero', + extrapolate=False): + """Splatting weights with respect to a deformation field (pull adjoint). + + Notes + ----- + {interpolation} + + {bound} + + Parameters + ---------- + grid : (..., *inshape, dim) tensor + Transformation field. + shape : sequence[int], default=inshape + Output shape + interpolation : int or sequence[int], default=1 + Interpolation order. + bound : BoundType, or sequence[BoundType], default='zero' + Boundary conditions. + extrapolate : bool or int, default=True + Extrapolate out-of-bound data. + + Returns + ------- + output : (..., [1], *shape) tensor + Splatted weights. + + """ + if backend.jitfields and jitfields.available: + return jitfields.grid_count(grid, shape, interpolation, bound, extrapolate) + + grid, shape_info = _preproc(grid) + out = GridCount.apply(grid, shape, interpolation, bound, extrapolate) + return _postproc(out, shape_info, mode='count') + + +def grid_grad(input, grid, interpolation='linear', bound='zero', + extrapolate=False, prefilter=False): + """Sample spatial gradients of an image with respect to a deformation field. + + Notes + ----- + {interpolation} + + {bound} + + Parameters + ---------- + input : (..., [channel], *inshape) tensor + Input image. + grid : (..., *inshape, dim) tensor + Transformation field. + shape : sequence[int], default=inshape + Output shape + interpolation : int or sequence[int], default=1 + Interpolation order. + bound : BoundType, or sequence[BoundType], default='zero' + Boundary conditions. + extrapolate : bool or int, default=True + Extrapolate out-of-bound data. + prefilter : bool, default=False + Apply spline pre-filter (= interpolates the input) + + Returns + ------- + output : (..., [channel], *shape, dim) tensor + Sampled gradients. + + """ + if backend.jitfields and jitfields.available: + return jitfields.grid_grad(input, grid, interpolation, bound, + extrapolate, prefilter) + + grid, input, shape_info = _preproc(grid, input) + dim = grid.shape[-1] + if prefilter: + input = spline_coeff_nd(input, interpolation, bound, dim) + out = GridGrad.apply(input, grid, interpolation, bound, extrapolate) + return _postproc(out, shape_info, mode='grad') + + +def spline_coeff(input, interpolation='linear', bound='dct2', dim=-1, + inplace=False): + """Compute the interpolating spline coefficients, for a given spline order + and boundary conditions, along a single dimension. + + Notes + ----- + {interpolation} + + {bound} + + References + ---------- + {ref} + + + Parameters + ---------- + input : tensor + Input image. + interpolation : int or sequence[int], default=1 + Interpolation order. + bound : BoundType or sequence[BoundType], default='dct1' + Boundary conditions. + dim : int, default=-1 + Dimension along which to process + inplace : bool, default=False + Process the volume in place. + + Returns + ------- + output : tensor + Coefficient image. + + """ + # This implementation is based on the file bsplines.c in SPM12, written + # by John Ashburner, which is itself based on the file coeff.c, + # written by Philippe Thevenaz: http://bigwww.epfl.ch/thevenaz/interpolation + # . DCT1 boundary conditions were derived by Thevenaz and Unser. + # . DFT boundary conditions were derived by John Ashburner. + # SPM12 is released under the GNU-GPL v2 license. + # Philippe Thevenaz's code does not have an explicit license as far + # as we know. + if backend.jitfields and jitfields.available: + return jitfields.spline_coeff(input, interpolation, bound, + dim, inplace) + + out = SplineCoeff.apply(input, bound, interpolation, dim, inplace) + return out + + +def spline_coeff_nd(input, interpolation='linear', bound='dct2', dim=None, + inplace=False): + """Compute the interpolating spline coefficients, for a given spline order + and boundary conditions, along the last `dim` dimensions. + + Notes + ----- + {interpolation} + + {bound} + + References + ---------- + {ref} + + Parameters + ---------- + input : (..., *spatial) tensor + Input image. + interpolation : int or sequence[int], default=1 + Interpolation order. + bound : BoundType or sequence[BoundType], default='dct1' + Boundary conditions. + dim : int, default=-1 + Number of spatial dimensions + inplace : bool, default=False + Process the volume in place. + + Returns + ------- + output : (..., *spatial) tensor + Coefficient image. + + """ + # This implementation is based on the file bsplines.c in SPM12, written + # by John Ashburner, which is itself based on the file coeff.c, + # written by Philippe Thevenaz: http://bigwww.epfl.ch/thevenaz/interpolation + # . DCT1 boundary conditions were derived by Thevenaz and Unser. + # . DFT boundary conditions were derived by John Ashburner. + # SPM12 is released under the GNU-GPL v2 license. + # Philippe Thevenaz's code does not have an explicit license as far + # as we know. + if backend.jitfields and jitfields.available: + return jitfields.spline_coeff_nd(input, interpolation, bound, + dim, inplace) + + out = SplineCoeffND.apply(input, bound, interpolation, dim, inplace) + return out + + +grid_pull.__doc__ = grid_pull.__doc__.format( + interpolation=_doc_interpolation, bound=_doc_bound) +grid_push.__doc__ = grid_push.__doc__.format( + interpolation=_doc_interpolation, bound=_doc_bound) +grid_count.__doc__ = grid_count.__doc__.format( + interpolation=_doc_interpolation, bound=_doc_bound) +grid_grad.__doc__ = grid_grad.__doc__.format( + interpolation=_doc_interpolation, bound=_doc_bound) +spline_coeff.__doc__ = spline_coeff.__doc__.format( + interpolation=_doc_interpolation, bound=_doc_bound_coeff, ref=_ref_coeff) +spline_coeff_nd.__doc__ = spline_coeff_nd.__doc__.format( + interpolation=_doc_interpolation, bound=_doc_bound_coeff, ref=_ref_coeff) + +# aliases +pull = grid_pull +push = grid_push +count = grid_count + + +def identity_grid(shape, dtype=None, device=None): + """Returns an identity deformation field. + + Parameters + ---------- + shape : (dim,) sequence of int + Spatial dimension of the field. + dtype : torch.dtype, default=`get_default_dtype()` + Data type. + device torch.device, optional + Device. + + Returns + ------- + grid : (*shape, dim) tensor + Transformation field + + """ + mesh1d = [torch.arange(float(s), dtype=dtype, device=device) + for s in shape] + grid = torch.stack(meshgrid(mesh1d), dim=-1) + return grid + + +@torch.jit.script +def add_identity_grid_(disp): + """Adds the identity grid to a displacement field, inplace. + + Parameters + ---------- + disp : (..., *spatial, dim) tensor + Displacement field + + Returns + ------- + grid : (..., *spatial, dim) tensor + Transformation field + + """ + dim = disp.shape[-1] + spatial = disp.shape[-dim-1:-1] + mesh1d = [torch.arange(s, dtype=disp.dtype, device=disp.device) + for s in spatial] + grid = meshgrid(mesh1d) + disp = movedim1(disp, -1, 0) + for i, grid1 in enumerate(grid): + disp[i].add_(grid1) + disp = movedim1(disp, 0, -1) + return disp + + +@torch.jit.script +def add_identity_grid(disp): + """Adds the identity grid to a displacement field. + + Parameters + ---------- + disp : (..., *spatial, dim) tensor + Displacement field + + Returns + ------- + grid : (..., *spatial, dim) tensor + Transformation field + + """ + return add_identity_grid_(disp.clone()) + + +def affine_grid(mat, shape): + """Create a dense transformation grid from an affine matrix. + + Parameters + ---------- + mat : (..., D[+1], D+1) tensor + Affine matrix (or matrices). + shape : (D,) sequence[int] + Shape of the grid, with length D. + + Returns + ------- + grid : (..., *shape, D) tensor + Dense transformation grid + + """ + mat = torch.as_tensor(mat) + shape = list(shape) + nb_dim = mat.shape[-1] - 1 + if nb_dim != len(shape): + raise ValueError('Dimension of the affine matrix ({}) and shape ({}) ' + 'are not the same.'.format(nb_dim, len(shape))) + if mat.shape[-2] not in (nb_dim, nb_dim+1): + raise ValueError('First argument should be matrces of shape ' + '(..., {0}, {1}) or (..., {1], {1}) but got {2}.' + .format(nb_dim, nb_dim+1, mat.shape)) + batch_shape = mat.shape[:-2] + grid = identity_grid(shape, mat.dtype, mat.device) + if batch_shape: + for _ in range(len(batch_shape)): + grid = grid.unsqueeze(0) + for _ in range(nb_dim): + mat = mat.unsqueeze(-1) + lin = mat[..., :nb_dim, :nb_dim] + off = mat[..., :nb_dim, -1] + grid = matvec(lin, grid) + off + return grid diff --git a/utils/interpol/autograd.py b/utils/interpol/autograd.py new file mode 100644 index 0000000000000000000000000000000000000000..40cace911615a22a3d33cc79c4697bfabd868c2b --- /dev/null +++ b/utils/interpol/autograd.py @@ -0,0 +1,301 @@ +"""AutoGrad version of pull/push/count/grad""" +import torch +from .coeff import spline_coeff_nd, spline_coeff +from .bounds import BoundType +from .splines import InterpolationType +from .pushpull import ( + grid_pull, grid_pull_backward, + grid_push, grid_push_backward, + grid_count, grid_count_backward, + grid_grad, grid_grad_backward) +from .utils import fake_decorator +try: + from torch.cuda.amp import custom_fwd, custom_bwd +except (ModuleNotFoundError, ImportError): + custom_fwd = custom_bwd = fake_decorator + + +def make_list(x): + if not isinstance(x, (list, tuple)): + x = [x] + return list(x) + + +def bound_to_nitorch(bound, as_type='str'): + """Convert boundary type to niTorch's convention. + + Parameters + ---------- + bound : [list of] str or bound_like + Boundary condition in any convention + as_type : {'str', 'enum', 'int'}, default='str' + Return BoundType or int rather than str + + Returns + ------- + bound : [list of] str or BoundType + Boundary condition in NITorch's convention + + """ + intype = type(bound) + if not isinstance(bound, (list, tuple)): + bound = [bound] + obound = [] + for b in bound: + b = b.lower() if isinstance(b, str) else b + if b in ('replicate', 'repeat', 'border', 'nearest', BoundType.replicate): + obound.append('replicate') + elif b in ('zero', 'zeros', 'constant', BoundType.zero): + obound.append('zero') + elif b in ('dct2', 'reflect', 'reflection', 'neumann', BoundType.dct2): + obound.append('dct2') + elif b in ('dct1', 'mirror', BoundType.dct1): + obound.append('dct1') + elif b in ('dft', 'wrap', 'circular', BoundType.dft): + obound.append('dft') + elif b in ('dst2', 'antireflect', 'dirichlet', BoundType.dst2): + obound.append('dst2') + elif b in ('dst1', 'antimirror', BoundType.dst1): + obound.append('dst1') + elif isinstance(b, int): + obound.append(b) + else: + raise ValueError(f'Unknown boundary condition {b}') + obound = list(map(lambda b: getattr(BoundType, b) if isinstance(b, str) + else BoundType(b), obound)) + if as_type in ('int', int): + obound = [b.value for b in obound] + if as_type in ('str', str): + obound = [b.name for b in obound] + if issubclass(intype, (list, tuple)): + obound = intype(obound) + else: + obound = obound[0] + return obound + + +def inter_to_nitorch(inter, as_type='str'): + """Convert interpolation order to NITorch's convention. + + Parameters + ---------- + inter : [sequence of] int or str or InterpolationType + as_type : {'str', 'enum', 'int'}, default='int' + + Returns + ------- + inter : [sequence of] int or InterpolationType + + """ + intype = type(inter) + if not isinstance(inter, (list, tuple)): + inter = [inter] + ointer = [] + for o in inter: + o = o.lower() if isinstance(o, str) else o + if o in (0, 'nearest', InterpolationType.nearest): + ointer.append(0) + elif o in (1, 'linear', InterpolationType.linear): + ointer.append(1) + elif o in (2, 'quadratic', InterpolationType.quadratic): + ointer.append(2) + elif o in (3, 'cubic', InterpolationType.cubic): + ointer.append(3) + elif o in (4, 'fourth', InterpolationType.fourth): + ointer.append(4) + elif o in (5, 'fifth', InterpolationType.fifth): + ointer.append(5) + elif o in (6, 'sixth', InterpolationType.sixth): + ointer.append(6) + elif o in (7, 'seventh', InterpolationType.seventh): + ointer.append(7) + else: + raise ValueError(f'Unknown interpolation order {o}') + if as_type in ('enum', 'str', str): + ointer = list(map(InterpolationType, ointer)) + if as_type in ('str', str): + ointer = [o.name for o in ointer] + if issubclass(intype, (list, tuple)): + ointer = intype(ointer) + else: + ointer = ointer[0] + return ointer + + +class GridPull(torch.autograd.Function): + + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, input, grid, interpolation, bound, extrapolate): + + bound = bound_to_nitorch(make_list(bound), as_type='int') + interpolation = inter_to_nitorch(make_list(interpolation), as_type='int') + extrapolate = int(extrapolate) + opt = (bound, interpolation, extrapolate) + + # Pull + output = grid_pull(input, grid, *opt) + + # Context + ctx.opt = opt + ctx.save_for_backward(input, grid) + + return output + + @staticmethod + @custom_bwd + def backward(ctx, grad): + var = ctx.saved_tensors + opt = ctx.opt + grads = grid_pull_backward(grad, *var, *opt) + grad_input, grad_grid = grads + return grad_input, grad_grid, None, None, None + + +class GridPush(torch.autograd.Function): + + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, input, grid, shape, interpolation, bound, extrapolate): + + bound = bound_to_nitorch(make_list(bound), as_type='int') + interpolation = inter_to_nitorch(make_list(interpolation), as_type='int') + extrapolate = int(extrapolate) + opt = (bound, interpolation, extrapolate) + + # Push + output = grid_push(input, grid, shape, *opt) + + # Context + ctx.opt = opt + ctx.save_for_backward(input, grid) + + return output + + @staticmethod + @custom_bwd + def backward(ctx, grad): + var = ctx.saved_tensors + opt = ctx.opt + grads = grid_push_backward(grad, *var, *opt) + grad_input, grad_grid = grads + return grad_input, grad_grid, None, None, None, None + + +class GridCount(torch.autograd.Function): + + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, grid, shape, interpolation, bound, extrapolate): + + bound = bound_to_nitorch(make_list(bound), as_type='int') + interpolation = inter_to_nitorch(make_list(interpolation), as_type='int') + extrapolate = int(extrapolate) + opt = (bound, interpolation, extrapolate) + + # Push + output = grid_count(grid, shape, *opt) + + # Context + ctx.opt = opt + ctx.save_for_backward(grid) + + return output + + @staticmethod + @custom_bwd + def backward(ctx, grad): + var = ctx.saved_tensors + opt = ctx.opt + grad_grid = None + if ctx.needs_input_grad[0]: + grad_grid = grid_count_backward(grad, *var, *opt) + return grad_grid, None, None, None, None + + +class GridGrad(torch.autograd.Function): + + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, input, grid, interpolation, bound, extrapolate): + + bound = bound_to_nitorch(make_list(bound), as_type='int') + interpolation = inter_to_nitorch(make_list(interpolation), as_type='int') + extrapolate = int(extrapolate) + opt = (bound, interpolation, extrapolate) + + # Pull + output = grid_grad(input, grid, *opt) + + # Context + ctx.opt = opt + ctx.save_for_backward(input, grid) + + return output + + @staticmethod + @custom_bwd + def backward(ctx, grad): + var = ctx.saved_tensors + opt = ctx.opt + grad_input = grad_grid = None + if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: + grads = grid_grad_backward(grad, *var, *opt) + grad_input, grad_grid = grads + return grad_input, grad_grid, None, None, None + + +class SplineCoeff(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward(ctx, input, bound, interpolation, dim, inplace): + + bound = bound_to_nitorch(make_list(bound)[0], as_type='int') + interpolation = inter_to_nitorch(make_list(interpolation)[0], as_type='int') + opt = (bound, interpolation, dim, inplace) + + # Pull + output = spline_coeff(input, *opt) + + # Context + if input.requires_grad: + ctx.opt = opt + + return output + + @staticmethod + @custom_bwd + def backward(ctx, grad): + # symmetric filter -> backward == forward + # (I don't know if I can write into grad, so inplace=False to be safe) + grad = spline_coeff(grad, *ctx.opt[:-1], inplace=False) + return [grad] + [None] * 4 + + +class SplineCoeffND(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward(ctx, input, bound, interpolation, dim, inplace): + + bound = bound_to_nitorch(make_list(bound), as_type='int') + interpolation = inter_to_nitorch(make_list(interpolation), as_type='int') + opt = (bound, interpolation, dim, inplace) + + # Pull + output = spline_coeff_nd(input, *opt) + + # Context + if input.requires_grad: + ctx.opt = opt + + return output + + @staticmethod + @custom_bwd + def backward(ctx, grad): + # symmetric filter -> backward == forward + # (I don't know if I can write into grad, so inplace=False to be safe) + grad = spline_coeff_nd(grad, *ctx.opt[:-1], inplace=False) + return grad, None, None, None, None diff --git a/utils/interpol/backend.py b/utils/interpol/backend.py new file mode 100644 index 0000000000000000000000000000000000000000..a5e3a8386a8f6f4c23c3932039d1e8540f6b7135 --- /dev/null +++ b/utils/interpol/backend.py @@ -0,0 +1 @@ +jitfields = False # Whether to use jitfields if available diff --git a/utils/interpol/bounds.py b/utils/interpol/bounds.py new file mode 100644 index 0000000000000000000000000000000000000000..67ece9415195b3547eee53a7e21e632ffe7c5a79 --- /dev/null +++ b/utils/interpol/bounds.py @@ -0,0 +1,89 @@ +import torch +from enum import Enum +from typing import Optional +from .jit_utils import floor_div +Tensor = torch.Tensor + + +class BoundType(Enum): + zero = zeros = 0 + replicate = nearest = 1 + dct1 = mirror = 2 + dct2 = reflect = 3 + dst1 = antimirror = 4 + dst2 = antireflect = 5 + dft = wrap = 6 + + +class ExtrapolateType(Enum): + no = 0 # threshold: (0, n-1) + yes = 1 + hist = 2 # threshold: (-0.5, n-0.5) + + +@torch.jit.script +class Bound: + + def __init__(self, bound_type: int = 3): + self.type = bound_type + + def index(self, i, n: int): + if self.type in (0, 1): # zero / replicate + return i.clamp(min=0, max=n-1) + elif self.type in (3, 5): # dct2 / dst2 + n2 = n * 2 + i = torch.where(i < 0, (-i-1).remainder(n2).neg().add(n2 - 1), + i.remainder(n2)) + i = torch.where(i >= n, -i + (n2 - 1), i) + return i + elif self.type == 2: # dct1 + if n == 1: + return torch.zeros(i.shape, dtype=i.dtype, device=i.device) + else: + n2 = (n - 1) * 2 + i = i.abs().remainder(n2) + i = torch.where(i >= n, -i + n2, i) + return i + elif self.type == 4: # dst1 + n2 = 2 * (n + 1) + first = torch.zeros([1], dtype=i.dtype, device=i.device) + last = torch.full([1], n - 1, dtype=i.dtype, device=i.device) + i = torch.where(i < 0, -i - 2, i) + i = i.remainder(n2) + i = torch.where(i > n, -i + (n2 - 2), i) + i = torch.where(i == -1, first, i) + i = torch.where(i == n, last, i) + return i + elif self.type == 6: # dft + return i.remainder(n) + else: + return i + + def transform(self, i, n: int) -> Optional[Tensor]: + if self.type == 4: # dst1 + if n == 1: + return None + one = torch.ones([1], dtype=torch.int8, device=i.device) + zero = torch.zeros([1], dtype=torch.int8, device=i.device) + n2 = 2 * (n + 1) + i = torch.where(i < 0, -i + (n-1), i) + i = i.remainder(n2) + x = torch.where(i == 0, zero, one) + x = torch.where(i.remainder(n + 1) == n, zero, x) + i = floor_div(i, n+1) + x = torch.where(torch.remainder(i, 2) > 0, -x, x) + return x + elif self.type == 5: # dst2 + i = torch.where(i < 0, n - 1 - i, i) + x = torch.ones([1], dtype=torch.int8, device=i.device) + i = floor_div(i, n) + x = torch.where(torch.remainder(i, 2) > 0, -x, x) + return x + elif self.type == 0: # zero + one = torch.ones([1], dtype=torch.int8, device=i.device) + zero = torch.zeros([1], dtype=torch.int8, device=i.device) + outbounds = ((i < 0) | (i >= n)) + x = torch.where(outbounds, zero, one) + return x + else: + return None diff --git a/utils/interpol/coeff.py b/utils/interpol/coeff.py new file mode 100644 index 0000000000000000000000000000000000000000..d1d6c047d090f1e60e18cf858f6ffe8454488f71 --- /dev/null +++ b/utils/interpol/coeff.py @@ -0,0 +1,344 @@ +"""Compute spline interpolating coefficients + +These functions are ported from the C routines in SPM's bsplines.c +by John Ashburner, which are themselves ports from Philippe Thevenaz's +code. JA furthermore derived the initial conditions for the DFT ("wrap around") +boundary conditions. + +Note that similar routines are available in scipy with boundary conditions +DCT1 ("mirror"), DCT2 ("reflect") and DFT ("wrap"); all derived by P. Thevenaz, +according to the comments. Our DCT2 boundary conditions are ported from +scipy. + +Only boundary conditions DCT1, DCT2 and DFT are implemented. + +References +---------- +..[1] M. Unser, A. Aldroubi and M. Eden. + "B-Spline Signal Processing: Part I-Theory," + IEEE Transactions on Signal Processing 41(2):821-832 (1993). +..[2] M. Unser, A. Aldroubi and M. Eden. + "B-Spline Signal Processing: Part II-Efficient Design and Applications," + IEEE Transactions on Signal Processing 41(2):834-848 (1993). +..[3] M. Unser. + "Splines: A Perfect Fit for Signal and Image Processing," + IEEE Signal Processing Magazine 16(6):22-38 (1999). +""" +import torch +import math +from typing import List, Optional +from .jit_utils import movedim1 +from .pushpull import pad_list_int + + +@torch.jit.script +def get_poles(order: int) -> List[float]: + empty: List[float] = [] + if order in (0, 1): + return empty + if order == 2: + return [math.sqrt(8.) - 3.] + if order == 3: + return [math.sqrt(3.) - 2.] + if order == 4: + return [math.sqrt(664. - math.sqrt(438976.)) + math.sqrt(304.) - 19., + math.sqrt(664. + math.sqrt(438976.)) - math.sqrt(304.) - 19.] + if order == 5: + return [math.sqrt(67.5 - math.sqrt(4436.25)) + math.sqrt(26.25) - 6.5, + math.sqrt(67.5 + math.sqrt(4436.25)) - math.sqrt(26.25) - 6.5] + if order == 6: + return [-0.488294589303044755130118038883789062112279161239377608394, + -0.081679271076237512597937765737059080653379610398148178525368, + -0.00141415180832581775108724397655859252786416905534669851652709] + if order == 7: + return [-0.5352804307964381655424037816816460718339231523426924148812, + -0.122554615192326690515272264359357343605486549427295558490763, + -0.0091486948096082769285930216516478534156925639545994482648003] + raise NotImplementedError + + +@torch.jit.script +def get_gain(poles: List[float]) -> float: + lam: float = 1. + for pole in poles: + lam *= (1. - pole) * (1. - 1./pole) + return lam + + +@torch.jit.script +def dft_initial(inp, pole: float, dim: int = -1, keepdim: bool = False): + + assert inp.shape[dim] > 1 + max_iter: int = int(math.ceil(-30./math.log(abs(pole)))) + max_iter = min(max_iter, inp.shape[dim]) + + poles = torch.as_tensor(pole, dtype=inp.dtype, device=inp.device) + poles = poles.pow(torch.arange(1, max_iter, dtype=inp.dtype, device=inp.device)) + poles = poles.flip(0) + + inp = movedim1(inp, dim, 0) + inp0 = inp[0] + inp = inp[1-max_iter:] + inp = movedim1(inp, 0, -1) + out = torch.matmul(inp.unsqueeze(-2), poles.unsqueeze(-1)).squeeze(-1) + out = out + inp0.unsqueeze(-1) + if keepdim: + out = movedim1(out, -1, dim) + else: + out = out.squeeze(-1) + + pole = pole ** max_iter + out = out / (1 - pole) + return out + + +@torch.jit.script +def dct1_initial(inp, pole: float, dim: int = -1, keepdim: bool = False): + + n = inp.shape[dim] + max_iter: int = int(math.ceil(-30./math.log(abs(pole)))) + + if max_iter < n: + + poles = torch.as_tensor(pole, dtype=inp.dtype, device=inp.device) + poles = poles.pow(torch.arange(1, max_iter, dtype=inp.dtype, device=inp.device)) + + inp = movedim1(inp, dim, 0) + inp0 = inp[0] + inp = inp[1:max_iter] + inp = movedim1(inp, 0, -1) + out = torch.matmul(inp.unsqueeze(-2), poles.unsqueeze(-1)).squeeze(-1) + out = out + inp0.unsqueeze(-1) + if keepdim: + out = movedim1(out, -1, dim) + else: + out = out.squeeze(-1) + + else: + max_iter = n + + polen = pole ** (n - 1) + inp0 = inp[0] + polen * inp[-1] + inp = inp[1:-1] + inp = movedim1(inp, 0, -1) + + poles = torch.as_tensor(pole, dtype=inp.dtype, device=inp.device) + poles = poles.pow(torch.arange(1, n-1, dtype=inp.dtype, device=inp.device)) + poles = poles + (polen * polen) / poles + + out = torch.matmul(inp.unsqueeze(-2), poles.unsqueeze(-1)).squeeze(-1) + out = out + inp0.unsqueeze(-1) + if keepdim: + out = movedim1(out, -1, dim) + else: + out = out.squeeze(-1) + + pole = pole ** (max_iter - 1) + out = out / (1 - pole * pole) + + return out + + +@torch.jit.script +def dct2_initial(inp, pole: float, dim: int = -1, keepdim: bool = False): + # Ported from scipy: + # https://github.com/scipy/scipy/blob/master/scipy/ndimage/src/ni_splines.c + # + # I (YB) unwarped and simplied the terms so that I could use a dot + # product instead of a loop. + # It should certainly be possible to derive a version for max_iter < n, + # as JA did for DCT1, to avoid long recursions when `n` is large. But + # I think it would require a more complicated anticausal/final condition. + + n = inp.shape[dim] + + polen = pole ** n + pole_last = polen * (1 + 1/(pole + polen * polen)) + inp00 = inp[0] + inp0 = inp[0] + pole_last * inp[-1] + inp = inp[1:-1] + inp = movedim1(inp, 0, -1) + + poles = torch.as_tensor(pole, dtype=inp.dtype, device=inp.device) + poles = (poles.pow(torch.arange(1, n-1, dtype=inp.dtype, device=inp.device)) + + poles.pow(torch.arange(2*n-2, n, -1, dtype=inp.dtype, device=inp.device))) + + out = torch.matmul(inp.unsqueeze(-2), poles.unsqueeze(-1)).squeeze(-1) + + out = out + inp0.unsqueeze(-1) + out = out * (pole / (1 - polen * polen)) + out = out + inp00.unsqueeze(-1) + + if keepdim: + out = movedim1(out, -1, dim) + else: + out = out.squeeze(-1) + + return out + + +@torch.jit.script +def dft_final(inp, pole: float, dim: int = -1, keepdim: bool = False): + + assert inp.shape[dim] > 1 + max_iter: int = int(math.ceil(-30./math.log(abs(pole)))) + max_iter = min(max_iter, inp.shape[dim]) + + poles = torch.as_tensor(pole, dtype=inp.dtype, device=inp.device) + poles = poles.pow(torch.arange(2, max_iter+1, dtype=inp.dtype, device=inp.device)) + + inp = movedim1(inp, dim, 0) + inp0 = inp[-1] + inp = inp[:max_iter-1] + inp = movedim1(inp, 0, -1) + out = torch.matmul(inp.unsqueeze(-2), poles.unsqueeze(-1)).squeeze(-1) + out = out.add(inp0.unsqueeze(-1), alpha=pole) + if keepdim: + out = movedim1(out, -1, dim) + else: + out = out.squeeze(-1) + + pole = pole ** max_iter + out = out / (pole - 1) + return out + + +@torch.jit.script +def dct1_final(inp, pole: float, dim: int = -1, keepdim: bool = False): + inp = movedim1(inp, dim, 0) + out = pole * inp[-2] + inp[-1] + out = out * (pole / (pole*pole - 1)) + if keepdim: + out = movedim1(out.unsqueeze(0), 0, dim) + return out + + +@torch.jit.script +def dct2_final(inp, pole: float, dim: int = -1, keepdim: bool = False): + # Ported from scipy: + # https://github.com/scipy/scipy/blob/master/scipy/ndimage/src/ni_splines.c + inp = movedim1(inp, dim, 0) + out = inp[-1] * (pole / (pole - 1)) + if keepdim: + out = movedim1(out.unsqueeze(0), 0, dim) + return out + + +@torch.jit.script +class CoeffBound: + + def __init__(self, bound: int): + self.bound = bound + + def initial(self, inp, pole: float, dim: int = -1, keepdim: bool = False): + if self.bound in (0, 2): # zero, dct1 + return dct1_initial(inp, pole, dim, keepdim) + elif self.bound in (1, 3): # nearest, dct2 + return dct2_initial(inp, pole, dim, keepdim) + elif self.bound == 6: # dft + return dft_initial(inp, pole, dim, keepdim) + else: + raise NotImplementedError + + def final(self, inp, pole: float, dim: int = -1, keepdim: bool = False): + if self.bound in (0, 2): # zero, dct1 + return dct1_final(inp, pole, dim, keepdim) + elif self.bound in (1, 3): # nearest, dct2 + return dct2_final(inp, pole, dim, keepdim) + elif self.bound == 6: # dft + return dft_final(inp, pole, dim, keepdim) + else: + raise NotImplementedError + + +@torch.jit.script +def filter(inp, bound: CoeffBound, poles: List[float], + dim: int = -1, inplace: bool = False): + + if not inplace: + inp = inp.clone() + + if inp.shape[dim] == 1: + return inp + + gain = get_gain(poles) + inp *= gain + inp = movedim1(inp, dim, 0) + n = inp.shape[0] + + for pole in poles: + inp[0] = bound.initial(inp, pole, dim=0, keepdim=False) + + for i in range(1, n): + inp[i].add_(inp[i-1], alpha=pole) + + inp[-1] = bound.final(inp, pole, dim=0, keepdim=False) + + for i in range(n-2, -1, -1): + inp[i].neg_().add_(inp[i+1]).mul_(pole) + + inp = movedim1(inp, 0, dim) + return inp + + +@torch.jit.script +def spline_coeff(inp, bound: int, order: int, dim: int = -1, + inplace: bool = False): + """Compute the interpolating spline coefficients, for a given spline order + and boundary conditions, along a single dimension. + + Parameters + ---------- + inp : tensor + bound : {2: dct1, 6: dft} + order : {0..7} + dim : int, default=-1 + inplace : bool, default=False + + Returns + ------- + out : tensor + + """ + if not inplace: + inp = inp.clone() + + if order in (0, 1): + return inp + + poles = get_poles(order) + return filter(inp, CoeffBound(bound), poles, dim=dim, inplace=True) + + +@torch.jit.script +def spline_coeff_nd(inp, bound: List[int], order: List[int], + dim: Optional[int] = None, inplace: bool = False): + """Compute the interpolating spline coefficients, for a given spline order + and boundary condition, along the last `dim` dimensions. + + Parameters + ---------- + inp : (..., *spatial) tensor + bound : List[{2: dct1, 6: dft}] + order : List[{0..7}] + dim : int, default=`inp.dim()` + inplace : bool, default=False + + Returns + ------- + out : (..., *spatial) tensor + + """ + if not inplace: + inp = inp.clone() + + if dim is None: + dim = inp.dim() + + bound = pad_list_int(bound, dim) + order = pad_list_int(order, dim) + + for d, b, o in zip(range(dim), bound, order): + inp = spline_coeff(inp, b, o, dim=-dim + d, inplace=True) + + return inp diff --git a/utils/interpol/iso0.py b/utils/interpol/iso0.py new file mode 100644 index 0000000000000000000000000000000000000000..7f43a81be437ac96cb9dd39d9f735ca6890c6f95 --- /dev/null +++ b/utils/interpol/iso0.py @@ -0,0 +1,368 @@ +"""Isotropic 0-th order splines ("nearest neighbor")""" +import torch +from .bounds import Bound +from .jit_utils import (sub2ind_list, make_sign, + inbounds_mask_3d, inbounds_mask_2d, inbounds_mask_1d) +from typing import List, Optional +Tensor = torch.Tensor + + +@torch.jit.script +def get_indices(g, n: int, bound: Bound): + g0 = g.round().long() + sign0 = bound.transform(g0, n) + g0 = bound.index(g0, n) + return g0, sign0 + + +# ====================================================================== +# 3D +# ====================================================================== + + +@torch.jit.script +def pull3d(inp, g, bound: List[Bound], extrapolate: int = 1): + """ + inp: (B, C, iX, iY, iZ) tensor + g: (B, oX, oY, oZ, 3) tensor + bound: List{3}[Bound] tensor + extrapolate: ExtrapolateType + returns: (B, C, oX, oY, oZ) tensor + """ + dim = 3 + boundx, boundy, boundz = bound + oshape = g.shape[-dim-1:-1] + g = g.reshape([g.shape[0], 1, -1, dim]) + gx, gy, gz = g.unbind(-1) + batch = max(inp.shape[0], gx.shape[0]) + channel = inp.shape[1] + shape = inp.shape[-dim:] + nx, ny, nz = shape + + # mask of inbounds voxels + mask = inbounds_mask_3d(extrapolate, gx, gy, gz, nx, ny, nz) + + # nearest integer coordinates + gx, signx = get_indices(gx, nx, boundx) + gy, signy = get_indices(gy, ny, boundy) + gz, signz = get_indices(gz, nz, boundz) + + # gather + inp = inp.reshape(inp.shape[:2] + [-1]) + idx = sub2ind_list([gx, gy, gz], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out = inp.gather(-1, idx) + sign = make_sign([signx, signy, signz]) + if sign is not None: + out *= sign + if mask is not None: + out *= mask + out = out.reshape(out.shape[:2] + oshape) + return out + + +@torch.jit.script +def push3d(inp, g, shape: Optional[List[int]], bound: List[Bound], + extrapolate: int = 1): + """ + inp: (B, C, iX, iY, iZ) tensor + g: (B, iX, iY, iZ, 3) tensor + shape: List{3}[int], optional + bound: List{3}[Bound] tensor + extrapolate: ExtrapolateType + returns: (B, C, *shape) tensor + """ + dim = 3 + boundx, boundy, boundz = bound + if inp.shape[-dim:] != g.shape[-dim-1:-1]: + raise ValueError('Input and grid should have the same spatial shape') + ishape = inp.shape[-dim:] + g = g.reshape([g.shape[0], 1, -1, dim]) + gx, gy, gz = torch.unbind(g, -1) + inp = inp.reshape(inp.shape[:2] + [-1]) + batch = max(inp.shape[0], gx.shape[0]) + channel = inp.shape[1] + + if shape is None: + shape = ishape + nx, ny, nz = shape + + # mask of inbounds voxels + mask = inbounds_mask_3d(extrapolate, gx, gy, gz, nx, ny, nz) + + # nearest integer coordinates + gx, signx = get_indices(gx, nx, boundx) + gy, signy = get_indices(gy, ny, boundy) + gz, signz = get_indices(gz, nz, boundz) + + # scatter + out = torch.zeros([batch, channel, nx*ny*nz], dtype=inp.dtype, device=inp.device) + idx = sub2ind_list([gx, gy, gz], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + sign = make_sign([signx, signy, signz]) + if sign is not None or mask is not None: + inp = inp.clone() + if sign is not None: + inp *= sign + if mask is not None: + inp *= mask + out.scatter_add_(-1, idx, inp) + + out = out.reshape(out.shape[:2] + shape) + return out + + +# ====================================================================== +# 2D +# ====================================================================== + + +@torch.jit.script +def pull2d(inp, g, bound: List[Bound], extrapolate: int = 1): + """ + inp: (B, C, iX, iY) tensor + g: (B, oX, oY, 2) tensor + bound: List{2}[Bound] tensor + extrapolate: ExtrapolateType + returns: (B, C, oX, oY) tensor + """ + dim = 2 + boundx, boundy = bound + oshape = g.shape[-dim-1:-1] + g = g.reshape([g.shape[0], 1, -1, dim]) + gx, gy = g.unbind(-1) + batch = max(inp.shape[0], gx.shape[0]) + channel = inp.shape[1] + shape = inp.shape[-dim:] + nx, ny = shape + + # mask of inbounds voxels + mask = inbounds_mask_2d(extrapolate, gx, gy, nx, ny) + + # nearest integer coordinates + gx, signx = get_indices(gx, nx, boundx) + gy, signy = get_indices(gy, ny, boundy) + + # gather + inp = inp.reshape(inp.shape[:2] + [-1]) + idx = sub2ind_list([gx, gy], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out = inp.gather(-1, idx) + sign = make_sign([signx, signy]) + if sign is not None: + out = out * sign + if mask is not None: + out = mask * mask + out = out.reshape(out.shape[:2] + oshape) + return out + + +@torch.jit.script +def push2d(inp, g, shape: Optional[List[int]], bound: List[Bound], + extrapolate: int = 1): + """ + inp: (B, C, iX, iY) tensor + g: (B, iX, iY, 2) tensor + shape: List{2}[int], optional + bound: List{2}[Bound] tensor + extrapolate: ExtrapolateType + returns: (B, C, *shape) tensor + """ + dim = 2 + boundx, boundy = bound + if inp.shape[-dim:] != g.shape[-dim-1:-1]: + raise ValueError('Input and grid should have the same spatial shape') + ishape = inp.shape[-dim:] + g = g.reshape([g.shape[0], 1, -1, dim]) + gx, gy = torch.unbind(g, -1) + inp = inp.reshape(inp.shape[:2] + [-1]) + batch = max(inp.shape[0], gx.shape[0]) + channel = inp.shape[1] + + if shape is None: + shape = ishape + nx, ny = shape + + # mask of inbounds voxels + mask = inbounds_mask_2d(extrapolate, gx, gy, nx, ny) + + # nearest integer coordinates + gx, signx = get_indices(gx, nx, boundx) + gy, signy = get_indices(gy, ny, boundy) + + # scatter + out = torch.zeros([batch, channel, nx*ny], dtype=inp.dtype, device=inp.device) + idx = sub2ind_list([gx, gy], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + sign = make_sign([signx, signy]) + if sign is not None or mask is not None: + inp = inp.clone() + if sign is not None: + inp = inp * sign + if mask is not None: + inp = inp * mask + out.scatter_add_(-1, idx, inp) + + out = out.reshape(out.shape[:2] + shape) + return out + + +# ====================================================================== +# 1D +# ====================================================================== + + +@torch.jit.script +def pull1d(inp, g, bound: List[Bound], extrapolate: int = 1): + """ + inp: (B, C, iX) tensor + g: (B, oX, 1) tensor + bound: List{1}[Bound] tensor + extrapolate: ExtrapolateType + returns: (B, C, oX) tensor + """ + dim = 1 + boundx = bound[0] + oshape = g.shape[-dim-1:-1] + g = g.reshape([g.shape[0], 1, -1, dim]) + gx = g.squeeze(-1) + batch = max(inp.shape[0], gx.shape[0]) + channel = inp.shape[1] + shape = inp.shape[-dim:] + nx = shape[0] + + # mask of inbounds voxels + mask = inbounds_mask_1d(extrapolate, gx, nx) + + # nearest integer coordinates + gx, signx = get_indices(gx, nx, boundx) + + # gather + inp = inp.reshape(inp.shape[:2] + [-1]) + idx = gx + idx = idx.expand([batch, channel, idx.shape[-1]]) + out = inp.gather(-1, idx) + sign = signx + if sign is not None: + out = out * sign + if mask is not None: + out = out * mask + out = out.reshape(out.shape[:2] + oshape) + return out + + +@torch.jit.script +def push1d(inp, g, shape: Optional[List[int]], bound: List[Bound], + extrapolate: int = 1): + """ + inp: (B, C, iX) tensor + g: (B, iX, 1) tensor + shape: List{1}[int], optional + bound: List{1}[Bound] tensor + extrapolate: ExtrapolateType + returns: (B, C, *shape) tensor + """ + dim = 1 + boundx = bound[0] + if inp.shape[-dim:] != g.shape[-dim-1:-1]: + raise ValueError('Input and grid should have the same spatial shape') + ishape = inp.shape[-dim:] + g = g.reshape([g.shape[0], 1, -1, dim]) + gx = g.squeeze(-1) + inp = inp.reshape(inp.shape[:2] + [-1]) + batch = max(inp.shape[0], gx.shape[0]) + channel = inp.shape[1] + + if shape is None: + shape = ishape + nx = shape[0] + + # mask of inbounds voxels + mask = inbounds_mask_1d(extrapolate, gx, nx) + + # nearest integer coordinates + gx, signx = get_indices(gx, nx, boundx) + + # scatter + out = torch.zeros([batch, channel, nx], dtype=inp.dtype, device=inp.device) + idx = gx + idx = idx.expand([batch, channel, idx.shape[-1]]) + sign = signx + if sign is not None or mask is not None: + inp = inp.clone() + if sign is not None: + inp = inp * sign + if mask is not None: + inp = inp * mask + out.scatter_add_(-1, idx, inp) + + out = out.reshape(out.shape[:2] + shape) + return out + + +# ====================================================================== +# ND +# ====================================================================== + + +@torch.jit.script +def grad(inp, g, bound: List[Bound], extrapolate: int = 1): + """ + inp: (B, C, *ishape) tensor + g: (B, *oshape, D) tensor + bound: List{D}[Bound] tensor + extrapolate: ExtrapolateType + returns: (B, C, *oshape, D) tensor + """ + dim = g.shape[-1] + oshape = list(g.shape[-dim-1:-1]) + batch = max(inp.shape[0], g.shape[0]) + channel = inp.shape[1] + + return torch.zeros([batch, channel] + oshape + [dim], + dtype=inp.dtype, device=inp.device) + + +@torch.jit.script +def pushgrad(inp, g, shape: Optional[List[int]], bound: List[Bound], + extrapolate: int = 1): + """ + inp: (B, C, *ishape, D) tensor + g: (B, *ishape, D) tensor + shape: List{D}[int], optional, optional + bound: List{D}[Bound] tensor + extrapolate: ExtrapolateType + returns: (B, C, *shape) tensor + """ + dim = g.shape[-1] + if inp.shape[-dim-1:-1] != g.shape[-dim-1:-1]: + raise ValueError('Input and grid should have the same spatial shape') + ishape = inp.shape[-dim-1:-1] + batch = max(inp.shape[0], g.shape[0]) + channel = inp.shape[1] + + if shape is None: + shape = ishape + shape = list(shape) + + return torch.zeros([batch, channel] + shape, + dtype=inp.dtype, device=inp.device) + + +@torch.jit.script +def hess(inp, g, bound: List[Bound], extrapolate: int = 1): + """ + inp: (B, C, *ishape) tensor + g: (B, *oshape, D) tensor + bound: List{D}[Bound] tensor + extrapolate: ExtrapolateType + returns: (B, C, *oshape, D, D) tensor + """ + dim = g.shape[-1] + oshape = list(g.shape[-dim-1:-1]) + g = g.reshape([g.shape[0], 1, -1, dim]) + batch = max(inp.shape[0], g.shape[0]) + channel = inp.shape[1] + + return torch.zeros([batch, channel] + oshape + [dim, dim], + dtype=inp.dtype, device=inp.device) diff --git a/utils/interpol/iso1.py b/utils/interpol/iso1.py new file mode 100644 index 0000000000000000000000000000000000000000..fa21f12d9a68532cec2ebdcb4ce6ef7c75d6d6a6 --- /dev/null +++ b/utils/interpol/iso1.py @@ -0,0 +1,1339 @@ +"""Isotropic 1-st order splines ("linear/bilinear/trilinear")""" +import torch +from .bounds import Bound +from .jit_utils import (sub2ind_list, make_sign, + inbounds_mask_3d, inbounds_mask_2d, inbounds_mask_1d) +from typing import List, Tuple, Optional +Tensor = torch.Tensor + + +@torch.jit.script +def get_weights_and_indices(g, n: int, bound: Bound) \ + -> Tuple[Tensor, Tensor, Tensor, Optional[Tensor], Optional[Tensor]]: + g0 = g.floor().long() + g1 = g0 + 1 + sign1 = bound.transform(g1, n) + sign0 = bound.transform(g0, n) + g1 = bound.index(g1, n) + g0 = bound.index(g0, n) + g = g - g.floor() + return g, g0, g1, sign0, sign1 + + +# ====================================================================== +# 3D +# ====================================================================== + + +@torch.jit.script +def pull3d(inp, g, bound: List[Bound], extrapolate: int = 1): + """ + inp: (B, C, iX, iY, iZ) tensor + g: (B, oX, oY, oZ, 3) tensor + bound: List{3}[Bound] tensor + extrapolate: ExtrapolateType + returns: (B, C, oX, oY, oZ) tensor + """ + dim = 3 + boundx, boundy, boundz = bound + oshape = list(g.shape[-dim-1:-1]) + g = g.reshape([g.shape[0], 1, -1, dim]) + gx, gy, gz = g.unbind(-1) + batch = max(inp.shape[0], gx.shape[0]) + channel = inp.shape[1] + shape = list(inp.shape[-dim:]) + nx, ny, nz = shape + + # mask of inbounds voxels + mask = inbounds_mask_3d(extrapolate, gx, gy, gz, nx, ny, nz) + + # corners + # (upper weight, lower corner, upper corner, lower sign, upper sign) + gx, gx0, gx1, signx0, signx1 = get_weights_and_indices(gx, nx, boundx) + gy, gy0, gy1, signy0, signy1 = get_weights_and_indices(gy, ny, boundy) + gz, gz0, gz1, signz0, signz1 = get_weights_and_indices(gz, nz, boundz) + + # gather + inp = inp.reshape(list(inp.shape[:2]) + [-1]) + # - corner 000 + idx = sub2ind_list([gx0, gy0, gz0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out = inp.gather(-1, idx) + sign = make_sign([signx0, signy0, signz0]) + if sign is not None: + out = out * sign + out = out * ((1 - gx) * (1 - gy) * (1 - gz)) + # - corner 001 + idx = sub2ind_list([gx0, gy0, gz1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx0, signy0, signz1]) + if sign is not None: + out1 = out1 * sign + out1 = out1 * ((1 - gx) * (1 - gy) * gz) + out = out + out1 + # - corner 010 + idx = sub2ind_list([gx0, gy1, gz0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx0, signy1, signz0]) + if sign is not None: + out1 = out1 * sign + out1 = out1 * ((1 - gx) * gy * (1 - gz)) + out = out + out1 + # - corner 011 + idx = sub2ind_list([gx0, gy1, gz1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx0, signy1, signz1]) + if sign is not None: + out1 = out1 * sign + out1 = out1 * ((1 - gx) * gy * gz) + out = out + out1 + # - corner 100 + idx = sub2ind_list([gx1, gy0, gz0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx1, signy0, signz0]) + if sign is not None: + out1 = out1 * sign + out1 = out1 * (gx * (1 - gy) * (1 - gz)) + out = out + out1 + # - corner 101 + idx = sub2ind_list([gx1, gy0, gz1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx1, signy0, signz1]) + if sign is not None: + out1 = out1 * sign + out1 = out1 * (gx * (1 - gy) * gz) + out = out + out1 + # - corner 110 + idx = sub2ind_list([gx1, gy1, gz0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx1, signy1, signz0]) + if sign is not None: + out1 = out1 * sign + out1 = out1 * (gx * gy * (1 - gz)) + out = out + out1 + # - corner 111 + idx = sub2ind_list([gx1, gy1, gz1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx1, signy1, signz1]) + if sign is not None: + out1 = out1 * sign + out1 = out1 * (gx * gy * gz) + out = out + out1 + + if mask is not None: + out *= mask + out = out.reshape(list(out.shape[:2]) + oshape) + return out + + +@torch.jit.script +def push3d(inp, g, shape: Optional[List[int]], bound: List[Bound], + extrapolate: int = 1): + """ + inp: (B, C, iX, iY, iZ) tensor + g: (B, iX, iY, iZ, 3) tensor + shape: List{3}[int], optional + bound: List{3}[Bound] tensor + extrapolate: ExtrapolateType + returns: (B, C, *shape) tensor + """ + dim = 3 + boundx, boundy, boundz = bound + if inp.shape[-dim:] != g.shape[-dim-1:-1]: + raise ValueError('Input and grid should have the same spatial shape') + ishape = list(inp.shape[-dim:]) + g = g.reshape([g.shape[0], 1, -1, dim]) + gx, gy, gz = torch.unbind(g, -1) + inp = inp.reshape(list(inp.shape[:2]) + [-1]) + batch = max(inp.shape[0], gx.shape[0]) + channel = inp.shape[1] + + if shape is None: + shape = ishape + shape = list(shape) + nx, ny, nz = shape + + # mask of inbounds voxels + mask = inbounds_mask_3d(extrapolate, gx, gy, gz, nx, ny, nz) + + # corners + # (upper weight, lower corner, upper corner, lower sign, upper sign) + gx, gx0, gx1, signx0, signx1 = get_weights_and_indices(gx, nx, boundx) + gy, gy0, gy1, signy0, signy1 = get_weights_and_indices(gy, ny, boundy) + gz, gz0, gz1, signz0, signz1 = get_weights_and_indices(gz, nz, boundz) + + # scatter + out = torch.zeros([batch, channel, nx*ny*nz], + dtype=inp.dtype, device=inp.device) + # - corner 000 + idx = sub2ind_list([gx0, gy0, gz0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = make_sign([signx0, signy0, signz0]) + if sign is not None: + out1 = out1 * sign + if mask is not None: + out1 = out1 * mask + out1 = out1 * ((1 - gx) * (1 - gy) * (1 - gz)) + out.scatter_add_(-1, idx, out1) + # - corner 001 + idx = sub2ind_list([gx0, gy0, gz1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = make_sign([signx0, signy0, signz1]) + if sign is not None: + out1 = out1 * sign + if mask is not None: + out1 = out1 * mask + out1 = out1 * ((1 - gx) * (1 - gy) * gz) + out.scatter_add_(-1, idx, out1) + # - corner 010 + idx = sub2ind_list([gx0, gy1, gz0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = make_sign([signx0, signy1, signz0]) + if sign is not None: + out1 = out1 * sign + if mask is not None: + out1 = out1 * mask + out1 = out1 * ((1 - gx) * gy * (1 - gz)) + out.scatter_add_(-1, idx, out1) + # - corner 011 + idx = sub2ind_list([gx0, gy1, gz1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = make_sign([signx0, signy1, signz1]) + if sign is not None: + out1 = out1 * sign + if mask is not None: + out1 = out1 * mask + out1 = out1 * ((1 - gx) * gy * gz) + out.scatter_add_(-1, idx, out1) + # - corner 100 + idx = sub2ind_list([gx1, gy0, gz0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = make_sign([signx1, signy0, signz0]) + if sign is not None: + out1 = out1 * sign + if mask is not None: + out1 = out1 * mask + out1 = out1 * (gx * (1 - gy) * (1 - gz)) + out.scatter_add_(-1, idx, out1) + # - corner 101 + idx = sub2ind_list([gx1, gy0, gz1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = make_sign([signx1, signy0, signz1]) + if sign is not None: + out1 = out1 * sign + if mask is not None: + out1 = out1 * mask + out1 = out1 * (gx * (1 - gy) * gz) + out.scatter_add_(-1, idx, out1) + # - corner 110 + idx = sub2ind_list([gx1, gy1, gz0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = make_sign([signx1, signy1, signz0]) + if sign is not None: + out1 = out1 * sign + if mask is not None: + out1 = out1 * mask + out1 = out1 * (gx * gy * (1 - gz)) + out.scatter_add_(-1, idx, out1) + # - corner 111 + idx = sub2ind_list([gx1, gy1, gz1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = make_sign([signx1, signy1, signz1]) + if sign is not None: + out1 = out1 * sign + if mask is not None: + out1 = out1 * mask + out1 = out1 * (gx * gy * gz) + out.scatter_add_(-1, idx, out1) + + out = out.reshape(list(out.shape[:2]) + shape) + return out + + +@torch.jit.script +def grad3d(inp, g, bound: List[Bound], extrapolate: int = 1): + """ + inp: (B, C, iX, iY, iZ) tensor + g: (B, oX, oY, oZ, 3) tensor + bound: List{3}[Bound] tensor + extrapolate: ExtrapolateType + returns: (B, C, oX, oY, oZ, 3) tensor + """ + dim = 3 + boundx, boundy, boundz = bound + oshape = list(g.shape[-dim-1:-1]) + g = g.reshape([g.shape[0], 1, -1, dim]) + gx, gy, gz = torch.unbind(g, -1) + batch = max(inp.shape[0], gx.shape[0]) + channel = inp.shape[1] + shape = list(inp.shape[-dim:]) + nx, ny, nz = shape + + # mask of inbounds voxels + mask = inbounds_mask_3d(extrapolate, gx, gy, gz, nx, ny, nz) + + # corners + # (upper weight, lower corner, upper corner, lower sign, upper sign) + gx, gx0, gx1, signx0, signx1 = get_weights_and_indices(gx, nx, boundx) + gy, gy0, gy1, signy0, signy1 = get_weights_and_indices(gy, ny, boundy) + gz, gz0, gz1, signz0, signz1 = get_weights_and_indices(gz, nz, boundz) + + # gather + inp = inp.reshape(list(inp.shape[:2]) + [-1]) + out = torch.empty([batch, channel] + list(g.shape[-2:]), + dtype=inp.dtype, device=inp.device) + outx, outy, outz = out.unbind(-1) + # - corner 000 + idx = sub2ind_list([gx0, gy0, gz0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + torch.gather(inp, -1, idx, out=outx) + outy.copy_(outx) + outz.copy_(outx) + sign = make_sign([signx0, signy0, signz0]) + if sign is not None: + out *= sign.unsqueeze(-1) + outx *= - (1 - gy) * (1 - gz) + outy *= - (1 - gx) * (1 - gz) + outz *= - (1 - gx) * (1 - gy) + # - corner 001 + idx = sub2ind_list([gx0, gy0, gz1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx0, signy0, signz1]) + if sign is not None: + out1 *= sign + outx.addcmul_(out1, - (1 - gy) * gz) + outy.addcmul_(out1, - (1 - gx) * gz) + outz.addcmul_(out1, (1 - gx) * (1 - gy)) + # - corner 010 + idx = sub2ind_list([gx0, gy1, gz0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx0, signy1, signz0]) + if sign is not None: + out1 *= sign + outx.addcmul_(out1, - gy * (1 - gz)) + outy.addcmul_(out1, (1 - gx) * (1 - gz)) + outz.addcmul_(out1, - (1 - gx) * gy) + # - corner 011 + idx = sub2ind_list([gx0, gy1, gz1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx0, signy1, signz1]) + if sign is not None: + out1 *= sign + outx.addcmul_(out1, - gy * gz) + outy.addcmul_(out1, (1 - gx) * gz) + outz.addcmul_(out1, (1 - gx) * gy) + # - corner 100 + idx = sub2ind_list([gx1, gy0, gz0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx1, signy0, signz0]) + if sign is not None: + out1 *= sign + outx.addcmul_(out1, (1 - gy) * (1 - gz)) + outy.addcmul_(out1, - gx * (1 - gz)) + outz.addcmul_(out1, - gx * (1 - gy)) + # - corner 101 + idx = sub2ind_list([gx1, gy0, gz1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx1, signy0, signz1]) + if sign is not None: + out1 *= sign + outx.addcmul_(out1, (1 - gy) * gz) + outy.addcmul_(out1, - gx * gz) + outz.addcmul_(out1, gx * (1 - gy)) + # - corner 110 + idx = sub2ind_list([gx1, gy1, gz0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx1, signy1, signz0]) + if sign is not None: + out1 *= sign + outx.addcmul_(out1, gy * (1 - gz)) + outy.addcmul_(out1, gx * (1 - gz)) + outz.addcmul_(out1, - gx * gy) + # - corner 111 + idx = sub2ind_list([gx1, gy1, gz1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx1, signy1, signz1]) + if sign is not None: + out1 *= sign + outx.addcmul_(out1, gy * gz) + outy.addcmul_(out1, gx * gz) + outz.addcmul_(out1, gx * gy) + + if mask is not None: + out *= mask.unsqueeze(-1) + out = out.reshape(list(out.shape[:2]) + oshape + [3]) + return out + + +@torch.jit.script +def pushgrad3d(inp, g, shape: Optional[List[int]], bound: List[Bound], + extrapolate: int = 1): + """ + inp: (B, C, iX, iY, iZ, 3) tensor + g: (B, iX, iY, iZ, 3) tensor + shape: List{3}[int], optional + bound: List{3}[Bound] tensor + extrapolate: ExtrapolateType + returns: (B, C, *shape) tensor + """ + dim = 3 + boundx, boundy, boundz = bound + if inp.shape[-dim-1:-1] != g.shape[-dim-1:-1]: + raise ValueError('Input and grid should have the same spatial shape') + ishape = list(inp.shape[-dim-1:-1]) + g = g.reshape([g.shape[0], 1, -1, dim]) + gx, gy, gz = g.unbind(-1) + inp = inp.reshape(list(inp.shape[:2]) + [-1, dim]) + batch = max(inp.shape[0], g.shape[0]) + channel = inp.shape[1] + + if shape is None: + shape = ishape + shape = list(shape) + nx, ny, nz = shape + + # mask of inbounds voxels + mask = inbounds_mask_3d(extrapolate, gx, gy, gz, nx, ny, nz) + + # corners + # (upper weight, lower corner, upper corner, lower sign, upper sign) + gx, gx0, gx1, signx0, signx1 = get_weights_and_indices(gx, nx, boundx) + gy, gy0, gy1, signy0, signy1 = get_weights_and_indices(gy, ny, boundy) + gz, gz0, gz1, signz0, signz1 = get_weights_and_indices(gz, nz, boundz) + + # scatter + out = torch.zeros([batch, channel, nx*ny*nz], + dtype=inp.dtype, device=inp.device) + # - corner 000 + idx = sub2ind_list([gx0, gy0, gz0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = make_sign([signx0, signy0, signz0]) + if sign is not None: + out1 *= sign.unsqueeze(-1) + if mask is not None: + out1 *= mask.unsqueeze(-1) + out1x, out1y, out1z = out1.unbind(-1) + out1x *= - (1 - gy) * (1 - gz) + out1y *= - (1 - gx) * (1 - gz) + out1z *= - (1 - gx) * (1 - gy) + out.scatter_add_(-1, idx, out1x + out1y + out1z) + # - corner 001 + idx = sub2ind_list([gx0, gy0, gz1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = make_sign([signx0, signy0, signz1]) + if sign is not None: + out1 *= sign.unsqueeze(-1) + if mask is not None: + out1 *= mask.unsqueeze(-1) + out1x, out1y, out1z = out1.unbind(-1) + out1x *= - (1 - gy) * gz + out1y *= - (1 - gx) * gz + out1z *= (1 - gx) * (1 - gy) + out.scatter_add_(-1, idx, out1x + out1y + out1z) + # - corner 010 + idx = sub2ind_list([gx0, gy1, gz0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = make_sign([signx0, signy1, signz0]) + if sign is not None: + out1 *= sign.unsqueeze(-1) + if mask is not None: + out1 *= mask.unsqueeze(-1) + out1x, out1y, out1z = out1.unbind(-1) + out1x *= - gy * (1 - gz) + out1y *= (1 - gx) * (1 - gz) + out1z *= - (1 - gx) * gy + out.scatter_add_(-1, idx, out1x + out1y + out1z) + # - corner 011 + idx = sub2ind_list([gx0, gy1, gz1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = make_sign([signx0, signy1, signz1]) + if sign is not None: + out1 *= sign.unsqueeze(-1) + if mask is not None: + out1 *= mask.unsqueeze(-1) + out1x, out1y, out1z = out1.unbind(-1) + out1x *= - gy * gz + out1y *= (1 - gx) * gz + out1z *= (1 - gx) * gy + out.scatter_add_(-1, idx, out1x + out1y + out1z) + # - corner 100 + idx = sub2ind_list([gx1, gy0, gz0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = make_sign([signx1, signy0, signz0]) + if sign is not None: + out1 *= sign.unsqueeze(-1) + if mask is not None: + out1 *= mask.unsqueeze(-1) + out1x, out1y, out1z = out1.unbind(-1) + out1x *= (1 - gy) * (1 - gz) + out1y *= - gx * (1 - gz) + out1z *= - gx * (1 - gy) + out.scatter_add_(-1, idx, out1x + out1y + out1z) + # - corner 101 + idx = sub2ind_list([gx1, gy0, gz1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = make_sign([signx1, signy0, signz1]) + if sign is not None: + out1 *= sign.unsqueeze(-1) + if mask is not None: + out1 *= mask.unsqueeze(-1) + out1x, out1y, out1z = out1.unbind(-1) + out1x *= (1 - gy) * gz + out1y *= - gx * gz + out1z *= gx * (1 - gy) + out.scatter_add_(-1, idx, out1x + out1y + out1z) + # - corner 110 + idx = sub2ind_list([gx1, gy1, gz0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = make_sign([signx1, signy1, signz0]) + if sign is not None: + out1 *= sign.unsqueeze(-1) + if mask is not None: + out1 *= mask.unsqueeze(-1) + out1x, out1y, out1z = out1.unbind(-1) + out1x *= gy * (1 - gz) + out1y *= gx * (1 - gz) + out1z *= - gx * gy + out.scatter_add_(-1, idx, out1x + out1y + out1z) + # - corner 111 + idx = sub2ind_list([gx1, gy1, gz1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = make_sign([signx1, signy1, signz1]) + if sign is not None: + out1 *= sign.unsqueeze(-1) + if mask is not None: + out1 *= mask.unsqueeze(-1) + out1x, out1y, out1z = out1.unbind(-1) + out1x *= gy * gz + out1y *= gx * gz + out1z *= gx * gy + out.scatter_add_(-1, idx, out1x + out1y + out1z) + + out = out.reshape(list(out.shape[:2]) + shape) + return out + + +@torch.jit.script +def hess3d(inp, g, bound: List[Bound], extrapolate: int = 1): + """ + inp: (B, C, iX, iY, iZ) tensor + g: (B, oX, oY, oZ, 3) tensor + bound: List{3}[Bound] tensor + extrapolate: ExtrapolateType + returns: (B, C, oX, oY, oZ, 3, 3) tensor + """ + dim = 3 + boundx, boundy, boundz = bound + oshape = list(g.shape[-dim-1:-1]) + g = g.reshape([g.shape[0], 1, -1, dim]) + gx, gy, gz = torch.unbind(g, -1) + batch = max(inp.shape[0], gx.shape[0]) + channel = inp.shape[1] + shape = list(inp.shape[-dim:]) + nx, ny, nz = shape + + # mask of inbounds voxels + mask = inbounds_mask_3d(extrapolate, gx, gy, gz, nx, ny, nz) + + # corners + # (upper weight, lower corner, upper corner, lower sign, upper sign) + gx, gx0, gx1, signx0, signx1 = get_weights_and_indices(gx, nx, boundx) + gy, gy0, gy1, signy0, signy1 = get_weights_and_indices(gy, ny, boundy) + gz, gz0, gz1, signz0, signz1 = get_weights_and_indices(gz, nz, boundz) + + # gather + inp = inp.reshape(list(inp.shape[:2]) + [-1]) + out = torch.empty([batch, channel, g.shape[-2], dim, dim], + dtype=inp.dtype, device=inp.device) + outx, outy, outz = out.unbind(-1) + outxx, outyx, outzx = outx.unbind(-1) + outxy, outyy, outzy = outy.unbind(-1) + outxz, outyz, outzz = outz.unbind(-1) + # - corner 000 + idx = sub2ind_list([gx0, gy0, gz0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + torch.gather(inp, -1, idx, out=outxy) + outxz.copy_(outxy) + outyz.copy_(outxy) + outxx.zero_() + outyy.zero_() + outzz.zero_() + sign = make_sign([signx0, signy0, signz0]) + if sign is not None: + out *= sign.unsqueeze(-1).unsqueeze(-1) + outxy *= (1 - gz) + outxz *= (1 - gy) + outyz *= (1 - gx) + # - corner 001 + idx = sub2ind_list([gx0, gy0, gz1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx0, signy0, signz1]) + if sign is not None: + out1 *= sign + outxy.addcmul_(out1, gz) + outxz.addcmul_(out1, - (1 - gy)) + outyz.addcmul_(out1, - (1 - gx)) + # - corner 010 + idx = sub2ind_list([gx0, gy1, gz0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx0, signy1, signz0]) + if sign is not None: + out1 *= sign + outxy.addcmul_(out1, - (1 - gz)) + outxz.addcmul_(out1, gy) + outyz.addcmul_(out1, - (1 - gx)) + # - corner 011 + idx = sub2ind_list([gx0, gy1, gz1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx0, signy1, signz1]) + if sign is not None: + out1 *= sign + outxy.addcmul_(out1, - gz) + outxz.addcmul_(out1, - gy) + outyz.addcmul_(out1, (1 - gx)) + # - corner 100 + idx = sub2ind_list([gx1, gy0, gz0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx1, signy0, signz0]) + if sign is not None: + out1 *= sign + outxy.addcmul_(out1, - (1 - gz)) + outxz.addcmul_(out1, - (1 - gy)) + outyz.addcmul_(out1, gx) + # - corner 101 + idx = sub2ind_list([gx1, gy0, gz1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx1, signy0, signz1]) + if sign is not None: + out1 *= sign + outxy.addcmul_(out1, - gz) + outxz.addcmul_(out1, (1 - gy)) + outyz.addcmul_(out1, - gx) + # - corner 110 + idx = sub2ind_list([gx1, gy1, gz0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx1, signy1, signz0]) + if sign is not None: + out1 *= sign + outxy.addcmul_(out1, (1 - gz)) + outxz.addcmul_(out1, - gy) + outyz.addcmul_(out1, - gx) + # - corner 111 + idx = sub2ind_list([gx1, gy1, gz1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx1, signy1, signz1]) + if sign is not None: + out1 *= sign + outxy.addcmul_(out1, gz) + outxz.addcmul_(out1, gy) + outyz.addcmul_(out1, gx) + + outyx.copy_(outxy) + outzx.copy_(outxz) + outzy.copy_(outyz) + + if mask is not None: + out *= mask.unsqueeze(-1).unsqueeze(-1) + out = out.reshape(list(out.shape[:2]) + oshape + [dim, dim]) + return out + + +# ====================================================================== +# 2D +# ====================================================================== + + +@torch.jit.script +def pull2d(inp, g, bound: List[Bound], extrapolate: int = 1): + """ + inp: (B, C, iX, iY) tensor + g: (B, oX, oY, 2) tensor + bound: List{2}[Bound] tensor + extrapolate: ExtrapolateType + returns: (B, C, oX, oY) tensor + """ + dim = 2 + boundx, boundy = bound + oshape = list(g.shape[-dim-1:-1]) + g = g.reshape([g.shape[0], 1, -1, dim]) + gx, gy = g.unbind(-1) + batch = max(inp.shape[0], gx.shape[0]) + channel = inp.shape[1] + shape = list(inp.shape[-dim:]) + nx, ny = shape + + # mask of inbounds voxels + mask = inbounds_mask_2d(extrapolate, gx, gy, nx, ny) + + # corners + # (upper weight, lower corner, upper corner, lower sign, upper sign) + gx, gx0, gx1, signx0, signx1 = get_weights_and_indices(gx, nx, boundx) + gy, gy0, gy1, signy0, signy1 = get_weights_and_indices(gy, ny, boundy) + + # gather + inp = inp.reshape(list(inp.shape[:2]) + [-1]) + # - corner 00 + idx = sub2ind_list([gx0, gy0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out = inp.gather(-1, idx) + sign = make_sign([signx0, signy0]) + if sign is not None: + out = out * sign + out = out * ((1 - gx) * (1 - gy)) + # - corner 01 + idx = sub2ind_list([gx0, gy1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx0, signy1]) + if sign is not None: + out1 = out1 * sign + out1 = out1 * ((1 - gx) * gy) + out = out + out1 + # - corner 10 + idx = sub2ind_list([gx1, gy0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx1, signy0]) + if sign is not None: + out1 = out1 * sign + out1 = out1 * (gx * (1 - gy)) + out = out + out1 + # - corner 11 + idx = sub2ind_list([gx1, gy1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx1, signy1]) + if sign is not None: + out1 = out1 * sign + out1 = out1 * (gx * gy) + out = out + out1 + + if mask is not None: + out *= mask + out = out.reshape(list(out.shape[:2]) + oshape) + return out + + +@torch.jit.script +def push2d(inp, g, shape: Optional[List[int]], bound: List[Bound], + extrapolate: int = 1): + """ + inp: (B, C, iX, iY) tensor + g: (B, iX, iY, 2) tensor + shape: List{2}[int], optional + bound: List{2}[Bound] tensor + extrapolate: ExtrapolateType + returns: (B, C, *shape) tensor + """ + dim = 2 + boundx, boundy = bound + if inp.shape[-dim:] != g.shape[-dim-1:-1]: + raise ValueError('Input and grid should have the same spatial shape') + ishape = list(inp.shape[-dim:]) + g = g.reshape([g.shape[0], 1, -1, dim]) + gx, gy = torch.unbind(g, -1) + inp = inp.reshape(list(inp.shape[:2]) + [-1]) + batch = max(inp.shape[0], gx.shape[0]) + channel = inp.shape[1] + + if shape is None: + shape = ishape + shape = list(shape) + nx, ny = shape + + # mask of inbounds voxels + mask = inbounds_mask_2d(extrapolate, gx, gy, nx, ny) + + # corners + # (upper weight, lower corner, upper corner, lower sign, upper sign) + gx, gx0, gx1, signx0, signx1 = get_weights_and_indices(gx, nx, boundx) + gy, gy0, gy1, signy0, signy1 = get_weights_and_indices(gy, ny, boundy) + + # scatter + out = torch.zeros([batch, channel, nx*ny], + dtype=inp.dtype, device=inp.device) + # - corner 00 + idx = sub2ind_list([gx0, gy0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = make_sign([signx0, signy0]) + if sign is not None: + out1 *= sign + if mask is not None: + out1 *= mask + out1 *= (1 - gx) * (1 - gy) + out.scatter_add_(-1, idx, out1) + # - corner 01 + idx = sub2ind_list([gx0, gy1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = make_sign([signx0, signy1]) + if sign is not None: + out1 *= sign + if mask is not None: + out1 *= mask + out1 *= (1 - gx) * gy + out.scatter_add_(-1, idx, out1) + # - corner 10 + idx = sub2ind_list([gx1, gy0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = make_sign([signx1, signy0]) + if sign is not None: + out1 *= sign + if mask is not None: + out1 *= mask + out1 *= gx * (1 - gy) + out.scatter_add_(-1, idx, out1) + # - corner 11 + idx = sub2ind_list([gx1, gy1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = make_sign([signx1, signy1]) + if sign is not None: + out1 *= sign + if mask is not None: + out1 *= mask + out1 *= gx * gy + out.scatter_add_(-1, idx, out1) + + out = out.reshape(list(out.shape[:2]) + shape) + return out + + +@torch.jit.script +def grad2d(inp, g, bound: List[Bound], extrapolate: int = 1): + """ + inp: (B, C, iX, iY) tensor + g: (B, oX, oY, 2) tensor + bound: List{2}[Bound] tensor + extrapolate: ExtrapolateType + returns: (B, C, oX, oY, 2) tensor + """ + dim = 2 + boundx, boundy = bound + oshape = list(g.shape[-dim-1:-1]) + g = g.reshape([g.shape[0], 1, -1, dim]) + gx, gy = torch.unbind(g, -1) + batch = max(inp.shape[0], gx.shape[0]) + channel = inp.shape[1] + shape = list(inp.shape[-dim:]) + nx, ny = shape + + # mask of inbounds voxels + mask = inbounds_mask_2d(extrapolate, gx, gy, nx, ny) + + # corners + # (upper weight, lower corner, upper corner, lower sign, upper sign) + gx, gx0, gx1, signx0, signx1 = get_weights_and_indices(gx, nx, boundx) + gy, gy0, gy1, signy0, signy1 = get_weights_and_indices(gy, ny, boundy) + + # gather + inp = inp.reshape(list(inp.shape[:2]) + [-1]) + out = torch.empty([batch, channel] + list(g.shape[-2:]), + dtype=inp.dtype, device=inp.device) + outx, outy = out.unbind(-1) + # - corner 00 + idx = sub2ind_list([gx0, gy0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + torch.gather(inp, -1, idx, out=outx) + outy.copy_(outx) + sign = make_sign([signx0, signy0]) + if sign is not None: + out *= sign.unsqueeze(-1) + outx *= - (1 - gy) + outy *= - (1 - gx) + # - corner 01 + idx = sub2ind_list([gx0, gy1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx0, signy1]) + if sign is not None: + out1 *= sign + outx.addcmul_(out1, - gy) + outy.addcmul_(out1, (1 - gx)) + # - corner 10 + idx = sub2ind_list([gx1, gy0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx1, signy0]) + if sign is not None: + out1 *= sign + outx.addcmul_(out1, (1 - gy)) + outy.addcmul_(out1, - gx) + # - corner 11 + idx = sub2ind_list([gx1, gy1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx1, signy1]) + if sign is not None: + out1 *= sign + outx.addcmul_(out1, gy) + outy.addcmul_(out1, gx) + + if mask is not None: + out *= mask.unsqueeze(-1) + out = out.reshape(list(out.shape[:2]) + oshape + [dim]) + return out + + +@torch.jit.script +def pushgrad2d(inp, g, shape: Optional[List[int]], bound: List[Bound], + extrapolate: int = 1): + """ + inp: (B, C, iX, iY, 2) tensor + g: (B, iX, iY, 2) tensor + shape: List{2}[int], optional + bound: List{2}[Bound] tensor + extrapolate: ExtrapolateType + returns: (B, C, *shape) tensor + """ + dim = 2 + boundx, boundy = bound + if inp.shape[-dim-1:-1] != g.shape[-dim-1:-1]: + raise ValueError('Input and grid should have the same spatial shape') + ishape = list(inp.shape[-dim-1:-1]) + g = g.reshape([g.shape[0], 1, -1, dim]) + gx, gy = g.unbind(-1) + inp = inp.reshape(list(inp.shape[:2]) + [-1, dim]) + batch = max(inp.shape[0], g.shape[0]) + channel = inp.shape[1] + + if shape is None: + shape = ishape + shape = list(shape) + nx, ny = shape + + # mask of inbounds voxels + mask = inbounds_mask_2d(extrapolate, gx, gy, nx, ny) + + # corners + # (upper weight, lower corner, upper corner, lower sign, upper sign) + gx, gx0, gx1, signx0, signx1 = get_weights_and_indices(gx, nx, boundx) + gy, gy0, gy1, signy0, signy1 = get_weights_and_indices(gy, ny, boundy) + + # scatter + out = torch.zeros([batch, channel, nx*ny], + dtype=inp.dtype, device=inp.device) + # - corner 00 + idx = sub2ind_list([gx0, gy0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = make_sign([signx0, signy0]) + if sign is not None: + out1 *= sign.unsqueeze(-1) + if mask is not None: + out1 *= mask.unsqueeze(-1) + out1x, out1y = out1.unbind(-1) + out1x *= - (1 - gy) + out1y *= - (1 - gx) + out.scatter_add_(-1, idx, out1x + out1y) + # - corner 01 + idx = sub2ind_list([gx0, gy1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = make_sign([signx0, signy1]) + if sign is not None: + out1 *= sign.unsqueeze(-1) + if mask is not None: + out1 *= mask.unsqueeze(-1) + out1x, out1y = out1.unbind(-1) + out1x *= - gy + out1y *= (1 - gx) + out.scatter_add_(-1, idx, out1x + out1y) + # - corner 10 + idx = sub2ind_list([gx1, gy0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = make_sign([signx1, signy0]) + if sign is not None: + out1 *= sign.unsqueeze(-1) + if mask is not None: + out1 *= mask.unsqueeze(-1) + out1x, out1y = out1.unbind(-1) + out1x *= (1 - gy) + out1y *= - gx + out.scatter_add_(-1, idx, out1x + out1y) + # - corner 11 + idx = sub2ind_list([gx1, gy1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = make_sign([signx1, signy1]) + if sign is not None: + out1 *= sign.unsqueeze(-1) + if mask is not None: + out1 *= mask.unsqueeze(-1) + out1x, out1y = out1.unbind(-1) + out1x *= gy + out1y *= gx + out.scatter_add_(-1, idx, out1x + out1y) + + out = out.reshape(list(out.shape[:2]) + shape) + return out + + +@torch.jit.script +def hess2d(inp, g, bound: List[Bound], extrapolate: int = 1): + """ + inp: (B, C, iX, iY) tensor + g: (B, oX, oY, 2) tensor + bound: List{2}[Bound] tensor + extrapolate: ExtrapolateType + returns: (B, C, oX, oY, 2, 2) tensor + """ + dim = 2 + boundx, boundy = bound + oshape = list(g.shape[-dim-1:-1]) + g = g.reshape([g.shape[0], 1, -1, dim]) + gx, gy = torch.unbind(g, -1) + batch = max(inp.shape[0], gx.shape[0]) + channel = inp.shape[1] + shape = list(inp.shape[-dim:]) + nx, ny = shape + + # mask of inbounds voxels + mask = inbounds_mask_2d(extrapolate, gx, gy, nx, ny) + + # corners + # (upper weight, lower corner, upper corner, lower sign, upper sign) + gx, gx0, gx1, signx0, signx1 = get_weights_and_indices(gx, nx, boundx) + gy, gy0, gy1, signy0, signy1 = get_weights_and_indices(gy, ny, boundy) + + # gather + inp = inp.reshape(list(inp.shape[:2]) + [-1]) + out = torch.empty([batch, channel, g.shape[-2], dim, dim], + dtype=inp.dtype, device=inp.device) + outx, outy = out.unbind(-1) + outxx, outyx = outx.unbind(-1) + outxy, outyy = outy.unbind(-1) + # - corner 00 + idx = sub2ind_list([gx0, gy0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + torch.gather(inp, -1, idx, out=outxy) + outxx.zero_() + outyy.zero_() + sign = make_sign([signx0, signy0]) + if sign is not None: + out *= sign.unsqueeze(-1).unsqueeze(-1) + outxy *= 1 + # - corner 01 + idx = sub2ind_list([gx0, gy1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx0, signy1]) + if sign is not None: + out1 *= sign + outxy.add_(out1, alpha=-1) + # - corner 10 + idx = sub2ind_list([gx1, gy0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx1, signy0]) + if sign is not None: + out1 *= sign + outxy.add_(out1, alpha=-1) + # - corner 11 + idx = sub2ind_list([gx1, gy1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx1, signy1]) + if sign is not None: + out1 *= sign + outxy.add_(out1) + + outyx.copy_(outxy) + + if mask is not None: + out *= mask.unsqueeze(-1).unsqueeze(-1) + out = out.reshape(list(out.shape[:2]) + oshape + [dim, dim]) + return out + + +# ====================================================================== +# 1D +# ====================================================================== + + +@torch.jit.script +def pull1d(inp, g, bound: List[Bound], extrapolate: int = 1): + """ + inp: (B, C, iX) tensor + g: (B, oX, 1) tensor + bound: List{1}[Bound] tensor + extrapolate: ExtrapolateType + returns: (B, C, oX) tensor + """ + dim = 1 + boundx = bound[0] + oshape = list(g.shape[-dim-1:-1]) + g = g.reshape([g.shape[0], 1, -1, dim]) + gx = g.squeeze(-1) + batch = max(inp.shape[0], gx.shape[0]) + channel = inp.shape[1] + shape = list(inp.shape[-dim:]) + nx = shape[0] + + # mask of inbounds voxels + mask = inbounds_mask_1d(extrapolate, gx, nx) + + # corners + # (upper weight, lower corner, upper corner, lower sign, upper sign) + gx, gx0, gx1, signx0, signx1 = get_weights_and_indices(gx, nx, boundx) + + # gather + inp = inp.reshape(list(inp.shape[:2]) + [-1]) + # - corner 0 + idx = gx0 + idx = idx.expand([batch, channel, idx.shape[-1]]) + out = inp.gather(-1, idx) + sign = signx0 + if sign is not None: + out = out * sign + out = out * (1 - gx) + # - corner 1 + idx = gx1 + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = signx1 + if sign is not None: + out1 = out1 * sign + out1 = out1 * gx + out = out + out1 + + if mask is not None: + out *= mask + out = out.reshape(list(out.shape[:2]) + oshape) + return out + + +@torch.jit.script +def push1d(inp, g, shape: Optional[List[int]], bound: List[Bound], + extrapolate: int = 1): + """ + inp: (B, C, iX, iY) tensor + g: (B, iX, iY, 2) tensor + shape: List{2}[int], optional + bound: List{2}[Bound] tensor + extrapolate: ExtrapolateType + returns: (B, C, *shape) tensor + """ + dim = 1 + boundx = bound[0] + if inp.shape[-dim:] != g.shape[-dim-1:-1]: + raise ValueError('Input and grid should have the same spatial shape') + ishape = list(inp.shape[-dim:]) + g = g.reshape([g.shape[0], 1, -1, dim]) + gx = g.squeeze(-1) + inp = inp.reshape(list(inp.shape[:2]) + [-1]) + batch = max(inp.shape[0], gx.shape[0]) + channel = inp.shape[1] + + if shape is None: + shape = ishape + shape = list(shape) + nx = shape[0] + + # mask of inbounds voxels + mask = inbounds_mask_1d(extrapolate, gx, nx) + + # corners + # (upper weight, lower corner, upper corner, lower sign, upper sign) + gx, gx0, gx1, signx0, signx1 = get_weights_and_indices(gx, nx, boundx) + + # scatter + out = torch.zeros([batch, channel, nx], + dtype=inp.dtype, device=inp.device) + # - corner 0 + idx = gx0 + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = signx0 + if sign is not None: + out1 = out1 * sign + if mask is not None: + out1 = out1 * mask + out1 = out1 * (1 - gx) + out.scatter_add_(-1, idx, out1) + # - corner 1 + idx = gx1 + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = signx1 + if sign is not None: + out1 = out1 * sign + if mask is not None: + out1 = out1 * mask + out1 = out1 * gx + out.scatter_add_(-1, idx, out1) + + out = out.reshape(list(out.shape[:2]) + shape) + return out + + +@torch.jit.script +def grad1d(inp, g, bound: List[Bound], extrapolate: int = 1): + """ + inp: (B, C, iX) tensor + g: (B, oX, 1) tensor + bound: List{1}[Bound] tensor + extrapolate: ExtrapolateType + returns: (B, C, oX, 1) tensor + """ + dim = 1 + boundx = bound[0] + oshape = list(g.shape[-dim-1:-1]) + g = g.reshape([g.shape[0], 1, -1, dim]) + gx = g.squeeze(-1) + batch = max(inp.shape[0], gx.shape[0]) + channel = inp.shape[1] + shape = list(inp.shape[-dim:]) + nx = shape[0] + + # mask of inbounds voxels + mask = inbounds_mask_1d(extrapolate, gx, nx) + + # corners + # (upper weight, lower corner, upper corner, lower sign, upper sign) + gx, gx0, gx1, signx0, signx1 = get_weights_and_indices(gx, nx, boundx) + + # gather + inp = inp.reshape(list(inp.shape[:2]) + [-1]) + out = torch.empty([batch, channel] + list(g.shape[-2:]), + dtype=inp.dtype, device=inp.device) + outx = out.squeeze(-1) + # - corner 0 + idx = gx0 + idx = idx.expand([batch, channel, idx.shape[-1]]) + torch.gather(inp, -1, idx, out=outx) + sign = signx0 + if sign is not None: + out *= sign.unsqueeze(-1) + outx.neg_() + # - corner 1 + idx = gx1 + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = signx1 + if sign is not None: + out1 *= sign + outx.add_(out1) + + if mask is not None: + out *= mask.unsqueeze(-1) + out = out.reshape(list(out.shape[:2]) + oshape + [dim]) + return out + + +@torch.jit.script +def pushgrad1d(inp, g, shape: Optional[List[int]], bound: List[Bound], + extrapolate: int = 1): + """ + inp: (B, C, iX, 1) tensor + g: (B, iX, 1) tensor + shape: List{1}[int], optional + bound: List{1}[Bound] tensor + extrapolate: ExtrapolateType + returns: (B, C, *shape) tensor + """ + dim = 1 + boundx = bound[0] + if inp.shape[-2] != g.shape[-2]: + raise ValueError('Input and grid should have the same spatial shape') + ishape = list(inp.shape[-dim-1:-1]) + g = g.reshape([g.shape[0], 1, -1, dim]) + gx = g.squeeze(-1) + inp = inp.reshape(list(inp.shape[:2]) + [-1, dim]) + batch = max(inp.shape[0], g.shape[0]) + channel = inp.shape[1] + + if shape is None: + shape = ishape + shape = list(shape) + nx = shape[0] + + # mask of inbounds voxels + mask = inbounds_mask_1d(extrapolate, gx, nx) + + # corners + # (upper weight, lower corner, upper corner, lower sign, upper sign) + gx, gx0, gx1, signx0, signx1 = get_weights_and_indices(gx, nx, boundx) + + # scatter + out = torch.zeros([batch, channel, nx], dtype=inp.dtype, device=inp.device) + # - corner 000 + idx = gx0 + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = signx0 + if sign is not None: + out1 *= sign.unsqueeze(-1) + if mask is not None: + out1 *= mask.unsqueeze(-1) + out1x = out1.squeeze(-1) + out1x.neg_() + out.scatter_add_(-1, idx, out1x) + # - corner 100 + idx = gx1 + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = signx1 + if sign is not None: + out1 *= sign.unsqueeze(-1) + if mask is not None: + out1 *= mask.unsqueeze(-1) + out1x = out1.squeeze(-1) + out.scatter_add_(-1, idx, out1x) + + out = out.reshape(list(out.shape[:2]) + shape) + return out + + +@torch.jit.script +def hess1d(inp, g, bound: List[Bound], extrapolate: int = 1): + """ + inp: (B, C, iX) tensor + g: (B, oX, 1) tensor + bound: List{1}[Bound] tensor + extrapolate: ExtrapolateType + returns: (B, C, oX, 1, 1) tensor + """ + batch = max(inp.shape[0], g.shape[0]) + return torch.zeros([batch, inp.shape[1], g.shape[1], 1, 1], + dtype=inp.dtype, device=inp.device) \ No newline at end of file diff --git a/utils/interpol/jit_utils.py b/utils/interpol/jit_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..cc2765af659eb553ef060513c2fad71fe48daadd --- /dev/null +++ b/utils/interpol/jit_utils.py @@ -0,0 +1,443 @@ +"""A lot of utility functions for TorchScript""" +import torch +import os +from typing import List, Tuple, Optional +from .utils import torch_version +from torch import Tensor + + +@torch.jit.script +def pad_list_int(x: List[int], dim: int) -> List[int]: + if len(x) < dim: + x = x + x[-1:] * (dim - len(x)) + if len(x) > dim: + x = x[:dim] + return x + + +@torch.jit.script +def pad_list_float(x: List[float], dim: int) -> List[float]: + if len(x) < dim: + x = x + x[-1:] * (dim - len(x)) + if len(x) > dim: + x = x[:dim] + return x + + +@torch.jit.script +def pad_list_str(x: List[str], dim: int) -> List[str]: + if len(x) < dim: + x = x + x[-1:] * (dim - len(x)) + if len(x) > dim: + x = x[:dim] + return x + + +@torch.jit.script +def list_any(x: List[bool]) -> bool: + for elem in x: + if elem: + return True + return False + + +@torch.jit.script +def list_all(x: List[bool]) -> bool: + for elem in x: + if not elem: + return False + return True + + +@torch.jit.script +def list_prod_int(x: List[int]) -> int: + if len(x) == 0: + return 1 + x0 = x[0] + for x1 in x[1:]: + x0 = x0 * x1 + return x0 + + +@torch.jit.script +def list_sum_int(x: List[int]) -> int: + if len(x) == 0: + return 1 + x0 = x[0] + for x1 in x[1:]: + x0 = x0 + x1 + return x0 + + +@torch.jit.script +def list_prod_tensor(x: List[Tensor]) -> Tensor: + if len(x) == 0: + empty: List[int] = [] + return torch.ones(empty) + x0 = x[0] + for x1 in x[1:]: + x0 = x0 * x1 + return x0 + + +@torch.jit.script +def list_sum_tensor(x: List[Tensor]) -> Tensor: + if len(x) == 0: + empty: List[int] = [] + return torch.ones(empty) + x0 = x[0] + for x1 in x[1:]: + x0 = x0 + x1 + return x0 + + +@torch.jit.script +def list_reverse_int(x: List[int]) -> List[int]: + if len(x) == 0: + return x + return [x[i] for i in range(-1, -len(x)-1, -1)] + + +@torch.jit.script +def list_cumprod_int(x: List[int], reverse: bool = False, + exclusive: bool = False) -> List[int]: + if len(x) == 0: + lx: List[int] = [] + return lx + if reverse: + x = list_reverse_int(x) + + x0 = 1 if exclusive else x[0] + lx = [x0] + all_x = x[:-1] if exclusive else x[1:] + for x1 in all_x: + x0 = x0 * x1 + lx.append(x0) + if reverse: + lx = list_reverse_int(lx) + return lx + + +@torch.jit.script +def movedim1(x, source: int, destination: int): + dim = x.dim() + source = dim + source if source < 0 else source + destination = dim + destination if destination < 0 else destination + permutation = [d for d in range(dim)] + permutation = permutation[:source] + permutation[source+1:] + permutation = permutation[:destination] + [source] + permutation[destination:] + return x.permute(permutation) + + +@torch.jit.script +def sub2ind(subs, shape: List[int]): + """Convert sub indices (i, j, k) into linear indices. + + The rightmost dimension is the most rapidly changing one + -> if shape == [D, H, W], the strides are therefore [H*W, W, 1] + + Parameters + ---------- + subs : (D, ...) tensor + List of sub-indices. The first dimension is the number of dimension. + Each element should have the same number of elements and shape. + shape : (D,) list[int] + Size of each dimension. Its length should be the same as the + first dimension of ``subs``. + + Returns + ------- + ind : (...) tensor + Linear indices + """ + subs = subs.unbind(0) + ind = subs[-1] + subs = subs[:-1] + ind = ind.clone() + stride = list_cumprod_int(shape[1:], reverse=True, exclusive=False) + for i, s in zip(subs, stride): + ind += i * s + return ind + + +@torch.jit.script +def sub2ind_list(subs: List[Tensor], shape: List[int]): + """Convert sub indices (i, j, k) into linear indices. + + The rightmost dimension is the most rapidly changing one + -> if shape == [D, H, W], the strides are therefore [H*W, W, 1] + + Parameters + ---------- + subs : (D,) list[tensor] + List of sub-indices. The first dimension is the number of dimension. + Each element should have the same number of elements and shape. + shape : (D,) list[int] + Size of each dimension. Its length should be the same as the + first dimension of ``subs``. + + Returns + ------- + ind : (...) tensor + Linear indices + """ + ind = subs[-1] + subs = subs[:-1] + ind = ind.clone() + stride = list_cumprod_int(shape[1:], reverse=True, exclusive=False) + for i, s in zip(subs, stride): + ind += i * s + return ind + +# floor_divide returns wrong results for negative values, because it truncates +# instead of performing a proper floor. In recent version of pytorch, it is +# advised to use div(..., rounding_mode='trunc'|'floor') instead. +# Here, we only use floor_divide on positive values so we do not care. +if torch_version('>=', [1, 8]): + @torch.jit.script + def floor_div(x, y) -> torch.Tensor: + return torch.div(x, y, rounding_mode='floor') + @torch.jit.script + def floor_div_int(x, y: int) -> torch.Tensor: + return torch.div(x, y, rounding_mode='floor') +else: + @torch.jit.script + def floor_div(x, y) -> torch.Tensor: + return (x / y).floor_() + @torch.jit.script + def floor_div_int(x, y: int) -> torch.Tensor: + return (x / y).floor_() + + +@torch.jit.script +def ind2sub(ind, shape: List[int]): + """Convert linear indices into sub indices (i, j, k). + + The rightmost dimension is the most rapidly changing one + -> if shape == [D, H, W], the strides are therefore [H*W, W, 1] + + Parameters + ---------- + ind : tensor_like + Linear indices + shape : (D,) vector_like + Size of each dimension. + + Returns + ------- + subs : (D, ...) tensor + Sub-indices. + """ + stride = list_cumprod_int(shape, reverse=True, exclusive=True) + sub = ind.new_empty([len(shape)] + ind.shape) + sub.copy_(ind) + for d in range(len(shape)): + if d > 0: + sub[d] = torch.remainder(sub[d], stride[d-1]) + sub[d] = floor_div_int(sub[d], stride[d]) + return sub + + +@torch.jit.script +def inbounds_mask_3d(extrapolate: int, gx, gy, gz, nx: int, ny: int, nz: int) \ + -> Optional[Tensor]: + # mask of inbounds voxels + mask: Optional[Tensor] = None + if extrapolate in (0, 2): # no / hist + tiny = 5e-2 + threshold = tiny + if extrapolate == 2: + threshold = 0.5 + tiny + mask = ((gx > -threshold) & (gx < nx - 1 + threshold) & + (gy > -threshold) & (gy < ny - 1 + threshold) & + (gz > -threshold) & (gz < nz - 1 + threshold)) + return mask + return mask + + +@torch.jit.script +def inbounds_mask_2d(extrapolate: int, gx, gy, nx: int, ny: int) \ + -> Optional[Tensor]: + # mask of inbounds voxels + mask: Optional[Tensor] = None + if extrapolate in (0, 2): # no / hist + tiny = 5e-2 + threshold = tiny + if extrapolate == 2: + threshold = 0.5 + tiny + mask = ((gx > -threshold) & (gx < nx - 1 + threshold) & + (gy > -threshold) & (gy < ny - 1 + threshold)) + return mask + return mask + + +@torch.jit.script +def inbounds_mask_1d(extrapolate: int, gx, nx: int) -> Optional[Tensor]: + # mask of inbounds voxels + mask: Optional[Tensor] = None + if extrapolate in (0, 2): # no / hist + tiny = 5e-2 + threshold = tiny + if extrapolate == 2: + threshold = 0.5 + tiny + mask = (gx > -threshold) & (gx < nx - 1 + threshold) + return mask + return mask + + +@torch.jit.script +def make_sign(sign: List[Optional[Tensor]]) -> Optional[Tensor]: + is_none : List[bool] = [s is None for s in sign] + if list_all(is_none): + return None + filt_sign: List[Tensor] = [] + for s in sign: + if s is not None: + filt_sign.append(s) + return list_prod_tensor(filt_sign) + + +@torch.jit.script +def square(x): + return x * x + + +@torch.jit.script +def square_(x): + return x.mul_(x) + + +@torch.jit.script +def cube(x): + return x * x * x + + +@torch.jit.script +def cube_(x): + return square_(x).mul_(x) + + +@torch.jit.script +def pow4(x): + return square(square(x)) + + +@torch.jit.script +def pow4_(x): + return square_(square_(x)) + + +@torch.jit.script +def pow5(x): + return x * pow4(x) + + +@torch.jit.script +def pow5_(x): + return pow4_(x).mul_(x) + + +@torch.jit.script +def pow6(x): + return square(cube(x)) + + +@torch.jit.script +def pow6_(x): + return square_(cube_(x)) + + +@torch.jit.script +def pow7(x): + return pow6(x) * x + + +@torch.jit.script +def pow7_(x): + return pow6_(x).mul_(x) + + +@torch.jit.script +def dot(x, y, dim: int = -1, keepdim: bool = False): + """(Batched) dot product along a dimension""" + x = movedim1(x, dim, -1).unsqueeze(-2) + y = movedim1(y, dim, -1).unsqueeze(-1) + d = torch.matmul(x, y).squeeze(-1).squeeze(-1) + if keepdim: + d.unsqueeze(dim) + return d + + +@torch.jit.script +def dot_multi(x, y, dim: List[int], keepdim: bool = False): + """(Batched) dot product along a dimension""" + for d in dim: + x = movedim1(x, d, -1) + y = movedim1(y, d, -1) + x = x.reshape(x.shape[:-len(dim)] + [1, -1]) + y = y.reshape(x.shape[:-len(dim)] + [-1, 1]) + dt = torch.matmul(x, y).squeeze(-1).squeeze(-1) + if keepdim: + for d in dim: + dt.unsqueeze(d) + return dt + + + +# cartesian_prod takes multiple inout tensors as input in eager mode +# but takes a list of tensor in jit mode. This is a helper that works +# in both cases. +if not int(os.environ.get('PYTORCH_JIT', '1')): + cartesian_prod = lambda x: torch.cartesian_prod(*x) + if torch_version('>=', (1, 10)): + def meshgrid_ij(x: List[torch.Tensor]) -> List[torch.Tensor]: + return torch.meshgrid(*x, indexing='ij') + def meshgrid_xy(x: List[torch.Tensor]) -> List[torch.Tensor]: + return torch.meshgrid(*x, indexing='xy') + else: + def meshgrid_ij(x: List[torch.Tensor]) -> List[torch.Tensor]: + return torch.meshgrid(*x) + def meshgrid_xy(x: List[torch.Tensor]) -> List[torch.Tensor]: + grid = torch.meshgrid(*x) + if len(grid) > 1: + grid[0] = grid[0].transpose(0, 1) + grid[1] = grid[1].transpose(0, 1) + return grid + +else: + cartesian_prod = torch.cartesian_prod + if torch_version('>=', (1, 10)): + @torch.jit.script + def meshgrid_ij(x: List[torch.Tensor]) -> List[torch.Tensor]: + return torch.meshgrid(x, indexing='ij') + @torch.jit.script + def meshgrid_xy(x: List[torch.Tensor]) -> List[torch.Tensor]: + return torch.meshgrid(x, indexing='xy') + else: + @torch.jit.script + def meshgrid_ij(x: List[torch.Tensor]) -> List[torch.Tensor]: + return torch.meshgrid(x) + @torch.jit.script + def meshgrid_xyt(x: List[torch.Tensor]) -> List[torch.Tensor]: + grid = torch.meshgrid(x) + if len(grid) > 1: + grid[0] = grid[0].transpose(0, 1) + grid[1] = grid[1].transpose(0, 1) + return grid + + +meshgrid = meshgrid_ij + + +# In torch < 1.6, div applied to integer tensor performed a floor_divide +# In torch > 1.6, it performs a true divide. +# Floor division must be done using `floor_divide`, but it was buggy +# until torch 1.13 (it was doing a trunc divide instead of a floor divide). +# There was at some point a deprecation warning for floor_divide, but it +# seems to have been lifted afterwards. In torch >= 1.13, floor_divide +# performs a correct floor division. +# Since we only apply floor_divide ot positive values, we are fine. +if torch_version('<', (1, 6)): + floor_div = torch.div +else: + floor_div = torch.floor_divide \ No newline at end of file diff --git a/utils/interpol/jitfields.py b/utils/interpol/jitfields.py new file mode 100644 index 0000000000000000000000000000000000000000..b758a8085de51847d2aef8b0b1795a1720a92136 --- /dev/null +++ b/utils/interpol/jitfields.py @@ -0,0 +1,95 @@ +try: + import jitfields + available = True +except (ImportError, ModuleNotFoundError): + jitfields = None + available = False +from .utils import make_list +import torch + + +def first2last(input, ndim): + insert = input.dim() <= ndim + if insert: + input = input.unsqueeze(-1) + else: + input = torch.movedim(input, -ndim-1, -1) + return input, insert + + +def last2first(input, ndim, inserted, grad=False): + if inserted: + input = input.squeeze(-1 - grad) + else: + input = torch.movedim(input, -1 - grad, -ndim-1 - grad) + return input + + +def grid_pull(input, grid, interpolation='linear', bound='zero', + extrapolate=False, prefilter=False): + ndim = grid.shape[-1] + input, inserted = first2last(input, ndim) + input = jitfields.pull(input, grid, order=interpolation, bound=bound, + extrapolate=extrapolate, prefilter=prefilter) + input = last2first(input, ndim, inserted) + return input + + +def grid_push(input, grid, shape=None, interpolation='linear', bound='zero', + extrapolate=False, prefilter=False): + ndim = grid.shape[-1] + input, inserted = first2last(input, ndim) + input = jitfields.push(input, grid, shape, order=interpolation, bound=bound, + extrapolate=extrapolate, prefilter=prefilter) + input = last2first(input, ndim, inserted) + return input + + +def grid_count(grid, shape=None, interpolation='linear', bound='zero', + extrapolate=False): + return jitfields.count(grid, shape, order=interpolation, bound=bound, + extrapolate=extrapolate) + + +def grid_grad(input, grid, interpolation='linear', bound='zero', + extrapolate=False, prefilter=False): + ndim = grid.shape[-1] + input, inserted = first2last(input, ndim) + input = jitfields.grad(input, grid, order=interpolation, bound=bound, + extrapolate=extrapolate, prefilter=prefilter) + input = last2first(input, ndim, inserted, True) + return input + + +def spline_coeff(input, interpolation='linear', bound='dct2', dim=-1, + inplace=False): + func = jitfields.spline_coeff_ if inplace else jitfields.spline_coeff + return func(input, interpolation, bound=bound, dim=dim) + + +def spline_coeff_nd(input, interpolation='linear', bound='dct2', dim=None, + inplace=False): + func = jitfields.spline_coeff_nd_ if inplace else jitfields.spline_coeff_nd + return func(input, interpolation, bound=bound, ndim=dim) + + +def resize(image, factor=None, shape=None, anchor='c', + interpolation=1, prefilter=True, **kwargs): + kwargs.setdefault('bound', 'nearest') + ndim = max(len(make_list(factor or [])), + len(make_list(shape or [])), + len(make_list(anchor or []))) or (image.dim() - 2) + return jitfields.resize(image, factor=factor, shape=shape, ndim=ndim, + anchor=anchor, order=interpolation, + bound=kwargs['bound'], prefilter=prefilter) + + +def restrict(image, factor=None, shape=None, anchor='c', + interpolation=1, reduce_sum=False, **kwargs): + kwargs.setdefault('bound', 'nearest') + ndim = max(len(make_list(factor or [])), + len(make_list(shape or [])), + len(make_list(anchor or []))) or (image.dim() - 2) + return jitfields.restrict(image, factor=factor, shape=shape, ndim=ndim, + anchor=anchor, order=interpolation, + bound=kwargs['bound'], reduce_sum=reduce_sum) diff --git a/utils/interpol/nd.py b/utils/interpol/nd.py new file mode 100644 index 0000000000000000000000000000000000000000..1a366ff2e8ca3c07f15defb01ff6df9fa3990ed6 --- /dev/null +++ b/utils/interpol/nd.py @@ -0,0 +1,464 @@ +"""Generic N-dimensional version: any combination of spline orders""" +import torch +from typing import List, Optional, Tuple +from .bounds import Bound +from .splines import Spline +from .jit_utils import sub2ind_list, make_sign, list_prod_int, cartesian_prod +Tensor = torch.Tensor + + +@torch.jit.script +def inbounds_mask(extrapolate: int, grid, shape: List[int])\ + -> Optional[Tensor]: + # mask of inbounds voxels + mask: Optional[Tensor] = None + if extrapolate in (0, 2): # no / hist + grid = grid.unsqueeze(1) + tiny = 5e-2 + threshold = tiny + if extrapolate == 2: + threshold = 0.5 + tiny + mask = torch.ones(grid.shape[:-1], + dtype=torch.bool, device=grid.device) + for grid1, shape1 in zip(grid.unbind(-1), shape): + mask = mask & (grid1 > -threshold) + mask = mask & (grid1 < shape1 - 1 + threshold) + return mask + return mask + + +@torch.jit.script +def get_weights(grid, bound: List[Bound], spline: List[Spline], + shape: List[int], grad: bool = False, hess: bool = False) \ + -> Tuple[List[List[Tensor]], + List[List[Optional[Tensor]]], + List[List[Optional[Tensor]]], + List[List[Tensor]], + List[List[Optional[Tensor]]]]: + + weights: List[List[Tensor]] = [] + grads: List[List[Optional[Tensor]]] = [] + hesss: List[List[Optional[Tensor]]] = [] + coords: List[List[Tensor]] = [] + signs: List[List[Optional[Tensor]]] = [] + for g, b, s, n in zip(grid.unbind(-1), bound, spline, shape): + grid0 = (g - (s.order-1)/2).floor() + dist0 = g - grid0 + grid0 = grid0.long() + nb_nodes = s.order + 1 + subweights: List[Tensor] = [] + subcoords: List[Tensor] = [] + subgrads: List[Optional[Tensor]] = [] + subhesss: List[Optional[Tensor]] = [] + subsigns: List[Optional[Tensor]] = [] + for node in range(nb_nodes): + grid1 = grid0 + node + sign1: Optional[Tensor] = b.transform(grid1, n) + subsigns.append(sign1) + grid1 = b.index(grid1, n) + subcoords.append(grid1) + dist1 = dist0 - node + weight1 = s.fastweight(dist1) + subweights.append(weight1) + grad1: Optional[Tensor] = None + if grad: + grad1 = s.fastgrad(dist1) + subgrads.append(grad1) + hess1: Optional[Tensor] = None + if hess: + hess1 = s.fasthess(dist1) + subhesss.append(hess1) + weights.append(subweights) + coords.append(subcoords) + signs.append(subsigns) + grads.append(subgrads) + hesss.append(subhesss) + + return weights, grads, hesss, coords, signs + + +@torch.jit.script +def pull(inp, grid, bound: List[Bound], spline: List[Spline], + extrapolate: int = 1): + """ + inp: (B, C, *ishape) tensor + g: (B, *oshape, D) tensor + bound: List{D}[Bound] tensor + spline: List{D}[Spline] tensor + extrapolate: int + returns: (B, C, *oshape) tensor + """ + + dim = grid.shape[-1] + shape = list(inp.shape[-dim:]) + oshape = list(grid.shape[-dim-1:-1]) + batch = max(inp.shape[0], grid.shape[0]) + channel = inp.shape[1] + + grid = grid.reshape([grid.shape[0], -1, grid.shape[-1]]) + inp = inp.reshape([inp.shape[0], inp.shape[1], -1]) + mask = inbounds_mask(extrapolate, grid, shape) + + # precompute weights along each dimension + weights, _, _, coords, signs = get_weights(grid, bound, spline, shape, False, False) + + # initialize + out = torch.zeros([batch, channel, grid.shape[1]], + dtype=inp.dtype, device=inp.device) + + # iterate across nodes/corners + range_nodes = [torch.as_tensor([d for d in range(n)]) + for n in [s.order + 1 for s in spline]] + if dim == 1: + # cartesian_prod does not work as expected when only one + # element is provided + all_nodes = range_nodes[0].unsqueeze(-1) + else: + all_nodes = cartesian_prod(range_nodes) + for nodes in all_nodes: + # gather + idx = [c[n] for c, n in zip(coords, nodes)] + idx = sub2ind_list(idx, shape).unsqueeze(1) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + + # apply sign + sign0: List[Optional[Tensor]] = [sgn[n] for sgn, n in zip(signs, nodes)] + sign1: Optional[Tensor] = make_sign(sign0) + if sign1 is not None: + out1 = out1 * sign1.unsqueeze(1) + + # apply weights + for weight, n in zip(weights, nodes): + out1 = out1 * weight[n].unsqueeze(1) + + # accumulate + out = out + out1 + + # out-of-bounds mask + if mask is not None: + out = out * mask + + out = out.reshape(list(out.shape[:2]) + oshape) + return out + + +@torch.jit.script +def push(inp, grid, shape: Optional[List[int]], bound: List[Bound], + spline: List[Spline], extrapolate: int = 1): + """ + inp: (B, C, *ishape) tensor + g: (B, *ishape, D) tensor + shape: List{D}[int], optional + bound: List{D}[Bound] tensor + spline: List{D}[Spline] tensor + extrapolate: int + returns: (B, C, *oshape) tensor + """ + + dim = grid.shape[-1] + ishape = list(grid.shape[-dim - 1:-1]) + if shape is None: + shape = ishape + shape = list(shape) + batch = max(inp.shape[0], grid.shape[0]) + channel = inp.shape[1] + + grid = grid.reshape([grid.shape[0], -1, grid.shape[-1]]) + inp = inp.reshape([inp.shape[0], inp.shape[1], -1]) + mask = inbounds_mask(extrapolate, grid, shape) + + # precompute weights along each dimension + weights, _, _, coords, signs = get_weights(grid, bound, spline, shape) + + # initialize + out = torch.zeros([batch, channel, list_prod_int(shape)], + dtype=inp.dtype, device=inp.device) + + # iterate across nodes/corners + range_nodes = [torch.as_tensor([d for d in range(n)]) + for n in [s.order + 1 for s in spline]] + if dim == 1: + # cartesian_prod does not work as expected when only one + # element is provided + all_nodes = range_nodes[0].unsqueeze(-1) + else: + all_nodes = cartesian_prod(range_nodes) + for nodes in all_nodes: + + # gather + idx = [c[n] for c, n in zip(coords, nodes)] + idx = sub2ind_list(idx, shape).unsqueeze(1) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + + # apply sign + sign0: List[Optional[Tensor]] = [sgn[n] for sgn, n in zip(signs, nodes)] + sign1: Optional[Tensor] = make_sign(sign0) + if sign1 is not None: + out1 = out1 * sign1.unsqueeze(1) + + # out-of-bounds mask + if mask is not None: + out1 = out1 * mask + + # apply weights + for weight, n in zip(weights, nodes): + out1 = out1 * weight[n].unsqueeze(1) + + # accumulate + out.scatter_add_(-1, idx, out1) + + out = out.reshape(list(out.shape[:2]) + shape) + return out + + +@torch.jit.script +def grad(inp, grid, bound: List[Bound], spline: List[Spline], + extrapolate: int = 1): + """ + inp: (B, C, *ishape) tensor + grid: (B, *oshape, D) tensor + bound: List{D}[Bound] tensor + spline: List{D}[Spline] tensor + extrapolate: int + returns: (B, C, *oshape, D) tensor + """ + + dim = grid.shape[-1] + shape = list(inp.shape[-dim:]) + oshape = list(grid.shape[-dim-1:-1]) + batch = max(inp.shape[0], grid.shape[0]) + channel = inp.shape[1] + + grid = grid.reshape([grid.shape[0], -1, grid.shape[-1]]) + inp = inp.reshape([inp.shape[0], inp.shape[1], -1]) + mask = inbounds_mask(extrapolate, grid, shape) + + # precompute weights along each dimension + weights, grads, _, coords, signs = get_weights(grid, bound, spline, shape, + grad=True) + + # initialize + out = torch.zeros([batch, channel, grid.shape[1], dim], + dtype=inp.dtype, device=inp.device) + + # iterate across nodes/corners + range_nodes = [torch.as_tensor([d for d in range(n)]) + for n in [s.order + 1 for s in spline]] + if dim == 1: + # cartesian_prod does not work as expected when only one + # element is provided + all_nodes = range_nodes[0].unsqueeze(-1) + else: + all_nodes = cartesian_prod(range_nodes) + for nodes in all_nodes: + + # gather + idx = [c[n] for c, n in zip(coords, nodes)] + idx = sub2ind_list(idx, shape).unsqueeze(1) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out0 = inp.gather(-1, idx) + + # apply sign + sign0: List[Optional[Tensor]] = [sgn[n] for sgn, n in zip(signs, nodes)] + sign1: Optional[Tensor] = make_sign(sign0) + if sign1 is not None: + out0 = out0 * sign1.unsqueeze(1) + + for d in range(dim): + out1 = out0.clone() + # apply weights + for dd, (weight, grad1, n) in enumerate(zip(weights, grads, nodes)): + if d == dd: + grad11 = grad1[n] + if grad11 is not None: + out1 = out1 * grad11.unsqueeze(1) + else: + out1 = out1 * weight[n].unsqueeze(1) + + # accumulate + out.unbind(-1)[d].add_(out1) + + # out-of-bounds mask + if mask is not None: + out = out * mask.unsqueeze(-1) + + out = out.reshape(list(out.shape[:2]) + oshape + list(out.shape[-1:])) + return out + + +@torch.jit.script +def pushgrad(inp, grid, shape: Optional[List[int]], bound: List[Bound], + spline: List[Spline], extrapolate: int = 1): + """ + inp: (B, C, *ishape, D) tensor + g: (B, *ishape, D) tensor + shape: List{D}[int], optional + bound: List{D}[Bound] tensor + spline: List{D}[Spline] tensor + extrapolate: int + returns: (B, C, *shape) tensor + """ + dim = grid.shape[-1] + oshape = list(grid.shape[-dim-1:-1]) + if shape is None: + shape = oshape + shape = list(shape) + batch = max(inp.shape[0], grid.shape[0]) + channel = inp.shape[1] + + grid = grid.reshape([grid.shape[0], -1, grid.shape[-1]]) + inp = inp.reshape([inp.shape[0], inp.shape[1], -1, dim]) + mask = inbounds_mask(extrapolate, grid, shape) + + # precompute weights along each dimension + weights, grads, _, coords, signs = get_weights(grid, bound, spline, shape, grad=True) + + # initialize + out = torch.zeros([batch, channel, list_prod_int(shape)], + dtype=inp.dtype, device=inp.device) + + # iterate across nodes/corners + range_nodes = [torch.as_tensor([d for d in range(n)]) + for n in [s.order + 1 for s in spline]] + if dim == 1: + # cartesian_prod does not work as expected when only one + # element is provided + all_nodes = range_nodes[0].unsqueeze(-1) + else: + all_nodes = cartesian_prod(range_nodes) + for nodes in all_nodes: + + # gather + idx = [c[n] for c, n in zip(coords, nodes)] + idx = sub2ind_list(idx, shape).unsqueeze(1) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out0 = inp.clone() + + # apply sign + sign0: List[Optional[Tensor]] = [sgn[n] for sgn, n in zip(signs, nodes)] + sign1: Optional[Tensor] = make_sign(sign0) + if sign1 is not None: + out0 = out0 * sign1.unsqueeze(1).unsqueeze(-1) + + # out-of-bounds mask + if mask is not None: + out0 = out0 * mask.unsqueeze(-1) + + for d in range(dim): + out1 = out0.unbind(-1)[d].clone() + # apply weights + for dd, (weight, grad1, n) in enumerate(zip(weights, grads, nodes)): + if d == dd: + grad11 = grad1[n] + if grad11 is not None: + out1 = out1 * grad11.unsqueeze(1) + else: + out1 = out1 * weight[n].unsqueeze(1) + + # accumulate + out.scatter_add_(-1, idx, out1) + + out = out.reshape(list(out.shape[:2]) + shape) + return out + + +@torch.jit.script +def hess(inp, grid, bound: List[Bound], spline: List[Spline], + extrapolate: int = 1): + """ + inp: (B, C, *ishape) tensor + grid: (B, *oshape, D) tensor + bound: List{D}[Bound] tensor + spline: List{D}[Spline] tensor + extrapolate: int + returns: (B, C, *oshape, D, D) tensor + """ + + dim = grid.shape[-1] + shape = list(inp.shape[-dim:]) + oshape = list(grid.shape[-dim-1:-1]) + batch = max(inp.shape[0], grid.shape[0]) + channel = inp.shape[1] + + grid = grid.reshape([grid.shape[0], -1, grid.shape[-1]]) + inp = inp.reshape([inp.shape[0], inp.shape[1], -1]) + mask = inbounds_mask(extrapolate, grid, shape) + + # precompute weights along each dimension + weights, grads, hesss, coords, signs \ + = get_weights(grid, bound, spline, shape, grad=True, hess=True) + + # initialize + out = torch.zeros([batch, channel, grid.shape[1], dim, dim], + dtype=inp.dtype, device=inp.device) + + # iterate across nodes/corners + range_nodes = [torch.as_tensor([d for d in range(n)]) + for n in [s.order + 1 for s in spline]] + if dim == 1: + # cartesian_prod does not work as expected when only one + # element is provided + all_nodes = range_nodes[0].unsqueeze(-1) + else: + all_nodes = cartesian_prod(range_nodes) + for nodes in all_nodes: + + # gather + idx = [c[n] for c, n in zip(coords, nodes)] + idx = sub2ind_list(idx, shape).unsqueeze(1) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out0 = inp.gather(-1, idx) + + # apply sign + sign0: List[Optional[Tensor]] = [sgn[n] for sgn, n in zip(signs, nodes)] + sign1: Optional[Tensor] = make_sign(sign0) + if sign1 is not None: + out0 = out0 * sign1.unsqueeze(1) + + for d in range(dim): + # -- diagonal -- + out1 = out0.clone() + + # apply weights + for dd, (weight, hess1, n) \ + in enumerate(zip(weights, hesss, nodes)): + if d == dd: + hess11 = hess1[n] + if hess11 is not None: + out1 = out1 * hess11.unsqueeze(1) + else: + out1 = out1 * weight[n].unsqueeze(1) + + # accumulate + out.unbind(-1)[d].unbind(-1)[d].add_(out1) + + # -- off diagonal -- + for d2 in range(d+1, dim): + out1 = out0.clone() + + # apply weights + for dd, (weight, grad1, n) \ + in enumerate(zip(weights, grads, nodes)): + if dd in (d, d2): + grad11 = grad1[n] + if grad11 is not None: + out1 = out1 * grad11.unsqueeze(1) + else: + out1 = out1 * weight[n].unsqueeze(1) + + # accumulate + out.unbind(-1)[d].unbind(-1)[d2].add_(out1) + + # out-of-bounds mask + if mask is not None: + out = out * mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) + + # fill lower triangle + for d in range(dim): + for d2 in range(d+1, dim): + out.unbind(-1)[d2].unbind(-1)[d].copy_(out.unbind(-1)[d].unbind(-1)[d2]) + + out = out.reshape(list(out.shape[:2]) + oshape + list(out.shape[-2:])) + return out diff --git a/utils/interpol/pushpull.py b/utils/interpol/pushpull.py new file mode 100644 index 0000000000000000000000000000000000000000..d37b2d3e4815f6544c0b5f2c37b40a9087196bbc --- /dev/null +++ b/utils/interpol/pushpull.py @@ -0,0 +1,325 @@ +""" +Non-differentiable forward/backward components. +These components are put together in `interpol.autograd` to generate +differentiable functions. + +Note +---- +.. I removed @torch.jit.script from these entry-points because compiling + all possible combinations of bound+interpolation made the first call + extremely slow. +.. I am not using the dot/multi_dot helpers even though they should be + more efficient that "multiply and sum" because I haven't had the time + to test them. It would be worth doing it. +""" +import torch +from typing import List, Optional, Tuple +from .jit_utils import list_all, dot, dot_multi, pad_list_int +from .bounds import Bound +from .splines import Spline +from . import iso0, iso1, nd +Tensor = torch.Tensor + + +@torch.jit.script +def make_bound(bound: List[int]) -> List[Bound]: + return [Bound(b) for b in bound] + + +@torch.jit.script +def make_spline(spline: List[int]) -> List[Spline]: + return [Spline(s) for s in spline] + + +# @torch.jit.script +def grid_pull(inp, grid, bound: List[int], interpolation: List[int], + extrapolate: int): + """ + inp: (B, C, *spatial_in) tensor + grid: (B, *spatial_out, D) tensor + bound: List{D}[int] tensor + interpolation: List{D}[int] + extrapolate: int + returns: (B, C, *spatial_out) tensor + """ + dim = grid.shape[-1] + bound = pad_list_int(bound, dim) + interpolation = pad_list_int(interpolation, dim) + bound_fn = make_bound(bound) + is_iso1 = list_all([order == 1 for order in interpolation]) + if is_iso1: + if dim == 3: + return iso1.pull3d(inp, grid, bound_fn, extrapolate) + elif dim == 2: + return iso1.pull2d(inp, grid, bound_fn, extrapolate) + elif dim == 1: + return iso1.pull1d(inp, grid, bound_fn, extrapolate) + is_iso0 = list_all([order == 0 for order in interpolation]) + if is_iso0: + if dim == 3: + return iso0.pull3d(inp, grid, bound_fn, extrapolate) + elif dim == 2: + return iso0.pull2d(inp, grid, bound_fn, extrapolate) + elif dim == 1: + return iso0.pull1d(inp, grid, bound_fn, extrapolate) + spline_fn = make_spline(interpolation) + return nd.pull(inp, grid, bound_fn, spline_fn, extrapolate) + + +# @torch.jit.script +def grid_push(inp, grid, shape: Optional[List[int]], bound: List[int], + interpolation: List[int], extrapolate: int): + """ + inp: (B, C, *spatial_in) tensor + grid: (B, *spatial_in, D) tensor + shape: List{D}[int] tensor, optional, default=spatial_in + bound: List{D}[int] tensor + interpolation: List{D}[int] + extrapolate: int + returns: (B, C, *shape) tensor + """ + dim = grid.shape[-1] + bound = pad_list_int(bound, dim) + interpolation = pad_list_int(interpolation, dim) + bound_fn = make_bound(bound) + is_iso1 = list_all([order == 1 for order in interpolation]) + if is_iso1: + if dim == 3: + return iso1.push3d(inp, grid, shape, bound_fn, extrapolate) + elif dim == 2: + return iso1.push2d(inp, grid, shape, bound_fn, extrapolate) + elif dim == 1: + return iso1.push1d(inp, grid, shape, bound_fn, extrapolate) + is_iso0 = list_all([order == 0 for order in interpolation]) + if is_iso0: + if dim == 3: + return iso0.push3d(inp, grid, shape, bound_fn, extrapolate) + elif dim == 2: + return iso0.push2d(inp, grid, shape, bound_fn, extrapolate) + elif dim == 1: + return iso0.push1d(inp, grid, shape, bound_fn, extrapolate) + spline_fn = make_spline(interpolation) + return nd.push(inp, grid, shape, bound_fn, spline_fn, extrapolate) + + +# @torch.jit.script +def grid_count(grid, shape: Optional[List[int]], bound: List[int], + interpolation: List[int], extrapolate: int): + """ + grid: (B, *spatial_in, D) tensor + shape: List{D}[int] tensor, optional, default=spatial_in + bound: List{D}[int] tensor + interpolation: List{D}[int] + extrapolate: int + returns: (B, 1, *shape) tensor + """ + dim = grid.shape[-1] + bound = pad_list_int(bound, dim) + interpolation = pad_list_int(interpolation, dim) + bound_fn = make_bound(bound) + gshape = list(grid.shape[-dim-1:-1]) + if shape is None: + shape = gshape + inp = torch.ones([], dtype=grid.dtype, device=grid.device) + inp = inp.expand([len(grid), 1] + gshape) + is_iso1 = list_all([order == 1 for order in interpolation]) + if is_iso1: + if dim == 3: + return iso1.push3d(inp, grid, shape, bound_fn, extrapolate) + elif dim == 2: + return iso1.push2d(inp, grid, shape, bound_fn, extrapolate) + elif dim == 1: + return iso1.push1d(inp, grid, shape, bound_fn, extrapolate) + is_iso0 = list_all([order == 0 for order in interpolation]) + if is_iso0: + if dim == 3: + return iso0.push3d(inp, grid, shape, bound_fn, extrapolate) + elif dim == 2: + return iso0.push2d(inp, grid, shape, bound_fn, extrapolate) + elif dim == 1: + return iso0.push1d(inp, grid, shape, bound_fn, extrapolate) + spline_fn = make_spline(interpolation) + return nd.push(inp, grid, shape, bound_fn, spline_fn, extrapolate) + + +# @torch.jit.script +def grid_grad(inp, grid, bound: List[int], interpolation: List[int], + extrapolate: int): + """ + inp: (B, C, *spatial_in) tensor + grid: (B, *spatial_out, D) tensor + bound: List{D}[int] tensor + interpolation: List{D}[int] + extrapolate: int + returns: (B, C, *spatial_out, D) tensor + """ + dim = grid.shape[-1] + bound = pad_list_int(bound, dim) + interpolation = pad_list_int(interpolation, dim) + bound_fn = make_bound(bound) + is_iso1 = list_all([order == 1 for order in interpolation]) + if is_iso1: + if dim == 3: + return iso1.grad3d(inp, grid, bound_fn, extrapolate) + elif dim == 2: + return iso1.grad2d(inp, grid, bound_fn, extrapolate) + elif dim == 1: + return iso1.grad1d(inp, grid, bound_fn, extrapolate) + is_iso0 = list_all([order == 0 for order in interpolation]) + if is_iso0: + return iso0.grad(inp, grid, bound_fn, extrapolate) + spline_fn = make_spline(interpolation) + return nd.grad(inp, grid, bound_fn, spline_fn, extrapolate) + + +# @torch.jit.script +def grid_pushgrad(inp, grid, shape: List[int], bound: List[int], + interpolation: List[int], extrapolate: int): + """ /!\ Used only in backward pass of grid_grad + inp: (B, C, *spatial_in, D) tensor + grid: (B, *spatial_in, D) tensor + shape: List{D}[int], optional + bound: List{D}[int] tensor + interpolation: List{D}[int] + extrapolate: int + returns: (B, C, *shape) tensor + """ + dim = grid.shape[-1] + bound = pad_list_int(bound, dim) + interpolation = pad_list_int(interpolation, dim) + bound_fn = make_bound(bound) + is_iso1 = list_all([order == 1 for order in interpolation]) + if is_iso1: + if dim == 3: + return iso1.pushgrad3d(inp, grid, shape, bound_fn, extrapolate) + elif dim == 2: + return iso1.pushgrad2d(inp, grid, shape, bound_fn, extrapolate) + elif dim == 1: + return iso1.pushgrad1d(inp, grid, shape, bound_fn, extrapolate) + is_iso0 = list_all([order == 0 for order in interpolation]) + if is_iso0: + return iso0.pushgrad(inp, grid, shape, bound_fn, extrapolate) + spline_fn = make_spline(interpolation) + return nd.pushgrad(inp, grid, shape, bound_fn, spline_fn, extrapolate) + + +# @torch.jit.script +def grid_hess(inp, grid, bound: List[int], interpolation: List[int], + extrapolate: int): + """ /!\ Used only in backward pass of grid_grad + inp: (B, C, *spatial_in) tensor + grid: (B, *spatial_out, D) tensor + bound: List{D}[int] tensor + interpolation: List{D}[int] + extrapolate: int + returns: (B, C, *spatial_out, D, D) tensor + """ + dim = grid.shape[-1] + bound = pad_list_int(bound, dim) + interpolation = pad_list_int(interpolation, dim) + bound_fn = make_bound(bound) + is_iso1 = list_all([order == 1 for order in interpolation]) + if is_iso1: + if dim == 3: + return iso1.hess3d(inp, grid, bound_fn, extrapolate) + if dim == 2: + return iso1.hess2d(inp, grid, bound_fn, extrapolate) + if dim == 1: + return iso1.hess1d(inp, grid, bound_fn, extrapolate) + is_iso0 = list_all([order == 0 for order in interpolation]) + if is_iso0: + return iso0.hess(inp, grid, bound_fn, extrapolate) + spline_fn = make_spline(interpolation) + return nd.hess(inp, grid, bound_fn, spline_fn, extrapolate) + + +# @torch.jit.script +def grid_pull_backward(grad, inp, grid, bound: List[int], + interpolation: List[int], extrapolate: int) \ + -> Tuple[Optional[Tensor], Optional[Tensor], ]: + """ + grad: (B, C, *spatial_out) tensor + inp: (B, C, *spatial_in) tensor + grid: (B, *spatial_out, D) tensor + bound: List{D}[int] tensor + interpolation: List{D}[int] + extrapolate: int + returns: (B, C, *spatial_in) tensor, (B, *spatial_out, D) + """ + dim = grid.shape[-1] + grad_inp: Optional[Tensor] = None + grad_grid: Optional[Tensor] = None + if inp.requires_grad: + grad_inp = grid_push(grad, grid, inp.shape[-dim:], bound, interpolation, extrapolate) + if grid.requires_grad: + grad_grid = grid_grad(inp, grid, bound, interpolation, extrapolate) + # grad_grid = dot(grad_grid, grad.unsqueeze(-1), dim=1) + grad_grid = (grad_grid * grad.unsqueeze(-1)).sum(dim=1) + return grad_inp, grad_grid + + +# @torch.jit.script +def grid_push_backward(grad, inp, grid, bound: List[int], + interpolation: List[int], extrapolate: int) \ + -> Tuple[Optional[Tensor], Optional[Tensor], ]: + """ + grad: (B, C, *spatial_out) tensor + inp: (B, C, *spatial_in) tensor + grid: (B, *spatial_in, D) tensor + bound: List{D}[int] tensor + interpolation: List{D}[int] + extrapolate: int + returns: (B, C, *spatial_in) tensor, (B, *spatial_in, D) + """ + grad_inp: Optional[Tensor] = None + grad_grid: Optional[Tensor] = None + if inp.requires_grad: + grad_inp = grid_pull(grad, grid, bound, interpolation, extrapolate) + if grid.requires_grad: + grad_grid = grid_grad(grad, grid, bound, interpolation, extrapolate) + # grad_grid = dot(grad_grid, inp.unsqueeze(-1), dim=1) + grad_grid = (grad_grid * inp.unsqueeze(-1)).sum(dim=1) + return grad_inp, grad_grid + + +# @torch.jit.script +def grid_count_backward(grad, grid, bound: List[int], + interpolation: List[int], extrapolate: int) \ + -> Optional[Tensor]: + """ + grad: (B, C, *spatial_out) tensor + grid: (B, *spatial_in, D) tensor + bound: List{D}[int] tensor + interpolation: List{D}[int] + extrapolate: int + returns: (B, C, *spatial_in) tensor, (B, *spatial_in, D) + """ + if grid.requires_grad: + return grid_grad(grad, grid, bound, interpolation, extrapolate).sum(1) + return None + + +# @torch.jit.script +def grid_grad_backward(grad, inp, grid, bound: List[int], + interpolation: List[int], extrapolate: int) \ + -> Tuple[Optional[Tensor], Optional[Tensor]]: + """ + grad: (B, C, *spatial_out, D) tensor + inp: (B, C, *spatial_in) tensor + grid: (B, *spatial_out, D) tensor + bound: List{D}[int] tensor + interpolation: List{D}[int] + extrapolate: int + returns: (B, C, *spatial_in, D) tensor, (B, *spatial_out, D) + """ + dim = grid.shape[-1] + shape = inp.shape[-dim:] + grad_inp: Optional[Tensor] = None + grad_grid: Optional[Tensor] = None + if inp.requires_grad: + grad_inp = grid_pushgrad(grad, grid, shape, bound, interpolation, extrapolate) + if grid.requires_grad: + grad_grid = grid_hess(inp, grid, bound, interpolation, extrapolate) + # grad_grid = dot_multi(grad_grid, grad.unsqueeze(-1), dim=[1, -2]) + grad_grid = (grad_grid * grad.unsqueeze(-1)).sum(dim=[1, -2]) + return grad_inp, grad_grid diff --git a/utils/interpol/resize.py b/utils/interpol/resize.py new file mode 100644 index 0000000000000000000000000000000000000000..9b505624ef795437f516d578465f702b07a4d7ae --- /dev/null +++ b/utils/interpol/resize.py @@ -0,0 +1,120 @@ +""" +Resize functions (equivalent to scipy's zoom, pytorch's interpolate) +based on grid_pull. +""" +__all__ = ['resize'] + +from .api import grid_pull +from .utils import make_list, meshgrid_ij +from . import backend, jitfields +import torch + + +def resize(image, factor=None, shape=None, anchor='c', + interpolation=1, prefilter=True, **kwargs): + """Resize an image by a factor or to a specific shape. + + Notes + ----- + .. A least one of `factor` and `shape` must be specified + .. If `anchor in ('centers', 'edges')`, exactly one of `factor` or + `shape must be specified. + .. If `anchor in ('first', 'last')`, `factor` must be provided even + if `shape` is specified. + .. Because of rounding, it is in general not assured that + `resize(resize(x, f), 1/f)` returns a tensor with the same shape as x. + + edges centers first last + e - + - + - e + - + - + - + + - + - + - + + - + - + - + + | . | . | . | | c | . | c | | f | . | . | | . | . | . | + + _ + _ + _ + + _ + _ + _ + + _ + _ + _ + + _ + _ + _ + + | . | . | . | | . | . | . | | . | . | . | | . | . | . | + + _ + _ + _ + + _ + _ + _ + + _ + _ + _ + + _ + _ + _ + + | . | . | . | | c | . | c | | . | . | . | | . | . | l | + e _ + _ + _ e + _ + _ + _ + + _ + _ + _ + + _ + _ + _ + + + Parameters + ---------- + image : (batch, channel, *inshape) tensor + Image to resize + factor : float or list[float], optional + Resizing factor + * > 1 : larger image <-> smaller voxels + * < 1 : smaller image <-> larger voxels + shape : (ndim,) list[int], optional + Output shape + anchor : {'centers', 'edges', 'first', 'last'} or list, default='centers' + * In cases 'c' and 'e', the volume shape is multiplied by the + zoom factor (and eventually truncated), and two anchor points + are used to determine the voxel size. + * In cases 'f' and 'l', a single anchor point is used so that + the voxel size is exactly divided by the zoom factor. + This case with an integer factor corresponds to subslicing + the volume (e.g., `vol[::f, ::f, ::f]`). + * A list of anchors (one per dimension) can also be provided. + interpolation : int or sequence[int], default=1 + Interpolation order. + prefilter : bool, default=True + Apply spline pre-filter (= interpolates the input) + + Returns + ------- + resized : (batch, channel, *shape) tensor + Resized image + + """ + if backend.jitfields and jitfields.available: + return jitfields.resize(image, factor, shape, anchor, + interpolation, prefilter, **kwargs) + + factor = make_list(factor) if factor else [] + shape = make_list(shape) if shape else [] + anchor = make_list(anchor) + nb_dim = max(len(factor), len(shape), len(anchor)) or (image.dim() - 2) + anchor = [a[0].lower() for a in make_list(anchor, nb_dim)] + bck = dict(dtype=image.dtype, device=image.device) + + # compute output shape + inshape = image.shape[-nb_dim:] + if factor: + factor = make_list(factor, nb_dim) + elif not shape: + raise ValueError('One of `factor` or `shape` must be provided') + if shape: + shape = make_list(shape, nb_dim) + else: + shape = [int(i*f) for i, f in zip(inshape, factor)] + + if not factor: + factor = [o/i for o, i in zip(shape, inshape)] + + # compute transformation grid + lin = [] + for anch, f, inshp, outshp in zip(anchor, factor, inshape, shape): + if anch == 'c': # centers + lin.append(torch.linspace(0, inshp - 1, outshp, **bck)) + elif anch == 'e': # edges + scale = inshp / outshp + shift = 0.5 * (scale - 1) + lin.append(torch.arange(0., outshp, **bck) * scale + shift) + elif anch == 'f': # first voxel + # scale = 1/f + # shift = 0 + lin.append(torch.arange(0., outshp, **bck) / f) + elif anch == 'l': # last voxel + # scale = 1/f + shift = (inshp - 1) - (outshp - 1) / f + lin.append(torch.arange(0., outshp, **bck) / f + shift) + else: + raise ValueError('Unknown anchor {}'.format(anch)) + + # interpolate + kwargs.setdefault('bound', 'nearest') + kwargs.setdefault('extrapolate', True) + kwargs.setdefault('interpolation', interpolation) + kwargs.setdefault('prefilter', prefilter) + grid = torch.stack(meshgrid_ij(*lin), dim=-1) + resized = grid_pull(image, grid, **kwargs) + + return resized + diff --git a/utils/interpol/restrict.py b/utils/interpol/restrict.py new file mode 100644 index 0000000000000000000000000000000000000000..771acdfa9ce3a9e63e0ab8315362519ecceae587 --- /dev/null +++ b/utils/interpol/restrict.py @@ -0,0 +1,122 @@ +__all__ = ['restrict'] + +from .api import grid_push +from .utils import make_list, meshgrid_ij +from . import backend, jitfields +import torch + + +def restrict(image, factor=None, shape=None, anchor='c', + interpolation=1, reduce_sum=False, **kwargs): + """Restrict an image by a factor or to a specific shape. + + Notes + ----- + .. A least one of `factor` and `shape` must be specified + .. If `anchor in ('centers', 'edges')`, exactly one of `factor` or + `shape must be specified. + .. If `anchor in ('first', 'last')`, `factor` must be provided even + if `shape` is specified. + .. Because of rounding, it is in general not assured that + `resize(resize(x, f), 1/f)` returns a tensor with the same shape as x. + + edges centers first last + e - + - + - e + - + - + - + + - + - + - + + - + - + - + + | . | . | . | | c | . | c | | f | . | . | | . | . | . | + + _ + _ + _ + + _ + _ + _ + + _ + _ + _ + + _ + _ + _ + + | . | . | . | | . | . | . | | . | . | . | | . | . | . | + + _ + _ + _ + + _ + _ + _ + + _ + _ + _ + + _ + _ + _ + + | . | . | . | | c | . | c | | . | . | . | | . | . | l | + e _ + _ + _ e + _ + _ + _ + + _ + _ + _ + + _ + _ + _ + + + Parameters + ---------- + image : (batch, channel, *inshape) tensor + Image to resize + factor : float or list[float], optional + Resizing factor + * > 1 : larger image <-> smaller voxels + * < 1 : smaller image <-> larger voxels + shape : (ndim,) list[int], optional + Output shape + anchor : {'centers', 'edges', 'first', 'last'} or list, default='centers' + * In cases 'c' and 'e', the volume shape is multiplied by the + zoom factor (and eventually truncated), and two anchor points + are used to determine the voxel size. + * In cases 'f' and 'l', a single anchor point is used so that + the voxel size is exactly divided by the zoom factor. + This case with an integer factor corresponds to subslicing + the volume (e.g., `vol[::f, ::f, ::f]`). + * A list of anchors (one per dimension) can also be provided. + interpolation : int or sequence[int], default=1 + Interpolation order. + reduce_sum : bool, default=False + Do not normalize by the number of accumulated values per voxel + + Returns + ------- + restricted : (batch, channel, *shape) tensor + Restricted image + + """ + if backend.jitfields and jitfields.available: + return jitfields.restrict(image, factor, shape, anchor, + interpolation, reduce_sum, **kwargs) + + factor = make_list(factor) if factor else [] + shape = make_list(shape) if shape else [] + anchor = make_list(anchor) + nb_dim = max(len(factor), len(shape), len(anchor)) or (image.dim() - 2) + anchor = [a[0].lower() for a in make_list(anchor, nb_dim)] + bck = dict(dtype=image.dtype, device=image.device) + + # compute output shape + inshape = image.shape[-nb_dim:] + if factor: + factor = make_list(factor, nb_dim) + elif not shape: + raise ValueError('One of `factor` or `shape` must be provided') + if shape: + shape = make_list(shape, nb_dim) + else: + shape = [int(i/f) for i, f in zip(inshape, factor)] + + if not factor: + factor = [i/o for o, i in zip(shape, inshape)] + + # compute transformation grid + lin = [] + fullscale = 1 + for anch, f, inshp, outshp in zip(anchor, factor, inshape, shape): + if anch == 'c': # centers + lin.append(torch.linspace(0, outshp - 1, inshp, **bck)) + fullscale *= (inshp - 1) / (outshp - 1) + elif anch == 'e': # edges + scale = outshp / inshp + shift = 0.5 * (scale - 1) + fullscale *= scale + lin.append(torch.arange(0., inshp, **bck) * scale + shift) + elif anch == 'f': # first voxel + # scale = 1/f + # shift = 0 + fullscale *= 1/f + lin.append(torch.arange(0., inshp, **bck) / f) + elif anch == 'l': # last voxel + # scale = 1/f + shift = (outshp - 1) - (inshp - 1) / f + fullscale *= 1/f + lin.append(torch.arange(0., inshp, **bck) / f + shift) + else: + raise ValueError('Unknown anchor {}'.format(anch)) + + # scatter + kwargs.setdefault('bound', 'nearest') + kwargs.setdefault('extrapolate', True) + kwargs.setdefault('interpolation', interpolation) + kwargs.setdefault('prefilter', False) + grid = torch.stack(meshgrid_ij(*lin), dim=-1) + resized = grid_push(image, grid, shape, **kwargs) + if not reduce_sum: + resized /= fullscale + + return resized diff --git a/utils/interpol/splines.py b/utils/interpol/splines.py new file mode 100644 index 0000000000000000000000000000000000000000..a456d87ff24ccc93b727af2cbdad6b0fbf5f6356 --- /dev/null +++ b/utils/interpol/splines.py @@ -0,0 +1,196 @@ +"""Weights and derivatives of spline orders 0 to 7.""" +import torch +from enum import Enum +from .jit_utils import square, cube, pow4, pow5, pow6, pow7 + + +class InterpolationType(Enum): + nearest = zeroth = 0 + linear = first = 1 + quadratic = second = 2 + cubic = third = 3 + fourth = 4 + fifth = 5 + sixth = 6 + seventh = 7 + + +@torch.jit.script +class Spline: + + def __init__(self, order: int = 1): + self.order = order + + def weight(self, x): + w = self.fastweight(x) + zero = torch.zeros([1], dtype=x.dtype, device=x.device) + w = torch.where(x.abs() >= (self.order + 1)/2, zero, w) + return w + + def fastweight(self, x): + if self.order == 0: + return torch.ones(x.shape, dtype=x.dtype, device=x.device) + x = x.abs() + if self.order == 1: + return 1 - x + if self.order == 2: + x_low = 0.75 - square(x) + x_up = 0.5 * square(1.5 - x) + return torch.where(x < 0.5, x_low, x_up) + if self.order == 3: + x_low = (x * x * (x - 2.) * 3. + 4.) / 6. + x_up = cube(2. - x) / 6. + return torch.where(x < 1., x_low, x_up) + if self.order == 4: + x_low = square(x) + x_low = x_low * (x_low * 0.25 - 0.625) + 115. / 192. + x_mid = x * (x * (x * (5. - x) / 6. - 1.25) + 5./24.) + 55./96. + x_up = pow4(x - 2.5) / 24. + return torch.where(x < 0.5, x_low, torch.where(x < 1.5, x_mid, x_up)) + if self.order == 5: + x_low = square(x) + x_low = x_low * (x_low * (0.25 - x / 12.) - 0.5) + 0.55 + x_mid = x * (x * (x * (x * (x / 24. - 0.375) + 1.25) - 1.75) + 0.625) + 0.425 + x_up = pow5(3 - x) / 120. + return torch.where(x < 1., x_low, torch.where(x < 2., x_mid, x_up)) + if self.order == 6: + x_low = square(x) + x_low = x_low * (x_low * (7./48. - x_low/36.) - 77./192.) + 5887./11520. + x_mid_low = (x * (x * (x * (x * (x * (x / 48. - 7./48.) + 0.328125) + - 35./288.) - 91./256.) - 7./768.) + 7861./15360.) + x_mid_up = (x * (x * (x * (x * (x * (7./60. - x / 120.) - 0.65625) + + 133./72.) - 2.5703125) + 1267./960.) + 1379./7680.) + x_up = pow6(x - 3.5) / 720. + return torch.where(x < .5, x_low, + torch.where(x < 1.5, x_mid_low, + torch.where(x < 2.5, x_mid_up, x_up))) + if self.order == 7: + x_low = square(x) + x_low = (x_low * (x_low * (x_low * (x / 144. - 1./36.) + + 1./9.) - 1./3.) + 151./315.) + x_mid_low = (x * (x * (x * (x * (x * (x * (0.05 - x/240.) - 7./30.) + + 0.5) - 7./18.) - 0.1) - 7./90.) + 103./210.) + x_mid_up = (x * (x * (x * (x * (x * (x * (x / 720. - 1./36.) + + 7./30.) - 19./18.) + 49./18.) - 23./6.) + 217./90.) + - 139./630.) + x_up = pow7(4 - x) / 5040. + return torch.where(x < 1., x_low, + torch.where(x < 2., x_mid_low, + torch.where(x < 3., x_mid_up, x_up))) + raise NotImplementedError + + def grad(self, x): + if self.order == 0: + return torch.zeros(x.shape, dtype=x.dtype, device=x.device) + g = self.fastgrad(x) + zero = torch.zeros([1], dtype=x.dtype, device=x.device) + g = torch.where(x.abs() >= (self.order + 1)/2, zero, g) + return g + + def fastgrad(self, x): + if self.order == 0: + return torch.zeros(x.shape, dtype=x.dtype, device=x.device) + return self._fastgrad(x.abs()).mul(x.sign()) + + def _fastgrad(self, x): + if self.order == 1: + return torch.ones(x.shape, dtype=x.dtype, device=x.device) + if self.order == 2: + return torch.where(x < 0.5, -2*x, x - 1.5) + if self.order == 3: + g_low = x * (x * 1.5 - 2) + g_up = -0.5 * square(2 - x) + return torch.where(x < 1, g_low, g_up) + if self.order == 4: + g_low = x * (square(x) - 1.25) + g_mid = x * (x * (x * (-2./3.) + 2.5) - 2.5) + 5./24. + g_up = cube(2. * x - 5.) / 48. + return torch.where(x < 0.5, g_low, + torch.where(x < 1.5, g_mid, g_up)) + if self.order == 5: + g_low = x * (x * (x * (x * (-5./12.) + 1.)) - 1.) + g_mid = x * (x * (x * (x * (5./24.) - 1.5) + 3.75) - 3.5) + 0.625 + g_up = pow4(x - 3.) / (-24.) + return torch.where(x < 1, g_low, + torch.where(x < 2, g_mid, g_up)) + if self.order == 6: + g_low = square(x) + g_low = x * (g_low * (7./12.) - square(g_low) / 6. - 77./96.) + g_mid_low = (x * (x * (x * (x * (x * 0.125 - 35./48.) + 1.3125) + - 35./96.) - 0.7109375) - 7./768.) + g_mid_up = (x * (x * (x * (x * (x / (-20.) + 7./12.) - 2.625) + + 133./24.) - 5.140625) + 1267./960.) + g_up = pow5(2*x - 7) / 3840. + return torch.where(x < 0.5, g_low, + torch.where(x < 1.5, g_mid_low, + torch.where(x < 2.5, g_mid_up, + g_up))) + if self.order == 7: + g_low = square(x) + g_low = x * (g_low * (g_low * (x * (7./144.) - 1./6.) + 4./9.) - 2./3.) + g_mid_low = (x * (x * (x * (x * (x * (x * (-7./240.) + 3./10.) + - 7./6.) + 2.) - 7./6.) - 1./5.) - 7./90.) + g_mid_up = (x * (x * (x * (x * (x * (x * (7./720.) - 1./6.) + + 7./6.) - 38./9.) + 49./6.) - 23./3.) + 217./90.) + g_up = pow6(x - 4) / (-720.) + return torch.where(x < 1, g_low, + torch.where(x < 2, g_mid_low, + torch.where(x < 3, g_mid_up, g_up))) + raise NotImplementedError + + def hess(self, x): + if self.order == 0: + return torch.zeros(x.shape, dtype=x.dtype, device=x.device) + h = self.fasthess(x) + zero = torch.zeros([1], dtype=x.dtype, device=x.device) + h = torch.where(x.abs() >= (self.order + 1)/2, zero, h) + return h + + def fasthess(self, x): + if self.order in (0, 1): + return torch.zeros(x.shape, dtype=x.dtype, device=x.device) + x = x.abs() + if self.order == 2: + one = torch.ones([1], dtype=x.dtype, device=x.device) + return torch.where(x < 0.5, -2 * one, one) + if self.order == 3: + return torch.where(x < 1, 3. * x - 2., 2. - x) + if self.order == 4: + return torch.where(x < 0.5, 3. * square(x) - 1.25, + torch.where(x < 1.5, x * (-2. * x + 5.) - 2.5, + square(2. * x - 5.) / 8.)) + if self.order == 5: + h_low = square(x) + h_low = - h_low * (x * (5./3.) - 3.) - 1. + h_mid = x * (x * (x * (5./6.) - 9./2.) + 15./2.) - 7./2. + h_up = 9./2. - x * (x * (x/6. - 3./2.) + 9./2.) + return torch.where(x < 1, h_low, + torch.where(x < 2, h_mid, h_up)) + if self.order == 6: + h_low = square(x) + h_low = - h_low * (h_low * (5./6) - 7./4.) - 77./96. + h_mid_low = (x * (x * (x * (x * (5./8.) - 35./12.) + 63./16.) + - 35./48.) - 91./128.) + h_mid_up = -(x * (x * (x * (x/4. - 7./3.) + 63./8.) - 133./12.) + + 329./64.) + h_up = (x * (x * (x * (x/24. - 7./12.) + 49./16.) - 343./48.) + + 2401./384.) + return torch.where(x < 0.5, h_low, + torch.where(x < 1.5, h_mid_low, + torch.where(x < 2.5, h_mid_up, + h_up))) + if self.order == 7: + h_low = square(x) + h_low = h_low * (h_low*(x * (7./24.) - 5./6.) + 4./3.) - 2./3. + h_mid_low = - (x * (x * (x * (x * (x * (7./40.) - 3./2.) + 14./3.) + - 6.) + 7./3.) + 1./5.) + h_mid_up = (x * (x * (x * (x * (x * (7./120.) - 5./6.) + 14./3.) + - 38./3.) + 49./3.) - 23./3.) + h_up = - (x * (x * (x * (x * (x/120. - 1./6.) + 4./3.) - 16./3.) + + 32./3.) - 128./15.) + return torch.where(x < 1, h_low, + torch.where(x < 2, h_mid_low, + torch.where(x < 3, h_mid_up, + h_up))) + raise NotImplementedError + diff --git a/utils/interpol/tests/__init__.py b/utils/interpol/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/utils/interpol/tests/test_gradcheck_pushpull.py b/utils/interpol/tests/test_gradcheck_pushpull.py new file mode 100644 index 0000000000000000000000000000000000000000..7a4344391b37b39284ed1b2b1d1e62f7346de533 --- /dev/null +++ b/utils/interpol/tests/test_gradcheck_pushpull.py @@ -0,0 +1,125 @@ +import torch +from torch.autograd import gradcheck +from interpol import grid_pull, grid_push, grid_count, grid_grad, add_identity_grid_ +import pytest +import inspect + +# global parameters +dtype = torch.double # data type (double advised to check gradients) +shape1 = 3 # size along each dimension +extrapolate = True + +if hasattr(torch, 'use_deterministic_algorithms'): + torch.use_deterministic_algorithms(True) +kwargs = dict(rtol=1., raise_exception=True) +if 'check_undefined_grad' in inspect.signature(gradcheck).parameters: + kwargs['check_undefined_grad'] = False +if 'nondet_tol' in inspect.signature(gradcheck).parameters: + kwargs['nondet_tol'] = 1e-3 + +# parameters +devices = [('cpu', 1)] +if torch.backends.openmp.is_available() or torch.backends.mkl.is_available(): + print('parallel backend available') + devices.append(('cpu', 10)) +if torch.cuda.is_available(): + print('cuda backend available') + devices.append('cuda') + +dims = [1, 2, 3] +bounds = list(range(7)) +order_bounds = [] +for o in range(3): + for b in bounds: + order_bounds += [(o, b)] +for o in range(3, 8): + order_bounds += [(o, 3)] # only test dc2 for order > 2 + + +def make_data(shape, device, dtype): + grid = torch.randn([2, *shape, len(shape)], device=device, dtype=dtype) + grid = add_identity_grid_(grid) + vol = torch.randn((2, 1,) + shape, device=device, dtype=dtype) + return vol, grid + + +def init_device(device): + if isinstance(device, (list, tuple)): + device, param = device + else: + param = 1 if device == 'cpu' else 0 + if device == 'cuda': + torch.cuda.set_device(param) + torch.cuda.init() + try: + torch.cuda.empty_cache() + except RuntimeError: + pass + device = '{}:{}'.format(device, param) + else: + assert device == 'cpu' + torch.set_num_threads(param) + return torch.device(device) + + +@pytest.mark.parametrize("device", devices) +@pytest.mark.parametrize("dim", dims) +# @pytest.mark.parametrize("bound", bounds) +# @pytest.mark.parametrize("interpolation", orders) +@pytest.mark.parametrize("interpolation,bound", order_bounds) +def test_gradcheck_grad(device, dim, bound, interpolation): + print(f'grad_{dim}d({interpolation}, {bound}) on {device}') + device = init_device(device) + shape = (shape1,) * dim + vol, grid = make_data(shape, device, dtype) + vol.requires_grad = True + grid.requires_grad = True + assert gradcheck(grid_grad, (vol, grid, interpolation, bound, extrapolate), + **kwargs) + + +@pytest.mark.parametrize("device", devices) +@pytest.mark.parametrize("dim", dims) +# @pytest.mark.parametrize("bound", bounds) +# @pytest.mark.parametrize("interpolation", orders) +@pytest.mark.parametrize("interpolation,bound", order_bounds) +def test_gradcheck_pull(device, dim, bound, interpolation): + print(f'pull_{dim}d({interpolation}, {bound}) on {device}') + device = init_device(device) + shape = (shape1,) * dim + vol, grid = make_data(shape, device, dtype) + vol.requires_grad = True + grid.requires_grad = True + assert gradcheck(grid_pull, (vol, grid, interpolation, bound, extrapolate), + **kwargs) + + +@pytest.mark.parametrize("device", devices) +@pytest.mark.parametrize("dim", dims) +# @pytest.mark.parametrize("bound", bounds) +# @pytest.mark.parametrize("interpolation", orders) +@pytest.mark.parametrize("interpolation,bound", order_bounds) +def test_gradcheck_push(device, dim, bound, interpolation): + print(f'push_{dim}d({interpolation}, {bound}) on {device}') + device = init_device(device) + shape = (shape1,) * dim + vol, grid = make_data(shape, device, dtype) + vol.requires_grad = True + grid.requires_grad = True + assert gradcheck(grid_push, (vol, grid, shape, interpolation, bound, extrapolate), + **kwargs) + + +@pytest.mark.parametrize("device", devices) +@pytest.mark.parametrize("dim", dims) +# @pytest.mark.parametrize("bound", bounds) +# @pytest.mark.parametrize("interpolation", orders) +@pytest.mark.parametrize("interpolation,bound", order_bounds) +def test_gradcheck_count(device, dim, bound, interpolation): + print(f'count_{dim}d({interpolation}, {bound}) on {device}') + device = init_device(device) + shape = (shape1,) * dim + _, grid = make_data(shape, device, dtype) + grid.requires_grad = True + assert gradcheck(grid_count, (grid, shape, interpolation, bound, extrapolate), + **kwargs) \ No newline at end of file diff --git a/utils/interpol/utils.py b/utils/interpol/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2f1109fef6e3d93b7cb009a99ec31cccd7f45752 --- /dev/null +++ b/utils/interpol/utils.py @@ -0,0 +1,176 @@ +import torch + + +def fake_decorator(*a, **k): + if len(a) == 1 and not k: + return a[0] + else: + return fake_decorator + + +def make_list(x, n=None, **kwargs): + """Ensure that the input is a list (of a given size) + + Parameters + ---------- + x : list or tuple or scalar + Input object + n : int, optional + Required length + default : scalar, optional + Value to right-pad with. Use last value of the input by default. + + Returns + ------- + x : list + """ + if not isinstance(x, (list, tuple)): + x = [x] + x = list(x) + if n and len(x) < n: + default = kwargs.get('default', x[-1]) + x = x + [default] * max(0, n - len(x)) + return x + + +def expanded_shape(*shapes, side='left'): + """Expand input shapes according to broadcasting rules + + Parameters + ---------- + *shapes : sequence[int] + Input shapes + side : {'left', 'right'}, default='left' + Side to add singleton dimensions. + + Returns + ------- + shape : tuple[int] + Output shape + + Raises + ------ + ValueError + If shapes are not compatible for broadcast. + + """ + def error(s0, s1): + raise ValueError('Incompatible shapes for broadcasting: {} and {}.' + .format(s0, s1)) + + # 1. nb dimensions + nb_dim = 0 + for shape in shapes: + nb_dim = max(nb_dim, len(shape)) + + # 2. enumerate + shape = [1] * nb_dim + for i, shape1 in enumerate(shapes): + pad_size = nb_dim - len(shape1) + ones = [1] * pad_size + if side == 'left': + shape1 = [*ones, *shape1] + else: + shape1 = [*shape1, *ones] + shape = [max(s0, s1) if s0 == 1 or s1 == 1 or s0 == s1 + else error(s0, s1) for s0, s1 in zip(shape, shape1)] + + return tuple(shape) + + +def matvec(mat, vec, out=None): + """Matrix-vector product (supports broadcasting) + + Parameters + ---------- + mat : (..., M, N) tensor + Input matrix. + vec : (..., N) tensor + Input vector. + out : (..., M) tensor, optional + Placeholder for the output tensor. + + Returns + ------- + mv : (..., M) tensor + Matrix vector product of the inputs + + """ + vec = vec[..., None] + if out is not None: + out = out[..., None] + + mv = torch.matmul(mat, vec, out=out) + mv = mv[..., 0] + if out is not None: + out = out[..., 0] + + return mv + + +def _compare_versions(version1, mode, version2): + for v1, v2 in zip(version1, version2): + if mode in ('gt', '>'): + if v1 > v2: + return True + elif v1 < v2: + return False + elif mode in ('ge', '>='): + if v1 > v2: + return True + elif v1 < v2: + return False + elif mode in ('lt', '<'): + if v1 < v2: + return True + elif v1 > v2: + return False + elif mode in ('le', '<='): + if v1 < v2: + return True + elif v1 > v2: + return False + if mode in ('gt', 'lt', '>', '<'): + return False + else: + return True + + +def torch_version(mode, version): + """Check torch version + + Parameters + ---------- + mode : {'<', '<=', '>', '>='} + version : tuple[int] + + Returns + ------- + True if "torch.version version" + + """ + current_version, *cuda_variant = torch.__version__.split('+') + major, minor, patch, *_ = current_version.split('.') + # strip alpha tags + for x in 'abcdefghijklmnopqrstuvwxy': + if x in patch: + patch = patch[:patch.index(x)] + current_version = (int(major), int(minor), int(patch)) + version = make_list(version) + return _compare_versions(current_version, mode, version) + + +if torch_version('>=', (1, 10)): + meshgrid_ij = lambda *x: torch.meshgrid(*x, indexing='ij') + meshgrid_xy = lambda *x: torch.meshgrid(*x, indexing='xy') +else: + meshgrid_ij = lambda *x: torch.meshgrid(*x) + def meshgrid_xy(*x): + grid = list(torch.meshgrid(*x)) + if len(grid) > 1: + grid[0] = grid[0].transpose(0, 1) + grid[1] = grid[1].transpose(0, 1) + return grid + + +meshgrid = meshgrid_ij \ No newline at end of file diff --git a/utils/logging.py b/utils/logging.py new file mode 100644 index 0000000000000000000000000000000000000000..92cc250f3b50717c1c860f2ac80388d427b7b71d --- /dev/null +++ b/utils/logging.py @@ -0,0 +1,101 @@ +#!/usr/bin/env python3 + +"""Logging.""" + +import atexit +import builtins +import decimal +import functools +import logging +import os +import sys + +import simplejson + +import utils.distributed as du +from utils.env import pathmgr + + + +def _suppress_print(): + """ + Suppresses printing from the current process. + """ + + def print_pass(*objects, sep=" ", end="\n", file=sys.stdout, flush=False): + pass + + builtins.print = print_pass + + +@functools.lru_cache(maxsize=None) +def _cached_log_stream(filename): + # Use 1K buffer if writing to cloud storage. + io = pathmgr.open(filename, "a", buffering=1024 if "://" in filename else -1) + atexit.register(io.close) + return io + + +def setup_logging(output_dir=None): + """ + Sets up the logging for multiple processes. Only enable the logging for the + master process, and suppress logging for the non-master processes. + """ + + if du.is_master_proc(): + # Enable logging for the master process. + logging.root.handlers = [] + else: + # Suppress logging for non-master processes. + _suppress_print() + + logger = logging.getLogger() + #logger.setLevel(logging.DEBUG) + logger.setLevel(logging.INFO) + logger.propagate = False + plain_formatter = logging.Formatter( + "[%(asctime)s][%(levelname)s] %(filename)s: %(lineno)3d: %(message)s", + datefmt="%m/%d %H:%M:%S", + ) + + if du.is_master_proc(): + ch = logging.StreamHandler(stream=sys.stdout) + #ch.setLevel(logging.DEBUG) + ch.setLevel(logging.INFO) + ch.setFormatter(plain_formatter) + logger.addHandler(ch) + + if output_dir is not None and du.is_master_proc(du.get_world_size()): + filename = os.path.join(output_dir, "stdout.log") + fh = logging.StreamHandler(_cached_log_stream(filename)) + #fh.setLevel(logging.DEBUG) + fh.setLevel(logging.INFO) + fh.setFormatter(plain_formatter) + logger.addHandler(fh) + + +def get_logger(name): + """ + Retrieve the logger with the specified name or, if name is None, return a + logger which is the root logger of the hierarchy. + Args: + name (string): name of the logger. + """ + return logging.getLogger(name) + + +def log_json_stats(stats): + """ + Logs json stats. + Args: + stats (dict): a dictionary of statistical information to log. + """ + stats = { + k: decimal.Decimal("{:.5f}".format(v)) if isinstance(v, float) else v + for k, v in stats.items() + } + json_stats = simplejson.dumps(stats, sort_keys=True, use_decimal=True) + logger = get_logger(__name__) + logger.info("json_stats: {:s}".format(json_stats)) + + diff --git a/utils/lr_policy.py b/utils/lr_policy.py new file mode 100644 index 0000000000000000000000000000000000000000..42ab9b31dda107d7e3bf2ac750a22c30705e205d --- /dev/null +++ b/utils/lr_policy.py @@ -0,0 +1,89 @@ +#!/usr/bin/env python3 + +"""Learning rate policy.""" + +import math + + +def get_lr_at_epoch(cfg, cur_epoch): + """ + Retrieve the learning rate of the current epoch with the option to perform + warm up in the beginning of the training stage. + Args: + cfg (CfgNode): configs. Details can be found in + slowfast/config/defaults.py + cur_epoch (float): the number of epoch of the current training stage. + """ + lr = get_lr_func(cfg.SOLVER.LR_POLICY)(cfg, cur_epoch) + # Perform warm up. + if cur_epoch < cfg.SOLVER.WARMUP_EPOCHS: + lr_start = cfg.SOLVER.WARMUP_START_LR + lr_end = get_lr_func(cfg.SOLVER.LR_POLICY)(cfg, cfg.SOLVER.WARMUP_EPOCHS) + alpha = (lr_end - lr_start) / cfg.SOLVER.WARMUP_EPOCHS + lr = cur_epoch * alpha + lr_start + return lr + + +def lr_func_cosine(cfg, cur_epoch): + """ + Retrieve the learning rate to specified values at specified epoch with the + cosine learning rate schedule. Details can be found in: + Ilya Loshchilov, and Frank Hutter + SGDR: Stochastic Gradient Descent With Warm Restarts. + Args: + cfg (CfgNode): configs. Details can be found in + slowfast/config/defaults.py + cur_epoch (float): the number of epoch of the current training stage. + """ + offset = cfg.SOLVER.WARMUP_EPOCHS if cfg.SOLVER.COSINE_AFTER_WARMUP else 0.0 + assert cfg.SOLVER.COSINE_END_LR < cfg.SOLVER.BASE_LR + return ( + cfg.SOLVER.COSINE_END_LR + + (cfg.SOLVER.BASE_LR - cfg.SOLVER.COSINE_END_LR) + * ( + math.cos(math.pi * (cur_epoch - offset) / (cfg.SOLVER.MAX_EPOCH - offset)) + + 1.0 + ) + * 0.5 + ) + + +def lr_func_steps_with_relative_lrs(cfg, cur_epoch): + """ + Retrieve the learning rate to specified values at specified epoch with the + steps with relative learning rate schedule. + Args: + cfg (CfgNode): configs. Details can be found in + slowfast/config/defaults.py + cur_epoch (float): the number of epoch of the current training stage. + """ + ind = get_step_index(cfg, cur_epoch) + return cfg.SOLVER.LRS[ind] * cfg.SOLVER.BASE_LR + + +def get_step_index(cfg, cur_epoch): + """ + Retrieves the lr step index for the given epoch. + Args: + cfg (CfgNode): configs. Details can be found in + slowfast/config/defaults.py + cur_epoch (float): the number of epoch of the current training stage. + """ + steps = cfg.SOLVER.STEPS + [cfg.SOLVER.MAX_EPOCH] + for ind, step in enumerate(steps): # NoQA + if cur_epoch < step: + break + return ind - 1 + + +def get_lr_func(lr_policy): + """ + Given the configs, retrieve the specified lr policy function. + Args: + lr_policy (string): the learning rate policy to use for the job. + """ + policy = "lr_func_" + lr_policy + if policy not in globals(): + raise NotImplementedError("Unknown LR policy: {}".format(lr_policy)) + else: + return globals()[policy] diff --git a/utils/misc.py b/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..7b0dc758e6e87c5f6dc8af76c481e698c1dd96de --- /dev/null +++ b/utils/misc.py @@ -0,0 +1,1448 @@ +""" +Misc functions, including distributed helpers. + +Mostly copy-paste from torchvision references. +""" +import os +import math +import time +import datetime +import pickle +import shutil +import subprocess +import warnings +from argparse import Namespace +from typing import List, Optional +import numpy as np +import nibabel as nib +from pathlib import Path +import SimpleITK as sitk +import matplotlib.pyplot as plt + +import utils.logging as logging +import utils.multiprocessing as mpu +from utils.process_cfg import load_config + +from collections import defaultdict, deque + + +import torch +import torch.nn as nn +import torch.distributed as dist +import torch.nn.functional as F +# needed due to empty tensor bug in pytorch and torchvision 0.5 +import torchvision +from torch import Tensor +from visdom import Visdom + + +logger = logging.get_logger(__name__) + + +'''if float(torchvision.__version__[:3]) < 0.7: + from torchvision.ops import _new_empty_tensor + from torchvision.ops.misc import _output_size''' + + +def make_dir(dir_name, parents = True, exist_ok = True, reset = False): + if reset and os.path.isdir(dir_name): + shutil.rmtree(dir_name) + dir_name = Path(dir_name) + dir_name.mkdir(parents=parents, exist_ok=exist_ok) + return dir_name + + +def read_image(img_path, save_path = None): + img = nib.load(img_path) + nda = img.get_fdata() + affine = img.affine + if save_path: + ni_img = nib.Nifti1Image(nda, affine) + nib.save(ni_img, save_path) + return np.squeeze(nda), affine + +def save_image(nda, affine, save_path): + ni_img = nib.Nifti1Image(nda, affine) + nib.save(ni_img, save_path) + return save_path + +def img2nda(img_path, save_path = None): + img = sitk.ReadImage(img_path) + nda = sitk.GetArrayFromImage(img) + if save_path: + np.save(save_path, nda) + return nda, img.GetOrigin(), img.GetSpacing(), img.GetDirection() + +def to3d(img_path, save_path = None): + nda, o, s, d = img2nda(img_path) + save_path = img_path if save_path is None else save_path + if len(o) > 3: + nda2img(nda, o[:3], s[:3], d[:3] + d[4:7] + d[8:11], save_path) + return save_path + +def nda2img(nda, origin = None, spacing = None, direction = None, save_path = None, isVector = None): + if type(nda) == torch.Tensor: + nda = nda.cpu().detach().numpy() + nda = np.squeeze(np.array(nda)) + isVector = isVector if isVector else len(nda.shape) > 3 + img = sitk.GetImageFromArray(nda, isVector = isVector) + if origin: + img.SetOrigin(origin) + if spacing: + img.SetSpacing(spacing) + if direction: + img.SetDirection(direction) + if save_path: + sitk.WriteImage(img, save_path) + return img + + + +def cropping(img_path, tol = 0, crop_range_lst = None, spare = 0, save_path = None): + + img = sitk.ReadImage(img_path) + orig_nda = sitk.GetArrayFromImage(img) + if len(orig_nda.shape) > 3: # 4D data: last axis (t=0) as time dimension + nda = orig_nda[..., 0] + else: + nda = np.copy(orig_nda) + + if crop_range_lst is None: + # Mask of non-black pixels (assuming image has a single channel). + mask = nda > tol + # Coordinates of non-black pixels. + coords = np.argwhere(mask) + # Bounding box of non-black pixels. + x0, y0, z0 = coords.min(axis=0) + x1, y1, z1 = coords.max(axis=0) + 1 # slices are exclusive at the top + # add sparing gap if needed + x0 = x0 - spare if x0 > spare else x0 + y0 = y0 - spare if y0 > spare else y0 + z0 = z0 - spare if z0 > spare else z0 + x1 = x1 + spare if x1 < orig_nda.shape[0] - spare else x1 + y1 = y1 + spare if y1 < orig_nda.shape[1] - spare else y1 + z1 = z1 + spare if z1 < orig_nda.shape[2] - spare else z1 + + # Check the the bounding box # + #print(' Cropping Slice [%d, %d)' % (x0, x1)) + #print(' Cropping Row [%d, %d)' % (y0, y1)) + #print(' Cropping Column [%d, %d)' % (z0, z1)) + + else: + [[x0, y0, z0], [x1, y1, z1]] = crop_range_lst + + + cropped_nda = orig_nda[x0 : x1, y0 : y1, z0 : z1] + new_origin = [img.GetOrigin()[0] + img.GetSpacing()[0] * z0,\ + img.GetOrigin()[1] + img.GetSpacing()[1] * y0,\ + img.GetOrigin()[2] + img.GetSpacing()[2] * x0] # numpy reverse to sitk''' + cropped_img = sitk.GetImageFromArray(cropped_nda, isVector = len(orig_nda.shape) > 3) + cropped_img.SetOrigin(new_origin) + #cropped_img.SetOrigin(img.GetOrigin()) + cropped_img.SetSpacing(img.GetSpacing()) + cropped_img.SetDirection(img.GetDirection()) + if save_path: + sitk.WriteImage(cropped_img, save_path) + + return cropped_img, [[x0, y0, z0], [x1, y1, z1]], new_origin + + + +def memory_stats(): + logger.info(torch.cuda.memory_allocated()/1024**2) + logger.info(torch.cuda.memory_reserved()/1024**2) + +######################################### +######################################### + + +def viewVolume(x, aff=None, prefix='', postfix='', names=[], ext='.nii.gz', save_dir='/tmp'): + + if aff is None: + aff = np.eye(4) + else: + if type(aff) == torch.Tensor: + aff = aff.cpu().detach().numpy() + + if type(x) is dict: + names = list(x.keys()) + x = [x[k] for k in x] + + if type(x) is not list: + x = [x] + + #cmd = 'source /usr/local/freesurfer/nmr-dev-env-bash && freeview ' + + for n in range(len(x)): + vol = x[n] + if vol is not None: + if type(vol) == torch.Tensor: + vol = vol.cpu().detach().numpy() + vol = np.squeeze(np.array(vol)) + try: + save_path = os.path.join(make_dir(save_dir), prefix + names[n] + postfix + ext) + except: + save_path = os.path.join(make_dir(save_dir), prefix + str(n) + postfix + ext) + MRIwrite(vol, aff, save_path) + #cmd = cmd + ' ' + save_path + + #os.system(cmd + ' &') + return save_path + +###############################3 + +def MRIwrite(volume, aff, filename, dtype=None): + + if dtype is not None: + volume = volume.astype(dtype=dtype) + + if aff is None: + aff = np.eye(4) + header = nib.Nifti1Header() + nifty = nib.Nifti1Image(volume, aff, header) + + nib.save(nifty, filename) + +############################### + +def MRIread(filename, dtype=None, im_only=False): + # dtype example: 'int', 'float' + assert filename.endswith(('.nii', '.nii.gz', '.mgz')), 'Unknown data file: %s' % filename + + x = nib.load(filename) + volume = x.get_fdata() + aff = x.affine + + if dtype is not None: + volume = volume.astype(dtype=dtype) + + if im_only: + return volume + else: + return volume, aff + +############## + +def get_ras_axes(aff, n_dims=3): + """This function finds the RAS axes corresponding to each dimension of a volume, based on its affine matrix. + :param aff: affine matrix Can be a 2d numpy array of size n_dims*n_dims, n_dims+1*n_dims+1, or n_dims*n_dims+1. + :param n_dims: number of dimensions (excluding channels) of the volume corresponding to the provided affine matrix. + :return: two numpy 1d arrays of lengtn n_dims, one with the axes corresponding to RAS orientations, + and one with their corresponding direction. + """ + aff_inverted = np.linalg.inv(aff) + img_ras_axes = np.argmax(np.absolute(aff_inverted[0:n_dims, 0:n_dims]), axis=0) + return img_ras_axes + + + + +def all_gather(data): + + """ + Run all_gather on arbitrary picklable data (not necessarily tensors) + Args: + data: any picklable object + Returns: + list[data]: list of data gathered from each rank + """ + world_size = get_world_size() + if world_size == 1: + return [data] + + # serialized to a Tensor + buffer = pickle.dumps(data) + storage = torch.ByteStorage.from_buffer(buffer) + tensor = torch.ByteTensor(storage).to("cuda") + + # obtain Tensor size of each rank + local_size = torch.tensor([tensor.numel()], device="cuda") + size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)] + dist.all_gather(size_list, local_size) + size_list = [int(size.item()) for size in size_list] + max_size = max(size_list) + + # receiving Tensor from all ranks + # we pad the tensor because torch all_gather does not support + # gathering tensors of different shapes + tensor_list = [] + for _ in size_list: + tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda")) + if local_size != max_size: + padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda") + tensor = torch.cat((tensor, padding), dim=0) + dist.all_gather(tensor_list, tensor) + + data_list = [] + for size, tensor in zip(size_list, tensor_list): + buffer = tensor.cpu().numpy().tobytes()[:size] + data_list.append(pickle.loads(buffer)) + + return data_list + + +def reduce_dict(input_dict, average=True): + """ + Args: + input_dict (dict): all the values will be reduced + average (bool): whether to do average or sum + Reduce the values in the dictionary from all processes so that all processes + have the averaged results. Returns a dict with the same fields as + input_dict, after reduction. + """ + world_size = get_world_size() + if world_size < 2: + return input_dict + with torch.no_grad(): + names = [] + values = [] + # sort the keys so that they are consistent across processes + for k in sorted(input_dict.keys()): + names.append(k) + values.append(input_dict[k]) + values = torch.stack(values, dim=0) + dist.all_reduce(values) + if average: + values /= world_size + reduced_dict = {k: v for k, v in zip(names, values)} + return reduced_dict + + +def get_sha(): + cwd = os.path.dirname(os.path.abspath(__file__)) + + def _run(command): + return subprocess.check_output(command, cwd=cwd).decode('ascii').strip() + sha = 'N/A' + diff = "clean" + branch = 'N/A' + try: + sha = _run(['git', 'rev-parse', 'HEAD']) + subprocess.check_output(['git', 'diff'], cwd=cwd) + diff = _run(['git', 'diff-index', 'HEAD']) + diff = "has uncommited changes" if diff else "clean" + branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD']) + except Exception: + pass + message = f"sha: {sha}, status: {diff}, branch: {branch}" + return message + + +def collate_fn(batch): + batch = {k: torch.stack([dict[k] for dict in batch]) for k in batch[0]} # switch from batch of dict to dict of batch + return batch + #v = {k: [dic[k] for dic in LD] for k in LD[0]} + +def _max_by_axis(the_list): + # type: (List[List[int]]) -> List[int] + maxes = the_list[0] + for sublist in the_list[1:]: + for index, item in enumerate(sublist): + maxes[index] = max(maxes[index], item) + return maxes + + +def launch_job(submit_cfg, gen_cfg, train_cfg, func, daemon = False): + """ + Run 'func' on one or more GPUs, specified in cfg + Args: + cfg (CfgNode): configs. Details can be found in + slowfast/config/defaults.py + init_method (str): initialization method to launch the job with multiple + devices. + func (function): job to run on GPU(s) + daemon (bool): The spawned processes’ daemon flag. If set to True, + daemonic processes will be created + """ + if submit_cfg is not None and submit_cfg.num_gpus > 1: + logger.info('num_gpus:', submit_cfg.num_gpus) + torch.multiprocessing.spawn( + mpu.run, + nprocs=submit_cfg.num_gpus, + args=( + submit_cfg.num_gpus, + func, + submit_cfg.init_method, + submit_cfg.shard_id, + submit_cfg.num_shards, + submit_cfg.dist_backend, + submit_cfg, + ), + daemon = daemon, + ) + else: + logger.info('num_gpus: 1') + func([submit_cfg, gen_cfg, train_cfg]) + + +def preprocess_cfg(cfg_files, cfg_dir = ''): + config = load_config(cfg_files[0], cfg_files[1:], cfg_dir) + args = nested_dict_to_namespace(config) + return args + + +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + import builtins as __builtin__ + builtin_print = __builtin__.print + + def print(*args, **kwargs): + force = kwargs.pop('force', False) + + if is_master or force: + builtin_print(*args, **kwargs) + + __builtin__.print = print + + if not is_master: + def line(*args, **kwargs): + pass + def images(*args, **kwargs): + pass + Visdom.line = line + Visdom.images = images + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process(): + return get_rank() == 0 + + +def save_on_master(*args, **kwargs): + if is_main_process(): + torch.save(*args, **kwargs) + + +def init_distributed_mode(cfg): + """ + Initialize variables needed for distributed training. + """ + if cfg.num_gpus <= 1: + return + num_gpus_per_machine = cfg.num_gpus + num_machines = dist.get_world_size() // num_gpus_per_machine + logger.info(num_gpus_per_machine, dist.get_world_size()) + for i in range(num_machines): + ranks_on_i = list( + range(i * num_gpus_per_machine, (i + 1) * num_gpus_per_machine) + ) + pg = dist.new_group(ranks_on_i) + if i == cfg.shard_id: + global _LOCAL_PROCESS_GROUP + _LOCAL_PROCESS_GROUP = pg + +'''def init_distributed_mode(args): + if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: + #args.rank = int(os.environ["RANK"]) + #args.world_size = int(os.environ['WORLD_SIZE']) + #args.gpu = int(os.environ['LOCAL_RANK']) + pass + elif 'SLURM_PROCID' in os.environ and 'SLURM_PTY_PORT' not in os.environ: + # slurm process but not interactive + args.rank = int(os.environ['SLURM_PROCID']) + args.gpu = args.rank % torch.cuda.device_count() + elif args.num_gpus < 1: + print('Not using distributed mode') + #args.distributed = False + return + + args.world_size = int(args.num_gpus * args.nodes) + + #args.distributed = True + + torch.cuda.set_device(args.gpu) + #args.dist_backend = 'nccl' + print(f'| distributed init (rank {args.rank}): {args.dist_url}', flush=True) + torch.distributed.init_process_group( + backend=args.dist_backend, init_method=args.dist_url, + world_size=args.world_size, rank=args.rank) + # torch.distributed.barrier() + setup_for_distributed(args.rank == 0)''' + + +@torch.no_grad() +def accuracy(output, target, topk=(1,)): + """Computes the precision@k for the specified values of k""" + if target.numel() == 0: + return [torch.zeros([], device=output.device)] + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None): + # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor + """ + Equivalent to nn.functional.interpolate, but with support for empty batch sizes. + This will eventually be supported natively by PyTorch, and this + class can go away. + """ + if float(torchvision.__version__[:3]) < 0.7: + if input.numel() > 0: + return torch.nn.functional.interpolate( + input, size, scale_factor, mode, align_corners + ) + + output_shape = _output_size(2, input, size, scale_factor) + output_shape = list(input.shape[:-2]) + list(output_shape) + return _new_empty_tensor(input, output_shape) + else: + return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners) + + +class DistributedWeightedSampler(torch.utils.data.DistributedSampler): + def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, replacement=True): + super(DistributedWeightedSampler, self).__init__(dataset, num_replicas, rank, shuffle) + + assert replacement + + self.replacement = replacement + + def __iter__(self): + iter_indices = super(DistributedWeightedSampler, self).__iter__() + if hasattr(self.dataset, 'sample_weight'): + indices = list(iter_indices) + + weights = torch.tensor([self.dataset.sample_weight(idx) for idx in indices]) + + g = torch.Generator() + g.manual_seed(self.epoch) + + weight_indices = torch.multinomial( + weights, self.num_samples, self.replacement, generator=g) + indices = torch.tensor(indices)[weight_indices] + + iter_indices = iter(indices.tolist()) + return iter_indices + + def __len__(self): + return self.num_samples + + +def inverse_sigmoid(x, eps=1e-5): + x = x.clamp(min=0, max=1) + x1 = x.clamp(min=eps) + x2 = (1 - x).clamp(min=eps) + return torch.log(x1/x2) + + +def dice_loss(inputs, targets, num_boxes): + """ + Compute the DICE loss, similar to generalized IOU for masks + Args: + inputs: A float tensor of arbitrary shape. + The predictions for each example. + targets: A float tensor with the same shape as inputs. Stores the binary + classification label for each element in inputs + (0 for the negative class and 1 for the positive class). + """ + inputs = inputs.sigmoid() + inputs = inputs.flatten(1) + numerator = 2 * (inputs * targets).sum(1) + denominator = inputs.sum(-1) + targets.sum(-1) + loss = 1 - (numerator + 1) / (denominator + 1) + return loss.sum() / num_boxes + + +def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2, query_mask=None, reduction=True): + """ + Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. + Args: + inputs: A float tensor of arbitrary shape. + The predictions for each example. + targets: A float tensor with the same shape as inputs. Stores the binary + classification label for each element in inputs + (0 for the negative class and 1 for the positive class). + alpha: (optional) Weighting factor in range (0,1) to balance + positive vs negative examples. Default = -1 (no weighting). + gamma: Exponent of the modulating factor (1 - p_t) to + balance easy vs hard examples. + Returns: + Loss tensor + """ + prob = inputs.sigmoid() + ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") + p_t = prob * targets + (1 - prob) * (1 - targets) + loss = ce_loss * ((1 - p_t) ** gamma) + + if alpha >= 0: + alpha_t = alpha * targets + (1 - alpha) * (1 - targets) + loss = alpha_t * loss + + if not reduction: + return loss + + if query_mask is not None: + loss = torch.stack([l[m].mean(0) for l, m in zip(loss, query_mask)]) + return loss.sum() / num_boxes + return loss.mean(1).sum() / num_boxes + + +def nested_dict_to_namespace(dictionary): + namespace = dictionary + if isinstance(dictionary, dict): + namespace = Namespace(**dictionary) + for key, value in dictionary.items(): + setattr(namespace, key, nested_dict_to_namespace(value)) + return namespace + + + +def nested_dict_to_device(dictionary, device): + + if isinstance(dictionary, dict): + output = {} + for key, value in dictionary.items(): + output[key] = nested_dict_to_device(value, device) + return output + + if isinstance(dictionary, str): + return dictionary + elif isinstance(dictionary, list): + return [nested_dict_to_device(d, device) for d in dictionary] + else: + try: + return dictionary.to(device) + except: + return dictionary + +def merge_list_of_dict(dict_list_a, dict_list_b): + assert len(dict_list_a) == len(dict_list_b) + for i in range(len(dict_list_a)): + dict_list_a[i].update(dict_list_b[i]) + return dict_list_a + + + +class SmoothedValue(object): + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20, fmt=None): + if fmt is None: + fmt = "{median:.4f} ({global_avg:.4f})" + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, n=1): + self.deque.append(value) + self.count += n + self.total += value * n + + def synchronize_between_processes(self): + """ + Warning: does not synchronize the deque! + """ + if not is_dist_avail_and_initialized(): + return + t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') + dist.barrier() + dist.all_reduce(t) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + try: + return self.total / self.count + except: + return 0. + + @property + def max(self): + return max(self.deque) + + @property + def value(self): + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, + avg=self.avg, + global_avg=self.global_avg, + max=self.max, + value=self.value) + + + +class MetricLogger(object): + def __init__(self, print_freq, delimiter="\t", debug=False, sample_freq=None): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + self.print_freq = print_freq + self.debug = debug + self.sample_freq = sample_freq + + def update(self, **kwargs): + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError("'{}' object has no attribute '{}'".format( + type(self).__name__, attr)) + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + try: + loss_str.append(f"{name}: {meter}") + except: + loss_str = '' + return self.delimiter.join(loss_str) + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, meter): + self.meters[name] = meter + + def log_every(self, iterables, max_len, probs, epoch=None, header=None, is_test=False, train_limit=None, test_limit=None): + # iterables: dict = {dataset_name: dataloader} + + i = 0 + if header is None: + header = 'Epoch: [{}]'.format(epoch) + + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt='{avg:.4f}') + data_time = SmoothedValue(fmt='{avg:.4f}') + space_fmt = ':' + str(len(str(max_len))) + 'd' + MB = 1024.0 * 1024.0 + + generator_dict = {} + for k, v in iterables.items(): + generator_dict[k] = iter(v) + + for i in range(max_len): + chosen_dataset = np.random.choice(len(iterables), 1, p=probs)[0] + curr_dataset = list(iterables.keys())[chosen_dataset] + + if train_limit and i >= train_limit and not is_test: # train sub-set + break + if test_limit and i >= test_limit and is_test: # limit test iterations (1000) + break + + data_time.update(time.time() - end) + + try: + (dataset_num, dataset_name, input_mode, target, samples) = next(generator_dict[curr_dataset]) + except StopIteration: + logger.info('Re-iterate: {}'.format(curr_dataset)) + generator_dict[curr_dataset] = iter(iterables[curr_dataset]) + (dataset_num, dataset_name, input_mode, target, samples) = next(generator_dict[curr_dataset]) + dataset_name = dataset_name[0] + yield dataset_num, dataset_name, input_mode[0], target, samples + iter_time.update(time.time() - end) + + + if torch.cuda.is_available(): + log_msg = self.delimiter.join([ + header, + '[{0' + space_fmt + '}/{1}]', + 'dataset: {}'.format(dataset_name), + 'mode: {}'.format(input_mode[0]), + 'eta: {eta}', + '{meters}', + 'time: {time}', + 'data: {data}', + 'max mem: {memory:.0f}', + ]) + else: + log_msg = self.delimiter.join([ + header, + '[{0' + space_fmt + '}/{1}]', + 'dataset: {}'.format(dataset_name), + 'mode: {}'.format(input_mode[0]), + 'eta: {eta}', + '{meters}', + 'time: {time}', + 'data_time: {data}', + ]) + + if i % self.print_freq == 0 or i == max_len - 1: + eta_seconds = iter_time.global_avg * (max_len - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + if torch.cuda.is_available(): + logger.info(log_msg.format( + i , max_len, eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB)) + else: + logger.info(log_msg.format( + i, max_len, eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time))) + + if self.debug and i % self.print_freq == 0: + break + + i += 1 + end = time.time() + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + logger.info('{} Total time: {} ({:.4f} s / it)'.format( + header, total_time_str, total_time / max_len)) + + +######################### Synth ######################### + +def myzoom_torch_slow(X, factor, device, aff=None): + + if len(X.shape)==3: + X = X[..., None] + + delta = (1.0 - factor) / (2.0 * factor) + newsize = np.round(X.shape[:-1] * factor).astype(int) + + vx = torch.arange(delta[0], delta[0] + newsize[0] / factor[0], 1 / factor[0], dtype=torch.float, device=device)[:newsize[0]] + vy = torch.arange(delta[1], delta[1] + newsize[1] / factor[1], 1 / factor[1], dtype=torch.float, device=device)[:newsize[1]] + vz = torch.arange(delta[2], delta[2] + newsize[2] / factor[2], 1 / factor[2], dtype=torch.float, device=device)[:newsize[2]] + + vx[vx < 0] = 0 + vy[vy < 0] = 0 + vz[vz < 0] = 0 + vx[vx > (X.shape[0]-1)] = (X.shape[0]-1) + vy[vy > (X.shape[1] - 1)] = (X.shape[1] - 1) + vz[vz > (X.shape[2] - 1)] = (X.shape[2] - 1) + + fx = torch.floor(vx).int() + cx = fx + 1 + cx[cx > (X.shape[0]-1)] = (X.shape[0]-1) + wcx = vx - fx + wfx = 1 - wcx + + fy = torch.floor(vy).int() + cy = fy + 1 + cy[cy > (X.shape[1]-1)] = (X.shape[1]-1) + wcy = vy - fy + wfy = 1 - wcy + + fz = torch.floor(vz).int() + cz = fz + 1 + cz[cz > (X.shape[2]-1)] = (X.shape[2]-1) + wcz = vz - fz + wfz = 1 - wcz + + Y = torch.zeros([newsize[0], newsize[1], newsize[2], X.shape[3]], dtype=torch.float, device=device) + + for channel in range(X.shape[3]): + Xc = X[:,:,:,channel] + + tmp1 = torch.zeros([newsize[0], Xc.shape[1], Xc.shape[2]], dtype=torch.float, device=device) + for i in range(newsize[0]): + tmp1[i, :, :] = wfx[i] * Xc[fx[i], :, :] + wcx[i] * Xc[cx[i], :, :] + tmp2 = torch.zeros([newsize[0], newsize[1], Xc.shape[2]], dtype=torch.float, device=device) + for j in range(newsize[1]): + tmp2[:, j, :] = wfy[j] * tmp1[:, fy[j], :] + wcy[j] * tmp1[:, cy[j], :] + for k in range(newsize[2]): + Y[:, :, k, channel] = wfz[k] * tmp2[:, :, fz[k]] + wcz[k] * tmp2[:, :, cz[k]] + + if Y.shape[3] == 1: + Y = Y[:,:,:, 0] + + if aff is not None: + aff_new = aff.copy() + for c in range(3): + aff_new[:-1, c] = aff_new[:-1, c] / factor + aff_new[:-1, -1] = aff_new[:-1, -1] - aff[:-1, :-1] @ (0.5 - 0.5 / (factor * np.ones(3))) + return Y, aff_new + else: + return Y + +def myzoom_torch(X, factor, aff=None): + + if len(X.shape)==3: + X = X[..., None] + + delta = (1.0 - factor) / (2.0 * factor) + newsize = np.round(X.shape[:-1] * factor).astype(int) + + vx = torch.arange(delta[0], delta[0] + newsize[0] / factor[0], 1 / factor[0], dtype=torch.float, device=X.device)[:newsize[0]] + vy = torch.arange(delta[1], delta[1] + newsize[1] / factor[1], 1 / factor[1], dtype=torch.float, device=X.device)[:newsize[1]] + vz = torch.arange(delta[2], delta[2] + newsize[2] / factor[2], 1 / factor[2], dtype=torch.float, device=X.device)[:newsize[2]] + + vx[vx < 0] = 0 + vy[vy < 0] = 0 + vz[vz < 0] = 0 + vx[vx > (X.shape[0]-1)] = (X.shape[0]-1) + vy[vy > (X.shape[1] - 1)] = (X.shape[1] - 1) + vz[vz > (X.shape[2] - 1)] = (X.shape[2] - 1) + + fx = torch.floor(vx).int() + cx = fx + 1 + cx[cx > (X.shape[0]-1)] = (X.shape[0]-1) + wcx = (vx - fx) + wfx = 1 - wcx + + fy = torch.floor(vy).int() + cy = fy + 1 + cy[cy > (X.shape[1]-1)] = (X.shape[1]-1) + wcy = (vy - fy) + wfy = 1 - wcy + + fz = torch.floor(vz).int() + cz = fz + 1 + cz[cz > (X.shape[2]-1)] = (X.shape[2]-1) + wcz = (vz - fz) + wfz = 1 - wcz + + Y = torch.zeros([newsize[0], newsize[1], newsize[2], X.shape[3]], dtype=torch.float, device=X.device) + + tmp1 = torch.zeros([newsize[0], X.shape[1], X.shape[2], X.shape[3]], dtype=torch.float, device=X.device) + for i in range(newsize[0]): + tmp1[i, :, :] = wfx[i] * X[fx[i], :, :] + wcx[i] * X[cx[i], :, :] + tmp2 = torch.zeros([newsize[0], newsize[1], X.shape[2], X.shape[3]], dtype=torch.float, device=X.device) + for j in range(newsize[1]): + tmp2[:, j, :] = wfy[j] * tmp1[:, fy[j], :] + wcy[j] * tmp1[:, cy[j], :] + for k in range(newsize[2]): + Y[:, :, k] = wfz[k] * tmp2[:, :, fz[k]] + wcz[k] * tmp2[:, :, cz[k]] + + if Y.shape[3] == 1: + Y = Y[:,:,:, 0] + + if aff is not None: + aff_new = aff.copy() + aff_new[:-1] = aff_new[:-1] / factor + aff_new[:-1, -1] = aff_new[:-1, -1] - aff[:-1, :-1] @ (0.5 - 0.5 / (factor * np.ones(3))) + return Y, aff_new + else: + return Y + +def myzoom_torch_test(X, factor, aff=None): + time.sleep(3) + + start_time = time.time() + Y2 = myzoom_torch_slow(X, factor, aff) + print('slow', X.shape[-1], time.time() - start_time) + + time.sleep(3) + + start_time = time.time() + Y1 = myzoom_torch(X, factor, aff) + print('fast', X.shape[-1], time.time() - start_time) + + time.sleep(3) + + print('diff', (Y2 - Y1).mean(), (Y2 - Y1).max()) + return Y1 + +def myzoom_torch_anisotropic_slow(X, aff, newsize, device): + + if len(X.shape)==3: + X = X[..., None] + + factors = np.array(newsize) / np.array(X.shape[:-1]) + delta = (1.0 - factors) / (2.0 * factors) + + vx = torch.arange(delta[0], delta[0] + newsize[0] / factors[0], 1 / factors[0], dtype=torch.float, device=device)[:newsize[0]] + vy = torch.arange(delta[1], delta[1] + newsize[1] / factors[1], 1 / factors[1], dtype=torch.float, device=device)[:newsize[1]] + vz = torch.arange(delta[2], delta[2] + newsize[2] / factors[2], 1 / factors[2], dtype=torch.float, device=device)[:newsize[2]] + + vx[vx < 0] = 0 + vy[vy < 0] = 0 + vz[vz < 0] = 0 + vx[vx > (X.shape[0]-1)] = (X.shape[0]-1) + vy[vy > (X.shape[1] - 1)] = (X.shape[1] - 1) + vz[vz > (X.shape[2] - 1)] = (X.shape[2] - 1) + + fx = torch.floor(vx).int() + cx = fx + 1 + cx[cx > (X.shape[0]-1)] = (X.shape[0]-1) + wcx = vx - fx + wfx = 1 - wcx + + fy = torch.floor(vy).int() + cy = fy + 1 + cy[cy > (X.shape[1]-1)] = (X.shape[1]-1) + wcy = vy - fy + wfy = 1 - wcy + + fz = torch.floor(vz).int() + cz = fz + 1 + cz[cz > (X.shape[2]-1)] = (X.shape[2]-1) + wcz = vz - fz + wfz = 1 - wcz + + Y = torch.zeros([newsize[0], newsize[1], newsize[2], X.shape[3]], dtype=torch.float, device=device) + + dtype = X.dtype + for channel in range(X.shape[3]): + Xc = X[:,:,:,channel] + + tmp1 = torch.zeros([newsize[0], Xc.shape[1], Xc.shape[2]], dtype=dtype, device=device) + for i in range(newsize[0]): + tmp1[i, :, :] = wfx[i] * Xc[fx[i], :, :] + wcx[i] * Xc[cx[i], :, :] + tmp2 = torch.zeros([newsize[0], newsize[1], Xc.shape[2]], dtype=dtype, device=device) + for j in range(newsize[1]): + tmp2[:, j, :] = wfy[j] * tmp1[:, fy[j], :] + wcy[j] * tmp1[:, cy[j], :] + for k in range(newsize[2]): + Y[:, :, k, channel] = wfz[k] * tmp2[:, :, fz[k]] + wcz[k] * tmp2[:, :, cz[k]] + + if Y.shape[3] == 1: + Y = Y[:,:,:, 0] + + if aff is not None: + aff_new = aff.copy() + for c in range(3): + aff_new[:-1, c] = aff_new[:-1, c] / factors[c] + aff_new[:-1, -1] = aff_new[:-1, -1] - aff[:-1, :-1] @ (0.5 - 0.5 / factors) + return Y, aff_new + else: + return Y + + + +def myzoom_torch_anisotropic(X, aff, newsize): + + device = X.device + + if len(X.shape)==3: + X = X[..., None] + + factors = np.array(newsize) / np.array(X.shape[:-1]) + delta = (1.0 - factors) / (2.0 * factors) + + vx = torch.arange(delta[0], delta[0] + newsize[0] / factors[0], 1 / factors[0], dtype=torch.float, device=device)[:newsize[0]] + vy = torch.arange(delta[1], delta[1] + newsize[1] / factors[1], 1 / factors[1], dtype=torch.float, device=device)[:newsize[1]] + vz = torch.arange(delta[2], delta[2] + newsize[2] / factors[2], 1 / factors[2], dtype=torch.float, device=device)[:newsize[2]] + + vx[vx < 0] = 0 + vy[vy < 0] = 0 + vz[vz < 0] = 0 + vx[vx > (X.shape[0]-1)] = (X.shape[0]-1) + vy[vy > (X.shape[1] - 1)] = (X.shape[1] - 1) + vz[vz > (X.shape[2] - 1)] = (X.shape[2] - 1) + + fx = torch.floor(vx).int() + cx = fx + 1 + cx[cx > (X.shape[0]-1)] = (X.shape[0]-1) + wcx = vx - fx + wfx = 1 - wcx + + fy = torch.floor(vy).int() + cy = fy + 1 + cy[cy > (X.shape[1]-1)] = (X.shape[1]-1) + wcy = vy - fy + wfy = 1 - wcy + + fz = torch.floor(vz).int() + cz = fz + 1 + cz[cz > (X.shape[2]-1)] = (X.shape[2]-1) + wcz = vz - fz + wfz = 1 - wcz + + Y = torch.zeros([newsize[0], newsize[1], newsize[2], X.shape[3]], dtype=torch.float, device=device) + + dtype = X.dtype + for channel in range(X.shape[3]): + Xc = X[:,:,:,channel] + + tmp1 = torch.zeros([newsize[0], Xc.shape[1], Xc.shape[2]], dtype=dtype, device=device) + for i in range(newsize[0]): + tmp1[i, :, :] = wfx[i] * Xc[fx[i], :, :] + wcx[i] * Xc[cx[i], :, :] + tmp2 = torch.zeros([newsize[0], newsize[1], Xc.shape[2]], dtype=dtype, device=device) + for j in range(newsize[1]): + tmp2[:, j, :] = wfy[j] * tmp1[:, fy[j], :] + wcy[j] * tmp1[:, cy[j], :] + for k in range(newsize[2]): + Y[:, :, k, channel] = wfz[k] * tmp2[:, :, fz[k]] + wcz[k] * tmp2[:, :, cz[k]] + + if Y.shape[3] == 1: + Y = Y[:,:,:, 0] + + if aff is not None: + aff_new = aff.copy() + for c in range(3): + aff_new[:-1, c] = aff_new[:-1, c] / factors[c] + aff_new[:-1, -1] = aff_new[:-1, -1] - aff[:-1, :-1] @ (0.5 - 0.5 / factors) + return Y, aff_new + else: + return Y + +def torch_resize(I, aff, resolution, power_factor_at_half_width=5, dtype=torch.float32, slow=False): + + if torch.is_grad_enabled(): + with torch.no_grad(): + return torch_resize(I, aff, resolution, power_factor_at_half_width, dtype, slow) + + slow = slow or (I.device == 'cpu') + voxsize = np.sqrt(np.sum(aff[:-1, :-1] ** 2, axis=0)) + newsize = np.round(I.shape[0:3] * (voxsize / resolution)).astype(int) + factors = np.array(I.shape[0:3]) / np.array(newsize) + k = np.log(power_factor_at_half_width) / np.pi + sigmas = k * factors + sigmas[sigmas<=k] = 0 + + if len(I.shape) not in (3, 4): + raise Exception('torch_resize works with 3D or 3D+label volumes') + no_channels = len(I.shape) == 3 + if no_channels: + I = I[:, :, :, None] + if torch.is_tensor(I): + I = I.permute([3, 0, 1, 2]) + else: + I = I.transpose([3, 0, 1, 2]) + + It_lowres = None + for c in range(len(I)): + It = torch.as_tensor(I[c], device=I.device, dtype=dtype)[None, None] + # Smoothen if needed + for d in range(3): + It = It.permute([0, 1, 3, 4, 2]) + if sigmas[d]>0: + sl = np.ceil(sigmas[d] * 2.5).astype(int) + v = np.arange(-sl, sl + 1) + gauss = np.exp((-(v / sigmas[d]) ** 2 / 2)) + kernel = gauss / np.sum(gauss) + kernel = torch.tensor(kernel, device=I.device, dtype=dtype) + if slow: + It = conv_slow_fallback(It, kernel) + else: + kernel = kernel[None, None, None, None, :] + It = torch.conv3d(It, kernel, bias=None, stride=1, padding=[0, 0, int((kernel.shape[-1] - 1) / 2)]) + + + It = torch.squeeze(It) + It, aff2 = myzoom_torch_anisotropic(It, aff, newsize) + It = It.detach() + if torch.is_tensor(I): + It = It.to(I.device) + else: + It = It.cpu().numpy() + if len(I) == 1: + It_lowres = It[None] + else: + if It_lowres is None: + if torch.is_tensor(It): + It_lowres = It.new_empty([len(I), *It.shape]) + else: + It_lowres = np.empty_like(It, shape=[len(I), *It.shape]) + It_lowres[c] = It + + torch.cuda.empty_cache() + + if not no_channels: + if torch.is_tensor(I): + It_lowres = It_lowres.permute([1, 2, 3, 0]) + else: + It_lowres = It_lowres.transpose([1, 2, 3, 0]) + else: + It_lowres = It_lowres[0] + + return It_lowres, aff2 + +############################### + +@torch.jit.script +def conv_slow_fallback(x, kernel): + """1D Conv along the last dimension with padding""" + y = torch.zeros_like(x) + x = torch.nn.functional.pad(x, [(len(kernel) - 1) // 2]*2) + x = x.unfold(-1, size=len(kernel), step=1) + x = x.movedim(-1, 0) + for i in range(len(kernel)): + y = y.addcmul_(x[i], kernel[i]) + return y + + + +############################### + + +def align_volume_to_ref(volume, aff, aff_ref=None, return_aff=False, n_dims=3): + """This function aligns a volume to a reference orientation (axis and direction) specified by an affine matrix. + :param volume: a numpy array + :param aff: affine matrix of the floating volume + :param aff_ref: (optional) affine matrix of the target orientation. Default is identity matrix. + :param return_aff: (optional) whether to return the affine matrix of the aligned volume + :param n_dims: number of dimensions (excluding channels) of the volume corresponding to the provided affine matrix. + :return: aligned volume, with corresponding affine matrix if return_aff is True. + """ + + # work on copy + aff_flo = aff.copy() + + # default value for aff_ref + if aff_ref is None: + aff_ref = np.eye(4) + + # extract ras axes + ras_axes_ref = get_ras_axes(aff_ref, n_dims=n_dims) + ras_axes_flo = get_ras_axes(aff_flo, n_dims=n_dims) + + # align axes + aff_flo[:, ras_axes_ref] = aff_flo[:, ras_axes_flo] + for i in range(n_dims): + if ras_axes_flo[i] != ras_axes_ref[i]: + volume = torch.swapaxes(volume, ras_axes_flo[i], ras_axes_ref[i]) + swapped_axis_idx = np.where(ras_axes_flo == ras_axes_ref[i]) + ras_axes_flo[swapped_axis_idx], ras_axes_flo[i] = ras_axes_flo[i], ras_axes_flo[swapped_axis_idx] + + # align directions + dot_products = np.sum(aff_flo[:3, :3] * aff_ref[:3, :3], axis=0) + for i in range(n_dims): + if dot_products[i] < 0: + volume = torch.flip(volume, [i]) + aff_flo[:, i] = - aff_flo[:, i] + aff_flo[:3, 3] = aff_flo[:3, 3] - aff_flo[:3, i] * (volume.shape[i] - 1) + + if return_aff: + return volume, aff_flo + else: + return volume + + + +def multistep_scheduler(base_value, lr_drops, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0, gamma=0.1): + warmup_schedule = np.array([]) + warmup_iters = warmup_epochs * niter_per_ep + if warmup_epochs > 0: + warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) + + schedule = np.ones(epochs * niter_per_ep - warmup_iters) * base_value + for milestone in lr_drops: + schedule[milestone * niter_per_ep :] *= gamma + schedule = np.concatenate((warmup_schedule, schedule)) + assert len(schedule) == epochs * niter_per_ep + return schedule + + +def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0): + warmup_schedule = np.array([]) + warmup_iters = warmup_epochs * niter_per_ep + if warmup_epochs > 0: + warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) + + iters = np.arange(epochs * niter_per_ep - warmup_iters) + schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters))) + + schedule = np.concatenate((warmup_schedule, schedule)) + assert len(schedule) == epochs * niter_per_ep + return schedule + + +class LARS(torch.optim.Optimizer): + """ + Almost copy-paste from https://github.com/facebookresearch/barlowtwins/blob/main/main.py + """ + def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, eta=0.001, + weight_decay_filter=None, lars_adaptation_filter=None): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, + eta=eta, weight_decay_filter=weight_decay_filter, + lars_adaptation_filter=lars_adaptation_filter) + super().__init__(params, defaults) + + @torch.no_grad() + def step(self): + for g in self.param_groups: + for p in g['params']: + dp = p.grad + + if dp is None: + continue + + if p.ndim != 1: + dp = dp.add(p, alpha=g['weight_decay']) + + if p.ndim != 1: + param_norm = torch.norm(p) + update_norm = torch.norm(dp) + one = torch.ones_like(param_norm) + q = torch.where(param_norm > 0., + torch.where(update_norm > 0, + (g['eta'] * param_norm / update_norm), one), one) + dp = dp.mul(q) + + param_state = self.state[p] + if 'mu' not in param_state: + param_state['mu'] = torch.zeros_like(p) + mu = param_state['mu'] + mu.mul_(g['momentum']).add_(dp) + + p.add_(mu, alpha=-g['lr']) + + + +def cancel_gradients_last_layer(epoch, model, freeze_last_layer): + if epoch >= freeze_last_layer: + return + for n, p in model.named_parameters(): + if "last_layer" in n: + p.grad = None + + +def clip_gradients(model, clip): + norms = [] + for name, p in model.named_parameters(): + if p.grad is not None: + param_norm = p.grad.data.norm(2) + norms.append(param_norm.item()) + clip_coef = clip / (param_norm + 1e-6) + if clip_coef < 1: + p.grad.data.mul_(clip_coef) + return norms + + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + # type: (Tensor, float, float, float, float) -> Tensor + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + + +def has_batchnorms(model): + bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm) + for name, module in model.named_modules(): + if isinstance(module, bn_types): + return True + return False + + +def read_log(log_path, loss_name = 'loss'): + log_file = open(log_path, 'r') + lines = log_file.readlines() + epoches = [] + losses = [] + num_epoches = 0 + for i, line in enumerate(lines): + #print("Line{}: {}".format(i, line.strip())) + if len(line) <= 1: + break + num_epoches += 1 + epoches.append(int(line.split(' - ')[0].split('epoch ')[1])) + losses.append(float(line.split('"%s": ' % loss_name)[1].split(',')[0])) + #print('num_epoches:', num_epoches) + return epoches, losses + +def plot_loss(loss_lst, save_path): + fig = plt.figure() + ax = fig.add_subplot(111) + t = list(np.arange(len(loss_lst))) + + ax.plot(t, np.array(loss_lst), 'r--') + ax.set_xlabel('Epoch') + ax.set_ylabel('Loss') + #ax.set_yscale('log') + #ax.legend() + #ax.title.set_text(loss_name) + plt.savefig(save_path) + plt.close(fig) + return + +############################### + +# map SynthSeg right to left labels for contrast synthesis +right_to_left_dict = { + 41: 2, + 42: 3, + 43: 4, + 44: 5, + 46: 7, + 47: 8, + 49: 10, + 50: 11, + 51: 12, + 52: 13, + 53: 17, + 54: 18, + 58: 26, + 60: 28 +} + +# based on merged left & right SynthSeg labels +ct_brightness_group = { + 'darker': [4, 5, 14, 15, 24, 31, 72], # ventricles, CSF + 'dark': [2, 7, 16, 77, 30], # white matter + 'bright': [3, 8, 17, 18, 28, 10, 11, 12, 13, 26], # grey matter (cortex, hippocampus, amggdala, ventral DC), thalamus, ganglia (nucleus (putamen, pallidus, accumbens), caudate) + 'brighter': [], # skull, pineal gland, choroid plexus +} diff --git a/utils/multiprocessing.py b/utils/multiprocessing.py new file mode 100644 index 0000000000000000000000000000000000000000..c847c49eaa66a08a9e723e53fca53d2f1341848b --- /dev/null +++ b/utils/multiprocessing.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python3 + +"""Multiprocessing helpers.""" + +import torch + + +def run( + local_rank, + num_proc, + func, + init_method, + shard_id, + num_shards, + backend, + cfgs, + output_queue=None, +): + """ + Runs a function from a child process. + Args: + local_rank (int): rank of the current process on the current machine. + num_proc (int): number of processes per machine. + func (function): function to execute on each of the process. + init_method (string): method to initialize the distributed training. + TCP initialization: equiring a network address reachable from all + processes followed by the port. + Shared file-system initialization: makes use of a file system that + is shared and visible from all machines. The URL should start with + file:// and contain a path to a non-existent file on a shared file + system. + shard_id (int): the rank of the current machine. + num_shards (int): number of overall machines for the distributed + training job. + backend (string): three distributed backends ('nccl', 'gloo', 'mpi') are + supports, each with different capabilities. Details can be found + here: + https://pytorch.org/docs/stable/distributed.html + cfg (CfgNode): list of configs. Details can be found in + slowfast/config/defaults.py + output_queue (queue): can optionally be used to return values from the + master process. + """ + # Initialize the process group. + world_size = num_proc * num_shards + rank = shard_id * num_proc + local_rank + + try: + torch.distributed.init_process_group( + backend=backend, + init_method=init_method, + world_size=world_size, + rank=rank, + ) + except Exception as e: + raise e + + torch.cuda.set_device(local_rank) + ret = func(cfgs) + if output_queue is not None and local_rank == 0: + output_queue.put(ret) diff --git a/utils/plot_utils.py b/utils/plot_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4bd430f8d084fdb53fb00831213e608c1ced9b98 --- /dev/null +++ b/utils/plot_utils.py @@ -0,0 +1,122 @@ +""" +Plotting utilities to visualize training logs. +""" +from pathlib import Path, PurePath + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import seaborn as sns +import torch +from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas + + +def fig_to_numpy(fig): + w, h = fig.get_size_inches() * fig.dpi + w = int(w.item()) + h = int(h.item()) + canvas = FigureCanvas(fig) + canvas.draw() + numpy_image = np.frombuffer(canvas.tostring_rgb(), dtype='uint8').reshape(h, w, 3) + return np.copy(numpy_image) + + +def get_vis_win_names(vis_dict): + vis_win_names = { + outer_k: { + inner_k: inner_v.win + for inner_k, inner_v in outer_v.items() + } + for outer_k, outer_v in vis_dict.items() + } + return vis_win_names + + +def plot_logs(logs, fields=('class_error', 'loss_bbox_unscaled', 'mAP'), ewm_col=0, log_name='log.txt'): + ''' + Function to plot specific fields from training log(s). Plots both training and test results. + + :: Inputs - logs = list containing Path objects, each pointing to individual dir with a log file + - fields = which results to plot from each log file - plots both training and test for each field. + - ewm_col = optional, which column to use as the exponential weighted smoothing of the plots + - log_name = optional, name of log file if different than default 'log.txt'. + + :: Outputs - matplotlib plots of results in fields, color coded for each log file. + - solid lines are training results, dashed lines are test results. + + ''' + func_name = "plot_utils.py::plot_logs" + + # verify logs is a list of Paths (list[Paths]) or single Pathlib object Path, + # convert single Path to list to avoid 'not iterable' error + + if not isinstance(logs, list): + if isinstance(logs, PurePath): + logs = [logs] + print(f"{func_name} info: logs param expects a list argument, converted to list[Path].") + else: + raise ValueError(f"{func_name} - invalid argument for logs parameter.\n \ + Expect list[Path] or single Path obj, received {type(logs)}") + + # verify valid dir(s) and that every item in list is Path object + for i, dir in enumerate(logs): + if not isinstance(dir, PurePath): + raise ValueError(f"{func_name} - non-Path object in logs argument of {type(dir)}: \n{dir}") + if dir.exists(): + continue + raise ValueError(f"{func_name} - invalid directory in logs argument:\n{dir}") + + # load log file(s) and plot + dfs = [pd.read_json(Path(p) / log_name, lines=True) for p in logs] + + fig, axs = plt.subplots(ncols=len(fields), figsize=(16, 5)) + + for df, color in zip(dfs, sns.color_palette(n_colors=len(logs))): + for j, field in enumerate(fields): + if field == 'mAP': + coco_eval = pd.DataFrame(pd.np.stack(df.test_coco_eval.dropna().values)[:, 1]).ewm(com=ewm_col).mean() + axs[j].plot(coco_eval, c=color) + else: + df.interpolate().ewm(com=ewm_col).mean().plot( + y=[f'train_{field}', f'test_{field}'], + ax=axs[j], + color=[color] * 2, + style=['-', '--'] + ) + for ax, field in zip(axs, fields): + ax.legend([Path(p).name for p in logs]) + ax.set_title(field) + + +def plot_precision_recall(files, naming_scheme='iter'): + if naming_scheme == 'exp_id': + # name becomes exp_id + names = [f.parts[-3] for f in files] + elif naming_scheme == 'iter': + names = [f.stem for f in files] + else: + raise ValueError(f'not supported {naming_scheme}') + fig, axs = plt.subplots(ncols=2, figsize=(16, 5)) + for f, color, name in zip(files, sns.color_palette("Blues", n_colors=len(files)), names): + data = torch.load(f) + # precision is n_iou, n_points, n_cat, n_area, max_det + precision = data['precision'] + recall = data['params'].recThrs + scores = data['scores'] + # take precision for all classes, all areas and 100 detections + precision = precision[0, :, :, 0, -1].mean(1) + scores = scores[0, :, :, 0, -1].mean(1) + prec = precision.mean() + rec = data['recall'][0, :, 0, -1].mean() + print(f'{naming_scheme} {name}: mAP@50={prec * 100: 05.1f}, ' + + f'score={scores.mean():0.3f}, ' + + f'f1={2 * prec * rec / (prec + rec + 1e-8):0.3f}' + ) + axs[0].plot(recall, precision, c=color) + axs[1].plot(recall, scores, c=color) + + axs[0].set_title('Precision / Recall') + axs[0].legend(names) + axs[1].set_title('Scores / Recall') + axs[1].legend(names) + return fig, axs diff --git a/utils/process_cfg.py b/utils/process_cfg.py new file mode 100644 index 0000000000000000000000000000000000000000..94a9390ebf9d1a78b3b8b606da8eb04a13a39f5a --- /dev/null +++ b/utils/process_cfg.py @@ -0,0 +1,69 @@ +"""Wrapper to train/test models.""" + +import os +import pytz +from datetime import datetime + +from utils.config import Config + +def update_config(cfg, exp_name='', job_name=''): + """ + Update some configs. + Args: + cfg: from submit_config.config + """ + tz_NY = pytz.timezone('America/New_York') + + if 'lemon' in cfg.out_root: + cfg.out_dir = os.path.join(cfg.root_dir_lemon, cfg.out_dir) + else: + cfg.out_dir = os.path.join(cfg.root_dir_yogurt_out, cfg.out_dir) + + cfg.vis_itr = int(cfg.vis_itr) + + + if cfg.eval_only: + cfg.out_dir = os.path.join(cfg.out_dir, 'Test', exp_name, job_name, datetime.now(tz_NY).strftime("%m%d-%H%M")) + else: + cfg.out_dir = os.path.join(cfg.out_dir, exp_name, job_name, datetime.now(tz_NY).strftime("%m%d-%H%M")) + return cfg + + +def merge_and_update_from_dict(cfg, dct): + """ + (Compatible for submitit's Dict as attribute trick) + Merge dict as dict() to config as CfgNode(). + Args: + cfg: dict + dct: dict + """ + if dct is not None: + for key, value in dct.items(): + if isinstance(value, dict): + if key in cfg.keys(): + sub_cfgnode = cfg[key] + else: + sub_cfgnode = dict() + cfg.__setattr__(key, sub_cfgnode) + sub_cfgnode = merge_and_update_from_dict(sub_cfgnode, value) + else: + cfg[key] = value + return cfg + + +def load_config(default_cfg_file, add_cfg_files = [], cfg_dir = ''): + cfg = Config(default_cfg_file) + for cfg_file in add_cfg_files: + if os.path.isabs(cfg_file): + add_cfg = Config(cfg_file) + else: + assert os.path.isabs(cfg_dir) + if not cfg_file.endswith('.yaml'): + cfg_file += '.yaml' + add_cfg = Config(os.path.join(cfg_dir, cfg_file)) + cfg = merge_and_update_from_dict(cfg, add_cfg) + if "exp_name" in cfg: + return update_config(cfg, exp_name=cfg["exp_name"], job_name = cfg["job_name"]) + else: + return cfg + diff --git a/utils/test_utils.py b/utils/test_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9595bd7b64268cc6c46abd2e9eed33e2d3903b66 --- /dev/null +++ b/utils/test_utils.py @@ -0,0 +1,404 @@ +import os +import numpy as np +import torch + +from Generator.utils import fast_3D_interp_torch, myzoom_torch +from Trainer.models import build_model, build_inpaint_model +from utils.checkpoint import load_checkpoint +import utils.misc as utils + +device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu' + + +# default & gpu cfg # + +#submit_cfg_file = '/autofs/space/yogurt_003/users/pl629/code/MTBrainID/cfgs/submit.yaml' +#default_gen_cfg_file = '/autofs/space/yogurt_003/users/pl629/code/MTBrainID/cfgs/generator/default.yaml' + +#default_train_cfg_file = '/autofs/space/yogurt_003/users/pl629/code/MTBrainID/cfgs/trainer/default_train.yaml' +#default_val_file = '/autofs/space/yogurt_003/users/pl629/code/MTBrainID/cfgs/trainer/default_val.yaml' + +#gen_cfg_dir = '/autofs/space/yogurt_003/users/pl629/code/MTBrainID/cfgs/generator/test' +#train_cfg_dir = '/autofs/space/yogurt_003/users/pl629/code/MTBrainID/cfgs/trainer/test' + +#atlas_path = '/autofs/vast/lemon/temp_stuff/peirong/data/gca.mgz' + + + +submit_cfg_file = '/autofs/vast/lemon/temp_stuff/brainfm/cfg/defaults/submit.yaml' +default_gen_cfg_file = '/autofs/vast/lemon/temp_stuff/brainfm/cfg/defaults/default_gen.yaml' + +default_train_cfg_file = '/autofs/vast/lemon/temp_stuff/brainfm/cfg/defaults/default_train.yaml' +default_val_file = '/autofs/vast/lemon/temp_stuff/brainfm/cfg/defaults/default_val.yaml' + + +gen_cfg_dir = '' +train_cfg_dir = '' +atlas_path = '/autofs/vast/lemon/temp_stuff/brainfm/files/gca.mgz' + + +MNI, aff2 = utils.MRIread(atlas_path) +A = np.linalg.inv(aff2) +A = torch.tensor(A, device=device, dtype=torch.float32) +MNI = torch.tensor(MNI, device = device, dtype = torch.float32) + +def get_deformed_atlas(brain_labels, regx, regy, regz): + M = brain_labels>0 + xx = 100 * regx[M] + yy = 100 * regy[M] + zz = 100 * regz[M] + ii = A[0, 0] * xx + A[0, 1] * yy + A[0, 2] * zz + A[0, 3] + jj = A[1, 0] * xx + A[1, 1] * yy + A[1, 2] * zz + A[1, 3] + kk = A[2, 0] * xx + A[2, 1] * yy + A[2, 2] * zz + A[2, 3] + + vals = fast_3D_interp_torch(MNI, ii, jj, kk, 'linear') + DEF = torch.zeros_like(regx) + DEF[M] = vals + return DEF + + +def zero_crop(orig, tol = 0, crop_range_lst = None, save_path = None): + + ''' + crop_range_lst: [[x0, y0, z0], [x1, y1, z1]] + ''' + + if crop_range_lst is None: + + # Mask of non-black pixels (assuming image has a single channel). + mask = orig > tol + + # Coordinates of non-black pixels. + coords = torch.argwhere(mask) + + # Bounding box of non-black pixels. + #print(coords.min(dim=0)) + x0, y0, z0 = coords.min(dim=0)[0] + x1, y1, z1 = coords.max(dim=0)[0] + 1 # slices are exclusive at the top + + # Check the the bounding box # + #print(' Cropping Slice [%d, %d)' % (x0, x1)) + #print(' Cropping Row [%d, %d)' % (y0, y1)) + #print(' Cropping Column [%d, %d)' % (z0, z1)) + + else: + + [[x0, y0, z0], [x1, y1, z1]] = crop_range_lst + + cropped = orig[x0 : x1, y0 : y1, z0 : z1] + + return cropped #, [[x0, y0, z0], [x1, y1, z1]] + + +def tiling(img, stride = [40, 40, 40], win_size = [160, 160, 160], zero_crop_first = False): + + orig_shp = img.shape[2:] # (1, d, s, r, c) + + # first, crop all zeros -> get *actual* shape + if zero_crop_first: + #print(' before zero croppping:', orig_shp) + img = zero_crop(img[0, 0])[None, None] + orig_shp = img.shape[2:] + print('shape after zero croppping:', orig_shp) + + # tiling + x_start, y_start, z_start = 0, 0, 0 + x_end = min(x_start + win_size[0], orig_shp[0]) + y_end = min(y_start + win_size[1], orig_shp[1]) + z_end = min(z_start + win_size[2], orig_shp[2]) + x_list, y_list, z_list = [(x_start, x_end)], [(y_start, y_end)], [([z_start, z_end])] + + while x_end < orig_shp[0]: + x_start = min(x_end, orig_shp[0] - stride[0]) + x_end = min(x_start + stride[0], orig_shp[0]) + x_list.append(([x_start, x_end])) + + while y_end < orig_shp[1]: + y_start = min(y_end, orig_shp[1] - stride[1]) + y_end = min(y_start + stride[1], orig_shp[1]) + y_list.append(([y_start, y_end])) + + while z_end < orig_shp[2]: + z_start = min(z_end, orig_shp[2] - stride[2]) + z_end = min(z_start + stride[2], orig_shp[2]) + z_list.append(([z_start, z_end])) + + img_list = [] + cnt = torch.zeros_like(img[0, 0]) + for (x_start, x_end) in x_list: + for (y_start, y_end) in y_list: + for (z_start, z_end) in z_list: + curr_img = img[:, :, x_start : x_end, y_start : y_end, z_start : z_end] + curr_range = [(x_start, x_end), (y_start, y_end), (z_start, z_end)] + img_list.append((curr_img, curr_range)) + cnt[x_start : x_end, y_start : y_end, z_start : z_end] += 1 + print('Patch #'+str(len(img_list)), curr_img.shape[2:], curr_range) + + return img_list, cnt + + + +def center_crop(img, win_size = [220, 220, 220], zero_crop_first = False, aff=np.eye(4)): + # center crop + if len(img.shape) == 4: + img = torch.permute(img, (3, 0, 1, 2)) # (move last dim to first) + img = img[None] + permuted = True + else: + assert len(img.shape) == 3 + img = img[None, None] + permuted = False + + orig_shp = img.shape[2:] # (1, d, s, r, c) + + # first, crop all zeros -> get *actual* shape + if zero_crop_first: + print(' before zero croppping:', orig_shp) + img = zero_crop(img[0, 0])[None, None] + orig_shp = img.shape[2:] + print(' after zero croppping:', orig_shp) + + + if win_size is None: + if permuted: + return torch.permute(img, (0, 2, 3, 4, 1)), [0, 0, 0], orig_shp + return img, [0, 0, 0], orig_shp, aff + + elif orig_shp[0] > win_size[0] or orig_shp[1] > win_size[1] or orig_shp[2] > win_size[2]: + crop_start = [ max((orig_shp[i] - win_size[i]), 0) // 2 for i in range(3) ] + aff[:-1,-1] = aff[:-1,-1] + aff[:-1,:-1] @ np.array(crop_start) + crop_img = img[ :, :, crop_start[0] : crop_start[0] + win_size[0], + crop_start[1] : crop_start[1] + win_size[1], + crop_start[2] : crop_start[2] + win_size[2]] + #pad_img = torch.zeros((1, 1, win_size[0], win_size[1], win_size[2]), device = device) + #pad_img[:, :, int((win_size[0] - crop_img.shape[2])/2) : int((win_size[0] - crop_img.shape[2])/2) + crop_img.shape[2], \ + # int((win_size[1] - crop_img.shape[3])/2) : int((win_size[1] - crop_img.shape[3])/2) + crop_img.shape[3], \ + # int((win_size[2] - crop_img.shape[4])/2) : int((win_size[2] - crop_img.shape[4])/2) + crop_img.shape[4] ] = crop_img + if permuted: + return torch.permute(crop_img, (0, 2, 3, 4, 1)), [0, 0, 0], orig_shp, aff + return crop_img, crop_start, orig_shp, aff + + else: + #pad_img = torch.zeros((1, 1, win_size[0], win_size[1], win_size[2]), device = device) + #pad_img[:, :, int((win_size[0] - img.shape[2])/2) : int((win_size[0] - img.shape[2])/2) + img.shape[2], \ + # int((win_size[1] - img.shape[3])/2) : int((win_size[1] - img.shape[3])/2) + img.shape[3], \ + # int((win_size[2] - img.shape[4])/2) : int((win_size[2] - img.shape[4])/2) + img.shape[4] ] = img + if permuted: + return torch.permute(img, (0, 2, 3, 4, 1)), [0, 0, 0], orig_shp, aff + return img, [0, 0, 0], orig_shp, aff + + + +def add_bias_field(I, bf_scale_min = 0.02, bf_scale_max = 0.04, bf_std_min = 0.1, bf_std_max = 0.6, device = 'cpu'): + bf_scale = bf_scale_min + np.random.rand(1) * (bf_scale_max - bf_scale_min) + size_BF_small = np.round(bf_scale * np.array(I.shape)).astype(int).tolist() + BFsmall = torch.tensor(bf_std_min + (bf_std_max - bf_std_min) * np.random.rand(1), dtype=torch.float, device=device) * \ + torch.randn(size_BF_small, dtype=torch.float, device=device) + BFlog = myzoom_torch(BFsmall, np.array(I.shape) / size_BF_small) + BF = torch.exp(BFlog) + I_bf = I * BF + return I_bf, BF + +def resample(I, orig_res = [1., 1., 1.], new_res = [1., 1., 1.]): + if not isinstance(orig_res, list): + orig_res = [orig_res, orig_res, orig_res] + if not isinstance(new_res, list): + new_res = [new_res, new_res, new_res] + #print('pre resample', I.shape) + resolution = np.array(new_res) + new_size = (np.array(I.shape) * orig_res / resolution).astype(int) + + factors = np.array(new_size) / np.array(I.shape) + delta = (1.0 - factors) / (2.0 * factors) + vx = np.arange(delta[0], delta[0] + new_size[0] / factors[0], 1 / factors[0])[:new_size[0]] + vy = np.arange(delta[1], delta[1] + new_size[1] / factors[1], 1 / factors[1])[:new_size[1]] + vz = np.arange(delta[2], delta[2] + new_size[2] / factors[2], 1 / factors[2])[:new_size[2]] + II, JJ, KK = np.meshgrid(vx, vy, vz, sparse=False, indexing='ij') + II = torch.tensor(II, dtype=torch.float, device=I.device) + JJ = torch.tensor(JJ, dtype=torch.float, device=I.device) + KK = torch.tensor(KK, dtype=torch.float, device=I.device) + + I_resize = fast_3D_interp_torch(I, II, JJ, KK, 'linear') + I_new = utils.myzoom_torch(I_resize, 1 / factors) + + #print('post resample', I_new.shape) + return I_new + + +def read_image(img_path, is_label = False, device = 'cpu'): + im, aff = utils.MRIread(img_path, im_only=False, dtype='int' if is_label else 'float') + im = torch.tensor(np.squeeze(im), dtype=torch.int if is_label else torch.float32, device=device) + im = torch.nan_to_num(im) + return im + + +def prepare_image(img_path, win_size = None, zero_crop_first = False, spacing = None, + add_bf = False, is_CT = False, is_label = False, rescale = True, + hemis_mask = None, im_only = False, device = 'cpu'): + im, aff = utils.MRIread(img_path, im_only=False, dtype='int' if is_label else 'float') + im = torch.tensor(np.squeeze(im), dtype=torch.int if is_label else torch.float32, device=device) + im = torch.nan_to_num(im) + + if len(im.shape) > 3: + #print('shape', im.shape) + im = im.mean(dim = -1) # averaging the RGB + + if is_CT and rescale: # for CT as input + im = torch.clamp(im, min = 0., max = 80.) + + if not is_label and rescale: + im -= torch.min(im) + im /= torch.max(im) + + im, aff = utils.torch_resize(im, aff, 1.) + + orig = im + orig, aff_before_crop = utils.align_volume_to_ref(orig, aff, aff_ref=np.eye(4), return_aff=True, n_dims=3) + + orig, crop_start, orig_shp, aff = center_crop(orig, win_size, zero_crop_first = zero_crop_first, aff = aff_before_crop) + + if add_bf and not is_CT: + high_res, bf = add_bias_field(im, device = device) + bf, _ = utils.align_volume_to_ref(bf, aff_before_crop, aff_ref=np.eye(4), return_aff=True, n_dims=3) + bf, crop_start, orig_shp, _ = center_crop(bf, win_size, zero_crop_first = zero_crop_first, aff = aff_before_crop) + else: + high_res, bf = im, None + + if spacing is not None: + final = resample(high_res, new_res = spacing) + else: + final = high_res + + high_res, _ = utils.align_volume_to_ref(high_res, aff_before_crop, aff_ref=np.eye(4), return_aff=True, n_dims=3) + high_res, crop_start, orig_shp, _ = center_crop(high_res, win_size, zero_crop_first = zero_crop_first, aff = aff_before_crop) + + final, _ = utils.align_volume_to_ref(final, aff_before_crop, aff_ref=np.eye(4), return_aff=True, n_dims=3) + final, crop_start, orig_shp, _ = center_crop(final, win_size, zero_crop_first = zero_crop_first, aff = aff_before_crop) + + if hemis_mask is not None: + final[hemis_mask ==0] = 0 + + if im_only: + return final + + return final, orig, high_res, bf, aff, crop_start, orig_shp + + + + +@torch.no_grad() +def evaluate_image(inputs, ckp_path, feature_only = True, device = 'cpu', gen_cfg = None, model_cfg = None): + # inputs: Torch.Tensor -- (batch_size, 1, s, r, c) + + # ============ prepare ... ============ + gen_args = utils.preprocess_cfg([default_gen_cfg_file, gen_cfg], cfg_dir = gen_cfg_dir) + train_args = utils.preprocess_cfg([default_train_cfg_file, default_val_file, model_cfg], cfg_dir = train_cfg_dir) + + samples = [ { 'input': inputs } ] + + # ============ testing ... ============ + gen_args, train_args, feat_model, processors, criterion, postprocessor = build_model(gen_args, train_args, device) + load_checkpoint(ckp_path, [feat_model], model_keys = ['model'], to_print = False) + outputs, _ = feat_model(samples) # dict with features + + for processor in processors: + outputs = processor(outputs, samples) + if postprocessor is not None: + outputs, _, _ = postprocessor(gen_args, train_args, outputs, samples, target = None, feats = None, tasks = gen_args.tasks) + + if feature_only: + return outputs[0]['feat'][-1] # (batch_size, 64, s, r, c) + else: + return outputs[0] + + +@torch.no_grad() +def evaluate_image_twostage(inputs, pathol_ckp_path, task_ckp_path, feature_only = True, device = 'cpu', gen_cfg = None, model_cfg = None): + # inputs: Torch.Tensor -- (batch_size, 1, s, r, c) + + # ============ prepare ... ============ + gen_args = utils.preprocess_cfg([default_gen_cfg_file, gen_cfg], cfg_dir = gen_cfg_dir) + train_args = utils.preprocess_cfg([default_train_cfg_file, default_val_file, model_cfg], cfg_dir = train_cfg_dir) + + samples = [ { 'input': inputs } ] + + # ============ testing ... ============ + gen_args, train_args, pathol_model, task_model, pathol_processors, task_processors, criterion, postprocessor = build_inpaint_model(gen_args, train_args, device) + load_checkpoint(pathol_ckp_path, [pathol_model], model_keys = ['model'], to_print = False) + load_checkpoint(task_ckp_path, [task_model], model_keys = ['model'], to_print = False) + + # stage-0: pathology segmentation prediction + outputs_pathol, _ = pathol_model(samples) + for processor in pathol_processors: + outputs_pathol = processor(outputs_pathol, samples) + + # stage-1: pathology-mask-conditioned inpainting tasks prediction + for i in range(len(samples)): # mask using predicted anomaly + samples[i]['input_masked'] = samples[i]['input'] * (1 - outputs_pathol[i]['pathology']) + outputs_task, _ = task_model(samples, input_name = 'input_masked', cond = [o['pathology'] for o in outputs_pathol]) + for processor in task_processors: + outputs_task = processor(outputs_task, samples) + + outputs = utils.merge_list_of_dict(outputs_task, outputs_pathol) + + if postprocessor is not None: + outputs, _, _ = postprocessor(gen_args, train_args, outputs, samples, target = None, feats = None, tasks = gen_args.tasks) + + if feature_only: + return outputs[0]['feat_pathol'][-1], outputs[0]['feat_task'][-1] # (batch_size, 64, s, r, c) + else: + return outputs[0] + + + +@torch.no_grad() +def evaluate_path(input_paths, save_dir, ckp_path, win_size = [220, 220, 220], + save_input = False, aux_paths = {}, save_aux = False, exclude_keys = [], + mask_output = False, ext = '.nii.gz', device = 'cpu', + gen_cfg = None, model_cfg = None): + + gen_args = utils.preprocess_cfg([default_gen_cfg_file, gen_cfg], cfg_dir = gen_cfg_dir) + train_args = utils.preprocess_cfg([default_train_cfg_file, default_val_file, model_cfg], cfg_dir = train_cfg_dir) + + # ============ loading ... ============ + gen_args, train_args, model, processors, criterion, postprocessor = build_model(gen_args, train_args, device) + load_checkpoint(ckp_path, [model], model_keys = ['model'], to_print = False) + + for i, input_path in enumerate(input_paths): + print('Now testing: %s (%d/%d)' % (input_path, i+1, len(input_paths))) + print(' ckp:', ckp_path) + curr_save_dir = utils.make_dir(os.path.join(save_dir, os.path.basename(input_path).split('.nii')[0])) + + # ============ prepare ... ============ + mask = None + im, orig, high_res, bf, aff, crop_start, orig_shp = prepare_image(input_path, win_size, device = device) + if save_input: + print(' Input: saved in - %s' % (os.path.join(curr_save_dir, 'input' + ext))) + utils.viewVolume(im, aff, names = ['input'], ext = ext, save_dir = curr_save_dir) + for k in aux_paths.keys(): + im_k, _, _, _, _, _, _ = prepare_image(aux_paths[k][i], win_size, is_label = 'label' in k, device = device) + if save_aux: + print(' Aux input: %s - saved in - %s' % (k, os.path.join(curr_save_dir, k + ext))) + utils.viewVolume(im_k, aff, names = [k], ext = ext, save_dir = curr_save_dir) + if mask_output and 'mask' in k: + mask = im_k.clone() + mask[im_k != 0.] = 1. + samples = [ { 'input': im } ] + + # ============ testing ... ============ + outputs, _ = model(samples) # dict with features + + for processor in processors: + outputs = processor(outputs, samples) + if postprocessor is not None: + outputs, _, _ = postprocessor(gen_args, train_args, outputs, samples, target = None, feats = None, tasks = gen_args.tasks) + + out = outputs[0] + if mask_output and mask is None: + mask = torch.zeros_like(im) + mask[im != 0.] = 1. + for key in out.keys(): + if key not in exclude_keys and isinstance(out[key], torch.Tensor): + print(' Output: %s - saved in - %s' % (key, os.path.join(curr_save_dir, 'out_' + key + ext))) + out[key][out[key] < 0.] = 0. + utils.viewVolume(out[key] * mask if mask_output else out[key], aff, names = ['out_'+key], ext = ext, save_dir = curr_save_dir)