Spaces:
Sleeping
Sleeping
| """ PyTorch EfficientDet support benches | |
| Hacked together by Ross Wightman | |
| """ | |
| from typing import Optional, Dict, List | |
| import torch | |
| import torch.nn as nn | |
| from timm.utils import ModelEma | |
| from .anchors import Anchors, AnchorLabeler, generate_detections, MAX_DETECTION_POINTS | |
| from .loss import DetectionLoss | |
| def _post_process( | |
| cls_outputs: List[torch.Tensor], | |
| box_outputs: List[torch.Tensor], | |
| num_levels: int, | |
| num_classes: int, | |
| max_detection_points: int = MAX_DETECTION_POINTS, | |
| ): | |
| """Selects top-k predictions. | |
| Post-proc code adapted from Tensorflow version at: https://github.com/google/automl/tree/master/efficientdet | |
| and optimized for PyTorch. | |
| Args: | |
| cls_outputs: an OrderDict with keys representing levels and values | |
| representing logits in [batch_size, height, width, num_anchors]. | |
| box_outputs: an OrderDict with keys representing levels and values | |
| representing box regression targets in [batch_size, height, width, num_anchors * 4]. | |
| num_levels (int): number of feature levels | |
| num_classes (int): number of output classes | |
| """ | |
| batch_size = cls_outputs[0].shape[0] | |
| cls_outputs_all = torch.cat([ | |
| cls_outputs[level].permute(0, 2, 3, 1).reshape([batch_size, -1, num_classes]) | |
| for level in range(num_levels)], 1) | |
| box_outputs_all = torch.cat([ | |
| box_outputs[level].permute(0, 2, 3, 1).reshape([batch_size, -1, 4]) | |
| for level in range(num_levels)], 1) | |
| _, cls_topk_indices_all = torch.topk(cls_outputs_all.reshape(batch_size, -1), dim=1, k=max_detection_points) | |
| indices_all = cls_topk_indices_all // num_classes | |
| classes_all = cls_topk_indices_all % num_classes | |
| box_outputs_all_after_topk = torch.gather( | |
| box_outputs_all, 1, indices_all.unsqueeze(2).expand(-1, -1, 4)) | |
| cls_outputs_all_after_topk = torch.gather( | |
| cls_outputs_all, 1, indices_all.unsqueeze(2).expand(-1, -1, num_classes)) | |
| cls_outputs_all_after_topk = torch.gather( | |
| cls_outputs_all_after_topk, 2, classes_all.unsqueeze(2)) | |
| return cls_outputs_all_after_topk, box_outputs_all_after_topk, indices_all, classes_all | |
| def _batch_detection( | |
| batch_size: int, class_out, box_out, anchor_boxes, indices, classes, | |
| img_scale: Optional[torch.Tensor] = None, img_size: Optional[torch.Tensor] = None): | |
| batch_detections = [] | |
| # FIXME we may be able to do this as a batch with some tensor reshaping/indexing, PR welcome | |
| for i in range(batch_size): | |
| img_scale_i = None if img_scale is None else img_scale[i] | |
| img_size_i = None if img_size is None else img_size[i] | |
| detections = generate_detections( | |
| class_out[i], box_out[i], anchor_boxes, indices[i], classes[i], img_scale_i, img_size_i) | |
| batch_detections.append(detections) | |
| return torch.stack(batch_detections, dim=0) | |
| class DetBenchPredict(nn.Module): | |
| def __init__(self, model): | |
| super(DetBenchPredict, self).__init__() | |
| self.model = model | |
| self.config = model.config # FIXME remove this when we can use @property (torchscript limitation) | |
| self.num_levels = model.config.num_levels | |
| self.num_classes = model.config.num_classes | |
| self.anchors = Anchors.from_config(model.config) | |
| def forward(self, x, img_info: Optional[Dict[str, torch.Tensor]] = None): | |
| class_out, box_out = self.model(x) | |
| class_out, box_out, indices, classes = _post_process( | |
| class_out, box_out, num_levels=self.num_levels, num_classes=self.num_classes) | |
| if img_info is None: | |
| img_scale, img_size = None, None | |
| else: | |
| img_scale, img_size = img_info['img_scale'], img_info['img_size'] | |
| return _batch_detection( | |
| x.shape[0], class_out, box_out, self.anchors.boxes, indices, classes, img_scale, img_size) | |
| class DetBenchTrain(nn.Module): | |
| def __init__(self, model, create_labeler=True): | |
| super(DetBenchTrain, self).__init__() | |
| self.model = model | |
| self.config = model.config # FIXME remove this when we can use @property (torchscript limitation) | |
| self.num_levels = model.config.num_levels | |
| self.num_classes = model.config.num_classes | |
| self.anchors = Anchors.from_config(model.config) | |
| self.anchor_labeler = None | |
| if create_labeler: | |
| self.anchor_labeler = AnchorLabeler(self.anchors, self.num_classes, match_threshold=0.5) | |
| self.loss_fn = DetectionLoss(model.config) | |
| def forward(self, x, target: Dict[str, torch.Tensor]): | |
| class_out, box_out = self.model(x) | |
| if self.anchor_labeler is None: | |
| # target should contain pre-computed anchor labels if labeler not present in bench | |
| assert 'label_num_positives' in target | |
| cls_targets = [target[f'label_cls_{l}'] for l in range(self.num_levels)] | |
| box_targets = [target[f'label_bbox_{l}'] for l in range(self.num_levels)] | |
| num_positives = target['label_num_positives'] | |
| else: | |
| cls_targets, box_targets, num_positives = self.anchor_labeler.batch_label_anchors( | |
| target['bbox'], target['cls']) | |
| loss, class_loss, box_loss = self.loss_fn(class_out, box_out, cls_targets, box_targets, num_positives) | |
| output = {'loss': loss, 'class_loss': class_loss, 'box_loss': box_loss} | |
| if not self.training: | |
| # if eval mode, output detections for evaluation | |
| class_out_pp, box_out_pp, indices, classes = _post_process( | |
| class_out, box_out, num_levels=self.num_levels, num_classes=self.num_classes) | |
| output['detections'] = _batch_detection( | |
| x.shape[0], class_out_pp, box_out_pp, self.anchors.boxes, indices, classes, | |
| target['img_scale'], target['img_size']) | |
| return output | |
| def unwrap_bench(model): | |
| # Unwrap a model in support bench so that various other fns can access the weights and attribs of the | |
| # underlying model directly | |
| if isinstance(model, ModelEma): # unwrap ModelEma | |
| return unwrap_bench(model.ema) | |
| elif hasattr(model, 'module'): # unwrap DDP | |
| return unwrap_bench(model.module) | |
| elif hasattr(model, 'model'): # unwrap Bench -> model | |
| return unwrap_bench(model.model) | |
| else: | |
| return model | |