Astridkraft commited on
Commit
98758f6
·
verified ·
1 Parent(s): 23e99aa

Create sam_module.py

Browse files
Files changed (1) hide show
  1. sam_module.py +310 -0
sam_module.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def create_sam_mask(self, image, bbox_coords, mode):
2
+ """
3
+ ERWEITERTE Funktion: Erstellt präzise Maske mit SAM 2
4
+ Restrukturierte Version mit klaren Blöcken pro Modus
5
+ """
6
+ try:
7
+ print("#" * 80)
8
+ print("# 🎯 STARTE SAM 2 SEGMENTIERUNG")
9
+ print("#" * 80)
10
+ print(f"📐 Eingabebild-Größe: {image.size}")
11
+ print(f"🎛️ Ausgewählter Modus: {mode}")
12
+
13
+ # ============================================================
14
+ # VORBEREITUNG FÜR ALLE MODI
15
+ # ============================================================
16
+ original_image = image
17
+
18
+ # 1. SAM2 laden
19
+ if not self.sam_initialized:
20
+ print("📥 SAM 2 ist noch nicht geladen, starte Lazy Loading...")
21
+ self._lazy_load_sam()
22
+
23
+ if self.sam_model is None or self.sam_processor is None:
24
+ print("⚠️ SAM 2 Model nicht verfügbar, verwende Fallback")
25
+ return self._create_rectangular_mask(image, bbox_coords, mode)
26
+
27
+ # 2. Validiere BBox
28
+ x1, y1, x2, y2 = self._validate_bbox(image, bbox_coords)
29
+ original_bbox = (x1, y1, x2, y2)
30
+ print(f"📏 Original-BBox Größe: {x2-x1} × {y2-y1} px")
31
+
32
+ # ============================================================
33
+ # BLOCK 1: ENVIRONMENT_CHANGE
34
+ # ============================================================
35
+ if mode == "environment_change":
36
+ print("-" * 60)
37
+ print("🌳 MODUS: ENVIRONMENT_CHANGE")
38
+ print("-" * 60)
39
+
40
+ # ... existierende environment_change Logik hier komplett ...
41
+ # (wird aus dem Original übernommen, nicht verändert)
42
+
43
+ # WICHTIG: Du musst den environment_change Code hier einfügen
44
+ # von Zeile ~175 bis ~250 aus dem Original
45
+
46
+ # Beispiel-Struktur (vereinfacht):
47
+ image_np = np.array(image.convert("RGB"))
48
+ input_boxes = [[[x1, y1, x2, y2]]]
49
+
50
+ # KEINE Punkte für environment_change
51
+ inputs = self.sam_processor(
52
+ image_np,
53
+ input_boxes=input_boxes,
54
+ return_tensors="pt"
55
+ ).to(self.device)
56
+
57
+ with torch.no_grad():
58
+ outputs = self.sam_model(**inputs)
59
+
60
+ # Nur beste Maske verwenden und auf 512x512 skalieren
61
+ best_mask = outputs.pred_masks[:, :, 0, :, :] # Erste Maske nehmen
62
+ resized_mask = F.interpolate(
63
+ best_mask,
64
+ size=(512, 512), # Direkt auf ControlNet-Zielgröße
65
+ mode='bilinear',
66
+ align_corners=False
67
+ ).squeeze()
68
+
69
+ mask_np = resized_mask.sigmoid().cpu().numpy()
70
+
71
+ # Invertieren für environment_change
72
+ threshold = 0.5
73
+ mask_array = (mask_np > threshold).astype(np.uint8) * 255
74
+ mask_array = 255 - mask_array # Invertieren
75
+
76
+ # Auf Originalgröße für Rückgabe
77
+ mask = Image.fromarray(mask_array).convert("L")
78
+ mask = mask.resize(original_image.size, Image.Resampling.NEAREST)
79
+
80
+ return mask, mask # raw_mask gleiche wie finale Maske
81
+
82
+ # ============================================================
83
+ # BLOCK 2: FOCUS_CHANGE (KORRIGIERTE VERSION)
84
+ # ============================================================
85
+ elif mode == "focus_change":
86
+ print("-" * 60)
87
+ print("🎯 MODUS: FOCUS_CHANGE (OPTIMIERT)")
88
+ print("-" * 60)
89
+
90
+ # Bild für SAM vorbereiten
91
+ image_np = np.array(image.convert("RGB"))
92
+
93
+ # NUR EINE BBOX UND NUR MITTELPUNKT (kein Gesichtspunkt)
94
+ input_boxes = [[[x1, y1, x2, y2]]]
95
+
96
+ # Nur Mittelpunkt als positiver Prompt
97
+ center_x = (x1 + x2) // 2
98
+ center_y = (y1 + y2) // 2
99
+ input_points = [[[[center_x, center_y]]]] # NUR EIN PUNKT
100
+ input_labels = [[[1]]] # Positiver Prompt
101
+
102
+ print(f" 🎯 SAM-Prompt: BBox [{x1},{y1},{x2},{y2}]")
103
+ print(f" 👁️ Punkt: Nur Mitte ({center_x},{center_y})")
104
+
105
+ # SAM Inputs vorbereiten
106
+ inputs = self.sam_processor(
107
+ image_np,
108
+ input_boxes=input_boxes,
109
+ input_points=input_points,
110
+ input_labels=input_labels,
111
+ return_tensors="pt"
112
+ ).to(self.device)
113
+
114
+ # SAM Vorhersage (alle 3 Masken)
115
+ print("🧠 SAM 2 INFERENZ (3 Masken-Varianten)")
116
+ with torch.no_grad():
117
+ outputs = self.sam_model(**inputs)
118
+
119
+ # BBox-Information für Heuristik
120
+ bbox_center = ((x1 + x2) // 2, (y1 + y2) // 2)
121
+ bbox_area = (x2 - x1) * (y2 - y1)
122
+
123
+ print("🤔 HEURISTIK: Beste Maske auswählen")
124
+ best_mask_idx = 0
125
+ best_score = -1
126
+
127
+ # Alle 3 Masken analysieren (OHNE sie alle zu skalieren!)
128
+ for i in range(3):
129
+ # Maske in Original-SAM-Größe (256x256) analysieren
130
+ mask_256 = outputs.pred_masks[:, :, i, :, :]
131
+ mask_np_256 = mask_256.sigmoid().squeeze().cpu().numpy()
132
+
133
+ # Für Heuristik: Temporär auf Bildgröße skalieren
134
+ temp_mask = F.interpolate(
135
+ mask_256,
136
+ size=(image.height, image.width),
137
+ mode='bilinear',
138
+ align_corners=False
139
+ ).squeeze()
140
+ mask_np_temp = temp_mask.sigmoid().cpu().numpy()
141
+
142
+ # Adaptive Vor-Filterung
143
+ mask_max = mask_np_temp.max()
144
+ if mask_max < 0.3:
145
+ continue # Maske überspringen
146
+
147
+ adaptive_threshold = max(0.3, mask_max * 0.7)
148
+ mask_binary = (mask_np_temp > adaptive_threshold).astype(np.uint8)
149
+
150
+ if np.sum(mask_binary) == 0:
151
+ continue
152
+
153
+ # Heuristik-Berechnung (wie bisher)
154
+ mask_area_pixels = np.sum(mask_binary)
155
+
156
+ # BBox-Überlappung
157
+ bbox_mask = np.zeros((image.height, image.width), dtype=np.uint8)
158
+ bbox_mask[y1:y2, x1:x2] = 1
159
+ overlap = np.sum(mask_binary & bbox_mask)
160
+ bbox_overlap_ratio = overlap / np.sum(bbox_mask) if np.sum(bbox_mask) > 0 else 0
161
+
162
+ # Schwerpunkt
163
+ y_coords, x_coords = np.where(mask_binary > 0)
164
+ if len(y_coords) > 0:
165
+ centroid_y = np.mean(y_coords)
166
+ centroid_x = np.mean(x_coords)
167
+ centroid_distance = np.sqrt((centroid_x - bbox_center[0])**2 +
168
+ (centroid_y - bbox_center[1])**2)
169
+ normalized_distance = centroid_distance / max(image.width, image.height)
170
+ else:
171
+ normalized_distance = 1.0
172
+
173
+ # Flächen-Ratio
174
+ area_ratio = mask_area_pixels / bbox_area
175
+ area_score = 1.0 - min(abs(area_ratio - 1.0), 1.0)
176
+
177
+ # FOCUS_CHANGE spezifischer Score
178
+ score = (
179
+ bbox_overlap_ratio * 0.4 + # 40% BBox-Überlappung
180
+ (1.0 - normalized_distance) * 0.25 + # 25% Zentrumsnähe
181
+ area_score * 0.25 + # 25% Flächenpassung
182
+ mask_max * 0.1 # 10% SAM-Konfidenz
183
+ )
184
+
185
+ print(f" Maske {i+1}: Score={score:.3f}, "
186
+ f"Überlappung={bbox_overlap_ratio:.3f}, "
187
+ f"Fläche={mask_area_pixels:,}px")
188
+
189
+ if score > best_score:
190
+ best_score = score
191
+ best_mask_idx = i
192
+
193
+ print(f"✅ Beste Maske: Nr. {best_mask_idx+1} mit Score {best_score:.3f}")
194
+
195
+ # NUR DIE BESTE MASKE AUF 512x512 SKALIEREN
196
+ best_mask_256 = outputs.pred_masks[:, :, best_mask_idx, :, :]
197
+ resized_mask = F.interpolate(
198
+ best_mask_256,
199
+ size=(512, 512), # DIREKT AUF CONTROLNET-ZIELGRÖßE
200
+ mode='bilinear',
201
+ align_corners=False
202
+ ).squeeze()
203
+
204
+ mask_np = resized_mask.sigmoid().cpu().numpy()
205
+ print(f" 🔄 Beste Maske skaliert auf 512×512 für ControlNet")
206
+
207
+ # Dynamischer Threshold für focus_change
208
+ mask_max = mask_np.max()
209
+ if best_score < 0.7: # Schlechte Maskenqualität
210
+ dynamic_threshold = 0.05 # SEHR NIEDRIG für maximale Abdeckung
211
+ print(f" ⚠️ Masken-Score niedrig ({best_score:.3f}). "
212
+ f"Threshold=0.05 für maximale Abdeckung")
213
+ else:
214
+ dynamic_threshold = max(0.15, mask_max * 0.3) # Moderater Threshold
215
+ print(f" ✅ Gute Maske. Threshold={dynamic_threshold:.3f}")
216
+
217
+ # Binärmaske erstellen
218
+ mask_array = (mask_np > dynamic_threshold).astype(np.uint8) * 255
219
+
220
+ # Fallback bei leerer Maske
221
+ if mask_array.max() == 0:
222
+ print(" ⚠️ Maske leer, erstelle rechteckige Fallback-Maske")
223
+ mask_array = np.zeros((512, 512), dtype=np.uint8)
224
+ # BBox auf 512x512 skalieren für Fallback
225
+ scale_x = 512 / image.width
226
+ scale_y = 512 / image.height
227
+ fb_x1 = int(x1 * scale_x)
228
+ fb_y1 = int(y1 * scale_y)
229
+ fb_x2 = int(x2 * scale_x)
230
+ fb_y2 = int(y2 * scale_y)
231
+ cv2.rectangle(mask_array, (fb_x1, fb_y1), (fb_x2, fb_y2), 255, -1)
232
+
233
+ # FOCUS_CHANGE POSTPROCESSING (angepasst für 512x512)
234
+ print("🔧 FOCUS_CHANGE POSTPROCESSING (auf 512×512)")
235
+
236
+ # 1. Größte Komponente behalten
237
+ labeled_array, num_features = ndimage.label(mask_array)
238
+ if num_features > 1:
239
+ sizes = ndimage.sum(mask_array, labeled_array, range(1, num_features + 1))
240
+ largest_component = np.argmax(sizes) + 1
241
+ mask_array = np.where(labeled_array == largest_component, mask_array, 0)
242
+ print(f" ✅ Größte Komponente behalten ({num_features}→1)")
243
+
244
+ # 2. Morphologische Operationen
245
+ kernel_close = np.ones((5, 5), np.uint8)
246
+ mask_array = cv2.morphologyEx(mask_array, cv2.MORPH_CLOSE, kernel_close, iterations=2)
247
+
248
+ kernel_dilate = np.ones((15, 15), np.uint8)
249
+ mask_array = cv2.dilate(mask_array, kernel_dilate, iterations=1)
250
+
251
+ # 3. Weiche Übergänge
252
+ mask_array = cv2.GaussianBlur(mask_array, (9, 9), 2.0)
253
+
254
+ # 4. Gamma-Korrektur
255
+ mask_array_float = mask_array.astype(np.float32) / 255.0
256
+ mask_array_float = np.clip(mask_array_float, 0.0, 1.0)
257
+ mask_array_float = mask_array_float ** 0.85
258
+ mask_array = (mask_array_float * 255).astype(np.uint8)
259
+
260
+ # 5. Auf Originalgröße für Rückgabe (falls benötigt)
261
+ mask_512 = Image.fromarray(mask_array).convert("L")
262
+ raw_mask = mask_512.copy() # Rohmaske = finale Maske bei focus_change
263
+
264
+ # Finale Maske für ControlNet ist 512x512
265
+ mask = mask_512
266
+
267
+ print(f"✅ FOCUS_CHANGE Maske erstellt: {mask.size}")
268
+ return mask, raw_mask
269
+
270
+ # ============================================================
271
+ # BLOCK 3: FACE_ONLY_CHANGE
272
+ # ============================================================
273
+ elif mode == "face_only_change":
274
+ print("-" * 60)
275
+ print("👤 MODUS: FACE_ONLY_CHANGE")
276
+ print("-" * 60)
277
+
278
+ # ... existierende face_only_change Logik hier komplett ...
279
+ # (wird aus dem Original übernommen, nicht verändert)
280
+
281
+ # WICHTIG: Du musst den face_only_change Code hier einfügen
282
+ # von Zeile ~252 bis ~650 aus dem Original
283
+
284
+ # Beispiel-Struktur (vereinfacht):
285
+ # Crop, Punkte setzen, spezielle Gesichtsheuristik etc.
286
+
287
+ # Am Ende:
288
+ mask = Image.new("L", (512, 512), 128) # Platzhalter
289
+ raw_mask = mask.copy()
290
+ return mask, raw_mask
291
+
292
+ # ============================================================
293
+ # UNBEKANNTER MODUS
294
+ # ============================================================
295
+ else:
296
+ print(f"❌ Unbekannter Modus: {mode}")
297
+ return self._create_rectangular_mask(image, bbox_coords, "focus_change")
298
+
299
+ except Exception as e:
300
+ print("❌ FEHLER IN SAM 2 SEGMENTIERUNG")
301
+ print(f"Fehler: {str(e)[:200]}")
302
+ import traceback
303
+ traceback.print_exc()
304
+
305
+ # Fallback
306
+ fallback_mask = self._create_rectangular_mask(original_image, original_bbox, mode)
307
+ if fallback_mask.size != original_image.size:
308
+ fallback_mask = fallback_mask.resize(original_image.size, Image.Resampling.NEAREST)
309
+
310
+ return fallback_mask, fallback_mask