File size: 14,989 Bytes
71d0872
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
"""
步态聚类调试工具
用法:python gait_cluster_debugger.py --json 足印数据.json
"""

import argparse
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from sklearn.preprocessing import StandardScaler
from sklearn.cluster import DBSCAN
from sklearn.neighbors import NearestNeighbors
from gait_analyze import GaitAnalyzer  # 从原文件导入类
import json
import os
from tqdm import tqdm
from dataclasses import dataclass
from pathlib import Path
import itertools
import cv2
import base64
from io import BytesIO
from concurrent.futures import ProcessPoolExecutor
import multiprocessing
from functools import wraps

@dataclass
class GaitPrint:
    frame_id: int
    x: float
    y: float 
    w: float
    h: float
    conf: float
    timestamp: float
    paw_type: str = None
    cluster_id: int = -1
    # 其他字段按需添加...

def _process_combination_wrapper(args):
    """修复导入问题的处理函数"""
    i, combination, gait_prints, params_template = args
    # 改为绝对导入
    from gait_cluster_debugger import ClusterDebugger  # 移除相对导入
    
    try:
        debugger = ClusterDebugger(gait_prints)
        debugger.params = params_template.copy()
        debugger.params.update(combination)
        
        # 执行聚类
        labels = debugger.run_clustering()
        
        # 生成可视化
        fig = debugger._generate_plot(labels)
        
        # 转换为Base64图片
        from io import BytesIO
        import base64
        buf = BytesIO()
        fig.savefig(buf, format='png', bbox_inches='tight', dpi=100)
        plt.close(fig)
        img_data = base64.b64encode(buf.getvalue()).decode('utf-8')
        
        # 生成参数描述
        param_desc = "<br>".join([f"{k}: {v}" for k, v in combination.items()])
        
        # 构建HTML片段
        html_snippet = f"""
        <div class="item">
            <div class="param-header">Combination #{i+1}</div>
            <div class="param-desc">{param_desc}</div>
            <img src="data:image/png;base64,{img_data}">
        </div>
        """
        return html_snippet
    except Exception as e:
        print(f"Error processing combination {i}: {str(e)}")
        return ""

class ClusterDebugger:
    def __init__(self, gait_prints):
        self.gait_prints = gait_prints
        self.params = {
            'time_weight': 0.8,    # 时间维度权重
            'spatial_weight': 0.2, # 空间维度权重
            'eps_factor': 5.5,     # eps系数
            'min_samples': 2,      # 最小样本数
            'merge_threshold': 0.2 # 合并阈值(秒)
        }
        
    def run_clustering(self):
        """执行聚类流程"""
        # 准备特征
        features = np.array([
            [p.x * self.params['spatial_weight'], 
             p.y * self.params['spatial_weight'], 
             p.timestamp * self.params['time_weight']] 
            for p in self.gait_prints
        ])
        
        # 标准化
        scaler = StandardScaler()
        features_scaled = scaler.fit_transform(features)
        
        # 计算eps
        k = min(len(features), 5)
        nbrs = NearestNeighbors(n_neighbors=k).fit(features_scaled)
        distances, _ = nbrs.kneighbors(features_scaled)
        mean_dist = np.mean(distances[:, 1:])
        eps = mean_dist * self.params['eps_factor']
        
        # DBSCAN聚类
        dbscan = DBSCAN(eps=eps, min_samples=self.params['min_samples'])
        labels = dbscan.fit_predict(features_scaled)
        
        # 合并时间连续的簇
        self._merge_temporal_clusters(labels)
        return labels
    
    def _merge_temporal_clusters(self, labels):
        """简单的时间连续性合并"""
        clusters = {}
        for i, label in enumerate(labels):
            if label not in clusters:
                clusters[label] = []
            clusters[label].append(self.gait_prints[i])
        
        # 按时间排序并合并
        new_labels = labels.copy()
        current_label = max(labels) + 1
        sorted_clusters = sorted(clusters.items(), key=lambda x: min(p.timestamp for p in x[1]))
        
        for i in range(1, len(sorted_clusters)):
            prev_label, prev_points = sorted_clusters[i-1]
            curr_label, curr_points = sorted_clusters[i]
            
            last_time = max(p.timestamp for p in prev_points)
            first_time = min(p.timestamp for p in curr_points)
            
            if first_time - last_time < self.params['merge_threshold']:
                for p in curr_points:
                    new_labels[p.cluster_id] = prev_label
                
        return new_labels

    def visualize(self, labels):
        """交互式三维可视化"""
        features = np.array([
            [p.x, p.y, p.timestamp] 
            for p in self.gait_prints
        ])
        
        plt.figure(figsize=(15, 8))
        ax = plt.axes(projection='3d')
        
        # 绘制聚类结果
        scatter = ax.scatter3D(
            features[:,0], features[:,1], features[:,2],
            c=labels, cmap='tab20', alpha=0.8, s=50
        )
        
        # 标注参数
        param_text = (
            f"Time Weight: {self.params['time_weight']}\n"
            f"Spatial Weight: {self.params['spatial_weight']}\n"
            f"EPS Factor: {self.params['eps_factor']}\n"
            f"Min Samples: {self.params['min_samples']}\n"
            f"Merge Threshold: {self.params['merge_threshold']}s"
        )
        plt.figtext(0.8, 0.8, param_text, bbox=dict(facecolor='white', alpha=0.5))
        
        ax.set_xlabel('X Position')
        ax.set_ylabel('Y Position')
        ax.set_zlabel('Time (s)')
        plt.title("Gait Print Clustering Debug View")
        plt.show()

    def interactive_adjust(self):
        """交互式参数调整"""
        while True:
            print("\n当前参数:")
            for k, v in self.params.items():
                print(f"{k}: {v}")
            
            try:
                cmd = input("输入参数名和值 (如 'time_weight 0.6') 或 q退出: ").strip()
                if cmd.lower() == 'q':
                    break
                
                param, value = cmd.split()
                if param not in self.params:
                    raise ValueError
                
                # 类型转换
                if param in ['time_weight', 'spatial_weight', 'eps_factor', 'merge_threshold']:
                    self.params[param] = float(value)
                elif param == 'min_samples':
                    self.params[param] = int(value)
                else:
                    raise ValueError
                
                # 重新运行并可视化
                labels = self.run_clustering()
                self.visualize(labels)
                
            except Exception as e:
                print("输入无效,请按格式输入 (参数名 数值)")

    def batch_parameter_search(self, output_dir="param_search", max_workers=None):
        """完整的多进程参数搜索实现"""
        from pathlib import Path
        import itertools
        from concurrent.futures import ProcessPoolExecutor
        
        # 创建输出目录
        output_path = Path(output_dir)
        output_path.mkdir(parents=True, exist_ok=True)
        
        # 参数网格配置
        param_grid = {
            'spatial_weight': np.round(np.linspace(0.1, 5.0, 46), 2).tolist(),  # 0.1-1.0步长0.02
            'time_weight': [0.8],                # 固定时间权重
            'merge_threshold': [0.2],            # 固定合并阈值
            'eps_factor': [5.5],                 # 固定空间密度系数
            'min_samples': [2]                   # 固定最小样本数
        }
        
        # 生成有效参数组合
        active_params = {k: v for k, v in param_grid.items() if len(v) > 1}
        keys = list(active_params.keys())
        combinations = [dict(zip(keys, vals)) for vals in itertools.product(*active_params.values())]
        
        # 配置多进程
        max_workers = max_workers or (os.cpu_count() - 1 if os.cpu_count() else 1)
        total = len(combinations)
        
        # 进度条包装器
        def process_wrapper():
            with ProcessPoolExecutor(max_workers=max_workers) as executor:
                args = ((i, comb, self.gait_prints, self.params) 
                       for i, comb in enumerate(combinations))
                yield from executor.map(_process_combination_wrapper, args, chunksize=5)
        
        # 执行并收集结果
        html_parts = []
        with tqdm(total=total, desc="参数搜索进度") as pbar:
            for result in process_wrapper():
                html_parts.append(result)
                pbar.update(1)
        
        # 生成完整HTML报告
        full_html = self._build_html_report(html_parts)
        (output_path / "report.html").write_text(full_html)
        print(f"参数搜索完成!结果保存在 {output_path.resolve()}")

    def _build_html_report(self, html_parts):
        """构建完整的HTML报告结构"""
        return f"""
        <html>
        <head>
            <style>
                /* 保持之前的样式不变 */
                body {{ margin: 10px; background: #f5f5f5; }}
                .grid {{ /* 样式细节 */ }}
                /* 其他样式规则 */
            </style>
        </head>
        <body>
            <div class="grid">
                {''.join(html_parts)}
            </div>
        </body>
        </html>
        """

    def _generate_plot(self, labels):
        """生成改进后的可视化布局"""
        # 计算画布尺寸
        max_x = max(p.x + p.w/2 for p in self.gait_prints) + 50
        max_y = max(p.y + p.h/2 for p in self.gait_prints) + 50
        canvas_size = (int(max_y), int(max_x), 3)
        
        # 调整布局比例 (3:2)
        fig = plt.figure(figsize=(24, 10))
        gs = fig.add_gridspec(1, 2, width_ratios=[3, 2])
        
        # 子图1:空间分布(加宽)
        ax1 = fig.add_subplot(gs[0])
        spatial_canvas = np.ones(canvas_size, dtype=np.uint8) * 255
        unique_clusters = np.unique(labels)
        cmap = plt.get_cmap('tab20')
        
        for cluster_id in unique_clusters:
            if cluster_id == -1:  # 噪声点
                continue
            # 获取该簇的所有足印
            cluster_points = [p for p, lbl in zip(self.gait_prints, labels) if lbl == cluster_id]
            # 随机选择一个颜色
            color = np.array(cmap(cluster_id % 20)) * 255
            
            # 绘制每个足印的框
            for p in cluster_points:
                x = int(p.x - p.w/2)
                y = int(p.y - p.h/2)
                cv2.rectangle(spatial_canvas,
                             (x, y),
                             (x + int(p.w), y + int(p.h)),
                             color.tolist(), 2)
                
                # 添加簇标签
                label_pos = (x, y - 5)
                cv2.putText(spatial_canvas, f"C{cluster_id}", 
                           label_pos, cv2.FONT_HERSHEY_SIMPLEX, 
                           0.5, color.tolist(), 1)
        
        ax1.imshow(spatial_canvas)
        ax1.set_title("Spatial Distribution (Clustered)", fontsize=14, pad=20)
        ax1.axis('off')
        
        # 子图2:时间轴(调整布局)
        ax2 = fig.add_subplot(gs[1])
        
        # 绘制时间轴...
        for cluster_id in unique_clusters:
            if cluster_id == -1:
                continue
            # 获取该簇的时间戳和X坐标
            times = [p.timestamp for p, lbl in zip(self.gait_prints, labels) if lbl == cluster_id]
            x_coords = [p.x for p, lbl in zip(self.gait_prints, labels) if lbl == cluster_id]
            color = cmap(cluster_id % 20)
            
            ax2.scatter(times, x_coords, color=color, s=40, 
                           label=f'Cluster {cluster_id}')
        
        ax2.set_xlabel('Time (s)', fontsize=12)
        ax2.set_ylabel('X Position', fontsize=12)
        ax2.set_title('Temporal Distribution', fontsize=14, pad=20)
        ax2.grid(True, linestyle='--', alpha=0.6)
        
        # 添加颜色图例
        from matplotlib.patches import Patch
        legend_elements = [
            Patch(facecolor=cmap(i%20), label=f'Cluster {i}')
            for i in unique_clusters if i != -1
        ]
        ax2.legend(handles=legend_elements, 
                  bbox_to_anchor=(1.05, 1), 
                  loc='upper left',
                  title="Cluster ID")
        
        plt.tight_layout()
        return fig

    def visualize_cluster_distribution(self, labels):
        """新增簇分布直方图"""
        unique, counts = np.unique(labels, return_counts=True)
        plt.figure(figsize=(10, 6))
        plt.bar(unique, counts)
        plt.xlabel('簇ID')
        plt.ylabel('足印数量')
        plt.title('簇分布直方图')
        plt.grid(axis='y')
        plt.savefig("cluster_distribution.png")
        plt.close()

def load_debug_data(json_path):
    """加载预处理好的足印数据"""
    with open(json_path) as f:
        data = json.load(f)
    
    gait_prints = []
    for frame in data['frames']:
        for fp in frame['footprints']:
            gait_prints.append(GaitPrint(
                frame_id=frame['frameId'],
                x=fp['position']['x'] + fp['position']['width']/2,
                y=fp['position']['y'] + fp['position']['height']/2,
                w=fp['position']['width'],
                h=fp['position']['height'],
                conf=fp['confidence'],
                timestamp=frame['frameId']/120.0,
                paw_type=fp.get('type', 'unknown')  # 从数据中获取类型
            ))
    return gait_prints

if __name__ == "__main__":
    # 添加项目根目录到PATH
    import sys
    from pathlib import Path
    root_dir = str(Path(__file__).parent.parent.resolve())
    if root_dir not in sys.path:
        sys.path.insert(0, root_dir)
    
    parser = argparse.ArgumentParser()
    parser.add_argument("--json", required=True, help="足印数据JSON文件路径")
    parser.add_argument("--batch-search", action="store_true", 
                       help="执行批量参数搜索")
    parser.add_argument("--max-workers", type=int, default=0,
                       help="最大并行进程数(0=自动)")
    args = parser.parse_args()

    print("加载数据...")
    gait_prints = load_debug_data(args.json)
    
    debugger = ClusterDebugger(gait_prints)
    print("初始聚类...")
    labels = debugger.run_clustering()
    debugger.visualize(labels)
    
    if args.batch_search:
        debugger.batch_parameter_search()
    else:
        print("进入交互调试模式")
        debugger.interactive_adjust()