Spaces:
Running
Running
| import os | |
| from tqdm import tqdm | |
| from utils.interact_tools import SamControler | |
| from tracker.base_tracker import BaseTracker | |
| import numpy as np | |
| import argparse | |
| import cv2 | |
| from typing import Optional | |
| class TrackingAnything: | |
| def __init__(self, sam_pt_checkpoint, sam_onnx_checkpoint, xmem_checkpoint, args): | |
| self.args = args | |
| self.sam_pt_checkpoint = sam_pt_checkpoint | |
| self.sam_onnx_checkpoint = sam_onnx_checkpoint | |
| self.xmem_checkpoint = xmem_checkpoint | |
| self.samcontroler = SamControler( | |
| self.sam_pt_checkpoint, self.sam_onnx_checkpoint, args.sam_model_type, args.device | |
| ) | |
| self.xmem = BaseTracker(self.xmem_checkpoint, device=args.device) | |
| def first_frame_click( | |
| self, image: np.ndarray, points: np.ndarray, labels: np.ndarray, multimask=True | |
| ): | |
| mask, logit, painted_image = self.samcontroler.first_frame_click( | |
| image, points, labels, multimask | |
| ) | |
| return mask, logit, painted_image | |
| def generator( | |
| self, | |
| images: list, | |
| template_mask: np.ndarray, | |
| write: Optional[bool] = False, | |
| fps: Optional[int] = "30", | |
| output_path: Optional[str] = "tracking.mp4", | |
| ): | |
| masks = [] | |
| logits = [] | |
| painted_images = [] | |
| if write: | |
| size = images[0].shape[:2][::-1] | |
| if not os.path.exists(os.path.dirname(output_path)): | |
| os.makedirs(os.path.dirname(output_path)) | |
| writer = cv2.VideoWriter( | |
| output_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, size | |
| ) | |
| for i in tqdm(range(len(images)), desc="Tracking image"): | |
| if i == 0: | |
| mask, logit, painted_image = self.xmem.track(images[i], template_mask) | |
| else: | |
| mask, logit, painted_image = self.xmem.track(images[i]) | |
| masks.append(mask) | |
| logits.append(logit) | |
| if write: | |
| writer.write(painted_image[:,:,::-1]) | |
| else: | |
| painted_images.append(painted_image) | |
| if write: | |
| writer.release() | |
| return masks, logits, painted_images | |
| def parse_augment(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--device", type=str, default="cpu") | |
| parser.add_argument("--sam_model_type", type=str, default="vit_t") | |
| parser.add_argument( | |
| "--port", | |
| type=int, | |
| default=6080, | |
| help="only useful when running gradio applications", | |
| ) | |
| parser.add_argument("--debug", action="store_true") | |
| parser.add_argument("--mask_save", default=False) | |
| args = parser.parse_args() | |
| if args.debug: | |
| print(args) | |
| return args | |