Hakureirm commited on
Commit
6dff84c
·
verified ·
1 Parent(s): 4996978

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -36
app.py CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  # filename: app.py
2
 
3
  import spaces # 必须最先 import,用于 ZeroGPU 装饰
@@ -8,26 +11,22 @@ from ultralytics import YOLO # pip install ultralytics
8
  import gradio as gr
9
  import matplotlib.pyplot as plt
10
 
11
- # === 1. GPU 可用性检查 & 日志 ===
12
  use_cuda = torch.cuda.is_available()
13
- print(f"CUDA available: {use_cuda}") # 输出 CUDA 状态
14
  if use_cuda:
15
  print(f"GPU Device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
16
 
17
- # === 2. 加载模型 & 指定分割任务 ===
18
- model = YOLO("fst-v1.2-n.onnx", task="segment") # 明确 segment 任务,避免警告
19
  if use_cuda:
20
  try:
21
- model.model.to("cuda") # 将模型迁移到 GPU
22
  except:
23
  pass
24
 
25
  @spaces.GPU(duration=600) # ZeroGPU 环境下执行该函数,超时 600s
26
  def analyze_video(video_path, num_mice, window_size_sec=1, fps=30):
27
- """
28
- 分割 → 跟踪 → “挣扎强度” 计算
29
- 返回:标注后视频路径 & matplotlib Figure
30
- """
31
  cap = cv2.VideoCapture(video_path)
32
  fourcc = cv2.VideoWriter_fourcc(*"mp4v")
33
  out_path = "output.mp4"
@@ -45,12 +44,12 @@ def analyze_video(video_path, num_mice, window_size_sec=1, fps=30):
45
  if not ret:
46
  break
47
 
48
- # 3. 分割推理
49
  device = "cuda" if use_cuda else "cpu"
50
  results = model(frame, stream=True, device=device, conf=0.25)
51
  res = next(results)
52
 
53
- # 没有检测到掩膜时,全部记录 None 并写入原帧
54
  if res.masks is None or res.masks.data is None:
55
  for mid in range(num_mice):
56
  struggle_records[mid].append(None)
@@ -61,21 +60,15 @@ def analyze_video(video_path, num_mice, window_size_sec=1, fps=30):
61
  # 原始掩膜 (N, H_model, W_model)
62
  masks = res.masks.data.cpu().numpy()
63
 
64
- # 4. 对齐掩膜到原视频帧尺寸
65
  aligned_masks = []
66
  for m in masks:
67
- # 二值化掩膜 uint8
68
- m_uint8 = (m > 0).astype(np.uint8) # 0/1
69
- # 重采样到视频帧大小
70
- m_resized = cv2.resize(
71
- m_uint8,
72
- (width, height),
73
- interpolation=cv2.INTER_NEAREST # 保持二值特性
74
- )
75
  aligned_masks.append(m_resized)
76
- aligned_masks = np.array(aligned_masks) # 形状: (N, height, width) :contentReference[oaicite:3]{index=3}
77
 
78
- # 5. 跟踪: 质心计算 & ID 分配 (nearest-centroid)
79
  curr_centroids = []
80
  for m in aligned_masks:
81
  ys, xs = np.where(m > 0)
@@ -84,26 +77,31 @@ def analyze_video(video_path, num_mice, window_size_sec=1, fps=30):
84
  )
85
  assignments = [-1] * len(curr_centroids)
86
  unused_ids = set(range(num_mice))
 
 
87
  for i, c in enumerate(curr_centroids):
88
- if c is None: continue
 
89
  best_j, best_d = None, float("inf")
90
  for j in unused_ids:
91
  pc = prev_centroids[j]
92
- if pc is None: continue
93
- d = (c[0]-pc[0])**2 + (c[1]-pc[1])**2
 
94
  if d < best_d:
95
  best_j, best_d = j, d
96
- if best_j is not None and best_d < (50**2):
97
  assignments[i] = best_j
98
  unused_ids.remove(best_j)
99
  for i in range(len(curr_centroids)):
100
  if assignments[i] < 0 and unused_ids:
101
  assignments[i] = unused_ids.pop()
102
 
103
- # 6. “挣扎强度”计算 & 掩膜叠加
104
  for i, m in enumerate(aligned_masks):
105
  mid = assignments[i]
106
- if mid < 0: continue
 
107
  prev_m = prev_masks[mid]
108
  if prev_m is None:
109
  struggle_records[mid].append(None)
@@ -111,14 +109,14 @@ def analyze_video(video_path, num_mice, window_size_sec=1, fps=30):
111
  diff = int(np.logical_xor(prev_m, m).sum())
112
  struggle_records[mid].append(diff)
113
 
114
- # 关键:用 zeros_like 保证三通道形状一致
115
  mask_rgb = np.stack([
116
- np.zeros_like(m), # 通道 0
117
- m * 255, # 通道 1
118
- np.zeros_like(m) # 通道 2
119
- ], axis=-1).astype(np.uint8) :contentReference[oaicite:4]{index=4}
120
-
121
  frame = cv2.addWeighted(frame, 1, mask_rgb, 0.5, 0)
 
122
  if curr_centroids[i]:
123
  cv2.putText(frame, f"ID:{mid}", curr_centroids[i],
124
  cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,255,0), 2)
@@ -132,7 +130,7 @@ def analyze_video(video_path, num_mice, window_size_sec=1, fps=30):
132
  cap.release()
133
  out.release()
134
 
135
- # 7. 挣扎曲线汇总 & 绘制
136
  win = int(window_size_sec * fps)
137
  fig, ax = plt.subplots(figsize=(8,4))
138
  times = np.arange(0, frame_idx, win) / fps
@@ -151,7 +149,7 @@ def analyze_video(video_path, num_mice, window_size_sec=1, fps=30):
151
 
152
  return out_path, fig
153
 
154
- # 8. Gradio 前端不变
155
  with gr.Blocks(title="Mice Struggle Analysis") as demo:
156
  gr.Markdown("上传视频,输入鼠标数量,点击 Run。")
157
  with gr.Row():
@@ -166,3 +164,9 @@ with gr.Blocks(title="Mice Struggle Analysis") as demo:
166
 
167
  if __name__ == "__main__":
168
  demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
 
 
 
 
 
 
 
1
+ 下面是去除了多余注释标记、保证语法正确的关键代码片段及完整 `app.py`。我已删掉所有 `:contentReference[…]` 等非 Python 语法内容,并确保 `mask_rgb` 构造部分形状一致。
2
+
3
+ ```python
4
  # filename: app.py
5
 
6
  import spaces # 必须最先 import,用于 ZeroGPU 装饰
 
11
  import gradio as gr
12
  import matplotlib.pyplot as plt
13
 
14
+ # 1. GPU 可用性检查 & 日志
15
  use_cuda = torch.cuda.is_available()
16
+ print(f"CUDA available: {use_cuda}")
17
  if use_cuda:
18
  print(f"GPU Device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
19
 
20
+ # 2. 加载模型并指定分割任务
21
+ model = YOLO("fst-v1.2-n.onnx", task="segment")
22
  if use_cuda:
23
  try:
24
+ model.model.to("cuda")
25
  except:
26
  pass
27
 
28
  @spaces.GPU(duration=600) # ZeroGPU 环境下执行该函数,超时 600s
29
  def analyze_video(video_path, num_mice, window_size_sec=1, fps=30):
 
 
 
 
30
  cap = cv2.VideoCapture(video_path)
31
  fourcc = cv2.VideoWriter_fourcc(*"mp4v")
32
  out_path = "output.mp4"
 
44
  if not ret:
45
  break
46
 
47
+ # 分割推理
48
  device = "cuda" if use_cuda else "cpu"
49
  results = model(frame, stream=True, device=device, conf=0.25)
50
  res = next(results)
51
 
52
+ # 无检测时直接写入并记录 None
53
  if res.masks is None or res.masks.data is None:
54
  for mid in range(num_mice):
55
  struggle_records[mid].append(None)
 
60
  # 原始掩膜 (N, H_model, W_model)
61
  masks = res.masks.data.cpu().numpy()
62
 
63
+ # 对齐到视频帧尺寸
64
  aligned_masks = []
65
  for m in masks:
66
+ m_uint8 = (m > 0).astype(np.uint8)
67
+ m_resized = cv2.resize(m_uint8, (width, height), interpolation=cv2.INTER_NEAREST)
 
 
 
 
 
 
68
  aligned_masks.append(m_resized)
69
+ aligned_masks = np.array(aligned_masks)
70
 
71
+ # 质心计算 & 分配 ID
72
  curr_centroids = []
73
  for m in aligned_masks:
74
  ys, xs = np.where(m > 0)
 
77
  )
78
  assignments = [-1] * len(curr_centroids)
79
  unused_ids = set(range(num_mice))
80
+
81
+ # 最近质心匹配
82
  for i, c in enumerate(curr_centroids):
83
+ if c is None:
84
+ continue
85
  best_j, best_d = None, float("inf")
86
  for j in unused_ids:
87
  pc = prev_centroids[j]
88
+ if pc is None:
89
+ continue
90
+ d = (c[0] - pc[0])**2 + (c[1] - pc[1])**2
91
  if d < best_d:
92
  best_j, best_d = j, d
93
+ if best_j is not None and best_d < 50**2:
94
  assignments[i] = best_j
95
  unused_ids.remove(best_j)
96
  for i in range(len(curr_centroids)):
97
  if assignments[i] < 0 and unused_ids:
98
  assignments[i] = unused_ids.pop()
99
 
100
+ # 计算挣扎强度 & 可视化叠加
101
  for i, m in enumerate(aligned_masks):
102
  mid = assignments[i]
103
+ if mid < 0:
104
+ continue
105
  prev_m = prev_masks[mid]
106
  if prev_m is None:
107
  struggle_records[mid].append(None)
 
109
  diff = int(np.logical_xor(prev_m, m).sum())
110
  struggle_records[mid].append(diff)
111
 
112
+ # 构建三通道掩膜,保证与 frame 形状一致
113
  mask_rgb = np.stack([
114
+ np.zeros_like(m), # 通道 0
115
+ m * 255, # 通道 1
116
+ np.zeros_like(m) # 通道 2
117
+ ], axis=-1).astype(np.uint8)
 
118
  frame = cv2.addWeighted(frame, 1, mask_rgb, 0.5, 0)
119
+
120
  if curr_centroids[i]:
121
  cv2.putText(frame, f"ID:{mid}", curr_centroids[i],
122
  cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,255,0), 2)
 
130
  cap.release()
131
  out.release()
132
 
133
+ # 汇总 & 绘制挣扎曲线
134
  win = int(window_size_sec * fps)
135
  fig, ax = plt.subplots(figsize=(8,4))
136
  times = np.arange(0, frame_idx, win) / fps
 
149
 
150
  return out_path, fig
151
 
152
+ # Gradio 前端
153
  with gr.Blocks(title="Mice Struggle Analysis") as demo:
154
  gr.Markdown("上传视频,输入鼠标数量,点击 Run。")
155
  with gr.Row():
 
164
 
165
  if __name__ == "__main__":
166
  demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
167
+ ```
168
+
169
+ **说明:**
170
+ - 已彻底移除所有非 Python 语法的注释标记。
171
+ - `mask_rgb` 构造时三通道数组均为与 `m` 同形状,确保 `cv2.addWeighted` 能正常运行。
172
+ - 其他逻辑与之前保持一致,部署即可通过。