import torch def initialize_collection(first_data): collected_patches = [] collected_coords = [] #first_data = next(iter(train_loader)) original_spatial_shape = first_data['original_spatial_shape'] data_patch_0 = first_data['img'] #print(data_patch_0.meta['filename_or_obj']) volume_shape = tuple(torch.max(dim_shape).item() for dim_shape in original_spatial_shape) reconstructed_volume = torch.zeros(volume_shape, dtype=data_patch_0.dtype) print('empty volume_shape:',volume_shape) # Initialize a volume to keep count of the number of patches added at each location count_volume = torch.zeros(volume_shape, dtype=torch.int) return collected_patches, collected_coords, reconstructed_volume, count_volume def reconstruct_volume(collected_patches, collected_coords, reconstructed_volume, count_volume=None): A_data = collected_patches[0] batch_size = A_data.shape[0] batch_num = len(collected_patches) print('batch_num:',batch_num) for data_idx in range(batch_num): data = collected_patches[data_idx] patch_coords = collected_coords[data_idx] #print(patch_coords) for batch_idx in range(batch_size): data_patch_idx = data[batch_idx] patch_coords_idx = patch_coords[batch_idx] channel_start, channel_end = patch_coords_idx[0] x_start, x_end = patch_coords_idx[1] y_start, y_end = patch_coords_idx[2] z_start, z_end = patch_coords_idx[3] # Place the patch in the reconstructed volume try: reconstructed_volume[x_start:x_end, y_start:y_end, z_start:z_end] = data_patch_idx[0] if count_volume is not None: count_volume[x_start:x_end, y_start:y_end, z_start:z_end] = 1 except IndexError as e: print(f"IndexError: {e} - check patch coordinates and dimensions") print('patch_coords_idx:',patch_coords_idx) print('data shape:',data_patch_idx.shape) print('to fill shape:',reconstructed_volume[x_start:x_end, y_start:y_end, z_start:z_end].shape) print('check the div_size and patch_size, they should be at least the same') ''' si_input(B_data[batch_idx]) si_seg(A_data[batch_idx]) grad=gradient_calc(B_data[batch_idx]) si_grad(grad) ''' # Avoid division by zero #count_volume = torch.where(count_volume == 0, torch.ones_like(count_volume), count_volume) # Average out the overlapping areas #reconstructed_volume = reconstructed_volume / count_volume return reconstructed_volume, count_volume