Spaces:
Sleeping
Sleeping
| # inference_track.py | |
| # 视频跟踪模型推理模块 | |
| import torch | |
| import numpy as np | |
| import os | |
| from pathlib import Path | |
| from tqdm import tqdm | |
| from huggingface_hub import hf_hub_download | |
| from tracking_one import TrackingModule | |
| from models.tra_post_model.trackastra.tracking import graph_to_ctc | |
| MODEL = None | |
| DEVICE = torch.device("cpu") | |
| def load_model(use_box=False): | |
| """ | |
| 加载跟踪模型 | |
| Args: | |
| use_box: 是否使用边界框 | |
| Returns: | |
| model: 加载的模型 | |
| device: 设备 | |
| """ | |
| global MODEL, DEVICE | |
| try: | |
| print("🔄 Loading tracking model...") | |
| # 初始化模型 | |
| MODEL = TrackingModule(use_box=use_box) | |
| # 从 Hugging Face Hub 下载权重 | |
| ckpt_path = hf_hub_download( | |
| repo_id="phoebe777777/111", | |
| filename="microscopy_matching_tra.pth", | |
| token=None, | |
| force_download=False | |
| ) | |
| print(f"✅ Checkpoint downloaded: {ckpt_path}") | |
| # 加载权重 | |
| MODEL.load_state_dict( | |
| torch.load(ckpt_path, map_location="cpu"), | |
| strict=True | |
| ) | |
| MODEL.eval() | |
| # 设置设备 | |
| if torch.cuda.is_available(): | |
| DEVICE = torch.device("cuda") | |
| MODEL.move_to_device(DEVICE) | |
| print("✅ Model moved to CUDA") | |
| else: | |
| DEVICE = torch.device("cpu") | |
| MODEL.move_to_device(DEVICE) | |
| print("✅ Model on CPU") | |
| print("✅ Tracking model loaded successfully") | |
| return MODEL, DEVICE | |
| except Exception as e: | |
| print(f"❌ Error loading tracking model: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return None, torch.device("cpu") | |
| def run(model, video_dir, box=None, device="cpu", output_dir="tracked_results"): | |
| """ | |
| 运行视频跟踪推理 | |
| Args: | |
| model: 跟踪模型 | |
| video_dir: 视频帧序列目录 (包含连续的图像文件) | |
| box: 边界框 (可选) | |
| device: 设备 | |
| output_dir: 输出目录 | |
| Returns: | |
| result_dict: { | |
| 'track_graph': TrackGraph对象, | |
| 'masks': 分割掩码数组 (T, H, W), | |
| 'output_dir': 输出目录路径, | |
| 'num_tracks': 跟踪轨迹数量 | |
| } | |
| """ | |
| if model is None: | |
| return { | |
| 'track_graph': None, | |
| 'masks': None, | |
| 'output_dir': None, | |
| 'num_tracks': 0, | |
| 'error': 'Model not loaded' | |
| } | |
| try: | |
| print(f"🔄 Running tracking inference on {video_dir}") | |
| # 运行跟踪 | |
| track_graph, masks = model.track( | |
| file_dir=video_dir, | |
| boxes=box, | |
| mode="greedy", # 可选: "greedy", "greedy_nodiv", "ilp" | |
| dataname="tracking_result" | |
| ) | |
| # 创建输出目录 | |
| if not os.path.exists(output_dir): | |
| os.makedirs(output_dir) | |
| # 转换为CTC格式并保存 | |
| print("🔄 Converting to CTC format...") | |
| ctc_tracks, masks_tracked = graph_to_ctc( | |
| track_graph, | |
| masks, | |
| outdir=output_dir, | |
| ) | |
| print(f"✅ CTC results saved to {output_dir}") | |
| # num_tracks = len(track_graph.tracks()) | |
| print(f"✅ Tracking completed") | |
| result = { | |
| 'track_graph': track_graph, | |
| 'masks': masks, | |
| 'masks_tracked': masks_tracked, | |
| 'output_dir': output_dir, | |
| # 'num_tracks': num_tracks | |
| } | |
| return result | |
| except Exception as e: | |
| print(f"❌ Tracking inference error: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return { | |
| 'track_graph': None, | |
| 'masks': None, | |
| 'output_dir': None, | |
| 'num_tracks': 0, | |
| 'error': str(e) | |
| } | |
| def visualize_tracking_result(masks_tracked, output_path): | |
| """ | |
| 可视化跟踪结果 (可选) | |
| Args: | |
| masks_tracked: 跟踪后的掩码 (T, H, W) | |
| output_path: 输出视频路径 | |
| Returns: | |
| output_path: 视频文件路径 | |
| """ | |
| try: | |
| import cv2 | |
| import matplotlib.pyplot as plt | |
| from matplotlib import cm | |
| # 获取时间帧数 | |
| T, H, W = masks_tracked.shape | |
| # 创建颜色映射 | |
| unique_ids = np.unique(masks_tracked) | |
| num_colors = len(unique_ids) | |
| cmap = cm.get_cmap('tab20', num_colors) | |
| # 创建视频写入器 | |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
| out = cv2.VideoWriter(output_path, fourcc, 5.0, (W, H)) | |
| for t in range(T): | |
| frame = masks_tracked[t] | |
| # 创建彩色图像 | |
| colored_frame = np.zeros((H, W, 3), dtype=np.uint8) | |
| for i, obj_id in enumerate(unique_ids): | |
| if obj_id == 0: | |
| continue | |
| mask = (frame == obj_id) | |
| color = np.array(cmap(i % num_colors)[:3]) * 255 | |
| colored_frame[mask] = color | |
| # 转换为BGR (OpenCV格式) | |
| colored_frame_bgr = cv2.cvtColor(colored_frame, cv2.COLOR_RGB2BGR) | |
| out.write(colored_frame_bgr) | |
| out.release() | |
| print(f"✅ Visualization saved to {output_path}") | |
| return output_path | |
| except Exception as e: | |
| print(f"❌ Visualization error: {e}") | |
| return None | |