Spaces:
Runtime error
Runtime error
| import torch.nn as nn | |
| import torch_scatter | |
| from pointcept.models.losses import build_criteria | |
| from pointcept.models.utils.structure import Point | |
| from .builder import MODELS, build_model | |
| class DefaultSegmentor(nn.Module): | |
| def __init__(self, backbone=None, criteria=None): | |
| super().__init__() | |
| self.backbone = build_model(backbone) | |
| self.criteria = build_criteria(criteria) | |
| def forward(self, input_dict): | |
| if "condition" in input_dict.keys(): | |
| # PPT (https://arxiv.org/abs/2308.09718) | |
| # currently, only support one batch one condition | |
| input_dict["condition"] = input_dict["condition"][0] | |
| seg_logits = self.backbone(input_dict) | |
| # train | |
| if self.training: | |
| loss = self.criteria(seg_logits, input_dict["segment"]) | |
| return dict(loss=loss) | |
| # eval | |
| elif "segment" in input_dict.keys(): | |
| loss = self.criteria(seg_logits, input_dict["segment"]) | |
| return dict(loss=loss, seg_logits=seg_logits) | |
| # test | |
| else: | |
| return dict(seg_logits=seg_logits) | |
| class DefaultSegmentorV2(nn.Module): | |
| def __init__( | |
| self, | |
| num_classes, | |
| backbone_out_channels, | |
| backbone=None, | |
| criteria=None, | |
| ): | |
| super().__init__() | |
| self.seg_head = ( | |
| nn.Linear(backbone_out_channels, num_classes) | |
| if num_classes > 0 | |
| else nn.Identity() | |
| ) | |
| self.backbone = build_model(backbone) | |
| self.criteria = build_criteria(criteria) | |
| def forward(self, input_dict): | |
| point = Point(input_dict) | |
| point = self.backbone(point) | |
| # Backbone added after v1.5.0 return Point instead of feat and use DefaultSegmentorV2 | |
| # TODO: remove this part after make all backbone return Point only. | |
| if isinstance(point, Point): | |
| feat = point.feat | |
| else: | |
| feat = point | |
| seg_logits = self.seg_head(feat) | |
| # train | |
| if self.training: | |
| loss = self.criteria(seg_logits, input_dict["segment"]) | |
| return dict(loss=loss) | |
| # eval | |
| elif "segment" in input_dict.keys(): | |
| loss = self.criteria(seg_logits, input_dict["segment"]) | |
| return dict(loss=loss, seg_logits=seg_logits) | |
| # test | |
| else: | |
| return dict(seg_logits=seg_logits) | |
| class DefaultClassifier(nn.Module): | |
| def __init__( | |
| self, | |
| backbone=None, | |
| criteria=None, | |
| num_classes=40, | |
| backbone_embed_dim=256, | |
| ): | |
| super().__init__() | |
| self.backbone = build_model(backbone) | |
| self.criteria = build_criteria(criteria) | |
| self.num_classes = num_classes | |
| self.backbone_embed_dim = backbone_embed_dim | |
| self.cls_head = nn.Sequential( | |
| nn.Linear(backbone_embed_dim, 256), | |
| nn.BatchNorm1d(256), | |
| nn.ReLU(inplace=True), | |
| nn.Dropout(p=0.5), | |
| nn.Linear(256, 128), | |
| nn.BatchNorm1d(128), | |
| nn.ReLU(inplace=True), | |
| nn.Dropout(p=0.5), | |
| nn.Linear(128, num_classes), | |
| ) | |
| def forward(self, input_dict): | |
| point = Point(input_dict) | |
| point = self.backbone(point) | |
| # Backbone added after v1.5.0 return Point instead of feat | |
| # And after v1.5.0 feature aggregation for classification operated in classifier | |
| # TODO: remove this part after make all backbone return Point only. | |
| if isinstance(point, Point): | |
| point.feat = torch_scatter.segment_csr( | |
| src=point.feat, | |
| indptr=nn.functional.pad(point.offset, (1, 0)), | |
| reduce="mean", | |
| ) | |
| feat = point.feat | |
| else: | |
| feat = point | |
| cls_logits = self.cls_head(feat) | |
| if self.training: | |
| loss = self.criteria(cls_logits, input_dict["category"]) | |
| return dict(loss=loss) | |
| elif "category" in input_dict.keys(): | |
| loss = self.criteria(cls_logits, input_dict["category"]) | |
| return dict(loss=loss, cls_logits=cls_logits) | |
| else: | |
| return dict(cls_logits=cls_logits) | |