Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR) | |
| Copyright(c) 2023 lyuwenyu. All Rights Reserved. | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torchvision | |
| __all__ = ["DetDETRPostProcessor"] | |
| from .box_revert import BoxProcessFormat, box_revert | |
| def mod(a, b): | |
| out = a - a // b * b | |
| return out | |
| class DetDETRPostProcessor(nn.Module): | |
| def __init__( | |
| self, | |
| num_classes=80, | |
| use_focal_loss=True, | |
| num_top_queries=300, | |
| box_process_format=BoxProcessFormat.RESIZE, | |
| ) -> None: | |
| super().__init__() | |
| self.use_focal_loss = use_focal_loss | |
| self.num_top_queries = num_top_queries | |
| self.num_classes = int(num_classes) | |
| self.box_process_format = box_process_format | |
| self.deploy_mode = False | |
| def extra_repr(self) -> str: | |
| return f"use_focal_loss={self.use_focal_loss}, num_classes={self.num_classes}, num_top_queries={self.num_top_queries}" | |
| def forward(self, outputs, **kwargs): | |
| logits, boxes = outputs["pred_logits"], outputs["pred_boxes"] | |
| if self.use_focal_loss: | |
| scores = F.sigmoid(logits) | |
| scores, index = torch.topk(scores.flatten(1), self.num_top_queries, dim=-1) | |
| labels = index % self.num_classes | |
| # labels = mod(index, self.num_classes) # for tensorrt | |
| index = index // self.num_classes | |
| boxes = boxes.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, boxes.shape[-1])) | |
| else: | |
| scores = F.softmax(logits)[:, :, :-1] | |
| scores, labels = scores.max(dim=-1) | |
| if scores.shape[1] > self.num_top_queries: | |
| scores, index = torch.topk(scores, self.num_top_queries, dim=-1) | |
| labels = torch.gather(labels, dim=1, index=index) | |
| boxes = torch.gather( | |
| boxes, dim=1, index=index.unsqueeze(-1).tile(1, 1, boxes.shape[-1]) | |
| ) | |
| if kwargs is not None: | |
| boxes = box_revert( | |
| boxes, | |
| in_fmt="cxcywh", | |
| out_fmt="xyxy", | |
| process_fmt=self.box_process_format, | |
| normalized=True, | |
| **kwargs, | |
| ) | |
| # TODO for onnx export | |
| if self.deploy_mode: | |
| return labels, boxes, scores | |
| results = [] | |
| for lab, box, sco in zip(labels, boxes, scores): | |
| result = dict(labels=lab, boxes=box, scores=sco) | |
| results.append(result) | |
| return results | |
| def deploy( | |
| self, | |
| ): | |
| self.eval() | |
| self.deploy_mode = True | |
| return self | |