| import os |
| import cv2 |
| import torch |
| import numpy as np |
| import gradio as gr |
| from PIL import Image, ImageDraw |
| from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor |
| from transformers import OwlViTProcessor, OwlViTForObjectDetection |
| import gc |
|
|
| models = { |
| 'vit_b': './checkpoints/sam_vit_b_01ec64.pth', |
| 'vit_l': './checkpoints/sam_vit_l_0b3195.pth', |
| 'vit_h': './checkpoints/sam_vit_h_4b8939.pth' |
| } |
|
|
| image_examples = [ |
| [os.path.join(os.path.dirname(__file__), "./images/53960-scaled.jpg"), 0, []], |
| [os.path.join(os.path.dirname(__file__), "./images/2388455-scaled.jpg"), 1, []], |
| [os.path.join(os.path.dirname(__file__), "./images/1.jpg"),2,[]], |
| [os.path.join(os.path.dirname(__file__), "./images/2.jpg"),3,[]], |
| [os.path.join(os.path.dirname(__file__), "./images/3.jpg"),4,[]], |
| [os.path.join(os.path.dirname(__file__), "./images/4.jpg"),5,[]], |
| [os.path.join(os.path.dirname(__file__), "./images/5.jpg"),6,[]], |
| [os.path.join(os.path.dirname(__file__), "./images/6.jpg"),7,[]], |
| [os.path.join(os.path.dirname(__file__), "./images/7.jpg"),8,[]], |
| [os.path.join(os.path.dirname(__file__), "./images/8.jpg"),9,[]] |
| ] |
|
|
|
|
| def plot_boxes(img, boxes): |
| img_pil = Image.fromarray(np.uint8(img * 255)).convert('RGB') |
| draw = ImageDraw.Draw(img_pil) |
| for box in boxes: |
| color = tuple(np.random.randint(0, 255, size=3).tolist()) |
| x0, y0, x1, y1 = box |
| x0, y0, x1, y1 = int(x0), int(y0), int(x1), int(y1) |
| draw.rectangle([x0, y0, x1, y1], outline=color, width=6) |
| return img_pil |
|
|
|
|
| def segment_one(img, mask_generator, seed=None): |
| if seed is not None: |
| np.random.seed(seed) |
| masks = mask_generator.generate(img) |
| sorted_anns = sorted(masks, key=(lambda x: x['area']), reverse=True) |
| mask_all = np.ones((img.shape[0], img.shape[1], 3)) |
| for ann in sorted_anns: |
| m = ann['segmentation'] |
| color_mask = np.random.random((1, 3)).tolist()[0] |
| for i in range(3): |
| mask_all[m == True, i] = color_mask[i] |
| result = img / 255 * 0.3 + mask_all * 0.7 |
| return result, mask_all |
|
|
|
|
| def generator_inference(device, model_type, points_per_side, pred_iou_thresh, stability_score_thresh, |
| min_mask_region_area, stability_score_offset, box_nms_thresh, crop_n_layers, crop_nms_thresh, |
| input_x, progress=gr.Progress()): |
| |
| sam = sam_model_registry[model_type](checkpoint=models[model_type]).to(device) |
| mask_generator = SamAutomaticMaskGenerator( |
| sam, |
| points_per_side=points_per_side, |
| pred_iou_thresh=pred_iou_thresh, |
| stability_score_thresh=stability_score_thresh, |
| stability_score_offset=stability_score_offset, |
| box_nms_thresh=box_nms_thresh, |
| crop_n_layers=crop_n_layers, |
| crop_nms_thresh=crop_nms_thresh, |
| crop_overlap_ratio=512 / 1500, |
| crop_n_points_downscale_factor=1, |
| point_grids=None, |
| min_mask_region_area=min_mask_region_area, |
| output_mode='binary_mask' |
| ) |
|
|
| |
| if type(input_x) == np.ndarray: |
| result, mask_all = segment_one(input_x, mask_generator) |
| return result, mask_all |
| elif isinstance(input_x, str): |
| cap = cv2.VideoCapture(input_x) |
| frames_num = cap.get(cv2.CAP_PROP_FRAME_COUNT) |
| W, H = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
| fps = int(cap.get(cv2.CAP_PROP_FPS)) |
| out = cv2.VideoWriter("output.mp4", cv2.VideoWriter_fourcc('x', '2', '6', '4'), fps, (W, H), isColor=True) |
| for _ in progress.tqdm(range(int(frames_num)), |
| desc='Processing video ({} frames, size {}x{})'.format(int(frames_num), W, H)): |
| ret, frame = cap.read() |
| result, mask_all = segment_one(frame, mask_generator, seed=2023) |
| result = (result * 255).astype(np.uint8) |
| out.write(result) |
| out.release() |
| cap.release() |
| return 'output.mp4' |
|
|
|
|
| def predictor_inference(device, model_type, input_x, input_text, selected_points, owl_vit_threshold=0.1): |
| |
| sam = sam_model_registry[model_type](checkpoint=models[model_type]).to(device) |
| predictor = SamPredictor(sam) |
| predictor.set_image(input_x) |
|
|
| if input_text != '': |
| |
| input_text = [input_text.split(',')] |
| print(input_text) |
| |
| processor = OwlViTProcessor.from_pretrained('./checkpoints/models--google--owlvit-base-patch32') |
| owlvit_model = OwlViTForObjectDetection.from_pretrained("./checkpoints/models--google--owlvit-base-patch32").to(device) |
| |
| input_text = processor(text=input_text, images=input_x, return_tensors="pt").to(device) |
| outputs = owlvit_model(**input_text) |
| target_size = torch.Tensor([input_x.shape[:2]]).to(device) |
| results = processor.post_process_object_detection(outputs=outputs, target_sizes=target_size, |
| threshold=owl_vit_threshold) |
|
|
| |
| scores = torch.sigmoid(outputs.logits) |
| |
| |
|
|
| i = 0 |
| boxes_tensor = results[i]["boxes"] |
| boxes = boxes_tensor.cpu().detach().numpy() |
| |
| transformed_boxes = predictor.transform.apply_boxes_torch(torch.Tensor(boxes).to(device), |
| input_x.shape[:2]) |
| |
| print(transformed_boxes.size(), boxes.shape) |
| else: |
| transformed_boxes = None |
|
|
| |
| if len(selected_points) != 0: |
| points = torch.Tensor([p for p, _ in selected_points]).to(device).unsqueeze(1) |
| labels = torch.Tensor([int(l) for _, l in selected_points]).to(device).unsqueeze(1) |
| transformed_points = predictor.transform.apply_coords_torch(points, input_x.shape[:2]) |
| print(points.size(), transformed_points.size(), labels.size(), input_x.shape, points) |
| else: |
| transformed_points, labels = None, None |
|
|
| |
| masks, scores, logits = predictor.predict_torch( |
| point_coords=transformed_points, |
| point_labels=labels, |
| boxes=transformed_boxes, |
| multimask_output=False, |
| ) |
| masks = masks.cpu().detach().numpy() |
| mask_all = np.ones((input_x.shape[0], input_x.shape[1], 3)) |
| for ann in masks: |
| color_mask = np.random.random((1, 3)).tolist()[0] |
| for i in range(3): |
| mask_all[ann[0] == True, i] = color_mask[i] |
| img = input_x / 255 * 0.3 + mask_all * 0.7 |
| if input_text != '': |
| img = plot_boxes(img, boxes_tensor) |
|
|
| |
| if input_text != '': |
| owlvit_model.cpu() |
| del owlvit_model |
| del input_text |
| gc.collect() |
| torch.cuda.empty_cache() |
|
|
| return img, mask_all |
|
|
|
|
| def run_inference(device, model_type, points_per_side, pred_iou_thresh, stability_score_thresh, min_mask_region_area, |
| stability_score_offset, box_nms_thresh, crop_n_layers, crop_nms_thresh, owl_vit_threshold, input_x, |
| input_text, selected_points): |
| |
| if isinstance(input_x, int): |
| input_x = cv2.imread(image_examples[input_x][0]) |
| input_x = cv2.cvtColor(input_x, cv2.COLOR_BGR2RGB) |
| if (input_text != '' and not isinstance(input_x, str)) or len(selected_points) != 0: |
| print('use predictor_inference') |
| print('prompt text: ', input_text) |
| print('prompt points length: ', len(selected_points)) |
| return predictor_inference(device, model_type, input_x, input_text, selected_points, owl_vit_threshold) |
| else: |
| print('use generator_inference') |
| return generator_inference(device, model_type, points_per_side, pred_iou_thresh, stability_score_thresh, |
| min_mask_region_area, stability_score_offset, box_nms_thresh, crop_n_layers, |
| crop_nms_thresh, input_x) |
|
|