Spaces:
Runtime error
Runtime error
| import argparse | |
| import functools | |
| import json | |
| import os | |
| import sys | |
| import tempfile | |
| import cv2 | |
| import numpy as np | |
| import supervision as sv | |
| from groundingdino.util.inference import Model as DinoModel | |
| from imutils import paths | |
| from PIL import Image | |
| from segment_anything import sam_model_registry | |
| from segment_anything import SamAutomaticMaskGenerator | |
| from segment_anything import SamPredictor | |
| from supervision.detection.utils import xywh_to_xyxy | |
| from tqdm import tqdm | |
| sys.path.append("tag2text") | |
| from tag2text.models import tag2text | |
| from config import * | |
| from utils import detect, download_file_hf, segment, generate_tags, show_anns_sv | |
| def process( | |
| tag2text_model, | |
| grounding_dino_model, | |
| sam_predictor, | |
| sam_automask_generator, | |
| image_path, | |
| task, | |
| prompt, | |
| box_threshold, | |
| text_threshold, | |
| iou_threshold, | |
| kernel_size=2, | |
| expand_mask=False, | |
| device="cuda", | |
| output_dir=None, | |
| save_ann=True, | |
| save_mask=False, | |
| ): | |
| detections = None | |
| metadata = {"image": {}, "annotations": [], "assets": {}} | |
| if save_mask: | |
| metadata["assets"]["intermediate_mask"] = [] | |
| try: | |
| # Load image | |
| image = Image.open(image_path) | |
| image_pil = image.convert("RGB") | |
| image = np.array(image_pil) | |
| orig_image = image.copy() | |
| # Extract image metadata | |
| filename = os.path.basename(image_path) | |
| basename = os.path.splitext(filename)[0] | |
| h, w = image.shape[:2] | |
| metadata["image"]["file_name"] = filename | |
| metadata["image"]["width"] = w | |
| metadata["image"]["height"] = h | |
| # Generate tags | |
| if task in ["auto", "detection"] and prompt == "": | |
| tags, caption = generate_tags(tag2text_model, image_pil, "None", device) | |
| prompt = " . ".join(tags) | |
| # print(f"Caption: {caption}") | |
| # print(f"Tags: {tags}") | |
| # ToDo: Extract metadata | |
| metadata["image"]["caption"] = caption | |
| metadata["image"]["tags"] = tags | |
| if prompt: | |
| metadata["prompt"] = prompt | |
| # Detect boxes | |
| if prompt != "": | |
| detections, phrases, classes = detect( | |
| grounding_dino_model, | |
| image, | |
| caption=prompt, | |
| box_threshold=box_threshold, | |
| text_threshold=text_threshold, | |
| iou_threshold=iou_threshold, | |
| post_process=True, | |
| ) | |
| # Save detection image | |
| if output_dir and save_ann: | |
| # Draw boxes | |
| box_annotator = sv.BoxAnnotator() | |
| labels = [ | |
| f"{phrases[i]} {detections.confidence[i]:0.2f}" | |
| for i in range(len(phrases)) | |
| ] | |
| box_image = box_annotator.annotate( | |
| scene=image, detections=detections, labels=labels | |
| ) | |
| box_image_path = os.path.join(output_dir, basename + "_detect.png") | |
| metadata["assets"]["detection"] = box_image_path | |
| Image.fromarray(box_image).save(box_image_path) | |
| # Segmentation | |
| if task in ["auto", "segment"]: | |
| kernel = cv2.getStructuringElement( | |
| cv2.MORPH_ELLIPSE, (2 * kernel_size + 1, 2 * kernel_size + 1) | |
| ) | |
| if detections: | |
| masks, scores = segment( | |
| sam_predictor, image=orig_image, boxes=detections.xyxy | |
| ) | |
| if expand_mask: | |
| masks = [ | |
| cv2.dilate(mask.astype(np.uint8), kernel) for mask in masks | |
| ] | |
| else: | |
| masks = [ | |
| cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, kernel) | |
| for mask in masks | |
| ] | |
| detections.mask = masks | |
| binary_mask = functools.reduce( | |
| lambda x, y: x + y, detections.mask | |
| ).astype(np.bool) | |
| else: | |
| masks = sam_automask_generator.generate(orig_image) | |
| sorted_generated_masks = sorted( | |
| masks, key=lambda x: x["area"], reverse=True | |
| ) | |
| xywh = np.array([mask["bbox"] for mask in sorted_generated_masks]) | |
| scores = np.array( | |
| [mask["predicted_iou"] for mask in sorted_generated_masks] | |
| ) | |
| if expand_mask: | |
| mask = np.array( | |
| [ | |
| cv2.dilate(mask["segmentation"].astype(np.uint8), kernel) | |
| for mask in sorted_generated_masks | |
| ] | |
| ) | |
| else: | |
| mask = np.array( | |
| [mask["segmentation"] for mask in sorted_generated_masks] | |
| ) | |
| detections = sv.Detections( | |
| xyxy=xywh_to_xyxy(boxes_xywh=xywh), mask=mask | |
| ) | |
| binary_mask = None | |
| # Save annotated image | |
| if output_dir and save_ann: | |
| mask_annotator = sv.MaskAnnotator() | |
| mask_image, res = show_anns_sv(detections) | |
| annotated_image = mask_annotator.annotate(image, detections=detections) | |
| mask_image_path = os.path.join(output_dir, basename + "_mask.png") | |
| metadata["assets"]["mask"] = mask_image_path | |
| Image.fromarray(mask_image).save(mask_image_path) | |
| # Save annotation encoding from https://github.com/LUSSeg/ImageNet-S | |
| mask_enc_path = os.path.join(output_dir, basename + "_mask_enc.npy") | |
| np.save(mask_enc_path, res) | |
| metadata["assets"]["mask_enc"] = mask_enc_path | |
| if binary_mask is not None: | |
| cutout_image = np.expand_dims(binary_mask, axis=-1) * orig_image | |
| cutout_image_path = os.path.join( | |
| output_dir, basename + "_cutout.png" | |
| ) | |
| Image.fromarray(cutout_image).save(cutout_image_path) | |
| annotated_image_path = os.path.join( | |
| output_dir, basename + "_annotate.png" | |
| ) | |
| metadata["assets"]["annotate"] = annotated_image_path | |
| Image.fromarray(annotated_image).save(annotated_image_path) | |
| # ToDo: Extract metadata | |
| if detections: | |
| i = 0 | |
| for (xyxy, mask, confidence, _, _), area, box_area in zip( | |
| detections, detections.area, detections.box_area | |
| ): | |
| annotation = { | |
| "id": i + 1, | |
| "bbox": [int(x) for x in xyxy], | |
| "box_area": float(box_area), | |
| } | |
| if confidence: | |
| annotation["confidence"] = float(confidence) | |
| annotation["label"] = phrases[i] | |
| if mask is not None: | |
| # annotation["segmentation"] = mask_to_polygons(mask) | |
| annotation["area"] = int(area) | |
| annotation["predicted_iou"] = float(scores[i]) | |
| metadata["annotations"].append(annotation) | |
| i += 1 | |
| if output_dir and save_mask: | |
| mask_image_path = os.path.join( | |
| output_dir, f"{basename}_mask_{id}.png" | |
| ) | |
| metadata["assets"]["intermediate_mask"].append(mask_image_path) | |
| Image.fromarray(mask * 255).save(mask_image_path) | |
| if output_dir: | |
| meta_file_path = os.path.join(output_dir, basename + "_meta.json") | |
| with open(meta_file_path, "w") as fp: | |
| json.dump(metadata, fp) | |
| else: | |
| meta_file = tempfile.NamedTemporaryFile(delete=False, suffix=".json") | |
| meta_file_path = meta_file.name | |
| return meta_file_path | |
| except Exception as error: | |
| raise ValueError(f"global exception: {error}") | |
| def main(args: argparse.Namespace) -> None: | |
| device = args.device | |
| prompt = args.prompt | |
| task = args.task | |
| tag2text_model = None | |
| grounding_dino_model = None | |
| sam_predictor = None | |
| sam_automask_generator = None | |
| box_threshold = args.box_threshold | |
| text_threshold = args.text_threshold | |
| iou_threshold = args.iou_threshold | |
| save_ann = not args.no_save_ann | |
| save_mask = args.save_mask | |
| # load model | |
| if task in ["auto", "detection"] and prompt == "": | |
| print("Loading Tag2Text model...") | |
| tag2text_type = args.tag2text_type | |
| tag2text_checkpoint = os.path.join( | |
| abs_weight_dir, tag2text_dict[tag2text_type]["checkpoint_file"] | |
| ) | |
| if not os.path.exists(tag2text_checkpoint): | |
| print(f"Downloading weights for Tag2Text {tag2text_type} model") | |
| os.system( | |
| f"wget {tag2text_dict[tag2text_type]['checkpoint_url']} -O {tag2text_checkpoint}" | |
| ) | |
| tag2text_model = tag2text.tag2text_caption( | |
| pretrained=tag2text_checkpoint, | |
| image_size=384, | |
| vit="swin_b", | |
| delete_tag_index=delete_tag_index, | |
| ) | |
| # threshold for tagging | |
| # we reduce the threshold to obtain more tags | |
| tag2text_model.threshold = 0.64 | |
| tag2text_model.to(device) | |
| tag2text_model.eval() | |
| if task in ["auto", "detection"] or prompt != "": | |
| print("Loading Grounding Dino model...") | |
| dino_type = args.dino_type | |
| dino_checkpoint = os.path.join( | |
| abs_weight_dir, dino_dict[dino_type]["checkpoint_file"] | |
| ) | |
| dino_config_file = os.path.join( | |
| abs_weight_dir, dino_dict[dino_type]["config_file"] | |
| ) | |
| if not os.path.exists(dino_checkpoint): | |
| print(f"Downloading weights for Grounding Dino {dino_type} model") | |
| dino_repo_id = dino_dict[dino_type]["repo_id"] | |
| download_file_hf( | |
| repo_id=dino_repo_id, | |
| filename=dino_dict[dino_type]["checkpoint_file"], | |
| cache_dir=weight_dir, | |
| ) | |
| download_file_hf( | |
| repo_id=dino_repo_id, | |
| filename=dino_dict[dino_type]["checkpoint_file"], | |
| cache_dir=weight_dir, | |
| ) | |
| grounding_dino_model = DinoModel( | |
| model_config_path=dino_config_file, | |
| model_checkpoint_path=dino_checkpoint, | |
| device=device, | |
| ) | |
| if task in ["auto", "segment"]: | |
| print("Loading SAM...") | |
| sam_type = args.sam_type | |
| sam_checkpoint = os.path.join( | |
| abs_weight_dir, sam_dict[sam_type]["checkpoint_file"] | |
| ) | |
| if not os.path.exists(sam_checkpoint): | |
| print(f"Downloading weights for SAM {sam_type}") | |
| os.system( | |
| f"wget {sam_dict[sam_type]['checkpoint_url']} -O {sam_checkpoint}" | |
| ) | |
| sam = sam_model_registry[sam_type](checkpoint=sam_checkpoint) | |
| sam.to(device=device) | |
| sam_predictor = SamPredictor(sam) | |
| sam_automask_generator = SamAutomaticMaskGenerator(sam) | |
| if not os.path.exists(args.input): | |
| raise ValueError("The input directory doesn't exist!") | |
| elif not os.path.isdir(args.input): | |
| image_paths = [args.input] | |
| else: | |
| image_paths = paths.list_images(args.input) | |
| os.makedirs(args.output, exist_ok=True) | |
| with tqdm(image_paths) as pbar: | |
| for image_path in pbar: | |
| pbar.set_postfix_str(f"Processing {image_path}") | |
| process( | |
| tag2text_model=tag2text_model, | |
| grounding_dino_model=grounding_dino_model, | |
| sam_predictor=sam_predictor, | |
| sam_automask_generator=sam_automask_generator, | |
| image_path=image_path, | |
| task=task, | |
| prompt=prompt, | |
| box_threshold=box_threshold, | |
| text_threshold=text_threshold, | |
| iou_threshold=iou_threshold, | |
| device=device, | |
| output_dir=args.output, | |
| save_ann=save_ann, | |
| save_mask=save_mask, | |
| ) | |
| if __name__ == "__main__": | |
| if not os.path.exists(abs_weight_dir): | |
| os.makedirs(abs_weight_dir, exist_ok=True) | |
| parser = argparse.ArgumentParser( | |
| description=( | |
| "Runs automatic detection and mask generation on an input image or directory of images" | |
| ) | |
| ) | |
| parser.add_argument( | |
| "--input", | |
| "-i", | |
| type=str, | |
| required=True, | |
| help="Path to either a single input image or folder of images.", | |
| ) | |
| parser.add_argument( | |
| "--output", | |
| "-o", | |
| type=str, | |
| required=True, | |
| help="Path to the directory where masks will be output.", | |
| ) | |
| parser.add_argument( | |
| "--sam-type", | |
| type=str, | |
| default=default_sam, | |
| choices=sam_dict.keys(), | |
| help="The type of SA model use for segmentation.", | |
| ) | |
| parser.add_argument( | |
| "--tag2text-type", | |
| type=str, | |
| default=default_tag2text, | |
| choices=tag2text_dict.keys(), | |
| help="The type of Tag2Text model use for tags and caption generation.", | |
| ) | |
| parser.add_argument( | |
| "--dino-type", | |
| type=str, | |
| default=default_dino, | |
| choices=dino_dict.keys(), | |
| help="The type of Grounding Dino model use for promptable object detection.", | |
| ) | |
| parser.add_argument( | |
| "--task", | |
| help="Task to run", | |
| default="auto", | |
| choices=["auto", "detect", "segment"], | |
| type=str, | |
| ) | |
| parser.add_argument( | |
| "--prompt", | |
| help="Detection prompt", | |
| default="", | |
| type=str, | |
| ) | |
| parser.add_argument( | |
| "--box-threshold", type=float, default=0.25, help="box threshold" | |
| ) | |
| parser.add_argument( | |
| "--text-threshold", type=float, default=0.2, help="text threshold" | |
| ) | |
| parser.add_argument( | |
| "--iou-threshold", type=float, default=0.5, help="iou threshold" | |
| ) | |
| parser.add_argument( | |
| "--kernel-size", | |
| type=int, | |
| default=2, | |
| choices=range(1, 6), | |
| help="kernel size use for smoothing/expanding segment masks", | |
| ) | |
| parser.add_argument( | |
| "--expand-mask", | |
| action="store_true", | |
| default=False, | |
| help="If True, expanding segment masks for smoother output.", | |
| ) | |
| parser.add_argument( | |
| "--no-save-ann", | |
| action="store_true", | |
| default=False, | |
| help="If False, save original image with blended masks and detection boxes.", | |
| ) | |
| parser.add_argument( | |
| "--save-mask", | |
| action="store_true", | |
| default=False, | |
| help="If True, save all intermidiate masks.", | |
| ) | |
| parser.add_argument( | |
| "--device", type=str, default="cuda", help="The device to run generation on." | |
| ) | |
| args = parser.parse_args() | |
| main(args) | |