Hakureirm commited on
Commit
448a1a6
·
0 Parent(s):

Add initial code

Browse files
.gitignore ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.py[cod]
3
+ *$py.class
4
+ .env
5
+ .venv
6
+ env/
7
+ venv/
8
+ ENV/
9
+ results/
10
+ *.mp4
11
+ *.avi
12
+ *.mov
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ ultralytics
2
+ opencv-python
3
+ numpy
4
+ pandas
5
+ matplotlib
6
+ seaborn
7
+ scikit-learn
8
+ fastapi
9
+ python-multipart
10
+ uvicorn
src/api_server.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, UploadFile, Form, File
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ import json
4
+ import tempfile
5
+ import os
6
+ from pathlib import Path
7
+ import uvicorn
8
+ from typing import Optional
9
+ from pydantic import BaseModel
10
+ import shutil
11
+
12
+ from gait_analyze import GaitAnalyzer
13
+
14
+ app = FastAPI()
15
+
16
+ # 配置CORS
17
+ app.add_middleware(
18
+ CORSMiddleware,
19
+ allow_origins=["*"], # 在生产环境中应该设置具体的域名
20
+ allow_credentials=True,
21
+ allow_methods=["*"],
22
+ allow_headers=["*"],
23
+ )
24
+
25
+ class AnalysisResponse(BaseModel):
26
+ """分析响应模型"""
27
+ status: str
28
+ message: str
29
+ data: Optional[dict] = None
30
+
31
+ @app.post("/api/v1/analysisFootVideo")
32
+ async def analysis_foot_video(
33
+ video: UploadFile = File(...),
34
+ params: str = Form(...)
35
+ ):
36
+ """处理足印视频分析请求"""
37
+ try:
38
+ # 解析参数
39
+ analysis_params = json.loads(params)
40
+
41
+ # 创建临时目录保存视频
42
+ with tempfile.TemporaryDirectory() as temp_dir:
43
+ # 保存上传的视频
44
+ video_path = os.path.join(temp_dir, "input_video.mp4")
45
+ with open(video_path, "wb") as buffer:
46
+ shutil.copyfileobj(video.file, buffer)
47
+
48
+ # 创建分析器实例
49
+ analyzer = GaitAnalyzer()
50
+
51
+ # 自动检测时间范围
52
+ start_time, end_time = analyzer._detect_mouse_time_range(video_path)
53
+
54
+ # 处理视频
55
+ analyzer.process_video(
56
+ video_path,
57
+ start_time=start_time,
58
+ end_time=end_time,
59
+ conf_thres=0.7,
60
+ iou_thres=0.5
61
+ )
62
+
63
+ # 获取结果目录中的数据
64
+ result_dir = Path(analyzer.result_dir)
65
+
66
+ # 读取JSON结果
67
+ json_path = result_dir / "data" / "footprint_data.json"
68
+ with open(json_path, 'r') as f:
69
+ footprint_data = json.load(f)
70
+
71
+ # 添加视频参数到结果中
72
+ footprint_data.update({
73
+ "video_info": {
74
+ "fps": analysis_params.get("video_fps"),
75
+ "width": analysis_params.get("video_width"),
76
+ "height": analysis_params.get("video_height"),
77
+ "scale_length": analysis_params.get("scale_length"),
78
+ "actual_length": analysis_params.get("actual_length"),
79
+ }
80
+ })
81
+
82
+ return AnalysisResponse(
83
+ status="success",
84
+ message="分析完成",
85
+ data=footprint_data
86
+ )
87
+
88
+ except Exception as e:
89
+ return AnalysisResponse(
90
+ status="error",
91
+ message=f"分析失败: {str(e)}"
92
+ )
93
+
94
+ def main():
95
+ """启动服务器"""
96
+ uvicorn.run(
97
+ "api_server:app",
98
+ host="0.0.0.0",
99
+ port=12345,
100
+ reload=True
101
+ )
102
+
103
+ if __name__ == "__main__":
104
+ main()
src/gait_analyze.py ADDED
@@ -0,0 +1,1385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ultralytics import YOLO
2
+ import cv2
3
+ import numpy as np
4
+ from dataclasses import dataclass
5
+ from typing import List, Dict, Tuple
6
+ import pandas as pd
7
+ from datetime import datetime
8
+ import matplotlib.pyplot as plt
9
+ from matplotlib.patches import Rectangle
10
+ import seaborn as sns
11
+ import os
12
+
13
+ @dataclass
14
+ class GaitPrint:
15
+ """足印数据类"""
16
+ frame_id: int # 帧号
17
+ x: float # 中心x坐标
18
+ y: float # 中心y坐标
19
+ w: float # 宽度
20
+ h: float # 高度
21
+ conf: float # 置信度
22
+ paw_type: str = None # 爪子类型 (LF, RF, LH, RH)
23
+ timestamp: float = None # 时间戳
24
+ stance_phase: bool = True # 支撑相标志
25
+ gait_cycle: int = 0 # 步态周期编号
26
+ image_patch: np.ndarray = None # 足印图像块
27
+ image_features: np.ndarray = None # 图像特征向量
28
+ cluster_id: int = -1 # 聚类ID
29
+
30
+ class GaitAnalyzer:
31
+ def __init__(self):
32
+ """初始化步态分析器"""
33
+ self.gait_model = YOLO('models/mice-gait-v1.0.mlpackage', task='detect')
34
+ self.pose_model = YOLO('models/mice-pose-bottomview-v1.0.mlpackage', task='pose') # 新增pose模型
35
+ self.gait_prints: List[GaitPrint] = []
36
+ self.mice_positions: List[Dict] = [] # 现在存储pose关键点信息
37
+ self.params = {}
38
+ self.time_window = 0.2
39
+ self.distance_threshold = 30
40
+ self.gait_pattern = None
41
+ self.result_dir = self._create_result_dir()
42
+
43
+ def _detect_mouse_time_range(self, video_path: str, margin_ratio: float = 0.05) -> Tuple[float, float]:
44
+ """使用pose模型的鼻子和尾巴点来检测老鼠"""
45
+ print("\n[0/6] 预处理视频以确定分析时间范围...")
46
+
47
+ cap = cv2.VideoCapture(video_path)
48
+ if not cap.isOpened():
49
+ raise ValueError("无法打开视频文件")
50
+
51
+ # 获取视频基本信息
52
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
53
+ fps = int(cap.get(cv2.CAP_PROP_FPS))
54
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
55
+
56
+ margin = int(width * margin_ratio)
57
+ left_boundary = margin
58
+ right_boundary = width - margin
59
+
60
+ start_frame = None
61
+ end_frame = None
62
+ frame_id = 0
63
+
64
+ while cap.isOpened():
65
+ ret, frame = cap.read()
66
+ if not ret:
67
+ break
68
+
69
+ if frame_id % 30 == 0: # 每30帧打印一次进度
70
+ print(f"预处理进度: {frame_id}/{total_frames} ({frame_id/total_frames*100:.1f}%)")
71
+
72
+ # 使用pose模型检测
73
+ results = self.pose_model(frame, conf=0.5, verbose=False)
74
+
75
+ for r in results:
76
+ keypoints = r.keypoints.data[0] # 获取关键点
77
+ if len(keypoints) >= 7: # 确保检测到所有关键点
78
+ nose_x = keypoints[0][0].item() # 鼻子x坐标
79
+ tail_x = keypoints[6][0].item() # 尾巴x坐标
80
+
81
+ # 使用鼻子位置判断开始
82
+ if start_frame is None and nose_x > left_boundary + margin:
83
+ start_frame = frame_id
84
+
85
+ # 使用尾巴位置判断结束
86
+ if start_frame is not None and tail_x > right_boundary - margin:
87
+ end_frame = frame_id
88
+ break
89
+
90
+ if end_frame is not None:
91
+ break
92
+
93
+ frame_id += 1
94
+
95
+ cap.release()
96
+
97
+ # 转换为时间
98
+ start_time = start_frame / fps if start_frame is not None else 0
99
+ end_time = end_frame / fps if end_frame is not None else total_frames / fps
100
+
101
+ print(f"检测到有效时间范围: {start_time:.2f}s - {end_time:.2f}s")
102
+ return start_time, end_time
103
+
104
+ def _create_result_dir(self) -> str:
105
+ """创建结果目录结构"""
106
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
107
+ result_dir = f"results/{timestamp}"
108
+
109
+ # 创建目录结构
110
+ for subdir in ['data', 'plots', 'videos']:
111
+ os.makedirs(f"{result_dir}/{subdir}", exist_ok=True)
112
+
113
+ return result_dir
114
+
115
+ def _filter_footprints(self):
116
+ """对足印进行时空聚类,将同一个足印的多次检测归为一组"""
117
+ if not self.gait_prints:
118
+ print("警告:没有足印数据可供聚类")
119
+ return
120
+
121
+ try:
122
+ from sklearn.preprocessing import StandardScaler
123
+ from sklearn.cluster import DBSCAN
124
+
125
+ # 1. 准备数据
126
+ features = np.array([[p.x, p.y, p.timestamp * 30] for p in self.gait_prints])
127
+ print(f"开始聚类,原始足印数量: {len(features)}")
128
+
129
+ # 2. 标准化特征
130
+ scaler = StandardScaler()
131
+ features_scaled = scaler.fit_transform(features)
132
+
133
+ # 3. DBSCAN聚类
134
+ eps = 0.3 # 可以根据实际情况调整
135
+ min_samples = 1 # 设为1以保留所有检测
136
+ dbscan = DBSCAN(eps=eps, min_samples=min_samples)
137
+ cluster_labels = dbscan.fit_predict(features_scaled)
138
+
139
+ # 4. 将聚类结果添加到足印对象中
140
+ for print_obj, label in zip(self.gait_prints, cluster_labels):
141
+ print_obj.cluster_id = label
142
+
143
+ # 打印聚类统计信息
144
+ n_clusters = len(set(cluster_labels)) - (1 if -1 in cluster_labels else 0)
145
+ print(f"聚类完成! 共识别出 {n_clusters} 个独立足印")
146
+
147
+ # 按cluster_id分组并打印每组的大小
148
+ from collections import Counter
149
+ cluster_sizes = Counter(cluster_labels)
150
+ print("\n各组足印检测数量:")
151
+ for cluster_id, size in sorted(cluster_sizes.items()):
152
+ if cluster_id != -1:
153
+ print(f"足印 #{cluster_id}: {size}个检测")
154
+
155
+ except Exception as e:
156
+ print(f"聚类过程出错: {str(e)}")
157
+
158
+ def _post_process_footprints(self):
159
+ """后处理足迹数据:聚类、过滤和分类"""
160
+ if not self.gait_prints:
161
+ print("警告:没有足印数据可供处理")
162
+ return
163
+
164
+ # 1. 时空聚类
165
+ self._filter_footprints()
166
+
167
+ if not self.gait_prints:
168
+ print("错误:后处理后没有足印数据")
169
+ return
170
+
171
+ # 2. 全局分析
172
+ mouse_path = pd.DataFrame(self.mice_positions)
173
+ total_dx = mouse_path['x'].iloc[-1] - mouse_path['x'].iloc[0]
174
+ moving_right = total_dx > 0
175
+
176
+ # 3. 使用机器学习进行足迹分类
177
+ self._classify_footprints(moving_right)
178
+
179
+ # 4. 确定步态周期
180
+ self._determine_gait_cycles()
181
+
182
+ def _classify_footprints(self, moving_right: bool):
183
+ """使用pose关键点来分类足迹"""
184
+ if len(self.gait_prints) < 4 or not self.mice_positions:
185
+ print("警告:足迹或姿态数据不足")
186
+ return
187
+
188
+ # 按时间排序足印
189
+ sorted_prints = sorted(self.gait_prints, key=lambda p: p.timestamp)
190
+
191
+ # 对每个cluster进行分类
192
+ cluster_groups = {}
193
+ for p in sorted_prints:
194
+ if p.cluster_id not in cluster_groups:
195
+ cluster_groups[p.cluster_id] = []
196
+ cluster_groups[p.cluster_id].append(p)
197
+
198
+ # 对每个cluster,找到最近时间的pose数据
199
+ for cluster_id, prints in cluster_groups.items():
200
+ mid_time = np.mean([p.timestamp for p in prints])
201
+ closest_pose = min(self.mice_positions,
202
+ key=lambda m: abs(m['timestamp'] - mid_time))
203
+
204
+ if 'keypoints' not in closest_pose:
205
+ print(f"警告:时间戳 {mid_time:.2f}s 处的姿态数据缺少关键点信息")
206
+ continue
207
+
208
+ # 计算cluster中心位置
209
+ center_x = np.mean([p.x for p in prints])
210
+ center_y = np.mean([p.y for p in prints])
211
+
212
+ # 获取关键点位置
213
+ nose = closest_pose['keypoints']['nose']
214
+ re = closest_pose['keypoints']['right_ear']
215
+ le = closest_pose['keypoints']['left_ear']
216
+ mid = closest_pose['keypoints']['mid']
217
+ rl = closest_pose['keypoints']['right_leg']
218
+ ll = closest_pose['keypoints']['left_leg']
219
+ tail = closest_pose['keypoints']['tail_base']
220
+
221
+ # 计算到各关键点的距离
222
+ dist_to_front = min(
223
+ np.sqrt((center_x - nose[0])**2 + (center_y - nose[1])**2),
224
+ np.sqrt((center_x - re[0])**2 + (center_y - re[1])**2),
225
+ np.sqrt((center_x - le[0])**2 + (center_y - le[1])**2)
226
+ )
227
+
228
+ dist_to_back = min(
229
+ np.sqrt((center_x - rl[0])**2 + (center_y - rl[1])**2),
230
+ np.sqrt((center_x - ll[0])**2 + (center_y - ll[1])**2),
231
+ np.sqrt((center_x - tail[0])**2 + (center_y - tail[1])**2)
232
+ )
233
+
234
+ # 前后判断:比较到前后关键点的距离
235
+ is_front = dist_to_front < dist_to_back
236
+
237
+ # 左右判断:根据y坐标相对位置
238
+ if is_front:
239
+ is_left = center_y > (re[1] + le[1])/2 # 比较与耳朵中点的位置
240
+ else:
241
+ is_left = center_y > (rl[1] + ll[1])/2 # 比较与后腿中点的位置
242
+
243
+ # 确定爪子类型
244
+ paw_type = None
245
+ if is_front:
246
+ paw_type = 'LF' if is_left else 'RF'
247
+ else:
248
+ paw_type = 'LH' if is_left else 'RH'
249
+
250
+ # 应用分类结果
251
+ for p in prints:
252
+ p.paw_type = paw_type
253
+
254
+ # 按帧组织足印数据,用于时序一致性检查
255
+ frame_prints = {}
256
+ for p in sorted_prints:
257
+ if p.frame_id not in frame_prints:
258
+ frame_prints[p.frame_id] = []
259
+ frame_prints[p.frame_id].append(p)
260
+
261
+ # 进行时序一致性检查
262
+ self._enforce_temporal_consistency(frame_prints)
263
+
264
+ def _enforce_temporal_consistency(self, frame_prints):
265
+ """确保时序一致性:同一时间同一类型的足印只能有一个"""
266
+ for frame_id, prints in frame_prints.items():
267
+ # 按类型分组
268
+ type_groups = {'LF': [], 'RF': [], 'LH': [], 'RH': []}
269
+ for p in prints:
270
+ if p.paw_type:
271
+ type_groups[p.paw_type].append(p)
272
+
273
+ # 处理每个有多个足印的类型
274
+ for paw_type, group in type_groups.items():
275
+ if len(group) > 1:
276
+ # 保留cluster_id较大的足印(通常是较新的足印)
277
+ newest_print = max(group, key=lambda p: p.cluster_id)
278
+ for p in group:
279
+ if p != newest_print:
280
+ # 将重复的足印重新分类为对角的另一只脚
281
+ if paw_type == 'LF':
282
+ p.paw_type = 'RH'
283
+ elif paw_type == 'RF':
284
+ p.paw_type = 'LH'
285
+ elif paw_type == 'LH':
286
+ p.paw_type = 'RF'
287
+ else: # RH
288
+ p.paw_type = 'LF'
289
+
290
+ def _smooth_classifications(self):
291
+ """使用时序信息平滑分类结果"""
292
+ # 1. 按时间排序足迹
293
+ sorted_prints = sorted(self.gait_prints, key=lambda x: x.timestamp)
294
+
295
+ # 2. 为每种爪子类型建立时间序列
296
+ paw_sequences = {'LF': [], 'RF': [], 'LH': [], 'RH': []}
297
+
298
+ # 3. 检测和修正异常分类
299
+ window_size = 0.2 # 200ms时间窗口
300
+ for i, print in enumerate(sorted_prints):
301
+ # 获取时间窗口内的同类足迹
302
+ nearby_prints = [
303
+ p for p in sorted_prints
304
+ if abs(p.timestamp - print.timestamp) < window_size
305
+ and p.paw_type == print.paw_type
306
+ and p != print
307
+ ]
308
+
309
+ if nearby_prints:
310
+ # 检查空间一致性
311
+ avg_x = np.mean([p.x for p in nearby_prints])
312
+ avg_y = np.mean([p.y for p in nearby_prints])
313
+
314
+ # 如果当前足迹位置偏离太远,考虑重新分类
315
+ dist = np.sqrt((print.x - avg_x)**2 + (print.y - avg_y)**2)
316
+ if dist > 50: # 像素距离阈值
317
+ # 尝试重新分类
318
+ self._reclassify_print(print, sorted_prints, i)
319
+
320
+ # 更新序列
321
+ paw_sequences[print.paw_type].append(print)
322
+
323
+ # 4. 验证步态模式的合理性
324
+ self._validate_gait_pattern(paw_sequences)
325
+
326
+ def _reclassify_print(self, print, all_prints, current_idx):
327
+ """重新分类可能错误的足迹"""
328
+ # 获取临近的其他足迹
329
+ window_size = 0.2
330
+ nearby_prints = [
331
+ p for p in all_prints
332
+ if abs(p.timestamp - print.timestamp) < window_size
333
+ and p != print
334
+ ]
335
+
336
+ if not nearby_prints:
337
+ return
338
+
339
+ # 计算到每种类型足迹的平均距离
340
+ type_distances = {}
341
+ for paw_type in ['LF', 'RF', 'LH', 'RH']:
342
+ same_type = [p for p in nearby_prints if p.paw_type == paw_type]
343
+ if same_type:
344
+ avg_dist = np.mean([
345
+ np.sqrt((print.x - p.x)**2 + (print.y - p.y)**2)
346
+ for p in same_type
347
+ ])
348
+ type_distances[paw_type] = avg_dist
349
+
350
+ # 选择距离最小的类型
351
+ if type_distances:
352
+ new_type = min(type_distances.items(), key=lambda x: x[1])[0]
353
+ print.paw_type = new_type
354
+
355
+ def _validate_gait_pattern(self, paw_sequences):
356
+ """验证步态模式的合理性"""
357
+ # 1. 检查每个爪子的步频是否合理
358
+ for paw_type, sequence in paw_sequences.items():
359
+ if len(sequence) >= 2:
360
+ time_diffs = [
361
+ sequence[i+1].timestamp - sequence[i].timestamp
362
+ for i in range(len(sequence)-1)
363
+ ]
364
+ avg_freq = 1 / np.mean(time_diffs) if time_diffs else 0
365
+
366
+ # 步频应该在合理范围内 (通常2-5Hz)
367
+ if not (2 <= avg_freq <= 5):
368
+ print(f"警告:{paw_type}的步频 ({avg_freq:.2f}Hz) 可能不合理")
369
+
370
+ # 2. ���查对角步态模式
371
+ diagonal_pairs = [('LF', 'RH'), ('RF', 'LH')]
372
+ for pair in diagonal_pairs:
373
+ seq1 = paw_sequences[pair[0]]
374
+ seq2 = paw_sequences[pair[1]]
375
+ if seq1 and seq2:
376
+ # 检查对角步态的同步性
377
+ for p1 in seq1:
378
+ nearest = min(seq2, key=lambda p: abs(p.timestamp - p1.timestamp))
379
+ if abs(nearest.timestamp - p1.timestamp) > 0.1: # 100ms阈值
380
+ print(f"警告:对角步态 {pair} 可能不同步")
381
+
382
+ def _is_position_reasonable(self, print_obj, new_type):
383
+ """检查新的分类是否在合理的空间位置范围内"""
384
+ # 获取同类型足迹的平均位置
385
+ same_type_prints = [p for p in self.gait_prints if p.paw_type == new_type]
386
+ if not same_type_prints:
387
+ return True
388
+
389
+ # 计算平均位置
390
+ avg_x = np.mean([p.x for p in same_type_prints])
391
+ avg_y = np.mean([p.y for p in same_type_prints])
392
+
393
+ # 计算标准差
394
+ std_x = np.std([p.x for p in same_type_prints])
395
+ std_y = np.std([p.y for p in same_type_prints])
396
+
397
+ # 检查是否在3个标准差范围内
398
+ x_reasonable = abs(print_obj.x - avg_x) < 3 * std_x
399
+ y_reasonable = abs(print_obj.y - avg_y) < 3 * std_y
400
+
401
+ return x_reasonable and y_reasonable
402
+
403
+ def _determine_gait_cycles(self):
404
+ """确定每个足印所属的步态周期"""
405
+ # 按爪子类型分组
406
+ paw_groups = {'LF': [], 'RF': [], 'LH': [], 'RH': []}
407
+ for print in self.gait_prints:
408
+ if print.paw_type:
409
+ paw_groups[print.paw_type].append(print)
410
+
411
+ # 为每种爪子类型确定步态周期
412
+ for paw_type, prints in paw_groups.items():
413
+ if len(prints) < 2:
414
+ continue
415
+
416
+ # 按时间排序
417
+ prints.sort(key=lambda x: x.timestamp)
418
+
419
+ # 初始化第一个周期
420
+ cycle_id = 1
421
+ prints[0].gait_cycle = cycle_id
422
+
423
+ # 基于时空距离确定周期
424
+ for i in range(1, len(prints)):
425
+ prev_print = prints[i-1]
426
+ curr_print = prints[i]
427
+
428
+ # 计算与前一个足印的时空距离
429
+ time_diff = curr_print.timestamp - prev_print.timestamp
430
+ space_diff = np.sqrt((curr_print.x - prev_print.x)**2 +
431
+ (curr_print.y - prev_print.y)**2)
432
+
433
+ # 如果时空距离超过阈值,开始新的周期
434
+ if time_diff > self.time_window * 2 or space_diff > self.distance_threshold * 3:
435
+ cycle_id += 1
436
+
437
+ curr_print.gait_cycle = cycle_id
438
+
439
+ def _calculate_gait_parameters(self):
440
+ """计算步态参数"""
441
+ # 按类型分组足印
442
+ paw_groups = {'LF': [], 'RF': [], 'LH': [], 'RH': []}
443
+ for print in self.gait_prints:
444
+ if print.paw_type:
445
+ paw_groups[print.paw_type].append(print)
446
+
447
+ # 初始化参数字典
448
+ self.params = {
449
+ 'stride_length': {}, # 步幅
450
+ 'step_width': {}, # 步宽
451
+ 'step_frequency': {}, # 步频
452
+ 'stance_time': {}, # 支撑时间
453
+ 'swing_time': {}, # 摆动时间
454
+ 'duty_factor': {}, # 支撑占空比
455
+ 'symmetry_index': {}, # 对称性指数
456
+ 'base_of_support': {}, # 支撑基底
457
+ 'stride_time': {} # 步态周期时间
458
+ }
459
+
460
+ # 计算每种爪子的参数
461
+ for paw_type, prints in paw_groups.items():
462
+ if len(prints) < 2:
463
+ continue
464
+
465
+ # 按时间排序
466
+ prints.sort(key=lambda x: x.timestamp)
467
+
468
+ # 1. 计算步幅 (相邻足印间距)
469
+ stride_lengths = []
470
+ for i in range(1, len(prints)):
471
+ if prints[i].gait_cycle == prints[i-1].gait_cycle:
472
+ dist = np.sqrt((prints[i].x - prints[i-1].x)**2 +
473
+ (prints[i].y - prints[i-1].y)**2)
474
+ stride_lengths.append(dist)
475
+ self.params['stride_length'][paw_type] = np.mean(stride_lengths) if stride_lengths else 0
476
+
477
+ # 2. 计算步频
478
+ if len(prints) > 1:
479
+ time_diff = prints[-1].timestamp - prints[0].timestamp
480
+ self.params['step_frequency'][paw_type] = (len(prints) - 1) / time_diff if time_diff > 0 else 0
481
+
482
+ # 3. 计算步态周期时间
483
+ cycle_times = []
484
+ for i in range(1, len(prints)):
485
+ if prints[i].gait_cycle == prints[i-1].gait_cycle:
486
+ cycle_times.append(prints[i].timestamp - prints[i-1].timestamp)
487
+ self.params['stride_time'][paw_type] = np.mean(cycle_times) if cycle_times else 0
488
+
489
+ # 4. 计算支撑时间和摆动时间
490
+ stance_times = []
491
+ swing_times = []
492
+ for i in range(len(prints)-1):
493
+ if prints[i].gait_cycle == prints[i+1].gait_cycle:
494
+ stance_time = 0.1 # 假设支撑相持续100ms
495
+ swing_time = prints[i+1].timestamp - prints[i].timestamp - stance_time
496
+ stance_times.append(stance_time)
497
+ swing_times.append(swing_time)
498
+
499
+ self.params['stance_time'][paw_type] = np.mean(stance_times) if stance_times else 0
500
+ self.params['swing_time'][paw_type] = np.mean(swing_times) if swing_times else 0
501
+
502
+ # 5. 计算支撑占空比
503
+ total_time = self.params['stance_time'][paw_type] + self.params['swing_time'][paw_type]
504
+ self.params['duty_factor'][paw_type] = (self.params['stance_time'][paw_type] / total_time
505
+ if total_time > 0 else 0)
506
+
507
+ # 6. 计算左右对称性指数
508
+ for side in ['F', 'H']: # 前爪和后爪
509
+ left_paw = f'L{side}'
510
+ right_paw = f'R{side}'
511
+ if (left_paw in self.params['stride_length'] and
512
+ right_paw in self.params['stride_length']):
513
+ left_stride = self.params['stride_length'][left_paw]
514
+ right_stride = self.params['stride_length'][right_paw]
515
+ symmetry = abs(left_stride - right_stride) / ((left_stride + right_stride) / 2)
516
+ self.params['symmetry_index'][side] = symmetry
517
+
518
+ # 7. 计算支撑基底
519
+ for side in ['F', 'H']: # 前爪和后爪
520
+ left_prints = paw_groups[f'L{side}']
521
+ right_prints = paw_groups[f'R{side}']
522
+ if left_prints and right_prints:
523
+ # 计算同一时间窗口内左右爪的横向距离
524
+ base_distances = []
525
+ for left_print in left_prints:
526
+ # 找到最近时间的右爪足印
527
+ nearest_right = min(right_prints,
528
+ key=lambda p: abs(p.timestamp - left_print.timestamp))
529
+ if abs(nearest_right.timestamp - left_print.timestamp) < self.time_window:
530
+ dist = abs(left_print.y - nearest_right.y)
531
+ base_distances.append(dist)
532
+
533
+ self.params['base_of_support'][side] = np.mean(base_distances) if base_distances else 0
534
+
535
+ # 8. 确定步态模式
536
+ self._determine_gait_pattern()
537
+
538
+ def _determine_gait_pattern(self):
539
+ """确定步态模式(walk, trot, gallop, bound)"""
540
+ if len(self.gait_prints) < 4:
541
+ return
542
+
543
+ # 分析相邻足印的时间关系
544
+ prints_by_time = sorted(self.gait_prints, key=lambda p: p.timestamp)
545
+ time_diffs = []
546
+ for i in range(1, len(prints_by_time)):
547
+ if prints_by_time[i].paw_type and prints_by_time[i-1].paw_type:
548
+ time_diff = prints_by_time[i].timestamp - prints_by_time[i-1].timestamp
549
+ time_diffs.append(time_diff)
550
+
551
+ if not time_diffs:
552
+ return
553
+
554
+ # 计算时间差的变异系数
555
+ mean_diff = np.mean(time_diffs)
556
+ std_diff = np.std(time_diffs)
557
+ cv = std_diff / mean_diff if mean_diff > 0 else 0
558
+
559
+ # 分析对角步态模式
560
+ diagonal_pairs = [('LF', 'RH'), ('RF', 'LH')]
561
+ diagonal_sync = []
562
+ for pair in diagonal_pairs:
563
+ prints1 = [p for p in self.gait_prints if p.paw_type == pair[0]]
564
+ prints2 = [p for p in self.gait_prints if p.paw_type == pair[1]]
565
+
566
+ if prints1 and prints2:
567
+ # 计算对角足印的同步性
568
+ min_time_diff = float('inf')
569
+ for p1 in prints1:
570
+ for p2 in prints2:
571
+ time_diff = abs(p1.timestamp - p2.timestamp)
572
+ min_time_diff = min(min_time_diff, time_diff)
573
+ diagonal_sync.append(min_time_diff)
574
+
575
+ # 基于时间特征确定步态模式
576
+ if cv < 0.2 and all(sync < 0.1 for sync in diagonal_sync):
577
+ self.gait_pattern = 'trot' # 对角步态同步性高
578
+ elif cv > 0.4:
579
+ self.gait_pattern = 'gallop' # 时间间隔变异大
580
+ elif len(set(p.gait_cycle for p in self.gait_prints)) > len(self.gait_prints) / 2:
581
+ self.gait_pattern = 'bound' # 步态周期数量多
582
+ else:
583
+ self.gait_pattern = 'walk' # 默认为行走
584
+
585
+ def process_video(self, video_path: str, start_time=0.0, end_time=None, conf_thres=0.7, iou_thres=0.5):
586
+ """处理视频并分析步态
587
+
588
+ Args:
589
+ video_path (str): 视频文件路径
590
+ start_time (float): 开始时间(秒),精确到0.01秒
591
+ end_time (float): 结束时间(秒),精确到0.01秒,None表示处理到视频结束
592
+ conf_thres (float): 置信度阈值
593
+ iou_thres (float): IOU阈值
594
+ """
595
+ print(f"\n[1/6] 开始处理视频: {video_path}")
596
+
597
+ cap = cv2.VideoCapture(video_path)
598
+ if not cap.isOpened():
599
+ raise ValueError(f"无法打开视频文件: {video_path}")
600
+
601
+ fps = int(cap.get(cv2.CAP_PROP_FPS))
602
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
603
+ video_duration = total_frames / fps
604
+
605
+ # 处理时间范围
606
+ if end_time is None:
607
+ end_time = video_duration
608
+
609
+ # 确保时间范围有效
610
+ if start_time < 0 or start_time >= video_duration:
611
+ raise ValueError(f"起始时间 {start_time:.2f}s 超出视频范围 [0, {video_duration:.2f}s]")
612
+ if end_time <= start_time or end_time > video_duration:
613
+ raise ValueError(f"结束时间 {end_time:.2f}s 无效,应在区间 ({start_time:.2f}, {video_duration:.2f}s]")
614
+
615
+ # 计算帧范围
616
+ start_frame = int(start_time * fps)
617
+ end_frame = int(end_time * fps)
618
+ process_frames = end_frame - start_frame
619
+ print(f"视频信息: 总时长 {video_duration:.2f}s ({total_frames}帧, {fps}FPS)")
620
+ print(f"处理时间段: {start_time:.2f}s - {end_time:.2f}s ({process_frames}帧)")
621
+
622
+ # 跳转到起始帧
623
+ cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
624
+
625
+ frame_id = start_frame
626
+ print("\n[2/6] 开始检测足迹...")
627
+ while cap.isOpened() and frame_id < end_frame:
628
+ success, frame = cap.read()
629
+ if not success:
630
+ break
631
+
632
+ if (frame_id - start_frame) % 30 == 0: # 每30帧打印一次进度
633
+ progress = (frame_id - start_frame) / process_frames
634
+ print(f"处理进度: {frame_id-start_frame}/{process_frames} ({progress*100:.1f}%)")
635
+
636
+ timestamp = frame_id / fps
637
+
638
+ # 处理pose检测结果
639
+ pose_results = self.pose_model(frame, conf=conf_thres, verbose=False)
640
+ for r in pose_results:
641
+ if r.keypoints is not None and len(r.keypoints.data) > 0:
642
+ kpts = r.keypoints.data[0] # 获取第一个检测到的老鼠的关键点
643
+ if len(kpts) >= 7: # 确保检测到所有关键点
644
+ self.mice_positions.append({
645
+ 'frame_id': frame_id,
646
+ 'timestamp': timestamp,
647
+ 'keypoints': {
648
+ 'nose': (kpts[0][0].item(), kpts[0][1].item()),
649
+ 'right_ear': (kpts[1][0].item(), kpts[1][1].item()),
650
+ 'left_ear': (kpts[2][0].item(), kpts[2][1].item()),
651
+ 'mid': (kpts[3][0].item(), kpts[3][1].item()),
652
+ 'right_leg': (kpts[4][0].item(), kpts[4][1].item()),
653
+ 'left_leg': (kpts[5][0].item(), kpts[5][1].item()),
654
+ 'tail_base': (kpts[6][0].item(), kpts[6][1].item())
655
+ },
656
+ 'x': kpts[3][0].item(), # 使用中点作为老鼠位置
657
+ 'y': kpts[3][1].item()
658
+ })
659
+
660
+ # 处理gait检测结果
661
+ gait_results = self.gait_model(frame, conf=conf_thres, iou=iou_thres, verbose=False)
662
+
663
+ # 处理检测结果
664
+ for r in gait_results:
665
+ boxes = r.boxes
666
+ for box in boxes:
667
+ x1, y1, x2, y2 = map(int, box.xyxy[0])
668
+ conf = float(box.conf[0])
669
+ cls = int(box.cls[0])
670
+
671
+ if cls == 0: # gait
672
+ # 提取足印图像块
673
+ patch = frame[y1:y2, x1:x2].copy()
674
+ # 只保留绿色通道
675
+ green_mask = self._extract_green_channel(patch)
676
+ # 提取图像特征
677
+ image_features = self._extract_image_features(green_mask)
678
+
679
+ self.gait_prints.append(GaitPrint(
680
+ frame_id=frame_id,
681
+ x=(x1 + x2) / 2,
682
+ y=(y1 + y2) / 2,
683
+ w=x2 - x1,
684
+ h=y2 - y1,
685
+ conf=conf,
686
+ timestamp=timestamp,
687
+ image_patch=green_mask,
688
+ image_features=image_features
689
+
690
+ ))
691
+ else: # mice类
692
+ self.mice_positions.append({
693
+ 'frame_id': frame_id,
694
+ 'x': (x1 + x2) / 2,
695
+ 'y': (y1 + y2) / 2,
696
+ 'w': x2 - x1,
697
+ 'h': y2 - y1,
698
+ 'conf': conf,
699
+ 'timestamp': timestamp
700
+ })
701
+
702
+ frame_id += 1
703
+
704
+ cap.release()
705
+ print(f"检测完成! 共检测到 {len(self.gait_prints)} 个足迹")
706
+ print("\n[3/6] 开始后处理...")
707
+ self._post_process_footprints()
708
+ print(f"后处理完成! 剩余 {len(self.gait_prints)} 个有效足迹")
709
+
710
+ print("\n[4/6] 计算步态参数...")
711
+ self._calculate_gait_parameters()
712
+ print("参数计算完成!")
713
+
714
+ print("\n[5/6] 保存分析结果...")
715
+ self._save_results()
716
+ self.visualize_results()
717
+ print("分析结果已保存!")
718
+
719
+ print("\n[6/6] 生成轨迹视频...")
720
+ self.generate_trajectory_video(video_path)
721
+ print("视频生成完成!")
722
+
723
+ def _get_paw_color(self, paw_type: str) -> Tuple[int, int, int]:
724
+ """获取不同爪子类型的颜色"""
725
+ colors = {
726
+ 'LF': (0, 255, 0), # 绿色
727
+ 'RF': (0, 255, 255), # 黄色
728
+ 'LH': (255, 0, 255), # 紫色
729
+ 'RH': (0, 165, 255), # 橙色
730
+ None: (0, 255, 0) # 未分类时为绿色
731
+ }
732
+ return colors.get(paw_type, (0, 255, 0))
733
+
734
+ def _save_results(self):
735
+ """保存分析结果"""
736
+ # 保存原有的CSV数据
737
+ df = pd.DataFrame([{
738
+ 'frame_id': p.frame_id,
739
+ 'x': p.x,
740
+ 'y': p.y,
741
+ 'width': p.w,
742
+ 'height': p.h,
743
+ 'confidence': p.conf,
744
+ 'paw_type': p.paw_type,
745
+ 'timestamp': p.timestamp,
746
+ 'gait_cycle': p.gait_cycle,
747
+ 'stance_phase': p.stance_phase
748
+ } for p in self.gait_prints])
749
+
750
+ df.to_csv(f'{self.result_dir}/data/gait_analysis.csv', index=False)
751
+
752
+ # 保存参数数据
753
+ params_df = pd.DataFrame(self.params)
754
+ params_df.to_csv(f'{self.result_dir}/data/gait_parameters.csv')
755
+
756
+ # 保存JSON格式的足印数据
757
+ self._save_footprint_json()
758
+
759
+ def visualize_results(self):
760
+ """可视化分析结果"""
761
+ # 创建保存路径
762
+ plots_dir = f'{self.result_dir}/plots'
763
+
764
+ # 1. 足印轨迹图
765
+ fig, ax = plt.subplots(figsize=(10, 8))
766
+ self._plot_footprint_trajectory(ax)
767
+ plt.savefig(f'{plots_dir}/trajectory.png', dpi=300, bbox_inches='tight')
768
+ plt.close()
769
+
770
+ # 2. 步态时序图
771
+ fig, ax = plt.subplots(figsize=(12, 6))
772
+ self._plot_gait_timeline(ax)
773
+ plt.savefig(f'{plots_dir}/timeline.png', dpi=300, bbox_inches='tight')
774
+ plt.close()
775
+
776
+ # 3. 步态参数图
777
+ # 3.1 步幅
778
+ fig, ax = plt.subplots(figsize=(8, 6))
779
+ self._plot_stride_length(ax)
780
+ plt.savefig(f'{plots_dir}/stride_length.png', dpi=300, bbox_inches='tight')
781
+ plt.close()
782
+
783
+ # 3.2 步频
784
+ fig, ax = plt.subplots(figsize=(8, 6))
785
+ self._plot_step_frequency(ax)
786
+ plt.savefig(f'{plots_dir}/step_frequency.png', dpi=300, bbox_inches='tight')
787
+ plt.close()
788
+
789
+ # 3.3 支撑和摆动时间
790
+ fig, ax = plt.subplots(figsize=(10, 6))
791
+ self._plot_stance_swing_time(ax)
792
+ plt.savefig(f'{plots_dir}/stance_swing_time.png', dpi=300, bbox_inches='tight')
793
+ plt.close()
794
+
795
+ # 3.4 支撑占空比
796
+ fig, ax = plt.subplots(figsize=(8, 6))
797
+ self._plot_duty_factor(ax)
798
+ plt.savefig(f'{plots_dir}/duty_factor.png', dpi=300, bbox_inches='tight')
799
+ plt.close()
800
+
801
+ # 3.5 对称性指数
802
+ fig, ax = plt.subplots(figsize=(8, 6))
803
+ self._plot_symmetry_index(ax)
804
+ plt.savefig(f'{plots_dir}/symmetry_index.png', dpi=300, bbox_inches='tight')
805
+ plt.close()
806
+
807
+ # 3.6 支撑基底
808
+ fig, ax = plt.subplots(figsize=(8, 6))
809
+ self._plot_base_of_support(ax)
810
+ plt.savefig(f'{plots_dir}/base_of_support.png', dpi=300, bbox_inches='tight')
811
+ plt.close()
812
+
813
+ # 4. 步态模式热图
814
+ fig, ax = plt.subplots(figsize=(12, 6))
815
+ self._plot_gait_pattern(ax)
816
+ plt.savefig(f'{plots_dir}/gait_pattern.png', dpi=300, bbox_inches='tight')
817
+ plt.close()
818
+
819
+ def generate_trajectory_video(self, video_path: str):
820
+ """生成包含足迹轨迹的视频"""
821
+ try:
822
+ # 准备输出路径
823
+ video_name = os.path.splitext(os.path.basename(video_path))[0]
824
+ yolo_path = f"{self.result_dir}/videos/{video_name}_yolo.mp4"
825
+ trajectory_path = f"{self.result_dir}/videos/{video_name}_trajectory.mp4"
826
+ combined_path = f"{self.result_dir}/videos/{video_name}_combined.mp4"
827
+
828
+ # 打开视频
829
+ cap = cv2.VideoCapture(video_path)
830
+ if not cap.isOpened():
831
+ raise ValueError("无法打开视频文件")
832
+
833
+ # 获取视频信息
834
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
835
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
836
+ fps = int(cap.get(cv2.CAP_PROP_FPS))
837
+ frame_duration = 1.0 / fps
838
+
839
+ # 确定有效时间范围
840
+ valid_times = [p.timestamp for p in self.gait_prints]
841
+ start_time = min(valid_times) if valid_times else 0
842
+ end_time = max(valid_times) if valid_times else 0
843
+
844
+ # 跳转到起始帧
845
+ start_frame = int(start_time * fps)
846
+ cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
847
+
848
+ # 创建视频写入器
849
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
850
+ yolo_writer = cv2.VideoWriter(yolo_path, fourcc, fps, (width, height))
851
+ traj_writer = cv2.VideoWriter(trajectory_path, fourcc, fps, (width, height))
852
+ combined_writer = cv2.VideoWriter(combined_path, fourcc, fps, (width, height*2))
853
+
854
+ # 准备轨迹图像
855
+ trajectory_img = np.ones((height, width, 3), dtype=np.uint8) * 255
856
+
857
+ # 按时间排序足迹和老鼠位置
858
+ sorted_prints = sorted(self.gait_prints, key=lambda x: x.timestamp)
859
+ sorted_mice = sorted(self.mice_positions, key=lambda x: x['timestamp'])
860
+ current_time = start_time
861
+
862
+ while cap.isOpened() and current_time <= end_time:
863
+ ret, frame = cap.read()
864
+ if not ret:
865
+ break
866
+
867
+ # 1. 创建YOLO检测帧
868
+ yolo_frame = frame.copy()
869
+
870
+ # 2. 创建轨迹帧
871
+ traj_frame = trajectory_img.copy()
872
+
873
+ # 绘制当前时间的老鼠位置
874
+ current_mouse = next((m for m in sorted_mice if abs(m['timestamp'] - current_time) < frame_duration), None)
875
+ if current_mouse and 'keypoints' in current_mouse: # 确保有关键点数据
876
+ keypoints = current_mouse['keypoints']
877
+ nose = keypoints['nose']
878
+ tail = keypoints['tail_base']
879
+
880
+ # 计算中点
881
+ center_x = int((nose[0] + tail[0]) / 2)
882
+ center_y = int((nose[1] + tail[1]) / 2)
883
+
884
+ # 绘制鼻子到尾巴的虚线箭头
885
+ # 计算箭头方向
886
+ dx = nose[0] - tail[0]
887
+ dy = nose[1] - tail[1]
888
+ angle = np.arctan2(dy, dx)
889
+
890
+ # 绘制虚线
891
+ pt1 = (int(tail[0]), int(tail[1]))
892
+ pt2 = (int(nose[0]), int(nose[1]))
893
+
894
+ # 使用虚线绘制主线
895
+ dash_length = 5
896
+ total_length = np.sqrt(dx**2 + dy**2)
897
+ num_segments = int(total_length / (dash_length * 2))
898
+
899
+ for i in range(num_segments):
900
+ start_ratio = i / num_segments
901
+ end_ratio = (i + 0.5) / num_segments
902
+ start_x = int(tail[0] + dx * start_ratio)
903
+ start_y = int(tail[1] + dy * start_ratio)
904
+ end_x = int(tail[0] + dx * end_ratio)
905
+ end_y = int(tail[1] + dy * end_ratio)
906
+ cv2.line(yolo_frame, (start_x, start_y), (end_x, end_y), (255, 0, 0), 1)
907
+ cv2.line(traj_frame, (start_x, start_y), (end_x, end_y), (255, 0, 0), 1)
908
+
909
+ # 绘制箭头
910
+ arrow_length = 15
911
+ arrow_angle = np.pi/6 # 30度
912
+
913
+ # 计算箭头两边的点
914
+ arrow1_x = int(nose[0] - arrow_length * np.cos(angle + arrow_angle))
915
+ arrow1_y = int(nose[1] - arrow_length * np.sin(angle + arrow_angle))
916
+ arrow2_x = int(nose[0] - arrow_length * np.cos(angle - arrow_angle))
917
+ arrow2_y = int(nose[1] - arrow_length * np.sin(angle - arrow_angle))
918
+
919
+ # 绘制箭头
920
+ cv2.line(yolo_frame, (int(nose[0]), int(nose[1])), (arrow1_x, arrow1_y), (255, 0, 0), 1)
921
+ cv2.line(yolo_frame, (int(nose[0]), int(nose[1])), (arrow2_x, arrow2_y), (255, 0, 0), 1)
922
+ cv2.line(traj_frame, (int(nose[0]), int(nose[1])), (arrow1_x, arrow1_y), (255, 0, 0), 1)
923
+ cv2.line(traj_frame, (int(nose[0]), int(nose[1])), (arrow2_x, arrow2_y), (255, 0, 0), 1)
924
+
925
+ # 绘制中点
926
+ cv2.circle(yolo_frame, (center_x, center_y), 3, (255, 0, 0), -1)
927
+ cv2.circle(traj_frame, (center_x, center_y), 3, (255, 0, 0), -1)
928
+
929
+ # 绘制当前时间之前的所有足迹
930
+ for foot_print in sorted_prints:
931
+ if foot_print.timestamp <= current_time:
932
+ color = self._get_paw_color(foot_print.paw_type)
933
+ x1 = int(foot_print.x - foot_print.w/2)
934
+ y1 = int(foot_print.y - foot_print.h/2)
935
+ x2 = int(foot_print.x + foot_print.w/2)
936
+ y2 = int(foot_print.y + foot_print.h/2)
937
+
938
+ # 在YOLO帧上绘制检测框和标签
939
+ cv2.rectangle(yolo_frame, (x1, y1), (x2, y2), color, 2)
940
+ if foot_print.paw_type:
941
+ cv2.putText(yolo_frame, foot_print.paw_type, (x1, y1-5),
942
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
943
+
944
+ # 在轨迹帧上绘制足迹点
945
+ cv2.circle(traj_frame, (int(foot_print.x), int(foot_print.y)), 3, color, -1)
946
+
947
+ # 添加时间戳
948
+ cv2.putText(yolo_frame, f"Time: {current_time:.2f}s", (10, 30),
949
+ cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 2)
950
+ cv2.putText(traj_frame, f"Time: {current_time:.2f}s", (10, 30),
951
+ cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 2)
952
+
953
+ # 合并视频
954
+ combined_frame = np.vstack((yolo_frame, traj_frame))
955
+
956
+ # 写入帧
957
+ yolo_writer.write(yolo_frame)
958
+ traj_writer.write(traj_frame)
959
+ combined_writer.write(combined_frame)
960
+
961
+ current_time += frame_duration
962
+
963
+ # 释放资源
964
+ cap.release()
965
+ yolo_writer.release()
966
+ traj_writer.release()
967
+ combined_writer.release()
968
+
969
+ print(f"\nYOLO视频生成完成: {yolo_path}")
970
+ print(f"轨迹视频生成完成: {trajectory_path}")
971
+ print(f"并列视频生成完成: {combined_path}")
972
+
973
+ except Exception as e:
974
+ print(f"生成视频时出错: {str(e)}")
975
+ raise
976
+
977
+ def _plot_footprint_trajectory(self, ax):
978
+ """绘制足印轨迹"""
979
+ ax.set_title('Footprint Trajectory')
980
+
981
+ # 为每种足印类型绘制散点图
982
+ colors = self._get_paw_color_dict()
983
+ for paw_type in ['LF', 'RF', 'LH', 'RH']:
984
+ prints = [p for p in self.gait_prints if p.paw_type == paw_type]
985
+ if prints:
986
+ x = [p.x for p in prints]
987
+ y = [p.y for p in prints]
988
+ ax.scatter(x, y, c=[colors[paw_type]], label=paw_type, alpha=0.6)
989
+
990
+ # 绘制老鼠运动轨迹,使用鼻子到尾巴的中点
991
+ mouse_positions = []
992
+ for pos in self.mice_positions:
993
+ if 'keypoints' in pos:
994
+ nose = pos['keypoints']['nose']
995
+ tail = pos['keypoints']['tail_base']
996
+ mid_x = (nose[0] + tail[0]) / 2
997
+ mid_y = (nose[1] + tail[1]) / 2
998
+ mouse_positions.append((mid_x, mid_y))
999
+
1000
+ if mouse_positions:
1001
+ mouse_x, mouse_y = zip(*mouse_positions)
1002
+ ax.plot(mouse_x, mouse_y, 'b--', alpha=0.1, label='Mouse path')
1003
+ # 在最后一个位置绘制一个点
1004
+ ax.scatter(mouse_x[-1], mouse_y[-1], c='blue', s=50, alpha=0.6)
1005
+
1006
+ ax.legend()
1007
+ ax.set_xlabel('X position (pixels)')
1008
+ ax.set_ylabel('Y position (pixels)')
1009
+ ax.invert_yaxis() # 图像坐标系y轴向下
1010
+
1011
+ def _plot_gait_timeline(self, ax):
1012
+ """绘制步态时序图"""
1013
+ ax.set_title('Gait Timeline')
1014
+
1015
+ colors = self._get_paw_color_dict()
1016
+ y_positions = {'LF': 4, 'RF': 3, 'LH': 2, 'RH': 1}
1017
+
1018
+ for paw_type in ['LF', 'RF', 'LH', 'RH']:
1019
+ prints = [p for p in self.gait_prints if p.paw_type == paw_type]
1020
+ times = [p.timestamp for p in prints]
1021
+ y = [y_positions[paw_type]] * len(times)
1022
+ ax.scatter(times, y, c=[colors[paw_type]], label=paw_type, marker='s')
1023
+
1024
+ ax.set_yticks(list(y_positions.values()))
1025
+ ax.set_yticklabels(list(y_positions.keys()))
1026
+ ax.set_xlabel('Time (seconds)')
1027
+ ax.grid(True, axis='x', alpha=0.3)
1028
+
1029
+ def _plot_gait_parameters(self, ax):
1030
+ """绘制步态参数柱状图"""
1031
+ ax.set_title('Gait Parameters')
1032
+
1033
+ # 提取步幅和步频数据
1034
+ paw_types = list(self.params['stride_length'].keys())
1035
+ stride_lengths = [self.params['stride_length'][p] for p in paw_types]
1036
+ frequencies = [self.params['step_frequency'][p] for p in paw_types]
1037
+
1038
+ x = np.arange(len(paw_types))
1039
+ width = 0.35
1040
+
1041
+ ax.bar(x - width/2, stride_lengths, width, label='Stride Length')
1042
+ ax.bar(x + width/2, frequencies, width, label='Step Frequency')
1043
+
1044
+ ax.set_xticks(x)
1045
+ ax.set_xticklabels(paw_types)
1046
+ ax.legend()
1047
+ ax.set_ylabel('Value')
1048
+
1049
+ if self.gait_pattern:
1050
+ ax.text(0.02, 0.98, f'Gait Pattern: {self.gait_pattern}',
1051
+ transform=ax.transAxes, verticalalignment='top')
1052
+
1053
+ def _plot_gait_pattern(self, ax):
1054
+ """绘制步态模式热图"""
1055
+ ax.set_title('Gait Pattern')
1056
+
1057
+ # 创建时间窗口内的步态模式矩阵
1058
+ time_bins = np.linspace(0, max(p.timestamp for p in self.gait_prints), 20)
1059
+ paw_types = ['LF', 'RF', 'LH', 'RH']
1060
+ pattern_matrix = np.zeros((len(paw_types), len(time_bins)-1))
1061
+
1062
+ for i, paw_type in enumerate(paw_types):
1063
+ prints = [p for p in self.gait_prints if p.paw_type == paw_type]
1064
+ for p in prints:
1065
+ bin_idx = np.digitize(p.timestamp, time_bins) - 1
1066
+ if 0 <= bin_idx < len(time_bins)-1:
1067
+ pattern_matrix[i, bin_idx] += 1
1068
+
1069
+ sns.heatmap(pattern_matrix, ax=ax, cmap='YlOrRd',
1070
+ xticklabels=np.round(time_bins[:-1], 1),
1071
+ yticklabels=paw_types)
1072
+ ax.set_xlabel('Time (seconds)')
1073
+
1074
+ def _plot_stride_length(self, ax):
1075
+ """绘制步幅图"""
1076
+ ax.set_title('Stride Length')
1077
+ paw_types = list(self.params['stride_length'].keys())
1078
+ values = [self.params['stride_length'][p] for p in paw_types]
1079
+ colors = [self._get_paw_color_dict()[p] for p in paw_types]
1080
+
1081
+ bars = ax.bar(paw_types, values)
1082
+ for bar, color in zip(bars, colors):
1083
+ bar.set_color(color)
1084
+
1085
+ ax.set_ylabel('Length (pixels)')
1086
+ ax.grid(True, alpha=0.3)
1087
+
1088
+ def _plot_step_frequency(self, ax):
1089
+ """绘制步频图"""
1090
+ ax.set_title('Step Frequency')
1091
+ paw_types = list(self.params['step_frequency'].keys())
1092
+ values = [self.params['step_frequency'][p] for p in paw_types]
1093
+ colors = [self._get_paw_color_dict()[p] for p in paw_types]
1094
+
1095
+ bars = ax.bar(paw_types, values)
1096
+ for bar, color in zip(bars, colors):
1097
+ bar.set_color(color)
1098
+
1099
+ ax.set_ylabel('Frequency (steps/second)')
1100
+ ax.grid(True, alpha=0.3)
1101
+
1102
+ def _plot_stance_swing_time(self, ax):
1103
+ """绘制支撑和摆动时间图"""
1104
+ ax.set_title('Stance and Swing Time')
1105
+ paw_types = list(self.params['stance_time'].keys())
1106
+ stance_times = [self.params['stance_time'][p] for p in paw_types]
1107
+ swing_times = [self.params['swing_time'][p] for p in paw_types]
1108
+
1109
+ x = np.arange(len(paw_types))
1110
+ width = 0.35
1111
+
1112
+ ax.bar(x - width/2, stance_times, width, label='Stance Time')
1113
+ ax.bar(x + width/2, swing_times, width, label='Swing Time')
1114
+
1115
+ ax.set_xticks(x)
1116
+ ax.set_xticklabels(paw_types)
1117
+ ax.set_ylabel('Time (seconds)')
1118
+ ax.legend()
1119
+ ax.grid(True, alpha=0.3)
1120
+
1121
+ def _plot_duty_factor(self, ax):
1122
+ """绘制支撑占空比图"""
1123
+ ax.set_title('Duty Factor')
1124
+ paw_types = list(self.params['duty_factor'].keys())
1125
+ values = [self.params['duty_factor'][p] * 100 for p in paw_types] # 转换为百分比
1126
+ colors = [self._get_paw_color_dict()[p] for p in paw_types]
1127
+
1128
+ bars = ax.bar(paw_types, values)
1129
+ for bar, color in zip(bars, colors):
1130
+ bar.set_color(color)
1131
+
1132
+ ax.set_ylabel('Duty Factor (%)')
1133
+ ax.grid(True, alpha=0.3)
1134
+
1135
+ def _plot_symmetry_index(self, ax):
1136
+ """绘制对称性指数图"""
1137
+ ax.set_title('Symmetry Index')
1138
+ sides = list(self.params['symmetry_index'].keys())
1139
+ values = [self.params['symmetry_index'][s] * 100 for s in sides] # 转换为百分比
1140
+
1141
+ bars = ax.bar(sides, values)
1142
+ ax.set_ylabel('Symmetry Index (%)')
1143
+ ax.grid(True, alpha=0.3)
1144
+
1145
+ def _plot_base_of_support(self, ax):
1146
+ """绘制支撑基底图"""
1147
+ ax.set_title('Base of Support')
1148
+ sides = list(self.params['base_of_support'].keys())
1149
+ values = [self.params['base_of_support'][s] for s in sides]
1150
+
1151
+ bars = ax.bar(sides, values)
1152
+ ax.set_ylabel('Width (pixels)')
1153
+ ax.grid(True, alpha=0.3)
1154
+
1155
+ def _get_paw_color_dict(self):
1156
+ """获取足印颜色字典(matplotlib格式)"""
1157
+ return {
1158
+ 'LF': 'green',
1159
+ 'RF': 'yellow',
1160
+ 'LH': 'purple',
1161
+ 'RH': 'orange'
1162
+ }
1163
+
1164
+ def _extract_green_channel(self, patch: np.ndarray) -> np.ndarray:
1165
+ """提取并增强绿色通道"""
1166
+ # 转换为HSV颜色空间
1167
+ hsv = cv2.cvtColor(patch, cv2.COLOR_BGR2HSV)
1168
+
1169
+ # 定义绿色的HSV范围
1170
+ lower_green = np.array([40, 40, 40])
1171
+ upper_green = np.array([80, 255, 255])
1172
+
1173
+ # 创建掩码
1174
+ mask = cv2.inRange(hsv, lower_green, upper_green)
1175
+
1176
+ # 应用掩码
1177
+ green_only = cv2.bitwise_and(patch, patch, mask=mask)
1178
+
1179
+ # 转换为灰度图
1180
+ gray = cv2.cvtColor(green_only, cv2.COLOR_BGR2GRAY)
1181
+
1182
+ # 标准化尺寸
1183
+ resized = cv2.resize(gray, (32, 32))
1184
+
1185
+ return resized
1186
+
1187
+ def _extract_image_features(self, green_mask: np.ndarray) -> np.ndarray:
1188
+ """从足印图像提取特征"""
1189
+ # 1. 计算形状特征
1190
+ contours, _ = cv2.findContours(green_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
1191
+ if not contours:
1192
+ return np.zeros(10) # 返回默认特征向量
1193
+
1194
+ largest_contour = max(contours, key=cv2.contourArea)
1195
+
1196
+ # 面积
1197
+ area = cv2.contourArea(largest_contour)
1198
+ # 周长
1199
+ perimeter = cv2.arcLength(largest_contour, True)
1200
+ # 圆度
1201
+ circularity = 4 * np.pi * area / (perimeter * perimeter) if perimeter > 0 else 0
1202
+
1203
+ # 2. 计算方向
1204
+ if len(largest_contour) >= 5:
1205
+ (x, y), (MA, ma), angle = cv2.fitEllipse(largest_contour)
1206
+ else:
1207
+ MA, ma, angle = 0, 0, 0
1208
+
1209
+ # 3. 计算Hu矩
1210
+ moments = cv2.moments(largest_contour)
1211
+ hu_moments = cv2.HuMoments(moments).flatten()
1212
+
1213
+ # 4. 组合特征
1214
+ features = np.array([
1215
+ area,
1216
+ perimeter,
1217
+ circularity,
1218
+ MA/ma if ma > 0 else 0, # 长宽比
1219
+ angle,
1220
+ *hu_moments[:5] # 取前5个Hu矩
1221
+ ])
1222
+
1223
+ return features
1224
+
1225
+ def _generate_footprint_data(self) -> dict:
1226
+ """生成足印数据的JSON格式"""
1227
+ # 初始化数据结构
1228
+ data = {
1229
+ "frames": [],
1230
+ "footprintArea": []
1231
+ }
1232
+
1233
+ # 按帧ID组织足印数据
1234
+ frame_prints = {}
1235
+ for print in self.gait_prints:
1236
+ if print.frame_id not in frame_prints:
1237
+ frame_prints[print.frame_id] = []
1238
+ frame_prints[print.frame_id].append(print)
1239
+
1240
+ # 生成footprintArea的ID映射
1241
+ area_id_map = {} # cluster_id -> footprintAreaId
1242
+ for print in self.gait_prints:
1243
+ if print.cluster_id not in area_id_map:
1244
+ area_id_map[print.cluster_id] = f"footprintArea_{print.cluster_id}"
1245
+
1246
+ # 转换爪子类型
1247
+ type_map = {
1248
+ 'LF': 'leftFront',
1249
+ 'RF': 'rightFront',
1250
+ 'LH': 'leftHind',
1251
+ 'RH': 'rightHind'
1252
+ }
1253
+
1254
+ # 生成frames数据
1255
+ for frame_id, prints in frame_prints.items():
1256
+ frame_data = {
1257
+ "frameId": frame_id,
1258
+ "footprints": []
1259
+ }
1260
+
1261
+ for i, print in enumerate(prints):
1262
+ # 转换中心点坐标为左上角坐标
1263
+ x = int(print.x - print.w/2)
1264
+ y = int(print.y - print.h/2)
1265
+ w = int(print.w)
1266
+ h = int(print.h)
1267
+
1268
+ footprint_data = {
1269
+ "id": f"footprint_{frame_id}_{i}",
1270
+ "type": type_map.get(print.paw_type, 'unknown'),
1271
+ "isKeyFootprint": False,
1272
+ "position": {
1273
+ "x": x,
1274
+ "y": y,
1275
+ "width": w,
1276
+ "height": h
1277
+ },
1278
+ "footprintAreaId": area_id_map[print.cluster_id]
1279
+ }
1280
+ frame_data["footprints"].append(footprint_data)
1281
+
1282
+ data["frames"].append(frame_data)
1283
+
1284
+ # 生成footprintArea数据
1285
+ cluster_groups = {}
1286
+ for print in self.gait_prints:
1287
+ if print.cluster_id not in cluster_groups:
1288
+ cluster_groups[print.cluster_id] = []
1289
+ cluster_groups[print.cluster_id].append(print)
1290
+
1291
+ for cluster_id, prints in cluster_groups.items():
1292
+ # 计算整个簇的边界框(考虑每个足印的完整区域)
1293
+ x_coords = []
1294
+ y_coords = []
1295
+ for p in prints:
1296
+ # 添加每个足印框的四个角点
1297
+ x_coords.extend([
1298
+ p.x - p.w/2, # 左边界
1299
+ p.x + p.w/2 # 右边界
1300
+ ])
1301
+ y_coords.extend([
1302
+ p.y - p.h/2, # 上边界
1303
+ p.y + p.h/2 # 下边界
1304
+ ])
1305
+
1306
+ # 计算能包含所有足印的最小矩形
1307
+ x_min = min(x_coords)
1308
+ y_min = min(y_coords)
1309
+ x_max = max(x_coords)
1310
+ y_max = max(y_coords)
1311
+
1312
+ # 选择关键帧(这里选择置信度最高的)
1313
+ key_print = max(prints, key=lambda p: p.conf)
1314
+
1315
+ # 获取图像数据
1316
+ if key_print.image_patch is not None:
1317
+ import base64
1318
+ import cv2
1319
+ success, buffer = cv2.imencode('.png', key_print.image_patch)
1320
+ if success:
1321
+ img_base64 = base64.b64encode(buffer).decode('utf-8')
1322
+ else:
1323
+ img_base64 = ""
1324
+ else:
1325
+ img_base64 = ""
1326
+
1327
+ area_data = {
1328
+ "footprintAreaId": area_id_map[cluster_id],
1329
+ "type": type_map.get(key_print.paw_type, 'unknown'),
1330
+ "areaPosition": {
1331
+ "x": int(x_min),
1332
+ "y": int(y_min),
1333
+ "width": int(x_max - x_min),
1334
+ "height": int(y_max - y_min)
1335
+ },
1336
+ "keyFootprintFrame": {
1337
+ "frameId": key_print.frame_id,
1338
+ "footprintId": f"footprint_{key_print.frame_id}_{cluster_id}",
1339
+ "footPosition": {
1340
+ "x": int(key_print.x - key_print.w/2),
1341
+ "y": int(key_print.y - key_print.h/2),
1342
+ "width": int(key_print.w),
1343
+ "height": int(key_print.h)
1344
+ },
1345
+ "base64Image": img_base64
1346
+ },
1347
+ "startFrame": min(p.frame_id for p in prints),
1348
+ "endFrame": max(p.frame_id for p in prints)
1349
+ }
1350
+
1351
+ data["footprintArea"].append(area_data)
1352
+
1353
+ return data
1354
+
1355
+ def _save_footprint_json(self):
1356
+ """保存足印数据为JSON文件"""
1357
+ import json
1358
+
1359
+ data = self._generate_footprint_data()
1360
+ json_path = f'{self.result_dir}/data/footprint_data.json'
1361
+
1362
+ with open(json_path, 'w', encoding='utf-8') as f:
1363
+ json.dump(data, f, indent=2, ensure_ascii=False)
1364
+
1365
+ print(f"足印数据已保存至: {json_path}")
1366
+
1367
+
1368
+ def main():
1369
+ analyzer = GaitAnalyzer()
1370
+ video_path = "/Users/hakureirm/codespace/Work/Algorithm/gait/exp_videos/Exp8.mp4"
1371
+
1372
+ # 自动检测时间范围
1373
+ start_time, end_time = analyzer._detect_mouse_time_range(video_path)
1374
+
1375
+ # 处理检测到的时间段
1376
+ analyzer.process_video(
1377
+ video_path,
1378
+ start_time=start_time,
1379
+ end_time=end_time,
1380
+ conf_thres=0.7,
1381
+ iou_thres=0.5
1382
+ )
1383
+
1384
+ if __name__ == "__main__":
1385
+ main()
src/visualize_footprint_json.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import cv2
3
+ import numpy as np
4
+ from pathlib import Path
5
+
6
+ def visualize_footprint_json(json_path, output_dir):
7
+ """可视化足印JSON数据"""
8
+ # 读取JSON数据
9
+ with open(json_path, 'r') as f:
10
+ data = json.load(f)
11
+
12
+ # 创建输出目录
13
+ output_dir = Path(output_dir)
14
+ output_dir.mkdir(parents=True, exist_ok=True)
15
+
16
+ # 确定画布大小(从数据中获取最大坐标)
17
+ max_x = max_y = 0
18
+ for frame in data['frames']:
19
+ for footprint in frame['footprints']:
20
+ pos = footprint['position']
21
+ max_x = max(max_x, pos['x'] + pos['width'])
22
+ max_y = max(max_y, pos['y'] + pos['height'])
23
+
24
+ for area in data['footprintArea']:
25
+ pos = area['areaPosition']
26
+ max_x = max(max_x, pos['x'] + pos['width'])
27
+ max_y = max(max_y, pos['y'] + pos['height'])
28
+
29
+ # 添加一些边距
30
+ canvas_width = max_x + 50
31
+ canvas_height = max_y + 50
32
+
33
+ # 颜色映射
34
+ color_map = {
35
+ 'leftFront': (0, 255, 0), # 绿色
36
+ 'rightFront': (0, 255, 255), # 黄色
37
+ 'leftHind': (255, 0, 255), # 紫色
38
+ 'rightHind': (0, 165, 255), # 橙色
39
+ 'unknown': (128, 128, 128) # 灰色
40
+ }
41
+
42
+ # 1. 绘制所有足印区域的覆盖范围
43
+ area_canvas = np.ones((canvas_height, canvas_width, 3), dtype=np.uint8) * 255
44
+ for area in data['footprintArea']:
45
+ pos = area['areaPosition']
46
+ color = color_map[area['type']]
47
+
48
+ # 绘制区域边界框
49
+ cv2.rectangle(area_canvas,
50
+ (pos['x'], pos['y']),
51
+ (pos['x'] + pos['width'], pos['y'] + pos['height']),
52
+ color, 2)
53
+
54
+ # 添加标签
55
+ label = f"{area['type']}#{area['footprintAreaId'].split('_')[-1]}"
56
+ cv2.putText(area_canvas, label,
57
+ (pos['x'], pos['y'] - 5),
58
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
59
+
60
+ # 保存足印区域图
61
+ cv2.imwrite(str(output_dir / 'footprint_areas.png'), area_canvas)
62
+
63
+ # 2. 为每一帧创建可视化
64
+ frame_ids = sorted(list(set(frame['frameId'] for frame in data['frames'])))
65
+ for frame_id in frame_ids:
66
+ frame_canvas = np.ones((canvas_height, canvas_width, 3), dtype=np.uint8) * 255
67
+
68
+ # 首先绘制所有区域的轮廓(半透明)
69
+ overlay = frame_canvas.copy()
70
+ for area in data['footprintArea']:
71
+ if area['startFrame'] <= frame_id <= area['endFrame']:
72
+ pos = area['areaPosition']
73
+ color = color_map[area['type']]
74
+ cv2.rectangle(overlay,
75
+ (pos['x'], pos['y']),
76
+ (pos['x'] + pos['width'], pos['y'] + pos['height']),
77
+ color, -1) # 填充矩形
78
+
79
+ # 应用半透明效果
80
+ alpha = 0.3
81
+ cv2.addWeighted(overlay, alpha, frame_canvas, 1 - alpha, 0, frame_canvas)
82
+
83
+ # 然后绘制当前帧的足印
84
+ for frame_data in data['frames']:
85
+ if frame_data['frameId'] == frame_id:
86
+ for footprint in frame_data['footprints']:
87
+ pos = footprint['position']
88
+ color = color_map[footprint['type']]
89
+
90
+ # 绘制足印边界框
91
+ cv2.rectangle(frame_canvas,
92
+ (pos['x'], pos['y']),
93
+ (pos['x'] + pos['width'], pos['y'] + pos['height']),
94
+ color, 2)
95
+
96
+ # 添加标签
97
+ label = f"{footprint['type']}#{footprint['id'].split('_')[-1]}"
98
+ cv2.putText(frame_canvas, label,
99
+ (pos['x'], pos['y'] - 5),
100
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
101
+
102
+ # 添加帧号
103
+ cv2.putText(frame_canvas, f"Frame: {frame_id}",
104
+ (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 2)
105
+
106
+ # 保存帧图像
107
+ cv2.imwrite(str(output_dir / f'frame_{frame_id:04d}.png'), frame_canvas)
108
+
109
+ print(f"可视化结果已保存至: {output_dir}")
110
+
111
+ def main():
112
+ # 假设JSON文件在results目录下的最新时间戳文件夹中
113
+ results_dir = Path("results")
114
+ latest_dir = max(results_dir.glob("*"), key=lambda p: p.stat().st_mtime)
115
+ json_path = latest_dir / "data" / "footprint_data.json"
116
+
117
+ # 创建可视化输出目录
118
+ output_dir = latest_dir / "visualization"
119
+
120
+ visualize_footprint_json(json_path, output_dir)
121
+
122
+ if __name__ == "__main__":
123
+ main()