|
|
""" |
|
|
步态聚类调试工具 |
|
|
用法:python gait_cluster_debugger.py --json 足印数据.json |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import numpy as np |
|
|
import matplotlib.pyplot as plt |
|
|
from mpl_toolkits.mplot3d import Axes3D |
|
|
from sklearn.preprocessing import StandardScaler |
|
|
from sklearn.cluster import DBSCAN |
|
|
from sklearn.neighbors import NearestNeighbors |
|
|
from gait_analyze import GaitAnalyzer |
|
|
import json |
|
|
import os |
|
|
from tqdm import tqdm |
|
|
from dataclasses import dataclass |
|
|
from pathlib import Path |
|
|
import itertools |
|
|
import cv2 |
|
|
import base64 |
|
|
from io import BytesIO |
|
|
from concurrent.futures import ProcessPoolExecutor |
|
|
import multiprocessing |
|
|
from functools import wraps |
|
|
|
|
|
@dataclass |
|
|
class GaitPrint: |
|
|
frame_id: int |
|
|
x: float |
|
|
y: float |
|
|
w: float |
|
|
h: float |
|
|
conf: float |
|
|
timestamp: float |
|
|
paw_type: str = None |
|
|
cluster_id: int = -1 |
|
|
|
|
|
|
|
|
def _process_combination_wrapper(args): |
|
|
"""修复导入问题的处理函数""" |
|
|
i, combination, gait_prints, params_template = args |
|
|
|
|
|
from gait_cluster_debugger import ClusterDebugger |
|
|
|
|
|
try: |
|
|
debugger = ClusterDebugger(gait_prints) |
|
|
debugger.params = params_template.copy() |
|
|
debugger.params.update(combination) |
|
|
|
|
|
|
|
|
labels = debugger.run_clustering() |
|
|
|
|
|
|
|
|
fig = debugger._generate_plot(labels) |
|
|
|
|
|
|
|
|
from io import BytesIO |
|
|
import base64 |
|
|
buf = BytesIO() |
|
|
fig.savefig(buf, format='png', bbox_inches='tight', dpi=100) |
|
|
plt.close(fig) |
|
|
img_data = base64.b64encode(buf.getvalue()).decode('utf-8') |
|
|
|
|
|
|
|
|
param_desc = "<br>".join([f"{k}: {v}" for k, v in combination.items()]) |
|
|
|
|
|
|
|
|
html_snippet = f""" |
|
|
<div class="item"> |
|
|
<div class="param-header">Combination #{i+1}</div> |
|
|
<div class="param-desc">{param_desc}</div> |
|
|
<img src="data:image/png;base64,{img_data}"> |
|
|
</div> |
|
|
""" |
|
|
return html_snippet |
|
|
except Exception as e: |
|
|
print(f"Error processing combination {i}: {str(e)}") |
|
|
return "" |
|
|
|
|
|
class ClusterDebugger: |
|
|
def __init__(self, gait_prints): |
|
|
self.gait_prints = gait_prints |
|
|
self.params = { |
|
|
'time_weight': 0.8, |
|
|
'spatial_weight': 0.2, |
|
|
'eps_factor': 5.5, |
|
|
'min_samples': 2, |
|
|
'merge_threshold': 0.2 |
|
|
} |
|
|
|
|
|
def run_clustering(self): |
|
|
"""执行聚类流程""" |
|
|
|
|
|
features = np.array([ |
|
|
[p.x * self.params['spatial_weight'], |
|
|
p.y * self.params['spatial_weight'], |
|
|
p.timestamp * self.params['time_weight']] |
|
|
for p in self.gait_prints |
|
|
]) |
|
|
|
|
|
|
|
|
scaler = StandardScaler() |
|
|
features_scaled = scaler.fit_transform(features) |
|
|
|
|
|
|
|
|
k = min(len(features), 5) |
|
|
nbrs = NearestNeighbors(n_neighbors=k).fit(features_scaled) |
|
|
distances, _ = nbrs.kneighbors(features_scaled) |
|
|
mean_dist = np.mean(distances[:, 1:]) |
|
|
eps = mean_dist * self.params['eps_factor'] |
|
|
|
|
|
|
|
|
dbscan = DBSCAN(eps=eps, min_samples=self.params['min_samples']) |
|
|
labels = dbscan.fit_predict(features_scaled) |
|
|
|
|
|
|
|
|
self._merge_temporal_clusters(labels) |
|
|
return labels |
|
|
|
|
|
def _merge_temporal_clusters(self, labels): |
|
|
"""简单的时间连续性合并""" |
|
|
clusters = {} |
|
|
for i, label in enumerate(labels): |
|
|
if label not in clusters: |
|
|
clusters[label] = [] |
|
|
clusters[label].append(self.gait_prints[i]) |
|
|
|
|
|
|
|
|
new_labels = labels.copy() |
|
|
current_label = max(labels) + 1 |
|
|
sorted_clusters = sorted(clusters.items(), key=lambda x: min(p.timestamp for p in x[1])) |
|
|
|
|
|
for i in range(1, len(sorted_clusters)): |
|
|
prev_label, prev_points = sorted_clusters[i-1] |
|
|
curr_label, curr_points = sorted_clusters[i] |
|
|
|
|
|
last_time = max(p.timestamp for p in prev_points) |
|
|
first_time = min(p.timestamp for p in curr_points) |
|
|
|
|
|
if first_time - last_time < self.params['merge_threshold']: |
|
|
for p in curr_points: |
|
|
new_labels[p.cluster_id] = prev_label |
|
|
|
|
|
return new_labels |
|
|
|
|
|
def visualize(self, labels): |
|
|
"""交互式三维可视化""" |
|
|
features = np.array([ |
|
|
[p.x, p.y, p.timestamp] |
|
|
for p in self.gait_prints |
|
|
]) |
|
|
|
|
|
plt.figure(figsize=(15, 8)) |
|
|
ax = plt.axes(projection='3d') |
|
|
|
|
|
|
|
|
scatter = ax.scatter3D( |
|
|
features[:,0], features[:,1], features[:,2], |
|
|
c=labels, cmap='tab20', alpha=0.8, s=50 |
|
|
) |
|
|
|
|
|
|
|
|
param_text = ( |
|
|
f"Time Weight: {self.params['time_weight']}\n" |
|
|
f"Spatial Weight: {self.params['spatial_weight']}\n" |
|
|
f"EPS Factor: {self.params['eps_factor']}\n" |
|
|
f"Min Samples: {self.params['min_samples']}\n" |
|
|
f"Merge Threshold: {self.params['merge_threshold']}s" |
|
|
) |
|
|
plt.figtext(0.8, 0.8, param_text, bbox=dict(facecolor='white', alpha=0.5)) |
|
|
|
|
|
ax.set_xlabel('X Position') |
|
|
ax.set_ylabel('Y Position') |
|
|
ax.set_zlabel('Time (s)') |
|
|
plt.title("Gait Print Clustering Debug View") |
|
|
plt.show() |
|
|
|
|
|
def interactive_adjust(self): |
|
|
"""交互式参数调整""" |
|
|
while True: |
|
|
print("\n当前参数:") |
|
|
for k, v in self.params.items(): |
|
|
print(f"{k}: {v}") |
|
|
|
|
|
try: |
|
|
cmd = input("输入参数名和值 (如 'time_weight 0.6') 或 q退出: ").strip() |
|
|
if cmd.lower() == 'q': |
|
|
break |
|
|
|
|
|
param, value = cmd.split() |
|
|
if param not in self.params: |
|
|
raise ValueError |
|
|
|
|
|
|
|
|
if param in ['time_weight', 'spatial_weight', 'eps_factor', 'merge_threshold']: |
|
|
self.params[param] = float(value) |
|
|
elif param == 'min_samples': |
|
|
self.params[param] = int(value) |
|
|
else: |
|
|
raise ValueError |
|
|
|
|
|
|
|
|
labels = self.run_clustering() |
|
|
self.visualize(labels) |
|
|
|
|
|
except Exception as e: |
|
|
print("输入无效,请按格式输入 (参数名 数值)") |
|
|
|
|
|
def batch_parameter_search(self, output_dir="param_search", max_workers=None): |
|
|
"""完整的多进程参数搜索实现""" |
|
|
from pathlib import Path |
|
|
import itertools |
|
|
from concurrent.futures import ProcessPoolExecutor |
|
|
|
|
|
|
|
|
output_path = Path(output_dir) |
|
|
output_path.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
param_grid = { |
|
|
'spatial_weight': np.round(np.linspace(0.1, 5.0, 46), 2).tolist(), |
|
|
'time_weight': [0.8], |
|
|
'merge_threshold': [0.2], |
|
|
'eps_factor': [5.5], |
|
|
'min_samples': [2] |
|
|
} |
|
|
|
|
|
|
|
|
active_params = {k: v for k, v in param_grid.items() if len(v) > 1} |
|
|
keys = list(active_params.keys()) |
|
|
combinations = [dict(zip(keys, vals)) for vals in itertools.product(*active_params.values())] |
|
|
|
|
|
|
|
|
max_workers = max_workers or (os.cpu_count() - 1 if os.cpu_count() else 1) |
|
|
total = len(combinations) |
|
|
|
|
|
|
|
|
def process_wrapper(): |
|
|
with ProcessPoolExecutor(max_workers=max_workers) as executor: |
|
|
args = ((i, comb, self.gait_prints, self.params) |
|
|
for i, comb in enumerate(combinations)) |
|
|
yield from executor.map(_process_combination_wrapper, args, chunksize=5) |
|
|
|
|
|
|
|
|
html_parts = [] |
|
|
with tqdm(total=total, desc="参数搜索进度") as pbar: |
|
|
for result in process_wrapper(): |
|
|
html_parts.append(result) |
|
|
pbar.update(1) |
|
|
|
|
|
|
|
|
full_html = self._build_html_report(html_parts) |
|
|
(output_path / "report.html").write_text(full_html) |
|
|
print(f"参数搜索完成!结果保存在 {output_path.resolve()}") |
|
|
|
|
|
def _build_html_report(self, html_parts): |
|
|
"""构建完整的HTML报告结构""" |
|
|
return f""" |
|
|
<html> |
|
|
<head> |
|
|
<style> |
|
|
/* 保持之前的样式不变 */ |
|
|
body {{ margin: 10px; background: #f5f5f5; }} |
|
|
.grid {{ /* 样式细节 */ }} |
|
|
/* 其他样式规则 */ |
|
|
</style> |
|
|
</head> |
|
|
<body> |
|
|
<div class="grid"> |
|
|
{''.join(html_parts)} |
|
|
</div> |
|
|
</body> |
|
|
</html> |
|
|
""" |
|
|
|
|
|
def _generate_plot(self, labels): |
|
|
"""生成改进后的可视化布局""" |
|
|
|
|
|
max_x = max(p.x + p.w/2 for p in self.gait_prints) + 50 |
|
|
max_y = max(p.y + p.h/2 for p in self.gait_prints) + 50 |
|
|
canvas_size = (int(max_y), int(max_x), 3) |
|
|
|
|
|
|
|
|
fig = plt.figure(figsize=(24, 10)) |
|
|
gs = fig.add_gridspec(1, 2, width_ratios=[3, 2]) |
|
|
|
|
|
|
|
|
ax1 = fig.add_subplot(gs[0]) |
|
|
spatial_canvas = np.ones(canvas_size, dtype=np.uint8) * 255 |
|
|
unique_clusters = np.unique(labels) |
|
|
cmap = plt.get_cmap('tab20') |
|
|
|
|
|
for cluster_id in unique_clusters: |
|
|
if cluster_id == -1: |
|
|
continue |
|
|
|
|
|
cluster_points = [p for p, lbl in zip(self.gait_prints, labels) if lbl == cluster_id] |
|
|
|
|
|
color = np.array(cmap(cluster_id % 20)) * 255 |
|
|
|
|
|
|
|
|
for p in cluster_points: |
|
|
x = int(p.x - p.w/2) |
|
|
y = int(p.y - p.h/2) |
|
|
cv2.rectangle(spatial_canvas, |
|
|
(x, y), |
|
|
(x + int(p.w), y + int(p.h)), |
|
|
color.tolist(), 2) |
|
|
|
|
|
|
|
|
label_pos = (x, y - 5) |
|
|
cv2.putText(spatial_canvas, f"C{cluster_id}", |
|
|
label_pos, cv2.FONT_HERSHEY_SIMPLEX, |
|
|
0.5, color.tolist(), 1) |
|
|
|
|
|
ax1.imshow(spatial_canvas) |
|
|
ax1.set_title("Spatial Distribution (Clustered)", fontsize=14, pad=20) |
|
|
ax1.axis('off') |
|
|
|
|
|
|
|
|
ax2 = fig.add_subplot(gs[1]) |
|
|
|
|
|
|
|
|
for cluster_id in unique_clusters: |
|
|
if cluster_id == -1: |
|
|
continue |
|
|
|
|
|
times = [p.timestamp for p, lbl in zip(self.gait_prints, labels) if lbl == cluster_id] |
|
|
x_coords = [p.x for p, lbl in zip(self.gait_prints, labels) if lbl == cluster_id] |
|
|
color = cmap(cluster_id % 20) |
|
|
|
|
|
ax2.scatter(times, x_coords, color=color, s=40, |
|
|
label=f'Cluster {cluster_id}') |
|
|
|
|
|
ax2.set_xlabel('Time (s)', fontsize=12) |
|
|
ax2.set_ylabel('X Position', fontsize=12) |
|
|
ax2.set_title('Temporal Distribution', fontsize=14, pad=20) |
|
|
ax2.grid(True, linestyle='--', alpha=0.6) |
|
|
|
|
|
|
|
|
from matplotlib.patches import Patch |
|
|
legend_elements = [ |
|
|
Patch(facecolor=cmap(i%20), label=f'Cluster {i}') |
|
|
for i in unique_clusters if i != -1 |
|
|
] |
|
|
ax2.legend(handles=legend_elements, |
|
|
bbox_to_anchor=(1.05, 1), |
|
|
loc='upper left', |
|
|
title="Cluster ID") |
|
|
|
|
|
plt.tight_layout() |
|
|
return fig |
|
|
|
|
|
def visualize_cluster_distribution(self, labels): |
|
|
"""新增簇分布直方图""" |
|
|
unique, counts = np.unique(labels, return_counts=True) |
|
|
plt.figure(figsize=(10, 6)) |
|
|
plt.bar(unique, counts) |
|
|
plt.xlabel('簇ID') |
|
|
plt.ylabel('足印数量') |
|
|
plt.title('簇分布直方图') |
|
|
plt.grid(axis='y') |
|
|
plt.savefig("cluster_distribution.png") |
|
|
plt.close() |
|
|
|
|
|
def load_debug_data(json_path): |
|
|
"""加载预处理好的足印数据""" |
|
|
with open(json_path) as f: |
|
|
data = json.load(f) |
|
|
|
|
|
gait_prints = [] |
|
|
for frame in data['frames']: |
|
|
for fp in frame['footprints']: |
|
|
gait_prints.append(GaitPrint( |
|
|
frame_id=frame['frameId'], |
|
|
x=fp['position']['x'] + fp['position']['width']/2, |
|
|
y=fp['position']['y'] + fp['position']['height']/2, |
|
|
w=fp['position']['width'], |
|
|
h=fp['position']['height'], |
|
|
conf=fp['confidence'], |
|
|
timestamp=frame['frameId']/120.0, |
|
|
paw_type=fp.get('type', 'unknown') |
|
|
)) |
|
|
return gait_prints |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
import sys |
|
|
from pathlib import Path |
|
|
root_dir = str(Path(__file__).parent.parent.resolve()) |
|
|
if root_dir not in sys.path: |
|
|
sys.path.insert(0, root_dir) |
|
|
|
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument("--json", required=True, help="足印数据JSON文件路径") |
|
|
parser.add_argument("--batch-search", action="store_true", |
|
|
help="执行批量参数搜索") |
|
|
parser.add_argument("--max-workers", type=int, default=0, |
|
|
help="最大并行进程数(0=自动)") |
|
|
args = parser.parse_args() |
|
|
|
|
|
print("加载数据...") |
|
|
gait_prints = load_debug_data(args.json) |
|
|
|
|
|
debugger = ClusterDebugger(gait_prints) |
|
|
print("初始聚类...") |
|
|
labels = debugger.run_clustering() |
|
|
debugger.visualize(labels) |
|
|
|
|
|
if args.batch_search: |
|
|
debugger.batch_parameter_search() |
|
|
else: |
|
|
print("进入交互调试模式") |
|
|
debugger.interactive_adjust() |