Stable-ControlNet-GPU / controlnet_module.py
Astridkraft's picture
Update controlnet_module.py
3f045c5 verified
raw
history blame
17.6 kB
import torch
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
from controlnet_aux import OpenposeDetector
from PIL import Image
import random
import cv2
import numpy as np
import gradio as gr
class ControlNetProgressCallback:
def __init__(self, progress, total_steps):
self.progress = progress
self.total_steps = total_steps
self.current_step = 0
def __call__(self, pipe, step_index, timestep, callback_kwargs):
self.current_step = step_index + 1
progress_percentage = self.current_step / self.total_steps
# Fortschritt aktualisieren
if self.progress is not None:
self.progress(progress_percentage, desc=f"ControlNet: Schritt {self.current_step}/{self.total_steps}")
print(f"ControlNet Fortschritt: {self.current_step}/{self.total_steps} ({progress_percentage:.1%})")
return callback_kwargs
class ControlNetProcessor:
def __init__(self, device="cuda", torch_dtype=torch.float32):
self.device = device
self.torch_dtype = torch_dtype
self.pose_detector = None
self.controlnet_openpose = None
self.controlnet_canny = None
self.controlnet_depth = None
self.pipe_openpose = None
self.pipe_canny = None
self.pipe_depth = None
self.pipe_multi_inside = None # OpenPose + Canny für Inside-Box
self.pipe_multi_outside = None # Depth + Canny für Outside-Box
def load_pose_detector(self):
"""Lädt nur den Pose-Detector"""
if self.pose_detector is None:
print("Loading Pose Detector...")
try:
self.pose_detector = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
except Exception as e:
print(f"Warnung: Pose-Detector konnte nicht geladen werden: {e}")
return self.pose_detector
def extract_pose_simple(self, image):
"""Einfache Pose-Extraktion ohne komplexe Abhängigkeiten"""
try:
img_array = np.array(image.convert("RGB"))
edges = cv2.Canny(img_array, 100, 200)
pose_image = Image.fromarray(edges).convert("RGB")
print("⚠️ Verwende Kanten-basierte Pose-Approximation")
return pose_image
except Exception as e:
print(f"Fehler bei einfacher Pose-Extraktion: {e}")
return image.convert("RGB").resize((512, 512))
def extract_pose(self, image):
"""Extrahiert Pose-Map aus Bild mit Fallback"""
try:
detector = self.load_pose_detector()
if detector is None:
return self.extract_pose_simple(image)
pose_image = detector(image, hand_and_face=True)
return pose_image
except Exception as e:
print(f"Fehler bei Pose-Extraktion: {e}")
return self.extract_pose_simple(image)
def extract_canny_edges(self, image):
"""Extrahiert Canny Edges für Umgebungserhaltung"""
try:
img_array = np.array(image.convert("RGB"))
# Canny Edge Detection
gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
edges = cv2.Canny(gray, 100, 200)
# Zu 3-Kanal Bild konvertieren
edges_rgb = cv2.cvtColor(edges, cv2.COLOR_GRAY2RGB)
edges_image = Image.fromarray(edges_rgb)
print("✅ Canny Edge Map erstellt")
return edges_image
except Exception as e:
print(f"Fehler bei Canny Edge Extraction: {e}")
return image.convert("RGB").resize((512, 512))
def extract_depth_map(self, image):
"""Extrahiert Depth Map für räumliche Konsistenz"""
try:
# Für echte Depth-Maps würde man ein Depth-Estimation-Modell verwenden
# Hier als Fallback: Konvertierung zu Grayscale als Depth-Approximation
img_array = np.array(image.convert("RGB"))
gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
# Depth-ähnliche Map erstellen (helle Bereiche = nah, dunkle = fern)
depth_map = cv2.GaussianBlur(gray, (5, 5), 0)
depth_rgb = cv2.cvtColor(depth_map, cv2.COLOR_GRAY2RGB)
depth_image = Image.fromarray(depth_rgb)
print("✅ Depth Map erstellt (Grayscale Approximation)")
return depth_image
except Exception as e:
print(f"Fehler bei Depth Map Extraction: {e}")
return image.convert("RGB").resize((512, 512))
def load_controlnet_pipeline(self, controlnet_type="openpose"):
"""Lädt die passende ControlNet Pipeline"""
if controlnet_type == "openpose":
if self.pipe_openpose is None:
print("Loading OpenPose ControlNet pipeline...")
try:
self.controlnet_openpose = ControlNetModel.from_pretrained(
"lllyasviel/sd-controlnet-openpose",
torch_dtype=self.torch_dtype
)
self.pipe_openpose = StableDiffusionControlNetPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
controlnet=self.controlnet_openpose,
torch_dtype=self.torch_dtype,
safety_checker=None,
requires_safety_checker=False
).to(self.device)
from diffusers import EulerAncestralDiscreteScheduler
self.pipe_openpose.scheduler = EulerAncestralDiscreteScheduler.from_config(self.pipe_openpose.scheduler.config)
self.pipe_openpose.enable_attention_slicing()
print("✅ OpenPose ControlNet pipeline loaded successfully!")
except Exception as e:
print(f"Fehler beim Laden von OpenPose ControlNet: {e}")
raise
return self.pipe_openpose
elif controlnet_type == "canny":
if self.pipe_canny is None:
print("Loading Canny ControlNet pipeline...")
try:
self.controlnet_canny = ControlNetModel.from_pretrained(
"lllyasviel/sd-controlnet-canny",
torch_dtype=self.torch_dtype
)
self.pipe_canny = StableDiffusionControlNetPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
controlnet=self.controlnet_canny,
torch_dtype=self.torch_dtype,
safety_checker=None,
requires_safety_checker=False
).to(self.device)
from diffusers import EulerAncestralDiscreteScheduler
self.pipe_canny.scheduler = EulerAncestralDiscreteScheduler.from_config(self.pipe_canny.scheduler.config)
self.pipe_canny.enable_attention_slicing()
print("✅ Canny ControlNet pipeline loaded successfully!")
except Exception as e:
print(f"Fehler beim Laden von Canny ControlNet: {e}")
raise
return self.pipe_canny
elif controlnet_type == "depth":
if self.pipe_depth is None:
print("Loading Depth ControlNet pipeline...")
try:
self.controlnet_depth = ControlNetModel.from_pretrained(
"lllyasviel/sd-controlnet-depth",
torch_dtype=self.torch_dtype
)
self.pipe_depth = StableDiffusionControlNetPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
controlnet=self.controlnet_depth,
torch_dtype=self.torch_dtype,
safety_checker=None,
requires_safety_checker=False
).to(self.device)
from diffusers import EulerAncestralDiscreteScheduler
self.pipe_depth.scheduler = EulerAncestralDiscreteScheduler.from_config(self.pipe_depth.scheduler.config)
self.pipe_depth.enable_attention_slicing()
print("✅ Depth ControlNet pipeline loaded successfully!")
except Exception as e:
print(f"Fehler beim Laden von Depth ControlNet: {e}")
raise
return self.pipe_depth
elif controlnet_type == "multi_inside": # OpenPose + Canny für Inside-Box
if self.pipe_multi_inside is None:
print("Loading Multi-ControlNet pipeline für Inside-Box...")
try:
if self.controlnet_openpose is None:
self.controlnet_openpose = ControlNetModel.from_pretrained(
"lllyasviel/sd-controlnet-openpose",
torch_dtype=self.torch_dtype
)
if self.controlnet_canny is None:
self.controlnet_canny = ControlNetModel.from_pretrained(
"lllyasviel/sd-controlnet-canny",
torch_dtype=self.torch_dtype
)
self.pipe_multi_inside = StableDiffusionControlNetPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
controlnet=[self.controlnet_openpose, self.controlnet_canny],
torch_dtype=self.torch_dtype,
safety_checker=None,
requires_safety_checker=False
).to(self.device)
from diffusers import EulerAncestralDiscreteScheduler
self.pipe_multi_inside.scheduler = EulerAncestralDiscreteScheduler.from_config(self.pipe_multi_inside.scheduler.config)
self.pipe_multi_inside.enable_attention_slicing()
print("✅ Multi-ControlNet (Inside) pipeline loaded successfully!")
except Exception as e:
print(f"Fehler beim Laden von Multi-ControlNet Inside: {e}")
raise
return self.pipe_multi_inside
elif controlnet_type == "multi_outside": # Depth + Canny für Outside-Box
if self.pipe_multi_outside is None:
print("Loading Multi-ControlNet pipeline für Outside-Box...")
try:
if self.controlnet_depth is None:
self.controlnet_depth = ControlNetModel.from_pretrained(
"lllyasviel/sd-controlnet-depth",
torch_dtype=self.torch_dtype
)
if self.controlnet_canny is None:
self.controlnet_canny = ControlNetModel.from_pretrained(
"lllyasviel/sd-controlnet-canny",
torch_dtype=self.torch_dtype
)
self.pipe_multi_outside = StableDiffusionControlNetPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
controlnet=[self.controlnet_depth, self.controlnet_canny],
torch_dtype=self.torch_dtype,
safety_checker=None,
requires_safety_checker=False
).to(self.device)
from diffusers import EulerAncestralDiscreteScheduler
self.pipe_multi_outside.scheduler = EulerAncestralDiscreteScheduler.from_config(self.pipe_multi_outside.scheduler.config)
self.pipe_multi_outside.enable_attention_slicing()
print("✅ Multi-ControlNet (Outside) pipeline loaded successfully!")
except Exception as e:
print(f"Fehler beim Laden von Multi-ControlNet Outside: {e}")
raise
return self.pipe_multi_outside
def generate_with_controlnet(
self, image, prompt, negative_prompt,
steps, guidance_scale, controlnet_strength,
progress=None, keep_environment=False
):
"""
GENERIERT BILD MIT CONTROLNET
WICHTIG: Diese Funktion wird von app.py aufgerufen
Parameter keep_environment bestimmt:
- True: "Umgebung ändern" und "Ausschließlich Gesicht" → Depth+Canny
- False: "Focus verändern" → OpenPose+Canny
Die eigentliche Maskenlogik wird in app.py (create_face_mask) gehandhabt
"""
try:
# --- LOGIK FÜR 3 MODI (VON APP.PY GESTEUERT) ---
if keep_environment:
# FALL 1 & 3: Umgebung ändern ODER Ausschließlich Gesicht → Depth + Canny
print("🎯 ControlNet: Depth + Canny (keep_environment=True)")
# Beide Conditioning Maps erstellen
depth_image = self.extract_depth_map(image)
canny_image = self.extract_canny_edges(image)
print("✅ Depth + Canny Maps für Outside/Inside-Box erstellt")
# Multi-ControlNet für Outside verwenden
conditioning_images = [depth_image, canny_image]
controlnet_type = "multi_outside"
# Gewichtung: Depth 60%, Canny 40%
controlnet_conditioning_scale = [controlnet_strength * 0.6, # Depth: 60% für räumliche Tiefe
controlnet_strength * 0.4] # Canny: 40% für Strukturen
else:
# FALL 2: Focus verändern → OpenPose + Canny
print("🎯 ControlNet: OpenPose + Canny (keep_environment=False)")
# Beide Conditioning Maps erstellen
pose_image = self.extract_pose(image)
canny_image = self.extract_canny_edges(image)
print("✅ OpenPose + Canny Maps für Inside-Box erstellt")
# Multi-ControlNet für Inside verwenden
conditioning_images = [pose_image, canny_image]
controlnet_type = "multi_inside"
# Gewichtung: OpenPose 70%, Canny 30%
controlnet_conditioning_scale = [controlnet_strength * 0.7, # OpenPose: 70% für Person
controlnet_strength * 0.3] # Canny: 30% für Konturen
# Zufälliger Seed
seed = random.randint(0, 2**32 - 1)
generator = torch.Generator(device=self.device).manual_seed(seed)
print(f"ControlNet Seed: {seed}")
pipe = self.load_controlnet_pipeline(controlnet_type)
# Fortschritt-Callback
callback = ControlNetProgressCallback(progress, int(steps)) if progress is not None else None
print("🔄 ControlNet: Starte Pipeline...")
# ControlNet Generierung
result = pipe(
prompt=prompt,
image=conditioning_images,
negative_prompt=negative_prompt,
num_inference_steps=int(steps),
guidance_scale=guidance_scale,
generator=generator,
controlnet_conditioning_scale=controlnet_conditioning_scale,
height=512,
width=512,
output_type="pil",
callback_on_step_end=callback,
callback_on_step_end_tensor_inputs=[],
)
print("✅ ControlNet abgeschlossen!")
# Rückgabe: ControlNet-Output + Originalbild (für Inpaint)
return result.images[0], image
except Exception as e:
print(f"❌ Fehler in ControlNet: {e}")
import traceback
traceback.print_exc()
error_image = image.convert("RGB").resize((512, 512))
return error_image, error_image
def prepare_inpaint_input(self, image, keep_environment=False):
"""
Bereitet das Input-Bild für Inpaint vor
Rückgabe: (image_für_inpaint, conditioning_info)
HINWEIS: Diese Funktion wird nicht direkt von app.py verwendet,
da die Logik in generate_with_controlnet enthalten ist.
"""
if keep_environment:
# OUTSIDE-BOX ÄNDERN: Depth+Canny Info für Umgebung
print("🎯 Inpaint: Übergebe Depth+Canny Info (Outside-Box ändern)")
depth_image = self.extract_depth_map(image)
canny_image = self.extract_canny_edges(image)
# Für Inpaint kann eine kombinierte Map verwendet werden
combined_map = Image.blend(depth_image.convert("RGB"), canny_image.convert("RGB"), alpha=0.5)
return combined_map, {"type": "depth_canny", "image": combined_map}
else:
# INSIDE-BOX ÄNDERN: Originalbild an Inpaint übergeben
print("🎯 Inpaint: Übergebe Originalbild (Inside-Box ändern)")
return image, {"type": "original", "image": image}
# Globale Instanz
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if device == "cuda" else torch.float32
controlnet_processor = ControlNetProcessor(device=device, torch_dtype=torch_dtype)