Astridkraft commited on
Commit
421d3c4
·
verified ·
1 Parent(s): ae7e31f

Update controlnet_module.py

Browse files
Files changed (1) hide show
  1. controlnet_module.py +85 -99
controlnet_module.py CHANGED
@@ -1,5 +1,9 @@
1
  import torch
2
- from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
 
 
 
 
3
  from controlnet_aux import OpenposeDetector
4
  from PIL import Image
5
  import random
@@ -8,6 +12,9 @@ import numpy as np
8
  import gradio as gr
9
 
10
 
 
 
 
11
  class ControlNetProgressCallback:
12
  def __init__(self, progress, total_steps):
13
  self.progress = progress
@@ -26,33 +33,36 @@ class ControlNetProgressCallback:
26
  return callback_kwargs
27
 
28
 
 
 
 
29
  class ControlNetProcessor:
30
  def __init__(self, device="cuda", torch_dtype=torch.float32):
31
  self.device = device
32
  self.torch_dtype = torch_dtype
33
  self.pose_detector = None
34
- self.controlnet_openpose = None
35
- self.controlnet_canny = None
36
- self.pipe_openpose = None
37
- self.pipe_canny = None
38
 
 
 
 
39
  def load_pose_detector(self):
40
  """Lädt nur den Pose-Detector"""
41
  if self.pose_detector is None:
42
- print("Loading Pose Detector...")
43
  try:
44
  self.pose_detector = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
45
  except Exception as e:
46
- print(f"Warnung: Pose-Detector konnte nicht geladen werden: {e}")
47
  return self.pose_detector
48
 
49
  def extract_pose_simple(self, image):
50
- """Einfache Pose-Extraktion ohne komplexe Abhängigkeiten"""
51
  try:
52
  img_array = np.array(image.convert("RGB"))
53
  edges = cv2.Canny(img_array, 100, 200)
54
  pose_image = Image.fromarray(edges).convert("RGB")
55
- print("⚠️ Verwende Kanten-basierte Pose-Approximation")
56
  return pose_image
57
  except Exception as e:
58
  print(f"Fehler bei einfacher Pose-Extraktion: {e}")
@@ -64,124 +74,99 @@ class ControlNetProcessor:
64
  detector = self.load_pose_detector()
65
  if detector is None:
66
  return self.extract_pose_simple(image)
67
-
68
  pose_image = detector(image, hand_and_face=True)
 
69
  return pose_image
70
  except Exception as e:
71
  print(f"Fehler bei Pose-Extraktion: {e}")
72
  return self.extract_pose_simple(image)
73
 
 
 
 
74
  def extract_canny_edges(self, image):
75
- """Extrahiert Canny Edges für Umgebungserhaltung"""
76
  try:
77
  img_array = np.array(image.convert("RGB"))
78
-
79
- # Canny Edge Detection
80
  gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
81
  edges = cv2.Canny(gray, 100, 200)
82
-
83
- # Zu 3-Kanal Bild konvertieren
84
  edges_rgb = cv2.cvtColor(edges, cv2.COLOR_GRAY2RGB)
85
  edges_image = Image.fromarray(edges_rgb)
86
-
87
- print("✅ Canny Edge für Umgebungserhaltung erstellt")
88
  return edges_image
89
  except Exception as e:
90
- print(f"Fehler bei Canny Edge Extraction: {e}")
91
  return image.convert("RGB").resize((512, 512))
92
 
93
- def load_controlnet_pipeline(self, controlnet_type="openpose"):
94
- """Lädt die passende ControlNet Pipeline"""
95
- if controlnet_type == "openpose":
96
- if self.pipe_openpose is None:
97
- print("Loading OpenPose ControlNet pipeline...")
98
- try:
99
- self.controlnet_openpose = ControlNetModel.from_pretrained(
100
- "lllyasviel/sd-controlnet-openpose",
101
- torch_dtype=self.torch_dtype
102
- )
103
- self.pipe_openpose = StableDiffusionControlNetPipeline.from_pretrained(
104
- "runwayml/stable-diffusion-v1-5",
105
- controlnet=self.controlnet_openpose,
106
- torch_dtype=self.torch_dtype,
107
- safety_checker=None,
108
- requires_safety_checker=False
109
- ).to(self.device)
110
-
111
- from diffusers import EulerAncestralDiscreteScheduler
112
- self.pipe_openpose.scheduler = EulerAncestralDiscreteScheduler.from_config(self.pipe_openpose.scheduler.config)
113
- self.pipe_openpose.enable_attention_slicing()
114
- print("✅ OpenPose ControlNet pipeline loaded successfully!")
115
- except Exception as e:
116
- print(f"Fehler beim Laden von OpenPose ControlNet: {e}")
117
- raise
118
- return self.pipe_openpose
119
-
120
- elif controlnet_type == "canny":
121
- if self.pipe_canny is None:
122
- print("Loading Canny ControlNet pipeline...")
123
- try:
124
- self.controlnet_canny = ControlNetModel.from_pretrained(
125
- "lllyasviel/sd-controlnet-canny",
126
- torch_dtype=self.torch_dtype
127
- )
128
- self.pipe_canny = StableDiffusionControlNetPipeline.from_pretrained(
129
- "runwayml/stable-diffusion-v1-5",
130
- controlnet=self.controlnet_canny,
131
- torch_dtype=self.torch_dtype,
132
- safety_checker=None,
133
- requires_safety_checker=False
134
- ).to(self.device)
135
-
136
- from diffusers import EulerAncestralDiscreteScheduler
137
- self.pipe_canny.scheduler = EulerAncestralDiscreteScheduler.from_config(self.pipe_canny.scheduler.config)
138
- self.pipe_canny.enable_attention_slicing()
139
- print("✅ Canny ControlNet pipeline loaded successfully!")
140
- except Exception as e:
141
- print(f"Fehler beim Laden von Canny ControlNet: {e}")
142
- raise
143
- return self.pipe_canny
144
 
 
 
 
145
  def generate_with_controlnet(
146
  self, image, prompt, negative_prompt,
147
  steps, guidance_scale, controlnet_strength,
148
- progress=None, keep_environment=False
149
  ):
150
- """Generiert Bild mit ControlNet und Fortschrittsanzeige"""
151
  try:
152
- # --- KORREKTE LOGIK ---
153
- if keep_environment:
154
- # UMGEBUNG BEIBEHALTEN, PERSON ÄNDERN
155
- controlnet_type = "canny" # ✅ Canny behält Umgebung
156
- print("🎯 ControlNet Modus: Umgebung beibehalten (Canny Edge)")
157
- conditioning_image = self.extract_canny_edges(image)
158
- else:
159
- # PERSON BEIBEHALTEN, UMGEBUNG ÄNDERN
160
- controlnet_type = "openpose" # ✅ OpenPose behält Person
161
- print("🎯 ControlNet Modus: Person beibehalten (OpenPose)")
162
- conditioning_image = self.extract_pose(image)
163
-
164
- pipe = self.load_controlnet_pipeline(controlnet_type)
165
-
166
- # Zufälliger Seed
167
  seed = random.randint(0, 2**32 - 1)
168
  generator = torch.Generator(device=self.device).manual_seed(seed)
169
  print(f"ControlNet Seed: {seed}")
170
 
171
- # Fortschritt-Callback
172
  callback = ControlNetProgressCallback(progress, int(steps)) if progress is not None else None
173
 
174
- print("🔄 ControlNet: Starte Pipeline...")
175
 
176
- # ControlNet Generierung
177
  result = pipe(
178
  prompt=prompt,
179
- image=conditioning_image,
180
  negative_prompt=negative_prompt,
181
  num_inference_steps=int(steps),
182
  guidance_scale=guidance_scale,
 
183
  generator=generator,
184
- controlnet_conditioning_scale=controlnet_strength,
185
  height=512,
186
  width=512,
187
  output_type="pil",
@@ -189,35 +174,36 @@ class ControlNetProcessor:
189
  callback_on_step_end_tensor_inputs=[],
190
  )
191
 
192
- print("✅ ControlNet abgeschlossen!")
193
-
194
- # ZWEI Werte zurückgeben: ControlNet-Output + ORIGINALBILD für Inpaint
195
- return result.images[0], image # ← IMMER Originalbild für Inpaint!
196
 
197
  except Exception as e:
198
- print(f"❌ Fehler in ControlNet: {e}")
199
  import traceback
200
  traceback.print_exc()
201
  error_image = image.convert("RGB").resize((512, 512))
202
  return error_image, error_image
203
 
 
 
 
204
  def prepare_inpaint_input(self, image, keep_environment=False):
205
  """
206
  Bereitet das Input-Bild für Inpaint vor
207
  Rückgabe: (image_für_inpaint, conditioning_info)
208
  """
209
  if keep_environment:
210
- # PERSON ÄNDERN: Originalbild an Inpaint übergeben
211
  print("🎯 Inpaint: Übergebe Originalbild (Person ändern)")
212
  return image, {"type": "original", "image": image}
213
  else:
214
- # UMGEBUNG ÄNDERN: Pose-Map an Inpaint übergeben
215
  print("🎯 Inpaint: Übergebe Pose-Map (Umgebung ändern)")
216
  pose_image = self.extract_pose(image)
217
  return pose_image, {"type": "pose", "image": pose_image}
218
 
219
 
220
- # Globale Instanz
 
 
221
  device = "cuda" if torch.cuda.is_available() else "cpu"
222
  torch_dtype = torch.float16 if device == "cuda" else torch.float32
223
  controlnet_processor = ControlNetProcessor(device=device, torch_dtype=torch_dtype)
 
1
  import torch
2
+ from diffusers import (
3
+ StableDiffusionControlNetPipeline,
4
+ ControlNetModel,
5
+ MultiControlNetModel,
6
+ )
7
  from controlnet_aux import OpenposeDetector
8
  from PIL import Image
9
  import random
 
12
  import gradio as gr
13
 
14
 
15
+ # ============================================================
16
+ # PROGRESS CALLBACK
17
+ # ============================================================
18
  class ControlNetProgressCallback:
19
  def __init__(self, progress, total_steps):
20
  self.progress = progress
 
33
  return callback_kwargs
34
 
35
 
36
+ # ============================================================
37
+ # CONTROLNET PROZESSOR
38
+ # ============================================================
39
  class ControlNetProcessor:
40
  def __init__(self, device="cuda", torch_dtype=torch.float32):
41
  self.device = device
42
  self.torch_dtype = torch_dtype
43
  self.pose_detector = None
44
+ self.pipe_multi = None # Multi-ControlNet (OpenPose + Canny)
 
 
 
45
 
46
+ # ------------------------------------------------------------
47
+ # POSE DETECTOR
48
+ # ------------------------------------------------------------
49
  def load_pose_detector(self):
50
  """Lädt nur den Pose-Detector"""
51
  if self.pose_detector is None:
52
+ print("🧠 Lade Pose Detector...")
53
  try:
54
  self.pose_detector = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
55
  except Exception as e:
56
+ print(f"⚠️ Warnung: Pose-Detector konnte nicht geladen werden: {e}")
57
  return self.pose_detector
58
 
59
  def extract_pose_simple(self, image):
60
+ """Fallback: Kantenbasierte Pose"""
61
  try:
62
  img_array = np.array(image.convert("RGB"))
63
  edges = cv2.Canny(img_array, 100, 200)
64
  pose_image = Image.fromarray(edges).convert("RGB")
65
+ print("⚠️ Verwende einfache Kanten-Pose")
66
  return pose_image
67
  except Exception as e:
68
  print(f"Fehler bei einfacher Pose-Extraktion: {e}")
 
74
  detector = self.load_pose_detector()
75
  if detector is None:
76
  return self.extract_pose_simple(image)
 
77
  pose_image = detector(image, hand_and_face=True)
78
+ print("✅ Pose-Map erfolgreich extrahiert")
79
  return pose_image
80
  except Exception as e:
81
  print(f"Fehler bei Pose-Extraktion: {e}")
82
  return self.extract_pose_simple(image)
83
 
84
+ # ------------------------------------------------------------
85
+ # CANNY EDGE
86
+ # ------------------------------------------------------------
87
  def extract_canny_edges(self, image):
88
+ """Extrahiert Canny-Kantenbild zur Umgebungserhaltung"""
89
  try:
90
  img_array = np.array(image.convert("RGB"))
 
 
91
  gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
92
  edges = cv2.Canny(gray, 100, 200)
 
 
93
  edges_rgb = cv2.cvtColor(edges, cv2.COLOR_GRAY2RGB)
94
  edges_image = Image.fromarray(edges_rgb)
95
+ print("✅ Canny Edge Map erstellt")
 
96
  return edges_image
97
  except Exception as e:
98
+ print(f"Fehler bei Canny-Extraktion: {e}")
99
  return image.convert("RGB").resize((512, 512))
100
 
101
+ # ------------------------------------------------------------
102
+ # PIPELINE-LADER
103
+ # ------------------------------------------------------------
104
+ def load_controlnet_pipeline(self):
105
+ """Lädt kombinierte Multi-ControlNet Pipeline (OpenPose + Canny)"""
106
+ if self.pipe_multi is None:
107
+ print("🧩 Lade Multi-ControlNet Pipeline (OpenPose + Canny)...")
108
+ try:
109
+ controlnet_openpose = ControlNetModel.from_pretrained(
110
+ "lllyasviel/sd-controlnet-openpose",
111
+ torch_dtype=self.torch_dtype
112
+ )
113
+ controlnet_canny = ControlNetModel.from_pretrained(
114
+ "lllyasviel/sd-controlnet-canny",
115
+ torch_dtype=self.torch_dtype
116
+ )
117
+
118
+ multi_controlnet = MultiControlNetModel([controlnet_openpose, controlnet_canny])
119
+
120
+ self.pipe_multi = StableDiffusionControlNetPipeline.from_pretrained(
121
+ "runwayml/stable-diffusion-v1-5",
122
+ controlnet=multi_controlnet,
123
+ torch_dtype=self.torch_dtype,
124
+ safety_checker=None,
125
+ requires_safety_checker=False
126
+ ).to(self.device)
127
+
128
+ from diffusers import EulerAncestralDiscreteScheduler
129
+ self.pipe_multi.scheduler = EulerAncestralDiscreteScheduler.from_config(self.pipe_multi.scheduler.config)
130
+ self.pipe_multi.enable_attention_slicing()
131
+
132
+ print("✅ Multi-ControlNet Pipeline erfolgreich geladen!")
133
+ except Exception as e:
134
+ print(f"❌ Fehler beim Laden von Multi-ControlNet: {e}")
135
+ raise
136
+ return self.pipe_multi
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
+ # ------------------------------------------------------------
139
+ # GENERIERUNG
140
+ # ------------------------------------------------------------
141
  def generate_with_controlnet(
142
  self, image, prompt, negative_prompt,
143
  steps, guidance_scale, controlnet_strength,
144
+ progress=None
145
  ):
146
+ """Generiert Bild mit OpenPose + Canny Kombination"""
147
  try:
148
+ print("🎯 Modus: Kombiniert (OpenPose + Canny + Inpaint)")
149
+
150
+ pose_map = self.extract_pose(image)
151
+ canny_map = self.extract_canny_edges(image)
152
+ pipe = self.load_controlnet_pipeline()
153
+
 
 
 
 
 
 
 
 
 
154
  seed = random.randint(0, 2**32 - 1)
155
  generator = torch.Generator(device=self.device).manual_seed(seed)
156
  print(f"ControlNet Seed: {seed}")
157
 
 
158
  callback = ControlNetProgressCallback(progress, int(steps)) if progress is not None else None
159
 
160
+ print("🚀 Starte kombinierte ControlNet-Pipeline...")
161
 
 
162
  result = pipe(
163
  prompt=prompt,
164
+ image=[pose_map, canny_map], # ← Beide Steuerbilder!
165
  negative_prompt=negative_prompt,
166
  num_inference_steps=int(steps),
167
  guidance_scale=guidance_scale,
168
+ controlnet_conditioning_scale=[controlnet_strength, controlnet_strength * 0.7],
169
  generator=generator,
 
170
  height=512,
171
  width=512,
172
  output_type="pil",
 
174
  callback_on_step_end_tensor_inputs=[],
175
  )
176
 
177
+ print("✅ Multi-ControlNet abgeschlossen!")
178
+ return result.images[0], image # Für Inpaint weitergeben
 
 
179
 
180
  except Exception as e:
181
+ print(f"❌ Fehler in Multi-ControlNet: {e}")
182
  import traceback
183
  traceback.print_exc()
184
  error_image = image.convert("RGB").resize((512, 512))
185
  return error_image, error_image
186
 
187
+ # ------------------------------------------------------------
188
+ # INPAINT-VORBEREITUNG
189
+ # ------------------------------------------------------------
190
  def prepare_inpaint_input(self, image, keep_environment=False):
191
  """
192
  Bereitet das Input-Bild für Inpaint vor
193
  Rückgabe: (image_für_inpaint, conditioning_info)
194
  """
195
  if keep_environment:
 
196
  print("🎯 Inpaint: Übergebe Originalbild (Person ändern)")
197
  return image, {"type": "original", "image": image}
198
  else:
 
199
  print("🎯 Inpaint: Übergebe Pose-Map (Umgebung ändern)")
200
  pose_image = self.extract_pose(image)
201
  return pose_image, {"type": "pose", "image": pose_image}
202
 
203
 
204
+ # ============================================================
205
+ # GLOBALE INSTANZ
206
+ # ============================================================
207
  device = "cuda" if torch.cuda.is_available() else "cpu"
208
  torch_dtype = torch.float16 if device == "cuda" else torch.float32
209
  controlnet_processor = ControlNetProcessor(device=device, torch_dtype=torch_dtype)