Spaces:
Sleeping
Sleeping
| from typing import Sequence | |
| import torch | |
| import torch.nn.functional as F | |
| def crop_and_pad_into_buffer(target_tensor: torch.Tensor, | |
| bbox: Sequence[Sequence[int]], | |
| source_tensor: torch.Tensor) -> None: | |
| """ | |
| Copies a sub-region from source_tensor into target_tensor based on a bounding box. | |
| Args: | |
| target_tensor (torch.Tensor): A preallocated tensor that will be updated. | |
| bbox (sequence of [int, int]): A bounding box for each dimension of the source tensor | |
| that is covered by the bbox. The bbox is defined as [start, end) (half-open interval) | |
| and may extend outside the source tensor. If source_tensor has more dimensions than | |
| len(bbox), the leading dimensions will be fully included. | |
| source_tensor (torch.Tensor): The tensor to copy data from. | |
| Behavior: | |
| For each dimension that the bbox covers (i.e. the last len(bbox) dims of source_tensor): | |
| - Compute the overlapping region between the bbox and the source tensor. | |
| - Determine the corresponding indices in the target tensor where the data will be copied. | |
| For any extra leading dimensions (i.e. source_tensor.ndim > len(bbox)): | |
| - Use slice(None) to include the entire dimension. | |
| If source_tensor and target_tensor are on different devices, only the overlapping subregion | |
| is transferred to the device of target_tensor. | |
| """ | |
| total_dims = source_tensor.ndim | |
| bbox_dims = len(bbox) | |
| # Compute the number of leading dims that are not covered by bbox. | |
| leading_dims = total_dims - bbox_dims | |
| source_slices = [] | |
| target_slices = [] | |
| # For the leading dimensions, include the entire dimension. | |
| for _ in range(leading_dims): | |
| source_slices.append(slice(None)) | |
| target_slices.append(slice(None)) | |
| # Process the dimensions covered by the bbox. | |
| for d in range(bbox_dims): | |
| box_start, box_end = bbox[d] | |
| d_source = leading_dims + d | |
| source_size = source_tensor.shape[d_source] | |
| # Compute the overlapping region in source coordinates. | |
| copy_start_source = max(box_start, 0) | |
| copy_end_source = min(box_end, source_size) | |
| copy_size = copy_end_source - copy_start_source | |
| # Compute the corresponding indices in the target tensor. | |
| copy_start_target = max(0, -box_start) | |
| copy_end_target = copy_start_target + copy_size | |
| source_slices.append(slice(copy_start_source, copy_end_source)) | |
| target_slices.append(slice(copy_start_target, copy_end_target)) | |
| # Extract the overlapping region from the source. | |
| sub_source = source_tensor[tuple(source_slices)] | |
| # Transfer only this subregion to the target tensor's device. | |
| sub_source = sub_source.to(target_tensor.device) if isinstance(target_tensor, torch.Tensor) else sub_source.cpu() | |
| # Write the data into the preallocated target_tensor. | |
| target_tensor[tuple(target_slices)] = sub_source | |
| def paste_tensor(target: torch.Tensor, source: torch.Tensor, bbox): | |
| """ | |
| Paste a source tensor into a target tensor using a given bounding box. | |
| Both tensors are assumed to be 3D. | |
| The bounding box is specified in the coordinate system of the target as: | |
| [[x1, x2], [y1, y2], [z1, z2]] | |
| and its size is assumed to be equal to the shape of the source tensor. | |
| The bbox may exceed the boundaries of the target tensor. | |
| The function computes the valid overlapping region between the bbox and the target, | |
| and then adjusts the corresponding region in the source tensor so that only the valid | |
| parts are pasted. | |
| Args: | |
| target (torch.Tensor): The target tensor of shape (T0, T1, T2). | |
| source (torch.Tensor): The source tensor of shape (S0, S1, S2). It must be the same size as | |
| the bbox, i.e. S0 = x2 - x1, etc. | |
| bbox (list or tuple): List of intervals for each dimension: [[x1, x2], [y1, y2], [z1, z2]]. | |
| Returns: | |
| torch.Tensor: The target tensor after pasting in the source. | |
| """ | |
| target_shape = target.shape # (T0, T1, T2) | |
| # For each dimension compute: | |
| # - The valid region in the target: [t_start, t_end) | |
| # - The corresponding region in the source: [s_start, s_end) | |
| target_indices = [] | |
| source_indices = [] | |
| for i, (b0, b1) in enumerate(bbox): | |
| # Determine valid region in target tensor: | |
| t_start = max(b0, 0) | |
| t_end = min(b1, target_shape[i]) | |
| # If there's no overlap in any dimension, nothing gets pasted. | |
| if t_start >= t_end: | |
| return target | |
| # Determine corresponding indices in the source tensor. | |
| # The source's coordinate 0 corresponds to b0 in the target. | |
| s_start = t_start - b0 | |
| s_end = s_start + (t_end - t_start) | |
| target_indices.append((t_start, t_end)) | |
| source_indices.append((s_start, s_end)) | |
| # Paste the corresponding region from source into target. | |
| if isinstance(target, torch.Tensor): | |
| target[target_indices[0][0]:target_indices[0][1], | |
| target_indices[1][0]:target_indices[1][1], | |
| target_indices[2][0]:target_indices[2][1]] = \ | |
| source[source_indices[0][0]:source_indices[0][1], | |
| source_indices[1][0]:source_indices[1][1], | |
| source_indices[2][0]:source_indices[2][1]].to(target.device) | |
| else: | |
| target[target_indices[0][0]:target_indices[0][1], | |
| target_indices[1][0]:target_indices[1][1], | |
| target_indices[2][0]:target_indices[2][1]] = \ | |
| source[source_indices[0][0]:source_indices[0][1], | |
| source_indices[1][0]:source_indices[1][1], | |
| source_indices[2][0]:source_indices[2][1]].cpu() | |
| return target | |
| def crop_to_valid(img: torch.Tensor, bbox): | |
| """ | |
| Crops the image to the part of the bounding box that lies within the image. | |
| Supports a 4D tensor of shape (C, X, Y, Z). The bounding box is specified as | |
| [[x1, x2], [y1, y2], [z1, z2]] with half-open intervals. | |
| Args: | |
| img (torch.Tensor): Input tensor of shape (C, X, Y, Z). | |
| bbox (list or tuple): Bounding box as a list of three intervals for spatial dims: | |
| [[x1, x2], [y1, y2], [z1, z2]]. | |
| Returns: | |
| cropped (torch.Tensor): Cropped tensor of shape (C, cropped_x, cropped_y, cropped_z). | |
| pad (list of tuples): A list [(pad_x_left, pad_x_right), | |
| (pad_y_left, pad_y_right), | |
| (pad_z_left, pad_z_right)] | |
| indicating how much padding needs to be applied on each side. | |
| """ | |
| # Only spatial dimensions (X, Y, Z) are cropped; channels are preserved. | |
| spatial_dims = img.shape[1:] # (X, Y, Z) | |
| crop_indices = [] | |
| pad = [] # for each spatial dimension | |
| for i, (start, end) in enumerate(bbox): | |
| dim_size = spatial_dims[i] | |
| # Clamp the indices to the valid range for cropping. | |
| crop_start = max(start, 0) | |
| crop_end = min(end, dim_size) | |
| crop_indices.append((crop_start, crop_end)) | |
| # Calculate padding if the bbox goes out-of-bound. | |
| pad_left = -start if start < 0 else 0 | |
| pad_right = end - dim_size if end > dim_size else 0 | |
| pad.append((pad_left, pad_right)) | |
| # Crop the image on spatial dimensions, leaving the channel dimension intact. | |
| cropped = img[:, | |
| crop_indices[0][0]:crop_indices[0][1], | |
| crop_indices[1][0]:crop_indices[1][1], | |
| crop_indices[2][0]:crop_indices[2][1]] | |
| return cropped, pad | |
| def pad_cropped(cropped: torch.Tensor, pad): | |
| """ | |
| Pads the cropped image using the given pad amounts. | |
| Supports a 4D tensor of shape (C, X, Y, Z) and applies padding only on the spatial dimensions. | |
| For 3D (volumetric) padding, F.pad expects a 5D tensor with shape (N, C, X, Y, Z). | |
| Hence, we temporarily add a dummy batch dimension. | |
| Args: | |
| cropped (torch.Tensor): Cropped tensor of shape (C, X, Y, Z). | |
| pad (list of tuples): List of padding for each spatial dimension, in order (x, y, z): | |
| [(pad_x_left, pad_x_right), | |
| (pad_y_left, pad_y_right), | |
| (pad_z_left, pad_z_right)]. | |
| Returns: | |
| padded (torch.Tensor): Padded tensor of shape (C, desired_x, desired_y, desired_z), | |
| where the spatial dimensions match the bbox size. | |
| """ | |
| # F.pad for 3D data expects a 5D input (N, C, X, Y, Z) and a pad tuple of length 6: | |
| # (pad_z_left, pad_z_right, pad_y_left, pad_y_right, pad_x_left, pad_x_right) | |
| need_unsqueeze = (cropped.dim() == 4) | |
| if need_unsqueeze: | |
| cropped = cropped.unsqueeze(0) # Now shape is (1, C, X, Y, Z) | |
| # Reverse the pad list (currently in order x, y, z) to match F.pad's expected order: z, y, x. | |
| pad_rev = pad[::-1] | |
| pad_flat = [p for pair in pad_rev for p in pair] | |
| padded = F.pad(cropped, pad_flat) | |
| if need_unsqueeze: | |
| padded = padded.squeeze(0) | |
| return padded |