| """Implementation of Pointnet.""" |
|
|
| from __future__ import annotations |
|
|
| import torch |
| from torch import nn |
|
|
| from vis4d.common.ckpt import load_model_checkpoint |
| from vis4d.common.typing import LossesType, ModelOutput |
| from vis4d.data.const import CommonKeys |
| from vis4d.op.base.pointnet import PointNetSegmentation, PointNetSemanticsOut |
| from vis4d.op.loss.orthogonal_transform_loss import ( |
| OrthogonalTransformRegularizationLoss, |
| ) |
|
|
|
|
| class PointnetSegmentationModel(nn.Module): |
| """Simple Segmentation Model using Pointnet.""" |
|
|
| def __init__( |
| self, |
| num_classes: int = 11, |
| in_dimensions: int = 3, |
| weights: str | None = None, |
| ) -> None: |
| """Simple Segmentation Model using Pointnet. |
| |
| Args: |
| num_classes: Number of semantic classes |
| in_dimensions: Input dimension |
| weights: Path to weight file |
| """ |
| super().__init__() |
| self.model = PointNetSegmentation( |
| n_classes=num_classes, in_dimensions=in_dimensions |
| ) |
| if weights is not None: |
| load_model_checkpoint(self, weights) |
|
|
| def __call__( |
| self, data: torch.Tensor, target: torch.Tensor | None = None |
| ) -> PointNetSemanticsOut | ModelOutput: |
| """Runs the semantic model. |
| |
| Args: |
| data: Input Tensor Shape [N, C, n_pts] |
| target: Target Classes shape [N, n_pts] |
| """ |
| return self._call_impl(data, target) |
|
|
| def forward( |
| self, data: torch.Tensor, target: torch.Tensor | None = None |
| ) -> PointNetSemanticsOut | ModelOutput: |
| """Runs the semantic model. |
| |
| Args: |
| data: Input Tensor Shape [N, C, n_pts] |
| target: Target Classes shape [N, n_pts] |
| """ |
| if target is not None: |
| return self.forward_train(data, target) |
| return self.forward_test(data) |
|
|
| def forward_train( |
| self, |
| points: torch.Tensor, |
| target: torch.Tensor, |
| ) -> PointNetSemanticsOut: |
| """Forward training stage. |
| |
| Args: |
| points: Input Tensor Shape [N, C, n_pts] |
| target: Target Classes shape [N, n_pts] |
| """ |
| out = self.model(points) |
| return out |
|
|
| def forward_test( |
| self, |
| points: torch.Tensor, |
| ) -> ModelOutput: |
| """Forward test stage. |
| |
| Args: |
| points: Input Tensor Shape [N, C, n_pts] |
| """ |
| return { |
| CommonKeys.semantics3d: torch.argmax( |
| self.model(points).class_logits, dim=1 |
| ) |
| } |
|
|
|
|
| class PointnetSegmentationLoss(nn.Module): |
| """PointnetSegmentationLoss Loss.""" |
|
|
| def __init__( |
| self, |
| regularize_transform: bool = True, |
| ignore_index: int = 255, |
| transform_weight: float = 1e-3, |
| semantic_weights: torch.Tensor | None = None, |
| ) -> None: |
| """Creates an instance of the class. |
| |
| Args: |
| regularize_transform: If true add transforms to loss |
| ignore_index: Semantic class that should be ignored |
| transform_weight: Loss weight factor for transform |
| regularization loss |
| semantic_weights: Classwise weights for semantic loss |
| """ |
| super().__init__() |
| self.segmentation_loss = nn.CrossEntropyLoss( |
| weight=semantic_weights, ignore_index=ignore_index |
| ) |
| self.transformation_loss = OrthogonalTransformRegularizationLoss() |
| self.regularize_transform = regularize_transform |
| self.transform_weight = transform_weight |
|
|
| def forward( |
| self, outputs: PointNetSemanticsOut, target: torch.Tensor |
| ) -> LossesType: |
| """Calculates the losss. |
| |
| Args: |
| outputs: Pointnet output |
| target: Target Labels |
| """ |
| if not self.regularize_transform: |
| dict( |
| segmentation_loss=self.segmentation_loss( |
| outputs.class_logits, target |
| ) |
| ) |
|
|
| return dict( |
| segmentation_loss=self.segmentation_loss( |
| outputs.class_logits, target |
| ), |
| transform_loss=self.transform_weight |
| * self.transformation_loss(outputs.transformations), |
| ) |
|
|