SpaCeFormer / demo /postprocessing.py
chrischoy's picture
Merge SpaceFormer demo (viser + CLI + Gradio) under demo/
a8e8155 verified
Raw
History Blame Contribute Delete
22.5 kB
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""
SAM2-style post-processing utilities for mask segmentation.
This module provides shared post-processing functions used by both the
MaskLanguageLitModule (validation/testing) and the demo script.
"""
from typing import Tuple, Optional, Dict
import time
import numpy as np
import torch
try:
from cuml.cluster import DBSCAN
except ImportError:
DBSCAN = None
def calculate_stability_score(
masks: torch.Tensor,
mask_threshold: float = 0.0,
threshold_offset: float = 1.0,
) -> torch.Tensor:
"""
Computes the stability score for a set of masks.
The stability score is the IoU between the binary masks obtained by
thresholding at (mask_threshold + threshold_offset) and
(mask_threshold - threshold_offset).
High stability means sharp mask boundaries.
Args:
masks: [Q, N] mask logits
mask_threshold: Base threshold (usually 0.0 for logits)
threshold_offset: Offset to apply for high/low thresholds
Returns:
stability_score: [Q] stability score per mask
"""
high_thresh_mask = masks > (mask_threshold + threshold_offset)
low_thresh_mask = masks > (mask_threshold - threshold_offset)
intersection = high_thresh_mask.float().sum(-1)
union = low_thresh_mask.float().sum(-1)
stability_score = intersection / (union + 1e-6)
return stability_score
def apply_nms(
masks_binary: torch.Tensor,
scores: torch.Tensor,
nms_thresh: float = 0.7,
) -> torch.Tensor:
"""
Applies greedy NMS on masks using pairwise IoU.
Args:
masks_binary: [Q, N] binary masks (booleans or 0/1 floats)
scores: [Q] mask scores for ranking
nms_thresh: IoU threshold for suppression
Returns:
keep_indices: Tensor of indices to keep after NMS
"""
# Sort by score descending
order = torch.argsort(scores, descending=True)
masks_binary = masks_binary.bool()
keep = []
indices = order
while indices.numel() > 0:
current = indices[0]
keep.append(current.item())
if indices.numel() == 1:
break
# Compare current mask with rest
current_mask = masks_binary[current].unsqueeze(0) # [1, N]
rest_indices = indices[1:]
rest_masks = masks_binary[rest_indices] # [K, N]
intersection = (current_mask & rest_masks).float().sum(dim=1)
union = (current_mask | rest_masks).float().sum(dim=1)
iou = intersection / (union + 1e-6)
# Keep masks with IoU < thresh
mask_keep = iou < nms_thresh
indices = rest_indices[mask_keep]
return torch.tensor(keep, device=masks_binary.device, dtype=torch.long)
def apply_dbscan_clustering(
current_masks: torch.Tensor,
point_coords: torch.Tensor,
current_scores: torch.Tensor,
current_classes: torch.Tensor,
eps: float = 0.95,
min_samples: int = 1,
backend: str = "auto",
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Applies DBSCAN to each mask to split spatially disconnected components.
Args:
current_masks: [Q, N] boolean masks
point_coords: [N, 3] point coordinates
current_scores: [Q] scores
current_classes: [Q] classes
eps: DBSCAN eps parameter
min_samples: DBSCAN min_samples parameter
backend: "auto", "cuml", or "cpu"
Returns:
new_masks: [Q', N] expanded boolean masks
new_scores: [Q'] expanded scores
new_classes: [Q'] expanded classes
new_indices: [Q'] indices mapping to original queries
"""
# 0. Size check (Performance optimization) - REMOVED GLOBAL CHECK
# if point_coords.shape[0] > 100000:
# print(f"DBSCAN: Skipping due to large point cloud ({point_coords.shape[0]} points > 100k)")
# return current_masks, current_scores, current_classes
# 1. Determine Backend
use_cuml = False
if backend == "auto":
use_cuml = DBSCAN is not None
elif backend == "cuml":
if DBSCAN is None:
print("Warning: backend='cuml' requested but cuML not found. Falling back to CPU.")
use_cuml = False
else:
use_cuml = True
elif backend == "cpu":
use_cuml = False
device = current_masks.device
num_queries = current_masks.shape[0]
# Initialize lists to hold the new split masks
new_masks_list = []
# We'll store indices pointing to original scores/classes to avoid duplicating them early
new_indices_list = []
# 2. Execution Path
if use_cuml:
# --- cuML (GPU) Path ---
# print(f"DBSCAN (cuML): Processing {point_coords.shape[0]} points")
# Ensure data is on GPU and valid types
# cuML DBSCAN expects input of shape (n_samples, n_features)
# We process each mask independently.
# Optimization: To avoid loop overhead, we could try to batch, but DBSCAN isn't batched.
# We iterate over queries.
for i in range(num_queries):
mask = current_masks[i]
# Skip empty masks
if not mask.any():
continue
# Filter points for this mask
# mask is [N], point_coords is [N, 3]
# Slicing creates a new tensor on GPU
points = point_coords[mask]
# Check per-mask size limit
if points.shape[0] > 100000:
# Skip DBSCAN for this mask, keep original
print(
f"DBSCAN (cuML): Skipping mask {i} due to large point cloud ({points.shape[0]} points > 100k)"
)
new_masks_list.append(mask)
new_indices_list.append(i)
continue
if points.shape[0] < min_samples:
# Keep original
print(
f"DBSCAN (cuML): Skipping mask {i} due to small point cloud ({points.shape[0]} points < {min_samples})"
)
new_masks_list.append(mask)
new_indices_list.append(i)
continue
try:
# Run cuML DBSCAN
# dbscan = DBSCAN(eps=eps, min_samples=min_samples)
# labels = dbscan.fit_predict(points)
# fit_predict returns a cudf Series or cupy array depending on input?
# If input is torch tensor, cuML >= 23.04 supports __cuda_array_interface__
# It usually returns a cupy array or similar.
# Check if we need to convert to cupy explicitly if torch support is iffy in installed version
# But modern cuML supports torch tensors.
start_time = time.time()
clusterer = DBSCAN(eps=eps, min_samples=min_samples)
labels = clusterer.fit_predict(points)
db_time = time.time() - start_time
# Labels is likely a cupy array or similar on GPU
# Convert to torch for easier handling
if hasattr(labels, "to_dlpack"):
from torch.utils.dlpack import from_dlpack
labels = from_dlpack(labels.to_dlpack())
elif hasattr(labels, "__cuda_array_interface__"):
labels = torch.as_tensor(labels, device=device)
unique_labels = torch.unique(labels)
# Count valid clusters (excluding noise -1)
valid_clusters = unique_labels[unique_labels != -1]
if len(valid_clusters) == 0:
# All noise? Or just one noise cluster?
# If essentially no structure found, maybe keep original or drop?
# Standard behavior: if it was a mask, and now it's all noise...
# we probably shouldn't discard the *entire* mask content if it was a valid object.
# But DBSCAN says it's noise.
# Let's keep original if nothing valid found, similar to CPU path logic.
pass
found_cluster = False
# Reconstruct masks
# We need global indices of the points
mask_indices = torch.nonzero(mask, as_tuple=True)[0]
for label in valid_clusters:
found_cluster = True
# Create new boolean mask
# 1. Start with zeros
new_mask = torch.zeros_like(mask)
# 2. Get local indices where label matches
local_indices = (labels == label).nonzero(as_tuple=True)[0]
# 3. Map to global indices
global_indices = mask_indices[local_indices]
# 4. Set True
new_mask[global_indices] = True
new_masks_list.append(new_mask)
new_indices_list.append(i)
if not found_cluster:
# Treat as noise/failure to cluster, keep original?
if len(new_masks_list) == 0 or new_indices_list[-1] != i:
# If we haven't added anything for this query `i`
# (Logic check: strictly speaking we might have added splits from previous masks
# so checking new_indices_list[-1] is valid only if list not empty)
pass
except Exception as e:
print(f"DBSCAN (cuML) Error Query {i}: {e}")
# Fallback: keep original
new_masks_list.append(mask)
new_indices_list.append(i)
else:
# --- CPU Path ---
# print(f"DBSCAN (CPU): Processing {point_coords.shape[0]} points")
# Move inputs to CPU
masks_cpu = current_masks.detach().cpu().numpy()
coords_cpu = point_coords.detach().cpu().numpy()
try:
from sklearn.cluster import DBSCAN as SklearnDBSCAN
except ImportError:
print("Scikit-learn not found. Returning original masks.")
print("Scikit-learn not found. Returning original masks.")
return (
current_masks,
current_scores,
current_classes,
torch.arange(num_queries, device=device),
)
for i in range(num_queries):
mask = masks_cpu[i]
if not mask.any():
continue
points = coords_cpu[mask]
# Check per-mask size limit
if points.shape[0] > 100000:
# Skip DBSCAN for this mask, keep original
print(
f"DBSCAN (CPU): Skipping mask {i} due to large point cloud ({points.shape[0]} points > 100k)"
)
new_masks_list.append(current_masks[i])
new_indices_list.append(i)
continue
if points.shape[0] < min_samples:
# Keep original
print(
f"DBSCAN (CPU): Skipping mask {i} due to small point cloud ({points.shape[0]} points < {min_samples})"
)
new_masks_list.append(current_masks[i])
new_indices_list.append(i)
continue
try:
# Ensure float32 for sklearn
start_time = time.time()
clusterer = SklearnDBSCAN(eps=eps, min_samples=min_samples)
labels = clusterer.fit_predict(points.astype(np.float32))
db_time = time.time() - start_time
unique_labels = np.unique(labels)
print(
f"DBSCAN (CPU): Processing {points.shape[0]} points took {db_time:.4f} seconds, found {len(unique_labels)} clusters"
)
found_cluster = False
# We need indices to reconstruct mask on GPU/CPU
# Since we are returning torch tensors on `device`, let's construct list of tensors
# It is faster to construct on CPU then move or construct on GPU?
# Constructing on GPU inside loop might be slow due to kernel launches.
# Let's construct on GPU to match the list type of cuML path
mask_indices_cpu = np.nonzero(mask)[0]
for label in unique_labels:
if label == -1:
continue
found_cluster = True
# Construct new mask
# It's easier to create on CPU then convert
new_mask_cpu = np.zeros_like(mask) # bool/uint8
local_mask = labels == label
active_indices = mask_indices_cpu[local_mask]
new_mask_cpu[active_indices] = 1 # True
# Convert to tensor on device
new_masks_list.append(
torch.from_numpy(new_mask_cpu).to(device, dtype=torch.bool)
)
new_indices_list.append(i)
if not found_cluster:
# Keep original? Currently explicitly dropped in previous code pass?
# "if not found_cluster: # Treated as noise, currently dropped."
# But we should probably keep it if it was a valid object that just didn't cluster well?
# The original code did `pass`.
pass
except Exception as e:
print(f"DBSCAN (CPU) Error Query {i}: {e}")
new_masks_list.append(current_masks[i])
new_indices_list.append(i)
# 3. Assemble Results
if len(new_masks_list) == 0:
return (
torch.zeros((0, current_masks.shape[1]), device=device, dtype=torch.bool),
torch.zeros((0,), device=device, dtype=current_scores.dtype),
torch.zeros((0,), device=device, dtype=current_classes.dtype),
torch.zeros((0,), device=device, dtype=torch.long),
)
final_masks = torch.stack(new_masks_list)
# Gather scores and classes using indices
indices_tensor = torch.tensor(new_indices_list, device=device, dtype=torch.long)
final_scores = current_scores[indices_tensor]
final_classes = current_classes[indices_tensor]
return final_masks, final_scores, final_classes, indices_tensor
def apply_post_processing(
pred_masks: torch.Tensor,
pred_logits: torch.Tensor,
mask_threshold: float = 0.0,
point_coords: Optional[torch.Tensor] = None,
pp_cfg: Optional[Dict] = None,
pred_iou: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Applies configured post-processing filters.
Args:
pred_masks: [Q, N] mask logits
pred_logits: [Q, 2] class logits (objectness is class 0)
mask_threshold: Threshold for mask binarization (usually 0.0 for logits)
pred_iou: Optional [Q] learned IoU logits from SpaceFormer's IoU head.
When provided, `sigmoid(pred_iou)` replaces the hand-coded
`mask_quality = (sigmoid(masks) * binary).sum / binary.sum` proxy in
the score = obj * quality formula. DBSCAN expansion copies the same
scalar to every component of an expanded query.
pp_cfg: Post-processing configuration dict with keys:
- objectness_thresh: float (default 0.0, disabled)
- min_mask_points: int (default 0, disabled)
- use_stability_score: bool (default False)
- stability_score_thresh: float (default 0.9)
- stability_score_offset: float (default 1.0)
- stability_score_thresh: float (default 0.9)
- stability_score_offset: float (default 1.0)
- use_nms: bool (default False)
- nms_thresh: float (default 0.7)
- use_dbscan: bool (default False)
- dbscan_eps: float (default 0.95)
- dbscan_min_points: int (default 1)
- dbscan_backend: str (default "auto")
Returns:
final_masks: [Q', N] final binary masks
final_scores: [Q'] final scores
final_classes: [Q'] final classes
final_indices: [Q'] indices mapping to original queries
"""
if pp_cfg is None:
pp_cfg = {}
# Basic preparation
masks_binary = pred_masks > mask_threshold
# 0. Min Point Count Filtering (FIRST STEP - early rejection)
# Filter out small masks before expensive operations like DBSCAN
keep = torch.arange(pred_masks.shape[0], device=pred_masks.device)
if pp_cfg.get("min_mask_points", 0) > 0:
counts = masks_binary.float().sum(1)
keep_size = counts >= pp_cfg["min_mask_points"]
keep = keep[keep_size]
if len(keep) == 0:
return (
torch.zeros((0, pred_masks.shape[1]), device=pred_masks.device, dtype=torch.bool),
torch.zeros((0,), device=pred_masks.device, dtype=pred_masks.dtype),
torch.zeros((0,), device=pred_masks.device, dtype=torch.long),
torch.zeros((0,), device=pred_masks.device, dtype=torch.long),
)
# Filter all inputs
masks_binary = masks_binary[keep]
pred_masks = pred_masks[keep]
pred_logits = pred_logits[keep]
if pred_iou is not None:
pred_iou = pred_iou[keep]
# 1. DBSCAN Expansion
# If DBSCAN is used, we expand masks immediately.
# We maintain a mapping to original logits to allow stability calculation later.
current_masks = masks_binary
current_logits = pred_masks
current_pred_logits = pred_logits
# Track indices (now relative to filtered set if min_mask_points was applied)
current_indices = keep.clone()
# Objectness component
# Check what class 0 means?
obj_probs = pred_logits.softmax(dim=-1)[:, 0]
# Mask quality component (IoU proxy) — learned if pred_iou is provided
# (P3-SAM-style IoU head), otherwise the hand-coded sigmoid-mean proxy.
if pred_iou is not None:
mask_quality = pred_iou.sigmoid()
else:
masks_sigmoid = pred_masks.sigmoid()
mask_quality = (masks_sigmoid * masks_binary.float()).sum(1) / (
masks_binary.float().sum(1) + 1e-6
)
scores = obj_probs * mask_quality
classes = torch.zeros_like(scores, dtype=torch.long) # class 0
if pp_cfg.get("use_dbscan", False) and point_coords is not None:
current_masks, scores, classes, dbscan_indices = apply_dbscan_clustering(
current_masks,
point_coords,
scores,
classes,
eps=pp_cfg.get("dbscan_eps", 0.95),
min_samples=pp_cfg.get("dbscan_min_points", 1),
backend=pp_cfg.get("dbscan_backend", "auto"),
)
# We need to map them back to original query indices
current_indices = keep[dbscan_indices]
# Expand logits and other properties to match split masks
# Use dbscan_indices (relative to current filtered set) for indexing current tensors
current_logits = current_logits[dbscan_indices]
current_pred_logits = current_pred_logits[dbscan_indices]
obj_probs = obj_probs[dbscan_indices]
# MASK THE LOGITS (Stability Fix)
# Key step: constrain the logits to the new binary mask shape
# so stability score is calculated on the component, not the whole original mask.
# We use a large negative value for background.
current_logits = torch.where(current_masks, current_logits, -100.0)
# Recalculate mask quality for the NEW masks. With learned IoU we copy
# the parent query's scalar to every expanded component (no per-component
# IoU prediction is available); without it, recompute the sigmoid-mean
# proxy from the masked logits.
if pred_iou is not None:
mask_quality = pred_iou[dbscan_indices].sigmoid()
else:
masks_sigmoid = current_logits.sigmoid()
mask_quality = (masks_sigmoid * current_masks.float()).sum(1) / (
current_masks.float().sum(1) + 1e-6
)
# Recalculate scores (Obj * Quality)
scores = obj_probs * mask_quality
# Now we have `current_masks` (binary) and `current_logits` (masked logits).
# All subsequent steps operate on these.
# 2. Objectness Filtering
keep = torch.arange(current_masks.shape[0], device=current_masks.device)
if pp_cfg.get("objectness_thresh", 0.0) > 0:
# obj_probs is aligned with current set
keep_obj = obj_probs > pp_cfg["objectness_thresh"]
keep = keep[keep_obj[keep]]
if len(keep) == 0:
return (
torch.zeros((0, pred_masks.shape[1]), device=pred_masks.device, dtype=torch.bool),
torch.zeros((0,), device=pred_masks.device, dtype=scores.dtype),
torch.zeros((0,), device=pred_masks.device, dtype=classes.dtype),
torch.zeros((0,), device=pred_masks.device, dtype=torch.long),
)
# 3. Stability Score
if pp_cfg.get("use_stability_score", False):
active_logits = current_logits[keep]
stability = calculate_stability_score(
active_logits,
mask_threshold,
pp_cfg.get("stability_score_offset", 1.0),
)
keep_stable = stability >= pp_cfg.get("stability_score_thresh", 0.9)
keep = keep[keep_stable]
if len(keep) == 0:
return (
torch.zeros((0, pred_masks.shape[1]), device=pred_masks.device, dtype=torch.bool),
torch.zeros((0,), device=pred_masks.device, dtype=scores.dtype),
torch.zeros((0,), device=pred_masks.device, dtype=classes.dtype),
torch.zeros((0,), device=pred_masks.device, dtype=torch.long),
)
# 4. NMS
if pp_cfg.get("use_nms", False):
active_masks = current_masks[keep]
active_scores = scores[keep]
keep_nms = apply_nms(active_masks, active_scores, pp_cfg.get("nms_thresh", 0.7))
keep = keep[keep_nms]
# Final gather
final_masks = current_masks[keep]
final_scores = scores[keep]
final_classes = classes[keep]
final_indices = current_indices[keep]
return final_masks, final_scores, final_classes, final_indices