| import os |
| import warnings |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torch_scatter |
| import torch_cluster |
|
|
| from pointcept.models.losses import build_criteria |
| from pointcept.models.utils.structure import Point |
| from pointcept.models.utils import offset2batch |
| from .builder import MODELS, build_model |
|
|
|
|
| def _segmentor_debug_enabled(): |
| return os.environ.get("POINTCEPT_DEBUG_SEGMENTOR", "").strip().lower() in ( |
| "1", |
| "true", |
| "yes", |
| ) |
|
|
|
|
| def superpoint_semantic_mixed_per_point(superpoint_ids, segment): |
| """Per-point mask: True iff the point's superpoint id spans multiple GT labels.""" |
| sp = superpoint_ids.long().view(-1) |
| seg = segment.long().view(-1) |
| unique_ids, inverse, counts = torch.unique( |
| sp, sorted=True, return_inverse=True, return_counts=True |
| ) |
| _, order = torch.sort(inverse) |
| idx_ptr = torch.cat([counts.new_zeros(1), torch.cumsum(counts, dim=0)]) |
| seg_ord = seg[order].float() |
| smin = torch_scatter.segment_csr(seg_ord, idx_ptr, reduce="min") |
| smax = torch_scatter.segment_csr(seg_ord, idx_ptr, reduce="max") |
| mixed = smin != smax |
| return mixed[inverse] |
|
|
|
|
| def _get_superpoint_ids(input_dict, num_points): |
| for key in ("superpoint", "spt"): |
| value = input_dict.get(key) |
| if value is not None and value.numel() == num_points: |
| return value.long().view(-1) |
| return None |
|
|
|
|
| def _get_point_batch_ids(input_dict, num_points, device): |
| offset = input_dict.get("offset") |
| if offset is None: |
| return torch.zeros(num_points, device=device, dtype=torch.long) |
| batch = offset2batch(offset) |
| if batch.numel() != num_points: |
| return torch.zeros(num_points, device=device, dtype=torch.long) |
| return batch.to(device=device, dtype=torch.long) |
|
|
|
|
| @MODELS.register_module() |
| class DefaultSegmentor(nn.Module): |
| def __init__(self, backbone=None, criteria=None): |
| super().__init__() |
| self.backbone = build_model(backbone) |
| self.criteria = build_criteria(criteria) |
|
|
| def forward(self, input_dict): |
| if "condition" in input_dict.keys(): |
| |
| |
| input_dict["condition"] = input_dict["condition"][0] |
| seg_logits = self.backbone(input_dict) |
| |
| if self.training: |
| if _segmentor_debug_enabled(): |
| print(f"[DEBUG DefaultSegmentor] Batch size (from target): {input_dict['segment'].shape[0]}") |
| print(f"[DEBUG DefaultSegmentor] seg_logits shape: {seg_logits.shape}") |
| print(f"[DEBUG DefaultSegmentor] target (input_dict['segment']) shape: {input_dict['segment'].shape}") |
| loss = self.criteria(seg_logits, input_dict["segment"]) |
| return dict(loss=loss) |
| |
| elif "segment" in input_dict.keys(): |
| if _segmentor_debug_enabled(): |
| print(f"[DEBUG DefaultSegmentor] Batch size (from target): {input_dict['segment'].shape[0]}") |
| print(f"[DEBUG DefaultSegmentor] seg_logits shape: {seg_logits.shape}") |
| print(f"[DEBUG DefaultSegmentor] target (input_dict['segment']) shape: {input_dict['segment'].shape}") |
| loss = self.criteria(seg_logits, input_dict["segment"]) |
| return dict(loss=loss, seg_logits=seg_logits) |
| |
| else: |
| return dict(seg_logits=seg_logits) |
|
|
|
|
| @MODELS.register_module() |
| class DefaultSegmentorV2(nn.Module): |
| def __init__( |
| self, |
| num_classes, |
| backbone_out_channels, |
| backbone=None, |
| criteria=None, |
| freeze_backbone=False, |
| superpoint_edge_aux_weight=0.0, |
| superpoint_edge_boost=2.0, |
| superpoint_contrastive_weight=0.0, |
| superpoint_contrastive_temperature=0.1, |
| superpoint_contrastive_max_samples=2048, |
| superpoint_contrastive_max_points_per_superpoint=4, |
| superpoint_contrastive_min_points_per_superpoint=2, |
| ): |
| super().__init__() |
| self.seg_head = ( |
| nn.Linear(backbone_out_channels, num_classes) |
| if num_classes > 0 |
| else nn.Identity() |
| ) |
| self.backbone = build_model(backbone) |
| self.criteria = build_criteria(criteria) |
| self.freeze_backbone = freeze_backbone |
| self.superpoint_edge_aux_weight = superpoint_edge_aux_weight |
| self.superpoint_edge_boost = superpoint_edge_boost |
| self.superpoint_contrastive_weight = superpoint_contrastive_weight |
| self.superpoint_contrastive_temperature = superpoint_contrastive_temperature |
| self.superpoint_contrastive_max_samples = int(superpoint_contrastive_max_samples) |
| self.superpoint_contrastive_max_points_per_superpoint = int( |
| superpoint_contrastive_max_points_per_superpoint |
| ) |
| self.superpoint_contrastive_min_points_per_superpoint = int( |
| superpoint_contrastive_min_points_per_superpoint |
| ) |
| if self.freeze_backbone: |
| for p in self.backbone.parameters(): |
| p.requires_grad = False |
|
|
| def _maybe_superpoint_edge_aux_loss(self, seg_logits, target, input_dict): |
| if self.superpoint_edge_aux_weight <= 0: |
| return seg_logits.new_tensor(0.0) |
| sp = input_dict.get("superpoint") |
| if sp is None: |
| sp = input_dict.get("spt") |
| if sp is None or sp.numel() != seg_logits.shape[0]: |
| return seg_logits.new_tensor(0.0) |
| mixed = superpoint_semantic_mixed_per_point(sp, target) |
| ce = F.cross_entropy(seg_logits, target, reduction="none", ignore_index=-1) |
| valid = target != -1 |
| w = 1.0 + (self.superpoint_edge_boost - 1.0) * mixed.float() |
| denom = valid.float().sum().clamp_min(1.0) |
| return (ce * (w - 1.0) * valid.float()).sum() / denom |
|
|
| def _maybe_superpoint_contrastive_loss(self, feat, target, input_dict): |
| if self.superpoint_contrastive_weight <= 0: |
| return feat.new_tensor(0.0) |
| if self.superpoint_contrastive_max_samples < 2: |
| return feat.new_tensor(0.0) |
| if self.superpoint_contrastive_max_points_per_superpoint < 2: |
| return feat.new_tensor(0.0) |
|
|
| num_points = feat.shape[0] |
| sp = _get_superpoint_ids(input_dict, num_points) |
| if sp is None: |
| return feat.new_tensor(0.0) |
|
|
| batch = _get_point_batch_ids(input_dict, num_points, feat.device) |
| valid_idx = torch.nonzero(target != -1, as_tuple=False).flatten() |
| if valid_idx.numel() < self.superpoint_contrastive_min_points_per_superpoint * 2: |
| return feat.new_tensor(0.0) |
|
|
| valid_batch = batch[valid_idx] |
| valid_sp = sp[valid_idx] |
| group_keys = torch.stack([valid_batch, valid_sp], dim=1) |
| _, inverse, counts = torch.unique( |
| group_keys, dim=0, sorted=False, return_inverse=True, return_counts=True |
| ) |
| candidate_groups = torch.nonzero( |
| counts >= self.superpoint_contrastive_min_points_per_superpoint, |
| as_tuple=False, |
| ).flatten() |
| if candidate_groups.numel() < 2: |
| return feat.new_tensor(0.0) |
|
|
| order = torch.argsort(inverse) |
| idx_ptr = torch.cat([counts.new_zeros(1), torch.cumsum(counts, dim=0)]) |
| ordered_idx = valid_idx[order] |
| if self.training: |
| candidate_groups = candidate_groups[ |
| torch.randperm(candidate_groups.numel(), device=feat.device) |
| ] |
|
|
| selected = [] |
| budget = 0 |
| for group_id in candidate_groups.tolist(): |
| start = idx_ptr[group_id].item() |
| end = idx_ptr[group_id + 1].item() |
| group_point_idx = ordered_idx[start:end] |
| remaining = self.superpoint_contrastive_max_samples - budget |
| if remaining < self.superpoint_contrastive_min_points_per_superpoint: |
| break |
| take = min( |
| group_point_idx.numel(), |
| self.superpoint_contrastive_max_points_per_superpoint, |
| remaining, |
| ) |
| if take < self.superpoint_contrastive_min_points_per_superpoint: |
| continue |
| if self.training and group_point_idx.numel() > take: |
| choice = torch.randperm(group_point_idx.numel(), device=feat.device)[:take] |
| group_point_idx = group_point_idx[choice] |
| else: |
| group_point_idx = group_point_idx[:take] |
| selected.append(group_point_idx) |
| budget += group_point_idx.numel() |
| if budget >= self.superpoint_contrastive_max_samples: |
| break |
|
|
| if len(selected) < 2: |
| return feat.new_tensor(0.0) |
|
|
| sample_idx = torch.cat(selected, dim=0) |
| sample_batch = batch[sample_idx] |
| sample_sp = sp[sample_idx] |
| sample_group = torch.unique( |
| torch.stack([sample_batch, sample_sp], dim=1), |
| dim=0, |
| sorted=False, |
| return_inverse=True, |
| )[1] |
|
|
| z = F.normalize(feat[sample_idx], p=2, dim=1) |
| logits = torch.matmul(z, z.transpose(0, 1)) |
| logits = logits / max(self.superpoint_contrastive_temperature, 1e-6) |
| logits = logits - logits.max(dim=1, keepdim=True).values.detach() |
|
|
| self_mask = torch.eye(sample_idx.numel(), device=feat.device, dtype=torch.bool) |
| same_scene = sample_batch[:, None].eq(sample_batch[None, :]) |
| positive_mask = sample_group[:, None].eq(sample_group[None, :]) & (~self_mask) |
| valid_pair_mask = same_scene & (~self_mask) |
| positive_count = positive_mask.sum(dim=1) |
| anchor_mask = positive_count > 0 |
| if not anchor_mask.any(): |
| return feat.new_tensor(0.0) |
|
|
| exp_logits = torch.exp(logits) * valid_pair_mask.float() |
| denom = exp_logits.sum(dim=1, keepdim=True).clamp_min(1e-12) |
| valid_anchor_mask = anchor_mask & (valid_pair_mask.sum(dim=1) > 0) |
| if not valid_anchor_mask.any(): |
| return feat.new_tensor(0.0) |
|
|
| log_prob = logits - torch.log(denom) |
| mean_log_prob_pos = (log_prob * positive_mask.float()).sum(dim=1) / positive_count.clamp_min(1) |
| return -mean_log_prob_pos[valid_anchor_mask].mean() |
|
|
| def forward(self, input_dict, return_point=False): |
| point = Point(input_dict) |
|
|
| if _segmentor_debug_enabled(): |
| if hasattr(point, "offset") and point.offset is not None: |
| batch_size = len(point.offset) |
| print(f"[DEBUG DefaultSegmentorV2] Batch Size (from offset): {batch_size}") |
| for i in range(batch_size): |
| start_idx = 0 if i == 0 else point.offset[i - 1].item() |
| end_idx = point.offset[i].item() |
| num_points = end_idx - start_idx |
| print(f"[DEBUG DefaultSegmentorV2] Sample {i}: Number of points = {num_points}") |
| if num_points <= 0: |
| print( |
| f"[ERROR] Sample {i} has {num_points} points! This will cause arange error." |
| ) |
| else: |
| print(f"[DEBUG DefaultSegmentorV2] Point keys: {point.keys()}") |
| if "coord" in point.keys(): |
| print(f"[DEBUG DefaultSegmentorV2] Point cloud shape: {point.coord.shape}") |
|
|
| point = self.backbone(point) |
| |
| |
| if isinstance(point, Point): |
| while "pooling_parent" in point.keys(): |
| assert "pooling_inverse" in point.keys() |
| parent = point.pop("pooling_parent") |
| inverse = point.pop("pooling_inverse") |
| parent.feat = torch.cat([parent.feat, point.feat[inverse]], dim=-1) |
| point = parent |
| feat = point.feat |
| else: |
| feat = point |
| seg_logits = self.seg_head(feat) |
| return_dict = dict() |
| if return_point: |
| |
| return_dict["point"] = point |
| |
| if self.training: |
| if _segmentor_debug_enabled(): |
| print(f"[DEBUG DefaultSegmentorV2] Batch size (from target): {input_dict['segment'].shape[0]}") |
| print(f"[DEBUG DefaultSegmentorV2] seg_logits shape: {seg_logits.shape}") |
| print(f"[DEBUG DefaultSegmentorV2] target (input_dict['segment']) shape: {input_dict['segment'].shape}") |
| valid_mask = input_dict["segment"] != -1 |
| num_valid_points = valid_mask.sum().item() |
| if _segmentor_debug_enabled(): |
| print(f"[DEBUG] Number of valid points (not ignored): {num_valid_points}") |
| if num_valid_points == 0: |
| warnings.warn( |
| "DefaultSegmentorV2: batch has zero valid labels (all ignore_index).", |
| UserWarning, |
| stacklevel=2, |
| ) |
| loss_seg = self.criteria(seg_logits, input_dict["segment"]) |
| loss_sp_edge = self.superpoint_edge_aux_weight * self._maybe_superpoint_edge_aux_loss( |
| seg_logits, input_dict["segment"], input_dict |
| ) |
| loss_sp_contrast = self.superpoint_contrastive_weight * self._maybe_superpoint_contrastive_loss( |
| feat, input_dict["segment"], input_dict |
| ) |
| loss = loss_seg + loss_sp_edge + loss_sp_contrast |
| return_dict["loss"] = loss |
| return_dict["loss_seg"] = loss_seg |
| return_dict["loss_superpoint_edge"] = loss_sp_edge |
| return_dict["loss_superpoint_contrast"] = loss_sp_contrast |
| |
| elif "segment" in input_dict.keys(): |
| if _segmentor_debug_enabled(): |
| print(f"[DEBUG DefaultSegmentorV2] Batch size (from target): {input_dict['segment'].shape[0]}") |
| print(f"[DEBUG DefaultSegmentorV2] seg_logits shape: {seg_logits.shape}") |
| print(f"[DEBUG DefaultSegmentorV2] target (input_dict['segment']) shape: {input_dict['segment'].shape}") |
| valid_mask = input_dict["segment"] != -1 |
| num_valid_points = valid_mask.sum().item() |
| if _segmentor_debug_enabled(): |
| print(f"[DEBUG] Number of valid points (not ignored): {num_valid_points}") |
| if num_valid_points == 0: |
| warnings.warn( |
| "DefaultSegmentorV2: batch has zero valid labels (all ignore_index).", |
| UserWarning, |
| stacklevel=2, |
| ) |
| loss_seg = self.criteria(seg_logits, input_dict["segment"]) |
| loss_sp_edge = self.superpoint_edge_aux_weight * self._maybe_superpoint_edge_aux_loss( |
| seg_logits, input_dict["segment"], input_dict |
| ) |
| loss_sp_contrast = self.superpoint_contrastive_weight * self._maybe_superpoint_contrastive_loss( |
| feat, input_dict["segment"], input_dict |
| ) |
| loss = loss_seg + loss_sp_edge + loss_sp_contrast |
| return_dict["loss"] = loss |
| return_dict["loss_seg"] = loss_seg |
| return_dict["loss_superpoint_edge"] = loss_sp_edge |
| return_dict["loss_superpoint_contrast"] = loss_sp_contrast |
| return_dict["seg_logits"] = seg_logits |
| |
| else: |
| return_dict["seg_logits"] = seg_logits |
| return return_dict |
|
|
|
|
| @MODELS.register_module() |
| class DINOEnhancedSegmentor(nn.Module): |
| def __init__( |
| self, |
| num_classes, |
| backbone_out_channels, |
| backbone=None, |
| criteria=None, |
| freeze_backbone=False, |
| ): |
| super().__init__() |
| self.seg_head = ( |
| nn.Linear(backbone_out_channels, num_classes) |
| if num_classes > 0 |
| else nn.Identity() |
| ) |
| self.backbone = build_model(backbone) if backbone is not None else None |
| self.criteria = build_criteria(criteria) |
| self.freeze_backbone = freeze_backbone |
| if self.backbone is not None and self.freeze_backbone: |
| for p in self.backbone.parameters(): |
| p.requires_grad = False |
|
|
| def forward(self, input_dict, return_point=False): |
| point = Point(input_dict) |
| if self.backbone is not None: |
| if self.freeze_backbone: |
| with torch.no_grad(): |
| point = self.backbone(point) |
| else: |
| point = self.backbone(point) |
| point_list = [point] |
| while "unpooling_parent" in point_list[-1].keys(): |
| point_list.append(point_list[-1].pop("unpooling_parent")) |
| for i in reversed(range(1, len(point_list))): |
| point = point_list[i] |
| parent = point_list[i - 1] |
| assert "pooling_inverse" in point.keys() |
| inverse = point.pooling_inverse |
| parent.feat = torch.cat([parent.feat, point.feat[inverse]], dim=-1) |
| point = point_list[0] |
| while "pooling_parent" in point.keys(): |
| assert "pooling_inverse" in point.keys() |
| parent = point.pop("pooling_parent") |
| inverse = point.pooling_inverse |
| parent.feat = torch.cat([parent.feat, point.feat[inverse]], dim=-1) |
| point = parent |
| feat = [point.feat] |
| else: |
| feat = [] |
| dino_coord = input_dict["dino_coord"] |
| dino_feat = input_dict["dino_feat"] |
| dino_offset = input_dict["dino_offset"] |
| idx = torch_cluster.knn( |
| x=dino_coord, |
| y=point.origin_coord, |
| batch_x=offset2batch(dino_offset), |
| batch_y=offset2batch(point.origin_offset), |
| k=1, |
| )[1] |
|
|
| feat.append(dino_feat[idx]) |
| feat = torch.concatenate(feat, dim=-1) |
| seg_logits = self.seg_head(feat) |
| return_dict = dict() |
| if return_point: |
| |
| return_dict["point"] = point |
| |
| if self.training: |
| if _segmentor_debug_enabled(): |
| print(f"[DEBUG DINOEnhancedSegmentor] Batch size (from target): {input_dict['segment'].shape[0]}") |
| print(f"[DEBUG DINOEnhancedSegmentor] seg_logits shape: {seg_logits.shape}") |
| print(f"[DEBUG DINOEnhancedSegmentor] target (input_dict['segment']) shape: {input_dict['segment'].shape}") |
| loss = self.criteria(seg_logits, input_dict["segment"]) |
| return_dict["loss"] = loss |
| |
| elif "segment" in input_dict.keys(): |
| if _segmentor_debug_enabled(): |
| print(f"[DEBUG DINOEnhancedSegmentor] Batch size (from target): {input_dict['segment'].shape[0]}") |
| print(f"[DEBUG DINOEnhancedSegmentor] seg_logits shape: {seg_logits.shape}") |
| print(f"[DEBUG DINOEnhancedSegmentor] target (input_dict['segment']) shape: {input_dict['segment'].shape}") |
| loss = self.criteria(seg_logits, input_dict["segment"]) |
| return_dict["loss"] = loss |
| return_dict["seg_logits"] = seg_logits |
| |
| else: |
| return_dict["seg_logits"] = seg_logits |
| return return_dict |
|
|
|
|
| @MODELS.register_module() |
| class DefaultClassifier(nn.Module): |
| def __init__( |
| self, |
| backbone=None, |
| criteria=None, |
| num_classes=40, |
| backbone_embed_dim=256, |
| ): |
| super().__init__() |
| self.backbone = build_model(backbone) |
| self.criteria = build_criteria(criteria) |
| self.num_classes = num_classes |
| self.backbone_embed_dim = backbone_embed_dim |
| self.cls_head = nn.Sequential( |
| nn.Linear(backbone_embed_dim, 256), |
| nn.BatchNorm1d(256), |
| nn.ReLU(inplace=True), |
| nn.Dropout(p=0.5), |
| nn.Linear(256, 128), |
| nn.BatchNorm1d(128), |
| nn.ReLU(inplace=True), |
| nn.Dropout(p=0.5), |
| nn.Linear(128, num_classes), |
| ) |
|
|
| def forward(self, input_dict): |
| point = Point(input_dict) |
| point = self.backbone(point) |
| |
| |
| |
| if isinstance(point, Point): |
| point.feat = torch_scatter.segment_csr( |
| src=point.feat, |
| indptr=nn.functional.pad(point.offset, (1, 0)), |
| reduce="mean", |
| ) |
| feat = point.feat |
| else: |
| feat = point |
| cls_logits = self.cls_head(feat) |
| if self.training: |
| loss = self.criteria(cls_logits, input_dict["category"]) |
| return dict(loss=loss) |
| elif "category" in input_dict.keys(): |
| loss = self.criteria(cls_logits, input_dict["category"]) |
| return dict(loss=loss, cls_logits=cls_logits) |
| else: |
| return dict(cls_logits=cls_logits) |
|
|