Hakureirm commited on
Commit
65d5672
·
verified ·
1 Parent(s): 689811f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -23
app.py CHANGED
@@ -1,5 +1,6 @@
1
  # filename: app.py
2
 
 
3
  import os
4
  import cv2
5
  import numpy as np
@@ -7,30 +8,31 @@ import torch
7
  from ultralytics import YOLO # pip install ultralytics
8
  import gradio as gr
9
  import matplotlib.pyplot as plt
10
- import spaces # ZeroGPU 装饰器
11
 
12
- # === GPU 可用性检查 & 日志输出 ===
13
  use_cuda = torch.cuda.is_available()
14
  print(f"CUDA available: {use_cuda}")
15
  if use_cuda:
16
  print(f"GPU Device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
17
 
18
- # 1. 加载并迁移模型到 GPU/CPU
19
- model = YOLO("fst-v1.2-n.onnx")
 
 
 
20
  if use_cuda:
21
- # ONNX 推理时可通过 .model(PyTorch)迁移,或在推理调用时指定 device="cuda"
22
  try:
23
  model.model.to("cuda")
24
  except Exception:
25
  pass
26
 
27
- @spaces.GPU(duration=600) # 调用时分配 GPU600s 超时 :contentReference[oaicite:4]{index=4}
28
  def analyze_video(video_path, num_mice, window_size_sec=1, fps=30):
29
  """
30
- 对视频进行分割 → 跟踪 → '挣扎强度' 计算
31
- 返回:标注后的视频路径 & 挣扎强度曲线 (matplotlib Figure)
32
  """
33
- # 2. 视频读取 & 输出初始化
34
  cap = cv2.VideoCapture(video_path)
35
  fourcc = cv2.VideoWriter_fourcc(*"mp4v")
36
  out_path = "output.mp4"
@@ -38,7 +40,7 @@ def analyze_video(video_path, num_mice, window_size_sec=1, fps=30):
38
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
39
  out = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
40
 
41
- # 3. 跟踪 & 记录结构
42
  prev_centroids = [None] * num_mice
43
  prev_masks = [None] * num_mice
44
  struggle_records = [[] for _ in range(num_mice)]
@@ -46,15 +48,16 @@ def analyze_video(video_path, num_mice, window_size_sec=1, fps=30):
46
 
47
  while True:
48
  ret, frame = cap.read()
49
- if not ret: break
 
50
 
51
- # 4. 分割推理 (stream=True 加速),指定 device
52
  device = "cuda" if use_cuda else "cpu"
53
  results = model(frame, stream=True, device=device)
54
  res = next(results)
55
  masks = res.masks.data.cpu().numpy() # [N, H, W]
56
 
57
- # 5. 计算质心 & ID 分配 (nearest-centroid)
58
  curr_centroids = []
59
  for m in masks:
60
  ys, xs = np.where(m > 0)
@@ -76,20 +79,20 @@ def analyze_video(video_path, num_mice, window_size_sec=1, fps=30):
76
  assignments[i] = best_j
77
  unused_ids.remove(best_j)
78
  for i in range(len(curr_centroids)):
79
- if assignments[i]<0 and unused_ids:
80
  assignments[i] = unused_ids.pop()
81
 
82
- # 6. 计算挣扎强度 & 渲染
83
  for i, m in enumerate(masks):
84
  mid = assignments[i]
85
- if mid is None or mid<0: continue
86
  prev_m = prev_masks[mid]
87
  if prev_m is None:
88
  struggle_records[mid].append(None)
89
  else:
90
  diff = int(np.logical_xor(prev_m, m).sum())
91
  struggle_records[mid].append(diff)
92
- # 可视化 mask & ID
93
  mask_rgb = np.stack([m*255 if c==1 else 0 for c in range(3)], axis=-1).astype(np.uint8)
94
  frame = cv2.addWeighted(frame, 1, mask_rgb, 0.5, 0)
95
  if curr_centroids[i]:
@@ -104,7 +107,7 @@ def analyze_video(video_path, num_mice, window_size_sec=1, fps=30):
104
  cap.release()
105
  out.release()
106
 
107
- # 7. 汇总 & 绘制挣扎曲线
108
  win = int(window_size_sec * fps)
109
  fig, ax = plt.subplots(figsize=(8,4))
110
  times = np.arange(0, frame_idx, win) / fps
@@ -113,7 +116,7 @@ def analyze_video(video_path, num_mice, window_size_sec=1, fps=30):
113
  for i in range(len(times))]
114
  ax.plot(times, sums, label=f"Mouse {mid}")
115
  first = next((i for i,v in enumerate(rec) if v is not None), None)
116
- if first:
117
  ax.axvspan(0, first/fps, alpha=0.3, color='gray')
118
  ax.set_xlabel("Time (s)")
119
  ax.set_ylabel("Struggle Intensity")
@@ -122,7 +125,7 @@ def analyze_video(video_path, num_mice, window_size_sec=1, fps=30):
122
 
123
  return out_path, fig
124
 
125
- # 8. Gradio 前端
126
  with gr.Blocks(title="Mice Struggle Analysis") as demo:
127
  gr.Markdown("上传视频,输入鼠标数量,点击 Run。")
128
  with gr.Row():
@@ -136,6 +139,5 @@ with gr.Blocks(title="Mice Struggle Analysis") as demo:
136
  outputs=[output_video, output_plot])
137
 
138
  if __name__ == "__main__":
139
- demo.launch(server_name="0.0.0.0", server_port=7860, share=False,
140
- inbrowser=False,
141
- api_config={"timeout":600}) # 保持 600s 超时 :contentReference[oaicite:5]{index=5}
 
1
  # filename: app.py
2
 
3
+ import spaces # 必须最先 import
4
  import os
5
  import cv2
6
  import numpy as np
 
8
  from ultralytics import YOLO # pip install ultralytics
9
  import gradio as gr
10
  import matplotlib.pyplot as plt
 
11
 
12
+ # === 1. GPU 可用性检查 & 日志 ===
13
  use_cuda = torch.cuda.is_available()
14
  print(f"CUDA available: {use_cuda}")
15
  if use_cuda:
16
  print(f"GPU Device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
17
 
18
+ # === 2. 加载模型并显式指定 task ===
19
+ # 避免无法猜测任务类型的警告,明确使用分割 (segment) 模式
20
+ model = YOLO("fst-v1.2-n.onnx", task="segment") # ONNX 模型需上传至空间
21
+
22
+ # 若 CUDA 可用,迁移模型至 GPU
23
  if use_cuda:
 
24
  try:
25
  model.model.to("cuda")
26
  except Exception:
27
  pass
28
 
29
+ @spaces.GPU(duration=600) # 调用时分配 GPU,超时 600s :contentReference[oaicite:3]{index=3}
30
  def analyze_video(video_path, num_mice, window_size_sec=1, fps=30):
31
  """
32
+ 核心分析:分割 → 跟踪 → 计算“挣扎强度”
33
+ 返回:标注后的视频路径 & 绘制好的挣扎曲线 (matplotlib Figure)
34
  """
35
+ # 视频读写设置
36
  cap = cv2.VideoCapture(video_path)
37
  fourcc = cv2.VideoWriter_fourcc(*"mp4v")
38
  out_path = "output.mp4"
 
40
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
41
  out = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
42
 
43
+ # 跟踪初始化
44
  prev_centroids = [None] * num_mice
45
  prev_masks = [None] * num_mice
46
  struggle_records = [[] for _ in range(num_mice)]
 
48
 
49
  while True:
50
  ret, frame = cap.read()
51
+ if not ret:
52
+ break
53
 
54
+ # 分割推理(stream=True 加速),指定 device
55
  device = "cuda" if use_cuda else "cpu"
56
  results = model(frame, stream=True, device=device)
57
  res = next(results)
58
  masks = res.masks.data.cpu().numpy() # [N, H, W]
59
 
60
+ # 计算质心 & 分配 ID(nearest-centroid
61
  curr_centroids = []
62
  for m in masks:
63
  ys, xs = np.where(m > 0)
 
79
  assignments[i] = best_j
80
  unused_ids.remove(best_j)
81
  for i in range(len(curr_centroids)):
82
+ if assignments[i] < 0 and unused_ids:
83
  assignments[i] = unused_ids.pop()
84
 
85
+ # 计算挣扎强度 & 可视化
86
  for i, m in enumerate(masks):
87
  mid = assignments[i]
88
+ if mid < 0: continue
89
  prev_m = prev_masks[mid]
90
  if prev_m is None:
91
  struggle_records[mid].append(None)
92
  else:
93
  diff = int(np.logical_xor(prev_m, m).sum())
94
  struggle_records[mid].append(diff)
95
+ # 叠加掩膜 & ID
96
  mask_rgb = np.stack([m*255 if c==1 else 0 for c in range(3)], axis=-1).astype(np.uint8)
97
  frame = cv2.addWeighted(frame, 1, mask_rgb, 0.5, 0)
98
  if curr_centroids[i]:
 
107
  cap.release()
108
  out.release()
109
 
110
+ # 汇总 & 绘图
111
  win = int(window_size_sec * fps)
112
  fig, ax = plt.subplots(figsize=(8,4))
113
  times = np.arange(0, frame_idx, win) / fps
 
116
  for i in range(len(times))]
117
  ax.plot(times, sums, label=f"Mouse {mid}")
118
  first = next((i for i,v in enumerate(rec) if v is not None), None)
119
+ if first is not None:
120
  ax.axvspan(0, first/fps, alpha=0.3, color='gray')
121
  ax.set_xlabel("Time (s)")
122
  ax.set_ylabel("Struggle Intensity")
 
125
 
126
  return out_path, fig
127
 
128
+ # Gradio 前端
129
  with gr.Blocks(title="Mice Struggle Analysis") as demo:
130
  gr.Markdown("上传视频,输入鼠标数量,点击 Run。")
131
  with gr.Row():
 
139
  outputs=[output_video, output_plot])
140
 
141
  if __name__ == "__main__":
142
+ # 去除不支持的 api_config 参数
143
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=False) # ⚠️ 删除 api_config :contentReference[oaicite:4]{index=4}