| import argparse |
| import os |
| import time |
|
|
| import cv2 |
| import numpy as np |
| import requests |
| import torch |
| import wget |
| import yolov7 |
| from mobile_sam import SamPredictor, sam_model_registry |
| from PIL import Image |
| from tqdm import tqdm |
| from transformers import YolosForObjectDetection, YolosImageProcessor |
|
|
| from images_to_video import VideoCreator |
| from video_to_images import ImageCreator |
|
|
|
|
| def download_mobile_sam_weight(path): |
| if not os.path.exists(path): |
| sam_weights = "https://raw.githubusercontent.com/ChaoningZhang/MobileSAM/master/weights/mobile_sam.pt" |
| for i in range(2, len(path.split("/"))): |
| temp = path.split("/")[:i] |
| cur_path = "/".join(temp) |
| if not os.path.isdir(cur_path): |
| os.mkdir(cur_path) |
| model_name = path.split("/")[-1] |
| if model_name in sam_weights: |
| wget.download(sam_weights, path) |
| else: |
| raise NameError( |
| "There is no pretrained weight to download for %s, you need to provide a path to segformer weights." |
| % model_name |
| ) |
|
|
|
|
| def get_closest_bbox(bbox_list, bbox_target): |
| """ |
| Given a list of bounding boxes, find the one that is closest to the target bounding box. |
| Args: |
| bbox_list: list of bounding boxes |
| bbox_target: target bounding box |
| Returns: |
| closest bounding box |
| |
| """ |
| min_dist = 100000000 |
| min_idx = 0 |
| for idx, bbox in enumerate(bbox_list): |
| dist = np.linalg.norm(bbox - bbox_target) |
| if dist < min_dist: |
| min_dist = dist |
| min_idx = idx |
| return bbox_list[min_idx] |
|
|
|
|
| def get_bboxes(image_file, image, model, image_processor, threshold=0.9): |
| if image_processor is None: |
| results = model(image_file) |
| predictions = results.pred[0] |
| boxes = predictions[:, :4].detach().numpy() |
| return boxes |
| else: |
| inputs = image_processor(images=image, return_tensors="pt") |
| outputs = model(**inputs) |
|
|
| target_sizes = torch.tensor([image.size[::-1]]) |
| results = image_processor.post_process_object_detection( |
| outputs, threshold=threshold, target_sizes=target_sizes |
| )[0] |
|
|
| return results["boxes"].detach().numpy() |
|
|
|
|
| def segment_video( |
| video_filename, |
| dir_frames, |
| image_start, |
| image_end, |
| bbox_file, |
| skip_vid2im, |
| mobile_sam_weights, |
| auto_detect=False, |
| tracker_name="yolov7", |
| background_color="#009000", |
| output_dir="output_frames", |
| output_video="output.mp4", |
| pbar=False, |
| reverse_mask=False, |
| ): |
| if not skip_vid2im: |
| vid_to_im = ImageCreator( |
| video_filename, |
| dir_frames, |
| image_start=image_start, |
| image_end=image_end, |
| pbar=pbar, |
| ) |
| vid_to_im.get_images() |
| |
| vid = cv2.VideoCapture(video_filename) |
| fps = vid.get(cv2.CAP_PROP_FPS) |
| vid.release() |
| background_color = background_color.lstrip("#") |
| background_color = ( |
| np.array([int(background_color[i : i + 2], 16) for i in (0, 2, 4)]) / 255.0 |
| ) |
|
|
| with open(bbox_file, "r") as f: |
| bbox_orig = [int(coord) for coord in f.read().split(" ")] |
| download_mobile_sam_weight(mobile_sam_weights) |
| if image_end == 0: |
| frames = sorted(os.listdir(dir_frames))[image_start:] |
| else: |
| frames = sorted(os.listdir(dir_frames))[image_start:image_end] |
|
|
| model_type = "vit_t" |
|
|
| if torch.backends.mps.is_available(): |
| device = "mps" |
| elif torch.cuda.is_available(): |
|
|
| device = "cuda" |
| else: |
| device = "cpu" |
| sam = sam_model_registry[model_type](checkpoint=mobile_sam_weights) |
| sam.to(device=device) |
| sam.eval() |
|
|
| predictor = SamPredictor(sam) |
|
|
| if not auto_detect: |
| if tracker_name == "yolov7": |
| model = yolov7.load("kadirnar/yolov7-tiny-v0.1", hf_model=True) |
| model.conf = 0.25 |
| model.iou = 0.45 |
| model.classes = None |
| image_processor = None |
| else: |
| model = YolosForObjectDetection.from_pretrained("hustvl/yolos-tiny") |
| image_processor = YolosImageProcessor.from_pretrained("hustvl/yolos-tiny") |
|
|
| output_frames = [] |
|
|
| if pbar: |
| pb = tqdm(frames) |
| else: |
| pb = frames |
|
|
| processed_frames = 0 |
| init_time = time.time() |
| for frame in pb: |
| processed_frames += 1 |
| image_file = dir_frames + "/" + frame |
| image_pil = Image.open(image_file) |
| image_np = np.array(image_pil) |
| if not auto_detect: |
| bboxes = get_bboxes(image_file, image_pil, model, image_processor) |
| closest_bbox = get_closest_bbox(bboxes, bbox_orig) |
| input_box = np.array(closest_bbox) |
| else: |
| input_box = np.array([0, 0, image_np.shape[1], image_np.shape[0]]) |
| predictor.set_image(image_np) |
| masks, _, _ = predictor.predict( |
| point_coords=None, |
| point_labels=None, |
| box=input_box[None, :], |
| multimask_output=True, |
| ) |
| if reverse_mask: |
| mask = masks[0] |
| h, w = mask.shape[-2:] |
| mask_image = ( |
| (mask).reshape(h, w, 1) * background_color.reshape(1, 1, -1) |
| ) * 255 |
| masked_image = image_np * (1 - mask).reshape(h, w, 1) |
| masked_image = masked_image + mask_image |
| output_frames.append(masked_image) |
| else: |
| mask = masks[0] |
| h, w = mask.shape[-2:] |
| mask_image = ( |
| (1 - mask).reshape(h, w, 1) * background_color.reshape(1, 1, -1) |
| ) * 255 |
| masked_image = image_np * mask.reshape(h, w, 1) |
| masked_image = masked_image + mask_image |
| output_frames.append(masked_image) |
|
|
| if not pbar and processed_frames % 10 == 0: |
| remaining_time = ( |
| (time.time() - init_time) |
| / processed_frames |
| * (len(frames) - processed_frames) |
| ) |
| remaining_time = int(remaining_time) |
| remaining_time_str = f"{remaining_time//60}m {remaining_time%60}s" |
| print( |
| f"Processed frame {processed_frames}/{len(frames)} - Remaining time: {remaining_time_str}" |
| ) |
| if not os.path.exists(output_dir): |
| os.mkdir(output_dir) |
|
|
| zfill_max = len(str(len(output_frames))) |
| for idx, frame in enumerate(output_frames): |
| cv2.imwrite( |
| f"{output_dir}/frame_{str(idx).zfill(zfill_max)}.png", |
| frame, |
| ) |
| vid_creator = VideoCreator(output_dir, output_video, pbar=pbar) |
| vid_creator.create_video(fps=int(fps)) |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument( |
| "--video_filename", |
| default="assets/example.mp4", |
| type=str, |
| help="path to the video", |
| ) |
| parser.add_argument( |
| "--dir_frames", |
| type=str, |
| default="frames", |
| help="path to the directory in which all input frames will be stored", |
| ) |
| parser.add_argument( |
| "--image_start", type=int, default=0, help="first image to be stored" |
| ) |
| parser.add_argument( |
| "--image_end", |
| type=int, |
| default=0, |
| help="last image to be stored, last one if 0", |
| ) |
| parser.add_argument( |
| "--bbox_file", |
| type=str, |
| default="bbox.txt", |
| help="path to the bounding box text file", |
| ) |
| parser.add_argument( |
| "--skip_vid2im", |
| action="store_true", |
| help="whether to write the video frames as images", |
| ) |
| parser.add_argument( |
| "--mobile_sam_weights", |
| type=str, |
| default="./models/mobile_sam.pt", |
| help="path to MobileSAM weights", |
| ) |
|
|
| parser.add_argument( |
| "--tracker_name", |
| type=str, |
| default="yolov7", |
| help="tracker name", |
| choices=["yolov7", "yoloS"], |
| ) |
|
|
| parser.add_argument( |
| "--output_dir", |
| type=str, |
| default="output_frames", |
| help="directory to store the output frames", |
| ) |
|
|
| parser.add_argument( |
| "--output_video", |
| type=str, |
| default="output.mp4", |
| help="path to store the output video", |
| ) |
| parser.add_argument( |
| "--auto_detect", |
| action="store_true", |
| help="whether to use a bounding box to force the model to segment the object", |
| ) |
| parser.add_argument( |
| "--background_color", |
| type=str, |
| default="#009000", |
| help="background color for the output (hex)", |
| ) |
| args = parser.parse_args() |
|
|
| segment_video( |
| args.video_filename, |
| args.dir_frames, |
| args.image_start, |
| args.image_end, |
| args.bbox_file, |
| args.skip_vid2im, |
| args.mobile_sam_weights, |
| args.auto_detect, |
| args.output_dir, |
| args.output_video, |
| args.tracker_name, |
| args.background_color, |
| ) |
|
|