Stable-ControlNet-GPU / controlnet_module.py
Astridkraft's picture
Update controlnet_module.py
555cf3f verified
raw
history blame
14.8 kB
import torch
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
from controlnet_aux import OpenposeDetector
from PIL import Image, ImageFilter
import random
import cv2
import numpy as np
import gradio as gr
# WICHTIG: Importiere die neuen SAM2-Klassen aus Transformers
from transformers import Sam2Model, Sam2Processor
class ControlNetProgressCallback:
def __init__(self, progress, total_steps):
self.progress = progress
self.total_steps = total_steps
self.current_step = 0
def __call__(self, pipe, step_index, timestep, callback_kwargs):
self.current_step = step_index + 1
progress_percentage = self.current_step / self.total_steps
if self.progress is not None:
self.progress(progress_percentage, desc=f"ControlNet: Schritt {self.current_step}/{self.total_steps}")
print(f"ControlNet Fortschritt: {self.current_step}/{self.total_steps} ({progress_percentage:.1%})")
return callback_kwargs
class ControlNetProcessor:
def __init__(self, device="cuda", torch_dtype=torch.float32):
self.device = device
self.torch_dtype = torch_dtype
self.pose_detector = None
self.midas_model = None
self.midas_transform = None
# Ändere die Variablennamen für die neue API
self.sam_processor = None
self.sam_model = None
self.sam_initialized = False
def _lazy_load_sam(self):
"""Lazy Loading von SAM 2 über 🤗 Transformers API"""
if self.sam_initialized:
return True
try:
print("🔄 Lade SAM 2 über 🤗 Transformers...")
# Die korrekte Modell-ID für SAM 2 Tiny
model_id = "facebook/sam2-hiera-tiny"
# Lade Processor und Modell mit der neuen API
self.sam_processor = Sam2Processor.from_pretrained(model_id)
self.sam_model = Sam2Model.from_pretrained(model_id, torch_dtype=self.torch_dtype).to(self.device)
self.sam_model.eval() # Setze Modell in Evaluierungsmodus
self.sam_initialized = True
print("✅ SAM 2 erfolgreich geladen (via Transformers)")
return True
except Exception as e:
print(f"❌ Fehler beim Laden von SAM 2: {str(e)[:200]}")
self.sam_initialized = True # Verhindert weitere Ladeversuche
return False
def _validate_bbox(self, image, bbox_coords):
"""Validiert und korrigiert BBox-Koordinaten"""
width, height = image.size
# Extrahiere Koordinaten - unterstützt beide Formate
if isinstance(bbox_coords, (list, tuple)) and len(bbox_coords) == 4:
x1, y1, x2, y2 = bbox_coords
else:
# Für den Fall, dass Koordinaten einzeln übergeben werden
x1, y1, x2, y2 = bbox_coords
# Stelle sicher, dass x1 <= x2 und y1 <= y2
x1, x2 = min(x1, x2), max(x1, x2)
y1, y2 = min(y1, y2), max(y1, y2)
# Begrenze auf Bildgrenzen
x1 = max(0, min(x1, width - 1))
y1 = max(0, min(y1, height - 1))
x2 = max(0, min(x2, width - 1))
y2 = max(0, min(y2, height - 1))
# Stelle sicher, dass BBox gültig ist
if x2 - x1 < 10 or y2 - y1 < 10:
# Fallback auf sinnvolle Größe
size = min(width, height) * 0.3
x1 = max(0, width/2 - size/2)
y1 = max(0, height/2 - size/2)
x2 = min(width, width/2 + size/2)
y2 = min(height, height/2 + size/2)
return int(x1), int(y1), int(x2), int(y2)
def _smooth_mask(self, mask_array, blur_radius=3):
"""Glättet die Maske für bessere Übergänge"""
try:
if blur_radius > 0:
# Verwende median blur für bessere Kantenerhaltung als Gaussian
mask_array = cv2.medianBlur(mask_array, blur_radius*2+1)
return mask_array
except Exception as e:
print(f"⚠️ Fehler beim Glätten der Maske: {e}")
return mask_array
def create_sam_mask(self, image, bbox_coords, mode):
"""
Erstellt präzise Maske mit SAM 2 (via 🤗 Transformers API)
Gibt PIL Image in L-Modus zurück (0=schwarz=erhalten, 255=weiß=verändern)
"""
try:
# 1. SAM2 laden (falls noch nicht geschehen)
if not self.sam_initialized:
self._lazy_load_sam()
if self.sam_model is None or self.sam_processor is None:
print("⚠️ SAM 2 Model nicht verfügbar, verwende Fallback")
return self._create_rectangular_mask(image, bbox_coords, mode)
# 2. Validiere BBox und konvertiere Bild
x1, y1, x2, y2 = self._validate_bbox(image, bbox_coords)
width, height = image.size
# Konvertiere zu numpy array (RGB) - für SAM2 Processor
image_np = np.array(image.convert("RGB"))
# 3. Vorbereiten der Eingabe für SAM2
# BBox im Format [x_min, y_min, x_max, y_max] erstellen
# ACHTUNG: SAM2 erwartet Boxen in diesem Format
# Zeilen in der Funktion anpassen:
input_boxes = [[[x1, y1, x2, y2]]] #Dreifach verschachtelt
# Bild mit dem Processor vorverarbeiten
inputs = self.sam_processor(
image_np,
input_boxes=input_boxes,
return_tensors="pt"
).to(self.device)
# 4. Vorhersage mit dem Modell
print(f"🎯 SAM 2: Segmentiere Bereich {x1},{y1}-{x2},{y2}")
with torch.no_grad():
outputs = self.sam_model(**inputs)
# 5. Maske extrahieren und verarbeiten
# outputs.pred_masks enthält die Masken-Logits
# post_process_masks stellt die Originalgröße wieder her
mask = self.sam_processor.post_process_masks(
outputs.pred_masks,
inputs.original_sizes,
inputs.reshaped_input_sizes
)[0][0] # [batch_index][mask_index]
# Sigmoid für Wahrscheinlichkeiten, dann Schwellenwert
mask = mask.sigmoid().cpu().numpy()
mask_array = (mask > 0.5).astype(np.uint8) * 255
# 6. Zu PIL Image konvertieren und auf Originalgröße bringen
mask = Image.fromarray(mask_array.squeeze()).convert("L")
mask = mask.resize((width, height), Image.Resampling.NEAREST)
# 7. Kanten glätten für natürlichere Übergänge
mask_array = np.array(mask)
mask_array = self._smooth_mask(mask_array, blur_radius=2)
mask = Image.fromarray(mask_array).convert("L")
# 8. Modus-spezifische Anpassung (Invertierung)
if mode == "environment_change":
# MODUS 1: Umgebung ändern - Objekt schwarz (erhalten)
mask = Image.eval(mask, lambda x: 255 - x)
print(" SAM-Modus: Umgebung ändern (Objekt erhalten)")
else:
# MODUS 2 & 3: Focus/Gesicht ändern - Objekt weiß (verändern)
print(" SAM-Modus: Focus/Gesicht ändern (Objekt verändern)")
print(f"✅ SAM 2: Präzise Maske erstellt ({mask.size})")
return mask
except Exception as e:
print(f"⚠️ SAM 2 Fehler (Transformers API): {str(e)[:200]}")
import traceback
traceback.print_exc()
print("ℹ️ Fallback auf rechteckige Maske")
return self._create_rectangular_mask(image, bbox_coords, mode)
def _create_rectangular_mask(self, image, bbox_coords, mode):
"""Fallback: Erstellt rechteckige Maske"""
from PIL import ImageDraw
mask = Image.new("L", image.size, 0)
if bbox_coords and all(coord is not None for coord in bbox_coords):
x1, y1, x2, y2 = self._validate_bbox(image, bbox_coords)
draw = ImageDraw.Draw(mask)
if mode == "environment_change":
# MODUS 1: Alles außer Box verändern
draw.rectangle([0, 0, image.size[0], image.size[1]], fill=255)
draw.rectangle([x1, y1, x2, y2], fill=0)
print("ℹ️ Rechteckige Maske: Umgebung ändern")
else:
# MODUS 2 & 3: Nur Box verändern
draw.rectangle([x1, y1, x2, y2], fill=255)
print("ℹ️ Rechteckige Maske: Focus/Gesicht ändern")
return mask
def load_pose_detector(self):
"""Lädt nur den Pose-Detector"""
if self.pose_detector is None:
print("Loading Pose Detector...")
try:
self.pose_detector = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
print("✅ Pose-Detector geladen")
except Exception as e:
print(f"⚠️ Pose-Detector konnte nicht geladen werden: {e}")
return self.pose_detector
def load_midas_model(self):
"""Lädt MiDaS Model für Depth Maps"""
if self.midas_model is None:
print("🔄 Lade MiDaS Modell für Depth Maps...")
try:
import torchvision.transforms as T
self.midas_model = torch.hub.load(
"intel-isl/MiDaS",
"DPT_Hybrid",
trust_repo=True
)
self.midas_model.to(self.device)
self.midas_model.eval()
self.midas_transform = T.Compose([
T.Resize(384),
T.ToTensor(),
T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
print("✅ MiDaS Modell erfolgreich geladen")
except Exception as e:
print(f"❌ MiDaS konnte nicht geladen werden: {e}")
print("ℹ️ Verwende Fallback-Methode")
self.midas_model = None
return self.midas_model
def extract_pose_simple(self, image):
"""Einfache Pose-Extraktion ohne komplexe Abhängigkeiten"""
try:
img_array = np.array(image.convert("RGB"))
edges = cv2.Canny(img_array, 100, 200)
pose_image = Image.fromarray(edges).convert("RGB")
print("⚠️ Verwende Kanten-basierte Pose-Approximation")
return pose_image
except Exception as e:
print(f"Fehler bei einfacher Pose-Extraktion: {e}")
return image.convert("RGB").resize((512, 512))
def extract_pose(self, image):
"""Extrahiert Pose-Map aus Bild mit Fallback"""
try:
detector = self.load_pose_detector()
if detector is None:
return self.extract_pose_simple(image)
pose_image = detector(image, hand_and_face=True)
return pose_image
except Exception as e:
print(f"Fehler bei Pose-Extraktion: {e}")
return self.extract_pose_simple(image)
def extract_canny_edges(self, image):
"""Extrahiert Canny Edges für Umgebungserhaltung"""
try:
img_array = np.array(image.convert("RGB"))
gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
edges = cv2.Canny(gray, 100, 200)
edges_rgb = cv2.cvtColor(edges, cv2.COLOR_GRAY2RGB)
edges_image = Image.fromarray(edges_rgb)
print("✅ Canny Edge Map erstellt")
return edges_image
except Exception as e:
print(f"Fehler bei Canny Edge Extraction: {e}")
return image.convert("RGB").resize((512, 512))
def extract_depth_map(self, image):
"""
Extrahiert Depth Map mit MiDaS (Fallback auf Filter)
"""
try:
midas = self.load_midas_model()
if midas is not None:
print("🎯 Verwende MiDaS für Depth Map...")
import torchvision.transforms as T
img_transformed = self.midas_transform(image).unsqueeze(0).to(self.device)
with torch.no_grad():
prediction = midas(img_transformed)
prediction = torch.nn.functional.interpolate(
prediction.unsqueeze(1),
size=image.size[::-1],
mode="bicubic",
align_corners=False,
).squeeze()
depth_np = prediction.cpu().numpy()
depth_min, depth_max = depth_np.min(), depth_np.max()
if depth_max > depth_min:
depth_np = (depth_np - depth_min) / (depth_max - depth_min)
depth_np = (depth_np * 255).astype(np.uint8)
depth_image = Image.fromarray(depth_np).convert("RGB")
print("✅ MiDaS Depth Map erfolgreich erstellt")
return depth_image
else:
raise Exception("MiDaS nicht geladen")
except Exception as e:
print(f"⚠️ MiDaS Fehler: {e}. Verwende Fallback...")
try:
img_array = np.array(image.convert("RGB"))
gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
depth_map = cv2.GaussianBlur(gray, (5, 5), 0)
depth_rgb = cv2.cvtColor(depth_map, cv2.COLOR_GRAY2RGB)
depth_image = Image.fromarray(depth_rgb)
print("✅ Fallback Depth Map erstellt")
return depth_image
except Exception as fallback_error:
print(f"❌ Auch Fallback fehlgeschlagen: {fallback_error}")
return image.convert("RGB").resize((512, 512))
def prepare_controlnet_maps(self, image, keep_environment=False):
"""
ERSTELLT NUR CONDITIONING-MAPS, generiert KEIN Bild.
"""
print("🎯 ControlNet: Erstelle Conditioning-Maps...")
if keep_environment:
print(" Modus: Depth + Canny")
conditioning_images = [
self.extract_depth_map(image),
self.extract_canny_edges(image)
]
else:
print(" Modus: OpenPose + Canny")
conditioning_images = [
self.extract_pose(image),
self.extract_canny_edges(image)
]
print(f"✅ {len(conditioning_images)} Conditioning-Maps erstellt.")
return conditioning_images
# Globale Instanz
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if device == "cuda" else torch.float32
controlnet_processor = ControlNetProcessor(device=device, torch_dtype=torch_dtype)