Astridkraft commited on
Commit
61a2a11
·
verified ·
1 Parent(s): acfa8ed

Create controlnet_module.py

Browse files
Files changed (1) hide show
  1. controlnet_module.py +128 -0
controlnet_module.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
3
+ from controlnet_aux import OpenposeDetector
4
+ from PIL import Image
5
+ import random
6
+ import cv2 #generiert Pose-Maske, geht auch mit matlibplot
7
+ import numpy as np
8
+
9
+ class ControlNetProcessor:
10
+ def __init__(self, device="cuda", torch_dtype=torch.float32):
11
+ self.device = device
12
+ self.torch_dtype = torch_dtype
13
+ self.pose_detector = None
14
+ self.controlnet = None
15
+ self.pipe = None
16
+
17
+ def load_pose_detector(self):
18
+ """Lädt nur den Pose-Detector"""
19
+ if self.pose_detector is None:
20
+ print("Loading Pose Detector...")
21
+ try:
22
+ # OpenposeDetector ohne matplotlib Abhängigkeit
23
+ self.pose_detector = OpenposeDetector.from_pretrained(
24
+ "lllyasviel/ControlNet",
25
+ #torch_dtype=self.torch_dtype
26
+ )
27
+ except Exception as e:
28
+ print(f"Warnung: Pose-Detector konnte nicht geladen werden: {e}")
29
+ return self.pose_detector
30
+
31
+ def extract_pose_simple(self, image):
32
+ """Einfache Pose-Extraktion ohne komplexe Abhängigkeiten"""
33
+ try:
34
+ # Fallback: Einfache Kantenerkennung als Pose-Approximation
35
+ img_array = np.array(image.convert("RGB"))
36
+ edges = cv2.Canny(img_array, 100, 200)
37
+ pose_image = Image.fromarray(edges).convert("RGB")
38
+ print("⚠️ Verwende Kanten-basierte Pose-Approximation")
39
+ return pose_image
40
+ except Exception as e:
41
+ print(f"Fehler bei einfacher Pose-Extraktion: {e}")
42
+ return image.convert("RGB").resize((512, 512))
43
+
44
+ def extract_pose(self, image):
45
+ """Extrahiert Pose-Map aus Bild mit Fallback"""
46
+ try:
47
+ detector = self.load_pose_detector()
48
+ if detector is None:
49
+ return self.extract_pose_simple(image)
50
+
51
+ pose_image = detector(image, hand_and_face=True, detect_resolution=512)
52
+ return pose_image
53
+ except Exception as e:
54
+ print(f"Fehler bei Pose-Extraktion: {e}")
55
+ return self.extract_pose_simple(image)
56
+
57
+ def generate_with_controlnet(self, image, prompt, negative_prompt,
58
+ steps, guidance_scale, controlnet_strength):
59
+ """Generiert Bild mit ControlNet"""
60
+ try:
61
+ # Zuerst Pipeline laden um Fehler früh zu erkennen
62
+ pipe = self.load_controlnet_pipeline()
63
+
64
+ # Pose extrahieren
65
+ print("🔄 ControlNet: Extrahiere Pose...")
66
+ pose_map = self.extract_pose(image)
67
+
68
+ # Zufälliger Seed
69
+ seed = random.randint(0, 2**32 - 1)
70
+ generator = torch.Generator(device=self.device).manual_seed(seed)
71
+ print(f"ControlNet Seed: {seed}")
72
+
73
+ # ControlNet anwenden
74
+ print("🔄 ControlNet: Wende Pose-Kontrolle an...")
75
+ result = pipe(
76
+ prompt=prompt,
77
+ image=pose_map,
78
+ negative_prompt=negative_prompt,
79
+ num_inference_steps=int(steps),
80
+ guidance_scale=guidance_scale,
81
+ generator=generator,
82
+ controlnet_conditioning_scale=controlnet_strength,
83
+ height=512,
84
+ width=512,
85
+ output_type="pil"
86
+ )
87
+
88
+ print("✅ ControlNet abgeschlossen!")
89
+ return result.images[0]
90
+
91
+ except Exception as e:
92
+ print(f"❌ Fehler in ControlNet: {e}")
93
+ # Fallback: Originalbild zurückgeben
94
+ return image.convert("RGB").resize((512, 512))
95
+
96
+ def load_controlnet_pipeline(self):
97
+ """Lädt die ControlNet Pipeline"""
98
+ if self.pipe is None:
99
+ print("Loading ControlNet pipeline...")
100
+ try:
101
+ self.controlnet = ControlNetModel.from_pretrained(
102
+ "lllyasviel/sd-controlnet-openpose",
103
+ torch_dtype=self.torch_dtype
104
+ )
105
+ self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
106
+ "runwayml/stable-diffusion-v1-5",
107
+ controlnet=self.controlnet,
108
+ torch_dtype=self.torch_dtype,
109
+ safety_checker=None,
110
+ requires_safety_checker=False
111
+ ).to(self.device)
112
+
113
+ from diffusers import DPMSolverMultistepScheduler
114
+ self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(
115
+ self.pipe.scheduler.config
116
+ )
117
+
118
+ self.pipe.enable_attention_slicing()
119
+ print("ControlNet pipeline loaded successfully!")
120
+ except Exception as e:
121
+ print(f"Fehler beim Laden von ControlNet: {e}")
122
+ raise
123
+ return self.pipe
124
+
125
+ # Globale Instanz
126
+ device = "cuda" if torch.cuda.is_available() else "cpu"
127
+ torch_dtype = torch.float16 if device == "cuda" else torch.float32
128
+ controlnet_processor = ControlNetProcessor(device=device, torch_dtype=torch_dtype)