unknownuser6666's picture
Upload folder using huggingface_hub
663494c verified
import copy
import os
from typing import Any, Dict, List, Tuple
import carla
import numpy as np
import pygame
import torch
import utils.map.rasterization as map_rasterize_utils
import utils.vis.color as color_utils
import mmdet3d_plugin.datasets.carla as map_rasterizer
class RasterizedMap(CarlaModule):
def __init__(self):
self.pixels_per_meter = 5
self.scale = 1.0
self.margin = 50 # empty area around the boundary of the map, units are in meters
def runtime_init(self, data: Dict[str, Any]) -> None:
super().runtime_init(data)
""" Rasterize the map once at the beginning of the simulation """
self.hd_map: carla.Map = data["hd_map"]
# TODO: check if the display works
os.environ["SDL_VIDEODRIVER"] = "dummy"
pygame.init()
display = pygame.display.set_mode((320, 320), 0, 32)
pygame.display.flip()
# create the grid to fill in rasterized information
self.create_empty_grid(self.hd_map)
# draw static elements onto the empty surfaces
self.draw_road_map()
# convert the road and lane 3-channel to 1-channel
# TODO: Xinshuo: do we still need to swap H/W
make_image = lambda x: np.swapaxes(pygame.surfarray.array3d(x), 0, 1).mean(
axis=-1
)
road = make_image(self.grid_road) # H x W
lane = make_image(self.grid_lane) # H x W
# TODO: Xinshuo, clean up renderer
self.full_map = np.zeros((1, 15) + road.shape) # 1 x 15 x H x W
self.full_map[:, 0, ...] = road / 255.0
self.full_map[:, 1, ...] = lane / 255.0
self.full_map = torch.tensor(self.full_map, device="cuda", dtype=torch.float32)
world_offset = torch.tensor(
self._world_offset, device="cuda", dtype=torch.float32
)
map_dims = self.full_map.shape[2:4]
self.renderer = map_render.MapRenderer(
world_offset, map_dims, data_generation=True
)
def create_empty_grid(self, hd_map: carla.Map):
"""Define the size of grid by checking all waypoints on the map"""
# retrieve all waypoints from the map
waypoints: List[carla.Waypoint] = hd_map.generate_waypoints(2)
# find the min/max world coordinate for the entire city
max_x: int = (
max(waypoints, key=lambda x: x.transform.location.x).transform.location.x
+ self.cfg.margin
)
max_y: int = (
max(waypoints, key=lambda x: x.transform.location.y).transform.location.y
+ self.cfg.margin
)
min_x: int = (
min(waypoints, key=lambda x: x.transform.location.x).transform.location.x
- self.cfg.margin
)
min_y: int = (
min(waypoints, key=lambda x: x.transform.location.y).transform.location.y
- self.cfg.margin
)
width: int = max(max_x - min_x, max_y - min_y)
# use min as the offser to shift the map with postive pixel locations
self._world_offset: Tuple[int, int] = (min_x, min_y)
# create empty surface to draw
width_in_pixels = int(self.cfg.pixels_per_meter * width)
self.grid_road = pygame.Surface((width_in_pixels, width_in_pixels)).convert()
self.grid_lane = pygame.Surface((width_in_pixels, width_in_pixels)).convert()
def world_to_pixel(self, location: carla.Vector3D) -> List[int]:
"""Convert the waypoint location in world coordinate to rasterized image coordinate"""
x = (
self.cfg.scale
* self.cfg.pixels_per_meter
* (location.x - self._world_offset[0])
)
y = (
self.cfg.scale
* self.cfg.pixels_per_meter
* (location.y - self._world_offset[1])
)
return [int(x), int(y)]
def world_to_pixel_width(self, width):
return int(self.cfg.scale * self.cfg.pixels_per_meter * width)
def rasterize_drivable_area(self, lane_marking: Dict[str, List[List[int]]]) -> None:
"""Rasterize the drivable area by finding the road polygon
Note: create road polygon by using waypoints to enclose the road
since we loop through all waypoints starting of the road,
this drawing plots all the drivable area
"""
polygon: List[List[int]] = lane_marking["left"] + [
x for x in reversed(lane_marking["right"])
]
map_utils.draw_drivable_area(self.grid_road, polygon)
def draw_arrow(self, surface, transform, color=color_utils.ALUMINIUM_2):
# TODO: Xinshuo: why not showing up,
transform.rotation.yaw += 180
forward = transform.get_forward_vector()
transform.rotation.yaw += 90
right_dir = transform.get_forward_vector()
start = transform.location
end = start + 2.0 * forward
right = start + 0.8 * forward + 0.4 * right_dir
left = start + 0.8 * forward - 0.4 * right_dir
pygame.draw.lines(
surface, color, False, [self.world_to_pixel(x) for x in [start, end]], 4
)
pygame.draw.lines(
surface,
color,
False,
[self.world_to_pixel(x) for x in [left, start, right]],
4,
)
def draw_stop(
self, surface, font_surface, transform, color=color_utils.ALUMINIUM_2
):
# TODO: Xinshuo, clean up
waypoint = self.hd_map.get_waypoint(transform.location)
angle = -waypoint.transform.rotation.yaw - 90.0
font_surface = pygame.transform.rotate(font_surface, angle)
pixel_pos = self.world_to_pixel(waypoint.transform.location)
offset = font_surface.get_rect(center=(pixel_pos[0], pixel_pos[1]))
surface.blit(font_surface, offset)
# Draw line in front of stop
forward_vector = carla.Location(waypoint.transform.get_forward_vector())
left_vector = (
carla.Location(-forward_vector.y, forward_vector.x, forward_vector.z)
* waypoint.lane_width
/ 2
* 0.7
)
line = [
(waypoint.transform.location + (forward_vector * 1.5) + (left_vector)),
(waypoint.transform.location + (forward_vector * 1.5) - (left_vector)),
]
line_pixel = [self.world_to_pixel(p) for p in line]
pygame.draw.lines(surface, color, True, line_pixel, 2)
def rasterize_lane(
self, road_wps: List[carla.Waypoint], lane_marking: Dict[str, List[List[int]]]
) -> None:
"""Rasterize the lane with solid/dashed line indicating cross-able"""
# use the middle waypoint of the road to determine
# if the lane at one side of the road is crossable or not
center_wp: carla.Waypoint = road_wps[int(len(road_wps) / 2)]
# left and right lanes are separatedly rasterized
# because they might be crossable or non-crossable
for lane, marking in lane_marking.items():
map_utils.draw_lane_marking(
self.grid_lane,
marking,
map_utils.road_crossable(self.hd_map, center_wp)[lane],
)
# draw road arrow
for n, wp in enumerate(road_wps):
# TODO: why 400
if n % 400 == 0:
self.draw_arrow(self.grid_road, wp.transform)
def draw_road_map(self) -> None:
"""Rasterize static scene elements such as lane/road/trafficlight/stopsign"""
# define the color of the background
# map_surface.fill(COLOR_ALUMINIUM_4)
self.grid_road.fill(color_utils.BLACK)
topology_wps = map_utils.get_topology_waypoints(self.hd_map)
# loop through each road-starting waypoint
road_wp_start: carla.Waypoint
for road_wp_start in topology_wps:
road_wps: List[carla.Waypoint] = map_utils.get_road_waypoints(road_wp_start)
lane_marking: Dict = map_utils.get_lane_marking(road_wps)
# convert location from world coordinate to rasterized image coordinate
for lane, marking in lane_marking.items():
lane_marking[lane]: List[List[int]] = [
self.world_to_pixel(x) for x in marking
]
# plot drivable area
self.rasterize_drivable_area(lane_marking)
# plot lane if not at the intersection
# if not road_wp_start.is_intersection:
# TODO: intersection check necessary?
self.rasterize_lane(road_wps, lane_marking)
stops_transform = [
actor.get_transform() for actor in self._actors if "stop" in actor.type_id
]
font_size = self.world_to_pixel_width(1)
font = pygame.font.SysFont("Arial", font_size, True)
font_surface = font.render("STOP", False, color_utils.ALUMINIUM_2)
font_surface = pygame.transform.scale(
font_surface, (font_surface.get_width(), font_surface.get_height() * 2)
)
# Dian: do not draw stop sign
for stop in stops_transform:
self.draw_stop(self.grid_road, font_surface, stop)
def __call__(self, data: Dict) -> torch.Tensor:
# TODO: many repetitive process, need to speed up
# ego_pos_list = [
# self._ego_vehicle.get_transform().location.x,
# self._ego_vehicle.get_transform().location.y,
# ]
ego_yaw_list = [self._ego_vehicle.get_transform().rotation.yaw / 180 * np.pi]
# fetch local birdview per agent
ego_pos = torch.tensor(
[
self._ego_vehicle.get_transform().location.x,
self._ego_vehicle.get_transform().location.y,
],
device="cuda",
dtype=torch.float32,
)
ego_yaw = torch.tensor(
[self._ego_vehicle.get_transform().rotation.yaw / 180 * np.pi],
device="cuda",
dtype=torch.float32,
)
birdview = self.renderer.get_local_map(self.full_map, ego_pos, ego_yaw)
# # -----------------------------------------------------------
# # Traffic light rendering
# # -----------------------------------------------------------
# # vehicle_position = self._ego_vehicle.get_location()
# traffic_lights = self._actors.filter("*traffic_light*")
# for traffic_light in traffic_lights:
# trigger_box_global_pos = traffic_light.get_transform().transform(
# traffic_light.trigger_volume.location
# )
# trigger_box_global_pos = carla.Location(
# x=trigger_box_global_pos.x,
# y=trigger_box_global_pos.y,
# z=trigger_box_global_pos.z,
# )
# if trigger_box_global_pos.distance(vehicle_position) > 15.0:
# continue
# ego_pos_batched.append(ego_pos_list)
# ego_yaw_batched.append(ego_yaw_list)
# pos_batched.append(
# [
# traffic_light.get_transform().location.x,
# traffic_light.get_transform().location.y,
# ]
# )
# yaw_batched.append(
# [traffic_light.get_transform().rotation.yaw / 180 * np.pi]
# )
# template_batched.append(np.ones([4, 4]))
# if str(traffic_light.state) == "Green":
# channel_batched.append(4)
# elif str(traffic_light.state) == "Yellow":
# channel_batched.append(3)
# elif str(traffic_light.state) == "Red":
# channel_batched.append(2)
# if len(ego_pos_batched) > 0:
# ego_pos_batched_torch = torch.tensor(
# ego_pos_batched, device="cuda", dtype=torch.float32
# ).unsqueeze(1)
# ego_yaw_batched_torch = torch.tensor(
# ego_yaw_batched, device="cuda", dtype=torch.float32
# ).unsqueeze(1)
# pos_batched_torch = torch.tensor(
# pos_batched, device="cuda", dtype=torch.float32
# ).unsqueeze(1)
# yaw_batched_torch = torch.tensor(
# yaw_batched, device="cuda", dtype=torch.float32
# ).unsqueeze(1)
# template_batched_torch = torch.tensor(
# template_batched, device="cuda", dtype=torch.float32
# ).unsqueeze(1)
# channel_batched_torch = torch.tensor(
# channel_batched, device="cuda", dtype=torch.int
# )
# self.renderer.render_agent_bv_batched(
# birdview,
# ego_pos_batched_torch,
# ego_yaw_batched_torch,
# template_batched_torch,
# pos_batched_torch,
# yaw_batched_torch,
# channel=channel_batched_torch,
# )
return birdview