import torch from detectron2.data import MetadataCatalog from torch import nn import detectron2.data.transforms as T from detectron2.config import configurable from detectron2.modeling import build_model from detectron2.checkpoint import DetectionCheckpointer from detectron2.modeling.backbone.backbone import Backbone from detectron2.structures import ImageList from detectron2.modeling import (build_backbone, META_ARCH_REGISTRY, build_proposal_generator, build_roi_heads, detector_postprocess) from typing import Optional, Tuple @META_ARCH_REGISTRY.register() class GeneralizedRCNN_with_Rate(nn.Module): """ Generalized R-CNN. Any models that contains the following three components: 1. Per-image feature extraction (aka backbone) 2. Region proposal generation 3. Per-region feature extraction and prediction """ @configurable def __init__( self, *, backbone: Backbone, proposal_generator: nn.Module, roi_heads: nn.Module, pixel_mean: Tuple[float], pixel_std: Tuple[float], input_format: Optional[str] = None, vis_period: int = 0, ): """ NOTE: this interface is experimental. Args: backbone: a backbone module, must follow detectron2's backbone interface proposal_generator: a module that generates proposals using backbone features roi_heads: a ROI head that performs per-region computation pixel_mean, pixel_std: list or tuple with #channels element, representing the per-channel mean and std to be used to normalize the input image input_format: describe the meaning of channels of input. Needed by visualization vis_period: the period to run visualization. Set to 0 to disable. """ super().__init__() self.backbone = backbone self.proposal_generator = proposal_generator self.roi_heads = roi_heads self.input_format = input_format self.vis_period = vis_period if vis_period > 0: assert input_format is not None, "input_format is required for visualization!" self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1)) self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1)) assert ( self.pixel_mean.shape == self.pixel_std.shape ), f"{self.pixel_mean} and {self.pixel_std} have different shapes!" @classmethod def from_config(cls, cfg): backbone = build_backbone(cfg) return { "backbone": backbone, "proposal_generator": build_proposal_generator(cfg, backbone.output_shape()), "roi_heads": build_roi_heads(cfg, backbone.output_shape()), "input_format": cfg.INPUT.FORMAT, "vis_period": cfg.VIS_PERIOD, "pixel_mean": cfg.MODEL.PIXEL_MEAN, "pixel_std": cfg.MODEL.PIXEL_STD, } @property def device(self): return self.pixel_mean.device def forward(self, batched_inputs, trand_y_tilde): """ Args: batched_inputs: a list, batched outputs of :class:`DatasetMapper` . Each item in the list contains the inputs for one image. For now, each item in the list is a dict that contains: * image: Tensor, image in (C, H, W) format. * instances (optional): groundtruth :class:`Instances` * proposals (optional): :class:`Instances`, precomputed proposals. Other information that's included in the original dicts, such as: * "height", "width" (int): the output resolution of the model, used in inference. See :meth:`postprocess` for details. Returns: list[dict]: Each dict is the output for one input image. The dict contains one key "instances" whose value is a :class:`Instances`. The :class:`Instances` object has the following keys: "pred_boxes", "pred_classes", "scores", "pred_masks", "pred_keypoints" """ if not self.training: return self.inference(batched_inputs, trand_y_tilde=trand_y_tilde) images = self.preprocess_image(batched_inputs) if "instances" in batched_inputs[0]: gt_instances = [x["instances"].to(self.device) for x in batched_inputs] else: gt_instances = None features, distortion, rate = self.backbone(images.tensor) if self.proposal_generator: proposals, proposal_losses = self.proposal_generator(images, features, gt_instances) else: assert "proposals" in batched_inputs[0] proposals = [x["proposals"].to(self.device) for x in batched_inputs] proposal_losses = {} _, detector_losses = self.roi_heads(images, features, proposals, gt_instances) if self.vis_period > 0: storage = get_event_storage() if storage.iter % self.vis_period == 0: self.visualize_training(batched_inputs, proposals) losses = {} losses.update(detector_losses) losses.update(proposal_losses) return losses, distortion, rate def inference(self, batched_inputs, detected_instances=None, do_postprocess=True, trand_y_tilde=None): """ Run inference on the given inputs. Args: batched_inputs (list[dict]): same as in :meth:`forward` detected_instances (None or list[Instances]): if not None, it contains an `Instances` object per image. The `Instances` object contains "pred_boxes" and "pred_classes" which are known boxes in the image. The inference will then skip the detection of bounding boxes, and only predict other per-ROI outputs. do_postprocess (bool): whether to apply post-processing on the outputs. Returns: same as in :meth:`forward`. """ assert not self.training images = self.preprocess_image(batched_inputs) features = self.backbone(trand_y_tilde) if detected_instances is None: if self.proposal_generator: proposals, _ = self.proposal_generator(images, features, None) else: assert "proposals" in batched_inputs[0] proposals = [x["proposals"].to(self.device) for x in batched_inputs] results, _ = self.roi_heads(images, features, proposals, None) else: detected_instances = [x.to(self.device) for x in detected_instances] results = self.roi_heads.forward_with_given_boxes(features, detected_instances) if do_postprocess: return self._postprocess(results, batched_inputs, images.image_sizes) else: return results def preprocess_image(self, batched_inputs): """ Normalize, pad and batch the input images. """ images = [x["image"].to(self.device) for x in batched_inputs] images = [(x - self.pixel_mean) / self.pixel_std for x in images] images = ImageList.from_tensors(images, self.backbone.size_divisibility) return images @staticmethod def _postprocess(instances, batched_inputs, image_sizes): """ Rescale the output instances to the target size. """ # note: private function; subject to changes processed_results = [] for results_per_image, input_per_image, image_size in zip( instances, batched_inputs, image_sizes ): height = input_per_image.get("height", image_size[0]) width = input_per_image.get("width", image_size[1]) r = detector_postprocess(results_per_image, height, width) processed_results.append({"instances": r}) return processed_results class ModPredictor: def __init__(self, cfg): self.cfg = cfg.clone() # cfg can be modified by model self.model = build_model(self.cfg) self.model.eval() if len(cfg.DATASETS.TEST): self.metadata = MetadataCatalog.get(cfg.DATASETS.TEST[0]) checkpointer = DetectionCheckpointer(self.model) checkpointer.load(cfg.MODEL.WEIGHTS) self.aug = T.ResizeShortestEdge( [cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MIN_SIZE_TEST], cfg.INPUT.MAX_SIZE_TEST ) self.input_format = cfg.INPUT.FORMAT assert self.input_format in ["RGB", "BGR"], self.input_format def __call__(self, original_image, trand_y_tilde): with torch.no_grad(): # https://github.com/sphinx-doc/sphinx/issues/4258 # Apply pre-processing to image. if self.input_format == "RGB": # whether the model expects BGR inputs or RGB original_image = original_image[:, :, ::-1] height, width = original_image.shape[:2] image = self.aug.get_transform(original_image).apply_image(original_image) image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1)) inputs = {"image": image[0], "height": height, "width": width} predictions = self.model([inputs], trand_y_tilde)[0] return predictions