EnginDev commited on
Commit
84fdbb0
·
verified ·
1 Parent(s): 86db70f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +416 -224
app.py CHANGED
@@ -3,268 +3,460 @@ import torch
3
  import numpy as np
4
  from PIL import Image
5
  import cv2
6
- from groundingdino.util.inference import Model as GroundingDINOModel
7
- from segment_anything import sam_model_registry, SamPredictor
8
- import supervision as sv
9
 
10
- print("🚀 Starting Grounded SAM FishBoost Edition v5.0...")
11
 
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
  print(f"📱 Using device: {device}")
14
 
15
- grounding_dino_model = None
16
- sam_predictor = None
17
 
18
- def load_models():
19
- """Load Grounding DINO + SAM models"""
20
- global grounding_dino_model, sam_predictor
21
-
22
- if grounding_dino_model is None:
23
- print("📦 Loading Grounding DINO model...")
24
- grounding_dino_model = GroundingDINOModel(
25
- model_config_path="GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py",
26
- model_checkpoint_path="weights/groundingdino_swint_ogc.pth",
27
- device=device
28
- )
29
- print("✅ Grounding DINO loaded!")
30
-
31
- if sam_predictor is None:
32
  print("📦 Loading SAM model...")
33
- sam_checkpoint = "weights/sam_vit_h_4b8939.pth"
34
- model_type = "vit_h"
35
- sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
36
- sam.to(device=device)
37
- sam_predictor = SamPredictor(sam)
38
- print("✅ SAM loaded!")
 
 
 
 
 
 
 
 
 
 
39
 
40
- def detect_fish_with_grounded_sam(image_pil, text_prompt="fish", box_threshold=0.25, text_threshold=0.25):
41
- """
42
- Detect and segment fish using Grounding DINO + SAM
 
 
43
 
44
- Args:
45
- image_pil: PIL Image
46
- text_prompt: Text prompt for detection (default: "fish")
47
- box_threshold: Confidence threshold for boxes
48
- text_threshold: Confidence threshold for text matching
49
 
50
- Returns:
51
- mask: Binary mask of detected fish
52
- metadata: Detection metadata
53
- """
54
- load_models()
55
-
56
- # Convert PIL to numpy
57
  image_np = np.array(image_pil)
 
58
 
59
- # 1. Grounding DINO: Detect fish boxes
60
- print(f"🔍 Detecting '{text_prompt}' with Grounding DINO...")
61
- detections = grounding_dino_model.predict_with_classes(
62
- image=image_np,
63
- classes=[text_prompt],
64
- box_threshold=box_threshold,
65
- text_threshold=text_threshold
66
- )
67
-
68
- print(f"📦 Found {len(detections.xyxy)} boxes")
69
-
70
- if len(detections.xyxy) == 0:
71
- print("❌ No fish detected!")
72
- return None, {
73
- "success": False,
74
- "mode": "grounded_sam",
75
- "detection_method": "grounding_dino",
76
- "fish_detected": False,
77
- "reason": "No fish found in image"
78
- }
79
-
80
- # Select best detection (highest confidence)
81
- best_idx = np.argmax(detections.confidence)
82
- best_box = detections.xyxy[best_idx]
83
- best_conf = float(detections.confidence[best_idx])
84
-
85
- print(f"🎯 Best detection: Confidence={best_conf:.2f}, Box={best_box}")
86
-
87
- # 2. SAM: Segment the detected fish
88
- print("✂️ Segmenting with SAM...")
89
- sam_predictor.set_image(image_np)
90
-
91
- # Convert box to SAM format
92
- box_np = best_box.reshape(1, 4)
93
-
94
- masks, scores, _ = sam_predictor.predict(
95
- box=box_np,
96
- multimask_output=False
97
- )
98
-
99
- mask = masks[0] # Get best mask
100
-
101
- # Calculate statistics
102
- mask_area = int(np.sum(mask))
103
- total_pixels = mask.shape[0] * mask.shape[1]
104
- mask_percentage = (mask_area / total_pixels) * 100
105
 
106
- # Get contours
107
- contours, _ = cv2.findContours(
108
- mask.astype(np.uint8),
109
- cv2.RETR_EXTERNAL,
110
- cv2.CHAIN_APPROX_SIMPLE
111
- )
 
 
 
 
 
 
 
 
 
 
 
112
 
113
- # Get fish center
114
- if len(contours) > 0:
115
- largest_contour = max(contours, key=cv2.contourArea)
116
- M = cv2.moments(largest_contour)
117
- if M["m00"] != 0:
118
- cx = int(M["m10"] / M["m00"])
119
- cy = int(M["m01"] / M["m00"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  else:
121
- cx, cy = int(best_box[0] + best_box[2]) // 2, int(best_box[1] + best_box[3]) // 2
122
- else:
123
- cx, cy = int(best_box[0] + best_box[2]) // 2, int(best_box[1] + best_box[3]) // 2
124
-
125
- # Convert contours to list format
126
- contour_points = []
127
- if len(contours) > 0:
128
- for point in contours[0][:100]: # Limit to 100 points
129
- contour_points.append({
130
- "x": int(point[0][0]),
131
- "y": int(point[0][1])
132
- })
133
-
134
- metadata = {
135
- "success": True,
136
- "mode": "grounded_sam",
137
- "detection_method": "grounding_dino_sam",
138
- "fish_detected": True,
139
- "grounding_dino": {
140
- "confidence": best_conf,
141
- "bounding_box": [int(x) for x in best_box],
142
- "text_prompt": text_prompt,
143
- "total_detections": len(detections.xyxy)
144
- },
145
- "mask_area": mask_area,
146
- "mask_percentage": mask_percentage,
147
- "num_contours": len(contours),
148
- "fish_center": [cx, cy],
149
- "image_size": list(mask.shape),
150
- "device": device,
151
- "contours": contour_points
152
- }
153
-
154
- print(f"✅ Segmentation complete! Mask: {mask_percentage:.2f}%")
155
-
156
- return mask, metadata
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
 
158
- def process_image(image, quality="high"):
159
- """Main processing function for Gradio interface"""
160
-
161
  if image is None:
162
- return None, " No image provided"
163
 
164
  try:
165
- # Convert to PIL if needed
166
- if isinstance(image, np.ndarray):
167
- image_pil = Image.fromarray(image)
 
 
 
 
 
 
 
168
  else:
169
- image_pil = image
170
 
171
- # Resize for faster processing on CPU
172
- max_size = 1024 if quality == "high" else 768
173
- image_pil.thumbnail((max_size, max_size), Image.Resampling.LANCZOS)
 
 
 
 
174
 
175
- # Detect and segment fish
176
- mask, metadata = detect_fish_with_grounded_sam(image_pil, text_prompt="fish")
177
 
178
- if mask is None:
179
- return None, f"❌ No fish detected!\n\n{metadata}"
180
 
181
- # Create visualization
182
- image_np = np.array(image_pil)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
 
184
- # Apply green overlay on fish
185
  overlay = image_np.copy()
186
- overlay[mask] = [0, 255, 0] # Green
187
- result = cv2.addWeighted(image_np, 0.7, overlay, 0.3, 0)
188
 
189
- # Draw bounding box
190
- box = metadata["grounding_dino"]["bounding_box"]
191
- cv2.rectangle(result, (box[0], box[1]), (box[2], box[3]), (255, 0, 0), 2)
 
 
 
 
192
 
193
- # Add confidence text
194
- conf_text = f"Fish: {metadata['grounding_dino']['confidence']:.2f}"
195
- cv2.putText(result, conf_text, (box[0], box[1] - 10),
196
- cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 0, 0), 2)
 
 
 
 
 
 
 
 
 
 
 
197
 
198
- # Format metadata for display
199
- meta_str = f"""✅ Fish detected successfully!
200
-
201
- 🎯 Grounding DINO
202
- Confidence: {metadata['grounding_dino']['confidence']:.2%}
203
- Bounding Box: {metadata['grounding_dino']['bounding_box']}
204
- Detections: {metadata['grounding_dino']['total_detections']}
205
-
206
- ✂️ SAM Segmentation
207
- Mask Area: {metadata['mask_percentage']:.2f}%
208
- Fish Center: {metadata['fish_center']}
209
- Contours: {metadata['num_contours']}
210
-
211
- ⚙️ System
212
- Device: {metadata['device']}
213
- Image Size: {metadata['image_size']}
214
- """
215
 
216
- return result, meta_str
 
217
 
218
  except Exception as e:
219
- print(f"❌ Error: {str(e)}")
220
  import traceback
221
- traceback.print_exc()
222
- return None, f"❌ Error: {str(e)}"
223
 
224
  # Gradio Interface
225
- with gr.Blocks(title="🎣 FishBoost - Grounded SAM Edition") as demo:
226
- gr.Markdown("""
227
- # 🎣 FishBoost - Grounded SAM Fish Detector
228
- ### Powered by Grounding DINO + SAM
229
-
230
- Upload an image with a fish and watch the AI detect and segment it!
231
-
232
- ⚠️ **CPU Mode**: First run downloads ~680MB models (2-3 min). Processing: ~30-60 sec per image.
233
- """)
234
 
235
- with gr.Row():
236
- with gr.Column():
237
- input_image = gr.Image(type="pil", label="📤 Upload Fish Image")
238
- quality = gr.Radio(
239
- choices=["high", "medium"],
240
- value="high",
241
- label="🎨 Quality",
242
- info="High = 1024px, Medium = 768px (faster)"
243
- )
244
- process_btn = gr.Button("🚀 Detect Fish", variant="primary")
245
 
246
- with gr.Column():
247
- output_image = gr.Image(label="🎯 Detected Fish (Green = Mask, Blue = Box)")
248
- output_meta = gr.Textbox(label="📊 Detection Metadata", lines=15)
249
-
250
- process_btn.click(
251
- fn=process_image,
252
- inputs=[input_image, quality],
253
- outputs=[output_image, output_meta]
254
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
 
256
- gr.Markdown("""
257
- ---
258
- ### 🔧 How it works
259
- 1. **Grounding DINO** finds fish bounding boxes using text prompt "fish"
260
- 2. **SAM** segments the exact fish shape within the box
261
- 3. **Result**: Precise fish mask ignoring angler/background
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
262
 
263
- ### 📝 Model Info
264
- - Grounding DINO: Text-prompted object detection
265
- - SAM (ViT-H): High-quality segmentation
266
- - Total Model Size: ~680MB
267
- """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
268
 
269
  if __name__ == "__main__":
270
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
 
3
  import numpy as np
4
  from PIL import Image
5
  import cv2
 
 
 
6
 
7
+ print("🚀 Starting SAM2 App v2.1 - OPTIMIZED...")
8
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
  print(f"📱 Using device: {device}")
11
 
12
+ model = None
13
+ processor = None
14
 
15
+ def load_model():
16
+ global model, processor
17
+ if model is None:
 
 
 
 
 
 
 
 
 
 
 
18
  print("📦 Loading SAM model...")
19
+ try:
20
+ from transformers import SamModel, SamProcessor
21
+
22
+ model_name = "facebook/sam-vit-large"
23
+
24
+ processor = SamProcessor.from_pretrained(model_name)
25
+ model = SamModel.from_pretrained(model_name)
26
+ model.to(device)
27
+ print(f"✅ Model loaded: {model_name}")
28
+ except Exception as e:
29
+ print(f"❌ Error: {e}, falling back to base model")
30
+ model_name = "facebook/sam-vit-base"
31
+ processor = SamProcessor.from_pretrained(model_name)
32
+ model = SamModel.from_pretrained(model_name)
33
+ model.to(device)
34
+ return model, processor
35
 
36
+ def prepare_image(image, max_size=1024):
37
+ if isinstance(image, np.ndarray):
38
+ image_pil = Image.fromarray(image)
39
+ else:
40
+ image_pil = image
41
 
42
+ if image_pil.mode != 'RGB':
43
+ image_pil = image_pil.convert('RGB')
 
 
 
44
 
 
 
 
 
 
 
 
45
  image_np = np.array(image_pil)
46
+ h, w = image_np.shape[:2]
47
 
48
+ if max(h, w) > max_size:
49
+ scale = max_size / max(h, w)
50
+ new_h, new_w = int(h * scale), int(w * scale)
51
+ image_pil = image_pil.resize((new_w, new_h), Image.Resampling.LANCZOS)
52
+ image_np = np.array(image_pil)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
+ return image_pil, image_np
55
+
56
+ def refine_mask(mask, kernel_size=5):
57
+ """Glättet Maskenkanten"""
58
+ mask_uint8 = (mask > 0).astype(np.uint8) * 255
59
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
60
+ mask_closed = cv2.morphologyEx(mask_uint8, cv2.MORPH_CLOSE, kernel)
61
+ mask_refined = cv2.morphologyEx(mask_closed, cv2.MORPH_OPEN, kernel)
62
+ return mask_refined > 0
63
+
64
+ def segment_automatic(image, quality="high", merge_parts=True):
65
+ """
66
+ OPTIMIERTE Automatische Segmentierung
67
+ Schnell & präzise - kombiniert mehrere Masken
68
+ """
69
+ if image is None:
70
+ return None, {"error": "Kein Bild hochgeladen"}
71
 
72
+ try:
73
+ print(f"🔄 Starting segmentation (quality: {quality}, merge: {merge_parts})...")
74
+ model, processor = load_model()
75
+
76
+ image_pil, image_np = prepare_image(image)
77
+ h, w = image_np.shape[:2]
78
+
79
+ center_x, center_y = w // 2, h // 2
80
+
81
+ # Single point inference mit multimask_output
82
+ inputs = processor(
83
+ image_pil,
84
+ input_points=[[[center_x, center_y]]],
85
+ input_labels=[[1]],
86
+ return_tensors="pt"
87
+ ).to(device)
88
+
89
+ print("🧠 Running inference...")
90
+ with torch.no_grad():
91
+ outputs = model(**inputs, multimask_output=True)
92
+
93
+ masks = processor.image_processor.post_process_masks(
94
+ outputs.pred_masks.cpu(),
95
+ inputs["original_sizes"].cpu(),
96
+ inputs["reshaped_input_sizes"].cpu()
97
+ )[0]
98
+
99
+ scores = outputs.iou_scores.cpu().numpy()
100
+ if scores.ndim > 1:
101
+ scores = scores.flatten()
102
+
103
+ print(f"✅ Got {len(scores)} masks with scores: {scores}")
104
+
105
+ # SMART MERGING: Kombiniere alle guten Masken
106
+ if merge_parts:
107
+ combined_mask = np.zeros((h, w), dtype=bool)
108
+ masks_used = 0
109
+
110
+ for idx, score in enumerate(scores):
111
+ if score > 0.5: # Nur Masken mit gutem Score
112
+ if masks.ndim == 4:
113
+ mask = masks[0, idx].numpy()
114
+ else:
115
+ mask = masks[idx].numpy()
116
+
117
+ # OR-Kombination (super schnell!)
118
+ combined_mask = combined_mask | (mask > 0)
119
+ masks_used += 1
120
+ print(f" ✅ Added mask {idx} (score: {score:.3f})")
121
+
122
+ final_mask = combined_mask
123
+ print(f"🔗 Combined {masks_used} masks into one!")
124
  else:
125
+ # Nur beste Maske
126
+ best_idx = np.argmax(scores)
127
+ if masks.ndim == 4:
128
+ final_mask = masks[0, best_idx].numpy() > 0
129
+ else:
130
+ final_mask = masks[best_idx].numpy() > 0
131
+ masks_used = 1
132
+ print(f"✅ Using best mask (score: {scores[best_idx]:.3f})")
133
+
134
+ # Refinement für glatte Kanten
135
+ if quality == "high":
136
+ print("🎨 Refining mask...")
137
+ final_mask = refine_mask(final_mask, kernel_size=7)
138
+
139
+ # Overlay erstellen
140
+ overlay = image_np.copy()
141
+ color = np.array([255, 80, 180]) # Rosa/Pink
142
+
143
+ mask_float = final_mask.astype(float)
144
+ if quality == "high":
145
+ mask_float = cv2.GaussianBlur(mask_float, (5, 5), 0)
146
+
147
+ # Farbiges Overlay
148
+ for c in range(3):
149
+ overlay[:, :, c] = (
150
+ overlay[:, :, c] * (1 - mask_float * 0.65) +
151
+ color[c] * mask_float * 0.65
152
+ )
153
+
154
+ # Gelbe Kontur zeichnen
155
+ contours, _ = cv2.findContours(
156
+ final_mask.astype(np.uint8),
157
+ cv2.RETR_EXTERNAL,
158
+ cv2.CHAIN_APPROX_SIMPLE
159
+ )
160
+ cv2.drawContours(overlay, contours, -1, (255, 255, 0), 3)
161
+
162
+ metadata = {
163
+ "success": True,
164
+ "mode": "automatic_plus" if merge_parts else "automatic",
165
+ "quality": quality,
166
+ "masks_combined": masks_used,
167
+ "all_scores": scores.tolist(),
168
+ "image_size": [w, h],
169
+ "mask_area": int(np.sum(final_mask)),
170
+ "mask_percentage": float(np.sum(final_mask) / (h * w) * 100),
171
+ "num_contours": len(contours),
172
+ "device": device
173
+ }
174
+
175
+ print("✅ Segmentation complete!")
176
+ return Image.fromarray(overlay.astype(np.uint8)), metadata
177
+
178
+ except Exception as e:
179
+ import traceback
180
+ print(f"❌ ERROR:\n{traceback.format_exc()}")
181
+ return image, {"error": str(e)}
182
 
183
+ def segment_multi_dense(image, density="medium"):
184
+ """Multi-Object Segmentierung mit Grid"""
 
185
  if image is None:
186
+ return None, {"error": "Kein Bild"}
187
 
188
  try:
189
+ print(f"🎯 Starting multi-region segmentation (density: {density})...")
190
+ model, processor = load_model()
191
+ image_pil, image_np = prepare_image(image)
192
+ h, w = image_np.shape[:2]
193
+
194
+ # Grid-Größe basierend auf Density
195
+ if density == "high":
196
+ grid_size = 5
197
+ elif density == "medium":
198
+ grid_size = 4
199
  else:
200
+ grid_size = 3
201
 
202
+ # Grid-Punkte generieren
203
+ points = []
204
+ for i in range(1, grid_size + 1):
205
+ for j in range(1, grid_size + 1):
206
+ x = int(w * i / (grid_size + 1))
207
+ y = int(h * j / (grid_size + 1))
208
+ points.append([x, y])
209
 
210
+ print(f"📍 Using {len(points)} grid points ({grid_size}x{grid_size})...")
 
211
 
212
+ all_masks = []
213
+ all_scores = []
214
 
215
+ # Segmentiere jeden Punkt
216
+ for idx, point in enumerate(points):
217
+ inputs = processor(
218
+ image_pil,
219
+ input_points=[[point]],
220
+ input_labels=[[1]],
221
+ return_tensors="pt"
222
+ ).to(device)
223
+
224
+ with torch.no_grad():
225
+ outputs = model(**inputs, multimask_output=True)
226
+
227
+ masks = processor.image_processor.post_process_masks(
228
+ outputs.pred_masks.cpu(),
229
+ inputs["original_sizes"].cpu(),
230
+ inputs["reshaped_input_sizes"].cpu()
231
+ )[0]
232
+
233
+ scores = outputs.iou_scores.cpu().numpy().flatten()
234
+ best_idx = np.argmax(scores)
235
+
236
+ if masks.ndim == 4:
237
+ mask = masks[0, best_idx].numpy()
238
+ else:
239
+ mask = masks[best_idx].numpy()
240
+
241
+ # Nur Masken mit gutem Score
242
+ if scores[best_idx] > 0.7:
243
+ all_masks.append(refine_mask(mask))
244
+ all_scores.append(scores[best_idx])
245
+
246
+ print(f"✅ Got {len(all_masks)} quality masks")
247
 
248
+ # Overlay mit verschiedenen Farben
249
  overlay = image_np.copy()
 
 
250
 
251
+ # HSV-basierte Farbgenerierung
252
+ colors = []
253
+ for i in range(len(all_masks)):
254
+ hue = int(180 * i / max(len(all_masks), 1))
255
+ color_hsv = np.uint8([[[hue, 255, 200]]])
256
+ color_rgb = cv2.cvtColor(color_hsv, cv2.COLOR_HSV2RGB)[0][0]
257
+ colors.append(color_rgb)
258
 
259
+ # Masken anwenden
260
+ for mask, color, score in zip(all_masks, colors, all_scores):
261
+ alpha = 0.4 + (score - 0.7) * 0.2 # Höherer Score = stärkere Farbe
262
+ overlay[mask] = (
263
+ overlay[mask] * (1 - alpha) +
264
+ np.array(color) * alpha
265
+ ).astype(np.uint8)
266
+
267
+ # Kontur
268
+ contours, _ = cv2.findContours(
269
+ mask.astype(np.uint8),
270
+ cv2.RETR_EXTERNAL,
271
+ cv2.CHAIN_APPROX_SIMPLE
272
+ )
273
+ cv2.drawContours(overlay, contours, -1, color.tolist(), 2)
274
 
275
+ metadata = {
276
+ "success": True,
277
+ "mode": "multi_object_dense",
278
+ "density": density,
279
+ "grid_size": f"{grid_size}x{grid_size}",
280
+ "total_points": len(points),
281
+ "quality_masks": len(all_masks),
282
+ "avg_score": float(np.mean(all_scores)) if all_scores else 0,
283
+ "scores": [float(s) for s in all_scores]
284
+ }
 
 
 
 
 
 
 
285
 
286
+ print("✅ Multi-region complete!")
287
+ return Image.fromarray(overlay), metadata
288
 
289
  except Exception as e:
 
290
  import traceback
291
+ print(f"❌ ERROR:\n{traceback.format_exc()}")
292
+ return image, {"error": str(e)}
293
 
294
  # Gradio Interface
295
+ demo = gr.Blocks(title="SAM2 Boostly", theme=gr.themes.Soft())
296
+
297
+ with demo:
298
+ gr.Markdown("# 🎨 SAM2 Segmentierung - Boostly Edition")
299
+ gr.Markdown("### ⚡ Optimierte Zero-Shot Object Segmentation")
 
 
 
 
300
 
301
+ with gr.Tab("🤖 Automatisch PLUS"):
302
+ gr.Markdown("**Smart Multi-Mask Combining** - Kombiniert automatisch alle Objektteile!")
 
 
 
 
 
 
 
 
303
 
304
+ with gr.Row():
305
+ with gr.Column():
306
+ input_auto = gr.Image(type="pil", label="📸 Bild hochladen")
307
+
308
+ quality_radio = gr.Radio(
309
+ choices=["high", "fast"],
310
+ value="high",
311
+ label="⚙️ Qualität",
312
+ info="High = präzisere Kanten, Fast = schneller"
313
+ )
314
+
315
+ merge_checkbox = gr.Checkbox(
316
+ value=True,
317
+ label="🔗 Teile zusammenfügen",
318
+ info="Kombiniert alle erkannten Bereiche (Fisch + Flosse = 1 Objekt)"
319
+ )
320
+
321
+ btn_auto = gr.Button("🚀 Segmentieren", variant="primary", size="lg")
322
+
323
+ gr.Markdown("""
324
+ **✨ Funktionsweise:**
325
+ - SAM generiert 3 verschiedene Masken
326
+ - Wenn "Teile zusammenfügen" AN: Alle kombiniert → vollständiges Objekt
327
+ - Wenn AUS: Nur präziseste Maske
328
+ - ⚡ Optimiert: ~10-30 Sekunden statt 25 Minuten!
329
+ """)
330
+
331
+ with gr.Column():
332
+ output_auto = gr.Image(label="✨ Segmentiertes Bild")
333
+ json_auto = gr.JSON(label="📊 Metadata")
334
+
335
+ btn_auto.click(
336
+ fn=segment_automatic,
337
+ inputs=[input_auto, quality_radio, merge_checkbox],
338
+ outputs=[output_auto, json_auto]
339
+ )
340
+
341
+ gr.Examples(
342
+ examples=[],
343
+ inputs=input_auto,
344
+ label="💡 Tipp: Objekt sollte zentral im Bild sein"
345
+ )
346
 
347
+ with gr.Tab("🎯 Multi-Region"):
348
+ gr.Markdown("**Grid-basierte Segmentierung** - Für mehrere separate Objekte")
349
+
350
+ with gr.Row():
351
+ with gr.Column():
352
+ input_multi = gr.Image(type="pil", label="📸 Bild hochladen")
353
+
354
+ density_radio = gr.Radio(
355
+ choices=["high", "medium", "low"],
356
+ value="medium",
357
+ label="📊 Punkt-Dichte",
358
+ info="Mehr Punkte = mehr Details, aber langsamer"
359
+ )
360
+
361
+ btn_multi = gr.Button("🎯 Alle Bereiche segmentieren", variant="primary", size="lg")
362
+
363
+ gr.Markdown("""
364
+ **Grid-Größen:**
365
+ - 🔥 High: 5x5 = 25 Erkennungspunkte
366
+ - ⚡ Medium: 4x4 = 16 Punkte (empfohlen)
367
+ - 💨 Low: 3x3 = 9 Punkte
368
+
369
+ Jedes Objekt bekommt eigene Farbe!
370
+ """)
371
+
372
+ with gr.Column():
373
+ output_multi = gr.Image(label="✨ Segmentiertes Bild")
374
+ json_multi = gr.JSON(label="📊 Metadata")
375
+
376
+ btn_multi.click(
377
+ fn=segment_multi_dense,
378
+ inputs=[input_multi, density_radio],
379
+ outputs=[output_multi, json_multi]
380
+ )
381
 
382
+ with gr.Tab("📡 API Dokumentation"):
383
+ gr.Markdown("### 🔗 API Endpoint")
384
+ gr.Code(
385
+ "https://EnginDev-Boostly.hf.space/api/predict",
386
+ label="Base URL"
387
+ )
388
+
389
+ gr.Markdown("### 📝 JavaScript Integration (für Lovable)")
390
+ gr.Code('''
391
+ // Segmentation Service
392
+ const HUGGINGFACE_API = 'https://EnginDev-Boostly.hf.space';
393
+ async function segmentImage(imageFile, mode = 'automatic') {
394
+ // File zu Base64 konvertieren
395
+ const base64 = await new Promise((resolve) => {
396
+ const reader = new FileReader();
397
+ reader.onloadend = () => resolve(reader.result);
398
+ reader.readAsDataURL(imageFile);
399
+ });
400
+
401
+ // API Call
402
+ const response = await fetch(`${HUGGINGFACE_API}/api/predict`, {
403
+ method: 'POST',
404
+ headers: {'Content-Type': 'application/json'},
405
+ body: JSON.stringify({
406
+ data: [base64, "high", true], // [image, quality, merge]
407
+ fn_index: mode === 'automatic' ? 0 : 1
408
+ })
409
+ });
410
+
411
+ const result = await response.json();
412
+
413
+ return {
414
+ segmentedImage: result.data[0], // Base64 segmentiertes Bild
415
+ metadata: result.data[1] // JSON mit Details
416
+ };
417
+ }
418
+ // Verwendung:
419
+ const result = await segmentImage(myImageFile, 'automatic');
420
+ console.log('Mask covers:', result.metadata.mask_percentage + '%');
421
+ ''', language="javascript")
422
+
423
+ gr.Markdown("### ⚙️ Parameter")
424
+ gr.Markdown("""
425
+ **fn_index:**
426
+ - `0` = Automatisch PLUS (empfohlen für einzelne Objekte)
427
+ - `1` = Multi-Region (für mehrere Objekte)
428
+
429
+ **quality:**
430
+ - `"high"` = Präzise Kanten, Gaussian Blur, Refinement (~20-30s)
431
+ - `"fast"` = Schneller, weniger Nachbearbeitung (~10-15s)
432
+
433
+ **merge (nur fn_index=0):**
434
+ - `true` = Kombiniert alle Masken → vollständiges Objekt
435
+ - `false` = Nur beste Maske → nur Hauptteil
436
+
437
+ **density (nur fn_index=1):**
438
+ - `"high"` = 5x5 Grid = 25 Punkte
439
+ - `"medium"` = 4x4 Grid = 16 Punkte
440
+ - `"low"` = 3x3 Grid = 9 Punkte
441
+ """)
442
+
443
+ gr.Markdown("### 📊 Response Format")
444
+ gr.Code('''
445
+ {
446
+ "data": [
447
+ "...", // Segmentiertes Bild
448
+ {
449
+ "success": true,
450
+ "mode": "automatic_plus",
451
+ "masks_combined": 3,
452
+ "mask_percentage": 12.5,
453
+ "num_contours": 1,
454
+ "all_scores": [0.998, 0.583, 0.864]
455
+ }
456
+ ]
457
+ }
458
+ ''', language="json")
459
 
460
  if __name__ == "__main__":
461
+ print("🌐 Launching Boostly SAM2 v2.1...")
462
+ demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)