""" Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR) Copyright(c) 2023 lyuwenyu. All Rights Reserved. """ import torch import torch.nn as nn import torch.nn.functional as F import torchvision from ..core import register __all__ = ['PostProcessor'] def mod(a, b): out = a - a // b * b return out @register() class PostProcessor(nn.Module): __share__ = [ 'num_classes', 'use_focal_loss', 'num_top_queries', 'remap_mscoco_category' ] def __init__( self, num_classes=80, use_focal_loss=True, num_top_queries=300, remap_mscoco_category=False ) -> None: super().__init__() self.use_focal_loss = use_focal_loss self.num_top_queries = num_top_queries self.num_classes = int(num_classes) self.remap_mscoco_category = remap_mscoco_category self.deploy_mode = False def extra_repr(self) -> str: return f'use_focal_loss={self.use_focal_loss}, num_classes={self.num_classes}, num_top_queries={self.num_top_queries}' # def forward(self, outputs, orig_target_sizes): def forward(self, outputs, orig_target_sizes: torch.Tensor): logits, boxes = outputs['pred_logits'], outputs['pred_boxes'] # orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0) bbox_pred = torchvision.ops.box_convert(boxes, in_fmt='cxcywh', out_fmt='xyxy') bbox_pred *= orig_target_sizes.repeat(1, 2).unsqueeze(1) if self.use_focal_loss: scores = F.sigmoid(logits) scores, index = torch.topk(scores.flatten(1), self.num_top_queries, dim=-1) # labels = index % self.num_classes labels = mod(index, self.num_classes) index = index // self.num_classes boxes = bbox_pred.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, bbox_pred.shape[-1])) else: scores = F.softmax(logits)[:, :, :-1] scores, labels = scores.max(dim=-1) if scores.shape[1] > self.num_top_queries: scores, index = torch.topk(scores, self.num_top_queries, dim=-1) labels = torch.gather(labels, dim=1, index=index) boxes = torch.gather(boxes, dim=1, index=index.unsqueeze(-1).tile(1, 1, boxes.shape[-1])) if self.deploy_mode: return labels, boxes, scores if self.remap_mscoco_category: from ..data.dataset import mscoco_label2category labels = torch.tensor([mscoco_label2category[int(x.item())] for x in labels.flatten()])\ .to(boxes.device).reshape(labels.shape) results = [] for lab, box, sco in zip(labels, boxes, scores): result = dict(labels=lab, boxes=box, scores=sco) results.append(result) return results def deploy(self, ): self.eval() self.deploy_mode = True return self