Hakureirm commited on
Commit
9ee846a
·
verified ·
1 Parent(s): 5d7aeec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -202
app.py CHANGED
@@ -20,308 +20,185 @@ except ImportError:
20
  # 全局变量
21
  analyzer = None
22
  video_file_path = None
23
- model_file_path = "./fst-v1.2-n.engine" # 直接指定模型文件路径
 
24
  total_frames = 0
25
  output_path = None
26
 
 
 
 
 
 
27
  # 从视频中提取特定帧
28
  def extract_frame(video_path, frame_num):
29
- """从视频中提取特定帧"""
30
  if not video_path:
31
  return None
32
-
33
  cap = cv2.VideoCapture(video_path)
34
  if not cap.isOpened():
35
  return None
36
-
37
- # 设置帧位置
38
  cap.set(cv2.CAP_PROP_POS_FRAMES, frame_num)
39
-
40
- # 读取帧
41
  ret, frame = cap.read()
42
  cap.release()
43
-
44
  if not ret:
45
  return None
46
-
47
- # 转换为RGB格式
48
- frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
49
- return frame_rgb
50
 
51
  # 选择视频文件
52
  def select_video(video_file):
53
  global video_file_path, total_frames
54
-
55
  if not video_file:
56
- return None, "请选择视频文件", gr.Slider(minimum=0, maximum=0, value=0), gr.Slider(minimum=0, maximum=0, value=0)
57
-
58
  video_file_path = video_file
59
-
60
- # 获取视频总帧数
61
  cap = cv2.VideoCapture(video_file_path)
62
  if not cap.isOpened():
63
- return None, "无法打开视频文件", gr.Slider(minimum=0, maximum=0, value=0), gr.Slider(minimum=0, maximum=0, value=0)
64
-
65
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
66
-
67
- # 提取第一帧
68
  ret, first_frame = cap.read()
69
  cap.release()
70
-
71
  if not ret:
72
- return None, "无法读取视频帧", gr.Slider(minimum=0, maximum=0, value=0), gr.Slider(minimum=0, maximum=0, value=0)
73
-
74
- # 转为RGB
75
  first_frame_rgb = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
76
-
77
- # 更新帧滑块
78
- start_slider = gr.Slider(minimum=0, maximum=total_frames-1, value=0, step=1)
79
- end_slider = gr.Slider(minimum=0, maximum=total_frames-1, value=total_frames-1, step=1)
80
-
81
- model_status = f"使用模型: {os.path.basename(model_file_path)}"
82
- return first_frame_rgb, f"视频加载成功,总帧数: {total_frames}. {model_status}", start_slider, end_slider
83
 
84
  # 预览帧
85
  def preview_frame(video_file, frame_num):
86
  if not video_file:
87
  return None, "请先选择视频文件"
88
-
89
- # 从视频提取帧
90
  frame = extract_frame(video_file, frame_num)
91
  if frame is None:
92
  return None, "无法读取指定帧"
93
-
94
  return frame, f"帧 {frame_num}"
95
 
96
- # 开始分析
97
- # 为HF Spaces环境添加GPU装饰器
98
- if is_spaces:
99
- @spaces.GPU(duration=120) # 申请GPU资源,持续120秒
100
- def start_analysis(video, conf, iou, max_det, start_frame, end_frame, threshold):
101
- return _start_analysis_impl(video, conf, iou, max_det, start_frame, end_frame, threshold)
102
- else:
103
- def start_analysis(video, conf, iou, max_det, start_frame, end_frame, threshold):
104
- return _start_analysis_impl(video, conf, iou, max_det, start_frame, end_frame, threshold)
105
-
106
- # 实际的分析实现
107
  def _start_analysis_impl(video, conf, iou, max_det, start_frame, end_frame, threshold):
108
- global analyzer, output_path, model_file_path
109
-
110
  if not video:
111
  return None, None, "请选择视频文件"
112
-
113
  if start_frame >= end_frame:
114
  return None, None, "起始帧必须小于结束帧"
115
-
116
- # 创建输出路径
117
  video_name = os.path.splitext(os.path.basename(video))[0]
118
  output_path = os.path.join(os.path.dirname(video), f"{video_name}_out.mp4")
119
  csv_path = os.path.join(os.path.dirname(video), f"{video_name}_results.csv")
120
-
121
  try:
122
- # 检查设备
123
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
124
- print(f"使用设备: {device}")
125
-
126
- # 确保模型文件存在
127
  if not os.path.exists(model_file_path):
128
- # 如果在Hugging Face Spaces环境中,尝试从Hub下载模型
129
  if is_spaces:
130
  try:
131
- print(f"尝试从Hugging Face Hub下载模型: {os.path.basename(model_file_path)}")
132
  model_file_path = hf_hub_download(
133
- repo_id="YOUR_HF_USERNAME/YOUR_REPO_NAME", # 替换为您的仓库
134
- filename="weights/fst-v1.2-n.onnx"
135
  )
136
- print(f"模型已下载到: {model_file_path}")
137
  except Exception as e:
138
- print(f"从Hub下载模型失败: {str(e)}")
139
  else:
140
- print(f"警告: 模型文件 {model_file_path} 不存在!")
141
-
142
- # 创建分析器
143
  analyzer = MouseTrackerAnalyzer(
144
- model_path=model_file_path,
145
- conf=conf,
146
- iou=iou,
147
- max_det=max_det,
148
- verbose=True # 开启详细日志
149
  )
150
  analyzer.struggle_threshold = threshold
151
-
152
- # 处理视频的进度回调
153
  def progress_update(progress, frame, results):
154
- print(f"处理进度: {progress}%, 检测到对象数: {len(results)}")
155
-
156
- print(f"处理视频: {video}")
157
- print(f"输出路径: {output_path}")
158
- print(f"参数: conf={conf}, iou={iou}, max_det={max_det}, threshold={threshold}")
159
- print(f"使用模型: {model_file_path}")
160
-
161
- # 提取视频帧数范围并分析
162
  results = analyzer.process_video(
163
- video_path=video,
164
- output_path=output_path,
165
- start_frame=start_frame,
166
  end_frame=end_frame,
167
  callback=progress_update
168
  )
169
-
170
- # 保存结果到CSV
171
- print(f"保存结果到CSV: {csv_path}")
172
  analyzer.save_results(csv_path)
173
- print(f"结果已保存到CSV,共 {len(analyzer.results)} 帧数据")
174
-
175
- # 生成分析图表
176
- print("生成时间序列图...")
177
- if len(analyzer.results) == 0:
178
- print("警告: 没有可用于绘图的结果!")
179
- plot_path = None
180
- else:
181
  plot_path = analyzer.generate_time_series_plot()
182
- if plot_path and os.path.exists(plot_path):
183
- print(f"图表已生成并保存到: {plot_path}, 大小: {os.path.getsize(plot_path)/1024:.2f}KB")
184
- else:
185
- print(f"生成图表失败或图表文件不存在!")
186
- plot_path = None
187
-
188
- # 检查输出文件是否存在
189
- if os.path.exists(output_path):
190
- file_size = os.path.getsize(output_path) / (1024 * 1024) # MB
191
- print(f"输出视频大小: {file_size:.2f}MB")
192
-
193
- # 处理debug帧
194
- debug_frame_path = os.path.join(os.path.dirname(output_path), "debug_frame.jpg")
195
- if os.path.exists(debug_frame_path):
196
- print(f"调试帧保存在: {debug_frame_path}")
197
-
198
- if plot_path and os.path.exists(plot_path):
199
- print(f"图表文件存在于: {plot_path}, 大小: {os.path.getsize(plot_path)/1024:.2f}KB")
200
-
201
- # 确保返回正确的文件路径
202
- status_message = "分析完成。"
203
-
204
- if os.path.exists(output_path):
205
- status_message += f"视频已保存。"
206
- else:
207
- status_message += "警告: 未找到输出视频。"
208
-
209
- if plot_path and os.path.exists(plot_path):
210
- status_message += f" 时间序列图已生成。"
211
- else:
212
- status_message += " 警告: 生成时间序列图失败。"
213
-
214
- status_message += f" 结果已保存到: {csv_path}"
215
-
216
- return output_path, plot_path, status_message
217
  except Exception as e:
218
  import traceback
219
  traceback.print_exc()
220
- return None, None, f"处理错误: {str(e)}"
221
 
222
- # 创建Gradio界面
 
 
 
 
 
 
 
 
 
223
  def create_interface():
224
  with gr.Blocks(title="鼠强迫游泳挣扎度分析 - 对象跟踪") as app:
225
  gr.Markdown("# 鼠强迫游泳测试挣扎度分析 (对象跟踪)")
226
-
227
  with gr.Row():
228
  with gr.Column(scale=1):
229
- # 只保留视频选择,移除模型选择
230
- video_input = gr.Video(label="输入视频")
231
-
232
- # 显示当前使用的模型和设备信息
233
- device_info = "GPU" if torch.cuda.is_available() else "CPU"
234
- model_info = gr.Textbox(
235
- label="系统信息",
236
- value=f"使用模型: {os.path.basename(model_file_path)} | 计算设备: {device_info}",
 
 
 
237
  interactive=False
238
  )
239
-
240
- # 参数设置
241
  with gr.Row():
242
- conf = gr.Slider(minimum=0.1, maximum=0.9, value=0.25, step=0.05, label="置信度阈值")
243
- iou = gr.Slider(minimum=0.1, maximum=0.9, value=0.45, step=0.05, label="IoU阈值")
244
-
245
  with gr.Row():
246
- max_det = gr.Slider(minimum=1, maximum=50, value=20, step=1, label="最大检测数")
247
- threshold = gr.Slider(minimum=0, maximum=1, value=0.3, step=0.01, label="挣扎阈值")
248
-
249
- # 帧选择
250
- start_frame = gr.Slider(minimum=0, maximum=999999, value=0, step=1, label="起始帧")
251
- end_frame = gr.Slider(minimum=0, maximum=999999, value=999999, step=1, label="结束帧")
252
-
253
- # 预览按钮
254
  preview_btn = gr.Button("预览帧")
255
-
256
- # 开始分析
257
  start_btn = gr.Button("开始分析", variant="primary")
258
-
259
  with gr.Column(scale=2):
260
- # 显示区域
261
  with gr.Tab("预览"):
262
- # 图像预览
263
  preview_image = gr.Image(label="预览图像", type="numpy", height=400)
264
  status_text = gr.Textbox(label="状态", interactive=False)
265
- gr.Markdown("""
266
- ### 使用说明:
267
- 1. 选择一个视频文件
268
- 2. 调整参数
269
- - 置信度阈值: 对象检测的最低置信度,较低的值会检测更多潜在对象
270
- - IoU阈值: 用于过滤重叠检测
271
- - 最大检测数: 每帧检测的最大对象数
272
- - 挣扎阈值: 分类为挣扎状态的最低分数
273
- 3. 设置帧范围
274
- 4. 点击"开始分析"按钮
275
-
276
- 系统将自动跟踪小鼠并分析其挣扎行为,无需手动定义区域
277
- """)
278
-
279
  with gr.Tab("结果"):
280
  with gr.Row():
281
  output_video = gr.Video(label="分析结果视频")
282
  result_plot = gr.Image(label="挣扎分数时间序列")
283
-
284
  result_status = gr.Textbox(label="分析状态", interactive=False)
285
-
286
- # 绑定事件
287
  video_input.change(select_video, inputs=[video_input], outputs=[preview_image, status_text, start_frame, end_frame])
288
-
289
  preview_btn.click(preview_frame, inputs=[video_input, start_frame], outputs=[preview_image, status_text])
290
-
 
291
  start_btn.click(
292
- start_analysis,
293
- inputs=[video_input, conf, iou, max_det, start_frame, end_frame, threshold],
294
  outputs=[output_video, result_plot, result_status]
295
  )
296
-
297
  return app
298
 
299
- # 启动应用
300
  if __name__ == "__main__":
301
- # 清除可能干扰的代理设置
302
- if 'http_proxy' in os.environ:
303
- del os.environ['http_proxy']
304
- if 'https_proxy' in os.environ:
305
- del os.environ['https_proxy']
306
- if 'all_proxy' in os.environ:
307
- del os.environ['all_proxy']
308
-
309
- # 检查设备和模型
310
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
311
  print(f"使用设备: {device}")
312
-
313
- # 检查模型文件是否存在
314
- if not os.path.exists(model_file_path):
315
- print(f"警告: 模型文件 {model_file_path} 不存在!")
316
- else:
317
- print(f"使用模型: {model_file_path}")
318
-
319
  app = create_interface()
320
-
321
- # 根据环境决定启动方式
322
  if is_spaces:
323
- # Hugging Face Spaces环境中的启动方式
324
  app.launch()
325
  else:
326
- # 本地环境的启动方式
327
- app.launch(server_name="0.0.0.0", server_port=7860, share=False)
 
20
  # 全局变量
21
  analyzer = None
22
  video_file_path = None
23
+ model_suffix = ".engine" # 默认使用 TensorRT 格式
24
+ model_base_name = "fst-v1.2-n" # 模型基础名称,无后缀
25
  total_frames = 0
26
  output_path = None
27
 
28
+ # 构造模型文件路径
29
+ def get_model_file_path():
30
+ """根据用户选择的后缀返回完整模型文件路径"""
31
+ return f"./{model_base_name}{model_suffix}"
32
+
33
  # 从视频中提取特定帧
34
  def extract_frame(video_path, frame_num):
35
+ """从视频中提取指定帧号并返回 RGB 图像"""
36
  if not video_path:
37
  return None
 
38
  cap = cv2.VideoCapture(video_path)
39
  if not cap.isOpened():
40
  return None
 
 
41
  cap.set(cv2.CAP_PROP_POS_FRAMES, frame_num)
 
 
42
  ret, frame = cap.read()
43
  cap.release()
 
44
  if not ret:
45
  return None
46
+ return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
 
 
 
47
 
48
  # 选择视频文件
49
  def select_video(video_file):
50
  global video_file_path, total_frames
 
51
  if not video_file:
52
+ # 无视频时重置滑块和提示
53
+ return None, "请选择视频文件", gr.Slider(0, 0, 0), gr.Slider(0, 0, 0)
54
  video_file_path = video_file
 
 
55
  cap = cv2.VideoCapture(video_file_path)
56
  if not cap.isOpened():
57
+ return None, "无法打开视频文件", gr.Slider(0,0,0), gr.Slider(0,0,0)
 
58
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
 
 
59
  ret, first_frame = cap.read()
60
  cap.release()
 
61
  if not ret:
62
+ return None, "无法读取视频帧", gr.Slider(0,0,0), gr.Slider(0,0,0)
 
 
63
  first_frame_rgb = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
64
+ # 更新滑块范围
65
+ start = gr.Slider(minimum=0, maximum=total_frames-1, value=0, step=1)
66
+ end = gr.Slider(minimum=0, maximum=total_frames-1, value=total_frames-1, step=1)
67
+ status = f"视频加载成功,总帧数: {total_frames}. 使用模型: {os.path.basename(get_model_file_path())}"
68
+ return first_frame_rgb, status, start, end
 
 
69
 
70
  # 预览帧
71
  def preview_frame(video_file, frame_num):
72
  if not video_file:
73
  return None, "请先选择视频文件"
 
 
74
  frame = extract_frame(video_file, frame_num)
75
  if frame is None:
76
  return None, "无法读取指定帧"
 
77
  return frame, f"帧 {frame_num}"
78
 
79
+ # 分析逻辑实现
 
 
 
 
 
 
 
 
 
 
80
  def _start_analysis_impl(video, conf, iou, max_det, start_frame, end_frame, threshold):
81
+ global analyzer, output_path
 
82
  if not video:
83
  return None, None, "请选择视频文件"
 
84
  if start_frame >= end_frame:
85
  return None, None, "起始帧必须小于结束帧"
 
 
86
  video_name = os.path.splitext(os.path.basename(video))[0]
87
  output_path = os.path.join(os.path.dirname(video), f"{video_name}_out.mp4")
88
  csv_path = os.path.join(os.path.dirname(video), f"{video_name}_results.csv")
 
89
  try:
 
90
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
91
+ model_file_path = get_model_file_path()
 
 
92
  if not os.path.exists(model_file_path):
 
93
  if is_spaces:
94
  try:
 
95
  model_file_path = hf_hub_download(
96
+ repo_id="YOUR_HF_USERNAME/YOUR_REPO_NAME",
97
+ filename=f"weights/{model_base_name}{model_suffix}"
98
  )
 
99
  except Exception as e:
100
+ print(f"下载模型失败: {e}")
101
  else:
102
+ print(f"警告: 未找到模型文件 {model_file_path}")
 
 
103
  analyzer = MouseTrackerAnalyzer(
104
+ model_path=model_file_path,
105
+ conf=conf,
106
+ iou=iou,
107
+ max_det=max_det,
108
+ verbose=True
109
  )
110
  analyzer.struggle_threshold = threshold
 
 
111
  def progress_update(progress, frame, results):
112
+ print(f"进度: {progress}%,检测: {len(results)} 个对象")
 
 
 
 
 
 
 
113
  results = analyzer.process_video(
114
+ video_path=video,
115
+ output_path=output_path,
116
+ start_frame=start_frame,
117
  end_frame=end_frame,
118
  callback=progress_update
119
  )
 
 
 
120
  analyzer.save_results(csv_path)
121
+ plot_path = None
122
+ if analyzer.results:
 
 
 
 
 
 
123
  plot_path = analyzer.generate_time_series_plot()
124
+ status = f"分析完成。视频: {output_path}, CSV: {csv_path}"
125
+ if plot_path:
126
+ status += f", 图表: {plot_path}"
127
+ return output_path, plot_path, status
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  except Exception as e:
129
  import traceback
130
  traceback.print_exc()
131
+ return None, None, f"处理错误: {e}"
132
 
133
+ # HF Spaces 装饰器
134
+ if is_spaces:
135
+ @spaces.GPU(duration=120)
136
+ def start_analysis(video, conf, iou, max_det, start_frame, end_frame, threshold):
137
+ return _start_analysis_impl(video, conf, iou, max_det, start_frame, end_frame, threshold)
138
+ else:
139
+ def start_analysis(video, conf, iou, max_det, start_frame, end_frame, threshold):
140
+ return _start_analysis_impl(video, conf, iou, max_det, start_frame, end_frame, threshold)
141
+
142
+ # 创建 Gradio 界面
143
  def create_interface():
144
  with gr.Blocks(title="鼠强迫游泳挣扎度分析 - 对象跟踪") as app:
145
  gr.Markdown("# 鼠强迫游泳测试挣扎度分析 (对象跟踪)")
 
146
  with gr.Row():
147
  with gr.Column(scale=1):
148
+ video_input = gr.Video(label="输入视频")
149
+ # 新增模型格式选择下拉框
150
+ model_format = gr.Dropdown(
151
+ label="模型格式",
152
+ choices=[".onnx", ".engine", ".pt", ".mlpackage"],
153
+ value=model_suffix,
154
+ interactive=True
155
+ )
156
+ device_info = gr.Textbox(
157
+ label="系统信息",
158
+ value=f"设备: {'GPU' if torch.cuda.is_available() else 'CPU'}",
159
  interactive=False
160
  )
 
 
161
  with gr.Row():
162
+ conf = gr.Slider(0.1, 0.9, value=0.25, step=0.05, label="置信度阈值")
163
+ iou = gr.Slider(0.1, 0.9, value=0.45, step=0.05, label="IoU阈值")
 
164
  with gr.Row():
165
+ max_det = gr.Slider(1, 50, value=20, step=1, label="最大检测数")
166
+ threshold = gr.Slider(0, 1, value=0.3, step=0.01, label="挣扎阈值")
167
+ start_frame = gr.Slider(0, 999999, value=0, step=1, label="起始帧")
168
+ end_frame = gr.Slider(0, 999999, value=999999, step=1, label="结束帧")
 
 
 
 
169
  preview_btn = gr.Button("预览帧")
 
 
170
  start_btn = gr.Button("开始分析", variant="primary")
 
171
  with gr.Column(scale=2):
 
172
  with gr.Tab("预览"):
 
173
  preview_image = gr.Image(label="预览图像", type="numpy", height=400)
174
  status_text = gr.Textbox(label="状态", interactive=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
  with gr.Tab("结果"):
176
  with gr.Row():
177
  output_video = gr.Video(label="分析结果视频")
178
  result_plot = gr.Image(label="挣扎分数时间序列")
 
179
  result_status = gr.Textbox(label="分析状态", interactive=False)
180
+ # 事件绑定
 
181
  video_input.change(select_video, inputs=[video_input], outputs=[preview_image, status_text, start_frame, end_frame])
 
182
  preview_btn.click(preview_frame, inputs=[video_input, start_frame], outputs=[preview_image, status_text])
183
+ # 传递模型格式到全局
184
+ model_format.change(lambda fmt: setattr(globals(), 'model_suffix', fmt) or fmt, inputs=[model_format], outputs=[])
185
  start_btn.click(
186
+ start_analysis,
187
+ inputs=[video_input, conf, iou, max_det, start_frame, end_frame, threshold],
188
  outputs=[output_video, result_plot, result_status]
189
  )
 
190
  return app
191
 
 
192
  if __name__ == "__main__":
193
+ # 清除代理环境
194
+ for p in ['http_proxy', 'https_proxy', 'all_proxy']:
195
+ os.environ.pop(p, None)
196
+ # 日志设备和模型信息
 
 
 
 
 
197
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
198
  print(f"使用设备: {device}")
199
+ print(f"默认模型路径: {get_model_file_path()}")
 
 
 
 
 
 
200
  app = create_interface()
 
 
201
  if is_spaces:
 
202
  app.launch()
203
  else:
204
+ app.launch(server_name="0.0.0.0", server_port=7860, share=False)