Hakureirm commited on
Commit
d739c25
·
verified ·
1 Parent(s): e5f4569

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +149 -0
app.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # filename: app.py
2
+
3
+ import cv2
4
+ import numpy as np
5
+ from ultralytics import YOLO # pip install ultralytics :contentReference[oaicite:2]{index=2}
6
+ import gradio as gr
7
+ import matplotlib.pyplot as plt
8
+
9
+ # 1. 加载已训练好的分割模型
10
+ model = YOLO("yolo11n-seg.pt") # 模型文件需手动上传至 Space :contentReference[oaicite:3]{index=3}
11
+
12
+ def analyze_video(video_path, num_mice, window_size_sec=1, fps=30):
13
+ """
14
+ 核心分析函数:对上传视频进行分割、跟踪与挣扎强度计算
15
+ 返回:标注后的视频路径 & 挣扎强度曲线图(matplotlib Figure)
16
+ """
17
+ # 视频读取与输出配置
18
+ cap = cv2.VideoCapture(video_path)
19
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
20
+ out_path = "output.mp4"
21
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
22
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
23
+ out = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
24
+
25
+ # 跟踪数据结构:每只鼠标保留上帧质心、掩膜
26
+ prev_centroids = [None]*num_mice
27
+ prev_masks = [None]*num_mice
28
+ # 时间序列数据:每只鼠标每帧的“挣扎程度”
29
+ struggle_records = [[] for _ in range(num_mice)]
30
+
31
+ frame_idx = 0
32
+ while True:
33
+ ret, frame = cap.read()
34
+ if not ret:
35
+ break
36
+
37
+ # 2. 分割推理(stream=True 可加速):
38
+ results = model(frame, stream=True, device='cpu')
39
+ # 取第一张结果
40
+ res = next(results)
41
+ masks = res.masks.data.cpu().numpy() # shape: [N, H, W]
42
+ # 只保留 tag="mice" 的结果(假设模型只检测 mice 类)
43
+ # masks 已经是二值化
44
+
45
+ # 计算当前帧每个实例的质心
46
+ curr_centroids = []
47
+ for m in masks:
48
+ ys, xs = np.where(m > 0)
49
+ if len(xs)==0:
50
+ curr_centroids.append(None)
51
+ else:
52
+ curr_centroids.append((int(np.mean(xs)), int(np.mean(ys))))
53
+
54
+ # 3. 质心匹配分配 ID
55
+ assignments = [-1]*len(curr_centroids)
56
+ unused_prev = set(range(num_mice))
57
+ for i, c in enumerate(curr_centroids):
58
+ if c is None:
59
+ continue
60
+ # 找到距离最近的上一帧质心
61
+ best_j, best_dist = None, float('inf')
62
+ for j in unused_prev:
63
+ pc = prev_centroids[j]
64
+ if pc is None: continue
65
+ d = (c[0]-pc[0])**2 + (c[1]-pc[1])**2
66
+ if d < best_dist:
67
+ best_j, best_dist = j, d
68
+ if best_j is not None and best_dist < (50**2): # 距离阈值 50
69
+ assignments[i] = best_j
70
+ unused_prev.remove(best_j)
71
+ # 未匹配的实例新分配 ID
72
+ for i in range(len(curr_centroids)):
73
+ if assignments[i] == -1 and unused_prev:
74
+ assignments[i] = unused_prev.pop()
75
+
76
+ # 4. 计算“挣扎强度” & 叠加绘制
77
+ for i, m in enumerate(masks):
78
+ id_ = assignments[i]
79
+ if id_ is None or id_<0:
80
+ continue
81
+ prev_m = prev_masks[id_]
82
+ if prev_m is None:
83
+ # 未检测到前,标记为 None
84
+ struggle_records[id_].append(None)
85
+ else:
86
+ # XOR 统计像素差异
87
+ diff = np.logical_xor(prev_m, m).sum()
88
+ struggle_records[id_].append(int(diff))
89
+ # 叠加掩膜 & ID
90
+ color = (0,255,0)
91
+ mask_rgb = np.stack([m*color[c] for c in range(3)], axis=-1).astype(np.uint8)
92
+ frame = cv2.addWeighted(frame,1,mask_rgb,0.5,0)
93
+ if curr_centroids[i]:
94
+ cv2.putText(frame, f"ID:{id_}", curr_centroids[i],
95
+ cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2)
96
+
97
+ # 更新上一帧数据
98
+ prev_centroids[id_] = curr_centroids[i]
99
+ prev_masks[id_] = m.copy()
100
+
101
+ out.write(frame)
102
+ frame_idx += 1
103
+
104
+ cap.release()
105
+ out.release()
106
+
107
+ # 5. 按时间窗口汇总并绘制
108
+ win_size = int(window_size_sec * fps)
109
+ fig, ax = plt.subplots(figsize=(8,4))
110
+ times = np.arange(0, frame_idx, win_size) / fps
111
+ for id_, records in enumerate(struggle_records):
112
+ # 将记录按窗口求和,None视为 0 或保持空白
113
+ sums = []
114
+ for w in range(len(times)):
115
+ segment = records[w*win_size:(w+1)*win_size]
116
+ # 把 None 当作 0,但在绘图时保留空白
117
+ vals = [v if v is not None else 0 for v in segment]
118
+ sums.append(sum(vals))
119
+ ax.plot(times, sums, label=f"Mouse {id_}")
120
+ # 标记 None 区间
121
+ first_detect = next((i for i,v in enumerate(records) if v is not None), None)
122
+ if first_detect:
123
+ ax.axvspan(0, first_detect/fps, color='grey', alpha=0.3)
124
+
125
+ ax.set_xlabel("Time (s)")
126
+ ax.set_ylabel("Struggle Intensity")
127
+ ax.legend()
128
+ ax.set_title("Mouse Struggle Over Time")
129
+
130
+ return out_path, fig
131
+
132
+ # 6. Gradio 接口
133
+ with gr.Blocks(title="Mice Struggle Analysis") as demo:
134
+ gr.Markdown("上传实验视频,输入鼠标数量,点击 Run 开始分析。")
135
+ with gr.Row():
136
+ video_in = gr.Video(label="Input Video")
137
+ num_in = gr.Number(value=1, precision=0, label="Number of Mice")
138
+ run_btn = gr.Button("Run")
139
+ output_video = gr.Video(label="Annotated Video")
140
+ output_plot = gr.Plot(label="Struggle Plot")
141
+ run_btn.click(fn=analyze_video,
142
+ inputs=[video_in, num_in],
143
+ outputs=[output_video, output_plot])
144
+
145
+ if __name__ == "__main__":
146
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=False,
147
+ inbrowser=False,
148
+ # Zero GPU 环境下设置 600s 超时
149
+ api_config={"timeout":600})