import os import torch import gradio as gr import numpy as np from PIL import Image from transformers import ( AutoProcessor, AutoModelForZeroShotObjectDetection, BlipProcessor, BlipForConditionalGeneration ) from segment_anything import sam_model_registry, SamPredictor DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # -------------------------------------------------- # MODELS # -------------------------------------------------- DINO_MODEL = "IDEA-Research/grounding-dino-base" BLIP_MODEL = "Salesforce/blip-image-captioning-base" SAM_TYPE = "vit_b" SAM_CHECKPOINT = "sam_vit_b.pth" SAM_URL = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth" BOX_THRESHOLD = 0.3 # -------------------------------------------------- # DOWNLOAD SAM # -------------------------------------------------- if not os.path.exists(SAM_CHECKPOINT): import urllib.request print("Downloading SAM checkpoint...") urllib.request.urlretrieve(SAM_URL, SAM_CHECKPOINT) # -------------------------------------------------- # LOAD MODELS # -------------------------------------------------- print("Loading GroundingDINO...") processor = AutoProcessor.from_pretrained(DINO_MODEL) dino = AutoModelForZeroShotObjectDetection.from_pretrained(DINO_MODEL).to(DEVICE) print("Loading SAM...") sam = sam_model_registry[SAM_TYPE](checkpoint=SAM_CHECKPOINT) sam.to(device=DEVICE) predictor = SamPredictor(sam) print("Loading BLIP...") blip_processor = BlipProcessor.from_pretrained(BLIP_MODEL) blip_model = BlipForConditionalGeneration.from_pretrained(BLIP_MODEL).to(DEVICE) # -------------------------------------------------- # BLIP CAPTION # -------------------------------------------------- def generate_caption(image): inputs = blip_processor(image, return_tensors="pt").to(DEVICE) with torch.no_grad(): out = blip_model.generate(**inputs) caption = blip_processor.decode(out[0], skip_special_tokens=True) return caption # -------------------------------------------------- # DETECT OBJECTS # -------------------------------------------------- def detect(image, prompt): inputs = processor(images=image, text=prompt, return_tensors="pt").to(DEVICE) with torch.no_grad(): outputs = dino(**inputs) results = processor.post_process_grounded_object_detection( outputs, target_sizes=[image.size[::-1]], )[0] boxes = results["boxes"] scores = results["scores"] keep = scores > BOX_THRESHOLD return boxes[keep] # -------------------------------------------------- # DRAW BOXES # -------------------------------------------------- def draw_boxes(image, boxes): image_np = np.array(image) result = image_np.copy() for box in boxes: x1, y1, x2, y2 = box.cpu().numpy().astype(int) result[y1:y1+3, x1:x2] = [255, 0, 0] result[y2:y2+3, x1:x2] = [255, 0, 0] result[y1:y2, x1:x1+3] = [255, 0, 0] result[y1:y2, x2:x2+3] = [255, 0, 0] return Image.fromarray(result) # -------------------------------------------------- # SEGMENT # -------------------------------------------------- def segment(image, prompt): image = image.convert("RGB") image_np = np.array(image) boxes = detect(image, prompt) if len(boxes) == 0: return image predictor.set_image(image_np) boxes = boxes.to(DEVICE) transformed = predictor.transform.apply_boxes_torch( boxes, image_np.shape[:2] ) masks, _, _ = predictor.predict_torch( point_coords=None, point_labels=None, boxes=transformed, multimask_output=False, ) result = image_np.copy() for mask in masks: m = mask[0].cpu().numpy() result[m > 0] = ( result[m > 0] * 0.5 + np.array([0, 255, 0]) * 0.5 ).astype(np.uint8) return Image.fromarray(result) # -------------------------------------------------- # PIPELINE # -------------------------------------------------- def run_pipeline(image, prompt, mode): if mode == "seg": return segment(image, prompt) if mode == "det": boxes = detect(image, prompt) return draw_boxes(image, boxes) if mode == "automatic": caption = generate_caption(image) print("BLIP caption:", caption) return segment(image, caption) # -------------------------------------------------- # UI # -------------------------------------------------- demo = gr.Interface( fn=run_pipeline, inputs=[ gr.Image(type="pil"), gr.Textbox(label="Prompt", value="person"), gr.Dropdown( ["seg", "det", "automatic"], value="seg", label="Mode" ), ], outputs=gr.Image(), title="GroundingDINO + SAM + BLIP (CPU version)", ) if __name__ == "__main__": demo.launch()