| | |
| | import logging |
| | from typing import List, Optional, Sequence, Tuple |
| | import torch |
| |
|
| | from detectron2.layers.nms import batched_nms |
| | from detectron2.structures.instances import Instances |
| |
|
| | from densepose.converters import ToChartResultConverterWithConfidences |
| | from densepose.structures import ( |
| | DensePoseChartResultWithConfidences, |
| | DensePoseEmbeddingPredictorOutput, |
| | ) |
| | from densepose.vis.bounding_box import BoundingBoxVisualizer, ScoredBoundingBoxVisualizer |
| | from densepose.vis.densepose_outputs_vertex import DensePoseOutputsVertexVisualizer |
| | from densepose.vis.densepose_results import DensePoseResultsVisualizer |
| |
|
| | from .base import CompoundVisualizer |
| |
|
| | Scores = Sequence[float] |
| | DensePoseChartResultsWithConfidences = List[DensePoseChartResultWithConfidences] |
| |
|
| |
|
| | def extract_scores_from_instances(instances: Instances, select=None): |
| | if instances.has("scores"): |
| | return instances.scores if select is None else instances.scores[select] |
| | return None |
| |
|
| |
|
| | def extract_boxes_xywh_from_instances(instances: Instances, select=None): |
| | if instances.has("pred_boxes"): |
| | boxes_xywh = instances.pred_boxes.tensor.clone() |
| | boxes_xywh[:, 2] -= boxes_xywh[:, 0] |
| | boxes_xywh[:, 3] -= boxes_xywh[:, 1] |
| | return boxes_xywh if select is None else boxes_xywh[select] |
| | return None |
| |
|
| |
|
| | def create_extractor(visualizer: object): |
| | """ |
| | Create an extractor for the provided visualizer |
| | """ |
| | if isinstance(visualizer, CompoundVisualizer): |
| | extractors = [create_extractor(v) for v in visualizer.visualizers] |
| | return CompoundExtractor(extractors) |
| | elif isinstance(visualizer, DensePoseResultsVisualizer): |
| | return DensePoseResultExtractor() |
| | elif isinstance(visualizer, ScoredBoundingBoxVisualizer): |
| | return CompoundExtractor([extract_boxes_xywh_from_instances, extract_scores_from_instances]) |
| | elif isinstance(visualizer, BoundingBoxVisualizer): |
| | return extract_boxes_xywh_from_instances |
| | elif isinstance(visualizer, DensePoseOutputsVertexVisualizer): |
| | return DensePoseOutputsExtractor() |
| | else: |
| | logger = logging.getLogger(__name__) |
| | logger.error(f"Could not create extractor for {visualizer}") |
| | return None |
| |
|
| |
|
| | class BoundingBoxExtractor: |
| | """ |
| | Extracts bounding boxes from instances |
| | """ |
| |
|
| | def __call__(self, instances: Instances): |
| | boxes_xywh = extract_boxes_xywh_from_instances(instances) |
| | return boxes_xywh |
| |
|
| |
|
| | class ScoredBoundingBoxExtractor: |
| | """ |
| | Extracts bounding boxes from instances |
| | """ |
| |
|
| | def __call__(self, instances: Instances, select=None): |
| | scores = extract_scores_from_instances(instances) |
| | boxes_xywh = extract_boxes_xywh_from_instances(instances) |
| | if (scores is None) or (boxes_xywh is None): |
| | return (boxes_xywh, scores) |
| | if select is not None: |
| | scores = scores[select] |
| | boxes_xywh = boxes_xywh[select] |
| | return (boxes_xywh, scores) |
| |
|
| |
|
| | class DensePoseResultExtractor: |
| | """ |
| | Extracts DensePose chart result with confidences from instances |
| | """ |
| |
|
| | def __call__( |
| | self, instances: Instances, select=None |
| | ) -> Tuple[Optional[DensePoseChartResultsWithConfidences], Optional[torch.Tensor]]: |
| | if instances.has("pred_densepose") and instances.has("pred_boxes"): |
| | dpout = instances.pred_densepose |
| | boxes_xyxy = instances.pred_boxes |
| | boxes_xywh = extract_boxes_xywh_from_instances(instances) |
| | if select is not None: |
| | dpout = dpout[select] |
| | boxes_xyxy = boxes_xyxy[select] |
| | converter = ToChartResultConverterWithConfidences() |
| | results = [converter.convert(dpout[i], boxes_xyxy[[i]]) for i in range(len(dpout))] |
| | return results, boxes_xywh |
| | else: |
| | return None, None |
| |
|
| |
|
| | class DensePoseOutputsExtractor: |
| | """ |
| | Extracts DensePose result from instances |
| | """ |
| |
|
| | def __call__( |
| | self, |
| | instances: Instances, |
| | select=None, |
| | ) -> Tuple[ |
| | Optional[DensePoseEmbeddingPredictorOutput], Optional[torch.Tensor], Optional[List[int]] |
| | ]: |
| | if not (instances.has("pred_densepose") and instances.has("pred_boxes")): |
| | return None, None, None |
| |
|
| | dpout = instances.pred_densepose |
| | boxes_xyxy = instances.pred_boxes |
| | boxes_xywh = extract_boxes_xywh_from_instances(instances) |
| |
|
| | if instances.has("pred_classes"): |
| | classes = instances.pred_classes.tolist() |
| | else: |
| | classes = None |
| |
|
| | if select is not None: |
| | dpout = dpout[select] |
| | boxes_xyxy = boxes_xyxy[select] |
| | if classes is not None: |
| | classes = classes[select] |
| |
|
| | return dpout, boxes_xywh, classes |
| |
|
| |
|
| | class CompoundExtractor: |
| | """ |
| | Extracts data for CompoundVisualizer |
| | """ |
| |
|
| | def __init__(self, extractors): |
| | self.extractors = extractors |
| |
|
| | def __call__(self, instances: Instances, select=None): |
| | datas = [] |
| | for extractor in self.extractors: |
| | data = extractor(instances, select) |
| | datas.append(data) |
| | return datas |
| |
|
| |
|
| | class NmsFilteredExtractor: |
| | """ |
| | Extracts data in the format accepted by NmsFilteredVisualizer |
| | """ |
| |
|
| | def __init__(self, extractor, iou_threshold): |
| | self.extractor = extractor |
| | self.iou_threshold = iou_threshold |
| |
|
| | def __call__(self, instances: Instances, select=None): |
| | scores = extract_scores_from_instances(instances) |
| | boxes_xywh = extract_boxes_xywh_from_instances(instances) |
| | if boxes_xywh is None: |
| | return None |
| | select_local_idx = batched_nms( |
| | boxes_xywh, |
| | scores, |
| | torch.zeros(len(scores), dtype=torch.int32), |
| | iou_threshold=self.iou_threshold, |
| | ).squeeze() |
| | select_local = torch.zeros(len(boxes_xywh), dtype=torch.bool, device=boxes_xywh.device) |
| | select_local[select_local_idx] = True |
| | select = select_local if select is None else (select & select_local) |
| | return self.extractor(instances, select=select) |
| |
|
| |
|
| | class ScoreThresholdedExtractor: |
| | """ |
| | Extracts data in the format accepted by ScoreThresholdedVisualizer |
| | """ |
| |
|
| | def __init__(self, extractor, min_score): |
| | self.extractor = extractor |
| | self.min_score = min_score |
| |
|
| | def __call__(self, instances: Instances, select=None): |
| | scores = extract_scores_from_instances(instances) |
| | if scores is None: |
| | return None |
| | select_local = scores > self.min_score |
| | select = select_local if select is None else (select & select_local) |
| | data = self.extractor(instances, select=select) |
| | return data |
| |
|