Hakureirm commited on
Commit
7ae1738
·
verified ·
1 Parent(s): 93b6f29

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -84
app.py CHANGED
@@ -1,5 +1,3 @@
1
- # filename: app.py
2
-
3
  import spaces # 必须最先 import,用于 ZeroGPU 装饰
4
  import cv2
5
  import numpy as np
@@ -8,117 +6,119 @@ from ultralytics import YOLO # pip install ultralytics
8
  import gradio as gr
9
  import matplotlib.pyplot as plt
10
 
11
- # 1. GPU 可用性检查 & 日志
12
- use_cuda = torch.cuda.is_available()
13
  print(f"CUDA available: {use_cuda}")
14
  if use_cuda:
15
  print(f"GPU Device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
16
 
17
- # 2. 加载模型并指定分割任务
18
- model = YOLO("fst-v1.2-n.onnx", task="segment")
19
  if use_cuda:
20
  try:
21
- model.model.to("cuda")
22
  except:
23
  pass
24
 
25
- @spaces.GPU(duration=600) # ZeroGPU 环境下执行该函数,超时 600s
26
- def analyze_video(video_path, num_mice, window_size_sec=1, fps=30):
 
 
 
 
 
27
  cap = cv2.VideoCapture(video_path)
28
- fourcc = cv2.VideoWriter_fourcc(*"mp4v")
29
- out_path = "output.mp4"
30
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
31
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
32
- out = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
 
 
 
 
 
 
 
 
 
 
33
 
34
- prev_centroids = [None] * num_mice
35
- prev_masks = [None] * num_mice
36
  struggle_records = [[] for _ in range(num_mice)]
37
- frame_idx = 0
38
 
39
- while True:
40
  ret, frame = cap.read()
41
  if not ret:
42
  break
43
 
44
- # 分割推理
45
- device = "cuda" if use_cuda else "cpu"
46
- results = model(frame, stream=True, device=device, conf=0.25)
47
  res = next(results)
48
 
49
- # 无检测时直接写入并记录 None
50
- if res.masks is None or res.masks.data is None:
51
  for mid in range(num_mice):
52
  struggle_records[mid].append(None)
53
  out.write(frame)
54
  frame_idx += 1
55
  continue
56
 
57
- # 原始掩膜 (N, H_model, W_model)
58
- masks = res.masks.data.cpu().numpy()
59
 
60
- # 对齐到视频帧尺寸
61
- aligned_masks = []
62
  for m in masks:
63
- m_uint8 = (m > 0).astype(np.uint8)
64
- m_resized = cv2.resize(m_uint8, (width, height), interpolation=cv2.INTER_NEAREST)
65
- aligned_masks.append(m_resized)
66
- aligned_masks = np.array(aligned_masks)
67
-
68
- # 质心计算 & 分配 ID
69
- curr_centroids = []
70
- for m in aligned_masks:
71
  ys, xs = np.where(m > 0)
72
- curr_centroids.append(
73
- (int(np.mean(xs)), int(np.mean(ys))) if xs.size else None
74
- )
75
- assignments = [-1] * len(curr_centroids)
76
- unused_ids = set(range(num_mice))
77
-
78
- # 最近质心匹配
79
- for i, c in enumerate(curr_centroids):
80
- if c is None:
81
- continue
82
- best_j, best_d = None, float("inf")
83
- for j in unused_ids:
84
  pc = prev_centroids[j]
85
- if pc is None:
86
- continue
87
- d = (c[0] - pc[0])**2 + (c[1] - pc[1])**2
88
  if d < best_d:
89
  best_j, best_d = j, d
90
  if best_j is not None and best_d < 50**2:
91
- assignments[i] = best_j
92
- unused_ids.remove(best_j)
93
- for i in range(len(curr_centroids)):
94
- if assignments[i] < 0 and unused_ids:
95
- assignments[i] = unused_ids.pop()
96
-
97
- # 计算挣扎强度 & 可视化叠加
98
- for i, m in enumerate(aligned_masks):
99
- mid = assignments[i]
100
- if mid < 0:
101
- continue
102
- prev_m = prev_masks[mid]
103
- if prev_m is None:
104
  struggle_records[mid].append(None)
105
  else:
106
- diff = int(np.logical_xor(prev_m, m).sum())
107
  struggle_records[mid].append(diff)
108
 
109
- # 构建三通道掩膜,保证与 frame 形状一致
110
  mask_rgb = np.stack([
111
- np.zeros_like(m), # 通道 0
112
- m * 255, # 通道 1
113
- np.zeros_like(m) # 通道 2
114
  ], axis=-1).astype(np.uint8)
115
  frame = cv2.addWeighted(frame, 1, mask_rgb, 0.5, 0)
 
 
116
 
117
- if curr_centroids[i]:
118
- cv2.putText(frame, f"ID:{mid}", curr_centroids[i],
119
- cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,255,0), 2)
120
-
121
- prev_centroids[mid] = curr_centroids[i]
122
  prev_masks[mid] = m.copy()
123
 
124
  out.write(frame)
@@ -127,18 +127,16 @@ def analyze_video(video_path, num_mice, window_size_sec=1, fps=30):
127
  cap.release()
128
  out.release()
129
 
130
- # 汇总 & 绘制挣扎曲线
131
- win = int(window_size_sec * fps)
132
  fig, ax = plt.subplots(figsize=(8,4))
133
- times = np.arange(0, frame_idx, win) / fps
134
  for mid, rec in enumerate(struggle_records):
135
- sums = [sum(v if v is not None else 0 for v in rec[i*win:(i+1)*win])
136
- for i in range(len(times))]
137
  ax.plot(times, sums, label=f"Mouse {mid}")
138
  first = next((i for i,v in enumerate(rec) if v is not None), None)
139
  if first is not None:
140
- ax.axvspan(0, first/fps, alpha=0.3, color='gray')
141
-
142
  ax.set_xlabel("Time (s)")
143
  ax.set_ylabel("Struggle Intensity")
144
  ax.set_title("Mouse Struggle Over Time")
@@ -148,16 +146,27 @@ def analyze_video(video_path, num_mice, window_size_sec=1, fps=30):
148
 
149
  # Gradio 前端
150
  with gr.Blocks(title="Mice Struggle Analysis") as demo:
151
- gr.Markdown("上传视频,输入鼠标数量,点击 Run")
152
  with gr.Row():
153
- video_in = gr.Video(label="Input Video")
154
- num_in = gr.Number(value=1, precision=0, label="Number of Mice")
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  run_btn = gr.Button("Run")
156
  output_video = gr.Video(label="Annotated Video")
157
  output_plot = gr.Plot(label="Struggle Plot")
158
- run_btn.click(fn=analyze_video,
159
- inputs=[video_in, num_in],
160
- outputs=[output_video, output_plot])
161
 
162
  if __name__ == "__main__":
163
  demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
 
 
 
1
  import spaces # 必须最先 import,用于 ZeroGPU 装饰
2
  import cv2
3
  import numpy as np
 
6
  import gradio as gr
7
  import matplotlib.pyplot as plt
8
 
9
+ # GPU 可用性检查 & 日志
10
+ evice_is_cuda = torch.cuda.is_available()
11
  print(f"CUDA available: {use_cuda}")
12
  if use_cuda:
13
  print(f"GPU Device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
14
 
15
+ # 加载模型并指定分割任务
16
+ yolo_model = YOLO("fst-v1.2-n.onnx", task="segment")
17
  if use_cuda:
18
  try:
19
+ yolo_model.model.to("cuda")
20
  except:
21
  pass
22
 
23
+ @spaces.GPU(duration=600) # ZeroGPU 上运行,超时 600s
24
+ def analyze_video(video_path, num_mice, time_range, window_size_sec=1, fps=30):
25
+ """
26
+ 分割 → 跟踪 → 计算“挣扎强度”,只分析指定时间范围
27
+ 返回:标注后视频 & 挣扎曲线 Figure
28
+ """
29
+ # 打开视频,获取基本信息
30
  cap = cv2.VideoCapture(video_path)
31
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
32
+ vid_fps = cap.get(cv2.CAP_PROP_FPS) or fps
33
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
34
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
35
+ start_s, end_s = time_range
36
+ start_frame = int(start_s * vid_fps)
37
+ end_frame = int(end_s * vid_fps)
38
+
39
+ # 跳转到起始帧
40
+ cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
41
+
42
+ # 输出视频初始化
43
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
44
+ out_path = "output.mp4"
45
+ out = cv2.VideoWriter(out_path, fourcc, vid_fps, (width, height))
46
 
47
+ prev_centroids = [None] * num_mice
48
+ prev_masks = [None] * num_mice
49
  struggle_records = [[] for _ in range(num_mice)]
50
+ frame_idx = start_frame
51
 
52
+ while frame_idx <= end_frame:
53
  ret, frame = cap.read()
54
  if not ret:
55
  break
56
 
57
+ # 分割推理\ device = "cuda" if use_cuda else "cpu"
58
+ results = yolo_model(frame, stream=True, device=device, conf=0.25)
 
59
  res = next(results)
60
 
61
+ # 处理无检测帧\ if res.masks is None or res.masks.data is None:
 
62
  for mid in range(num_mice):
63
  struggle_records[mid].append(None)
64
  out.write(frame)
65
  frame_idx += 1
66
  continue
67
 
68
+ masks = res.masks.data.cpu().numpy() # (N, H_model, W_model)
 
69
 
70
+ # 对齐掩膜至帧尺寸
71
+ aligned = []
72
  for m in masks:
73
+ m_bin = (m > 0).astype(np.uint8)
74
+ m_res = cv2.resize(m_bin, (width, height), interpolation=cv2.INTER_NEAREST)
75
+ aligned.append(m_res)
76
+ aligned = np.array(aligned)
77
+
78
+ # 计算质心 & 分配 ID
79
+ curr_cent = []
80
+ for m in aligned:
81
  ys, xs = np.where(m > 0)
82
+ curr_cent.append((int(xs.mean()), int(ys.mean())) if xs.size else None)
83
+ assign = [-1] * len(curr_cent)
84
+ unused = set(range(num_mice))
85
+ for i, c in enumerate(curr_cent):
86
+ if c is None: continue
87
+ best_j, best_d = None, float('inf')
88
+ for j in unused:
 
 
 
 
 
89
  pc = prev_centroids[j]
90
+ if pc is None: continue
91
+ d = (c[0]-pc[0])**2 + (c[1]-pc[1])**2
 
92
  if d < best_d:
93
  best_j, best_d = j, d
94
  if best_j is not None and best_d < 50**2:
95
+ assign[i] = best_j
96
+ unused.remove(best_j)
97
+ for i in range(len(curr_cent)):
98
+ if assign[i] < 0 and unused:
99
+ assign[i] = unused.pop()
100
+
101
+ # 计算挣扎强度 & 叠加
102
+ for i, m in enumerate(aligned):
103
+ mid = assign[i]
104
+ if mid < 0: continue
105
+ pm = prev_masks[mid]
106
+ if pm is None:
 
107
  struggle_records[mid].append(None)
108
  else:
109
+ diff = int(np.logical_xor(pm, m).sum())
110
  struggle_records[mid].append(diff)
111
 
 
112
  mask_rgb = np.stack([
113
+ np.zeros_like(m),
114
+ m * 255,
115
+ np.zeros_like(m)
116
  ], axis=-1).astype(np.uint8)
117
  frame = cv2.addWeighted(frame, 1, mask_rgb, 0.5, 0)
118
+ if curr_cent[i]:
119
+ cv2.putText(frame, f"ID:{mid}", curr_cent[i], cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,255,0), 2)
120
 
121
+ prev_centroids[mid] = curr_cent[i]
 
 
 
 
122
  prev_masks[mid] = m.copy()
123
 
124
  out.write(frame)
 
127
  cap.release()
128
  out.release()
129
 
130
+ # 绘制挣扎曲线
131
+ win = int(window_size_sec * vid_fps)
132
  fig, ax = plt.subplots(figsize=(8,4))
133
+ times = np.arange(start_s, end_s, win/vid_fps)
134
  for mid, rec in enumerate(struggle_records):
135
+ sums = [sum(v if v is not None else 0 for v in rec[i*win:(i+1)*win]) for i in range(len(times))]
 
136
  ax.plot(times, sums, label=f"Mouse {mid}")
137
  first = next((i for i,v in enumerate(rec) if v is not None), None)
138
  if first is not None:
139
+ ax.axvspan(start_s, start_s+first/vid_fps, alpha=0.3, color='gray')
 
140
  ax.set_xlabel("Time (s)")
141
  ax.set_ylabel("Struggle Intensity")
142
  ax.set_title("Mouse Struggle Over Time")
 
146
 
147
  # Gradio 前端
148
  with gr.Blocks(title="Mice Struggle Analysis") as demo:
149
+ gr.Markdown("上传视频,输入鼠标数量,选择分析时间范围,点击 Run")
150
  with gr.Row():
151
+ video_in = gr.Video(label="Input Video")
152
+ num_in = gr.Number(value=1, precision=0, label="Number of Mice")
153
+ time_range = gr.RangeSlider(label="Analysis Time Range (s)", minimum=0, maximum=1, value=(0,1), step=1, disabled=True)
154
+
155
+ # 上传视频后激活滑块并设置最大值
156
+ def get_video_duration(path):
157
+ cap = cv2.VideoCapture(path)
158
+ fps = cap.get(cv2.CAP_PROP_FPS) or fps
159
+ frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
160
+ dur = int(frames / fps)
161
+ cap.release()
162
+ return gr.update(maximum=dur, value=(0, dur), disabled=False)
163
+
164
+ video_in.change(fn=get_video_duration, inputs=video_in, outputs=time_range)
165
+
166
  run_btn = gr.Button("Run")
167
  output_video = gr.Video(label="Annotated Video")
168
  output_plot = gr.Plot(label="Struggle Plot")
169
+ run_btn.click(fn=analyze_video, inputs=[video_in, num_in, time_range], outputs=[output_video, output_plot])
 
 
170
 
171
  if __name__ == "__main__":
172
  demo.launch(server_name="0.0.0.0", server_port=7860, share=False)