EnginDev commited on
Commit
eeb2177
·
verified ·
1 Parent(s): 4ceec40

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +196 -40
app.py CHANGED
@@ -1,52 +1,208 @@
1
  import gradio as gr
2
- import numpy as np
3
  import torch
4
- import cv2
5
  from PIL import Image
6
- import os
7
- import urllib.request
8
- from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
9
-
10
- # Modell laden oder herunterladen
11
- MODEL_URL = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"
12
- MODEL_PATH = "sam_vit_b_01ec64.pth"
13
-
14
- if not os.path.exists(MODEL_PATH):
15
- print("Modell wird heruntergeladen...")
16
- urllib.request.urlretrieve(MODEL_URL, MODEL_PATH)
17
- print("Modell heruntergeladen.")
18
 
19
- # Modelltyp
20
- model_type = "vit_b"
 
 
21
  device = "cuda" if torch.cuda.is_available() else "cpu"
22
- sam = sam_model_registry[model_type](checkpoint=MODEL_PATH)
23
- sam.to(device=device)
24
 
25
- mask_generator = SamAutomaticMaskGenerator(sam)
 
 
 
 
 
 
 
26
 
27
- def segment_all_objects(image):
 
 
 
 
 
28
  image_np = np.array(image)
29
- masks = mask_generator.generate(image_np)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  overlay = image_np.copy()
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
- for i, mask in enumerate(masks):
33
- m = mask["segmentation"]
34
- color = np.random.randint(0, 255, size=(3,))
35
- overlay[m] = overlay[m] * 0.3 + color * 0.7
36
- y, x = np.where(m)
37
- if len(x) > 0 and len(y) > 0:
38
- cx, cy = int(np.mean(x)), int(np.mean(y))
39
- cv2.putText(overlay, f"Obj {i+1}", (cx, cy),
40
- cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
41
-
42
- return Image.fromarray(overlay.astype(np.uint8))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
- demo = gr.Interface(
45
- fn=segment_all_objects,
46
- inputs=gr.Image(type="pil", label="Bild hochladen"),
47
- outputs=gr.Image(type="pil", label="Segmentiertes Ergebnis"),
48
- title="FishBoost SAM (Meta Original)",
49
- description="Segmentiert automatisch alle Objekte im Bild mit Metas offiziellem SAM-Modell."
50
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
 
2
  import torch
3
+ import numpy as np
4
  from PIL import Image
5
+ import cv2
6
+ from transformers import SamModel, SamProcessor
7
+ import json
 
 
 
 
 
 
 
 
 
8
 
9
+ # SAM Model laden
10
+ print("Lade SAM Model...")
11
+ processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
12
+ model = SamModel.from_pretrained("facebook/sam-vit-huge")
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
+ model.to(device)
 
15
 
16
+ def generate_colors(n):
17
+ """Generiere verschiedene Farben für Masken"""
18
+ colors = []
19
+ for i in range(n):
20
+ hue = int(180 * i / max(n, 1))
21
+ color = cv2.cvtColor(np.uint8([[[hue, 255, 255]]]), cv2.COLOR_HSV2RGB)[0][0]
22
+ colors.append(color.tolist())
23
+ return colors
24
 
25
+ def segment_automatic(image):
26
+ """Automatische Segmentierung - Hauptobjekt in der Mitte"""
27
+ if image is None:
28
+ return None, {"error": "Kein Bild hochgeladen"}
29
+
30
+ # Bild vorbereiten
31
  image_np = np.array(image)
32
+ h, w = image_np.shape[:2]
33
+
34
+ # Mittelpunkt als Input (da Objekt zentral sein sollte)
35
+ input_points = [[[w//2, h//2]]]
36
+
37
+ # SAM Processing
38
+ inputs = processor(image, input_points=input_points, return_tensors="pt").to(device)
39
+
40
+ with torch.no_grad():
41
+ outputs = model(**inputs)
42
+
43
+ # Masken extrahieren
44
+ masks = processor.image_processor.post_process_masks(
45
+ outputs.pred_masks.cpu(),
46
+ inputs["original_sizes"].cpu(),
47
+ inputs["reshaped_input_sizes"].cpu()
48
+ )[0]
49
+
50
+ # Beste Maske nehmen
51
+ scores = outputs.iou_scores.cpu().numpy()[0]
52
+ best_mask_idx = np.argmax(scores)
53
+ best_mask = masks[best_mask_idx].numpy().squeeze()
54
+
55
+ # Farbige Overlay erstellen
56
  overlay = image_np.copy()
57
+ color = [255, 0, 100] # Pink
58
+ overlay[best_mask] = overlay[best_mask] * 0.5 + np.array(color) * 0.5
59
+
60
+ # Metadata
61
+ metadata = {
62
+ "mode": "automatic",
63
+ "num_masks": 1,
64
+ "score": float(scores[best_mask_idx]),
65
+ "mask_shape": best_mask.shape,
66
+ "object_detected": True
67
+ }
68
+
69
+ return Image.fromarray(overlay.astype(np.uint8)), metadata
70
 
71
+ def segment_all_objects(image):
72
+ """Alle Objekte segmentieren - für manuelle Auswahl"""
73
+ if image is None:
74
+ return None, {"error": "Kein Bild hochgeladen"}
75
+
76
+ image_np = np.array(image)
77
+ h, w = image_np.shape[:2]
78
+
79
+ # Grid von Punkten für Segmentierung
80
+ grid_points = []
81
+ step = max(h, w) // 8 # 8x8 Grid
82
+ for y in range(step, h, step):
83
+ for x in range(step, w, step):
84
+ grid_points.append([x, y])
85
+
86
+ all_masks = []
87
+ all_scores = []
88
+
89
+ # Segmentiere jeden Punkt
90
+ for point in grid_points[:10]: # Limitiere auf 10 für Performance
91
+ inputs = processor(image, input_points=[[point]], return_tensors="pt").to(device)
92
+
93
+ with torch.no_grad():
94
+ outputs = model(**inputs)
95
+
96
+ masks = processor.image_processor.post_process_masks(
97
+ outputs.pred_masks.cpu(),
98
+ inputs["original_sizes"].cpu(),
99
+ inputs["reshaped_input_sizes"].cpu()
100
+ )[0]
101
+
102
+ scores = outputs.iou_scores.cpu().numpy()[0]
103
+ best_idx = np.argmax(scores)
104
+
105
+ all_masks.append(masks[best_idx].numpy().squeeze())
106
+ all_scores.append(scores[best_idx])
107
+
108
+ # Kombiniere alle Masken mit verschiedenen Farben
109
+ overlay = image_np.copy()
110
+ colors = generate_colors(len(all_masks))
111
+
112
+ for mask, color in zip(all_masks, colors):
113
+ overlay[mask] = overlay[mask] * 0.6 + np.array(color) * 0.4
114
+
115
+ metadata = {
116
+ "mode": "multi_object",
117
+ "num_masks": len(all_masks),
118
+ "avg_score": float(np.mean(all_scores)),
119
+ "masks_data": [
120
+ {
121
+ "id": i,
122
+ "score": float(score),
123
+ "area": int(mask.sum())
124
+ } for i, (mask, score) in enumerate(zip(all_masks, all_scores))
125
+ ]
126
+ }
127
+
128
+ return Image.fromarray(overlay.astype(np.uint8)), metadata
129
 
130
+ def segment_with_points(image, points_json):
131
+ """Segmentierung mit benutzerdefinierten Punkten"""
132
+ if image is None:
133
+ return None, {"error": "Kein Bild hochgeladen"}
134
+
135
+ image_np = np.array(image)
136
+
137
+ try:
138
+ # Parse JSON input
139
+ points_data = json.loads(points_json) if isinstance(points_json, str) else points_json
140
+ input_points = [points_data.get("points", [[image_np.shape[1]//2, image_np.shape[0]//2]])]
141
+
142
+ inputs = processor(image, input_points=input_points, return_tensors="pt").to(device)
143
+
144
+ with torch.no_grad():
145
+ outputs = model(**inputs)
146
+
147
+ masks = processor.image_processor.post_process_masks(
148
+ outputs.pred_masks.cpu(),
149
+ inputs["original_sizes"].cpu(),
150
+ inputs["reshaped_input_sizes"].cpu()
151
+ )[0]
152
+
153
+ scores = outputs.iou_scores.cpu().numpy()[0]
154
+ best_idx = np.argmax(scores)
155
+ best_mask = masks[best_idx].numpy().squeeze()
156
+
157
+ overlay = image_np.copy()
158
+ color = [0, 255, 100] # Grün
159
+ overlay[best_mask] = overlay[best_mask] * 0.5 + np.array(color) * 0.5
160
+
161
+ metadata = {
162
+ "mode": "custom_points",
163
+ "points": input_points[0],
164
+ "score": float(scores[best_idx]),
165
+ "success": True
166
+ }
167
+
168
+ return Image.fromarray(overlay.astype(np.uint8)), metadata
169
+
170
+ except Exception as e:
171
+ return image, {"error": str(e)}
172
 
173
+ # Gradio Interface
174
+ with gr.Blocks(title="SAM2 Segmentierung API") as demo:
175
+ gr.Markdown("""
176
+ # 🎨 SAM2 Bild Segmentierung
177
+ ### Keine Training nötig - Zero-Shot Object Segmentation!
178
+ """)
179
+
180
+ with gr.Tab("🤖 Automatisch (Hauptobjekt)"):
181
+ with gr.Row():
182
+ with gr.Column():
183
+ input_auto = gr.Image(type="pil", label="Bild hochladen")
184
+ btn_auto = gr.Button("Objekt erkennen", variant="primary")
185
+ with gr.Column():
186
+ output_auto = gr.Image(label="Segmentiertes Bild")
187
+ json_auto = gr.JSON(label="Metadata")
188
+
189
+ btn_auto.click(segment_automatic, inputs=input_auto, outputs=[output_auto, json_auto])
190
+
191
+ with gr.Tab("🎯 Mehrere Objekte"):
192
+ with gr.Row():
193
+ with gr.Column():
194
+ input_multi = gr.Image(type="pil", label="Bild hochladen")
195
+ btn_multi = gr.Button("Alle Objekte erkennen", variant="primary")
196
+ with gr.Column():
197
+ output_multi = gr.Image(label="Segmentierte Bereiche")
198
+ json_multi = gr.JSON(label="Metadata")
199
+
200
+ btn_multi.click(segment_all_objects, inputs=input_multi, outputs=[output_multi, json_multi])
201
+
202
+ with gr.Tab("✋ Custom (mit Punkten)"):
203
+ with gr.Row():
204
+ with gr.Column():
205
+ input_custom = gr.Image(type="pil", label="Bild hochladen")
206
+ points_input = gr.Textbox(
207
+ label="Punkte (JSON)",
208
+