File size: 14,989 Bytes
71d0872 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 |
"""
步态聚类调试工具
用法:python gait_cluster_debugger.py --json 足印数据.json
"""
import argparse
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from sklearn.preprocessing import StandardScaler
from sklearn.cluster import DBSCAN
from sklearn.neighbors import NearestNeighbors
from gait_analyze import GaitAnalyzer # 从原文件导入类
import json
import os
from tqdm import tqdm
from dataclasses import dataclass
from pathlib import Path
import itertools
import cv2
import base64
from io import BytesIO
from concurrent.futures import ProcessPoolExecutor
import multiprocessing
from functools import wraps
@dataclass
class GaitPrint:
frame_id: int
x: float
y: float
w: float
h: float
conf: float
timestamp: float
paw_type: str = None
cluster_id: int = -1
# 其他字段按需添加...
def _process_combination_wrapper(args):
"""修复导入问题的处理函数"""
i, combination, gait_prints, params_template = args
# 改为绝对导入
from gait_cluster_debugger import ClusterDebugger # 移除相对导入
try:
debugger = ClusterDebugger(gait_prints)
debugger.params = params_template.copy()
debugger.params.update(combination)
# 执行聚类
labels = debugger.run_clustering()
# 生成可视化
fig = debugger._generate_plot(labels)
# 转换为Base64图片
from io import BytesIO
import base64
buf = BytesIO()
fig.savefig(buf, format='png', bbox_inches='tight', dpi=100)
plt.close(fig)
img_data = base64.b64encode(buf.getvalue()).decode('utf-8')
# 生成参数描述
param_desc = "<br>".join([f"{k}: {v}" for k, v in combination.items()])
# 构建HTML片段
html_snippet = f"""
<div class="item">
<div class="param-header">Combination #{i+1}</div>
<div class="param-desc">{param_desc}</div>
<img src="data:image/png;base64,{img_data}">
</div>
"""
return html_snippet
except Exception as e:
print(f"Error processing combination {i}: {str(e)}")
return ""
class ClusterDebugger:
def __init__(self, gait_prints):
self.gait_prints = gait_prints
self.params = {
'time_weight': 0.8, # 时间维度权重
'spatial_weight': 0.2, # 空间维度权重
'eps_factor': 5.5, # eps系数
'min_samples': 2, # 最小样本数
'merge_threshold': 0.2 # 合并阈值(秒)
}
def run_clustering(self):
"""执行聚类流程"""
# 准备特征
features = np.array([
[p.x * self.params['spatial_weight'],
p.y * self.params['spatial_weight'],
p.timestamp * self.params['time_weight']]
for p in self.gait_prints
])
# 标准化
scaler = StandardScaler()
features_scaled = scaler.fit_transform(features)
# 计算eps
k = min(len(features), 5)
nbrs = NearestNeighbors(n_neighbors=k).fit(features_scaled)
distances, _ = nbrs.kneighbors(features_scaled)
mean_dist = np.mean(distances[:, 1:])
eps = mean_dist * self.params['eps_factor']
# DBSCAN聚类
dbscan = DBSCAN(eps=eps, min_samples=self.params['min_samples'])
labels = dbscan.fit_predict(features_scaled)
# 合并时间连续的簇
self._merge_temporal_clusters(labels)
return labels
def _merge_temporal_clusters(self, labels):
"""简单的时间连续性合并"""
clusters = {}
for i, label in enumerate(labels):
if label not in clusters:
clusters[label] = []
clusters[label].append(self.gait_prints[i])
# 按时间排序并合并
new_labels = labels.copy()
current_label = max(labels) + 1
sorted_clusters = sorted(clusters.items(), key=lambda x: min(p.timestamp for p in x[1]))
for i in range(1, len(sorted_clusters)):
prev_label, prev_points = sorted_clusters[i-1]
curr_label, curr_points = sorted_clusters[i]
last_time = max(p.timestamp for p in prev_points)
first_time = min(p.timestamp for p in curr_points)
if first_time - last_time < self.params['merge_threshold']:
for p in curr_points:
new_labels[p.cluster_id] = prev_label
return new_labels
def visualize(self, labels):
"""交互式三维可视化"""
features = np.array([
[p.x, p.y, p.timestamp]
for p in self.gait_prints
])
plt.figure(figsize=(15, 8))
ax = plt.axes(projection='3d')
# 绘制聚类结果
scatter = ax.scatter3D(
features[:,0], features[:,1], features[:,2],
c=labels, cmap='tab20', alpha=0.8, s=50
)
# 标注参数
param_text = (
f"Time Weight: {self.params['time_weight']}\n"
f"Spatial Weight: {self.params['spatial_weight']}\n"
f"EPS Factor: {self.params['eps_factor']}\n"
f"Min Samples: {self.params['min_samples']}\n"
f"Merge Threshold: {self.params['merge_threshold']}s"
)
plt.figtext(0.8, 0.8, param_text, bbox=dict(facecolor='white', alpha=0.5))
ax.set_xlabel('X Position')
ax.set_ylabel('Y Position')
ax.set_zlabel('Time (s)')
plt.title("Gait Print Clustering Debug View")
plt.show()
def interactive_adjust(self):
"""交互式参数调整"""
while True:
print("\n当前参数:")
for k, v in self.params.items():
print(f"{k}: {v}")
try:
cmd = input("输入参数名和值 (如 'time_weight 0.6') 或 q退出: ").strip()
if cmd.lower() == 'q':
break
param, value = cmd.split()
if param not in self.params:
raise ValueError
# 类型转换
if param in ['time_weight', 'spatial_weight', 'eps_factor', 'merge_threshold']:
self.params[param] = float(value)
elif param == 'min_samples':
self.params[param] = int(value)
else:
raise ValueError
# 重新运行并可视化
labels = self.run_clustering()
self.visualize(labels)
except Exception as e:
print("输入无效,请按格式输入 (参数名 数值)")
def batch_parameter_search(self, output_dir="param_search", max_workers=None):
"""完整的多进程参数搜索实现"""
from pathlib import Path
import itertools
from concurrent.futures import ProcessPoolExecutor
# 创建输出目录
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
# 参数网格配置
param_grid = {
'spatial_weight': np.round(np.linspace(0.1, 5.0, 46), 2).tolist(), # 0.1-1.0步长0.02
'time_weight': [0.8], # 固定时间权重
'merge_threshold': [0.2], # 固定合并阈值
'eps_factor': [5.5], # 固定空间密度系数
'min_samples': [2] # 固定最小样本数
}
# 生成有效参数组合
active_params = {k: v for k, v in param_grid.items() if len(v) > 1}
keys = list(active_params.keys())
combinations = [dict(zip(keys, vals)) for vals in itertools.product(*active_params.values())]
# 配置多进程
max_workers = max_workers or (os.cpu_count() - 1 if os.cpu_count() else 1)
total = len(combinations)
# 进度条包装器
def process_wrapper():
with ProcessPoolExecutor(max_workers=max_workers) as executor:
args = ((i, comb, self.gait_prints, self.params)
for i, comb in enumerate(combinations))
yield from executor.map(_process_combination_wrapper, args, chunksize=5)
# 执行并收集结果
html_parts = []
with tqdm(total=total, desc="参数搜索进度") as pbar:
for result in process_wrapper():
html_parts.append(result)
pbar.update(1)
# 生成完整HTML报告
full_html = self._build_html_report(html_parts)
(output_path / "report.html").write_text(full_html)
print(f"参数搜索完成!结果保存在 {output_path.resolve()}")
def _build_html_report(self, html_parts):
"""构建完整的HTML报告结构"""
return f"""
<html>
<head>
<style>
/* 保持之前的样式不变 */
body {{ margin: 10px; background: #f5f5f5; }}
.grid {{ /* 样式细节 */ }}
/* 其他样式规则 */
</style>
</head>
<body>
<div class="grid">
{''.join(html_parts)}
</div>
</body>
</html>
"""
def _generate_plot(self, labels):
"""生成改进后的可视化布局"""
# 计算画布尺寸
max_x = max(p.x + p.w/2 for p in self.gait_prints) + 50
max_y = max(p.y + p.h/2 for p in self.gait_prints) + 50
canvas_size = (int(max_y), int(max_x), 3)
# 调整布局比例 (3:2)
fig = plt.figure(figsize=(24, 10))
gs = fig.add_gridspec(1, 2, width_ratios=[3, 2])
# 子图1:空间分布(加宽)
ax1 = fig.add_subplot(gs[0])
spatial_canvas = np.ones(canvas_size, dtype=np.uint8) * 255
unique_clusters = np.unique(labels)
cmap = plt.get_cmap('tab20')
for cluster_id in unique_clusters:
if cluster_id == -1: # 噪声点
continue
# 获取该簇的所有足印
cluster_points = [p for p, lbl in zip(self.gait_prints, labels) if lbl == cluster_id]
# 随机选择一个颜色
color = np.array(cmap(cluster_id % 20)) * 255
# 绘制每个足印的框
for p in cluster_points:
x = int(p.x - p.w/2)
y = int(p.y - p.h/2)
cv2.rectangle(spatial_canvas,
(x, y),
(x + int(p.w), y + int(p.h)),
color.tolist(), 2)
# 添加簇标签
label_pos = (x, y - 5)
cv2.putText(spatial_canvas, f"C{cluster_id}",
label_pos, cv2.FONT_HERSHEY_SIMPLEX,
0.5, color.tolist(), 1)
ax1.imshow(spatial_canvas)
ax1.set_title("Spatial Distribution (Clustered)", fontsize=14, pad=20)
ax1.axis('off')
# 子图2:时间轴(调整布局)
ax2 = fig.add_subplot(gs[1])
# 绘制时间轴...
for cluster_id in unique_clusters:
if cluster_id == -1:
continue
# 获取该簇的时间戳和X坐标
times = [p.timestamp for p, lbl in zip(self.gait_prints, labels) if lbl == cluster_id]
x_coords = [p.x for p, lbl in zip(self.gait_prints, labels) if lbl == cluster_id]
color = cmap(cluster_id % 20)
ax2.scatter(times, x_coords, color=color, s=40,
label=f'Cluster {cluster_id}')
ax2.set_xlabel('Time (s)', fontsize=12)
ax2.set_ylabel('X Position', fontsize=12)
ax2.set_title('Temporal Distribution', fontsize=14, pad=20)
ax2.grid(True, linestyle='--', alpha=0.6)
# 添加颜色图例
from matplotlib.patches import Patch
legend_elements = [
Patch(facecolor=cmap(i%20), label=f'Cluster {i}')
for i in unique_clusters if i != -1
]
ax2.legend(handles=legend_elements,
bbox_to_anchor=(1.05, 1),
loc='upper left',
title="Cluster ID")
plt.tight_layout()
return fig
def visualize_cluster_distribution(self, labels):
"""新增簇分布直方图"""
unique, counts = np.unique(labels, return_counts=True)
plt.figure(figsize=(10, 6))
plt.bar(unique, counts)
plt.xlabel('簇ID')
plt.ylabel('足印数量')
plt.title('簇分布直方图')
plt.grid(axis='y')
plt.savefig("cluster_distribution.png")
plt.close()
def load_debug_data(json_path):
"""加载预处理好的足印数据"""
with open(json_path) as f:
data = json.load(f)
gait_prints = []
for frame in data['frames']:
for fp in frame['footprints']:
gait_prints.append(GaitPrint(
frame_id=frame['frameId'],
x=fp['position']['x'] + fp['position']['width']/2,
y=fp['position']['y'] + fp['position']['height']/2,
w=fp['position']['width'],
h=fp['position']['height'],
conf=fp['confidence'],
timestamp=frame['frameId']/120.0,
paw_type=fp.get('type', 'unknown') # 从数据中获取类型
))
return gait_prints
if __name__ == "__main__":
# 添加项目根目录到PATH
import sys
from pathlib import Path
root_dir = str(Path(__file__).parent.parent.resolve())
if root_dir not in sys.path:
sys.path.insert(0, root_dir)
parser = argparse.ArgumentParser()
parser.add_argument("--json", required=True, help="足印数据JSON文件路径")
parser.add_argument("--batch-search", action="store_true",
help="执行批量参数搜索")
parser.add_argument("--max-workers", type=int, default=0,
help="最大并行进程数(0=自动)")
args = parser.parse_args()
print("加载数据...")
gait_prints = load_debug_data(args.json)
debugger = ClusterDebugger(gait_prints)
print("初始聚类...")
labels = debugger.run_clustering()
debugger.visualize(labels)
if args.batch_search:
debugger.batch_parameter_search()
else:
print("进入交互调试模式")
debugger.interactive_adjust() |