|
|
import torch |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
def pad_patient_3D(patient, shape_must_be_divisible_by=16, min_size=None): |
|
|
if not (isinstance(shape_must_be_divisible_by, list) or isinstance(shape_must_be_divisible_by, tuple)): |
|
|
shape_must_be_divisible_by = [shape_must_be_divisible_by] * 3 |
|
|
shp = patient.shape |
|
|
new_shp = [shp[0] + shape_must_be_divisible_by[0] - shp[0] % shape_must_be_divisible_by[0], |
|
|
shp[1] + shape_must_be_divisible_by[1] - shp[1] % shape_must_be_divisible_by[1], |
|
|
shp[2] + shape_must_be_divisible_by[2] - shp[2] % shape_must_be_divisible_by[2]] |
|
|
for i in range(len(shp)): |
|
|
if shp[i] % shape_must_be_divisible_by[i] == 0: |
|
|
new_shp[i] -= shape_must_be_divisible_by[i] |
|
|
if min_size is not None: |
|
|
new_shp = np.max(np.vstack((np.array(new_shp), np.array(min_size))), 0) |
|
|
return reshape_by_padding_upper_coords(patient, new_shp, 0), shp |
|
|
|
|
|
|
|
|
def reshape_by_padding_upper_coords(image, new_shape, pad_value=None): |
|
|
shape = tuple(list(image.shape)) |
|
|
new_shape = tuple(np.max(np.concatenate((shape, new_shape)).reshape((2,len(shape))), axis=0)) |
|
|
if pad_value is None: |
|
|
if len(shape) == 2: |
|
|
pad_value = image[0,0] |
|
|
elif len(shape) == 3: |
|
|
pad_value = image[0, 0, 0] |
|
|
else: |
|
|
raise ValueError("Image must be either 2 or 3 dimensional") |
|
|
res = np.ones(list(new_shape), dtype=image.dtype) * pad_value |
|
|
if len(shape) == 2: |
|
|
res[0:0+int(shape[0]), 0:0+int(shape[1])] = image |
|
|
elif len(shape) == 3: |
|
|
res[0:0+int(shape[0]), 0:0+int(shape[1]), 0:0+int(shape[2])] = image |
|
|
return res |
|
|
|
|
|
|
|
|
def predict_case_3D_net(net, patient_data, do_mirroring, num_repeats, BATCH_SIZE=None, |
|
|
new_shape_must_be_divisible_by=16, min_size=None, main_device=0, mirror_axes=(2, 3, 4)): |
|
|
with torch.no_grad(): |
|
|
pad_res = [] |
|
|
for i in range(patient_data.shape[0]): |
|
|
t, old_shape = pad_patient_3D(patient_data[i], new_shape_must_be_divisible_by, min_size) |
|
|
pad_res.append(t[None]) |
|
|
|
|
|
patient_data = np.vstack(pad_res) |
|
|
|
|
|
new_shp = patient_data.shape |
|
|
|
|
|
data = np.zeros(tuple([1] + list(new_shp)), dtype=np.float32) |
|
|
|
|
|
data[0] = patient_data |
|
|
|
|
|
if BATCH_SIZE is not None: |
|
|
data = np.vstack([data] * BATCH_SIZE) |
|
|
|
|
|
a = torch.rand(data.shape).float() |
|
|
|
|
|
if main_device == 'cpu': |
|
|
pass |
|
|
else: |
|
|
a = a.cuda(main_device) |
|
|
|
|
|
if do_mirroring: |
|
|
x = 8 |
|
|
else: |
|
|
x = 1 |
|
|
all_preds = [] |
|
|
for i in range(num_repeats): |
|
|
for m in range(x): |
|
|
data_for_net = np.array(data) |
|
|
do_stuff = False |
|
|
if m == 0: |
|
|
do_stuff = True |
|
|
pass |
|
|
if m == 1 and (4 in mirror_axes): |
|
|
do_stuff = True |
|
|
data_for_net = data_for_net[:, :, :, :, ::-1] |
|
|
if m == 2 and (3 in mirror_axes): |
|
|
do_stuff = True |
|
|
data_for_net = data_for_net[:, :, :, ::-1, :] |
|
|
if m == 3 and (4 in mirror_axes) and (3 in mirror_axes): |
|
|
do_stuff = True |
|
|
data_for_net = data_for_net[:, :, :, ::-1, ::-1] |
|
|
if m == 4 and (2 in mirror_axes): |
|
|
do_stuff = True |
|
|
data_for_net = data_for_net[:, :, ::-1, :, :] |
|
|
if m == 5 and (2 in mirror_axes) and (4 in mirror_axes): |
|
|
do_stuff = True |
|
|
data_for_net = data_for_net[:, :, ::-1, :, ::-1] |
|
|
if m == 6 and (2 in mirror_axes) and (3 in mirror_axes): |
|
|
do_stuff = True |
|
|
data_for_net = data_for_net[:, :, ::-1, ::-1, :] |
|
|
if m == 7 and (2 in mirror_axes) and (3 in mirror_axes) and (4 in mirror_axes): |
|
|
do_stuff = True |
|
|
data_for_net = data_for_net[:, :, ::-1, ::-1, ::-1] |
|
|
|
|
|
if do_stuff: |
|
|
_ = a.data.copy_(torch.from_numpy(np.copy(data_for_net))) |
|
|
p = net(a) |
|
|
p = p.data.cpu().numpy() |
|
|
|
|
|
if m == 0: |
|
|
pass |
|
|
if m == 1 and (4 in mirror_axes): |
|
|
p = p[:, :, :, :, ::-1] |
|
|
if m == 2 and (3 in mirror_axes): |
|
|
p = p[:, :, :, ::-1, :] |
|
|
if m == 3 and (4 in mirror_axes) and (3 in mirror_axes): |
|
|
p = p[:, :, :, ::-1, ::-1] |
|
|
if m == 4 and (2 in mirror_axes): |
|
|
p = p[:, :, ::-1, :, :] |
|
|
if m == 5 and (2 in mirror_axes) and (4 in mirror_axes): |
|
|
p = p[:, :, ::-1, :, ::-1] |
|
|
if m == 6 and (2 in mirror_axes) and (3 in mirror_axes): |
|
|
p = p[:, :, ::-1, ::-1, :] |
|
|
if m == 7 and (2 in mirror_axes) and (3 in mirror_axes) and (4 in mirror_axes): |
|
|
p = p[:, :, ::-1, ::-1, ::-1] |
|
|
all_preds.append(p) |
|
|
|
|
|
stacked = np.vstack(all_preds)[:, :, :old_shape[0], :old_shape[1], :old_shape[2]] |
|
|
predicted_segmentation = stacked.mean(0).argmax(0) |
|
|
uncertainty = stacked.var(0) |
|
|
bayesian_predictions = stacked |
|
|
softmax_pred = stacked.mean(0) |
|
|
return predicted_segmentation, bayesian_predictions, softmax_pred, uncertainty |
|
|
|