|
|
| import itertools
|
| import logging
|
| import numpy as np
|
| from collections import OrderedDict
|
| from collections.abc import Mapping
|
| from typing import Dict, List, Optional, Tuple, Union
|
| import torch
|
| from omegaconf import DictConfig, OmegaConf
|
| from torch import Tensor, nn
|
|
|
| from detectron2.layers import ShapeSpec
|
| from detectron2.structures import BitMasks, Boxes, ImageList, Instances
|
| from detectron2.utils.events import get_event_storage
|
|
|
| from .backbone import Backbone
|
|
|
| logger = logging.getLogger(__name__)
|
|
|
|
|
| def _to_container(cfg):
|
| """
|
| mmdet will assert the type of dict/list.
|
| So convert omegaconf objects to dict/list.
|
| """
|
| if isinstance(cfg, DictConfig):
|
| cfg = OmegaConf.to_container(cfg, resolve=True)
|
| from mmcv.utils import ConfigDict
|
|
|
| return ConfigDict(cfg)
|
|
|
|
|
| class MMDetBackbone(Backbone):
|
| """
|
| Wrapper of mmdetection backbones to use in detectron2.
|
|
|
| mmdet backbones produce list/tuple of tensors, while detectron2 backbones
|
| produce a dict of tensors. This class wraps the given backbone to produce
|
| output in detectron2's convention, so it can be used in place of detectron2
|
| backbones.
|
| """
|
|
|
| def __init__(
|
| self,
|
| backbone: Union[nn.Module, Mapping],
|
| neck: Union[nn.Module, Mapping, None] = None,
|
| *,
|
| output_shapes: List[ShapeSpec],
|
| output_names: Optional[List[str]] = None,
|
| ):
|
| """
|
| Args:
|
| backbone: either a backbone module or a mmdet config dict that defines a
|
| backbone. The backbone takes a 4D image tensor and returns a
|
| sequence of tensors.
|
| neck: either a backbone module or a mmdet config dict that defines a
|
| neck. The neck takes outputs of backbone and returns a
|
| sequence of tensors. If None, no neck is used.
|
| output_shapes: shape for every output of the backbone (or neck, if given).
|
| stride and channels are often needed.
|
| output_names: names for every output of the backbone (or neck, if given).
|
| By default, will use "out0", "out1", ...
|
| """
|
| super().__init__()
|
| if isinstance(backbone, Mapping):
|
| from mmdet.models import build_backbone
|
|
|
| backbone = build_backbone(_to_container(backbone))
|
| self.backbone = backbone
|
|
|
| if isinstance(neck, Mapping):
|
| from mmdet.models import build_neck
|
|
|
| neck = build_neck(_to_container(neck))
|
| self.neck = neck
|
|
|
|
|
|
|
|
|
| logger.info("Initializing mmdet backbone weights...")
|
| self.backbone.init_weights()
|
|
|
|
|
|
|
| self.backbone.train()
|
| if self.neck is not None:
|
| logger.info("Initializing mmdet neck weights ...")
|
| if isinstance(self.neck, nn.Sequential):
|
| for m in self.neck:
|
| m.init_weights()
|
| else:
|
| self.neck.init_weights()
|
| self.neck.train()
|
|
|
| self._output_shapes = output_shapes
|
| if not output_names:
|
| output_names = [f"out{i}" for i in range(len(output_shapes))]
|
| self._output_names = output_names
|
|
|
| def forward(self, x) -> Dict[str, Tensor]:
|
| outs = self.backbone(x)
|
| if self.neck is not None:
|
| outs = self.neck(outs)
|
| assert isinstance(
|
| outs, (list, tuple)
|
| ), "mmdet backbone should return a list/tuple of tensors!"
|
| if len(outs) != len(self._output_shapes):
|
| raise ValueError(
|
| "Length of output_shapes does not match outputs from the mmdet backbone: "
|
| f"{len(outs)} != {len(self._output_shapes)}"
|
| )
|
| return {k: v for k, v in zip(self._output_names, outs)}
|
|
|
| def output_shape(self) -> Dict[str, ShapeSpec]:
|
| return {k: v for k, v in zip(self._output_names, self._output_shapes)}
|
|
|
|
|
| class MMDetDetector(nn.Module):
|
| """
|
| Wrapper of a mmdetection detector model, for detection and instance segmentation.
|
| Input/output formats of this class follow detectron2's convention, so a
|
| mmdetection model can be trained and evaluated in detectron2.
|
| """
|
|
|
| def __init__(
|
| self,
|
| detector: Union[nn.Module, Mapping],
|
| *,
|
|
|
|
|
| size_divisibility=32,
|
| pixel_mean: Tuple[float],
|
| pixel_std: Tuple[float],
|
| ):
|
| """
|
| Args:
|
| detector: a mmdet detector, or a mmdet config dict that defines a detector.
|
| size_divisibility: pad input images to multiple of this number
|
| pixel_mean: per-channel mean to normalize input image
|
| pixel_std: per-channel stddev to normalize input image
|
| """
|
| super().__init__()
|
| if isinstance(detector, Mapping):
|
| from mmdet.models import build_detector
|
|
|
| detector = build_detector(_to_container(detector))
|
| self.detector = detector
|
| self.detector.init_weights()
|
| self.size_divisibility = size_divisibility
|
|
|
| self.register_buffer("pixel_mean", torch.tensor(pixel_mean).view(-1, 1, 1), False)
|
| self.register_buffer("pixel_std", torch.tensor(pixel_std).view(-1, 1, 1), False)
|
| assert (
|
| self.pixel_mean.shape == self.pixel_std.shape
|
| ), f"{self.pixel_mean} and {self.pixel_std} have different shapes!"
|
|
|
| def forward(self, batched_inputs: List[Dict[str, torch.Tensor]]):
|
| images = [x["image"].to(self.device) for x in batched_inputs]
|
| images = [(x - self.pixel_mean) / self.pixel_std for x in images]
|
| images = ImageList.from_tensors(images, size_divisibility=self.size_divisibility).tensor
|
| metas = []
|
| rescale = {"height" in x for x in batched_inputs}
|
| if len(rescale) != 1:
|
| raise ValueError("Some inputs have original height/width, but some don't!")
|
| rescale = list(rescale)[0]
|
| output_shapes = []
|
| for input in batched_inputs:
|
| meta = {}
|
| c, h, w = input["image"].shape
|
| meta["img_shape"] = meta["ori_shape"] = (h, w, c)
|
| if rescale:
|
| scale_factor = np.array(
|
| [w / input["width"], h / input["height"]] * 2, dtype="float32"
|
| )
|
| ori_shape = (input["height"], input["width"])
|
| output_shapes.append(ori_shape)
|
| meta["ori_shape"] = ori_shape + (c,)
|
| else:
|
| scale_factor = 1.0
|
| output_shapes.append((h, w))
|
| meta["scale_factor"] = scale_factor
|
| meta["flip"] = False
|
| padh, padw = images.shape[-2:]
|
| meta["pad_shape"] = (padh, padw, c)
|
| metas.append(meta)
|
|
|
| if self.training:
|
| gt_instances = [x["instances"].to(self.device) for x in batched_inputs]
|
| if gt_instances[0].has("gt_masks"):
|
| from mmdet.core import PolygonMasks as mm_PolygonMasks, BitmapMasks as mm_BitMasks
|
|
|
| def convert_mask(m, shape):
|
|
|
| if isinstance(m, BitMasks):
|
| return mm_BitMasks(m.tensor.cpu().numpy(), shape[0], shape[1])
|
| else:
|
| return mm_PolygonMasks(m.polygons, shape[0], shape[1])
|
|
|
| gt_masks = [convert_mask(x.gt_masks, x.image_size) for x in gt_instances]
|
| losses_and_metrics = self.detector.forward_train(
|
| images,
|
| metas,
|
| [x.gt_boxes.tensor for x in gt_instances],
|
| [x.gt_classes for x in gt_instances],
|
| gt_masks=gt_masks,
|
| )
|
| else:
|
| losses_and_metrics = self.detector.forward_train(
|
| images,
|
| metas,
|
| [x.gt_boxes.tensor for x in gt_instances],
|
| [x.gt_classes for x in gt_instances],
|
| )
|
| return _parse_losses(losses_and_metrics)
|
| else:
|
| results = self.detector.simple_test(images, metas, rescale=rescale)
|
| results = [
|
| {"instances": _convert_mmdet_result(r, shape)}
|
| for r, shape in zip(results, output_shapes)
|
| ]
|
| return results
|
|
|
| @property
|
| def device(self):
|
| return self.pixel_mean.device
|
|
|
|
|
|
|
|
|
| def _convert_mmdet_result(result, shape: Tuple[int, int]) -> Instances:
|
| if isinstance(result, tuple):
|
| bbox_result, segm_result = result
|
| if isinstance(segm_result, tuple):
|
| segm_result = segm_result[0]
|
| else:
|
| bbox_result, segm_result = result, None
|
|
|
| bboxes = torch.from_numpy(np.vstack(bbox_result))
|
| bboxes, scores = bboxes[:, :4], bboxes[:, -1]
|
| labels = [
|
| torch.full((bbox.shape[0],), i, dtype=torch.int32) for i, bbox in enumerate(bbox_result)
|
| ]
|
| labels = torch.cat(labels)
|
| inst = Instances(shape)
|
| inst.pred_boxes = Boxes(bboxes)
|
| inst.scores = scores
|
| inst.pred_classes = labels
|
|
|
| if segm_result is not None and len(labels) > 0:
|
| segm_result = list(itertools.chain(*segm_result))
|
| segm_result = [torch.from_numpy(x) if isinstance(x, np.ndarray) else x for x in segm_result]
|
| segm_result = torch.stack(segm_result, dim=0)
|
| inst.pred_masks = segm_result
|
| return inst
|
|
|
|
|
|
|
| def _parse_losses(losses: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
| log_vars = OrderedDict()
|
| for loss_name, loss_value in losses.items():
|
| if isinstance(loss_value, torch.Tensor):
|
| log_vars[loss_name] = loss_value.mean()
|
| elif isinstance(loss_value, list):
|
| log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
|
| else:
|
| raise TypeError(f"{loss_name} is not a tensor or list of tensors")
|
|
|
| if "loss" not in loss_name:
|
|
|
| storage = get_event_storage()
|
| value = log_vars.pop(loss_name).cpu().item()
|
| storage.put_scalar(loss_name, value)
|
| return log_vars
|
|
|