Hakureirm commited on
Commit
4996978
·
verified ·
1 Parent(s): 548f599

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -23
app.py CHANGED
@@ -1,29 +1,28 @@
1
  # filename: app.py
2
 
3
- import spaces # 必须首 import,用于 ZeroGPU 装饰
4
- import os
5
  import cv2
6
  import numpy as np
7
  import torch
8
- from ultralytics import YOLO # pip install ultralytics
9
  import gradio as gr
10
  import matplotlib.pyplot as plt
11
 
12
  # === 1. GPU 可用性检查 & 日志 ===
13
  use_cuda = torch.cuda.is_available()
14
- print(f"CUDA available: {use_cuda}")
15
  if use_cuda:
16
  print(f"GPU Device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
17
 
18
- # === 2. 加载模型并显式指定 segmentation 任务 ===
19
- model = YOLO("fst-v1.2-n.onnx", task="segment") # ONNX 模型需上传至 Space
20
  if use_cuda:
21
  try:
22
- model.model.to("cuda")
23
  except:
24
  pass
25
 
26
- @spaces.GPU(duration=600) # ZeroGPU 上执行本函数,超时 600s
27
  def analyze_video(video_path, num_mice, window_size_sec=1, fps=30):
28
  """
29
  分割 → 跟踪 → “挣扎强度” 计算
@@ -43,13 +42,15 @@ def analyze_video(video_path, num_mice, window_size_sec=1, fps=30):
43
 
44
  while True:
45
  ret, frame = cap.read()
46
- if not ret: break
 
47
 
 
48
  device = "cuda" if use_cuda else "cpu"
49
  results = model(frame, stream=True, device=device, conf=0.25)
50
  res = next(results)
51
 
52
- # 若无分割结果,则记录 None 并写入原帧
53
  if res.masks is None or res.masks.data is None:
54
  for mid in range(num_mice):
55
  struggle_records[mid].append(None)
@@ -57,11 +58,26 @@ def analyze_video(video_path, num_mice, window_size_sec=1, fps=30):
57
  frame_idx += 1
58
  continue
59
 
60
- masks = res.masks.data.cpu().numpy() # [N, H, W]
 
61
 
62
- # 质心计算与 ID 分配
63
- curr_centroids = []
64
  for m in masks:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  ys, xs = np.where(m > 0)
66
  curr_centroids.append(
67
  (int(np.mean(xs)), int(np.mean(ys))) if xs.size else None
@@ -84,8 +100,8 @@ def analyze_video(video_path, num_mice, window_size_sec=1, fps=30):
84
  if assignments[i] < 0 and unused_ids:
85
  assignments[i] = unused_ids.pop()
86
 
87
- # 计算挣扎强度 & 可视化叠加
88
- for i, m in enumerate(masks):
89
  mid = assignments[i]
90
  if mid < 0: continue
91
  prev_m = prev_masks[mid]
@@ -95,14 +111,14 @@ def analyze_video(video_path, num_mice, window_size_sec=1, fps=30):
95
  diff = int(np.logical_xor(prev_m, m).sum())
96
  struggle_records[mid].append(diff)
97
 
98
- # —— 关键修改:用 zeros_like 保证三通道形状一致 ——
99
  mask_rgb = np.stack([
100
- np.zeros_like(m), # 通道 0
101
- m.astype(np.uint8)*255, # 通道 1
102
- np.zeros_like(m) # 通道 2
103
- ], axis=-1).astype(np.uint8) # 避免标量与数组形状不匹配 :contentReference[oaicite:2]{index=2}
104
- frame = cv2.addWeighted(frame, 1, mask_rgb, 0.5, 0)
105
 
 
106
  if curr_centroids[i]:
107
  cv2.putText(frame, f"ID:{mid}", curr_centroids[i],
108
  cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,255,0), 2)
@@ -116,7 +132,7 @@ def analyze_video(video_path, num_mice, window_size_sec=1, fps=30):
116
  cap.release()
117
  out.release()
118
 
119
- # 汇总 & 绘制挣扎曲线
120
  win = int(window_size_sec * fps)
121
  fig, ax = plt.subplots(figsize=(8,4))
122
  times = np.arange(0, frame_idx, win) / fps
@@ -135,7 +151,7 @@ def analyze_video(video_path, num_mice, window_size_sec=1, fps=30):
135
 
136
  return out_path, fig
137
 
138
- # Gradio 前端部分保持不变
139
  with gr.Blocks(title="Mice Struggle Analysis") as demo:
140
  gr.Markdown("上传视频,输入鼠标数量,点击 Run。")
141
  with gr.Row():
 
1
  # filename: app.py
2
 
3
+ import spaces # 必须最先 import,用于 ZeroGPU 装饰
 
4
  import cv2
5
  import numpy as np
6
  import torch
7
+ from ultralytics import YOLO # pip install ultralytics
8
  import gradio as gr
9
  import matplotlib.pyplot as plt
10
 
11
  # === 1. GPU 可用性检查 & 日志 ===
12
  use_cuda = torch.cuda.is_available()
13
+ print(f"CUDA available: {use_cuda}") # 输出 CUDA 状态
14
  if use_cuda:
15
  print(f"GPU Device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
16
 
17
+ # === 2. 加载模型 & 指定分割任务 ===
18
+ model = YOLO("fst-v1.2-n.onnx", task="segment") # 明确 segment 任务,避免警告
19
  if use_cuda:
20
  try:
21
+ model.model.to("cuda") # 将模型迁移到 GPU
22
  except:
23
  pass
24
 
25
+ @spaces.GPU(duration=600) # ZeroGPU 环境下执行该函数,超时 600s
26
  def analyze_video(video_path, num_mice, window_size_sec=1, fps=30):
27
  """
28
  分割 → 跟踪 → “挣扎强度” 计算
 
42
 
43
  while True:
44
  ret, frame = cap.read()
45
+ if not ret:
46
+ break
47
 
48
+ # 3. 分割推理
49
  device = "cuda" if use_cuda else "cpu"
50
  results = model(frame, stream=True, device=device, conf=0.25)
51
  res = next(results)
52
 
53
+ # 没有检测到掩膜时,全部记录 None 并写入原帧
54
  if res.masks is None or res.masks.data is None:
55
  for mid in range(num_mice):
56
  struggle_records[mid].append(None)
 
58
  frame_idx += 1
59
  continue
60
 
61
+ # 原始掩膜 (N, H_model, W_model)
62
+ masks = res.masks.data.cpu().numpy()
63
 
64
+ # 4. 对齐掩膜到原视频帧尺寸
65
+ aligned_masks = []
66
  for m in masks:
67
+ # 二值化掩膜 → uint8
68
+ m_uint8 = (m > 0).astype(np.uint8) # 0/1
69
+ # 重采样到视频帧大小
70
+ m_resized = cv2.resize(
71
+ m_uint8,
72
+ (width, height),
73
+ interpolation=cv2.INTER_NEAREST # 保持二值特性
74
+ )
75
+ aligned_masks.append(m_resized)
76
+ aligned_masks = np.array(aligned_masks) # 形状: (N, height, width) :contentReference[oaicite:3]{index=3}
77
+
78
+ # 5. 跟踪: 质心计算 & ID 分配 (nearest-centroid)
79
+ curr_centroids = []
80
+ for m in aligned_masks:
81
  ys, xs = np.where(m > 0)
82
  curr_centroids.append(
83
  (int(np.mean(xs)), int(np.mean(ys))) if xs.size else None
 
100
  if assignments[i] < 0 and unused_ids:
101
  assignments[i] = unused_ids.pop()
102
 
103
+ # 6. “挣扎强度”计算 & 掩膜叠加
104
+ for i, m in enumerate(aligned_masks):
105
  mid = assignments[i]
106
  if mid < 0: continue
107
  prev_m = prev_masks[mid]
 
111
  diff = int(np.logical_xor(prev_m, m).sum())
112
  struggle_records[mid].append(diff)
113
 
114
+ # 关键:用 zeros_like 保证三通道形状一致
115
  mask_rgb = np.stack([
116
+ np.zeros_like(m), # 通道 0
117
+ m * 255, # 通道 1
118
+ np.zeros_like(m) # 通道 2
119
+ ], axis=-1).astype(np.uint8) :contentReference[oaicite:4]{index=4}
 
120
 
121
+ frame = cv2.addWeighted(frame, 1, mask_rgb, 0.5, 0)
122
  if curr_centroids[i]:
123
  cv2.putText(frame, f"ID:{mid}", curr_centroids[i],
124
  cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,255,0), 2)
 
132
  cap.release()
133
  out.release()
134
 
135
+ # 7. 挣扎曲线汇总 & 绘制
136
  win = int(window_size_sec * fps)
137
  fig, ax = plt.subplots(figsize=(8,4))
138
  times = np.arange(0, frame_idx, win) / fps
 
151
 
152
  return out_path, fig
153
 
154
+ # 8. Gradio 前端不变
155
  with gr.Blocks(title="Mice Struggle Analysis") as demo:
156
  gr.Markdown("上传视频,输入鼠标数量,点击 Run。")
157
  with gr.Row():