Spaces:
Running
on
T4
Running
on
T4
| import torch | |
| from diffusers import StableDiffusionControlNetPipeline, ControlNetModel | |
| from controlnet_aux import OpenposeDetector | |
| from PIL import Image | |
| import random | |
| import cv2 | |
| import numpy as np | |
| import gradio as gr | |
| class ControlNetProgressCallback: | |
| def __init__(self, progress, total_steps): | |
| self.progress = progress | |
| self.total_steps = total_steps | |
| self.current_step = 0 | |
| def __call__(self, pipe, step_index, timestep, callback_kwargs): | |
| self.current_step = step_index + 1 | |
| progress_percentage = self.current_step / self.total_steps | |
| # Fortschritt aktualisieren | |
| if self.progress is not None: | |
| self.progress(progress_percentage, desc=f"ControlNet: Schritt {self.current_step}/{self.total_steps}") | |
| print(f"ControlNet Fortschritt: {self.current_step}/{self.total_steps} ({progress_percentage:.1%})") | |
| return callback_kwargs | |
| class ControlNetProcessor: | |
| def __init__(self, device="cuda", torch_dtype=torch.float32): | |
| self.device = device | |
| self.torch_dtype = torch_dtype | |
| self.pose_detector = None | |
| self.controlnet_openpose = None | |
| self.controlnet_canny = None | |
| self.pipe_openpose = None | |
| self.pipe_canny = None | |
| self.pipe_multi = None | |
| def load_pose_detector(self): | |
| """Lädt nur den Pose-Detector""" | |
| if self.pose_detector is None: | |
| print("Loading Pose Detector...") | |
| try: | |
| self.pose_detector = OpenposeDetector.from_pretrained("lllyasviel/ControlNet") | |
| except Exception as e: | |
| print(f"Warnung: Pose-Detector konnte nicht geladen werden: {e}") | |
| return self.pose_detector | |
| def extract_pose_simple(self, image): | |
| """Einfache Pose-Extraktion ohne komplexe Abhängigkeiten""" | |
| try: | |
| img_array = np.array(image.convert("RGB")) | |
| edges = cv2.Canny(img_array, 100, 200) | |
| pose_image = Image.fromarray(edges).convert("RGB") | |
| print("⚠️ Verwende Kanten-basierte Pose-Approximation") | |
| return pose_image | |
| except Exception as e: | |
| print(f"Fehler bei einfacher Pose-Extraktion: {e}") | |
| return image.convert("RGB").resize((512, 512)) | |
| def extract_pose(self, image): | |
| """Extrahiert Pose-Map aus Bild mit Fallback""" | |
| try: | |
| detector = self.load_pose_detector() | |
| if detector is None: | |
| return self.extract_pose_simple(image) | |
| pose_image = detector(image, hand_and_face=True) | |
| return pose_image | |
| except Exception as e: | |
| print(f"Fehler bei Pose-Extraktion: {e}") | |
| return self.extract_pose_simple(image) | |
| def extract_canny_edges(self, image): | |
| """Extrahiert Canny Edges für Umgebungserhaltung""" | |
| try: | |
| img_array = np.array(image.convert("RGB")) | |
| # Canny Edge Detection | |
| gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY) | |
| edges = cv2.Canny(gray, 100, 200) | |
| # Zu 3-Kanal Bild konvertieren | |
| edges_rgb = cv2.cvtColor(edges, cv2.COLOR_GRAY2RGB) | |
| edges_image = Image.fromarray(edges_rgb) | |
| print("✅ Canny Edge für Umgebungserhaltung erstellt") | |
| return edges_image | |
| except Exception as e: | |
| print(f"Fehler bei Canny Edge Extraction: {e}") | |
| return image.convert("RGB").resize((512, 512)) | |
| def load_controlnet_pipeline(self, controlnet_type="openpose"): | |
| """Lädt die passende ControlNet Pipeline""" | |
| if controlnet_type == "openpose": | |
| if self.pipe_openpose is None: | |
| print("Loading OpenPose ControlNet pipeline...") | |
| try: | |
| self.controlnet_openpose = ControlNetModel.from_pretrained( | |
| "lllyasviel/sd-controlnet-openpose", | |
| torch_dtype=self.torch_dtype | |
| ) | |
| self.pipe_openpose = StableDiffusionControlNetPipeline.from_pretrained( | |
| "runwayml/stable-diffusion-v1-5", | |
| controlnet=self.controlnet_openpose, | |
| torch_dtype=self.torch_dtype, | |
| safety_checker=None, | |
| requires_safety_checker=False | |
| ).to(self.device) | |
| from diffusers import EulerAncestralDiscreteScheduler | |
| self.pipe_openpose.scheduler = EulerAncestralDiscreteScheduler.from_config(self.pipe_openpose.scheduler.config) | |
| self.pipe_openpose.enable_attention_slicing() | |
| print("✅ OpenPose ControlNet pipeline loaded successfully!") | |
| except Exception as e: | |
| print(f"Fehler beim Laden von OpenPose ControlNet: {e}") | |
| raise | |
| return self.pipe_openpose | |
| elif controlnet_type == "canny": | |
| if self.pipe_canny is None: | |
| print("Loading Canny ControlNet pipeline...") | |
| try: | |
| self.controlnet_canny = ControlNetModel.from_pretrained( | |
| "lllyasviel/sd-controlnet-canny", | |
| torch_dtype=self.torch_dtype | |
| ) | |
| self.pipe_canny = StableDiffusionControlNetPipeline.from_pretrained( | |
| "runwayml/stable-diffusion-v1-5", | |
| controlnet=self.controlnet_canny, | |
| torch_dtype=self.torch_dtype, | |
| safety_checker=None, | |
| requires_safety_checker=False | |
| ).to(self.device) | |
| from diffusers import EulerAncestralDiscreteScheduler | |
| self.pipe_canny.scheduler = EulerAncestralDiscreteScheduler.from_config(self.pipe_canny.scheduler.config) | |
| self.pipe_canny.enable_attention_slicing() | |
| print("✅ Canny ControlNet pipeline loaded successfully!") | |
| except Exception as e: | |
| print(f"Fehler beim Laden von Canny ControlNet: {e}") | |
| raise | |
| return self.pipe_canny | |
| elif controlnet_type == "multi": | |
| if self.pipe_multi is None: | |
| print("Loading Multi-ControlNet pipeline...") | |
| try: | |
| # Beide ControlNet-Modelle laden | |
| if self.controlnet_openpose is None: | |
| self.controlnet_openpose = ControlNetModel.from_pretrained( | |
| "lllyasviel/sd-controlnet-openpose", | |
| torch_dtype=self.torch_dtype | |
| ) | |
| if self.controlnet_canny is None: | |
| self.controlnet_canny = ControlNetModel.from_pretrained( | |
| "lllyasviel/sd-controlnet-canny", | |
| torch_dtype=self.torch_dtype | |
| ) | |
| # Multi-ControlNet Pipeline | |
| self.pipe_multi = StableDiffusionControlNetPipeline.from_pretrained( | |
| "runwayml/stable-diffusion-v1-5", | |
| controlnet=[self.controlnet_openpose, self.controlnet_canny], | |
| torch_dtype=self.torch_dtype, | |
| safety_checker=None, | |
| requires_safety_checker=False | |
| ).to(self.device) | |
| from diffusers import EulerAncestralDiscreteScheduler | |
| self.pipe_multi.scheduler = EulerAncestralDiscreteScheduler.from_config(self.pipe_multi.scheduler.config) | |
| self.pipe_multi.enable_attention_slicing() | |
| print("✅ Multi-ControlNet pipeline loaded successfully!") | |
| except Exception as e: | |
| print(f"Fehler beim Laden von Multi-ControlNet: {e}") | |
| raise | |
| return self.pipe_multi | |
| def generate_with_controlnet( | |
| self, image, prompt, negative_prompt, | |
| steps, guidance_scale, controlnet_strength, | |
| progress=None, keep_environment=False | |
| ): | |
| """Generiert Bild mit ControlNet und Fortschrittsanzeige""" | |
| try: | |
| # --- KORRIGIERTE LOGIK --- | |
| if keep_environment: | |
| # UMGEBUNG BEIBEHALTEN, PERSON ÄNDERN → MULTI-CONTROLNET | |
| print("🎯 ControlNet Modus: Umgebung beibehalten (Multi-ControlNet: OpenPose + Canny)") | |
| # Beide Conditioning Maps erstellen | |
| pose_image = self.extract_pose(image) | |
| canny_image = self.extract_canny_edges(image) | |
| print("✅ OpenPose + Canny Maps erstellt") | |
| # Multi-ControlNet verwenden | |
| conditioning_images = [pose_image, canny_image] | |
| controlnet_type = "multi" | |
| # Unterschiedliche Strengths für Pose und Canny | |
| controlnet_conditioning_scale = [controlnet_strength * 0.6, # OpenPose: 60% für Person | |
| controlnet_strength * 0.4] # Canny: 40% für Umgebung | |
| # Zufälliger Seed | |
| seed = random.randint(0, 2**32 - 1) | |
| generator = torch.Generator(device=self.device).manual_seed(seed) | |
| print(f"ControlNet Seed: {seed}") | |
| else: | |
| # PERSON BEIBEHALTEN, UMGEBUNG ÄNDERN → NUR OPENPOSE | |
| controlnet_type = "openpose" | |
| print("🎯 ControlNet Modus: Person beibehalten (OpenPose)") | |
| conditioning_images = self.extract_pose(image) | |
| controlnet_conditioning_scale = controlnet_strength | |
| # Zufälliger Seed | |
| seed = random.randint(0, 2**32 - 1) | |
| generator = torch.Generator(device=self.device).manual_seed(seed) | |
| print(f"ControlNet Seed: {seed}") | |
| pipe = self.load_controlnet_pipeline(controlnet_type) | |
| # Fortschritt-Callback | |
| callback = ControlNetProgressCallback(progress, int(steps)) if progress is not None else None | |
| print("🔄 ControlNet: Starte Pipeline...") | |
| # ControlNet Generierung | |
| if controlnet_type == "multi": | |
| result = pipe( | |
| prompt=prompt, | |
| image=conditioning_images, | |
| negative_prompt=negative_prompt, | |
| num_inference_steps=int(steps), | |
| guidance_scale=guidance_scale, | |
| generator=generator, | |
| controlnet_conditioning_scale=controlnet_conditioning_scale, | |
| height=512, | |
| width=512, | |
| output_type="pil", | |
| callback_on_step_end=callback, | |
| callback_on_step_end_tensor_inputs=[], | |
| ) | |
| else: | |
| result = pipe( | |
| prompt=prompt, | |
| image=conditioning_images, | |
| negative_prompt=negative_prompt, | |
| num_inference_steps=int(steps), | |
| guidance_scale=guidance_scale, | |
| generator=generator, | |
| controlnet_conditioning_scale=controlnet_conditioning_scale, | |
| height=512, | |
| width=512, | |
| output_type="pil", | |
| callback_on_step_end=callback, | |
| callback_on_step_end_tensor_inputs=[], | |
| ) | |
| print("✅ ControlNet abgeschlossen!") | |
| return result.images[0], image # ControlNet-Output + Originalbild | |
| except Exception as e: | |
| print(f"❌ Fehler in ControlNet: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| error_image = image.convert("RGB").resize((512, 512)) | |
| return error_image, error_image | |
| def prepare_inpaint_input(self, image, keep_environment=False): | |
| """ | |
| Bereitet das Input-Bild für Inpaint vor | |
| Rückgabe: (image_für_inpaint, conditioning_info) | |
| """ | |
| if keep_environment: | |
| # PERSON ÄNDERN: Originalbild an Inpaint übergeben | |
| print("🎯 Inpaint: Übergebe Originalbild (Person ändern)") | |
| return image, {"type": "original", "image": image} | |
| else: | |
| # UMGEBUNG ÄNDERN: Pose-Map an Inpaint übergeben | |
| print("🎯 Inpaint: Übergebe Pose-Map (Umgebung ändern)") | |
| pose_image = self.extract_pose(image) | |
| return pose_image, {"type": "pose", "image": pose_image} | |
| # Globale Instanz | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| torch_dtype = torch.float16 if device == "cuda" else torch.float32 | |
| controlnet_processor = ControlNetProcessor(device=device, torch_dtype=torch_dtype) |