Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR) | |
| Copyright(c) 2023 lyuwenyu. All Rights Reserved. | |
| """ | |
| from typing import Dict | |
| import torch | |
| import torch.distributed | |
| import torch.nn.functional as F | |
| import torchvision | |
| from torch import Tensor | |
| from ...core import register | |
| __all__ = [ | |
| "DetNMSPostProcessor", | |
| ] | |
| class DetNMSPostProcessor(torch.nn.Module): | |
| def __init__( | |
| self, | |
| iou_threshold=0.7, | |
| score_threshold=0.01, | |
| keep_topk=300, | |
| box_fmt="cxcywh", | |
| logit_fmt="sigmoid", | |
| ) -> None: | |
| super().__init__() | |
| self.iou_threshold = iou_threshold | |
| self.score_threshold = score_threshold | |
| self.keep_topk = keep_topk | |
| self.box_fmt = box_fmt.lower() | |
| self.logit_fmt = logit_fmt.lower() | |
| self.logit_func = getattr(F, self.logit_fmt, None) | |
| self.deploy_mode = False | |
| def forward(self, outputs: Dict[str, Tensor], orig_target_sizes: Tensor): | |
| logits, boxes = outputs["pred_logits"], outputs["pred_boxes"] | |
| pred_boxes = torchvision.ops.box_convert(boxes, in_fmt=self.box_fmt, out_fmt="xyxy") | |
| pred_boxes *= orig_target_sizes.repeat(1, 2).unsqueeze(1) | |
| values, pred_labels = torch.max(logits, dim=-1) | |
| if self.logit_func: | |
| pred_scores = self.logit_func(values) | |
| else: | |
| pred_scores = values | |
| # TODO for onnx export | |
| if self.deploy_mode: | |
| blobs = { | |
| "pred_labels": pred_labels, | |
| "pred_boxes": pred_boxes, | |
| "pred_scores": pred_scores, | |
| } | |
| return blobs | |
| results = [] | |
| for i in range(logits.shape[0]): | |
| score_keep = pred_scores[i] > self.score_threshold | |
| pred_box = pred_boxes[i][score_keep] | |
| pred_label = pred_labels[i][score_keep] | |
| pred_score = pred_scores[i][score_keep] | |
| keep = torchvision.ops.batched_nms(pred_box, pred_score, pred_label, self.iou_threshold) | |
| keep = keep[: self.keep_topk] | |
| blob = { | |
| "labels": pred_label[keep], | |
| "boxes": pred_box[keep], | |
| "scores": pred_score[keep], | |
| } | |
| results.append(blob) | |
| return results | |
| def deploy( | |
| self, | |
| ): | |
| self.eval() | |
| self.deploy_mode = True | |
| return self | |