Shengxiao0709 commited on
Commit
47ef157
·
verified ·
1 Parent(s): 82ee5dc

Update inference_track.py

Browse files
Files changed (1) hide show
  1. inference_track.py +201 -0
inference_track.py CHANGED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # inference_track.py
2
+ # 视频跟踪模型推理模块
3
+
4
+ import torch
5
+ import numpy as np
6
+ import os
7
+ from pathlib import Path
8
+ from tqdm import tqdm
9
+ from huggingface_hub import hf_hub_download
10
+ from tracking_model import TrackingModule
11
+ from models.tra_post_model.trackastra.tracking import graph_to_ctc
12
+
13
+ MODEL = None
14
+ DEVICE = torch.device("cpu")
15
+
16
+ def load_model(use_box=False):
17
+ """
18
+ 加载跟踪模型
19
+
20
+ Args:
21
+ use_box: 是否使用边界框
22
+
23
+ Returns:
24
+ model: 加载的模型
25
+ device: 设备
26
+ """
27
+ global MODEL, DEVICE
28
+
29
+ try:
30
+ print("🔄 Loading tracking model...")
31
+
32
+ # 初始化模型
33
+ MODEL = TrackingModule(use_box=use_box)
34
+
35
+ # 从 Hugging Face Hub 下载权重
36
+ ckpt_path = hf_hub_download(
37
+ repo_id="Shengxiao0709/cellsegmodel",
38
+ filename="microscopy_matching_tra.pth",
39
+ token=None,
40
+ force_download=False
41
+ )
42
+
43
+ print(f"✅ Checkpoint downloaded: {ckpt_path}")
44
+
45
+ # 加载权重
46
+ MODEL.load_state_dict(
47
+ torch.load(ckpt_path, map_location="cpu"),
48
+ strict=True
49
+ )
50
+ MODEL.eval()
51
+
52
+ # 设置设备
53
+ if torch.cuda.is_available():
54
+ DEVICE = torch.device("cuda")
55
+ MODEL.move_to_device(DEVICE)
56
+ print("✅ Model moved to CUDA")
57
+ else:
58
+ DEVICE = torch.device("cpu")
59
+ MODEL.move_to_device(DEVICE)
60
+ print("✅ Model on CPU")
61
+
62
+ print("✅ Tracking model loaded successfully")
63
+ return MODEL, DEVICE
64
+
65
+ except Exception as e:
66
+ print(f"❌ Error loading tracking model: {e}")
67
+ import traceback
68
+ traceback.print_exc()
69
+ return None, torch.device("cpu")
70
+
71
+
72
+ @torch.no_grad()
73
+ def run(model, video_dir, box=None, device="cpu", output_dir="tracked_results"):
74
+ """
75
+ 运行视频跟踪推理
76
+
77
+ Args:
78
+ model: 跟踪模型
79
+ video_dir: 视频帧序列目录 (包含连续的图像文件)
80
+ box: 边界框 (可选)
81
+ device: 设备
82
+ output_dir: 输出目录
83
+
84
+ Returns:
85
+ result_dict: {
86
+ 'track_graph': TrackGraph对象,
87
+ 'masks': 分割掩码数组 (T, H, W),
88
+ 'output_dir': 输出目录路径,
89
+ 'num_tracks': 跟踪轨迹数量
90
+ }
91
+ """
92
+ if model is None:
93
+ return {
94
+ 'track_graph': None,
95
+ 'masks': None,
96
+ 'output_dir': None,
97
+ 'num_tracks': 0,
98
+ 'error': 'Model not loaded'
99
+ }
100
+
101
+ try:
102
+ print(f"🔄 Running tracking inference on {video_dir}")
103
+
104
+ # 运行跟踪
105
+ track_graph, masks = model.track(
106
+ file_dir=video_dir,
107
+ boxes=box,
108
+ mode="greedy", # 可选: "greedy", "greedy_nodiv", "ilp"
109
+ dataname="tracking_result"
110
+ )
111
+
112
+ # 创建输出目录
113
+ if not os.path.exists(output_dir):
114
+ os.makedirs(output_dir)
115
+
116
+ # 转换为CTC格式并保存
117
+ print("🔄 Converting to CTC format...")
118
+ ctc_tracks, masks_tracked = graph_to_ctc(
119
+ track_graph,
120
+ masks,
121
+ outdir=output_dir,
122
+ )
123
+
124
+ num_tracks = len(track_graph.tracks())
125
+
126
+ print(f"✅ Tracking completed: {num_tracks} tracks found")
127
+
128
+ result = {
129
+ 'track_graph': track_graph,
130
+ 'masks': masks,
131
+ 'masks_tracked': masks_tracked,
132
+ 'output_dir': output_dir,
133
+ 'num_tracks': num_tracks
134
+ }
135
+
136
+ return result
137
+
138
+ except Exception as e:
139
+ print(f"❌ Tracking inference error: {e}")
140
+ import traceback
141
+ traceback.print_exc()
142
+ return {
143
+ 'track_graph': None,
144
+ 'masks': None,
145
+ 'output_dir': None,
146
+ 'num_tracks': 0,
147
+ 'error': str(e)
148
+ }
149
+
150
+
151
+ def visualize_tracking_result(masks_tracked, output_path):
152
+ """
153
+ 可视化跟踪结果 (可选)
154
+
155
+ Args:
156
+ masks_tracked: 跟踪后的掩码 (T, H, W)
157
+ output_path: 输出视频路径
158
+
159
+ Returns:
160
+ output_path: 视频文件路径
161
+ """
162
+ try:
163
+ import cv2
164
+ import matplotlib.pyplot as plt
165
+ from matplotlib import cm
166
+
167
+ # 获取时间帧数
168
+ T, H, W = masks_tracked.shape
169
+
170
+ # 创建颜色映射
171
+ unique_ids = np.unique(masks_tracked)
172
+ num_colors = len(unique_ids)
173
+ cmap = cm.get_cmap('tab20', num_colors)
174
+
175
+ # 创建视频写入器
176
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
177
+ out = cv2.VideoWriter(output_path, fourcc, 5.0, (W, H))
178
+
179
+ for t in range(T):
180
+ frame = masks_tracked[t]
181
+
182
+ # 创建彩色图像
183
+ colored_frame = np.zeros((H, W, 3), dtype=np.uint8)
184
+ for i, obj_id in enumerate(unique_ids):
185
+ if obj_id == 0:
186
+ continue
187
+ mask = (frame == obj_id)
188
+ color = np.array(cmap(i % num_colors)[:3]) * 255
189
+ colored_frame[mask] = color
190
+
191
+ # 转换为BGR (OpenCV格式)
192
+ colored_frame_bgr = cv2.cvtColor(colored_frame, cv2.COLOR_RGB2BGR)
193
+ out.write(colored_frame_bgr)
194
+
195
+ out.release()
196
+ print(f"✅ Visualization saved to {output_path}")
197
+ return output_path
198
+
199
+ except Exception as e:
200
+ print(f"❌ Visualization error: {e}")
201
+ return None