"""Atlas evaluation metrics.""" import re import numpy as np from typing import List, Dict, Tuple, Optional # scipy only affects match_lanes() / calculate_lane_detection_metrics(), # which are NOT used in the main eval path (eval_atlas.py). # Main eval uses: greedy matching for detection, OpenLane-V2 LaneEval.bench() for lanes. try: from scipy.optimize import linear_sum_assignment SCIPY_AVAILABLE = True except ImportError: SCIPY_AVAILABLE = False NUSCENES_CLASS_MAP = { # Base class names 'car': 'car', 'truck': 'truck', 'construction_vehicle': 'construction_vehicle', 'bus': 'bus', 'trailer': 'trailer', 'barrier': 'barrier', 'motorcycle': 'motorcycle', 'bicycle': 'bicycle', 'pedestrian': 'pedestrian', 'traffic_cone': 'traffic_cone', # Full nuScenes category names - vehicles 'vehicle.car': 'car', 'vehicle.truck': 'truck', 'vehicle.construction': 'construction_vehicle', 'vehicle.bus.bendy': 'bus', 'vehicle.bus.rigid': 'bus', 'vehicle.trailer': 'trailer', 'vehicle.motorcycle': 'motorcycle', 'vehicle.bicycle': 'bicycle', # Full nuScenes category names - pedestrians (all subtypes) 'human.pedestrian.adult': 'pedestrian', 'human.pedestrian.child': 'pedestrian', 'human.pedestrian.construction_worker': 'pedestrian', 'human.pedestrian.police_officer': 'pedestrian', 'human.pedestrian.wheelchair': 'pedestrian', 'human.pedestrian.stroller': 'pedestrian', 'human.pedestrian.personal_mobility': 'pedestrian', # Full nuScenes category names - movable objects 'movable_object.barrier': 'barrier', 'movable_object.trafficcone': 'traffic_cone', 'movable_object.traffic_cone': 'traffic_cone', } def normalize_category(category: str) -> str: """Normalize nuScenes category names to base class names.""" cat_lower = category.lower().strip() if cat_lower in NUSCENES_CLASS_MAP: return NUSCENES_CLASS_MAP[cat_lower] for key, val in NUSCENES_CLASS_MAP.items(): if key in cat_lower or cat_lower in key: return val return cat_lower def normalize_ground_truths(ground_truths: List[Dict]) -> List[Dict]: """Normalize category names and ensure world_coords in ground truth list. Handles multiple GT formats: - {"translation": [x, y, z], "category_name": ...} (from regenerate_atlas_with_gt.py) - {"box": [x, y, z, w, l, h, yaw], "category_name": ...} (from gen_atlas_full_data.py) - {"world_coords": [x, y, z], "category": ...} (already normalized) """ normalized = [] for gt in ground_truths: gt_copy = dict(gt) # Normalize category if 'category' in gt_copy: gt_copy['category_raw'] = gt_copy['category'] gt_copy['category'] = normalize_category(gt_copy['category']) elif 'category_name' in gt_copy: gt_copy['category_raw'] = gt_copy['category_name'] gt_copy['category'] = normalize_category(gt_copy['category_name']) # Ensure world_coords exists if 'world_coords' not in gt_copy: if 'translation' in gt_copy: gt_copy['world_coords'] = list(gt_copy['translation'][:3]) elif 'box' in gt_copy: gt_copy['world_coords'] = list(gt_copy['box'][:3]) normalized.append(gt_copy) return normalized def bin_to_meters(bin_val: int, bin_range: Tuple[float, float] = (-51.2, 51.2), num_bins: int = 1000) -> float: min_val, max_val = bin_range normalized = bin_val / (num_bins - 1) meters = min_val + normalized * (max_val - min_val) return meters def meters_to_bin(meters: float, bin_range: Tuple[float, float] = (-51.2, 51.2), num_bins: int = 1000) -> int: min_val, max_val = bin_range meters = np.clip(meters, min_val, max_val) normalized = (meters - min_val) / (max_val - min_val) bin_val = round(normalized * (num_bins - 1)) bin_val = int(np.clip(bin_val, 0, num_bins - 1)) return bin_val def _parse_lane_points(points_str: str) -> List[Dict]: """Parse a sequence of [x, y, z] bins into lane point dicts.""" point_pattern = r'\[(\d+),\s*(\d+),\s*(\d+)\]' points = re.findall(point_pattern, points_str) lane_points = [] for x_bin, y_bin, z_bin in points: x_bin, y_bin, z_bin = int(x_bin), int(y_bin), int(z_bin) x_meters = bin_to_meters(x_bin, bin_range=(-51.2, 51.2)) y_meters = bin_to_meters(y_bin, bin_range=(-51.2, 51.2)) z_meters = bin_to_meters(z_bin, bin_range=(-5.0, 3.0)) lane_points.append({ 'bin_coords': [x_bin, y_bin, z_bin], 'world_coords': [x_meters, y_meters, z_meters] }) return lane_points def parse_atlas_output(text: str) -> List[Dict]: """ Parse Atlas model output. Supports two canonical formats (checked in order): 1. Paper lane: Lane: [x, y, z], [x, y, z]; [x, y, z], [x, y, z]; ... 2. Detection: category: [x, y, z], [x, y, z]; category: [x, y, z]. """ results = [] # --- 1. Paper lane format: "Lane: [pts], [pts]; [pts], [pts]; ..." --- paper_lane_match = re.search(r'Lane:\s*(.*)', text, re.DOTALL) if paper_lane_match: content = paper_lane_match.group(1).rstrip('. \t\n') lane_strs = content.split(';') for lane_idx, lane_str in enumerate(lane_strs): lane_str = lane_str.strip() if not lane_str: continue lane_points = _parse_lane_points(lane_str) if lane_points: results.append({ 'type': 'lane', 'lane_id': str(lane_idx), 'points': lane_points, }) if results: return results # --- 2. Detection grouped format --- # Canonical: "car: [pt1], [pt2]; truck: [pt3]." def _make_det(category: str, x_b: int, y_b: int, z_b: int) -> Dict: return { 'type': 'detection', 'category': normalize_category(category), 'category_raw': category, 'bin_coords': [x_b, y_b, z_b], 'world_coords': [ bin_to_meters(x_b, bin_range=(-51.2, 51.2)), bin_to_meters(y_b, bin_range=(-51.2, 51.2)), bin_to_meters(z_b, bin_range=(-5.0, 3.0)), ], } point_re = re.compile(r'\[(\d+),\s*(\d+),\s*(\d+)\]') group_re = re.compile(r'(\S+)\s*:\s*((?:\[\d+,\s*\d+,\s*\d+\][\s,]*)+)') stripped = text.strip().rstrip('.') if stripped.startswith('lane_centerline('): return [] if ';' in stripped: for seg in stripped.split(';'): seg = seg.strip() if not seg: continue gm = group_re.match(seg) if gm: for x_b, y_b, z_b in point_re.findall(gm.group(2)): results.append(_make_det(gm.group(1), int(x_b), int(y_b), int(z_b))) if not results: gm = group_re.match(stripped) if gm: pts_in_group = point_re.findall(gm.group(2)) pts_in_text = point_re.findall(stripped) if len(pts_in_group) == len(pts_in_text): for x_b, y_b, z_b in pts_in_group: results.append(_make_det(gm.group(1), int(x_b), int(y_b), int(z_b))) return results def calculate_distance( pred_coord: List[float], gt_coord: List[float], use_2d: bool = False, ) -> float: """ 计算预测坐标和真实坐标之间的距离 Args: pred_coord: 预测坐标 [x, y, z] gt_coord: 真实坐标 [x, y, z] use_2d: 如果为 True,只使用 XY 平面距离(BEV 距离),忽略 Z 轴 这是 BEV 3D 检测中更常用的匹配方式 """ pred = np.array(pred_coord) gt = np.array(gt_coord) if use_2d: # 只使用 XY 平面距离(BEV 距离) distance = np.linalg.norm(pred[:2] - gt[:2]) else: # 3D 欧式距离 distance = np.linalg.norm(pred - gt) return float(distance) def match_detections( predictions: List[Dict], ground_truths: List[Dict], threshold: float = 2.0, use_2d_distance: bool = True, use_hungarian: bool = False, ) -> Tuple[List[Tuple[int, int]], List[int], List[int]]: """ 匹配预测和真实检测框 Args: predictions: 预测检测结果列表 ground_truths: 真实检测结果列表 threshold: 匹配距离阈值(米) use_2d_distance: 如果为 True,使用 2D BEV 距离(XY 平面),这是 BEV 检测的标准做法 use_hungarian: 如果为 True,使用匈牙利算法进行最优匹配(需要 scipy); 默认 False,使用贪婪匹配(nuScenes 标准) """ if len(predictions) == 0: return [], [], list(range(len(ground_truths))) if len(ground_truths) == 0: return [], list(range(len(predictions))), [] # 按类别分组进行匹配 all_categories = set(p['category'] for p in predictions) | set(g['category'] for g in ground_truths) matched_preds = set() matched_gts = set() matches = [] for category in all_categories: cat_preds = [(i, p) for i, p in enumerate(predictions) if p['category'] == category] cat_gts = [(i, g) for i, g in enumerate(ground_truths) if g['category'] == category] if not cat_preds or not cat_gts: continue # 构建距离矩阵 n_preds = len(cat_preds) n_gts = len(cat_gts) cost_matrix = np.full((n_preds, n_gts), float('inf')) for pi, (pred_idx, pred) in enumerate(cat_preds): for gi, (gt_idx, gt) in enumerate(cat_gts): dist = calculate_distance(pred['world_coords'], gt['world_coords'], use_2d=use_2d_distance) if dist < threshold: cost_matrix[pi, gi] = dist # 使用匈牙利算法或贪婪匹配 if use_hungarian and SCIPY_AVAILABLE and n_preds > 0 and n_gts > 0: # 匈牙利算法最优匹配 row_ind, col_ind = linear_sum_assignment(cost_matrix) for pi, gi in zip(row_ind, col_ind): if cost_matrix[pi, gi] < threshold: pred_idx = cat_preds[pi][0] gt_idx = cat_gts[gi][0] matches.append((pred_idx, gt_idx)) matched_preds.add(pred_idx) matched_gts.add(gt_idx) else: # 贪婪匹配(按距离排序) distances = [] for pi, (pred_idx, pred) in enumerate(cat_preds): for gi, (gt_idx, gt) in enumerate(cat_gts): dist = cost_matrix[pi, gi] if dist < threshold: distances.append((dist, pred_idx, gt_idx)) distances.sort(key=lambda x: x[0]) for dist, pred_idx, gt_idx in distances: if pred_idx not in matched_preds and gt_idx not in matched_gts: matches.append((pred_idx, gt_idx)) matched_preds.add(pred_idx) matched_gts.add(gt_idx) false_positives = [i for i in range(len(predictions)) if i not in matched_preds] false_negatives = [i for i in range(len(ground_truths)) if i not in matched_gts] return matches, false_positives, false_negatives def calculate_detection_f1( predictions: List[Dict], ground_truths: List[Dict], threshold: float = 2.0, ) -> Dict[str, float]: matches, false_positives, false_negatives = match_detections( predictions, ground_truths, threshold ) tp = len(matches) fp = len(false_positives) fn = len(false_negatives) precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0 recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0 f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0 metrics = { 'precision': precision, 'recall': recall, 'f1': f1, 'tp': tp, 'fp': fp, 'fn': fn, 'num_predictions': len(predictions), 'num_ground_truths': len(ground_truths), } return metrics def denormalize_ref_points_01( ref_points_01: np.ndarray, pc_range: Tuple[float, float, float, float, float, float] = (-51.2, -51.2, -5.0, 51.2, 51.2, 3.0), ) -> np.ndarray: """Convert normalized ref points in [0,1] back to meters. Args: ref_points_01: array-like [..., 3] in [0, 1] pc_range: (x_min, y_min, z_min, x_max, y_max, z_max) Returns: np.ndarray [..., 3] in meters """ ref = np.asarray(ref_points_01, dtype=np.float64) pc_min = np.array(pc_range[:3], dtype=np.float64) pc_max = np.array(pc_range[3:], dtype=np.float64) denom = np.clip(pc_max - pc_min, 1e-6, None) ref01 = np.clip(ref, 0.0, 1.0) return pc_min + ref01 * denom def snap_detections_to_ref_points( predictions: List[Dict], ref_points_01: np.ndarray, pc_range: Tuple[float, float, float, float, float, float] = (-51.2, -51.2, -5.0, 51.2, 51.2, 3.0), keep_z: bool = True, ) -> List[Dict]: """Snap predicted detection centers to nearest reference points (BEV XY). This is a post-processing step that constrains predictions to lie on the StreamPETR proposal set (ref points). It can significantly reduce small metric thresholds (0.5m/1m) sensitivity to free-form numeric drift. Args: predictions: list of detection dicts with 'world_coords' in meters ref_points_01: [Q,3] or [B,Q,3] normalized ref points in [0,1] pc_range: point cloud range for denormalization keep_z: if True, keep each prediction's original z; else use ref z Returns: New list of predictions (deep-copied dicts) with snapped 'world_coords' """ if not predictions: return [] ref = np.asarray(ref_points_01, dtype=np.float64) if ref.ndim == 3: ref = ref[0] if ref.ndim != 2 or ref.shape[1] != 3 or ref.shape[0] == 0: return list(predictions) ref_m = denormalize_ref_points_01(ref, pc_range=pc_range) ref_xy = ref_m[:, :2] pred_xy = np.array([p.get("world_coords", [0.0, 0.0, 0.0])[:2] for p in predictions], dtype=np.float64) if pred_xy.ndim != 2 or pred_xy.shape[0] == 0: return list(predictions) d = ((pred_xy[:, None, :] - ref_xy[None, :, :]) ** 2).sum(-1) nn = d.argmin(axis=1) snapped = [] for i, p in enumerate(predictions): p2 = dict(p) wc = list(p2.get("world_coords", [0.0, 0.0, 0.0])) j = int(nn[i]) new_xyz = ref_m[j].tolist() if keep_z and len(wc) >= 3: new_xyz[2] = float(wc[2]) p2["world_coords"] = [float(new_xyz[0]), float(new_xyz[1]), float(new_xyz[2])] snapped.append(p2) return snapped def calculate_per_class_metrics( predictions: List[Dict], ground_truths: List[Dict], threshold: float = 2.0, ) -> Dict[str, Dict[str, float]]: pred_categories = set(pred['category'] for pred in predictions) gt_categories = set(gt['category'] for gt in ground_truths) all_categories = pred_categories | gt_categories per_class_metrics = {} for category in all_categories: cat_preds = [pred for pred in predictions if pred['category'] == category] cat_gts = [gt for gt in ground_truths if gt['category'] == category] metrics = calculate_detection_f1(cat_preds, cat_gts, threshold) per_class_metrics[category] = metrics return per_class_metrics def parse_planning_output(text: str, require_full_vap: bool = False) -> Optional[Dict]: result = {} vel_pattern = r'ego car speed value:\s*\[(\d+),\s*(\d+)\]\.?' acc_pattern = r'ego car acceleration value:\s*\[(\d+),\s*(\d+)\]\.?' wp_pattern = ( r'(?:based on the ego car speed and acceleration you predicted,\s*)?' r'(?:requeset|request)\s+the ego car planning waypoint(?:s)? in 3-seconds:\s*' r'((?:\[\d+,\s*\d+\](?:,\s*)?)+)\.?' ) vel_m = re.search(vel_pattern, text, flags=re.IGNORECASE) if vel_m: result['velocity_bins'] = [int(vel_m.group(1)), int(vel_m.group(2))] acc_m = re.search(acc_pattern, text, flags=re.IGNORECASE) if acc_m: result['acceleration_bins'] = [int(acc_m.group(1)), int(acc_m.group(2))] wp_m = re.search(wp_pattern, text, flags=re.IGNORECASE) if wp_m: point_pattern = r'\[(\d+),\s*(\d+)\]' points = re.findall(point_pattern, wp_m.group(1)) wps = [] for xb, yb in points: x = bin_to_meters(int(xb), bin_range=(-51.2, 51.2)) y = bin_to_meters(int(yb), bin_range=(-51.2, 51.2)) wps.append([x, y]) result['waypoints'] = wps if 'waypoints' not in result or len(result['waypoints']) == 0: return None # Planning answers use a Figure 5-style chained speed + acceleration + # waypoint protocol. The main evaluation path can require all three fields. if require_full_vap and ( 'velocity_bins' not in result or 'acceleration_bins' not in result ): return None return result def _pad_waypoints(waypoints: List[List[float]], target_n: int = 6) -> List[List[float]]: """Pad waypoint list to target_n by repeating last waypoint. This prevents short model outputs from gaming the L2 / collision metrics. """ if len(waypoints) >= target_n: return waypoints[:target_n] if len(waypoints) == 0: return [[0.0, 0.0]] * target_n last = list(waypoints[-1]) return list(waypoints) + [list(last)] * (target_n - len(waypoints)) def calculate_planning_l2( pred_waypoints: List[List[float]], gt_waypoints: List[List[float]], timestamps: List[float] = None, ) -> Dict[str, float]: n_gt = len(gt_waypoints) if timestamps is None: timestamps = [0.5 * (i + 1) for i in range(n_gt)] # Pad predictions to match GT length to prevent short-output bias pred_padded = _pad_waypoints(pred_waypoints, target_n=n_gt) errors = {} all_l2 = [] for i in range(n_gt): pred = np.array(pred_padded[i][:2]) gt = np.array(gt_waypoints[i][:2]) l2 = float(np.linalg.norm(pred - gt)) all_l2.append(l2) t = timestamps[i] if i < len(timestamps) else 0.5 * (i + 1) if abs(t - 1.0) < 0.01: errors['L2_1s'] = l2 if abs(t - 2.0) < 0.01: errors['L2_2s'] = l2 if abs(t - 3.0) < 0.01: errors['L2_3s'] = l2 key_steps = [v for k, v in errors.items() if k in ('L2_1s', 'L2_2s', 'L2_3s')] errors['L2_avg'] = float(np.mean(key_steps)) if key_steps else (float(np.mean(all_l2)) if all_l2 else 0.0) return errors def _box_corners_2d(cx: float, cy: float, w: float, l: float, yaw: float) -> np.ndarray: """Build oriented box corners for yaw-from-x headings. In planning eval JSON, yaw is measured from +X (right) axis: - yaw = 0 -> vehicle length points to +X - yaw = +pi/2 -> vehicle length points to +Y This matches the qualitative visualization helper. """ c = np.cos(yaw) s = np.sin(yaw) center = np.array([cx, cy], dtype=np.float64) # Heading axis follows the vehicle length, with width perpendicular to it. d_len = np.array([c, s], dtype=np.float64) * (l / 2.0) d_wid = np.array([-s, c], dtype=np.float64) * (w / 2.0) corners = np.stack([ center + d_len + d_wid, center + d_len - d_wid, center - d_len - d_wid, center - d_len + d_wid, ], axis=0) return corners def _boxes_overlap(box1_corners: np.ndarray, box2_corners: np.ndarray) -> bool: for box in [box1_corners, box2_corners]: for i in range(4): j = (i + 1) % 4 edge = box[j] - box[i] normal = np.array([-edge[1], edge[0]]) proj1 = box1_corners @ normal proj2 = box2_corners @ normal if proj1.max() < proj2.min() or proj2.max() < proj1.min(): return False return True def _check_collision_at_waypoints( waypoints: List[List[float]], gt_boxes: List[Dict], ego_w: float, ego_l: float, gt_boxes_per_timestep: Optional[List[List[Dict]]] = None, ) -> List[bool]: """Check collision between ego at each waypoint and GT boxes. When *gt_boxes_per_timestep* is provided (ST-P3 aligned), each waypoint is checked against the boxes at the corresponding future timestep. Otherwise falls back to using the same static *gt_boxes* for all waypoints. """ collisions = [] for i, wp in enumerate(waypoints): if i + 1 < len(waypoints): dx = waypoints[i + 1][0] - wp[0] dy = waypoints[i + 1][1] - wp[1] ego_yaw = float(np.arctan2(dy, dx)) if (abs(dx) + abs(dy)) > 1e-4 else 0.0 elif i > 0: dx = wp[0] - waypoints[i - 1][0] dy = wp[1] - waypoints[i - 1][1] ego_yaw = float(np.arctan2(dy, dx)) if (abs(dx) + abs(dy)) > 1e-4 else 0.0 else: ego_yaw = 0.0 ego_corners = _box_corners_2d(wp[0], wp[1], ego_w, ego_l, ego_yaw) boxes_at_t = gt_boxes if gt_boxes_per_timestep is not None and i < len(gt_boxes_per_timestep): boxes_at_t = gt_boxes_per_timestep[i] collided = False for box in boxes_at_t: if 'world_coords' not in box: continue bx, by = box['world_coords'][0], box['world_coords'][1] bw = box.get('w', 2.0) bl = box.get('l', 4.0) byaw = box.get('yaw', 0.0) obj_corners = _box_corners_2d(bx, by, bw, bl, byaw) if _boxes_overlap(ego_corners, obj_corners): collided = True break collisions.append(collided) return collisions def calculate_collision_rate( pred_waypoints: List[List[float]], gt_boxes: List[Dict], ego_w: float = 1.85, ego_l: float = 4.084, timestamps: List[float] = None, num_waypoints: int = 6, gt_waypoints: Optional[List[List[float]]] = None, gt_boxes_per_timestep: Optional[List[List[Dict]]] = None, ) -> Dict[str, float]: pred_padded = _pad_waypoints(pred_waypoints, target_n=num_waypoints) if timestamps is None: timestamps = [0.5 * (i + 1) for i in range(num_waypoints)] # ST-P3 aligned: exclude timesteps where the GT trajectory itself collides gt_collides = [False] * num_waypoints if gt_waypoints is not None: gt_padded = _pad_waypoints(gt_waypoints, target_n=num_waypoints) gt_collides = _check_collision_at_waypoints( gt_padded, gt_boxes, ego_w, ego_l, gt_boxes_per_timestep=gt_boxes_per_timestep, ) pred_collides = _check_collision_at_waypoints( pred_padded, gt_boxes, ego_w, ego_l, gt_boxes_per_timestep=gt_boxes_per_timestep, ) collisions_at_t = {} for i in range(num_waypoints): t = timestamps[i] if i < len(timestamps) else 0.5 * (i + 1) if gt_collides[i]: collisions_at_t[t] = False else: collisions_at_t[t] = pred_collides[i] results = {} for target_t, key in [(1.0, 'collision_1s'), (2.0, 'collision_2s'), (3.0, 'collision_3s')]: matched = [v for t, v in collisions_at_t.items() if abs(t - target_t) < 0.01] if matched: results[key] = float(matched[0]) key_cols = [v for k, v in results.items() if k in ('collision_1s', 'collision_2s', 'collision_3s')] results['collision_avg'] = float(np.mean(key_cols)) if key_cols else 0.0 return results def calculate_planning_metrics( predictions: List[Dict], ground_truths: List[Dict], ) -> Dict[str, float]: all_l2 = {'L2_1s': [], 'L2_2s': [], 'L2_3s': [], 'L2_avg': []} all_col = {'collision_1s': [], 'collision_2s': [], 'collision_3s': [], 'collision_avg': []} for pred, gt in zip(predictions, ground_truths): pred_wps = pred.get('waypoints', []) gt_wps = gt.get('waypoints', []) if pred_wps and gt_wps: l2 = calculate_planning_l2(pred_wps, gt_wps) for k, v in l2.items(): if k in all_l2: all_l2[k].append(v) gt_boxes = gt.get('gt_boxes', []) gt_boxes_per_ts = gt.get('gt_boxes_per_timestep', None) if pred_wps and (gt_boxes or gt_boxes_per_ts): col = calculate_collision_rate( pred_wps, gt_boxes, gt_waypoints=gt_wps, gt_boxes_per_timestep=gt_boxes_per_ts, ) for k, v in col.items(): if k in all_col: all_col[k].append(v) results = {} for k, vals in all_l2.items(): results[k] = float(np.mean(vals)) if vals else 0.0 for k, vals in all_col.items(): results[k] = float(np.mean(vals)) if vals else 0.0 return results VEL_ACC_RANGE = (-50.0, 50.0) def vel_acc_bin_to_meters(bin_val: int, num_bins: int = 1000) -> float: return bin_to_meters(bin_val, bin_range=VEL_ACC_RANGE, num_bins=num_bins) def chamfer_distance_polyline( pred_pts: np.ndarray, gt_pts: np.ndarray, ) -> float: if len(pred_pts) == 0 or len(gt_pts) == 0: return float('inf') pred_pts = np.asarray(pred_pts, dtype=np.float64) gt_pts = np.asarray(gt_pts, dtype=np.float64) d_p2g = 0.0 for p in pred_pts: d_p2g += np.linalg.norm(gt_pts - p[None, :], axis=1).min() d_p2g /= len(pred_pts) d_g2p = 0.0 for g in gt_pts: d_g2p += np.linalg.norm(pred_pts - g[None, :], axis=1).min() d_g2p /= len(gt_pts) return 0.5 * (d_p2g + d_g2p) def _lane_points_array(lane) -> np.ndarray: pts = lane.get('points', []) if not pts: return np.zeros((0, 3)) rows = [] for pt in pts: if isinstance(pt, dict): rows.append(pt.get('world_coords', [0, 0, 0])[:3]) else: rows.append(list(pt)[:3]) return np.array(rows, dtype=np.float64) def match_lanes( pred_lanes: List[Dict], gt_lanes: List[Dict], threshold: float = 1.5, ) -> Tuple[List[Tuple[int, int]], List[int], List[int]]: if not pred_lanes: return [], [], list(range(len(gt_lanes))) if not gt_lanes: return [], list(range(len(pred_lanes))), [] n_p = len(pred_lanes) n_g = len(gt_lanes) cost = np.full((n_p, n_g), float('inf')) for i, pl in enumerate(pred_lanes): p_pts = _lane_points_array(pl) if len(p_pts) == 0: continue for j, gl in enumerate(gt_lanes): g_pts = _lane_points_array(gl) if len(g_pts) == 0: continue cd = chamfer_distance_polyline(p_pts, g_pts) if cd < threshold: cost[i, j] = cd matches = [] matched_p = set() matched_g = set() if SCIPY_AVAILABLE and n_p > 0 and n_g > 0 and np.isfinite(cost).any(): try: row_ind, col_ind = linear_sum_assignment(cost) except ValueError: row_ind, col_ind = [], [] for pi, gi in zip(row_ind, col_ind): if cost[pi, gi] < threshold: matches.append((pi, gi)) matched_p.add(pi) matched_g.add(gi) else: pairs = [] for i in range(n_p): for j in range(n_g): if cost[i, j] < threshold: pairs.append((cost[i, j], i, j)) pairs.sort() for _, i, j in pairs: if i not in matched_p and j not in matched_g: matches.append((i, j)) matched_p.add(i) matched_g.add(j) fp = [i for i in range(n_p) if i not in matched_p] fn = [j for j in range(n_g) if j not in matched_g] return matches, fp, fn def calculate_lane_detection_metrics( pred_lanes: List[Dict], gt_lanes: List[Dict], threshold: float = 1.5, ) -> Dict[str, float]: matches, fp_list, fn_list = match_lanes(pred_lanes, gt_lanes, threshold) tp = len(matches) fp = len(fp_list) fn = len(fn_list) precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0 recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0 f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0 return { 'lane_precision': precision, 'lane_recall': recall, 'lane_f1': f1, 'lane_tp': tp, 'lane_fp': fp, 'lane_fn': fn, } def calculate_multi_threshold_detection_f1( predictions: List[Dict], ground_truths: List[Dict], thresholds: Tuple[float, ...] = (0.5, 1.0, 2.0, 4.0), ) -> Dict[str, float]: results = {} f1_vals = [] for t in thresholds: m = calculate_detection_f1(predictions, ground_truths, threshold=t) results[f'P@{t}m'] = m['precision'] results[f'R@{t}m'] = m['recall'] results[f'F1@{t}m'] = m['f1'] f1_vals.append(m['f1']) results['F1_avg'] = float(np.mean(f1_vals)) if f1_vals else 0.0 return results def evaluate_all( task_predictions: Dict[str, List], task_ground_truths: Dict[str, List], ) -> Dict[str, Dict[str, float]]: results = {} if 'detection' in task_predictions and 'detection' in task_ground_truths: results['detection'] = calculate_multi_threshold_detection_f1( task_predictions['detection'], task_ground_truths['detection'], ) if 'lane' in task_predictions and 'lane' in task_ground_truths: agg = {'lane_precision': [], 'lane_recall': [], 'lane_f1': []} for pred_set, gt_set in zip(task_predictions['lane'], task_ground_truths['lane']): p_list = pred_set if isinstance(pred_set, list) else [pred_set] g_list = gt_set if isinstance(gt_set, list) else [gt_set] m = calculate_lane_detection_metrics(p_list, g_list) for k in agg: agg[k].append(m[k]) results['lane'] = {k: float(np.mean(v)) for k, v in agg.items() if v} if 'planning' in task_predictions and 'planning' in task_ground_truths: results['planning'] = calculate_planning_metrics( task_predictions['planning'], task_ground_truths['planning'], ) return results