| import glob |
| import gzip |
| import json |
| import math |
| import os |
| import pickle |
| import typing |
| from dataclasses import dataclass |
| from enum import IntEnum, auto |
|
|
| import cv2 |
| import jaxtyping as jt |
| import numpy as np |
| import timm |
| import torch |
| import torch.nn.functional as F |
| from beartype import beartype |
| from omegaconf import OmegaConf |
| from torch import nn |
|
|
|
|
| class CameraPointCloudIndex(IntEnum): |
| """Index to access point cloud array of camera.""" |
|
|
| X = 0 |
| Y = 1 |
| Z = 2 |
| UNREAL_SEMANTICS_ID = 3 |
| UNREAL_INSTANCE_ID = 4 |
|
|
|
|
| class TargetDataset(IntEnum): |
| """Dataset we target to collect data/train/evaluate on.""" |
|
|
| UNKNOWN = 0 |
| CARLA_LEADERBOARD2_3CAMERAS = auto() |
| CARLA_LEADERBOARD2_6CAMERAS = auto() |
| NAVSIM_4CAMERAS = auto() |
| WAYMO_E2E_2025_3CAMERAS = auto() |
|
|
|
|
| class SourceDataset(IntEnum): |
| """Dataset intenum a sample originates from.""" |
|
|
| UNKNOWN = 0 |
| CARLA = auto() |
| NAVSIM = auto() |
| WAYMO_E2E_2025 = auto() |
|
|
|
|
| SOURCE_DATASET_NAME_MAP = { |
| SourceDataset.CARLA: "carla", |
| SourceDataset.NAVSIM: "navsim", |
| SourceDataset.WAYMO_E2E_2025: "waymo_e2e_2025", |
| } |
|
|
|
|
| class RadarLabels(IntEnum): |
| """Index to access radar label array.""" |
|
|
| X = 0 |
| Y = 1 |
| V = 2 |
| VALID = 3 |
|
|
|
|
| class RadarDataIndex(IntEnum): |
| """Index to access radar data array.""" |
|
|
| X = 0 |
| Y = 1 |
| Z = 2 |
| V = 3 |
| SENSOR_ID = 4 |
|
|
|
|
| class TransfuserBoundingBoxIndex(IntEnum): |
| """Index to access bounding array of TransFuser.""" |
|
|
| X = 0 |
| Y = 1 |
| W = 2 |
| H = 3 |
| YAW = 4 |
| VELOCITY = 5 |
| BRAKE = 6 |
| CLASS = 7 |
| SCORE = 8 |
| VISIBLE_PIXELS = 8 |
| NUM_POINTS = 9 |
| ID = 10 |
| NUM_RADAR_POINTS = 11 |
|
|
|
|
| class TransfuserBoundingBoxClass(IntEnum): |
| """Bounding box classes used in TransFuser.""" |
|
|
| VEHICLE = 0 |
| WALKER = 1 |
| TRAFFIC_LIGHT = 2 |
| STOP_SIGN = 3 |
| SPECIAL = 4 |
| OBSTACLE = 5 |
| PARKING = 6 |
| BIKER = 7 |
|
|
|
|
| class CarlaNavigationCommand(IntEnum): |
| """Source: https://carla.org/Doxygen/html/d0/db7/namespacecarla_1_1traffic__manager.html#a5734807dba08623eeca046a963ade360""" |
|
|
| UNKNOWN = 0 |
| LEFT = 1 |
| RIGHT = 2 |
| STRAIGHT = 3 |
| LANEFOLLOW = 4 |
| CHANGELANELEFT = 5 |
| CHANGELANERIGHT = 6 |
|
|
|
|
| class ChaffeurNetBEVSemanticClass(IntEnum): |
| """Indicies to access BEV semantic map produced by ChaffeurNet.""" |
|
|
| UNLABELED = 0 |
| ROAD = 1 |
| SIDEWALK = 2 |
| LANE_MARKERS = 3 |
| LANE_MARKERS_BROKEN = 4 |
| STOP_SIGNS = 5 |
| TRAFFIC_GREEN = 6 |
| TRAFFIC_YELLOW = 7 |
| TRAFFIC_RED = 8 |
|
|
|
|
| class TransfuserBEVSemanticClass(IntEnum): |
| """Indicies to access BEV semantic map produced by TransFuser.""" |
|
|
| UNLABELED = 0 |
| ROAD = 1 |
| LANE_MARKERS = 2 |
| STOP_SIGNS = 3 |
| VEHICLE = 4 |
| WALKER = 5 |
| OBSTACLE = 6 |
| PARKING_VEHICLE = 7 |
| SPECIAL_VEHICLE = 8 |
| BIKER = 9 |
| TRAFFIC_GREEN = 10 |
| TRAFFIC_RED_NORMAL = 11 |
| TRAFFIC_RED_NOT_NORMAL = 12 |
|
|
|
|
| class TransfuserBEVOccupancyClass(IntEnum): |
| """Indicies to access BEV occupancy map produced by TransFuser.""" |
|
|
| UNLABELED = 0 |
| VEHICLE = 1 |
| WALKER = 2 |
| OBSTACLE = 3 |
| PARKING_VEHICLE = 4 |
| SPECIAL_VEHICLE = 5 |
| BIKER = 6 |
| TRAFFIC_GREEN = 7 |
| TRAFFIC_RED_NORMAL = 8 |
| TRAFFIC_RED_NOT_NORMAL = 9 |
|
|
|
|
| class WeatherVisibility(IntEnum): |
| """Visibility conditions in CARLA simulator, classified by TransFuser.""" |
|
|
| CLEAR = 0 |
| OK = 1 |
| LIMITED = 2 |
| VERY_LIMITED = 3 |
|
|
|
|
| class CarlaSemanticSegmentationClass(IntEnum): |
| """https://carla.readthedocs.io/en/latest/ref_sensors/#semantic-segmentation-camera""" |
|
|
| Unlabeled = 0 |
| Roads = 1 |
| SideWalks = 2 |
| Building = 3 |
| Wall = 4 |
| Fence = 5 |
| Pole = 6 |
| TrafficLight = 7 |
| TrafficSign = 8 |
| Vegetation = 9 |
| Terrain = 10 |
| Sky = 11 |
| Pedestrian = 12 |
| Rider = 13 |
| Car = 14 |
| Truck = 15 |
| Bus = 16 |
| Train = 17 |
| Motorcycle = 18 |
| Bicycle = 19 |
| Static = 20 |
| Dynamic = 21 |
| Other = 22 |
| Water = 23 |
| RoadLine = 24 |
| Ground = 25 |
| Bridge = 26 |
| RailTrack = 27 |
| GuardRail = 28 |
| |
| ConeAndTrafficWarning = 29 |
| SpecialVehicles = 30 |
| StopSign = 31 |
|
|
|
|
| class TransfuserSemanticSegmentationClass(IntEnum): |
| """Semantic segmentation classes used in TransFuser.""" |
|
|
| UNLABELED = 0 |
| VEHICLE = 1 |
| ROAD = 2 |
| TRAFFIC_LIGHT = 3 |
| PEDESTRIAN = 4 |
| ROAD_LINE = 5 |
| OBSTACLE = 6 |
| SPECIAL_VEHICLE = 7 |
| STOP_SIGN = 8 |
| BIKER = 9 |
|
|
|
|
| def rgb(r, g, b): |
| """Dummy function to help with visualizing RGB colors by using a VSCode extension.""" |
| return (r, g, b) |
|
|
|
|
| |
| LIDAR_COLOR = rgb(90, 107, 249) |
| EGO_BB_COLOR = rgb(151, 15, 48) |
| TP_DEFAULT_COLOR = rgb(255, 10, 10) |
| RADAR_COLOR = rgb(24, 237, 3) |
| RADAR_DETECTION_COLOR = rgb(255, 24, 0) |
|
|
| |
| GROUNDTRUTH_BB_WP_COLOR = rgb(15, 60, 255) |
| GROUNDTRUTH_FUTURE_WAYPOINT_COLOR = rgb(0, 0, 0) |
| GROUND_TRUTH_PAST_WAYPOINT_COLOR = rgb(0, 0, 0) |
| PREDICTION_WAYPOINT_COLOR = rgb(255, 0, 0) |
| PREDICTION_ROUTE_COLOR = rgb(0, 0, 255) |
| PREDICTION_WAYPOINT_RADIUS = 10 |
| PREDICTION_ROUTE_RADIUS = 6 |
|
|
| |
| CARLA_TRANSFUSER_BEV_SEMANTIC_COLOR_CONVERTER = { |
| TransfuserBEVSemanticClass.UNLABELED: rgb(0, 0, 0), |
| TransfuserBEVSemanticClass.ROAD: rgb(250, 250, 250), |
| TransfuserBEVSemanticClass.LANE_MARKERS: rgb(255, 255, 0), |
| TransfuserBEVSemanticClass.STOP_SIGNS: rgb(160, 160, 0), |
| TransfuserBEVSemanticClass.VEHICLE: rgb(15, 60, 255), |
| TransfuserBEVSemanticClass.WALKER: rgb(0, 255, 0), |
| TransfuserBEVSemanticClass.OBSTACLE: rgb(255, 0, 0), |
| TransfuserBEVSemanticClass.PARKING_VEHICLE: rgb(116, 150, 65), |
| TransfuserBEVSemanticClass.SPECIAL_VEHICLE: rgb(255, 0, 255), |
| TransfuserBEVSemanticClass.BIKER: rgb(255, 0, 0), |
| TransfuserBEVSemanticClass.TRAFFIC_GREEN: rgb(0, 255, 0), |
| TransfuserBEVSemanticClass.TRAFFIC_RED_NORMAL: rgb(255, 0, 0), |
| TransfuserBEVSemanticClass.TRAFFIC_RED_NOT_NORMAL: rgb(0, 0, 255), |
| } |
|
|
| |
| TRANSFUSER_SEMANTIC_COLORS = { |
| TransfuserSemanticSegmentationClass.UNLABELED: rgb(0, 0, 0), |
| TransfuserSemanticSegmentationClass.VEHICLE: rgb(31, 119, 180), |
| TransfuserSemanticSegmentationClass.ROAD: rgb(128, 64, 128), |
| TransfuserSemanticSegmentationClass.TRAFFIC_LIGHT: rgb(250, 170, 30), |
| TransfuserSemanticSegmentationClass.PEDESTRIAN: rgb(0, 255, 60), |
| TransfuserSemanticSegmentationClass.ROAD_LINE: rgb(157, 234, 50), |
| TransfuserSemanticSegmentationClass.OBSTACLE: rgb(255, 0, 0), |
| TransfuserSemanticSegmentationClass.SPECIAL_VEHICLE: rgb(255, 255, 0), |
| TransfuserSemanticSegmentationClass.STOP_SIGN: rgb(125, 0, 0), |
| TransfuserSemanticSegmentationClass.BIKER: rgb(220, 20, 60), |
| } |
|
|
| |
| TRANSFUSER_BOUNDING_BOX_COLORS = { |
| TransfuserBoundingBoxClass.VEHICLE: rgb(0, 0, 255), |
| TransfuserBoundingBoxClass.WALKER: rgb(0, 255, 0), |
| TransfuserBoundingBoxClass.TRAFFIC_LIGHT: rgb(255, 0, 0), |
| TransfuserBoundingBoxClass.STOP_SIGN: rgb(250, 160, 160), |
| TransfuserBoundingBoxClass.SPECIAL: rgb(0, 0, 255), |
| TransfuserBoundingBoxClass.OBSTACLE: rgb(0, 255, 13), |
| TransfuserBoundingBoxClass.PARKING: rgb(116, 150, 65), |
| TransfuserBoundingBoxClass.BIKER: rgb(255, 0, 0), |
| } |
|
|
| |
| SEMANTIC_SEGMENTATION_CONVERTER = { |
| CarlaSemanticSegmentationClass.Unlabeled: TransfuserSemanticSegmentationClass.UNLABELED, |
| CarlaSemanticSegmentationClass.Roads: TransfuserSemanticSegmentationClass.ROAD, |
| CarlaSemanticSegmentationClass.SideWalks: TransfuserSemanticSegmentationClass.UNLABELED, |
| CarlaSemanticSegmentationClass.Building: TransfuserSemanticSegmentationClass.UNLABELED, |
| CarlaSemanticSegmentationClass.Wall: TransfuserSemanticSegmentationClass.UNLABELED, |
| CarlaSemanticSegmentationClass.Fence: TransfuserSemanticSegmentationClass.UNLABELED, |
| CarlaSemanticSegmentationClass.Pole: TransfuserSemanticSegmentationClass.UNLABELED, |
| CarlaSemanticSegmentationClass.TrafficLight: TransfuserSemanticSegmentationClass.TRAFFIC_LIGHT, |
| CarlaSemanticSegmentationClass.TrafficSign: TransfuserSemanticSegmentationClass.UNLABELED, |
| CarlaSemanticSegmentationClass.Vegetation: TransfuserSemanticSegmentationClass.UNLABELED, |
| CarlaSemanticSegmentationClass.Terrain: TransfuserSemanticSegmentationClass.UNLABELED, |
| CarlaSemanticSegmentationClass.Sky: TransfuserSemanticSegmentationClass.UNLABELED, |
| CarlaSemanticSegmentationClass.Pedestrian: TransfuserSemanticSegmentationClass.PEDESTRIAN, |
| CarlaSemanticSegmentationClass.Rider: TransfuserSemanticSegmentationClass.BIKER, |
| CarlaSemanticSegmentationClass.Car: TransfuserSemanticSegmentationClass.VEHICLE, |
| CarlaSemanticSegmentationClass.Truck: TransfuserSemanticSegmentationClass.VEHICLE, |
| CarlaSemanticSegmentationClass.Bus: TransfuserSemanticSegmentationClass.VEHICLE, |
| CarlaSemanticSegmentationClass.Train: TransfuserSemanticSegmentationClass.UNLABELED, |
| CarlaSemanticSegmentationClass.Motorcycle: TransfuserSemanticSegmentationClass.VEHICLE, |
| CarlaSemanticSegmentationClass.Bicycle: TransfuserSemanticSegmentationClass.BIKER, |
| CarlaSemanticSegmentationClass.Static: TransfuserSemanticSegmentationClass.UNLABELED, |
| CarlaSemanticSegmentationClass.Dynamic: TransfuserSemanticSegmentationClass.UNLABELED, |
| CarlaSemanticSegmentationClass.Other: TransfuserSemanticSegmentationClass.UNLABELED, |
| CarlaSemanticSegmentationClass.Water: TransfuserSemanticSegmentationClass.UNLABELED, |
| CarlaSemanticSegmentationClass.RoadLine: TransfuserSemanticSegmentationClass.ROAD_LINE, |
| CarlaSemanticSegmentationClass.Ground: TransfuserSemanticSegmentationClass.UNLABELED, |
| CarlaSemanticSegmentationClass.Bridge: TransfuserSemanticSegmentationClass.UNLABELED, |
| CarlaSemanticSegmentationClass.RailTrack: TransfuserSemanticSegmentationClass.UNLABELED, |
| CarlaSemanticSegmentationClass.GuardRail: TransfuserSemanticSegmentationClass.UNLABELED, |
| CarlaSemanticSegmentationClass.ConeAndTrafficWarning: TransfuserSemanticSegmentationClass.OBSTACLE, |
| CarlaSemanticSegmentationClass.SpecialVehicles: TransfuserSemanticSegmentationClass.SPECIAL_VEHICLE, |
| CarlaSemanticSegmentationClass.StopSign: TransfuserSemanticSegmentationClass.STOP_SIGN, |
| } |
|
|
| |
| CHAFFEURNET_TO_TRANSFUSER_BEV_SEMANTIC_CONVERTER = { |
| ChaffeurNetBEVSemanticClass.UNLABELED: TransfuserBEVSemanticClass.UNLABELED, |
| ChaffeurNetBEVSemanticClass.ROAD: TransfuserBEVSemanticClass.ROAD, |
| ChaffeurNetBEVSemanticClass.SIDEWALK: TransfuserBEVSemanticClass.UNLABELED, |
| ChaffeurNetBEVSemanticClass.LANE_MARKERS: TransfuserBEVSemanticClass.LANE_MARKERS, |
| ChaffeurNetBEVSemanticClass.LANE_MARKERS_BROKEN: TransfuserBEVSemanticClass.LANE_MARKERS, |
| ChaffeurNetBEVSemanticClass.STOP_SIGNS: TransfuserBEVSemanticClass.STOP_SIGNS, |
| ChaffeurNetBEVSemanticClass.TRAFFIC_GREEN: TransfuserBEVSemanticClass.UNLABELED, |
| ChaffeurNetBEVSemanticClass.TRAFFIC_YELLOW: TransfuserBEVSemanticClass.UNLABELED, |
| ChaffeurNetBEVSemanticClass.TRAFFIC_RED: TransfuserBEVSemanticClass.UNLABELED, |
| } |
|
|
| SCENARIO_TYPES = [ |
| "Accident", |
| "AccidentTwoWays", |
| "BlockedIntersection", |
| "ConstructionObstacle", |
| "ConstructionObstacleTwoWays", |
| "ControlLoss", |
| "CrossJunctionDefectTrafficLight", |
| "CrossingBicycleFlow", |
| "DynamicObjectCrossing", |
| "EnterActorFlow", |
| "EnterActorFlowV2", |
| "HardBreakRoute", |
| "HazardAtSideLane", |
| "HazardAtSideLaneTwoWays", |
| "HighwayCutIn", |
| "HighwayExit", |
| "InterurbanActorFlow", |
| "InterurbanAdvancedActorFlow", |
| "InvadingTurn", |
| "MergerIntoSlowTraffic", |
| "MergerIntoSlowTrafficV2", |
| "NonSignalizedJunctionLeftTurn", |
| "NonSignalizedJunctionLeftTurnEnterFlow", |
| "NonSignalizedJunctionRightTurn", |
| "noScenarios", |
| "OppositeVehicleRunningRedLight", |
| "OppositeVehicleTakingPriority", |
| "ParkedObstacle", |
| "ParkedObstacleTwoWays", |
| "ParkingCrossingPedestrian", |
| "ParkingCutIn", |
| "ParkingExit", |
| "PedestrianCrossing", |
| "PriorityAtJunction", |
| "RedLightWithoutLeadVehicle", |
| "SequentialLaneChange", |
| "SignalizedJunctionLeftTurn", |
| "SignalizedJunctionLeftTurnEnterFlow", |
| "SignalizedJunctionRightTurn", |
| "StaticCutIn", |
| "T_Junction", |
| "VanillaNonSignalizedTurn", |
| "VanillaNonSignalizedTurnEncounterStopsign", |
| "VanillaSignalizedTurnEncounterGreenLight", |
| "VanillaSignalizedTurnEncounterRedLight", |
| "VehicleOpensDoorTwoWays", |
| "VehicleTurningRoute", |
| "VehicleTurningRoutePedestrian", |
| "YieldToEmergencyVehicle", |
| "NA", |
| ] |
|
|
| EMERGENCY_MESHES = { |
| "vehicle.dodge.charger_police_2020", |
| "vehicle.dodge.charger_police", |
| "vehicle.ford.ambulance", |
| "vehicle.carlamotors.firetruck", |
| } |
|
|
| CONSTRUCTION_MESHES = {"static.prop.constructioncone", "static.prop.trafficwarning"} |
|
|
| BIKER_MESHES = { |
| "vehicle.diamondback.century", |
| "vehicle.gazelle.omafiets", |
| "vehicle.bh.crossbike", |
| "vehicle.harley-davidson.low_rider", |
| "vehicle.kawasaki.ninja", |
| "vehicle.vespa.zx125", |
| "vehicle.yamaha.yzf", |
| } |
|
|
| URBAN_MAX_SPEED_LIMIT = 15 |
| SUBURBAN_MAX_SPEED_LIMIT = 25 |
| HIGHWAY_MAX_SPEED_LIMIT = 35 |
|
|
| LOOKUP_TABLE = { |
| "/Game/Carla/Static/Car/4Wheeled/ParkedVehicles/Lincoln/SM_LincolnParked.SM_LincolnParked": [ |
| 2.44619083404541, |
| 1.115301489830017, |
| 0.7606233954429626, |
| ], |
| "/Game/Carla/Static/Car/4Wheeled/ParkedVehicles/Charger/SM_ChargerParked.SM_ChargerParked": [ |
| 2.5039126873016357, |
| 1.0485419034957886, |
| 0.7673624753952026, |
| ], |
| "/Game/Carla/Static/Car/4Wheeled/ParkedVehicles/VolkswagenT2/SM_VolkswagenT2_2021_Parked.SM_VolkswagenT2_2021_Parked": [ |
| 2.2210919857025146, |
| 0.9388753771781921, |
| 0.9936029314994812, |
| ], |
| "/Game/Carla/Static/Car/4Wheeled/ParkedVehicles/FordCrown/SM_FordCrown_parked.SM_FordCrown_parked": [ |
| 2.6828393936157227, |
| 0.9732309579849243, |
| 0.7874829173088074, |
| ], |
| "/Game/Carla/Static/Car/4Wheeled/ParkedVehicles/NissanPatrol2021/SM_NissanPatrol2021_parked.SM_NissanPatrol2021_parked": [ |
| 2.782914400100708, |
| 1.217571496963501, |
| 1.022573471069336, |
| ], |
| "/Game/Carla/Static/Car/4Wheeled/ParkedVehicles/MercedesCCC/SM_MercedesCCC_Parked.SM_MercedesCCC_Parked": [ |
| 2.3368194103240967, |
| 1.0011461973190308, |
| 0.7259762287139893, |
| ], |
| "/Game/Carla/Static/Car/4Wheeled/ParkedVehicles/TeslaM3/SM_TeslaM3_parked.SM_TeslaM3_parked": [ |
| 2.3958897590637207, |
| 1.081725001335144, |
| 0.7438300848007202, |
| ], |
| "/Game/Carla/Static/Car/4Wheeled/ParkedVehicles/Mini2021/SM_Mini2021_parked.SM_Mini2021_parked": [ |
| 2.2763495445251465, |
| 1.0926425457000732, |
| 0.8835831880569458, |
| ], |
| "/Game/Carla/Static/Dynamic/Garden/SM_PlasticTable.SM_PlasticTable": [ |
| 1.241101622581482, |
| 1.241101622581482, |
| 1.239898920059204, |
| ], |
| "/Game/Carla/Static/Dynamic/Garden/SM_PlasticChair.SM_PlasticChair": [ |
| 0.36523768305778503, |
| 0.37522444128990173, |
| 0.6356779336929321, |
| ], |
| "/Game/Carla/Static/Dynamic/Construction/SM_ConstructionCone.SM_ConstructionCone": [ |
| 0.1720348298549652, |
| 0.1720348298549652, |
| 0.2928849756717682, |
| ], |
| } |
|
|
| CONSTRUCTION_CONE_BB_SIZE = [0.1720348298549652, 0.1720348298549652] |
|
|
| TRAFFIC_WARNING_BB_SIZE = [1.186714768409729, 1.4352929592132568] |
|
|
| OLD_TOWNS = { |
| "Town01", |
| "Town02", |
| "Town03", |
| "Town04", |
| "Town05", |
| "Town06", |
| "Town07", |
| "Town10HD", |
| } |
|
|
| |
| ALL_TOWNS = [ |
| "Town01", |
| "Town02", |
| "Town03", |
| "Town04", |
| "Town05", |
| "Town06", |
| "Town07", |
| "Town10HD", |
| "Town11", |
| "Town12", |
| "Town13", |
| "Town15", |
| ] |
|
|
| |
| TOWN_NAME_TO_INDEX = { |
| "Town01": "01", |
| "Town02": "02", |
| "Town03": "03", |
| "Town04": "04", |
| "Town05": "05", |
| "Town06": "06", |
| "Town07": "07", |
| "Town10HD": "10HD", |
| "Town11": "11", |
| "Town12": "12", |
| "Town13": "13", |
| "Town15": "15", |
| } |
|
|
| |
| NUPLAN_CAMERA_CALIBRATION = { |
| "CAM_L0": { |
| "pos": [0.2567803445203347, -0.14912709068475835, 1.9611907818710856], |
| "rot": [-1.7908743283844297, -0.18030832979657968, -56.0098600797478], |
| "fov": 63.71, |
| "width": 1920 // 4, |
| "height": 1120 // 4, |
| "cropped_height": 1080 // 4, |
| }, |
| "CAM_F0": { |
| "pos": [0.3435966588946209, 0.00981503465349912, 1.9520959734648988], |
| "rot": [0.004047525205852703, -2.344563417746492, -1.0253844128360994], |
| "fov": 63.71, |
| "width": 1920 // 4, |
| "height": 1120 // 4, |
| "cropped_height": 1080 // 4, |
| }, |
| "CAM_R0": { |
| "pos": [0.362775037177945, 0.16380984892156114, 1.9540305698009064], |
| "rot": [0.8908895493227222, -0.5177262293066659, 54.26239302855094], |
| "fov": 63.71, |
| "width": 1920 // 4, |
| "height": 1120 // 4, |
| "cropped_height": 1080 // 4, |
| }, |
| "CAM_B0": { |
| "pos": [-1.8371898892894447, 0.023124645489646514, 1.9105230244574516], |
| "rot": [0.004047525205852703, 1.9819092884563787, 180], |
| "fov": 63.71, |
| "width": 1920 // 4, |
| "height": 1120 // 4, |
| "cropped_height": 1080 // 4, |
| }, |
| } |
|
|
|
|
| class NavSimBEVSemanticClass(IntEnum): |
| """Indicies to access BEV semantic map produced by NavSim. |
| |
| See: https://github.com/autonomousvision/navsim/blob/main/navsim/agents/transfuser/transfuser_config.py#L83 |
| """ |
|
|
| UNLABELED = 0 |
| ROAD = auto() |
| WALKWAYS = auto() |
| CENTERLINE = auto() |
| STATIC_OBJECTS = auto() |
| VEHICLES = auto() |
| PEDESTRIANS = auto() |
|
|
|
|
| class NavSimBoundingBoxIndex(IntEnum): |
| """Index to access NavSim bounding box attribute array. |
| |
| See: https://github.com/autonomousvision/navsim/blob/main/navsim/agents/transfuser/transfuser_features.py#L174 |
| """ |
|
|
| X = 0 |
| Y = 1 |
| HEADING = 2 |
| LENGTH = 3 |
| WIDTH = 4 |
|
|
|
|
| class NavSimBBClass(IntEnum): |
| """Bounding box classes used in NavSim.""" |
|
|
| GENERIC_CLASS = 0 |
|
|
|
|
| class NavSimStatusFeature(IntEnum): |
| """Status feature indices used in NavSim.""" |
|
|
| DRIVING_COMMAND_LEFT = 0 |
| DRIVING_COMMAND_RIGHT = 1 |
| DRIVING_COMMAND_STRAIGHT = 2 |
| DRIVING_COMMAND_UNKNOWN = 3 |
| EGO_VELOCITY_X = 4 |
| EGO_VELOCITY_Y = 5 |
| ACCELERATION_X = 6 |
| ACCELERATION_Y = 7 |
|
|
|
|
| SIM2REAL_SEMANTIC_SEGMENTATION_CONVERTER = { |
| TransfuserSemanticSegmentationClass.UNLABELED: TransfuserSemanticSegmentationClass.UNLABELED, |
| TransfuserSemanticSegmentationClass.VEHICLE: TransfuserSemanticSegmentationClass.VEHICLE, |
| TransfuserSemanticSegmentationClass.ROAD: TransfuserSemanticSegmentationClass.ROAD, |
| TransfuserSemanticSegmentationClass.TRAFFIC_LIGHT: TransfuserSemanticSegmentationClass.TRAFFIC_LIGHT, |
| TransfuserSemanticSegmentationClass.PEDESTRIAN: TransfuserSemanticSegmentationClass.PEDESTRIAN, |
| TransfuserSemanticSegmentationClass.ROAD_LINE: TransfuserSemanticSegmentationClass.ROAD_LINE, |
| TransfuserSemanticSegmentationClass.OBSTACLE: TransfuserSemanticSegmentationClass.OBSTACLE, |
| TransfuserSemanticSegmentationClass.SPECIAL_VEHICLE: TransfuserSemanticSegmentationClass.VEHICLE, |
| TransfuserSemanticSegmentationClass.STOP_SIGN: TransfuserSemanticSegmentationClass.STOP_SIGN, |
| TransfuserSemanticSegmentationClass.BIKER: TransfuserSemanticSegmentationClass.BIKER, |
| } |
|
|
|
|
| SIM2REAL_BEV_SEMANTIC_SEGMENTATION_CONVERTER = { |
| TransfuserBEVSemanticClass.UNLABELED: TransfuserBEVSemanticClass.UNLABELED, |
| TransfuserBEVSemanticClass.ROAD: TransfuserBEVSemanticClass.ROAD, |
| TransfuserBEVSemanticClass.LANE_MARKERS: TransfuserBEVSemanticClass.LANE_MARKERS, |
| TransfuserBEVSemanticClass.STOP_SIGNS: TransfuserBEVSemanticClass.UNLABELED, |
| TransfuserBEVSemanticClass.VEHICLE: TransfuserBEVSemanticClass.VEHICLE, |
| TransfuserBEVSemanticClass.WALKER: TransfuserBEVSemanticClass.WALKER, |
| TransfuserBEVSemanticClass.OBSTACLE: TransfuserBEVSemanticClass.OBSTACLE, |
| TransfuserBEVSemanticClass.PARKING_VEHICLE: TransfuserBEVSemanticClass.VEHICLE, |
| TransfuserBEVSemanticClass.SPECIAL_VEHICLE: TransfuserBEVSemanticClass.SPECIAL_VEHICLE, |
| TransfuserBEVSemanticClass.BIKER: TransfuserBEVSemanticClass.BIKER, |
| TransfuserBEVSemanticClass.TRAFFIC_GREEN: TransfuserBEVSemanticClass.TRAFFIC_GREEN, |
| TransfuserBEVSemanticClass.TRAFFIC_RED_NORMAL: TransfuserBEVSemanticClass.TRAFFIC_RED_NORMAL, |
| TransfuserBEVSemanticClass.TRAFFIC_RED_NOT_NORMAL: TransfuserBEVSemanticClass.TRAFFIC_RED_NORMAL, |
| } |
|
|
| SIM2REAL_BOUNDING_BOX_CLASS_CONVERTER = { |
| TransfuserBoundingBoxClass.VEHICLE: TransfuserBoundingBoxClass.VEHICLE, |
| TransfuserBoundingBoxClass.WALKER: TransfuserBoundingBoxClass.WALKER, |
| TransfuserBoundingBoxClass.TRAFFIC_LIGHT: TransfuserBoundingBoxClass.TRAFFIC_LIGHT, |
| TransfuserBoundingBoxClass.STOP_SIGN: TransfuserBoundingBoxClass.STOP_SIGN, |
| TransfuserBoundingBoxClass.SPECIAL: TransfuserBoundingBoxClass.VEHICLE, |
| TransfuserBoundingBoxClass.OBSTACLE: TransfuserBoundingBoxClass.OBSTACLE, |
| TransfuserBoundingBoxClass.PARKING: TransfuserBoundingBoxClass.VEHICLE, |
| TransfuserBoundingBoxClass.BIKER: TransfuserBoundingBoxClass.BIKER, |
| } |
|
|
| SIM2REAL_BEV_OCCUPANCY_CLASS_CONVERTER = { |
| TransfuserBEVOccupancyClass.UNLABELED: TransfuserBEVOccupancyClass.UNLABELED, |
| TransfuserBEVOccupancyClass.VEHICLE: TransfuserBEVOccupancyClass.VEHICLE, |
| TransfuserBEVOccupancyClass.WALKER: TransfuserBEVOccupancyClass.WALKER, |
| TransfuserBEVOccupancyClass.OBSTACLE: TransfuserBEVOccupancyClass.UNLABELED, |
| TransfuserBEVOccupancyClass.PARKING_VEHICLE: TransfuserBEVOccupancyClass.VEHICLE, |
| TransfuserBEVOccupancyClass.SPECIAL_VEHICLE: TransfuserBEVOccupancyClass.VEHICLE, |
| TransfuserBEVOccupancyClass.BIKER: TransfuserBEVOccupancyClass.BIKER, |
| TransfuserBEVOccupancyClass.TRAFFIC_GREEN: TransfuserBEVOccupancyClass.TRAFFIC_GREEN, |
| TransfuserBEVOccupancyClass.TRAFFIC_RED_NORMAL: TransfuserBEVOccupancyClass.TRAFFIC_RED_NORMAL, |
| TransfuserBEVOccupancyClass.TRAFFIC_RED_NOT_NORMAL: TransfuserBEVOccupancyClass.TRAFFIC_RED_NORMAL, |
| } |
|
|
| CARLA_REAR_AXLE = [-1.389, 0.0, 0.360] |
|
|
| |
| WAYMO_E2E_INTRINSIC = [ |
| 1112.17806333, |
| 1112.04114501, |
| 488.128479, |
| 719.15600586, |
| -0.073584, |
| -0.036582, |
| 0.0, |
| 0.0, |
| 0.0, |
| ] |
|
|
| |
| WAYMO_E2E_2025_CAMERA_SETTING = { |
| "FRONT_LEFT": { |
| "extrinsic": [ |
| [0.70549636, -0.70869903, -0.00026687, 1.44509995], |
| [0.70868651, 0.70547806, 0.00679447, 0.15270001], |
| [-0.00462506, -0.00498361, 0.99996268, 1.80649996], |
| [0.0, 0.0, 0.0, 1.0], |
| ], |
| "width": 972, |
| "height": 1440, |
| "cropped_height": 1080, |
| }, |
| "FRONT": { |
| "extrinsic": [ |
| [0.99998026, -0.00215741, 0.00378466, 1.51909995], |
| [0.00214115, 0.99997577, 0.0045841, 0.0258], |
| [-0.00379113, -0.00457491, 0.9999674, 1.80649996], |
| [0.0, 0.0, 0.0, 1.0], |
| ], |
| "width": 972, |
| "height": 1440, |
| "cropped_height": 1080, |
| }, |
| "FRONT_RIGHT": { |
| "extrinsic": [ |
| [0.70908528, 0.70508644, 0.00506919, 1.48150003], |
| [-0.7050955, 0.70909768, -0.00009787, -0.1163], |
| [-0.00366141, -0.00350169, 0.99996981, 1.80649996], |
| [0.0, 0.0, 0.0, 1.0], |
| ], |
| "width": 972, |
| "height": 1440, |
| "cropped_height": 1080, |
| }, |
| } |
|
|
|
|
| class WaymoE2EIntrinsicIndex(IntEnum): |
| |
| F_U = 0 |
| F_V = 1 |
| C_U = 2 |
| C_V = 3 |
| K_1 = 4 |
| K_2 = 5 |
| P_1 = 6 |
| P_2 = 7 |
| K_3 = 8 |
|
|
|
|
| class CarlaImageCroppingType(IntEnum): |
| TOP = 0 |
| BOTTOM = auto() |
| BOTH = auto() |
| NONE = auto() |
|
|
|
|
| WAYMO_DOWN_SAMPLE_FACTOR = 3 |
| WAYMO_E2E_REAL_DATA_JPEG_LEVEL = 50 |
|
|
|
|
| @beartype |
| def normalize_imagenet(x: jt.Float[torch.Tensor, "B 3 H W"]) -> jt.Float[torch.Tensor, "B 3 H W"]: |
| """Normalize input images according to ImageNet standards. |
| Args: |
| x: Input images batch. |
| |
| Returns: |
| Normalized images batch. |
| """ |
| x = x.clone() |
| x[:, 0] = ((x[:, 0] / 255.0) - 0.485) / 0.229 |
| x[:, 1] = ((x[:, 1] / 255.0) - 0.456) / 0.224 |
| x[:, 2] = ((x[:, 2] / 255.0) - 0.406) / 0.225 |
| return x |
|
|
|
|
| class BaseConfig: |
| @property |
| def target_dataset(self): |
| raise NotImplementedError("Subclasses must implement the target_dataset property.") |
|
|
| |
| |
| bicycle_frame_rate = 20 |
| |
| num_route_points_saved = 50 |
| |
| points_per_meter = 10 |
| |
| |
| pixels_per_meter_collection = 2.0 |
| |
| longitudinal_max_accelerations = 1.89 |
|
|
| |
| |
| time_step = 1.0 / 20.0 |
| |
| front_wheel_base = -0.090769015 |
| |
| rear_wheel_base = 1.4178275 |
| |
| steering_gain = 0.36848336 |
| |
| brake_acceleration = -4.952399 |
| |
| throttle_acceleration = 0.5633837 |
| |
| throttle_threshold_during_forecasting = 0.3 |
|
|
| |
| |
| data_save_freq = 5 |
| |
| augment = False |
| |
| default_safety_translation_augmentation_penalty = 0.25 |
| |
| urban_safety_translation_augmentation_penalty = 0.4 |
|
|
| |
| camera_translation_augmentation_min = 0.1 |
| |
| camera_translation_augmentation_max = 1.0 |
|
|
| |
| camera_rotation_augmentation_min = 5.0 |
| |
| camera_rotation_augmentation_max = 12.5 |
| |
| camera_rotation_epsilon = 0.5 |
| |
| extreme_augment = False |
|
|
| |
| data_collection_extreme_augment_prob = 1.0 |
| |
| camera_extreme_rotation_augmentation_min = 12.5 |
| |
| camera_extreme_rotation_augmentation_max = 40 |
| |
| camera_extreme_translation_augmentation_min = 0.1 |
| |
| camera_extreme_translation_augmentation_max = 1.0 |
| |
| camera_extreme_pitch_augmentation_min = -1.0 |
| |
| camera_extreme_pitch_augmentation_max = 0.5 |
| |
| camera_extreme_roll_augmentation_max = 1.0 |
|
|
| |
| |
| point_format = 0 |
| |
| point_precision_x = point_precision_y = point_precision_z = 0.1 |
| |
| max_height_lidar = 10.0 |
| |
| min_height_lidar = -4.0 |
|
|
| |
| |
| use_two_lidars = True |
|
|
| |
| @property |
| def lidar_pos_1(self): |
| return [0.0, 0.0, 2.5] |
|
|
| |
| @property |
| def lidar_rot_1(self): |
| return [0.0, 0.0, -90.0] |
|
|
| |
| @property |
| def lidar_pos_2(self): |
| return [0.0, 0.0, 2.5] |
|
|
| |
| @property |
| def lidar_rot_2(self): |
| return [0.0, 0.0, -270.0] |
|
|
| |
| @property |
| def lidar_accumulation(self): |
| return True |
|
|
| |
| @property |
| def num_cameras(self): |
| """Number of cameras based on the target dataset.""" |
| return { |
| TargetDataset.CARLA_LEADERBOARD2_6CAMERAS: 6, |
| TargetDataset.CARLA_LEADERBOARD2_3CAMERAS: 3, |
| TargetDataset.NAVSIM_4CAMERAS: 4, |
| TargetDataset.WAYMO_E2E_2025_3CAMERAS: 3, |
| }[self.target_dataset] |
|
|
| @property |
| def camera_calibration(self): |
| """Camera calibration configuration with positions, rotations, and sensor parameters""" |
| if self.target_dataset == TargetDataset.CARLA_LEADERBOARD2_6CAMERAS: |
| return { |
| 1: { |
| "pos": [0.0, -0.3, 2.25], |
| "rot": [0.0, 0.0, -57.5], |
| "width": 1152 // 3, |
| "height": 384, |
| "cropped_height": 384, |
| "fov": 60, |
| }, |
| 2: { |
| "pos": [0.25, 0.0, 2.25], |
| "rot": [0.0, 0.0, 0.0], |
| "width": 1152 // 3, |
| "height": 384, |
| "cropped_height": 384, |
| "fov": 60, |
| }, |
| 3: { |
| "pos": [0.0, 0.3, 2.25], |
| "rot": [0.0, 0.0, 57.5], |
| "width": 1152 // 3, |
| "height": 384, |
| "cropped_height": 384, |
| "fov": 60, |
| }, |
| 4: { |
| "pos": [-0.30, 0.3, 2.25], |
| "rot": [0.0, 0.0, 180 - 57.5], |
| "width": 1152 // 3, |
| "height": 384, |
| "cropped_height": 384, |
| "fov": 60, |
| }, |
| 5: { |
| "pos": [-0.55, 0.0, 2.25], |
| "rot": [0.0, 0.0, 180.0], |
| "width": 1152 // 3, |
| "height": 384, |
| "cropped_height": 384, |
| "fov": 60, |
| }, |
| 6: { |
| "pos": [-0.30, -0.3, 2.25], |
| "rot": [0.0, 0.0, -180 + 57.5], |
| "width": 1152 // 3, |
| "height": 384, |
| "cropped_height": 384, |
| "fov": 60, |
| }, |
| } |
| elif self.target_dataset == TargetDataset.CARLA_LEADERBOARD2_3CAMERAS: |
| return { |
| 1: { |
| "pos": [0.1, -0.35, 2.25], |
| "rot": [0.0, 0.0, -54.5], |
| "width": 1152 // 3, |
| "height": 384, |
| "cropped_height": 384, |
| "fov": 60, |
| }, |
| 2: { |
| "pos": [0.35, 0.0, 2.25], |
| "rot": [0.0, 0.0, 0.0], |
| "width": 1152 // 3, |
| "height": 384, |
| "cropped_height": 384, |
| "fov": 60, |
| }, |
| 3: { |
| "pos": [0.1, 0.35, 2.25], |
| "rot": [0.0, 0.0, 54.5], |
| "width": 1152 // 3, |
| "height": 384, |
| "cropped_height": 384, |
| "fov": 60, |
| }, |
| } |
| elif self.target_dataset == TargetDataset.NAVSIM_4CAMERAS: |
| return { |
| 1: NUPLAN_CAMERA_CALIBRATION["CAM_L0"], |
| 2: NUPLAN_CAMERA_CALIBRATION["CAM_F0"], |
| 3: NUPLAN_CAMERA_CALIBRATION["CAM_R0"], |
| 4: NUPLAN_CAMERA_CALIBRATION["CAM_B0"], |
| } |
| elif self.target_dataset == TargetDataset.WAYMO_E2E_2025_3CAMERAS: |
| from pdm_lite_utils_0916 import waymo_e2e_camera_setting_to_carla |
|
|
| return { |
| 1: waymo_e2e_camera_setting_to_carla( |
| WAYMO_E2E_INTRINSIC, |
| WAYMO_E2E_2025_CAMERA_SETTING["FRONT_RIGHT"]["extrinsic"], |
| WAYMO_E2E_2025_CAMERA_SETTING["FRONT_RIGHT"]["width"], |
| WAYMO_E2E_2025_CAMERA_SETTING["FRONT_RIGHT"]["height"], |
| WAYMO_E2E_2025_CAMERA_SETTING["FRONT_RIGHT"]["cropped_height"], |
| ), |
| 2: waymo_e2e_camera_setting_to_carla( |
| WAYMO_E2E_INTRINSIC, |
| WAYMO_E2E_2025_CAMERA_SETTING["FRONT"]["extrinsic"], |
| WAYMO_E2E_2025_CAMERA_SETTING["FRONT"]["width"], |
| WAYMO_E2E_2025_CAMERA_SETTING["FRONT"]["height"], |
| WAYMO_E2E_2025_CAMERA_SETTING["FRONT"]["cropped_height"], |
| ), |
| 3: waymo_e2e_camera_setting_to_carla( |
| WAYMO_E2E_INTRINSIC, |
| WAYMO_E2E_2025_CAMERA_SETTING["FRONT_LEFT"]["extrinsic"], |
| WAYMO_E2E_2025_CAMERA_SETTING["FRONT_LEFT"]["width"], |
| WAYMO_E2E_2025_CAMERA_SETTING["FRONT_LEFT"]["height"], |
| WAYMO_E2E_2025_CAMERA_SETTING["FRONT_LEFT"]["cropped_height"], |
| ), |
| } |
| raise ValueError(f"Unsupported target dataset: {self.target_dataset}") |
|
|
| |
| num_radar_sensors = 4 |
|
|
| @property |
| def radar_calibration(self): |
| """Radar sensor calibration configuration with positions and orientations. |
| |
| Currently supports only 4 radar sensors. For other numbers, raises an error. |
| """ |
| return { |
| "1": { |
| "pos": [2.6, 0, 0.60], |
| "rot": [0.0, 0.0, -45.0], |
| "horz_fov": 90, |
| "vert_fov": 0.1, |
| }, |
| "2": { |
| "pos": [2.6, 0, 0.60], |
| "rot": [0.0, 0.0, 45.0], |
| "horz_fov": 90, |
| "vert_fov": 0.1, |
| }, |
| "3": { |
| "pos": [-2.6, 0, 0.60], |
| "rot": [0.0, 0.0, 135], |
| "horz_fov": 90, |
| "vert_fov": 0.1, |
| }, |
| "4": { |
| "pos": [-2.6, 0, 0.60], |
| "rot": [0.0, 0.0, 225], |
| "horz_fov": 90, |
| "vert_fov": 0.1, |
| }, |
| } |
|
|
| |
| use_radars = True |
| |
| save_radar_pc_as_lidar = True |
| |
| save_lidar_only_inside_bev = True |
| |
| duplicate_radar_near_ego = True |
| |
| duplicate_radar_radius = 32 |
| |
| duplicate_radar_factor = 5 |
|
|
| |
| |
| save_depth_lower_resolution = True |
|
|
| @property |
| def save_depth_resolution_ratio(self): |
| """Resolution reduction ratio for depth image storage.""" |
| if self.is_on_slurm: |
| return 4 |
| return 4 |
|
|
| |
| save_depth_bits = 8 |
| |
| save_only_non_ground_lidar = True |
| |
| save_grouped_semantic = True |
|
|
| |
| |
| ego_num_temporal_data_points_saved = 200 |
| |
| other_vehicles_num_temporal_data_points_saved = 40 |
|
|
| |
| |
| carla_fps = 20 |
| |
| carla_frame_rate = 1.0 / carla_fps |
| |
| iou_treshold_nms = 0.2 |
| |
| route_planner_min_distance = 7.5 |
| |
| route_planner_max_distance = 50.0 |
| |
| dense_route_planner_min_distance = 2.4 |
| |
| inital_frames_delay = 1 |
|
|
| |
| tp_distances = [ |
| 3.0, |
| 3.25, |
| 3.5, |
| 3.75, |
| 4.0, |
| 4.25, |
| 4.5, |
| 4.75, |
| 5.0, |
| 5.25, |
| 5.5, |
| 5.75, |
| 6.0, |
| 6.25, |
| 6.5, |
| 6.75, |
| 7.0, |
| 7.25, |
| 7.5, |
| 7.75, |
| 8.0, |
| 8.25, |
| 8.5, |
| 8.75, |
| 9.0, |
| 9.25, |
| 9.5, |
| 9.75, |
| 10.0, |
| ] |
|
|
| |
| ego_extent_x = 2.4508416652679443 |
| |
| ego_extent_y = 1.0641621351242065 |
| |
| ego_extent_z = 0.7553732395172119 |
|
|
| |
| safety_box_z_min = 0.5 |
| |
| safety_box_z_max = 1.5 |
|
|
| |
| @property |
| def safety_box_y_min(self): |
| """Minimum y coordinate of the safety box relative to ego vehicle.""" |
| return -self.ego_extent_y * 0.8 |
|
|
| @property |
| def safety_box_y_max(self): |
| """Maximum y coordinate of the safety box relative to ego vehicle.""" |
| return self.ego_extent_y * 0.8 |
|
|
| @property |
| def safety_box_x_min(self): |
| """Minimum x coordinate of the safety box relative to ego vehicle.""" |
| return self.ego_extent_x |
|
|
| @property |
| def safety_box_x_max(self): |
| """Maximum x coordinate of the safety box relative to ego vehicle.""" |
| return self.ego_extent_x + 2.5 |
|
|
| @property |
| def is_on_slurm(self): |
| """Check if running on SLURM cluster environment.""" |
| return os.getenv("SLURM_JOB_ID") is not None |
|
|
| @property |
| def is_on_tcml(self): |
| """Check if running on Training Center for Machine Learning of Tübingen.""" |
| return os.getenv("TCML") is not None |
|
|
| |
|
|
| def parse(self, loaded_config, env_key: str, raise_error_on_missing_key: bool): |
| |
| env_params = os.getenv(env_key, "").strip() |
| if not env_params: |
| env_params = {} |
| else: |
| parsed = OmegaConf.create(OmegaConf.from_dotlist(env_params.split())) |
| env_params = OmegaConf.to_container(parsed, resolve=True) |
|
|
| |
| if loaded_config is None: |
| loaded_config = {} |
|
|
| |
| for key, value in env_params.items(): |
| loaded_config[key] = value |
|
|
| |
| for key, value in loaded_config.items(): |
| if hasattr(self, key) and not callable(getattr(self, key)): |
| try: |
| setattr(self, key, value) |
| except Exception as _: |
| pass |
| elif raise_error_on_missing_key: |
| raise AttributeError(f"Unknown configuration key: {key}") |
|
|
| self._loaded_config = loaded_config |
|
|
| def __setattr__(self, name, value): |
| |
| |
| |
| allowed = set() |
| for cls in self.__class__.__mro__: |
| allowed.update(getattr(cls, "__dict__", {}).keys()) |
| allowed.update(self.__dict__.keys()) |
|
|
| if name not in allowed and not name.startswith("_"): |
| raise AttributeError( |
| f"Can't set unknown attribute '{name}'. Please check if this variable might have been renamed." |
| ) |
| super().__setattr__(name, value) |
|
|
| def base_dict(self): |
| out = {} |
| for k in dir(self): |
| if k.startswith("_"): |
| continue |
| try: |
| v = getattr(self, k) |
| if not callable(v): |
| out[k] = v |
| except Exception: |
| pass |
| return out |
|
|
|
|
| def overridable_property(fn): |
| attr_name = fn.__name__ |
|
|
| @property |
| def wrapper(self): |
| try: |
| if attr_name in self._loaded_config: |
| return type(fn(self))(self._loaded_config[attr_name]) |
| except: |
| pass |
| return fn(self) |
|
|
| return wrapper |
|
|
|
|
| class TrainingConfig(BaseConfig): |
| @property |
| def target_dataset(self): |
| if self.use_waymo_e2e_data and not self.mixed_data_training: |
| return TargetDataset.WAYMO_E2E_2025_3CAMERAS |
| elif self.use_navsim_data and not self.use_carla_data: |
| return TargetDataset.NAVSIM_4CAMERAS |
| elif "carla_leaderboad2_v3" in self.carla_root: |
| return TargetDataset.CARLA_LEADERBOARD2_3CAMERAS |
| elif "carla_leaderboad2_v8" in self.carla_root: |
| return TargetDataset.CARLA_LEADERBOARD2_3CAMERAS |
| elif "carla_leaderboad2_v10" in self.carla_root: |
| return TargetDataset.CARLA_LEADERBOARD2_6CAMERAS |
| elif "carla_leaderboad2_v12" in self.carla_root: |
| return TargetDataset.NAVSIM_4CAMERAS |
| elif "carla_leaderboad2_v14" in self.carla_root: |
| return TargetDataset.WAYMO_E2E_2025_3CAMERAS |
| elif "data/carla_today/results/data/garage_v2_2025_06_23" in self.carla_root: |
| return TargetDataset.NAVSIM_4CAMERAS |
| raise ValueError(f"Unknown CARLA root path: {self.carla_root}") |
|
|
| @property |
| def num_available_cameras(self): |
| """Number of available cameras based on the target dataset.""" |
| return { |
| TargetDataset.CARLA_LEADERBOARD2_3CAMERAS: 3, |
| TargetDataset.CARLA_LEADERBOARD2_6CAMERAS: 6, |
| TargetDataset.CARLA_LEADERBOARD2_3CAMERAS: 3, |
| TargetDataset.NAVSIM_4CAMERAS: 4, |
| TargetDataset.WAYMO_E2E_2025_3CAMERAS: 3, |
| }[self.target_dataset] |
|
|
| @overridable_property |
| def used_cameras(self): |
| """List indicating which cameras are used based on the target dataset. |
| Can be overriden, if a camera is false it will be ignored during training.""" |
| return [True] * self.num_available_cameras |
|
|
| @property |
| def num_used_cameras(self): |
| """Number of cameras used during training.""" |
| return sum(int(use) for use in self.used_cameras) |
|
|
| |
| |
| |
| pixels_per_meter = 4.0 |
|
|
| @property |
| def min_x_meter(self): |
| """Back boundary of the planning area in meters.""" |
| if self.target_dataset == TargetDataset.WAYMO_E2E_2025_3CAMERAS: |
| return 0 |
| return -32 |
|
|
| @property |
| def max_x_meter(self): |
| """Front boundary of the planning area in meters.""" |
| if self.carla_leaderboard_mode: |
| return 64 |
| if self.target_dataset == TargetDataset.WAYMO_E2E_2025_3CAMERAS: |
| return 64 |
| return 32 |
|
|
| @property |
| def min_y_meter(self): |
| """Left boundary of the planning area in meters.""" |
| if self.carla_leaderboard_mode: |
| return -40 |
| return -32 |
|
|
| @property |
| def max_y_meter(self): |
| """Right boundary of the planning area in meters.""" |
| if self.carla_leaderboard_mode: |
| return 40 |
| return 32 |
|
|
| @property |
| def lidar_width_pixel(self): |
| """Width resolution of LiDAR BEV representation in pixels.""" |
| return int((self.max_x_meter - self.min_x_meter) * self.pixels_per_meter) |
|
|
| @property |
| def lidar_height_pixel(self): |
| """Height resolution of LiDAR BEV representation in pixels.""" |
| return int((self.max_y_meter - self.min_y_meter) * self.pixels_per_meter) |
|
|
| @property |
| def lidar_width_meter(self): |
| """Width of LiDAR coverage area in meters.""" |
| return int(self.max_x_meter - self.min_x_meter) |
|
|
| @property |
| def lidar_height_meter(self): |
| """Height of LiDAR coverage area in meters.""" |
| return int(self.max_y_meter - self.min_y_meter) |
|
|
| |
| visualize_dataset = False |
| |
| visualize_failed_scenarios = False |
| |
| load_bev_3rd_person_images = False |
|
|
| |
| |
| seed = 0 |
| |
| wandb_id: str = None |
| |
| wandb_resume = "never" |
| |
| description = "My Experiment." |
| |
| visualize_training = True |
| |
| id = "Experiment 1." |
| |
| load_file: str = None |
| |
| continue_epoch = False |
|
|
| @property |
| def epoch_checkpoints_keep(self): |
| """Number of checkpoints to keep during training.""" |
| if self.use_carla_data and not self.mixed_data_training: |
| return [] |
| return [sum([1 * 2**i for i in range(n)]) for n in range(3, 10)] |
|
|
| |
| |
| use_training_session_cache = True |
| |
| use_persistent_cache = False |
| |
| force_rebuild_data_cache = False |
|
|
| @property |
| def carla_cache_path(self): |
| """Tuple of cache characteristics used to identify cached data compatibility.""" |
| return ( |
| str(self.image_width_before_camera_subselection), |
| str(self.final_image_height), |
| str(self.min_x_meter), |
| str(self.max_x_meter), |
| str(self.min_y_meter), |
| str(self.max_y_meter), |
| str(self.detect_boxes), |
| str(self.use_depth), |
| str(self.use_semantic), |
| str(self.use_bev_semantic), |
| str(self.load_bev_3rd_person_images), |
| str(self.training_used_lidar_steps), |
| str(self.num_radar_forecast_steps), |
| str(self.radar_waypoints_spacing), |
| ) |
|
|
| @property |
| def ssd_cache_path(self): |
| """Path to SSD cache directory.""" |
| tmp_folder = "/scratch/" + str(os.environ.get("SLURM_JOB_ID")) |
| if not self.is_on_tcml: |
| tmp_folder = str(os.environ.get("SCRATCH", "/tmp")) |
| return tmp_folder |
|
|
| |
| carla_root = "data/carla_leaderboad2_v12/results/data/sensor_data/" |
|
|
| @property |
| def carla_data(self): |
| """Path to CARLA data directory.""" |
| return os.path.join(self.carla_root, "data") |
|
|
| |
| |
| logdir = None |
| |
| training_png_compression_level = 6 |
| |
| vehicle_min_num_lidar_points = 1 |
| |
| vehicle_min_num_visible_pixels = 1 |
| |
| pedestrian_min_num_lidar_points = 5 |
| |
| pedestrian_min_num_visible_pixels = 15 |
| |
| parking_vehicle_min_num_lidar_points = 3 |
| |
| parking_vehicle_min_num_visible_pixels = 5 |
| |
| grad_scaler_init_scale = 1024 |
| |
| grad_scaler_growth_factor = 2 |
| |
| grad_scaler_backoff_factor = 0.5 |
| |
| grad_scaler_growth_interval = 256 |
| |
| grad_scaler_max_grad_scale = 2**16 |
|
|
| @property |
| def sync_batchnorm(self) -> bool: |
| """If true synchronize batch normalization across distributed processes.""" |
| return False |
|
|
| |
| freeze_batchnorm = False |
|
|
| @overridable_property |
| def epochs(self): |
| """Total number of training epochs.""" |
| if self.carla_leaderboard_mode: |
| return 31 |
| if self.target_dataset == TargetDataset.NAVSIM_4CAMERAS: |
| return 61 |
| if self.target_dataset == TargetDataset.WAYMO_E2E_2025_3CAMERAS: |
| return 20 |
| raise ValueError("Unknown target dataset") |
|
|
| @overridable_property |
| def batch_size(self): |
| """Batch size for training.""" |
| if not self.is_on_slurm: |
| return 4 |
| return 64 |
|
|
| @property |
| def torch_float_type(self): |
| """PyTorch float precision type for training.""" |
| if self.use_mixed_precision_training and self.gpu_name in ["a100", "l40s"]: |
| return torch.bfloat16 |
| return torch.float32 |
|
|
| @property |
| def use_mixed_precision_training(self): |
| """If true use mixed precision training.""" |
| return self.gpu_name in ["a100", "l40s"] |
|
|
| @property |
| def need_grad_scaler(self): |
| """If true gradient scaling is needed for mixed precision training.""" |
| return self.use_mixed_precision_training and self.torch_float_type == torch.float16 |
|
|
| |
| use_zero_redundancy = False |
|
|
| @property |
| def save_model_checkpoint(self): |
| """If true save model checkpoints during training.""" |
| if self.is_on_slurm: |
| return True |
| return True |
|
|
| @property |
| def is_pretraining(self): |
| """If true indicates pretraining phase.""" |
| return not self.use_planning_decoder |
|
|
| |
| |
| prefetch_factor = 8 |
|
|
| @property |
| def compile(self): |
| """If true compile the model for optimization.""" |
| return True |
|
|
| @property |
| def channel_last(self): |
| """If true use channel last memory format for input tensors.""" |
| return True |
|
|
| |
| |
| lr = 3e-4 |
|
|
| |
| @property |
| def skip_first(self): |
| """Number of frames to skip at the beginning of sequences.""" |
| if self.is_pretraining and not self.mixed_data_training: |
| return 1 |
| return self.num_way_points_prediction |
|
|
| @property |
| def skip_last(self): |
| """Number of frames to skip at the end of sequences.""" |
| if self.is_pretraining: |
| return 1 |
| if self.carla_leaderboard_mode: |
| return self.num_way_points_prediction |
| if self.target_dataset in [ |
| TargetDataset.NAVSIM_4CAMERAS, |
| ]: |
| return self.num_way_points_prediction * 2 |
| raise ValueError("Unknown target dataset") |
|
|
| |
| @property |
| def radar_detection(self): |
| """If true use radar points as additional input to the model.""" |
| return self.carla_leaderboard_mode and not self.use_tfpp_planning_decoder |
|
|
| @overridable_property |
| def use_radar_detection(self): |
| """If true use radar points as additional input to the model.""" |
| return self.carla_leaderboard_mode and not self.use_tfpp_planning_decoder |
|
|
| |
| num_radar_points_per_sensor = 75 |
| |
| num_radar_queries = 20 |
| |
| radar_hidden_dim_tokenizer = 1024 |
| |
| radar_token_dim = 256 |
| |
| radar_tf_dim_ff = 1024 |
| |
| radar_dropout = 0.1 |
| |
| radar_num_heads = 8 |
| |
| radar_num_layers = 4 |
| |
| radar_hidden_dim_decoder = 1024 |
| |
| radar_classification_loss_weight = 1.0 |
| |
| radar_forecast_loss_weight = 0.1 |
| |
| radar_regression_loss_weight = 5.0 |
| |
| radar_sine_features = 128 |
| |
| num_radar_sensors = 4 |
| |
| forecast_radar_detections = False |
| |
| num_radar_forecast_steps = 8 |
| |
| radar_waypoints_spacing = 5 |
| |
| max_distance_future_waypoint = 10.0 |
|
|
| |
| |
| force_rebuild_bucket = False |
| |
| randomize_route_order = False |
| |
| subsample_data = False |
| |
| hold_out_town13_routes = False |
|
|
| @property |
| def carla_bucket_collection(self): |
| """Name of the bucket collection to use for training data.""" |
| if ( |
| self.use_carla_data |
| and self.use_waymo_e2e_data |
| and not self.force_rebuild_bucket |
| and not self.force_rebuild_data_cache |
| ): |
| return "navsim" |
| if self.use_carla_data and self.use_navsim_data and not self.force_rebuild_bucket and not self.force_rebuild_data_cache: |
| return "navsim" |
| if self.visualize_failed_scenarios: |
| return "failed" |
| if self.subsample_data: |
| return "subsampled_posttrain" |
| if self.is_pretraining and self.hold_out_town13_routes: |
| return "town13_heldout_pretrain" |
| if self.is_pretraining or self.visualize_dataset or self.force_rebuild_data_cache: |
| return "full_pretrain" |
| if not self.is_pretraining and self.hold_out_town13_routes: |
| return "town13_heldout_posttrain" |
| return "full_posttrain" |
|
|
| @property |
| def bucket_collection_path(self): |
| """Path to bucket collection directory.""" |
| return os.path.join(self.carla_root, "buckets") |
|
|
| |
| |
| use_sensor_perburtation = True |
|
|
| |
| @overridable_property |
| def use_sensor_perburtation_prob(self): |
| if not self.carla_leaderboard_mode: |
| return 0.8 |
| return 0.5 |
|
|
| @overridable_property |
| def use_extreme_sensor_perburtation_prob(self): |
| """Probability of extreme sensor perburtation being applied if we sampled an sensor perburtated sample.""" |
| if self.is_pretraining and not self.carla_leaderboard_mode: |
| return 0.5 |
| return 0.0 |
|
|
| |
| @property |
| def use_color_aug(self): |
| """If true apply image color based augmentations.""" |
| |
| return not self.visualize_dataset |
|
|
| @property |
| def use_color_aug_prob(self): |
| """Probability to apply the different image color augmentations.""" |
| if self.carla_leaderboard_mode: |
| return 0.2 |
| return 0.1 |
|
|
| |
| weight_decay = 0.01 |
|
|
| |
| @property |
| def use_grad_clip(self): |
| """If true use gradient clipping during training.""" |
| return False |
|
|
| |
| grad_quantile = 0.99 |
| |
| grad_history_length = 1000 |
| |
| use_optimizer_groups = False |
|
|
| |
| @overridable_property |
| def use_cosine_annealing_with_restarts(self): |
| if self.target_dataset == TargetDataset.WAYMO_E2E_2025_3CAMERAS: |
| return False |
| return True |
|
|
| |
| @overridable_property |
| def use_depth(self): |
| """If true use depth prediction as auxiliary task.""" |
| return self.carla_leaderboard_mode |
|
|
| |
| @property |
| def training_used_lidar_steps(self): |
| """We stack lidar frames for motion cues. Number of past frames we stack for the model input.""" |
| return 10 |
|
|
| |
| min_z = -4 |
| |
| max_z = 4 |
| |
| hist_max_per_pixel = 5 |
|
|
| @property |
| def lidar_vert_anchors(self): |
| """Number of vertical anchors for LiDAR feature maps.""" |
| return self.lidar_height_pixel // 32 |
|
|
| @property |
| def lidar_horz_anchors(self): |
| """Number of horizontal anchors for LiDAR feature maps.""" |
| return self.lidar_width_pixel // 32 |
|
|
| |
| |
| detect_boxes = True |
| |
| debug_boxes_visualization = True |
| |
| center_net_global_avg_factor = False |
| |
| data_bb_static_types_white_list = ["static.prop.constructioncone", "static.prop.trafficwarning"] |
| |
| bb_confidence_threshold = 0.3 |
| |
| max_num_bbs = 90 |
| |
| |
| num_dir_bins = 12 |
| |
| top_k_center_keypoints = 100 |
| |
| center_net_max_pooling_kernel = 3 |
| |
| bb_input_channel = 64 |
| |
| car_open_door_extra_width = 1.2 |
| |
| num_bb_classes = len(TransfuserBoundingBoxClass) |
|
|
| |
| @overridable_property |
| def use_discrete_command(self): |
| """If true use discrete command input to the network.""" |
| if self.target_dataset == TargetDataset.WAYMO_E2E_2025_3CAMERAS: |
| return False |
| return True |
|
|
| @property |
| def discrete_command_dim(self): |
| """Dimension of discrete command input.""" |
| if self.carla_leaderboard_mode: |
| return 6 |
| elif self.target_dataset in [ |
| TargetDataset.NAVSIM_4CAMERAS, |
| ]: |
| return 4 |
| elif self.target_dataset in [ |
| TargetDataset.WAYMO_E2E_2025_3CAMERAS, |
| ]: |
| return 4 |
| raise ValueError("Unknown target dataset") |
|
|
| |
| use_noisy_tp = False |
| |
| use_velocity = True |
|
|
| @property |
| def max_speed(self): |
| """Maximum speed limit for the vehicle in m/s.""" |
| if self.carla_leaderboard_mode: |
| return 25.0 |
| if self.target_dataset in [ |
| TargetDataset.NAVSIM_4CAMERAS, |
| ]: |
| return 15.0 |
| if self.target_dataset in [ |
| TargetDataset.WAYMO_E2E_2025_3CAMERAS, |
| ]: |
| return 33.33 |
| raise ValueError("Unknown target dataset") |
|
|
| @property |
| def use_acceleration(self): |
| """If true use the acceleration as input to the network.""" |
| return not self.carla_leaderboard_mode and self.target_dataset not in [TargetDataset.WAYMO_E2E_2025_3CAMERAS] |
|
|
| @property |
| def max_acceleration(self): |
| """Maximum acceleration for normalization.""" |
| if self.carla_leaderboard_mode: |
| return 10.0 |
| if self.target_dataset in [ |
| TargetDataset.NAVSIM_4CAMERAS, |
| ]: |
| return 4.0 |
|
|
| @property |
| def use_previous_tp(self): |
| """If true use the previous/visited target point as input to the network.""" |
| if self.carla_leaderboard_mode and not self.use_tfpp_planning_decoder: |
| return True |
| return False |
|
|
| @property |
| def use_next_tp(self): |
| """If true use the next/subsequent target point as input to the network.""" |
| if self.carla_leaderboard_mode and not self.use_tfpp_planning_decoder: |
| return True |
| return False |
|
|
| @property |
| def use_tp(self): |
| """If true use the current target point as input to the network.""" |
| if self.carla_leaderboard_mode: |
| return True |
| return False |
|
|
| @property |
| def target_points_normalization_constants(self): |
| """Normalization constants for target points [x_norm, y_norm].""" |
| return [[200.0, 50.0]] |
|
|
| @property |
| def tp_pop_distance(self): |
| """Distance threshold for popping target points from route.""" |
| return 3.25 |
|
|
| @overridable_property |
| def use_past_positions(self): |
| """If true use past positions as input to the network.""" |
| return not self.carla_leaderboard_mode and self.target_dataset in [TargetDataset.WAYMO_E2E_2025_3CAMERAS] |
|
|
| @overridable_property |
| def use_past_speeds(self): |
| """If true use past speeds as input to the network.""" |
| return not self.carla_leaderboard_mode and self.target_dataset in [TargetDataset.WAYMO_E2E_2025_3CAMERAS] |
|
|
| @overridable_property |
| def num_past_samples_used(self): |
| """Number of past samples to use as input to the network.""" |
| if self.use_past_positions or self.use_past_speeds: |
| return 6 |
| return 0 |
|
|
| |
| |
| transfuser_num_bev_cross_attention_layers = 6 |
| |
| transfuser_num_bev_cross_attention_heads = 8 |
| |
| transfuser_token_dim = 256 |
|
|
| @property |
| def predict_target_speed(self): |
| """If true predict target speed.""" |
| return self.carla_leaderboard_mode |
|
|
| @property |
| def predict_spatial_path(self): |
| """If true predict spatial path.""" |
| return self.carla_leaderboard_mode |
|
|
| |
| predict_temporal_spatial_waypoints = True |
|
|
| |
| use_planning_decoder = False |
|
|
| |
| use_tfpp_planning_decoder = False |
|
|
| |
| gru_hidden_size = 64 |
|
|
| @property |
| def target_speed_classes(self): |
| """Carla target speed prediction classes in m/s.""" |
| return [ |
| 0.0, |
| 4.0, |
| 8.0, |
| 10.0, |
| 13.88888888, |
| 16.0, |
| 17.77777777, |
| 20.0, |
| ] |
|
|
| @property |
| def target_speeds(self): |
| return self.target_speed_classes |
|
|
| |
| smooth_route = True |
| |
| num_route_points_smoothing = 20 |
| |
| num_route_points_prediction = 10 |
|
|
| @property |
| def num_way_points_prediction(self): |
| """Number of waypoints to predict.""" |
| if self.carla_leaderboard_mode: |
| return 8 |
| elif self.target_dataset in [ |
| TargetDataset.NAVSIM_4CAMERAS, |
| ]: |
| return 8 |
| elif self.target_dataset in [TargetDataset.WAYMO_E2E_2025_3CAMERAS]: |
| return 10 |
| raise ValueError("Unknown target dataset") |
|
|
| |
| @property |
| def waypoints_spacing(self): |
| if self.carla_leaderboard_mode: |
| return 5 |
| elif self.target_dataset in [ |
| TargetDataset.NAVSIM_4CAMERAS, |
| ]: |
| return 10 |
| elif self.target_dataset in [TargetDataset.WAYMO_E2E_2025_3CAMERAS]: |
| return 10 |
| raise ValueError("Unknown target dataset") |
|
|
| |
| @property |
| def crop_height(self): |
| """The amount of pixels cropped from the bottom of the image.""" |
| return self.camera_calibration[1]["height"] - self.camera_calibration[1]["cropped_height"] |
|
|
| @property |
| def carla_crop_height_type(self): |
| """Type of cropping applied to CARLA images.""" |
| if self.carla_leaderboard_mode: |
| return CarlaImageCroppingType.NONE |
| elif self.target_dataset in [ |
| TargetDataset.NAVSIM_4CAMERAS, |
| ]: |
| return CarlaImageCroppingType.BOTTOM |
| elif self.target_dataset in [TargetDataset.WAYMO_E2E_2025_3CAMERAS]: |
| return CarlaImageCroppingType.NONE |
| raise ValueError("Unknown target dataset") |
|
|
| @property |
| def image_width_before_camera_subselection(self): |
| """Final width of images after loading from disk but before camera sub-selection.""" |
| return self.num_available_cameras * self.camera_calibration[1]["width"] |
|
|
| @property |
| def final_image_width(self): |
| """Final width of images after cropping and camera sub-selection.""" |
| return self.num_used_cameras * self.camera_calibration[1]["width"] |
|
|
| @property |
| def final_image_height(self): |
| """Final height of images after cropping.""" |
| return self.camera_calibration[1]["cropped_height"] |
|
|
| @property |
| def img_vert_anchors(self): |
| """Number of vertical anchors for image feature maps.""" |
| return self.final_image_height // 32 |
|
|
| @property |
| def img_horz_anchors(self): |
| """Number of horizontal anchors for image feature maps.""" |
| return self.num_used_cameras * self.camera_calibration[1]["width"] // 32 |
|
|
| |
| |
| freeze_backbone = False |
| |
| image_architecture = "resnet34" |
| |
| lidar_architecture = "resnet34" |
| |
| LTF = False |
| |
| num_fusion_stages = 4 |
| |
| only_ego_status_input = False |
|
|
| |
| |
| block_exp = 4 |
| |
| n_layer = 2 |
| |
| n_head = 4 |
| |
| embd_pdrop = 0.1 |
| |
| resid_pdrop = 0.1 |
| |
| attn_pdrop = 0.1 |
| |
| gpt_linear_layer_init_mean = 0.0 |
| |
| gpt_linear_layer_init_std = 0.02 |
| |
| gpt_layer_norm_init_weight = 1.0 |
|
|
| |
| |
| use_semantic = True |
| |
| num_semantic_classes = len(TransfuserSemanticSegmentationClass) |
| |
| perspective_downsample_factor = 1 |
| |
| deconv_channel_num_0 = 128 |
| |
| deconv_channel_num_1 = 64 |
| |
| deconv_channel_num_2 = 32 |
| |
| deconv_scale_factor_0 = 4 |
| |
| deconv_scale_factor_1 = 8 |
|
|
| |
| |
| use_bev_semantic = True |
| |
| num_bev_semantic_classes = len(TransfuserBEVSemanticClass) |
| |
| scale_pedestrian_bev_semantic_size = 2.5 |
| |
| pedestrian_bev_min_extent = 0.4 |
| |
| bev_features_chanels = 64 |
| |
| bev_down_sample_factor = 4 |
| |
| bev_upsample_factor = 2 |
|
|
| |
| |
| use_carla_data = True |
| |
| carla_num_samples = -1 |
| |
| use_navsim_data = False |
| |
| navsim_data_root = "data/navsim_training_cache/trainval" |
| |
| navsim_num_samples = -1 |
| |
| schedule_carla_num_samples = False |
| |
| use_waymo_e2e_data = False |
| |
| waymo_e2e_num_training_samples = -1 |
| |
| waymo_e2e_training_data_root = "data/waymo_open_dataset_end_to_end_camera_v_1_0_0_training" |
| |
| waymo_e2e_val_data_root = "data/waymo_open_dataset_end_to_end_camera_v_1_0_0_val_rfm" |
| |
| waymo_e2e_test_data_root = "data/waymo_open_dataset_end_to_end_camera_v_1_0_0_test_submission" |
| |
| waymo_e2e_subsample_factor = 5 |
|
|
| @property |
| def navsim_num_bev_semantic_classes(self): |
| """Number of BEV semantic classes in NavSim data.""" |
| return len(NavSimBEVSemanticClass) |
|
|
| @property |
| def navsim_num_bb_classes(self): |
| """Number of bb classes in NavSim data.""" |
| return len(NavSimBBClass) |
|
|
| @property |
| def mixed_data_training(self): |
| """If true use mixed data for training.""" |
| return int(self.use_navsim_data) + int(self.use_carla_data) + int(self.use_waymo_e2e_data) > 1 |
|
|
| @property |
| def carla_leaderboard_mode(self): |
| """If true use CARLA leaderboard mode settings.""" |
| return ( |
| self.target_dataset |
| in [ |
| TargetDataset.CARLA_LEADERBOARD2_3CAMERAS, |
| TargetDataset.CARLA_LEADERBOARD2_6CAMERAS, |
| ] |
| and not self.mixed_data_training |
| ) |
|
|
| @beartype |
| def detailed_loss_weights(self, source_dataset: int, epoch: int) -> dict[str, float]: |
| """Computed loss weights for all auxiliary tasks with normalization.""" |
|
|
| weights = { |
| "loss_semantic": 1.0, |
| "loss_depth": 0.00001, |
| "loss_bev_semantic": 1.0, |
| "loss_center_net_heatmap": 1.0, |
| "loss_center_net_wh": 1.0, |
| "loss_center_net_offset": 1.0, |
| "loss_center_net_yaw_class": 1.0, |
| "loss_center_net_yaw_res": 1.0, |
| "loss_center_net_velocity": 1.0, |
| "radar_loss": 1.0, |
| } |
|
|
| if source_dataset != SourceDataset.CARLA: |
| weights["radar_loss"] = 0.0 |
| weights["loss_semantic"] = 0.0 |
| weights["loss_depth"] = 0.0 |
| weights["loss_center_net_velocity"] = 0.0 |
|
|
| if self.LTF: |
| weights["loss_center_net_velocity"] = 0.0 |
|
|
| if not self.use_semantic: |
| weights["loss_semantic"] = 0.0 |
|
|
| if not self.use_depth: |
| weights["loss_depth"] = 0.0 |
|
|
| if not self.use_bev_semantic: |
| weights["loss_bev_semantic"] = 0.0 |
|
|
| if not self.detect_boxes: |
| weights["loss_center_net_heatmap"] = 0.0 |
| weights["loss_center_net_wh"] = 0.0 |
| weights["loss_center_net_offset"] = 0.0 |
| weights["loss_center_net_yaw_class"] = 0.0 |
| weights["loss_center_net_yaw_res"] = 0.0 |
| weights["loss_center_net_velocity"] = 0.0 |
|
|
| if self.training_used_lidar_steps <= 1: |
| weights["loss_center_net_velocity"] = 0.0 |
|
|
| if not self.radar_detection: |
| weights["radar_loss"] = 0.0 |
|
|
| |
| prefix = f"{SOURCE_DATASET_NAME_MAP[source_dataset]}_" |
| if source_dataset == SourceDataset.CARLA: |
| prefix = "" |
| weights = {f"{prefix}{k}": v for k, v in weights.items()} |
|
|
| |
| weights.update( |
| { |
| "loss_spatio_temporal_waypoints": 1.0, |
| "loss_target_speed": 1.0, |
| "loss_spatial_route": 1.0, |
| } |
| ) |
|
|
| |
| if not self.use_planning_decoder: |
| weights["loss_spatio_temporal_waypoints"] = 0.0 |
| weights["loss_spatial_route"] = 0.0 |
| weights["loss_target_speed"] = 0.0 |
|
|
| return weights |
|
|
| @property |
| def log_scalars_frequency(self): |
| """How often to log scalar values during training.""" |
| if not self.is_on_slurm: |
| return 1 |
| try: |
| with open("scripts/slurm/configs/wandb_log_frequency_training_scalar.txt") as f: |
| return int(f.readline().strip()) |
| except Exception as e: |
| print(f"Error reading log frequency file: {e}.") |
| return 1 |
|
|
| @property |
| def log_images_frequency(self): |
| """How often to log images during training.""" |
| if not self.is_on_slurm: |
| return 1 |
| try: |
| with open("scripts/slurm/configs/wandb_log_frequency_training_images.txt") as f: |
| return int(f.readline().strip()) |
| except Exception as e: |
| print(f"Error reading log frequency file: {e}.") |
| return 1000 |
|
|
| @property |
| def slurm_job_id(self): |
| """Current SLURM job ID if running on SLURM cluster.""" |
| if self.is_on_slurm: |
| return os.environ.get("SLURM_JOB_ID", "0") |
| return None |
|
|
| @property |
| def log_wandb(self): |
| """If true log metrics to Weights & Biases.""" |
| if self.is_on_slurm: |
| return True |
| return False |
|
|
| |
| @property |
| def gpu_name(self): |
| """Normalized GPU name for hardware-specific configurations.""" |
| try: |
| name = torch.cuda.get_device_name().lower() |
| if "rtx 2080 ti" in name: |
| return "rtx2080ti" |
| elif "gtx 1080 ti" in name: |
| return "gtx1080ti" |
| elif "a100" in name: |
| return "a100" |
| elif "l40s" in name: |
| return "l40s" |
| elif "a4000" in name: |
| return "a4000" |
| else: |
| raise Exception(f"Unknown GPU name: {name}") |
| except RuntimeError: |
| return "" |
|
|
| @property |
| def rank(self): |
| """Current process rank in distributed training.""" |
| return int(os.environ.get("RANK", "0")) |
|
|
| @property |
| def world_size(self): |
| """Total number of processes in distributed training.""" |
| return int(os.environ.get("WORLD_SIZE", "1")) |
|
|
| @property |
| def local_rank(self): |
| """Local rank of current process on the node.""" |
| return int(os.environ.get("LOCAL_RANK", "0")) |
|
|
| @property |
| def device(self): |
| """PyTorch device to use for training.""" |
| return torch.device(f"cuda:{self.local_rank}" if torch.cuda.is_available() else "cpu") |
|
|
| @property |
| def assigned_cpu_cores(self): |
| """Number of CPU cores assigned to this job.""" |
| if "SLURM_JOB_ID" in os.environ: |
| cpus_per_task = os.environ.get("SLURM_CPUS_PER_TASK") |
| if cpus_per_task: |
| return int(cpus_per_task) |
| return 8 |
|
|
| @property |
| def workers_per_cpu_cores(self): |
| """Number of data loader workers per CPU core.""" |
| if not self.mixed_data_training and not self.use_carla_data: |
| return 1 |
| return 1 |
|
|
| def training_dict(self): |
| out = {} |
| cls = self.__class__ |
| for k, v in cls.__dict__.items(): |
| if isinstance(v, property): |
| try: |
| out[k] = getattr(self, k) |
| except Exception: |
| pass |
| else: |
| if not k.startswith("__") and not callable(v): |
| out[k] = v |
| for k, v in self.__dict__.items(): |
| if isinstance(v, property): |
| try: |
| out[k] = getattr(self, k) |
| except Exception: |
| pass |
| else: |
| if not k.startswith("__") and not callable(v): |
| out[k] = v |
| return out |
|
|
| def __init__(self, loaded_config: dict = None, raise_error_on_missing_key: bool = False): |
| """Constructor for training config.""" |
| super().__init__() |
| self.parse( |
| loaded_config=loaded_config, |
| env_key="TRAINING_CONFIG", |
| raise_error_on_missing_key=raise_error_on_missing_key, |
| ) |
|
|
|
|
| class TransfuserBackbone(nn.Module): |
| @beartype |
| def __init__(self, device: torch.device, config: TrainingConfig): |
| super().__init__() |
| self.device = device |
| self.config = config |
|
|
| |
| self.image_encoder = timm.create_model(config.image_architecture, pretrained=True, features_only=True) |
| self.avgpool_img = nn.AdaptiveAvgPool2d((self.config.img_vert_anchors, self.config.img_horz_anchors)) |
| image_start_index = 0 |
| if len(self.image_encoder.return_layers) > 4: |
| image_start_index += 1 |
| self.num_image_features = self.image_encoder.feature_info.info[image_start_index + 3]["num_chs"] |
|
|
| |
| if self.config.num_fusion_stages > 1: |
| self.lidar_encoder = timm.create_model( |
| config.lidar_architecture, pretrained=False, in_chans=2 if config.LTF else 1, features_only=True |
| ) |
| lidar_start_index = 0 |
| if len(self.lidar_encoder.return_layers) > 4: |
| lidar_start_index += 1 |
| self.num_lidar_features = self.lidar_encoder.feature_info.info[lidar_start_index + 3]["num_chs"] |
| self.lidar_channel_to_img = nn.ModuleList( |
| [ |
| nn.Conv2d( |
| self.lidar_encoder.feature_info.info[lidar_start_index + i]["num_chs"], |
| self.image_encoder.feature_info.info[image_start_index + i]["num_chs"], |
| kernel_size=1, |
| ) |
| for i in range(4 - config.num_fusion_stages, 4) |
| ] |
| ) |
| self.img_channel_to_lidar = nn.ModuleList( |
| [ |
| nn.Conv2d( |
| self.image_encoder.feature_info.info[image_start_index + i]["num_chs"], |
| self.lidar_encoder.feature_info.info[lidar_start_index + i]["num_chs"], |
| kernel_size=1, |
| ) |
| for i in range(4 - config.num_fusion_stages, 4) |
| ] |
| ) |
| else: |
| self.num_lidar_features = self.num_image_features |
| self.bev_queries = nn.Parameter( |
| torch.randn( |
| 1, |
| self.num_lidar_features, |
| self.config.lidar_vert_anchors, |
| self.config.lidar_horz_anchors, |
| ) |
| ) |
| self.lidar_channel_to_img = nn.ModuleList( |
| [ |
| nn.Conv2d( |
| self.num_lidar_features, |
| self.image_encoder.feature_info.info[image_start_index + 3]["num_chs"], |
| kernel_size=1, |
| ) |
| ] |
| ) |
| self.img_channel_to_lidar = nn.ModuleList( |
| [ |
| nn.Conv2d( |
| self.image_encoder.feature_info.info[image_start_index + 3]["num_chs"], |
| self.num_lidar_features, |
| kernel_size=1, |
| ) |
| ] |
| ) |
| self.avgpool_lidar = nn.AdaptiveAvgPool2d((self.config.lidar_vert_anchors, self.config.lidar_horz_anchors)) |
|
|
| |
| self.transformers = nn.ModuleList( |
| [ |
| GPT(n_embd=self.image_encoder.feature_info.info[image_start_index + i]["num_chs"], config=config) |
| for i in range(4 - config.num_fusion_stages, 4) |
| ] |
| ) |
|
|
| |
| self.perspective_upsample_factor = ( |
| self.image_encoder.feature_info.info[image_start_index + 3]["reduction"] |
| // self.config.perspective_downsample_factor |
| ) |
|
|
| self.upsample = nn.Upsample(scale_factor=self.config.bev_upsample_factor, mode="bilinear", align_corners=False) |
| self.upsample2 = nn.Upsample( |
| size=( |
| self.config.lidar_height_pixel // self.config.bev_down_sample_factor, |
| self.config.lidar_width_pixel // self.config.bev_down_sample_factor, |
| ), |
| mode="bilinear", |
| align_corners=False, |
| ) |
| self.up_conv5 = nn.Conv2d(self.config.bev_features_chanels, self.config.bev_features_chanels, (3, 3), padding=1) |
| self.up_conv4 = nn.Conv2d(self.config.bev_features_chanels, self.config.bev_features_chanels, (3, 3), padding=1) |
| self.c5_conv = nn.Conv2d(self.num_lidar_features, self.config.bev_features_chanels, (1, 1)) |
|
|
| def top_down(self, x): |
| p5 = F.relu(self.c5_conv(x), inplace=True) |
| p4 = F.relu(self.up_conv5(self.upsample(p5)), inplace=True) |
| p3 = F.relu(self.up_conv4(self.upsample2(p4)), inplace=True) |
| return p3 |
|
|
| def forward(self, data: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]: |
| rgb = data["rgb"].to(self.device, dtype=self.config.torch_float_type, non_blocking=True) |
| if self.config.LTF: |
| if self.config.num_fusion_stages > 1: |
| x = torch.linspace(0, 1, self.config.lidar_width_pixel) |
| y = torch.linspace(0, 1, self.config.lidar_height_pixel) |
| y_grid, x_grid = torch.meshgrid(y, x, indexing="ij") |
|
|
| lidar = torch.zeros( |
| (rgb.shape[0], 2, self.config.lidar_height_pixel, self.config.lidar_width_pixel), |
| device=rgb.device, |
| ) |
| lidar[:, 0] = y_grid.unsqueeze(0) |
| lidar[:, 1] = x_grid.unsqueeze(0) |
| else: |
| lidar = None |
| else: |
| lidar = data["lidar"].to(self.device, dtype=self.config.torch_float_type, non_blocking=True) |
| return self._forward(rgb, lidar) |
|
|
| @jt.jaxtyped(typechecker=beartype) |
| def _forward( |
| self, |
| image: jt.Float[torch.Tensor, "B 3 img_h img_w"], |
| lidar: jt.Float[torch.Tensor, "B 1 bev_h bev_w"] | jt.Float[torch.Tensor, "B 2 bev_h bev_w"] | None, |
| ) -> tuple[jt.Float[torch.Tensor, "B D1 H1 W1"], jt.Float[torch.Tensor, "B D2 H2 W2"]]: |
| """ |
| Image + LiDAR feature fusion using transformers |
| Args: |
| image: RGB image. |
| lidar: Pseudo-image LiDAR. |
| Returns: |
| lidar_features: BEV feature map for planning. |
| image_features: Image feature map for perception. |
| """ |
| image_features = normalize_imagenet(image) |
| lidar_features = lidar |
|
|
| if self.config.channel_last: |
| image = image.to(memory_format=torch.channels_last) |
| if lidar is not None: |
| lidar = lidar.to(memory_format=torch.channels_last) |
|
|
| |
| image_layers = iter(self.image_encoder.items()) |
| if self.config.num_fusion_stages > 1: |
| lidar_layers = iter(self.lidar_encoder.items()) |
|
|
| |
| if len(self.image_encoder.return_layers) > 4: |
| image_features = self.forward_layer_block(image_layers, self.image_encoder.return_layers, image_features) |
| if self.config.num_fusion_stages > 1 and len(self.lidar_encoder.return_layers) > 4: |
| lidar_features = self.forward_layer_block(lidar_layers, self.lidar_encoder.return_layers, lidar_features) |
|
|
| |
| for i in range(4): |
| |
| image_features = self.forward_layer_block(image_layers, self.image_encoder.return_layers, image_features) |
| if self.config.num_fusion_stages > 1: |
| lidar_features = self.forward_layer_block(lidar_layers, self.lidar_encoder.return_layers, lidar_features) |
|
|
| |
| if i >= 4 - self.config.num_fusion_stages: |
| if self.config.num_fusion_stages == 1: |
| lidar_features = self.bev_queries.expand(image_features.shape[0], -1, -1, -1) |
| image_features, lidar_features = self.fuse_features( |
| image_features, lidar_features, i - (4 - self.config.num_fusion_stages) |
| ) |
|
|
| return lidar_features, image_features |
|
|
| @beartype |
| def forward_layer_block(self, layers, return_layers: dict[str, str], features: torch.Tensor) -> torch.Tensor: |
| """Run one forward pass to a block of layers from a TIMM neural network and returns the result. |
| Advances the whole network by just one block. |
| |
| Args: |
| layers: Iterator starting at the current layer block of the target network. |
| return_layers: TIMM dictionary describing at which intermediate layers features are returned. |
| features: Input features. |
| |
| Return: |
| torch.Tensor: Processed features |
| """ |
| for name, module in layers: |
| features = module(features) |
| if name in return_layers: |
| break |
| return features |
|
|
| def fuse_features( |
| self, image_features: torch.Tensor, lidar_features: torch.Tensor, layer_idx: int |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Perform a TransFuser feature fusion block using a Transformer module. |
| Args: |
| image_features: Features from the image branch |
| lidar_features: Features from the LiDAR branch |
| layer_idx: Transformer layer index. |
| Returns: |
| image_features and lidar_features with added features from the other branch. |
| """ |
| image_embd_layer = self.avgpool_img(image_features) |
| lidar_embd_layer = self.avgpool_lidar(lidar_features) |
| lidar_embd_layer = self.lidar_channel_to_img[layer_idx](lidar_embd_layer) |
|
|
| image_features_layer, lidar_features_layer = self.transformers[layer_idx](image_embd_layer, lidar_embd_layer) |
|
|
| lidar_features_layer = self.img_channel_to_lidar[layer_idx](lidar_features_layer) |
| image_features_layer = F.interpolate( |
| image_features_layer, size=(image_features.shape[2], image_features.shape[3]), mode="bilinear", align_corners=False |
| ) |
| lidar_features_layer = F.interpolate( |
| lidar_features_layer, size=(lidar_features.shape[2], lidar_features.shape[3]), mode="bilinear", align_corners=False |
| ) |
|
|
| image_features = image_features + image_features_layer |
| lidar_features = lidar_features + lidar_features_layer |
|
|
| return image_features, lidar_features |
|
|
|
|
| class GPT(nn.Module): |
| def __init__(self, n_embd, config): |
| super().__init__() |
| self.n_embd = n_embd |
| self.config = config |
| |
| self.pos_emb = nn.Parameter( |
| torch.zeros( |
| 1, |
| self.config.img_vert_anchors * self.config.img_horz_anchors |
| + self.config.lidar_vert_anchors * self.config.lidar_horz_anchors, |
| self.n_embd, |
| ) |
| ) |
| self.drop = nn.Dropout(config.embd_pdrop) |
| |
| self.blocks = nn.Sequential( |
| *[ |
| Block(n_embd, config.n_head, config.block_exp, config.attn_pdrop, config.resid_pdrop) |
| for layer in range(config.n_layer) |
| ] |
| ) |
| |
| self.ln_f = nn.LayerNorm(n_embd) |
| self.apply(self._init_weights) |
|
|
| def _init_weights(self, module): |
| if isinstance(module, nn.Linear): |
| module.weight.data.normal_(mean=self.config.gpt_linear_layer_init_mean, std=self.config.gpt_linear_layer_init_std) |
| if module.bias is not None: |
| module.bias.data.zero_() |
| elif isinstance(module, nn.LayerNorm): |
| module.bias.data.zero_() |
| module.weight.data.fill_(self.config.gpt_layer_norm_init_weight) |
|
|
| def forward(self, image_tensor, lidar_tensor): |
| """ |
| Args: |
| image_tensor (tensor): B, C, H, W |
| lidar_tensor (tensor): B, C, H, W |
| """ |
| bz = lidar_tensor.shape[0] |
| lidar_h, lidar_w = lidar_tensor.shape[2:4] |
| img_h, img_w = image_tensor.shape[2:4] |
|
|
| image_tensor = image_tensor.permute(0, 2, 3, 1).contiguous().view(bz, -1, self.n_embd) |
| lidar_tensor = lidar_tensor.permute(0, 2, 3, 1).contiguous().view(bz, -1, self.n_embd) |
|
|
| token_embeddings = torch.cat((image_tensor, lidar_tensor), dim=1) |
|
|
| x = self.drop(self.pos_emb + token_embeddings) |
| x = self.blocks(x) |
| x = self.ln_f(x) |
|
|
| image_tensor_out = ( |
| x[:, : self.config.img_vert_anchors * self.config.img_horz_anchors, :] |
| .view(bz, img_h, img_w, -1) |
| .permute(0, 3, 1, 2) |
| .contiguous() |
| ) |
| lidar_tensor_out = ( |
| x[:, self.config.img_vert_anchors * self.config.img_horz_anchors :, :] |
| .view(bz, lidar_h, lidar_w, -1) |
| .permute(0, 3, 1, 2) |
| .contiguous() |
| ) |
|
|
| return image_tensor_out, lidar_tensor_out |
|
|
|
|
| class Block(nn.Module): |
| def __init__(self, n_embd, n_head, block_exp, attn_pdrop, resid_pdrop): |
| super().__init__() |
| self.ln1 = nn.LayerNorm(n_embd) |
| self.ln2 = nn.LayerNorm(n_embd) |
| self.attn = SelfAttention(n_embd, n_head, attn_pdrop, resid_pdrop) |
| self.mlp = nn.Sequential( |
| nn.Linear(n_embd, block_exp * n_embd), |
| nn.ReLU(True), |
| nn.Linear(block_exp * n_embd, n_embd), |
| nn.Dropout(resid_pdrop), |
| ) |
|
|
| def forward(self, x): |
| x = x + self.attn(self.ln1(x)) |
| x = x + self.mlp(self.ln2(x)) |
| return x |
|
|
|
|
| class SelfAttention(nn.Module): |
| def __init__(self, n_embd, n_head, attn_pdrop, resid_pdrop): |
| super().__init__() |
| assert n_embd % n_head == 0 |
| |
| self.key = nn.Linear(n_embd, n_embd) |
| self.query = nn.Linear(n_embd, n_embd) |
| self.value = nn.Linear(n_embd, n_embd) |
| |
| self.dropout = attn_pdrop |
| self.resid_drop = nn.Dropout(resid_pdrop) |
| |
| self.proj = nn.Linear(n_embd, n_embd) |
| self.n_head = n_head |
|
|
| def forward(self, x): |
| b, t, c = x.size() |
| |
| |
| k = self.key(x).view(b, t, self.n_head, c // self.n_head).transpose(1, 2) |
| q = self.query(x).view(b, t, self.n_head, c // self.n_head).transpose(1, 2) |
| v = self.value(x).view(b, t, self.n_head, c // self.n_head).transpose(1, 2) |
|
|
| |
| y = torch.nn.functional.scaled_dot_product_attention( |
| q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=False |
| ) |
| y = y.transpose(1, 2).contiguous().view(b, t, c) |
|
|
| |
| y = self.resid_drop(self.proj(y)) |
| return y |
|
|
|
|
| class PlanningDecoder(nn.Module): |
| @beartype |
| def __init__(self, input_bev_channels: int, config: TrainingConfig, device: torch.device): |
| super().__init__() |
| self.device = device |
| self.config = config |
| self.planning_context_encoder = PlanningContextEncoder( |
| config=self.config, input_bev_channels=input_bev_channels, device=self.device |
| ) |
|
|
| num_queries = 0 |
| if self.config.predict_temporal_spatial_waypoints: |
| num_queries += self.config.num_way_points_prediction |
|
|
| self.query = nn.Parameter( |
| torch.zeros( |
| 1, |
| num_queries, |
| self.config.transfuser_token_dim, |
| ) |
| ) |
|
|
| self.transformer_decoder = torch.nn.TransformerDecoder( |
| decoder_layer=nn.TransformerDecoderLayer( |
| self.config.transfuser_token_dim, |
| self.config.transfuser_num_bev_cross_attention_heads, |
| activation=nn.GELU(), |
| batch_first=True, |
| ), |
| num_layers=self.config.transfuser_num_bev_cross_attention_layers, |
| norm=nn.LayerNorm(self.config.transfuser_token_dim), |
| ) |
|
|
| |
| if self.config.predict_temporal_spatial_waypoints: |
| self.wp_decoder = nn.Linear(config.transfuser_token_dim, 2) |
| if self.config.use_navsim_data: |
| self.heading_decoder = nn.Linear(config.transfuser_token_dim, 1) |
|
|
| @beartype |
| def forward( |
| self, |
| bev_features: jt.Float[torch.Tensor, "bs bev_dim height_bev width_bev"], |
| data: dict, |
| log: dict, |
| ) -> tuple[ |
| jt.Float[torch.Tensor, "B n_waypoints 2"], |
| jt.Float[torch.Tensor, "B n_waypoints"] | None, |
| ]: |
| """ |
| Args: |
| bev_features: BEV features. |
| radar_features: Radar features. |
| radar_predictions: Radar predictions. |
| data: dict |
| log: dict |
| Returns: |
| waypoints: Spatial and temporal path. |
| headings: Heading predictions (if using NavSim data). |
| """ |
| self.kv = context_tokens = self.planning_context_encoder(bev_features=bev_features, data=data) |
|
|
| bs = context_tokens.shape[0] |
|
|
| queries = self.transformer_decoder(self.query.repeat(bs, 1, 1), context_tokens) |
|
|
| |
| query_idx = 0 |
| waypoints = None |
| headings = None |
|
|
| if self.config.predict_temporal_spatial_waypoints: |
| waypoints_queries = queries[:, query_idx : query_idx + self.config.num_way_points_prediction] |
| waypoints = torch.cumsum(self.wp_decoder(waypoints_queries), 1) |
| if self.config.use_navsim_data: |
| headings = torch.cumsum(self.heading_decoder(waypoints_queries), 1) |
| query_idx += self.config.num_way_points_prediction |
|
|
| return (waypoints, headings.squeeze(-1) if headings is not None else None) |
|
|
|
|
| class PlanningContextEncoder(nn.Module): |
| @beartype |
| def __init__(self, config: TrainingConfig, input_bev_channels: int, device: torch.device): |
| super().__init__() |
| self.device = device |
| self.config: TrainingConfig = config |
|
|
| self.num_status_tokens = 0 |
|
|
| if self.config.use_velocity: |
| self.num_status_tokens += 1 |
| self.velocity_encoder = nn.Sequential( |
| nn.Linear(1, self.config.transfuser_token_dim), |
| ) |
|
|
| if self.config.use_acceleration: |
| self.num_status_tokens += 1 |
| self.acceleration_encoder = nn.Sequential( |
| nn.Linear(1, self.config.transfuser_token_dim), |
| ) |
|
|
| if self.config.use_discrete_command: |
| self.num_status_tokens += 1 |
| self.command_encoder = nn.Sequential(nn.Linear(self.config.discrete_command_dim, self.config.transfuser_token_dim)) |
|
|
| self.cosine_pos_embeding = PositionEmbeddingSine(config, self.config.transfuser_token_dim // 2, normalize=True) |
| self.status_pos_embedding = nn.Parameter(torch.zeros(1, self.num_status_tokens, self.config.transfuser_token_dim)) |
|
|
| self.dimension_adapter = nn.Conv2d(input_bev_channels, self.config.transfuser_token_dim, kernel_size=1) |
| self.reset_parameters() |
|
|
| self.target_points_normalization_constants = torch.tensor( |
| self.config.target_points_normalization_constants, device=self.device, dtype=self.config.torch_float_type |
| ) |
|
|
| def reset_parameters(self): |
| nn.init.uniform_(self.status_pos_embedding) |
|
|
| @beartype |
| def forward( |
| self, |
| bev_features: jt.Float[torch.Tensor, "B C H W"], |
| data: dict, |
| ) -> jt.Float[torch.Tensor, "B N D"]: |
| """ |
| Args: |
| bev_features: Raw BEV features. |
| radar_logits: Radar logits. |
| radar_predictions: Radar predictions. |
| data: dict |
| log: dict |
| Returns: |
| context_tokens: Output tokens for planning transformer decoder. |
| """ |
| |
| if self.config.use_velocity: |
| velocity = data["speed"].reshape(-1, 1).to(self.device, dtype=self.config.torch_float_type) |
| if self.config.use_discrete_command: |
| command = data["command"].to(self.device, dtype=self.config.torch_float_type) |
|
|
| status_tokens = [] |
|
|
| |
| if self.config.use_velocity: |
| velocity_token = self.velocity_encoder(velocity / self.config.max_speed).reshape( |
| -1, 1, self.config.transfuser_token_dim |
| ) |
| status_tokens.append(velocity_token) |
|
|
| |
| if self.config.use_acceleration: |
| acceleration = data["acceleration"].reshape(-1, 1).to(self.device, dtype=self.config.torch_float_type) |
| acceleration_token = self.acceleration_encoder(acceleration / self.config.max_acceleration).reshape( |
| -1, 1, self.config.transfuser_token_dim |
| ) |
| status_tokens.append(acceleration_token) |
|
|
| |
| if self.config.use_discrete_command: |
| command_token = self.command_encoder(command).reshape( |
| -1, 1, self.config.transfuser_token_dim |
| ) |
| status_tokens.append(command_token) |
|
|
| |
| has_statuses = False |
| if len(status_tokens) > 0: |
| status_tokens = torch.cat(status_tokens, dim=1) |
| has_statuses = True |
|
|
| |
| context_tokens = self.dimension_adapter(bev_features) |
|
|
| |
| if has_statuses: |
| context_tokens = context_tokens + self.cosine_pos_embeding( |
| context_tokens |
| ) |
| context_tokens = torch.flatten(context_tokens, start_dim=2) |
| context_tokens = torch.permute(context_tokens, (0, 2, 1)) |
|
|
| status_tokens = status_tokens + self.status_pos_embedding |
| context_tokens = torch.cat( |
| [context_tokens, status_tokens], dim=1 |
| ) |
|
|
| return context_tokens |
|
|
|
|
| class PositionEmbeddingSine(nn.Module): |
| def __init__(self, config: TrainingConfig, num_pos_feats=64, temperature=10000, normalize=False, scale=None): |
| super().__init__() |
| self.config = config |
| self.num_pos_feats = num_pos_feats |
| self.temperature = temperature |
| self.normalize = normalize |
| if scale is not None and normalize is False: |
| raise ValueError("normalize should be True if scale is passed") |
| if scale is None: |
| scale = 2 * math.pi |
| self.scale = scale |
|
|
| def forward(self, tensor: torch.Tensor): |
| x = tensor |
| bs, _, h, w = x.shape |
| not_mask = torch.ones((bs, h, w), device=x.device) |
| y_embed = not_mask.cumsum(1, dtype=torch.float32) |
| x_embed = not_mask.cumsum(2, dtype=torch.float32) |
| if self.normalize: |
| eps = 1e-6 |
| y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale |
| x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale |
|
|
| dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) |
| dim_t = self.temperature ** (2 * (torch.div(dim_t, 2, rounding_mode="floor")) / self.num_pos_feats) |
|
|
| pos_x = x_embed[:, :, :, None] / dim_t |
| pos_y = y_embed[:, :, :, None] / dim_t |
| pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) |
| pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) |
| pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) |
| return pos.to(self.config.torch_float_type).contiguous() |
|
|
|
|
| @jt.jaxtyped(typechecker=beartype) |
| @dataclass |
| class Prediction: |
| """Raw output predictions from the model.""" |
|
|
| |
| pred_future_waypoints: jt.Float[torch.Tensor, "bs n_waypoints 2"] | None |
| pred_headings: jt.Float[torch.Tensor, "bs n_waypoints"] | None |
|
|
|
|
| class Model(nn.Module): |
| @beartype |
| def __init__( |
| self, |
| device: torch.device, |
| config: TrainingConfig, |
| ): |
| super().__init__() |
| self.device = device |
| self.config = config |
| self.log = {} |
|
|
| self.backbone = TransfuserBackbone(self.device, self.config) |
|
|
| self.planning_decoder = PlanningDecoder( |
| input_bev_channels=self.backbone.num_lidar_features, config=self.config, device=self.device |
| ).to(self.device) |
|
|
| @beartype |
| def forward(self, data: dict[str, typing.Any]) -> Prediction: |
| self.log = {} |
| pred_future_waypoints = pred_headings = None |
|
|
| bev_features, image_features = self.backbone(data) |
|
|
| |
| if self.config.use_planning_decoder: |
| (pred_future_waypoints, pred_headings) = self.planning_decoder(bev_features, data, log=self.log) |
|
|
| |
| return Prediction( |
| |
| pred_future_waypoints=pred_future_waypoints, |
| pred_headings=pred_headings, |
| ) |
|
|
|
|
| def load_tf(model_path: str, device: torch.device): |
| base_dir = os.path.dirname(model_path) |
| with open(os.path.join(base_dir, "config.json")) as f: |
| loaded_config = json.load(f) |
| config_training = TrainingConfig(loaded_config) |
| |
| model = Model(device=device, config=config_training) |
| model.load_state_dict(torch.load(model_path), strict=False) |
| model.to(device) |
| model.eval() |
| return model |
|
|
|
|
| class NavsimData(torch.utils.data.Dataset): |
| """Data loader for NavSim data""" |
|
|
| def __init__(self, root, config: TrainingConfig): |
| self.root = root |
| self.config = config |
|
|
| self.feature = glob.glob(os.path.join(self.root, "**/transfuser_feature.gz"), recursive=True) |
| self.target = glob.glob(os.path.join(self.root, "**/transfuser_target.gz"), recursive=True) |
|
|
| self.size = len(self.feature) |
|
|
| def __len__(self): |
| return self.size |
|
|
| def __getitem__(self, index): |
| |
| feature_path = self.feature[index] |
| with gzip.open(feature_path, "rb") as f: |
| feature = pickle.load(f) |
|
|
| rgb = cv2.imdecode(np.frombuffer(feature["camera_feature"], np.uint8), cv2.IMREAD_COLOR) |
| rgb = np.transpose(rgb, (2, 0, 1)) |
|
|
| data = { |
| "rgb": rgb, |
| "command": feature["status_feature"][:4], |
| "speed": np.linalg.norm(feature["status_feature"][4:6]), |
| "acceleration": np.linalg.norm(feature["status_feature"][6:8]), |
| } |
| return data |
|
|