Upload calculate_trajectory_distance.py
Browse files
calculate_trajectory_distance.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pickle
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
# 读取pkl文件
|
| 5 |
+
pkl_path = 'data_splits/airvln_16/test/airvln_16_3_trajectories_long.pkl'
|
| 6 |
+
with open(pkl_path, 'rb') as f:
|
| 7 |
+
trajectories_data = pickle.load(f)
|
| 8 |
+
|
| 9 |
+
# 存储每条轨迹的起点到终点距离
|
| 10 |
+
distances = []
|
| 11 |
+
|
| 12 |
+
# 遍历100条轨迹
|
| 13 |
+
for traj_id in range(100):
|
| 14 |
+
if traj_id not in trajectories_data:
|
| 15 |
+
continue
|
| 16 |
+
|
| 17 |
+
# 获取轨迹数据,形状为(3, 9, 4)
|
| 18 |
+
# 3条候选轨迹,每条9个点,每个点4个值(x, y, z, angle)
|
| 19 |
+
traj_candidates = trajectories_data[traj_id] # (3, 9, 4)
|
| 20 |
+
|
| 21 |
+
# 取第一条候选轨迹(索引0)
|
| 22 |
+
traj = traj_candidates[0] # (9, 4)
|
| 23 |
+
|
| 24 |
+
# 提取起点和终点坐标(前3个值是x, y, z坐标)
|
| 25 |
+
start_point = traj[0, :3] # 起点 (x, y, z)
|
| 26 |
+
end_point = traj[-1, :3] # 终点 (x, y, z)
|
| 27 |
+
|
| 28 |
+
# 计算欧氏距离
|
| 29 |
+
distance = np.linalg.norm(end_point - start_point)
|
| 30 |
+
distances.append(distance)
|
| 31 |
+
|
| 32 |
+
# 计算平均距离
|
| 33 |
+
if len(distances) > 0:
|
| 34 |
+
avg_distance = np.mean(distances)
|
| 35 |
+
print(f"共计算了 {len(distances)} 条轨迹")
|
| 36 |
+
print(f"起点到终点的平均距离: {avg_distance:.2f}")
|
| 37 |
+
else:
|
| 38 |
+
print("未找到有效轨迹数据")
|
| 39 |
+
|