File size: 2,188 Bytes
629cffb
 
 
 
 
 
 
 
117fc34
 
 
 
 
 
 
 
 
 
 
 
629cffb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
import cv2
    import numpy as np
from PIL import Image
import gradio as gr
from segment_anything import SamPredictor, sam_model_registry
from groundingdino.util.inference import load_model, predict, annotate

checkpoint_path = "./sam_vit_h_4b8939.pth"  # Will be downloaded programmatically

# Download programmatically using hf_hub_download
from huggingface_hub import hf_hub_download

checkpoint_path = hf_hub_download(
    repo_id="HCMUE-Research/SAM-vit-h",
    filename="sam_vit_h_4b8939.pth"
)

sam = sam_model_registry["vit_h"](checkpoint=checkpoint_path)

grounding_dino_config = "groundingdino/config/GroundingDINO_SwinT_OGC.py"
grounding_dino_weights = 'groundingdino_swift_ogc.path'

dino_model = load_model(grounding_dino_config, grounding_dino_weights)

sam_checkpoint = 'sam_vit_h_4b8939.pth'
sam = sam_model_registry['vit_h'](checkpoint = sam_checkpoint)
sam.to('cuda' if torch.cuda.is_available() else 'cpu')
predictor = SamPredictor(sam)

def grounded_sam_segment(image: Image.Image, prompt: str) -> Image.Image:
    image_np = np.array(image.convert('RGB'))

    boxes, logits, phrases = predict(
        model = dino_model,
        image = image_np,
        caption = prompt,
        box_threshold = 0.3,
        text_threshold = 0.25
    )

    if len(boxes) == 0:
        return image

    predictor.set_image(image_np)
    transformed_boxes = predictor.transform.apply_boxes_torch(boxes, image_np.shape[:2])
    masks, _, = predictor.predict_torch(boxes = transformed_boxes, multimask_output=False)

    mask = masks[0][0].cpu().numpy()
    mask = np.stack([mask * 255] * 3, axis =-1).astype(np.units)
    overlay = cv2.addweighted(image_np, 1, mask, 0.4, 0)
    return Image.fromarray(overlay)

    gr.Interface(
        fn=grounded_sam_segment,
        inputs=[
            gr.Image(type='pil', label='Upload Image'),
            gr.Textbox(label='Prompt', placeholder='e.g., cup handle, bottle')
        ],
        outputs=gr.Image(label='Segmented Output'),
        title='Grounded-SAM Image Segmentation',
        description="Accurate image segmentation using GroundingDINO + SAM. Prompt: 'cup handle', 'helmet', 'etc.'")
        ]
    ).launch()