""" 步态聚类调试工具 用法:python gait_cluster_debugger.py --json 足印数据.json """ import argparse import numpy as np import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D from sklearn.preprocessing import StandardScaler from sklearn.cluster import DBSCAN from sklearn.neighbors import NearestNeighbors from gait_analyze import GaitAnalyzer # 从原文件导入类 import json import os from tqdm import tqdm from dataclasses import dataclass from pathlib import Path import itertools import cv2 import base64 from io import BytesIO from concurrent.futures import ProcessPoolExecutor import multiprocessing from functools import wraps @dataclass class GaitPrint: frame_id: int x: float y: float w: float h: float conf: float timestamp: float paw_type: str = None cluster_id: int = -1 # 其他字段按需添加... def _process_combination_wrapper(args): """修复导入问题的处理函数""" i, combination, gait_prints, params_template = args # 改为绝对导入 from gait_cluster_debugger import ClusterDebugger # 移除相对导入 try: debugger = ClusterDebugger(gait_prints) debugger.params = params_template.copy() debugger.params.update(combination) # 执行聚类 labels = debugger.run_clustering() # 生成可视化 fig = debugger._generate_plot(labels) # 转换为Base64图片 from io import BytesIO import base64 buf = BytesIO() fig.savefig(buf, format='png', bbox_inches='tight', dpi=100) plt.close(fig) img_data = base64.b64encode(buf.getvalue()).decode('utf-8') # 生成参数描述 param_desc = "
".join([f"{k}: {v}" for k, v in combination.items()]) # 构建HTML片段 html_snippet = f"""
Combination #{i+1}
{param_desc}
""" return html_snippet except Exception as e: print(f"Error processing combination {i}: {str(e)}") return "" class ClusterDebugger: def __init__(self, gait_prints): self.gait_prints = gait_prints self.params = { 'time_weight': 0.8, # 时间维度权重 'spatial_weight': 0.2, # 空间维度权重 'eps_factor': 5.5, # eps系数 'min_samples': 2, # 最小样本数 'merge_threshold': 0.2 # 合并阈值(秒) } def run_clustering(self): """执行聚类流程""" # 准备特征 features = np.array([ [p.x * self.params['spatial_weight'], p.y * self.params['spatial_weight'], p.timestamp * self.params['time_weight']] for p in self.gait_prints ]) # 标准化 scaler = StandardScaler() features_scaled = scaler.fit_transform(features) # 计算eps k = min(len(features), 5) nbrs = NearestNeighbors(n_neighbors=k).fit(features_scaled) distances, _ = nbrs.kneighbors(features_scaled) mean_dist = np.mean(distances[:, 1:]) eps = mean_dist * self.params['eps_factor'] # DBSCAN聚类 dbscan = DBSCAN(eps=eps, min_samples=self.params['min_samples']) labels = dbscan.fit_predict(features_scaled) # 合并时间连续的簇 self._merge_temporal_clusters(labels) return labels def _merge_temporal_clusters(self, labels): """简单的时间连续性合并""" clusters = {} for i, label in enumerate(labels): if label not in clusters: clusters[label] = [] clusters[label].append(self.gait_prints[i]) # 按时间排序并合并 new_labels = labels.copy() current_label = max(labels) + 1 sorted_clusters = sorted(clusters.items(), key=lambda x: min(p.timestamp for p in x[1])) for i in range(1, len(sorted_clusters)): prev_label, prev_points = sorted_clusters[i-1] curr_label, curr_points = sorted_clusters[i] last_time = max(p.timestamp for p in prev_points) first_time = min(p.timestamp for p in curr_points) if first_time - last_time < self.params['merge_threshold']: for p in curr_points: new_labels[p.cluster_id] = prev_label return new_labels def visualize(self, labels): """交互式三维可视化""" features = np.array([ [p.x, p.y, p.timestamp] for p in self.gait_prints ]) plt.figure(figsize=(15, 8)) ax = plt.axes(projection='3d') # 绘制聚类结果 scatter = ax.scatter3D( features[:,0], features[:,1], features[:,2], c=labels, cmap='tab20', alpha=0.8, s=50 ) # 标注参数 param_text = ( f"Time Weight: {self.params['time_weight']}\n" f"Spatial Weight: {self.params['spatial_weight']}\n" f"EPS Factor: {self.params['eps_factor']}\n" f"Min Samples: {self.params['min_samples']}\n" f"Merge Threshold: {self.params['merge_threshold']}s" ) plt.figtext(0.8, 0.8, param_text, bbox=dict(facecolor='white', alpha=0.5)) ax.set_xlabel('X Position') ax.set_ylabel('Y Position') ax.set_zlabel('Time (s)') plt.title("Gait Print Clustering Debug View") plt.show() def interactive_adjust(self): """交互式参数调整""" while True: print("\n当前参数:") for k, v in self.params.items(): print(f"{k}: {v}") try: cmd = input("输入参数名和值 (如 'time_weight 0.6') 或 q退出: ").strip() if cmd.lower() == 'q': break param, value = cmd.split() if param not in self.params: raise ValueError # 类型转换 if param in ['time_weight', 'spatial_weight', 'eps_factor', 'merge_threshold']: self.params[param] = float(value) elif param == 'min_samples': self.params[param] = int(value) else: raise ValueError # 重新运行并可视化 labels = self.run_clustering() self.visualize(labels) except Exception as e: print("输入无效,请按格式输入 (参数名 数值)") def batch_parameter_search(self, output_dir="param_search", max_workers=None): """完整的多进程参数搜索实现""" from pathlib import Path import itertools from concurrent.futures import ProcessPoolExecutor # 创建输出目录 output_path = Path(output_dir) output_path.mkdir(parents=True, exist_ok=True) # 参数网格配置 param_grid = { 'spatial_weight': np.round(np.linspace(0.1, 5.0, 46), 2).tolist(), # 0.1-1.0步长0.02 'time_weight': [0.8], # 固定时间权重 'merge_threshold': [0.2], # 固定合并阈值 'eps_factor': [5.5], # 固定空间密度系数 'min_samples': [2] # 固定最小样本数 } # 生成有效参数组合 active_params = {k: v for k, v in param_grid.items() if len(v) > 1} keys = list(active_params.keys()) combinations = [dict(zip(keys, vals)) for vals in itertools.product(*active_params.values())] # 配置多进程 max_workers = max_workers or (os.cpu_count() - 1 if os.cpu_count() else 1) total = len(combinations) # 进度条包装器 def process_wrapper(): with ProcessPoolExecutor(max_workers=max_workers) as executor: args = ((i, comb, self.gait_prints, self.params) for i, comb in enumerate(combinations)) yield from executor.map(_process_combination_wrapper, args, chunksize=5) # 执行并收集结果 html_parts = [] with tqdm(total=total, desc="参数搜索进度") as pbar: for result in process_wrapper(): html_parts.append(result) pbar.update(1) # 生成完整HTML报告 full_html = self._build_html_report(html_parts) (output_path / "report.html").write_text(full_html) print(f"参数搜索完成!结果保存在 {output_path.resolve()}") def _build_html_report(self, html_parts): """构建完整的HTML报告结构""" return f"""
{''.join(html_parts)}
""" def _generate_plot(self, labels): """生成改进后的可视化布局""" # 计算画布尺寸 max_x = max(p.x + p.w/2 for p in self.gait_prints) + 50 max_y = max(p.y + p.h/2 for p in self.gait_prints) + 50 canvas_size = (int(max_y), int(max_x), 3) # 调整布局比例 (3:2) fig = plt.figure(figsize=(24, 10)) gs = fig.add_gridspec(1, 2, width_ratios=[3, 2]) # 子图1:空间分布(加宽) ax1 = fig.add_subplot(gs[0]) spatial_canvas = np.ones(canvas_size, dtype=np.uint8) * 255 unique_clusters = np.unique(labels) cmap = plt.get_cmap('tab20') for cluster_id in unique_clusters: if cluster_id == -1: # 噪声点 continue # 获取该簇的所有足印 cluster_points = [p for p, lbl in zip(self.gait_prints, labels) if lbl == cluster_id] # 随机选择一个颜色 color = np.array(cmap(cluster_id % 20)) * 255 # 绘制每个足印的框 for p in cluster_points: x = int(p.x - p.w/2) y = int(p.y - p.h/2) cv2.rectangle(spatial_canvas, (x, y), (x + int(p.w), y + int(p.h)), color.tolist(), 2) # 添加簇标签 label_pos = (x, y - 5) cv2.putText(spatial_canvas, f"C{cluster_id}", label_pos, cv2.FONT_HERSHEY_SIMPLEX, 0.5, color.tolist(), 1) ax1.imshow(spatial_canvas) ax1.set_title("Spatial Distribution (Clustered)", fontsize=14, pad=20) ax1.axis('off') # 子图2:时间轴(调整布局) ax2 = fig.add_subplot(gs[1]) # 绘制时间轴... for cluster_id in unique_clusters: if cluster_id == -1: continue # 获取该簇的时间戳和X坐标 times = [p.timestamp for p, lbl in zip(self.gait_prints, labels) if lbl == cluster_id] x_coords = [p.x for p, lbl in zip(self.gait_prints, labels) if lbl == cluster_id] color = cmap(cluster_id % 20) ax2.scatter(times, x_coords, color=color, s=40, label=f'Cluster {cluster_id}') ax2.set_xlabel('Time (s)', fontsize=12) ax2.set_ylabel('X Position', fontsize=12) ax2.set_title('Temporal Distribution', fontsize=14, pad=20) ax2.grid(True, linestyle='--', alpha=0.6) # 添加颜色图例 from matplotlib.patches import Patch legend_elements = [ Patch(facecolor=cmap(i%20), label=f'Cluster {i}') for i in unique_clusters if i != -1 ] ax2.legend(handles=legend_elements, bbox_to_anchor=(1.05, 1), loc='upper left', title="Cluster ID") plt.tight_layout() return fig def visualize_cluster_distribution(self, labels): """新增簇分布直方图""" unique, counts = np.unique(labels, return_counts=True) plt.figure(figsize=(10, 6)) plt.bar(unique, counts) plt.xlabel('簇ID') plt.ylabel('足印数量') plt.title('簇分布直方图') plt.grid(axis='y') plt.savefig("cluster_distribution.png") plt.close() def load_debug_data(json_path): """加载预处理好的足印数据""" with open(json_path) as f: data = json.load(f) gait_prints = [] for frame in data['frames']: for fp in frame['footprints']: gait_prints.append(GaitPrint( frame_id=frame['frameId'], x=fp['position']['x'] + fp['position']['width']/2, y=fp['position']['y'] + fp['position']['height']/2, w=fp['position']['width'], h=fp['position']['height'], conf=fp['confidence'], timestamp=frame['frameId']/120.0, paw_type=fp.get('type', 'unknown') # 从数据中获取类型 )) return gait_prints if __name__ == "__main__": # 添加项目根目录到PATH import sys from pathlib import Path root_dir = str(Path(__file__).parent.parent.resolve()) if root_dir not in sys.path: sys.path.insert(0, root_dir) parser = argparse.ArgumentParser() parser.add_argument("--json", required=True, help="足印数据JSON文件路径") parser.add_argument("--batch-search", action="store_true", help="执行批量参数搜索") parser.add_argument("--max-workers", type=int, default=0, help="最大并行进程数(0=自动)") args = parser.parse_args() print("加载数据...") gait_prints = load_debug_data(args.json) debugger = ClusterDebugger(gait_prints) print("初始聚类...") labels = debugger.run_clustering() debugger.visualize(labels) if args.batch_search: debugger.batch_parameter_search() else: print("进入交互调试模式") debugger.interactive_adjust()