| |
| |
| """ |
| SAM2-style post-processing utilities for mask segmentation. |
| |
| This module provides shared post-processing functions used by both the |
| MaskLanguageLitModule (validation/testing) and the demo script. |
| """ |
|
|
| from typing import Tuple, Optional, Dict |
| import time |
| import numpy as np |
|
|
| import torch |
|
|
| try: |
| from cuml.cluster import DBSCAN |
| except ImportError: |
| DBSCAN = None |
|
|
|
|
| def calculate_stability_score( |
| masks: torch.Tensor, |
| mask_threshold: float = 0.0, |
| threshold_offset: float = 1.0, |
| ) -> torch.Tensor: |
| """ |
| Computes the stability score for a set of masks. |
| |
| The stability score is the IoU between the binary masks obtained by |
| thresholding at (mask_threshold + threshold_offset) and |
| (mask_threshold - threshold_offset). |
| |
| High stability means sharp mask boundaries. |
| |
| Args: |
| masks: [Q, N] mask logits |
| mask_threshold: Base threshold (usually 0.0 for logits) |
| threshold_offset: Offset to apply for high/low thresholds |
| |
| Returns: |
| stability_score: [Q] stability score per mask |
| """ |
| high_thresh_mask = masks > (mask_threshold + threshold_offset) |
| low_thresh_mask = masks > (mask_threshold - threshold_offset) |
|
|
| intersection = high_thresh_mask.float().sum(-1) |
| union = low_thresh_mask.float().sum(-1) |
|
|
| stability_score = intersection / (union + 1e-6) |
| return stability_score |
|
|
|
|
| def apply_nms( |
| masks_binary: torch.Tensor, |
| scores: torch.Tensor, |
| nms_thresh: float = 0.7, |
| ) -> torch.Tensor: |
| """ |
| Applies greedy NMS on masks using pairwise IoU. |
| |
| Args: |
| masks_binary: [Q, N] binary masks (booleans or 0/1 floats) |
| scores: [Q] mask scores for ranking |
| nms_thresh: IoU threshold for suppression |
| |
| Returns: |
| keep_indices: Tensor of indices to keep after NMS |
| """ |
| |
| order = torch.argsort(scores, descending=True) |
| masks_binary = masks_binary.bool() |
|
|
| keep = [] |
| indices = order |
|
|
| while indices.numel() > 0: |
| current = indices[0] |
| keep.append(current.item()) |
|
|
| if indices.numel() == 1: |
| break |
|
|
| |
| current_mask = masks_binary[current].unsqueeze(0) |
| rest_indices = indices[1:] |
| rest_masks = masks_binary[rest_indices] |
|
|
| intersection = (current_mask & rest_masks).float().sum(dim=1) |
| union = (current_mask | rest_masks).float().sum(dim=1) |
| iou = intersection / (union + 1e-6) |
|
|
| |
| mask_keep = iou < nms_thresh |
| indices = rest_indices[mask_keep] |
|
|
| return torch.tensor(keep, device=masks_binary.device, dtype=torch.long) |
|
|
|
|
| def apply_dbscan_clustering( |
| current_masks: torch.Tensor, |
| point_coords: torch.Tensor, |
| current_scores: torch.Tensor, |
| current_classes: torch.Tensor, |
| eps: float = 0.95, |
| min_samples: int = 1, |
| backend: str = "auto", |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| """ |
| Applies DBSCAN to each mask to split spatially disconnected components. |
| |
| Args: |
| current_masks: [Q, N] boolean masks |
| point_coords: [N, 3] point coordinates |
| current_scores: [Q] scores |
| current_classes: [Q] classes |
| eps: DBSCAN eps parameter |
| min_samples: DBSCAN min_samples parameter |
| backend: "auto", "cuml", or "cpu" |
| |
| Returns: |
| new_masks: [Q', N] expanded boolean masks |
| new_scores: [Q'] expanded scores |
| new_classes: [Q'] expanded classes |
| new_indices: [Q'] indices mapping to original queries |
| """ |
| |
| |
| |
| |
|
|
| |
| use_cuml = False |
| if backend == "auto": |
| use_cuml = DBSCAN is not None |
| elif backend == "cuml": |
| if DBSCAN is None: |
| print("Warning: backend='cuml' requested but cuML not found. Falling back to CPU.") |
| use_cuml = False |
| else: |
| use_cuml = True |
| elif backend == "cpu": |
| use_cuml = False |
|
|
| device = current_masks.device |
| num_queries = current_masks.shape[0] |
|
|
| |
| new_masks_list = [] |
| |
| new_indices_list = [] |
|
|
| |
| if use_cuml: |
| |
| |
|
|
| |
| |
| |
|
|
| |
| |
|
|
| for i in range(num_queries): |
| mask = current_masks[i] |
|
|
| |
| if not mask.any(): |
| continue |
|
|
| |
| |
| |
| points = point_coords[mask] |
|
|
| |
| if points.shape[0] > 100000: |
| |
| print( |
| f"DBSCAN (cuML): Skipping mask {i} due to large point cloud ({points.shape[0]} points > 100k)" |
| ) |
| new_masks_list.append(mask) |
| new_indices_list.append(i) |
| continue |
|
|
| if points.shape[0] < min_samples: |
| |
| print( |
| f"DBSCAN (cuML): Skipping mask {i} due to small point cloud ({points.shape[0]} points < {min_samples})" |
| ) |
| new_masks_list.append(mask) |
| new_indices_list.append(i) |
| continue |
|
|
| try: |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
|
|
| start_time = time.time() |
| clusterer = DBSCAN(eps=eps, min_samples=min_samples) |
| labels = clusterer.fit_predict(points) |
| db_time = time.time() - start_time |
|
|
| |
| |
| if hasattr(labels, "to_dlpack"): |
| from torch.utils.dlpack import from_dlpack |
|
|
| labels = from_dlpack(labels.to_dlpack()) |
| elif hasattr(labels, "__cuda_array_interface__"): |
| labels = torch.as_tensor(labels, device=device) |
|
|
| unique_labels = torch.unique(labels) |
|
|
| |
| valid_clusters = unique_labels[unique_labels != -1] |
|
|
| if len(valid_clusters) == 0: |
| |
| |
| |
| |
| |
| |
| pass |
|
|
| found_cluster = False |
|
|
| |
| |
| mask_indices = torch.nonzero(mask, as_tuple=True)[0] |
|
|
| for label in valid_clusters: |
| found_cluster = True |
| |
| |
| new_mask = torch.zeros_like(mask) |
| |
| local_indices = (labels == label).nonzero(as_tuple=True)[0] |
| |
| global_indices = mask_indices[local_indices] |
| |
| new_mask[global_indices] = True |
|
|
| new_masks_list.append(new_mask) |
| new_indices_list.append(i) |
|
|
| if not found_cluster: |
| |
| if len(new_masks_list) == 0 or new_indices_list[-1] != i: |
| |
| |
| |
| pass |
|
|
| except Exception as e: |
| print(f"DBSCAN (cuML) Error Query {i}: {e}") |
| |
| new_masks_list.append(mask) |
| new_indices_list.append(i) |
|
|
| else: |
| |
| |
|
|
| |
| masks_cpu = current_masks.detach().cpu().numpy() |
| coords_cpu = point_coords.detach().cpu().numpy() |
|
|
| try: |
| from sklearn.cluster import DBSCAN as SklearnDBSCAN |
| except ImportError: |
| print("Scikit-learn not found. Returning original masks.") |
| print("Scikit-learn not found. Returning original masks.") |
| return ( |
| current_masks, |
| current_scores, |
| current_classes, |
| torch.arange(num_queries, device=device), |
| ) |
|
|
| for i in range(num_queries): |
| mask = masks_cpu[i] |
|
|
| if not mask.any(): |
| continue |
|
|
| points = coords_cpu[mask] |
|
|
| |
| if points.shape[0] > 100000: |
| |
| print( |
| f"DBSCAN (CPU): Skipping mask {i} due to large point cloud ({points.shape[0]} points > 100k)" |
| ) |
| new_masks_list.append(current_masks[i]) |
| new_indices_list.append(i) |
| continue |
|
|
| if points.shape[0] < min_samples: |
| |
| print( |
| f"DBSCAN (CPU): Skipping mask {i} due to small point cloud ({points.shape[0]} points < {min_samples})" |
| ) |
| new_masks_list.append(current_masks[i]) |
| new_indices_list.append(i) |
| continue |
|
|
| try: |
| |
| start_time = time.time() |
| clusterer = SklearnDBSCAN(eps=eps, min_samples=min_samples) |
| labels = clusterer.fit_predict(points.astype(np.float32)) |
| db_time = time.time() - start_time |
| unique_labels = np.unique(labels) |
| print( |
| f"DBSCAN (CPU): Processing {points.shape[0]} points took {db_time:.4f} seconds, found {len(unique_labels)} clusters" |
| ) |
| found_cluster = False |
|
|
| |
| |
| |
| |
| |
|
|
| mask_indices_cpu = np.nonzero(mask)[0] |
|
|
| for label in unique_labels: |
| if label == -1: |
| continue |
| found_cluster = True |
|
|
| |
| |
| new_mask_cpu = np.zeros_like(mask) |
|
|
| local_mask = labels == label |
| active_indices = mask_indices_cpu[local_mask] |
| new_mask_cpu[active_indices] = 1 |
|
|
| |
| new_masks_list.append( |
| torch.from_numpy(new_mask_cpu).to(device, dtype=torch.bool) |
| ) |
| new_indices_list.append(i) |
|
|
| if not found_cluster: |
| |
| |
| |
| |
| pass |
|
|
| except Exception as e: |
| print(f"DBSCAN (CPU) Error Query {i}: {e}") |
| new_masks_list.append(current_masks[i]) |
| new_indices_list.append(i) |
|
|
| |
| if len(new_masks_list) == 0: |
| return ( |
| torch.zeros((0, current_masks.shape[1]), device=device, dtype=torch.bool), |
| torch.zeros((0,), device=device, dtype=current_scores.dtype), |
| torch.zeros((0,), device=device, dtype=current_classes.dtype), |
| torch.zeros((0,), device=device, dtype=torch.long), |
| ) |
|
|
| final_masks = torch.stack(new_masks_list) |
|
|
| |
| indices_tensor = torch.tensor(new_indices_list, device=device, dtype=torch.long) |
| final_scores = current_scores[indices_tensor] |
| final_classes = current_classes[indices_tensor] |
|
|
| return final_masks, final_scores, final_classes, indices_tensor |
|
|
|
|
| def apply_post_processing( |
| pred_masks: torch.Tensor, |
| pred_logits: torch.Tensor, |
| mask_threshold: float = 0.0, |
| point_coords: Optional[torch.Tensor] = None, |
| pp_cfg: Optional[Dict] = None, |
| pred_iou: Optional[torch.Tensor] = None, |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
| """ |
| Applies configured post-processing filters. |
| |
| Args: |
| pred_masks: [Q, N] mask logits |
| pred_logits: [Q, 2] class logits (objectness is class 0) |
| mask_threshold: Threshold for mask binarization (usually 0.0 for logits) |
| pred_iou: Optional [Q] learned IoU logits from SpaceFormer's IoU head. |
| When provided, `sigmoid(pred_iou)` replaces the hand-coded |
| `mask_quality = (sigmoid(masks) * binary).sum / binary.sum` proxy in |
| the score = obj * quality formula. DBSCAN expansion copies the same |
| scalar to every component of an expanded query. |
| pp_cfg: Post-processing configuration dict with keys: |
| - objectness_thresh: float (default 0.0, disabled) |
| - min_mask_points: int (default 0, disabled) |
| - use_stability_score: bool (default False) |
| - stability_score_thresh: float (default 0.9) |
| - stability_score_offset: float (default 1.0) |
| - stability_score_thresh: float (default 0.9) |
| - stability_score_offset: float (default 1.0) |
| - use_nms: bool (default False) |
| - nms_thresh: float (default 0.7) |
| - use_dbscan: bool (default False) |
| - dbscan_eps: float (default 0.95) |
| - dbscan_min_points: int (default 1) |
| - dbscan_backend: str (default "auto") |
| |
| Returns: |
| final_masks: [Q', N] final binary masks |
| final_scores: [Q'] final scores |
| final_classes: [Q'] final classes |
| final_indices: [Q'] indices mapping to original queries |
| """ |
| if pp_cfg is None: |
| pp_cfg = {} |
|
|
| |
| masks_binary = pred_masks > mask_threshold |
|
|
| |
| |
| keep = torch.arange(pred_masks.shape[0], device=pred_masks.device) |
|
|
| if pp_cfg.get("min_mask_points", 0) > 0: |
| counts = masks_binary.float().sum(1) |
| keep_size = counts >= pp_cfg["min_mask_points"] |
| keep = keep[keep_size] |
|
|
| if len(keep) == 0: |
| return ( |
| torch.zeros((0, pred_masks.shape[1]), device=pred_masks.device, dtype=torch.bool), |
| torch.zeros((0,), device=pred_masks.device, dtype=pred_masks.dtype), |
| torch.zeros((0,), device=pred_masks.device, dtype=torch.long), |
| torch.zeros((0,), device=pred_masks.device, dtype=torch.long), |
| ) |
|
|
| |
| masks_binary = masks_binary[keep] |
| pred_masks = pred_masks[keep] |
| pred_logits = pred_logits[keep] |
| if pred_iou is not None: |
| pred_iou = pred_iou[keep] |
|
|
| |
| |
| |
|
|
| current_masks = masks_binary |
| current_logits = pred_masks |
| current_pred_logits = pred_logits |
|
|
| |
| current_indices = keep.clone() |
|
|
| |
| |
| obj_probs = pred_logits.softmax(dim=-1)[:, 0] |
|
|
| |
| |
| if pred_iou is not None: |
| mask_quality = pred_iou.sigmoid() |
| else: |
| masks_sigmoid = pred_masks.sigmoid() |
| mask_quality = (masks_sigmoid * masks_binary.float()).sum(1) / ( |
| masks_binary.float().sum(1) + 1e-6 |
| ) |
| scores = obj_probs * mask_quality |
| classes = torch.zeros_like(scores, dtype=torch.long) |
|
|
| if pp_cfg.get("use_dbscan", False) and point_coords is not None: |
| current_masks, scores, classes, dbscan_indices = apply_dbscan_clustering( |
| current_masks, |
| point_coords, |
| scores, |
| classes, |
| eps=pp_cfg.get("dbscan_eps", 0.95), |
| min_samples=pp_cfg.get("dbscan_min_points", 1), |
| backend=pp_cfg.get("dbscan_backend", "auto"), |
| ) |
|
|
| |
| current_indices = keep[dbscan_indices] |
|
|
| |
| |
| current_logits = current_logits[dbscan_indices] |
| current_pred_logits = current_pred_logits[dbscan_indices] |
| obj_probs = obj_probs[dbscan_indices] |
|
|
| |
| |
| |
| |
| current_logits = torch.where(current_masks, current_logits, -100.0) |
|
|
| |
| |
| |
| |
| if pred_iou is not None: |
| mask_quality = pred_iou[dbscan_indices].sigmoid() |
| else: |
| masks_sigmoid = current_logits.sigmoid() |
| mask_quality = (masks_sigmoid * current_masks.float()).sum(1) / ( |
| current_masks.float().sum(1) + 1e-6 |
| ) |
| |
| scores = obj_probs * mask_quality |
|
|
| |
| |
|
|
| |
| keep = torch.arange(current_masks.shape[0], device=current_masks.device) |
|
|
| if pp_cfg.get("objectness_thresh", 0.0) > 0: |
| |
| keep_obj = obj_probs > pp_cfg["objectness_thresh"] |
| keep = keep[keep_obj[keep]] |
|
|
| if len(keep) == 0: |
| return ( |
| torch.zeros((0, pred_masks.shape[1]), device=pred_masks.device, dtype=torch.bool), |
| torch.zeros((0,), device=pred_masks.device, dtype=scores.dtype), |
| torch.zeros((0,), device=pred_masks.device, dtype=classes.dtype), |
| torch.zeros((0,), device=pred_masks.device, dtype=torch.long), |
| ) |
|
|
| |
| if pp_cfg.get("use_stability_score", False): |
| active_logits = current_logits[keep] |
| stability = calculate_stability_score( |
| active_logits, |
| mask_threshold, |
| pp_cfg.get("stability_score_offset", 1.0), |
| ) |
| keep_stable = stability >= pp_cfg.get("stability_score_thresh", 0.9) |
| keep = keep[keep_stable] |
|
|
| if len(keep) == 0: |
| return ( |
| torch.zeros((0, pred_masks.shape[1]), device=pred_masks.device, dtype=torch.bool), |
| torch.zeros((0,), device=pred_masks.device, dtype=scores.dtype), |
| torch.zeros((0,), device=pred_masks.device, dtype=classes.dtype), |
| torch.zeros((0,), device=pred_masks.device, dtype=torch.long), |
| ) |
|
|
| |
| if pp_cfg.get("use_nms", False): |
| active_masks = current_masks[keep] |
| active_scores = scores[keep] |
|
|
| keep_nms = apply_nms(active_masks, active_scores, pp_cfg.get("nms_thresh", 0.7)) |
| keep = keep[keep_nms] |
|
|
| |
| final_masks = current_masks[keep] |
| final_scores = scores[keep] |
| final_classes = classes[keep] |
| final_indices = current_indices[keep] |
|
|
| return final_masks, final_scores, final_classes, final_indices |
|
|