EnginDev commited on
Commit
851e42c
Β·
verified Β·
1 Parent(s): 84e5715

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +116 -232
app.py CHANGED
@@ -1,270 +1,154 @@
1
  import gradio as gr
2
- import torch
3
  import numpy as np
4
  from PIL import Image
5
- import cv2
6
- from groundingdino.util.inference import Model as GroundingDINOModel
 
7
  from segment_anything import sam_model_registry, SamPredictor
8
  import supervision as sv
 
 
9
 
10
- print("πŸš€ Starting Grounded SAM FishBoost Edition v5.0...")
 
11
 
12
- device = "cuda" if torch.cuda.is_available() else "cpu"
13
- print(f"πŸ“± Using device: {device}")
 
14
 
15
- grounding_dino_model = None
16
- sam_predictor = None
17
 
18
- def load_models():
19
- """Load Grounding DINO + SAM models"""
20
- global grounding_dino_model, sam_predictor
21
-
22
- if grounding_dino_model is None:
23
- print("πŸ“¦ Loading Grounding DINO model...")
24
- grounding_dino_model = GroundingDINOModel(
25
- model_config_path="GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py",
26
- model_checkpoint_path="weights/groundingdino_swint_ogc.pth",
27
- device=device
28
- )
29
- print("βœ… Grounding DINO loaded!")
30
-
31
- if sam_predictor is None:
32
- print("πŸ“¦ Loading SAM model...")
33
- sam_checkpoint = "weights/sam_vit_h_4b8939.pth"
34
- model_type = "vit_h"
35
- sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
36
- sam.to(device=device)
37
- sam_predictor = SamPredictor(sam)
38
- print("βœ… SAM loaded!")
39
 
40
- def detect_fish_with_grounded_sam(image_pil, text_prompt="fish", box_threshold=0.25, text_threshold=0.25):
 
 
 
 
 
 
 
 
 
41
  """
42
- Detect and segment fish using Grounding DINO + SAM
43
-
44
- Args:
45
- image_pil: PIL Image
46
- text_prompt: Text prompt for detection (default: "fish")
47
- box_threshold: Confidence threshold for boxes
48
- text_threshold: Confidence threshold for text matching
49
-
50
- Returns:
51
- mask: Binary mask of detected fish
52
- metadata: Detection metadata
53
  """
54
- load_models()
55
-
56
- # Convert PIL to numpy
57
- image_np = np.array(image_pil)
58
-
59
- # 1. Grounding DINO: Detect fish boxes
60
- print(f"πŸ” Detecting '{text_prompt}' with Grounding DINO...")
61
- detections = grounding_dino_model.predict_with_classes(
62
- image=image_np,
63
- classes=[text_prompt],
64
- box_threshold=box_threshold,
65
- text_threshold=text_threshold
66
- )
67
-
68
- print(f"πŸ“¦ Found {len(detections.xyxy)} boxes")
69
-
70
- if len(detections.xyxy) == 0:
71
- print("❌ No fish detected!")
72
- return None, {
73
- "success": False,
74
- "mode": "grounded_sam",
75
- "detection_method": "grounding_dino",
76
- "fish_detected": False,
77
- "reason": "No fish found in image"
78
- }
79
-
80
- # Select best detection (highest confidence)
81
- best_idx = np.argmax(detections.confidence)
82
- best_box = detections.xyxy[best_idx]
83
- best_conf = float(detections.confidence[best_idx])
84
-
85
- print(f"🎯 Best detection: Confidence={best_conf:.2f}, Box={best_box}")
86
-
87
- # 2. SAM: Segment the detected fish
88
- print("βœ‚οΈ Segmenting with SAM...")
89
- sam_predictor.set_image(image_np)
90
-
91
- # Convert box to SAM format
92
- box_np = best_box.reshape(1, 4)
93
-
94
- masks, scores, _ = sam_predictor.predict(
95
- box=box_np,
96
- multimask_output=False
97
- )
98
-
99
- mask = masks[0] # Get best mask
100
-
101
- # Calculate statistics
102
- mask_area = int(np.sum(mask))
103
- total_pixels = mask.shape[0] * mask.shape[1]
104
- mask_percentage = (mask_area / total_pixels) * 100
105
-
106
- # Get contours
107
- contours, _ = cv2.findContours(
108
- mask.astype(np.uint8),
109
- cv2.RETR_EXTERNAL,
110
- cv2.CHAIN_APPROX_SIMPLE
111
- )
112
-
113
- # Get fish center
114
- if len(contours) > 0:
115
- largest_contour = max(contours, key=cv2.contourArea)
116
- M = cv2.moments(largest_contour)
117
- if M["m00"] != 0:
118
- cx = int(M["m10"] / M["m00"])
119
- cy = int(M["m01"] / M["m00"])
120
- else:
121
- cx, cy = int(best_box[0] + best_box[2]) // 2, int(best_box[1] + best_box[3]) // 2
122
- else:
123
- cx, cy = int(best_box[0] + best_box[2]) // 2, int(best_box[1] + best_box[3]) // 2
124
-
125
- # Convert contours to list format
126
- contour_points = []
127
- if len(contours) > 0:
128
- for point in contours[0][:100]: # Limit to 100 points
129
- contour_points.append({
130
- "x": int(point[0][0]),
131
- "y": int(point[0][1])
132
- })
133
-
134
- metadata = {
135
- "success": True,
136
- "mode": "grounded_sam",
137
- "detection_method": "grounding_dino_sam",
138
- "fish_detected": True,
139
- "grounding_dino": {
140
- "confidence": best_conf,
141
- "bounding_box": [int(x) for x in best_box],
142
- "text_prompt": text_prompt,
143
- "total_detections": len(detections.xyxy)
144
- },
145
- "mask_area": mask_area,
146
- "mask_percentage": mask_percentage,
147
- "num_contours": len(contours),
148
- "fish_center": [cx, cy],
149
- "image_size": list(mask.shape),
150
- "device": device,
151
- "contours": contour_points
152
- }
153
-
154
- print(f"βœ… Segmentation complete! Mask: {mask_percentage:.2f}%")
155
-
156
- return mask, metadata
157
-
158
- def process_image(image, quality="high"):
159
- """Main processing function for Gradio interface"""
160
-
161
- if image is None:
162
- return None, "❌ No image provided"
163
-
164
  try:
165
- # Convert to PIL if needed
166
- if isinstance(image, np.ndarray):
167
- image_pil = Image.fromarray(image)
168
- else:
169
- image_pil = image
 
 
 
 
 
 
 
 
 
 
 
170
 
171
- # Resize for faster processing on CPU
172
- max_size = 1024 if quality == "high" else 768
173
- image_pil.thumbnail((max_size, max_size), Image.Resampling.LANCZOS)
174
 
175
- # Detect and segment fish
176
- mask, metadata = detect_fish_with_grounded_sam(image_pil, text_prompt="fish")
177
 
178
- if mask is None:
179
- return None, f"❌ No fish detected!\n\n{metadata}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
 
181
  # Create visualization
182
- image_np = np.array(image_pil)
183
 
184
- # Apply green overlay on fish
185
- overlay = image_np.copy()
186
- overlay[mask] = [0, 255, 0] # Green
187
- result = cv2.addWeighted(image_np, 0.7, overlay, 0.3, 0)
 
188
 
189
- # Draw bounding box
190
- box = metadata["grounding_dino"]["bounding_box"]
191
- cv2.rectangle(result, (box[0], box[1]), (box[2], box[3]), (255, 0, 0), 2)
 
192
 
193
- # Add confidence text
194
- conf_text = f"Fish: {metadata['grounding_dino']['confidence']:.2f}"
195
- cv2.putText(result, conf_text, (box[0], box[1] - 10),
196
- cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 0, 0), 2)
197
 
198
- # Format metadata for display
199
- meta_str = f"""βœ… Fish detected successfully!
200
-
201
- 🎯 Grounding DINO
202
- Confidence: {metadata['grounding_dino']['confidence']:.2%}
203
- Bounding Box: {metadata['grounding_dino']['bounding_box']}
204
- Detections: {metadata['grounding_dino']['total_detections']}
205
-
206
- βœ‚οΈ SAM Segmentation
207
- Mask Area: {metadata['mask_percentage']:.2f}%
208
- Fish Center: {metadata['fish_center']}
209
- Contours: {metadata['num_contours']}
210
-
211
- βš™οΈ System
212
- Device: {metadata['device']}
213
- Image Size: {metadata['image_size']}
214
- """
215
 
216
- return result, meta_str
217
 
218
  except Exception as e:
219
- print(f"❌ Error: {str(e)}")
220
- import traceback
221
- traceback.print_exc()
222
- return None, f"❌ Error: {str(e)}"
223
 
224
- # Gradio Interface
225
- with gr.Blocks(title="🎣 FishBoost - Grounded SAM Edition") as demo:
226
- gr.Markdown("""
227
- # 🎣 FishBoost - Grounded SAM Fish Detector
228
- ### Powered by Grounding DINO + SAM
229
-
230
- Upload an image with a fish and watch the AI detect and segment it!
231
-
232
- ⚠️ **CPU Mode**: First run downloads ~680MB models (2-3 min). Processing: ~30-60 sec per image.
233
- """)
234
 
235
  with gr.Row():
236
  with gr.Column():
237
- input_image = gr.Image(type="pil", label="πŸ“€ Upload Fish Image")
 
238
  quality = gr.Radio(
239
- choices=["high", "medium"],
240
- value="high",
241
- label="🎨 Quality",
242
- info="High = 1024px, Medium = 768px (faster)"
243
  )
244
- process_btn = gr.Button("πŸš€ Detect Fish", variant="primary")
245
 
246
  with gr.Column():
247
- output_image = gr.Image(label="🎯 Detected Fish (Green = Mask, Blue = Box)")
248
- output_meta = gr.Textbox(label="πŸ“Š Detection Metadata", lines=15)
249
 
250
- process_btn.click(
251
- fn=process_image,
252
- inputs=[input_image, quality],
253
- outputs=[output_image, output_meta]
254
  )
255
 
256
- gr.Markdown("""
257
- ---
258
- ### πŸ”§ How it works
259
- 1. **Grounding DINO** finds fish bounding boxes using text prompt "fish"
260
- 2. **SAM** segments the exact fish shape within the box
261
- 3. **Result**: Precise fish mask ignoring angler/background
262
-
263
- ### πŸ“ Model Info
264
- - Grounding DINO: Text-prompted object detection
265
- - SAM (ViT-H): High-quality segmentation
266
- - Total Model Size: ~680MB
267
- """)
268
 
269
  if __name__ == "__main__":
270
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
1
  import gradio as gr
 
2
  import numpy as np
3
  from PIL import Image
4
+ import torch
5
+ from transformers import pipeline
6
+ from groundingdino.util.inference import load_model, load_image, predict
7
  from segment_anything import sam_model_registry, SamPredictor
8
  import supervision as sv
9
+ import cv2
10
+ import os
11
 
12
+ # Download models on startup
13
+ print("Loading models...")
14
 
15
+ # Load Grounding DINO model from Hugging Face
16
+ # Using a different approach that doesn't require local config files
17
+ from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
18
 
19
+ dino_processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-tiny")
20
+ dino_model = AutoModelForZeroShotObjectDetection.from_pretrained("IDEA-Research/grounding-dino-tiny")
21
 
22
+ # Load SAM model
23
+ sam_checkpoint = "sam_vit_h_4b8939.pth"
24
+ model_type = "vit_h"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
+ # Download SAM weights if not present
27
+ if not os.path.exists(sam_checkpoint):
28
+ os.system(f"wget https://dl.fbaipublicfiles.com/segment_anything/{sam_checkpoint}")
29
+
30
+ sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
31
+ sam_predictor = SamPredictor(sam)
32
+
33
+ print("Models loaded successfully!")
34
+
35
+ def detect_and_segment(image, text_prompt="fish", quality="Medium (512px)"):
36
  """
37
+ Detect objects using Grounding DINO and segment using SAM
 
 
 
 
 
 
 
 
 
 
38
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  try:
40
+ # Resize image based on quality setting
41
+ quality_map = {
42
+ "Low (256px)": 256,
43
+ "Medium (512px)": 512,
44
+ "High (1024px)": 1024
45
+ }
46
+ target_size = quality_map.get(quality, 512)
47
+
48
+ # Convert PIL to numpy
49
+ image_np = np.array(image)
50
+ h, w = image_np.shape[:2]
51
+
52
+ # Resize maintaining aspect ratio
53
+ scale = min(target_size / w, target_size / h)
54
+ new_w, new_h = int(w * scale), int(h * scale)
55
+ image_resized = cv2.resize(image_np, (new_w, new_h))
56
 
57
+ # Prepare image for Grounding DINO
58
+ inputs = dino_processor(images=image_resized, text=text_prompt, return_tensors="pt")
 
59
 
60
+ with torch.no_grad():
61
+ outputs = dino_model(**inputs)
62
 
63
+ # Post-process results
64
+ results = dino_processor.post_process_grounded_object_detection(
65
+ outputs,
66
+ inputs.input_ids,
67
+ box_threshold=0.25,
68
+ text_threshold=0.25,
69
+ target_sizes=[(new_h, new_w)]
70
+ )
71
+
72
+ if len(results) == 0 or len(results[0]["boxes"]) == 0:
73
+ return image, {"error": "No fish detected", "detections": 0}
74
+
75
+ # Get boxes and scores
76
+ boxes = results[0]["boxes"].cpu().numpy()
77
+ scores = results[0]["scores"].cpu().numpy()
78
+
79
+ # Use SAM to segment
80
+ sam_predictor.set_image(image_resized)
81
+
82
+ # Convert boxes to SAM format
83
+ masks = []
84
+ for box in boxes:
85
+ box_sam = np.array([box[0], box[1], box[2], box[3]])
86
+ mask, _, _ = sam_predictor.predict(box=box_sam, multimask_output=False)
87
+ masks.append(mask[0])
88
 
89
  # Create visualization
90
+ annotated_image = image_resized.copy()
91
 
92
+ # Draw masks
93
+ for mask in masks:
94
+ color_mask = np.zeros_like(annotated_image)
95
+ color_mask[mask] = [0, 255, 0] # Green mask
96
+ annotated_image = cv2.addWeighted(annotated_image, 1, color_mask, 0.5, 0)
97
 
98
+ # Draw bounding boxes
99
+ for box in boxes:
100
+ x1, y1, x2, y2 = map(int, box)
101
+ cv2.rectangle(annotated_image, (x1, y1), (x2, y2), (0, 0, 255), 2)
102
 
103
+ # Calculate metadata
104
+ total_pixels = new_w * new_h
105
+ mask_pixels = sum(np.sum(mask) for mask in masks)
106
+ mask_percentage = (mask_pixels / total_pixels) * 100
107
 
108
+ metadata = {
109
+ "detections": len(boxes),
110
+ "avg_confidence": float(np.mean(scores)),
111
+ "image_size": f"{new_w}x{new_h}",
112
+ "mask_percentage": f"{mask_percentage:.2f}%"
113
+ }
 
 
 
 
 
 
 
 
 
 
 
114
 
115
+ return Image.fromarray(annotated_image), metadata
116
 
117
  except Exception as e:
118
+ return image, {"error": str(e)}
 
 
 
119
 
120
+ # Create Gradio interface
121
+ with gr.Blocks(title="Grounded SAM - Fish Detection") as demo:
122
+ gr.Markdown("# 🐟 Grounded SAM: Fish Detection & Segmentation")
123
+ gr.Markdown("Upload an image and detect fish using Grounding DINO + Segment Anything Model")
 
 
 
 
 
 
124
 
125
  with gr.Row():
126
  with gr.Column():
127
+ input_image = gr.Image(type="pil", label="Upload Image")
128
+ text_prompt = gr.Textbox(value="fish", label="Detection Prompt")
129
  quality = gr.Radio(
130
+ choices=["Low (256px)", "Medium (512px)", "High (1024px)"],
131
+ value="Medium (512px)",
132
+ label="Processing Quality"
 
133
  )
134
+ submit_btn = gr.Button("Process Image", variant="primary")
135
 
136
  with gr.Column():
137
+ output_image = gr.Image(label="Detection Result")
138
+ output_metadata = gr.JSON(label="Detection Metadata")
139
 
140
+ submit_btn.click(
141
+ fn=detect_and_segment,
142
+ inputs=[input_image, text_prompt, quality],
143
+ outputs=[output_image, output_metadata]
144
  )
145
 
146
+ gr.Examples(
147
+ examples=[
148
+ ["fish_angler.jpg", "fish", "High (1024px)"],
149
+ ],
150
+ inputs=[input_image, text_prompt, quality]
151
+ )
 
 
 
 
 
 
152
 
153
  if __name__ == "__main__":
154
+ demo.launch()