Commit
·
448a1a6
0
Parent(s):
Add initial code
Browse files- .gitignore +12 -0
- requirements.txt +10 -0
- src/api_server.py +104 -0
- src/gait_analyze.py +1385 -0
- src/visualize_footprint_json.py +123 -0
.gitignore
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.py[cod]
|
| 3 |
+
*$py.class
|
| 4 |
+
.env
|
| 5 |
+
.venv
|
| 6 |
+
env/
|
| 7 |
+
venv/
|
| 8 |
+
ENV/
|
| 9 |
+
results/
|
| 10 |
+
*.mp4
|
| 11 |
+
*.avi
|
| 12 |
+
*.mov
|
requirements.txt
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
ultralytics
|
| 2 |
+
opencv-python
|
| 3 |
+
numpy
|
| 4 |
+
pandas
|
| 5 |
+
matplotlib
|
| 6 |
+
seaborn
|
| 7 |
+
scikit-learn
|
| 8 |
+
fastapi
|
| 9 |
+
python-multipart
|
| 10 |
+
uvicorn
|
src/api_server.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import FastAPI, UploadFile, Form, File
|
| 2 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 3 |
+
import json
|
| 4 |
+
import tempfile
|
| 5 |
+
import os
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
import uvicorn
|
| 8 |
+
from typing import Optional
|
| 9 |
+
from pydantic import BaseModel
|
| 10 |
+
import shutil
|
| 11 |
+
|
| 12 |
+
from gait_analyze import GaitAnalyzer
|
| 13 |
+
|
| 14 |
+
app = FastAPI()
|
| 15 |
+
|
| 16 |
+
# 配置CORS
|
| 17 |
+
app.add_middleware(
|
| 18 |
+
CORSMiddleware,
|
| 19 |
+
allow_origins=["*"], # 在生产环境中应该设置具体的域名
|
| 20 |
+
allow_credentials=True,
|
| 21 |
+
allow_methods=["*"],
|
| 22 |
+
allow_headers=["*"],
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
class AnalysisResponse(BaseModel):
|
| 26 |
+
"""分析响应模型"""
|
| 27 |
+
status: str
|
| 28 |
+
message: str
|
| 29 |
+
data: Optional[dict] = None
|
| 30 |
+
|
| 31 |
+
@app.post("/api/v1/analysisFootVideo")
|
| 32 |
+
async def analysis_foot_video(
|
| 33 |
+
video: UploadFile = File(...),
|
| 34 |
+
params: str = Form(...)
|
| 35 |
+
):
|
| 36 |
+
"""处理足印视频分析请求"""
|
| 37 |
+
try:
|
| 38 |
+
# 解析参数
|
| 39 |
+
analysis_params = json.loads(params)
|
| 40 |
+
|
| 41 |
+
# 创建临时目录保存视频
|
| 42 |
+
with tempfile.TemporaryDirectory() as temp_dir:
|
| 43 |
+
# 保存上传的视频
|
| 44 |
+
video_path = os.path.join(temp_dir, "input_video.mp4")
|
| 45 |
+
with open(video_path, "wb") as buffer:
|
| 46 |
+
shutil.copyfileobj(video.file, buffer)
|
| 47 |
+
|
| 48 |
+
# 创建分析器实例
|
| 49 |
+
analyzer = GaitAnalyzer()
|
| 50 |
+
|
| 51 |
+
# 自动检测时间范围
|
| 52 |
+
start_time, end_time = analyzer._detect_mouse_time_range(video_path)
|
| 53 |
+
|
| 54 |
+
# 处理视频
|
| 55 |
+
analyzer.process_video(
|
| 56 |
+
video_path,
|
| 57 |
+
start_time=start_time,
|
| 58 |
+
end_time=end_time,
|
| 59 |
+
conf_thres=0.7,
|
| 60 |
+
iou_thres=0.5
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
# 获取结果目录中的数据
|
| 64 |
+
result_dir = Path(analyzer.result_dir)
|
| 65 |
+
|
| 66 |
+
# 读取JSON结果
|
| 67 |
+
json_path = result_dir / "data" / "footprint_data.json"
|
| 68 |
+
with open(json_path, 'r') as f:
|
| 69 |
+
footprint_data = json.load(f)
|
| 70 |
+
|
| 71 |
+
# 添加视频参数到结果中
|
| 72 |
+
footprint_data.update({
|
| 73 |
+
"video_info": {
|
| 74 |
+
"fps": analysis_params.get("video_fps"),
|
| 75 |
+
"width": analysis_params.get("video_width"),
|
| 76 |
+
"height": analysis_params.get("video_height"),
|
| 77 |
+
"scale_length": analysis_params.get("scale_length"),
|
| 78 |
+
"actual_length": analysis_params.get("actual_length"),
|
| 79 |
+
}
|
| 80 |
+
})
|
| 81 |
+
|
| 82 |
+
return AnalysisResponse(
|
| 83 |
+
status="success",
|
| 84 |
+
message="分析完成",
|
| 85 |
+
data=footprint_data
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
except Exception as e:
|
| 89 |
+
return AnalysisResponse(
|
| 90 |
+
status="error",
|
| 91 |
+
message=f"分析失败: {str(e)}"
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
def main():
|
| 95 |
+
"""启动服务器"""
|
| 96 |
+
uvicorn.run(
|
| 97 |
+
"api_server:app",
|
| 98 |
+
host="0.0.0.0",
|
| 99 |
+
port=12345,
|
| 100 |
+
reload=True
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
if __name__ == "__main__":
|
| 104 |
+
main()
|
src/gait_analyze.py
ADDED
|
@@ -0,0 +1,1385 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ultralytics import YOLO
|
| 2 |
+
import cv2
|
| 3 |
+
import numpy as np
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from typing import List, Dict, Tuple
|
| 6 |
+
import pandas as pd
|
| 7 |
+
from datetime import datetime
|
| 8 |
+
import matplotlib.pyplot as plt
|
| 9 |
+
from matplotlib.patches import Rectangle
|
| 10 |
+
import seaborn as sns
|
| 11 |
+
import os
|
| 12 |
+
|
| 13 |
+
@dataclass
|
| 14 |
+
class GaitPrint:
|
| 15 |
+
"""足印数据类"""
|
| 16 |
+
frame_id: int # 帧号
|
| 17 |
+
x: float # 中心x坐标
|
| 18 |
+
y: float # 中心y坐标
|
| 19 |
+
w: float # 宽度
|
| 20 |
+
h: float # 高度
|
| 21 |
+
conf: float # 置信度
|
| 22 |
+
paw_type: str = None # 爪子类型 (LF, RF, LH, RH)
|
| 23 |
+
timestamp: float = None # 时间戳
|
| 24 |
+
stance_phase: bool = True # 支撑相标志
|
| 25 |
+
gait_cycle: int = 0 # 步态周期编号
|
| 26 |
+
image_patch: np.ndarray = None # 足印图像块
|
| 27 |
+
image_features: np.ndarray = None # 图像特征向量
|
| 28 |
+
cluster_id: int = -1 # 聚类ID
|
| 29 |
+
|
| 30 |
+
class GaitAnalyzer:
|
| 31 |
+
def __init__(self):
|
| 32 |
+
"""初始化步态分析器"""
|
| 33 |
+
self.gait_model = YOLO('models/mice-gait-v1.0.mlpackage', task='detect')
|
| 34 |
+
self.pose_model = YOLO('models/mice-pose-bottomview-v1.0.mlpackage', task='pose') # 新增pose模型
|
| 35 |
+
self.gait_prints: List[GaitPrint] = []
|
| 36 |
+
self.mice_positions: List[Dict] = [] # 现在存储pose关键点信息
|
| 37 |
+
self.params = {}
|
| 38 |
+
self.time_window = 0.2
|
| 39 |
+
self.distance_threshold = 30
|
| 40 |
+
self.gait_pattern = None
|
| 41 |
+
self.result_dir = self._create_result_dir()
|
| 42 |
+
|
| 43 |
+
def _detect_mouse_time_range(self, video_path: str, margin_ratio: float = 0.05) -> Tuple[float, float]:
|
| 44 |
+
"""使用pose模型的鼻子和尾巴点来检测老鼠"""
|
| 45 |
+
print("\n[0/6] 预处理视频以确定分析时间范围...")
|
| 46 |
+
|
| 47 |
+
cap = cv2.VideoCapture(video_path)
|
| 48 |
+
if not cap.isOpened():
|
| 49 |
+
raise ValueError("无法打开视频文件")
|
| 50 |
+
|
| 51 |
+
# 获取视频基本信息
|
| 52 |
+
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
| 53 |
+
fps = int(cap.get(cv2.CAP_PROP_FPS))
|
| 54 |
+
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 55 |
+
|
| 56 |
+
margin = int(width * margin_ratio)
|
| 57 |
+
left_boundary = margin
|
| 58 |
+
right_boundary = width - margin
|
| 59 |
+
|
| 60 |
+
start_frame = None
|
| 61 |
+
end_frame = None
|
| 62 |
+
frame_id = 0
|
| 63 |
+
|
| 64 |
+
while cap.isOpened():
|
| 65 |
+
ret, frame = cap.read()
|
| 66 |
+
if not ret:
|
| 67 |
+
break
|
| 68 |
+
|
| 69 |
+
if frame_id % 30 == 0: # 每30帧打印一次进度
|
| 70 |
+
print(f"预处理进度: {frame_id}/{total_frames} ({frame_id/total_frames*100:.1f}%)")
|
| 71 |
+
|
| 72 |
+
# 使用pose模型检测
|
| 73 |
+
results = self.pose_model(frame, conf=0.5, verbose=False)
|
| 74 |
+
|
| 75 |
+
for r in results:
|
| 76 |
+
keypoints = r.keypoints.data[0] # 获取关键点
|
| 77 |
+
if len(keypoints) >= 7: # 确保检测到所有关键点
|
| 78 |
+
nose_x = keypoints[0][0].item() # 鼻子x坐标
|
| 79 |
+
tail_x = keypoints[6][0].item() # 尾巴x坐标
|
| 80 |
+
|
| 81 |
+
# 使用鼻子位置判断开始
|
| 82 |
+
if start_frame is None and nose_x > left_boundary + margin:
|
| 83 |
+
start_frame = frame_id
|
| 84 |
+
|
| 85 |
+
# 使用尾巴位置判断结束
|
| 86 |
+
if start_frame is not None and tail_x > right_boundary - margin:
|
| 87 |
+
end_frame = frame_id
|
| 88 |
+
break
|
| 89 |
+
|
| 90 |
+
if end_frame is not None:
|
| 91 |
+
break
|
| 92 |
+
|
| 93 |
+
frame_id += 1
|
| 94 |
+
|
| 95 |
+
cap.release()
|
| 96 |
+
|
| 97 |
+
# 转换为时间
|
| 98 |
+
start_time = start_frame / fps if start_frame is not None else 0
|
| 99 |
+
end_time = end_frame / fps if end_frame is not None else total_frames / fps
|
| 100 |
+
|
| 101 |
+
print(f"检测到有效时间范围: {start_time:.2f}s - {end_time:.2f}s")
|
| 102 |
+
return start_time, end_time
|
| 103 |
+
|
| 104 |
+
def _create_result_dir(self) -> str:
|
| 105 |
+
"""创建结果目录结构"""
|
| 106 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 107 |
+
result_dir = f"results/{timestamp}"
|
| 108 |
+
|
| 109 |
+
# 创建目录结构
|
| 110 |
+
for subdir in ['data', 'plots', 'videos']:
|
| 111 |
+
os.makedirs(f"{result_dir}/{subdir}", exist_ok=True)
|
| 112 |
+
|
| 113 |
+
return result_dir
|
| 114 |
+
|
| 115 |
+
def _filter_footprints(self):
|
| 116 |
+
"""对足印进行时空聚类,将同一个足印的多次检测归为一组"""
|
| 117 |
+
if not self.gait_prints:
|
| 118 |
+
print("警告:没有足印数据可供聚类")
|
| 119 |
+
return
|
| 120 |
+
|
| 121 |
+
try:
|
| 122 |
+
from sklearn.preprocessing import StandardScaler
|
| 123 |
+
from sklearn.cluster import DBSCAN
|
| 124 |
+
|
| 125 |
+
# 1. 准备数据
|
| 126 |
+
features = np.array([[p.x, p.y, p.timestamp * 30] for p in self.gait_prints])
|
| 127 |
+
print(f"开始聚类,原始足印数量: {len(features)}")
|
| 128 |
+
|
| 129 |
+
# 2. 标准化特征
|
| 130 |
+
scaler = StandardScaler()
|
| 131 |
+
features_scaled = scaler.fit_transform(features)
|
| 132 |
+
|
| 133 |
+
# 3. DBSCAN聚类
|
| 134 |
+
eps = 0.3 # 可以根据实际情况调整
|
| 135 |
+
min_samples = 1 # 设为1以保留所有检测
|
| 136 |
+
dbscan = DBSCAN(eps=eps, min_samples=min_samples)
|
| 137 |
+
cluster_labels = dbscan.fit_predict(features_scaled)
|
| 138 |
+
|
| 139 |
+
# 4. 将聚类结果添加到足印对象中
|
| 140 |
+
for print_obj, label in zip(self.gait_prints, cluster_labels):
|
| 141 |
+
print_obj.cluster_id = label
|
| 142 |
+
|
| 143 |
+
# 打印聚类统计信息
|
| 144 |
+
n_clusters = len(set(cluster_labels)) - (1 if -1 in cluster_labels else 0)
|
| 145 |
+
print(f"聚类完成! 共识别出 {n_clusters} 个独立足印")
|
| 146 |
+
|
| 147 |
+
# 按cluster_id分组并打印每组的大小
|
| 148 |
+
from collections import Counter
|
| 149 |
+
cluster_sizes = Counter(cluster_labels)
|
| 150 |
+
print("\n各组足印检测数量:")
|
| 151 |
+
for cluster_id, size in sorted(cluster_sizes.items()):
|
| 152 |
+
if cluster_id != -1:
|
| 153 |
+
print(f"足印 #{cluster_id}: {size}个检测")
|
| 154 |
+
|
| 155 |
+
except Exception as e:
|
| 156 |
+
print(f"聚类过程出错: {str(e)}")
|
| 157 |
+
|
| 158 |
+
def _post_process_footprints(self):
|
| 159 |
+
"""后处理足迹数据:聚类、过滤和分类"""
|
| 160 |
+
if not self.gait_prints:
|
| 161 |
+
print("警告:没有足印数据可供处理")
|
| 162 |
+
return
|
| 163 |
+
|
| 164 |
+
# 1. 时空聚类
|
| 165 |
+
self._filter_footprints()
|
| 166 |
+
|
| 167 |
+
if not self.gait_prints:
|
| 168 |
+
print("错误:后处理后没有足印数据")
|
| 169 |
+
return
|
| 170 |
+
|
| 171 |
+
# 2. 全局分析
|
| 172 |
+
mouse_path = pd.DataFrame(self.mice_positions)
|
| 173 |
+
total_dx = mouse_path['x'].iloc[-1] - mouse_path['x'].iloc[0]
|
| 174 |
+
moving_right = total_dx > 0
|
| 175 |
+
|
| 176 |
+
# 3. 使用机器学习进行足迹分类
|
| 177 |
+
self._classify_footprints(moving_right)
|
| 178 |
+
|
| 179 |
+
# 4. 确定步态周期
|
| 180 |
+
self._determine_gait_cycles()
|
| 181 |
+
|
| 182 |
+
def _classify_footprints(self, moving_right: bool):
|
| 183 |
+
"""使用pose关键点来分类足迹"""
|
| 184 |
+
if len(self.gait_prints) < 4 or not self.mice_positions:
|
| 185 |
+
print("警告:足迹或姿态数据不足")
|
| 186 |
+
return
|
| 187 |
+
|
| 188 |
+
# 按时间排序足印
|
| 189 |
+
sorted_prints = sorted(self.gait_prints, key=lambda p: p.timestamp)
|
| 190 |
+
|
| 191 |
+
# 对每个cluster进行分类
|
| 192 |
+
cluster_groups = {}
|
| 193 |
+
for p in sorted_prints:
|
| 194 |
+
if p.cluster_id not in cluster_groups:
|
| 195 |
+
cluster_groups[p.cluster_id] = []
|
| 196 |
+
cluster_groups[p.cluster_id].append(p)
|
| 197 |
+
|
| 198 |
+
# 对每个cluster,找到最近时间的pose数据
|
| 199 |
+
for cluster_id, prints in cluster_groups.items():
|
| 200 |
+
mid_time = np.mean([p.timestamp for p in prints])
|
| 201 |
+
closest_pose = min(self.mice_positions,
|
| 202 |
+
key=lambda m: abs(m['timestamp'] - mid_time))
|
| 203 |
+
|
| 204 |
+
if 'keypoints' not in closest_pose:
|
| 205 |
+
print(f"警告:时间戳 {mid_time:.2f}s 处的姿态数据缺少关键点信息")
|
| 206 |
+
continue
|
| 207 |
+
|
| 208 |
+
# 计算cluster中心位置
|
| 209 |
+
center_x = np.mean([p.x for p in prints])
|
| 210 |
+
center_y = np.mean([p.y for p in prints])
|
| 211 |
+
|
| 212 |
+
# 获取关键点位置
|
| 213 |
+
nose = closest_pose['keypoints']['nose']
|
| 214 |
+
re = closest_pose['keypoints']['right_ear']
|
| 215 |
+
le = closest_pose['keypoints']['left_ear']
|
| 216 |
+
mid = closest_pose['keypoints']['mid']
|
| 217 |
+
rl = closest_pose['keypoints']['right_leg']
|
| 218 |
+
ll = closest_pose['keypoints']['left_leg']
|
| 219 |
+
tail = closest_pose['keypoints']['tail_base']
|
| 220 |
+
|
| 221 |
+
# 计算到各关键点的距离
|
| 222 |
+
dist_to_front = min(
|
| 223 |
+
np.sqrt((center_x - nose[0])**2 + (center_y - nose[1])**2),
|
| 224 |
+
np.sqrt((center_x - re[0])**2 + (center_y - re[1])**2),
|
| 225 |
+
np.sqrt((center_x - le[0])**2 + (center_y - le[1])**2)
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
dist_to_back = min(
|
| 229 |
+
np.sqrt((center_x - rl[0])**2 + (center_y - rl[1])**2),
|
| 230 |
+
np.sqrt((center_x - ll[0])**2 + (center_y - ll[1])**2),
|
| 231 |
+
np.sqrt((center_x - tail[0])**2 + (center_y - tail[1])**2)
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
# 前后判断:比较到前后关键点的距离
|
| 235 |
+
is_front = dist_to_front < dist_to_back
|
| 236 |
+
|
| 237 |
+
# 左右判断:根据y坐标相对位置
|
| 238 |
+
if is_front:
|
| 239 |
+
is_left = center_y > (re[1] + le[1])/2 # 比较与耳朵中点的位置
|
| 240 |
+
else:
|
| 241 |
+
is_left = center_y > (rl[1] + ll[1])/2 # 比较与后腿中点的位置
|
| 242 |
+
|
| 243 |
+
# 确定爪子类型
|
| 244 |
+
paw_type = None
|
| 245 |
+
if is_front:
|
| 246 |
+
paw_type = 'LF' if is_left else 'RF'
|
| 247 |
+
else:
|
| 248 |
+
paw_type = 'LH' if is_left else 'RH'
|
| 249 |
+
|
| 250 |
+
# 应用分类结果
|
| 251 |
+
for p in prints:
|
| 252 |
+
p.paw_type = paw_type
|
| 253 |
+
|
| 254 |
+
# 按帧组织足印数据,用于时序一致性检查
|
| 255 |
+
frame_prints = {}
|
| 256 |
+
for p in sorted_prints:
|
| 257 |
+
if p.frame_id not in frame_prints:
|
| 258 |
+
frame_prints[p.frame_id] = []
|
| 259 |
+
frame_prints[p.frame_id].append(p)
|
| 260 |
+
|
| 261 |
+
# 进行时序一致性检查
|
| 262 |
+
self._enforce_temporal_consistency(frame_prints)
|
| 263 |
+
|
| 264 |
+
def _enforce_temporal_consistency(self, frame_prints):
|
| 265 |
+
"""确保时序一致性:同一时间同一类型的足印只能有一个"""
|
| 266 |
+
for frame_id, prints in frame_prints.items():
|
| 267 |
+
# 按类型分组
|
| 268 |
+
type_groups = {'LF': [], 'RF': [], 'LH': [], 'RH': []}
|
| 269 |
+
for p in prints:
|
| 270 |
+
if p.paw_type:
|
| 271 |
+
type_groups[p.paw_type].append(p)
|
| 272 |
+
|
| 273 |
+
# 处理每个有多个足印的类型
|
| 274 |
+
for paw_type, group in type_groups.items():
|
| 275 |
+
if len(group) > 1:
|
| 276 |
+
# 保留cluster_id较大的足印(通常是较新的足印)
|
| 277 |
+
newest_print = max(group, key=lambda p: p.cluster_id)
|
| 278 |
+
for p in group:
|
| 279 |
+
if p != newest_print:
|
| 280 |
+
# 将重复的足印重新分类为对角的另一只脚
|
| 281 |
+
if paw_type == 'LF':
|
| 282 |
+
p.paw_type = 'RH'
|
| 283 |
+
elif paw_type == 'RF':
|
| 284 |
+
p.paw_type = 'LH'
|
| 285 |
+
elif paw_type == 'LH':
|
| 286 |
+
p.paw_type = 'RF'
|
| 287 |
+
else: # RH
|
| 288 |
+
p.paw_type = 'LF'
|
| 289 |
+
|
| 290 |
+
def _smooth_classifications(self):
|
| 291 |
+
"""使用时序信息平滑分类结果"""
|
| 292 |
+
# 1. 按时间排序足迹
|
| 293 |
+
sorted_prints = sorted(self.gait_prints, key=lambda x: x.timestamp)
|
| 294 |
+
|
| 295 |
+
# 2. 为每种爪子类型建立时间序列
|
| 296 |
+
paw_sequences = {'LF': [], 'RF': [], 'LH': [], 'RH': []}
|
| 297 |
+
|
| 298 |
+
# 3. 检测和修正异常分类
|
| 299 |
+
window_size = 0.2 # 200ms时间窗口
|
| 300 |
+
for i, print in enumerate(sorted_prints):
|
| 301 |
+
# 获取时间窗口内的同类足迹
|
| 302 |
+
nearby_prints = [
|
| 303 |
+
p for p in sorted_prints
|
| 304 |
+
if abs(p.timestamp - print.timestamp) < window_size
|
| 305 |
+
and p.paw_type == print.paw_type
|
| 306 |
+
and p != print
|
| 307 |
+
]
|
| 308 |
+
|
| 309 |
+
if nearby_prints:
|
| 310 |
+
# 检查空间一致性
|
| 311 |
+
avg_x = np.mean([p.x for p in nearby_prints])
|
| 312 |
+
avg_y = np.mean([p.y for p in nearby_prints])
|
| 313 |
+
|
| 314 |
+
# 如果当前足迹位置偏离太远,考虑重新分类
|
| 315 |
+
dist = np.sqrt((print.x - avg_x)**2 + (print.y - avg_y)**2)
|
| 316 |
+
if dist > 50: # 像素距离阈值
|
| 317 |
+
# 尝试重新分类
|
| 318 |
+
self._reclassify_print(print, sorted_prints, i)
|
| 319 |
+
|
| 320 |
+
# 更新序列
|
| 321 |
+
paw_sequences[print.paw_type].append(print)
|
| 322 |
+
|
| 323 |
+
# 4. 验证步态模式的合理性
|
| 324 |
+
self._validate_gait_pattern(paw_sequences)
|
| 325 |
+
|
| 326 |
+
def _reclassify_print(self, print, all_prints, current_idx):
|
| 327 |
+
"""重新分类可能错误的足迹"""
|
| 328 |
+
# 获取临近的其他足迹
|
| 329 |
+
window_size = 0.2
|
| 330 |
+
nearby_prints = [
|
| 331 |
+
p for p in all_prints
|
| 332 |
+
if abs(p.timestamp - print.timestamp) < window_size
|
| 333 |
+
and p != print
|
| 334 |
+
]
|
| 335 |
+
|
| 336 |
+
if not nearby_prints:
|
| 337 |
+
return
|
| 338 |
+
|
| 339 |
+
# 计算到每种类型足迹的平均距离
|
| 340 |
+
type_distances = {}
|
| 341 |
+
for paw_type in ['LF', 'RF', 'LH', 'RH']:
|
| 342 |
+
same_type = [p for p in nearby_prints if p.paw_type == paw_type]
|
| 343 |
+
if same_type:
|
| 344 |
+
avg_dist = np.mean([
|
| 345 |
+
np.sqrt((print.x - p.x)**2 + (print.y - p.y)**2)
|
| 346 |
+
for p in same_type
|
| 347 |
+
])
|
| 348 |
+
type_distances[paw_type] = avg_dist
|
| 349 |
+
|
| 350 |
+
# 选择距离最小的类型
|
| 351 |
+
if type_distances:
|
| 352 |
+
new_type = min(type_distances.items(), key=lambda x: x[1])[0]
|
| 353 |
+
print.paw_type = new_type
|
| 354 |
+
|
| 355 |
+
def _validate_gait_pattern(self, paw_sequences):
|
| 356 |
+
"""验证步态模式的合理性"""
|
| 357 |
+
# 1. 检查每个爪子的步频是否合理
|
| 358 |
+
for paw_type, sequence in paw_sequences.items():
|
| 359 |
+
if len(sequence) >= 2:
|
| 360 |
+
time_diffs = [
|
| 361 |
+
sequence[i+1].timestamp - sequence[i].timestamp
|
| 362 |
+
for i in range(len(sequence)-1)
|
| 363 |
+
]
|
| 364 |
+
avg_freq = 1 / np.mean(time_diffs) if time_diffs else 0
|
| 365 |
+
|
| 366 |
+
# 步频应该在合理范围内 (通常2-5Hz)
|
| 367 |
+
if not (2 <= avg_freq <= 5):
|
| 368 |
+
print(f"警告:{paw_type}的步频 ({avg_freq:.2f}Hz) 可能不合理")
|
| 369 |
+
|
| 370 |
+
# 2. ���查对角步态模式
|
| 371 |
+
diagonal_pairs = [('LF', 'RH'), ('RF', 'LH')]
|
| 372 |
+
for pair in diagonal_pairs:
|
| 373 |
+
seq1 = paw_sequences[pair[0]]
|
| 374 |
+
seq2 = paw_sequences[pair[1]]
|
| 375 |
+
if seq1 and seq2:
|
| 376 |
+
# 检查对角步态的同步性
|
| 377 |
+
for p1 in seq1:
|
| 378 |
+
nearest = min(seq2, key=lambda p: abs(p.timestamp - p1.timestamp))
|
| 379 |
+
if abs(nearest.timestamp - p1.timestamp) > 0.1: # 100ms阈值
|
| 380 |
+
print(f"警告:对角步态 {pair} 可能不同步")
|
| 381 |
+
|
| 382 |
+
def _is_position_reasonable(self, print_obj, new_type):
|
| 383 |
+
"""检查新的分类是否在合理的空间位置范围内"""
|
| 384 |
+
# 获取同类型足迹的平均位置
|
| 385 |
+
same_type_prints = [p for p in self.gait_prints if p.paw_type == new_type]
|
| 386 |
+
if not same_type_prints:
|
| 387 |
+
return True
|
| 388 |
+
|
| 389 |
+
# 计算平均位置
|
| 390 |
+
avg_x = np.mean([p.x for p in same_type_prints])
|
| 391 |
+
avg_y = np.mean([p.y for p in same_type_prints])
|
| 392 |
+
|
| 393 |
+
# 计算标准差
|
| 394 |
+
std_x = np.std([p.x for p in same_type_prints])
|
| 395 |
+
std_y = np.std([p.y for p in same_type_prints])
|
| 396 |
+
|
| 397 |
+
# 检查是否在3个标准差范围内
|
| 398 |
+
x_reasonable = abs(print_obj.x - avg_x) < 3 * std_x
|
| 399 |
+
y_reasonable = abs(print_obj.y - avg_y) < 3 * std_y
|
| 400 |
+
|
| 401 |
+
return x_reasonable and y_reasonable
|
| 402 |
+
|
| 403 |
+
def _determine_gait_cycles(self):
|
| 404 |
+
"""确定每个足印所属的步态周期"""
|
| 405 |
+
# 按爪子类型分组
|
| 406 |
+
paw_groups = {'LF': [], 'RF': [], 'LH': [], 'RH': []}
|
| 407 |
+
for print in self.gait_prints:
|
| 408 |
+
if print.paw_type:
|
| 409 |
+
paw_groups[print.paw_type].append(print)
|
| 410 |
+
|
| 411 |
+
# 为每种爪子类型确定步态周期
|
| 412 |
+
for paw_type, prints in paw_groups.items():
|
| 413 |
+
if len(prints) < 2:
|
| 414 |
+
continue
|
| 415 |
+
|
| 416 |
+
# 按时间排序
|
| 417 |
+
prints.sort(key=lambda x: x.timestamp)
|
| 418 |
+
|
| 419 |
+
# 初始化第一个周期
|
| 420 |
+
cycle_id = 1
|
| 421 |
+
prints[0].gait_cycle = cycle_id
|
| 422 |
+
|
| 423 |
+
# 基于时空距离确定周期
|
| 424 |
+
for i in range(1, len(prints)):
|
| 425 |
+
prev_print = prints[i-1]
|
| 426 |
+
curr_print = prints[i]
|
| 427 |
+
|
| 428 |
+
# 计算与前一个足印的时空距离
|
| 429 |
+
time_diff = curr_print.timestamp - prev_print.timestamp
|
| 430 |
+
space_diff = np.sqrt((curr_print.x - prev_print.x)**2 +
|
| 431 |
+
(curr_print.y - prev_print.y)**2)
|
| 432 |
+
|
| 433 |
+
# 如果时空距离超过阈值,开始新的周期
|
| 434 |
+
if time_diff > self.time_window * 2 or space_diff > self.distance_threshold * 3:
|
| 435 |
+
cycle_id += 1
|
| 436 |
+
|
| 437 |
+
curr_print.gait_cycle = cycle_id
|
| 438 |
+
|
| 439 |
+
def _calculate_gait_parameters(self):
|
| 440 |
+
"""计算步态参数"""
|
| 441 |
+
# 按类型分组足印
|
| 442 |
+
paw_groups = {'LF': [], 'RF': [], 'LH': [], 'RH': []}
|
| 443 |
+
for print in self.gait_prints:
|
| 444 |
+
if print.paw_type:
|
| 445 |
+
paw_groups[print.paw_type].append(print)
|
| 446 |
+
|
| 447 |
+
# 初始化参数字典
|
| 448 |
+
self.params = {
|
| 449 |
+
'stride_length': {}, # 步幅
|
| 450 |
+
'step_width': {}, # 步宽
|
| 451 |
+
'step_frequency': {}, # 步频
|
| 452 |
+
'stance_time': {}, # 支撑时间
|
| 453 |
+
'swing_time': {}, # 摆动时间
|
| 454 |
+
'duty_factor': {}, # 支撑占空比
|
| 455 |
+
'symmetry_index': {}, # 对称性指数
|
| 456 |
+
'base_of_support': {}, # 支撑基底
|
| 457 |
+
'stride_time': {} # 步态周期时间
|
| 458 |
+
}
|
| 459 |
+
|
| 460 |
+
# 计算每种爪子的参数
|
| 461 |
+
for paw_type, prints in paw_groups.items():
|
| 462 |
+
if len(prints) < 2:
|
| 463 |
+
continue
|
| 464 |
+
|
| 465 |
+
# 按时间排序
|
| 466 |
+
prints.sort(key=lambda x: x.timestamp)
|
| 467 |
+
|
| 468 |
+
# 1. 计算步幅 (相邻足印间距)
|
| 469 |
+
stride_lengths = []
|
| 470 |
+
for i in range(1, len(prints)):
|
| 471 |
+
if prints[i].gait_cycle == prints[i-1].gait_cycle:
|
| 472 |
+
dist = np.sqrt((prints[i].x - prints[i-1].x)**2 +
|
| 473 |
+
(prints[i].y - prints[i-1].y)**2)
|
| 474 |
+
stride_lengths.append(dist)
|
| 475 |
+
self.params['stride_length'][paw_type] = np.mean(stride_lengths) if stride_lengths else 0
|
| 476 |
+
|
| 477 |
+
# 2. 计算步频
|
| 478 |
+
if len(prints) > 1:
|
| 479 |
+
time_diff = prints[-1].timestamp - prints[0].timestamp
|
| 480 |
+
self.params['step_frequency'][paw_type] = (len(prints) - 1) / time_diff if time_diff > 0 else 0
|
| 481 |
+
|
| 482 |
+
# 3. 计算步态周期时间
|
| 483 |
+
cycle_times = []
|
| 484 |
+
for i in range(1, len(prints)):
|
| 485 |
+
if prints[i].gait_cycle == prints[i-1].gait_cycle:
|
| 486 |
+
cycle_times.append(prints[i].timestamp - prints[i-1].timestamp)
|
| 487 |
+
self.params['stride_time'][paw_type] = np.mean(cycle_times) if cycle_times else 0
|
| 488 |
+
|
| 489 |
+
# 4. 计算支撑时间和摆动时间
|
| 490 |
+
stance_times = []
|
| 491 |
+
swing_times = []
|
| 492 |
+
for i in range(len(prints)-1):
|
| 493 |
+
if prints[i].gait_cycle == prints[i+1].gait_cycle:
|
| 494 |
+
stance_time = 0.1 # 假设支撑相持续100ms
|
| 495 |
+
swing_time = prints[i+1].timestamp - prints[i].timestamp - stance_time
|
| 496 |
+
stance_times.append(stance_time)
|
| 497 |
+
swing_times.append(swing_time)
|
| 498 |
+
|
| 499 |
+
self.params['stance_time'][paw_type] = np.mean(stance_times) if stance_times else 0
|
| 500 |
+
self.params['swing_time'][paw_type] = np.mean(swing_times) if swing_times else 0
|
| 501 |
+
|
| 502 |
+
# 5. 计算支撑占空比
|
| 503 |
+
total_time = self.params['stance_time'][paw_type] + self.params['swing_time'][paw_type]
|
| 504 |
+
self.params['duty_factor'][paw_type] = (self.params['stance_time'][paw_type] / total_time
|
| 505 |
+
if total_time > 0 else 0)
|
| 506 |
+
|
| 507 |
+
# 6. 计算左右对称性指数
|
| 508 |
+
for side in ['F', 'H']: # 前爪和后爪
|
| 509 |
+
left_paw = f'L{side}'
|
| 510 |
+
right_paw = f'R{side}'
|
| 511 |
+
if (left_paw in self.params['stride_length'] and
|
| 512 |
+
right_paw in self.params['stride_length']):
|
| 513 |
+
left_stride = self.params['stride_length'][left_paw]
|
| 514 |
+
right_stride = self.params['stride_length'][right_paw]
|
| 515 |
+
symmetry = abs(left_stride - right_stride) / ((left_stride + right_stride) / 2)
|
| 516 |
+
self.params['symmetry_index'][side] = symmetry
|
| 517 |
+
|
| 518 |
+
# 7. 计算支撑基底
|
| 519 |
+
for side in ['F', 'H']: # 前爪和后爪
|
| 520 |
+
left_prints = paw_groups[f'L{side}']
|
| 521 |
+
right_prints = paw_groups[f'R{side}']
|
| 522 |
+
if left_prints and right_prints:
|
| 523 |
+
# 计算同一时间窗口内左右爪的横向距离
|
| 524 |
+
base_distances = []
|
| 525 |
+
for left_print in left_prints:
|
| 526 |
+
# 找到最近时间的右爪足印
|
| 527 |
+
nearest_right = min(right_prints,
|
| 528 |
+
key=lambda p: abs(p.timestamp - left_print.timestamp))
|
| 529 |
+
if abs(nearest_right.timestamp - left_print.timestamp) < self.time_window:
|
| 530 |
+
dist = abs(left_print.y - nearest_right.y)
|
| 531 |
+
base_distances.append(dist)
|
| 532 |
+
|
| 533 |
+
self.params['base_of_support'][side] = np.mean(base_distances) if base_distances else 0
|
| 534 |
+
|
| 535 |
+
# 8. 确定步态模式
|
| 536 |
+
self._determine_gait_pattern()
|
| 537 |
+
|
| 538 |
+
def _determine_gait_pattern(self):
|
| 539 |
+
"""确定步态模式(walk, trot, gallop, bound)"""
|
| 540 |
+
if len(self.gait_prints) < 4:
|
| 541 |
+
return
|
| 542 |
+
|
| 543 |
+
# 分析相邻足印的时间关系
|
| 544 |
+
prints_by_time = sorted(self.gait_prints, key=lambda p: p.timestamp)
|
| 545 |
+
time_diffs = []
|
| 546 |
+
for i in range(1, len(prints_by_time)):
|
| 547 |
+
if prints_by_time[i].paw_type and prints_by_time[i-1].paw_type:
|
| 548 |
+
time_diff = prints_by_time[i].timestamp - prints_by_time[i-1].timestamp
|
| 549 |
+
time_diffs.append(time_diff)
|
| 550 |
+
|
| 551 |
+
if not time_diffs:
|
| 552 |
+
return
|
| 553 |
+
|
| 554 |
+
# 计算时间差的变异系数
|
| 555 |
+
mean_diff = np.mean(time_diffs)
|
| 556 |
+
std_diff = np.std(time_diffs)
|
| 557 |
+
cv = std_diff / mean_diff if mean_diff > 0 else 0
|
| 558 |
+
|
| 559 |
+
# 分析对角步态模式
|
| 560 |
+
diagonal_pairs = [('LF', 'RH'), ('RF', 'LH')]
|
| 561 |
+
diagonal_sync = []
|
| 562 |
+
for pair in diagonal_pairs:
|
| 563 |
+
prints1 = [p for p in self.gait_prints if p.paw_type == pair[0]]
|
| 564 |
+
prints2 = [p for p in self.gait_prints if p.paw_type == pair[1]]
|
| 565 |
+
|
| 566 |
+
if prints1 and prints2:
|
| 567 |
+
# 计算对角足印的同步性
|
| 568 |
+
min_time_diff = float('inf')
|
| 569 |
+
for p1 in prints1:
|
| 570 |
+
for p2 in prints2:
|
| 571 |
+
time_diff = abs(p1.timestamp - p2.timestamp)
|
| 572 |
+
min_time_diff = min(min_time_diff, time_diff)
|
| 573 |
+
diagonal_sync.append(min_time_diff)
|
| 574 |
+
|
| 575 |
+
# 基于时间特征确定步态模式
|
| 576 |
+
if cv < 0.2 and all(sync < 0.1 for sync in diagonal_sync):
|
| 577 |
+
self.gait_pattern = 'trot' # 对角步态同步性高
|
| 578 |
+
elif cv > 0.4:
|
| 579 |
+
self.gait_pattern = 'gallop' # 时间间隔变异大
|
| 580 |
+
elif len(set(p.gait_cycle for p in self.gait_prints)) > len(self.gait_prints) / 2:
|
| 581 |
+
self.gait_pattern = 'bound' # 步态周期数量多
|
| 582 |
+
else:
|
| 583 |
+
self.gait_pattern = 'walk' # 默认为行走
|
| 584 |
+
|
| 585 |
+
def process_video(self, video_path: str, start_time=0.0, end_time=None, conf_thres=0.7, iou_thres=0.5):
|
| 586 |
+
"""处理视频并分析步态
|
| 587 |
+
|
| 588 |
+
Args:
|
| 589 |
+
video_path (str): 视频文件路径
|
| 590 |
+
start_time (float): 开始时间(秒),精确到0.01秒
|
| 591 |
+
end_time (float): 结束时间(秒),精确到0.01秒,None表示处理到视频结束
|
| 592 |
+
conf_thres (float): 置信度阈值
|
| 593 |
+
iou_thres (float): IOU阈值
|
| 594 |
+
"""
|
| 595 |
+
print(f"\n[1/6] 开始处理视频: {video_path}")
|
| 596 |
+
|
| 597 |
+
cap = cv2.VideoCapture(video_path)
|
| 598 |
+
if not cap.isOpened():
|
| 599 |
+
raise ValueError(f"无法打开视频文件: {video_path}")
|
| 600 |
+
|
| 601 |
+
fps = int(cap.get(cv2.CAP_PROP_FPS))
|
| 602 |
+
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 603 |
+
video_duration = total_frames / fps
|
| 604 |
+
|
| 605 |
+
# 处理时间范围
|
| 606 |
+
if end_time is None:
|
| 607 |
+
end_time = video_duration
|
| 608 |
+
|
| 609 |
+
# 确保时间范围有效
|
| 610 |
+
if start_time < 0 or start_time >= video_duration:
|
| 611 |
+
raise ValueError(f"起始时间 {start_time:.2f}s 超出视频范围 [0, {video_duration:.2f}s]")
|
| 612 |
+
if end_time <= start_time or end_time > video_duration:
|
| 613 |
+
raise ValueError(f"结束时间 {end_time:.2f}s 无效,应在区间 ({start_time:.2f}, {video_duration:.2f}s]")
|
| 614 |
+
|
| 615 |
+
# 计算帧范围
|
| 616 |
+
start_frame = int(start_time * fps)
|
| 617 |
+
end_frame = int(end_time * fps)
|
| 618 |
+
process_frames = end_frame - start_frame
|
| 619 |
+
print(f"视频信息: 总时长 {video_duration:.2f}s ({total_frames}帧, {fps}FPS)")
|
| 620 |
+
print(f"处理时间段: {start_time:.2f}s - {end_time:.2f}s ({process_frames}帧)")
|
| 621 |
+
|
| 622 |
+
# 跳转到起始帧
|
| 623 |
+
cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
|
| 624 |
+
|
| 625 |
+
frame_id = start_frame
|
| 626 |
+
print("\n[2/6] 开始检测足迹...")
|
| 627 |
+
while cap.isOpened() and frame_id < end_frame:
|
| 628 |
+
success, frame = cap.read()
|
| 629 |
+
if not success:
|
| 630 |
+
break
|
| 631 |
+
|
| 632 |
+
if (frame_id - start_frame) % 30 == 0: # 每30帧打印一次进度
|
| 633 |
+
progress = (frame_id - start_frame) / process_frames
|
| 634 |
+
print(f"处理进度: {frame_id-start_frame}/{process_frames} ({progress*100:.1f}%)")
|
| 635 |
+
|
| 636 |
+
timestamp = frame_id / fps
|
| 637 |
+
|
| 638 |
+
# 处理pose检测结果
|
| 639 |
+
pose_results = self.pose_model(frame, conf=conf_thres, verbose=False)
|
| 640 |
+
for r in pose_results:
|
| 641 |
+
if r.keypoints is not None and len(r.keypoints.data) > 0:
|
| 642 |
+
kpts = r.keypoints.data[0] # 获取第一个检测到的老鼠的关键点
|
| 643 |
+
if len(kpts) >= 7: # 确保检测到所有关键点
|
| 644 |
+
self.mice_positions.append({
|
| 645 |
+
'frame_id': frame_id,
|
| 646 |
+
'timestamp': timestamp,
|
| 647 |
+
'keypoints': {
|
| 648 |
+
'nose': (kpts[0][0].item(), kpts[0][1].item()),
|
| 649 |
+
'right_ear': (kpts[1][0].item(), kpts[1][1].item()),
|
| 650 |
+
'left_ear': (kpts[2][0].item(), kpts[2][1].item()),
|
| 651 |
+
'mid': (kpts[3][0].item(), kpts[3][1].item()),
|
| 652 |
+
'right_leg': (kpts[4][0].item(), kpts[4][1].item()),
|
| 653 |
+
'left_leg': (kpts[5][0].item(), kpts[5][1].item()),
|
| 654 |
+
'tail_base': (kpts[6][0].item(), kpts[6][1].item())
|
| 655 |
+
},
|
| 656 |
+
'x': kpts[3][0].item(), # 使用中点作为老鼠位置
|
| 657 |
+
'y': kpts[3][1].item()
|
| 658 |
+
})
|
| 659 |
+
|
| 660 |
+
# 处理gait检测结果
|
| 661 |
+
gait_results = self.gait_model(frame, conf=conf_thres, iou=iou_thres, verbose=False)
|
| 662 |
+
|
| 663 |
+
# 处理检测结果
|
| 664 |
+
for r in gait_results:
|
| 665 |
+
boxes = r.boxes
|
| 666 |
+
for box in boxes:
|
| 667 |
+
x1, y1, x2, y2 = map(int, box.xyxy[0])
|
| 668 |
+
conf = float(box.conf[0])
|
| 669 |
+
cls = int(box.cls[0])
|
| 670 |
+
|
| 671 |
+
if cls == 0: # gait
|
| 672 |
+
# 提取足印图像块
|
| 673 |
+
patch = frame[y1:y2, x1:x2].copy()
|
| 674 |
+
# 只保留绿色通道
|
| 675 |
+
green_mask = self._extract_green_channel(patch)
|
| 676 |
+
# 提取图像特征
|
| 677 |
+
image_features = self._extract_image_features(green_mask)
|
| 678 |
+
|
| 679 |
+
self.gait_prints.append(GaitPrint(
|
| 680 |
+
frame_id=frame_id,
|
| 681 |
+
x=(x1 + x2) / 2,
|
| 682 |
+
y=(y1 + y2) / 2,
|
| 683 |
+
w=x2 - x1,
|
| 684 |
+
h=y2 - y1,
|
| 685 |
+
conf=conf,
|
| 686 |
+
timestamp=timestamp,
|
| 687 |
+
image_patch=green_mask,
|
| 688 |
+
image_features=image_features
|
| 689 |
+
|
| 690 |
+
))
|
| 691 |
+
else: # mice类
|
| 692 |
+
self.mice_positions.append({
|
| 693 |
+
'frame_id': frame_id,
|
| 694 |
+
'x': (x1 + x2) / 2,
|
| 695 |
+
'y': (y1 + y2) / 2,
|
| 696 |
+
'w': x2 - x1,
|
| 697 |
+
'h': y2 - y1,
|
| 698 |
+
'conf': conf,
|
| 699 |
+
'timestamp': timestamp
|
| 700 |
+
})
|
| 701 |
+
|
| 702 |
+
frame_id += 1
|
| 703 |
+
|
| 704 |
+
cap.release()
|
| 705 |
+
print(f"检测完成! 共检测到 {len(self.gait_prints)} 个足迹")
|
| 706 |
+
print("\n[3/6] 开始后处理...")
|
| 707 |
+
self._post_process_footprints()
|
| 708 |
+
print(f"后处理完成! 剩余 {len(self.gait_prints)} 个有效足迹")
|
| 709 |
+
|
| 710 |
+
print("\n[4/6] 计算步态参数...")
|
| 711 |
+
self._calculate_gait_parameters()
|
| 712 |
+
print("参数计算完成!")
|
| 713 |
+
|
| 714 |
+
print("\n[5/6] 保存分析结果...")
|
| 715 |
+
self._save_results()
|
| 716 |
+
self.visualize_results()
|
| 717 |
+
print("分析结果已保存!")
|
| 718 |
+
|
| 719 |
+
print("\n[6/6] 生成轨迹视频...")
|
| 720 |
+
self.generate_trajectory_video(video_path)
|
| 721 |
+
print("视频生成完成!")
|
| 722 |
+
|
| 723 |
+
def _get_paw_color(self, paw_type: str) -> Tuple[int, int, int]:
|
| 724 |
+
"""获取不同爪子类型的颜色"""
|
| 725 |
+
colors = {
|
| 726 |
+
'LF': (0, 255, 0), # 绿色
|
| 727 |
+
'RF': (0, 255, 255), # 黄色
|
| 728 |
+
'LH': (255, 0, 255), # 紫色
|
| 729 |
+
'RH': (0, 165, 255), # 橙色
|
| 730 |
+
None: (0, 255, 0) # 未分类时为绿色
|
| 731 |
+
}
|
| 732 |
+
return colors.get(paw_type, (0, 255, 0))
|
| 733 |
+
|
| 734 |
+
def _save_results(self):
|
| 735 |
+
"""保存分析结果"""
|
| 736 |
+
# 保存原有的CSV数据
|
| 737 |
+
df = pd.DataFrame([{
|
| 738 |
+
'frame_id': p.frame_id,
|
| 739 |
+
'x': p.x,
|
| 740 |
+
'y': p.y,
|
| 741 |
+
'width': p.w,
|
| 742 |
+
'height': p.h,
|
| 743 |
+
'confidence': p.conf,
|
| 744 |
+
'paw_type': p.paw_type,
|
| 745 |
+
'timestamp': p.timestamp,
|
| 746 |
+
'gait_cycle': p.gait_cycle,
|
| 747 |
+
'stance_phase': p.stance_phase
|
| 748 |
+
} for p in self.gait_prints])
|
| 749 |
+
|
| 750 |
+
df.to_csv(f'{self.result_dir}/data/gait_analysis.csv', index=False)
|
| 751 |
+
|
| 752 |
+
# 保存参数数据
|
| 753 |
+
params_df = pd.DataFrame(self.params)
|
| 754 |
+
params_df.to_csv(f'{self.result_dir}/data/gait_parameters.csv')
|
| 755 |
+
|
| 756 |
+
# 保存JSON格式的足印数据
|
| 757 |
+
self._save_footprint_json()
|
| 758 |
+
|
| 759 |
+
def visualize_results(self):
|
| 760 |
+
"""可视化分析结果"""
|
| 761 |
+
# 创建保存路径
|
| 762 |
+
plots_dir = f'{self.result_dir}/plots'
|
| 763 |
+
|
| 764 |
+
# 1. 足印轨迹图
|
| 765 |
+
fig, ax = plt.subplots(figsize=(10, 8))
|
| 766 |
+
self._plot_footprint_trajectory(ax)
|
| 767 |
+
plt.savefig(f'{plots_dir}/trajectory.png', dpi=300, bbox_inches='tight')
|
| 768 |
+
plt.close()
|
| 769 |
+
|
| 770 |
+
# 2. 步态时序图
|
| 771 |
+
fig, ax = plt.subplots(figsize=(12, 6))
|
| 772 |
+
self._plot_gait_timeline(ax)
|
| 773 |
+
plt.savefig(f'{plots_dir}/timeline.png', dpi=300, bbox_inches='tight')
|
| 774 |
+
plt.close()
|
| 775 |
+
|
| 776 |
+
# 3. 步态参数图
|
| 777 |
+
# 3.1 步幅
|
| 778 |
+
fig, ax = plt.subplots(figsize=(8, 6))
|
| 779 |
+
self._plot_stride_length(ax)
|
| 780 |
+
plt.savefig(f'{plots_dir}/stride_length.png', dpi=300, bbox_inches='tight')
|
| 781 |
+
plt.close()
|
| 782 |
+
|
| 783 |
+
# 3.2 步频
|
| 784 |
+
fig, ax = plt.subplots(figsize=(8, 6))
|
| 785 |
+
self._plot_step_frequency(ax)
|
| 786 |
+
plt.savefig(f'{plots_dir}/step_frequency.png', dpi=300, bbox_inches='tight')
|
| 787 |
+
plt.close()
|
| 788 |
+
|
| 789 |
+
# 3.3 支撑和摆动时间
|
| 790 |
+
fig, ax = plt.subplots(figsize=(10, 6))
|
| 791 |
+
self._plot_stance_swing_time(ax)
|
| 792 |
+
plt.savefig(f'{plots_dir}/stance_swing_time.png', dpi=300, bbox_inches='tight')
|
| 793 |
+
plt.close()
|
| 794 |
+
|
| 795 |
+
# 3.4 支撑占空比
|
| 796 |
+
fig, ax = plt.subplots(figsize=(8, 6))
|
| 797 |
+
self._plot_duty_factor(ax)
|
| 798 |
+
plt.savefig(f'{plots_dir}/duty_factor.png', dpi=300, bbox_inches='tight')
|
| 799 |
+
plt.close()
|
| 800 |
+
|
| 801 |
+
# 3.5 对称性指数
|
| 802 |
+
fig, ax = plt.subplots(figsize=(8, 6))
|
| 803 |
+
self._plot_symmetry_index(ax)
|
| 804 |
+
plt.savefig(f'{plots_dir}/symmetry_index.png', dpi=300, bbox_inches='tight')
|
| 805 |
+
plt.close()
|
| 806 |
+
|
| 807 |
+
# 3.6 支撑基底
|
| 808 |
+
fig, ax = plt.subplots(figsize=(8, 6))
|
| 809 |
+
self._plot_base_of_support(ax)
|
| 810 |
+
plt.savefig(f'{plots_dir}/base_of_support.png', dpi=300, bbox_inches='tight')
|
| 811 |
+
plt.close()
|
| 812 |
+
|
| 813 |
+
# 4. 步态模式热图
|
| 814 |
+
fig, ax = plt.subplots(figsize=(12, 6))
|
| 815 |
+
self._plot_gait_pattern(ax)
|
| 816 |
+
plt.savefig(f'{plots_dir}/gait_pattern.png', dpi=300, bbox_inches='tight')
|
| 817 |
+
plt.close()
|
| 818 |
+
|
| 819 |
+
def generate_trajectory_video(self, video_path: str):
|
| 820 |
+
"""生成包含足迹轨迹的视频"""
|
| 821 |
+
try:
|
| 822 |
+
# 准备输出路径
|
| 823 |
+
video_name = os.path.splitext(os.path.basename(video_path))[0]
|
| 824 |
+
yolo_path = f"{self.result_dir}/videos/{video_name}_yolo.mp4"
|
| 825 |
+
trajectory_path = f"{self.result_dir}/videos/{video_name}_trajectory.mp4"
|
| 826 |
+
combined_path = f"{self.result_dir}/videos/{video_name}_combined.mp4"
|
| 827 |
+
|
| 828 |
+
# 打开视频
|
| 829 |
+
cap = cv2.VideoCapture(video_path)
|
| 830 |
+
if not cap.isOpened():
|
| 831 |
+
raise ValueError("无法打开视频文件")
|
| 832 |
+
|
| 833 |
+
# 获取视频信息
|
| 834 |
+
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
| 835 |
+
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
| 836 |
+
fps = int(cap.get(cv2.CAP_PROP_FPS))
|
| 837 |
+
frame_duration = 1.0 / fps
|
| 838 |
+
|
| 839 |
+
# 确定有效时间范围
|
| 840 |
+
valid_times = [p.timestamp for p in self.gait_prints]
|
| 841 |
+
start_time = min(valid_times) if valid_times else 0
|
| 842 |
+
end_time = max(valid_times) if valid_times else 0
|
| 843 |
+
|
| 844 |
+
# 跳转到起始帧
|
| 845 |
+
start_frame = int(start_time * fps)
|
| 846 |
+
cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
|
| 847 |
+
|
| 848 |
+
# 创建视频写入器
|
| 849 |
+
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
| 850 |
+
yolo_writer = cv2.VideoWriter(yolo_path, fourcc, fps, (width, height))
|
| 851 |
+
traj_writer = cv2.VideoWriter(trajectory_path, fourcc, fps, (width, height))
|
| 852 |
+
combined_writer = cv2.VideoWriter(combined_path, fourcc, fps, (width, height*2))
|
| 853 |
+
|
| 854 |
+
# 准备轨迹图像
|
| 855 |
+
trajectory_img = np.ones((height, width, 3), dtype=np.uint8) * 255
|
| 856 |
+
|
| 857 |
+
# 按时间排序足迹和老鼠位置
|
| 858 |
+
sorted_prints = sorted(self.gait_prints, key=lambda x: x.timestamp)
|
| 859 |
+
sorted_mice = sorted(self.mice_positions, key=lambda x: x['timestamp'])
|
| 860 |
+
current_time = start_time
|
| 861 |
+
|
| 862 |
+
while cap.isOpened() and current_time <= end_time:
|
| 863 |
+
ret, frame = cap.read()
|
| 864 |
+
if not ret:
|
| 865 |
+
break
|
| 866 |
+
|
| 867 |
+
# 1. 创建YOLO检测帧
|
| 868 |
+
yolo_frame = frame.copy()
|
| 869 |
+
|
| 870 |
+
# 2. 创建轨迹帧
|
| 871 |
+
traj_frame = trajectory_img.copy()
|
| 872 |
+
|
| 873 |
+
# 绘制当前时间的老鼠位置
|
| 874 |
+
current_mouse = next((m for m in sorted_mice if abs(m['timestamp'] - current_time) < frame_duration), None)
|
| 875 |
+
if current_mouse and 'keypoints' in current_mouse: # 确保有关键点数据
|
| 876 |
+
keypoints = current_mouse['keypoints']
|
| 877 |
+
nose = keypoints['nose']
|
| 878 |
+
tail = keypoints['tail_base']
|
| 879 |
+
|
| 880 |
+
# 计算中点
|
| 881 |
+
center_x = int((nose[0] + tail[0]) / 2)
|
| 882 |
+
center_y = int((nose[1] + tail[1]) / 2)
|
| 883 |
+
|
| 884 |
+
# 绘制鼻子到尾巴的虚线箭头
|
| 885 |
+
# 计算箭头方向
|
| 886 |
+
dx = nose[0] - tail[0]
|
| 887 |
+
dy = nose[1] - tail[1]
|
| 888 |
+
angle = np.arctan2(dy, dx)
|
| 889 |
+
|
| 890 |
+
# 绘制虚线
|
| 891 |
+
pt1 = (int(tail[0]), int(tail[1]))
|
| 892 |
+
pt2 = (int(nose[0]), int(nose[1]))
|
| 893 |
+
|
| 894 |
+
# 使用虚线绘制主线
|
| 895 |
+
dash_length = 5
|
| 896 |
+
total_length = np.sqrt(dx**2 + dy**2)
|
| 897 |
+
num_segments = int(total_length / (dash_length * 2))
|
| 898 |
+
|
| 899 |
+
for i in range(num_segments):
|
| 900 |
+
start_ratio = i / num_segments
|
| 901 |
+
end_ratio = (i + 0.5) / num_segments
|
| 902 |
+
start_x = int(tail[0] + dx * start_ratio)
|
| 903 |
+
start_y = int(tail[1] + dy * start_ratio)
|
| 904 |
+
end_x = int(tail[0] + dx * end_ratio)
|
| 905 |
+
end_y = int(tail[1] + dy * end_ratio)
|
| 906 |
+
cv2.line(yolo_frame, (start_x, start_y), (end_x, end_y), (255, 0, 0), 1)
|
| 907 |
+
cv2.line(traj_frame, (start_x, start_y), (end_x, end_y), (255, 0, 0), 1)
|
| 908 |
+
|
| 909 |
+
# 绘制箭头
|
| 910 |
+
arrow_length = 15
|
| 911 |
+
arrow_angle = np.pi/6 # 30度
|
| 912 |
+
|
| 913 |
+
# 计算箭头两边的点
|
| 914 |
+
arrow1_x = int(nose[0] - arrow_length * np.cos(angle + arrow_angle))
|
| 915 |
+
arrow1_y = int(nose[1] - arrow_length * np.sin(angle + arrow_angle))
|
| 916 |
+
arrow2_x = int(nose[0] - arrow_length * np.cos(angle - arrow_angle))
|
| 917 |
+
arrow2_y = int(nose[1] - arrow_length * np.sin(angle - arrow_angle))
|
| 918 |
+
|
| 919 |
+
# 绘制箭头
|
| 920 |
+
cv2.line(yolo_frame, (int(nose[0]), int(nose[1])), (arrow1_x, arrow1_y), (255, 0, 0), 1)
|
| 921 |
+
cv2.line(yolo_frame, (int(nose[0]), int(nose[1])), (arrow2_x, arrow2_y), (255, 0, 0), 1)
|
| 922 |
+
cv2.line(traj_frame, (int(nose[0]), int(nose[1])), (arrow1_x, arrow1_y), (255, 0, 0), 1)
|
| 923 |
+
cv2.line(traj_frame, (int(nose[0]), int(nose[1])), (arrow2_x, arrow2_y), (255, 0, 0), 1)
|
| 924 |
+
|
| 925 |
+
# 绘制中点
|
| 926 |
+
cv2.circle(yolo_frame, (center_x, center_y), 3, (255, 0, 0), -1)
|
| 927 |
+
cv2.circle(traj_frame, (center_x, center_y), 3, (255, 0, 0), -1)
|
| 928 |
+
|
| 929 |
+
# 绘制当前时间之前的所有足迹
|
| 930 |
+
for foot_print in sorted_prints:
|
| 931 |
+
if foot_print.timestamp <= current_time:
|
| 932 |
+
color = self._get_paw_color(foot_print.paw_type)
|
| 933 |
+
x1 = int(foot_print.x - foot_print.w/2)
|
| 934 |
+
y1 = int(foot_print.y - foot_print.h/2)
|
| 935 |
+
x2 = int(foot_print.x + foot_print.w/2)
|
| 936 |
+
y2 = int(foot_print.y + foot_print.h/2)
|
| 937 |
+
|
| 938 |
+
# 在YOLO帧上绘制检测框和标签
|
| 939 |
+
cv2.rectangle(yolo_frame, (x1, y1), (x2, y2), color, 2)
|
| 940 |
+
if foot_print.paw_type:
|
| 941 |
+
cv2.putText(yolo_frame, foot_print.paw_type, (x1, y1-5),
|
| 942 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
|
| 943 |
+
|
| 944 |
+
# 在轨迹帧上绘制足迹点
|
| 945 |
+
cv2.circle(traj_frame, (int(foot_print.x), int(foot_print.y)), 3, color, -1)
|
| 946 |
+
|
| 947 |
+
# 添加时间戳
|
| 948 |
+
cv2.putText(yolo_frame, f"Time: {current_time:.2f}s", (10, 30),
|
| 949 |
+
cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 2)
|
| 950 |
+
cv2.putText(traj_frame, f"Time: {current_time:.2f}s", (10, 30),
|
| 951 |
+
cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 2)
|
| 952 |
+
|
| 953 |
+
# 合并视频
|
| 954 |
+
combined_frame = np.vstack((yolo_frame, traj_frame))
|
| 955 |
+
|
| 956 |
+
# 写入帧
|
| 957 |
+
yolo_writer.write(yolo_frame)
|
| 958 |
+
traj_writer.write(traj_frame)
|
| 959 |
+
combined_writer.write(combined_frame)
|
| 960 |
+
|
| 961 |
+
current_time += frame_duration
|
| 962 |
+
|
| 963 |
+
# 释放资源
|
| 964 |
+
cap.release()
|
| 965 |
+
yolo_writer.release()
|
| 966 |
+
traj_writer.release()
|
| 967 |
+
combined_writer.release()
|
| 968 |
+
|
| 969 |
+
print(f"\nYOLO视频生成完成: {yolo_path}")
|
| 970 |
+
print(f"轨迹视频生成完成: {trajectory_path}")
|
| 971 |
+
print(f"并列视频生成完成: {combined_path}")
|
| 972 |
+
|
| 973 |
+
except Exception as e:
|
| 974 |
+
print(f"生成视频时出错: {str(e)}")
|
| 975 |
+
raise
|
| 976 |
+
|
| 977 |
+
def _plot_footprint_trajectory(self, ax):
|
| 978 |
+
"""绘制足印轨迹"""
|
| 979 |
+
ax.set_title('Footprint Trajectory')
|
| 980 |
+
|
| 981 |
+
# 为每种足印类型绘制散点图
|
| 982 |
+
colors = self._get_paw_color_dict()
|
| 983 |
+
for paw_type in ['LF', 'RF', 'LH', 'RH']:
|
| 984 |
+
prints = [p for p in self.gait_prints if p.paw_type == paw_type]
|
| 985 |
+
if prints:
|
| 986 |
+
x = [p.x for p in prints]
|
| 987 |
+
y = [p.y for p in prints]
|
| 988 |
+
ax.scatter(x, y, c=[colors[paw_type]], label=paw_type, alpha=0.6)
|
| 989 |
+
|
| 990 |
+
# 绘制老鼠运动轨迹,使用鼻子到尾巴的中点
|
| 991 |
+
mouse_positions = []
|
| 992 |
+
for pos in self.mice_positions:
|
| 993 |
+
if 'keypoints' in pos:
|
| 994 |
+
nose = pos['keypoints']['nose']
|
| 995 |
+
tail = pos['keypoints']['tail_base']
|
| 996 |
+
mid_x = (nose[0] + tail[0]) / 2
|
| 997 |
+
mid_y = (nose[1] + tail[1]) / 2
|
| 998 |
+
mouse_positions.append((mid_x, mid_y))
|
| 999 |
+
|
| 1000 |
+
if mouse_positions:
|
| 1001 |
+
mouse_x, mouse_y = zip(*mouse_positions)
|
| 1002 |
+
ax.plot(mouse_x, mouse_y, 'b--', alpha=0.1, label='Mouse path')
|
| 1003 |
+
# 在最后一个位置绘制一个点
|
| 1004 |
+
ax.scatter(mouse_x[-1], mouse_y[-1], c='blue', s=50, alpha=0.6)
|
| 1005 |
+
|
| 1006 |
+
ax.legend()
|
| 1007 |
+
ax.set_xlabel('X position (pixels)')
|
| 1008 |
+
ax.set_ylabel('Y position (pixels)')
|
| 1009 |
+
ax.invert_yaxis() # 图像坐标系y轴向下
|
| 1010 |
+
|
| 1011 |
+
def _plot_gait_timeline(self, ax):
|
| 1012 |
+
"""绘制步态时序图"""
|
| 1013 |
+
ax.set_title('Gait Timeline')
|
| 1014 |
+
|
| 1015 |
+
colors = self._get_paw_color_dict()
|
| 1016 |
+
y_positions = {'LF': 4, 'RF': 3, 'LH': 2, 'RH': 1}
|
| 1017 |
+
|
| 1018 |
+
for paw_type in ['LF', 'RF', 'LH', 'RH']:
|
| 1019 |
+
prints = [p for p in self.gait_prints if p.paw_type == paw_type]
|
| 1020 |
+
times = [p.timestamp for p in prints]
|
| 1021 |
+
y = [y_positions[paw_type]] * len(times)
|
| 1022 |
+
ax.scatter(times, y, c=[colors[paw_type]], label=paw_type, marker='s')
|
| 1023 |
+
|
| 1024 |
+
ax.set_yticks(list(y_positions.values()))
|
| 1025 |
+
ax.set_yticklabels(list(y_positions.keys()))
|
| 1026 |
+
ax.set_xlabel('Time (seconds)')
|
| 1027 |
+
ax.grid(True, axis='x', alpha=0.3)
|
| 1028 |
+
|
| 1029 |
+
def _plot_gait_parameters(self, ax):
|
| 1030 |
+
"""绘制步态参数柱状图"""
|
| 1031 |
+
ax.set_title('Gait Parameters')
|
| 1032 |
+
|
| 1033 |
+
# 提取步幅和步频数据
|
| 1034 |
+
paw_types = list(self.params['stride_length'].keys())
|
| 1035 |
+
stride_lengths = [self.params['stride_length'][p] for p in paw_types]
|
| 1036 |
+
frequencies = [self.params['step_frequency'][p] for p in paw_types]
|
| 1037 |
+
|
| 1038 |
+
x = np.arange(len(paw_types))
|
| 1039 |
+
width = 0.35
|
| 1040 |
+
|
| 1041 |
+
ax.bar(x - width/2, stride_lengths, width, label='Stride Length')
|
| 1042 |
+
ax.bar(x + width/2, frequencies, width, label='Step Frequency')
|
| 1043 |
+
|
| 1044 |
+
ax.set_xticks(x)
|
| 1045 |
+
ax.set_xticklabels(paw_types)
|
| 1046 |
+
ax.legend()
|
| 1047 |
+
ax.set_ylabel('Value')
|
| 1048 |
+
|
| 1049 |
+
if self.gait_pattern:
|
| 1050 |
+
ax.text(0.02, 0.98, f'Gait Pattern: {self.gait_pattern}',
|
| 1051 |
+
transform=ax.transAxes, verticalalignment='top')
|
| 1052 |
+
|
| 1053 |
+
def _plot_gait_pattern(self, ax):
|
| 1054 |
+
"""绘制步态模式热图"""
|
| 1055 |
+
ax.set_title('Gait Pattern')
|
| 1056 |
+
|
| 1057 |
+
# 创建时间窗口内的步态模式矩阵
|
| 1058 |
+
time_bins = np.linspace(0, max(p.timestamp for p in self.gait_prints), 20)
|
| 1059 |
+
paw_types = ['LF', 'RF', 'LH', 'RH']
|
| 1060 |
+
pattern_matrix = np.zeros((len(paw_types), len(time_bins)-1))
|
| 1061 |
+
|
| 1062 |
+
for i, paw_type in enumerate(paw_types):
|
| 1063 |
+
prints = [p for p in self.gait_prints if p.paw_type == paw_type]
|
| 1064 |
+
for p in prints:
|
| 1065 |
+
bin_idx = np.digitize(p.timestamp, time_bins) - 1
|
| 1066 |
+
if 0 <= bin_idx < len(time_bins)-1:
|
| 1067 |
+
pattern_matrix[i, bin_idx] += 1
|
| 1068 |
+
|
| 1069 |
+
sns.heatmap(pattern_matrix, ax=ax, cmap='YlOrRd',
|
| 1070 |
+
xticklabels=np.round(time_bins[:-1], 1),
|
| 1071 |
+
yticklabels=paw_types)
|
| 1072 |
+
ax.set_xlabel('Time (seconds)')
|
| 1073 |
+
|
| 1074 |
+
def _plot_stride_length(self, ax):
|
| 1075 |
+
"""绘制步幅图"""
|
| 1076 |
+
ax.set_title('Stride Length')
|
| 1077 |
+
paw_types = list(self.params['stride_length'].keys())
|
| 1078 |
+
values = [self.params['stride_length'][p] for p in paw_types]
|
| 1079 |
+
colors = [self._get_paw_color_dict()[p] for p in paw_types]
|
| 1080 |
+
|
| 1081 |
+
bars = ax.bar(paw_types, values)
|
| 1082 |
+
for bar, color in zip(bars, colors):
|
| 1083 |
+
bar.set_color(color)
|
| 1084 |
+
|
| 1085 |
+
ax.set_ylabel('Length (pixels)')
|
| 1086 |
+
ax.grid(True, alpha=0.3)
|
| 1087 |
+
|
| 1088 |
+
def _plot_step_frequency(self, ax):
|
| 1089 |
+
"""绘制步频图"""
|
| 1090 |
+
ax.set_title('Step Frequency')
|
| 1091 |
+
paw_types = list(self.params['step_frequency'].keys())
|
| 1092 |
+
values = [self.params['step_frequency'][p] for p in paw_types]
|
| 1093 |
+
colors = [self._get_paw_color_dict()[p] for p in paw_types]
|
| 1094 |
+
|
| 1095 |
+
bars = ax.bar(paw_types, values)
|
| 1096 |
+
for bar, color in zip(bars, colors):
|
| 1097 |
+
bar.set_color(color)
|
| 1098 |
+
|
| 1099 |
+
ax.set_ylabel('Frequency (steps/second)')
|
| 1100 |
+
ax.grid(True, alpha=0.3)
|
| 1101 |
+
|
| 1102 |
+
def _plot_stance_swing_time(self, ax):
|
| 1103 |
+
"""绘制支撑和摆动时间图"""
|
| 1104 |
+
ax.set_title('Stance and Swing Time')
|
| 1105 |
+
paw_types = list(self.params['stance_time'].keys())
|
| 1106 |
+
stance_times = [self.params['stance_time'][p] for p in paw_types]
|
| 1107 |
+
swing_times = [self.params['swing_time'][p] for p in paw_types]
|
| 1108 |
+
|
| 1109 |
+
x = np.arange(len(paw_types))
|
| 1110 |
+
width = 0.35
|
| 1111 |
+
|
| 1112 |
+
ax.bar(x - width/2, stance_times, width, label='Stance Time')
|
| 1113 |
+
ax.bar(x + width/2, swing_times, width, label='Swing Time')
|
| 1114 |
+
|
| 1115 |
+
ax.set_xticks(x)
|
| 1116 |
+
ax.set_xticklabels(paw_types)
|
| 1117 |
+
ax.set_ylabel('Time (seconds)')
|
| 1118 |
+
ax.legend()
|
| 1119 |
+
ax.grid(True, alpha=0.3)
|
| 1120 |
+
|
| 1121 |
+
def _plot_duty_factor(self, ax):
|
| 1122 |
+
"""绘制支撑占空比图"""
|
| 1123 |
+
ax.set_title('Duty Factor')
|
| 1124 |
+
paw_types = list(self.params['duty_factor'].keys())
|
| 1125 |
+
values = [self.params['duty_factor'][p] * 100 for p in paw_types] # 转换为百分比
|
| 1126 |
+
colors = [self._get_paw_color_dict()[p] for p in paw_types]
|
| 1127 |
+
|
| 1128 |
+
bars = ax.bar(paw_types, values)
|
| 1129 |
+
for bar, color in zip(bars, colors):
|
| 1130 |
+
bar.set_color(color)
|
| 1131 |
+
|
| 1132 |
+
ax.set_ylabel('Duty Factor (%)')
|
| 1133 |
+
ax.grid(True, alpha=0.3)
|
| 1134 |
+
|
| 1135 |
+
def _plot_symmetry_index(self, ax):
|
| 1136 |
+
"""绘制对称性指数图"""
|
| 1137 |
+
ax.set_title('Symmetry Index')
|
| 1138 |
+
sides = list(self.params['symmetry_index'].keys())
|
| 1139 |
+
values = [self.params['symmetry_index'][s] * 100 for s in sides] # 转换为百分比
|
| 1140 |
+
|
| 1141 |
+
bars = ax.bar(sides, values)
|
| 1142 |
+
ax.set_ylabel('Symmetry Index (%)')
|
| 1143 |
+
ax.grid(True, alpha=0.3)
|
| 1144 |
+
|
| 1145 |
+
def _plot_base_of_support(self, ax):
|
| 1146 |
+
"""绘制支撑基底图"""
|
| 1147 |
+
ax.set_title('Base of Support')
|
| 1148 |
+
sides = list(self.params['base_of_support'].keys())
|
| 1149 |
+
values = [self.params['base_of_support'][s] for s in sides]
|
| 1150 |
+
|
| 1151 |
+
bars = ax.bar(sides, values)
|
| 1152 |
+
ax.set_ylabel('Width (pixels)')
|
| 1153 |
+
ax.grid(True, alpha=0.3)
|
| 1154 |
+
|
| 1155 |
+
def _get_paw_color_dict(self):
|
| 1156 |
+
"""获取足印颜色字典(matplotlib格式)"""
|
| 1157 |
+
return {
|
| 1158 |
+
'LF': 'green',
|
| 1159 |
+
'RF': 'yellow',
|
| 1160 |
+
'LH': 'purple',
|
| 1161 |
+
'RH': 'orange'
|
| 1162 |
+
}
|
| 1163 |
+
|
| 1164 |
+
def _extract_green_channel(self, patch: np.ndarray) -> np.ndarray:
|
| 1165 |
+
"""提取并增强绿色通道"""
|
| 1166 |
+
# 转换为HSV颜色空间
|
| 1167 |
+
hsv = cv2.cvtColor(patch, cv2.COLOR_BGR2HSV)
|
| 1168 |
+
|
| 1169 |
+
# 定义绿色的HSV范围
|
| 1170 |
+
lower_green = np.array([40, 40, 40])
|
| 1171 |
+
upper_green = np.array([80, 255, 255])
|
| 1172 |
+
|
| 1173 |
+
# 创建掩码
|
| 1174 |
+
mask = cv2.inRange(hsv, lower_green, upper_green)
|
| 1175 |
+
|
| 1176 |
+
# 应用掩码
|
| 1177 |
+
green_only = cv2.bitwise_and(patch, patch, mask=mask)
|
| 1178 |
+
|
| 1179 |
+
# 转换为灰度图
|
| 1180 |
+
gray = cv2.cvtColor(green_only, cv2.COLOR_BGR2GRAY)
|
| 1181 |
+
|
| 1182 |
+
# 标准化尺寸
|
| 1183 |
+
resized = cv2.resize(gray, (32, 32))
|
| 1184 |
+
|
| 1185 |
+
return resized
|
| 1186 |
+
|
| 1187 |
+
def _extract_image_features(self, green_mask: np.ndarray) -> np.ndarray:
|
| 1188 |
+
"""从足印图像提取特征"""
|
| 1189 |
+
# 1. 计算形状特征
|
| 1190 |
+
contours, _ = cv2.findContours(green_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 1191 |
+
if not contours:
|
| 1192 |
+
return np.zeros(10) # 返回默认特征向量
|
| 1193 |
+
|
| 1194 |
+
largest_contour = max(contours, key=cv2.contourArea)
|
| 1195 |
+
|
| 1196 |
+
# 面积
|
| 1197 |
+
area = cv2.contourArea(largest_contour)
|
| 1198 |
+
# 周长
|
| 1199 |
+
perimeter = cv2.arcLength(largest_contour, True)
|
| 1200 |
+
# 圆度
|
| 1201 |
+
circularity = 4 * np.pi * area / (perimeter * perimeter) if perimeter > 0 else 0
|
| 1202 |
+
|
| 1203 |
+
# 2. 计算方向
|
| 1204 |
+
if len(largest_contour) >= 5:
|
| 1205 |
+
(x, y), (MA, ma), angle = cv2.fitEllipse(largest_contour)
|
| 1206 |
+
else:
|
| 1207 |
+
MA, ma, angle = 0, 0, 0
|
| 1208 |
+
|
| 1209 |
+
# 3. 计算Hu矩
|
| 1210 |
+
moments = cv2.moments(largest_contour)
|
| 1211 |
+
hu_moments = cv2.HuMoments(moments).flatten()
|
| 1212 |
+
|
| 1213 |
+
# 4. 组合特征
|
| 1214 |
+
features = np.array([
|
| 1215 |
+
area,
|
| 1216 |
+
perimeter,
|
| 1217 |
+
circularity,
|
| 1218 |
+
MA/ma if ma > 0 else 0, # 长宽比
|
| 1219 |
+
angle,
|
| 1220 |
+
*hu_moments[:5] # 取前5个Hu矩
|
| 1221 |
+
])
|
| 1222 |
+
|
| 1223 |
+
return features
|
| 1224 |
+
|
| 1225 |
+
def _generate_footprint_data(self) -> dict:
|
| 1226 |
+
"""生成足印数据的JSON格式"""
|
| 1227 |
+
# 初始化数据结构
|
| 1228 |
+
data = {
|
| 1229 |
+
"frames": [],
|
| 1230 |
+
"footprintArea": []
|
| 1231 |
+
}
|
| 1232 |
+
|
| 1233 |
+
# 按帧ID组织足印数据
|
| 1234 |
+
frame_prints = {}
|
| 1235 |
+
for print in self.gait_prints:
|
| 1236 |
+
if print.frame_id not in frame_prints:
|
| 1237 |
+
frame_prints[print.frame_id] = []
|
| 1238 |
+
frame_prints[print.frame_id].append(print)
|
| 1239 |
+
|
| 1240 |
+
# 生成footprintArea的ID映射
|
| 1241 |
+
area_id_map = {} # cluster_id -> footprintAreaId
|
| 1242 |
+
for print in self.gait_prints:
|
| 1243 |
+
if print.cluster_id not in area_id_map:
|
| 1244 |
+
area_id_map[print.cluster_id] = f"footprintArea_{print.cluster_id}"
|
| 1245 |
+
|
| 1246 |
+
# 转换爪子类型
|
| 1247 |
+
type_map = {
|
| 1248 |
+
'LF': 'leftFront',
|
| 1249 |
+
'RF': 'rightFront',
|
| 1250 |
+
'LH': 'leftHind',
|
| 1251 |
+
'RH': 'rightHind'
|
| 1252 |
+
}
|
| 1253 |
+
|
| 1254 |
+
# 生成frames数据
|
| 1255 |
+
for frame_id, prints in frame_prints.items():
|
| 1256 |
+
frame_data = {
|
| 1257 |
+
"frameId": frame_id,
|
| 1258 |
+
"footprints": []
|
| 1259 |
+
}
|
| 1260 |
+
|
| 1261 |
+
for i, print in enumerate(prints):
|
| 1262 |
+
# 转换中心点坐标为左上角坐标
|
| 1263 |
+
x = int(print.x - print.w/2)
|
| 1264 |
+
y = int(print.y - print.h/2)
|
| 1265 |
+
w = int(print.w)
|
| 1266 |
+
h = int(print.h)
|
| 1267 |
+
|
| 1268 |
+
footprint_data = {
|
| 1269 |
+
"id": f"footprint_{frame_id}_{i}",
|
| 1270 |
+
"type": type_map.get(print.paw_type, 'unknown'),
|
| 1271 |
+
"isKeyFootprint": False,
|
| 1272 |
+
"position": {
|
| 1273 |
+
"x": x,
|
| 1274 |
+
"y": y,
|
| 1275 |
+
"width": w,
|
| 1276 |
+
"height": h
|
| 1277 |
+
},
|
| 1278 |
+
"footprintAreaId": area_id_map[print.cluster_id]
|
| 1279 |
+
}
|
| 1280 |
+
frame_data["footprints"].append(footprint_data)
|
| 1281 |
+
|
| 1282 |
+
data["frames"].append(frame_data)
|
| 1283 |
+
|
| 1284 |
+
# 生成footprintArea数据
|
| 1285 |
+
cluster_groups = {}
|
| 1286 |
+
for print in self.gait_prints:
|
| 1287 |
+
if print.cluster_id not in cluster_groups:
|
| 1288 |
+
cluster_groups[print.cluster_id] = []
|
| 1289 |
+
cluster_groups[print.cluster_id].append(print)
|
| 1290 |
+
|
| 1291 |
+
for cluster_id, prints in cluster_groups.items():
|
| 1292 |
+
# 计算整个簇的边界框(考虑每个足印的完整区域)
|
| 1293 |
+
x_coords = []
|
| 1294 |
+
y_coords = []
|
| 1295 |
+
for p in prints:
|
| 1296 |
+
# 添加每个足印框的四个角点
|
| 1297 |
+
x_coords.extend([
|
| 1298 |
+
p.x - p.w/2, # 左边界
|
| 1299 |
+
p.x + p.w/2 # 右边界
|
| 1300 |
+
])
|
| 1301 |
+
y_coords.extend([
|
| 1302 |
+
p.y - p.h/2, # 上边界
|
| 1303 |
+
p.y + p.h/2 # 下边界
|
| 1304 |
+
])
|
| 1305 |
+
|
| 1306 |
+
# 计算能包含所有足印的最小矩形
|
| 1307 |
+
x_min = min(x_coords)
|
| 1308 |
+
y_min = min(y_coords)
|
| 1309 |
+
x_max = max(x_coords)
|
| 1310 |
+
y_max = max(y_coords)
|
| 1311 |
+
|
| 1312 |
+
# 选择关键帧(这里选择置信度最高的)
|
| 1313 |
+
key_print = max(prints, key=lambda p: p.conf)
|
| 1314 |
+
|
| 1315 |
+
# 获取图像数据
|
| 1316 |
+
if key_print.image_patch is not None:
|
| 1317 |
+
import base64
|
| 1318 |
+
import cv2
|
| 1319 |
+
success, buffer = cv2.imencode('.png', key_print.image_patch)
|
| 1320 |
+
if success:
|
| 1321 |
+
img_base64 = base64.b64encode(buffer).decode('utf-8')
|
| 1322 |
+
else:
|
| 1323 |
+
img_base64 = ""
|
| 1324 |
+
else:
|
| 1325 |
+
img_base64 = ""
|
| 1326 |
+
|
| 1327 |
+
area_data = {
|
| 1328 |
+
"footprintAreaId": area_id_map[cluster_id],
|
| 1329 |
+
"type": type_map.get(key_print.paw_type, 'unknown'),
|
| 1330 |
+
"areaPosition": {
|
| 1331 |
+
"x": int(x_min),
|
| 1332 |
+
"y": int(y_min),
|
| 1333 |
+
"width": int(x_max - x_min),
|
| 1334 |
+
"height": int(y_max - y_min)
|
| 1335 |
+
},
|
| 1336 |
+
"keyFootprintFrame": {
|
| 1337 |
+
"frameId": key_print.frame_id,
|
| 1338 |
+
"footprintId": f"footprint_{key_print.frame_id}_{cluster_id}",
|
| 1339 |
+
"footPosition": {
|
| 1340 |
+
"x": int(key_print.x - key_print.w/2),
|
| 1341 |
+
"y": int(key_print.y - key_print.h/2),
|
| 1342 |
+
"width": int(key_print.w),
|
| 1343 |
+
"height": int(key_print.h)
|
| 1344 |
+
},
|
| 1345 |
+
"base64Image": img_base64
|
| 1346 |
+
},
|
| 1347 |
+
"startFrame": min(p.frame_id for p in prints),
|
| 1348 |
+
"endFrame": max(p.frame_id for p in prints)
|
| 1349 |
+
}
|
| 1350 |
+
|
| 1351 |
+
data["footprintArea"].append(area_data)
|
| 1352 |
+
|
| 1353 |
+
return data
|
| 1354 |
+
|
| 1355 |
+
def _save_footprint_json(self):
|
| 1356 |
+
"""保存足印数据为JSON文件"""
|
| 1357 |
+
import json
|
| 1358 |
+
|
| 1359 |
+
data = self._generate_footprint_data()
|
| 1360 |
+
json_path = f'{self.result_dir}/data/footprint_data.json'
|
| 1361 |
+
|
| 1362 |
+
with open(json_path, 'w', encoding='utf-8') as f:
|
| 1363 |
+
json.dump(data, f, indent=2, ensure_ascii=False)
|
| 1364 |
+
|
| 1365 |
+
print(f"足印数据已保存至: {json_path}")
|
| 1366 |
+
|
| 1367 |
+
|
| 1368 |
+
def main():
|
| 1369 |
+
analyzer = GaitAnalyzer()
|
| 1370 |
+
video_path = "/Users/hakureirm/codespace/Work/Algorithm/gait/exp_videos/Exp8.mp4"
|
| 1371 |
+
|
| 1372 |
+
# 自动检测时间范围
|
| 1373 |
+
start_time, end_time = analyzer._detect_mouse_time_range(video_path)
|
| 1374 |
+
|
| 1375 |
+
# 处理检测到的时间段
|
| 1376 |
+
analyzer.process_video(
|
| 1377 |
+
video_path,
|
| 1378 |
+
start_time=start_time,
|
| 1379 |
+
end_time=end_time,
|
| 1380 |
+
conf_thres=0.7,
|
| 1381 |
+
iou_thres=0.5
|
| 1382 |
+
)
|
| 1383 |
+
|
| 1384 |
+
if __name__ == "__main__":
|
| 1385 |
+
main()
|
src/visualize_footprint_json.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import cv2
|
| 3 |
+
import numpy as np
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
def visualize_footprint_json(json_path, output_dir):
|
| 7 |
+
"""可视化足印JSON数据"""
|
| 8 |
+
# 读取JSON数据
|
| 9 |
+
with open(json_path, 'r') as f:
|
| 10 |
+
data = json.load(f)
|
| 11 |
+
|
| 12 |
+
# 创建输出目录
|
| 13 |
+
output_dir = Path(output_dir)
|
| 14 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 15 |
+
|
| 16 |
+
# 确定画布大小(从数据中获取最大坐标)
|
| 17 |
+
max_x = max_y = 0
|
| 18 |
+
for frame in data['frames']:
|
| 19 |
+
for footprint in frame['footprints']:
|
| 20 |
+
pos = footprint['position']
|
| 21 |
+
max_x = max(max_x, pos['x'] + pos['width'])
|
| 22 |
+
max_y = max(max_y, pos['y'] + pos['height'])
|
| 23 |
+
|
| 24 |
+
for area in data['footprintArea']:
|
| 25 |
+
pos = area['areaPosition']
|
| 26 |
+
max_x = max(max_x, pos['x'] + pos['width'])
|
| 27 |
+
max_y = max(max_y, pos['y'] + pos['height'])
|
| 28 |
+
|
| 29 |
+
# 添加一些边距
|
| 30 |
+
canvas_width = max_x + 50
|
| 31 |
+
canvas_height = max_y + 50
|
| 32 |
+
|
| 33 |
+
# 颜色映射
|
| 34 |
+
color_map = {
|
| 35 |
+
'leftFront': (0, 255, 0), # 绿色
|
| 36 |
+
'rightFront': (0, 255, 255), # 黄色
|
| 37 |
+
'leftHind': (255, 0, 255), # 紫色
|
| 38 |
+
'rightHind': (0, 165, 255), # 橙色
|
| 39 |
+
'unknown': (128, 128, 128) # 灰色
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
# 1. 绘制所有足印区域的覆盖范围
|
| 43 |
+
area_canvas = np.ones((canvas_height, canvas_width, 3), dtype=np.uint8) * 255
|
| 44 |
+
for area in data['footprintArea']:
|
| 45 |
+
pos = area['areaPosition']
|
| 46 |
+
color = color_map[area['type']]
|
| 47 |
+
|
| 48 |
+
# 绘制区域边界框
|
| 49 |
+
cv2.rectangle(area_canvas,
|
| 50 |
+
(pos['x'], pos['y']),
|
| 51 |
+
(pos['x'] + pos['width'], pos['y'] + pos['height']),
|
| 52 |
+
color, 2)
|
| 53 |
+
|
| 54 |
+
# 添加标签
|
| 55 |
+
label = f"{area['type']}#{area['footprintAreaId'].split('_')[-1]}"
|
| 56 |
+
cv2.putText(area_canvas, label,
|
| 57 |
+
(pos['x'], pos['y'] - 5),
|
| 58 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
|
| 59 |
+
|
| 60 |
+
# 保存足印区域图
|
| 61 |
+
cv2.imwrite(str(output_dir / 'footprint_areas.png'), area_canvas)
|
| 62 |
+
|
| 63 |
+
# 2. 为每一帧创建可视化
|
| 64 |
+
frame_ids = sorted(list(set(frame['frameId'] for frame in data['frames'])))
|
| 65 |
+
for frame_id in frame_ids:
|
| 66 |
+
frame_canvas = np.ones((canvas_height, canvas_width, 3), dtype=np.uint8) * 255
|
| 67 |
+
|
| 68 |
+
# 首先绘制所有区域的轮廓(半透明)
|
| 69 |
+
overlay = frame_canvas.copy()
|
| 70 |
+
for area in data['footprintArea']:
|
| 71 |
+
if area['startFrame'] <= frame_id <= area['endFrame']:
|
| 72 |
+
pos = area['areaPosition']
|
| 73 |
+
color = color_map[area['type']]
|
| 74 |
+
cv2.rectangle(overlay,
|
| 75 |
+
(pos['x'], pos['y']),
|
| 76 |
+
(pos['x'] + pos['width'], pos['y'] + pos['height']),
|
| 77 |
+
color, -1) # 填充矩形
|
| 78 |
+
|
| 79 |
+
# 应用半透明效果
|
| 80 |
+
alpha = 0.3
|
| 81 |
+
cv2.addWeighted(overlay, alpha, frame_canvas, 1 - alpha, 0, frame_canvas)
|
| 82 |
+
|
| 83 |
+
# 然后绘制当前帧的足印
|
| 84 |
+
for frame_data in data['frames']:
|
| 85 |
+
if frame_data['frameId'] == frame_id:
|
| 86 |
+
for footprint in frame_data['footprints']:
|
| 87 |
+
pos = footprint['position']
|
| 88 |
+
color = color_map[footprint['type']]
|
| 89 |
+
|
| 90 |
+
# 绘制足印边界框
|
| 91 |
+
cv2.rectangle(frame_canvas,
|
| 92 |
+
(pos['x'], pos['y']),
|
| 93 |
+
(pos['x'] + pos['width'], pos['y'] + pos['height']),
|
| 94 |
+
color, 2)
|
| 95 |
+
|
| 96 |
+
# 添加标签
|
| 97 |
+
label = f"{footprint['type']}#{footprint['id'].split('_')[-1]}"
|
| 98 |
+
cv2.putText(frame_canvas, label,
|
| 99 |
+
(pos['x'], pos['y'] - 5),
|
| 100 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
|
| 101 |
+
|
| 102 |
+
# 添加帧号
|
| 103 |
+
cv2.putText(frame_canvas, f"Frame: {frame_id}",
|
| 104 |
+
(10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 2)
|
| 105 |
+
|
| 106 |
+
# 保存帧图像
|
| 107 |
+
cv2.imwrite(str(output_dir / f'frame_{frame_id:04d}.png'), frame_canvas)
|
| 108 |
+
|
| 109 |
+
print(f"可视化结果已保存至: {output_dir}")
|
| 110 |
+
|
| 111 |
+
def main():
|
| 112 |
+
# 假设JSON文件在results目录下的最新时间戳文件夹中
|
| 113 |
+
results_dir = Path("results")
|
| 114 |
+
latest_dir = max(results_dir.glob("*"), key=lambda p: p.stat().st_mtime)
|
| 115 |
+
json_path = latest_dir / "data" / "footprint_data.json"
|
| 116 |
+
|
| 117 |
+
# 创建可视化输出目录
|
| 118 |
+
output_dir = latest_dir / "visualization"
|
| 119 |
+
|
| 120 |
+
visualize_footprint_json(json_path, output_dir)
|
| 121 |
+
|
| 122 |
+
if __name__ == "__main__":
|
| 123 |
+
main()
|