Spaces:
Build error
Build error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| import numpy as np | |
| from typing import Callable, Dict, List, Union | |
| import fvcore.nn.weight_init as weight_init | |
| import torch | |
| from torch import nn | |
| from torch.nn import functional as F | |
| from detectron2.config import configurable | |
| from detectron2.data import MetadataCatalog | |
| from detectron2.layers import Conv2d, DepthwiseSeparableConv2d, ShapeSpec, get_norm | |
| from detectron2.modeling import ( | |
| META_ARCH_REGISTRY, | |
| SEM_SEG_HEADS_REGISTRY, | |
| build_backbone, | |
| build_sem_seg_head, | |
| ) | |
| from detectron2.modeling.postprocessing import sem_seg_postprocess | |
| from detectron2.projects.deeplab import DeepLabV3PlusHead | |
| from detectron2.projects.deeplab.loss import DeepLabCE | |
| from detectron2.structures import BitMasks, ImageList, Instances | |
| from detectron2.utils.registry import Registry | |
| from .post_processing import get_panoptic_segmentation | |
| __all__ = ["PanopticDeepLab", "INS_EMBED_BRANCHES_REGISTRY", "build_ins_embed_branch"] | |
| INS_EMBED_BRANCHES_REGISTRY = Registry("INS_EMBED_BRANCHES") | |
| INS_EMBED_BRANCHES_REGISTRY.__doc__ = """ | |
| Registry for instance embedding branches, which make instance embedding | |
| predictions from feature maps. | |
| """ | |
| class PanopticDeepLab(nn.Module): | |
| """ | |
| Main class for panoptic segmentation architectures. | |
| """ | |
| def __init__(self, cfg): | |
| super().__init__() | |
| self.backbone = build_backbone(cfg) | |
| self.sem_seg_head = build_sem_seg_head(cfg, self.backbone.output_shape()) | |
| self.ins_embed_head = build_ins_embed_branch(cfg, self.backbone.output_shape()) | |
| self.register_buffer("pixel_mean", torch.tensor(cfg.MODEL.PIXEL_MEAN).view(-1, 1, 1), False) | |
| self.register_buffer("pixel_std", torch.tensor(cfg.MODEL.PIXEL_STD).view(-1, 1, 1), False) | |
| self.meta = MetadataCatalog.get(cfg.DATASETS.TRAIN[0]) | |
| self.stuff_area = cfg.MODEL.PANOPTIC_DEEPLAB.STUFF_AREA | |
| self.threshold = cfg.MODEL.PANOPTIC_DEEPLAB.CENTER_THRESHOLD | |
| self.nms_kernel = cfg.MODEL.PANOPTIC_DEEPLAB.NMS_KERNEL | |
| self.top_k = cfg.MODEL.PANOPTIC_DEEPLAB.TOP_K_INSTANCE | |
| self.predict_instances = cfg.MODEL.PANOPTIC_DEEPLAB.PREDICT_INSTANCES | |
| self.use_depthwise_separable_conv = cfg.MODEL.PANOPTIC_DEEPLAB.USE_DEPTHWISE_SEPARABLE_CONV | |
| assert ( | |
| cfg.MODEL.SEM_SEG_HEAD.USE_DEPTHWISE_SEPARABLE_CONV | |
| == cfg.MODEL.PANOPTIC_DEEPLAB.USE_DEPTHWISE_SEPARABLE_CONV | |
| ) | |
| self.size_divisibility = cfg.MODEL.PANOPTIC_DEEPLAB.SIZE_DIVISIBILITY | |
| self.benchmark_network_speed = cfg.MODEL.PANOPTIC_DEEPLAB.BENCHMARK_NETWORK_SPEED | |
| def device(self): | |
| return self.pixel_mean.device | |
| def forward(self, batched_inputs): | |
| """ | |
| Args: | |
| batched_inputs: a list, batched outputs of :class:`DatasetMapper`. | |
| Each item in the list contains the inputs for one image. | |
| For now, each item in the list is a dict that contains: | |
| * "image": Tensor, image in (C, H, W) format. | |
| * "sem_seg": semantic segmentation ground truth | |
| * "center": center points heatmap ground truth | |
| * "offset": pixel offsets to center points ground truth | |
| * Other information that's included in the original dicts, such as: | |
| "height", "width" (int): the output resolution of the model (may be different | |
| from input resolution), used in inference. | |
| Returns: | |
| list[dict]: | |
| each dict is the results for one image. The dict contains the following keys: | |
| * "panoptic_seg", "sem_seg": see documentation | |
| :doc:`/tutorials/models` for the standard output format | |
| * "instances": available if ``predict_instances is True``. see documentation | |
| :doc:`/tutorials/models` for the standard output format | |
| """ | |
| images = [x["image"].to(self.device) for x in batched_inputs] | |
| images = [(x - self.pixel_mean) / self.pixel_std for x in images] | |
| # To avoid error in ASPP layer when input has different size. | |
| size_divisibility = ( | |
| self.size_divisibility | |
| if self.size_divisibility > 0 | |
| else self.backbone.size_divisibility | |
| ) | |
| images = ImageList.from_tensors(images, size_divisibility) | |
| features = self.backbone(images.tensor) | |
| losses = {} | |
| if "sem_seg" in batched_inputs[0]: | |
| targets = [x["sem_seg"].to(self.device) for x in batched_inputs] | |
| targets = ImageList.from_tensors( | |
| targets, size_divisibility, self.sem_seg_head.ignore_value | |
| ).tensor | |
| if "sem_seg_weights" in batched_inputs[0]: | |
| # The default D2 DatasetMapper may not contain "sem_seg_weights" | |
| # Avoid error in testing when default DatasetMapper is used. | |
| weights = [x["sem_seg_weights"].to(self.device) for x in batched_inputs] | |
| weights = ImageList.from_tensors(weights, size_divisibility).tensor | |
| else: | |
| weights = None | |
| else: | |
| targets = None | |
| weights = None | |
| sem_seg_results, sem_seg_losses = self.sem_seg_head(features, targets, weights) | |
| losses.update(sem_seg_losses) | |
| if "center" in batched_inputs[0] and "offset" in batched_inputs[0]: | |
| center_targets = [x["center"].to(self.device) for x in batched_inputs] | |
| center_targets = ImageList.from_tensors( | |
| center_targets, size_divisibility | |
| ).tensor.unsqueeze(1) | |
| center_weights = [x["center_weights"].to(self.device) for x in batched_inputs] | |
| center_weights = ImageList.from_tensors(center_weights, size_divisibility).tensor | |
| offset_targets = [x["offset"].to(self.device) for x in batched_inputs] | |
| offset_targets = ImageList.from_tensors(offset_targets, size_divisibility).tensor | |
| offset_weights = [x["offset_weights"].to(self.device) for x in batched_inputs] | |
| offset_weights = ImageList.from_tensors(offset_weights, size_divisibility).tensor | |
| else: | |
| center_targets = None | |
| center_weights = None | |
| offset_targets = None | |
| offset_weights = None | |
| center_results, offset_results, center_losses, offset_losses = self.ins_embed_head( | |
| features, center_targets, center_weights, offset_targets, offset_weights | |
| ) | |
| losses.update(center_losses) | |
| losses.update(offset_losses) | |
| if self.training: | |
| return losses | |
| if self.benchmark_network_speed: | |
| return [] | |
| processed_results = [] | |
| for sem_seg_result, center_result, offset_result, input_per_image, image_size in zip( | |
| sem_seg_results, center_results, offset_results, batched_inputs, images.image_sizes | |
| ): | |
| height = input_per_image.get("height") | |
| width = input_per_image.get("width") | |
| r = sem_seg_postprocess(sem_seg_result, image_size, height, width) | |
| c = sem_seg_postprocess(center_result, image_size, height, width) | |
| o = sem_seg_postprocess(offset_result, image_size, height, width) | |
| # Post-processing to get panoptic segmentation. | |
| panoptic_image, _ = get_panoptic_segmentation( | |
| r.argmax(dim=0, keepdim=True), | |
| c, | |
| o, | |
| thing_ids=self.meta.thing_dataset_id_to_contiguous_id.values(), | |
| label_divisor=self.meta.label_divisor, | |
| stuff_area=self.stuff_area, | |
| void_label=-1, | |
| threshold=self.threshold, | |
| nms_kernel=self.nms_kernel, | |
| top_k=self.top_k, | |
| ) | |
| # For semantic segmentation evaluation. | |
| processed_results.append({"sem_seg": r}) | |
| panoptic_image = panoptic_image.squeeze(0) | |
| semantic_prob = F.softmax(r, dim=0) | |
| # For panoptic segmentation evaluation. | |
| processed_results[-1]["panoptic_seg"] = (panoptic_image, None) | |
| # For instance segmentation evaluation. | |
| if self.predict_instances: | |
| instances = [] | |
| panoptic_image_cpu = panoptic_image.cpu().numpy() | |
| for panoptic_label in np.unique(panoptic_image_cpu): | |
| if panoptic_label == -1: | |
| continue | |
| pred_class = panoptic_label // self.meta.label_divisor | |
| isthing = pred_class in list( | |
| self.meta.thing_dataset_id_to_contiguous_id.values() | |
| ) | |
| # Get instance segmentation results. | |
| if isthing: | |
| instance = Instances((height, width)) | |
| # Evaluation code takes continuous id starting from 0 | |
| instance.pred_classes = torch.tensor( | |
| [pred_class], device=panoptic_image.device | |
| ) | |
| mask = panoptic_image == panoptic_label | |
| instance.pred_masks = mask.unsqueeze(0) | |
| # Average semantic probability | |
| sem_scores = semantic_prob[pred_class, ...] | |
| sem_scores = torch.mean(sem_scores[mask]) | |
| # Center point probability | |
| mask_indices = torch.nonzero(mask).float() | |
| center_y, center_x = ( | |
| torch.mean(mask_indices[:, 0]), | |
| torch.mean(mask_indices[:, 1]), | |
| ) | |
| center_scores = c[0, int(center_y.item()), int(center_x.item())] | |
| # Confidence score is semantic prob * center prob. | |
| instance.scores = torch.tensor( | |
| [sem_scores * center_scores], device=panoptic_image.device | |
| ) | |
| # Get bounding boxes | |
| instance.pred_boxes = BitMasks(instance.pred_masks).get_bounding_boxes() | |
| instances.append(instance) | |
| if len(instances) > 0: | |
| processed_results[-1]["instances"] = Instances.cat(instances) | |
| return processed_results | |
| class PanopticDeepLabSemSegHead(DeepLabV3PlusHead): | |
| """ | |
| A semantic segmentation head described in :paper:`Panoptic-DeepLab`. | |
| """ | |
| def __init__( | |
| self, | |
| input_shape: Dict[str, ShapeSpec], | |
| *, | |
| decoder_channels: List[int], | |
| norm: Union[str, Callable], | |
| head_channels: int, | |
| loss_weight: float, | |
| loss_type: str, | |
| loss_top_k: float, | |
| ignore_value: int, | |
| num_classes: int, | |
| **kwargs, | |
| ): | |
| """ | |
| NOTE: this interface is experimental. | |
| Args: | |
| input_shape (ShapeSpec): shape of the input feature | |
| decoder_channels (list[int]): a list of output channels of each | |
| decoder stage. It should have the same length as "input_shape" | |
| (each element in "input_shape" corresponds to one decoder stage). | |
| norm (str or callable): normalization for all conv layers. | |
| head_channels (int): the output channels of extra convolutions | |
| between decoder and predictor. | |
| loss_weight (float): loss weight. | |
| loss_top_k: (float): setting the top k% hardest pixels for | |
| "hard_pixel_mining" loss. | |
| loss_type, ignore_value, num_classes: the same as the base class. | |
| """ | |
| super().__init__( | |
| input_shape, | |
| decoder_channels=decoder_channels, | |
| norm=norm, | |
| ignore_value=ignore_value, | |
| **kwargs, | |
| ) | |
| assert self.decoder_only | |
| self.loss_weight = loss_weight | |
| use_bias = norm == "" | |
| # `head` is additional transform before predictor | |
| if self.use_depthwise_separable_conv: | |
| # We use a single 5x5 DepthwiseSeparableConv2d to replace | |
| # 2 3x3 Conv2d since they have the same receptive field. | |
| self.head = DepthwiseSeparableConv2d( | |
| decoder_channels[0], | |
| head_channels, | |
| kernel_size=5, | |
| padding=2, | |
| norm1=norm, | |
| activation1=F.relu, | |
| norm2=norm, | |
| activation2=F.relu, | |
| ) | |
| else: | |
| self.head = nn.Sequential( | |
| Conv2d( | |
| decoder_channels[0], | |
| decoder_channels[0], | |
| kernel_size=3, | |
| padding=1, | |
| bias=use_bias, | |
| norm=get_norm(norm, decoder_channels[0]), | |
| activation=F.relu, | |
| ), | |
| Conv2d( | |
| decoder_channels[0], | |
| head_channels, | |
| kernel_size=3, | |
| padding=1, | |
| bias=use_bias, | |
| norm=get_norm(norm, head_channels), | |
| activation=F.relu, | |
| ), | |
| ) | |
| weight_init.c2_xavier_fill(self.head[0]) | |
| weight_init.c2_xavier_fill(self.head[1]) | |
| self.predictor = Conv2d(head_channels, num_classes, kernel_size=1) | |
| nn.init.normal_(self.predictor.weight, 0, 0.001) | |
| nn.init.constant_(self.predictor.bias, 0) | |
| if loss_type == "cross_entropy": | |
| self.loss = nn.CrossEntropyLoss(reduction="mean", ignore_index=ignore_value) | |
| elif loss_type == "hard_pixel_mining": | |
| self.loss = DeepLabCE(ignore_label=ignore_value, top_k_percent_pixels=loss_top_k) | |
| else: | |
| raise ValueError("Unexpected loss type: %s" % loss_type) | |
| def from_config(cls, cfg, input_shape): | |
| ret = super().from_config(cfg, input_shape) | |
| ret["head_channels"] = cfg.MODEL.SEM_SEG_HEAD.HEAD_CHANNELS | |
| ret["loss_top_k"] = cfg.MODEL.SEM_SEG_HEAD.LOSS_TOP_K | |
| return ret | |
| def forward(self, features, targets=None, weights=None): | |
| """ | |
| Returns: | |
| In training, returns (None, dict of losses) | |
| In inference, returns (CxHxW logits, {}) | |
| """ | |
| y = self.layers(features) | |
| if self.training: | |
| return None, self.losses(y, targets, weights) | |
| else: | |
| y = F.interpolate( | |
| y, scale_factor=self.common_stride, mode="bilinear", align_corners=False | |
| ) | |
| return y, {} | |
| def layers(self, features): | |
| assert self.decoder_only | |
| y = super().layers(features) | |
| y = self.head(y) | |
| y = self.predictor(y) | |
| return y | |
| def losses(self, predictions, targets, weights=None): | |
| predictions = F.interpolate( | |
| predictions, scale_factor=self.common_stride, mode="bilinear", align_corners=False | |
| ) | |
| loss = self.loss(predictions, targets, weights) | |
| losses = {"loss_sem_seg": loss * self.loss_weight} | |
| return losses | |
| def build_ins_embed_branch(cfg, input_shape): | |
| """ | |
| Build a instance embedding branch from `cfg.MODEL.INS_EMBED_HEAD.NAME`. | |
| """ | |
| name = cfg.MODEL.INS_EMBED_HEAD.NAME | |
| return INS_EMBED_BRANCHES_REGISTRY.get(name)(cfg, input_shape) | |
| class PanopticDeepLabInsEmbedHead(DeepLabV3PlusHead): | |
| """ | |
| A instance embedding head described in :paper:`Panoptic-DeepLab`. | |
| """ | |
| def __init__( | |
| self, | |
| input_shape: Dict[str, ShapeSpec], | |
| *, | |
| decoder_channels: List[int], | |
| norm: Union[str, Callable], | |
| head_channels: int, | |
| center_loss_weight: float, | |
| offset_loss_weight: float, | |
| **kwargs, | |
| ): | |
| """ | |
| NOTE: this interface is experimental. | |
| Args: | |
| input_shape (ShapeSpec): shape of the input feature | |
| decoder_channels (list[int]): a list of output channels of each | |
| decoder stage. It should have the same length as "input_shape" | |
| (each element in "input_shape" corresponds to one decoder stage). | |
| norm (str or callable): normalization for all conv layers. | |
| head_channels (int): the output channels of extra convolutions | |
| between decoder and predictor. | |
| center_loss_weight (float): loss weight for center point prediction. | |
| offset_loss_weight (float): loss weight for center offset prediction. | |
| """ | |
| super().__init__(input_shape, decoder_channels=decoder_channels, norm=norm, **kwargs) | |
| assert self.decoder_only | |
| self.center_loss_weight = center_loss_weight | |
| self.offset_loss_weight = offset_loss_weight | |
| use_bias = norm == "" | |
| # center prediction | |
| # `head` is additional transform before predictor | |
| self.center_head = nn.Sequential( | |
| Conv2d( | |
| decoder_channels[0], | |
| decoder_channels[0], | |
| kernel_size=3, | |
| padding=1, | |
| bias=use_bias, | |
| norm=get_norm(norm, decoder_channels[0]), | |
| activation=F.relu, | |
| ), | |
| Conv2d( | |
| decoder_channels[0], | |
| head_channels, | |
| kernel_size=3, | |
| padding=1, | |
| bias=use_bias, | |
| norm=get_norm(norm, head_channels), | |
| activation=F.relu, | |
| ), | |
| ) | |
| weight_init.c2_xavier_fill(self.center_head[0]) | |
| weight_init.c2_xavier_fill(self.center_head[1]) | |
| self.center_predictor = Conv2d(head_channels, 1, kernel_size=1) | |
| nn.init.normal_(self.center_predictor.weight, 0, 0.001) | |
| nn.init.constant_(self.center_predictor.bias, 0) | |
| # offset prediction | |
| # `head` is additional transform before predictor | |
| if self.use_depthwise_separable_conv: | |
| # We use a single 5x5 DepthwiseSeparableConv2d to replace | |
| # 2 3x3 Conv2d since they have the same receptive field. | |
| self.offset_head = DepthwiseSeparableConv2d( | |
| decoder_channels[0], | |
| head_channels, | |
| kernel_size=5, | |
| padding=2, | |
| norm1=norm, | |
| activation1=F.relu, | |
| norm2=norm, | |
| activation2=F.relu, | |
| ) | |
| else: | |
| self.offset_head = nn.Sequential( | |
| Conv2d( | |
| decoder_channels[0], | |
| decoder_channels[0], | |
| kernel_size=3, | |
| padding=1, | |
| bias=use_bias, | |
| norm=get_norm(norm, decoder_channels[0]), | |
| activation=F.relu, | |
| ), | |
| Conv2d( | |
| decoder_channels[0], | |
| head_channels, | |
| kernel_size=3, | |
| padding=1, | |
| bias=use_bias, | |
| norm=get_norm(norm, head_channels), | |
| activation=F.relu, | |
| ), | |
| ) | |
| weight_init.c2_xavier_fill(self.offset_head[0]) | |
| weight_init.c2_xavier_fill(self.offset_head[1]) | |
| self.offset_predictor = Conv2d(head_channels, 2, kernel_size=1) | |
| nn.init.normal_(self.offset_predictor.weight, 0, 0.001) | |
| nn.init.constant_(self.offset_predictor.bias, 0) | |
| self.center_loss = nn.MSELoss(reduction="none") | |
| self.offset_loss = nn.L1Loss(reduction="none") | |
| def from_config(cls, cfg, input_shape): | |
| if cfg.INPUT.CROP.ENABLED: | |
| assert cfg.INPUT.CROP.TYPE == "absolute" | |
| train_size = cfg.INPUT.CROP.SIZE | |
| else: | |
| train_size = None | |
| decoder_channels = [cfg.MODEL.INS_EMBED_HEAD.CONVS_DIM] * ( | |
| len(cfg.MODEL.INS_EMBED_HEAD.IN_FEATURES) - 1 | |
| ) + [cfg.MODEL.INS_EMBED_HEAD.ASPP_CHANNELS] | |
| ret = dict( | |
| input_shape={ | |
| k: v for k, v in input_shape.items() if k in cfg.MODEL.INS_EMBED_HEAD.IN_FEATURES | |
| }, | |
| project_channels=cfg.MODEL.INS_EMBED_HEAD.PROJECT_CHANNELS, | |
| aspp_dilations=cfg.MODEL.INS_EMBED_HEAD.ASPP_DILATIONS, | |
| aspp_dropout=cfg.MODEL.INS_EMBED_HEAD.ASPP_DROPOUT, | |
| decoder_channels=decoder_channels, | |
| common_stride=cfg.MODEL.INS_EMBED_HEAD.COMMON_STRIDE, | |
| norm=cfg.MODEL.INS_EMBED_HEAD.NORM, | |
| train_size=train_size, | |
| head_channels=cfg.MODEL.INS_EMBED_HEAD.HEAD_CHANNELS, | |
| center_loss_weight=cfg.MODEL.INS_EMBED_HEAD.CENTER_LOSS_WEIGHT, | |
| offset_loss_weight=cfg.MODEL.INS_EMBED_HEAD.OFFSET_LOSS_WEIGHT, | |
| use_depthwise_separable_conv=cfg.MODEL.SEM_SEG_HEAD.USE_DEPTHWISE_SEPARABLE_CONV, | |
| ) | |
| return ret | |
| def forward( | |
| self, | |
| features, | |
| center_targets=None, | |
| center_weights=None, | |
| offset_targets=None, | |
| offset_weights=None, | |
| ): | |
| """ | |
| Returns: | |
| In training, returns (None, dict of losses) | |
| In inference, returns (CxHxW logits, {}) | |
| """ | |
| center, offset = self.layers(features) | |
| if self.training: | |
| return ( | |
| None, | |
| None, | |
| self.center_losses(center, center_targets, center_weights), | |
| self.offset_losses(offset, offset_targets, offset_weights), | |
| ) | |
| else: | |
| center = F.interpolate( | |
| center, scale_factor=self.common_stride, mode="bilinear", align_corners=False | |
| ) | |
| offset = ( | |
| F.interpolate( | |
| offset, scale_factor=self.common_stride, mode="bilinear", align_corners=False | |
| ) | |
| * self.common_stride | |
| ) | |
| return center, offset, {}, {} | |
| def layers(self, features): | |
| assert self.decoder_only | |
| y = super().layers(features) | |
| # center | |
| center = self.center_head(y) | |
| center = self.center_predictor(center) | |
| # offset | |
| offset = self.offset_head(y) | |
| offset = self.offset_predictor(offset) | |
| return center, offset | |
| def center_losses(self, predictions, targets, weights): | |
| predictions = F.interpolate( | |
| predictions, scale_factor=self.common_stride, mode="bilinear", align_corners=False | |
| ) | |
| loss = self.center_loss(predictions, targets) * weights | |
| if weights.sum() > 0: | |
| loss = loss.sum() / weights.sum() | |
| else: | |
| loss = loss.sum() * 0 | |
| losses = {"loss_center": loss * self.center_loss_weight} | |
| return losses | |
| def offset_losses(self, predictions, targets, weights): | |
| predictions = ( | |
| F.interpolate( | |
| predictions, scale_factor=self.common_stride, mode="bilinear", align_corners=False | |
| ) | |
| * self.common_stride | |
| ) | |
| loss = self.offset_loss(predictions, targets) * weights | |
| if weights.sum() > 0: | |
| loss = loss.sum() / weights.sum() | |
| else: | |
| loss = loss.sum() * 0 | |
| losses = {"loss_offset": loss * self.offset_loss_weight} | |
| return losses | |