Stable-ControlNet-GPU / controlnet_module.py
Astridkraft's picture
Update controlnet_module.py
f82ebb1 verified
raw
history blame
6.72 kB
import torch
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
from controlnet_aux import OpenposeDetector
from PIL import Image
import random
import cv2
import numpy as np
import gradio as gr
class ControlNetProgressCallback:
def __init__(self, progress, total_steps):
self.progress = progress
self.total_steps = total_steps
self.current_step = 0
def __call__(self, pipe, step_index, timestep, callback_kwargs):
self.current_step = step_index + 1
progress_percentage = self.current_step / self.total_steps
# Fortschritt aktualisieren
if self.progress is not None:
self.progress(progress_percentage, desc=f"ControlNet: Schritt {self.current_step}/{self.total_steps}")
print(f"ControlNet Fortschritt: {self.current_step}/{self.total_steps} ({progress_percentage:.1%})")
return callback_kwargs
class ControlNetProcessor:
def __init__(self, device="cuda", torch_dtype=torch.float32):
self.device = device
self.torch_dtype = torch_dtype
self.pose_detector = None
self.controlnet = None
self.pipe = None
def load_pose_detector(self):
"""Lädt nur den Pose-Detector"""
if self.pose_detector is None:
print("Loading Pose Detector...")
try:
self.pose_detector = OpenposeDetector.from_pretrained(
"lllyasviel/ControlNet",
)
except Exception as e:
print(f"Warnung: Pose-Detector konnte nicht geladen werden: {e}")
return self.pose_detector
def extract_pose_simple(self, image):
"""Einfache Pose-Extraktion ohne komplexe Abhängigkeiten"""
try:
img_array = np.array(image.convert("RGB"))
edges = cv2.Canny(img_array, 100, 200)
pose_image = Image.fromarray(edges).convert("RGB")
print("⚠️ Verwende Kanten-basierte Pose-Approximation")
return pose_image
except Exception as e:
print(f"Fehler bei einfacher Pose-Extraktion: {e}")
return image.convert("RGB").resize((512, 512))
def extract_pose(self, image):
"""Extrahiert Pose-Map aus Bild mit Fallback"""
try:
detector = self.load_pose_detector()
if detector is None:
return self.extract_pose_simple(image)
#pose_image = detector(image, hand_and_face=True, detect_resolution=512)
pose_image = detector.detect(image, hand_and_face=True)
return pose_image
except Exception as e:
print(f"Fehler bei Pose-Extraktion: {e}")
return self.extract_pose_simple(image)
def generate_with_controlnet(self, image, prompt, negative_prompt,
steps, guidance_scale, controlnet_strength, progress=None):
"""Generiert Bild mit ControlNet und Fortschrittsanzeige"""
try:
# Pipeline laden
pipe = self.load_controlnet_pipeline()
# Pose extrahieren
print("🔄 ControlNet: Extrahiere Pose...")
if progress:
progress(0.05, desc="ControlNet: Extrahiere Pose...")
pose_map = self.extract_pose(image)
# Zufälliger Seed
seed = random.randint(0, 2**32 - 1)
generator = torch.Generator(device=self.device).manual_seed(seed)
print(f"ControlNet Seed: {seed}")
# Progress Callback erstellen
callback = None
if progress is not None:
callback = ControlNetProgressCallback(progress, int(steps))
print("🔄 ControlNet: Wende Pose-Kontrolle an...")
# ControlNet anwenden mit Callback
result = pipe(
prompt=prompt,
image=pose_map,
negative_prompt=negative_prompt,
num_inference_steps=int(steps),
guidance_scale=guidance_scale,
generator=generator,
controlnet_conditioning_scale=controlnet_strength,
height=512,
width=512,
output_type="pil",
callback_on_step_end=callback,
callback_on_step_end_tensor_inputs=[],
)
# Debug-Ausgabe der tatsächlichen Steps
try:
scheduler = pipe.scheduler
if hasattr(scheduler, 'timesteps'):
actual_steps = len(scheduler.timesteps)
print(f"🎯 CONTROLNET TATSÄCHLICHE STEPS: {actual_steps} (von {steps} angefordert)")
except Exception as e:
print(f"⚠️ Konnte ControlNet Scheduler-Info nicht auslesen: {e}")
print("✅ ControlNet abgeschlossen!")
return result.images[0]
except Exception as e:
print(f"❌ Fehler in ControlNet: {e}")
import traceback
traceback.print_exc()
return image.convert("RGB").resize((512, 512))
def load_controlnet_pipeline(self):
"""Lädt die ControlNet Pipeline"""
if self.pipe is None:
print("Loading ControlNet pipeline...")
try:
self.controlnet = ControlNetModel.from_pretrained(
"lllyasviel/sd-controlnet-openpose",
torch_dtype=self.torch_dtype
)
self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
controlnet=self.controlnet,
torch_dtype=self.torch_dtype,
safety_checker=None,
requires_safety_checker=False
).to(self.device)
from diffusers import DPMSolverMultistepScheduler
self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(
self.pipe.scheduler.config
)
self.pipe.enable_attention_slicing()
print("ControlNet pipeline loaded successfully!")
except Exception as e:
print(f"Fehler beim Laden von ControlNet: {e}")
raise
return self.pipe
# Globale Instanz
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if device == "cuda" else torch.float32
controlnet_processor = ControlNetProcessor(device=device, torch_dtype=torch_dtype)