FinalVision / inference_track.py
phoebehxf
update model
06244eb
# inference_track.py
# 视频跟踪模型推理模块
import torch
import numpy as np
import os
from pathlib import Path
from tqdm import tqdm
from huggingface_hub import hf_hub_download
from tracking_one import TrackingModule
from models.tra_post_model.trackastra.tracking import graph_to_ctc
MODEL = None
DEVICE = torch.device("cpu")
def load_model(use_box=False):
"""
加载跟踪模型
Args:
use_box: 是否使用边界框
Returns:
model: 加载的模型
device: 设备
"""
global MODEL, DEVICE
try:
print("🔄 Loading tracking model...")
# 初始化模型
MODEL = TrackingModule(use_box=use_box)
# 从 Hugging Face Hub 下载权重
ckpt_path = hf_hub_download(
repo_id="phoebe777777/111",
filename="microscopy_matching_tra.pth",
token=None,
force_download=False
)
print(f"✅ Checkpoint downloaded: {ckpt_path}")
# 加载权重
MODEL.load_state_dict(
torch.load(ckpt_path, map_location="cpu"),
strict=True
)
MODEL.eval()
# 设置设备
if torch.cuda.is_available():
DEVICE = torch.device("cuda")
MODEL.move_to_device(DEVICE)
print("✅ Model moved to CUDA")
else:
DEVICE = torch.device("cpu")
MODEL.move_to_device(DEVICE)
print("✅ Model on CPU")
print("✅ Tracking model loaded successfully")
return MODEL, DEVICE
except Exception as e:
print(f"❌ Error loading tracking model: {e}")
import traceback
traceback.print_exc()
return None, torch.device("cpu")
@torch.no_grad()
def run(model, video_dir, box=None, device="cpu", output_dir="tracked_results"):
"""
运行视频跟踪推理
Args:
model: 跟踪模型
video_dir: 视频帧序列目录 (包含连续的图像文件)
box: 边界框 (可选)
device: 设备
output_dir: 输出目录
Returns:
result_dict: {
'track_graph': TrackGraph对象,
'masks': 分割掩码数组 (T, H, W),
'output_dir': 输出目录路径,
'num_tracks': 跟踪轨迹数量
}
"""
if model is None:
return {
'track_graph': None,
'masks': None,
'output_dir': None,
'num_tracks': 0,
'error': 'Model not loaded'
}
try:
print(f"🔄 Running tracking inference on {video_dir}")
# 运行跟踪
track_graph, masks = model.track(
file_dir=video_dir,
boxes=box,
mode="greedy", # 可选: "greedy", "greedy_nodiv", "ilp"
dataname="tracking_result"
)
# 创建输出目录
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# 转换为CTC格式并保存
print("🔄 Converting to CTC format...")
ctc_tracks, masks_tracked = graph_to_ctc(
track_graph,
masks,
outdir=output_dir,
)
print(f"✅ CTC results saved to {output_dir}")
# num_tracks = len(track_graph.tracks())
print(f"✅ Tracking completed")
result = {
'track_graph': track_graph,
'masks': masks,
'masks_tracked': masks_tracked,
'output_dir': output_dir,
# 'num_tracks': num_tracks
}
return result
except Exception as e:
print(f"❌ Tracking inference error: {e}")
import traceback
traceback.print_exc()
return {
'track_graph': None,
'masks': None,
'output_dir': None,
'num_tracks': 0,
'error': str(e)
}
def visualize_tracking_result(masks_tracked, output_path):
"""
可视化跟踪结果 (可选)
Args:
masks_tracked: 跟踪后的掩码 (T, H, W)
output_path: 输出视频路径
Returns:
output_path: 视频文件路径
"""
try:
import cv2
import matplotlib.pyplot as plt
from matplotlib import cm
# 获取时间帧数
T, H, W = masks_tracked.shape
# 创建颜色映射
unique_ids = np.unique(masks_tracked)
num_colors = len(unique_ids)
cmap = cm.get_cmap('tab20', num_colors)
# 创建视频写入器
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, fourcc, 5.0, (W, H))
for t in range(T):
frame = masks_tracked[t]
# 创建彩色图像
colored_frame = np.zeros((H, W, 3), dtype=np.uint8)
for i, obj_id in enumerate(unique_ids):
if obj_id == 0:
continue
mask = (frame == obj_id)
color = np.array(cmap(i % num_colors)[:3]) * 255
colored_frame[mask] = color
# 转换为BGR (OpenCV格式)
colored_frame_bgr = cv2.cvtColor(colored_frame, cv2.COLOR_RGB2BGR)
out.write(colored_frame_bgr)
out.release()
print(f"✅ Visualization saved to {output_path}")
return output_path
except Exception as e:
print(f"❌ Visualization error: {e}")
return None