Hakureirm commited on
Commit
b748ffa
·
verified ·
1 Parent(s): b4516e1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -76
app.py CHANGED
@@ -8,7 +8,7 @@ from mouse_tracker import MouseTrackerAnalyzer
8
  # 全局变量
9
  analyzer = None
10
  video_file_path = None
11
- model_file_path = None
12
  total_frames = 0
13
  output_path = None
14
 
@@ -41,14 +41,14 @@ def select_video(video_file):
41
  global video_file_path, total_frames
42
 
43
  if not video_file:
44
- return None, "Please select a video file", gr.Slider(minimum=0, maximum=0, value=0), gr.Slider(minimum=0, maximum=0, value=0)
45
 
46
  video_file_path = video_file
47
 
48
  # 获取视频总帧数
49
  cap = cv2.VideoCapture(video_file_path)
50
  if not cap.isOpened():
51
- return None, "Cannot open video file", gr.Slider(minimum=0, maximum=0, value=0), gr.Slider(minimum=0, maximum=0, value=0)
52
 
53
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
54
 
@@ -57,7 +57,7 @@ def select_video(video_file):
57
  cap.release()
58
 
59
  if not ret:
60
- return None, "Cannot read video frame", gr.Slider(minimum=0, maximum=0, value=0), gr.Slider(minimum=0, maximum=0, value=0)
61
 
62
  # 转为RGB
63
  first_frame_rgb = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
@@ -66,39 +66,30 @@ def select_video(video_file):
66
  start_slider = gr.Slider(minimum=0, maximum=total_frames-1, value=0, step=1)
67
  end_slider = gr.Slider(minimum=0, maximum=total_frames-1, value=total_frames-1, step=1)
68
 
69
- return first_frame_rgb, f"Video loaded successfully, total frames: {total_frames}", start_slider, end_slider
70
-
71
- # 选择模型文件
72
- def select_model(model_file):
73
- global model_file_path
74
-
75
- if model_file is None:
76
- return "Please select a model file"
77
-
78
- model_file_path = model_file
79
- return f"Model selected: {os.path.basename(model_file_path)}"
80
 
81
  # 预览帧
82
  def preview_frame(video_file, frame_num):
83
  if not video_file:
84
- return None, "Please select a video first"
85
 
86
  # 从视频提取帧
87
  frame = extract_frame(video_file, frame_num)
88
  if frame is None:
89
- return None, "Cannot read specified frame"
90
 
91
- return frame, f"Frame {frame_num}"
92
 
93
  # 开始分析
94
- def start_analysis(video, model, conf, iou, max_det, start_frame, end_frame, threshold):
95
- global analyzer, output_path
96
 
97
- if not video or not model:
98
- return None, None, "Please select a video and model file"
99
 
100
  if start_frame >= end_frame:
101
- return None, None, "Start frame must be less than end frame"
102
 
103
  # 创建输出路径
104
  video_name = os.path.splitext(os.path.basename(video))[0]
@@ -108,7 +99,7 @@ def start_analysis(video, model, conf, iou, max_det, start_frame, end_frame, thr
108
  try:
109
  # 创建分析器
110
  analyzer = MouseTrackerAnalyzer(
111
- model_path=model,
112
  conf=conf,
113
  iou=iou,
114
  max_det=max_det,
@@ -118,11 +109,12 @@ def start_analysis(video, model, conf, iou, max_det, start_frame, end_frame, thr
118
 
119
  # 处理视频的进度回调
120
  def progress_update(progress, frame, results):
121
- print(f"Processing: {progress}%, Objects detected: {len(results)}")
122
 
123
- print(f"Processing video: {video}")
124
- print(f"Output path: {output_path}")
125
- print(f"Parameters: conf={conf}, iou={iou}, max_det={max_det}, threshold={threshold}")
 
126
 
127
  # 提取视频帧数范围并分析
128
  results = analyzer.process_video(
@@ -134,123 +126,124 @@ def start_analysis(video, model, conf, iou, max_det, start_frame, end_frame, thr
134
  )
135
 
136
  # 保存结果到CSV
137
- print(f"Saving results to CSV: {csv_path}")
138
  analyzer.save_results(csv_path)
139
- print(f"Results saved to CSV with {len(analyzer.results)} frames of data")
140
 
141
  # 生成分析图表
142
- print("Generating time series plot...")
143
  if len(analyzer.results) == 0:
144
- print("WARNING: No results available for plotting!")
145
  plot_path = None
146
  else:
147
  plot_path = analyzer.generate_time_series_plot()
148
  if plot_path and os.path.exists(plot_path):
149
- print(f"Plot generated and saved to: {plot_path}, size: {os.path.getsize(plot_path)/1024:.2f}KB")
150
  else:
151
- print(f"Failed to generate plot or plot file does not exist!")
152
  plot_path = None
153
 
154
  # 检查输出文件是否存在
155
  if os.path.exists(output_path):
156
  file_size = os.path.getsize(output_path) / (1024 * 1024) # MB
157
- print(f"Output video size: {file_size:.2f}MB")
158
 
159
  # 处理debug帧
160
  debug_frame_path = os.path.join(os.path.dirname(output_path), "debug_frame.jpg")
161
  if os.path.exists(debug_frame_path):
162
- print(f"Debug frame saved at: {debug_frame_path}")
163
 
164
  if plot_path and os.path.exists(plot_path):
165
- print(f"Plot file exists at: {plot_path}, size: {os.path.getsize(plot_path)/1024:.2f}KB")
166
 
167
  # 确保返回正确的文件路径
168
- status_message = "Analysis complete. "
169
 
170
  if os.path.exists(output_path):
171
- status_message += f"Video saved."
172
  else:
173
- status_message += "WARNING: Output video not found. "
174
 
175
  if plot_path and os.path.exists(plot_path):
176
- status_message += f" Time series plot generated."
177
  else:
178
- status_message += " WARNING: Failed to generate time series plot."
179
 
180
- status_message += f" Results saved to: {csv_path}"
181
 
182
  return output_path, plot_path, status_message
183
  except Exception as e:
184
  import traceback
185
  traceback.print_exc()
186
- return None, None, f"Processing error: {str(e)}"
187
 
188
  # 创建Gradio界面
189
  def create_interface():
190
- with gr.Blocks(title="Mouse Struggle Analysis - Object Tracking") as app:
191
- gr.Markdown("# Mouse Forced Swim Test Struggle Analysis (Object Tracking)")
192
 
193
  with gr.Row():
194
  with gr.Column(scale=1):
195
- # 视频和模型选择
196
- video_input = gr.Video(label="Input Video")
197
- model_input = gr.File(label="Model File (.pt format recommended)")
 
 
198
 
199
  # 参数设置
200
  with gr.Row():
201
- conf = gr.Slider(minimum=0.1, maximum=0.9, value=0.25, step=0.05, label="Confidence Threshold")
202
- iou = gr.Slider(minimum=0.1, maximum=0.9, value=0.45, step=0.05, label="IoU Threshold")
203
 
204
  with gr.Row():
205
- max_det = gr.Slider(minimum=1, maximum=50, value=20, step=1, label="Max Detections")
206
- threshold = gr.Slider(minimum=0, maximum=1, value=0.3, step=0.01, label="Struggle Threshold")
207
 
208
  # 帧选择
209
- start_frame = gr.Slider(minimum=0, maximum=999999, value=0, step=1, label="Start Frame")
210
- end_frame = gr.Slider(minimum=0, maximum=999999, value=999999, step=1, label="End Frame")
211
 
212
  # 预览按钮
213
- preview_btn = gr.Button("Preview Frame")
214
 
215
  # 开始分析
216
- start_btn = gr.Button("Start Analysis", variant="primary")
217
 
218
  with gr.Column(scale=2):
219
  # 显示区域
220
- with gr.Tab("Preview"):
221
  # 图像预览
222
- preview_image = gr.Image(label="Preview Image", type="numpy", height=400)
223
- status_text = gr.Textbox(label="Status", interactive=False)
224
  gr.Markdown("""
225
- ### Instructions:
226
- 1. Select a video and model file (.pt format segmentation model like yolov8n-seg.pt recommended)
227
- 2. Adjust parameters
228
- - Confidence Threshold: Minimum confidence for object detection, lower values detect more potential objects
229
- - IoU Threshold: For filtering overlapping detections
230
- - Max Detections: Maximum number of objects to detect per frame
231
- - Struggle Threshold: Minimum score to classify as struggle state
232
- 3. Set frame range
233
- 4. Click "Start Analysis" button
234
 
235
- The system will automatically track mice and analyze their struggle behavior, no need to manually define regions
236
  """)
237
 
238
- with gr.Tab("Results"):
239
  with gr.Row():
240
- output_video = gr.Video(label="Analysis Result Video")
241
- result_plot = gr.Image(label="Struggle Score Time Series")
242
 
243
- result_status = gr.Textbox(label="Analysis Status", interactive=False)
244
 
245
  # 绑定事件
246
  video_input.change(select_video, inputs=[video_input], outputs=[preview_image, status_text, start_frame, end_frame])
247
- model_input.change(select_model, inputs=[model_input], outputs=[status_text])
248
 
249
  preview_btn.click(preview_frame, inputs=[video_input, start_frame], outputs=[preview_image, status_text])
250
 
251
  start_btn.click(
252
  start_analysis,
253
- inputs=[video_input, model_input, conf, iou, max_det, start_frame, end_frame, threshold],
254
  outputs=[output_video, result_plot, result_status]
255
  )
256
 
@@ -258,6 +251,19 @@ def create_interface():
258
 
259
  # 启动应用
260
  if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
 
 
261
 
262
  app = create_interface()
263
  # 使用简化的启动配置
 
8
  # 全局变量
9
  analyzer = None
10
  video_file_path = None
11
+ model_file_path = "weights/fst-v1.2-n.onnx" # 直接指定模型文件路径
12
  total_frames = 0
13
  output_path = None
14
 
 
41
  global video_file_path, total_frames
42
 
43
  if not video_file:
44
+ return None, "请选择视频文件", gr.Slider(minimum=0, maximum=0, value=0), gr.Slider(minimum=0, maximum=0, value=0)
45
 
46
  video_file_path = video_file
47
 
48
  # 获取视频总帧数
49
  cap = cv2.VideoCapture(video_file_path)
50
  if not cap.isOpened():
51
+ return None, "无法打开视频文件", gr.Slider(minimum=0, maximum=0, value=0), gr.Slider(minimum=0, maximum=0, value=0)
52
 
53
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
54
 
 
57
  cap.release()
58
 
59
  if not ret:
60
+ return None, "无法读取视频帧", gr.Slider(minimum=0, maximum=0, value=0), gr.Slider(minimum=0, maximum=0, value=0)
61
 
62
  # 转为RGB
63
  first_frame_rgb = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
 
66
  start_slider = gr.Slider(minimum=0, maximum=total_frames-1, value=0, step=1)
67
  end_slider = gr.Slider(minimum=0, maximum=total_frames-1, value=total_frames-1, step=1)
68
 
69
+ model_status = f"使用模型: {os.path.basename(model_file_path)}"
70
+ return first_frame_rgb, f"视频加载成功,总帧数: {total_frames}. {model_status}", start_slider, end_slider
 
 
 
 
 
 
 
 
 
71
 
72
  # 预览帧
73
  def preview_frame(video_file, frame_num):
74
  if not video_file:
75
+ return None, "请先选择视频文件"
76
 
77
  # 从视频提取帧
78
  frame = extract_frame(video_file, frame_num)
79
  if frame is None:
80
+ return None, "无法读取指定帧"
81
 
82
+ return frame, f" {frame_num}"
83
 
84
  # 开始分析
85
+ def start_analysis(video, conf, iou, max_det, start_frame, end_frame, threshold):
86
+ global analyzer, output_path, model_file_path
87
 
88
+ if not video:
89
+ return None, None, "请选择视频文件"
90
 
91
  if start_frame >= end_frame:
92
+ return None, None, "起始帧必须小于结束帧"
93
 
94
  # 创建输出路径
95
  video_name = os.path.splitext(os.path.basename(video))[0]
 
99
  try:
100
  # 创建分析器
101
  analyzer = MouseTrackerAnalyzer(
102
+ model_path=model_file_path,
103
  conf=conf,
104
  iou=iou,
105
  max_det=max_det,
 
109
 
110
  # 处理视频的进度回调
111
  def progress_update(progress, frame, results):
112
+ print(f"处理进度: {progress}%, 检测到对象数: {len(results)}")
113
 
114
+ print(f"处理视频: {video}")
115
+ print(f"输出路径: {output_path}")
116
+ print(f"参数: conf={conf}, iou={iou}, max_det={max_det}, threshold={threshold}")
117
+ print(f"使用模型: {model_file_path}")
118
 
119
  # 提取视频帧数范围并分析
120
  results = analyzer.process_video(
 
126
  )
127
 
128
  # 保存结果到CSV
129
+ print(f"保存结果到CSV: {csv_path}")
130
  analyzer.save_results(csv_path)
131
+ print(f"结果已保存到CSV,共 {len(analyzer.results)} 帧数据")
132
 
133
  # 生成分析图表
134
+ print("生成时间序列图...")
135
  if len(analyzer.results) == 0:
136
+ print("警告: 没有可用于绘图的结果!")
137
  plot_path = None
138
  else:
139
  plot_path = analyzer.generate_time_series_plot()
140
  if plot_path and os.path.exists(plot_path):
141
+ print(f"图表已生成并保存到: {plot_path}, 大小: {os.path.getsize(plot_path)/1024:.2f}KB")
142
  else:
143
+ print(f"生成图表失败或图表文件不存在!")
144
  plot_path = None
145
 
146
  # 检查输出文件是否存在
147
  if os.path.exists(output_path):
148
  file_size = os.path.getsize(output_path) / (1024 * 1024) # MB
149
+ print(f"输出视频大小: {file_size:.2f}MB")
150
 
151
  # 处理debug帧
152
  debug_frame_path = os.path.join(os.path.dirname(output_path), "debug_frame.jpg")
153
  if os.path.exists(debug_frame_path):
154
+ print(f"调试帧保存在: {debug_frame_path}")
155
 
156
  if plot_path and os.path.exists(plot_path):
157
+ print(f"图表文件存在于: {plot_path}, 大小: {os.path.getsize(plot_path)/1024:.2f}KB")
158
 
159
  # 确保返回正确的文件路径
160
+ status_message = "分析完成。"
161
 
162
  if os.path.exists(output_path):
163
+ status_message += f"视频已保存。"
164
  else:
165
+ status_message += "警告: 未找到输出视频。"
166
 
167
  if plot_path and os.path.exists(plot_path):
168
+ status_message += f" 时间序列图已生成。"
169
  else:
170
+ status_message += " 警告: 生成时间序列图失败。"
171
 
172
+ status_message += f" 结果已保存到: {csv_path}"
173
 
174
  return output_path, plot_path, status_message
175
  except Exception as e:
176
  import traceback
177
  traceback.print_exc()
178
+ return None, None, f"处理错误: {str(e)}"
179
 
180
  # 创建Gradio界面
181
  def create_interface():
182
+ with gr.Blocks(title="鼠强迫游泳挣扎度分析 - 对象跟踪") as app:
183
+ gr.Markdown("# 鼠强迫游泳测试挣扎度分析 (对象跟踪)")
184
 
185
  with gr.Row():
186
  with gr.Column(scale=1):
187
+ # 只保留视频选择,移除模型选择
188
+ video_input = gr.Video(label="输入视频")
189
+
190
+ # 显示当前使用的模型
191
+ model_info = gr.Textbox(label="模型信息", value=f"使用模型: {os.path.basename(model_file_path)}", interactive=False)
192
 
193
  # 参数设置
194
  with gr.Row():
195
+ conf = gr.Slider(minimum=0.1, maximum=0.9, value=0.25, step=0.05, label="置信度阈值")
196
+ iou = gr.Slider(minimum=0.1, maximum=0.9, value=0.45, step=0.05, label="IoU阈值")
197
 
198
  with gr.Row():
199
+ max_det = gr.Slider(minimum=1, maximum=50, value=20, step=1, label="最大检测数")
200
+ threshold = gr.Slider(minimum=0, maximum=1, value=0.3, step=0.01, label="挣扎阈值")
201
 
202
  # 帧选择
203
+ start_frame = gr.Slider(minimum=0, maximum=999999, value=0, step=1, label="起始帧")
204
+ end_frame = gr.Slider(minimum=0, maximum=999999, value=999999, step=1, label="结束帧")
205
 
206
  # 预览按钮
207
+ preview_btn = gr.Button("预览帧")
208
 
209
  # 开始分析
210
+ start_btn = gr.Button("开始分析", variant="primary")
211
 
212
  with gr.Column(scale=2):
213
  # 显示区域
214
+ with gr.Tab("预览"):
215
  # 图像预览
216
+ preview_image = gr.Image(label="预览图像", type="numpy", height=400)
217
+ status_text = gr.Textbox(label="状态", interactive=False)
218
  gr.Markdown("""
219
+ ### 使用说明:
220
+ 1. 选择一个视频文件
221
+ 2. 调整参数
222
+ - 置信度阈值: 对象检测的最低置信度,较低的值会检测更多潜在对象
223
+ - IoU阈值: 用于过滤重叠检测
224
+ - 最大检测数: 每帧检测的最大对象数
225
+ - 挣扎阈值: 分类为挣扎状态的最低分数
226
+ 3. 设置帧范围
227
+ 4. 点击"开始分析"按钮
228
 
229
+ 系统将自动跟踪小鼠并分析其挣扎行为,无需手动定义区域
230
  """)
231
 
232
+ with gr.Tab("结果"):
233
  with gr.Row():
234
+ output_video = gr.Video(label="分析结果视频")
235
+ result_plot = gr.Image(label="挣扎分数时间序列")
236
 
237
+ result_status = gr.Textbox(label="分析状态", interactive=False)
238
 
239
  # 绑定事件
240
  video_input.change(select_video, inputs=[video_input], outputs=[preview_image, status_text, start_frame, end_frame])
 
241
 
242
  preview_btn.click(preview_frame, inputs=[video_input, start_frame], outputs=[preview_image, status_text])
243
 
244
  start_btn.click(
245
  start_analysis,
246
+ inputs=[video_input, conf, iou, max_det, start_frame, end_frame, threshold],
247
  outputs=[output_video, result_plot, result_status]
248
  )
249
 
 
251
 
252
  # 启动应用
253
  if __name__ == "__main__":
254
+ # 清除可能干扰的代理设置
255
+ if 'http_proxy' in os.environ:
256
+ del os.environ['http_proxy']
257
+ if 'https_proxy' in os.environ:
258
+ del os.environ['https_proxy']
259
+ if 'all_proxy' in os.environ:
260
+ del os.environ['all_proxy']
261
+
262
+ # 检查模型文件是否存在
263
+ if not os.path.exists(model_file_path):
264
+ print(f"警告: 模型文件 {model_file_path} 不存在!")
265
+ else:
266
+ print(f"使用模型: {model_file_path}")
267
 
268
  app = create_interface()
269
  # 使用简化的启动配置