Hakureirm commited on
Commit
689811f
·
verified ·
1 Parent(s): 38e81fe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -88
app.py CHANGED
@@ -1,20 +1,36 @@
1
  # filename: app.py
2
 
 
3
  import cv2
4
  import numpy as np
5
- from ultralytics import YOLO # pip install ultralytics :contentReference[oaicite:2]{index=2}
 
6
  import gradio as gr
7
  import matplotlib.pyplot as plt
8
-
9
- # 1. 加载已训练好的分割模型
10
- model = YOLO("fst-v1.2-n.onnx") # 模型文件需手动上传至 Space :contentReference[oaicite:3]{index=3}
11
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  def analyze_video(video_path, num_mice, window_size_sec=1, fps=30):
13
  """
14
- 核心分析函数:对上传视频进行分割、跟踪与挣扎强度计算
15
- 返回:标注后的视频路径 & 挣扎强度曲线图(matplotlib Figure
16
  """
17
- # 视频读取与输出配置
18
  cap = cv2.VideoCapture(video_path)
19
  fourcc = cv2.VideoWriter_fourcc(*"mp4v")
20
  out_path = "output.mp4"
@@ -22,81 +38,65 @@ def analyze_video(video_path, num_mice, window_size_sec=1, fps=30):
22
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
23
  out = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
24
 
25
- # 跟踪数据结构:每只鼠标保留上帧质心、掩膜
26
- prev_centroids = [None]*num_mice
27
- prev_masks = [None]*num_mice
28
- # 时间序列数据:每只鼠标每帧的“挣扎程度”
29
  struggle_records = [[] for _ in range(num_mice)]
30
-
31
  frame_idx = 0
 
32
  while True:
33
  ret, frame = cap.read()
34
- if not ret:
35
- break
36
 
37
- # 2. 分割推理(stream=True 可加速):
38
- results = model(frame, stream=True, device='cpu')
39
- # 取第一张结果
40
- res = next(results)
41
- masks = res.masks.data.cpu().numpy() # shape: [N, H, W]
42
- # 只保留 tag="mice" 的结果(假设模型只检测 mice 类)
43
- # masks 已经是二值化
44
 
45
- # 计算当前帧每个实例的质心
46
  curr_centroids = []
47
  for m in masks:
48
  ys, xs = np.where(m > 0)
49
- if len(xs)==0:
50
- curr_centroids.append(None)
51
- else:
52
- curr_centroids.append((int(np.mean(xs)), int(np.mean(ys))))
53
-
54
- # 3. 质心匹配分配 ID
55
- assignments = [-1]*len(curr_centroids)
56
- unused_prev = set(range(num_mice))
57
  for i, c in enumerate(curr_centroids):
58
- if c is None:
59
- continue
60
- # 找到距离最近的上一帧质心
61
- best_j, best_dist = None, float('inf')
62
- for j in unused_prev:
63
  pc = prev_centroids[j]
64
  if pc is None: continue
65
  d = (c[0]-pc[0])**2 + (c[1]-pc[1])**2
66
- if d < best_dist:
67
- best_j, best_dist = j, d
68
- if best_j is not None and best_dist < (50**2): # 距离阈值 50
69
  assignments[i] = best_j
70
- unused_prev.remove(best_j)
71
- # 未匹配的实例新分配 ID
72
  for i in range(len(curr_centroids)):
73
- if assignments[i] == -1 and unused_prev:
74
- assignments[i] = unused_prev.pop()
75
 
76
- # 4. 计算“挣扎强度” & 叠加绘制
77
  for i, m in enumerate(masks):
78
- id_ = assignments[i]
79
- if id_ is None or id_<0:
80
- continue
81
- prev_m = prev_masks[id_]
82
  if prev_m is None:
83
- # 未检测到前,标记为 None
84
- struggle_records[id_].append(None)
85
  else:
86
- # XOR 统计像素差异
87
- diff = np.logical_xor(prev_m, m).sum()
88
- struggle_records[id_].append(int(diff))
89
- # 叠加掩膜 & ID
90
- color = (0,255,0)
91
- mask_rgb = np.stack([m*color[c] for c in range(3)], axis=-1).astype(np.uint8)
92
- frame = cv2.addWeighted(frame,1,mask_rgb,0.5,0)
93
  if curr_centroids[i]:
94
- cv2.putText(frame, f"ID:{id_}", curr_centroids[i],
95
- cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2)
96
-
97
- # 更新上一帧数据
98
- prev_centroids[id_] = curr_centroids[i]
99
- prev_masks[id_] = m.copy()
100
 
101
  out.write(frame)
102
  frame_idx += 1
@@ -104,38 +104,31 @@ def analyze_video(video_path, num_mice, window_size_sec=1, fps=30):
104
  cap.release()
105
  out.release()
106
 
107
- # 5. 按时间窗口汇总并绘制
108
- win_size = int(window_size_sec * fps)
109
  fig, ax = plt.subplots(figsize=(8,4))
110
- times = np.arange(0, frame_idx, win_size) / fps
111
- for id_, records in enumerate(struggle_records):
112
- # 将记录按窗口求和,None视为 0 或保持空白
113
- sums = []
114
- for w in range(len(times)):
115
- segment = records[w*win_size:(w+1)*win_size]
116
- # 把 None 当作 0,但在绘图时保留空白
117
- vals = [v if v is not None else 0 for v in segment]
118
- sums.append(sum(vals))
119
- ax.plot(times, sums, label=f"Mouse {id_}")
120
- # 标记 None 区间
121
- first_detect = next((i for i,v in enumerate(records) if v is not None), None)
122
- if first_detect:
123
- ax.axvspan(0, first_detect/fps, color='grey', alpha=0.3)
124
-
125
  ax.set_xlabel("Time (s)")
126
  ax.set_ylabel("Struggle Intensity")
127
- ax.legend()
128
  ax.set_title("Mouse Struggle Over Time")
 
129
 
130
  return out_path, fig
131
 
132
- # 6. Gradio 接口
133
  with gr.Blocks(title="Mice Struggle Analysis") as demo:
134
- gr.Markdown("上传实验视频,输入鼠标数量,点击 Run 开始分析。")
135
  with gr.Row():
136
  video_in = gr.Video(label="Input Video")
137
  num_in = gr.Number(value=1, precision=0, label="Number of Mice")
138
- run_btn = gr.Button("Run")
139
  output_video = gr.Video(label="Annotated Video")
140
  output_plot = gr.Plot(label="Struggle Plot")
141
  run_btn.click(fn=analyze_video,
@@ -143,7 +136,6 @@ with gr.Blocks(title="Mice Struggle Analysis") as demo:
143
  outputs=[output_video, output_plot])
144
 
145
  if __name__ == "__main__":
146
- demo.launch(server_name="0.0.0.0", server_port=7860, share=False,
147
- inbrowser=False,
148
- # Zero GPU 环境下设置 600s 超时
149
- api_config={"timeout":600})
 
1
  # filename: app.py
2
 
3
+ import os
4
  import cv2
5
  import numpy as np
6
+ import torch
7
+ from ultralytics import YOLO # pip install ultralytics
8
  import gradio as gr
9
  import matplotlib.pyplot as plt
10
+ import spaces # ZeroGPU 装饰器
11
+
12
+ # === GPU 可用性检查 & 日志输出 ===
13
+ use_cuda = torch.cuda.is_available()
14
+ print(f"CUDA available: {use_cuda}")
15
+ if use_cuda:
16
+ print(f"GPU Device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
17
+
18
+ # 1. 加载并迁移模型到 GPU/CPU
19
+ model = YOLO("fst-v1.2-n.onnx")
20
+ if use_cuda:
21
+ # ONNX 推理时可通过 .model(PyTorch)迁移,或在推理调用时指定 device="cuda"
22
+ try:
23
+ model.model.to("cuda")
24
+ except Exception:
25
+ pass
26
+
27
+ @spaces.GPU(duration=600) # 调用时分配 GPU,600s 超时 :contentReference[oaicite:4]{index=4}
28
  def analyze_video(video_path, num_mice, window_size_sec=1, fps=30):
29
  """
30
+ 对视频进行分割 → 跟踪 → '挣扎强度' 计算
31
+ 返回:标注后的视频路径 & 挣扎强度曲线 (matplotlib Figure)
32
  """
33
+ # 2. 视频读取 & 输出初始化
34
  cap = cv2.VideoCapture(video_path)
35
  fourcc = cv2.VideoWriter_fourcc(*"mp4v")
36
  out_path = "output.mp4"
 
38
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
39
  out = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
40
 
41
+ # 3. 跟踪 & 记录结构
42
+ prev_centroids = [None] * num_mice
43
+ prev_masks = [None] * num_mice
 
44
  struggle_records = [[] for _ in range(num_mice)]
 
45
  frame_idx = 0
46
+
47
  while True:
48
  ret, frame = cap.read()
49
+ if not ret: break
 
50
 
51
+ # 4. 分割推理 (stream=True 加速),指定 device
52
+ device = "cuda" if use_cuda else "cpu"
53
+ results = model(frame, stream=True, device=device)
54
+ res = next(results)
55
+ masks = res.masks.data.cpu().numpy() # [N, H, W]
 
 
56
 
57
+ # 5. 计算质心 & ID 分配 (nearest-centroid)
58
  curr_centroids = []
59
  for m in masks:
60
  ys, xs = np.where(m > 0)
61
+ curr_centroids.append(
62
+ (int(np.mean(xs)), int(np.mean(ys))) if len(xs)>0 else None
63
+ )
64
+ assignments = [-1] * len(curr_centroids)
65
+ unused_ids = set(range(num_mice))
 
 
 
66
  for i, c in enumerate(curr_centroids):
67
+ if c is None: continue
68
+ best_j, best_d = None, float("inf")
69
+ for j in unused_ids:
 
 
70
  pc = prev_centroids[j]
71
  if pc is None: continue
72
  d = (c[0]-pc[0])**2 + (c[1]-pc[1])**2
73
+ if d < best_d:
74
+ best_j, best_d = j, d
75
+ if best_j is not None and best_d < (50**2):
76
  assignments[i] = best_j
77
+ unused_ids.remove(best_j)
 
78
  for i in range(len(curr_centroids)):
79
+ if assignments[i]<0 and unused_ids:
80
+ assignments[i] = unused_ids.pop()
81
 
82
+ # 6. 计算挣扎强度 & 渲染
83
  for i, m in enumerate(masks):
84
+ mid = assignments[i]
85
+ if mid is None or mid<0: continue
86
+ prev_m = prev_masks[mid]
 
87
  if prev_m is None:
88
+ struggle_records[mid].append(None)
 
89
  else:
90
+ diff = int(np.logical_xor(prev_m, m).sum())
91
+ struggle_records[mid].append(diff)
92
+ # 可视化 mask & ID
93
+ mask_rgb = np.stack([m*255 if c==1 else 0 for c in range(3)], axis=-1).astype(np.uint8)
94
+ frame = cv2.addWeighted(frame, 1, mask_rgb, 0.5, 0)
 
 
95
  if curr_centroids[i]:
96
+ cv2.putText(frame, f"ID:{mid}", curr_centroids[i],
97
+ cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,255,0), 2)
98
+ prev_centroids[mid] = curr_centroids[i]
99
+ prev_masks[mid] = m.copy()
 
 
100
 
101
  out.write(frame)
102
  frame_idx += 1
 
104
  cap.release()
105
  out.release()
106
 
107
+ # 7. 汇总 & 绘制挣扎曲线
108
+ win = int(window_size_sec * fps)
109
  fig, ax = plt.subplots(figsize=(8,4))
110
+ times = np.arange(0, frame_idx, win) / fps
111
+ for mid, rec in enumerate(struggle_records):
112
+ sums = [sum(v if v is not None else 0 for v in rec[i*win:(i+1)*win])
113
+ for i in range(len(times))]
114
+ ax.plot(times, sums, label=f"Mouse {mid}")
115
+ first = next((i for i,v in enumerate(rec) if v is not None), None)
116
+ if first:
117
+ ax.axvspan(0, first/fps, alpha=0.3, color='gray')
 
 
 
 
 
 
 
118
  ax.set_xlabel("Time (s)")
119
  ax.set_ylabel("Struggle Intensity")
 
120
  ax.set_title("Mouse Struggle Over Time")
121
+ ax.legend()
122
 
123
  return out_path, fig
124
 
125
+ # 8. Gradio 前端
126
  with gr.Blocks(title="Mice Struggle Analysis") as demo:
127
+ gr.Markdown("上传视频,输入鼠标数量,点击 Run")
128
  with gr.Row():
129
  video_in = gr.Video(label="Input Video")
130
  num_in = gr.Number(value=1, precision=0, label="Number of Mice")
131
+ run_btn = gr.Button("Run")
132
  output_video = gr.Video(label="Annotated Video")
133
  output_plot = gr.Plot(label="Struggle Plot")
134
  run_btn.click(fn=analyze_video,
 
136
  outputs=[output_video, output_plot])
137
 
138
  if __name__ == "__main__":
139
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=False,
140
+ inbrowser=False,
141
+ api_config={"timeout":600}) # 保持 600s 超时 :contentReference[oaicite:5]{index=5}