File size: 3,644 Bytes
0dbe5a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
from dataclasses import dataclass
import numpy as np
import time

def is_segment_collision_free(env, p1, p2, step=0.1):
    """Return True if straight segment p1->p2 is free.
    - env: FlightEnvironment instance
    - p1,p2: iterable (x,y,z)
    - step: distance between samples along the segment (meters)
    """
    dist = np.linalg.norm(p2 - p1)
    if dist == 0:
        return not env.is_collide(tuple(p1)) and not env.is_collide(tuple(p1))
    n = max(1, int(np.ceil(dist / step)))
    for i in range(n + 1):
        t = i / n
        p = p1 * (1 - t) + p2 * t
        if env.is_outside(tuple(p)) or env.is_collide(tuple(p)):
            return False
    return True

def rrt_planner(env, start, goal, max_iters=10000, step_size=1.0, goal_tolerance=0.8, goal_sample_rate=0.1):
    """
    Basic RRT planner in continuous 3D space.
    Returns a path as an (N x 3) numpy array from start to goal (inclusive).
    """
    start = np.array(start, dtype=float)
    goal = np.array(goal, dtype=float)

    # quick checks
    if env.is_outside(tuple(start)) or env.is_collide(tuple(start)):
        raise RuntimeError("Start is invalid (outside or in collision).")
    if env.is_outside(tuple(goal)) or env.is_collide(tuple(goal)):
        raise RuntimeError("Goal is invalid (outside or in collision).")

    # recorded statistics stored here, can be inspected later
    stats = dict(collisions=0)
    tstart = time.time()

    @dataclass
    class Node:
        point: np.ndarray
        parent: np.ndarray = None

    # environment bounds
    Xmax, Ymax, Zmax = env.space_size

    nodes = [Node(start)]

    for it in range(max_iters):
        # sample random point sometimes, otherwise take goal (->bias)
        if np.random.rand() < goal_sample_rate:
            sample = goal.copy()
        else:
            sample = np.array([np.random.uniform(0, Xmax),
                               np.random.uniform(0, Ymax),
                               np.random.uniform(0, Zmax)])
        # find nearest node
        dists = [np.linalg.norm(node.point - sample) for node in nodes]
        nearest_idx = int(np.argmin(dists))
        nearest = nodes[nearest_idx]

        direction = sample - nearest.point
        norm = np.linalg.norm(direction)
        if norm == 0:
            continue
        direction = direction / norm

        new_point = nearest.point + direction * min(step_size, norm)

        # ensure within bounds
        if env.is_outside(tuple(new_point)):
            continue

        # check collision along segment
        if not is_segment_collision_free(env, nearest.point, new_point, step=step_size / 5.0):
            stats["collisions"]+=1
            continue

        new_node = Node(new_point, nearest)
        nodes.append(new_node)

        # check if we reached goal
        if np.linalg.norm(new_point - goal) <= goal_tolerance:
            # try connect directly to goal
            if is_segment_collision_free(env, new_point, goal, step=step_size / 5.0):
                goal_node = Node(goal, new_node)
                nodes.append(goal_node)
                # reconstruct path
                path_nodes = []
                cur = goal_node
                while cur is not None:
                    path_nodes.append(cur.point)
                    cur = cur.parent
                path = np.array(path_nodes[::-1])
                print("reached goal", len(nodes), "nodes", len(path), "waypoints, avoided", stats['collisions'], "collisions", f'{time.time()-tstart:.2f}', "seconds")
                return path

    raise RuntimeError(f"RRT failed to find a path within the given iterations. {stats=}")