Update app.py
Browse files
app.py
CHANGED
|
@@ -8,7 +8,7 @@ from mouse_tracker import MouseTrackerAnalyzer
|
|
| 8 |
# 全局变量
|
| 9 |
analyzer = None
|
| 10 |
video_file_path = None
|
| 11 |
-
model_file_path =
|
| 12 |
total_frames = 0
|
| 13 |
output_path = None
|
| 14 |
|
|
@@ -41,14 +41,14 @@ def select_video(video_file):
|
|
| 41 |
global video_file_path, total_frames
|
| 42 |
|
| 43 |
if not video_file:
|
| 44 |
-
return None, "
|
| 45 |
|
| 46 |
video_file_path = video_file
|
| 47 |
|
| 48 |
# 获取视频总帧数
|
| 49 |
cap = cv2.VideoCapture(video_file_path)
|
| 50 |
if not cap.isOpened():
|
| 51 |
-
return None, "
|
| 52 |
|
| 53 |
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 54 |
|
|
@@ -57,7 +57,7 @@ def select_video(video_file):
|
|
| 57 |
cap.release()
|
| 58 |
|
| 59 |
if not ret:
|
| 60 |
-
return None, "
|
| 61 |
|
| 62 |
# 转为RGB
|
| 63 |
first_frame_rgb = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
|
|
@@ -66,39 +66,30 @@ def select_video(video_file):
|
|
| 66 |
start_slider = gr.Slider(minimum=0, maximum=total_frames-1, value=0, step=1)
|
| 67 |
end_slider = gr.Slider(minimum=0, maximum=total_frames-1, value=total_frames-1, step=1)
|
| 68 |
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
# 选择模型文件
|
| 72 |
-
def select_model(model_file):
|
| 73 |
-
global model_file_path
|
| 74 |
-
|
| 75 |
-
if model_file is None:
|
| 76 |
-
return "Please select a model file"
|
| 77 |
-
|
| 78 |
-
model_file_path = model_file
|
| 79 |
-
return f"Model selected: {os.path.basename(model_file_path)}"
|
| 80 |
|
| 81 |
# 预览帧
|
| 82 |
def preview_frame(video_file, frame_num):
|
| 83 |
if not video_file:
|
| 84 |
-
return None, "
|
| 85 |
|
| 86 |
# 从视频提取帧
|
| 87 |
frame = extract_frame(video_file, frame_num)
|
| 88 |
if frame is None:
|
| 89 |
-
return None, "
|
| 90 |
|
| 91 |
-
return frame, f"
|
| 92 |
|
| 93 |
# 开始分析
|
| 94 |
-
def start_analysis(video,
|
| 95 |
-
global analyzer, output_path
|
| 96 |
|
| 97 |
-
if not video
|
| 98 |
-
return None, None, "
|
| 99 |
|
| 100 |
if start_frame >= end_frame:
|
| 101 |
-
return None, None, "
|
| 102 |
|
| 103 |
# 创建输出路径
|
| 104 |
video_name = os.path.splitext(os.path.basename(video))[0]
|
|
@@ -108,7 +99,7 @@ def start_analysis(video, model, conf, iou, max_det, start_frame, end_frame, thr
|
|
| 108 |
try:
|
| 109 |
# 创建分析器
|
| 110 |
analyzer = MouseTrackerAnalyzer(
|
| 111 |
-
model_path=
|
| 112 |
conf=conf,
|
| 113 |
iou=iou,
|
| 114 |
max_det=max_det,
|
|
@@ -118,11 +109,12 @@ def start_analysis(video, model, conf, iou, max_det, start_frame, end_frame, thr
|
|
| 118 |
|
| 119 |
# 处理视频的进度回调
|
| 120 |
def progress_update(progress, frame, results):
|
| 121 |
-
print(f"
|
| 122 |
|
| 123 |
-
print(f"
|
| 124 |
-
print(f"
|
| 125 |
-
print(f"
|
|
|
|
| 126 |
|
| 127 |
# 提取视频帧数范围并分析
|
| 128 |
results = analyzer.process_video(
|
|
@@ -134,123 +126,124 @@ def start_analysis(video, model, conf, iou, max_det, start_frame, end_frame, thr
|
|
| 134 |
)
|
| 135 |
|
| 136 |
# 保存结果到CSV
|
| 137 |
-
print(f"
|
| 138 |
analyzer.save_results(csv_path)
|
| 139 |
-
print(f"
|
| 140 |
|
| 141 |
# 生成分析图表
|
| 142 |
-
print("
|
| 143 |
if len(analyzer.results) == 0:
|
| 144 |
-
print("
|
| 145 |
plot_path = None
|
| 146 |
else:
|
| 147 |
plot_path = analyzer.generate_time_series_plot()
|
| 148 |
if plot_path and os.path.exists(plot_path):
|
| 149 |
-
print(f"
|
| 150 |
else:
|
| 151 |
-
print(f"
|
| 152 |
plot_path = None
|
| 153 |
|
| 154 |
# 检查输出文件是否存在
|
| 155 |
if os.path.exists(output_path):
|
| 156 |
file_size = os.path.getsize(output_path) / (1024 * 1024) # MB
|
| 157 |
-
print(f"
|
| 158 |
|
| 159 |
# 处理debug帧
|
| 160 |
debug_frame_path = os.path.join(os.path.dirname(output_path), "debug_frame.jpg")
|
| 161 |
if os.path.exists(debug_frame_path):
|
| 162 |
-
print(f"
|
| 163 |
|
| 164 |
if plot_path and os.path.exists(plot_path):
|
| 165 |
-
print(f"
|
| 166 |
|
| 167 |
# 确保返回正确的文件路径
|
| 168 |
-
status_message = "
|
| 169 |
|
| 170 |
if os.path.exists(output_path):
|
| 171 |
-
status_message += f"
|
| 172 |
else:
|
| 173 |
-
status_message += "
|
| 174 |
|
| 175 |
if plot_path and os.path.exists(plot_path):
|
| 176 |
-
status_message += f"
|
| 177 |
else:
|
| 178 |
-
status_message += "
|
| 179 |
|
| 180 |
-
status_message += f"
|
| 181 |
|
| 182 |
return output_path, plot_path, status_message
|
| 183 |
except Exception as e:
|
| 184 |
import traceback
|
| 185 |
traceback.print_exc()
|
| 186 |
-
return None, None, f"
|
| 187 |
|
| 188 |
# 创建Gradio界面
|
| 189 |
def create_interface():
|
| 190 |
-
with gr.Blocks(title="
|
| 191 |
-
gr.Markdown("#
|
| 192 |
|
| 193 |
with gr.Row():
|
| 194 |
with gr.Column(scale=1):
|
| 195 |
-
#
|
| 196 |
-
video_input = gr.Video(label="
|
| 197 |
-
|
|
|
|
|
|
|
| 198 |
|
| 199 |
# 参数设置
|
| 200 |
with gr.Row():
|
| 201 |
-
conf = gr.Slider(minimum=0.1, maximum=0.9, value=0.25, step=0.05, label="
|
| 202 |
-
iou = gr.Slider(minimum=0.1, maximum=0.9, value=0.45, step=0.05, label="IoU
|
| 203 |
|
| 204 |
with gr.Row():
|
| 205 |
-
max_det = gr.Slider(minimum=1, maximum=50, value=20, step=1, label="
|
| 206 |
-
threshold = gr.Slider(minimum=0, maximum=1, value=0.3, step=0.01, label="
|
| 207 |
|
| 208 |
# 帧选择
|
| 209 |
-
start_frame = gr.Slider(minimum=0, maximum=999999, value=0, step=1, label="
|
| 210 |
-
end_frame = gr.Slider(minimum=0, maximum=999999, value=999999, step=1, label="
|
| 211 |
|
| 212 |
# 预览按钮
|
| 213 |
-
preview_btn = gr.Button("
|
| 214 |
|
| 215 |
# 开始分析
|
| 216 |
-
start_btn = gr.Button("
|
| 217 |
|
| 218 |
with gr.Column(scale=2):
|
| 219 |
# 显示区域
|
| 220 |
-
with gr.Tab("
|
| 221 |
# 图像预览
|
| 222 |
-
preview_image = gr.Image(label="
|
| 223 |
-
status_text = gr.Textbox(label="
|
| 224 |
gr.Markdown("""
|
| 225 |
-
###
|
| 226 |
-
1.
|
| 227 |
-
2.
|
| 228 |
-
-
|
| 229 |
-
- IoU
|
| 230 |
-
-
|
| 231 |
-
-
|
| 232 |
-
3.
|
| 233 |
-
4.
|
| 234 |
|
| 235 |
-
|
| 236 |
""")
|
| 237 |
|
| 238 |
-
with gr.Tab("
|
| 239 |
with gr.Row():
|
| 240 |
-
output_video = gr.Video(label="
|
| 241 |
-
result_plot = gr.Image(label="
|
| 242 |
|
| 243 |
-
result_status = gr.Textbox(label="
|
| 244 |
|
| 245 |
# 绑定事件
|
| 246 |
video_input.change(select_video, inputs=[video_input], outputs=[preview_image, status_text, start_frame, end_frame])
|
| 247 |
-
model_input.change(select_model, inputs=[model_input], outputs=[status_text])
|
| 248 |
|
| 249 |
preview_btn.click(preview_frame, inputs=[video_input, start_frame], outputs=[preview_image, status_text])
|
| 250 |
|
| 251 |
start_btn.click(
|
| 252 |
start_analysis,
|
| 253 |
-
inputs=[video_input,
|
| 254 |
outputs=[output_video, result_plot, result_status]
|
| 255 |
)
|
| 256 |
|
|
@@ -258,6 +251,19 @@ def create_interface():
|
|
| 258 |
|
| 259 |
# 启动应用
|
| 260 |
if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 261 |
|
| 262 |
app = create_interface()
|
| 263 |
# 使用简化的启动配置
|
|
|
|
| 8 |
# 全局变量
|
| 9 |
analyzer = None
|
| 10 |
video_file_path = None
|
| 11 |
+
model_file_path = "weights/fst-v1.2-n.onnx" # 直接指定模型文件路径
|
| 12 |
total_frames = 0
|
| 13 |
output_path = None
|
| 14 |
|
|
|
|
| 41 |
global video_file_path, total_frames
|
| 42 |
|
| 43 |
if not video_file:
|
| 44 |
+
return None, "请选择视频文件", gr.Slider(minimum=0, maximum=0, value=0), gr.Slider(minimum=0, maximum=0, value=0)
|
| 45 |
|
| 46 |
video_file_path = video_file
|
| 47 |
|
| 48 |
# 获取视频总帧数
|
| 49 |
cap = cv2.VideoCapture(video_file_path)
|
| 50 |
if not cap.isOpened():
|
| 51 |
+
return None, "无法打开视频文件", gr.Slider(minimum=0, maximum=0, value=0), gr.Slider(minimum=0, maximum=0, value=0)
|
| 52 |
|
| 53 |
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 54 |
|
|
|
|
| 57 |
cap.release()
|
| 58 |
|
| 59 |
if not ret:
|
| 60 |
+
return None, "无法读取视频帧", gr.Slider(minimum=0, maximum=0, value=0), gr.Slider(minimum=0, maximum=0, value=0)
|
| 61 |
|
| 62 |
# 转为RGB
|
| 63 |
first_frame_rgb = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
|
|
|
|
| 66 |
start_slider = gr.Slider(minimum=0, maximum=total_frames-1, value=0, step=1)
|
| 67 |
end_slider = gr.Slider(minimum=0, maximum=total_frames-1, value=total_frames-1, step=1)
|
| 68 |
|
| 69 |
+
model_status = f"使用模型: {os.path.basename(model_file_path)}"
|
| 70 |
+
return first_frame_rgb, f"视频加载成功,总帧数: {total_frames}. {model_status}", start_slider, end_slider
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
|
| 72 |
# 预览帧
|
| 73 |
def preview_frame(video_file, frame_num):
|
| 74 |
if not video_file:
|
| 75 |
+
return None, "请先选择视频文件"
|
| 76 |
|
| 77 |
# 从视频提取帧
|
| 78 |
frame = extract_frame(video_file, frame_num)
|
| 79 |
if frame is None:
|
| 80 |
+
return None, "无法读取指定帧"
|
| 81 |
|
| 82 |
+
return frame, f"帧 {frame_num}"
|
| 83 |
|
| 84 |
# 开始分析
|
| 85 |
+
def start_analysis(video, conf, iou, max_det, start_frame, end_frame, threshold):
|
| 86 |
+
global analyzer, output_path, model_file_path
|
| 87 |
|
| 88 |
+
if not video:
|
| 89 |
+
return None, None, "请选择视频文件"
|
| 90 |
|
| 91 |
if start_frame >= end_frame:
|
| 92 |
+
return None, None, "起始帧必须小于结束帧"
|
| 93 |
|
| 94 |
# 创建输出路径
|
| 95 |
video_name = os.path.splitext(os.path.basename(video))[0]
|
|
|
|
| 99 |
try:
|
| 100 |
# 创建分析器
|
| 101 |
analyzer = MouseTrackerAnalyzer(
|
| 102 |
+
model_path=model_file_path,
|
| 103 |
conf=conf,
|
| 104 |
iou=iou,
|
| 105 |
max_det=max_det,
|
|
|
|
| 109 |
|
| 110 |
# 处理视频的进度回调
|
| 111 |
def progress_update(progress, frame, results):
|
| 112 |
+
print(f"处理进度: {progress}%, 检测到对象数: {len(results)}")
|
| 113 |
|
| 114 |
+
print(f"处理视频: {video}")
|
| 115 |
+
print(f"输出路径: {output_path}")
|
| 116 |
+
print(f"参数: conf={conf}, iou={iou}, max_det={max_det}, threshold={threshold}")
|
| 117 |
+
print(f"使用模型: {model_file_path}")
|
| 118 |
|
| 119 |
# 提取视频帧数范围并分析
|
| 120 |
results = analyzer.process_video(
|
|
|
|
| 126 |
)
|
| 127 |
|
| 128 |
# 保存结果到CSV
|
| 129 |
+
print(f"保存结果到CSV: {csv_path}")
|
| 130 |
analyzer.save_results(csv_path)
|
| 131 |
+
print(f"结果已保存到CSV,共 {len(analyzer.results)} 帧数据")
|
| 132 |
|
| 133 |
# 生成分析图表
|
| 134 |
+
print("生成时间序列图...")
|
| 135 |
if len(analyzer.results) == 0:
|
| 136 |
+
print("警告: 没有可用于绘图的结果!")
|
| 137 |
plot_path = None
|
| 138 |
else:
|
| 139 |
plot_path = analyzer.generate_time_series_plot()
|
| 140 |
if plot_path and os.path.exists(plot_path):
|
| 141 |
+
print(f"图表已生成并保存到: {plot_path}, 大小: {os.path.getsize(plot_path)/1024:.2f}KB")
|
| 142 |
else:
|
| 143 |
+
print(f"生成图表失败或图表文件不存在!")
|
| 144 |
plot_path = None
|
| 145 |
|
| 146 |
# 检查输出文件是否存在
|
| 147 |
if os.path.exists(output_path):
|
| 148 |
file_size = os.path.getsize(output_path) / (1024 * 1024) # MB
|
| 149 |
+
print(f"输出视频大小: {file_size:.2f}MB")
|
| 150 |
|
| 151 |
# 处理debug帧
|
| 152 |
debug_frame_path = os.path.join(os.path.dirname(output_path), "debug_frame.jpg")
|
| 153 |
if os.path.exists(debug_frame_path):
|
| 154 |
+
print(f"调试帧保存在: {debug_frame_path}")
|
| 155 |
|
| 156 |
if plot_path and os.path.exists(plot_path):
|
| 157 |
+
print(f"图表文件存在于: {plot_path}, 大小: {os.path.getsize(plot_path)/1024:.2f}KB")
|
| 158 |
|
| 159 |
# 确保返回正确的文件路径
|
| 160 |
+
status_message = "分析完成。"
|
| 161 |
|
| 162 |
if os.path.exists(output_path):
|
| 163 |
+
status_message += f"视频已保存。"
|
| 164 |
else:
|
| 165 |
+
status_message += "警告: 未找到输出视频。"
|
| 166 |
|
| 167 |
if plot_path and os.path.exists(plot_path):
|
| 168 |
+
status_message += f" 时间序列图已生成。"
|
| 169 |
else:
|
| 170 |
+
status_message += " 警告: 生成时间序列图失败。"
|
| 171 |
|
| 172 |
+
status_message += f" 结果已保存到: {csv_path}"
|
| 173 |
|
| 174 |
return output_path, plot_path, status_message
|
| 175 |
except Exception as e:
|
| 176 |
import traceback
|
| 177 |
traceback.print_exc()
|
| 178 |
+
return None, None, f"处理错误: {str(e)}"
|
| 179 |
|
| 180 |
# 创建Gradio界面
|
| 181 |
def create_interface():
|
| 182 |
+
with gr.Blocks(title="鼠强迫游泳挣扎度分析 - 对象跟踪") as app:
|
| 183 |
+
gr.Markdown("# 鼠强迫游泳测试挣扎度分析 (对象跟踪)")
|
| 184 |
|
| 185 |
with gr.Row():
|
| 186 |
with gr.Column(scale=1):
|
| 187 |
+
# 只保留视频选择,移除模型选择
|
| 188 |
+
video_input = gr.Video(label="输入视频")
|
| 189 |
+
|
| 190 |
+
# 显示当前使用的模型
|
| 191 |
+
model_info = gr.Textbox(label="模型信息", value=f"使用模型: {os.path.basename(model_file_path)}", interactive=False)
|
| 192 |
|
| 193 |
# 参数设置
|
| 194 |
with gr.Row():
|
| 195 |
+
conf = gr.Slider(minimum=0.1, maximum=0.9, value=0.25, step=0.05, label="置信度阈值")
|
| 196 |
+
iou = gr.Slider(minimum=0.1, maximum=0.9, value=0.45, step=0.05, label="IoU阈值")
|
| 197 |
|
| 198 |
with gr.Row():
|
| 199 |
+
max_det = gr.Slider(minimum=1, maximum=50, value=20, step=1, label="最大检测数")
|
| 200 |
+
threshold = gr.Slider(minimum=0, maximum=1, value=0.3, step=0.01, label="挣扎阈值")
|
| 201 |
|
| 202 |
# 帧选择
|
| 203 |
+
start_frame = gr.Slider(minimum=0, maximum=999999, value=0, step=1, label="起始帧")
|
| 204 |
+
end_frame = gr.Slider(minimum=0, maximum=999999, value=999999, step=1, label="结束帧")
|
| 205 |
|
| 206 |
# 预览按钮
|
| 207 |
+
preview_btn = gr.Button("预览帧")
|
| 208 |
|
| 209 |
# 开始分析
|
| 210 |
+
start_btn = gr.Button("开始分析", variant="primary")
|
| 211 |
|
| 212 |
with gr.Column(scale=2):
|
| 213 |
# 显示区域
|
| 214 |
+
with gr.Tab("预览"):
|
| 215 |
# 图像预览
|
| 216 |
+
preview_image = gr.Image(label="预览图像", type="numpy", height=400)
|
| 217 |
+
status_text = gr.Textbox(label="状态", interactive=False)
|
| 218 |
gr.Markdown("""
|
| 219 |
+
### 使用说明:
|
| 220 |
+
1. 选择一个视频文件
|
| 221 |
+
2. 调整参数
|
| 222 |
+
- 置信度阈值: 对象检测的最低置信度,较低的值会检测更多潜在对象
|
| 223 |
+
- IoU阈值: 用于过滤重叠检测
|
| 224 |
+
- 最大检测数: 每帧检测的最大对象数
|
| 225 |
+
- 挣扎阈值: 分类为挣扎状态的最低分数
|
| 226 |
+
3. 设置帧范围
|
| 227 |
+
4. 点击"开始分析"按钮
|
| 228 |
|
| 229 |
+
系统将自动跟踪小鼠并分析其挣扎行为,无需手动定义区域
|
| 230 |
""")
|
| 231 |
|
| 232 |
+
with gr.Tab("结果"):
|
| 233 |
with gr.Row():
|
| 234 |
+
output_video = gr.Video(label="分析结果视频")
|
| 235 |
+
result_plot = gr.Image(label="挣扎分数时间序列")
|
| 236 |
|
| 237 |
+
result_status = gr.Textbox(label="分析状态", interactive=False)
|
| 238 |
|
| 239 |
# 绑定事件
|
| 240 |
video_input.change(select_video, inputs=[video_input], outputs=[preview_image, status_text, start_frame, end_frame])
|
|
|
|
| 241 |
|
| 242 |
preview_btn.click(preview_frame, inputs=[video_input, start_frame], outputs=[preview_image, status_text])
|
| 243 |
|
| 244 |
start_btn.click(
|
| 245 |
start_analysis,
|
| 246 |
+
inputs=[video_input, conf, iou, max_det, start_frame, end_frame, threshold],
|
| 247 |
outputs=[output_video, result_plot, result_status]
|
| 248 |
)
|
| 249 |
|
|
|
|
| 251 |
|
| 252 |
# 启动应用
|
| 253 |
if __name__ == "__main__":
|
| 254 |
+
# 清除可能干扰的代理设置
|
| 255 |
+
if 'http_proxy' in os.environ:
|
| 256 |
+
del os.environ['http_proxy']
|
| 257 |
+
if 'https_proxy' in os.environ:
|
| 258 |
+
del os.environ['https_proxy']
|
| 259 |
+
if 'all_proxy' in os.environ:
|
| 260 |
+
del os.environ['all_proxy']
|
| 261 |
+
|
| 262 |
+
# 检查模型文件是否存在
|
| 263 |
+
if not os.path.exists(model_file_path):
|
| 264 |
+
print(f"警告: 模型文件 {model_file_path} 不存在!")
|
| 265 |
+
else:
|
| 266 |
+
print(f"使用模型: {model_file_path}")
|
| 267 |
|
| 268 |
app = create_interface()
|
| 269 |
# 使用简化的启动配置
|