snair94 commited on
Commit
629cffb
·
verified ·
1 Parent(s): 7554540

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -0
app.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import cv2
3
+ import numpy as np
4
+ from PIL import Image
5
+ import gradio as gr
6
+ from segment_anything import SamPredictor, sam_model_registry
7
+ from groundingdino.util.inference import load_model, predict, annotate
8
+
9
+ grounding_dino_config = "groundingdino/config/GroundingDINO_SwinT_OGC.py"
10
+ grounding_dino_weights = 'groundingdino_swift_ogc.path'
11
+
12
+ dino_model = load_model(grounding_dino_config, grounding_dino_weights)
13
+
14
+ sam_checkpoint = 'sam_vit_h_4b8939.pth'
15
+ sam = sam_model_registry['vit_h'](checkpoint = sam_checkpoint)
16
+ sam.to('cuda' if torch.cuda.is_available() else 'cpu')
17
+ predictor = SamPredictor(sam)
18
+
19
+ def grounded_sam_segment(image: Image.Image, prompt: str) -> Image.Image:
20
+ image_np = np.array(image.convert('RGB'))
21
+
22
+ boxes, logits, phrases = predict(
23
+ model = dino_model,
24
+ image = image_np,
25
+ caption = prompt,
26
+ box_threshold = 0.3,
27
+ text_threshold = 0.25
28
+ )
29
+
30
+ if len(boxes) == 0:
31
+ return image
32
+
33
+ predictor.set_image(image_np)
34
+ transformed_boxes = predictor.transform.apply_boxes_torch(boxes, image_np.shape[:2])
35
+ masks, _, = predictor.predict_torch(boxes = transformed_boxes, multimask_output=False)
36
+
37
+ mask = masks[0][0].cpu().numpy()
38
+ mask = np.stack([mask * 255] * 3, axis =-1).astype(np.units)
39
+ overlay = cv2.addweighted(image_np, 1, mask, 0.4, 0)
40
+ return Image.fromarray(overlay)
41
+
42
+ gr.Interface(
43
+ fn=grounded_sam_segment,
44
+ inputs=[
45
+ gr.Image(type='pil', label='Upload Image'),
46
+ gr.Textbox(label='Prompt', placeholder='e.g., cup handle, bottle')
47
+ ],
48
+ outputs=gr.Image(label='Segmented Output'),
49
+ title='Grounded-SAM Image Segmentation',
50
+ description="Accurate image segmentation using GroundingDINO + SAM. Prompt: 'cup handle', 'helmet', 'etc.'")
51
+ ]
52
+ ).launch()