| |
|
|
| from dataclasses import dataclass |
| from typing import Any, Callable, Dict, List, Optional |
|
|
| from detectron2.structures import Instances |
|
|
| ModelOutput = Dict[str, Any] |
| SampledData = Dict[str, Any] |
|
|
|
|
| @dataclass |
| class _Sampler: |
| """ |
| Sampler registry entry that contains: |
| - src (str): source field to sample from (deleted after sampling) |
| - dst (Optional[str]): destination field to sample to, if not None |
| - func (Optional[Callable: Any -> Any]): function that performs sampling, |
| if None, reference copy is performed |
| """ |
|
|
| src: str |
| dst: Optional[str] |
| func: Optional[Callable[[Any], Any]] |
|
|
|
|
| class PredictionToGroundTruthSampler: |
| """ |
| Sampler implementation that converts predictions to GT using registered |
| samplers for different fields of `Instances`. |
| """ |
|
|
| def __init__(self, dataset_name: str = ""): |
| self.dataset_name = dataset_name |
| self._samplers = {} |
| self.register_sampler("pred_boxes", "gt_boxes", None) |
| self.register_sampler("pred_classes", "gt_classes", None) |
| |
| self.register_sampler("scores") |
|
|
| def __call__(self, model_output: List[ModelOutput]) -> List[SampledData]: |
| """ |
| Transform model output into ground truth data through sampling |
| |
| Args: |
| model_output (Dict[str, Any]): model output |
| Returns: |
| Dict[str, Any]: sampled data |
| """ |
| for model_output_i in model_output: |
| instances: Instances = model_output_i["instances"] |
| |
| for _, sampler in self._samplers.items(): |
| if not instances.has(sampler.src) or sampler.dst is None: |
| continue |
| if sampler.func is None: |
| instances.set(sampler.dst, instances.get(sampler.src)) |
| else: |
| instances.set(sampler.dst, sampler.func(instances)) |
| |
| for _, sampler in self._samplers.items(): |
| if sampler.src != sampler.dst and instances.has(sampler.src): |
| instances.remove(sampler.src) |
| model_output_i["dataset"] = self.dataset_name |
| return model_output |
|
|
| def register_sampler( |
| self, |
| prediction_attr: str, |
| gt_attr: Optional[str] = None, |
| func: Optional[Callable[[Any], Any]] = None, |
| ): |
| """ |
| Register sampler for a field |
| |
| Args: |
| prediction_attr (str): field to replace with a sampled value |
| gt_attr (Optional[str]): field to store the sampled value to, if not None |
| func (Optional[Callable: Any -> Any]): sampler function |
| """ |
| self._samplers[(prediction_attr, gt_attr)] = _Sampler( |
| src=prediction_attr, dst=gt_attr, func=func |
| ) |
|
|
| def remove_sampler( |
| self, |
| prediction_attr: str, |
| gt_attr: Optional[str] = None, |
| ): |
| """ |
| Remove sampler for a field |
| |
| Args: |
| prediction_attr (str): field to replace with a sampled value |
| gt_attr (Optional[str]): field to store the sampled value to, if not None |
| """ |
| assert (prediction_attr, gt_attr) in self._samplers |
| del self._samplers[(prediction_attr, gt_attr)] |
|
|