|
|
| import logging
|
| from typing import List, Optional, Sequence, Tuple
|
| import torch
|
|
|
| from detectron2.layers.nms import batched_nms
|
| from detectron2.structures.instances import Instances
|
|
|
| from densepose.converters import ToChartResultConverterWithConfidences
|
| from densepose.structures import (
|
| DensePoseChartResultWithConfidences,
|
| DensePoseEmbeddingPredictorOutput,
|
| )
|
| from densepose.vis.bounding_box import BoundingBoxVisualizer, ScoredBoundingBoxVisualizer
|
| from densepose.vis.densepose_outputs_vertex import DensePoseOutputsVertexVisualizer
|
| from densepose.vis.densepose_results import DensePoseResultsVisualizer
|
|
|
| from .base import CompoundVisualizer
|
|
|
| Scores = Sequence[float]
|
| DensePoseChartResultsWithConfidences = List[DensePoseChartResultWithConfidences]
|
|
|
|
|
| def extract_scores_from_instances(instances: Instances, select=None):
|
| if instances.has("scores"):
|
| return instances.scores if select is None else instances.scores[select]
|
| return None
|
|
|
|
|
| def extract_boxes_xywh_from_instances(instances: Instances, select=None):
|
| if instances.has("pred_boxes"):
|
| boxes_xywh = instances.pred_boxes.tensor.clone()
|
| boxes_xywh[:, 2] -= boxes_xywh[:, 0]
|
| boxes_xywh[:, 3] -= boxes_xywh[:, 1]
|
| return boxes_xywh if select is None else boxes_xywh[select]
|
| return None
|
|
|
|
|
| def create_extractor(visualizer: object):
|
| """
|
| Create an extractor for the provided visualizer
|
| """
|
| if isinstance(visualizer, CompoundVisualizer):
|
| extractors = [create_extractor(v) for v in visualizer.visualizers]
|
| return CompoundExtractor(extractors)
|
| elif isinstance(visualizer, DensePoseResultsVisualizer):
|
| return DensePoseResultExtractor()
|
| elif isinstance(visualizer, ScoredBoundingBoxVisualizer):
|
| return CompoundExtractor([extract_boxes_xywh_from_instances, extract_scores_from_instances])
|
| elif isinstance(visualizer, BoundingBoxVisualizer):
|
| return extract_boxes_xywh_from_instances
|
| elif isinstance(visualizer, DensePoseOutputsVertexVisualizer):
|
| return DensePoseOutputsExtractor()
|
| else:
|
| logger = logging.getLogger(__name__)
|
| logger.error(f"Could not create extractor for {visualizer}")
|
| return None
|
|
|
|
|
| class BoundingBoxExtractor:
|
| """
|
| Extracts bounding boxes from instances
|
| """
|
|
|
| def __call__(self, instances: Instances):
|
| boxes_xywh = extract_boxes_xywh_from_instances(instances)
|
| return boxes_xywh
|
|
|
|
|
| class ScoredBoundingBoxExtractor:
|
| """
|
| Extracts bounding boxes from instances
|
| """
|
|
|
| def __call__(self, instances: Instances, select=None):
|
| scores = extract_scores_from_instances(instances)
|
| boxes_xywh = extract_boxes_xywh_from_instances(instances)
|
| if (scores is None) or (boxes_xywh is None):
|
| return (boxes_xywh, scores)
|
| if select is not None:
|
| scores = scores[select]
|
| boxes_xywh = boxes_xywh[select]
|
| return (boxes_xywh, scores)
|
|
|
|
|
| class DensePoseResultExtractor:
|
| """
|
| Extracts DensePose chart result with confidences from instances
|
| """
|
|
|
| def __call__(
|
| self, instances: Instances, select=None
|
| ) -> Tuple[Optional[DensePoseChartResultsWithConfidences], Optional[torch.Tensor]]:
|
| if instances.has("pred_densepose") and instances.has("pred_boxes"):
|
| dpout = instances.pred_densepose
|
| boxes_xyxy = instances.pred_boxes
|
| boxes_xywh = extract_boxes_xywh_from_instances(instances)
|
| if select is not None:
|
| dpout = dpout[select]
|
| boxes_xyxy = boxes_xyxy[select]
|
| converter = ToChartResultConverterWithConfidences()
|
| results = [converter.convert(dpout[i], boxes_xyxy[[i]]) for i in range(len(dpout))]
|
| return results, boxes_xywh
|
| else:
|
| return None, None
|
|
|
|
|
| class DensePoseOutputsExtractor:
|
| """
|
| Extracts DensePose result from instances
|
| """
|
|
|
| def __call__(
|
| self,
|
| instances: Instances,
|
| select=None,
|
| ) -> Tuple[
|
| Optional[DensePoseEmbeddingPredictorOutput], Optional[torch.Tensor], Optional[List[int]]
|
| ]:
|
| if not (instances.has("pred_densepose") and instances.has("pred_boxes")):
|
| return None, None, None
|
|
|
| dpout = instances.pred_densepose
|
| boxes_xyxy = instances.pred_boxes
|
| boxes_xywh = extract_boxes_xywh_from_instances(instances)
|
|
|
| if instances.has("pred_classes"):
|
| classes = instances.pred_classes.tolist()
|
| else:
|
| classes = None
|
|
|
| if select is not None:
|
| dpout = dpout[select]
|
| boxes_xyxy = boxes_xyxy[select]
|
| if classes is not None:
|
| classes = classes[select]
|
|
|
| return dpout, boxes_xywh, classes
|
|
|
|
|
| class CompoundExtractor:
|
| """
|
| Extracts data for CompoundVisualizer
|
| """
|
|
|
| def __init__(self, extractors):
|
| self.extractors = extractors
|
|
|
| def __call__(self, instances: Instances, select=None):
|
| datas = []
|
| for extractor in self.extractors:
|
| data = extractor(instances, select)
|
| datas.append(data)
|
| return datas
|
|
|
|
|
| class NmsFilteredExtractor:
|
| """
|
| Extracts data in the format accepted by NmsFilteredVisualizer
|
| """
|
|
|
| def __init__(self, extractor, iou_threshold):
|
| self.extractor = extractor
|
| self.iou_threshold = iou_threshold
|
|
|
| def __call__(self, instances: Instances, select=None):
|
| scores = extract_scores_from_instances(instances)
|
| boxes_xywh = extract_boxes_xywh_from_instances(instances)
|
| if boxes_xywh is None:
|
| return None
|
| select_local_idx = batched_nms(
|
| boxes_xywh,
|
| scores,
|
| torch.zeros(len(scores), dtype=torch.int32),
|
| iou_threshold=self.iou_threshold,
|
| ).squeeze()
|
| select_local = torch.zeros(len(boxes_xywh), dtype=torch.bool, device=boxes_xywh.device)
|
| select_local[select_local_idx] = True
|
| select = select_local if select is None else (select & select_local)
|
| return self.extractor(instances, select=select)
|
|
|
|
|
| class ScoreThresholdedExtractor:
|
| """
|
| Extracts data in the format accepted by ScoreThresholdedVisualizer
|
| """
|
|
|
| def __init__(self, extractor, min_score):
|
| self.extractor = extractor
|
| self.min_score = min_score
|
|
|
| def __call__(self, instances: Instances, select=None):
|
| scores = extract_scores_from_instances(instances)
|
| if scores is None:
|
| return None
|
| select_local = scores > self.min_score
|
| select = select_local if select is None else (select & select_local)
|
| data = self.extractor(instances, select=select)
|
| return data
|
|
|