File size: 1,521 Bytes
663494c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
from typing import Dict, List, Tuple
from nuscenes.eval.detection.data_classes import DetectionConfig
# from nuscenes.eval.detection.constants import ATTRIBUTE_NAMES, TP_METRICS
# from mmdet3d_plugin.eval.detection.constants import DETECTION_NAMES_CARLA
class CustomizedDetectionConfig(DetectionConfig):
"""Inherit nuScenes DetectionConfig from nuscenes.eval.detection.data_classes.py
but change the class names
"""
def __init__(
self,
class_range: Dict[str, int],
dist_fcn: str,
dist_ths: List[float],
dist_th_tp: float,
min_recall: float,
min_precision: float,
max_boxes_per_sample: int,
mean_ap_weight: int,
):
# print(dataset_name)
# if dataset_name == 'carla':
# det_name = DETECTION_NAMES_CARLA
# asd
# elif dataset_name == 'nuplan':
# det_name = DETECTION_NAMES_NUPLAN
# zxc
# print(det_name)
# assert set(class_range.keys()) == set(det_name), "Class count mismatch."
assert dist_th_tp in dist_ths, "dist_th_tp must be in set of dist_ths."
self.class_range = class_range
self.dist_fcn = dist_fcn
self.dist_ths = dist_ths
self.dist_th_tp = dist_th_tp
self.min_recall = min_recall
self.min_precision = min_precision
self.max_boxes_per_sample = max_boxes_per_sample
self.mean_ap_weight = mean_ap_weight
self.class_names = self.class_range.keys()
|