Hakureirm commited on
Commit
899c703
·
verified ·
1 Parent(s): 65d5672

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -24
app.py CHANGED
@@ -1,6 +1,6 @@
1
  # filename: app.py
2
 
3
- import spaces # 必须最先 import
4
  import os
5
  import cv2
6
  import numpy as np
@@ -15,24 +15,21 @@ print(f"CUDA available: {use_cuda}")
15
  if use_cuda:
16
  print(f"GPU Device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
17
 
18
- # === 2. 加载模型并显式指定 task ===
19
- # 避免无法猜测任务类型的警告,明确使用分割 (segment) 模式
20
- model = YOLO("fst-v1.2-n.onnx", task="segment") # ONNX 模型需上传至空间
21
-
22
- # 若 CUDA 可用,迁移模型至 GPU
23
  if use_cuda:
24
  try:
25
  model.model.to("cuda")
26
- except Exception:
27
  pass
28
 
29
- @spaces.GPU(duration=600) # 调用时分配 GPU,超时 600s :contentReference[oaicite:3]{index=3}
30
  def analyze_video(video_path, num_mice, window_size_sec=1, fps=30):
31
  """
32
- 核心分析:分割 → 跟踪 → 计算“挣扎强度”
33
- 返回:标注后的视频路径 & 绘制好的挣扎曲线 (matplotlib Figure)
34
  """
35
- # 视频读写设置
36
  cap = cv2.VideoCapture(video_path)
37
  fourcc = cv2.VideoWriter_fourcc(*"mp4v")
38
  out_path = "output.mp4"
@@ -40,7 +37,7 @@ def analyze_video(video_path, num_mice, window_size_sec=1, fps=30):
40
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
41
  out = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
42
 
43
- # 跟踪初始化
44
  prev_centroids = [None] * num_mice
45
  prev_masks = [None] * num_mice
46
  struggle_records = [[] for _ in range(num_mice)]
@@ -51,21 +48,33 @@ def analyze_video(video_path, num_mice, window_size_sec=1, fps=30):
51
  if not ret:
52
  break
53
 
54
- # 分割推理(stream=True 加速),指定 device
55
  device = "cuda" if use_cuda else "cpu"
56
- results = model(frame, stream=True, device=device)
57
  res = next(results)
58
- masks = res.masks.data.cpu().numpy() # [N, H, W]
59
 
60
- # 计算质心 & 分配 ID(nearest-centroid)
 
 
 
 
 
 
 
 
 
 
 
61
  curr_centroids = []
62
  for m in masks:
63
  ys, xs = np.where(m > 0)
64
  curr_centroids.append(
65
- (int(np.mean(xs)), int(np.mean(ys))) if len(xs)>0 else None
66
  )
 
67
  assignments = [-1] * len(curr_centroids)
68
  unused_ids = set(range(num_mice))
 
69
  for i, c in enumerate(curr_centroids):
70
  if c is None: continue
71
  best_j, best_d = None, float("inf")
@@ -78,26 +87,31 @@ def analyze_video(video_path, num_mice, window_size_sec=1, fps=30):
78
  if best_j is not None and best_d < (50**2):
79
  assignments[i] = best_j
80
  unused_ids.remove(best_j)
 
81
  for i in range(len(curr_centroids)):
82
  if assignments[i] < 0 and unused_ids:
83
  assignments[i] = unused_ids.pop()
84
 
85
- # 计算挣扎强度 & 可视化
86
  for i, m in enumerate(masks):
87
  mid = assignments[i]
88
- if mid < 0: continue
 
89
  prev_m = prev_masks[mid]
90
  if prev_m is None:
91
  struggle_records[mid].append(None)
92
  else:
 
93
  diff = int(np.logical_xor(prev_m, m).sum())
94
  struggle_records[mid].append(diff)
95
- # 叠加掩膜 & ID
 
96
  mask_rgb = np.stack([m*255 if c==1 else 0 for c in range(3)], axis=-1).astype(np.uint8)
97
  frame = cv2.addWeighted(frame, 1, mask_rgb, 0.5, 0)
98
  if curr_centroids[i]:
99
  cv2.putText(frame, f"ID:{mid}", curr_centroids[i],
100
  cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,255,0), 2)
 
101
  prev_centroids[mid] = curr_centroids[i]
102
  prev_masks[mid] = m.copy()
103
 
@@ -107,7 +121,7 @@ def analyze_video(video_path, num_mice, window_size_sec=1, fps=30):
107
  cap.release()
108
  out.release()
109
 
110
- # 汇总 & 绘图
111
  win = int(window_size_sec * fps)
112
  fig, ax = plt.subplots(figsize=(8,4))
113
  times = np.arange(0, frame_idx, win) / fps
@@ -118,6 +132,7 @@ def analyze_video(video_path, num_mice, window_size_sec=1, fps=30):
118
  first = next((i for i,v in enumerate(rec) if v is not None), None)
119
  if first is not None:
120
  ax.axvspan(0, first/fps, alpha=0.3, color='gray')
 
121
  ax.set_xlabel("Time (s)")
122
  ax.set_ylabel("Struggle Intensity")
123
  ax.set_title("Mouse Struggle Over Time")
@@ -125,7 +140,7 @@ def analyze_video(video_path, num_mice, window_size_sec=1, fps=30):
125
 
126
  return out_path, fig
127
 
128
- # Gradio 前端
129
  with gr.Blocks(title="Mice Struggle Analysis") as demo:
130
  gr.Markdown("上传视频,输入鼠标数量,点击 Run。")
131
  with gr.Row():
@@ -139,5 +154,5 @@ with gr.Blocks(title="Mice Struggle Analysis") as demo:
139
  outputs=[output_video, output_plot])
140
 
141
  if __name__ == "__main__":
142
- # 去除不支持的 api_config 参数
143
- demo.launch(server_name="0.0.0.0", server_port=7860, share=False) # ⚠️ 删除 api_config :contentReference[oaicite:4]{index=4}
 
1
  # filename: app.py
2
 
3
+ import spaces # 必须最先 import,用于 ZeroGPU 装饰
4
  import os
5
  import cv2
6
  import numpy as np
 
15
  if use_cuda:
16
  print(f"GPU Device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
17
 
18
+ # === 2. 加载模型 (显式指定 segmentation 任务 & 默认置信度阈值) ===
19
+ # conf 设置为 0.25,可根据实际降低到 0.1-0.2
20
+ model = YOLO("fst-v1.2-n.onnx", task="segment") # 避免自动猜测任务模式的警告
 
 
21
  if use_cuda:
22
  try:
23
  model.model.to("cuda")
24
+ except:
25
  pass
26
 
27
+ @spaces.GPU(duration=600) # ZeroGPU 上执行该函数,超时 600s :contentReference[oaicite:2]{index=2}
28
  def analyze_video(video_path, num_mice, window_size_sec=1, fps=30):
29
  """
30
+ 分割 → 跟踪 → “挣扎强度” 计算
31
+ 返回:标注后视频文件路径 & matplotlib Figure
32
  """
 
33
  cap = cv2.VideoCapture(video_path)
34
  fourcc = cv2.VideoWriter_fourcc(*"mp4v")
35
  out_path = "output.mp4"
 
37
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
38
  out = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
39
 
40
+ # 跟踪状态
41
  prev_centroids = [None] * num_mice
42
  prev_masks = [None] * num_mice
43
  struggle_records = [[] for _ in range(num_mice)]
 
48
  if not ret:
49
  break
50
 
51
+ # === 3. 分割推理 ===
52
  device = "cuda" if use_cuda else "cpu"
53
+ results = model(frame, stream=True, device=device, conf=0.25) # 指定置信度阈值 :contentReference[oaicite:3]{index=3}
54
  res = next(results)
 
55
 
56
+ # 空检测帧处理:res.masks 可能为 None
57
+ if res.masks is None or res.masks.data is None:
58
+ # 为每只鼠标补充 None 并直接写入原帧
59
+ for mid in range(num_mice):
60
+ struggle_records[mid].append(None)
61
+ out.write(frame)
62
+ frame_idx += 1
63
+ continue
64
+
65
+ masks = res.masks.data.cpu().numpy() # [N, H, W]
66
+
67
+ # === 4. 质心计算 & ID 分配 (nearest-centroid) ===
68
  curr_centroids = []
69
  for m in masks:
70
  ys, xs = np.where(m > 0)
71
  curr_centroids.append(
72
+ (int(np.mean(xs)), int(np.mean(ys))) if xs.size else None
73
  )
74
+
75
  assignments = [-1] * len(curr_centroids)
76
  unused_ids = set(range(num_mice))
77
+ # 匹配已有 ID
78
  for i, c in enumerate(curr_centroids):
79
  if c is None: continue
80
  best_j, best_d = None, float("inf")
 
87
  if best_j is not None and best_d < (50**2):
88
  assignments[i] = best_j
89
  unused_ids.remove(best_j)
90
+ # 分配新 ID
91
  for i in range(len(curr_centroids)):
92
  if assignments[i] < 0 and unused_ids:
93
  assignments[i] = unused_ids.pop()
94
 
95
+ # === 5. 计算挣扎强度 & 可视化叠加 ===
96
  for i, m in enumerate(masks):
97
  mid = assignments[i]
98
+ if mid < 0:
99
+ continue
100
  prev_m = prev_masks[mid]
101
  if prev_m is None:
102
  struggle_records[mid].append(None)
103
  else:
104
+ # 异或统计像素差异
105
  diff = int(np.logical_xor(prev_m, m).sum())
106
  struggle_records[mid].append(diff)
107
+
108
+ # 掩膜叠加
109
  mask_rgb = np.stack([m*255 if c==1 else 0 for c in range(3)], axis=-1).astype(np.uint8)
110
  frame = cv2.addWeighted(frame, 1, mask_rgb, 0.5, 0)
111
  if curr_centroids[i]:
112
  cv2.putText(frame, f"ID:{mid}", curr_centroids[i],
113
  cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,255,0), 2)
114
+
115
  prev_centroids[mid] = curr_centroids[i]
116
  prev_masks[mid] = m.copy()
117
 
 
121
  cap.release()
122
  out.release()
123
 
124
+ # === 6. 汇总 & 绘制挣扎曲线 ===
125
  win = int(window_size_sec * fps)
126
  fig, ax = plt.subplots(figsize=(8,4))
127
  times = np.arange(0, frame_idx, win) / fps
 
132
  first = next((i for i,v in enumerate(rec) if v is not None), None)
133
  if first is not None:
134
  ax.axvspan(0, first/fps, alpha=0.3, color='gray')
135
+
136
  ax.set_xlabel("Time (s)")
137
  ax.set_ylabel("Struggle Intensity")
138
  ax.set_title("Mouse Struggle Over Time")
 
140
 
141
  return out_path, fig
142
 
143
+ # === 7. Gradio 前端 ===
144
  with gr.Blocks(title="Mice Struggle Analysis") as demo:
145
  gr.Markdown("上传视频,输入鼠标数量,点击 Run。")
146
  with gr.Row():
 
154
  outputs=[output_video, output_plot])
155
 
156
  if __name__ == "__main__":
157
+ # 不再使用 api_config,保持默认超时
158
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=False)