|
|
import glob |
|
|
import gzip |
|
|
import json |
|
|
import math |
|
|
import os |
|
|
import pickle |
|
|
import typing |
|
|
from dataclasses import dataclass |
|
|
from enum import IntEnum, auto |
|
|
|
|
|
import cv2 |
|
|
import jaxtyping as jt |
|
|
import numpy as np |
|
|
import timm |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from beartype import beartype |
|
|
from omegaconf import OmegaConf |
|
|
from torch import nn |
|
|
|
|
|
|
|
|
class CameraPointCloudIndex(IntEnum): |
|
|
"""Index to access point cloud array of camera.""" |
|
|
|
|
|
X = 0 |
|
|
Y = 1 |
|
|
Z = 2 |
|
|
UNREAL_SEMANTICS_ID = 3 |
|
|
UNREAL_INSTANCE_ID = 4 |
|
|
|
|
|
|
|
|
class TargetDataset(IntEnum): |
|
|
"""Dataset we target to collect data/train/evaluate on.""" |
|
|
|
|
|
UNKNOWN = 0 |
|
|
CARLA_LEADERBOARD2_3CAMERAS = auto() |
|
|
CARLA_LEADERBOARD2_6CAMERAS = auto() |
|
|
NAVSIM_4CAMERAS = auto() |
|
|
WAYMO_E2E_2025_3CAMERAS = auto() |
|
|
|
|
|
|
|
|
class SourceDataset(IntEnum): |
|
|
"""Dataset intenum a sample originates from.""" |
|
|
|
|
|
UNKNOWN = 0 |
|
|
CARLA = auto() |
|
|
NAVSIM = auto() |
|
|
WAYMO_E2E_2025 = auto() |
|
|
|
|
|
|
|
|
SOURCE_DATASET_NAME_MAP = { |
|
|
SourceDataset.CARLA: "carla", |
|
|
SourceDataset.NAVSIM: "navsim", |
|
|
SourceDataset.WAYMO_E2E_2025: "waymo_e2e_2025", |
|
|
} |
|
|
|
|
|
|
|
|
class RadarLabels(IntEnum): |
|
|
"""Index to access radar label array.""" |
|
|
|
|
|
X = 0 |
|
|
Y = 1 |
|
|
V = 2 |
|
|
VALID = 3 |
|
|
|
|
|
|
|
|
class RadarDataIndex(IntEnum): |
|
|
"""Index to access radar data array.""" |
|
|
|
|
|
X = 0 |
|
|
Y = 1 |
|
|
Z = 2 |
|
|
V = 3 |
|
|
SENSOR_ID = 4 |
|
|
|
|
|
|
|
|
class TransfuserBoundingBoxIndex(IntEnum): |
|
|
"""Index to access bounding array of TransFuser.""" |
|
|
|
|
|
X = 0 |
|
|
Y = 1 |
|
|
W = 2 |
|
|
H = 3 |
|
|
YAW = 4 |
|
|
VELOCITY = 5 |
|
|
BRAKE = 6 |
|
|
CLASS = 7 |
|
|
SCORE = 8 |
|
|
VISIBLE_PIXELS = 8 |
|
|
NUM_POINTS = 9 |
|
|
ID = 10 |
|
|
NUM_RADAR_POINTS = 11 |
|
|
|
|
|
|
|
|
class TransfuserBoundingBoxClass(IntEnum): |
|
|
"""Bounding box classes used in TransFuser.""" |
|
|
|
|
|
VEHICLE = 0 |
|
|
WALKER = 1 |
|
|
TRAFFIC_LIGHT = 2 |
|
|
STOP_SIGN = 3 |
|
|
SPECIAL = 4 |
|
|
OBSTACLE = 5 |
|
|
PARKING = 6 |
|
|
BIKER = 7 |
|
|
|
|
|
|
|
|
class CarlaNavigationCommand(IntEnum): |
|
|
"""Source: https://carla.org/Doxygen/html/d0/db7/namespacecarla_1_1traffic__manager.html#a5734807dba08623eeca046a963ade360""" |
|
|
|
|
|
UNKNOWN = 0 |
|
|
LEFT = 1 |
|
|
RIGHT = 2 |
|
|
STRAIGHT = 3 |
|
|
LANEFOLLOW = 4 |
|
|
CHANGELANELEFT = 5 |
|
|
CHANGELANERIGHT = 6 |
|
|
|
|
|
|
|
|
class ChaffeurNetBEVSemanticClass(IntEnum): |
|
|
"""Indicies to access BEV semantic map produced by ChaffeurNet.""" |
|
|
|
|
|
UNLABELED = 0 |
|
|
ROAD = 1 |
|
|
SIDEWALK = 2 |
|
|
LANE_MARKERS = 3 |
|
|
LANE_MARKERS_BROKEN = 4 |
|
|
STOP_SIGNS = 5 |
|
|
TRAFFIC_GREEN = 6 |
|
|
TRAFFIC_YELLOW = 7 |
|
|
TRAFFIC_RED = 8 |
|
|
|
|
|
|
|
|
class TransfuserBEVSemanticClass(IntEnum): |
|
|
"""Indicies to access BEV semantic map produced by TransFuser.""" |
|
|
|
|
|
UNLABELED = 0 |
|
|
ROAD = 1 |
|
|
LANE_MARKERS = 2 |
|
|
STOP_SIGNS = 3 |
|
|
VEHICLE = 4 |
|
|
WALKER = 5 |
|
|
OBSTACLE = 6 |
|
|
PARKING_VEHICLE = 7 |
|
|
SPECIAL_VEHICLE = 8 |
|
|
BIKER = 9 |
|
|
TRAFFIC_GREEN = 10 |
|
|
TRAFFIC_RED_NORMAL = 11 |
|
|
TRAFFIC_RED_NOT_NORMAL = 12 |
|
|
|
|
|
|
|
|
class TransfuserBEVOccupancyClass(IntEnum): |
|
|
"""Indicies to access BEV occupancy map produced by TransFuser.""" |
|
|
|
|
|
UNLABELED = 0 |
|
|
VEHICLE = 1 |
|
|
WALKER = 2 |
|
|
OBSTACLE = 3 |
|
|
PARKING_VEHICLE = 4 |
|
|
SPECIAL_VEHICLE = 5 |
|
|
BIKER = 6 |
|
|
TRAFFIC_GREEN = 7 |
|
|
TRAFFIC_RED_NORMAL = 8 |
|
|
TRAFFIC_RED_NOT_NORMAL = 9 |
|
|
|
|
|
|
|
|
class WeatherVisibility(IntEnum): |
|
|
"""Visibility conditions in CARLA simulator, classified by TransFuser.""" |
|
|
|
|
|
CLEAR = 0 |
|
|
OK = 1 |
|
|
LIMITED = 2 |
|
|
VERY_LIMITED = 3 |
|
|
|
|
|
|
|
|
class CarlaSemanticSegmentationClass(IntEnum): |
|
|
"""https://carla.readthedocs.io/en/latest/ref_sensors/#semantic-segmentation-camera""" |
|
|
|
|
|
Unlabeled = 0 |
|
|
Roads = 1 |
|
|
SideWalks = 2 |
|
|
Building = 3 |
|
|
Wall = 4 |
|
|
Fence = 5 |
|
|
Pole = 6 |
|
|
TrafficLight = 7 |
|
|
TrafficSign = 8 |
|
|
Vegetation = 9 |
|
|
Terrain = 10 |
|
|
Sky = 11 |
|
|
Pedestrian = 12 |
|
|
Rider = 13 |
|
|
Car = 14 |
|
|
Truck = 15 |
|
|
Bus = 16 |
|
|
Train = 17 |
|
|
Motorcycle = 18 |
|
|
Bicycle = 19 |
|
|
Static = 20 |
|
|
Dynamic = 21 |
|
|
Other = 22 |
|
|
Water = 23 |
|
|
RoadLine = 24 |
|
|
Ground = 25 |
|
|
Bridge = 26 |
|
|
RailTrack = 27 |
|
|
GuardRail = 28 |
|
|
|
|
|
ConeAndTrafficWarning = 29 |
|
|
SpecialVehicles = 30 |
|
|
StopSign = 31 |
|
|
|
|
|
|
|
|
class TransfuserSemanticSegmentationClass(IntEnum): |
|
|
"""Semantic segmentation classes used in TransFuser.""" |
|
|
|
|
|
UNLABELED = 0 |
|
|
VEHICLE = 1 |
|
|
ROAD = 2 |
|
|
TRAFFIC_LIGHT = 3 |
|
|
PEDESTRIAN = 4 |
|
|
ROAD_LINE = 5 |
|
|
OBSTACLE = 6 |
|
|
SPECIAL_VEHICLE = 7 |
|
|
STOP_SIGN = 8 |
|
|
BIKER = 9 |
|
|
|
|
|
|
|
|
def rgb(r, g, b): |
|
|
"""Dummy function to help with visualizing RGB colors by using a VSCode extension.""" |
|
|
return (r, g, b) |
|
|
|
|
|
|
|
|
|
|
|
LIDAR_COLOR = rgb(90, 107, 249) |
|
|
EGO_BB_COLOR = rgb(151, 15, 48) |
|
|
TP_DEFAULT_COLOR = rgb(255, 10, 10) |
|
|
RADAR_COLOR = rgb(24, 237, 3) |
|
|
RADAR_DETECTION_COLOR = rgb(255, 24, 0) |
|
|
|
|
|
|
|
|
GROUNDTRUTH_BB_WP_COLOR = rgb(15, 60, 255) |
|
|
GROUNDTRUTH_FUTURE_WAYPOINT_COLOR = rgb(0, 0, 0) |
|
|
GROUND_TRUTH_PAST_WAYPOINT_COLOR = rgb(0, 0, 0) |
|
|
PREDICTION_WAYPOINT_COLOR = rgb(255, 0, 0) |
|
|
PREDICTION_ROUTE_COLOR = rgb(0, 0, 255) |
|
|
PREDICTION_WAYPOINT_RADIUS = 10 |
|
|
PREDICTION_ROUTE_RADIUS = 6 |
|
|
|
|
|
|
|
|
CARLA_TRANSFUSER_BEV_SEMANTIC_COLOR_CONVERTER = { |
|
|
TransfuserBEVSemanticClass.UNLABELED: rgb(0, 0, 0), |
|
|
TransfuserBEVSemanticClass.ROAD: rgb(250, 250, 250), |
|
|
TransfuserBEVSemanticClass.LANE_MARKERS: rgb(255, 255, 0), |
|
|
TransfuserBEVSemanticClass.STOP_SIGNS: rgb(160, 160, 0), |
|
|
TransfuserBEVSemanticClass.VEHICLE: rgb(15, 60, 255), |
|
|
TransfuserBEVSemanticClass.WALKER: rgb(0, 255, 0), |
|
|
TransfuserBEVSemanticClass.OBSTACLE: rgb(255, 0, 0), |
|
|
TransfuserBEVSemanticClass.PARKING_VEHICLE: rgb(116, 150, 65), |
|
|
TransfuserBEVSemanticClass.SPECIAL_VEHICLE: rgb(255, 0, 255), |
|
|
TransfuserBEVSemanticClass.BIKER: rgb(255, 0, 0), |
|
|
TransfuserBEVSemanticClass.TRAFFIC_GREEN: rgb(0, 255, 0), |
|
|
TransfuserBEVSemanticClass.TRAFFIC_RED_NORMAL: rgb(255, 0, 0), |
|
|
TransfuserBEVSemanticClass.TRAFFIC_RED_NOT_NORMAL: rgb(0, 0, 255), |
|
|
} |
|
|
|
|
|
|
|
|
TRANSFUSER_SEMANTIC_COLORS = { |
|
|
TransfuserSemanticSegmentationClass.UNLABELED: rgb(0, 0, 0), |
|
|
TransfuserSemanticSegmentationClass.VEHICLE: rgb(31, 119, 180), |
|
|
TransfuserSemanticSegmentationClass.ROAD: rgb(128, 64, 128), |
|
|
TransfuserSemanticSegmentationClass.TRAFFIC_LIGHT: rgb(250, 170, 30), |
|
|
TransfuserSemanticSegmentationClass.PEDESTRIAN: rgb(0, 255, 60), |
|
|
TransfuserSemanticSegmentationClass.ROAD_LINE: rgb(157, 234, 50), |
|
|
TransfuserSemanticSegmentationClass.OBSTACLE: rgb(255, 0, 0), |
|
|
TransfuserSemanticSegmentationClass.SPECIAL_VEHICLE: rgb(255, 255, 0), |
|
|
TransfuserSemanticSegmentationClass.STOP_SIGN: rgb(125, 0, 0), |
|
|
TransfuserSemanticSegmentationClass.BIKER: rgb(220, 20, 60), |
|
|
} |
|
|
|
|
|
|
|
|
TRANSFUSER_BOUNDING_BOX_COLORS = { |
|
|
TransfuserBoundingBoxClass.VEHICLE: rgb(0, 0, 255), |
|
|
TransfuserBoundingBoxClass.WALKER: rgb(0, 255, 0), |
|
|
TransfuserBoundingBoxClass.TRAFFIC_LIGHT: rgb(255, 0, 0), |
|
|
TransfuserBoundingBoxClass.STOP_SIGN: rgb(250, 160, 160), |
|
|
TransfuserBoundingBoxClass.SPECIAL: rgb(0, 0, 255), |
|
|
TransfuserBoundingBoxClass.OBSTACLE: rgb(0, 255, 13), |
|
|
TransfuserBoundingBoxClass.PARKING: rgb(116, 150, 65), |
|
|
TransfuserBoundingBoxClass.BIKER: rgb(255, 0, 0), |
|
|
} |
|
|
|
|
|
|
|
|
SEMANTIC_SEGMENTATION_CONVERTER = { |
|
|
CarlaSemanticSegmentationClass.Unlabeled: TransfuserSemanticSegmentationClass.UNLABELED, |
|
|
CarlaSemanticSegmentationClass.Roads: TransfuserSemanticSegmentationClass.ROAD, |
|
|
CarlaSemanticSegmentationClass.SideWalks: TransfuserSemanticSegmentationClass.UNLABELED, |
|
|
CarlaSemanticSegmentationClass.Building: TransfuserSemanticSegmentationClass.UNLABELED, |
|
|
CarlaSemanticSegmentationClass.Wall: TransfuserSemanticSegmentationClass.UNLABELED, |
|
|
CarlaSemanticSegmentationClass.Fence: TransfuserSemanticSegmentationClass.UNLABELED, |
|
|
CarlaSemanticSegmentationClass.Pole: TransfuserSemanticSegmentationClass.UNLABELED, |
|
|
CarlaSemanticSegmentationClass.TrafficLight: TransfuserSemanticSegmentationClass.TRAFFIC_LIGHT, |
|
|
CarlaSemanticSegmentationClass.TrafficSign: TransfuserSemanticSegmentationClass.UNLABELED, |
|
|
CarlaSemanticSegmentationClass.Vegetation: TransfuserSemanticSegmentationClass.UNLABELED, |
|
|
CarlaSemanticSegmentationClass.Terrain: TransfuserSemanticSegmentationClass.UNLABELED, |
|
|
CarlaSemanticSegmentationClass.Sky: TransfuserSemanticSegmentationClass.UNLABELED, |
|
|
CarlaSemanticSegmentationClass.Pedestrian: TransfuserSemanticSegmentationClass.PEDESTRIAN, |
|
|
CarlaSemanticSegmentationClass.Rider: TransfuserSemanticSegmentationClass.BIKER, |
|
|
CarlaSemanticSegmentationClass.Car: TransfuserSemanticSegmentationClass.VEHICLE, |
|
|
CarlaSemanticSegmentationClass.Truck: TransfuserSemanticSegmentationClass.VEHICLE, |
|
|
CarlaSemanticSegmentationClass.Bus: TransfuserSemanticSegmentationClass.VEHICLE, |
|
|
CarlaSemanticSegmentationClass.Train: TransfuserSemanticSegmentationClass.UNLABELED, |
|
|
CarlaSemanticSegmentationClass.Motorcycle: TransfuserSemanticSegmentationClass.VEHICLE, |
|
|
CarlaSemanticSegmentationClass.Bicycle: TransfuserSemanticSegmentationClass.BIKER, |
|
|
CarlaSemanticSegmentationClass.Static: TransfuserSemanticSegmentationClass.UNLABELED, |
|
|
CarlaSemanticSegmentationClass.Dynamic: TransfuserSemanticSegmentationClass.UNLABELED, |
|
|
CarlaSemanticSegmentationClass.Other: TransfuserSemanticSegmentationClass.UNLABELED, |
|
|
CarlaSemanticSegmentationClass.Water: TransfuserSemanticSegmentationClass.UNLABELED, |
|
|
CarlaSemanticSegmentationClass.RoadLine: TransfuserSemanticSegmentationClass.ROAD_LINE, |
|
|
CarlaSemanticSegmentationClass.Ground: TransfuserSemanticSegmentationClass.UNLABELED, |
|
|
CarlaSemanticSegmentationClass.Bridge: TransfuserSemanticSegmentationClass.UNLABELED, |
|
|
CarlaSemanticSegmentationClass.RailTrack: TransfuserSemanticSegmentationClass.UNLABELED, |
|
|
CarlaSemanticSegmentationClass.GuardRail: TransfuserSemanticSegmentationClass.UNLABELED, |
|
|
CarlaSemanticSegmentationClass.ConeAndTrafficWarning: TransfuserSemanticSegmentationClass.OBSTACLE, |
|
|
CarlaSemanticSegmentationClass.SpecialVehicles: TransfuserSemanticSegmentationClass.SPECIAL_VEHICLE, |
|
|
CarlaSemanticSegmentationClass.StopSign: TransfuserSemanticSegmentationClass.STOP_SIGN, |
|
|
} |
|
|
|
|
|
|
|
|
CHAFFEURNET_TO_TRANSFUSER_BEV_SEMANTIC_CONVERTER = { |
|
|
ChaffeurNetBEVSemanticClass.UNLABELED: TransfuserBEVSemanticClass.UNLABELED, |
|
|
ChaffeurNetBEVSemanticClass.ROAD: TransfuserBEVSemanticClass.ROAD, |
|
|
ChaffeurNetBEVSemanticClass.SIDEWALK: TransfuserBEVSemanticClass.UNLABELED, |
|
|
ChaffeurNetBEVSemanticClass.LANE_MARKERS: TransfuserBEVSemanticClass.LANE_MARKERS, |
|
|
ChaffeurNetBEVSemanticClass.LANE_MARKERS_BROKEN: TransfuserBEVSemanticClass.LANE_MARKERS, |
|
|
ChaffeurNetBEVSemanticClass.STOP_SIGNS: TransfuserBEVSemanticClass.STOP_SIGNS, |
|
|
ChaffeurNetBEVSemanticClass.TRAFFIC_GREEN: TransfuserBEVSemanticClass.UNLABELED, |
|
|
ChaffeurNetBEVSemanticClass.TRAFFIC_YELLOW: TransfuserBEVSemanticClass.UNLABELED, |
|
|
ChaffeurNetBEVSemanticClass.TRAFFIC_RED: TransfuserBEVSemanticClass.UNLABELED, |
|
|
} |
|
|
|
|
|
SCENARIO_TYPES = [ |
|
|
"Accident", |
|
|
"AccidentTwoWays", |
|
|
"BlockedIntersection", |
|
|
"ConstructionObstacle", |
|
|
"ConstructionObstacleTwoWays", |
|
|
"ControlLoss", |
|
|
"CrossJunctionDefectTrafficLight", |
|
|
"CrossingBicycleFlow", |
|
|
"DynamicObjectCrossing", |
|
|
"EnterActorFlow", |
|
|
"EnterActorFlowV2", |
|
|
"HardBreakRoute", |
|
|
"HazardAtSideLane", |
|
|
"HazardAtSideLaneTwoWays", |
|
|
"HighwayCutIn", |
|
|
"HighwayExit", |
|
|
"InterurbanActorFlow", |
|
|
"InterurbanAdvancedActorFlow", |
|
|
"InvadingTurn", |
|
|
"MergerIntoSlowTraffic", |
|
|
"MergerIntoSlowTrafficV2", |
|
|
"NonSignalizedJunctionLeftTurn", |
|
|
"NonSignalizedJunctionLeftTurnEnterFlow", |
|
|
"NonSignalizedJunctionRightTurn", |
|
|
"noScenarios", |
|
|
"OppositeVehicleRunningRedLight", |
|
|
"OppositeVehicleTakingPriority", |
|
|
"ParkedObstacle", |
|
|
"ParkedObstacleTwoWays", |
|
|
"ParkingCrossingPedestrian", |
|
|
"ParkingCutIn", |
|
|
"ParkingExit", |
|
|
"PedestrianCrossing", |
|
|
"PriorityAtJunction", |
|
|
"RedLightWithoutLeadVehicle", |
|
|
"SequentialLaneChange", |
|
|
"SignalizedJunctionLeftTurn", |
|
|
"SignalizedJunctionLeftTurnEnterFlow", |
|
|
"SignalizedJunctionRightTurn", |
|
|
"StaticCutIn", |
|
|
"T_Junction", |
|
|
"VanillaNonSignalizedTurn", |
|
|
"VanillaNonSignalizedTurnEncounterStopsign", |
|
|
"VanillaSignalizedTurnEncounterGreenLight", |
|
|
"VanillaSignalizedTurnEncounterRedLight", |
|
|
"VehicleOpensDoorTwoWays", |
|
|
"VehicleTurningRoute", |
|
|
"VehicleTurningRoutePedestrian", |
|
|
"YieldToEmergencyVehicle", |
|
|
"NA", |
|
|
] |
|
|
|
|
|
EMERGENCY_MESHES = { |
|
|
"vehicle.dodge.charger_police_2020", |
|
|
"vehicle.dodge.charger_police", |
|
|
"vehicle.ford.ambulance", |
|
|
"vehicle.carlamotors.firetruck", |
|
|
} |
|
|
|
|
|
CONSTRUCTION_MESHES = {"static.prop.constructioncone", "static.prop.trafficwarning"} |
|
|
|
|
|
BIKER_MESHES = { |
|
|
"vehicle.diamondback.century", |
|
|
"vehicle.gazelle.omafiets", |
|
|
"vehicle.bh.crossbike", |
|
|
"vehicle.harley-davidson.low_rider", |
|
|
"vehicle.kawasaki.ninja", |
|
|
"vehicle.vespa.zx125", |
|
|
"vehicle.yamaha.yzf", |
|
|
} |
|
|
|
|
|
URBAN_MAX_SPEED_LIMIT = 15 |
|
|
SUBURBAN_MAX_SPEED_LIMIT = 25 |
|
|
HIGHWAY_MAX_SPEED_LIMIT = 35 |
|
|
|
|
|
LOOKUP_TABLE = { |
|
|
"/Game/Carla/Static/Car/4Wheeled/ParkedVehicles/Lincoln/SM_LincolnParked.SM_LincolnParked": [ |
|
|
2.44619083404541, |
|
|
1.115301489830017, |
|
|
0.7606233954429626, |
|
|
], |
|
|
"/Game/Carla/Static/Car/4Wheeled/ParkedVehicles/Charger/SM_ChargerParked.SM_ChargerParked": [ |
|
|
2.5039126873016357, |
|
|
1.0485419034957886, |
|
|
0.7673624753952026, |
|
|
], |
|
|
"/Game/Carla/Static/Car/4Wheeled/ParkedVehicles/VolkswagenT2/SM_VolkswagenT2_2021_Parked.SM_VolkswagenT2_2021_Parked": [ |
|
|
2.2210919857025146, |
|
|
0.9388753771781921, |
|
|
0.9936029314994812, |
|
|
], |
|
|
"/Game/Carla/Static/Car/4Wheeled/ParkedVehicles/FordCrown/SM_FordCrown_parked.SM_FordCrown_parked": [ |
|
|
2.6828393936157227, |
|
|
0.9732309579849243, |
|
|
0.7874829173088074, |
|
|
], |
|
|
"/Game/Carla/Static/Car/4Wheeled/ParkedVehicles/NissanPatrol2021/SM_NissanPatrol2021_parked.SM_NissanPatrol2021_parked": [ |
|
|
2.782914400100708, |
|
|
1.217571496963501, |
|
|
1.022573471069336, |
|
|
], |
|
|
"/Game/Carla/Static/Car/4Wheeled/ParkedVehicles/MercedesCCC/SM_MercedesCCC_Parked.SM_MercedesCCC_Parked": [ |
|
|
2.3368194103240967, |
|
|
1.0011461973190308, |
|
|
0.7259762287139893, |
|
|
], |
|
|
"/Game/Carla/Static/Car/4Wheeled/ParkedVehicles/TeslaM3/SM_TeslaM3_parked.SM_TeslaM3_parked": [ |
|
|
2.3958897590637207, |
|
|
1.081725001335144, |
|
|
0.7438300848007202, |
|
|
], |
|
|
"/Game/Carla/Static/Car/4Wheeled/ParkedVehicles/Mini2021/SM_Mini2021_parked.SM_Mini2021_parked": [ |
|
|
2.2763495445251465, |
|
|
1.0926425457000732, |
|
|
0.8835831880569458, |
|
|
], |
|
|
"/Game/Carla/Static/Dynamic/Garden/SM_PlasticTable.SM_PlasticTable": [ |
|
|
1.241101622581482, |
|
|
1.241101622581482, |
|
|
1.239898920059204, |
|
|
], |
|
|
"/Game/Carla/Static/Dynamic/Garden/SM_PlasticChair.SM_PlasticChair": [ |
|
|
0.36523768305778503, |
|
|
0.37522444128990173, |
|
|
0.6356779336929321, |
|
|
], |
|
|
"/Game/Carla/Static/Dynamic/Construction/SM_ConstructionCone.SM_ConstructionCone": [ |
|
|
0.1720348298549652, |
|
|
0.1720348298549652, |
|
|
0.2928849756717682, |
|
|
], |
|
|
} |
|
|
|
|
|
CONSTRUCTION_CONE_BB_SIZE = [0.1720348298549652, 0.1720348298549652] |
|
|
|
|
|
TRAFFIC_WARNING_BB_SIZE = [1.186714768409729, 1.4352929592132568] |
|
|
|
|
|
OLD_TOWNS = { |
|
|
"Town01", |
|
|
"Town02", |
|
|
"Town03", |
|
|
"Town04", |
|
|
"Town05", |
|
|
"Town06", |
|
|
"Town07", |
|
|
"Town10HD", |
|
|
} |
|
|
|
|
|
|
|
|
ALL_TOWNS = [ |
|
|
"Town01", |
|
|
"Town02", |
|
|
"Town03", |
|
|
"Town04", |
|
|
"Town05", |
|
|
"Town06", |
|
|
"Town07", |
|
|
"Town10HD", |
|
|
"Town11", |
|
|
"Town12", |
|
|
"Town13", |
|
|
"Town15", |
|
|
] |
|
|
|
|
|
|
|
|
TOWN_NAME_TO_INDEX = { |
|
|
"Town01": "01", |
|
|
"Town02": "02", |
|
|
"Town03": "03", |
|
|
"Town04": "04", |
|
|
"Town05": "05", |
|
|
"Town06": "06", |
|
|
"Town07": "07", |
|
|
"Town10HD": "10HD", |
|
|
"Town11": "11", |
|
|
"Town12": "12", |
|
|
"Town13": "13", |
|
|
"Town15": "15", |
|
|
} |
|
|
|
|
|
|
|
|
NUPLAN_CAMERA_CALIBRATION = { |
|
|
"CAM_L0": { |
|
|
"pos": [0.2567803445203347, -0.14912709068475835, 1.9611907818710856], |
|
|
"rot": [-1.7908743283844297, -0.18030832979657968, -56.0098600797478], |
|
|
"fov": 63.71, |
|
|
"width": 1920 // 4, |
|
|
"height": 1120 // 4, |
|
|
"cropped_height": 1080 // 4, |
|
|
}, |
|
|
"CAM_F0": { |
|
|
"pos": [0.3435966588946209, 0.00981503465349912, 1.9520959734648988], |
|
|
"rot": [0.004047525205852703, -2.344563417746492, -1.0253844128360994], |
|
|
"fov": 63.71, |
|
|
"width": 1920 // 4, |
|
|
"height": 1120 // 4, |
|
|
"cropped_height": 1080 // 4, |
|
|
}, |
|
|
"CAM_R0": { |
|
|
"pos": [0.362775037177945, 0.16380984892156114, 1.9540305698009064], |
|
|
"rot": [0.8908895493227222, -0.5177262293066659, 54.26239302855094], |
|
|
"fov": 63.71, |
|
|
"width": 1920 // 4, |
|
|
"height": 1120 // 4, |
|
|
"cropped_height": 1080 // 4, |
|
|
}, |
|
|
"CAM_B0": { |
|
|
"pos": [-1.8371898892894447, 0.023124645489646514, 1.9105230244574516], |
|
|
"rot": [0.004047525205852703, 1.9819092884563787, 180], |
|
|
"fov": 63.71, |
|
|
"width": 1920 // 4, |
|
|
"height": 1120 // 4, |
|
|
"cropped_height": 1080 // 4, |
|
|
}, |
|
|
} |
|
|
|
|
|
|
|
|
class NavSimBEVSemanticClass(IntEnum): |
|
|
"""Indicies to access BEV semantic map produced by NavSim. |
|
|
|
|
|
See: https://github.com/autonomousvision/navsim/blob/main/navsim/agents/transfuser/transfuser_config.py#L83 |
|
|
""" |
|
|
|
|
|
UNLABELED = 0 |
|
|
ROAD = auto() |
|
|
WALKWAYS = auto() |
|
|
CENTERLINE = auto() |
|
|
STATIC_OBJECTS = auto() |
|
|
VEHICLES = auto() |
|
|
PEDESTRIANS = auto() |
|
|
|
|
|
|
|
|
class NavSimBoundingBoxIndex(IntEnum): |
|
|
"""Index to access NavSim bounding box attribute array. |
|
|
|
|
|
See: https://github.com/autonomousvision/navsim/blob/main/navsim/agents/transfuser/transfuser_features.py#L174 |
|
|
""" |
|
|
|
|
|
X = 0 |
|
|
Y = 1 |
|
|
HEADING = 2 |
|
|
LENGTH = 3 |
|
|
WIDTH = 4 |
|
|
|
|
|
|
|
|
class NavSimBBClass(IntEnum): |
|
|
"""Bounding box classes used in NavSim.""" |
|
|
|
|
|
GENERIC_CLASS = 0 |
|
|
|
|
|
|
|
|
class NavSimStatusFeature(IntEnum): |
|
|
"""Status feature indices used in NavSim.""" |
|
|
|
|
|
DRIVING_COMMAND_LEFT = 0 |
|
|
DRIVING_COMMAND_RIGHT = 1 |
|
|
DRIVING_COMMAND_STRAIGHT = 2 |
|
|
DRIVING_COMMAND_UNKNOWN = 3 |
|
|
EGO_VELOCITY_X = 4 |
|
|
EGO_VELOCITY_Y = 5 |
|
|
ACCELERATION_X = 6 |
|
|
ACCELERATION_Y = 7 |
|
|
|
|
|
|
|
|
SIM2REAL_SEMANTIC_SEGMENTATION_CONVERTER = { |
|
|
TransfuserSemanticSegmentationClass.UNLABELED: TransfuserSemanticSegmentationClass.UNLABELED, |
|
|
TransfuserSemanticSegmentationClass.VEHICLE: TransfuserSemanticSegmentationClass.VEHICLE, |
|
|
TransfuserSemanticSegmentationClass.ROAD: TransfuserSemanticSegmentationClass.ROAD, |
|
|
TransfuserSemanticSegmentationClass.TRAFFIC_LIGHT: TransfuserSemanticSegmentationClass.TRAFFIC_LIGHT, |
|
|
TransfuserSemanticSegmentationClass.PEDESTRIAN: TransfuserSemanticSegmentationClass.PEDESTRIAN, |
|
|
TransfuserSemanticSegmentationClass.ROAD_LINE: TransfuserSemanticSegmentationClass.ROAD_LINE, |
|
|
TransfuserSemanticSegmentationClass.OBSTACLE: TransfuserSemanticSegmentationClass.OBSTACLE, |
|
|
TransfuserSemanticSegmentationClass.SPECIAL_VEHICLE: TransfuserSemanticSegmentationClass.VEHICLE, |
|
|
TransfuserSemanticSegmentationClass.STOP_SIGN: TransfuserSemanticSegmentationClass.STOP_SIGN, |
|
|
TransfuserSemanticSegmentationClass.BIKER: TransfuserSemanticSegmentationClass.BIKER, |
|
|
} |
|
|
|
|
|
|
|
|
SIM2REAL_BEV_SEMANTIC_SEGMENTATION_CONVERTER = { |
|
|
TransfuserBEVSemanticClass.UNLABELED: TransfuserBEVSemanticClass.UNLABELED, |
|
|
TransfuserBEVSemanticClass.ROAD: TransfuserBEVSemanticClass.ROAD, |
|
|
TransfuserBEVSemanticClass.LANE_MARKERS: TransfuserBEVSemanticClass.LANE_MARKERS, |
|
|
TransfuserBEVSemanticClass.STOP_SIGNS: TransfuserBEVSemanticClass.UNLABELED, |
|
|
TransfuserBEVSemanticClass.VEHICLE: TransfuserBEVSemanticClass.VEHICLE, |
|
|
TransfuserBEVSemanticClass.WALKER: TransfuserBEVSemanticClass.WALKER, |
|
|
TransfuserBEVSemanticClass.OBSTACLE: TransfuserBEVSemanticClass.OBSTACLE, |
|
|
TransfuserBEVSemanticClass.PARKING_VEHICLE: TransfuserBEVSemanticClass.VEHICLE, |
|
|
TransfuserBEVSemanticClass.SPECIAL_VEHICLE: TransfuserBEVSemanticClass.SPECIAL_VEHICLE, |
|
|
TransfuserBEVSemanticClass.BIKER: TransfuserBEVSemanticClass.BIKER, |
|
|
TransfuserBEVSemanticClass.TRAFFIC_GREEN: TransfuserBEVSemanticClass.TRAFFIC_GREEN, |
|
|
TransfuserBEVSemanticClass.TRAFFIC_RED_NORMAL: TransfuserBEVSemanticClass.TRAFFIC_RED_NORMAL, |
|
|
TransfuserBEVSemanticClass.TRAFFIC_RED_NOT_NORMAL: TransfuserBEVSemanticClass.TRAFFIC_RED_NORMAL, |
|
|
} |
|
|
|
|
|
SIM2REAL_BOUNDING_BOX_CLASS_CONVERTER = { |
|
|
TransfuserBoundingBoxClass.VEHICLE: TransfuserBoundingBoxClass.VEHICLE, |
|
|
TransfuserBoundingBoxClass.WALKER: TransfuserBoundingBoxClass.WALKER, |
|
|
TransfuserBoundingBoxClass.TRAFFIC_LIGHT: TransfuserBoundingBoxClass.TRAFFIC_LIGHT, |
|
|
TransfuserBoundingBoxClass.STOP_SIGN: TransfuserBoundingBoxClass.STOP_SIGN, |
|
|
TransfuserBoundingBoxClass.SPECIAL: TransfuserBoundingBoxClass.VEHICLE, |
|
|
TransfuserBoundingBoxClass.OBSTACLE: TransfuserBoundingBoxClass.OBSTACLE, |
|
|
TransfuserBoundingBoxClass.PARKING: TransfuserBoundingBoxClass.VEHICLE, |
|
|
TransfuserBoundingBoxClass.BIKER: TransfuserBoundingBoxClass.BIKER, |
|
|
} |
|
|
|
|
|
SIM2REAL_BEV_OCCUPANCY_CLASS_CONVERTER = { |
|
|
TransfuserBEVOccupancyClass.UNLABELED: TransfuserBEVOccupancyClass.UNLABELED, |
|
|
TransfuserBEVOccupancyClass.VEHICLE: TransfuserBEVOccupancyClass.VEHICLE, |
|
|
TransfuserBEVOccupancyClass.WALKER: TransfuserBEVOccupancyClass.WALKER, |
|
|
TransfuserBEVOccupancyClass.OBSTACLE: TransfuserBEVOccupancyClass.UNLABELED, |
|
|
TransfuserBEVOccupancyClass.PARKING_VEHICLE: TransfuserBEVOccupancyClass.VEHICLE, |
|
|
TransfuserBEVOccupancyClass.SPECIAL_VEHICLE: TransfuserBEVOccupancyClass.VEHICLE, |
|
|
TransfuserBEVOccupancyClass.BIKER: TransfuserBEVOccupancyClass.BIKER, |
|
|
TransfuserBEVOccupancyClass.TRAFFIC_GREEN: TransfuserBEVOccupancyClass.TRAFFIC_GREEN, |
|
|
TransfuserBEVOccupancyClass.TRAFFIC_RED_NORMAL: TransfuserBEVOccupancyClass.TRAFFIC_RED_NORMAL, |
|
|
TransfuserBEVOccupancyClass.TRAFFIC_RED_NOT_NORMAL: TransfuserBEVOccupancyClass.TRAFFIC_RED_NORMAL, |
|
|
} |
|
|
|
|
|
CARLA_REAR_AXLE = [-1.389, 0.0, 0.360] |
|
|
|
|
|
|
|
|
WAYMO_E2E_INTRINSIC = [ |
|
|
1112.17806333, |
|
|
1112.04114501, |
|
|
488.128479, |
|
|
719.15600586, |
|
|
-0.073584, |
|
|
-0.036582, |
|
|
0.0, |
|
|
0.0, |
|
|
0.0, |
|
|
] |
|
|
|
|
|
|
|
|
WAYMO_E2E_2025_CAMERA_SETTING = { |
|
|
"FRONT_LEFT": { |
|
|
"extrinsic": [ |
|
|
[0.70549636, -0.70869903, -0.00026687, 1.44509995], |
|
|
[0.70868651, 0.70547806, 0.00679447, 0.15270001], |
|
|
[-0.00462506, -0.00498361, 0.99996268, 1.80649996], |
|
|
[0.0, 0.0, 0.0, 1.0], |
|
|
], |
|
|
"width": 972, |
|
|
"height": 1440, |
|
|
"cropped_height": 1080, |
|
|
}, |
|
|
"FRONT": { |
|
|
"extrinsic": [ |
|
|
[0.99998026, -0.00215741, 0.00378466, 1.51909995], |
|
|
[0.00214115, 0.99997577, 0.0045841, 0.0258], |
|
|
[-0.00379113, -0.00457491, 0.9999674, 1.80649996], |
|
|
[0.0, 0.0, 0.0, 1.0], |
|
|
], |
|
|
"width": 972, |
|
|
"height": 1440, |
|
|
"cropped_height": 1080, |
|
|
}, |
|
|
"FRONT_RIGHT": { |
|
|
"extrinsic": [ |
|
|
[0.70908528, 0.70508644, 0.00506919, 1.48150003], |
|
|
[-0.7050955, 0.70909768, -0.00009787, -0.1163], |
|
|
[-0.00366141, -0.00350169, 0.99996981, 1.80649996], |
|
|
[0.0, 0.0, 0.0, 1.0], |
|
|
], |
|
|
"width": 972, |
|
|
"height": 1440, |
|
|
"cropped_height": 1080, |
|
|
}, |
|
|
} |
|
|
|
|
|
|
|
|
class WaymoE2EIntrinsicIndex(IntEnum): |
|
|
|
|
|
F_U = 0 |
|
|
F_V = 1 |
|
|
C_U = 2 |
|
|
C_V = 3 |
|
|
K_1 = 4 |
|
|
K_2 = 5 |
|
|
P_1 = 6 |
|
|
P_2 = 7 |
|
|
K_3 = 8 |
|
|
|
|
|
|
|
|
class CarlaImageCroppingType(IntEnum): |
|
|
TOP = 0 |
|
|
BOTTOM = auto() |
|
|
BOTH = auto() |
|
|
NONE = auto() |
|
|
|
|
|
|
|
|
WAYMO_DOWN_SAMPLE_FACTOR = 3 |
|
|
WAYMO_E2E_REAL_DATA_JPEG_LEVEL = 50 |
|
|
|
|
|
|
|
|
@beartype |
|
|
def normalize_imagenet(x: jt.Float[torch.Tensor, "B 3 H W"]) -> jt.Float[torch.Tensor, "B 3 H W"]: |
|
|
"""Normalize input images according to ImageNet standards. |
|
|
Args: |
|
|
x: Input images batch. |
|
|
|
|
|
Returns: |
|
|
Normalized images batch. |
|
|
""" |
|
|
x = x.clone() |
|
|
x[:, 0] = ((x[:, 0] / 255.0) - 0.485) / 0.229 |
|
|
x[:, 1] = ((x[:, 1] / 255.0) - 0.456) / 0.224 |
|
|
x[:, 2] = ((x[:, 2] / 255.0) - 0.406) / 0.225 |
|
|
return x |
|
|
|
|
|
|
|
|
class BaseConfig: |
|
|
@property |
|
|
def target_dataset(self): |
|
|
raise NotImplementedError("Subclasses must implement the target_dataset property.") |
|
|
|
|
|
|
|
|
|
|
|
bicycle_frame_rate = 20 |
|
|
|
|
|
num_route_points_saved = 50 |
|
|
|
|
|
points_per_meter = 10 |
|
|
|
|
|
|
|
|
pixels_per_meter_collection = 2.0 |
|
|
|
|
|
longitudinal_max_accelerations = 1.89 |
|
|
|
|
|
|
|
|
|
|
|
time_step = 1.0 / 20.0 |
|
|
|
|
|
front_wheel_base = -0.090769015 |
|
|
|
|
|
rear_wheel_base = 1.4178275 |
|
|
|
|
|
steering_gain = 0.36848336 |
|
|
|
|
|
brake_acceleration = -4.952399 |
|
|
|
|
|
throttle_acceleration = 0.5633837 |
|
|
|
|
|
throttle_threshold_during_forecasting = 0.3 |
|
|
|
|
|
|
|
|
|
|
|
data_save_freq = 5 |
|
|
|
|
|
augment = False |
|
|
|
|
|
default_safety_translation_augmentation_penalty = 0.25 |
|
|
|
|
|
urban_safety_translation_augmentation_penalty = 0.4 |
|
|
|
|
|
|
|
|
camera_translation_augmentation_min = 0.1 |
|
|
|
|
|
camera_translation_augmentation_max = 1.0 |
|
|
|
|
|
|
|
|
camera_rotation_augmentation_min = 5.0 |
|
|
|
|
|
camera_rotation_augmentation_max = 12.5 |
|
|
|
|
|
camera_rotation_epsilon = 0.5 |
|
|
|
|
|
extreme_augment = False |
|
|
|
|
|
|
|
|
data_collection_extreme_augment_prob = 1.0 |
|
|
|
|
|
camera_extreme_rotation_augmentation_min = 12.5 |
|
|
|
|
|
camera_extreme_rotation_augmentation_max = 40 |
|
|
|
|
|
camera_extreme_translation_augmentation_min = 0.1 |
|
|
|
|
|
camera_extreme_translation_augmentation_max = 1.0 |
|
|
|
|
|
camera_extreme_pitch_augmentation_min = -1.0 |
|
|
|
|
|
camera_extreme_pitch_augmentation_max = 0.5 |
|
|
|
|
|
camera_extreme_roll_augmentation_max = 1.0 |
|
|
|
|
|
|
|
|
|
|
|
point_format = 0 |
|
|
|
|
|
point_precision_x = point_precision_y = point_precision_z = 0.1 |
|
|
|
|
|
max_height_lidar = 10.0 |
|
|
|
|
|
min_height_lidar = -4.0 |
|
|
|
|
|
|
|
|
|
|
|
use_two_lidars = True |
|
|
|
|
|
|
|
|
@property |
|
|
def lidar_pos_1(self): |
|
|
return [0.0, 0.0, 2.5] |
|
|
|
|
|
|
|
|
@property |
|
|
def lidar_rot_1(self): |
|
|
return [0.0, 0.0, -90.0] |
|
|
|
|
|
|
|
|
@property |
|
|
def lidar_pos_2(self): |
|
|
return [0.0, 0.0, 2.5] |
|
|
|
|
|
|
|
|
@property |
|
|
def lidar_rot_2(self): |
|
|
return [0.0, 0.0, -270.0] |
|
|
|
|
|
|
|
|
@property |
|
|
def lidar_accumulation(self): |
|
|
return True |
|
|
|
|
|
|
|
|
@property |
|
|
def num_cameras(self): |
|
|
"""Number of cameras based on the target dataset.""" |
|
|
return { |
|
|
TargetDataset.CARLA_LEADERBOARD2_6CAMERAS: 6, |
|
|
TargetDataset.CARLA_LEADERBOARD2_3CAMERAS: 3, |
|
|
TargetDataset.NAVSIM_4CAMERAS: 4, |
|
|
TargetDataset.WAYMO_E2E_2025_3CAMERAS: 3, |
|
|
}[self.target_dataset] |
|
|
|
|
|
@property |
|
|
def camera_calibration(self): |
|
|
"""Camera calibration configuration with positions, rotations, and sensor parameters""" |
|
|
if self.target_dataset == TargetDataset.CARLA_LEADERBOARD2_6CAMERAS: |
|
|
return { |
|
|
1: { |
|
|
"pos": [0.0, -0.3, 2.25], |
|
|
"rot": [0.0, 0.0, -57.5], |
|
|
"width": 1152 // 3, |
|
|
"height": 384, |
|
|
"cropped_height": 384, |
|
|
"fov": 60, |
|
|
}, |
|
|
2: { |
|
|
"pos": [0.25, 0.0, 2.25], |
|
|
"rot": [0.0, 0.0, 0.0], |
|
|
"width": 1152 // 3, |
|
|
"height": 384, |
|
|
"cropped_height": 384, |
|
|
"fov": 60, |
|
|
}, |
|
|
3: { |
|
|
"pos": [0.0, 0.3, 2.25], |
|
|
"rot": [0.0, 0.0, 57.5], |
|
|
"width": 1152 // 3, |
|
|
"height": 384, |
|
|
"cropped_height": 384, |
|
|
"fov": 60, |
|
|
}, |
|
|
4: { |
|
|
"pos": [-0.30, 0.3, 2.25], |
|
|
"rot": [0.0, 0.0, 180 - 57.5], |
|
|
"width": 1152 // 3, |
|
|
"height": 384, |
|
|
"cropped_height": 384, |
|
|
"fov": 60, |
|
|
}, |
|
|
5: { |
|
|
"pos": [-0.55, 0.0, 2.25], |
|
|
"rot": [0.0, 0.0, 180.0], |
|
|
"width": 1152 // 3, |
|
|
"height": 384, |
|
|
"cropped_height": 384, |
|
|
"fov": 60, |
|
|
}, |
|
|
6: { |
|
|
"pos": [-0.30, -0.3, 2.25], |
|
|
"rot": [0.0, 0.0, -180 + 57.5], |
|
|
"width": 1152 // 3, |
|
|
"height": 384, |
|
|
"cropped_height": 384, |
|
|
"fov": 60, |
|
|
}, |
|
|
} |
|
|
elif self.target_dataset == TargetDataset.CARLA_LEADERBOARD2_3CAMERAS: |
|
|
return { |
|
|
1: { |
|
|
"pos": [0.1, -0.35, 2.25], |
|
|
"rot": [0.0, 0.0, -54.5], |
|
|
"width": 1152 // 3, |
|
|
"height": 384, |
|
|
"cropped_height": 384, |
|
|
"fov": 60, |
|
|
}, |
|
|
2: { |
|
|
"pos": [0.35, 0.0, 2.25], |
|
|
"rot": [0.0, 0.0, 0.0], |
|
|
"width": 1152 // 3, |
|
|
"height": 384, |
|
|
"cropped_height": 384, |
|
|
"fov": 60, |
|
|
}, |
|
|
3: { |
|
|
"pos": [0.1, 0.35, 2.25], |
|
|
"rot": [0.0, 0.0, 54.5], |
|
|
"width": 1152 // 3, |
|
|
"height": 384, |
|
|
"cropped_height": 384, |
|
|
"fov": 60, |
|
|
}, |
|
|
} |
|
|
elif self.target_dataset == TargetDataset.NAVSIM_4CAMERAS: |
|
|
return { |
|
|
1: NUPLAN_CAMERA_CALIBRATION["CAM_L0"], |
|
|
2: NUPLAN_CAMERA_CALIBRATION["CAM_F0"], |
|
|
3: NUPLAN_CAMERA_CALIBRATION["CAM_R0"], |
|
|
4: NUPLAN_CAMERA_CALIBRATION["CAM_B0"], |
|
|
} |
|
|
elif self.target_dataset == TargetDataset.WAYMO_E2E_2025_3CAMERAS: |
|
|
from pdm_lite_utils_0916 import waymo_e2e_camera_setting_to_carla |
|
|
|
|
|
return { |
|
|
1: waymo_e2e_camera_setting_to_carla( |
|
|
WAYMO_E2E_INTRINSIC, |
|
|
WAYMO_E2E_2025_CAMERA_SETTING["FRONT_RIGHT"]["extrinsic"], |
|
|
WAYMO_E2E_2025_CAMERA_SETTING["FRONT_RIGHT"]["width"], |
|
|
WAYMO_E2E_2025_CAMERA_SETTING["FRONT_RIGHT"]["height"], |
|
|
WAYMO_E2E_2025_CAMERA_SETTING["FRONT_RIGHT"]["cropped_height"], |
|
|
), |
|
|
2: waymo_e2e_camera_setting_to_carla( |
|
|
WAYMO_E2E_INTRINSIC, |
|
|
WAYMO_E2E_2025_CAMERA_SETTING["FRONT"]["extrinsic"], |
|
|
WAYMO_E2E_2025_CAMERA_SETTING["FRONT"]["width"], |
|
|
WAYMO_E2E_2025_CAMERA_SETTING["FRONT"]["height"], |
|
|
WAYMO_E2E_2025_CAMERA_SETTING["FRONT"]["cropped_height"], |
|
|
), |
|
|
3: waymo_e2e_camera_setting_to_carla( |
|
|
WAYMO_E2E_INTRINSIC, |
|
|
WAYMO_E2E_2025_CAMERA_SETTING["FRONT_LEFT"]["extrinsic"], |
|
|
WAYMO_E2E_2025_CAMERA_SETTING["FRONT_LEFT"]["width"], |
|
|
WAYMO_E2E_2025_CAMERA_SETTING["FRONT_LEFT"]["height"], |
|
|
WAYMO_E2E_2025_CAMERA_SETTING["FRONT_LEFT"]["cropped_height"], |
|
|
), |
|
|
} |
|
|
raise ValueError(f"Unsupported target dataset: {self.target_dataset}") |
|
|
|
|
|
|
|
|
num_radar_sensors = 4 |
|
|
|
|
|
@property |
|
|
def radar_calibration(self): |
|
|
"""Radar sensor calibration configuration with positions and orientations. |
|
|
|
|
|
Currently supports only 4 radar sensors. For other numbers, raises an error. |
|
|
""" |
|
|
return { |
|
|
"1": { |
|
|
"pos": [2.6, 0, 0.60], |
|
|
"rot": [0.0, 0.0, -45.0], |
|
|
"horz_fov": 90, |
|
|
"vert_fov": 0.1, |
|
|
}, |
|
|
"2": { |
|
|
"pos": [2.6, 0, 0.60], |
|
|
"rot": [0.0, 0.0, 45.0], |
|
|
"horz_fov": 90, |
|
|
"vert_fov": 0.1, |
|
|
}, |
|
|
"3": { |
|
|
"pos": [-2.6, 0, 0.60], |
|
|
"rot": [0.0, 0.0, 135], |
|
|
"horz_fov": 90, |
|
|
"vert_fov": 0.1, |
|
|
}, |
|
|
"4": { |
|
|
"pos": [-2.6, 0, 0.60], |
|
|
"rot": [0.0, 0.0, 225], |
|
|
"horz_fov": 90, |
|
|
"vert_fov": 0.1, |
|
|
}, |
|
|
} |
|
|
|
|
|
|
|
|
use_radars = True |
|
|
|
|
|
save_radar_pc_as_lidar = True |
|
|
|
|
|
save_lidar_only_inside_bev = True |
|
|
|
|
|
duplicate_radar_near_ego = True |
|
|
|
|
|
duplicate_radar_radius = 32 |
|
|
|
|
|
duplicate_radar_factor = 5 |
|
|
|
|
|
|
|
|
|
|
|
save_depth_lower_resolution = True |
|
|
|
|
|
@property |
|
|
def save_depth_resolution_ratio(self): |
|
|
"""Resolution reduction ratio for depth image storage.""" |
|
|
if self.is_on_slurm: |
|
|
return 4 |
|
|
return 4 |
|
|
|
|
|
|
|
|
save_depth_bits = 8 |
|
|
|
|
|
save_only_non_ground_lidar = True |
|
|
|
|
|
save_grouped_semantic = True |
|
|
|
|
|
|
|
|
|
|
|
ego_num_temporal_data_points_saved = 200 |
|
|
|
|
|
other_vehicles_num_temporal_data_points_saved = 40 |
|
|
|
|
|
|
|
|
|
|
|
carla_fps = 20 |
|
|
|
|
|
carla_frame_rate = 1.0 / carla_fps |
|
|
|
|
|
iou_treshold_nms = 0.2 |
|
|
|
|
|
route_planner_min_distance = 7.5 |
|
|
|
|
|
route_planner_max_distance = 50.0 |
|
|
|
|
|
dense_route_planner_min_distance = 2.4 |
|
|
|
|
|
inital_frames_delay = 1 |
|
|
|
|
|
|
|
|
tp_distances = [ |
|
|
3.0, |
|
|
3.25, |
|
|
3.5, |
|
|
3.75, |
|
|
4.0, |
|
|
4.25, |
|
|
4.5, |
|
|
4.75, |
|
|
5.0, |
|
|
5.25, |
|
|
5.5, |
|
|
5.75, |
|
|
6.0, |
|
|
6.25, |
|
|
6.5, |
|
|
6.75, |
|
|
7.0, |
|
|
7.25, |
|
|
7.5, |
|
|
7.75, |
|
|
8.0, |
|
|
8.25, |
|
|
8.5, |
|
|
8.75, |
|
|
9.0, |
|
|
9.25, |
|
|
9.5, |
|
|
9.75, |
|
|
10.0, |
|
|
] |
|
|
|
|
|
|
|
|
ego_extent_x = 2.4508416652679443 |
|
|
|
|
|
ego_extent_y = 1.0641621351242065 |
|
|
|
|
|
ego_extent_z = 0.7553732395172119 |
|
|
|
|
|
|
|
|
safety_box_z_min = 0.5 |
|
|
|
|
|
safety_box_z_max = 1.5 |
|
|
|
|
|
|
|
|
@property |
|
|
def safety_box_y_min(self): |
|
|
"""Minimum y coordinate of the safety box relative to ego vehicle.""" |
|
|
return -self.ego_extent_y * 0.8 |
|
|
|
|
|
@property |
|
|
def safety_box_y_max(self): |
|
|
"""Maximum y coordinate of the safety box relative to ego vehicle.""" |
|
|
return self.ego_extent_y * 0.8 |
|
|
|
|
|
@property |
|
|
def safety_box_x_min(self): |
|
|
"""Minimum x coordinate of the safety box relative to ego vehicle.""" |
|
|
return self.ego_extent_x |
|
|
|
|
|
@property |
|
|
def safety_box_x_max(self): |
|
|
"""Maximum x coordinate of the safety box relative to ego vehicle.""" |
|
|
return self.ego_extent_x + 2.5 |
|
|
|
|
|
@property |
|
|
def is_on_slurm(self): |
|
|
"""Check if running on SLURM cluster environment.""" |
|
|
return os.getenv("SLURM_JOB_ID") is not None |
|
|
|
|
|
@property |
|
|
def is_on_tcml(self): |
|
|
"""Check if running on Training Center for Machine Learning of Tübingen.""" |
|
|
return os.getenv("TCML") is not None |
|
|
|
|
|
|
|
|
|
|
|
def parse(self, loaded_config, env_key: str, raise_error_on_missing_key: bool): |
|
|
|
|
|
env_params = os.getenv(env_key, "").strip() |
|
|
if not env_params: |
|
|
env_params = {} |
|
|
else: |
|
|
parsed = OmegaConf.create(OmegaConf.from_dotlist(env_params.split())) |
|
|
env_params = OmegaConf.to_container(parsed, resolve=True) |
|
|
|
|
|
|
|
|
if loaded_config is None: |
|
|
loaded_config = {} |
|
|
|
|
|
|
|
|
for key, value in env_params.items(): |
|
|
loaded_config[key] = value |
|
|
|
|
|
|
|
|
for key, value in loaded_config.items(): |
|
|
if hasattr(self, key) and not callable(getattr(self, key)): |
|
|
try: |
|
|
setattr(self, key, value) |
|
|
except Exception as _: |
|
|
pass |
|
|
elif raise_error_on_missing_key: |
|
|
raise AttributeError(f"Unknown configuration key: {key}") |
|
|
|
|
|
self._loaded_config = loaded_config |
|
|
|
|
|
def __setattr__(self, name, value): |
|
|
|
|
|
|
|
|
|
|
|
allowed = set() |
|
|
for cls in self.__class__.__mro__: |
|
|
allowed.update(getattr(cls, "__dict__", {}).keys()) |
|
|
allowed.update(self.__dict__.keys()) |
|
|
|
|
|
if name not in allowed and not name.startswith("_"): |
|
|
raise AttributeError( |
|
|
f"Can't set unknown attribute '{name}'. Please check if this variable might have been renamed." |
|
|
) |
|
|
super().__setattr__(name, value) |
|
|
|
|
|
def base_dict(self): |
|
|
out = {} |
|
|
for k in dir(self): |
|
|
if k.startswith("_"): |
|
|
continue |
|
|
try: |
|
|
v = getattr(self, k) |
|
|
if not callable(v): |
|
|
out[k] = v |
|
|
except Exception: |
|
|
pass |
|
|
return out |
|
|
|
|
|
|
|
|
def overridable_property(fn): |
|
|
attr_name = fn.__name__ |
|
|
|
|
|
@property |
|
|
def wrapper(self): |
|
|
try: |
|
|
if attr_name in self._loaded_config: |
|
|
return type(fn(self))(self._loaded_config[attr_name]) |
|
|
except: |
|
|
pass |
|
|
return fn(self) |
|
|
|
|
|
return wrapper |
|
|
|
|
|
|
|
|
class TrainingConfig(BaseConfig): |
|
|
@property |
|
|
def target_dataset(self): |
|
|
if self.use_waymo_e2e_data and not self.mixed_data_training: |
|
|
return TargetDataset.WAYMO_E2E_2025_3CAMERAS |
|
|
elif self.use_navsim_data and not self.use_carla_data: |
|
|
return TargetDataset.NAVSIM_4CAMERAS |
|
|
elif "carla_leaderboad2_v3" in self.carla_root: |
|
|
return TargetDataset.CARLA_LEADERBOARD2_3CAMERAS |
|
|
elif "carla_leaderboad2_v8" in self.carla_root: |
|
|
return TargetDataset.CARLA_LEADERBOARD2_3CAMERAS |
|
|
elif "carla_leaderboad2_v10" in self.carla_root: |
|
|
return TargetDataset.CARLA_LEADERBOARD2_6CAMERAS |
|
|
elif "carla_leaderboad2_v12" in self.carla_root: |
|
|
return TargetDataset.NAVSIM_4CAMERAS |
|
|
elif "carla_leaderboad2_v14" in self.carla_root: |
|
|
return TargetDataset.WAYMO_E2E_2025_3CAMERAS |
|
|
elif "data/carla_today/results/data/garage_v2_2025_06_23" in self.carla_root: |
|
|
return TargetDataset.NAVSIM_4CAMERAS |
|
|
raise ValueError(f"Unknown CARLA root path: {self.carla_root}") |
|
|
|
|
|
@property |
|
|
def num_available_cameras(self): |
|
|
"""Number of available cameras based on the target dataset.""" |
|
|
return { |
|
|
TargetDataset.CARLA_LEADERBOARD2_3CAMERAS: 3, |
|
|
TargetDataset.CARLA_LEADERBOARD2_6CAMERAS: 6, |
|
|
TargetDataset.CARLA_LEADERBOARD2_3CAMERAS: 3, |
|
|
TargetDataset.NAVSIM_4CAMERAS: 4, |
|
|
TargetDataset.WAYMO_E2E_2025_3CAMERAS: 3, |
|
|
}[self.target_dataset] |
|
|
|
|
|
@overridable_property |
|
|
def used_cameras(self): |
|
|
"""List indicating which cameras are used based on the target dataset. |
|
|
Can be overriden, if a camera is false it will be ignored during training.""" |
|
|
return [True] * self.num_available_cameras |
|
|
|
|
|
@property |
|
|
def num_used_cameras(self): |
|
|
"""Number of cameras used during training.""" |
|
|
return sum(int(use) for use in self.used_cameras) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pixels_per_meter = 4.0 |
|
|
|
|
|
@property |
|
|
def min_x_meter(self): |
|
|
"""Back boundary of the planning area in meters.""" |
|
|
if self.target_dataset == TargetDataset.WAYMO_E2E_2025_3CAMERAS: |
|
|
return 0 |
|
|
return -32 |
|
|
|
|
|
@property |
|
|
def max_x_meter(self): |
|
|
"""Front boundary of the planning area in meters.""" |
|
|
if self.carla_leaderboard_mode: |
|
|
return 64 |
|
|
if self.target_dataset == TargetDataset.WAYMO_E2E_2025_3CAMERAS: |
|
|
return 64 |
|
|
return 32 |
|
|
|
|
|
@property |
|
|
def min_y_meter(self): |
|
|
"""Left boundary of the planning area in meters.""" |
|
|
if self.carla_leaderboard_mode: |
|
|
return -40 |
|
|
return -32 |
|
|
|
|
|
@property |
|
|
def max_y_meter(self): |
|
|
"""Right boundary of the planning area in meters.""" |
|
|
if self.carla_leaderboard_mode: |
|
|
return 40 |
|
|
return 32 |
|
|
|
|
|
@property |
|
|
def lidar_width_pixel(self): |
|
|
"""Width resolution of LiDAR BEV representation in pixels.""" |
|
|
return int((self.max_x_meter - self.min_x_meter) * self.pixels_per_meter) |
|
|
|
|
|
@property |
|
|
def lidar_height_pixel(self): |
|
|
"""Height resolution of LiDAR BEV representation in pixels.""" |
|
|
return int((self.max_y_meter - self.min_y_meter) * self.pixels_per_meter) |
|
|
|
|
|
@property |
|
|
def lidar_width_meter(self): |
|
|
"""Width of LiDAR coverage area in meters.""" |
|
|
return int(self.max_x_meter - self.min_x_meter) |
|
|
|
|
|
@property |
|
|
def lidar_height_meter(self): |
|
|
"""Height of LiDAR coverage area in meters.""" |
|
|
return int(self.max_y_meter - self.min_y_meter) |
|
|
|
|
|
|
|
|
visualize_dataset = False |
|
|
|
|
|
visualize_failed_scenarios = False |
|
|
|
|
|
load_bev_3rd_person_images = False |
|
|
|
|
|
|
|
|
|
|
|
seed = 0 |
|
|
|
|
|
wandb_id: str = None |
|
|
|
|
|
wandb_resume = "never" |
|
|
|
|
|
description = "My Experiment." |
|
|
|
|
|
visualize_training = True |
|
|
|
|
|
id = "Experiment 1." |
|
|
|
|
|
load_file: str = None |
|
|
|
|
|
continue_epoch = False |
|
|
|
|
|
@property |
|
|
def epoch_checkpoints_keep(self): |
|
|
"""Number of checkpoints to keep during training.""" |
|
|
if self.use_carla_data and not self.mixed_data_training: |
|
|
return [] |
|
|
return [sum([1 * 2**i for i in range(n)]) for n in range(3, 10)] |
|
|
|
|
|
|
|
|
|
|
|
use_training_session_cache = True |
|
|
|
|
|
use_persistent_cache = False |
|
|
|
|
|
force_rebuild_data_cache = False |
|
|
|
|
|
@property |
|
|
def carla_cache_path(self): |
|
|
"""Tuple of cache characteristics used to identify cached data compatibility.""" |
|
|
return ( |
|
|
str(self.image_width_before_camera_subselection), |
|
|
str(self.final_image_height), |
|
|
str(self.min_x_meter), |
|
|
str(self.max_x_meter), |
|
|
str(self.min_y_meter), |
|
|
str(self.max_y_meter), |
|
|
str(self.detect_boxes), |
|
|
str(self.use_depth), |
|
|
str(self.use_semantic), |
|
|
str(self.use_bev_semantic), |
|
|
str(self.load_bev_3rd_person_images), |
|
|
str(self.training_used_lidar_steps), |
|
|
str(self.num_radar_forecast_steps), |
|
|
str(self.radar_waypoints_spacing), |
|
|
) |
|
|
|
|
|
@property |
|
|
def ssd_cache_path(self): |
|
|
"""Path to SSD cache directory.""" |
|
|
tmp_folder = "/scratch/" + str(os.environ.get("SLURM_JOB_ID")) |
|
|
if not self.is_on_tcml: |
|
|
tmp_folder = str(os.environ.get("SCRATCH", "/tmp")) |
|
|
return tmp_folder |
|
|
|
|
|
|
|
|
carla_root = "data/carla_leaderboad2_v12/results/data/sensor_data/" |
|
|
|
|
|
@property |
|
|
def carla_data(self): |
|
|
"""Path to CARLA data directory.""" |
|
|
return os.path.join(self.carla_root, "data") |
|
|
|
|
|
|
|
|
|
|
|
logdir = None |
|
|
|
|
|
training_png_compression_level = 6 |
|
|
|
|
|
vehicle_min_num_lidar_points = 1 |
|
|
|
|
|
vehicle_min_num_visible_pixels = 1 |
|
|
|
|
|
pedestrian_min_num_lidar_points = 5 |
|
|
|
|
|
pedestrian_min_num_visible_pixels = 15 |
|
|
|
|
|
parking_vehicle_min_num_lidar_points = 3 |
|
|
|
|
|
parking_vehicle_min_num_visible_pixels = 5 |
|
|
|
|
|
grad_scaler_init_scale = 1024 |
|
|
|
|
|
grad_scaler_growth_factor = 2 |
|
|
|
|
|
grad_scaler_backoff_factor = 0.5 |
|
|
|
|
|
grad_scaler_growth_interval = 256 |
|
|
|
|
|
grad_scaler_max_grad_scale = 2**16 |
|
|
|
|
|
@property |
|
|
def sync_batchnorm(self) -> bool: |
|
|
"""If true synchronize batch normalization across distributed processes.""" |
|
|
return False |
|
|
|
|
|
|
|
|
freeze_batchnorm = False |
|
|
|
|
|
@overridable_property |
|
|
def epochs(self): |
|
|
"""Total number of training epochs.""" |
|
|
if self.carla_leaderboard_mode: |
|
|
return 31 |
|
|
if self.target_dataset == TargetDataset.NAVSIM_4CAMERAS: |
|
|
return 61 |
|
|
if self.target_dataset == TargetDataset.WAYMO_E2E_2025_3CAMERAS: |
|
|
return 20 |
|
|
raise ValueError("Unknown target dataset") |
|
|
|
|
|
@overridable_property |
|
|
def batch_size(self): |
|
|
"""Batch size for training.""" |
|
|
if not self.is_on_slurm: |
|
|
return 4 |
|
|
return 64 |
|
|
|
|
|
@property |
|
|
def torch_float_type(self): |
|
|
"""PyTorch float precision type for training.""" |
|
|
if self.use_mixed_precision_training and self.gpu_name in ["a100", "l40s"]: |
|
|
return torch.bfloat16 |
|
|
return torch.float32 |
|
|
|
|
|
@property |
|
|
def use_mixed_precision_training(self): |
|
|
"""If true use mixed precision training.""" |
|
|
return self.gpu_name in ["a100", "l40s"] |
|
|
|
|
|
@property |
|
|
def need_grad_scaler(self): |
|
|
"""If true gradient scaling is needed for mixed precision training.""" |
|
|
return self.use_mixed_precision_training and self.torch_float_type == torch.float16 |
|
|
|
|
|
|
|
|
use_zero_redundancy = False |
|
|
|
|
|
@property |
|
|
def save_model_checkpoint(self): |
|
|
"""If true save model checkpoints during training.""" |
|
|
if self.is_on_slurm: |
|
|
return True |
|
|
return True |
|
|
|
|
|
@property |
|
|
def is_pretraining(self): |
|
|
"""If true indicates pretraining phase.""" |
|
|
return not self.use_planning_decoder |
|
|
|
|
|
|
|
|
|
|
|
prefetch_factor = 8 |
|
|
|
|
|
@property |
|
|
def compile(self): |
|
|
"""If true compile the model for optimization.""" |
|
|
return True |
|
|
|
|
|
@property |
|
|
def channel_last(self): |
|
|
"""If true use channel last memory format for input tensors.""" |
|
|
return True |
|
|
|
|
|
|
|
|
|
|
|
lr = 3e-4 |
|
|
|
|
|
|
|
|
@property |
|
|
def skip_first(self): |
|
|
"""Number of frames to skip at the beginning of sequences.""" |
|
|
if self.is_pretraining and not self.mixed_data_training: |
|
|
return 1 |
|
|
return self.num_way_points_prediction |
|
|
|
|
|
@property |
|
|
def skip_last(self): |
|
|
"""Number of frames to skip at the end of sequences.""" |
|
|
if self.is_pretraining: |
|
|
return 1 |
|
|
if self.carla_leaderboard_mode: |
|
|
return self.num_way_points_prediction |
|
|
if self.target_dataset in [ |
|
|
TargetDataset.NAVSIM_4CAMERAS, |
|
|
]: |
|
|
return self.num_way_points_prediction * 2 |
|
|
raise ValueError("Unknown target dataset") |
|
|
|
|
|
|
|
|
@property |
|
|
def radar_detection(self): |
|
|
"""If true use radar points as additional input to the model.""" |
|
|
return self.carla_leaderboard_mode and not self.use_tfpp_planning_decoder |
|
|
|
|
|
@overridable_property |
|
|
def use_radar_detection(self): |
|
|
"""If true use radar points as additional input to the model.""" |
|
|
return self.carla_leaderboard_mode and not self.use_tfpp_planning_decoder |
|
|
|
|
|
|
|
|
num_radar_points_per_sensor = 75 |
|
|
|
|
|
num_radar_queries = 20 |
|
|
|
|
|
radar_hidden_dim_tokenizer = 1024 |
|
|
|
|
|
radar_token_dim = 256 |
|
|
|
|
|
radar_tf_dim_ff = 1024 |
|
|
|
|
|
radar_dropout = 0.1 |
|
|
|
|
|
radar_num_heads = 8 |
|
|
|
|
|
radar_num_layers = 4 |
|
|
|
|
|
radar_hidden_dim_decoder = 1024 |
|
|
|
|
|
radar_classification_loss_weight = 1.0 |
|
|
|
|
|
radar_forecast_loss_weight = 0.1 |
|
|
|
|
|
radar_regression_loss_weight = 5.0 |
|
|
|
|
|
radar_sine_features = 128 |
|
|
|
|
|
num_radar_sensors = 4 |
|
|
|
|
|
forecast_radar_detections = False |
|
|
|
|
|
num_radar_forecast_steps = 8 |
|
|
|
|
|
radar_waypoints_spacing = 5 |
|
|
|
|
|
max_distance_future_waypoint = 10.0 |
|
|
|
|
|
|
|
|
|
|
|
force_rebuild_bucket = False |
|
|
|
|
|
randomize_route_order = False |
|
|
|
|
|
subsample_data = False |
|
|
|
|
|
hold_out_town13_routes = False |
|
|
|
|
|
@property |
|
|
def carla_bucket_collection(self): |
|
|
"""Name of the bucket collection to use for training data.""" |
|
|
if ( |
|
|
self.use_carla_data |
|
|
and self.use_waymo_e2e_data |
|
|
and not self.force_rebuild_bucket |
|
|
and not self.force_rebuild_data_cache |
|
|
): |
|
|
return "navsim" |
|
|
if self.use_carla_data and self.use_navsim_data and not self.force_rebuild_bucket and not self.force_rebuild_data_cache: |
|
|
return "navsim" |
|
|
if self.visualize_failed_scenarios: |
|
|
return "failed" |
|
|
if self.subsample_data: |
|
|
return "subsampled_posttrain" |
|
|
if self.is_pretraining and self.hold_out_town13_routes: |
|
|
return "town13_heldout_pretrain" |
|
|
if self.is_pretraining or self.visualize_dataset or self.force_rebuild_data_cache: |
|
|
return "full_pretrain" |
|
|
if not self.is_pretraining and self.hold_out_town13_routes: |
|
|
return "town13_heldout_posttrain" |
|
|
return "full_posttrain" |
|
|
|
|
|
@property |
|
|
def bucket_collection_path(self): |
|
|
"""Path to bucket collection directory.""" |
|
|
return os.path.join(self.carla_root, "buckets") |
|
|
|
|
|
|
|
|
|
|
|
use_sensor_perburtation = True |
|
|
|
|
|
|
|
|
@overridable_property |
|
|
def use_sensor_perburtation_prob(self): |
|
|
if not self.carla_leaderboard_mode: |
|
|
return 0.8 |
|
|
return 0.5 |
|
|
|
|
|
@overridable_property |
|
|
def use_extreme_sensor_perburtation_prob(self): |
|
|
"""Probability of extreme sensor perburtation being applied if we sampled an sensor perburtated sample.""" |
|
|
if self.is_pretraining and not self.carla_leaderboard_mode: |
|
|
return 0.5 |
|
|
return 0.0 |
|
|
|
|
|
|
|
|
@property |
|
|
def use_color_aug(self): |
|
|
"""If true apply image color based augmentations.""" |
|
|
|
|
|
return not self.visualize_dataset |
|
|
|
|
|
@property |
|
|
def use_color_aug_prob(self): |
|
|
"""Probability to apply the different image color augmentations.""" |
|
|
if self.carla_leaderboard_mode: |
|
|
return 0.2 |
|
|
return 0.1 |
|
|
|
|
|
|
|
|
weight_decay = 0.01 |
|
|
|
|
|
|
|
|
@property |
|
|
def use_grad_clip(self): |
|
|
"""If true use gradient clipping during training.""" |
|
|
return False |
|
|
|
|
|
|
|
|
grad_quantile = 0.99 |
|
|
|
|
|
grad_history_length = 1000 |
|
|
|
|
|
use_optimizer_groups = False |
|
|
|
|
|
|
|
|
@overridable_property |
|
|
def use_cosine_annealing_with_restarts(self): |
|
|
if self.target_dataset == TargetDataset.WAYMO_E2E_2025_3CAMERAS: |
|
|
return False |
|
|
return True |
|
|
|
|
|
|
|
|
@overridable_property |
|
|
def use_depth(self): |
|
|
"""If true use depth prediction as auxiliary task.""" |
|
|
return self.carla_leaderboard_mode |
|
|
|
|
|
|
|
|
@property |
|
|
def training_used_lidar_steps(self): |
|
|
"""We stack lidar frames for motion cues. Number of past frames we stack for the model input.""" |
|
|
return 10 |
|
|
|
|
|
|
|
|
min_z = -4 |
|
|
|
|
|
max_z = 4 |
|
|
|
|
|
hist_max_per_pixel = 5 |
|
|
|
|
|
@property |
|
|
def lidar_vert_anchors(self): |
|
|
"""Number of vertical anchors for LiDAR feature maps.""" |
|
|
return self.lidar_height_pixel // 32 |
|
|
|
|
|
@property |
|
|
def lidar_horz_anchors(self): |
|
|
"""Number of horizontal anchors for LiDAR feature maps.""" |
|
|
return self.lidar_width_pixel // 32 |
|
|
|
|
|
|
|
|
|
|
|
detect_boxes = True |
|
|
|
|
|
debug_boxes_visualization = True |
|
|
|
|
|
center_net_global_avg_factor = False |
|
|
|
|
|
data_bb_static_types_white_list = ["static.prop.constructioncone", "static.prop.trafficwarning"] |
|
|
|
|
|
bb_confidence_threshold = 0.3 |
|
|
|
|
|
max_num_bbs = 90 |
|
|
|
|
|
|
|
|
num_dir_bins = 12 |
|
|
|
|
|
top_k_center_keypoints = 100 |
|
|
|
|
|
center_net_max_pooling_kernel = 3 |
|
|
|
|
|
bb_input_channel = 64 |
|
|
|
|
|
car_open_door_extra_width = 1.2 |
|
|
|
|
|
num_bb_classes = len(TransfuserBoundingBoxClass) |
|
|
|
|
|
|
|
|
@overridable_property |
|
|
def use_discrete_command(self): |
|
|
"""If true use discrete command input to the network.""" |
|
|
if self.target_dataset == TargetDataset.WAYMO_E2E_2025_3CAMERAS: |
|
|
return False |
|
|
return True |
|
|
|
|
|
@property |
|
|
def discrete_command_dim(self): |
|
|
"""Dimension of discrete command input.""" |
|
|
if self.carla_leaderboard_mode: |
|
|
return 6 |
|
|
elif self.target_dataset in [ |
|
|
TargetDataset.NAVSIM_4CAMERAS, |
|
|
]: |
|
|
return 4 |
|
|
elif self.target_dataset in [ |
|
|
TargetDataset.WAYMO_E2E_2025_3CAMERAS, |
|
|
]: |
|
|
return 4 |
|
|
raise ValueError("Unknown target dataset") |
|
|
|
|
|
|
|
|
use_noisy_tp = False |
|
|
|
|
|
use_velocity = True |
|
|
|
|
|
@property |
|
|
def max_speed(self): |
|
|
"""Maximum speed limit for the vehicle in m/s.""" |
|
|
if self.carla_leaderboard_mode: |
|
|
return 25.0 |
|
|
if self.target_dataset in [ |
|
|
TargetDataset.NAVSIM_4CAMERAS, |
|
|
]: |
|
|
return 15.0 |
|
|
if self.target_dataset in [ |
|
|
TargetDataset.WAYMO_E2E_2025_3CAMERAS, |
|
|
]: |
|
|
return 33.33 |
|
|
raise ValueError("Unknown target dataset") |
|
|
|
|
|
@property |
|
|
def use_acceleration(self): |
|
|
"""If true use the acceleration as input to the network.""" |
|
|
return not self.carla_leaderboard_mode and self.target_dataset not in [TargetDataset.WAYMO_E2E_2025_3CAMERAS] |
|
|
|
|
|
@property |
|
|
def max_acceleration(self): |
|
|
"""Maximum acceleration for normalization.""" |
|
|
if self.carla_leaderboard_mode: |
|
|
return 10.0 |
|
|
if self.target_dataset in [ |
|
|
TargetDataset.NAVSIM_4CAMERAS, |
|
|
]: |
|
|
return 4.0 |
|
|
|
|
|
@property |
|
|
def use_previous_tp(self): |
|
|
"""If true use the previous/visited target point as input to the network.""" |
|
|
if self.carla_leaderboard_mode and not self.use_tfpp_planning_decoder: |
|
|
return True |
|
|
return False |
|
|
|
|
|
@property |
|
|
def use_next_tp(self): |
|
|
"""If true use the next/subsequent target point as input to the network.""" |
|
|
if self.carla_leaderboard_mode and not self.use_tfpp_planning_decoder: |
|
|
return True |
|
|
return False |
|
|
|
|
|
@property |
|
|
def use_tp(self): |
|
|
"""If true use the current target point as input to the network.""" |
|
|
if self.carla_leaderboard_mode: |
|
|
return True |
|
|
return False |
|
|
|
|
|
@property |
|
|
def target_points_normalization_constants(self): |
|
|
"""Normalization constants for target points [x_norm, y_norm].""" |
|
|
return [[200.0, 50.0]] |
|
|
|
|
|
@property |
|
|
def tp_pop_distance(self): |
|
|
"""Distance threshold for popping target points from route.""" |
|
|
return 3.25 |
|
|
|
|
|
@overridable_property |
|
|
def use_past_positions(self): |
|
|
"""If true use past positions as input to the network.""" |
|
|
return not self.carla_leaderboard_mode and self.target_dataset in [TargetDataset.WAYMO_E2E_2025_3CAMERAS] |
|
|
|
|
|
@overridable_property |
|
|
def use_past_speeds(self): |
|
|
"""If true use past speeds as input to the network.""" |
|
|
return not self.carla_leaderboard_mode and self.target_dataset in [TargetDataset.WAYMO_E2E_2025_3CAMERAS] |
|
|
|
|
|
@overridable_property |
|
|
def num_past_samples_used(self): |
|
|
"""Number of past samples to use as input to the network.""" |
|
|
if self.use_past_positions or self.use_past_speeds: |
|
|
return 6 |
|
|
return 0 |
|
|
|
|
|
|
|
|
|
|
|
transfuser_num_bev_cross_attention_layers = 6 |
|
|
|
|
|
transfuser_num_bev_cross_attention_heads = 8 |
|
|
|
|
|
transfuser_token_dim = 256 |
|
|
|
|
|
@property |
|
|
def predict_target_speed(self): |
|
|
"""If true predict target speed.""" |
|
|
return self.carla_leaderboard_mode |
|
|
|
|
|
@property |
|
|
def predict_spatial_path(self): |
|
|
"""If true predict spatial path.""" |
|
|
return self.carla_leaderboard_mode |
|
|
|
|
|
|
|
|
predict_temporal_spatial_waypoints = True |
|
|
|
|
|
|
|
|
use_planning_decoder = False |
|
|
|
|
|
|
|
|
use_tfpp_planning_decoder = False |
|
|
|
|
|
|
|
|
gru_hidden_size = 64 |
|
|
|
|
|
@property |
|
|
def target_speed_classes(self): |
|
|
"""Carla target speed prediction classes in m/s.""" |
|
|
return [ |
|
|
0.0, |
|
|
4.0, |
|
|
8.0, |
|
|
10.0, |
|
|
13.88888888, |
|
|
16.0, |
|
|
17.77777777, |
|
|
20.0, |
|
|
] |
|
|
|
|
|
@property |
|
|
def target_speeds(self): |
|
|
return self.target_speed_classes |
|
|
|
|
|
|
|
|
smooth_route = True |
|
|
|
|
|
num_route_points_smoothing = 20 |
|
|
|
|
|
num_route_points_prediction = 10 |
|
|
|
|
|
@property |
|
|
def num_way_points_prediction(self): |
|
|
"""Number of waypoints to predict.""" |
|
|
if self.carla_leaderboard_mode: |
|
|
return 8 |
|
|
elif self.target_dataset in [ |
|
|
TargetDataset.NAVSIM_4CAMERAS, |
|
|
]: |
|
|
return 8 |
|
|
elif self.target_dataset in [TargetDataset.WAYMO_E2E_2025_3CAMERAS]: |
|
|
return 10 |
|
|
raise ValueError("Unknown target dataset") |
|
|
|
|
|
|
|
|
@property |
|
|
def waypoints_spacing(self): |
|
|
if self.carla_leaderboard_mode: |
|
|
return 5 |
|
|
elif self.target_dataset in [ |
|
|
TargetDataset.NAVSIM_4CAMERAS, |
|
|
]: |
|
|
return 10 |
|
|
elif self.target_dataset in [TargetDataset.WAYMO_E2E_2025_3CAMERAS]: |
|
|
return 10 |
|
|
raise ValueError("Unknown target dataset") |
|
|
|
|
|
|
|
|
@property |
|
|
def crop_height(self): |
|
|
"""The amount of pixels cropped from the bottom of the image.""" |
|
|
return self.camera_calibration[1]["height"] - self.camera_calibration[1]["cropped_height"] |
|
|
|
|
|
@property |
|
|
def carla_crop_height_type(self): |
|
|
"""Type of cropping applied to CARLA images.""" |
|
|
if self.carla_leaderboard_mode: |
|
|
return CarlaImageCroppingType.NONE |
|
|
elif self.target_dataset in [ |
|
|
TargetDataset.NAVSIM_4CAMERAS, |
|
|
]: |
|
|
return CarlaImageCroppingType.BOTTOM |
|
|
elif self.target_dataset in [TargetDataset.WAYMO_E2E_2025_3CAMERAS]: |
|
|
return CarlaImageCroppingType.NONE |
|
|
raise ValueError("Unknown target dataset") |
|
|
|
|
|
@property |
|
|
def image_width_before_camera_subselection(self): |
|
|
"""Final width of images after loading from disk but before camera sub-selection.""" |
|
|
return self.num_available_cameras * self.camera_calibration[1]["width"] |
|
|
|
|
|
@property |
|
|
def final_image_width(self): |
|
|
"""Final width of images after cropping and camera sub-selection.""" |
|
|
return self.num_used_cameras * self.camera_calibration[1]["width"] |
|
|
|
|
|
@property |
|
|
def final_image_height(self): |
|
|
"""Final height of images after cropping.""" |
|
|
return self.camera_calibration[1]["cropped_height"] |
|
|
|
|
|
@property |
|
|
def img_vert_anchors(self): |
|
|
"""Number of vertical anchors for image feature maps.""" |
|
|
return self.final_image_height // 32 |
|
|
|
|
|
@property |
|
|
def img_horz_anchors(self): |
|
|
"""Number of horizontal anchors for image feature maps.""" |
|
|
return self.num_used_cameras * self.camera_calibration[1]["width"] // 32 |
|
|
|
|
|
|
|
|
|
|
|
freeze_backbone = False |
|
|
|
|
|
image_architecture = "resnet34" |
|
|
|
|
|
lidar_architecture = "resnet34" |
|
|
|
|
|
LTF = False |
|
|
|
|
|
num_fusion_stages = 4 |
|
|
|
|
|
only_ego_status_input = False |
|
|
|
|
|
|
|
|
|
|
|
block_exp = 4 |
|
|
|
|
|
n_layer = 2 |
|
|
|
|
|
n_head = 4 |
|
|
|
|
|
embd_pdrop = 0.1 |
|
|
|
|
|
resid_pdrop = 0.1 |
|
|
|
|
|
attn_pdrop = 0.1 |
|
|
|
|
|
gpt_linear_layer_init_mean = 0.0 |
|
|
|
|
|
gpt_linear_layer_init_std = 0.02 |
|
|
|
|
|
gpt_layer_norm_init_weight = 1.0 |
|
|
|
|
|
|
|
|
|
|
|
use_semantic = True |
|
|
|
|
|
num_semantic_classes = len(TransfuserSemanticSegmentationClass) |
|
|
|
|
|
perspective_downsample_factor = 1 |
|
|
|
|
|
deconv_channel_num_0 = 128 |
|
|
|
|
|
deconv_channel_num_1 = 64 |
|
|
|
|
|
deconv_channel_num_2 = 32 |
|
|
|
|
|
deconv_scale_factor_0 = 4 |
|
|
|
|
|
deconv_scale_factor_1 = 8 |
|
|
|
|
|
|
|
|
|
|
|
use_bev_semantic = True |
|
|
|
|
|
num_bev_semantic_classes = len(TransfuserBEVSemanticClass) |
|
|
|
|
|
scale_pedestrian_bev_semantic_size = 2.5 |
|
|
|
|
|
pedestrian_bev_min_extent = 0.4 |
|
|
|
|
|
bev_features_chanels = 64 |
|
|
|
|
|
bev_down_sample_factor = 4 |
|
|
|
|
|
bev_upsample_factor = 2 |
|
|
|
|
|
|
|
|
|
|
|
use_carla_data = True |
|
|
|
|
|
carla_num_samples = -1 |
|
|
|
|
|
use_navsim_data = False |
|
|
|
|
|
navsim_data_root = "data/navsim_training_cache/trainval" |
|
|
|
|
|
navsim_num_samples = -1 |
|
|
|
|
|
schedule_carla_num_samples = False |
|
|
|
|
|
use_waymo_e2e_data = False |
|
|
|
|
|
waymo_e2e_num_training_samples = -1 |
|
|
|
|
|
waymo_e2e_training_data_root = "data/waymo_open_dataset_end_to_end_camera_v_1_0_0_training" |
|
|
|
|
|
waymo_e2e_val_data_root = "data/waymo_open_dataset_end_to_end_camera_v_1_0_0_val_rfm" |
|
|
|
|
|
waymo_e2e_test_data_root = "data/waymo_open_dataset_end_to_end_camera_v_1_0_0_test_submission" |
|
|
|
|
|
waymo_e2e_subsample_factor = 5 |
|
|
|
|
|
@property |
|
|
def navsim_num_bev_semantic_classes(self): |
|
|
"""Number of BEV semantic classes in NavSim data.""" |
|
|
return len(NavSimBEVSemanticClass) |
|
|
|
|
|
@property |
|
|
def navsim_num_bb_classes(self): |
|
|
"""Number of bb classes in NavSim data.""" |
|
|
return len(NavSimBBClass) |
|
|
|
|
|
@property |
|
|
def mixed_data_training(self): |
|
|
"""If true use mixed data for training.""" |
|
|
return int(self.use_navsim_data) + int(self.use_carla_data) + int(self.use_waymo_e2e_data) > 1 |
|
|
|
|
|
@property |
|
|
def carla_leaderboard_mode(self): |
|
|
"""If true use CARLA leaderboard mode settings.""" |
|
|
return ( |
|
|
self.target_dataset |
|
|
in [ |
|
|
TargetDataset.CARLA_LEADERBOARD2_3CAMERAS, |
|
|
TargetDataset.CARLA_LEADERBOARD2_6CAMERAS, |
|
|
] |
|
|
and not self.mixed_data_training |
|
|
) |
|
|
|
|
|
@beartype |
|
|
def detailed_loss_weights(self, source_dataset: int, epoch: int) -> dict[str, float]: |
|
|
"""Computed loss weights for all auxiliary tasks with normalization.""" |
|
|
|
|
|
weights = { |
|
|
"loss_semantic": 1.0, |
|
|
"loss_depth": 0.00001, |
|
|
"loss_bev_semantic": 1.0, |
|
|
"loss_center_net_heatmap": 1.0, |
|
|
"loss_center_net_wh": 1.0, |
|
|
"loss_center_net_offset": 1.0, |
|
|
"loss_center_net_yaw_class": 1.0, |
|
|
"loss_center_net_yaw_res": 1.0, |
|
|
"loss_center_net_velocity": 1.0, |
|
|
"radar_loss": 1.0, |
|
|
} |
|
|
|
|
|
if source_dataset != SourceDataset.CARLA: |
|
|
weights["radar_loss"] = 0.0 |
|
|
weights["loss_semantic"] = 0.0 |
|
|
weights["loss_depth"] = 0.0 |
|
|
weights["loss_center_net_velocity"] = 0.0 |
|
|
|
|
|
if self.LTF: |
|
|
weights["loss_center_net_velocity"] = 0.0 |
|
|
|
|
|
if not self.use_semantic: |
|
|
weights["loss_semantic"] = 0.0 |
|
|
|
|
|
if not self.use_depth: |
|
|
weights["loss_depth"] = 0.0 |
|
|
|
|
|
if not self.use_bev_semantic: |
|
|
weights["loss_bev_semantic"] = 0.0 |
|
|
|
|
|
if not self.detect_boxes: |
|
|
weights["loss_center_net_heatmap"] = 0.0 |
|
|
weights["loss_center_net_wh"] = 0.0 |
|
|
weights["loss_center_net_offset"] = 0.0 |
|
|
weights["loss_center_net_yaw_class"] = 0.0 |
|
|
weights["loss_center_net_yaw_res"] = 0.0 |
|
|
weights["loss_center_net_velocity"] = 0.0 |
|
|
|
|
|
if self.training_used_lidar_steps <= 1: |
|
|
weights["loss_center_net_velocity"] = 0.0 |
|
|
|
|
|
if not self.radar_detection: |
|
|
weights["radar_loss"] = 0.0 |
|
|
|
|
|
|
|
|
prefix = f"{SOURCE_DATASET_NAME_MAP[source_dataset]}_" |
|
|
if source_dataset == SourceDataset.CARLA: |
|
|
prefix = "" |
|
|
weights = {f"{prefix}{k}": v for k, v in weights.items()} |
|
|
|
|
|
|
|
|
weights.update( |
|
|
{ |
|
|
"loss_spatio_temporal_waypoints": 1.0, |
|
|
"loss_target_speed": 1.0, |
|
|
"loss_spatial_route": 1.0, |
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
if not self.use_planning_decoder: |
|
|
weights["loss_spatio_temporal_waypoints"] = 0.0 |
|
|
weights["loss_spatial_route"] = 0.0 |
|
|
weights["loss_target_speed"] = 0.0 |
|
|
|
|
|
return weights |
|
|
|
|
|
@property |
|
|
def log_scalars_frequency(self): |
|
|
"""How often to log scalar values during training.""" |
|
|
if not self.is_on_slurm: |
|
|
return 1 |
|
|
try: |
|
|
with open("scripts/slurm/configs/wandb_log_frequency_training_scalar.txt") as f: |
|
|
return int(f.readline().strip()) |
|
|
except Exception as e: |
|
|
print(f"Error reading log frequency file: {e}.") |
|
|
return 1 |
|
|
|
|
|
@property |
|
|
def log_images_frequency(self): |
|
|
"""How often to log images during training.""" |
|
|
if not self.is_on_slurm: |
|
|
return 1 |
|
|
try: |
|
|
with open("scripts/slurm/configs/wandb_log_frequency_training_images.txt") as f: |
|
|
return int(f.readline().strip()) |
|
|
except Exception as e: |
|
|
print(f"Error reading log frequency file: {e}.") |
|
|
return 1000 |
|
|
|
|
|
@property |
|
|
def slurm_job_id(self): |
|
|
"""Current SLURM job ID if running on SLURM cluster.""" |
|
|
if self.is_on_slurm: |
|
|
return os.environ.get("SLURM_JOB_ID", "0") |
|
|
return None |
|
|
|
|
|
@property |
|
|
def log_wandb(self): |
|
|
"""If true log metrics to Weights & Biases.""" |
|
|
if self.is_on_slurm: |
|
|
return True |
|
|
return False |
|
|
|
|
|
|
|
|
@property |
|
|
def gpu_name(self): |
|
|
"""Normalized GPU name for hardware-specific configurations.""" |
|
|
try: |
|
|
name = torch.cuda.get_device_name().lower() |
|
|
if "rtx 2080 ti" in name: |
|
|
return "rtx2080ti" |
|
|
elif "gtx 1080 ti" in name: |
|
|
return "gtx1080ti" |
|
|
elif "a100" in name: |
|
|
return "a100" |
|
|
elif "l40s" in name: |
|
|
return "l40s" |
|
|
elif "a4000" in name: |
|
|
return "a4000" |
|
|
else: |
|
|
raise Exception(f"Unknown GPU name: {name}") |
|
|
except RuntimeError: |
|
|
return "" |
|
|
|
|
|
@property |
|
|
def rank(self): |
|
|
"""Current process rank in distributed training.""" |
|
|
return int(os.environ.get("RANK", "0")) |
|
|
|
|
|
@property |
|
|
def world_size(self): |
|
|
"""Total number of processes in distributed training.""" |
|
|
return int(os.environ.get("WORLD_SIZE", "1")) |
|
|
|
|
|
@property |
|
|
def local_rank(self): |
|
|
"""Local rank of current process on the node.""" |
|
|
return int(os.environ.get("LOCAL_RANK", "0")) |
|
|
|
|
|
@property |
|
|
def device(self): |
|
|
"""PyTorch device to use for training.""" |
|
|
return torch.device(f"cuda:{self.local_rank}" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
@property |
|
|
def assigned_cpu_cores(self): |
|
|
"""Number of CPU cores assigned to this job.""" |
|
|
if "SLURM_JOB_ID" in os.environ: |
|
|
cpus_per_task = os.environ.get("SLURM_CPUS_PER_TASK") |
|
|
if cpus_per_task: |
|
|
return int(cpus_per_task) |
|
|
return 8 |
|
|
|
|
|
@property |
|
|
def workers_per_cpu_cores(self): |
|
|
"""Number of data loader workers per CPU core.""" |
|
|
if not self.mixed_data_training and not self.use_carla_data: |
|
|
return 1 |
|
|
return 1 |
|
|
|
|
|
def training_dict(self): |
|
|
out = {} |
|
|
cls = self.__class__ |
|
|
for k, v in cls.__dict__.items(): |
|
|
if isinstance(v, property): |
|
|
try: |
|
|
out[k] = getattr(self, k) |
|
|
except Exception: |
|
|
pass |
|
|
else: |
|
|
if not k.startswith("__") and not callable(v): |
|
|
out[k] = v |
|
|
for k, v in self.__dict__.items(): |
|
|
if isinstance(v, property): |
|
|
try: |
|
|
out[k] = getattr(self, k) |
|
|
except Exception: |
|
|
pass |
|
|
else: |
|
|
if not k.startswith("__") and not callable(v): |
|
|
out[k] = v |
|
|
return out |
|
|
|
|
|
def __init__(self, loaded_config: dict = None, raise_error_on_missing_key: bool = False): |
|
|
"""Constructor for training config.""" |
|
|
super().__init__() |
|
|
self.parse( |
|
|
loaded_config=loaded_config, |
|
|
env_key="TRAINING_CONFIG", |
|
|
raise_error_on_missing_key=raise_error_on_missing_key, |
|
|
) |
|
|
|
|
|
|
|
|
class TransfuserBackbone(nn.Module): |
|
|
@beartype |
|
|
def __init__(self, device: torch.device, config: TrainingConfig): |
|
|
super().__init__() |
|
|
self.device = device |
|
|
self.config = config |
|
|
|
|
|
|
|
|
self.image_encoder = timm.create_model(config.image_architecture, pretrained=True, features_only=True) |
|
|
self.avgpool_img = nn.AdaptiveAvgPool2d((self.config.img_vert_anchors, self.config.img_horz_anchors)) |
|
|
image_start_index = 0 |
|
|
if len(self.image_encoder.return_layers) > 4: |
|
|
image_start_index += 1 |
|
|
self.num_image_features = self.image_encoder.feature_info.info[image_start_index + 3]["num_chs"] |
|
|
|
|
|
|
|
|
if self.config.num_fusion_stages > 1: |
|
|
self.lidar_encoder = timm.create_model( |
|
|
config.lidar_architecture, pretrained=False, in_chans=2 if config.LTF else 1, features_only=True |
|
|
) |
|
|
lidar_start_index = 0 |
|
|
if len(self.lidar_encoder.return_layers) > 4: |
|
|
lidar_start_index += 1 |
|
|
self.num_lidar_features = self.lidar_encoder.feature_info.info[lidar_start_index + 3]["num_chs"] |
|
|
self.lidar_channel_to_img = nn.ModuleList( |
|
|
[ |
|
|
nn.Conv2d( |
|
|
self.lidar_encoder.feature_info.info[lidar_start_index + i]["num_chs"], |
|
|
self.image_encoder.feature_info.info[image_start_index + i]["num_chs"], |
|
|
kernel_size=1, |
|
|
) |
|
|
for i in range(4 - config.num_fusion_stages, 4) |
|
|
] |
|
|
) |
|
|
self.img_channel_to_lidar = nn.ModuleList( |
|
|
[ |
|
|
nn.Conv2d( |
|
|
self.image_encoder.feature_info.info[image_start_index + i]["num_chs"], |
|
|
self.lidar_encoder.feature_info.info[lidar_start_index + i]["num_chs"], |
|
|
kernel_size=1, |
|
|
) |
|
|
for i in range(4 - config.num_fusion_stages, 4) |
|
|
] |
|
|
) |
|
|
else: |
|
|
self.num_lidar_features = self.num_image_features |
|
|
self.bev_queries = nn.Parameter( |
|
|
torch.randn( |
|
|
1, |
|
|
self.num_lidar_features, |
|
|
self.config.lidar_vert_anchors, |
|
|
self.config.lidar_horz_anchors, |
|
|
) |
|
|
) |
|
|
self.lidar_channel_to_img = nn.ModuleList( |
|
|
[ |
|
|
nn.Conv2d( |
|
|
self.num_lidar_features, |
|
|
self.image_encoder.feature_info.info[image_start_index + 3]["num_chs"], |
|
|
kernel_size=1, |
|
|
) |
|
|
] |
|
|
) |
|
|
self.img_channel_to_lidar = nn.ModuleList( |
|
|
[ |
|
|
nn.Conv2d( |
|
|
self.image_encoder.feature_info.info[image_start_index + 3]["num_chs"], |
|
|
self.num_lidar_features, |
|
|
kernel_size=1, |
|
|
) |
|
|
] |
|
|
) |
|
|
self.avgpool_lidar = nn.AdaptiveAvgPool2d((self.config.lidar_vert_anchors, self.config.lidar_horz_anchors)) |
|
|
|
|
|
|
|
|
self.transformers = nn.ModuleList( |
|
|
[ |
|
|
GPT(n_embd=self.image_encoder.feature_info.info[image_start_index + i]["num_chs"], config=config) |
|
|
for i in range(4 - config.num_fusion_stages, 4) |
|
|
] |
|
|
) |
|
|
|
|
|
|
|
|
self.perspective_upsample_factor = ( |
|
|
self.image_encoder.feature_info.info[image_start_index + 3]["reduction"] |
|
|
// self.config.perspective_downsample_factor |
|
|
) |
|
|
|
|
|
self.upsample = nn.Upsample(scale_factor=self.config.bev_upsample_factor, mode="bilinear", align_corners=False) |
|
|
self.upsample2 = nn.Upsample( |
|
|
size=( |
|
|
self.config.lidar_height_pixel // self.config.bev_down_sample_factor, |
|
|
self.config.lidar_width_pixel // self.config.bev_down_sample_factor, |
|
|
), |
|
|
mode="bilinear", |
|
|
align_corners=False, |
|
|
) |
|
|
self.up_conv5 = nn.Conv2d(self.config.bev_features_chanels, self.config.bev_features_chanels, (3, 3), padding=1) |
|
|
self.up_conv4 = nn.Conv2d(self.config.bev_features_chanels, self.config.bev_features_chanels, (3, 3), padding=1) |
|
|
self.c5_conv = nn.Conv2d(self.num_lidar_features, self.config.bev_features_chanels, (1, 1)) |
|
|
|
|
|
def top_down(self, x): |
|
|
p5 = F.relu(self.c5_conv(x), inplace=True) |
|
|
p4 = F.relu(self.up_conv5(self.upsample(p5)), inplace=True) |
|
|
p3 = F.relu(self.up_conv4(self.upsample2(p4)), inplace=True) |
|
|
return p3 |
|
|
|
|
|
def forward(self, data: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]: |
|
|
rgb = data["rgb"].to(self.device, dtype=self.config.torch_float_type, non_blocking=True) |
|
|
if self.config.LTF: |
|
|
if self.config.num_fusion_stages > 1: |
|
|
x = torch.linspace(0, 1, self.config.lidar_width_pixel) |
|
|
y = torch.linspace(0, 1, self.config.lidar_height_pixel) |
|
|
y_grid, x_grid = torch.meshgrid(y, x, indexing="ij") |
|
|
|
|
|
lidar = torch.zeros( |
|
|
(rgb.shape[0], 2, self.config.lidar_height_pixel, self.config.lidar_width_pixel), |
|
|
device=rgb.device, |
|
|
) |
|
|
lidar[:, 0] = y_grid.unsqueeze(0) |
|
|
lidar[:, 1] = x_grid.unsqueeze(0) |
|
|
else: |
|
|
lidar = None |
|
|
else: |
|
|
lidar = data["lidar"].to(self.device, dtype=self.config.torch_float_type, non_blocking=True) |
|
|
return self._forward(rgb, lidar) |
|
|
|
|
|
@jt.jaxtyped(typechecker=beartype) |
|
|
def _forward( |
|
|
self, |
|
|
image: jt.Float[torch.Tensor, "B 3 img_h img_w"], |
|
|
lidar: jt.Float[torch.Tensor, "B 1 bev_h bev_w"] | jt.Float[torch.Tensor, "B 2 bev_h bev_w"] | None, |
|
|
) -> tuple[jt.Float[torch.Tensor, "B D1 H1 W1"], jt.Float[torch.Tensor, "B D2 H2 W2"]]: |
|
|
""" |
|
|
Image + LiDAR feature fusion using transformers |
|
|
Args: |
|
|
image: RGB image. |
|
|
lidar: Pseudo-image LiDAR. |
|
|
Returns: |
|
|
lidar_features: BEV feature map for planning. |
|
|
image_features: Image feature map for perception. |
|
|
""" |
|
|
image_features = normalize_imagenet(image) |
|
|
lidar_features = lidar |
|
|
|
|
|
if self.config.channel_last: |
|
|
image = image.to(memory_format=torch.channels_last) |
|
|
if lidar is not None: |
|
|
lidar = lidar.to(memory_format=torch.channels_last) |
|
|
|
|
|
|
|
|
image_layers = iter(self.image_encoder.items()) |
|
|
if self.config.num_fusion_stages > 1: |
|
|
lidar_layers = iter(self.lidar_encoder.items()) |
|
|
|
|
|
|
|
|
if len(self.image_encoder.return_layers) > 4: |
|
|
image_features = self.forward_layer_block(image_layers, self.image_encoder.return_layers, image_features) |
|
|
if self.config.num_fusion_stages > 1 and len(self.lidar_encoder.return_layers) > 4: |
|
|
lidar_features = self.forward_layer_block(lidar_layers, self.lidar_encoder.return_layers, lidar_features) |
|
|
|
|
|
|
|
|
for i in range(4): |
|
|
|
|
|
image_features = self.forward_layer_block(image_layers, self.image_encoder.return_layers, image_features) |
|
|
if self.config.num_fusion_stages > 1: |
|
|
lidar_features = self.forward_layer_block(lidar_layers, self.lidar_encoder.return_layers, lidar_features) |
|
|
|
|
|
|
|
|
if i >= 4 - self.config.num_fusion_stages: |
|
|
if self.config.num_fusion_stages == 1: |
|
|
lidar_features = self.bev_queries.expand(image_features.shape[0], -1, -1, -1) |
|
|
image_features, lidar_features = self.fuse_features( |
|
|
image_features, lidar_features, i - (4 - self.config.num_fusion_stages) |
|
|
) |
|
|
|
|
|
return lidar_features, image_features |
|
|
|
|
|
@beartype |
|
|
def forward_layer_block(self, layers, return_layers: dict[str, str], features: torch.Tensor) -> torch.Tensor: |
|
|
"""Run one forward pass to a block of layers from a TIMM neural network and returns the result. |
|
|
Advances the whole network by just one block. |
|
|
|
|
|
Args: |
|
|
layers: Iterator starting at the current layer block of the target network. |
|
|
return_layers: TIMM dictionary describing at which intermediate layers features are returned. |
|
|
features: Input features. |
|
|
|
|
|
Return: |
|
|
torch.Tensor: Processed features |
|
|
""" |
|
|
for name, module in layers: |
|
|
features = module(features) |
|
|
if name in return_layers: |
|
|
break |
|
|
return features |
|
|
|
|
|
def fuse_features( |
|
|
self, image_features: torch.Tensor, lidar_features: torch.Tensor, layer_idx: int |
|
|
) -> tuple[torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Perform a TransFuser feature fusion block using a Transformer module. |
|
|
Args: |
|
|
image_features: Features from the image branch |
|
|
lidar_features: Features from the LiDAR branch |
|
|
layer_idx: Transformer layer index. |
|
|
Returns: |
|
|
image_features and lidar_features with added features from the other branch. |
|
|
""" |
|
|
image_embd_layer = self.avgpool_img(image_features) |
|
|
lidar_embd_layer = self.avgpool_lidar(lidar_features) |
|
|
lidar_embd_layer = self.lidar_channel_to_img[layer_idx](lidar_embd_layer) |
|
|
|
|
|
image_features_layer, lidar_features_layer = self.transformers[layer_idx](image_embd_layer, lidar_embd_layer) |
|
|
|
|
|
lidar_features_layer = self.img_channel_to_lidar[layer_idx](lidar_features_layer) |
|
|
image_features_layer = F.interpolate( |
|
|
image_features_layer, size=(image_features.shape[2], image_features.shape[3]), mode="bilinear", align_corners=False |
|
|
) |
|
|
lidar_features_layer = F.interpolate( |
|
|
lidar_features_layer, size=(lidar_features.shape[2], lidar_features.shape[3]), mode="bilinear", align_corners=False |
|
|
) |
|
|
|
|
|
image_features = image_features + image_features_layer |
|
|
lidar_features = lidar_features + lidar_features_layer |
|
|
|
|
|
return image_features, lidar_features |
|
|
|
|
|
|
|
|
class GPT(nn.Module): |
|
|
def __init__(self, n_embd, config): |
|
|
super().__init__() |
|
|
self.n_embd = n_embd |
|
|
self.config = config |
|
|
|
|
|
self.pos_emb = nn.Parameter( |
|
|
torch.zeros( |
|
|
1, |
|
|
self.config.img_vert_anchors * self.config.img_horz_anchors |
|
|
+ self.config.lidar_vert_anchors * self.config.lidar_horz_anchors, |
|
|
self.n_embd, |
|
|
) |
|
|
) |
|
|
self.drop = nn.Dropout(config.embd_pdrop) |
|
|
|
|
|
self.blocks = nn.Sequential( |
|
|
*[ |
|
|
Block(n_embd, config.n_head, config.block_exp, config.attn_pdrop, config.resid_pdrop) |
|
|
for layer in range(config.n_layer) |
|
|
] |
|
|
) |
|
|
|
|
|
self.ln_f = nn.LayerNorm(n_embd) |
|
|
self.apply(self._init_weights) |
|
|
|
|
|
def _init_weights(self, module): |
|
|
if isinstance(module, nn.Linear): |
|
|
module.weight.data.normal_(mean=self.config.gpt_linear_layer_init_mean, std=self.config.gpt_linear_layer_init_std) |
|
|
if module.bias is not None: |
|
|
module.bias.data.zero_() |
|
|
elif isinstance(module, nn.LayerNorm): |
|
|
module.bias.data.zero_() |
|
|
module.weight.data.fill_(self.config.gpt_layer_norm_init_weight) |
|
|
|
|
|
def forward(self, image_tensor, lidar_tensor): |
|
|
""" |
|
|
Args: |
|
|
image_tensor (tensor): B, C, H, W |
|
|
lidar_tensor (tensor): B, C, H, W |
|
|
""" |
|
|
bz = lidar_tensor.shape[0] |
|
|
lidar_h, lidar_w = lidar_tensor.shape[2:4] |
|
|
img_h, img_w = image_tensor.shape[2:4] |
|
|
|
|
|
image_tensor = image_tensor.permute(0, 2, 3, 1).contiguous().view(bz, -1, self.n_embd) |
|
|
lidar_tensor = lidar_tensor.permute(0, 2, 3, 1).contiguous().view(bz, -1, self.n_embd) |
|
|
|
|
|
token_embeddings = torch.cat((image_tensor, lidar_tensor), dim=1) |
|
|
|
|
|
x = self.drop(self.pos_emb + token_embeddings) |
|
|
x = self.blocks(x) |
|
|
x = self.ln_f(x) |
|
|
|
|
|
image_tensor_out = ( |
|
|
x[:, : self.config.img_vert_anchors * self.config.img_horz_anchors, :] |
|
|
.view(bz, img_h, img_w, -1) |
|
|
.permute(0, 3, 1, 2) |
|
|
.contiguous() |
|
|
) |
|
|
lidar_tensor_out = ( |
|
|
x[:, self.config.img_vert_anchors * self.config.img_horz_anchors :, :] |
|
|
.view(bz, lidar_h, lidar_w, -1) |
|
|
.permute(0, 3, 1, 2) |
|
|
.contiguous() |
|
|
) |
|
|
|
|
|
return image_tensor_out, lidar_tensor_out |
|
|
|
|
|
|
|
|
class Block(nn.Module): |
|
|
def __init__(self, n_embd, n_head, block_exp, attn_pdrop, resid_pdrop): |
|
|
super().__init__() |
|
|
self.ln1 = nn.LayerNorm(n_embd) |
|
|
self.ln2 = nn.LayerNorm(n_embd) |
|
|
self.attn = SelfAttention(n_embd, n_head, attn_pdrop, resid_pdrop) |
|
|
self.mlp = nn.Sequential( |
|
|
nn.Linear(n_embd, block_exp * n_embd), |
|
|
nn.ReLU(True), |
|
|
nn.Linear(block_exp * n_embd, n_embd), |
|
|
nn.Dropout(resid_pdrop), |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
x = x + self.attn(self.ln1(x)) |
|
|
x = x + self.mlp(self.ln2(x)) |
|
|
return x |
|
|
|
|
|
|
|
|
class SelfAttention(nn.Module): |
|
|
def __init__(self, n_embd, n_head, attn_pdrop, resid_pdrop): |
|
|
super().__init__() |
|
|
assert n_embd % n_head == 0 |
|
|
|
|
|
self.key = nn.Linear(n_embd, n_embd) |
|
|
self.query = nn.Linear(n_embd, n_embd) |
|
|
self.value = nn.Linear(n_embd, n_embd) |
|
|
|
|
|
self.dropout = attn_pdrop |
|
|
self.resid_drop = nn.Dropout(resid_pdrop) |
|
|
|
|
|
self.proj = nn.Linear(n_embd, n_embd) |
|
|
self.n_head = n_head |
|
|
|
|
|
def forward(self, x): |
|
|
b, t, c = x.size() |
|
|
|
|
|
|
|
|
k = self.key(x).view(b, t, self.n_head, c // self.n_head).transpose(1, 2) |
|
|
q = self.query(x).view(b, t, self.n_head, c // self.n_head).transpose(1, 2) |
|
|
v = self.value(x).view(b, t, self.n_head, c // self.n_head).transpose(1, 2) |
|
|
|
|
|
|
|
|
y = torch.nn.functional.scaled_dot_product_attention( |
|
|
q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=False |
|
|
) |
|
|
y = y.transpose(1, 2).contiguous().view(b, t, c) |
|
|
|
|
|
|
|
|
y = self.resid_drop(self.proj(y)) |
|
|
return y |
|
|
|
|
|
|
|
|
class PlanningDecoder(nn.Module): |
|
|
@beartype |
|
|
def __init__(self, input_bev_channels: int, config: TrainingConfig, device: torch.device): |
|
|
super().__init__() |
|
|
self.device = device |
|
|
self.config = config |
|
|
self.planning_context_encoder = PlanningContextEncoder( |
|
|
config=self.config, input_bev_channels=input_bev_channels, device=self.device |
|
|
) |
|
|
|
|
|
num_queries = 0 |
|
|
if self.config.predict_temporal_spatial_waypoints: |
|
|
num_queries += self.config.num_way_points_prediction |
|
|
|
|
|
self.query = nn.Parameter( |
|
|
torch.zeros( |
|
|
1, |
|
|
num_queries, |
|
|
self.config.transfuser_token_dim, |
|
|
) |
|
|
) |
|
|
|
|
|
self.transformer_decoder = torch.nn.TransformerDecoder( |
|
|
decoder_layer=nn.TransformerDecoderLayer( |
|
|
self.config.transfuser_token_dim, |
|
|
self.config.transfuser_num_bev_cross_attention_heads, |
|
|
activation=nn.GELU(), |
|
|
batch_first=True, |
|
|
), |
|
|
num_layers=self.config.transfuser_num_bev_cross_attention_layers, |
|
|
norm=nn.LayerNorm(self.config.transfuser_token_dim), |
|
|
) |
|
|
|
|
|
|
|
|
if self.config.predict_temporal_spatial_waypoints: |
|
|
self.wp_decoder = nn.Linear(config.transfuser_token_dim, 2) |
|
|
if self.config.use_navsim_data: |
|
|
self.heading_decoder = nn.Linear(config.transfuser_token_dim, 1) |
|
|
|
|
|
@beartype |
|
|
def forward( |
|
|
self, |
|
|
bev_features: jt.Float[torch.Tensor, "bs bev_dim height_bev width_bev"], |
|
|
data: dict, |
|
|
log: dict, |
|
|
) -> tuple[ |
|
|
jt.Float[torch.Tensor, "B n_waypoints 2"], |
|
|
jt.Float[torch.Tensor, "B n_waypoints"] | None, |
|
|
]: |
|
|
""" |
|
|
Args: |
|
|
bev_features: BEV features. |
|
|
radar_features: Radar features. |
|
|
radar_predictions: Radar predictions. |
|
|
data: dict |
|
|
log: dict |
|
|
Returns: |
|
|
waypoints: Spatial and temporal path. |
|
|
headings: Heading predictions (if using NavSim data). |
|
|
""" |
|
|
self.kv = context_tokens = self.planning_context_encoder(bev_features=bev_features, data=data) |
|
|
|
|
|
bs = context_tokens.shape[0] |
|
|
|
|
|
queries = self.transformer_decoder(self.query.repeat(bs, 1, 1), context_tokens) |
|
|
|
|
|
|
|
|
query_idx = 0 |
|
|
waypoints = None |
|
|
headings = None |
|
|
|
|
|
if self.config.predict_temporal_spatial_waypoints: |
|
|
waypoints_queries = queries[:, query_idx : query_idx + self.config.num_way_points_prediction] |
|
|
waypoints = torch.cumsum(self.wp_decoder(waypoints_queries), 1) |
|
|
if self.config.use_navsim_data: |
|
|
headings = torch.cumsum(self.heading_decoder(waypoints_queries), 1) |
|
|
query_idx += self.config.num_way_points_prediction |
|
|
|
|
|
return (waypoints, headings.squeeze(-1) if headings is not None else None) |
|
|
|
|
|
|
|
|
class PlanningContextEncoder(nn.Module): |
|
|
@beartype |
|
|
def __init__(self, config: TrainingConfig, input_bev_channels: int, device: torch.device): |
|
|
super().__init__() |
|
|
self.device = device |
|
|
self.config: TrainingConfig = config |
|
|
|
|
|
self.num_status_tokens = 0 |
|
|
|
|
|
if self.config.use_velocity: |
|
|
self.num_status_tokens += 1 |
|
|
self.velocity_encoder = nn.Sequential( |
|
|
nn.Linear(1, self.config.transfuser_token_dim), |
|
|
) |
|
|
|
|
|
if self.config.use_acceleration: |
|
|
self.num_status_tokens += 1 |
|
|
self.acceleration_encoder = nn.Sequential( |
|
|
nn.Linear(1, self.config.transfuser_token_dim), |
|
|
) |
|
|
|
|
|
if self.config.use_discrete_command: |
|
|
self.num_status_tokens += 1 |
|
|
self.command_encoder = nn.Sequential(nn.Linear(self.config.discrete_command_dim, self.config.transfuser_token_dim)) |
|
|
|
|
|
self.cosine_pos_embeding = PositionEmbeddingSine(config, self.config.transfuser_token_dim // 2, normalize=True) |
|
|
self.status_pos_embedding = nn.Parameter(torch.zeros(1, self.num_status_tokens, self.config.transfuser_token_dim)) |
|
|
|
|
|
self.dimension_adapter = nn.Conv2d(input_bev_channels, self.config.transfuser_token_dim, kernel_size=1) |
|
|
self.reset_parameters() |
|
|
|
|
|
self.target_points_normalization_constants = torch.tensor( |
|
|
self.config.target_points_normalization_constants, device=self.device, dtype=self.config.torch_float_type |
|
|
) |
|
|
|
|
|
def reset_parameters(self): |
|
|
nn.init.uniform_(self.status_pos_embedding) |
|
|
|
|
|
@beartype |
|
|
def forward( |
|
|
self, |
|
|
bev_features: jt.Float[torch.Tensor, "B C H W"], |
|
|
data: dict, |
|
|
) -> jt.Float[torch.Tensor, "B N D"]: |
|
|
""" |
|
|
Args: |
|
|
bev_features: Raw BEV features. |
|
|
radar_logits: Radar logits. |
|
|
radar_predictions: Radar predictions. |
|
|
data: dict |
|
|
log: dict |
|
|
Returns: |
|
|
context_tokens: Output tokens for planning transformer decoder. |
|
|
""" |
|
|
|
|
|
if self.config.use_velocity: |
|
|
velocity = data["speed"].reshape(-1, 1).to(self.device, dtype=self.config.torch_float_type) |
|
|
if self.config.use_discrete_command: |
|
|
command = data["command"].to(self.device, dtype=self.config.torch_float_type) |
|
|
|
|
|
status_tokens = [] |
|
|
|
|
|
|
|
|
if self.config.use_velocity: |
|
|
velocity_token = self.velocity_encoder(velocity / self.config.max_speed).reshape( |
|
|
-1, 1, self.config.transfuser_token_dim |
|
|
) |
|
|
status_tokens.append(velocity_token) |
|
|
|
|
|
|
|
|
if self.config.use_acceleration: |
|
|
acceleration = data["acceleration"].reshape(-1, 1).to(self.device, dtype=self.config.torch_float_type) |
|
|
acceleration_token = self.acceleration_encoder(acceleration / self.config.max_acceleration).reshape( |
|
|
-1, 1, self.config.transfuser_token_dim |
|
|
) |
|
|
status_tokens.append(acceleration_token) |
|
|
|
|
|
|
|
|
if self.config.use_discrete_command: |
|
|
command_token = self.command_encoder(command).reshape( |
|
|
-1, 1, self.config.transfuser_token_dim |
|
|
) |
|
|
status_tokens.append(command_token) |
|
|
|
|
|
|
|
|
has_statuses = False |
|
|
if len(status_tokens) > 0: |
|
|
status_tokens = torch.cat(status_tokens, dim=1) |
|
|
has_statuses = True |
|
|
|
|
|
|
|
|
context_tokens = self.dimension_adapter(bev_features) |
|
|
|
|
|
|
|
|
if has_statuses: |
|
|
context_tokens = context_tokens + self.cosine_pos_embeding( |
|
|
context_tokens |
|
|
) |
|
|
context_tokens = torch.flatten(context_tokens, start_dim=2) |
|
|
context_tokens = torch.permute(context_tokens, (0, 2, 1)) |
|
|
|
|
|
status_tokens = status_tokens + self.status_pos_embedding |
|
|
context_tokens = torch.cat( |
|
|
[context_tokens, status_tokens], dim=1 |
|
|
) |
|
|
|
|
|
return context_tokens |
|
|
|
|
|
|
|
|
class PositionEmbeddingSine(nn.Module): |
|
|
def __init__(self, config: TrainingConfig, num_pos_feats=64, temperature=10000, normalize=False, scale=None): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
self.num_pos_feats = num_pos_feats |
|
|
self.temperature = temperature |
|
|
self.normalize = normalize |
|
|
if scale is not None and normalize is False: |
|
|
raise ValueError("normalize should be True if scale is passed") |
|
|
if scale is None: |
|
|
scale = 2 * math.pi |
|
|
self.scale = scale |
|
|
|
|
|
def forward(self, tensor: torch.Tensor): |
|
|
x = tensor |
|
|
bs, _, h, w = x.shape |
|
|
not_mask = torch.ones((bs, h, w), device=x.device) |
|
|
y_embed = not_mask.cumsum(1, dtype=torch.float32) |
|
|
x_embed = not_mask.cumsum(2, dtype=torch.float32) |
|
|
if self.normalize: |
|
|
eps = 1e-6 |
|
|
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale |
|
|
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale |
|
|
|
|
|
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) |
|
|
dim_t = self.temperature ** (2 * (torch.div(dim_t, 2, rounding_mode="floor")) / self.num_pos_feats) |
|
|
|
|
|
pos_x = x_embed[:, :, :, None] / dim_t |
|
|
pos_y = y_embed[:, :, :, None] / dim_t |
|
|
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) |
|
|
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) |
|
|
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) |
|
|
return pos.to(self.config.torch_float_type).contiguous() |
|
|
|
|
|
|
|
|
@jt.jaxtyped(typechecker=beartype) |
|
|
@dataclass |
|
|
class Prediction: |
|
|
"""Raw output predictions from the model.""" |
|
|
|
|
|
|
|
|
pred_future_waypoints: jt.Float[torch.Tensor, "bs n_waypoints 2"] | None |
|
|
pred_headings: jt.Float[torch.Tensor, "bs n_waypoints"] | None |
|
|
|
|
|
|
|
|
class Model(nn.Module): |
|
|
@beartype |
|
|
def __init__( |
|
|
self, |
|
|
device: torch.device, |
|
|
config: TrainingConfig, |
|
|
): |
|
|
super().__init__() |
|
|
self.device = device |
|
|
self.config = config |
|
|
self.log = {} |
|
|
|
|
|
self.backbone = TransfuserBackbone(self.device, self.config) |
|
|
|
|
|
self.planning_decoder = PlanningDecoder( |
|
|
input_bev_channels=self.backbone.num_lidar_features, config=self.config, device=self.device |
|
|
).to(self.device) |
|
|
|
|
|
@beartype |
|
|
def forward(self, data: dict[str, typing.Any]) -> Prediction: |
|
|
self.log = {} |
|
|
pred_future_waypoints = pred_headings = None |
|
|
|
|
|
bev_features, image_features = self.backbone(data) |
|
|
|
|
|
|
|
|
if self.config.use_planning_decoder: |
|
|
(pred_future_waypoints, pred_headings) = self.planning_decoder(bev_features, data, log=self.log) |
|
|
|
|
|
|
|
|
return Prediction( |
|
|
|
|
|
pred_future_waypoints=pred_future_waypoints, |
|
|
pred_headings=pred_headings, |
|
|
) |
|
|
|
|
|
|
|
|
def load_tf(model_path: str, device: torch.device): |
|
|
base_dir = os.path.dirname(model_path) |
|
|
with open(os.path.join(base_dir, "config.json")) as f: |
|
|
loaded_config = json.load(f) |
|
|
config_training = TrainingConfig(loaded_config) |
|
|
|
|
|
model = Model(device=device, config=config_training) |
|
|
model.load_state_dict(torch.load(model_path), strict=False) |
|
|
model.to(device) |
|
|
model.eval() |
|
|
return model |
|
|
|
|
|
|
|
|
class NavsimData(torch.utils.data.Dataset): |
|
|
"""Data loader for NavSim data""" |
|
|
|
|
|
def __init__(self, root, config: TrainingConfig): |
|
|
self.root = root |
|
|
self.config = config |
|
|
|
|
|
self.feature = glob.glob(os.path.join(self.root, "**/transfuser_feature.gz"), recursive=True) |
|
|
self.target = glob.glob(os.path.join(self.root, "**/transfuser_target.gz"), recursive=True) |
|
|
|
|
|
self.size = len(self.feature) |
|
|
|
|
|
def __len__(self): |
|
|
return self.size |
|
|
|
|
|
def __getitem__(self, index): |
|
|
|
|
|
feature_path = self.feature[index] |
|
|
with gzip.open(feature_path, "rb") as f: |
|
|
feature = pickle.load(f) |
|
|
|
|
|
rgb = cv2.imdecode(np.frombuffer(feature["camera_feature"], np.uint8), cv2.IMREAD_COLOR) |
|
|
rgb = np.transpose(rgb, (2, 0, 1)) |
|
|
|
|
|
data = { |
|
|
"rgb": rgb, |
|
|
"command": feature["status_feature"][:4], |
|
|
"speed": np.linalg.norm(feature["status_feature"][4:6]), |
|
|
"acceleration": np.linalg.norm(feature["status_feature"][6:8]), |
|
|
} |
|
|
return data |
|
|
|