Spaces:
Starting
on
T4
Starting
on
T4
| import torch | |
| from diffusers import StableDiffusionControlNetPipeline, ControlNetModel | |
| from controlnet_aux import OpenposeDetector | |
| from PIL import Image, ImageFilter | |
| import random | |
| import cv2 | |
| import numpy as np | |
| import gradio as gr | |
| # WICHTIG: Importiere die neuen SAM2-Klassen aus Transformers | |
| from transformers import Sam2Model, Sam2Processor | |
| class ControlNetProgressCallback: | |
| def __init__(self, progress, total_steps): | |
| self.progress = progress | |
| self.total_steps = total_steps | |
| self.current_step = 0 | |
| def __call__(self, pipe, step_index, timestep, callback_kwargs): | |
| self.current_step = step_index + 1 | |
| progress_percentage = self.current_step / self.total_steps | |
| if self.progress is not None: | |
| self.progress(progress_percentage, desc=f"ControlNet: Schritt {self.current_step}/{self.total_steps}") | |
| print(f"ControlNet Fortschritt: {self.current_step}/{self.total_steps} ({progress_percentage:.1%})") | |
| return callback_kwargs | |
| class ControlNetProcessor: | |
| def __init__(self, device="cuda", torch_dtype=torch.float32): | |
| self.device = device | |
| self.torch_dtype = torch_dtype | |
| self.pose_detector = None | |
| self.midas_model = None | |
| self.midas_transform = None | |
| # Ändere die Variablennamen für die neue API | |
| self.sam_processor = None | |
| self.sam_model = None | |
| self.sam_initialized = False | |
| def _lazy_load_sam(self): | |
| """Lazy Loading von SAM 2 über 🤗 Transformers API""" | |
| if self.sam_initialized: | |
| return True | |
| try: | |
| print("🔄 Lade SAM 2 über 🤗 Transformers...") | |
| # Die korrekte Modell-ID für SAM 2 Tiny | |
| model_id = "facebook/sam2-hiera-tiny" | |
| # Lade Processor und Modell mit der neuen API | |
| self.sam_processor = Sam2Processor.from_pretrained(model_id) | |
| self.sam_model = Sam2Model.from_pretrained(model_id, torch_dtype=self.torch_dtype).to(self.device) | |
| self.sam_model.eval() # Setze Modell in Evaluierungsmodus | |
| self.sam_initialized = True | |
| print("✅ SAM 2 erfolgreich geladen (via Transformers)") | |
| return True | |
| except Exception as e: | |
| print(f"❌ Fehler beim Laden von SAM 2: {str(e)[:200]}") | |
| self.sam_initialized = True # Verhindert weitere Ladeversuche | |
| return False | |
| def _validate_bbox(self, image, bbox_coords): | |
| """Validiert und korrigiert BBox-Koordinaten""" | |
| width, height = image.size | |
| # Extrahiere Koordinaten - unterstützt beide Formate | |
| if isinstance(bbox_coords, (list, tuple)) and len(bbox_coords) == 4: | |
| x1, y1, x2, y2 = bbox_coords | |
| else: | |
| # Für den Fall, dass Koordinaten einzeln übergeben werden | |
| x1, y1, x2, y2 = bbox_coords | |
| # Stelle sicher, dass x1 <= x2 und y1 <= y2 | |
| x1, x2 = min(x1, x2), max(x1, x2) | |
| y1, y2 = min(y1, y2), max(y1, y2) | |
| # Begrenze auf Bildgrenzen | |
| x1 = max(0, min(x1, width - 1)) | |
| y1 = max(0, min(y1, height - 1)) | |
| x2 = max(0, min(x2, width - 1)) | |
| y2 = max(0, min(y2, height - 1)) | |
| # Stelle sicher, dass BBox gültig ist | |
| if x2 - x1 < 10 or y2 - y1 < 10: | |
| # Fallback auf sinnvolle Größe | |
| size = min(width, height) * 0.3 | |
| x1 = max(0, width/2 - size/2) | |
| y1 = max(0, height/2 - size/2) | |
| x2 = min(width, width/2 + size/2) | |
| y2 = min(height, height/2 + size/2) | |
| return int(x1), int(y1), int(x2), int(y2) | |
| def _smooth_mask(self, mask_array, blur_radius=3): | |
| """Glättet die Maske für bessere Übergänge""" | |
| try: | |
| if blur_radius > 0: | |
| # Verwende median blur für bessere Kantenerhaltung als Gaussian | |
| mask_array = cv2.medianBlur(mask_array, blur_radius*2+1) | |
| return mask_array | |
| except Exception as e: | |
| print(f"⚠️ Fehler beim Glätten der Maske: {e}") | |
| return mask_array | |
| def create_sam_mask(self, image, bbox_coords, mode): | |
| """ | |
| Erstellt präzise Maske mit SAM 2 (via 🤗 Transformers API) | |
| Gibt PIL Image in L-Modus zurück (0=schwarz=erhalten, 255=weiß=verändern) | |
| """ | |
| try: | |
| # 1. SAM2 laden (falls noch nicht geschehen) | |
| if not self.sam_initialized: | |
| self._lazy_load_sam() | |
| if self.sam_model is None or self.sam_processor is None: | |
| print("⚠️ SAM 2 Model nicht verfügbar, verwende Fallback") | |
| return self._create_rectangular_mask(image, bbox_coords, mode) | |
| # 2. Validiere BBox und konvertiere Bild | |
| x1, y1, x2, y2 = self._validate_bbox(image, bbox_coords) | |
| width, height = image.size | |
| # Konvertiere zu numpy array (RGB) - für SAM2 Processor | |
| image_np = np.array(image.convert("RGB")) | |
| # 3. Vorbereiten der Eingabe für SAM2 | |
| # BBox im Format [x_min, y_min, x_max, y_max] erstellen | |
| # ACHTUNG: SAM2 erwartet Boxen in diesem Format | |
| # Zeilen in der Funktion anpassen: | |
| input_boxes = [[[x1, y1, x2, y2]]] #Dreifach verschachtelt | |
| # Bild mit dem Processor vorverarbeiten | |
| inputs = self.sam_processor( | |
| image_np, | |
| input_boxes=input_boxes, | |
| return_tensors="pt" | |
| ).to(self.device) | |
| # 4. Vorhersage mit dem Modell | |
| print(f"🎯 SAM 2: Segmentiere Bereich {x1},{y1}-{x2},{y2}") | |
| with torch.no_grad(): | |
| outputs = self.sam_model(**inputs) | |
| # 5. Maske extrahieren und verarbeiten | |
| # outputs.pred_masks enthält die Masken-Logits | |
| # post_process_masks stellt die Originalgröße wieder her | |
| mask = self.sam_processor.post_process_masks( | |
| outputs.pred_masks, | |
| inputs.original_sizes, | |
| inputs.reshaped_input_sizes | |
| )[0][0] # [batch_index][mask_index] | |
| # Sigmoid für Wahrscheinlichkeiten, dann Schwellenwert | |
| mask = mask.sigmoid().cpu().numpy() | |
| mask_array = (mask > 0.5).astype(np.uint8) * 255 | |
| # 6. Zu PIL Image konvertieren und auf Originalgröße bringen | |
| mask = Image.fromarray(mask_array.squeeze()).convert("L") | |
| mask = mask.resize((width, height), Image.Resampling.NEAREST) | |
| # 7. Kanten glätten für natürlichere Übergänge | |
| mask_array = np.array(mask) | |
| mask_array = self._smooth_mask(mask_array, blur_radius=2) | |
| mask = Image.fromarray(mask_array).convert("L") | |
| # 8. Modus-spezifische Anpassung (Invertierung) | |
| if mode == "environment_change": | |
| # MODUS 1: Umgebung ändern - Objekt schwarz (erhalten) | |
| mask = Image.eval(mask, lambda x: 255 - x) | |
| print(" SAM-Modus: Umgebung ändern (Objekt erhalten)") | |
| else: | |
| # MODUS 2 & 3: Focus/Gesicht ändern - Objekt weiß (verändern) | |
| print(" SAM-Modus: Focus/Gesicht ändern (Objekt verändern)") | |
| print(f"✅ SAM 2: Präzise Maske erstellt ({mask.size})") | |
| return mask | |
| except Exception as e: | |
| print(f"⚠️ SAM 2 Fehler (Transformers API): {str(e)[:200]}") | |
| import traceback | |
| traceback.print_exc() | |
| print("ℹ️ Fallback auf rechteckige Maske") | |
| return self._create_rectangular_mask(image, bbox_coords, mode) | |
| def _create_rectangular_mask(self, image, bbox_coords, mode): | |
| """Fallback: Erstellt rechteckige Maske""" | |
| from PIL import ImageDraw | |
| mask = Image.new("L", image.size, 0) | |
| if bbox_coords and all(coord is not None for coord in bbox_coords): | |
| x1, y1, x2, y2 = self._validate_bbox(image, bbox_coords) | |
| draw = ImageDraw.Draw(mask) | |
| if mode == "environment_change": | |
| # MODUS 1: Alles außer Box verändern | |
| draw.rectangle([0, 0, image.size[0], image.size[1]], fill=255) | |
| draw.rectangle([x1, y1, x2, y2], fill=0) | |
| print("ℹ️ Rechteckige Maske: Umgebung ändern") | |
| else: | |
| # MODUS 2 & 3: Nur Box verändern | |
| draw.rectangle([x1, y1, x2, y2], fill=255) | |
| print("ℹ️ Rechteckige Maske: Focus/Gesicht ändern") | |
| return mask | |
| def load_pose_detector(self): | |
| """Lädt nur den Pose-Detector""" | |
| if self.pose_detector is None: | |
| print("Loading Pose Detector...") | |
| try: | |
| self.pose_detector = OpenposeDetector.from_pretrained("lllyasviel/ControlNet") | |
| print("✅ Pose-Detector geladen") | |
| except Exception as e: | |
| print(f"⚠️ Pose-Detector konnte nicht geladen werden: {e}") | |
| return self.pose_detector | |
| def load_midas_model(self): | |
| """Lädt MiDaS Model für Depth Maps""" | |
| if self.midas_model is None: | |
| print("🔄 Lade MiDaS Modell für Depth Maps...") | |
| try: | |
| import torchvision.transforms as T | |
| self.midas_model = torch.hub.load( | |
| "intel-isl/MiDaS", | |
| "DPT_Hybrid", | |
| trust_repo=True | |
| ) | |
| self.midas_model.to(self.device) | |
| self.midas_model.eval() | |
| self.midas_transform = T.Compose([ | |
| T.Resize(384), | |
| T.ToTensor(), | |
| T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), | |
| ]) | |
| print("✅ MiDaS Modell erfolgreich geladen") | |
| except Exception as e: | |
| print(f"❌ MiDaS konnte nicht geladen werden: {e}") | |
| print("ℹ️ Verwende Fallback-Methode") | |
| self.midas_model = None | |
| return self.midas_model | |
| def extract_pose_simple(self, image): | |
| """Einfache Pose-Extraktion ohne komplexe Abhängigkeiten""" | |
| try: | |
| img_array = np.array(image.convert("RGB")) | |
| edges = cv2.Canny(img_array, 100, 200) | |
| pose_image = Image.fromarray(edges).convert("RGB") | |
| print("⚠️ Verwende Kanten-basierte Pose-Approximation") | |
| return pose_image | |
| except Exception as e: | |
| print(f"Fehler bei einfacher Pose-Extraktion: {e}") | |
| return image.convert("RGB").resize((512, 512)) | |
| def extract_pose(self, image): | |
| """Extrahiert Pose-Map aus Bild mit Fallback""" | |
| try: | |
| detector = self.load_pose_detector() | |
| if detector is None: | |
| return self.extract_pose_simple(image) | |
| pose_image = detector(image, hand_and_face=True) | |
| return pose_image | |
| except Exception as e: | |
| print(f"Fehler bei Pose-Extraktion: {e}") | |
| return self.extract_pose_simple(image) | |
| def extract_canny_edges(self, image): | |
| """Extrahiert Canny Edges für Umgebungserhaltung""" | |
| try: | |
| img_array = np.array(image.convert("RGB")) | |
| gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY) | |
| edges = cv2.Canny(gray, 100, 200) | |
| edges_rgb = cv2.cvtColor(edges, cv2.COLOR_GRAY2RGB) | |
| edges_image = Image.fromarray(edges_rgb) | |
| print("✅ Canny Edge Map erstellt") | |
| return edges_image | |
| except Exception as e: | |
| print(f"Fehler bei Canny Edge Extraction: {e}") | |
| return image.convert("RGB").resize((512, 512)) | |
| def extract_depth_map(self, image): | |
| """ | |
| Extrahiert Depth Map mit MiDaS (Fallback auf Filter) | |
| """ | |
| try: | |
| midas = self.load_midas_model() | |
| if midas is not None: | |
| print("🎯 Verwende MiDaS für Depth Map...") | |
| import torchvision.transforms as T | |
| img_transformed = self.midas_transform(image).unsqueeze(0).to(self.device) | |
| with torch.no_grad(): | |
| prediction = midas(img_transformed) | |
| prediction = torch.nn.functional.interpolate( | |
| prediction.unsqueeze(1), | |
| size=image.size[::-1], | |
| mode="bicubic", | |
| align_corners=False, | |
| ).squeeze() | |
| depth_np = prediction.cpu().numpy() | |
| depth_min, depth_max = depth_np.min(), depth_np.max() | |
| if depth_max > depth_min: | |
| depth_np = (depth_np - depth_min) / (depth_max - depth_min) | |
| depth_np = (depth_np * 255).astype(np.uint8) | |
| depth_image = Image.fromarray(depth_np).convert("RGB") | |
| print("✅ MiDaS Depth Map erfolgreich erstellt") | |
| return depth_image | |
| else: | |
| raise Exception("MiDaS nicht geladen") | |
| except Exception as e: | |
| print(f"⚠️ MiDaS Fehler: {e}. Verwende Fallback...") | |
| try: | |
| img_array = np.array(image.convert("RGB")) | |
| gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY) | |
| depth_map = cv2.GaussianBlur(gray, (5, 5), 0) | |
| depth_rgb = cv2.cvtColor(depth_map, cv2.COLOR_GRAY2RGB) | |
| depth_image = Image.fromarray(depth_rgb) | |
| print("✅ Fallback Depth Map erstellt") | |
| return depth_image | |
| except Exception as fallback_error: | |
| print(f"❌ Auch Fallback fehlgeschlagen: {fallback_error}") | |
| return image.convert("RGB").resize((512, 512)) | |
| def prepare_controlnet_maps(self, image, keep_environment=False): | |
| """ | |
| ERSTELLT NUR CONDITIONING-MAPS, generiert KEIN Bild. | |
| """ | |
| print("🎯 ControlNet: Erstelle Conditioning-Maps...") | |
| if keep_environment: | |
| print(" Modus: Depth + Canny") | |
| conditioning_images = [ | |
| self.extract_depth_map(image), | |
| self.extract_canny_edges(image) | |
| ] | |
| else: | |
| print(" Modus: OpenPose + Canny") | |
| conditioning_images = [ | |
| self.extract_pose(image), | |
| self.extract_canny_edges(image) | |
| ] | |
| print(f"✅ {len(conditioning_images)} Conditioning-Maps erstellt.") | |
| return conditioning_images | |
| # Globale Instanz | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| torch_dtype = torch.float16 if device == "cuda" else torch.float32 | |
| controlnet_processor = ControlNetProcessor(device=device, torch_dtype=torch_dtype) |