Update app.py
Browse files
app.py
CHANGED
|
@@ -1,185 +1,271 @@
|
|
| 1 |
-
import
|
| 2 |
import cv2
|
| 3 |
import numpy as np
|
| 4 |
-
import torch
|
| 5 |
-
from ultralytics import YOLO # pip install ultralytics
|
| 6 |
import gradio as gr
|
| 7 |
-
import
|
|
|
|
| 8 |
|
| 9 |
-
#
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
|
|
|
| 14 |
|
| 15 |
-
#
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
pass
|
| 22 |
-
|
| 23 |
-
@spaces.GPU(duration=600) # ZeroGPU 环境下执行该函数,超时 600s
|
| 24 |
-
def analyze_video(video_path, num_mice, time_range, window_size_sec=1, fps=30):
|
| 25 |
-
"""
|
| 26 |
-
分割 → 跟踪 → 计算挣扎强度,仅分析指定时间区间
|
| 27 |
-
返回:标注后视频 & 绘制的挣扎强度曲线 (matplotlib Figure)
|
| 28 |
-
"""
|
| 29 |
-
# 打开视频并获取基本信息
|
| 30 |
cap = cv2.VideoCapture(video_path)
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
# 跳转到指定起始帧
|
| 40 |
-
cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
|
| 41 |
-
|
| 42 |
-
# 输出视频初始化
|
| 43 |
-
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
| 44 |
-
out_path = "output.mp4"
|
| 45 |
-
out = cv2.VideoWriter(out_path, fourcc, vid_fps, (width, height))
|
| 46 |
-
|
| 47 |
-
prev_centroids = [None] * num_mice
|
| 48 |
-
prev_masks = [None] * num_mice
|
| 49 |
-
struggle_records = [[] for _ in range(num_mice)]
|
| 50 |
-
frame_idx = start_frame
|
| 51 |
-
|
| 52 |
-
while frame_idx <= end_frame:
|
| 53 |
-
ret, frame = cap.read()
|
| 54 |
-
if not ret:
|
| 55 |
-
break
|
| 56 |
-
|
| 57 |
-
# 分割推理
|
| 58 |
-
device = "cuda" if use_cuda else "cpu"
|
| 59 |
-
results = model(frame, stream=True, device=device, conf=0.25)
|
| 60 |
-
res = next(results)
|
| 61 |
-
|
| 62 |
-
# 无检测帧处理
|
| 63 |
-
if res.masks is None or res.masks.data is None:
|
| 64 |
-
for mid in range(num_mice):
|
| 65 |
-
struggle_records[mid].append(None)
|
| 66 |
-
out.write(frame)
|
| 67 |
-
frame_idx += 1
|
| 68 |
-
continue
|
| 69 |
-
|
| 70 |
-
# 获取并对齐掩膜至帧尺寸
|
| 71 |
-
masks = res.masks.data.cpu().numpy() # (N, H_model, W_model)
|
| 72 |
-
aligned_masks = []
|
| 73 |
-
for m in masks:
|
| 74 |
-
m_bin = (m > 0).astype(np.uint8)
|
| 75 |
-
m_res = cv2.resize(m_bin, (width, height), interpolation=cv2.INTER_NEAREST)
|
| 76 |
-
aligned_masks.append(m_res)
|
| 77 |
-
aligned_masks = np.array(aligned_masks)
|
| 78 |
-
|
| 79 |
-
# 计算质心 & ID 分配 (nearest-centroid)
|
| 80 |
-
curr_centroids = []
|
| 81 |
-
for m in aligned_masks:
|
| 82 |
-
ys, xs = np.where(m > 0)
|
| 83 |
-
curr_centroids.append((int(xs.mean()), int(ys.mean())) if xs.size else None)
|
| 84 |
-
assignments = [-1] * len(curr_centroids)
|
| 85 |
-
unused_ids = set(range(num_mice))
|
| 86 |
-
for i, c in enumerate(curr_centroids):
|
| 87 |
-
if c is None:
|
| 88 |
-
continue
|
| 89 |
-
best_j, best_d = None, float("inf")
|
| 90 |
-
for j in unused_ids:
|
| 91 |
-
pc = prev_centroids[j]
|
| 92 |
-
if pc is None:
|
| 93 |
-
continue
|
| 94 |
-
d = (c[0] - pc[0])**2 + (c[1] - pc[1])**2
|
| 95 |
-
if d < best_d:
|
| 96 |
-
best_j, best_d = j, d
|
| 97 |
-
if best_j is not None and best_d < 50**2:
|
| 98 |
-
assignments[i] = best_j
|
| 99 |
-
unused_ids.remove(best_j)
|
| 100 |
-
for i in range(len(curr_centroids)):
|
| 101 |
-
if assignments[i] < 0 and unused_ids:
|
| 102 |
-
assignments[i] = unused_ids.pop()
|
| 103 |
-
|
| 104 |
-
# 计算挣扎强度 & 可视化叠加
|
| 105 |
-
for i, m in enumerate(aligned_masks):
|
| 106 |
-
mid = assignments[i]
|
| 107 |
-
if mid < 0:
|
| 108 |
-
continue
|
| 109 |
-
prev_m = prev_masks[mid]
|
| 110 |
-
if prev_m is None:
|
| 111 |
-
struggle_records[mid].append(None)
|
| 112 |
-
else:
|
| 113 |
-
struggle = int(np.logical_xor(prev_m, m).sum())
|
| 114 |
-
struggle_records[mid].append(struggle)
|
| 115 |
-
|
| 116 |
-
# 构建三通道掩膜
|
| 117 |
-
mask_rgb = np.stack([
|
| 118 |
-
np.zeros_like(m),
|
| 119 |
-
m * 255,
|
| 120 |
-
np.zeros_like(m)
|
| 121 |
-
], axis=-1).astype(np.uint8)
|
| 122 |
-
frame = cv2.addWeighted(frame, 1, mask_rgb, 0.5, 0)
|
| 123 |
-
|
| 124 |
-
centroid = curr_centroids[i]
|
| 125 |
-
if centroid:
|
| 126 |
-
cv2.putText(frame, f"ID:{mid}", centroid,
|
| 127 |
-
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,255,0), 2)
|
| 128 |
-
|
| 129 |
-
prev_centroids[mid] = curr_centroids[i]
|
| 130 |
-
prev_masks[mid] = m.copy()
|
| 131 |
-
|
| 132 |
-
out.write(frame)
|
| 133 |
-
frame_idx += 1
|
| 134 |
-
|
| 135 |
cap.release()
|
| 136 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
|
| 159 |
-
#
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
|
|
|
|
|
|
| 166 |
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
|
| 175 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
|
|
|
|
| 184 |
if __name__ == "__main__":
|
| 185 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
import cv2
|
| 3 |
import numpy as np
|
|
|
|
|
|
|
| 4 |
import gradio as gr
|
| 5 |
+
import tempfile
|
| 6 |
+
from mouse_tracker import MouseTrackerAnalyzer
|
| 7 |
|
| 8 |
+
# 全局变量
|
| 9 |
+
analyzer = None
|
| 10 |
+
video_file_path = None
|
| 11 |
+
model_file_path = None
|
| 12 |
+
total_frames = 0
|
| 13 |
+
output_path = None
|
| 14 |
|
| 15 |
+
# 从视频中提取特定帧
|
| 16 |
+
def extract_frame(video_path, frame_num):
|
| 17 |
+
"""从视频中提取特定帧"""
|
| 18 |
+
if not video_path:
|
| 19 |
+
return None
|
| 20 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
cap = cv2.VideoCapture(video_path)
|
| 22 |
+
if not cap.isOpened():
|
| 23 |
+
return None
|
| 24 |
+
|
| 25 |
+
# 设置帧位置
|
| 26 |
+
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_num)
|
| 27 |
+
|
| 28 |
+
# 读取帧
|
| 29 |
+
ret, frame = cap.read()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
cap.release()
|
| 31 |
+
|
| 32 |
+
if not ret:
|
| 33 |
+
return None
|
| 34 |
+
|
| 35 |
+
# 转换为RGB格式
|
| 36 |
+
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 37 |
+
return frame_rgb
|
| 38 |
|
| 39 |
+
# 选择视频文件
|
| 40 |
+
def select_video(video_file):
|
| 41 |
+
global video_file_path, total_frames
|
| 42 |
+
|
| 43 |
+
if not video_file:
|
| 44 |
+
return None, "Please select a video file", gr.Slider(minimum=0, maximum=0, value=0), gr.Slider(minimum=0, maximum=0, value=0)
|
| 45 |
+
|
| 46 |
+
video_file_path = video_file
|
| 47 |
+
|
| 48 |
+
# 获取视频总帧数
|
| 49 |
+
cap = cv2.VideoCapture(video_file_path)
|
| 50 |
+
if not cap.isOpened():
|
| 51 |
+
return None, "Cannot open video file", gr.Slider(minimum=0, maximum=0, value=0), gr.Slider(minimum=0, maximum=0, value=0)
|
| 52 |
+
|
| 53 |
+
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 54 |
+
|
| 55 |
+
# 提取第一帧
|
| 56 |
+
ret, first_frame = cap.read()
|
| 57 |
+
cap.release()
|
| 58 |
+
|
| 59 |
+
if not ret:
|
| 60 |
+
return None, "Cannot read video frame", gr.Slider(minimum=0, maximum=0, value=0), gr.Slider(minimum=0, maximum=0, value=0)
|
| 61 |
+
|
| 62 |
+
# 转为RGB
|
| 63 |
+
first_frame_rgb = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
|
| 64 |
+
|
| 65 |
+
# 更新帧滑块
|
| 66 |
+
start_slider = gr.Slider(minimum=0, maximum=total_frames-1, value=0, step=1)
|
| 67 |
+
end_slider = gr.Slider(minimum=0, maximum=total_frames-1, value=total_frames-1, step=1)
|
| 68 |
+
|
| 69 |
+
return first_frame_rgb, f"Video loaded successfully, total frames: {total_frames}", start_slider, end_slider
|
| 70 |
|
| 71 |
+
# 选择模型文件
|
| 72 |
+
def select_model(model_file):
|
| 73 |
+
global model_file_path
|
| 74 |
+
|
| 75 |
+
if model_file is None:
|
| 76 |
+
return "Please select a model file"
|
| 77 |
+
|
| 78 |
+
model_file_path = model_file
|
| 79 |
+
return f"Model selected: {os.path.basename(model_file_path)}"
|
| 80 |
|
| 81 |
+
# 预览帧
|
| 82 |
+
def preview_frame(video_file, frame_num):
|
| 83 |
+
if not video_file:
|
| 84 |
+
return None, "Please select a video first"
|
| 85 |
+
|
| 86 |
+
# 从视频提取帧
|
| 87 |
+
frame = extract_frame(video_file, frame_num)
|
| 88 |
+
if frame is None:
|
| 89 |
+
return None, "Cannot read specified frame"
|
| 90 |
+
|
| 91 |
+
return frame, f"Frame {frame_num}"
|
| 92 |
|
| 93 |
+
# 开始分析
|
| 94 |
+
def start_analysis(video, model, conf, iou, max_det, start_frame, end_frame, threshold):
|
| 95 |
+
global analyzer, output_path
|
| 96 |
+
|
| 97 |
+
if not video or not model:
|
| 98 |
+
return None, None, "Please select a video and model file"
|
| 99 |
+
|
| 100 |
+
if start_frame >= end_frame:
|
| 101 |
+
return None, None, "Start frame must be less than end frame"
|
| 102 |
+
|
| 103 |
+
# 创建输出路径
|
| 104 |
+
video_name = os.path.splitext(os.path.basename(video))[0]
|
| 105 |
+
output_path = os.path.join(os.path.dirname(video), f"{video_name}_out.mp4")
|
| 106 |
+
csv_path = os.path.join(os.path.dirname(video), f"{video_name}_results.csv")
|
| 107 |
+
|
| 108 |
+
try:
|
| 109 |
+
# 创建分析器
|
| 110 |
+
analyzer = MouseTrackerAnalyzer(
|
| 111 |
+
model_path=model,
|
| 112 |
+
conf=conf,
|
| 113 |
+
iou=iou,
|
| 114 |
+
max_det=max_det,
|
| 115 |
+
verbose=True # 开启详细日志
|
| 116 |
+
)
|
| 117 |
+
analyzer.struggle_threshold = threshold
|
| 118 |
+
|
| 119 |
+
# 处理视频的进度回调
|
| 120 |
+
def progress_update(progress, frame, results):
|
| 121 |
+
print(f"Processing: {progress}%, Objects detected: {len(results)}")
|
| 122 |
+
|
| 123 |
+
print(f"Processing video: {video}")
|
| 124 |
+
print(f"Output path: {output_path}")
|
| 125 |
+
print(f"Parameters: conf={conf}, iou={iou}, max_det={max_det}, threshold={threshold}")
|
| 126 |
+
|
| 127 |
+
# 提取视频帧数范围并分析
|
| 128 |
+
results = analyzer.process_video(
|
| 129 |
+
video_path=video,
|
| 130 |
+
output_path=output_path,
|
| 131 |
+
start_frame=start_frame,
|
| 132 |
+
end_frame=end_frame,
|
| 133 |
+
callback=progress_update
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
# 保存结果到CSV
|
| 137 |
+
print(f"Saving results to CSV: {csv_path}")
|
| 138 |
+
analyzer.save_results(csv_path)
|
| 139 |
+
print(f"Results saved to CSV with {len(analyzer.results)} frames of data")
|
| 140 |
+
|
| 141 |
+
# 生成分析图表
|
| 142 |
+
print("Generating time series plot...")
|
| 143 |
+
if len(analyzer.results) == 0:
|
| 144 |
+
print("WARNING: No results available for plotting!")
|
| 145 |
+
plot_path = None
|
| 146 |
+
else:
|
| 147 |
+
plot_path = analyzer.generate_time_series_plot()
|
| 148 |
+
if plot_path and os.path.exists(plot_path):
|
| 149 |
+
print(f"Plot generated and saved to: {plot_path}, size: {os.path.getsize(plot_path)/1024:.2f}KB")
|
| 150 |
+
else:
|
| 151 |
+
print(f"Failed to generate plot or plot file does not exist!")
|
| 152 |
+
plot_path = None
|
| 153 |
+
|
| 154 |
+
# 检查输出文件是否存在
|
| 155 |
+
if os.path.exists(output_path):
|
| 156 |
+
file_size = os.path.getsize(output_path) / (1024 * 1024) # MB
|
| 157 |
+
print(f"Output video size: {file_size:.2f}MB")
|
| 158 |
+
|
| 159 |
+
# 处理debug帧
|
| 160 |
+
debug_frame_path = os.path.join(os.path.dirname(output_path), "debug_frame.jpg")
|
| 161 |
+
if os.path.exists(debug_frame_path):
|
| 162 |
+
print(f"Debug frame saved at: {debug_frame_path}")
|
| 163 |
+
|
| 164 |
+
if plot_path and os.path.exists(plot_path):
|
| 165 |
+
print(f"Plot file exists at: {plot_path}, size: {os.path.getsize(plot_path)/1024:.2f}KB")
|
| 166 |
+
|
| 167 |
+
# 确保返回正确的文件路径
|
| 168 |
+
status_message = "Analysis complete. "
|
| 169 |
+
|
| 170 |
+
if os.path.exists(output_path):
|
| 171 |
+
status_message += f"Video saved."
|
| 172 |
+
else:
|
| 173 |
+
status_message += "WARNING: Output video not found. "
|
| 174 |
+
|
| 175 |
+
if plot_path and os.path.exists(plot_path):
|
| 176 |
+
status_message += f" Time series plot generated."
|
| 177 |
+
else:
|
| 178 |
+
status_message += " WARNING: Failed to generate time series plot."
|
| 179 |
+
|
| 180 |
+
status_message += f" Results saved to: {csv_path}"
|
| 181 |
+
|
| 182 |
+
return output_path, plot_path, status_message
|
| 183 |
+
except Exception as e:
|
| 184 |
+
import traceback
|
| 185 |
+
traceback.print_exc()
|
| 186 |
+
return None, None, f"Processing error: {str(e)}"
|
| 187 |
|
| 188 |
+
# 创建Gradio界面
|
| 189 |
+
def create_interface():
|
| 190 |
+
with gr.Blocks(title="Mouse Struggle Analysis - Object Tracking") as app:
|
| 191 |
+
gr.Markdown("# Mouse Forced Swim Test Struggle Analysis (Object Tracking)")
|
| 192 |
+
|
| 193 |
+
with gr.Row():
|
| 194 |
+
with gr.Column(scale=1):
|
| 195 |
+
# 视频和模型选择
|
| 196 |
+
video_input = gr.Video(label="Input Video")
|
| 197 |
+
model_input = gr.File(label="Model File (.pt format recommended)")
|
| 198 |
+
|
| 199 |
+
# 参数设置
|
| 200 |
+
with gr.Row():
|
| 201 |
+
conf = gr.Slider(minimum=0.1, maximum=0.9, value=0.25, step=0.05, label="Confidence Threshold")
|
| 202 |
+
iou = gr.Slider(minimum=0.1, maximum=0.9, value=0.45, step=0.05, label="IoU Threshold")
|
| 203 |
+
|
| 204 |
+
with gr.Row():
|
| 205 |
+
max_det = gr.Slider(minimum=1, maximum=50, value=20, step=1, label="Max Detections")
|
| 206 |
+
threshold = gr.Slider(minimum=0, maximum=1, value=0.3, step=0.01, label="Struggle Threshold")
|
| 207 |
+
|
| 208 |
+
# 帧选择
|
| 209 |
+
start_frame = gr.Slider(minimum=0, maximum=999999, value=0, step=1, label="Start Frame")
|
| 210 |
+
end_frame = gr.Slider(minimum=0, maximum=999999, value=999999, step=1, label="End Frame")
|
| 211 |
+
|
| 212 |
+
# 预览按钮
|
| 213 |
+
preview_btn = gr.Button("Preview Frame")
|
| 214 |
+
|
| 215 |
+
# 开始分析
|
| 216 |
+
start_btn = gr.Button("Start Analysis", variant="primary")
|
| 217 |
+
|
| 218 |
+
with gr.Column(scale=2):
|
| 219 |
+
# 显示区域
|
| 220 |
+
with gr.Tab("Preview"):
|
| 221 |
+
# 图像预览
|
| 222 |
+
preview_image = gr.Image(label="Preview Image", type="numpy", height=400)
|
| 223 |
+
status_text = gr.Textbox(label="Status", interactive=False)
|
| 224 |
+
gr.Markdown("""
|
| 225 |
+
### Instructions:
|
| 226 |
+
1. Select a video and model file (.pt format segmentation model like yolov8n-seg.pt recommended)
|
| 227 |
+
2. Adjust parameters
|
| 228 |
+
- Confidence Threshold: Minimum confidence for object detection, lower values detect more potential objects
|
| 229 |
+
- IoU Threshold: For filtering overlapping detections
|
| 230 |
+
- Max Detections: Maximum number of objects to detect per frame
|
| 231 |
+
- Struggle Threshold: Minimum score to classify as struggle state
|
| 232 |
+
3. Set frame range
|
| 233 |
+
4. Click "Start Analysis" button
|
| 234 |
+
|
| 235 |
+
The system will automatically track mice and analyze their struggle behavior, no need to manually define regions
|
| 236 |
+
""")
|
| 237 |
+
|
| 238 |
+
with gr.Tab("Results"):
|
| 239 |
+
with gr.Row():
|
| 240 |
+
output_video = gr.Video(label="Analysis Result Video")
|
| 241 |
+
result_plot = gr.Image(label="Struggle Score Time Series")
|
| 242 |
+
|
| 243 |
+
result_status = gr.Textbox(label="Analysis Status", interactive=False)
|
| 244 |
+
|
| 245 |
+
# 绑定事件
|
| 246 |
+
video_input.change(select_video, inputs=[video_input], outputs=[preview_image, status_text, start_frame, end_frame])
|
| 247 |
+
model_input.change(select_model, inputs=[model_input], outputs=[status_text])
|
| 248 |
+
|
| 249 |
+
preview_btn.click(preview_frame, inputs=[video_input, start_frame], outputs=[preview_image, status_text])
|
| 250 |
+
|
| 251 |
+
start_btn.click(
|
| 252 |
+
start_analysis,
|
| 253 |
+
inputs=[video_input, model_input, conf, iou, max_det, start_frame, end_frame, threshold],
|
| 254 |
+
outputs=[output_video, result_plot, result_status]
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
return app
|
| 258 |
|
| 259 |
+
# 启动应用
|
| 260 |
if __name__ == "__main__":
|
| 261 |
+
# 清除可能干扰的代理设置
|
| 262 |
+
if 'http_proxy' in os.environ:
|
| 263 |
+
del os.environ['http_proxy']
|
| 264 |
+
if 'https_proxy' in os.environ:
|
| 265 |
+
del os.environ['https_proxy']
|
| 266 |
+
if 'all_proxy' in os.environ:
|
| 267 |
+
del os.environ['all_proxy']
|
| 268 |
+
|
| 269 |
+
app = create_interface()
|
| 270 |
+
# 使用简化的启动配置
|
| 271 |
+
app.launch(server_name="127.0.0.1", server_port=7860, share=False)
|