mimc_rl / util /predictor.py
wangyanhui666's picture
fine tune decoder with mask
9cf79cf
import torch
from detectron2.data import MetadataCatalog
from torch import nn
import detectron2.data.transforms as T
from detectron2.config import configurable
from detectron2.modeling import build_model
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.modeling.backbone.backbone import Backbone
from detectron2.structures import ImageList
from detectron2.modeling import (build_backbone, META_ARCH_REGISTRY,
build_proposal_generator, build_roi_heads,
detector_postprocess)
from typing import Optional, Tuple
@META_ARCH_REGISTRY.register()
class GeneralizedRCNN_with_Rate(nn.Module):
"""
Generalized R-CNN. Any models that contains the following three components:
1. Per-image feature extraction (aka backbone)
2. Region proposal generation
3. Per-region feature extraction and prediction
"""
@configurable
def __init__(
self,
*,
backbone: Backbone,
proposal_generator: nn.Module,
roi_heads: nn.Module,
pixel_mean: Tuple[float],
pixel_std: Tuple[float],
input_format: Optional[str] = None,
vis_period: int = 0,
):
"""
NOTE: this interface is experimental.
Args:
backbone: a backbone module, must follow detectron2's backbone interface
proposal_generator: a module that generates proposals using backbone features
roi_heads: a ROI head that performs per-region computation
pixel_mean, pixel_std: list or tuple with #channels element,
representing the per-channel mean and std to be used to normalize
the input image
input_format: describe the meaning of channels of input. Needed by visualization
vis_period: the period to run visualization. Set to 0 to disable.
"""
super().__init__()
self.backbone = backbone
self.proposal_generator = proposal_generator
self.roi_heads = roi_heads
self.input_format = input_format
self.vis_period = vis_period
if vis_period > 0:
assert input_format is not None, "input_format is required for visualization!"
self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1))
self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1))
assert (
self.pixel_mean.shape == self.pixel_std.shape
), f"{self.pixel_mean} and {self.pixel_std} have different shapes!"
@classmethod
def from_config(cls, cfg):
backbone = build_backbone(cfg)
return {
"backbone": backbone,
"proposal_generator": build_proposal_generator(cfg, backbone.output_shape()),
"roi_heads": build_roi_heads(cfg, backbone.output_shape()),
"input_format": cfg.INPUT.FORMAT,
"vis_period": cfg.VIS_PERIOD,
"pixel_mean": cfg.MODEL.PIXEL_MEAN,
"pixel_std": cfg.MODEL.PIXEL_STD,
}
@property
def device(self):
return self.pixel_mean.device
def forward(self, batched_inputs, trand_y_tilde):
"""
Args:
batched_inputs: a list, batched outputs of :class:`DatasetMapper` .
Each item in the list contains the inputs for one image.
For now, each item in the list is a dict that contains:
* image: Tensor, image in (C, H, W) format.
* instances (optional): groundtruth :class:`Instances`
* proposals (optional): :class:`Instances`, precomputed proposals.
Other information that's included in the original dicts, such as:
* "height", "width" (int): the output resolution of the model, used in inference.
See :meth:`postprocess` for details.
Returns:
list[dict]:
Each dict is the output for one input image.
The dict contains one key "instances" whose value is a :class:`Instances`.
The :class:`Instances` object has the following keys:
"pred_boxes", "pred_classes", "scores", "pred_masks", "pred_keypoints"
"""
if not self.training:
return self.inference(batched_inputs, trand_y_tilde=trand_y_tilde)
images = self.preprocess_image(batched_inputs)
if "instances" in batched_inputs[0]:
gt_instances = [x["instances"].to(self.device) for x in batched_inputs]
else:
gt_instances = None
features, distortion, rate = self.backbone(images.tensor)
if self.proposal_generator:
proposals, proposal_losses = self.proposal_generator(images, features, gt_instances)
else:
assert "proposals" in batched_inputs[0]
proposals = [x["proposals"].to(self.device) for x in batched_inputs]
proposal_losses = {}
_, detector_losses = self.roi_heads(images, features, proposals, gt_instances)
if self.vis_period > 0:
storage = get_event_storage()
if storage.iter % self.vis_period == 0:
self.visualize_training(batched_inputs, proposals)
losses = {}
losses.update(detector_losses)
losses.update(proposal_losses)
return losses, distortion, rate
def inference(self, batched_inputs, detected_instances=None, do_postprocess=True, trand_y_tilde=None):
"""
Run inference on the given inputs.
Args:
batched_inputs (list[dict]): same as in :meth:`forward`
detected_instances (None or list[Instances]): if not None, it
contains an `Instances` object per image. The `Instances`
object contains "pred_boxes" and "pred_classes" which are
known boxes in the image.
The inference will then skip the detection of bounding boxes,
and only predict other per-ROI outputs.
do_postprocess (bool): whether to apply post-processing on the outputs.
Returns:
same as in :meth:`forward`.
"""
assert not self.training
images = self.preprocess_image(batched_inputs)
features = self.backbone(trand_y_tilde)
if detected_instances is None:
if self.proposal_generator:
proposals, _ = self.proposal_generator(images, features, None)
else:
assert "proposals" in batched_inputs[0]
proposals = [x["proposals"].to(self.device) for x in batched_inputs]
results, _ = self.roi_heads(images, features, proposals, None)
else:
detected_instances = [x.to(self.device) for x in detected_instances]
results = self.roi_heads.forward_with_given_boxes(features, detected_instances)
if do_postprocess:
return self._postprocess(results, batched_inputs, images.image_sizes)
else:
return results
def preprocess_image(self, batched_inputs):
"""
Normalize, pad and batch the input images.
"""
images = [x["image"].to(self.device) for x in batched_inputs]
images = [(x - self.pixel_mean) / self.pixel_std for x in images]
images = ImageList.from_tensors(images, self.backbone.size_divisibility)
return images
@staticmethod
def _postprocess(instances, batched_inputs, image_sizes):
"""
Rescale the output instances to the target size.
"""
# note: private function; subject to changes
processed_results = []
for results_per_image, input_per_image, image_size in zip(
instances, batched_inputs, image_sizes
):
height = input_per_image.get("height", image_size[0])
width = input_per_image.get("width", image_size[1])
r = detector_postprocess(results_per_image, height, width)
processed_results.append({"instances": r})
return processed_results
class ModPredictor:
def __init__(self, cfg):
self.cfg = cfg.clone() # cfg can be modified by model
self.model = build_model(self.cfg)
self.model.eval()
if len(cfg.DATASETS.TEST):
self.metadata = MetadataCatalog.get(cfg.DATASETS.TEST[0])
checkpointer = DetectionCheckpointer(self.model)
checkpointer.load(cfg.MODEL.WEIGHTS)
self.aug = T.ResizeShortestEdge(
[cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MIN_SIZE_TEST], cfg.INPUT.MAX_SIZE_TEST
)
self.input_format = cfg.INPUT.FORMAT
assert self.input_format in ["RGB", "BGR"], self.input_format
def __call__(self, original_image, trand_y_tilde):
with torch.no_grad(): # https://github.com/sphinx-doc/sphinx/issues/4258
# Apply pre-processing to image.
if self.input_format == "RGB":
# whether the model expects BGR inputs or RGB
original_image = original_image[:, :, ::-1]
height, width = original_image.shape[:2]
image = self.aug.get_transform(original_image).apply_image(original_image)
image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))
inputs = {"image": image[0], "height": height, "width": width}
predictions = self.model([inputs], trand_y_tilde)[0]
return predictions