|
|
import torch |
|
|
from diffusers import ( |
|
|
StableDiffusionControlNetPipeline, |
|
|
ControlNetModel, |
|
|
MultiControlNetModel, |
|
|
) |
|
|
from controlnet_aux import OpenposeDetector |
|
|
from PIL import Image |
|
|
import random |
|
|
import cv2 |
|
|
import numpy as np |
|
|
import gradio as gr |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ControlNetProgressCallback: |
|
|
def __init__(self, progress, total_steps): |
|
|
self.progress = progress |
|
|
self.total_steps = total_steps |
|
|
self.current_step = 0 |
|
|
|
|
|
def __call__(self, pipe, step_index, timestep, callback_kwargs): |
|
|
self.current_step = step_index + 1 |
|
|
progress_percentage = self.current_step / self.total_steps |
|
|
|
|
|
|
|
|
if self.progress is not None: |
|
|
self.progress(progress_percentage, desc=f"ControlNet: Schritt {self.current_step}/{self.total_steps}") |
|
|
|
|
|
print(f"ControlNet Fortschritt: {self.current_step}/{self.total_steps} ({progress_percentage:.1%})") |
|
|
return callback_kwargs |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ControlNetProcessor: |
|
|
def __init__(self, device="cuda", torch_dtype=torch.float32): |
|
|
self.device = device |
|
|
self.torch_dtype = torch_dtype |
|
|
self.pose_detector = None |
|
|
self.pipe_multi = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_pose_detector(self): |
|
|
"""Lädt nur den Pose-Detector""" |
|
|
if self.pose_detector is None: |
|
|
print("🧠 Lade Pose Detector...") |
|
|
try: |
|
|
self.pose_detector = OpenposeDetector.from_pretrained("lllyasviel/ControlNet") |
|
|
except Exception as e: |
|
|
print(f"⚠️ Warnung: Pose-Detector konnte nicht geladen werden: {e}") |
|
|
return self.pose_detector |
|
|
|
|
|
def extract_pose_simple(self, image): |
|
|
"""Fallback: Kantenbasierte Pose""" |
|
|
try: |
|
|
img_array = np.array(image.convert("RGB")) |
|
|
edges = cv2.Canny(img_array, 100, 200) |
|
|
pose_image = Image.fromarray(edges).convert("RGB") |
|
|
print("⚠️ Verwende einfache Kanten-Pose") |
|
|
return pose_image |
|
|
except Exception as e: |
|
|
print(f"Fehler bei einfacher Pose-Extraktion: {e}") |
|
|
return image.convert("RGB").resize((512, 512)) |
|
|
|
|
|
def extract_pose(self, image): |
|
|
"""Extrahiert Pose-Map aus Bild mit Fallback""" |
|
|
try: |
|
|
detector = self.load_pose_detector() |
|
|
if detector is None: |
|
|
return self.extract_pose_simple(image) |
|
|
pose_image = detector(image, hand_and_face=True) |
|
|
print("✅ Pose-Map erfolgreich extrahiert") |
|
|
return pose_image |
|
|
except Exception as e: |
|
|
print(f"Fehler bei Pose-Extraktion: {e}") |
|
|
return self.extract_pose_simple(image) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def extract_canny_edges(self, image): |
|
|
"""Extrahiert Canny-Kantenbild zur Umgebungserhaltung""" |
|
|
try: |
|
|
img_array = np.array(image.convert("RGB")) |
|
|
gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY) |
|
|
edges = cv2.Canny(gray, 100, 200) |
|
|
edges_rgb = cv2.cvtColor(edges, cv2.COLOR_GRAY2RGB) |
|
|
edges_image = Image.fromarray(edges_rgb) |
|
|
print("✅ Canny Edge Map erstellt") |
|
|
return edges_image |
|
|
except Exception as e: |
|
|
print(f"Fehler bei Canny-Extraktion: {e}") |
|
|
return image.convert("RGB").resize((512, 512)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_controlnet_pipeline(self): |
|
|
"""Lädt kombinierte Multi-ControlNet Pipeline (OpenPose + Canny)""" |
|
|
if self.pipe_multi is None: |
|
|
print("🧩 Lade Multi-ControlNet Pipeline (OpenPose + Canny)...") |
|
|
try: |
|
|
controlnet_openpose = ControlNetModel.from_pretrained( |
|
|
"lllyasviel/sd-controlnet-openpose", |
|
|
torch_dtype=self.torch_dtype |
|
|
) |
|
|
controlnet_canny = ControlNetModel.from_pretrained( |
|
|
"lllyasviel/sd-controlnet-canny", |
|
|
torch_dtype=self.torch_dtype |
|
|
) |
|
|
|
|
|
multi_controlnet = MultiControlNetModel([controlnet_openpose, controlnet_canny]) |
|
|
|
|
|
self.pipe_multi = StableDiffusionControlNetPipeline.from_pretrained( |
|
|
"runwayml/stable-diffusion-v1-5", |
|
|
controlnet=multi_controlnet, |
|
|
torch_dtype=self.torch_dtype, |
|
|
safety_checker=None, |
|
|
requires_safety_checker=False |
|
|
).to(self.device) |
|
|
|
|
|
from diffusers import EulerAncestralDiscreteScheduler |
|
|
self.pipe_multi.scheduler = EulerAncestralDiscreteScheduler.from_config(self.pipe_multi.scheduler.config) |
|
|
self.pipe_multi.enable_attention_slicing() |
|
|
|
|
|
print("✅ Multi-ControlNet Pipeline erfolgreich geladen!") |
|
|
except Exception as e: |
|
|
print(f"❌ Fehler beim Laden von Multi-ControlNet: {e}") |
|
|
raise |
|
|
return self.pipe_multi |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_with_controlnet( |
|
|
self, image, prompt, negative_prompt, |
|
|
steps, guidance_scale, controlnet_strength, |
|
|
progress=None |
|
|
): |
|
|
"""Generiert Bild mit OpenPose + Canny Kombination""" |
|
|
try: |
|
|
print("🎯 Modus: Kombiniert (OpenPose + Canny + Inpaint)") |
|
|
|
|
|
pose_map = self.extract_pose(image) |
|
|
canny_map = self.extract_canny_edges(image) |
|
|
pipe = self.load_controlnet_pipeline() |
|
|
|
|
|
seed = random.randint(0, 2**32 - 1) |
|
|
generator = torch.Generator(device=self.device).manual_seed(seed) |
|
|
print(f"ControlNet Seed: {seed}") |
|
|
|
|
|
callback = ControlNetProgressCallback(progress, int(steps)) if progress is not None else None |
|
|
|
|
|
print("🚀 Starte kombinierte ControlNet-Pipeline...") |
|
|
|
|
|
result = pipe( |
|
|
prompt=prompt, |
|
|
image=[pose_map, canny_map], |
|
|
negative_prompt=negative_prompt, |
|
|
num_inference_steps=int(steps), |
|
|
guidance_scale=guidance_scale, |
|
|
controlnet_conditioning_scale=[controlnet_strength, controlnet_strength * 0.7], |
|
|
generator=generator, |
|
|
height=512, |
|
|
width=512, |
|
|
output_type="pil", |
|
|
callback_on_step_end=callback, |
|
|
callback_on_step_end_tensor_inputs=[], |
|
|
) |
|
|
|
|
|
print("✅ Multi-ControlNet abgeschlossen!") |
|
|
return result.images[0], image |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ Fehler in Multi-ControlNet: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
error_image = image.convert("RGB").resize((512, 512)) |
|
|
return error_image, error_image |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def prepare_inpaint_input(self, image, keep_environment=False): |
|
|
""" |
|
|
Bereitet das Input-Bild für Inpaint vor |
|
|
Rückgabe: (image_für_inpaint, conditioning_info) |
|
|
""" |
|
|
if keep_environment: |
|
|
print("🎯 Inpaint: Übergebe Originalbild (Person ändern)") |
|
|
return image, {"type": "original", "image": image} |
|
|
else: |
|
|
print("🎯 Inpaint: Übergebe Pose-Map (Umgebung ändern)") |
|
|
pose_image = self.extract_pose(image) |
|
|
return pose_image, {"type": "pose", "image": pose_image} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
torch_dtype = torch.float16 if device == "cuda" else torch.float32 |
|
|
controlnet_processor = ControlNetProcessor(device=device, torch_dtype=torch_dtype) |