import os import warnings import torch import torch.nn as nn import torch.nn.functional as F import torch_scatter import torch_cluster from pointcept.models.losses import build_criteria from pointcept.models.utils.structure import Point from pointcept.models.utils import offset2batch from .builder import MODELS, build_model def _segmentor_debug_enabled(): return os.environ.get("POINTCEPT_DEBUG_SEGMENTOR", "").strip().lower() in ( "1", "true", "yes", ) def superpoint_semantic_mixed_per_point(superpoint_ids, segment): """Per-point mask: True iff the point's superpoint id spans multiple GT labels.""" sp = superpoint_ids.long().view(-1) seg = segment.long().view(-1) unique_ids, inverse, counts = torch.unique( sp, sorted=True, return_inverse=True, return_counts=True ) _, order = torch.sort(inverse) idx_ptr = torch.cat([counts.new_zeros(1), torch.cumsum(counts, dim=0)]) seg_ord = seg[order].float() smin = torch_scatter.segment_csr(seg_ord, idx_ptr, reduce="min") smax = torch_scatter.segment_csr(seg_ord, idx_ptr, reduce="max") mixed = smin != smax return mixed[inverse] def _get_superpoint_ids(input_dict, num_points): for key in ("superpoint", "spt"): value = input_dict.get(key) if value is not None and value.numel() == num_points: return value.long().view(-1) return None def _get_point_batch_ids(input_dict, num_points, device): offset = input_dict.get("offset") if offset is None: return torch.zeros(num_points, device=device, dtype=torch.long) batch = offset2batch(offset) if batch.numel() != num_points: return torch.zeros(num_points, device=device, dtype=torch.long) return batch.to(device=device, dtype=torch.long) @MODELS.register_module() class DefaultSegmentor(nn.Module): def __init__(self, backbone=None, criteria=None): super().__init__() self.backbone = build_model(backbone) self.criteria = build_criteria(criteria) def forward(self, input_dict): if "condition" in input_dict.keys(): # PPT (https://arxiv.org/abs/2308.09718) # currently, only support one batch one condition input_dict["condition"] = input_dict["condition"][0] seg_logits = self.backbone(input_dict) # train if self.training: if _segmentor_debug_enabled(): print(f"[DEBUG DefaultSegmentor] Batch size (from target): {input_dict['segment'].shape[0]}") print(f"[DEBUG DefaultSegmentor] seg_logits shape: {seg_logits.shape}") print(f"[DEBUG DefaultSegmentor] target (input_dict['segment']) shape: {input_dict['segment'].shape}") loss = self.criteria(seg_logits, input_dict["segment"]) return dict(loss=loss) # eval elif "segment" in input_dict.keys(): if _segmentor_debug_enabled(): print(f"[DEBUG DefaultSegmentor] Batch size (from target): {input_dict['segment'].shape[0]}") print(f"[DEBUG DefaultSegmentor] seg_logits shape: {seg_logits.shape}") print(f"[DEBUG DefaultSegmentor] target (input_dict['segment']) shape: {input_dict['segment'].shape}") loss = self.criteria(seg_logits, input_dict["segment"]) return dict(loss=loss, seg_logits=seg_logits) # test else: return dict(seg_logits=seg_logits) @MODELS.register_module() class DefaultSegmentorV2(nn.Module): def __init__( self, num_classes, backbone_out_channels, backbone=None, criteria=None, freeze_backbone=False, superpoint_edge_aux_weight=0.0, superpoint_edge_boost=2.0, superpoint_contrastive_weight=0.0, superpoint_contrastive_temperature=0.1, superpoint_contrastive_max_samples=2048, superpoint_contrastive_max_points_per_superpoint=4, superpoint_contrastive_min_points_per_superpoint=2, ): super().__init__() self.seg_head = ( nn.Linear(backbone_out_channels, num_classes) if num_classes > 0 else nn.Identity() ) self.backbone = build_model(backbone) self.criteria = build_criteria(criteria) self.freeze_backbone = freeze_backbone self.superpoint_edge_aux_weight = superpoint_edge_aux_weight self.superpoint_edge_boost = superpoint_edge_boost self.superpoint_contrastive_weight = superpoint_contrastive_weight self.superpoint_contrastive_temperature = superpoint_contrastive_temperature self.superpoint_contrastive_max_samples = int(superpoint_contrastive_max_samples) self.superpoint_contrastive_max_points_per_superpoint = int( superpoint_contrastive_max_points_per_superpoint ) self.superpoint_contrastive_min_points_per_superpoint = int( superpoint_contrastive_min_points_per_superpoint ) if self.freeze_backbone: for p in self.backbone.parameters(): p.requires_grad = False def _maybe_superpoint_edge_aux_loss(self, seg_logits, target, input_dict): if self.superpoint_edge_aux_weight <= 0: return seg_logits.new_tensor(0.0) sp = input_dict.get("superpoint") if sp is None: sp = input_dict.get("spt") if sp is None or sp.numel() != seg_logits.shape[0]: return seg_logits.new_tensor(0.0) mixed = superpoint_semantic_mixed_per_point(sp, target) ce = F.cross_entropy(seg_logits, target, reduction="none", ignore_index=-1) valid = target != -1 w = 1.0 + (self.superpoint_edge_boost - 1.0) * mixed.float() denom = valid.float().sum().clamp_min(1.0) return (ce * (w - 1.0) * valid.float()).sum() / denom def _maybe_superpoint_contrastive_loss(self, feat, target, input_dict): if self.superpoint_contrastive_weight <= 0: return feat.new_tensor(0.0) if self.superpoint_contrastive_max_samples < 2: return feat.new_tensor(0.0) if self.superpoint_contrastive_max_points_per_superpoint < 2: return feat.new_tensor(0.0) num_points = feat.shape[0] sp = _get_superpoint_ids(input_dict, num_points) if sp is None: return feat.new_tensor(0.0) batch = _get_point_batch_ids(input_dict, num_points, feat.device) valid_idx = torch.nonzero(target != -1, as_tuple=False).flatten() if valid_idx.numel() < self.superpoint_contrastive_min_points_per_superpoint * 2: return feat.new_tensor(0.0) valid_batch = batch[valid_idx] valid_sp = sp[valid_idx] group_keys = torch.stack([valid_batch, valid_sp], dim=1) _, inverse, counts = torch.unique( group_keys, dim=0, sorted=False, return_inverse=True, return_counts=True ) candidate_groups = torch.nonzero( counts >= self.superpoint_contrastive_min_points_per_superpoint, as_tuple=False, ).flatten() if candidate_groups.numel() < 2: return feat.new_tensor(0.0) order = torch.argsort(inverse) idx_ptr = torch.cat([counts.new_zeros(1), torch.cumsum(counts, dim=0)]) ordered_idx = valid_idx[order] if self.training: candidate_groups = candidate_groups[ torch.randperm(candidate_groups.numel(), device=feat.device) ] selected = [] budget = 0 for group_id in candidate_groups.tolist(): start = idx_ptr[group_id].item() end = idx_ptr[group_id + 1].item() group_point_idx = ordered_idx[start:end] remaining = self.superpoint_contrastive_max_samples - budget if remaining < self.superpoint_contrastive_min_points_per_superpoint: break take = min( group_point_idx.numel(), self.superpoint_contrastive_max_points_per_superpoint, remaining, ) if take < self.superpoint_contrastive_min_points_per_superpoint: continue if self.training and group_point_idx.numel() > take: choice = torch.randperm(group_point_idx.numel(), device=feat.device)[:take] group_point_idx = group_point_idx[choice] else: group_point_idx = group_point_idx[:take] selected.append(group_point_idx) budget += group_point_idx.numel() if budget >= self.superpoint_contrastive_max_samples: break if len(selected) < 2: return feat.new_tensor(0.0) sample_idx = torch.cat(selected, dim=0) sample_batch = batch[sample_idx] sample_sp = sp[sample_idx] sample_group = torch.unique( torch.stack([sample_batch, sample_sp], dim=1), dim=0, sorted=False, return_inverse=True, )[1] z = F.normalize(feat[sample_idx], p=2, dim=1) logits = torch.matmul(z, z.transpose(0, 1)) logits = logits / max(self.superpoint_contrastive_temperature, 1e-6) logits = logits - logits.max(dim=1, keepdim=True).values.detach() self_mask = torch.eye(sample_idx.numel(), device=feat.device, dtype=torch.bool) same_scene = sample_batch[:, None].eq(sample_batch[None, :]) positive_mask = sample_group[:, None].eq(sample_group[None, :]) & (~self_mask) valid_pair_mask = same_scene & (~self_mask) positive_count = positive_mask.sum(dim=1) anchor_mask = positive_count > 0 if not anchor_mask.any(): return feat.new_tensor(0.0) exp_logits = torch.exp(logits) * valid_pair_mask.float() denom = exp_logits.sum(dim=1, keepdim=True).clamp_min(1e-12) valid_anchor_mask = anchor_mask & (valid_pair_mask.sum(dim=1) > 0) if not valid_anchor_mask.any(): return feat.new_tensor(0.0) log_prob = logits - torch.log(denom) mean_log_prob_pos = (log_prob * positive_mask.float()).sum(dim=1) / positive_count.clamp_min(1) return -mean_log_prob_pos[valid_anchor_mask].mean() def forward(self, input_dict, return_point=False): point = Point(input_dict) if _segmentor_debug_enabled(): if hasattr(point, "offset") and point.offset is not None: batch_size = len(point.offset) print(f"[DEBUG DefaultSegmentorV2] Batch Size (from offset): {batch_size}") for i in range(batch_size): start_idx = 0 if i == 0 else point.offset[i - 1].item() end_idx = point.offset[i].item() num_points = end_idx - start_idx print(f"[DEBUG DefaultSegmentorV2] Sample {i}: Number of points = {num_points}") if num_points <= 0: print( f"[ERROR] Sample {i} has {num_points} points! This will cause arange error." ) else: print(f"[DEBUG DefaultSegmentorV2] Point keys: {point.keys()}") if "coord" in point.keys(): print(f"[DEBUG DefaultSegmentorV2] Point cloud shape: {point.coord.shape}") point = self.backbone(point) # Backbone added after v1.5.0 return Point instead of feat and use DefaultSegmentorV2 # TODO: remove this part after make all backbone return Point only. if isinstance(point, Point): while "pooling_parent" in point.keys(): assert "pooling_inverse" in point.keys() parent = point.pop("pooling_parent") inverse = point.pop("pooling_inverse") parent.feat = torch.cat([parent.feat, point.feat[inverse]], dim=-1) point = parent feat = point.feat else: feat = point seg_logits = self.seg_head(feat) return_dict = dict() if return_point: # PCA evaluator parse feat and coord in point return_dict["point"] = point # train if self.training: if _segmentor_debug_enabled(): print(f"[DEBUG DefaultSegmentorV2] Batch size (from target): {input_dict['segment'].shape[0]}") print(f"[DEBUG DefaultSegmentorV2] seg_logits shape: {seg_logits.shape}") print(f"[DEBUG DefaultSegmentorV2] target (input_dict['segment']) shape: {input_dict['segment'].shape}") valid_mask = input_dict["segment"] != -1 num_valid_points = valid_mask.sum().item() if _segmentor_debug_enabled(): print(f"[DEBUG] Number of valid points (not ignored): {num_valid_points}") if num_valid_points == 0: warnings.warn( "DefaultSegmentorV2: batch has zero valid labels (all ignore_index).", UserWarning, stacklevel=2, ) loss_seg = self.criteria(seg_logits, input_dict["segment"]) loss_sp_edge = self.superpoint_edge_aux_weight * self._maybe_superpoint_edge_aux_loss( seg_logits, input_dict["segment"], input_dict ) loss_sp_contrast = self.superpoint_contrastive_weight * self._maybe_superpoint_contrastive_loss( feat, input_dict["segment"], input_dict ) loss = loss_seg + loss_sp_edge + loss_sp_contrast return_dict["loss"] = loss return_dict["loss_seg"] = loss_seg return_dict["loss_superpoint_edge"] = loss_sp_edge return_dict["loss_superpoint_contrast"] = loss_sp_contrast # eval elif "segment" in input_dict.keys(): if _segmentor_debug_enabled(): print(f"[DEBUG DefaultSegmentorV2] Batch size (from target): {input_dict['segment'].shape[0]}") print(f"[DEBUG DefaultSegmentorV2] seg_logits shape: {seg_logits.shape}") print(f"[DEBUG DefaultSegmentorV2] target (input_dict['segment']) shape: {input_dict['segment'].shape}") valid_mask = input_dict["segment"] != -1 num_valid_points = valid_mask.sum().item() if _segmentor_debug_enabled(): print(f"[DEBUG] Number of valid points (not ignored): {num_valid_points}") if num_valid_points == 0: warnings.warn( "DefaultSegmentorV2: batch has zero valid labels (all ignore_index).", UserWarning, stacklevel=2, ) loss_seg = self.criteria(seg_logits, input_dict["segment"]) loss_sp_edge = self.superpoint_edge_aux_weight * self._maybe_superpoint_edge_aux_loss( seg_logits, input_dict["segment"], input_dict ) loss_sp_contrast = self.superpoint_contrastive_weight * self._maybe_superpoint_contrastive_loss( feat, input_dict["segment"], input_dict ) loss = loss_seg + loss_sp_edge + loss_sp_contrast return_dict["loss"] = loss return_dict["loss_seg"] = loss_seg return_dict["loss_superpoint_edge"] = loss_sp_edge return_dict["loss_superpoint_contrast"] = loss_sp_contrast return_dict["seg_logits"] = seg_logits # test else: return_dict["seg_logits"] = seg_logits return return_dict @MODELS.register_module() class DINOEnhancedSegmentor(nn.Module): def __init__( self, num_classes, backbone_out_channels, backbone=None, criteria=None, freeze_backbone=False, ): super().__init__() self.seg_head = ( nn.Linear(backbone_out_channels, num_classes) if num_classes > 0 else nn.Identity() ) self.backbone = build_model(backbone) if backbone is not None else None self.criteria = build_criteria(criteria) self.freeze_backbone = freeze_backbone if self.backbone is not None and self.freeze_backbone: for p in self.backbone.parameters(): p.requires_grad = False def forward(self, input_dict, return_point=False): point = Point(input_dict) if self.backbone is not None: if self.freeze_backbone: with torch.no_grad(): point = self.backbone(point) else: point = self.backbone(point) point_list = [point] while "unpooling_parent" in point_list[-1].keys(): point_list.append(point_list[-1].pop("unpooling_parent")) for i in reversed(range(1, len(point_list))): point = point_list[i] parent = point_list[i - 1] assert "pooling_inverse" in point.keys() inverse = point.pooling_inverse parent.feat = torch.cat([parent.feat, point.feat[inverse]], dim=-1) point = point_list[0] while "pooling_parent" in point.keys(): assert "pooling_inverse" in point.keys() parent = point.pop("pooling_parent") inverse = point.pooling_inverse parent.feat = torch.cat([parent.feat, point.feat[inverse]], dim=-1) point = parent feat = [point.feat] else: feat = [] dino_coord = input_dict["dino_coord"] dino_feat = input_dict["dino_feat"] dino_offset = input_dict["dino_offset"] idx = torch_cluster.knn( x=dino_coord, y=point.origin_coord, batch_x=offset2batch(dino_offset), batch_y=offset2batch(point.origin_offset), k=1, )[1] feat.append(dino_feat[idx]) feat = torch.concatenate(feat, dim=-1) seg_logits = self.seg_head(feat) return_dict = dict() if return_point: # PCA evaluator parse feat and coord in point return_dict["point"] = point # train if self.training: if _segmentor_debug_enabled(): print(f"[DEBUG DINOEnhancedSegmentor] Batch size (from target): {input_dict['segment'].shape[0]}") print(f"[DEBUG DINOEnhancedSegmentor] seg_logits shape: {seg_logits.shape}") print(f"[DEBUG DINOEnhancedSegmentor] target (input_dict['segment']) shape: {input_dict['segment'].shape}") loss = self.criteria(seg_logits, input_dict["segment"]) return_dict["loss"] = loss # eval elif "segment" in input_dict.keys(): if _segmentor_debug_enabled(): print(f"[DEBUG DINOEnhancedSegmentor] Batch size (from target): {input_dict['segment'].shape[0]}") print(f"[DEBUG DINOEnhancedSegmentor] seg_logits shape: {seg_logits.shape}") print(f"[DEBUG DINOEnhancedSegmentor] target (input_dict['segment']) shape: {input_dict['segment'].shape}") loss = self.criteria(seg_logits, input_dict["segment"]) return_dict["loss"] = loss return_dict["seg_logits"] = seg_logits # test else: return_dict["seg_logits"] = seg_logits return return_dict @MODELS.register_module() class DefaultClassifier(nn.Module): def __init__( self, backbone=None, criteria=None, num_classes=40, backbone_embed_dim=256, ): super().__init__() self.backbone = build_model(backbone) self.criteria = build_criteria(criteria) self.num_classes = num_classes self.backbone_embed_dim = backbone_embed_dim self.cls_head = nn.Sequential( nn.Linear(backbone_embed_dim, 256), nn.BatchNorm1d(256), nn.ReLU(inplace=True), nn.Dropout(p=0.5), nn.Linear(256, 128), nn.BatchNorm1d(128), nn.ReLU(inplace=True), nn.Dropout(p=0.5), nn.Linear(128, num_classes), ) def forward(self, input_dict): point = Point(input_dict) point = self.backbone(point) # Backbone added after v1.5.0 return Point instead of feat # And after v1.5.0 feature aggregation for classification operated in classifier # TODO: remove this part after make all backbone return Point only. if isinstance(point, Point): point.feat = torch_scatter.segment_csr( src=point.feat, indptr=nn.functional.pad(point.offset, (1, 0)), reduce="mean", ) feat = point.feat else: feat = point cls_logits = self.cls_head(feat) if self.training: loss = self.criteria(cls_logits, input_dict["category"]) return dict(loss=loss) elif "category" in input_dict.keys(): loss = self.criteria(cls_logits, input_dict["category"]) return dict(loss=loss, cls_logits=cls_logits) else: return dict(cls_logits=cls_logits)