import os import cv2 import numpy as np import matplotlib.pyplot as plt import pandas as pd import collections import tempfile from ultralytics import YOLO import math class MouseTrackerAnalyzer: """基于Ultralytics对象跟踪的鼠强迫游泳实验挣扎度分析器""" def __init__(self, model_path, history_size=5, conf=0.25, iou=0.45, max_det=20, verbose=False): # 初始化模型和参数 self.model = YOLO(model_path, task="segment", verbose=False) self.history_size = history_size self.verbose = verbose # 控制日志输出级别 self.struggle_threshold = 0.3 # 挣扎阈值 # 跟踪相关参数 self.conf = conf # 置信度阈值 self.iou = iou # IOU阈值 self.max_det = max_det # 最大检测数量 # 预设16种固定颜色 (BGR顺序) self.colors = [ (255, 0, 0), # 红 (0, 255, 0), # 绿 (0, 0, 255), # 蓝 (255, 255, 0), # 青 (255, 0, 255), # 洋红 (0, 255, 255), # 黄 (128, 0, 0), # 深红 (128, 0, 128), # 紫 (0, 128, 128), # 青绿 (192, 192, 192),# 银 (128, 128, 128),# 灰 (255, 128, 0), # 橙 (255, 0, 128), # 粉 (0, 128, 255), # 浅蓝 (128, 255, 0), # 黄绿 (0, 255, 128) # 浅绿 ] # 追踪相关 self.prev_masks = {} # 上一帧各 ID 二值掩码 self.histories = {} # 各 ID 分数历史队列 self.track_ids = set() # 所有被跟踪的ID # 视频处理状态 self.cap = None self.writer = None self.frame_id = 0 self.results = [] # 存储每帧结果 self.start_frame = 0 self.end_frame = 0 def init_video(self, video_path, output_path=None, start_frame=0, end_frame=None): """初始化视频处理""" # 打开视频并初始化写出器 self.cap = cv2.VideoCapture(video_path) if not self.cap.isOpened(): raise IOError(f"无法打开视频 {video_path}") # 获取视频属性 width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) fps = self.cap.get(cv2.CAP_PROP_FPS) or 30 self.fps = max(fps, 1.0) # 保存帧率到实例变量,确保至少为1 total_frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT)) if self.verbose: print(f"视频尺寸: {width}x{height}, 帧率: {fps}, 总帧数: {total_frames}") # 设置帧范围 self.start_frame = start_frame self.end_frame = end_frame if end_frame is not None else total_frames - 1 # 确保帧范围有效 if self.start_frame < 0: self.start_frame = 0 if self.end_frame >= total_frames: self.end_frame = total_frames - 1 if self.start_frame > self.end_frame: self.start_frame, self.end_frame = self.end_frame, self.start_frame # 将视频定位到起始帧 if self.start_frame > 0: self.cap.set(cv2.CAP_PROP_POS_FRAMES, self.start_frame) # 如果输出为视频则初始化 VideoWriter if output_path and output_path.lower().endswith(('.mp4', '.avi')): # 使用标准编码器 fourcc = cv2.VideoWriter_fourcc(*'mp4v') # 创建VideoWriter self.writer = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) if self.writer.isOpened(): print(f"成功创建输出视频: {output_path}, 尺寸: {width}x{height}") else: print(f"警告: 无法创建输出视频 {output_path}") # 重置状态 self.frame_id = self.start_frame self.results = [] self.prev_masks.clear() self.histories.clear() self.track_ids.clear() if self.verbose: print(f"视频初始化完成: 总帧数 {total_frames}, 分析范围 {self.start_frame}-{self.end_frame}") return total_frames, self.start_frame, self.end_frame def process_frame(self, frame, frame_id): """处理单帧,返回可视化帧和本帧结果列表""" if self.verbose and frame_id % 10 == 0: print(f"process_frame: 处理帧 {frame_id}") try: # 使用YOLO模型跟踪对象 results = self.model.track( frame, persist=True, # 保持跟踪ID的持久性 conf=self.conf, iou=self.iou, max_det=self.max_det, verbose=False ) # 检查是否有检测结果 frame_results = [] if results[0].boxes is None or len(results[0].boxes) == 0: if self.verbose and frame_id % 50 == 0: print("没有检测到任何对象") return frame.copy(), [] # 处理检测结果 if hasattr(results[0], 'masks') and results[0].masks is not None: # 获取掩码和跟踪ID masks = results[0].masks.data.cpu().numpy() track_ids = results[0].boxes.id if track_ids is None: if self.verbose and frame_id % 50 == 0: print("没有获取到跟踪ID") return frame.copy(), [] track_ids = track_ids.int().cpu().numpy() if self.verbose and frame_id % 50 == 0: print(f"检测到 {len(masks)} 个掩码,{len(track_ids)} 个跟踪ID") # 更新跟踪ID集合 for track_id in track_ids: self.track_ids.add(int(track_id)) # 处理每个跟踪对象 for i, (mask, track_id) in enumerate(zip(masks, track_ids)): track_id = int(track_id) # 二值化掩码 bin_mask = (mask > 0.2).astype(np.uint8) # 应用形态学操作清理掩码 kernel = np.ones((5,5), np.uint8) bin_mask = cv2.morphologyEx(bin_mask, cv2.MORPH_CLOSE, kernel) # 调整掩码尺寸到与原始帧相同 if bin_mask.shape != (frame.shape[0], frame.shape[1]): bin_mask = cv2.resize(bin_mask, (frame.shape[1], frame.shape[0]), interpolation=cv2.INTER_NEAREST) # 计算挣扎度 if track_id in self.prev_masks: prev_mask = self.prev_masks[track_id] # 确保比较的掩码尺寸一致 if prev_mask.shape != bin_mask.shape: prev_mask = cv2.resize(prev_mask, (bin_mask.shape[1], bin_mask.shape[0]), interpolation=cv2.INTER_NEAREST) inter = np.logical_and(prev_mask > 0, bin_mask > 0).sum() union = np.logical_or(prev_mask > 0, bin_mask > 0).sum() iou = inter / union if union > 0 else 0 score = 1 - iou if self.verbose and frame_id % 50 == 0: print(f"跟踪ID {track_id} 挣扎分数: {score:.4f} (IoU: {iou:.4f})") else: score = 0.0 if self.verbose and frame_id % 50 == 0: print(f"跟踪ID {track_id} 初始帧,分数为0") # 保存当前掩码和历史 self.prev_masks[track_id] = bin_mask if track_id not in self.histories: self.histories[track_id] = collections.deque(maxlen=self.history_size) self.histories[track_id].append(score) # 计算挣扎状态 is_struggling = score >= self.struggle_threshold # 计算质心 ys, xs = np.where(bin_mask > 0) if len(xs) > 0: centroid = (int(xs.mean()), int(ys.mean())) else: # 如果掩码为空,使用边界框中心点 box = results[0].boxes[i].xyxy.cpu().numpy()[0] centroid = (int((box[0] + box[2]) / 2), int((box[1] + box[3]) / 2)) # 添加到帧结果 frame_results.append({ 'id': track_id, 'score': float(score), 'centroid': centroid, 'is_struggling': is_struggling }) else: if self.verbose and frame_id % 50 == 0: print("没有检测到任何掩码") return frame.copy(), [] # 可视化 - 在这里创建最终的标注帧 annotated = frame.copy() # 绘制掩码和ID for result in frame_results: track_id = result['id'] color = self.colors[track_id % len(self.colors)] # 绘制掩码 if track_id in self.prev_masks: mask = self.prev_masks[track_id] # 确保掩码与帧大小一致 if mask.shape != (frame.shape[0], frame.shape[1]): mask = cv2.resize(mask, (frame.shape[1], frame.shape[0]), interpolation=cv2.INTER_NEAREST) mask_overlay = np.zeros_like(frame) mask_overlay[mask > 0] = color # 使用更精确的掩码边缘 contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) cv2.drawContours(annotated, contours, -1, color, 2) # 使用addWeighted进行混合 cv2.addWeighted(annotated, 1.0, mask_overlay, 0.4, 0, annotated) # 在质心位置绘制ID和挣扎状态 centroid = result['centroid'] status_text = "Struggle" if result['is_struggling'] else "Static" cv2.putText(annotated, f"ID:{track_id} {status_text}", (centroid[0], centroid[1]), cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2) # 在顶部创建黑色半透明条,显示总结信息 cv2.rectangle(annotated, (0, 0), (frame.shape[1], 40), (0, 0, 0), -1) # 计算挣扎中的老鼠数量 struggling_count = sum(1 for r in frame_results if r['is_struggling']) total_count = len(frame_results) # 显示统计信息 cv2.putText(annotated, f"Total: {total_count} Struggling: {struggling_count}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2) # 最后,由于OpenCV以BGR格式工作,但可能需要RGB格式, # 确保返回的图像是BGR格式(视频写入用BGR,显示用RGB) if annotated.dtype != np.uint8: annotated = annotated.astype(np.uint8) return annotated, frame_results except Exception as e: import traceback if self.verbose: print(f"处理帧时出错: {str(e)}") traceback.print_exc() # 返回原始帧和空结果 return frame.copy(), [] def process_video(self, video_path, output_path=None, start_frame=0, end_frame=None, callback=None): """处理整段视频,可选的回调函数用于更新进度""" # 初始化视频 total_frames, start, end = self.init_video(video_path, output_path, start_frame, end_frame) self.results = [] # 确保结果列表被清空 frame_id = start processed_frames = 0 frames_to_process = end - start + 1 last_progress = -1 # 临时保存一帧,用于调试 debug_frame_saved = False while frame_id <= end: ret, frame = self.cap.read() if not ret: break # 处理当前帧 annotated, frame_res = self.process_frame(frame, frame_id) self.results.append(frame_res) # 将当前帧结果存入results列表 # 保存第一帧用于调试 if not debug_frame_saved and len(frame_res) > 0: debug_frame_path = os.path.join(os.path.dirname(output_path), "debug_frame.jpg") cv2.imwrite(debug_frame_path, annotated) print(f"调试: 保存了标注帧到 {debug_frame_path}") debug_frame_saved = True # 写入输出视频 if self.writer: # 确保帧是BGR格式 if len(annotated.shape) == 3 and annotated.shape[2] == 3: # 如果需要,将RGB转换回BGR (OpenCV使用BGR) # 默认应该已经是BGR,但为了确保 if frame_id == start: print(f"调试: 写入标注帧到视频,形状: {annotated.shape}") try: self.writer.write(annotated) except Exception as e: print(f"调试: 写入帧到视频时出错: {str(e)}") import traceback traceback.print_exc() # 更新进度和回调 processed_frames += 1 progress = int(100 * processed_frames / frames_to_process) if progress != last_progress and callback: callback(progress, annotated, frame_res) last_progress = progress frame_id += 1 # 释放资源 self.cap.release() if self.writer: self.writer.release() print(f"调试: 视频写入完成,保存到: {output_path}") return self.results def save_results(self, csv_path): """导出分析结果到 CSV""" import csv with open(csv_path, 'w', newline='') as f: writer = csv.writer(f) writer.writerow(['frame_id', 'mouse_id', 'score', 'is_struggling']) for fid, frs in enumerate(self.results): for fr in frs: writer.writerow([ fid + self.start_frame, fr['id'], f"{fr['score']:.4f}", 1 if fr.get('is_struggling', False) else 0 ]) def generate_time_series_plot(self, threshold=None): """生成时序图分析""" try: print(f"Starting to generate time series plot with {len(self.results)} frames of data") if not self.results or len(self.results) < 10: print("Not enough data for time series plot (need at least 10 frames)") return None # 使用传入的阈值或默认阈值 if threshold is None: threshold = self.struggle_threshold # 使用保存的帧率,确保不会出现除以零的情况 fps = getattr(self, 'fps', None) if fps is None or fps <= 0: fps = 30 # 使用默认帧率 print(f"Warning: Invalid frame rate detected, using default: {fps} fps") else: print(f"Using frame rate: {fps} fps") # 处理数据 frames = [] mouse_data = {} mouse_positions = {} # 用于存储每只老鼠的平均X坐标 for frame_id, frame_results in enumerate(self.results): frames.append(frame_id + self.start_frame) # 使用真实帧号 for result in frame_results: mouse_id = result['id'] if mouse_id not in mouse_data: mouse_data[mouse_id] = {'frames': [], 'seconds': [], 'scores': [], 'struggling': []} mouse_positions[mouse_id] = [] # 初始化X坐标列表 frame_num = frame_id + self.start_frame second = frame_num / fps # 转换为秒 mouse_data[mouse_id]['frames'].append(frame_num) mouse_data[mouse_id]['seconds'].append(second) mouse_data[mouse_id]['scores'].append(result['score']) mouse_data[mouse_id]['struggling'].append(1 if result.get('is_struggling', False) else 0) # 记录质心的X坐标 if 'centroid' in result: mouse_positions[mouse_id].append(result['centroid'][0]) print(f"Processed data for {len(mouse_data)} mice") if not mouse_data: print("No valid mouse data to plot") return None # 计算每只老鼠的平均X坐标并按从左到右排序 avg_positions = {} for mouse_id, positions in mouse_positions.items(): if positions: avg_positions[mouse_id] = sum(positions) / len(positions) else: avg_positions[mouse_id] = float('inf') # 如果没有位置数据,放到最后 # 按从左到右排序老鼠ID sorted_mice = sorted(mouse_data.keys(), key=lambda mid: avg_positions.get(mid, float('inf'))) print(f"Mice sorted from left to right: {sorted_mice}") # 对数据进行平滑处理 def smooth_data(data, window_size=5): """使用移动平均平滑数据""" if len(data) < window_size: return data smoothed = [] for i in range(len(data)): start = max(0, i - window_size // 2) end = min(len(data), i + window_size // 2 + 1) window = data[start:end] smoothed.append(sum(window) / len(window)) return smoothed # 创建子图 num_mice = len(mouse_data) fig, axes = plt.subplots(num_mice, 1, figsize=(12, 4*num_mice), sharex=True) # 如果只有一只鼠,确保axes是列表 if num_mice == 1: axes = [axes] # 绘制每只老鼠的挣扎得分曲线,按从左到右的顺序 for idx, mouse_id in enumerate(sorted_mice): data = mouse_data[mouse_id] ax = axes[idx] # 平滑数据 smoothed_scores = smooth_data(data['scores'], window_size=5) # 绘制曲线 ax.plot(data['seconds'], smoothed_scores, label=f"Smoothed", color='blue', linewidth=2) ax.plot(data['seconds'], data['scores'], label=f"Raw", color='lightblue', alpha=0.5, linewidth=1) # 标记挣扎区域 for i, is_struggling in enumerate(data['struggling']): if is_struggling: ax.axvspan(data['seconds'][i]-0.5/fps, data['seconds'][i]+0.5/fps, alpha=0.1, color='red') # 绘制阈值线 ax.axhline(y=threshold, color='r', linestyle='--', label=f"Threshold ({threshold:.2f})") # 设置图表 ax.set_ylabel('Struggle Score') position_text = f"(Position: Left #{sorted_mice.index(mouse_id)+1})" if mouse_id in avg_positions else "" ax.set_title(f'Mouse {mouse_id} Struggle Score {position_text}') ax.legend(loc='upper right') ax.grid(True) # 设置Y轴范围0-1 ax.set_ylim(-0.05, 1.05) # 设置共享的X轴标签 axes[-1].set_xlabel('Time (seconds)') # 动态调整x轴范围,精确到0.1秒 if frames: start_time = self.start_frame / fps end_time = max(frames) / fps # 扩展一点范围以便更好地显示 axes[-1].set_xlim(start_time, end_time) # 设置次要刻度(细网格线) tick_interval = 0.1 # 保持0.1秒的细网格 minor_ticks = np.arange(start_time, end_time + tick_interval, tick_interval) axes[-1].set_xticks(minor_ticks, minor=True) # 设置主要刻度(标签和粗网格线)- 整秒 major_start = math.ceil(start_time) major_end = math.floor(end_time) major_ticks = np.arange(major_start, major_end + 1, 1.0) # 整秒刻度 axes[-1].set_xticks(major_ticks) axes[-1].set_xticklabels([f"{int(t)}" for t in major_ticks]) # 整数秒标签 # 设置网格 axes[-1].grid(True, which='both') axes[-1].grid(which='minor', alpha=0.2) axes[-1].grid(which='major', alpha=0.5) plt.tight_layout() # 保存图表到临时文件并返回路径 temp_file = tempfile.NamedTemporaryFile(suffix='.png', delete=False) plt.savefig(temp_file.name, dpi=150, bbox_inches='tight') plt.close() print(f"Time series plot saved to: {temp_file.name}") return temp_file.name except Exception as e: import traceback print(f"Error generating time series plot: {str(e)}") traceback.print_exc() return None if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description="鼠强迫游泳实验挣扎度分析") parser.add_argument('--video', type=str, required=True, help='输入视频路径') parser.add_argument('--model', type=str, required=True, help='模型文件路径') parser.add_argument('--output', type=str, help='输出视频路径') parser.add_argument('--csv', type=str, help='输出CSV结果路径') parser.add_argument('--conf', type=float, default=0.25, help='置信度阈值') parser.add_argument('--iou', type=float, default=0.45, help='IOU阈值') parser.add_argument('--max-det', type=int, default=20, help='最大检测数量') parser.add_argument('--threshold', type=float, default=0.3, help='挣扎阈值') parser.add_argument('--start', type=int, default=0, help='起始帧') parser.add_argument('--end', type=int, default=None, help='结束帧') parser.add_argument('--verbose', action='store_true', help='详细输出') args = parser.parse_args() # 设置输出路径 if not args.output: video_name = os.path.splitext(os.path.basename(args.video))[0] args.output = os.path.join(os.path.dirname(args.video), f"{video_name}_out.mp4") if not args.csv: video_name = os.path.splitext(os.path.basename(args.video))[0] args.csv = os.path.join(os.path.dirname(args.video), f"{video_name}_results.csv") # 创建分析器并处理 analyzer = MouseTrackerAnalyzer( model_path=args.model, conf=args.conf, iou=args.iou, max_det=args.max_det, verbose=args.verbose ) analyzer.struggle_threshold = args.threshold # 进度回调函数 def progress_callback(progress, frame, results): print(f"处理进度: {progress}%, 检测到 {len(results)} 个对象") # 处理视频 analyzer.process_video( video_path=args.video, output_path=args.output, start_frame=args.start, end_frame=args.end, callback=progress_callback ) # 保存结果 analyzer.save_results(args.csv) # 生成分析图表 plot_path = analyzer.generate_time_series_plot() if plot_path: print(f"挣扎度时序分析图已保存到: {plot_path}") print(f"分析完成,视频已保存到: {args.output}") print(f"结果数据已保存到: {args.csv}")