import torch from diffusers import ( StableDiffusionControlNetPipeline, ControlNetModel, MultiControlNetModel, ) from controlnet_aux import OpenposeDetector from PIL import Image import random import cv2 import numpy as np import gradio as gr # ============================================================ # PROGRESS CALLBACK # ============================================================ class ControlNetProgressCallback: def __init__(self, progress, total_steps): self.progress = progress self.total_steps = total_steps self.current_step = 0 def __call__(self, pipe, step_index, timestep, callback_kwargs): self.current_step = step_index + 1 progress_percentage = self.current_step / self.total_steps # Fortschritt aktualisieren if self.progress is not None: self.progress(progress_percentage, desc=f"ControlNet: Schritt {self.current_step}/{self.total_steps}") print(f"ControlNet Fortschritt: {self.current_step}/{self.total_steps} ({progress_percentage:.1%})") return callback_kwargs # ============================================================ # CONTROLNET PROZESSOR # ============================================================ class ControlNetProcessor: def __init__(self, device="cuda", torch_dtype=torch.float32): self.device = device self.torch_dtype = torch_dtype self.pose_detector = None self.pipe_multi = None # Multi-ControlNet (OpenPose + Canny) # ------------------------------------------------------------ # POSE DETECTOR # ------------------------------------------------------------ def load_pose_detector(self): """Lädt nur den Pose-Detector""" if self.pose_detector is None: print("🧠 Lade Pose Detector...") try: self.pose_detector = OpenposeDetector.from_pretrained("lllyasviel/ControlNet") except Exception as e: print(f"⚠️ Warnung: Pose-Detector konnte nicht geladen werden: {e}") return self.pose_detector def extract_pose_simple(self, image): """Fallback: Kantenbasierte Pose""" try: img_array = np.array(image.convert("RGB")) edges = cv2.Canny(img_array, 100, 200) pose_image = Image.fromarray(edges).convert("RGB") print("⚠️ Verwende einfache Kanten-Pose") return pose_image except Exception as e: print(f"Fehler bei einfacher Pose-Extraktion: {e}") return image.convert("RGB").resize((512, 512)) def extract_pose(self, image): """Extrahiert Pose-Map aus Bild mit Fallback""" try: detector = self.load_pose_detector() if detector is None: return self.extract_pose_simple(image) pose_image = detector(image, hand_and_face=True) print("✅ Pose-Map erfolgreich extrahiert") return pose_image except Exception as e: print(f"Fehler bei Pose-Extraktion: {e}") return self.extract_pose_simple(image) # ------------------------------------------------------------ # CANNY EDGE # ------------------------------------------------------------ def extract_canny_edges(self, image): """Extrahiert Canny-Kantenbild zur Umgebungserhaltung""" try: img_array = np.array(image.convert("RGB")) gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY) edges = cv2.Canny(gray, 100, 200) edges_rgb = cv2.cvtColor(edges, cv2.COLOR_GRAY2RGB) edges_image = Image.fromarray(edges_rgb) print("✅ Canny Edge Map erstellt") return edges_image except Exception as e: print(f"Fehler bei Canny-Extraktion: {e}") return image.convert("RGB").resize((512, 512)) # ------------------------------------------------------------ # PIPELINE-LADER # ------------------------------------------------------------ def load_controlnet_pipeline(self): """Lädt kombinierte Multi-ControlNet Pipeline (OpenPose + Canny)""" if self.pipe_multi is None: print("🧩 Lade Multi-ControlNet Pipeline (OpenPose + Canny)...") try: controlnet_openpose = ControlNetModel.from_pretrained( "lllyasviel/sd-controlnet-openpose", torch_dtype=self.torch_dtype ) controlnet_canny = ControlNetModel.from_pretrained( "lllyasviel/sd-controlnet-canny", torch_dtype=self.torch_dtype ) multi_controlnet = MultiControlNetModel([controlnet_openpose, controlnet_canny]) self.pipe_multi = StableDiffusionControlNetPipeline.from_pretrained( "runwayml/stable-diffusion-v1-5", controlnet=multi_controlnet, torch_dtype=self.torch_dtype, safety_checker=None, requires_safety_checker=False ).to(self.device) from diffusers import EulerAncestralDiscreteScheduler self.pipe_multi.scheduler = EulerAncestralDiscreteScheduler.from_config(self.pipe_multi.scheduler.config) self.pipe_multi.enable_attention_slicing() print("✅ Multi-ControlNet Pipeline erfolgreich geladen!") except Exception as e: print(f"❌ Fehler beim Laden von Multi-ControlNet: {e}") raise return self.pipe_multi # ------------------------------------------------------------ # GENERIERUNG # ------------------------------------------------------------ def generate_with_controlnet( self, image, prompt, negative_prompt, steps, guidance_scale, controlnet_strength, progress=None ): """Generiert Bild mit OpenPose + Canny Kombination""" try: print("🎯 Modus: Kombiniert (OpenPose + Canny + Inpaint)") pose_map = self.extract_pose(image) canny_map = self.extract_canny_edges(image) pipe = self.load_controlnet_pipeline() seed = random.randint(0, 2**32 - 1) generator = torch.Generator(device=self.device).manual_seed(seed) print(f"ControlNet Seed: {seed}") callback = ControlNetProgressCallback(progress, int(steps)) if progress is not None else None print("🚀 Starte kombinierte ControlNet-Pipeline...") result = pipe( prompt=prompt, image=[pose_map, canny_map], # ← Beide Steuerbilder! negative_prompt=negative_prompt, num_inference_steps=int(steps), guidance_scale=guidance_scale, controlnet_conditioning_scale=[controlnet_strength, controlnet_strength * 0.7], generator=generator, height=512, width=512, output_type="pil", callback_on_step_end=callback, callback_on_step_end_tensor_inputs=[], ) print("✅ Multi-ControlNet abgeschlossen!") return result.images[0], image # Für Inpaint weitergeben except Exception as e: print(f"❌ Fehler in Multi-ControlNet: {e}") import traceback traceback.print_exc() error_image = image.convert("RGB").resize((512, 512)) return error_image, error_image # ------------------------------------------------------------ # INPAINT-VORBEREITUNG # ------------------------------------------------------------ def prepare_inpaint_input(self, image, keep_environment=False): """ Bereitet das Input-Bild für Inpaint vor Rückgabe: (image_für_inpaint, conditioning_info) """ if keep_environment: print("🎯 Inpaint: Übergebe Originalbild (Person ändern)") return image, {"type": "original", "image": image} else: print("🎯 Inpaint: Übergebe Pose-Map (Umgebung ändern)") pose_image = self.extract_pose(image) return pose_image, {"type": "pose", "image": pose_image} # ============================================================ # GLOBALE INSTANZ # ============================================================ device = "cuda" if torch.cuda.is_available() else "cpu" torch_dtype = torch.float16 if device == "cuda" else torch.float32 controlnet_processor = ControlNetProcessor(device=device, torch_dtype=torch_dtype)