Astridkraft commited on
Commit
5a39fae
·
verified ·
1 Parent(s): 2a0a395

Update controlnet_module.py

Browse files
Files changed (1) hide show
  1. controlnet_module.py +99 -90
controlnet_module.py CHANGED
@@ -1,5 +1,5 @@
1
  import torch
2
- from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
3
  from controlnet_aux import OpenposeDetector
4
  from PIL import Image
5
  import random
@@ -8,8 +8,6 @@ import numpy as np
8
  import gradio as gr
9
 
10
 
11
-
12
-
13
  class ControlNetProgressCallback:
14
  def __init__(self, progress, total_steps):
15
  self.progress = progress
@@ -20,7 +18,6 @@ class ControlNetProgressCallback:
20
  self.current_step = step_index + 1
21
  progress_percentage = self.current_step / self.total_steps
22
 
23
- # Fortschritt aktualisieren
24
  if self.progress is not None:
25
  self.progress(progress_percentage, desc=f"ControlNet: Schritt {self.current_step}/{self.total_steps}")
26
 
@@ -33,24 +30,52 @@ class ControlNetProcessor:
33
  self.device = device
34
  self.torch_dtype = torch_dtype
35
  self.pose_detector = None
36
- self.controlnet_openpose = None
37
- self.controlnet_canny = None
38
- self.controlnet_depth = None
39
- self.pipe_openpose = None
40
- self.pipe_canny = None
41
- self.pipe_depth = None
42
- self.pipe_multi_inside = None # OpenPose + Canny für Inside-Box
43
- self.pipe_multi_outside = None # Depth + Canny für Outside-Box
44
-
45
  def load_pose_detector(self):
46
  """Lädt nur den Pose-Detector"""
47
  if self.pose_detector is None:
48
  print("Loading Pose Detector...")
49
  try:
50
  self.pose_detector = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
 
51
  except Exception as e:
52
- print(f"Warnung: Pose-Detector konnte nicht geladen werden: {e}")
53
  return self.pose_detector
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
  def extract_pose_simple(self, image):
56
  """Einfache Pose-Extraktion ohne komplexe Abhängigkeiten"""
@@ -96,76 +121,67 @@ class ControlNetProcessor:
96
  print(f"Fehler bei Canny Edge Extraction: {e}")
97
  return image.convert("RGB").resize((512, 512))
98
 
99
-
100
  def extract_depth_map(self, image):
101
- """
102
- Extrahiert Depth Map mit MiDaS Small (Fallback auf alten Filter).
103
- """
104
- try:
105
- print("🔄 Versuche MiDaS Small für Depth Map...")
106
- # 1. MiDaS Modelle vor dem ersten Gebrauch laden (spart VRAM)
107
- if not hasattr(self, 'midas_model'):
108
- from torchvision.transforms import Compose, Resize, ToTensor, Normalize
109
- import midas
110
-
111
- self.midas_transform = Compose([
112
- Resize(384, interpolation=midas.utils.interpolation),
113
- ToTensor(),
114
- Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
115
- ])
116
-
117
- # WICHTIG: MiDaS Small lädt automatisch die 'small'-Variante (weniger VRAM)
118
- self.midas_model = midas.MiDaS()
119
- self.midas_model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
- if self.device == 'cuda':
122
- self.midas_model.to(self.device)
123
- print("✅ MiDaS Small Modell geladen (GPU)")
124
  else:
125
- print("✅ MiDaS Small Modell geladen (CPU)")
126
-
127
- # 2. Bild für MiDaS vorbereiten
128
- img_input = self.midas_transform(image).unsqueeze(0).to(self.device)
129
-
130
- # 3. Depth Map berechnen
131
- with torch.no_grad():
132
- prediction = self.midas_model(img_input)
133
- prediction = torch.nn.functional.interpolate(
134
- prediction.unsqueeze(1),
135
- size=image.size[::-1], # (height, width)
136
- mode="bicubic",
137
- align_corners=False,
138
- ).squeeze()
139
-
140
- # 4. Normalisierung für sichtbare Ausgabe
141
- depth_np = prediction.cpu().numpy()
142
- depth_min, depth_max = depth_np.min(), depth_np.max()
143
- if depth_max > depth_min:
144
- depth_np = (depth_np - depth_min) / (depth_max - depth_min)
145
- depth_np = (depth_np * 255).astype(np.uint8)
146
-
147
- depth_image = Image.fromarray(depth_np).convert("RGB")
148
- print("✅ MiDaS Depth Map erfolgreich erstellt")
149
- return depth_image
150
-
151
- except Exception as e:
152
- print(f"⚠️ MiDaS Fehler: {e}. Verwende Fallback (Grayscale Filter)...")
153
- # Fallback auf Ihren bestehenden Filter-Code
154
- try:
155
- img_array = np.array(image.convert("RGB"))
156
- gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
157
-
158
- # Depth-ähnliche Map erstellen (helle Bereiche = nah, dunkle = fern)
159
- depth_map = cv2.GaussianBlur(gray, (5, 5), 0)
160
- depth_rgb = cv2.cvtColor(depth_map, cv2.COLOR_GRAY2RGB)
161
- depth_image = Image.fromarray(depth_rgb)
162
- print("✅ Fallback Depth Map erstellt")
163
- return depth_image
164
- except Exception as fallback_error:
165
- print(f"❌ Auch Fallback fehlgeschlagen: {fallback_error}")
166
- return image.convert("RGB").resize((512, 512))
167
-
168
-
169
 
170
  def prepare_controlnet_maps(self, image, keep_environment=False):
171
  """
@@ -195,22 +211,15 @@ class ControlNetProcessor:
195
  def prepare_inpaint_input(self, image, keep_environment=False):
196
  """
197
  Bereitet das Input-Bild für Inpaint vor
198
- Rückgabe: (image_für_inpaint, conditioning_info)
199
-
200
- HINWEIS: Diese Funktion wird nicht direkt von app.py verwendet,
201
- da die Logik in generate_with_controlnet enthalten ist.
202
  """
203
  if keep_environment:
204
- # OUTSIDE-BOX ÄNDERN: Depth+Canny Info für Umgebung
205
- print("🎯 Inpaint: Übergebe Depth+Canny Info (Outside-Box ändern)")
206
  depth_image = self.extract_depth_map(image)
207
  canny_image = self.extract_canny_edges(image)
208
- # Für Inpaint kann eine kombinierte Map verwendet werden
209
  combined_map = Image.blend(depth_image.convert("RGB"), canny_image.convert("RGB"), alpha=0.5)
210
  return combined_map, {"type": "depth_canny", "image": combined_map}
211
  else:
212
- # INSIDE-BOX ÄNDERN: Originalbild an Inpaint übergeben
213
- print("🎯 Inpaint: Übergebe Originalbild (Inside-Box ändern)")
214
  return image, {"type": "original", "image": image}
215
 
216
 
 
1
  import torch
2
+ from diffusers import StableDiffusionControlNetPipeline, ControlNetModel # <- KORREKT!
3
  from controlnet_aux import OpenposeDetector
4
  from PIL import Image
5
  import random
 
8
  import gradio as gr
9
 
10
 
 
 
11
  class ControlNetProgressCallback:
12
  def __init__(self, progress, total_steps):
13
  self.progress = progress
 
18
  self.current_step = step_index + 1
19
  progress_percentage = self.current_step / self.total_steps
20
 
 
21
  if self.progress is not None:
22
  self.progress(progress_percentage, desc=f"ControlNet: Schritt {self.current_step}/{self.total_steps}")
23
 
 
30
  self.device = device
31
  self.torch_dtype = torch_dtype
32
  self.pose_detector = None
33
+ self.midas_model = None
34
+ self.midas_transform = None
35
+
 
 
 
 
 
 
36
  def load_pose_detector(self):
37
  """Lädt nur den Pose-Detector"""
38
  if self.pose_detector is None:
39
  print("Loading Pose Detector...")
40
  try:
41
  self.pose_detector = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
42
+ print("✅ Pose-Detector geladen")
43
  except Exception as e:
44
+ print(f"⚠️ Pose-Detector konnte nicht geladen werden: {e}")
45
  return self.pose_detector
46
+
47
+ def load_midas_model(self):
48
+ """Lädt MiDaS Model für Depth Maps"""
49
+ if self.midas_model is None:
50
+ print("🔄 Lade MiDaS Modell für Depth Maps...")
51
+ try:
52
+ # WICHTIG: torchvision 0.20.0 hat MiDaS integriert
53
+ import torchvision.transforms as T
54
+
55
+ # MiDaS Small (weniger VRAM)
56
+ self.midas_model = torch.hub.load(
57
+ "intel-isl/MiDaS",
58
+ "DPT_Hybrid",
59
+ trust_repo=True
60
+ )
61
+
62
+ self.midas_model.to(self.device)
63
+ self.midas_model.eval()
64
+
65
+ # Transform für MiDaS
66
+ self.midas_transform = T.Compose([
67
+ T.Resize(384),
68
+ T.ToTensor(),
69
+ T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
70
+ ])
71
+
72
+ print("✅ MiDaS Modell erfolgreich geladen")
73
+ except Exception as e:
74
+ print(f"❌ MiDaS konnte nicht geladen werden: {e}")
75
+ print("ℹ️ Verwende Fallback-Methode")
76
+ self.midas_model = None
77
+
78
+ return self.midas_model
79
 
80
  def extract_pose_simple(self, image):
81
  """Einfache Pose-Extraktion ohne komplexe Abhängigkeiten"""
 
121
  print(f"Fehler bei Canny Edge Extraction: {e}")
122
  return image.convert("RGB").resize((512, 512))
123
 
 
124
  def extract_depth_map(self, image):
125
+ """
126
+ Extrahiert Depth Map mit MiDaS (Fallback auf Filter)
127
+ """
128
+ try:
129
+ # Versuche MiDaS
130
+ midas = self.load_midas_model()
131
+ if midas is not None:
132
+ print("🎯 Verwende MiDaS für Depth Map...")
133
+
134
+ import torchvision.transforms as T
135
+ from PIL import Image
136
+
137
+ # Bild vorbereiten
138
+ img_transformed = self.midas_transform(image).unsqueeze(0).to(self.device)
139
+
140
+ # Depth Map berechnen
141
+ with torch.no_grad():
142
+ prediction = midas(img_transformed)
143
+ prediction = torch.nn.functional.interpolate(
144
+ prediction.unsqueeze(1),
145
+ size=image.size[::-1], # (height, width)
146
+ mode="bicubic",
147
+ align_corners=False,
148
+ ).squeeze()
149
+
150
+ # Normalisieren für Ausgabe
151
+ depth_np = prediction.cpu().numpy()
152
+ depth_min, depth_max = depth_np.min(), depth_np.max()
153
+
154
+ if depth_max > depth_min:
155
+ depth_np = (depth_np - depth_min) / (depth_max - depth_min)
156
+
157
+ depth_np = (depth_np * 255).astype(np.uint8)
158
+ depth_image = Image.fromarray(depth_np).convert("RGB")
159
+
160
+ print("✅ MiDaS Depth Map erfolgreich erstellt")
161
+ return depth_image
162
 
 
 
 
163
  else:
164
+ # Fallback auf einfache Methode
165
+ print("⚠️ MiDaS nicht verfügbar, verwende Fallback...")
166
+ raise Exception("MiDaS nicht geladen")
167
+
168
+ except Exception as e:
169
+ print(f"⚠️ MiDaS Fehler: {e}. Verwende Fallback...")
170
+ # Fallback auf einfache Depth Map
171
+ try:
172
+ img_array = np.array(image.convert("RGB"))
173
+ gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
174
+
175
+ # Depth-ähnliche Map erstellen
176
+ depth_map = cv2.GaussianBlur(gray, (5, 5), 0)
177
+ depth_rgb = cv2.cvtColor(depth_map, cv2.COLOR_GRAY2RGB)
178
+ depth_image = Image.fromarray(depth_rgb)
179
+
180
+ print("✅ Fallback Depth Map erstellt")
181
+ return depth_image
182
+ except Exception as fallback_error:
183
+ print(f"❌ Auch Fallback fehlgeschlagen: {fallback_error}")
184
+ return image.convert("RGB").resize((512, 512))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
 
186
  def prepare_controlnet_maps(self, image, keep_environment=False):
187
  """
 
211
  def prepare_inpaint_input(self, image, keep_environment=False):
212
  """
213
  Bereitet das Input-Bild für Inpaint vor
 
 
 
 
214
  """
215
  if keep_environment:
216
+ print("🎯 Inpaint: Depth+Canny Info (Outside-Box ändern)")
 
217
  depth_image = self.extract_depth_map(image)
218
  canny_image = self.extract_canny_edges(image)
 
219
  combined_map = Image.blend(depth_image.convert("RGB"), canny_image.convert("RGB"), alpha=0.5)
220
  return combined_map, {"type": "depth_canny", "image": combined_map}
221
  else:
222
+ print("🎯 Inpaint: Originalbild (Inside-Box ändern)")
 
223
  return image, {"type": "original", "image": image}
224
 
225