Divyanshu Tak
Add BrainIAC IDH Classification app with Vision Transformer model
65bee5d
import torch
import numpy as np
def pad_patient_3D(patient, shape_must_be_divisible_by=16, min_size=None):
if not (isinstance(shape_must_be_divisible_by, list) or isinstance(shape_must_be_divisible_by, tuple)):
shape_must_be_divisible_by = [shape_must_be_divisible_by] * 3
shp = patient.shape
new_shp = [shp[0] + shape_must_be_divisible_by[0] - shp[0] % shape_must_be_divisible_by[0],
shp[1] + shape_must_be_divisible_by[1] - shp[1] % shape_must_be_divisible_by[1],
shp[2] + shape_must_be_divisible_by[2] - shp[2] % shape_must_be_divisible_by[2]]
for i in range(len(shp)):
if shp[i] % shape_must_be_divisible_by[i] == 0:
new_shp[i] -= shape_must_be_divisible_by[i]
if min_size is not None:
new_shp = np.max(np.vstack((np.array(new_shp), np.array(min_size))), 0)
return reshape_by_padding_upper_coords(patient, new_shp, 0), shp
def reshape_by_padding_upper_coords(image, new_shape, pad_value=None):
shape = tuple(list(image.shape))
new_shape = tuple(np.max(np.concatenate((shape, new_shape)).reshape((2,len(shape))), axis=0))
if pad_value is None:
if len(shape) == 2:
pad_value = image[0,0]
elif len(shape) == 3:
pad_value = image[0, 0, 0]
else:
raise ValueError("Image must be either 2 or 3 dimensional")
res = np.ones(list(new_shape), dtype=image.dtype) * pad_value
if len(shape) == 2:
res[0:0+int(shape[0]), 0:0+int(shape[1])] = image
elif len(shape) == 3:
res[0:0+int(shape[0]), 0:0+int(shape[1]), 0:0+int(shape[2])] = image
return res
def predict_case_3D_net(net, patient_data, do_mirroring, num_repeats, BATCH_SIZE=None,
new_shape_must_be_divisible_by=16, min_size=None, main_device=0, mirror_axes=(2, 3, 4)):
with torch.no_grad():
pad_res = []
for i in range(patient_data.shape[0]):
t, old_shape = pad_patient_3D(patient_data[i], new_shape_must_be_divisible_by, min_size)
pad_res.append(t[None])
patient_data = np.vstack(pad_res)
new_shp = patient_data.shape
data = np.zeros(tuple([1] + list(new_shp)), dtype=np.float32)
data[0] = patient_data
if BATCH_SIZE is not None:
data = np.vstack([data] * BATCH_SIZE)
a = torch.rand(data.shape).float()
if main_device == 'cpu':
pass
else:
a = a.cuda(main_device)
if do_mirroring:
x = 8
else:
x = 1
all_preds = []
for i in range(num_repeats):
for m in range(x):
data_for_net = np.array(data)
do_stuff = False
if m == 0:
do_stuff = True
pass
if m == 1 and (4 in mirror_axes):
do_stuff = True
data_for_net = data_for_net[:, :, :, :, ::-1]
if m == 2 and (3 in mirror_axes):
do_stuff = True
data_for_net = data_for_net[:, :, :, ::-1, :]
if m == 3 and (4 in mirror_axes) and (3 in mirror_axes):
do_stuff = True
data_for_net = data_for_net[:, :, :, ::-1, ::-1]
if m == 4 and (2 in mirror_axes):
do_stuff = True
data_for_net = data_for_net[:, :, ::-1, :, :]
if m == 5 and (2 in mirror_axes) and (4 in mirror_axes):
do_stuff = True
data_for_net = data_for_net[:, :, ::-1, :, ::-1]
if m == 6 and (2 in mirror_axes) and (3 in mirror_axes):
do_stuff = True
data_for_net = data_for_net[:, :, ::-1, ::-1, :]
if m == 7 and (2 in mirror_axes) and (3 in mirror_axes) and (4 in mirror_axes):
do_stuff = True
data_for_net = data_for_net[:, :, ::-1, ::-1, ::-1]
if do_stuff:
_ = a.data.copy_(torch.from_numpy(np.copy(data_for_net)))
p = net(a) # np.copy is necessary because ::-1 creates just a view i think
p = p.data.cpu().numpy()
if m == 0:
pass
if m == 1 and (4 in mirror_axes):
p = p[:, :, :, :, ::-1]
if m == 2 and (3 in mirror_axes):
p = p[:, :, :, ::-1, :]
if m == 3 and (4 in mirror_axes) and (3 in mirror_axes):
p = p[:, :, :, ::-1, ::-1]
if m == 4 and (2 in mirror_axes):
p = p[:, :, ::-1, :, :]
if m == 5 and (2 in mirror_axes) and (4 in mirror_axes):
p = p[:, :, ::-1, :, ::-1]
if m == 6 and (2 in mirror_axes) and (3 in mirror_axes):
p = p[:, :, ::-1, ::-1, :]
if m == 7 and (2 in mirror_axes) and (3 in mirror_axes) and (4 in mirror_axes):
p = p[:, :, ::-1, ::-1, ::-1]
all_preds.append(p)
stacked = np.vstack(all_preds)[:, :, :old_shape[0], :old_shape[1], :old_shape[2]]
predicted_segmentation = stacked.mean(0).argmax(0)
uncertainty = stacked.var(0)
bayesian_predictions = stacked
softmax_pred = stacked.mean(0)
return predicted_segmentation, bayesian_predictions, softmax_pred, uncertainty