Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| import gradio as gr | |
| import numpy as np | |
| from PIL import Image | |
| from transformers import ( | |
| AutoProcessor, | |
| AutoModelForZeroShotObjectDetection, | |
| BlipProcessor, | |
| BlipForConditionalGeneration | |
| ) | |
| from segment_anything import sam_model_registry, SamPredictor | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| # -------------------------------------------------- | |
| # MODELS | |
| # -------------------------------------------------- | |
| DINO_MODEL = "IDEA-Research/grounding-dino-base" | |
| BLIP_MODEL = "Salesforce/blip-image-captioning-base" | |
| SAM_TYPE = "vit_b" | |
| SAM_CHECKPOINT = "sam_vit_b.pth" | |
| SAM_URL = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth" | |
| BOX_THRESHOLD = 0.3 | |
| # -------------------------------------------------- | |
| # DOWNLOAD SAM | |
| # -------------------------------------------------- | |
| if not os.path.exists(SAM_CHECKPOINT): | |
| import urllib.request | |
| print("Downloading SAM checkpoint...") | |
| urllib.request.urlretrieve(SAM_URL, SAM_CHECKPOINT) | |
| # -------------------------------------------------- | |
| # LOAD MODELS | |
| # -------------------------------------------------- | |
| print("Loading GroundingDINO...") | |
| processor = AutoProcessor.from_pretrained(DINO_MODEL) | |
| dino = AutoModelForZeroShotObjectDetection.from_pretrained(DINO_MODEL).to(DEVICE) | |
| print("Loading SAM...") | |
| sam = sam_model_registry[SAM_TYPE](checkpoint=SAM_CHECKPOINT) | |
| sam.to(device=DEVICE) | |
| predictor = SamPredictor(sam) | |
| print("Loading BLIP...") | |
| blip_processor = BlipProcessor.from_pretrained(BLIP_MODEL) | |
| blip_model = BlipForConditionalGeneration.from_pretrained(BLIP_MODEL).to(DEVICE) | |
| # -------------------------------------------------- | |
| # BLIP CAPTION | |
| # -------------------------------------------------- | |
| def generate_caption(image): | |
| inputs = blip_processor(image, return_tensors="pt").to(DEVICE) | |
| with torch.no_grad(): | |
| out = blip_model.generate(**inputs) | |
| caption = blip_processor.decode(out[0], skip_special_tokens=True) | |
| return caption | |
| # -------------------------------------------------- | |
| # DETECT OBJECTS | |
| # -------------------------------------------------- | |
| def detect(image, prompt): | |
| inputs = processor(images=image, text=prompt, return_tensors="pt").to(DEVICE) | |
| with torch.no_grad(): | |
| outputs = dino(**inputs) | |
| results = processor.post_process_grounded_object_detection( | |
| outputs, | |
| target_sizes=[image.size[::-1]], | |
| )[0] | |
| boxes = results["boxes"] | |
| scores = results["scores"] | |
| keep = scores > BOX_THRESHOLD | |
| return boxes[keep] | |
| # -------------------------------------------------- | |
| # DRAW BOXES | |
| # -------------------------------------------------- | |
| def draw_boxes(image, boxes): | |
| image_np = np.array(image) | |
| result = image_np.copy() | |
| for box in boxes: | |
| x1, y1, x2, y2 = box.cpu().numpy().astype(int) | |
| result[y1:y1+3, x1:x2] = [255, 0, 0] | |
| result[y2:y2+3, x1:x2] = [255, 0, 0] | |
| result[y1:y2, x1:x1+3] = [255, 0, 0] | |
| result[y1:y2, x2:x2+3] = [255, 0, 0] | |
| return Image.fromarray(result) | |
| # -------------------------------------------------- | |
| # SEGMENT | |
| # -------------------------------------------------- | |
| def segment(image, prompt): | |
| image = image.convert("RGB") | |
| image_np = np.array(image) | |
| boxes = detect(image, prompt) | |
| if len(boxes) == 0: | |
| return image | |
| predictor.set_image(image_np) | |
| boxes = boxes.to(DEVICE) | |
| transformed = predictor.transform.apply_boxes_torch( | |
| boxes, image_np.shape[:2] | |
| ) | |
| masks, _, _ = predictor.predict_torch( | |
| point_coords=None, | |
| point_labels=None, | |
| boxes=transformed, | |
| multimask_output=False, | |
| ) | |
| result = image_np.copy() | |
| for mask in masks: | |
| m = mask[0].cpu().numpy() | |
| result[m > 0] = ( | |
| result[m > 0] * 0.5 + np.array([0, 255, 0]) * 0.5 | |
| ).astype(np.uint8) | |
| return Image.fromarray(result) | |
| # -------------------------------------------------- | |
| # PIPELINE | |
| # -------------------------------------------------- | |
| def run_pipeline(image, prompt, mode): | |
| if mode == "seg": | |
| return segment(image, prompt) | |
| if mode == "det": | |
| boxes = detect(image, prompt) | |
| return draw_boxes(image, boxes) | |
| if mode == "automatic": | |
| caption = generate_caption(image) | |
| print("BLIP caption:", caption) | |
| return segment(image, caption) | |
| # -------------------------------------------------- | |
| # UI | |
| # -------------------------------------------------- | |
| demo = gr.Interface( | |
| fn=run_pipeline, | |
| inputs=[ | |
| gr.Image(type="pil"), | |
| gr.Textbox(label="Prompt", value="person"), | |
| gr.Dropdown( | |
| ["seg", "det", "automatic"], | |
| value="seg", | |
| label="Mode" | |
| ), | |
| ], | |
| outputs=gr.Image(), | |
| title="GroundingDINO + SAM + BLIP (CPU version)", | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |