Astridkraft commited on
Commit
6db455d
·
verified ·
1 Parent(s): c4ecf92

Update controlnet_module.py

Browse files
Files changed (1) hide show
  1. controlnet_module.py +161 -35
controlnet_module.py CHANGED
@@ -1,11 +1,12 @@
1
  import torch
2
- from diffusers import StableDiffusionControlNetPipeline, ControlNetModel # <- KORREKT!
3
  from controlnet_aux import OpenposeDetector
4
- from PIL import Image
5
  import random
6
  import cv2
7
  import numpy as np
8
  import gradio as gr
 
9
 
10
 
11
  class ControlNetProgressCallback:
@@ -32,7 +33,163 @@ class ControlNetProcessor:
32
  self.pose_detector = None
33
  self.midas_model = None
34
  self.midas_transform = None
 
 
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  def load_pose_detector(self):
37
  """Lädt nur den Pose-Detector"""
38
  if self.pose_detector is None:
@@ -49,10 +206,8 @@ class ControlNetProcessor:
49
  if self.midas_model is None:
50
  print("🔄 Lade MiDaS Modell für Depth Maps...")
51
  try:
52
- # WICHTIG: torchvision 0.20.0 hat MiDaS integriert
53
  import torchvision.transforms as T
54
 
55
- # MiDaS Small (weniger VRAM)
56
  self.midas_model = torch.hub.load(
57
  "intel-isl/MiDaS",
58
  "DPT_Hybrid",
@@ -62,7 +217,6 @@ class ControlNetProcessor:
62
  self.midas_model.to(self.device)
63
  self.midas_model.eval()
64
 
65
- # Transform für MiDaS
66
  self.midas_transform = T.Compose([
67
  T.Resize(384),
68
  T.ToTensor(),
@@ -107,11 +261,9 @@ class ControlNetProcessor:
107
  try:
108
  img_array = np.array(image.convert("RGB"))
109
 
110
- # Canny Edge Detection
111
  gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
112
  edges = cv2.Canny(gray, 100, 200)
113
 
114
- # Zu 3-Kanal Bild konvertieren
115
  edges_rgb = cv2.cvtColor(edges, cv2.COLOR_GRAY2RGB)
116
  edges_image = Image.fromarray(edges_rgb)
117
 
@@ -126,28 +278,23 @@ class ControlNetProcessor:
126
  Extrahiert Depth Map mit MiDaS (Fallback auf Filter)
127
  """
128
  try:
129
- # Versuche MiDaS
130
  midas = self.load_midas_model()
131
  if midas is not None:
132
  print("🎯 Verwende MiDaS für Depth Map...")
133
 
134
  import torchvision.transforms as T
135
- from PIL import Image
136
 
137
- # Bild vorbereiten
138
  img_transformed = self.midas_transform(image).unsqueeze(0).to(self.device)
139
 
140
- # Depth Map berechnen
141
  with torch.no_grad():
142
  prediction = midas(img_transformed)
143
  prediction = torch.nn.functional.interpolate(
144
  prediction.unsqueeze(1),
145
- size=image.size[::-1], # (height, width)
146
  mode="bicubic",
147
  align_corners=False,
148
  ).squeeze()
149
 
150
- # Normalisieren für Ausgabe
151
  depth_np = prediction.cpu().numpy()
152
  depth_min, depth_max = depth_np.min(), depth_np.max()
153
 
@@ -161,18 +308,14 @@ class ControlNetProcessor:
161
  return depth_image
162
 
163
  else:
164
- # Fallback auf einfache Methode
165
- print("⚠️ MiDaS nicht verfügbar, verwende Fallback...")
166
  raise Exception("MiDaS nicht geladen")
167
 
168
  except Exception as e:
169
  print(f"⚠️ MiDaS Fehler: {e}. Verwende Fallback...")
170
- # Fallback auf einfache Depth Map
171
  try:
172
  img_array = np.array(image.convert("RGB"))
173
  gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
174
 
175
- # Depth-ähnliche Map erstellen
176
  depth_map = cv2.GaussianBlur(gray, (5, 5), 0)
177
  depth_rgb = cv2.cvtColor(depth_map, cv2.COLOR_GRAY2RGB)
178
  depth_image = Image.fromarray(depth_rgb)
@@ -190,14 +333,12 @@ class ControlNetProcessor:
190
  print("🎯 ControlNet: Erstelle Conditioning-Maps...")
191
 
192
  if keep_environment:
193
- # Depth + Canny
194
  print(" Modus: Depth + Canny")
195
  conditioning_images = [
196
  self.extract_depth_map(image),
197
  self.extract_canny_edges(image)
198
  ]
199
  else:
200
- # OpenPose + Canny
201
  print(" Modus: OpenPose + Canny")
202
  conditioning_images = [
203
  self.extract_pose(image),
@@ -205,22 +346,7 @@ class ControlNetProcessor:
205
  ]
206
 
207
  print(f"✅ {len(conditioning_images)} Conditioning-Maps erstellt.")
208
- return conditioning_images # Rückgabe: Liste der PIL Images
209
-
210
-
211
- def prepare_inpaint_input(self, image, keep_environment=False):
212
- """
213
- Bereitet das Input-Bild für Inpaint vor
214
- """
215
- if keep_environment:
216
- print("🎯 Inpaint: Depth+Canny Info (Outside-Box ändern)")
217
- depth_image = self.extract_depth_map(image)
218
- canny_image = self.extract_canny_edges(image)
219
- combined_map = Image.blend(depth_image.convert("RGB"), canny_image.convert("RGB"), alpha=0.5)
220
- return combined_map, {"type": "depth_canny", "image": combined_map}
221
- else:
222
- print("🎯 Inpaint: Originalbild (Inside-Box ändern)")
223
- return image, {"type": "original", "image": image}
224
 
225
 
226
  # Globale Instanz
 
1
  import torch
2
+ from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
3
  from controlnet_aux import OpenposeDetector
4
+ from PIL import Image, ImageFilter
5
  import random
6
  import cv2
7
  import numpy as np
8
  import gradio as gr
9
+ from segment_anything import sam_model_registry, SamPredictor
10
 
11
 
12
  class ControlNetProgressCallback:
 
33
  self.pose_detector = None
34
  self.midas_model = None
35
  self.midas_transform = None
36
+ self.sam_predictor = None
37
+ self.sam_initialized = False
38
 
39
+ def _lazy_load_sam(self):
40
+ """Lazy Loading von SAM 2 Tiny - Optimiert für Hugging Face Spaces"""
41
+ if self.sam_initialized:
42
+ return True
43
+
44
+ try:
45
+ print("🔄 Lade SAM 2 Tiny von Hugging Face Hub...")
46
+
47
+ # KORRIGIERT: Nur der Hugging Face Model-ID Pfad
48
+ model_id = "facebook/sam2-hiera-tiny"
49
+
50
+ # SAM 2 Modell direkt von Hugging Face laden
51
+ sam = sam_model_registry["sam2_hiera_tiny"](checkpoint=model_id)
52
+ sam.to(self.device)
53
+ self.sam_predictor = SamPredictor(sam)
54
+
55
+ self.sam_initialized = True
56
+ print(f"✅ SAM 2 ({model_id}) erfolgreich geladen")
57
+ return True
58
+
59
+ except Exception as e:
60
+ print(f"❌ SAM 2 konnte nicht geladen werden: {str(e)[:100]}")
61
+ print("ℹ️ Verwende rechteckige Masken als Fallback")
62
+ self.sam_predictor = None
63
+ self.sam_initialized = True # Verhindert weitere Ladeversuche
64
+ return False
65
+
66
+ def _validate_bbox(self, image, bbox_coords):
67
+ """Validiert und korrigiert BBox-Koordinaten"""
68
+ width, height = image.size
69
+ x1, y1, x2, y2 = bbox_coords
70
+
71
+ # Stelle sicher, dass x1 <= x2 und y1 <= y2
72
+ x1, x2 = min(x1, x2), max(x1, x2)
73
+ y1, y2 = min(y1, y2), max(y1, y2)
74
+
75
+ # Begrenze auf Bildgrenzen
76
+ x1 = max(0, min(x1, width - 1))
77
+ y1 = max(0, min(y1, height - 1))
78
+ x2 = max(0, min(x2, width - 1))
79
+ y2 = max(0, min(y2, height - 1))
80
+
81
+ # Stelle sicher, dass BBox gültig ist
82
+ if x2 - x1 < 10 or y2 - y1 < 10:
83
+ # Fallback auf sinnvolle Größe
84
+ size = min(width, height) * 0.3
85
+ x1 = max(0, width/2 - size/2)
86
+ y1 = max(0, height/2 - size/2)
87
+ x2 = min(width, width/2 + size/2)
88
+ y2 = min(height, height/2 + size/2)
89
+
90
+ return int(x1), int(y1), int(x2), int(y2)
91
+
92
+ def _smooth_mask(self, mask_array, blur_radius=3):
93
+ """Glättet die Maske für bessere Übergänge (5-Pixel Randbereich)"""
94
+ try:
95
+ # Gaussian Blur für weiche Kanten - nur der Randbereich wird beeinflusst
96
+ if blur_radius > 0:
97
+ mask_array = cv2.GaussianBlur(mask_array,
98
+ (blur_radius*2+1, blur_radius*2+1),
99
+ 0)
100
+
101
+ return mask_array
102
+ except:
103
+ return mask_array
104
+
105
+ def create_sam_mask(self, image, bbox_coords, mode):
106
+ """
107
+ Erstellt präzise Maske mit SAM 2 (transparent für Benutzer)
108
+ Gibt PIL Image in L-Modus zurück (0=schwarz=erhalten, 255=weiß=verändern)
109
+ """
110
+ try:
111
+ # Lade SAM bei Bedarf (automatisch für Hugging Face Spaces)
112
+ if not self.sam_initialized:
113
+ self._lazy_load_sam()
114
+
115
+ # Fallback wenn SAM nicht verfügbar
116
+ if self.sam_predictor is None:
117
+ return self._create_rectangular_mask(image, bbox_coords, mode)
118
+
119
+ # Validiere BBox
120
+ x1, y1, x2, y2 = self._validate_bbox(image, bbox_coords)
121
+
122
+ # Konvertiere zu numpy array (RGB)
123
+ image_np = np.array(image.convert("RGB"))
124
+
125
+ # SAM vorbereiten
126
+ try:
127
+ self.sam_predictor.set_image(image_np)
128
+ except Exception as e:
129
+ print(f"⚠️ SAM set_image Fehler: {e}")
130
+ return self._create_rectangular_mask(image, bbox_coords, mode)
131
+
132
+ # BBox für SAM formatieren
133
+ input_box = np.array([x1, y1, x2, y2])
134
+
135
+ print(f"🎯 SAM 2: Segmentiere Bereich {x1},{y1}-{x2},{y2}")
136
+
137
+ # SAM Prediction
138
+ masks, scores, _ = self.sam_predictor.predict(
139
+ point_coords=None,
140
+ point_labels=None,
141
+ box=input_box[None, :],
142
+ multimask_output=False,
143
+ return_logits=False
144
+ )
145
+
146
+ # Beste Maske extrahieren und glätten (5-Pixel Übergang)
147
+ mask_array = masks[0].astype(np.uint8) * 255
148
+ mask_array = self._smooth_mask(mask_array, blur_radius=2) # ~5 Pixel Rand
149
+
150
+ # Zu PIL Image konvertieren
151
+ mask = Image.fromarray(mask_array).convert("L")
152
+
153
+ # Modus-spezifische Anpassung
154
+ if mode == "environment_change":
155
+ # MODUS 1: Umgebung ändern
156
+ # Objekt schwarz (0) = ERHALTEN, Umgebung weiß (255) = VERÄNDERN
157
+ mask = Image.eval(mask, lambda x: 255 - x)
158
+ print(" SAM-Modus: Umgebung ändern (Objekt erhalten)")
159
+ else:
160
+ # MODUS 2 & 3: Focus oder Gesicht ändern
161
+ # Objekt weiß (255) = VERÄNDERN, Umgebung schwarz (0) = ERHALTEN
162
+ print(" SAM-Modus: Focus/Gesicht ändern (Objekt verändern)")
163
+
164
+ print(f"✅ SAM 2: Präzise Maske erstellt ({mask.size})")
165
+ return mask
166
+
167
+ except Exception as e:
168
+ print(f"⚠️ SAM 2 Fehler: {str(e)[:100]}")
169
+ print("ℹ️ Fallback auf rechteckige Maske")
170
+ return self._create_rectangular_mask(image, bbox_coords, mode)
171
+
172
+ def _create_rectangular_mask(self, image, bbox_coords, mode):
173
+ """Fallback: Erstellt rechteckige Maske"""
174
+ from PIL import ImageDraw
175
+
176
+ mask = Image.new("L", image.size, 0)
177
+
178
+ if bbox_coords and all(coord is not None for coord in bbox_coords):
179
+ x1, y1, x2, y2 = self._validate_bbox(image, bbox_coords)
180
+ draw = ImageDraw.Draw(mask)
181
+
182
+ if mode == "environment_change":
183
+ # MODUS 1: Alles außer Box verändern
184
+ draw.rectangle([0, 0, image.size[0], image.size[1]], fill=255)
185
+ draw.rectangle([x1, y1, x2, y2], fill=0)
186
+ else:
187
+ # MODUS 2 & 3: Nur Box verändern
188
+ draw.rectangle([x1, y1, x2, y2], fill=255)
189
+
190
+ print("ℹ️ Rechteckige Maske (SAM Fallback)")
191
+ return mask
192
+
193
  def load_pose_detector(self):
194
  """Lädt nur den Pose-Detector"""
195
  if self.pose_detector is None:
 
206
  if self.midas_model is None:
207
  print("🔄 Lade MiDaS Modell für Depth Maps...")
208
  try:
 
209
  import torchvision.transforms as T
210
 
 
211
  self.midas_model = torch.hub.load(
212
  "intel-isl/MiDaS",
213
  "DPT_Hybrid",
 
217
  self.midas_model.to(self.device)
218
  self.midas_model.eval()
219
 
 
220
  self.midas_transform = T.Compose([
221
  T.Resize(384),
222
  T.ToTensor(),
 
261
  try:
262
  img_array = np.array(image.convert("RGB"))
263
 
 
264
  gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
265
  edges = cv2.Canny(gray, 100, 200)
266
 
 
267
  edges_rgb = cv2.cvtColor(edges, cv2.COLOR_GRAY2RGB)
268
  edges_image = Image.fromarray(edges_rgb)
269
 
 
278
  Extrahiert Depth Map mit MiDaS (Fallback auf Filter)
279
  """
280
  try:
 
281
  midas = self.load_midas_model()
282
  if midas is not None:
283
  print("🎯 Verwende MiDaS für Depth Map...")
284
 
285
  import torchvision.transforms as T
 
286
 
 
287
  img_transformed = self.midas_transform(image).unsqueeze(0).to(self.device)
288
 
 
289
  with torch.no_grad():
290
  prediction = midas(img_transformed)
291
  prediction = torch.nn.functional.interpolate(
292
  prediction.unsqueeze(1),
293
+ size=image.size[::-1],
294
  mode="bicubic",
295
  align_corners=False,
296
  ).squeeze()
297
 
 
298
  depth_np = prediction.cpu().numpy()
299
  depth_min, depth_max = depth_np.min(), depth_np.max()
300
 
 
308
  return depth_image
309
 
310
  else:
 
 
311
  raise Exception("MiDaS nicht geladen")
312
 
313
  except Exception as e:
314
  print(f"⚠️ MiDaS Fehler: {e}. Verwende Fallback...")
 
315
  try:
316
  img_array = np.array(image.convert("RGB"))
317
  gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
318
 
 
319
  depth_map = cv2.GaussianBlur(gray, (5, 5), 0)
320
  depth_rgb = cv2.cvtColor(depth_map, cv2.COLOR_GRAY2RGB)
321
  depth_image = Image.fromarray(depth_rgb)
 
333
  print("🎯 ControlNet: Erstelle Conditioning-Maps...")
334
 
335
  if keep_environment:
 
336
  print(" Modus: Depth + Canny")
337
  conditioning_images = [
338
  self.extract_depth_map(image),
339
  self.extract_canny_edges(image)
340
  ]
341
  else:
 
342
  print(" Modus: OpenPose + Canny")
343
  conditioning_images = [
344
  self.extract_pose(image),
 
346
  ]
347
 
348
  print(f"✅ {len(conditioning_images)} Conditioning-Maps erstellt.")
349
+ return conditioning_images
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
350
 
351
 
352
  # Globale Instanz