# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 """ SAM2-style post-processing utilities for mask segmentation. This module provides shared post-processing functions used by both the MaskLanguageLitModule (validation/testing) and the demo script. """ from typing import Tuple, Optional, Dict import time import numpy as np import torch try: from cuml.cluster import DBSCAN except ImportError: DBSCAN = None def calculate_stability_score( masks: torch.Tensor, mask_threshold: float = 0.0, threshold_offset: float = 1.0, ) -> torch.Tensor: """ Computes the stability score for a set of masks. The stability score is the IoU between the binary masks obtained by thresholding at (mask_threshold + threshold_offset) and (mask_threshold - threshold_offset). High stability means sharp mask boundaries. Args: masks: [Q, N] mask logits mask_threshold: Base threshold (usually 0.0 for logits) threshold_offset: Offset to apply for high/low thresholds Returns: stability_score: [Q] stability score per mask """ high_thresh_mask = masks > (mask_threshold + threshold_offset) low_thresh_mask = masks > (mask_threshold - threshold_offset) intersection = high_thresh_mask.float().sum(-1) union = low_thresh_mask.float().sum(-1) stability_score = intersection / (union + 1e-6) return stability_score def apply_nms( masks_binary: torch.Tensor, scores: torch.Tensor, nms_thresh: float = 0.7, ) -> torch.Tensor: """ Applies greedy NMS on masks using pairwise IoU. Args: masks_binary: [Q, N] binary masks (booleans or 0/1 floats) scores: [Q] mask scores for ranking nms_thresh: IoU threshold for suppression Returns: keep_indices: Tensor of indices to keep after NMS """ # Sort by score descending order = torch.argsort(scores, descending=True) masks_binary = masks_binary.bool() keep = [] indices = order while indices.numel() > 0: current = indices[0] keep.append(current.item()) if indices.numel() == 1: break # Compare current mask with rest current_mask = masks_binary[current].unsqueeze(0) # [1, N] rest_indices = indices[1:] rest_masks = masks_binary[rest_indices] # [K, N] intersection = (current_mask & rest_masks).float().sum(dim=1) union = (current_mask | rest_masks).float().sum(dim=1) iou = intersection / (union + 1e-6) # Keep masks with IoU < thresh mask_keep = iou < nms_thresh indices = rest_indices[mask_keep] return torch.tensor(keep, device=masks_binary.device, dtype=torch.long) def apply_dbscan_clustering( current_masks: torch.Tensor, point_coords: torch.Tensor, current_scores: torch.Tensor, current_classes: torch.Tensor, eps: float = 0.95, min_samples: int = 1, backend: str = "auto", ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Applies DBSCAN to each mask to split spatially disconnected components. Args: current_masks: [Q, N] boolean masks point_coords: [N, 3] point coordinates current_scores: [Q] scores current_classes: [Q] classes eps: DBSCAN eps parameter min_samples: DBSCAN min_samples parameter backend: "auto", "cuml", or "cpu" Returns: new_masks: [Q', N] expanded boolean masks new_scores: [Q'] expanded scores new_classes: [Q'] expanded classes new_indices: [Q'] indices mapping to original queries """ # 0. Size check (Performance optimization) - REMOVED GLOBAL CHECK # if point_coords.shape[0] > 100000: # print(f"DBSCAN: Skipping due to large point cloud ({point_coords.shape[0]} points > 100k)") # return current_masks, current_scores, current_classes # 1. Determine Backend use_cuml = False if backend == "auto": use_cuml = DBSCAN is not None elif backend == "cuml": if DBSCAN is None: print("Warning: backend='cuml' requested but cuML not found. Falling back to CPU.") use_cuml = False else: use_cuml = True elif backend == "cpu": use_cuml = False device = current_masks.device num_queries = current_masks.shape[0] # Initialize lists to hold the new split masks new_masks_list = [] # We'll store indices pointing to original scores/classes to avoid duplicating them early new_indices_list = [] # 2. Execution Path if use_cuml: # --- cuML (GPU) Path --- # print(f"DBSCAN (cuML): Processing {point_coords.shape[0]} points") # Ensure data is on GPU and valid types # cuML DBSCAN expects input of shape (n_samples, n_features) # We process each mask independently. # Optimization: To avoid loop overhead, we could try to batch, but DBSCAN isn't batched. # We iterate over queries. for i in range(num_queries): mask = current_masks[i] # Skip empty masks if not mask.any(): continue # Filter points for this mask # mask is [N], point_coords is [N, 3] # Slicing creates a new tensor on GPU points = point_coords[mask] # Check per-mask size limit if points.shape[0] > 100000: # Skip DBSCAN for this mask, keep original print( f"DBSCAN (cuML): Skipping mask {i} due to large point cloud ({points.shape[0]} points > 100k)" ) new_masks_list.append(mask) new_indices_list.append(i) continue if points.shape[0] < min_samples: # Keep original print( f"DBSCAN (cuML): Skipping mask {i} due to small point cloud ({points.shape[0]} points < {min_samples})" ) new_masks_list.append(mask) new_indices_list.append(i) continue try: # Run cuML DBSCAN # dbscan = DBSCAN(eps=eps, min_samples=min_samples) # labels = dbscan.fit_predict(points) # fit_predict returns a cudf Series or cupy array depending on input? # If input is torch tensor, cuML >= 23.04 supports __cuda_array_interface__ # It usually returns a cupy array or similar. # Check if we need to convert to cupy explicitly if torch support is iffy in installed version # But modern cuML supports torch tensors. start_time = time.time() clusterer = DBSCAN(eps=eps, min_samples=min_samples) labels = clusterer.fit_predict(points) db_time = time.time() - start_time # Labels is likely a cupy array or similar on GPU # Convert to torch for easier handling if hasattr(labels, "to_dlpack"): from torch.utils.dlpack import from_dlpack labels = from_dlpack(labels.to_dlpack()) elif hasattr(labels, "__cuda_array_interface__"): labels = torch.as_tensor(labels, device=device) unique_labels = torch.unique(labels) # Count valid clusters (excluding noise -1) valid_clusters = unique_labels[unique_labels != -1] if len(valid_clusters) == 0: # All noise? Or just one noise cluster? # If essentially no structure found, maybe keep original or drop? # Standard behavior: if it was a mask, and now it's all noise... # we probably shouldn't discard the *entire* mask content if it was a valid object. # But DBSCAN says it's noise. # Let's keep original if nothing valid found, similar to CPU path logic. pass found_cluster = False # Reconstruct masks # We need global indices of the points mask_indices = torch.nonzero(mask, as_tuple=True)[0] for label in valid_clusters: found_cluster = True # Create new boolean mask # 1. Start with zeros new_mask = torch.zeros_like(mask) # 2. Get local indices where label matches local_indices = (labels == label).nonzero(as_tuple=True)[0] # 3. Map to global indices global_indices = mask_indices[local_indices] # 4. Set True new_mask[global_indices] = True new_masks_list.append(new_mask) new_indices_list.append(i) if not found_cluster: # Treat as noise/failure to cluster, keep original? if len(new_masks_list) == 0 or new_indices_list[-1] != i: # If we haven't added anything for this query `i` # (Logic check: strictly speaking we might have added splits from previous masks # so checking new_indices_list[-1] is valid only if list not empty) pass except Exception as e: print(f"DBSCAN (cuML) Error Query {i}: {e}") # Fallback: keep original new_masks_list.append(mask) new_indices_list.append(i) else: # --- CPU Path --- # print(f"DBSCAN (CPU): Processing {point_coords.shape[0]} points") # Move inputs to CPU masks_cpu = current_masks.detach().cpu().numpy() coords_cpu = point_coords.detach().cpu().numpy() try: from sklearn.cluster import DBSCAN as SklearnDBSCAN except ImportError: print("Scikit-learn not found. Returning original masks.") print("Scikit-learn not found. Returning original masks.") return ( current_masks, current_scores, current_classes, torch.arange(num_queries, device=device), ) for i in range(num_queries): mask = masks_cpu[i] if not mask.any(): continue points = coords_cpu[mask] # Check per-mask size limit if points.shape[0] > 100000: # Skip DBSCAN for this mask, keep original print( f"DBSCAN (CPU): Skipping mask {i} due to large point cloud ({points.shape[0]} points > 100k)" ) new_masks_list.append(current_masks[i]) new_indices_list.append(i) continue if points.shape[0] < min_samples: # Keep original print( f"DBSCAN (CPU): Skipping mask {i} due to small point cloud ({points.shape[0]} points < {min_samples})" ) new_masks_list.append(current_masks[i]) new_indices_list.append(i) continue try: # Ensure float32 for sklearn start_time = time.time() clusterer = SklearnDBSCAN(eps=eps, min_samples=min_samples) labels = clusterer.fit_predict(points.astype(np.float32)) db_time = time.time() - start_time unique_labels = np.unique(labels) print( f"DBSCAN (CPU): Processing {points.shape[0]} points took {db_time:.4f} seconds, found {len(unique_labels)} clusters" ) found_cluster = False # We need indices to reconstruct mask on GPU/CPU # Since we are returning torch tensors on `device`, let's construct list of tensors # It is faster to construct on CPU then move or construct on GPU? # Constructing on GPU inside loop might be slow due to kernel launches. # Let's construct on GPU to match the list type of cuML path mask_indices_cpu = np.nonzero(mask)[0] for label in unique_labels: if label == -1: continue found_cluster = True # Construct new mask # It's easier to create on CPU then convert new_mask_cpu = np.zeros_like(mask) # bool/uint8 local_mask = labels == label active_indices = mask_indices_cpu[local_mask] new_mask_cpu[active_indices] = 1 # True # Convert to tensor on device new_masks_list.append( torch.from_numpy(new_mask_cpu).to(device, dtype=torch.bool) ) new_indices_list.append(i) if not found_cluster: # Keep original? Currently explicitly dropped in previous code pass? # "if not found_cluster: # Treated as noise, currently dropped." # But we should probably keep it if it was a valid object that just didn't cluster well? # The original code did `pass`. pass except Exception as e: print(f"DBSCAN (CPU) Error Query {i}: {e}") new_masks_list.append(current_masks[i]) new_indices_list.append(i) # 3. Assemble Results if len(new_masks_list) == 0: return ( torch.zeros((0, current_masks.shape[1]), device=device, dtype=torch.bool), torch.zeros((0,), device=device, dtype=current_scores.dtype), torch.zeros((0,), device=device, dtype=current_classes.dtype), torch.zeros((0,), device=device, dtype=torch.long), ) final_masks = torch.stack(new_masks_list) # Gather scores and classes using indices indices_tensor = torch.tensor(new_indices_list, device=device, dtype=torch.long) final_scores = current_scores[indices_tensor] final_classes = current_classes[indices_tensor] return final_masks, final_scores, final_classes, indices_tensor def apply_post_processing( pred_masks: torch.Tensor, pred_logits: torch.Tensor, mask_threshold: float = 0.0, point_coords: Optional[torch.Tensor] = None, pp_cfg: Optional[Dict] = None, pred_iou: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Applies configured post-processing filters. Args: pred_masks: [Q, N] mask logits pred_logits: [Q, 2] class logits (objectness is class 0) mask_threshold: Threshold for mask binarization (usually 0.0 for logits) pred_iou: Optional [Q] learned IoU logits from SpaceFormer's IoU head. When provided, `sigmoid(pred_iou)` replaces the hand-coded `mask_quality = (sigmoid(masks) * binary).sum / binary.sum` proxy in the score = obj * quality formula. DBSCAN expansion copies the same scalar to every component of an expanded query. pp_cfg: Post-processing configuration dict with keys: - objectness_thresh: float (default 0.0, disabled) - min_mask_points: int (default 0, disabled) - use_stability_score: bool (default False) - stability_score_thresh: float (default 0.9) - stability_score_offset: float (default 1.0) - stability_score_thresh: float (default 0.9) - stability_score_offset: float (default 1.0) - use_nms: bool (default False) - nms_thresh: float (default 0.7) - use_dbscan: bool (default False) - dbscan_eps: float (default 0.95) - dbscan_min_points: int (default 1) - dbscan_backend: str (default "auto") Returns: final_masks: [Q', N] final binary masks final_scores: [Q'] final scores final_classes: [Q'] final classes final_indices: [Q'] indices mapping to original queries """ if pp_cfg is None: pp_cfg = {} # Basic preparation masks_binary = pred_masks > mask_threshold # 0. Min Point Count Filtering (FIRST STEP - early rejection) # Filter out small masks before expensive operations like DBSCAN keep = torch.arange(pred_masks.shape[0], device=pred_masks.device) if pp_cfg.get("min_mask_points", 0) > 0: counts = masks_binary.float().sum(1) keep_size = counts >= pp_cfg["min_mask_points"] keep = keep[keep_size] if len(keep) == 0: return ( torch.zeros((0, pred_masks.shape[1]), device=pred_masks.device, dtype=torch.bool), torch.zeros((0,), device=pred_masks.device, dtype=pred_masks.dtype), torch.zeros((0,), device=pred_masks.device, dtype=torch.long), torch.zeros((0,), device=pred_masks.device, dtype=torch.long), ) # Filter all inputs masks_binary = masks_binary[keep] pred_masks = pred_masks[keep] pred_logits = pred_logits[keep] if pred_iou is not None: pred_iou = pred_iou[keep] # 1. DBSCAN Expansion # If DBSCAN is used, we expand masks immediately. # We maintain a mapping to original logits to allow stability calculation later. current_masks = masks_binary current_logits = pred_masks current_pred_logits = pred_logits # Track indices (now relative to filtered set if min_mask_points was applied) current_indices = keep.clone() # Objectness component # Check what class 0 means? obj_probs = pred_logits.softmax(dim=-1)[:, 0] # Mask quality component (IoU proxy) — learned if pred_iou is provided # (P3-SAM-style IoU head), otherwise the hand-coded sigmoid-mean proxy. if pred_iou is not None: mask_quality = pred_iou.sigmoid() else: masks_sigmoid = pred_masks.sigmoid() mask_quality = (masks_sigmoid * masks_binary.float()).sum(1) / ( masks_binary.float().sum(1) + 1e-6 ) scores = obj_probs * mask_quality classes = torch.zeros_like(scores, dtype=torch.long) # class 0 if pp_cfg.get("use_dbscan", False) and point_coords is not None: current_masks, scores, classes, dbscan_indices = apply_dbscan_clustering( current_masks, point_coords, scores, classes, eps=pp_cfg.get("dbscan_eps", 0.95), min_samples=pp_cfg.get("dbscan_min_points", 1), backend=pp_cfg.get("dbscan_backend", "auto"), ) # We need to map them back to original query indices current_indices = keep[dbscan_indices] # Expand logits and other properties to match split masks # Use dbscan_indices (relative to current filtered set) for indexing current tensors current_logits = current_logits[dbscan_indices] current_pred_logits = current_pred_logits[dbscan_indices] obj_probs = obj_probs[dbscan_indices] # MASK THE LOGITS (Stability Fix) # Key step: constrain the logits to the new binary mask shape # so stability score is calculated on the component, not the whole original mask. # We use a large negative value for background. current_logits = torch.where(current_masks, current_logits, -100.0) # Recalculate mask quality for the NEW masks. With learned IoU we copy # the parent query's scalar to every expanded component (no per-component # IoU prediction is available); without it, recompute the sigmoid-mean # proxy from the masked logits. if pred_iou is not None: mask_quality = pred_iou[dbscan_indices].sigmoid() else: masks_sigmoid = current_logits.sigmoid() mask_quality = (masks_sigmoid * current_masks.float()).sum(1) / ( current_masks.float().sum(1) + 1e-6 ) # Recalculate scores (Obj * Quality) scores = obj_probs * mask_quality # Now we have `current_masks` (binary) and `current_logits` (masked logits). # All subsequent steps operate on these. # 2. Objectness Filtering keep = torch.arange(current_masks.shape[0], device=current_masks.device) if pp_cfg.get("objectness_thresh", 0.0) > 0: # obj_probs is aligned with current set keep_obj = obj_probs > pp_cfg["objectness_thresh"] keep = keep[keep_obj[keep]] if len(keep) == 0: return ( torch.zeros((0, pred_masks.shape[1]), device=pred_masks.device, dtype=torch.bool), torch.zeros((0,), device=pred_masks.device, dtype=scores.dtype), torch.zeros((0,), device=pred_masks.device, dtype=classes.dtype), torch.zeros((0,), device=pred_masks.device, dtype=torch.long), ) # 3. Stability Score if pp_cfg.get("use_stability_score", False): active_logits = current_logits[keep] stability = calculate_stability_score( active_logits, mask_threshold, pp_cfg.get("stability_score_offset", 1.0), ) keep_stable = stability >= pp_cfg.get("stability_score_thresh", 0.9) keep = keep[keep_stable] if len(keep) == 0: return ( torch.zeros((0, pred_masks.shape[1]), device=pred_masks.device, dtype=torch.bool), torch.zeros((0,), device=pred_masks.device, dtype=scores.dtype), torch.zeros((0,), device=pred_masks.device, dtype=classes.dtype), torch.zeros((0,), device=pred_masks.device, dtype=torch.long), ) # 4. NMS if pp_cfg.get("use_nms", False): active_masks = current_masks[keep] active_scores = scores[keep] keep_nms = apply_nms(active_masks, active_scores, pp_cfg.get("nms_thresh", 0.7)) keep = keep[keep_nms] # Final gather final_masks = current_masks[keep] final_scores = scores[keep] final_classes = classes[keep] final_indices = current_indices[keep] return final_masks, final_scores, final_classes, final_indices