hyzhou404's picture
init
7accb91
from typing import Dict, List, Optional, Tuple
import numpy as np
from nuplan.common.maps.abstract_map_objects import LaneGraphEdgeMapObject, RoadBlockGraphEdgeMapObject
class Dijkstra:
"""
A class that performs dijkstra's shortest path. The class operates on lane level graph search.
The goal condition is specified to be if the lane can be found at the target roadblock or roadblock connector.
"""
def __init__(self, start_edge: LaneGraphEdgeMapObject, candidate_lane_edge_ids: List[str]):
"""
Constructor for the Dijkstra class.
:param start_edge: The starting edge for the search
:param candidate_lane_edge_ids: The candidates lane ids that can be included in the search.
"""
self._queue = list([start_edge])
self._parent: Dict[str, Optional[LaneGraphEdgeMapObject]] = dict()
self._candidate_lane_edge_ids = candidate_lane_edge_ids
def search(self, target_roadblock: RoadBlockGraphEdgeMapObject) -> Tuple[List[LaneGraphEdgeMapObject], bool]:
"""
Performs dijkstra's shortest path to find a route to the target roadblock.
:param target_roadblock: The target roadblock the path should end at.
:return:
- A route starting from the given start edge
- A bool indicating if the route is successfully found. Successful means that there exists a path
from the start edge to an edge contained in the end roadblock.
If unsuccessful the shortest deepest path is returned.
"""
start_edge = self._queue[0]
# Initial search states
path_found: bool = False
end_edge: LaneGraphEdgeMapObject = start_edge
self._parent[start_edge.id] = None
self._frontier = [start_edge.id]
self._dist = [1]
self._depth = [1]
self._expanded = []
self._expanded_id = []
self._expanded_dist = []
self._expanded_depth = []
while len(self._queue) > 0:
dist, idx = min((val, idx) for (idx, val) in enumerate(self._dist))
current_edge = self._queue[idx]
current_depth = self._depth[idx]
del self._dist[idx], self._queue[idx], self._frontier[idx], self._depth[idx]
if self._check_goal_condition(current_edge, target_roadblock):
end_edge = current_edge
path_found = True
break
self._expanded.append(current_edge)
self._expanded_id.append(current_edge.id)
self._expanded_dist.append(dist)
self._expanded_depth.append(current_depth)
# Populate queue
for next_edge in current_edge.outgoing_edges:
if next_edge.id not in self._candidate_lane_edge_ids:
continue
alt = dist + self._edge_cost(next_edge)
if next_edge.id not in self._expanded_id and next_edge.id not in self._frontier:
self._parent[next_edge.id] = current_edge
self._queue.append(next_edge)
self._frontier.append(next_edge.id)
self._dist.append(alt)
self._depth.append(current_depth + 1)
end_edge = next_edge
elif next_edge.id in self._frontier:
next_edge_idx = self._frontier.index(next_edge.id)
current_cost = self._dist[next_edge_idx]
if alt < current_cost:
self._parent[next_edge.id] = current_edge
self._dist[next_edge_idx] = alt
self._depth[next_edge_idx] = current_depth + 1
if not path_found:
# filter max depth
max_depth = max(self._expanded_depth)
idx_max_depth = list(np.where(np.array(self._expanded_depth) == max_depth)[0])
dist_at_max_depth = [self._expanded_dist[i] for i in idx_max_depth]
dist, _idx = min((val, idx) for (idx, val) in enumerate(dist_at_max_depth))
end_edge = self._expanded[idx_max_depth[_idx]]
return self._construct_path(end_edge), path_found
@staticmethod
def _edge_cost(lane: LaneGraphEdgeMapObject) -> float:
"""
Edge cost of given lane.
:param lane: lane class
:return: length of lane
"""
return lane.baseline_path.length
@staticmethod
def _check_end_condition(depth: int, target_depth: int) -> bool:
"""
Check if the search should end regardless if the goal condition is met.
:param depth: The current depth to check.
:param target_depth: The target depth to check against.
:return: True if:
- The current depth exceeds the target depth.
"""
return depth > target_depth
@staticmethod
def _check_goal_condition(
current_edge: LaneGraphEdgeMapObject,
target_roadblock: RoadBlockGraphEdgeMapObject,
) -> bool:
"""
Check if the current edge is at the target roadblock at the given depth.
:param current_edge: The edge to check.
:param target_roadblock: The target roadblock the edge should be contained in.
:return: whether the current edge is in the target roadblock
"""
return current_edge.get_roadblock_id() == target_roadblock.id
def _construct_path(self, end_edge: LaneGraphEdgeMapObject) -> List[LaneGraphEdgeMapObject]:
"""
:param end_edge: The end edge to start back propagating back to the start edge.
:param depth: The depth of the target edge.
:return: The constructed path as a list of LaneGraphEdgeMapObject
"""
path = [end_edge]
while self._parent[end_edge.id] is not None:
node = self._parent[end_edge.id]
path.append(node)
end_edge = node
path.reverse()
return path