Hakureirm commited on
Commit
18e189f
·
verified ·
1 Parent(s): 53a4d4a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -103
app.py CHANGED
@@ -2,37 +2,29 @@ import os
2
  import cv2
3
  import numpy as np
4
  import gradio as gr
5
- import tempfile
6
  import torch
7
  from mouse_tracker import MouseTrackerAnalyzer
8
- import huggingface_hub
9
  from huggingface_hub import hf_hub_download
10
 
11
  # 检查是否在Hugging Face Spaces环境中
12
  try:
13
  import spaces
14
  is_spaces = True
15
- print("检测到Hugging Face Spaces环境")
16
  except ImportError:
17
  is_spaces = False
18
  print("在本地环境运行")
19
 
20
- # 全局变量
21
- analyzer = None
22
- video_file_path = None
23
- model_suffix = ".onnx" # 默认使用 TensorRT 格式
24
  model_base_name = "fst-v1.2-n" # 模型基础名称,无后缀
25
  total_frames = 0
26
- output_path = None
27
 
28
- # 构造模型文件路径
29
- def get_model_file_path():
30
- """根据用户选择的后缀返回完整模型文件路径"""
31
  return f"./{model_base_name}{model_suffix}"
32
 
33
  # 从视频中提取特定帧
34
  def extract_frame(video_path, frame_num):
35
- """从视频中提取指定帧号并返回 RGB 图像"""
36
  if not video_path:
37
  return None
38
  cap = cv2.VideoCapture(video_path)
@@ -46,26 +38,23 @@ def extract_frame(video_path, frame_num):
46
  return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
47
 
48
  # 选择视频文件
49
- def select_video(video_file):
50
- global video_file_path, total_frames
51
  if not video_file:
52
- # 无视频时重置滑块和提示
53
- return None, "请选择视频文件", gr.Slider(0, 0, 0), gr.Slider(0, 0, 0)
54
- video_file_path = video_file
55
- cap = cv2.VideoCapture(video_file_path)
56
- if not cap.isOpened():
57
- return None, "无法打开视频文件", gr.Slider(0,0,0), gr.Slider(0,0,0)
58
- total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
59
- ret, first_frame = cap.read()
60
  cap.release()
61
  if not ret:
62
  return None, "无法读取视频帧", gr.Slider(0,0,0), gr.Slider(0,0,0)
63
- first_frame_rgb = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
64
- # 更新滑块范围
65
  start = gr.Slider(minimum=0, maximum=total_frames-1, value=0, step=1)
66
  end = gr.Slider(minimum=0, maximum=total_frames-1, value=total_frames-1, step=1)
67
- status = f"视频加载成功,总帧数: {total_frames}. 使用模型: {os.path.basename(get_model_file_path())}"
68
- return first_frame_rgb, status, start, end
69
 
70
  # 预览帧
71
  def preview_frame(video_file, frame_num):
@@ -76,94 +65,87 @@ def preview_frame(video_file, frame_num):
76
  return None, "无法读取指定帧"
77
  return frame, f"帧 {frame_num}"
78
 
79
- # 分析逻辑实现
80
- def _start_analysis_impl(video, conf, iou, max_det, start_frame, end_frame, threshold):
81
- global analyzer, output_path
82
  if not video:
83
  return None, None, "请选择视频文件"
84
  if start_frame >= end_frame:
85
  return None, None, "起始帧必须小于结束帧"
 
86
  video_name = os.path.splitext(os.path.basename(video))[0]
87
  output_path = os.path.join(os.path.dirname(video), f"{video_name}_out.mp4")
88
  csv_path = os.path.join(os.path.dirname(video), f"{video_name}_results.csv")
89
- try:
90
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
91
- model_file_path = get_model_file_path()
92
- if not os.path.exists(model_file_path):
93
- if is_spaces:
94
- try:
95
- model_file_path = hf_hub_download(
96
- repo_id="YOUR_HF_USERNAME/YOUR_REPO_NAME",
97
- filename=f"weights/{model_base_name}{model_suffix}"
98
- )
99
- except Exception as e:
100
- print(f"下载模型失败: {e}")
101
- else:
102
- print(f"警告: 未找到模型文件 {model_file_path}")
103
- analyzer = MouseTrackerAnalyzer(
104
- model_path=model_file_path,
105
- conf=conf,
106
- iou=iou,
107
- max_det=max_det,
108
- verbose=True
109
- )
110
- analyzer.struggle_threshold = threshold
111
- def progress_update(progress, frame, results):
112
- print(f"进度: {progress}%,检测: {len(results)} 个对象")
113
- results = analyzer.process_video(
114
- video_path=video,
115
- output_path=output_path,
116
- start_frame=start_frame,
117
- end_frame=end_frame,
118
- callback=progress_update
119
- )
120
- analyzer.save_results(csv_path)
121
- plot_path = None
122
- if analyzer.results:
123
- plot_path = analyzer.generate_time_series_plot()
124
- status = f"分析完成。视频: {output_path}, CSV: {csv_path}"
125
- if plot_path:
126
- status += f", 图表: {plot_path}"
127
- return output_path, plot_path, status
128
- except Exception as e:
129
- import traceback
130
- traceback.print_exc()
131
- return None, None, f"处理错误: {e}"
132
 
133
- # HF Spaces 装饰器
134
  if is_spaces:
135
  @spaces.GPU(duration=120)
136
- def start_analysis(video, conf, iou, max_det, start_frame, end_frame, threshold):
137
- return _start_analysis_impl(video, conf, iou, max_det, start_frame, end_frame, threshold)
138
  else:
139
- def start_analysis(video, conf, iou, max_det, start_frame, end_frame, threshold):
140
- return _start_analysis_impl(video, conf, iou, max_det, start_frame, end_frame, threshold)
141
 
142
  # 创建 Gradio 界面
143
  def create_interface():
144
- with gr.Blocks(title="鼠强迫游泳挣扎度分析 - 对象跟踪") as app:
145
  gr.Markdown("# 鼠强迫游泳测试挣扎度分析 (对象跟踪)")
146
  with gr.Row():
147
  with gr.Column(scale=1):
148
- video_input = gr.Video(label="输入视频")
149
- # 新增模型格式选择下拉框
150
  model_format = gr.Dropdown(
151
  label="模型格式",
152
  choices=[".onnx", ".engine", ".pt", ".mlpackage"],
153
- value=model_suffix,
154
  interactive=True
155
  )
156
  device_info = gr.Textbox(
157
  label="系统信息",
158
- value=f"设备: {'GPU' if torch.cuda.is_available() else 'CPU'}",
159
  interactive=False
160
  )
161
- with gr.Row():
162
- conf = gr.Slider(0.1, 0.9, value=0.25, step=0.05, label="置信度阈值")
163
- iou = gr.Slider(0.1, 0.9, value=0.45, step=0.05, label="IoU阈值")
164
- with gr.Row():
165
- max_det = gr.Slider(1, 50, value=20, step=1, label="最大检测数")
166
- threshold = gr.Slider(0, 1, value=0.3, step=0.01, label="挣扎阈值")
167
  start_frame = gr.Slider(0, 999999, value=0, step=1, label="起始帧")
168
  end_frame = gr.Slider(0, 999999, value=999999, step=1, label="结束帧")
169
  preview_btn = gr.Button("预览帧")
@@ -173,30 +155,25 @@ def create_interface():
173
  preview_image = gr.Image(label="预览图像", type="numpy", height=400)
174
  status_text = gr.Textbox(label="状态", interactive=False)
175
  with gr.Tab("结果"):
176
- with gr.Row():
177
- output_video = gr.Video(label="分析结果视频")
178
- result_plot = gr.Image(label="挣扎分数时间序列")
179
  result_status = gr.Textbox(label="分析状态", interactive=False)
180
- # 事件绑定
181
- video_input.change(select_video, inputs=[video_input], outputs=[preview_image, status_text, start_frame, end_frame])
182
  preview_btn.click(preview_frame, inputs=[video_input, start_frame], outputs=[preview_image, status_text])
183
- # 传递模型格式到全局
184
- model_format.change(lambda fmt: setattr(globals(), 'model_suffix', fmt) or fmt, inputs=[model_format], outputs=[])
185
  start_btn.click(
186
  start_analysis,
187
- inputs=[video_input, conf, iou, max_det, start_frame, end_frame, threshold],
188
  outputs=[output_video, result_plot, result_status]
189
  )
190
  return app
191
 
192
  if __name__ == "__main__":
193
- # 清除代理环境
194
- for p in ['http_proxy', 'https_proxy', 'all_proxy']:
195
- os.environ.pop(p, None)
196
- # 日志设备和模型信息
197
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
198
- print(f"使用设备: {device}")
199
- print(f"默认模型路径: {get_model_file_path()}")
200
  app = create_interface()
201
  if is_spaces:
202
  app.launch()
 
2
  import cv2
3
  import numpy as np
4
  import gradio as gr
 
5
  import torch
6
  from mouse_tracker import MouseTrackerAnalyzer
 
7
  from huggingface_hub import hf_hub_download
8
 
9
  # 检查是否在Hugging Face Spaces环境中
10
  try:
11
  import spaces
12
  is_spaces = True
13
+ print("检测到 Hugging Face Spaces 环境")
14
  except ImportError:
15
  is_spaces = False
16
  print("在本地环境运行")
17
 
18
+ # 全局配置
 
 
 
19
  model_base_name = "fst-v1.2-n" # 模型基础名称,无后缀
20
  total_frames = 0
 
21
 
22
+ # 根据后缀构造模型路径
23
+ def get_model_file_path(model_suffix):
 
24
  return f"./{model_base_name}{model_suffix}"
25
 
26
  # 从视频中提取特定帧
27
  def extract_frame(video_path, frame_num):
 
28
  if not video_path:
29
  return None
30
  cap = cv2.VideoCapture(video_path)
 
38
  return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
39
 
40
  # 选择视频文件
41
+ def select_video(video_file, model_suffix):
42
+ global total_frames
43
  if not video_file:
44
+ return None, "请选择视频文件", gr.Slider(0,0,0), gr.Slider(0,0,0)
45
+ total_frames = int(cv2.VideoCapture(video_file).get(cv2.CAP_PROP_FRAME_COUNT))
46
+ # 读取首帧
47
+ cap = cv2.VideoCapture(video_file)
48
+ ret, frame = cap.read()
 
 
 
49
  cap.release()
50
  if not ret:
51
  return None, "无法读取视频帧", gr.Slider(0,0,0), gr.Slider(0,0,0)
52
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
53
+ # 更新滑块
54
  start = gr.Slider(minimum=0, maximum=total_frames-1, value=0, step=1)
55
  end = gr.Slider(minimum=0, maximum=total_frames-1, value=total_frames-1, step=1)
56
+ status = f"视频加载成功,总帧数: {total_frames}. 使用模型: {os.path.basename(get_model_file_path(model_suffix))}"
57
+ return frame_rgb, status, start, end
58
 
59
  # 预览帧
60
  def preview_frame(video_file, frame_num):
 
65
  return None, "无法读取指定帧"
66
  return frame, f"帧 {frame_num}"
67
 
68
+ # 分析实现
69
+ def _start_analysis_impl(video, model_suffix, conf, iou, max_det, start_frame, end_frame, threshold):
 
70
  if not video:
71
  return None, None, "请选择视频文件"
72
  if start_frame >= end_frame:
73
  return None, None, "起始帧必须小于结束帧"
74
+ # 构造路径
75
  video_name = os.path.splitext(os.path.basename(video))[0]
76
  output_path = os.path.join(os.path.dirname(video), f"{video_name}_out.mp4")
77
  csv_path = os.path.join(os.path.dirname(video), f"{video_name}_results.csv")
78
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
79
+ model_path = get_model_file_path(model_suffix)
80
+ if not os.path.exists(model_path):
81
+ if is_spaces:
82
+ try:
83
+ model_path = hf_hub_download(
84
+ repo_id="YOUR_HF_USERNAME/YOUR_REPO_NAME",
85
+ filename=f"weights/{model_base_name}{model_suffix}"
86
+ )
87
+ except Exception:
88
+ print(f"下载模型失败: {model_path}")
89
+ else:
90
+ print(f"警告: 本地未找到模型文件 {model_path}")
91
+ # 初始化分析器
92
+ analyzer = MouseTrackerAnalyzer(
93
+ model_path=model_path,
94
+ conf=conf,
95
+ iou=iou,
96
+ max_det=max_det,
97
+ verbose=True
98
+ )
99
+ analyzer.struggle_threshold = threshold
100
+ # 运行分析
101
+ analyzer.process_video(
102
+ video_path=video,
103
+ output_path=output_path,
104
+ start_frame=start_frame,
105
+ end_frame=end_frame,
106
+ callback=lambda prog, frm, res: print(f"进度: {prog}% 检测: {len(res)} 项")
107
+ )
108
+ analyzer.save_results(csv_path)
109
+ # 生成图表
110
+ plot_path = None
111
+ if analyzer.results:
112
+ plot_path = analyzer.generate_time_series_plot()
113
+ status = f"分析完成。视频: {output_path}, CSV: {csv_path}"
114
+ if plot_path:
115
+ status += f", 图表: {plot_path}"
116
+ return output_path, plot_path, status
 
 
 
 
117
 
118
+ # HF Spaces GPU 装饰
119
  if is_spaces:
120
  @spaces.GPU(duration=120)
121
+ def start_analysis(video, model_suffix, conf, iou, max_det, start_frame, end_frame, threshold):
122
+ return _start_analysis_impl(video, model_suffix, conf, iou, max_det, start_frame, end_frame, threshold)
123
  else:
124
+ def start_analysis(video, model_suffix, conf, iou, max_det, start_frame, end_frame, threshold):
125
+ return _start_analysis_impl(video, model_suffix, conf, iou, max_det, start_frame, end_frame, threshold)
126
 
127
  # 创建 Gradio 界面
128
  def create_interface():
129
+ with gr.Blocks(title="鼠强迫游泳挣扎度分析") as app:
130
  gr.Markdown("# 鼠强迫游泳测试挣扎度分析 (对象跟踪)")
131
  with gr.Row():
132
  with gr.Column(scale=1):
133
+ video_input = gr.Video(label="输入视频")
 
134
  model_format = gr.Dropdown(
135
  label="模型格式",
136
  choices=[".onnx", ".engine", ".pt", ".mlpackage"],
137
+ value=".onnx",
138
  interactive=True
139
  )
140
  device_info = gr.Textbox(
141
  label="系统信息",
142
+ value=f"设备: {'GPU' if torch.cuda.is_available() else 'CPU'}",
143
  interactive=False
144
  )
145
+ conf = gr.Slider(0.1, 0.9, value=0.25, step=0.05, label="置信度阈值")
146
+ iou = gr.Slider(0.1, 0.9, value=0.45, step=0.05, label="IoU阈值")
147
+ max_det = gr.Slider(1, 50, value=20, step=1, label="最大检测数")
148
+ threshold = gr.Slider(0, 1, value=0.3, step=0.01, label="挣扎阈值")
 
 
149
  start_frame = gr.Slider(0, 999999, value=0, step=1, label="起始帧")
150
  end_frame = gr.Slider(0, 999999, value=999999, step=1, label="结束帧")
151
  preview_btn = gr.Button("预览帧")
 
155
  preview_image = gr.Image(label="预览图像", type="numpy", height=400)
156
  status_text = gr.Textbox(label="状态", interactive=False)
157
  with gr.Tab("结果"):
158
+ output_video = gr.Video(label="分析结果视频")
159
+ result_plot = gr.Image(label="挣扎分数时间序列")
 
160
  result_status = gr.Textbox(label="分析状态", interactive=False)
161
+ # 事件绑定,包含模型格式参数
162
+ video_input.change(select_video, inputs=[video_input, model_format], outputs=[preview_image, status_text, start_frame, end_frame])
163
  preview_btn.click(preview_frame, inputs=[video_input, start_frame], outputs=[preview_image, status_text])
 
 
164
  start_btn.click(
165
  start_analysis,
166
+ inputs=[video_input, model_format, conf, iou, max_det, start_frame, end_frame, threshold],
167
  outputs=[output_video, result_plot, result_status]
168
  )
169
  return app
170
 
171
  if __name__ == "__main__":
172
+ # 清理代理
173
+ for key in ['http_proxy', 'https_proxy', 'all_proxy']:
174
+ os.environ.pop(key, None)
175
+ print(f"设备: {'GPU' if torch.cuda.is_available() else 'CPU'}")
176
+ print(f"默认模型路径: {get_model_file_path('.onnx')}")
 
 
177
  app = create_interface()
178
  if is_spaces:
179
  app.launch()