Astridkraft commited on
Commit
7bbb34d
·
verified ·
1 Parent(s): edf1b1a

Update controlnet_module.py

Browse files
Files changed (1) hide show
  1. controlnet_module.py +122 -77
controlnet_module.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import torch
2
  from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
3
  from controlnet_aux import OpenposeDetector
@@ -31,8 +32,10 @@ class ControlNetProcessor:
31
  self.device = device
32
  self.torch_dtype = torch_dtype
33
  self.pose_detector = None
34
- self.controlnet = None
35
- self.pipe = None
 
 
36
 
37
  def load_pose_detector(self):
38
  """Lädt nur den Pose-Detector"""
@@ -63,12 +66,83 @@ class ControlNetProcessor:
63
  if detector is None:
64
  return self.extract_pose_simple(image)
65
 
66
- pose_image = detector.detect(image, hand_and_face=True)
67
  return pose_image
68
  except Exception as e:
69
  print(f"Fehler bei Pose-Extraktion: {e}")
70
  return self.extract_pose_simple(image)
71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  def generate_with_controlnet(
73
  self, image, prompt, negative_prompt,
74
  steps, guidance_scale, controlnet_strength,
@@ -76,21 +150,19 @@ class ControlNetProcessor:
76
  ):
77
  """Generiert Bild mit ControlNet und Fortschrittsanzeige"""
78
  try:
79
- pipe = self.load_controlnet_pipeline()
80
-
81
- print("🔄 ControlNet: Extrahiere Pose...")
82
- if progress:
83
- progress(0.05, desc="ControlNet: Extrahiere Pose...")
84
-
85
- # --- Fallunterscheidung ---
86
  if keep_environment:
87
- print("🎯 Modus: Umgebung beibehalten (nutze Originalbild als Quelle)")
88
- input_image = image
89
- conditioning_image = None
 
90
  else:
91
- print("🎯 Modus: Umgebung darf sich ändern (nutze Pose-Map)")
 
 
92
  conditioning_image = self.extract_pose(image)
93
- input_image = conditioning_image
 
94
 
95
  # Zufälliger Seed
96
  seed = random.randint(0, 2**32 - 1)
@@ -102,38 +174,21 @@ class ControlNetProcessor:
102
 
103
  print("🔄 ControlNet: Starte Pipeline...")
104
 
105
- if conditioning_image is not None:
106
- # Umgebung darf sich ändern
107
- result = pipe(
108
- prompt=prompt,
109
- image=conditioning_image,
110
- negative_prompt=negative_prompt,
111
- num_inference_steps=int(steps),
112
- guidance_scale=guidance_scale,
113
- generator=generator,
114
- controlnet_conditioning_scale=controlnet_strength,
115
- height=512,
116
- width=512,
117
- output_type="pil",
118
- callback_on_step_end=callback,
119
- callback_on_step_end_tensor_inputs=[],
120
- )
121
- else:
122
- # Umgebung soll beibehalten werden
123
- result = pipe(
124
- prompt=prompt,
125
- image=input_image,
126
- negative_prompt=negative_prompt,
127
- num_inference_steps=int(steps),
128
- guidance_scale=guidance_scale,
129
- generator=generator,
130
- controlnet_conditioning_scale=controlnet_strength,
131
- height=512,
132
- width=512,
133
- output_type="pil",
134
- callback_on_step_end=callback,
135
- callback_on_step_end_tensor_inputs=[],
136
- )
137
 
138
  # Debug-Ausgabe Scheduler Steps
139
  try:
@@ -145,41 +200,31 @@ class ControlNetProcessor:
145
  print(f"⚠️ Konnte ControlNet Scheduler-Info nicht auslesen: {e}")
146
 
147
  print("✅ ControlNet abgeschlossen!")
148
- return result.images[0]
 
 
149
 
150
  except Exception as e:
151
  print(f"❌ Fehler in ControlNet: {e}")
152
  import traceback
153
  traceback.print_exc()
154
- return image.convert("RGB").resize((512, 512))
155
-
156
- def load_controlnet_pipeline(self):
157
- """Lädt die ControlNet Pipeline"""
158
- if self.pipe is None:
159
- print("Loading ControlNet pipeline...")
160
- try:
161
- self.controlnet = ControlNetModel.from_pretrained(
162
- "lllyasviel/sd-controlnet-openpose",
163
- torch_dtype=self.torch_dtype
164
- )
165
- self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
166
- "runwayml/stable-diffusion-v1-5",
167
- controlnet=self.controlnet,
168
- torch_dtype=self.torch_dtype,
169
- safety_checker=None,
170
- requires_safety_checker=False
171
- ).to(self.device)
172
-
173
- # Scheduler wechseln zu Euler Ancestral
174
- from diffusers import EulerAncestralDiscreteScheduler
175
- self.pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(self.pipe.scheduler.config)
176
-
177
- self.pipe.enable_attention_slicing()
178
- print("✅ ControlNet pipeline loaded successfully with EulerAncestralDiscreteScheduler!")
179
- except Exception as e:
180
- print(f"Fehler beim Laden von ControlNet: {e}")
181
- raise
182
- return self.pipe
183
 
184
 
185
  # Globale Instanz
 
1
+ # controlnet_processor.py
2
  import torch
3
  from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
4
  from controlnet_aux import OpenposeDetector
 
32
  self.device = device
33
  self.torch_dtype = torch_dtype
34
  self.pose_detector = None
35
+ self.controlnet_openpose = None
36
+ self.controlnet_canny = None
37
+ self.pipe_openpose = None
38
+ self.pipe_canny = None
39
 
40
  def load_pose_detector(self):
41
  """Lädt nur den Pose-Detector"""
 
66
  if detector is None:
67
  return self.extract_pose_simple(image)
68
 
69
+ pose_image = detector(image, hand_and_face=True)
70
  return pose_image
71
  except Exception as e:
72
  print(f"Fehler bei Pose-Extraktion: {e}")
73
  return self.extract_pose_simple(image)
74
 
75
+ def extract_canny_edges(self, image):
76
+ """Extrahiert Canny Edges für Umgebungserhaltung"""
77
+ try:
78
+ img_array = np.array(image.convert("RGB"))
79
+
80
+ # Canny Edge Detection
81
+ gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
82
+ edges = cv2.Canny(gray, 100, 200)
83
+
84
+ # Zu 3-Kanal Bild konvertieren
85
+ edges_rgb = cv2.cvtColor(edges, cv2.COLOR_GRAY2RGB)
86
+ edges_image = Image.fromarray(edges_rgb)
87
+
88
+ print("✅ Canny Edge für Umgebungserhaltung erstellt")
89
+ return edges_image
90
+ except Exception as e:
91
+ print(f"Fehler bei Canny Edge Extraction: {e}")
92
+ return image.convert("RGB").resize((512, 512))
93
+
94
+ def load_controlnet_pipeline(self, controlnet_type="openpose"):
95
+ """Lädt die passende ControlNet Pipeline"""
96
+ if controlnet_type == "openpose":
97
+ if self.pipe_openpose is None:
98
+ print("Loading OpenPose ControlNet pipeline...")
99
+ try:
100
+ self.controlnet_openpose = ControlNetModel.from_pretrained(
101
+ "lllyasviel/sd-controlnet-openpose",
102
+ torch_dtype=self.torch_dtype
103
+ )
104
+ self.pipe_openpose = StableDiffusionControlNetPipeline.from_pretrained(
105
+ "runwayml/stable-diffusion-v1-5",
106
+ controlnet=self.controlnet_openpose,
107
+ torch_dtype=self.torch_dtype,
108
+ safety_checker=None,
109
+ requires_safety_checker=False
110
+ ).to(self.device)
111
+
112
+ from diffusers import EulerAncestralDiscreteScheduler
113
+ self.pipe_openpose.scheduler = EulerAncestralDiscreteScheduler.from_config(self.pipe_openpose.scheduler.config)
114
+ self.pipe_openpose.enable_attention_slicing()
115
+ print("✅ OpenPose ControlNet pipeline loaded successfully!")
116
+ except Exception as e:
117
+ print(f"Fehler beim Laden von OpenPose ControlNet: {e}")
118
+ raise
119
+ return self.pipe_openpose
120
+
121
+ elif controlnet_type == "canny":
122
+ if self.pipe_canny is None:
123
+ print("Loading Canny ControlNet pipeline...")
124
+ try:
125
+ self.controlnet_canny = ControlNetModel.from_pretrained(
126
+ "lllyasviel/sd-controlnet-canny",
127
+ torch_dtype=self.torch_dtype
128
+ )
129
+ self.pipe_canny = StableDiffusionControlNetPipeline.from_pretrained(
130
+ "runwayml/stable-diffusion-v1-5",
131
+ controlnet=self.controlnet_canny,
132
+ torch_dtype=self.torch_dtype,
133
+ safety_checker=None,
134
+ requires_safety_checker=False
135
+ ).to(self.device)
136
+
137
+ from diffusers import EulerAncestralDiscreteScheduler
138
+ self.pipe_canny.scheduler = EulerAncestralDiscreteScheduler.from_config(self.pipe_canny.scheduler.config)
139
+ self.pipe_canny.enable_attention_slicing()
140
+ print("✅ Canny ControlNet pipeline loaded successfully!")
141
+ except Exception as e:
142
+ print(f"Fehler beim Laden von Canny ControlNet: {e}")
143
+ raise
144
+ return self.pipe_canny
145
+
146
  def generate_with_controlnet(
147
  self, image, prompt, negative_prompt,
148
  steps, guidance_scale, controlnet_strength,
 
150
  ):
151
  """Generiert Bild mit ControlNet und Fortschrittsanzeige"""
152
  try:
153
+ # --- ENTSCHEIDUNG: Welches ControlNet für welche Aufgabe? ---
 
 
 
 
 
 
154
  if keep_environment:
155
+ # PERSON ÄNDERN, UMGEBUNG BEIBEHALTEN
156
+ controlnet_type = "canny"
157
+ print("🎯 ControlNet Modus: Umgebung beibehalten (Canny Edge)")
158
+ conditioning_image = self.extract_canny_edges(image)
159
  else:
160
+ # UMGEBUNG ÄNDERN, PERSON BEIBEHALTEN
161
+ controlnet_type = "openpose"
162
+ print("🎯 ControlNet Modus: Person beibehalten (OpenPose)")
163
  conditioning_image = self.extract_pose(image)
164
+
165
+ pipe = self.load_controlnet_pipeline(controlnet_type)
166
 
167
  # Zufälliger Seed
168
  seed = random.randint(0, 2**32 - 1)
 
174
 
175
  print("🔄 ControlNet: Starte Pipeline...")
176
 
177
+ # ControlNet Generierung
178
+ result = pipe(
179
+ prompt=prompt,
180
+ image=conditioning_image,
181
+ negative_prompt=negative_prompt,
182
+ num_inference_steps=int(steps),
183
+ guidance_scale=guidance_scale,
184
+ generator=generator,
185
+ controlnet_conditioning_scale=controlnet_strength,
186
+ height=512,
187
+ width=512,
188
+ output_type="pil",
189
+ callback_on_step_end=callback,
190
+ callback_on_step_end_tensor_inputs=[],
191
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
 
193
  # Debug-Ausgabe Scheduler Steps
194
  try:
 
200
  print(f"⚠️ Konnte ControlNet Scheduler-Info nicht auslesen: {e}")
201
 
202
  print("✅ ControlNet abgeschlossen!")
203
+
204
+ # KORREKTUR: ZWEI Werte zurückgeben
205
+ return result.images[0], conditioning_image
206
 
207
  except Exception as e:
208
  print(f"❌ Fehler in ControlNet: {e}")
209
  import traceback
210
  traceback.print_exc()
211
+ error_image = image.convert("RGB").resize((512, 512))
212
+ return error_image, error_image
213
+
214
+ def prepare_inpaint_input(self, image, keep_environment=False):
215
+ """
216
+ Bereitet das Input-Bild für Inpaint vor
217
+ Rückgabe: (image_für_inpaint, conditioning_info)
218
+ """
219
+ if keep_environment:
220
+ # PERSON ÄNDERN: Originalbild an Inpaint übergeben
221
+ print("🎯 Inpaint: Übergebe Originalbild (Person ändern)")
222
+ return image, {"type": "original", "image": image}
223
+ else:
224
+ # UMGEBUNG ÄNDERN: Pose-Map an Inpaint übergeben
225
+ print("🎯 Inpaint: Übergebe Pose-Map (Umgebung ändern)")
226
+ pose_image = self.extract_pose(image)
227
+ return pose_image, {"type": "pose", "image": pose_image}
 
 
 
 
 
 
 
 
 
 
 
 
228
 
229
 
230
  # Globale Instanz