frankenstein / dataprocesser /reconstruct_patch_to_volume.py
zy7_oldserver
1
fd601de
import torch
def initialize_collection(first_data):
collected_patches = []
collected_coords = []
#first_data = next(iter(train_loader))
original_spatial_shape = first_data['original_spatial_shape']
data_patch_0 = first_data['img']
#print(data_patch_0.meta['filename_or_obj'])
volume_shape = tuple(torch.max(dim_shape).item() for dim_shape in original_spatial_shape)
reconstructed_volume = torch.zeros(volume_shape, dtype=data_patch_0.dtype)
print('empty volume_shape:',volume_shape)
# Initialize a volume to keep count of the number of patches added at each location
count_volume = torch.zeros(volume_shape, dtype=torch.int)
return collected_patches, collected_coords, reconstructed_volume, count_volume
def reconstruct_volume(collected_patches, collected_coords, reconstructed_volume, count_volume=None):
A_data = collected_patches[0]
batch_size = A_data.shape[0]
batch_num = len(collected_patches)
print('batch_num:',batch_num)
for data_idx in range(batch_num):
data = collected_patches[data_idx]
patch_coords = collected_coords[data_idx]
#print(patch_coords)
for batch_idx in range(batch_size):
data_patch_idx = data[batch_idx]
patch_coords_idx = patch_coords[batch_idx]
channel_start, channel_end = patch_coords_idx[0]
x_start, x_end = patch_coords_idx[1]
y_start, y_end = patch_coords_idx[2]
z_start, z_end = patch_coords_idx[3]
# Place the patch in the reconstructed volume
try:
reconstructed_volume[x_start:x_end, y_start:y_end, z_start:z_end] = data_patch_idx[0]
if count_volume is not None:
count_volume[x_start:x_end, y_start:y_end, z_start:z_end] = 1
except IndexError as e:
print(f"IndexError: {e} - check patch coordinates and dimensions")
print('patch_coords_idx:',patch_coords_idx)
print('data shape:',data_patch_idx.shape)
print('to fill shape:',reconstructed_volume[x_start:x_end, y_start:y_end, z_start:z_end].shape)
print('check the div_size and patch_size, they should be at least the same')
'''
si_input(B_data[batch_idx])
si_seg(A_data[batch_idx])
grad=gradient_calc(B_data[batch_idx])
si_grad(grad)
'''
# Avoid division by zero
#count_volume = torch.where(count_volume == 0, torch.ones_like(count_volume), count_volume)
# Average out the overlapping areas
#reconstructed_volume = reconstructed_volume / count_volume
return reconstructed_volume, count_volume