EnginDev commited on
Commit
7a701fd
Β·
verified Β·
1 Parent(s): 4d6a7f4

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +270 -0
app.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ from PIL import Image
5
+ import cv2
6
+ from groundingdino.util.inference import Model as GroundingDINOModel
7
+ from segment_anything import sam_model_registry, SamPredictor
8
+ import supervision as sv
9
+
10
+ print("πŸš€ Starting Grounded SAM FishBoost Edition v5.0...")
11
+
12
+ device = "cuda" if torch.cuda.is_available() else "cpu"
13
+ print(f"πŸ“± Using device: {device}")
14
+
15
+ grounding_dino_model = None
16
+ sam_predictor = None
17
+
18
+ def load_models():
19
+ """Load Grounding DINO + SAM models"""
20
+ global grounding_dino_model, sam_predictor
21
+
22
+ if grounding_dino_model is None:
23
+ print("πŸ“¦ Loading Grounding DINO model...")
24
+ grounding_dino_model = GroundingDINOModel(
25
+ model_config_path="GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py",
26
+ model_checkpoint_path="weights/groundingdino_swint_ogc.pth",
27
+ device=device
28
+ )
29
+ print("βœ… Grounding DINO loaded!")
30
+
31
+ if sam_predictor is None:
32
+ print("πŸ“¦ Loading SAM model...")
33
+ sam_checkpoint = "weights/sam_vit_h_4b8939.pth"
34
+ model_type = "vit_h"
35
+ sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
36
+ sam.to(device=device)
37
+ sam_predictor = SamPredictor(sam)
38
+ print("βœ… SAM loaded!")
39
+
40
+ def detect_fish_with_grounded_sam(image_pil, text_prompt="fish", box_threshold=0.25, text_threshold=0.25):
41
+ """
42
+ Detect and segment fish using Grounding DINO + SAM
43
+
44
+ Args:
45
+ image_pil: PIL Image
46
+ text_prompt: Text prompt for detection (default: "fish")
47
+ box_threshold: Confidence threshold for boxes
48
+ text_threshold: Confidence threshold for text matching
49
+
50
+ Returns:
51
+ mask: Binary mask of detected fish
52
+ metadata: Detection metadata
53
+ """
54
+ load_models()
55
+
56
+ # Convert PIL to numpy
57
+ image_np = np.array(image_pil)
58
+
59
+ # 1. Grounding DINO: Detect fish boxes
60
+ print(f"πŸ” Detecting '{text_prompt}' with Grounding DINO...")
61
+ detections = grounding_dino_model.predict_with_classes(
62
+ image=image_np,
63
+ classes=[text_prompt],
64
+ box_threshold=box_threshold,
65
+ text_threshold=text_threshold
66
+ )
67
+
68
+ print(f"πŸ“¦ Found {len(detections.xyxy)} boxes")
69
+
70
+ if len(detections.xyxy) == 0:
71
+ print("❌ No fish detected!")
72
+ return None, {
73
+ "success": False,
74
+ "mode": "grounded_sam",
75
+ "detection_method": "grounding_dino",
76
+ "fish_detected": False,
77
+ "reason": "No fish found in image"
78
+ }
79
+
80
+ # Select best detection (highest confidence)
81
+ best_idx = np.argmax(detections.confidence)
82
+ best_box = detections.xyxy[best_idx]
83
+ best_conf = float(detections.confidence[best_idx])
84
+
85
+ print(f"🎯 Best detection: Confidence={best_conf:.2f}, Box={best_box}")
86
+
87
+ # 2. SAM: Segment the detected fish
88
+ print("βœ‚οΈ Segmenting with SAM...")
89
+ sam_predictor.set_image(image_np)
90
+
91
+ # Convert box to SAM format
92
+ box_np = best_box.reshape(1, 4)
93
+
94
+ masks, scores, _ = sam_predictor.predict(
95
+ box=box_np,
96
+ multimask_output=False
97
+ )
98
+
99
+ mask = masks[0] # Get best mask
100
+
101
+ # Calculate statistics
102
+ mask_area = int(np.sum(mask))
103
+ total_pixels = mask.shape[0] * mask.shape[1]
104
+ mask_percentage = (mask_area / total_pixels) * 100
105
+
106
+ # Get contours
107
+ contours, _ = cv2.findContours(
108
+ mask.astype(np.uint8),
109
+ cv2.RETR_EXTERNAL,
110
+ cv2.CHAIN_APPROX_SIMPLE
111
+ )
112
+
113
+ # Get fish center
114
+ if len(contours) > 0:
115
+ largest_contour = max(contours, key=cv2.contourArea)
116
+ M = cv2.moments(largest_contour)
117
+ if M["m00"] != 0:
118
+ cx = int(M["m10"] / M["m00"])
119
+ cy = int(M["m01"] / M["m00"])
120
+ else:
121
+ cx, cy = int(best_box[0] + best_box[2]) // 2, int(best_box[1] + best_box[3]) // 2
122
+ else:
123
+ cx, cy = int(best_box[0] + best_box[2]) // 2, int(best_box[1] + best_box[3]) // 2
124
+
125
+ # Convert contours to list format
126
+ contour_points = []
127
+ if len(contours) > 0:
128
+ for point in contours[0][:100]: # Limit to 100 points
129
+ contour_points.append({
130
+ "x": int(point[0][0]),
131
+ "y": int(point[0][1])
132
+ })
133
+
134
+ metadata = {
135
+ "success": True,
136
+ "mode": "grounded_sam",
137
+ "detection_method": "grounding_dino_sam",
138
+ "fish_detected": True,
139
+ "grounding_dino": {
140
+ "confidence": best_conf,
141
+ "bounding_box": [int(x) for x in best_box],
142
+ "text_prompt": text_prompt,
143
+ "total_detections": len(detections.xyxy)
144
+ },
145
+ "mask_area": mask_area,
146
+ "mask_percentage": mask_percentage,
147
+ "num_contours": len(contours),
148
+ "fish_center": [cx, cy],
149
+ "image_size": list(mask.shape),
150
+ "device": device,
151
+ "contours": contour_points
152
+ }
153
+
154
+ print(f"βœ… Segmentation complete! Mask: {mask_percentage:.2f}%")
155
+
156
+ return mask, metadata
157
+
158
+ def process_image(image, quality="high"):
159
+ """Main processing function for Gradio interface"""
160
+
161
+ if image is None:
162
+ return None, "❌ No image provided"
163
+
164
+ try:
165
+ # Convert to PIL if needed
166
+ if isinstance(image, np.ndarray):
167
+ image_pil = Image.fromarray(image)
168
+ else:
169
+ image_pil = image
170
+
171
+ # Resize for faster processing on CPU
172
+ max_size = 1024 if quality == "high" else 768
173
+ image_pil.thumbnail((max_size, max_size), Image.Resampling.LANCZOS)
174
+
175
+ # Detect and segment fish
176
+ mask, metadata = detect_fish_with_grounded_sam(image_pil, text_prompt="fish")
177
+
178
+ if mask is None:
179
+ return None, f"❌ No fish detected!\n\n{metadata}"
180
+
181
+ # Create visualization
182
+ image_np = np.array(image_pil)
183
+
184
+ # Apply green overlay on fish
185
+ overlay = image_np.copy()
186
+ overlay[mask] = [0, 255, 0] # Green
187
+ result = cv2.addWeighted(image_np, 0.7, overlay, 0.3, 0)
188
+
189
+ # Draw bounding box
190
+ box = metadata["grounding_dino"]["bounding_box"]
191
+ cv2.rectangle(result, (box[0], box[1]), (box[2], box[3]), (255, 0, 0), 2)
192
+
193
+ # Add confidence text
194
+ conf_text = f"Fish: {metadata['grounding_dino']['confidence']:.2f}"
195
+ cv2.putText(result, conf_text, (box[0], box[1] - 10),
196
+ cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 0, 0), 2)
197
+
198
+ # Format metadata for display
199
+ meta_str = f"""βœ… Fish detected successfully!
200
+
201
+ 🎯 Grounding DINO
202
+ Confidence: {metadata['grounding_dino']['confidence']:.2%}
203
+ Bounding Box: {metadata['grounding_dino']['bounding_box']}
204
+ Detections: {metadata['grounding_dino']['total_detections']}
205
+
206
+ βœ‚οΈ SAM Segmentation
207
+ Mask Area: {metadata['mask_percentage']:.2f}%
208
+ Fish Center: {metadata['fish_center']}
209
+ Contours: {metadata['num_contours']}
210
+
211
+ βš™οΈ System
212
+ Device: {metadata['device']}
213
+ Image Size: {metadata['image_size']}
214
+ """
215
+
216
+ return result, meta_str
217
+
218
+ except Exception as e:
219
+ print(f"❌ Error: {str(e)}")
220
+ import traceback
221
+ traceback.print_exc()
222
+ return None, f"❌ Error: {str(e)}"
223
+
224
+ # Gradio Interface
225
+ with gr.Blocks(title="🎣 FishBoost - Grounded SAM Edition") as demo:
226
+ gr.Markdown("""
227
+ # 🎣 FishBoost - Grounded SAM Fish Detector
228
+ ### Powered by Grounding DINO + SAM
229
+
230
+ Upload an image with a fish and watch the AI detect and segment it!
231
+
232
+ ⚠️ **CPU Mode**: First run downloads ~680MB models (2-3 min). Processing: ~30-60 sec per image.
233
+ """)
234
+
235
+ with gr.Row():
236
+ with gr.Column():
237
+ input_image = gr.Image(type="pil", label="πŸ“€ Upload Fish Image")
238
+ quality = gr.Radio(
239
+ choices=["high", "medium"],
240
+ value="high",
241
+ label="🎨 Quality",
242
+ info="High = 1024px, Medium = 768px (faster)"
243
+ )
244
+ process_btn = gr.Button("πŸš€ Detect Fish", variant="primary")
245
+
246
+ with gr.Column():
247
+ output_image = gr.Image(label="🎯 Detected Fish (Green = Mask, Blue = Box)")
248
+ output_meta = gr.Textbox(label="πŸ“Š Detection Metadata", lines=15)
249
+
250
+ process_btn.click(
251
+ fn=process_image,
252
+ inputs=[input_image, quality],
253
+ outputs=[output_image, output_meta]
254
+ )
255
+
256
+ gr.Markdown("""
257
+ ---
258
+ ### πŸ”§ How it works
259
+ 1. **Grounding DINO** finds fish bounding boxes using text prompt "fish"
260
+ 2. **SAM** segments the exact fish shape within the box
261
+ 3. **Result**: Precise fish mask ignoring angler/background
262
+
263
+ ### πŸ“ Model Info
264
+ - Grounding DINO: Text-prompted object detection
265
+ - SAM (ViT-H): High-quality segmentation
266
+ - Total Model Size: ~680MB
267
+ """)
268
+
269
+ if __name__ == "__main__":
270
+ demo.launch(server_name="0.0.0.0", server_port=7860)