de99 commited on
Commit
85b09d9
·
verified ·
1 Parent(s): c210daf

Upload calculate_trajectory_distance.py

Browse files
Files changed (1) hide show
  1. calculate_trajectory_distance.py +39 -0
calculate_trajectory_distance.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import numpy as np
3
+
4
+ # 读取pkl文件
5
+ pkl_path = 'data_splits/airvln_16/test/airvln_16_3_trajectories_long.pkl'
6
+ with open(pkl_path, 'rb') as f:
7
+ trajectories_data = pickle.load(f)
8
+
9
+ # 存储每条轨迹的起点到终点距离
10
+ distances = []
11
+
12
+ # 遍历100条轨迹
13
+ for traj_id in range(100):
14
+ if traj_id not in trajectories_data:
15
+ continue
16
+
17
+ # 获取轨迹数据,形状为(3, 9, 4)
18
+ # 3条候选轨迹,每条9个点,每个点4个值(x, y, z, angle)
19
+ traj_candidates = trajectories_data[traj_id] # (3, 9, 4)
20
+
21
+ # 取第一条候选轨迹(索引0)
22
+ traj = traj_candidates[0] # (9, 4)
23
+
24
+ # 提取起点和终点坐标(前3个值是x, y, z坐标)
25
+ start_point = traj[0, :3] # 起点 (x, y, z)
26
+ end_point = traj[-1, :3] # 终点 (x, y, z)
27
+
28
+ # 计算欧氏距离
29
+ distance = np.linalg.norm(end_point - start_point)
30
+ distances.append(distance)
31
+
32
+ # 计算平均距离
33
+ if len(distances) > 0:
34
+ avg_distance = np.mean(distances)
35
+ print(f"共计算了 {len(distances)} 条轨迹")
36
+ print(f"起点到终点的平均距离: {avg_distance:.2f}")
37
+ else:
38
+ print("未找到有效轨迹数据")
39
+