EnginDev commited on
Commit
6f96e62
·
verified ·
1 Parent(s): d4de6a5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +112 -101
app.py CHANGED
@@ -1,154 +1,165 @@
1
  import gradio as gr
 
2
  import numpy as np
3
  from PIL import Image
4
- import torch
5
- from transformers import pipeline
6
- from groundingdino.util.inference import load_model, load_image, predict
7
  from segment_anything import sam_model_registry, SamPredictor
 
8
  import supervision as sv
9
- import cv2
10
  import os
 
11
 
12
- # Download models on startup
13
- print("Loading models...")
14
-
15
- # Load Grounding DINO model from Hugging Face
16
- # Using a different approach that doesn't require local config files
17
- from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
18
 
19
- dino_processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-tiny")
20
- dino_model = AutoModelForZeroShotObjectDetection.from_pretrained("IDEA-Research/grounding-dino-tiny")
 
 
21
 
22
- # Load SAM model
23
- sam_checkpoint = "sam_vit_h_4b8939.pth"
24
- model_type = "vit_h"
25
 
26
- # Download SAM weights if not present
27
- if not os.path.exists(sam_checkpoint):
28
- os.system(f"wget https://dl.fbaipublicfiles.com/segment_anything/{sam_checkpoint}")
 
 
29
 
30
- sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
 
 
31
  sam_predictor = SamPredictor(sam)
32
 
33
- print("Models loaded successfully!")
34
-
35
- def detect_and_segment(image, text_prompt="fish", quality="Medium (512px)"):
36
  """
37
- Detect objects using Grounding DINO and segment using SAM
38
  """
39
  try:
40
- # Resize image based on quality setting
41
- quality_map = {
42
- "Low (256px)": 256,
43
- "Medium (512px)": 512,
44
- "High (1024px)": 1024
45
- }
46
- target_size = quality_map.get(quality, 512)
47
-
48
- # Convert PIL to numpy
49
- image_np = np.array(image)
50
- h, w = image_np.shape[:2]
51
-
52
- # Resize maintaining aspect ratio
53
- scale = min(target_size / w, target_size / h)
54
- new_w, new_h = int(w * scale), int(h * scale)
55
- image_resized = cv2.resize(image_np, (new_w, new_h))
56
-
57
- # Prepare image for Grounding DINO
58
- inputs = dino_processor(images=image_resized, text=text_prompt, return_tensors="pt")
 
59
 
60
  with torch.no_grad():
61
- outputs = dino_model(**inputs)
62
 
63
  # Post-process results
64
- results = dino_processor.post_process_grounded_object_detection(
65
  outputs,
66
  inputs.input_ids,
67
- box_threshold=0.25,
68
- text_threshold=0.25,
69
- target_sizes=[(new_h, new_w)]
70
- )
 
 
 
 
71
 
72
- if len(results) == 0 or len(results[0]["boxes"]) == 0:
73
- return image, {"error": "No fish detected", "detections": 0}
74
 
75
- # Get boxes and scores
76
- boxes = results[0]["boxes"].cpu().numpy()
77
- scores = results[0]["scores"].cpu().numpy()
78
 
79
- # Use SAM to segment
80
- sam_predictor.set_image(image_resized)
81
 
82
- # Convert boxes to SAM format
83
  masks = []
84
- for box in boxes:
85
- box_sam = np.array([box[0], box[1], box[2], box[3]])
86
- mask, _, _ = sam_predictor.predict(box=box_sam, multimask_output=False)
 
 
87
  masks.append(mask[0])
88
 
89
- # Create visualization
90
- annotated_image = image_resized.copy()
91
 
92
  # Draw masks
93
- for mask in masks:
94
- color_mask = np.zeros_like(annotated_image)
95
- color_mask[mask] = [0, 255, 0] # Green mask
96
- annotated_image = cv2.addWeighted(annotated_image, 1, color_mask, 0.5, 0)
97
 
98
- # Draw bounding boxes
99
- for box in boxes:
100
  x1, y1, x2, y2 = map(int, box)
101
- cv2.rectangle(annotated_image, (x1, y1), (x2, y2), (0, 0, 255), 2)
102
-
103
- # Calculate metadata
104
- total_pixels = new_w * new_h
105
- mask_pixels = sum(np.sum(mask) for mask in masks)
106
- mask_percentage = (mask_pixels / total_pixels) * 100
107
-
108
- metadata = {
109
- "detections": len(boxes),
110
- "avg_confidence": float(np.mean(scores)),
111
- "image_size": f"{new_w}x{new_h}",
112
- "mask_percentage": f"{mask_percentage:.2f}%"
113
- }
114
 
115
- return Image.fromarray(annotated_image), metadata
 
116
 
117
  except Exception as e:
118
- return image, {"error": str(e)}
119
 
120
- # Create Gradio interface
121
- with gr.Blocks(title="Grounded SAM - Fish Detection") as demo:
122
- gr.Markdown("# 🐟 Grounded SAM: Fish Detection & Segmentation")
123
- gr.Markdown("Upload an image and detect fish using Grounding DINO + Segment Anything Model")
124
 
125
  with gr.Row():
126
  with gr.Column():
127
- input_image = gr.Image(type="pil", label="Upload Image")
128
- text_prompt = gr.Textbox(value="fish", label="Detection Prompt")
129
- quality = gr.Radio(
130
- choices=["Low (256px)", "Medium (512px)", "High (1024px)"],
131
- value="Medium (512px)",
132
- label="Processing Quality"
133
  )
134
- submit_btn = gr.Button("Process Image", variant="primary")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
  with gr.Column():
137
- output_image = gr.Image(label="Detection Result")
138
- output_metadata = gr.JSON(label="Detection Metadata")
139
 
140
  submit_btn.click(
141
- fn=detect_and_segment,
142
- inputs=[input_image, text_prompt, quality],
143
  outputs=[output_image, output_metadata]
144
  )
145
 
146
  gr.Examples(
147
  examples=[
148
- ["fish_angler.jpg", "fish", "High (1024px)"],
 
149
  ],
150
- inputs=[input_image, text_prompt, quality]
151
  )
152
 
153
- if __name__ == "__main__":
154
- demo.launch()
 
1
  import gradio as gr
2
+ import torch
3
  import numpy as np
4
  from PIL import Image
5
+ import cv2
 
 
6
  from segment_anything import sam_model_registry, SamPredictor
7
+ from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
8
  import supervision as sv
 
9
  import os
10
+ import urllib.request
11
 
12
+ # Download SAM checkpoint if not exists
13
+ SAM_CHECKPOINT = "sam_vit_h_4b8939.pth"
14
+ SAM_CHECKPOINT_URL = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
 
 
 
15
 
16
+ if not os.path.exists(SAM_CHECKPOINT):
17
+ print(f"Downloading SAM checkpoint...")
18
+ urllib.request.urlretrieve(SAM_CHECKPOINT_URL, SAM_CHECKPOINT)
19
+ print(f"SAM checkpoint downloaded!")
20
 
21
+ # Initialize models
22
+ device = "cuda" if torch.cuda.is_available() else "cpu"
 
23
 
24
+ # Load Grounding DINO from Hugging Face
25
+ grounding_dino_processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-tiny")
26
+ grounding_dino_model = AutoModelForZeroShotObjectDetection.from_pretrained(
27
+ "IDEA-Research/grounding-dino-tiny"
28
+ ).to(device)
29
 
30
+ # Load SAM
31
+ sam = sam_model_registry["vit_h"](checkpoint=SAM_CHECKPOINT)
32
+ sam.to(device=device)
33
  sam_predictor = SamPredictor(sam)
34
 
35
+ def process_image(image, text_prompt, box_threshold, text_threshold, quality):
 
 
36
  """
37
+ Process image with Grounded SAM
38
  """
39
  try:
40
+ # Resize based on quality setting
41
+ if quality == "Low":
42
+ max_size = 800
43
+ elif quality == "Medium":
44
+ max_size = 1024
45
+ else: # High
46
+ max_size = 1920
47
+
48
+ # Resize image if needed
49
+ h, w = image.shape[:2]
50
+ if max(h, w) > max_size:
51
+ scale = max_size / max(h, w)
52
+ new_h, new_w = int(h * scale), int(w * scale)
53
+ image = cv2.resize(image, (new_w, new_h))
54
+
55
+ # Convert to PIL Image for Grounding DINO
56
+ pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
57
+
58
+ # Grounding DINO inference
59
+ inputs = grounding_dino_processor(images=pil_image, text=text_prompt, return_tensors="pt").to(device)
60
 
61
  with torch.no_grad():
62
+ outputs = grounding_dino_model(**inputs)
63
 
64
  # Post-process results
65
+ results = grounding_dino_processor.post_process_grounded_object_detection(
66
  outputs,
67
  inputs.input_ids,
68
+ box_threshold=box_threshold,
69
+ text_threshold=text_threshold,
70
+ target_sizes=[pil_image.size[::-1]]
71
+ )[0]
72
+
73
+ # Extract boxes and labels
74
+ boxes = results["boxes"].cpu().numpy()
75
+ labels = results["labels"]
76
 
77
+ if len(boxes) == 0:
78
+ return image, "No objects detected. Try adjusting the thresholds or text prompt."
79
 
80
+ # Convert boxes to xyxy format for SAM
81
+ boxes_xyxy = boxes
 
82
 
83
+ # SAM inference
84
+ sam_predictor.set_image(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
85
 
 
86
  masks = []
87
+ for box in boxes_xyxy:
88
+ mask, _, _ = sam_predictor.predict(
89
+ box=box,
90
+ multimask_output=False
91
+ )
92
  masks.append(mask[0])
93
 
94
+ # Visualize results
95
+ result_image = image.copy()
96
 
97
  # Draw masks
98
+ for i, mask in enumerate(masks):
99
+ color = np.random.randint(0, 255, 3).tolist()
100
+ result_image[mask] = result_image[mask] * 0.5 + np.array(color) * 0.5
 
101
 
102
+ # Draw boxes and labels
103
+ for i, (box, label) in enumerate(zip(boxes_xyxy, labels)):
104
  x1, y1, x2, y2 = map(int, box)
105
+ color = np.random.randint(0, 255, 3).tolist()
106
+ cv2.rectangle(result_image, (x1, y1), (x2, y2), color, 2)
107
+ cv2.putText(result_image, label, (x1, y1-10),
108
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
 
 
 
 
 
 
 
 
 
109
 
110
+ metadata = f"✅ Detected {len(boxes)} objects: {', '.join(labels)}"
111
+ return result_image, metadata
112
 
113
  except Exception as e:
114
+ return image, f"❌ Error: {str(e)}"
115
 
116
+ # Gradio Interface
117
+ with gr.Blocks(title="Grounded SAM") as demo:
118
+ gr.Markdown("# 🎯 Grounded SAM - Object Detection & Segmentation")
119
+ gr.Markdown("Upload an image and describe what you want to detect (e.g., 'fish', 'all fish', 'person').")
120
 
121
  with gr.Row():
122
  with gr.Column():
123
+ input_image = gr.Image(label="Input Image", type="numpy")
124
+ text_prompt = gr.Textbox(
125
+ label="Text Prompt",
126
+ placeholder="e.g., 'fish', 'person', 'car'",
127
+ value="fish"
 
128
  )
129
+
130
+ with gr.Accordion("Advanced Settings", open=False):
131
+ box_threshold = gr.Slider(
132
+ minimum=0.0, maximum=1.0, value=0.35, step=0.05,
133
+ label="Box Threshold (detection confidence)"
134
+ )
135
+ text_threshold = gr.Slider(
136
+ minimum=0.0, maximum=1.0, value=0.25, step=0.05,
137
+ label="Text Threshold (text matching confidence)"
138
+ )
139
+ quality = gr.Radio(
140
+ choices=["Low", "Medium", "High"],
141
+ value="Medium",
142
+ label="Processing Quality"
143
+ )
144
+
145
+ submit_btn = gr.Button("🚀 Process Image", variant="primary")
146
 
147
  with gr.Column():
148
+ output_image = gr.Image(label="Output with Masks & Boxes", type="numpy")
149
+ output_metadata = gr.Textbox(label="Detection Metadata", lines=3)
150
 
151
  submit_btn.click(
152
+ fn=process_image,
153
+ inputs=[input_image, text_prompt, box_threshold, text_threshold, quality],
154
  outputs=[output_image, output_metadata]
155
  )
156
 
157
  gr.Examples(
158
  examples=[
159
+ ["examples/fish1.jpg", "fish", 0.35, 0.25, "Medium"],
160
+ ["examples/fish2.jpg", "all fish", 0.35, 0.25, "Medium"],
161
  ],
162
+ inputs=[input_image, text_prompt, box_threshold, text_threshold, quality],
163
  )
164
 
165
+ demo.launch()