Hakureirm commited on
Commit
d4c9e7d
·
1 Parent(s): 346f515

修改聚类参数

Browse files
Files changed (2) hide show
  1. models/.DS_Store +3 -0
  2. src/gait_analyze.py +217 -117
models/.DS_Store ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:49618902cea3b65197edab1b87a9814144c76a88ed7f3e79b31c153418e861a4
3
+ size 6148
src/gait_analyze.py CHANGED
@@ -9,6 +9,7 @@ import matplotlib.pyplot as plt
9
  from matplotlib.patches import Rectangle
10
  import seaborn as sns
11
  import os
 
12
 
13
  @dataclass
14
  class GaitPrint:
@@ -36,9 +37,14 @@ class GaitAnalyzer:
36
  self.mice_positions: List[Dict] = [] # 现在存储pose关键点信息
37
  self.params = {}
38
  self.time_window = 0.2
39
- self.distance_threshold = 30
40
  self.gait_pattern = None
41
  self.result_dir = self._create_result_dir()
 
 
 
 
 
42
 
43
  def _detect_mouse_time_range(self, video_path: str, margin_ratio: float = 0.05) -> Tuple[float, float]:
44
  """使用pose模型的鼻子和尾巴点来检测老鼠"""
@@ -122,38 +128,103 @@ class GaitAnalyzer:
122
  from sklearn.preprocessing import StandardScaler
123
  from sklearn.cluster import DBSCAN
124
 
125
- # 1. 准备数据
126
- features = np.array([[p.x, p.y, p.timestamp * 30] for p in self.gait_prints])
 
 
 
 
127
  print(f"开始聚类,原始足印数量: {len(features)}")
128
 
129
  # 2. 标准化特征
130
  scaler = StandardScaler()
131
  features_scaled = scaler.fit_transform(features)
132
 
133
- # 3. DBSCAN聚类
134
- eps = 0.3 # 可以根据实际情况调整
135
- min_samples = 1 # 设为1以保留所有检测
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  dbscan = DBSCAN(eps=eps, min_samples=min_samples)
137
  cluster_labels = dbscan.fit_predict(features_scaled)
138
 
139
- # 4. 将聚类结果添加到足印对象中
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  for print_obj, label in zip(self.gait_prints, cluster_labels):
141
  print_obj.cluster_id = label
142
 
143
- # 打印聚类统计信息
144
- n_clusters = len(set(cluster_labels)) - (1 if -1 in cluster_labels else 0)
145
- print(f"聚类完成! 共识别出 {n_clusters} 个独立足印")
146
-
147
- # 按cluster_id分组并打印每组的大小
148
  from collections import Counter
149
  cluster_sizes = Counter(cluster_labels)
150
  print("\n各组足印检测数量:")
151
  for cluster_id, size in sorted(cluster_sizes.items()):
152
  if cluster_id != -1:
153
- print(f"足印 #{cluster_id}: {size}个检测")
 
 
154
 
155
  except Exception as e:
156
  print(f"聚类过程出错: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
 
158
  def _post_process_footprints(self):
159
  """后处理足迹数据:聚类、过滤和分类"""
@@ -179,113 +250,141 @@ class GaitAnalyzer:
179
  # 4. 确定步态周期
180
  self._determine_gait_cycles()
181
 
182
- def _classify_footprints(self, moving_right: bool):
183
- """使用pose关键点来分类足迹"""
184
- if len(self.gait_prints) < 4 or not self.mice_positions:
185
- print("警告:足迹或姿态数据不足")
186
- return
187
-
188
- # 按时间排序足印
189
- sorted_prints = sorted(self.gait_prints, key=lambda p: p.timestamp)
190
 
191
- # 对每个cluster进行分类
192
- cluster_groups = {}
193
- for p in sorted_prints:
194
- if p.cluster_id not in cluster_groups:
195
- cluster_groups[p.cluster_id] = []
196
- cluster_groups[p.cluster_id].append(p)
197
-
198
- # 对每个cluster,找到最近时间的pose数据
199
- for cluster_id, prints in cluster_groups.items():
200
- mid_time = np.mean([p.timestamp for p in prints])
201
- closest_pose = min(self.mice_positions,
202
- key=lambda m: abs(m['timestamp'] - mid_time))
 
 
 
 
 
 
 
 
 
203
 
204
- if 'keypoints' not in closest_pose:
205
- print(f"警告:时间戳 {mid_time:.2f}s 处的姿态数据缺少关键点信息")
 
 
 
 
 
 
 
 
 
 
 
206
  continue
207
 
208
- # 计算cluster中心位置
209
- center_x = np.mean([p.x for p in prints])
210
- center_y = np.mean([p.y for p in prints])
211
-
212
- # 获取关键点位置
213
- nose = closest_pose['keypoints']['nose']
214
- re = closest_pose['keypoints']['right_ear']
215
- le = closest_pose['keypoints']['left_ear']
216
- mid = closest_pose['keypoints']['mid']
217
- rl = closest_pose['keypoints']['right_leg']
218
- ll = closest_pose['keypoints']['left_leg']
219
- tail = closest_pose['keypoints']['tail_base']
220
-
221
- # 计算到各关键点的距离
222
- dist_to_front = min(
223
- np.sqrt((center_x - nose[0])**2 + (center_y - nose[1])**2),
224
- np.sqrt((center_x - re[0])**2 + (center_y - re[1])**2),
225
- np.sqrt((center_x - le[0])**2 + (center_y - le[1])**2)
226
- )
227
-
228
- dist_to_back = min(
229
- np.sqrt((center_x - rl[0])**2 + (center_y - rl[1])**2),
230
- np.sqrt((center_x - ll[0])**2 + (center_y - ll[1])**2),
231
- np.sqrt((center_x - tail[0])**2 + (center_y - tail[1])**2)
232
- )
233
-
234
- # 前后判断:比较到前后关键点的距离
235
- is_front = dist_to_front < dist_to_back
236
-
237
- # 左右判断:根据y坐标相对位置
238
- if is_front:
239
- is_left = center_y > (re[1] + le[1])/2 # 比较与耳朵中点的位置
 
 
 
 
 
240
  else:
241
- is_left = center_y > (rl[1] + ll[1])/2 # 比较与后腿中点的位置
242
 
243
- # 确定爪子类型
244
- paw_type = None
245
  if is_front:
246
- paw_type = 'LF' if is_left else 'RF'
 
 
 
247
  else:
248
- paw_type = 'LH' if is_left else 'RH'
 
 
 
249
 
250
- # 应用分类结果
251
- for p in prints:
252
- p.paw_type = paw_type
253
-
254
- # 按帧组织足印数据,用于时序一致性检查
255
- frame_prints = {}
256
- for p in sorted_prints:
257
- if p.frame_id not in frame_prints:
258
- frame_prints[p.frame_id] = []
259
- frame_prints[p.frame_id].append(p)
260
-
261
- # 进行时序一致性检查
262
- self._enforce_temporal_consistency(frame_prints)
 
 
263
 
264
- def _enforce_temporal_consistency(self, frame_prints):
265
- """确保时序一致性:同一时间同一类型的足印只能有一个"""
266
- for frame_id, prints in frame_prints.items():
267
- # 按类型分组
268
- type_groups = {'LF': [], 'RF': [], 'LH': [], 'RH': []}
269
- for p in prints:
270
- if p.paw_type:
271
- type_groups[p.paw_type].append(p)
272
-
273
- # 处理每个有多个足印的类型
274
- for paw_type, group in type_groups.items():
275
- if len(group) > 1:
276
- # 保留cluster_id较大的足印(通常是较新的足印)
277
- newest_print = max(group, key=lambda p: p.cluster_id)
278
- for p in group:
279
- if p != newest_print:
280
- # 将重复的足印重新分类为对角的另一只脚
281
- if paw_type == 'LF':
282
- p.paw_type = 'RH'
283
- elif paw_type == 'RF':
284
- p.paw_type = 'LH'
285
- elif paw_type == 'LH':
286
- p.paw_type = 'RF'
287
- else: # RH
288
- p.paw_type = 'LF'
 
289
 
290
  def _smooth_classifications(self):
291
  """使用时序信息平滑分类结果"""
@@ -313,7 +412,7 @@ class GaitAnalyzer:
313
 
314
  # 如果当前足迹位置偏离太远,考虑重新分类
315
  dist = np.sqrt((print.x - avg_x)**2 + (print.y - avg_y)**2)
316
- if dist > 50: # 像素距离阈值
317
  # 尝试重新分类
318
  self._reclassify_print(print, sorted_prints, i)
319
 
@@ -719,6 +818,7 @@ class GaitAnalyzer:
719
  print("\n[6/6] 生成轨迹视频...")
720
  self.generate_trajectory_video(video_path)
721
  print("视频生成完成!")
 
722
 
723
  def _get_paw_color(self, paw_type: str) -> Tuple[int, int, int]:
724
  """获取不同爪子类型的颜色"""
@@ -1245,10 +1345,10 @@ class GaitAnalyzer:
1245
 
1246
  # 转换爪子类型
1247
  type_map = {
1248
- 'LF': 'leftFront',
1249
- 'RF': 'rightFront',
1250
- 'LH': 'leftHind',
1251
- 'RH': 'rightHind'
1252
  }
1253
 
1254
  # 生成frames数据
@@ -1367,7 +1467,7 @@ class GaitAnalyzer:
1367
 
1368
  def main():
1369
  analyzer = GaitAnalyzer()
1370
- video_path = "/Users/hakureirm/codespace/Work/Algorithm/gait/exp_videos/Exp8.mp4"
1371
 
1372
  # 自动检测时间范围
1373
  start_time, end_time = analyzer._detect_mouse_time_range(video_path)
@@ -1377,7 +1477,7 @@ def main():
1377
  video_path,
1378
  start_time=start_time,
1379
  end_time=end_time,
1380
- conf_thres=0.7,
1381
  iou_thres=0.5
1382
  )
1383
 
 
9
  from matplotlib.patches import Rectangle
10
  import seaborn as sns
11
  import os
12
+ import logging
13
 
14
  @dataclass
15
  class GaitPrint:
 
37
  self.mice_positions: List[Dict] = [] # 现在存储pose关键点信息
38
  self.params = {}
39
  self.time_window = 0.2
40
+ self.distance_threshold = 5
41
  self.gait_pattern = None
42
  self.result_dir = self._create_result_dir()
43
+
44
+ # 添加 logger
45
+ self.logger = logging.getLogger(__name__)
46
+
47
+ self.fps = 120 # 添加 fps 属性,默认值为 120
48
 
49
  def _detect_mouse_time_range(self, video_path: str, margin_ratio: float = 0.05) -> Tuple[float, float]:
50
  """使用pose模型的鼻子和尾巴点来检测老鼠"""
 
128
  from sklearn.preprocessing import StandardScaler
129
  from sklearn.cluster import DBSCAN
130
 
131
+ # 1. 准备数据 - 调整时间权重
132
+ time_weight = 0.2 # 减小时间维度的权重
133
+ features = np.array([
134
+ [p.x, p.y, p.timestamp * self.fps * time_weight]
135
+ for p in self.gait_prints
136
+ ])
137
  print(f"开始聚类,原始足印数量: {len(features)}")
138
 
139
  # 2. 标准化特征
140
  scaler = StandardScaler()
141
  features_scaled = scaler.fit_transform(features)
142
 
143
+ # 3. 计算合适的eps
144
+ from sklearn.neighbors import NearestNeighbors
145
+ k = min(len(features), 5)
146
+ nbrs = NearestNeighbors(n_neighbors=k).fit(features_scaled)
147
+ distances, _ = nbrs.kneighbors(features_scaled)
148
+ mean_dist = np.mean(distances[:, 1:])
149
+
150
+ # 设置更大的eps以获得更合适的聚类
151
+ eps = mean_dist * 3.0 # 增大eps
152
+
153
+ # 放宽最小样本数要求
154
+ min_samples = max(int(0.02 * self.fps), 2) # 降低持续时间要求到0.02秒
155
+
156
+ print(f"聚类参数: eps={eps:.3f}, min_samples={min_samples}")
157
+
158
+ # 4. DBSCAN聚类
159
  dbscan = DBSCAN(eps=eps, min_samples=min_samples)
160
  cluster_labels = dbscan.fit_predict(features_scaled)
161
 
162
+ # 5. 评估聚类结果
163
+ n_clusters = len(set(cluster_labels)) - (1 if -1 in cluster_labels else 0)
164
+ n_noise = list(cluster_labels).count(-1)
165
+
166
+ print(f"\n聚类结果:")
167
+ print(f"- 识别出的独立足印数: {n_clusters}")
168
+ print(f"- 噪声点数量: {n_noise} ({n_noise/len(features)*100:.1f}%)")
169
+
170
+ # 6. 如果聚类数量不合理,调整参数重试
171
+ if n_clusters < 12 or n_clusters > 24: # 期望12-24个足印
172
+ print("\n尝试调整参数...")
173
+ for eps_factor in [2.0, 2.5, 3.0]: # 尝试更大的eps值
174
+ new_eps = mean_dist * eps_factor
175
+ dbscan = DBSCAN(eps=new_eps, min_samples=min_samples)
176
+ new_labels = dbscan.fit_predict(features_scaled)
177
+ new_n_clusters = len(set(new_labels)) - (1 if -1 in new_labels else 0)
178
+ new_n_noise = list(new_labels).count(-1)
179
+
180
+ print(f"eps={new_eps:.3f}: {new_n_clusters} 簇, {new_n_noise} 噪声点")
181
+
182
+ if 12 <= new_n_clusters <= 24:
183
+ print(f"使用新参数: eps={new_eps:.3f}")
184
+ cluster_labels = new_labels
185
+ break
186
+
187
+ # 7. 更新足印对象
188
  for print_obj, label in zip(self.gait_prints, cluster_labels):
189
  print_obj.cluster_id = label
190
 
191
+ # 8. 打印统计信息
 
 
 
 
192
  from collections import Counter
193
  cluster_sizes = Counter(cluster_labels)
194
  print("\n各组足印检测数量:")
195
  for cluster_id, size in sorted(cluster_sizes.items()):
196
  if cluster_id != -1:
197
+ cluster_prints = [p for p in self.gait_prints if p.cluster_id == cluster_id]
198
+ time_span = max(p.timestamp for p in cluster_prints) - min(p.timestamp for p in cluster_prints)
199
+ print(f"足印 #{cluster_id}: {size}个检测, 持续时间: {time_span:.3f}s")
200
 
201
  except Exception as e:
202
  print(f"聚类过程出错: {str(e)}")
203
+ raise
204
+
205
+ def _visualize_clusters(self, features, labels):
206
+ """可视化聚类结果"""
207
+ try:
208
+ import matplotlib.pyplot as plt
209
+
210
+ plt.figure(figsize=(12, 8))
211
+
212
+ # 绘制散点图
213
+ scatter = plt.scatter(features[:, 0], features[:, 1],
214
+ c=labels, cmap='rainbow',
215
+ alpha=0.6)
216
+
217
+ plt.colorbar(scatter)
218
+ plt.title('足印聚类结果')
219
+ plt.xlabel('X坐标')
220
+ plt.ylabel('Y坐标')
221
+
222
+ # 保存图片
223
+ plt.savefig(os.path.join(self.result_dir, 'plots', 'cluster_visualization.png'))
224
+ plt.close()
225
+
226
+ except Exception as e:
227
+ print(f"可视化过程出错: {str(e)}")
228
 
229
  def _post_process_footprints(self):
230
  """后处理足迹数据:聚类、过滤和分类"""
 
250
  # 4. 确定步态周期
251
  self._determine_gait_cycles()
252
 
253
+ def _classify_footprints(self, moving_right=True):
254
+ """基于姿态关键点对足印簇进行分类
 
 
 
 
 
 
255
 
256
+ Args:
257
+ moving_right: bool, 老鼠是否向右移动,影响左右判定
258
+ """
259
+ self.logger.info("开始基于姿态关键点对足印簇进行分类...")
260
+
261
+ # 1. 首先获取所有足印簇
262
+ clusters = {}
263
+ for print in self.gait_prints:
264
+ if print.cluster_id not in clusters:
265
+ clusters[print.cluster_id] = []
266
+ clusters[print.cluster_id].append(print)
267
+
268
+ # 2. 对每个足印簇进行分类
269
+ for cluster_id, prints in clusters.items():
270
+ # 计算簇的中心点
271
+ cluster_x = np.mean([p.x for p in prints])
272
+ cluster_y = np.mean([p.y for p in prints])
273
+
274
+ # 找到时间上最接近的姿态关键点帧
275
+ closest_pose = None
276
+ min_time_diff = float('inf')
277
 
278
+ # 使用簇中第一个足印的时间戳
279
+ cluster_time = prints[0].timestamp
280
+
281
+ for pose in self.mice_positions:
282
+ if 'keypoints' not in pose:
283
+ continue
284
+ time_diff = abs(pose['timestamp'] - cluster_time)
285
+ if time_diff < min_time_diff:
286
+ min_time_diff = time_diff
287
+ closest_pose = pose
288
+
289
+ if not closest_pose or 'keypoints' not in closest_pose:
290
+ self.logger.warning(f"簇 {cluster_id} 未找到对应的姿态关键点")
291
  continue
292
 
293
+ # 3. 提取关键点
294
+ kpts = closest_pose['keypoints']
295
+ nose = np.array([kpts['nose'][0], kpts['nose'][1]])
296
+ right_ear = np.array([kpts['right_ear'][0], kpts['right_ear'][1]])
297
+ left_ear = np.array([kpts['left_ear'][0], kpts['left_ear'][1]])
298
+ mid = np.array([kpts['mid'][0], kpts['mid'][1]])
299
+ right_leg = np.array([kpts['right_leg'][0], kpts['right_leg'][1]])
300
+ left_leg = np.array([kpts['left_leg'][0], kpts['left_leg'][1]])
301
+ tail_base = np.array([kpts['tail_base'][0], kpts['tail_base'][1]])
302
+
303
+ cluster_pos = np.array([cluster_x, cluster_y])
304
+
305
+ # 4. 计算距离特征
306
+ distances = {
307
+ 'nose': np.linalg.norm(cluster_pos - nose),
308
+ 'right_ear': np.linalg.norm(cluster_pos - right_ear),
309
+ 'left_ear': np.linalg.norm(cluster_pos - left_ear),
310
+ 'mid': np.linalg.norm(cluster_pos - mid),
311
+ 'right_leg': np.linalg.norm(cluster_pos - right_leg),
312
+ 'left_leg': np.linalg.norm(cluster_pos - left_leg),
313
+ 'tail_base': np.linalg.norm(cluster_pos - tail_base)
314
+ }
315
+
316
+ # 5. 计算前后特征
317
+ front_score = (distances['nose'] + distances['right_ear'] + distances['left_ear']) / 3
318
+ back_score = (distances['right_leg'] + distances['left_leg'] + distances['tail_base']) / 3
319
+
320
+ # 6. 计算左右特征
321
+ # 考虑到底部视角:右侧(上方)特征点包括右耳和右腿,左侧(下方)特征点包括左耳和左腿
322
+ right_score = (distances['right_ear'] + distances['right_leg']) / 2
323
+ left_score = (distances['left_ear'] + distances['left_leg']) / 2
324
+
325
+ # 7. 分类决策
326
+ is_front = front_score < back_score
327
+ # 根据移动方向调整左右判定
328
+ if moving_right:
329
+ is_right = right_score < left_score
330
  else:
331
+ is_right = left_score < right_score
332
 
333
+ # 8. 分配爪子类型
334
+ paw_type = ''
335
  if is_front:
336
+ if is_right:
337
+ paw_type = 'RF' # 右前爪
338
+ else:
339
+ paw_type = 'LF' # 左前爪
340
  else:
341
+ if is_right:
342
+ paw_type = 'RH' # 右后爪
343
+ else:
344
+ paw_type = 'LH' # 左后爪
345
 
346
+ # 9. 更新簇中所有足印的类型
347
+ for print in prints:
348
+ print.paw_type = paw_type
349
+
350
+ self.logger.info(f"簇 {cluster_id} 被分类为 {paw_type}")
351
+
352
+ # 10. 添加调试信息
353
+ self.logger.debug(f"簇 {cluster_id} 分类详情:")
354
+ self.logger.debug(f"位置: ({cluster_x:.2f}, {cluster_y:.2f})")
355
+ self.logger.debug(f"前后分数: 前={front_score:.2f}, 后={back_score:.2f}")
356
+ self.logger.debug(f"左右分数: 右={right_score:.2f}, 左={left_score:.2f}")
357
+ self.logger.debug(f"各点距离: {distances}")
358
+
359
+ # 11. 验证分类结果
360
+ self._validate_classification()
361
 
362
+ def _validate_classification(self):
363
+ """验证足印分类结果的合理性"""
364
+ # 统计各类型足印数量
365
+ type_counts = {'LF': 0, 'RF': 0, 'LH': 0, 'RH': 0}
366
+ for print in self.gait_prints:
367
+ if print.paw_type:
368
+ type_counts[print.paw_type] += 1
369
+
370
+ # 检查数量是否平衡
371
+ total = sum(type_counts.values())
372
+ expected = total / 4
373
+ threshold = expected * 0.5 # 允许50%的偏差
374
+
375
+ for paw_type, count in type_counts.items():
376
+ if abs(count - expected) > threshold:
377
+ self.logger.warning(f"{paw_type}的数量({count})与预期({expected:.1f})相差较大")
378
+
379
+ # 检查时空分布
380
+ for paw_type in ['LF', 'RF', 'LH', 'RH']:
381
+ prints = [p for p in self.gait_prints if p.paw_type == paw_type]
382
+ if len(prints) >= 2:
383
+ # 检查时间间隔
384
+ timestamps = sorted([p.timestamp for p in prints])
385
+ intervals = np.diff(timestamps)
386
+ if np.std(intervals) > np.mean(intervals):
387
+ self.logger.warning(f"{paw_type}的时间间隔变异性较大")
388
 
389
  def _smooth_classifications(self):
390
  """使用时序信息平滑分类结果"""
 
412
 
413
  # 如果当前足迹位置偏离太远,考虑重新分类
414
  dist = np.sqrt((print.x - avg_x)**2 + (print.y - avg_y)**2)
415
+ if dist > 10: # 像素距离阈值
416
  # 尝试重新分类
417
  self._reclassify_print(print, sorted_prints, i)
418
 
 
818
  print("\n[6/6] 生成轨迹视频...")
819
  self.generate_trajectory_video(video_path)
820
  print("视频生成完成!")
821
+ # print("api版本取消生成视频!")
822
 
823
  def _get_paw_color(self, paw_type: str) -> Tuple[int, int, int]:
824
  """获取不同爪子类型的颜色"""
 
1345
 
1346
  # 转换爪子类型
1347
  type_map = {
1348
+ 'LF': 'LF',
1349
+ 'RF': 'RF',
1350
+ 'LH': 'LH',
1351
+ 'RH': 'RH'
1352
  }
1353
 
1354
  # 生成frames数据
 
1467
 
1468
  def main():
1469
  analyzer = GaitAnalyzer()
1470
+ video_path = "/Users/hakureirm/codespace/Work/Algorithm/gait/exp_videos/Exp7.mp4"
1471
 
1472
  # 自动检测时间范围
1473
  start_time, end_time = analyzer._detect_mouse_time_range(video_path)
 
1477
  video_path,
1478
  start_time=start_time,
1479
  end_time=end_time,
1480
+ conf_thres=0.8,
1481
  iou_thres=0.5
1482
  )
1483