EnginDev commited on
Commit
86db70f
Β·
verified Β·
1 Parent(s): 9292061

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +215 -392
app.py CHANGED
@@ -3,445 +3,268 @@ import torch
3
  import numpy as np
4
  from PIL import Image
5
  import cv2
 
 
 
6
 
7
- print("πŸš€ Starting SAM2 FishBoost Edition v4.0 - ULTRA OPTIMIZED...")
8
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
  print(f"πŸ“± Using device: {device}")
11
 
12
- model = None
13
- processor = None
14
 
15
- def load_model():
16
- global model, processor
17
- if model is None:
 
 
 
 
 
 
 
 
 
 
 
18
  print("πŸ“¦ Loading SAM model...")
19
- try:
20
- from transformers import SamModel, SamProcessor
21
-
22
- model_name = "facebook/sam-vit-large"
23
-
24
- processor = SamProcessor.from_pretrained(model_name)
25
- model = SamModel.from_pretrained(model_name)
26
- model.to(device)
27
- print(f"βœ… Model loaded: {model_name}")
28
- except Exception as e:
29
- print(f"❌ Error: {e}, falling back to base model")
30
- model_name = "facebook/sam-vit-base"
31
- processor = SamProcessor.from_pretrained(model_name)
32
- model = SamModel.from_pretrained(model_name)
33
- model.to(device)
34
- return model, processor
35
 
36
- def prepare_image(image, max_size=1024):
37
- if isinstance(image, np.ndarray):
38
- image_pil = Image.fromarray(image)
39
- else:
40
- image_pil = image
41
 
42
- if image_pil.mode != 'RGB':
43
- image_pil = image_pil.convert('RGB')
 
 
 
44
 
45
- image_np = np.array(image_pil)
46
- h, w = image_np.shape[:2]
 
 
 
47
 
48
- if max(h, w) > max_size:
49
- scale = max_size / max(h, w)
50
- new_h, new_w = int(h * scale), int(w * scale)
51
- image_pil = image_pil.resize((new_w, new_h), Image.Resampling.LANCZOS)
52
- image_np = np.array(image_pil)
53
 
54
- return image_pil, image_np
55
-
56
- def refine_mask(mask, kernel_size=5):
57
- """GlΓ€ttet Maskenkanten"""
58
- mask_uint8 = (mask > 0).astype(np.uint8) * 255
59
- kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
60
- mask_closed = cv2.morphologyEx(mask_uint8, cv2.MORPH_CLOSE, kernel)
61
- mask_refined = cv2.morphologyEx(mask_closed, cv2.MORPH_OPEN, kernel)
62
- return mask_refined > 0
63
-
64
- def calculate_mask_center(mask):
65
- """Berechnet Schwerpunkt der Maske"""
66
- y_coords, x_coords = np.where(mask)
67
- if len(x_coords) == 0:
68
- return None, None
69
- return np.mean(x_coords), np.mean(y_coords)
70
-
71
- def extract_contours_from_mask(mask):
72
- """Extrahiert Konturen als [{x, y}, ...] Format"""
73
- contours, _ = cv2.findContours(
74
- mask.astype(np.uint8),
75
- cv2.RETR_EXTERNAL,
76
- cv2.CHAIN_APPROX_SIMPLE
77
  )
78
 
79
- if not contours:
80
- return []
 
 
 
 
 
 
 
 
 
81
 
82
- # Grâßte Kontur wÀhlen
83
- largest_contour = max(contours, key=cv2.contourArea)
 
 
84
 
85
- # Format konvertieren: [[x, y]] -> [{x: int, y: int}]
86
- points = []
87
- for point in largest_contour:
88
- x, y = point[0]
89
- points.append({"x": int(x), "y": int(y)})
90
 
91
- return points
92
-
93
- def generate_grid_points(w, h, grid_size=3):
94
- """Generiert Grid-Punkte ΓΌber das Bild verteilt"""
95
- points = []
96
- for i in range(1, grid_size + 1):
97
- for j in range(1, grid_size + 1):
98
- x = int(w * i / (grid_size + 1))
99
- y = int(h * j / (grid_size + 1))
100
- points.append([x, y])
101
- return points
102
-
103
- def select_best_fish_mask(all_masks, all_scores, image_shape):
104
- """
105
- 🎣 ULTRA-INTELLIGENTE FISCH-AUSWAHL
106
 
107
- Strategie:
108
- 1. Filtere sehr große Masken (>15% = Hintergrund/Angler)
109
- 2. Filtere sehr kleine Masken (<2% = Noise)
110
- 3. WΓ€hle KLEINSTE verbleibende Maske (= Fisch)
111
- """
112
- h, w = image_shape
113
- image_center_x, image_center_y = w / 2, h / 2
114
- total_pixels = h * w
115
 
116
- valid_masks = []
 
 
 
117
 
118
- print(f"\nπŸ” Analyzing {len(all_masks)} candidate masks...")
119
 
120
- for mask_data in all_masks:
121
- mask = mask_data['mask']
122
- score = mask_data['score']
123
- point = mask_data['point']
124
-
125
- # Coverage berechnen
126
- mask_area = np.sum(mask)
127
- coverage = mask_area / total_pixels
128
-
129
- # 🚫 FILTER 1: Zu groß (Hintergrund/Angler)
130
- if coverage > 0.15: # 15% Threshold (vorher 60%)
131
- print(f" ❌ Rejected: Coverage {coverage*100:.1f}% > 15% (Background)")
132
- continue
133
-
134
- # 🚫 FILTER 2: Zu klein (Noise)
135
- if coverage < 0.02: # 2% Minimum
136
- print(f" ❌ Rejected: Coverage {coverage*100:.1f}% < 2% (Noise)")
137
- continue
138
-
139
- # 🚫 FILTER 3: Schlechter Score
140
- if score < 0.7:
141
- print(f" ❌ Rejected: Score {score:.3f} < 0.7")
142
- continue
143
-
144
- # Center Distance berechnen
145
- center_x, center_y = calculate_mask_center(mask)
146
- if center_x is None:
147
- continue
148
-
149
- distance_to_center = np.sqrt(
150
- (center_x - image_center_x)**2 +
151
- (center_y - image_center_y)**2
152
- )
153
-
154
- valid_masks.append({
155
- 'mask': mask,
156
- 'score': score,
157
- 'area': mask_area,
158
- 'coverage': coverage,
159
- 'center': (center_x, center_y),
160
- 'distance_to_center': distance_to_center,
161
- 'point': point
162
- })
163
-
164
- print(f" βœ… Valid: coverage={coverage*100:.1f}%, score={score:.3f}, dist={distance_to_center:.0f}px")
165
 
166
- if not valid_masks:
167
- print(" ❌ No valid fish masks found!")
168
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
- # 🎯 STRATEGIE: WÀhle KLEINSTE Maske (= Fisch, nicht Angler)
171
- valid_masks.sort(key=lambda m: m['coverage'])
172
- best_mask = valid_masks[0]
 
 
 
 
 
173
 
174
- print(f"\n πŸ† SELECTED: Smallest mask (coverage: {best_mask['coverage']*100:.1f}%)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
 
176
- return best_mask
 
 
177
 
178
- def segment_automatic(image, quality="high", mode="fish"):
179
- """
180
- 🎣 ULTRA-OPTIMIZED Fish Detection
181
- - Multi-Point Grid (9 Punkte statt nur Mitte)
182
- - 15% Coverage Filter (statt 60%)
183
- - Kleinste Maske = Fisch
184
- """
185
  if image is None:
186
- return None, {"error": "Kein Bild hochgeladen"}
187
 
188
  try:
189
- print(f"\n{'='*60}")
190
- print(f"πŸ”„ Starting ULTRA segmentation (mode: {mode}, quality: {quality})")
191
- print(f"{'='*60}")
192
-
193
- model, processor = load_model()
194
-
195
- image_pil, image_np = prepare_image(image)
196
- h, w = image_np.shape[:2]
197
-
198
- if mode == "fish":
199
- # πŸ†• MULTI-POINT GRID (statt nur Bildmitte)
200
- grid_points = generate_grid_points(w, h, grid_size=3)
201
- print(f"πŸ“ Using {len(grid_points)} grid points for detection")
202
  else:
203
- # Fallback: nur Bildmitte
204
- grid_points = [[w // 2, h // 2]]
205
-
206
- all_masks = []
207
-
208
- # FΓΌr jeden Grid-Punkt: Maske generieren
209
- for idx, point in enumerate(grid_points):
210
- inputs = processor(
211
- image_pil,
212
- input_points=[[point]],
213
- input_labels=[[1]],
214
- return_tensors="pt"
215
- ).to(device)
216
-
217
- with torch.no_grad():
218
- outputs = model(**inputs, multimask_output=True)
219
-
220
- masks = processor.image_processor.post_process_masks(
221
- outputs.pred_masks.cpu(),
222
- inputs["original_sizes"].cpu(),
223
- inputs["reshaped_input_sizes"].cpu()
224
- )[0]
225
-
226
- scores = outputs.iou_scores.cpu().numpy().flatten()
227
-
228
- # Beste Maske dieses Punktes
229
- best_idx = np.argmax(scores)
230
- if masks.ndim == 4:
231
- mask = masks[0, best_idx].numpy()
232
- else:
233
- mask = masks[best_idx].numpy()
234
-
235
- all_masks.append({
236
- 'mask': mask > 0,
237
- 'score': scores[best_idx],
238
- 'point': point
239
- })
240
-
241
- print(f"βœ… Generated {len(all_masks)} masks from grid points")
242
-
243
- # 🎣 BESTE FISCH-MASKE WΓ„HLEN
244
- best_fish = select_best_fish_mask(all_masks, None, (h, w))
245
 
246
- if best_fish is None:
247
- return None, {
248
- "error": "No fish detected. Image might contain only background/angler.",
249
- "suggestion": "Try 'Multi-Object' mode or use a different image."
250
- }
251
 
252
- final_mask = best_fish['mask']
 
253
 
254
- # Refinement
255
- if quality == "high":
256
- print("🎨 Refining mask edges...")
257
- final_mask = refine_mask(final_mask, kernel_size=7)
258
 
259
- # πŸ†• KONTUREN EXTRAHIEREN
260
- contours_list = extract_contours_from_mask(final_mask)
261
 
262
- # Overlay erstellen
263
  overlay = image_np.copy()
264
- color = np.array([0, 255, 100]) # GrΓΌn fΓΌr Fisch
265
-
266
- mask_float = final_mask.astype(float)
267
- if quality == "high":
268
- mask_float = cv2.GaussianBlur(mask_float, (5, 5), 0)
269
-
270
- for c in range(3):
271
- overlay[:, :, c] = (
272
- overlay[:, :, c] * (1 - mask_float * 0.65) +
273
- color[c] * mask_float * 0.65
274
- )
275
-
276
- # Kontur zeichnen
277
- contours_cv, _ = cv2.findContours(
278
- final_mask.astype(np.uint8),
279
- cv2.RETR_EXTERNAL,
280
- cv2.CHAIN_APPROX_SIMPLE
281
- )
282
- cv2.drawContours(overlay, contours_cv, -1, (255, 255, 0), 3)
283
 
284
- # Metadata
285
- mask_area = int(np.sum(final_mask))
286
- mask_percentage = float(mask_area / (h * w) * 100)
287
 
288
- metadata = {
289
- "success": True,
290
- "mode": "automatic_fish_ultra_optimized",
291
- "quality": quality,
292
- "detection_method": "multi_point_grid" if mode == "fish" else "center_point",
293
- "grid_points_used": len(grid_points),
294
- "image_size": [w, h],
295
- "mask_area": mask_area,
296
- "mask_percentage": mask_percentage,
297
- "num_contours": len(contours_cv),
298
- "fish_score": float(best_fish['score']),
299
- "fish_center": [float(best_fish['center'][0]), float(best_fish['center'][1])],
300
- "device": device,
301
- "contours": contours_list
302
- }
303
 
304
- print(f"\n{'='*60}")
305
- print(f"βœ… SEGMENTATION COMPLETE!")
306
- print(f" Fish coverage: {mask_percentage:.1f}%")
307
- print(f" Confidence: {best_fish['score']*100:.1f}%")
308
- print(f" Contour points: {len(contours_list)}")
309
- print(f"{'='*60}\n")
 
 
 
 
 
 
 
 
 
 
 
310
 
311
- return Image.fromarray(overlay.astype(np.uint8)), metadata
312
 
313
  except Exception as e:
 
314
  import traceback
315
- print(f"❌ ERROR:\n{traceback.format_exc()}")
316
- return image, {"error": str(e)}
317
 
318
  # Gradio Interface
319
- demo = gr.Blocks(title="SAM2 FishBoost ULTRA", theme=gr.themes.Soft())
320
-
321
- with demo:
322
- gr.Markdown("# 🎣 SAM2 FishBoost ULTRA v4.0")
323
- gr.Markdown("### Multi-Point Grid Detection + 15% Coverage Filter")
324
 
325
- with gr.Tab("🎣 Fish Detection (ULTRA)"):
326
- gr.Markdown("""
327
- **πŸš€ NEUE FEATURES:**
328
- - βœ… 9-Punkt Grid Detection (nicht nur Bildmitte!)
329
- - βœ… 15% Coverage Filter (filtert Angler/Hintergrund)
330
- - βœ… Kleinste Maske = Fisch
331
- """)
332
-
333
- with gr.Row():
334
- with gr.Column():
335
- input_fish = gr.Image(type="pil", label="πŸ“Έ Bild hochladen")
336
-
337
- quality_radio = gr.Radio(
338
- choices=["high", "fast"],
339
- value="high",
340
- label="βš™οΈ QualitΓ€t"
341
- )
342
-
343
- mode_radio = gr.Radio(
344
- choices=["fish", "multi"],
345
- value="fish",
346
- label="🎯 Modus",
347
- info="Fish = Multi-Point Grid, Multi = Center Only"
348
- )
349
-
350
- btn_fish = gr.Button("🎣 Fisch segmentieren", variant="primary", size="lg")
351
-
352
- gr.Markdown("""
353
- **πŸ’‘ Wie es funktioniert:**
354
-
355
- **Fish Mode (ULTRA):**
356
- 1. Scannt Bild mit 9 Punkten (3x3 Grid)
357
- 2. Ignoriert große Objekte (>15% = Angler)
358
- 3. Ignoriert kleine Objekte (<2% = Noise)
359
- 4. WΓ€hlt kleinste Maske (= Fisch!)
360
-
361
- **Multi Mode:**
362
- - Alte Methode (nur Bildmitte)
363
- - FΓΌr allgemeine Objekte
364
- """)
365
-
366
- with gr.Column():
367
- output_fish = gr.Image(label="✨ Segmentierter Fisch")
368
- json_fish = gr.JSON(label="πŸ“Š Metadata")
369
-
370
- btn_fish.click(
371
- fn=segment_automatic,
372
- inputs=[input_fish, quality_radio, mode_radio],
373
- outputs=[output_fish, json_fish]
374
- )
375
-
376
- gr.Examples(
377
- examples=[],
378
- inputs=input_fish,
379
- label="πŸ’‘ Upload dein Angelfoto!"
380
- )
381
 
382
- with gr.Tab("πŸ“‘ API Integration (Lovable)"):
383
- gr.Markdown("### πŸ”— API Endpoint")
384
- gr.Code("https://EnginDev-Boostly.hf.space/api/predict", label="Base URL")
385
-
386
- gr.Markdown("### πŸ“ JavaScript Code")
387
- gr.Code('''
388
- // Fish Detection ULTRA
389
- const response = await fetch('https://EnginDev-Boostly.hf.space/api/predict', {
390
- method: 'POST',
391
- headers: {'Content-Type': 'application/json'},
392
- body: JSON.stringify({
393
- data: [
394
- base64Image, // Base64 image
395
- "high", // quality: "high" | "fast"
396
- "fish" // mode: "fish" (ULTRA) | "multi"
397
- ],
398
- fn_index: 0
399
- })
400
- });
401
-
402
- const result = await response.json();
403
-
404
- // Expected Response:
405
- {
406
- "data": [
407
- "data:image/png;base64,iVBORw...", // Segmented overlay
408
- {
409
- "success": true,
410
- "mode": "automatic_fish_ultra_optimized",
411
- "detection_method": "multi_point_grid",
412
- "grid_points_used": 9,
413
- "mask_percentage": 8.2, // Nur der Fisch! (nicht 86%)
414
- "fish_score": 0.98,
415
- "fish_center": [385, 520],
416
- "contours": [
417
- {"x": 350, "y": 450},
418
- {"x": 351, "y": 451},
419
- // ... prΓ€zise Fisch-Kontur
420
- ]
421
- }
422
- ]
423
- }
424
- ''', language="javascript")
425
-
426
- gr.Markdown("""
427
- ### βš™οΈ Parameter ErklΓ€rung
428
-
429
- **mode: "fish"** (ULTRA - EMPFOHLEN fΓΌr Angelfotos)
430
- - Multi-Point Grid (9 Erkennungspunkte)
431
- - 15% Coverage Filter
432
- - Kleinste Maske = Fisch
433
- - βœ… Perfekt fΓΌr: Angler mit Fisch im Bild
434
-
435
- **mode: "multi"**
436
- - Center-Point Only (alte Methode)
437
- - 60% Coverage Filter
438
- - βœ… FΓΌr allgemeine Objekte
439
 
440
- **quality:**
441
- - `"high"` = PrΓ€zise Kanten, Gaussian Blur (~20s)
442
- - `"fast"` = Schneller, weniger Nachbearbeitung (~10s)
443
- """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
444
 
445
  if __name__ == "__main__":
446
- print("🌐 Launching FishBoost SAM2 ULTRA v4.0...")
447
- demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)
 
3
  import numpy as np
4
  from PIL import Image
5
  import cv2
6
+ from groundingdino.util.inference import Model as GroundingDINOModel
7
+ from segment_anything import sam_model_registry, SamPredictor
8
+ import supervision as sv
9
 
10
+ print("πŸš€ Starting Grounded SAM FishBoost Edition v5.0...")
11
 
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
  print(f"πŸ“± Using device: {device}")
14
 
15
+ grounding_dino_model = None
16
+ sam_predictor = None
17
 
18
+ def load_models():
19
+ """Load Grounding DINO + SAM models"""
20
+ global grounding_dino_model, sam_predictor
21
+
22
+ if grounding_dino_model is None:
23
+ print("πŸ“¦ Loading Grounding DINO model...")
24
+ grounding_dino_model = GroundingDINOModel(
25
+ model_config_path="GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py",
26
+ model_checkpoint_path="weights/groundingdino_swint_ogc.pth",
27
+ device=device
28
+ )
29
+ print("βœ… Grounding DINO loaded!")
30
+
31
+ if sam_predictor is None:
32
  print("πŸ“¦ Loading SAM model...")
33
+ sam_checkpoint = "weights/sam_vit_h_4b8939.pth"
34
+ model_type = "vit_h"
35
+ sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
36
+ sam.to(device=device)
37
+ sam_predictor = SamPredictor(sam)
38
+ print("βœ… SAM loaded!")
 
 
 
 
 
 
 
 
 
 
39
 
40
+ def detect_fish_with_grounded_sam(image_pil, text_prompt="fish", box_threshold=0.25, text_threshold=0.25):
41
+ """
42
+ Detect and segment fish using Grounding DINO + SAM
 
 
43
 
44
+ Args:
45
+ image_pil: PIL Image
46
+ text_prompt: Text prompt for detection (default: "fish")
47
+ box_threshold: Confidence threshold for boxes
48
+ text_threshold: Confidence threshold for text matching
49
 
50
+ Returns:
51
+ mask: Binary mask of detected fish
52
+ metadata: Detection metadata
53
+ """
54
+ load_models()
55
 
56
+ # Convert PIL to numpy
57
+ image_np = np.array(image_pil)
 
 
 
58
 
59
+ # 1. Grounding DINO: Detect fish boxes
60
+ print(f"πŸ” Detecting '{text_prompt}' with Grounding DINO...")
61
+ detections = grounding_dino_model.predict_with_classes(
62
+ image=image_np,
63
+ classes=[text_prompt],
64
+ box_threshold=box_threshold,
65
+ text_threshold=text_threshold
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  )
67
 
68
+ print(f"πŸ“¦ Found {len(detections.xyxy)} boxes")
69
+
70
+ if len(detections.xyxy) == 0:
71
+ print("❌ No fish detected!")
72
+ return None, {
73
+ "success": False,
74
+ "mode": "grounded_sam",
75
+ "detection_method": "grounding_dino",
76
+ "fish_detected": False,
77
+ "reason": "No fish found in image"
78
+ }
79
 
80
+ # Select best detection (highest confidence)
81
+ best_idx = np.argmax(detections.confidence)
82
+ best_box = detections.xyxy[best_idx]
83
+ best_conf = float(detections.confidence[best_idx])
84
 
85
+ print(f"🎯 Best detection: Confidence={best_conf:.2f}, Box={best_box}")
 
 
 
 
86
 
87
+ # 2. SAM: Segment the detected fish
88
+ print("βœ‚οΈ Segmenting with SAM...")
89
+ sam_predictor.set_image(image_np)
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
+ # Convert box to SAM format
92
+ box_np = best_box.reshape(1, 4)
 
 
 
 
 
 
93
 
94
+ masks, scores, _ = sam_predictor.predict(
95
+ box=box_np,
96
+ multimask_output=False
97
+ )
98
 
99
+ mask = masks[0] # Get best mask
100
 
101
+ # Calculate statistics
102
+ mask_area = int(np.sum(mask))
103
+ total_pixels = mask.shape[0] * mask.shape[1]
104
+ mask_percentage = (mask_area / total_pixels) * 100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
+ # Get contours
107
+ contours, _ = cv2.findContours(
108
+ mask.astype(np.uint8),
109
+ cv2.RETR_EXTERNAL,
110
+ cv2.CHAIN_APPROX_SIMPLE
111
+ )
112
+
113
+ # Get fish center
114
+ if len(contours) > 0:
115
+ largest_contour = max(contours, key=cv2.contourArea)
116
+ M = cv2.moments(largest_contour)
117
+ if M["m00"] != 0:
118
+ cx = int(M["m10"] / M["m00"])
119
+ cy = int(M["m01"] / M["m00"])
120
+ else:
121
+ cx, cy = int(best_box[0] + best_box[2]) // 2, int(best_box[1] + best_box[3]) // 2
122
+ else:
123
+ cx, cy = int(best_box[0] + best_box[2]) // 2, int(best_box[1] + best_box[3]) // 2
124
 
125
+ # Convert contours to list format
126
+ contour_points = []
127
+ if len(contours) > 0:
128
+ for point in contours[0][:100]: # Limit to 100 points
129
+ contour_points.append({
130
+ "x": int(point[0][0]),
131
+ "y": int(point[0][1])
132
+ })
133
 
134
+ metadata = {
135
+ "success": True,
136
+ "mode": "grounded_sam",
137
+ "detection_method": "grounding_dino_sam",
138
+ "fish_detected": True,
139
+ "grounding_dino": {
140
+ "confidence": best_conf,
141
+ "bounding_box": [int(x) for x in best_box],
142
+ "text_prompt": text_prompt,
143
+ "total_detections": len(detections.xyxy)
144
+ },
145
+ "mask_area": mask_area,
146
+ "mask_percentage": mask_percentage,
147
+ "num_contours": len(contours),
148
+ "fish_center": [cx, cy],
149
+ "image_size": list(mask.shape),
150
+ "device": device,
151
+ "contours": contour_points
152
+ }
153
 
154
+ print(f"βœ… Segmentation complete! Mask: {mask_percentage:.2f}%")
155
+
156
+ return mask, metadata
157
 
158
+ def process_image(image, quality="high"):
159
+ """Main processing function for Gradio interface"""
160
+
 
 
 
 
161
  if image is None:
162
+ return None, "❌ No image provided"
163
 
164
  try:
165
+ # Convert to PIL if needed
166
+ if isinstance(image, np.ndarray):
167
+ image_pil = Image.fromarray(image)
 
 
 
 
 
 
 
 
 
 
168
  else:
169
+ image_pil = image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
 
171
+ # Resize for faster processing on CPU
172
+ max_size = 1024 if quality == "high" else 768
173
+ image_pil.thumbnail((max_size, max_size), Image.Resampling.LANCZOS)
 
 
174
 
175
+ # Detect and segment fish
176
+ mask, metadata = detect_fish_with_grounded_sam(image_pil, text_prompt="fish")
177
 
178
+ if mask is None:
179
+ return None, f"❌ No fish detected!\n\n{metadata}"
 
 
180
 
181
+ # Create visualization
182
+ image_np = np.array(image_pil)
183
 
184
+ # Apply green overlay on fish
185
  overlay = image_np.copy()
186
+ overlay[mask] = [0, 255, 0] # Green
187
+ result = cv2.addWeighted(image_np, 0.7, overlay, 0.3, 0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
 
189
+ # Draw bounding box
190
+ box = metadata["grounding_dino"]["bounding_box"]
191
+ cv2.rectangle(result, (box[0], box[1]), (box[2], box[3]), (255, 0, 0), 2)
192
 
193
+ # Add confidence text
194
+ conf_text = f"Fish: {metadata['grounding_dino']['confidence']:.2f}"
195
+ cv2.putText(result, conf_text, (box[0], box[1] - 10),
196
+ cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 0, 0), 2)
 
 
 
 
 
 
 
 
 
 
 
197
 
198
+ # Format metadata for display
199
+ meta_str = f"""βœ… Fish detected successfully!
200
+
201
+ 🎯 Grounding DINO
202
+ Confidence: {metadata['grounding_dino']['confidence']:.2%}
203
+ Bounding Box: {metadata['grounding_dino']['bounding_box']}
204
+ Detections: {metadata['grounding_dino']['total_detections']}
205
+
206
+ βœ‚οΈ SAM Segmentation
207
+ Mask Area: {metadata['mask_percentage']:.2f}%
208
+ Fish Center: {metadata['fish_center']}
209
+ Contours: {metadata['num_contours']}
210
+
211
+ βš™οΈ System
212
+ Device: {metadata['device']}
213
+ Image Size: {metadata['image_size']}
214
+ """
215
 
216
+ return result, meta_str
217
 
218
  except Exception as e:
219
+ print(f"❌ Error: {str(e)}")
220
  import traceback
221
+ traceback.print_exc()
222
+ return None, f"❌ Error: {str(e)}"
223
 
224
  # Gradio Interface
225
+ with gr.Blocks(title="🎣 FishBoost - Grounded SAM Edition") as demo:
226
+ gr.Markdown("""
227
+ # 🎣 FishBoost - Grounded SAM Fish Detector
228
+ ### Powered by Grounding DINO + SAM
 
229
 
230
+ Upload an image with a fish and watch the AI detect and segment it!
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
 
232
+ ⚠️ **CPU Mode**: First run downloads ~680MB models (2-3 min). Processing: ~30-60 sec per image.
233
+ """)
234
+
235
+ with gr.Row():
236
+ with gr.Column():
237
+ input_image = gr.Image(type="pil", label="πŸ“€ Upload Fish Image")
238
+ quality = gr.Radio(
239
+ choices=["high", "medium"],
240
+ value="high",
241
+ label="🎨 Quality",
242
+ info="High = 1024px, Medium = 768px (faster)"
243
+ )
244
+ process_btn = gr.Button("πŸš€ Detect Fish", variant="primary")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
 
246
+ with gr.Column():
247
+ output_image = gr.Image(label="🎯 Detected Fish (Green = Mask, Blue = Box)")
248
+ output_meta = gr.Textbox(label="πŸ“Š Detection Metadata", lines=15)
249
+
250
+ process_btn.click(
251
+ fn=process_image,
252
+ inputs=[input_image, quality],
253
+ outputs=[output_image, output_meta]
254
+ )
255
+
256
+ gr.Markdown("""
257
+ ---
258
+ ### πŸ”§ How it works
259
+ 1. **Grounding DINO** finds fish bounding boxes using text prompt "fish"
260
+ 2. **SAM** segments the exact fish shape within the box
261
+ 3. **Result**: Precise fish mask ignoring angler/background
262
+
263
+ ### πŸ“ Model Info
264
+ - Grounding DINO: Text-prompted object detection
265
+ - SAM (ViT-H): High-quality segmentation
266
+ - Total Model Size: ~680MB
267
+ """)
268
 
269
  if __name__ == "__main__":
270
+ demo.launch(server_name="0.0.0.0", server_port=7860)