# inference_track.py # 视频跟踪模型推理模块 import torch import numpy as np import os from pathlib import Path from tqdm import tqdm from huggingface_hub import hf_hub_download from tracking_one import TrackingModule from models.tra_post_model.trackastra.tracking import graph_to_ctc MODEL = None DEVICE = torch.device("cpu") def load_model(use_box=False): """ 加载跟踪模型 Args: use_box: 是否使用边界框 Returns: model: 加载的模型 device: 设备 """ global MODEL, DEVICE try: print("🔄 Loading tracking model...") # 初始化模型 MODEL = TrackingModule(use_box=use_box) # 从 Hugging Face Hub 下载权重 ckpt_path = hf_hub_download( repo_id="phoebe777777/111", filename="microscopy_matching_tra.pth", token=None, force_download=False ) print(f"✅ Checkpoint downloaded: {ckpt_path}") # 加载权重 MODEL.load_state_dict( torch.load(ckpt_path, map_location="cpu"), strict=True ) MODEL.eval() # 设置设备 if torch.cuda.is_available(): DEVICE = torch.device("cuda") MODEL.move_to_device(DEVICE) print("✅ Model moved to CUDA") else: DEVICE = torch.device("cpu") MODEL.move_to_device(DEVICE) print("✅ Model on CPU") print("✅ Tracking model loaded successfully") return MODEL, DEVICE except Exception as e: print(f"❌ Error loading tracking model: {e}") import traceback traceback.print_exc() return None, torch.device("cpu") @torch.no_grad() def run(model, video_dir, box=None, device="cpu", output_dir="tracked_results"): """ 运行视频跟踪推理 Args: model: 跟踪模型 video_dir: 视频帧序列目录 (包含连续的图像文件) box: 边界框 (可选) device: 设备 output_dir: 输出目录 Returns: result_dict: { 'track_graph': TrackGraph对象, 'masks': 分割掩码数组 (T, H, W), 'output_dir': 输出目录路径, 'num_tracks': 跟踪轨迹数量 } """ if model is None: return { 'track_graph': None, 'masks': None, 'output_dir': None, 'num_tracks': 0, 'error': 'Model not loaded' } try: print(f"🔄 Running tracking inference on {video_dir}") # 运行跟踪 track_graph, masks = model.track( file_dir=video_dir, boxes=box, mode="greedy", # 可选: "greedy", "greedy_nodiv", "ilp" dataname="tracking_result" ) # 创建输出目录 if not os.path.exists(output_dir): os.makedirs(output_dir) # 转换为CTC格式并保存 print("🔄 Converting to CTC format...") ctc_tracks, masks_tracked = graph_to_ctc( track_graph, masks, outdir=output_dir, ) print(f"✅ CTC results saved to {output_dir}") # num_tracks = len(track_graph.tracks()) print(f"✅ Tracking completed") result = { 'track_graph': track_graph, 'masks': masks, 'masks_tracked': masks_tracked, 'output_dir': output_dir, # 'num_tracks': num_tracks } return result except Exception as e: print(f"❌ Tracking inference error: {e}") import traceback traceback.print_exc() return { 'track_graph': None, 'masks': None, 'output_dir': None, 'num_tracks': 0, 'error': str(e) } def visualize_tracking_result(masks_tracked, output_path): """ 可视化跟踪结果 (可选) Args: masks_tracked: 跟踪后的掩码 (T, H, W) output_path: 输出视频路径 Returns: output_path: 视频文件路径 """ try: import cv2 import matplotlib.pyplot as plt from matplotlib import cm # 获取时间帧数 T, H, W = masks_tracked.shape # 创建颜色映射 unique_ids = np.unique(masks_tracked) num_colors = len(unique_ids) cmap = cm.get_cmap('tab20', num_colors) # 创建视频写入器 fourcc = cv2.VideoWriter_fourcc(*'mp4v') out = cv2.VideoWriter(output_path, fourcc, 5.0, (W, H)) for t in range(T): frame = masks_tracked[t] # 创建彩色图像 colored_frame = np.zeros((H, W, 3), dtype=np.uint8) for i, obj_id in enumerate(unique_ids): if obj_id == 0: continue mask = (frame == obj_id) color = np.array(cmap(i % num_colors)[:3]) * 255 colored_frame[mask] = color # 转换为BGR (OpenCV格式) colored_frame_bgr = cv2.cvtColor(colored_frame, cv2.COLOR_RGB2BGR) out.write(colored_frame_bgr) out.release() print(f"✅ Visualization saved to {output_path}") return output_path except Exception as e: print(f"❌ Visualization error: {e}") return None