heyoujue's picture
show both plots at the same time
0dbe5a4
from dataclasses import dataclass
import numpy as np
import time
def is_segment_collision_free(env, p1, p2, step=0.1):
"""Return True if straight segment p1->p2 is free.
- env: FlightEnvironment instance
- p1,p2: iterable (x,y,z)
- step: distance between samples along the segment (meters)
"""
dist = np.linalg.norm(p2 - p1)
if dist == 0:
return not env.is_collide(tuple(p1)) and not env.is_collide(tuple(p1))
n = max(1, int(np.ceil(dist / step)))
for i in range(n + 1):
t = i / n
p = p1 * (1 - t) + p2 * t
if env.is_outside(tuple(p)) or env.is_collide(tuple(p)):
return False
return True
def rrt_planner(env, start, goal, max_iters=10000, step_size=1.0, goal_tolerance=0.8, goal_sample_rate=0.1):
"""
Basic RRT planner in continuous 3D space.
Returns a path as an (N x 3) numpy array from start to goal (inclusive).
"""
start = np.array(start, dtype=float)
goal = np.array(goal, dtype=float)
# quick checks
if env.is_outside(tuple(start)) or env.is_collide(tuple(start)):
raise RuntimeError("Start is invalid (outside or in collision).")
if env.is_outside(tuple(goal)) or env.is_collide(tuple(goal)):
raise RuntimeError("Goal is invalid (outside or in collision).")
# recorded statistics stored here, can be inspected later
stats = dict(collisions=0)
tstart = time.time()
@dataclass
class Node:
point: np.ndarray
parent: np.ndarray = None
# environment bounds
Xmax, Ymax, Zmax = env.space_size
nodes = [Node(start)]
for it in range(max_iters):
# sample random point sometimes, otherwise take goal (->bias)
if np.random.rand() < goal_sample_rate:
sample = goal.copy()
else:
sample = np.array([np.random.uniform(0, Xmax),
np.random.uniform(0, Ymax),
np.random.uniform(0, Zmax)])
# find nearest node
dists = [np.linalg.norm(node.point - sample) for node in nodes]
nearest_idx = int(np.argmin(dists))
nearest = nodes[nearest_idx]
direction = sample - nearest.point
norm = np.linalg.norm(direction)
if norm == 0:
continue
direction = direction / norm
new_point = nearest.point + direction * min(step_size, norm)
# ensure within bounds
if env.is_outside(tuple(new_point)):
continue
# check collision along segment
if not is_segment_collision_free(env, nearest.point, new_point, step=step_size / 5.0):
stats["collisions"]+=1
continue
new_node = Node(new_point, nearest)
nodes.append(new_node)
# check if we reached goal
if np.linalg.norm(new_point - goal) <= goal_tolerance:
# try connect directly to goal
if is_segment_collision_free(env, new_point, goal, step=step_size / 5.0):
goal_node = Node(goal, new_node)
nodes.append(goal_node)
# reconstruct path
path_nodes = []
cur = goal_node
while cur is not None:
path_nodes.append(cur.point)
cur = cur.parent
path = np.array(path_nodes[::-1])
print("reached goal", len(nodes), "nodes", len(path), "waypoints, avoided", stats['collisions'], "collisions", f'{time.time()-tstart:.2f}', "seconds")
return path
raise RuntimeError(f"RRT failed to find a path within the given iterations. {stats=}")