Spaces:
Sleeping
Sleeping
| """ | |
| Implements the Generalized R-CNN framework | |
| """ | |
| import warnings | |
| from collections import OrderedDict | |
| from typing import Dict, List, Optional, Tuple, Union | |
| import torch | |
| from torch import nn, Tensor | |
| from torchvision.utils import _log_api_usage_once | |
| class GeneralizedRCNN(nn.Module): | |
| """ | |
| Main class for Generalized R-CNN. | |
| Args: | |
| backbone (nn.Module): | |
| rpn (nn.Module): | |
| roi_heads (nn.Module): takes the features + the proposals from the RPN and computes | |
| detections / masks from it. | |
| transform (nn.Module): performs the data transformation from the inputs to feed into | |
| the model | |
| """ | |
| def __init__(self, backbone: nn.Module, rpn: nn.Module, roi_heads: nn.Module, transform: nn.Module) -> None: | |
| super().__init__() | |
| _log_api_usage_once(self) | |
| self.transform = transform | |
| self.backbone = backbone | |
| self.rpn = rpn | |
| self.roi_heads = roi_heads | |
| # used only on torchscript mode | |
| self._has_warned = False | |
| def eager_outputs(self, losses, detections): | |
| # type: (Dict[str, Tensor], List[Dict[str, Tensor]]) -> Union[Dict[str, Tensor], List[Dict[str, Tensor]]] | |
| if self.training: | |
| return losses | |
| return detections | |
| def forward(self, images, targets=None): | |
| # type: (List[Tensor], Optional[List[Dict[str, Tensor]]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]] | |
| """ | |
| Args: | |
| images (list[Tensor]): images to be processed | |
| targets (list[Dict[str, Tensor]]): ground-truth boxes present in the image (optional) | |
| Returns: | |
| result (list[BoxList] or dict[Tensor]): the output from the model. | |
| During training, it returns a dict[Tensor] which contains the losses. | |
| During testing, it returns list[BoxList] contains additional fields | |
| like `scores`, `labels` and `mask` (for Mask R-CNN models). | |
| """ | |
| if self.training: | |
| if targets is None: | |
| torch._assert(False, "targets should not be none when in training mode") | |
| else: | |
| for target in targets: | |
| boxes = target["boxes"] | |
| if isinstance(boxes, torch.Tensor): | |
| torch._assert( | |
| len(boxes.shape) == 2 and boxes.shape[-1] == 4, | |
| f"Expected target boxes to be a tensor of shape [N, 4], got {boxes.shape}.", | |
| ) | |
| else: | |
| torch._assert(False, f"Expected target boxes to be of type Tensor, got {type(boxes)}.") | |
| original_image_sizes: List[Tuple[int, int]] = [] | |
| for img in images: | |
| val = img.shape[-2:] | |
| torch._assert( | |
| len(val) == 2, | |
| f"expecting the last two dimensions of the Tensor to be H and W instead got {img.shape[-2:]}", | |
| ) | |
| original_image_sizes.append((val[0], val[1])) | |
| images, targets = self.transform(images, targets) | |
| # Check for degenerate boxes | |
| # TODO: Move this to a function | |
| if targets is not None: | |
| for target_idx, target in enumerate(targets): | |
| boxes = target["boxes"] | |
| degenerate_boxes = boxes[:, 2:4] <= boxes[:, :2] | |
| if degenerate_boxes.any(): | |
| # print the first degenerate box | |
| bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0] | |
| degen_bb: List[float] = boxes[bb_idx].tolist() | |
| torch._assert( | |
| False, | |
| "All bounding boxes should have positive height and width." | |
| f" Found invalid box {degen_bb} for target at index {target_idx}.", | |
| ) | |
| features = self.backbone(images.tensors) | |
| if isinstance(features, torch.Tensor): | |
| features = OrderedDict([("0", features)]) | |
| # modify targets to remove theta for rpn | |
| # print(f"{len(targets)=}") | |
| # print(f"{targets[0]=}") | |
| # targets_rpn = [] | |
| # for target in targets: | |
| # target_rpn = target.copy() | |
| # target_rpn['boxes'] = target_rpn['boxes'][:, :-1] | |
| # targets_rpn.append(target_rpn) | |
| # print(f"{targets_rpn[0]=}") | |
| proposals, proposal_losses = self.rpn(images, features, targets) | |
| detections, detector_losses = self.roi_heads(features, proposals, images.image_sizes, targets) | |
| detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes) # type: ignore[operator] | |
| losses = {} | |
| losses.update(detector_losses) | |
| losses.update(proposal_losses) | |
| if torch.jit.is_scripting(): | |
| if not self._has_warned: | |
| warnings.warn("RCNN always returns a (Losses, Detections) tuple in scripting") | |
| self._has_warned = True | |
| return losses, detections | |
| else: | |
| return self.eager_outputs(losses, detections) |