Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR) | |
| Copyright(c) 2023 lyuwenyu. All Rights Reserved. | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torchvision | |
| from ..core import register | |
| __all__ = ['PostProcessor'] | |
| def mod(a, b): | |
| out = a - a // b * b | |
| return out | |
| class PostProcessor(nn.Module): | |
| __share__ = [ | |
| 'num_classes', | |
| 'use_focal_loss', | |
| 'num_top_queries', | |
| 'remap_mscoco_category' | |
| ] | |
| def __init__( | |
| self, | |
| num_classes=80, | |
| use_focal_loss=True, | |
| num_top_queries=300, | |
| remap_mscoco_category=False | |
| ) -> None: | |
| super().__init__() | |
| self.use_focal_loss = use_focal_loss | |
| self.num_top_queries = num_top_queries | |
| self.num_classes = int(num_classes) | |
| self.remap_mscoco_category = remap_mscoco_category | |
| self.deploy_mode = False | |
| def extra_repr(self) -> str: | |
| return f'use_focal_loss={self.use_focal_loss}, num_classes={self.num_classes}, num_top_queries={self.num_top_queries}' | |
| # def forward(self, outputs, orig_target_sizes): | |
| def forward(self, outputs, orig_target_sizes: torch.Tensor): | |
| logits, boxes = outputs['pred_logits'], outputs['pred_boxes'] | |
| # orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0) | |
| bbox_pred = torchvision.ops.box_convert(boxes, in_fmt='cxcywh', out_fmt='xyxy') | |
| bbox_pred *= orig_target_sizes.repeat(1, 2).unsqueeze(1) | |
| if self.use_focal_loss: | |
| scores = F.sigmoid(logits) | |
| scores, index = torch.topk(scores.flatten(1), self.num_top_queries, dim=-1) | |
| # labels = index % self.num_classes | |
| labels = mod(index, self.num_classes) | |
| index = index // self.num_classes | |
| boxes = bbox_pred.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, bbox_pred.shape[-1])) | |
| else: | |
| scores = F.softmax(logits)[:, :, :-1] | |
| scores, labels = scores.max(dim=-1) | |
| if scores.shape[1] > self.num_top_queries: | |
| scores, index = torch.topk(scores, self.num_top_queries, dim=-1) | |
| labels = torch.gather(labels, dim=1, index=index) | |
| boxes = torch.gather(boxes, dim=1, index=index.unsqueeze(-1).tile(1, 1, boxes.shape[-1])) | |
| if self.deploy_mode: | |
| return labels, boxes, scores | |
| if self.remap_mscoco_category: | |
| from ..data.dataset import mscoco_label2category | |
| labels = torch.tensor([mscoco_label2category[int(x.item())] for x in labels.flatten()])\ | |
| .to(boxes.device).reshape(labels.shape) | |
| results = [] | |
| for lab, box, sco in zip(labels, boxes, scores): | |
| result = dict(labels=lab, boxes=box, scores=sco) | |
| results.append(result) | |
| return results | |
| def deploy(self, ): | |
| self.eval() | |
| self.deploy_mode = True | |
| return self | |