"""
步态聚类调试工具
用法:python gait_cluster_debugger.py --json 足印数据.json
"""
import argparse
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from sklearn.preprocessing import StandardScaler
from sklearn.cluster import DBSCAN
from sklearn.neighbors import NearestNeighbors
from gait_analyze import GaitAnalyzer # 从原文件导入类
import json
import os
from tqdm import tqdm
from dataclasses import dataclass
from pathlib import Path
import itertools
import cv2
import base64
from io import BytesIO
from concurrent.futures import ProcessPoolExecutor
import multiprocessing
from functools import wraps
@dataclass
class GaitPrint:
frame_id: int
x: float
y: float
w: float
h: float
conf: float
timestamp: float
paw_type: str = None
cluster_id: int = -1
# 其他字段按需添加...
def _process_combination_wrapper(args):
"""修复导入问题的处理函数"""
i, combination, gait_prints, params_template = args
# 改为绝对导入
from gait_cluster_debugger import ClusterDebugger # 移除相对导入
try:
debugger = ClusterDebugger(gait_prints)
debugger.params = params_template.copy()
debugger.params.update(combination)
# 执行聚类
labels = debugger.run_clustering()
# 生成可视化
fig = debugger._generate_plot(labels)
# 转换为Base64图片
from io import BytesIO
import base64
buf = BytesIO()
fig.savefig(buf, format='png', bbox_inches='tight', dpi=100)
plt.close(fig)
img_data = base64.b64encode(buf.getvalue()).decode('utf-8')
# 生成参数描述
param_desc = "
".join([f"{k}: {v}" for k, v in combination.items()])
# 构建HTML片段
html_snippet = f"""
{param_desc}
"""
return html_snippet
except Exception as e:
print(f"Error processing combination {i}: {str(e)}")
return ""
class ClusterDebugger:
def __init__(self, gait_prints):
self.gait_prints = gait_prints
self.params = {
'time_weight': 0.8, # 时间维度权重
'spatial_weight': 0.2, # 空间维度权重
'eps_factor': 5.5, # eps系数
'min_samples': 2, # 最小样本数
'merge_threshold': 0.2 # 合并阈值(秒)
}
def run_clustering(self):
"""执行聚类流程"""
# 准备特征
features = np.array([
[p.x * self.params['spatial_weight'],
p.y * self.params['spatial_weight'],
p.timestamp * self.params['time_weight']]
for p in self.gait_prints
])
# 标准化
scaler = StandardScaler()
features_scaled = scaler.fit_transform(features)
# 计算eps
k = min(len(features), 5)
nbrs = NearestNeighbors(n_neighbors=k).fit(features_scaled)
distances, _ = nbrs.kneighbors(features_scaled)
mean_dist = np.mean(distances[:, 1:])
eps = mean_dist * self.params['eps_factor']
# DBSCAN聚类
dbscan = DBSCAN(eps=eps, min_samples=self.params['min_samples'])
labels = dbscan.fit_predict(features_scaled)
# 合并时间连续的簇
self._merge_temporal_clusters(labels)
return labels
def _merge_temporal_clusters(self, labels):
"""简单的时间连续性合并"""
clusters = {}
for i, label in enumerate(labels):
if label not in clusters:
clusters[label] = []
clusters[label].append(self.gait_prints[i])
# 按时间排序并合并
new_labels = labels.copy()
current_label = max(labels) + 1
sorted_clusters = sorted(clusters.items(), key=lambda x: min(p.timestamp for p in x[1]))
for i in range(1, len(sorted_clusters)):
prev_label, prev_points = sorted_clusters[i-1]
curr_label, curr_points = sorted_clusters[i]
last_time = max(p.timestamp for p in prev_points)
first_time = min(p.timestamp for p in curr_points)
if first_time - last_time < self.params['merge_threshold']:
for p in curr_points:
new_labels[p.cluster_id] = prev_label
return new_labels
def visualize(self, labels):
"""交互式三维可视化"""
features = np.array([
[p.x, p.y, p.timestamp]
for p in self.gait_prints
])
plt.figure(figsize=(15, 8))
ax = plt.axes(projection='3d')
# 绘制聚类结果
scatter = ax.scatter3D(
features[:,0], features[:,1], features[:,2],
c=labels, cmap='tab20', alpha=0.8, s=50
)
# 标注参数
param_text = (
f"Time Weight: {self.params['time_weight']}\n"
f"Spatial Weight: {self.params['spatial_weight']}\n"
f"EPS Factor: {self.params['eps_factor']}\n"
f"Min Samples: {self.params['min_samples']}\n"
f"Merge Threshold: {self.params['merge_threshold']}s"
)
plt.figtext(0.8, 0.8, param_text, bbox=dict(facecolor='white', alpha=0.5))
ax.set_xlabel('X Position')
ax.set_ylabel('Y Position')
ax.set_zlabel('Time (s)')
plt.title("Gait Print Clustering Debug View")
plt.show()
def interactive_adjust(self):
"""交互式参数调整"""
while True:
print("\n当前参数:")
for k, v in self.params.items():
print(f"{k}: {v}")
try:
cmd = input("输入参数名和值 (如 'time_weight 0.6') 或 q退出: ").strip()
if cmd.lower() == 'q':
break
param, value = cmd.split()
if param not in self.params:
raise ValueError
# 类型转换
if param in ['time_weight', 'spatial_weight', 'eps_factor', 'merge_threshold']:
self.params[param] = float(value)
elif param == 'min_samples':
self.params[param] = int(value)
else:
raise ValueError
# 重新运行并可视化
labels = self.run_clustering()
self.visualize(labels)
except Exception as e:
print("输入无效,请按格式输入 (参数名 数值)")
def batch_parameter_search(self, output_dir="param_search", max_workers=None):
"""完整的多进程参数搜索实现"""
from pathlib import Path
import itertools
from concurrent.futures import ProcessPoolExecutor
# 创建输出目录
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
# 参数网格配置
param_grid = {
'spatial_weight': np.round(np.linspace(0.1, 5.0, 46), 2).tolist(), # 0.1-1.0步长0.02
'time_weight': [0.8], # 固定时间权重
'merge_threshold': [0.2], # 固定合并阈值
'eps_factor': [5.5], # 固定空间密度系数
'min_samples': [2] # 固定最小样本数
}
# 生成有效参数组合
active_params = {k: v for k, v in param_grid.items() if len(v) > 1}
keys = list(active_params.keys())
combinations = [dict(zip(keys, vals)) for vals in itertools.product(*active_params.values())]
# 配置多进程
max_workers = max_workers or (os.cpu_count() - 1 if os.cpu_count() else 1)
total = len(combinations)
# 进度条包装器
def process_wrapper():
with ProcessPoolExecutor(max_workers=max_workers) as executor:
args = ((i, comb, self.gait_prints, self.params)
for i, comb in enumerate(combinations))
yield from executor.map(_process_combination_wrapper, args, chunksize=5)
# 执行并收集结果
html_parts = []
with tqdm(total=total, desc="参数搜索进度") as pbar:
for result in process_wrapper():
html_parts.append(result)
pbar.update(1)
# 生成完整HTML报告
full_html = self._build_html_report(html_parts)
(output_path / "report.html").write_text(full_html)
print(f"参数搜索完成!结果保存在 {output_path.resolve()}")
def _build_html_report(self, html_parts):
"""构建完整的HTML报告结构"""
return f"""
{''.join(html_parts)}
"""
def _generate_plot(self, labels):
"""生成改进后的可视化布局"""
# 计算画布尺寸
max_x = max(p.x + p.w/2 for p in self.gait_prints) + 50
max_y = max(p.y + p.h/2 for p in self.gait_prints) + 50
canvas_size = (int(max_y), int(max_x), 3)
# 调整布局比例 (3:2)
fig = plt.figure(figsize=(24, 10))
gs = fig.add_gridspec(1, 2, width_ratios=[3, 2])
# 子图1:空间分布(加宽)
ax1 = fig.add_subplot(gs[0])
spatial_canvas = np.ones(canvas_size, dtype=np.uint8) * 255
unique_clusters = np.unique(labels)
cmap = plt.get_cmap('tab20')
for cluster_id in unique_clusters:
if cluster_id == -1: # 噪声点
continue
# 获取该簇的所有足印
cluster_points = [p for p, lbl in zip(self.gait_prints, labels) if lbl == cluster_id]
# 随机选择一个颜色
color = np.array(cmap(cluster_id % 20)) * 255
# 绘制每个足印的框
for p in cluster_points:
x = int(p.x - p.w/2)
y = int(p.y - p.h/2)
cv2.rectangle(spatial_canvas,
(x, y),
(x + int(p.w), y + int(p.h)),
color.tolist(), 2)
# 添加簇标签
label_pos = (x, y - 5)
cv2.putText(spatial_canvas, f"C{cluster_id}",
label_pos, cv2.FONT_HERSHEY_SIMPLEX,
0.5, color.tolist(), 1)
ax1.imshow(spatial_canvas)
ax1.set_title("Spatial Distribution (Clustered)", fontsize=14, pad=20)
ax1.axis('off')
# 子图2:时间轴(调整布局)
ax2 = fig.add_subplot(gs[1])
# 绘制时间轴...
for cluster_id in unique_clusters:
if cluster_id == -1:
continue
# 获取该簇的时间戳和X坐标
times = [p.timestamp for p, lbl in zip(self.gait_prints, labels) if lbl == cluster_id]
x_coords = [p.x for p, lbl in zip(self.gait_prints, labels) if lbl == cluster_id]
color = cmap(cluster_id % 20)
ax2.scatter(times, x_coords, color=color, s=40,
label=f'Cluster {cluster_id}')
ax2.set_xlabel('Time (s)', fontsize=12)
ax2.set_ylabel('X Position', fontsize=12)
ax2.set_title('Temporal Distribution', fontsize=14, pad=20)
ax2.grid(True, linestyle='--', alpha=0.6)
# 添加颜色图例
from matplotlib.patches import Patch
legend_elements = [
Patch(facecolor=cmap(i%20), label=f'Cluster {i}')
for i in unique_clusters if i != -1
]
ax2.legend(handles=legend_elements,
bbox_to_anchor=(1.05, 1),
loc='upper left',
title="Cluster ID")
plt.tight_layout()
return fig
def visualize_cluster_distribution(self, labels):
"""新增簇分布直方图"""
unique, counts = np.unique(labels, return_counts=True)
plt.figure(figsize=(10, 6))
plt.bar(unique, counts)
plt.xlabel('簇ID')
plt.ylabel('足印数量')
plt.title('簇分布直方图')
plt.grid(axis='y')
plt.savefig("cluster_distribution.png")
plt.close()
def load_debug_data(json_path):
"""加载预处理好的足印数据"""
with open(json_path) as f:
data = json.load(f)
gait_prints = []
for frame in data['frames']:
for fp in frame['footprints']:
gait_prints.append(GaitPrint(
frame_id=frame['frameId'],
x=fp['position']['x'] + fp['position']['width']/2,
y=fp['position']['y'] + fp['position']['height']/2,
w=fp['position']['width'],
h=fp['position']['height'],
conf=fp['confidence'],
timestamp=frame['frameId']/120.0,
paw_type=fp.get('type', 'unknown') # 从数据中获取类型
))
return gait_prints
if __name__ == "__main__":
# 添加项目根目录到PATH
import sys
from pathlib import Path
root_dir = str(Path(__file__).parent.parent.resolve())
if root_dir not in sys.path:
sys.path.insert(0, root_dir)
parser = argparse.ArgumentParser()
parser.add_argument("--json", required=True, help="足印数据JSON文件路径")
parser.add_argument("--batch-search", action="store_true",
help="执行批量参数搜索")
parser.add_argument("--max-workers", type=int, default=0,
help="最大并行进程数(0=自动)")
args = parser.parse_args()
print("加载数据...")
gait_prints = load_debug_data(args.json)
debugger = ClusterDebugger(gait_prints)
print("初始聚类...")
labels = debugger.run_clustering()
debugger.visualize(labels)
if args.batch_search:
debugger.batch_parameter_search()
else:
print("进入交互调试模式")
debugger.interactive_adjust()