|
|
import torch |
|
|
|
|
|
from detectron2.data import MetadataCatalog |
|
|
|
|
|
from torch import nn |
|
|
import detectron2.data.transforms as T |
|
|
from detectron2.config import configurable |
|
|
from detectron2.modeling import build_model |
|
|
from detectron2.checkpoint import DetectionCheckpointer |
|
|
from detectron2.modeling.backbone.backbone import Backbone |
|
|
|
|
|
from detectron2.structures import ImageList |
|
|
from detectron2.modeling import (build_backbone, META_ARCH_REGISTRY, |
|
|
build_proposal_generator, build_roi_heads, |
|
|
detector_postprocess) |
|
|
from typing import Optional, Tuple |
|
|
|
|
|
@META_ARCH_REGISTRY.register() |
|
|
class GeneralizedRCNN_with_Rate(nn.Module): |
|
|
""" |
|
|
Generalized R-CNN. Any models that contains the following three components: |
|
|
1. Per-image feature extraction (aka backbone) |
|
|
2. Region proposal generation |
|
|
3. Per-region feature extraction and prediction |
|
|
""" |
|
|
|
|
|
@configurable |
|
|
def __init__( |
|
|
self, |
|
|
*, |
|
|
backbone: Backbone, |
|
|
proposal_generator: nn.Module, |
|
|
roi_heads: nn.Module, |
|
|
pixel_mean: Tuple[float], |
|
|
pixel_std: Tuple[float], |
|
|
input_format: Optional[str] = None, |
|
|
vis_period: int = 0, |
|
|
): |
|
|
""" |
|
|
NOTE: this interface is experimental. |
|
|
|
|
|
Args: |
|
|
backbone: a backbone module, must follow detectron2's backbone interface |
|
|
proposal_generator: a module that generates proposals using backbone features |
|
|
roi_heads: a ROI head that performs per-region computation |
|
|
pixel_mean, pixel_std: list or tuple with #channels element, |
|
|
representing the per-channel mean and std to be used to normalize |
|
|
the input image |
|
|
input_format: describe the meaning of channels of input. Needed by visualization |
|
|
vis_period: the period to run visualization. Set to 0 to disable. |
|
|
""" |
|
|
super().__init__() |
|
|
self.backbone = backbone |
|
|
self.proposal_generator = proposal_generator |
|
|
self.roi_heads = roi_heads |
|
|
|
|
|
self.input_format = input_format |
|
|
self.vis_period = vis_period |
|
|
if vis_period > 0: |
|
|
assert input_format is not None, "input_format is required for visualization!" |
|
|
|
|
|
self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1)) |
|
|
self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1)) |
|
|
assert ( |
|
|
self.pixel_mean.shape == self.pixel_std.shape |
|
|
), f"{self.pixel_mean} and {self.pixel_std} have different shapes!" |
|
|
|
|
|
@classmethod |
|
|
def from_config(cls, cfg): |
|
|
backbone = build_backbone(cfg) |
|
|
return { |
|
|
"backbone": backbone, |
|
|
"proposal_generator": build_proposal_generator(cfg, backbone.output_shape()), |
|
|
"roi_heads": build_roi_heads(cfg, backbone.output_shape()), |
|
|
"input_format": cfg.INPUT.FORMAT, |
|
|
"vis_period": cfg.VIS_PERIOD, |
|
|
"pixel_mean": cfg.MODEL.PIXEL_MEAN, |
|
|
"pixel_std": cfg.MODEL.PIXEL_STD, |
|
|
} |
|
|
|
|
|
@property |
|
|
def device(self): |
|
|
return self.pixel_mean.device |
|
|
|
|
|
def forward(self, batched_inputs, trand_y_tilde): |
|
|
""" |
|
|
Args: |
|
|
batched_inputs: a list, batched outputs of :class:`DatasetMapper` . |
|
|
Each item in the list contains the inputs for one image. |
|
|
For now, each item in the list is a dict that contains: |
|
|
|
|
|
* image: Tensor, image in (C, H, W) format. |
|
|
* instances (optional): groundtruth :class:`Instances` |
|
|
* proposals (optional): :class:`Instances`, precomputed proposals. |
|
|
|
|
|
Other information that's included in the original dicts, such as: |
|
|
|
|
|
* "height", "width" (int): the output resolution of the model, used in inference. |
|
|
See :meth:`postprocess` for details. |
|
|
|
|
|
Returns: |
|
|
list[dict]: |
|
|
Each dict is the output for one input image. |
|
|
The dict contains one key "instances" whose value is a :class:`Instances`. |
|
|
The :class:`Instances` object has the following keys: |
|
|
"pred_boxes", "pred_classes", "scores", "pred_masks", "pred_keypoints" |
|
|
""" |
|
|
if not self.training: |
|
|
return self.inference(batched_inputs, trand_y_tilde=trand_y_tilde) |
|
|
|
|
|
images = self.preprocess_image(batched_inputs) |
|
|
if "instances" in batched_inputs[0]: |
|
|
gt_instances = [x["instances"].to(self.device) for x in batched_inputs] |
|
|
else: |
|
|
gt_instances = None |
|
|
|
|
|
features, distortion, rate = self.backbone(images.tensor) |
|
|
|
|
|
if self.proposal_generator: |
|
|
proposals, proposal_losses = self.proposal_generator(images, features, gt_instances) |
|
|
else: |
|
|
assert "proposals" in batched_inputs[0] |
|
|
proposals = [x["proposals"].to(self.device) for x in batched_inputs] |
|
|
proposal_losses = {} |
|
|
|
|
|
_, detector_losses = self.roi_heads(images, features, proposals, gt_instances) |
|
|
if self.vis_period > 0: |
|
|
storage = get_event_storage() |
|
|
if storage.iter % self.vis_period == 0: |
|
|
self.visualize_training(batched_inputs, proposals) |
|
|
|
|
|
losses = {} |
|
|
losses.update(detector_losses) |
|
|
losses.update(proposal_losses) |
|
|
return losses, distortion, rate |
|
|
|
|
|
def inference(self, batched_inputs, detected_instances=None, do_postprocess=True, trand_y_tilde=None): |
|
|
""" |
|
|
Run inference on the given inputs. |
|
|
|
|
|
Args: |
|
|
batched_inputs (list[dict]): same as in :meth:`forward` |
|
|
detected_instances (None or list[Instances]): if not None, it |
|
|
contains an `Instances` object per image. The `Instances` |
|
|
object contains "pred_boxes" and "pred_classes" which are |
|
|
known boxes in the image. |
|
|
The inference will then skip the detection of bounding boxes, |
|
|
and only predict other per-ROI outputs. |
|
|
do_postprocess (bool): whether to apply post-processing on the outputs. |
|
|
|
|
|
Returns: |
|
|
same as in :meth:`forward`. |
|
|
""" |
|
|
assert not self.training |
|
|
|
|
|
images = self.preprocess_image(batched_inputs) |
|
|
features = self.backbone(trand_y_tilde) |
|
|
|
|
|
if detected_instances is None: |
|
|
if self.proposal_generator: |
|
|
proposals, _ = self.proposal_generator(images, features, None) |
|
|
else: |
|
|
assert "proposals" in batched_inputs[0] |
|
|
proposals = [x["proposals"].to(self.device) for x in batched_inputs] |
|
|
|
|
|
results, _ = self.roi_heads(images, features, proposals, None) |
|
|
else: |
|
|
detected_instances = [x.to(self.device) for x in detected_instances] |
|
|
results = self.roi_heads.forward_with_given_boxes(features, detected_instances) |
|
|
|
|
|
if do_postprocess: |
|
|
return self._postprocess(results, batched_inputs, images.image_sizes) |
|
|
else: |
|
|
return results |
|
|
|
|
|
def preprocess_image(self, batched_inputs): |
|
|
""" |
|
|
Normalize, pad and batch the input images. |
|
|
""" |
|
|
images = [x["image"].to(self.device) for x in batched_inputs] |
|
|
images = [(x - self.pixel_mean) / self.pixel_std for x in images] |
|
|
images = ImageList.from_tensors(images, self.backbone.size_divisibility) |
|
|
return images |
|
|
|
|
|
@staticmethod |
|
|
def _postprocess(instances, batched_inputs, image_sizes): |
|
|
""" |
|
|
Rescale the output instances to the target size. |
|
|
""" |
|
|
|
|
|
processed_results = [] |
|
|
for results_per_image, input_per_image, image_size in zip( |
|
|
instances, batched_inputs, image_sizes |
|
|
): |
|
|
height = input_per_image.get("height", image_size[0]) |
|
|
width = input_per_image.get("width", image_size[1]) |
|
|
r = detector_postprocess(results_per_image, height, width) |
|
|
processed_results.append({"instances": r}) |
|
|
return processed_results |
|
|
|
|
|
|
|
|
class ModPredictor: |
|
|
def __init__(self, cfg): |
|
|
self.cfg = cfg.clone() |
|
|
self.model = build_model(self.cfg) |
|
|
self.model.eval() |
|
|
if len(cfg.DATASETS.TEST): |
|
|
self.metadata = MetadataCatalog.get(cfg.DATASETS.TEST[0]) |
|
|
|
|
|
checkpointer = DetectionCheckpointer(self.model) |
|
|
checkpointer.load(cfg.MODEL.WEIGHTS) |
|
|
|
|
|
self.aug = T.ResizeShortestEdge( |
|
|
[cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MIN_SIZE_TEST], cfg.INPUT.MAX_SIZE_TEST |
|
|
) |
|
|
|
|
|
self.input_format = cfg.INPUT.FORMAT |
|
|
assert self.input_format in ["RGB", "BGR"], self.input_format |
|
|
|
|
|
def __call__(self, original_image, trand_y_tilde): |
|
|
with torch.no_grad(): |
|
|
|
|
|
if self.input_format == "RGB": |
|
|
|
|
|
original_image = original_image[:, :, ::-1] |
|
|
height, width = original_image.shape[:2] |
|
|
image = self.aug.get_transform(original_image).apply_image(original_image) |
|
|
image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1)) |
|
|
|
|
|
inputs = {"image": image[0], "height": height, "width": width} |
|
|
predictions = self.model([inputs], trand_y_tilde)[0] |
|
|
return predictions |
|
|
|