heyoujue commited on
Commit
458a605
·
1 Parent(s): df1f00b

implement rrt planner

Browse files
interactive_3d_plot.html ADDED
The diff for this file is too large to render. See raw diff
 
main.py CHANGED
@@ -1,4 +1,6 @@
1
  from flight_environment import FlightEnvironment
 
 
2
 
3
  env = FlightEnvironment(50)
4
  start = (1,2,0)
@@ -13,12 +15,13 @@ goal = (18,18,3)
13
  # - column 3 contains the z-coordinates of all path points
14
  # This `path` array will be provided to the `env` object for visualization.
15
 
16
- path = [[0,0,0],[1,1,1],[2,2,2],[3,3,3]]
 
17
 
18
  # --------------------------------------------------------------------------------------------------- #
19
 
20
 
21
- env.plot_cylinders(path)
22
 
23
 
24
  # --------------------------------------------------------------------------------------------------- #
 
1
  from flight_environment import FlightEnvironment
2
+ from planners.rrt_planner import rrt_planner
3
+ from plot_plotly import plot_cylinders
4
 
5
  env = FlightEnvironment(50)
6
  start = (1,2,0)
 
15
  # - column 3 contains the z-coordinates of all path points
16
  # This `path` array will be provided to the `env` object for visualization.
17
 
18
+ #path = [[0,0,0],[1,1,1],[2,2,2],[3,3,3]]
19
+ path = rrt_planner(env, start, goal, step_size=1)
20
 
21
  # --------------------------------------------------------------------------------------------------- #
22
 
23
 
24
+ plot_cylinders(env, path)
25
 
26
 
27
  # --------------------------------------------------------------------------------------------------- #
path_planner.py DELETED
@@ -1,24 +0,0 @@
1
- """
2
- In this file, you should implement your own path planning class or function.
3
- Within your implementation, you may call `env.is_collide()` and `env.is_outside()`
4
- to verify whether candidate path points collide with obstacles or exceed the
5
- environment boundaries.
6
-
7
- You are required to write the path planning algorithm by yourself. Copying or calling
8
- any existing path planning algorithms from others is strictly
9
- prohibited. Please avoid using external packages beyond common Python libraries
10
- such as `numpy`, `math`, or `scipy`. If you must use additional packages, you
11
- must clearly explain the reason in your report.
12
- """
13
-
14
-
15
-
16
-
17
-
18
-
19
-
20
-
21
-
22
-
23
-
24
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
planners/__pycache__/rrt_planner.cpython-313.pyc ADDED
Binary file (5.31 kB). View file
 
planners/path_planner.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ """
3
+ In this file, you should implement your own path planning class or function.
4
+ Within your implementation, you may call `env.is_collide()` and `env.is_outside()`
5
+ to verify whether candidate path points collide with obstacles or exceed the
6
+ environment boundaries.
7
+
8
+ You are required to write the path planning algorithm by yourself. Copying or calling
9
+ any existing path planning algorithms from others is strictly
10
+ prohibited. Please avoid using external packages beyond common Python libraries
11
+ such as `numpy`, `math`, or `scipy`. If you must use additional packages, you
12
+ must clearly explain the reason in your report.
13
+ """
14
+
15
+ def is_segment_collision_free(env, p1, p2, step=0.1):
16
+ """Return True if straight segment p1->p2 is free.
17
+ - env: FlightEnvironment instance
18
+ - p1,p2: iterable (x,y,z)
19
+ - step: distance between samples along the segment (meters)
20
+ """
21
+ dist = np.linalg.norm(p2 - p1)
22
+ if dist == 0:
23
+ return not env.is_collide(tuple(p1)) and not env.is_collide(tuple(p1))
24
+ n = max(1, int(np.ceil(dist / step)))
25
+ for i in range(n + 1):
26
+ t = i / n
27
+ p = p1 * (1 - t) + p2 * t
28
+ if env.is_outside(tuple(p)) or env.is_collide(tuple(p)):
29
+ return False
30
+ return True
31
+
32
+ def rrt_planner(env, start, goal, max_iters=10000, step_size=1.0, goal_tolerance=0.8, goal_sample_rate=0.1):
33
+ """
34
+ Basic RRT planner in continuous 3D space with cylinder obstacles provided by env.
35
+ Returns a path as an (N x 3) numpy array from start to goal (inclusive).
36
+ """
37
+ start = np.array(start, dtype=float)
38
+ goal = np.array(goal, dtype=float)
39
+
40
+ # quick checks
41
+ if env.is_outside(tuple(start)) or env.is_collide(tuple(start)):
42
+ raise RuntimeError("Start is invalid (outside or in collision).")
43
+ if env.is_outside(tuple(goal)) or env.is_collide(tuple(goal)):
44
+ raise RuntimeError("Goal is invalid (outside or in collision).")
45
+
46
+ # recorded statistics stored here
47
+ stats = dict(collisions=0)
48
+
49
+ @dataclass
50
+ class Node:
51
+ point: np.ndarray
52
+ parent: np.ndarray = None
53
+
54
+ # environment bounds
55
+ Xmax, Ymax, Zmax = env.space_size
56
+
57
+ nodes = [Node(start)]
58
+
59
+ for it in range(max_iters):
60
+ # sample point
61
+ if np.random.rand() < goal_sample_rate:
62
+ sample = goal.copy()
63
+ else:
64
+ sample = np.array([np.random.uniform(0, Xmax),
65
+ np.random.uniform(0, Ymax),
66
+ np.random.uniform(0, Zmax)])
67
+ # find nearest node
68
+ dists = [np.linalg.norm(node.point - sample) for node in nodes]
69
+ nearest_idx = int(np.argmin(dists))
70
+ nearest = nodes[nearest_idx]
71
+
72
+ direction = sample - nearest.point
73
+ norm = np.linalg.norm(direction)
74
+ if norm == 0:
75
+ print('norm0')
76
+ continue
77
+ direction = direction / norm
78
+
79
+ new_point = nearest.point + direction * min(step_size, norm)
80
+
81
+ # ensure within bounds
82
+ if env.is_outside(tuple(new_point)):
83
+ print('outside')
84
+ continue
85
+
86
+ # check collision along segment
87
+ if not is_segment_collision_free(env, nearest.point, new_point, step=step_size / 5.0):
88
+ stats["collisions"]+=1
89
+ continue
90
+
91
+ new_node = Node(new_point, nearest)
92
+ nodes.append(new_node)
93
+
94
+ # check if we reached goal
95
+ if np.linalg.norm(new_point - goal) <= goal_tolerance:
96
+ # try connect directly to goal
97
+ if is_segment_collision_free(env, new_point, goal, step=step_size / 5.0):
98
+ goal_node = Node(goal, new_node)
99
+ nodes.append(goal_node)
100
+ # reconstruct path
101
+ path_nodes = []
102
+ cur = goal_node
103
+ while cur is not None:
104
+ path_nodes.append(cur.point)
105
+ cur = cur.parent
106
+ path = np.array(path_nodes[::-1])
107
+ return path
108
+
109
+ raise RuntimeError(f"RRT failed to find a path within the given iterations. {stats=}")
planners/rrt_planner.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ import numpy as np
3
+
4
+ def is_segment_collision_free(env, p1, p2, step=0.1):
5
+ """Return True if straight segment p1->p2 is free.
6
+ - env: FlightEnvironment instance
7
+ - p1,p2: iterable (x,y,z)
8
+ - step: distance between samples along the segment (meters)
9
+ """
10
+ dist = np.linalg.norm(p2 - p1)
11
+ if dist == 0:
12
+ return not env.is_collide(tuple(p1)) and not env.is_collide(tuple(p1))
13
+ n = max(1, int(np.ceil(dist / step)))
14
+ for i in range(n + 1):
15
+ t = i / n
16
+ p = p1 * (1 - t) + p2 * t
17
+ if env.is_outside(tuple(p)) or env.is_collide(tuple(p)):
18
+ return False
19
+ return True
20
+
21
+ def rrt_planner(env, start, goal, max_iters=10000, step_size=1.0, goal_tolerance=0.8, goal_sample_rate=0.1):
22
+ """
23
+ Basic RRT planner in continuous 3D space with cylinder obstacles provided by env.
24
+ Returns a path as an (N x 3) numpy array from start to goal (inclusive).
25
+ """
26
+ start = np.array(start, dtype=float)
27
+ goal = np.array(goal, dtype=float)
28
+
29
+ # quick checks
30
+ if env.is_outside(tuple(start)) or env.is_collide(tuple(start)):
31
+ raise RuntimeError("Start is invalid (outside or in collision).")
32
+ if env.is_outside(tuple(goal)) or env.is_collide(tuple(goal)):
33
+ raise RuntimeError("Goal is invalid (outside or in collision).")
34
+
35
+ # recorded statistics stored here
36
+ stats = dict(collisions=0)
37
+
38
+ @dataclass
39
+ class Node:
40
+ point: np.ndarray
41
+ parent: np.ndarray = None
42
+
43
+ # environment bounds
44
+ Xmax, Ymax, Zmax = env.space_size
45
+
46
+ nodes = [Node(start)]
47
+
48
+ for it in range(max_iters):
49
+ # sample point
50
+ if np.random.rand() < goal_sample_rate:
51
+ sample = goal.copy()
52
+ else:
53
+ sample = np.array([np.random.uniform(0, Xmax),
54
+ np.random.uniform(0, Ymax),
55
+ np.random.uniform(0, Zmax)])
56
+ # find nearest node
57
+ dists = [np.linalg.norm(node.point - sample) for node in nodes]
58
+ nearest_idx = int(np.argmin(dists))
59
+ nearest = nodes[nearest_idx]
60
+
61
+ direction = sample - nearest.point
62
+ norm = np.linalg.norm(direction)
63
+ if norm == 0:
64
+ print('norm0')
65
+ continue
66
+ direction = direction / norm
67
+
68
+ new_point = nearest.point + direction * min(step_size, norm)
69
+
70
+ # ensure within bounds
71
+ if env.is_outside(tuple(new_point)):
72
+ print('outside')
73
+ continue
74
+
75
+ # check collision along segment
76
+ if not is_segment_collision_free(env, nearest.point, new_point, step=step_size / 5.0):
77
+ stats["collisions"]+=1
78
+ continue
79
+
80
+ new_node = Node(new_point, nearest)
81
+ nodes.append(new_node)
82
+
83
+ # check if we reached goal
84
+ if np.linalg.norm(new_point - goal) <= goal_tolerance:
85
+ # try connect directly to goal
86
+ if is_segment_collision_free(env, new_point, goal, step=step_size / 5.0):
87
+ goal_node = Node(goal, new_node)
88
+ nodes.append(goal_node)
89
+ # reconstruct path
90
+ path_nodes = []
91
+ cur = goal_node
92
+ while cur is not None:
93
+ path_nodes.append(cur.point)
94
+ cur = cur.parent
95
+ path = np.array(path_nodes[::-1])
96
+ return path
97
+
98
+ raise RuntimeError(f"RRT failed to find a path within the given iterations. {stats=}")
plot_plotly.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import plotly.graph_objects as go
3
+ import plotly.io as pio
4
+
5
+ def plot_cylinders(env, path=None, save_html=True):
6
+ """
7
+ 使用Plotly创建高性能交互式3D图形
8
+ """
9
+ cylinders = env.cylinders
10
+ space_size = env.space_size
11
+
12
+ fig = go.Figure()
13
+
14
+ # 添加圆柱体
15
+ for i, (cx, cy, h, r) in enumerate(cylinders):
16
+ # 创建圆柱体侧面
17
+ z = np.linspace(0, h, 20)
18
+ theta = np.linspace(0, 2 * np.pi, 30)
19
+ theta_grid, z_grid = np.meshgrid(theta, z)
20
+
21
+ x = cx + r * np.cos(theta_grid)
22
+ y = cy + r * np.sin(theta_grid)
23
+
24
+ fig.add_trace(go.Surface(
25
+ x=x, y=y, z=z_grid,
26
+ colorscale=[[0, 'lightblue'], [1, 'steelblue']],
27
+ opacity=0.7,
28
+ showscale=False,
29
+ name=f'Cylinder {i}',
30
+ hovertemplate=f'<b>Cylinder {i}</b><br>' +
31
+ f'Center: ({cx:.1f}, {cy:.1f})<br>' +
32
+ f'Height: {h:.1f}<br>' +
33
+ f'Radius: {r:.1f}<extra></extra>'
34
+ ))
35
+
36
+ # 添加顶部
37
+ theta2 = np.linspace(0, 2 * np.pi, 30)
38
+ x_top = cx + r * np.cos(theta2)
39
+ y_top = cy + r * np.sin(theta2)
40
+ z_top = np.full_like(theta2, h)
41
+
42
+ # 创建顶部面
43
+ fig.add_trace(go.Mesh3d(
44
+ x=x_top, y=y_top, z=z_top,
45
+ color='steelblue',
46
+ opacity=0.7,
47
+ alphahull=0,
48
+ showscale=False
49
+ ))
50
+
51
+ # 添加路径
52
+ if path is not None:
53
+ path = np.array(path)
54
+ xs, ys, zs = path[:, 0], path[:, 1], path[:, 2]
55
+
56
+ fig.add_trace(go.Scatter3d(
57
+ x=xs, y=ys, z=zs,
58
+ mode='lines+markers',
59
+ line=dict(color='red', width=6),
60
+ marker=dict(
61
+ size=4,
62
+ color=zs,
63
+ colorscale='Viridis',
64
+ showscale=True,
65
+ colorbar=dict(title="高度")
66
+ ),
67
+ name='Path',
68
+ hovertemplate='<b>Path Point</b><br>' +
69
+ 'x: %{x:.2f}<br>' +
70
+ 'y: %{y:.2f}<br>' +
71
+ 'z: %{z:.2f}<extra></extra>'
72
+ ))
73
+
74
+ # 添加起点和终点
75
+ fig.add_trace(go.Scatter3d(
76
+ x=[xs[0]], y=[ys[0]], z=[zs[0]],
77
+ mode='markers',
78
+ marker=dict(size=10, color='green', symbol='circle'),
79
+ name='Start',
80
+ hovertemplate='<b>Start</b><br>' +
81
+ f'Position: ({xs[0]:.2f}, {ys[0]:.2f}, {zs[0]:.2f})<extra></extra>'
82
+ ))
83
+
84
+ fig.add_trace(go.Scatter3d(
85
+ x=[xs[-1]], y=[ys[-1]], z=[zs[-1]],
86
+ mode='markers',
87
+ marker=dict(size=10, color='red', symbol='circle'),
88
+ name='End',
89
+ hovertemplate='<b>End</b><br>' +
90
+ f'Position: ({xs[-1]:.2f}, {ys[-1]:.2f}, {zs[-1]:.2f})<extra></extra>'
91
+ ))
92
+
93
+ # 更新布局
94
+ fig.update_layout(
95
+ title=dict(
96
+ text='3D Environment with Path',
97
+ x=0.5,
98
+ font=dict(size=20)
99
+ ),
100
+ scene=dict(
101
+ xaxis=dict(
102
+ title='X',
103
+ range=[0, env.env_width],
104
+ backgroundcolor='rgb(240, 240, 240)',
105
+ gridcolor='white',
106
+ showbackground=True
107
+ ),
108
+ yaxis=dict(
109
+ title='Y',
110
+ range=[0, env.env_length],
111
+ backgroundcolor='rgb(240, 240, 240)',
112
+ gridcolor='white',
113
+ showbackground=True
114
+ ),
115
+ zaxis=dict(
116
+ title='Z',
117
+ range=[0, env.env_height],
118
+ backgroundcolor='rgb(240, 240, 240)',
119
+ gridcolor='white',
120
+ showbackground=True
121
+ ),
122
+ aspectmode='manual',
123
+ aspectratio=dict(x=1, y=env.env_length / env.env_width,
124
+ z=env.env_height / env.env_width)
125
+ ),
126
+ showlegend=True,
127
+ width=1200,
128
+ height=800,
129
+ hovermode='closest'
130
+ )
131
+
132
+ # 保存为HTML(可选)
133
+ if save_html:
134
+ pio.write_html(fig, file='interactive_3d_plot.html', auto_open=True)
135
+
136
+ return fig