Stable-ControlNet-GPU / controlnet_facefix.py
Astridkraft's picture
Update controlnet_facefix.py
5e5594f verified
raw
history blame
5.49 kB
# controlnet_facefix.py - KORRIGIERTE VERSION FÜR HF SPACES
import torch
from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel
from PIL import Image
import time
print("="*60)
print("CONTROLNET_FACEFIX START")
print("="*60)
# WICHTIG: Keine .to("cuda") beim Laden!
_components_loaded = False
_depth_processor = None
_controlnet_depth = None
_controlnet_pose = None
_pipeline = None
def _initialize_components():
"""Lade alle Komponenten mit Fehlerbehandlung"""
global _components_loaded, _depth_processor, _controlnet_depth, _controlnet_pose
if _components_loaded:
return True
try:
print("1. Lade Depth Processor...")
from controlnet_aux import ZoeDetector
_depth_processor = ZoeDetector.from_pretrained("lllyasviel/ControlNet")
print(" ✅ Depth Processor OK")
except Exception as e:
print(f" ❌ Depth Processor Fehler: {e}")
return False
try:
print("2. Lade ControlNet Depth...")
_controlnet_depth = ControlNetModel.from_pretrained(
"lllyasviel/control_v11f1e_sd15_depth",
torch_dtype=torch.float16 # KEIN .to("cuda") hier!
)
print(" ✅ ControlNet Depth OK")
except Exception as e:
print(f" ❌ ControlNet Depth Fehler: {e}")
return False
try:
print("3. Lade ControlNet OpenPose...")
# WICHTIG: KEIN subfolder="faceonly" - das gibt es nicht!
_controlnet_pose = ControlNetModel.from_pretrained(
"lllyasviel/control_v11p_sd15_openpose",
torch_dtype=torch.float16 # KEIN .to("cuda") hier!
)
print(" ✅ ControlNet OpenPose OK")
except Exception as e:
print(f" ❌ ControlNet OpenPose Fehler: {e}")
return False
_components_loaded = True
print("✅ ALLE KOMPONENTEN GELADEN")
return True
def _create_control_images(image):
"""Erstelle Depth und Pose Bilder"""
try:
print(" Erstelle Depth Map...")
depth_img = _depth_processor(image)
print(" Erstelle OpenPose...")
from controlnet_aux import OpenposeDetector
openpose = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
pose_img = openpose(image)
return pose_img, depth_img
except Exception as e:
print(f" ❌ Control Images Fehler: {e}")
return None, None
def apply_facefix(image: Image.Image, prompt: str, negative_prompt: str, seed: int, model_id: str):
"""HAUPTFUNKTION - Optimiert für HF Spaces"""
print("\n" + "🎭"*50)
print("FACE-FIX WIRD AUSGEFÜHRT")
print(f" Model: {model_id}")
print(f" Seed: {seed}")
print("🎭"*50)
start_time = time.time()
# 1. Komponenten initialisieren
if not _initialize_components():
print("❌ Komponenten konnten nicht geladen werden")
return image
# 2. Control Images erstellen
pose_img, depth_img = _create_control_images(image)
if pose_img is None or depth_img is None:
print("❌ Control Images konnten nicht erstellt werden")
return image
# 3. Pipeline erstellen (lazy)
global _pipeline
if _pipeline is None:
try:
print("🔄 Lade Pipeline...")
_pipeline = StableDiffusionControlNetInpaintPipeline.from_pretrained(
model_id,
controlnet=[_controlnet_pose, _controlnet_depth],
torch_dtype=torch.float16,
safety_checker=None,
requires_safety_checker=False,
)
# HF Spaces Optimierungen
_pipeline.enable_attention_slicing()
_pipeline.enable_vae_slicing()
print("✅ Pipeline geladen")
except Exception as e:
print(f"❌ Pipeline Fehler: {e}")
return image
try:
# 4. Auf GPU bewegen (erst JETZT!)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f" Device: {device}")
pipeline = _pipeline.to(device)
# 5. Prompts optimieren
face_prompt = f"{prompt}, perfect face, detailed skin, realistic eyes, sharp focus"
face_negative = f"{negative_prompt}, deformed face, blurry face, bad anatomy, ugly"
print("⚡ Führe Face-Fix Inference aus...")
# 6. Face-Fix ausführen
result = pipeline(
prompt=face_prompt,
negative_prompt=face_negative,
image=image,
mask_image=None, # Ganzer Bildbereich
control_image=[pose_img, depth_img],
controlnet_conditioning_scale=[0.8, 0.6], # OpenPose stärker
strength=0.4, # Mittlere Stärke
num_inference_steps=20,
guidance_scale=7.0,
generator=torch.Generator(device).manual_seed(seed),
height=512,
width=512,
).images[0]
duration = time.time() - start_time
print(f"\n✅✅✅ FACE-FIX ERFOLGREICH in {duration:.1f}s ✅✅✅")
return result
except Exception as e:
print(f"\n❌❌❌ FACE-FIX FEHLGESCHLAGEN: {e} ❌❌❌")
import traceback
traceback.print_exc()
return image
print("="*60)
print("CONTROLNET_FACEFIX INITIALISIERT")
print(f"apply_facefix Funktion: {'apply_facefix' in globals()}")
print("="*60)