|
|
import os |
|
|
import cv2 |
|
|
import numpy as np |
|
|
import matplotlib.pyplot as plt |
|
|
import pandas as pd |
|
|
import collections |
|
|
import tempfile |
|
|
from ultralytics import YOLO |
|
|
import math |
|
|
|
|
|
class MouseTrackerAnalyzer: |
|
|
"""基于Ultralytics对象跟踪的鼠强迫游泳实验挣扎度分析器""" |
|
|
def __init__(self, model_path, history_size=5, conf=0.25, iou=0.45, max_det=20, verbose=False): |
|
|
|
|
|
self.model = YOLO(model_path, task="segment", verbose=False) |
|
|
self.history_size = history_size |
|
|
self.verbose = verbose |
|
|
self.struggle_threshold = 0.3 |
|
|
|
|
|
|
|
|
self.conf = conf |
|
|
self.iou = iou |
|
|
self.max_det = max_det |
|
|
|
|
|
|
|
|
self.colors = [ |
|
|
(255, 0, 0), |
|
|
(0, 255, 0), |
|
|
(0, 0, 255), |
|
|
(255, 255, 0), |
|
|
(255, 0, 255), |
|
|
(0, 255, 255), |
|
|
(128, 0, 0), |
|
|
(128, 0, 128), |
|
|
(0, 128, 128), |
|
|
(192, 192, 192), |
|
|
(128, 128, 128), |
|
|
(255, 128, 0), |
|
|
(255, 0, 128), |
|
|
(0, 128, 255), |
|
|
(128, 255, 0), |
|
|
(0, 255, 128) |
|
|
] |
|
|
|
|
|
self.prev_masks = {} |
|
|
self.histories = {} |
|
|
self.track_ids = set() |
|
|
|
|
|
|
|
|
self.cap = None |
|
|
self.writer = None |
|
|
self.frame_id = 0 |
|
|
self.results = [] |
|
|
self.start_frame = 0 |
|
|
self.end_frame = 0 |
|
|
|
|
|
def init_video(self, video_path, output_path=None, start_frame=0, end_frame=None): |
|
|
"""初始化视频处理""" |
|
|
|
|
|
self.cap = cv2.VideoCapture(video_path) |
|
|
if not self.cap.isOpened(): |
|
|
raise IOError(f"无法打开视频 {video_path}") |
|
|
|
|
|
|
|
|
width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
|
|
height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
|
|
fps = self.cap.get(cv2.CAP_PROP_FPS) or 30 |
|
|
self.fps = max(fps, 1.0) |
|
|
total_frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
|
|
|
|
if self.verbose: |
|
|
print(f"视频尺寸: {width}x{height}, 帧率: {fps}, 总帧数: {total_frames}") |
|
|
|
|
|
|
|
|
self.start_frame = start_frame |
|
|
self.end_frame = end_frame if end_frame is not None else total_frames - 1 |
|
|
|
|
|
|
|
|
if self.start_frame < 0: |
|
|
self.start_frame = 0 |
|
|
if self.end_frame >= total_frames: |
|
|
self.end_frame = total_frames - 1 |
|
|
if self.start_frame > self.end_frame: |
|
|
self.start_frame, self.end_frame = self.end_frame, self.start_frame |
|
|
|
|
|
|
|
|
if self.start_frame > 0: |
|
|
self.cap.set(cv2.CAP_PROP_POS_FRAMES, self.start_frame) |
|
|
|
|
|
|
|
|
if output_path and output_path.lower().endswith(('.mp4', '.avi')): |
|
|
|
|
|
fourcc = cv2.VideoWriter_fourcc(*'mp4v') |
|
|
|
|
|
self.writer = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) |
|
|
if self.writer.isOpened(): |
|
|
print(f"成功创建输出视频: {output_path}, 尺寸: {width}x{height}") |
|
|
else: |
|
|
print(f"警告: 无法创建输出视频 {output_path}") |
|
|
|
|
|
|
|
|
self.frame_id = self.start_frame |
|
|
self.results = [] |
|
|
self.prev_masks.clear() |
|
|
self.histories.clear() |
|
|
self.track_ids.clear() |
|
|
|
|
|
if self.verbose: |
|
|
print(f"视频初始化完成: 总帧数 {total_frames}, 分析范围 {self.start_frame}-{self.end_frame}") |
|
|
|
|
|
return total_frames, self.start_frame, self.end_frame |
|
|
|
|
|
def process_frame(self, frame, frame_id): |
|
|
"""处理单帧,返回可视化帧和本帧结果列表""" |
|
|
if self.verbose and frame_id % 10 == 0: |
|
|
print(f"process_frame: 处理帧 {frame_id}") |
|
|
|
|
|
try: |
|
|
|
|
|
results = self.model.track( |
|
|
frame, |
|
|
persist=True, |
|
|
conf=self.conf, |
|
|
iou=self.iou, |
|
|
max_det=self.max_det, |
|
|
verbose=False |
|
|
) |
|
|
|
|
|
|
|
|
frame_results = [] |
|
|
|
|
|
if results[0].boxes is None or len(results[0].boxes) == 0: |
|
|
if self.verbose and frame_id % 50 == 0: |
|
|
print("没有检测到任何对象") |
|
|
return frame.copy(), [] |
|
|
|
|
|
|
|
|
if hasattr(results[0], 'masks') and results[0].masks is not None: |
|
|
|
|
|
masks = results[0].masks.data.cpu().numpy() |
|
|
track_ids = results[0].boxes.id |
|
|
|
|
|
if track_ids is None: |
|
|
if self.verbose and frame_id % 50 == 0: |
|
|
print("没有获取到跟踪ID") |
|
|
return frame.copy(), [] |
|
|
|
|
|
track_ids = track_ids.int().cpu().numpy() |
|
|
|
|
|
if self.verbose and frame_id % 50 == 0: |
|
|
print(f"检测到 {len(masks)} 个掩码,{len(track_ids)} 个跟踪ID") |
|
|
|
|
|
|
|
|
for track_id in track_ids: |
|
|
self.track_ids.add(int(track_id)) |
|
|
|
|
|
|
|
|
for i, (mask, track_id) in enumerate(zip(masks, track_ids)): |
|
|
track_id = int(track_id) |
|
|
|
|
|
|
|
|
bin_mask = (mask > 0.2).astype(np.uint8) |
|
|
|
|
|
|
|
|
kernel = np.ones((5,5), np.uint8) |
|
|
bin_mask = cv2.morphologyEx(bin_mask, cv2.MORPH_CLOSE, kernel) |
|
|
|
|
|
|
|
|
if bin_mask.shape != (frame.shape[0], frame.shape[1]): |
|
|
bin_mask = cv2.resize(bin_mask, (frame.shape[1], frame.shape[0]), interpolation=cv2.INTER_NEAREST) |
|
|
|
|
|
|
|
|
if track_id in self.prev_masks: |
|
|
prev_mask = self.prev_masks[track_id] |
|
|
|
|
|
if prev_mask.shape != bin_mask.shape: |
|
|
prev_mask = cv2.resize(prev_mask, (bin_mask.shape[1], bin_mask.shape[0]), interpolation=cv2.INTER_NEAREST) |
|
|
inter = np.logical_and(prev_mask > 0, bin_mask > 0).sum() |
|
|
union = np.logical_or(prev_mask > 0, bin_mask > 0).sum() |
|
|
iou = inter / union if union > 0 else 0 |
|
|
score = 1 - iou |
|
|
if self.verbose and frame_id % 50 == 0: |
|
|
print(f"跟踪ID {track_id} 挣扎分数: {score:.4f} (IoU: {iou:.4f})") |
|
|
else: |
|
|
score = 0.0 |
|
|
if self.verbose and frame_id % 50 == 0: |
|
|
print(f"跟踪ID {track_id} 初始帧,分数为0") |
|
|
|
|
|
|
|
|
self.prev_masks[track_id] = bin_mask |
|
|
|
|
|
if track_id not in self.histories: |
|
|
self.histories[track_id] = collections.deque(maxlen=self.history_size) |
|
|
self.histories[track_id].append(score) |
|
|
|
|
|
|
|
|
is_struggling = score >= self.struggle_threshold |
|
|
|
|
|
|
|
|
ys, xs = np.where(bin_mask > 0) |
|
|
if len(xs) > 0: |
|
|
centroid = (int(xs.mean()), int(ys.mean())) |
|
|
else: |
|
|
|
|
|
box = results[0].boxes[i].xyxy.cpu().numpy()[0] |
|
|
centroid = (int((box[0] + box[2]) / 2), int((box[1] + box[3]) / 2)) |
|
|
|
|
|
|
|
|
frame_results.append({ |
|
|
'id': track_id, |
|
|
'score': float(score), |
|
|
'centroid': centroid, |
|
|
'is_struggling': is_struggling |
|
|
}) |
|
|
else: |
|
|
if self.verbose and frame_id % 50 == 0: |
|
|
print("没有检测到任何掩码") |
|
|
return frame.copy(), [] |
|
|
|
|
|
|
|
|
annotated = frame.copy() |
|
|
|
|
|
|
|
|
for result in frame_results: |
|
|
track_id = result['id'] |
|
|
color = self.colors[track_id % len(self.colors)] |
|
|
|
|
|
|
|
|
if track_id in self.prev_masks: |
|
|
mask = self.prev_masks[track_id] |
|
|
|
|
|
if mask.shape != (frame.shape[0], frame.shape[1]): |
|
|
mask = cv2.resize(mask, (frame.shape[1], frame.shape[0]), interpolation=cv2.INTER_NEAREST) |
|
|
mask_overlay = np.zeros_like(frame) |
|
|
mask_overlay[mask > 0] = color |
|
|
|
|
|
|
|
|
contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
|
|
cv2.drawContours(annotated, contours, -1, color, 2) |
|
|
|
|
|
|
|
|
cv2.addWeighted(annotated, 1.0, mask_overlay, 0.4, 0, annotated) |
|
|
|
|
|
|
|
|
centroid = result['centroid'] |
|
|
status_text = "Struggle" if result['is_struggling'] else "Static" |
|
|
cv2.putText(annotated, f"ID:{track_id} {status_text}", |
|
|
(centroid[0], centroid[1]), |
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2) |
|
|
|
|
|
|
|
|
cv2.rectangle(annotated, (0, 0), (frame.shape[1], 40), (0, 0, 0), -1) |
|
|
|
|
|
|
|
|
struggling_count = sum(1 for r in frame_results if r['is_struggling']) |
|
|
total_count = len(frame_results) |
|
|
|
|
|
|
|
|
cv2.putText(annotated, f"Total: {total_count} Struggling: {struggling_count}", |
|
|
(10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2) |
|
|
|
|
|
|
|
|
|
|
|
if annotated.dtype != np.uint8: |
|
|
annotated = annotated.astype(np.uint8) |
|
|
|
|
|
return annotated, frame_results |
|
|
|
|
|
except Exception as e: |
|
|
import traceback |
|
|
if self.verbose: |
|
|
print(f"处理帧时出错: {str(e)}") |
|
|
traceback.print_exc() |
|
|
|
|
|
return frame.copy(), [] |
|
|
|
|
|
def process_video(self, video_path, output_path=None, start_frame=0, end_frame=None, callback=None): |
|
|
"""处理整段视频,可选的回调函数用于更新进度""" |
|
|
|
|
|
total_frames, start, end = self.init_video(video_path, output_path, start_frame, end_frame) |
|
|
self.results = [] |
|
|
|
|
|
frame_id = start |
|
|
processed_frames = 0 |
|
|
frames_to_process = end - start + 1 |
|
|
last_progress = -1 |
|
|
|
|
|
|
|
|
debug_frame_saved = False |
|
|
|
|
|
while frame_id <= end: |
|
|
ret, frame = self.cap.read() |
|
|
if not ret: |
|
|
break |
|
|
|
|
|
|
|
|
annotated, frame_res = self.process_frame(frame, frame_id) |
|
|
self.results.append(frame_res) |
|
|
|
|
|
|
|
|
if not debug_frame_saved and len(frame_res) > 0: |
|
|
debug_frame_path = os.path.join(os.path.dirname(output_path), "debug_frame.jpg") |
|
|
cv2.imwrite(debug_frame_path, annotated) |
|
|
print(f"调试: 保存了标注帧到 {debug_frame_path}") |
|
|
debug_frame_saved = True |
|
|
|
|
|
|
|
|
if self.writer: |
|
|
|
|
|
if len(annotated.shape) == 3 and annotated.shape[2] == 3: |
|
|
|
|
|
|
|
|
if frame_id == start: |
|
|
print(f"调试: 写入标注帧到视频,形状: {annotated.shape}") |
|
|
|
|
|
try: |
|
|
self.writer.write(annotated) |
|
|
except Exception as e: |
|
|
print(f"调试: 写入帧到视频时出错: {str(e)}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
|
|
|
|
|
|
processed_frames += 1 |
|
|
progress = int(100 * processed_frames / frames_to_process) |
|
|
|
|
|
if progress != last_progress and callback: |
|
|
callback(progress, annotated, frame_res) |
|
|
last_progress = progress |
|
|
|
|
|
frame_id += 1 |
|
|
|
|
|
|
|
|
self.cap.release() |
|
|
if self.writer: |
|
|
self.writer.release() |
|
|
print(f"调试: 视频写入完成,保存到: {output_path}") |
|
|
|
|
|
return self.results |
|
|
|
|
|
def save_results(self, csv_path): |
|
|
"""导出分析结果到 CSV""" |
|
|
import csv |
|
|
with open(csv_path, 'w', newline='') as f: |
|
|
writer = csv.writer(f) |
|
|
writer.writerow(['frame_id', 'mouse_id', 'score', 'is_struggling']) |
|
|
for fid, frs in enumerate(self.results): |
|
|
for fr in frs: |
|
|
writer.writerow([ |
|
|
fid + self.start_frame, |
|
|
fr['id'], |
|
|
f"{fr['score']:.4f}", |
|
|
1 if fr.get('is_struggling', False) else 0 |
|
|
]) |
|
|
|
|
|
def generate_time_series_plot(self, threshold=None): |
|
|
"""生成时序图分析""" |
|
|
try: |
|
|
print(f"Starting to generate time series plot with {len(self.results)} frames of data") |
|
|
|
|
|
if not self.results or len(self.results) < 10: |
|
|
print("Not enough data for time series plot (need at least 10 frames)") |
|
|
return None |
|
|
|
|
|
|
|
|
if threshold is None: |
|
|
threshold = self.struggle_threshold |
|
|
|
|
|
|
|
|
fps = getattr(self, 'fps', None) |
|
|
if fps is None or fps <= 0: |
|
|
fps = 30 |
|
|
print(f"Warning: Invalid frame rate detected, using default: {fps} fps") |
|
|
else: |
|
|
print(f"Using frame rate: {fps} fps") |
|
|
|
|
|
|
|
|
frames = [] |
|
|
mouse_data = {} |
|
|
mouse_positions = {} |
|
|
|
|
|
for frame_id, frame_results in enumerate(self.results): |
|
|
frames.append(frame_id + self.start_frame) |
|
|
for result in frame_results: |
|
|
mouse_id = result['id'] |
|
|
if mouse_id not in mouse_data: |
|
|
mouse_data[mouse_id] = {'frames': [], 'seconds': [], 'scores': [], 'struggling': []} |
|
|
mouse_positions[mouse_id] = [] |
|
|
|
|
|
frame_num = frame_id + self.start_frame |
|
|
second = frame_num / fps |
|
|
|
|
|
mouse_data[mouse_id]['frames'].append(frame_num) |
|
|
mouse_data[mouse_id]['seconds'].append(second) |
|
|
mouse_data[mouse_id]['scores'].append(result['score']) |
|
|
mouse_data[mouse_id]['struggling'].append(1 if result.get('is_struggling', False) else 0) |
|
|
|
|
|
|
|
|
if 'centroid' in result: |
|
|
mouse_positions[mouse_id].append(result['centroid'][0]) |
|
|
|
|
|
print(f"Processed data for {len(mouse_data)} mice") |
|
|
if not mouse_data: |
|
|
print("No valid mouse data to plot") |
|
|
return None |
|
|
|
|
|
|
|
|
avg_positions = {} |
|
|
for mouse_id, positions in mouse_positions.items(): |
|
|
if positions: |
|
|
avg_positions[mouse_id] = sum(positions) / len(positions) |
|
|
else: |
|
|
avg_positions[mouse_id] = float('inf') |
|
|
|
|
|
|
|
|
sorted_mice = sorted(mouse_data.keys(), key=lambda mid: avg_positions.get(mid, float('inf'))) |
|
|
print(f"Mice sorted from left to right: {sorted_mice}") |
|
|
|
|
|
|
|
|
def smooth_data(data, window_size=5): |
|
|
"""使用移动平均平滑数据""" |
|
|
if len(data) < window_size: |
|
|
return data |
|
|
smoothed = [] |
|
|
for i in range(len(data)): |
|
|
start = max(0, i - window_size // 2) |
|
|
end = min(len(data), i + window_size // 2 + 1) |
|
|
window = data[start:end] |
|
|
smoothed.append(sum(window) / len(window)) |
|
|
return smoothed |
|
|
|
|
|
|
|
|
num_mice = len(mouse_data) |
|
|
fig, axes = plt.subplots(num_mice, 1, figsize=(12, 4*num_mice), sharex=True) |
|
|
|
|
|
|
|
|
if num_mice == 1: |
|
|
axes = [axes] |
|
|
|
|
|
|
|
|
for idx, mouse_id in enumerate(sorted_mice): |
|
|
data = mouse_data[mouse_id] |
|
|
ax = axes[idx] |
|
|
|
|
|
|
|
|
smoothed_scores = smooth_data(data['scores'], window_size=5) |
|
|
|
|
|
|
|
|
ax.plot(data['seconds'], smoothed_scores, label=f"Smoothed", color='blue', linewidth=2) |
|
|
ax.plot(data['seconds'], data['scores'], label=f"Raw", color='lightblue', alpha=0.5, linewidth=1) |
|
|
|
|
|
|
|
|
for i, is_struggling in enumerate(data['struggling']): |
|
|
if is_struggling: |
|
|
ax.axvspan(data['seconds'][i]-0.5/fps, data['seconds'][i]+0.5/fps, alpha=0.1, color='red') |
|
|
|
|
|
|
|
|
ax.axhline(y=threshold, color='r', linestyle='--', label=f"Threshold ({threshold:.2f})") |
|
|
|
|
|
|
|
|
ax.set_ylabel('Struggle Score') |
|
|
position_text = f"(Position: Left #{sorted_mice.index(mouse_id)+1})" if mouse_id in avg_positions else "" |
|
|
ax.set_title(f'Mouse {mouse_id} Struggle Score {position_text}') |
|
|
ax.legend(loc='upper right') |
|
|
ax.grid(True) |
|
|
|
|
|
|
|
|
ax.set_ylim(-0.05, 1.05) |
|
|
|
|
|
|
|
|
axes[-1].set_xlabel('Time (seconds)') |
|
|
|
|
|
|
|
|
if frames: |
|
|
start_time = self.start_frame / fps |
|
|
end_time = max(frames) / fps |
|
|
|
|
|
axes[-1].set_xlim(start_time, end_time) |
|
|
|
|
|
|
|
|
tick_interval = 0.1 |
|
|
minor_ticks = np.arange(start_time, end_time + tick_interval, tick_interval) |
|
|
axes[-1].set_xticks(minor_ticks, minor=True) |
|
|
|
|
|
|
|
|
major_start = math.ceil(start_time) |
|
|
major_end = math.floor(end_time) |
|
|
major_ticks = np.arange(major_start, major_end + 1, 1.0) |
|
|
axes[-1].set_xticks(major_ticks) |
|
|
axes[-1].set_xticklabels([f"{int(t)}" for t in major_ticks]) |
|
|
|
|
|
|
|
|
axes[-1].grid(True, which='both') |
|
|
axes[-1].grid(which='minor', alpha=0.2) |
|
|
axes[-1].grid(which='major', alpha=0.5) |
|
|
|
|
|
plt.tight_layout() |
|
|
|
|
|
|
|
|
temp_file = tempfile.NamedTemporaryFile(suffix='.png', delete=False) |
|
|
plt.savefig(temp_file.name, dpi=150, bbox_inches='tight') |
|
|
plt.close() |
|
|
|
|
|
print(f"Time series plot saved to: {temp_file.name}") |
|
|
return temp_file.name |
|
|
|
|
|
except Exception as e: |
|
|
import traceback |
|
|
print(f"Error generating time series plot: {str(e)}") |
|
|
traceback.print_exc() |
|
|
return None |
|
|
|
|
|
if __name__ == "__main__": |
|
|
import argparse |
|
|
|
|
|
parser = argparse.ArgumentParser(description="鼠强迫游泳实验挣扎度分析") |
|
|
parser.add_argument('--video', type=str, required=True, help='输入视频路径') |
|
|
parser.add_argument('--model', type=str, required=True, help='模型文件路径') |
|
|
parser.add_argument('--output', type=str, help='输出视频路径') |
|
|
parser.add_argument('--csv', type=str, help='输出CSV结果路径') |
|
|
parser.add_argument('--conf', type=float, default=0.25, help='置信度阈值') |
|
|
parser.add_argument('--iou', type=float, default=0.45, help='IOU阈值') |
|
|
parser.add_argument('--max-det', type=int, default=20, help='最大检测数量') |
|
|
parser.add_argument('--threshold', type=float, default=0.3, help='挣扎阈值') |
|
|
parser.add_argument('--start', type=int, default=0, help='起始帧') |
|
|
parser.add_argument('--end', type=int, default=None, help='结束帧') |
|
|
parser.add_argument('--verbose', action='store_true', help='详细输出') |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
if not args.output: |
|
|
video_name = os.path.splitext(os.path.basename(args.video))[0] |
|
|
args.output = os.path.join(os.path.dirname(args.video), f"{video_name}_out.mp4") |
|
|
|
|
|
if not args.csv: |
|
|
video_name = os.path.splitext(os.path.basename(args.video))[0] |
|
|
args.csv = os.path.join(os.path.dirname(args.video), f"{video_name}_results.csv") |
|
|
|
|
|
|
|
|
analyzer = MouseTrackerAnalyzer( |
|
|
model_path=args.model, |
|
|
conf=args.conf, |
|
|
iou=args.iou, |
|
|
max_det=args.max_det, |
|
|
verbose=args.verbose |
|
|
) |
|
|
analyzer.struggle_threshold = args.threshold |
|
|
|
|
|
|
|
|
def progress_callback(progress, frame, results): |
|
|
print(f"处理进度: {progress}%, 检测到 {len(results)} 个对象") |
|
|
|
|
|
|
|
|
analyzer.process_video( |
|
|
video_path=args.video, |
|
|
output_path=args.output, |
|
|
start_frame=args.start, |
|
|
end_frame=args.end, |
|
|
callback=progress_callback |
|
|
) |
|
|
|
|
|
|
|
|
analyzer.save_results(args.csv) |
|
|
|
|
|
|
|
|
plot_path = analyzer.generate_time_series_plot() |
|
|
if plot_path: |
|
|
print(f"挣扎度时序分析图已保存到: {plot_path}") |
|
|
|
|
|
print(f"分析完成,视频已保存到: {args.output}") |
|
|
print(f"结果数据已保存到: {args.csv}") |