import torch import numpy as np import PIL.Image as Image import torchvision.transforms as transforms import torch.nn.functional as F from typing import Optional, Tuple, Union def morphological_open(image: torch.Tensor, kernel_size: int = 3) -> torch.Tensor: """ Perform morphological opening on a 2D torch tensor (image). Args: image (torch.Tensor): image to open kernel_size (int): size of the structuring element - roughly the size of hole to be opened Returns: torch.Tensor: The opened image. """ kernel = torch.ones((1, 1, kernel_size, kernel_size), dtype=torch.float32, device=image.device) eroded = F.conv2d(image.unsqueeze(0), kernel, stride=1, padding=kernel_size // 2) eroded = (eroded > 0).float() dilated = F.conv2d(eroded, kernel, stride=1, padding=kernel_size // 2) return (dilated > 0).float() def morphological_close(image: torch.Tensor, kernel_size: int = 3) -> torch.Tensor: """ Perform morphological closing on a 2D torch tensor (image). Args: image (torch.Tensor): image to close kernel_size (int): size of the structuring element - roughly the size of hole to be closed Returns: torch.Tensor: The closed image. """ kernel = torch.ones((1, 1, kernel_size, kernel_size), dtype=torch.float32, device=image.device) dilated = F.conv2d(image.unsqueeze(0), kernel, stride=1, padding=kernel_size // 2) dilated = (dilated > 0).float() eroded = F.conv2d(dilated, kernel, stride=1, padding=kernel_size // 2) return (eroded > 0).float() def gaussian_convolve(image: torch.Tensor, kernel_size: int = 5, sigma: float = 1.0) -> torch.Tensor: """ Gaussian Convolution to smooth image Args: image (torch.Tensor): image to convolve kernel_size (int): size of the Gaussian kernel sigma (float): standard deviation of the Gaussian distribution Returns: torch.Tensor: The convolved image. """ x = torch.arange(-kernel_size // 2 + 1, kernel_size // 2 + 1, dtype=torch.float32) y = torch.arange(-kernel_size // 2 + 1, kernel_size // 2 + 1, dtype=torch.float32) x, y = torch.meshgrid(x, y) kernel = torch.exp(-(x**2 + y**2) / (2 * sigma**2)) kernel = kernel / kernel.sum() # Apply the Gaussian kernel return F.conv2d(image.unsqueeze(0), kernel.unsqueeze(0).unsqueeze(0), stride=1, padding=kernel_size // 2) def hysteresis_filter(image: torch.Tensor, low_threshold: float, high_threshold: float) -> torch.Tensor: """ Hysteresis Filter Function - for Canny Edge detection Args: image (torch.Tensor): image to process low_threshold (float): low threshold for hysteresis high_threshold (float): high threshold for hysteresis Returns: edge (torch.Tensor): The edges detected in the image. """ edges = (image > high_threshold).float() # Perform hysteresis thresholding edges = torch.where(image > low_threshold, edges, 0) return edges def non_maxima_suppression_2d( image: torch.Tensor, kernel_size: int = 3, threshold: Optional[float] = None, return_mask: bool = False ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ Perform non-maxima suppression on a 2D torch tensor (image). Args: image (torch.Tensor): Input tensor of shape (H, W) or (B, C, H, W) or (C, H, W) kernel_size (int): Size of the local neighborhood for maxima detection (default: 3) threshold (float, optional): Minimum value threshold for considering pixels return_mask (bool): If True, return both suppressed image and binary mask Returns: torch.Tensor: Image with non-maxima suppressed torch.Tensor (optional): Binary mask of local maxima if return_mask=True """ original_shape = image.shape # Handle different input shapes if len(image.shape) == 2: # (H, W) image = image.unsqueeze(0).unsqueeze(0) # (1, 1, H, W) elif len(image.shape) == 3: # (C, H, W) image = image.unsqueeze(0) # (1, C, H, W) elif len(image.shape) == 4: # (B, C, H, W) pass else: raise ValueError(f"Unsupported tensor shape: {original_shape}") batch_size, channels, height, width = image.shape # Apply threshold if specified if threshold is not None: image = torch.where(image >= threshold, image, torch.tensor(0.0, device=image.device)) # Perform max pooling to find local maxima padding = kernel_size // 2 max_pooled = F.max_pool2d(image, kernel_size=kernel_size, stride=1, padding=padding) # Create mask where original values equal max pooled values (local maxima) mask = (image == max_pooled) & (image > 0) # Apply non-maxima suppression suppressed = image * mask.float() # Reshape back to original shape if len(original_shape) == 2: suppressed = suppressed.squeeze(0).squeeze(0) mask = mask.squeeze(0).squeeze(0) elif len(original_shape) == 3: suppressed = suppressed.squeeze(0) mask = mask.squeeze(0) if return_mask: return suppressed, mask return suppressed def non_maxima_suppression_with_orientation( magnitude: torch.Tensor, orientation: torch.Tensor, threshold: Optional[float] = None ) -> torch.Tensor: """ Perform oriented non-maxima suppression (commonly used in edge detection). Args: magnitude (torch.Tensor): Gradient magnitude tensor of shape (H, W) or (B, C, H, W) orientation (torch.Tensor): Gradient orientation tensor (in radians) of same shape threshold (float, optional): Minimum magnitude threshold Returns: torch.Tensor: Non-maxima suppressed magnitude """ original_shape = magnitude.shape # Handle different input shapes if len(magnitude.shape) == 2: magnitude = magnitude.unsqueeze(0).unsqueeze(0) orientation = orientation.unsqueeze(0).unsqueeze(0) elif len(magnitude.shape) == 3: magnitude = magnitude.unsqueeze(0) orientation = orientation.unsqueeze(0) batch_size, channels, height, width = magnitude.shape device = magnitude.device # Apply threshold if specified if threshold is not None: magnitude = torch.where(magnitude >= threshold, magnitude, torch.tensor(0.0, device=device)) # Convert orientation to degrees and normalize to [0, 180) angle = torch.rad2deg(orientation) % 180 # Create padded magnitude for neighbor comparison mag_padded = F.pad(magnitude, (1, 1, 1, 1), mode='constant', value=0) # Initialize output suppressed = torch.zeros_like(magnitude) # Define 8-connectivity neighbors for b in range(batch_size): for c in range(channels): mag = magnitude[b, c] ang = angle[b, c] mag_pad = mag_padded[b, c] for i in range(1, height + 1): for j in range(1, width + 1): current_mag = mag_pad[i, j] current_angle = ang[i-1, j-1] if current_mag == 0: continue # Determine interpolation direction based on angle if (0 <= current_angle < 22.5) or (157.5 <= current_angle < 180): # Horizontal direction (0°) neighbor1 = mag_pad[i, j-1] neighbor2 = mag_pad[i, j+1] elif 22.5 <= current_angle < 67.5: # Diagonal direction (45°) neighbor1 = mag_pad[i-1, j+1] neighbor2 = mag_pad[i+1, j-1] elif 67.5 <= current_angle < 112.5: # Vertical direction (90°) neighbor1 = mag_pad[i-1, j] neighbor2 = mag_pad[i+1, j] else: # 112.5 <= current_angle < 157.5 # Diagonal direction (135°) neighbor1 = mag_pad[i-1, j-1] neighbor2 = mag_pad[i+1, j+1] # Keep pixel if it's a local maximum if current_mag >= neighbor1 and current_mag >= neighbor2: suppressed[b, c, i-1, j-1] = current_mag # Reshape back to original shape if len(original_shape) == 2: suppressed = suppressed.squeeze(0).squeeze(0) elif len(original_shape) == 3: suppressed = suppressed.squeeze(0) return suppressed def adaptive_non_maxima_suppression( image: torch.Tensor, num_points: int, min_distance: int = 5, threshold: Optional[float] = None ) -> Tuple[torch.Tensor, torch.Tensor]: """ Adaptive non-maxima suppression that selects a fixed number of strongest points while maintaining minimum distance between them. Args: image (torch.Tensor): Input tensor of shape (H, W) num_points (int): Number of points to select min_distance (int): Minimum distance between selected points threshold (float, optional): Minimum value threshold Returns: Tuple[torch.Tensor, torch.Tensor]: Coordinates (y, x) and values of selected points """ if len(image.shape) != 2: raise ValueError("Input must be a 2D tensor") height, width = image.shape device = image.device # Apply threshold if specified if threshold is not None: image = torch.where(image >= threshold, image, torch.tensor(0.0, device=device)) # Find all local maxima using simple NMS nms_result = non_maxima_suppression_2d(image, kernel_size=3) # Get coordinates and values of all local maxima y_coords, x_coords = torch.nonzero(nms_result > 0, as_tuple=True) values = nms_result[y_coords, x_coords] if len(values) == 0: return torch.empty((0, 2), device=device), torch.empty(0, device=device) # Sort by strength (descending) sorted_indices = torch.argsort(values, descending=True) y_coords = y_coords[sorted_indices] x_coords = x_coords[sorted_indices] values = values[sorted_indices] # Select points with minimum distance constraint selected_coords = [] selected_values = [] for i in range(len(values)): if len(selected_coords) >= num_points: break current_y, current_x = y_coords[i].item(), x_coords[i].item() current_val = values[i].item() # Check distance to all previously selected points valid = True for sel_y, sel_x in selected_coords: distance = ((current_y - sel_y) ** 2 + (current_x - sel_x) ** 2) ** 0.5 if distance < min_distance: valid = False break if valid: selected_coords.append((current_y, current_x)) selected_values.append(current_val) if selected_coords: coords_tensor = torch.tensor(selected_coords, device=device, dtype=torch.float32) values_tensor = torch.tensor(selected_values, device=device, dtype=torch.float32) else: coords_tensor = torch.empty((0, 2), device=device) values_tensor = torch.empty(0, device=device) return coords_tensor, values_tensor