YYYYYYUUU's picture
Upload folder using huggingface_hub
bc6b9b1 verified
Raw
History Blame Contribute Delete
21.3 kB
import os
import warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_scatter
import torch_cluster
from pointcept.models.losses import build_criteria
from pointcept.models.utils.structure import Point
from pointcept.models.utils import offset2batch
from .builder import MODELS, build_model
def _segmentor_debug_enabled():
return os.environ.get("POINTCEPT_DEBUG_SEGMENTOR", "").strip().lower() in (
"1",
"true",
"yes",
)
def superpoint_semantic_mixed_per_point(superpoint_ids, segment):
"""Per-point mask: True iff the point's superpoint id spans multiple GT labels."""
sp = superpoint_ids.long().view(-1)
seg = segment.long().view(-1)
unique_ids, inverse, counts = torch.unique(
sp, sorted=True, return_inverse=True, return_counts=True
)
_, order = torch.sort(inverse)
idx_ptr = torch.cat([counts.new_zeros(1), torch.cumsum(counts, dim=0)])
seg_ord = seg[order].float()
smin = torch_scatter.segment_csr(seg_ord, idx_ptr, reduce="min")
smax = torch_scatter.segment_csr(seg_ord, idx_ptr, reduce="max")
mixed = smin != smax
return mixed[inverse]
def _get_superpoint_ids(input_dict, num_points):
for key in ("superpoint", "spt"):
value = input_dict.get(key)
if value is not None and value.numel() == num_points:
return value.long().view(-1)
return None
def _get_point_batch_ids(input_dict, num_points, device):
offset = input_dict.get("offset")
if offset is None:
return torch.zeros(num_points, device=device, dtype=torch.long)
batch = offset2batch(offset)
if batch.numel() != num_points:
return torch.zeros(num_points, device=device, dtype=torch.long)
return batch.to(device=device, dtype=torch.long)
@MODELS.register_module()
class DefaultSegmentor(nn.Module):
def __init__(self, backbone=None, criteria=None):
super().__init__()
self.backbone = build_model(backbone)
self.criteria = build_criteria(criteria)
def forward(self, input_dict):
if "condition" in input_dict.keys():
# PPT (https://arxiv.org/abs/2308.09718)
# currently, only support one batch one condition
input_dict["condition"] = input_dict["condition"][0]
seg_logits = self.backbone(input_dict)
# train
if self.training:
if _segmentor_debug_enabled():
print(f"[DEBUG DefaultSegmentor] Batch size (from target): {input_dict['segment'].shape[0]}")
print(f"[DEBUG DefaultSegmentor] seg_logits shape: {seg_logits.shape}")
print(f"[DEBUG DefaultSegmentor] target (input_dict['segment']) shape: {input_dict['segment'].shape}")
loss = self.criteria(seg_logits, input_dict["segment"])
return dict(loss=loss)
# eval
elif "segment" in input_dict.keys():
if _segmentor_debug_enabled():
print(f"[DEBUG DefaultSegmentor] Batch size (from target): {input_dict['segment'].shape[0]}")
print(f"[DEBUG DefaultSegmentor] seg_logits shape: {seg_logits.shape}")
print(f"[DEBUG DefaultSegmentor] target (input_dict['segment']) shape: {input_dict['segment'].shape}")
loss = self.criteria(seg_logits, input_dict["segment"])
return dict(loss=loss, seg_logits=seg_logits)
# test
else:
return dict(seg_logits=seg_logits)
@MODELS.register_module()
class DefaultSegmentorV2(nn.Module):
def __init__(
self,
num_classes,
backbone_out_channels,
backbone=None,
criteria=None,
freeze_backbone=False,
superpoint_edge_aux_weight=0.0,
superpoint_edge_boost=2.0,
superpoint_contrastive_weight=0.0,
superpoint_contrastive_temperature=0.1,
superpoint_contrastive_max_samples=2048,
superpoint_contrastive_max_points_per_superpoint=4,
superpoint_contrastive_min_points_per_superpoint=2,
):
super().__init__()
self.seg_head = (
nn.Linear(backbone_out_channels, num_classes)
if num_classes > 0
else nn.Identity()
)
self.backbone = build_model(backbone)
self.criteria = build_criteria(criteria)
self.freeze_backbone = freeze_backbone
self.superpoint_edge_aux_weight = superpoint_edge_aux_weight
self.superpoint_edge_boost = superpoint_edge_boost
self.superpoint_contrastive_weight = superpoint_contrastive_weight
self.superpoint_contrastive_temperature = superpoint_contrastive_temperature
self.superpoint_contrastive_max_samples = int(superpoint_contrastive_max_samples)
self.superpoint_contrastive_max_points_per_superpoint = int(
superpoint_contrastive_max_points_per_superpoint
)
self.superpoint_contrastive_min_points_per_superpoint = int(
superpoint_contrastive_min_points_per_superpoint
)
if self.freeze_backbone:
for p in self.backbone.parameters():
p.requires_grad = False
def _maybe_superpoint_edge_aux_loss(self, seg_logits, target, input_dict):
if self.superpoint_edge_aux_weight <= 0:
return seg_logits.new_tensor(0.0)
sp = input_dict.get("superpoint")
if sp is None:
sp = input_dict.get("spt")
if sp is None or sp.numel() != seg_logits.shape[0]:
return seg_logits.new_tensor(0.0)
mixed = superpoint_semantic_mixed_per_point(sp, target)
ce = F.cross_entropy(seg_logits, target, reduction="none", ignore_index=-1)
valid = target != -1
w = 1.0 + (self.superpoint_edge_boost - 1.0) * mixed.float()
denom = valid.float().sum().clamp_min(1.0)
return (ce * (w - 1.0) * valid.float()).sum() / denom
def _maybe_superpoint_contrastive_loss(self, feat, target, input_dict):
if self.superpoint_contrastive_weight <= 0:
return feat.new_tensor(0.0)
if self.superpoint_contrastive_max_samples < 2:
return feat.new_tensor(0.0)
if self.superpoint_contrastive_max_points_per_superpoint < 2:
return feat.new_tensor(0.0)
num_points = feat.shape[0]
sp = _get_superpoint_ids(input_dict, num_points)
if sp is None:
return feat.new_tensor(0.0)
batch = _get_point_batch_ids(input_dict, num_points, feat.device)
valid_idx = torch.nonzero(target != -1, as_tuple=False).flatten()
if valid_idx.numel() < self.superpoint_contrastive_min_points_per_superpoint * 2:
return feat.new_tensor(0.0)
valid_batch = batch[valid_idx]
valid_sp = sp[valid_idx]
group_keys = torch.stack([valid_batch, valid_sp], dim=1)
_, inverse, counts = torch.unique(
group_keys, dim=0, sorted=False, return_inverse=True, return_counts=True
)
candidate_groups = torch.nonzero(
counts >= self.superpoint_contrastive_min_points_per_superpoint,
as_tuple=False,
).flatten()
if candidate_groups.numel() < 2:
return feat.new_tensor(0.0)
order = torch.argsort(inverse)
idx_ptr = torch.cat([counts.new_zeros(1), torch.cumsum(counts, dim=0)])
ordered_idx = valid_idx[order]
if self.training:
candidate_groups = candidate_groups[
torch.randperm(candidate_groups.numel(), device=feat.device)
]
selected = []
budget = 0
for group_id in candidate_groups.tolist():
start = idx_ptr[group_id].item()
end = idx_ptr[group_id + 1].item()
group_point_idx = ordered_idx[start:end]
remaining = self.superpoint_contrastive_max_samples - budget
if remaining < self.superpoint_contrastive_min_points_per_superpoint:
break
take = min(
group_point_idx.numel(),
self.superpoint_contrastive_max_points_per_superpoint,
remaining,
)
if take < self.superpoint_contrastive_min_points_per_superpoint:
continue
if self.training and group_point_idx.numel() > take:
choice = torch.randperm(group_point_idx.numel(), device=feat.device)[:take]
group_point_idx = group_point_idx[choice]
else:
group_point_idx = group_point_idx[:take]
selected.append(group_point_idx)
budget += group_point_idx.numel()
if budget >= self.superpoint_contrastive_max_samples:
break
if len(selected) < 2:
return feat.new_tensor(0.0)
sample_idx = torch.cat(selected, dim=0)
sample_batch = batch[sample_idx]
sample_sp = sp[sample_idx]
sample_group = torch.unique(
torch.stack([sample_batch, sample_sp], dim=1),
dim=0,
sorted=False,
return_inverse=True,
)[1]
z = F.normalize(feat[sample_idx], p=2, dim=1)
logits = torch.matmul(z, z.transpose(0, 1))
logits = logits / max(self.superpoint_contrastive_temperature, 1e-6)
logits = logits - logits.max(dim=1, keepdim=True).values.detach()
self_mask = torch.eye(sample_idx.numel(), device=feat.device, dtype=torch.bool)
same_scene = sample_batch[:, None].eq(sample_batch[None, :])
positive_mask = sample_group[:, None].eq(sample_group[None, :]) & (~self_mask)
valid_pair_mask = same_scene & (~self_mask)
positive_count = positive_mask.sum(dim=1)
anchor_mask = positive_count > 0
if not anchor_mask.any():
return feat.new_tensor(0.0)
exp_logits = torch.exp(logits) * valid_pair_mask.float()
denom = exp_logits.sum(dim=1, keepdim=True).clamp_min(1e-12)
valid_anchor_mask = anchor_mask & (valid_pair_mask.sum(dim=1) > 0)
if not valid_anchor_mask.any():
return feat.new_tensor(0.0)
log_prob = logits - torch.log(denom)
mean_log_prob_pos = (log_prob * positive_mask.float()).sum(dim=1) / positive_count.clamp_min(1)
return -mean_log_prob_pos[valid_anchor_mask].mean()
def forward(self, input_dict, return_point=False):
point = Point(input_dict)
if _segmentor_debug_enabled():
if hasattr(point, "offset") and point.offset is not None:
batch_size = len(point.offset)
print(f"[DEBUG DefaultSegmentorV2] Batch Size (from offset): {batch_size}")
for i in range(batch_size):
start_idx = 0 if i == 0 else point.offset[i - 1].item()
end_idx = point.offset[i].item()
num_points = end_idx - start_idx
print(f"[DEBUG DefaultSegmentorV2] Sample {i}: Number of points = {num_points}")
if num_points <= 0:
print(
f"[ERROR] Sample {i} has {num_points} points! This will cause arange error."
)
else:
print(f"[DEBUG DefaultSegmentorV2] Point keys: {point.keys()}")
if "coord" in point.keys():
print(f"[DEBUG DefaultSegmentorV2] Point cloud shape: {point.coord.shape}")
point = self.backbone(point)
# Backbone added after v1.5.0 return Point instead of feat and use DefaultSegmentorV2
# TODO: remove this part after make all backbone return Point only.
if isinstance(point, Point):
while "pooling_parent" in point.keys():
assert "pooling_inverse" in point.keys()
parent = point.pop("pooling_parent")
inverse = point.pop("pooling_inverse")
parent.feat = torch.cat([parent.feat, point.feat[inverse]], dim=-1)
point = parent
feat = point.feat
else:
feat = point
seg_logits = self.seg_head(feat)
return_dict = dict()
if return_point:
# PCA evaluator parse feat and coord in point
return_dict["point"] = point
# train
if self.training:
if _segmentor_debug_enabled():
print(f"[DEBUG DefaultSegmentorV2] Batch size (from target): {input_dict['segment'].shape[0]}")
print(f"[DEBUG DefaultSegmentorV2] seg_logits shape: {seg_logits.shape}")
print(f"[DEBUG DefaultSegmentorV2] target (input_dict['segment']) shape: {input_dict['segment'].shape}")
valid_mask = input_dict["segment"] != -1
num_valid_points = valid_mask.sum().item()
if _segmentor_debug_enabled():
print(f"[DEBUG] Number of valid points (not ignored): {num_valid_points}")
if num_valid_points == 0:
warnings.warn(
"DefaultSegmentorV2: batch has zero valid labels (all ignore_index).",
UserWarning,
stacklevel=2,
)
loss_seg = self.criteria(seg_logits, input_dict["segment"])
loss_sp_edge = self.superpoint_edge_aux_weight * self._maybe_superpoint_edge_aux_loss(
seg_logits, input_dict["segment"], input_dict
)
loss_sp_contrast = self.superpoint_contrastive_weight * self._maybe_superpoint_contrastive_loss(
feat, input_dict["segment"], input_dict
)
loss = loss_seg + loss_sp_edge + loss_sp_contrast
return_dict["loss"] = loss
return_dict["loss_seg"] = loss_seg
return_dict["loss_superpoint_edge"] = loss_sp_edge
return_dict["loss_superpoint_contrast"] = loss_sp_contrast
# eval
elif "segment" in input_dict.keys():
if _segmentor_debug_enabled():
print(f"[DEBUG DefaultSegmentorV2] Batch size (from target): {input_dict['segment'].shape[0]}")
print(f"[DEBUG DefaultSegmentorV2] seg_logits shape: {seg_logits.shape}")
print(f"[DEBUG DefaultSegmentorV2] target (input_dict['segment']) shape: {input_dict['segment'].shape}")
valid_mask = input_dict["segment"] != -1
num_valid_points = valid_mask.sum().item()
if _segmentor_debug_enabled():
print(f"[DEBUG] Number of valid points (not ignored): {num_valid_points}")
if num_valid_points == 0:
warnings.warn(
"DefaultSegmentorV2: batch has zero valid labels (all ignore_index).",
UserWarning,
stacklevel=2,
)
loss_seg = self.criteria(seg_logits, input_dict["segment"])
loss_sp_edge = self.superpoint_edge_aux_weight * self._maybe_superpoint_edge_aux_loss(
seg_logits, input_dict["segment"], input_dict
)
loss_sp_contrast = self.superpoint_contrastive_weight * self._maybe_superpoint_contrastive_loss(
feat, input_dict["segment"], input_dict
)
loss = loss_seg + loss_sp_edge + loss_sp_contrast
return_dict["loss"] = loss
return_dict["loss_seg"] = loss_seg
return_dict["loss_superpoint_edge"] = loss_sp_edge
return_dict["loss_superpoint_contrast"] = loss_sp_contrast
return_dict["seg_logits"] = seg_logits
# test
else:
return_dict["seg_logits"] = seg_logits
return return_dict
@MODELS.register_module()
class DINOEnhancedSegmentor(nn.Module):
def __init__(
self,
num_classes,
backbone_out_channels,
backbone=None,
criteria=None,
freeze_backbone=False,
):
super().__init__()
self.seg_head = (
nn.Linear(backbone_out_channels, num_classes)
if num_classes > 0
else nn.Identity()
)
self.backbone = build_model(backbone) if backbone is not None else None
self.criteria = build_criteria(criteria)
self.freeze_backbone = freeze_backbone
if self.backbone is not None and self.freeze_backbone:
for p in self.backbone.parameters():
p.requires_grad = False
def forward(self, input_dict, return_point=False):
point = Point(input_dict)
if self.backbone is not None:
if self.freeze_backbone:
with torch.no_grad():
point = self.backbone(point)
else:
point = self.backbone(point)
point_list = [point]
while "unpooling_parent" in point_list[-1].keys():
point_list.append(point_list[-1].pop("unpooling_parent"))
for i in reversed(range(1, len(point_list))):
point = point_list[i]
parent = point_list[i - 1]
assert "pooling_inverse" in point.keys()
inverse = point.pooling_inverse
parent.feat = torch.cat([parent.feat, point.feat[inverse]], dim=-1)
point = point_list[0]
while "pooling_parent" in point.keys():
assert "pooling_inverse" in point.keys()
parent = point.pop("pooling_parent")
inverse = point.pooling_inverse
parent.feat = torch.cat([parent.feat, point.feat[inverse]], dim=-1)
point = parent
feat = [point.feat]
else:
feat = []
dino_coord = input_dict["dino_coord"]
dino_feat = input_dict["dino_feat"]
dino_offset = input_dict["dino_offset"]
idx = torch_cluster.knn(
x=dino_coord,
y=point.origin_coord,
batch_x=offset2batch(dino_offset),
batch_y=offset2batch(point.origin_offset),
k=1,
)[1]
feat.append(dino_feat[idx])
feat = torch.concatenate(feat, dim=-1)
seg_logits = self.seg_head(feat)
return_dict = dict()
if return_point:
# PCA evaluator parse feat and coord in point
return_dict["point"] = point
# train
if self.training:
if _segmentor_debug_enabled():
print(f"[DEBUG DINOEnhancedSegmentor] Batch size (from target): {input_dict['segment'].shape[0]}")
print(f"[DEBUG DINOEnhancedSegmentor] seg_logits shape: {seg_logits.shape}")
print(f"[DEBUG DINOEnhancedSegmentor] target (input_dict['segment']) shape: {input_dict['segment'].shape}")
loss = self.criteria(seg_logits, input_dict["segment"])
return_dict["loss"] = loss
# eval
elif "segment" in input_dict.keys():
if _segmentor_debug_enabled():
print(f"[DEBUG DINOEnhancedSegmentor] Batch size (from target): {input_dict['segment'].shape[0]}")
print(f"[DEBUG DINOEnhancedSegmentor] seg_logits shape: {seg_logits.shape}")
print(f"[DEBUG DINOEnhancedSegmentor] target (input_dict['segment']) shape: {input_dict['segment'].shape}")
loss = self.criteria(seg_logits, input_dict["segment"])
return_dict["loss"] = loss
return_dict["seg_logits"] = seg_logits
# test
else:
return_dict["seg_logits"] = seg_logits
return return_dict
@MODELS.register_module()
class DefaultClassifier(nn.Module):
def __init__(
self,
backbone=None,
criteria=None,
num_classes=40,
backbone_embed_dim=256,
):
super().__init__()
self.backbone = build_model(backbone)
self.criteria = build_criteria(criteria)
self.num_classes = num_classes
self.backbone_embed_dim = backbone_embed_dim
self.cls_head = nn.Sequential(
nn.Linear(backbone_embed_dim, 256),
nn.BatchNorm1d(256),
nn.ReLU(inplace=True),
nn.Dropout(p=0.5),
nn.Linear(256, 128),
nn.BatchNorm1d(128),
nn.ReLU(inplace=True),
nn.Dropout(p=0.5),
nn.Linear(128, num_classes),
)
def forward(self, input_dict):
point = Point(input_dict)
point = self.backbone(point)
# Backbone added after v1.5.0 return Point instead of feat
# And after v1.5.0 feature aggregation for classification operated in classifier
# TODO: remove this part after make all backbone return Point only.
if isinstance(point, Point):
point.feat = torch_scatter.segment_csr(
src=point.feat,
indptr=nn.functional.pad(point.offset, (1, 0)),
reduce="mean",
)
feat = point.feat
else:
feat = point
cls_logits = self.cls_head(feat)
if self.training:
loss = self.criteria(cls_logits, input_dict["category"])
return dict(loss=loss)
elif "category" in input_dict.keys():
loss = self.criteria(cls_logits, input_dict["category"])
return dict(loss=loss, cls_logits=cls_logits)
else:
return dict(cls_logits=cls_logits)