from functools import lru_cache from typing import Tuple, Optional import numpy as np import torch from batchgeneratorsv2.helpers.scalar_type import sample_scalar, RandomScalar from scipy.ndimage import distance_transform_edt from skimage.morphology import disk, ball @lru_cache(maxsize=5) def build_point(radii, use_distance_transform, binarize): max_radius = max(radii) ndim = len(radii) # Create a spherical (or circular) structuring element with max_radius if ndim == 2: structuring_element = disk(max_radius) elif ndim == 3: structuring_element = ball(max_radius) else: raise ValueError("Unsupported number of dimensions. Only 2D and 3D are supported.") # Convert the structuring element to a tensor structuring_element = torch.from_numpy(structuring_element.astype(np.float32)) # Create the target shape based on the sampled radii target_shape = [round(2 * r + 1) for r in radii] if any([i != j for i, j in zip(target_shape, structuring_element.shape)]): structuring_element_resized = torch.nn.functional.interpolate( structuring_element.unsqueeze(0).unsqueeze(0), # Add batch and channel dimensions for interpolation size=target_shape, mode='trilinear' if ndim == 3 else 'bilinear', align_corners=False )[0, 0] # Remove batch and channel dimensions after interpolation else: structuring_element_resized = structuring_element if use_distance_transform: # Convert the structuring element to a binary mask for distance transform computation binary_structuring_element = (structuring_element_resized >= 0.5).numpy() # Compute the Euclidean distance transform of the binary structuring element structuring_element_resized = distance_transform_edt(binary_structuring_element) # Normalize the distance transform to have values between 0 and 1 structuring_element_resized /= structuring_element_resized.max() structuring_element_resized = torch.from_numpy(structuring_element_resized) if binarize and not use_distance_transform: # Normalize the resized structuring element to binary (values near 1 are treated as the point region) structuring_element_resized = (structuring_element_resized >= 0.5).float() return structuring_element_resized class PointInteraction_stub(): interaction_type = 'point' def __init__(self, point_radius: RandomScalar, use_distance_transform: bool = False): """ Initializes the PointInteraction object. Parameters: point_radius (RandomScalar): Specifies the radius for the interaction points. use_distance_transform (bool): Determines whether to use a distance transform for smooth interactions. """ super().__init__() self.point_radius = point_radius self.use_distance_transform = use_distance_transform def place_point(self, position: Tuple[int, ...], interaction_map: torch.Tensor, binarize: bool = False) -> torch.Tensor: """ Places a point on the interaction map around the specified position. Parameters: position (Tuple[int, ...]): The (x, y, z) coordinates where the point should be placed. interaction_map (torch.Tensor): A tensor representing the interaction map where the point should be placed. The shape should match the volume dimensions. binarize (bool): If True, inserts a binary mask. If False, may insert smooth values based on distance. Returns: torch.Tensor: Updated interaction map with the point added. """ ndim = interaction_map.ndim # Determine the radius for each dimension radius = tuple([sample_scalar(self.point_radius, d, interaction_map.shape) for d in range(ndim)]) strel = build_point(radius, self.use_distance_transform, binarize) # Calculate slice range in each dimension, ensuring it is within the bounds of the interaction map bbox = [[position[i] - strel.shape[i] // 2, position[i] + strel.shape[i] // 2 + strel.shape[i] % 2] for i in range(ndim)] # detect if bbox is completely outside interaction_map if any([i[1] < 0 for i in bbox]) or any([i[0] > s for i, s in zip(bbox, interaction_map.shape)]): print('Point is outside the interaction map! Ignoring') print(f'Position: {position}') print(f'Interaction map shape: {interaction_map.shape}') print(f'Point bbox would have been {bbox}') return interaction_map slices = tuple(slice(max(0, bbox[i][0]), min(interaction_map.shape[i], bbox[i][1])) for i in range(ndim)) # Calculate where the resized structuring element should be placed within the slices structuring_slices = tuple([slice(max(0, -bbox[i][0]), slices[i].stop - slices[i].start + max(0, -bbox[i][0])) for i in range(ndim)]) # Place the resized structuring element into the interaction map torch.maximum(interaction_map[slices], strel[structuring_slices].to(interaction_map.device), out=interaction_map[slices]) return interaction_map