peirong26 commited on
Commit
2571f24
·
verified ·
1 Parent(s): 56cda19

Upload 187 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +14 -0
  2. Generator/__init__.py +22 -0
  3. Generator/config.py +181 -0
  4. Generator/constants.py +290 -0
  5. Generator/datasets.py +757 -0
  6. Generator/interpol/__init__.py +7 -0
  7. Generator/interpol/_version.py +623 -0
  8. Generator/interpol/api.py +560 -0
  9. Generator/interpol/autograd.py +301 -0
  10. Generator/interpol/backend.py +1 -0
  11. Generator/interpol/bounds.py +89 -0
  12. Generator/interpol/coeff.py +344 -0
  13. Generator/interpol/iso0.py +368 -0
  14. Generator/interpol/iso1.py +1339 -0
  15. Generator/interpol/jit_utils.py +443 -0
  16. Generator/interpol/jitfields.py +95 -0
  17. Generator/interpol/nd.py +464 -0
  18. Generator/interpol/pushpull.py +325 -0
  19. Generator/interpol/resize.py +120 -0
  20. Generator/interpol/restrict.py +122 -0
  21. Generator/interpol/splines.py +196 -0
  22. Generator/interpol/tests/__init__.py +0 -0
  23. Generator/interpol/tests/test_gradcheck_pushpull.py +125 -0
  24. Generator/interpol/utils.py +176 -0
  25. Generator/utils.py +669 -0
  26. README.md +91 -3
  27. ShapeID/DiffEqs/FD.py +525 -0
  28. ShapeID/DiffEqs/adams.py +170 -0
  29. ShapeID/DiffEqs/adjoint.py +133 -0
  30. ShapeID/DiffEqs/dopri5.py +172 -0
  31. ShapeID/DiffEqs/fixed_adams.py +211 -0
  32. ShapeID/DiffEqs/fixed_grid.py +33 -0
  33. ShapeID/DiffEqs/interp.py +65 -0
  34. ShapeID/DiffEqs/misc.py +195 -0
  35. ShapeID/DiffEqs/odeint.py +75 -0
  36. ShapeID/DiffEqs/pde.py +643 -0
  37. ShapeID/DiffEqs/rk_common.py +78 -0
  38. ShapeID/DiffEqs/solvers.py +216 -0
  39. ShapeID/DiffEqs/tsit5.py +139 -0
  40. ShapeID/__init__.py +1 -0
  41. ShapeID/demo2d.py +102 -0
  42. ShapeID/demo3d.py +91 -0
  43. ShapeID/misc.py +261 -0
  44. ShapeID/out/2d/V.png +3 -0
  45. ShapeID/out/2d/curl.png +0 -0
  46. ShapeID/out/2d/image.png +0 -0
  47. ShapeID/out/2d/image_with_v.png +3 -0
  48. ShapeID/out/2d/mask_curl.png +0 -0
  49. ShapeID/out/2d/mask_image.png +0 -0
  50. ShapeID/out/2d/progression/New Folder With Items/0.png +3 -0
.gitattributes CHANGED
@@ -33,3 +33,17 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/overview.png filter=lfs diff=lfs merge=lfs -text
37
+ ShapeID/out/2d/image_with_v.png filter=lfs diff=lfs merge=lfs -text
38
+ ShapeID/out/2d/progression/New[[:space:]]Folder[[:space:]]With[[:space:]]Items/0.png filter=lfs diff=lfs merge=lfs -text
39
+ ShapeID/out/2d/progression/New[[:space:]]Folder[[:space:]]With[[:space:]]Items/1.png filter=lfs diff=lfs merge=lfs -text
40
+ ShapeID/out/2d/progression/New[[:space:]]Folder[[:space:]]With[[:space:]]Items/10.png filter=lfs diff=lfs merge=lfs -text
41
+ ShapeID/out/2d/progression/New[[:space:]]Folder[[:space:]]With[[:space:]]Items/2.png filter=lfs diff=lfs merge=lfs -text
42
+ ShapeID/out/2d/progression/New[[:space:]]Folder[[:space:]]With[[:space:]]Items/3.png filter=lfs diff=lfs merge=lfs -text
43
+ ShapeID/out/2d/progression/New[[:space:]]Folder[[:space:]]With[[:space:]]Items/4.png filter=lfs diff=lfs merge=lfs -text
44
+ ShapeID/out/2d/progression/New[[:space:]]Folder[[:space:]]With[[:space:]]Items/5.png filter=lfs diff=lfs merge=lfs -text
45
+ ShapeID/out/2d/progression/New[[:space:]]Folder[[:space:]]With[[:space:]]Items/6.png filter=lfs diff=lfs merge=lfs -text
46
+ ShapeID/out/2d/progression/New[[:space:]]Folder[[:space:]]With[[:space:]]Items/7.png filter=lfs diff=lfs merge=lfs -text
47
+ ShapeID/out/2d/progression/New[[:space:]]Folder[[:space:]]With[[:space:]]Items/8.png filter=lfs diff=lfs merge=lfs -text
48
+ ShapeID/out/2d/progression/New[[:space:]]Folder[[:space:]]With[[:space:]]Items/9.png filter=lfs diff=lfs merge=lfs -text
49
+ ShapeID/out/2d/V.png filter=lfs diff=lfs merge=lfs -text
Generator/__init__.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ """
3
+ Datasets interface.
4
+ """
5
+ from .constants import dataset_setups
6
+ from .datasets import BaseGen, BrainIDGen
7
+
8
+
9
+
10
+ dataset_options = {
11
+ 'default': BaseGen,
12
+ 'brain_id': BrainIDGen,
13
+ }
14
+
15
+
16
+
17
+
18
+ def build_datasets(gen_args, device):
19
+ """Helper function to build dataset for different splits ('train' or 'test')."""
20
+ datasets = {'all': dataset_options[gen_args.dataset_option](gen_args, device)}
21
+ return datasets
22
+
Generator/config.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ """Config utilities for yml file."""
3
+ import os
4
+ from argparse import Namespace
5
+ import collections
6
+ import functools
7
+ import os
8
+ import re
9
+
10
+ import yaml
11
+ # from imaginaire.utils.distributed import master_only_print as print
12
+
13
+
14
+ class AttrDict(dict):
15
+ """Dict as attribute trick."""
16
+
17
+ def __init__(self, *args, **kwargs):
18
+ super(AttrDict, self).__init__(*args, **kwargs)
19
+ self.__dict__ = self
20
+ for key, value in self.__dict__.items():
21
+ if isinstance(value, dict):
22
+ self.__dict__[key] = AttrDict(value)
23
+ elif isinstance(value, (list, tuple)):
24
+ if isinstance(value[0], dict):
25
+ self.__dict__[key] = [AttrDict(item) for item in value]
26
+ else:
27
+ self.__dict__[key] = value
28
+
29
+ def yaml(self):
30
+ """Convert object to yaml dict and return."""
31
+ yaml_dict = {}
32
+ for key, value in self.__dict__.items():
33
+ if isinstance(value, AttrDict):
34
+ yaml_dict[key] = value.yaml()
35
+ elif isinstance(value, list):
36
+ if isinstance(value[0], AttrDict):
37
+ new_l = []
38
+ for item in value:
39
+ new_l.append(item.yaml())
40
+ yaml_dict[key] = new_l
41
+ else:
42
+ yaml_dict[key] = value
43
+ else:
44
+ yaml_dict[key] = value
45
+ return yaml_dict
46
+
47
+ def __repr__(self):
48
+ """Print all variables."""
49
+ ret_str = []
50
+ for key, value in self.__dict__.items():
51
+ if isinstance(value, AttrDict):
52
+ ret_str.append('{}:'.format(key))
53
+ child_ret_str = value.__repr__().split('\n')
54
+ for item in child_ret_str:
55
+ ret_str.append(' ' + item)
56
+ elif isinstance(value, list):
57
+ if isinstance(value[0], AttrDict):
58
+ ret_str.append('{}:'.format(key))
59
+ for item in value:
60
+ # Treat as AttrDict above.
61
+ child_ret_str = item.__repr__().split('\n')
62
+ for item in child_ret_str:
63
+ ret_str.append(' ' + item)
64
+ else:
65
+ ret_str.append('{}: {}'.format(key, value))
66
+ else:
67
+ ret_str.append('{}: {}'.format(key, value))
68
+ return '\n'.join(ret_str)
69
+
70
+
71
+ class Config(AttrDict):
72
+ r"""Configuration class. This should include every human specifiable
73
+ hyperparameter values for your training."""
74
+
75
+ def __init__(self, filename=None, verbose=False):
76
+ super(Config, self).__init__()
77
+
78
+ # Update with given configurations.
79
+ if os.path.exists(filename):
80
+
81
+ loader = yaml.SafeLoader
82
+ loader.add_implicit_resolver(
83
+ u'tag:yaml.org,2002:float',
84
+ re.compile(u'''^(?:
85
+ [-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
86
+ |[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
87
+ |\\.[0-9_]+(?:[eE][-+][0-9]+)?
88
+ |[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*
89
+ |[-+]?\\.(?:inf|Inf|INF)
90
+ |\\.(?:nan|NaN|NAN))$''', re.X),
91
+ list(u'-+0123456789.'))
92
+ try:
93
+ with open(filename, 'r') as f:
94
+ cfg_dict = yaml.load(f, Loader=loader)
95
+ except EnvironmentError:
96
+ print('Please check the file with name of "%s"', filename)
97
+ recursive_update(self, cfg_dict)
98
+ else:
99
+ raise ValueError('Provided config path not existed: %s' % filename)
100
+
101
+ if verbose:
102
+ print(' imaginaire config '.center(80, '-'))
103
+ print(self.__repr__())
104
+ print(''.center(80, '-'))
105
+
106
+
107
+ def rsetattr(obj, attr, val):
108
+ """Recursively find object and set value"""
109
+ pre, _, post = attr.rpartition('.')
110
+ return setattr(rgetattr(obj, pre) if pre else obj, post, val)
111
+
112
+
113
+ def rgetattr(obj, attr, *args):
114
+ """Recursively find object and return value"""
115
+
116
+ def _getattr(obj, attr):
117
+ r"""Get attribute."""
118
+ return getattr(obj, attr, *args)
119
+
120
+ return functools.reduce(_getattr, [obj] + attr.split('.'))
121
+
122
+
123
+ def recursive_update(d, u):
124
+ """Recursively update AttrDict d with AttrDict u"""
125
+ if u is not None:
126
+ for key, value in u.items():
127
+ if isinstance(value, collections.abc.Mapping):
128
+ d.__dict__[key] = recursive_update(d.get(key, AttrDict({})), value)
129
+ elif isinstance(value, (list, tuple)):
130
+ if len(value) > 0 and isinstance(value[0], dict):
131
+ d.__dict__[key] = [AttrDict(item) for item in value]
132
+ else:
133
+ d.__dict__[key] = value
134
+ else:
135
+ d.__dict__[key] = value
136
+ return d
137
+
138
+
139
+ def merge_and_update_from_dict(cfg, dct):
140
+ """
141
+ (Compatible for submitit's Dict as attribute trick)
142
+ Merge dict as dict() to config as CfgNode().
143
+ Args:
144
+ cfg: dict
145
+ dct: dict
146
+ """
147
+ if dct is not None:
148
+ for key, value in dct.items():
149
+ if isinstance(value, dict):
150
+ if key in cfg.keys():
151
+ sub_cfgnode = cfg[key]
152
+ else:
153
+ sub_cfgnode = dict()
154
+ cfg.__setattr__(key, sub_cfgnode)
155
+ sub_cfgnode = merge_and_update_from_dict(sub_cfgnode, value)
156
+ else:
157
+ cfg[key] = value
158
+ return cfg
159
+
160
+
161
+ def load_config(cfg_files = [], cfg_dir = ''):
162
+ cfg = Config(cfg_files[0])
163
+ for cfg_file in cfg_files[1:]:
164
+ add_cfg = Config(cfg_file)
165
+ cfg = merge_and_update_from_dict(cfg, add_cfg)
166
+ return cfg
167
+
168
+
169
+ def nested_dict_to_namespace(dictionary):
170
+ namespace = dictionary
171
+ if isinstance(dictionary, dict):
172
+ namespace = Namespace(**dictionary)
173
+ for key, value in dictionary.items():
174
+ setattr(namespace, key, nested_dict_to_namespace(value))
175
+ return namespace
176
+
177
+
178
+ def preprocess_cfg(cfg_files, cfg_dir = ''):
179
+ config = load_config(cfg_files, cfg_dir)
180
+ args = nested_dict_to_namespace(config)
181
+ return args
Generator/constants.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, glob
2
+
3
+ from .utils import *
4
+
5
+ augmentation_funcs = {
6
+ 'gamma': add_gamma_transform,
7
+ 'bias_field': add_bias_field,
8
+ 'resample': resample_resolution,
9
+ 'noise': add_noise,
10
+ }
11
+
12
+ processing_funcs = {
13
+ 'T1': read_and_deform_image,
14
+ 'T2': read_and_deform_image,
15
+ 'FLAIR': read_and_deform_image,
16
+ 'CT': read_and_deform_CT,
17
+ 'segmentation': read_and_deform_segmentation,
18
+ 'surface': read_and_deform_surface,
19
+ 'distance': read_and_deform_distance,
20
+ 'bias_field': read_and_deform_bias_field,
21
+ 'registration': read_and_deform_registration,
22
+ 'pathology': read_and_deform_pathology,
23
+ }
24
+
25
+
26
+ dataset_setups = {
27
+
28
+ 'ADHD': {
29
+ 'root': '/autofs/space/yogurt_001/users/pl629/data/adhd200_crop',
30
+ 'pathology_type': None,
31
+ 'train': 'train.txt',
32
+ 'test': 'test.txt',
33
+ 'modalities': ['T1'],
34
+
35
+ 'paths':{
36
+ # for synth
37
+ 'Gen': 'label_maps_generation',
38
+ 'Dmaps': None,
39
+ 'DmapsBag': None,
40
+
41
+ # real images
42
+ 'T1': 'T1',
43
+ 'T2': None,
44
+ 'FLAIR': None,
45
+ 'CT': None,
46
+
47
+ # processed ground truths
48
+ 'surface': None, #'surfaces', TODO
49
+ 'distance': None,
50
+ 'segmentation': 'label_maps_segmentation',
51
+ 'bias_field': None,
52
+ 'pathology': None,
53
+ 'pathology_prob': None,
54
+ }
55
+ },
56
+
57
+ 'HCP': {
58
+ 'root': '/autofs/space/yogurt_001/users/pl629/data/hcp_crop',
59
+ 'pathology_type': None,
60
+ 'train': 'train.txt',
61
+ 'test': 'test.txt',
62
+ 'modalities': ['T1', 'T2'],
63
+
64
+ 'paths':{
65
+ # for synth
66
+ 'Gen': 'label_maps_generation',
67
+ 'Dmaps': None,
68
+ 'DmapsBag': None,
69
+
70
+ # real images
71
+ 'T1': 'T1',
72
+ 'T2': 'T2',
73
+ 'FLAIR': None,
74
+ 'CT': None,
75
+
76
+ # processed ground truths
77
+ 'surface': None, #'surfaces',
78
+ 'distance': None,
79
+ 'segmentation': 'label_maps_segmentation',
80
+ 'bias_field': None,
81
+ 'pathology': None,
82
+ 'pathology_prob': None,
83
+ }
84
+ },
85
+
86
+ 'AIBL': {
87
+ 'root': '/autofs/space/yogurt_001/users/pl629/data/aibl_crop',
88
+ 'pathology_type': None,
89
+ 'train': 'train.txt',
90
+ 'test': 'test.txt',
91
+ 'modalities': ['T1', 'T2', 'FLAIR'],
92
+
93
+ 'paths':{
94
+ # for synth
95
+ 'Gen': 'label_maps_generation',
96
+ 'Dmaps': None,
97
+ 'DmapsBag': None,
98
+
99
+ # real images
100
+ 'T1': 'T1',
101
+ 'T2': 'T2',
102
+ 'FLAIR': 'FLAIR',
103
+ 'CT': None,
104
+
105
+ # processed ground truths
106
+ 'surface': None, #'surfaces',
107
+ 'distance': None,
108
+ 'segmentation': 'label_maps_segmentation',
109
+ 'bias_field': None,
110
+ 'pathology': None,
111
+ 'pathology_prob': None,
112
+ }
113
+ },
114
+
115
+ 'OASIS': {
116
+ 'root': '/autofs/space/yogurt_001/users/pl629/data/oasis3',
117
+ 'pathology_type': None,
118
+ 'train': 'train.txt',
119
+ 'test': 'test.txt',
120
+ 'modalities': ['T1', 'CT'],
121
+
122
+ 'paths':{
123
+ # for synth
124
+ 'Gen': 'label_maps_generation',
125
+ 'Dmaps': None,
126
+ 'DmapsBag': None,
127
+
128
+ # real images
129
+ 'T1': 'T1',
130
+ 'T2': None,
131
+ 'FLAIR': None,
132
+ 'CT': 'CT',
133
+
134
+ # processed ground truths
135
+ 'surface': None, #'surfaces',
136
+ 'distance': None,
137
+ 'segmentation': 'label_maps_segmentation',
138
+ 'bias_field': None,
139
+ 'pathology': None,
140
+ 'pathology_prob': None,
141
+ }
142
+ },
143
+
144
+ 'ADNI': {
145
+ 'root': '/autofs/space/yogurt_001/users/pl629/data/adni_crop',
146
+ 'pathology_type': None, #'wmh',
147
+ 'train': 'train.txt',
148
+ 'test': 'test.txt',
149
+ 'modalities': ['T1'],
150
+
151
+ 'paths':{
152
+ # for synth
153
+ 'Gen': 'label_maps_generation',
154
+ 'Dmaps': 'Dmaps',
155
+ 'DmapsBag': 'DmapsBag',
156
+
157
+ # real images
158
+ 'T1': 'T1',
159
+ 'T2': None,
160
+ 'FLAIR': None,
161
+ 'CT': None,
162
+
163
+ # processed ground truths
164
+ 'surface': 'surfaces',
165
+ 'distance': 'Dmaps',
166
+ 'segmentation': 'label_maps_segmentation',
167
+ 'bias_field': None,
168
+ 'pathology': 'pathology_maps_segmentation',
169
+ 'pathology_prob': 'pathology_probability',
170
+ }
171
+ },
172
+
173
+ 'ADNI3': {
174
+ 'root': '/autofs/space/yogurt_001/users/pl629/data/adni3_crop',
175
+ 'pathology_type': None, # 'wmh',
176
+ 'train': 'train.txt',
177
+ 'test': 'test.txt',
178
+ 'modalities': ['T1', 'FLAIR'],
179
+
180
+ 'paths':{
181
+ # for synth
182
+ 'Gen': 'label_maps_generation',
183
+ 'Dmaps': None,
184
+ 'DmapsBag': None,
185
+
186
+ # real images
187
+ 'T1': 'T1',
188
+ 'T2': None,
189
+ 'FLAIR': 'FLAIR',
190
+ 'CT': None,
191
+
192
+ # processed ground truths
193
+ 'surface': None, #'surfaces', TODO
194
+ 'distance': None,
195
+ 'segmentation': 'label_maps_segmentation',
196
+ 'bias_field': None,
197
+ 'pathology': 'pathology_maps_segmentation',
198
+ 'pathology_prob': 'pathology_probability',
199
+ }
200
+ },
201
+
202
+ 'ATLAS': {
203
+ 'root': '/autofs/space/yogurt_001/users/pl629/data/atlas_crop',
204
+ 'pathology_type': 'stroke',
205
+ 'train': 'train.txt',
206
+ 'test': 'test.txt',
207
+ 'modalities': ['T1'],
208
+
209
+ 'paths':{
210
+ # for synth
211
+ 'Gen': 'label_maps_generation',
212
+ 'Dmaps': None,
213
+ 'DmapsBag': None,
214
+
215
+ # real images
216
+ 'T1': 'T1',
217
+ 'T2': None,
218
+ 'FLAIR': None,
219
+ 'CT': None,
220
+
221
+ # processed ground truths
222
+ 'surface': None, #'surfaces', TODO
223
+ 'distance': None,
224
+ 'segmentation': 'label_maps_segmentation',
225
+ 'bias_field': None,
226
+ 'pathology': 'pathology_maps_segmentation',
227
+ 'pathology_prob': 'pathology_probability',
228
+ }
229
+ },
230
+
231
+ 'ISLES': {
232
+ 'root': '/autofs/space/yogurt_001/users/pl629/data/isles2022_crop',
233
+ 'pathology_type': 'stroke',
234
+ 'train': 'train.txt',
235
+ 'test': 'test.txt',
236
+ 'modalities': ['FLAIR'],
237
+
238
+ 'paths':{
239
+ # for synth
240
+ 'Gen': 'label_maps_generation',
241
+ 'Dmaps': None,
242
+ 'DmapsBag': None,
243
+
244
+ # real images
245
+ 'T1': None,
246
+ 'T2': None,
247
+ 'FLAIR': 'FLAIR',
248
+ 'CT': None,
249
+
250
+ # processed ground truths
251
+ 'surface': None, #'surfaces', TODO
252
+ 'distance': None,
253
+ 'segmentation': 'label_maps_segmentation',
254
+ 'bias_field': None,
255
+ 'pathology': 'pathology_maps_segmentation',
256
+ 'pathology_prob': 'pathology_probability',
257
+ }
258
+ },
259
+ }
260
+
261
+
262
+ all_dataset_names = dataset_setups.keys()
263
+
264
+
265
+ # get all pathologies
266
+ pathology_paths = []
267
+ pathology_prob_paths = []
268
+ for name, dict in dataset_setups.items():
269
+ # TODO: select what kind of shapes?
270
+ if dict['paths']['pathology'] is not None and dict['pathology_type'] is not None and dict['pathology_type'] == 'stroke':
271
+ pathology_paths += glob.glob(os.path.join(dict['root'], dict['paths']['pathology'], '*.nii.gz')) \
272
+ + glob.glob(os.path.join(dict['root'], dict['paths']['pathology'], '*.nii'))
273
+ pathology_prob_paths += glob.glob(os.path.join(dict['root'], dict['paths']['pathology_prob'], '*.nii.gz')) \
274
+ + glob.glob(os.path.join(dict['root'], dict['paths']['pathology_prob'], '*.nii'))
275
+ n_pathology = len(pathology_paths)
276
+
277
+
278
+ # with csf # NOTE old version (FreeSurfer standard), non-vast
279
+ label_list_segmentation = [0,14,15,16,24,77,85, 2, 3, 4, 7, 8, 10,11,12,13,17,18,26,28, 41,42,43,46,47,49,50,51,52,53,54,58,60] # 33
280
+ n_neutral_labels = 7
281
+
282
+
283
+ ## NEW VAST synth
284
+ label_list_segmentation_brainseg_with_extracerebral = [0, 11, 12, 13, 16, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46,
285
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 14, 15, 17, 47, 49, 51, 53, 55,
286
+ 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 48, 50, 52, 54, 56]
287
+ n_neutral_labels_brainseg_with_extracerebral = 20
288
+
289
+ label_list_segmentation_brainseg_left = [0, 1, 2, 3, 4, 7, 8, 9, 10, 14, 15, 17, 31, 34, 36, 38, 40, 42]
290
+
Generator/datasets.py ADDED
@@ -0,0 +1,757 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys, glob
2
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
3
+ from collections import defaultdict
4
+ import random
5
+
6
+ import torch
7
+ import numpy as np
8
+ import nibabel as nib
9
+ from torch.utils.data import Dataset
10
+
11
+
12
+ from .utils import *
13
+ from .constants import n_pathology, pathology_paths, pathology_prob_paths, \
14
+ n_neutral_labels_brainseg_with_extracerebral, label_list_segmentation_brainseg_with_extracerebral, \
15
+ label_list_segmentation_brainseg_left, augmentation_funcs, processing_funcs
16
+ import utils.interpol as interpol
17
+
18
+ from utils.misc import viewVolume
19
+
20
+
21
+ from ShapeID.DiffEqs.pde import AdvDiffPDE
22
+
23
+
24
+
25
+ class BaseGen(Dataset):
26
+ """
27
+ BaseGen dataset
28
+ """
29
+ def __init__(self, gen_args, device='cpu'):
30
+
31
+ self.gen_args = gen_args
32
+ self.split = gen_args.split
33
+
34
+ self.synth_args = self.gen_args.generator
35
+ self.shape_gen_args = gen_args.pathology_shape_generator
36
+ self.real_image_args = gen_args.real_image_generator
37
+ self.synth_image_args = gen_args.synth_image_generator
38
+ self.augmentation_steps = vars(gen_args.augmentation_steps)
39
+ self.input_prob = vars(gen_args.modality_probs)
40
+ self.device = device
41
+
42
+ self.prepare_tasks()
43
+ self.prepare_paths()
44
+ self.prepare_grid()
45
+ self.prepare_one_hot()
46
+
47
+
48
+ def __len__(self):
49
+ return sum([len(self.names[i]) for i in range(len(self.names))])
50
+
51
+
52
+ def idx_to_path(self, idx):
53
+ cnt = 0
54
+ for i, l in enumerate(self.datasets_len):
55
+ if idx >= cnt and idx < cnt + l:
56
+ dataset_name = self.datasets[i]
57
+ age = self.ages[i][os.path.basename(self.names[i][idx - cnt]).split('.T1w')[0]] if len(self.ages) > 0 else None
58
+ return dataset_name, vars(self.input_prob[dataset_name]), self.names[i][idx - cnt], age
59
+ else:
60
+ cnt += l
61
+
62
+
63
+ def prepare_paths(self):
64
+
65
+ # Collect list of available images, per dataset
66
+ if len(self.gen_args.dataset_names) < 1:
67
+ datasets = []
68
+ g = glob.glob(os.path.join(self.gen_args.data_root, '*' + 'T1w.nii'))
69
+ for i in range(len(g)):
70
+ filename = os.path.basename(g[i])
71
+ dataset = filename[:filename.find('.')]
72
+ found = False
73
+ for d in datasets:
74
+ if dataset == d:
75
+ found = True
76
+ if found is False:
77
+ datasets.append(dataset)
78
+ print('Found ' + str(len(datasets)) + ' datasets with ' + str(len(g)) + ' scans in total')
79
+ else:
80
+ datasets = self.gen_args.dataset_names
81
+ print('Dataset list', datasets)
82
+
83
+
84
+ names = []
85
+ if 'age' in self.tasks:
86
+ self.split = self.split + '_age'
87
+ if self.gen_args.split_root is not None:
88
+ split_file = open(os.path.join(self.gen_args.split_root, self.split + '.txt'), 'r')
89
+ split_names = []
90
+ for subj in split_file.readlines():
91
+ split_names.append(subj.strip())
92
+
93
+ for i in range(len(datasets)):
94
+ names.append([name for name in split_names if os.path.basename(name).startswith(datasets[i])])
95
+ #else:
96
+ # for i in range(len(datasets)):
97
+ # names.append(glob.glob(os.path.join(self.gen_args.data_root, datasets[i] + '.*' + 'T1w.nii')))
98
+
99
+ # read brain age
100
+ ages = []
101
+ if 'age' in self.tasks:
102
+ age_file = open(os.path.join(self.gen_args.split_root, 'participants_age.txt'), 'r')
103
+ subj_name_age = []
104
+ for line in age_file.readlines(): # 'subj age\n'
105
+ subj_name_age.append(line.strip().split(' '))
106
+ for i in range(len(datasets)):
107
+ ages.append({})
108
+ for [name, age] in subj_name_age:
109
+ if name.startswith(datasets[i]):
110
+ ages[-1][name] = float(age)
111
+ print('Age info', self.split, len(ages[0].items()), min(ages[0].values()), max(ages[0].values()))
112
+
113
+ self.ages = ages
114
+ self.names = names
115
+ self.datasets = datasets
116
+ self.datasets_num = len(datasets)
117
+ self.datasets_len = [len(self.names[i]) for i in range(len(self.names))]
118
+ print('Num of data', sum([len(self.names[i]) for i in range(len(self.names))]))
119
+
120
+ self.pathology_type = None #setup_dict['pathology_type']
121
+
122
+
123
+ def prepare_tasks(self):
124
+ self.tasks = [key for (key, value) in vars(self.gen_args.task).items() if value]
125
+ if 'bias_field' in self.tasks and 'segmentation' not in self.tasks:
126
+ # add segmentation mask for computing bias_field_soft_mask
127
+ self.tasks += ['segmentation']
128
+ if 'pathology' in self.tasks and self.synth_args.augment_pathology and self.synth_args.random_shape_prob < 1.:
129
+ self.t = torch.from_numpy(np.arange(self.shape_gen_args.max_nt) * self.shape_gen_args.dt).to(self.device)
130
+ with torch.no_grad():
131
+ self.adv_pde = AdvDiffPDE(data_spacing=[1., 1., 1.],
132
+ perf_pattern='adv',
133
+ V_type='vector_div_free',
134
+ V_dict={},
135
+ BC=self.shape_gen_args.bc,
136
+ dt=self.shape_gen_args.dt,
137
+ device=self.device
138
+ )
139
+ else:
140
+ self.t, self.adv_pde = None, None
141
+ for task_name in self.tasks:
142
+ if task_name not in processing_funcs.keys():
143
+ print('Warning: Function for task "%s" not found' % task_name)
144
+
145
+
146
+ def prepare_grid(self):
147
+ self.size = self.synth_args.size
148
+
149
+ # Get resolution of training data
150
+ #aff = nib.load(os.path.join(self.modalities['Gen'], self.names[0])).affine
151
+ #self.res_training_data = np.sqrt(np.sum(abs(aff[:-1, :-1]), axis=0))
152
+
153
+ self.res_training_data = np.array([1.0, 1.0, 1.0])
154
+
155
+ xx, yy, zz = np.meshgrid(range(self.size[0]), range(self.size[1]), range(self.size[2]), sparse=False, indexing='ij')
156
+ self.xx = torch.tensor(xx, dtype=torch.float, device=self.device)
157
+ self.yy = torch.tensor(yy, dtype=torch.float, device=self.device)
158
+ self.zz = torch.tensor(zz, dtype=torch.float, device=self.device)
159
+ self.c = torch.tensor((np.array(self.size) - 1) / 2, dtype=torch.float, device=self.device)
160
+ self.xc = self.xx - self.c[0]
161
+ self.yc = self.yy - self.c[1]
162
+ self.zc = self.zz - self.c[2]
163
+ return
164
+
165
+ def prepare_one_hot(self):
166
+ if self.synth_args.left_hemis_only:
167
+ n_labels = len(label_list_segmentation_brainseg_left)
168
+ label_list_segmentation = label_list_segmentation_brainseg_left
169
+ else:
170
+ # Matrix for one-hot encoding (includes a lookup-table)
171
+ n_labels = len(label_list_segmentation_brainseg_with_extracerebral)
172
+ label_list_segmentation = label_list_segmentation_brainseg_with_extracerebral
173
+
174
+ self.lut = torch.zeros(10000, dtype=torch.long, device=self.device)
175
+ for l in range(n_labels):
176
+ self.lut[label_list_segmentation[l]] = l
177
+ self.onehotmatrix = torch.eye(n_labels, dtype=torch.float, device=self.device)
178
+
179
+ # useless for left_hemis_only
180
+ nlat = int((n_labels - n_neutral_labels_brainseg_with_extracerebral) / 2.0)
181
+ self.vflip = np.concatenate([np.array(range(n_neutral_labels_brainseg_with_extracerebral)),
182
+ np.array(range(n_neutral_labels_brainseg_with_extracerebral + nlat, n_labels)),
183
+ np.array(range(n_neutral_labels_brainseg_with_extracerebral, n_neutral_labels_brainseg_with_extracerebral + nlat))])
184
+ return
185
+
186
+
187
+ def random_affine_transform(self, shp):
188
+ rotations = (2 * self.synth_args.max_rotation * np.random.rand(3) - self.synth_args.max_rotation) / 180.0 * np.pi
189
+ shears = (2 * self.synth_args.max_shear * np.random.rand(3) - self.synth_args.max_shear)
190
+ scalings = 1 + (2 * self.synth_args.max_scaling * np.random.rand(3) - self.synth_args.max_scaling)
191
+ scaling_factor_distances = np.prod(scalings) ** .33333333333
192
+ A = torch.tensor(make_affine_matrix(rotations, shears, scalings), dtype=torch.float, device=self.device)
193
+
194
+ # sample center
195
+ if self.synth_args.random_shift:
196
+ max_shift = (torch.tensor(np.array(shp[0:3]) - self.size, dtype=torch.float, device=self.device)) / 2
197
+ max_shift[max_shift < 0] = 0
198
+ c2 = torch.tensor((np.array(shp[0:3]) - 1)/2, dtype=torch.float, device=self.device) + (2 * (max_shift * torch.rand(3, dtype=float, device=self.device)) - max_shift)
199
+ else:
200
+ c2 = torch.tensor((np.array(shp[0:3]) - 1)/2, dtype=torch.float, device=self.device)
201
+ return scaling_factor_distances, A, c2
202
+
203
+ def random_nonlinear_transform(self, photo_mode, spac):
204
+ nonlin_scale = self.synth_args.nonlin_scale_min + np.random.rand(1) * (self.synth_args.nonlin_scale_max - self.synth_args.nonlin_scale_min)
205
+ size_F_small = np.round(nonlin_scale * np.array(self.size)).astype(int).tolist()
206
+ if photo_mode:
207
+ size_F_small[1] = np.round(self.size[1]/spac).astype(int)
208
+ nonlin_std = self.synth_args.nonlin_std_max * np.random.rand()
209
+ Fsmall = nonlin_std * torch.randn([*size_F_small, 3], dtype=torch.float, device=self.device)
210
+ F = myzoom_torch(Fsmall, np.array(self.size) / size_F_small)
211
+ if photo_mode:
212
+ F[:, :, :, 1] = 0
213
+
214
+ if 'surface' in self.tasks: # TODO need to integrate the non-linear deformation fields for inverse
215
+ steplength = 1.0 / (2.0 ** self.synth_args.n_steps_svf_integration)
216
+ Fsvf = F * steplength
217
+ for _ in range(self.synth_args.n_steps_svf_integration):
218
+ Fsvf += fast_3D_interp_torch(Fsvf, self.xx + Fsvf[:, :, :, 0], self.yy + Fsvf[:, :, :, 1], self.zz + Fsvf[:, :, :, 2], 'linear')
219
+ Fsvf_neg = -F * steplength
220
+ for _ in range(self.synth_args.n_steps_svf_integration):
221
+ Fsvf_neg += fast_3D_interp_torch(Fsvf_neg, self.xx + Fsvf_neg[:, :, :, 0], self.yy + Fsvf_neg[:, :, :, 1], self.zz + Fsvf_neg[:, :, :, 2], 'linear')
222
+ F = Fsvf
223
+ Fneg = Fsvf_neg
224
+ else:
225
+ Fneg = None
226
+ return F, Fneg
227
+
228
+ def generate_deformation(self, setups, shp):
229
+
230
+ # generate affine deformation
231
+ scaling_factor_distances, A, c2 = self.random_affine_transform(shp)
232
+
233
+ # generate nonlinear deformation
234
+ if self.synth_args.nonlinear_transform:
235
+ F, Fneg = self.random_nonlinear_transform(setups['photo_mode'], setups['spac'])
236
+ else:
237
+ F, Fneg = None, None
238
+
239
+ # deform the image grid
240
+ xx2, yy2, zz2, x1, y1, z1, x2, y2, z2 = self.deform_grid(shp, A, c2, F)
241
+
242
+ return {'scaling_factor_distances': scaling_factor_distances,
243
+ 'A': A,
244
+ 'c2': c2,
245
+ 'F': F,
246
+ 'Fneg': Fneg,
247
+ 'grid': [xx2, yy2, zz2, x1, y1, z1, x2, y2, z2],
248
+ }
249
+
250
+
251
+ def get_left_hemis_mask(self, grid):
252
+ [xx2, yy2, zz2, x1, y1, z1, x2, y2, z2] = grid
253
+
254
+ if self.synth_args.left_hemis_only:
255
+ S, aff, res = read_image(self.modalities['segmentation']) # read seg map
256
+ S = torch.squeeze(torch.from_numpy(S.get_fdata()[x1:x2, y1:y2, z1:z2].astype(int))).to(self.device)
257
+ S = self.lut[S.int()] # mask out non-left labels
258
+ X, aff, res = read_image(self.modalities['registration'][0]) # read_mni_coord_X
259
+ X = torch.squeeze(torch.from_numpy(X.get_fdata()[x1:x2, y1:y2, z1:z2])).to(self.device)
260
+ self.hemis_mask = ((S > 0) & (X < 0)).int()
261
+ else:
262
+ self.hemis_mask = None
263
+
264
+ def deform_grid(self, shp, A, c2, F):
265
+ if F is not None:
266
+ # deform the images (we do nonlinear "first" ie after so we can do heavy coronal deformations in photo mode)
267
+ xx1 = self.xc + F[:, :, :, 0]
268
+ yy1 = self.yc + F[:, :, :, 1]
269
+ zz1 = self.zc + F[:, :, :, 2]
270
+ else:
271
+ xx1 = self.xc
272
+ yy1 = self.yc
273
+ zz1 = self.zc
274
+
275
+ xx2 = A[0, 0] * xx1 + A[0, 1] * yy1 + A[0, 2] * zz1 + c2[0]
276
+ yy2 = A[1, 0] * xx1 + A[1, 1] * yy1 + A[1, 2] * zz1 + c2[1]
277
+ zz2 = A[2, 0] * xx1 + A[2, 1] * yy1 + A[2, 2] * zz1 + c2[2]
278
+ xx2[xx2 < 0] = 0
279
+ yy2[yy2 < 0] = 0
280
+ zz2[zz2 < 0] = 0
281
+ xx2[xx2 > (shp[0] - 1)] = shp[0] - 1
282
+ yy2[yy2 > (shp[1] - 1)] = shp[1] - 1
283
+ zz2[zz2 > (shp[2] - 1)] = shp[2] - 1
284
+
285
+ # Get the margins for reading images
286
+ x1 = torch.floor(torch.min(xx2))
287
+ y1 = torch.floor(torch.min(yy2))
288
+ z1 = torch.floor(torch.min(zz2))
289
+ x2 = 1+torch.ceil(torch.max(xx2))
290
+ y2 = 1 + torch.ceil(torch.max(yy2))
291
+ z2 = 1 + torch.ceil(torch.max(zz2))
292
+ xx2 -= x1
293
+ yy2 -= y1
294
+ zz2 -= z1
295
+
296
+ x1 = x1.cpu().numpy().astype(int)
297
+ y1 = y1.cpu().numpy().astype(int)
298
+ z1 = z1.cpu().numpy().astype(int)
299
+ x2 = x2.cpu().numpy().astype(int)
300
+ y2 = y2.cpu().numpy().astype(int)
301
+ z2 = z2.cpu().numpy().astype(int)
302
+
303
+ return xx2, yy2, zz2, x1, y1, z1, x2, y2, z2
304
+
305
+
306
+ def augment_sample(self, name, I_def, setups, deform_dict, res, target, pathol_direction = None, input_mode = 'synth'):
307
+
308
+ sample = {}
309
+ [xx2, yy2, zz2, x1, y1, z1, x2, y2, z2] = deform_dict['grid']
310
+
311
+ if not isinstance(I_def, torch.Tensor):
312
+ I_def = torch.squeeze(torch.tensor(I_def.get_fdata()[x1:x2, y1:y2, z1:z2].astype(float), dtype=torch.float, device=self.device))
313
+ if self.hemis_mask is not None:
314
+ I_def[self.hemis_mask == 0] = 0
315
+ # Deform grid
316
+ I_def = fast_3D_interp_torch(I_def, xx2, yy2, zz2, 'linear')
317
+
318
+ if input_mode == 'CT':
319
+ I_def = torch.clamp(I_def, min = 0., max = 80.)
320
+
321
+ if 'pathology' in target and isinstance(target['pathology'], torch.Tensor) and target['pathology'].sum() > 0:
322
+ I_def = self.encode_pathology(I_def, target['pathology'], target['pathology_prob'], pathol_direction)
323
+ I_def[I_def < 0.] = 0.
324
+ else:
325
+ target['pathology'] = 0.
326
+ target['pathology_prob'] = 0.
327
+
328
+ # Augment sample
329
+ aux_dict = {}
330
+ augmentation_steps = self.augmentation_steps['synth'] if input_mode == 'synth' else self.augmentation_steps['real']
331
+ for func_name in augmentation_steps:
332
+ I_def, aux_dict = augmentation_funcs[func_name](I = I_def, aux_dict = aux_dict, cfg = self.gen_args.generator,
333
+ input_mode = input_mode, setups = setups, size = self.size, res = res, device = self.device)
334
+
335
+
336
+ # Back to original resolution
337
+ if self.synth_args.bspline_zooming:
338
+ I_def = interpol.resize(I_def, shape=self.size, anchor='edge', interpolation=3, bound='dct2', prefilter=True)
339
+ else:
340
+ I_def = myzoom_torch(I_def, 1 / aux_dict['factors'])
341
+
342
+ maxi = torch.max(I_def)
343
+ I_final = I_def / maxi
344
+
345
+ if 'super_resolution' in self.tasks:
346
+ SRresidual = aux_dict['high_res'] / maxi - I_final
347
+ sample.update({'high_res_residual': torch.flip(SRresidual, [0])[None] if setups['flip'] else SRresidual[None]})
348
+
349
+
350
+ sample.update({'input': torch.flip(I_final, [0])[None] if setups['flip'] else I_final[None]})
351
+ if 'bias_field' in self.tasks and input_mode != 'CT':
352
+ sample.update({'bias_field_log': torch.flip(aux_dict['BFlog'], [0])[None] if setups['flip'] else aux_dict['BFlog'][None]})
353
+
354
+ return sample
355
+
356
+
357
+ def generate_sample(self, name, G, setups, deform_dict, res, target):
358
+
359
+ [xx2, yy2, zz2, x1, y1, z1, x2, y2, z2] = deform_dict['grid']
360
+
361
+ # Generate contrasts
362
+ mus, sigmas = self.get_contrast(setups['photo_mode'])
363
+
364
+ G = torch.squeeze(torch.tensor(G.get_fdata()[x1:x2, y1:y2, z1:z2].astype(float), dtype=torch.float, device=self.device))
365
+ #G[G > 255] = 0 # kill extracerebral regions
366
+ G[G == 77] = 2 # merge WM lesion to white matter region
367
+ if self.hemis_mask is not None:
368
+ G[self.hemis_mask == 0] = 0
369
+ Gr = torch.round(G).long()
370
+
371
+ SYN = mus[Gr] + sigmas[Gr] * torch.randn(Gr.shape, dtype=torch.float, device=self.device)
372
+ SYN[SYN < 0] = 0
373
+ #SYN /= mus[2] # normalize by WM
374
+ #SYN = gaussian_blur_3d(SYN, 0.5*np.ones(3), self.device) # cosmetic
375
+
376
+ SYN = fast_3D_interp_torch(SYN, xx2, yy2, zz2)
377
+
378
+ # Make random linear combinations
379
+ if np.random.rand() < self.gen_args.mix_synth_prob:
380
+ v = torch.rand(4)
381
+ v[2] = 0 if 'T2' not in self.modalities else v[2]
382
+ v[3] = 0 if 'FLAIR' not in self.modalities else v[3]
383
+ v /= torch.sum(v)
384
+ SYN = v[0] * SYN + v[1] * target['T1'][0]
385
+ if 'T2' in self.modalities:
386
+ SYN += v[2] * target['T2'][0]
387
+ if 'FLAIR' in self.modalities:
388
+ SYN += v[3] * target['FLAIR'][0]
389
+
390
+ if 'pathology' in target and isinstance(target['pathology'], torch.Tensor) and target['pathology'].sum() > 0:
391
+ SYN_cerebral = SYN.clone()
392
+ SYN_cerebral[Gr == 0] = 0
393
+ SYN_cerebral = fast_3D_interp_torch(SYN_cerebral, xx2, yy2, zz2)[None]
394
+
395
+ wm_mask = (Gr==2) | (Gr==41)
396
+ wm_mean = (SYN * wm_mask).sum() / wm_mask.sum()
397
+ gm_mask = (Gr!=0) & (Gr!=2) & (Gr!=41)
398
+ gm_mean = (SYN * gm_mask).sum() / gm_mask.sum()
399
+
400
+ target['pathology'][SYN_cerebral == 0] = 0
401
+ target['pathology_prob'][SYN_cerebral == 0] = 0
402
+ # determine to be T1-resembled or T2-resembled
403
+ #if pathol_direction: lesion should be brigher than WM.mean()
404
+ # pathol_direction: +1: T2-like; -1: T1-like
405
+ pathol_direction = self.get_pathology_direction('synth', gm_mean > wm_mean)
406
+ else:
407
+ pathol_direction = None
408
+ target['pathology'] = 0.
409
+ target['pathology_prob'] = 0.
410
+
411
+ SYN[SYN < 0.] = 0.
412
+ return target['pathology'], target['pathology_prob'], self.augment_sample(name, SYN, setups, deform_dict, res, target, pathol_direction = pathol_direction)
413
+
414
+ def get_pathology_direction(self, input_mode, pathol_direction = None):
415
+ #if np.random.rand() < 0.1: # in some (rare) cases, randomly pick the direction
416
+ # return random.choice([True, False])
417
+
418
+ if pathol_direction is not None: # for synth image
419
+ return pathol_direction
420
+
421
+ if input_mode in ['T1', 'CT']:
422
+ return False
423
+
424
+ if input_mode in ['T2', 'FLAIR']:
425
+ return True
426
+
427
+ return random.choice([True, False])
428
+
429
+
430
+ def get_contrast(self, photo_mode):
431
+ # Sample Gaussian image
432
+ mus = 25 + 200 * torch.rand(256, dtype=torch.float, device=self.device)
433
+ sigmas = 5 + 20 * torch.rand(256, dtype=torch.float, device=self.device)
434
+
435
+ if np.random.rand() < self.synth_args.ct_prob:
436
+ darker = 25 + 10 * torch.rand(1, dtype=torch.float, device=self.device)[0]
437
+ for l in ct_brightness_group['darker']:
438
+ mus[l] = darker
439
+ dark = 90 + 20 * torch.rand(1, dtype=torch.float, device=self.device)[0]
440
+ for l in ct_brightness_group['dark']:
441
+ mus[l] = dark
442
+ bright = 110 + 20 * torch.rand(1, dtype=torch.float, device=self.device)[0]
443
+ for l in ct_brightness_group['bright']:
444
+ mus[l] = bright
445
+ brighter = 150 + 50 * torch.rand(1, dtype=torch.float, device=self.device)[0]
446
+ for l in ct_brightness_group['brighter']:
447
+ mus[l] = brighter
448
+
449
+ if photo_mode or np.random.rand(1)<0.5: # set the background to zero every once in a while (or always in photo mode)
450
+ mus[0] = 0
451
+
452
+ # partial volume
453
+ # 1 = lesion, 2 = WM, 3 = GM, 4 = CSF
454
+ v = 0.02 * torch.arange(50).to(self.device)
455
+ mus[100:150] = mus[1] * (1 - v) + mus[2] * v
456
+ mus[150:200] = mus[2] * (1 - v) + mus[3] * v
457
+ mus[200:250] = mus[3] * (1 - v) + mus[4] * v
458
+ mus[250] = mus[4]
459
+ sigmas[100:150] = torch.sqrt(sigmas[1]**2 * (1 - v) + sigmas[2]**2 * v)
460
+ sigmas[150:200] = torch.sqrt(sigmas[2]**2 * (1 - v) + sigmas[3]**2 * v)
461
+ sigmas[200:250] = torch.sqrt(sigmas[3]**2 * (1 - v) + sigmas[4]**2 * v)
462
+ sigmas[250] = sigmas[4]
463
+
464
+ return mus, sigmas
465
+
466
+ def get_setup_params(self):
467
+
468
+ if self.synth_args.left_hemis_only:
469
+ hemis = 'left'
470
+ else:
471
+ hemis = 'both'
472
+
473
+ if self.synth_args.low_res_only:
474
+ photo_mode = False
475
+ elif self.synth_args.left_hemis_only:
476
+ photo_mode = True
477
+ else:
478
+ photo_mode = np.random.rand() < self.synth_args.photo_prob
479
+
480
+ pathol_mode = np.random.rand() < self.synth_args.pathology_prob
481
+ pathol_random_shape = np.random.rand() < self.synth_args.random_shape_prob
482
+ spac = 2.5 + 10 * np.random.rand() if photo_mode else None
483
+ flip = np.random.randn() < self.synth_args.flip_prob if not self.synth_args.left_hemis_only else False
484
+
485
+ if photo_mode:
486
+ resolution = np.array([self.res_training_data[0], spac, self.res_training_data[2]])
487
+ thickness = np.array([self.res_training_data[0], 0.1, self.res_training_data[2]])
488
+ else:
489
+ resolution, thickness = resolution_sampler(self.synth_args.low_res_only)
490
+ return {'resolution': resolution, 'thickness': thickness,
491
+ 'photo_mode': photo_mode, 'pathol_mode': pathol_mode,
492
+ 'pathol_random_shape': pathol_random_shape,
493
+ 'spac': spac, 'flip': flip, 'hemis': hemis}
494
+
495
+
496
+ def encode_pathology(self, I, P, Pprob, pathol_direction = None):
497
+
498
+
499
+ if pathol_direction is None: # True: T2/FLAIR-resembled, False: T1-resembled
500
+ pathol_direction = random.choice([True, False])
501
+
502
+ P, Pprob = torch.squeeze(P), torch.squeeze(Pprob)
503
+ I_mu = (I * P).sum() / P.sum()
504
+
505
+ p_mask = torch.round(P).long()
506
+ #pth_mus = I_mu/4 + I_mu/2 * torch.rand(10000, dtype=torch.float, device=self.device)
507
+ pth_mus = 3*I_mu/4 + I_mu/4 * torch.rand(10000, dtype=torch.float, device=self.device) # enforce the pathology pattern harder!
508
+ pth_mus = pth_mus if pathol_direction else -pth_mus
509
+ pth_sigmas = I_mu/4 * torch.rand(10000, dtype=torch.float, device=self.device)
510
+ I += Pprob * (pth_mus[p_mask] + pth_sigmas[p_mask] * torch.randn(p_mask.shape, dtype=torch.float, device=self.device))
511
+ I[I < 0] = 0
512
+
513
+ #print('encode', P.shape, P.mean())
514
+ #print('pre', I_mu)
515
+ #I_mu = (I * P).sum() / P.sum()
516
+ #print('post', I_mu)
517
+
518
+ return I
519
+
520
+ def get_info(self, t1):
521
+
522
+ t1dm = t1[:-7] + 'T1w.defacingmask.nii'
523
+ t2 = t1[:-7] + 'T2w.nii'
524
+ t2dm = t1[:-7] + 'T2w.defacingmask.nii'
525
+ flair = t1[:-7] + 'FLAIR.nii'
526
+ flairdm = t1[:-7] + 'FLAIR.defacingmask.nii'
527
+ ct = t1[:-7] + 'CT.nii'
528
+ ctdm = t1[:-7] + 'CT.defacingmask.nii'
529
+ generation_labels = t1[:-7] + 'generation_labels.nii'
530
+ segmentation_labels = t1[:-7] + self.gen_args.segment_prefix + '.nii'
531
+ #brain_dist_map = t1[:-7] + 'brain_dist_map.nii'
532
+ lp_dist_map = t1[:-7] + 'lp_dist_map.nii'
533
+ rp_dist_map = t1[:-7] + 'rp_dist_map.nii'
534
+ lw_dist_map = t1[:-7] + 'lw_dist_map.nii'
535
+ rw_dist_map = t1[:-7] + 'rw_dist_map.nii'
536
+ mni_reg_x = t1[:-7] + 'mni_reg.x.nii'
537
+ mni_reg_y = t1[:-7] + 'mni_reg.y.nii'
538
+ mni_reg_z = t1[:-7] + 'mni_reg.z.nii'
539
+
540
+
541
+ self.modalities = {'T1': t1, 'Gen': generation_labels, 'segmentation': segmentation_labels,
542
+ 'distance': [lp_dist_map, lw_dist_map, rp_dist_map, rw_dist_map],
543
+ 'registration': [mni_reg_x, mni_reg_y, mni_reg_z]}
544
+
545
+ if os.path.isfile(t1dm):
546
+ self.modalities.update({'T1_DM': t1dm})
547
+ if os.path.isfile(t2):
548
+ self.modalities.update({'T2': t2})
549
+ if os.path.isfile(t2dm):
550
+ self.modalities.update({'T2_DM': t2dm})
551
+ if os.path.isfile(flair):
552
+ self.modalities.update({'FLAIR': flair})
553
+ if os.path.isfile(flairdm):
554
+ self.modalities.update({'FLAIR_DM': flairdm})
555
+ if os.path.isfile(ct):
556
+ self.modalities.update({'CT': ct})
557
+ if os.path.isfile(ctdm):
558
+ self.modalities.update({'CT_DM': ctdm})
559
+
560
+ return self.modalities
561
+
562
+
563
+ def read_input(self, idx):
564
+ """
565
+ determine input type according to prob (in generator/constants.py)
566
+ Logic: if np.random.rand() < real_image_prob and is real_image_exist --> input real images; otherwise, synthesize images.
567
+ """
568
+ dataset_name, input_prob, t1_path, age = self.idx_to_path(idx)
569
+ case_name = os.path.basename(t1_path).split('.T1w.nii')[0]
570
+ self.modalities = self.get_info(t1_path)
571
+
572
+ prob = np.random.rand()
573
+ if prob < input_prob['T1'] and 'T1' in self.modalities:
574
+ input_mode = 'T1'
575
+ img, aff, res = read_image(self.modalities['T1'])
576
+ elif prob < input_prob['T2'] and 'T2' in self.modalities:
577
+ input_mode = 'T2'
578
+ img, aff, res = read_image(self.modalities['T2'])
579
+ elif prob < input_prob['FLAIR'] and 'FLAIR' in self.modalities:
580
+ input_mode = 'FLAIR'
581
+ img, aff, res = read_image(self.modalities['FLAIR'])
582
+ elif prob < input_prob['CT'] and 'CT' in self.modalities:
583
+ input_mode = 'CT'
584
+ img, aff, res = read_image(self.modalities['CT'])
585
+ else:
586
+ input_mode = 'synth'
587
+ img, aff, res = read_image(self.modalities['Gen'])
588
+
589
+ return dataset_name, case_name, input_mode, img, aff, res, age
590
+
591
+
592
+ def read_and_deform_target(self, idx, exist_keys, task_name, input_mode, setups, deform_dict, linear_weights = None):
593
+ current_target = {}
594
+ p_prob_path, augment, thres = None, False, 0.1
595
+
596
+ if task_name == 'pathology':
597
+ # NOTE: for now - encode pathology only for healthy cases
598
+ # TODO: what to do if the case has pathology itself? -- inconsistency between encoded pathol and the output
599
+ if self.pathology_type is None: # healthy
600
+ if setups['pathol_mode']: # and input_mode == 'synth':
601
+ if setups['pathol_random_shape']:
602
+ p_prob_path = 'random_shape'
603
+ augment, thres = False, self.shape_gen_args.pathol_thres
604
+ else:
605
+ p_prob_path = random.choice(pathology_prob_paths)
606
+ augment, thres = self.synth_args.augment_pathology, self.shape_gen_args.pathol_thres
607
+ else:
608
+ pass
609
+ #p_prob_path = self.modalities['pathology_prob']
610
+
611
+ current_target = processing_funcs[task_name](exist_keys, task_name, p_prob_path, setups, deform_dict, self.device,
612
+ mask = self.hemis_mask,
613
+ augment = augment,
614
+ pde_func = self.adv_pde,
615
+ t = self.t,
616
+ shape_gen_args = self.shape_gen_args,
617
+ thres = thres
618
+ )
619
+
620
+ else:
621
+ if task_name in self.modalities:
622
+ current_target = processing_funcs[task_name](exist_keys, task_name, self.modalities[task_name],
623
+ setups, deform_dict, self.device,
624
+ mask = self.hemis_mask,
625
+ cfg = self.gen_args,
626
+ onehotmatrix = self.onehotmatrix,
627
+ lut = self.lut, vflip = self.vflip
628
+ )
629
+ else:
630
+ current_target = {task_name: 0.}
631
+ return current_target
632
+
633
+
634
+ def update_gen_args(self, new_args):
635
+ for key, value in vars(new_args).items():
636
+ vars(self.gen_args.generator)[key] = value
637
+
638
+ def __getitem__(self, idx):
639
+ if torch.is_tensor(idx):
640
+ idx = idx.tolist()
641
+
642
+ # read input: real or synthesized image, according to customized prob
643
+ dataset_name, case_name, input_mode, img, aff, res, age = self.read_input(idx)
644
+
645
+ # generate random values
646
+ setups = self.get_setup_params()
647
+
648
+ # sample random deformation
649
+ deform_dict = self.generate_deformation(setups, img.shape)
650
+
651
+ # get left_hemis_mask if needed
652
+ self.get_left_hemis_mask(deform_dict['grid'])
653
+
654
+ # read and deform target according to the assigned tasks
655
+ target = defaultdict(lambda: None)
656
+ target['name'] = case_name
657
+ target.update(self.read_and_deform_target(idx, target.keys(), 'T1', input_mode, setups, deform_dict))
658
+ target.update(self.read_and_deform_target(idx, target.keys(), 'T2', input_mode, setups, deform_dict))
659
+ target.update(self.read_and_deform_target(idx, target.keys(), 'FLAIR', input_mode, setups, deform_dict))
660
+ for task_name in self.tasks:
661
+ if task_name in processing_funcs.keys() and task_name not in ['T1', 'T2', 'FLAIR']:
662
+ target.update(self.read_and_deform_target(idx, target.keys(), task_name, input_mode, setups, deform_dict))
663
+
664
+
665
+ # process or generate input sample
666
+ if input_mode == 'synth':
667
+ self.update_gen_args(self.synth_image_args) # severe noise injection for real images
668
+ target['pathology'], target['pathology_prob'], sample = \
669
+ self.generate_sample(case_name, img, setups, deform_dict, res, target)
670
+ else:
671
+ self.update_gen_args(self.real_image_args) # milder noise injection for real images
672
+ sample = self.augment_sample(case_name, img, setups, deform_dict, res, target,
673
+ pathol_direction = self.get_pathology_direction(input_mode),input_mode = input_mode)
674
+
675
+ if setups['flip'] and isinstance(target['pathology'], torch.Tensor): # flipping should happen after P has been encoded
676
+ target['pathology'], target['pathology_prob'] = torch.flip(target['pathology'], [1]), torch.flip(target['pathology_prob'], [1])
677
+
678
+ if age is not None:
679
+ target['age'] = age
680
+
681
+ return self.datasets_num, dataset_name, input_mode, target, sample
682
+
683
+
684
+
685
+
686
+ # An example of customized dataset from BaseSynth
687
+ class BrainIDGen(BaseGen):
688
+ """
689
+ BrainIDGen dataset
690
+ BrainIDGen enables intra-subject augmentation, i.e., each subject will have multiple augmentations
691
+ """
692
+ def __init__(self, gen_args, device='cpu'):
693
+ super(BrainIDGen, self).__init__(gen_args, device)
694
+
695
+ self.all_samples = gen_args.generator.all_samples
696
+ self.mild_samples = gen_args.generator.mild_samples
697
+ self.mild_generator_args = gen_args.mild_generator
698
+ self.severe_generator_args = gen_args.severe_generator
699
+
700
+ def __getitem__(self, idx):
701
+ if torch.is_tensor(idx):
702
+ idx = idx.tolist()
703
+
704
+ # read input: real or synthesized image, according to customized prob
705
+ dataset_name, case_name, input_mode, img, aff, res, age = self.read_input(idx)
706
+
707
+ # generate random values
708
+ setups = self.get_setup_params()
709
+
710
+ # sample random deformation
711
+ deform_dict = self.generate_deformation(setups, img.shape)
712
+
713
+ # get left_hemis_mask if needed
714
+ self.get_left_hemis_mask(deform_dict['grid'])
715
+
716
+ # read and deform target according to the assigned tasks
717
+ target = defaultdict(lambda: 1.)
718
+ target['name'] = case_name
719
+ target.update(self.read_and_deform_target(idx, target.keys(), 'T1', input_mode, setups, deform_dict))
720
+ target.update(self.read_and_deform_target(idx, target.keys(), 'T2', input_mode, setups, deform_dict))
721
+ target.update(self.read_and_deform_target(idx, target.keys(), 'FLAIR', input_mode, setups, deform_dict))
722
+ for task_name in self.tasks:
723
+ if task_name in processing_funcs.keys() and task_name not in ['T1', 'T2', 'FLAIR']:
724
+ target.update(self.read_and_deform_target(idx, target.keys(), task_name, input_mode, setups, deform_dict))
725
+
726
+ # process or generate intra-subject input samples
727
+ samples = []
728
+ for i_sample in range(self.all_samples):
729
+ if i_sample < self.mild_samples:
730
+ self.update_gen_args(self.mild_generator_args)
731
+ if input_mode == 'synth':
732
+ self.update_gen_args(self.synth_image_args)
733
+ target['pathology'], target['pathology_prob'], sample = \
734
+ self.generate_sample(case_name, img, setups, deform_dict, res, target)
735
+ else:
736
+ self.update_gen_args(self.real_image_args)
737
+ sample = self.augment_sample(case_name, img, setups, deform_dict, res, target,
738
+ pathol_direction = self.get_pathology_direction(input_mode),input_mode = input_mode)
739
+ else:
740
+ self.update_gen_args(self.severe_generator_args)
741
+ if input_mode == 'synth':
742
+ self.update_gen_args(self.synth_image_args)
743
+ target['pathology'], target['pathology_prob'], sample = \
744
+ self.generate_sample(case_name, img, setups, deform_dict, res, target)
745
+ else:
746
+ self.update_gen_args(self.real_image_args)
747
+ sample = self.augment_sample(case_name, img, setups, deform_dict, res, target,
748
+ pathol_direction = self.get_pathology_direction(input_mode),input_mode = input_mode)
749
+
750
+ samples.append(sample)
751
+
752
+ if setups['flip'] and isinstance(target['pathology'], torch.Tensor): # flipping should happen after P has been encoded
753
+ target['pathology'], target['pathology_prob'] = torch.flip(target['pathology'], [1]), torch.flip(target['pathology_prob'], [1])
754
+
755
+ if age is not None:
756
+ target['age'] = age
757
+ return self.datasets_num, dataset_name, input_mode, target, samples
Generator/interpol/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .api import *
2
+ from .resize import *
3
+ from .restrict import *
4
+ from . import backend
5
+
6
+ from . import _version
7
+ __version__ = _version.get_versions()['version']
Generator/interpol/_version.py ADDED
@@ -0,0 +1,623 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # This file helps to compute a version number in source trees obtained from
3
+ # git-archive tarball (such as those provided by githubs download-from-tag
4
+ # feature). Distribution tarballs (built by setup.py sdist) and build
5
+ # directories (produced by setup.py build) will contain a much shorter file
6
+ # that just contains the computed version number.
7
+
8
+ # This file is released into the public domain. Generated by
9
+ # versioneer-0.20 (https://github.com/python-versioneer/python-versioneer)
10
+
11
+ """Git implementation of _version.py."""
12
+
13
+ import errno
14
+ import os
15
+ import re
16
+ import subprocess
17
+ import sys
18
+
19
+
20
+ def get_keywords():
21
+ """Get the keywords needed to look up the version information."""
22
+ # these strings will be replaced by git during git-archive.
23
+ # setup.py/versioneer.py will grep for the variable names, so they must
24
+ # each be defined on a line of their own. _version.py will just call
25
+ # get_keywords().
26
+ git_refnames = " (HEAD -> main, tag: 0.2.3)"
27
+ git_full = "414ed52c973b9d32e3e6a5a75c91cd5aab064f23"
28
+ git_date = "2023-04-17 20:36:50 -0400"
29
+ keywords = {"refnames": git_refnames, "full": git_full, "date": git_date}
30
+ return keywords
31
+
32
+
33
+ class VersioneerConfig: # pylint: disable=too-few-public-methods
34
+ """Container for Versioneer configuration parameters."""
35
+
36
+
37
+ def get_config():
38
+ """Create, populate and return the VersioneerConfig() object."""
39
+ # these strings are filled in when 'setup.py versioneer' creates
40
+ # _version.py
41
+ cfg = VersioneerConfig()
42
+ cfg.VCS = "git"
43
+ cfg.style = "pep440"
44
+ cfg.tag_prefix = ""
45
+ cfg.parentdir_prefix = ""
46
+ cfg.versionfile_source = "interpol/_version.py"
47
+ cfg.verbose = False
48
+ return cfg
49
+
50
+
51
+ class NotThisMethod(Exception):
52
+ """Exception raised if a method is not valid for the current scenario."""
53
+
54
+
55
+ LONG_VERSION_PY = {}
56
+ HANDLERS = {}
57
+
58
+
59
+ def register_vcs_handler(vcs, method): # decorator
60
+ """Create decorator to mark a method as the handler of a VCS."""
61
+ def decorate(f):
62
+ """Store f in HANDLERS[vcs][method]."""
63
+ if vcs not in HANDLERS:
64
+ HANDLERS[vcs] = {}
65
+ HANDLERS[vcs][method] = f
66
+ return f
67
+ return decorate
68
+
69
+
70
+ # pylint:disable=too-many-arguments,consider-using-with # noqa
71
+ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False,
72
+ env=None):
73
+ """Call the given command(s)."""
74
+ assert isinstance(commands, list)
75
+ process = None
76
+ for command in commands:
77
+ try:
78
+ dispcmd = str([command] + args)
79
+ # remember shell=False, so use git.cmd on windows, not just git
80
+ process = subprocess.Popen([command] + args, cwd=cwd, env=env,
81
+ stdout=subprocess.PIPE,
82
+ stderr=(subprocess.PIPE if hide_stderr
83
+ else None))
84
+ break
85
+ except EnvironmentError:
86
+ e = sys.exc_info()[1]
87
+ if e.errno == errno.ENOENT:
88
+ continue
89
+ if verbose:
90
+ print("unable to run %s" % dispcmd)
91
+ print(e)
92
+ return None, None
93
+ else:
94
+ if verbose:
95
+ print("unable to find command, tried %s" % (commands,))
96
+ return None, None
97
+ stdout = process.communicate()[0].strip().decode()
98
+ if process.returncode != 0:
99
+ if verbose:
100
+ print("unable to run %s (error)" % dispcmd)
101
+ print("stdout was %s" % stdout)
102
+ return None, process.returncode
103
+ return stdout, process.returncode
104
+
105
+
106
+ def versions_from_parentdir(parentdir_prefix, root, verbose):
107
+ """Try to determine the version from the parent directory name.
108
+
109
+ Source tarballs conventionally unpack into a directory that includes both
110
+ the project name and a version string. We will also support searching up
111
+ two directory levels for an appropriately named parent directory
112
+ """
113
+ rootdirs = []
114
+
115
+ for _ in range(3):
116
+ dirname = os.path.basename(root)
117
+ if dirname.startswith(parentdir_prefix):
118
+ return {"version": dirname[len(parentdir_prefix):],
119
+ "full-revisionid": None,
120
+ "dirty": False, "error": None, "date": None}
121
+ rootdirs.append(root)
122
+ root = os.path.dirname(root) # up a level
123
+
124
+ if verbose:
125
+ print("Tried directories %s but none started with prefix %s" %
126
+ (str(rootdirs), parentdir_prefix))
127
+ raise NotThisMethod("rootdir doesn't start with parentdir_prefix")
128
+
129
+
130
+ @register_vcs_handler("git", "get_keywords")
131
+ def git_get_keywords(versionfile_abs):
132
+ """Extract version information from the given file."""
133
+ # the code embedded in _version.py can just fetch the value of these
134
+ # keywords. When used from setup.py, we don't want to import _version.py,
135
+ # so we do it with a regexp instead. This function is not used from
136
+ # _version.py.
137
+ keywords = {}
138
+ try:
139
+ with open(versionfile_abs, "r") as fobj:
140
+ for line in fobj:
141
+ if line.strip().startswith("git_refnames ="):
142
+ mo = re.search(r'=\s*"(.*)"', line)
143
+ if mo:
144
+ keywords["refnames"] = mo.group(1)
145
+ if line.strip().startswith("git_full ="):
146
+ mo = re.search(r'=\s*"(.*)"', line)
147
+ if mo:
148
+ keywords["full"] = mo.group(1)
149
+ if line.strip().startswith("git_date ="):
150
+ mo = re.search(r'=\s*"(.*)"', line)
151
+ if mo:
152
+ keywords["date"] = mo.group(1)
153
+ except EnvironmentError:
154
+ pass
155
+ return keywords
156
+
157
+
158
+ @register_vcs_handler("git", "keywords")
159
+ def git_versions_from_keywords(keywords, tag_prefix, verbose):
160
+ """Get version information from git keywords."""
161
+ if "refnames" not in keywords:
162
+ raise NotThisMethod("Short version file found")
163
+ date = keywords.get("date")
164
+ if date is not None:
165
+ # Use only the last line. Previous lines may contain GPG signature
166
+ # information.
167
+ date = date.splitlines()[-1]
168
+
169
+ # git-2.2.0 added "%cI", which expands to an ISO-8601 -compliant
170
+ # datestamp. However we prefer "%ci" (which expands to an "ISO-8601
171
+ # -like" string, which we must then edit to make compliant), because
172
+ # it's been around since git-1.5.3, and it's too difficult to
173
+ # discover which version we're using, or to work around using an
174
+ # older one.
175
+ date = date.strip().replace(" ", "T", 1).replace(" ", "", 1)
176
+ refnames = keywords["refnames"].strip()
177
+ if refnames.startswith("$Format"):
178
+ if verbose:
179
+ print("keywords are unexpanded, not using")
180
+ raise NotThisMethod("unexpanded keywords, not a git-archive tarball")
181
+ refs = {r.strip() for r in refnames.strip("()").split(",")}
182
+ # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of
183
+ # just "foo-1.0". If we see a "tag: " prefix, prefer those.
184
+ TAG = "tag: "
185
+ tags = {r[len(TAG):] for r in refs if r.startswith(TAG)}
186
+ if not tags:
187
+ # Either we're using git < 1.8.3, or there really are no tags. We use
188
+ # a heuristic: assume all version tags have a digit. The old git %d
189
+ # expansion behaves like git log --decorate=short and strips out the
190
+ # refs/heads/ and refs/tags/ prefixes that would let us distinguish
191
+ # between branches and tags. By ignoring refnames without digits, we
192
+ # filter out many common branch names like "release" and
193
+ # "stabilization", as well as "HEAD" and "master".
194
+ tags = {r for r in refs if re.search(r'\d', r)}
195
+ if verbose:
196
+ print("discarding '%s', no digits" % ",".join(refs - tags))
197
+ if verbose:
198
+ print("likely tags: %s" % ",".join(sorted(tags)))
199
+ for ref in sorted(tags):
200
+ # sorting will prefer e.g. "2.0" over "2.0rc1"
201
+ if ref.startswith(tag_prefix):
202
+ r = ref[len(tag_prefix):]
203
+ # Filter out refs that exactly match prefix or that don't start
204
+ # with a number once the prefix is stripped (mostly a concern
205
+ # when prefix is '')
206
+ if not re.match(r'\d', r):
207
+ continue
208
+ if verbose:
209
+ print("picking %s" % r)
210
+ return {"version": r,
211
+ "full-revisionid": keywords["full"].strip(),
212
+ "dirty": False, "error": None,
213
+ "date": date}
214
+ # no suitable tags, so version is "0+unknown", but full hex is still there
215
+ if verbose:
216
+ print("no suitable tags, using unknown + full revision id")
217
+ return {"version": "0+unknown",
218
+ "full-revisionid": keywords["full"].strip(),
219
+ "dirty": False, "error": "no suitable tags", "date": None}
220
+
221
+
222
+ @register_vcs_handler("git", "pieces_from_vcs")
223
+ def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command):
224
+ """Get version from 'git describe' in the root of the source tree.
225
+
226
+ This only gets called if the git-archive 'subst' keywords were *not*
227
+ expanded, and _version.py hasn't already been rewritten with a short
228
+ version string, meaning we're inside a checked out source tree.
229
+ """
230
+ GITS = ["git"]
231
+ if sys.platform == "win32":
232
+ GITS = ["git.cmd", "git.exe"]
233
+
234
+ _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root,
235
+ hide_stderr=True)
236
+ if rc != 0:
237
+ if verbose:
238
+ print("Directory %s not under git control" % root)
239
+ raise NotThisMethod("'git rev-parse --git-dir' returned error")
240
+
241
+ # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty]
242
+ # if there isn't one, this yields HEX[-dirty] (no NUM)
243
+ describe_out, rc = runner(GITS, ["describe", "--tags", "--dirty",
244
+ "--always", "--long",
245
+ "--match", "%s*" % tag_prefix],
246
+ cwd=root)
247
+ # --long was added in git-1.5.5
248
+ if describe_out is None:
249
+ raise NotThisMethod("'git describe' failed")
250
+ describe_out = describe_out.strip()
251
+ full_out, rc = runner(GITS, ["rev-parse", "HEAD"], cwd=root)
252
+ if full_out is None:
253
+ raise NotThisMethod("'git rev-parse' failed")
254
+ full_out = full_out.strip()
255
+
256
+ pieces = {}
257
+ pieces["long"] = full_out
258
+ pieces["short"] = full_out[:7] # maybe improved later
259
+ pieces["error"] = None
260
+
261
+ branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"],
262
+ cwd=root)
263
+ # --abbrev-ref was added in git-1.6.3
264
+ if rc != 0 or branch_name is None:
265
+ raise NotThisMethod("'git rev-parse --abbrev-ref' returned error")
266
+ branch_name = branch_name.strip()
267
+
268
+ if branch_name == "HEAD":
269
+ # If we aren't exactly on a branch, pick a branch which represents
270
+ # the current commit. If all else fails, we are on a branchless
271
+ # commit.
272
+ branches, rc = runner(GITS, ["branch", "--contains"], cwd=root)
273
+ # --contains was added in git-1.5.4
274
+ if rc != 0 or branches is None:
275
+ raise NotThisMethod("'git branch --contains' returned error")
276
+ branches = branches.split("\n")
277
+
278
+ # Remove the first line if we're running detached
279
+ if "(" in branches[0]:
280
+ branches.pop(0)
281
+
282
+ # Strip off the leading "* " from the list of branches.
283
+ branches = [branch[2:] for branch in branches]
284
+ if "master" in branches:
285
+ branch_name = "master"
286
+ elif not branches:
287
+ branch_name = None
288
+ else:
289
+ # Pick the first branch that is returned. Good or bad.
290
+ branch_name = branches[0]
291
+
292
+ pieces["branch"] = branch_name
293
+
294
+ # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty]
295
+ # TAG might have hyphens.
296
+ git_describe = describe_out
297
+
298
+ # look for -dirty suffix
299
+ dirty = git_describe.endswith("-dirty")
300
+ pieces["dirty"] = dirty
301
+ if dirty:
302
+ git_describe = git_describe[:git_describe.rindex("-dirty")]
303
+
304
+ # now we have TAG-NUM-gHEX or HEX
305
+
306
+ if "-" in git_describe:
307
+ # TAG-NUM-gHEX
308
+ mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe)
309
+ if not mo:
310
+ # unparseable. Maybe git-describe is misbehaving?
311
+ pieces["error"] = ("unable to parse git-describe output: '%s'"
312
+ % describe_out)
313
+ return pieces
314
+
315
+ # tag
316
+ full_tag = mo.group(1)
317
+ if not full_tag.startswith(tag_prefix):
318
+ if verbose:
319
+ fmt = "tag '%s' doesn't start with prefix '%s'"
320
+ print(fmt % (full_tag, tag_prefix))
321
+ pieces["error"] = ("tag '%s' doesn't start with prefix '%s'"
322
+ % (full_tag, tag_prefix))
323
+ return pieces
324
+ pieces["closest-tag"] = full_tag[len(tag_prefix):]
325
+
326
+ # distance: number of commits since tag
327
+ pieces["distance"] = int(mo.group(2))
328
+
329
+ # commit: short hex revision ID
330
+ pieces["short"] = mo.group(3)
331
+
332
+ else:
333
+ # HEX: no tags
334
+ pieces["closest-tag"] = None
335
+ count_out, rc = runner(GITS, ["rev-list", "HEAD", "--count"], cwd=root)
336
+ pieces["distance"] = int(count_out) # total number of commits
337
+
338
+ # commit date: see ISO-8601 comment in git_versions_from_keywords()
339
+ date = runner(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[0].strip()
340
+ # Use only the last line. Previous lines may contain GPG signature
341
+ # information.
342
+ date = date.splitlines()[-1]
343
+ pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1)
344
+
345
+ return pieces
346
+
347
+
348
+ def plus_or_dot(pieces):
349
+ """Return a + if we don't already have one, else return a ."""
350
+ if "+" in pieces.get("closest-tag", ""):
351
+ return "."
352
+ return "+"
353
+
354
+
355
+ def render_pep440(pieces):
356
+ """Build up version string, with post-release "local version identifier".
357
+
358
+ Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you
359
+ get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty
360
+
361
+ Exceptions:
362
+ 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty]
363
+ """
364
+ if pieces["closest-tag"]:
365
+ rendered = pieces["closest-tag"]
366
+ if pieces["distance"] or pieces["dirty"]:
367
+ rendered += plus_or_dot(pieces)
368
+ rendered += "%d.g%s" % (pieces["distance"], pieces["short"])
369
+ if pieces["dirty"]:
370
+ rendered += ".dirty"
371
+ else:
372
+ # exception #1
373
+ rendered = "0+untagged.%d.g%s" % (pieces["distance"],
374
+ pieces["short"])
375
+ if pieces["dirty"]:
376
+ rendered += ".dirty"
377
+ return rendered
378
+
379
+
380
+ def render_pep440_branch(pieces):
381
+ """TAG[[.dev0]+DISTANCE.gHEX[.dirty]] .
382
+
383
+ The ".dev0" means not master branch. Note that .dev0 sorts backwards
384
+ (a feature branch will appear "older" than the master branch).
385
+
386
+ Exceptions:
387
+ 1: no tags. 0[.dev0]+untagged.DISTANCE.gHEX[.dirty]
388
+ """
389
+ if pieces["closest-tag"]:
390
+ rendered = pieces["closest-tag"]
391
+ if pieces["distance"] or pieces["dirty"]:
392
+ if pieces["branch"] != "master":
393
+ rendered += ".dev0"
394
+ rendered += plus_or_dot(pieces)
395
+ rendered += "%d.g%s" % (pieces["distance"], pieces["short"])
396
+ if pieces["dirty"]:
397
+ rendered += ".dirty"
398
+ else:
399
+ # exception #1
400
+ rendered = "0"
401
+ if pieces["branch"] != "master":
402
+ rendered += ".dev0"
403
+ rendered += "+untagged.%d.g%s" % (pieces["distance"],
404
+ pieces["short"])
405
+ if pieces["dirty"]:
406
+ rendered += ".dirty"
407
+ return rendered
408
+
409
+
410
+ def render_pep440_pre(pieces):
411
+ """TAG[.post0.devDISTANCE] -- No -dirty.
412
+
413
+ Exceptions:
414
+ 1: no tags. 0.post0.devDISTANCE
415
+ """
416
+ if pieces["closest-tag"]:
417
+ rendered = pieces["closest-tag"]
418
+ if pieces["distance"]:
419
+ rendered += ".post0.dev%d" % pieces["distance"]
420
+ else:
421
+ # exception #1
422
+ rendered = "0.post0.dev%d" % pieces["distance"]
423
+ return rendered
424
+
425
+
426
+ def render_pep440_post(pieces):
427
+ """TAG[.postDISTANCE[.dev0]+gHEX] .
428
+
429
+ The ".dev0" means dirty. Note that .dev0 sorts backwards
430
+ (a dirty tree will appear "older" than the corresponding clean one),
431
+ but you shouldn't be releasing software with -dirty anyways.
432
+
433
+ Exceptions:
434
+ 1: no tags. 0.postDISTANCE[.dev0]
435
+ """
436
+ if pieces["closest-tag"]:
437
+ rendered = pieces["closest-tag"]
438
+ if pieces["distance"] or pieces["dirty"]:
439
+ rendered += ".post%d" % pieces["distance"]
440
+ if pieces["dirty"]:
441
+ rendered += ".dev0"
442
+ rendered += plus_or_dot(pieces)
443
+ rendered += "g%s" % pieces["short"]
444
+ else:
445
+ # exception #1
446
+ rendered = "0.post%d" % pieces["distance"]
447
+ if pieces["dirty"]:
448
+ rendered += ".dev0"
449
+ rendered += "+g%s" % pieces["short"]
450
+ return rendered
451
+
452
+
453
+ def render_pep440_post_branch(pieces):
454
+ """TAG[.postDISTANCE[.dev0]+gHEX[.dirty]] .
455
+
456
+ The ".dev0" means not master branch.
457
+
458
+ Exceptions:
459
+ 1: no tags. 0.postDISTANCE[.dev0]+gHEX[.dirty]
460
+ """
461
+ if pieces["closest-tag"]:
462
+ rendered = pieces["closest-tag"]
463
+ if pieces["distance"] or pieces["dirty"]:
464
+ rendered += ".post%d" % pieces["distance"]
465
+ if pieces["branch"] != "master":
466
+ rendered += ".dev0"
467
+ rendered += plus_or_dot(pieces)
468
+ rendered += "g%s" % pieces["short"]
469
+ if pieces["dirty"]:
470
+ rendered += ".dirty"
471
+ else:
472
+ # exception #1
473
+ rendered = "0.post%d" % pieces["distance"]
474
+ if pieces["branch"] != "master":
475
+ rendered += ".dev0"
476
+ rendered += "+g%s" % pieces["short"]
477
+ if pieces["dirty"]:
478
+ rendered += ".dirty"
479
+ return rendered
480
+
481
+
482
+ def render_pep440_old(pieces):
483
+ """TAG[.postDISTANCE[.dev0]] .
484
+
485
+ The ".dev0" means dirty.
486
+
487
+ Exceptions:
488
+ 1: no tags. 0.postDISTANCE[.dev0]
489
+ """
490
+ if pieces["closest-tag"]:
491
+ rendered = pieces["closest-tag"]
492
+ if pieces["distance"] or pieces["dirty"]:
493
+ rendered += ".post%d" % pieces["distance"]
494
+ if pieces["dirty"]:
495
+ rendered += ".dev0"
496
+ else:
497
+ # exception #1
498
+ rendered = "0.post%d" % pieces["distance"]
499
+ if pieces["dirty"]:
500
+ rendered += ".dev0"
501
+ return rendered
502
+
503
+
504
+ def render_git_describe(pieces):
505
+ """TAG[-DISTANCE-gHEX][-dirty].
506
+
507
+ Like 'git describe --tags --dirty --always'.
508
+
509
+ Exceptions:
510
+ 1: no tags. HEX[-dirty] (note: no 'g' prefix)
511
+ """
512
+ if pieces["closest-tag"]:
513
+ rendered = pieces["closest-tag"]
514
+ if pieces["distance"]:
515
+ rendered += "-%d-g%s" % (pieces["distance"], pieces["short"])
516
+ else:
517
+ # exception #1
518
+ rendered = pieces["short"]
519
+ if pieces["dirty"]:
520
+ rendered += "-dirty"
521
+ return rendered
522
+
523
+
524
+ def render_git_describe_long(pieces):
525
+ """TAG-DISTANCE-gHEX[-dirty].
526
+
527
+ Like 'git describe --tags --dirty --always -long'.
528
+ The distance/hash is unconditional.
529
+
530
+ Exceptions:
531
+ 1: no tags. HEX[-dirty] (note: no 'g' prefix)
532
+ """
533
+ if pieces["closest-tag"]:
534
+ rendered = pieces["closest-tag"]
535
+ rendered += "-%d-g%s" % (pieces["distance"], pieces["short"])
536
+ else:
537
+ # exception #1
538
+ rendered = pieces["short"]
539
+ if pieces["dirty"]:
540
+ rendered += "-dirty"
541
+ return rendered
542
+
543
+
544
+ def render(pieces, style):
545
+ """Render the given version pieces into the requested style."""
546
+ if pieces["error"]:
547
+ return {"version": "unknown",
548
+ "full-revisionid": pieces.get("long"),
549
+ "dirty": None,
550
+ "error": pieces["error"],
551
+ "date": None}
552
+
553
+ if not style or style == "default":
554
+ style = "pep440" # the default
555
+
556
+ if style == "pep440":
557
+ rendered = render_pep440(pieces)
558
+ elif style == "pep440-branch":
559
+ rendered = render_pep440_branch(pieces)
560
+ elif style == "pep440-pre":
561
+ rendered = render_pep440_pre(pieces)
562
+ elif style == "pep440-post":
563
+ rendered = render_pep440_post(pieces)
564
+ elif style == "pep440-post-branch":
565
+ rendered = render_pep440_post_branch(pieces)
566
+ elif style == "pep440-old":
567
+ rendered = render_pep440_old(pieces)
568
+ elif style == "git-describe":
569
+ rendered = render_git_describe(pieces)
570
+ elif style == "git-describe-long":
571
+ rendered = render_git_describe_long(pieces)
572
+ else:
573
+ raise ValueError("unknown style '%s'" % style)
574
+
575
+ return {"version": rendered, "full-revisionid": pieces["long"],
576
+ "dirty": pieces["dirty"], "error": None,
577
+ "date": pieces.get("date")}
578
+
579
+
580
+ def get_versions():
581
+ """Get version information or return default if unable to do so."""
582
+ # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have
583
+ # __file__, we can work backwards from there to the root. Some
584
+ # py2exe/bbfreeze/non-CPython implementations don't do __file__, in which
585
+ # case we can only use expanded keywords.
586
+
587
+ cfg = get_config()
588
+ verbose = cfg.verbose
589
+
590
+ try:
591
+ return git_versions_from_keywords(get_keywords(), cfg.tag_prefix,
592
+ verbose)
593
+ except NotThisMethod:
594
+ pass
595
+
596
+ try:
597
+ root = os.path.realpath(__file__)
598
+ # versionfile_source is the relative path from the top of the source
599
+ # tree (where the .git directory might live) to this file. Invert
600
+ # this to find the root from __file__.
601
+ for _ in cfg.versionfile_source.split('/'):
602
+ root = os.path.dirname(root)
603
+ except NameError:
604
+ return {"version": "0+unknown", "full-revisionid": None,
605
+ "dirty": None,
606
+ "error": "unable to find root of source tree",
607
+ "date": None}
608
+
609
+ try:
610
+ pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose)
611
+ return render(pieces, cfg.style)
612
+ except NotThisMethod:
613
+ pass
614
+
615
+ try:
616
+ if cfg.parentdir_prefix:
617
+ return versions_from_parentdir(cfg.parentdir_prefix, root, verbose)
618
+ except NotThisMethod:
619
+ pass
620
+
621
+ return {"version": "0+unknown", "full-revisionid": None,
622
+ "dirty": None,
623
+ "error": "unable to compute version", "date": None}
Generator/interpol/api.py ADDED
@@ -0,0 +1,560 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """High level interpolation API"""
2
+
3
+ __all__ = ['grid_pull', 'grid_push', 'grid_count', 'grid_grad',
4
+ 'spline_coeff', 'spline_coeff_nd',
5
+ 'identity_grid', 'add_identity_grid', 'add_identity_grid_']
6
+
7
+ import torch
8
+ from .utils import expanded_shape, matvec
9
+ from .jit_utils import movedim1, meshgrid
10
+ from .autograd import (GridPull, GridPush, GridCount, GridGrad,
11
+ SplineCoeff, SplineCoeffND)
12
+ from . import backend, jitfields
13
+
14
+ _doc_interpolation = \
15
+ """`interpolation` can be an int, a string or an InterpolationType.
16
+ Possible values are:
17
+ - 0 or 'nearest'
18
+ - 1 or 'linear'
19
+ - 2 or 'quadratic'
20
+ - 3 or 'cubic'
21
+ - 4 or 'fourth'
22
+ - 5 or 'fifth'
23
+ - etc.
24
+ A list of values can be provided, in the order [W, H, D],
25
+ to specify dimension-specific interpolation orders."""
26
+
27
+ _doc_bound = \
28
+ """`bound` can be an int, a string or a BoundType.
29
+ Possible values are:
30
+ - 'replicate' or 'nearest' : a a a | a b c d | d d d
31
+ - 'dct1' or 'mirror' : d c b | a b c d | c b a
32
+ - 'dct2' or 'reflect' : c b a | a b c d | d c b
33
+ - 'dst1' or 'antimirror' : -b -a 0 | a b c d | 0 -d -c
34
+ - 'dst2' or 'antireflect' : -c -b -a | a b c d | -d -c -b
35
+ - 'dft' or 'wrap' : b c d | a b c d | a b c
36
+ - 'zero' or 'zeros' : 0 0 0 | a b c d | 0 0 0
37
+ A list of values can be provided, in the order [W, H, D],
38
+ to specify dimension-specific boundary conditions.
39
+ Note that
40
+ - `dft` corresponds to circular padding
41
+ - `dct2` corresponds to Neumann boundary conditions (symmetric)
42
+ - `dst2` corresponds to Dirichlet boundary conditions (antisymmetric)
43
+ See https://en.wikipedia.org/wiki/Discrete_cosine_transform
44
+ https://en.wikipedia.org/wiki/Discrete_sine_transform"""
45
+
46
+ _doc_bound_coeff = \
47
+ """`bound` can be an int, a string or a BoundType.
48
+ Possible values are:
49
+ - 'replicate' or 'nearest' : a a a | a b c d | d d d
50
+ - 'dct1' or 'mirror' : d c b | a b c d | c b a
51
+ - 'dct2' or 'reflect' : c b a | a b c d | d c b
52
+ - 'dst1' or 'antimirror' : -b -a 0 | a b c d | 0 -d -c
53
+ - 'dst2' or 'antireflect' : -c -b -a | a b c d | -d -c -b
54
+ - 'dft' or 'wrap' : b c d | a b c d | a b c
55
+ - 'zero' or 'zeros' : 0 0 0 | a b c d | 0 0 0
56
+ A list of values can be provided, in the order [W, H, D],
57
+ to specify dimension-specific boundary conditions.
58
+ Note that
59
+ - `dft` corresponds to circular padding
60
+ - `dct1` corresponds to mirroring about the center of the first/last voxel
61
+ - `dct2` corresponds to mirroring about the edge of the first/last voxel
62
+ See https://en.wikipedia.org/wiki/Discrete_cosine_transform
63
+ https://en.wikipedia.org/wiki/Discrete_sine_transform
64
+
65
+ /!\ Only 'dct1', 'dct2' and 'dft' are implemented for interpolation
66
+ orders >= 6."""
67
+
68
+ _ref_coeff = \
69
+ """..[1] M. Unser, A. Aldroubi and M. Eden.
70
+ "B-Spline Signal Processing: Part I-Theory,"
71
+ IEEE Transactions on Signal Processing 41(2):821-832 (1993).
72
+ ..[2] M. Unser, A. Aldroubi and M. Eden.
73
+ "B-Spline Signal Processing: Part II-Efficient Design and Applications,"
74
+ IEEE Transactions on Signal Processing 41(2):834-848 (1993).
75
+ ..[3] M. Unser.
76
+ "Splines: A Perfect Fit for Signal and Image Processing,"
77
+ IEEE Signal Processing Magazine 16(6):22-38 (1999).
78
+ """
79
+
80
+
81
+ def _preproc(grid, input=None, mode=None):
82
+ """Preprocess tensors for pull/push/count/grad
83
+
84
+ Low level bindings expect inputs of shape
85
+ [batch, channel, *spatial] and [batch, *spatial, dim], whereas
86
+ the high level python API accepts inputs of shape
87
+ [..., [channel], *spatial] and [..., *spatial, dim].
88
+
89
+ This function broadcasts and reshapes the input tensors accordingly.
90
+ /!\\ This *can* trigger large allocations /!\\
91
+ """
92
+ dim = grid.shape[-1]
93
+ if input is None:
94
+ spatial = grid.shape[-dim-1:-1]
95
+ batch = grid.shape[:-dim-1]
96
+ grid = grid.reshape([-1, *spatial, dim])
97
+ info = dict(batch=batch, channel=[1] if batch else [], dim=dim)
98
+ return grid, info
99
+
100
+ grid_spatial = grid.shape[-dim-1:-1]
101
+ grid_batch = grid.shape[:-dim-1]
102
+ input_spatial = input.shape[-dim:]
103
+ channel = 0 if input.dim() == dim else input.shape[-dim-1]
104
+ input_batch = input.shape[:-dim-1]
105
+
106
+ if mode == 'push':
107
+ grid_spatial = input_spatial = expanded_shape(grid_spatial, input_spatial)
108
+
109
+ # broadcast and reshape
110
+ batch = expanded_shape(grid_batch, input_batch)
111
+ grid = grid.expand([*batch, *grid_spatial, dim])
112
+ grid = grid.reshape([-1, *grid_spatial, dim])
113
+ input = input.expand([*batch, channel or 1, *input_spatial])
114
+ input = input.reshape([-1, channel or 1, *input_spatial])
115
+
116
+ out_channel = [channel] if channel else ([1] if batch else [])
117
+ info = dict(batch=batch, channel=out_channel, dim=dim)
118
+ return grid, input, info
119
+
120
+
121
+ def _postproc(out, shape_info, mode):
122
+ """Postprocess tensors for pull/push/count/grad"""
123
+ dim = shape_info['dim']
124
+ if mode != 'grad':
125
+ spatial = out.shape[-dim:]
126
+ feat = []
127
+ else:
128
+ spatial = out.shape[-dim-1:-1]
129
+ feat = [out.shape[-1]]
130
+ batch = shape_info['batch']
131
+ channel = shape_info['channel']
132
+
133
+ out = out.reshape([*batch, *channel, *spatial, *feat])
134
+ return out
135
+
136
+
137
+ def grid_pull(input, grid, interpolation='linear', bound='zero',
138
+ extrapolate=False, prefilter=False):
139
+ """Sample an image with respect to a deformation field.
140
+
141
+ Notes
142
+ -----
143
+ {interpolation}
144
+
145
+ {bound}
146
+
147
+ If the input dtype is not a floating point type, the input image is
148
+ assumed to contain labels. Then, unique labels are extracted
149
+ and resampled individually, making them soft labels. Finally,
150
+ the label map is reconstructed from the individual soft labels by
151
+ assigning the label with maximum soft value.
152
+
153
+ Parameters
154
+ ----------
155
+ input : (..., [channel], *inshape) tensor
156
+ Input image.
157
+ grid : (..., *outshape, dim) tensor
158
+ Transformation field.
159
+ interpolation : int or sequence[int], default=1
160
+ Interpolation order.
161
+ bound : BoundType or sequence[BoundType], default='zero'
162
+ Boundary conditions.
163
+ extrapolate : bool or int, default=True
164
+ Extrapolate out-of-bound data.
165
+ prefilter : bool, default=False
166
+ Apply spline pre-filter (= interpolates the input)
167
+
168
+ Returns
169
+ -------
170
+ output : (..., [channel], *outshape) tensor
171
+ Deformed image.
172
+
173
+ """
174
+ if backend.jitfields and jitfields.available:
175
+ return jitfields.grid_pull(input, grid, interpolation, bound,
176
+ extrapolate, prefilter)
177
+
178
+ grid, input, shape_info = _preproc(grid, input)
179
+ batch, channel = input.shape[:2]
180
+ dim = grid.shape[-1]
181
+
182
+ if not input.dtype.is_floating_point:
183
+ # label map -> specific processing
184
+ out = input.new_zeros([batch, channel, *grid.shape[1:-1]])
185
+ pmax = grid.new_zeros([batch, channel, *grid.shape[1:-1]])
186
+ for label in input.unique():
187
+ soft = (input == label).to(grid.dtype)
188
+ if prefilter:
189
+ input = spline_coeff_nd(soft, interpolation=interpolation,
190
+ bound=bound, dim=dim, inplace=True)
191
+ soft = GridPull.apply(soft, grid, interpolation, bound, extrapolate)
192
+ out[soft > pmax] = label
193
+ pmax = torch.max(pmax, soft)
194
+ else:
195
+ if prefilter:
196
+ input = spline_coeff_nd(input, interpolation=interpolation,
197
+ bound=bound, dim=dim)
198
+ out = GridPull.apply(input, grid, interpolation, bound, extrapolate)
199
+
200
+ return _postproc(out, shape_info, mode='pull')
201
+
202
+
203
+ def grid_push(input, grid, shape=None, interpolation='linear', bound='zero',
204
+ extrapolate=False, prefilter=False):
205
+ """Splat an image with respect to a deformation field (pull adjoint).
206
+
207
+ Notes
208
+ -----
209
+ {interpolation}
210
+
211
+ {bound}
212
+
213
+ Parameters
214
+ ----------
215
+ input : (..., [channel], *inshape) tensor
216
+ Input image.
217
+ grid : (..., *inshape, dim) tensor
218
+ Transformation field.
219
+ shape : sequence[int], default=inshape
220
+ Output shape
221
+ interpolation : int or sequence[int], default=1
222
+ Interpolation order.
223
+ bound : BoundType, or sequence[BoundType], default='zero'
224
+ Boundary conditions.
225
+ extrapolate : bool or int, default=True
226
+ Extrapolate out-of-bound data.
227
+ prefilter : bool, default=False
228
+ Apply spline pre-filter.
229
+
230
+ Returns
231
+ -------
232
+ output : (..., [channel], *shape) tensor
233
+ Spatted image.
234
+
235
+ """
236
+ if backend.jitfields and jitfields.available:
237
+ return jitfields.grid_push(input, grid, shape, interpolation, bound,
238
+ extrapolate, prefilter)
239
+
240
+ grid, input, shape_info = _preproc(grid, input, mode='push')
241
+ dim = grid.shape[-1]
242
+
243
+ if shape is None:
244
+ shape = tuple(input.shape[2:])
245
+
246
+ out = GridPush.apply(input, grid, shape, interpolation, bound, extrapolate)
247
+ if prefilter:
248
+ out = spline_coeff_nd(out, interpolation=interpolation, bound=bound,
249
+ dim=dim, inplace=True)
250
+ return _postproc(out, shape_info, mode='push')
251
+
252
+
253
+ def grid_count(grid, shape=None, interpolation='linear', bound='zero',
254
+ extrapolate=False):
255
+ """Splatting weights with respect to a deformation field (pull adjoint).
256
+
257
+ Notes
258
+ -----
259
+ {interpolation}
260
+
261
+ {bound}
262
+
263
+ Parameters
264
+ ----------
265
+ grid : (..., *inshape, dim) tensor
266
+ Transformation field.
267
+ shape : sequence[int], default=inshape
268
+ Output shape
269
+ interpolation : int or sequence[int], default=1
270
+ Interpolation order.
271
+ bound : BoundType, or sequence[BoundType], default='zero'
272
+ Boundary conditions.
273
+ extrapolate : bool or int, default=True
274
+ Extrapolate out-of-bound data.
275
+
276
+ Returns
277
+ -------
278
+ output : (..., [1], *shape) tensor
279
+ Splatted weights.
280
+
281
+ """
282
+ if backend.jitfields and jitfields.available:
283
+ return jitfields.grid_count(grid, shape, interpolation, bound, extrapolate)
284
+
285
+ grid, shape_info = _preproc(grid)
286
+ out = GridCount.apply(grid, shape, interpolation, bound, extrapolate)
287
+ return _postproc(out, shape_info, mode='count')
288
+
289
+
290
+ def grid_grad(input, grid, interpolation='linear', bound='zero',
291
+ extrapolate=False, prefilter=False):
292
+ """Sample spatial gradients of an image with respect to a deformation field.
293
+
294
+ Notes
295
+ -----
296
+ {interpolation}
297
+
298
+ {bound}
299
+
300
+ Parameters
301
+ ----------
302
+ input : (..., [channel], *inshape) tensor
303
+ Input image.
304
+ grid : (..., *inshape, dim) tensor
305
+ Transformation field.
306
+ shape : sequence[int], default=inshape
307
+ Output shape
308
+ interpolation : int or sequence[int], default=1
309
+ Interpolation order.
310
+ bound : BoundType, or sequence[BoundType], default='zero'
311
+ Boundary conditions.
312
+ extrapolate : bool or int, default=True
313
+ Extrapolate out-of-bound data.
314
+ prefilter : bool, default=False
315
+ Apply spline pre-filter (= interpolates the input)
316
+
317
+ Returns
318
+ -------
319
+ output : (..., [channel], *shape, dim) tensor
320
+ Sampled gradients.
321
+
322
+ """
323
+ if backend.jitfields and jitfields.available:
324
+ return jitfields.grid_grad(input, grid, interpolation, bound,
325
+ extrapolate, prefilter)
326
+
327
+ grid, input, shape_info = _preproc(grid, input)
328
+ dim = grid.shape[-1]
329
+ if prefilter:
330
+ input = spline_coeff_nd(input, interpolation, bound, dim)
331
+ out = GridGrad.apply(input, grid, interpolation, bound, extrapolate)
332
+ return _postproc(out, shape_info, mode='grad')
333
+
334
+
335
+ def spline_coeff(input, interpolation='linear', bound='dct2', dim=-1,
336
+ inplace=False):
337
+ """Compute the interpolating spline coefficients, for a given spline order
338
+ and boundary conditions, along a single dimension.
339
+
340
+ Notes
341
+ -----
342
+ {interpolation}
343
+
344
+ {bound}
345
+
346
+ References
347
+ ----------
348
+ {ref}
349
+
350
+
351
+ Parameters
352
+ ----------
353
+ input : tensor
354
+ Input image.
355
+ interpolation : int or sequence[int], default=1
356
+ Interpolation order.
357
+ bound : BoundType or sequence[BoundType], default='dct1'
358
+ Boundary conditions.
359
+ dim : int, default=-1
360
+ Dimension along which to process
361
+ inplace : bool, default=False
362
+ Process the volume in place.
363
+
364
+ Returns
365
+ -------
366
+ output : tensor
367
+ Coefficient image.
368
+
369
+ """
370
+ # This implementation is based on the file bsplines.c in SPM12, written
371
+ # by John Ashburner, which is itself based on the file coeff.c,
372
+ # written by Philippe Thevenaz: http://bigwww.epfl.ch/thevenaz/interpolation
373
+ # . DCT1 boundary conditions were derived by Thevenaz and Unser.
374
+ # . DFT boundary conditions were derived by John Ashburner.
375
+ # SPM12 is released under the GNU-GPL v2 license.
376
+ # Philippe Thevenaz's code does not have an explicit license as far
377
+ # as we know.
378
+ if backend.jitfields and jitfields.available:
379
+ return jitfields.spline_coeff(input, interpolation, bound,
380
+ dim, inplace)
381
+
382
+ out = SplineCoeff.apply(input, bound, interpolation, dim, inplace)
383
+ return out
384
+
385
+
386
+ def spline_coeff_nd(input, interpolation='linear', bound='dct2', dim=None,
387
+ inplace=False):
388
+ """Compute the interpolating spline coefficients, for a given spline order
389
+ and boundary conditions, along the last `dim` dimensions.
390
+
391
+ Notes
392
+ -----
393
+ {interpolation}
394
+
395
+ {bound}
396
+
397
+ References
398
+ ----------
399
+ {ref}
400
+
401
+ Parameters
402
+ ----------
403
+ input : (..., *spatial) tensor
404
+ Input image.
405
+ interpolation : int or sequence[int], default=1
406
+ Interpolation order.
407
+ bound : BoundType or sequence[BoundType], default='dct1'
408
+ Boundary conditions.
409
+ dim : int, default=-1
410
+ Number of spatial dimensions
411
+ inplace : bool, default=False
412
+ Process the volume in place.
413
+
414
+ Returns
415
+ -------
416
+ output : (..., *spatial) tensor
417
+ Coefficient image.
418
+
419
+ """
420
+ # This implementation is based on the file bsplines.c in SPM12, written
421
+ # by John Ashburner, which is itself based on the file coeff.c,
422
+ # written by Philippe Thevenaz: http://bigwww.epfl.ch/thevenaz/interpolation
423
+ # . DCT1 boundary conditions were derived by Thevenaz and Unser.
424
+ # . DFT boundary conditions were derived by John Ashburner.
425
+ # SPM12 is released under the GNU-GPL v2 license.
426
+ # Philippe Thevenaz's code does not have an explicit license as far
427
+ # as we know.
428
+ if backend.jitfields and jitfields.available:
429
+ return jitfields.spline_coeff_nd(input, interpolation, bound,
430
+ dim, inplace)
431
+
432
+ out = SplineCoeffND.apply(input, bound, interpolation, dim, inplace)
433
+ return out
434
+
435
+
436
+ grid_pull.__doc__ = grid_pull.__doc__.format(
437
+ interpolation=_doc_interpolation, bound=_doc_bound)
438
+ grid_push.__doc__ = grid_push.__doc__.format(
439
+ interpolation=_doc_interpolation, bound=_doc_bound)
440
+ grid_count.__doc__ = grid_count.__doc__.format(
441
+ interpolation=_doc_interpolation, bound=_doc_bound)
442
+ grid_grad.__doc__ = grid_grad.__doc__.format(
443
+ interpolation=_doc_interpolation, bound=_doc_bound)
444
+ spline_coeff.__doc__ = spline_coeff.__doc__.format(
445
+ interpolation=_doc_interpolation, bound=_doc_bound_coeff, ref=_ref_coeff)
446
+ spline_coeff_nd.__doc__ = spline_coeff_nd.__doc__.format(
447
+ interpolation=_doc_interpolation, bound=_doc_bound_coeff, ref=_ref_coeff)
448
+
449
+ # aliases
450
+ pull = grid_pull
451
+ push = grid_push
452
+ count = grid_count
453
+
454
+
455
+ def identity_grid(shape, dtype=None, device=None):
456
+ """Returns an identity deformation field.
457
+
458
+ Parameters
459
+ ----------
460
+ shape : (dim,) sequence of int
461
+ Spatial dimension of the field.
462
+ dtype : torch.dtype, default=`get_default_dtype()`
463
+ Data type.
464
+ device torch.device, optional
465
+ Device.
466
+
467
+ Returns
468
+ -------
469
+ grid : (*shape, dim) tensor
470
+ Transformation field
471
+
472
+ """
473
+ mesh1d = [torch.arange(float(s), dtype=dtype, device=device)
474
+ for s in shape]
475
+ grid = torch.stack(meshgrid(mesh1d), dim=-1)
476
+ return grid
477
+
478
+
479
+ @torch.jit.script
480
+ def add_identity_grid_(disp):
481
+ """Adds the identity grid to a displacement field, inplace.
482
+
483
+ Parameters
484
+ ----------
485
+ disp : (..., *spatial, dim) tensor
486
+ Displacement field
487
+
488
+ Returns
489
+ -------
490
+ grid : (..., *spatial, dim) tensor
491
+ Transformation field
492
+
493
+ """
494
+ dim = disp.shape[-1]
495
+ spatial = disp.shape[-dim-1:-1]
496
+ mesh1d = [torch.arange(s, dtype=disp.dtype, device=disp.device)
497
+ for s in spatial]
498
+ grid = meshgrid(mesh1d)
499
+ disp = movedim1(disp, -1, 0)
500
+ for i, grid1 in enumerate(grid):
501
+ disp[i].add_(grid1)
502
+ disp = movedim1(disp, 0, -1)
503
+ return disp
504
+
505
+
506
+ @torch.jit.script
507
+ def add_identity_grid(disp):
508
+ """Adds the identity grid to a displacement field.
509
+
510
+ Parameters
511
+ ----------
512
+ disp : (..., *spatial, dim) tensor
513
+ Displacement field
514
+
515
+ Returns
516
+ -------
517
+ grid : (..., *spatial, dim) tensor
518
+ Transformation field
519
+
520
+ """
521
+ return add_identity_grid_(disp.clone())
522
+
523
+
524
+ def affine_grid(mat, shape):
525
+ """Create a dense transformation grid from an affine matrix.
526
+
527
+ Parameters
528
+ ----------
529
+ mat : (..., D[+1], D+1) tensor
530
+ Affine matrix (or matrices).
531
+ shape : (D,) sequence[int]
532
+ Shape of the grid, with length D.
533
+
534
+ Returns
535
+ -------
536
+ grid : (..., *shape, D) tensor
537
+ Dense transformation grid
538
+
539
+ """
540
+ mat = torch.as_tensor(mat)
541
+ shape = list(shape)
542
+ nb_dim = mat.shape[-1] - 1
543
+ if nb_dim != len(shape):
544
+ raise ValueError('Dimension of the affine matrix ({}) and shape ({}) '
545
+ 'are not the same.'.format(nb_dim, len(shape)))
546
+ if mat.shape[-2] not in (nb_dim, nb_dim+1):
547
+ raise ValueError('First argument should be matrces of shape '
548
+ '(..., {0}, {1}) or (..., {1], {1}) but got {2}.'
549
+ .format(nb_dim, nb_dim+1, mat.shape))
550
+ batch_shape = mat.shape[:-2]
551
+ grid = identity_grid(shape, mat.dtype, mat.device)
552
+ if batch_shape:
553
+ for _ in range(len(batch_shape)):
554
+ grid = grid.unsqueeze(0)
555
+ for _ in range(nb_dim):
556
+ mat = mat.unsqueeze(-1)
557
+ lin = mat[..., :nb_dim, :nb_dim]
558
+ off = mat[..., :nb_dim, -1]
559
+ grid = matvec(lin, grid) + off
560
+ return grid
Generator/interpol/autograd.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """AutoGrad version of pull/push/count/grad"""
2
+ import torch
3
+ from .coeff import spline_coeff_nd, spline_coeff
4
+ from .bounds import BoundType
5
+ from .splines import InterpolationType
6
+ from .pushpull import (
7
+ grid_pull, grid_pull_backward,
8
+ grid_push, grid_push_backward,
9
+ grid_count, grid_count_backward,
10
+ grid_grad, grid_grad_backward)
11
+ from .utils import fake_decorator
12
+ try:
13
+ from torch.cuda.amp import custom_fwd, custom_bwd
14
+ except (ModuleNotFoundError, ImportError):
15
+ custom_fwd = custom_bwd = fake_decorator
16
+
17
+
18
+ def make_list(x):
19
+ if not isinstance(x, (list, tuple)):
20
+ x = [x]
21
+ return list(x)
22
+
23
+
24
+ def bound_to_nitorch(bound, as_type='str'):
25
+ """Convert boundary type to niTorch's convention.
26
+
27
+ Parameters
28
+ ----------
29
+ bound : [list of] str or bound_like
30
+ Boundary condition in any convention
31
+ as_type : {'str', 'enum', 'int'}, default='str'
32
+ Return BoundType or int rather than str
33
+
34
+ Returns
35
+ -------
36
+ bound : [list of] str or BoundType
37
+ Boundary condition in NITorch's convention
38
+
39
+ """
40
+ intype = type(bound)
41
+ if not isinstance(bound, (list, tuple)):
42
+ bound = [bound]
43
+ obound = []
44
+ for b in bound:
45
+ b = b.lower() if isinstance(b, str) else b
46
+ if b in ('replicate', 'repeat', 'border', 'nearest', BoundType.replicate):
47
+ obound.append('replicate')
48
+ elif b in ('zero', 'zeros', 'constant', BoundType.zero):
49
+ obound.append('zero')
50
+ elif b in ('dct2', 'reflect', 'reflection', 'neumann', BoundType.dct2):
51
+ obound.append('dct2')
52
+ elif b in ('dct1', 'mirror', BoundType.dct1):
53
+ obound.append('dct1')
54
+ elif b in ('dft', 'wrap', 'circular', BoundType.dft):
55
+ obound.append('dft')
56
+ elif b in ('dst2', 'antireflect', 'dirichlet', BoundType.dst2):
57
+ obound.append('dst2')
58
+ elif b in ('dst1', 'antimirror', BoundType.dst1):
59
+ obound.append('dst1')
60
+ elif isinstance(b, int):
61
+ obound.append(b)
62
+ else:
63
+ raise ValueError(f'Unknown boundary condition {b}')
64
+ obound = list(map(lambda b: getattr(BoundType, b) if isinstance(b, str)
65
+ else BoundType(b), obound))
66
+ if as_type in ('int', int):
67
+ obound = [b.value for b in obound]
68
+ if as_type in ('str', str):
69
+ obound = [b.name for b in obound]
70
+ if issubclass(intype, (list, tuple)):
71
+ obound = intype(obound)
72
+ else:
73
+ obound = obound[0]
74
+ return obound
75
+
76
+
77
+ def inter_to_nitorch(inter, as_type='str'):
78
+ """Convert interpolation order to NITorch's convention.
79
+
80
+ Parameters
81
+ ----------
82
+ inter : [sequence of] int or str or InterpolationType
83
+ as_type : {'str', 'enum', 'int'}, default='int'
84
+
85
+ Returns
86
+ -------
87
+ inter : [sequence of] int or InterpolationType
88
+
89
+ """
90
+ intype = type(inter)
91
+ if not isinstance(inter, (list, tuple)):
92
+ inter = [inter]
93
+ ointer = []
94
+ for o in inter:
95
+ o = o.lower() if isinstance(o, str) else o
96
+ if o in (0, 'nearest', InterpolationType.nearest):
97
+ ointer.append(0)
98
+ elif o in (1, 'linear', InterpolationType.linear):
99
+ ointer.append(1)
100
+ elif o in (2, 'quadratic', InterpolationType.quadratic):
101
+ ointer.append(2)
102
+ elif o in (3, 'cubic', InterpolationType.cubic):
103
+ ointer.append(3)
104
+ elif o in (4, 'fourth', InterpolationType.fourth):
105
+ ointer.append(4)
106
+ elif o in (5, 'fifth', InterpolationType.fifth):
107
+ ointer.append(5)
108
+ elif o in (6, 'sixth', InterpolationType.sixth):
109
+ ointer.append(6)
110
+ elif o in (7, 'seventh', InterpolationType.seventh):
111
+ ointer.append(7)
112
+ else:
113
+ raise ValueError(f'Unknown interpolation order {o}')
114
+ if as_type in ('enum', 'str', str):
115
+ ointer = list(map(InterpolationType, ointer))
116
+ if as_type in ('str', str):
117
+ ointer = [o.name for o in ointer]
118
+ if issubclass(intype, (list, tuple)):
119
+ ointer = intype(ointer)
120
+ else:
121
+ ointer = ointer[0]
122
+ return ointer
123
+
124
+
125
+ class GridPull(torch.autograd.Function):
126
+
127
+ @staticmethod
128
+ @custom_fwd(cast_inputs=torch.float32)
129
+ def forward(ctx, input, grid, interpolation, bound, extrapolate):
130
+
131
+ bound = bound_to_nitorch(make_list(bound), as_type='int')
132
+ interpolation = inter_to_nitorch(make_list(interpolation), as_type='int')
133
+ extrapolate = int(extrapolate)
134
+ opt = (bound, interpolation, extrapolate)
135
+
136
+ # Pull
137
+ output = grid_pull(input, grid, *opt)
138
+
139
+ # Context
140
+ ctx.opt = opt
141
+ ctx.save_for_backward(input, grid)
142
+
143
+ return output
144
+
145
+ @staticmethod
146
+ @custom_bwd
147
+ def backward(ctx, grad):
148
+ var = ctx.saved_tensors
149
+ opt = ctx.opt
150
+ grads = grid_pull_backward(grad, *var, *opt)
151
+ grad_input, grad_grid = grads
152
+ return grad_input, grad_grid, None, None, None
153
+
154
+
155
+ class GridPush(torch.autograd.Function):
156
+
157
+ @staticmethod
158
+ @custom_fwd(cast_inputs=torch.float32)
159
+ def forward(ctx, input, grid, shape, interpolation, bound, extrapolate):
160
+
161
+ bound = bound_to_nitorch(make_list(bound), as_type='int')
162
+ interpolation = inter_to_nitorch(make_list(interpolation), as_type='int')
163
+ extrapolate = int(extrapolate)
164
+ opt = (bound, interpolation, extrapolate)
165
+
166
+ # Push
167
+ output = grid_push(input, grid, shape, *opt)
168
+
169
+ # Context
170
+ ctx.opt = opt
171
+ ctx.save_for_backward(input, grid)
172
+
173
+ return output
174
+
175
+ @staticmethod
176
+ @custom_bwd
177
+ def backward(ctx, grad):
178
+ var = ctx.saved_tensors
179
+ opt = ctx.opt
180
+ grads = grid_push_backward(grad, *var, *opt)
181
+ grad_input, grad_grid = grads
182
+ return grad_input, grad_grid, None, None, None, None
183
+
184
+
185
+ class GridCount(torch.autograd.Function):
186
+
187
+ @staticmethod
188
+ @custom_fwd(cast_inputs=torch.float32)
189
+ def forward(ctx, grid, shape, interpolation, bound, extrapolate):
190
+
191
+ bound = bound_to_nitorch(make_list(bound), as_type='int')
192
+ interpolation = inter_to_nitorch(make_list(interpolation), as_type='int')
193
+ extrapolate = int(extrapolate)
194
+ opt = (bound, interpolation, extrapolate)
195
+
196
+ # Push
197
+ output = grid_count(grid, shape, *opt)
198
+
199
+ # Context
200
+ ctx.opt = opt
201
+ ctx.save_for_backward(grid)
202
+
203
+ return output
204
+
205
+ @staticmethod
206
+ @custom_bwd
207
+ def backward(ctx, grad):
208
+ var = ctx.saved_tensors
209
+ opt = ctx.opt
210
+ grad_grid = None
211
+ if ctx.needs_input_grad[0]:
212
+ grad_grid = grid_count_backward(grad, *var, *opt)
213
+ return grad_grid, None, None, None, None
214
+
215
+
216
+ class GridGrad(torch.autograd.Function):
217
+
218
+ @staticmethod
219
+ @custom_fwd(cast_inputs=torch.float32)
220
+ def forward(ctx, input, grid, interpolation, bound, extrapolate):
221
+
222
+ bound = bound_to_nitorch(make_list(bound), as_type='int')
223
+ interpolation = inter_to_nitorch(make_list(interpolation), as_type='int')
224
+ extrapolate = int(extrapolate)
225
+ opt = (bound, interpolation, extrapolate)
226
+
227
+ # Pull
228
+ output = grid_grad(input, grid, *opt)
229
+
230
+ # Context
231
+ ctx.opt = opt
232
+ ctx.save_for_backward(input, grid)
233
+
234
+ return output
235
+
236
+ @staticmethod
237
+ @custom_bwd
238
+ def backward(ctx, grad):
239
+ var = ctx.saved_tensors
240
+ opt = ctx.opt
241
+ grad_input = grad_grid = None
242
+ if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
243
+ grads = grid_grad_backward(grad, *var, *opt)
244
+ grad_input, grad_grid = grads
245
+ return grad_input, grad_grid, None, None, None
246
+
247
+
248
+ class SplineCoeff(torch.autograd.Function):
249
+
250
+ @staticmethod
251
+ @custom_fwd
252
+ def forward(ctx, input, bound, interpolation, dim, inplace):
253
+
254
+ bound = bound_to_nitorch(make_list(bound)[0], as_type='int')
255
+ interpolation = inter_to_nitorch(make_list(interpolation)[0], as_type='int')
256
+ opt = (bound, interpolation, dim, inplace)
257
+
258
+ # Pull
259
+ output = spline_coeff(input, *opt)
260
+
261
+ # Context
262
+ if input.requires_grad:
263
+ ctx.opt = opt
264
+
265
+ return output
266
+
267
+ @staticmethod
268
+ @custom_bwd
269
+ def backward(ctx, grad):
270
+ # symmetric filter -> backward == forward
271
+ # (I don't know if I can write into grad, so inplace=False to be safe)
272
+ grad = spline_coeff(grad, *ctx.opt[:-1], inplace=False)
273
+ return [grad] + [None] * 4
274
+
275
+
276
+ class SplineCoeffND(torch.autograd.Function):
277
+
278
+ @staticmethod
279
+ @custom_fwd
280
+ def forward(ctx, input, bound, interpolation, dim, inplace):
281
+
282
+ bound = bound_to_nitorch(make_list(bound), as_type='int')
283
+ interpolation = inter_to_nitorch(make_list(interpolation), as_type='int')
284
+ opt = (bound, interpolation, dim, inplace)
285
+
286
+ # Pull
287
+ output = spline_coeff_nd(input, *opt)
288
+
289
+ # Context
290
+ if input.requires_grad:
291
+ ctx.opt = opt
292
+
293
+ return output
294
+
295
+ @staticmethod
296
+ @custom_bwd
297
+ def backward(ctx, grad):
298
+ # symmetric filter -> backward == forward
299
+ # (I don't know if I can write into grad, so inplace=False to be safe)
300
+ grad = spline_coeff_nd(grad, *ctx.opt[:-1], inplace=False)
301
+ return grad, None, None, None, None
Generator/interpol/backend.py ADDED
@@ -0,0 +1 @@
 
 
1
+ jitfields = False # Whether to use jitfields if available
Generator/interpol/bounds.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from enum import Enum
3
+ from typing import Optional
4
+ from .jit_utils import floor_div
5
+ Tensor = torch.Tensor
6
+
7
+
8
+ class BoundType(Enum):
9
+ zero = zeros = 0
10
+ replicate = nearest = 1
11
+ dct1 = mirror = 2
12
+ dct2 = reflect = 3
13
+ dst1 = antimirror = 4
14
+ dst2 = antireflect = 5
15
+ dft = wrap = 6
16
+
17
+
18
+ class ExtrapolateType(Enum):
19
+ no = 0 # threshold: (0, n-1)
20
+ yes = 1
21
+ hist = 2 # threshold: (-0.5, n-0.5)
22
+
23
+
24
+ @torch.jit.script
25
+ class Bound:
26
+
27
+ def __init__(self, bound_type: int = 3):
28
+ self.type = bound_type
29
+
30
+ def index(self, i, n: int):
31
+ if self.type in (0, 1): # zero / replicate
32
+ return i.clamp(min=0, max=n-1)
33
+ elif self.type in (3, 5): # dct2 / dst2
34
+ n2 = n * 2
35
+ i = torch.where(i < 0, (-i-1).remainder(n2).neg().add(n2 - 1),
36
+ i.remainder(n2))
37
+ i = torch.where(i >= n, -i + (n2 - 1), i)
38
+ return i
39
+ elif self.type == 2: # dct1
40
+ if n == 1:
41
+ return torch.zeros(i.shape, dtype=i.dtype, device=i.device)
42
+ else:
43
+ n2 = (n - 1) * 2
44
+ i = i.abs().remainder(n2)
45
+ i = torch.where(i >= n, -i + n2, i)
46
+ return i
47
+ elif self.type == 4: # dst1
48
+ n2 = 2 * (n + 1)
49
+ first = torch.zeros([1], dtype=i.dtype, device=i.device)
50
+ last = torch.full([1], n - 1, dtype=i.dtype, device=i.device)
51
+ i = torch.where(i < 0, -i - 2, i)
52
+ i = i.remainder(n2)
53
+ i = torch.where(i > n, -i + (n2 - 2), i)
54
+ i = torch.where(i == -1, first, i)
55
+ i = torch.where(i == n, last, i)
56
+ return i
57
+ elif self.type == 6: # dft
58
+ return i.remainder(n)
59
+ else:
60
+ return i
61
+
62
+ def transform(self, i, n: int) -> Optional[Tensor]:
63
+ if self.type == 4: # dst1
64
+ if n == 1:
65
+ return None
66
+ one = torch.ones([1], dtype=torch.int8, device=i.device)
67
+ zero = torch.zeros([1], dtype=torch.int8, device=i.device)
68
+ n2 = 2 * (n + 1)
69
+ i = torch.where(i < 0, -i + (n-1), i)
70
+ i = i.remainder(n2)
71
+ x = torch.where(i == 0, zero, one)
72
+ x = torch.where(i.remainder(n + 1) == n, zero, x)
73
+ i = floor_div(i, n+1)
74
+ x = torch.where(torch.remainder(i, 2) > 0, -x, x)
75
+ return x
76
+ elif self.type == 5: # dst2
77
+ i = torch.where(i < 0, n - 1 - i, i)
78
+ x = torch.ones([1], dtype=torch.int8, device=i.device)
79
+ i = floor_div(i, n)
80
+ x = torch.where(torch.remainder(i, 2) > 0, -x, x)
81
+ return x
82
+ elif self.type == 0: # zero
83
+ one = torch.ones([1], dtype=torch.int8, device=i.device)
84
+ zero = torch.zeros([1], dtype=torch.int8, device=i.device)
85
+ outbounds = ((i < 0) | (i >= n))
86
+ x = torch.where(outbounds, zero, one)
87
+ return x
88
+ else:
89
+ return None
Generator/interpol/coeff.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Compute spline interpolating coefficients
2
+
3
+ These functions are ported from the C routines in SPM's bsplines.c
4
+ by John Ashburner, which are themselves ports from Philippe Thevenaz's
5
+ code. JA furthermore derived the initial conditions for the DFT ("wrap around")
6
+ boundary conditions.
7
+
8
+ Note that similar routines are available in scipy with boundary conditions
9
+ DCT1 ("mirror"), DCT2 ("reflect") and DFT ("wrap"); all derived by P. Thevenaz,
10
+ according to the comments. Our DCT2 boundary conditions are ported from
11
+ scipy.
12
+
13
+ Only boundary conditions DCT1, DCT2 and DFT are implemented.
14
+
15
+ References
16
+ ----------
17
+ ..[1] M. Unser, A. Aldroubi and M. Eden.
18
+ "B-Spline Signal Processing: Part I-Theory,"
19
+ IEEE Transactions on Signal Processing 41(2):821-832 (1993).
20
+ ..[2] M. Unser, A. Aldroubi and M. Eden.
21
+ "B-Spline Signal Processing: Part II-Efficient Design and Applications,"
22
+ IEEE Transactions on Signal Processing 41(2):834-848 (1993).
23
+ ..[3] M. Unser.
24
+ "Splines: A Perfect Fit for Signal and Image Processing,"
25
+ IEEE Signal Processing Magazine 16(6):22-38 (1999).
26
+ """
27
+ import torch
28
+ import math
29
+ from typing import List, Optional
30
+ from .jit_utils import movedim1
31
+ from .pushpull import pad_list_int
32
+
33
+
34
+ @torch.jit.script
35
+ def get_poles(order: int) -> List[float]:
36
+ empty: List[float] = []
37
+ if order in (0, 1):
38
+ return empty
39
+ if order == 2:
40
+ return [math.sqrt(8.) - 3.]
41
+ if order == 3:
42
+ return [math.sqrt(3.) - 2.]
43
+ if order == 4:
44
+ return [math.sqrt(664. - math.sqrt(438976.)) + math.sqrt(304.) - 19.,
45
+ math.sqrt(664. + math.sqrt(438976.)) - math.sqrt(304.) - 19.]
46
+ if order == 5:
47
+ return [math.sqrt(67.5 - math.sqrt(4436.25)) + math.sqrt(26.25) - 6.5,
48
+ math.sqrt(67.5 + math.sqrt(4436.25)) - math.sqrt(26.25) - 6.5]
49
+ if order == 6:
50
+ return [-0.488294589303044755130118038883789062112279161239377608394,
51
+ -0.081679271076237512597937765737059080653379610398148178525368,
52
+ -0.00141415180832581775108724397655859252786416905534669851652709]
53
+ if order == 7:
54
+ return [-0.5352804307964381655424037816816460718339231523426924148812,
55
+ -0.122554615192326690515272264359357343605486549427295558490763,
56
+ -0.0091486948096082769285930216516478534156925639545994482648003]
57
+ raise NotImplementedError
58
+
59
+
60
+ @torch.jit.script
61
+ def get_gain(poles: List[float]) -> float:
62
+ lam: float = 1.
63
+ for pole in poles:
64
+ lam *= (1. - pole) * (1. - 1./pole)
65
+ return lam
66
+
67
+
68
+ @torch.jit.script
69
+ def dft_initial(inp, pole: float, dim: int = -1, keepdim: bool = False):
70
+
71
+ assert inp.shape[dim] > 1
72
+ max_iter: int = int(math.ceil(-30./math.log(abs(pole))))
73
+ max_iter = min(max_iter, inp.shape[dim])
74
+
75
+ poles = torch.as_tensor(pole, dtype=inp.dtype, device=inp.device)
76
+ poles = poles.pow(torch.arange(1, max_iter, dtype=inp.dtype, device=inp.device))
77
+ poles = poles.flip(0)
78
+
79
+ inp = movedim1(inp, dim, 0)
80
+ inp0 = inp[0]
81
+ inp = inp[1-max_iter:]
82
+ inp = movedim1(inp, 0, -1)
83
+ out = torch.matmul(inp.unsqueeze(-2), poles.unsqueeze(-1)).squeeze(-1)
84
+ out = out + inp0.unsqueeze(-1)
85
+ if keepdim:
86
+ out = movedim1(out, -1, dim)
87
+ else:
88
+ out = out.squeeze(-1)
89
+
90
+ pole = pole ** max_iter
91
+ out = out / (1 - pole)
92
+ return out
93
+
94
+
95
+ @torch.jit.script
96
+ def dct1_initial(inp, pole: float, dim: int = -1, keepdim: bool = False):
97
+
98
+ n = inp.shape[dim]
99
+ max_iter: int = int(math.ceil(-30./math.log(abs(pole))))
100
+
101
+ if max_iter < n:
102
+
103
+ poles = torch.as_tensor(pole, dtype=inp.dtype, device=inp.device)
104
+ poles = poles.pow(torch.arange(1, max_iter, dtype=inp.dtype, device=inp.device))
105
+
106
+ inp = movedim1(inp, dim, 0)
107
+ inp0 = inp[0]
108
+ inp = inp[1:max_iter]
109
+ inp = movedim1(inp, 0, -1)
110
+ out = torch.matmul(inp.unsqueeze(-2), poles.unsqueeze(-1)).squeeze(-1)
111
+ out = out + inp0.unsqueeze(-1)
112
+ if keepdim:
113
+ out = movedim1(out, -1, dim)
114
+ else:
115
+ out = out.squeeze(-1)
116
+
117
+ else:
118
+ max_iter = n
119
+
120
+ polen = pole ** (n - 1)
121
+ inp0 = inp[0] + polen * inp[-1]
122
+ inp = inp[1:-1]
123
+ inp = movedim1(inp, 0, -1)
124
+
125
+ poles = torch.as_tensor(pole, dtype=inp.dtype, device=inp.device)
126
+ poles = poles.pow(torch.arange(1, n-1, dtype=inp.dtype, device=inp.device))
127
+ poles = poles + (polen * polen) / poles
128
+
129
+ out = torch.matmul(inp.unsqueeze(-2), poles.unsqueeze(-1)).squeeze(-1)
130
+ out = out + inp0.unsqueeze(-1)
131
+ if keepdim:
132
+ out = movedim1(out, -1, dim)
133
+ else:
134
+ out = out.squeeze(-1)
135
+
136
+ pole = pole ** (max_iter - 1)
137
+ out = out / (1 - pole * pole)
138
+
139
+ return out
140
+
141
+
142
+ @torch.jit.script
143
+ def dct2_initial(inp, pole: float, dim: int = -1, keepdim: bool = False):
144
+ # Ported from scipy:
145
+ # https://github.com/scipy/scipy/blob/master/scipy/ndimage/src/ni_splines.c
146
+ #
147
+ # I (YB) unwarped and simplied the terms so that I could use a dot
148
+ # product instead of a loop.
149
+ # It should certainly be possible to derive a version for max_iter < n,
150
+ # as JA did for DCT1, to avoid long recursions when `n` is large. But
151
+ # I think it would require a more complicated anticausal/final condition.
152
+
153
+ n = inp.shape[dim]
154
+
155
+ polen = pole ** n
156
+ pole_last = polen * (1 + 1/(pole + polen * polen))
157
+ inp00 = inp[0]
158
+ inp0 = inp[0] + pole_last * inp[-1]
159
+ inp = inp[1:-1]
160
+ inp = movedim1(inp, 0, -1)
161
+
162
+ poles = torch.as_tensor(pole, dtype=inp.dtype, device=inp.device)
163
+ poles = (poles.pow(torch.arange(1, n-1, dtype=inp.dtype, device=inp.device)) +
164
+ poles.pow(torch.arange(2*n-2, n, -1, dtype=inp.dtype, device=inp.device)))
165
+
166
+ out = torch.matmul(inp.unsqueeze(-2), poles.unsqueeze(-1)).squeeze(-1)
167
+
168
+ out = out + inp0.unsqueeze(-1)
169
+ out = out * (pole / (1 - polen * polen))
170
+ out = out + inp00.unsqueeze(-1)
171
+
172
+ if keepdim:
173
+ out = movedim1(out, -1, dim)
174
+ else:
175
+ out = out.squeeze(-1)
176
+
177
+ return out
178
+
179
+
180
+ @torch.jit.script
181
+ def dft_final(inp, pole: float, dim: int = -1, keepdim: bool = False):
182
+
183
+ assert inp.shape[dim] > 1
184
+ max_iter: int = int(math.ceil(-30./math.log(abs(pole))))
185
+ max_iter = min(max_iter, inp.shape[dim])
186
+
187
+ poles = torch.as_tensor(pole, dtype=inp.dtype, device=inp.device)
188
+ poles = poles.pow(torch.arange(2, max_iter+1, dtype=inp.dtype, device=inp.device))
189
+
190
+ inp = movedim1(inp, dim, 0)
191
+ inp0 = inp[-1]
192
+ inp = inp[:max_iter-1]
193
+ inp = movedim1(inp, 0, -1)
194
+ out = torch.matmul(inp.unsqueeze(-2), poles.unsqueeze(-1)).squeeze(-1)
195
+ out = out.add(inp0.unsqueeze(-1), alpha=pole)
196
+ if keepdim:
197
+ out = movedim1(out, -1, dim)
198
+ else:
199
+ out = out.squeeze(-1)
200
+
201
+ pole = pole ** max_iter
202
+ out = out / (pole - 1)
203
+ return out
204
+
205
+
206
+ @torch.jit.script
207
+ def dct1_final(inp, pole: float, dim: int = -1, keepdim: bool = False):
208
+ inp = movedim1(inp, dim, 0)
209
+ out = pole * inp[-2] + inp[-1]
210
+ out = out * (pole / (pole*pole - 1))
211
+ if keepdim:
212
+ out = movedim1(out.unsqueeze(0), 0, dim)
213
+ return out
214
+
215
+
216
+ @torch.jit.script
217
+ def dct2_final(inp, pole: float, dim: int = -1, keepdim: bool = False):
218
+ # Ported from scipy:
219
+ # https://github.com/scipy/scipy/blob/master/scipy/ndimage/src/ni_splines.c
220
+ inp = movedim1(inp, dim, 0)
221
+ out = inp[-1] * (pole / (pole - 1))
222
+ if keepdim:
223
+ out = movedim1(out.unsqueeze(0), 0, dim)
224
+ return out
225
+
226
+
227
+ @torch.jit.script
228
+ class CoeffBound:
229
+
230
+ def __init__(self, bound: int):
231
+ self.bound = bound
232
+
233
+ def initial(self, inp, pole: float, dim: int = -1, keepdim: bool = False):
234
+ if self.bound in (0, 2): # zero, dct1
235
+ return dct1_initial(inp, pole, dim, keepdim)
236
+ elif self.bound in (1, 3): # nearest, dct2
237
+ return dct2_initial(inp, pole, dim, keepdim)
238
+ elif self.bound == 6: # dft
239
+ return dft_initial(inp, pole, dim, keepdim)
240
+ else:
241
+ raise NotImplementedError
242
+
243
+ def final(self, inp, pole: float, dim: int = -1, keepdim: bool = False):
244
+ if self.bound in (0, 2): # zero, dct1
245
+ return dct1_final(inp, pole, dim, keepdim)
246
+ elif self.bound in (1, 3): # nearest, dct2
247
+ return dct2_final(inp, pole, dim, keepdim)
248
+ elif self.bound == 6: # dft
249
+ return dft_final(inp, pole, dim, keepdim)
250
+ else:
251
+ raise NotImplementedError
252
+
253
+
254
+ @torch.jit.script
255
+ def filter(inp, bound: CoeffBound, poles: List[float],
256
+ dim: int = -1, inplace: bool = False):
257
+
258
+ if not inplace:
259
+ inp = inp.clone()
260
+
261
+ if inp.shape[dim] == 1:
262
+ return inp
263
+
264
+ gain = get_gain(poles)
265
+ inp *= gain
266
+ inp = movedim1(inp, dim, 0)
267
+ n = inp.shape[0]
268
+
269
+ for pole in poles:
270
+ inp[0] = bound.initial(inp, pole, dim=0, keepdim=False)
271
+
272
+ for i in range(1, n):
273
+ inp[i].add_(inp[i-1], alpha=pole)
274
+
275
+ inp[-1] = bound.final(inp, pole, dim=0, keepdim=False)
276
+
277
+ for i in range(n-2, -1, -1):
278
+ inp[i].neg_().add_(inp[i+1]).mul_(pole)
279
+
280
+ inp = movedim1(inp, 0, dim)
281
+ return inp
282
+
283
+
284
+ @torch.jit.script
285
+ def spline_coeff(inp, bound: int, order: int, dim: int = -1,
286
+ inplace: bool = False):
287
+ """Compute the interpolating spline coefficients, for a given spline order
288
+ and boundary conditions, along a single dimension.
289
+
290
+ Parameters
291
+ ----------
292
+ inp : tensor
293
+ bound : {2: dct1, 6: dft}
294
+ order : {0..7}
295
+ dim : int, default=-1
296
+ inplace : bool, default=False
297
+
298
+ Returns
299
+ -------
300
+ out : tensor
301
+
302
+ """
303
+ if not inplace:
304
+ inp = inp.clone()
305
+
306
+ if order in (0, 1):
307
+ return inp
308
+
309
+ poles = get_poles(order)
310
+ return filter(inp, CoeffBound(bound), poles, dim=dim, inplace=True)
311
+
312
+
313
+ @torch.jit.script
314
+ def spline_coeff_nd(inp, bound: List[int], order: List[int],
315
+ dim: Optional[int] = None, inplace: bool = False):
316
+ """Compute the interpolating spline coefficients, for a given spline order
317
+ and boundary condition, along the last `dim` dimensions.
318
+
319
+ Parameters
320
+ ----------
321
+ inp : (..., *spatial) tensor
322
+ bound : List[{2: dct1, 6: dft}]
323
+ order : List[{0..7}]
324
+ dim : int, default=`inp.dim()`
325
+ inplace : bool, default=False
326
+
327
+ Returns
328
+ -------
329
+ out : (..., *spatial) tensor
330
+
331
+ """
332
+ if not inplace:
333
+ inp = inp.clone()
334
+
335
+ if dim is None:
336
+ dim = inp.dim()
337
+
338
+ bound = pad_list_int(bound, dim)
339
+ order = pad_list_int(order, dim)
340
+
341
+ for d, b, o in zip(range(dim), bound, order):
342
+ inp = spline_coeff(inp, b, o, dim=-dim + d, inplace=True)
343
+
344
+ return inp
Generator/interpol/iso0.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Isotropic 0-th order splines ("nearest neighbor")"""
2
+ import torch
3
+ from .bounds import Bound
4
+ from .jit_utils import (sub2ind_list, make_sign,
5
+ inbounds_mask_3d, inbounds_mask_2d, inbounds_mask_1d)
6
+ from typing import List, Optional
7
+ Tensor = torch.Tensor
8
+
9
+
10
+ @torch.jit.script
11
+ def get_indices(g, n: int, bound: Bound):
12
+ g0 = g.round().long()
13
+ sign0 = bound.transform(g0, n)
14
+ g0 = bound.index(g0, n)
15
+ return g0, sign0
16
+
17
+
18
+ # ======================================================================
19
+ # 3D
20
+ # ======================================================================
21
+
22
+
23
+ @torch.jit.script
24
+ def pull3d(inp, g, bound: List[Bound], extrapolate: int = 1):
25
+ """
26
+ inp: (B, C, iX, iY, iZ) tensor
27
+ g: (B, oX, oY, oZ, 3) tensor
28
+ bound: List{3}[Bound] tensor
29
+ extrapolate: ExtrapolateType
30
+ returns: (B, C, oX, oY, oZ) tensor
31
+ """
32
+ dim = 3
33
+ boundx, boundy, boundz = bound
34
+ oshape = g.shape[-dim-1:-1]
35
+ g = g.reshape([g.shape[0], 1, -1, dim])
36
+ gx, gy, gz = g.unbind(-1)
37
+ batch = max(inp.shape[0], gx.shape[0])
38
+ channel = inp.shape[1]
39
+ shape = inp.shape[-dim:]
40
+ nx, ny, nz = shape
41
+
42
+ # mask of inbounds voxels
43
+ mask = inbounds_mask_3d(extrapolate, gx, gy, gz, nx, ny, nz)
44
+
45
+ # nearest integer coordinates
46
+ gx, signx = get_indices(gx, nx, boundx)
47
+ gy, signy = get_indices(gy, ny, boundy)
48
+ gz, signz = get_indices(gz, nz, boundz)
49
+
50
+ # gather
51
+ inp = inp.reshape(inp.shape[:2] + [-1])
52
+ idx = sub2ind_list([gx, gy, gz], shape)
53
+ idx = idx.expand([batch, channel, idx.shape[-1]])
54
+ out = inp.gather(-1, idx)
55
+ sign = make_sign([signx, signy, signz])
56
+ if sign is not None:
57
+ out *= sign
58
+ if mask is not None:
59
+ out *= mask
60
+ out = out.reshape(out.shape[:2] + oshape)
61
+ return out
62
+
63
+
64
+ @torch.jit.script
65
+ def push3d(inp, g, shape: Optional[List[int]], bound: List[Bound],
66
+ extrapolate: int = 1):
67
+ """
68
+ inp: (B, C, iX, iY, iZ) tensor
69
+ g: (B, iX, iY, iZ, 3) tensor
70
+ shape: List{3}[int], optional
71
+ bound: List{3}[Bound] tensor
72
+ extrapolate: ExtrapolateType
73
+ returns: (B, C, *shape) tensor
74
+ """
75
+ dim = 3
76
+ boundx, boundy, boundz = bound
77
+ if inp.shape[-dim:] != g.shape[-dim-1:-1]:
78
+ raise ValueError('Input and grid should have the same spatial shape')
79
+ ishape = inp.shape[-dim:]
80
+ g = g.reshape([g.shape[0], 1, -1, dim])
81
+ gx, gy, gz = torch.unbind(g, -1)
82
+ inp = inp.reshape(inp.shape[:2] + [-1])
83
+ batch = max(inp.shape[0], gx.shape[0])
84
+ channel = inp.shape[1]
85
+
86
+ if shape is None:
87
+ shape = ishape
88
+ nx, ny, nz = shape
89
+
90
+ # mask of inbounds voxels
91
+ mask = inbounds_mask_3d(extrapolate, gx, gy, gz, nx, ny, nz)
92
+
93
+ # nearest integer coordinates
94
+ gx, signx = get_indices(gx, nx, boundx)
95
+ gy, signy = get_indices(gy, ny, boundy)
96
+ gz, signz = get_indices(gz, nz, boundz)
97
+
98
+ # scatter
99
+ out = torch.zeros([batch, channel, nx*ny*nz], dtype=inp.dtype, device=inp.device)
100
+ idx = sub2ind_list([gx, gy, gz], shape)
101
+ idx = idx.expand([batch, channel, idx.shape[-1]])
102
+ sign = make_sign([signx, signy, signz])
103
+ if sign is not None or mask is not None:
104
+ inp = inp.clone()
105
+ if sign is not None:
106
+ inp *= sign
107
+ if mask is not None:
108
+ inp *= mask
109
+ out.scatter_add_(-1, idx, inp)
110
+
111
+ out = out.reshape(out.shape[:2] + shape)
112
+ return out
113
+
114
+
115
+ # ======================================================================
116
+ # 2D
117
+ # ======================================================================
118
+
119
+
120
+ @torch.jit.script
121
+ def pull2d(inp, g, bound: List[Bound], extrapolate: int = 1):
122
+ """
123
+ inp: (B, C, iX, iY) tensor
124
+ g: (B, oX, oY, 2) tensor
125
+ bound: List{2}[Bound] tensor
126
+ extrapolate: ExtrapolateType
127
+ returns: (B, C, oX, oY) tensor
128
+ """
129
+ dim = 2
130
+ boundx, boundy = bound
131
+ oshape = g.shape[-dim-1:-1]
132
+ g = g.reshape([g.shape[0], 1, -1, dim])
133
+ gx, gy = g.unbind(-1)
134
+ batch = max(inp.shape[0], gx.shape[0])
135
+ channel = inp.shape[1]
136
+ shape = inp.shape[-dim:]
137
+ nx, ny = shape
138
+
139
+ # mask of inbounds voxels
140
+ mask = inbounds_mask_2d(extrapolate, gx, gy, nx, ny)
141
+
142
+ # nearest integer coordinates
143
+ gx, signx = get_indices(gx, nx, boundx)
144
+ gy, signy = get_indices(gy, ny, boundy)
145
+
146
+ # gather
147
+ inp = inp.reshape(inp.shape[:2] + [-1])
148
+ idx = sub2ind_list([gx, gy], shape)
149
+ idx = idx.expand([batch, channel, idx.shape[-1]])
150
+ out = inp.gather(-1, idx)
151
+ sign = make_sign([signx, signy])
152
+ if sign is not None:
153
+ out = out * sign
154
+ if mask is not None:
155
+ out = mask * mask
156
+ out = out.reshape(out.shape[:2] + oshape)
157
+ return out
158
+
159
+
160
+ @torch.jit.script
161
+ def push2d(inp, g, shape: Optional[List[int]], bound: List[Bound],
162
+ extrapolate: int = 1):
163
+ """
164
+ inp: (B, C, iX, iY) tensor
165
+ g: (B, iX, iY, 2) tensor
166
+ shape: List{2}[int], optional
167
+ bound: List{2}[Bound] tensor
168
+ extrapolate: ExtrapolateType
169
+ returns: (B, C, *shape) tensor
170
+ """
171
+ dim = 2
172
+ boundx, boundy = bound
173
+ if inp.shape[-dim:] != g.shape[-dim-1:-1]:
174
+ raise ValueError('Input and grid should have the same spatial shape')
175
+ ishape = inp.shape[-dim:]
176
+ g = g.reshape([g.shape[0], 1, -1, dim])
177
+ gx, gy = torch.unbind(g, -1)
178
+ inp = inp.reshape(inp.shape[:2] + [-1])
179
+ batch = max(inp.shape[0], gx.shape[0])
180
+ channel = inp.shape[1]
181
+
182
+ if shape is None:
183
+ shape = ishape
184
+ nx, ny = shape
185
+
186
+ # mask of inbounds voxels
187
+ mask = inbounds_mask_2d(extrapolate, gx, gy, nx, ny)
188
+
189
+ # nearest integer coordinates
190
+ gx, signx = get_indices(gx, nx, boundx)
191
+ gy, signy = get_indices(gy, ny, boundy)
192
+
193
+ # scatter
194
+ out = torch.zeros([batch, channel, nx*ny], dtype=inp.dtype, device=inp.device)
195
+ idx = sub2ind_list([gx, gy], shape)
196
+ idx = idx.expand([batch, channel, idx.shape[-1]])
197
+ sign = make_sign([signx, signy])
198
+ if sign is not None or mask is not None:
199
+ inp = inp.clone()
200
+ if sign is not None:
201
+ inp = inp * sign
202
+ if mask is not None:
203
+ inp = inp * mask
204
+ out.scatter_add_(-1, idx, inp)
205
+
206
+ out = out.reshape(out.shape[:2] + shape)
207
+ return out
208
+
209
+
210
+ # ======================================================================
211
+ # 1D
212
+ # ======================================================================
213
+
214
+
215
+ @torch.jit.script
216
+ def pull1d(inp, g, bound: List[Bound], extrapolate: int = 1):
217
+ """
218
+ inp: (B, C, iX) tensor
219
+ g: (B, oX, 1) tensor
220
+ bound: List{1}[Bound] tensor
221
+ extrapolate: ExtrapolateType
222
+ returns: (B, C, oX) tensor
223
+ """
224
+ dim = 1
225
+ boundx = bound[0]
226
+ oshape = g.shape[-dim-1:-1]
227
+ g = g.reshape([g.shape[0], 1, -1, dim])
228
+ gx = g.squeeze(-1)
229
+ batch = max(inp.shape[0], gx.shape[0])
230
+ channel = inp.shape[1]
231
+ shape = inp.shape[-dim:]
232
+ nx = shape[0]
233
+
234
+ # mask of inbounds voxels
235
+ mask = inbounds_mask_1d(extrapolate, gx, nx)
236
+
237
+ # nearest integer coordinates
238
+ gx, signx = get_indices(gx, nx, boundx)
239
+
240
+ # gather
241
+ inp = inp.reshape(inp.shape[:2] + [-1])
242
+ idx = gx
243
+ idx = idx.expand([batch, channel, idx.shape[-1]])
244
+ out = inp.gather(-1, idx)
245
+ sign = signx
246
+ if sign is not None:
247
+ out = out * sign
248
+ if mask is not None:
249
+ out = out * mask
250
+ out = out.reshape(out.shape[:2] + oshape)
251
+ return out
252
+
253
+
254
+ @torch.jit.script
255
+ def push1d(inp, g, shape: Optional[List[int]], bound: List[Bound],
256
+ extrapolate: int = 1):
257
+ """
258
+ inp: (B, C, iX) tensor
259
+ g: (B, iX, 1) tensor
260
+ shape: List{1}[int], optional
261
+ bound: List{1}[Bound] tensor
262
+ extrapolate: ExtrapolateType
263
+ returns: (B, C, *shape) tensor
264
+ """
265
+ dim = 1
266
+ boundx = bound[0]
267
+ if inp.shape[-dim:] != g.shape[-dim-1:-1]:
268
+ raise ValueError('Input and grid should have the same spatial shape')
269
+ ishape = inp.shape[-dim:]
270
+ g = g.reshape([g.shape[0], 1, -1, dim])
271
+ gx = g.squeeze(-1)
272
+ inp = inp.reshape(inp.shape[:2] + [-1])
273
+ batch = max(inp.shape[0], gx.shape[0])
274
+ channel = inp.shape[1]
275
+
276
+ if shape is None:
277
+ shape = ishape
278
+ nx = shape[0]
279
+
280
+ # mask of inbounds voxels
281
+ mask = inbounds_mask_1d(extrapolate, gx, nx)
282
+
283
+ # nearest integer coordinates
284
+ gx, signx = get_indices(gx, nx, boundx)
285
+
286
+ # scatter
287
+ out = torch.zeros([batch, channel, nx], dtype=inp.dtype, device=inp.device)
288
+ idx = gx
289
+ idx = idx.expand([batch, channel, idx.shape[-1]])
290
+ sign = signx
291
+ if sign is not None or mask is not None:
292
+ inp = inp.clone()
293
+ if sign is not None:
294
+ inp = inp * sign
295
+ if mask is not None:
296
+ inp = inp * mask
297
+ out.scatter_add_(-1, idx, inp)
298
+
299
+ out = out.reshape(out.shape[:2] + shape)
300
+ return out
301
+
302
+
303
+ # ======================================================================
304
+ # ND
305
+ # ======================================================================
306
+
307
+
308
+ @torch.jit.script
309
+ def grad(inp, g, bound: List[Bound], extrapolate: int = 1):
310
+ """
311
+ inp: (B, C, *ishape) tensor
312
+ g: (B, *oshape, D) tensor
313
+ bound: List{D}[Bound] tensor
314
+ extrapolate: ExtrapolateType
315
+ returns: (B, C, *oshape, D) tensor
316
+ """
317
+ dim = g.shape[-1]
318
+ oshape = list(g.shape[-dim-1:-1])
319
+ batch = max(inp.shape[0], g.shape[0])
320
+ channel = inp.shape[1]
321
+
322
+ return torch.zeros([batch, channel] + oshape + [dim],
323
+ dtype=inp.dtype, device=inp.device)
324
+
325
+
326
+ @torch.jit.script
327
+ def pushgrad(inp, g, shape: Optional[List[int]], bound: List[Bound],
328
+ extrapolate: int = 1):
329
+ """
330
+ inp: (B, C, *ishape, D) tensor
331
+ g: (B, *ishape, D) tensor
332
+ shape: List{D}[int], optional, optional
333
+ bound: List{D}[Bound] tensor
334
+ extrapolate: ExtrapolateType
335
+ returns: (B, C, *shape) tensor
336
+ """
337
+ dim = g.shape[-1]
338
+ if inp.shape[-dim-1:-1] != g.shape[-dim-1:-1]:
339
+ raise ValueError('Input and grid should have the same spatial shape')
340
+ ishape = inp.shape[-dim-1:-1]
341
+ batch = max(inp.shape[0], g.shape[0])
342
+ channel = inp.shape[1]
343
+
344
+ if shape is None:
345
+ shape = ishape
346
+ shape = list(shape)
347
+
348
+ return torch.zeros([batch, channel] + shape,
349
+ dtype=inp.dtype, device=inp.device)
350
+
351
+
352
+ @torch.jit.script
353
+ def hess(inp, g, bound: List[Bound], extrapolate: int = 1):
354
+ """
355
+ inp: (B, C, *ishape) tensor
356
+ g: (B, *oshape, D) tensor
357
+ bound: List{D}[Bound] tensor
358
+ extrapolate: ExtrapolateType
359
+ returns: (B, C, *oshape, D, D) tensor
360
+ """
361
+ dim = g.shape[-1]
362
+ oshape = list(g.shape[-dim-1:-1])
363
+ g = g.reshape([g.shape[0], 1, -1, dim])
364
+ batch = max(inp.shape[0], g.shape[0])
365
+ channel = inp.shape[1]
366
+
367
+ return torch.zeros([batch, channel] + oshape + [dim, dim],
368
+ dtype=inp.dtype, device=inp.device)
Generator/interpol/iso1.py ADDED
@@ -0,0 +1,1339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Isotropic 1-st order splines ("linear/bilinear/trilinear")"""
2
+ import torch
3
+ from .bounds import Bound
4
+ from .jit_utils import (sub2ind_list, make_sign,
5
+ inbounds_mask_3d, inbounds_mask_2d, inbounds_mask_1d)
6
+ from typing import List, Tuple, Optional
7
+ Tensor = torch.Tensor
8
+
9
+
10
+ @torch.jit.script
11
+ def get_weights_and_indices(g, n: int, bound: Bound) \
12
+ -> Tuple[Tensor, Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
13
+ g0 = g.floor().long()
14
+ g1 = g0 + 1
15
+ sign1 = bound.transform(g1, n)
16
+ sign0 = bound.transform(g0, n)
17
+ g1 = bound.index(g1, n)
18
+ g0 = bound.index(g0, n)
19
+ g = g - g.floor()
20
+ return g, g0, g1, sign0, sign1
21
+
22
+
23
+ # ======================================================================
24
+ # 3D
25
+ # ======================================================================
26
+
27
+
28
+ @torch.jit.script
29
+ def pull3d(inp, g, bound: List[Bound], extrapolate: int = 1):
30
+ """
31
+ inp: (B, C, iX, iY, iZ) tensor
32
+ g: (B, oX, oY, oZ, 3) tensor
33
+ bound: List{3}[Bound] tensor
34
+ extrapolate: ExtrapolateType
35
+ returns: (B, C, oX, oY, oZ) tensor
36
+ """
37
+ dim = 3
38
+ boundx, boundy, boundz = bound
39
+ oshape = list(g.shape[-dim-1:-1])
40
+ g = g.reshape([g.shape[0], 1, -1, dim])
41
+ gx, gy, gz = g.unbind(-1)
42
+ batch = max(inp.shape[0], gx.shape[0])
43
+ channel = inp.shape[1]
44
+ shape = list(inp.shape[-dim:])
45
+ nx, ny, nz = shape
46
+
47
+ # mask of inbounds voxels
48
+ mask = inbounds_mask_3d(extrapolate, gx, gy, gz, nx, ny, nz)
49
+
50
+ # corners
51
+ # (upper weight, lower corner, upper corner, lower sign, upper sign)
52
+ gx, gx0, gx1, signx0, signx1 = get_weights_and_indices(gx, nx, boundx)
53
+ gy, gy0, gy1, signy0, signy1 = get_weights_and_indices(gy, ny, boundy)
54
+ gz, gz0, gz1, signz0, signz1 = get_weights_and_indices(gz, nz, boundz)
55
+
56
+ # gather
57
+ inp = inp.reshape(list(inp.shape[:2]) + [-1])
58
+ # - corner 000
59
+ idx = sub2ind_list([gx0, gy0, gz0], shape)
60
+ idx = idx.expand([batch, channel, idx.shape[-1]])
61
+ out = inp.gather(-1, idx)
62
+ sign = make_sign([signx0, signy0, signz0])
63
+ if sign is not None:
64
+ out = out * sign
65
+ out = out * ((1 - gx) * (1 - gy) * (1 - gz))
66
+ # - corner 001
67
+ idx = sub2ind_list([gx0, gy0, gz1], shape)
68
+ idx = idx.expand([batch, channel, idx.shape[-1]])
69
+ out1 = inp.gather(-1, idx)
70
+ sign = make_sign([signx0, signy0, signz1])
71
+ if sign is not None:
72
+ out1 = out1 * sign
73
+ out1 = out1 * ((1 - gx) * (1 - gy) * gz)
74
+ out = out + out1
75
+ # - corner 010
76
+ idx = sub2ind_list([gx0, gy1, gz0], shape)
77
+ idx = idx.expand([batch, channel, idx.shape[-1]])
78
+ out1 = inp.gather(-1, idx)
79
+ sign = make_sign([signx0, signy1, signz0])
80
+ if sign is not None:
81
+ out1 = out1 * sign
82
+ out1 = out1 * ((1 - gx) * gy * (1 - gz))
83
+ out = out + out1
84
+ # - corner 011
85
+ idx = sub2ind_list([gx0, gy1, gz1], shape)
86
+ idx = idx.expand([batch, channel, idx.shape[-1]])
87
+ out1 = inp.gather(-1, idx)
88
+ sign = make_sign([signx0, signy1, signz1])
89
+ if sign is not None:
90
+ out1 = out1 * sign
91
+ out1 = out1 * ((1 - gx) * gy * gz)
92
+ out = out + out1
93
+ # - corner 100
94
+ idx = sub2ind_list([gx1, gy0, gz0], shape)
95
+ idx = idx.expand([batch, channel, idx.shape[-1]])
96
+ out1 = inp.gather(-1, idx)
97
+ sign = make_sign([signx1, signy0, signz0])
98
+ if sign is not None:
99
+ out1 = out1 * sign
100
+ out1 = out1 * (gx * (1 - gy) * (1 - gz))
101
+ out = out + out1
102
+ # - corner 101
103
+ idx = sub2ind_list([gx1, gy0, gz1], shape)
104
+ idx = idx.expand([batch, channel, idx.shape[-1]])
105
+ out1 = inp.gather(-1, idx)
106
+ sign = make_sign([signx1, signy0, signz1])
107
+ if sign is not None:
108
+ out1 = out1 * sign
109
+ out1 = out1 * (gx * (1 - gy) * gz)
110
+ out = out + out1
111
+ # - corner 110
112
+ idx = sub2ind_list([gx1, gy1, gz0], shape)
113
+ idx = idx.expand([batch, channel, idx.shape[-1]])
114
+ out1 = inp.gather(-1, idx)
115
+ sign = make_sign([signx1, signy1, signz0])
116
+ if sign is not None:
117
+ out1 = out1 * sign
118
+ out1 = out1 * (gx * gy * (1 - gz))
119
+ out = out + out1
120
+ # - corner 111
121
+ idx = sub2ind_list([gx1, gy1, gz1], shape)
122
+ idx = idx.expand([batch, channel, idx.shape[-1]])
123
+ out1 = inp.gather(-1, idx)
124
+ sign = make_sign([signx1, signy1, signz1])
125
+ if sign is not None:
126
+ out1 = out1 * sign
127
+ out1 = out1 * (gx * gy * gz)
128
+ out = out + out1
129
+
130
+ if mask is not None:
131
+ out *= mask
132
+ out = out.reshape(list(out.shape[:2]) + oshape)
133
+ return out
134
+
135
+
136
+ @torch.jit.script
137
+ def push3d(inp, g, shape: Optional[List[int]], bound: List[Bound],
138
+ extrapolate: int = 1):
139
+ """
140
+ inp: (B, C, iX, iY, iZ) tensor
141
+ g: (B, iX, iY, iZ, 3) tensor
142
+ shape: List{3}[int], optional
143
+ bound: List{3}[Bound] tensor
144
+ extrapolate: ExtrapolateType
145
+ returns: (B, C, *shape) tensor
146
+ """
147
+ dim = 3
148
+ boundx, boundy, boundz = bound
149
+ if inp.shape[-dim:] != g.shape[-dim-1:-1]:
150
+ raise ValueError('Input and grid should have the same spatial shape')
151
+ ishape = list(inp.shape[-dim:])
152
+ g = g.reshape([g.shape[0], 1, -1, dim])
153
+ gx, gy, gz = torch.unbind(g, -1)
154
+ inp = inp.reshape(list(inp.shape[:2]) + [-1])
155
+ batch = max(inp.shape[0], gx.shape[0])
156
+ channel = inp.shape[1]
157
+
158
+ if shape is None:
159
+ shape = ishape
160
+ shape = list(shape)
161
+ nx, ny, nz = shape
162
+
163
+ # mask of inbounds voxels
164
+ mask = inbounds_mask_3d(extrapolate, gx, gy, gz, nx, ny, nz)
165
+
166
+ # corners
167
+ # (upper weight, lower corner, upper corner, lower sign, upper sign)
168
+ gx, gx0, gx1, signx0, signx1 = get_weights_and_indices(gx, nx, boundx)
169
+ gy, gy0, gy1, signy0, signy1 = get_weights_and_indices(gy, ny, boundy)
170
+ gz, gz0, gz1, signz0, signz1 = get_weights_and_indices(gz, nz, boundz)
171
+
172
+ # scatter
173
+ out = torch.zeros([batch, channel, nx*ny*nz],
174
+ dtype=inp.dtype, device=inp.device)
175
+ # - corner 000
176
+ idx = sub2ind_list([gx0, gy0, gz0], shape)
177
+ idx = idx.expand([batch, channel, idx.shape[-1]])
178
+ out1 = inp.clone()
179
+ sign = make_sign([signx0, signy0, signz0])
180
+ if sign is not None:
181
+ out1 = out1 * sign
182
+ if mask is not None:
183
+ out1 = out1 * mask
184
+ out1 = out1 * ((1 - gx) * (1 - gy) * (1 - gz))
185
+ out.scatter_add_(-1, idx, out1)
186
+ # - corner 001
187
+ idx = sub2ind_list([gx0, gy0, gz1], shape)
188
+ idx = idx.expand([batch, channel, idx.shape[-1]])
189
+ out1 = inp.clone()
190
+ sign = make_sign([signx0, signy0, signz1])
191
+ if sign is not None:
192
+ out1 = out1 * sign
193
+ if mask is not None:
194
+ out1 = out1 * mask
195
+ out1 = out1 * ((1 - gx) * (1 - gy) * gz)
196
+ out.scatter_add_(-1, idx, out1)
197
+ # - corner 010
198
+ idx = sub2ind_list([gx0, gy1, gz0], shape)
199
+ idx = idx.expand([batch, channel, idx.shape[-1]])
200
+ out1 = inp.clone()
201
+ sign = make_sign([signx0, signy1, signz0])
202
+ if sign is not None:
203
+ out1 = out1 * sign
204
+ if mask is not None:
205
+ out1 = out1 * mask
206
+ out1 = out1 * ((1 - gx) * gy * (1 - gz))
207
+ out.scatter_add_(-1, idx, out1)
208
+ # - corner 011
209
+ idx = sub2ind_list([gx0, gy1, gz1], shape)
210
+ idx = idx.expand([batch, channel, idx.shape[-1]])
211
+ out1 = inp.clone()
212
+ sign = make_sign([signx0, signy1, signz1])
213
+ if sign is not None:
214
+ out1 = out1 * sign
215
+ if mask is not None:
216
+ out1 = out1 * mask
217
+ out1 = out1 * ((1 - gx) * gy * gz)
218
+ out.scatter_add_(-1, idx, out1)
219
+ # - corner 100
220
+ idx = sub2ind_list([gx1, gy0, gz0], shape)
221
+ idx = idx.expand([batch, channel, idx.shape[-1]])
222
+ out1 = inp.clone()
223
+ sign = make_sign([signx1, signy0, signz0])
224
+ if sign is not None:
225
+ out1 = out1 * sign
226
+ if mask is not None:
227
+ out1 = out1 * mask
228
+ out1 = out1 * (gx * (1 - gy) * (1 - gz))
229
+ out.scatter_add_(-1, idx, out1)
230
+ # - corner 101
231
+ idx = sub2ind_list([gx1, gy0, gz1], shape)
232
+ idx = idx.expand([batch, channel, idx.shape[-1]])
233
+ out1 = inp.clone()
234
+ sign = make_sign([signx1, signy0, signz1])
235
+ if sign is not None:
236
+ out1 = out1 * sign
237
+ if mask is not None:
238
+ out1 = out1 * mask
239
+ out1 = out1 * (gx * (1 - gy) * gz)
240
+ out.scatter_add_(-1, idx, out1)
241
+ # - corner 110
242
+ idx = sub2ind_list([gx1, gy1, gz0], shape)
243
+ idx = idx.expand([batch, channel, idx.shape[-1]])
244
+ out1 = inp.clone()
245
+ sign = make_sign([signx1, signy1, signz0])
246
+ if sign is not None:
247
+ out1 = out1 * sign
248
+ if mask is not None:
249
+ out1 = out1 * mask
250
+ out1 = out1 * (gx * gy * (1 - gz))
251
+ out.scatter_add_(-1, idx, out1)
252
+ # - corner 111
253
+ idx = sub2ind_list([gx1, gy1, gz1], shape)
254
+ idx = idx.expand([batch, channel, idx.shape[-1]])
255
+ out1 = inp.clone()
256
+ sign = make_sign([signx1, signy1, signz1])
257
+ if sign is not None:
258
+ out1 = out1 * sign
259
+ if mask is not None:
260
+ out1 = out1 * mask
261
+ out1 = out1 * (gx * gy * gz)
262
+ out.scatter_add_(-1, idx, out1)
263
+
264
+ out = out.reshape(list(out.shape[:2]) + shape)
265
+ return out
266
+
267
+
268
+ @torch.jit.script
269
+ def grad3d(inp, g, bound: List[Bound], extrapolate: int = 1):
270
+ """
271
+ inp: (B, C, iX, iY, iZ) tensor
272
+ g: (B, oX, oY, oZ, 3) tensor
273
+ bound: List{3}[Bound] tensor
274
+ extrapolate: ExtrapolateType
275
+ returns: (B, C, oX, oY, oZ, 3) tensor
276
+ """
277
+ dim = 3
278
+ boundx, boundy, boundz = bound
279
+ oshape = list(g.shape[-dim-1:-1])
280
+ g = g.reshape([g.shape[0], 1, -1, dim])
281
+ gx, gy, gz = torch.unbind(g, -1)
282
+ batch = max(inp.shape[0], gx.shape[0])
283
+ channel = inp.shape[1]
284
+ shape = list(inp.shape[-dim:])
285
+ nx, ny, nz = shape
286
+
287
+ # mask of inbounds voxels
288
+ mask = inbounds_mask_3d(extrapolate, gx, gy, gz, nx, ny, nz)
289
+
290
+ # corners
291
+ # (upper weight, lower corner, upper corner, lower sign, upper sign)
292
+ gx, gx0, gx1, signx0, signx1 = get_weights_and_indices(gx, nx, boundx)
293
+ gy, gy0, gy1, signy0, signy1 = get_weights_and_indices(gy, ny, boundy)
294
+ gz, gz0, gz1, signz0, signz1 = get_weights_and_indices(gz, nz, boundz)
295
+
296
+ # gather
297
+ inp = inp.reshape(list(inp.shape[:2]) + [-1])
298
+ out = torch.empty([batch, channel] + list(g.shape[-2:]),
299
+ dtype=inp.dtype, device=inp.device)
300
+ outx, outy, outz = out.unbind(-1)
301
+ # - corner 000
302
+ idx = sub2ind_list([gx0, gy0, gz0], shape)
303
+ idx = idx.expand([batch, channel, idx.shape[-1]])
304
+ torch.gather(inp, -1, idx, out=outx)
305
+ outy.copy_(outx)
306
+ outz.copy_(outx)
307
+ sign = make_sign([signx0, signy0, signz0])
308
+ if sign is not None:
309
+ out *= sign.unsqueeze(-1)
310
+ outx *= - (1 - gy) * (1 - gz)
311
+ outy *= - (1 - gx) * (1 - gz)
312
+ outz *= - (1 - gx) * (1 - gy)
313
+ # - corner 001
314
+ idx = sub2ind_list([gx0, gy0, gz1], shape)
315
+ idx = idx.expand([batch, channel, idx.shape[-1]])
316
+ out1 = inp.gather(-1, idx)
317
+ sign = make_sign([signx0, signy0, signz1])
318
+ if sign is not None:
319
+ out1 *= sign
320
+ outx.addcmul_(out1, - (1 - gy) * gz)
321
+ outy.addcmul_(out1, - (1 - gx) * gz)
322
+ outz.addcmul_(out1, (1 - gx) * (1 - gy))
323
+ # - corner 010
324
+ idx = sub2ind_list([gx0, gy1, gz0], shape)
325
+ idx = idx.expand([batch, channel, idx.shape[-1]])
326
+ out1 = inp.gather(-1, idx)
327
+ sign = make_sign([signx0, signy1, signz0])
328
+ if sign is not None:
329
+ out1 *= sign
330
+ outx.addcmul_(out1, - gy * (1 - gz))
331
+ outy.addcmul_(out1, (1 - gx) * (1 - gz))
332
+ outz.addcmul_(out1, - (1 - gx) * gy)
333
+ # - corner 011
334
+ idx = sub2ind_list([gx0, gy1, gz1], shape)
335
+ idx = idx.expand([batch, channel, idx.shape[-1]])
336
+ out1 = inp.gather(-1, idx)
337
+ sign = make_sign([signx0, signy1, signz1])
338
+ if sign is not None:
339
+ out1 *= sign
340
+ outx.addcmul_(out1, - gy * gz)
341
+ outy.addcmul_(out1, (1 - gx) * gz)
342
+ outz.addcmul_(out1, (1 - gx) * gy)
343
+ # - corner 100
344
+ idx = sub2ind_list([gx1, gy0, gz0], shape)
345
+ idx = idx.expand([batch, channel, idx.shape[-1]])
346
+ out1 = inp.gather(-1, idx)
347
+ sign = make_sign([signx1, signy0, signz0])
348
+ if sign is not None:
349
+ out1 *= sign
350
+ outx.addcmul_(out1, (1 - gy) * (1 - gz))
351
+ outy.addcmul_(out1, - gx * (1 - gz))
352
+ outz.addcmul_(out1, - gx * (1 - gy))
353
+ # - corner 101
354
+ idx = sub2ind_list([gx1, gy0, gz1], shape)
355
+ idx = idx.expand([batch, channel, idx.shape[-1]])
356
+ out1 = inp.gather(-1, idx)
357
+ sign = make_sign([signx1, signy0, signz1])
358
+ if sign is not None:
359
+ out1 *= sign
360
+ outx.addcmul_(out1, (1 - gy) * gz)
361
+ outy.addcmul_(out1, - gx * gz)
362
+ outz.addcmul_(out1, gx * (1 - gy))
363
+ # - corner 110
364
+ idx = sub2ind_list([gx1, gy1, gz0], shape)
365
+ idx = idx.expand([batch, channel, idx.shape[-1]])
366
+ out1 = inp.gather(-1, idx)
367
+ sign = make_sign([signx1, signy1, signz0])
368
+ if sign is not None:
369
+ out1 *= sign
370
+ outx.addcmul_(out1, gy * (1 - gz))
371
+ outy.addcmul_(out1, gx * (1 - gz))
372
+ outz.addcmul_(out1, - gx * gy)
373
+ # - corner 111
374
+ idx = sub2ind_list([gx1, gy1, gz1], shape)
375
+ idx = idx.expand([batch, channel, idx.shape[-1]])
376
+ out1 = inp.gather(-1, idx)
377
+ sign = make_sign([signx1, signy1, signz1])
378
+ if sign is not None:
379
+ out1 *= sign
380
+ outx.addcmul_(out1, gy * gz)
381
+ outy.addcmul_(out1, gx * gz)
382
+ outz.addcmul_(out1, gx * gy)
383
+
384
+ if mask is not None:
385
+ out *= mask.unsqueeze(-1)
386
+ out = out.reshape(list(out.shape[:2]) + oshape + [3])
387
+ return out
388
+
389
+
390
+ @torch.jit.script
391
+ def pushgrad3d(inp, g, shape: Optional[List[int]], bound: List[Bound],
392
+ extrapolate: int = 1):
393
+ """
394
+ inp: (B, C, iX, iY, iZ, 3) tensor
395
+ g: (B, iX, iY, iZ, 3) tensor
396
+ shape: List{3}[int], optional
397
+ bound: List{3}[Bound] tensor
398
+ extrapolate: ExtrapolateType
399
+ returns: (B, C, *shape) tensor
400
+ """
401
+ dim = 3
402
+ boundx, boundy, boundz = bound
403
+ if inp.shape[-dim-1:-1] != g.shape[-dim-1:-1]:
404
+ raise ValueError('Input and grid should have the same spatial shape')
405
+ ishape = list(inp.shape[-dim-1:-1])
406
+ g = g.reshape([g.shape[0], 1, -1, dim])
407
+ gx, gy, gz = g.unbind(-1)
408
+ inp = inp.reshape(list(inp.shape[:2]) + [-1, dim])
409
+ batch = max(inp.shape[0], g.shape[0])
410
+ channel = inp.shape[1]
411
+
412
+ if shape is None:
413
+ shape = ishape
414
+ shape = list(shape)
415
+ nx, ny, nz = shape
416
+
417
+ # mask of inbounds voxels
418
+ mask = inbounds_mask_3d(extrapolate, gx, gy, gz, nx, ny, nz)
419
+
420
+ # corners
421
+ # (upper weight, lower corner, upper corner, lower sign, upper sign)
422
+ gx, gx0, gx1, signx0, signx1 = get_weights_and_indices(gx, nx, boundx)
423
+ gy, gy0, gy1, signy0, signy1 = get_weights_and_indices(gy, ny, boundy)
424
+ gz, gz0, gz1, signz0, signz1 = get_weights_and_indices(gz, nz, boundz)
425
+
426
+ # scatter
427
+ out = torch.zeros([batch, channel, nx*ny*nz],
428
+ dtype=inp.dtype, device=inp.device)
429
+ # - corner 000
430
+ idx = sub2ind_list([gx0, gy0, gz0], shape)
431
+ idx = idx.expand([batch, channel, idx.shape[-1]])
432
+ out1 = inp.clone()
433
+ sign = make_sign([signx0, signy0, signz0])
434
+ if sign is not None:
435
+ out1 *= sign.unsqueeze(-1)
436
+ if mask is not None:
437
+ out1 *= mask.unsqueeze(-1)
438
+ out1x, out1y, out1z = out1.unbind(-1)
439
+ out1x *= - (1 - gy) * (1 - gz)
440
+ out1y *= - (1 - gx) * (1 - gz)
441
+ out1z *= - (1 - gx) * (1 - gy)
442
+ out.scatter_add_(-1, idx, out1x + out1y + out1z)
443
+ # - corner 001
444
+ idx = sub2ind_list([gx0, gy0, gz1], shape)
445
+ idx = idx.expand([batch, channel, idx.shape[-1]])
446
+ out1 = inp.clone()
447
+ sign = make_sign([signx0, signy0, signz1])
448
+ if sign is not None:
449
+ out1 *= sign.unsqueeze(-1)
450
+ if mask is not None:
451
+ out1 *= mask.unsqueeze(-1)
452
+ out1x, out1y, out1z = out1.unbind(-1)
453
+ out1x *= - (1 - gy) * gz
454
+ out1y *= - (1 - gx) * gz
455
+ out1z *= (1 - gx) * (1 - gy)
456
+ out.scatter_add_(-1, idx, out1x + out1y + out1z)
457
+ # - corner 010
458
+ idx = sub2ind_list([gx0, gy1, gz0], shape)
459
+ idx = idx.expand([batch, channel, idx.shape[-1]])
460
+ out1 = inp.clone()
461
+ sign = make_sign([signx0, signy1, signz0])
462
+ if sign is not None:
463
+ out1 *= sign.unsqueeze(-1)
464
+ if mask is not None:
465
+ out1 *= mask.unsqueeze(-1)
466
+ out1x, out1y, out1z = out1.unbind(-1)
467
+ out1x *= - gy * (1 - gz)
468
+ out1y *= (1 - gx) * (1 - gz)
469
+ out1z *= - (1 - gx) * gy
470
+ out.scatter_add_(-1, idx, out1x + out1y + out1z)
471
+ # - corner 011
472
+ idx = sub2ind_list([gx0, gy1, gz1], shape)
473
+ idx = idx.expand([batch, channel, idx.shape[-1]])
474
+ out1 = inp.clone()
475
+ sign = make_sign([signx0, signy1, signz1])
476
+ if sign is not None:
477
+ out1 *= sign.unsqueeze(-1)
478
+ if mask is not None:
479
+ out1 *= mask.unsqueeze(-1)
480
+ out1x, out1y, out1z = out1.unbind(-1)
481
+ out1x *= - gy * gz
482
+ out1y *= (1 - gx) * gz
483
+ out1z *= (1 - gx) * gy
484
+ out.scatter_add_(-1, idx, out1x + out1y + out1z)
485
+ # - corner 100
486
+ idx = sub2ind_list([gx1, gy0, gz0], shape)
487
+ idx = idx.expand([batch, channel, idx.shape[-1]])
488
+ out1 = inp.clone()
489
+ sign = make_sign([signx1, signy0, signz0])
490
+ if sign is not None:
491
+ out1 *= sign.unsqueeze(-1)
492
+ if mask is not None:
493
+ out1 *= mask.unsqueeze(-1)
494
+ out1x, out1y, out1z = out1.unbind(-1)
495
+ out1x *= (1 - gy) * (1 - gz)
496
+ out1y *= - gx * (1 - gz)
497
+ out1z *= - gx * (1 - gy)
498
+ out.scatter_add_(-1, idx, out1x + out1y + out1z)
499
+ # - corner 101
500
+ idx = sub2ind_list([gx1, gy0, gz1], shape)
501
+ idx = idx.expand([batch, channel, idx.shape[-1]])
502
+ out1 = inp.clone()
503
+ sign = make_sign([signx1, signy0, signz1])
504
+ if sign is not None:
505
+ out1 *= sign.unsqueeze(-1)
506
+ if mask is not None:
507
+ out1 *= mask.unsqueeze(-1)
508
+ out1x, out1y, out1z = out1.unbind(-1)
509
+ out1x *= (1 - gy) * gz
510
+ out1y *= - gx * gz
511
+ out1z *= gx * (1 - gy)
512
+ out.scatter_add_(-1, idx, out1x + out1y + out1z)
513
+ # - corner 110
514
+ idx = sub2ind_list([gx1, gy1, gz0], shape)
515
+ idx = idx.expand([batch, channel, idx.shape[-1]])
516
+ out1 = inp.clone()
517
+ sign = make_sign([signx1, signy1, signz0])
518
+ if sign is not None:
519
+ out1 *= sign.unsqueeze(-1)
520
+ if mask is not None:
521
+ out1 *= mask.unsqueeze(-1)
522
+ out1x, out1y, out1z = out1.unbind(-1)
523
+ out1x *= gy * (1 - gz)
524
+ out1y *= gx * (1 - gz)
525
+ out1z *= - gx * gy
526
+ out.scatter_add_(-1, idx, out1x + out1y + out1z)
527
+ # - corner 111
528
+ idx = sub2ind_list([gx1, gy1, gz1], shape)
529
+ idx = idx.expand([batch, channel, idx.shape[-1]])
530
+ out1 = inp.clone()
531
+ sign = make_sign([signx1, signy1, signz1])
532
+ if sign is not None:
533
+ out1 *= sign.unsqueeze(-1)
534
+ if mask is not None:
535
+ out1 *= mask.unsqueeze(-1)
536
+ out1x, out1y, out1z = out1.unbind(-1)
537
+ out1x *= gy * gz
538
+ out1y *= gx * gz
539
+ out1z *= gx * gy
540
+ out.scatter_add_(-1, idx, out1x + out1y + out1z)
541
+
542
+ out = out.reshape(list(out.shape[:2]) + shape)
543
+ return out
544
+
545
+
546
+ @torch.jit.script
547
+ def hess3d(inp, g, bound: List[Bound], extrapolate: int = 1):
548
+ """
549
+ inp: (B, C, iX, iY, iZ) tensor
550
+ g: (B, oX, oY, oZ, 3) tensor
551
+ bound: List{3}[Bound] tensor
552
+ extrapolate: ExtrapolateType
553
+ returns: (B, C, oX, oY, oZ, 3, 3) tensor
554
+ """
555
+ dim = 3
556
+ boundx, boundy, boundz = bound
557
+ oshape = list(g.shape[-dim-1:-1])
558
+ g = g.reshape([g.shape[0], 1, -1, dim])
559
+ gx, gy, gz = torch.unbind(g, -1)
560
+ batch = max(inp.shape[0], gx.shape[0])
561
+ channel = inp.shape[1]
562
+ shape = list(inp.shape[-dim:])
563
+ nx, ny, nz = shape
564
+
565
+ # mask of inbounds voxels
566
+ mask = inbounds_mask_3d(extrapolate, gx, gy, gz, nx, ny, nz)
567
+
568
+ # corners
569
+ # (upper weight, lower corner, upper corner, lower sign, upper sign)
570
+ gx, gx0, gx1, signx0, signx1 = get_weights_and_indices(gx, nx, boundx)
571
+ gy, gy0, gy1, signy0, signy1 = get_weights_and_indices(gy, ny, boundy)
572
+ gz, gz0, gz1, signz0, signz1 = get_weights_and_indices(gz, nz, boundz)
573
+
574
+ # gather
575
+ inp = inp.reshape(list(inp.shape[:2]) + [-1])
576
+ out = torch.empty([batch, channel, g.shape[-2], dim, dim],
577
+ dtype=inp.dtype, device=inp.device)
578
+ outx, outy, outz = out.unbind(-1)
579
+ outxx, outyx, outzx = outx.unbind(-1)
580
+ outxy, outyy, outzy = outy.unbind(-1)
581
+ outxz, outyz, outzz = outz.unbind(-1)
582
+ # - corner 000
583
+ idx = sub2ind_list([gx0, gy0, gz0], shape)
584
+ idx = idx.expand([batch, channel, idx.shape[-1]])
585
+ torch.gather(inp, -1, idx, out=outxy)
586
+ outxz.copy_(outxy)
587
+ outyz.copy_(outxy)
588
+ outxx.zero_()
589
+ outyy.zero_()
590
+ outzz.zero_()
591
+ sign = make_sign([signx0, signy0, signz0])
592
+ if sign is not None:
593
+ out *= sign.unsqueeze(-1).unsqueeze(-1)
594
+ outxy *= (1 - gz)
595
+ outxz *= (1 - gy)
596
+ outyz *= (1 - gx)
597
+ # - corner 001
598
+ idx = sub2ind_list([gx0, gy0, gz1], shape)
599
+ idx = idx.expand([batch, channel, idx.shape[-1]])
600
+ out1 = inp.gather(-1, idx)
601
+ sign = make_sign([signx0, signy0, signz1])
602
+ if sign is not None:
603
+ out1 *= sign
604
+ outxy.addcmul_(out1, gz)
605
+ outxz.addcmul_(out1, - (1 - gy))
606
+ outyz.addcmul_(out1, - (1 - gx))
607
+ # - corner 010
608
+ idx = sub2ind_list([gx0, gy1, gz0], shape)
609
+ idx = idx.expand([batch, channel, idx.shape[-1]])
610
+ out1 = inp.gather(-1, idx)
611
+ sign = make_sign([signx0, signy1, signz0])
612
+ if sign is not None:
613
+ out1 *= sign
614
+ outxy.addcmul_(out1, - (1 - gz))
615
+ outxz.addcmul_(out1, gy)
616
+ outyz.addcmul_(out1, - (1 - gx))
617
+ # - corner 011
618
+ idx = sub2ind_list([gx0, gy1, gz1], shape)
619
+ idx = idx.expand([batch, channel, idx.shape[-1]])
620
+ out1 = inp.gather(-1, idx)
621
+ sign = make_sign([signx0, signy1, signz1])
622
+ if sign is not None:
623
+ out1 *= sign
624
+ outxy.addcmul_(out1, - gz)
625
+ outxz.addcmul_(out1, - gy)
626
+ outyz.addcmul_(out1, (1 - gx))
627
+ # - corner 100
628
+ idx = sub2ind_list([gx1, gy0, gz0], shape)
629
+ idx = idx.expand([batch, channel, idx.shape[-1]])
630
+ out1 = inp.gather(-1, idx)
631
+ sign = make_sign([signx1, signy0, signz0])
632
+ if sign is not None:
633
+ out1 *= sign
634
+ outxy.addcmul_(out1, - (1 - gz))
635
+ outxz.addcmul_(out1, - (1 - gy))
636
+ outyz.addcmul_(out1, gx)
637
+ # - corner 101
638
+ idx = sub2ind_list([gx1, gy0, gz1], shape)
639
+ idx = idx.expand([batch, channel, idx.shape[-1]])
640
+ out1 = inp.gather(-1, idx)
641
+ sign = make_sign([signx1, signy0, signz1])
642
+ if sign is not None:
643
+ out1 *= sign
644
+ outxy.addcmul_(out1, - gz)
645
+ outxz.addcmul_(out1, (1 - gy))
646
+ outyz.addcmul_(out1, - gx)
647
+ # - corner 110
648
+ idx = sub2ind_list([gx1, gy1, gz0], shape)
649
+ idx = idx.expand([batch, channel, idx.shape[-1]])
650
+ out1 = inp.gather(-1, idx)
651
+ sign = make_sign([signx1, signy1, signz0])
652
+ if sign is not None:
653
+ out1 *= sign
654
+ outxy.addcmul_(out1, (1 - gz))
655
+ outxz.addcmul_(out1, - gy)
656
+ outyz.addcmul_(out1, - gx)
657
+ # - corner 111
658
+ idx = sub2ind_list([gx1, gy1, gz1], shape)
659
+ idx = idx.expand([batch, channel, idx.shape[-1]])
660
+ out1 = inp.gather(-1, idx)
661
+ sign = make_sign([signx1, signy1, signz1])
662
+ if sign is not None:
663
+ out1 *= sign
664
+ outxy.addcmul_(out1, gz)
665
+ outxz.addcmul_(out1, gy)
666
+ outyz.addcmul_(out1, gx)
667
+
668
+ outyx.copy_(outxy)
669
+ outzx.copy_(outxz)
670
+ outzy.copy_(outyz)
671
+
672
+ if mask is not None:
673
+ out *= mask.unsqueeze(-1).unsqueeze(-1)
674
+ out = out.reshape(list(out.shape[:2]) + oshape + [dim, dim])
675
+ return out
676
+
677
+
678
+ # ======================================================================
679
+ # 2D
680
+ # ======================================================================
681
+
682
+
683
+ @torch.jit.script
684
+ def pull2d(inp, g, bound: List[Bound], extrapolate: int = 1):
685
+ """
686
+ inp: (B, C, iX, iY) tensor
687
+ g: (B, oX, oY, 2) tensor
688
+ bound: List{2}[Bound] tensor
689
+ extrapolate: ExtrapolateType
690
+ returns: (B, C, oX, oY) tensor
691
+ """
692
+ dim = 2
693
+ boundx, boundy = bound
694
+ oshape = list(g.shape[-dim-1:-1])
695
+ g = g.reshape([g.shape[0], 1, -1, dim])
696
+ gx, gy = g.unbind(-1)
697
+ batch = max(inp.shape[0], gx.shape[0])
698
+ channel = inp.shape[1]
699
+ shape = list(inp.shape[-dim:])
700
+ nx, ny = shape
701
+
702
+ # mask of inbounds voxels
703
+ mask = inbounds_mask_2d(extrapolate, gx, gy, nx, ny)
704
+
705
+ # corners
706
+ # (upper weight, lower corner, upper corner, lower sign, upper sign)
707
+ gx, gx0, gx1, signx0, signx1 = get_weights_and_indices(gx, nx, boundx)
708
+ gy, gy0, gy1, signy0, signy1 = get_weights_and_indices(gy, ny, boundy)
709
+
710
+ # gather
711
+ inp = inp.reshape(list(inp.shape[:2]) + [-1])
712
+ # - corner 00
713
+ idx = sub2ind_list([gx0, gy0], shape)
714
+ idx = idx.expand([batch, channel, idx.shape[-1]])
715
+ out = inp.gather(-1, idx)
716
+ sign = make_sign([signx0, signy0])
717
+ if sign is not None:
718
+ out = out * sign
719
+ out = out * ((1 - gx) * (1 - gy))
720
+ # - corner 01
721
+ idx = sub2ind_list([gx0, gy1], shape)
722
+ idx = idx.expand([batch, channel, idx.shape[-1]])
723
+ out1 = inp.gather(-1, idx)
724
+ sign = make_sign([signx0, signy1])
725
+ if sign is not None:
726
+ out1 = out1 * sign
727
+ out1 = out1 * ((1 - gx) * gy)
728
+ out = out + out1
729
+ # - corner 10
730
+ idx = sub2ind_list([gx1, gy0], shape)
731
+ idx = idx.expand([batch, channel, idx.shape[-1]])
732
+ out1 = inp.gather(-1, idx)
733
+ sign = make_sign([signx1, signy0])
734
+ if sign is not None:
735
+ out1 = out1 * sign
736
+ out1 = out1 * (gx * (1 - gy))
737
+ out = out + out1
738
+ # - corner 11
739
+ idx = sub2ind_list([gx1, gy1], shape)
740
+ idx = idx.expand([batch, channel, idx.shape[-1]])
741
+ out1 = inp.gather(-1, idx)
742
+ sign = make_sign([signx1, signy1])
743
+ if sign is not None:
744
+ out1 = out1 * sign
745
+ out1 = out1 * (gx * gy)
746
+ out = out + out1
747
+
748
+ if mask is not None:
749
+ out *= mask
750
+ out = out.reshape(list(out.shape[:2]) + oshape)
751
+ return out
752
+
753
+
754
+ @torch.jit.script
755
+ def push2d(inp, g, shape: Optional[List[int]], bound: List[Bound],
756
+ extrapolate: int = 1):
757
+ """
758
+ inp: (B, C, iX, iY) tensor
759
+ g: (B, iX, iY, 2) tensor
760
+ shape: List{2}[int], optional
761
+ bound: List{2}[Bound] tensor
762
+ extrapolate: ExtrapolateType
763
+ returns: (B, C, *shape) tensor
764
+ """
765
+ dim = 2
766
+ boundx, boundy = bound
767
+ if inp.shape[-dim:] != g.shape[-dim-1:-1]:
768
+ raise ValueError('Input and grid should have the same spatial shape')
769
+ ishape = list(inp.shape[-dim:])
770
+ g = g.reshape([g.shape[0], 1, -1, dim])
771
+ gx, gy = torch.unbind(g, -1)
772
+ inp = inp.reshape(list(inp.shape[:2]) + [-1])
773
+ batch = max(inp.shape[0], gx.shape[0])
774
+ channel = inp.shape[1]
775
+
776
+ if shape is None:
777
+ shape = ishape
778
+ shape = list(shape)
779
+ nx, ny = shape
780
+
781
+ # mask of inbounds voxels
782
+ mask = inbounds_mask_2d(extrapolate, gx, gy, nx, ny)
783
+
784
+ # corners
785
+ # (upper weight, lower corner, upper corner, lower sign, upper sign)
786
+ gx, gx0, gx1, signx0, signx1 = get_weights_and_indices(gx, nx, boundx)
787
+ gy, gy0, gy1, signy0, signy1 = get_weights_and_indices(gy, ny, boundy)
788
+
789
+ # scatter
790
+ out = torch.zeros([batch, channel, nx*ny],
791
+ dtype=inp.dtype, device=inp.device)
792
+ # - corner 00
793
+ idx = sub2ind_list([gx0, gy0], shape)
794
+ idx = idx.expand([batch, channel, idx.shape[-1]])
795
+ out1 = inp.clone()
796
+ sign = make_sign([signx0, signy0])
797
+ if sign is not None:
798
+ out1 *= sign
799
+ if mask is not None:
800
+ out1 *= mask
801
+ out1 *= (1 - gx) * (1 - gy)
802
+ out.scatter_add_(-1, idx, out1)
803
+ # - corner 01
804
+ idx = sub2ind_list([gx0, gy1], shape)
805
+ idx = idx.expand([batch, channel, idx.shape[-1]])
806
+ out1 = inp.clone()
807
+ sign = make_sign([signx0, signy1])
808
+ if sign is not None:
809
+ out1 *= sign
810
+ if mask is not None:
811
+ out1 *= mask
812
+ out1 *= (1 - gx) * gy
813
+ out.scatter_add_(-1, idx, out1)
814
+ # - corner 10
815
+ idx = sub2ind_list([gx1, gy0], shape)
816
+ idx = idx.expand([batch, channel, idx.shape[-1]])
817
+ out1 = inp.clone()
818
+ sign = make_sign([signx1, signy0])
819
+ if sign is not None:
820
+ out1 *= sign
821
+ if mask is not None:
822
+ out1 *= mask
823
+ out1 *= gx * (1 - gy)
824
+ out.scatter_add_(-1, idx, out1)
825
+ # - corner 11
826
+ idx = sub2ind_list([gx1, gy1], shape)
827
+ idx = idx.expand([batch, channel, idx.shape[-1]])
828
+ out1 = inp.clone()
829
+ sign = make_sign([signx1, signy1])
830
+ if sign is not None:
831
+ out1 *= sign
832
+ if mask is not None:
833
+ out1 *= mask
834
+ out1 *= gx * gy
835
+ out.scatter_add_(-1, idx, out1)
836
+
837
+ out = out.reshape(list(out.shape[:2]) + shape)
838
+ return out
839
+
840
+
841
+ @torch.jit.script
842
+ def grad2d(inp, g, bound: List[Bound], extrapolate: int = 1):
843
+ """
844
+ inp: (B, C, iX, iY) tensor
845
+ g: (B, oX, oY, 2) tensor
846
+ bound: List{2}[Bound] tensor
847
+ extrapolate: ExtrapolateType
848
+ returns: (B, C, oX, oY, 2) tensor
849
+ """
850
+ dim = 2
851
+ boundx, boundy = bound
852
+ oshape = list(g.shape[-dim-1:-1])
853
+ g = g.reshape([g.shape[0], 1, -1, dim])
854
+ gx, gy = torch.unbind(g, -1)
855
+ batch = max(inp.shape[0], gx.shape[0])
856
+ channel = inp.shape[1]
857
+ shape = list(inp.shape[-dim:])
858
+ nx, ny = shape
859
+
860
+ # mask of inbounds voxels
861
+ mask = inbounds_mask_2d(extrapolate, gx, gy, nx, ny)
862
+
863
+ # corners
864
+ # (upper weight, lower corner, upper corner, lower sign, upper sign)
865
+ gx, gx0, gx1, signx0, signx1 = get_weights_and_indices(gx, nx, boundx)
866
+ gy, gy0, gy1, signy0, signy1 = get_weights_and_indices(gy, ny, boundy)
867
+
868
+ # gather
869
+ inp = inp.reshape(list(inp.shape[:2]) + [-1])
870
+ out = torch.empty([batch, channel] + list(g.shape[-2:]),
871
+ dtype=inp.dtype, device=inp.device)
872
+ outx, outy = out.unbind(-1)
873
+ # - corner 00
874
+ idx = sub2ind_list([gx0, gy0], shape)
875
+ idx = idx.expand([batch, channel, idx.shape[-1]])
876
+ torch.gather(inp, -1, idx, out=outx)
877
+ outy.copy_(outx)
878
+ sign = make_sign([signx0, signy0])
879
+ if sign is not None:
880
+ out *= sign.unsqueeze(-1)
881
+ outx *= - (1 - gy)
882
+ outy *= - (1 - gx)
883
+ # - corner 01
884
+ idx = sub2ind_list([gx0, gy1], shape)
885
+ idx = idx.expand([batch, channel, idx.shape[-1]])
886
+ out1 = inp.gather(-1, idx)
887
+ sign = make_sign([signx0, signy1])
888
+ if sign is not None:
889
+ out1 *= sign
890
+ outx.addcmul_(out1, - gy)
891
+ outy.addcmul_(out1, (1 - gx))
892
+ # - corner 10
893
+ idx = sub2ind_list([gx1, gy0], shape)
894
+ idx = idx.expand([batch, channel, idx.shape[-1]])
895
+ out1 = inp.gather(-1, idx)
896
+ sign = make_sign([signx1, signy0])
897
+ if sign is not None:
898
+ out1 *= sign
899
+ outx.addcmul_(out1, (1 - gy))
900
+ outy.addcmul_(out1, - gx)
901
+ # - corner 11
902
+ idx = sub2ind_list([gx1, gy1], shape)
903
+ idx = idx.expand([batch, channel, idx.shape[-1]])
904
+ out1 = inp.gather(-1, idx)
905
+ sign = make_sign([signx1, signy1])
906
+ if sign is not None:
907
+ out1 *= sign
908
+ outx.addcmul_(out1, gy)
909
+ outy.addcmul_(out1, gx)
910
+
911
+ if mask is not None:
912
+ out *= mask.unsqueeze(-1)
913
+ out = out.reshape(list(out.shape[:2]) + oshape + [dim])
914
+ return out
915
+
916
+
917
+ @torch.jit.script
918
+ def pushgrad2d(inp, g, shape: Optional[List[int]], bound: List[Bound],
919
+ extrapolate: int = 1):
920
+ """
921
+ inp: (B, C, iX, iY, 2) tensor
922
+ g: (B, iX, iY, 2) tensor
923
+ shape: List{2}[int], optional
924
+ bound: List{2}[Bound] tensor
925
+ extrapolate: ExtrapolateType
926
+ returns: (B, C, *shape) tensor
927
+ """
928
+ dim = 2
929
+ boundx, boundy = bound
930
+ if inp.shape[-dim-1:-1] != g.shape[-dim-1:-1]:
931
+ raise ValueError('Input and grid should have the same spatial shape')
932
+ ishape = list(inp.shape[-dim-1:-1])
933
+ g = g.reshape([g.shape[0], 1, -1, dim])
934
+ gx, gy = g.unbind(-1)
935
+ inp = inp.reshape(list(inp.shape[:2]) + [-1, dim])
936
+ batch = max(inp.shape[0], g.shape[0])
937
+ channel = inp.shape[1]
938
+
939
+ if shape is None:
940
+ shape = ishape
941
+ shape = list(shape)
942
+ nx, ny = shape
943
+
944
+ # mask of inbounds voxels
945
+ mask = inbounds_mask_2d(extrapolate, gx, gy, nx, ny)
946
+
947
+ # corners
948
+ # (upper weight, lower corner, upper corner, lower sign, upper sign)
949
+ gx, gx0, gx1, signx0, signx1 = get_weights_and_indices(gx, nx, boundx)
950
+ gy, gy0, gy1, signy0, signy1 = get_weights_and_indices(gy, ny, boundy)
951
+
952
+ # scatter
953
+ out = torch.zeros([batch, channel, nx*ny],
954
+ dtype=inp.dtype, device=inp.device)
955
+ # - corner 00
956
+ idx = sub2ind_list([gx0, gy0], shape)
957
+ idx = idx.expand([batch, channel, idx.shape[-1]])
958
+ out1 = inp.clone()
959
+ sign = make_sign([signx0, signy0])
960
+ if sign is not None:
961
+ out1 *= sign.unsqueeze(-1)
962
+ if mask is not None:
963
+ out1 *= mask.unsqueeze(-1)
964
+ out1x, out1y = out1.unbind(-1)
965
+ out1x *= - (1 - gy)
966
+ out1y *= - (1 - gx)
967
+ out.scatter_add_(-1, idx, out1x + out1y)
968
+ # - corner 01
969
+ idx = sub2ind_list([gx0, gy1], shape)
970
+ idx = idx.expand([batch, channel, idx.shape[-1]])
971
+ out1 = inp.clone()
972
+ sign = make_sign([signx0, signy1])
973
+ if sign is not None:
974
+ out1 *= sign.unsqueeze(-1)
975
+ if mask is not None:
976
+ out1 *= mask.unsqueeze(-1)
977
+ out1x, out1y = out1.unbind(-1)
978
+ out1x *= - gy
979
+ out1y *= (1 - gx)
980
+ out.scatter_add_(-1, idx, out1x + out1y)
981
+ # - corner 10
982
+ idx = sub2ind_list([gx1, gy0], shape)
983
+ idx = idx.expand([batch, channel, idx.shape[-1]])
984
+ out1 = inp.clone()
985
+ sign = make_sign([signx1, signy0])
986
+ if sign is not None:
987
+ out1 *= sign.unsqueeze(-1)
988
+ if mask is not None:
989
+ out1 *= mask.unsqueeze(-1)
990
+ out1x, out1y = out1.unbind(-1)
991
+ out1x *= (1 - gy)
992
+ out1y *= - gx
993
+ out.scatter_add_(-1, idx, out1x + out1y)
994
+ # - corner 11
995
+ idx = sub2ind_list([gx1, gy1], shape)
996
+ idx = idx.expand([batch, channel, idx.shape[-1]])
997
+ out1 = inp.clone()
998
+ sign = make_sign([signx1, signy1])
999
+ if sign is not None:
1000
+ out1 *= sign.unsqueeze(-1)
1001
+ if mask is not None:
1002
+ out1 *= mask.unsqueeze(-1)
1003
+ out1x, out1y = out1.unbind(-1)
1004
+ out1x *= gy
1005
+ out1y *= gx
1006
+ out.scatter_add_(-1, idx, out1x + out1y)
1007
+
1008
+ out = out.reshape(list(out.shape[:2]) + shape)
1009
+ return out
1010
+
1011
+
1012
+ @torch.jit.script
1013
+ def hess2d(inp, g, bound: List[Bound], extrapolate: int = 1):
1014
+ """
1015
+ inp: (B, C, iX, iY) tensor
1016
+ g: (B, oX, oY, 2) tensor
1017
+ bound: List{2}[Bound] tensor
1018
+ extrapolate: ExtrapolateType
1019
+ returns: (B, C, oX, oY, 2, 2) tensor
1020
+ """
1021
+ dim = 2
1022
+ boundx, boundy = bound
1023
+ oshape = list(g.shape[-dim-1:-1])
1024
+ g = g.reshape([g.shape[0], 1, -1, dim])
1025
+ gx, gy = torch.unbind(g, -1)
1026
+ batch = max(inp.shape[0], gx.shape[0])
1027
+ channel = inp.shape[1]
1028
+ shape = list(inp.shape[-dim:])
1029
+ nx, ny = shape
1030
+
1031
+ # mask of inbounds voxels
1032
+ mask = inbounds_mask_2d(extrapolate, gx, gy, nx, ny)
1033
+
1034
+ # corners
1035
+ # (upper weight, lower corner, upper corner, lower sign, upper sign)
1036
+ gx, gx0, gx1, signx0, signx1 = get_weights_and_indices(gx, nx, boundx)
1037
+ gy, gy0, gy1, signy0, signy1 = get_weights_and_indices(gy, ny, boundy)
1038
+
1039
+ # gather
1040
+ inp = inp.reshape(list(inp.shape[:2]) + [-1])
1041
+ out = torch.empty([batch, channel, g.shape[-2], dim, dim],
1042
+ dtype=inp.dtype, device=inp.device)
1043
+ outx, outy = out.unbind(-1)
1044
+ outxx, outyx = outx.unbind(-1)
1045
+ outxy, outyy = outy.unbind(-1)
1046
+ # - corner 00
1047
+ idx = sub2ind_list([gx0, gy0], shape)
1048
+ idx = idx.expand([batch, channel, idx.shape[-1]])
1049
+ torch.gather(inp, -1, idx, out=outxy)
1050
+ outxx.zero_()
1051
+ outyy.zero_()
1052
+ sign = make_sign([signx0, signy0])
1053
+ if sign is not None:
1054
+ out *= sign.unsqueeze(-1).unsqueeze(-1)
1055
+ outxy *= 1
1056
+ # - corner 01
1057
+ idx = sub2ind_list([gx0, gy1], shape)
1058
+ idx = idx.expand([batch, channel, idx.shape[-1]])
1059
+ out1 = inp.gather(-1, idx)
1060
+ sign = make_sign([signx0, signy1])
1061
+ if sign is not None:
1062
+ out1 *= sign
1063
+ outxy.add_(out1, alpha=-1)
1064
+ # - corner 10
1065
+ idx = sub2ind_list([gx1, gy0], shape)
1066
+ idx = idx.expand([batch, channel, idx.shape[-1]])
1067
+ out1 = inp.gather(-1, idx)
1068
+ sign = make_sign([signx1, signy0])
1069
+ if sign is not None:
1070
+ out1 *= sign
1071
+ outxy.add_(out1, alpha=-1)
1072
+ # - corner 11
1073
+ idx = sub2ind_list([gx1, gy1], shape)
1074
+ idx = idx.expand([batch, channel, idx.shape[-1]])
1075
+ out1 = inp.gather(-1, idx)
1076
+ sign = make_sign([signx1, signy1])
1077
+ if sign is not None:
1078
+ out1 *= sign
1079
+ outxy.add_(out1)
1080
+
1081
+ outyx.copy_(outxy)
1082
+
1083
+ if mask is not None:
1084
+ out *= mask.unsqueeze(-1).unsqueeze(-1)
1085
+ out = out.reshape(list(out.shape[:2]) + oshape + [dim, dim])
1086
+ return out
1087
+
1088
+
1089
+ # ======================================================================
1090
+ # 1D
1091
+ # ======================================================================
1092
+
1093
+
1094
+ @torch.jit.script
1095
+ def pull1d(inp, g, bound: List[Bound], extrapolate: int = 1):
1096
+ """
1097
+ inp: (B, C, iX) tensor
1098
+ g: (B, oX, 1) tensor
1099
+ bound: List{1}[Bound] tensor
1100
+ extrapolate: ExtrapolateType
1101
+ returns: (B, C, oX) tensor
1102
+ """
1103
+ dim = 1
1104
+ boundx = bound[0]
1105
+ oshape = list(g.shape[-dim-1:-1])
1106
+ g = g.reshape([g.shape[0], 1, -1, dim])
1107
+ gx = g.squeeze(-1)
1108
+ batch = max(inp.shape[0], gx.shape[0])
1109
+ channel = inp.shape[1]
1110
+ shape = list(inp.shape[-dim:])
1111
+ nx = shape[0]
1112
+
1113
+ # mask of inbounds voxels
1114
+ mask = inbounds_mask_1d(extrapolate, gx, nx)
1115
+
1116
+ # corners
1117
+ # (upper weight, lower corner, upper corner, lower sign, upper sign)
1118
+ gx, gx0, gx1, signx0, signx1 = get_weights_and_indices(gx, nx, boundx)
1119
+
1120
+ # gather
1121
+ inp = inp.reshape(list(inp.shape[:2]) + [-1])
1122
+ # - corner 0
1123
+ idx = gx0
1124
+ idx = idx.expand([batch, channel, idx.shape[-1]])
1125
+ out = inp.gather(-1, idx)
1126
+ sign = signx0
1127
+ if sign is not None:
1128
+ out = out * sign
1129
+ out = out * (1 - gx)
1130
+ # - corner 1
1131
+ idx = gx1
1132
+ idx = idx.expand([batch, channel, idx.shape[-1]])
1133
+ out1 = inp.gather(-1, idx)
1134
+ sign = signx1
1135
+ if sign is not None:
1136
+ out1 = out1 * sign
1137
+ out1 = out1 * gx
1138
+ out = out + out1
1139
+
1140
+ if mask is not None:
1141
+ out *= mask
1142
+ out = out.reshape(list(out.shape[:2]) + oshape)
1143
+ return out
1144
+
1145
+
1146
+ @torch.jit.script
1147
+ def push1d(inp, g, shape: Optional[List[int]], bound: List[Bound],
1148
+ extrapolate: int = 1):
1149
+ """
1150
+ inp: (B, C, iX, iY) tensor
1151
+ g: (B, iX, iY, 2) tensor
1152
+ shape: List{2}[int], optional
1153
+ bound: List{2}[Bound] tensor
1154
+ extrapolate: ExtrapolateType
1155
+ returns: (B, C, *shape) tensor
1156
+ """
1157
+ dim = 1
1158
+ boundx = bound[0]
1159
+ if inp.shape[-dim:] != g.shape[-dim-1:-1]:
1160
+ raise ValueError('Input and grid should have the same spatial shape')
1161
+ ishape = list(inp.shape[-dim:])
1162
+ g = g.reshape([g.shape[0], 1, -1, dim])
1163
+ gx = g.squeeze(-1)
1164
+ inp = inp.reshape(list(inp.shape[:2]) + [-1])
1165
+ batch = max(inp.shape[0], gx.shape[0])
1166
+ channel = inp.shape[1]
1167
+
1168
+ if shape is None:
1169
+ shape = ishape
1170
+ shape = list(shape)
1171
+ nx = shape[0]
1172
+
1173
+ # mask of inbounds voxels
1174
+ mask = inbounds_mask_1d(extrapolate, gx, nx)
1175
+
1176
+ # corners
1177
+ # (upper weight, lower corner, upper corner, lower sign, upper sign)
1178
+ gx, gx0, gx1, signx0, signx1 = get_weights_and_indices(gx, nx, boundx)
1179
+
1180
+ # scatter
1181
+ out = torch.zeros([batch, channel, nx],
1182
+ dtype=inp.dtype, device=inp.device)
1183
+ # - corner 0
1184
+ idx = gx0
1185
+ idx = idx.expand([batch, channel, idx.shape[-1]])
1186
+ out1 = inp.clone()
1187
+ sign = signx0
1188
+ if sign is not None:
1189
+ out1 = out1 * sign
1190
+ if mask is not None:
1191
+ out1 = out1 * mask
1192
+ out1 = out1 * (1 - gx)
1193
+ out.scatter_add_(-1, idx, out1)
1194
+ # - corner 1
1195
+ idx = gx1
1196
+ idx = idx.expand([batch, channel, idx.shape[-1]])
1197
+ out1 = inp.clone()
1198
+ sign = signx1
1199
+ if sign is not None:
1200
+ out1 = out1 * sign
1201
+ if mask is not None:
1202
+ out1 = out1 * mask
1203
+ out1 = out1 * gx
1204
+ out.scatter_add_(-1, idx, out1)
1205
+
1206
+ out = out.reshape(list(out.shape[:2]) + shape)
1207
+ return out
1208
+
1209
+
1210
+ @torch.jit.script
1211
+ def grad1d(inp, g, bound: List[Bound], extrapolate: int = 1):
1212
+ """
1213
+ inp: (B, C, iX) tensor
1214
+ g: (B, oX, 1) tensor
1215
+ bound: List{1}[Bound] tensor
1216
+ extrapolate: ExtrapolateType
1217
+ returns: (B, C, oX, 1) tensor
1218
+ """
1219
+ dim = 1
1220
+ boundx = bound[0]
1221
+ oshape = list(g.shape[-dim-1:-1])
1222
+ g = g.reshape([g.shape[0], 1, -1, dim])
1223
+ gx = g.squeeze(-1)
1224
+ batch = max(inp.shape[0], gx.shape[0])
1225
+ channel = inp.shape[1]
1226
+ shape = list(inp.shape[-dim:])
1227
+ nx = shape[0]
1228
+
1229
+ # mask of inbounds voxels
1230
+ mask = inbounds_mask_1d(extrapolate, gx, nx)
1231
+
1232
+ # corners
1233
+ # (upper weight, lower corner, upper corner, lower sign, upper sign)
1234
+ gx, gx0, gx1, signx0, signx1 = get_weights_and_indices(gx, nx, boundx)
1235
+
1236
+ # gather
1237
+ inp = inp.reshape(list(inp.shape[:2]) + [-1])
1238
+ out = torch.empty([batch, channel] + list(g.shape[-2:]),
1239
+ dtype=inp.dtype, device=inp.device)
1240
+ outx = out.squeeze(-1)
1241
+ # - corner 0
1242
+ idx = gx0
1243
+ idx = idx.expand([batch, channel, idx.shape[-1]])
1244
+ torch.gather(inp, -1, idx, out=outx)
1245
+ sign = signx0
1246
+ if sign is not None:
1247
+ out *= sign.unsqueeze(-1)
1248
+ outx.neg_()
1249
+ # - corner 1
1250
+ idx = gx1
1251
+ idx = idx.expand([batch, channel, idx.shape[-1]])
1252
+ out1 = inp.gather(-1, idx)
1253
+ sign = signx1
1254
+ if sign is not None:
1255
+ out1 *= sign
1256
+ outx.add_(out1)
1257
+
1258
+ if mask is not None:
1259
+ out *= mask.unsqueeze(-1)
1260
+ out = out.reshape(list(out.shape[:2]) + oshape + [dim])
1261
+ return out
1262
+
1263
+
1264
+ @torch.jit.script
1265
+ def pushgrad1d(inp, g, shape: Optional[List[int]], bound: List[Bound],
1266
+ extrapolate: int = 1):
1267
+ """
1268
+ inp: (B, C, iX, 1) tensor
1269
+ g: (B, iX, 1) tensor
1270
+ shape: List{1}[int], optional
1271
+ bound: List{1}[Bound] tensor
1272
+ extrapolate: ExtrapolateType
1273
+ returns: (B, C, *shape) tensor
1274
+ """
1275
+ dim = 1
1276
+ boundx = bound[0]
1277
+ if inp.shape[-2] != g.shape[-2]:
1278
+ raise ValueError('Input and grid should have the same spatial shape')
1279
+ ishape = list(inp.shape[-dim-1:-1])
1280
+ g = g.reshape([g.shape[0], 1, -1, dim])
1281
+ gx = g.squeeze(-1)
1282
+ inp = inp.reshape(list(inp.shape[:2]) + [-1, dim])
1283
+ batch = max(inp.shape[0], g.shape[0])
1284
+ channel = inp.shape[1]
1285
+
1286
+ if shape is None:
1287
+ shape = ishape
1288
+ shape = list(shape)
1289
+ nx = shape[0]
1290
+
1291
+ # mask of inbounds voxels
1292
+ mask = inbounds_mask_1d(extrapolate, gx, nx)
1293
+
1294
+ # corners
1295
+ # (upper weight, lower corner, upper corner, lower sign, upper sign)
1296
+ gx, gx0, gx1, signx0, signx1 = get_weights_and_indices(gx, nx, boundx)
1297
+
1298
+ # scatter
1299
+ out = torch.zeros([batch, channel, nx], dtype=inp.dtype, device=inp.device)
1300
+ # - corner 000
1301
+ idx = gx0
1302
+ idx = idx.expand([batch, channel, idx.shape[-1]])
1303
+ out1 = inp.clone()
1304
+ sign = signx0
1305
+ if sign is not None:
1306
+ out1 *= sign.unsqueeze(-1)
1307
+ if mask is not None:
1308
+ out1 *= mask.unsqueeze(-1)
1309
+ out1x = out1.squeeze(-1)
1310
+ out1x.neg_()
1311
+ out.scatter_add_(-1, idx, out1x)
1312
+ # - corner 100
1313
+ idx = gx1
1314
+ idx = idx.expand([batch, channel, idx.shape[-1]])
1315
+ out1 = inp.clone()
1316
+ sign = signx1
1317
+ if sign is not None:
1318
+ out1 *= sign.unsqueeze(-1)
1319
+ if mask is not None:
1320
+ out1 *= mask.unsqueeze(-1)
1321
+ out1x = out1.squeeze(-1)
1322
+ out.scatter_add_(-1, idx, out1x)
1323
+
1324
+ out = out.reshape(list(out.shape[:2]) + shape)
1325
+ return out
1326
+
1327
+
1328
+ @torch.jit.script
1329
+ def hess1d(inp, g, bound: List[Bound], extrapolate: int = 1):
1330
+ """
1331
+ inp: (B, C, iX) tensor
1332
+ g: (B, oX, 1) tensor
1333
+ bound: List{1}[Bound] tensor
1334
+ extrapolate: ExtrapolateType
1335
+ returns: (B, C, oX, 1, 1) tensor
1336
+ """
1337
+ batch = max(inp.shape[0], g.shape[0])
1338
+ return torch.zeros([batch, inp.shape[1], g.shape[1], 1, 1],
1339
+ dtype=inp.dtype, device=inp.device)
Generator/interpol/jit_utils.py ADDED
@@ -0,0 +1,443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """A lot of utility functions for TorchScript"""
2
+ import torch
3
+ import os
4
+ from typing import List, Tuple, Optional
5
+ from .utils import torch_version
6
+ from torch import Tensor
7
+
8
+
9
+ @torch.jit.script
10
+ def pad_list_int(x: List[int], dim: int) -> List[int]:
11
+ if len(x) < dim:
12
+ x = x + x[-1:] * (dim - len(x))
13
+ if len(x) > dim:
14
+ x = x[:dim]
15
+ return x
16
+
17
+
18
+ @torch.jit.script
19
+ def pad_list_float(x: List[float], dim: int) -> List[float]:
20
+ if len(x) < dim:
21
+ x = x + x[-1:] * (dim - len(x))
22
+ if len(x) > dim:
23
+ x = x[:dim]
24
+ return x
25
+
26
+
27
+ @torch.jit.script
28
+ def pad_list_str(x: List[str], dim: int) -> List[str]:
29
+ if len(x) < dim:
30
+ x = x + x[-1:] * (dim - len(x))
31
+ if len(x) > dim:
32
+ x = x[:dim]
33
+ return x
34
+
35
+
36
+ @torch.jit.script
37
+ def list_any(x: List[bool]) -> bool:
38
+ for elem in x:
39
+ if elem:
40
+ return True
41
+ return False
42
+
43
+
44
+ @torch.jit.script
45
+ def list_all(x: List[bool]) -> bool:
46
+ for elem in x:
47
+ if not elem:
48
+ return False
49
+ return True
50
+
51
+
52
+ @torch.jit.script
53
+ def list_prod_int(x: List[int]) -> int:
54
+ if len(x) == 0:
55
+ return 1
56
+ x0 = x[0]
57
+ for x1 in x[1:]:
58
+ x0 = x0 * x1
59
+ return x0
60
+
61
+
62
+ @torch.jit.script
63
+ def list_sum_int(x: List[int]) -> int:
64
+ if len(x) == 0:
65
+ return 1
66
+ x0 = x[0]
67
+ for x1 in x[1:]:
68
+ x0 = x0 + x1
69
+ return x0
70
+
71
+
72
+ @torch.jit.script
73
+ def list_prod_tensor(x: List[Tensor]) -> Tensor:
74
+ if len(x) == 0:
75
+ empty: List[int] = []
76
+ return torch.ones(empty)
77
+ x0 = x[0]
78
+ for x1 in x[1:]:
79
+ x0 = x0 * x1
80
+ return x0
81
+
82
+
83
+ @torch.jit.script
84
+ def list_sum_tensor(x: List[Tensor]) -> Tensor:
85
+ if len(x) == 0:
86
+ empty: List[int] = []
87
+ return torch.ones(empty)
88
+ x0 = x[0]
89
+ for x1 in x[1:]:
90
+ x0 = x0 + x1
91
+ return x0
92
+
93
+
94
+ @torch.jit.script
95
+ def list_reverse_int(x: List[int]) -> List[int]:
96
+ if len(x) == 0:
97
+ return x
98
+ return [x[i] for i in range(-1, -len(x)-1, -1)]
99
+
100
+
101
+ @torch.jit.script
102
+ def list_cumprod_int(x: List[int], reverse: bool = False,
103
+ exclusive: bool = False) -> List[int]:
104
+ if len(x) == 0:
105
+ lx: List[int] = []
106
+ return lx
107
+ if reverse:
108
+ x = list_reverse_int(x)
109
+
110
+ x0 = 1 if exclusive else x[0]
111
+ lx = [x0]
112
+ all_x = x[:-1] if exclusive else x[1:]
113
+ for x1 in all_x:
114
+ x0 = x0 * x1
115
+ lx.append(x0)
116
+ if reverse:
117
+ lx = list_reverse_int(lx)
118
+ return lx
119
+
120
+
121
+ @torch.jit.script
122
+ def movedim1(x, source: int, destination: int):
123
+ dim = x.dim()
124
+ source = dim + source if source < 0 else source
125
+ destination = dim + destination if destination < 0 else destination
126
+ permutation = [d for d in range(dim)]
127
+ permutation = permutation[:source] + permutation[source+1:]
128
+ permutation = permutation[:destination] + [source] + permutation[destination:]
129
+ return x.permute(permutation)
130
+
131
+
132
+ @torch.jit.script
133
+ def sub2ind(subs, shape: List[int]):
134
+ """Convert sub indices (i, j, k) into linear indices.
135
+
136
+ The rightmost dimension is the most rapidly changing one
137
+ -> if shape == [D, H, W], the strides are therefore [H*W, W, 1]
138
+
139
+ Parameters
140
+ ----------
141
+ subs : (D, ...) tensor
142
+ List of sub-indices. The first dimension is the number of dimension.
143
+ Each element should have the same number of elements and shape.
144
+ shape : (D,) list[int]
145
+ Size of each dimension. Its length should be the same as the
146
+ first dimension of ``subs``.
147
+
148
+ Returns
149
+ -------
150
+ ind : (...) tensor
151
+ Linear indices
152
+ """
153
+ subs = subs.unbind(0)
154
+ ind = subs[-1]
155
+ subs = subs[:-1]
156
+ ind = ind.clone()
157
+ stride = list_cumprod_int(shape[1:], reverse=True, exclusive=False)
158
+ for i, s in zip(subs, stride):
159
+ ind += i * s
160
+ return ind
161
+
162
+
163
+ @torch.jit.script
164
+ def sub2ind_list(subs: List[Tensor], shape: List[int]):
165
+ """Convert sub indices (i, j, k) into linear indices.
166
+
167
+ The rightmost dimension is the most rapidly changing one
168
+ -> if shape == [D, H, W], the strides are therefore [H*W, W, 1]
169
+
170
+ Parameters
171
+ ----------
172
+ subs : (D,) list[tensor]
173
+ List of sub-indices. The first dimension is the number of dimension.
174
+ Each element should have the same number of elements and shape.
175
+ shape : (D,) list[int]
176
+ Size of each dimension. Its length should be the same as the
177
+ first dimension of ``subs``.
178
+
179
+ Returns
180
+ -------
181
+ ind : (...) tensor
182
+ Linear indices
183
+ """
184
+ ind = subs[-1]
185
+ subs = subs[:-1]
186
+ ind = ind.clone()
187
+ stride = list_cumprod_int(shape[1:], reverse=True, exclusive=False)
188
+ for i, s in zip(subs, stride):
189
+ ind += i * s
190
+ return ind
191
+
192
+ # floor_divide returns wrong results for negative values, because it truncates
193
+ # instead of performing a proper floor. In recent version of pytorch, it is
194
+ # advised to use div(..., rounding_mode='trunc'|'floor') instead.
195
+ # Here, we only use floor_divide on positive values so we do not care.
196
+ if torch_version('>=', [1, 8]):
197
+ @torch.jit.script
198
+ def floor_div(x, y) -> torch.Tensor:
199
+ return torch.div(x, y, rounding_mode='floor')
200
+ @torch.jit.script
201
+ def floor_div_int(x, y: int) -> torch.Tensor:
202
+ return torch.div(x, y, rounding_mode='floor')
203
+ else:
204
+ @torch.jit.script
205
+ def floor_div(x, y) -> torch.Tensor:
206
+ return (x / y).floor_()
207
+ @torch.jit.script
208
+ def floor_div_int(x, y: int) -> torch.Tensor:
209
+ return (x / y).floor_()
210
+
211
+
212
+ @torch.jit.script
213
+ def ind2sub(ind, shape: List[int]):
214
+ """Convert linear indices into sub indices (i, j, k).
215
+
216
+ The rightmost dimension is the most rapidly changing one
217
+ -> if shape == [D, H, W], the strides are therefore [H*W, W, 1]
218
+
219
+ Parameters
220
+ ----------
221
+ ind : tensor_like
222
+ Linear indices
223
+ shape : (D,) vector_like
224
+ Size of each dimension.
225
+
226
+ Returns
227
+ -------
228
+ subs : (D, ...) tensor
229
+ Sub-indices.
230
+ """
231
+ stride = list_cumprod_int(shape, reverse=True, exclusive=True)
232
+ sub = ind.new_empty([len(shape)] + ind.shape)
233
+ sub.copy_(ind)
234
+ for d in range(len(shape)):
235
+ if d > 0:
236
+ sub[d] = torch.remainder(sub[d], stride[d-1])
237
+ sub[d] = floor_div_int(sub[d], stride[d])
238
+ return sub
239
+
240
+
241
+ @torch.jit.script
242
+ def inbounds_mask_3d(extrapolate: int, gx, gy, gz, nx: int, ny: int, nz: int) \
243
+ -> Optional[Tensor]:
244
+ # mask of inbounds voxels
245
+ mask: Optional[Tensor] = None
246
+ if extrapolate in (0, 2): # no / hist
247
+ tiny = 5e-2
248
+ threshold = tiny
249
+ if extrapolate == 2:
250
+ threshold = 0.5 + tiny
251
+ mask = ((gx > -threshold) & (gx < nx - 1 + threshold) &
252
+ (gy > -threshold) & (gy < ny - 1 + threshold) &
253
+ (gz > -threshold) & (gz < nz - 1 + threshold))
254
+ return mask
255
+ return mask
256
+
257
+
258
+ @torch.jit.script
259
+ def inbounds_mask_2d(extrapolate: int, gx, gy, nx: int, ny: int) \
260
+ -> Optional[Tensor]:
261
+ # mask of inbounds voxels
262
+ mask: Optional[Tensor] = None
263
+ if extrapolate in (0, 2): # no / hist
264
+ tiny = 5e-2
265
+ threshold = tiny
266
+ if extrapolate == 2:
267
+ threshold = 0.5 + tiny
268
+ mask = ((gx > -threshold) & (gx < nx - 1 + threshold) &
269
+ (gy > -threshold) & (gy < ny - 1 + threshold))
270
+ return mask
271
+ return mask
272
+
273
+
274
+ @torch.jit.script
275
+ def inbounds_mask_1d(extrapolate: int, gx, nx: int) -> Optional[Tensor]:
276
+ # mask of inbounds voxels
277
+ mask: Optional[Tensor] = None
278
+ if extrapolate in (0, 2): # no / hist
279
+ tiny = 5e-2
280
+ threshold = tiny
281
+ if extrapolate == 2:
282
+ threshold = 0.5 + tiny
283
+ mask = (gx > -threshold) & (gx < nx - 1 + threshold)
284
+ return mask
285
+ return mask
286
+
287
+
288
+ @torch.jit.script
289
+ def make_sign(sign: List[Optional[Tensor]]) -> Optional[Tensor]:
290
+ is_none : List[bool] = [s is None for s in sign]
291
+ if list_all(is_none):
292
+ return None
293
+ filt_sign: List[Tensor] = []
294
+ for s in sign:
295
+ if s is not None:
296
+ filt_sign.append(s)
297
+ return list_prod_tensor(filt_sign)
298
+
299
+
300
+ @torch.jit.script
301
+ def square(x):
302
+ return x * x
303
+
304
+
305
+ @torch.jit.script
306
+ def square_(x):
307
+ return x.mul_(x)
308
+
309
+
310
+ @torch.jit.script
311
+ def cube(x):
312
+ return x * x * x
313
+
314
+
315
+ @torch.jit.script
316
+ def cube_(x):
317
+ return square_(x).mul_(x)
318
+
319
+
320
+ @torch.jit.script
321
+ def pow4(x):
322
+ return square(square(x))
323
+
324
+
325
+ @torch.jit.script
326
+ def pow4_(x):
327
+ return square_(square_(x))
328
+
329
+
330
+ @torch.jit.script
331
+ def pow5(x):
332
+ return x * pow4(x)
333
+
334
+
335
+ @torch.jit.script
336
+ def pow5_(x):
337
+ return pow4_(x).mul_(x)
338
+
339
+
340
+ @torch.jit.script
341
+ def pow6(x):
342
+ return square(cube(x))
343
+
344
+
345
+ @torch.jit.script
346
+ def pow6_(x):
347
+ return square_(cube_(x))
348
+
349
+
350
+ @torch.jit.script
351
+ def pow7(x):
352
+ return pow6(x) * x
353
+
354
+
355
+ @torch.jit.script
356
+ def pow7_(x):
357
+ return pow6_(x).mul_(x)
358
+
359
+
360
+ @torch.jit.script
361
+ def dot(x, y, dim: int = -1, keepdim: bool = False):
362
+ """(Batched) dot product along a dimension"""
363
+ x = movedim1(x, dim, -1).unsqueeze(-2)
364
+ y = movedim1(y, dim, -1).unsqueeze(-1)
365
+ d = torch.matmul(x, y).squeeze(-1).squeeze(-1)
366
+ if keepdim:
367
+ d.unsqueeze(dim)
368
+ return d
369
+
370
+
371
+ @torch.jit.script
372
+ def dot_multi(x, y, dim: List[int], keepdim: bool = False):
373
+ """(Batched) dot product along a dimension"""
374
+ for d in dim:
375
+ x = movedim1(x, d, -1)
376
+ y = movedim1(y, d, -1)
377
+ x = x.reshape(x.shape[:-len(dim)] + [1, -1])
378
+ y = y.reshape(x.shape[:-len(dim)] + [-1, 1])
379
+ dt = torch.matmul(x, y).squeeze(-1).squeeze(-1)
380
+ if keepdim:
381
+ for d in dim:
382
+ dt.unsqueeze(d)
383
+ return dt
384
+
385
+
386
+
387
+ # cartesian_prod takes multiple inout tensors as input in eager mode
388
+ # but takes a list of tensor in jit mode. This is a helper that works
389
+ # in both cases.
390
+ if not int(os.environ.get('PYTORCH_JIT', '1')):
391
+ cartesian_prod = lambda x: torch.cartesian_prod(*x)
392
+ if torch_version('>=', (1, 10)):
393
+ def meshgrid_ij(x: List[torch.Tensor]) -> List[torch.Tensor]:
394
+ return torch.meshgrid(*x, indexing='ij')
395
+ def meshgrid_xy(x: List[torch.Tensor]) -> List[torch.Tensor]:
396
+ return torch.meshgrid(*x, indexing='xy')
397
+ else:
398
+ def meshgrid_ij(x: List[torch.Tensor]) -> List[torch.Tensor]:
399
+ return torch.meshgrid(*x)
400
+ def meshgrid_xy(x: List[torch.Tensor]) -> List[torch.Tensor]:
401
+ grid = torch.meshgrid(*x)
402
+ if len(grid) > 1:
403
+ grid[0] = grid[0].transpose(0, 1)
404
+ grid[1] = grid[1].transpose(0, 1)
405
+ return grid
406
+
407
+ else:
408
+ cartesian_prod = torch.cartesian_prod
409
+ if torch_version('>=', (1, 10)):
410
+ @torch.jit.script
411
+ def meshgrid_ij(x: List[torch.Tensor]) -> List[torch.Tensor]:
412
+ return torch.meshgrid(x, indexing='ij')
413
+ @torch.jit.script
414
+ def meshgrid_xy(x: List[torch.Tensor]) -> List[torch.Tensor]:
415
+ return torch.meshgrid(x, indexing='xy')
416
+ else:
417
+ @torch.jit.script
418
+ def meshgrid_ij(x: List[torch.Tensor]) -> List[torch.Tensor]:
419
+ return torch.meshgrid(x)
420
+ @torch.jit.script
421
+ def meshgrid_xyt(x: List[torch.Tensor]) -> List[torch.Tensor]:
422
+ grid = torch.meshgrid(x)
423
+ if len(grid) > 1:
424
+ grid[0] = grid[0].transpose(0, 1)
425
+ grid[1] = grid[1].transpose(0, 1)
426
+ return grid
427
+
428
+
429
+ meshgrid = meshgrid_ij
430
+
431
+
432
+ # In torch < 1.6, div applied to integer tensor performed a floor_divide
433
+ # In torch > 1.6, it performs a true divide.
434
+ # Floor division must be done using `floor_divide`, but it was buggy
435
+ # until torch 1.13 (it was doing a trunc divide instead of a floor divide).
436
+ # There was at some point a deprecation warning for floor_divide, but it
437
+ # seems to have been lifted afterwards. In torch >= 1.13, floor_divide
438
+ # performs a correct floor division.
439
+ # Since we only apply floor_divide ot positive values, we are fine.
440
+ if torch_version('<', (1, 6)):
441
+ floor_div = torch.div
442
+ else:
443
+ floor_div = torch.floor_divide
Generator/interpol/jitfields.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ try:
2
+ import jitfields
3
+ available = True
4
+ except (ImportError, ModuleNotFoundError):
5
+ jitfields = None
6
+ available = False
7
+ from .utils import make_list
8
+ import torch
9
+
10
+
11
+ def first2last(input, ndim):
12
+ insert = input.dim() <= ndim
13
+ if insert:
14
+ input = input.unsqueeze(-1)
15
+ else:
16
+ input = torch.movedim(input, -ndim-1, -1)
17
+ return input, insert
18
+
19
+
20
+ def last2first(input, ndim, inserted, grad=False):
21
+ if inserted:
22
+ input = input.squeeze(-1 - grad)
23
+ else:
24
+ input = torch.movedim(input, -1 - grad, -ndim-1 - grad)
25
+ return input
26
+
27
+
28
+ def grid_pull(input, grid, interpolation='linear', bound='zero',
29
+ extrapolate=False, prefilter=False):
30
+ ndim = grid.shape[-1]
31
+ input, inserted = first2last(input, ndim)
32
+ input = jitfields.pull(input, grid, order=interpolation, bound=bound,
33
+ extrapolate=extrapolate, prefilter=prefilter)
34
+ input = last2first(input, ndim, inserted)
35
+ return input
36
+
37
+
38
+ def grid_push(input, grid, shape=None, interpolation='linear', bound='zero',
39
+ extrapolate=False, prefilter=False):
40
+ ndim = grid.shape[-1]
41
+ input, inserted = first2last(input, ndim)
42
+ input = jitfields.push(input, grid, shape, order=interpolation, bound=bound,
43
+ extrapolate=extrapolate, prefilter=prefilter)
44
+ input = last2first(input, ndim, inserted)
45
+ return input
46
+
47
+
48
+ def grid_count(grid, shape=None, interpolation='linear', bound='zero',
49
+ extrapolate=False):
50
+ return jitfields.count(grid, shape, order=interpolation, bound=bound,
51
+ extrapolate=extrapolate)
52
+
53
+
54
+ def grid_grad(input, grid, interpolation='linear', bound='zero',
55
+ extrapolate=False, prefilter=False):
56
+ ndim = grid.shape[-1]
57
+ input, inserted = first2last(input, ndim)
58
+ input = jitfields.grad(input, grid, order=interpolation, bound=bound,
59
+ extrapolate=extrapolate, prefilter=prefilter)
60
+ input = last2first(input, ndim, inserted, True)
61
+ return input
62
+
63
+
64
+ def spline_coeff(input, interpolation='linear', bound='dct2', dim=-1,
65
+ inplace=False):
66
+ func = jitfields.spline_coeff_ if inplace else jitfields.spline_coeff
67
+ return func(input, interpolation, bound=bound, dim=dim)
68
+
69
+
70
+ def spline_coeff_nd(input, interpolation='linear', bound='dct2', dim=None,
71
+ inplace=False):
72
+ func = jitfields.spline_coeff_nd_ if inplace else jitfields.spline_coeff_nd
73
+ return func(input, interpolation, bound=bound, ndim=dim)
74
+
75
+
76
+ def resize(image, factor=None, shape=None, anchor='c',
77
+ interpolation=1, prefilter=True, **kwargs):
78
+ kwargs.setdefault('bound', 'nearest')
79
+ ndim = max(len(make_list(factor or [])),
80
+ len(make_list(shape or [])),
81
+ len(make_list(anchor or []))) or (image.dim() - 2)
82
+ return jitfields.resize(image, factor=factor, shape=shape, ndim=ndim,
83
+ anchor=anchor, order=interpolation,
84
+ bound=kwargs['bound'], prefilter=prefilter)
85
+
86
+
87
+ def restrict(image, factor=None, shape=None, anchor='c',
88
+ interpolation=1, reduce_sum=False, **kwargs):
89
+ kwargs.setdefault('bound', 'nearest')
90
+ ndim = max(len(make_list(factor or [])),
91
+ len(make_list(shape or [])),
92
+ len(make_list(anchor or []))) or (image.dim() - 2)
93
+ return jitfields.restrict(image, factor=factor, shape=shape, ndim=ndim,
94
+ anchor=anchor, order=interpolation,
95
+ bound=kwargs['bound'], reduce_sum=reduce_sum)
Generator/interpol/nd.py ADDED
@@ -0,0 +1,464 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Generic N-dimensional version: any combination of spline orders"""
2
+ import torch
3
+ from typing import List, Optional, Tuple
4
+ from .bounds import Bound
5
+ from .splines import Spline
6
+ from .jit_utils import sub2ind_list, make_sign, list_prod_int, cartesian_prod
7
+ Tensor = torch.Tensor
8
+
9
+
10
+ @torch.jit.script
11
+ def inbounds_mask(extrapolate: int, grid, shape: List[int])\
12
+ -> Optional[Tensor]:
13
+ # mask of inbounds voxels
14
+ mask: Optional[Tensor] = None
15
+ if extrapolate in (0, 2): # no / hist
16
+ grid = grid.unsqueeze(1)
17
+ tiny = 5e-2
18
+ threshold = tiny
19
+ if extrapolate == 2:
20
+ threshold = 0.5 + tiny
21
+ mask = torch.ones(grid.shape[:-1],
22
+ dtype=torch.bool, device=grid.device)
23
+ for grid1, shape1 in zip(grid.unbind(-1), shape):
24
+ mask = mask & (grid1 > -threshold)
25
+ mask = mask & (grid1 < shape1 - 1 + threshold)
26
+ return mask
27
+ return mask
28
+
29
+
30
+ @torch.jit.script
31
+ def get_weights(grid, bound: List[Bound], spline: List[Spline],
32
+ shape: List[int], grad: bool = False, hess: bool = False) \
33
+ -> Tuple[List[List[Tensor]],
34
+ List[List[Optional[Tensor]]],
35
+ List[List[Optional[Tensor]]],
36
+ List[List[Tensor]],
37
+ List[List[Optional[Tensor]]]]:
38
+
39
+ weights: List[List[Tensor]] = []
40
+ grads: List[List[Optional[Tensor]]] = []
41
+ hesss: List[List[Optional[Tensor]]] = []
42
+ coords: List[List[Tensor]] = []
43
+ signs: List[List[Optional[Tensor]]] = []
44
+ for g, b, s, n in zip(grid.unbind(-1), bound, spline, shape):
45
+ grid0 = (g - (s.order-1)/2).floor()
46
+ dist0 = g - grid0
47
+ grid0 = grid0.long()
48
+ nb_nodes = s.order + 1
49
+ subweights: List[Tensor] = []
50
+ subcoords: List[Tensor] = []
51
+ subgrads: List[Optional[Tensor]] = []
52
+ subhesss: List[Optional[Tensor]] = []
53
+ subsigns: List[Optional[Tensor]] = []
54
+ for node in range(nb_nodes):
55
+ grid1 = grid0 + node
56
+ sign1: Optional[Tensor] = b.transform(grid1, n)
57
+ subsigns.append(sign1)
58
+ grid1 = b.index(grid1, n)
59
+ subcoords.append(grid1)
60
+ dist1 = dist0 - node
61
+ weight1 = s.fastweight(dist1)
62
+ subweights.append(weight1)
63
+ grad1: Optional[Tensor] = None
64
+ if grad:
65
+ grad1 = s.fastgrad(dist1)
66
+ subgrads.append(grad1)
67
+ hess1: Optional[Tensor] = None
68
+ if hess:
69
+ hess1 = s.fasthess(dist1)
70
+ subhesss.append(hess1)
71
+ weights.append(subweights)
72
+ coords.append(subcoords)
73
+ signs.append(subsigns)
74
+ grads.append(subgrads)
75
+ hesss.append(subhesss)
76
+
77
+ return weights, grads, hesss, coords, signs
78
+
79
+
80
+ @torch.jit.script
81
+ def pull(inp, grid, bound: List[Bound], spline: List[Spline],
82
+ extrapolate: int = 1):
83
+ """
84
+ inp: (B, C, *ishape) tensor
85
+ g: (B, *oshape, D) tensor
86
+ bound: List{D}[Bound] tensor
87
+ spline: List{D}[Spline] tensor
88
+ extrapolate: int
89
+ returns: (B, C, *oshape) tensor
90
+ """
91
+
92
+ dim = grid.shape[-1]
93
+ shape = list(inp.shape[-dim:])
94
+ oshape = list(grid.shape[-dim-1:-1])
95
+ batch = max(inp.shape[0], grid.shape[0])
96
+ channel = inp.shape[1]
97
+
98
+ grid = grid.reshape([grid.shape[0], -1, grid.shape[-1]])
99
+ inp = inp.reshape([inp.shape[0], inp.shape[1], -1])
100
+ mask = inbounds_mask(extrapolate, grid, shape)
101
+
102
+ # precompute weights along each dimension
103
+ weights, _, _, coords, signs = get_weights(grid, bound, spline, shape, False, False)
104
+
105
+ # initialize
106
+ out = torch.zeros([batch, channel, grid.shape[1]],
107
+ dtype=inp.dtype, device=inp.device)
108
+
109
+ # iterate across nodes/corners
110
+ range_nodes = [torch.as_tensor([d for d in range(n)])
111
+ for n in [s.order + 1 for s in spline]]
112
+ if dim == 1:
113
+ # cartesian_prod does not work as expected when only one
114
+ # element is provided
115
+ all_nodes = range_nodes[0].unsqueeze(-1)
116
+ else:
117
+ all_nodes = cartesian_prod(range_nodes)
118
+ for nodes in all_nodes:
119
+ # gather
120
+ idx = [c[n] for c, n in zip(coords, nodes)]
121
+ idx = sub2ind_list(idx, shape).unsqueeze(1)
122
+ idx = idx.expand([batch, channel, idx.shape[-1]])
123
+ out1 = inp.gather(-1, idx)
124
+
125
+ # apply sign
126
+ sign0: List[Optional[Tensor]] = [sgn[n] for sgn, n in zip(signs, nodes)]
127
+ sign1: Optional[Tensor] = make_sign(sign0)
128
+ if sign1 is not None:
129
+ out1 = out1 * sign1.unsqueeze(1)
130
+
131
+ # apply weights
132
+ for weight, n in zip(weights, nodes):
133
+ out1 = out1 * weight[n].unsqueeze(1)
134
+
135
+ # accumulate
136
+ out = out + out1
137
+
138
+ # out-of-bounds mask
139
+ if mask is not None:
140
+ out = out * mask
141
+
142
+ out = out.reshape(list(out.shape[:2]) + oshape)
143
+ return out
144
+
145
+
146
+ @torch.jit.script
147
+ def push(inp, grid, shape: Optional[List[int]], bound: List[Bound],
148
+ spline: List[Spline], extrapolate: int = 1):
149
+ """
150
+ inp: (B, C, *ishape) tensor
151
+ g: (B, *ishape, D) tensor
152
+ shape: List{D}[int], optional
153
+ bound: List{D}[Bound] tensor
154
+ spline: List{D}[Spline] tensor
155
+ extrapolate: int
156
+ returns: (B, C, *oshape) tensor
157
+ """
158
+
159
+ dim = grid.shape[-1]
160
+ ishape = list(grid.shape[-dim - 1:-1])
161
+ if shape is None:
162
+ shape = ishape
163
+ shape = list(shape)
164
+ batch = max(inp.shape[0], grid.shape[0])
165
+ channel = inp.shape[1]
166
+
167
+ grid = grid.reshape([grid.shape[0], -1, grid.shape[-1]])
168
+ inp = inp.reshape([inp.shape[0], inp.shape[1], -1])
169
+ mask = inbounds_mask(extrapolate, grid, shape)
170
+
171
+ # precompute weights along each dimension
172
+ weights, _, _, coords, signs = get_weights(grid, bound, spline, shape)
173
+
174
+ # initialize
175
+ out = torch.zeros([batch, channel, list_prod_int(shape)],
176
+ dtype=inp.dtype, device=inp.device)
177
+
178
+ # iterate across nodes/corners
179
+ range_nodes = [torch.as_tensor([d for d in range(n)])
180
+ for n in [s.order + 1 for s in spline]]
181
+ if dim == 1:
182
+ # cartesian_prod does not work as expected when only one
183
+ # element is provided
184
+ all_nodes = range_nodes[0].unsqueeze(-1)
185
+ else:
186
+ all_nodes = cartesian_prod(range_nodes)
187
+ for nodes in all_nodes:
188
+
189
+ # gather
190
+ idx = [c[n] for c, n in zip(coords, nodes)]
191
+ idx = sub2ind_list(idx, shape).unsqueeze(1)
192
+ idx = idx.expand([batch, channel, idx.shape[-1]])
193
+ out1 = inp.clone()
194
+
195
+ # apply sign
196
+ sign0: List[Optional[Tensor]] = [sgn[n] for sgn, n in zip(signs, nodes)]
197
+ sign1: Optional[Tensor] = make_sign(sign0)
198
+ if sign1 is not None:
199
+ out1 = out1 * sign1.unsqueeze(1)
200
+
201
+ # out-of-bounds mask
202
+ if mask is not None:
203
+ out1 = out1 * mask
204
+
205
+ # apply weights
206
+ for weight, n in zip(weights, nodes):
207
+ out1 = out1 * weight[n].unsqueeze(1)
208
+
209
+ # accumulate
210
+ out.scatter_add_(-1, idx, out1)
211
+
212
+ out = out.reshape(list(out.shape[:2]) + shape)
213
+ return out
214
+
215
+
216
+ @torch.jit.script
217
+ def grad(inp, grid, bound: List[Bound], spline: List[Spline],
218
+ extrapolate: int = 1):
219
+ """
220
+ inp: (B, C, *ishape) tensor
221
+ grid: (B, *oshape, D) tensor
222
+ bound: List{D}[Bound] tensor
223
+ spline: List{D}[Spline] tensor
224
+ extrapolate: int
225
+ returns: (B, C, *oshape, D) tensor
226
+ """
227
+
228
+ dim = grid.shape[-1]
229
+ shape = list(inp.shape[-dim:])
230
+ oshape = list(grid.shape[-dim-1:-1])
231
+ batch = max(inp.shape[0], grid.shape[0])
232
+ channel = inp.shape[1]
233
+
234
+ grid = grid.reshape([grid.shape[0], -1, grid.shape[-1]])
235
+ inp = inp.reshape([inp.shape[0], inp.shape[1], -1])
236
+ mask = inbounds_mask(extrapolate, grid, shape)
237
+
238
+ # precompute weights along each dimension
239
+ weights, grads, _, coords, signs = get_weights(grid, bound, spline, shape,
240
+ grad=True)
241
+
242
+ # initialize
243
+ out = torch.zeros([batch, channel, grid.shape[1], dim],
244
+ dtype=inp.dtype, device=inp.device)
245
+
246
+ # iterate across nodes/corners
247
+ range_nodes = [torch.as_tensor([d for d in range(n)])
248
+ for n in [s.order + 1 for s in spline]]
249
+ if dim == 1:
250
+ # cartesian_prod does not work as expected when only one
251
+ # element is provided
252
+ all_nodes = range_nodes[0].unsqueeze(-1)
253
+ else:
254
+ all_nodes = cartesian_prod(range_nodes)
255
+ for nodes in all_nodes:
256
+
257
+ # gather
258
+ idx = [c[n] for c, n in zip(coords, nodes)]
259
+ idx = sub2ind_list(idx, shape).unsqueeze(1)
260
+ idx = idx.expand([batch, channel, idx.shape[-1]])
261
+ out0 = inp.gather(-1, idx)
262
+
263
+ # apply sign
264
+ sign0: List[Optional[Tensor]] = [sgn[n] for sgn, n in zip(signs, nodes)]
265
+ sign1: Optional[Tensor] = make_sign(sign0)
266
+ if sign1 is not None:
267
+ out0 = out0 * sign1.unsqueeze(1)
268
+
269
+ for d in range(dim):
270
+ out1 = out0.clone()
271
+ # apply weights
272
+ for dd, (weight, grad1, n) in enumerate(zip(weights, grads, nodes)):
273
+ if d == dd:
274
+ grad11 = grad1[n]
275
+ if grad11 is not None:
276
+ out1 = out1 * grad11.unsqueeze(1)
277
+ else:
278
+ out1 = out1 * weight[n].unsqueeze(1)
279
+
280
+ # accumulate
281
+ out.unbind(-1)[d].add_(out1)
282
+
283
+ # out-of-bounds mask
284
+ if mask is not None:
285
+ out = out * mask.unsqueeze(-1)
286
+
287
+ out = out.reshape(list(out.shape[:2]) + oshape + list(out.shape[-1:]))
288
+ return out
289
+
290
+
291
+ @torch.jit.script
292
+ def pushgrad(inp, grid, shape: Optional[List[int]], bound: List[Bound],
293
+ spline: List[Spline], extrapolate: int = 1):
294
+ """
295
+ inp: (B, C, *ishape, D) tensor
296
+ g: (B, *ishape, D) tensor
297
+ shape: List{D}[int], optional
298
+ bound: List{D}[Bound] tensor
299
+ spline: List{D}[Spline] tensor
300
+ extrapolate: int
301
+ returns: (B, C, *shape) tensor
302
+ """
303
+ dim = grid.shape[-1]
304
+ oshape = list(grid.shape[-dim-1:-1])
305
+ if shape is None:
306
+ shape = oshape
307
+ shape = list(shape)
308
+ batch = max(inp.shape[0], grid.shape[0])
309
+ channel = inp.shape[1]
310
+
311
+ grid = grid.reshape([grid.shape[0], -1, grid.shape[-1]])
312
+ inp = inp.reshape([inp.shape[0], inp.shape[1], -1, dim])
313
+ mask = inbounds_mask(extrapolate, grid, shape)
314
+
315
+ # precompute weights along each dimension
316
+ weights, grads, _, coords, signs = get_weights(grid, bound, spline, shape, grad=True)
317
+
318
+ # initialize
319
+ out = torch.zeros([batch, channel, list_prod_int(shape)],
320
+ dtype=inp.dtype, device=inp.device)
321
+
322
+ # iterate across nodes/corners
323
+ range_nodes = [torch.as_tensor([d for d in range(n)])
324
+ for n in [s.order + 1 for s in spline]]
325
+ if dim == 1:
326
+ # cartesian_prod does not work as expected when only one
327
+ # element is provided
328
+ all_nodes = range_nodes[0].unsqueeze(-1)
329
+ else:
330
+ all_nodes = cartesian_prod(range_nodes)
331
+ for nodes in all_nodes:
332
+
333
+ # gather
334
+ idx = [c[n] for c, n in zip(coords, nodes)]
335
+ idx = sub2ind_list(idx, shape).unsqueeze(1)
336
+ idx = idx.expand([batch, channel, idx.shape[-1]])
337
+ out0 = inp.clone()
338
+
339
+ # apply sign
340
+ sign0: List[Optional[Tensor]] = [sgn[n] for sgn, n in zip(signs, nodes)]
341
+ sign1: Optional[Tensor] = make_sign(sign0)
342
+ if sign1 is not None:
343
+ out0 = out0 * sign1.unsqueeze(1).unsqueeze(-1)
344
+
345
+ # out-of-bounds mask
346
+ if mask is not None:
347
+ out0 = out0 * mask.unsqueeze(-1)
348
+
349
+ for d in range(dim):
350
+ out1 = out0.unbind(-1)[d].clone()
351
+ # apply weights
352
+ for dd, (weight, grad1, n) in enumerate(zip(weights, grads, nodes)):
353
+ if d == dd:
354
+ grad11 = grad1[n]
355
+ if grad11 is not None:
356
+ out1 = out1 * grad11.unsqueeze(1)
357
+ else:
358
+ out1 = out1 * weight[n].unsqueeze(1)
359
+
360
+ # accumulate
361
+ out.scatter_add_(-1, idx, out1)
362
+
363
+ out = out.reshape(list(out.shape[:2]) + shape)
364
+ return out
365
+
366
+
367
+ @torch.jit.script
368
+ def hess(inp, grid, bound: List[Bound], spline: List[Spline],
369
+ extrapolate: int = 1):
370
+ """
371
+ inp: (B, C, *ishape) tensor
372
+ grid: (B, *oshape, D) tensor
373
+ bound: List{D}[Bound] tensor
374
+ spline: List{D}[Spline] tensor
375
+ extrapolate: int
376
+ returns: (B, C, *oshape, D, D) tensor
377
+ """
378
+
379
+ dim = grid.shape[-1]
380
+ shape = list(inp.shape[-dim:])
381
+ oshape = list(grid.shape[-dim-1:-1])
382
+ batch = max(inp.shape[0], grid.shape[0])
383
+ channel = inp.shape[1]
384
+
385
+ grid = grid.reshape([grid.shape[0], -1, grid.shape[-1]])
386
+ inp = inp.reshape([inp.shape[0], inp.shape[1], -1])
387
+ mask = inbounds_mask(extrapolate, grid, shape)
388
+
389
+ # precompute weights along each dimension
390
+ weights, grads, hesss, coords, signs \
391
+ = get_weights(grid, bound, spline, shape, grad=True, hess=True)
392
+
393
+ # initialize
394
+ out = torch.zeros([batch, channel, grid.shape[1], dim, dim],
395
+ dtype=inp.dtype, device=inp.device)
396
+
397
+ # iterate across nodes/corners
398
+ range_nodes = [torch.as_tensor([d for d in range(n)])
399
+ for n in [s.order + 1 for s in spline]]
400
+ if dim == 1:
401
+ # cartesian_prod does not work as expected when only one
402
+ # element is provided
403
+ all_nodes = range_nodes[0].unsqueeze(-1)
404
+ else:
405
+ all_nodes = cartesian_prod(range_nodes)
406
+ for nodes in all_nodes:
407
+
408
+ # gather
409
+ idx = [c[n] for c, n in zip(coords, nodes)]
410
+ idx = sub2ind_list(idx, shape).unsqueeze(1)
411
+ idx = idx.expand([batch, channel, idx.shape[-1]])
412
+ out0 = inp.gather(-1, idx)
413
+
414
+ # apply sign
415
+ sign0: List[Optional[Tensor]] = [sgn[n] for sgn, n in zip(signs, nodes)]
416
+ sign1: Optional[Tensor] = make_sign(sign0)
417
+ if sign1 is not None:
418
+ out0 = out0 * sign1.unsqueeze(1)
419
+
420
+ for d in range(dim):
421
+ # -- diagonal --
422
+ out1 = out0.clone()
423
+
424
+ # apply weights
425
+ for dd, (weight, hess1, n) \
426
+ in enumerate(zip(weights, hesss, nodes)):
427
+ if d == dd:
428
+ hess11 = hess1[n]
429
+ if hess11 is not None:
430
+ out1 = out1 * hess11.unsqueeze(1)
431
+ else:
432
+ out1 = out1 * weight[n].unsqueeze(1)
433
+
434
+ # accumulate
435
+ out.unbind(-1)[d].unbind(-1)[d].add_(out1)
436
+
437
+ # -- off diagonal --
438
+ for d2 in range(d+1, dim):
439
+ out1 = out0.clone()
440
+
441
+ # apply weights
442
+ for dd, (weight, grad1, n) \
443
+ in enumerate(zip(weights, grads, nodes)):
444
+ if dd in (d, d2):
445
+ grad11 = grad1[n]
446
+ if grad11 is not None:
447
+ out1 = out1 * grad11.unsqueeze(1)
448
+ else:
449
+ out1 = out1 * weight[n].unsqueeze(1)
450
+
451
+ # accumulate
452
+ out.unbind(-1)[d].unbind(-1)[d2].add_(out1)
453
+
454
+ # out-of-bounds mask
455
+ if mask is not None:
456
+ out = out * mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
457
+
458
+ # fill lower triangle
459
+ for d in range(dim):
460
+ for d2 in range(d+1, dim):
461
+ out.unbind(-1)[d2].unbind(-1)[d].copy_(out.unbind(-1)[d].unbind(-1)[d2])
462
+
463
+ out = out.reshape(list(out.shape[:2]) + oshape + list(out.shape[-2:]))
464
+ return out
Generator/interpol/pushpull.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Non-differentiable forward/backward components.
3
+ These components are put together in `interpol.autograd` to generate
4
+ differentiable functions.
5
+
6
+ Note
7
+ ----
8
+ .. I removed @torch.jit.script from these entry-points because compiling
9
+ all possible combinations of bound+interpolation made the first call
10
+ extremely slow.
11
+ .. I am not using the dot/multi_dot helpers even though they should be
12
+ more efficient that "multiply and sum" because I haven't had the time
13
+ to test them. It would be worth doing it.
14
+ """
15
+ import torch
16
+ from typing import List, Optional, Tuple
17
+ from .jit_utils import list_all, dot, dot_multi, pad_list_int
18
+ from .bounds import Bound
19
+ from .splines import Spline
20
+ from . import iso0, iso1, nd
21
+ Tensor = torch.Tensor
22
+
23
+
24
+ @torch.jit.script
25
+ def make_bound(bound: List[int]) -> List[Bound]:
26
+ return [Bound(b) for b in bound]
27
+
28
+
29
+ @torch.jit.script
30
+ def make_spline(spline: List[int]) -> List[Spline]:
31
+ return [Spline(s) for s in spline]
32
+
33
+
34
+ # @torch.jit.script
35
+ def grid_pull(inp, grid, bound: List[int], interpolation: List[int],
36
+ extrapolate: int):
37
+ """
38
+ inp: (B, C, *spatial_in) tensor
39
+ grid: (B, *spatial_out, D) tensor
40
+ bound: List{D}[int] tensor
41
+ interpolation: List{D}[int]
42
+ extrapolate: int
43
+ returns: (B, C, *spatial_out) tensor
44
+ """
45
+ dim = grid.shape[-1]
46
+ bound = pad_list_int(bound, dim)
47
+ interpolation = pad_list_int(interpolation, dim)
48
+ bound_fn = make_bound(bound)
49
+ is_iso1 = list_all([order == 1 for order in interpolation])
50
+ if is_iso1:
51
+ if dim == 3:
52
+ return iso1.pull3d(inp, grid, bound_fn, extrapolate)
53
+ elif dim == 2:
54
+ return iso1.pull2d(inp, grid, bound_fn, extrapolate)
55
+ elif dim == 1:
56
+ return iso1.pull1d(inp, grid, bound_fn, extrapolate)
57
+ is_iso0 = list_all([order == 0 for order in interpolation])
58
+ if is_iso0:
59
+ if dim == 3:
60
+ return iso0.pull3d(inp, grid, bound_fn, extrapolate)
61
+ elif dim == 2:
62
+ return iso0.pull2d(inp, grid, bound_fn, extrapolate)
63
+ elif dim == 1:
64
+ return iso0.pull1d(inp, grid, bound_fn, extrapolate)
65
+ spline_fn = make_spline(interpolation)
66
+ return nd.pull(inp, grid, bound_fn, spline_fn, extrapolate)
67
+
68
+
69
+ # @torch.jit.script
70
+ def grid_push(inp, grid, shape: Optional[List[int]], bound: List[int],
71
+ interpolation: List[int], extrapolate: int):
72
+ """
73
+ inp: (B, C, *spatial_in) tensor
74
+ grid: (B, *spatial_in, D) tensor
75
+ shape: List{D}[int] tensor, optional, default=spatial_in
76
+ bound: List{D}[int] tensor
77
+ interpolation: List{D}[int]
78
+ extrapolate: int
79
+ returns: (B, C, *shape) tensor
80
+ """
81
+ dim = grid.shape[-1]
82
+ bound = pad_list_int(bound, dim)
83
+ interpolation = pad_list_int(interpolation, dim)
84
+ bound_fn = make_bound(bound)
85
+ is_iso1 = list_all([order == 1 for order in interpolation])
86
+ if is_iso1:
87
+ if dim == 3:
88
+ return iso1.push3d(inp, grid, shape, bound_fn, extrapolate)
89
+ elif dim == 2:
90
+ return iso1.push2d(inp, grid, shape, bound_fn, extrapolate)
91
+ elif dim == 1:
92
+ return iso1.push1d(inp, grid, shape, bound_fn, extrapolate)
93
+ is_iso0 = list_all([order == 0 for order in interpolation])
94
+ if is_iso0:
95
+ if dim == 3:
96
+ return iso0.push3d(inp, grid, shape, bound_fn, extrapolate)
97
+ elif dim == 2:
98
+ return iso0.push2d(inp, grid, shape, bound_fn, extrapolate)
99
+ elif dim == 1:
100
+ return iso0.push1d(inp, grid, shape, bound_fn, extrapolate)
101
+ spline_fn = make_spline(interpolation)
102
+ return nd.push(inp, grid, shape, bound_fn, spline_fn, extrapolate)
103
+
104
+
105
+ # @torch.jit.script
106
+ def grid_count(grid, shape: Optional[List[int]], bound: List[int],
107
+ interpolation: List[int], extrapolate: int):
108
+ """
109
+ grid: (B, *spatial_in, D) tensor
110
+ shape: List{D}[int] tensor, optional, default=spatial_in
111
+ bound: List{D}[int] tensor
112
+ interpolation: List{D}[int]
113
+ extrapolate: int
114
+ returns: (B, 1, *shape) tensor
115
+ """
116
+ dim = grid.shape[-1]
117
+ bound = pad_list_int(bound, dim)
118
+ interpolation = pad_list_int(interpolation, dim)
119
+ bound_fn = make_bound(bound)
120
+ gshape = list(grid.shape[-dim-1:-1])
121
+ if shape is None:
122
+ shape = gshape
123
+ inp = torch.ones([], dtype=grid.dtype, device=grid.device)
124
+ inp = inp.expand([len(grid), 1] + gshape)
125
+ is_iso1 = list_all([order == 1 for order in interpolation])
126
+ if is_iso1:
127
+ if dim == 3:
128
+ return iso1.push3d(inp, grid, shape, bound_fn, extrapolate)
129
+ elif dim == 2:
130
+ return iso1.push2d(inp, grid, shape, bound_fn, extrapolate)
131
+ elif dim == 1:
132
+ return iso1.push1d(inp, grid, shape, bound_fn, extrapolate)
133
+ is_iso0 = list_all([order == 0 for order in interpolation])
134
+ if is_iso0:
135
+ if dim == 3:
136
+ return iso0.push3d(inp, grid, shape, bound_fn, extrapolate)
137
+ elif dim == 2:
138
+ return iso0.push2d(inp, grid, shape, bound_fn, extrapolate)
139
+ elif dim == 1:
140
+ return iso0.push1d(inp, grid, shape, bound_fn, extrapolate)
141
+ spline_fn = make_spline(interpolation)
142
+ return nd.push(inp, grid, shape, bound_fn, spline_fn, extrapolate)
143
+
144
+
145
+ # @torch.jit.script
146
+ def grid_grad(inp, grid, bound: List[int], interpolation: List[int],
147
+ extrapolate: int):
148
+ """
149
+ inp: (B, C, *spatial_in) tensor
150
+ grid: (B, *spatial_out, D) tensor
151
+ bound: List{D}[int] tensor
152
+ interpolation: List{D}[int]
153
+ extrapolate: int
154
+ returns: (B, C, *spatial_out, D) tensor
155
+ """
156
+ dim = grid.shape[-1]
157
+ bound = pad_list_int(bound, dim)
158
+ interpolation = pad_list_int(interpolation, dim)
159
+ bound_fn = make_bound(bound)
160
+ is_iso1 = list_all([order == 1 for order in interpolation])
161
+ if is_iso1:
162
+ if dim == 3:
163
+ return iso1.grad3d(inp, grid, bound_fn, extrapolate)
164
+ elif dim == 2:
165
+ return iso1.grad2d(inp, grid, bound_fn, extrapolate)
166
+ elif dim == 1:
167
+ return iso1.grad1d(inp, grid, bound_fn, extrapolate)
168
+ is_iso0 = list_all([order == 0 for order in interpolation])
169
+ if is_iso0:
170
+ return iso0.grad(inp, grid, bound_fn, extrapolate)
171
+ spline_fn = make_spline(interpolation)
172
+ return nd.grad(inp, grid, bound_fn, spline_fn, extrapolate)
173
+
174
+
175
+ # @torch.jit.script
176
+ def grid_pushgrad(inp, grid, shape: List[int], bound: List[int],
177
+ interpolation: List[int], extrapolate: int):
178
+ """ /!\ Used only in backward pass of grid_grad
179
+ inp: (B, C, *spatial_in, D) tensor
180
+ grid: (B, *spatial_in, D) tensor
181
+ shape: List{D}[int], optional
182
+ bound: List{D}[int] tensor
183
+ interpolation: List{D}[int]
184
+ extrapolate: int
185
+ returns: (B, C, *shape) tensor
186
+ """
187
+ dim = grid.shape[-1]
188
+ bound = pad_list_int(bound, dim)
189
+ interpolation = pad_list_int(interpolation, dim)
190
+ bound_fn = make_bound(bound)
191
+ is_iso1 = list_all([order == 1 for order in interpolation])
192
+ if is_iso1:
193
+ if dim == 3:
194
+ return iso1.pushgrad3d(inp, grid, shape, bound_fn, extrapolate)
195
+ elif dim == 2:
196
+ return iso1.pushgrad2d(inp, grid, shape, bound_fn, extrapolate)
197
+ elif dim == 1:
198
+ return iso1.pushgrad1d(inp, grid, shape, bound_fn, extrapolate)
199
+ is_iso0 = list_all([order == 0 for order in interpolation])
200
+ if is_iso0:
201
+ return iso0.pushgrad(inp, grid, shape, bound_fn, extrapolate)
202
+ spline_fn = make_spline(interpolation)
203
+ return nd.pushgrad(inp, grid, shape, bound_fn, spline_fn, extrapolate)
204
+
205
+
206
+ # @torch.jit.script
207
+ def grid_hess(inp, grid, bound: List[int], interpolation: List[int],
208
+ extrapolate: int):
209
+ """ /!\ Used only in backward pass of grid_grad
210
+ inp: (B, C, *spatial_in) tensor
211
+ grid: (B, *spatial_out, D) tensor
212
+ bound: List{D}[int] tensor
213
+ interpolation: List{D}[int]
214
+ extrapolate: int
215
+ returns: (B, C, *spatial_out, D, D) tensor
216
+ """
217
+ dim = grid.shape[-1]
218
+ bound = pad_list_int(bound, dim)
219
+ interpolation = pad_list_int(interpolation, dim)
220
+ bound_fn = make_bound(bound)
221
+ is_iso1 = list_all([order == 1 for order in interpolation])
222
+ if is_iso1:
223
+ if dim == 3:
224
+ return iso1.hess3d(inp, grid, bound_fn, extrapolate)
225
+ if dim == 2:
226
+ return iso1.hess2d(inp, grid, bound_fn, extrapolate)
227
+ if dim == 1:
228
+ return iso1.hess1d(inp, grid, bound_fn, extrapolate)
229
+ is_iso0 = list_all([order == 0 for order in interpolation])
230
+ if is_iso0:
231
+ return iso0.hess(inp, grid, bound_fn, extrapolate)
232
+ spline_fn = make_spline(interpolation)
233
+ return nd.hess(inp, grid, bound_fn, spline_fn, extrapolate)
234
+
235
+
236
+ # @torch.jit.script
237
+ def grid_pull_backward(grad, inp, grid, bound: List[int],
238
+ interpolation: List[int], extrapolate: int) \
239
+ -> Tuple[Optional[Tensor], Optional[Tensor], ]:
240
+ """
241
+ grad: (B, C, *spatial_out) tensor
242
+ inp: (B, C, *spatial_in) tensor
243
+ grid: (B, *spatial_out, D) tensor
244
+ bound: List{D}[int] tensor
245
+ interpolation: List{D}[int]
246
+ extrapolate: int
247
+ returns: (B, C, *spatial_in) tensor, (B, *spatial_out, D)
248
+ """
249
+ dim = grid.shape[-1]
250
+ grad_inp: Optional[Tensor] = None
251
+ grad_grid: Optional[Tensor] = None
252
+ if inp.requires_grad:
253
+ grad_inp = grid_push(grad, grid, inp.shape[-dim:], bound, interpolation, extrapolate)
254
+ if grid.requires_grad:
255
+ grad_grid = grid_grad(inp, grid, bound, interpolation, extrapolate)
256
+ # grad_grid = dot(grad_grid, grad.unsqueeze(-1), dim=1)
257
+ grad_grid = (grad_grid * grad.unsqueeze(-1)).sum(dim=1)
258
+ return grad_inp, grad_grid
259
+
260
+
261
+ # @torch.jit.script
262
+ def grid_push_backward(grad, inp, grid, bound: List[int],
263
+ interpolation: List[int], extrapolate: int) \
264
+ -> Tuple[Optional[Tensor], Optional[Tensor], ]:
265
+ """
266
+ grad: (B, C, *spatial_out) tensor
267
+ inp: (B, C, *spatial_in) tensor
268
+ grid: (B, *spatial_in, D) tensor
269
+ bound: List{D}[int] tensor
270
+ interpolation: List{D}[int]
271
+ extrapolate: int
272
+ returns: (B, C, *spatial_in) tensor, (B, *spatial_in, D)
273
+ """
274
+ grad_inp: Optional[Tensor] = None
275
+ grad_grid: Optional[Tensor] = None
276
+ if inp.requires_grad:
277
+ grad_inp = grid_pull(grad, grid, bound, interpolation, extrapolate)
278
+ if grid.requires_grad:
279
+ grad_grid = grid_grad(grad, grid, bound, interpolation, extrapolate)
280
+ # grad_grid = dot(grad_grid, inp.unsqueeze(-1), dim=1)
281
+ grad_grid = (grad_grid * inp.unsqueeze(-1)).sum(dim=1)
282
+ return grad_inp, grad_grid
283
+
284
+
285
+ # @torch.jit.script
286
+ def grid_count_backward(grad, grid, bound: List[int],
287
+ interpolation: List[int], extrapolate: int) \
288
+ -> Optional[Tensor]:
289
+ """
290
+ grad: (B, C, *spatial_out) tensor
291
+ grid: (B, *spatial_in, D) tensor
292
+ bound: List{D}[int] tensor
293
+ interpolation: List{D}[int]
294
+ extrapolate: int
295
+ returns: (B, C, *spatial_in) tensor, (B, *spatial_in, D)
296
+ """
297
+ if grid.requires_grad:
298
+ return grid_grad(grad, grid, bound, interpolation, extrapolate).sum(1)
299
+ return None
300
+
301
+
302
+ # @torch.jit.script
303
+ def grid_grad_backward(grad, inp, grid, bound: List[int],
304
+ interpolation: List[int], extrapolate: int) \
305
+ -> Tuple[Optional[Tensor], Optional[Tensor]]:
306
+ """
307
+ grad: (B, C, *spatial_out, D) tensor
308
+ inp: (B, C, *spatial_in) tensor
309
+ grid: (B, *spatial_out, D) tensor
310
+ bound: List{D}[int] tensor
311
+ interpolation: List{D}[int]
312
+ extrapolate: int
313
+ returns: (B, C, *spatial_in, D) tensor, (B, *spatial_out, D)
314
+ """
315
+ dim = grid.shape[-1]
316
+ shape = inp.shape[-dim:]
317
+ grad_inp: Optional[Tensor] = None
318
+ grad_grid: Optional[Tensor] = None
319
+ if inp.requires_grad:
320
+ grad_inp = grid_pushgrad(grad, grid, shape, bound, interpolation, extrapolate)
321
+ if grid.requires_grad:
322
+ grad_grid = grid_hess(inp, grid, bound, interpolation, extrapolate)
323
+ # grad_grid = dot_multi(grad_grid, grad.unsqueeze(-1), dim=[1, -2])
324
+ grad_grid = (grad_grid * grad.unsqueeze(-1)).sum(dim=[1, -2])
325
+ return grad_inp, grad_grid
Generator/interpol/resize.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Resize functions (equivalent to scipy's zoom, pytorch's interpolate)
3
+ based on grid_pull.
4
+ """
5
+ __all__ = ['resize']
6
+
7
+ from .api import grid_pull
8
+ from .utils import make_list, meshgrid_ij
9
+ from . import backend, jitfields
10
+ import torch
11
+
12
+
13
+ def resize(image, factor=None, shape=None, anchor='c',
14
+ interpolation=1, prefilter=True, **kwargs):
15
+ """Resize an image by a factor or to a specific shape.
16
+
17
+ Notes
18
+ -----
19
+ .. A least one of `factor` and `shape` must be specified
20
+ .. If `anchor in ('centers', 'edges')`, exactly one of `factor` or
21
+ `shape must be specified.
22
+ .. If `anchor in ('first', 'last')`, `factor` must be provided even
23
+ if `shape` is specified.
24
+ .. Because of rounding, it is in general not assured that
25
+ `resize(resize(x, f), 1/f)` returns a tensor with the same shape as x.
26
+
27
+ edges centers first last
28
+ e - + - + - e + - + - + - + + - + - + - + + - + - + - +
29
+ | . | . | . | | c | . | c | | f | . | . | | . | . | . |
30
+ + _ + _ + _ + + _ + _ + _ + + _ + _ + _ + + _ + _ + _ +
31
+ | . | . | . | | . | . | . | | . | . | . | | . | . | . |
32
+ + _ + _ + _ + + _ + _ + _ + + _ + _ + _ + + _ + _ + _ +
33
+ | . | . | . | | c | . | c | | . | . | . | | . | . | l |
34
+ e _ + _ + _ e + _ + _ + _ + + _ + _ + _ + + _ + _ + _ +
35
+
36
+ Parameters
37
+ ----------
38
+ image : (batch, channel, *inshape) tensor
39
+ Image to resize
40
+ factor : float or list[float], optional
41
+ Resizing factor
42
+ * > 1 : larger image <-> smaller voxels
43
+ * < 1 : smaller image <-> larger voxels
44
+ shape : (ndim,) list[int], optional
45
+ Output shape
46
+ anchor : {'centers', 'edges', 'first', 'last'} or list, default='centers'
47
+ * In cases 'c' and 'e', the volume shape is multiplied by the
48
+ zoom factor (and eventually truncated), and two anchor points
49
+ are used to determine the voxel size.
50
+ * In cases 'f' and 'l', a single anchor point is used so that
51
+ the voxel size is exactly divided by the zoom factor.
52
+ This case with an integer factor corresponds to subslicing
53
+ the volume (e.g., `vol[::f, ::f, ::f]`).
54
+ * A list of anchors (one per dimension) can also be provided.
55
+ interpolation : int or sequence[int], default=1
56
+ Interpolation order.
57
+ prefilter : bool, default=True
58
+ Apply spline pre-filter (= interpolates the input)
59
+
60
+ Returns
61
+ -------
62
+ resized : (batch, channel, *shape) tensor
63
+ Resized image
64
+
65
+ """
66
+ if backend.jitfields and jitfields.available:
67
+ return jitfields.resize(image, factor, shape, anchor,
68
+ interpolation, prefilter, **kwargs)
69
+
70
+ factor = make_list(factor) if factor else []
71
+ shape = make_list(shape) if shape else []
72
+ anchor = make_list(anchor)
73
+ nb_dim = max(len(factor), len(shape), len(anchor)) or (image.dim() - 2)
74
+ anchor = [a[0].lower() for a in make_list(anchor, nb_dim)]
75
+ bck = dict(dtype=image.dtype, device=image.device)
76
+
77
+ # compute output shape
78
+ inshape = image.shape[-nb_dim:]
79
+ if factor:
80
+ factor = make_list(factor, nb_dim)
81
+ elif not shape:
82
+ raise ValueError('One of `factor` or `shape` must be provided')
83
+ if shape:
84
+ shape = make_list(shape, nb_dim)
85
+ else:
86
+ shape = [int(i*f) for i, f in zip(inshape, factor)]
87
+
88
+ if not factor:
89
+ factor = [o/i for o, i in zip(shape, inshape)]
90
+
91
+ # compute transformation grid
92
+ lin = []
93
+ for anch, f, inshp, outshp in zip(anchor, factor, inshape, shape):
94
+ if anch == 'c': # centers
95
+ lin.append(torch.linspace(0, inshp - 1, outshp, **bck))
96
+ elif anch == 'e': # edges
97
+ scale = inshp / outshp
98
+ shift = 0.5 * (scale - 1)
99
+ lin.append(torch.arange(0., outshp, **bck) * scale + shift)
100
+ elif anch == 'f': # first voxel
101
+ # scale = 1/f
102
+ # shift = 0
103
+ lin.append(torch.arange(0., outshp, **bck) / f)
104
+ elif anch == 'l': # last voxel
105
+ # scale = 1/f
106
+ shift = (inshp - 1) - (outshp - 1) / f
107
+ lin.append(torch.arange(0., outshp, **bck) / f + shift)
108
+ else:
109
+ raise ValueError('Unknown anchor {}'.format(anch))
110
+
111
+ # interpolate
112
+ kwargs.setdefault('bound', 'nearest')
113
+ kwargs.setdefault('extrapolate', True)
114
+ kwargs.setdefault('interpolation', interpolation)
115
+ kwargs.setdefault('prefilter', prefilter)
116
+ grid = torch.stack(meshgrid_ij(*lin), dim=-1)
117
+ resized = grid_pull(image, grid, **kwargs)
118
+
119
+ return resized
120
+
Generator/interpol/restrict.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __all__ = ['restrict']
2
+
3
+ from .api import grid_push
4
+ from .utils import make_list, meshgrid_ij
5
+ from . import backend, jitfields
6
+ import torch
7
+
8
+
9
+ def restrict(image, factor=None, shape=None, anchor='c',
10
+ interpolation=1, reduce_sum=False, **kwargs):
11
+ """Restrict an image by a factor or to a specific shape.
12
+
13
+ Notes
14
+ -----
15
+ .. A least one of `factor` and `shape` must be specified
16
+ .. If `anchor in ('centers', 'edges')`, exactly one of `factor` or
17
+ `shape must be specified.
18
+ .. If `anchor in ('first', 'last')`, `factor` must be provided even
19
+ if `shape` is specified.
20
+ .. Because of rounding, it is in general not assured that
21
+ `resize(resize(x, f), 1/f)` returns a tensor with the same shape as x.
22
+
23
+ edges centers first last
24
+ e - + - + - e + - + - + - + + - + - + - + + - + - + - +
25
+ | . | . | . | | c | . | c | | f | . | . | | . | . | . |
26
+ + _ + _ + _ + + _ + _ + _ + + _ + _ + _ + + _ + _ + _ +
27
+ | . | . | . | | . | . | . | | . | . | . | | . | . | . |
28
+ + _ + _ + _ + + _ + _ + _ + + _ + _ + _ + + _ + _ + _ +
29
+ | . | . | . | | c | . | c | | . | . | . | | . | . | l |
30
+ e _ + _ + _ e + _ + _ + _ + + _ + _ + _ + + _ + _ + _ +
31
+
32
+ Parameters
33
+ ----------
34
+ image : (batch, channel, *inshape) tensor
35
+ Image to resize
36
+ factor : float or list[float], optional
37
+ Resizing factor
38
+ * > 1 : larger image <-> smaller voxels
39
+ * < 1 : smaller image <-> larger voxels
40
+ shape : (ndim,) list[int], optional
41
+ Output shape
42
+ anchor : {'centers', 'edges', 'first', 'last'} or list, default='centers'
43
+ * In cases 'c' and 'e', the volume shape is multiplied by the
44
+ zoom factor (and eventually truncated), and two anchor points
45
+ are used to determine the voxel size.
46
+ * In cases 'f' and 'l', a single anchor point is used so that
47
+ the voxel size is exactly divided by the zoom factor.
48
+ This case with an integer factor corresponds to subslicing
49
+ the volume (e.g., `vol[::f, ::f, ::f]`).
50
+ * A list of anchors (one per dimension) can also be provided.
51
+ interpolation : int or sequence[int], default=1
52
+ Interpolation order.
53
+ reduce_sum : bool, default=False
54
+ Do not normalize by the number of accumulated values per voxel
55
+
56
+ Returns
57
+ -------
58
+ restricted : (batch, channel, *shape) tensor
59
+ Restricted image
60
+
61
+ """
62
+ if backend.jitfields and jitfields.available:
63
+ return jitfields.restrict(image, factor, shape, anchor,
64
+ interpolation, reduce_sum, **kwargs)
65
+
66
+ factor = make_list(factor) if factor else []
67
+ shape = make_list(shape) if shape else []
68
+ anchor = make_list(anchor)
69
+ nb_dim = max(len(factor), len(shape), len(anchor)) or (image.dim() - 2)
70
+ anchor = [a[0].lower() for a in make_list(anchor, nb_dim)]
71
+ bck = dict(dtype=image.dtype, device=image.device)
72
+
73
+ # compute output shape
74
+ inshape = image.shape[-nb_dim:]
75
+ if factor:
76
+ factor = make_list(factor, nb_dim)
77
+ elif not shape:
78
+ raise ValueError('One of `factor` or `shape` must be provided')
79
+ if shape:
80
+ shape = make_list(shape, nb_dim)
81
+ else:
82
+ shape = [int(i/f) for i, f in zip(inshape, factor)]
83
+
84
+ if not factor:
85
+ factor = [i/o for o, i in zip(shape, inshape)]
86
+
87
+ # compute transformation grid
88
+ lin = []
89
+ fullscale = 1
90
+ for anch, f, inshp, outshp in zip(anchor, factor, inshape, shape):
91
+ if anch == 'c': # centers
92
+ lin.append(torch.linspace(0, outshp - 1, inshp, **bck))
93
+ fullscale *= (inshp - 1) / (outshp - 1)
94
+ elif anch == 'e': # edges
95
+ scale = outshp / inshp
96
+ shift = 0.5 * (scale - 1)
97
+ fullscale *= scale
98
+ lin.append(torch.arange(0., inshp, **bck) * scale + shift)
99
+ elif anch == 'f': # first voxel
100
+ # scale = 1/f
101
+ # shift = 0
102
+ fullscale *= 1/f
103
+ lin.append(torch.arange(0., inshp, **bck) / f)
104
+ elif anch == 'l': # last voxel
105
+ # scale = 1/f
106
+ shift = (outshp - 1) - (inshp - 1) / f
107
+ fullscale *= 1/f
108
+ lin.append(torch.arange(0., inshp, **bck) / f + shift)
109
+ else:
110
+ raise ValueError('Unknown anchor {}'.format(anch))
111
+
112
+ # scatter
113
+ kwargs.setdefault('bound', 'nearest')
114
+ kwargs.setdefault('extrapolate', True)
115
+ kwargs.setdefault('interpolation', interpolation)
116
+ kwargs.setdefault('prefilter', False)
117
+ grid = torch.stack(meshgrid_ij(*lin), dim=-1)
118
+ resized = grid_push(image, grid, shape, **kwargs)
119
+ if not reduce_sum:
120
+ resized /= fullscale
121
+
122
+ return resized
Generator/interpol/splines.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Weights and derivatives of spline orders 0 to 7."""
2
+ import torch
3
+ from enum import Enum
4
+ from .jit_utils import square, cube, pow4, pow5, pow6, pow7
5
+
6
+
7
+ class InterpolationType(Enum):
8
+ nearest = zeroth = 0
9
+ linear = first = 1
10
+ quadratic = second = 2
11
+ cubic = third = 3
12
+ fourth = 4
13
+ fifth = 5
14
+ sixth = 6
15
+ seventh = 7
16
+
17
+
18
+ @torch.jit.script
19
+ class Spline:
20
+
21
+ def __init__(self, order: int = 1):
22
+ self.order = order
23
+
24
+ def weight(self, x):
25
+ w = self.fastweight(x)
26
+ zero = torch.zeros([1], dtype=x.dtype, device=x.device)
27
+ w = torch.where(x.abs() >= (self.order + 1)/2, zero, w)
28
+ return w
29
+
30
+ def fastweight(self, x):
31
+ if self.order == 0:
32
+ return torch.ones(x.shape, dtype=x.dtype, device=x.device)
33
+ x = x.abs()
34
+ if self.order == 1:
35
+ return 1 - x
36
+ if self.order == 2:
37
+ x_low = 0.75 - square(x)
38
+ x_up = 0.5 * square(1.5 - x)
39
+ return torch.where(x < 0.5, x_low, x_up)
40
+ if self.order == 3:
41
+ x_low = (x * x * (x - 2.) * 3. + 4.) / 6.
42
+ x_up = cube(2. - x) / 6.
43
+ return torch.where(x < 1., x_low, x_up)
44
+ if self.order == 4:
45
+ x_low = square(x)
46
+ x_low = x_low * (x_low * 0.25 - 0.625) + 115. / 192.
47
+ x_mid = x * (x * (x * (5. - x) / 6. - 1.25) + 5./24.) + 55./96.
48
+ x_up = pow4(x - 2.5) / 24.
49
+ return torch.where(x < 0.5, x_low, torch.where(x < 1.5, x_mid, x_up))
50
+ if self.order == 5:
51
+ x_low = square(x)
52
+ x_low = x_low * (x_low * (0.25 - x / 12.) - 0.5) + 0.55
53
+ x_mid = x * (x * (x * (x * (x / 24. - 0.375) + 1.25) - 1.75) + 0.625) + 0.425
54
+ x_up = pow5(3 - x) / 120.
55
+ return torch.where(x < 1., x_low, torch.where(x < 2., x_mid, x_up))
56
+ if self.order == 6:
57
+ x_low = square(x)
58
+ x_low = x_low * (x_low * (7./48. - x_low/36.) - 77./192.) + 5887./11520.
59
+ x_mid_low = (x * (x * (x * (x * (x * (x / 48. - 7./48.) + 0.328125)
60
+ - 35./288.) - 91./256.) - 7./768.) + 7861./15360.)
61
+ x_mid_up = (x * (x * (x * (x * (x * (7./60. - x / 120.) - 0.65625)
62
+ + 133./72.) - 2.5703125) + 1267./960.) + 1379./7680.)
63
+ x_up = pow6(x - 3.5) / 720.
64
+ return torch.where(x < .5, x_low,
65
+ torch.where(x < 1.5, x_mid_low,
66
+ torch.where(x < 2.5, x_mid_up, x_up)))
67
+ if self.order == 7:
68
+ x_low = square(x)
69
+ x_low = (x_low * (x_low * (x_low * (x / 144. - 1./36.)
70
+ + 1./9.) - 1./3.) + 151./315.)
71
+ x_mid_low = (x * (x * (x * (x * (x * (x * (0.05 - x/240.) - 7./30.)
72
+ + 0.5) - 7./18.) - 0.1) - 7./90.) + 103./210.)
73
+ x_mid_up = (x * (x * (x * (x * (x * (x * (x / 720. - 1./36.)
74
+ + 7./30.) - 19./18.) + 49./18.) - 23./6.) + 217./90.)
75
+ - 139./630.)
76
+ x_up = pow7(4 - x) / 5040.
77
+ return torch.where(x < 1., x_low,
78
+ torch.where(x < 2., x_mid_low,
79
+ torch.where(x < 3., x_mid_up, x_up)))
80
+ raise NotImplementedError
81
+
82
+ def grad(self, x):
83
+ if self.order == 0:
84
+ return torch.zeros(x.shape, dtype=x.dtype, device=x.device)
85
+ g = self.fastgrad(x)
86
+ zero = torch.zeros([1], dtype=x.dtype, device=x.device)
87
+ g = torch.where(x.abs() >= (self.order + 1)/2, zero, g)
88
+ return g
89
+
90
+ def fastgrad(self, x):
91
+ if self.order == 0:
92
+ return torch.zeros(x.shape, dtype=x.dtype, device=x.device)
93
+ return self._fastgrad(x.abs()).mul(x.sign())
94
+
95
+ def _fastgrad(self, x):
96
+ if self.order == 1:
97
+ return torch.ones(x.shape, dtype=x.dtype, device=x.device)
98
+ if self.order == 2:
99
+ return torch.where(x < 0.5, -2*x, x - 1.5)
100
+ if self.order == 3:
101
+ g_low = x * (x * 1.5 - 2)
102
+ g_up = -0.5 * square(2 - x)
103
+ return torch.where(x < 1, g_low, g_up)
104
+ if self.order == 4:
105
+ g_low = x * (square(x) - 1.25)
106
+ g_mid = x * (x * (x * (-2./3.) + 2.5) - 2.5) + 5./24.
107
+ g_up = cube(2. * x - 5.) / 48.
108
+ return torch.where(x < 0.5, g_low,
109
+ torch.where(x < 1.5, g_mid, g_up))
110
+ if self.order == 5:
111
+ g_low = x * (x * (x * (x * (-5./12.) + 1.)) - 1.)
112
+ g_mid = x * (x * (x * (x * (5./24.) - 1.5) + 3.75) - 3.5) + 0.625
113
+ g_up = pow4(x - 3.) / (-24.)
114
+ return torch.where(x < 1, g_low,
115
+ torch.where(x < 2, g_mid, g_up))
116
+ if self.order == 6:
117
+ g_low = square(x)
118
+ g_low = x * (g_low * (7./12.) - square(g_low) / 6. - 77./96.)
119
+ g_mid_low = (x * (x * (x * (x * (x * 0.125 - 35./48.) + 1.3125)
120
+ - 35./96.) - 0.7109375) - 7./768.)
121
+ g_mid_up = (x * (x * (x * (x * (x / (-20.) + 7./12.) - 2.625)
122
+ + 133./24.) - 5.140625) + 1267./960.)
123
+ g_up = pow5(2*x - 7) / 3840.
124
+ return torch.where(x < 0.5, g_low,
125
+ torch.where(x < 1.5, g_mid_low,
126
+ torch.where(x < 2.5, g_mid_up,
127
+ g_up)))
128
+ if self.order == 7:
129
+ g_low = square(x)
130
+ g_low = x * (g_low * (g_low * (x * (7./144.) - 1./6.) + 4./9.) - 2./3.)
131
+ g_mid_low = (x * (x * (x * (x * (x * (x * (-7./240.) + 3./10.)
132
+ - 7./6.) + 2.) - 7./6.) - 1./5.) - 7./90.)
133
+ g_mid_up = (x * (x * (x * (x * (x * (x * (7./720.) - 1./6.)
134
+ + 7./6.) - 38./9.) + 49./6.) - 23./3.) + 217./90.)
135
+ g_up = pow6(x - 4) / (-720.)
136
+ return torch.where(x < 1, g_low,
137
+ torch.where(x < 2, g_mid_low,
138
+ torch.where(x < 3, g_mid_up, g_up)))
139
+ raise NotImplementedError
140
+
141
+ def hess(self, x):
142
+ if self.order == 0:
143
+ return torch.zeros(x.shape, dtype=x.dtype, device=x.device)
144
+ h = self.fasthess(x)
145
+ zero = torch.zeros([1], dtype=x.dtype, device=x.device)
146
+ h = torch.where(x.abs() >= (self.order + 1)/2, zero, h)
147
+ return h
148
+
149
+ def fasthess(self, x):
150
+ if self.order in (0, 1):
151
+ return torch.zeros(x.shape, dtype=x.dtype, device=x.device)
152
+ x = x.abs()
153
+ if self.order == 2:
154
+ one = torch.ones([1], dtype=x.dtype, device=x.device)
155
+ return torch.where(x < 0.5, -2 * one, one)
156
+ if self.order == 3:
157
+ return torch.where(x < 1, 3. * x - 2., 2. - x)
158
+ if self.order == 4:
159
+ return torch.where(x < 0.5, 3. * square(x) - 1.25,
160
+ torch.where(x < 1.5, x * (-2. * x + 5.) - 2.5,
161
+ square(2. * x - 5.) / 8.))
162
+ if self.order == 5:
163
+ h_low = square(x)
164
+ h_low = - h_low * (x * (5./3.) - 3.) - 1.
165
+ h_mid = x * (x * (x * (5./6.) - 9./2.) + 15./2.) - 7./2.
166
+ h_up = 9./2. - x * (x * (x/6. - 3./2.) + 9./2.)
167
+ return torch.where(x < 1, h_low,
168
+ torch.where(x < 2, h_mid, h_up))
169
+ if self.order == 6:
170
+ h_low = square(x)
171
+ h_low = - h_low * (h_low * (5./6) - 7./4.) - 77./96.
172
+ h_mid_low = (x * (x * (x * (x * (5./8.) - 35./12.) + 63./16.)
173
+ - 35./48.) - 91./128.)
174
+ h_mid_up = -(x * (x * (x * (x/4. - 7./3.) + 63./8.) - 133./12.)
175
+ + 329./64.)
176
+ h_up = (x * (x * (x * (x/24. - 7./12.) + 49./16.) - 343./48.)
177
+ + 2401./384.)
178
+ return torch.where(x < 0.5, h_low,
179
+ torch.where(x < 1.5, h_mid_low,
180
+ torch.where(x < 2.5, h_mid_up,
181
+ h_up)))
182
+ if self.order == 7:
183
+ h_low = square(x)
184
+ h_low = h_low * (h_low*(x * (7./24.) - 5./6.) + 4./3.) - 2./3.
185
+ h_mid_low = - (x * (x * (x * (x * (x * (7./40.) - 3./2.) + 14./3.)
186
+ - 6.) + 7./3.) + 1./5.)
187
+ h_mid_up = (x * (x * (x * (x * (x * (7./120.) - 5./6.) + 14./3.)
188
+ - 38./3.) + 49./3.) - 23./3.)
189
+ h_up = - (x * (x * (x * (x * (x/120. - 1./6.) + 4./3.) - 16./3.)
190
+ + 32./3.) - 128./15.)
191
+ return torch.where(x < 1, h_low,
192
+ torch.where(x < 2, h_mid_low,
193
+ torch.where(x < 3, h_mid_up,
194
+ h_up)))
195
+ raise NotImplementedError
196
+
Generator/interpol/tests/__init__.py ADDED
File without changes
Generator/interpol/tests/test_gradcheck_pushpull.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.autograd import gradcheck
3
+ from interpol import grid_pull, grid_push, grid_count, grid_grad, add_identity_grid_
4
+ import pytest
5
+ import inspect
6
+
7
+ # global parameters
8
+ dtype = torch.double # data type (double advised to check gradients)
9
+ shape1 = 3 # size along each dimension
10
+ extrapolate = True
11
+
12
+ if hasattr(torch, 'use_deterministic_algorithms'):
13
+ torch.use_deterministic_algorithms(True)
14
+ kwargs = dict(rtol=1., raise_exception=True)
15
+ if 'check_undefined_grad' in inspect.signature(gradcheck).parameters:
16
+ kwargs['check_undefined_grad'] = False
17
+ if 'nondet_tol' in inspect.signature(gradcheck).parameters:
18
+ kwargs['nondet_tol'] = 1e-3
19
+
20
+ # parameters
21
+ devices = [('cpu', 1)]
22
+ if torch.backends.openmp.is_available() or torch.backends.mkl.is_available():
23
+ print('parallel backend available')
24
+ devices.append(('cpu', 10))
25
+ if torch.cuda.is_available():
26
+ print('cuda backend available')
27
+ devices.append('cuda')
28
+
29
+ dims = [1, 2, 3]
30
+ bounds = list(range(7))
31
+ order_bounds = []
32
+ for o in range(3):
33
+ for b in bounds:
34
+ order_bounds += [(o, b)]
35
+ for o in range(3, 8):
36
+ order_bounds += [(o, 3)] # only test dc2 for order > 2
37
+
38
+
39
+ def make_data(shape, device, dtype):
40
+ grid = torch.randn([2, *shape, len(shape)], device=device, dtype=dtype)
41
+ grid = add_identity_grid_(grid)
42
+ vol = torch.randn((2, 1,) + shape, device=device, dtype=dtype)
43
+ return vol, grid
44
+
45
+
46
+ def init_device(device):
47
+ if isinstance(device, (list, tuple)):
48
+ device, param = device
49
+ else:
50
+ param = 1 if device == 'cpu' else 0
51
+ if device == 'cuda':
52
+ torch.cuda.set_device(param)
53
+ torch.cuda.init()
54
+ try:
55
+ torch.cuda.empty_cache()
56
+ except RuntimeError:
57
+ pass
58
+ device = '{}:{}'.format(device, param)
59
+ else:
60
+ assert device == 'cpu'
61
+ torch.set_num_threads(param)
62
+ return torch.device(device)
63
+
64
+
65
+ @pytest.mark.parametrize("device", devices)
66
+ @pytest.mark.parametrize("dim", dims)
67
+ # @pytest.mark.parametrize("bound", bounds)
68
+ # @pytest.mark.parametrize("interpolation", orders)
69
+ @pytest.mark.parametrize("interpolation,bound", order_bounds)
70
+ def test_gradcheck_grad(device, dim, bound, interpolation):
71
+ print(f'grad_{dim}d({interpolation}, {bound}) on {device}')
72
+ device = init_device(device)
73
+ shape = (shape1,) * dim
74
+ vol, grid = make_data(shape, device, dtype)
75
+ vol.requires_grad = True
76
+ grid.requires_grad = True
77
+ assert gradcheck(grid_grad, (vol, grid, interpolation, bound, extrapolate),
78
+ **kwargs)
79
+
80
+
81
+ @pytest.mark.parametrize("device", devices)
82
+ @pytest.mark.parametrize("dim", dims)
83
+ # @pytest.mark.parametrize("bound", bounds)
84
+ # @pytest.mark.parametrize("interpolation", orders)
85
+ @pytest.mark.parametrize("interpolation,bound", order_bounds)
86
+ def test_gradcheck_pull(device, dim, bound, interpolation):
87
+ print(f'pull_{dim}d({interpolation}, {bound}) on {device}')
88
+ device = init_device(device)
89
+ shape = (shape1,) * dim
90
+ vol, grid = make_data(shape, device, dtype)
91
+ vol.requires_grad = True
92
+ grid.requires_grad = True
93
+ assert gradcheck(grid_pull, (vol, grid, interpolation, bound, extrapolate),
94
+ **kwargs)
95
+
96
+
97
+ @pytest.mark.parametrize("device", devices)
98
+ @pytest.mark.parametrize("dim", dims)
99
+ # @pytest.mark.parametrize("bound", bounds)
100
+ # @pytest.mark.parametrize("interpolation", orders)
101
+ @pytest.mark.parametrize("interpolation,bound", order_bounds)
102
+ def test_gradcheck_push(device, dim, bound, interpolation):
103
+ print(f'push_{dim}d({interpolation}, {bound}) on {device}')
104
+ device = init_device(device)
105
+ shape = (shape1,) * dim
106
+ vol, grid = make_data(shape, device, dtype)
107
+ vol.requires_grad = True
108
+ grid.requires_grad = True
109
+ assert gradcheck(grid_push, (vol, grid, shape, interpolation, bound, extrapolate),
110
+ **kwargs)
111
+
112
+
113
+ @pytest.mark.parametrize("device", devices)
114
+ @pytest.mark.parametrize("dim", dims)
115
+ # @pytest.mark.parametrize("bound", bounds)
116
+ # @pytest.mark.parametrize("interpolation", orders)
117
+ @pytest.mark.parametrize("interpolation,bound", order_bounds)
118
+ def test_gradcheck_count(device, dim, bound, interpolation):
119
+ print(f'count_{dim}d({interpolation}, {bound}) on {device}')
120
+ device = init_device(device)
121
+ shape = (shape1,) * dim
122
+ _, grid = make_data(shape, device, dtype)
123
+ grid.requires_grad = True
124
+ assert gradcheck(grid_count, (grid, shape, interpolation, bound, extrapolate),
125
+ **kwargs)
Generator/interpol/utils.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def fake_decorator(*a, **k):
5
+ if len(a) == 1 and not k:
6
+ return a[0]
7
+ else:
8
+ return fake_decorator
9
+
10
+
11
+ def make_list(x, n=None, **kwargs):
12
+ """Ensure that the input is a list (of a given size)
13
+
14
+ Parameters
15
+ ----------
16
+ x : list or tuple or scalar
17
+ Input object
18
+ n : int, optional
19
+ Required length
20
+ default : scalar, optional
21
+ Value to right-pad with. Use last value of the input by default.
22
+
23
+ Returns
24
+ -------
25
+ x : list
26
+ """
27
+ if not isinstance(x, (list, tuple)):
28
+ x = [x]
29
+ x = list(x)
30
+ if n and len(x) < n:
31
+ default = kwargs.get('default', x[-1])
32
+ x = x + [default] * max(0, n - len(x))
33
+ return x
34
+
35
+
36
+ def expanded_shape(*shapes, side='left'):
37
+ """Expand input shapes according to broadcasting rules
38
+
39
+ Parameters
40
+ ----------
41
+ *shapes : sequence[int]
42
+ Input shapes
43
+ side : {'left', 'right'}, default='left'
44
+ Side to add singleton dimensions.
45
+
46
+ Returns
47
+ -------
48
+ shape : tuple[int]
49
+ Output shape
50
+
51
+ Raises
52
+ ------
53
+ ValueError
54
+ If shapes are not compatible for broadcast.
55
+
56
+ """
57
+ def error(s0, s1):
58
+ raise ValueError('Incompatible shapes for broadcasting: {} and {}.'
59
+ .format(s0, s1))
60
+
61
+ # 1. nb dimensions
62
+ nb_dim = 0
63
+ for shape in shapes:
64
+ nb_dim = max(nb_dim, len(shape))
65
+
66
+ # 2. enumerate
67
+ shape = [1] * nb_dim
68
+ for i, shape1 in enumerate(shapes):
69
+ pad_size = nb_dim - len(shape1)
70
+ ones = [1] * pad_size
71
+ if side == 'left':
72
+ shape1 = [*ones, *shape1]
73
+ else:
74
+ shape1 = [*shape1, *ones]
75
+ shape = [max(s0, s1) if s0 == 1 or s1 == 1 or s0 == s1
76
+ else error(s0, s1) for s0, s1 in zip(shape, shape1)]
77
+
78
+ return tuple(shape)
79
+
80
+
81
+ def matvec(mat, vec, out=None):
82
+ """Matrix-vector product (supports broadcasting)
83
+
84
+ Parameters
85
+ ----------
86
+ mat : (..., M, N) tensor
87
+ Input matrix.
88
+ vec : (..., N) tensor
89
+ Input vector.
90
+ out : (..., M) tensor, optional
91
+ Placeholder for the output tensor.
92
+
93
+ Returns
94
+ -------
95
+ mv : (..., M) tensor
96
+ Matrix vector product of the inputs
97
+
98
+ """
99
+ vec = vec[..., None]
100
+ if out is not None:
101
+ out = out[..., None]
102
+
103
+ mv = torch.matmul(mat, vec, out=out)
104
+ mv = mv[..., 0]
105
+ if out is not None:
106
+ out = out[..., 0]
107
+
108
+ return mv
109
+
110
+
111
+ def _compare_versions(version1, mode, version2):
112
+ for v1, v2 in zip(version1, version2):
113
+ if mode in ('gt', '>'):
114
+ if v1 > v2:
115
+ return True
116
+ elif v1 < v2:
117
+ return False
118
+ elif mode in ('ge', '>='):
119
+ if v1 > v2:
120
+ return True
121
+ elif v1 < v2:
122
+ return False
123
+ elif mode in ('lt', '<'):
124
+ if v1 < v2:
125
+ return True
126
+ elif v1 > v2:
127
+ return False
128
+ elif mode in ('le', '<='):
129
+ if v1 < v2:
130
+ return True
131
+ elif v1 > v2:
132
+ return False
133
+ if mode in ('gt', 'lt', '>', '<'):
134
+ return False
135
+ else:
136
+ return True
137
+
138
+
139
+ def torch_version(mode, version):
140
+ """Check torch version
141
+
142
+ Parameters
143
+ ----------
144
+ mode : {'<', '<=', '>', '>='}
145
+ version : tuple[int]
146
+
147
+ Returns
148
+ -------
149
+ True if "torch.version <mode> version"
150
+
151
+ """
152
+ current_version, *cuda_variant = torch.__version__.split('+')
153
+ major, minor, patch, *_ = current_version.split('.')
154
+ # strip alpha tags
155
+ for x in 'abcdefghijklmnopqrstuvwxy':
156
+ if x in patch:
157
+ patch = patch[:patch.index(x)]
158
+ current_version = (int(major), int(minor), int(patch))
159
+ version = make_list(version)
160
+ return _compare_versions(current_version, mode, version)
161
+
162
+
163
+ if torch_version('>=', (1, 10)):
164
+ meshgrid_ij = lambda *x: torch.meshgrid(*x, indexing='ij')
165
+ meshgrid_xy = lambda *x: torch.meshgrid(*x, indexing='xy')
166
+ else:
167
+ meshgrid_ij = lambda *x: torch.meshgrid(*x)
168
+ def meshgrid_xy(*x):
169
+ grid = list(torch.meshgrid(*x))
170
+ if len(grid) > 1:
171
+ grid[0] = grid[0].transpose(0, 1)
172
+ grid[1] = grid[1].transpose(0, 1)
173
+ return grid
174
+
175
+
176
+ meshgrid = meshgrid_ij
Generator/utils.py ADDED
@@ -0,0 +1,669 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import nibabel as nib
4
+
5
+ import torch
6
+ from torch.nn.functional import conv3d
7
+ from torch.utils.data import Dataset
8
+
9
+ from scipy.io.matlab import loadmat
10
+
11
+
12
+ import time, datetime
13
+
14
+ from ShapeID.DiffEqs.adjoint import odeint_adjoint as odeint
15
+ from ShapeID.perlin3d import generate_velocity_3d , generate_shape_3d
16
+
17
+
18
+ class ConcatDataset(Dataset):
19
+ def __init__(self,dataset_list, probs=None):
20
+ self.datasets = dataset_list
21
+ self.probs = probs if probs else [1/len(self.datasets)] * len(self.datasets)
22
+
23
+ def __getitem__(self, i):
24
+ chosen_dataset = np.random.choice(self.datasets, 1, p=self.probs)[0]
25
+ i = i % len(chosen_dataset)
26
+ return chosen_dataset[i]
27
+
28
+ def __len__(self):
29
+ return max(len(d) for d in self.datasets)
30
+
31
+
32
+
33
+ # Prepare generator
34
+ def resolution_sampler(low_res_only = False):
35
+
36
+ if low_res_only:
37
+ r = (np.random.rand() * 0.5) + 0.5 # in [0.5, 1]
38
+ else:
39
+ r = np.random.rand() # in [0, 1]
40
+
41
+ if r < 0.25: # 1mm isotropic
42
+ resolution = np.array([1.0, 1.0, 1.0])
43
+ thickness = np.array([1.0, 1.0, 1.0])
44
+ elif r < 0.5: # clinical (low-res in one dimension)
45
+ resolution = np.array([1.0, 1.0, 1.0])
46
+ thickness = np.array([1.0, 1.0, 1.0])
47
+ idx = np.random.randint(3)
48
+ resolution[idx] = 2.5 + 6 * np.random.rand()
49
+ thickness[idx] = np.min([resolution[idx], 4.0 + 2.0 * np.random.rand()])
50
+ elif r < 0.75: # low-field: stock sequences (always axial)
51
+ resolution = np.array([1.3, 1.3, 4.8]) + 0.4 * np.random.rand(3)
52
+ thickness = resolution.copy()
53
+ else: # low-field: isotropic-ish (also good for scouts)
54
+ resolution = 2.0 + 3.0 * np.random.rand(3)
55
+ thickness = resolution.copy()
56
+
57
+ return resolution, thickness
58
+
59
+
60
+ #####################################
61
+ ############ Utility Func ###########
62
+ #####################################
63
+
64
+
65
+ def binarize(p, thres):
66
+ # TODO: what is the optimal thresholding strategy?
67
+ thres = thres * p.max()
68
+
69
+ bin = p.clone()
70
+ bin[p < thres] = 0.
71
+ bin[p >= thres] = 1.
72
+ return bin
73
+
74
+ def make_gaussian_kernel(sigma, device):
75
+
76
+ sl = int(np.ceil(3 * sigma))
77
+ ts = torch.linspace(-sl, sl, 2*sl+1, dtype=torch.float, device=device)
78
+ gauss = torch.exp((-(ts / sigma)**2 / 2))
79
+ kernel = gauss / gauss.sum()
80
+
81
+ return kernel
82
+
83
+ def gaussian_blur_3d(input, stds, device):
84
+ blurred = input[None, None, :, :, :]
85
+ if stds[0]>0:
86
+ kx = make_gaussian_kernel(stds[0], device=device)
87
+ blurred = conv3d(blurred, kx[None, None, :, None, None], stride=1, padding=(len(kx) // 2, 0, 0))
88
+ if stds[1]>0:
89
+ ky = make_gaussian_kernel(stds[1], device=device)
90
+ blurred = conv3d(blurred, ky[None, None, None, :, None], stride=1, padding=(0, len(ky) // 2, 0))
91
+ if stds[2]>0:
92
+ kz = make_gaussian_kernel(stds[2], device=device)
93
+ blurred = conv3d(blurred, kz[None, None, None, None, :], stride=1, padding=(0, 0, len(kz) // 2))
94
+ return torch.squeeze(blurred)
95
+
96
+
97
+
98
+ #####################################
99
+ ######### Deformation Func ##########
100
+ #####################################
101
+
102
+ def make_affine_matrix(rot, sh, s):
103
+ Rx = np.array([[1, 0, 0], [0, np.cos(rot[0]), -np.sin(rot[0])], [0, np.sin(rot[0]), np.cos(rot[0])]])
104
+ Ry = np.array([[np.cos(rot[1]), 0, np.sin(rot[1])], [0, 1, 0], [-np.sin(rot[1]), 0, np.cos(rot[1])]])
105
+ Rz = np.array([[np.cos(rot[2]), -np.sin(rot[2]), 0], [np.sin(rot[2]), np.cos(rot[2]), 0], [0, 0, 1]])
106
+
107
+ SHx = np.array([[1, 0, 0], [sh[1], 1, 0], [sh[2], 0, 1]])
108
+ SHy = np.array([[1, sh[0], 0], [0, 1, 0], [0, sh[2], 1]])
109
+ SHz = np.array([[1, 0, sh[0]], [0, 1, sh[1]], [0, 0, 1]])
110
+
111
+ A = SHx @ SHy @ SHz @ Rx @ Ry @ Rz
112
+ A[0, :] = A[0, :] * s[0]
113
+ A[1, :] = A[1, :] * s[1]
114
+ A[2, :] = A[2, :] * s[2]
115
+
116
+ return A
117
+
118
+
119
+ def fast_3D_interp_torch(X, II, JJ, KK, mode='linear', default_value_linear=0.0):
120
+
121
+ if II is None:
122
+ return X
123
+
124
+ if mode=='nearest':
125
+ IIr = torch.round(II).long()
126
+ JJr = torch.round(JJ).long()
127
+ KKr = torch.round(KK).long()
128
+ IIr[IIr < 0] = 0
129
+ JJr[JJr < 0] = 0
130
+ KKr[KKr < 0] = 0
131
+ IIr[IIr > (X.shape[0] - 1)] = (X.shape[0] - 1)
132
+ JJr[JJr > (X.shape[1] - 1)] = (X.shape[1] - 1)
133
+ KKr[KKr > (X.shape[2] - 1)] = (X.shape[2] - 1)
134
+ if len(X.shape)==3:
135
+ X = X[..., None]
136
+ Y = X[IIr, JJr, KKr]
137
+ if Y.shape[3] == 1:
138
+ Y = Y[:, :, :, 0]
139
+
140
+ elif mode=='linear':
141
+ ok = (II>0) & (JJ>0) & (KK>0) & (II<=X.shape[0]-1) & (JJ<=X.shape[1]-1) & (KK<=X.shape[2]-1)
142
+
143
+ IIv = II[ok]
144
+ JJv = JJ[ok]
145
+ KKv = KK[ok]
146
+
147
+ fx = torch.floor(IIv).long()
148
+ cx = fx + 1
149
+ cx[cx > (X.shape[0] - 1)] = (X.shape[0] - 1)
150
+ wcx = (IIv - fx)[..., None]
151
+ wfx = 1 - wcx
152
+
153
+ fy = torch.floor(JJv).long()
154
+ cy = fy + 1
155
+ cy[cy > (X.shape[1] - 1)] = (X.shape[1] - 1)
156
+ wcy = (JJv - fy)[..., None]
157
+ wfy = 1 - wcy
158
+
159
+ fz = torch.floor(KKv).long()
160
+ cz = fz + 1
161
+ cz[cz > (X.shape[2] - 1)] = (X.shape[2] - 1)
162
+ wcz = (KKv - fz)[..., None]
163
+ wfz = 1 - wcz
164
+
165
+ if len(X.shape)==3:
166
+ X = X[..., None]
167
+
168
+ c000 = X[fx, fy, fz]
169
+ c100 = X[cx, fy, fz]
170
+ c010 = X[fx, cy, fz]
171
+ c110 = X[cx, cy, fz]
172
+ c001 = X[fx, fy, cz]
173
+ c101 = X[cx, fy, cz]
174
+ c011 = X[fx, cy, cz]
175
+ c111 = X[cx, cy, cz]
176
+
177
+ c00 = c000 * wfx + c100 * wcx
178
+ c01 = c001 * wfx + c101 * wcx
179
+ c10 = c010 * wfx + c110 * wcx
180
+ c11 = c011 * wfx + c111 * wcx
181
+
182
+ c0 = c00 * wfy + c10 * wcy
183
+ c1 = c01 * wfy + c11 * wcy
184
+
185
+ c = c0 * wfz + c1 * wcz
186
+
187
+ Y = torch.zeros([*II.shape, X.shape[3]], device=X.device)
188
+ Y[ok] = c.float()
189
+ Y[~ok] = default_value_linear
190
+
191
+ if Y.shape[-1]==1:
192
+ Y = Y[...,0]
193
+ else:
194
+ raise Exception('mode must be linear or nearest')
195
+
196
+ return Y
197
+
198
+
199
+
200
+ def myzoom_torch(X, factor, aff=None):
201
+
202
+ if len(X.shape)==3:
203
+ X = X[..., None]
204
+
205
+ delta = (1.0 - factor) / (2.0 * factor)
206
+ newsize = np.round(X.shape[:-1] * factor).astype(int)
207
+
208
+ vx = torch.arange(delta[0], delta[0] + newsize[0] / factor[0], 1 / factor[0], dtype=torch.float, device=X.device)[:newsize[0]]
209
+ vy = torch.arange(delta[1], delta[1] + newsize[1] / factor[1], 1 / factor[1], dtype=torch.float, device=X.device)[:newsize[1]]
210
+ vz = torch.arange(delta[2], delta[2] + newsize[2] / factor[2], 1 / factor[2], dtype=torch.float, device=X.device)[:newsize[2]]
211
+
212
+ vx[vx < 0] = 0
213
+ vy[vy < 0] = 0
214
+ vz[vz < 0] = 0
215
+ vx[vx > (X.shape[0]-1)] = (X.shape[0]-1)
216
+ vy[vy > (X.shape[1] - 1)] = (X.shape[1] - 1)
217
+ vz[vz > (X.shape[2] - 1)] = (X.shape[2] - 1)
218
+
219
+ fx = torch.floor(vx).int()
220
+ cx = fx + 1
221
+ cx[cx > (X.shape[0]-1)] = (X.shape[0]-1)
222
+ wcx = (vx - fx)
223
+ wfx = 1 - wcx
224
+
225
+ fy = torch.floor(vy).int()
226
+ cy = fy + 1
227
+ cy[cy > (X.shape[1]-1)] = (X.shape[1]-1)
228
+ wcy = (vy - fy)
229
+ wfy = 1 - wcy
230
+
231
+ fz = torch.floor(vz).int()
232
+ cz = fz + 1
233
+ cz[cz > (X.shape[2]-1)] = (X.shape[2]-1)
234
+ wcz = (vz - fz)
235
+ wfz = 1 - wcz
236
+
237
+ Y = torch.zeros([newsize[0], newsize[1], newsize[2], X.shape[3]], dtype=torch.float, device=X.device)
238
+
239
+ tmp1 = torch.zeros([newsize[0], X.shape[1], X.shape[2], X.shape[3]], dtype=torch.float, device=X.device)
240
+ for i in range(newsize[0]):
241
+ tmp1[i, :, :] = wfx[i] * X[fx[i], :, :] + wcx[i] * X[cx[i], :, :]
242
+ tmp2 = torch.zeros([newsize[0], newsize[1], X.shape[2], X.shape[3]], dtype=torch.float, device=X.device)
243
+ for j in range(newsize[1]):
244
+ tmp2[:, j, :] = wfy[j] * tmp1[:, fy[j], :] + wcy[j] * tmp1[:, cy[j], :]
245
+ for k in range(newsize[2]):
246
+ Y[:, :, k] = wfz[k] * tmp2[:, :, fz[k]] + wcz[k] * tmp2[:, :, cz[k]]
247
+
248
+ if Y.shape[3] == 1:
249
+ Y = Y[:,:,:, 0]
250
+
251
+ if aff is not None:
252
+ aff_new = aff.copy()
253
+ aff_new[:-1] = aff_new[:-1] / factor
254
+ aff_new[:-1, -1] = aff_new[:-1, -1] - aff[:-1, :-1] @ (0.5 - 0.5 / (factor * np.ones(3)))
255
+ return Y, aff_new
256
+ else:
257
+ return Y
258
+
259
+
260
+
261
+
262
+ #####################################
263
+ ############ Reading Func ###########
264
+ #####################################
265
+
266
+ def read_image(file_name):
267
+ img = nib.load(file_name)
268
+ aff = img.affine
269
+ res = np.sqrt(np.sum(abs(aff[:-1, :-1]), axis=0))
270
+ return img, aff, res
271
+
272
+ def deform_image(I, deform_dict, device, default_value_linear_mode=None, deform_mode = 'linear'):
273
+ if I is None:
274
+ return I
275
+
276
+ [xx2, yy2, zz2, x1, y1, z1, x2, y2, z2] = deform_dict['grid']
277
+
278
+ if not isinstance(I, torch.Tensor):
279
+ I = torch.squeeze(torch.tensor(I.get_fdata()[x1:x2, y1:y2, z1:z2].astype(float), dtype=torch.float, device=device))
280
+ else:
281
+ I = torch.squeeze(I[x1:x2, y1:y2, z1:z2].astype(float), dtype=torch.float, device=device)
282
+ I = torch.nan_to_num(I)
283
+
284
+ if default_value_linear_mode is not None:
285
+ if default_value_linear_mode == 'max':
286
+ default_value_linear = torch.max(I)
287
+ else:
288
+ raise ValueError('Not support default_value_linear_mode:', default_value_linear_mode)
289
+ else:
290
+ default_value_linear = 0.
291
+ Idef = fast_3D_interp_torch(I, xx2, yy2, zz2, deform_mode, default_value_linear)
292
+
293
+ return Idef
294
+
295
+
296
+ def read_and_deform(file_name, dtype, deform_dict, device, mask, default_value_linear_mode=None, deform_mode = 'linear', mean = 0., scale = 1.):
297
+ [xx2, yy2, zz2, x1, y1, z1, x2, y2, z2] = deform_dict['grid']
298
+
299
+ try:
300
+ Iimg = nib.load(file_name)
301
+ except:
302
+ Iimg = nib.load(file_name + '.gz')
303
+ res = np.sqrt(np.sum(abs(Iimg.affine[:-1, :-1]), axis=0))
304
+ I = torch.squeeze(torch.tensor(Iimg.get_fdata()[x1:x2, y1:y2, z1:z2].astype(float), dtype=dtype, device=device))
305
+ I = torch.nan_to_num(I)
306
+
307
+ I -= mean
308
+ I /= scale
309
+
310
+ if mask is not None:
311
+ I[mask == 0] = 0
312
+
313
+ if default_value_linear_mode is not None:
314
+ if default_value_linear_mode == 'max':
315
+ default_value_linear = torch.max(I)
316
+ else:
317
+ raise ValueError('Not support default_value_linear_mode:', default_value_linear_mode)
318
+ else:
319
+ default_value_linear = 0.
320
+ Idef = fast_3D_interp_torch(I, xx2, yy2, zz2, deform_mode, default_value_linear)
321
+ return Idef, res
322
+
323
+
324
+ def read_and_deform_image(exist_keys, task_name, file_name, setups, deform_dict, device, mask, **kwargs):
325
+ Idef, _ = read_and_deform(file_name, torch.float, deform_dict, device, mask)
326
+ Idef -= torch.min(Idef)
327
+ Idef /= torch.max(Idef)
328
+ if setups['flip']:
329
+ Idef = torch.flip(Idef, [0])
330
+ update_dict = {task_name: Idef[None]}
331
+
332
+ if os.path.isfile(file_name[:-4] + '.defacingmask.nii'):
333
+ Idef_DM, _ = read_and_deform(file_name[:-4] + '.defacingmask.nii', torch.float, deform_dict, device, mask)
334
+ Idef_DM = torch.clamp(Idef_DM, min = 0.)
335
+ Idef_DM /= torch.max(Idef_DM)
336
+ if setups['flip']:
337
+ Idef = torch.flip(Idef_DM, [0])
338
+ update_dict.update({task_name + '_DM': Idef_DM[None]})
339
+ #if not 'brain_mask' in exist_keys:
340
+ # mask = torch.ones_like(Idef)
341
+ # mask[Idef <= 0.] = 0.
342
+ # update_dict.update({'brain_mask': mask[None]})
343
+ return update_dict
344
+
345
+ def read_and_deform_CT(exist_keys, task_name, file_name, setups, deform_dict, device, mask, **kwargs):
346
+ Idef, _ = read_and_deform(file_name, torch.float, deform_dict, device, mask, scale = 1000)
347
+ #Idef = torch.clamp(Idef, min = 0., max = 80.) # No clamping for inference/GT
348
+ #Idef /= torch.max(Idef)
349
+ if setups['flip']:
350
+ Idef = torch.flip(Idef, [0])
351
+ update_dict = {'CT': Idef[None]}
352
+
353
+ if os.path.isfile(file_name[:-4] + '.defacingmask.nii'):
354
+ Idef_DM, _ = read_and_deform(file_name[:-4] + '.defacingmask.nii', torch.float, deform_dict, device, mask)
355
+ Idef_DM = torch.clamp(Idef_DM, min = 0.)
356
+ Idef_DM /= torch.max(Idef_DM)
357
+ if setups['flip']:
358
+ Idef = torch.flip(Idef_DM, [0])
359
+ update_dict.update({task_name + '_DM': Idef_DM[None]})
360
+ #if not 'brain_mask' in exist_keys:
361
+ # mask = torch.ones_like(Idef)
362
+ # mask[Idef <= 0.] = 0.
363
+ # update_dict.update({'brain_mask': mask[None]})
364
+ return update_dict
365
+
366
+ def read_and_deform_distance(exist_keys, task_name, file_names, setups, deform_dict, device, mask, cfg, **kwargs):
367
+ [lp_dist_map, lw_dist_map, rp_dist_map, rw_dist_map] = file_names
368
+
369
+
370
+ lp, _ = read_and_deform(lp_dist_map, torch.float, deform_dict, device, mask, default_value_linear_mode = 'max', mean = 128., scale = 20)
371
+ lw, _ = read_and_deform(lw_dist_map, torch.float, deform_dict, device, mask, default_value_linear_mode = 'max', mean = 128., scale = 20)
372
+
373
+ if mask is not None: # left_hemis_only
374
+ Idef = torch.stack([lp, lw], dim = 0)
375
+ else:
376
+ rp, _ = read_and_deform(rp_dist_map, torch.float, deform_dict, device, mask, default_value_linear_mode = 'max', mean = 128., scale = 20)
377
+ rw, _ = read_and_deform(rw_dist_map, torch.float, deform_dict, device, mask, default_value_linear_mode = 'max', mean = 128., scale = 20)
378
+
379
+ if setups['flip']:
380
+ aux = torch.flip(lp, [0])
381
+ lp = torch.flip(rp, [0])
382
+ rp = aux
383
+ aux = torch.flip(lw, [0])
384
+ lw = torch.flip(rw, [0])
385
+ rw = aux
386
+
387
+ Idef = torch.stack([lp, lw, rp, rw], dim = 0)
388
+
389
+ Idef /= deform_dict['scaling_factor_distances']
390
+ Idef = torch.clamp(Idef, min=-cfg.max_surf_distance, max=cfg.max_surf_distance)
391
+
392
+ return {'distance': Idef}
393
+
394
+ def read_and_deform_segmentation(exist_keys, task_name, file_name, setups, deform_dict, device, mask, cfg, onehotmatrix, lut, vflip, **kwargs):
395
+ [xx2, yy2, zz2, x1, y1, z1, x2, y2, z2] = deform_dict['grid']
396
+
397
+ Simg = nib.load(file_name)
398
+ S = torch.squeeze(torch.tensor(Simg.get_fdata()[x1:x2, y1:y2, z1:z2].astype(int), dtype=torch.int, device=device))
399
+
400
+ if mask is not None:
401
+ S[mask == 0] = 0
402
+
403
+ Sdef = fast_3D_interp_torch(S, xx2, yy2, zz2, 'nearest')
404
+ if cfg.generator.deform_one_hots:
405
+ Sonehot = onehotmatrix[lut[S.long()]]
406
+ Sdef_OneHot = fast_3D_interp_torch(Sonehot, xx2, yy2, zz2)
407
+ else:
408
+ Sdef_OneHot = onehotmatrix[lut[Sdef.long()]]
409
+
410
+ if setups['flip']:
411
+ #Sdef = torch.flip(Sdef, [0])
412
+ Sdef_OneHot = torch.flip(Sdef_OneHot, [0])[:, :, :, vflip]
413
+
414
+ # prepare for input
415
+ Sdef_OneHot = Sdef_OneHot.permute([3, 0, 1, 2])
416
+
417
+ #update_dict = {'label': Sdef[None], 'segmentation': Sdef_OneHot}
418
+ update_dict = {'segmentation': Sdef_OneHot}
419
+
420
+ #if not 'brain_mask' in exist_keys:
421
+ # mask = torch.ones_like(Sdef)
422
+ # mask[Sdef <= 0.] = 0.
423
+ # update_dict.update({'brain_mask': mask[None]})
424
+ return update_dict
425
+
426
+
427
+
428
+ def read_and_deform_pathology(exist_keys, task_name, file_name, setups, deform_dict, device, mask = None,
429
+ augment = False, pde_func = None, t = None,
430
+ shape_gen_args = None, thres = 0., **kwargs):
431
+ # NOTE does not support left_hemis for now
432
+
433
+ [xx2, yy2, zz2, x1, y1, z1, x2, y2, z2] = deform_dict['grid']
434
+
435
+ if file_name is None:
436
+ return {'pathology': torch.zeros(xx2.shape)[None].to(device), 'pathology_prob': torch.zeros(xx2.shape)[None].to(device)}
437
+
438
+ if file_name == 'random_shape': # generate random shape
439
+ percentile = np.random.uniform(shape_gen_args.mask_percentile_min, shape_gen_args.mask_percentile_max)
440
+ _, Pdef = generate_shape_3d(xx2.shape, shape_gen_args.perlin_res, percentile, device)
441
+ else: # read from existing shape
442
+ Pdef, _ = read_and_deform(file_name, torch.float, deform_dict, device)
443
+
444
+ if augment:
445
+ Pdef = augment_pathology(Pdef, pde_func, t, shape_gen_args, device)
446
+
447
+ #if setups['flip']: # flipping should happen after P has been encoded
448
+ # Pdef = torch.flip(Pdef, [0])
449
+
450
+ P = binarize(Pdef, thres)
451
+ if P.mean() <= shape_gen_args.pathol_tol:
452
+ return {'pathology': torch.zeros(xx2.shape)[None].to(device), 'pathology_prob': torch.zeros(xx2.shape)[None].to(device)}
453
+ #print('process', P.mean(), shape_gen_args.pathol_tol)
454
+
455
+ return {'pathology': P[None], 'pathology_prob': Pdef[None]}
456
+
457
+
458
+ def read_and_deform_registration(exist_keys, task_name, file_names, setups, deform_dict, device, mask, **kwargs):
459
+ [mni_reg_x, mni_reg_y, mni_reg_z] = file_names
460
+ regx, _ = read_and_deform(mni_reg_x, torch.float, deform_dict, device, mask, scale = 10000)
461
+ regy, _ = read_and_deform(mni_reg_y, torch.float, deform_dict, device, mask, scale = 10000)
462
+ regz, _ = read_and_deform(mni_reg_z, torch.float, deform_dict, device, mask, scale = 10000)
463
+
464
+ if setups['flip']:
465
+ regx = -torch.flip(regx, [0]) # NOTE: careful with switching sign
466
+ regy = torch.flip(regy, [0])
467
+ regz = torch.flip(regz, [0])
468
+
469
+ Idef = torch.stack([regx, regy, regz], dim = 0)
470
+
471
+ return {'registration': Idef}
472
+
473
+ def read_and_deform_bias_field(exist_keys, task_name, file_name, setups, deform_dict, device, mask, **kwargs):
474
+ Idef, _ = read_and_deform(file_name, torch.float, deform_dict, mask, device)
475
+ if setups['flip']:
476
+ Idef = torch.flip(Idef, [0])
477
+ return {'bias_field': Idef[None]}
478
+
479
+ def read_and_deform_surface(exist_keys, task_name, file_name, setups, deform_dict, device, mask, size):
480
+ Fneg, A, c2 = deform_dict['Fneg'], deform_dict['A'], deform_dict['c2']
481
+ # NOTE does not support left_hemis for now
482
+
483
+ mat = loadmat(file_name.split('.nii')[0] + '.mat')
484
+
485
+ Vlw = torch.tensor(mat['Vlw'], dtype=torch.float, device=device)
486
+ Flw = torch.tensor(mat['Flw'], dtype=torch.int, device=device)
487
+ Vrw = torch.tensor(mat['Vrw'], dtype=torch.float, device=device)
488
+ Frw = torch.tensor(mat['Frw'], dtype=torch.int, device=device)
489
+ Vlp = torch.tensor(mat['Vlp'], dtype=torch.float, device=device)
490
+ Flp = torch.tensor(mat['Flp'], dtype=torch.int, device=device)
491
+ Vrp = torch.tensor(mat['Vrp'], dtype=torch.float, device=device)
492
+ Frp = torch.tensor(mat['Frp'], dtype=torch.int, device=device)
493
+
494
+ Ainv = torch.inverse(A)
495
+ Vlw -= c2[None, :]
496
+ Vlw = Vlw @ torch.transpose(Ainv, 0, 1)
497
+ Vlw += fast_3D_interp_torch(Fneg, Vlw[:, 0] + c2[0], Vlw[:, 1]+c2[1], Vlw[:, 2] + c2[2])
498
+ Vlw += c2[None, :]
499
+ Vrw -= c2[None, :]
500
+ Vrw = Vrw @ torch.transpose(Ainv, 0, 1)
501
+ Vrw += fast_3D_interp_torch(Fneg, Vrw[:, 0] + c2[0], Vrw[:, 1]+c2[1], Vrw[:, 2] + c2[2])
502
+ Vrw += c2[None, :]
503
+ Vlp -= c2[None, :]
504
+ Vlp = Vlp @ torch.transpose(Ainv, 0, 1)
505
+ Vlp += fast_3D_interp_torch(Fneg, Vlp[:, 0] + c2[0], Vlp[:, 1] + c2[1], Vlp[:, 2] + c2[2])
506
+ Vlp += c2[None, :]
507
+ Vrp -= c2[None, :]
508
+ Vrp = Vrp @ torch.transpose(Ainv, 0, 1)
509
+ Vrp += fast_3D_interp_torch(Fneg, Vrp[:, 0] + c2[0], Vrp[:, 1] + c2[1], Vrp[:, 2] + c2[2])
510
+ Vrp += c2[None, :]
511
+
512
+ if setups['flip']:
513
+ Vlw[:, 0] = size[0] - 1 - Vlw[:, 0]
514
+ Vrw[:, 0] = size[0] - 1 - Vrw[:, 0]
515
+ Vlp[:, 0] = size[0] - 1 - Vlp[:, 0]
516
+ Vrp[:, 0] = size[0] - 1 - Vrp[:, 0]
517
+ Vlw, Vrw = Vrw, Vlw
518
+ Vlp, Vrp = Vrp, Vlp
519
+ Flw, Frw = Frw, Flw
520
+ Flp, Frp = Frp, Flp
521
+
522
+ print(Vlw.shape) # 131148
523
+ print(Vlp.shape) # 131148
524
+
525
+ print(Vrw.shape) # 131720
526
+ print(Vrp.shape) # 131720
527
+
528
+ print(Flw.shape) # 262292
529
+ print(Flp.shape) # 262292
530
+
531
+ print(Frw.shape) # 263436
532
+ print(Frp.shape) # 263436
533
+ #return torch.stack([Vlw, Flw, Vrw, Frw, Vlp, Flp, Vrp, Frp])
534
+ return {'Vlw': Vlw, 'Flw': Flw, 'Vrw': Vrw, 'Frw': Frw, 'Vlp': Vlp, 'Flp': Flp, 'Vrp': Vrp, 'Frp': Frp}
535
+
536
+
537
+ #####################################
538
+ ######### Pathology Shape #########
539
+ #####################################
540
+
541
+
542
+ def augment_pathology(Pprob, pde_func, t, shape_gen_args, device):
543
+ Pprob = torch.squeeze(Pprob)
544
+
545
+ nt = np.random.randint(1, shape_gen_args.max_nt+1)
546
+ if nt <= 1:
547
+ return Pprob
548
+
549
+ pde_func.V_dict = generate_velocity_3d(Pprob.shape, shape_gen_args.perlin_res, shape_gen_args.V_multiplier, device)
550
+
551
+ #start_time = time.time()
552
+ Pprob = odeint(pde_func, Pprob[None], t[:nt],
553
+ shape_gen_args.dt,
554
+ method = shape_gen_args.integ_method)[-1, 0] # (last_t, n_batch=1, s, r, c)
555
+ # total_time = time.time() - start_time
556
+ #total_time_str = str(datetime.timedelta(seconds=int(total_time)))
557
+ #print('Time {} for {} time points'.format(total_time_str, nt))
558
+
559
+
560
+ return Pprob
561
+
562
+
563
+ #####################################
564
+ ######### Augmentation Func #########
565
+ #####################################
566
+
567
+
568
+ def add_gamma_transform(I, aux_dict, cfg, device, **kwargs):
569
+ gamma = torch.tensor(np.exp(cfg.gamma_std * np.random.randn(1)[0]), dtype=float, device=device)
570
+ I_gamma = 300.0 * (I / 300.0) ** gamma
571
+ #aux_dict.update({'gamma': gamma}) # uncomment if you want to save gamma for later use
572
+ return I_gamma, aux_dict
573
+
574
+ def add_bias_field(I, aux_dict, cfg, input_mode, setups, size, device, **kwargs):
575
+ if input_mode == 'CT':
576
+ aux_dict.update({'high_res': I})
577
+ return I, aux_dict
578
+
579
+ bf_scale = cfg.bf_scale_min + np.random.rand(1) * (cfg.bf_scale_max - cfg.bf_scale_min)
580
+ size_BF_small = np.round(bf_scale * np.array(size)).astype(int).tolist()
581
+ if setups['photo_mode']:
582
+ size_BF_small[1] = np.round(size[1]/setups['spac']).astype(int)
583
+ BFsmall = torch.tensor(cfg.bf_std_min + (cfg.bf_std_max - cfg.bf_std_min) * np.random.rand(1), dtype=torch.float, device=device) * \
584
+ torch.randn(size_BF_small, dtype=torch.float, device=device)
585
+ BFlog = myzoom_torch(BFsmall, np.array(size) / size_BF_small)
586
+ BF = torch.exp(BFlog)
587
+ I_bf = I * BF
588
+ aux_dict.update({'BFlog': BFlog, 'high_res': I_bf})
589
+ return I_bf, aux_dict
590
+
591
+ def resample_resolution(I, aux_dict, setups, res, size, device, **kwargs):
592
+ stds = (0.85 + 0.3 * np.random.rand()) * np.log(5) /np.pi * setups['thickness'] / res
593
+ stds[setups['thickness']<=res] = 0.0 # no blur if thickness is equal to the resolution of the training data
594
+ I_blur = gaussian_blur_3d(I, stds, device)
595
+ new_size = (np.array(size) * res / setups['resolution']).astype(int)
596
+
597
+ factors = np.array(new_size) / np.array(size)
598
+ delta = (1.0 - factors) / (2.0 * factors)
599
+ vx = np.arange(delta[0], delta[0] + new_size[0] / factors[0], 1 / factors[0])[:new_size[0]]
600
+ vy = np.arange(delta[1], delta[1] + new_size[1] / factors[1], 1 / factors[1])[:new_size[1]]
601
+ vz = np.arange(delta[2], delta[2] + new_size[2] / factors[2], 1 / factors[2])[:new_size[2]]
602
+ II, JJ, KK = np.meshgrid(vx, vy, vz, sparse=False, indexing='ij')
603
+ II = torch.tensor(II, dtype=torch.float, device=device)
604
+ JJ = torch.tensor(JJ, dtype=torch.float, device=device)
605
+ KK = torch.tensor(KK, dtype=torch.float, device=device)
606
+
607
+ I_small = fast_3D_interp_torch(I_blur, II, JJ, KK)
608
+ aux_dict.update({'factors': factors})
609
+ return I_small, aux_dict
610
+
611
+
612
+ def resample_resolution_photo(I, aux_dict, setups, res, size, device, **kwargs):
613
+ stds = (0.85 + 0.3 * np.random.rand()) * np.log(5) /np.pi * setups['thickness'] / res
614
+ stds[setups['thickness']<=res] = 0.0 # no blur if thickness is equal to the resolution of the training data
615
+ I_blur = gaussian_blur_3d(I, stds, device)
616
+ new_size = (np.array(size) * res / setups['resolution']).astype(int)
617
+
618
+ factors = np.array(new_size) / np.array(size)
619
+ delta = (1.0 - factors) / (2.0 * factors)
620
+ vx = np.arange(delta[0], delta[0] + new_size[0] / factors[0], 1 / factors[0])[:new_size[0]]
621
+ vy = np.arange(delta[1], delta[1] + new_size[1] / factors[1], 1 / factors[1])[:new_size[1]]
622
+ vz = np.arange(delta[2], delta[2] + new_size[2] / factors[2], 1 / factors[2])[:new_size[2]]
623
+ II, JJ, KK = np.meshgrid(vx, vy, vz, sparse=False, indexing='ij')
624
+ II = torch.tensor(II, dtype=torch.float, device=device)
625
+ JJ = torch.tensor(JJ, dtype=torch.float, device=device)
626
+ KK = torch.tensor(KK, dtype=torch.float, device=device)
627
+
628
+ I_small = fast_3D_interp_torch(I_blur, II, JJ, KK)
629
+ aux_dict.update({'factors': factors})
630
+ return I_small, aux_dict
631
+
632
+
633
+ def add_noise(I, aux_dict, cfg, device, **kwargs):
634
+ noise_std = torch.tensor(cfg.noise_std_min + (cfg.noise_std_max - cfg.noise_std_min) * np.random.rand(1), dtype=torch.float, device=device)
635
+ I_noisy = I + noise_std * torch.randn(I.shape, dtype=torch.float, device=device)
636
+ I_noisy[I_noisy < 0] = 0
637
+ #aux_dict.update({'noise_std': noise_std}) # uncomment if you want to save noise_std for later use
638
+ return I_noisy, aux_dict
639
+
640
+
641
+ #####################################
642
+ #####################################
643
+
644
+
645
+ # map SynthSeg right to left labels for contrast synthesis
646
+ right_to_left_dict = {
647
+ 41: 2,
648
+ 42: 3,
649
+ 43: 4,
650
+ 44: 5,
651
+ 46: 7,
652
+ 47: 8,
653
+ 49: 10,
654
+ 50: 11,
655
+ 51: 12,
656
+ 52: 13,
657
+ 53: 17,
658
+ 54: 18,
659
+ 58: 26,
660
+ 60: 28
661
+ }
662
+
663
+ # based on merged left & right SynthSeg labels
664
+ ct_brightness_group = {
665
+ 'darker': [4, 5, 14, 15, 24, 31, 72], # ventricles, CSF
666
+ 'dark': [2, 7, 16, 77, 30], # white matter
667
+ 'bright': [3, 8, 17, 18, 28, 10, 11, 12, 13, 26], # grey matter (cortex, hippocampus, amggdala, ventral DC), thalamus, ganglia (nucleus (putamen, pallidus, accumbens), caudate)
668
+ 'brighter': [], # skull, pineal gland, choroid plexus
669
+ }
README.md CHANGED
@@ -1,3 +1,91 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ## <p align="center">[A Modality-agnostic Multi-task Foundation Model for Human Brain Imaging](https://arxiv.org/abs/2509.00549)</p>
3
+
4
+ **<p align="center">Peirong Liu<sup>1,2</sup>, Oula Puonti<sup>2</sup>, Xiaoling Hu<sup>2</sup>, Karthik Gopinath<sup>2</sup>, Annabel Sorby-Adams<sup>2</sup>, Daniel C. Alexander<sup>3</sup>, Juan Eugenio Iglesias<sup>2,3,4</sup></p>**
5
+
6
+
7
+ <p align="center">
8
+ <sup>1</sup>Johns Hopkins University<br />
9
+ <sup>2</sup>Harvard Medical School and Massachusetts General Hospital<br />
10
+ <sup>3</sup>University College London <br />
11
+ <sup>4</sup>Massachusetts Institute of Technology
12
+ </p>
13
+
14
+ <p align="center">
15
+ <img src="./assets/overview.png" alt="drawing", width="650"/>
16
+ </p>
17
+
18
+
19
+ This is the official repository for our preprint: A Modality-agnostic Multi-task Foundation Model for Human Brain Imaging [[arXiv]](https://arxiv.org/abs/2509.00549)<br />
20
+ More detailed and organized instructions are coming soon...
21
+
22
+ ## Environment
23
+ Training and evaluation environment: Python 3.11.4, PyTorch 2.0.1, CUDA 12.2. Run the following command to install required packages.
24
+ ```
25
+ conda create -n pre python=3.11
26
+ conda activate pre
27
+
28
+ git clone https://github.com/jhuldr/BrainFM
29
+ cd /path/to/brainfm
30
+ pip install -r requirements.txt
31
+ ```
32
+
33
+
34
+ ## Generator
35
+ ```
36
+ cd scripts
37
+ python demo_generator.py
38
+ ```
39
+
40
+ ### Generator setups
41
+ Setups are in cfgs/generator, default setups are in default.yaml. A customized setup example can be found in train/brain_id.yaml, where several Brain-ID-specific setups are added. During Config reading/implementation, customized yaml will overwrite default.yaml if they have the same keys.
42
+
43
+ dataset_setups: information for all datasets, in Generator/constants.py<br>
44
+ augmentation_funcs: augmentation functions and steps, in Generator/constants.py<br>
45
+ processing_funcs: image processing functions for each modality/task, in Generator/constants.py<br>
46
+
47
+ dataset_names: dataset name list, paths setups in Generator/constants.py<br>
48
+ mix_synth_prob: if the input mode is synthesizing, then probability for blending synth with real images<br>
49
+ dataset_option: generator types, could be BaseGen or customized generator<br>
50
+ task: switch on/off individual training tasks
51
+
52
+ ### Base generator module
53
+ ```
54
+ cd Generator
55
+ python datasets.py
56
+ ```
57
+ The dataset paths setups are in constants.py. In datasets.py, different datasets been used are fomulated as a list of dataset names.
58
+
59
+ A customized data generator module example can be found in datasets.py -- BrainIDGen.
60
+
61
+
62
+ Refer to "__getitem__" function. Specifically, it includes: <br>
63
+ (1) read original input: could be either generation labels or real images;<br>
64
+ (2) generate augmentation setups and deformation fields; <br>
65
+ (3) read target(s) according to the assigned tasks -- here I seperate the processing functions for each item/modality, in case we want different processing steps for them; <br>
66
+ (4) augment input sample: either synthesized or real image input.
67
+
68
+
69
+
70
+ (Some of the functions are leaved blank for now.)
71
+
72
+
73
+
74
+ ## Trainer
75
+ ```
76
+ cd scripts
77
+ python train.py
78
+ ```
79
+
80
+ ## Downloads
81
+ The pre-trained model weight is available on [OneDrive](https://livejohnshopkins-my.sharepoint.com/:u:/g/personal/pliu53_jh_edu/EZ_BJ7K6pMJEj9hZ8SA51GYBxH_Nan4fA3a-s4udwvVRog?e=nwZ7JC).
82
+
83
+
84
+ ## Citation
85
+ ```bibtex
86
+ @article{Liu_2025_BrainFM,
87
+ author = {Liu, Peirong and Puonti, Oula and Hu, Xiaoling and Gopinath, Karthik and Sorby-Adams, Annabel and Alexander, Daniel C. and Iglesias, Juan E.},
88
+ title = {A Modality-agnostic Multi-task Foundation Model for Human Brain Imaging},
89
+ booktitle = {arXiv preprint arXiv:2509.00549},
90
+ year = {2025},
91
+ }
ShapeID/DiffEqs/FD.py ADDED
@@ -0,0 +1,525 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ *finite_difference.py* is the main package to compute finite differences in
3
+ 1D, 2D, and 3D on numpy arrays (class FD_np) and pytorch tensors (class FD_torch).
4
+ The package supports first and second order derivatives and Neumann and linear extrapolation
5
+ boundary conditions (though the latter have not been tested extensively yet).
6
+ """
7
+ from __future__ import absolute_import
8
+
9
+ # from builtins import object
10
+ from abc import ABCMeta, abstractmethod
11
+
12
+ import torch
13
+ from torch.autograd import Variable
14
+ import numpy as np
15
+ from future.utils import with_metaclass
16
+
17
+ class FD(with_metaclass(ABCMeta, object)):
18
+ """
19
+ *FD* is the abstract class for finite differences. It includes most of the actual finite difference code,
20
+ but requires the definition (in a derived class) of the methods *get_dimension*, *create_zero_array*, and *get_size_of_array*.
21
+ In this way the numpy and pytorch versions can easily be derived. All the method expect BxXxYxZ format (i.e., they process a batch at a time)
22
+ """
23
+
24
+ def __init__(self, spacing, bcNeumannZero=True):
25
+ """
26
+ Constructor
27
+ :param spacing: 1D numpy array defining the spatial spacing, e.g., [0.1,0.1,0.1] for a 3D image
28
+ :param bcNeumannZero: Defines the boundary condition. If set to *True* (default) zero Neumann boundary conditions
29
+ are imposed. If set to *False* linear extrapolation is used (this is still experimental, but may be beneficial
30
+ for better boundary behavior)
31
+ """
32
+
33
+ self.dim = len(spacing) # In my code, data_spacing is a list # spacing.size
34
+ """spatial dimension"""
35
+ self.spacing = np.ones(self.dim)
36
+ """spacing"""
37
+ self.bcNeumannZero = bcNeumannZero # if false then linear interpolation
38
+ """should Neumann boundary conditions be used? (otherwise linear extrapolation)"""
39
+ if len(spacing) == 1: #spacing.size==1:
40
+ self.spacing[0] = spacing[0]
41
+ elif len(spacing) == 2: # spacing.size==2:
42
+ self.spacing[0] = spacing[0]
43
+ self.spacing[1] = spacing[1]
44
+ elif len(spacing) == 3: # spacing.size==3:
45
+ self.spacing[0] = spacing[0]
46
+ self.spacing[1] = spacing[1]
47
+ self.spacing[2] = spacing[2]
48
+ else:
49
+ print('Current dimension:', len(spacing))
50
+ raise ValueError('Finite differences are only supported in dimensions 1 to 3')
51
+
52
+ def dXb(self,I):
53
+ """
54
+ Backward difference in x direction:
55
+ :math:`\\frac{dI(i)}{dx}\\approx\\frac{I_i-I_{i-1}}{h_x}`
56
+ :param I: Input image
57
+ :return: Returns the first derivative in x direction using backward differences
58
+ """
59
+ return (I-self.xm(I))/self.spacing[0]
60
+
61
+ def dXf(self,I):
62
+ """
63
+ Forward difference in x direction:
64
+ :math:`\\frac{dI(i)}{dx}\\approx\\frac{I_{i+1}-I_{i}}{h_x}`
65
+
66
+ :param I: Input image
67
+ :return: Returns the first derivative in x direction using forward differences
68
+ """
69
+ return (self.xp(I)-I)/self.spacing[0]
70
+
71
+ def dXc(self,I):
72
+ """
73
+ Central difference in x direction:
74
+ :math:`\\frac{dI(i)}{dx}\\approx\\frac{I_{i+1}-I_{i-1}}{2h_x}`
75
+
76
+ :param I: Input image
77
+ :return: Returns the first derivative in x direction using central differences
78
+ """
79
+ return (self.xp(I)-self.xm(I))/(2*self.spacing[0])
80
+
81
+ def ddXc(self,I):
82
+ """
83
+ Second deriative in x direction
84
+
85
+ :param I: Input image
86
+ :return: Returns the second derivative in x direction
87
+ """
88
+ return (self.xp(I)-2*I+self.xm(I))/(self.spacing[0]**2)
89
+
90
+ def dYb(self,I):
91
+ """
92
+ Same as dXb, but for the y direction
93
+
94
+ :param I: Input image
95
+ :return: Returns the first derivative in y direction using backward differences
96
+ """
97
+ return (I-self.ym(I))/self.spacing[1]
98
+
99
+ def dYf(self,I):
100
+ """
101
+ Same as dXf, but for the y direction
102
+
103
+ :param I: Input image
104
+ :return: Returns the first derivative in y direction using forward differences
105
+ """
106
+ return (self.yp(I)-I)/self.spacing[1]
107
+
108
+ def dYc(self,I):
109
+ """
110
+ Same as dXc, but for the y direction
111
+
112
+ :param I: Input image
113
+ :return: Returns the first derivative in y direction using central differences
114
+ """
115
+ return (self.yp(I)-self.ym(I))/(2*self.spacing[1])
116
+
117
+ def ddYc(self,I):
118
+ """
119
+ Same as ddXc, but for the y direction
120
+
121
+ :param I: Input image
122
+ :return: Returns the second derivative in the y direction
123
+ """
124
+ return (self.yp(I)-2*I+self.ym(I))/(self.spacing[1]**2)
125
+
126
+ def dZb(self,I):
127
+ """
128
+ Same as dXb, but for the z direction
129
+
130
+ :param I: Input image
131
+ :return: Returns the first derivative in the z direction using backward differences
132
+ """
133
+ return (I - self.zm(I))/self.spacing[2]
134
+
135
+ def dZf(self, I):
136
+ """
137
+ Same as dXf, but for the z direction
138
+
139
+ :param I: Input image
140
+ :return: Returns the first derivative in the z direction using forward differences
141
+ """
142
+ return (self.zp(I)-I)/self.spacing[2]
143
+
144
+ def dZc(self, I):
145
+ """
146
+ Same as dXc, but for the z direction
147
+
148
+ :param I: Input image
149
+ :return: Returns the first derivative in the z direction using central differences
150
+ """
151
+ return (self.zp(I)-self.zm(I))/(2*self.spacing[2])
152
+
153
+ def ddZc(self,I):
154
+ """
155
+ Same as ddXc, but for the z direction
156
+
157
+ :param I: Input iamge
158
+ :return: Returns the second derivative in the z direction
159
+ """
160
+ return (self.zp(I)-2*I+self.zm(I))/(self.spacing[2]**2)
161
+
162
+ def lap(self, I):
163
+ """
164
+ Compute the Lapacian of an image
165
+ !!!!!!!!!!!
166
+ IMPORTANT:
167
+ ALL THE FOLLOWING IMPLEMENTED CODE ADD 1 ON DIMENSION, WHICH REPRESENT BATCH DIMENSION.
168
+ THIS IS FOR COMPUTATIONAL EFFICIENCY.
169
+
170
+ :param I: Input image [batch, X,Y,Z]
171
+ :return: Returns the Laplacian
172
+ """
173
+ ndim = self.getdimension(I)
174
+ if ndim == 1+1:
175
+ return self.ddXc(I)
176
+ elif ndim == 2+1:
177
+ return (self.ddXc(I) + self.ddYc(I))
178
+ elif ndim == 3+1:
179
+ return (self.ddXc(I) + self.ddYc(I) + self.ddZc(I))
180
+ else:
181
+ raise ValueError('Finite differences are only supported in dimensions 1 to 3')
182
+
183
+ def grad_norm_sqr_c(self, I):
184
+ """
185
+ Computes the gradient norm of an image
186
+ !!!!!!!!!!!
187
+ IMPORTANT:
188
+ ALL THE FOLLOWING IMPLEMENTED CODE ADD 1 ON DIMENSION, WHICH REPRESENT BATCH DIMENSION.
189
+ THIS IS FOR COMPUTATIONAL EFFICIENCY.
190
+ :param I: Input image [batch, X,Y,Z]
191
+ :return: returns ||grad I||^2
192
+ """
193
+ ndim = self.getdimension(I)
194
+ if ndim == 1 + 1:
195
+ return self.dXc(I)**2
196
+ elif ndim == 2 + 1:
197
+ return (self.dXc(I)**2 + self.dYc(I)**2)
198
+ elif ndim == 3 + 1:
199
+ return (self.dXc(I)**2 + self.dYc(I)**2 + self.dZc(I)**2)
200
+ else:
201
+ raise ValueError('Finite differences are only supported in dimensions 1 to 3')
202
+
203
+ def grad_norm_sqr_f(self, I):
204
+ """
205
+ Computes the gradient norm of an image
206
+ !!!!!!!!!!!
207
+ IMPORTANT:
208
+ ALL THE FOLLOWING IMPLEMENTED CODE ADD 1 ON DIMENSION, WHICH REPRESENT BATCH DIMENSION.
209
+ THIS IS FOR COMPUTATIONAL EFFICIENCY.
210
+ :param I: Input image [batch, X,Y,Z]
211
+ :return: returns ||grad I||^2
212
+ """
213
+ ndim = self.getdimension(I)
214
+ if ndim == 1 + 1:
215
+ return self.dXf(I)**2
216
+ elif ndim == 2 + 1:
217
+ return (self.dXf(I)**2 + self.dYf(I)**2)
218
+ elif ndim == 3 + 1:
219
+ return (self.dXf(I)**2 + self.dYf(I)**2 + self.dZf(I)**2)
220
+ else:
221
+ raise ValueError('Finite differences are only supported in dimensions 1 to 3')
222
+
223
+ def grad_norm_sqr_b(self, I):
224
+ """
225
+ Computes the gradient norm of an image
226
+ !!!!!!!!!!!
227
+ IMPORTANT:
228
+ ALL THE FOLLOWING IMPLEMENTED CODE ADD 1 ON DIMENSION, WHICH REPRESENT BATCH DIMENSION.
229
+ THIS IS FOR COMPUTATIONAL EFFICIENCY.
230
+ :param I: Input image [batch, X,Y,Z]
231
+ :return: returns ||grad I||^2
232
+ """
233
+ ndim = self.getdimension(I)
234
+ if ndim == 1 + 1:
235
+ return self.dXb(I)**2
236
+ elif ndim == 2 + 1:
237
+ return (self.dXb(I)**2 + self.dYb(I)**2)
238
+ elif ndim == 3 + 1:
239
+ return (self.dXb(I)**2 + self.dYb(I)**2 + self.dZb(I)**2)
240
+ else:
241
+ raise ValueError('Finite differences are only supported in dimensions 1 to 3')
242
+
243
+ @abstractmethod
244
+ def getdimension(self,I):
245
+ """
246
+ Abstract method to return the dimension of an input image I
247
+
248
+ :param I: Input image
249
+ :return: Returns the dimension of the image I
250
+ """
251
+ pass
252
+
253
+ @abstractmethod
254
+ def create_zero_array(self, sz):
255
+ """
256
+ Abstract method to create a zero array of a given size, sz. E.g., sz=[10,2,5]
257
+
258
+ :param sz: Size array
259
+ :return: Returns a zero array of the specified size
260
+ """
261
+ pass
262
+
263
+ @abstractmethod
264
+ def get_size_of_array(self, A):
265
+ """
266
+ Abstract method to return the size of an array (as a vector)
267
+
268
+ :param A: Input array
269
+ :return: Returns its size (e.g., [5,10] or [3,4,6]
270
+ """
271
+ pass
272
+
273
+ def xp(self,I):
274
+ """
275
+ !!!!!!!!!!!
276
+ IMPORTANT:
277
+ ALL THE FOLLOWING IMPLEMENTED CODE ADD 1 ON DIMENSION, WHICH REPRESENT BATCH DIMENSION.
278
+ THIS IS FOR COMPUTATIONAL EFFICIENCY.
279
+ Returns the values for x-index incremented by one (to the right in 1D)
280
+
281
+ :param I: Input image [batch, X, Y,Z]
282
+ :return: Image with values at an x-index one larger
283
+ """
284
+ rxp = self.create_zero_array( self.get_size_of_array( I ) )
285
+ ndim = self.getdimension(I)
286
+ if ndim == 1+1:
287
+ rxp[:,0:-1] = I[:,1:]
288
+ if self.bcNeumannZero:
289
+ rxp[:,-1] = I[:,-1]
290
+ else:
291
+ rxp[:,-1] = 2*I[:,-1]-I[:,-2]
292
+ elif ndim == 2+1:
293
+ rxp[:,0:-1,:] = I[:,1:,:]
294
+ if self.bcNeumannZero:
295
+ rxp[:,-1,:] = I[:,-1,:]
296
+ else:
297
+ rxp[:,-1,:] = 2*I[:,-1,:]-I[:,-2,:]
298
+ elif ndim == 3+1:
299
+ rxp[:,0:-1,:,:] = I[:,1:,:,:]
300
+ if self.bcNeumannZero:
301
+ rxp[:,-1,:,:] = I[:,-1,:,:]
302
+ else:
303
+ rxp[:,-1,:,:] = 2*I[:,-1,:,:]-I[:,-2,:,:]
304
+ else:
305
+ raise ValueError('Finite differences are only supported in dimensions 1 to 3')
306
+ return rxp
307
+
308
+ def xm(self,I):
309
+ """
310
+ !!!!!!!!!!!
311
+ IMPORTANT:
312
+ ALL THE FOLLOWING IMPLEMENTED CODE ADD 1 ON DIMENSION, WHICH REPRESENT BATCH DIMENSION.
313
+ THIS IS FOR COMPUTATIONAL EFFICIENCY.
314
+ Returns the values for x-index decremented by one (to the left in 1D)
315
+
316
+ :param I: Input image [batch, X, Y, Z]
317
+ :return: Image with values at an x-index one smaller
318
+ """
319
+ rxm = self.create_zero_array( self.get_size_of_array( I ) )
320
+ ndim = self.getdimension(I)
321
+ if ndim == 1+1:
322
+ rxm[:,1:] = I[:,0:-1]
323
+ if self.bcNeumannZero:
324
+ rxm[:,0] = I[:,0]
325
+ else:
326
+ rxm[:,0] = 2*I[:,0]-I[:,1]
327
+ elif ndim == 2+1:
328
+ rxm[:,1:,:] = I[:,0:-1,:]
329
+ if self.bcNeumannZero:
330
+ rxm[:,0,:] = I[:,0,:]
331
+ else:
332
+ rxm[:,0,:] = 2*I[:,0,:]-I[:,1,:]
333
+ elif ndim == 3+1:
334
+ rxm[:,1:,:,:] = I[:,0:-1,:,:]
335
+ if self.bcNeumannZero:
336
+ rxm[:,0,:,:] = I[:,0,:,:]
337
+ else:
338
+ rxm[:,0,:,:] = 2*I[:,0,:,:]-I[:,1,:,:]
339
+ else:
340
+ raise ValueError('Finite differences are only supported in dimensions 1 to 3')
341
+ return rxm
342
+
343
+ def yp(self, I):
344
+ """
345
+ !!!!!!!!!!!
346
+ IMPORTANT:
347
+ ALL THE FOLLOWING IMPLEMENTED CODE ADD 1 ON DIMENSION, WHICH REPRESENT BATCH DIMENSION.
348
+ THIS IS FOR COMPUTATIONAL EFFICIENCY.
349
+ Same as xp, but for the y direction
350
+
351
+ :param I: Input image
352
+ :return: Image with values at y-index one larger
353
+ """
354
+ ryp = self.create_zero_array( self.get_size_of_array( I ) )
355
+ ndim = self.getdimension(I)
356
+ if ndim == 2+1:
357
+ ryp[:,:,0:-1] = I[:,:,1:]
358
+ if self.bcNeumannZero:
359
+ ryp[:,:,-1] = I[:,:,-1]
360
+ else:
361
+ ryp[:,:,-1] = 2*I[:,:,-1]-I[:,:,-2]
362
+ elif ndim == 3+1:
363
+ ryp[:,:,0:-1,:] = I[:,:,1:,:]
364
+ if self.bcNeumannZero:
365
+ ryp[:,:,-1,:] = I[:,:,-1,:]
366
+ else:
367
+ ryp[:,:,-1,:] = 2*I[:,:,-1,:]-I[:,:,-2,:]
368
+ else:
369
+ print('Current dimension:', ndim-1)
370
+ raise ValueError('Finite differences are only supported in dimensions 1 to 3')
371
+ return ryp
372
+
373
+ def ym(self, I):
374
+ """
375
+ Same as xm, but for the y direction
376
+ !!!!!!!!!!!
377
+ IMPORTANT:
378
+ ALL THE FOLLOWING IMPLEMENTED CODE ADD 1 ON DIMENSION, WHICH REPRESENT BATCH DIMENSION.
379
+ THIS IS FOR COMPUTATIONAL EFFICIENCY.
380
+ Returns the values for x-index decremented by one (to the left in 1D)
381
+ :param I: Input image [batch, X, Y, Z]
382
+ :return: Image with values at y-index one smaller
383
+ """
384
+ rym = self.create_zero_array( self.get_size_of_array( I ) )
385
+ ndim = self.getdimension(I)
386
+ if ndim == 2+1:
387
+ rym[:,:,1:] = I[:,:,0:-1]
388
+ if self.bcNeumannZero:
389
+ rym[:,:,0] = I[:,:,0]
390
+ else:
391
+ rym[:,:,0] = 2*I[:,:,0]-I[:,:,1]
392
+ elif ndim == 3+1:
393
+ rym[:,:,1:,:] = I[:,:,0:-1,:]
394
+ if self.bcNeumannZero:
395
+ rym[:,:,0,:] = I[:,:,0,:]
396
+ else:
397
+ rym[:,:,0,:] = 2*I[:,:,0,:]-I[:,:,1,:]
398
+ else:
399
+ raise ValueError('Finite differences are only supported in dimensions 1 to 3')
400
+ return rym
401
+
402
+ def zp(self, I):
403
+ """
404
+ Same as xp, but for the z direction
405
+
406
+ !!!!!!!!!!!
407
+ IMPORTANT:
408
+ ALL THE FOLLOWING IMPLEMENTED CODE ADD 1 ON DIMENSION, WHICH REPRESENT BATCH DIMENSION.
409
+ THIS IS FOR COMPUTATIONAL EFFICIENCY.
410
+ Returns the values for x-index decremented by one (to the left in 1D)
411
+ :param I: Input image [batch, X, Y, Z]
412
+ :return: Image with values at z-index one larger
413
+ """
414
+ rzp = self.create_zero_array( self.get_size_of_array( I ) )
415
+ ndim = self.getdimension(I)
416
+ if ndim == 3+1:
417
+ rzp[:,:,:,0:-1] = I[:,:,:,1:]
418
+ if self.bcNeumannZero:
419
+ rzp[:,:,:,-1] = I[:,:,:,-1]
420
+ else:
421
+ rzp[:,:,:,-1] = 2*I[:,:,:,-1]-I[:,:,:,-2]
422
+ else:
423
+ raise ValueError('Finite differences are only supported in dimensions 1 to 3')
424
+ return rzp
425
+
426
+ def zm(self, I):
427
+ """
428
+ Same as xm, but for the z direction
429
+
430
+ !!!!!!!!!!!
431
+ IMPORTANT:
432
+ ALL THE FOLLOWING IMPLEMENTED CODE ADD 1 ON DIMENSION, WHICH REPRESENT BATCH DIMENSION.
433
+ THIS IS FOR COMPUTATIONAL EFFICIENCY.
434
+ Returns the values for x-index decremented by one (to the left in 1D)
435
+ :param I: Input image [batch, X, Y, Z]
436
+ :return: Image with values at z-index one smaller
437
+ """
438
+ rzm = self.create_zero_array( self.get_size_of_array( I ) )
439
+ ndim = self.getdimension(I)
440
+ if ndim == 3+1:
441
+ rzm[:,:,:,1:] = I[:,:,:,0:-1]
442
+ if self.bcNeumannZero:
443
+ rzm[:,:,:,0] = I[:,:,:,0]
444
+ else:
445
+ rzm[:,:,:,0] = 2*I[:,:,:,0]-I[:,:,:,1]
446
+ else:
447
+ raise ValueError('Finite differences are only supported in dimensions 1 to 3')
448
+ return rzm
449
+
450
+
451
+ class FD_np(FD):
452
+ """
453
+ Defnitions of the abstract methods for numpy
454
+ """
455
+
456
+ def __init__(self,spacing,bcNeumannZero=True):
457
+ """
458
+ Constructor for numpy finite differences
459
+ :param spacing: spatial spacing (array with as many entries as there are spatial dimensions)
460
+ :param bcNeumannZero: Specifies if zero Neumann conditions should be used (if not, uses linear extrapolation)
461
+ """
462
+ super(FD_np, self).__init__(spacing,bcNeumannZero)
463
+
464
+ def getdimension(self,I):
465
+ """
466
+ Returns the dimension of an image
467
+ :param I: input image
468
+ :return: dimension of the input image
469
+ """
470
+ return I.ndim
471
+
472
+ def create_zero_array(self, sz):
473
+ """
474
+ Creates a zero array
475
+ :param sz: size of the zero array, e.g., [3,4,2]
476
+ :return: the zero array
477
+ """
478
+ return np.zeros( sz )
479
+
480
+ def get_size_of_array(self, A):
481
+ """
482
+ Returns the size (shape in numpy) of an array
483
+ :param A: input array
484
+ :return: shape/size
485
+ """
486
+ return A.shape
487
+
488
+
489
+ class FD_torch(FD):
490
+ """
491
+ Defnitions of the abstract methods for torch
492
+ """
493
+
494
+ def __init__(self,spacing,device,bcNeumannZero=True):
495
+ """
496
+ Constructor for torch finite differences
497
+ :param spacing: spatial spacing (array with as many entries as there are spatial dimensions)
498
+ :param bcNeumannZero: Specifies if zero Neumann conditions should be used (if not, uses linear extrapolation)
499
+ """
500
+ super(FD_torch, self).__init__(spacing,bcNeumannZero)
501
+ self.device = device
502
+
503
+ def getdimension(self,I):
504
+ """
505
+ Returns the dimension of an image
506
+ :param I: input image
507
+ :return: dimension of the input image
508
+ """
509
+ return I.dim()
510
+
511
+ def create_zero_array(self, sz):
512
+ """
513
+ Creats a zero array
514
+ :param sz: size of the array, e.g., [3,4,2]
515
+ :return: the zero array
516
+ """
517
+ return torch.zeros(sz).float().to(self.device)
518
+
519
+ def get_size_of_array(self, A):
520
+ """
521
+ Returns the size (size()) of an array
522
+ :param A: input array
523
+ :return: shape/size
524
+ """
525
+ return A.size()
ShapeID/DiffEqs/adams.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import torch
3
+ from ShapeID.DiffEqs.solvers import AdaptiveStepsizeODESolver
4
+ from ShapeID.DiffEqs.misc import (
5
+ _handle_unused_kwargs, _select_initial_step, _convert_to_tensor, _scaled_dot_product, _is_iterable,
6
+ _optimal_step_size, _compute_error_ratio
7
+ )
8
+
9
+ _MIN_ORDER = 1
10
+ _MAX_ORDER = 12
11
+
12
+ gamma_star = [
13
+ 1, -1 / 2, -1 / 12, -1 / 24, -19 / 720, -3 / 160, -863 / 60480, -275 / 24192, -33953 / 3628800, -0.00789255,
14
+ -0.00678585, -0.00592406, -0.00523669, -0.0046775, -0.00421495, -0.0038269
15
+ ]
16
+
17
+
18
+ class _VCABMState(collections.namedtuple('_VCABMState', 'y_n, prev_f, prev_t, next_t, phi, order')):
19
+ """Saved state of the variable step size Adams-Bashforth-Moulton solver as described in
20
+
21
+ Solving Ordinary Differential Equations I - Nonstiff Problems III.5
22
+ by Ernst Hairer, Gerhard Wanner, and Syvert P Norsett.
23
+ """
24
+
25
+
26
+ def g_and_explicit_phi(prev_t, next_t, implicit_phi, k):
27
+ curr_t = prev_t[0]
28
+ dt = next_t - prev_t[0]
29
+
30
+ g = torch.empty(k + 1).to(prev_t[0])
31
+ explicit_phi = collections.deque(maxlen=k)
32
+ beta = torch.tensor(1).to(prev_t[0])
33
+
34
+ g[0] = 1
35
+ c = 1 / torch.arange(1, k + 2).to(prev_t[0])
36
+ explicit_phi.append(implicit_phi[0])
37
+
38
+ for j in range(1, k):
39
+ beta = (next_t - prev_t[j - 1]) / (curr_t - prev_t[j]) * beta
40
+ beat_cast = beta.to(implicit_phi[j][0])
41
+ explicit_phi.append(tuple(iphi_ * beat_cast for iphi_ in implicit_phi[j]))
42
+
43
+ c = c[:-1] - c[1:] if j == 1 else c[:-1] - c[1:] * dt / (next_t - prev_t[j - 1])
44
+ g[j] = c[0]
45
+
46
+ c = c[:-1] - c[1:] * dt / (next_t - prev_t[k - 1])
47
+ g[k] = c[0]
48
+
49
+ return g, explicit_phi
50
+
51
+
52
+ def compute_implicit_phi(explicit_phi, f_n, k):
53
+ k = min(len(explicit_phi) + 1, k)
54
+ implicit_phi = collections.deque(maxlen=k)
55
+ implicit_phi.append(f_n)
56
+ for j in range(1, k):
57
+ implicit_phi.append(tuple(iphi_ - ephi_ for iphi_, ephi_ in zip(implicit_phi[j - 1], explicit_phi[j - 1])))
58
+ return implicit_phi
59
+
60
+
61
+ class VariableCoefficientAdamsBashforth(AdaptiveStepsizeODESolver):
62
+
63
+ def __init__(
64
+ self, func, y0, rtol, atol, implicit=True, max_order=_MAX_ORDER, safety=0.9, ifactor=10.0, dfactor=0.2,
65
+ **unused_kwargs
66
+ ):
67
+ _handle_unused_kwargs(self, unused_kwargs)
68
+ del unused_kwargs
69
+
70
+ self.func = func
71
+ self.y0 = y0
72
+ self.rtol = rtol if _is_iterable(rtol) else [rtol] * len(y0)
73
+ self.atol = atol if _is_iterable(atol) else [atol] * len(y0)
74
+ self.implicit = implicit
75
+ self.max_order = int(max(_MIN_ORDER, min(max_order, _MAX_ORDER)))
76
+ self.safety = _convert_to_tensor(safety, dtype=torch.float64, device=y0[0].device)
77
+ self.ifactor = _convert_to_tensor(ifactor, dtype=torch.float64, device=y0[0].device)
78
+ self.dfactor = _convert_to_tensor(dfactor, dtype=torch.float64, device=y0[0].device)
79
+
80
+ def before_integrate(self, t):
81
+ prev_f = collections.deque(maxlen=self.max_order + 1)
82
+ prev_t = collections.deque(maxlen=self.max_order + 1)
83
+ phi = collections.deque(maxlen=self.max_order)
84
+
85
+ t0 = t[0]
86
+ f0 = self.func(t0.type_as(self.y0[0]), self.y0)
87
+ prev_t.appendleft(t0)
88
+ prev_f.appendleft(f0)
89
+ phi.appendleft(f0)
90
+ first_step = _select_initial_step(self.func, t[0], self.y0, 2, self.rtol[0], self.atol[0], f0=f0).to(t)
91
+
92
+ self.vcabm_state = _VCABMState(self.y0, prev_f, prev_t, next_t=t[0] + first_step, phi=phi, order=1)
93
+
94
+ def advance(self, final_t):
95
+ final_t = _convert_to_tensor(final_t).to(self.vcabm_state.prev_t[0])
96
+ while final_t > self.vcabm_state.prev_t[0]:
97
+ self.vcabm_state = self._adaptive_adams_step(self.vcabm_state, final_t)
98
+ assert final_t == self.vcabm_state.prev_t[0]
99
+ return self.vcabm_state.y_n
100
+
101
+ def _adaptive_adams_step(self, vcabm_state, final_t):
102
+ y0, prev_f, prev_t, next_t, prev_phi, order = vcabm_state
103
+ if next_t > final_t:
104
+ next_t = final_t
105
+ dt = (next_t - prev_t[0])
106
+ dt_cast = dt.to(y0[0])
107
+
108
+ # Explicit predictor step.
109
+ g, phi = g_and_explicit_phi(prev_t, next_t, prev_phi, order)
110
+ g = g.to(y0[0])
111
+ p_next = tuple(
112
+ y0_ + _scaled_dot_product(dt_cast, g[:max(1, order - 1)], phi_[:max(1, order - 1)])
113
+ for y0_, phi_ in zip(y0, tuple(zip(*phi)))
114
+ )
115
+
116
+ # Update phi to implicit.
117
+ next_f0 = self.func(next_t.to(p_next[0]), p_next)
118
+ implicit_phi_p = compute_implicit_phi(phi, next_f0, order + 1)
119
+
120
+ # Implicit corrector step.
121
+ y_next = tuple(
122
+ p_next_ + dt_cast * g[order - 1] * iphi_ for p_next_, iphi_ in zip(p_next, implicit_phi_p[order - 1])
123
+ )
124
+
125
+ # Error estimation.
126
+ tolerance = tuple(
127
+ atol_ + rtol_ * torch.max(torch.abs(y0_), torch.abs(y1_))
128
+ for atol_, rtol_, y0_, y1_ in zip(self.atol, self.rtol, y0, y_next)
129
+ )
130
+ local_error = tuple(dt_cast * (g[order] - g[order - 1]) * iphi_ for iphi_ in implicit_phi_p[order])
131
+ error_k = _compute_error_ratio(local_error, tolerance)
132
+ accept_step = (torch.tensor(error_k) <= 1).all()
133
+
134
+ if not accept_step:
135
+ # Retry with adjusted step size if step is rejected.
136
+ dt_next = _optimal_step_size(dt, error_k, self.safety, self.ifactor, self.dfactor, order=order)
137
+ return _VCABMState(y0, prev_f, prev_t, prev_t[0] + dt_next, prev_phi, order=order)
138
+
139
+ # We accept the step. Evaluate f and update phi.
140
+ next_f0 = self.func(next_t.to(p_next[0]), y_next)
141
+ implicit_phi = compute_implicit_phi(phi, next_f0, order + 2)
142
+
143
+ next_order = order
144
+
145
+ if len(prev_t) <= 4 or order < 3:
146
+ next_order = min(order + 1, 3, self.max_order)
147
+ else:
148
+ error_km1 = _compute_error_ratio(
149
+ tuple(dt_cast * (g[order - 1] - g[order - 2]) * iphi_ for iphi_ in implicit_phi_p[order - 1]), tolerance
150
+ )
151
+ error_km2 = _compute_error_ratio(
152
+ tuple(dt_cast * (g[order - 2] - g[order - 3]) * iphi_ for iphi_ in implicit_phi_p[order - 2]), tolerance
153
+ )
154
+ if min(error_km1 + error_km2) < max(error_k):
155
+ next_order = order - 1
156
+ elif order < self.max_order:
157
+ error_kp1 = _compute_error_ratio(
158
+ tuple(dt_cast * gamma_star[order] * iphi_ for iphi_ in implicit_phi_p[order]), tolerance
159
+ )
160
+ if max(error_kp1) < max(error_k):
161
+ next_order = order + 1
162
+
163
+ # Keep step size constant if increasing order. Else use adaptive step size.
164
+ dt_next = dt if next_order > order else _optimal_step_size(
165
+ dt, error_k, self.safety, self.ifactor, self.dfactor, order=order + 1
166
+ )
167
+
168
+ prev_f.appendleft(next_f0)
169
+ prev_t.appendleft(next_t)
170
+ return _VCABMState(p_next, prev_f, prev_t, next_t + dt_next, implicit_phi, order=next_order)
ShapeID/DiffEqs/adjoint.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from ShapeID.DiffEqs.odeint import odeint
4
+ from ShapeID.DiffEqs.misc import _flatten, _flatten_convert_none_to_zeros
5
+
6
+
7
+ class OdeintAdjointMethod(torch.autograd.Function):
8
+
9
+ @staticmethod
10
+ def forward(ctx, *args):
11
+ assert len(args) >= 8, 'Internal error: all arguments required.'
12
+ y0, func, t, dt, flat_params, rtol, atol, method, options = \
13
+ args[:-8], args[-8], args[-7], args[-6], args[-5], args[-4], args[-3], args[-2], args[-1]
14
+
15
+ ctx.func, ctx.rtol, ctx.atol, ctx.method, ctx.options = func, rtol, atol, method, options
16
+
17
+ with torch.no_grad():
18
+ ans = odeint(func, y0, t, dt, rtol=rtol, atol=atol, method=method, options=options)
19
+ ctx.save_for_backward(t, flat_params, *ans)
20
+ return ans
21
+
22
+ @staticmethod
23
+ def backward(ctx, *grad_output):
24
+
25
+ t, flat_params, *ans = ctx.saved_tensors
26
+ ans = tuple(ans)
27
+ func, rtol, atol, method, options = ctx.func, ctx.rtol, ctx.atol, ctx.method, ctx.options
28
+ n_tensors = len(ans)
29
+ f_params = tuple(func.parameters())
30
+
31
+ # TODO: use a nn.Module and call odeint_adjoint to implement higher order derivatives.
32
+ def augmented_dynamics(t, y_aug):
33
+ # Dynamics of the original system augmented with
34
+ # the adjoint wrt y, and an integrator wrt t and args.
35
+ y, adj_y = y_aug[:n_tensors], y_aug[n_tensors:2 * n_tensors] # Ignore adj_time and adj_params.
36
+
37
+ with torch.set_grad_enabled(True):
38
+ t = t.to(y[0].device).detach().requires_grad_(True)
39
+ y = tuple(y_.detach().requires_grad_(True) for y_ in y)
40
+ func_eval = func(t, y)
41
+ vjp_t, *vjp_y_and_params = torch.autograd.grad(
42
+ func_eval, (t,) + y + f_params,
43
+ tuple(-adj_y_ for adj_y_ in adj_y), allow_unused=True, retain_graph=True
44
+ )
45
+ vjp_y = vjp_y_and_params[:n_tensors]
46
+ vjp_params = vjp_y_and_params[n_tensors:]
47
+
48
+ # autograd.grad returns None if no gradient, set to zero.
49
+ vjp_t = torch.zeros_like(t) if vjp_t is None else vjp_t
50
+ vjp_y = tuple(torch.zeros_like(y_) if vjp_y_ is None else vjp_y_ for vjp_y_, y_ in zip(vjp_y, y))
51
+ vjp_params = _flatten_convert_none_to_zeros(vjp_params, f_params)
52
+
53
+ if len(f_params) == 0:
54
+ vjp_params = torch.tensor(0.).to(vjp_y[0])
55
+ return (*func_eval, *vjp_y, vjp_t, vjp_params)
56
+
57
+ T = ans[0].shape[0]
58
+ with torch.no_grad():
59
+ adj_y = tuple(grad_output_[-1] for grad_output_ in grad_output)
60
+ adj_params = torch.zeros_like(flat_params)
61
+ adj_time = torch.tensor(0.).to(t)
62
+ time_vjps = []
63
+ for i in range(T - 1, 0, -1):
64
+
65
+ ans_i = tuple(ans_[i] for ans_ in ans)
66
+ grad_output_i = tuple(grad_output_[i] for grad_output_ in grad_output)
67
+ func_i = func(t[i], ans_i)
68
+
69
+ # Compute the effect of moving the current time measurement point.
70
+ dLd_cur_t = sum(
71
+ torch.dot(func_i_.reshape(-1), grad_output_i_.reshape(-1)).reshape(1)
72
+ for func_i_, grad_output_i_ in zip(func_i, grad_output_i)
73
+ )
74
+ adj_time = adj_time - dLd_cur_t
75
+ time_vjps.append(dLd_cur_t)
76
+
77
+ # Run the augmented system backwards in time.
78
+ if adj_params.numel() == 0:
79
+ adj_params = torch.tensor(0.).to(adj_y[0])
80
+ aug_y0 = (*ans_i, *adj_y, adj_time, adj_params)
81
+ aug_ans = odeint(
82
+ augmented_dynamics, aug_y0,
83
+ torch.tensor([t[i], t[i - 1]]), rtol=rtol, atol=atol, method=method, options=options
84
+ )
85
+
86
+ # Unpack aug_ans.
87
+ adj_y = aug_ans[n_tensors:2 * n_tensors]
88
+ adj_time = aug_ans[2 * n_tensors]
89
+ adj_params = aug_ans[2 * n_tensors + 1]
90
+
91
+ adj_y = tuple(adj_y_[1] if len(adj_y_) > 0 else adj_y_ for adj_y_ in adj_y)
92
+ if len(adj_time) > 0: adj_time = adj_time[1]
93
+ if len(adj_params) > 0: adj_params = adj_params[1]
94
+
95
+ adj_y = tuple(adj_y_ + grad_output_[i - 1] for adj_y_, grad_output_ in zip(adj_y, grad_output))
96
+
97
+ del aug_y0, aug_ans
98
+
99
+ time_vjps.append(adj_time)
100
+ time_vjps = torch.cat(time_vjps[::-1])
101
+
102
+ return (*adj_y, None, time_vjps, adj_params, None, None, None, None, None, None) # Add a None (TODO, futher check)
103
+
104
+
105
+ def odeint_adjoint(func, y0, t, dt, rtol=1e-6, atol=1e-12, method=None, options=None):
106
+
107
+ # We need this in order to access the variables inside this module,
108
+ # since we have no other way of getting variables along the execution path.
109
+ if not isinstance(func, nn.Module):
110
+ raise ValueError('func is required to be an instance of nn.Module.')
111
+
112
+ tensor_input = False
113
+ if torch.is_tensor(y0):
114
+
115
+ class TupleFunc(nn.Module):
116
+
117
+ def __init__(self, base_func):
118
+ super(TupleFunc, self).__init__()
119
+ self.base_func = base_func
120
+
121
+ def forward(self, t, y):
122
+ return (self.base_func(t, y[0]),)
123
+
124
+ tensor_input = True
125
+ y0 = (y0,)
126
+ func = TupleFunc(func)
127
+
128
+ flat_params = _flatten(func.parameters())
129
+ ys = OdeintAdjointMethod.apply(*y0, func, t, dt, flat_params, rtol, atol, method, options)
130
+
131
+ if tensor_input:
132
+ ys = ys[0]
133
+ return ys
ShapeID/DiffEqs/dopri5.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .misc import (
3
+ _scaled_dot_product, _convert_to_tensor, _is_finite, _select_initial_step, _handle_unused_kwargs, _is_iterable,
4
+ _optimal_step_size, _compute_error_ratio
5
+ )
6
+ from .solvers import AdaptiveStepsizeODESolver, set_BC_2D, set_BC_3D, add_dBC_2D, add_dBC_3D
7
+ from .interp import _interp_fit, _interp_evaluate
8
+ from .rk_common import _RungeKuttaState, _ButcherTableau, _runge_kutta_step
9
+
10
+
11
+ _DORMAND_PRINCE_SHAMPINE_TABLEAU = _ButcherTableau(
12
+ alpha=[1 / 5, 3 / 10, 4 / 5, 8 / 9, 1., 1.],
13
+ beta=[
14
+ [1 / 5],
15
+ [3 / 40, 9 / 40],
16
+ [44 / 45, -56 / 15, 32 / 9],
17
+ [19372 / 6561, -25360 / 2187, 64448 / 6561, -212 / 729],
18
+ [9017 / 3168, -355 / 33, 46732 / 5247, 49 / 176, -5103 / 18656],
19
+ [35 / 384, 0, 500 / 1113, 125 / 192, -2187 / 6784, 11 / 84],
20
+ ],
21
+ c_sol=[35 / 384, 0, 500 / 1113, 125 / 192, -2187 / 6784, 11 / 84, 0],
22
+ c_error=[
23
+ 35 / 384 - 1951 / 21600,
24
+ 0,
25
+ 500 / 1113 - 22642 / 50085,
26
+ 125 / 192 - 451 / 720,
27
+ -2187 / 6784 - -12231 / 42400,
28
+ 11 / 84 - 649 / 6300,
29
+ -1. / 60.,
30
+ ],
31
+ )
32
+
33
+ DPS_C_MID = [
34
+ 6025192743 / 30085553152 / 2, 0, 51252292925 / 65400821598 / 2, -2691868925 / 45128329728 / 2,
35
+ 187940372067 / 1594534317056 / 2, -1776094331 / 19743644256 / 2, 11237099 / 235043384 / 2
36
+ ]
37
+
38
+
39
+ def _interp_fit_dopri5(y0, y1, k, dt, tableau=_DORMAND_PRINCE_SHAMPINE_TABLEAU):
40
+ """Fit an interpolating polynomial to the results of a Runge-Kutta step."""
41
+ dt = dt.type_as(y0[0])
42
+ y_mid = tuple(y0_ + _scaled_dot_product(dt, DPS_C_MID, k_) for y0_, k_ in zip(y0, k))
43
+ f0 = tuple(k_[0] for k_ in k)
44
+ f1 = tuple(k_[-1] for k_ in k)
45
+ return _interp_fit(y0, y1, y_mid, f0, f1, dt)
46
+
47
+
48
+ def _abs_square(x):
49
+ return torch.mul(x, x)
50
+
51
+
52
+ def _ta_append(list_of_tensors, value):
53
+ """Append a value to the end of a list of PyTorch tensors."""
54
+ list_of_tensors.append(value)
55
+ return list_of_tensors
56
+
57
+
58
+ class Dopri5Solver(AdaptiveStepsizeODESolver):
59
+
60
+ def __init__(
61
+ self, func, y0, rtol, atol, dt, first_step=None, safety=0.9, ifactor=10.0, dfactor=0.2, max_num_steps=2**31 - 1,
62
+ options = None
63
+ #**unused_kwargs
64
+ ):
65
+ #_handle_unused_kwargs(self, unused_kwargs)
66
+ #del unused_kwargs
67
+
68
+ self.func = func
69
+ self.y0 = y0
70
+
71
+ self.dt = dt #options.dt
72
+ '''if 'dirichlet' in options.BC or 'cauchy' in options.BC and options.contours is not None:
73
+ self.contours = options.contours # (n_batch, nT, 4 / 6, BC_size, sub_spatial_shape)
74
+ self.BC_size = self.contours.size(3)
75
+ self.set_BC = set_BC_2D if self.contours.size(2) == 4 else set_BC_3D
76
+ else:
77
+ self.contours = None
78
+ if 'source' in options.BC and options.dcontours is not None:
79
+ self.dcontours = options.dcontours # (n_batch, nT, 4 / 6, BC_size, sub_spatial_shape)
80
+ self.BC_size = self.dcontours.size(3)
81
+ self.add_dBC = add_dBC_2D if self.dcontours.size(2) == 4 else add_dBC_3D
82
+ else:
83
+ self.dcontours = None'''
84
+
85
+ #self.adjoint = options.adjoint
86
+
87
+ self.rtol = rtol if _is_iterable(rtol) else [rtol] * len(y0)
88
+ self.atol = atol if _is_iterable(atol) else [atol] * len(y0)
89
+ self.first_step = first_step
90
+ self.safety = _convert_to_tensor(safety, dtype=torch.float64, device=y0[0].device)
91
+ self.ifactor = _convert_to_tensor(ifactor, dtype=torch.float64, device=y0[0].device)
92
+ self.dfactor = _convert_to_tensor(dfactor, dtype=torch.float64, device=y0[0].device)
93
+ self.max_num_steps = _convert_to_tensor(max_num_steps, dtype=torch.int32, device=y0[0].device)
94
+ #self.n_step_record=[]
95
+
96
+ def before_integrate(self, t):
97
+ f0 = self.func(t[0].type_as(self.y0[0]), self.y0)
98
+ #print("first_step is {}".format(self.first_step))
99
+ if self.first_step is None:
100
+ first_step = _select_initial_step(self.func, t[0], self.y0, 4, self.rtol[0], self.atol[0], f0=f0).to(t)
101
+ else:
102
+ first_step = _convert_to_tensor(0.01, dtype=t.dtype, device=t.device)
103
+ # if first_step>0.2:
104
+ # print("warning the first step of dopri5 {} is too big, set to 0.2".format(first_step))
105
+ # first_step = _convert_to_tensor(0.2, dtype=torch.float64, device=self.y0[0].device)
106
+
107
+ self.rk_state = _RungeKuttaState(self.y0, f0, t[0], t[0], first_step, interp_coeff=[self.y0] * 5)
108
+
109
+ def advance(self, next_t):
110
+ """Interpolate through the next time point, integrating as necessary."""
111
+ n_steps = 0
112
+ while next_t > self.rk_state.t1:
113
+ assert n_steps < self.max_num_steps, 'max_num_steps exceeded ({}>={})'.format(n_steps, self.max_num_steps)
114
+ self.rk_state = self._adaptive_dopri5_step(self.rk_state)
115
+ n_steps += 1
116
+ # if len(self.n_step_record)==100:
117
+ # print("this dopri5 step info will print every 100 calls, the current average step is {}".format(sum(self.n_step_record)/100))
118
+ # self.n_step_record=[]
119
+ # else:
120
+ # self.n_step_record.append(n_steps)
121
+
122
+ return _interp_evaluate(self.rk_state.interp_coeff, self.rk_state.t0, self.rk_state.t1, next_t)
123
+
124
+ def _adaptive_dopri5_step(self, rk_state):
125
+ """Take an adaptive Runge-Kutta step to integrate the DiffEqs."""
126
+ y0, f0, _, t0, dt, interp_coeff = rk_state
127
+ ########################################################
128
+ # Assertions #
129
+ ########################################################
130
+ assert t0 + dt > t0, 'underflow in dt {}'.format(dt.item())
131
+ # for y0_ in y0:
132
+ # #assert _is_finite(torch.abs(y0_)), 'non-finite values in state `y`: {}'.format(y0_)
133
+ # is_finite= _is_finite(torch.abs(y0_))
134
+ # if not is_finite:
135
+ # print(" non-finite elements exist, try to fix")
136
+ # y0_[y0_ != y0_] = 0.
137
+ # y0_[y0_ == float("Inf")] = 0.
138
+
139
+ y1, f1, y1_error, k = _runge_kutta_step(self.func, y0, f0, t0, dt, tableau=_DORMAND_PRINCE_SHAMPINE_TABLEAU)
140
+
141
+ ########################################################
142
+ # Error Ratio #
143
+ ########################################################
144
+ mean_sq_error_ratio = _compute_error_ratio(y1_error, atol=self.atol, rtol=self.rtol, y0=y0, y1=y1)
145
+ accept_step = (torch.tensor(mean_sq_error_ratio) <= 1).all()
146
+
147
+ ########################################################
148
+ # Update RK State #
149
+ ########################################################
150
+ dt_next = _optimal_step_size(
151
+ dt, mean_sq_error_ratio, safety=self.safety, ifactor=self.ifactor, dfactor=self.dfactor, order=5)
152
+ tol_min_dt = 0.2 * self.dt if 0.1 * self.dt >= 0.01 else 0.01
153
+ #print('tol min', tol_min_dt)
154
+ if not (dt_next< tol_min_dt or dt_next>0.1): #(dt_next<0.01 or dt_next>0.1): #(dt_next<0.02): #not (dt_next<0.02 or dt_next>0.1):
155
+ y_next = y1 if accept_step else y0
156
+ f_next = f1 if accept_step else f0
157
+ t_next = t0 + dt if accept_step else t0
158
+ interp_coeff = _interp_fit_dopri5(y0, y_next, k, dt) if accept_step else interp_coeff
159
+ else:
160
+ if dt_next< tol_min_dt: #dt_next<0.01: # 0.01
161
+ #print("Dopri5 step %.3f too small, set to %.3f" % (dt_next, 0.2 * self.dt))
162
+ dt_next = _convert_to_tensor(tol_min_dt, dtype=torch.float64, device=y0[0].device)
163
+ if dt_next>0.1:
164
+ #print("Dopri5 step %.8f is too big, set to 0.1" % (dt_next))
165
+ dt_next = _convert_to_tensor(0.1, dtype=torch.float64, device=y0[0].device)
166
+ y_next = y1
167
+ f_next = f1
168
+ t_next = t0 + dt
169
+ interp_coeff = _interp_fit_dopri5(y0, y1, k, dt)
170
+ rk_state = _RungeKuttaState(y_next, f_next, t0, t_next, dt_next, interp_coeff)
171
+ #print('dt_next', dt_next)
172
+ return rk_state
ShapeID/DiffEqs/fixed_adams.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import collections
3
+ from ShapeID.DiffEqs.solvers import FixedGridODESolver
4
+ from ShapeID.DiffEqs.misc import _scaled_dot_product, _has_converged
5
+ import ShapeID.DiffEqs.rk_common
6
+
7
+ _BASHFORTH_COEFFICIENTS = [
8
+ [], # order 0
9
+ [11],
10
+ [3, -1],
11
+ [23, -16, 5],
12
+ [55, -59, 37, -9],
13
+ [1901, -2774, 2616, -1274, 251],
14
+ [4277, -7923, 9982, -7298, 2877, -475],
15
+ [198721, -447288, 705549, -688256, 407139, -134472, 19087],
16
+ [434241, -1152169, 2183877, -2664477, 2102243, -1041723, 295767, -36799],
17
+ [14097247, -43125206, 95476786, -139855262, 137968480, -91172642, 38833486, -9664106, 1070017],
18
+ [30277247, -104995189, 265932680, -454661776, 538363838, -444772162, 252618224, -94307320, 20884811, -2082753],
19
+ [
20
+ 2132509567, -8271795124, 23591063805, -46113029016, 63716378958, -63176201472, 44857168434, -22329634920,
21
+ 7417904451, -1479574348, 134211265
22
+ ],
23
+ [
24
+ 4527766399, -19433810163, 61633227185, -135579356757, 214139355366, -247741639374, 211103573298, -131365867290,
25
+ 58189107627, -17410248271, 3158642445, -262747265
26
+ ],
27
+ [
28
+ 13064406523627, -61497552797274, 214696591002612, -524924579905150, 932884546055895, -1233589244941764,
29
+ 1226443086129408, -915883387152444, 507140369728425, -202322913738370, 55060974662412, -9160551085734,
30
+ 703604254357
31
+ ],
32
+ [
33
+ 27511554976875, -140970750679621, 537247052515662, -1445313351681906, 2854429571790805, -4246767353305755,
34
+ 4825671323488452, -4204551925534524, 2793869602879077, -1393306307155755, 505586141196430, -126174972681906,
35
+ 19382853593787, -1382741929621
36
+ ],
37
+ [
38
+ 173233498598849, -960122866404112, 3966421670215481, -11643637530577472, 25298910337081429, -41825269932507728,
39
+ 53471026659940509, -53246738660646912, 41280216336284259, -24704503655607728, 11205849753515179,
40
+ -3728807256577472, 859236476684231, -122594813904112, 8164168737599
41
+ ],
42
+ [
43
+ 362555126427073, -2161567671248849, 9622096909515337, -30607373860520569, 72558117072259733,
44
+ -131963191940828581, 187463140112902893, -210020588912321949, 186087544263596643, -129930094104237331,
45
+ 70724351582843483, -29417910911251819, 9038571752734087, -1934443196892599, 257650275915823, -16088129229375
46
+ ],
47
+ [
48
+ 192996103681340479, -1231887339593444974, 5878428128276811750, -20141834622844109630, 51733880057282977010,
49
+ -102651404730855807942, 160414858999474733422, -199694296833704562550, 199061418623907202560,
50
+ -158848144481581407370, 100878076849144434322, -50353311405771659322, 19338911944324897550,
51
+ -5518639984393844930, 1102560345141059610, -137692773163513234, 8092989203533249
52
+ ],
53
+ [
54
+ 401972381695456831, -2735437642844079789, 13930159965811142228, -51150187791975812900, 141500575026572531760,
55
+ -304188128232928718008, 518600355541383671092, -710171024091234303204, 786600875277595877750,
56
+ -706174326992944287370, 512538584122114046748, -298477260353977522892, 137563142659866897224,
57
+ -49070094880794267600, 13071639236569712860, -2448689255584545196, 287848942064256339, -15980174332775873
58
+ ],
59
+ [
60
+ 333374427829017307697, -2409687649238345289684, 13044139139831833251471, -51099831122607588046344,
61
+ 151474888613495715415020, -350702929608291455167896, 647758157491921902292692, -967713746544629658690408,
62
+ 1179078743786280451953222, -1176161829956768365219840, 960377035444205950813626, -639182123082298748001432,
63
+ 343690461612471516746028, -147118738993288163742312, 48988597853073465932820, -12236035290567356418552,
64
+ 2157574942881818312049, -239560589366324764716, 12600467236042756559
65
+ ],
66
+ [
67
+ 691668239157222107697, -5292843584961252933125, 30349492858024727686755, -126346544855927856134295,
68
+ 399537307669842150996468, -991168450545135070835076, 1971629028083798845750380, -3191065388846318679544380,
69
+ 4241614331208149947151790, -4654326468801478894406214, 4222756879776354065593786, -3161821089800186539248210,
70
+ 1943018818982002395655620, -970350191086531368649620, 387739787034699092364924, -121059601023985433003532,
71
+ 28462032496476316665705, -4740335757093710713245, 498669220956647866875, -24919383499187492303
72
+ ],
73
+ ]
74
+
75
+ _MOULTON_COEFFICIENTS = [
76
+ [], # order 0
77
+ [1],
78
+ [1, 1],
79
+ [5, 8, -1],
80
+ [9, 19, -5, 1],
81
+ [251, 646, -264, 106, -19],
82
+ [475, 1427, -798, 482, -173, 27],
83
+ [19087, 65112, -46461, 37504, -20211, 6312, -863],
84
+ [36799, 139849, -121797, 123133, -88547, 41499, -11351, 1375],
85
+ [1070017, 4467094, -4604594, 5595358, -5033120, 3146338, -1291214, 312874, -33953],
86
+ [2082753, 9449717, -11271304, 16002320, -17283646, 13510082, -7394032, 2687864, -583435, 57281],
87
+ [
88
+ 134211265, 656185652, -890175549, 1446205080, -1823311566, 1710774528, -1170597042, 567450984, -184776195,
89
+ 36284876, -3250433
90
+ ],
91
+ [
92
+ 262747265, 1374799219, -2092490673, 3828828885, -5519460582, 6043521486, -4963166514, 3007739418, -1305971115,
93
+ 384709327, -68928781, 5675265
94
+ ],
95
+ [
96
+ 703604254357, 3917551216986, -6616420957428, 13465774256510, -21847538039895, 27345870698436, -26204344465152,
97
+ 19058185652796, -10344711794985, 4063327863170, -1092096992268, 179842822566, -13695779093
98
+ ],
99
+ [
100
+ 1382741929621, 8153167962181, -15141235084110, 33928990133618, -61188680131285, 86180228689563, -94393338653892,
101
+ 80101021029180, -52177910882661, 25620259777835, -9181635605134, 2268078814386, -345457086395, 24466579093
102
+ ],
103
+ [
104
+ 8164168737599, 50770967534864, -102885148956217, 251724894607936, -499547203754837, 781911618071632,
105
+ -963605400824733, 934600833490944, -710312834197347, 418551804601264, -187504936597931, 61759426692544,
106
+ -14110480969927, 1998759236336, -132282840127
107
+ ],
108
+ [
109
+ 16088129229375, 105145058757073, -230992163723849, 612744541065337, -1326978663058069, 2285168598349733,
110
+ -3129453071993581, 3414941728852893, -2966365730265699, 2039345879546643, -1096355235402331, 451403108933483,
111
+ -137515713789319, 29219384284087, -3867689367599, 240208245823
112
+ ],
113
+ [
114
+ 8092989203533249, 55415287221275246, -131240807912923110, 375195469874202430, -880520318434977010,
115
+ 1654462865819232198, -2492570347928318318, 3022404969160106870, -2953729295811279360, 2320851086013919370,
116
+ -1455690451266780818, 719242466216944698, -273894214307914510, 77597639915764930, -15407325991235610,
117
+ 1913813460537746, -111956703448001
118
+ ],
119
+ [
120
+ 15980174332775873, 114329243705491117, -290470969929371220, 890337710266029860, -2250854333681641520,
121
+ 4582441343348851896, -7532171919277411636, 10047287575124288740, -10910555637627652470, 9644799218032932490,
122
+ -6913858539337636636, 3985516155854664396, -1821304040326216520, 645008976643217360, -170761422500096220,
123
+ 31816981024600492, -3722582669836627, 205804074290625
124
+ ],
125
+ [
126
+ 12600467236042756559, 93965550344204933076, -255007751875033918095, 834286388106402145800,
127
+ -2260420115705863623660, 4956655592790542146968, -8827052559979384209108, 12845814402199484797800,
128
+ -15345231910046032448070, 15072781455122686545920, -12155867625610599812538, 8008520809622324571288,
129
+ -4269779992576330506540, 1814584564159445787240, -600505972582990474260, 149186846171741510136,
130
+ -26182538841925312881, 2895045518506940460, -151711881512390095
131
+ ],
132
+ [
133
+ 24919383499187492303, 193280569173472261637, -558160720115629395555, 1941395668950986461335,
134
+ -5612131802364455926260, 13187185898439270330756, -25293146116627869170796, 39878419226784442421820,
135
+ -51970649453670274135470, 56154678684618739939910, -50320851025594566473146, 37297227252822858381906,
136
+ -22726350407538133839300, 11268210124987992327060, -4474886658024166985340, 1389665263296211699212,
137
+ -325187970422032795497, 53935307402575440285, -5652892248087175675, 281550972898020815
138
+ ],
139
+ ]
140
+
141
+ _DIVISOR = [
142
+ None, 11, 2, 12, 24, 720, 1440, 60480, 120960, 3628800, 7257600, 479001600, 958003200, 2615348736000, 5230697472000,
143
+ 31384184832000, 62768369664000, 32011868528640000, 64023737057280000, 51090942171709440000, 102181884343418880000
144
+ ]
145
+
146
+ _MIN_ORDER = 4
147
+ _MAX_ORDER = 12
148
+ _MAX_ITERS = 4
149
+
150
+
151
+ class AdamsBashforthMoulton(FixedGridODESolver):
152
+
153
+ def __init__(
154
+ self, func, y0, rtol=1e-3, atol=1e-4, implicit=True, max_iters=_MAX_ITERS, max_order=_MAX_ORDER, **kwargs
155
+ ):
156
+ super(AdamsBashforthMoulton, self).__init__(func, y0, **kwargs)
157
+
158
+ self.rtol = rtol
159
+ self.atol = atol
160
+ self.implicit = implicit
161
+ self.max_iters = max_iters
162
+ self.max_order = int(min(max_order, _MAX_ORDER))
163
+ self.prev_f = collections.deque(maxlen=self.max_order - 1)
164
+ self.prev_t = None
165
+
166
+ def _update_history(self, t, f):
167
+ if self.prev_t is None or self.prev_t != t:
168
+ self.prev_f.appendleft(f)
169
+ self.prev_t = t
170
+
171
+ def step_func(self, func, t, dt, y):
172
+ self._update_history(t, func(t, y))
173
+ order = min(len(self.prev_f), self.max_order - 1)
174
+ if order < _MIN_ORDER - 1:
175
+ # Compute using RK4.
176
+ dy = rk_common.rk4_alt_step_func(func, t, dt, y, k1=self.prev_f[0])
177
+ return dy
178
+ else:
179
+ # Adams-Bashforth predictor.
180
+ bashforth_coeffs = _BASHFORTH_COEFFICIENTS[order]
181
+ ab_div = _DIVISOR[order]
182
+ dy = tuple(dt * _scaled_dot_product(1 / ab_div, bashforth_coeffs, f_) for f_ in zip(*self.prev_f))
183
+
184
+ # Adams-Moulton corrector.
185
+ if self.implicit:
186
+ moulton_coeffs = _MOULTON_COEFFICIENTS[order + 1]
187
+ am_div = _DIVISOR[order + 1]
188
+ delta = tuple(dt * _scaled_dot_product(1 / am_div, moulton_coeffs[1:], f_) for f_ in zip(*self.prev_f))
189
+ converged = False
190
+ for _ in range(self.max_iters):
191
+ dy_old = dy
192
+ f = func(t + dt, tuple(y_ + dy_ for y_, dy_ in zip(y, dy)))
193
+ dy = tuple(dt * (moulton_coeffs[0] / am_div) * f_ + delta_ for f_, delta_ in zip(f, delta))
194
+ converged = _has_converged(dy_old, dy, self.rtol, self.atol)
195
+ if converged:
196
+ break
197
+ if not converged:
198
+ print('Warning: Functional iteration did not converge. Solution may be incorrect.', file=sys.stderr)
199
+ self.prev_f.pop()
200
+ self._update_history(t, f)
201
+ return dy
202
+
203
+ @property
204
+ def order(self):
205
+ return 4
206
+
207
+
208
+ class AdamsBashforth(AdamsBashforthMoulton):
209
+
210
+ def __init__(self, func, y0, **kwargs):
211
+ super(AdamsBashforth, self).__init__(func, y0, implicit=False, **kwargs)
ShapeID/DiffEqs/fixed_grid.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ShapeID.DiffEqs.solvers import FixedGridODESolver
2
+ import ShapeID.DiffEqs.rk_common as rk_common
3
+
4
+
5
+ class Euler(FixedGridODESolver):
6
+
7
+ def step_func(self, func, t, dt, y):
8
+ return tuple(dt * f_ for f_ in func(t, y))
9
+
10
+ @property
11
+ def order(self):
12
+ return 1
13
+
14
+
15
+ class Midpoint(FixedGridODESolver):
16
+
17
+ def step_func(self, func, t, dt, y):
18
+ y_mid = tuple(y_ + f_ * dt / 2 for y_, f_ in zip(y, func(t, y)))
19
+ return tuple(dt * f_ for f_ in func(t + dt / 2, y_mid))
20
+
21
+ @property
22
+ def order(self):
23
+ return 2
24
+
25
+
26
+ class RK4(FixedGridODESolver):
27
+
28
+ def step_func(self, func, t, dt, y):
29
+ return rk_common.rk4_alt_step_func(func, t, dt, y)
30
+
31
+ @property
32
+ def order(self):
33
+ return 4
ShapeID/DiffEqs/interp.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from ShapeID.DiffEqs.misc import _convert_to_tensor, _dot_product
3
+
4
+
5
+ def _interp_fit(y0, y1, y_mid, f0, f1, dt):
6
+ """Fit coefficients for 4th order polynomial interpolation.
7
+
8
+ Args:
9
+ y0: function value at the start of the interval.
10
+ y1: function value at the end of the interval.
11
+ y_mid: function value at the mid-point of the interval.
12
+ f0: derivative value at the start of the interval.
13
+ f1: derivative value at the end of the interval.
14
+ dt: width of the interval.
15
+
16
+ Returns:
17
+ List of coefficients `[a, b, c, d, e]` for interpolating with the polynomial
18
+ `p = a * x ** 4 + b * x ** 3 + c * x ** 2 + d * x + e` for values of `x`
19
+ between 0 (start of interval) and 1 (end of interval).
20
+ """
21
+ a = tuple(
22
+ _dot_product([-2 * dt, 2 * dt, -8, -8, 16], [f0_, f1_, y0_, y1_, y_mid_])
23
+ for f0_, f1_, y0_, y1_, y_mid_ in zip(f0, f1, y0, y1, y_mid)
24
+ )
25
+ b = tuple(
26
+ _dot_product([5 * dt, -3 * dt, 18, 14, -32], [f0_, f1_, y0_, y1_, y_mid_])
27
+ for f0_, f1_, y0_, y1_, y_mid_ in zip(f0, f1, y0, y1, y_mid)
28
+ )
29
+ c = tuple(
30
+ _dot_product([-4 * dt, dt, -11, -5, 16], [f0_, f1_, y0_, y1_, y_mid_])
31
+ for f0_, f1_, y0_, y1_, y_mid_ in zip(f0, f1, y0, y1, y_mid)
32
+ )
33
+ d = tuple(dt * f0_ for f0_ in f0)
34
+ e = y0
35
+ return [a, b, c, d, e]
36
+
37
+
38
+ def _interp_evaluate(coefficients, t0, t1, t):
39
+ """Evaluate polynomial interpolation at the given time point.
40
+
41
+ Args:
42
+ coefficients: list of Tensor coefficients as created by `interp_fit`.
43
+ t0: scalar float64 Tensor giving the start of the interval.
44
+ t1: scalar float64 Tensor giving the end of the interval.
45
+ t: scalar float64 Tensor giving the desired interpolation point.
46
+
47
+ Returns:
48
+ Polynomial interpolation of the coefficients at time `t`.
49
+ """
50
+
51
+ dtype = coefficients[0][0].dtype
52
+ device = coefficients[0][0].device
53
+
54
+ t0 = _convert_to_tensor(t0, dtype=dtype, device=device)
55
+ t1 = _convert_to_tensor(t1, dtype=dtype, device=device)
56
+ t = _convert_to_tensor(t, dtype=dtype, device=device)
57
+
58
+ assert (t0 <= t) & (t <= t1), 'invalid interpolation, fails `t0 <= t <= t1`: {}, {}, {}'.format(t0, t, t1)
59
+ x = ((t - t0) / (t1 - t0)).type(dtype).to(device)
60
+
61
+ xs = [torch.tensor(1).type(dtype).to(device), x]
62
+ for _ in range(2, len(coefficients)):
63
+ xs.append(xs[-1] * x)
64
+
65
+ return tuple(_dot_product(coefficients_, reversed(xs)) for coefficients_ in zip(*coefficients))
ShapeID/DiffEqs/misc.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ import torch
3
+
4
+
5
+ def _flatten(sequence):
6
+ flat = [p.contiguous().view(-1) for p in sequence]
7
+ return torch.cat(flat) if len(flat) > 0 else torch.tensor([])
8
+
9
+
10
+ def _flatten_convert_none_to_zeros(sequence, like_sequence):
11
+ flat = [
12
+ p.contiguous().view(-1) if p is not None else torch.zeros_like(q).view(-1)
13
+ for p, q in zip(sequence, like_sequence)
14
+ ]
15
+ return torch.cat(flat) if len(flat) > 0 else torch.tensor([])
16
+
17
+
18
+ def _possibly_nonzero(x):
19
+ return isinstance(x, torch.Tensor) or x != 0
20
+
21
+
22
+ def _scaled_dot_product(scale, xs, ys):
23
+ """Calculate a scaled, vector inner product between lists of Tensors."""
24
+ # Using _possibly_nonzero lets us avoid wasted computation.
25
+ return sum([(scale * x) * y for x, y in zip(xs, ys) if _possibly_nonzero(x) or _possibly_nonzero(y)])
26
+
27
+
28
+ def _dot_product(xs, ys):
29
+ """Calculate the vector inner product between two lists of Tensors."""
30
+ return sum([x * y for x, y in zip(xs, ys)])
31
+
32
+
33
+ def _has_converged(y0, y1, rtol, atol):
34
+ """Checks that each element is within the error tolerance."""
35
+ error_tol = tuple(atol + rtol * torch.max(torch.abs(y0_), torch.abs(y1_)) for y0_, y1_ in zip(y0, y1))
36
+ error = tuple(torch.abs(y0_ - y1_) for y0_, y1_ in zip(y0, y1))
37
+ return all((error_ < error_tol_).all() for error_, error_tol_ in zip(error, error_tol))
38
+
39
+
40
+ def _convert_to_tensor(a, dtype=None, device=None):
41
+ if not isinstance(a, torch.Tensor):
42
+ a = torch.tensor(a)
43
+ if dtype is not None:
44
+ a = a.type(dtype)
45
+ if device is not None:
46
+ a = a.to(device)
47
+ return a
48
+
49
+
50
+ def _is_finite(tensor):
51
+ _check = (tensor == float('inf')) + (tensor == float('-inf')) + torch.isnan(tensor)
52
+ return not _check.any()
53
+
54
+
55
+ def _decreasing(t):
56
+ return (t[1:] < t[:-1]).all()
57
+
58
+
59
+ def _assert_increasing(t):
60
+ assert (t[1:] > t[:-1]).all(), 't must be strictly increasing or decrasing'
61
+
62
+
63
+ def _is_iterable(inputs):
64
+ try:
65
+ iter(inputs)
66
+ return True
67
+ except TypeError:
68
+ return False
69
+
70
+
71
+ def _norm(x):
72
+ """Compute RMS norm."""
73
+ if torch.is_tensor(x):
74
+ return x.norm() / (x.numel()**0.5)
75
+ else:
76
+ return torch.sqrt(sum(x_.norm()**2 for x_ in x) / sum(x_.numel() for x_ in x))
77
+
78
+
79
+ def _handle_unused_kwargs(solver, unused_kwargs):
80
+ if len(unused_kwargs) > 0:
81
+ warnings.warn('{}: Unexpected arguments {}'.format(solver.__class__.__name__, unused_kwargs))
82
+
83
+
84
+ def _select_initial_step(fun, t0, y0, order, rtol, atol, f0=None):
85
+ """Empirically select a good initial step.
86
+
87
+ The algorithm is described in [1]_.
88
+
89
+ Parameters
90
+ ----------
91
+ fun : callable
92
+ Right-hand side of the system.
93
+ t0 : float
94
+ Initial value of the independent variable.
95
+ y0 : ndarray, shape (n,)
96
+ Initial value of the dependent variable.
97
+ direction : float
98
+ Integration direction.
99
+ order : float
100
+ Method order.
101
+ rtol : float
102
+ Desired relative tolerance.
103
+ atol : float
104
+ Desired absolute tolerance.
105
+
106
+ Returns
107
+ -------
108
+ h_abs : float
109
+ Absolute value of the suggested initial step.
110
+
111
+ References
112
+ ----------
113
+ .. [1] E. Hairer, S. P. Norsett G. Wanner, "Solving Ordinary Differential
114
+ Equations I: Nonstiff Problems", Sec. II.4.
115
+ """
116
+ t0 = t0.to(y0[0])
117
+ if f0 is None:
118
+ f0 = fun(t0, y0)
119
+
120
+ rtol = rtol if _is_iterable(rtol) else [rtol] * len(y0)
121
+ atol = atol if _is_iterable(atol) else [atol] * len(y0)
122
+
123
+ scale = tuple(atol_ + torch.abs(y0_) * rtol_ for y0_, atol_, rtol_ in zip(y0, atol, rtol))
124
+
125
+ d0 = tuple(_norm(y0_ / scale_) for y0_, scale_ in zip(y0, scale))
126
+ d1 = tuple(_norm(f0_ / scale_) for f0_, scale_ in zip(f0, scale))
127
+
128
+ if max(d0).item() < 1e-5 or max(d1).item() < 1e-5:
129
+ h0 = torch.tensor(1e-6).to(t0)
130
+ else:
131
+ h0 = 0.01 * max(d0_ / d1_ for d0_, d1_ in zip(d0, d1))
132
+
133
+ y1 = tuple(y0_ + h0 * f0_ for y0_, f0_ in zip(y0, f0))
134
+ f1 = fun(t0 + h0, y1)
135
+
136
+ d2 = tuple(_norm((f1_ - f0_) / scale_) / h0 for f1_, f0_, scale_ in zip(f1, f0, scale))
137
+
138
+ if max(d1).item() <= 1e-15 and max(d2).item() <= 1e-15:
139
+ h1 = torch.max(torch.tensor(1e-6).to(h0), h0 * 1e-3)
140
+ else:
141
+ h1 = (0.01 / max(d1 + d2))**(1. / float(order + 1))
142
+
143
+ return torch.min(100 * h0, h1)
144
+
145
+
146
+ def _compute_error_ratio(error_estimate, error_tol=None, rtol=None, atol=None, y0=None, y1=None):
147
+ if error_tol is None:
148
+ assert rtol is not None and atol is not None and y0 is not None and y1 is not None
149
+ rtol if _is_iterable(rtol) else [rtol] * len(y0)
150
+ atol if _is_iterable(atol) else [atol] * len(y0)
151
+ error_tol = tuple(
152
+ atol_ + rtol_ * torch.max(torch.abs(y0_), torch.abs(y1_))
153
+ for atol_, rtol_, y0_, y1_ in zip(atol, rtol, y0, y1)
154
+ )
155
+ error_ratio = tuple(error_estimate_ / error_tol_ for error_estimate_, error_tol_ in zip(error_estimate, error_tol))
156
+ mean_sq_error_ratio = tuple(torch.mean(error_ratio_ * error_ratio_) for error_ratio_ in error_ratio)
157
+ return mean_sq_error_ratio
158
+
159
+
160
+ def _optimal_step_size(last_step, mean_error_ratio, safety=0.9, ifactor=10.0, dfactor=0.2, order=5):
161
+ """Calculate the optimal size for the next step."""
162
+ mean_error_ratio = max(mean_error_ratio) # Compute step size based on highest ratio.
163
+ if mean_error_ratio == 0:
164
+ return last_step * ifactor
165
+ if mean_error_ratio < 1:
166
+ dfactor = _convert_to_tensor(1, dtype=torch.float64, device=mean_error_ratio.device)
167
+ error_ratio = torch.sqrt(mean_error_ratio).to(last_step)
168
+ exponent = torch.tensor(1 / order).to(last_step)
169
+ factor = torch.max(1 / ifactor, torch.min(error_ratio**exponent / safety, 1 / dfactor))
170
+ return last_step / factor
171
+
172
+
173
+ def _check_inputs(func, y0, t):
174
+ tensor_input = False
175
+ if torch.is_tensor(y0):
176
+ tensor_input = True
177
+ y0 = (y0,)
178
+ _base_nontuple_func_ = func
179
+ func = lambda t, y: (_base_nontuple_func_(t, y[0]),)
180
+ assert isinstance(y0, tuple), 'y0 must be either a torch.Tensor or a tuple'
181
+ for y0_ in y0:
182
+ assert torch.is_tensor(y0_), 'each element must be a torch.Tensor but received {}'.format(type(y0_))
183
+
184
+ if _decreasing(t):
185
+ t = -t
186
+ _base_reverse_func = func
187
+ func = lambda t, y: tuple(-f_ for f_ in _base_reverse_func(-t, y))
188
+
189
+ for y0_ in y0:
190
+ if not torch.is_floating_point(y0_):
191
+ raise TypeError('`y0` must be a floating point Tensor but is a {}'.format(y0_.type()))
192
+ if not torch.is_floating_point(t):
193
+ raise TypeError('`t` must be a floating point Tensor but is a {}'.format(t.type()))
194
+
195
+ return tensor_input, func, y0, t
ShapeID/DiffEqs/odeint.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ShapeID.DiffEqs.tsit5 import Tsit5Solver
2
+ from ShapeID.DiffEqs.dopri5 import Dopri5Solver
3
+ from ShapeID.DiffEqs.fixed_grid import Euler, Midpoint, RK4
4
+ from ShapeID.DiffEqs.fixed_adams import AdamsBashforth, AdamsBashforthMoulton
5
+ from ShapeID.DiffEqs.adams import VariableCoefficientAdamsBashforth
6
+ from ShapeID.DiffEqs.misc import _check_inputs
7
+
8
+ SOLVERS = {
9
+ 'explicit_adams': AdamsBashforth,
10
+ 'fixed_adams': AdamsBashforthMoulton,
11
+ 'adams': VariableCoefficientAdamsBashforth,
12
+ 'tsit5': Tsit5Solver,
13
+ 'dopri5': Dopri5Solver,
14
+ 'euler': Euler,
15
+ 'midpoint': Midpoint,
16
+ 'rk4': RK4,
17
+ }
18
+
19
+
20
+ def odeint(func, y0, t, dt, step_size = None, rtol = 1e-7, atol = 1e-9, method = None, options = None):
21
+ """Integrate a system of ordinary differential equations.
22
+
23
+ Solves the initial value problem for a non-stiff system of first order ODEs:
24
+ ```
25
+ dy/dt = func(t, y), y(t[0]) = y0
26
+ ```
27
+ where y is a Tensor of any shape.
28
+
29
+ Output dtypes and numerical precision are based on the dtypes of the inputs `y0`.
30
+
31
+ Args:
32
+ func: Function that maps a Tensor holding the state `y` and a scalar Tensor
33
+ `t` into a Tensor of state derivatives with respect to time.
34
+ y0: N-D Tensor giving starting value of `y` at time point `t[0]`. May
35
+ have any floating point or complex dtype.
36
+ t: 1-D Tensor holding a sequence of time points for which to solve for
37
+ `y`. The initial time point should be the first element of this sequence,
38
+ and each time must be larger than the previous time. May have any floating
39
+ point dtype. Converted to a Tensor with float64 dtype.
40
+ rtol: optional float64 Tensor specifying an upper bound on relative error,
41
+ per element of `y`.
42
+ atol: optional float64 Tensor specifying an upper bound on absolute error,
43
+ per element of `y`.
44
+ method: optional string indicating the integration method to use.
45
+ options: optional dict of configuring options for the indicated integration
46
+ method. Can only be provided if a `method` is explicitly set.
47
+ name: Optional name for this operation.
48
+
49
+ Returns:
50
+ y: Tensor, where the first dimension corresponds to different
51
+ time points. Contains the solved value of y for each desired time point in
52
+ `t`, with the initial value `y0` being the first element along the first
53
+ dimension.
54
+
55
+ Raises:
56
+ ValueError: if an invalid `method` is provided.
57
+ TypeError: if `options` is supplied without `method`, or if `t` or `y0` has
58
+ an invalid dtype.
59
+ """
60
+
61
+ tensor_input, func, y0, t = _check_inputs(func, y0, t)
62
+
63
+ if options and method is None:
64
+ raise ValueError('cannot supply `options` without specifying `method`')
65
+
66
+ if method is None:
67
+ method = 'dopri5'
68
+
69
+ #solver = SOLVERS[method](func, y0, rtol = rtol, atol = atol, **options)
70
+ solver = SOLVERS[method](func, y0, rtol = rtol, atol = atol, dt = dt, options = options)
71
+ solution = solver.integrate(t)
72
+
73
+ if tensor_input:
74
+ solution = solution[0]
75
+ return solution
ShapeID/DiffEqs/pde.py ADDED
@@ -0,0 +1,643 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ported from https://github.com/pvigier/perlin-numpy
2
+
3
+ import math
4
+
5
+ import numpy as np
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+
11
+
12
+
13
+ def gradient_f(X, batched = False, delta_lst = [1., 1., 1.]):
14
+ '''
15
+ Compute gradient of a torch tensor "X" in each direction
16
+ Upper-boundaries: Backward Difference
17
+ Non-boundaries & Upper-boundaries: Forward Difference
18
+ if X is batched: (n_batch, ...);
19
+ else: (...)
20
+ '''
21
+ device = X.device
22
+ dim = len(X.size()) - 1 if batched else len(X.size())
23
+ #print(batched)
24
+ #print(dim)
25
+ if dim == 1:
26
+ #print('dim = 1')
27
+ dX = torch.zeros(X.size(), dtype = torch.float, device = device)
28
+ X = X.permute(1, 0) if batched else X
29
+ dX = dX.permute(1, 0) if batched else dX
30
+ dX[-1] = X[-1] - X[-2] # Backward Difference
31
+ dX[:-1] = X[1:] - X[:-1] # Forward Difference
32
+
33
+ dX = dX.permute(1, 0) if batched else dX
34
+ dX /= delta_lst[0]
35
+ elif dim == 2:
36
+ #print('dim = 2')
37
+ dX = torch.zeros(X.size() + tuple([2]), dtype = torch.float, device = device)
38
+ X = X.permute(1, 2, 0) if batched else X
39
+ dX = dX.permute(1, 2, 3, 0) if batched else dX # put batch to last dim
40
+ dX[-1, :, 0] = X[-1, :] - X[-2, :] # Backward Difference
41
+ dX[:-1, :, 0] = X[1:] - X[:-1] # Forward Difference
42
+
43
+ dX[:, -1, 1] = X[:, -1] - X[:, -2] # Backward Difference
44
+ dX[:, :-1, 1] = X[:, 1:] - X[:, :-1] # Forward Difference
45
+
46
+ dX = dX.permute(3, 0, 1, 2) if batched else dX
47
+ dX[..., 0] /= delta_lst[0]
48
+ dX[..., 1] /= delta_lst[1]
49
+ elif dim == 3:
50
+ #print('dim = 3')
51
+ dX = torch.zeros(X.size() + tuple([3]), dtype = torch.float, device = device)
52
+ X = X.permute(1, 2, 3, 0) if batched else X
53
+ dX = dX.permute(1, 2, 3, 4, 0) if batched else dX
54
+ dX[-1, :, :, 0] = X[-1, :, :] - X[-2, :, :] # Backward Difference
55
+ dX[:-1, :, :, 0] = X[1:] - X[:-1] # Forward Difference
56
+
57
+ dX[:, -1, :, 1] = X[:, -1] - X[:, -2] # Backward Difference
58
+ dX[:, :-1, :, 1] = X[:, 1:] - X[:, :-1] # Forward Difference
59
+
60
+ dX[:, :, -1, 2] = X[:, :, -1] - X[:, :, -2] # Backward Difference
61
+ dX[:, :, :-1, 2] = X[:, :, 1:] - X[:, :, :-1] # Forward Difference
62
+
63
+ dX = dX.permute(4, 0, 1, 2, 3) if batched else dX
64
+ dX[..., 0] /= delta_lst[0]
65
+ dX[..., 1] /= delta_lst[1]
66
+ dX[..., 2] /= delta_lst[2]
67
+ return dX
68
+
69
+
70
+ def gradient_b(X, batched = False, delta_lst = [1., 1., 1.]):
71
+ '''
72
+ Compute gradient of a torch tensor "X" in each direction
73
+ Non-boundaries & Upper-boundaries: Backward Difference
74
+ Lower-boundaries: Forward Difference
75
+ if X is batched: (n_batch, ...);
76
+ else: (...)
77
+ '''
78
+ device = X.device
79
+ dim = len(X.size()) - 1 if batched else len(X.size())
80
+ #print(batched)
81
+ #print(dim)
82
+ if dim == 1:
83
+ #print('dim = 1')
84
+ dX = torch.zeros(X.size(), dtype = torch.float, device = device)
85
+ X = X.permute(1, 0) if batched else X
86
+ dX = dX.permute(1, 0) if batched else dX
87
+ dX[1:] = X[1:] - X[:-1] # Backward Difference
88
+ dX[0] = X[1] - X[0] # Forward Difference
89
+
90
+ dX = dX.permute(1, 0) if batched else dX
91
+ dX /= delta_lst[0]
92
+ elif dim == 2:
93
+ #print('dim = 2')
94
+ dX = torch.zeros(X.size() + tuple([2]), dtype = torch.float, device = device)
95
+ X = X.permute(1, 2, 0) if batched else X
96
+ dX = dX.permute(1, 2, 3, 0) if batched else dX # put batch to last dim
97
+ dX[1:, :, 0] = X[1:, :] - X[:-1, :] # Backward Difference
98
+ dX[0, :, 0] = X[1] - X[0] # Forward Difference
99
+
100
+ dX[:, 1:, 1] = X[:, 1:] - X[:, :-1] # Backward Difference
101
+ dX[:, 0, 1] = X[:, 1] - X[:, 0] # Forward Difference
102
+
103
+ dX = dX.permute(3, 0, 1, 2) if batched else dX
104
+ dX[..., 0] /= delta_lst[0]
105
+ dX[..., 1] /= delta_lst[1]
106
+ elif dim == 3:
107
+ #print('dim = 3')
108
+ dX = torch.zeros(X.size() + tuple([3]), dtype = torch.float, device = device)
109
+ X = X.permute(1, 2, 3, 0) if batched else X
110
+ dX = dX.permute(1, 2, 3, 4, 0) if batched else dX
111
+ dX[1:, :, :, 0] = X[1:, :, :] - X[:-1, :, :] # Backward Difference
112
+ dX[0, :, :, 0] = X[1] - X[0] # Forward Difference
113
+
114
+ dX[:, 1:, :, 1] = X[:, 1:] - X[:, :-1] # Backward Difference
115
+ dX[:, 0, :, 1] = X[:, 1] - X[:, 0] # Forward Difference
116
+
117
+ dX[:, :, 1:, 2] = X[:, :, 1:] - X[:, :, :-1] # Backward Difference
118
+ dX[:, :, 0, 2] = X[:, :, 1] - X[:, :, 0] # Forward Difference
119
+
120
+ dX = dX.permute(4, 0, 1, 2, 3) if batched else dX
121
+ dX[..., 0] /= delta_lst[0]
122
+ dX[..., 1] /= delta_lst[1]
123
+ dX[..., 2] /= delta_lst[2]
124
+ return dX
125
+
126
+
127
+ def gradient_c(X, batched = False, delta_lst = [1., 1., 1.]):
128
+ '''
129
+ Compute gradient of a torch tensor "X" in each direction
130
+ Non-boundaries: Central Difference
131
+ Upper-boundaries: Backward Difference
132
+ Lower-boundaries: Forward Difference
133
+ if X is batched: (n_batch, ...);
134
+ else: (...)
135
+ '''
136
+ device = X.device
137
+ dim = len(X.size()) - 1 if batched else len(X.size())
138
+ #print(X.size())
139
+ #print(batched)
140
+ #print(dim)
141
+ if dim == 1:
142
+ #print('dim = 1')
143
+ dX = torch.zeros(X.size(), dtype = torch.float, device = device)
144
+ X = X.permute(1, 0) if batched else X
145
+ dX = dX.permute(1, 0) if batched else dX
146
+ dX[1:-1] = (X[2:] - X[:-2]) / 2 # Central Difference
147
+ dX[0] = X[1] - X[0] # Forward Difference
148
+ dX[-1] = X[-1] - X[-2] # Backward Difference
149
+
150
+ dX = dX.permute(1, 0) if batched else dX
151
+ dX /= delta_lst[0]
152
+ elif dim == 2:
153
+ #print('dim = 2')
154
+ dX = torch.zeros(X.size() + tuple([2]), dtype = torch.float, device = device)
155
+ X = X.permute(1, 2, 0) if batched else X
156
+ dX = dX.permute(1, 2, 3, 0) if batched else dX # put batch to last dim
157
+ dX[1:-1, :, 0] = (X[2:, :] - X[:-2, :]) / 2
158
+ dX[0, :, 0] = X[1] - X[0]
159
+ dX[-1, :, 0] = X[-1] - X[-2]
160
+ dX[:, 1:-1, 1] = (X[:, 2:] - X[:, :-2]) / 2
161
+ dX[:, 0, 1] = X[:, 1] - X[:, 0]
162
+ dX[:, -1, 1] = X[:, -1] - X[:, -2]
163
+
164
+ dX = dX.permute(3, 0, 1, 2) if batched else dX
165
+ dX[..., 0] /= delta_lst[0]
166
+ dX[..., 1] /= delta_lst[1]
167
+ elif dim == 3:
168
+ #print('dim = 3')
169
+ dX = torch.zeros(X.size() + tuple([3]), dtype = torch.float, device = device)
170
+ X = X.permute(1, 2, 3, 0) if batched else X
171
+ dX = dX.permute(1, 2, 3, 4, 0) if batched else dX
172
+ dX[1:-1, :, :, 0] = (X[2:, :, :] - X[:-2, :, :]) / 2
173
+ dX[0, :, :, 0] = X[1] - X[0]
174
+ dX[-1, :, :, 0] = X[-1] - X[-2]
175
+ dX[:, 1:-1, :, 1] = (X[:, 2:, :] - X[:, :-2, :]) / 2
176
+ dX[:, 0, :, 1] = X[:, 1] - X[:, 0]
177
+ dX[:, -1, :, 1] = X[:, -1] - X[:, -2]
178
+ dX[:, :, 1:-1, 2] = (X[:, :, 2:] - X[:, :, :-2]) / 2
179
+ dX[:, :, 0, 2] = X[:, :, 1] - X[:, :, 0]
180
+ dX[:, :, -1, 2] = X[:, :, -1] - X[:, :, -2]
181
+
182
+ dX = dX.permute(4, 0, 1, 2, 3) if batched else dX
183
+ dX[..., 0] /= delta_lst[0]
184
+ dX[..., 1] /= delta_lst[1]
185
+ dX[..., 2] /= delta_lst[2]
186
+ return dX
187
+
188
+
189
+ def gradient_c_numpy(X, batched = False, delta_lst = [1., 1., 1.]):
190
+ '''
191
+ Compute gradient of a Numpy array "X" in each direction
192
+ Non-boundaries: Central Difference
193
+ Upper-boundaries: Backward Difference
194
+ Lower-boundaries: Forward Difference
195
+ if X is batched: (n_batch, ...);
196
+ else: (...)
197
+ '''
198
+ dim = len(X.shape) - 1 if batched else len(X.shape)
199
+ #print(dim)
200
+ if dim == 1:
201
+ #print('dim = 1')
202
+ X = np.transpose(X, (1, 0)) if batched else X
203
+ dX = np.zeros(X.shapee).astype(float)
204
+ dX[1:-1] = (X[2:] - X[:-2]) / 2 # Central Difference
205
+ dX[0] = X[1] - X[0] # Forward Difference
206
+ dX[-1] = X[-1] - X[-2] # Backward Difference
207
+
208
+ dX = np.transpose(X, (1, 0)) if batched else dX
209
+ dX /= delta_lst[0]
210
+ elif dim == 2:
211
+ #print('dim = 2')
212
+ dX = np.zeros(X.shape + tuple([2])).astype(float)
213
+ X = np.transpose(X, (1, 2, 0)) if batched else X
214
+ dX = np.transpose(dX, (1, 2, 3, 0)) if batched else dX # put batch to last dim
215
+ dX[1:-1, :, 0] = (X[2:, :] - X[:-2, :]) / 2
216
+ dX[0, :, 0] = X[1] - X[0]
217
+ dX[-1, :, 0] = X[-1] - X[-2]
218
+ dX[:, 1:-1, 1] = (X[:, 2:] - X[:, :-2]) / 2
219
+ dX[:, 0, 1] = X[:, 1] - X[:, 0]
220
+ dX[:, -1, 1] = X[:, -1] - X[:, -2]
221
+
222
+ dX = np.transpose(dX, (3, 0, 1, 2)) if batched else dX
223
+ dX[..., 0] /= delta_lst[0]
224
+ dX[..., 1] /= delta_lst[1]
225
+ elif dim == 3:
226
+ #print('dim = 3')
227
+ dX = np.zeros(X.shape + tuple([3])).astype(float)
228
+ X = np.transpose(X, (1, 2, 3, 0)) if batched else X
229
+ dX = np.transpose(dX, (1, 2, 3, 4, 0)) if batched else dX # put batch to last dim
230
+ dX[1:-1, :, :, 0] = (X[2:, :, :] - X[:-2, :, :]) / 2
231
+ dX[0, :, :, 0] = X[1] - X[0]
232
+ dX[-1, :, :, 0] = X[-1] - X[-2]
233
+ dX[:, 1:-1, :, 1] = (X[:, 2:, :] - X[:, :-2, :]) / 2
234
+ dX[:, 0, :, 1] = X[:, 1] - X[:, 0]
235
+ dX[:, -1, :, 1] = X[:, -1] - X[:, -2]
236
+ dX[:, :, 1:-1, 2] = (X[:, :, 2:] - X[:, :, :-2]) / 2
237
+ dX[:, :, 0, 2] = X[:, :, 1] - X[:, :, 0]
238
+ dX[:, :, -1, 2] = X[:, :, -1] - X[:, :, -2]
239
+
240
+ dX = np.transpose(dX, (4, 0, 1, 2, 3)) if batched else dX
241
+ dX[..., 0] /= delta_lst[0]
242
+ dX[..., 1] /= delta_lst[1]
243
+ dX[..., 2] /= delta_lst[2]
244
+ return dX
245
+
246
+
247
+ def gradient_f_numpy(X, batched = False, delta_lst = [1., 1., 1.]):
248
+ '''
249
+ Compute gradient of a torch tensor "X" in each direction
250
+ Upper-boundaries: Backward Difference
251
+ Non-boundaries & Upper-boundaries: Forward Difference
252
+ if X is batched: (n_batch, ...);
253
+ else: (...)
254
+ '''
255
+ dim = len(X.shape) - 1 if batched else len(X.shape)
256
+ #print(dim)
257
+ if dim == 1:
258
+ #print('dim = 1')
259
+ X = np.transpose(X, (1, 0)) if batched else X
260
+ dX = np.zeros(X.shapee).astype(float)
261
+ dX[-1] = X[-1] - X[-2] # Backward Difference
262
+ dX[:-1] = X[1:] - X[:-1] # Forward Difference
263
+
264
+ dX = np.transpose(X, (1, 0)) if batched else dX
265
+ dX /= delta_lst[0]
266
+ elif dim == 2:
267
+ #print('dim = 2')
268
+ dX = np.zeros(X.shape + tuple([2])).astype(float)
269
+ X = np.transpose(X, (1, 2, 0)) if batched else X
270
+ dX = np.transpose(dX, (1, 2, 3, 0)) if batched else dX # put batch to last dim
271
+ dX[-1, :, 0] = X[-1, :] - X[-2, :] # Backward Difference
272
+ dX[:-1, :, 0] = X[1:] - X[:-1] # Forward Difference
273
+
274
+ dX[:, -1, 1] = X[:, -1] - X[:, -2] # Backward Difference
275
+ dX[:, :-1, 1] = X[:, 1:] - X[:, :-1] # Forward Difference
276
+
277
+ dX = np.transpose(dX, (3, 0, 1, 2)) if batched else dX
278
+ dX[..., 0] /= delta_lst[0]
279
+ dX[..., 1] /= delta_lst[1]
280
+ elif dim == 3:
281
+ #print('dim = 3')
282
+ dX = np.zeros(X.shape + tuple([3])).astype(float)
283
+ X = np.transpose(X, (1, 2, 3, 0)) if batched else X
284
+ dX = np.transpose(dX, (1, 2, 3, 4, 0)) if batched else dX # put batch to last dim
285
+ dX[-1, :, :, 0] = X[-1, :, :] - X[-2, :, :] # Backward Difference
286
+ dX[:-1, :, :, 0] = X[1:] - X[:-1] # Forward Difference
287
+
288
+ dX[:, -1, :, 1] = X[:, -1] - X[:, -2] # Backward Difference
289
+ dX[:, :-1, :, 1] = X[:, 1:] - X[:, :-1] # Forward Difference
290
+
291
+ dX[:, :, -1, 2] = X[:, :, -1] - X[:, :, -2] # Backward Difference
292
+ dX[:, :, :-1, 2] = X[:, :, 1:] - X[:, :, :-1] # Forward Difference
293
+
294
+ dX = np.transpose(dX, (4, 0, 1, 2, 3)) if batched else dX
295
+ dX[..., 0] /= delta_lst[0]
296
+ dX[..., 1] /= delta_lst[1]
297
+ dX[..., 2] /= delta_lst[2]
298
+ return dX
299
+
300
+
301
+ class Upwind(object):
302
+ '''
303
+ Backward if > 0, forward if <= 0
304
+ '''
305
+ def __init__(self, U, data_spacing = [1., 1, 1.], batched = True):
306
+ self.U = U # (s, r, c)
307
+ self.batched = batched
308
+ self.data_spacing = data_spacing
309
+ self.dim = len(self.U.size()) - 1 if batched else len(self.U.size())
310
+ self.I = torch.ones(self.U.size(), dtype = torch.float, device = U.device)
311
+
312
+ def dX(self, FGx):
313
+ dXf = gradient_f(self.U, batched = self.batched, delta_lst = self.data_spacing)[..., 0]
314
+ dXb = gradient_b(self.U, batched = self.batched, delta_lst = self.data_spacing)[..., 0]
315
+ Xflag = (FGx > 0).float()
316
+ return dXf * (self.I - Xflag) + dXb * Xflag
317
+
318
+ def dY(self, FGy):
319
+ dYf = gradient_f(self.U, batched = self.batched, delta_lst = self.data_spacing)[..., 1]
320
+ dYb = gradient_b(self.U, batched = self.batched, delta_lst = self.data_spacing)[..., 1]
321
+ Yflag = (FGy > 0).float()
322
+ return dYf * (self.I - Yflag) + dYb * Yflag
323
+
324
+ def dZ(self, FGz):
325
+ dZf = gradient_f(self.U, batched = self.batched, delta_lst = self.data_spacing)[..., 2]
326
+ dZb = gradient_b(self.U, batched = self.batched, delta_lst = self.data_spacing)[..., 2]
327
+ Zflag = (FGz > 0).float()
328
+ return dZf * (self.I - Zflag) + dZb * Zflag
329
+
330
+
331
+ class AdvDiffPartial(nn.Module):
332
+ def __init__(self, data_spacing, device):
333
+ super(AdvDiffPartial, self).__init__()
334
+ self.dimension = len(data_spacing) # (slc, row, col)
335
+ self.device = device
336
+ self.data_spacing = data_spacing
337
+
338
+ @property
339
+ def Grad_Ds(self):
340
+ return {
341
+ 'constant': self.Grad_constantD,
342
+ 'scalar': self.Grad_scalarD,
343
+ 'diag': self.Grad_diagD,
344
+ 'full': self.Grad_fullD,
345
+ 'full_dual': self.Grad_fullD,
346
+ 'full_spectral':self.Grad_fullD,
347
+ 'full_cholesky': self.Grad_fullD,
348
+ 'full_symmetric': self.Grad_fullD
349
+ }
350
+ @property
351
+ def Grad_Vs(self):
352
+ return {
353
+ 'constant': self.Grad_constantV,
354
+ 'scalar': self.Grad_scalarV,
355
+ 'vector': self.Grad_vectorV, # For general V w/o div-free TODO self.Grad_vectorV
356
+ 'vector_div_free': self.Grad_div_free_vectorV,
357
+ 'vector_div_free_clebsch': self.Grad_div_free_vectorV,
358
+ 'vector_div_free_stream': self.Grad_div_free_vectorV,
359
+ 'vector_div_free_stream_gauge': self.Grad_div_free_vectorV,
360
+ }
361
+
362
+ def Grad_constantD(self, C, Dlst):
363
+ if self.dimension == 1:
364
+ return Dlst['D'] * (self.ddXc(C))
365
+ elif self.dimension == 2:
366
+ return Dlst['D'] * (self.ddXc(C) + self.ddYc(C))
367
+ elif self.dimension == 3:
368
+ return Dlst['D'] * (self.ddXc(C) + self.ddYc(C) + self.ddZc(C))
369
+
370
+ def Grad_constant_tensorD(self, C, Dlst):
371
+ if self.dimension == 1:
372
+ raise NotImplementedError
373
+ elif self.dimension == 2:
374
+ dC_c = self.dc(C)
375
+ dC_f = self.df(C)
376
+ return Dlst['Dxx'] * self.dXb(dC_f[..., 0]) +\
377
+ Dlst['Dxy'] * self.dXb(dC_f[..., 1]) + Dlst['Dxy'] * self.dYb(dC_f[..., 0]) +\
378
+ Dlst['Dyy'] * self.dYb(dC_f[..., 1])
379
+ elif self.dimension == 3:
380
+ dC_c = self.dc(C)
381
+ dC_f = self.df(C)
382
+ return Dlst['Dxx'] * self.dXb(dC_f[..., 0]) + Dlst['Dyy'] * self.dYb(dC_f[..., 1]) + Dlst['Dzz'] * self.dZb(dC_f[..., 2]) + \
383
+ Dlst['Dxy'] * (self.dXb(dC_f[..., 1]) + self.dYb(dC_f[..., 0])) + \
384
+ Dlst['Dyz'] * (self.dYb(dC_f[..., 2]) + self.dZb(dC_f[..., 1])) + \
385
+ Dlst['Dxz'] * (self.dZb(dC_f[..., 0]) + self.dXb(dC_f[..., 2]))
386
+
387
+ def Grad_scalarD(self, C, Dlst): # batch_C: (batch_size, (slc), row, col)
388
+ # Expanded version: \nabla (D \nabla C) => \nabla D \cdot \nabla C (part (a)) + D \Delta C (part (b)) #
389
+ # NOTE: Work better than Central Differences !!! #
390
+ # Nested Forward-Backward Difference Scheme in part (b)#
391
+ if self.dimension == 1:
392
+ dC = gradient_c(C, batched = True, delta_lst = self.data_spacing)
393
+ return gradient_c(Dlst['D'], batched = True, delta_lst = self.data_spacing) * dC + \
394
+ Dlst['D'] * gradient_c(dC, batched = True, delta_lst = self.data_spacing)
395
+ else: # (dimension = 2 or 3)
396
+ dC_c = gradient_c(C, batched = True, delta_lst = self.data_spacing)
397
+ dC_f = gradient_f(C, batched = True, delta_lst = self.data_spacing)
398
+ dD_c = gradient_c(Dlst['D'], batched = True, delta_lst = self.data_spacing)
399
+ out = (dD_c * dC_c).sum(-1)
400
+ for dim in range(dC_f.size(-1)):
401
+ out += Dlst['D'] * gradient_b(dC_f[..., dim], batched = True, delta_lst = self.data_spacing)[..., dim]
402
+ return out
403
+
404
+ def Grad_diagD(self, C, Dlst):
405
+ # Expanded version #
406
+ if self.dimension == 1:
407
+ raise NotImplementedError('diag_D is not supported for 1D version of diffusivity')
408
+ elif self.dimension == 2:
409
+ dC_c = self.dc(C)
410
+ dC_f = self.df(C)
411
+ return self.dXc(Dlst['Dxx']) * dC_c[..., 0] + Dlst['Dxx'] * self.dXb(dC_f[..., 0]) +\
412
+ self.dYc(Dlst['Dyy']) * dC_c[..., 1] + Dlst['Dyy'] * self.dYb(dC_f[..., 1])
413
+ elif self.dimension == 3:
414
+ dC_c = self.dc(C)
415
+ dC_f = self.df(C)
416
+ return self.dXc(Dlst['Dxx']) * dC_c[..., 0] + Dlst['Dxx'] * self.dXb(dC_f[..., 0]) +\
417
+ self.dYc(Dlst['Dyy']) * dC_c[..., 1] + Dlst['Dyy'] * self.dYb(dC_f[..., 1]) +\
418
+ self.dZc(Dlst['Dzz']) * dC_c[..., 2] + Dlst['Dzz'] * self.dZb(dC_f[..., 2])
419
+
420
+ def Grad_fullD(self, C, Dlst):
421
+ # Expanded version #
422
+ '''https://github.com/uncbiag/PIANOinD/blob/master/Doc/PIANOinD.pdf'''
423
+ if self.dimension == 1:
424
+ raise NotImplementedError('full_D is not supported for 1D version of diffusivity')
425
+ elif self.dimension == 2:
426
+ dC_c = self.dc(C)
427
+ dC_f = self.df(C)
428
+ return self.dXc(Dlst['Dxx']) * dC_c[..., 0] + Dlst['Dxx'] * self.dXb(dC_f[..., 0]) +\
429
+ self.dXc(Dlst['Dxy']) * dC_c[..., 1] + Dlst['Dxy'] * self.dXb(dC_f[..., 1]) +\
430
+ self.dYc(Dlst['Dxy']) * dC_c[..., 0] + Dlst['Dxy'] * self.dYb(dC_f[..., 0]) +\
431
+ self.dYc(Dlst['Dyy']) * dC_c[..., 1] + Dlst['Dyy'] * self.dYb(dC_f[..., 1])
432
+ elif self.dimension == 3:
433
+ dC_c = self.dc(C)
434
+ dC_f = self.df(C)
435
+ return (self.dXc(Dlst['Dxx']) + self.dYc(Dlst['Dxy']) + self.dZc(Dlst['Dxz'])) * dC_c[..., 0] + \
436
+ (self.dXc(Dlst['Dxy']) + self.dYc(Dlst['Dyy']) + self.dZc(Dlst['Dyz'])) * dC_c[..., 1] + \
437
+ (self.dXc(Dlst['Dxz']) + self.dYc(Dlst['Dyz']) + self.dZc(Dlst['Dzz'])) * dC_c[..., 2] + \
438
+ Dlst['Dxx'] * self.dXb(dC_f[..., 0]) + Dlst['Dyy'] * self.dYb(dC_f[..., 1]) + Dlst['Dzz'] * self.dZb(dC_f[..., 2]) + \
439
+ Dlst['Dxy'] * (self.dXb(dC_f[..., 1]) + self.dYb(dC_f[..., 0])) + \
440
+ Dlst['Dyz'] * (self.dYb(dC_f[..., 2]) + self.dZb(dC_f[..., 1])) + \
441
+ Dlst['Dxz'] * (self.dZb(dC_f[..., 0]) + self.dXb(dC_f[..., 2]))
442
+
443
+ def Grad_constantV(self, C, Vlst):
444
+ if len(Vlst['V'].size()) == 1:
445
+ if self.dimension == 1:
446
+ return - Vlst['V'] * self.dXb(C) if Vlst['V'] > 0 else - Vlst['V'] * self.dXf(C)
447
+ elif self.dimension == 2:
448
+ return - Vlst['V'] * (self.dXb(C) + self.dYb(C)) if Vlst['V'] > 0 else - Vlst['V'] * (self.dXf(C) + self.dYf(C))
449
+ elif self.dimension == 3:
450
+ return - Vlst['V'] * (self.dXb(C) + self.dYb(C) + self.dZb(C)) if Vlst['V'] > 0 else - Vlst['V'] * (self.dXf(C) + self.dYf(C) + self.dZf(C))
451
+ else:
452
+ if self.dimension == 1:
453
+ return - Vlst['V'] * self.dXb(C) if Vlst['V'][0, 0] > 0 else - Vlst['V'] * self.dXf(C)
454
+ elif self.dimension == 2:
455
+ return - Vlst['V'] * (self.dXb(C) + self.dYb(C)) if Vlst['V'][0, 0, 0] > 0 else - Vlst['V'] * (self.dXf(C) + self.dYf(C))
456
+ elif self.dimension == 3:
457
+ return - Vlst['V'] * (self.dXb(C) + self.dYb(C) + self.dZb(C)) if Vlst['V'][0, 0, 0, 0] > 0 else - Vlst['V'] * (self.dXf(C) + self.dYf(C) + self.dZf(C))
458
+
459
+ def Grad_constant_vectorV(self, C, Vlst):
460
+ if self.dimension == 1:
461
+ raise NotImplementedError
462
+ elif self.dimension == 2:
463
+ out_x = - Vlst['Vx'] * (self.dXb(C) + self.dYb(C)) if Vlst['Vx'][0, 0, 0] > 0 else - Vlst['Vx'] * (self.dXf(C) + self.dYf(C))
464
+ out_y = - Vlst['Vy'] * (self.dXb(C) + self.dYb(C)) if Vlst['Vy'][0, 0, 0] > 0 else - Vlst['Vy'] * (self.dXf(C) + self.dYf(C))
465
+ return out_x + out_y
466
+ elif self.dimension == 3:
467
+ out_x = - Vlst['Vx'] * (self.dXb(C) + self.dYb(C)) if Vlst['Vx'][0, 0, 0] > 0 else - Vlst['Vx'] * (self.dXf(C) + self.dYf(C))
468
+ out_y = - Vlst['Vy'] * (self.dXb(C) + self.dYb(C)) if Vlst['Vy'][0, 0, 0] > 0 else - Vlst['Vy'] * (self.dXf(C) + self.dYf(C))
469
+ out_z = - Vlst['Vz'] * (self.dXb(C) + self.dYb(C)) if Vlst['Vz'][0, 0, 0] > 0 else - Vlst['Vz'] * (self.dXf(C) + self.dYf(C))
470
+ return out_x + out_y + out_z
471
+
472
+ def Grad_SimscalarV(self, C, Vlst):
473
+ V = Vlst['V']
474
+ Upwind_C = Upwind(C, self.data_spacing)
475
+ if self.dimension == 1:
476
+ C_x = Upwind_C.dX(V)
477
+ return - V * C_x
478
+ if self.dimension == 2:
479
+ C_x, C_y = Upwind_C.dX(V), Upwind_C.dY(V)
480
+ return - V * (C_x + C_y)
481
+ if self.dimension == 3:
482
+ C_x, C_y, C_z = Upwind_C.dX(V), Upwind_C.dY(V), Upwind_C.dZ(V)
483
+ return - V * (C_x + C_y + C_z)
484
+
485
+ def Grad_scalarV(self, C, Vlst):
486
+ V = Vlst['V']
487
+ Upwind_C = Upwind(C, self.data_spacing)
488
+ dV = gradient_c(V, batched = True, delta_lst = self.data_spacing)
489
+ if self.dimension == 1:
490
+ C_x = Upwind_C.dX(V)
491
+ return - V * C_x - C * dV
492
+ elif self.dimension == 2:
493
+ C_x, C_y = Upwind_C.dX(V), Upwind_C.dY(V)
494
+ return - V * (C_x + C_y) - C * dV.sum(-1)
495
+ elif self.dimension == 3:
496
+ C_x, C_y, C_z = Upwind_C.dX(V), Upwind_C.dY(V), Upwind_C.dZ(V)
497
+ return - V * (C_x + C_y + C_z) - C * dV.sum(-1)
498
+
499
+ def Grad_div_free_vectorV(self, C, Vlst):
500
+ ''' For divergence-free-by-definition velocity'''
501
+ if self.dimension == 1:
502
+ raise NotImplementedError('clebschVector is not supported for 1D version of velocity')
503
+ Upwind_C = Upwind(C, self.data_spacing)
504
+ C_x, C_y = Upwind_C.dX(Vlst['Vx']), Upwind_C.dY(Vlst['Vy'])
505
+ if self.dimension == 2:
506
+ return - (Vlst['Vx'] * C_x + Vlst['Vy'] * C_y)
507
+ elif self.dimension == 3:
508
+ C_z = Upwind_C.dZ(Vlst['Vz'])
509
+ return - (Vlst['Vx'] * C_x + Vlst['Vy'] * C_y + Vlst['Vz'] * C_z)
510
+
511
+ def Grad_vectorV(self, C, Vlst):
512
+ ''' For general velocity'''
513
+ if self.dimension == 1:
514
+ raise NotImplementedError('vector is not supported for 1D version of velocity')
515
+ Upwind_C = Upwind(C, self.data_spacing)
516
+ C_x, C_y = Upwind_C.dX(Vlst['Vx']), Upwind_C.dY(Vlst['Vy'])
517
+ Vx_x = self.dXc(Vlst['Vx'])
518
+ Vy_y = self.dYc(Vlst['Vy'])
519
+ if self.dimension == 2:
520
+ return - (Vlst['Vx'] * C_x + Vlst['Vy'] * C_y) - C * (Vx_x + Vy_y)
521
+ if self.dimension == 3:
522
+ C_z = Upwind_C.dZ(Vlst['Vz'])
523
+ Vz_z = self.dZc(Vlst['Vz'])
524
+ return - (Vlst['Vx'] * C_x + Vlst['Vy'] * C_y + Vlst['Vz'] * C_z) - C * (Vx_x + Vy_y + Vz_z)
525
+
526
+ ################# Utilities #################
527
+ def db(self, X):
528
+ return gradient_b(X, batched = True, delta_lst = self.data_spacing)
529
+ def df(self, X):
530
+ return gradient_f(X, batched = True, delta_lst = self.data_spacing)
531
+ def dc(self, X):
532
+ return gradient_c(X, batched = True, delta_lst = self.data_spacing)
533
+ def dXb(self, X):
534
+ return gradient_b(X, batched = True, delta_lst = self.data_spacing)[..., 0]
535
+ def dXf(self, X):
536
+ return gradient_f(X, batched = True, delta_lst = self.data_spacing)[..., 0]
537
+ def dXc(self, X):
538
+ return gradient_c(X, batched = True, delta_lst = self.data_spacing)[..., 0]
539
+ def dYb(self, X):
540
+ return gradient_b(X, batched = True, delta_lst = self.data_spacing)[..., 1]
541
+ def dYf(self, X):
542
+ return gradient_f(X, batched = True, delta_lst = self.data_spacing)[..., 1]
543
+ def dYc(self, X):
544
+ return gradient_c(X, batched = True, delta_lst = self.data_spacing)[..., 1]
545
+ def dZb(self, X):
546
+ return gradient_b(X, batched = True, delta_lst = self.data_spacing)[..., 2]
547
+ def dZf(self, X):
548
+ return gradient_f(X, batched = True, delta_lst = self.data_spacing)[..., 2]
549
+ def dZc(self, X):
550
+ return gradient_c(X, batched = True, delta_lst = self.data_spacing)[..., 2]
551
+ def ddXc(self, X):
552
+ return gradient_b(gradient_f(X, batched = True, delta_lst = self.data_spacing)[..., 0],
553
+ batched = True, delta_lst = self.data_spacing)[..., 0]
554
+ def ddYc(self, X):
555
+ return gradient_b(gradient_f(X, batched = True, delta_lst = self.data_spacing)[..., 1],
556
+ batched = True, delta_lst = self.data_spacing)[..., 1]
557
+ def ddZc(self, X):
558
+ return gradient_b(gradient_f(X, batched = True, delta_lst = self.data_spacing)[..., 2],
559
+ batched = True, delta_lst = self.data_spacing)[..., 2]
560
+
561
+
562
+
563
+ class AdvDiffPDE(nn.Module):
564
+ '''
565
+ Plain advection-diffusion PDE solver for pre-set V_lst and D_lst (1D, 2D, 3D) for forward time series simulation
566
+ '''
567
+ def __init__(self, data_spacing, perf_pattern, D_type='scalar', V_type='vector', BC=None, dt=0.1, V_dict={}, D_dict={}, stochastic=False, device='cpu'):
568
+ super(AdvDiffPDE, self).__init__()
569
+ self.BC = BC
570
+ self.dt = dt
571
+ self.dimension = len(data_spacing)
572
+ self.perf_pattern = perf_pattern
573
+ self.partials = AdvDiffPartial(data_spacing, device)
574
+ self.D_type, self.V_type = D_type, V_type
575
+ self.stochastic = stochastic
576
+ self.V_dict, self.D_dict = V_dict, D_dict
577
+ self.Sigma, self.Sigma_V, self.Sigma_D = 0., 0., 0. # Only for initialization #
578
+ if self.dimension == 1:
579
+ self.neumann_BC = torch.nn.ReplicationPad1d(1)
580
+ elif self.dimension == 2:
581
+ self.neumann_BC = torch.nn.ReplicationPad2d(1)
582
+ elif self.dimension == 3:
583
+ self.neumann_BC = torch.nn.ReplicationPad3d(1)
584
+ else:
585
+ raise ValueError('Unsupported dimension: %d' % self.dimension)
586
+
587
+ @property
588
+ def set_BC(self):
589
+ # NOTE For bondary condition of mass concentration #
590
+ '''X: (n_batch, spatial_shape)'''
591
+ if self.BC == 'neumann' or self.BC == 'cauchy':
592
+ if self.dimension == 1:
593
+ return lambda X: self.neumann_BC(X[:, 1:-1].unsqueeze(dim=1))[:,0]
594
+ elif self.dimension == 2:
595
+ return lambda X: self.neumann_BC(X[:, 1:-1, 1:-1].unsqueeze(dim=1))[:,0]
596
+ elif self.dimension == 3:
597
+ return lambda X: self.neumann_BC(X[:, 1:-1, 1:-1, 1:-1].unsqueeze(dim=1))[:,0]
598
+ else:
599
+ raise NotImplementedError('Unsupported B.C.!')
600
+ elif self.BC == 'dirichlet_neumann' or self.BC == 'source_neumann':
601
+ ctrl_wdth = 1
602
+ if self.dimension == 1:
603
+ self.dirichlet_BC = torch.nn.ReplicationPad1d(ctrl_wdth)
604
+ return lambda X: self.dirichlet_BC(X[:, ctrl_wdth : -ctrl_wdth].unsqueeze(dim=1))[:,0]
605
+ elif self.dimension == 2:
606
+ self.dirichlet_BC = torch.nn.ReplicationPad2d(ctrl_wdth)
607
+ return lambda X: self.dirichlet_BC(X[:, ctrl_wdth : -ctrl_wdth, ctrl_wdth : -ctrl_wdth].unsqueeze(dim=1))[:,0]
608
+ elif self.dimension == 3:
609
+ self.dirichlet_BC = torch.nn.ReplicationPad3d(ctrl_wdth)
610
+ return lambda X: self.neumann_dirichlet_BCBC(X[:, ctrl_wdth : -ctrl_wdth, ctrl_wdth : -ctrl_wdth, ctrl_wdth : -ctrl_wdth].unsqueeze(dim=1))[:,0]
611
+ else:
612
+ raise NotImplementedError('Unsupported B.C.!')
613
+ else:
614
+ return lambda X: X
615
+
616
+ def forward(self, t, batch_C):
617
+ '''
618
+ t: (batch_size,)
619
+ batch_C: (batch_size, (slc,) row, col)
620
+ '''
621
+ batch_size = batch_C.size(0)
622
+ batch_C = self.set_BC(batch_C)
623
+ if 'diff' not in self.perf_pattern:
624
+ out = self.partials.Grad_Vs[self.V_type](batch_C, self.V_dict)
625
+ if self.stochastic:
626
+ out = out + self.Sigma * math.sqrt(self.dt) * torch.randn_like(batch_C).to(batch_C)
627
+ elif 'adv' not in self.perf_pattern:
628
+ out = self.partials.Grad_Ds[self.D_type](batch_C, self.D_dict)
629
+ if self.stochastic:
630
+ out = out + self.Sigma * math.sqrt(self.dt) * torch.randn_like(batch_C).to(batch_C)
631
+ else:
632
+ if self.stochastic:
633
+ out_D = self.partials.Grad_Ds[self.D_type](batch_C, self.D_dict)
634
+ out_V = self.partials.Grad_Vs[self.V_type](batch_C, self.V_dict)
635
+ out = out_D + out_V + self.Sigma * math.sqrt(self.dt) * torch.randn_like(batch_C).to(batch_C)
636
+ else:
637
+ out_V = self.partials.Grad_Vs[self.V_type](batch_C, self.V_dict)
638
+ out_D = self.partials.Grad_Ds[self.D_type](batch_C, self.D_dict)
639
+ out = out_V + out_D
640
+ return out
641
+
642
+
643
+
ShapeID/DiffEqs/rk_common.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Based on https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/integrate
2
+ import collections
3
+ from ShapeID.DiffEqs.misc import _scaled_dot_product, _convert_to_tensor
4
+
5
+ _ButcherTableau = collections.namedtuple('_ButcherTableau', 'alpha beta c_sol c_error')
6
+
7
+
8
+ class _RungeKuttaState(collections.namedtuple('_RungeKuttaState', 'y1, f1, t0, t1, dt, interp_coeff')):
9
+ """Saved state of the Runge Kutta solver.
10
+
11
+ Attributes:
12
+ y1: Tensor giving the function value at the end of the last time step.
13
+ f1: Tensor giving derivative at the end of the last time step.
14
+ t0: scalar float64 Tensor giving start of the last time step.
15
+ t1: scalar float64 Tensor giving end of the last time step.
16
+ dt: scalar float64 Tensor giving the size for the next time step.
17
+ interp_coef: list of Tensors giving coefficients for polynomial
18
+ interpolation between `t0` and `t1`.
19
+ """
20
+
21
+
22
+ def _runge_kutta_step(func, y0, f0, t0, dt, tableau):
23
+ """Take an arbitrary Runge-Kutta step and estimate error.
24
+
25
+ Args:
26
+ func: Function to evaluate like `func(t, y)` to compute the time derivative
27
+ of `y`.
28
+ y0: Tensor initial value for the state.
29
+ f0: Tensor initial value for the derivative, computed from `func(t0, y0)`.
30
+ t0: float64 scalar Tensor giving the initial time.
31
+ dt: float64 scalar Tensor giving the size of the desired time step.
32
+ tableau: optional _ButcherTableau describing how to take the Runge-Kutta
33
+ step.
34
+ name: optional name for the operation.
35
+
36
+ Returns:
37
+ Tuple `(y1, f1, y1_error, k)` giving the estimated function value after
38
+ the Runge-Kutta step at `t1 = t0 + dt`, the derivative of the state at `t1`,
39
+ estimated error at `t1`, and a list of Runge-Kutta coefficients `k` used for
40
+ calculating these terms.
41
+ """
42
+ dtype = y0[0].dtype
43
+ device = y0[0].device
44
+
45
+ t0 = _convert_to_tensor(t0, dtype=dtype, device=device)
46
+ dt = _convert_to_tensor(dt, dtype=dtype, device=device)
47
+
48
+ k = tuple(map(lambda x: [x], f0))
49
+ for alpha_i, beta_i in zip(tableau.alpha, tableau.beta):
50
+ ti = t0 + alpha_i * dt
51
+ yi = tuple(y0_ + _scaled_dot_product(dt, beta_i, k_) for y0_, k_ in zip(y0, k))
52
+ tuple(k_.append(f_) for k_, f_ in zip(k, func(ti, yi)))
53
+
54
+ if not (tableau.c_sol[-1] == 0 and tableau.c_sol[:-1] == tableau.beta[-1]):
55
+ # This property (true for Dormand-Prince) lets us save a few FLOPs.
56
+ yi = tuple(y0_ + _scaled_dot_product(dt, tableau.c_sol, k_) for y0_, k_ in zip(y0, k))
57
+
58
+ y1 = yi
59
+ f1 = tuple(k_[-1] for k_ in k)
60
+ y1_error = tuple(_scaled_dot_product(dt, tableau.c_error, k_) for k_ in k)
61
+ return (y1, f1, y1_error, k)
62
+
63
+
64
+ def rk4_step_func(func, t, dt, y, k1=None):
65
+ if k1 is None: k1 = func(t, y)
66
+ k2 = func(t + dt / 2, tuple(y_ + dt * k1_ / 2 for y_, k1_ in zip(y, k1)))
67
+ k3 = func(t + dt / 2, tuple(y_ + dt * k2_ / 2 for y_, k2_ in zip(y, k2)))
68
+ k4 = func(t + dt, tuple(y_ + dt * k3_ for y_, k3_ in zip(y, k3)))
69
+ return tuple((k1_ + 2 * k2_ + 2 * k3_ + k4_) * (dt / 6) for k1_, k2_, k3_, k4_ in zip(k1, k2, k3, k4))
70
+
71
+
72
+ def rk4_alt_step_func(func, t, dt, y, k1=None):
73
+ """Smaller error with slightly more compute."""
74
+ if k1 is None: k1 = func(t, y)
75
+ k2 = func(t + dt / 3, tuple(y_ + dt * k1_ / 3 for y_, k1_ in zip(y, k1)))
76
+ k3 = func(t + dt * 2 / 3, tuple(y_ + dt * (k1_ / -3 + k2_) for y_, k1_, k2_ in zip(y, k1, k2)))
77
+ k4 = func(t + dt, tuple(y_ + dt * (k1_ - k2_ + k3_) for y_, k1_, k2_, k3_ in zip(y, k1, k2, k3)))
78
+ return tuple((k1_ + 3 * k2_ + 3 * k3_ + k4_) * (dt / 8) for k1_, k2_, k3_, k4_ in zip(k1, k2, k3, k4))
ShapeID/DiffEqs/solvers.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ import torch
3
+ from ShapeID.DiffEqs.misc import _assert_increasing, _handle_unused_kwargs
4
+
5
+ def set_BC_2D(X, BCs): # X: (n_batch, spatial_size); BCs: (batch, 4, BC_shape, data_dim)
6
+ BC_size = BCs.size(2)
7
+ X[:, : BC_size] = BCs[:, 0]
8
+ X[:, - BC_size :] = BCs[:, 1]
9
+ X[:, :, : BC_size] = BCs[:, 2].permute(0, 2, 1) # (batch, BC_shape, r) -> (batch, r, BC_shape)
10
+ X[:, :, - BC_size :] = BCs[:, 3].permute(0, 2, 1) # (batch, BC_shape, r) -> (batch, r, BC_shape)
11
+ del BCs
12
+ return X
13
+ def set_BC_3D(X, BCs): # X: (n_batch, spatial_size); BCs: (batch, 6, BC_shape, data_dim, dta_dim)
14
+ BC_size = BCs.size(2)
15
+ X[:, : BC_size] = BCs[:, 0]
16
+ X[:, - BC_size :] = BCs[:, 1]
17
+ X[:, :, : BC_size] = BCs[:, 2].permute(0, 2, 1, 3) # (batch, BC_shape, s, c) -> (batch, s, BC_shape, c)
18
+ X[:, :, - BC_size :] = BCs[:, 3].permute(0, 2, 1, 3) # (batch, BC_shape, s, c) -> (batch, s, BC_shape, c)
19
+ X[:, :, :, : BC_size] = BCs[:, 4].permute(0, 2, 3, 1) # (batch, BC_shape, s, r) -> (batch, s, r, BC_shape)
20
+ X[:, :, :, - BC_size :] = BCs[:, 5].permute(0, 2, 3, 1) # (batch, BC_shape, s, r) -> (batch, s, r, BC_shape)
21
+ del BCs
22
+ return X
23
+
24
+ ''' X[t] = X[t] + dBC[t] (dBC[t] = BC[t+1] - BC[t]) '''
25
+ def add_dBC_2D(X, dBCs): # X: (n_batch, spatial_size); BCs: (batch, 4, BC_shape, data_dim)
26
+ BC_size = dBCs.size(2)
27
+ X[:, : BC_size] += dBCs[:, 0]
28
+ X[:, - BC_size :] += dBCs[:, 1]
29
+ X[:, :, : BC_size] += dBCs[:, 2].permute(0, 2, 1) # (batch, BC_shape, r) -> (batch, r, BC_shape)
30
+ X[:, :, - BC_size :] += dBCs[:, 3].permute(0, 2, 1) # (batch, BC_shape, r) -> (batch, r, BC_shape)
31
+ del dBCs
32
+ return X
33
+ def add_dBC_3D(X, dBCs): # X: (n_batch, spatial_size); BCs: (batch, 6, BC_shape, data_dim, dta_dim)
34
+ BC_size = dBCs.size(2)
35
+ X[:, : BC_size] += dBCs[:, 0]
36
+ X[:, - BC_size :] += dBCs[:, 1]
37
+ X[:, :, : BC_size] += dBCs[:, 2].permute(0, 2, 1, 3) # (batch, BC_shape, s, c) -> (batch, s, BC_shape, c)
38
+ X[:, :, - BC_size :] += dBCs[:, 3].permute(0, 2, 1, 3) # (batch, BC_shape, s, c) -> (batch, s, BC_shape, c)
39
+ X[:, :, :, : BC_size] += dBCs[:, 4].permute(0, 2, 3, 1) # (batch, BC_shape, s, r) -> (batch, s, r, BC_shape)
40
+ X[:, :, :, - BC_size :] += dBCs[:, 5].permute(0, 2, 3, 1) # (batch, BC_shape, s, r) -> (batch, s, r, BC_shape)
41
+ del dBCs
42
+ return X
43
+
44
+ class AdaptiveStepsizeODESolver(object):
45
+ __metaclass__ = abc.ABCMeta
46
+
47
+ def __init__(self, func, y0, atol, rtol, options= None):
48
+
49
+ # _handle_unused_kwargs(self, options)
50
+ #del options
51
+ self.func = func
52
+ self.y0 = y0
53
+ self.atol = atol
54
+ self.rtol = rtol
55
+
56
+ def before_integrate(self, t):
57
+ pass
58
+
59
+ @abc.abstractmethod
60
+ def advance(self, next_t):
61
+ raise NotImplementedError
62
+
63
+ def integrate(self, t):
64
+ _assert_increasing(t)
65
+ solution = [self.y0]
66
+ t = t.to(self.y0[0].device, torch.float64)
67
+ self.before_integrate(t)
68
+ for i in range(1, len(t)):
69
+ y = self.advance(t[i])
70
+ solution.append(y)
71
+ '''if self.contours is not None: # contours: (n_batch, nT, 4 / 6, BC_size, c)
72
+ if self.adjoint:
73
+ for i in range(1, len(t)):
74
+ ys = list(self.advance(t[i])) # tuple: (y0, **back_grad) -> y0: (n_batch, spatial_shape)
75
+ #print(len(t))
76
+ #print(ys[0].size())
77
+ #print(self.contours.size())
78
+ ys[0] = self.set_BC(ys[0], self.contours[:, i]) # (n_batch, 4 / 6, BC_size, c)
79
+ solution.append(tuple(ys))
80
+ else:
81
+ for i in range(1, len(t)):
82
+ y = torch.stack(self.advance(t[i])) # y: (n_batch, 1, spatial_shape)
83
+ y = self.set_BC(y[:, 0], self.contours[:, i]).unsqueeze(1)
84
+ solution.append(tuple(y))
85
+ elif self.dcontours is not None: # dcontours: (n_batch, nT, 4 / 6, BC_size, c)
86
+ if self.adjoint:
87
+ for i in range(1, len(t)):
88
+ ys = list(self.advance(t[i])) # ys - tuple: (y0, **back_grad) -> y0: (n_batch, spatial_shape)
89
+ ys[0] = self.add_dBC(ys[0], self.dcontours[:, i]) # (n_batch, 4 / 6, BC_size, c)
90
+ solution.append(tuple(ys))
91
+ else:
92
+ for i in range(1, len(t)):
93
+ y = torch.stack(self.advance(t[i])) # (n_batch, 1, spatial_shape)
94
+ y = self.add_dBC(y[:, 0], self.dcontours[:, i]).unsqueeze(1)
95
+ solution.append(tuple(y))
96
+ else:
97
+ for i in range(1, len(t)):
98
+ y = self.advance(t[i])
99
+ solution.append(y)'''
100
+ return tuple(map(torch.stack, tuple(zip(*solution))))
101
+
102
+
103
+ class FixedGridODESolver(object):
104
+ __metaclass__ = abc.ABCMeta
105
+
106
+ def __init__(self, func, y0, step_size=None, grid_constructor=None, atol=None, rtol=None, dt=None, options = None):
107
+ '''if 'dirichlet' in options.BC or 'cauchy' in options.BC and options.contours is not None:
108
+ self.contours = options.contours # (n_batch, nT, 4 / 6, BC_size, sub_spatial_shape)
109
+ self.BC_size = self.contours.size(3)
110
+ self.set_BC = set_BC_2D if self.contours.size(2) == 4 else set_BC_3D
111
+ else:
112
+ self.contours = None
113
+ if 'source' in options.BC and options.dcontours is not None:
114
+ self.dcontours = options.dcontours # (n_batch, nT, 4 / 6, BC_size, sub_spatial_shape)
115
+ self.BC_size = self.dcontours.size(3)
116
+ self.add_dBC = add_dBC_2D if self.dcontours.size(2) == 4 else add_dBC_3D
117
+ else:
118
+ self.dcontours = None'''
119
+ #self.adjoint = options.adjoint
120
+ #options.pop('rtol', None)
121
+ #options.pop('atol', None)
122
+ #_handle_unused_kwargs(self, options)
123
+ #del options
124
+
125
+ self.func = func
126
+ self.y0 = y0
127
+
128
+ if step_size is not None and grid_constructor is None:
129
+ self.grid_constructor = self._grid_constructor_from_step_size(step_size)
130
+ elif grid_constructor is None:
131
+ self.grid_constructor = lambda f, y0, t: t # Same time step as time interval
132
+ else:
133
+ raise ValueError("step_size and grid_constructor are exclusive arguments.")
134
+
135
+ def _grid_constructor_from_step_size(self, step_size):
136
+
137
+ def _grid_constructor(func, y0, t):
138
+ start_time = t[0]
139
+ end_time = t[-1]
140
+
141
+ niters = torch.ceil((end_time - start_time) / step_size + 1).item()
142
+ t_infer = torch.arange(0, niters).to(t) * step_size + start_time
143
+ if t_infer[-1] > t[-1]:
144
+ t_infer[-1] = t[-1]
145
+ return t_infer
146
+
147
+ return _grid_constructor
148
+
149
+ @property
150
+ @abc.abstractmethod
151
+ def order(self):
152
+ pass
153
+
154
+ @abc.abstractmethod
155
+ def step_func(self, func, t, dt, y):
156
+ pass
157
+
158
+ def integrate(self, t):
159
+ _assert_increasing(t)
160
+ t = t.type_as(self.y0[0]) # (n_time, )
161
+ time_grid = self.grid_constructor(self.func, self.y0, t)
162
+ #print('time_grid:', time_grid.size())
163
+ #print('t:', t.size())
164
+ assert time_grid[0] == t[0] and time_grid[-1] == t[-1]
165
+ time_grid = time_grid.to(self.y0[0])
166
+
167
+ solution = [self.y0]
168
+
169
+ j = 1
170
+ y0 = self.y0
171
+ for t0, t1 in zip(time_grid[:-1], time_grid[1:]):
172
+ dy = self.step_func(self.func, t0, t1 - t0, y0)
173
+ y1 = tuple(y0_ + dy_ for y0_, dy_ in zip(y0, dy))
174
+ y0 = y1
175
+ while j < len(t) and t1 >= t[j]:
176
+ solution.append(self._linear_interp(t0, t1, y0, y1, t[j]))
177
+ j += 1
178
+ '''if self.contours is not None:
179
+ if self.adjoint:
180
+ for i in range(1, len(t)):
181
+ ys = list(self._linear_interp(t0, t1, y0, y1, t[j])) # tuple: (y0, **back_grad) -> y0: (n_batch, spatial_shape)
182
+ ys[0] = self.set_BC(ys[0], self.contours[:, i]) # (n_batch, 4 / 6, BC_size, c)
183
+ solution.append(tuple(ys))
184
+ j += 1
185
+ else:
186
+ while j < len(t) and t1 >= t[j]:
187
+ y = torch.stack(self._linear_interp(t0, t1, y0, y1, t[j])) # (n_batch, 1, spatial_shape)
188
+ y = self.set_BC(y[:, 0], self.contours[:, j]).unsqueeze(1)
189
+ solution.append(tuple(y))
190
+ j += 1
191
+ elif self.dcontours is not None:
192
+ if self.adjoint:
193
+ for i in range(1, len(t)):
194
+ ys = list(self._linear_interp(t0, t1, y0, y1, t[j])) # tuple: (y0, **back_grad) -> y0: (n_batch, spatial_shape)
195
+ ys[0] = self.add_dBC(ys[0], self.dcontours[:, j]) # (n_batch, 4 / 6, BC_size, c)
196
+ solution.append(tuple(ys))
197
+ else:
198
+ while j < len(t) and t1 >= t[j]:
199
+ y = torch.stack(self._linear_interp(t0, t1, y0, y1, t[j])) # (n_batch, 1, spatial_shape)
200
+ y = self.add_dBC(y[:, 0], self.dcontours[:, j]).unsqueeze(1)
201
+ solution.append(tuple(y))
202
+ j += 1
203
+ else:
204
+ while j < len(t) and t1 >= t[j]:
205
+ solution.append(self._linear_interp(t0, t1, y0, y1, t[j]))
206
+ j += 1'''
207
+ return tuple(map(torch.stack, tuple(zip(*solution)))) # (batch, time)
208
+
209
+ def _linear_interp(self, t0, t1, y0, y1, t):
210
+ if t == t0:
211
+ return y0
212
+ if t == t1:
213
+ return y1
214
+ t0, t1, t = t0.to(y0[0]), t1.to(y0[0]), t.to(y0[0])
215
+ slope = tuple((y1_ - y0_) / (t1 - t0) for y0_, y1_, in zip(y0, y1))
216
+ return tuple(y0_ + slope_ * (t - t0) for y0_, slope_ in zip(y0, slope))
ShapeID/DiffEqs/tsit5.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from ShapeID.DiffEqs.misc import _scaled_dot_product, _convert_to_tensor, _is_finite, _select_initial_step, _handle_unused_kwargs
3
+ from ShapeID.DiffEqs.solvers import AdaptiveStepsizeODESolver
4
+ from ShapeID.DiffEqs.rk_common import _RungeKuttaState, _ButcherTableau, _runge_kutta_step
5
+
6
+ # Parameters from Tsitouras (2011).
7
+ _TSITOURAS_TABLEAU = _ButcherTableau(
8
+ alpha=[0.161, 0.327, 0.9, 0.9800255409045097, 1., 1.],
9
+ beta=[
10
+ [0.161],
11
+ [-0.008480655492357, 0.3354806554923570],
12
+ [2.897153057105494, -6.359448489975075, 4.362295432869581],
13
+ [5.32586482843925895, -11.74888356406283, 7.495539342889836, -0.09249506636175525],
14
+ [5.86145544294642038, -12.92096931784711, 8.159367898576159, -0.071584973281401006, -0.02826905039406838],
15
+ [0.09646076681806523, 0.01, 0.4798896504144996, 1.379008574103742, -3.290069515436081, 2.324710524099774],
16
+ ],
17
+ c_sol=[0.09646076681806523, 0.01, 0.4798896504144996, 1.379008574103742, -3.290069515436081, 2.324710524099774, 0],
18
+ c_error=[
19
+ 0.09646076681806523 - 0.001780011052226,
20
+ 0.01 - 0.000816434459657,
21
+ 0.4798896504144996 - -0.007880878010262,
22
+ 1.379008574103742 - 0.144711007173263,
23
+ -3.290069515436081 - -0.582357165452555,
24
+ 2.324710524099774 - 0.458082105929187,
25
+ -1 / 66,
26
+ ],
27
+ )
28
+
29
+
30
+ def _interp_coeff_tsit5(t0, dt, eval_t):
31
+ t = float((eval_t - t0) / dt)
32
+ b1 = -1.0530884977290216 * t * (t - 1.3299890189751412) * (t**2 - 1.4364028541716351 * t + 0.7139816917074209)
33
+ b2 = 0.1017 * t**2 * (t**2 - 2.1966568338249754 * t + 1.2949852507374631)
34
+ b3 = 2.490627285651252793 * t**2 * (t**2 - 2.38535645472061657 * t + 1.57803468208092486)
35
+ b4 = -16.54810288924490272 * (t - 1.21712927295533244) * (t - 0.61620406037800089) * t**2
36
+ b5 = 47.37952196281928122 * (t - 1.203071208372362603) * (t - 0.658047292653547382) * t**2
37
+ b6 = -34.87065786149660974 * (t - 1.2) * (t - 0.666666666666666667) * t**2
38
+ b7 = 2.5 * (t - 1) * (t - 0.6) * t**2
39
+ return [b1, b2, b3, b4, b5, b6, b7]
40
+
41
+
42
+ def _interp_eval_tsit5(t0, t1, k, eval_t):
43
+ dt = t1 - t0
44
+ y0 = tuple(k_[0] for k_ in k)
45
+ interp_coeff = _interp_coeff_tsit5(t0, dt, eval_t)
46
+ y_t = tuple(y0_ + _scaled_dot_product(dt, interp_coeff, k_) for y0_, k_ in zip(y0, k))
47
+ return y_t
48
+
49
+
50
+ def _optimal_step_size(last_step, mean_error_ratio, safety=0.9, ifactor=10.0, dfactor=0.2, order=5):
51
+ """Calculate the optimal size for the next Runge-Kutta step."""
52
+ if mean_error_ratio == 0:
53
+ return last_step * ifactor
54
+ if mean_error_ratio < 1:
55
+ dfactor = _convert_to_tensor(1, dtype=torch.float64, device=mean_error_ratio.device)
56
+ error_ratio = torch.sqrt(mean_error_ratio).type_as(last_step)
57
+ exponent = torch.tensor(1 / order).type_as(last_step)
58
+ factor = torch.max(1 / ifactor, torch.min(error_ratio**exponent / safety, 1 / dfactor))
59
+ return last_step / factor
60
+
61
+
62
+ def _abs_square(x):
63
+ return torch.mul(x, x)
64
+
65
+
66
+ class Tsit5Solver(AdaptiveStepsizeODESolver):
67
+
68
+ def __init__(
69
+ self, func, y0, rtol, atol, first_step=None, safety=0.9, ifactor=10.0, dfactor=0.2, max_num_steps=2**31 - 1,
70
+ **unused_kwargs
71
+ ):
72
+ _handle_unused_kwargs(self, unused_kwargs)
73
+ del unused_kwargs
74
+
75
+ self.func = func
76
+ self.y0 = y0
77
+ self.rtol = rtol
78
+ self.atol = atol
79
+ self.first_step = first_step
80
+ self.safety = _convert_to_tensor(safety, dtype=torch.float64, device=y0[0].device)
81
+ self.ifactor = _convert_to_tensor(ifactor, dtype=torch.float64, device=y0[0].device)
82
+ self.dfactor = _convert_to_tensor(dfactor, dtype=torch.float64, device=y0[0].device)
83
+ self.max_num_steps = _convert_to_tensor(max_num_steps, dtype=torch.int32, device=y0[0].device)
84
+
85
+ def before_integrate(self, t):
86
+ if self.first_step is None:
87
+ first_step = _select_initial_step(self.func, t[0], self.y0, 4, self.rtol, self.atol).to(t)
88
+ else:
89
+ first_step = _convert_to_tensor(0.01, dtype=t.dtype, device=t.device)
90
+ self.rk_state = _RungeKuttaState(
91
+ self.y0,
92
+ self.func(t[0].type_as(self.y0[0]), self.y0), t[0], t[0], first_step,
93
+ tuple(map(lambda x: [x] * 7, self.y0))
94
+ )
95
+
96
+ def advance(self, next_t):
97
+ """Interpolate through the next time point, integrating as necessary."""
98
+ n_steps = 0
99
+ while next_t > self.rk_state.t1:
100
+ assert n_steps < self.max_num_steps, 'max_num_steps exceeded ({}>={})'.format(n_steps, self.max_num_steps)
101
+ self.rk_state = self._adaptive_tsit5_step(self.rk_state)
102
+ n_steps += 1
103
+ return _interp_eval_tsit5(self.rk_state.t0, self.rk_state.t1, self.rk_state.interp_coeff, next_t)
104
+
105
+ def _adaptive_tsit5_step(self, rk_state):
106
+ """Take an adaptive Runge-Kutta step to integrate the DiffEqs."""
107
+ y0, f0, _, t0, dt, _ = rk_state
108
+ ########################################################
109
+ # Assertions #
110
+ ########################################################
111
+ assert t0 + dt > t0, 'underflow in dt {}'.format(dt.item())
112
+ for y0_ in y0:
113
+ assert _is_finite(torch.abs(y0_)), 'non-finite values in state `y`: {}'.format(y0_)
114
+ y1, f1, y1_error, k = _runge_kutta_step(self.func, y0, f0, t0, dt, tableau=_TSITOURAS_TABLEAU)
115
+
116
+ ########################################################
117
+ # Error Ratio #
118
+ ########################################################
119
+ error_tol = tuple(self.atol + self.rtol * torch.max(torch.abs(y0_), torch.abs(y1_)) for y0_, y1_ in zip(y0, y1))
120
+ tensor_error_ratio = tuple(y1_error_ / error_tol_ for y1_error_, error_tol_ in zip(y1_error, error_tol))
121
+ sq_error_ratio = tuple(
122
+ torch.mul(tensor_error_ratio_, tensor_error_ratio_) for tensor_error_ratio_ in tensor_error_ratio
123
+ )
124
+ mean_error_ratio = (
125
+ sum(torch.sum(sq_error_ratio_) for sq_error_ratio_ in sq_error_ratio) /
126
+ sum(sq_error_ratio_.numel() for sq_error_ratio_ in sq_error_ratio)
127
+ )
128
+ accept_step = mean_error_ratio <= 1
129
+
130
+ ########################################################
131
+ # Update RK State #
132
+ ########################################################
133
+ y_next = y1 if accept_step else y0
134
+ f_next = f1 if accept_step else f0
135
+ t_next = t0 + dt if accept_step else t0
136
+ dt_next = _optimal_step_size(dt, mean_error_ratio, self.safety, self.ifactor, self.dfactor)
137
+ k_next = k if accept_step else self.rk_state.interp_coeff
138
+ rk_state = _RungeKuttaState(y_next, f_next, t0, t_next, dt_next, k_next)
139
+ return rk_state
ShapeID/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from utils import *
ShapeID/demo2d.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ported from https://github.com/pvigier/perlin-numpy
2
+
3
+ import os, sys
4
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
5
+
6
+ import time, datetime
7
+
8
+ import torch
9
+ import numpy as np
10
+ import matplotlib.pyplot as plt
11
+
12
+ from misc import stream_2D, V_plot
13
+ from utils.misc import viewVolume, make_dir
14
+
15
+ from perlin2d import *
16
+
17
+
18
+ #from ShapeID.DiffEqs.odeint import odeint
19
+ from ShapeID.DiffEqs.adjoint import odeint_adjoint as odeint
20
+ from ShapeID.DiffEqs.pde import AdvDiffPDE
21
+
22
+
23
+
24
+
25
+ if __name__ == '__main__':
26
+
27
+ device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
28
+
29
+
30
+ image, mask_image = generate_perlin_noise_2d([256, 256], [2, 2], percentile = 80)
31
+ plt.imshow(image, cmap='gray') #, interpolation='lanczos')
32
+ plt.axis('off')
33
+ plt.savefig('out/2d/image.png')
34
+ plt.imshow(mask_image, cmap='gray') #, interpolation='lanczos')
35
+ plt.axis('off')
36
+ plt.savefig('out/2d/mask_image.png')
37
+
38
+
39
+
40
+ curl, mask_curl = generate_perlin_noise_2d([256, 256], [2, 2], percentile = 80)
41
+ plt.imshow(curl, cmap='gray') #, interpolation='lanczos')
42
+ plt.axis('off')
43
+ plt.savefig('out/2d/curl.png')
44
+ plt.imshow(mask_curl, cmap='gray') #, interpolation='lanczos')
45
+ plt.axis('off')
46
+ plt.savefig('out/2d/mask_curl.png')
47
+
48
+
49
+ dx, dy = stream_2D(torch.from_numpy(curl))
50
+ V_plot(dx.numpy(), dy.numpy(), 'out/2d/V.png')
51
+
52
+ plt.imshow(mask_image, cmap='gray') #, interpolation='lanczos')
53
+ plt.axis('off')
54
+ plt.savefig('out/2d/image_with_v.png')
55
+ #plt.close()
56
+
57
+
58
+ dt = 0.15
59
+ nt = 21
60
+ integ_method = 'dopri5' # choices=['dopri5', 'adams', 'rk4', 'euler']
61
+ t = torch.from_numpy(np.arange(nt) * dt).to(device)
62
+ thres = 0.9
63
+
64
+ initial = torch.from_numpy(mask_image)
65
+ Vx, Vy = dx * 1000, dy * 1000
66
+
67
+ forward_pde = AdvDiffPDE(data_spacing=[1., 1.],
68
+ perf_pattern='adv',
69
+ V_type='vector_div_free',
70
+ V_dict={'Vx': Vx, 'Vy': Vy},
71
+ BC='neumann',
72
+ dt=dt,
73
+ device=device
74
+ )
75
+
76
+
77
+ start_time = time.time()
78
+ noise_progression = odeint(forward_pde,
79
+ initial.unsqueeze(0),
80
+ t, dt, method = integ_method
81
+ )[:, 0]
82
+ total_time = time.time() - start_time
83
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
84
+ print('Time {}'.format(total_time_str))
85
+ noise_progression = noise_progression[::2]
86
+
87
+
88
+ noise_progression = noise_progression.numpy()
89
+ make_dir('out/2d/progression')
90
+
91
+ for i, noise_t in enumerate(noise_progression):
92
+ print(i, noise_t.mean())
93
+
94
+ noise_t[noise_t > thres] = 1
95
+ noise_t[noise_t <= thres] = 0
96
+
97
+ #fig = plt.figure()
98
+ plt.imshow(noise_t, cmap='gray') #, interpolation='lanczos')
99
+ plt.savefig('out/2d/progression/%d.png' % i)
100
+ #plt.close()
101
+
102
+ viewVolume(noise_progression, names = ['noise_progression'], save_dir = 'out/2d/progression')
ShapeID/demo3d.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ported from https://github.com/pvigier/perlin-numpy
2
+
3
+ import os, sys
4
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
5
+
6
+ import time, datetime
7
+
8
+ import torch
9
+ import numpy as np
10
+ import matplotlib.pyplot as plt
11
+
12
+ from misc import stream_3D, V_plot, center_crop
13
+ from utils.misc import viewVolume, make_dir, read_image
14
+
15
+
16
+ #from ShapeID.DiffEqs.odeint import odeint
17
+ from ShapeID.DiffEqs.adjoint import odeint_adjoint as odeint
18
+ from ShapeID.DiffEqs.pde import AdvDiffPDE
19
+
20
+ from perlin3d import *
21
+
22
+
23
+
24
+
25
+ if __name__ == '__main__':
26
+
27
+ device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
28
+
29
+ percentile = 80
30
+
31
+
32
+ #image, mask_image = generate_perlin_noise_3d([128, 128, 128], [2, 2, 2], tileable=(True, False, False), percentile = percentile)
33
+ #viewVolume(image, names = ['image'], save_dir = '/autofs/space/yogurt_003/users/pl629/code/MTBrainID/ShapeID/out/3d')
34
+ #viewVolume(mask_image, names = ['mask_image'], save_dir = '/autofs/space/yogurt_003/users/pl629/code/MTBrainID/ShapeID/out/3d')
35
+
36
+
37
+ #mask_image, aff = read_image('/autofs/space/yogurt_001/users/pl629/data/adni/pathology_probability/subject_193441.nii.gz')
38
+ mask_image, aff = read_image('/autofs/space/yogurt_001/users/pl629/data/isles2022/pathology_probability/sub-strokecase0127.nii.gz')
39
+ mask_image, _, _ = center_crop(torch.from_numpy(mask_image), win_size = [128, 128, 128])
40
+ mask_image = mask_image[0, 0].numpy()
41
+
42
+ shape = mask_image.shape
43
+
44
+ curl_a, _ = generate_perlin_noise_3d(shape, [2, 2, 2], tileable=(True, False, False), percentile = percentile)
45
+ curl_b, _ = generate_perlin_noise_3d(shape, [2, 2, 2], tileable=(True, False, False), percentile = percentile)
46
+ curl_c, _ = generate_perlin_noise_3d(shape, [2, 2, 2], tileable=(True, False, False), percentile = percentile)
47
+ dx, dy, dz = stream_3D(torch.from_numpy(curl_a), torch.from_numpy(curl_b), torch.from_numpy(curl_c))
48
+
49
+
50
+ dt = 0.1
51
+ nt = 10
52
+ integ_method = 'dopri5' # choices=['dopri5', 'adams', 'rk4', 'euler']
53
+ t = torch.from_numpy(np.arange(nt) * dt).to(device)
54
+ thres = 0.5
55
+
56
+ initial = torch.from_numpy(mask_image)[None] # (batch=1, h, w)
57
+ Vx, Vy, Vz = dx * 500, dy * 500, dz * 500
58
+ print(abs(Vx).mean(), abs(Vy).mean(), abs(Vz).mean())
59
+
60
+ forward_pde = AdvDiffPDE(data_spacing=[1., 1., 1.],
61
+ perf_pattern='adv',
62
+ V_type='vector_div_free',
63
+ V_dict={'Vx': Vx, 'Vy': Vy, 'Vz': Vz},
64
+ BC='neumann',
65
+ dt=dt,
66
+ device=device
67
+ )
68
+
69
+
70
+ start_time = time.time()
71
+ noise_progression = odeint(forward_pde,
72
+ initial,
73
+ t, dt, method = integ_method
74
+ )[:, 0] # (nt, n_batch, h, w)
75
+ total_time = time.time() - start_time
76
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
77
+ print('Time {}'.format(total_time_str))
78
+
79
+ noise_progression = noise_progression[::2]
80
+ noise_progression = noise_progression.numpy()
81
+ make_dir('out/3d/progression')
82
+
83
+
84
+ for i, noise_t in enumerate(noise_progression):
85
+ noise_t[noise_t > 1] = 1
86
+ noise_t[noise_t <= thres] = 0
87
+ print(i, noise_t.mean())
88
+ viewVolume(noise_t, names = ['noise_%s' % i], save_dir = '/autofs/space/yogurt_003/users/pl629/code/MTBrainID/ShapeID/out/3d/progression')
89
+
90
+ noise_t[noise_t > 0.] = 1
91
+ viewVolume(noise_t, names = ['noise_%s_mask' % i], save_dir = '/autofs/space/yogurt_003/users/pl629/code/MTBrainID/ShapeID/out/3d/progression')
ShapeID/misc.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ported from https://github.com/pvigier/perlin-numpy
2
+
3
+
4
+ import torch
5
+ import numpy as np
6
+ import matplotlib.pyplot as plt
7
+
8
+
9
+
10
+ def center_crop(img, win_size = [220, 220, 220]):
11
+ # center crop
12
+ if len(img.shape) == 4:
13
+ img = torch.permute(img, (3, 0, 1, 2)) # (move last dim to first)
14
+ img = img[None]
15
+ permuted = True
16
+ else:
17
+ assert len(img.shape) == 3
18
+ img = img[None, None]
19
+ permuted = False
20
+
21
+ orig_shp = img.shape[2:] # (1, d, s, r, c)
22
+ if win_size is None:
23
+ if permuted:
24
+ return torch.permute(img, (0, 2, 3, 4, 1)), [0, 0, 0], orig_shp
25
+ return img, [0, 0, 0], orig_shp
26
+ elif orig_shp[0] > win_size[0] or orig_shp[1] > win_size[1] or orig_shp[2] > win_size[2]:
27
+ crop_start = [ max((orig_shp[i] - win_size[i]), 0) // 2 for i in range(3) ]
28
+ crop_img = img[ :, :, crop_start[0] : crop_start[0] + win_size[0],
29
+ crop_start[1] : crop_start[1] + win_size[1],
30
+ crop_start[2] : crop_start[2] + win_size[2]]
31
+ if permuted:
32
+ return torch.permute(crop_img, (0, 2, 3, 4, 1)), [0, 0, 0], orig_shp
33
+ return crop_img, crop_start, orig_shp
34
+ else:
35
+ if permuted:
36
+ return torch.permute(img, (0, 2, 3, 4, 1)), [0, 0, 0], orig_shp
37
+ return img, [0, 0, 0], orig_shp
38
+
39
+
40
+
41
+ def V_plot(Vx, Vy, save_path):
42
+ # Meshgrid
43
+ X,Y = np.meshgrid(np.arange(0, Vx.shape[0], 1), np.arange(0, Vx.shape[1], 1))
44
+ # Assign vector directions
45
+ Ex = Vx
46
+ Ey = Vy
47
+
48
+ # Depict illustration
49
+ plt.figure()
50
+ plt.streamplot(X,Y,Ex,Ey, density=1.4, linewidth=None, color='orange')
51
+ plt.axis('off')
52
+ plt.savefig(save_path)
53
+ #plt.show()
54
+
55
+ def stream_2D(Phi, batched = False, delta_lst = [1., 1.]):
56
+ '''
57
+ input: Phi as a scalar field in 2D grid: (r, c) or (n_batch, r, c)
58
+ output: curl of Phi (divergence-free by definition)
59
+ '''
60
+ dD = gradient_c(Phi, batched = batched, delta_lst = delta_lst)
61
+ Vx = - dD[..., 1]
62
+ Vy = dD[..., 0]
63
+ return Vx, Vy
64
+
65
+
66
+ def stream_3D(Phi_a, Phi_b, Phi_c, batched = False, delta_lst = [1., 1., 1.]):
67
+ '''
68
+ input: (batch, s, r, c)
69
+ '''
70
+ device = Phi_a.device
71
+ dDa = gradient_c(Phi_a, batched = batched, delta_lst = delta_lst)
72
+ dDb = gradient_c(Phi_b, batched = batched, delta_lst = delta_lst)
73
+ dDc = gradient_c(Phi_c, batched = batched, delta_lst = delta_lst)
74
+ Va_x, Va_y, Va_z = dDa[..., 0], dDa[..., 1], dDa[..., 2]
75
+ Vb_x, Vb_y, Vb_z = dDb[..., 0], dDb[..., 1], dDb[..., 2]
76
+ Vc_x, Vc_y, Vc_z = dDc[..., 0], dDc[..., 1], dDc[..., 2]
77
+ Vx = Vc_y - Vb_z
78
+ Vy = Va_z - Vc_x
79
+ Vz = Vb_x - Va_y
80
+ return Vx, Vy, Vz
81
+
82
+
83
+
84
+ def gradient_f(X, batched = False, delta_lst = [1., 1., 1.]):
85
+ '''
86
+ Compute gradient of a torch tensor "X" in each direction
87
+ Upper-boundaries: Backward Difference
88
+ Non-boundaries & Upper-boundaries: Forward Difference
89
+ if X is batched: (n_batch, ...);
90
+ else: (...)
91
+ '''
92
+ device = X.device
93
+ dim = len(X.size()) - 1 if batched else len(X.size())
94
+ #print(batched)
95
+ #print(dim)
96
+ if dim == 1:
97
+ #print('dim = 1')
98
+ dX = torch.zeros(X.size(), dtype = torch.float, device = device)
99
+ X = X.permute(1, 0) if batched else X
100
+ dX = dX.permute(1, 0) if batched else dX
101
+ dX[-1] = X[-1] - X[-2] # Backward Difference
102
+ dX[:-1] = X[1:] - X[:-1] # Forward Difference
103
+
104
+ dX = dX.permute(1, 0) if batched else dX
105
+ dX /= delta_lst[0]
106
+ elif dim == 2:
107
+ #print('dim = 2')
108
+ dX = torch.zeros(X.size() + tuple([2]), dtype = torch.float, device = device)
109
+ X = X.permute(1, 2, 0) if batched else X
110
+ dX = dX.permute(1, 2, 3, 0) if batched else dX # put batch to last dim
111
+ dX[-1, :, 0] = X[-1, :] - X[-2, :] # Backward Difference
112
+ dX[:-1, :, 0] = X[1:] - X[:-1] # Forward Difference
113
+
114
+ dX[:, -1, 1] = X[:, -1] - X[:, -2] # Backward Difference
115
+ dX[:, :-1, 1] = X[:, 1:] - X[:, :-1] # Forward Difference
116
+
117
+ dX = dX.permute(3, 0, 1, 2) if batched else dX
118
+ dX[..., 0] /= delta_lst[0]
119
+ dX[..., 1] /= delta_lst[1]
120
+ elif dim == 3:
121
+ #print('dim = 3')
122
+ dX = torch.zeros(X.size() + tuple([3]), dtype = torch.float, device = device)
123
+ X = X.permute(1, 2, 3, 0) if batched else X
124
+ dX = dX.permute(1, 2, 3, 4, 0) if batched else dX
125
+ dX[-1, :, :, 0] = X[-1, :, :] - X[-2, :, :] # Backward Difference
126
+ dX[:-1, :, :, 0] = X[1:] - X[:-1] # Forward Difference
127
+
128
+ dX[:, -1, :, 1] = X[:, -1] - X[:, -2] # Backward Difference
129
+ dX[:, :-1, :, 1] = X[:, 1:] - X[:, :-1] # Forward Difference
130
+
131
+ dX[:, :, -1, 2] = X[:, :, -1] - X[:, :, -2] # Backward Difference
132
+ dX[:, :, :-1, 2] = X[:, :, 1:] - X[:, :, :-1] # Forward Difference
133
+
134
+ dX = dX.permute(4, 0, 1, 2, 3) if batched else dX
135
+ dX[..., 0] /= delta_lst[0]
136
+ dX[..., 1] /= delta_lst[1]
137
+ dX[..., 2] /= delta_lst[2]
138
+ return dX
139
+
140
+
141
+ def gradient_b(X, batched = False, delta_lst = [1., 1., 1.]):
142
+ '''
143
+ Compute gradient of a torch tensor "X" in each direction
144
+ Non-boundaries & Upper-boundaries: Backward Difference
145
+ Lower-boundaries: Forward Difference
146
+ if X is batched: (n_batch, ...);
147
+ else: (...)
148
+ '''
149
+ device = X.device
150
+ dim = len(X.size()) - 1 if batched else len(X.size())
151
+ #print(batched)
152
+ #print(dim)
153
+ if dim == 1:
154
+ #print('dim = 1')
155
+ dX = torch.zeros(X.size(), dtype = torch.float, device = device)
156
+ X = X.permute(1, 0) if batched else X
157
+ dX = dX.permute(1, 0) if batched else dX
158
+ dX[1:] = X[1:] - X[:-1] # Backward Difference
159
+ dX[0] = X[1] - X[0] # Forward Difference
160
+
161
+ dX = dX.permute(1, 0) if batched else dX
162
+ dX /= delta_lst[0]
163
+ elif dim == 2:
164
+ #print('dim = 2')
165
+ dX = torch.zeros(X.size() + tuple([2]), dtype = torch.float, device = device)
166
+ X = X.permute(1, 2, 0) if batched else X
167
+ dX = dX.permute(1, 2, 3, 0) if batched else dX # put batch to last dim
168
+ dX[1:, :, 0] = X[1:, :] - X[:-1, :] # Backward Difference
169
+ dX[0, :, 0] = X[1] - X[0] # Forward Difference
170
+
171
+ dX[:, 1:, 1] = X[:, 1:] - X[:, :-1] # Backward Difference
172
+ dX[:, 0, 1] = X[:, 1] - X[:, 0] # Forward Difference
173
+
174
+ dX = dX.permute(3, 0, 1, 2) if batched else dX
175
+ dX[..., 0] /= delta_lst[0]
176
+ dX[..., 1] /= delta_lst[1]
177
+ elif dim == 3:
178
+ #print('dim = 3')
179
+ dX = torch.zeros(X.size() + tuple([3]), dtype = torch.float, device = device)
180
+ X = X.permute(1, 2, 3, 0) if batched else X
181
+ dX = dX.permute(1, 2, 3, 4, 0) if batched else dX
182
+ dX[1:, :, :, 0] = X[1:, :, :] - X[:-1, :, :] # Backward Difference
183
+ dX[0, :, :, 0] = X[1] - X[0] # Forward Difference
184
+
185
+ dX[:, 1:, :, 1] = X[:, 1:] - X[:, :-1] # Backward Difference
186
+ dX[:, 0, :, 1] = X[:, 1] - X[:, 0] # Forward Difference
187
+
188
+ dX[:, :, 1:, 2] = X[:, :, 1:] - X[:, :, :-1] # Backward Difference
189
+ dX[:, :, 0, 2] = X[:, :, 1] - X[:, :, 0] # Forward Difference
190
+
191
+ dX = dX.permute(4, 0, 1, 2, 3) if batched else dX
192
+ dX[..., 0] /= delta_lst[0]
193
+ dX[..., 1] /= delta_lst[1]
194
+ dX[..., 2] /= delta_lst[2]
195
+ return dX
196
+
197
+
198
+ def gradient_c(X, batched = False, delta_lst = [1., 1., 1.]):
199
+ '''
200
+ Compute gradient of a torch tensor "X" in each direction
201
+ Non-boundaries: Central Difference
202
+ Upper-boundaries: Backward Difference
203
+ Lower-boundaries: Forward Difference
204
+ if X is batched: (n_batch, ...);
205
+ else: (...)
206
+ '''
207
+
208
+ device = X.device
209
+ dim = len(X.size()) - 1 if batched else len(X.size())
210
+ #print(X.size())
211
+ #print(batched)
212
+ #print(dim)
213
+ if dim == 1:
214
+ #print('dim = 1')
215
+ dX = torch.zeros(X.size(), dtype = torch.float, device = device)
216
+ X = X.permute(1, 0) if batched else X
217
+ dX = dX.permute(1, 0) if batched else dX
218
+ dX[1:-1] = (X[2:] - X[:-2]) / 2 # Central Difference
219
+ dX[0] = X[1] - X[0] # Forward Difference
220
+ dX[-1] = X[-1] - X[-2] # Backward Difference
221
+
222
+ dX = dX.permute(1, 0) if batched else dX
223
+ dX /= delta_lst[0]
224
+ elif dim == 2:
225
+ #print('dim = 2')
226
+ dX = torch.zeros(X.size() + tuple([2]), dtype = torch.float, device = device)
227
+ X = X.permute(1, 2, 0) if batched else X
228
+ dX = dX.permute(1, 2, 3, 0) if batched else dX # put batch to last dim
229
+ dX[1:-1, :, 0] = (X[2:, :] - X[:-2, :]) / 2
230
+ dX[0, :, 0] = X[1] - X[0]
231
+ dX[-1, :, 0] = X[-1] - X[-2]
232
+ dX[:, 1:-1, 1] = (X[:, 2:] - X[:, :-2]) / 2
233
+ dX[:, 0, 1] = X[:, 1] - X[:, 0]
234
+ dX[:, -1, 1] = X[:, -1] - X[:, -2]
235
+
236
+ dX = dX.permute(3, 0, 1, 2) if batched else dX
237
+ dX[..., 0] /= delta_lst[0]
238
+ dX[..., 1] /= delta_lst[1]
239
+ elif dim == 3:
240
+ #print('dim = 3')
241
+ dX = torch.zeros(X.size() + tuple([3]), dtype = torch.float, device = device)
242
+ X = X.permute(1, 2, 3, 0) if batched else X
243
+ dX = dX.permute(1, 2, 3, 4, 0) if batched else dX
244
+ dX[1:-1, :, :, 0] = (X[2:, :, :] - X[:-2, :, :]) / 2
245
+ dX[0, :, :, 0] = X[1] - X[0]
246
+ dX[-1, :, :, 0] = X[-1] - X[-2]
247
+ dX[:, 1:-1, :, 1] = (X[:, 2:, :] - X[:, :-2, :]) / 2
248
+ dX[:, 0, :, 1] = X[:, 1] - X[:, 0]
249
+ dX[:, -1, :, 1] = X[:, -1] - X[:, -2]
250
+ dX[:, :, 1:-1, 2] = (X[:, :, 2:] - X[:, :, :-2]) / 2
251
+ dX[:, :, 0, 2] = X[:, :, 1] - X[:, :, 0]
252
+ dX[:, :, -1, 2] = X[:, :, -1] - X[:, :, -2]
253
+
254
+ dX = dX.permute(4, 0, 1, 2, 3) if batched else dX
255
+ dX[..., 0] /= delta_lst[0]
256
+ dX[..., 1] /= delta_lst[1]
257
+ dX[..., 2] /= delta_lst[2]
258
+
259
+ return dX
260
+
261
+
ShapeID/out/2d/V.png ADDED

Git LFS Details

  • SHA256: ddfa181038abcb687f69cb20d9aea1caf3e9d474895de4dca01ee39a6afd8b82
  • Pointer size: 131 Bytes
  • Size of remote file: 136 kB
ShapeID/out/2d/curl.png ADDED
ShapeID/out/2d/image.png ADDED
ShapeID/out/2d/image_with_v.png ADDED

Git LFS Details

  • SHA256: 6f9309254daadba376969a414084601ee466c9d8e24e9e975d6ca3b93243dbf0
  • Pointer size: 131 Bytes
  • Size of remote file: 130 kB
ShapeID/out/2d/mask_curl.png ADDED
ShapeID/out/2d/mask_image.png ADDED
ShapeID/out/2d/progression/New Folder With Items/0.png ADDED

Git LFS Details

  • SHA256: 6f9309254daadba376969a414084601ee466c9d8e24e9e975d6ca3b93243dbf0
  • Pointer size: 131 Bytes
  • Size of remote file: 130 kB