Astridkraft commited on
Commit
073e1ec
·
verified ·
1 Parent(s): d7f8818

Update controlnet_module.py

Browse files
Files changed (1) hide show
  1. controlnet_module.py +46 -10
controlnet_module.py CHANGED
@@ -3,8 +3,26 @@ from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
3
  from controlnet_aux import OpenposeDetector
4
  from PIL import Image
5
  import random
6
- import cv2 #generiert Pose-Maske, geht auch mit matlibplot
7
  import numpy as np
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  class ControlNetProcessor:
10
  def __init__(self, device="cuda", torch_dtype=torch.float32):
@@ -19,10 +37,8 @@ class ControlNetProcessor:
19
  if self.pose_detector is None:
20
  print("Loading Pose Detector...")
21
  try:
22
- # OpenposeDetector ohne matplotlib Abhängigkeit
23
  self.pose_detector = OpenposeDetector.from_pretrained(
24
  "lllyasviel/ControlNet",
25
- #torch_dtype=self.torch_dtype
26
  )
27
  except Exception as e:
28
  print(f"Warnung: Pose-Detector konnte nicht geladen werden: {e}")
@@ -31,7 +47,6 @@ class ControlNetProcessor:
31
  def extract_pose_simple(self, image):
32
  """Einfache Pose-Extraktion ohne komplexe Abhängigkeiten"""
33
  try:
34
- # Fallback: Einfache Kantenerkennung als Pose-Approximation
35
  img_array = np.array(image.convert("RGB"))
36
  edges = cv2.Canny(img_array, 100, 200)
37
  pose_image = Image.fromarray(edges).convert("RGB")
@@ -55,14 +70,17 @@ class ControlNetProcessor:
55
  return self.extract_pose_simple(image)
56
 
57
  def generate_with_controlnet(self, image, prompt, negative_prompt,
58
- steps, guidance_scale, controlnet_strength):
59
- """Generiert Bild mit ControlNet"""
60
  try:
61
- # Zuerst Pipeline laden um Fehler früh zu erkennen
62
  pipe = self.load_controlnet_pipeline()
63
 
64
  # Pose extrahieren
65
  print("🔄 ControlNet: Extrahiere Pose...")
 
 
 
66
  pose_map = self.extract_pose(image)
67
 
68
  # Zufälliger Seed
@@ -70,8 +88,14 @@ class ControlNetProcessor:
70
  generator = torch.Generator(device=self.device).manual_seed(seed)
71
  print(f"ControlNet Seed: {seed}")
72
 
73
- # ControlNet anwenden
 
 
 
 
74
  print("🔄 ControlNet: Wende Pose-Kontrolle an...")
 
 
75
  result = pipe(
76
  prompt=prompt,
77
  image=pose_map,
@@ -82,15 +106,27 @@ class ControlNetProcessor:
82
  controlnet_conditioning_scale=controlnet_strength,
83
  height=512,
84
  width=512,
85
- output_type="pil"
 
 
86
  )
87
 
 
 
 
 
 
 
 
 
 
88
  print("✅ ControlNet abgeschlossen!")
89
  return result.images[0]
90
 
91
  except Exception as e:
92
  print(f"❌ Fehler in ControlNet: {e}")
93
- # Fallback: Originalbild zurückgeben
 
94
  return image.convert("RGB").resize((512, 512))
95
 
96
  def load_controlnet_pipeline(self):
 
3
  from controlnet_aux import OpenposeDetector
4
  from PIL import Image
5
  import random
6
+ import cv2
7
  import numpy as np
8
+ import gradio as gr
9
+
10
+ class ControlNetProgressCallback:
11
+ def __init__(self, progress, total_steps):
12
+ self.progress = progress
13
+ self.total_steps = total_steps
14
+ self.current_step = 0
15
+
16
+ def __call__(self, pipe, step_index, timestep, callback_kwargs):
17
+ self.current_step = step_index + 1
18
+ progress_percentage = self.current_step / self.total_steps
19
+
20
+ # Fortschritt aktualisieren
21
+ if self.progress is not None:
22
+ self.progress(progress_percentage, desc=f"ControlNet: Schritt {self.current_step}/{self.total_steps}")
23
+
24
+ print(f"ControlNet Fortschritt: {self.current_step}/{self.total_steps} ({progress_percentage:.1%})")
25
+ return callback_kwargs
26
 
27
  class ControlNetProcessor:
28
  def __init__(self, device="cuda", torch_dtype=torch.float32):
 
37
  if self.pose_detector is None:
38
  print("Loading Pose Detector...")
39
  try:
 
40
  self.pose_detector = OpenposeDetector.from_pretrained(
41
  "lllyasviel/ControlNet",
 
42
  )
43
  except Exception as e:
44
  print(f"Warnung: Pose-Detector konnte nicht geladen werden: {e}")
 
47
  def extract_pose_simple(self, image):
48
  """Einfache Pose-Extraktion ohne komplexe Abhängigkeiten"""
49
  try:
 
50
  img_array = np.array(image.convert("RGB"))
51
  edges = cv2.Canny(img_array, 100, 200)
52
  pose_image = Image.fromarray(edges).convert("RGB")
 
70
  return self.extract_pose_simple(image)
71
 
72
  def generate_with_controlnet(self, image, prompt, negative_prompt,
73
+ steps, guidance_scale, controlnet_strength, progress=None):
74
+ """Generiert Bild mit ControlNet und Fortschrittsanzeige"""
75
  try:
76
+ # Pipeline laden
77
  pipe = self.load_controlnet_pipeline()
78
 
79
  # Pose extrahieren
80
  print("🔄 ControlNet: Extrahiere Pose...")
81
+ if progress:
82
+ progress(0.05, desc="ControlNet: Extrahiere Pose...")
83
+
84
  pose_map = self.extract_pose(image)
85
 
86
  # Zufälliger Seed
 
88
  generator = torch.Generator(device=self.device).manual_seed(seed)
89
  print(f"ControlNet Seed: {seed}")
90
 
91
+ # Progress Callback erstellen
92
+ callback = None
93
+ if progress is not None:
94
+ callback = ControlNetProgressCallback(progress, int(steps))
95
+
96
  print("🔄 ControlNet: Wende Pose-Kontrolle an...")
97
+
98
+ # ControlNet anwenden mit Callback
99
  result = pipe(
100
  prompt=prompt,
101
  image=pose_map,
 
106
  controlnet_conditioning_scale=controlnet_strength,
107
  height=512,
108
  width=512,
109
+ output_type="pil",
110
+ callback_on_step_end=callback,
111
+ callback_on_step_end_tensor_inputs=[],
112
  )
113
 
114
+ # Debug-Ausgabe der tatsächlichen Steps
115
+ try:
116
+ scheduler = pipe.scheduler
117
+ if hasattr(scheduler, 'timesteps'):
118
+ actual_steps = len(scheduler.timesteps)
119
+ print(f"🎯 CONTROLNET TATSÄCHLICHE STEPS: {actual_steps} (von {steps} angefordert)")
120
+ except Exception as e:
121
+ print(f"⚠️ Konnte ControlNet Scheduler-Info nicht auslesen: {e}")
122
+
123
  print("✅ ControlNet abgeschlossen!")
124
  return result.images[0]
125
 
126
  except Exception as e:
127
  print(f"❌ Fehler in ControlNet: {e}")
128
+ import traceback
129
+ traceback.print_exc()
130
  return image.convert("RGB").resize((512, 512))
131
 
132
  def load_controlnet_pipeline(self):