Stable-ControlNet-GPU / controlnet_module.py
Astridkraft's picture
Update controlnet_module.py
8559945 verified
raw
history blame
8.87 kB
import torch
from diffusers import StableDiffusionControlNetPipeline, ControlNetMode
from torchvision.models.detection import keypointrcnn_resnet50_fpn
from controlnet_aux import OpenposeDetector
from PIL import Image
import random
import cv2
import numpy as np
import gradio as gr
class ControlNetProgressCallback:
def __init__(self, progress, total_steps):
self.progress = progress
self.total_steps = total_steps
self.current_step = 0
def __call__(self, pipe, step_index, timestep, callback_kwargs):
self.current_step = step_index + 1
progress_percentage = self.current_step / self.total_steps
# Fortschritt aktualisieren
if self.progress is not None:
self.progress(progress_percentage, desc=f"ControlNet: Schritt {self.current_step}/{self.total_steps}")
print(f"ControlNet Fortschritt: {self.current_step}/{self.total_steps} ({progress_percentage:.1%})")
return callback_kwargs
class ControlNetProcessor:
def __init__(self, device="cuda", torch_dtype=torch.float32):
self.device = device
self.torch_dtype = torch_dtype
self.pose_detector = None
self.controlnet_openpose = None
self.controlnet_canny = None
self.controlnet_depth = None
self.pipe_openpose = None
self.pipe_canny = None
self.pipe_depth = None
self.pipe_multi_inside = None # OpenPose + Canny für Inside-Box
self.pipe_multi_outside = None # Depth + Canny für Outside-Box
def load_pose_detector(self):
"""Lädt nur den Pose-Detector"""
if self.pose_detector is None:
print("Loading Pose Detector...")
try:
self.pose_detector = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
except Exception as e:
print(f"Warnung: Pose-Detector konnte nicht geladen werden: {e}")
return self.pose_detector
def extract_pose_simple(self, image):
"""Einfache Pose-Extraktion ohne komplexe Abhängigkeiten"""
try:
img_array = np.array(image.convert("RGB"))
edges = cv2.Canny(img_array, 100, 200)
pose_image = Image.fromarray(edges).convert("RGB")
print("⚠️ Verwende Kanten-basierte Pose-Approximation")
return pose_image
except Exception as e:
print(f"Fehler bei einfacher Pose-Extraktion: {e}")
return image.convert("RGB").resize((512, 512))
def extract_pose(self, image):
"""Extrahiert Pose-Map aus Bild mit Fallback"""
try:
detector = self.load_pose_detector()
if detector is None:
return self.extract_pose_simple(image)
pose_image = detector(image, hand_and_face=True)
return pose_image
except Exception as e:
print(f"Fehler bei Pose-Extraktion: {e}")
return self.extract_pose_simple(image)
def extract_canny_edges(self, image):
"""Extrahiert Canny Edges für Umgebungserhaltung"""
try:
img_array = np.array(image.convert("RGB"))
# Canny Edge Detection
gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
edges = cv2.Canny(gray, 100, 200)
# Zu 3-Kanal Bild konvertieren
edges_rgb = cv2.cvtColor(edges, cv2.COLOR_GRAY2RGB)
edges_image = Image.fromarray(edges_rgb)
print("✅ Canny Edge Map erstellt")
return edges_image
except Exception as e:
print(f"Fehler bei Canny Edge Extraction: {e}")
return image.convert("RGB").resize((512, 512))
def extract_depth_map(self, image):
"""
Extrahiert Depth Map mit MiDaS Small (Fallback auf alten Filter).
"""
try:
print("🔄 Versuche MiDaS Small für Depth Map...")
# 1. MiDaS Modelle vor dem ersten Gebrauch laden (spart VRAM)
if not hasattr(self, 'midas_model'):
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
import midas
self.midas_transform = Compose([
Resize(384, interpolation=midas.utils.interpolation),
ToTensor(),
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# WICHTIG: MiDaS Small lädt automatisch die 'small'-Variante (weniger VRAM)
self.midas_model = midas.MiDaS()
self.midas_model.eval()
if self.device == 'cuda':
self.midas_model.to(self.device)
print("✅ MiDaS Small Modell geladen (GPU)")
else:
print("✅ MiDaS Small Modell geladen (CPU)")
# 2. Bild für MiDaS vorbereiten
img_input = self.midas_transform(image).unsqueeze(0).to(self.device)
# 3. Depth Map berechnen
with torch.no_grad():
prediction = self.midas_model(img_input)
prediction = torch.nn.functional.interpolate(
prediction.unsqueeze(1),
size=image.size[::-1], # (height, width)
mode="bicubic",
align_corners=False,
).squeeze()
# 4. Normalisierung für sichtbare Ausgabe
depth_np = prediction.cpu().numpy()
depth_min, depth_max = depth_np.min(), depth_np.max()
if depth_max > depth_min:
depth_np = (depth_np - depth_min) / (depth_max - depth_min)
depth_np = (depth_np * 255).astype(np.uint8)
depth_image = Image.fromarray(depth_np).convert("RGB")
print("✅ MiDaS Depth Map erfolgreich erstellt")
return depth_image
except Exception as e:
print(f"⚠️ MiDaS Fehler: {e}. Verwende Fallback (Grayscale Filter)...")
# Fallback auf Ihren bestehenden Filter-Code
try:
img_array = np.array(image.convert("RGB"))
gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
# Depth-ähnliche Map erstellen (helle Bereiche = nah, dunkle = fern)
depth_map = cv2.GaussianBlur(gray, (5, 5), 0)
depth_rgb = cv2.cvtColor(depth_map, cv2.COLOR_GRAY2RGB)
depth_image = Image.fromarray(depth_rgb)
print("✅ Fallback Depth Map erstellt")
return depth_image
except Exception as fallback_error:
print(f"❌ Auch Fallback fehlgeschlagen: {fallback_error}")
return image.convert("RGB").resize((512, 512))
def prepare_controlnet_maps(self, image, keep_environment=False):
"""
ERSTELLT NUR CONDITIONING-MAPS, generiert KEIN Bild.
"""
print("🎯 ControlNet: Erstelle Conditioning-Maps...")
if keep_environment:
# Depth + Canny
print(" Modus: Depth + Canny")
conditioning_images = [
self.extract_depth_map(image),
self.extract_canny_edges(image)
]
else:
# OpenPose + Canny
print(" Modus: OpenPose + Canny")
conditioning_images = [
self.extract_pose(image),
self.extract_canny_edges(image)
]
print(f"✅ {len(conditioning_images)} Conditioning-Maps erstellt.")
return conditioning_images # Rückgabe: Liste der PIL Images
def prepare_inpaint_input(self, image, keep_environment=False):
"""
Bereitet das Input-Bild für Inpaint vor
Rückgabe: (image_für_inpaint, conditioning_info)
HINWEIS: Diese Funktion wird nicht direkt von app.py verwendet,
da die Logik in generate_with_controlnet enthalten ist.
"""
if keep_environment:
# OUTSIDE-BOX ÄNDERN: Depth+Canny Info für Umgebung
print("🎯 Inpaint: Übergebe Depth+Canny Info (Outside-Box ändern)")
depth_image = self.extract_depth_map(image)
canny_image = self.extract_canny_edges(image)
# Für Inpaint kann eine kombinierte Map verwendet werden
combined_map = Image.blend(depth_image.convert("RGB"), canny_image.convert("RGB"), alpha=0.5)
return combined_map, {"type": "depth_canny", "image": combined_map}
else:
# INSIDE-BOX ÄNDERN: Originalbild an Inpaint übergeben
print("🎯 Inpaint: Übergebe Originalbild (Inside-Box ändern)")
return image, {"type": "original", "image": image}
# Globale Instanz
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if device == "cuda" else torch.float32
controlnet_processor = ControlNetProcessor(device=device, torch_dtype=torch_dtype)