import torch from diffusers import StableDiffusionControlNetPipeline, ControlNetModel from controlnet_aux import OpenposeDetector from PIL import Image, ImageFilter import random import cv2 import numpy as np import gradio as gr from segment_anything import sam_model_registry, SamPredictor class ControlNetProgressCallback: def __init__(self, progress, total_steps): self.progress = progress self.total_steps = total_steps self.current_step = 0 def __call__(self, pipe, step_index, timestep, callback_kwargs): self.current_step = step_index + 1 progress_percentage = self.current_step / self.total_steps if self.progress is not None: self.progress(progress_percentage, desc=f"ControlNet: Schritt {self.current_step}/{self.total_steps}") print(f"ControlNet Fortschritt: {self.current_step}/{self.total_steps} ({progress_percentage:.1%})") return callback_kwargs class ControlNetProcessor: def __init__(self, device="cuda", torch_dtype=torch.float32): self.device = device self.torch_dtype = torch_dtype self.pose_detector = None self.midas_model = None self.midas_transform = None self.sam_predictor = None self.sam_initialized = False def _lazy_load_sam(self): """Lazy Loading von SAM 2 Tiny - Optimiert für Hugging Face Spaces""" if self.sam_initialized: return True try: print("🔄 Lade SAM 2 Tiny von Hugging Face Hub...") # KORRIGIERT: Nur der Hugging Face Model-ID Pfad model_id = "facebook/sam2-hiera-tiny" # SAM 2 Modell direkt von Hugging Face laden sam = sam_model_registry["sam2_hiera_tiny"](checkpoint=model_id) sam.to(self.device) self.sam_predictor = SamPredictor(sam) self.sam_initialized = True print(f"✅ SAM 2 ({model_id}) erfolgreich geladen") return True except Exception as e: print(f"❌ SAM 2 konnte nicht geladen werden: {str(e)[:100]}") print("ℹ️ Verwende rechteckige Masken als Fallback") self.sam_predictor = None self.sam_initialized = True # Verhindert weitere Ladeversuche return False def _validate_bbox(self, image, bbox_coords): """Validiert und korrigiert BBox-Koordinaten""" width, height = image.size x1, y1, x2, y2 = bbox_coords # Stelle sicher, dass x1 <= x2 und y1 <= y2 x1, x2 = min(x1, x2), max(x1, x2) y1, y2 = min(y1, y2), max(y1, y2) # Begrenze auf Bildgrenzen x1 = max(0, min(x1, width - 1)) y1 = max(0, min(y1, height - 1)) x2 = max(0, min(x2, width - 1)) y2 = max(0, min(y2, height - 1)) # Stelle sicher, dass BBox gültig ist if x2 - x1 < 10 or y2 - y1 < 10: # Fallback auf sinnvolle Größe size = min(width, height) * 0.3 x1 = max(0, width/2 - size/2) y1 = max(0, height/2 - size/2) x2 = min(width, width/2 + size/2) y2 = min(height, height/2 + size/2) return int(x1), int(y1), int(x2), int(y2) def _smooth_mask(self, mask_array, blur_radius=3): """Glättet die Maske für bessere Übergänge (5-Pixel Randbereich)""" try: # Gaussian Blur für weiche Kanten - nur der Randbereich wird beeinflusst if blur_radius > 0: mask_array = cv2.GaussianBlur(mask_array, (blur_radius*2+1, blur_radius*2+1), 0) return mask_array except: return mask_array def create_sam_mask(self, image, bbox_coords, mode): """ Erstellt präzise Maske mit SAM 2 (transparent für Benutzer) Gibt PIL Image in L-Modus zurück (0=schwarz=erhalten, 255=weiß=verändern) """ try: # Lade SAM bei Bedarf (automatisch für Hugging Face Spaces) if not self.sam_initialized: self._lazy_load_sam() # Fallback wenn SAM nicht verfügbar if self.sam_predictor is None: return self._create_rectangular_mask(image, bbox_coords, mode) # Validiere BBox x1, y1, x2, y2 = self._validate_bbox(image, bbox_coords) # Konvertiere zu numpy array (RGB) image_np = np.array(image.convert("RGB")) # SAM vorbereiten try: self.sam_predictor.set_image(image_np) except Exception as e: print(f"⚠️ SAM set_image Fehler: {e}") return self._create_rectangular_mask(image, bbox_coords, mode) # BBox für SAM formatieren input_box = np.array([x1, y1, x2, y2]) print(f"🎯 SAM 2: Segmentiere Bereich {x1},{y1}-{x2},{y2}") # SAM Prediction masks, scores, _ = self.sam_predictor.predict( point_coords=None, point_labels=None, box=input_box[None, :], multimask_output=False, return_logits=False ) # Beste Maske extrahieren und glätten (5-Pixel Übergang) mask_array = masks[0].astype(np.uint8) * 255 mask_array = self._smooth_mask(mask_array, blur_radius=2) # ~5 Pixel Rand # Zu PIL Image konvertieren mask = Image.fromarray(mask_array).convert("L") # Modus-spezifische Anpassung if mode == "environment_change": # MODUS 1: Umgebung ändern # Objekt schwarz (0) = ERHALTEN, Umgebung weiß (255) = VERÄNDERN mask = Image.eval(mask, lambda x: 255 - x) print(" SAM-Modus: Umgebung ändern (Objekt erhalten)") else: # MODUS 2 & 3: Focus oder Gesicht ändern # Objekt weiß (255) = VERÄNDERN, Umgebung schwarz (0) = ERHALTEN print(" SAM-Modus: Focus/Gesicht ändern (Objekt verändern)") print(f"✅ SAM 2: Präzise Maske erstellt ({mask.size})") return mask except Exception as e: print(f"⚠️ SAM 2 Fehler: {str(e)[:100]}") print("ℹ️ Fallback auf rechteckige Maske") return self._create_rectangular_mask(image, bbox_coords, mode) def _create_rectangular_mask(self, image, bbox_coords, mode): """Fallback: Erstellt rechteckige Maske""" from PIL import ImageDraw mask = Image.new("L", image.size, 0) if bbox_coords and all(coord is not None for coord in bbox_coords): x1, y1, x2, y2 = self._validate_bbox(image, bbox_coords) draw = ImageDraw.Draw(mask) if mode == "environment_change": # MODUS 1: Alles außer Box verändern draw.rectangle([0, 0, image.size[0], image.size[1]], fill=255) draw.rectangle([x1, y1, x2, y2], fill=0) else: # MODUS 2 & 3: Nur Box verändern draw.rectangle([x1, y1, x2, y2], fill=255) print("ℹ️ Rechteckige Maske (SAM Fallback)") return mask def load_pose_detector(self): """Lädt nur den Pose-Detector""" if self.pose_detector is None: print("Loading Pose Detector...") try: self.pose_detector = OpenposeDetector.from_pretrained("lllyasviel/ControlNet") print("✅ Pose-Detector geladen") except Exception as e: print(f"⚠️ Pose-Detector konnte nicht geladen werden: {e}") return self.pose_detector def load_midas_model(self): """Lädt MiDaS Model für Depth Maps""" if self.midas_model is None: print("🔄 Lade MiDaS Modell für Depth Maps...") try: import torchvision.transforms as T self.midas_model = torch.hub.load( "intel-isl/MiDaS", "DPT_Hybrid", trust_repo=True ) self.midas_model.to(self.device) self.midas_model.eval() self.midas_transform = T.Compose([ T.Resize(384), T.ToTensor(), T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), ]) print("✅ MiDaS Modell erfolgreich geladen") except Exception as e: print(f"❌ MiDaS konnte nicht geladen werden: {e}") print("ℹ️ Verwende Fallback-Methode") self.midas_model = None return self.midas_model def extract_pose_simple(self, image): """Einfache Pose-Extraktion ohne komplexe Abhängigkeiten""" try: img_array = np.array(image.convert("RGB")) edges = cv2.Canny(img_array, 100, 200) pose_image = Image.fromarray(edges).convert("RGB") print("⚠️ Verwende Kanten-basierte Pose-Approximation") return pose_image except Exception as e: print(f"Fehler bei einfacher Pose-Extraktion: {e}") return image.convert("RGB").resize((512, 512)) def extract_pose(self, image): """Extrahiert Pose-Map aus Bild mit Fallback""" try: detector = self.load_pose_detector() if detector is None: return self.extract_pose_simple(image) pose_image = detector(image, hand_and_face=True) return pose_image except Exception as e: print(f"Fehler bei Pose-Extraktion: {e}") return self.extract_pose_simple(image) def extract_canny_edges(self, image): """Extrahiert Canny Edges für Umgebungserhaltung""" try: img_array = np.array(image.convert("RGB")) gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY) edges = cv2.Canny(gray, 100, 200) edges_rgb = cv2.cvtColor(edges, cv2.COLOR_GRAY2RGB) edges_image = Image.fromarray(edges_rgb) print("✅ Canny Edge Map erstellt") return edges_image except Exception as e: print(f"Fehler bei Canny Edge Extraction: {e}") return image.convert("RGB").resize((512, 512)) def extract_depth_map(self, image): """ Extrahiert Depth Map mit MiDaS (Fallback auf Filter) """ try: midas = self.load_midas_model() if midas is not None: print("🎯 Verwende MiDaS für Depth Map...") import torchvision.transforms as T img_transformed = self.midas_transform(image).unsqueeze(0).to(self.device) with torch.no_grad(): prediction = midas(img_transformed) prediction = torch.nn.functional.interpolate( prediction.unsqueeze(1), size=image.size[::-1], mode="bicubic", align_corners=False, ).squeeze() depth_np = prediction.cpu().numpy() depth_min, depth_max = depth_np.min(), depth_np.max() if depth_max > depth_min: depth_np = (depth_np - depth_min) / (depth_max - depth_min) depth_np = (depth_np * 255).astype(np.uint8) depth_image = Image.fromarray(depth_np).convert("RGB") print("✅ MiDaS Depth Map erfolgreich erstellt") return depth_image else: raise Exception("MiDaS nicht geladen") except Exception as e: print(f"⚠️ MiDaS Fehler: {e}. Verwende Fallback...") try: img_array = np.array(image.convert("RGB")) gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY) depth_map = cv2.GaussianBlur(gray, (5, 5), 0) depth_rgb = cv2.cvtColor(depth_map, cv2.COLOR_GRAY2RGB) depth_image = Image.fromarray(depth_rgb) print("✅ Fallback Depth Map erstellt") return depth_image except Exception as fallback_error: print(f"❌ Auch Fallback fehlgeschlagen: {fallback_error}") return image.convert("RGB").resize((512, 512)) def prepare_controlnet_maps(self, image, keep_environment=False): """ ERSTELLT NUR CONDITIONING-MAPS, generiert KEIN Bild. """ print("🎯 ControlNet: Erstelle Conditioning-Maps...") if keep_environment: print(" Modus: Depth + Canny") conditioning_images = [ self.extract_depth_map(image), self.extract_canny_edges(image) ] else: print(" Modus: OpenPose + Canny") conditioning_images = [ self.extract_pose(image), self.extract_canny_edges(image) ] print(f"✅ {len(conditioning_images)} Conditioning-Maps erstellt.") return conditioning_images # Globale Instanz device = "cuda" if torch.cuda.is_available() else "cpu" torch_dtype = torch.float16 if device == "cuda" else torch.float32 controlnet_processor = ControlNetProcessor(device=device, torch_dtype=torch_dtype)