Hakureirm commited on
Commit
548f599
·
verified ·
1 Parent(s): 899c703

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -26
app.py CHANGED
@@ -1,6 +1,6 @@
1
  # filename: app.py
2
 
3
- import spaces # 必须最先 import,用于 ZeroGPU 装饰
4
  import os
5
  import cv2
6
  import numpy as np
@@ -15,20 +15,19 @@ print(f"CUDA available: {use_cuda}")
15
  if use_cuda:
16
  print(f"GPU Device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
17
 
18
- # === 2. 加载模型 (显式指定 segmentation 任务 & 默认置信度阈值) ===
19
- # conf 设置为 0.25,可根据实际降低到 0.1-0.2
20
- model = YOLO("fst-v1.2-n.onnx", task="segment") # 避免自动猜测任务模式的警告
21
  if use_cuda:
22
  try:
23
  model.model.to("cuda")
24
  except:
25
  pass
26
 
27
- @spaces.GPU(duration=600) # 在 ZeroGPU 上执行该函数,超时 600s :contentReference[oaicite:2]{index=2}
28
  def analyze_video(video_path, num_mice, window_size_sec=1, fps=30):
29
  """
30
  分割 → 跟踪 → “挣扎强度” 计算
31
- 返回:标注后视频文件路径 & matplotlib Figure
32
  """
33
  cap = cv2.VideoCapture(video_path)
34
  fourcc = cv2.VideoWriter_fourcc(*"mp4v")
@@ -37,7 +36,6 @@ def analyze_video(video_path, num_mice, window_size_sec=1, fps=30):
37
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
38
  out = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
39
 
40
- # 跟踪状态
41
  prev_centroids = [None] * num_mice
42
  prev_masks = [None] * num_mice
43
  struggle_records = [[] for _ in range(num_mice)]
@@ -45,17 +43,14 @@ def analyze_video(video_path, num_mice, window_size_sec=1, fps=30):
45
 
46
  while True:
47
  ret, frame = cap.read()
48
- if not ret:
49
- break
50
 
51
- # === 3. 分割推理 ===
52
  device = "cuda" if use_cuda else "cpu"
53
- results = model(frame, stream=True, device=device, conf=0.25) # 指定置信度阈值 :contentReference[oaicite:3]{index=3}
54
  res = next(results)
55
 
56
- # 空检测帧处理:res.masks 可能为 None
57
  if res.masks is None or res.masks.data is None:
58
- # 为每只鼠标补充 None 并直接写入原帧
59
  for mid in range(num_mice):
60
  struggle_records[mid].append(None)
61
  out.write(frame)
@@ -64,17 +59,15 @@ def analyze_video(video_path, num_mice, window_size_sec=1, fps=30):
64
 
65
  masks = res.masks.data.cpu().numpy() # [N, H, W]
66
 
67
- # === 4. 质心计算 & ID 分配 (nearest-centroid) ===
68
  curr_centroids = []
69
  for m in masks:
70
  ys, xs = np.where(m > 0)
71
  curr_centroids.append(
72
  (int(np.mean(xs)), int(np.mean(ys))) if xs.size else None
73
  )
74
-
75
  assignments = [-1] * len(curr_centroids)
76
  unused_ids = set(range(num_mice))
77
- # 匹配已有 ID
78
  for i, c in enumerate(curr_centroids):
79
  if c is None: continue
80
  best_j, best_d = None, float("inf")
@@ -87,27 +80,29 @@ def analyze_video(video_path, num_mice, window_size_sec=1, fps=30):
87
  if best_j is not None and best_d < (50**2):
88
  assignments[i] = best_j
89
  unused_ids.remove(best_j)
90
- # 分配新 ID
91
  for i in range(len(curr_centroids)):
92
  if assignments[i] < 0 and unused_ids:
93
  assignments[i] = unused_ids.pop()
94
 
95
- # === 5. 计算挣扎强度 & 可视化叠加 ===
96
  for i, m in enumerate(masks):
97
  mid = assignments[i]
98
- if mid < 0:
99
- continue
100
  prev_m = prev_masks[mid]
101
  if prev_m is None:
102
  struggle_records[mid].append(None)
103
  else:
104
- # 异或统计像素差异
105
  diff = int(np.logical_xor(prev_m, m).sum())
106
  struggle_records[mid].append(diff)
107
 
108
- # 掩膜叠加
109
- mask_rgb = np.stack([m*255 if c==1 else 0 for c in range(3)], axis=-1).astype(np.uint8)
 
 
 
 
110
  frame = cv2.addWeighted(frame, 1, mask_rgb, 0.5, 0)
 
111
  if curr_centroids[i]:
112
  cv2.putText(frame, f"ID:{mid}", curr_centroids[i],
113
  cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,255,0), 2)
@@ -121,7 +116,7 @@ def analyze_video(video_path, num_mice, window_size_sec=1, fps=30):
121
  cap.release()
122
  out.release()
123
 
124
- # === 6. 汇总 & 绘制挣扎曲线 ===
125
  win = int(window_size_sec * fps)
126
  fig, ax = plt.subplots(figsize=(8,4))
127
  times = np.arange(0, frame_idx, win) / fps
@@ -140,7 +135,7 @@ def analyze_video(video_path, num_mice, window_size_sec=1, fps=30):
140
 
141
  return out_path, fig
142
 
143
- # === 7. Gradio 前端 ===
144
  with gr.Blocks(title="Mice Struggle Analysis") as demo:
145
  gr.Markdown("上传视频,输入鼠标数量,点击 Run。")
146
  with gr.Row():
@@ -154,5 +149,4 @@ with gr.Blocks(title="Mice Struggle Analysis") as demo:
154
  outputs=[output_video, output_plot])
155
 
156
  if __name__ == "__main__":
157
- # 不再使用 api_config,保持默认超时
158
  demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
 
1
  # filename: app.py
2
 
3
+ import spaces # 必须首 import,用于 ZeroGPU 装饰
4
  import os
5
  import cv2
6
  import numpy as np
 
15
  if use_cuda:
16
  print(f"GPU Device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
17
 
18
+ # === 2. 加载模型并显式指定 segmentation 任务 ===
19
+ model = YOLO("fst-v1.2-n.onnx", task="segment") # ONNX 模型需上传至 Space
 
20
  if use_cuda:
21
  try:
22
  model.model.to("cuda")
23
  except:
24
  pass
25
 
26
+ @spaces.GPU(duration=600) # 在 ZeroGPU 上执行本函数,超时 600s
27
  def analyze_video(video_path, num_mice, window_size_sec=1, fps=30):
28
  """
29
  分割 → 跟踪 → “挣扎强度” 计算
30
+ 返回:标注后视频路径 & matplotlib Figure
31
  """
32
  cap = cv2.VideoCapture(video_path)
33
  fourcc = cv2.VideoWriter_fourcc(*"mp4v")
 
36
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
37
  out = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
38
 
 
39
  prev_centroids = [None] * num_mice
40
  prev_masks = [None] * num_mice
41
  struggle_records = [[] for _ in range(num_mice)]
 
43
 
44
  while True:
45
  ret, frame = cap.read()
46
+ if not ret: break
 
47
 
 
48
  device = "cuda" if use_cuda else "cpu"
49
+ results = model(frame, stream=True, device=device, conf=0.25)
50
  res = next(results)
51
 
52
+ # 若无分割结果,则记录 None 并写入原帧
53
  if res.masks is None or res.masks.data is None:
 
54
  for mid in range(num_mice):
55
  struggle_records[mid].append(None)
56
  out.write(frame)
 
59
 
60
  masks = res.masks.data.cpu().numpy() # [N, H, W]
61
 
62
+ # 质心计算与 ID 分配
63
  curr_centroids = []
64
  for m in masks:
65
  ys, xs = np.where(m > 0)
66
  curr_centroids.append(
67
  (int(np.mean(xs)), int(np.mean(ys))) if xs.size else None
68
  )
 
69
  assignments = [-1] * len(curr_centroids)
70
  unused_ids = set(range(num_mice))
 
71
  for i, c in enumerate(curr_centroids):
72
  if c is None: continue
73
  best_j, best_d = None, float("inf")
 
80
  if best_j is not None and best_d < (50**2):
81
  assignments[i] = best_j
82
  unused_ids.remove(best_j)
 
83
  for i in range(len(curr_centroids)):
84
  if assignments[i] < 0 and unused_ids:
85
  assignments[i] = unused_ids.pop()
86
 
87
+ # 计算挣扎强度 & 可视化叠加
88
  for i, m in enumerate(masks):
89
  mid = assignments[i]
90
+ if mid < 0: continue
 
91
  prev_m = prev_masks[mid]
92
  if prev_m is None:
93
  struggle_records[mid].append(None)
94
  else:
 
95
  diff = int(np.logical_xor(prev_m, m).sum())
96
  struggle_records[mid].append(diff)
97
 
98
+ # —— 关键修改:用 zeros_like 保证三通道形状一致 ——
99
+ mask_rgb = np.stack([
100
+ np.zeros_like(m), # 通道 0
101
+ m.astype(np.uint8)*255, # 通道 1
102
+ np.zeros_like(m) # 通道 2
103
+ ], axis=-1).astype(np.uint8) # 避免标量与数组形状不匹配 :contentReference[oaicite:2]{index=2}
104
  frame = cv2.addWeighted(frame, 1, mask_rgb, 0.5, 0)
105
+
106
  if curr_centroids[i]:
107
  cv2.putText(frame, f"ID:{mid}", curr_centroids[i],
108
  cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,255,0), 2)
 
116
  cap.release()
117
  out.release()
118
 
119
+ # 汇总 & 绘制挣扎曲线
120
  win = int(window_size_sec * fps)
121
  fig, ax = plt.subplots(figsize=(8,4))
122
  times = np.arange(0, frame_idx, win) / fps
 
135
 
136
  return out_path, fig
137
 
138
+ # Gradio 前端部分保持不变
139
  with gr.Blocks(title="Mice Struggle Analysis") as demo:
140
  gr.Markdown("上传视频,输入鼠标数量,点击 Run。")
141
  with gr.Row():
 
149
  outputs=[output_video, output_plot])
150
 
151
  if __name__ == "__main__":
 
152
  demo.launch(server_name="0.0.0.0", server_port=7860, share=False)