File size: 25,406 Bytes
eddf713
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import collections
import tempfile
from ultralytics import YOLO
import math

class MouseTrackerAnalyzer:
    """基于Ultralytics对象跟踪的鼠强迫游泳实验挣扎度分析器"""
    def __init__(self, model_path, history_size=5, conf=0.25, iou=0.45, max_det=20, verbose=False):
        # 初始化模型和参数
        self.model = YOLO(model_path, task="segment", verbose=False)
        self.history_size = history_size
        self.verbose = verbose  # 控制日志输出级别
        self.struggle_threshold = 0.3  # 挣扎阈值
        
        # 跟踪相关参数
        self.conf = conf  # 置信度阈值
        self.iou = iou    # IOU阈值
        self.max_det = max_det  # 最大检测数量
        
        # 预设16种固定颜色 (BGR顺序)
        self.colors = [
            (255, 0, 0),    # 红
            (0, 255, 0),    # 绿
            (0, 0, 255),    # 蓝
            (255, 255, 0),  # 青
            (255, 0, 255),  # 洋红
            (0, 255, 255),  # 黄
            (128, 0, 0),    # 深红
            (128, 0, 128),  # 紫
            (0, 128, 128),  # 青绿
            (192, 192, 192),# 银
            (128, 128, 128),# 灰
            (255, 128, 0),  # 橙
            (255, 0, 128),  # 粉
            (0, 128, 255),  # 浅蓝
            (128, 255, 0),  # 黄绿
            (0, 255, 128)   # 浅绿
        ]
        # 追踪相关
        self.prev_masks = {}      # 上一帧各 ID 二值掩码
        self.histories = {}       # 各 ID 分数历史队列
        self.track_ids = set()    # 所有被跟踪的ID
        
        # 视频处理状态
        self.cap = None
        self.writer = None
        self.frame_id = 0
        self.results = []  # 存储每帧结果
        self.start_frame = 0
        self.end_frame = 0

    def init_video(self, video_path, output_path=None, start_frame=0, end_frame=None):
        """初始化视频处理"""
        # 打开视频并初始化写出器
        self.cap = cv2.VideoCapture(video_path)
        if not self.cap.isOpened():
            raise IOError(f"无法打开视频 {video_path}")
        
        # 获取视频属性
        width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        fps = self.cap.get(cv2.CAP_PROP_FPS) or 30
        self.fps = max(fps, 1.0)  # 保存帧率到实例变量,确保至少为1
        total_frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
        
        if self.verbose:
            print(f"视频尺寸: {width}x{height}, 帧率: {fps}, 总帧数: {total_frames}")
        
        # 设置帧范围
        self.start_frame = start_frame
        self.end_frame = end_frame if end_frame is not None else total_frames - 1
        
        # 确保帧范围有效
        if self.start_frame < 0:
            self.start_frame = 0
        if self.end_frame >= total_frames:
            self.end_frame = total_frames - 1
        if self.start_frame > self.end_frame:
            self.start_frame, self.end_frame = self.end_frame, self.start_frame
            
        # 将视频定位到起始帧
        if self.start_frame > 0:
            self.cap.set(cv2.CAP_PROP_POS_FRAMES, self.start_frame)
            
        # 如果输出为视频则初始化 VideoWriter
        if output_path and output_path.lower().endswith(('.mp4', '.avi')):
            # 使用标准编码器
            fourcc = cv2.VideoWriter_fourcc(*'mp4v')
            # 创建VideoWriter
            self.writer = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
            if self.writer.isOpened():
                print(f"成功创建输出视频: {output_path}, 尺寸: {width}x{height}")
            else:
                print(f"警告: 无法创建输出视频 {output_path}")
            
        # 重置状态
        self.frame_id = self.start_frame
        self.results = []
        self.prev_masks.clear()
        self.histories.clear()
        self.track_ids.clear()
        
        if self.verbose:
            print(f"视频初始化完成: 总帧数 {total_frames}, 分析范围 {self.start_frame}-{self.end_frame}")
        
        return total_frames, self.start_frame, self.end_frame

    def process_frame(self, frame, frame_id):
        """处理单帧,返回可视化帧和本帧结果列表"""
        if self.verbose and frame_id % 10 == 0:
            print(f"process_frame: 处理帧 {frame_id}")
            
        try:
            # 使用YOLO模型跟踪对象
            results = self.model.track(
                frame, 
                persist=True,  # 保持跟踪ID的持久性
                conf=self.conf, 
                iou=self.iou,
                max_det=self.max_det,
                verbose=False
            )
            
            # 检查是否有检测结果
            frame_results = []
            
            if results[0].boxes is None or len(results[0].boxes) == 0:
                if self.verbose and frame_id % 50 == 0:
                    print("没有检测到任何对象")
                return frame.copy(), []
                
            # 处理检测结果
            if hasattr(results[0], 'masks') and results[0].masks is not None:
                # 获取掩码和跟踪ID
                masks = results[0].masks.data.cpu().numpy()
                track_ids = results[0].boxes.id
                
                if track_ids is None:
                    if self.verbose and frame_id % 50 == 0:
                        print("没有获取到跟踪ID")
                    return frame.copy(), []
                    
                track_ids = track_ids.int().cpu().numpy()
                
                if self.verbose and frame_id % 50 == 0:
                    print(f"检测到 {len(masks)} 个掩码,{len(track_ids)} 个跟踪ID")
                
                # 更新跟踪ID集合
                for track_id in track_ids:
                    self.track_ids.add(int(track_id))
                    
                # 处理每个跟踪对象
                for i, (mask, track_id) in enumerate(zip(masks, track_ids)):
                    track_id = int(track_id)
                    
                    # 二值化掩码
                    bin_mask = (mask > 0.2).astype(np.uint8)
                    
                    # 应用形态学操作清理掩码
                    kernel = np.ones((5,5), np.uint8)
                    bin_mask = cv2.morphologyEx(bin_mask, cv2.MORPH_CLOSE, kernel)
                    
                    # 调整掩码尺寸到与原始帧相同
                    if bin_mask.shape != (frame.shape[0], frame.shape[1]):
                        bin_mask = cv2.resize(bin_mask, (frame.shape[1], frame.shape[0]), interpolation=cv2.INTER_NEAREST)
                    
                    # 计算挣扎度
                    if track_id in self.prev_masks:
                        prev_mask = self.prev_masks[track_id]
                        # 确保比较的掩码尺寸一致
                        if prev_mask.shape != bin_mask.shape:
                            prev_mask = cv2.resize(prev_mask, (bin_mask.shape[1], bin_mask.shape[0]), interpolation=cv2.INTER_NEAREST)
                        inter = np.logical_and(prev_mask > 0, bin_mask > 0).sum()
                        union = np.logical_or(prev_mask > 0, bin_mask > 0).sum()
                        iou = inter / union if union > 0 else 0
                        score = 1 - iou
                        if self.verbose and frame_id % 50 == 0:
                            print(f"跟踪ID {track_id} 挣扎分数: {score:.4f} (IoU: {iou:.4f})")
                    else:
                        score = 0.0
                        if self.verbose and frame_id % 50 == 0:
                            print(f"跟踪ID {track_id} 初始帧,分数为0")
                            
                    # 保存当前掩码和历史
                    self.prev_masks[track_id] = bin_mask
                    
                    if track_id not in self.histories:
                        self.histories[track_id] = collections.deque(maxlen=self.history_size)
                    self.histories[track_id].append(score)
                    
                    # 计算挣扎状态
                    is_struggling = score >= self.struggle_threshold
                    
                    # 计算质心
                    ys, xs = np.where(bin_mask > 0)
                    if len(xs) > 0:
                        centroid = (int(xs.mean()), int(ys.mean()))
                    else:
                        # 如果掩码为空,使用边界框中心点
                        box = results[0].boxes[i].xyxy.cpu().numpy()[0]
                        centroid = (int((box[0] + box[2]) / 2), int((box[1] + box[3]) / 2))
                    
                    # 添加到帧结果
                    frame_results.append({
                        'id': track_id,
                        'score': float(score),
                        'centroid': centroid,
                        'is_struggling': is_struggling
                    })
            else:
                if self.verbose and frame_id % 50 == 0:
                    print("没有检测到任何掩码")
                return frame.copy(), []
                
            # 可视化 - 在这里创建最终的标注帧
            annotated = frame.copy()
            
            # 绘制掩码和ID
            for result in frame_results:
                track_id = result['id']
                color = self.colors[track_id % len(self.colors)]
                
                # 绘制掩码
                if track_id in self.prev_masks:
                    mask = self.prev_masks[track_id]
                    # 确保掩码与帧大小一致
                    if mask.shape != (frame.shape[0], frame.shape[1]):
                        mask = cv2.resize(mask, (frame.shape[1], frame.shape[0]), interpolation=cv2.INTER_NEAREST)
                    mask_overlay = np.zeros_like(frame)
                    mask_overlay[mask > 0] = color
                    
                    # 使用更精确的掩码边缘
                    contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
                    cv2.drawContours(annotated, contours, -1, color, 2)
                    
                    # 使用addWeighted进行混合
                    cv2.addWeighted(annotated, 1.0, mask_overlay, 0.4, 0, annotated)
                
                # 在质心位置绘制ID和挣扎状态
                centroid = result['centroid']
                status_text = "Struggle" if result['is_struggling'] else "Static"
                cv2.putText(annotated, f"ID:{track_id} {status_text}", 
                           (centroid[0], centroid[1]), 
                           cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2)
            
            # 在顶部创建黑色半透明条,显示总结信息
            cv2.rectangle(annotated, (0, 0), (frame.shape[1], 40), (0, 0, 0), -1)
            
            # 计算挣扎中的老鼠数量
            struggling_count = sum(1 for r in frame_results if r['is_struggling'])
            total_count = len(frame_results)
            
            # 显示统计信息
            cv2.putText(annotated, f"Total: {total_count} Struggling: {struggling_count}", 
                       (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
            
            # 最后,由于OpenCV以BGR格式工作,但可能需要RGB格式,
            # 确保返回的图像是BGR格式(视频写入用BGR,显示用RGB)
            if annotated.dtype != np.uint8:
                annotated = annotated.astype(np.uint8)
                
            return annotated, frame_results
            
        except Exception as e:
            import traceback
            if self.verbose:
                print(f"处理帧时出错: {str(e)}")
                traceback.print_exc()
            # 返回原始帧和空结果
            return frame.copy(), []

    def process_video(self, video_path, output_path=None, start_frame=0, end_frame=None, callback=None):
        """处理整段视频,可选的回调函数用于更新进度"""
        # 初始化视频
        total_frames, start, end = self.init_video(video_path, output_path, start_frame, end_frame)
        self.results = []  # 确保结果列表被清空
        
        frame_id = start
        processed_frames = 0
        frames_to_process = end - start + 1
        last_progress = -1
        
        # 临时保存一帧,用于调试
        debug_frame_saved = False
        
        while frame_id <= end:
            ret, frame = self.cap.read()
            if not ret:
                break
                
            # 处理当前帧
            annotated, frame_res = self.process_frame(frame, frame_id)
            self.results.append(frame_res)  # 将当前帧结果存入results列表
            
            # 保存第一帧用于调试
            if not debug_frame_saved and len(frame_res) > 0:
                debug_frame_path = os.path.join(os.path.dirname(output_path), "debug_frame.jpg")
                cv2.imwrite(debug_frame_path, annotated)
                print(f"调试: 保存了标注帧到 {debug_frame_path}")
                debug_frame_saved = True
            
            # 写入输出视频
            if self.writer:
                # 确保帧是BGR格式
                if len(annotated.shape) == 3 and annotated.shape[2] == 3:
                    # 如果需要,将RGB转换回BGR (OpenCV使用BGR)
                    # 默认应该已经是BGR,但为了确保
                    if frame_id == start:
                        print(f"调试: 写入标注帧到视频,形状: {annotated.shape}")
                    
                    try:
                        self.writer.write(annotated)
                    except Exception as e:
                        print(f"调试: 写入帧到视频时出错: {str(e)}")
                        import traceback
                        traceback.print_exc()
            
            # 更新进度和回调
            processed_frames += 1
            progress = int(100 * processed_frames / frames_to_process)
            
            if progress != last_progress and callback:
                callback(progress, annotated, frame_res)
                last_progress = progress
                
            frame_id += 1
            
        # 释放资源
        self.cap.release()
        if self.writer:
            self.writer.release()
            print(f"调试: 视频写入完成,保存到: {output_path}")
            
        return self.results
    
    def save_results(self, csv_path):
        """导出分析结果到 CSV"""
        import csv
        with open(csv_path, 'w', newline='') as f:
            writer = csv.writer(f)
            writer.writerow(['frame_id', 'mouse_id', 'score', 'is_struggling'])
            for fid, frs in enumerate(self.results):
                for fr in frs:
                    writer.writerow([
                        fid + self.start_frame, 
                        fr['id'], 
                        f"{fr['score']:.4f}", 
                        1 if fr.get('is_struggling', False) else 0
                    ])
    
    def generate_time_series_plot(self, threshold=None):
        """生成时序图分析"""
        try:
            print(f"Starting to generate time series plot with {len(self.results)} frames of data")
            
            if not self.results or len(self.results) < 10:
                print("Not enough data for time series plot (need at least 10 frames)")
                return None
                
            # 使用传入的阈值或默认阈值
            if threshold is None:
                threshold = self.struggle_threshold
            
            # 使用保存的帧率,确保不会出现除以零的情况
            fps = getattr(self, 'fps', None)
            if fps is None or fps <= 0:
                fps = 30  # 使用默认帧率
                print(f"Warning: Invalid frame rate detected, using default: {fps} fps")
            else:
                print(f"Using frame rate: {fps} fps")
                
            # 处理数据
            frames = []
            mouse_data = {}
            mouse_positions = {}  # 用于存储每只老鼠的平均X坐标
            
            for frame_id, frame_results in enumerate(self.results):
                frames.append(frame_id + self.start_frame)  # 使用真实帧号
                for result in frame_results:
                    mouse_id = result['id']
                    if mouse_id not in mouse_data:
                        mouse_data[mouse_id] = {'frames': [], 'seconds': [], 'scores': [], 'struggling': []}
                        mouse_positions[mouse_id] = []  # 初始化X坐标列表
                        
                    frame_num = frame_id + self.start_frame
                    second = frame_num / fps  # 转换为秒
                    
                    mouse_data[mouse_id]['frames'].append(frame_num)
                    mouse_data[mouse_id]['seconds'].append(second)
                    mouse_data[mouse_id]['scores'].append(result['score'])
                    mouse_data[mouse_id]['struggling'].append(1 if result.get('is_struggling', False) else 0)
                    
                    # 记录质心的X坐标
                    if 'centroid' in result:
                        mouse_positions[mouse_id].append(result['centroid'][0])
            
            print(f"Processed data for {len(mouse_data)} mice")
            if not mouse_data:
                print("No valid mouse data to plot")
                return None
                
            # 计算每只老鼠的平均X坐标并按从左到右排序
            avg_positions = {}
            for mouse_id, positions in mouse_positions.items():
                if positions:
                    avg_positions[mouse_id] = sum(positions) / len(positions)
                else:
                    avg_positions[mouse_id] = float('inf')  # 如果没有位置数据,放到最后
                    
            # 按从左到右排序老鼠ID
            sorted_mice = sorted(mouse_data.keys(), key=lambda mid: avg_positions.get(mid, float('inf')))
            print(f"Mice sorted from left to right: {sorted_mice}")
            
            # 对数据进行平滑处理
            def smooth_data(data, window_size=5):
                """使用移动平均平滑数据"""
                if len(data) < window_size:
                    return data
                smoothed = []
                for i in range(len(data)):
                    start = max(0, i - window_size // 2)
                    end = min(len(data), i + window_size // 2 + 1)
                    window = data[start:end]
                    smoothed.append(sum(window) / len(window))
                return smoothed
            
            # 创建子图
            num_mice = len(mouse_data)
            fig, axes = plt.subplots(num_mice, 1, figsize=(12, 4*num_mice), sharex=True)
            
            # 如果只有一只鼠,确保axes是列表
            if num_mice == 1:
                axes = [axes]
            
            # 绘制每只老鼠的挣扎得分曲线,按从左到右的顺序
            for idx, mouse_id in enumerate(sorted_mice):
                data = mouse_data[mouse_id]
                ax = axes[idx]
                
                # 平滑数据
                smoothed_scores = smooth_data(data['scores'], window_size=5)
                
                # 绘制曲线
                ax.plot(data['seconds'], smoothed_scores, label=f"Smoothed", color='blue', linewidth=2)
                ax.plot(data['seconds'], data['scores'], label=f"Raw", color='lightblue', alpha=0.5, linewidth=1)
                
                # 标记挣扎区域
                for i, is_struggling in enumerate(data['struggling']):
                    if is_struggling:
                        ax.axvspan(data['seconds'][i]-0.5/fps, data['seconds'][i]+0.5/fps, alpha=0.1, color='red')
                
                # 绘制阈值线
                ax.axhline(y=threshold, color='r', linestyle='--', label=f"Threshold ({threshold:.2f})")
                
                # 设置图表
                ax.set_ylabel('Struggle Score')
                position_text = f"(Position: Left #{sorted_mice.index(mouse_id)+1})" if mouse_id in avg_positions else ""
                ax.set_title(f'Mouse {mouse_id} Struggle Score {position_text}')
                ax.legend(loc='upper right')
                ax.grid(True)
                
                # 设置Y轴范围0-1
                ax.set_ylim(-0.05, 1.05)
            
            # 设置共享的X轴标签
            axes[-1].set_xlabel('Time (seconds)')
            
            # 动态调整x轴范围,精确到0.1秒
            if frames:
                start_time = self.start_frame / fps
                end_time = max(frames) / fps
                # 扩展一点范围以便更好地显示
                axes[-1].set_xlim(start_time, end_time)
                
                # 设置次要刻度(细网格线)
                tick_interval = 0.1  # 保持0.1秒的细网格
                minor_ticks = np.arange(start_time, end_time + tick_interval, tick_interval)
                axes[-1].set_xticks(minor_ticks, minor=True)
                
                # 设置主要刻度(标签和粗网格线)- 整秒
                major_start = math.ceil(start_time)
                major_end = math.floor(end_time)
                major_ticks = np.arange(major_start, major_end + 1, 1.0)  # 整秒刻度
                axes[-1].set_xticks(major_ticks)
                axes[-1].set_xticklabels([f"{int(t)}" for t in major_ticks])  # 整数秒标签
                
                # 设置网格
                axes[-1].grid(True, which='both')
                axes[-1].grid(which='minor', alpha=0.2)
                axes[-1].grid(which='major', alpha=0.5)
            
            plt.tight_layout()
            
            # 保存图表到临时文件并返回路径
            temp_file = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
            plt.savefig(temp_file.name, dpi=150, bbox_inches='tight')
            plt.close()
            
            print(f"Time series plot saved to: {temp_file.name}")
            return temp_file.name
            
        except Exception as e:
            import traceback
            print(f"Error generating time series plot: {str(e)}")
            traceback.print_exc()
            return None

if __name__ == "__main__":
    import argparse
    
    parser = argparse.ArgumentParser(description="鼠强迫游泳实验挣扎度分析")
    parser.add_argument('--video', type=str, required=True, help='输入视频路径')
    parser.add_argument('--model', type=str, required=True, help='模型文件路径')
    parser.add_argument('--output', type=str, help='输出视频路径')
    parser.add_argument('--csv', type=str, help='输出CSV结果路径')
    parser.add_argument('--conf', type=float, default=0.25, help='置信度阈值')
    parser.add_argument('--iou', type=float, default=0.45, help='IOU阈值')
    parser.add_argument('--max-det', type=int, default=20, help='最大检测数量')
    parser.add_argument('--threshold', type=float, default=0.3, help='挣扎阈值')
    parser.add_argument('--start', type=int, default=0, help='起始帧')
    parser.add_argument('--end', type=int, default=None, help='结束帧')
    parser.add_argument('--verbose', action='store_true', help='详细输出')
    
    args = parser.parse_args()
    
    # 设置输出路径
    if not args.output:
        video_name = os.path.splitext(os.path.basename(args.video))[0]
        args.output = os.path.join(os.path.dirname(args.video), f"{video_name}_out.mp4")
    
    if not args.csv:
        video_name = os.path.splitext(os.path.basename(args.video))[0]
        args.csv = os.path.join(os.path.dirname(args.video), f"{video_name}_results.csv")
    
    # 创建分析器并处理
    analyzer = MouseTrackerAnalyzer(
        model_path=args.model,
        conf=args.conf,
        iou=args.iou,
        max_det=args.max_det,
        verbose=args.verbose
    )
    analyzer.struggle_threshold = args.threshold
    
    # 进度回调函数
    def progress_callback(progress, frame, results):
        print(f"处理进度: {progress}%, 检测到 {len(results)} 个对象")
    
    # 处理视频
    analyzer.process_video(
        video_path=args.video,
        output_path=args.output,
        start_frame=args.start,
        end_frame=args.end,
        callback=progress_callback
    )
    
    # 保存结果
    analyzer.save_results(args.csv)
    
    # 生成分析图表
    plot_path = analyzer.generate_time_series_plot()
    if plot_path:
        print(f"挣扎度时序分析图已保存到: {plot_path}")
    
    print(f"分析完成,视频已保存到: {args.output}")
    print(f"结果数据已保存到: {args.csv}")