Astridkraft commited on
Commit
d18cc2e
·
verified ·
1 Parent(s): 3744e15

Update controlnet_module.py

Browse files
Files changed (1) hide show
  1. controlnet_module.py +110 -85
controlnet_module.py CHANGED
@@ -1,9 +1,5 @@
1
  import torch
2
- from diffusers import (
3
- StableDiffusionControlNetPipeline,
4
- ControlNetModel,
5
- MultiControlNetModel,
6
- )
7
  from controlnet_aux import OpenposeDetector
8
  from PIL import Image
9
  import random
@@ -12,9 +8,6 @@ import numpy as np
12
  import gradio as gr
13
 
14
 
15
- # ============================================================
16
- # PROGRESS CALLBACK
17
- # ============================================================
18
  class ControlNetProgressCallback:
19
  def __init__(self, progress, total_steps):
20
  self.progress = progress
@@ -33,36 +26,33 @@ class ControlNetProgressCallback:
33
  return callback_kwargs
34
 
35
 
36
- # ============================================================
37
- # CONTROLNET PROZESSOR
38
- # ============================================================
39
  class ControlNetProcessor:
40
  def __init__(self, device="cuda", torch_dtype=torch.float32):
41
  self.device = device
42
  self.torch_dtype = torch_dtype
43
  self.pose_detector = None
44
- self.pipe_multi = None # Multi-ControlNet (OpenPose + Canny)
 
 
 
45
 
46
- # ------------------------------------------------------------
47
- # POSE DETECTOR
48
- # ------------------------------------------------------------
49
  def load_pose_detector(self):
50
  """Lädt nur den Pose-Detector"""
51
  if self.pose_detector is None:
52
- print("🧠 Lade Pose Detector...")
53
  try:
54
  self.pose_detector = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
55
  except Exception as e:
56
- print(f"⚠️ Warnung: Pose-Detector konnte nicht geladen werden: {e}")
57
  return self.pose_detector
58
 
59
  def extract_pose_simple(self, image):
60
- """Fallback: Kantenbasierte Pose"""
61
  try:
62
  img_array = np.array(image.convert("RGB"))
63
  edges = cv2.Canny(img_array, 100, 200)
64
  pose_image = Image.fromarray(edges).convert("RGB")
65
- print("⚠️ Verwende einfache Kanten-Pose")
66
  return pose_image
67
  except Exception as e:
68
  print(f"Fehler bei einfacher Pose-Extraktion: {e}")
@@ -74,99 +64,135 @@ class ControlNetProcessor:
74
  detector = self.load_pose_detector()
75
  if detector is None:
76
  return self.extract_pose_simple(image)
 
77
  pose_image = detector(image, hand_and_face=True)
78
- print("✅ Pose-Map erfolgreich extrahiert")
79
  return pose_image
80
  except Exception as e:
81
  print(f"Fehler bei Pose-Extraktion: {e}")
82
  return self.extract_pose_simple(image)
83
 
84
- # ------------------------------------------------------------
85
- # CANNY EDGE
86
- # ------------------------------------------------------------
87
  def extract_canny_edges(self, image):
88
- """Extrahiert Canny-Kantenbild zur Umgebungserhaltung"""
89
  try:
90
  img_array = np.array(image.convert("RGB"))
 
 
91
  gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
92
  edges = cv2.Canny(gray, 100, 200)
 
 
93
  edges_rgb = cv2.cvtColor(edges, cv2.COLOR_GRAY2RGB)
94
  edges_image = Image.fromarray(edges_rgb)
95
- print("✅ Canny Edge Map erstellt")
 
96
  return edges_image
97
  except Exception as e:
98
- print(f"Fehler bei Canny-Extraktion: {e}")
99
  return image.convert("RGB").resize((512, 512))
100
 
101
- # ------------------------------------------------------------
102
- # PIPELINE-LADER
103
- # ------------------------------------------------------------
104
- def load_controlnet_pipeline(self):
105
- """Lädt kombinierte Multi-ControlNet Pipeline (OpenPose + Canny)"""
106
- if self.pipe_multi is None:
107
- print("🧩 Lade Multi-ControlNet Pipeline (OpenPose + Canny)...")
108
- try:
109
- controlnet_openpose = ControlNetModel.from_pretrained(
110
- "lllyasviel/sd-controlnet-openpose",
111
- torch_dtype=self.torch_dtype
112
- )
113
- controlnet_canny = ControlNetModel.from_pretrained(
114
- "lllyasviel/sd-controlnet-canny",
115
- torch_dtype=self.torch_dtype
116
- )
117
-
118
- multi_controlnet = MultiControlNetModel([controlnet_openpose, controlnet_canny])
119
-
120
- self.pipe_multi = StableDiffusionControlNetPipeline.from_pretrained(
121
- "runwayml/stable-diffusion-v1-5",
122
- controlnet=multi_controlnet,
123
- torch_dtype=self.torch_dtype,
124
- safety_checker=None,
125
- requires_safety_checker=False
126
- ).to(self.device)
127
-
128
- from diffusers import EulerAncestralDiscreteScheduler
129
- self.pipe_multi.scheduler = EulerAncestralDiscreteScheduler.from_config(self.pipe_multi.scheduler.config)
130
- self.pipe_multi.enable_attention_slicing()
131
-
132
- print("✅ Multi-ControlNet Pipeline erfolgreich geladen!")
133
- except Exception as e:
134
- print(f"❌ Fehler beim Laden von Multi-ControlNet: {e}")
135
- raise
136
- return self.pipe_multi
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
- # ------------------------------------------------------------
139
- # GENERIERUNG
140
- # ------------------------------------------------------------
141
  def generate_with_controlnet(
142
  self, image, prompt, negative_prompt,
143
  steps, guidance_scale, controlnet_strength,
144
- progress=None
145
  ):
146
- """Generiert Bild mit OpenPose + Canny Kombination"""
147
  try:
148
- print("🎯 Modus: Kombiniert (OpenPose + Canny + Inpaint)")
149
-
150
- pose_map = self.extract_pose(image)
151
- canny_map = self.extract_canny_edges(image)
152
- pipe = self.load_controlnet_pipeline()
153
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  seed = random.randint(0, 2**32 - 1)
155
  generator = torch.Generator(device=self.device).manual_seed(seed)
156
  print(f"ControlNet Seed: {seed}")
157
 
 
158
  callback = ControlNetProgressCallback(progress, int(steps)) if progress is not None else None
159
 
160
- print("🚀 Starte kombinierte ControlNet-Pipeline...")
161
 
 
162
  result = pipe(
163
  prompt=prompt,
164
- image=[pose_map, canny_map], # ← Beide Steuerbilder!
165
  negative_prompt=negative_prompt,
166
  num_inference_steps=int(steps),
167
  guidance_scale=guidance_scale,
168
- controlnet_conditioning_scale=[controlnet_strength, controlnet_strength * 0.7],
169
  generator=generator,
 
170
  height=512,
171
  width=512,
172
  output_type="pil",
@@ -174,36 +200,35 @@ class ControlNetProcessor:
174
  callback_on_step_end_tensor_inputs=[],
175
  )
176
 
177
- print("✅ Multi-ControlNet abgeschlossen!")
178
- return result.images[0], image # Für Inpaint weitergeben
 
 
179
 
180
  except Exception as e:
181
- print(f"❌ Fehler in Multi-ControlNet: {e}")
182
  import traceback
183
  traceback.print_exc()
184
  error_image = image.convert("RGB").resize((512, 512))
185
  return error_image, error_image
186
 
187
- # ------------------------------------------------------------
188
- # INPAINT-VORBEREITUNG
189
- # ------------------------------------------------------------
190
  def prepare_inpaint_input(self, image, keep_environment=False):
191
  """
192
  Bereitet das Input-Bild für Inpaint vor
193
  Rückgabe: (image_für_inpaint, conditioning_info)
194
  """
195
  if keep_environment:
 
196
  print("🎯 Inpaint: Übergebe Originalbild (Person ändern)")
197
  return image, {"type": "original", "image": image}
198
  else:
 
199
  print("🎯 Inpaint: Übergebe Pose-Map (Umgebung ändern)")
200
  pose_image = self.extract_pose(image)
201
  return pose_image, {"type": "pose", "image": pose_image}
202
 
203
 
204
- # ============================================================
205
- # GLOBALE INSTANZ
206
- # ============================================================
207
  device = "cuda" if torch.cuda.is_available() else "cpu"
208
  torch_dtype = torch.float16 if device == "cuda" else torch.float32
209
  controlnet_processor = ControlNetProcessor(device=device, torch_dtype=torch_dtype)
 
1
  import torch
2
+ from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
 
 
 
 
3
  from controlnet_aux import OpenposeDetector
4
  from PIL import Image
5
  import random
 
8
  import gradio as gr
9
 
10
 
 
 
 
11
  class ControlNetProgressCallback:
12
  def __init__(self, progress, total_steps):
13
  self.progress = progress
 
26
  return callback_kwargs
27
 
28
 
 
 
 
29
  class ControlNetProcessor:
30
  def __init__(self, device="cuda", torch_dtype=torch.float32):
31
  self.device = device
32
  self.torch_dtype = torch_dtype
33
  self.pose_detector = None
34
+ self.controlnet_openpose = None
35
+ self.controlnet_canny = None
36
+ self.pipe_openpose = None
37
+ self.pipe_canny = None
38
 
 
 
 
39
  def load_pose_detector(self):
40
  """Lädt nur den Pose-Detector"""
41
  if self.pose_detector is None:
42
+ print("Loading Pose Detector...")
43
  try:
44
  self.pose_detector = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
45
  except Exception as e:
46
+ print(f"Warnung: Pose-Detector konnte nicht geladen werden: {e}")
47
  return self.pose_detector
48
 
49
  def extract_pose_simple(self, image):
50
+ """Einfache Pose-Extraktion ohne komplexe Abhängigkeiten"""
51
  try:
52
  img_array = np.array(image.convert("RGB"))
53
  edges = cv2.Canny(img_array, 100, 200)
54
  pose_image = Image.fromarray(edges).convert("RGB")
55
+ print("⚠️ Verwende Kanten-basierte Pose-Approximation")
56
  return pose_image
57
  except Exception as e:
58
  print(f"Fehler bei einfacher Pose-Extraktion: {e}")
 
64
  detector = self.load_pose_detector()
65
  if detector is None:
66
  return self.extract_pose_simple(image)
67
+
68
  pose_image = detector(image, hand_and_face=True)
 
69
  return pose_image
70
  except Exception as e:
71
  print(f"Fehler bei Pose-Extraktion: {e}")
72
  return self.extract_pose_simple(image)
73
 
 
 
 
74
  def extract_canny_edges(self, image):
75
+ """Extrahiert Canny Edges für Umgebungserhaltung"""
76
  try:
77
  img_array = np.array(image.convert("RGB"))
78
+
79
+ # Canny Edge Detection
80
  gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
81
  edges = cv2.Canny(gray, 100, 200)
82
+
83
+ # Zu 3-Kanal Bild konvertieren
84
  edges_rgb = cv2.cvtColor(edges, cv2.COLOR_GRAY2RGB)
85
  edges_image = Image.fromarray(edges_rgb)
86
+
87
+ print("✅ Canny Edge für Umgebungserhaltung erstellt")
88
  return edges_image
89
  except Exception as e:
90
+ print(f"Fehler bei Canny Edge Extraction: {e}")
91
  return image.convert("RGB").resize((512, 512))
92
 
93
+ def load_controlnet_pipeline(self, controlnet_type="openpose"):
94
+ """Lädt die passende ControlNet Pipeline"""
95
+ if controlnet_type == "openpose":
96
+ if self.pipe_openpose is None:
97
+ print("Loading OpenPose ControlNet pipeline...")
98
+ try:
99
+ self.controlnet_openpose = ControlNetModel.from_pretrained(
100
+ "lllyasviel/sd-controlnet-openpose",
101
+ torch_dtype=self.torch_dtype
102
+ )
103
+ self.pipe_openpose = StableDiffusionControlNetPipeline.from_pretrained(
104
+ "runwayml/stable-diffusion-v1-5",
105
+ controlnet=self.controlnet_openpose,
106
+ torch_dtype=self.torch_dtype,
107
+ safety_checker=None,
108
+ requires_safety_checker=False
109
+ ).to(self.device)
110
+
111
+ from diffusers import EulerAncestralDiscreteScheduler
112
+ self.pipe_openpose.scheduler = EulerAncestralDiscreteScheduler.from_config(self.pipe_openpose.scheduler.config)
113
+ self.pipe_openpose.enable_attention_slicing()
114
+ print("✅ OpenPose ControlNet pipeline loaded successfully!")
115
+ except Exception as e:
116
+ print(f"Fehler beim Laden von OpenPose ControlNet: {e}")
117
+ raise
118
+ return self.pipe_openpose
119
+
120
+ elif controlnet_type == "canny":
121
+ if self.pipe_canny is None:
122
+ print("Loading Canny ControlNet pipeline...")
123
+ try:
124
+ self.controlnet_canny = ControlNetModel.from_pretrained(
125
+ "lllyasviel/sd-controlnet-canny",
126
+ torch_dtype=self.torch_dtype
127
+ )
128
+ self.pipe_canny = StableDiffusionControlNetPipeline.from_pretrained(
129
+ "runwayml/stable-diffusion-v1-5",
130
+ controlnet=self.controlnet_canny,
131
+ torch_dtype=self.torch_dtype,
132
+ safety_checker=None,
133
+ requires_safety_checker=False
134
+ ).to(self.device)
135
+
136
+ from diffusers import EulerAncestralDiscreteScheduler
137
+ self.pipe_canny.scheduler = EulerAncestralDiscreteScheduler.from_config(self.pipe_canny.scheduler.config)
138
+ self.pipe_canny.enable_attention_slicing()
139
+ print("✅ Canny ControlNet pipeline loaded successfully!")
140
+ except Exception as e:
141
+ print(f"Fehler beim Laden von Canny ControlNet: {e}")
142
+ raise
143
+ return self.pipe_canny
144
 
 
 
 
145
  def generate_with_controlnet(
146
  self, image, prompt, negative_prompt,
147
  steps, guidance_scale, controlnet_strength,
148
+ progress=None, keep_environment=False
149
  ):
150
+ """Generiert Bild mit ControlNet und Fortschrittsanzeige"""
151
  try:
152
+ # --- KORREKTE LOGIK ---
153
+ if keep_environment:
154
+ # UMGEBUNG BEIBEHALTEN, PERSON ÄNDERN → KOMBINIERTE STRATEGIE
155
+ print("🎯 ControlNet Modus: Umgebung beibehalten (OpenPose + Canny Kombination)")
156
+
157
+ # Schritt 1: OpenPose für Grundpose
158
+ pose_image = self.extract_pose(image)
159
+ print("✅ OpenPose für Grundpose erstellt")
160
+
161
+ # Schritt 2: Canny für Silhouette + Umgebung
162
+ canny_image = self.extract_canny_edges(image)
163
+ print("✅ Canny für Silhouette + Umgebung erstellt")
164
+
165
+ # Kombinierte Conditioning - zuerst mit OpenPose arbeiten
166
+ conditioning_image = pose_image
167
+ controlnet_type = "openpose"
168
+
169
+ else:
170
+ # PERSON BEIBEHALTEN, UMGEBUNG ÄNDERN → NUR OPENPOSE (wie bisher)
171
+ controlnet_type = "openpose"
172
+ print("🎯 ControlNet Modus: Person beibehalten (OpenPose)")
173
+ conditioning_image = self.extract_pose(image)
174
+
175
+ pipe = self.load_controlnet_pipeline(controlnet_type)
176
+
177
+ # Zufälliger Seed
178
  seed = random.randint(0, 2**32 - 1)
179
  generator = torch.Generator(device=self.device).manual_seed(seed)
180
  print(f"ControlNet Seed: {seed}")
181
 
182
+ # Fortschritt-Callback
183
  callback = ControlNetProgressCallback(progress, int(steps)) if progress is not None else None
184
 
185
+ print("🔄 ControlNet: Starte Pipeline...")
186
 
187
+ # ControlNet Generierung
188
  result = pipe(
189
  prompt=prompt,
190
+ image=conditioning_image,
191
  negative_prompt=negative_prompt,
192
  num_inference_steps=int(steps),
193
  guidance_scale=guidance_scale,
 
194
  generator=generator,
195
+ controlnet_conditioning_scale=controlnet_strength,
196
  height=512,
197
  width=512,
198
  output_type="pil",
 
200
  callback_on_step_end_tensor_inputs=[],
201
  )
202
 
203
+ print("✅ ControlNet abgeschlossen!")
204
+
205
+ # ZWEI Werte zurückgeben: ControlNet-Output + ORIGINALBILD für Inpaint
206
+ return result.images[0], image # ← IMMER Originalbild für Inpaint!
207
 
208
  except Exception as e:
209
+ print(f"❌ Fehler in ControlNet: {e}")
210
  import traceback
211
  traceback.print_exc()
212
  error_image = image.convert("RGB").resize((512, 512))
213
  return error_image, error_image
214
 
 
 
 
215
  def prepare_inpaint_input(self, image, keep_environment=False):
216
  """
217
  Bereitet das Input-Bild für Inpaint vor
218
  Rückgabe: (image_für_inpaint, conditioning_info)
219
  """
220
  if keep_environment:
221
+ # PERSON ÄNDERN: Originalbild an Inpaint übergeben
222
  print("🎯 Inpaint: Übergebe Originalbild (Person ändern)")
223
  return image, {"type": "original", "image": image}
224
  else:
225
+ # UMGEBUNG ÄNDERN: Pose-Map an Inpaint übergeben
226
  print("🎯 Inpaint: Übergebe Pose-Map (Umgebung ändern)")
227
  pose_image = self.extract_pose(image)
228
  return pose_image, {"type": "pose", "image": pose_image}
229
 
230
 
231
+ # Globale Instanz
 
 
232
  device = "cuda" if torch.cuda.is_available() else "cpu"
233
  torch_dtype = torch.float16 if device == "cuda" else torch.float32
234
  controlnet_processor = ControlNetProcessor(device=device, torch_dtype=torch_dtype)