Upload 187 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +14 -0
- Generator/__init__.py +22 -0
- Generator/config.py +181 -0
- Generator/constants.py +290 -0
- Generator/datasets.py +757 -0
- Generator/interpol/__init__.py +7 -0
- Generator/interpol/_version.py +623 -0
- Generator/interpol/api.py +560 -0
- Generator/interpol/autograd.py +301 -0
- Generator/interpol/backend.py +1 -0
- Generator/interpol/bounds.py +89 -0
- Generator/interpol/coeff.py +344 -0
- Generator/interpol/iso0.py +368 -0
- Generator/interpol/iso1.py +1339 -0
- Generator/interpol/jit_utils.py +443 -0
- Generator/interpol/jitfields.py +95 -0
- Generator/interpol/nd.py +464 -0
- Generator/interpol/pushpull.py +325 -0
- Generator/interpol/resize.py +120 -0
- Generator/interpol/restrict.py +122 -0
- Generator/interpol/splines.py +196 -0
- Generator/interpol/tests/__init__.py +0 -0
- Generator/interpol/tests/test_gradcheck_pushpull.py +125 -0
- Generator/interpol/utils.py +176 -0
- Generator/utils.py +669 -0
- README.md +91 -3
- ShapeID/DiffEqs/FD.py +525 -0
- ShapeID/DiffEqs/adams.py +170 -0
- ShapeID/DiffEqs/adjoint.py +133 -0
- ShapeID/DiffEqs/dopri5.py +172 -0
- ShapeID/DiffEqs/fixed_adams.py +211 -0
- ShapeID/DiffEqs/fixed_grid.py +33 -0
- ShapeID/DiffEqs/interp.py +65 -0
- ShapeID/DiffEqs/misc.py +195 -0
- ShapeID/DiffEqs/odeint.py +75 -0
- ShapeID/DiffEqs/pde.py +643 -0
- ShapeID/DiffEqs/rk_common.py +78 -0
- ShapeID/DiffEqs/solvers.py +216 -0
- ShapeID/DiffEqs/tsit5.py +139 -0
- ShapeID/__init__.py +1 -0
- ShapeID/demo2d.py +102 -0
- ShapeID/demo3d.py +91 -0
- ShapeID/misc.py +261 -0
- ShapeID/out/2d/V.png +3 -0
- ShapeID/out/2d/curl.png +0 -0
- ShapeID/out/2d/image.png +0 -0
- ShapeID/out/2d/image_with_v.png +3 -0
- ShapeID/out/2d/mask_curl.png +0 -0
- ShapeID/out/2d/mask_image.png +0 -0
- ShapeID/out/2d/progression/New Folder With Items/0.png +3 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,17 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
assets/overview.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
ShapeID/out/2d/image_with_v.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
ShapeID/out/2d/progression/New[[:space:]]Folder[[:space:]]With[[:space:]]Items/0.png filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
ShapeID/out/2d/progression/New[[:space:]]Folder[[:space:]]With[[:space:]]Items/1.png filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
ShapeID/out/2d/progression/New[[:space:]]Folder[[:space:]]With[[:space:]]Items/10.png filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
ShapeID/out/2d/progression/New[[:space:]]Folder[[:space:]]With[[:space:]]Items/2.png filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
ShapeID/out/2d/progression/New[[:space:]]Folder[[:space:]]With[[:space:]]Items/3.png filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
ShapeID/out/2d/progression/New[[:space:]]Folder[[:space:]]With[[:space:]]Items/4.png filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
ShapeID/out/2d/progression/New[[:space:]]Folder[[:space:]]With[[:space:]]Items/5.png filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
ShapeID/out/2d/progression/New[[:space:]]Folder[[:space:]]With[[:space:]]Items/6.png filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
ShapeID/out/2d/progression/New[[:space:]]Folder[[:space:]]With[[:space:]]Items/7.png filter=lfs diff=lfs merge=lfs -text
|
| 47 |
+
ShapeID/out/2d/progression/New[[:space:]]Folder[[:space:]]With[[:space:]]Items/8.png filter=lfs diff=lfs merge=lfs -text
|
| 48 |
+
ShapeID/out/2d/progression/New[[:space:]]Folder[[:space:]]With[[:space:]]Items/9.png filter=lfs diff=lfs merge=lfs -text
|
| 49 |
+
ShapeID/out/2d/V.png filter=lfs diff=lfs merge=lfs -text
|
Generator/__init__.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
"""
|
| 3 |
+
Datasets interface.
|
| 4 |
+
"""
|
| 5 |
+
from .constants import dataset_setups
|
| 6 |
+
from .datasets import BaseGen, BrainIDGen
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
dataset_options = {
|
| 11 |
+
'default': BaseGen,
|
| 12 |
+
'brain_id': BrainIDGen,
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def build_datasets(gen_args, device):
|
| 19 |
+
"""Helper function to build dataset for different splits ('train' or 'test')."""
|
| 20 |
+
datasets = {'all': dataset_options[gen_args.dataset_option](gen_args, device)}
|
| 21 |
+
return datasets
|
| 22 |
+
|
Generator/config.py
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
"""Config utilities for yml file."""
|
| 3 |
+
import os
|
| 4 |
+
from argparse import Namespace
|
| 5 |
+
import collections
|
| 6 |
+
import functools
|
| 7 |
+
import os
|
| 8 |
+
import re
|
| 9 |
+
|
| 10 |
+
import yaml
|
| 11 |
+
# from imaginaire.utils.distributed import master_only_print as print
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class AttrDict(dict):
|
| 15 |
+
"""Dict as attribute trick."""
|
| 16 |
+
|
| 17 |
+
def __init__(self, *args, **kwargs):
|
| 18 |
+
super(AttrDict, self).__init__(*args, **kwargs)
|
| 19 |
+
self.__dict__ = self
|
| 20 |
+
for key, value in self.__dict__.items():
|
| 21 |
+
if isinstance(value, dict):
|
| 22 |
+
self.__dict__[key] = AttrDict(value)
|
| 23 |
+
elif isinstance(value, (list, tuple)):
|
| 24 |
+
if isinstance(value[0], dict):
|
| 25 |
+
self.__dict__[key] = [AttrDict(item) for item in value]
|
| 26 |
+
else:
|
| 27 |
+
self.__dict__[key] = value
|
| 28 |
+
|
| 29 |
+
def yaml(self):
|
| 30 |
+
"""Convert object to yaml dict and return."""
|
| 31 |
+
yaml_dict = {}
|
| 32 |
+
for key, value in self.__dict__.items():
|
| 33 |
+
if isinstance(value, AttrDict):
|
| 34 |
+
yaml_dict[key] = value.yaml()
|
| 35 |
+
elif isinstance(value, list):
|
| 36 |
+
if isinstance(value[0], AttrDict):
|
| 37 |
+
new_l = []
|
| 38 |
+
for item in value:
|
| 39 |
+
new_l.append(item.yaml())
|
| 40 |
+
yaml_dict[key] = new_l
|
| 41 |
+
else:
|
| 42 |
+
yaml_dict[key] = value
|
| 43 |
+
else:
|
| 44 |
+
yaml_dict[key] = value
|
| 45 |
+
return yaml_dict
|
| 46 |
+
|
| 47 |
+
def __repr__(self):
|
| 48 |
+
"""Print all variables."""
|
| 49 |
+
ret_str = []
|
| 50 |
+
for key, value in self.__dict__.items():
|
| 51 |
+
if isinstance(value, AttrDict):
|
| 52 |
+
ret_str.append('{}:'.format(key))
|
| 53 |
+
child_ret_str = value.__repr__().split('\n')
|
| 54 |
+
for item in child_ret_str:
|
| 55 |
+
ret_str.append(' ' + item)
|
| 56 |
+
elif isinstance(value, list):
|
| 57 |
+
if isinstance(value[0], AttrDict):
|
| 58 |
+
ret_str.append('{}:'.format(key))
|
| 59 |
+
for item in value:
|
| 60 |
+
# Treat as AttrDict above.
|
| 61 |
+
child_ret_str = item.__repr__().split('\n')
|
| 62 |
+
for item in child_ret_str:
|
| 63 |
+
ret_str.append(' ' + item)
|
| 64 |
+
else:
|
| 65 |
+
ret_str.append('{}: {}'.format(key, value))
|
| 66 |
+
else:
|
| 67 |
+
ret_str.append('{}: {}'.format(key, value))
|
| 68 |
+
return '\n'.join(ret_str)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class Config(AttrDict):
|
| 72 |
+
r"""Configuration class. This should include every human specifiable
|
| 73 |
+
hyperparameter values for your training."""
|
| 74 |
+
|
| 75 |
+
def __init__(self, filename=None, verbose=False):
|
| 76 |
+
super(Config, self).__init__()
|
| 77 |
+
|
| 78 |
+
# Update with given configurations.
|
| 79 |
+
if os.path.exists(filename):
|
| 80 |
+
|
| 81 |
+
loader = yaml.SafeLoader
|
| 82 |
+
loader.add_implicit_resolver(
|
| 83 |
+
u'tag:yaml.org,2002:float',
|
| 84 |
+
re.compile(u'''^(?:
|
| 85 |
+
[-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
|
| 86 |
+
|[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
|
| 87 |
+
|\\.[0-9_]+(?:[eE][-+][0-9]+)?
|
| 88 |
+
|[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*
|
| 89 |
+
|[-+]?\\.(?:inf|Inf|INF)
|
| 90 |
+
|\\.(?:nan|NaN|NAN))$''', re.X),
|
| 91 |
+
list(u'-+0123456789.'))
|
| 92 |
+
try:
|
| 93 |
+
with open(filename, 'r') as f:
|
| 94 |
+
cfg_dict = yaml.load(f, Loader=loader)
|
| 95 |
+
except EnvironmentError:
|
| 96 |
+
print('Please check the file with name of "%s"', filename)
|
| 97 |
+
recursive_update(self, cfg_dict)
|
| 98 |
+
else:
|
| 99 |
+
raise ValueError('Provided config path not existed: %s' % filename)
|
| 100 |
+
|
| 101 |
+
if verbose:
|
| 102 |
+
print(' imaginaire config '.center(80, '-'))
|
| 103 |
+
print(self.__repr__())
|
| 104 |
+
print(''.center(80, '-'))
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def rsetattr(obj, attr, val):
|
| 108 |
+
"""Recursively find object and set value"""
|
| 109 |
+
pre, _, post = attr.rpartition('.')
|
| 110 |
+
return setattr(rgetattr(obj, pre) if pre else obj, post, val)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def rgetattr(obj, attr, *args):
|
| 114 |
+
"""Recursively find object and return value"""
|
| 115 |
+
|
| 116 |
+
def _getattr(obj, attr):
|
| 117 |
+
r"""Get attribute."""
|
| 118 |
+
return getattr(obj, attr, *args)
|
| 119 |
+
|
| 120 |
+
return functools.reduce(_getattr, [obj] + attr.split('.'))
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def recursive_update(d, u):
|
| 124 |
+
"""Recursively update AttrDict d with AttrDict u"""
|
| 125 |
+
if u is not None:
|
| 126 |
+
for key, value in u.items():
|
| 127 |
+
if isinstance(value, collections.abc.Mapping):
|
| 128 |
+
d.__dict__[key] = recursive_update(d.get(key, AttrDict({})), value)
|
| 129 |
+
elif isinstance(value, (list, tuple)):
|
| 130 |
+
if len(value) > 0 and isinstance(value[0], dict):
|
| 131 |
+
d.__dict__[key] = [AttrDict(item) for item in value]
|
| 132 |
+
else:
|
| 133 |
+
d.__dict__[key] = value
|
| 134 |
+
else:
|
| 135 |
+
d.__dict__[key] = value
|
| 136 |
+
return d
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def merge_and_update_from_dict(cfg, dct):
|
| 140 |
+
"""
|
| 141 |
+
(Compatible for submitit's Dict as attribute trick)
|
| 142 |
+
Merge dict as dict() to config as CfgNode().
|
| 143 |
+
Args:
|
| 144 |
+
cfg: dict
|
| 145 |
+
dct: dict
|
| 146 |
+
"""
|
| 147 |
+
if dct is not None:
|
| 148 |
+
for key, value in dct.items():
|
| 149 |
+
if isinstance(value, dict):
|
| 150 |
+
if key in cfg.keys():
|
| 151 |
+
sub_cfgnode = cfg[key]
|
| 152 |
+
else:
|
| 153 |
+
sub_cfgnode = dict()
|
| 154 |
+
cfg.__setattr__(key, sub_cfgnode)
|
| 155 |
+
sub_cfgnode = merge_and_update_from_dict(sub_cfgnode, value)
|
| 156 |
+
else:
|
| 157 |
+
cfg[key] = value
|
| 158 |
+
return cfg
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def load_config(cfg_files = [], cfg_dir = ''):
|
| 162 |
+
cfg = Config(cfg_files[0])
|
| 163 |
+
for cfg_file in cfg_files[1:]:
|
| 164 |
+
add_cfg = Config(cfg_file)
|
| 165 |
+
cfg = merge_and_update_from_dict(cfg, add_cfg)
|
| 166 |
+
return cfg
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def nested_dict_to_namespace(dictionary):
|
| 170 |
+
namespace = dictionary
|
| 171 |
+
if isinstance(dictionary, dict):
|
| 172 |
+
namespace = Namespace(**dictionary)
|
| 173 |
+
for key, value in dictionary.items():
|
| 174 |
+
setattr(namespace, key, nested_dict_to_namespace(value))
|
| 175 |
+
return namespace
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def preprocess_cfg(cfg_files, cfg_dir = ''):
|
| 179 |
+
config = load_config(cfg_files, cfg_dir)
|
| 180 |
+
args = nested_dict_to_namespace(config)
|
| 181 |
+
return args
|
Generator/constants.py
ADDED
|
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, glob
|
| 2 |
+
|
| 3 |
+
from .utils import *
|
| 4 |
+
|
| 5 |
+
augmentation_funcs = {
|
| 6 |
+
'gamma': add_gamma_transform,
|
| 7 |
+
'bias_field': add_bias_field,
|
| 8 |
+
'resample': resample_resolution,
|
| 9 |
+
'noise': add_noise,
|
| 10 |
+
}
|
| 11 |
+
|
| 12 |
+
processing_funcs = {
|
| 13 |
+
'T1': read_and_deform_image,
|
| 14 |
+
'T2': read_and_deform_image,
|
| 15 |
+
'FLAIR': read_and_deform_image,
|
| 16 |
+
'CT': read_and_deform_CT,
|
| 17 |
+
'segmentation': read_and_deform_segmentation,
|
| 18 |
+
'surface': read_and_deform_surface,
|
| 19 |
+
'distance': read_and_deform_distance,
|
| 20 |
+
'bias_field': read_and_deform_bias_field,
|
| 21 |
+
'registration': read_and_deform_registration,
|
| 22 |
+
'pathology': read_and_deform_pathology,
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
dataset_setups = {
|
| 27 |
+
|
| 28 |
+
'ADHD': {
|
| 29 |
+
'root': '/autofs/space/yogurt_001/users/pl629/data/adhd200_crop',
|
| 30 |
+
'pathology_type': None,
|
| 31 |
+
'train': 'train.txt',
|
| 32 |
+
'test': 'test.txt',
|
| 33 |
+
'modalities': ['T1'],
|
| 34 |
+
|
| 35 |
+
'paths':{
|
| 36 |
+
# for synth
|
| 37 |
+
'Gen': 'label_maps_generation',
|
| 38 |
+
'Dmaps': None,
|
| 39 |
+
'DmapsBag': None,
|
| 40 |
+
|
| 41 |
+
# real images
|
| 42 |
+
'T1': 'T1',
|
| 43 |
+
'T2': None,
|
| 44 |
+
'FLAIR': None,
|
| 45 |
+
'CT': None,
|
| 46 |
+
|
| 47 |
+
# processed ground truths
|
| 48 |
+
'surface': None, #'surfaces', TODO
|
| 49 |
+
'distance': None,
|
| 50 |
+
'segmentation': 'label_maps_segmentation',
|
| 51 |
+
'bias_field': None,
|
| 52 |
+
'pathology': None,
|
| 53 |
+
'pathology_prob': None,
|
| 54 |
+
}
|
| 55 |
+
},
|
| 56 |
+
|
| 57 |
+
'HCP': {
|
| 58 |
+
'root': '/autofs/space/yogurt_001/users/pl629/data/hcp_crop',
|
| 59 |
+
'pathology_type': None,
|
| 60 |
+
'train': 'train.txt',
|
| 61 |
+
'test': 'test.txt',
|
| 62 |
+
'modalities': ['T1', 'T2'],
|
| 63 |
+
|
| 64 |
+
'paths':{
|
| 65 |
+
# for synth
|
| 66 |
+
'Gen': 'label_maps_generation',
|
| 67 |
+
'Dmaps': None,
|
| 68 |
+
'DmapsBag': None,
|
| 69 |
+
|
| 70 |
+
# real images
|
| 71 |
+
'T1': 'T1',
|
| 72 |
+
'T2': 'T2',
|
| 73 |
+
'FLAIR': None,
|
| 74 |
+
'CT': None,
|
| 75 |
+
|
| 76 |
+
# processed ground truths
|
| 77 |
+
'surface': None, #'surfaces',
|
| 78 |
+
'distance': None,
|
| 79 |
+
'segmentation': 'label_maps_segmentation',
|
| 80 |
+
'bias_field': None,
|
| 81 |
+
'pathology': None,
|
| 82 |
+
'pathology_prob': None,
|
| 83 |
+
}
|
| 84 |
+
},
|
| 85 |
+
|
| 86 |
+
'AIBL': {
|
| 87 |
+
'root': '/autofs/space/yogurt_001/users/pl629/data/aibl_crop',
|
| 88 |
+
'pathology_type': None,
|
| 89 |
+
'train': 'train.txt',
|
| 90 |
+
'test': 'test.txt',
|
| 91 |
+
'modalities': ['T1', 'T2', 'FLAIR'],
|
| 92 |
+
|
| 93 |
+
'paths':{
|
| 94 |
+
# for synth
|
| 95 |
+
'Gen': 'label_maps_generation',
|
| 96 |
+
'Dmaps': None,
|
| 97 |
+
'DmapsBag': None,
|
| 98 |
+
|
| 99 |
+
# real images
|
| 100 |
+
'T1': 'T1',
|
| 101 |
+
'T2': 'T2',
|
| 102 |
+
'FLAIR': 'FLAIR',
|
| 103 |
+
'CT': None,
|
| 104 |
+
|
| 105 |
+
# processed ground truths
|
| 106 |
+
'surface': None, #'surfaces',
|
| 107 |
+
'distance': None,
|
| 108 |
+
'segmentation': 'label_maps_segmentation',
|
| 109 |
+
'bias_field': None,
|
| 110 |
+
'pathology': None,
|
| 111 |
+
'pathology_prob': None,
|
| 112 |
+
}
|
| 113 |
+
},
|
| 114 |
+
|
| 115 |
+
'OASIS': {
|
| 116 |
+
'root': '/autofs/space/yogurt_001/users/pl629/data/oasis3',
|
| 117 |
+
'pathology_type': None,
|
| 118 |
+
'train': 'train.txt',
|
| 119 |
+
'test': 'test.txt',
|
| 120 |
+
'modalities': ['T1', 'CT'],
|
| 121 |
+
|
| 122 |
+
'paths':{
|
| 123 |
+
# for synth
|
| 124 |
+
'Gen': 'label_maps_generation',
|
| 125 |
+
'Dmaps': None,
|
| 126 |
+
'DmapsBag': None,
|
| 127 |
+
|
| 128 |
+
# real images
|
| 129 |
+
'T1': 'T1',
|
| 130 |
+
'T2': None,
|
| 131 |
+
'FLAIR': None,
|
| 132 |
+
'CT': 'CT',
|
| 133 |
+
|
| 134 |
+
# processed ground truths
|
| 135 |
+
'surface': None, #'surfaces',
|
| 136 |
+
'distance': None,
|
| 137 |
+
'segmentation': 'label_maps_segmentation',
|
| 138 |
+
'bias_field': None,
|
| 139 |
+
'pathology': None,
|
| 140 |
+
'pathology_prob': None,
|
| 141 |
+
}
|
| 142 |
+
},
|
| 143 |
+
|
| 144 |
+
'ADNI': {
|
| 145 |
+
'root': '/autofs/space/yogurt_001/users/pl629/data/adni_crop',
|
| 146 |
+
'pathology_type': None, #'wmh',
|
| 147 |
+
'train': 'train.txt',
|
| 148 |
+
'test': 'test.txt',
|
| 149 |
+
'modalities': ['T1'],
|
| 150 |
+
|
| 151 |
+
'paths':{
|
| 152 |
+
# for synth
|
| 153 |
+
'Gen': 'label_maps_generation',
|
| 154 |
+
'Dmaps': 'Dmaps',
|
| 155 |
+
'DmapsBag': 'DmapsBag',
|
| 156 |
+
|
| 157 |
+
# real images
|
| 158 |
+
'T1': 'T1',
|
| 159 |
+
'T2': None,
|
| 160 |
+
'FLAIR': None,
|
| 161 |
+
'CT': None,
|
| 162 |
+
|
| 163 |
+
# processed ground truths
|
| 164 |
+
'surface': 'surfaces',
|
| 165 |
+
'distance': 'Dmaps',
|
| 166 |
+
'segmentation': 'label_maps_segmentation',
|
| 167 |
+
'bias_field': None,
|
| 168 |
+
'pathology': 'pathology_maps_segmentation',
|
| 169 |
+
'pathology_prob': 'pathology_probability',
|
| 170 |
+
}
|
| 171 |
+
},
|
| 172 |
+
|
| 173 |
+
'ADNI3': {
|
| 174 |
+
'root': '/autofs/space/yogurt_001/users/pl629/data/adni3_crop',
|
| 175 |
+
'pathology_type': None, # 'wmh',
|
| 176 |
+
'train': 'train.txt',
|
| 177 |
+
'test': 'test.txt',
|
| 178 |
+
'modalities': ['T1', 'FLAIR'],
|
| 179 |
+
|
| 180 |
+
'paths':{
|
| 181 |
+
# for synth
|
| 182 |
+
'Gen': 'label_maps_generation',
|
| 183 |
+
'Dmaps': None,
|
| 184 |
+
'DmapsBag': None,
|
| 185 |
+
|
| 186 |
+
# real images
|
| 187 |
+
'T1': 'T1',
|
| 188 |
+
'T2': None,
|
| 189 |
+
'FLAIR': 'FLAIR',
|
| 190 |
+
'CT': None,
|
| 191 |
+
|
| 192 |
+
# processed ground truths
|
| 193 |
+
'surface': None, #'surfaces', TODO
|
| 194 |
+
'distance': None,
|
| 195 |
+
'segmentation': 'label_maps_segmentation',
|
| 196 |
+
'bias_field': None,
|
| 197 |
+
'pathology': 'pathology_maps_segmentation',
|
| 198 |
+
'pathology_prob': 'pathology_probability',
|
| 199 |
+
}
|
| 200 |
+
},
|
| 201 |
+
|
| 202 |
+
'ATLAS': {
|
| 203 |
+
'root': '/autofs/space/yogurt_001/users/pl629/data/atlas_crop',
|
| 204 |
+
'pathology_type': 'stroke',
|
| 205 |
+
'train': 'train.txt',
|
| 206 |
+
'test': 'test.txt',
|
| 207 |
+
'modalities': ['T1'],
|
| 208 |
+
|
| 209 |
+
'paths':{
|
| 210 |
+
# for synth
|
| 211 |
+
'Gen': 'label_maps_generation',
|
| 212 |
+
'Dmaps': None,
|
| 213 |
+
'DmapsBag': None,
|
| 214 |
+
|
| 215 |
+
# real images
|
| 216 |
+
'T1': 'T1',
|
| 217 |
+
'T2': None,
|
| 218 |
+
'FLAIR': None,
|
| 219 |
+
'CT': None,
|
| 220 |
+
|
| 221 |
+
# processed ground truths
|
| 222 |
+
'surface': None, #'surfaces', TODO
|
| 223 |
+
'distance': None,
|
| 224 |
+
'segmentation': 'label_maps_segmentation',
|
| 225 |
+
'bias_field': None,
|
| 226 |
+
'pathology': 'pathology_maps_segmentation',
|
| 227 |
+
'pathology_prob': 'pathology_probability',
|
| 228 |
+
}
|
| 229 |
+
},
|
| 230 |
+
|
| 231 |
+
'ISLES': {
|
| 232 |
+
'root': '/autofs/space/yogurt_001/users/pl629/data/isles2022_crop',
|
| 233 |
+
'pathology_type': 'stroke',
|
| 234 |
+
'train': 'train.txt',
|
| 235 |
+
'test': 'test.txt',
|
| 236 |
+
'modalities': ['FLAIR'],
|
| 237 |
+
|
| 238 |
+
'paths':{
|
| 239 |
+
# for synth
|
| 240 |
+
'Gen': 'label_maps_generation',
|
| 241 |
+
'Dmaps': None,
|
| 242 |
+
'DmapsBag': None,
|
| 243 |
+
|
| 244 |
+
# real images
|
| 245 |
+
'T1': None,
|
| 246 |
+
'T2': None,
|
| 247 |
+
'FLAIR': 'FLAIR',
|
| 248 |
+
'CT': None,
|
| 249 |
+
|
| 250 |
+
# processed ground truths
|
| 251 |
+
'surface': None, #'surfaces', TODO
|
| 252 |
+
'distance': None,
|
| 253 |
+
'segmentation': 'label_maps_segmentation',
|
| 254 |
+
'bias_field': None,
|
| 255 |
+
'pathology': 'pathology_maps_segmentation',
|
| 256 |
+
'pathology_prob': 'pathology_probability',
|
| 257 |
+
}
|
| 258 |
+
},
|
| 259 |
+
}
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
all_dataset_names = dataset_setups.keys()
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
# get all pathologies
|
| 266 |
+
pathology_paths = []
|
| 267 |
+
pathology_prob_paths = []
|
| 268 |
+
for name, dict in dataset_setups.items():
|
| 269 |
+
# TODO: select what kind of shapes?
|
| 270 |
+
if dict['paths']['pathology'] is not None and dict['pathology_type'] is not None and dict['pathology_type'] == 'stroke':
|
| 271 |
+
pathology_paths += glob.glob(os.path.join(dict['root'], dict['paths']['pathology'], '*.nii.gz')) \
|
| 272 |
+
+ glob.glob(os.path.join(dict['root'], dict['paths']['pathology'], '*.nii'))
|
| 273 |
+
pathology_prob_paths += glob.glob(os.path.join(dict['root'], dict['paths']['pathology_prob'], '*.nii.gz')) \
|
| 274 |
+
+ glob.glob(os.path.join(dict['root'], dict['paths']['pathology_prob'], '*.nii'))
|
| 275 |
+
n_pathology = len(pathology_paths)
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
# with csf # NOTE old version (FreeSurfer standard), non-vast
|
| 279 |
+
label_list_segmentation = [0,14,15,16,24,77,85, 2, 3, 4, 7, 8, 10,11,12,13,17,18,26,28, 41,42,43,46,47,49,50,51,52,53,54,58,60] # 33
|
| 280 |
+
n_neutral_labels = 7
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
## NEW VAST synth
|
| 284 |
+
label_list_segmentation_brainseg_with_extracerebral = [0, 11, 12, 13, 16, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46,
|
| 285 |
+
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 14, 15, 17, 47, 49, 51, 53, 55,
|
| 286 |
+
18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 48, 50, 52, 54, 56]
|
| 287 |
+
n_neutral_labels_brainseg_with_extracerebral = 20
|
| 288 |
+
|
| 289 |
+
label_list_segmentation_brainseg_left = [0, 1, 2, 3, 4, 7, 8, 9, 10, 14, 15, 17, 31, 34, 36, 38, 40, 42]
|
| 290 |
+
|
Generator/datasets.py
ADDED
|
@@ -0,0 +1,757 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, sys, glob
|
| 2 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 3 |
+
from collections import defaultdict
|
| 4 |
+
import random
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import numpy as np
|
| 8 |
+
import nibabel as nib
|
| 9 |
+
from torch.utils.data import Dataset
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
from .utils import *
|
| 13 |
+
from .constants import n_pathology, pathology_paths, pathology_prob_paths, \
|
| 14 |
+
n_neutral_labels_brainseg_with_extracerebral, label_list_segmentation_brainseg_with_extracerebral, \
|
| 15 |
+
label_list_segmentation_brainseg_left, augmentation_funcs, processing_funcs
|
| 16 |
+
import utils.interpol as interpol
|
| 17 |
+
|
| 18 |
+
from utils.misc import viewVolume
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
from ShapeID.DiffEqs.pde import AdvDiffPDE
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class BaseGen(Dataset):
|
| 26 |
+
"""
|
| 27 |
+
BaseGen dataset
|
| 28 |
+
"""
|
| 29 |
+
def __init__(self, gen_args, device='cpu'):
|
| 30 |
+
|
| 31 |
+
self.gen_args = gen_args
|
| 32 |
+
self.split = gen_args.split
|
| 33 |
+
|
| 34 |
+
self.synth_args = self.gen_args.generator
|
| 35 |
+
self.shape_gen_args = gen_args.pathology_shape_generator
|
| 36 |
+
self.real_image_args = gen_args.real_image_generator
|
| 37 |
+
self.synth_image_args = gen_args.synth_image_generator
|
| 38 |
+
self.augmentation_steps = vars(gen_args.augmentation_steps)
|
| 39 |
+
self.input_prob = vars(gen_args.modality_probs)
|
| 40 |
+
self.device = device
|
| 41 |
+
|
| 42 |
+
self.prepare_tasks()
|
| 43 |
+
self.prepare_paths()
|
| 44 |
+
self.prepare_grid()
|
| 45 |
+
self.prepare_one_hot()
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def __len__(self):
|
| 49 |
+
return sum([len(self.names[i]) for i in range(len(self.names))])
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def idx_to_path(self, idx):
|
| 53 |
+
cnt = 0
|
| 54 |
+
for i, l in enumerate(self.datasets_len):
|
| 55 |
+
if idx >= cnt and idx < cnt + l:
|
| 56 |
+
dataset_name = self.datasets[i]
|
| 57 |
+
age = self.ages[i][os.path.basename(self.names[i][idx - cnt]).split('.T1w')[0]] if len(self.ages) > 0 else None
|
| 58 |
+
return dataset_name, vars(self.input_prob[dataset_name]), self.names[i][idx - cnt], age
|
| 59 |
+
else:
|
| 60 |
+
cnt += l
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def prepare_paths(self):
|
| 64 |
+
|
| 65 |
+
# Collect list of available images, per dataset
|
| 66 |
+
if len(self.gen_args.dataset_names) < 1:
|
| 67 |
+
datasets = []
|
| 68 |
+
g = glob.glob(os.path.join(self.gen_args.data_root, '*' + 'T1w.nii'))
|
| 69 |
+
for i in range(len(g)):
|
| 70 |
+
filename = os.path.basename(g[i])
|
| 71 |
+
dataset = filename[:filename.find('.')]
|
| 72 |
+
found = False
|
| 73 |
+
for d in datasets:
|
| 74 |
+
if dataset == d:
|
| 75 |
+
found = True
|
| 76 |
+
if found is False:
|
| 77 |
+
datasets.append(dataset)
|
| 78 |
+
print('Found ' + str(len(datasets)) + ' datasets with ' + str(len(g)) + ' scans in total')
|
| 79 |
+
else:
|
| 80 |
+
datasets = self.gen_args.dataset_names
|
| 81 |
+
print('Dataset list', datasets)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
names = []
|
| 85 |
+
if 'age' in self.tasks:
|
| 86 |
+
self.split = self.split + '_age'
|
| 87 |
+
if self.gen_args.split_root is not None:
|
| 88 |
+
split_file = open(os.path.join(self.gen_args.split_root, self.split + '.txt'), 'r')
|
| 89 |
+
split_names = []
|
| 90 |
+
for subj in split_file.readlines():
|
| 91 |
+
split_names.append(subj.strip())
|
| 92 |
+
|
| 93 |
+
for i in range(len(datasets)):
|
| 94 |
+
names.append([name for name in split_names if os.path.basename(name).startswith(datasets[i])])
|
| 95 |
+
#else:
|
| 96 |
+
# for i in range(len(datasets)):
|
| 97 |
+
# names.append(glob.glob(os.path.join(self.gen_args.data_root, datasets[i] + '.*' + 'T1w.nii')))
|
| 98 |
+
|
| 99 |
+
# read brain age
|
| 100 |
+
ages = []
|
| 101 |
+
if 'age' in self.tasks:
|
| 102 |
+
age_file = open(os.path.join(self.gen_args.split_root, 'participants_age.txt'), 'r')
|
| 103 |
+
subj_name_age = []
|
| 104 |
+
for line in age_file.readlines(): # 'subj age\n'
|
| 105 |
+
subj_name_age.append(line.strip().split(' '))
|
| 106 |
+
for i in range(len(datasets)):
|
| 107 |
+
ages.append({})
|
| 108 |
+
for [name, age] in subj_name_age:
|
| 109 |
+
if name.startswith(datasets[i]):
|
| 110 |
+
ages[-1][name] = float(age)
|
| 111 |
+
print('Age info', self.split, len(ages[0].items()), min(ages[0].values()), max(ages[0].values()))
|
| 112 |
+
|
| 113 |
+
self.ages = ages
|
| 114 |
+
self.names = names
|
| 115 |
+
self.datasets = datasets
|
| 116 |
+
self.datasets_num = len(datasets)
|
| 117 |
+
self.datasets_len = [len(self.names[i]) for i in range(len(self.names))]
|
| 118 |
+
print('Num of data', sum([len(self.names[i]) for i in range(len(self.names))]))
|
| 119 |
+
|
| 120 |
+
self.pathology_type = None #setup_dict['pathology_type']
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def prepare_tasks(self):
|
| 124 |
+
self.tasks = [key for (key, value) in vars(self.gen_args.task).items() if value]
|
| 125 |
+
if 'bias_field' in self.tasks and 'segmentation' not in self.tasks:
|
| 126 |
+
# add segmentation mask for computing bias_field_soft_mask
|
| 127 |
+
self.tasks += ['segmentation']
|
| 128 |
+
if 'pathology' in self.tasks and self.synth_args.augment_pathology and self.synth_args.random_shape_prob < 1.:
|
| 129 |
+
self.t = torch.from_numpy(np.arange(self.shape_gen_args.max_nt) * self.shape_gen_args.dt).to(self.device)
|
| 130 |
+
with torch.no_grad():
|
| 131 |
+
self.adv_pde = AdvDiffPDE(data_spacing=[1., 1., 1.],
|
| 132 |
+
perf_pattern='adv',
|
| 133 |
+
V_type='vector_div_free',
|
| 134 |
+
V_dict={},
|
| 135 |
+
BC=self.shape_gen_args.bc,
|
| 136 |
+
dt=self.shape_gen_args.dt,
|
| 137 |
+
device=self.device
|
| 138 |
+
)
|
| 139 |
+
else:
|
| 140 |
+
self.t, self.adv_pde = None, None
|
| 141 |
+
for task_name in self.tasks:
|
| 142 |
+
if task_name not in processing_funcs.keys():
|
| 143 |
+
print('Warning: Function for task "%s" not found' % task_name)
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def prepare_grid(self):
|
| 147 |
+
self.size = self.synth_args.size
|
| 148 |
+
|
| 149 |
+
# Get resolution of training data
|
| 150 |
+
#aff = nib.load(os.path.join(self.modalities['Gen'], self.names[0])).affine
|
| 151 |
+
#self.res_training_data = np.sqrt(np.sum(abs(aff[:-1, :-1]), axis=0))
|
| 152 |
+
|
| 153 |
+
self.res_training_data = np.array([1.0, 1.0, 1.0])
|
| 154 |
+
|
| 155 |
+
xx, yy, zz = np.meshgrid(range(self.size[0]), range(self.size[1]), range(self.size[2]), sparse=False, indexing='ij')
|
| 156 |
+
self.xx = torch.tensor(xx, dtype=torch.float, device=self.device)
|
| 157 |
+
self.yy = torch.tensor(yy, dtype=torch.float, device=self.device)
|
| 158 |
+
self.zz = torch.tensor(zz, dtype=torch.float, device=self.device)
|
| 159 |
+
self.c = torch.tensor((np.array(self.size) - 1) / 2, dtype=torch.float, device=self.device)
|
| 160 |
+
self.xc = self.xx - self.c[0]
|
| 161 |
+
self.yc = self.yy - self.c[1]
|
| 162 |
+
self.zc = self.zz - self.c[2]
|
| 163 |
+
return
|
| 164 |
+
|
| 165 |
+
def prepare_one_hot(self):
|
| 166 |
+
if self.synth_args.left_hemis_only:
|
| 167 |
+
n_labels = len(label_list_segmentation_brainseg_left)
|
| 168 |
+
label_list_segmentation = label_list_segmentation_brainseg_left
|
| 169 |
+
else:
|
| 170 |
+
# Matrix for one-hot encoding (includes a lookup-table)
|
| 171 |
+
n_labels = len(label_list_segmentation_brainseg_with_extracerebral)
|
| 172 |
+
label_list_segmentation = label_list_segmentation_brainseg_with_extracerebral
|
| 173 |
+
|
| 174 |
+
self.lut = torch.zeros(10000, dtype=torch.long, device=self.device)
|
| 175 |
+
for l in range(n_labels):
|
| 176 |
+
self.lut[label_list_segmentation[l]] = l
|
| 177 |
+
self.onehotmatrix = torch.eye(n_labels, dtype=torch.float, device=self.device)
|
| 178 |
+
|
| 179 |
+
# useless for left_hemis_only
|
| 180 |
+
nlat = int((n_labels - n_neutral_labels_brainseg_with_extracerebral) / 2.0)
|
| 181 |
+
self.vflip = np.concatenate([np.array(range(n_neutral_labels_brainseg_with_extracerebral)),
|
| 182 |
+
np.array(range(n_neutral_labels_brainseg_with_extracerebral + nlat, n_labels)),
|
| 183 |
+
np.array(range(n_neutral_labels_brainseg_with_extracerebral, n_neutral_labels_brainseg_with_extracerebral + nlat))])
|
| 184 |
+
return
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def random_affine_transform(self, shp):
|
| 188 |
+
rotations = (2 * self.synth_args.max_rotation * np.random.rand(3) - self.synth_args.max_rotation) / 180.0 * np.pi
|
| 189 |
+
shears = (2 * self.synth_args.max_shear * np.random.rand(3) - self.synth_args.max_shear)
|
| 190 |
+
scalings = 1 + (2 * self.synth_args.max_scaling * np.random.rand(3) - self.synth_args.max_scaling)
|
| 191 |
+
scaling_factor_distances = np.prod(scalings) ** .33333333333
|
| 192 |
+
A = torch.tensor(make_affine_matrix(rotations, shears, scalings), dtype=torch.float, device=self.device)
|
| 193 |
+
|
| 194 |
+
# sample center
|
| 195 |
+
if self.synth_args.random_shift:
|
| 196 |
+
max_shift = (torch.tensor(np.array(shp[0:3]) - self.size, dtype=torch.float, device=self.device)) / 2
|
| 197 |
+
max_shift[max_shift < 0] = 0
|
| 198 |
+
c2 = torch.tensor((np.array(shp[0:3]) - 1)/2, dtype=torch.float, device=self.device) + (2 * (max_shift * torch.rand(3, dtype=float, device=self.device)) - max_shift)
|
| 199 |
+
else:
|
| 200 |
+
c2 = torch.tensor((np.array(shp[0:3]) - 1)/2, dtype=torch.float, device=self.device)
|
| 201 |
+
return scaling_factor_distances, A, c2
|
| 202 |
+
|
| 203 |
+
def random_nonlinear_transform(self, photo_mode, spac):
|
| 204 |
+
nonlin_scale = self.synth_args.nonlin_scale_min + np.random.rand(1) * (self.synth_args.nonlin_scale_max - self.synth_args.nonlin_scale_min)
|
| 205 |
+
size_F_small = np.round(nonlin_scale * np.array(self.size)).astype(int).tolist()
|
| 206 |
+
if photo_mode:
|
| 207 |
+
size_F_small[1] = np.round(self.size[1]/spac).astype(int)
|
| 208 |
+
nonlin_std = self.synth_args.nonlin_std_max * np.random.rand()
|
| 209 |
+
Fsmall = nonlin_std * torch.randn([*size_F_small, 3], dtype=torch.float, device=self.device)
|
| 210 |
+
F = myzoom_torch(Fsmall, np.array(self.size) / size_F_small)
|
| 211 |
+
if photo_mode:
|
| 212 |
+
F[:, :, :, 1] = 0
|
| 213 |
+
|
| 214 |
+
if 'surface' in self.tasks: # TODO need to integrate the non-linear deformation fields for inverse
|
| 215 |
+
steplength = 1.0 / (2.0 ** self.synth_args.n_steps_svf_integration)
|
| 216 |
+
Fsvf = F * steplength
|
| 217 |
+
for _ in range(self.synth_args.n_steps_svf_integration):
|
| 218 |
+
Fsvf += fast_3D_interp_torch(Fsvf, self.xx + Fsvf[:, :, :, 0], self.yy + Fsvf[:, :, :, 1], self.zz + Fsvf[:, :, :, 2], 'linear')
|
| 219 |
+
Fsvf_neg = -F * steplength
|
| 220 |
+
for _ in range(self.synth_args.n_steps_svf_integration):
|
| 221 |
+
Fsvf_neg += fast_3D_interp_torch(Fsvf_neg, self.xx + Fsvf_neg[:, :, :, 0], self.yy + Fsvf_neg[:, :, :, 1], self.zz + Fsvf_neg[:, :, :, 2], 'linear')
|
| 222 |
+
F = Fsvf
|
| 223 |
+
Fneg = Fsvf_neg
|
| 224 |
+
else:
|
| 225 |
+
Fneg = None
|
| 226 |
+
return F, Fneg
|
| 227 |
+
|
| 228 |
+
def generate_deformation(self, setups, shp):
|
| 229 |
+
|
| 230 |
+
# generate affine deformation
|
| 231 |
+
scaling_factor_distances, A, c2 = self.random_affine_transform(shp)
|
| 232 |
+
|
| 233 |
+
# generate nonlinear deformation
|
| 234 |
+
if self.synth_args.nonlinear_transform:
|
| 235 |
+
F, Fneg = self.random_nonlinear_transform(setups['photo_mode'], setups['spac'])
|
| 236 |
+
else:
|
| 237 |
+
F, Fneg = None, None
|
| 238 |
+
|
| 239 |
+
# deform the image grid
|
| 240 |
+
xx2, yy2, zz2, x1, y1, z1, x2, y2, z2 = self.deform_grid(shp, A, c2, F)
|
| 241 |
+
|
| 242 |
+
return {'scaling_factor_distances': scaling_factor_distances,
|
| 243 |
+
'A': A,
|
| 244 |
+
'c2': c2,
|
| 245 |
+
'F': F,
|
| 246 |
+
'Fneg': Fneg,
|
| 247 |
+
'grid': [xx2, yy2, zz2, x1, y1, z1, x2, y2, z2],
|
| 248 |
+
}
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def get_left_hemis_mask(self, grid):
|
| 252 |
+
[xx2, yy2, zz2, x1, y1, z1, x2, y2, z2] = grid
|
| 253 |
+
|
| 254 |
+
if self.synth_args.left_hemis_only:
|
| 255 |
+
S, aff, res = read_image(self.modalities['segmentation']) # read seg map
|
| 256 |
+
S = torch.squeeze(torch.from_numpy(S.get_fdata()[x1:x2, y1:y2, z1:z2].astype(int))).to(self.device)
|
| 257 |
+
S = self.lut[S.int()] # mask out non-left labels
|
| 258 |
+
X, aff, res = read_image(self.modalities['registration'][0]) # read_mni_coord_X
|
| 259 |
+
X = torch.squeeze(torch.from_numpy(X.get_fdata()[x1:x2, y1:y2, z1:z2])).to(self.device)
|
| 260 |
+
self.hemis_mask = ((S > 0) & (X < 0)).int()
|
| 261 |
+
else:
|
| 262 |
+
self.hemis_mask = None
|
| 263 |
+
|
| 264 |
+
def deform_grid(self, shp, A, c2, F):
|
| 265 |
+
if F is not None:
|
| 266 |
+
# deform the images (we do nonlinear "first" ie after so we can do heavy coronal deformations in photo mode)
|
| 267 |
+
xx1 = self.xc + F[:, :, :, 0]
|
| 268 |
+
yy1 = self.yc + F[:, :, :, 1]
|
| 269 |
+
zz1 = self.zc + F[:, :, :, 2]
|
| 270 |
+
else:
|
| 271 |
+
xx1 = self.xc
|
| 272 |
+
yy1 = self.yc
|
| 273 |
+
zz1 = self.zc
|
| 274 |
+
|
| 275 |
+
xx2 = A[0, 0] * xx1 + A[0, 1] * yy1 + A[0, 2] * zz1 + c2[0]
|
| 276 |
+
yy2 = A[1, 0] * xx1 + A[1, 1] * yy1 + A[1, 2] * zz1 + c2[1]
|
| 277 |
+
zz2 = A[2, 0] * xx1 + A[2, 1] * yy1 + A[2, 2] * zz1 + c2[2]
|
| 278 |
+
xx2[xx2 < 0] = 0
|
| 279 |
+
yy2[yy2 < 0] = 0
|
| 280 |
+
zz2[zz2 < 0] = 0
|
| 281 |
+
xx2[xx2 > (shp[0] - 1)] = shp[0] - 1
|
| 282 |
+
yy2[yy2 > (shp[1] - 1)] = shp[1] - 1
|
| 283 |
+
zz2[zz2 > (shp[2] - 1)] = shp[2] - 1
|
| 284 |
+
|
| 285 |
+
# Get the margins for reading images
|
| 286 |
+
x1 = torch.floor(torch.min(xx2))
|
| 287 |
+
y1 = torch.floor(torch.min(yy2))
|
| 288 |
+
z1 = torch.floor(torch.min(zz2))
|
| 289 |
+
x2 = 1+torch.ceil(torch.max(xx2))
|
| 290 |
+
y2 = 1 + torch.ceil(torch.max(yy2))
|
| 291 |
+
z2 = 1 + torch.ceil(torch.max(zz2))
|
| 292 |
+
xx2 -= x1
|
| 293 |
+
yy2 -= y1
|
| 294 |
+
zz2 -= z1
|
| 295 |
+
|
| 296 |
+
x1 = x1.cpu().numpy().astype(int)
|
| 297 |
+
y1 = y1.cpu().numpy().astype(int)
|
| 298 |
+
z1 = z1.cpu().numpy().astype(int)
|
| 299 |
+
x2 = x2.cpu().numpy().astype(int)
|
| 300 |
+
y2 = y2.cpu().numpy().astype(int)
|
| 301 |
+
z2 = z2.cpu().numpy().astype(int)
|
| 302 |
+
|
| 303 |
+
return xx2, yy2, zz2, x1, y1, z1, x2, y2, z2
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
def augment_sample(self, name, I_def, setups, deform_dict, res, target, pathol_direction = None, input_mode = 'synth'):
|
| 307 |
+
|
| 308 |
+
sample = {}
|
| 309 |
+
[xx2, yy2, zz2, x1, y1, z1, x2, y2, z2] = deform_dict['grid']
|
| 310 |
+
|
| 311 |
+
if not isinstance(I_def, torch.Tensor):
|
| 312 |
+
I_def = torch.squeeze(torch.tensor(I_def.get_fdata()[x1:x2, y1:y2, z1:z2].astype(float), dtype=torch.float, device=self.device))
|
| 313 |
+
if self.hemis_mask is not None:
|
| 314 |
+
I_def[self.hemis_mask == 0] = 0
|
| 315 |
+
# Deform grid
|
| 316 |
+
I_def = fast_3D_interp_torch(I_def, xx2, yy2, zz2, 'linear')
|
| 317 |
+
|
| 318 |
+
if input_mode == 'CT':
|
| 319 |
+
I_def = torch.clamp(I_def, min = 0., max = 80.)
|
| 320 |
+
|
| 321 |
+
if 'pathology' in target and isinstance(target['pathology'], torch.Tensor) and target['pathology'].sum() > 0:
|
| 322 |
+
I_def = self.encode_pathology(I_def, target['pathology'], target['pathology_prob'], pathol_direction)
|
| 323 |
+
I_def[I_def < 0.] = 0.
|
| 324 |
+
else:
|
| 325 |
+
target['pathology'] = 0.
|
| 326 |
+
target['pathology_prob'] = 0.
|
| 327 |
+
|
| 328 |
+
# Augment sample
|
| 329 |
+
aux_dict = {}
|
| 330 |
+
augmentation_steps = self.augmentation_steps['synth'] if input_mode == 'synth' else self.augmentation_steps['real']
|
| 331 |
+
for func_name in augmentation_steps:
|
| 332 |
+
I_def, aux_dict = augmentation_funcs[func_name](I = I_def, aux_dict = aux_dict, cfg = self.gen_args.generator,
|
| 333 |
+
input_mode = input_mode, setups = setups, size = self.size, res = res, device = self.device)
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
# Back to original resolution
|
| 337 |
+
if self.synth_args.bspline_zooming:
|
| 338 |
+
I_def = interpol.resize(I_def, shape=self.size, anchor='edge', interpolation=3, bound='dct2', prefilter=True)
|
| 339 |
+
else:
|
| 340 |
+
I_def = myzoom_torch(I_def, 1 / aux_dict['factors'])
|
| 341 |
+
|
| 342 |
+
maxi = torch.max(I_def)
|
| 343 |
+
I_final = I_def / maxi
|
| 344 |
+
|
| 345 |
+
if 'super_resolution' in self.tasks:
|
| 346 |
+
SRresidual = aux_dict['high_res'] / maxi - I_final
|
| 347 |
+
sample.update({'high_res_residual': torch.flip(SRresidual, [0])[None] if setups['flip'] else SRresidual[None]})
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
sample.update({'input': torch.flip(I_final, [0])[None] if setups['flip'] else I_final[None]})
|
| 351 |
+
if 'bias_field' in self.tasks and input_mode != 'CT':
|
| 352 |
+
sample.update({'bias_field_log': torch.flip(aux_dict['BFlog'], [0])[None] if setups['flip'] else aux_dict['BFlog'][None]})
|
| 353 |
+
|
| 354 |
+
return sample
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
def generate_sample(self, name, G, setups, deform_dict, res, target):
|
| 358 |
+
|
| 359 |
+
[xx2, yy2, zz2, x1, y1, z1, x2, y2, z2] = deform_dict['grid']
|
| 360 |
+
|
| 361 |
+
# Generate contrasts
|
| 362 |
+
mus, sigmas = self.get_contrast(setups['photo_mode'])
|
| 363 |
+
|
| 364 |
+
G = torch.squeeze(torch.tensor(G.get_fdata()[x1:x2, y1:y2, z1:z2].astype(float), dtype=torch.float, device=self.device))
|
| 365 |
+
#G[G > 255] = 0 # kill extracerebral regions
|
| 366 |
+
G[G == 77] = 2 # merge WM lesion to white matter region
|
| 367 |
+
if self.hemis_mask is not None:
|
| 368 |
+
G[self.hemis_mask == 0] = 0
|
| 369 |
+
Gr = torch.round(G).long()
|
| 370 |
+
|
| 371 |
+
SYN = mus[Gr] + sigmas[Gr] * torch.randn(Gr.shape, dtype=torch.float, device=self.device)
|
| 372 |
+
SYN[SYN < 0] = 0
|
| 373 |
+
#SYN /= mus[2] # normalize by WM
|
| 374 |
+
#SYN = gaussian_blur_3d(SYN, 0.5*np.ones(3), self.device) # cosmetic
|
| 375 |
+
|
| 376 |
+
SYN = fast_3D_interp_torch(SYN, xx2, yy2, zz2)
|
| 377 |
+
|
| 378 |
+
# Make random linear combinations
|
| 379 |
+
if np.random.rand() < self.gen_args.mix_synth_prob:
|
| 380 |
+
v = torch.rand(4)
|
| 381 |
+
v[2] = 0 if 'T2' not in self.modalities else v[2]
|
| 382 |
+
v[3] = 0 if 'FLAIR' not in self.modalities else v[3]
|
| 383 |
+
v /= torch.sum(v)
|
| 384 |
+
SYN = v[0] * SYN + v[1] * target['T1'][0]
|
| 385 |
+
if 'T2' in self.modalities:
|
| 386 |
+
SYN += v[2] * target['T2'][0]
|
| 387 |
+
if 'FLAIR' in self.modalities:
|
| 388 |
+
SYN += v[3] * target['FLAIR'][0]
|
| 389 |
+
|
| 390 |
+
if 'pathology' in target and isinstance(target['pathology'], torch.Tensor) and target['pathology'].sum() > 0:
|
| 391 |
+
SYN_cerebral = SYN.clone()
|
| 392 |
+
SYN_cerebral[Gr == 0] = 0
|
| 393 |
+
SYN_cerebral = fast_3D_interp_torch(SYN_cerebral, xx2, yy2, zz2)[None]
|
| 394 |
+
|
| 395 |
+
wm_mask = (Gr==2) | (Gr==41)
|
| 396 |
+
wm_mean = (SYN * wm_mask).sum() / wm_mask.sum()
|
| 397 |
+
gm_mask = (Gr!=0) & (Gr!=2) & (Gr!=41)
|
| 398 |
+
gm_mean = (SYN * gm_mask).sum() / gm_mask.sum()
|
| 399 |
+
|
| 400 |
+
target['pathology'][SYN_cerebral == 0] = 0
|
| 401 |
+
target['pathology_prob'][SYN_cerebral == 0] = 0
|
| 402 |
+
# determine to be T1-resembled or T2-resembled
|
| 403 |
+
#if pathol_direction: lesion should be brigher than WM.mean()
|
| 404 |
+
# pathol_direction: +1: T2-like; -1: T1-like
|
| 405 |
+
pathol_direction = self.get_pathology_direction('synth', gm_mean > wm_mean)
|
| 406 |
+
else:
|
| 407 |
+
pathol_direction = None
|
| 408 |
+
target['pathology'] = 0.
|
| 409 |
+
target['pathology_prob'] = 0.
|
| 410 |
+
|
| 411 |
+
SYN[SYN < 0.] = 0.
|
| 412 |
+
return target['pathology'], target['pathology_prob'], self.augment_sample(name, SYN, setups, deform_dict, res, target, pathol_direction = pathol_direction)
|
| 413 |
+
|
| 414 |
+
def get_pathology_direction(self, input_mode, pathol_direction = None):
|
| 415 |
+
#if np.random.rand() < 0.1: # in some (rare) cases, randomly pick the direction
|
| 416 |
+
# return random.choice([True, False])
|
| 417 |
+
|
| 418 |
+
if pathol_direction is not None: # for synth image
|
| 419 |
+
return pathol_direction
|
| 420 |
+
|
| 421 |
+
if input_mode in ['T1', 'CT']:
|
| 422 |
+
return False
|
| 423 |
+
|
| 424 |
+
if input_mode in ['T2', 'FLAIR']:
|
| 425 |
+
return True
|
| 426 |
+
|
| 427 |
+
return random.choice([True, False])
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
def get_contrast(self, photo_mode):
|
| 431 |
+
# Sample Gaussian image
|
| 432 |
+
mus = 25 + 200 * torch.rand(256, dtype=torch.float, device=self.device)
|
| 433 |
+
sigmas = 5 + 20 * torch.rand(256, dtype=torch.float, device=self.device)
|
| 434 |
+
|
| 435 |
+
if np.random.rand() < self.synth_args.ct_prob:
|
| 436 |
+
darker = 25 + 10 * torch.rand(1, dtype=torch.float, device=self.device)[0]
|
| 437 |
+
for l in ct_brightness_group['darker']:
|
| 438 |
+
mus[l] = darker
|
| 439 |
+
dark = 90 + 20 * torch.rand(1, dtype=torch.float, device=self.device)[0]
|
| 440 |
+
for l in ct_brightness_group['dark']:
|
| 441 |
+
mus[l] = dark
|
| 442 |
+
bright = 110 + 20 * torch.rand(1, dtype=torch.float, device=self.device)[0]
|
| 443 |
+
for l in ct_brightness_group['bright']:
|
| 444 |
+
mus[l] = bright
|
| 445 |
+
brighter = 150 + 50 * torch.rand(1, dtype=torch.float, device=self.device)[0]
|
| 446 |
+
for l in ct_brightness_group['brighter']:
|
| 447 |
+
mus[l] = brighter
|
| 448 |
+
|
| 449 |
+
if photo_mode or np.random.rand(1)<0.5: # set the background to zero every once in a while (or always in photo mode)
|
| 450 |
+
mus[0] = 0
|
| 451 |
+
|
| 452 |
+
# partial volume
|
| 453 |
+
# 1 = lesion, 2 = WM, 3 = GM, 4 = CSF
|
| 454 |
+
v = 0.02 * torch.arange(50).to(self.device)
|
| 455 |
+
mus[100:150] = mus[1] * (1 - v) + mus[2] * v
|
| 456 |
+
mus[150:200] = mus[2] * (1 - v) + mus[3] * v
|
| 457 |
+
mus[200:250] = mus[3] * (1 - v) + mus[4] * v
|
| 458 |
+
mus[250] = mus[4]
|
| 459 |
+
sigmas[100:150] = torch.sqrt(sigmas[1]**2 * (1 - v) + sigmas[2]**2 * v)
|
| 460 |
+
sigmas[150:200] = torch.sqrt(sigmas[2]**2 * (1 - v) + sigmas[3]**2 * v)
|
| 461 |
+
sigmas[200:250] = torch.sqrt(sigmas[3]**2 * (1 - v) + sigmas[4]**2 * v)
|
| 462 |
+
sigmas[250] = sigmas[4]
|
| 463 |
+
|
| 464 |
+
return mus, sigmas
|
| 465 |
+
|
| 466 |
+
def get_setup_params(self):
|
| 467 |
+
|
| 468 |
+
if self.synth_args.left_hemis_only:
|
| 469 |
+
hemis = 'left'
|
| 470 |
+
else:
|
| 471 |
+
hemis = 'both'
|
| 472 |
+
|
| 473 |
+
if self.synth_args.low_res_only:
|
| 474 |
+
photo_mode = False
|
| 475 |
+
elif self.synth_args.left_hemis_only:
|
| 476 |
+
photo_mode = True
|
| 477 |
+
else:
|
| 478 |
+
photo_mode = np.random.rand() < self.synth_args.photo_prob
|
| 479 |
+
|
| 480 |
+
pathol_mode = np.random.rand() < self.synth_args.pathology_prob
|
| 481 |
+
pathol_random_shape = np.random.rand() < self.synth_args.random_shape_prob
|
| 482 |
+
spac = 2.5 + 10 * np.random.rand() if photo_mode else None
|
| 483 |
+
flip = np.random.randn() < self.synth_args.flip_prob if not self.synth_args.left_hemis_only else False
|
| 484 |
+
|
| 485 |
+
if photo_mode:
|
| 486 |
+
resolution = np.array([self.res_training_data[0], spac, self.res_training_data[2]])
|
| 487 |
+
thickness = np.array([self.res_training_data[0], 0.1, self.res_training_data[2]])
|
| 488 |
+
else:
|
| 489 |
+
resolution, thickness = resolution_sampler(self.synth_args.low_res_only)
|
| 490 |
+
return {'resolution': resolution, 'thickness': thickness,
|
| 491 |
+
'photo_mode': photo_mode, 'pathol_mode': pathol_mode,
|
| 492 |
+
'pathol_random_shape': pathol_random_shape,
|
| 493 |
+
'spac': spac, 'flip': flip, 'hemis': hemis}
|
| 494 |
+
|
| 495 |
+
|
| 496 |
+
def encode_pathology(self, I, P, Pprob, pathol_direction = None):
|
| 497 |
+
|
| 498 |
+
|
| 499 |
+
if pathol_direction is None: # True: T2/FLAIR-resembled, False: T1-resembled
|
| 500 |
+
pathol_direction = random.choice([True, False])
|
| 501 |
+
|
| 502 |
+
P, Pprob = torch.squeeze(P), torch.squeeze(Pprob)
|
| 503 |
+
I_mu = (I * P).sum() / P.sum()
|
| 504 |
+
|
| 505 |
+
p_mask = torch.round(P).long()
|
| 506 |
+
#pth_mus = I_mu/4 + I_mu/2 * torch.rand(10000, dtype=torch.float, device=self.device)
|
| 507 |
+
pth_mus = 3*I_mu/4 + I_mu/4 * torch.rand(10000, dtype=torch.float, device=self.device) # enforce the pathology pattern harder!
|
| 508 |
+
pth_mus = pth_mus if pathol_direction else -pth_mus
|
| 509 |
+
pth_sigmas = I_mu/4 * torch.rand(10000, dtype=torch.float, device=self.device)
|
| 510 |
+
I += Pprob * (pth_mus[p_mask] + pth_sigmas[p_mask] * torch.randn(p_mask.shape, dtype=torch.float, device=self.device))
|
| 511 |
+
I[I < 0] = 0
|
| 512 |
+
|
| 513 |
+
#print('encode', P.shape, P.mean())
|
| 514 |
+
#print('pre', I_mu)
|
| 515 |
+
#I_mu = (I * P).sum() / P.sum()
|
| 516 |
+
#print('post', I_mu)
|
| 517 |
+
|
| 518 |
+
return I
|
| 519 |
+
|
| 520 |
+
def get_info(self, t1):
|
| 521 |
+
|
| 522 |
+
t1dm = t1[:-7] + 'T1w.defacingmask.nii'
|
| 523 |
+
t2 = t1[:-7] + 'T2w.nii'
|
| 524 |
+
t2dm = t1[:-7] + 'T2w.defacingmask.nii'
|
| 525 |
+
flair = t1[:-7] + 'FLAIR.nii'
|
| 526 |
+
flairdm = t1[:-7] + 'FLAIR.defacingmask.nii'
|
| 527 |
+
ct = t1[:-7] + 'CT.nii'
|
| 528 |
+
ctdm = t1[:-7] + 'CT.defacingmask.nii'
|
| 529 |
+
generation_labels = t1[:-7] + 'generation_labels.nii'
|
| 530 |
+
segmentation_labels = t1[:-7] + self.gen_args.segment_prefix + '.nii'
|
| 531 |
+
#brain_dist_map = t1[:-7] + 'brain_dist_map.nii'
|
| 532 |
+
lp_dist_map = t1[:-7] + 'lp_dist_map.nii'
|
| 533 |
+
rp_dist_map = t1[:-7] + 'rp_dist_map.nii'
|
| 534 |
+
lw_dist_map = t1[:-7] + 'lw_dist_map.nii'
|
| 535 |
+
rw_dist_map = t1[:-7] + 'rw_dist_map.nii'
|
| 536 |
+
mni_reg_x = t1[:-7] + 'mni_reg.x.nii'
|
| 537 |
+
mni_reg_y = t1[:-7] + 'mni_reg.y.nii'
|
| 538 |
+
mni_reg_z = t1[:-7] + 'mni_reg.z.nii'
|
| 539 |
+
|
| 540 |
+
|
| 541 |
+
self.modalities = {'T1': t1, 'Gen': generation_labels, 'segmentation': segmentation_labels,
|
| 542 |
+
'distance': [lp_dist_map, lw_dist_map, rp_dist_map, rw_dist_map],
|
| 543 |
+
'registration': [mni_reg_x, mni_reg_y, mni_reg_z]}
|
| 544 |
+
|
| 545 |
+
if os.path.isfile(t1dm):
|
| 546 |
+
self.modalities.update({'T1_DM': t1dm})
|
| 547 |
+
if os.path.isfile(t2):
|
| 548 |
+
self.modalities.update({'T2': t2})
|
| 549 |
+
if os.path.isfile(t2dm):
|
| 550 |
+
self.modalities.update({'T2_DM': t2dm})
|
| 551 |
+
if os.path.isfile(flair):
|
| 552 |
+
self.modalities.update({'FLAIR': flair})
|
| 553 |
+
if os.path.isfile(flairdm):
|
| 554 |
+
self.modalities.update({'FLAIR_DM': flairdm})
|
| 555 |
+
if os.path.isfile(ct):
|
| 556 |
+
self.modalities.update({'CT': ct})
|
| 557 |
+
if os.path.isfile(ctdm):
|
| 558 |
+
self.modalities.update({'CT_DM': ctdm})
|
| 559 |
+
|
| 560 |
+
return self.modalities
|
| 561 |
+
|
| 562 |
+
|
| 563 |
+
def read_input(self, idx):
|
| 564 |
+
"""
|
| 565 |
+
determine input type according to prob (in generator/constants.py)
|
| 566 |
+
Logic: if np.random.rand() < real_image_prob and is real_image_exist --> input real images; otherwise, synthesize images.
|
| 567 |
+
"""
|
| 568 |
+
dataset_name, input_prob, t1_path, age = self.idx_to_path(idx)
|
| 569 |
+
case_name = os.path.basename(t1_path).split('.T1w.nii')[0]
|
| 570 |
+
self.modalities = self.get_info(t1_path)
|
| 571 |
+
|
| 572 |
+
prob = np.random.rand()
|
| 573 |
+
if prob < input_prob['T1'] and 'T1' in self.modalities:
|
| 574 |
+
input_mode = 'T1'
|
| 575 |
+
img, aff, res = read_image(self.modalities['T1'])
|
| 576 |
+
elif prob < input_prob['T2'] and 'T2' in self.modalities:
|
| 577 |
+
input_mode = 'T2'
|
| 578 |
+
img, aff, res = read_image(self.modalities['T2'])
|
| 579 |
+
elif prob < input_prob['FLAIR'] and 'FLAIR' in self.modalities:
|
| 580 |
+
input_mode = 'FLAIR'
|
| 581 |
+
img, aff, res = read_image(self.modalities['FLAIR'])
|
| 582 |
+
elif prob < input_prob['CT'] and 'CT' in self.modalities:
|
| 583 |
+
input_mode = 'CT'
|
| 584 |
+
img, aff, res = read_image(self.modalities['CT'])
|
| 585 |
+
else:
|
| 586 |
+
input_mode = 'synth'
|
| 587 |
+
img, aff, res = read_image(self.modalities['Gen'])
|
| 588 |
+
|
| 589 |
+
return dataset_name, case_name, input_mode, img, aff, res, age
|
| 590 |
+
|
| 591 |
+
|
| 592 |
+
def read_and_deform_target(self, idx, exist_keys, task_name, input_mode, setups, deform_dict, linear_weights = None):
|
| 593 |
+
current_target = {}
|
| 594 |
+
p_prob_path, augment, thres = None, False, 0.1
|
| 595 |
+
|
| 596 |
+
if task_name == 'pathology':
|
| 597 |
+
# NOTE: for now - encode pathology only for healthy cases
|
| 598 |
+
# TODO: what to do if the case has pathology itself? -- inconsistency between encoded pathol and the output
|
| 599 |
+
if self.pathology_type is None: # healthy
|
| 600 |
+
if setups['pathol_mode']: # and input_mode == 'synth':
|
| 601 |
+
if setups['pathol_random_shape']:
|
| 602 |
+
p_prob_path = 'random_shape'
|
| 603 |
+
augment, thres = False, self.shape_gen_args.pathol_thres
|
| 604 |
+
else:
|
| 605 |
+
p_prob_path = random.choice(pathology_prob_paths)
|
| 606 |
+
augment, thres = self.synth_args.augment_pathology, self.shape_gen_args.pathol_thres
|
| 607 |
+
else:
|
| 608 |
+
pass
|
| 609 |
+
#p_prob_path = self.modalities['pathology_prob']
|
| 610 |
+
|
| 611 |
+
current_target = processing_funcs[task_name](exist_keys, task_name, p_prob_path, setups, deform_dict, self.device,
|
| 612 |
+
mask = self.hemis_mask,
|
| 613 |
+
augment = augment,
|
| 614 |
+
pde_func = self.adv_pde,
|
| 615 |
+
t = self.t,
|
| 616 |
+
shape_gen_args = self.shape_gen_args,
|
| 617 |
+
thres = thres
|
| 618 |
+
)
|
| 619 |
+
|
| 620 |
+
else:
|
| 621 |
+
if task_name in self.modalities:
|
| 622 |
+
current_target = processing_funcs[task_name](exist_keys, task_name, self.modalities[task_name],
|
| 623 |
+
setups, deform_dict, self.device,
|
| 624 |
+
mask = self.hemis_mask,
|
| 625 |
+
cfg = self.gen_args,
|
| 626 |
+
onehotmatrix = self.onehotmatrix,
|
| 627 |
+
lut = self.lut, vflip = self.vflip
|
| 628 |
+
)
|
| 629 |
+
else:
|
| 630 |
+
current_target = {task_name: 0.}
|
| 631 |
+
return current_target
|
| 632 |
+
|
| 633 |
+
|
| 634 |
+
def update_gen_args(self, new_args):
|
| 635 |
+
for key, value in vars(new_args).items():
|
| 636 |
+
vars(self.gen_args.generator)[key] = value
|
| 637 |
+
|
| 638 |
+
def __getitem__(self, idx):
|
| 639 |
+
if torch.is_tensor(idx):
|
| 640 |
+
idx = idx.tolist()
|
| 641 |
+
|
| 642 |
+
# read input: real or synthesized image, according to customized prob
|
| 643 |
+
dataset_name, case_name, input_mode, img, aff, res, age = self.read_input(idx)
|
| 644 |
+
|
| 645 |
+
# generate random values
|
| 646 |
+
setups = self.get_setup_params()
|
| 647 |
+
|
| 648 |
+
# sample random deformation
|
| 649 |
+
deform_dict = self.generate_deformation(setups, img.shape)
|
| 650 |
+
|
| 651 |
+
# get left_hemis_mask if needed
|
| 652 |
+
self.get_left_hemis_mask(deform_dict['grid'])
|
| 653 |
+
|
| 654 |
+
# read and deform target according to the assigned tasks
|
| 655 |
+
target = defaultdict(lambda: None)
|
| 656 |
+
target['name'] = case_name
|
| 657 |
+
target.update(self.read_and_deform_target(idx, target.keys(), 'T1', input_mode, setups, deform_dict))
|
| 658 |
+
target.update(self.read_and_deform_target(idx, target.keys(), 'T2', input_mode, setups, deform_dict))
|
| 659 |
+
target.update(self.read_and_deform_target(idx, target.keys(), 'FLAIR', input_mode, setups, deform_dict))
|
| 660 |
+
for task_name in self.tasks:
|
| 661 |
+
if task_name in processing_funcs.keys() and task_name not in ['T1', 'T2', 'FLAIR']:
|
| 662 |
+
target.update(self.read_and_deform_target(idx, target.keys(), task_name, input_mode, setups, deform_dict))
|
| 663 |
+
|
| 664 |
+
|
| 665 |
+
# process or generate input sample
|
| 666 |
+
if input_mode == 'synth':
|
| 667 |
+
self.update_gen_args(self.synth_image_args) # severe noise injection for real images
|
| 668 |
+
target['pathology'], target['pathology_prob'], sample = \
|
| 669 |
+
self.generate_sample(case_name, img, setups, deform_dict, res, target)
|
| 670 |
+
else:
|
| 671 |
+
self.update_gen_args(self.real_image_args) # milder noise injection for real images
|
| 672 |
+
sample = self.augment_sample(case_name, img, setups, deform_dict, res, target,
|
| 673 |
+
pathol_direction = self.get_pathology_direction(input_mode),input_mode = input_mode)
|
| 674 |
+
|
| 675 |
+
if setups['flip'] and isinstance(target['pathology'], torch.Tensor): # flipping should happen after P has been encoded
|
| 676 |
+
target['pathology'], target['pathology_prob'] = torch.flip(target['pathology'], [1]), torch.flip(target['pathology_prob'], [1])
|
| 677 |
+
|
| 678 |
+
if age is not None:
|
| 679 |
+
target['age'] = age
|
| 680 |
+
|
| 681 |
+
return self.datasets_num, dataset_name, input_mode, target, sample
|
| 682 |
+
|
| 683 |
+
|
| 684 |
+
|
| 685 |
+
|
| 686 |
+
# An example of customized dataset from BaseSynth
|
| 687 |
+
class BrainIDGen(BaseGen):
|
| 688 |
+
"""
|
| 689 |
+
BrainIDGen dataset
|
| 690 |
+
BrainIDGen enables intra-subject augmentation, i.e., each subject will have multiple augmentations
|
| 691 |
+
"""
|
| 692 |
+
def __init__(self, gen_args, device='cpu'):
|
| 693 |
+
super(BrainIDGen, self).__init__(gen_args, device)
|
| 694 |
+
|
| 695 |
+
self.all_samples = gen_args.generator.all_samples
|
| 696 |
+
self.mild_samples = gen_args.generator.mild_samples
|
| 697 |
+
self.mild_generator_args = gen_args.mild_generator
|
| 698 |
+
self.severe_generator_args = gen_args.severe_generator
|
| 699 |
+
|
| 700 |
+
def __getitem__(self, idx):
|
| 701 |
+
if torch.is_tensor(idx):
|
| 702 |
+
idx = idx.tolist()
|
| 703 |
+
|
| 704 |
+
# read input: real or synthesized image, according to customized prob
|
| 705 |
+
dataset_name, case_name, input_mode, img, aff, res, age = self.read_input(idx)
|
| 706 |
+
|
| 707 |
+
# generate random values
|
| 708 |
+
setups = self.get_setup_params()
|
| 709 |
+
|
| 710 |
+
# sample random deformation
|
| 711 |
+
deform_dict = self.generate_deformation(setups, img.shape)
|
| 712 |
+
|
| 713 |
+
# get left_hemis_mask if needed
|
| 714 |
+
self.get_left_hemis_mask(deform_dict['grid'])
|
| 715 |
+
|
| 716 |
+
# read and deform target according to the assigned tasks
|
| 717 |
+
target = defaultdict(lambda: 1.)
|
| 718 |
+
target['name'] = case_name
|
| 719 |
+
target.update(self.read_and_deform_target(idx, target.keys(), 'T1', input_mode, setups, deform_dict))
|
| 720 |
+
target.update(self.read_and_deform_target(idx, target.keys(), 'T2', input_mode, setups, deform_dict))
|
| 721 |
+
target.update(self.read_and_deform_target(idx, target.keys(), 'FLAIR', input_mode, setups, deform_dict))
|
| 722 |
+
for task_name in self.tasks:
|
| 723 |
+
if task_name in processing_funcs.keys() and task_name not in ['T1', 'T2', 'FLAIR']:
|
| 724 |
+
target.update(self.read_and_deform_target(idx, target.keys(), task_name, input_mode, setups, deform_dict))
|
| 725 |
+
|
| 726 |
+
# process or generate intra-subject input samples
|
| 727 |
+
samples = []
|
| 728 |
+
for i_sample in range(self.all_samples):
|
| 729 |
+
if i_sample < self.mild_samples:
|
| 730 |
+
self.update_gen_args(self.mild_generator_args)
|
| 731 |
+
if input_mode == 'synth':
|
| 732 |
+
self.update_gen_args(self.synth_image_args)
|
| 733 |
+
target['pathology'], target['pathology_prob'], sample = \
|
| 734 |
+
self.generate_sample(case_name, img, setups, deform_dict, res, target)
|
| 735 |
+
else:
|
| 736 |
+
self.update_gen_args(self.real_image_args)
|
| 737 |
+
sample = self.augment_sample(case_name, img, setups, deform_dict, res, target,
|
| 738 |
+
pathol_direction = self.get_pathology_direction(input_mode),input_mode = input_mode)
|
| 739 |
+
else:
|
| 740 |
+
self.update_gen_args(self.severe_generator_args)
|
| 741 |
+
if input_mode == 'synth':
|
| 742 |
+
self.update_gen_args(self.synth_image_args)
|
| 743 |
+
target['pathology'], target['pathology_prob'], sample = \
|
| 744 |
+
self.generate_sample(case_name, img, setups, deform_dict, res, target)
|
| 745 |
+
else:
|
| 746 |
+
self.update_gen_args(self.real_image_args)
|
| 747 |
+
sample = self.augment_sample(case_name, img, setups, deform_dict, res, target,
|
| 748 |
+
pathol_direction = self.get_pathology_direction(input_mode),input_mode = input_mode)
|
| 749 |
+
|
| 750 |
+
samples.append(sample)
|
| 751 |
+
|
| 752 |
+
if setups['flip'] and isinstance(target['pathology'], torch.Tensor): # flipping should happen after P has been encoded
|
| 753 |
+
target['pathology'], target['pathology_prob'] = torch.flip(target['pathology'], [1]), torch.flip(target['pathology_prob'], [1])
|
| 754 |
+
|
| 755 |
+
if age is not None:
|
| 756 |
+
target['age'] = age
|
| 757 |
+
return self.datasets_num, dataset_name, input_mode, target, samples
|
Generator/interpol/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .api import *
|
| 2 |
+
from .resize import *
|
| 3 |
+
from .restrict import *
|
| 4 |
+
from . import backend
|
| 5 |
+
|
| 6 |
+
from . import _version
|
| 7 |
+
__version__ = _version.get_versions()['version']
|
Generator/interpol/_version.py
ADDED
|
@@ -0,0 +1,623 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
# This file helps to compute a version number in source trees obtained from
|
| 3 |
+
# git-archive tarball (such as those provided by githubs download-from-tag
|
| 4 |
+
# feature). Distribution tarballs (built by setup.py sdist) and build
|
| 5 |
+
# directories (produced by setup.py build) will contain a much shorter file
|
| 6 |
+
# that just contains the computed version number.
|
| 7 |
+
|
| 8 |
+
# This file is released into the public domain. Generated by
|
| 9 |
+
# versioneer-0.20 (https://github.com/python-versioneer/python-versioneer)
|
| 10 |
+
|
| 11 |
+
"""Git implementation of _version.py."""
|
| 12 |
+
|
| 13 |
+
import errno
|
| 14 |
+
import os
|
| 15 |
+
import re
|
| 16 |
+
import subprocess
|
| 17 |
+
import sys
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def get_keywords():
|
| 21 |
+
"""Get the keywords needed to look up the version information."""
|
| 22 |
+
# these strings will be replaced by git during git-archive.
|
| 23 |
+
# setup.py/versioneer.py will grep for the variable names, so they must
|
| 24 |
+
# each be defined on a line of their own. _version.py will just call
|
| 25 |
+
# get_keywords().
|
| 26 |
+
git_refnames = " (HEAD -> main, tag: 0.2.3)"
|
| 27 |
+
git_full = "414ed52c973b9d32e3e6a5a75c91cd5aab064f23"
|
| 28 |
+
git_date = "2023-04-17 20:36:50 -0400"
|
| 29 |
+
keywords = {"refnames": git_refnames, "full": git_full, "date": git_date}
|
| 30 |
+
return keywords
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class VersioneerConfig: # pylint: disable=too-few-public-methods
|
| 34 |
+
"""Container for Versioneer configuration parameters."""
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def get_config():
|
| 38 |
+
"""Create, populate and return the VersioneerConfig() object."""
|
| 39 |
+
# these strings are filled in when 'setup.py versioneer' creates
|
| 40 |
+
# _version.py
|
| 41 |
+
cfg = VersioneerConfig()
|
| 42 |
+
cfg.VCS = "git"
|
| 43 |
+
cfg.style = "pep440"
|
| 44 |
+
cfg.tag_prefix = ""
|
| 45 |
+
cfg.parentdir_prefix = ""
|
| 46 |
+
cfg.versionfile_source = "interpol/_version.py"
|
| 47 |
+
cfg.verbose = False
|
| 48 |
+
return cfg
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class NotThisMethod(Exception):
|
| 52 |
+
"""Exception raised if a method is not valid for the current scenario."""
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
LONG_VERSION_PY = {}
|
| 56 |
+
HANDLERS = {}
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def register_vcs_handler(vcs, method): # decorator
|
| 60 |
+
"""Create decorator to mark a method as the handler of a VCS."""
|
| 61 |
+
def decorate(f):
|
| 62 |
+
"""Store f in HANDLERS[vcs][method]."""
|
| 63 |
+
if vcs not in HANDLERS:
|
| 64 |
+
HANDLERS[vcs] = {}
|
| 65 |
+
HANDLERS[vcs][method] = f
|
| 66 |
+
return f
|
| 67 |
+
return decorate
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
# pylint:disable=too-many-arguments,consider-using-with # noqa
|
| 71 |
+
def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False,
|
| 72 |
+
env=None):
|
| 73 |
+
"""Call the given command(s)."""
|
| 74 |
+
assert isinstance(commands, list)
|
| 75 |
+
process = None
|
| 76 |
+
for command in commands:
|
| 77 |
+
try:
|
| 78 |
+
dispcmd = str([command] + args)
|
| 79 |
+
# remember shell=False, so use git.cmd on windows, not just git
|
| 80 |
+
process = subprocess.Popen([command] + args, cwd=cwd, env=env,
|
| 81 |
+
stdout=subprocess.PIPE,
|
| 82 |
+
stderr=(subprocess.PIPE if hide_stderr
|
| 83 |
+
else None))
|
| 84 |
+
break
|
| 85 |
+
except EnvironmentError:
|
| 86 |
+
e = sys.exc_info()[1]
|
| 87 |
+
if e.errno == errno.ENOENT:
|
| 88 |
+
continue
|
| 89 |
+
if verbose:
|
| 90 |
+
print("unable to run %s" % dispcmd)
|
| 91 |
+
print(e)
|
| 92 |
+
return None, None
|
| 93 |
+
else:
|
| 94 |
+
if verbose:
|
| 95 |
+
print("unable to find command, tried %s" % (commands,))
|
| 96 |
+
return None, None
|
| 97 |
+
stdout = process.communicate()[0].strip().decode()
|
| 98 |
+
if process.returncode != 0:
|
| 99 |
+
if verbose:
|
| 100 |
+
print("unable to run %s (error)" % dispcmd)
|
| 101 |
+
print("stdout was %s" % stdout)
|
| 102 |
+
return None, process.returncode
|
| 103 |
+
return stdout, process.returncode
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def versions_from_parentdir(parentdir_prefix, root, verbose):
|
| 107 |
+
"""Try to determine the version from the parent directory name.
|
| 108 |
+
|
| 109 |
+
Source tarballs conventionally unpack into a directory that includes both
|
| 110 |
+
the project name and a version string. We will also support searching up
|
| 111 |
+
two directory levels for an appropriately named parent directory
|
| 112 |
+
"""
|
| 113 |
+
rootdirs = []
|
| 114 |
+
|
| 115 |
+
for _ in range(3):
|
| 116 |
+
dirname = os.path.basename(root)
|
| 117 |
+
if dirname.startswith(parentdir_prefix):
|
| 118 |
+
return {"version": dirname[len(parentdir_prefix):],
|
| 119 |
+
"full-revisionid": None,
|
| 120 |
+
"dirty": False, "error": None, "date": None}
|
| 121 |
+
rootdirs.append(root)
|
| 122 |
+
root = os.path.dirname(root) # up a level
|
| 123 |
+
|
| 124 |
+
if verbose:
|
| 125 |
+
print("Tried directories %s but none started with prefix %s" %
|
| 126 |
+
(str(rootdirs), parentdir_prefix))
|
| 127 |
+
raise NotThisMethod("rootdir doesn't start with parentdir_prefix")
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
@register_vcs_handler("git", "get_keywords")
|
| 131 |
+
def git_get_keywords(versionfile_abs):
|
| 132 |
+
"""Extract version information from the given file."""
|
| 133 |
+
# the code embedded in _version.py can just fetch the value of these
|
| 134 |
+
# keywords. When used from setup.py, we don't want to import _version.py,
|
| 135 |
+
# so we do it with a regexp instead. This function is not used from
|
| 136 |
+
# _version.py.
|
| 137 |
+
keywords = {}
|
| 138 |
+
try:
|
| 139 |
+
with open(versionfile_abs, "r") as fobj:
|
| 140 |
+
for line in fobj:
|
| 141 |
+
if line.strip().startswith("git_refnames ="):
|
| 142 |
+
mo = re.search(r'=\s*"(.*)"', line)
|
| 143 |
+
if mo:
|
| 144 |
+
keywords["refnames"] = mo.group(1)
|
| 145 |
+
if line.strip().startswith("git_full ="):
|
| 146 |
+
mo = re.search(r'=\s*"(.*)"', line)
|
| 147 |
+
if mo:
|
| 148 |
+
keywords["full"] = mo.group(1)
|
| 149 |
+
if line.strip().startswith("git_date ="):
|
| 150 |
+
mo = re.search(r'=\s*"(.*)"', line)
|
| 151 |
+
if mo:
|
| 152 |
+
keywords["date"] = mo.group(1)
|
| 153 |
+
except EnvironmentError:
|
| 154 |
+
pass
|
| 155 |
+
return keywords
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
@register_vcs_handler("git", "keywords")
|
| 159 |
+
def git_versions_from_keywords(keywords, tag_prefix, verbose):
|
| 160 |
+
"""Get version information from git keywords."""
|
| 161 |
+
if "refnames" not in keywords:
|
| 162 |
+
raise NotThisMethod("Short version file found")
|
| 163 |
+
date = keywords.get("date")
|
| 164 |
+
if date is not None:
|
| 165 |
+
# Use only the last line. Previous lines may contain GPG signature
|
| 166 |
+
# information.
|
| 167 |
+
date = date.splitlines()[-1]
|
| 168 |
+
|
| 169 |
+
# git-2.2.0 added "%cI", which expands to an ISO-8601 -compliant
|
| 170 |
+
# datestamp. However we prefer "%ci" (which expands to an "ISO-8601
|
| 171 |
+
# -like" string, which we must then edit to make compliant), because
|
| 172 |
+
# it's been around since git-1.5.3, and it's too difficult to
|
| 173 |
+
# discover which version we're using, or to work around using an
|
| 174 |
+
# older one.
|
| 175 |
+
date = date.strip().replace(" ", "T", 1).replace(" ", "", 1)
|
| 176 |
+
refnames = keywords["refnames"].strip()
|
| 177 |
+
if refnames.startswith("$Format"):
|
| 178 |
+
if verbose:
|
| 179 |
+
print("keywords are unexpanded, not using")
|
| 180 |
+
raise NotThisMethod("unexpanded keywords, not a git-archive tarball")
|
| 181 |
+
refs = {r.strip() for r in refnames.strip("()").split(",")}
|
| 182 |
+
# starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of
|
| 183 |
+
# just "foo-1.0". If we see a "tag: " prefix, prefer those.
|
| 184 |
+
TAG = "tag: "
|
| 185 |
+
tags = {r[len(TAG):] for r in refs if r.startswith(TAG)}
|
| 186 |
+
if not tags:
|
| 187 |
+
# Either we're using git < 1.8.3, or there really are no tags. We use
|
| 188 |
+
# a heuristic: assume all version tags have a digit. The old git %d
|
| 189 |
+
# expansion behaves like git log --decorate=short and strips out the
|
| 190 |
+
# refs/heads/ and refs/tags/ prefixes that would let us distinguish
|
| 191 |
+
# between branches and tags. By ignoring refnames without digits, we
|
| 192 |
+
# filter out many common branch names like "release" and
|
| 193 |
+
# "stabilization", as well as "HEAD" and "master".
|
| 194 |
+
tags = {r for r in refs if re.search(r'\d', r)}
|
| 195 |
+
if verbose:
|
| 196 |
+
print("discarding '%s', no digits" % ",".join(refs - tags))
|
| 197 |
+
if verbose:
|
| 198 |
+
print("likely tags: %s" % ",".join(sorted(tags)))
|
| 199 |
+
for ref in sorted(tags):
|
| 200 |
+
# sorting will prefer e.g. "2.0" over "2.0rc1"
|
| 201 |
+
if ref.startswith(tag_prefix):
|
| 202 |
+
r = ref[len(tag_prefix):]
|
| 203 |
+
# Filter out refs that exactly match prefix or that don't start
|
| 204 |
+
# with a number once the prefix is stripped (mostly a concern
|
| 205 |
+
# when prefix is '')
|
| 206 |
+
if not re.match(r'\d', r):
|
| 207 |
+
continue
|
| 208 |
+
if verbose:
|
| 209 |
+
print("picking %s" % r)
|
| 210 |
+
return {"version": r,
|
| 211 |
+
"full-revisionid": keywords["full"].strip(),
|
| 212 |
+
"dirty": False, "error": None,
|
| 213 |
+
"date": date}
|
| 214 |
+
# no suitable tags, so version is "0+unknown", but full hex is still there
|
| 215 |
+
if verbose:
|
| 216 |
+
print("no suitable tags, using unknown + full revision id")
|
| 217 |
+
return {"version": "0+unknown",
|
| 218 |
+
"full-revisionid": keywords["full"].strip(),
|
| 219 |
+
"dirty": False, "error": "no suitable tags", "date": None}
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
@register_vcs_handler("git", "pieces_from_vcs")
|
| 223 |
+
def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command):
|
| 224 |
+
"""Get version from 'git describe' in the root of the source tree.
|
| 225 |
+
|
| 226 |
+
This only gets called if the git-archive 'subst' keywords were *not*
|
| 227 |
+
expanded, and _version.py hasn't already been rewritten with a short
|
| 228 |
+
version string, meaning we're inside a checked out source tree.
|
| 229 |
+
"""
|
| 230 |
+
GITS = ["git"]
|
| 231 |
+
if sys.platform == "win32":
|
| 232 |
+
GITS = ["git.cmd", "git.exe"]
|
| 233 |
+
|
| 234 |
+
_, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root,
|
| 235 |
+
hide_stderr=True)
|
| 236 |
+
if rc != 0:
|
| 237 |
+
if verbose:
|
| 238 |
+
print("Directory %s not under git control" % root)
|
| 239 |
+
raise NotThisMethod("'git rev-parse --git-dir' returned error")
|
| 240 |
+
|
| 241 |
+
# if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty]
|
| 242 |
+
# if there isn't one, this yields HEX[-dirty] (no NUM)
|
| 243 |
+
describe_out, rc = runner(GITS, ["describe", "--tags", "--dirty",
|
| 244 |
+
"--always", "--long",
|
| 245 |
+
"--match", "%s*" % tag_prefix],
|
| 246 |
+
cwd=root)
|
| 247 |
+
# --long was added in git-1.5.5
|
| 248 |
+
if describe_out is None:
|
| 249 |
+
raise NotThisMethod("'git describe' failed")
|
| 250 |
+
describe_out = describe_out.strip()
|
| 251 |
+
full_out, rc = runner(GITS, ["rev-parse", "HEAD"], cwd=root)
|
| 252 |
+
if full_out is None:
|
| 253 |
+
raise NotThisMethod("'git rev-parse' failed")
|
| 254 |
+
full_out = full_out.strip()
|
| 255 |
+
|
| 256 |
+
pieces = {}
|
| 257 |
+
pieces["long"] = full_out
|
| 258 |
+
pieces["short"] = full_out[:7] # maybe improved later
|
| 259 |
+
pieces["error"] = None
|
| 260 |
+
|
| 261 |
+
branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"],
|
| 262 |
+
cwd=root)
|
| 263 |
+
# --abbrev-ref was added in git-1.6.3
|
| 264 |
+
if rc != 0 or branch_name is None:
|
| 265 |
+
raise NotThisMethod("'git rev-parse --abbrev-ref' returned error")
|
| 266 |
+
branch_name = branch_name.strip()
|
| 267 |
+
|
| 268 |
+
if branch_name == "HEAD":
|
| 269 |
+
# If we aren't exactly on a branch, pick a branch which represents
|
| 270 |
+
# the current commit. If all else fails, we are on a branchless
|
| 271 |
+
# commit.
|
| 272 |
+
branches, rc = runner(GITS, ["branch", "--contains"], cwd=root)
|
| 273 |
+
# --contains was added in git-1.5.4
|
| 274 |
+
if rc != 0 or branches is None:
|
| 275 |
+
raise NotThisMethod("'git branch --contains' returned error")
|
| 276 |
+
branches = branches.split("\n")
|
| 277 |
+
|
| 278 |
+
# Remove the first line if we're running detached
|
| 279 |
+
if "(" in branches[0]:
|
| 280 |
+
branches.pop(0)
|
| 281 |
+
|
| 282 |
+
# Strip off the leading "* " from the list of branches.
|
| 283 |
+
branches = [branch[2:] for branch in branches]
|
| 284 |
+
if "master" in branches:
|
| 285 |
+
branch_name = "master"
|
| 286 |
+
elif not branches:
|
| 287 |
+
branch_name = None
|
| 288 |
+
else:
|
| 289 |
+
# Pick the first branch that is returned. Good or bad.
|
| 290 |
+
branch_name = branches[0]
|
| 291 |
+
|
| 292 |
+
pieces["branch"] = branch_name
|
| 293 |
+
|
| 294 |
+
# parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty]
|
| 295 |
+
# TAG might have hyphens.
|
| 296 |
+
git_describe = describe_out
|
| 297 |
+
|
| 298 |
+
# look for -dirty suffix
|
| 299 |
+
dirty = git_describe.endswith("-dirty")
|
| 300 |
+
pieces["dirty"] = dirty
|
| 301 |
+
if dirty:
|
| 302 |
+
git_describe = git_describe[:git_describe.rindex("-dirty")]
|
| 303 |
+
|
| 304 |
+
# now we have TAG-NUM-gHEX or HEX
|
| 305 |
+
|
| 306 |
+
if "-" in git_describe:
|
| 307 |
+
# TAG-NUM-gHEX
|
| 308 |
+
mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe)
|
| 309 |
+
if not mo:
|
| 310 |
+
# unparseable. Maybe git-describe is misbehaving?
|
| 311 |
+
pieces["error"] = ("unable to parse git-describe output: '%s'"
|
| 312 |
+
% describe_out)
|
| 313 |
+
return pieces
|
| 314 |
+
|
| 315 |
+
# tag
|
| 316 |
+
full_tag = mo.group(1)
|
| 317 |
+
if not full_tag.startswith(tag_prefix):
|
| 318 |
+
if verbose:
|
| 319 |
+
fmt = "tag '%s' doesn't start with prefix '%s'"
|
| 320 |
+
print(fmt % (full_tag, tag_prefix))
|
| 321 |
+
pieces["error"] = ("tag '%s' doesn't start with prefix '%s'"
|
| 322 |
+
% (full_tag, tag_prefix))
|
| 323 |
+
return pieces
|
| 324 |
+
pieces["closest-tag"] = full_tag[len(tag_prefix):]
|
| 325 |
+
|
| 326 |
+
# distance: number of commits since tag
|
| 327 |
+
pieces["distance"] = int(mo.group(2))
|
| 328 |
+
|
| 329 |
+
# commit: short hex revision ID
|
| 330 |
+
pieces["short"] = mo.group(3)
|
| 331 |
+
|
| 332 |
+
else:
|
| 333 |
+
# HEX: no tags
|
| 334 |
+
pieces["closest-tag"] = None
|
| 335 |
+
count_out, rc = runner(GITS, ["rev-list", "HEAD", "--count"], cwd=root)
|
| 336 |
+
pieces["distance"] = int(count_out) # total number of commits
|
| 337 |
+
|
| 338 |
+
# commit date: see ISO-8601 comment in git_versions_from_keywords()
|
| 339 |
+
date = runner(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[0].strip()
|
| 340 |
+
# Use only the last line. Previous lines may contain GPG signature
|
| 341 |
+
# information.
|
| 342 |
+
date = date.splitlines()[-1]
|
| 343 |
+
pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1)
|
| 344 |
+
|
| 345 |
+
return pieces
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
def plus_or_dot(pieces):
|
| 349 |
+
"""Return a + if we don't already have one, else return a ."""
|
| 350 |
+
if "+" in pieces.get("closest-tag", ""):
|
| 351 |
+
return "."
|
| 352 |
+
return "+"
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
def render_pep440(pieces):
|
| 356 |
+
"""Build up version string, with post-release "local version identifier".
|
| 357 |
+
|
| 358 |
+
Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you
|
| 359 |
+
get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty
|
| 360 |
+
|
| 361 |
+
Exceptions:
|
| 362 |
+
1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty]
|
| 363 |
+
"""
|
| 364 |
+
if pieces["closest-tag"]:
|
| 365 |
+
rendered = pieces["closest-tag"]
|
| 366 |
+
if pieces["distance"] or pieces["dirty"]:
|
| 367 |
+
rendered += plus_or_dot(pieces)
|
| 368 |
+
rendered += "%d.g%s" % (pieces["distance"], pieces["short"])
|
| 369 |
+
if pieces["dirty"]:
|
| 370 |
+
rendered += ".dirty"
|
| 371 |
+
else:
|
| 372 |
+
# exception #1
|
| 373 |
+
rendered = "0+untagged.%d.g%s" % (pieces["distance"],
|
| 374 |
+
pieces["short"])
|
| 375 |
+
if pieces["dirty"]:
|
| 376 |
+
rendered += ".dirty"
|
| 377 |
+
return rendered
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
def render_pep440_branch(pieces):
|
| 381 |
+
"""TAG[[.dev0]+DISTANCE.gHEX[.dirty]] .
|
| 382 |
+
|
| 383 |
+
The ".dev0" means not master branch. Note that .dev0 sorts backwards
|
| 384 |
+
(a feature branch will appear "older" than the master branch).
|
| 385 |
+
|
| 386 |
+
Exceptions:
|
| 387 |
+
1: no tags. 0[.dev0]+untagged.DISTANCE.gHEX[.dirty]
|
| 388 |
+
"""
|
| 389 |
+
if pieces["closest-tag"]:
|
| 390 |
+
rendered = pieces["closest-tag"]
|
| 391 |
+
if pieces["distance"] or pieces["dirty"]:
|
| 392 |
+
if pieces["branch"] != "master":
|
| 393 |
+
rendered += ".dev0"
|
| 394 |
+
rendered += plus_or_dot(pieces)
|
| 395 |
+
rendered += "%d.g%s" % (pieces["distance"], pieces["short"])
|
| 396 |
+
if pieces["dirty"]:
|
| 397 |
+
rendered += ".dirty"
|
| 398 |
+
else:
|
| 399 |
+
# exception #1
|
| 400 |
+
rendered = "0"
|
| 401 |
+
if pieces["branch"] != "master":
|
| 402 |
+
rendered += ".dev0"
|
| 403 |
+
rendered += "+untagged.%d.g%s" % (pieces["distance"],
|
| 404 |
+
pieces["short"])
|
| 405 |
+
if pieces["dirty"]:
|
| 406 |
+
rendered += ".dirty"
|
| 407 |
+
return rendered
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
def render_pep440_pre(pieces):
|
| 411 |
+
"""TAG[.post0.devDISTANCE] -- No -dirty.
|
| 412 |
+
|
| 413 |
+
Exceptions:
|
| 414 |
+
1: no tags. 0.post0.devDISTANCE
|
| 415 |
+
"""
|
| 416 |
+
if pieces["closest-tag"]:
|
| 417 |
+
rendered = pieces["closest-tag"]
|
| 418 |
+
if pieces["distance"]:
|
| 419 |
+
rendered += ".post0.dev%d" % pieces["distance"]
|
| 420 |
+
else:
|
| 421 |
+
# exception #1
|
| 422 |
+
rendered = "0.post0.dev%d" % pieces["distance"]
|
| 423 |
+
return rendered
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
def render_pep440_post(pieces):
|
| 427 |
+
"""TAG[.postDISTANCE[.dev0]+gHEX] .
|
| 428 |
+
|
| 429 |
+
The ".dev0" means dirty. Note that .dev0 sorts backwards
|
| 430 |
+
(a dirty tree will appear "older" than the corresponding clean one),
|
| 431 |
+
but you shouldn't be releasing software with -dirty anyways.
|
| 432 |
+
|
| 433 |
+
Exceptions:
|
| 434 |
+
1: no tags. 0.postDISTANCE[.dev0]
|
| 435 |
+
"""
|
| 436 |
+
if pieces["closest-tag"]:
|
| 437 |
+
rendered = pieces["closest-tag"]
|
| 438 |
+
if pieces["distance"] or pieces["dirty"]:
|
| 439 |
+
rendered += ".post%d" % pieces["distance"]
|
| 440 |
+
if pieces["dirty"]:
|
| 441 |
+
rendered += ".dev0"
|
| 442 |
+
rendered += plus_or_dot(pieces)
|
| 443 |
+
rendered += "g%s" % pieces["short"]
|
| 444 |
+
else:
|
| 445 |
+
# exception #1
|
| 446 |
+
rendered = "0.post%d" % pieces["distance"]
|
| 447 |
+
if pieces["dirty"]:
|
| 448 |
+
rendered += ".dev0"
|
| 449 |
+
rendered += "+g%s" % pieces["short"]
|
| 450 |
+
return rendered
|
| 451 |
+
|
| 452 |
+
|
| 453 |
+
def render_pep440_post_branch(pieces):
|
| 454 |
+
"""TAG[.postDISTANCE[.dev0]+gHEX[.dirty]] .
|
| 455 |
+
|
| 456 |
+
The ".dev0" means not master branch.
|
| 457 |
+
|
| 458 |
+
Exceptions:
|
| 459 |
+
1: no tags. 0.postDISTANCE[.dev0]+gHEX[.dirty]
|
| 460 |
+
"""
|
| 461 |
+
if pieces["closest-tag"]:
|
| 462 |
+
rendered = pieces["closest-tag"]
|
| 463 |
+
if pieces["distance"] or pieces["dirty"]:
|
| 464 |
+
rendered += ".post%d" % pieces["distance"]
|
| 465 |
+
if pieces["branch"] != "master":
|
| 466 |
+
rendered += ".dev0"
|
| 467 |
+
rendered += plus_or_dot(pieces)
|
| 468 |
+
rendered += "g%s" % pieces["short"]
|
| 469 |
+
if pieces["dirty"]:
|
| 470 |
+
rendered += ".dirty"
|
| 471 |
+
else:
|
| 472 |
+
# exception #1
|
| 473 |
+
rendered = "0.post%d" % pieces["distance"]
|
| 474 |
+
if pieces["branch"] != "master":
|
| 475 |
+
rendered += ".dev0"
|
| 476 |
+
rendered += "+g%s" % pieces["short"]
|
| 477 |
+
if pieces["dirty"]:
|
| 478 |
+
rendered += ".dirty"
|
| 479 |
+
return rendered
|
| 480 |
+
|
| 481 |
+
|
| 482 |
+
def render_pep440_old(pieces):
|
| 483 |
+
"""TAG[.postDISTANCE[.dev0]] .
|
| 484 |
+
|
| 485 |
+
The ".dev0" means dirty.
|
| 486 |
+
|
| 487 |
+
Exceptions:
|
| 488 |
+
1: no tags. 0.postDISTANCE[.dev0]
|
| 489 |
+
"""
|
| 490 |
+
if pieces["closest-tag"]:
|
| 491 |
+
rendered = pieces["closest-tag"]
|
| 492 |
+
if pieces["distance"] or pieces["dirty"]:
|
| 493 |
+
rendered += ".post%d" % pieces["distance"]
|
| 494 |
+
if pieces["dirty"]:
|
| 495 |
+
rendered += ".dev0"
|
| 496 |
+
else:
|
| 497 |
+
# exception #1
|
| 498 |
+
rendered = "0.post%d" % pieces["distance"]
|
| 499 |
+
if pieces["dirty"]:
|
| 500 |
+
rendered += ".dev0"
|
| 501 |
+
return rendered
|
| 502 |
+
|
| 503 |
+
|
| 504 |
+
def render_git_describe(pieces):
|
| 505 |
+
"""TAG[-DISTANCE-gHEX][-dirty].
|
| 506 |
+
|
| 507 |
+
Like 'git describe --tags --dirty --always'.
|
| 508 |
+
|
| 509 |
+
Exceptions:
|
| 510 |
+
1: no tags. HEX[-dirty] (note: no 'g' prefix)
|
| 511 |
+
"""
|
| 512 |
+
if pieces["closest-tag"]:
|
| 513 |
+
rendered = pieces["closest-tag"]
|
| 514 |
+
if pieces["distance"]:
|
| 515 |
+
rendered += "-%d-g%s" % (pieces["distance"], pieces["short"])
|
| 516 |
+
else:
|
| 517 |
+
# exception #1
|
| 518 |
+
rendered = pieces["short"]
|
| 519 |
+
if pieces["dirty"]:
|
| 520 |
+
rendered += "-dirty"
|
| 521 |
+
return rendered
|
| 522 |
+
|
| 523 |
+
|
| 524 |
+
def render_git_describe_long(pieces):
|
| 525 |
+
"""TAG-DISTANCE-gHEX[-dirty].
|
| 526 |
+
|
| 527 |
+
Like 'git describe --tags --dirty --always -long'.
|
| 528 |
+
The distance/hash is unconditional.
|
| 529 |
+
|
| 530 |
+
Exceptions:
|
| 531 |
+
1: no tags. HEX[-dirty] (note: no 'g' prefix)
|
| 532 |
+
"""
|
| 533 |
+
if pieces["closest-tag"]:
|
| 534 |
+
rendered = pieces["closest-tag"]
|
| 535 |
+
rendered += "-%d-g%s" % (pieces["distance"], pieces["short"])
|
| 536 |
+
else:
|
| 537 |
+
# exception #1
|
| 538 |
+
rendered = pieces["short"]
|
| 539 |
+
if pieces["dirty"]:
|
| 540 |
+
rendered += "-dirty"
|
| 541 |
+
return rendered
|
| 542 |
+
|
| 543 |
+
|
| 544 |
+
def render(pieces, style):
|
| 545 |
+
"""Render the given version pieces into the requested style."""
|
| 546 |
+
if pieces["error"]:
|
| 547 |
+
return {"version": "unknown",
|
| 548 |
+
"full-revisionid": pieces.get("long"),
|
| 549 |
+
"dirty": None,
|
| 550 |
+
"error": pieces["error"],
|
| 551 |
+
"date": None}
|
| 552 |
+
|
| 553 |
+
if not style or style == "default":
|
| 554 |
+
style = "pep440" # the default
|
| 555 |
+
|
| 556 |
+
if style == "pep440":
|
| 557 |
+
rendered = render_pep440(pieces)
|
| 558 |
+
elif style == "pep440-branch":
|
| 559 |
+
rendered = render_pep440_branch(pieces)
|
| 560 |
+
elif style == "pep440-pre":
|
| 561 |
+
rendered = render_pep440_pre(pieces)
|
| 562 |
+
elif style == "pep440-post":
|
| 563 |
+
rendered = render_pep440_post(pieces)
|
| 564 |
+
elif style == "pep440-post-branch":
|
| 565 |
+
rendered = render_pep440_post_branch(pieces)
|
| 566 |
+
elif style == "pep440-old":
|
| 567 |
+
rendered = render_pep440_old(pieces)
|
| 568 |
+
elif style == "git-describe":
|
| 569 |
+
rendered = render_git_describe(pieces)
|
| 570 |
+
elif style == "git-describe-long":
|
| 571 |
+
rendered = render_git_describe_long(pieces)
|
| 572 |
+
else:
|
| 573 |
+
raise ValueError("unknown style '%s'" % style)
|
| 574 |
+
|
| 575 |
+
return {"version": rendered, "full-revisionid": pieces["long"],
|
| 576 |
+
"dirty": pieces["dirty"], "error": None,
|
| 577 |
+
"date": pieces.get("date")}
|
| 578 |
+
|
| 579 |
+
|
| 580 |
+
def get_versions():
|
| 581 |
+
"""Get version information or return default if unable to do so."""
|
| 582 |
+
# I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have
|
| 583 |
+
# __file__, we can work backwards from there to the root. Some
|
| 584 |
+
# py2exe/bbfreeze/non-CPython implementations don't do __file__, in which
|
| 585 |
+
# case we can only use expanded keywords.
|
| 586 |
+
|
| 587 |
+
cfg = get_config()
|
| 588 |
+
verbose = cfg.verbose
|
| 589 |
+
|
| 590 |
+
try:
|
| 591 |
+
return git_versions_from_keywords(get_keywords(), cfg.tag_prefix,
|
| 592 |
+
verbose)
|
| 593 |
+
except NotThisMethod:
|
| 594 |
+
pass
|
| 595 |
+
|
| 596 |
+
try:
|
| 597 |
+
root = os.path.realpath(__file__)
|
| 598 |
+
# versionfile_source is the relative path from the top of the source
|
| 599 |
+
# tree (where the .git directory might live) to this file. Invert
|
| 600 |
+
# this to find the root from __file__.
|
| 601 |
+
for _ in cfg.versionfile_source.split('/'):
|
| 602 |
+
root = os.path.dirname(root)
|
| 603 |
+
except NameError:
|
| 604 |
+
return {"version": "0+unknown", "full-revisionid": None,
|
| 605 |
+
"dirty": None,
|
| 606 |
+
"error": "unable to find root of source tree",
|
| 607 |
+
"date": None}
|
| 608 |
+
|
| 609 |
+
try:
|
| 610 |
+
pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose)
|
| 611 |
+
return render(pieces, cfg.style)
|
| 612 |
+
except NotThisMethod:
|
| 613 |
+
pass
|
| 614 |
+
|
| 615 |
+
try:
|
| 616 |
+
if cfg.parentdir_prefix:
|
| 617 |
+
return versions_from_parentdir(cfg.parentdir_prefix, root, verbose)
|
| 618 |
+
except NotThisMethod:
|
| 619 |
+
pass
|
| 620 |
+
|
| 621 |
+
return {"version": "0+unknown", "full-revisionid": None,
|
| 622 |
+
"dirty": None,
|
| 623 |
+
"error": "unable to compute version", "date": None}
|
Generator/interpol/api.py
ADDED
|
@@ -0,0 +1,560 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""High level interpolation API"""
|
| 2 |
+
|
| 3 |
+
__all__ = ['grid_pull', 'grid_push', 'grid_count', 'grid_grad',
|
| 4 |
+
'spline_coeff', 'spline_coeff_nd',
|
| 5 |
+
'identity_grid', 'add_identity_grid', 'add_identity_grid_']
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from .utils import expanded_shape, matvec
|
| 9 |
+
from .jit_utils import movedim1, meshgrid
|
| 10 |
+
from .autograd import (GridPull, GridPush, GridCount, GridGrad,
|
| 11 |
+
SplineCoeff, SplineCoeffND)
|
| 12 |
+
from . import backend, jitfields
|
| 13 |
+
|
| 14 |
+
_doc_interpolation = \
|
| 15 |
+
"""`interpolation` can be an int, a string or an InterpolationType.
|
| 16 |
+
Possible values are:
|
| 17 |
+
- 0 or 'nearest'
|
| 18 |
+
- 1 or 'linear'
|
| 19 |
+
- 2 or 'quadratic'
|
| 20 |
+
- 3 or 'cubic'
|
| 21 |
+
- 4 or 'fourth'
|
| 22 |
+
- 5 or 'fifth'
|
| 23 |
+
- etc.
|
| 24 |
+
A list of values can be provided, in the order [W, H, D],
|
| 25 |
+
to specify dimension-specific interpolation orders."""
|
| 26 |
+
|
| 27 |
+
_doc_bound = \
|
| 28 |
+
"""`bound` can be an int, a string or a BoundType.
|
| 29 |
+
Possible values are:
|
| 30 |
+
- 'replicate' or 'nearest' : a a a | a b c d | d d d
|
| 31 |
+
- 'dct1' or 'mirror' : d c b | a b c d | c b a
|
| 32 |
+
- 'dct2' or 'reflect' : c b a | a b c d | d c b
|
| 33 |
+
- 'dst1' or 'antimirror' : -b -a 0 | a b c d | 0 -d -c
|
| 34 |
+
- 'dst2' or 'antireflect' : -c -b -a | a b c d | -d -c -b
|
| 35 |
+
- 'dft' or 'wrap' : b c d | a b c d | a b c
|
| 36 |
+
- 'zero' or 'zeros' : 0 0 0 | a b c d | 0 0 0
|
| 37 |
+
A list of values can be provided, in the order [W, H, D],
|
| 38 |
+
to specify dimension-specific boundary conditions.
|
| 39 |
+
Note that
|
| 40 |
+
- `dft` corresponds to circular padding
|
| 41 |
+
- `dct2` corresponds to Neumann boundary conditions (symmetric)
|
| 42 |
+
- `dst2` corresponds to Dirichlet boundary conditions (antisymmetric)
|
| 43 |
+
See https://en.wikipedia.org/wiki/Discrete_cosine_transform
|
| 44 |
+
https://en.wikipedia.org/wiki/Discrete_sine_transform"""
|
| 45 |
+
|
| 46 |
+
_doc_bound_coeff = \
|
| 47 |
+
"""`bound` can be an int, a string or a BoundType.
|
| 48 |
+
Possible values are:
|
| 49 |
+
- 'replicate' or 'nearest' : a a a | a b c d | d d d
|
| 50 |
+
- 'dct1' or 'mirror' : d c b | a b c d | c b a
|
| 51 |
+
- 'dct2' or 'reflect' : c b a | a b c d | d c b
|
| 52 |
+
- 'dst1' or 'antimirror' : -b -a 0 | a b c d | 0 -d -c
|
| 53 |
+
- 'dst2' or 'antireflect' : -c -b -a | a b c d | -d -c -b
|
| 54 |
+
- 'dft' or 'wrap' : b c d | a b c d | a b c
|
| 55 |
+
- 'zero' or 'zeros' : 0 0 0 | a b c d | 0 0 0
|
| 56 |
+
A list of values can be provided, in the order [W, H, D],
|
| 57 |
+
to specify dimension-specific boundary conditions.
|
| 58 |
+
Note that
|
| 59 |
+
- `dft` corresponds to circular padding
|
| 60 |
+
- `dct1` corresponds to mirroring about the center of the first/last voxel
|
| 61 |
+
- `dct2` corresponds to mirroring about the edge of the first/last voxel
|
| 62 |
+
See https://en.wikipedia.org/wiki/Discrete_cosine_transform
|
| 63 |
+
https://en.wikipedia.org/wiki/Discrete_sine_transform
|
| 64 |
+
|
| 65 |
+
/!\ Only 'dct1', 'dct2' and 'dft' are implemented for interpolation
|
| 66 |
+
orders >= 6."""
|
| 67 |
+
|
| 68 |
+
_ref_coeff = \
|
| 69 |
+
"""..[1] M. Unser, A. Aldroubi and M. Eden.
|
| 70 |
+
"B-Spline Signal Processing: Part I-Theory,"
|
| 71 |
+
IEEE Transactions on Signal Processing 41(2):821-832 (1993).
|
| 72 |
+
..[2] M. Unser, A. Aldroubi and M. Eden.
|
| 73 |
+
"B-Spline Signal Processing: Part II-Efficient Design and Applications,"
|
| 74 |
+
IEEE Transactions on Signal Processing 41(2):834-848 (1993).
|
| 75 |
+
..[3] M. Unser.
|
| 76 |
+
"Splines: A Perfect Fit for Signal and Image Processing,"
|
| 77 |
+
IEEE Signal Processing Magazine 16(6):22-38 (1999).
|
| 78 |
+
"""
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def _preproc(grid, input=None, mode=None):
|
| 82 |
+
"""Preprocess tensors for pull/push/count/grad
|
| 83 |
+
|
| 84 |
+
Low level bindings expect inputs of shape
|
| 85 |
+
[batch, channel, *spatial] and [batch, *spatial, dim], whereas
|
| 86 |
+
the high level python API accepts inputs of shape
|
| 87 |
+
[..., [channel], *spatial] and [..., *spatial, dim].
|
| 88 |
+
|
| 89 |
+
This function broadcasts and reshapes the input tensors accordingly.
|
| 90 |
+
/!\\ This *can* trigger large allocations /!\\
|
| 91 |
+
"""
|
| 92 |
+
dim = grid.shape[-1]
|
| 93 |
+
if input is None:
|
| 94 |
+
spatial = grid.shape[-dim-1:-1]
|
| 95 |
+
batch = grid.shape[:-dim-1]
|
| 96 |
+
grid = grid.reshape([-1, *spatial, dim])
|
| 97 |
+
info = dict(batch=batch, channel=[1] if batch else [], dim=dim)
|
| 98 |
+
return grid, info
|
| 99 |
+
|
| 100 |
+
grid_spatial = grid.shape[-dim-1:-1]
|
| 101 |
+
grid_batch = grid.shape[:-dim-1]
|
| 102 |
+
input_spatial = input.shape[-dim:]
|
| 103 |
+
channel = 0 if input.dim() == dim else input.shape[-dim-1]
|
| 104 |
+
input_batch = input.shape[:-dim-1]
|
| 105 |
+
|
| 106 |
+
if mode == 'push':
|
| 107 |
+
grid_spatial = input_spatial = expanded_shape(grid_spatial, input_spatial)
|
| 108 |
+
|
| 109 |
+
# broadcast and reshape
|
| 110 |
+
batch = expanded_shape(grid_batch, input_batch)
|
| 111 |
+
grid = grid.expand([*batch, *grid_spatial, dim])
|
| 112 |
+
grid = grid.reshape([-1, *grid_spatial, dim])
|
| 113 |
+
input = input.expand([*batch, channel or 1, *input_spatial])
|
| 114 |
+
input = input.reshape([-1, channel or 1, *input_spatial])
|
| 115 |
+
|
| 116 |
+
out_channel = [channel] if channel else ([1] if batch else [])
|
| 117 |
+
info = dict(batch=batch, channel=out_channel, dim=dim)
|
| 118 |
+
return grid, input, info
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def _postproc(out, shape_info, mode):
|
| 122 |
+
"""Postprocess tensors for pull/push/count/grad"""
|
| 123 |
+
dim = shape_info['dim']
|
| 124 |
+
if mode != 'grad':
|
| 125 |
+
spatial = out.shape[-dim:]
|
| 126 |
+
feat = []
|
| 127 |
+
else:
|
| 128 |
+
spatial = out.shape[-dim-1:-1]
|
| 129 |
+
feat = [out.shape[-1]]
|
| 130 |
+
batch = shape_info['batch']
|
| 131 |
+
channel = shape_info['channel']
|
| 132 |
+
|
| 133 |
+
out = out.reshape([*batch, *channel, *spatial, *feat])
|
| 134 |
+
return out
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def grid_pull(input, grid, interpolation='linear', bound='zero',
|
| 138 |
+
extrapolate=False, prefilter=False):
|
| 139 |
+
"""Sample an image with respect to a deformation field.
|
| 140 |
+
|
| 141 |
+
Notes
|
| 142 |
+
-----
|
| 143 |
+
{interpolation}
|
| 144 |
+
|
| 145 |
+
{bound}
|
| 146 |
+
|
| 147 |
+
If the input dtype is not a floating point type, the input image is
|
| 148 |
+
assumed to contain labels. Then, unique labels are extracted
|
| 149 |
+
and resampled individually, making them soft labels. Finally,
|
| 150 |
+
the label map is reconstructed from the individual soft labels by
|
| 151 |
+
assigning the label with maximum soft value.
|
| 152 |
+
|
| 153 |
+
Parameters
|
| 154 |
+
----------
|
| 155 |
+
input : (..., [channel], *inshape) tensor
|
| 156 |
+
Input image.
|
| 157 |
+
grid : (..., *outshape, dim) tensor
|
| 158 |
+
Transformation field.
|
| 159 |
+
interpolation : int or sequence[int], default=1
|
| 160 |
+
Interpolation order.
|
| 161 |
+
bound : BoundType or sequence[BoundType], default='zero'
|
| 162 |
+
Boundary conditions.
|
| 163 |
+
extrapolate : bool or int, default=True
|
| 164 |
+
Extrapolate out-of-bound data.
|
| 165 |
+
prefilter : bool, default=False
|
| 166 |
+
Apply spline pre-filter (= interpolates the input)
|
| 167 |
+
|
| 168 |
+
Returns
|
| 169 |
+
-------
|
| 170 |
+
output : (..., [channel], *outshape) tensor
|
| 171 |
+
Deformed image.
|
| 172 |
+
|
| 173 |
+
"""
|
| 174 |
+
if backend.jitfields and jitfields.available:
|
| 175 |
+
return jitfields.grid_pull(input, grid, interpolation, bound,
|
| 176 |
+
extrapolate, prefilter)
|
| 177 |
+
|
| 178 |
+
grid, input, shape_info = _preproc(grid, input)
|
| 179 |
+
batch, channel = input.shape[:2]
|
| 180 |
+
dim = grid.shape[-1]
|
| 181 |
+
|
| 182 |
+
if not input.dtype.is_floating_point:
|
| 183 |
+
# label map -> specific processing
|
| 184 |
+
out = input.new_zeros([batch, channel, *grid.shape[1:-1]])
|
| 185 |
+
pmax = grid.new_zeros([batch, channel, *grid.shape[1:-1]])
|
| 186 |
+
for label in input.unique():
|
| 187 |
+
soft = (input == label).to(grid.dtype)
|
| 188 |
+
if prefilter:
|
| 189 |
+
input = spline_coeff_nd(soft, interpolation=interpolation,
|
| 190 |
+
bound=bound, dim=dim, inplace=True)
|
| 191 |
+
soft = GridPull.apply(soft, grid, interpolation, bound, extrapolate)
|
| 192 |
+
out[soft > pmax] = label
|
| 193 |
+
pmax = torch.max(pmax, soft)
|
| 194 |
+
else:
|
| 195 |
+
if prefilter:
|
| 196 |
+
input = spline_coeff_nd(input, interpolation=interpolation,
|
| 197 |
+
bound=bound, dim=dim)
|
| 198 |
+
out = GridPull.apply(input, grid, interpolation, bound, extrapolate)
|
| 199 |
+
|
| 200 |
+
return _postproc(out, shape_info, mode='pull')
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def grid_push(input, grid, shape=None, interpolation='linear', bound='zero',
|
| 204 |
+
extrapolate=False, prefilter=False):
|
| 205 |
+
"""Splat an image with respect to a deformation field (pull adjoint).
|
| 206 |
+
|
| 207 |
+
Notes
|
| 208 |
+
-----
|
| 209 |
+
{interpolation}
|
| 210 |
+
|
| 211 |
+
{bound}
|
| 212 |
+
|
| 213 |
+
Parameters
|
| 214 |
+
----------
|
| 215 |
+
input : (..., [channel], *inshape) tensor
|
| 216 |
+
Input image.
|
| 217 |
+
grid : (..., *inshape, dim) tensor
|
| 218 |
+
Transformation field.
|
| 219 |
+
shape : sequence[int], default=inshape
|
| 220 |
+
Output shape
|
| 221 |
+
interpolation : int or sequence[int], default=1
|
| 222 |
+
Interpolation order.
|
| 223 |
+
bound : BoundType, or sequence[BoundType], default='zero'
|
| 224 |
+
Boundary conditions.
|
| 225 |
+
extrapolate : bool or int, default=True
|
| 226 |
+
Extrapolate out-of-bound data.
|
| 227 |
+
prefilter : bool, default=False
|
| 228 |
+
Apply spline pre-filter.
|
| 229 |
+
|
| 230 |
+
Returns
|
| 231 |
+
-------
|
| 232 |
+
output : (..., [channel], *shape) tensor
|
| 233 |
+
Spatted image.
|
| 234 |
+
|
| 235 |
+
"""
|
| 236 |
+
if backend.jitfields and jitfields.available:
|
| 237 |
+
return jitfields.grid_push(input, grid, shape, interpolation, bound,
|
| 238 |
+
extrapolate, prefilter)
|
| 239 |
+
|
| 240 |
+
grid, input, shape_info = _preproc(grid, input, mode='push')
|
| 241 |
+
dim = grid.shape[-1]
|
| 242 |
+
|
| 243 |
+
if shape is None:
|
| 244 |
+
shape = tuple(input.shape[2:])
|
| 245 |
+
|
| 246 |
+
out = GridPush.apply(input, grid, shape, interpolation, bound, extrapolate)
|
| 247 |
+
if prefilter:
|
| 248 |
+
out = spline_coeff_nd(out, interpolation=interpolation, bound=bound,
|
| 249 |
+
dim=dim, inplace=True)
|
| 250 |
+
return _postproc(out, shape_info, mode='push')
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
def grid_count(grid, shape=None, interpolation='linear', bound='zero',
|
| 254 |
+
extrapolate=False):
|
| 255 |
+
"""Splatting weights with respect to a deformation field (pull adjoint).
|
| 256 |
+
|
| 257 |
+
Notes
|
| 258 |
+
-----
|
| 259 |
+
{interpolation}
|
| 260 |
+
|
| 261 |
+
{bound}
|
| 262 |
+
|
| 263 |
+
Parameters
|
| 264 |
+
----------
|
| 265 |
+
grid : (..., *inshape, dim) tensor
|
| 266 |
+
Transformation field.
|
| 267 |
+
shape : sequence[int], default=inshape
|
| 268 |
+
Output shape
|
| 269 |
+
interpolation : int or sequence[int], default=1
|
| 270 |
+
Interpolation order.
|
| 271 |
+
bound : BoundType, or sequence[BoundType], default='zero'
|
| 272 |
+
Boundary conditions.
|
| 273 |
+
extrapolate : bool or int, default=True
|
| 274 |
+
Extrapolate out-of-bound data.
|
| 275 |
+
|
| 276 |
+
Returns
|
| 277 |
+
-------
|
| 278 |
+
output : (..., [1], *shape) tensor
|
| 279 |
+
Splatted weights.
|
| 280 |
+
|
| 281 |
+
"""
|
| 282 |
+
if backend.jitfields and jitfields.available:
|
| 283 |
+
return jitfields.grid_count(grid, shape, interpolation, bound, extrapolate)
|
| 284 |
+
|
| 285 |
+
grid, shape_info = _preproc(grid)
|
| 286 |
+
out = GridCount.apply(grid, shape, interpolation, bound, extrapolate)
|
| 287 |
+
return _postproc(out, shape_info, mode='count')
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
def grid_grad(input, grid, interpolation='linear', bound='zero',
|
| 291 |
+
extrapolate=False, prefilter=False):
|
| 292 |
+
"""Sample spatial gradients of an image with respect to a deformation field.
|
| 293 |
+
|
| 294 |
+
Notes
|
| 295 |
+
-----
|
| 296 |
+
{interpolation}
|
| 297 |
+
|
| 298 |
+
{bound}
|
| 299 |
+
|
| 300 |
+
Parameters
|
| 301 |
+
----------
|
| 302 |
+
input : (..., [channel], *inshape) tensor
|
| 303 |
+
Input image.
|
| 304 |
+
grid : (..., *inshape, dim) tensor
|
| 305 |
+
Transformation field.
|
| 306 |
+
shape : sequence[int], default=inshape
|
| 307 |
+
Output shape
|
| 308 |
+
interpolation : int or sequence[int], default=1
|
| 309 |
+
Interpolation order.
|
| 310 |
+
bound : BoundType, or sequence[BoundType], default='zero'
|
| 311 |
+
Boundary conditions.
|
| 312 |
+
extrapolate : bool or int, default=True
|
| 313 |
+
Extrapolate out-of-bound data.
|
| 314 |
+
prefilter : bool, default=False
|
| 315 |
+
Apply spline pre-filter (= interpolates the input)
|
| 316 |
+
|
| 317 |
+
Returns
|
| 318 |
+
-------
|
| 319 |
+
output : (..., [channel], *shape, dim) tensor
|
| 320 |
+
Sampled gradients.
|
| 321 |
+
|
| 322 |
+
"""
|
| 323 |
+
if backend.jitfields and jitfields.available:
|
| 324 |
+
return jitfields.grid_grad(input, grid, interpolation, bound,
|
| 325 |
+
extrapolate, prefilter)
|
| 326 |
+
|
| 327 |
+
grid, input, shape_info = _preproc(grid, input)
|
| 328 |
+
dim = grid.shape[-1]
|
| 329 |
+
if prefilter:
|
| 330 |
+
input = spline_coeff_nd(input, interpolation, bound, dim)
|
| 331 |
+
out = GridGrad.apply(input, grid, interpolation, bound, extrapolate)
|
| 332 |
+
return _postproc(out, shape_info, mode='grad')
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
def spline_coeff(input, interpolation='linear', bound='dct2', dim=-1,
|
| 336 |
+
inplace=False):
|
| 337 |
+
"""Compute the interpolating spline coefficients, for a given spline order
|
| 338 |
+
and boundary conditions, along a single dimension.
|
| 339 |
+
|
| 340 |
+
Notes
|
| 341 |
+
-----
|
| 342 |
+
{interpolation}
|
| 343 |
+
|
| 344 |
+
{bound}
|
| 345 |
+
|
| 346 |
+
References
|
| 347 |
+
----------
|
| 348 |
+
{ref}
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
Parameters
|
| 352 |
+
----------
|
| 353 |
+
input : tensor
|
| 354 |
+
Input image.
|
| 355 |
+
interpolation : int or sequence[int], default=1
|
| 356 |
+
Interpolation order.
|
| 357 |
+
bound : BoundType or sequence[BoundType], default='dct1'
|
| 358 |
+
Boundary conditions.
|
| 359 |
+
dim : int, default=-1
|
| 360 |
+
Dimension along which to process
|
| 361 |
+
inplace : bool, default=False
|
| 362 |
+
Process the volume in place.
|
| 363 |
+
|
| 364 |
+
Returns
|
| 365 |
+
-------
|
| 366 |
+
output : tensor
|
| 367 |
+
Coefficient image.
|
| 368 |
+
|
| 369 |
+
"""
|
| 370 |
+
# This implementation is based on the file bsplines.c in SPM12, written
|
| 371 |
+
# by John Ashburner, which is itself based on the file coeff.c,
|
| 372 |
+
# written by Philippe Thevenaz: http://bigwww.epfl.ch/thevenaz/interpolation
|
| 373 |
+
# . DCT1 boundary conditions were derived by Thevenaz and Unser.
|
| 374 |
+
# . DFT boundary conditions were derived by John Ashburner.
|
| 375 |
+
# SPM12 is released under the GNU-GPL v2 license.
|
| 376 |
+
# Philippe Thevenaz's code does not have an explicit license as far
|
| 377 |
+
# as we know.
|
| 378 |
+
if backend.jitfields and jitfields.available:
|
| 379 |
+
return jitfields.spline_coeff(input, interpolation, bound,
|
| 380 |
+
dim, inplace)
|
| 381 |
+
|
| 382 |
+
out = SplineCoeff.apply(input, bound, interpolation, dim, inplace)
|
| 383 |
+
return out
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
def spline_coeff_nd(input, interpolation='linear', bound='dct2', dim=None,
|
| 387 |
+
inplace=False):
|
| 388 |
+
"""Compute the interpolating spline coefficients, for a given spline order
|
| 389 |
+
and boundary conditions, along the last `dim` dimensions.
|
| 390 |
+
|
| 391 |
+
Notes
|
| 392 |
+
-----
|
| 393 |
+
{interpolation}
|
| 394 |
+
|
| 395 |
+
{bound}
|
| 396 |
+
|
| 397 |
+
References
|
| 398 |
+
----------
|
| 399 |
+
{ref}
|
| 400 |
+
|
| 401 |
+
Parameters
|
| 402 |
+
----------
|
| 403 |
+
input : (..., *spatial) tensor
|
| 404 |
+
Input image.
|
| 405 |
+
interpolation : int or sequence[int], default=1
|
| 406 |
+
Interpolation order.
|
| 407 |
+
bound : BoundType or sequence[BoundType], default='dct1'
|
| 408 |
+
Boundary conditions.
|
| 409 |
+
dim : int, default=-1
|
| 410 |
+
Number of spatial dimensions
|
| 411 |
+
inplace : bool, default=False
|
| 412 |
+
Process the volume in place.
|
| 413 |
+
|
| 414 |
+
Returns
|
| 415 |
+
-------
|
| 416 |
+
output : (..., *spatial) tensor
|
| 417 |
+
Coefficient image.
|
| 418 |
+
|
| 419 |
+
"""
|
| 420 |
+
# This implementation is based on the file bsplines.c in SPM12, written
|
| 421 |
+
# by John Ashburner, which is itself based on the file coeff.c,
|
| 422 |
+
# written by Philippe Thevenaz: http://bigwww.epfl.ch/thevenaz/interpolation
|
| 423 |
+
# . DCT1 boundary conditions were derived by Thevenaz and Unser.
|
| 424 |
+
# . DFT boundary conditions were derived by John Ashburner.
|
| 425 |
+
# SPM12 is released under the GNU-GPL v2 license.
|
| 426 |
+
# Philippe Thevenaz's code does not have an explicit license as far
|
| 427 |
+
# as we know.
|
| 428 |
+
if backend.jitfields and jitfields.available:
|
| 429 |
+
return jitfields.spline_coeff_nd(input, interpolation, bound,
|
| 430 |
+
dim, inplace)
|
| 431 |
+
|
| 432 |
+
out = SplineCoeffND.apply(input, bound, interpolation, dim, inplace)
|
| 433 |
+
return out
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
grid_pull.__doc__ = grid_pull.__doc__.format(
|
| 437 |
+
interpolation=_doc_interpolation, bound=_doc_bound)
|
| 438 |
+
grid_push.__doc__ = grid_push.__doc__.format(
|
| 439 |
+
interpolation=_doc_interpolation, bound=_doc_bound)
|
| 440 |
+
grid_count.__doc__ = grid_count.__doc__.format(
|
| 441 |
+
interpolation=_doc_interpolation, bound=_doc_bound)
|
| 442 |
+
grid_grad.__doc__ = grid_grad.__doc__.format(
|
| 443 |
+
interpolation=_doc_interpolation, bound=_doc_bound)
|
| 444 |
+
spline_coeff.__doc__ = spline_coeff.__doc__.format(
|
| 445 |
+
interpolation=_doc_interpolation, bound=_doc_bound_coeff, ref=_ref_coeff)
|
| 446 |
+
spline_coeff_nd.__doc__ = spline_coeff_nd.__doc__.format(
|
| 447 |
+
interpolation=_doc_interpolation, bound=_doc_bound_coeff, ref=_ref_coeff)
|
| 448 |
+
|
| 449 |
+
# aliases
|
| 450 |
+
pull = grid_pull
|
| 451 |
+
push = grid_push
|
| 452 |
+
count = grid_count
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
def identity_grid(shape, dtype=None, device=None):
|
| 456 |
+
"""Returns an identity deformation field.
|
| 457 |
+
|
| 458 |
+
Parameters
|
| 459 |
+
----------
|
| 460 |
+
shape : (dim,) sequence of int
|
| 461 |
+
Spatial dimension of the field.
|
| 462 |
+
dtype : torch.dtype, default=`get_default_dtype()`
|
| 463 |
+
Data type.
|
| 464 |
+
device torch.device, optional
|
| 465 |
+
Device.
|
| 466 |
+
|
| 467 |
+
Returns
|
| 468 |
+
-------
|
| 469 |
+
grid : (*shape, dim) tensor
|
| 470 |
+
Transformation field
|
| 471 |
+
|
| 472 |
+
"""
|
| 473 |
+
mesh1d = [torch.arange(float(s), dtype=dtype, device=device)
|
| 474 |
+
for s in shape]
|
| 475 |
+
grid = torch.stack(meshgrid(mesh1d), dim=-1)
|
| 476 |
+
return grid
|
| 477 |
+
|
| 478 |
+
|
| 479 |
+
@torch.jit.script
|
| 480 |
+
def add_identity_grid_(disp):
|
| 481 |
+
"""Adds the identity grid to a displacement field, inplace.
|
| 482 |
+
|
| 483 |
+
Parameters
|
| 484 |
+
----------
|
| 485 |
+
disp : (..., *spatial, dim) tensor
|
| 486 |
+
Displacement field
|
| 487 |
+
|
| 488 |
+
Returns
|
| 489 |
+
-------
|
| 490 |
+
grid : (..., *spatial, dim) tensor
|
| 491 |
+
Transformation field
|
| 492 |
+
|
| 493 |
+
"""
|
| 494 |
+
dim = disp.shape[-1]
|
| 495 |
+
spatial = disp.shape[-dim-1:-1]
|
| 496 |
+
mesh1d = [torch.arange(s, dtype=disp.dtype, device=disp.device)
|
| 497 |
+
for s in spatial]
|
| 498 |
+
grid = meshgrid(mesh1d)
|
| 499 |
+
disp = movedim1(disp, -1, 0)
|
| 500 |
+
for i, grid1 in enumerate(grid):
|
| 501 |
+
disp[i].add_(grid1)
|
| 502 |
+
disp = movedim1(disp, 0, -1)
|
| 503 |
+
return disp
|
| 504 |
+
|
| 505 |
+
|
| 506 |
+
@torch.jit.script
|
| 507 |
+
def add_identity_grid(disp):
|
| 508 |
+
"""Adds the identity grid to a displacement field.
|
| 509 |
+
|
| 510 |
+
Parameters
|
| 511 |
+
----------
|
| 512 |
+
disp : (..., *spatial, dim) tensor
|
| 513 |
+
Displacement field
|
| 514 |
+
|
| 515 |
+
Returns
|
| 516 |
+
-------
|
| 517 |
+
grid : (..., *spatial, dim) tensor
|
| 518 |
+
Transformation field
|
| 519 |
+
|
| 520 |
+
"""
|
| 521 |
+
return add_identity_grid_(disp.clone())
|
| 522 |
+
|
| 523 |
+
|
| 524 |
+
def affine_grid(mat, shape):
|
| 525 |
+
"""Create a dense transformation grid from an affine matrix.
|
| 526 |
+
|
| 527 |
+
Parameters
|
| 528 |
+
----------
|
| 529 |
+
mat : (..., D[+1], D+1) tensor
|
| 530 |
+
Affine matrix (or matrices).
|
| 531 |
+
shape : (D,) sequence[int]
|
| 532 |
+
Shape of the grid, with length D.
|
| 533 |
+
|
| 534 |
+
Returns
|
| 535 |
+
-------
|
| 536 |
+
grid : (..., *shape, D) tensor
|
| 537 |
+
Dense transformation grid
|
| 538 |
+
|
| 539 |
+
"""
|
| 540 |
+
mat = torch.as_tensor(mat)
|
| 541 |
+
shape = list(shape)
|
| 542 |
+
nb_dim = mat.shape[-1] - 1
|
| 543 |
+
if nb_dim != len(shape):
|
| 544 |
+
raise ValueError('Dimension of the affine matrix ({}) and shape ({}) '
|
| 545 |
+
'are not the same.'.format(nb_dim, len(shape)))
|
| 546 |
+
if mat.shape[-2] not in (nb_dim, nb_dim+1):
|
| 547 |
+
raise ValueError('First argument should be matrces of shape '
|
| 548 |
+
'(..., {0}, {1}) or (..., {1], {1}) but got {2}.'
|
| 549 |
+
.format(nb_dim, nb_dim+1, mat.shape))
|
| 550 |
+
batch_shape = mat.shape[:-2]
|
| 551 |
+
grid = identity_grid(shape, mat.dtype, mat.device)
|
| 552 |
+
if batch_shape:
|
| 553 |
+
for _ in range(len(batch_shape)):
|
| 554 |
+
grid = grid.unsqueeze(0)
|
| 555 |
+
for _ in range(nb_dim):
|
| 556 |
+
mat = mat.unsqueeze(-1)
|
| 557 |
+
lin = mat[..., :nb_dim, :nb_dim]
|
| 558 |
+
off = mat[..., :nb_dim, -1]
|
| 559 |
+
grid = matvec(lin, grid) + off
|
| 560 |
+
return grid
|
Generator/interpol/autograd.py
ADDED
|
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""AutoGrad version of pull/push/count/grad"""
|
| 2 |
+
import torch
|
| 3 |
+
from .coeff import spline_coeff_nd, spline_coeff
|
| 4 |
+
from .bounds import BoundType
|
| 5 |
+
from .splines import InterpolationType
|
| 6 |
+
from .pushpull import (
|
| 7 |
+
grid_pull, grid_pull_backward,
|
| 8 |
+
grid_push, grid_push_backward,
|
| 9 |
+
grid_count, grid_count_backward,
|
| 10 |
+
grid_grad, grid_grad_backward)
|
| 11 |
+
from .utils import fake_decorator
|
| 12 |
+
try:
|
| 13 |
+
from torch.cuda.amp import custom_fwd, custom_bwd
|
| 14 |
+
except (ModuleNotFoundError, ImportError):
|
| 15 |
+
custom_fwd = custom_bwd = fake_decorator
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def make_list(x):
|
| 19 |
+
if not isinstance(x, (list, tuple)):
|
| 20 |
+
x = [x]
|
| 21 |
+
return list(x)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def bound_to_nitorch(bound, as_type='str'):
|
| 25 |
+
"""Convert boundary type to niTorch's convention.
|
| 26 |
+
|
| 27 |
+
Parameters
|
| 28 |
+
----------
|
| 29 |
+
bound : [list of] str or bound_like
|
| 30 |
+
Boundary condition in any convention
|
| 31 |
+
as_type : {'str', 'enum', 'int'}, default='str'
|
| 32 |
+
Return BoundType or int rather than str
|
| 33 |
+
|
| 34 |
+
Returns
|
| 35 |
+
-------
|
| 36 |
+
bound : [list of] str or BoundType
|
| 37 |
+
Boundary condition in NITorch's convention
|
| 38 |
+
|
| 39 |
+
"""
|
| 40 |
+
intype = type(bound)
|
| 41 |
+
if not isinstance(bound, (list, tuple)):
|
| 42 |
+
bound = [bound]
|
| 43 |
+
obound = []
|
| 44 |
+
for b in bound:
|
| 45 |
+
b = b.lower() if isinstance(b, str) else b
|
| 46 |
+
if b in ('replicate', 'repeat', 'border', 'nearest', BoundType.replicate):
|
| 47 |
+
obound.append('replicate')
|
| 48 |
+
elif b in ('zero', 'zeros', 'constant', BoundType.zero):
|
| 49 |
+
obound.append('zero')
|
| 50 |
+
elif b in ('dct2', 'reflect', 'reflection', 'neumann', BoundType.dct2):
|
| 51 |
+
obound.append('dct2')
|
| 52 |
+
elif b in ('dct1', 'mirror', BoundType.dct1):
|
| 53 |
+
obound.append('dct1')
|
| 54 |
+
elif b in ('dft', 'wrap', 'circular', BoundType.dft):
|
| 55 |
+
obound.append('dft')
|
| 56 |
+
elif b in ('dst2', 'antireflect', 'dirichlet', BoundType.dst2):
|
| 57 |
+
obound.append('dst2')
|
| 58 |
+
elif b in ('dst1', 'antimirror', BoundType.dst1):
|
| 59 |
+
obound.append('dst1')
|
| 60 |
+
elif isinstance(b, int):
|
| 61 |
+
obound.append(b)
|
| 62 |
+
else:
|
| 63 |
+
raise ValueError(f'Unknown boundary condition {b}')
|
| 64 |
+
obound = list(map(lambda b: getattr(BoundType, b) if isinstance(b, str)
|
| 65 |
+
else BoundType(b), obound))
|
| 66 |
+
if as_type in ('int', int):
|
| 67 |
+
obound = [b.value for b in obound]
|
| 68 |
+
if as_type in ('str', str):
|
| 69 |
+
obound = [b.name for b in obound]
|
| 70 |
+
if issubclass(intype, (list, tuple)):
|
| 71 |
+
obound = intype(obound)
|
| 72 |
+
else:
|
| 73 |
+
obound = obound[0]
|
| 74 |
+
return obound
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def inter_to_nitorch(inter, as_type='str'):
|
| 78 |
+
"""Convert interpolation order to NITorch's convention.
|
| 79 |
+
|
| 80 |
+
Parameters
|
| 81 |
+
----------
|
| 82 |
+
inter : [sequence of] int or str or InterpolationType
|
| 83 |
+
as_type : {'str', 'enum', 'int'}, default='int'
|
| 84 |
+
|
| 85 |
+
Returns
|
| 86 |
+
-------
|
| 87 |
+
inter : [sequence of] int or InterpolationType
|
| 88 |
+
|
| 89 |
+
"""
|
| 90 |
+
intype = type(inter)
|
| 91 |
+
if not isinstance(inter, (list, tuple)):
|
| 92 |
+
inter = [inter]
|
| 93 |
+
ointer = []
|
| 94 |
+
for o in inter:
|
| 95 |
+
o = o.lower() if isinstance(o, str) else o
|
| 96 |
+
if o in (0, 'nearest', InterpolationType.nearest):
|
| 97 |
+
ointer.append(0)
|
| 98 |
+
elif o in (1, 'linear', InterpolationType.linear):
|
| 99 |
+
ointer.append(1)
|
| 100 |
+
elif o in (2, 'quadratic', InterpolationType.quadratic):
|
| 101 |
+
ointer.append(2)
|
| 102 |
+
elif o in (3, 'cubic', InterpolationType.cubic):
|
| 103 |
+
ointer.append(3)
|
| 104 |
+
elif o in (4, 'fourth', InterpolationType.fourth):
|
| 105 |
+
ointer.append(4)
|
| 106 |
+
elif o in (5, 'fifth', InterpolationType.fifth):
|
| 107 |
+
ointer.append(5)
|
| 108 |
+
elif o in (6, 'sixth', InterpolationType.sixth):
|
| 109 |
+
ointer.append(6)
|
| 110 |
+
elif o in (7, 'seventh', InterpolationType.seventh):
|
| 111 |
+
ointer.append(7)
|
| 112 |
+
else:
|
| 113 |
+
raise ValueError(f'Unknown interpolation order {o}')
|
| 114 |
+
if as_type in ('enum', 'str', str):
|
| 115 |
+
ointer = list(map(InterpolationType, ointer))
|
| 116 |
+
if as_type in ('str', str):
|
| 117 |
+
ointer = [o.name for o in ointer]
|
| 118 |
+
if issubclass(intype, (list, tuple)):
|
| 119 |
+
ointer = intype(ointer)
|
| 120 |
+
else:
|
| 121 |
+
ointer = ointer[0]
|
| 122 |
+
return ointer
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
class GridPull(torch.autograd.Function):
|
| 126 |
+
|
| 127 |
+
@staticmethod
|
| 128 |
+
@custom_fwd(cast_inputs=torch.float32)
|
| 129 |
+
def forward(ctx, input, grid, interpolation, bound, extrapolate):
|
| 130 |
+
|
| 131 |
+
bound = bound_to_nitorch(make_list(bound), as_type='int')
|
| 132 |
+
interpolation = inter_to_nitorch(make_list(interpolation), as_type='int')
|
| 133 |
+
extrapolate = int(extrapolate)
|
| 134 |
+
opt = (bound, interpolation, extrapolate)
|
| 135 |
+
|
| 136 |
+
# Pull
|
| 137 |
+
output = grid_pull(input, grid, *opt)
|
| 138 |
+
|
| 139 |
+
# Context
|
| 140 |
+
ctx.opt = opt
|
| 141 |
+
ctx.save_for_backward(input, grid)
|
| 142 |
+
|
| 143 |
+
return output
|
| 144 |
+
|
| 145 |
+
@staticmethod
|
| 146 |
+
@custom_bwd
|
| 147 |
+
def backward(ctx, grad):
|
| 148 |
+
var = ctx.saved_tensors
|
| 149 |
+
opt = ctx.opt
|
| 150 |
+
grads = grid_pull_backward(grad, *var, *opt)
|
| 151 |
+
grad_input, grad_grid = grads
|
| 152 |
+
return grad_input, grad_grid, None, None, None
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
class GridPush(torch.autograd.Function):
|
| 156 |
+
|
| 157 |
+
@staticmethod
|
| 158 |
+
@custom_fwd(cast_inputs=torch.float32)
|
| 159 |
+
def forward(ctx, input, grid, shape, interpolation, bound, extrapolate):
|
| 160 |
+
|
| 161 |
+
bound = bound_to_nitorch(make_list(bound), as_type='int')
|
| 162 |
+
interpolation = inter_to_nitorch(make_list(interpolation), as_type='int')
|
| 163 |
+
extrapolate = int(extrapolate)
|
| 164 |
+
opt = (bound, interpolation, extrapolate)
|
| 165 |
+
|
| 166 |
+
# Push
|
| 167 |
+
output = grid_push(input, grid, shape, *opt)
|
| 168 |
+
|
| 169 |
+
# Context
|
| 170 |
+
ctx.opt = opt
|
| 171 |
+
ctx.save_for_backward(input, grid)
|
| 172 |
+
|
| 173 |
+
return output
|
| 174 |
+
|
| 175 |
+
@staticmethod
|
| 176 |
+
@custom_bwd
|
| 177 |
+
def backward(ctx, grad):
|
| 178 |
+
var = ctx.saved_tensors
|
| 179 |
+
opt = ctx.opt
|
| 180 |
+
grads = grid_push_backward(grad, *var, *opt)
|
| 181 |
+
grad_input, grad_grid = grads
|
| 182 |
+
return grad_input, grad_grid, None, None, None, None
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
class GridCount(torch.autograd.Function):
|
| 186 |
+
|
| 187 |
+
@staticmethod
|
| 188 |
+
@custom_fwd(cast_inputs=torch.float32)
|
| 189 |
+
def forward(ctx, grid, shape, interpolation, bound, extrapolate):
|
| 190 |
+
|
| 191 |
+
bound = bound_to_nitorch(make_list(bound), as_type='int')
|
| 192 |
+
interpolation = inter_to_nitorch(make_list(interpolation), as_type='int')
|
| 193 |
+
extrapolate = int(extrapolate)
|
| 194 |
+
opt = (bound, interpolation, extrapolate)
|
| 195 |
+
|
| 196 |
+
# Push
|
| 197 |
+
output = grid_count(grid, shape, *opt)
|
| 198 |
+
|
| 199 |
+
# Context
|
| 200 |
+
ctx.opt = opt
|
| 201 |
+
ctx.save_for_backward(grid)
|
| 202 |
+
|
| 203 |
+
return output
|
| 204 |
+
|
| 205 |
+
@staticmethod
|
| 206 |
+
@custom_bwd
|
| 207 |
+
def backward(ctx, grad):
|
| 208 |
+
var = ctx.saved_tensors
|
| 209 |
+
opt = ctx.opt
|
| 210 |
+
grad_grid = None
|
| 211 |
+
if ctx.needs_input_grad[0]:
|
| 212 |
+
grad_grid = grid_count_backward(grad, *var, *opt)
|
| 213 |
+
return grad_grid, None, None, None, None
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
class GridGrad(torch.autograd.Function):
|
| 217 |
+
|
| 218 |
+
@staticmethod
|
| 219 |
+
@custom_fwd(cast_inputs=torch.float32)
|
| 220 |
+
def forward(ctx, input, grid, interpolation, bound, extrapolate):
|
| 221 |
+
|
| 222 |
+
bound = bound_to_nitorch(make_list(bound), as_type='int')
|
| 223 |
+
interpolation = inter_to_nitorch(make_list(interpolation), as_type='int')
|
| 224 |
+
extrapolate = int(extrapolate)
|
| 225 |
+
opt = (bound, interpolation, extrapolate)
|
| 226 |
+
|
| 227 |
+
# Pull
|
| 228 |
+
output = grid_grad(input, grid, *opt)
|
| 229 |
+
|
| 230 |
+
# Context
|
| 231 |
+
ctx.opt = opt
|
| 232 |
+
ctx.save_for_backward(input, grid)
|
| 233 |
+
|
| 234 |
+
return output
|
| 235 |
+
|
| 236 |
+
@staticmethod
|
| 237 |
+
@custom_bwd
|
| 238 |
+
def backward(ctx, grad):
|
| 239 |
+
var = ctx.saved_tensors
|
| 240 |
+
opt = ctx.opt
|
| 241 |
+
grad_input = grad_grid = None
|
| 242 |
+
if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
|
| 243 |
+
grads = grid_grad_backward(grad, *var, *opt)
|
| 244 |
+
grad_input, grad_grid = grads
|
| 245 |
+
return grad_input, grad_grid, None, None, None
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
class SplineCoeff(torch.autograd.Function):
|
| 249 |
+
|
| 250 |
+
@staticmethod
|
| 251 |
+
@custom_fwd
|
| 252 |
+
def forward(ctx, input, bound, interpolation, dim, inplace):
|
| 253 |
+
|
| 254 |
+
bound = bound_to_nitorch(make_list(bound)[0], as_type='int')
|
| 255 |
+
interpolation = inter_to_nitorch(make_list(interpolation)[0], as_type='int')
|
| 256 |
+
opt = (bound, interpolation, dim, inplace)
|
| 257 |
+
|
| 258 |
+
# Pull
|
| 259 |
+
output = spline_coeff(input, *opt)
|
| 260 |
+
|
| 261 |
+
# Context
|
| 262 |
+
if input.requires_grad:
|
| 263 |
+
ctx.opt = opt
|
| 264 |
+
|
| 265 |
+
return output
|
| 266 |
+
|
| 267 |
+
@staticmethod
|
| 268 |
+
@custom_bwd
|
| 269 |
+
def backward(ctx, grad):
|
| 270 |
+
# symmetric filter -> backward == forward
|
| 271 |
+
# (I don't know if I can write into grad, so inplace=False to be safe)
|
| 272 |
+
grad = spline_coeff(grad, *ctx.opt[:-1], inplace=False)
|
| 273 |
+
return [grad] + [None] * 4
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
class SplineCoeffND(torch.autograd.Function):
|
| 277 |
+
|
| 278 |
+
@staticmethod
|
| 279 |
+
@custom_fwd
|
| 280 |
+
def forward(ctx, input, bound, interpolation, dim, inplace):
|
| 281 |
+
|
| 282 |
+
bound = bound_to_nitorch(make_list(bound), as_type='int')
|
| 283 |
+
interpolation = inter_to_nitorch(make_list(interpolation), as_type='int')
|
| 284 |
+
opt = (bound, interpolation, dim, inplace)
|
| 285 |
+
|
| 286 |
+
# Pull
|
| 287 |
+
output = spline_coeff_nd(input, *opt)
|
| 288 |
+
|
| 289 |
+
# Context
|
| 290 |
+
if input.requires_grad:
|
| 291 |
+
ctx.opt = opt
|
| 292 |
+
|
| 293 |
+
return output
|
| 294 |
+
|
| 295 |
+
@staticmethod
|
| 296 |
+
@custom_bwd
|
| 297 |
+
def backward(ctx, grad):
|
| 298 |
+
# symmetric filter -> backward == forward
|
| 299 |
+
# (I don't know if I can write into grad, so inplace=False to be safe)
|
| 300 |
+
grad = spline_coeff_nd(grad, *ctx.opt[:-1], inplace=False)
|
| 301 |
+
return grad, None, None, None, None
|
Generator/interpol/backend.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
jitfields = False # Whether to use jitfields if available
|
Generator/interpol/bounds.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from enum import Enum
|
| 3 |
+
from typing import Optional
|
| 4 |
+
from .jit_utils import floor_div
|
| 5 |
+
Tensor = torch.Tensor
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class BoundType(Enum):
|
| 9 |
+
zero = zeros = 0
|
| 10 |
+
replicate = nearest = 1
|
| 11 |
+
dct1 = mirror = 2
|
| 12 |
+
dct2 = reflect = 3
|
| 13 |
+
dst1 = antimirror = 4
|
| 14 |
+
dst2 = antireflect = 5
|
| 15 |
+
dft = wrap = 6
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class ExtrapolateType(Enum):
|
| 19 |
+
no = 0 # threshold: (0, n-1)
|
| 20 |
+
yes = 1
|
| 21 |
+
hist = 2 # threshold: (-0.5, n-0.5)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@torch.jit.script
|
| 25 |
+
class Bound:
|
| 26 |
+
|
| 27 |
+
def __init__(self, bound_type: int = 3):
|
| 28 |
+
self.type = bound_type
|
| 29 |
+
|
| 30 |
+
def index(self, i, n: int):
|
| 31 |
+
if self.type in (0, 1): # zero / replicate
|
| 32 |
+
return i.clamp(min=0, max=n-1)
|
| 33 |
+
elif self.type in (3, 5): # dct2 / dst2
|
| 34 |
+
n2 = n * 2
|
| 35 |
+
i = torch.where(i < 0, (-i-1).remainder(n2).neg().add(n2 - 1),
|
| 36 |
+
i.remainder(n2))
|
| 37 |
+
i = torch.where(i >= n, -i + (n2 - 1), i)
|
| 38 |
+
return i
|
| 39 |
+
elif self.type == 2: # dct1
|
| 40 |
+
if n == 1:
|
| 41 |
+
return torch.zeros(i.shape, dtype=i.dtype, device=i.device)
|
| 42 |
+
else:
|
| 43 |
+
n2 = (n - 1) * 2
|
| 44 |
+
i = i.abs().remainder(n2)
|
| 45 |
+
i = torch.where(i >= n, -i + n2, i)
|
| 46 |
+
return i
|
| 47 |
+
elif self.type == 4: # dst1
|
| 48 |
+
n2 = 2 * (n + 1)
|
| 49 |
+
first = torch.zeros([1], dtype=i.dtype, device=i.device)
|
| 50 |
+
last = torch.full([1], n - 1, dtype=i.dtype, device=i.device)
|
| 51 |
+
i = torch.where(i < 0, -i - 2, i)
|
| 52 |
+
i = i.remainder(n2)
|
| 53 |
+
i = torch.where(i > n, -i + (n2 - 2), i)
|
| 54 |
+
i = torch.where(i == -1, first, i)
|
| 55 |
+
i = torch.where(i == n, last, i)
|
| 56 |
+
return i
|
| 57 |
+
elif self.type == 6: # dft
|
| 58 |
+
return i.remainder(n)
|
| 59 |
+
else:
|
| 60 |
+
return i
|
| 61 |
+
|
| 62 |
+
def transform(self, i, n: int) -> Optional[Tensor]:
|
| 63 |
+
if self.type == 4: # dst1
|
| 64 |
+
if n == 1:
|
| 65 |
+
return None
|
| 66 |
+
one = torch.ones([1], dtype=torch.int8, device=i.device)
|
| 67 |
+
zero = torch.zeros([1], dtype=torch.int8, device=i.device)
|
| 68 |
+
n2 = 2 * (n + 1)
|
| 69 |
+
i = torch.where(i < 0, -i + (n-1), i)
|
| 70 |
+
i = i.remainder(n2)
|
| 71 |
+
x = torch.where(i == 0, zero, one)
|
| 72 |
+
x = torch.where(i.remainder(n + 1) == n, zero, x)
|
| 73 |
+
i = floor_div(i, n+1)
|
| 74 |
+
x = torch.where(torch.remainder(i, 2) > 0, -x, x)
|
| 75 |
+
return x
|
| 76 |
+
elif self.type == 5: # dst2
|
| 77 |
+
i = torch.where(i < 0, n - 1 - i, i)
|
| 78 |
+
x = torch.ones([1], dtype=torch.int8, device=i.device)
|
| 79 |
+
i = floor_div(i, n)
|
| 80 |
+
x = torch.where(torch.remainder(i, 2) > 0, -x, x)
|
| 81 |
+
return x
|
| 82 |
+
elif self.type == 0: # zero
|
| 83 |
+
one = torch.ones([1], dtype=torch.int8, device=i.device)
|
| 84 |
+
zero = torch.zeros([1], dtype=torch.int8, device=i.device)
|
| 85 |
+
outbounds = ((i < 0) | (i >= n))
|
| 86 |
+
x = torch.where(outbounds, zero, one)
|
| 87 |
+
return x
|
| 88 |
+
else:
|
| 89 |
+
return None
|
Generator/interpol/coeff.py
ADDED
|
@@ -0,0 +1,344 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Compute spline interpolating coefficients
|
| 2 |
+
|
| 3 |
+
These functions are ported from the C routines in SPM's bsplines.c
|
| 4 |
+
by John Ashburner, which are themselves ports from Philippe Thevenaz's
|
| 5 |
+
code. JA furthermore derived the initial conditions for the DFT ("wrap around")
|
| 6 |
+
boundary conditions.
|
| 7 |
+
|
| 8 |
+
Note that similar routines are available in scipy with boundary conditions
|
| 9 |
+
DCT1 ("mirror"), DCT2 ("reflect") and DFT ("wrap"); all derived by P. Thevenaz,
|
| 10 |
+
according to the comments. Our DCT2 boundary conditions are ported from
|
| 11 |
+
scipy.
|
| 12 |
+
|
| 13 |
+
Only boundary conditions DCT1, DCT2 and DFT are implemented.
|
| 14 |
+
|
| 15 |
+
References
|
| 16 |
+
----------
|
| 17 |
+
..[1] M. Unser, A. Aldroubi and M. Eden.
|
| 18 |
+
"B-Spline Signal Processing: Part I-Theory,"
|
| 19 |
+
IEEE Transactions on Signal Processing 41(2):821-832 (1993).
|
| 20 |
+
..[2] M. Unser, A. Aldroubi and M. Eden.
|
| 21 |
+
"B-Spline Signal Processing: Part II-Efficient Design and Applications,"
|
| 22 |
+
IEEE Transactions on Signal Processing 41(2):834-848 (1993).
|
| 23 |
+
..[3] M. Unser.
|
| 24 |
+
"Splines: A Perfect Fit for Signal and Image Processing,"
|
| 25 |
+
IEEE Signal Processing Magazine 16(6):22-38 (1999).
|
| 26 |
+
"""
|
| 27 |
+
import torch
|
| 28 |
+
import math
|
| 29 |
+
from typing import List, Optional
|
| 30 |
+
from .jit_utils import movedim1
|
| 31 |
+
from .pushpull import pad_list_int
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@torch.jit.script
|
| 35 |
+
def get_poles(order: int) -> List[float]:
|
| 36 |
+
empty: List[float] = []
|
| 37 |
+
if order in (0, 1):
|
| 38 |
+
return empty
|
| 39 |
+
if order == 2:
|
| 40 |
+
return [math.sqrt(8.) - 3.]
|
| 41 |
+
if order == 3:
|
| 42 |
+
return [math.sqrt(3.) - 2.]
|
| 43 |
+
if order == 4:
|
| 44 |
+
return [math.sqrt(664. - math.sqrt(438976.)) + math.sqrt(304.) - 19.,
|
| 45 |
+
math.sqrt(664. + math.sqrt(438976.)) - math.sqrt(304.) - 19.]
|
| 46 |
+
if order == 5:
|
| 47 |
+
return [math.sqrt(67.5 - math.sqrt(4436.25)) + math.sqrt(26.25) - 6.5,
|
| 48 |
+
math.sqrt(67.5 + math.sqrt(4436.25)) - math.sqrt(26.25) - 6.5]
|
| 49 |
+
if order == 6:
|
| 50 |
+
return [-0.488294589303044755130118038883789062112279161239377608394,
|
| 51 |
+
-0.081679271076237512597937765737059080653379610398148178525368,
|
| 52 |
+
-0.00141415180832581775108724397655859252786416905534669851652709]
|
| 53 |
+
if order == 7:
|
| 54 |
+
return [-0.5352804307964381655424037816816460718339231523426924148812,
|
| 55 |
+
-0.122554615192326690515272264359357343605486549427295558490763,
|
| 56 |
+
-0.0091486948096082769285930216516478534156925639545994482648003]
|
| 57 |
+
raise NotImplementedError
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
@torch.jit.script
|
| 61 |
+
def get_gain(poles: List[float]) -> float:
|
| 62 |
+
lam: float = 1.
|
| 63 |
+
for pole in poles:
|
| 64 |
+
lam *= (1. - pole) * (1. - 1./pole)
|
| 65 |
+
return lam
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
@torch.jit.script
|
| 69 |
+
def dft_initial(inp, pole: float, dim: int = -1, keepdim: bool = False):
|
| 70 |
+
|
| 71 |
+
assert inp.shape[dim] > 1
|
| 72 |
+
max_iter: int = int(math.ceil(-30./math.log(abs(pole))))
|
| 73 |
+
max_iter = min(max_iter, inp.shape[dim])
|
| 74 |
+
|
| 75 |
+
poles = torch.as_tensor(pole, dtype=inp.dtype, device=inp.device)
|
| 76 |
+
poles = poles.pow(torch.arange(1, max_iter, dtype=inp.dtype, device=inp.device))
|
| 77 |
+
poles = poles.flip(0)
|
| 78 |
+
|
| 79 |
+
inp = movedim1(inp, dim, 0)
|
| 80 |
+
inp0 = inp[0]
|
| 81 |
+
inp = inp[1-max_iter:]
|
| 82 |
+
inp = movedim1(inp, 0, -1)
|
| 83 |
+
out = torch.matmul(inp.unsqueeze(-2), poles.unsqueeze(-1)).squeeze(-1)
|
| 84 |
+
out = out + inp0.unsqueeze(-1)
|
| 85 |
+
if keepdim:
|
| 86 |
+
out = movedim1(out, -1, dim)
|
| 87 |
+
else:
|
| 88 |
+
out = out.squeeze(-1)
|
| 89 |
+
|
| 90 |
+
pole = pole ** max_iter
|
| 91 |
+
out = out / (1 - pole)
|
| 92 |
+
return out
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
@torch.jit.script
|
| 96 |
+
def dct1_initial(inp, pole: float, dim: int = -1, keepdim: bool = False):
|
| 97 |
+
|
| 98 |
+
n = inp.shape[dim]
|
| 99 |
+
max_iter: int = int(math.ceil(-30./math.log(abs(pole))))
|
| 100 |
+
|
| 101 |
+
if max_iter < n:
|
| 102 |
+
|
| 103 |
+
poles = torch.as_tensor(pole, dtype=inp.dtype, device=inp.device)
|
| 104 |
+
poles = poles.pow(torch.arange(1, max_iter, dtype=inp.dtype, device=inp.device))
|
| 105 |
+
|
| 106 |
+
inp = movedim1(inp, dim, 0)
|
| 107 |
+
inp0 = inp[0]
|
| 108 |
+
inp = inp[1:max_iter]
|
| 109 |
+
inp = movedim1(inp, 0, -1)
|
| 110 |
+
out = torch.matmul(inp.unsqueeze(-2), poles.unsqueeze(-1)).squeeze(-1)
|
| 111 |
+
out = out + inp0.unsqueeze(-1)
|
| 112 |
+
if keepdim:
|
| 113 |
+
out = movedim1(out, -1, dim)
|
| 114 |
+
else:
|
| 115 |
+
out = out.squeeze(-1)
|
| 116 |
+
|
| 117 |
+
else:
|
| 118 |
+
max_iter = n
|
| 119 |
+
|
| 120 |
+
polen = pole ** (n - 1)
|
| 121 |
+
inp0 = inp[0] + polen * inp[-1]
|
| 122 |
+
inp = inp[1:-1]
|
| 123 |
+
inp = movedim1(inp, 0, -1)
|
| 124 |
+
|
| 125 |
+
poles = torch.as_tensor(pole, dtype=inp.dtype, device=inp.device)
|
| 126 |
+
poles = poles.pow(torch.arange(1, n-1, dtype=inp.dtype, device=inp.device))
|
| 127 |
+
poles = poles + (polen * polen) / poles
|
| 128 |
+
|
| 129 |
+
out = torch.matmul(inp.unsqueeze(-2), poles.unsqueeze(-1)).squeeze(-1)
|
| 130 |
+
out = out + inp0.unsqueeze(-1)
|
| 131 |
+
if keepdim:
|
| 132 |
+
out = movedim1(out, -1, dim)
|
| 133 |
+
else:
|
| 134 |
+
out = out.squeeze(-1)
|
| 135 |
+
|
| 136 |
+
pole = pole ** (max_iter - 1)
|
| 137 |
+
out = out / (1 - pole * pole)
|
| 138 |
+
|
| 139 |
+
return out
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
@torch.jit.script
|
| 143 |
+
def dct2_initial(inp, pole: float, dim: int = -1, keepdim: bool = False):
|
| 144 |
+
# Ported from scipy:
|
| 145 |
+
# https://github.com/scipy/scipy/blob/master/scipy/ndimage/src/ni_splines.c
|
| 146 |
+
#
|
| 147 |
+
# I (YB) unwarped and simplied the terms so that I could use a dot
|
| 148 |
+
# product instead of a loop.
|
| 149 |
+
# It should certainly be possible to derive a version for max_iter < n,
|
| 150 |
+
# as JA did for DCT1, to avoid long recursions when `n` is large. But
|
| 151 |
+
# I think it would require a more complicated anticausal/final condition.
|
| 152 |
+
|
| 153 |
+
n = inp.shape[dim]
|
| 154 |
+
|
| 155 |
+
polen = pole ** n
|
| 156 |
+
pole_last = polen * (1 + 1/(pole + polen * polen))
|
| 157 |
+
inp00 = inp[0]
|
| 158 |
+
inp0 = inp[0] + pole_last * inp[-1]
|
| 159 |
+
inp = inp[1:-1]
|
| 160 |
+
inp = movedim1(inp, 0, -1)
|
| 161 |
+
|
| 162 |
+
poles = torch.as_tensor(pole, dtype=inp.dtype, device=inp.device)
|
| 163 |
+
poles = (poles.pow(torch.arange(1, n-1, dtype=inp.dtype, device=inp.device)) +
|
| 164 |
+
poles.pow(torch.arange(2*n-2, n, -1, dtype=inp.dtype, device=inp.device)))
|
| 165 |
+
|
| 166 |
+
out = torch.matmul(inp.unsqueeze(-2), poles.unsqueeze(-1)).squeeze(-1)
|
| 167 |
+
|
| 168 |
+
out = out + inp0.unsqueeze(-1)
|
| 169 |
+
out = out * (pole / (1 - polen * polen))
|
| 170 |
+
out = out + inp00.unsqueeze(-1)
|
| 171 |
+
|
| 172 |
+
if keepdim:
|
| 173 |
+
out = movedim1(out, -1, dim)
|
| 174 |
+
else:
|
| 175 |
+
out = out.squeeze(-1)
|
| 176 |
+
|
| 177 |
+
return out
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
@torch.jit.script
|
| 181 |
+
def dft_final(inp, pole: float, dim: int = -1, keepdim: bool = False):
|
| 182 |
+
|
| 183 |
+
assert inp.shape[dim] > 1
|
| 184 |
+
max_iter: int = int(math.ceil(-30./math.log(abs(pole))))
|
| 185 |
+
max_iter = min(max_iter, inp.shape[dim])
|
| 186 |
+
|
| 187 |
+
poles = torch.as_tensor(pole, dtype=inp.dtype, device=inp.device)
|
| 188 |
+
poles = poles.pow(torch.arange(2, max_iter+1, dtype=inp.dtype, device=inp.device))
|
| 189 |
+
|
| 190 |
+
inp = movedim1(inp, dim, 0)
|
| 191 |
+
inp0 = inp[-1]
|
| 192 |
+
inp = inp[:max_iter-1]
|
| 193 |
+
inp = movedim1(inp, 0, -1)
|
| 194 |
+
out = torch.matmul(inp.unsqueeze(-2), poles.unsqueeze(-1)).squeeze(-1)
|
| 195 |
+
out = out.add(inp0.unsqueeze(-1), alpha=pole)
|
| 196 |
+
if keepdim:
|
| 197 |
+
out = movedim1(out, -1, dim)
|
| 198 |
+
else:
|
| 199 |
+
out = out.squeeze(-1)
|
| 200 |
+
|
| 201 |
+
pole = pole ** max_iter
|
| 202 |
+
out = out / (pole - 1)
|
| 203 |
+
return out
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
@torch.jit.script
|
| 207 |
+
def dct1_final(inp, pole: float, dim: int = -1, keepdim: bool = False):
|
| 208 |
+
inp = movedim1(inp, dim, 0)
|
| 209 |
+
out = pole * inp[-2] + inp[-1]
|
| 210 |
+
out = out * (pole / (pole*pole - 1))
|
| 211 |
+
if keepdim:
|
| 212 |
+
out = movedim1(out.unsqueeze(0), 0, dim)
|
| 213 |
+
return out
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
@torch.jit.script
|
| 217 |
+
def dct2_final(inp, pole: float, dim: int = -1, keepdim: bool = False):
|
| 218 |
+
# Ported from scipy:
|
| 219 |
+
# https://github.com/scipy/scipy/blob/master/scipy/ndimage/src/ni_splines.c
|
| 220 |
+
inp = movedim1(inp, dim, 0)
|
| 221 |
+
out = inp[-1] * (pole / (pole - 1))
|
| 222 |
+
if keepdim:
|
| 223 |
+
out = movedim1(out.unsqueeze(0), 0, dim)
|
| 224 |
+
return out
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
@torch.jit.script
|
| 228 |
+
class CoeffBound:
|
| 229 |
+
|
| 230 |
+
def __init__(self, bound: int):
|
| 231 |
+
self.bound = bound
|
| 232 |
+
|
| 233 |
+
def initial(self, inp, pole: float, dim: int = -1, keepdim: bool = False):
|
| 234 |
+
if self.bound in (0, 2): # zero, dct1
|
| 235 |
+
return dct1_initial(inp, pole, dim, keepdim)
|
| 236 |
+
elif self.bound in (1, 3): # nearest, dct2
|
| 237 |
+
return dct2_initial(inp, pole, dim, keepdim)
|
| 238 |
+
elif self.bound == 6: # dft
|
| 239 |
+
return dft_initial(inp, pole, dim, keepdim)
|
| 240 |
+
else:
|
| 241 |
+
raise NotImplementedError
|
| 242 |
+
|
| 243 |
+
def final(self, inp, pole: float, dim: int = -1, keepdim: bool = False):
|
| 244 |
+
if self.bound in (0, 2): # zero, dct1
|
| 245 |
+
return dct1_final(inp, pole, dim, keepdim)
|
| 246 |
+
elif self.bound in (1, 3): # nearest, dct2
|
| 247 |
+
return dct2_final(inp, pole, dim, keepdim)
|
| 248 |
+
elif self.bound == 6: # dft
|
| 249 |
+
return dft_final(inp, pole, dim, keepdim)
|
| 250 |
+
else:
|
| 251 |
+
raise NotImplementedError
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
@torch.jit.script
|
| 255 |
+
def filter(inp, bound: CoeffBound, poles: List[float],
|
| 256 |
+
dim: int = -1, inplace: bool = False):
|
| 257 |
+
|
| 258 |
+
if not inplace:
|
| 259 |
+
inp = inp.clone()
|
| 260 |
+
|
| 261 |
+
if inp.shape[dim] == 1:
|
| 262 |
+
return inp
|
| 263 |
+
|
| 264 |
+
gain = get_gain(poles)
|
| 265 |
+
inp *= gain
|
| 266 |
+
inp = movedim1(inp, dim, 0)
|
| 267 |
+
n = inp.shape[0]
|
| 268 |
+
|
| 269 |
+
for pole in poles:
|
| 270 |
+
inp[0] = bound.initial(inp, pole, dim=0, keepdim=False)
|
| 271 |
+
|
| 272 |
+
for i in range(1, n):
|
| 273 |
+
inp[i].add_(inp[i-1], alpha=pole)
|
| 274 |
+
|
| 275 |
+
inp[-1] = bound.final(inp, pole, dim=0, keepdim=False)
|
| 276 |
+
|
| 277 |
+
for i in range(n-2, -1, -1):
|
| 278 |
+
inp[i].neg_().add_(inp[i+1]).mul_(pole)
|
| 279 |
+
|
| 280 |
+
inp = movedim1(inp, 0, dim)
|
| 281 |
+
return inp
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
@torch.jit.script
|
| 285 |
+
def spline_coeff(inp, bound: int, order: int, dim: int = -1,
|
| 286 |
+
inplace: bool = False):
|
| 287 |
+
"""Compute the interpolating spline coefficients, for a given spline order
|
| 288 |
+
and boundary conditions, along a single dimension.
|
| 289 |
+
|
| 290 |
+
Parameters
|
| 291 |
+
----------
|
| 292 |
+
inp : tensor
|
| 293 |
+
bound : {2: dct1, 6: dft}
|
| 294 |
+
order : {0..7}
|
| 295 |
+
dim : int, default=-1
|
| 296 |
+
inplace : bool, default=False
|
| 297 |
+
|
| 298 |
+
Returns
|
| 299 |
+
-------
|
| 300 |
+
out : tensor
|
| 301 |
+
|
| 302 |
+
"""
|
| 303 |
+
if not inplace:
|
| 304 |
+
inp = inp.clone()
|
| 305 |
+
|
| 306 |
+
if order in (0, 1):
|
| 307 |
+
return inp
|
| 308 |
+
|
| 309 |
+
poles = get_poles(order)
|
| 310 |
+
return filter(inp, CoeffBound(bound), poles, dim=dim, inplace=True)
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
@torch.jit.script
|
| 314 |
+
def spline_coeff_nd(inp, bound: List[int], order: List[int],
|
| 315 |
+
dim: Optional[int] = None, inplace: bool = False):
|
| 316 |
+
"""Compute the interpolating spline coefficients, for a given spline order
|
| 317 |
+
and boundary condition, along the last `dim` dimensions.
|
| 318 |
+
|
| 319 |
+
Parameters
|
| 320 |
+
----------
|
| 321 |
+
inp : (..., *spatial) tensor
|
| 322 |
+
bound : List[{2: dct1, 6: dft}]
|
| 323 |
+
order : List[{0..7}]
|
| 324 |
+
dim : int, default=`inp.dim()`
|
| 325 |
+
inplace : bool, default=False
|
| 326 |
+
|
| 327 |
+
Returns
|
| 328 |
+
-------
|
| 329 |
+
out : (..., *spatial) tensor
|
| 330 |
+
|
| 331 |
+
"""
|
| 332 |
+
if not inplace:
|
| 333 |
+
inp = inp.clone()
|
| 334 |
+
|
| 335 |
+
if dim is None:
|
| 336 |
+
dim = inp.dim()
|
| 337 |
+
|
| 338 |
+
bound = pad_list_int(bound, dim)
|
| 339 |
+
order = pad_list_int(order, dim)
|
| 340 |
+
|
| 341 |
+
for d, b, o in zip(range(dim), bound, order):
|
| 342 |
+
inp = spline_coeff(inp, b, o, dim=-dim + d, inplace=True)
|
| 343 |
+
|
| 344 |
+
return inp
|
Generator/interpol/iso0.py
ADDED
|
@@ -0,0 +1,368 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Isotropic 0-th order splines ("nearest neighbor")"""
|
| 2 |
+
import torch
|
| 3 |
+
from .bounds import Bound
|
| 4 |
+
from .jit_utils import (sub2ind_list, make_sign,
|
| 5 |
+
inbounds_mask_3d, inbounds_mask_2d, inbounds_mask_1d)
|
| 6 |
+
from typing import List, Optional
|
| 7 |
+
Tensor = torch.Tensor
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@torch.jit.script
|
| 11 |
+
def get_indices(g, n: int, bound: Bound):
|
| 12 |
+
g0 = g.round().long()
|
| 13 |
+
sign0 = bound.transform(g0, n)
|
| 14 |
+
g0 = bound.index(g0, n)
|
| 15 |
+
return g0, sign0
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# ======================================================================
|
| 19 |
+
# 3D
|
| 20 |
+
# ======================================================================
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@torch.jit.script
|
| 24 |
+
def pull3d(inp, g, bound: List[Bound], extrapolate: int = 1):
|
| 25 |
+
"""
|
| 26 |
+
inp: (B, C, iX, iY, iZ) tensor
|
| 27 |
+
g: (B, oX, oY, oZ, 3) tensor
|
| 28 |
+
bound: List{3}[Bound] tensor
|
| 29 |
+
extrapolate: ExtrapolateType
|
| 30 |
+
returns: (B, C, oX, oY, oZ) tensor
|
| 31 |
+
"""
|
| 32 |
+
dim = 3
|
| 33 |
+
boundx, boundy, boundz = bound
|
| 34 |
+
oshape = g.shape[-dim-1:-1]
|
| 35 |
+
g = g.reshape([g.shape[0], 1, -1, dim])
|
| 36 |
+
gx, gy, gz = g.unbind(-1)
|
| 37 |
+
batch = max(inp.shape[0], gx.shape[0])
|
| 38 |
+
channel = inp.shape[1]
|
| 39 |
+
shape = inp.shape[-dim:]
|
| 40 |
+
nx, ny, nz = shape
|
| 41 |
+
|
| 42 |
+
# mask of inbounds voxels
|
| 43 |
+
mask = inbounds_mask_3d(extrapolate, gx, gy, gz, nx, ny, nz)
|
| 44 |
+
|
| 45 |
+
# nearest integer coordinates
|
| 46 |
+
gx, signx = get_indices(gx, nx, boundx)
|
| 47 |
+
gy, signy = get_indices(gy, ny, boundy)
|
| 48 |
+
gz, signz = get_indices(gz, nz, boundz)
|
| 49 |
+
|
| 50 |
+
# gather
|
| 51 |
+
inp = inp.reshape(inp.shape[:2] + [-1])
|
| 52 |
+
idx = sub2ind_list([gx, gy, gz], shape)
|
| 53 |
+
idx = idx.expand([batch, channel, idx.shape[-1]])
|
| 54 |
+
out = inp.gather(-1, idx)
|
| 55 |
+
sign = make_sign([signx, signy, signz])
|
| 56 |
+
if sign is not None:
|
| 57 |
+
out *= sign
|
| 58 |
+
if mask is not None:
|
| 59 |
+
out *= mask
|
| 60 |
+
out = out.reshape(out.shape[:2] + oshape)
|
| 61 |
+
return out
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
@torch.jit.script
|
| 65 |
+
def push3d(inp, g, shape: Optional[List[int]], bound: List[Bound],
|
| 66 |
+
extrapolate: int = 1):
|
| 67 |
+
"""
|
| 68 |
+
inp: (B, C, iX, iY, iZ) tensor
|
| 69 |
+
g: (B, iX, iY, iZ, 3) tensor
|
| 70 |
+
shape: List{3}[int], optional
|
| 71 |
+
bound: List{3}[Bound] tensor
|
| 72 |
+
extrapolate: ExtrapolateType
|
| 73 |
+
returns: (B, C, *shape) tensor
|
| 74 |
+
"""
|
| 75 |
+
dim = 3
|
| 76 |
+
boundx, boundy, boundz = bound
|
| 77 |
+
if inp.shape[-dim:] != g.shape[-dim-1:-1]:
|
| 78 |
+
raise ValueError('Input and grid should have the same spatial shape')
|
| 79 |
+
ishape = inp.shape[-dim:]
|
| 80 |
+
g = g.reshape([g.shape[0], 1, -1, dim])
|
| 81 |
+
gx, gy, gz = torch.unbind(g, -1)
|
| 82 |
+
inp = inp.reshape(inp.shape[:2] + [-1])
|
| 83 |
+
batch = max(inp.shape[0], gx.shape[0])
|
| 84 |
+
channel = inp.shape[1]
|
| 85 |
+
|
| 86 |
+
if shape is None:
|
| 87 |
+
shape = ishape
|
| 88 |
+
nx, ny, nz = shape
|
| 89 |
+
|
| 90 |
+
# mask of inbounds voxels
|
| 91 |
+
mask = inbounds_mask_3d(extrapolate, gx, gy, gz, nx, ny, nz)
|
| 92 |
+
|
| 93 |
+
# nearest integer coordinates
|
| 94 |
+
gx, signx = get_indices(gx, nx, boundx)
|
| 95 |
+
gy, signy = get_indices(gy, ny, boundy)
|
| 96 |
+
gz, signz = get_indices(gz, nz, boundz)
|
| 97 |
+
|
| 98 |
+
# scatter
|
| 99 |
+
out = torch.zeros([batch, channel, nx*ny*nz], dtype=inp.dtype, device=inp.device)
|
| 100 |
+
idx = sub2ind_list([gx, gy, gz], shape)
|
| 101 |
+
idx = idx.expand([batch, channel, idx.shape[-1]])
|
| 102 |
+
sign = make_sign([signx, signy, signz])
|
| 103 |
+
if sign is not None or mask is not None:
|
| 104 |
+
inp = inp.clone()
|
| 105 |
+
if sign is not None:
|
| 106 |
+
inp *= sign
|
| 107 |
+
if mask is not None:
|
| 108 |
+
inp *= mask
|
| 109 |
+
out.scatter_add_(-1, idx, inp)
|
| 110 |
+
|
| 111 |
+
out = out.reshape(out.shape[:2] + shape)
|
| 112 |
+
return out
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
# ======================================================================
|
| 116 |
+
# 2D
|
| 117 |
+
# ======================================================================
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
@torch.jit.script
|
| 121 |
+
def pull2d(inp, g, bound: List[Bound], extrapolate: int = 1):
|
| 122 |
+
"""
|
| 123 |
+
inp: (B, C, iX, iY) tensor
|
| 124 |
+
g: (B, oX, oY, 2) tensor
|
| 125 |
+
bound: List{2}[Bound] tensor
|
| 126 |
+
extrapolate: ExtrapolateType
|
| 127 |
+
returns: (B, C, oX, oY) tensor
|
| 128 |
+
"""
|
| 129 |
+
dim = 2
|
| 130 |
+
boundx, boundy = bound
|
| 131 |
+
oshape = g.shape[-dim-1:-1]
|
| 132 |
+
g = g.reshape([g.shape[0], 1, -1, dim])
|
| 133 |
+
gx, gy = g.unbind(-1)
|
| 134 |
+
batch = max(inp.shape[0], gx.shape[0])
|
| 135 |
+
channel = inp.shape[1]
|
| 136 |
+
shape = inp.shape[-dim:]
|
| 137 |
+
nx, ny = shape
|
| 138 |
+
|
| 139 |
+
# mask of inbounds voxels
|
| 140 |
+
mask = inbounds_mask_2d(extrapolate, gx, gy, nx, ny)
|
| 141 |
+
|
| 142 |
+
# nearest integer coordinates
|
| 143 |
+
gx, signx = get_indices(gx, nx, boundx)
|
| 144 |
+
gy, signy = get_indices(gy, ny, boundy)
|
| 145 |
+
|
| 146 |
+
# gather
|
| 147 |
+
inp = inp.reshape(inp.shape[:2] + [-1])
|
| 148 |
+
idx = sub2ind_list([gx, gy], shape)
|
| 149 |
+
idx = idx.expand([batch, channel, idx.shape[-1]])
|
| 150 |
+
out = inp.gather(-1, idx)
|
| 151 |
+
sign = make_sign([signx, signy])
|
| 152 |
+
if sign is not None:
|
| 153 |
+
out = out * sign
|
| 154 |
+
if mask is not None:
|
| 155 |
+
out = mask * mask
|
| 156 |
+
out = out.reshape(out.shape[:2] + oshape)
|
| 157 |
+
return out
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
@torch.jit.script
|
| 161 |
+
def push2d(inp, g, shape: Optional[List[int]], bound: List[Bound],
|
| 162 |
+
extrapolate: int = 1):
|
| 163 |
+
"""
|
| 164 |
+
inp: (B, C, iX, iY) tensor
|
| 165 |
+
g: (B, iX, iY, 2) tensor
|
| 166 |
+
shape: List{2}[int], optional
|
| 167 |
+
bound: List{2}[Bound] tensor
|
| 168 |
+
extrapolate: ExtrapolateType
|
| 169 |
+
returns: (B, C, *shape) tensor
|
| 170 |
+
"""
|
| 171 |
+
dim = 2
|
| 172 |
+
boundx, boundy = bound
|
| 173 |
+
if inp.shape[-dim:] != g.shape[-dim-1:-1]:
|
| 174 |
+
raise ValueError('Input and grid should have the same spatial shape')
|
| 175 |
+
ishape = inp.shape[-dim:]
|
| 176 |
+
g = g.reshape([g.shape[0], 1, -1, dim])
|
| 177 |
+
gx, gy = torch.unbind(g, -1)
|
| 178 |
+
inp = inp.reshape(inp.shape[:2] + [-1])
|
| 179 |
+
batch = max(inp.shape[0], gx.shape[0])
|
| 180 |
+
channel = inp.shape[1]
|
| 181 |
+
|
| 182 |
+
if shape is None:
|
| 183 |
+
shape = ishape
|
| 184 |
+
nx, ny = shape
|
| 185 |
+
|
| 186 |
+
# mask of inbounds voxels
|
| 187 |
+
mask = inbounds_mask_2d(extrapolate, gx, gy, nx, ny)
|
| 188 |
+
|
| 189 |
+
# nearest integer coordinates
|
| 190 |
+
gx, signx = get_indices(gx, nx, boundx)
|
| 191 |
+
gy, signy = get_indices(gy, ny, boundy)
|
| 192 |
+
|
| 193 |
+
# scatter
|
| 194 |
+
out = torch.zeros([batch, channel, nx*ny], dtype=inp.dtype, device=inp.device)
|
| 195 |
+
idx = sub2ind_list([gx, gy], shape)
|
| 196 |
+
idx = idx.expand([batch, channel, idx.shape[-1]])
|
| 197 |
+
sign = make_sign([signx, signy])
|
| 198 |
+
if sign is not None or mask is not None:
|
| 199 |
+
inp = inp.clone()
|
| 200 |
+
if sign is not None:
|
| 201 |
+
inp = inp * sign
|
| 202 |
+
if mask is not None:
|
| 203 |
+
inp = inp * mask
|
| 204 |
+
out.scatter_add_(-1, idx, inp)
|
| 205 |
+
|
| 206 |
+
out = out.reshape(out.shape[:2] + shape)
|
| 207 |
+
return out
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
# ======================================================================
|
| 211 |
+
# 1D
|
| 212 |
+
# ======================================================================
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
@torch.jit.script
|
| 216 |
+
def pull1d(inp, g, bound: List[Bound], extrapolate: int = 1):
|
| 217 |
+
"""
|
| 218 |
+
inp: (B, C, iX) tensor
|
| 219 |
+
g: (B, oX, 1) tensor
|
| 220 |
+
bound: List{1}[Bound] tensor
|
| 221 |
+
extrapolate: ExtrapolateType
|
| 222 |
+
returns: (B, C, oX) tensor
|
| 223 |
+
"""
|
| 224 |
+
dim = 1
|
| 225 |
+
boundx = bound[0]
|
| 226 |
+
oshape = g.shape[-dim-1:-1]
|
| 227 |
+
g = g.reshape([g.shape[0], 1, -1, dim])
|
| 228 |
+
gx = g.squeeze(-1)
|
| 229 |
+
batch = max(inp.shape[0], gx.shape[0])
|
| 230 |
+
channel = inp.shape[1]
|
| 231 |
+
shape = inp.shape[-dim:]
|
| 232 |
+
nx = shape[0]
|
| 233 |
+
|
| 234 |
+
# mask of inbounds voxels
|
| 235 |
+
mask = inbounds_mask_1d(extrapolate, gx, nx)
|
| 236 |
+
|
| 237 |
+
# nearest integer coordinates
|
| 238 |
+
gx, signx = get_indices(gx, nx, boundx)
|
| 239 |
+
|
| 240 |
+
# gather
|
| 241 |
+
inp = inp.reshape(inp.shape[:2] + [-1])
|
| 242 |
+
idx = gx
|
| 243 |
+
idx = idx.expand([batch, channel, idx.shape[-1]])
|
| 244 |
+
out = inp.gather(-1, idx)
|
| 245 |
+
sign = signx
|
| 246 |
+
if sign is not None:
|
| 247 |
+
out = out * sign
|
| 248 |
+
if mask is not None:
|
| 249 |
+
out = out * mask
|
| 250 |
+
out = out.reshape(out.shape[:2] + oshape)
|
| 251 |
+
return out
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
@torch.jit.script
|
| 255 |
+
def push1d(inp, g, shape: Optional[List[int]], bound: List[Bound],
|
| 256 |
+
extrapolate: int = 1):
|
| 257 |
+
"""
|
| 258 |
+
inp: (B, C, iX) tensor
|
| 259 |
+
g: (B, iX, 1) tensor
|
| 260 |
+
shape: List{1}[int], optional
|
| 261 |
+
bound: List{1}[Bound] tensor
|
| 262 |
+
extrapolate: ExtrapolateType
|
| 263 |
+
returns: (B, C, *shape) tensor
|
| 264 |
+
"""
|
| 265 |
+
dim = 1
|
| 266 |
+
boundx = bound[0]
|
| 267 |
+
if inp.shape[-dim:] != g.shape[-dim-1:-1]:
|
| 268 |
+
raise ValueError('Input and grid should have the same spatial shape')
|
| 269 |
+
ishape = inp.shape[-dim:]
|
| 270 |
+
g = g.reshape([g.shape[0], 1, -1, dim])
|
| 271 |
+
gx = g.squeeze(-1)
|
| 272 |
+
inp = inp.reshape(inp.shape[:2] + [-1])
|
| 273 |
+
batch = max(inp.shape[0], gx.shape[0])
|
| 274 |
+
channel = inp.shape[1]
|
| 275 |
+
|
| 276 |
+
if shape is None:
|
| 277 |
+
shape = ishape
|
| 278 |
+
nx = shape[0]
|
| 279 |
+
|
| 280 |
+
# mask of inbounds voxels
|
| 281 |
+
mask = inbounds_mask_1d(extrapolate, gx, nx)
|
| 282 |
+
|
| 283 |
+
# nearest integer coordinates
|
| 284 |
+
gx, signx = get_indices(gx, nx, boundx)
|
| 285 |
+
|
| 286 |
+
# scatter
|
| 287 |
+
out = torch.zeros([batch, channel, nx], dtype=inp.dtype, device=inp.device)
|
| 288 |
+
idx = gx
|
| 289 |
+
idx = idx.expand([batch, channel, idx.shape[-1]])
|
| 290 |
+
sign = signx
|
| 291 |
+
if sign is not None or mask is not None:
|
| 292 |
+
inp = inp.clone()
|
| 293 |
+
if sign is not None:
|
| 294 |
+
inp = inp * sign
|
| 295 |
+
if mask is not None:
|
| 296 |
+
inp = inp * mask
|
| 297 |
+
out.scatter_add_(-1, idx, inp)
|
| 298 |
+
|
| 299 |
+
out = out.reshape(out.shape[:2] + shape)
|
| 300 |
+
return out
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
# ======================================================================
|
| 304 |
+
# ND
|
| 305 |
+
# ======================================================================
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
@torch.jit.script
|
| 309 |
+
def grad(inp, g, bound: List[Bound], extrapolate: int = 1):
|
| 310 |
+
"""
|
| 311 |
+
inp: (B, C, *ishape) tensor
|
| 312 |
+
g: (B, *oshape, D) tensor
|
| 313 |
+
bound: List{D}[Bound] tensor
|
| 314 |
+
extrapolate: ExtrapolateType
|
| 315 |
+
returns: (B, C, *oshape, D) tensor
|
| 316 |
+
"""
|
| 317 |
+
dim = g.shape[-1]
|
| 318 |
+
oshape = list(g.shape[-dim-1:-1])
|
| 319 |
+
batch = max(inp.shape[0], g.shape[0])
|
| 320 |
+
channel = inp.shape[1]
|
| 321 |
+
|
| 322 |
+
return torch.zeros([batch, channel] + oshape + [dim],
|
| 323 |
+
dtype=inp.dtype, device=inp.device)
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
@torch.jit.script
|
| 327 |
+
def pushgrad(inp, g, shape: Optional[List[int]], bound: List[Bound],
|
| 328 |
+
extrapolate: int = 1):
|
| 329 |
+
"""
|
| 330 |
+
inp: (B, C, *ishape, D) tensor
|
| 331 |
+
g: (B, *ishape, D) tensor
|
| 332 |
+
shape: List{D}[int], optional, optional
|
| 333 |
+
bound: List{D}[Bound] tensor
|
| 334 |
+
extrapolate: ExtrapolateType
|
| 335 |
+
returns: (B, C, *shape) tensor
|
| 336 |
+
"""
|
| 337 |
+
dim = g.shape[-1]
|
| 338 |
+
if inp.shape[-dim-1:-1] != g.shape[-dim-1:-1]:
|
| 339 |
+
raise ValueError('Input and grid should have the same spatial shape')
|
| 340 |
+
ishape = inp.shape[-dim-1:-1]
|
| 341 |
+
batch = max(inp.shape[0], g.shape[0])
|
| 342 |
+
channel = inp.shape[1]
|
| 343 |
+
|
| 344 |
+
if shape is None:
|
| 345 |
+
shape = ishape
|
| 346 |
+
shape = list(shape)
|
| 347 |
+
|
| 348 |
+
return torch.zeros([batch, channel] + shape,
|
| 349 |
+
dtype=inp.dtype, device=inp.device)
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
@torch.jit.script
|
| 353 |
+
def hess(inp, g, bound: List[Bound], extrapolate: int = 1):
|
| 354 |
+
"""
|
| 355 |
+
inp: (B, C, *ishape) tensor
|
| 356 |
+
g: (B, *oshape, D) tensor
|
| 357 |
+
bound: List{D}[Bound] tensor
|
| 358 |
+
extrapolate: ExtrapolateType
|
| 359 |
+
returns: (B, C, *oshape, D, D) tensor
|
| 360 |
+
"""
|
| 361 |
+
dim = g.shape[-1]
|
| 362 |
+
oshape = list(g.shape[-dim-1:-1])
|
| 363 |
+
g = g.reshape([g.shape[0], 1, -1, dim])
|
| 364 |
+
batch = max(inp.shape[0], g.shape[0])
|
| 365 |
+
channel = inp.shape[1]
|
| 366 |
+
|
| 367 |
+
return torch.zeros([batch, channel] + oshape + [dim, dim],
|
| 368 |
+
dtype=inp.dtype, device=inp.device)
|
Generator/interpol/iso1.py
ADDED
|
@@ -0,0 +1,1339 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Isotropic 1-st order splines ("linear/bilinear/trilinear")"""
|
| 2 |
+
import torch
|
| 3 |
+
from .bounds import Bound
|
| 4 |
+
from .jit_utils import (sub2ind_list, make_sign,
|
| 5 |
+
inbounds_mask_3d, inbounds_mask_2d, inbounds_mask_1d)
|
| 6 |
+
from typing import List, Tuple, Optional
|
| 7 |
+
Tensor = torch.Tensor
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@torch.jit.script
|
| 11 |
+
def get_weights_and_indices(g, n: int, bound: Bound) \
|
| 12 |
+
-> Tuple[Tensor, Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
|
| 13 |
+
g0 = g.floor().long()
|
| 14 |
+
g1 = g0 + 1
|
| 15 |
+
sign1 = bound.transform(g1, n)
|
| 16 |
+
sign0 = bound.transform(g0, n)
|
| 17 |
+
g1 = bound.index(g1, n)
|
| 18 |
+
g0 = bound.index(g0, n)
|
| 19 |
+
g = g - g.floor()
|
| 20 |
+
return g, g0, g1, sign0, sign1
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# ======================================================================
|
| 24 |
+
# 3D
|
| 25 |
+
# ======================================================================
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@torch.jit.script
|
| 29 |
+
def pull3d(inp, g, bound: List[Bound], extrapolate: int = 1):
|
| 30 |
+
"""
|
| 31 |
+
inp: (B, C, iX, iY, iZ) tensor
|
| 32 |
+
g: (B, oX, oY, oZ, 3) tensor
|
| 33 |
+
bound: List{3}[Bound] tensor
|
| 34 |
+
extrapolate: ExtrapolateType
|
| 35 |
+
returns: (B, C, oX, oY, oZ) tensor
|
| 36 |
+
"""
|
| 37 |
+
dim = 3
|
| 38 |
+
boundx, boundy, boundz = bound
|
| 39 |
+
oshape = list(g.shape[-dim-1:-1])
|
| 40 |
+
g = g.reshape([g.shape[0], 1, -1, dim])
|
| 41 |
+
gx, gy, gz = g.unbind(-1)
|
| 42 |
+
batch = max(inp.shape[0], gx.shape[0])
|
| 43 |
+
channel = inp.shape[1]
|
| 44 |
+
shape = list(inp.shape[-dim:])
|
| 45 |
+
nx, ny, nz = shape
|
| 46 |
+
|
| 47 |
+
# mask of inbounds voxels
|
| 48 |
+
mask = inbounds_mask_3d(extrapolate, gx, gy, gz, nx, ny, nz)
|
| 49 |
+
|
| 50 |
+
# corners
|
| 51 |
+
# (upper weight, lower corner, upper corner, lower sign, upper sign)
|
| 52 |
+
gx, gx0, gx1, signx0, signx1 = get_weights_and_indices(gx, nx, boundx)
|
| 53 |
+
gy, gy0, gy1, signy0, signy1 = get_weights_and_indices(gy, ny, boundy)
|
| 54 |
+
gz, gz0, gz1, signz0, signz1 = get_weights_and_indices(gz, nz, boundz)
|
| 55 |
+
|
| 56 |
+
# gather
|
| 57 |
+
inp = inp.reshape(list(inp.shape[:2]) + [-1])
|
| 58 |
+
# - corner 000
|
| 59 |
+
idx = sub2ind_list([gx0, gy0, gz0], shape)
|
| 60 |
+
idx = idx.expand([batch, channel, idx.shape[-1]])
|
| 61 |
+
out = inp.gather(-1, idx)
|
| 62 |
+
sign = make_sign([signx0, signy0, signz0])
|
| 63 |
+
if sign is not None:
|
| 64 |
+
out = out * sign
|
| 65 |
+
out = out * ((1 - gx) * (1 - gy) * (1 - gz))
|
| 66 |
+
# - corner 001
|
| 67 |
+
idx = sub2ind_list([gx0, gy0, gz1], shape)
|
| 68 |
+
idx = idx.expand([batch, channel, idx.shape[-1]])
|
| 69 |
+
out1 = inp.gather(-1, idx)
|
| 70 |
+
sign = make_sign([signx0, signy0, signz1])
|
| 71 |
+
if sign is not None:
|
| 72 |
+
out1 = out1 * sign
|
| 73 |
+
out1 = out1 * ((1 - gx) * (1 - gy) * gz)
|
| 74 |
+
out = out + out1
|
| 75 |
+
# - corner 010
|
| 76 |
+
idx = sub2ind_list([gx0, gy1, gz0], shape)
|
| 77 |
+
idx = idx.expand([batch, channel, idx.shape[-1]])
|
| 78 |
+
out1 = inp.gather(-1, idx)
|
| 79 |
+
sign = make_sign([signx0, signy1, signz0])
|
| 80 |
+
if sign is not None:
|
| 81 |
+
out1 = out1 * sign
|
| 82 |
+
out1 = out1 * ((1 - gx) * gy * (1 - gz))
|
| 83 |
+
out = out + out1
|
| 84 |
+
# - corner 011
|
| 85 |
+
idx = sub2ind_list([gx0, gy1, gz1], shape)
|
| 86 |
+
idx = idx.expand([batch, channel, idx.shape[-1]])
|
| 87 |
+
out1 = inp.gather(-1, idx)
|
| 88 |
+
sign = make_sign([signx0, signy1, signz1])
|
| 89 |
+
if sign is not None:
|
| 90 |
+
out1 = out1 * sign
|
| 91 |
+
out1 = out1 * ((1 - gx) * gy * gz)
|
| 92 |
+
out = out + out1
|
| 93 |
+
# - corner 100
|
| 94 |
+
idx = sub2ind_list([gx1, gy0, gz0], shape)
|
| 95 |
+
idx = idx.expand([batch, channel, idx.shape[-1]])
|
| 96 |
+
out1 = inp.gather(-1, idx)
|
| 97 |
+
sign = make_sign([signx1, signy0, signz0])
|
| 98 |
+
if sign is not None:
|
| 99 |
+
out1 = out1 * sign
|
| 100 |
+
out1 = out1 * (gx * (1 - gy) * (1 - gz))
|
| 101 |
+
out = out + out1
|
| 102 |
+
# - corner 101
|
| 103 |
+
idx = sub2ind_list([gx1, gy0, gz1], shape)
|
| 104 |
+
idx = idx.expand([batch, channel, idx.shape[-1]])
|
| 105 |
+
out1 = inp.gather(-1, idx)
|
| 106 |
+
sign = make_sign([signx1, signy0, signz1])
|
| 107 |
+
if sign is not None:
|
| 108 |
+
out1 = out1 * sign
|
| 109 |
+
out1 = out1 * (gx * (1 - gy) * gz)
|
| 110 |
+
out = out + out1
|
| 111 |
+
# - corner 110
|
| 112 |
+
idx = sub2ind_list([gx1, gy1, gz0], shape)
|
| 113 |
+
idx = idx.expand([batch, channel, idx.shape[-1]])
|
| 114 |
+
out1 = inp.gather(-1, idx)
|
| 115 |
+
sign = make_sign([signx1, signy1, signz0])
|
| 116 |
+
if sign is not None:
|
| 117 |
+
out1 = out1 * sign
|
| 118 |
+
out1 = out1 * (gx * gy * (1 - gz))
|
| 119 |
+
out = out + out1
|
| 120 |
+
# - corner 111
|
| 121 |
+
idx = sub2ind_list([gx1, gy1, gz1], shape)
|
| 122 |
+
idx = idx.expand([batch, channel, idx.shape[-1]])
|
| 123 |
+
out1 = inp.gather(-1, idx)
|
| 124 |
+
sign = make_sign([signx1, signy1, signz1])
|
| 125 |
+
if sign is not None:
|
| 126 |
+
out1 = out1 * sign
|
| 127 |
+
out1 = out1 * (gx * gy * gz)
|
| 128 |
+
out = out + out1
|
| 129 |
+
|
| 130 |
+
if mask is not None:
|
| 131 |
+
out *= mask
|
| 132 |
+
out = out.reshape(list(out.shape[:2]) + oshape)
|
| 133 |
+
return out
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
@torch.jit.script
|
| 137 |
+
def push3d(inp, g, shape: Optional[List[int]], bound: List[Bound],
|
| 138 |
+
extrapolate: int = 1):
|
| 139 |
+
"""
|
| 140 |
+
inp: (B, C, iX, iY, iZ) tensor
|
| 141 |
+
g: (B, iX, iY, iZ, 3) tensor
|
| 142 |
+
shape: List{3}[int], optional
|
| 143 |
+
bound: List{3}[Bound] tensor
|
| 144 |
+
extrapolate: ExtrapolateType
|
| 145 |
+
returns: (B, C, *shape) tensor
|
| 146 |
+
"""
|
| 147 |
+
dim = 3
|
| 148 |
+
boundx, boundy, boundz = bound
|
| 149 |
+
if inp.shape[-dim:] != g.shape[-dim-1:-1]:
|
| 150 |
+
raise ValueError('Input and grid should have the same spatial shape')
|
| 151 |
+
ishape = list(inp.shape[-dim:])
|
| 152 |
+
g = g.reshape([g.shape[0], 1, -1, dim])
|
| 153 |
+
gx, gy, gz = torch.unbind(g, -1)
|
| 154 |
+
inp = inp.reshape(list(inp.shape[:2]) + [-1])
|
| 155 |
+
batch = max(inp.shape[0], gx.shape[0])
|
| 156 |
+
channel = inp.shape[1]
|
| 157 |
+
|
| 158 |
+
if shape is None:
|
| 159 |
+
shape = ishape
|
| 160 |
+
shape = list(shape)
|
| 161 |
+
nx, ny, nz = shape
|
| 162 |
+
|
| 163 |
+
# mask of inbounds voxels
|
| 164 |
+
mask = inbounds_mask_3d(extrapolate, gx, gy, gz, nx, ny, nz)
|
| 165 |
+
|
| 166 |
+
# corners
|
| 167 |
+
# (upper weight, lower corner, upper corner, lower sign, upper sign)
|
| 168 |
+
gx, gx0, gx1, signx0, signx1 = get_weights_and_indices(gx, nx, boundx)
|
| 169 |
+
gy, gy0, gy1, signy0, signy1 = get_weights_and_indices(gy, ny, boundy)
|
| 170 |
+
gz, gz0, gz1, signz0, signz1 = get_weights_and_indices(gz, nz, boundz)
|
| 171 |
+
|
| 172 |
+
# scatter
|
| 173 |
+
out = torch.zeros([batch, channel, nx*ny*nz],
|
| 174 |
+
dtype=inp.dtype, device=inp.device)
|
| 175 |
+
# - corner 000
|
| 176 |
+
idx = sub2ind_list([gx0, gy0, gz0], shape)
|
| 177 |
+
idx = idx.expand([batch, channel, idx.shape[-1]])
|
| 178 |
+
out1 = inp.clone()
|
| 179 |
+
sign = make_sign([signx0, signy0, signz0])
|
| 180 |
+
if sign is not None:
|
| 181 |
+
out1 = out1 * sign
|
| 182 |
+
if mask is not None:
|
| 183 |
+
out1 = out1 * mask
|
| 184 |
+
out1 = out1 * ((1 - gx) * (1 - gy) * (1 - gz))
|
| 185 |
+
out.scatter_add_(-1, idx, out1)
|
| 186 |
+
# - corner 001
|
| 187 |
+
idx = sub2ind_list([gx0, gy0, gz1], shape)
|
| 188 |
+
idx = idx.expand([batch, channel, idx.shape[-1]])
|
| 189 |
+
out1 = inp.clone()
|
| 190 |
+
sign = make_sign([signx0, signy0, signz1])
|
| 191 |
+
if sign is not None:
|
| 192 |
+
out1 = out1 * sign
|
| 193 |
+
if mask is not None:
|
| 194 |
+
out1 = out1 * mask
|
| 195 |
+
out1 = out1 * ((1 - gx) * (1 - gy) * gz)
|
| 196 |
+
out.scatter_add_(-1, idx, out1)
|
| 197 |
+
# - corner 010
|
| 198 |
+
idx = sub2ind_list([gx0, gy1, gz0], shape)
|
| 199 |
+
idx = idx.expand([batch, channel, idx.shape[-1]])
|
| 200 |
+
out1 = inp.clone()
|
| 201 |
+
sign = make_sign([signx0, signy1, signz0])
|
| 202 |
+
if sign is not None:
|
| 203 |
+
out1 = out1 * sign
|
| 204 |
+
if mask is not None:
|
| 205 |
+
out1 = out1 * mask
|
| 206 |
+
out1 = out1 * ((1 - gx) * gy * (1 - gz))
|
| 207 |
+
out.scatter_add_(-1, idx, out1)
|
| 208 |
+
# - corner 011
|
| 209 |
+
idx = sub2ind_list([gx0, gy1, gz1], shape)
|
| 210 |
+
idx = idx.expand([batch, channel, idx.shape[-1]])
|
| 211 |
+
out1 = inp.clone()
|
| 212 |
+
sign = make_sign([signx0, signy1, signz1])
|
| 213 |
+
if sign is not None:
|
| 214 |
+
out1 = out1 * sign
|
| 215 |
+
if mask is not None:
|
| 216 |
+
out1 = out1 * mask
|
| 217 |
+
out1 = out1 * ((1 - gx) * gy * gz)
|
| 218 |
+
out.scatter_add_(-1, idx, out1)
|
| 219 |
+
# - corner 100
|
| 220 |
+
idx = sub2ind_list([gx1, gy0, gz0], shape)
|
| 221 |
+
idx = idx.expand([batch, channel, idx.shape[-1]])
|
| 222 |
+
out1 = inp.clone()
|
| 223 |
+
sign = make_sign([signx1, signy0, signz0])
|
| 224 |
+
if sign is not None:
|
| 225 |
+
out1 = out1 * sign
|
| 226 |
+
if mask is not None:
|
| 227 |
+
out1 = out1 * mask
|
| 228 |
+
out1 = out1 * (gx * (1 - gy) * (1 - gz))
|
| 229 |
+
out.scatter_add_(-1, idx, out1)
|
| 230 |
+
# - corner 101
|
| 231 |
+
idx = sub2ind_list([gx1, gy0, gz1], shape)
|
| 232 |
+
idx = idx.expand([batch, channel, idx.shape[-1]])
|
| 233 |
+
out1 = inp.clone()
|
| 234 |
+
sign = make_sign([signx1, signy0, signz1])
|
| 235 |
+
if sign is not None:
|
| 236 |
+
out1 = out1 * sign
|
| 237 |
+
if mask is not None:
|
| 238 |
+
out1 = out1 * mask
|
| 239 |
+
out1 = out1 * (gx * (1 - gy) * gz)
|
| 240 |
+
out.scatter_add_(-1, idx, out1)
|
| 241 |
+
# - corner 110
|
| 242 |
+
idx = sub2ind_list([gx1, gy1, gz0], shape)
|
| 243 |
+
idx = idx.expand([batch, channel, idx.shape[-1]])
|
| 244 |
+
out1 = inp.clone()
|
| 245 |
+
sign = make_sign([signx1, signy1, signz0])
|
| 246 |
+
if sign is not None:
|
| 247 |
+
out1 = out1 * sign
|
| 248 |
+
if mask is not None:
|
| 249 |
+
out1 = out1 * mask
|
| 250 |
+
out1 = out1 * (gx * gy * (1 - gz))
|
| 251 |
+
out.scatter_add_(-1, idx, out1)
|
| 252 |
+
# - corner 111
|
| 253 |
+
idx = sub2ind_list([gx1, gy1, gz1], shape)
|
| 254 |
+
idx = idx.expand([batch, channel, idx.shape[-1]])
|
| 255 |
+
out1 = inp.clone()
|
| 256 |
+
sign = make_sign([signx1, signy1, signz1])
|
| 257 |
+
if sign is not None:
|
| 258 |
+
out1 = out1 * sign
|
| 259 |
+
if mask is not None:
|
| 260 |
+
out1 = out1 * mask
|
| 261 |
+
out1 = out1 * (gx * gy * gz)
|
| 262 |
+
out.scatter_add_(-1, idx, out1)
|
| 263 |
+
|
| 264 |
+
out = out.reshape(list(out.shape[:2]) + shape)
|
| 265 |
+
return out
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
@torch.jit.script
|
| 269 |
+
def grad3d(inp, g, bound: List[Bound], extrapolate: int = 1):
|
| 270 |
+
"""
|
| 271 |
+
inp: (B, C, iX, iY, iZ) tensor
|
| 272 |
+
g: (B, oX, oY, oZ, 3) tensor
|
| 273 |
+
bound: List{3}[Bound] tensor
|
| 274 |
+
extrapolate: ExtrapolateType
|
| 275 |
+
returns: (B, C, oX, oY, oZ, 3) tensor
|
| 276 |
+
"""
|
| 277 |
+
dim = 3
|
| 278 |
+
boundx, boundy, boundz = bound
|
| 279 |
+
oshape = list(g.shape[-dim-1:-1])
|
| 280 |
+
g = g.reshape([g.shape[0], 1, -1, dim])
|
| 281 |
+
gx, gy, gz = torch.unbind(g, -1)
|
| 282 |
+
batch = max(inp.shape[0], gx.shape[0])
|
| 283 |
+
channel = inp.shape[1]
|
| 284 |
+
shape = list(inp.shape[-dim:])
|
| 285 |
+
nx, ny, nz = shape
|
| 286 |
+
|
| 287 |
+
# mask of inbounds voxels
|
| 288 |
+
mask = inbounds_mask_3d(extrapolate, gx, gy, gz, nx, ny, nz)
|
| 289 |
+
|
| 290 |
+
# corners
|
| 291 |
+
# (upper weight, lower corner, upper corner, lower sign, upper sign)
|
| 292 |
+
gx, gx0, gx1, signx0, signx1 = get_weights_and_indices(gx, nx, boundx)
|
| 293 |
+
gy, gy0, gy1, signy0, signy1 = get_weights_and_indices(gy, ny, boundy)
|
| 294 |
+
gz, gz0, gz1, signz0, signz1 = get_weights_and_indices(gz, nz, boundz)
|
| 295 |
+
|
| 296 |
+
# gather
|
| 297 |
+
inp = inp.reshape(list(inp.shape[:2]) + [-1])
|
| 298 |
+
out = torch.empty([batch, channel] + list(g.shape[-2:]),
|
| 299 |
+
dtype=inp.dtype, device=inp.device)
|
| 300 |
+
outx, outy, outz = out.unbind(-1)
|
| 301 |
+
# - corner 000
|
| 302 |
+
idx = sub2ind_list([gx0, gy0, gz0], shape)
|
| 303 |
+
idx = idx.expand([batch, channel, idx.shape[-1]])
|
| 304 |
+
torch.gather(inp, -1, idx, out=outx)
|
| 305 |
+
outy.copy_(outx)
|
| 306 |
+
outz.copy_(outx)
|
| 307 |
+
sign = make_sign([signx0, signy0, signz0])
|
| 308 |
+
if sign is not None:
|
| 309 |
+
out *= sign.unsqueeze(-1)
|
| 310 |
+
outx *= - (1 - gy) * (1 - gz)
|
| 311 |
+
outy *= - (1 - gx) * (1 - gz)
|
| 312 |
+
outz *= - (1 - gx) * (1 - gy)
|
| 313 |
+
# - corner 001
|
| 314 |
+
idx = sub2ind_list([gx0, gy0, gz1], shape)
|
| 315 |
+
idx = idx.expand([batch, channel, idx.shape[-1]])
|
| 316 |
+
out1 = inp.gather(-1, idx)
|
| 317 |
+
sign = make_sign([signx0, signy0, signz1])
|
| 318 |
+
if sign is not None:
|
| 319 |
+
out1 *= sign
|
| 320 |
+
outx.addcmul_(out1, - (1 - gy) * gz)
|
| 321 |
+
outy.addcmul_(out1, - (1 - gx) * gz)
|
| 322 |
+
outz.addcmul_(out1, (1 - gx) * (1 - gy))
|
| 323 |
+
# - corner 010
|
| 324 |
+
idx = sub2ind_list([gx0, gy1, gz0], shape)
|
| 325 |
+
idx = idx.expand([batch, channel, idx.shape[-1]])
|
| 326 |
+
out1 = inp.gather(-1, idx)
|
| 327 |
+
sign = make_sign([signx0, signy1, signz0])
|
| 328 |
+
if sign is not None:
|
| 329 |
+
out1 *= sign
|
| 330 |
+
outx.addcmul_(out1, - gy * (1 - gz))
|
| 331 |
+
outy.addcmul_(out1, (1 - gx) * (1 - gz))
|
| 332 |
+
outz.addcmul_(out1, - (1 - gx) * gy)
|
| 333 |
+
# - corner 011
|
| 334 |
+
idx = sub2ind_list([gx0, gy1, gz1], shape)
|
| 335 |
+
idx = idx.expand([batch, channel, idx.shape[-1]])
|
| 336 |
+
out1 = inp.gather(-1, idx)
|
| 337 |
+
sign = make_sign([signx0, signy1, signz1])
|
| 338 |
+
if sign is not None:
|
| 339 |
+
out1 *= sign
|
| 340 |
+
outx.addcmul_(out1, - gy * gz)
|
| 341 |
+
outy.addcmul_(out1, (1 - gx) * gz)
|
| 342 |
+
outz.addcmul_(out1, (1 - gx) * gy)
|
| 343 |
+
# - corner 100
|
| 344 |
+
idx = sub2ind_list([gx1, gy0, gz0], shape)
|
| 345 |
+
idx = idx.expand([batch, channel, idx.shape[-1]])
|
| 346 |
+
out1 = inp.gather(-1, idx)
|
| 347 |
+
sign = make_sign([signx1, signy0, signz0])
|
| 348 |
+
if sign is not None:
|
| 349 |
+
out1 *= sign
|
| 350 |
+
outx.addcmul_(out1, (1 - gy) * (1 - gz))
|
| 351 |
+
outy.addcmul_(out1, - gx * (1 - gz))
|
| 352 |
+
outz.addcmul_(out1, - gx * (1 - gy))
|
| 353 |
+
# - corner 101
|
| 354 |
+
idx = sub2ind_list([gx1, gy0, gz1], shape)
|
| 355 |
+
idx = idx.expand([batch, channel, idx.shape[-1]])
|
| 356 |
+
out1 = inp.gather(-1, idx)
|
| 357 |
+
sign = make_sign([signx1, signy0, signz1])
|
| 358 |
+
if sign is not None:
|
| 359 |
+
out1 *= sign
|
| 360 |
+
outx.addcmul_(out1, (1 - gy) * gz)
|
| 361 |
+
outy.addcmul_(out1, - gx * gz)
|
| 362 |
+
outz.addcmul_(out1, gx * (1 - gy))
|
| 363 |
+
# - corner 110
|
| 364 |
+
idx = sub2ind_list([gx1, gy1, gz0], shape)
|
| 365 |
+
idx = idx.expand([batch, channel, idx.shape[-1]])
|
| 366 |
+
out1 = inp.gather(-1, idx)
|
| 367 |
+
sign = make_sign([signx1, signy1, signz0])
|
| 368 |
+
if sign is not None:
|
| 369 |
+
out1 *= sign
|
| 370 |
+
outx.addcmul_(out1, gy * (1 - gz))
|
| 371 |
+
outy.addcmul_(out1, gx * (1 - gz))
|
| 372 |
+
outz.addcmul_(out1, - gx * gy)
|
| 373 |
+
# - corner 111
|
| 374 |
+
idx = sub2ind_list([gx1, gy1, gz1], shape)
|
| 375 |
+
idx = idx.expand([batch, channel, idx.shape[-1]])
|
| 376 |
+
out1 = inp.gather(-1, idx)
|
| 377 |
+
sign = make_sign([signx1, signy1, signz1])
|
| 378 |
+
if sign is not None:
|
| 379 |
+
out1 *= sign
|
| 380 |
+
outx.addcmul_(out1, gy * gz)
|
| 381 |
+
outy.addcmul_(out1, gx * gz)
|
| 382 |
+
outz.addcmul_(out1, gx * gy)
|
| 383 |
+
|
| 384 |
+
if mask is not None:
|
| 385 |
+
out *= mask.unsqueeze(-1)
|
| 386 |
+
out = out.reshape(list(out.shape[:2]) + oshape + [3])
|
| 387 |
+
return out
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
@torch.jit.script
|
| 391 |
+
def pushgrad3d(inp, g, shape: Optional[List[int]], bound: List[Bound],
|
| 392 |
+
extrapolate: int = 1):
|
| 393 |
+
"""
|
| 394 |
+
inp: (B, C, iX, iY, iZ, 3) tensor
|
| 395 |
+
g: (B, iX, iY, iZ, 3) tensor
|
| 396 |
+
shape: List{3}[int], optional
|
| 397 |
+
bound: List{3}[Bound] tensor
|
| 398 |
+
extrapolate: ExtrapolateType
|
| 399 |
+
returns: (B, C, *shape) tensor
|
| 400 |
+
"""
|
| 401 |
+
dim = 3
|
| 402 |
+
boundx, boundy, boundz = bound
|
| 403 |
+
if inp.shape[-dim-1:-1] != g.shape[-dim-1:-1]:
|
| 404 |
+
raise ValueError('Input and grid should have the same spatial shape')
|
| 405 |
+
ishape = list(inp.shape[-dim-1:-1])
|
| 406 |
+
g = g.reshape([g.shape[0], 1, -1, dim])
|
| 407 |
+
gx, gy, gz = g.unbind(-1)
|
| 408 |
+
inp = inp.reshape(list(inp.shape[:2]) + [-1, dim])
|
| 409 |
+
batch = max(inp.shape[0], g.shape[0])
|
| 410 |
+
channel = inp.shape[1]
|
| 411 |
+
|
| 412 |
+
if shape is None:
|
| 413 |
+
shape = ishape
|
| 414 |
+
shape = list(shape)
|
| 415 |
+
nx, ny, nz = shape
|
| 416 |
+
|
| 417 |
+
# mask of inbounds voxels
|
| 418 |
+
mask = inbounds_mask_3d(extrapolate, gx, gy, gz, nx, ny, nz)
|
| 419 |
+
|
| 420 |
+
# corners
|
| 421 |
+
# (upper weight, lower corner, upper corner, lower sign, upper sign)
|
| 422 |
+
gx, gx0, gx1, signx0, signx1 = get_weights_and_indices(gx, nx, boundx)
|
| 423 |
+
gy, gy0, gy1, signy0, signy1 = get_weights_and_indices(gy, ny, boundy)
|
| 424 |
+
gz, gz0, gz1, signz0, signz1 = get_weights_and_indices(gz, nz, boundz)
|
| 425 |
+
|
| 426 |
+
# scatter
|
| 427 |
+
out = torch.zeros([batch, channel, nx*ny*nz],
|
| 428 |
+
dtype=inp.dtype, device=inp.device)
|
| 429 |
+
# - corner 000
|
| 430 |
+
idx = sub2ind_list([gx0, gy0, gz0], shape)
|
| 431 |
+
idx = idx.expand([batch, channel, idx.shape[-1]])
|
| 432 |
+
out1 = inp.clone()
|
| 433 |
+
sign = make_sign([signx0, signy0, signz0])
|
| 434 |
+
if sign is not None:
|
| 435 |
+
out1 *= sign.unsqueeze(-1)
|
| 436 |
+
if mask is not None:
|
| 437 |
+
out1 *= mask.unsqueeze(-1)
|
| 438 |
+
out1x, out1y, out1z = out1.unbind(-1)
|
| 439 |
+
out1x *= - (1 - gy) * (1 - gz)
|
| 440 |
+
out1y *= - (1 - gx) * (1 - gz)
|
| 441 |
+
out1z *= - (1 - gx) * (1 - gy)
|
| 442 |
+
out.scatter_add_(-1, idx, out1x + out1y + out1z)
|
| 443 |
+
# - corner 001
|
| 444 |
+
idx = sub2ind_list([gx0, gy0, gz1], shape)
|
| 445 |
+
idx = idx.expand([batch, channel, idx.shape[-1]])
|
| 446 |
+
out1 = inp.clone()
|
| 447 |
+
sign = make_sign([signx0, signy0, signz1])
|
| 448 |
+
if sign is not None:
|
| 449 |
+
out1 *= sign.unsqueeze(-1)
|
| 450 |
+
if mask is not None:
|
| 451 |
+
out1 *= mask.unsqueeze(-1)
|
| 452 |
+
out1x, out1y, out1z = out1.unbind(-1)
|
| 453 |
+
out1x *= - (1 - gy) * gz
|
| 454 |
+
out1y *= - (1 - gx) * gz
|
| 455 |
+
out1z *= (1 - gx) * (1 - gy)
|
| 456 |
+
out.scatter_add_(-1, idx, out1x + out1y + out1z)
|
| 457 |
+
# - corner 010
|
| 458 |
+
idx = sub2ind_list([gx0, gy1, gz0], shape)
|
| 459 |
+
idx = idx.expand([batch, channel, idx.shape[-1]])
|
| 460 |
+
out1 = inp.clone()
|
| 461 |
+
sign = make_sign([signx0, signy1, signz0])
|
| 462 |
+
if sign is not None:
|
| 463 |
+
out1 *= sign.unsqueeze(-1)
|
| 464 |
+
if mask is not None:
|
| 465 |
+
out1 *= mask.unsqueeze(-1)
|
| 466 |
+
out1x, out1y, out1z = out1.unbind(-1)
|
| 467 |
+
out1x *= - gy * (1 - gz)
|
| 468 |
+
out1y *= (1 - gx) * (1 - gz)
|
| 469 |
+
out1z *= - (1 - gx) * gy
|
| 470 |
+
out.scatter_add_(-1, idx, out1x + out1y + out1z)
|
| 471 |
+
# - corner 011
|
| 472 |
+
idx = sub2ind_list([gx0, gy1, gz1], shape)
|
| 473 |
+
idx = idx.expand([batch, channel, idx.shape[-1]])
|
| 474 |
+
out1 = inp.clone()
|
| 475 |
+
sign = make_sign([signx0, signy1, signz1])
|
| 476 |
+
if sign is not None:
|
| 477 |
+
out1 *= sign.unsqueeze(-1)
|
| 478 |
+
if mask is not None:
|
| 479 |
+
out1 *= mask.unsqueeze(-1)
|
| 480 |
+
out1x, out1y, out1z = out1.unbind(-1)
|
| 481 |
+
out1x *= - gy * gz
|
| 482 |
+
out1y *= (1 - gx) * gz
|
| 483 |
+
out1z *= (1 - gx) * gy
|
| 484 |
+
out.scatter_add_(-1, idx, out1x + out1y + out1z)
|
| 485 |
+
# - corner 100
|
| 486 |
+
idx = sub2ind_list([gx1, gy0, gz0], shape)
|
| 487 |
+
idx = idx.expand([batch, channel, idx.shape[-1]])
|
| 488 |
+
out1 = inp.clone()
|
| 489 |
+
sign = make_sign([signx1, signy0, signz0])
|
| 490 |
+
if sign is not None:
|
| 491 |
+
out1 *= sign.unsqueeze(-1)
|
| 492 |
+
if mask is not None:
|
| 493 |
+
out1 *= mask.unsqueeze(-1)
|
| 494 |
+
out1x, out1y, out1z = out1.unbind(-1)
|
| 495 |
+
out1x *= (1 - gy) * (1 - gz)
|
| 496 |
+
out1y *= - gx * (1 - gz)
|
| 497 |
+
out1z *= - gx * (1 - gy)
|
| 498 |
+
out.scatter_add_(-1, idx, out1x + out1y + out1z)
|
| 499 |
+
# - corner 101
|
| 500 |
+
idx = sub2ind_list([gx1, gy0, gz1], shape)
|
| 501 |
+
idx = idx.expand([batch, channel, idx.shape[-1]])
|
| 502 |
+
out1 = inp.clone()
|
| 503 |
+
sign = make_sign([signx1, signy0, signz1])
|
| 504 |
+
if sign is not None:
|
| 505 |
+
out1 *= sign.unsqueeze(-1)
|
| 506 |
+
if mask is not None:
|
| 507 |
+
out1 *= mask.unsqueeze(-1)
|
| 508 |
+
out1x, out1y, out1z = out1.unbind(-1)
|
| 509 |
+
out1x *= (1 - gy) * gz
|
| 510 |
+
out1y *= - gx * gz
|
| 511 |
+
out1z *= gx * (1 - gy)
|
| 512 |
+
out.scatter_add_(-1, idx, out1x + out1y + out1z)
|
| 513 |
+
# - corner 110
|
| 514 |
+
idx = sub2ind_list([gx1, gy1, gz0], shape)
|
| 515 |
+
idx = idx.expand([batch, channel, idx.shape[-1]])
|
| 516 |
+
out1 = inp.clone()
|
| 517 |
+
sign = make_sign([signx1, signy1, signz0])
|
| 518 |
+
if sign is not None:
|
| 519 |
+
out1 *= sign.unsqueeze(-1)
|
| 520 |
+
if mask is not None:
|
| 521 |
+
out1 *= mask.unsqueeze(-1)
|
| 522 |
+
out1x, out1y, out1z = out1.unbind(-1)
|
| 523 |
+
out1x *= gy * (1 - gz)
|
| 524 |
+
out1y *= gx * (1 - gz)
|
| 525 |
+
out1z *= - gx * gy
|
| 526 |
+
out.scatter_add_(-1, idx, out1x + out1y + out1z)
|
| 527 |
+
# - corner 111
|
| 528 |
+
idx = sub2ind_list([gx1, gy1, gz1], shape)
|
| 529 |
+
idx = idx.expand([batch, channel, idx.shape[-1]])
|
| 530 |
+
out1 = inp.clone()
|
| 531 |
+
sign = make_sign([signx1, signy1, signz1])
|
| 532 |
+
if sign is not None:
|
| 533 |
+
out1 *= sign.unsqueeze(-1)
|
| 534 |
+
if mask is not None:
|
| 535 |
+
out1 *= mask.unsqueeze(-1)
|
| 536 |
+
out1x, out1y, out1z = out1.unbind(-1)
|
| 537 |
+
out1x *= gy * gz
|
| 538 |
+
out1y *= gx * gz
|
| 539 |
+
out1z *= gx * gy
|
| 540 |
+
out.scatter_add_(-1, idx, out1x + out1y + out1z)
|
| 541 |
+
|
| 542 |
+
out = out.reshape(list(out.shape[:2]) + shape)
|
| 543 |
+
return out
|
| 544 |
+
|
| 545 |
+
|
| 546 |
+
@torch.jit.script
|
| 547 |
+
def hess3d(inp, g, bound: List[Bound], extrapolate: int = 1):
|
| 548 |
+
"""
|
| 549 |
+
inp: (B, C, iX, iY, iZ) tensor
|
| 550 |
+
g: (B, oX, oY, oZ, 3) tensor
|
| 551 |
+
bound: List{3}[Bound] tensor
|
| 552 |
+
extrapolate: ExtrapolateType
|
| 553 |
+
returns: (B, C, oX, oY, oZ, 3, 3) tensor
|
| 554 |
+
"""
|
| 555 |
+
dim = 3
|
| 556 |
+
boundx, boundy, boundz = bound
|
| 557 |
+
oshape = list(g.shape[-dim-1:-1])
|
| 558 |
+
g = g.reshape([g.shape[0], 1, -1, dim])
|
| 559 |
+
gx, gy, gz = torch.unbind(g, -1)
|
| 560 |
+
batch = max(inp.shape[0], gx.shape[0])
|
| 561 |
+
channel = inp.shape[1]
|
| 562 |
+
shape = list(inp.shape[-dim:])
|
| 563 |
+
nx, ny, nz = shape
|
| 564 |
+
|
| 565 |
+
# mask of inbounds voxels
|
| 566 |
+
mask = inbounds_mask_3d(extrapolate, gx, gy, gz, nx, ny, nz)
|
| 567 |
+
|
| 568 |
+
# corners
|
| 569 |
+
# (upper weight, lower corner, upper corner, lower sign, upper sign)
|
| 570 |
+
gx, gx0, gx1, signx0, signx1 = get_weights_and_indices(gx, nx, boundx)
|
| 571 |
+
gy, gy0, gy1, signy0, signy1 = get_weights_and_indices(gy, ny, boundy)
|
| 572 |
+
gz, gz0, gz1, signz0, signz1 = get_weights_and_indices(gz, nz, boundz)
|
| 573 |
+
|
| 574 |
+
# gather
|
| 575 |
+
inp = inp.reshape(list(inp.shape[:2]) + [-1])
|
| 576 |
+
out = torch.empty([batch, channel, g.shape[-2], dim, dim],
|
| 577 |
+
dtype=inp.dtype, device=inp.device)
|
| 578 |
+
outx, outy, outz = out.unbind(-1)
|
| 579 |
+
outxx, outyx, outzx = outx.unbind(-1)
|
| 580 |
+
outxy, outyy, outzy = outy.unbind(-1)
|
| 581 |
+
outxz, outyz, outzz = outz.unbind(-1)
|
| 582 |
+
# - corner 000
|
| 583 |
+
idx = sub2ind_list([gx0, gy0, gz0], shape)
|
| 584 |
+
idx = idx.expand([batch, channel, idx.shape[-1]])
|
| 585 |
+
torch.gather(inp, -1, idx, out=outxy)
|
| 586 |
+
outxz.copy_(outxy)
|
| 587 |
+
outyz.copy_(outxy)
|
| 588 |
+
outxx.zero_()
|
| 589 |
+
outyy.zero_()
|
| 590 |
+
outzz.zero_()
|
| 591 |
+
sign = make_sign([signx0, signy0, signz0])
|
| 592 |
+
if sign is not None:
|
| 593 |
+
out *= sign.unsqueeze(-1).unsqueeze(-1)
|
| 594 |
+
outxy *= (1 - gz)
|
| 595 |
+
outxz *= (1 - gy)
|
| 596 |
+
outyz *= (1 - gx)
|
| 597 |
+
# - corner 001
|
| 598 |
+
idx = sub2ind_list([gx0, gy0, gz1], shape)
|
| 599 |
+
idx = idx.expand([batch, channel, idx.shape[-1]])
|
| 600 |
+
out1 = inp.gather(-1, idx)
|
| 601 |
+
sign = make_sign([signx0, signy0, signz1])
|
| 602 |
+
if sign is not None:
|
| 603 |
+
out1 *= sign
|
| 604 |
+
outxy.addcmul_(out1, gz)
|
| 605 |
+
outxz.addcmul_(out1, - (1 - gy))
|
| 606 |
+
outyz.addcmul_(out1, - (1 - gx))
|
| 607 |
+
# - corner 010
|
| 608 |
+
idx = sub2ind_list([gx0, gy1, gz0], shape)
|
| 609 |
+
idx = idx.expand([batch, channel, idx.shape[-1]])
|
| 610 |
+
out1 = inp.gather(-1, idx)
|
| 611 |
+
sign = make_sign([signx0, signy1, signz0])
|
| 612 |
+
if sign is not None:
|
| 613 |
+
out1 *= sign
|
| 614 |
+
outxy.addcmul_(out1, - (1 - gz))
|
| 615 |
+
outxz.addcmul_(out1, gy)
|
| 616 |
+
outyz.addcmul_(out1, - (1 - gx))
|
| 617 |
+
# - corner 011
|
| 618 |
+
idx = sub2ind_list([gx0, gy1, gz1], shape)
|
| 619 |
+
idx = idx.expand([batch, channel, idx.shape[-1]])
|
| 620 |
+
out1 = inp.gather(-1, idx)
|
| 621 |
+
sign = make_sign([signx0, signy1, signz1])
|
| 622 |
+
if sign is not None:
|
| 623 |
+
out1 *= sign
|
| 624 |
+
outxy.addcmul_(out1, - gz)
|
| 625 |
+
outxz.addcmul_(out1, - gy)
|
| 626 |
+
outyz.addcmul_(out1, (1 - gx))
|
| 627 |
+
# - corner 100
|
| 628 |
+
idx = sub2ind_list([gx1, gy0, gz0], shape)
|
| 629 |
+
idx = idx.expand([batch, channel, idx.shape[-1]])
|
| 630 |
+
out1 = inp.gather(-1, idx)
|
| 631 |
+
sign = make_sign([signx1, signy0, signz0])
|
| 632 |
+
if sign is not None:
|
| 633 |
+
out1 *= sign
|
| 634 |
+
outxy.addcmul_(out1, - (1 - gz))
|
| 635 |
+
outxz.addcmul_(out1, - (1 - gy))
|
| 636 |
+
outyz.addcmul_(out1, gx)
|
| 637 |
+
# - corner 101
|
| 638 |
+
idx = sub2ind_list([gx1, gy0, gz1], shape)
|
| 639 |
+
idx = idx.expand([batch, channel, idx.shape[-1]])
|
| 640 |
+
out1 = inp.gather(-1, idx)
|
| 641 |
+
sign = make_sign([signx1, signy0, signz1])
|
| 642 |
+
if sign is not None:
|
| 643 |
+
out1 *= sign
|
| 644 |
+
outxy.addcmul_(out1, - gz)
|
| 645 |
+
outxz.addcmul_(out1, (1 - gy))
|
| 646 |
+
outyz.addcmul_(out1, - gx)
|
| 647 |
+
# - corner 110
|
| 648 |
+
idx = sub2ind_list([gx1, gy1, gz0], shape)
|
| 649 |
+
idx = idx.expand([batch, channel, idx.shape[-1]])
|
| 650 |
+
out1 = inp.gather(-1, idx)
|
| 651 |
+
sign = make_sign([signx1, signy1, signz0])
|
| 652 |
+
if sign is not None:
|
| 653 |
+
out1 *= sign
|
| 654 |
+
outxy.addcmul_(out1, (1 - gz))
|
| 655 |
+
outxz.addcmul_(out1, - gy)
|
| 656 |
+
outyz.addcmul_(out1, - gx)
|
| 657 |
+
# - corner 111
|
| 658 |
+
idx = sub2ind_list([gx1, gy1, gz1], shape)
|
| 659 |
+
idx = idx.expand([batch, channel, idx.shape[-1]])
|
| 660 |
+
out1 = inp.gather(-1, idx)
|
| 661 |
+
sign = make_sign([signx1, signy1, signz1])
|
| 662 |
+
if sign is not None:
|
| 663 |
+
out1 *= sign
|
| 664 |
+
outxy.addcmul_(out1, gz)
|
| 665 |
+
outxz.addcmul_(out1, gy)
|
| 666 |
+
outyz.addcmul_(out1, gx)
|
| 667 |
+
|
| 668 |
+
outyx.copy_(outxy)
|
| 669 |
+
outzx.copy_(outxz)
|
| 670 |
+
outzy.copy_(outyz)
|
| 671 |
+
|
| 672 |
+
if mask is not None:
|
| 673 |
+
out *= mask.unsqueeze(-1).unsqueeze(-1)
|
| 674 |
+
out = out.reshape(list(out.shape[:2]) + oshape + [dim, dim])
|
| 675 |
+
return out
|
| 676 |
+
|
| 677 |
+
|
| 678 |
+
# ======================================================================
|
| 679 |
+
# 2D
|
| 680 |
+
# ======================================================================
|
| 681 |
+
|
| 682 |
+
|
| 683 |
+
@torch.jit.script
|
| 684 |
+
def pull2d(inp, g, bound: List[Bound], extrapolate: int = 1):
|
| 685 |
+
"""
|
| 686 |
+
inp: (B, C, iX, iY) tensor
|
| 687 |
+
g: (B, oX, oY, 2) tensor
|
| 688 |
+
bound: List{2}[Bound] tensor
|
| 689 |
+
extrapolate: ExtrapolateType
|
| 690 |
+
returns: (B, C, oX, oY) tensor
|
| 691 |
+
"""
|
| 692 |
+
dim = 2
|
| 693 |
+
boundx, boundy = bound
|
| 694 |
+
oshape = list(g.shape[-dim-1:-1])
|
| 695 |
+
g = g.reshape([g.shape[0], 1, -1, dim])
|
| 696 |
+
gx, gy = g.unbind(-1)
|
| 697 |
+
batch = max(inp.shape[0], gx.shape[0])
|
| 698 |
+
channel = inp.shape[1]
|
| 699 |
+
shape = list(inp.shape[-dim:])
|
| 700 |
+
nx, ny = shape
|
| 701 |
+
|
| 702 |
+
# mask of inbounds voxels
|
| 703 |
+
mask = inbounds_mask_2d(extrapolate, gx, gy, nx, ny)
|
| 704 |
+
|
| 705 |
+
# corners
|
| 706 |
+
# (upper weight, lower corner, upper corner, lower sign, upper sign)
|
| 707 |
+
gx, gx0, gx1, signx0, signx1 = get_weights_and_indices(gx, nx, boundx)
|
| 708 |
+
gy, gy0, gy1, signy0, signy1 = get_weights_and_indices(gy, ny, boundy)
|
| 709 |
+
|
| 710 |
+
# gather
|
| 711 |
+
inp = inp.reshape(list(inp.shape[:2]) + [-1])
|
| 712 |
+
# - corner 00
|
| 713 |
+
idx = sub2ind_list([gx0, gy0], shape)
|
| 714 |
+
idx = idx.expand([batch, channel, idx.shape[-1]])
|
| 715 |
+
out = inp.gather(-1, idx)
|
| 716 |
+
sign = make_sign([signx0, signy0])
|
| 717 |
+
if sign is not None:
|
| 718 |
+
out = out * sign
|
| 719 |
+
out = out * ((1 - gx) * (1 - gy))
|
| 720 |
+
# - corner 01
|
| 721 |
+
idx = sub2ind_list([gx0, gy1], shape)
|
| 722 |
+
idx = idx.expand([batch, channel, idx.shape[-1]])
|
| 723 |
+
out1 = inp.gather(-1, idx)
|
| 724 |
+
sign = make_sign([signx0, signy1])
|
| 725 |
+
if sign is not None:
|
| 726 |
+
out1 = out1 * sign
|
| 727 |
+
out1 = out1 * ((1 - gx) * gy)
|
| 728 |
+
out = out + out1
|
| 729 |
+
# - corner 10
|
| 730 |
+
idx = sub2ind_list([gx1, gy0], shape)
|
| 731 |
+
idx = idx.expand([batch, channel, idx.shape[-1]])
|
| 732 |
+
out1 = inp.gather(-1, idx)
|
| 733 |
+
sign = make_sign([signx1, signy0])
|
| 734 |
+
if sign is not None:
|
| 735 |
+
out1 = out1 * sign
|
| 736 |
+
out1 = out1 * (gx * (1 - gy))
|
| 737 |
+
out = out + out1
|
| 738 |
+
# - corner 11
|
| 739 |
+
idx = sub2ind_list([gx1, gy1], shape)
|
| 740 |
+
idx = idx.expand([batch, channel, idx.shape[-1]])
|
| 741 |
+
out1 = inp.gather(-1, idx)
|
| 742 |
+
sign = make_sign([signx1, signy1])
|
| 743 |
+
if sign is not None:
|
| 744 |
+
out1 = out1 * sign
|
| 745 |
+
out1 = out1 * (gx * gy)
|
| 746 |
+
out = out + out1
|
| 747 |
+
|
| 748 |
+
if mask is not None:
|
| 749 |
+
out *= mask
|
| 750 |
+
out = out.reshape(list(out.shape[:2]) + oshape)
|
| 751 |
+
return out
|
| 752 |
+
|
| 753 |
+
|
| 754 |
+
@torch.jit.script
|
| 755 |
+
def push2d(inp, g, shape: Optional[List[int]], bound: List[Bound],
|
| 756 |
+
extrapolate: int = 1):
|
| 757 |
+
"""
|
| 758 |
+
inp: (B, C, iX, iY) tensor
|
| 759 |
+
g: (B, iX, iY, 2) tensor
|
| 760 |
+
shape: List{2}[int], optional
|
| 761 |
+
bound: List{2}[Bound] tensor
|
| 762 |
+
extrapolate: ExtrapolateType
|
| 763 |
+
returns: (B, C, *shape) tensor
|
| 764 |
+
"""
|
| 765 |
+
dim = 2
|
| 766 |
+
boundx, boundy = bound
|
| 767 |
+
if inp.shape[-dim:] != g.shape[-dim-1:-1]:
|
| 768 |
+
raise ValueError('Input and grid should have the same spatial shape')
|
| 769 |
+
ishape = list(inp.shape[-dim:])
|
| 770 |
+
g = g.reshape([g.shape[0], 1, -1, dim])
|
| 771 |
+
gx, gy = torch.unbind(g, -1)
|
| 772 |
+
inp = inp.reshape(list(inp.shape[:2]) + [-1])
|
| 773 |
+
batch = max(inp.shape[0], gx.shape[0])
|
| 774 |
+
channel = inp.shape[1]
|
| 775 |
+
|
| 776 |
+
if shape is None:
|
| 777 |
+
shape = ishape
|
| 778 |
+
shape = list(shape)
|
| 779 |
+
nx, ny = shape
|
| 780 |
+
|
| 781 |
+
# mask of inbounds voxels
|
| 782 |
+
mask = inbounds_mask_2d(extrapolate, gx, gy, nx, ny)
|
| 783 |
+
|
| 784 |
+
# corners
|
| 785 |
+
# (upper weight, lower corner, upper corner, lower sign, upper sign)
|
| 786 |
+
gx, gx0, gx1, signx0, signx1 = get_weights_and_indices(gx, nx, boundx)
|
| 787 |
+
gy, gy0, gy1, signy0, signy1 = get_weights_and_indices(gy, ny, boundy)
|
| 788 |
+
|
| 789 |
+
# scatter
|
| 790 |
+
out = torch.zeros([batch, channel, nx*ny],
|
| 791 |
+
dtype=inp.dtype, device=inp.device)
|
| 792 |
+
# - corner 00
|
| 793 |
+
idx = sub2ind_list([gx0, gy0], shape)
|
| 794 |
+
idx = idx.expand([batch, channel, idx.shape[-1]])
|
| 795 |
+
out1 = inp.clone()
|
| 796 |
+
sign = make_sign([signx0, signy0])
|
| 797 |
+
if sign is not None:
|
| 798 |
+
out1 *= sign
|
| 799 |
+
if mask is not None:
|
| 800 |
+
out1 *= mask
|
| 801 |
+
out1 *= (1 - gx) * (1 - gy)
|
| 802 |
+
out.scatter_add_(-1, idx, out1)
|
| 803 |
+
# - corner 01
|
| 804 |
+
idx = sub2ind_list([gx0, gy1], shape)
|
| 805 |
+
idx = idx.expand([batch, channel, idx.shape[-1]])
|
| 806 |
+
out1 = inp.clone()
|
| 807 |
+
sign = make_sign([signx0, signy1])
|
| 808 |
+
if sign is not None:
|
| 809 |
+
out1 *= sign
|
| 810 |
+
if mask is not None:
|
| 811 |
+
out1 *= mask
|
| 812 |
+
out1 *= (1 - gx) * gy
|
| 813 |
+
out.scatter_add_(-1, idx, out1)
|
| 814 |
+
# - corner 10
|
| 815 |
+
idx = sub2ind_list([gx1, gy0], shape)
|
| 816 |
+
idx = idx.expand([batch, channel, idx.shape[-1]])
|
| 817 |
+
out1 = inp.clone()
|
| 818 |
+
sign = make_sign([signx1, signy0])
|
| 819 |
+
if sign is not None:
|
| 820 |
+
out1 *= sign
|
| 821 |
+
if mask is not None:
|
| 822 |
+
out1 *= mask
|
| 823 |
+
out1 *= gx * (1 - gy)
|
| 824 |
+
out.scatter_add_(-1, idx, out1)
|
| 825 |
+
# - corner 11
|
| 826 |
+
idx = sub2ind_list([gx1, gy1], shape)
|
| 827 |
+
idx = idx.expand([batch, channel, idx.shape[-1]])
|
| 828 |
+
out1 = inp.clone()
|
| 829 |
+
sign = make_sign([signx1, signy1])
|
| 830 |
+
if sign is not None:
|
| 831 |
+
out1 *= sign
|
| 832 |
+
if mask is not None:
|
| 833 |
+
out1 *= mask
|
| 834 |
+
out1 *= gx * gy
|
| 835 |
+
out.scatter_add_(-1, idx, out1)
|
| 836 |
+
|
| 837 |
+
out = out.reshape(list(out.shape[:2]) + shape)
|
| 838 |
+
return out
|
| 839 |
+
|
| 840 |
+
|
| 841 |
+
@torch.jit.script
|
| 842 |
+
def grad2d(inp, g, bound: List[Bound], extrapolate: int = 1):
|
| 843 |
+
"""
|
| 844 |
+
inp: (B, C, iX, iY) tensor
|
| 845 |
+
g: (B, oX, oY, 2) tensor
|
| 846 |
+
bound: List{2}[Bound] tensor
|
| 847 |
+
extrapolate: ExtrapolateType
|
| 848 |
+
returns: (B, C, oX, oY, 2) tensor
|
| 849 |
+
"""
|
| 850 |
+
dim = 2
|
| 851 |
+
boundx, boundy = bound
|
| 852 |
+
oshape = list(g.shape[-dim-1:-1])
|
| 853 |
+
g = g.reshape([g.shape[0], 1, -1, dim])
|
| 854 |
+
gx, gy = torch.unbind(g, -1)
|
| 855 |
+
batch = max(inp.shape[0], gx.shape[0])
|
| 856 |
+
channel = inp.shape[1]
|
| 857 |
+
shape = list(inp.shape[-dim:])
|
| 858 |
+
nx, ny = shape
|
| 859 |
+
|
| 860 |
+
# mask of inbounds voxels
|
| 861 |
+
mask = inbounds_mask_2d(extrapolate, gx, gy, nx, ny)
|
| 862 |
+
|
| 863 |
+
# corners
|
| 864 |
+
# (upper weight, lower corner, upper corner, lower sign, upper sign)
|
| 865 |
+
gx, gx0, gx1, signx0, signx1 = get_weights_and_indices(gx, nx, boundx)
|
| 866 |
+
gy, gy0, gy1, signy0, signy1 = get_weights_and_indices(gy, ny, boundy)
|
| 867 |
+
|
| 868 |
+
# gather
|
| 869 |
+
inp = inp.reshape(list(inp.shape[:2]) + [-1])
|
| 870 |
+
out = torch.empty([batch, channel] + list(g.shape[-2:]),
|
| 871 |
+
dtype=inp.dtype, device=inp.device)
|
| 872 |
+
outx, outy = out.unbind(-1)
|
| 873 |
+
# - corner 00
|
| 874 |
+
idx = sub2ind_list([gx0, gy0], shape)
|
| 875 |
+
idx = idx.expand([batch, channel, idx.shape[-1]])
|
| 876 |
+
torch.gather(inp, -1, idx, out=outx)
|
| 877 |
+
outy.copy_(outx)
|
| 878 |
+
sign = make_sign([signx0, signy0])
|
| 879 |
+
if sign is not None:
|
| 880 |
+
out *= sign.unsqueeze(-1)
|
| 881 |
+
outx *= - (1 - gy)
|
| 882 |
+
outy *= - (1 - gx)
|
| 883 |
+
# - corner 01
|
| 884 |
+
idx = sub2ind_list([gx0, gy1], shape)
|
| 885 |
+
idx = idx.expand([batch, channel, idx.shape[-1]])
|
| 886 |
+
out1 = inp.gather(-1, idx)
|
| 887 |
+
sign = make_sign([signx0, signy1])
|
| 888 |
+
if sign is not None:
|
| 889 |
+
out1 *= sign
|
| 890 |
+
outx.addcmul_(out1, - gy)
|
| 891 |
+
outy.addcmul_(out1, (1 - gx))
|
| 892 |
+
# - corner 10
|
| 893 |
+
idx = sub2ind_list([gx1, gy0], shape)
|
| 894 |
+
idx = idx.expand([batch, channel, idx.shape[-1]])
|
| 895 |
+
out1 = inp.gather(-1, idx)
|
| 896 |
+
sign = make_sign([signx1, signy0])
|
| 897 |
+
if sign is not None:
|
| 898 |
+
out1 *= sign
|
| 899 |
+
outx.addcmul_(out1, (1 - gy))
|
| 900 |
+
outy.addcmul_(out1, - gx)
|
| 901 |
+
# - corner 11
|
| 902 |
+
idx = sub2ind_list([gx1, gy1], shape)
|
| 903 |
+
idx = idx.expand([batch, channel, idx.shape[-1]])
|
| 904 |
+
out1 = inp.gather(-1, idx)
|
| 905 |
+
sign = make_sign([signx1, signy1])
|
| 906 |
+
if sign is not None:
|
| 907 |
+
out1 *= sign
|
| 908 |
+
outx.addcmul_(out1, gy)
|
| 909 |
+
outy.addcmul_(out1, gx)
|
| 910 |
+
|
| 911 |
+
if mask is not None:
|
| 912 |
+
out *= mask.unsqueeze(-1)
|
| 913 |
+
out = out.reshape(list(out.shape[:2]) + oshape + [dim])
|
| 914 |
+
return out
|
| 915 |
+
|
| 916 |
+
|
| 917 |
+
@torch.jit.script
|
| 918 |
+
def pushgrad2d(inp, g, shape: Optional[List[int]], bound: List[Bound],
|
| 919 |
+
extrapolate: int = 1):
|
| 920 |
+
"""
|
| 921 |
+
inp: (B, C, iX, iY, 2) tensor
|
| 922 |
+
g: (B, iX, iY, 2) tensor
|
| 923 |
+
shape: List{2}[int], optional
|
| 924 |
+
bound: List{2}[Bound] tensor
|
| 925 |
+
extrapolate: ExtrapolateType
|
| 926 |
+
returns: (B, C, *shape) tensor
|
| 927 |
+
"""
|
| 928 |
+
dim = 2
|
| 929 |
+
boundx, boundy = bound
|
| 930 |
+
if inp.shape[-dim-1:-1] != g.shape[-dim-1:-1]:
|
| 931 |
+
raise ValueError('Input and grid should have the same spatial shape')
|
| 932 |
+
ishape = list(inp.shape[-dim-1:-1])
|
| 933 |
+
g = g.reshape([g.shape[0], 1, -1, dim])
|
| 934 |
+
gx, gy = g.unbind(-1)
|
| 935 |
+
inp = inp.reshape(list(inp.shape[:2]) + [-1, dim])
|
| 936 |
+
batch = max(inp.shape[0], g.shape[0])
|
| 937 |
+
channel = inp.shape[1]
|
| 938 |
+
|
| 939 |
+
if shape is None:
|
| 940 |
+
shape = ishape
|
| 941 |
+
shape = list(shape)
|
| 942 |
+
nx, ny = shape
|
| 943 |
+
|
| 944 |
+
# mask of inbounds voxels
|
| 945 |
+
mask = inbounds_mask_2d(extrapolate, gx, gy, nx, ny)
|
| 946 |
+
|
| 947 |
+
# corners
|
| 948 |
+
# (upper weight, lower corner, upper corner, lower sign, upper sign)
|
| 949 |
+
gx, gx0, gx1, signx0, signx1 = get_weights_and_indices(gx, nx, boundx)
|
| 950 |
+
gy, gy0, gy1, signy0, signy1 = get_weights_and_indices(gy, ny, boundy)
|
| 951 |
+
|
| 952 |
+
# scatter
|
| 953 |
+
out = torch.zeros([batch, channel, nx*ny],
|
| 954 |
+
dtype=inp.dtype, device=inp.device)
|
| 955 |
+
# - corner 00
|
| 956 |
+
idx = sub2ind_list([gx0, gy0], shape)
|
| 957 |
+
idx = idx.expand([batch, channel, idx.shape[-1]])
|
| 958 |
+
out1 = inp.clone()
|
| 959 |
+
sign = make_sign([signx0, signy0])
|
| 960 |
+
if sign is not None:
|
| 961 |
+
out1 *= sign.unsqueeze(-1)
|
| 962 |
+
if mask is not None:
|
| 963 |
+
out1 *= mask.unsqueeze(-1)
|
| 964 |
+
out1x, out1y = out1.unbind(-1)
|
| 965 |
+
out1x *= - (1 - gy)
|
| 966 |
+
out1y *= - (1 - gx)
|
| 967 |
+
out.scatter_add_(-1, idx, out1x + out1y)
|
| 968 |
+
# - corner 01
|
| 969 |
+
idx = sub2ind_list([gx0, gy1], shape)
|
| 970 |
+
idx = idx.expand([batch, channel, idx.shape[-1]])
|
| 971 |
+
out1 = inp.clone()
|
| 972 |
+
sign = make_sign([signx0, signy1])
|
| 973 |
+
if sign is not None:
|
| 974 |
+
out1 *= sign.unsqueeze(-1)
|
| 975 |
+
if mask is not None:
|
| 976 |
+
out1 *= mask.unsqueeze(-1)
|
| 977 |
+
out1x, out1y = out1.unbind(-1)
|
| 978 |
+
out1x *= - gy
|
| 979 |
+
out1y *= (1 - gx)
|
| 980 |
+
out.scatter_add_(-1, idx, out1x + out1y)
|
| 981 |
+
# - corner 10
|
| 982 |
+
idx = sub2ind_list([gx1, gy0], shape)
|
| 983 |
+
idx = idx.expand([batch, channel, idx.shape[-1]])
|
| 984 |
+
out1 = inp.clone()
|
| 985 |
+
sign = make_sign([signx1, signy0])
|
| 986 |
+
if sign is not None:
|
| 987 |
+
out1 *= sign.unsqueeze(-1)
|
| 988 |
+
if mask is not None:
|
| 989 |
+
out1 *= mask.unsqueeze(-1)
|
| 990 |
+
out1x, out1y = out1.unbind(-1)
|
| 991 |
+
out1x *= (1 - gy)
|
| 992 |
+
out1y *= - gx
|
| 993 |
+
out.scatter_add_(-1, idx, out1x + out1y)
|
| 994 |
+
# - corner 11
|
| 995 |
+
idx = sub2ind_list([gx1, gy1], shape)
|
| 996 |
+
idx = idx.expand([batch, channel, idx.shape[-1]])
|
| 997 |
+
out1 = inp.clone()
|
| 998 |
+
sign = make_sign([signx1, signy1])
|
| 999 |
+
if sign is not None:
|
| 1000 |
+
out1 *= sign.unsqueeze(-1)
|
| 1001 |
+
if mask is not None:
|
| 1002 |
+
out1 *= mask.unsqueeze(-1)
|
| 1003 |
+
out1x, out1y = out1.unbind(-1)
|
| 1004 |
+
out1x *= gy
|
| 1005 |
+
out1y *= gx
|
| 1006 |
+
out.scatter_add_(-1, idx, out1x + out1y)
|
| 1007 |
+
|
| 1008 |
+
out = out.reshape(list(out.shape[:2]) + shape)
|
| 1009 |
+
return out
|
| 1010 |
+
|
| 1011 |
+
|
| 1012 |
+
@torch.jit.script
|
| 1013 |
+
def hess2d(inp, g, bound: List[Bound], extrapolate: int = 1):
|
| 1014 |
+
"""
|
| 1015 |
+
inp: (B, C, iX, iY) tensor
|
| 1016 |
+
g: (B, oX, oY, 2) tensor
|
| 1017 |
+
bound: List{2}[Bound] tensor
|
| 1018 |
+
extrapolate: ExtrapolateType
|
| 1019 |
+
returns: (B, C, oX, oY, 2, 2) tensor
|
| 1020 |
+
"""
|
| 1021 |
+
dim = 2
|
| 1022 |
+
boundx, boundy = bound
|
| 1023 |
+
oshape = list(g.shape[-dim-1:-1])
|
| 1024 |
+
g = g.reshape([g.shape[0], 1, -1, dim])
|
| 1025 |
+
gx, gy = torch.unbind(g, -1)
|
| 1026 |
+
batch = max(inp.shape[0], gx.shape[0])
|
| 1027 |
+
channel = inp.shape[1]
|
| 1028 |
+
shape = list(inp.shape[-dim:])
|
| 1029 |
+
nx, ny = shape
|
| 1030 |
+
|
| 1031 |
+
# mask of inbounds voxels
|
| 1032 |
+
mask = inbounds_mask_2d(extrapolate, gx, gy, nx, ny)
|
| 1033 |
+
|
| 1034 |
+
# corners
|
| 1035 |
+
# (upper weight, lower corner, upper corner, lower sign, upper sign)
|
| 1036 |
+
gx, gx0, gx1, signx0, signx1 = get_weights_and_indices(gx, nx, boundx)
|
| 1037 |
+
gy, gy0, gy1, signy0, signy1 = get_weights_and_indices(gy, ny, boundy)
|
| 1038 |
+
|
| 1039 |
+
# gather
|
| 1040 |
+
inp = inp.reshape(list(inp.shape[:2]) + [-1])
|
| 1041 |
+
out = torch.empty([batch, channel, g.shape[-2], dim, dim],
|
| 1042 |
+
dtype=inp.dtype, device=inp.device)
|
| 1043 |
+
outx, outy = out.unbind(-1)
|
| 1044 |
+
outxx, outyx = outx.unbind(-1)
|
| 1045 |
+
outxy, outyy = outy.unbind(-1)
|
| 1046 |
+
# - corner 00
|
| 1047 |
+
idx = sub2ind_list([gx0, gy0], shape)
|
| 1048 |
+
idx = idx.expand([batch, channel, idx.shape[-1]])
|
| 1049 |
+
torch.gather(inp, -1, idx, out=outxy)
|
| 1050 |
+
outxx.zero_()
|
| 1051 |
+
outyy.zero_()
|
| 1052 |
+
sign = make_sign([signx0, signy0])
|
| 1053 |
+
if sign is not None:
|
| 1054 |
+
out *= sign.unsqueeze(-1).unsqueeze(-1)
|
| 1055 |
+
outxy *= 1
|
| 1056 |
+
# - corner 01
|
| 1057 |
+
idx = sub2ind_list([gx0, gy1], shape)
|
| 1058 |
+
idx = idx.expand([batch, channel, idx.shape[-1]])
|
| 1059 |
+
out1 = inp.gather(-1, idx)
|
| 1060 |
+
sign = make_sign([signx0, signy1])
|
| 1061 |
+
if sign is not None:
|
| 1062 |
+
out1 *= sign
|
| 1063 |
+
outxy.add_(out1, alpha=-1)
|
| 1064 |
+
# - corner 10
|
| 1065 |
+
idx = sub2ind_list([gx1, gy0], shape)
|
| 1066 |
+
idx = idx.expand([batch, channel, idx.shape[-1]])
|
| 1067 |
+
out1 = inp.gather(-1, idx)
|
| 1068 |
+
sign = make_sign([signx1, signy0])
|
| 1069 |
+
if sign is not None:
|
| 1070 |
+
out1 *= sign
|
| 1071 |
+
outxy.add_(out1, alpha=-1)
|
| 1072 |
+
# - corner 11
|
| 1073 |
+
idx = sub2ind_list([gx1, gy1], shape)
|
| 1074 |
+
idx = idx.expand([batch, channel, idx.shape[-1]])
|
| 1075 |
+
out1 = inp.gather(-1, idx)
|
| 1076 |
+
sign = make_sign([signx1, signy1])
|
| 1077 |
+
if sign is not None:
|
| 1078 |
+
out1 *= sign
|
| 1079 |
+
outxy.add_(out1)
|
| 1080 |
+
|
| 1081 |
+
outyx.copy_(outxy)
|
| 1082 |
+
|
| 1083 |
+
if mask is not None:
|
| 1084 |
+
out *= mask.unsqueeze(-1).unsqueeze(-1)
|
| 1085 |
+
out = out.reshape(list(out.shape[:2]) + oshape + [dim, dim])
|
| 1086 |
+
return out
|
| 1087 |
+
|
| 1088 |
+
|
| 1089 |
+
# ======================================================================
|
| 1090 |
+
# 1D
|
| 1091 |
+
# ======================================================================
|
| 1092 |
+
|
| 1093 |
+
|
| 1094 |
+
@torch.jit.script
|
| 1095 |
+
def pull1d(inp, g, bound: List[Bound], extrapolate: int = 1):
|
| 1096 |
+
"""
|
| 1097 |
+
inp: (B, C, iX) tensor
|
| 1098 |
+
g: (B, oX, 1) tensor
|
| 1099 |
+
bound: List{1}[Bound] tensor
|
| 1100 |
+
extrapolate: ExtrapolateType
|
| 1101 |
+
returns: (B, C, oX) tensor
|
| 1102 |
+
"""
|
| 1103 |
+
dim = 1
|
| 1104 |
+
boundx = bound[0]
|
| 1105 |
+
oshape = list(g.shape[-dim-1:-1])
|
| 1106 |
+
g = g.reshape([g.shape[0], 1, -1, dim])
|
| 1107 |
+
gx = g.squeeze(-1)
|
| 1108 |
+
batch = max(inp.shape[0], gx.shape[0])
|
| 1109 |
+
channel = inp.shape[1]
|
| 1110 |
+
shape = list(inp.shape[-dim:])
|
| 1111 |
+
nx = shape[0]
|
| 1112 |
+
|
| 1113 |
+
# mask of inbounds voxels
|
| 1114 |
+
mask = inbounds_mask_1d(extrapolate, gx, nx)
|
| 1115 |
+
|
| 1116 |
+
# corners
|
| 1117 |
+
# (upper weight, lower corner, upper corner, lower sign, upper sign)
|
| 1118 |
+
gx, gx0, gx1, signx0, signx1 = get_weights_and_indices(gx, nx, boundx)
|
| 1119 |
+
|
| 1120 |
+
# gather
|
| 1121 |
+
inp = inp.reshape(list(inp.shape[:2]) + [-1])
|
| 1122 |
+
# - corner 0
|
| 1123 |
+
idx = gx0
|
| 1124 |
+
idx = idx.expand([batch, channel, idx.shape[-1]])
|
| 1125 |
+
out = inp.gather(-1, idx)
|
| 1126 |
+
sign = signx0
|
| 1127 |
+
if sign is not None:
|
| 1128 |
+
out = out * sign
|
| 1129 |
+
out = out * (1 - gx)
|
| 1130 |
+
# - corner 1
|
| 1131 |
+
idx = gx1
|
| 1132 |
+
idx = idx.expand([batch, channel, idx.shape[-1]])
|
| 1133 |
+
out1 = inp.gather(-1, idx)
|
| 1134 |
+
sign = signx1
|
| 1135 |
+
if sign is not None:
|
| 1136 |
+
out1 = out1 * sign
|
| 1137 |
+
out1 = out1 * gx
|
| 1138 |
+
out = out + out1
|
| 1139 |
+
|
| 1140 |
+
if mask is not None:
|
| 1141 |
+
out *= mask
|
| 1142 |
+
out = out.reshape(list(out.shape[:2]) + oshape)
|
| 1143 |
+
return out
|
| 1144 |
+
|
| 1145 |
+
|
| 1146 |
+
@torch.jit.script
|
| 1147 |
+
def push1d(inp, g, shape: Optional[List[int]], bound: List[Bound],
|
| 1148 |
+
extrapolate: int = 1):
|
| 1149 |
+
"""
|
| 1150 |
+
inp: (B, C, iX, iY) tensor
|
| 1151 |
+
g: (B, iX, iY, 2) tensor
|
| 1152 |
+
shape: List{2}[int], optional
|
| 1153 |
+
bound: List{2}[Bound] tensor
|
| 1154 |
+
extrapolate: ExtrapolateType
|
| 1155 |
+
returns: (B, C, *shape) tensor
|
| 1156 |
+
"""
|
| 1157 |
+
dim = 1
|
| 1158 |
+
boundx = bound[0]
|
| 1159 |
+
if inp.shape[-dim:] != g.shape[-dim-1:-1]:
|
| 1160 |
+
raise ValueError('Input and grid should have the same spatial shape')
|
| 1161 |
+
ishape = list(inp.shape[-dim:])
|
| 1162 |
+
g = g.reshape([g.shape[0], 1, -1, dim])
|
| 1163 |
+
gx = g.squeeze(-1)
|
| 1164 |
+
inp = inp.reshape(list(inp.shape[:2]) + [-1])
|
| 1165 |
+
batch = max(inp.shape[0], gx.shape[0])
|
| 1166 |
+
channel = inp.shape[1]
|
| 1167 |
+
|
| 1168 |
+
if shape is None:
|
| 1169 |
+
shape = ishape
|
| 1170 |
+
shape = list(shape)
|
| 1171 |
+
nx = shape[0]
|
| 1172 |
+
|
| 1173 |
+
# mask of inbounds voxels
|
| 1174 |
+
mask = inbounds_mask_1d(extrapolate, gx, nx)
|
| 1175 |
+
|
| 1176 |
+
# corners
|
| 1177 |
+
# (upper weight, lower corner, upper corner, lower sign, upper sign)
|
| 1178 |
+
gx, gx0, gx1, signx0, signx1 = get_weights_and_indices(gx, nx, boundx)
|
| 1179 |
+
|
| 1180 |
+
# scatter
|
| 1181 |
+
out = torch.zeros([batch, channel, nx],
|
| 1182 |
+
dtype=inp.dtype, device=inp.device)
|
| 1183 |
+
# - corner 0
|
| 1184 |
+
idx = gx0
|
| 1185 |
+
idx = idx.expand([batch, channel, idx.shape[-1]])
|
| 1186 |
+
out1 = inp.clone()
|
| 1187 |
+
sign = signx0
|
| 1188 |
+
if sign is not None:
|
| 1189 |
+
out1 = out1 * sign
|
| 1190 |
+
if mask is not None:
|
| 1191 |
+
out1 = out1 * mask
|
| 1192 |
+
out1 = out1 * (1 - gx)
|
| 1193 |
+
out.scatter_add_(-1, idx, out1)
|
| 1194 |
+
# - corner 1
|
| 1195 |
+
idx = gx1
|
| 1196 |
+
idx = idx.expand([batch, channel, idx.shape[-1]])
|
| 1197 |
+
out1 = inp.clone()
|
| 1198 |
+
sign = signx1
|
| 1199 |
+
if sign is not None:
|
| 1200 |
+
out1 = out1 * sign
|
| 1201 |
+
if mask is not None:
|
| 1202 |
+
out1 = out1 * mask
|
| 1203 |
+
out1 = out1 * gx
|
| 1204 |
+
out.scatter_add_(-1, idx, out1)
|
| 1205 |
+
|
| 1206 |
+
out = out.reshape(list(out.shape[:2]) + shape)
|
| 1207 |
+
return out
|
| 1208 |
+
|
| 1209 |
+
|
| 1210 |
+
@torch.jit.script
|
| 1211 |
+
def grad1d(inp, g, bound: List[Bound], extrapolate: int = 1):
|
| 1212 |
+
"""
|
| 1213 |
+
inp: (B, C, iX) tensor
|
| 1214 |
+
g: (B, oX, 1) tensor
|
| 1215 |
+
bound: List{1}[Bound] tensor
|
| 1216 |
+
extrapolate: ExtrapolateType
|
| 1217 |
+
returns: (B, C, oX, 1) tensor
|
| 1218 |
+
"""
|
| 1219 |
+
dim = 1
|
| 1220 |
+
boundx = bound[0]
|
| 1221 |
+
oshape = list(g.shape[-dim-1:-1])
|
| 1222 |
+
g = g.reshape([g.shape[0], 1, -1, dim])
|
| 1223 |
+
gx = g.squeeze(-1)
|
| 1224 |
+
batch = max(inp.shape[0], gx.shape[0])
|
| 1225 |
+
channel = inp.shape[1]
|
| 1226 |
+
shape = list(inp.shape[-dim:])
|
| 1227 |
+
nx = shape[0]
|
| 1228 |
+
|
| 1229 |
+
# mask of inbounds voxels
|
| 1230 |
+
mask = inbounds_mask_1d(extrapolate, gx, nx)
|
| 1231 |
+
|
| 1232 |
+
# corners
|
| 1233 |
+
# (upper weight, lower corner, upper corner, lower sign, upper sign)
|
| 1234 |
+
gx, gx0, gx1, signx0, signx1 = get_weights_and_indices(gx, nx, boundx)
|
| 1235 |
+
|
| 1236 |
+
# gather
|
| 1237 |
+
inp = inp.reshape(list(inp.shape[:2]) + [-1])
|
| 1238 |
+
out = torch.empty([batch, channel] + list(g.shape[-2:]),
|
| 1239 |
+
dtype=inp.dtype, device=inp.device)
|
| 1240 |
+
outx = out.squeeze(-1)
|
| 1241 |
+
# - corner 0
|
| 1242 |
+
idx = gx0
|
| 1243 |
+
idx = idx.expand([batch, channel, idx.shape[-1]])
|
| 1244 |
+
torch.gather(inp, -1, idx, out=outx)
|
| 1245 |
+
sign = signx0
|
| 1246 |
+
if sign is not None:
|
| 1247 |
+
out *= sign.unsqueeze(-1)
|
| 1248 |
+
outx.neg_()
|
| 1249 |
+
# - corner 1
|
| 1250 |
+
idx = gx1
|
| 1251 |
+
idx = idx.expand([batch, channel, idx.shape[-1]])
|
| 1252 |
+
out1 = inp.gather(-1, idx)
|
| 1253 |
+
sign = signx1
|
| 1254 |
+
if sign is not None:
|
| 1255 |
+
out1 *= sign
|
| 1256 |
+
outx.add_(out1)
|
| 1257 |
+
|
| 1258 |
+
if mask is not None:
|
| 1259 |
+
out *= mask.unsqueeze(-1)
|
| 1260 |
+
out = out.reshape(list(out.shape[:2]) + oshape + [dim])
|
| 1261 |
+
return out
|
| 1262 |
+
|
| 1263 |
+
|
| 1264 |
+
@torch.jit.script
|
| 1265 |
+
def pushgrad1d(inp, g, shape: Optional[List[int]], bound: List[Bound],
|
| 1266 |
+
extrapolate: int = 1):
|
| 1267 |
+
"""
|
| 1268 |
+
inp: (B, C, iX, 1) tensor
|
| 1269 |
+
g: (B, iX, 1) tensor
|
| 1270 |
+
shape: List{1}[int], optional
|
| 1271 |
+
bound: List{1}[Bound] tensor
|
| 1272 |
+
extrapolate: ExtrapolateType
|
| 1273 |
+
returns: (B, C, *shape) tensor
|
| 1274 |
+
"""
|
| 1275 |
+
dim = 1
|
| 1276 |
+
boundx = bound[0]
|
| 1277 |
+
if inp.shape[-2] != g.shape[-2]:
|
| 1278 |
+
raise ValueError('Input and grid should have the same spatial shape')
|
| 1279 |
+
ishape = list(inp.shape[-dim-1:-1])
|
| 1280 |
+
g = g.reshape([g.shape[0], 1, -1, dim])
|
| 1281 |
+
gx = g.squeeze(-1)
|
| 1282 |
+
inp = inp.reshape(list(inp.shape[:2]) + [-1, dim])
|
| 1283 |
+
batch = max(inp.shape[0], g.shape[0])
|
| 1284 |
+
channel = inp.shape[1]
|
| 1285 |
+
|
| 1286 |
+
if shape is None:
|
| 1287 |
+
shape = ishape
|
| 1288 |
+
shape = list(shape)
|
| 1289 |
+
nx = shape[0]
|
| 1290 |
+
|
| 1291 |
+
# mask of inbounds voxels
|
| 1292 |
+
mask = inbounds_mask_1d(extrapolate, gx, nx)
|
| 1293 |
+
|
| 1294 |
+
# corners
|
| 1295 |
+
# (upper weight, lower corner, upper corner, lower sign, upper sign)
|
| 1296 |
+
gx, gx0, gx1, signx0, signx1 = get_weights_and_indices(gx, nx, boundx)
|
| 1297 |
+
|
| 1298 |
+
# scatter
|
| 1299 |
+
out = torch.zeros([batch, channel, nx], dtype=inp.dtype, device=inp.device)
|
| 1300 |
+
# - corner 000
|
| 1301 |
+
idx = gx0
|
| 1302 |
+
idx = idx.expand([batch, channel, idx.shape[-1]])
|
| 1303 |
+
out1 = inp.clone()
|
| 1304 |
+
sign = signx0
|
| 1305 |
+
if sign is not None:
|
| 1306 |
+
out1 *= sign.unsqueeze(-1)
|
| 1307 |
+
if mask is not None:
|
| 1308 |
+
out1 *= mask.unsqueeze(-1)
|
| 1309 |
+
out1x = out1.squeeze(-1)
|
| 1310 |
+
out1x.neg_()
|
| 1311 |
+
out.scatter_add_(-1, idx, out1x)
|
| 1312 |
+
# - corner 100
|
| 1313 |
+
idx = gx1
|
| 1314 |
+
idx = idx.expand([batch, channel, idx.shape[-1]])
|
| 1315 |
+
out1 = inp.clone()
|
| 1316 |
+
sign = signx1
|
| 1317 |
+
if sign is not None:
|
| 1318 |
+
out1 *= sign.unsqueeze(-1)
|
| 1319 |
+
if mask is not None:
|
| 1320 |
+
out1 *= mask.unsqueeze(-1)
|
| 1321 |
+
out1x = out1.squeeze(-1)
|
| 1322 |
+
out.scatter_add_(-1, idx, out1x)
|
| 1323 |
+
|
| 1324 |
+
out = out.reshape(list(out.shape[:2]) + shape)
|
| 1325 |
+
return out
|
| 1326 |
+
|
| 1327 |
+
|
| 1328 |
+
@torch.jit.script
|
| 1329 |
+
def hess1d(inp, g, bound: List[Bound], extrapolate: int = 1):
|
| 1330 |
+
"""
|
| 1331 |
+
inp: (B, C, iX) tensor
|
| 1332 |
+
g: (B, oX, 1) tensor
|
| 1333 |
+
bound: List{1}[Bound] tensor
|
| 1334 |
+
extrapolate: ExtrapolateType
|
| 1335 |
+
returns: (B, C, oX, 1, 1) tensor
|
| 1336 |
+
"""
|
| 1337 |
+
batch = max(inp.shape[0], g.shape[0])
|
| 1338 |
+
return torch.zeros([batch, inp.shape[1], g.shape[1], 1, 1],
|
| 1339 |
+
dtype=inp.dtype, device=inp.device)
|
Generator/interpol/jit_utils.py
ADDED
|
@@ -0,0 +1,443 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""A lot of utility functions for TorchScript"""
|
| 2 |
+
import torch
|
| 3 |
+
import os
|
| 4 |
+
from typing import List, Tuple, Optional
|
| 5 |
+
from .utils import torch_version
|
| 6 |
+
from torch import Tensor
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@torch.jit.script
|
| 10 |
+
def pad_list_int(x: List[int], dim: int) -> List[int]:
|
| 11 |
+
if len(x) < dim:
|
| 12 |
+
x = x + x[-1:] * (dim - len(x))
|
| 13 |
+
if len(x) > dim:
|
| 14 |
+
x = x[:dim]
|
| 15 |
+
return x
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@torch.jit.script
|
| 19 |
+
def pad_list_float(x: List[float], dim: int) -> List[float]:
|
| 20 |
+
if len(x) < dim:
|
| 21 |
+
x = x + x[-1:] * (dim - len(x))
|
| 22 |
+
if len(x) > dim:
|
| 23 |
+
x = x[:dim]
|
| 24 |
+
return x
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@torch.jit.script
|
| 28 |
+
def pad_list_str(x: List[str], dim: int) -> List[str]:
|
| 29 |
+
if len(x) < dim:
|
| 30 |
+
x = x + x[-1:] * (dim - len(x))
|
| 31 |
+
if len(x) > dim:
|
| 32 |
+
x = x[:dim]
|
| 33 |
+
return x
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@torch.jit.script
|
| 37 |
+
def list_any(x: List[bool]) -> bool:
|
| 38 |
+
for elem in x:
|
| 39 |
+
if elem:
|
| 40 |
+
return True
|
| 41 |
+
return False
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
@torch.jit.script
|
| 45 |
+
def list_all(x: List[bool]) -> bool:
|
| 46 |
+
for elem in x:
|
| 47 |
+
if not elem:
|
| 48 |
+
return False
|
| 49 |
+
return True
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
@torch.jit.script
|
| 53 |
+
def list_prod_int(x: List[int]) -> int:
|
| 54 |
+
if len(x) == 0:
|
| 55 |
+
return 1
|
| 56 |
+
x0 = x[0]
|
| 57 |
+
for x1 in x[1:]:
|
| 58 |
+
x0 = x0 * x1
|
| 59 |
+
return x0
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
@torch.jit.script
|
| 63 |
+
def list_sum_int(x: List[int]) -> int:
|
| 64 |
+
if len(x) == 0:
|
| 65 |
+
return 1
|
| 66 |
+
x0 = x[0]
|
| 67 |
+
for x1 in x[1:]:
|
| 68 |
+
x0 = x0 + x1
|
| 69 |
+
return x0
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
@torch.jit.script
|
| 73 |
+
def list_prod_tensor(x: List[Tensor]) -> Tensor:
|
| 74 |
+
if len(x) == 0:
|
| 75 |
+
empty: List[int] = []
|
| 76 |
+
return torch.ones(empty)
|
| 77 |
+
x0 = x[0]
|
| 78 |
+
for x1 in x[1:]:
|
| 79 |
+
x0 = x0 * x1
|
| 80 |
+
return x0
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
@torch.jit.script
|
| 84 |
+
def list_sum_tensor(x: List[Tensor]) -> Tensor:
|
| 85 |
+
if len(x) == 0:
|
| 86 |
+
empty: List[int] = []
|
| 87 |
+
return torch.ones(empty)
|
| 88 |
+
x0 = x[0]
|
| 89 |
+
for x1 in x[1:]:
|
| 90 |
+
x0 = x0 + x1
|
| 91 |
+
return x0
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
@torch.jit.script
|
| 95 |
+
def list_reverse_int(x: List[int]) -> List[int]:
|
| 96 |
+
if len(x) == 0:
|
| 97 |
+
return x
|
| 98 |
+
return [x[i] for i in range(-1, -len(x)-1, -1)]
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
@torch.jit.script
|
| 102 |
+
def list_cumprod_int(x: List[int], reverse: bool = False,
|
| 103 |
+
exclusive: bool = False) -> List[int]:
|
| 104 |
+
if len(x) == 0:
|
| 105 |
+
lx: List[int] = []
|
| 106 |
+
return lx
|
| 107 |
+
if reverse:
|
| 108 |
+
x = list_reverse_int(x)
|
| 109 |
+
|
| 110 |
+
x0 = 1 if exclusive else x[0]
|
| 111 |
+
lx = [x0]
|
| 112 |
+
all_x = x[:-1] if exclusive else x[1:]
|
| 113 |
+
for x1 in all_x:
|
| 114 |
+
x0 = x0 * x1
|
| 115 |
+
lx.append(x0)
|
| 116 |
+
if reverse:
|
| 117 |
+
lx = list_reverse_int(lx)
|
| 118 |
+
return lx
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
@torch.jit.script
|
| 122 |
+
def movedim1(x, source: int, destination: int):
|
| 123 |
+
dim = x.dim()
|
| 124 |
+
source = dim + source if source < 0 else source
|
| 125 |
+
destination = dim + destination if destination < 0 else destination
|
| 126 |
+
permutation = [d for d in range(dim)]
|
| 127 |
+
permutation = permutation[:source] + permutation[source+1:]
|
| 128 |
+
permutation = permutation[:destination] + [source] + permutation[destination:]
|
| 129 |
+
return x.permute(permutation)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
@torch.jit.script
|
| 133 |
+
def sub2ind(subs, shape: List[int]):
|
| 134 |
+
"""Convert sub indices (i, j, k) into linear indices.
|
| 135 |
+
|
| 136 |
+
The rightmost dimension is the most rapidly changing one
|
| 137 |
+
-> if shape == [D, H, W], the strides are therefore [H*W, W, 1]
|
| 138 |
+
|
| 139 |
+
Parameters
|
| 140 |
+
----------
|
| 141 |
+
subs : (D, ...) tensor
|
| 142 |
+
List of sub-indices. The first dimension is the number of dimension.
|
| 143 |
+
Each element should have the same number of elements and shape.
|
| 144 |
+
shape : (D,) list[int]
|
| 145 |
+
Size of each dimension. Its length should be the same as the
|
| 146 |
+
first dimension of ``subs``.
|
| 147 |
+
|
| 148 |
+
Returns
|
| 149 |
+
-------
|
| 150 |
+
ind : (...) tensor
|
| 151 |
+
Linear indices
|
| 152 |
+
"""
|
| 153 |
+
subs = subs.unbind(0)
|
| 154 |
+
ind = subs[-1]
|
| 155 |
+
subs = subs[:-1]
|
| 156 |
+
ind = ind.clone()
|
| 157 |
+
stride = list_cumprod_int(shape[1:], reverse=True, exclusive=False)
|
| 158 |
+
for i, s in zip(subs, stride):
|
| 159 |
+
ind += i * s
|
| 160 |
+
return ind
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
@torch.jit.script
|
| 164 |
+
def sub2ind_list(subs: List[Tensor], shape: List[int]):
|
| 165 |
+
"""Convert sub indices (i, j, k) into linear indices.
|
| 166 |
+
|
| 167 |
+
The rightmost dimension is the most rapidly changing one
|
| 168 |
+
-> if shape == [D, H, W], the strides are therefore [H*W, W, 1]
|
| 169 |
+
|
| 170 |
+
Parameters
|
| 171 |
+
----------
|
| 172 |
+
subs : (D,) list[tensor]
|
| 173 |
+
List of sub-indices. The first dimension is the number of dimension.
|
| 174 |
+
Each element should have the same number of elements and shape.
|
| 175 |
+
shape : (D,) list[int]
|
| 176 |
+
Size of each dimension. Its length should be the same as the
|
| 177 |
+
first dimension of ``subs``.
|
| 178 |
+
|
| 179 |
+
Returns
|
| 180 |
+
-------
|
| 181 |
+
ind : (...) tensor
|
| 182 |
+
Linear indices
|
| 183 |
+
"""
|
| 184 |
+
ind = subs[-1]
|
| 185 |
+
subs = subs[:-1]
|
| 186 |
+
ind = ind.clone()
|
| 187 |
+
stride = list_cumprod_int(shape[1:], reverse=True, exclusive=False)
|
| 188 |
+
for i, s in zip(subs, stride):
|
| 189 |
+
ind += i * s
|
| 190 |
+
return ind
|
| 191 |
+
|
| 192 |
+
# floor_divide returns wrong results for negative values, because it truncates
|
| 193 |
+
# instead of performing a proper floor. In recent version of pytorch, it is
|
| 194 |
+
# advised to use div(..., rounding_mode='trunc'|'floor') instead.
|
| 195 |
+
# Here, we only use floor_divide on positive values so we do not care.
|
| 196 |
+
if torch_version('>=', [1, 8]):
|
| 197 |
+
@torch.jit.script
|
| 198 |
+
def floor_div(x, y) -> torch.Tensor:
|
| 199 |
+
return torch.div(x, y, rounding_mode='floor')
|
| 200 |
+
@torch.jit.script
|
| 201 |
+
def floor_div_int(x, y: int) -> torch.Tensor:
|
| 202 |
+
return torch.div(x, y, rounding_mode='floor')
|
| 203 |
+
else:
|
| 204 |
+
@torch.jit.script
|
| 205 |
+
def floor_div(x, y) -> torch.Tensor:
|
| 206 |
+
return (x / y).floor_()
|
| 207 |
+
@torch.jit.script
|
| 208 |
+
def floor_div_int(x, y: int) -> torch.Tensor:
|
| 209 |
+
return (x / y).floor_()
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
@torch.jit.script
|
| 213 |
+
def ind2sub(ind, shape: List[int]):
|
| 214 |
+
"""Convert linear indices into sub indices (i, j, k).
|
| 215 |
+
|
| 216 |
+
The rightmost dimension is the most rapidly changing one
|
| 217 |
+
-> if shape == [D, H, W], the strides are therefore [H*W, W, 1]
|
| 218 |
+
|
| 219 |
+
Parameters
|
| 220 |
+
----------
|
| 221 |
+
ind : tensor_like
|
| 222 |
+
Linear indices
|
| 223 |
+
shape : (D,) vector_like
|
| 224 |
+
Size of each dimension.
|
| 225 |
+
|
| 226 |
+
Returns
|
| 227 |
+
-------
|
| 228 |
+
subs : (D, ...) tensor
|
| 229 |
+
Sub-indices.
|
| 230 |
+
"""
|
| 231 |
+
stride = list_cumprod_int(shape, reverse=True, exclusive=True)
|
| 232 |
+
sub = ind.new_empty([len(shape)] + ind.shape)
|
| 233 |
+
sub.copy_(ind)
|
| 234 |
+
for d in range(len(shape)):
|
| 235 |
+
if d > 0:
|
| 236 |
+
sub[d] = torch.remainder(sub[d], stride[d-1])
|
| 237 |
+
sub[d] = floor_div_int(sub[d], stride[d])
|
| 238 |
+
return sub
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
@torch.jit.script
|
| 242 |
+
def inbounds_mask_3d(extrapolate: int, gx, gy, gz, nx: int, ny: int, nz: int) \
|
| 243 |
+
-> Optional[Tensor]:
|
| 244 |
+
# mask of inbounds voxels
|
| 245 |
+
mask: Optional[Tensor] = None
|
| 246 |
+
if extrapolate in (0, 2): # no / hist
|
| 247 |
+
tiny = 5e-2
|
| 248 |
+
threshold = tiny
|
| 249 |
+
if extrapolate == 2:
|
| 250 |
+
threshold = 0.5 + tiny
|
| 251 |
+
mask = ((gx > -threshold) & (gx < nx - 1 + threshold) &
|
| 252 |
+
(gy > -threshold) & (gy < ny - 1 + threshold) &
|
| 253 |
+
(gz > -threshold) & (gz < nz - 1 + threshold))
|
| 254 |
+
return mask
|
| 255 |
+
return mask
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
@torch.jit.script
|
| 259 |
+
def inbounds_mask_2d(extrapolate: int, gx, gy, nx: int, ny: int) \
|
| 260 |
+
-> Optional[Tensor]:
|
| 261 |
+
# mask of inbounds voxels
|
| 262 |
+
mask: Optional[Tensor] = None
|
| 263 |
+
if extrapolate in (0, 2): # no / hist
|
| 264 |
+
tiny = 5e-2
|
| 265 |
+
threshold = tiny
|
| 266 |
+
if extrapolate == 2:
|
| 267 |
+
threshold = 0.5 + tiny
|
| 268 |
+
mask = ((gx > -threshold) & (gx < nx - 1 + threshold) &
|
| 269 |
+
(gy > -threshold) & (gy < ny - 1 + threshold))
|
| 270 |
+
return mask
|
| 271 |
+
return mask
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
@torch.jit.script
|
| 275 |
+
def inbounds_mask_1d(extrapolate: int, gx, nx: int) -> Optional[Tensor]:
|
| 276 |
+
# mask of inbounds voxels
|
| 277 |
+
mask: Optional[Tensor] = None
|
| 278 |
+
if extrapolate in (0, 2): # no / hist
|
| 279 |
+
tiny = 5e-2
|
| 280 |
+
threshold = tiny
|
| 281 |
+
if extrapolate == 2:
|
| 282 |
+
threshold = 0.5 + tiny
|
| 283 |
+
mask = (gx > -threshold) & (gx < nx - 1 + threshold)
|
| 284 |
+
return mask
|
| 285 |
+
return mask
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
@torch.jit.script
|
| 289 |
+
def make_sign(sign: List[Optional[Tensor]]) -> Optional[Tensor]:
|
| 290 |
+
is_none : List[bool] = [s is None for s in sign]
|
| 291 |
+
if list_all(is_none):
|
| 292 |
+
return None
|
| 293 |
+
filt_sign: List[Tensor] = []
|
| 294 |
+
for s in sign:
|
| 295 |
+
if s is not None:
|
| 296 |
+
filt_sign.append(s)
|
| 297 |
+
return list_prod_tensor(filt_sign)
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
@torch.jit.script
|
| 301 |
+
def square(x):
|
| 302 |
+
return x * x
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
@torch.jit.script
|
| 306 |
+
def square_(x):
|
| 307 |
+
return x.mul_(x)
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
@torch.jit.script
|
| 311 |
+
def cube(x):
|
| 312 |
+
return x * x * x
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
@torch.jit.script
|
| 316 |
+
def cube_(x):
|
| 317 |
+
return square_(x).mul_(x)
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
@torch.jit.script
|
| 321 |
+
def pow4(x):
|
| 322 |
+
return square(square(x))
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
@torch.jit.script
|
| 326 |
+
def pow4_(x):
|
| 327 |
+
return square_(square_(x))
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
@torch.jit.script
|
| 331 |
+
def pow5(x):
|
| 332 |
+
return x * pow4(x)
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
@torch.jit.script
|
| 336 |
+
def pow5_(x):
|
| 337 |
+
return pow4_(x).mul_(x)
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
@torch.jit.script
|
| 341 |
+
def pow6(x):
|
| 342 |
+
return square(cube(x))
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
@torch.jit.script
|
| 346 |
+
def pow6_(x):
|
| 347 |
+
return square_(cube_(x))
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
@torch.jit.script
|
| 351 |
+
def pow7(x):
|
| 352 |
+
return pow6(x) * x
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
@torch.jit.script
|
| 356 |
+
def pow7_(x):
|
| 357 |
+
return pow6_(x).mul_(x)
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
@torch.jit.script
|
| 361 |
+
def dot(x, y, dim: int = -1, keepdim: bool = False):
|
| 362 |
+
"""(Batched) dot product along a dimension"""
|
| 363 |
+
x = movedim1(x, dim, -1).unsqueeze(-2)
|
| 364 |
+
y = movedim1(y, dim, -1).unsqueeze(-1)
|
| 365 |
+
d = torch.matmul(x, y).squeeze(-1).squeeze(-1)
|
| 366 |
+
if keepdim:
|
| 367 |
+
d.unsqueeze(dim)
|
| 368 |
+
return d
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
@torch.jit.script
|
| 372 |
+
def dot_multi(x, y, dim: List[int], keepdim: bool = False):
|
| 373 |
+
"""(Batched) dot product along a dimension"""
|
| 374 |
+
for d in dim:
|
| 375 |
+
x = movedim1(x, d, -1)
|
| 376 |
+
y = movedim1(y, d, -1)
|
| 377 |
+
x = x.reshape(x.shape[:-len(dim)] + [1, -1])
|
| 378 |
+
y = y.reshape(x.shape[:-len(dim)] + [-1, 1])
|
| 379 |
+
dt = torch.matmul(x, y).squeeze(-1).squeeze(-1)
|
| 380 |
+
if keepdim:
|
| 381 |
+
for d in dim:
|
| 382 |
+
dt.unsqueeze(d)
|
| 383 |
+
return dt
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
# cartesian_prod takes multiple inout tensors as input in eager mode
|
| 388 |
+
# but takes a list of tensor in jit mode. This is a helper that works
|
| 389 |
+
# in both cases.
|
| 390 |
+
if not int(os.environ.get('PYTORCH_JIT', '1')):
|
| 391 |
+
cartesian_prod = lambda x: torch.cartesian_prod(*x)
|
| 392 |
+
if torch_version('>=', (1, 10)):
|
| 393 |
+
def meshgrid_ij(x: List[torch.Tensor]) -> List[torch.Tensor]:
|
| 394 |
+
return torch.meshgrid(*x, indexing='ij')
|
| 395 |
+
def meshgrid_xy(x: List[torch.Tensor]) -> List[torch.Tensor]:
|
| 396 |
+
return torch.meshgrid(*x, indexing='xy')
|
| 397 |
+
else:
|
| 398 |
+
def meshgrid_ij(x: List[torch.Tensor]) -> List[torch.Tensor]:
|
| 399 |
+
return torch.meshgrid(*x)
|
| 400 |
+
def meshgrid_xy(x: List[torch.Tensor]) -> List[torch.Tensor]:
|
| 401 |
+
grid = torch.meshgrid(*x)
|
| 402 |
+
if len(grid) > 1:
|
| 403 |
+
grid[0] = grid[0].transpose(0, 1)
|
| 404 |
+
grid[1] = grid[1].transpose(0, 1)
|
| 405 |
+
return grid
|
| 406 |
+
|
| 407 |
+
else:
|
| 408 |
+
cartesian_prod = torch.cartesian_prod
|
| 409 |
+
if torch_version('>=', (1, 10)):
|
| 410 |
+
@torch.jit.script
|
| 411 |
+
def meshgrid_ij(x: List[torch.Tensor]) -> List[torch.Tensor]:
|
| 412 |
+
return torch.meshgrid(x, indexing='ij')
|
| 413 |
+
@torch.jit.script
|
| 414 |
+
def meshgrid_xy(x: List[torch.Tensor]) -> List[torch.Tensor]:
|
| 415 |
+
return torch.meshgrid(x, indexing='xy')
|
| 416 |
+
else:
|
| 417 |
+
@torch.jit.script
|
| 418 |
+
def meshgrid_ij(x: List[torch.Tensor]) -> List[torch.Tensor]:
|
| 419 |
+
return torch.meshgrid(x)
|
| 420 |
+
@torch.jit.script
|
| 421 |
+
def meshgrid_xyt(x: List[torch.Tensor]) -> List[torch.Tensor]:
|
| 422 |
+
grid = torch.meshgrid(x)
|
| 423 |
+
if len(grid) > 1:
|
| 424 |
+
grid[0] = grid[0].transpose(0, 1)
|
| 425 |
+
grid[1] = grid[1].transpose(0, 1)
|
| 426 |
+
return grid
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
meshgrid = meshgrid_ij
|
| 430 |
+
|
| 431 |
+
|
| 432 |
+
# In torch < 1.6, div applied to integer tensor performed a floor_divide
|
| 433 |
+
# In torch > 1.6, it performs a true divide.
|
| 434 |
+
# Floor division must be done using `floor_divide`, but it was buggy
|
| 435 |
+
# until torch 1.13 (it was doing a trunc divide instead of a floor divide).
|
| 436 |
+
# There was at some point a deprecation warning for floor_divide, but it
|
| 437 |
+
# seems to have been lifted afterwards. In torch >= 1.13, floor_divide
|
| 438 |
+
# performs a correct floor division.
|
| 439 |
+
# Since we only apply floor_divide ot positive values, we are fine.
|
| 440 |
+
if torch_version('<', (1, 6)):
|
| 441 |
+
floor_div = torch.div
|
| 442 |
+
else:
|
| 443 |
+
floor_div = torch.floor_divide
|
Generator/interpol/jitfields.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
try:
|
| 2 |
+
import jitfields
|
| 3 |
+
available = True
|
| 4 |
+
except (ImportError, ModuleNotFoundError):
|
| 5 |
+
jitfields = None
|
| 6 |
+
available = False
|
| 7 |
+
from .utils import make_list
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def first2last(input, ndim):
|
| 12 |
+
insert = input.dim() <= ndim
|
| 13 |
+
if insert:
|
| 14 |
+
input = input.unsqueeze(-1)
|
| 15 |
+
else:
|
| 16 |
+
input = torch.movedim(input, -ndim-1, -1)
|
| 17 |
+
return input, insert
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def last2first(input, ndim, inserted, grad=False):
|
| 21 |
+
if inserted:
|
| 22 |
+
input = input.squeeze(-1 - grad)
|
| 23 |
+
else:
|
| 24 |
+
input = torch.movedim(input, -1 - grad, -ndim-1 - grad)
|
| 25 |
+
return input
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def grid_pull(input, grid, interpolation='linear', bound='zero',
|
| 29 |
+
extrapolate=False, prefilter=False):
|
| 30 |
+
ndim = grid.shape[-1]
|
| 31 |
+
input, inserted = first2last(input, ndim)
|
| 32 |
+
input = jitfields.pull(input, grid, order=interpolation, bound=bound,
|
| 33 |
+
extrapolate=extrapolate, prefilter=prefilter)
|
| 34 |
+
input = last2first(input, ndim, inserted)
|
| 35 |
+
return input
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def grid_push(input, grid, shape=None, interpolation='linear', bound='zero',
|
| 39 |
+
extrapolate=False, prefilter=False):
|
| 40 |
+
ndim = grid.shape[-1]
|
| 41 |
+
input, inserted = first2last(input, ndim)
|
| 42 |
+
input = jitfields.push(input, grid, shape, order=interpolation, bound=bound,
|
| 43 |
+
extrapolate=extrapolate, prefilter=prefilter)
|
| 44 |
+
input = last2first(input, ndim, inserted)
|
| 45 |
+
return input
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def grid_count(grid, shape=None, interpolation='linear', bound='zero',
|
| 49 |
+
extrapolate=False):
|
| 50 |
+
return jitfields.count(grid, shape, order=interpolation, bound=bound,
|
| 51 |
+
extrapolate=extrapolate)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def grid_grad(input, grid, interpolation='linear', bound='zero',
|
| 55 |
+
extrapolate=False, prefilter=False):
|
| 56 |
+
ndim = grid.shape[-1]
|
| 57 |
+
input, inserted = first2last(input, ndim)
|
| 58 |
+
input = jitfields.grad(input, grid, order=interpolation, bound=bound,
|
| 59 |
+
extrapolate=extrapolate, prefilter=prefilter)
|
| 60 |
+
input = last2first(input, ndim, inserted, True)
|
| 61 |
+
return input
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def spline_coeff(input, interpolation='linear', bound='dct2', dim=-1,
|
| 65 |
+
inplace=False):
|
| 66 |
+
func = jitfields.spline_coeff_ if inplace else jitfields.spline_coeff
|
| 67 |
+
return func(input, interpolation, bound=bound, dim=dim)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def spline_coeff_nd(input, interpolation='linear', bound='dct2', dim=None,
|
| 71 |
+
inplace=False):
|
| 72 |
+
func = jitfields.spline_coeff_nd_ if inplace else jitfields.spline_coeff_nd
|
| 73 |
+
return func(input, interpolation, bound=bound, ndim=dim)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def resize(image, factor=None, shape=None, anchor='c',
|
| 77 |
+
interpolation=1, prefilter=True, **kwargs):
|
| 78 |
+
kwargs.setdefault('bound', 'nearest')
|
| 79 |
+
ndim = max(len(make_list(factor or [])),
|
| 80 |
+
len(make_list(shape or [])),
|
| 81 |
+
len(make_list(anchor or []))) or (image.dim() - 2)
|
| 82 |
+
return jitfields.resize(image, factor=factor, shape=shape, ndim=ndim,
|
| 83 |
+
anchor=anchor, order=interpolation,
|
| 84 |
+
bound=kwargs['bound'], prefilter=prefilter)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def restrict(image, factor=None, shape=None, anchor='c',
|
| 88 |
+
interpolation=1, reduce_sum=False, **kwargs):
|
| 89 |
+
kwargs.setdefault('bound', 'nearest')
|
| 90 |
+
ndim = max(len(make_list(factor or [])),
|
| 91 |
+
len(make_list(shape or [])),
|
| 92 |
+
len(make_list(anchor or []))) or (image.dim() - 2)
|
| 93 |
+
return jitfields.restrict(image, factor=factor, shape=shape, ndim=ndim,
|
| 94 |
+
anchor=anchor, order=interpolation,
|
| 95 |
+
bound=kwargs['bound'], reduce_sum=reduce_sum)
|
Generator/interpol/nd.py
ADDED
|
@@ -0,0 +1,464 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Generic N-dimensional version: any combination of spline orders"""
|
| 2 |
+
import torch
|
| 3 |
+
from typing import List, Optional, Tuple
|
| 4 |
+
from .bounds import Bound
|
| 5 |
+
from .splines import Spline
|
| 6 |
+
from .jit_utils import sub2ind_list, make_sign, list_prod_int, cartesian_prod
|
| 7 |
+
Tensor = torch.Tensor
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@torch.jit.script
|
| 11 |
+
def inbounds_mask(extrapolate: int, grid, shape: List[int])\
|
| 12 |
+
-> Optional[Tensor]:
|
| 13 |
+
# mask of inbounds voxels
|
| 14 |
+
mask: Optional[Tensor] = None
|
| 15 |
+
if extrapolate in (0, 2): # no / hist
|
| 16 |
+
grid = grid.unsqueeze(1)
|
| 17 |
+
tiny = 5e-2
|
| 18 |
+
threshold = tiny
|
| 19 |
+
if extrapolate == 2:
|
| 20 |
+
threshold = 0.5 + tiny
|
| 21 |
+
mask = torch.ones(grid.shape[:-1],
|
| 22 |
+
dtype=torch.bool, device=grid.device)
|
| 23 |
+
for grid1, shape1 in zip(grid.unbind(-1), shape):
|
| 24 |
+
mask = mask & (grid1 > -threshold)
|
| 25 |
+
mask = mask & (grid1 < shape1 - 1 + threshold)
|
| 26 |
+
return mask
|
| 27 |
+
return mask
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@torch.jit.script
|
| 31 |
+
def get_weights(grid, bound: List[Bound], spline: List[Spline],
|
| 32 |
+
shape: List[int], grad: bool = False, hess: bool = False) \
|
| 33 |
+
-> Tuple[List[List[Tensor]],
|
| 34 |
+
List[List[Optional[Tensor]]],
|
| 35 |
+
List[List[Optional[Tensor]]],
|
| 36 |
+
List[List[Tensor]],
|
| 37 |
+
List[List[Optional[Tensor]]]]:
|
| 38 |
+
|
| 39 |
+
weights: List[List[Tensor]] = []
|
| 40 |
+
grads: List[List[Optional[Tensor]]] = []
|
| 41 |
+
hesss: List[List[Optional[Tensor]]] = []
|
| 42 |
+
coords: List[List[Tensor]] = []
|
| 43 |
+
signs: List[List[Optional[Tensor]]] = []
|
| 44 |
+
for g, b, s, n in zip(grid.unbind(-1), bound, spline, shape):
|
| 45 |
+
grid0 = (g - (s.order-1)/2).floor()
|
| 46 |
+
dist0 = g - grid0
|
| 47 |
+
grid0 = grid0.long()
|
| 48 |
+
nb_nodes = s.order + 1
|
| 49 |
+
subweights: List[Tensor] = []
|
| 50 |
+
subcoords: List[Tensor] = []
|
| 51 |
+
subgrads: List[Optional[Tensor]] = []
|
| 52 |
+
subhesss: List[Optional[Tensor]] = []
|
| 53 |
+
subsigns: List[Optional[Tensor]] = []
|
| 54 |
+
for node in range(nb_nodes):
|
| 55 |
+
grid1 = grid0 + node
|
| 56 |
+
sign1: Optional[Tensor] = b.transform(grid1, n)
|
| 57 |
+
subsigns.append(sign1)
|
| 58 |
+
grid1 = b.index(grid1, n)
|
| 59 |
+
subcoords.append(grid1)
|
| 60 |
+
dist1 = dist0 - node
|
| 61 |
+
weight1 = s.fastweight(dist1)
|
| 62 |
+
subweights.append(weight1)
|
| 63 |
+
grad1: Optional[Tensor] = None
|
| 64 |
+
if grad:
|
| 65 |
+
grad1 = s.fastgrad(dist1)
|
| 66 |
+
subgrads.append(grad1)
|
| 67 |
+
hess1: Optional[Tensor] = None
|
| 68 |
+
if hess:
|
| 69 |
+
hess1 = s.fasthess(dist1)
|
| 70 |
+
subhesss.append(hess1)
|
| 71 |
+
weights.append(subweights)
|
| 72 |
+
coords.append(subcoords)
|
| 73 |
+
signs.append(subsigns)
|
| 74 |
+
grads.append(subgrads)
|
| 75 |
+
hesss.append(subhesss)
|
| 76 |
+
|
| 77 |
+
return weights, grads, hesss, coords, signs
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
@torch.jit.script
|
| 81 |
+
def pull(inp, grid, bound: List[Bound], spline: List[Spline],
|
| 82 |
+
extrapolate: int = 1):
|
| 83 |
+
"""
|
| 84 |
+
inp: (B, C, *ishape) tensor
|
| 85 |
+
g: (B, *oshape, D) tensor
|
| 86 |
+
bound: List{D}[Bound] tensor
|
| 87 |
+
spline: List{D}[Spline] tensor
|
| 88 |
+
extrapolate: int
|
| 89 |
+
returns: (B, C, *oshape) tensor
|
| 90 |
+
"""
|
| 91 |
+
|
| 92 |
+
dim = grid.shape[-1]
|
| 93 |
+
shape = list(inp.shape[-dim:])
|
| 94 |
+
oshape = list(grid.shape[-dim-1:-1])
|
| 95 |
+
batch = max(inp.shape[0], grid.shape[0])
|
| 96 |
+
channel = inp.shape[1]
|
| 97 |
+
|
| 98 |
+
grid = grid.reshape([grid.shape[0], -1, grid.shape[-1]])
|
| 99 |
+
inp = inp.reshape([inp.shape[0], inp.shape[1], -1])
|
| 100 |
+
mask = inbounds_mask(extrapolate, grid, shape)
|
| 101 |
+
|
| 102 |
+
# precompute weights along each dimension
|
| 103 |
+
weights, _, _, coords, signs = get_weights(grid, bound, spline, shape, False, False)
|
| 104 |
+
|
| 105 |
+
# initialize
|
| 106 |
+
out = torch.zeros([batch, channel, grid.shape[1]],
|
| 107 |
+
dtype=inp.dtype, device=inp.device)
|
| 108 |
+
|
| 109 |
+
# iterate across nodes/corners
|
| 110 |
+
range_nodes = [torch.as_tensor([d for d in range(n)])
|
| 111 |
+
for n in [s.order + 1 for s in spline]]
|
| 112 |
+
if dim == 1:
|
| 113 |
+
# cartesian_prod does not work as expected when only one
|
| 114 |
+
# element is provided
|
| 115 |
+
all_nodes = range_nodes[0].unsqueeze(-1)
|
| 116 |
+
else:
|
| 117 |
+
all_nodes = cartesian_prod(range_nodes)
|
| 118 |
+
for nodes in all_nodes:
|
| 119 |
+
# gather
|
| 120 |
+
idx = [c[n] for c, n in zip(coords, nodes)]
|
| 121 |
+
idx = sub2ind_list(idx, shape).unsqueeze(1)
|
| 122 |
+
idx = idx.expand([batch, channel, idx.shape[-1]])
|
| 123 |
+
out1 = inp.gather(-1, idx)
|
| 124 |
+
|
| 125 |
+
# apply sign
|
| 126 |
+
sign0: List[Optional[Tensor]] = [sgn[n] for sgn, n in zip(signs, nodes)]
|
| 127 |
+
sign1: Optional[Tensor] = make_sign(sign0)
|
| 128 |
+
if sign1 is not None:
|
| 129 |
+
out1 = out1 * sign1.unsqueeze(1)
|
| 130 |
+
|
| 131 |
+
# apply weights
|
| 132 |
+
for weight, n in zip(weights, nodes):
|
| 133 |
+
out1 = out1 * weight[n].unsqueeze(1)
|
| 134 |
+
|
| 135 |
+
# accumulate
|
| 136 |
+
out = out + out1
|
| 137 |
+
|
| 138 |
+
# out-of-bounds mask
|
| 139 |
+
if mask is not None:
|
| 140 |
+
out = out * mask
|
| 141 |
+
|
| 142 |
+
out = out.reshape(list(out.shape[:2]) + oshape)
|
| 143 |
+
return out
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
@torch.jit.script
|
| 147 |
+
def push(inp, grid, shape: Optional[List[int]], bound: List[Bound],
|
| 148 |
+
spline: List[Spline], extrapolate: int = 1):
|
| 149 |
+
"""
|
| 150 |
+
inp: (B, C, *ishape) tensor
|
| 151 |
+
g: (B, *ishape, D) tensor
|
| 152 |
+
shape: List{D}[int], optional
|
| 153 |
+
bound: List{D}[Bound] tensor
|
| 154 |
+
spline: List{D}[Spline] tensor
|
| 155 |
+
extrapolate: int
|
| 156 |
+
returns: (B, C, *oshape) tensor
|
| 157 |
+
"""
|
| 158 |
+
|
| 159 |
+
dim = grid.shape[-1]
|
| 160 |
+
ishape = list(grid.shape[-dim - 1:-1])
|
| 161 |
+
if shape is None:
|
| 162 |
+
shape = ishape
|
| 163 |
+
shape = list(shape)
|
| 164 |
+
batch = max(inp.shape[0], grid.shape[0])
|
| 165 |
+
channel = inp.shape[1]
|
| 166 |
+
|
| 167 |
+
grid = grid.reshape([grid.shape[0], -1, grid.shape[-1]])
|
| 168 |
+
inp = inp.reshape([inp.shape[0], inp.shape[1], -1])
|
| 169 |
+
mask = inbounds_mask(extrapolate, grid, shape)
|
| 170 |
+
|
| 171 |
+
# precompute weights along each dimension
|
| 172 |
+
weights, _, _, coords, signs = get_weights(grid, bound, spline, shape)
|
| 173 |
+
|
| 174 |
+
# initialize
|
| 175 |
+
out = torch.zeros([batch, channel, list_prod_int(shape)],
|
| 176 |
+
dtype=inp.dtype, device=inp.device)
|
| 177 |
+
|
| 178 |
+
# iterate across nodes/corners
|
| 179 |
+
range_nodes = [torch.as_tensor([d for d in range(n)])
|
| 180 |
+
for n in [s.order + 1 for s in spline]]
|
| 181 |
+
if dim == 1:
|
| 182 |
+
# cartesian_prod does not work as expected when only one
|
| 183 |
+
# element is provided
|
| 184 |
+
all_nodes = range_nodes[0].unsqueeze(-1)
|
| 185 |
+
else:
|
| 186 |
+
all_nodes = cartesian_prod(range_nodes)
|
| 187 |
+
for nodes in all_nodes:
|
| 188 |
+
|
| 189 |
+
# gather
|
| 190 |
+
idx = [c[n] for c, n in zip(coords, nodes)]
|
| 191 |
+
idx = sub2ind_list(idx, shape).unsqueeze(1)
|
| 192 |
+
idx = idx.expand([batch, channel, idx.shape[-1]])
|
| 193 |
+
out1 = inp.clone()
|
| 194 |
+
|
| 195 |
+
# apply sign
|
| 196 |
+
sign0: List[Optional[Tensor]] = [sgn[n] for sgn, n in zip(signs, nodes)]
|
| 197 |
+
sign1: Optional[Tensor] = make_sign(sign0)
|
| 198 |
+
if sign1 is not None:
|
| 199 |
+
out1 = out1 * sign1.unsqueeze(1)
|
| 200 |
+
|
| 201 |
+
# out-of-bounds mask
|
| 202 |
+
if mask is not None:
|
| 203 |
+
out1 = out1 * mask
|
| 204 |
+
|
| 205 |
+
# apply weights
|
| 206 |
+
for weight, n in zip(weights, nodes):
|
| 207 |
+
out1 = out1 * weight[n].unsqueeze(1)
|
| 208 |
+
|
| 209 |
+
# accumulate
|
| 210 |
+
out.scatter_add_(-1, idx, out1)
|
| 211 |
+
|
| 212 |
+
out = out.reshape(list(out.shape[:2]) + shape)
|
| 213 |
+
return out
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
@torch.jit.script
|
| 217 |
+
def grad(inp, grid, bound: List[Bound], spline: List[Spline],
|
| 218 |
+
extrapolate: int = 1):
|
| 219 |
+
"""
|
| 220 |
+
inp: (B, C, *ishape) tensor
|
| 221 |
+
grid: (B, *oshape, D) tensor
|
| 222 |
+
bound: List{D}[Bound] tensor
|
| 223 |
+
spline: List{D}[Spline] tensor
|
| 224 |
+
extrapolate: int
|
| 225 |
+
returns: (B, C, *oshape, D) tensor
|
| 226 |
+
"""
|
| 227 |
+
|
| 228 |
+
dim = grid.shape[-1]
|
| 229 |
+
shape = list(inp.shape[-dim:])
|
| 230 |
+
oshape = list(grid.shape[-dim-1:-1])
|
| 231 |
+
batch = max(inp.shape[0], grid.shape[0])
|
| 232 |
+
channel = inp.shape[1]
|
| 233 |
+
|
| 234 |
+
grid = grid.reshape([grid.shape[0], -1, grid.shape[-1]])
|
| 235 |
+
inp = inp.reshape([inp.shape[0], inp.shape[1], -1])
|
| 236 |
+
mask = inbounds_mask(extrapolate, grid, shape)
|
| 237 |
+
|
| 238 |
+
# precompute weights along each dimension
|
| 239 |
+
weights, grads, _, coords, signs = get_weights(grid, bound, spline, shape,
|
| 240 |
+
grad=True)
|
| 241 |
+
|
| 242 |
+
# initialize
|
| 243 |
+
out = torch.zeros([batch, channel, grid.shape[1], dim],
|
| 244 |
+
dtype=inp.dtype, device=inp.device)
|
| 245 |
+
|
| 246 |
+
# iterate across nodes/corners
|
| 247 |
+
range_nodes = [torch.as_tensor([d for d in range(n)])
|
| 248 |
+
for n in [s.order + 1 for s in spline]]
|
| 249 |
+
if dim == 1:
|
| 250 |
+
# cartesian_prod does not work as expected when only one
|
| 251 |
+
# element is provided
|
| 252 |
+
all_nodes = range_nodes[0].unsqueeze(-1)
|
| 253 |
+
else:
|
| 254 |
+
all_nodes = cartesian_prod(range_nodes)
|
| 255 |
+
for nodes in all_nodes:
|
| 256 |
+
|
| 257 |
+
# gather
|
| 258 |
+
idx = [c[n] for c, n in zip(coords, nodes)]
|
| 259 |
+
idx = sub2ind_list(idx, shape).unsqueeze(1)
|
| 260 |
+
idx = idx.expand([batch, channel, idx.shape[-1]])
|
| 261 |
+
out0 = inp.gather(-1, idx)
|
| 262 |
+
|
| 263 |
+
# apply sign
|
| 264 |
+
sign0: List[Optional[Tensor]] = [sgn[n] for sgn, n in zip(signs, nodes)]
|
| 265 |
+
sign1: Optional[Tensor] = make_sign(sign0)
|
| 266 |
+
if sign1 is not None:
|
| 267 |
+
out0 = out0 * sign1.unsqueeze(1)
|
| 268 |
+
|
| 269 |
+
for d in range(dim):
|
| 270 |
+
out1 = out0.clone()
|
| 271 |
+
# apply weights
|
| 272 |
+
for dd, (weight, grad1, n) in enumerate(zip(weights, grads, nodes)):
|
| 273 |
+
if d == dd:
|
| 274 |
+
grad11 = grad1[n]
|
| 275 |
+
if grad11 is not None:
|
| 276 |
+
out1 = out1 * grad11.unsqueeze(1)
|
| 277 |
+
else:
|
| 278 |
+
out1 = out1 * weight[n].unsqueeze(1)
|
| 279 |
+
|
| 280 |
+
# accumulate
|
| 281 |
+
out.unbind(-1)[d].add_(out1)
|
| 282 |
+
|
| 283 |
+
# out-of-bounds mask
|
| 284 |
+
if mask is not None:
|
| 285 |
+
out = out * mask.unsqueeze(-1)
|
| 286 |
+
|
| 287 |
+
out = out.reshape(list(out.shape[:2]) + oshape + list(out.shape[-1:]))
|
| 288 |
+
return out
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
@torch.jit.script
|
| 292 |
+
def pushgrad(inp, grid, shape: Optional[List[int]], bound: List[Bound],
|
| 293 |
+
spline: List[Spline], extrapolate: int = 1):
|
| 294 |
+
"""
|
| 295 |
+
inp: (B, C, *ishape, D) tensor
|
| 296 |
+
g: (B, *ishape, D) tensor
|
| 297 |
+
shape: List{D}[int], optional
|
| 298 |
+
bound: List{D}[Bound] tensor
|
| 299 |
+
spline: List{D}[Spline] tensor
|
| 300 |
+
extrapolate: int
|
| 301 |
+
returns: (B, C, *shape) tensor
|
| 302 |
+
"""
|
| 303 |
+
dim = grid.shape[-1]
|
| 304 |
+
oshape = list(grid.shape[-dim-1:-1])
|
| 305 |
+
if shape is None:
|
| 306 |
+
shape = oshape
|
| 307 |
+
shape = list(shape)
|
| 308 |
+
batch = max(inp.shape[0], grid.shape[0])
|
| 309 |
+
channel = inp.shape[1]
|
| 310 |
+
|
| 311 |
+
grid = grid.reshape([grid.shape[0], -1, grid.shape[-1]])
|
| 312 |
+
inp = inp.reshape([inp.shape[0], inp.shape[1], -1, dim])
|
| 313 |
+
mask = inbounds_mask(extrapolate, grid, shape)
|
| 314 |
+
|
| 315 |
+
# precompute weights along each dimension
|
| 316 |
+
weights, grads, _, coords, signs = get_weights(grid, bound, spline, shape, grad=True)
|
| 317 |
+
|
| 318 |
+
# initialize
|
| 319 |
+
out = torch.zeros([batch, channel, list_prod_int(shape)],
|
| 320 |
+
dtype=inp.dtype, device=inp.device)
|
| 321 |
+
|
| 322 |
+
# iterate across nodes/corners
|
| 323 |
+
range_nodes = [torch.as_tensor([d for d in range(n)])
|
| 324 |
+
for n in [s.order + 1 for s in spline]]
|
| 325 |
+
if dim == 1:
|
| 326 |
+
# cartesian_prod does not work as expected when only one
|
| 327 |
+
# element is provided
|
| 328 |
+
all_nodes = range_nodes[0].unsqueeze(-1)
|
| 329 |
+
else:
|
| 330 |
+
all_nodes = cartesian_prod(range_nodes)
|
| 331 |
+
for nodes in all_nodes:
|
| 332 |
+
|
| 333 |
+
# gather
|
| 334 |
+
idx = [c[n] for c, n in zip(coords, nodes)]
|
| 335 |
+
idx = sub2ind_list(idx, shape).unsqueeze(1)
|
| 336 |
+
idx = idx.expand([batch, channel, idx.shape[-1]])
|
| 337 |
+
out0 = inp.clone()
|
| 338 |
+
|
| 339 |
+
# apply sign
|
| 340 |
+
sign0: List[Optional[Tensor]] = [sgn[n] for sgn, n in zip(signs, nodes)]
|
| 341 |
+
sign1: Optional[Tensor] = make_sign(sign0)
|
| 342 |
+
if sign1 is not None:
|
| 343 |
+
out0 = out0 * sign1.unsqueeze(1).unsqueeze(-1)
|
| 344 |
+
|
| 345 |
+
# out-of-bounds mask
|
| 346 |
+
if mask is not None:
|
| 347 |
+
out0 = out0 * mask.unsqueeze(-1)
|
| 348 |
+
|
| 349 |
+
for d in range(dim):
|
| 350 |
+
out1 = out0.unbind(-1)[d].clone()
|
| 351 |
+
# apply weights
|
| 352 |
+
for dd, (weight, grad1, n) in enumerate(zip(weights, grads, nodes)):
|
| 353 |
+
if d == dd:
|
| 354 |
+
grad11 = grad1[n]
|
| 355 |
+
if grad11 is not None:
|
| 356 |
+
out1 = out1 * grad11.unsqueeze(1)
|
| 357 |
+
else:
|
| 358 |
+
out1 = out1 * weight[n].unsqueeze(1)
|
| 359 |
+
|
| 360 |
+
# accumulate
|
| 361 |
+
out.scatter_add_(-1, idx, out1)
|
| 362 |
+
|
| 363 |
+
out = out.reshape(list(out.shape[:2]) + shape)
|
| 364 |
+
return out
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
@torch.jit.script
|
| 368 |
+
def hess(inp, grid, bound: List[Bound], spline: List[Spline],
|
| 369 |
+
extrapolate: int = 1):
|
| 370 |
+
"""
|
| 371 |
+
inp: (B, C, *ishape) tensor
|
| 372 |
+
grid: (B, *oshape, D) tensor
|
| 373 |
+
bound: List{D}[Bound] tensor
|
| 374 |
+
spline: List{D}[Spline] tensor
|
| 375 |
+
extrapolate: int
|
| 376 |
+
returns: (B, C, *oshape, D, D) tensor
|
| 377 |
+
"""
|
| 378 |
+
|
| 379 |
+
dim = grid.shape[-1]
|
| 380 |
+
shape = list(inp.shape[-dim:])
|
| 381 |
+
oshape = list(grid.shape[-dim-1:-1])
|
| 382 |
+
batch = max(inp.shape[0], grid.shape[0])
|
| 383 |
+
channel = inp.shape[1]
|
| 384 |
+
|
| 385 |
+
grid = grid.reshape([grid.shape[0], -1, grid.shape[-1]])
|
| 386 |
+
inp = inp.reshape([inp.shape[0], inp.shape[1], -1])
|
| 387 |
+
mask = inbounds_mask(extrapolate, grid, shape)
|
| 388 |
+
|
| 389 |
+
# precompute weights along each dimension
|
| 390 |
+
weights, grads, hesss, coords, signs \
|
| 391 |
+
= get_weights(grid, bound, spline, shape, grad=True, hess=True)
|
| 392 |
+
|
| 393 |
+
# initialize
|
| 394 |
+
out = torch.zeros([batch, channel, grid.shape[1], dim, dim],
|
| 395 |
+
dtype=inp.dtype, device=inp.device)
|
| 396 |
+
|
| 397 |
+
# iterate across nodes/corners
|
| 398 |
+
range_nodes = [torch.as_tensor([d for d in range(n)])
|
| 399 |
+
for n in [s.order + 1 for s in spline]]
|
| 400 |
+
if dim == 1:
|
| 401 |
+
# cartesian_prod does not work as expected when only one
|
| 402 |
+
# element is provided
|
| 403 |
+
all_nodes = range_nodes[0].unsqueeze(-1)
|
| 404 |
+
else:
|
| 405 |
+
all_nodes = cartesian_prod(range_nodes)
|
| 406 |
+
for nodes in all_nodes:
|
| 407 |
+
|
| 408 |
+
# gather
|
| 409 |
+
idx = [c[n] for c, n in zip(coords, nodes)]
|
| 410 |
+
idx = sub2ind_list(idx, shape).unsqueeze(1)
|
| 411 |
+
idx = idx.expand([batch, channel, idx.shape[-1]])
|
| 412 |
+
out0 = inp.gather(-1, idx)
|
| 413 |
+
|
| 414 |
+
# apply sign
|
| 415 |
+
sign0: List[Optional[Tensor]] = [sgn[n] for sgn, n in zip(signs, nodes)]
|
| 416 |
+
sign1: Optional[Tensor] = make_sign(sign0)
|
| 417 |
+
if sign1 is not None:
|
| 418 |
+
out0 = out0 * sign1.unsqueeze(1)
|
| 419 |
+
|
| 420 |
+
for d in range(dim):
|
| 421 |
+
# -- diagonal --
|
| 422 |
+
out1 = out0.clone()
|
| 423 |
+
|
| 424 |
+
# apply weights
|
| 425 |
+
for dd, (weight, hess1, n) \
|
| 426 |
+
in enumerate(zip(weights, hesss, nodes)):
|
| 427 |
+
if d == dd:
|
| 428 |
+
hess11 = hess1[n]
|
| 429 |
+
if hess11 is not None:
|
| 430 |
+
out1 = out1 * hess11.unsqueeze(1)
|
| 431 |
+
else:
|
| 432 |
+
out1 = out1 * weight[n].unsqueeze(1)
|
| 433 |
+
|
| 434 |
+
# accumulate
|
| 435 |
+
out.unbind(-1)[d].unbind(-1)[d].add_(out1)
|
| 436 |
+
|
| 437 |
+
# -- off diagonal --
|
| 438 |
+
for d2 in range(d+1, dim):
|
| 439 |
+
out1 = out0.clone()
|
| 440 |
+
|
| 441 |
+
# apply weights
|
| 442 |
+
for dd, (weight, grad1, n) \
|
| 443 |
+
in enumerate(zip(weights, grads, nodes)):
|
| 444 |
+
if dd in (d, d2):
|
| 445 |
+
grad11 = grad1[n]
|
| 446 |
+
if grad11 is not None:
|
| 447 |
+
out1 = out1 * grad11.unsqueeze(1)
|
| 448 |
+
else:
|
| 449 |
+
out1 = out1 * weight[n].unsqueeze(1)
|
| 450 |
+
|
| 451 |
+
# accumulate
|
| 452 |
+
out.unbind(-1)[d].unbind(-1)[d2].add_(out1)
|
| 453 |
+
|
| 454 |
+
# out-of-bounds mask
|
| 455 |
+
if mask is not None:
|
| 456 |
+
out = out * mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
|
| 457 |
+
|
| 458 |
+
# fill lower triangle
|
| 459 |
+
for d in range(dim):
|
| 460 |
+
for d2 in range(d+1, dim):
|
| 461 |
+
out.unbind(-1)[d2].unbind(-1)[d].copy_(out.unbind(-1)[d].unbind(-1)[d2])
|
| 462 |
+
|
| 463 |
+
out = out.reshape(list(out.shape[:2]) + oshape + list(out.shape[-2:]))
|
| 464 |
+
return out
|
Generator/interpol/pushpull.py
ADDED
|
@@ -0,0 +1,325 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Non-differentiable forward/backward components.
|
| 3 |
+
These components are put together in `interpol.autograd` to generate
|
| 4 |
+
differentiable functions.
|
| 5 |
+
|
| 6 |
+
Note
|
| 7 |
+
----
|
| 8 |
+
.. I removed @torch.jit.script from these entry-points because compiling
|
| 9 |
+
all possible combinations of bound+interpolation made the first call
|
| 10 |
+
extremely slow.
|
| 11 |
+
.. I am not using the dot/multi_dot helpers even though they should be
|
| 12 |
+
more efficient that "multiply and sum" because I haven't had the time
|
| 13 |
+
to test them. It would be worth doing it.
|
| 14 |
+
"""
|
| 15 |
+
import torch
|
| 16 |
+
from typing import List, Optional, Tuple
|
| 17 |
+
from .jit_utils import list_all, dot, dot_multi, pad_list_int
|
| 18 |
+
from .bounds import Bound
|
| 19 |
+
from .splines import Spline
|
| 20 |
+
from . import iso0, iso1, nd
|
| 21 |
+
Tensor = torch.Tensor
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@torch.jit.script
|
| 25 |
+
def make_bound(bound: List[int]) -> List[Bound]:
|
| 26 |
+
return [Bound(b) for b in bound]
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@torch.jit.script
|
| 30 |
+
def make_spline(spline: List[int]) -> List[Spline]:
|
| 31 |
+
return [Spline(s) for s in spline]
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# @torch.jit.script
|
| 35 |
+
def grid_pull(inp, grid, bound: List[int], interpolation: List[int],
|
| 36 |
+
extrapolate: int):
|
| 37 |
+
"""
|
| 38 |
+
inp: (B, C, *spatial_in) tensor
|
| 39 |
+
grid: (B, *spatial_out, D) tensor
|
| 40 |
+
bound: List{D}[int] tensor
|
| 41 |
+
interpolation: List{D}[int]
|
| 42 |
+
extrapolate: int
|
| 43 |
+
returns: (B, C, *spatial_out) tensor
|
| 44 |
+
"""
|
| 45 |
+
dim = grid.shape[-1]
|
| 46 |
+
bound = pad_list_int(bound, dim)
|
| 47 |
+
interpolation = pad_list_int(interpolation, dim)
|
| 48 |
+
bound_fn = make_bound(bound)
|
| 49 |
+
is_iso1 = list_all([order == 1 for order in interpolation])
|
| 50 |
+
if is_iso1:
|
| 51 |
+
if dim == 3:
|
| 52 |
+
return iso1.pull3d(inp, grid, bound_fn, extrapolate)
|
| 53 |
+
elif dim == 2:
|
| 54 |
+
return iso1.pull2d(inp, grid, bound_fn, extrapolate)
|
| 55 |
+
elif dim == 1:
|
| 56 |
+
return iso1.pull1d(inp, grid, bound_fn, extrapolate)
|
| 57 |
+
is_iso0 = list_all([order == 0 for order in interpolation])
|
| 58 |
+
if is_iso0:
|
| 59 |
+
if dim == 3:
|
| 60 |
+
return iso0.pull3d(inp, grid, bound_fn, extrapolate)
|
| 61 |
+
elif dim == 2:
|
| 62 |
+
return iso0.pull2d(inp, grid, bound_fn, extrapolate)
|
| 63 |
+
elif dim == 1:
|
| 64 |
+
return iso0.pull1d(inp, grid, bound_fn, extrapolate)
|
| 65 |
+
spline_fn = make_spline(interpolation)
|
| 66 |
+
return nd.pull(inp, grid, bound_fn, spline_fn, extrapolate)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
# @torch.jit.script
|
| 70 |
+
def grid_push(inp, grid, shape: Optional[List[int]], bound: List[int],
|
| 71 |
+
interpolation: List[int], extrapolate: int):
|
| 72 |
+
"""
|
| 73 |
+
inp: (B, C, *spatial_in) tensor
|
| 74 |
+
grid: (B, *spatial_in, D) tensor
|
| 75 |
+
shape: List{D}[int] tensor, optional, default=spatial_in
|
| 76 |
+
bound: List{D}[int] tensor
|
| 77 |
+
interpolation: List{D}[int]
|
| 78 |
+
extrapolate: int
|
| 79 |
+
returns: (B, C, *shape) tensor
|
| 80 |
+
"""
|
| 81 |
+
dim = grid.shape[-1]
|
| 82 |
+
bound = pad_list_int(bound, dim)
|
| 83 |
+
interpolation = pad_list_int(interpolation, dim)
|
| 84 |
+
bound_fn = make_bound(bound)
|
| 85 |
+
is_iso1 = list_all([order == 1 for order in interpolation])
|
| 86 |
+
if is_iso1:
|
| 87 |
+
if dim == 3:
|
| 88 |
+
return iso1.push3d(inp, grid, shape, bound_fn, extrapolate)
|
| 89 |
+
elif dim == 2:
|
| 90 |
+
return iso1.push2d(inp, grid, shape, bound_fn, extrapolate)
|
| 91 |
+
elif dim == 1:
|
| 92 |
+
return iso1.push1d(inp, grid, shape, bound_fn, extrapolate)
|
| 93 |
+
is_iso0 = list_all([order == 0 for order in interpolation])
|
| 94 |
+
if is_iso0:
|
| 95 |
+
if dim == 3:
|
| 96 |
+
return iso0.push3d(inp, grid, shape, bound_fn, extrapolate)
|
| 97 |
+
elif dim == 2:
|
| 98 |
+
return iso0.push2d(inp, grid, shape, bound_fn, extrapolate)
|
| 99 |
+
elif dim == 1:
|
| 100 |
+
return iso0.push1d(inp, grid, shape, bound_fn, extrapolate)
|
| 101 |
+
spline_fn = make_spline(interpolation)
|
| 102 |
+
return nd.push(inp, grid, shape, bound_fn, spline_fn, extrapolate)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
# @torch.jit.script
|
| 106 |
+
def grid_count(grid, shape: Optional[List[int]], bound: List[int],
|
| 107 |
+
interpolation: List[int], extrapolate: int):
|
| 108 |
+
"""
|
| 109 |
+
grid: (B, *spatial_in, D) tensor
|
| 110 |
+
shape: List{D}[int] tensor, optional, default=spatial_in
|
| 111 |
+
bound: List{D}[int] tensor
|
| 112 |
+
interpolation: List{D}[int]
|
| 113 |
+
extrapolate: int
|
| 114 |
+
returns: (B, 1, *shape) tensor
|
| 115 |
+
"""
|
| 116 |
+
dim = grid.shape[-1]
|
| 117 |
+
bound = pad_list_int(bound, dim)
|
| 118 |
+
interpolation = pad_list_int(interpolation, dim)
|
| 119 |
+
bound_fn = make_bound(bound)
|
| 120 |
+
gshape = list(grid.shape[-dim-1:-1])
|
| 121 |
+
if shape is None:
|
| 122 |
+
shape = gshape
|
| 123 |
+
inp = torch.ones([], dtype=grid.dtype, device=grid.device)
|
| 124 |
+
inp = inp.expand([len(grid), 1] + gshape)
|
| 125 |
+
is_iso1 = list_all([order == 1 for order in interpolation])
|
| 126 |
+
if is_iso1:
|
| 127 |
+
if dim == 3:
|
| 128 |
+
return iso1.push3d(inp, grid, shape, bound_fn, extrapolate)
|
| 129 |
+
elif dim == 2:
|
| 130 |
+
return iso1.push2d(inp, grid, shape, bound_fn, extrapolate)
|
| 131 |
+
elif dim == 1:
|
| 132 |
+
return iso1.push1d(inp, grid, shape, bound_fn, extrapolate)
|
| 133 |
+
is_iso0 = list_all([order == 0 for order in interpolation])
|
| 134 |
+
if is_iso0:
|
| 135 |
+
if dim == 3:
|
| 136 |
+
return iso0.push3d(inp, grid, shape, bound_fn, extrapolate)
|
| 137 |
+
elif dim == 2:
|
| 138 |
+
return iso0.push2d(inp, grid, shape, bound_fn, extrapolate)
|
| 139 |
+
elif dim == 1:
|
| 140 |
+
return iso0.push1d(inp, grid, shape, bound_fn, extrapolate)
|
| 141 |
+
spline_fn = make_spline(interpolation)
|
| 142 |
+
return nd.push(inp, grid, shape, bound_fn, spline_fn, extrapolate)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
# @torch.jit.script
|
| 146 |
+
def grid_grad(inp, grid, bound: List[int], interpolation: List[int],
|
| 147 |
+
extrapolate: int):
|
| 148 |
+
"""
|
| 149 |
+
inp: (B, C, *spatial_in) tensor
|
| 150 |
+
grid: (B, *spatial_out, D) tensor
|
| 151 |
+
bound: List{D}[int] tensor
|
| 152 |
+
interpolation: List{D}[int]
|
| 153 |
+
extrapolate: int
|
| 154 |
+
returns: (B, C, *spatial_out, D) tensor
|
| 155 |
+
"""
|
| 156 |
+
dim = grid.shape[-1]
|
| 157 |
+
bound = pad_list_int(bound, dim)
|
| 158 |
+
interpolation = pad_list_int(interpolation, dim)
|
| 159 |
+
bound_fn = make_bound(bound)
|
| 160 |
+
is_iso1 = list_all([order == 1 for order in interpolation])
|
| 161 |
+
if is_iso1:
|
| 162 |
+
if dim == 3:
|
| 163 |
+
return iso1.grad3d(inp, grid, bound_fn, extrapolate)
|
| 164 |
+
elif dim == 2:
|
| 165 |
+
return iso1.grad2d(inp, grid, bound_fn, extrapolate)
|
| 166 |
+
elif dim == 1:
|
| 167 |
+
return iso1.grad1d(inp, grid, bound_fn, extrapolate)
|
| 168 |
+
is_iso0 = list_all([order == 0 for order in interpolation])
|
| 169 |
+
if is_iso0:
|
| 170 |
+
return iso0.grad(inp, grid, bound_fn, extrapolate)
|
| 171 |
+
spline_fn = make_spline(interpolation)
|
| 172 |
+
return nd.grad(inp, grid, bound_fn, spline_fn, extrapolate)
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
# @torch.jit.script
|
| 176 |
+
def grid_pushgrad(inp, grid, shape: List[int], bound: List[int],
|
| 177 |
+
interpolation: List[int], extrapolate: int):
|
| 178 |
+
""" /!\ Used only in backward pass of grid_grad
|
| 179 |
+
inp: (B, C, *spatial_in, D) tensor
|
| 180 |
+
grid: (B, *spatial_in, D) tensor
|
| 181 |
+
shape: List{D}[int], optional
|
| 182 |
+
bound: List{D}[int] tensor
|
| 183 |
+
interpolation: List{D}[int]
|
| 184 |
+
extrapolate: int
|
| 185 |
+
returns: (B, C, *shape) tensor
|
| 186 |
+
"""
|
| 187 |
+
dim = grid.shape[-1]
|
| 188 |
+
bound = pad_list_int(bound, dim)
|
| 189 |
+
interpolation = pad_list_int(interpolation, dim)
|
| 190 |
+
bound_fn = make_bound(bound)
|
| 191 |
+
is_iso1 = list_all([order == 1 for order in interpolation])
|
| 192 |
+
if is_iso1:
|
| 193 |
+
if dim == 3:
|
| 194 |
+
return iso1.pushgrad3d(inp, grid, shape, bound_fn, extrapolate)
|
| 195 |
+
elif dim == 2:
|
| 196 |
+
return iso1.pushgrad2d(inp, grid, shape, bound_fn, extrapolate)
|
| 197 |
+
elif dim == 1:
|
| 198 |
+
return iso1.pushgrad1d(inp, grid, shape, bound_fn, extrapolate)
|
| 199 |
+
is_iso0 = list_all([order == 0 for order in interpolation])
|
| 200 |
+
if is_iso0:
|
| 201 |
+
return iso0.pushgrad(inp, grid, shape, bound_fn, extrapolate)
|
| 202 |
+
spline_fn = make_spline(interpolation)
|
| 203 |
+
return nd.pushgrad(inp, grid, shape, bound_fn, spline_fn, extrapolate)
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
# @torch.jit.script
|
| 207 |
+
def grid_hess(inp, grid, bound: List[int], interpolation: List[int],
|
| 208 |
+
extrapolate: int):
|
| 209 |
+
""" /!\ Used only in backward pass of grid_grad
|
| 210 |
+
inp: (B, C, *spatial_in) tensor
|
| 211 |
+
grid: (B, *spatial_out, D) tensor
|
| 212 |
+
bound: List{D}[int] tensor
|
| 213 |
+
interpolation: List{D}[int]
|
| 214 |
+
extrapolate: int
|
| 215 |
+
returns: (B, C, *spatial_out, D, D) tensor
|
| 216 |
+
"""
|
| 217 |
+
dim = grid.shape[-1]
|
| 218 |
+
bound = pad_list_int(bound, dim)
|
| 219 |
+
interpolation = pad_list_int(interpolation, dim)
|
| 220 |
+
bound_fn = make_bound(bound)
|
| 221 |
+
is_iso1 = list_all([order == 1 for order in interpolation])
|
| 222 |
+
if is_iso1:
|
| 223 |
+
if dim == 3:
|
| 224 |
+
return iso1.hess3d(inp, grid, bound_fn, extrapolate)
|
| 225 |
+
if dim == 2:
|
| 226 |
+
return iso1.hess2d(inp, grid, bound_fn, extrapolate)
|
| 227 |
+
if dim == 1:
|
| 228 |
+
return iso1.hess1d(inp, grid, bound_fn, extrapolate)
|
| 229 |
+
is_iso0 = list_all([order == 0 for order in interpolation])
|
| 230 |
+
if is_iso0:
|
| 231 |
+
return iso0.hess(inp, grid, bound_fn, extrapolate)
|
| 232 |
+
spline_fn = make_spline(interpolation)
|
| 233 |
+
return nd.hess(inp, grid, bound_fn, spline_fn, extrapolate)
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
# @torch.jit.script
|
| 237 |
+
def grid_pull_backward(grad, inp, grid, bound: List[int],
|
| 238 |
+
interpolation: List[int], extrapolate: int) \
|
| 239 |
+
-> Tuple[Optional[Tensor], Optional[Tensor], ]:
|
| 240 |
+
"""
|
| 241 |
+
grad: (B, C, *spatial_out) tensor
|
| 242 |
+
inp: (B, C, *spatial_in) tensor
|
| 243 |
+
grid: (B, *spatial_out, D) tensor
|
| 244 |
+
bound: List{D}[int] tensor
|
| 245 |
+
interpolation: List{D}[int]
|
| 246 |
+
extrapolate: int
|
| 247 |
+
returns: (B, C, *spatial_in) tensor, (B, *spatial_out, D)
|
| 248 |
+
"""
|
| 249 |
+
dim = grid.shape[-1]
|
| 250 |
+
grad_inp: Optional[Tensor] = None
|
| 251 |
+
grad_grid: Optional[Tensor] = None
|
| 252 |
+
if inp.requires_grad:
|
| 253 |
+
grad_inp = grid_push(grad, grid, inp.shape[-dim:], bound, interpolation, extrapolate)
|
| 254 |
+
if grid.requires_grad:
|
| 255 |
+
grad_grid = grid_grad(inp, grid, bound, interpolation, extrapolate)
|
| 256 |
+
# grad_grid = dot(grad_grid, grad.unsqueeze(-1), dim=1)
|
| 257 |
+
grad_grid = (grad_grid * grad.unsqueeze(-1)).sum(dim=1)
|
| 258 |
+
return grad_inp, grad_grid
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
# @torch.jit.script
|
| 262 |
+
def grid_push_backward(grad, inp, grid, bound: List[int],
|
| 263 |
+
interpolation: List[int], extrapolate: int) \
|
| 264 |
+
-> Tuple[Optional[Tensor], Optional[Tensor], ]:
|
| 265 |
+
"""
|
| 266 |
+
grad: (B, C, *spatial_out) tensor
|
| 267 |
+
inp: (B, C, *spatial_in) tensor
|
| 268 |
+
grid: (B, *spatial_in, D) tensor
|
| 269 |
+
bound: List{D}[int] tensor
|
| 270 |
+
interpolation: List{D}[int]
|
| 271 |
+
extrapolate: int
|
| 272 |
+
returns: (B, C, *spatial_in) tensor, (B, *spatial_in, D)
|
| 273 |
+
"""
|
| 274 |
+
grad_inp: Optional[Tensor] = None
|
| 275 |
+
grad_grid: Optional[Tensor] = None
|
| 276 |
+
if inp.requires_grad:
|
| 277 |
+
grad_inp = grid_pull(grad, grid, bound, interpolation, extrapolate)
|
| 278 |
+
if grid.requires_grad:
|
| 279 |
+
grad_grid = grid_grad(grad, grid, bound, interpolation, extrapolate)
|
| 280 |
+
# grad_grid = dot(grad_grid, inp.unsqueeze(-1), dim=1)
|
| 281 |
+
grad_grid = (grad_grid * inp.unsqueeze(-1)).sum(dim=1)
|
| 282 |
+
return grad_inp, grad_grid
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
# @torch.jit.script
|
| 286 |
+
def grid_count_backward(grad, grid, bound: List[int],
|
| 287 |
+
interpolation: List[int], extrapolate: int) \
|
| 288 |
+
-> Optional[Tensor]:
|
| 289 |
+
"""
|
| 290 |
+
grad: (B, C, *spatial_out) tensor
|
| 291 |
+
grid: (B, *spatial_in, D) tensor
|
| 292 |
+
bound: List{D}[int] tensor
|
| 293 |
+
interpolation: List{D}[int]
|
| 294 |
+
extrapolate: int
|
| 295 |
+
returns: (B, C, *spatial_in) tensor, (B, *spatial_in, D)
|
| 296 |
+
"""
|
| 297 |
+
if grid.requires_grad:
|
| 298 |
+
return grid_grad(grad, grid, bound, interpolation, extrapolate).sum(1)
|
| 299 |
+
return None
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
# @torch.jit.script
|
| 303 |
+
def grid_grad_backward(grad, inp, grid, bound: List[int],
|
| 304 |
+
interpolation: List[int], extrapolate: int) \
|
| 305 |
+
-> Tuple[Optional[Tensor], Optional[Tensor]]:
|
| 306 |
+
"""
|
| 307 |
+
grad: (B, C, *spatial_out, D) tensor
|
| 308 |
+
inp: (B, C, *spatial_in) tensor
|
| 309 |
+
grid: (B, *spatial_out, D) tensor
|
| 310 |
+
bound: List{D}[int] tensor
|
| 311 |
+
interpolation: List{D}[int]
|
| 312 |
+
extrapolate: int
|
| 313 |
+
returns: (B, C, *spatial_in, D) tensor, (B, *spatial_out, D)
|
| 314 |
+
"""
|
| 315 |
+
dim = grid.shape[-1]
|
| 316 |
+
shape = inp.shape[-dim:]
|
| 317 |
+
grad_inp: Optional[Tensor] = None
|
| 318 |
+
grad_grid: Optional[Tensor] = None
|
| 319 |
+
if inp.requires_grad:
|
| 320 |
+
grad_inp = grid_pushgrad(grad, grid, shape, bound, interpolation, extrapolate)
|
| 321 |
+
if grid.requires_grad:
|
| 322 |
+
grad_grid = grid_hess(inp, grid, bound, interpolation, extrapolate)
|
| 323 |
+
# grad_grid = dot_multi(grad_grid, grad.unsqueeze(-1), dim=[1, -2])
|
| 324 |
+
grad_grid = (grad_grid * grad.unsqueeze(-1)).sum(dim=[1, -2])
|
| 325 |
+
return grad_inp, grad_grid
|
Generator/interpol/resize.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Resize functions (equivalent to scipy's zoom, pytorch's interpolate)
|
| 3 |
+
based on grid_pull.
|
| 4 |
+
"""
|
| 5 |
+
__all__ = ['resize']
|
| 6 |
+
|
| 7 |
+
from .api import grid_pull
|
| 8 |
+
from .utils import make_list, meshgrid_ij
|
| 9 |
+
from . import backend, jitfields
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def resize(image, factor=None, shape=None, anchor='c',
|
| 14 |
+
interpolation=1, prefilter=True, **kwargs):
|
| 15 |
+
"""Resize an image by a factor or to a specific shape.
|
| 16 |
+
|
| 17 |
+
Notes
|
| 18 |
+
-----
|
| 19 |
+
.. A least one of `factor` and `shape` must be specified
|
| 20 |
+
.. If `anchor in ('centers', 'edges')`, exactly one of `factor` or
|
| 21 |
+
`shape must be specified.
|
| 22 |
+
.. If `anchor in ('first', 'last')`, `factor` must be provided even
|
| 23 |
+
if `shape` is specified.
|
| 24 |
+
.. Because of rounding, it is in general not assured that
|
| 25 |
+
`resize(resize(x, f), 1/f)` returns a tensor with the same shape as x.
|
| 26 |
+
|
| 27 |
+
edges centers first last
|
| 28 |
+
e - + - + - e + - + - + - + + - + - + - + + - + - + - +
|
| 29 |
+
| . | . | . | | c | . | c | | f | . | . | | . | . | . |
|
| 30 |
+
+ _ + _ + _ + + _ + _ + _ + + _ + _ + _ + + _ + _ + _ +
|
| 31 |
+
| . | . | . | | . | . | . | | . | . | . | | . | . | . |
|
| 32 |
+
+ _ + _ + _ + + _ + _ + _ + + _ + _ + _ + + _ + _ + _ +
|
| 33 |
+
| . | . | . | | c | . | c | | . | . | . | | . | . | l |
|
| 34 |
+
e _ + _ + _ e + _ + _ + _ + + _ + _ + _ + + _ + _ + _ +
|
| 35 |
+
|
| 36 |
+
Parameters
|
| 37 |
+
----------
|
| 38 |
+
image : (batch, channel, *inshape) tensor
|
| 39 |
+
Image to resize
|
| 40 |
+
factor : float or list[float], optional
|
| 41 |
+
Resizing factor
|
| 42 |
+
* > 1 : larger image <-> smaller voxels
|
| 43 |
+
* < 1 : smaller image <-> larger voxels
|
| 44 |
+
shape : (ndim,) list[int], optional
|
| 45 |
+
Output shape
|
| 46 |
+
anchor : {'centers', 'edges', 'first', 'last'} or list, default='centers'
|
| 47 |
+
* In cases 'c' and 'e', the volume shape is multiplied by the
|
| 48 |
+
zoom factor (and eventually truncated), and two anchor points
|
| 49 |
+
are used to determine the voxel size.
|
| 50 |
+
* In cases 'f' and 'l', a single anchor point is used so that
|
| 51 |
+
the voxel size is exactly divided by the zoom factor.
|
| 52 |
+
This case with an integer factor corresponds to subslicing
|
| 53 |
+
the volume (e.g., `vol[::f, ::f, ::f]`).
|
| 54 |
+
* A list of anchors (one per dimension) can also be provided.
|
| 55 |
+
interpolation : int or sequence[int], default=1
|
| 56 |
+
Interpolation order.
|
| 57 |
+
prefilter : bool, default=True
|
| 58 |
+
Apply spline pre-filter (= interpolates the input)
|
| 59 |
+
|
| 60 |
+
Returns
|
| 61 |
+
-------
|
| 62 |
+
resized : (batch, channel, *shape) tensor
|
| 63 |
+
Resized image
|
| 64 |
+
|
| 65 |
+
"""
|
| 66 |
+
if backend.jitfields and jitfields.available:
|
| 67 |
+
return jitfields.resize(image, factor, shape, anchor,
|
| 68 |
+
interpolation, prefilter, **kwargs)
|
| 69 |
+
|
| 70 |
+
factor = make_list(factor) if factor else []
|
| 71 |
+
shape = make_list(shape) if shape else []
|
| 72 |
+
anchor = make_list(anchor)
|
| 73 |
+
nb_dim = max(len(factor), len(shape), len(anchor)) or (image.dim() - 2)
|
| 74 |
+
anchor = [a[0].lower() for a in make_list(anchor, nb_dim)]
|
| 75 |
+
bck = dict(dtype=image.dtype, device=image.device)
|
| 76 |
+
|
| 77 |
+
# compute output shape
|
| 78 |
+
inshape = image.shape[-nb_dim:]
|
| 79 |
+
if factor:
|
| 80 |
+
factor = make_list(factor, nb_dim)
|
| 81 |
+
elif not shape:
|
| 82 |
+
raise ValueError('One of `factor` or `shape` must be provided')
|
| 83 |
+
if shape:
|
| 84 |
+
shape = make_list(shape, nb_dim)
|
| 85 |
+
else:
|
| 86 |
+
shape = [int(i*f) for i, f in zip(inshape, factor)]
|
| 87 |
+
|
| 88 |
+
if not factor:
|
| 89 |
+
factor = [o/i for o, i in zip(shape, inshape)]
|
| 90 |
+
|
| 91 |
+
# compute transformation grid
|
| 92 |
+
lin = []
|
| 93 |
+
for anch, f, inshp, outshp in zip(anchor, factor, inshape, shape):
|
| 94 |
+
if anch == 'c': # centers
|
| 95 |
+
lin.append(torch.linspace(0, inshp - 1, outshp, **bck))
|
| 96 |
+
elif anch == 'e': # edges
|
| 97 |
+
scale = inshp / outshp
|
| 98 |
+
shift = 0.5 * (scale - 1)
|
| 99 |
+
lin.append(torch.arange(0., outshp, **bck) * scale + shift)
|
| 100 |
+
elif anch == 'f': # first voxel
|
| 101 |
+
# scale = 1/f
|
| 102 |
+
# shift = 0
|
| 103 |
+
lin.append(torch.arange(0., outshp, **bck) / f)
|
| 104 |
+
elif anch == 'l': # last voxel
|
| 105 |
+
# scale = 1/f
|
| 106 |
+
shift = (inshp - 1) - (outshp - 1) / f
|
| 107 |
+
lin.append(torch.arange(0., outshp, **bck) / f + shift)
|
| 108 |
+
else:
|
| 109 |
+
raise ValueError('Unknown anchor {}'.format(anch))
|
| 110 |
+
|
| 111 |
+
# interpolate
|
| 112 |
+
kwargs.setdefault('bound', 'nearest')
|
| 113 |
+
kwargs.setdefault('extrapolate', True)
|
| 114 |
+
kwargs.setdefault('interpolation', interpolation)
|
| 115 |
+
kwargs.setdefault('prefilter', prefilter)
|
| 116 |
+
grid = torch.stack(meshgrid_ij(*lin), dim=-1)
|
| 117 |
+
resized = grid_pull(image, grid, **kwargs)
|
| 118 |
+
|
| 119 |
+
return resized
|
| 120 |
+
|
Generator/interpol/restrict.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__all__ = ['restrict']
|
| 2 |
+
|
| 3 |
+
from .api import grid_push
|
| 4 |
+
from .utils import make_list, meshgrid_ij
|
| 5 |
+
from . import backend, jitfields
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def restrict(image, factor=None, shape=None, anchor='c',
|
| 10 |
+
interpolation=1, reduce_sum=False, **kwargs):
|
| 11 |
+
"""Restrict an image by a factor or to a specific shape.
|
| 12 |
+
|
| 13 |
+
Notes
|
| 14 |
+
-----
|
| 15 |
+
.. A least one of `factor` and `shape` must be specified
|
| 16 |
+
.. If `anchor in ('centers', 'edges')`, exactly one of `factor` or
|
| 17 |
+
`shape must be specified.
|
| 18 |
+
.. If `anchor in ('first', 'last')`, `factor` must be provided even
|
| 19 |
+
if `shape` is specified.
|
| 20 |
+
.. Because of rounding, it is in general not assured that
|
| 21 |
+
`resize(resize(x, f), 1/f)` returns a tensor with the same shape as x.
|
| 22 |
+
|
| 23 |
+
edges centers first last
|
| 24 |
+
e - + - + - e + - + - + - + + - + - + - + + - + - + - +
|
| 25 |
+
| . | . | . | | c | . | c | | f | . | . | | . | . | . |
|
| 26 |
+
+ _ + _ + _ + + _ + _ + _ + + _ + _ + _ + + _ + _ + _ +
|
| 27 |
+
| . | . | . | | . | . | . | | . | . | . | | . | . | . |
|
| 28 |
+
+ _ + _ + _ + + _ + _ + _ + + _ + _ + _ + + _ + _ + _ +
|
| 29 |
+
| . | . | . | | c | . | c | | . | . | . | | . | . | l |
|
| 30 |
+
e _ + _ + _ e + _ + _ + _ + + _ + _ + _ + + _ + _ + _ +
|
| 31 |
+
|
| 32 |
+
Parameters
|
| 33 |
+
----------
|
| 34 |
+
image : (batch, channel, *inshape) tensor
|
| 35 |
+
Image to resize
|
| 36 |
+
factor : float or list[float], optional
|
| 37 |
+
Resizing factor
|
| 38 |
+
* > 1 : larger image <-> smaller voxels
|
| 39 |
+
* < 1 : smaller image <-> larger voxels
|
| 40 |
+
shape : (ndim,) list[int], optional
|
| 41 |
+
Output shape
|
| 42 |
+
anchor : {'centers', 'edges', 'first', 'last'} or list, default='centers'
|
| 43 |
+
* In cases 'c' and 'e', the volume shape is multiplied by the
|
| 44 |
+
zoom factor (and eventually truncated), and two anchor points
|
| 45 |
+
are used to determine the voxel size.
|
| 46 |
+
* In cases 'f' and 'l', a single anchor point is used so that
|
| 47 |
+
the voxel size is exactly divided by the zoom factor.
|
| 48 |
+
This case with an integer factor corresponds to subslicing
|
| 49 |
+
the volume (e.g., `vol[::f, ::f, ::f]`).
|
| 50 |
+
* A list of anchors (one per dimension) can also be provided.
|
| 51 |
+
interpolation : int or sequence[int], default=1
|
| 52 |
+
Interpolation order.
|
| 53 |
+
reduce_sum : bool, default=False
|
| 54 |
+
Do not normalize by the number of accumulated values per voxel
|
| 55 |
+
|
| 56 |
+
Returns
|
| 57 |
+
-------
|
| 58 |
+
restricted : (batch, channel, *shape) tensor
|
| 59 |
+
Restricted image
|
| 60 |
+
|
| 61 |
+
"""
|
| 62 |
+
if backend.jitfields and jitfields.available:
|
| 63 |
+
return jitfields.restrict(image, factor, shape, anchor,
|
| 64 |
+
interpolation, reduce_sum, **kwargs)
|
| 65 |
+
|
| 66 |
+
factor = make_list(factor) if factor else []
|
| 67 |
+
shape = make_list(shape) if shape else []
|
| 68 |
+
anchor = make_list(anchor)
|
| 69 |
+
nb_dim = max(len(factor), len(shape), len(anchor)) or (image.dim() - 2)
|
| 70 |
+
anchor = [a[0].lower() for a in make_list(anchor, nb_dim)]
|
| 71 |
+
bck = dict(dtype=image.dtype, device=image.device)
|
| 72 |
+
|
| 73 |
+
# compute output shape
|
| 74 |
+
inshape = image.shape[-nb_dim:]
|
| 75 |
+
if factor:
|
| 76 |
+
factor = make_list(factor, nb_dim)
|
| 77 |
+
elif not shape:
|
| 78 |
+
raise ValueError('One of `factor` or `shape` must be provided')
|
| 79 |
+
if shape:
|
| 80 |
+
shape = make_list(shape, nb_dim)
|
| 81 |
+
else:
|
| 82 |
+
shape = [int(i/f) for i, f in zip(inshape, factor)]
|
| 83 |
+
|
| 84 |
+
if not factor:
|
| 85 |
+
factor = [i/o for o, i in zip(shape, inshape)]
|
| 86 |
+
|
| 87 |
+
# compute transformation grid
|
| 88 |
+
lin = []
|
| 89 |
+
fullscale = 1
|
| 90 |
+
for anch, f, inshp, outshp in zip(anchor, factor, inshape, shape):
|
| 91 |
+
if anch == 'c': # centers
|
| 92 |
+
lin.append(torch.linspace(0, outshp - 1, inshp, **bck))
|
| 93 |
+
fullscale *= (inshp - 1) / (outshp - 1)
|
| 94 |
+
elif anch == 'e': # edges
|
| 95 |
+
scale = outshp / inshp
|
| 96 |
+
shift = 0.5 * (scale - 1)
|
| 97 |
+
fullscale *= scale
|
| 98 |
+
lin.append(torch.arange(0., inshp, **bck) * scale + shift)
|
| 99 |
+
elif anch == 'f': # first voxel
|
| 100 |
+
# scale = 1/f
|
| 101 |
+
# shift = 0
|
| 102 |
+
fullscale *= 1/f
|
| 103 |
+
lin.append(torch.arange(0., inshp, **bck) / f)
|
| 104 |
+
elif anch == 'l': # last voxel
|
| 105 |
+
# scale = 1/f
|
| 106 |
+
shift = (outshp - 1) - (inshp - 1) / f
|
| 107 |
+
fullscale *= 1/f
|
| 108 |
+
lin.append(torch.arange(0., inshp, **bck) / f + shift)
|
| 109 |
+
else:
|
| 110 |
+
raise ValueError('Unknown anchor {}'.format(anch))
|
| 111 |
+
|
| 112 |
+
# scatter
|
| 113 |
+
kwargs.setdefault('bound', 'nearest')
|
| 114 |
+
kwargs.setdefault('extrapolate', True)
|
| 115 |
+
kwargs.setdefault('interpolation', interpolation)
|
| 116 |
+
kwargs.setdefault('prefilter', False)
|
| 117 |
+
grid = torch.stack(meshgrid_ij(*lin), dim=-1)
|
| 118 |
+
resized = grid_push(image, grid, shape, **kwargs)
|
| 119 |
+
if not reduce_sum:
|
| 120 |
+
resized /= fullscale
|
| 121 |
+
|
| 122 |
+
return resized
|
Generator/interpol/splines.py
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Weights and derivatives of spline orders 0 to 7."""
|
| 2 |
+
import torch
|
| 3 |
+
from enum import Enum
|
| 4 |
+
from .jit_utils import square, cube, pow4, pow5, pow6, pow7
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class InterpolationType(Enum):
|
| 8 |
+
nearest = zeroth = 0
|
| 9 |
+
linear = first = 1
|
| 10 |
+
quadratic = second = 2
|
| 11 |
+
cubic = third = 3
|
| 12 |
+
fourth = 4
|
| 13 |
+
fifth = 5
|
| 14 |
+
sixth = 6
|
| 15 |
+
seventh = 7
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@torch.jit.script
|
| 19 |
+
class Spline:
|
| 20 |
+
|
| 21 |
+
def __init__(self, order: int = 1):
|
| 22 |
+
self.order = order
|
| 23 |
+
|
| 24 |
+
def weight(self, x):
|
| 25 |
+
w = self.fastweight(x)
|
| 26 |
+
zero = torch.zeros([1], dtype=x.dtype, device=x.device)
|
| 27 |
+
w = torch.where(x.abs() >= (self.order + 1)/2, zero, w)
|
| 28 |
+
return w
|
| 29 |
+
|
| 30 |
+
def fastweight(self, x):
|
| 31 |
+
if self.order == 0:
|
| 32 |
+
return torch.ones(x.shape, dtype=x.dtype, device=x.device)
|
| 33 |
+
x = x.abs()
|
| 34 |
+
if self.order == 1:
|
| 35 |
+
return 1 - x
|
| 36 |
+
if self.order == 2:
|
| 37 |
+
x_low = 0.75 - square(x)
|
| 38 |
+
x_up = 0.5 * square(1.5 - x)
|
| 39 |
+
return torch.where(x < 0.5, x_low, x_up)
|
| 40 |
+
if self.order == 3:
|
| 41 |
+
x_low = (x * x * (x - 2.) * 3. + 4.) / 6.
|
| 42 |
+
x_up = cube(2. - x) / 6.
|
| 43 |
+
return torch.where(x < 1., x_low, x_up)
|
| 44 |
+
if self.order == 4:
|
| 45 |
+
x_low = square(x)
|
| 46 |
+
x_low = x_low * (x_low * 0.25 - 0.625) + 115. / 192.
|
| 47 |
+
x_mid = x * (x * (x * (5. - x) / 6. - 1.25) + 5./24.) + 55./96.
|
| 48 |
+
x_up = pow4(x - 2.5) / 24.
|
| 49 |
+
return torch.where(x < 0.5, x_low, torch.where(x < 1.5, x_mid, x_up))
|
| 50 |
+
if self.order == 5:
|
| 51 |
+
x_low = square(x)
|
| 52 |
+
x_low = x_low * (x_low * (0.25 - x / 12.) - 0.5) + 0.55
|
| 53 |
+
x_mid = x * (x * (x * (x * (x / 24. - 0.375) + 1.25) - 1.75) + 0.625) + 0.425
|
| 54 |
+
x_up = pow5(3 - x) / 120.
|
| 55 |
+
return torch.where(x < 1., x_low, torch.where(x < 2., x_mid, x_up))
|
| 56 |
+
if self.order == 6:
|
| 57 |
+
x_low = square(x)
|
| 58 |
+
x_low = x_low * (x_low * (7./48. - x_low/36.) - 77./192.) + 5887./11520.
|
| 59 |
+
x_mid_low = (x * (x * (x * (x * (x * (x / 48. - 7./48.) + 0.328125)
|
| 60 |
+
- 35./288.) - 91./256.) - 7./768.) + 7861./15360.)
|
| 61 |
+
x_mid_up = (x * (x * (x * (x * (x * (7./60. - x / 120.) - 0.65625)
|
| 62 |
+
+ 133./72.) - 2.5703125) + 1267./960.) + 1379./7680.)
|
| 63 |
+
x_up = pow6(x - 3.5) / 720.
|
| 64 |
+
return torch.where(x < .5, x_low,
|
| 65 |
+
torch.where(x < 1.5, x_mid_low,
|
| 66 |
+
torch.where(x < 2.5, x_mid_up, x_up)))
|
| 67 |
+
if self.order == 7:
|
| 68 |
+
x_low = square(x)
|
| 69 |
+
x_low = (x_low * (x_low * (x_low * (x / 144. - 1./36.)
|
| 70 |
+
+ 1./9.) - 1./3.) + 151./315.)
|
| 71 |
+
x_mid_low = (x * (x * (x * (x * (x * (x * (0.05 - x/240.) - 7./30.)
|
| 72 |
+
+ 0.5) - 7./18.) - 0.1) - 7./90.) + 103./210.)
|
| 73 |
+
x_mid_up = (x * (x * (x * (x * (x * (x * (x / 720. - 1./36.)
|
| 74 |
+
+ 7./30.) - 19./18.) + 49./18.) - 23./6.) + 217./90.)
|
| 75 |
+
- 139./630.)
|
| 76 |
+
x_up = pow7(4 - x) / 5040.
|
| 77 |
+
return torch.where(x < 1., x_low,
|
| 78 |
+
torch.where(x < 2., x_mid_low,
|
| 79 |
+
torch.where(x < 3., x_mid_up, x_up)))
|
| 80 |
+
raise NotImplementedError
|
| 81 |
+
|
| 82 |
+
def grad(self, x):
|
| 83 |
+
if self.order == 0:
|
| 84 |
+
return torch.zeros(x.shape, dtype=x.dtype, device=x.device)
|
| 85 |
+
g = self.fastgrad(x)
|
| 86 |
+
zero = torch.zeros([1], dtype=x.dtype, device=x.device)
|
| 87 |
+
g = torch.where(x.abs() >= (self.order + 1)/2, zero, g)
|
| 88 |
+
return g
|
| 89 |
+
|
| 90 |
+
def fastgrad(self, x):
|
| 91 |
+
if self.order == 0:
|
| 92 |
+
return torch.zeros(x.shape, dtype=x.dtype, device=x.device)
|
| 93 |
+
return self._fastgrad(x.abs()).mul(x.sign())
|
| 94 |
+
|
| 95 |
+
def _fastgrad(self, x):
|
| 96 |
+
if self.order == 1:
|
| 97 |
+
return torch.ones(x.shape, dtype=x.dtype, device=x.device)
|
| 98 |
+
if self.order == 2:
|
| 99 |
+
return torch.where(x < 0.5, -2*x, x - 1.5)
|
| 100 |
+
if self.order == 3:
|
| 101 |
+
g_low = x * (x * 1.5 - 2)
|
| 102 |
+
g_up = -0.5 * square(2 - x)
|
| 103 |
+
return torch.where(x < 1, g_low, g_up)
|
| 104 |
+
if self.order == 4:
|
| 105 |
+
g_low = x * (square(x) - 1.25)
|
| 106 |
+
g_mid = x * (x * (x * (-2./3.) + 2.5) - 2.5) + 5./24.
|
| 107 |
+
g_up = cube(2. * x - 5.) / 48.
|
| 108 |
+
return torch.where(x < 0.5, g_low,
|
| 109 |
+
torch.where(x < 1.5, g_mid, g_up))
|
| 110 |
+
if self.order == 5:
|
| 111 |
+
g_low = x * (x * (x * (x * (-5./12.) + 1.)) - 1.)
|
| 112 |
+
g_mid = x * (x * (x * (x * (5./24.) - 1.5) + 3.75) - 3.5) + 0.625
|
| 113 |
+
g_up = pow4(x - 3.) / (-24.)
|
| 114 |
+
return torch.where(x < 1, g_low,
|
| 115 |
+
torch.where(x < 2, g_mid, g_up))
|
| 116 |
+
if self.order == 6:
|
| 117 |
+
g_low = square(x)
|
| 118 |
+
g_low = x * (g_low * (7./12.) - square(g_low) / 6. - 77./96.)
|
| 119 |
+
g_mid_low = (x * (x * (x * (x * (x * 0.125 - 35./48.) + 1.3125)
|
| 120 |
+
- 35./96.) - 0.7109375) - 7./768.)
|
| 121 |
+
g_mid_up = (x * (x * (x * (x * (x / (-20.) + 7./12.) - 2.625)
|
| 122 |
+
+ 133./24.) - 5.140625) + 1267./960.)
|
| 123 |
+
g_up = pow5(2*x - 7) / 3840.
|
| 124 |
+
return torch.where(x < 0.5, g_low,
|
| 125 |
+
torch.where(x < 1.5, g_mid_low,
|
| 126 |
+
torch.where(x < 2.5, g_mid_up,
|
| 127 |
+
g_up)))
|
| 128 |
+
if self.order == 7:
|
| 129 |
+
g_low = square(x)
|
| 130 |
+
g_low = x * (g_low * (g_low * (x * (7./144.) - 1./6.) + 4./9.) - 2./3.)
|
| 131 |
+
g_mid_low = (x * (x * (x * (x * (x * (x * (-7./240.) + 3./10.)
|
| 132 |
+
- 7./6.) + 2.) - 7./6.) - 1./5.) - 7./90.)
|
| 133 |
+
g_mid_up = (x * (x * (x * (x * (x * (x * (7./720.) - 1./6.)
|
| 134 |
+
+ 7./6.) - 38./9.) + 49./6.) - 23./3.) + 217./90.)
|
| 135 |
+
g_up = pow6(x - 4) / (-720.)
|
| 136 |
+
return torch.where(x < 1, g_low,
|
| 137 |
+
torch.where(x < 2, g_mid_low,
|
| 138 |
+
torch.where(x < 3, g_mid_up, g_up)))
|
| 139 |
+
raise NotImplementedError
|
| 140 |
+
|
| 141 |
+
def hess(self, x):
|
| 142 |
+
if self.order == 0:
|
| 143 |
+
return torch.zeros(x.shape, dtype=x.dtype, device=x.device)
|
| 144 |
+
h = self.fasthess(x)
|
| 145 |
+
zero = torch.zeros([1], dtype=x.dtype, device=x.device)
|
| 146 |
+
h = torch.where(x.abs() >= (self.order + 1)/2, zero, h)
|
| 147 |
+
return h
|
| 148 |
+
|
| 149 |
+
def fasthess(self, x):
|
| 150 |
+
if self.order in (0, 1):
|
| 151 |
+
return torch.zeros(x.shape, dtype=x.dtype, device=x.device)
|
| 152 |
+
x = x.abs()
|
| 153 |
+
if self.order == 2:
|
| 154 |
+
one = torch.ones([1], dtype=x.dtype, device=x.device)
|
| 155 |
+
return torch.where(x < 0.5, -2 * one, one)
|
| 156 |
+
if self.order == 3:
|
| 157 |
+
return torch.where(x < 1, 3. * x - 2., 2. - x)
|
| 158 |
+
if self.order == 4:
|
| 159 |
+
return torch.where(x < 0.5, 3. * square(x) - 1.25,
|
| 160 |
+
torch.where(x < 1.5, x * (-2. * x + 5.) - 2.5,
|
| 161 |
+
square(2. * x - 5.) / 8.))
|
| 162 |
+
if self.order == 5:
|
| 163 |
+
h_low = square(x)
|
| 164 |
+
h_low = - h_low * (x * (5./3.) - 3.) - 1.
|
| 165 |
+
h_mid = x * (x * (x * (5./6.) - 9./2.) + 15./2.) - 7./2.
|
| 166 |
+
h_up = 9./2. - x * (x * (x/6. - 3./2.) + 9./2.)
|
| 167 |
+
return torch.where(x < 1, h_low,
|
| 168 |
+
torch.where(x < 2, h_mid, h_up))
|
| 169 |
+
if self.order == 6:
|
| 170 |
+
h_low = square(x)
|
| 171 |
+
h_low = - h_low * (h_low * (5./6) - 7./4.) - 77./96.
|
| 172 |
+
h_mid_low = (x * (x * (x * (x * (5./8.) - 35./12.) + 63./16.)
|
| 173 |
+
- 35./48.) - 91./128.)
|
| 174 |
+
h_mid_up = -(x * (x * (x * (x/4. - 7./3.) + 63./8.) - 133./12.)
|
| 175 |
+
+ 329./64.)
|
| 176 |
+
h_up = (x * (x * (x * (x/24. - 7./12.) + 49./16.) - 343./48.)
|
| 177 |
+
+ 2401./384.)
|
| 178 |
+
return torch.where(x < 0.5, h_low,
|
| 179 |
+
torch.where(x < 1.5, h_mid_low,
|
| 180 |
+
torch.where(x < 2.5, h_mid_up,
|
| 181 |
+
h_up)))
|
| 182 |
+
if self.order == 7:
|
| 183 |
+
h_low = square(x)
|
| 184 |
+
h_low = h_low * (h_low*(x * (7./24.) - 5./6.) + 4./3.) - 2./3.
|
| 185 |
+
h_mid_low = - (x * (x * (x * (x * (x * (7./40.) - 3./2.) + 14./3.)
|
| 186 |
+
- 6.) + 7./3.) + 1./5.)
|
| 187 |
+
h_mid_up = (x * (x * (x * (x * (x * (7./120.) - 5./6.) + 14./3.)
|
| 188 |
+
- 38./3.) + 49./3.) - 23./3.)
|
| 189 |
+
h_up = - (x * (x * (x * (x * (x/120. - 1./6.) + 4./3.) - 16./3.)
|
| 190 |
+
+ 32./3.) - 128./15.)
|
| 191 |
+
return torch.where(x < 1, h_low,
|
| 192 |
+
torch.where(x < 2, h_mid_low,
|
| 193 |
+
torch.where(x < 3, h_mid_up,
|
| 194 |
+
h_up)))
|
| 195 |
+
raise NotImplementedError
|
| 196 |
+
|
Generator/interpol/tests/__init__.py
ADDED
|
File without changes
|
Generator/interpol/tests/test_gradcheck_pushpull.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.autograd import gradcheck
|
| 3 |
+
from interpol import grid_pull, grid_push, grid_count, grid_grad, add_identity_grid_
|
| 4 |
+
import pytest
|
| 5 |
+
import inspect
|
| 6 |
+
|
| 7 |
+
# global parameters
|
| 8 |
+
dtype = torch.double # data type (double advised to check gradients)
|
| 9 |
+
shape1 = 3 # size along each dimension
|
| 10 |
+
extrapolate = True
|
| 11 |
+
|
| 12 |
+
if hasattr(torch, 'use_deterministic_algorithms'):
|
| 13 |
+
torch.use_deterministic_algorithms(True)
|
| 14 |
+
kwargs = dict(rtol=1., raise_exception=True)
|
| 15 |
+
if 'check_undefined_grad' in inspect.signature(gradcheck).parameters:
|
| 16 |
+
kwargs['check_undefined_grad'] = False
|
| 17 |
+
if 'nondet_tol' in inspect.signature(gradcheck).parameters:
|
| 18 |
+
kwargs['nondet_tol'] = 1e-3
|
| 19 |
+
|
| 20 |
+
# parameters
|
| 21 |
+
devices = [('cpu', 1)]
|
| 22 |
+
if torch.backends.openmp.is_available() or torch.backends.mkl.is_available():
|
| 23 |
+
print('parallel backend available')
|
| 24 |
+
devices.append(('cpu', 10))
|
| 25 |
+
if torch.cuda.is_available():
|
| 26 |
+
print('cuda backend available')
|
| 27 |
+
devices.append('cuda')
|
| 28 |
+
|
| 29 |
+
dims = [1, 2, 3]
|
| 30 |
+
bounds = list(range(7))
|
| 31 |
+
order_bounds = []
|
| 32 |
+
for o in range(3):
|
| 33 |
+
for b in bounds:
|
| 34 |
+
order_bounds += [(o, b)]
|
| 35 |
+
for o in range(3, 8):
|
| 36 |
+
order_bounds += [(o, 3)] # only test dc2 for order > 2
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def make_data(shape, device, dtype):
|
| 40 |
+
grid = torch.randn([2, *shape, len(shape)], device=device, dtype=dtype)
|
| 41 |
+
grid = add_identity_grid_(grid)
|
| 42 |
+
vol = torch.randn((2, 1,) + shape, device=device, dtype=dtype)
|
| 43 |
+
return vol, grid
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def init_device(device):
|
| 47 |
+
if isinstance(device, (list, tuple)):
|
| 48 |
+
device, param = device
|
| 49 |
+
else:
|
| 50 |
+
param = 1 if device == 'cpu' else 0
|
| 51 |
+
if device == 'cuda':
|
| 52 |
+
torch.cuda.set_device(param)
|
| 53 |
+
torch.cuda.init()
|
| 54 |
+
try:
|
| 55 |
+
torch.cuda.empty_cache()
|
| 56 |
+
except RuntimeError:
|
| 57 |
+
pass
|
| 58 |
+
device = '{}:{}'.format(device, param)
|
| 59 |
+
else:
|
| 60 |
+
assert device == 'cpu'
|
| 61 |
+
torch.set_num_threads(param)
|
| 62 |
+
return torch.device(device)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
@pytest.mark.parametrize("device", devices)
|
| 66 |
+
@pytest.mark.parametrize("dim", dims)
|
| 67 |
+
# @pytest.mark.parametrize("bound", bounds)
|
| 68 |
+
# @pytest.mark.parametrize("interpolation", orders)
|
| 69 |
+
@pytest.mark.parametrize("interpolation,bound", order_bounds)
|
| 70 |
+
def test_gradcheck_grad(device, dim, bound, interpolation):
|
| 71 |
+
print(f'grad_{dim}d({interpolation}, {bound}) on {device}')
|
| 72 |
+
device = init_device(device)
|
| 73 |
+
shape = (shape1,) * dim
|
| 74 |
+
vol, grid = make_data(shape, device, dtype)
|
| 75 |
+
vol.requires_grad = True
|
| 76 |
+
grid.requires_grad = True
|
| 77 |
+
assert gradcheck(grid_grad, (vol, grid, interpolation, bound, extrapolate),
|
| 78 |
+
**kwargs)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
@pytest.mark.parametrize("device", devices)
|
| 82 |
+
@pytest.mark.parametrize("dim", dims)
|
| 83 |
+
# @pytest.mark.parametrize("bound", bounds)
|
| 84 |
+
# @pytest.mark.parametrize("interpolation", orders)
|
| 85 |
+
@pytest.mark.parametrize("interpolation,bound", order_bounds)
|
| 86 |
+
def test_gradcheck_pull(device, dim, bound, interpolation):
|
| 87 |
+
print(f'pull_{dim}d({interpolation}, {bound}) on {device}')
|
| 88 |
+
device = init_device(device)
|
| 89 |
+
shape = (shape1,) * dim
|
| 90 |
+
vol, grid = make_data(shape, device, dtype)
|
| 91 |
+
vol.requires_grad = True
|
| 92 |
+
grid.requires_grad = True
|
| 93 |
+
assert gradcheck(grid_pull, (vol, grid, interpolation, bound, extrapolate),
|
| 94 |
+
**kwargs)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
@pytest.mark.parametrize("device", devices)
|
| 98 |
+
@pytest.mark.parametrize("dim", dims)
|
| 99 |
+
# @pytest.mark.parametrize("bound", bounds)
|
| 100 |
+
# @pytest.mark.parametrize("interpolation", orders)
|
| 101 |
+
@pytest.mark.parametrize("interpolation,bound", order_bounds)
|
| 102 |
+
def test_gradcheck_push(device, dim, bound, interpolation):
|
| 103 |
+
print(f'push_{dim}d({interpolation}, {bound}) on {device}')
|
| 104 |
+
device = init_device(device)
|
| 105 |
+
shape = (shape1,) * dim
|
| 106 |
+
vol, grid = make_data(shape, device, dtype)
|
| 107 |
+
vol.requires_grad = True
|
| 108 |
+
grid.requires_grad = True
|
| 109 |
+
assert gradcheck(grid_push, (vol, grid, shape, interpolation, bound, extrapolate),
|
| 110 |
+
**kwargs)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
@pytest.mark.parametrize("device", devices)
|
| 114 |
+
@pytest.mark.parametrize("dim", dims)
|
| 115 |
+
# @pytest.mark.parametrize("bound", bounds)
|
| 116 |
+
# @pytest.mark.parametrize("interpolation", orders)
|
| 117 |
+
@pytest.mark.parametrize("interpolation,bound", order_bounds)
|
| 118 |
+
def test_gradcheck_count(device, dim, bound, interpolation):
|
| 119 |
+
print(f'count_{dim}d({interpolation}, {bound}) on {device}')
|
| 120 |
+
device = init_device(device)
|
| 121 |
+
shape = (shape1,) * dim
|
| 122 |
+
_, grid = make_data(shape, device, dtype)
|
| 123 |
+
grid.requires_grad = True
|
| 124 |
+
assert gradcheck(grid_count, (grid, shape, interpolation, bound, extrapolate),
|
| 125 |
+
**kwargs)
|
Generator/interpol/utils.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def fake_decorator(*a, **k):
|
| 5 |
+
if len(a) == 1 and not k:
|
| 6 |
+
return a[0]
|
| 7 |
+
else:
|
| 8 |
+
return fake_decorator
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def make_list(x, n=None, **kwargs):
|
| 12 |
+
"""Ensure that the input is a list (of a given size)
|
| 13 |
+
|
| 14 |
+
Parameters
|
| 15 |
+
----------
|
| 16 |
+
x : list or tuple or scalar
|
| 17 |
+
Input object
|
| 18 |
+
n : int, optional
|
| 19 |
+
Required length
|
| 20 |
+
default : scalar, optional
|
| 21 |
+
Value to right-pad with. Use last value of the input by default.
|
| 22 |
+
|
| 23 |
+
Returns
|
| 24 |
+
-------
|
| 25 |
+
x : list
|
| 26 |
+
"""
|
| 27 |
+
if not isinstance(x, (list, tuple)):
|
| 28 |
+
x = [x]
|
| 29 |
+
x = list(x)
|
| 30 |
+
if n and len(x) < n:
|
| 31 |
+
default = kwargs.get('default', x[-1])
|
| 32 |
+
x = x + [default] * max(0, n - len(x))
|
| 33 |
+
return x
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def expanded_shape(*shapes, side='left'):
|
| 37 |
+
"""Expand input shapes according to broadcasting rules
|
| 38 |
+
|
| 39 |
+
Parameters
|
| 40 |
+
----------
|
| 41 |
+
*shapes : sequence[int]
|
| 42 |
+
Input shapes
|
| 43 |
+
side : {'left', 'right'}, default='left'
|
| 44 |
+
Side to add singleton dimensions.
|
| 45 |
+
|
| 46 |
+
Returns
|
| 47 |
+
-------
|
| 48 |
+
shape : tuple[int]
|
| 49 |
+
Output shape
|
| 50 |
+
|
| 51 |
+
Raises
|
| 52 |
+
------
|
| 53 |
+
ValueError
|
| 54 |
+
If shapes are not compatible for broadcast.
|
| 55 |
+
|
| 56 |
+
"""
|
| 57 |
+
def error(s0, s1):
|
| 58 |
+
raise ValueError('Incompatible shapes for broadcasting: {} and {}.'
|
| 59 |
+
.format(s0, s1))
|
| 60 |
+
|
| 61 |
+
# 1. nb dimensions
|
| 62 |
+
nb_dim = 0
|
| 63 |
+
for shape in shapes:
|
| 64 |
+
nb_dim = max(nb_dim, len(shape))
|
| 65 |
+
|
| 66 |
+
# 2. enumerate
|
| 67 |
+
shape = [1] * nb_dim
|
| 68 |
+
for i, shape1 in enumerate(shapes):
|
| 69 |
+
pad_size = nb_dim - len(shape1)
|
| 70 |
+
ones = [1] * pad_size
|
| 71 |
+
if side == 'left':
|
| 72 |
+
shape1 = [*ones, *shape1]
|
| 73 |
+
else:
|
| 74 |
+
shape1 = [*shape1, *ones]
|
| 75 |
+
shape = [max(s0, s1) if s0 == 1 or s1 == 1 or s0 == s1
|
| 76 |
+
else error(s0, s1) for s0, s1 in zip(shape, shape1)]
|
| 77 |
+
|
| 78 |
+
return tuple(shape)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def matvec(mat, vec, out=None):
|
| 82 |
+
"""Matrix-vector product (supports broadcasting)
|
| 83 |
+
|
| 84 |
+
Parameters
|
| 85 |
+
----------
|
| 86 |
+
mat : (..., M, N) tensor
|
| 87 |
+
Input matrix.
|
| 88 |
+
vec : (..., N) tensor
|
| 89 |
+
Input vector.
|
| 90 |
+
out : (..., M) tensor, optional
|
| 91 |
+
Placeholder for the output tensor.
|
| 92 |
+
|
| 93 |
+
Returns
|
| 94 |
+
-------
|
| 95 |
+
mv : (..., M) tensor
|
| 96 |
+
Matrix vector product of the inputs
|
| 97 |
+
|
| 98 |
+
"""
|
| 99 |
+
vec = vec[..., None]
|
| 100 |
+
if out is not None:
|
| 101 |
+
out = out[..., None]
|
| 102 |
+
|
| 103 |
+
mv = torch.matmul(mat, vec, out=out)
|
| 104 |
+
mv = mv[..., 0]
|
| 105 |
+
if out is not None:
|
| 106 |
+
out = out[..., 0]
|
| 107 |
+
|
| 108 |
+
return mv
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def _compare_versions(version1, mode, version2):
|
| 112 |
+
for v1, v2 in zip(version1, version2):
|
| 113 |
+
if mode in ('gt', '>'):
|
| 114 |
+
if v1 > v2:
|
| 115 |
+
return True
|
| 116 |
+
elif v1 < v2:
|
| 117 |
+
return False
|
| 118 |
+
elif mode in ('ge', '>='):
|
| 119 |
+
if v1 > v2:
|
| 120 |
+
return True
|
| 121 |
+
elif v1 < v2:
|
| 122 |
+
return False
|
| 123 |
+
elif mode in ('lt', '<'):
|
| 124 |
+
if v1 < v2:
|
| 125 |
+
return True
|
| 126 |
+
elif v1 > v2:
|
| 127 |
+
return False
|
| 128 |
+
elif mode in ('le', '<='):
|
| 129 |
+
if v1 < v2:
|
| 130 |
+
return True
|
| 131 |
+
elif v1 > v2:
|
| 132 |
+
return False
|
| 133 |
+
if mode in ('gt', 'lt', '>', '<'):
|
| 134 |
+
return False
|
| 135 |
+
else:
|
| 136 |
+
return True
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def torch_version(mode, version):
|
| 140 |
+
"""Check torch version
|
| 141 |
+
|
| 142 |
+
Parameters
|
| 143 |
+
----------
|
| 144 |
+
mode : {'<', '<=', '>', '>='}
|
| 145 |
+
version : tuple[int]
|
| 146 |
+
|
| 147 |
+
Returns
|
| 148 |
+
-------
|
| 149 |
+
True if "torch.version <mode> version"
|
| 150 |
+
|
| 151 |
+
"""
|
| 152 |
+
current_version, *cuda_variant = torch.__version__.split('+')
|
| 153 |
+
major, minor, patch, *_ = current_version.split('.')
|
| 154 |
+
# strip alpha tags
|
| 155 |
+
for x in 'abcdefghijklmnopqrstuvwxy':
|
| 156 |
+
if x in patch:
|
| 157 |
+
patch = patch[:patch.index(x)]
|
| 158 |
+
current_version = (int(major), int(minor), int(patch))
|
| 159 |
+
version = make_list(version)
|
| 160 |
+
return _compare_versions(current_version, mode, version)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
if torch_version('>=', (1, 10)):
|
| 164 |
+
meshgrid_ij = lambda *x: torch.meshgrid(*x, indexing='ij')
|
| 165 |
+
meshgrid_xy = lambda *x: torch.meshgrid(*x, indexing='xy')
|
| 166 |
+
else:
|
| 167 |
+
meshgrid_ij = lambda *x: torch.meshgrid(*x)
|
| 168 |
+
def meshgrid_xy(*x):
|
| 169 |
+
grid = list(torch.meshgrid(*x))
|
| 170 |
+
if len(grid) > 1:
|
| 171 |
+
grid[0] = grid[0].transpose(0, 1)
|
| 172 |
+
grid[1] = grid[1].transpose(0, 1)
|
| 173 |
+
return grid
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
meshgrid = meshgrid_ij
|
Generator/utils.py
ADDED
|
@@ -0,0 +1,669 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import numpy as np
|
| 3 |
+
import nibabel as nib
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch.nn.functional import conv3d
|
| 7 |
+
from torch.utils.data import Dataset
|
| 8 |
+
|
| 9 |
+
from scipy.io.matlab import loadmat
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
import time, datetime
|
| 13 |
+
|
| 14 |
+
from ShapeID.DiffEqs.adjoint import odeint_adjoint as odeint
|
| 15 |
+
from ShapeID.perlin3d import generate_velocity_3d , generate_shape_3d
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class ConcatDataset(Dataset):
|
| 19 |
+
def __init__(self,dataset_list, probs=None):
|
| 20 |
+
self.datasets = dataset_list
|
| 21 |
+
self.probs = probs if probs else [1/len(self.datasets)] * len(self.datasets)
|
| 22 |
+
|
| 23 |
+
def __getitem__(self, i):
|
| 24 |
+
chosen_dataset = np.random.choice(self.datasets, 1, p=self.probs)[0]
|
| 25 |
+
i = i % len(chosen_dataset)
|
| 26 |
+
return chosen_dataset[i]
|
| 27 |
+
|
| 28 |
+
def __len__(self):
|
| 29 |
+
return max(len(d) for d in self.datasets)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
# Prepare generator
|
| 34 |
+
def resolution_sampler(low_res_only = False):
|
| 35 |
+
|
| 36 |
+
if low_res_only:
|
| 37 |
+
r = (np.random.rand() * 0.5) + 0.5 # in [0.5, 1]
|
| 38 |
+
else:
|
| 39 |
+
r = np.random.rand() # in [0, 1]
|
| 40 |
+
|
| 41 |
+
if r < 0.25: # 1mm isotropic
|
| 42 |
+
resolution = np.array([1.0, 1.0, 1.0])
|
| 43 |
+
thickness = np.array([1.0, 1.0, 1.0])
|
| 44 |
+
elif r < 0.5: # clinical (low-res in one dimension)
|
| 45 |
+
resolution = np.array([1.0, 1.0, 1.0])
|
| 46 |
+
thickness = np.array([1.0, 1.0, 1.0])
|
| 47 |
+
idx = np.random.randint(3)
|
| 48 |
+
resolution[idx] = 2.5 + 6 * np.random.rand()
|
| 49 |
+
thickness[idx] = np.min([resolution[idx], 4.0 + 2.0 * np.random.rand()])
|
| 50 |
+
elif r < 0.75: # low-field: stock sequences (always axial)
|
| 51 |
+
resolution = np.array([1.3, 1.3, 4.8]) + 0.4 * np.random.rand(3)
|
| 52 |
+
thickness = resolution.copy()
|
| 53 |
+
else: # low-field: isotropic-ish (also good for scouts)
|
| 54 |
+
resolution = 2.0 + 3.0 * np.random.rand(3)
|
| 55 |
+
thickness = resolution.copy()
|
| 56 |
+
|
| 57 |
+
return resolution, thickness
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
#####################################
|
| 61 |
+
############ Utility Func ###########
|
| 62 |
+
#####################################
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def binarize(p, thres):
|
| 66 |
+
# TODO: what is the optimal thresholding strategy?
|
| 67 |
+
thres = thres * p.max()
|
| 68 |
+
|
| 69 |
+
bin = p.clone()
|
| 70 |
+
bin[p < thres] = 0.
|
| 71 |
+
bin[p >= thres] = 1.
|
| 72 |
+
return bin
|
| 73 |
+
|
| 74 |
+
def make_gaussian_kernel(sigma, device):
|
| 75 |
+
|
| 76 |
+
sl = int(np.ceil(3 * sigma))
|
| 77 |
+
ts = torch.linspace(-sl, sl, 2*sl+1, dtype=torch.float, device=device)
|
| 78 |
+
gauss = torch.exp((-(ts / sigma)**2 / 2))
|
| 79 |
+
kernel = gauss / gauss.sum()
|
| 80 |
+
|
| 81 |
+
return kernel
|
| 82 |
+
|
| 83 |
+
def gaussian_blur_3d(input, stds, device):
|
| 84 |
+
blurred = input[None, None, :, :, :]
|
| 85 |
+
if stds[0]>0:
|
| 86 |
+
kx = make_gaussian_kernel(stds[0], device=device)
|
| 87 |
+
blurred = conv3d(blurred, kx[None, None, :, None, None], stride=1, padding=(len(kx) // 2, 0, 0))
|
| 88 |
+
if stds[1]>0:
|
| 89 |
+
ky = make_gaussian_kernel(stds[1], device=device)
|
| 90 |
+
blurred = conv3d(blurred, ky[None, None, None, :, None], stride=1, padding=(0, len(ky) // 2, 0))
|
| 91 |
+
if stds[2]>0:
|
| 92 |
+
kz = make_gaussian_kernel(stds[2], device=device)
|
| 93 |
+
blurred = conv3d(blurred, kz[None, None, None, None, :], stride=1, padding=(0, 0, len(kz) // 2))
|
| 94 |
+
return torch.squeeze(blurred)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
#####################################
|
| 99 |
+
######### Deformation Func ##########
|
| 100 |
+
#####################################
|
| 101 |
+
|
| 102 |
+
def make_affine_matrix(rot, sh, s):
|
| 103 |
+
Rx = np.array([[1, 0, 0], [0, np.cos(rot[0]), -np.sin(rot[0])], [0, np.sin(rot[0]), np.cos(rot[0])]])
|
| 104 |
+
Ry = np.array([[np.cos(rot[1]), 0, np.sin(rot[1])], [0, 1, 0], [-np.sin(rot[1]), 0, np.cos(rot[1])]])
|
| 105 |
+
Rz = np.array([[np.cos(rot[2]), -np.sin(rot[2]), 0], [np.sin(rot[2]), np.cos(rot[2]), 0], [0, 0, 1]])
|
| 106 |
+
|
| 107 |
+
SHx = np.array([[1, 0, 0], [sh[1], 1, 0], [sh[2], 0, 1]])
|
| 108 |
+
SHy = np.array([[1, sh[0], 0], [0, 1, 0], [0, sh[2], 1]])
|
| 109 |
+
SHz = np.array([[1, 0, sh[0]], [0, 1, sh[1]], [0, 0, 1]])
|
| 110 |
+
|
| 111 |
+
A = SHx @ SHy @ SHz @ Rx @ Ry @ Rz
|
| 112 |
+
A[0, :] = A[0, :] * s[0]
|
| 113 |
+
A[1, :] = A[1, :] * s[1]
|
| 114 |
+
A[2, :] = A[2, :] * s[2]
|
| 115 |
+
|
| 116 |
+
return A
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def fast_3D_interp_torch(X, II, JJ, KK, mode='linear', default_value_linear=0.0):
|
| 120 |
+
|
| 121 |
+
if II is None:
|
| 122 |
+
return X
|
| 123 |
+
|
| 124 |
+
if mode=='nearest':
|
| 125 |
+
IIr = torch.round(II).long()
|
| 126 |
+
JJr = torch.round(JJ).long()
|
| 127 |
+
KKr = torch.round(KK).long()
|
| 128 |
+
IIr[IIr < 0] = 0
|
| 129 |
+
JJr[JJr < 0] = 0
|
| 130 |
+
KKr[KKr < 0] = 0
|
| 131 |
+
IIr[IIr > (X.shape[0] - 1)] = (X.shape[0] - 1)
|
| 132 |
+
JJr[JJr > (X.shape[1] - 1)] = (X.shape[1] - 1)
|
| 133 |
+
KKr[KKr > (X.shape[2] - 1)] = (X.shape[2] - 1)
|
| 134 |
+
if len(X.shape)==3:
|
| 135 |
+
X = X[..., None]
|
| 136 |
+
Y = X[IIr, JJr, KKr]
|
| 137 |
+
if Y.shape[3] == 1:
|
| 138 |
+
Y = Y[:, :, :, 0]
|
| 139 |
+
|
| 140 |
+
elif mode=='linear':
|
| 141 |
+
ok = (II>0) & (JJ>0) & (KK>0) & (II<=X.shape[0]-1) & (JJ<=X.shape[1]-1) & (KK<=X.shape[2]-1)
|
| 142 |
+
|
| 143 |
+
IIv = II[ok]
|
| 144 |
+
JJv = JJ[ok]
|
| 145 |
+
KKv = KK[ok]
|
| 146 |
+
|
| 147 |
+
fx = torch.floor(IIv).long()
|
| 148 |
+
cx = fx + 1
|
| 149 |
+
cx[cx > (X.shape[0] - 1)] = (X.shape[0] - 1)
|
| 150 |
+
wcx = (IIv - fx)[..., None]
|
| 151 |
+
wfx = 1 - wcx
|
| 152 |
+
|
| 153 |
+
fy = torch.floor(JJv).long()
|
| 154 |
+
cy = fy + 1
|
| 155 |
+
cy[cy > (X.shape[1] - 1)] = (X.shape[1] - 1)
|
| 156 |
+
wcy = (JJv - fy)[..., None]
|
| 157 |
+
wfy = 1 - wcy
|
| 158 |
+
|
| 159 |
+
fz = torch.floor(KKv).long()
|
| 160 |
+
cz = fz + 1
|
| 161 |
+
cz[cz > (X.shape[2] - 1)] = (X.shape[2] - 1)
|
| 162 |
+
wcz = (KKv - fz)[..., None]
|
| 163 |
+
wfz = 1 - wcz
|
| 164 |
+
|
| 165 |
+
if len(X.shape)==3:
|
| 166 |
+
X = X[..., None]
|
| 167 |
+
|
| 168 |
+
c000 = X[fx, fy, fz]
|
| 169 |
+
c100 = X[cx, fy, fz]
|
| 170 |
+
c010 = X[fx, cy, fz]
|
| 171 |
+
c110 = X[cx, cy, fz]
|
| 172 |
+
c001 = X[fx, fy, cz]
|
| 173 |
+
c101 = X[cx, fy, cz]
|
| 174 |
+
c011 = X[fx, cy, cz]
|
| 175 |
+
c111 = X[cx, cy, cz]
|
| 176 |
+
|
| 177 |
+
c00 = c000 * wfx + c100 * wcx
|
| 178 |
+
c01 = c001 * wfx + c101 * wcx
|
| 179 |
+
c10 = c010 * wfx + c110 * wcx
|
| 180 |
+
c11 = c011 * wfx + c111 * wcx
|
| 181 |
+
|
| 182 |
+
c0 = c00 * wfy + c10 * wcy
|
| 183 |
+
c1 = c01 * wfy + c11 * wcy
|
| 184 |
+
|
| 185 |
+
c = c0 * wfz + c1 * wcz
|
| 186 |
+
|
| 187 |
+
Y = torch.zeros([*II.shape, X.shape[3]], device=X.device)
|
| 188 |
+
Y[ok] = c.float()
|
| 189 |
+
Y[~ok] = default_value_linear
|
| 190 |
+
|
| 191 |
+
if Y.shape[-1]==1:
|
| 192 |
+
Y = Y[...,0]
|
| 193 |
+
else:
|
| 194 |
+
raise Exception('mode must be linear or nearest')
|
| 195 |
+
|
| 196 |
+
return Y
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def myzoom_torch(X, factor, aff=None):
|
| 201 |
+
|
| 202 |
+
if len(X.shape)==3:
|
| 203 |
+
X = X[..., None]
|
| 204 |
+
|
| 205 |
+
delta = (1.0 - factor) / (2.0 * factor)
|
| 206 |
+
newsize = np.round(X.shape[:-1] * factor).astype(int)
|
| 207 |
+
|
| 208 |
+
vx = torch.arange(delta[0], delta[0] + newsize[0] / factor[0], 1 / factor[0], dtype=torch.float, device=X.device)[:newsize[0]]
|
| 209 |
+
vy = torch.arange(delta[1], delta[1] + newsize[1] / factor[1], 1 / factor[1], dtype=torch.float, device=X.device)[:newsize[1]]
|
| 210 |
+
vz = torch.arange(delta[2], delta[2] + newsize[2] / factor[2], 1 / factor[2], dtype=torch.float, device=X.device)[:newsize[2]]
|
| 211 |
+
|
| 212 |
+
vx[vx < 0] = 0
|
| 213 |
+
vy[vy < 0] = 0
|
| 214 |
+
vz[vz < 0] = 0
|
| 215 |
+
vx[vx > (X.shape[0]-1)] = (X.shape[0]-1)
|
| 216 |
+
vy[vy > (X.shape[1] - 1)] = (X.shape[1] - 1)
|
| 217 |
+
vz[vz > (X.shape[2] - 1)] = (X.shape[2] - 1)
|
| 218 |
+
|
| 219 |
+
fx = torch.floor(vx).int()
|
| 220 |
+
cx = fx + 1
|
| 221 |
+
cx[cx > (X.shape[0]-1)] = (X.shape[0]-1)
|
| 222 |
+
wcx = (vx - fx)
|
| 223 |
+
wfx = 1 - wcx
|
| 224 |
+
|
| 225 |
+
fy = torch.floor(vy).int()
|
| 226 |
+
cy = fy + 1
|
| 227 |
+
cy[cy > (X.shape[1]-1)] = (X.shape[1]-1)
|
| 228 |
+
wcy = (vy - fy)
|
| 229 |
+
wfy = 1 - wcy
|
| 230 |
+
|
| 231 |
+
fz = torch.floor(vz).int()
|
| 232 |
+
cz = fz + 1
|
| 233 |
+
cz[cz > (X.shape[2]-1)] = (X.shape[2]-1)
|
| 234 |
+
wcz = (vz - fz)
|
| 235 |
+
wfz = 1 - wcz
|
| 236 |
+
|
| 237 |
+
Y = torch.zeros([newsize[0], newsize[1], newsize[2], X.shape[3]], dtype=torch.float, device=X.device)
|
| 238 |
+
|
| 239 |
+
tmp1 = torch.zeros([newsize[0], X.shape[1], X.shape[2], X.shape[3]], dtype=torch.float, device=X.device)
|
| 240 |
+
for i in range(newsize[0]):
|
| 241 |
+
tmp1[i, :, :] = wfx[i] * X[fx[i], :, :] + wcx[i] * X[cx[i], :, :]
|
| 242 |
+
tmp2 = torch.zeros([newsize[0], newsize[1], X.shape[2], X.shape[3]], dtype=torch.float, device=X.device)
|
| 243 |
+
for j in range(newsize[1]):
|
| 244 |
+
tmp2[:, j, :] = wfy[j] * tmp1[:, fy[j], :] + wcy[j] * tmp1[:, cy[j], :]
|
| 245 |
+
for k in range(newsize[2]):
|
| 246 |
+
Y[:, :, k] = wfz[k] * tmp2[:, :, fz[k]] + wcz[k] * tmp2[:, :, cz[k]]
|
| 247 |
+
|
| 248 |
+
if Y.shape[3] == 1:
|
| 249 |
+
Y = Y[:,:,:, 0]
|
| 250 |
+
|
| 251 |
+
if aff is not None:
|
| 252 |
+
aff_new = aff.copy()
|
| 253 |
+
aff_new[:-1] = aff_new[:-1] / factor
|
| 254 |
+
aff_new[:-1, -1] = aff_new[:-1, -1] - aff[:-1, :-1] @ (0.5 - 0.5 / (factor * np.ones(3)))
|
| 255 |
+
return Y, aff_new
|
| 256 |
+
else:
|
| 257 |
+
return Y
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
#####################################
|
| 263 |
+
############ Reading Func ###########
|
| 264 |
+
#####################################
|
| 265 |
+
|
| 266 |
+
def read_image(file_name):
|
| 267 |
+
img = nib.load(file_name)
|
| 268 |
+
aff = img.affine
|
| 269 |
+
res = np.sqrt(np.sum(abs(aff[:-1, :-1]), axis=0))
|
| 270 |
+
return img, aff, res
|
| 271 |
+
|
| 272 |
+
def deform_image(I, deform_dict, device, default_value_linear_mode=None, deform_mode = 'linear'):
|
| 273 |
+
if I is None:
|
| 274 |
+
return I
|
| 275 |
+
|
| 276 |
+
[xx2, yy2, zz2, x1, y1, z1, x2, y2, z2] = deform_dict['grid']
|
| 277 |
+
|
| 278 |
+
if not isinstance(I, torch.Tensor):
|
| 279 |
+
I = torch.squeeze(torch.tensor(I.get_fdata()[x1:x2, y1:y2, z1:z2].astype(float), dtype=torch.float, device=device))
|
| 280 |
+
else:
|
| 281 |
+
I = torch.squeeze(I[x1:x2, y1:y2, z1:z2].astype(float), dtype=torch.float, device=device)
|
| 282 |
+
I = torch.nan_to_num(I)
|
| 283 |
+
|
| 284 |
+
if default_value_linear_mode is not None:
|
| 285 |
+
if default_value_linear_mode == 'max':
|
| 286 |
+
default_value_linear = torch.max(I)
|
| 287 |
+
else:
|
| 288 |
+
raise ValueError('Not support default_value_linear_mode:', default_value_linear_mode)
|
| 289 |
+
else:
|
| 290 |
+
default_value_linear = 0.
|
| 291 |
+
Idef = fast_3D_interp_torch(I, xx2, yy2, zz2, deform_mode, default_value_linear)
|
| 292 |
+
|
| 293 |
+
return Idef
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
def read_and_deform(file_name, dtype, deform_dict, device, mask, default_value_linear_mode=None, deform_mode = 'linear', mean = 0., scale = 1.):
|
| 297 |
+
[xx2, yy2, zz2, x1, y1, z1, x2, y2, z2] = deform_dict['grid']
|
| 298 |
+
|
| 299 |
+
try:
|
| 300 |
+
Iimg = nib.load(file_name)
|
| 301 |
+
except:
|
| 302 |
+
Iimg = nib.load(file_name + '.gz')
|
| 303 |
+
res = np.sqrt(np.sum(abs(Iimg.affine[:-1, :-1]), axis=0))
|
| 304 |
+
I = torch.squeeze(torch.tensor(Iimg.get_fdata()[x1:x2, y1:y2, z1:z2].astype(float), dtype=dtype, device=device))
|
| 305 |
+
I = torch.nan_to_num(I)
|
| 306 |
+
|
| 307 |
+
I -= mean
|
| 308 |
+
I /= scale
|
| 309 |
+
|
| 310 |
+
if mask is not None:
|
| 311 |
+
I[mask == 0] = 0
|
| 312 |
+
|
| 313 |
+
if default_value_linear_mode is not None:
|
| 314 |
+
if default_value_linear_mode == 'max':
|
| 315 |
+
default_value_linear = torch.max(I)
|
| 316 |
+
else:
|
| 317 |
+
raise ValueError('Not support default_value_linear_mode:', default_value_linear_mode)
|
| 318 |
+
else:
|
| 319 |
+
default_value_linear = 0.
|
| 320 |
+
Idef = fast_3D_interp_torch(I, xx2, yy2, zz2, deform_mode, default_value_linear)
|
| 321 |
+
return Idef, res
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
def read_and_deform_image(exist_keys, task_name, file_name, setups, deform_dict, device, mask, **kwargs):
|
| 325 |
+
Idef, _ = read_and_deform(file_name, torch.float, deform_dict, device, mask)
|
| 326 |
+
Idef -= torch.min(Idef)
|
| 327 |
+
Idef /= torch.max(Idef)
|
| 328 |
+
if setups['flip']:
|
| 329 |
+
Idef = torch.flip(Idef, [0])
|
| 330 |
+
update_dict = {task_name: Idef[None]}
|
| 331 |
+
|
| 332 |
+
if os.path.isfile(file_name[:-4] + '.defacingmask.nii'):
|
| 333 |
+
Idef_DM, _ = read_and_deform(file_name[:-4] + '.defacingmask.nii', torch.float, deform_dict, device, mask)
|
| 334 |
+
Idef_DM = torch.clamp(Idef_DM, min = 0.)
|
| 335 |
+
Idef_DM /= torch.max(Idef_DM)
|
| 336 |
+
if setups['flip']:
|
| 337 |
+
Idef = torch.flip(Idef_DM, [0])
|
| 338 |
+
update_dict.update({task_name + '_DM': Idef_DM[None]})
|
| 339 |
+
#if not 'brain_mask' in exist_keys:
|
| 340 |
+
# mask = torch.ones_like(Idef)
|
| 341 |
+
# mask[Idef <= 0.] = 0.
|
| 342 |
+
# update_dict.update({'brain_mask': mask[None]})
|
| 343 |
+
return update_dict
|
| 344 |
+
|
| 345 |
+
def read_and_deform_CT(exist_keys, task_name, file_name, setups, deform_dict, device, mask, **kwargs):
|
| 346 |
+
Idef, _ = read_and_deform(file_name, torch.float, deform_dict, device, mask, scale = 1000)
|
| 347 |
+
#Idef = torch.clamp(Idef, min = 0., max = 80.) # No clamping for inference/GT
|
| 348 |
+
#Idef /= torch.max(Idef)
|
| 349 |
+
if setups['flip']:
|
| 350 |
+
Idef = torch.flip(Idef, [0])
|
| 351 |
+
update_dict = {'CT': Idef[None]}
|
| 352 |
+
|
| 353 |
+
if os.path.isfile(file_name[:-4] + '.defacingmask.nii'):
|
| 354 |
+
Idef_DM, _ = read_and_deform(file_name[:-4] + '.defacingmask.nii', torch.float, deform_dict, device, mask)
|
| 355 |
+
Idef_DM = torch.clamp(Idef_DM, min = 0.)
|
| 356 |
+
Idef_DM /= torch.max(Idef_DM)
|
| 357 |
+
if setups['flip']:
|
| 358 |
+
Idef = torch.flip(Idef_DM, [0])
|
| 359 |
+
update_dict.update({task_name + '_DM': Idef_DM[None]})
|
| 360 |
+
#if not 'brain_mask' in exist_keys:
|
| 361 |
+
# mask = torch.ones_like(Idef)
|
| 362 |
+
# mask[Idef <= 0.] = 0.
|
| 363 |
+
# update_dict.update({'brain_mask': mask[None]})
|
| 364 |
+
return update_dict
|
| 365 |
+
|
| 366 |
+
def read_and_deform_distance(exist_keys, task_name, file_names, setups, deform_dict, device, mask, cfg, **kwargs):
|
| 367 |
+
[lp_dist_map, lw_dist_map, rp_dist_map, rw_dist_map] = file_names
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
lp, _ = read_and_deform(lp_dist_map, torch.float, deform_dict, device, mask, default_value_linear_mode = 'max', mean = 128., scale = 20)
|
| 371 |
+
lw, _ = read_and_deform(lw_dist_map, torch.float, deform_dict, device, mask, default_value_linear_mode = 'max', mean = 128., scale = 20)
|
| 372 |
+
|
| 373 |
+
if mask is not None: # left_hemis_only
|
| 374 |
+
Idef = torch.stack([lp, lw], dim = 0)
|
| 375 |
+
else:
|
| 376 |
+
rp, _ = read_and_deform(rp_dist_map, torch.float, deform_dict, device, mask, default_value_linear_mode = 'max', mean = 128., scale = 20)
|
| 377 |
+
rw, _ = read_and_deform(rw_dist_map, torch.float, deform_dict, device, mask, default_value_linear_mode = 'max', mean = 128., scale = 20)
|
| 378 |
+
|
| 379 |
+
if setups['flip']:
|
| 380 |
+
aux = torch.flip(lp, [0])
|
| 381 |
+
lp = torch.flip(rp, [0])
|
| 382 |
+
rp = aux
|
| 383 |
+
aux = torch.flip(lw, [0])
|
| 384 |
+
lw = torch.flip(rw, [0])
|
| 385 |
+
rw = aux
|
| 386 |
+
|
| 387 |
+
Idef = torch.stack([lp, lw, rp, rw], dim = 0)
|
| 388 |
+
|
| 389 |
+
Idef /= deform_dict['scaling_factor_distances']
|
| 390 |
+
Idef = torch.clamp(Idef, min=-cfg.max_surf_distance, max=cfg.max_surf_distance)
|
| 391 |
+
|
| 392 |
+
return {'distance': Idef}
|
| 393 |
+
|
| 394 |
+
def read_and_deform_segmentation(exist_keys, task_name, file_name, setups, deform_dict, device, mask, cfg, onehotmatrix, lut, vflip, **kwargs):
|
| 395 |
+
[xx2, yy2, zz2, x1, y1, z1, x2, y2, z2] = deform_dict['grid']
|
| 396 |
+
|
| 397 |
+
Simg = nib.load(file_name)
|
| 398 |
+
S = torch.squeeze(torch.tensor(Simg.get_fdata()[x1:x2, y1:y2, z1:z2].astype(int), dtype=torch.int, device=device))
|
| 399 |
+
|
| 400 |
+
if mask is not None:
|
| 401 |
+
S[mask == 0] = 0
|
| 402 |
+
|
| 403 |
+
Sdef = fast_3D_interp_torch(S, xx2, yy2, zz2, 'nearest')
|
| 404 |
+
if cfg.generator.deform_one_hots:
|
| 405 |
+
Sonehot = onehotmatrix[lut[S.long()]]
|
| 406 |
+
Sdef_OneHot = fast_3D_interp_torch(Sonehot, xx2, yy2, zz2)
|
| 407 |
+
else:
|
| 408 |
+
Sdef_OneHot = onehotmatrix[lut[Sdef.long()]]
|
| 409 |
+
|
| 410 |
+
if setups['flip']:
|
| 411 |
+
#Sdef = torch.flip(Sdef, [0])
|
| 412 |
+
Sdef_OneHot = torch.flip(Sdef_OneHot, [0])[:, :, :, vflip]
|
| 413 |
+
|
| 414 |
+
# prepare for input
|
| 415 |
+
Sdef_OneHot = Sdef_OneHot.permute([3, 0, 1, 2])
|
| 416 |
+
|
| 417 |
+
#update_dict = {'label': Sdef[None], 'segmentation': Sdef_OneHot}
|
| 418 |
+
update_dict = {'segmentation': Sdef_OneHot}
|
| 419 |
+
|
| 420 |
+
#if not 'brain_mask' in exist_keys:
|
| 421 |
+
# mask = torch.ones_like(Sdef)
|
| 422 |
+
# mask[Sdef <= 0.] = 0.
|
| 423 |
+
# update_dict.update({'brain_mask': mask[None]})
|
| 424 |
+
return update_dict
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
def read_and_deform_pathology(exist_keys, task_name, file_name, setups, deform_dict, device, mask = None,
|
| 429 |
+
augment = False, pde_func = None, t = None,
|
| 430 |
+
shape_gen_args = None, thres = 0., **kwargs):
|
| 431 |
+
# NOTE does not support left_hemis for now
|
| 432 |
+
|
| 433 |
+
[xx2, yy2, zz2, x1, y1, z1, x2, y2, z2] = deform_dict['grid']
|
| 434 |
+
|
| 435 |
+
if file_name is None:
|
| 436 |
+
return {'pathology': torch.zeros(xx2.shape)[None].to(device), 'pathology_prob': torch.zeros(xx2.shape)[None].to(device)}
|
| 437 |
+
|
| 438 |
+
if file_name == 'random_shape': # generate random shape
|
| 439 |
+
percentile = np.random.uniform(shape_gen_args.mask_percentile_min, shape_gen_args.mask_percentile_max)
|
| 440 |
+
_, Pdef = generate_shape_3d(xx2.shape, shape_gen_args.perlin_res, percentile, device)
|
| 441 |
+
else: # read from existing shape
|
| 442 |
+
Pdef, _ = read_and_deform(file_name, torch.float, deform_dict, device)
|
| 443 |
+
|
| 444 |
+
if augment:
|
| 445 |
+
Pdef = augment_pathology(Pdef, pde_func, t, shape_gen_args, device)
|
| 446 |
+
|
| 447 |
+
#if setups['flip']: # flipping should happen after P has been encoded
|
| 448 |
+
# Pdef = torch.flip(Pdef, [0])
|
| 449 |
+
|
| 450 |
+
P = binarize(Pdef, thres)
|
| 451 |
+
if P.mean() <= shape_gen_args.pathol_tol:
|
| 452 |
+
return {'pathology': torch.zeros(xx2.shape)[None].to(device), 'pathology_prob': torch.zeros(xx2.shape)[None].to(device)}
|
| 453 |
+
#print('process', P.mean(), shape_gen_args.pathol_tol)
|
| 454 |
+
|
| 455 |
+
return {'pathology': P[None], 'pathology_prob': Pdef[None]}
|
| 456 |
+
|
| 457 |
+
|
| 458 |
+
def read_and_deform_registration(exist_keys, task_name, file_names, setups, deform_dict, device, mask, **kwargs):
|
| 459 |
+
[mni_reg_x, mni_reg_y, mni_reg_z] = file_names
|
| 460 |
+
regx, _ = read_and_deform(mni_reg_x, torch.float, deform_dict, device, mask, scale = 10000)
|
| 461 |
+
regy, _ = read_and_deform(mni_reg_y, torch.float, deform_dict, device, mask, scale = 10000)
|
| 462 |
+
regz, _ = read_and_deform(mni_reg_z, torch.float, deform_dict, device, mask, scale = 10000)
|
| 463 |
+
|
| 464 |
+
if setups['flip']:
|
| 465 |
+
regx = -torch.flip(regx, [0]) # NOTE: careful with switching sign
|
| 466 |
+
regy = torch.flip(regy, [0])
|
| 467 |
+
regz = torch.flip(regz, [0])
|
| 468 |
+
|
| 469 |
+
Idef = torch.stack([regx, regy, regz], dim = 0)
|
| 470 |
+
|
| 471 |
+
return {'registration': Idef}
|
| 472 |
+
|
| 473 |
+
def read_and_deform_bias_field(exist_keys, task_name, file_name, setups, deform_dict, device, mask, **kwargs):
|
| 474 |
+
Idef, _ = read_and_deform(file_name, torch.float, deform_dict, mask, device)
|
| 475 |
+
if setups['flip']:
|
| 476 |
+
Idef = torch.flip(Idef, [0])
|
| 477 |
+
return {'bias_field': Idef[None]}
|
| 478 |
+
|
| 479 |
+
def read_and_deform_surface(exist_keys, task_name, file_name, setups, deform_dict, device, mask, size):
|
| 480 |
+
Fneg, A, c2 = deform_dict['Fneg'], deform_dict['A'], deform_dict['c2']
|
| 481 |
+
# NOTE does not support left_hemis for now
|
| 482 |
+
|
| 483 |
+
mat = loadmat(file_name.split('.nii')[0] + '.mat')
|
| 484 |
+
|
| 485 |
+
Vlw = torch.tensor(mat['Vlw'], dtype=torch.float, device=device)
|
| 486 |
+
Flw = torch.tensor(mat['Flw'], dtype=torch.int, device=device)
|
| 487 |
+
Vrw = torch.tensor(mat['Vrw'], dtype=torch.float, device=device)
|
| 488 |
+
Frw = torch.tensor(mat['Frw'], dtype=torch.int, device=device)
|
| 489 |
+
Vlp = torch.tensor(mat['Vlp'], dtype=torch.float, device=device)
|
| 490 |
+
Flp = torch.tensor(mat['Flp'], dtype=torch.int, device=device)
|
| 491 |
+
Vrp = torch.tensor(mat['Vrp'], dtype=torch.float, device=device)
|
| 492 |
+
Frp = torch.tensor(mat['Frp'], dtype=torch.int, device=device)
|
| 493 |
+
|
| 494 |
+
Ainv = torch.inverse(A)
|
| 495 |
+
Vlw -= c2[None, :]
|
| 496 |
+
Vlw = Vlw @ torch.transpose(Ainv, 0, 1)
|
| 497 |
+
Vlw += fast_3D_interp_torch(Fneg, Vlw[:, 0] + c2[0], Vlw[:, 1]+c2[1], Vlw[:, 2] + c2[2])
|
| 498 |
+
Vlw += c2[None, :]
|
| 499 |
+
Vrw -= c2[None, :]
|
| 500 |
+
Vrw = Vrw @ torch.transpose(Ainv, 0, 1)
|
| 501 |
+
Vrw += fast_3D_interp_torch(Fneg, Vrw[:, 0] + c2[0], Vrw[:, 1]+c2[1], Vrw[:, 2] + c2[2])
|
| 502 |
+
Vrw += c2[None, :]
|
| 503 |
+
Vlp -= c2[None, :]
|
| 504 |
+
Vlp = Vlp @ torch.transpose(Ainv, 0, 1)
|
| 505 |
+
Vlp += fast_3D_interp_torch(Fneg, Vlp[:, 0] + c2[0], Vlp[:, 1] + c2[1], Vlp[:, 2] + c2[2])
|
| 506 |
+
Vlp += c2[None, :]
|
| 507 |
+
Vrp -= c2[None, :]
|
| 508 |
+
Vrp = Vrp @ torch.transpose(Ainv, 0, 1)
|
| 509 |
+
Vrp += fast_3D_interp_torch(Fneg, Vrp[:, 0] + c2[0], Vrp[:, 1] + c2[1], Vrp[:, 2] + c2[2])
|
| 510 |
+
Vrp += c2[None, :]
|
| 511 |
+
|
| 512 |
+
if setups['flip']:
|
| 513 |
+
Vlw[:, 0] = size[0] - 1 - Vlw[:, 0]
|
| 514 |
+
Vrw[:, 0] = size[0] - 1 - Vrw[:, 0]
|
| 515 |
+
Vlp[:, 0] = size[0] - 1 - Vlp[:, 0]
|
| 516 |
+
Vrp[:, 0] = size[0] - 1 - Vrp[:, 0]
|
| 517 |
+
Vlw, Vrw = Vrw, Vlw
|
| 518 |
+
Vlp, Vrp = Vrp, Vlp
|
| 519 |
+
Flw, Frw = Frw, Flw
|
| 520 |
+
Flp, Frp = Frp, Flp
|
| 521 |
+
|
| 522 |
+
print(Vlw.shape) # 131148
|
| 523 |
+
print(Vlp.shape) # 131148
|
| 524 |
+
|
| 525 |
+
print(Vrw.shape) # 131720
|
| 526 |
+
print(Vrp.shape) # 131720
|
| 527 |
+
|
| 528 |
+
print(Flw.shape) # 262292
|
| 529 |
+
print(Flp.shape) # 262292
|
| 530 |
+
|
| 531 |
+
print(Frw.shape) # 263436
|
| 532 |
+
print(Frp.shape) # 263436
|
| 533 |
+
#return torch.stack([Vlw, Flw, Vrw, Frw, Vlp, Flp, Vrp, Frp])
|
| 534 |
+
return {'Vlw': Vlw, 'Flw': Flw, 'Vrw': Vrw, 'Frw': Frw, 'Vlp': Vlp, 'Flp': Flp, 'Vrp': Vrp, 'Frp': Frp}
|
| 535 |
+
|
| 536 |
+
|
| 537 |
+
#####################################
|
| 538 |
+
######### Pathology Shape #########
|
| 539 |
+
#####################################
|
| 540 |
+
|
| 541 |
+
|
| 542 |
+
def augment_pathology(Pprob, pde_func, t, shape_gen_args, device):
|
| 543 |
+
Pprob = torch.squeeze(Pprob)
|
| 544 |
+
|
| 545 |
+
nt = np.random.randint(1, shape_gen_args.max_nt+1)
|
| 546 |
+
if nt <= 1:
|
| 547 |
+
return Pprob
|
| 548 |
+
|
| 549 |
+
pde_func.V_dict = generate_velocity_3d(Pprob.shape, shape_gen_args.perlin_res, shape_gen_args.V_multiplier, device)
|
| 550 |
+
|
| 551 |
+
#start_time = time.time()
|
| 552 |
+
Pprob = odeint(pde_func, Pprob[None], t[:nt],
|
| 553 |
+
shape_gen_args.dt,
|
| 554 |
+
method = shape_gen_args.integ_method)[-1, 0] # (last_t, n_batch=1, s, r, c)
|
| 555 |
+
# total_time = time.time() - start_time
|
| 556 |
+
#total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
| 557 |
+
#print('Time {} for {} time points'.format(total_time_str, nt))
|
| 558 |
+
|
| 559 |
+
|
| 560 |
+
return Pprob
|
| 561 |
+
|
| 562 |
+
|
| 563 |
+
#####################################
|
| 564 |
+
######### Augmentation Func #########
|
| 565 |
+
#####################################
|
| 566 |
+
|
| 567 |
+
|
| 568 |
+
def add_gamma_transform(I, aux_dict, cfg, device, **kwargs):
|
| 569 |
+
gamma = torch.tensor(np.exp(cfg.gamma_std * np.random.randn(1)[0]), dtype=float, device=device)
|
| 570 |
+
I_gamma = 300.0 * (I / 300.0) ** gamma
|
| 571 |
+
#aux_dict.update({'gamma': gamma}) # uncomment if you want to save gamma for later use
|
| 572 |
+
return I_gamma, aux_dict
|
| 573 |
+
|
| 574 |
+
def add_bias_field(I, aux_dict, cfg, input_mode, setups, size, device, **kwargs):
|
| 575 |
+
if input_mode == 'CT':
|
| 576 |
+
aux_dict.update({'high_res': I})
|
| 577 |
+
return I, aux_dict
|
| 578 |
+
|
| 579 |
+
bf_scale = cfg.bf_scale_min + np.random.rand(1) * (cfg.bf_scale_max - cfg.bf_scale_min)
|
| 580 |
+
size_BF_small = np.round(bf_scale * np.array(size)).astype(int).tolist()
|
| 581 |
+
if setups['photo_mode']:
|
| 582 |
+
size_BF_small[1] = np.round(size[1]/setups['spac']).astype(int)
|
| 583 |
+
BFsmall = torch.tensor(cfg.bf_std_min + (cfg.bf_std_max - cfg.bf_std_min) * np.random.rand(1), dtype=torch.float, device=device) * \
|
| 584 |
+
torch.randn(size_BF_small, dtype=torch.float, device=device)
|
| 585 |
+
BFlog = myzoom_torch(BFsmall, np.array(size) / size_BF_small)
|
| 586 |
+
BF = torch.exp(BFlog)
|
| 587 |
+
I_bf = I * BF
|
| 588 |
+
aux_dict.update({'BFlog': BFlog, 'high_res': I_bf})
|
| 589 |
+
return I_bf, aux_dict
|
| 590 |
+
|
| 591 |
+
def resample_resolution(I, aux_dict, setups, res, size, device, **kwargs):
|
| 592 |
+
stds = (0.85 + 0.3 * np.random.rand()) * np.log(5) /np.pi * setups['thickness'] / res
|
| 593 |
+
stds[setups['thickness']<=res] = 0.0 # no blur if thickness is equal to the resolution of the training data
|
| 594 |
+
I_blur = gaussian_blur_3d(I, stds, device)
|
| 595 |
+
new_size = (np.array(size) * res / setups['resolution']).astype(int)
|
| 596 |
+
|
| 597 |
+
factors = np.array(new_size) / np.array(size)
|
| 598 |
+
delta = (1.0 - factors) / (2.0 * factors)
|
| 599 |
+
vx = np.arange(delta[0], delta[0] + new_size[0] / factors[0], 1 / factors[0])[:new_size[0]]
|
| 600 |
+
vy = np.arange(delta[1], delta[1] + new_size[1] / factors[1], 1 / factors[1])[:new_size[1]]
|
| 601 |
+
vz = np.arange(delta[2], delta[2] + new_size[2] / factors[2], 1 / factors[2])[:new_size[2]]
|
| 602 |
+
II, JJ, KK = np.meshgrid(vx, vy, vz, sparse=False, indexing='ij')
|
| 603 |
+
II = torch.tensor(II, dtype=torch.float, device=device)
|
| 604 |
+
JJ = torch.tensor(JJ, dtype=torch.float, device=device)
|
| 605 |
+
KK = torch.tensor(KK, dtype=torch.float, device=device)
|
| 606 |
+
|
| 607 |
+
I_small = fast_3D_interp_torch(I_blur, II, JJ, KK)
|
| 608 |
+
aux_dict.update({'factors': factors})
|
| 609 |
+
return I_small, aux_dict
|
| 610 |
+
|
| 611 |
+
|
| 612 |
+
def resample_resolution_photo(I, aux_dict, setups, res, size, device, **kwargs):
|
| 613 |
+
stds = (0.85 + 0.3 * np.random.rand()) * np.log(5) /np.pi * setups['thickness'] / res
|
| 614 |
+
stds[setups['thickness']<=res] = 0.0 # no blur if thickness is equal to the resolution of the training data
|
| 615 |
+
I_blur = gaussian_blur_3d(I, stds, device)
|
| 616 |
+
new_size = (np.array(size) * res / setups['resolution']).astype(int)
|
| 617 |
+
|
| 618 |
+
factors = np.array(new_size) / np.array(size)
|
| 619 |
+
delta = (1.0 - factors) / (2.0 * factors)
|
| 620 |
+
vx = np.arange(delta[0], delta[0] + new_size[0] / factors[0], 1 / factors[0])[:new_size[0]]
|
| 621 |
+
vy = np.arange(delta[1], delta[1] + new_size[1] / factors[1], 1 / factors[1])[:new_size[1]]
|
| 622 |
+
vz = np.arange(delta[2], delta[2] + new_size[2] / factors[2], 1 / factors[2])[:new_size[2]]
|
| 623 |
+
II, JJ, KK = np.meshgrid(vx, vy, vz, sparse=False, indexing='ij')
|
| 624 |
+
II = torch.tensor(II, dtype=torch.float, device=device)
|
| 625 |
+
JJ = torch.tensor(JJ, dtype=torch.float, device=device)
|
| 626 |
+
KK = torch.tensor(KK, dtype=torch.float, device=device)
|
| 627 |
+
|
| 628 |
+
I_small = fast_3D_interp_torch(I_blur, II, JJ, KK)
|
| 629 |
+
aux_dict.update({'factors': factors})
|
| 630 |
+
return I_small, aux_dict
|
| 631 |
+
|
| 632 |
+
|
| 633 |
+
def add_noise(I, aux_dict, cfg, device, **kwargs):
|
| 634 |
+
noise_std = torch.tensor(cfg.noise_std_min + (cfg.noise_std_max - cfg.noise_std_min) * np.random.rand(1), dtype=torch.float, device=device)
|
| 635 |
+
I_noisy = I + noise_std * torch.randn(I.shape, dtype=torch.float, device=device)
|
| 636 |
+
I_noisy[I_noisy < 0] = 0
|
| 637 |
+
#aux_dict.update({'noise_std': noise_std}) # uncomment if you want to save noise_std for later use
|
| 638 |
+
return I_noisy, aux_dict
|
| 639 |
+
|
| 640 |
+
|
| 641 |
+
#####################################
|
| 642 |
+
#####################################
|
| 643 |
+
|
| 644 |
+
|
| 645 |
+
# map SynthSeg right to left labels for contrast synthesis
|
| 646 |
+
right_to_left_dict = {
|
| 647 |
+
41: 2,
|
| 648 |
+
42: 3,
|
| 649 |
+
43: 4,
|
| 650 |
+
44: 5,
|
| 651 |
+
46: 7,
|
| 652 |
+
47: 8,
|
| 653 |
+
49: 10,
|
| 654 |
+
50: 11,
|
| 655 |
+
51: 12,
|
| 656 |
+
52: 13,
|
| 657 |
+
53: 17,
|
| 658 |
+
54: 18,
|
| 659 |
+
58: 26,
|
| 660 |
+
60: 28
|
| 661 |
+
}
|
| 662 |
+
|
| 663 |
+
# based on merged left & right SynthSeg labels
|
| 664 |
+
ct_brightness_group = {
|
| 665 |
+
'darker': [4, 5, 14, 15, 24, 31, 72], # ventricles, CSF
|
| 666 |
+
'dark': [2, 7, 16, 77, 30], # white matter
|
| 667 |
+
'bright': [3, 8, 17, 18, 28, 10, 11, 12, 13, 26], # grey matter (cortex, hippocampus, amggdala, ventral DC), thalamus, ganglia (nucleus (putamen, pallidus, accumbens), caudate)
|
| 668 |
+
'brighter': [], # skull, pineal gland, choroid plexus
|
| 669 |
+
}
|
README.md
CHANGED
|
@@ -1,3 +1,91 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
## <p align="center">[A Modality-agnostic Multi-task Foundation Model for Human Brain Imaging](https://arxiv.org/abs/2509.00549)</p>
|
| 3 |
+
|
| 4 |
+
**<p align="center">Peirong Liu<sup>1,2</sup>, Oula Puonti<sup>2</sup>, Xiaoling Hu<sup>2</sup>, Karthik Gopinath<sup>2</sup>, Annabel Sorby-Adams<sup>2</sup>, Daniel C. Alexander<sup>3</sup>, Juan Eugenio Iglesias<sup>2,3,4</sup></p>**
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
<p align="center">
|
| 8 |
+
<sup>1</sup>Johns Hopkins University<br />
|
| 9 |
+
<sup>2</sup>Harvard Medical School and Massachusetts General Hospital<br />
|
| 10 |
+
<sup>3</sup>University College London <br />
|
| 11 |
+
<sup>4</sup>Massachusetts Institute of Technology
|
| 12 |
+
</p>
|
| 13 |
+
|
| 14 |
+
<p align="center">
|
| 15 |
+
<img src="./assets/overview.png" alt="drawing", width="650"/>
|
| 16 |
+
</p>
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
This is the official repository for our preprint: A Modality-agnostic Multi-task Foundation Model for Human Brain Imaging [[arXiv]](https://arxiv.org/abs/2509.00549)<br />
|
| 20 |
+
More detailed and organized instructions are coming soon...
|
| 21 |
+
|
| 22 |
+
## Environment
|
| 23 |
+
Training and evaluation environment: Python 3.11.4, PyTorch 2.0.1, CUDA 12.2. Run the following command to install required packages.
|
| 24 |
+
```
|
| 25 |
+
conda create -n pre python=3.11
|
| 26 |
+
conda activate pre
|
| 27 |
+
|
| 28 |
+
git clone https://github.com/jhuldr/BrainFM
|
| 29 |
+
cd /path/to/brainfm
|
| 30 |
+
pip install -r requirements.txt
|
| 31 |
+
```
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
## Generator
|
| 35 |
+
```
|
| 36 |
+
cd scripts
|
| 37 |
+
python demo_generator.py
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
### Generator setups
|
| 41 |
+
Setups are in cfgs/generator, default setups are in default.yaml. A customized setup example can be found in train/brain_id.yaml, where several Brain-ID-specific setups are added. During Config reading/implementation, customized yaml will overwrite default.yaml if they have the same keys.
|
| 42 |
+
|
| 43 |
+
dataset_setups: information for all datasets, in Generator/constants.py<br>
|
| 44 |
+
augmentation_funcs: augmentation functions and steps, in Generator/constants.py<br>
|
| 45 |
+
processing_funcs: image processing functions for each modality/task, in Generator/constants.py<br>
|
| 46 |
+
|
| 47 |
+
dataset_names: dataset name list, paths setups in Generator/constants.py<br>
|
| 48 |
+
mix_synth_prob: if the input mode is synthesizing, then probability for blending synth with real images<br>
|
| 49 |
+
dataset_option: generator types, could be BaseGen or customized generator<br>
|
| 50 |
+
task: switch on/off individual training tasks
|
| 51 |
+
|
| 52 |
+
### Base generator module
|
| 53 |
+
```
|
| 54 |
+
cd Generator
|
| 55 |
+
python datasets.py
|
| 56 |
+
```
|
| 57 |
+
The dataset paths setups are in constants.py. In datasets.py, different datasets been used are fomulated as a list of dataset names.
|
| 58 |
+
|
| 59 |
+
A customized data generator module example can be found in datasets.py -- BrainIDGen.
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
Refer to "__getitem__" function. Specifically, it includes: <br>
|
| 63 |
+
(1) read original input: could be either generation labels or real images;<br>
|
| 64 |
+
(2) generate augmentation setups and deformation fields; <br>
|
| 65 |
+
(3) read target(s) according to the assigned tasks -- here I seperate the processing functions for each item/modality, in case we want different processing steps for them; <br>
|
| 66 |
+
(4) augment input sample: either synthesized or real image input.
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
(Some of the functions are leaved blank for now.)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
## Trainer
|
| 75 |
+
```
|
| 76 |
+
cd scripts
|
| 77 |
+
python train.py
|
| 78 |
+
```
|
| 79 |
+
|
| 80 |
+
## Downloads
|
| 81 |
+
The pre-trained model weight is available on [OneDrive](https://livejohnshopkins-my.sharepoint.com/:u:/g/personal/pliu53_jh_edu/EZ_BJ7K6pMJEj9hZ8SA51GYBxH_Nan4fA3a-s4udwvVRog?e=nwZ7JC).
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
## Citation
|
| 85 |
+
```bibtex
|
| 86 |
+
@article{Liu_2025_BrainFM,
|
| 87 |
+
author = {Liu, Peirong and Puonti, Oula and Hu, Xiaoling and Gopinath, Karthik and Sorby-Adams, Annabel and Alexander, Daniel C. and Iglesias, Juan E.},
|
| 88 |
+
title = {A Modality-agnostic Multi-task Foundation Model for Human Brain Imaging},
|
| 89 |
+
booktitle = {arXiv preprint arXiv:2509.00549},
|
| 90 |
+
year = {2025},
|
| 91 |
+
}
|
ShapeID/DiffEqs/FD.py
ADDED
|
@@ -0,0 +1,525 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
*finite_difference.py* is the main package to compute finite differences in
|
| 3 |
+
1D, 2D, and 3D on numpy arrays (class FD_np) and pytorch tensors (class FD_torch).
|
| 4 |
+
The package supports first and second order derivatives and Neumann and linear extrapolation
|
| 5 |
+
boundary conditions (though the latter have not been tested extensively yet).
|
| 6 |
+
"""
|
| 7 |
+
from __future__ import absolute_import
|
| 8 |
+
|
| 9 |
+
# from builtins import object
|
| 10 |
+
from abc import ABCMeta, abstractmethod
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
from torch.autograd import Variable
|
| 14 |
+
import numpy as np
|
| 15 |
+
from future.utils import with_metaclass
|
| 16 |
+
|
| 17 |
+
class FD(with_metaclass(ABCMeta, object)):
|
| 18 |
+
"""
|
| 19 |
+
*FD* is the abstract class for finite differences. It includes most of the actual finite difference code,
|
| 20 |
+
but requires the definition (in a derived class) of the methods *get_dimension*, *create_zero_array*, and *get_size_of_array*.
|
| 21 |
+
In this way the numpy and pytorch versions can easily be derived. All the method expect BxXxYxZ format (i.e., they process a batch at a time)
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(self, spacing, bcNeumannZero=True):
|
| 25 |
+
"""
|
| 26 |
+
Constructor
|
| 27 |
+
:param spacing: 1D numpy array defining the spatial spacing, e.g., [0.1,0.1,0.1] for a 3D image
|
| 28 |
+
:param bcNeumannZero: Defines the boundary condition. If set to *True* (default) zero Neumann boundary conditions
|
| 29 |
+
are imposed. If set to *False* linear extrapolation is used (this is still experimental, but may be beneficial
|
| 30 |
+
for better boundary behavior)
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
self.dim = len(spacing) # In my code, data_spacing is a list # spacing.size
|
| 34 |
+
"""spatial dimension"""
|
| 35 |
+
self.spacing = np.ones(self.dim)
|
| 36 |
+
"""spacing"""
|
| 37 |
+
self.bcNeumannZero = bcNeumannZero # if false then linear interpolation
|
| 38 |
+
"""should Neumann boundary conditions be used? (otherwise linear extrapolation)"""
|
| 39 |
+
if len(spacing) == 1: #spacing.size==1:
|
| 40 |
+
self.spacing[0] = spacing[0]
|
| 41 |
+
elif len(spacing) == 2: # spacing.size==2:
|
| 42 |
+
self.spacing[0] = spacing[0]
|
| 43 |
+
self.spacing[1] = spacing[1]
|
| 44 |
+
elif len(spacing) == 3: # spacing.size==3:
|
| 45 |
+
self.spacing[0] = spacing[0]
|
| 46 |
+
self.spacing[1] = spacing[1]
|
| 47 |
+
self.spacing[2] = spacing[2]
|
| 48 |
+
else:
|
| 49 |
+
print('Current dimension:', len(spacing))
|
| 50 |
+
raise ValueError('Finite differences are only supported in dimensions 1 to 3')
|
| 51 |
+
|
| 52 |
+
def dXb(self,I):
|
| 53 |
+
"""
|
| 54 |
+
Backward difference in x direction:
|
| 55 |
+
:math:`\\frac{dI(i)}{dx}\\approx\\frac{I_i-I_{i-1}}{h_x}`
|
| 56 |
+
:param I: Input image
|
| 57 |
+
:return: Returns the first derivative in x direction using backward differences
|
| 58 |
+
"""
|
| 59 |
+
return (I-self.xm(I))/self.spacing[0]
|
| 60 |
+
|
| 61 |
+
def dXf(self,I):
|
| 62 |
+
"""
|
| 63 |
+
Forward difference in x direction:
|
| 64 |
+
:math:`\\frac{dI(i)}{dx}\\approx\\frac{I_{i+1}-I_{i}}{h_x}`
|
| 65 |
+
|
| 66 |
+
:param I: Input image
|
| 67 |
+
:return: Returns the first derivative in x direction using forward differences
|
| 68 |
+
"""
|
| 69 |
+
return (self.xp(I)-I)/self.spacing[0]
|
| 70 |
+
|
| 71 |
+
def dXc(self,I):
|
| 72 |
+
"""
|
| 73 |
+
Central difference in x direction:
|
| 74 |
+
:math:`\\frac{dI(i)}{dx}\\approx\\frac{I_{i+1}-I_{i-1}}{2h_x}`
|
| 75 |
+
|
| 76 |
+
:param I: Input image
|
| 77 |
+
:return: Returns the first derivative in x direction using central differences
|
| 78 |
+
"""
|
| 79 |
+
return (self.xp(I)-self.xm(I))/(2*self.spacing[0])
|
| 80 |
+
|
| 81 |
+
def ddXc(self,I):
|
| 82 |
+
"""
|
| 83 |
+
Second deriative in x direction
|
| 84 |
+
|
| 85 |
+
:param I: Input image
|
| 86 |
+
:return: Returns the second derivative in x direction
|
| 87 |
+
"""
|
| 88 |
+
return (self.xp(I)-2*I+self.xm(I))/(self.spacing[0]**2)
|
| 89 |
+
|
| 90 |
+
def dYb(self,I):
|
| 91 |
+
"""
|
| 92 |
+
Same as dXb, but for the y direction
|
| 93 |
+
|
| 94 |
+
:param I: Input image
|
| 95 |
+
:return: Returns the first derivative in y direction using backward differences
|
| 96 |
+
"""
|
| 97 |
+
return (I-self.ym(I))/self.spacing[1]
|
| 98 |
+
|
| 99 |
+
def dYf(self,I):
|
| 100 |
+
"""
|
| 101 |
+
Same as dXf, but for the y direction
|
| 102 |
+
|
| 103 |
+
:param I: Input image
|
| 104 |
+
:return: Returns the first derivative in y direction using forward differences
|
| 105 |
+
"""
|
| 106 |
+
return (self.yp(I)-I)/self.spacing[1]
|
| 107 |
+
|
| 108 |
+
def dYc(self,I):
|
| 109 |
+
"""
|
| 110 |
+
Same as dXc, but for the y direction
|
| 111 |
+
|
| 112 |
+
:param I: Input image
|
| 113 |
+
:return: Returns the first derivative in y direction using central differences
|
| 114 |
+
"""
|
| 115 |
+
return (self.yp(I)-self.ym(I))/(2*self.spacing[1])
|
| 116 |
+
|
| 117 |
+
def ddYc(self,I):
|
| 118 |
+
"""
|
| 119 |
+
Same as ddXc, but for the y direction
|
| 120 |
+
|
| 121 |
+
:param I: Input image
|
| 122 |
+
:return: Returns the second derivative in the y direction
|
| 123 |
+
"""
|
| 124 |
+
return (self.yp(I)-2*I+self.ym(I))/(self.spacing[1]**2)
|
| 125 |
+
|
| 126 |
+
def dZb(self,I):
|
| 127 |
+
"""
|
| 128 |
+
Same as dXb, but for the z direction
|
| 129 |
+
|
| 130 |
+
:param I: Input image
|
| 131 |
+
:return: Returns the first derivative in the z direction using backward differences
|
| 132 |
+
"""
|
| 133 |
+
return (I - self.zm(I))/self.spacing[2]
|
| 134 |
+
|
| 135 |
+
def dZf(self, I):
|
| 136 |
+
"""
|
| 137 |
+
Same as dXf, but for the z direction
|
| 138 |
+
|
| 139 |
+
:param I: Input image
|
| 140 |
+
:return: Returns the first derivative in the z direction using forward differences
|
| 141 |
+
"""
|
| 142 |
+
return (self.zp(I)-I)/self.spacing[2]
|
| 143 |
+
|
| 144 |
+
def dZc(self, I):
|
| 145 |
+
"""
|
| 146 |
+
Same as dXc, but for the z direction
|
| 147 |
+
|
| 148 |
+
:param I: Input image
|
| 149 |
+
:return: Returns the first derivative in the z direction using central differences
|
| 150 |
+
"""
|
| 151 |
+
return (self.zp(I)-self.zm(I))/(2*self.spacing[2])
|
| 152 |
+
|
| 153 |
+
def ddZc(self,I):
|
| 154 |
+
"""
|
| 155 |
+
Same as ddXc, but for the z direction
|
| 156 |
+
|
| 157 |
+
:param I: Input iamge
|
| 158 |
+
:return: Returns the second derivative in the z direction
|
| 159 |
+
"""
|
| 160 |
+
return (self.zp(I)-2*I+self.zm(I))/(self.spacing[2]**2)
|
| 161 |
+
|
| 162 |
+
def lap(self, I):
|
| 163 |
+
"""
|
| 164 |
+
Compute the Lapacian of an image
|
| 165 |
+
!!!!!!!!!!!
|
| 166 |
+
IMPORTANT:
|
| 167 |
+
ALL THE FOLLOWING IMPLEMENTED CODE ADD 1 ON DIMENSION, WHICH REPRESENT BATCH DIMENSION.
|
| 168 |
+
THIS IS FOR COMPUTATIONAL EFFICIENCY.
|
| 169 |
+
|
| 170 |
+
:param I: Input image [batch, X,Y,Z]
|
| 171 |
+
:return: Returns the Laplacian
|
| 172 |
+
"""
|
| 173 |
+
ndim = self.getdimension(I)
|
| 174 |
+
if ndim == 1+1:
|
| 175 |
+
return self.ddXc(I)
|
| 176 |
+
elif ndim == 2+1:
|
| 177 |
+
return (self.ddXc(I) + self.ddYc(I))
|
| 178 |
+
elif ndim == 3+1:
|
| 179 |
+
return (self.ddXc(I) + self.ddYc(I) + self.ddZc(I))
|
| 180 |
+
else:
|
| 181 |
+
raise ValueError('Finite differences are only supported in dimensions 1 to 3')
|
| 182 |
+
|
| 183 |
+
def grad_norm_sqr_c(self, I):
|
| 184 |
+
"""
|
| 185 |
+
Computes the gradient norm of an image
|
| 186 |
+
!!!!!!!!!!!
|
| 187 |
+
IMPORTANT:
|
| 188 |
+
ALL THE FOLLOWING IMPLEMENTED CODE ADD 1 ON DIMENSION, WHICH REPRESENT BATCH DIMENSION.
|
| 189 |
+
THIS IS FOR COMPUTATIONAL EFFICIENCY.
|
| 190 |
+
:param I: Input image [batch, X,Y,Z]
|
| 191 |
+
:return: returns ||grad I||^2
|
| 192 |
+
"""
|
| 193 |
+
ndim = self.getdimension(I)
|
| 194 |
+
if ndim == 1 + 1:
|
| 195 |
+
return self.dXc(I)**2
|
| 196 |
+
elif ndim == 2 + 1:
|
| 197 |
+
return (self.dXc(I)**2 + self.dYc(I)**2)
|
| 198 |
+
elif ndim == 3 + 1:
|
| 199 |
+
return (self.dXc(I)**2 + self.dYc(I)**2 + self.dZc(I)**2)
|
| 200 |
+
else:
|
| 201 |
+
raise ValueError('Finite differences are only supported in dimensions 1 to 3')
|
| 202 |
+
|
| 203 |
+
def grad_norm_sqr_f(self, I):
|
| 204 |
+
"""
|
| 205 |
+
Computes the gradient norm of an image
|
| 206 |
+
!!!!!!!!!!!
|
| 207 |
+
IMPORTANT:
|
| 208 |
+
ALL THE FOLLOWING IMPLEMENTED CODE ADD 1 ON DIMENSION, WHICH REPRESENT BATCH DIMENSION.
|
| 209 |
+
THIS IS FOR COMPUTATIONAL EFFICIENCY.
|
| 210 |
+
:param I: Input image [batch, X,Y,Z]
|
| 211 |
+
:return: returns ||grad I||^2
|
| 212 |
+
"""
|
| 213 |
+
ndim = self.getdimension(I)
|
| 214 |
+
if ndim == 1 + 1:
|
| 215 |
+
return self.dXf(I)**2
|
| 216 |
+
elif ndim == 2 + 1:
|
| 217 |
+
return (self.dXf(I)**2 + self.dYf(I)**2)
|
| 218 |
+
elif ndim == 3 + 1:
|
| 219 |
+
return (self.dXf(I)**2 + self.dYf(I)**2 + self.dZf(I)**2)
|
| 220 |
+
else:
|
| 221 |
+
raise ValueError('Finite differences are only supported in dimensions 1 to 3')
|
| 222 |
+
|
| 223 |
+
def grad_norm_sqr_b(self, I):
|
| 224 |
+
"""
|
| 225 |
+
Computes the gradient norm of an image
|
| 226 |
+
!!!!!!!!!!!
|
| 227 |
+
IMPORTANT:
|
| 228 |
+
ALL THE FOLLOWING IMPLEMENTED CODE ADD 1 ON DIMENSION, WHICH REPRESENT BATCH DIMENSION.
|
| 229 |
+
THIS IS FOR COMPUTATIONAL EFFICIENCY.
|
| 230 |
+
:param I: Input image [batch, X,Y,Z]
|
| 231 |
+
:return: returns ||grad I||^2
|
| 232 |
+
"""
|
| 233 |
+
ndim = self.getdimension(I)
|
| 234 |
+
if ndim == 1 + 1:
|
| 235 |
+
return self.dXb(I)**2
|
| 236 |
+
elif ndim == 2 + 1:
|
| 237 |
+
return (self.dXb(I)**2 + self.dYb(I)**2)
|
| 238 |
+
elif ndim == 3 + 1:
|
| 239 |
+
return (self.dXb(I)**2 + self.dYb(I)**2 + self.dZb(I)**2)
|
| 240 |
+
else:
|
| 241 |
+
raise ValueError('Finite differences are only supported in dimensions 1 to 3')
|
| 242 |
+
|
| 243 |
+
@abstractmethod
|
| 244 |
+
def getdimension(self,I):
|
| 245 |
+
"""
|
| 246 |
+
Abstract method to return the dimension of an input image I
|
| 247 |
+
|
| 248 |
+
:param I: Input image
|
| 249 |
+
:return: Returns the dimension of the image I
|
| 250 |
+
"""
|
| 251 |
+
pass
|
| 252 |
+
|
| 253 |
+
@abstractmethod
|
| 254 |
+
def create_zero_array(self, sz):
|
| 255 |
+
"""
|
| 256 |
+
Abstract method to create a zero array of a given size, sz. E.g., sz=[10,2,5]
|
| 257 |
+
|
| 258 |
+
:param sz: Size array
|
| 259 |
+
:return: Returns a zero array of the specified size
|
| 260 |
+
"""
|
| 261 |
+
pass
|
| 262 |
+
|
| 263 |
+
@abstractmethod
|
| 264 |
+
def get_size_of_array(self, A):
|
| 265 |
+
"""
|
| 266 |
+
Abstract method to return the size of an array (as a vector)
|
| 267 |
+
|
| 268 |
+
:param A: Input array
|
| 269 |
+
:return: Returns its size (e.g., [5,10] or [3,4,6]
|
| 270 |
+
"""
|
| 271 |
+
pass
|
| 272 |
+
|
| 273 |
+
def xp(self,I):
|
| 274 |
+
"""
|
| 275 |
+
!!!!!!!!!!!
|
| 276 |
+
IMPORTANT:
|
| 277 |
+
ALL THE FOLLOWING IMPLEMENTED CODE ADD 1 ON DIMENSION, WHICH REPRESENT BATCH DIMENSION.
|
| 278 |
+
THIS IS FOR COMPUTATIONAL EFFICIENCY.
|
| 279 |
+
Returns the values for x-index incremented by one (to the right in 1D)
|
| 280 |
+
|
| 281 |
+
:param I: Input image [batch, X, Y,Z]
|
| 282 |
+
:return: Image with values at an x-index one larger
|
| 283 |
+
"""
|
| 284 |
+
rxp = self.create_zero_array( self.get_size_of_array( I ) )
|
| 285 |
+
ndim = self.getdimension(I)
|
| 286 |
+
if ndim == 1+1:
|
| 287 |
+
rxp[:,0:-1] = I[:,1:]
|
| 288 |
+
if self.bcNeumannZero:
|
| 289 |
+
rxp[:,-1] = I[:,-1]
|
| 290 |
+
else:
|
| 291 |
+
rxp[:,-1] = 2*I[:,-1]-I[:,-2]
|
| 292 |
+
elif ndim == 2+1:
|
| 293 |
+
rxp[:,0:-1,:] = I[:,1:,:]
|
| 294 |
+
if self.bcNeumannZero:
|
| 295 |
+
rxp[:,-1,:] = I[:,-1,:]
|
| 296 |
+
else:
|
| 297 |
+
rxp[:,-1,:] = 2*I[:,-1,:]-I[:,-2,:]
|
| 298 |
+
elif ndim == 3+1:
|
| 299 |
+
rxp[:,0:-1,:,:] = I[:,1:,:,:]
|
| 300 |
+
if self.bcNeumannZero:
|
| 301 |
+
rxp[:,-1,:,:] = I[:,-1,:,:]
|
| 302 |
+
else:
|
| 303 |
+
rxp[:,-1,:,:] = 2*I[:,-1,:,:]-I[:,-2,:,:]
|
| 304 |
+
else:
|
| 305 |
+
raise ValueError('Finite differences are only supported in dimensions 1 to 3')
|
| 306 |
+
return rxp
|
| 307 |
+
|
| 308 |
+
def xm(self,I):
|
| 309 |
+
"""
|
| 310 |
+
!!!!!!!!!!!
|
| 311 |
+
IMPORTANT:
|
| 312 |
+
ALL THE FOLLOWING IMPLEMENTED CODE ADD 1 ON DIMENSION, WHICH REPRESENT BATCH DIMENSION.
|
| 313 |
+
THIS IS FOR COMPUTATIONAL EFFICIENCY.
|
| 314 |
+
Returns the values for x-index decremented by one (to the left in 1D)
|
| 315 |
+
|
| 316 |
+
:param I: Input image [batch, X, Y, Z]
|
| 317 |
+
:return: Image with values at an x-index one smaller
|
| 318 |
+
"""
|
| 319 |
+
rxm = self.create_zero_array( self.get_size_of_array( I ) )
|
| 320 |
+
ndim = self.getdimension(I)
|
| 321 |
+
if ndim == 1+1:
|
| 322 |
+
rxm[:,1:] = I[:,0:-1]
|
| 323 |
+
if self.bcNeumannZero:
|
| 324 |
+
rxm[:,0] = I[:,0]
|
| 325 |
+
else:
|
| 326 |
+
rxm[:,0] = 2*I[:,0]-I[:,1]
|
| 327 |
+
elif ndim == 2+1:
|
| 328 |
+
rxm[:,1:,:] = I[:,0:-1,:]
|
| 329 |
+
if self.bcNeumannZero:
|
| 330 |
+
rxm[:,0,:] = I[:,0,:]
|
| 331 |
+
else:
|
| 332 |
+
rxm[:,0,:] = 2*I[:,0,:]-I[:,1,:]
|
| 333 |
+
elif ndim == 3+1:
|
| 334 |
+
rxm[:,1:,:,:] = I[:,0:-1,:,:]
|
| 335 |
+
if self.bcNeumannZero:
|
| 336 |
+
rxm[:,0,:,:] = I[:,0,:,:]
|
| 337 |
+
else:
|
| 338 |
+
rxm[:,0,:,:] = 2*I[:,0,:,:]-I[:,1,:,:]
|
| 339 |
+
else:
|
| 340 |
+
raise ValueError('Finite differences are only supported in dimensions 1 to 3')
|
| 341 |
+
return rxm
|
| 342 |
+
|
| 343 |
+
def yp(self, I):
|
| 344 |
+
"""
|
| 345 |
+
!!!!!!!!!!!
|
| 346 |
+
IMPORTANT:
|
| 347 |
+
ALL THE FOLLOWING IMPLEMENTED CODE ADD 1 ON DIMENSION, WHICH REPRESENT BATCH DIMENSION.
|
| 348 |
+
THIS IS FOR COMPUTATIONAL EFFICIENCY.
|
| 349 |
+
Same as xp, but for the y direction
|
| 350 |
+
|
| 351 |
+
:param I: Input image
|
| 352 |
+
:return: Image with values at y-index one larger
|
| 353 |
+
"""
|
| 354 |
+
ryp = self.create_zero_array( self.get_size_of_array( I ) )
|
| 355 |
+
ndim = self.getdimension(I)
|
| 356 |
+
if ndim == 2+1:
|
| 357 |
+
ryp[:,:,0:-1] = I[:,:,1:]
|
| 358 |
+
if self.bcNeumannZero:
|
| 359 |
+
ryp[:,:,-1] = I[:,:,-1]
|
| 360 |
+
else:
|
| 361 |
+
ryp[:,:,-1] = 2*I[:,:,-1]-I[:,:,-2]
|
| 362 |
+
elif ndim == 3+1:
|
| 363 |
+
ryp[:,:,0:-1,:] = I[:,:,1:,:]
|
| 364 |
+
if self.bcNeumannZero:
|
| 365 |
+
ryp[:,:,-1,:] = I[:,:,-1,:]
|
| 366 |
+
else:
|
| 367 |
+
ryp[:,:,-1,:] = 2*I[:,:,-1,:]-I[:,:,-2,:]
|
| 368 |
+
else:
|
| 369 |
+
print('Current dimension:', ndim-1)
|
| 370 |
+
raise ValueError('Finite differences are only supported in dimensions 1 to 3')
|
| 371 |
+
return ryp
|
| 372 |
+
|
| 373 |
+
def ym(self, I):
|
| 374 |
+
"""
|
| 375 |
+
Same as xm, but for the y direction
|
| 376 |
+
!!!!!!!!!!!
|
| 377 |
+
IMPORTANT:
|
| 378 |
+
ALL THE FOLLOWING IMPLEMENTED CODE ADD 1 ON DIMENSION, WHICH REPRESENT BATCH DIMENSION.
|
| 379 |
+
THIS IS FOR COMPUTATIONAL EFFICIENCY.
|
| 380 |
+
Returns the values for x-index decremented by one (to the left in 1D)
|
| 381 |
+
:param I: Input image [batch, X, Y, Z]
|
| 382 |
+
:return: Image with values at y-index one smaller
|
| 383 |
+
"""
|
| 384 |
+
rym = self.create_zero_array( self.get_size_of_array( I ) )
|
| 385 |
+
ndim = self.getdimension(I)
|
| 386 |
+
if ndim == 2+1:
|
| 387 |
+
rym[:,:,1:] = I[:,:,0:-1]
|
| 388 |
+
if self.bcNeumannZero:
|
| 389 |
+
rym[:,:,0] = I[:,:,0]
|
| 390 |
+
else:
|
| 391 |
+
rym[:,:,0] = 2*I[:,:,0]-I[:,:,1]
|
| 392 |
+
elif ndim == 3+1:
|
| 393 |
+
rym[:,:,1:,:] = I[:,:,0:-1,:]
|
| 394 |
+
if self.bcNeumannZero:
|
| 395 |
+
rym[:,:,0,:] = I[:,:,0,:]
|
| 396 |
+
else:
|
| 397 |
+
rym[:,:,0,:] = 2*I[:,:,0,:]-I[:,:,1,:]
|
| 398 |
+
else:
|
| 399 |
+
raise ValueError('Finite differences are only supported in dimensions 1 to 3')
|
| 400 |
+
return rym
|
| 401 |
+
|
| 402 |
+
def zp(self, I):
|
| 403 |
+
"""
|
| 404 |
+
Same as xp, but for the z direction
|
| 405 |
+
|
| 406 |
+
!!!!!!!!!!!
|
| 407 |
+
IMPORTANT:
|
| 408 |
+
ALL THE FOLLOWING IMPLEMENTED CODE ADD 1 ON DIMENSION, WHICH REPRESENT BATCH DIMENSION.
|
| 409 |
+
THIS IS FOR COMPUTATIONAL EFFICIENCY.
|
| 410 |
+
Returns the values for x-index decremented by one (to the left in 1D)
|
| 411 |
+
:param I: Input image [batch, X, Y, Z]
|
| 412 |
+
:return: Image with values at z-index one larger
|
| 413 |
+
"""
|
| 414 |
+
rzp = self.create_zero_array( self.get_size_of_array( I ) )
|
| 415 |
+
ndim = self.getdimension(I)
|
| 416 |
+
if ndim == 3+1:
|
| 417 |
+
rzp[:,:,:,0:-1] = I[:,:,:,1:]
|
| 418 |
+
if self.bcNeumannZero:
|
| 419 |
+
rzp[:,:,:,-1] = I[:,:,:,-1]
|
| 420 |
+
else:
|
| 421 |
+
rzp[:,:,:,-1] = 2*I[:,:,:,-1]-I[:,:,:,-2]
|
| 422 |
+
else:
|
| 423 |
+
raise ValueError('Finite differences are only supported in dimensions 1 to 3')
|
| 424 |
+
return rzp
|
| 425 |
+
|
| 426 |
+
def zm(self, I):
|
| 427 |
+
"""
|
| 428 |
+
Same as xm, but for the z direction
|
| 429 |
+
|
| 430 |
+
!!!!!!!!!!!
|
| 431 |
+
IMPORTANT:
|
| 432 |
+
ALL THE FOLLOWING IMPLEMENTED CODE ADD 1 ON DIMENSION, WHICH REPRESENT BATCH DIMENSION.
|
| 433 |
+
THIS IS FOR COMPUTATIONAL EFFICIENCY.
|
| 434 |
+
Returns the values for x-index decremented by one (to the left in 1D)
|
| 435 |
+
:param I: Input image [batch, X, Y, Z]
|
| 436 |
+
:return: Image with values at z-index one smaller
|
| 437 |
+
"""
|
| 438 |
+
rzm = self.create_zero_array( self.get_size_of_array( I ) )
|
| 439 |
+
ndim = self.getdimension(I)
|
| 440 |
+
if ndim == 3+1:
|
| 441 |
+
rzm[:,:,:,1:] = I[:,:,:,0:-1]
|
| 442 |
+
if self.bcNeumannZero:
|
| 443 |
+
rzm[:,:,:,0] = I[:,:,:,0]
|
| 444 |
+
else:
|
| 445 |
+
rzm[:,:,:,0] = 2*I[:,:,:,0]-I[:,:,:,1]
|
| 446 |
+
else:
|
| 447 |
+
raise ValueError('Finite differences are only supported in dimensions 1 to 3')
|
| 448 |
+
return rzm
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
class FD_np(FD):
|
| 452 |
+
"""
|
| 453 |
+
Defnitions of the abstract methods for numpy
|
| 454 |
+
"""
|
| 455 |
+
|
| 456 |
+
def __init__(self,spacing,bcNeumannZero=True):
|
| 457 |
+
"""
|
| 458 |
+
Constructor for numpy finite differences
|
| 459 |
+
:param spacing: spatial spacing (array with as many entries as there are spatial dimensions)
|
| 460 |
+
:param bcNeumannZero: Specifies if zero Neumann conditions should be used (if not, uses linear extrapolation)
|
| 461 |
+
"""
|
| 462 |
+
super(FD_np, self).__init__(spacing,bcNeumannZero)
|
| 463 |
+
|
| 464 |
+
def getdimension(self,I):
|
| 465 |
+
"""
|
| 466 |
+
Returns the dimension of an image
|
| 467 |
+
:param I: input image
|
| 468 |
+
:return: dimension of the input image
|
| 469 |
+
"""
|
| 470 |
+
return I.ndim
|
| 471 |
+
|
| 472 |
+
def create_zero_array(self, sz):
|
| 473 |
+
"""
|
| 474 |
+
Creates a zero array
|
| 475 |
+
:param sz: size of the zero array, e.g., [3,4,2]
|
| 476 |
+
:return: the zero array
|
| 477 |
+
"""
|
| 478 |
+
return np.zeros( sz )
|
| 479 |
+
|
| 480 |
+
def get_size_of_array(self, A):
|
| 481 |
+
"""
|
| 482 |
+
Returns the size (shape in numpy) of an array
|
| 483 |
+
:param A: input array
|
| 484 |
+
:return: shape/size
|
| 485 |
+
"""
|
| 486 |
+
return A.shape
|
| 487 |
+
|
| 488 |
+
|
| 489 |
+
class FD_torch(FD):
|
| 490 |
+
"""
|
| 491 |
+
Defnitions of the abstract methods for torch
|
| 492 |
+
"""
|
| 493 |
+
|
| 494 |
+
def __init__(self,spacing,device,bcNeumannZero=True):
|
| 495 |
+
"""
|
| 496 |
+
Constructor for torch finite differences
|
| 497 |
+
:param spacing: spatial spacing (array with as many entries as there are spatial dimensions)
|
| 498 |
+
:param bcNeumannZero: Specifies if zero Neumann conditions should be used (if not, uses linear extrapolation)
|
| 499 |
+
"""
|
| 500 |
+
super(FD_torch, self).__init__(spacing,bcNeumannZero)
|
| 501 |
+
self.device = device
|
| 502 |
+
|
| 503 |
+
def getdimension(self,I):
|
| 504 |
+
"""
|
| 505 |
+
Returns the dimension of an image
|
| 506 |
+
:param I: input image
|
| 507 |
+
:return: dimension of the input image
|
| 508 |
+
"""
|
| 509 |
+
return I.dim()
|
| 510 |
+
|
| 511 |
+
def create_zero_array(self, sz):
|
| 512 |
+
"""
|
| 513 |
+
Creats a zero array
|
| 514 |
+
:param sz: size of the array, e.g., [3,4,2]
|
| 515 |
+
:return: the zero array
|
| 516 |
+
"""
|
| 517 |
+
return torch.zeros(sz).float().to(self.device)
|
| 518 |
+
|
| 519 |
+
def get_size_of_array(self, A):
|
| 520 |
+
"""
|
| 521 |
+
Returns the size (size()) of an array
|
| 522 |
+
:param A: input array
|
| 523 |
+
:return: shape/size
|
| 524 |
+
"""
|
| 525 |
+
return A.size()
|
ShapeID/DiffEqs/adams.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import collections
|
| 2 |
+
import torch
|
| 3 |
+
from ShapeID.DiffEqs.solvers import AdaptiveStepsizeODESolver
|
| 4 |
+
from ShapeID.DiffEqs.misc import (
|
| 5 |
+
_handle_unused_kwargs, _select_initial_step, _convert_to_tensor, _scaled_dot_product, _is_iterable,
|
| 6 |
+
_optimal_step_size, _compute_error_ratio
|
| 7 |
+
)
|
| 8 |
+
|
| 9 |
+
_MIN_ORDER = 1
|
| 10 |
+
_MAX_ORDER = 12
|
| 11 |
+
|
| 12 |
+
gamma_star = [
|
| 13 |
+
1, -1 / 2, -1 / 12, -1 / 24, -19 / 720, -3 / 160, -863 / 60480, -275 / 24192, -33953 / 3628800, -0.00789255,
|
| 14 |
+
-0.00678585, -0.00592406, -0.00523669, -0.0046775, -0.00421495, -0.0038269
|
| 15 |
+
]
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class _VCABMState(collections.namedtuple('_VCABMState', 'y_n, prev_f, prev_t, next_t, phi, order')):
|
| 19 |
+
"""Saved state of the variable step size Adams-Bashforth-Moulton solver as described in
|
| 20 |
+
|
| 21 |
+
Solving Ordinary Differential Equations I - Nonstiff Problems III.5
|
| 22 |
+
by Ernst Hairer, Gerhard Wanner, and Syvert P Norsett.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def g_and_explicit_phi(prev_t, next_t, implicit_phi, k):
|
| 27 |
+
curr_t = prev_t[0]
|
| 28 |
+
dt = next_t - prev_t[0]
|
| 29 |
+
|
| 30 |
+
g = torch.empty(k + 1).to(prev_t[0])
|
| 31 |
+
explicit_phi = collections.deque(maxlen=k)
|
| 32 |
+
beta = torch.tensor(1).to(prev_t[0])
|
| 33 |
+
|
| 34 |
+
g[0] = 1
|
| 35 |
+
c = 1 / torch.arange(1, k + 2).to(prev_t[0])
|
| 36 |
+
explicit_phi.append(implicit_phi[0])
|
| 37 |
+
|
| 38 |
+
for j in range(1, k):
|
| 39 |
+
beta = (next_t - prev_t[j - 1]) / (curr_t - prev_t[j]) * beta
|
| 40 |
+
beat_cast = beta.to(implicit_phi[j][0])
|
| 41 |
+
explicit_phi.append(tuple(iphi_ * beat_cast for iphi_ in implicit_phi[j]))
|
| 42 |
+
|
| 43 |
+
c = c[:-1] - c[1:] if j == 1 else c[:-1] - c[1:] * dt / (next_t - prev_t[j - 1])
|
| 44 |
+
g[j] = c[0]
|
| 45 |
+
|
| 46 |
+
c = c[:-1] - c[1:] * dt / (next_t - prev_t[k - 1])
|
| 47 |
+
g[k] = c[0]
|
| 48 |
+
|
| 49 |
+
return g, explicit_phi
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def compute_implicit_phi(explicit_phi, f_n, k):
|
| 53 |
+
k = min(len(explicit_phi) + 1, k)
|
| 54 |
+
implicit_phi = collections.deque(maxlen=k)
|
| 55 |
+
implicit_phi.append(f_n)
|
| 56 |
+
for j in range(1, k):
|
| 57 |
+
implicit_phi.append(tuple(iphi_ - ephi_ for iphi_, ephi_ in zip(implicit_phi[j - 1], explicit_phi[j - 1])))
|
| 58 |
+
return implicit_phi
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class VariableCoefficientAdamsBashforth(AdaptiveStepsizeODESolver):
|
| 62 |
+
|
| 63 |
+
def __init__(
|
| 64 |
+
self, func, y0, rtol, atol, implicit=True, max_order=_MAX_ORDER, safety=0.9, ifactor=10.0, dfactor=0.2,
|
| 65 |
+
**unused_kwargs
|
| 66 |
+
):
|
| 67 |
+
_handle_unused_kwargs(self, unused_kwargs)
|
| 68 |
+
del unused_kwargs
|
| 69 |
+
|
| 70 |
+
self.func = func
|
| 71 |
+
self.y0 = y0
|
| 72 |
+
self.rtol = rtol if _is_iterable(rtol) else [rtol] * len(y0)
|
| 73 |
+
self.atol = atol if _is_iterable(atol) else [atol] * len(y0)
|
| 74 |
+
self.implicit = implicit
|
| 75 |
+
self.max_order = int(max(_MIN_ORDER, min(max_order, _MAX_ORDER)))
|
| 76 |
+
self.safety = _convert_to_tensor(safety, dtype=torch.float64, device=y0[0].device)
|
| 77 |
+
self.ifactor = _convert_to_tensor(ifactor, dtype=torch.float64, device=y0[0].device)
|
| 78 |
+
self.dfactor = _convert_to_tensor(dfactor, dtype=torch.float64, device=y0[0].device)
|
| 79 |
+
|
| 80 |
+
def before_integrate(self, t):
|
| 81 |
+
prev_f = collections.deque(maxlen=self.max_order + 1)
|
| 82 |
+
prev_t = collections.deque(maxlen=self.max_order + 1)
|
| 83 |
+
phi = collections.deque(maxlen=self.max_order)
|
| 84 |
+
|
| 85 |
+
t0 = t[0]
|
| 86 |
+
f0 = self.func(t0.type_as(self.y0[0]), self.y0)
|
| 87 |
+
prev_t.appendleft(t0)
|
| 88 |
+
prev_f.appendleft(f0)
|
| 89 |
+
phi.appendleft(f0)
|
| 90 |
+
first_step = _select_initial_step(self.func, t[0], self.y0, 2, self.rtol[0], self.atol[0], f0=f0).to(t)
|
| 91 |
+
|
| 92 |
+
self.vcabm_state = _VCABMState(self.y0, prev_f, prev_t, next_t=t[0] + first_step, phi=phi, order=1)
|
| 93 |
+
|
| 94 |
+
def advance(self, final_t):
|
| 95 |
+
final_t = _convert_to_tensor(final_t).to(self.vcabm_state.prev_t[0])
|
| 96 |
+
while final_t > self.vcabm_state.prev_t[0]:
|
| 97 |
+
self.vcabm_state = self._adaptive_adams_step(self.vcabm_state, final_t)
|
| 98 |
+
assert final_t == self.vcabm_state.prev_t[0]
|
| 99 |
+
return self.vcabm_state.y_n
|
| 100 |
+
|
| 101 |
+
def _adaptive_adams_step(self, vcabm_state, final_t):
|
| 102 |
+
y0, prev_f, prev_t, next_t, prev_phi, order = vcabm_state
|
| 103 |
+
if next_t > final_t:
|
| 104 |
+
next_t = final_t
|
| 105 |
+
dt = (next_t - prev_t[0])
|
| 106 |
+
dt_cast = dt.to(y0[0])
|
| 107 |
+
|
| 108 |
+
# Explicit predictor step.
|
| 109 |
+
g, phi = g_and_explicit_phi(prev_t, next_t, prev_phi, order)
|
| 110 |
+
g = g.to(y0[0])
|
| 111 |
+
p_next = tuple(
|
| 112 |
+
y0_ + _scaled_dot_product(dt_cast, g[:max(1, order - 1)], phi_[:max(1, order - 1)])
|
| 113 |
+
for y0_, phi_ in zip(y0, tuple(zip(*phi)))
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
# Update phi to implicit.
|
| 117 |
+
next_f0 = self.func(next_t.to(p_next[0]), p_next)
|
| 118 |
+
implicit_phi_p = compute_implicit_phi(phi, next_f0, order + 1)
|
| 119 |
+
|
| 120 |
+
# Implicit corrector step.
|
| 121 |
+
y_next = tuple(
|
| 122 |
+
p_next_ + dt_cast * g[order - 1] * iphi_ for p_next_, iphi_ in zip(p_next, implicit_phi_p[order - 1])
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
# Error estimation.
|
| 126 |
+
tolerance = tuple(
|
| 127 |
+
atol_ + rtol_ * torch.max(torch.abs(y0_), torch.abs(y1_))
|
| 128 |
+
for atol_, rtol_, y0_, y1_ in zip(self.atol, self.rtol, y0, y_next)
|
| 129 |
+
)
|
| 130 |
+
local_error = tuple(dt_cast * (g[order] - g[order - 1]) * iphi_ for iphi_ in implicit_phi_p[order])
|
| 131 |
+
error_k = _compute_error_ratio(local_error, tolerance)
|
| 132 |
+
accept_step = (torch.tensor(error_k) <= 1).all()
|
| 133 |
+
|
| 134 |
+
if not accept_step:
|
| 135 |
+
# Retry with adjusted step size if step is rejected.
|
| 136 |
+
dt_next = _optimal_step_size(dt, error_k, self.safety, self.ifactor, self.dfactor, order=order)
|
| 137 |
+
return _VCABMState(y0, prev_f, prev_t, prev_t[0] + dt_next, prev_phi, order=order)
|
| 138 |
+
|
| 139 |
+
# We accept the step. Evaluate f and update phi.
|
| 140 |
+
next_f0 = self.func(next_t.to(p_next[0]), y_next)
|
| 141 |
+
implicit_phi = compute_implicit_phi(phi, next_f0, order + 2)
|
| 142 |
+
|
| 143 |
+
next_order = order
|
| 144 |
+
|
| 145 |
+
if len(prev_t) <= 4 or order < 3:
|
| 146 |
+
next_order = min(order + 1, 3, self.max_order)
|
| 147 |
+
else:
|
| 148 |
+
error_km1 = _compute_error_ratio(
|
| 149 |
+
tuple(dt_cast * (g[order - 1] - g[order - 2]) * iphi_ for iphi_ in implicit_phi_p[order - 1]), tolerance
|
| 150 |
+
)
|
| 151 |
+
error_km2 = _compute_error_ratio(
|
| 152 |
+
tuple(dt_cast * (g[order - 2] - g[order - 3]) * iphi_ for iphi_ in implicit_phi_p[order - 2]), tolerance
|
| 153 |
+
)
|
| 154 |
+
if min(error_km1 + error_km2) < max(error_k):
|
| 155 |
+
next_order = order - 1
|
| 156 |
+
elif order < self.max_order:
|
| 157 |
+
error_kp1 = _compute_error_ratio(
|
| 158 |
+
tuple(dt_cast * gamma_star[order] * iphi_ for iphi_ in implicit_phi_p[order]), tolerance
|
| 159 |
+
)
|
| 160 |
+
if max(error_kp1) < max(error_k):
|
| 161 |
+
next_order = order + 1
|
| 162 |
+
|
| 163 |
+
# Keep step size constant if increasing order. Else use adaptive step size.
|
| 164 |
+
dt_next = dt if next_order > order else _optimal_step_size(
|
| 165 |
+
dt, error_k, self.safety, self.ifactor, self.dfactor, order=order + 1
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
prev_f.appendleft(next_f0)
|
| 169 |
+
prev_t.appendleft(next_t)
|
| 170 |
+
return _VCABMState(p_next, prev_f, prev_t, next_t + dt_next, implicit_phi, order=next_order)
|
ShapeID/DiffEqs/adjoint.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from ShapeID.DiffEqs.odeint import odeint
|
| 4 |
+
from ShapeID.DiffEqs.misc import _flatten, _flatten_convert_none_to_zeros
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class OdeintAdjointMethod(torch.autograd.Function):
|
| 8 |
+
|
| 9 |
+
@staticmethod
|
| 10 |
+
def forward(ctx, *args):
|
| 11 |
+
assert len(args) >= 8, 'Internal error: all arguments required.'
|
| 12 |
+
y0, func, t, dt, flat_params, rtol, atol, method, options = \
|
| 13 |
+
args[:-8], args[-8], args[-7], args[-6], args[-5], args[-4], args[-3], args[-2], args[-1]
|
| 14 |
+
|
| 15 |
+
ctx.func, ctx.rtol, ctx.atol, ctx.method, ctx.options = func, rtol, atol, method, options
|
| 16 |
+
|
| 17 |
+
with torch.no_grad():
|
| 18 |
+
ans = odeint(func, y0, t, dt, rtol=rtol, atol=atol, method=method, options=options)
|
| 19 |
+
ctx.save_for_backward(t, flat_params, *ans)
|
| 20 |
+
return ans
|
| 21 |
+
|
| 22 |
+
@staticmethod
|
| 23 |
+
def backward(ctx, *grad_output):
|
| 24 |
+
|
| 25 |
+
t, flat_params, *ans = ctx.saved_tensors
|
| 26 |
+
ans = tuple(ans)
|
| 27 |
+
func, rtol, atol, method, options = ctx.func, ctx.rtol, ctx.atol, ctx.method, ctx.options
|
| 28 |
+
n_tensors = len(ans)
|
| 29 |
+
f_params = tuple(func.parameters())
|
| 30 |
+
|
| 31 |
+
# TODO: use a nn.Module and call odeint_adjoint to implement higher order derivatives.
|
| 32 |
+
def augmented_dynamics(t, y_aug):
|
| 33 |
+
# Dynamics of the original system augmented with
|
| 34 |
+
# the adjoint wrt y, and an integrator wrt t and args.
|
| 35 |
+
y, adj_y = y_aug[:n_tensors], y_aug[n_tensors:2 * n_tensors] # Ignore adj_time and adj_params.
|
| 36 |
+
|
| 37 |
+
with torch.set_grad_enabled(True):
|
| 38 |
+
t = t.to(y[0].device).detach().requires_grad_(True)
|
| 39 |
+
y = tuple(y_.detach().requires_grad_(True) for y_ in y)
|
| 40 |
+
func_eval = func(t, y)
|
| 41 |
+
vjp_t, *vjp_y_and_params = torch.autograd.grad(
|
| 42 |
+
func_eval, (t,) + y + f_params,
|
| 43 |
+
tuple(-adj_y_ for adj_y_ in adj_y), allow_unused=True, retain_graph=True
|
| 44 |
+
)
|
| 45 |
+
vjp_y = vjp_y_and_params[:n_tensors]
|
| 46 |
+
vjp_params = vjp_y_and_params[n_tensors:]
|
| 47 |
+
|
| 48 |
+
# autograd.grad returns None if no gradient, set to zero.
|
| 49 |
+
vjp_t = torch.zeros_like(t) if vjp_t is None else vjp_t
|
| 50 |
+
vjp_y = tuple(torch.zeros_like(y_) if vjp_y_ is None else vjp_y_ for vjp_y_, y_ in zip(vjp_y, y))
|
| 51 |
+
vjp_params = _flatten_convert_none_to_zeros(vjp_params, f_params)
|
| 52 |
+
|
| 53 |
+
if len(f_params) == 0:
|
| 54 |
+
vjp_params = torch.tensor(0.).to(vjp_y[0])
|
| 55 |
+
return (*func_eval, *vjp_y, vjp_t, vjp_params)
|
| 56 |
+
|
| 57 |
+
T = ans[0].shape[0]
|
| 58 |
+
with torch.no_grad():
|
| 59 |
+
adj_y = tuple(grad_output_[-1] for grad_output_ in grad_output)
|
| 60 |
+
adj_params = torch.zeros_like(flat_params)
|
| 61 |
+
adj_time = torch.tensor(0.).to(t)
|
| 62 |
+
time_vjps = []
|
| 63 |
+
for i in range(T - 1, 0, -1):
|
| 64 |
+
|
| 65 |
+
ans_i = tuple(ans_[i] for ans_ in ans)
|
| 66 |
+
grad_output_i = tuple(grad_output_[i] for grad_output_ in grad_output)
|
| 67 |
+
func_i = func(t[i], ans_i)
|
| 68 |
+
|
| 69 |
+
# Compute the effect of moving the current time measurement point.
|
| 70 |
+
dLd_cur_t = sum(
|
| 71 |
+
torch.dot(func_i_.reshape(-1), grad_output_i_.reshape(-1)).reshape(1)
|
| 72 |
+
for func_i_, grad_output_i_ in zip(func_i, grad_output_i)
|
| 73 |
+
)
|
| 74 |
+
adj_time = adj_time - dLd_cur_t
|
| 75 |
+
time_vjps.append(dLd_cur_t)
|
| 76 |
+
|
| 77 |
+
# Run the augmented system backwards in time.
|
| 78 |
+
if adj_params.numel() == 0:
|
| 79 |
+
adj_params = torch.tensor(0.).to(adj_y[0])
|
| 80 |
+
aug_y0 = (*ans_i, *adj_y, adj_time, adj_params)
|
| 81 |
+
aug_ans = odeint(
|
| 82 |
+
augmented_dynamics, aug_y0,
|
| 83 |
+
torch.tensor([t[i], t[i - 1]]), rtol=rtol, atol=atol, method=method, options=options
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
# Unpack aug_ans.
|
| 87 |
+
adj_y = aug_ans[n_tensors:2 * n_tensors]
|
| 88 |
+
adj_time = aug_ans[2 * n_tensors]
|
| 89 |
+
adj_params = aug_ans[2 * n_tensors + 1]
|
| 90 |
+
|
| 91 |
+
adj_y = tuple(adj_y_[1] if len(adj_y_) > 0 else adj_y_ for adj_y_ in adj_y)
|
| 92 |
+
if len(adj_time) > 0: adj_time = adj_time[1]
|
| 93 |
+
if len(adj_params) > 0: adj_params = adj_params[1]
|
| 94 |
+
|
| 95 |
+
adj_y = tuple(adj_y_ + grad_output_[i - 1] for adj_y_, grad_output_ in zip(adj_y, grad_output))
|
| 96 |
+
|
| 97 |
+
del aug_y0, aug_ans
|
| 98 |
+
|
| 99 |
+
time_vjps.append(adj_time)
|
| 100 |
+
time_vjps = torch.cat(time_vjps[::-1])
|
| 101 |
+
|
| 102 |
+
return (*adj_y, None, time_vjps, adj_params, None, None, None, None, None, None) # Add a None (TODO, futher check)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def odeint_adjoint(func, y0, t, dt, rtol=1e-6, atol=1e-12, method=None, options=None):
|
| 106 |
+
|
| 107 |
+
# We need this in order to access the variables inside this module,
|
| 108 |
+
# since we have no other way of getting variables along the execution path.
|
| 109 |
+
if not isinstance(func, nn.Module):
|
| 110 |
+
raise ValueError('func is required to be an instance of nn.Module.')
|
| 111 |
+
|
| 112 |
+
tensor_input = False
|
| 113 |
+
if torch.is_tensor(y0):
|
| 114 |
+
|
| 115 |
+
class TupleFunc(nn.Module):
|
| 116 |
+
|
| 117 |
+
def __init__(self, base_func):
|
| 118 |
+
super(TupleFunc, self).__init__()
|
| 119 |
+
self.base_func = base_func
|
| 120 |
+
|
| 121 |
+
def forward(self, t, y):
|
| 122 |
+
return (self.base_func(t, y[0]),)
|
| 123 |
+
|
| 124 |
+
tensor_input = True
|
| 125 |
+
y0 = (y0,)
|
| 126 |
+
func = TupleFunc(func)
|
| 127 |
+
|
| 128 |
+
flat_params = _flatten(func.parameters())
|
| 129 |
+
ys = OdeintAdjointMethod.apply(*y0, func, t, dt, flat_params, rtol, atol, method, options)
|
| 130 |
+
|
| 131 |
+
if tensor_input:
|
| 132 |
+
ys = ys[0]
|
| 133 |
+
return ys
|
ShapeID/DiffEqs/dopri5.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from .misc import (
|
| 3 |
+
_scaled_dot_product, _convert_to_tensor, _is_finite, _select_initial_step, _handle_unused_kwargs, _is_iterable,
|
| 4 |
+
_optimal_step_size, _compute_error_ratio
|
| 5 |
+
)
|
| 6 |
+
from .solvers import AdaptiveStepsizeODESolver, set_BC_2D, set_BC_3D, add_dBC_2D, add_dBC_3D
|
| 7 |
+
from .interp import _interp_fit, _interp_evaluate
|
| 8 |
+
from .rk_common import _RungeKuttaState, _ButcherTableau, _runge_kutta_step
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
_DORMAND_PRINCE_SHAMPINE_TABLEAU = _ButcherTableau(
|
| 12 |
+
alpha=[1 / 5, 3 / 10, 4 / 5, 8 / 9, 1., 1.],
|
| 13 |
+
beta=[
|
| 14 |
+
[1 / 5],
|
| 15 |
+
[3 / 40, 9 / 40],
|
| 16 |
+
[44 / 45, -56 / 15, 32 / 9],
|
| 17 |
+
[19372 / 6561, -25360 / 2187, 64448 / 6561, -212 / 729],
|
| 18 |
+
[9017 / 3168, -355 / 33, 46732 / 5247, 49 / 176, -5103 / 18656],
|
| 19 |
+
[35 / 384, 0, 500 / 1113, 125 / 192, -2187 / 6784, 11 / 84],
|
| 20 |
+
],
|
| 21 |
+
c_sol=[35 / 384, 0, 500 / 1113, 125 / 192, -2187 / 6784, 11 / 84, 0],
|
| 22 |
+
c_error=[
|
| 23 |
+
35 / 384 - 1951 / 21600,
|
| 24 |
+
0,
|
| 25 |
+
500 / 1113 - 22642 / 50085,
|
| 26 |
+
125 / 192 - 451 / 720,
|
| 27 |
+
-2187 / 6784 - -12231 / 42400,
|
| 28 |
+
11 / 84 - 649 / 6300,
|
| 29 |
+
-1. / 60.,
|
| 30 |
+
],
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
DPS_C_MID = [
|
| 34 |
+
6025192743 / 30085553152 / 2, 0, 51252292925 / 65400821598 / 2, -2691868925 / 45128329728 / 2,
|
| 35 |
+
187940372067 / 1594534317056 / 2, -1776094331 / 19743644256 / 2, 11237099 / 235043384 / 2
|
| 36 |
+
]
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def _interp_fit_dopri5(y0, y1, k, dt, tableau=_DORMAND_PRINCE_SHAMPINE_TABLEAU):
|
| 40 |
+
"""Fit an interpolating polynomial to the results of a Runge-Kutta step."""
|
| 41 |
+
dt = dt.type_as(y0[0])
|
| 42 |
+
y_mid = tuple(y0_ + _scaled_dot_product(dt, DPS_C_MID, k_) for y0_, k_ in zip(y0, k))
|
| 43 |
+
f0 = tuple(k_[0] for k_ in k)
|
| 44 |
+
f1 = tuple(k_[-1] for k_ in k)
|
| 45 |
+
return _interp_fit(y0, y1, y_mid, f0, f1, dt)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def _abs_square(x):
|
| 49 |
+
return torch.mul(x, x)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def _ta_append(list_of_tensors, value):
|
| 53 |
+
"""Append a value to the end of a list of PyTorch tensors."""
|
| 54 |
+
list_of_tensors.append(value)
|
| 55 |
+
return list_of_tensors
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class Dopri5Solver(AdaptiveStepsizeODESolver):
|
| 59 |
+
|
| 60 |
+
def __init__(
|
| 61 |
+
self, func, y0, rtol, atol, dt, first_step=None, safety=0.9, ifactor=10.0, dfactor=0.2, max_num_steps=2**31 - 1,
|
| 62 |
+
options = None
|
| 63 |
+
#**unused_kwargs
|
| 64 |
+
):
|
| 65 |
+
#_handle_unused_kwargs(self, unused_kwargs)
|
| 66 |
+
#del unused_kwargs
|
| 67 |
+
|
| 68 |
+
self.func = func
|
| 69 |
+
self.y0 = y0
|
| 70 |
+
|
| 71 |
+
self.dt = dt #options.dt
|
| 72 |
+
'''if 'dirichlet' in options.BC or 'cauchy' in options.BC and options.contours is not None:
|
| 73 |
+
self.contours = options.contours # (n_batch, nT, 4 / 6, BC_size, sub_spatial_shape)
|
| 74 |
+
self.BC_size = self.contours.size(3)
|
| 75 |
+
self.set_BC = set_BC_2D if self.contours.size(2) == 4 else set_BC_3D
|
| 76 |
+
else:
|
| 77 |
+
self.contours = None
|
| 78 |
+
if 'source' in options.BC and options.dcontours is not None:
|
| 79 |
+
self.dcontours = options.dcontours # (n_batch, nT, 4 / 6, BC_size, sub_spatial_shape)
|
| 80 |
+
self.BC_size = self.dcontours.size(3)
|
| 81 |
+
self.add_dBC = add_dBC_2D if self.dcontours.size(2) == 4 else add_dBC_3D
|
| 82 |
+
else:
|
| 83 |
+
self.dcontours = None'''
|
| 84 |
+
|
| 85 |
+
#self.adjoint = options.adjoint
|
| 86 |
+
|
| 87 |
+
self.rtol = rtol if _is_iterable(rtol) else [rtol] * len(y0)
|
| 88 |
+
self.atol = atol if _is_iterable(atol) else [atol] * len(y0)
|
| 89 |
+
self.first_step = first_step
|
| 90 |
+
self.safety = _convert_to_tensor(safety, dtype=torch.float64, device=y0[0].device)
|
| 91 |
+
self.ifactor = _convert_to_tensor(ifactor, dtype=torch.float64, device=y0[0].device)
|
| 92 |
+
self.dfactor = _convert_to_tensor(dfactor, dtype=torch.float64, device=y0[0].device)
|
| 93 |
+
self.max_num_steps = _convert_to_tensor(max_num_steps, dtype=torch.int32, device=y0[0].device)
|
| 94 |
+
#self.n_step_record=[]
|
| 95 |
+
|
| 96 |
+
def before_integrate(self, t):
|
| 97 |
+
f0 = self.func(t[0].type_as(self.y0[0]), self.y0)
|
| 98 |
+
#print("first_step is {}".format(self.first_step))
|
| 99 |
+
if self.first_step is None:
|
| 100 |
+
first_step = _select_initial_step(self.func, t[0], self.y0, 4, self.rtol[0], self.atol[0], f0=f0).to(t)
|
| 101 |
+
else:
|
| 102 |
+
first_step = _convert_to_tensor(0.01, dtype=t.dtype, device=t.device)
|
| 103 |
+
# if first_step>0.2:
|
| 104 |
+
# print("warning the first step of dopri5 {} is too big, set to 0.2".format(first_step))
|
| 105 |
+
# first_step = _convert_to_tensor(0.2, dtype=torch.float64, device=self.y0[0].device)
|
| 106 |
+
|
| 107 |
+
self.rk_state = _RungeKuttaState(self.y0, f0, t[0], t[0], first_step, interp_coeff=[self.y0] * 5)
|
| 108 |
+
|
| 109 |
+
def advance(self, next_t):
|
| 110 |
+
"""Interpolate through the next time point, integrating as necessary."""
|
| 111 |
+
n_steps = 0
|
| 112 |
+
while next_t > self.rk_state.t1:
|
| 113 |
+
assert n_steps < self.max_num_steps, 'max_num_steps exceeded ({}>={})'.format(n_steps, self.max_num_steps)
|
| 114 |
+
self.rk_state = self._adaptive_dopri5_step(self.rk_state)
|
| 115 |
+
n_steps += 1
|
| 116 |
+
# if len(self.n_step_record)==100:
|
| 117 |
+
# print("this dopri5 step info will print every 100 calls, the current average step is {}".format(sum(self.n_step_record)/100))
|
| 118 |
+
# self.n_step_record=[]
|
| 119 |
+
# else:
|
| 120 |
+
# self.n_step_record.append(n_steps)
|
| 121 |
+
|
| 122 |
+
return _interp_evaluate(self.rk_state.interp_coeff, self.rk_state.t0, self.rk_state.t1, next_t)
|
| 123 |
+
|
| 124 |
+
def _adaptive_dopri5_step(self, rk_state):
|
| 125 |
+
"""Take an adaptive Runge-Kutta step to integrate the DiffEqs."""
|
| 126 |
+
y0, f0, _, t0, dt, interp_coeff = rk_state
|
| 127 |
+
########################################################
|
| 128 |
+
# Assertions #
|
| 129 |
+
########################################################
|
| 130 |
+
assert t0 + dt > t0, 'underflow in dt {}'.format(dt.item())
|
| 131 |
+
# for y0_ in y0:
|
| 132 |
+
# #assert _is_finite(torch.abs(y0_)), 'non-finite values in state `y`: {}'.format(y0_)
|
| 133 |
+
# is_finite= _is_finite(torch.abs(y0_))
|
| 134 |
+
# if not is_finite:
|
| 135 |
+
# print(" non-finite elements exist, try to fix")
|
| 136 |
+
# y0_[y0_ != y0_] = 0.
|
| 137 |
+
# y0_[y0_ == float("Inf")] = 0.
|
| 138 |
+
|
| 139 |
+
y1, f1, y1_error, k = _runge_kutta_step(self.func, y0, f0, t0, dt, tableau=_DORMAND_PRINCE_SHAMPINE_TABLEAU)
|
| 140 |
+
|
| 141 |
+
########################################################
|
| 142 |
+
# Error Ratio #
|
| 143 |
+
########################################################
|
| 144 |
+
mean_sq_error_ratio = _compute_error_ratio(y1_error, atol=self.atol, rtol=self.rtol, y0=y0, y1=y1)
|
| 145 |
+
accept_step = (torch.tensor(mean_sq_error_ratio) <= 1).all()
|
| 146 |
+
|
| 147 |
+
########################################################
|
| 148 |
+
# Update RK State #
|
| 149 |
+
########################################################
|
| 150 |
+
dt_next = _optimal_step_size(
|
| 151 |
+
dt, mean_sq_error_ratio, safety=self.safety, ifactor=self.ifactor, dfactor=self.dfactor, order=5)
|
| 152 |
+
tol_min_dt = 0.2 * self.dt if 0.1 * self.dt >= 0.01 else 0.01
|
| 153 |
+
#print('tol min', tol_min_dt)
|
| 154 |
+
if not (dt_next< tol_min_dt or dt_next>0.1): #(dt_next<0.01 or dt_next>0.1): #(dt_next<0.02): #not (dt_next<0.02 or dt_next>0.1):
|
| 155 |
+
y_next = y1 if accept_step else y0
|
| 156 |
+
f_next = f1 if accept_step else f0
|
| 157 |
+
t_next = t0 + dt if accept_step else t0
|
| 158 |
+
interp_coeff = _interp_fit_dopri5(y0, y_next, k, dt) if accept_step else interp_coeff
|
| 159 |
+
else:
|
| 160 |
+
if dt_next< tol_min_dt: #dt_next<0.01: # 0.01
|
| 161 |
+
#print("Dopri5 step %.3f too small, set to %.3f" % (dt_next, 0.2 * self.dt))
|
| 162 |
+
dt_next = _convert_to_tensor(tol_min_dt, dtype=torch.float64, device=y0[0].device)
|
| 163 |
+
if dt_next>0.1:
|
| 164 |
+
#print("Dopri5 step %.8f is too big, set to 0.1" % (dt_next))
|
| 165 |
+
dt_next = _convert_to_tensor(0.1, dtype=torch.float64, device=y0[0].device)
|
| 166 |
+
y_next = y1
|
| 167 |
+
f_next = f1
|
| 168 |
+
t_next = t0 + dt
|
| 169 |
+
interp_coeff = _interp_fit_dopri5(y0, y1, k, dt)
|
| 170 |
+
rk_state = _RungeKuttaState(y_next, f_next, t0, t_next, dt_next, interp_coeff)
|
| 171 |
+
#print('dt_next', dt_next)
|
| 172 |
+
return rk_state
|
ShapeID/DiffEqs/fixed_adams.py
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import collections
|
| 3 |
+
from ShapeID.DiffEqs.solvers import FixedGridODESolver
|
| 4 |
+
from ShapeID.DiffEqs.misc import _scaled_dot_product, _has_converged
|
| 5 |
+
import ShapeID.DiffEqs.rk_common
|
| 6 |
+
|
| 7 |
+
_BASHFORTH_COEFFICIENTS = [
|
| 8 |
+
[], # order 0
|
| 9 |
+
[11],
|
| 10 |
+
[3, -1],
|
| 11 |
+
[23, -16, 5],
|
| 12 |
+
[55, -59, 37, -9],
|
| 13 |
+
[1901, -2774, 2616, -1274, 251],
|
| 14 |
+
[4277, -7923, 9982, -7298, 2877, -475],
|
| 15 |
+
[198721, -447288, 705549, -688256, 407139, -134472, 19087],
|
| 16 |
+
[434241, -1152169, 2183877, -2664477, 2102243, -1041723, 295767, -36799],
|
| 17 |
+
[14097247, -43125206, 95476786, -139855262, 137968480, -91172642, 38833486, -9664106, 1070017],
|
| 18 |
+
[30277247, -104995189, 265932680, -454661776, 538363838, -444772162, 252618224, -94307320, 20884811, -2082753],
|
| 19 |
+
[
|
| 20 |
+
2132509567, -8271795124, 23591063805, -46113029016, 63716378958, -63176201472, 44857168434, -22329634920,
|
| 21 |
+
7417904451, -1479574348, 134211265
|
| 22 |
+
],
|
| 23 |
+
[
|
| 24 |
+
4527766399, -19433810163, 61633227185, -135579356757, 214139355366, -247741639374, 211103573298, -131365867290,
|
| 25 |
+
58189107627, -17410248271, 3158642445, -262747265
|
| 26 |
+
],
|
| 27 |
+
[
|
| 28 |
+
13064406523627, -61497552797274, 214696591002612, -524924579905150, 932884546055895, -1233589244941764,
|
| 29 |
+
1226443086129408, -915883387152444, 507140369728425, -202322913738370, 55060974662412, -9160551085734,
|
| 30 |
+
703604254357
|
| 31 |
+
],
|
| 32 |
+
[
|
| 33 |
+
27511554976875, -140970750679621, 537247052515662, -1445313351681906, 2854429571790805, -4246767353305755,
|
| 34 |
+
4825671323488452, -4204551925534524, 2793869602879077, -1393306307155755, 505586141196430, -126174972681906,
|
| 35 |
+
19382853593787, -1382741929621
|
| 36 |
+
],
|
| 37 |
+
[
|
| 38 |
+
173233498598849, -960122866404112, 3966421670215481, -11643637530577472, 25298910337081429, -41825269932507728,
|
| 39 |
+
53471026659940509, -53246738660646912, 41280216336284259, -24704503655607728, 11205849753515179,
|
| 40 |
+
-3728807256577472, 859236476684231, -122594813904112, 8164168737599
|
| 41 |
+
],
|
| 42 |
+
[
|
| 43 |
+
362555126427073, -2161567671248849, 9622096909515337, -30607373860520569, 72558117072259733,
|
| 44 |
+
-131963191940828581, 187463140112902893, -210020588912321949, 186087544263596643, -129930094104237331,
|
| 45 |
+
70724351582843483, -29417910911251819, 9038571752734087, -1934443196892599, 257650275915823, -16088129229375
|
| 46 |
+
],
|
| 47 |
+
[
|
| 48 |
+
192996103681340479, -1231887339593444974, 5878428128276811750, -20141834622844109630, 51733880057282977010,
|
| 49 |
+
-102651404730855807942, 160414858999474733422, -199694296833704562550, 199061418623907202560,
|
| 50 |
+
-158848144481581407370, 100878076849144434322, -50353311405771659322, 19338911944324897550,
|
| 51 |
+
-5518639984393844930, 1102560345141059610, -137692773163513234, 8092989203533249
|
| 52 |
+
],
|
| 53 |
+
[
|
| 54 |
+
401972381695456831, -2735437642844079789, 13930159965811142228, -51150187791975812900, 141500575026572531760,
|
| 55 |
+
-304188128232928718008, 518600355541383671092, -710171024091234303204, 786600875277595877750,
|
| 56 |
+
-706174326992944287370, 512538584122114046748, -298477260353977522892, 137563142659866897224,
|
| 57 |
+
-49070094880794267600, 13071639236569712860, -2448689255584545196, 287848942064256339, -15980174332775873
|
| 58 |
+
],
|
| 59 |
+
[
|
| 60 |
+
333374427829017307697, -2409687649238345289684, 13044139139831833251471, -51099831122607588046344,
|
| 61 |
+
151474888613495715415020, -350702929608291455167896, 647758157491921902292692, -967713746544629658690408,
|
| 62 |
+
1179078743786280451953222, -1176161829956768365219840, 960377035444205950813626, -639182123082298748001432,
|
| 63 |
+
343690461612471516746028, -147118738993288163742312, 48988597853073465932820, -12236035290567356418552,
|
| 64 |
+
2157574942881818312049, -239560589366324764716, 12600467236042756559
|
| 65 |
+
],
|
| 66 |
+
[
|
| 67 |
+
691668239157222107697, -5292843584961252933125, 30349492858024727686755, -126346544855927856134295,
|
| 68 |
+
399537307669842150996468, -991168450545135070835076, 1971629028083798845750380, -3191065388846318679544380,
|
| 69 |
+
4241614331208149947151790, -4654326468801478894406214, 4222756879776354065593786, -3161821089800186539248210,
|
| 70 |
+
1943018818982002395655620, -970350191086531368649620, 387739787034699092364924, -121059601023985433003532,
|
| 71 |
+
28462032496476316665705, -4740335757093710713245, 498669220956647866875, -24919383499187492303
|
| 72 |
+
],
|
| 73 |
+
]
|
| 74 |
+
|
| 75 |
+
_MOULTON_COEFFICIENTS = [
|
| 76 |
+
[], # order 0
|
| 77 |
+
[1],
|
| 78 |
+
[1, 1],
|
| 79 |
+
[5, 8, -1],
|
| 80 |
+
[9, 19, -5, 1],
|
| 81 |
+
[251, 646, -264, 106, -19],
|
| 82 |
+
[475, 1427, -798, 482, -173, 27],
|
| 83 |
+
[19087, 65112, -46461, 37504, -20211, 6312, -863],
|
| 84 |
+
[36799, 139849, -121797, 123133, -88547, 41499, -11351, 1375],
|
| 85 |
+
[1070017, 4467094, -4604594, 5595358, -5033120, 3146338, -1291214, 312874, -33953],
|
| 86 |
+
[2082753, 9449717, -11271304, 16002320, -17283646, 13510082, -7394032, 2687864, -583435, 57281],
|
| 87 |
+
[
|
| 88 |
+
134211265, 656185652, -890175549, 1446205080, -1823311566, 1710774528, -1170597042, 567450984, -184776195,
|
| 89 |
+
36284876, -3250433
|
| 90 |
+
],
|
| 91 |
+
[
|
| 92 |
+
262747265, 1374799219, -2092490673, 3828828885, -5519460582, 6043521486, -4963166514, 3007739418, -1305971115,
|
| 93 |
+
384709327, -68928781, 5675265
|
| 94 |
+
],
|
| 95 |
+
[
|
| 96 |
+
703604254357, 3917551216986, -6616420957428, 13465774256510, -21847538039895, 27345870698436, -26204344465152,
|
| 97 |
+
19058185652796, -10344711794985, 4063327863170, -1092096992268, 179842822566, -13695779093
|
| 98 |
+
],
|
| 99 |
+
[
|
| 100 |
+
1382741929621, 8153167962181, -15141235084110, 33928990133618, -61188680131285, 86180228689563, -94393338653892,
|
| 101 |
+
80101021029180, -52177910882661, 25620259777835, -9181635605134, 2268078814386, -345457086395, 24466579093
|
| 102 |
+
],
|
| 103 |
+
[
|
| 104 |
+
8164168737599, 50770967534864, -102885148956217, 251724894607936, -499547203754837, 781911618071632,
|
| 105 |
+
-963605400824733, 934600833490944, -710312834197347, 418551804601264, -187504936597931, 61759426692544,
|
| 106 |
+
-14110480969927, 1998759236336, -132282840127
|
| 107 |
+
],
|
| 108 |
+
[
|
| 109 |
+
16088129229375, 105145058757073, -230992163723849, 612744541065337, -1326978663058069, 2285168598349733,
|
| 110 |
+
-3129453071993581, 3414941728852893, -2966365730265699, 2039345879546643, -1096355235402331, 451403108933483,
|
| 111 |
+
-137515713789319, 29219384284087, -3867689367599, 240208245823
|
| 112 |
+
],
|
| 113 |
+
[
|
| 114 |
+
8092989203533249, 55415287221275246, -131240807912923110, 375195469874202430, -880520318434977010,
|
| 115 |
+
1654462865819232198, -2492570347928318318, 3022404969160106870, -2953729295811279360, 2320851086013919370,
|
| 116 |
+
-1455690451266780818, 719242466216944698, -273894214307914510, 77597639915764930, -15407325991235610,
|
| 117 |
+
1913813460537746, -111956703448001
|
| 118 |
+
],
|
| 119 |
+
[
|
| 120 |
+
15980174332775873, 114329243705491117, -290470969929371220, 890337710266029860, -2250854333681641520,
|
| 121 |
+
4582441343348851896, -7532171919277411636, 10047287575124288740, -10910555637627652470, 9644799218032932490,
|
| 122 |
+
-6913858539337636636, 3985516155854664396, -1821304040326216520, 645008976643217360, -170761422500096220,
|
| 123 |
+
31816981024600492, -3722582669836627, 205804074290625
|
| 124 |
+
],
|
| 125 |
+
[
|
| 126 |
+
12600467236042756559, 93965550344204933076, -255007751875033918095, 834286388106402145800,
|
| 127 |
+
-2260420115705863623660, 4956655592790542146968, -8827052559979384209108, 12845814402199484797800,
|
| 128 |
+
-15345231910046032448070, 15072781455122686545920, -12155867625610599812538, 8008520809622324571288,
|
| 129 |
+
-4269779992576330506540, 1814584564159445787240, -600505972582990474260, 149186846171741510136,
|
| 130 |
+
-26182538841925312881, 2895045518506940460, -151711881512390095
|
| 131 |
+
],
|
| 132 |
+
[
|
| 133 |
+
24919383499187492303, 193280569173472261637, -558160720115629395555, 1941395668950986461335,
|
| 134 |
+
-5612131802364455926260, 13187185898439270330756, -25293146116627869170796, 39878419226784442421820,
|
| 135 |
+
-51970649453670274135470, 56154678684618739939910, -50320851025594566473146, 37297227252822858381906,
|
| 136 |
+
-22726350407538133839300, 11268210124987992327060, -4474886658024166985340, 1389665263296211699212,
|
| 137 |
+
-325187970422032795497, 53935307402575440285, -5652892248087175675, 281550972898020815
|
| 138 |
+
],
|
| 139 |
+
]
|
| 140 |
+
|
| 141 |
+
_DIVISOR = [
|
| 142 |
+
None, 11, 2, 12, 24, 720, 1440, 60480, 120960, 3628800, 7257600, 479001600, 958003200, 2615348736000, 5230697472000,
|
| 143 |
+
31384184832000, 62768369664000, 32011868528640000, 64023737057280000, 51090942171709440000, 102181884343418880000
|
| 144 |
+
]
|
| 145 |
+
|
| 146 |
+
_MIN_ORDER = 4
|
| 147 |
+
_MAX_ORDER = 12
|
| 148 |
+
_MAX_ITERS = 4
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
class AdamsBashforthMoulton(FixedGridODESolver):
|
| 152 |
+
|
| 153 |
+
def __init__(
|
| 154 |
+
self, func, y0, rtol=1e-3, atol=1e-4, implicit=True, max_iters=_MAX_ITERS, max_order=_MAX_ORDER, **kwargs
|
| 155 |
+
):
|
| 156 |
+
super(AdamsBashforthMoulton, self).__init__(func, y0, **kwargs)
|
| 157 |
+
|
| 158 |
+
self.rtol = rtol
|
| 159 |
+
self.atol = atol
|
| 160 |
+
self.implicit = implicit
|
| 161 |
+
self.max_iters = max_iters
|
| 162 |
+
self.max_order = int(min(max_order, _MAX_ORDER))
|
| 163 |
+
self.prev_f = collections.deque(maxlen=self.max_order - 1)
|
| 164 |
+
self.prev_t = None
|
| 165 |
+
|
| 166 |
+
def _update_history(self, t, f):
|
| 167 |
+
if self.prev_t is None or self.prev_t != t:
|
| 168 |
+
self.prev_f.appendleft(f)
|
| 169 |
+
self.prev_t = t
|
| 170 |
+
|
| 171 |
+
def step_func(self, func, t, dt, y):
|
| 172 |
+
self._update_history(t, func(t, y))
|
| 173 |
+
order = min(len(self.prev_f), self.max_order - 1)
|
| 174 |
+
if order < _MIN_ORDER - 1:
|
| 175 |
+
# Compute using RK4.
|
| 176 |
+
dy = rk_common.rk4_alt_step_func(func, t, dt, y, k1=self.prev_f[0])
|
| 177 |
+
return dy
|
| 178 |
+
else:
|
| 179 |
+
# Adams-Bashforth predictor.
|
| 180 |
+
bashforth_coeffs = _BASHFORTH_COEFFICIENTS[order]
|
| 181 |
+
ab_div = _DIVISOR[order]
|
| 182 |
+
dy = tuple(dt * _scaled_dot_product(1 / ab_div, bashforth_coeffs, f_) for f_ in zip(*self.prev_f))
|
| 183 |
+
|
| 184 |
+
# Adams-Moulton corrector.
|
| 185 |
+
if self.implicit:
|
| 186 |
+
moulton_coeffs = _MOULTON_COEFFICIENTS[order + 1]
|
| 187 |
+
am_div = _DIVISOR[order + 1]
|
| 188 |
+
delta = tuple(dt * _scaled_dot_product(1 / am_div, moulton_coeffs[1:], f_) for f_ in zip(*self.prev_f))
|
| 189 |
+
converged = False
|
| 190 |
+
for _ in range(self.max_iters):
|
| 191 |
+
dy_old = dy
|
| 192 |
+
f = func(t + dt, tuple(y_ + dy_ for y_, dy_ in zip(y, dy)))
|
| 193 |
+
dy = tuple(dt * (moulton_coeffs[0] / am_div) * f_ + delta_ for f_, delta_ in zip(f, delta))
|
| 194 |
+
converged = _has_converged(dy_old, dy, self.rtol, self.atol)
|
| 195 |
+
if converged:
|
| 196 |
+
break
|
| 197 |
+
if not converged:
|
| 198 |
+
print('Warning: Functional iteration did not converge. Solution may be incorrect.', file=sys.stderr)
|
| 199 |
+
self.prev_f.pop()
|
| 200 |
+
self._update_history(t, f)
|
| 201 |
+
return dy
|
| 202 |
+
|
| 203 |
+
@property
|
| 204 |
+
def order(self):
|
| 205 |
+
return 4
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
class AdamsBashforth(AdamsBashforthMoulton):
|
| 209 |
+
|
| 210 |
+
def __init__(self, func, y0, **kwargs):
|
| 211 |
+
super(AdamsBashforth, self).__init__(func, y0, implicit=False, **kwargs)
|
ShapeID/DiffEqs/fixed_grid.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ShapeID.DiffEqs.solvers import FixedGridODESolver
|
| 2 |
+
import ShapeID.DiffEqs.rk_common as rk_common
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class Euler(FixedGridODESolver):
|
| 6 |
+
|
| 7 |
+
def step_func(self, func, t, dt, y):
|
| 8 |
+
return tuple(dt * f_ for f_ in func(t, y))
|
| 9 |
+
|
| 10 |
+
@property
|
| 11 |
+
def order(self):
|
| 12 |
+
return 1
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class Midpoint(FixedGridODESolver):
|
| 16 |
+
|
| 17 |
+
def step_func(self, func, t, dt, y):
|
| 18 |
+
y_mid = tuple(y_ + f_ * dt / 2 for y_, f_ in zip(y, func(t, y)))
|
| 19 |
+
return tuple(dt * f_ for f_ in func(t + dt / 2, y_mid))
|
| 20 |
+
|
| 21 |
+
@property
|
| 22 |
+
def order(self):
|
| 23 |
+
return 2
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class RK4(FixedGridODESolver):
|
| 27 |
+
|
| 28 |
+
def step_func(self, func, t, dt, y):
|
| 29 |
+
return rk_common.rk4_alt_step_func(func, t, dt, y)
|
| 30 |
+
|
| 31 |
+
@property
|
| 32 |
+
def order(self):
|
| 33 |
+
return 4
|
ShapeID/DiffEqs/interp.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from ShapeID.DiffEqs.misc import _convert_to_tensor, _dot_product
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def _interp_fit(y0, y1, y_mid, f0, f1, dt):
|
| 6 |
+
"""Fit coefficients for 4th order polynomial interpolation.
|
| 7 |
+
|
| 8 |
+
Args:
|
| 9 |
+
y0: function value at the start of the interval.
|
| 10 |
+
y1: function value at the end of the interval.
|
| 11 |
+
y_mid: function value at the mid-point of the interval.
|
| 12 |
+
f0: derivative value at the start of the interval.
|
| 13 |
+
f1: derivative value at the end of the interval.
|
| 14 |
+
dt: width of the interval.
|
| 15 |
+
|
| 16 |
+
Returns:
|
| 17 |
+
List of coefficients `[a, b, c, d, e]` for interpolating with the polynomial
|
| 18 |
+
`p = a * x ** 4 + b * x ** 3 + c * x ** 2 + d * x + e` for values of `x`
|
| 19 |
+
between 0 (start of interval) and 1 (end of interval).
|
| 20 |
+
"""
|
| 21 |
+
a = tuple(
|
| 22 |
+
_dot_product([-2 * dt, 2 * dt, -8, -8, 16], [f0_, f1_, y0_, y1_, y_mid_])
|
| 23 |
+
for f0_, f1_, y0_, y1_, y_mid_ in zip(f0, f1, y0, y1, y_mid)
|
| 24 |
+
)
|
| 25 |
+
b = tuple(
|
| 26 |
+
_dot_product([5 * dt, -3 * dt, 18, 14, -32], [f0_, f1_, y0_, y1_, y_mid_])
|
| 27 |
+
for f0_, f1_, y0_, y1_, y_mid_ in zip(f0, f1, y0, y1, y_mid)
|
| 28 |
+
)
|
| 29 |
+
c = tuple(
|
| 30 |
+
_dot_product([-4 * dt, dt, -11, -5, 16], [f0_, f1_, y0_, y1_, y_mid_])
|
| 31 |
+
for f0_, f1_, y0_, y1_, y_mid_ in zip(f0, f1, y0, y1, y_mid)
|
| 32 |
+
)
|
| 33 |
+
d = tuple(dt * f0_ for f0_ in f0)
|
| 34 |
+
e = y0
|
| 35 |
+
return [a, b, c, d, e]
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def _interp_evaluate(coefficients, t0, t1, t):
|
| 39 |
+
"""Evaluate polynomial interpolation at the given time point.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
coefficients: list of Tensor coefficients as created by `interp_fit`.
|
| 43 |
+
t0: scalar float64 Tensor giving the start of the interval.
|
| 44 |
+
t1: scalar float64 Tensor giving the end of the interval.
|
| 45 |
+
t: scalar float64 Tensor giving the desired interpolation point.
|
| 46 |
+
|
| 47 |
+
Returns:
|
| 48 |
+
Polynomial interpolation of the coefficients at time `t`.
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
dtype = coefficients[0][0].dtype
|
| 52 |
+
device = coefficients[0][0].device
|
| 53 |
+
|
| 54 |
+
t0 = _convert_to_tensor(t0, dtype=dtype, device=device)
|
| 55 |
+
t1 = _convert_to_tensor(t1, dtype=dtype, device=device)
|
| 56 |
+
t = _convert_to_tensor(t, dtype=dtype, device=device)
|
| 57 |
+
|
| 58 |
+
assert (t0 <= t) & (t <= t1), 'invalid interpolation, fails `t0 <= t <= t1`: {}, {}, {}'.format(t0, t, t1)
|
| 59 |
+
x = ((t - t0) / (t1 - t0)).type(dtype).to(device)
|
| 60 |
+
|
| 61 |
+
xs = [torch.tensor(1).type(dtype).to(device), x]
|
| 62 |
+
for _ in range(2, len(coefficients)):
|
| 63 |
+
xs.append(xs[-1] * x)
|
| 64 |
+
|
| 65 |
+
return tuple(_dot_product(coefficients_, reversed(xs)) for coefficients_ in zip(*coefficients))
|
ShapeID/DiffEqs/misc.py
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def _flatten(sequence):
|
| 6 |
+
flat = [p.contiguous().view(-1) for p in sequence]
|
| 7 |
+
return torch.cat(flat) if len(flat) > 0 else torch.tensor([])
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def _flatten_convert_none_to_zeros(sequence, like_sequence):
|
| 11 |
+
flat = [
|
| 12 |
+
p.contiguous().view(-1) if p is not None else torch.zeros_like(q).view(-1)
|
| 13 |
+
for p, q in zip(sequence, like_sequence)
|
| 14 |
+
]
|
| 15 |
+
return torch.cat(flat) if len(flat) > 0 else torch.tensor([])
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _possibly_nonzero(x):
|
| 19 |
+
return isinstance(x, torch.Tensor) or x != 0
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def _scaled_dot_product(scale, xs, ys):
|
| 23 |
+
"""Calculate a scaled, vector inner product between lists of Tensors."""
|
| 24 |
+
# Using _possibly_nonzero lets us avoid wasted computation.
|
| 25 |
+
return sum([(scale * x) * y for x, y in zip(xs, ys) if _possibly_nonzero(x) or _possibly_nonzero(y)])
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def _dot_product(xs, ys):
|
| 29 |
+
"""Calculate the vector inner product between two lists of Tensors."""
|
| 30 |
+
return sum([x * y for x, y in zip(xs, ys)])
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _has_converged(y0, y1, rtol, atol):
|
| 34 |
+
"""Checks that each element is within the error tolerance."""
|
| 35 |
+
error_tol = tuple(atol + rtol * torch.max(torch.abs(y0_), torch.abs(y1_)) for y0_, y1_ in zip(y0, y1))
|
| 36 |
+
error = tuple(torch.abs(y0_ - y1_) for y0_, y1_ in zip(y0, y1))
|
| 37 |
+
return all((error_ < error_tol_).all() for error_, error_tol_ in zip(error, error_tol))
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _convert_to_tensor(a, dtype=None, device=None):
|
| 41 |
+
if not isinstance(a, torch.Tensor):
|
| 42 |
+
a = torch.tensor(a)
|
| 43 |
+
if dtype is not None:
|
| 44 |
+
a = a.type(dtype)
|
| 45 |
+
if device is not None:
|
| 46 |
+
a = a.to(device)
|
| 47 |
+
return a
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _is_finite(tensor):
|
| 51 |
+
_check = (tensor == float('inf')) + (tensor == float('-inf')) + torch.isnan(tensor)
|
| 52 |
+
return not _check.any()
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def _decreasing(t):
|
| 56 |
+
return (t[1:] < t[:-1]).all()
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def _assert_increasing(t):
|
| 60 |
+
assert (t[1:] > t[:-1]).all(), 't must be strictly increasing or decrasing'
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def _is_iterable(inputs):
|
| 64 |
+
try:
|
| 65 |
+
iter(inputs)
|
| 66 |
+
return True
|
| 67 |
+
except TypeError:
|
| 68 |
+
return False
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def _norm(x):
|
| 72 |
+
"""Compute RMS norm."""
|
| 73 |
+
if torch.is_tensor(x):
|
| 74 |
+
return x.norm() / (x.numel()**0.5)
|
| 75 |
+
else:
|
| 76 |
+
return torch.sqrt(sum(x_.norm()**2 for x_ in x) / sum(x_.numel() for x_ in x))
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def _handle_unused_kwargs(solver, unused_kwargs):
|
| 80 |
+
if len(unused_kwargs) > 0:
|
| 81 |
+
warnings.warn('{}: Unexpected arguments {}'.format(solver.__class__.__name__, unused_kwargs))
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def _select_initial_step(fun, t0, y0, order, rtol, atol, f0=None):
|
| 85 |
+
"""Empirically select a good initial step.
|
| 86 |
+
|
| 87 |
+
The algorithm is described in [1]_.
|
| 88 |
+
|
| 89 |
+
Parameters
|
| 90 |
+
----------
|
| 91 |
+
fun : callable
|
| 92 |
+
Right-hand side of the system.
|
| 93 |
+
t0 : float
|
| 94 |
+
Initial value of the independent variable.
|
| 95 |
+
y0 : ndarray, shape (n,)
|
| 96 |
+
Initial value of the dependent variable.
|
| 97 |
+
direction : float
|
| 98 |
+
Integration direction.
|
| 99 |
+
order : float
|
| 100 |
+
Method order.
|
| 101 |
+
rtol : float
|
| 102 |
+
Desired relative tolerance.
|
| 103 |
+
atol : float
|
| 104 |
+
Desired absolute tolerance.
|
| 105 |
+
|
| 106 |
+
Returns
|
| 107 |
+
-------
|
| 108 |
+
h_abs : float
|
| 109 |
+
Absolute value of the suggested initial step.
|
| 110 |
+
|
| 111 |
+
References
|
| 112 |
+
----------
|
| 113 |
+
.. [1] E. Hairer, S. P. Norsett G. Wanner, "Solving Ordinary Differential
|
| 114 |
+
Equations I: Nonstiff Problems", Sec. II.4.
|
| 115 |
+
"""
|
| 116 |
+
t0 = t0.to(y0[0])
|
| 117 |
+
if f0 is None:
|
| 118 |
+
f0 = fun(t0, y0)
|
| 119 |
+
|
| 120 |
+
rtol = rtol if _is_iterable(rtol) else [rtol] * len(y0)
|
| 121 |
+
atol = atol if _is_iterable(atol) else [atol] * len(y0)
|
| 122 |
+
|
| 123 |
+
scale = tuple(atol_ + torch.abs(y0_) * rtol_ for y0_, atol_, rtol_ in zip(y0, atol, rtol))
|
| 124 |
+
|
| 125 |
+
d0 = tuple(_norm(y0_ / scale_) for y0_, scale_ in zip(y0, scale))
|
| 126 |
+
d1 = tuple(_norm(f0_ / scale_) for f0_, scale_ in zip(f0, scale))
|
| 127 |
+
|
| 128 |
+
if max(d0).item() < 1e-5 or max(d1).item() < 1e-5:
|
| 129 |
+
h0 = torch.tensor(1e-6).to(t0)
|
| 130 |
+
else:
|
| 131 |
+
h0 = 0.01 * max(d0_ / d1_ for d0_, d1_ in zip(d0, d1))
|
| 132 |
+
|
| 133 |
+
y1 = tuple(y0_ + h0 * f0_ for y0_, f0_ in zip(y0, f0))
|
| 134 |
+
f1 = fun(t0 + h0, y1)
|
| 135 |
+
|
| 136 |
+
d2 = tuple(_norm((f1_ - f0_) / scale_) / h0 for f1_, f0_, scale_ in zip(f1, f0, scale))
|
| 137 |
+
|
| 138 |
+
if max(d1).item() <= 1e-15 and max(d2).item() <= 1e-15:
|
| 139 |
+
h1 = torch.max(torch.tensor(1e-6).to(h0), h0 * 1e-3)
|
| 140 |
+
else:
|
| 141 |
+
h1 = (0.01 / max(d1 + d2))**(1. / float(order + 1))
|
| 142 |
+
|
| 143 |
+
return torch.min(100 * h0, h1)
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def _compute_error_ratio(error_estimate, error_tol=None, rtol=None, atol=None, y0=None, y1=None):
|
| 147 |
+
if error_tol is None:
|
| 148 |
+
assert rtol is not None and atol is not None and y0 is not None and y1 is not None
|
| 149 |
+
rtol if _is_iterable(rtol) else [rtol] * len(y0)
|
| 150 |
+
atol if _is_iterable(atol) else [atol] * len(y0)
|
| 151 |
+
error_tol = tuple(
|
| 152 |
+
atol_ + rtol_ * torch.max(torch.abs(y0_), torch.abs(y1_))
|
| 153 |
+
for atol_, rtol_, y0_, y1_ in zip(atol, rtol, y0, y1)
|
| 154 |
+
)
|
| 155 |
+
error_ratio = tuple(error_estimate_ / error_tol_ for error_estimate_, error_tol_ in zip(error_estimate, error_tol))
|
| 156 |
+
mean_sq_error_ratio = tuple(torch.mean(error_ratio_ * error_ratio_) for error_ratio_ in error_ratio)
|
| 157 |
+
return mean_sq_error_ratio
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def _optimal_step_size(last_step, mean_error_ratio, safety=0.9, ifactor=10.0, dfactor=0.2, order=5):
|
| 161 |
+
"""Calculate the optimal size for the next step."""
|
| 162 |
+
mean_error_ratio = max(mean_error_ratio) # Compute step size based on highest ratio.
|
| 163 |
+
if mean_error_ratio == 0:
|
| 164 |
+
return last_step * ifactor
|
| 165 |
+
if mean_error_ratio < 1:
|
| 166 |
+
dfactor = _convert_to_tensor(1, dtype=torch.float64, device=mean_error_ratio.device)
|
| 167 |
+
error_ratio = torch.sqrt(mean_error_ratio).to(last_step)
|
| 168 |
+
exponent = torch.tensor(1 / order).to(last_step)
|
| 169 |
+
factor = torch.max(1 / ifactor, torch.min(error_ratio**exponent / safety, 1 / dfactor))
|
| 170 |
+
return last_step / factor
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def _check_inputs(func, y0, t):
|
| 174 |
+
tensor_input = False
|
| 175 |
+
if torch.is_tensor(y0):
|
| 176 |
+
tensor_input = True
|
| 177 |
+
y0 = (y0,)
|
| 178 |
+
_base_nontuple_func_ = func
|
| 179 |
+
func = lambda t, y: (_base_nontuple_func_(t, y[0]),)
|
| 180 |
+
assert isinstance(y0, tuple), 'y0 must be either a torch.Tensor or a tuple'
|
| 181 |
+
for y0_ in y0:
|
| 182 |
+
assert torch.is_tensor(y0_), 'each element must be a torch.Tensor but received {}'.format(type(y0_))
|
| 183 |
+
|
| 184 |
+
if _decreasing(t):
|
| 185 |
+
t = -t
|
| 186 |
+
_base_reverse_func = func
|
| 187 |
+
func = lambda t, y: tuple(-f_ for f_ in _base_reverse_func(-t, y))
|
| 188 |
+
|
| 189 |
+
for y0_ in y0:
|
| 190 |
+
if not torch.is_floating_point(y0_):
|
| 191 |
+
raise TypeError('`y0` must be a floating point Tensor but is a {}'.format(y0_.type()))
|
| 192 |
+
if not torch.is_floating_point(t):
|
| 193 |
+
raise TypeError('`t` must be a floating point Tensor but is a {}'.format(t.type()))
|
| 194 |
+
|
| 195 |
+
return tensor_input, func, y0, t
|
ShapeID/DiffEqs/odeint.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ShapeID.DiffEqs.tsit5 import Tsit5Solver
|
| 2 |
+
from ShapeID.DiffEqs.dopri5 import Dopri5Solver
|
| 3 |
+
from ShapeID.DiffEqs.fixed_grid import Euler, Midpoint, RK4
|
| 4 |
+
from ShapeID.DiffEqs.fixed_adams import AdamsBashforth, AdamsBashforthMoulton
|
| 5 |
+
from ShapeID.DiffEqs.adams import VariableCoefficientAdamsBashforth
|
| 6 |
+
from ShapeID.DiffEqs.misc import _check_inputs
|
| 7 |
+
|
| 8 |
+
SOLVERS = {
|
| 9 |
+
'explicit_adams': AdamsBashforth,
|
| 10 |
+
'fixed_adams': AdamsBashforthMoulton,
|
| 11 |
+
'adams': VariableCoefficientAdamsBashforth,
|
| 12 |
+
'tsit5': Tsit5Solver,
|
| 13 |
+
'dopri5': Dopri5Solver,
|
| 14 |
+
'euler': Euler,
|
| 15 |
+
'midpoint': Midpoint,
|
| 16 |
+
'rk4': RK4,
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def odeint(func, y0, t, dt, step_size = None, rtol = 1e-7, atol = 1e-9, method = None, options = None):
|
| 21 |
+
"""Integrate a system of ordinary differential equations.
|
| 22 |
+
|
| 23 |
+
Solves the initial value problem for a non-stiff system of first order ODEs:
|
| 24 |
+
```
|
| 25 |
+
dy/dt = func(t, y), y(t[0]) = y0
|
| 26 |
+
```
|
| 27 |
+
where y is a Tensor of any shape.
|
| 28 |
+
|
| 29 |
+
Output dtypes and numerical precision are based on the dtypes of the inputs `y0`.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
func: Function that maps a Tensor holding the state `y` and a scalar Tensor
|
| 33 |
+
`t` into a Tensor of state derivatives with respect to time.
|
| 34 |
+
y0: N-D Tensor giving starting value of `y` at time point `t[0]`. May
|
| 35 |
+
have any floating point or complex dtype.
|
| 36 |
+
t: 1-D Tensor holding a sequence of time points for which to solve for
|
| 37 |
+
`y`. The initial time point should be the first element of this sequence,
|
| 38 |
+
and each time must be larger than the previous time. May have any floating
|
| 39 |
+
point dtype. Converted to a Tensor with float64 dtype.
|
| 40 |
+
rtol: optional float64 Tensor specifying an upper bound on relative error,
|
| 41 |
+
per element of `y`.
|
| 42 |
+
atol: optional float64 Tensor specifying an upper bound on absolute error,
|
| 43 |
+
per element of `y`.
|
| 44 |
+
method: optional string indicating the integration method to use.
|
| 45 |
+
options: optional dict of configuring options for the indicated integration
|
| 46 |
+
method. Can only be provided if a `method` is explicitly set.
|
| 47 |
+
name: Optional name for this operation.
|
| 48 |
+
|
| 49 |
+
Returns:
|
| 50 |
+
y: Tensor, where the first dimension corresponds to different
|
| 51 |
+
time points. Contains the solved value of y for each desired time point in
|
| 52 |
+
`t`, with the initial value `y0` being the first element along the first
|
| 53 |
+
dimension.
|
| 54 |
+
|
| 55 |
+
Raises:
|
| 56 |
+
ValueError: if an invalid `method` is provided.
|
| 57 |
+
TypeError: if `options` is supplied without `method`, or if `t` or `y0` has
|
| 58 |
+
an invalid dtype.
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
tensor_input, func, y0, t = _check_inputs(func, y0, t)
|
| 62 |
+
|
| 63 |
+
if options and method is None:
|
| 64 |
+
raise ValueError('cannot supply `options` without specifying `method`')
|
| 65 |
+
|
| 66 |
+
if method is None:
|
| 67 |
+
method = 'dopri5'
|
| 68 |
+
|
| 69 |
+
#solver = SOLVERS[method](func, y0, rtol = rtol, atol = atol, **options)
|
| 70 |
+
solver = SOLVERS[method](func, y0, rtol = rtol, atol = atol, dt = dt, options = options)
|
| 71 |
+
solution = solver.integrate(t)
|
| 72 |
+
|
| 73 |
+
if tensor_input:
|
| 74 |
+
solution = solution[0]
|
| 75 |
+
return solution
|
ShapeID/DiffEqs/pde.py
ADDED
|
@@ -0,0 +1,643 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ported from https://github.com/pvigier/perlin-numpy
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def gradient_f(X, batched = False, delta_lst = [1., 1., 1.]):
|
| 14 |
+
'''
|
| 15 |
+
Compute gradient of a torch tensor "X" in each direction
|
| 16 |
+
Upper-boundaries: Backward Difference
|
| 17 |
+
Non-boundaries & Upper-boundaries: Forward Difference
|
| 18 |
+
if X is batched: (n_batch, ...);
|
| 19 |
+
else: (...)
|
| 20 |
+
'''
|
| 21 |
+
device = X.device
|
| 22 |
+
dim = len(X.size()) - 1 if batched else len(X.size())
|
| 23 |
+
#print(batched)
|
| 24 |
+
#print(dim)
|
| 25 |
+
if dim == 1:
|
| 26 |
+
#print('dim = 1')
|
| 27 |
+
dX = torch.zeros(X.size(), dtype = torch.float, device = device)
|
| 28 |
+
X = X.permute(1, 0) if batched else X
|
| 29 |
+
dX = dX.permute(1, 0) if batched else dX
|
| 30 |
+
dX[-1] = X[-1] - X[-2] # Backward Difference
|
| 31 |
+
dX[:-1] = X[1:] - X[:-1] # Forward Difference
|
| 32 |
+
|
| 33 |
+
dX = dX.permute(1, 0) if batched else dX
|
| 34 |
+
dX /= delta_lst[0]
|
| 35 |
+
elif dim == 2:
|
| 36 |
+
#print('dim = 2')
|
| 37 |
+
dX = torch.zeros(X.size() + tuple([2]), dtype = torch.float, device = device)
|
| 38 |
+
X = X.permute(1, 2, 0) if batched else X
|
| 39 |
+
dX = dX.permute(1, 2, 3, 0) if batched else dX # put batch to last dim
|
| 40 |
+
dX[-1, :, 0] = X[-1, :] - X[-2, :] # Backward Difference
|
| 41 |
+
dX[:-1, :, 0] = X[1:] - X[:-1] # Forward Difference
|
| 42 |
+
|
| 43 |
+
dX[:, -1, 1] = X[:, -1] - X[:, -2] # Backward Difference
|
| 44 |
+
dX[:, :-1, 1] = X[:, 1:] - X[:, :-1] # Forward Difference
|
| 45 |
+
|
| 46 |
+
dX = dX.permute(3, 0, 1, 2) if batched else dX
|
| 47 |
+
dX[..., 0] /= delta_lst[0]
|
| 48 |
+
dX[..., 1] /= delta_lst[1]
|
| 49 |
+
elif dim == 3:
|
| 50 |
+
#print('dim = 3')
|
| 51 |
+
dX = torch.zeros(X.size() + tuple([3]), dtype = torch.float, device = device)
|
| 52 |
+
X = X.permute(1, 2, 3, 0) if batched else X
|
| 53 |
+
dX = dX.permute(1, 2, 3, 4, 0) if batched else dX
|
| 54 |
+
dX[-1, :, :, 0] = X[-1, :, :] - X[-2, :, :] # Backward Difference
|
| 55 |
+
dX[:-1, :, :, 0] = X[1:] - X[:-1] # Forward Difference
|
| 56 |
+
|
| 57 |
+
dX[:, -1, :, 1] = X[:, -1] - X[:, -2] # Backward Difference
|
| 58 |
+
dX[:, :-1, :, 1] = X[:, 1:] - X[:, :-1] # Forward Difference
|
| 59 |
+
|
| 60 |
+
dX[:, :, -1, 2] = X[:, :, -1] - X[:, :, -2] # Backward Difference
|
| 61 |
+
dX[:, :, :-1, 2] = X[:, :, 1:] - X[:, :, :-1] # Forward Difference
|
| 62 |
+
|
| 63 |
+
dX = dX.permute(4, 0, 1, 2, 3) if batched else dX
|
| 64 |
+
dX[..., 0] /= delta_lst[0]
|
| 65 |
+
dX[..., 1] /= delta_lst[1]
|
| 66 |
+
dX[..., 2] /= delta_lst[2]
|
| 67 |
+
return dX
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def gradient_b(X, batched = False, delta_lst = [1., 1., 1.]):
|
| 71 |
+
'''
|
| 72 |
+
Compute gradient of a torch tensor "X" in each direction
|
| 73 |
+
Non-boundaries & Upper-boundaries: Backward Difference
|
| 74 |
+
Lower-boundaries: Forward Difference
|
| 75 |
+
if X is batched: (n_batch, ...);
|
| 76 |
+
else: (...)
|
| 77 |
+
'''
|
| 78 |
+
device = X.device
|
| 79 |
+
dim = len(X.size()) - 1 if batched else len(X.size())
|
| 80 |
+
#print(batched)
|
| 81 |
+
#print(dim)
|
| 82 |
+
if dim == 1:
|
| 83 |
+
#print('dim = 1')
|
| 84 |
+
dX = torch.zeros(X.size(), dtype = torch.float, device = device)
|
| 85 |
+
X = X.permute(1, 0) if batched else X
|
| 86 |
+
dX = dX.permute(1, 0) if batched else dX
|
| 87 |
+
dX[1:] = X[1:] - X[:-1] # Backward Difference
|
| 88 |
+
dX[0] = X[1] - X[0] # Forward Difference
|
| 89 |
+
|
| 90 |
+
dX = dX.permute(1, 0) if batched else dX
|
| 91 |
+
dX /= delta_lst[0]
|
| 92 |
+
elif dim == 2:
|
| 93 |
+
#print('dim = 2')
|
| 94 |
+
dX = torch.zeros(X.size() + tuple([2]), dtype = torch.float, device = device)
|
| 95 |
+
X = X.permute(1, 2, 0) if batched else X
|
| 96 |
+
dX = dX.permute(1, 2, 3, 0) if batched else dX # put batch to last dim
|
| 97 |
+
dX[1:, :, 0] = X[1:, :] - X[:-1, :] # Backward Difference
|
| 98 |
+
dX[0, :, 0] = X[1] - X[0] # Forward Difference
|
| 99 |
+
|
| 100 |
+
dX[:, 1:, 1] = X[:, 1:] - X[:, :-1] # Backward Difference
|
| 101 |
+
dX[:, 0, 1] = X[:, 1] - X[:, 0] # Forward Difference
|
| 102 |
+
|
| 103 |
+
dX = dX.permute(3, 0, 1, 2) if batched else dX
|
| 104 |
+
dX[..., 0] /= delta_lst[0]
|
| 105 |
+
dX[..., 1] /= delta_lst[1]
|
| 106 |
+
elif dim == 3:
|
| 107 |
+
#print('dim = 3')
|
| 108 |
+
dX = torch.zeros(X.size() + tuple([3]), dtype = torch.float, device = device)
|
| 109 |
+
X = X.permute(1, 2, 3, 0) if batched else X
|
| 110 |
+
dX = dX.permute(1, 2, 3, 4, 0) if batched else dX
|
| 111 |
+
dX[1:, :, :, 0] = X[1:, :, :] - X[:-1, :, :] # Backward Difference
|
| 112 |
+
dX[0, :, :, 0] = X[1] - X[0] # Forward Difference
|
| 113 |
+
|
| 114 |
+
dX[:, 1:, :, 1] = X[:, 1:] - X[:, :-1] # Backward Difference
|
| 115 |
+
dX[:, 0, :, 1] = X[:, 1] - X[:, 0] # Forward Difference
|
| 116 |
+
|
| 117 |
+
dX[:, :, 1:, 2] = X[:, :, 1:] - X[:, :, :-1] # Backward Difference
|
| 118 |
+
dX[:, :, 0, 2] = X[:, :, 1] - X[:, :, 0] # Forward Difference
|
| 119 |
+
|
| 120 |
+
dX = dX.permute(4, 0, 1, 2, 3) if batched else dX
|
| 121 |
+
dX[..., 0] /= delta_lst[0]
|
| 122 |
+
dX[..., 1] /= delta_lst[1]
|
| 123 |
+
dX[..., 2] /= delta_lst[2]
|
| 124 |
+
return dX
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def gradient_c(X, batched = False, delta_lst = [1., 1., 1.]):
|
| 128 |
+
'''
|
| 129 |
+
Compute gradient of a torch tensor "X" in each direction
|
| 130 |
+
Non-boundaries: Central Difference
|
| 131 |
+
Upper-boundaries: Backward Difference
|
| 132 |
+
Lower-boundaries: Forward Difference
|
| 133 |
+
if X is batched: (n_batch, ...);
|
| 134 |
+
else: (...)
|
| 135 |
+
'''
|
| 136 |
+
device = X.device
|
| 137 |
+
dim = len(X.size()) - 1 if batched else len(X.size())
|
| 138 |
+
#print(X.size())
|
| 139 |
+
#print(batched)
|
| 140 |
+
#print(dim)
|
| 141 |
+
if dim == 1:
|
| 142 |
+
#print('dim = 1')
|
| 143 |
+
dX = torch.zeros(X.size(), dtype = torch.float, device = device)
|
| 144 |
+
X = X.permute(1, 0) if batched else X
|
| 145 |
+
dX = dX.permute(1, 0) if batched else dX
|
| 146 |
+
dX[1:-1] = (X[2:] - X[:-2]) / 2 # Central Difference
|
| 147 |
+
dX[0] = X[1] - X[0] # Forward Difference
|
| 148 |
+
dX[-1] = X[-1] - X[-2] # Backward Difference
|
| 149 |
+
|
| 150 |
+
dX = dX.permute(1, 0) if batched else dX
|
| 151 |
+
dX /= delta_lst[0]
|
| 152 |
+
elif dim == 2:
|
| 153 |
+
#print('dim = 2')
|
| 154 |
+
dX = torch.zeros(X.size() + tuple([2]), dtype = torch.float, device = device)
|
| 155 |
+
X = X.permute(1, 2, 0) if batched else X
|
| 156 |
+
dX = dX.permute(1, 2, 3, 0) if batched else dX # put batch to last dim
|
| 157 |
+
dX[1:-1, :, 0] = (X[2:, :] - X[:-2, :]) / 2
|
| 158 |
+
dX[0, :, 0] = X[1] - X[0]
|
| 159 |
+
dX[-1, :, 0] = X[-1] - X[-2]
|
| 160 |
+
dX[:, 1:-1, 1] = (X[:, 2:] - X[:, :-2]) / 2
|
| 161 |
+
dX[:, 0, 1] = X[:, 1] - X[:, 0]
|
| 162 |
+
dX[:, -1, 1] = X[:, -1] - X[:, -2]
|
| 163 |
+
|
| 164 |
+
dX = dX.permute(3, 0, 1, 2) if batched else dX
|
| 165 |
+
dX[..., 0] /= delta_lst[0]
|
| 166 |
+
dX[..., 1] /= delta_lst[1]
|
| 167 |
+
elif dim == 3:
|
| 168 |
+
#print('dim = 3')
|
| 169 |
+
dX = torch.zeros(X.size() + tuple([3]), dtype = torch.float, device = device)
|
| 170 |
+
X = X.permute(1, 2, 3, 0) if batched else X
|
| 171 |
+
dX = dX.permute(1, 2, 3, 4, 0) if batched else dX
|
| 172 |
+
dX[1:-1, :, :, 0] = (X[2:, :, :] - X[:-2, :, :]) / 2
|
| 173 |
+
dX[0, :, :, 0] = X[1] - X[0]
|
| 174 |
+
dX[-1, :, :, 0] = X[-1] - X[-2]
|
| 175 |
+
dX[:, 1:-1, :, 1] = (X[:, 2:, :] - X[:, :-2, :]) / 2
|
| 176 |
+
dX[:, 0, :, 1] = X[:, 1] - X[:, 0]
|
| 177 |
+
dX[:, -1, :, 1] = X[:, -1] - X[:, -2]
|
| 178 |
+
dX[:, :, 1:-1, 2] = (X[:, :, 2:] - X[:, :, :-2]) / 2
|
| 179 |
+
dX[:, :, 0, 2] = X[:, :, 1] - X[:, :, 0]
|
| 180 |
+
dX[:, :, -1, 2] = X[:, :, -1] - X[:, :, -2]
|
| 181 |
+
|
| 182 |
+
dX = dX.permute(4, 0, 1, 2, 3) if batched else dX
|
| 183 |
+
dX[..., 0] /= delta_lst[0]
|
| 184 |
+
dX[..., 1] /= delta_lst[1]
|
| 185 |
+
dX[..., 2] /= delta_lst[2]
|
| 186 |
+
return dX
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def gradient_c_numpy(X, batched = False, delta_lst = [1., 1., 1.]):
|
| 190 |
+
'''
|
| 191 |
+
Compute gradient of a Numpy array "X" in each direction
|
| 192 |
+
Non-boundaries: Central Difference
|
| 193 |
+
Upper-boundaries: Backward Difference
|
| 194 |
+
Lower-boundaries: Forward Difference
|
| 195 |
+
if X is batched: (n_batch, ...);
|
| 196 |
+
else: (...)
|
| 197 |
+
'''
|
| 198 |
+
dim = len(X.shape) - 1 if batched else len(X.shape)
|
| 199 |
+
#print(dim)
|
| 200 |
+
if dim == 1:
|
| 201 |
+
#print('dim = 1')
|
| 202 |
+
X = np.transpose(X, (1, 0)) if batched else X
|
| 203 |
+
dX = np.zeros(X.shapee).astype(float)
|
| 204 |
+
dX[1:-1] = (X[2:] - X[:-2]) / 2 # Central Difference
|
| 205 |
+
dX[0] = X[1] - X[0] # Forward Difference
|
| 206 |
+
dX[-1] = X[-1] - X[-2] # Backward Difference
|
| 207 |
+
|
| 208 |
+
dX = np.transpose(X, (1, 0)) if batched else dX
|
| 209 |
+
dX /= delta_lst[0]
|
| 210 |
+
elif dim == 2:
|
| 211 |
+
#print('dim = 2')
|
| 212 |
+
dX = np.zeros(X.shape + tuple([2])).astype(float)
|
| 213 |
+
X = np.transpose(X, (1, 2, 0)) if batched else X
|
| 214 |
+
dX = np.transpose(dX, (1, 2, 3, 0)) if batched else dX # put batch to last dim
|
| 215 |
+
dX[1:-1, :, 0] = (X[2:, :] - X[:-2, :]) / 2
|
| 216 |
+
dX[0, :, 0] = X[1] - X[0]
|
| 217 |
+
dX[-1, :, 0] = X[-1] - X[-2]
|
| 218 |
+
dX[:, 1:-1, 1] = (X[:, 2:] - X[:, :-2]) / 2
|
| 219 |
+
dX[:, 0, 1] = X[:, 1] - X[:, 0]
|
| 220 |
+
dX[:, -1, 1] = X[:, -1] - X[:, -2]
|
| 221 |
+
|
| 222 |
+
dX = np.transpose(dX, (3, 0, 1, 2)) if batched else dX
|
| 223 |
+
dX[..., 0] /= delta_lst[0]
|
| 224 |
+
dX[..., 1] /= delta_lst[1]
|
| 225 |
+
elif dim == 3:
|
| 226 |
+
#print('dim = 3')
|
| 227 |
+
dX = np.zeros(X.shape + tuple([3])).astype(float)
|
| 228 |
+
X = np.transpose(X, (1, 2, 3, 0)) if batched else X
|
| 229 |
+
dX = np.transpose(dX, (1, 2, 3, 4, 0)) if batched else dX # put batch to last dim
|
| 230 |
+
dX[1:-1, :, :, 0] = (X[2:, :, :] - X[:-2, :, :]) / 2
|
| 231 |
+
dX[0, :, :, 0] = X[1] - X[0]
|
| 232 |
+
dX[-1, :, :, 0] = X[-1] - X[-2]
|
| 233 |
+
dX[:, 1:-1, :, 1] = (X[:, 2:, :] - X[:, :-2, :]) / 2
|
| 234 |
+
dX[:, 0, :, 1] = X[:, 1] - X[:, 0]
|
| 235 |
+
dX[:, -1, :, 1] = X[:, -1] - X[:, -2]
|
| 236 |
+
dX[:, :, 1:-1, 2] = (X[:, :, 2:] - X[:, :, :-2]) / 2
|
| 237 |
+
dX[:, :, 0, 2] = X[:, :, 1] - X[:, :, 0]
|
| 238 |
+
dX[:, :, -1, 2] = X[:, :, -1] - X[:, :, -2]
|
| 239 |
+
|
| 240 |
+
dX = np.transpose(dX, (4, 0, 1, 2, 3)) if batched else dX
|
| 241 |
+
dX[..., 0] /= delta_lst[0]
|
| 242 |
+
dX[..., 1] /= delta_lst[1]
|
| 243 |
+
dX[..., 2] /= delta_lst[2]
|
| 244 |
+
return dX
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
def gradient_f_numpy(X, batched = False, delta_lst = [1., 1., 1.]):
|
| 248 |
+
'''
|
| 249 |
+
Compute gradient of a torch tensor "X" in each direction
|
| 250 |
+
Upper-boundaries: Backward Difference
|
| 251 |
+
Non-boundaries & Upper-boundaries: Forward Difference
|
| 252 |
+
if X is batched: (n_batch, ...);
|
| 253 |
+
else: (...)
|
| 254 |
+
'''
|
| 255 |
+
dim = len(X.shape) - 1 if batched else len(X.shape)
|
| 256 |
+
#print(dim)
|
| 257 |
+
if dim == 1:
|
| 258 |
+
#print('dim = 1')
|
| 259 |
+
X = np.transpose(X, (1, 0)) if batched else X
|
| 260 |
+
dX = np.zeros(X.shapee).astype(float)
|
| 261 |
+
dX[-1] = X[-1] - X[-2] # Backward Difference
|
| 262 |
+
dX[:-1] = X[1:] - X[:-1] # Forward Difference
|
| 263 |
+
|
| 264 |
+
dX = np.transpose(X, (1, 0)) if batched else dX
|
| 265 |
+
dX /= delta_lst[0]
|
| 266 |
+
elif dim == 2:
|
| 267 |
+
#print('dim = 2')
|
| 268 |
+
dX = np.zeros(X.shape + tuple([2])).astype(float)
|
| 269 |
+
X = np.transpose(X, (1, 2, 0)) if batched else X
|
| 270 |
+
dX = np.transpose(dX, (1, 2, 3, 0)) if batched else dX # put batch to last dim
|
| 271 |
+
dX[-1, :, 0] = X[-1, :] - X[-2, :] # Backward Difference
|
| 272 |
+
dX[:-1, :, 0] = X[1:] - X[:-1] # Forward Difference
|
| 273 |
+
|
| 274 |
+
dX[:, -1, 1] = X[:, -1] - X[:, -2] # Backward Difference
|
| 275 |
+
dX[:, :-1, 1] = X[:, 1:] - X[:, :-1] # Forward Difference
|
| 276 |
+
|
| 277 |
+
dX = np.transpose(dX, (3, 0, 1, 2)) if batched else dX
|
| 278 |
+
dX[..., 0] /= delta_lst[0]
|
| 279 |
+
dX[..., 1] /= delta_lst[1]
|
| 280 |
+
elif dim == 3:
|
| 281 |
+
#print('dim = 3')
|
| 282 |
+
dX = np.zeros(X.shape + tuple([3])).astype(float)
|
| 283 |
+
X = np.transpose(X, (1, 2, 3, 0)) if batched else X
|
| 284 |
+
dX = np.transpose(dX, (1, 2, 3, 4, 0)) if batched else dX # put batch to last dim
|
| 285 |
+
dX[-1, :, :, 0] = X[-1, :, :] - X[-2, :, :] # Backward Difference
|
| 286 |
+
dX[:-1, :, :, 0] = X[1:] - X[:-1] # Forward Difference
|
| 287 |
+
|
| 288 |
+
dX[:, -1, :, 1] = X[:, -1] - X[:, -2] # Backward Difference
|
| 289 |
+
dX[:, :-1, :, 1] = X[:, 1:] - X[:, :-1] # Forward Difference
|
| 290 |
+
|
| 291 |
+
dX[:, :, -1, 2] = X[:, :, -1] - X[:, :, -2] # Backward Difference
|
| 292 |
+
dX[:, :, :-1, 2] = X[:, :, 1:] - X[:, :, :-1] # Forward Difference
|
| 293 |
+
|
| 294 |
+
dX = np.transpose(dX, (4, 0, 1, 2, 3)) if batched else dX
|
| 295 |
+
dX[..., 0] /= delta_lst[0]
|
| 296 |
+
dX[..., 1] /= delta_lst[1]
|
| 297 |
+
dX[..., 2] /= delta_lst[2]
|
| 298 |
+
return dX
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
class Upwind(object):
|
| 302 |
+
'''
|
| 303 |
+
Backward if > 0, forward if <= 0
|
| 304 |
+
'''
|
| 305 |
+
def __init__(self, U, data_spacing = [1., 1, 1.], batched = True):
|
| 306 |
+
self.U = U # (s, r, c)
|
| 307 |
+
self.batched = batched
|
| 308 |
+
self.data_spacing = data_spacing
|
| 309 |
+
self.dim = len(self.U.size()) - 1 if batched else len(self.U.size())
|
| 310 |
+
self.I = torch.ones(self.U.size(), dtype = torch.float, device = U.device)
|
| 311 |
+
|
| 312 |
+
def dX(self, FGx):
|
| 313 |
+
dXf = gradient_f(self.U, batched = self.batched, delta_lst = self.data_spacing)[..., 0]
|
| 314 |
+
dXb = gradient_b(self.U, batched = self.batched, delta_lst = self.data_spacing)[..., 0]
|
| 315 |
+
Xflag = (FGx > 0).float()
|
| 316 |
+
return dXf * (self.I - Xflag) + dXb * Xflag
|
| 317 |
+
|
| 318 |
+
def dY(self, FGy):
|
| 319 |
+
dYf = gradient_f(self.U, batched = self.batched, delta_lst = self.data_spacing)[..., 1]
|
| 320 |
+
dYb = gradient_b(self.U, batched = self.batched, delta_lst = self.data_spacing)[..., 1]
|
| 321 |
+
Yflag = (FGy > 0).float()
|
| 322 |
+
return dYf * (self.I - Yflag) + dYb * Yflag
|
| 323 |
+
|
| 324 |
+
def dZ(self, FGz):
|
| 325 |
+
dZf = gradient_f(self.U, batched = self.batched, delta_lst = self.data_spacing)[..., 2]
|
| 326 |
+
dZb = gradient_b(self.U, batched = self.batched, delta_lst = self.data_spacing)[..., 2]
|
| 327 |
+
Zflag = (FGz > 0).float()
|
| 328 |
+
return dZf * (self.I - Zflag) + dZb * Zflag
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
class AdvDiffPartial(nn.Module):
|
| 332 |
+
def __init__(self, data_spacing, device):
|
| 333 |
+
super(AdvDiffPartial, self).__init__()
|
| 334 |
+
self.dimension = len(data_spacing) # (slc, row, col)
|
| 335 |
+
self.device = device
|
| 336 |
+
self.data_spacing = data_spacing
|
| 337 |
+
|
| 338 |
+
@property
|
| 339 |
+
def Grad_Ds(self):
|
| 340 |
+
return {
|
| 341 |
+
'constant': self.Grad_constantD,
|
| 342 |
+
'scalar': self.Grad_scalarD,
|
| 343 |
+
'diag': self.Grad_diagD,
|
| 344 |
+
'full': self.Grad_fullD,
|
| 345 |
+
'full_dual': self.Grad_fullD,
|
| 346 |
+
'full_spectral':self.Grad_fullD,
|
| 347 |
+
'full_cholesky': self.Grad_fullD,
|
| 348 |
+
'full_symmetric': self.Grad_fullD
|
| 349 |
+
}
|
| 350 |
+
@property
|
| 351 |
+
def Grad_Vs(self):
|
| 352 |
+
return {
|
| 353 |
+
'constant': self.Grad_constantV,
|
| 354 |
+
'scalar': self.Grad_scalarV,
|
| 355 |
+
'vector': self.Grad_vectorV, # For general V w/o div-free TODO self.Grad_vectorV
|
| 356 |
+
'vector_div_free': self.Grad_div_free_vectorV,
|
| 357 |
+
'vector_div_free_clebsch': self.Grad_div_free_vectorV,
|
| 358 |
+
'vector_div_free_stream': self.Grad_div_free_vectorV,
|
| 359 |
+
'vector_div_free_stream_gauge': self.Grad_div_free_vectorV,
|
| 360 |
+
}
|
| 361 |
+
|
| 362 |
+
def Grad_constantD(self, C, Dlst):
|
| 363 |
+
if self.dimension == 1:
|
| 364 |
+
return Dlst['D'] * (self.ddXc(C))
|
| 365 |
+
elif self.dimension == 2:
|
| 366 |
+
return Dlst['D'] * (self.ddXc(C) + self.ddYc(C))
|
| 367 |
+
elif self.dimension == 3:
|
| 368 |
+
return Dlst['D'] * (self.ddXc(C) + self.ddYc(C) + self.ddZc(C))
|
| 369 |
+
|
| 370 |
+
def Grad_constant_tensorD(self, C, Dlst):
|
| 371 |
+
if self.dimension == 1:
|
| 372 |
+
raise NotImplementedError
|
| 373 |
+
elif self.dimension == 2:
|
| 374 |
+
dC_c = self.dc(C)
|
| 375 |
+
dC_f = self.df(C)
|
| 376 |
+
return Dlst['Dxx'] * self.dXb(dC_f[..., 0]) +\
|
| 377 |
+
Dlst['Dxy'] * self.dXb(dC_f[..., 1]) + Dlst['Dxy'] * self.dYb(dC_f[..., 0]) +\
|
| 378 |
+
Dlst['Dyy'] * self.dYb(dC_f[..., 1])
|
| 379 |
+
elif self.dimension == 3:
|
| 380 |
+
dC_c = self.dc(C)
|
| 381 |
+
dC_f = self.df(C)
|
| 382 |
+
return Dlst['Dxx'] * self.dXb(dC_f[..., 0]) + Dlst['Dyy'] * self.dYb(dC_f[..., 1]) + Dlst['Dzz'] * self.dZb(dC_f[..., 2]) + \
|
| 383 |
+
Dlst['Dxy'] * (self.dXb(dC_f[..., 1]) + self.dYb(dC_f[..., 0])) + \
|
| 384 |
+
Dlst['Dyz'] * (self.dYb(dC_f[..., 2]) + self.dZb(dC_f[..., 1])) + \
|
| 385 |
+
Dlst['Dxz'] * (self.dZb(dC_f[..., 0]) + self.dXb(dC_f[..., 2]))
|
| 386 |
+
|
| 387 |
+
def Grad_scalarD(self, C, Dlst): # batch_C: (batch_size, (slc), row, col)
|
| 388 |
+
# Expanded version: \nabla (D \nabla C) => \nabla D \cdot \nabla C (part (a)) + D \Delta C (part (b)) #
|
| 389 |
+
# NOTE: Work better than Central Differences !!! #
|
| 390 |
+
# Nested Forward-Backward Difference Scheme in part (b)#
|
| 391 |
+
if self.dimension == 1:
|
| 392 |
+
dC = gradient_c(C, batched = True, delta_lst = self.data_spacing)
|
| 393 |
+
return gradient_c(Dlst['D'], batched = True, delta_lst = self.data_spacing) * dC + \
|
| 394 |
+
Dlst['D'] * gradient_c(dC, batched = True, delta_lst = self.data_spacing)
|
| 395 |
+
else: # (dimension = 2 or 3)
|
| 396 |
+
dC_c = gradient_c(C, batched = True, delta_lst = self.data_spacing)
|
| 397 |
+
dC_f = gradient_f(C, batched = True, delta_lst = self.data_spacing)
|
| 398 |
+
dD_c = gradient_c(Dlst['D'], batched = True, delta_lst = self.data_spacing)
|
| 399 |
+
out = (dD_c * dC_c).sum(-1)
|
| 400 |
+
for dim in range(dC_f.size(-1)):
|
| 401 |
+
out += Dlst['D'] * gradient_b(dC_f[..., dim], batched = True, delta_lst = self.data_spacing)[..., dim]
|
| 402 |
+
return out
|
| 403 |
+
|
| 404 |
+
def Grad_diagD(self, C, Dlst):
|
| 405 |
+
# Expanded version #
|
| 406 |
+
if self.dimension == 1:
|
| 407 |
+
raise NotImplementedError('diag_D is not supported for 1D version of diffusivity')
|
| 408 |
+
elif self.dimension == 2:
|
| 409 |
+
dC_c = self.dc(C)
|
| 410 |
+
dC_f = self.df(C)
|
| 411 |
+
return self.dXc(Dlst['Dxx']) * dC_c[..., 0] + Dlst['Dxx'] * self.dXb(dC_f[..., 0]) +\
|
| 412 |
+
self.dYc(Dlst['Dyy']) * dC_c[..., 1] + Dlst['Dyy'] * self.dYb(dC_f[..., 1])
|
| 413 |
+
elif self.dimension == 3:
|
| 414 |
+
dC_c = self.dc(C)
|
| 415 |
+
dC_f = self.df(C)
|
| 416 |
+
return self.dXc(Dlst['Dxx']) * dC_c[..., 0] + Dlst['Dxx'] * self.dXb(dC_f[..., 0]) +\
|
| 417 |
+
self.dYc(Dlst['Dyy']) * dC_c[..., 1] + Dlst['Dyy'] * self.dYb(dC_f[..., 1]) +\
|
| 418 |
+
self.dZc(Dlst['Dzz']) * dC_c[..., 2] + Dlst['Dzz'] * self.dZb(dC_f[..., 2])
|
| 419 |
+
|
| 420 |
+
def Grad_fullD(self, C, Dlst):
|
| 421 |
+
# Expanded version #
|
| 422 |
+
'''https://github.com/uncbiag/PIANOinD/blob/master/Doc/PIANOinD.pdf'''
|
| 423 |
+
if self.dimension == 1:
|
| 424 |
+
raise NotImplementedError('full_D is not supported for 1D version of diffusivity')
|
| 425 |
+
elif self.dimension == 2:
|
| 426 |
+
dC_c = self.dc(C)
|
| 427 |
+
dC_f = self.df(C)
|
| 428 |
+
return self.dXc(Dlst['Dxx']) * dC_c[..., 0] + Dlst['Dxx'] * self.dXb(dC_f[..., 0]) +\
|
| 429 |
+
self.dXc(Dlst['Dxy']) * dC_c[..., 1] + Dlst['Dxy'] * self.dXb(dC_f[..., 1]) +\
|
| 430 |
+
self.dYc(Dlst['Dxy']) * dC_c[..., 0] + Dlst['Dxy'] * self.dYb(dC_f[..., 0]) +\
|
| 431 |
+
self.dYc(Dlst['Dyy']) * dC_c[..., 1] + Dlst['Dyy'] * self.dYb(dC_f[..., 1])
|
| 432 |
+
elif self.dimension == 3:
|
| 433 |
+
dC_c = self.dc(C)
|
| 434 |
+
dC_f = self.df(C)
|
| 435 |
+
return (self.dXc(Dlst['Dxx']) + self.dYc(Dlst['Dxy']) + self.dZc(Dlst['Dxz'])) * dC_c[..., 0] + \
|
| 436 |
+
(self.dXc(Dlst['Dxy']) + self.dYc(Dlst['Dyy']) + self.dZc(Dlst['Dyz'])) * dC_c[..., 1] + \
|
| 437 |
+
(self.dXc(Dlst['Dxz']) + self.dYc(Dlst['Dyz']) + self.dZc(Dlst['Dzz'])) * dC_c[..., 2] + \
|
| 438 |
+
Dlst['Dxx'] * self.dXb(dC_f[..., 0]) + Dlst['Dyy'] * self.dYb(dC_f[..., 1]) + Dlst['Dzz'] * self.dZb(dC_f[..., 2]) + \
|
| 439 |
+
Dlst['Dxy'] * (self.dXb(dC_f[..., 1]) + self.dYb(dC_f[..., 0])) + \
|
| 440 |
+
Dlst['Dyz'] * (self.dYb(dC_f[..., 2]) + self.dZb(dC_f[..., 1])) + \
|
| 441 |
+
Dlst['Dxz'] * (self.dZb(dC_f[..., 0]) + self.dXb(dC_f[..., 2]))
|
| 442 |
+
|
| 443 |
+
def Grad_constantV(self, C, Vlst):
|
| 444 |
+
if len(Vlst['V'].size()) == 1:
|
| 445 |
+
if self.dimension == 1:
|
| 446 |
+
return - Vlst['V'] * self.dXb(C) if Vlst['V'] > 0 else - Vlst['V'] * self.dXf(C)
|
| 447 |
+
elif self.dimension == 2:
|
| 448 |
+
return - Vlst['V'] * (self.dXb(C) + self.dYb(C)) if Vlst['V'] > 0 else - Vlst['V'] * (self.dXf(C) + self.dYf(C))
|
| 449 |
+
elif self.dimension == 3:
|
| 450 |
+
return - Vlst['V'] * (self.dXb(C) + self.dYb(C) + self.dZb(C)) if Vlst['V'] > 0 else - Vlst['V'] * (self.dXf(C) + self.dYf(C) + self.dZf(C))
|
| 451 |
+
else:
|
| 452 |
+
if self.dimension == 1:
|
| 453 |
+
return - Vlst['V'] * self.dXb(C) if Vlst['V'][0, 0] > 0 else - Vlst['V'] * self.dXf(C)
|
| 454 |
+
elif self.dimension == 2:
|
| 455 |
+
return - Vlst['V'] * (self.dXb(C) + self.dYb(C)) if Vlst['V'][0, 0, 0] > 0 else - Vlst['V'] * (self.dXf(C) + self.dYf(C))
|
| 456 |
+
elif self.dimension == 3:
|
| 457 |
+
return - Vlst['V'] * (self.dXb(C) + self.dYb(C) + self.dZb(C)) if Vlst['V'][0, 0, 0, 0] > 0 else - Vlst['V'] * (self.dXf(C) + self.dYf(C) + self.dZf(C))
|
| 458 |
+
|
| 459 |
+
def Grad_constant_vectorV(self, C, Vlst):
|
| 460 |
+
if self.dimension == 1:
|
| 461 |
+
raise NotImplementedError
|
| 462 |
+
elif self.dimension == 2:
|
| 463 |
+
out_x = - Vlst['Vx'] * (self.dXb(C) + self.dYb(C)) if Vlst['Vx'][0, 0, 0] > 0 else - Vlst['Vx'] * (self.dXf(C) + self.dYf(C))
|
| 464 |
+
out_y = - Vlst['Vy'] * (self.dXb(C) + self.dYb(C)) if Vlst['Vy'][0, 0, 0] > 0 else - Vlst['Vy'] * (self.dXf(C) + self.dYf(C))
|
| 465 |
+
return out_x + out_y
|
| 466 |
+
elif self.dimension == 3:
|
| 467 |
+
out_x = - Vlst['Vx'] * (self.dXb(C) + self.dYb(C)) if Vlst['Vx'][0, 0, 0] > 0 else - Vlst['Vx'] * (self.dXf(C) + self.dYf(C))
|
| 468 |
+
out_y = - Vlst['Vy'] * (self.dXb(C) + self.dYb(C)) if Vlst['Vy'][0, 0, 0] > 0 else - Vlst['Vy'] * (self.dXf(C) + self.dYf(C))
|
| 469 |
+
out_z = - Vlst['Vz'] * (self.dXb(C) + self.dYb(C)) if Vlst['Vz'][0, 0, 0] > 0 else - Vlst['Vz'] * (self.dXf(C) + self.dYf(C))
|
| 470 |
+
return out_x + out_y + out_z
|
| 471 |
+
|
| 472 |
+
def Grad_SimscalarV(self, C, Vlst):
|
| 473 |
+
V = Vlst['V']
|
| 474 |
+
Upwind_C = Upwind(C, self.data_spacing)
|
| 475 |
+
if self.dimension == 1:
|
| 476 |
+
C_x = Upwind_C.dX(V)
|
| 477 |
+
return - V * C_x
|
| 478 |
+
if self.dimension == 2:
|
| 479 |
+
C_x, C_y = Upwind_C.dX(V), Upwind_C.dY(V)
|
| 480 |
+
return - V * (C_x + C_y)
|
| 481 |
+
if self.dimension == 3:
|
| 482 |
+
C_x, C_y, C_z = Upwind_C.dX(V), Upwind_C.dY(V), Upwind_C.dZ(V)
|
| 483 |
+
return - V * (C_x + C_y + C_z)
|
| 484 |
+
|
| 485 |
+
def Grad_scalarV(self, C, Vlst):
|
| 486 |
+
V = Vlst['V']
|
| 487 |
+
Upwind_C = Upwind(C, self.data_spacing)
|
| 488 |
+
dV = gradient_c(V, batched = True, delta_lst = self.data_spacing)
|
| 489 |
+
if self.dimension == 1:
|
| 490 |
+
C_x = Upwind_C.dX(V)
|
| 491 |
+
return - V * C_x - C * dV
|
| 492 |
+
elif self.dimension == 2:
|
| 493 |
+
C_x, C_y = Upwind_C.dX(V), Upwind_C.dY(V)
|
| 494 |
+
return - V * (C_x + C_y) - C * dV.sum(-1)
|
| 495 |
+
elif self.dimension == 3:
|
| 496 |
+
C_x, C_y, C_z = Upwind_C.dX(V), Upwind_C.dY(V), Upwind_C.dZ(V)
|
| 497 |
+
return - V * (C_x + C_y + C_z) - C * dV.sum(-1)
|
| 498 |
+
|
| 499 |
+
def Grad_div_free_vectorV(self, C, Vlst):
|
| 500 |
+
''' For divergence-free-by-definition velocity'''
|
| 501 |
+
if self.dimension == 1:
|
| 502 |
+
raise NotImplementedError('clebschVector is not supported for 1D version of velocity')
|
| 503 |
+
Upwind_C = Upwind(C, self.data_spacing)
|
| 504 |
+
C_x, C_y = Upwind_C.dX(Vlst['Vx']), Upwind_C.dY(Vlst['Vy'])
|
| 505 |
+
if self.dimension == 2:
|
| 506 |
+
return - (Vlst['Vx'] * C_x + Vlst['Vy'] * C_y)
|
| 507 |
+
elif self.dimension == 3:
|
| 508 |
+
C_z = Upwind_C.dZ(Vlst['Vz'])
|
| 509 |
+
return - (Vlst['Vx'] * C_x + Vlst['Vy'] * C_y + Vlst['Vz'] * C_z)
|
| 510 |
+
|
| 511 |
+
def Grad_vectorV(self, C, Vlst):
|
| 512 |
+
''' For general velocity'''
|
| 513 |
+
if self.dimension == 1:
|
| 514 |
+
raise NotImplementedError('vector is not supported for 1D version of velocity')
|
| 515 |
+
Upwind_C = Upwind(C, self.data_spacing)
|
| 516 |
+
C_x, C_y = Upwind_C.dX(Vlst['Vx']), Upwind_C.dY(Vlst['Vy'])
|
| 517 |
+
Vx_x = self.dXc(Vlst['Vx'])
|
| 518 |
+
Vy_y = self.dYc(Vlst['Vy'])
|
| 519 |
+
if self.dimension == 2:
|
| 520 |
+
return - (Vlst['Vx'] * C_x + Vlst['Vy'] * C_y) - C * (Vx_x + Vy_y)
|
| 521 |
+
if self.dimension == 3:
|
| 522 |
+
C_z = Upwind_C.dZ(Vlst['Vz'])
|
| 523 |
+
Vz_z = self.dZc(Vlst['Vz'])
|
| 524 |
+
return - (Vlst['Vx'] * C_x + Vlst['Vy'] * C_y + Vlst['Vz'] * C_z) - C * (Vx_x + Vy_y + Vz_z)
|
| 525 |
+
|
| 526 |
+
################# Utilities #################
|
| 527 |
+
def db(self, X):
|
| 528 |
+
return gradient_b(X, batched = True, delta_lst = self.data_spacing)
|
| 529 |
+
def df(self, X):
|
| 530 |
+
return gradient_f(X, batched = True, delta_lst = self.data_spacing)
|
| 531 |
+
def dc(self, X):
|
| 532 |
+
return gradient_c(X, batched = True, delta_lst = self.data_spacing)
|
| 533 |
+
def dXb(self, X):
|
| 534 |
+
return gradient_b(X, batched = True, delta_lst = self.data_spacing)[..., 0]
|
| 535 |
+
def dXf(self, X):
|
| 536 |
+
return gradient_f(X, batched = True, delta_lst = self.data_spacing)[..., 0]
|
| 537 |
+
def dXc(self, X):
|
| 538 |
+
return gradient_c(X, batched = True, delta_lst = self.data_spacing)[..., 0]
|
| 539 |
+
def dYb(self, X):
|
| 540 |
+
return gradient_b(X, batched = True, delta_lst = self.data_spacing)[..., 1]
|
| 541 |
+
def dYf(self, X):
|
| 542 |
+
return gradient_f(X, batched = True, delta_lst = self.data_spacing)[..., 1]
|
| 543 |
+
def dYc(self, X):
|
| 544 |
+
return gradient_c(X, batched = True, delta_lst = self.data_spacing)[..., 1]
|
| 545 |
+
def dZb(self, X):
|
| 546 |
+
return gradient_b(X, batched = True, delta_lst = self.data_spacing)[..., 2]
|
| 547 |
+
def dZf(self, X):
|
| 548 |
+
return gradient_f(X, batched = True, delta_lst = self.data_spacing)[..., 2]
|
| 549 |
+
def dZc(self, X):
|
| 550 |
+
return gradient_c(X, batched = True, delta_lst = self.data_spacing)[..., 2]
|
| 551 |
+
def ddXc(self, X):
|
| 552 |
+
return gradient_b(gradient_f(X, batched = True, delta_lst = self.data_spacing)[..., 0],
|
| 553 |
+
batched = True, delta_lst = self.data_spacing)[..., 0]
|
| 554 |
+
def ddYc(self, X):
|
| 555 |
+
return gradient_b(gradient_f(X, batched = True, delta_lst = self.data_spacing)[..., 1],
|
| 556 |
+
batched = True, delta_lst = self.data_spacing)[..., 1]
|
| 557 |
+
def ddZc(self, X):
|
| 558 |
+
return gradient_b(gradient_f(X, batched = True, delta_lst = self.data_spacing)[..., 2],
|
| 559 |
+
batched = True, delta_lst = self.data_spacing)[..., 2]
|
| 560 |
+
|
| 561 |
+
|
| 562 |
+
|
| 563 |
+
class AdvDiffPDE(nn.Module):
|
| 564 |
+
'''
|
| 565 |
+
Plain advection-diffusion PDE solver for pre-set V_lst and D_lst (1D, 2D, 3D) for forward time series simulation
|
| 566 |
+
'''
|
| 567 |
+
def __init__(self, data_spacing, perf_pattern, D_type='scalar', V_type='vector', BC=None, dt=0.1, V_dict={}, D_dict={}, stochastic=False, device='cpu'):
|
| 568 |
+
super(AdvDiffPDE, self).__init__()
|
| 569 |
+
self.BC = BC
|
| 570 |
+
self.dt = dt
|
| 571 |
+
self.dimension = len(data_spacing)
|
| 572 |
+
self.perf_pattern = perf_pattern
|
| 573 |
+
self.partials = AdvDiffPartial(data_spacing, device)
|
| 574 |
+
self.D_type, self.V_type = D_type, V_type
|
| 575 |
+
self.stochastic = stochastic
|
| 576 |
+
self.V_dict, self.D_dict = V_dict, D_dict
|
| 577 |
+
self.Sigma, self.Sigma_V, self.Sigma_D = 0., 0., 0. # Only for initialization #
|
| 578 |
+
if self.dimension == 1:
|
| 579 |
+
self.neumann_BC = torch.nn.ReplicationPad1d(1)
|
| 580 |
+
elif self.dimension == 2:
|
| 581 |
+
self.neumann_BC = torch.nn.ReplicationPad2d(1)
|
| 582 |
+
elif self.dimension == 3:
|
| 583 |
+
self.neumann_BC = torch.nn.ReplicationPad3d(1)
|
| 584 |
+
else:
|
| 585 |
+
raise ValueError('Unsupported dimension: %d' % self.dimension)
|
| 586 |
+
|
| 587 |
+
@property
|
| 588 |
+
def set_BC(self):
|
| 589 |
+
# NOTE For bondary condition of mass concentration #
|
| 590 |
+
'''X: (n_batch, spatial_shape)'''
|
| 591 |
+
if self.BC == 'neumann' or self.BC == 'cauchy':
|
| 592 |
+
if self.dimension == 1:
|
| 593 |
+
return lambda X: self.neumann_BC(X[:, 1:-1].unsqueeze(dim=1))[:,0]
|
| 594 |
+
elif self.dimension == 2:
|
| 595 |
+
return lambda X: self.neumann_BC(X[:, 1:-1, 1:-1].unsqueeze(dim=1))[:,0]
|
| 596 |
+
elif self.dimension == 3:
|
| 597 |
+
return lambda X: self.neumann_BC(X[:, 1:-1, 1:-1, 1:-1].unsqueeze(dim=1))[:,0]
|
| 598 |
+
else:
|
| 599 |
+
raise NotImplementedError('Unsupported B.C.!')
|
| 600 |
+
elif self.BC == 'dirichlet_neumann' or self.BC == 'source_neumann':
|
| 601 |
+
ctrl_wdth = 1
|
| 602 |
+
if self.dimension == 1:
|
| 603 |
+
self.dirichlet_BC = torch.nn.ReplicationPad1d(ctrl_wdth)
|
| 604 |
+
return lambda X: self.dirichlet_BC(X[:, ctrl_wdth : -ctrl_wdth].unsqueeze(dim=1))[:,0]
|
| 605 |
+
elif self.dimension == 2:
|
| 606 |
+
self.dirichlet_BC = torch.nn.ReplicationPad2d(ctrl_wdth)
|
| 607 |
+
return lambda X: self.dirichlet_BC(X[:, ctrl_wdth : -ctrl_wdth, ctrl_wdth : -ctrl_wdth].unsqueeze(dim=1))[:,0]
|
| 608 |
+
elif self.dimension == 3:
|
| 609 |
+
self.dirichlet_BC = torch.nn.ReplicationPad3d(ctrl_wdth)
|
| 610 |
+
return lambda X: self.neumann_dirichlet_BCBC(X[:, ctrl_wdth : -ctrl_wdth, ctrl_wdth : -ctrl_wdth, ctrl_wdth : -ctrl_wdth].unsqueeze(dim=1))[:,0]
|
| 611 |
+
else:
|
| 612 |
+
raise NotImplementedError('Unsupported B.C.!')
|
| 613 |
+
else:
|
| 614 |
+
return lambda X: X
|
| 615 |
+
|
| 616 |
+
def forward(self, t, batch_C):
|
| 617 |
+
'''
|
| 618 |
+
t: (batch_size,)
|
| 619 |
+
batch_C: (batch_size, (slc,) row, col)
|
| 620 |
+
'''
|
| 621 |
+
batch_size = batch_C.size(0)
|
| 622 |
+
batch_C = self.set_BC(batch_C)
|
| 623 |
+
if 'diff' not in self.perf_pattern:
|
| 624 |
+
out = self.partials.Grad_Vs[self.V_type](batch_C, self.V_dict)
|
| 625 |
+
if self.stochastic:
|
| 626 |
+
out = out + self.Sigma * math.sqrt(self.dt) * torch.randn_like(batch_C).to(batch_C)
|
| 627 |
+
elif 'adv' not in self.perf_pattern:
|
| 628 |
+
out = self.partials.Grad_Ds[self.D_type](batch_C, self.D_dict)
|
| 629 |
+
if self.stochastic:
|
| 630 |
+
out = out + self.Sigma * math.sqrt(self.dt) * torch.randn_like(batch_C).to(batch_C)
|
| 631 |
+
else:
|
| 632 |
+
if self.stochastic:
|
| 633 |
+
out_D = self.partials.Grad_Ds[self.D_type](batch_C, self.D_dict)
|
| 634 |
+
out_V = self.partials.Grad_Vs[self.V_type](batch_C, self.V_dict)
|
| 635 |
+
out = out_D + out_V + self.Sigma * math.sqrt(self.dt) * torch.randn_like(batch_C).to(batch_C)
|
| 636 |
+
else:
|
| 637 |
+
out_V = self.partials.Grad_Vs[self.V_type](batch_C, self.V_dict)
|
| 638 |
+
out_D = self.partials.Grad_Ds[self.D_type](batch_C, self.D_dict)
|
| 639 |
+
out = out_V + out_D
|
| 640 |
+
return out
|
| 641 |
+
|
| 642 |
+
|
| 643 |
+
|
ShapeID/DiffEqs/rk_common.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Based on https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/integrate
|
| 2 |
+
import collections
|
| 3 |
+
from ShapeID.DiffEqs.misc import _scaled_dot_product, _convert_to_tensor
|
| 4 |
+
|
| 5 |
+
_ButcherTableau = collections.namedtuple('_ButcherTableau', 'alpha beta c_sol c_error')
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class _RungeKuttaState(collections.namedtuple('_RungeKuttaState', 'y1, f1, t0, t1, dt, interp_coeff')):
|
| 9 |
+
"""Saved state of the Runge Kutta solver.
|
| 10 |
+
|
| 11 |
+
Attributes:
|
| 12 |
+
y1: Tensor giving the function value at the end of the last time step.
|
| 13 |
+
f1: Tensor giving derivative at the end of the last time step.
|
| 14 |
+
t0: scalar float64 Tensor giving start of the last time step.
|
| 15 |
+
t1: scalar float64 Tensor giving end of the last time step.
|
| 16 |
+
dt: scalar float64 Tensor giving the size for the next time step.
|
| 17 |
+
interp_coef: list of Tensors giving coefficients for polynomial
|
| 18 |
+
interpolation between `t0` and `t1`.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def _runge_kutta_step(func, y0, f0, t0, dt, tableau):
|
| 23 |
+
"""Take an arbitrary Runge-Kutta step and estimate error.
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
func: Function to evaluate like `func(t, y)` to compute the time derivative
|
| 27 |
+
of `y`.
|
| 28 |
+
y0: Tensor initial value for the state.
|
| 29 |
+
f0: Tensor initial value for the derivative, computed from `func(t0, y0)`.
|
| 30 |
+
t0: float64 scalar Tensor giving the initial time.
|
| 31 |
+
dt: float64 scalar Tensor giving the size of the desired time step.
|
| 32 |
+
tableau: optional _ButcherTableau describing how to take the Runge-Kutta
|
| 33 |
+
step.
|
| 34 |
+
name: optional name for the operation.
|
| 35 |
+
|
| 36 |
+
Returns:
|
| 37 |
+
Tuple `(y1, f1, y1_error, k)` giving the estimated function value after
|
| 38 |
+
the Runge-Kutta step at `t1 = t0 + dt`, the derivative of the state at `t1`,
|
| 39 |
+
estimated error at `t1`, and a list of Runge-Kutta coefficients `k` used for
|
| 40 |
+
calculating these terms.
|
| 41 |
+
"""
|
| 42 |
+
dtype = y0[0].dtype
|
| 43 |
+
device = y0[0].device
|
| 44 |
+
|
| 45 |
+
t0 = _convert_to_tensor(t0, dtype=dtype, device=device)
|
| 46 |
+
dt = _convert_to_tensor(dt, dtype=dtype, device=device)
|
| 47 |
+
|
| 48 |
+
k = tuple(map(lambda x: [x], f0))
|
| 49 |
+
for alpha_i, beta_i in zip(tableau.alpha, tableau.beta):
|
| 50 |
+
ti = t0 + alpha_i * dt
|
| 51 |
+
yi = tuple(y0_ + _scaled_dot_product(dt, beta_i, k_) for y0_, k_ in zip(y0, k))
|
| 52 |
+
tuple(k_.append(f_) for k_, f_ in zip(k, func(ti, yi)))
|
| 53 |
+
|
| 54 |
+
if not (tableau.c_sol[-1] == 0 and tableau.c_sol[:-1] == tableau.beta[-1]):
|
| 55 |
+
# This property (true for Dormand-Prince) lets us save a few FLOPs.
|
| 56 |
+
yi = tuple(y0_ + _scaled_dot_product(dt, tableau.c_sol, k_) for y0_, k_ in zip(y0, k))
|
| 57 |
+
|
| 58 |
+
y1 = yi
|
| 59 |
+
f1 = tuple(k_[-1] for k_ in k)
|
| 60 |
+
y1_error = tuple(_scaled_dot_product(dt, tableau.c_error, k_) for k_ in k)
|
| 61 |
+
return (y1, f1, y1_error, k)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def rk4_step_func(func, t, dt, y, k1=None):
|
| 65 |
+
if k1 is None: k1 = func(t, y)
|
| 66 |
+
k2 = func(t + dt / 2, tuple(y_ + dt * k1_ / 2 for y_, k1_ in zip(y, k1)))
|
| 67 |
+
k3 = func(t + dt / 2, tuple(y_ + dt * k2_ / 2 for y_, k2_ in zip(y, k2)))
|
| 68 |
+
k4 = func(t + dt, tuple(y_ + dt * k3_ for y_, k3_ in zip(y, k3)))
|
| 69 |
+
return tuple((k1_ + 2 * k2_ + 2 * k3_ + k4_) * (dt / 6) for k1_, k2_, k3_, k4_ in zip(k1, k2, k3, k4))
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def rk4_alt_step_func(func, t, dt, y, k1=None):
|
| 73 |
+
"""Smaller error with slightly more compute."""
|
| 74 |
+
if k1 is None: k1 = func(t, y)
|
| 75 |
+
k2 = func(t + dt / 3, tuple(y_ + dt * k1_ / 3 for y_, k1_ in zip(y, k1)))
|
| 76 |
+
k3 = func(t + dt * 2 / 3, tuple(y_ + dt * (k1_ / -3 + k2_) for y_, k1_, k2_ in zip(y, k1, k2)))
|
| 77 |
+
k4 = func(t + dt, tuple(y_ + dt * (k1_ - k2_ + k3_) for y_, k1_, k2_, k3_ in zip(y, k1, k2, k3)))
|
| 78 |
+
return tuple((k1_ + 3 * k2_ + 3 * k3_ + k4_) * (dt / 8) for k1_, k2_, k3_, k4_ in zip(k1, k2, k3, k4))
|
ShapeID/DiffEqs/solvers.py
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import abc
|
| 2 |
+
import torch
|
| 3 |
+
from ShapeID.DiffEqs.misc import _assert_increasing, _handle_unused_kwargs
|
| 4 |
+
|
| 5 |
+
def set_BC_2D(X, BCs): # X: (n_batch, spatial_size); BCs: (batch, 4, BC_shape, data_dim)
|
| 6 |
+
BC_size = BCs.size(2)
|
| 7 |
+
X[:, : BC_size] = BCs[:, 0]
|
| 8 |
+
X[:, - BC_size :] = BCs[:, 1]
|
| 9 |
+
X[:, :, : BC_size] = BCs[:, 2].permute(0, 2, 1) # (batch, BC_shape, r) -> (batch, r, BC_shape)
|
| 10 |
+
X[:, :, - BC_size :] = BCs[:, 3].permute(0, 2, 1) # (batch, BC_shape, r) -> (batch, r, BC_shape)
|
| 11 |
+
del BCs
|
| 12 |
+
return X
|
| 13 |
+
def set_BC_3D(X, BCs): # X: (n_batch, spatial_size); BCs: (batch, 6, BC_shape, data_dim, dta_dim)
|
| 14 |
+
BC_size = BCs.size(2)
|
| 15 |
+
X[:, : BC_size] = BCs[:, 0]
|
| 16 |
+
X[:, - BC_size :] = BCs[:, 1]
|
| 17 |
+
X[:, :, : BC_size] = BCs[:, 2].permute(0, 2, 1, 3) # (batch, BC_shape, s, c) -> (batch, s, BC_shape, c)
|
| 18 |
+
X[:, :, - BC_size :] = BCs[:, 3].permute(0, 2, 1, 3) # (batch, BC_shape, s, c) -> (batch, s, BC_shape, c)
|
| 19 |
+
X[:, :, :, : BC_size] = BCs[:, 4].permute(0, 2, 3, 1) # (batch, BC_shape, s, r) -> (batch, s, r, BC_shape)
|
| 20 |
+
X[:, :, :, - BC_size :] = BCs[:, 5].permute(0, 2, 3, 1) # (batch, BC_shape, s, r) -> (batch, s, r, BC_shape)
|
| 21 |
+
del BCs
|
| 22 |
+
return X
|
| 23 |
+
|
| 24 |
+
''' X[t] = X[t] + dBC[t] (dBC[t] = BC[t+1] - BC[t]) '''
|
| 25 |
+
def add_dBC_2D(X, dBCs): # X: (n_batch, spatial_size); BCs: (batch, 4, BC_shape, data_dim)
|
| 26 |
+
BC_size = dBCs.size(2)
|
| 27 |
+
X[:, : BC_size] += dBCs[:, 0]
|
| 28 |
+
X[:, - BC_size :] += dBCs[:, 1]
|
| 29 |
+
X[:, :, : BC_size] += dBCs[:, 2].permute(0, 2, 1) # (batch, BC_shape, r) -> (batch, r, BC_shape)
|
| 30 |
+
X[:, :, - BC_size :] += dBCs[:, 3].permute(0, 2, 1) # (batch, BC_shape, r) -> (batch, r, BC_shape)
|
| 31 |
+
del dBCs
|
| 32 |
+
return X
|
| 33 |
+
def add_dBC_3D(X, dBCs): # X: (n_batch, spatial_size); BCs: (batch, 6, BC_shape, data_dim, dta_dim)
|
| 34 |
+
BC_size = dBCs.size(2)
|
| 35 |
+
X[:, : BC_size] += dBCs[:, 0]
|
| 36 |
+
X[:, - BC_size :] += dBCs[:, 1]
|
| 37 |
+
X[:, :, : BC_size] += dBCs[:, 2].permute(0, 2, 1, 3) # (batch, BC_shape, s, c) -> (batch, s, BC_shape, c)
|
| 38 |
+
X[:, :, - BC_size :] += dBCs[:, 3].permute(0, 2, 1, 3) # (batch, BC_shape, s, c) -> (batch, s, BC_shape, c)
|
| 39 |
+
X[:, :, :, : BC_size] += dBCs[:, 4].permute(0, 2, 3, 1) # (batch, BC_shape, s, r) -> (batch, s, r, BC_shape)
|
| 40 |
+
X[:, :, :, - BC_size :] += dBCs[:, 5].permute(0, 2, 3, 1) # (batch, BC_shape, s, r) -> (batch, s, r, BC_shape)
|
| 41 |
+
del dBCs
|
| 42 |
+
return X
|
| 43 |
+
|
| 44 |
+
class AdaptiveStepsizeODESolver(object):
|
| 45 |
+
__metaclass__ = abc.ABCMeta
|
| 46 |
+
|
| 47 |
+
def __init__(self, func, y0, atol, rtol, options= None):
|
| 48 |
+
|
| 49 |
+
# _handle_unused_kwargs(self, options)
|
| 50 |
+
#del options
|
| 51 |
+
self.func = func
|
| 52 |
+
self.y0 = y0
|
| 53 |
+
self.atol = atol
|
| 54 |
+
self.rtol = rtol
|
| 55 |
+
|
| 56 |
+
def before_integrate(self, t):
|
| 57 |
+
pass
|
| 58 |
+
|
| 59 |
+
@abc.abstractmethod
|
| 60 |
+
def advance(self, next_t):
|
| 61 |
+
raise NotImplementedError
|
| 62 |
+
|
| 63 |
+
def integrate(self, t):
|
| 64 |
+
_assert_increasing(t)
|
| 65 |
+
solution = [self.y0]
|
| 66 |
+
t = t.to(self.y0[0].device, torch.float64)
|
| 67 |
+
self.before_integrate(t)
|
| 68 |
+
for i in range(1, len(t)):
|
| 69 |
+
y = self.advance(t[i])
|
| 70 |
+
solution.append(y)
|
| 71 |
+
'''if self.contours is not None: # contours: (n_batch, nT, 4 / 6, BC_size, c)
|
| 72 |
+
if self.adjoint:
|
| 73 |
+
for i in range(1, len(t)):
|
| 74 |
+
ys = list(self.advance(t[i])) # tuple: (y0, **back_grad) -> y0: (n_batch, spatial_shape)
|
| 75 |
+
#print(len(t))
|
| 76 |
+
#print(ys[0].size())
|
| 77 |
+
#print(self.contours.size())
|
| 78 |
+
ys[0] = self.set_BC(ys[0], self.contours[:, i]) # (n_batch, 4 / 6, BC_size, c)
|
| 79 |
+
solution.append(tuple(ys))
|
| 80 |
+
else:
|
| 81 |
+
for i in range(1, len(t)):
|
| 82 |
+
y = torch.stack(self.advance(t[i])) # y: (n_batch, 1, spatial_shape)
|
| 83 |
+
y = self.set_BC(y[:, 0], self.contours[:, i]).unsqueeze(1)
|
| 84 |
+
solution.append(tuple(y))
|
| 85 |
+
elif self.dcontours is not None: # dcontours: (n_batch, nT, 4 / 6, BC_size, c)
|
| 86 |
+
if self.adjoint:
|
| 87 |
+
for i in range(1, len(t)):
|
| 88 |
+
ys = list(self.advance(t[i])) # ys - tuple: (y0, **back_grad) -> y0: (n_batch, spatial_shape)
|
| 89 |
+
ys[0] = self.add_dBC(ys[0], self.dcontours[:, i]) # (n_batch, 4 / 6, BC_size, c)
|
| 90 |
+
solution.append(tuple(ys))
|
| 91 |
+
else:
|
| 92 |
+
for i in range(1, len(t)):
|
| 93 |
+
y = torch.stack(self.advance(t[i])) # (n_batch, 1, spatial_shape)
|
| 94 |
+
y = self.add_dBC(y[:, 0], self.dcontours[:, i]).unsqueeze(1)
|
| 95 |
+
solution.append(tuple(y))
|
| 96 |
+
else:
|
| 97 |
+
for i in range(1, len(t)):
|
| 98 |
+
y = self.advance(t[i])
|
| 99 |
+
solution.append(y)'''
|
| 100 |
+
return tuple(map(torch.stack, tuple(zip(*solution))))
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class FixedGridODESolver(object):
|
| 104 |
+
__metaclass__ = abc.ABCMeta
|
| 105 |
+
|
| 106 |
+
def __init__(self, func, y0, step_size=None, grid_constructor=None, atol=None, rtol=None, dt=None, options = None):
|
| 107 |
+
'''if 'dirichlet' in options.BC or 'cauchy' in options.BC and options.contours is not None:
|
| 108 |
+
self.contours = options.contours # (n_batch, nT, 4 / 6, BC_size, sub_spatial_shape)
|
| 109 |
+
self.BC_size = self.contours.size(3)
|
| 110 |
+
self.set_BC = set_BC_2D if self.contours.size(2) == 4 else set_BC_3D
|
| 111 |
+
else:
|
| 112 |
+
self.contours = None
|
| 113 |
+
if 'source' in options.BC and options.dcontours is not None:
|
| 114 |
+
self.dcontours = options.dcontours # (n_batch, nT, 4 / 6, BC_size, sub_spatial_shape)
|
| 115 |
+
self.BC_size = self.dcontours.size(3)
|
| 116 |
+
self.add_dBC = add_dBC_2D if self.dcontours.size(2) == 4 else add_dBC_3D
|
| 117 |
+
else:
|
| 118 |
+
self.dcontours = None'''
|
| 119 |
+
#self.adjoint = options.adjoint
|
| 120 |
+
#options.pop('rtol', None)
|
| 121 |
+
#options.pop('atol', None)
|
| 122 |
+
#_handle_unused_kwargs(self, options)
|
| 123 |
+
#del options
|
| 124 |
+
|
| 125 |
+
self.func = func
|
| 126 |
+
self.y0 = y0
|
| 127 |
+
|
| 128 |
+
if step_size is not None and grid_constructor is None:
|
| 129 |
+
self.grid_constructor = self._grid_constructor_from_step_size(step_size)
|
| 130 |
+
elif grid_constructor is None:
|
| 131 |
+
self.grid_constructor = lambda f, y0, t: t # Same time step as time interval
|
| 132 |
+
else:
|
| 133 |
+
raise ValueError("step_size and grid_constructor are exclusive arguments.")
|
| 134 |
+
|
| 135 |
+
def _grid_constructor_from_step_size(self, step_size):
|
| 136 |
+
|
| 137 |
+
def _grid_constructor(func, y0, t):
|
| 138 |
+
start_time = t[0]
|
| 139 |
+
end_time = t[-1]
|
| 140 |
+
|
| 141 |
+
niters = torch.ceil((end_time - start_time) / step_size + 1).item()
|
| 142 |
+
t_infer = torch.arange(0, niters).to(t) * step_size + start_time
|
| 143 |
+
if t_infer[-1] > t[-1]:
|
| 144 |
+
t_infer[-1] = t[-1]
|
| 145 |
+
return t_infer
|
| 146 |
+
|
| 147 |
+
return _grid_constructor
|
| 148 |
+
|
| 149 |
+
@property
|
| 150 |
+
@abc.abstractmethod
|
| 151 |
+
def order(self):
|
| 152 |
+
pass
|
| 153 |
+
|
| 154 |
+
@abc.abstractmethod
|
| 155 |
+
def step_func(self, func, t, dt, y):
|
| 156 |
+
pass
|
| 157 |
+
|
| 158 |
+
def integrate(self, t):
|
| 159 |
+
_assert_increasing(t)
|
| 160 |
+
t = t.type_as(self.y0[0]) # (n_time, )
|
| 161 |
+
time_grid = self.grid_constructor(self.func, self.y0, t)
|
| 162 |
+
#print('time_grid:', time_grid.size())
|
| 163 |
+
#print('t:', t.size())
|
| 164 |
+
assert time_grid[0] == t[0] and time_grid[-1] == t[-1]
|
| 165 |
+
time_grid = time_grid.to(self.y0[0])
|
| 166 |
+
|
| 167 |
+
solution = [self.y0]
|
| 168 |
+
|
| 169 |
+
j = 1
|
| 170 |
+
y0 = self.y0
|
| 171 |
+
for t0, t1 in zip(time_grid[:-1], time_grid[1:]):
|
| 172 |
+
dy = self.step_func(self.func, t0, t1 - t0, y0)
|
| 173 |
+
y1 = tuple(y0_ + dy_ for y0_, dy_ in zip(y0, dy))
|
| 174 |
+
y0 = y1
|
| 175 |
+
while j < len(t) and t1 >= t[j]:
|
| 176 |
+
solution.append(self._linear_interp(t0, t1, y0, y1, t[j]))
|
| 177 |
+
j += 1
|
| 178 |
+
'''if self.contours is not None:
|
| 179 |
+
if self.adjoint:
|
| 180 |
+
for i in range(1, len(t)):
|
| 181 |
+
ys = list(self._linear_interp(t0, t1, y0, y1, t[j])) # tuple: (y0, **back_grad) -> y0: (n_batch, spatial_shape)
|
| 182 |
+
ys[0] = self.set_BC(ys[0], self.contours[:, i]) # (n_batch, 4 / 6, BC_size, c)
|
| 183 |
+
solution.append(tuple(ys))
|
| 184 |
+
j += 1
|
| 185 |
+
else:
|
| 186 |
+
while j < len(t) and t1 >= t[j]:
|
| 187 |
+
y = torch.stack(self._linear_interp(t0, t1, y0, y1, t[j])) # (n_batch, 1, spatial_shape)
|
| 188 |
+
y = self.set_BC(y[:, 0], self.contours[:, j]).unsqueeze(1)
|
| 189 |
+
solution.append(tuple(y))
|
| 190 |
+
j += 1
|
| 191 |
+
elif self.dcontours is not None:
|
| 192 |
+
if self.adjoint:
|
| 193 |
+
for i in range(1, len(t)):
|
| 194 |
+
ys = list(self._linear_interp(t0, t1, y0, y1, t[j])) # tuple: (y0, **back_grad) -> y0: (n_batch, spatial_shape)
|
| 195 |
+
ys[0] = self.add_dBC(ys[0], self.dcontours[:, j]) # (n_batch, 4 / 6, BC_size, c)
|
| 196 |
+
solution.append(tuple(ys))
|
| 197 |
+
else:
|
| 198 |
+
while j < len(t) and t1 >= t[j]:
|
| 199 |
+
y = torch.stack(self._linear_interp(t0, t1, y0, y1, t[j])) # (n_batch, 1, spatial_shape)
|
| 200 |
+
y = self.add_dBC(y[:, 0], self.dcontours[:, j]).unsqueeze(1)
|
| 201 |
+
solution.append(tuple(y))
|
| 202 |
+
j += 1
|
| 203 |
+
else:
|
| 204 |
+
while j < len(t) and t1 >= t[j]:
|
| 205 |
+
solution.append(self._linear_interp(t0, t1, y0, y1, t[j]))
|
| 206 |
+
j += 1'''
|
| 207 |
+
return tuple(map(torch.stack, tuple(zip(*solution)))) # (batch, time)
|
| 208 |
+
|
| 209 |
+
def _linear_interp(self, t0, t1, y0, y1, t):
|
| 210 |
+
if t == t0:
|
| 211 |
+
return y0
|
| 212 |
+
if t == t1:
|
| 213 |
+
return y1
|
| 214 |
+
t0, t1, t = t0.to(y0[0]), t1.to(y0[0]), t.to(y0[0])
|
| 215 |
+
slope = tuple((y1_ - y0_) / (t1 - t0) for y0_, y1_, in zip(y0, y1))
|
| 216 |
+
return tuple(y0_ + slope_ * (t - t0) for y0_, slope_ in zip(y0, slope))
|
ShapeID/DiffEqs/tsit5.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from ShapeID.DiffEqs.misc import _scaled_dot_product, _convert_to_tensor, _is_finite, _select_initial_step, _handle_unused_kwargs
|
| 3 |
+
from ShapeID.DiffEqs.solvers import AdaptiveStepsizeODESolver
|
| 4 |
+
from ShapeID.DiffEqs.rk_common import _RungeKuttaState, _ButcherTableau, _runge_kutta_step
|
| 5 |
+
|
| 6 |
+
# Parameters from Tsitouras (2011).
|
| 7 |
+
_TSITOURAS_TABLEAU = _ButcherTableau(
|
| 8 |
+
alpha=[0.161, 0.327, 0.9, 0.9800255409045097, 1., 1.],
|
| 9 |
+
beta=[
|
| 10 |
+
[0.161],
|
| 11 |
+
[-0.008480655492357, 0.3354806554923570],
|
| 12 |
+
[2.897153057105494, -6.359448489975075, 4.362295432869581],
|
| 13 |
+
[5.32586482843925895, -11.74888356406283, 7.495539342889836, -0.09249506636175525],
|
| 14 |
+
[5.86145544294642038, -12.92096931784711, 8.159367898576159, -0.071584973281401006, -0.02826905039406838],
|
| 15 |
+
[0.09646076681806523, 0.01, 0.4798896504144996, 1.379008574103742, -3.290069515436081, 2.324710524099774],
|
| 16 |
+
],
|
| 17 |
+
c_sol=[0.09646076681806523, 0.01, 0.4798896504144996, 1.379008574103742, -3.290069515436081, 2.324710524099774, 0],
|
| 18 |
+
c_error=[
|
| 19 |
+
0.09646076681806523 - 0.001780011052226,
|
| 20 |
+
0.01 - 0.000816434459657,
|
| 21 |
+
0.4798896504144996 - -0.007880878010262,
|
| 22 |
+
1.379008574103742 - 0.144711007173263,
|
| 23 |
+
-3.290069515436081 - -0.582357165452555,
|
| 24 |
+
2.324710524099774 - 0.458082105929187,
|
| 25 |
+
-1 / 66,
|
| 26 |
+
],
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def _interp_coeff_tsit5(t0, dt, eval_t):
|
| 31 |
+
t = float((eval_t - t0) / dt)
|
| 32 |
+
b1 = -1.0530884977290216 * t * (t - 1.3299890189751412) * (t**2 - 1.4364028541716351 * t + 0.7139816917074209)
|
| 33 |
+
b2 = 0.1017 * t**2 * (t**2 - 2.1966568338249754 * t + 1.2949852507374631)
|
| 34 |
+
b3 = 2.490627285651252793 * t**2 * (t**2 - 2.38535645472061657 * t + 1.57803468208092486)
|
| 35 |
+
b4 = -16.54810288924490272 * (t - 1.21712927295533244) * (t - 0.61620406037800089) * t**2
|
| 36 |
+
b5 = 47.37952196281928122 * (t - 1.203071208372362603) * (t - 0.658047292653547382) * t**2
|
| 37 |
+
b6 = -34.87065786149660974 * (t - 1.2) * (t - 0.666666666666666667) * t**2
|
| 38 |
+
b7 = 2.5 * (t - 1) * (t - 0.6) * t**2
|
| 39 |
+
return [b1, b2, b3, b4, b5, b6, b7]
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def _interp_eval_tsit5(t0, t1, k, eval_t):
|
| 43 |
+
dt = t1 - t0
|
| 44 |
+
y0 = tuple(k_[0] for k_ in k)
|
| 45 |
+
interp_coeff = _interp_coeff_tsit5(t0, dt, eval_t)
|
| 46 |
+
y_t = tuple(y0_ + _scaled_dot_product(dt, interp_coeff, k_) for y0_, k_ in zip(y0, k))
|
| 47 |
+
return y_t
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _optimal_step_size(last_step, mean_error_ratio, safety=0.9, ifactor=10.0, dfactor=0.2, order=5):
|
| 51 |
+
"""Calculate the optimal size for the next Runge-Kutta step."""
|
| 52 |
+
if mean_error_ratio == 0:
|
| 53 |
+
return last_step * ifactor
|
| 54 |
+
if mean_error_ratio < 1:
|
| 55 |
+
dfactor = _convert_to_tensor(1, dtype=torch.float64, device=mean_error_ratio.device)
|
| 56 |
+
error_ratio = torch.sqrt(mean_error_ratio).type_as(last_step)
|
| 57 |
+
exponent = torch.tensor(1 / order).type_as(last_step)
|
| 58 |
+
factor = torch.max(1 / ifactor, torch.min(error_ratio**exponent / safety, 1 / dfactor))
|
| 59 |
+
return last_step / factor
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def _abs_square(x):
|
| 63 |
+
return torch.mul(x, x)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class Tsit5Solver(AdaptiveStepsizeODESolver):
|
| 67 |
+
|
| 68 |
+
def __init__(
|
| 69 |
+
self, func, y0, rtol, atol, first_step=None, safety=0.9, ifactor=10.0, dfactor=0.2, max_num_steps=2**31 - 1,
|
| 70 |
+
**unused_kwargs
|
| 71 |
+
):
|
| 72 |
+
_handle_unused_kwargs(self, unused_kwargs)
|
| 73 |
+
del unused_kwargs
|
| 74 |
+
|
| 75 |
+
self.func = func
|
| 76 |
+
self.y0 = y0
|
| 77 |
+
self.rtol = rtol
|
| 78 |
+
self.atol = atol
|
| 79 |
+
self.first_step = first_step
|
| 80 |
+
self.safety = _convert_to_tensor(safety, dtype=torch.float64, device=y0[0].device)
|
| 81 |
+
self.ifactor = _convert_to_tensor(ifactor, dtype=torch.float64, device=y0[0].device)
|
| 82 |
+
self.dfactor = _convert_to_tensor(dfactor, dtype=torch.float64, device=y0[0].device)
|
| 83 |
+
self.max_num_steps = _convert_to_tensor(max_num_steps, dtype=torch.int32, device=y0[0].device)
|
| 84 |
+
|
| 85 |
+
def before_integrate(self, t):
|
| 86 |
+
if self.first_step is None:
|
| 87 |
+
first_step = _select_initial_step(self.func, t[0], self.y0, 4, self.rtol, self.atol).to(t)
|
| 88 |
+
else:
|
| 89 |
+
first_step = _convert_to_tensor(0.01, dtype=t.dtype, device=t.device)
|
| 90 |
+
self.rk_state = _RungeKuttaState(
|
| 91 |
+
self.y0,
|
| 92 |
+
self.func(t[0].type_as(self.y0[0]), self.y0), t[0], t[0], first_step,
|
| 93 |
+
tuple(map(lambda x: [x] * 7, self.y0))
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
def advance(self, next_t):
|
| 97 |
+
"""Interpolate through the next time point, integrating as necessary."""
|
| 98 |
+
n_steps = 0
|
| 99 |
+
while next_t > self.rk_state.t1:
|
| 100 |
+
assert n_steps < self.max_num_steps, 'max_num_steps exceeded ({}>={})'.format(n_steps, self.max_num_steps)
|
| 101 |
+
self.rk_state = self._adaptive_tsit5_step(self.rk_state)
|
| 102 |
+
n_steps += 1
|
| 103 |
+
return _interp_eval_tsit5(self.rk_state.t0, self.rk_state.t1, self.rk_state.interp_coeff, next_t)
|
| 104 |
+
|
| 105 |
+
def _adaptive_tsit5_step(self, rk_state):
|
| 106 |
+
"""Take an adaptive Runge-Kutta step to integrate the DiffEqs."""
|
| 107 |
+
y0, f0, _, t0, dt, _ = rk_state
|
| 108 |
+
########################################################
|
| 109 |
+
# Assertions #
|
| 110 |
+
########################################################
|
| 111 |
+
assert t0 + dt > t0, 'underflow in dt {}'.format(dt.item())
|
| 112 |
+
for y0_ in y0:
|
| 113 |
+
assert _is_finite(torch.abs(y0_)), 'non-finite values in state `y`: {}'.format(y0_)
|
| 114 |
+
y1, f1, y1_error, k = _runge_kutta_step(self.func, y0, f0, t0, dt, tableau=_TSITOURAS_TABLEAU)
|
| 115 |
+
|
| 116 |
+
########################################################
|
| 117 |
+
# Error Ratio #
|
| 118 |
+
########################################################
|
| 119 |
+
error_tol = tuple(self.atol + self.rtol * torch.max(torch.abs(y0_), torch.abs(y1_)) for y0_, y1_ in zip(y0, y1))
|
| 120 |
+
tensor_error_ratio = tuple(y1_error_ / error_tol_ for y1_error_, error_tol_ in zip(y1_error, error_tol))
|
| 121 |
+
sq_error_ratio = tuple(
|
| 122 |
+
torch.mul(tensor_error_ratio_, tensor_error_ratio_) for tensor_error_ratio_ in tensor_error_ratio
|
| 123 |
+
)
|
| 124 |
+
mean_error_ratio = (
|
| 125 |
+
sum(torch.sum(sq_error_ratio_) for sq_error_ratio_ in sq_error_ratio) /
|
| 126 |
+
sum(sq_error_ratio_.numel() for sq_error_ratio_ in sq_error_ratio)
|
| 127 |
+
)
|
| 128 |
+
accept_step = mean_error_ratio <= 1
|
| 129 |
+
|
| 130 |
+
########################################################
|
| 131 |
+
# Update RK State #
|
| 132 |
+
########################################################
|
| 133 |
+
y_next = y1 if accept_step else y0
|
| 134 |
+
f_next = f1 if accept_step else f0
|
| 135 |
+
t_next = t0 + dt if accept_step else t0
|
| 136 |
+
dt_next = _optimal_step_size(dt, mean_error_ratio, self.safety, self.ifactor, self.dfactor)
|
| 137 |
+
k_next = k if accept_step else self.rk_state.interp_coeff
|
| 138 |
+
rk_state = _RungeKuttaState(y_next, f_next, t0, t_next, dt_next, k_next)
|
| 139 |
+
return rk_state
|
ShapeID/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from utils import *
|
ShapeID/demo2d.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ported from https://github.com/pvigier/perlin-numpy
|
| 2 |
+
|
| 3 |
+
import os, sys
|
| 4 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 5 |
+
|
| 6 |
+
import time, datetime
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import numpy as np
|
| 10 |
+
import matplotlib.pyplot as plt
|
| 11 |
+
|
| 12 |
+
from misc import stream_2D, V_plot
|
| 13 |
+
from utils.misc import viewVolume, make_dir
|
| 14 |
+
|
| 15 |
+
from perlin2d import *
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
#from ShapeID.DiffEqs.odeint import odeint
|
| 19 |
+
from ShapeID.DiffEqs.adjoint import odeint_adjoint as odeint
|
| 20 |
+
from ShapeID.DiffEqs.pde import AdvDiffPDE
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
if __name__ == '__main__':
|
| 26 |
+
|
| 27 |
+
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
image, mask_image = generate_perlin_noise_2d([256, 256], [2, 2], percentile = 80)
|
| 31 |
+
plt.imshow(image, cmap='gray') #, interpolation='lanczos')
|
| 32 |
+
plt.axis('off')
|
| 33 |
+
plt.savefig('out/2d/image.png')
|
| 34 |
+
plt.imshow(mask_image, cmap='gray') #, interpolation='lanczos')
|
| 35 |
+
plt.axis('off')
|
| 36 |
+
plt.savefig('out/2d/mask_image.png')
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
curl, mask_curl = generate_perlin_noise_2d([256, 256], [2, 2], percentile = 80)
|
| 41 |
+
plt.imshow(curl, cmap='gray') #, interpolation='lanczos')
|
| 42 |
+
plt.axis('off')
|
| 43 |
+
plt.savefig('out/2d/curl.png')
|
| 44 |
+
plt.imshow(mask_curl, cmap='gray') #, interpolation='lanczos')
|
| 45 |
+
plt.axis('off')
|
| 46 |
+
plt.savefig('out/2d/mask_curl.png')
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
dx, dy = stream_2D(torch.from_numpy(curl))
|
| 50 |
+
V_plot(dx.numpy(), dy.numpy(), 'out/2d/V.png')
|
| 51 |
+
|
| 52 |
+
plt.imshow(mask_image, cmap='gray') #, interpolation='lanczos')
|
| 53 |
+
plt.axis('off')
|
| 54 |
+
plt.savefig('out/2d/image_with_v.png')
|
| 55 |
+
#plt.close()
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
dt = 0.15
|
| 59 |
+
nt = 21
|
| 60 |
+
integ_method = 'dopri5' # choices=['dopri5', 'adams', 'rk4', 'euler']
|
| 61 |
+
t = torch.from_numpy(np.arange(nt) * dt).to(device)
|
| 62 |
+
thres = 0.9
|
| 63 |
+
|
| 64 |
+
initial = torch.from_numpy(mask_image)
|
| 65 |
+
Vx, Vy = dx * 1000, dy * 1000
|
| 66 |
+
|
| 67 |
+
forward_pde = AdvDiffPDE(data_spacing=[1., 1.],
|
| 68 |
+
perf_pattern='adv',
|
| 69 |
+
V_type='vector_div_free',
|
| 70 |
+
V_dict={'Vx': Vx, 'Vy': Vy},
|
| 71 |
+
BC='neumann',
|
| 72 |
+
dt=dt,
|
| 73 |
+
device=device
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
start_time = time.time()
|
| 78 |
+
noise_progression = odeint(forward_pde,
|
| 79 |
+
initial.unsqueeze(0),
|
| 80 |
+
t, dt, method = integ_method
|
| 81 |
+
)[:, 0]
|
| 82 |
+
total_time = time.time() - start_time
|
| 83 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
| 84 |
+
print('Time {}'.format(total_time_str))
|
| 85 |
+
noise_progression = noise_progression[::2]
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
noise_progression = noise_progression.numpy()
|
| 89 |
+
make_dir('out/2d/progression')
|
| 90 |
+
|
| 91 |
+
for i, noise_t in enumerate(noise_progression):
|
| 92 |
+
print(i, noise_t.mean())
|
| 93 |
+
|
| 94 |
+
noise_t[noise_t > thres] = 1
|
| 95 |
+
noise_t[noise_t <= thres] = 0
|
| 96 |
+
|
| 97 |
+
#fig = plt.figure()
|
| 98 |
+
plt.imshow(noise_t, cmap='gray') #, interpolation='lanczos')
|
| 99 |
+
plt.savefig('out/2d/progression/%d.png' % i)
|
| 100 |
+
#plt.close()
|
| 101 |
+
|
| 102 |
+
viewVolume(noise_progression, names = ['noise_progression'], save_dir = 'out/2d/progression')
|
ShapeID/demo3d.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ported from https://github.com/pvigier/perlin-numpy
|
| 2 |
+
|
| 3 |
+
import os, sys
|
| 4 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 5 |
+
|
| 6 |
+
import time, datetime
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import numpy as np
|
| 10 |
+
import matplotlib.pyplot as plt
|
| 11 |
+
|
| 12 |
+
from misc import stream_3D, V_plot, center_crop
|
| 13 |
+
from utils.misc import viewVolume, make_dir, read_image
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
#from ShapeID.DiffEqs.odeint import odeint
|
| 17 |
+
from ShapeID.DiffEqs.adjoint import odeint_adjoint as odeint
|
| 18 |
+
from ShapeID.DiffEqs.pde import AdvDiffPDE
|
| 19 |
+
|
| 20 |
+
from perlin3d import *
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
if __name__ == '__main__':
|
| 26 |
+
|
| 27 |
+
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
| 28 |
+
|
| 29 |
+
percentile = 80
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
#image, mask_image = generate_perlin_noise_3d([128, 128, 128], [2, 2, 2], tileable=(True, False, False), percentile = percentile)
|
| 33 |
+
#viewVolume(image, names = ['image'], save_dir = '/autofs/space/yogurt_003/users/pl629/code/MTBrainID/ShapeID/out/3d')
|
| 34 |
+
#viewVolume(mask_image, names = ['mask_image'], save_dir = '/autofs/space/yogurt_003/users/pl629/code/MTBrainID/ShapeID/out/3d')
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
#mask_image, aff = read_image('/autofs/space/yogurt_001/users/pl629/data/adni/pathology_probability/subject_193441.nii.gz')
|
| 38 |
+
mask_image, aff = read_image('/autofs/space/yogurt_001/users/pl629/data/isles2022/pathology_probability/sub-strokecase0127.nii.gz')
|
| 39 |
+
mask_image, _, _ = center_crop(torch.from_numpy(mask_image), win_size = [128, 128, 128])
|
| 40 |
+
mask_image = mask_image[0, 0].numpy()
|
| 41 |
+
|
| 42 |
+
shape = mask_image.shape
|
| 43 |
+
|
| 44 |
+
curl_a, _ = generate_perlin_noise_3d(shape, [2, 2, 2], tileable=(True, False, False), percentile = percentile)
|
| 45 |
+
curl_b, _ = generate_perlin_noise_3d(shape, [2, 2, 2], tileable=(True, False, False), percentile = percentile)
|
| 46 |
+
curl_c, _ = generate_perlin_noise_3d(shape, [2, 2, 2], tileable=(True, False, False), percentile = percentile)
|
| 47 |
+
dx, dy, dz = stream_3D(torch.from_numpy(curl_a), torch.from_numpy(curl_b), torch.from_numpy(curl_c))
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
dt = 0.1
|
| 51 |
+
nt = 10
|
| 52 |
+
integ_method = 'dopri5' # choices=['dopri5', 'adams', 'rk4', 'euler']
|
| 53 |
+
t = torch.from_numpy(np.arange(nt) * dt).to(device)
|
| 54 |
+
thres = 0.5
|
| 55 |
+
|
| 56 |
+
initial = torch.from_numpy(mask_image)[None] # (batch=1, h, w)
|
| 57 |
+
Vx, Vy, Vz = dx * 500, dy * 500, dz * 500
|
| 58 |
+
print(abs(Vx).mean(), abs(Vy).mean(), abs(Vz).mean())
|
| 59 |
+
|
| 60 |
+
forward_pde = AdvDiffPDE(data_spacing=[1., 1., 1.],
|
| 61 |
+
perf_pattern='adv',
|
| 62 |
+
V_type='vector_div_free',
|
| 63 |
+
V_dict={'Vx': Vx, 'Vy': Vy, 'Vz': Vz},
|
| 64 |
+
BC='neumann',
|
| 65 |
+
dt=dt,
|
| 66 |
+
device=device
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
start_time = time.time()
|
| 71 |
+
noise_progression = odeint(forward_pde,
|
| 72 |
+
initial,
|
| 73 |
+
t, dt, method = integ_method
|
| 74 |
+
)[:, 0] # (nt, n_batch, h, w)
|
| 75 |
+
total_time = time.time() - start_time
|
| 76 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
| 77 |
+
print('Time {}'.format(total_time_str))
|
| 78 |
+
|
| 79 |
+
noise_progression = noise_progression[::2]
|
| 80 |
+
noise_progression = noise_progression.numpy()
|
| 81 |
+
make_dir('out/3d/progression')
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
for i, noise_t in enumerate(noise_progression):
|
| 85 |
+
noise_t[noise_t > 1] = 1
|
| 86 |
+
noise_t[noise_t <= thres] = 0
|
| 87 |
+
print(i, noise_t.mean())
|
| 88 |
+
viewVolume(noise_t, names = ['noise_%s' % i], save_dir = '/autofs/space/yogurt_003/users/pl629/code/MTBrainID/ShapeID/out/3d/progression')
|
| 89 |
+
|
| 90 |
+
noise_t[noise_t > 0.] = 1
|
| 91 |
+
viewVolume(noise_t, names = ['noise_%s_mask' % i], save_dir = '/autofs/space/yogurt_003/users/pl629/code/MTBrainID/ShapeID/out/3d/progression')
|
ShapeID/misc.py
ADDED
|
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ported from https://github.com/pvigier/perlin-numpy
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import numpy as np
|
| 6 |
+
import matplotlib.pyplot as plt
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def center_crop(img, win_size = [220, 220, 220]):
|
| 11 |
+
# center crop
|
| 12 |
+
if len(img.shape) == 4:
|
| 13 |
+
img = torch.permute(img, (3, 0, 1, 2)) # (move last dim to first)
|
| 14 |
+
img = img[None]
|
| 15 |
+
permuted = True
|
| 16 |
+
else:
|
| 17 |
+
assert len(img.shape) == 3
|
| 18 |
+
img = img[None, None]
|
| 19 |
+
permuted = False
|
| 20 |
+
|
| 21 |
+
orig_shp = img.shape[2:] # (1, d, s, r, c)
|
| 22 |
+
if win_size is None:
|
| 23 |
+
if permuted:
|
| 24 |
+
return torch.permute(img, (0, 2, 3, 4, 1)), [0, 0, 0], orig_shp
|
| 25 |
+
return img, [0, 0, 0], orig_shp
|
| 26 |
+
elif orig_shp[0] > win_size[0] or orig_shp[1] > win_size[1] or orig_shp[2] > win_size[2]:
|
| 27 |
+
crop_start = [ max((orig_shp[i] - win_size[i]), 0) // 2 for i in range(3) ]
|
| 28 |
+
crop_img = img[ :, :, crop_start[0] : crop_start[0] + win_size[0],
|
| 29 |
+
crop_start[1] : crop_start[1] + win_size[1],
|
| 30 |
+
crop_start[2] : crop_start[2] + win_size[2]]
|
| 31 |
+
if permuted:
|
| 32 |
+
return torch.permute(crop_img, (0, 2, 3, 4, 1)), [0, 0, 0], orig_shp
|
| 33 |
+
return crop_img, crop_start, orig_shp
|
| 34 |
+
else:
|
| 35 |
+
if permuted:
|
| 36 |
+
return torch.permute(img, (0, 2, 3, 4, 1)), [0, 0, 0], orig_shp
|
| 37 |
+
return img, [0, 0, 0], orig_shp
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def V_plot(Vx, Vy, save_path):
|
| 42 |
+
# Meshgrid
|
| 43 |
+
X,Y = np.meshgrid(np.arange(0, Vx.shape[0], 1), np.arange(0, Vx.shape[1], 1))
|
| 44 |
+
# Assign vector directions
|
| 45 |
+
Ex = Vx
|
| 46 |
+
Ey = Vy
|
| 47 |
+
|
| 48 |
+
# Depict illustration
|
| 49 |
+
plt.figure()
|
| 50 |
+
plt.streamplot(X,Y,Ex,Ey, density=1.4, linewidth=None, color='orange')
|
| 51 |
+
plt.axis('off')
|
| 52 |
+
plt.savefig(save_path)
|
| 53 |
+
#plt.show()
|
| 54 |
+
|
| 55 |
+
def stream_2D(Phi, batched = False, delta_lst = [1., 1.]):
|
| 56 |
+
'''
|
| 57 |
+
input: Phi as a scalar field in 2D grid: (r, c) or (n_batch, r, c)
|
| 58 |
+
output: curl of Phi (divergence-free by definition)
|
| 59 |
+
'''
|
| 60 |
+
dD = gradient_c(Phi, batched = batched, delta_lst = delta_lst)
|
| 61 |
+
Vx = - dD[..., 1]
|
| 62 |
+
Vy = dD[..., 0]
|
| 63 |
+
return Vx, Vy
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def stream_3D(Phi_a, Phi_b, Phi_c, batched = False, delta_lst = [1., 1., 1.]):
|
| 67 |
+
'''
|
| 68 |
+
input: (batch, s, r, c)
|
| 69 |
+
'''
|
| 70 |
+
device = Phi_a.device
|
| 71 |
+
dDa = gradient_c(Phi_a, batched = batched, delta_lst = delta_lst)
|
| 72 |
+
dDb = gradient_c(Phi_b, batched = batched, delta_lst = delta_lst)
|
| 73 |
+
dDc = gradient_c(Phi_c, batched = batched, delta_lst = delta_lst)
|
| 74 |
+
Va_x, Va_y, Va_z = dDa[..., 0], dDa[..., 1], dDa[..., 2]
|
| 75 |
+
Vb_x, Vb_y, Vb_z = dDb[..., 0], dDb[..., 1], dDb[..., 2]
|
| 76 |
+
Vc_x, Vc_y, Vc_z = dDc[..., 0], dDc[..., 1], dDc[..., 2]
|
| 77 |
+
Vx = Vc_y - Vb_z
|
| 78 |
+
Vy = Va_z - Vc_x
|
| 79 |
+
Vz = Vb_x - Va_y
|
| 80 |
+
return Vx, Vy, Vz
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def gradient_f(X, batched = False, delta_lst = [1., 1., 1.]):
|
| 85 |
+
'''
|
| 86 |
+
Compute gradient of a torch tensor "X" in each direction
|
| 87 |
+
Upper-boundaries: Backward Difference
|
| 88 |
+
Non-boundaries & Upper-boundaries: Forward Difference
|
| 89 |
+
if X is batched: (n_batch, ...);
|
| 90 |
+
else: (...)
|
| 91 |
+
'''
|
| 92 |
+
device = X.device
|
| 93 |
+
dim = len(X.size()) - 1 if batched else len(X.size())
|
| 94 |
+
#print(batched)
|
| 95 |
+
#print(dim)
|
| 96 |
+
if dim == 1:
|
| 97 |
+
#print('dim = 1')
|
| 98 |
+
dX = torch.zeros(X.size(), dtype = torch.float, device = device)
|
| 99 |
+
X = X.permute(1, 0) if batched else X
|
| 100 |
+
dX = dX.permute(1, 0) if batched else dX
|
| 101 |
+
dX[-1] = X[-1] - X[-2] # Backward Difference
|
| 102 |
+
dX[:-1] = X[1:] - X[:-1] # Forward Difference
|
| 103 |
+
|
| 104 |
+
dX = dX.permute(1, 0) if batched else dX
|
| 105 |
+
dX /= delta_lst[0]
|
| 106 |
+
elif dim == 2:
|
| 107 |
+
#print('dim = 2')
|
| 108 |
+
dX = torch.zeros(X.size() + tuple([2]), dtype = torch.float, device = device)
|
| 109 |
+
X = X.permute(1, 2, 0) if batched else X
|
| 110 |
+
dX = dX.permute(1, 2, 3, 0) if batched else dX # put batch to last dim
|
| 111 |
+
dX[-1, :, 0] = X[-1, :] - X[-2, :] # Backward Difference
|
| 112 |
+
dX[:-1, :, 0] = X[1:] - X[:-1] # Forward Difference
|
| 113 |
+
|
| 114 |
+
dX[:, -1, 1] = X[:, -1] - X[:, -2] # Backward Difference
|
| 115 |
+
dX[:, :-1, 1] = X[:, 1:] - X[:, :-1] # Forward Difference
|
| 116 |
+
|
| 117 |
+
dX = dX.permute(3, 0, 1, 2) if batched else dX
|
| 118 |
+
dX[..., 0] /= delta_lst[0]
|
| 119 |
+
dX[..., 1] /= delta_lst[1]
|
| 120 |
+
elif dim == 3:
|
| 121 |
+
#print('dim = 3')
|
| 122 |
+
dX = torch.zeros(X.size() + tuple([3]), dtype = torch.float, device = device)
|
| 123 |
+
X = X.permute(1, 2, 3, 0) if batched else X
|
| 124 |
+
dX = dX.permute(1, 2, 3, 4, 0) if batched else dX
|
| 125 |
+
dX[-1, :, :, 0] = X[-1, :, :] - X[-2, :, :] # Backward Difference
|
| 126 |
+
dX[:-1, :, :, 0] = X[1:] - X[:-1] # Forward Difference
|
| 127 |
+
|
| 128 |
+
dX[:, -1, :, 1] = X[:, -1] - X[:, -2] # Backward Difference
|
| 129 |
+
dX[:, :-1, :, 1] = X[:, 1:] - X[:, :-1] # Forward Difference
|
| 130 |
+
|
| 131 |
+
dX[:, :, -1, 2] = X[:, :, -1] - X[:, :, -2] # Backward Difference
|
| 132 |
+
dX[:, :, :-1, 2] = X[:, :, 1:] - X[:, :, :-1] # Forward Difference
|
| 133 |
+
|
| 134 |
+
dX = dX.permute(4, 0, 1, 2, 3) if batched else dX
|
| 135 |
+
dX[..., 0] /= delta_lst[0]
|
| 136 |
+
dX[..., 1] /= delta_lst[1]
|
| 137 |
+
dX[..., 2] /= delta_lst[2]
|
| 138 |
+
return dX
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def gradient_b(X, batched = False, delta_lst = [1., 1., 1.]):
|
| 142 |
+
'''
|
| 143 |
+
Compute gradient of a torch tensor "X" in each direction
|
| 144 |
+
Non-boundaries & Upper-boundaries: Backward Difference
|
| 145 |
+
Lower-boundaries: Forward Difference
|
| 146 |
+
if X is batched: (n_batch, ...);
|
| 147 |
+
else: (...)
|
| 148 |
+
'''
|
| 149 |
+
device = X.device
|
| 150 |
+
dim = len(X.size()) - 1 if batched else len(X.size())
|
| 151 |
+
#print(batched)
|
| 152 |
+
#print(dim)
|
| 153 |
+
if dim == 1:
|
| 154 |
+
#print('dim = 1')
|
| 155 |
+
dX = torch.zeros(X.size(), dtype = torch.float, device = device)
|
| 156 |
+
X = X.permute(1, 0) if batched else X
|
| 157 |
+
dX = dX.permute(1, 0) if batched else dX
|
| 158 |
+
dX[1:] = X[1:] - X[:-1] # Backward Difference
|
| 159 |
+
dX[0] = X[1] - X[0] # Forward Difference
|
| 160 |
+
|
| 161 |
+
dX = dX.permute(1, 0) if batched else dX
|
| 162 |
+
dX /= delta_lst[0]
|
| 163 |
+
elif dim == 2:
|
| 164 |
+
#print('dim = 2')
|
| 165 |
+
dX = torch.zeros(X.size() + tuple([2]), dtype = torch.float, device = device)
|
| 166 |
+
X = X.permute(1, 2, 0) if batched else X
|
| 167 |
+
dX = dX.permute(1, 2, 3, 0) if batched else dX # put batch to last dim
|
| 168 |
+
dX[1:, :, 0] = X[1:, :] - X[:-1, :] # Backward Difference
|
| 169 |
+
dX[0, :, 0] = X[1] - X[0] # Forward Difference
|
| 170 |
+
|
| 171 |
+
dX[:, 1:, 1] = X[:, 1:] - X[:, :-1] # Backward Difference
|
| 172 |
+
dX[:, 0, 1] = X[:, 1] - X[:, 0] # Forward Difference
|
| 173 |
+
|
| 174 |
+
dX = dX.permute(3, 0, 1, 2) if batched else dX
|
| 175 |
+
dX[..., 0] /= delta_lst[0]
|
| 176 |
+
dX[..., 1] /= delta_lst[1]
|
| 177 |
+
elif dim == 3:
|
| 178 |
+
#print('dim = 3')
|
| 179 |
+
dX = torch.zeros(X.size() + tuple([3]), dtype = torch.float, device = device)
|
| 180 |
+
X = X.permute(1, 2, 3, 0) if batched else X
|
| 181 |
+
dX = dX.permute(1, 2, 3, 4, 0) if batched else dX
|
| 182 |
+
dX[1:, :, :, 0] = X[1:, :, :] - X[:-1, :, :] # Backward Difference
|
| 183 |
+
dX[0, :, :, 0] = X[1] - X[0] # Forward Difference
|
| 184 |
+
|
| 185 |
+
dX[:, 1:, :, 1] = X[:, 1:] - X[:, :-1] # Backward Difference
|
| 186 |
+
dX[:, 0, :, 1] = X[:, 1] - X[:, 0] # Forward Difference
|
| 187 |
+
|
| 188 |
+
dX[:, :, 1:, 2] = X[:, :, 1:] - X[:, :, :-1] # Backward Difference
|
| 189 |
+
dX[:, :, 0, 2] = X[:, :, 1] - X[:, :, 0] # Forward Difference
|
| 190 |
+
|
| 191 |
+
dX = dX.permute(4, 0, 1, 2, 3) if batched else dX
|
| 192 |
+
dX[..., 0] /= delta_lst[0]
|
| 193 |
+
dX[..., 1] /= delta_lst[1]
|
| 194 |
+
dX[..., 2] /= delta_lst[2]
|
| 195 |
+
return dX
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def gradient_c(X, batched = False, delta_lst = [1., 1., 1.]):
|
| 199 |
+
'''
|
| 200 |
+
Compute gradient of a torch tensor "X" in each direction
|
| 201 |
+
Non-boundaries: Central Difference
|
| 202 |
+
Upper-boundaries: Backward Difference
|
| 203 |
+
Lower-boundaries: Forward Difference
|
| 204 |
+
if X is batched: (n_batch, ...);
|
| 205 |
+
else: (...)
|
| 206 |
+
'''
|
| 207 |
+
|
| 208 |
+
device = X.device
|
| 209 |
+
dim = len(X.size()) - 1 if batched else len(X.size())
|
| 210 |
+
#print(X.size())
|
| 211 |
+
#print(batched)
|
| 212 |
+
#print(dim)
|
| 213 |
+
if dim == 1:
|
| 214 |
+
#print('dim = 1')
|
| 215 |
+
dX = torch.zeros(X.size(), dtype = torch.float, device = device)
|
| 216 |
+
X = X.permute(1, 0) if batched else X
|
| 217 |
+
dX = dX.permute(1, 0) if batched else dX
|
| 218 |
+
dX[1:-1] = (X[2:] - X[:-2]) / 2 # Central Difference
|
| 219 |
+
dX[0] = X[1] - X[0] # Forward Difference
|
| 220 |
+
dX[-1] = X[-1] - X[-2] # Backward Difference
|
| 221 |
+
|
| 222 |
+
dX = dX.permute(1, 0) if batched else dX
|
| 223 |
+
dX /= delta_lst[0]
|
| 224 |
+
elif dim == 2:
|
| 225 |
+
#print('dim = 2')
|
| 226 |
+
dX = torch.zeros(X.size() + tuple([2]), dtype = torch.float, device = device)
|
| 227 |
+
X = X.permute(1, 2, 0) if batched else X
|
| 228 |
+
dX = dX.permute(1, 2, 3, 0) if batched else dX # put batch to last dim
|
| 229 |
+
dX[1:-1, :, 0] = (X[2:, :] - X[:-2, :]) / 2
|
| 230 |
+
dX[0, :, 0] = X[1] - X[0]
|
| 231 |
+
dX[-1, :, 0] = X[-1] - X[-2]
|
| 232 |
+
dX[:, 1:-1, 1] = (X[:, 2:] - X[:, :-2]) / 2
|
| 233 |
+
dX[:, 0, 1] = X[:, 1] - X[:, 0]
|
| 234 |
+
dX[:, -1, 1] = X[:, -1] - X[:, -2]
|
| 235 |
+
|
| 236 |
+
dX = dX.permute(3, 0, 1, 2) if batched else dX
|
| 237 |
+
dX[..., 0] /= delta_lst[0]
|
| 238 |
+
dX[..., 1] /= delta_lst[1]
|
| 239 |
+
elif dim == 3:
|
| 240 |
+
#print('dim = 3')
|
| 241 |
+
dX = torch.zeros(X.size() + tuple([3]), dtype = torch.float, device = device)
|
| 242 |
+
X = X.permute(1, 2, 3, 0) if batched else X
|
| 243 |
+
dX = dX.permute(1, 2, 3, 4, 0) if batched else dX
|
| 244 |
+
dX[1:-1, :, :, 0] = (X[2:, :, :] - X[:-2, :, :]) / 2
|
| 245 |
+
dX[0, :, :, 0] = X[1] - X[0]
|
| 246 |
+
dX[-1, :, :, 0] = X[-1] - X[-2]
|
| 247 |
+
dX[:, 1:-1, :, 1] = (X[:, 2:, :] - X[:, :-2, :]) / 2
|
| 248 |
+
dX[:, 0, :, 1] = X[:, 1] - X[:, 0]
|
| 249 |
+
dX[:, -1, :, 1] = X[:, -1] - X[:, -2]
|
| 250 |
+
dX[:, :, 1:-1, 2] = (X[:, :, 2:] - X[:, :, :-2]) / 2
|
| 251 |
+
dX[:, :, 0, 2] = X[:, :, 1] - X[:, :, 0]
|
| 252 |
+
dX[:, :, -1, 2] = X[:, :, -1] - X[:, :, -2]
|
| 253 |
+
|
| 254 |
+
dX = dX.permute(4, 0, 1, 2, 3) if batched else dX
|
| 255 |
+
dX[..., 0] /= delta_lst[0]
|
| 256 |
+
dX[..., 1] /= delta_lst[1]
|
| 257 |
+
dX[..., 2] /= delta_lst[2]
|
| 258 |
+
|
| 259 |
+
return dX
|
| 260 |
+
|
| 261 |
+
|
ShapeID/out/2d/V.png
ADDED
|
Git LFS Details
|
ShapeID/out/2d/curl.png
ADDED
|
ShapeID/out/2d/image.png
ADDED
|
ShapeID/out/2d/image_with_v.png
ADDED
|
Git LFS Details
|
ShapeID/out/2d/mask_curl.png
ADDED
|
ShapeID/out/2d/mask_image.png
ADDED
|
ShapeID/out/2d/progression/New Folder With Items/0.png
ADDED
|
Git LFS Details
|