|
|
import torch |
|
|
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel |
|
|
from controlnet_aux import OpenposeDetector |
|
|
from PIL import Image, ImageFilter |
|
|
import random |
|
|
import cv2 |
|
|
import numpy as np |
|
|
import gradio as gr |
|
|
|
|
|
from transformers import Sam2Model, Sam2Processor |
|
|
|
|
|
|
|
|
class ControlNetProgressCallback: |
|
|
def __init__(self, progress, total_steps): |
|
|
self.progress = progress |
|
|
self.total_steps = total_steps |
|
|
self.current_step = 0 |
|
|
|
|
|
def __call__(self, pipe, step_index, timestep, callback_kwargs): |
|
|
self.current_step = step_index + 1 |
|
|
progress_percentage = self.current_step / self.total_steps |
|
|
|
|
|
if self.progress is not None: |
|
|
self.progress(progress_percentage, desc=f"ControlNet: Schritt {self.current_step}/{self.total_steps}") |
|
|
|
|
|
print(f"ControlNet Fortschritt: {self.current_step}/{self.total_steps} ({progress_percentage:.1%})") |
|
|
return callback_kwargs |
|
|
|
|
|
|
|
|
class ControlNetProcessor: |
|
|
def __init__(self, device="cuda", torch_dtype=torch.float32): |
|
|
self.device = device |
|
|
self.torch_dtype = torch_dtype |
|
|
self.pose_detector = None |
|
|
self.midas_model = None |
|
|
self.midas_transform = None |
|
|
|
|
|
self.sam_processor = None |
|
|
self.sam_model = None |
|
|
self.sam_initialized = False |
|
|
|
|
|
def _lazy_load_sam(self): |
|
|
"""Lazy Loading von SAM 2 über 🤗 Transformers API""" |
|
|
if self.sam_initialized: |
|
|
return True |
|
|
|
|
|
try: |
|
|
print("🔄 Lade SAM 2 über 🤗 Transformers...") |
|
|
|
|
|
|
|
|
model_id = "facebook/sam2-hiera-tiny" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.sam_processor = Sam2Processor.from_pretrained(model_id) |
|
|
self.sam_model = Sam2Model.from_pretrained(model_id, torch_dtype=self.torch_dtype).to(self.device) |
|
|
self.sam_model.eval() |
|
|
|
|
|
|
|
|
self.sam_initialized = True |
|
|
print("✅ SAM 2 erfolgreich geladen (via Transformers)") |
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ Fehler beim Laden von SAM 2: {str(e)[:200]}") |
|
|
self.sam_initialized = True |
|
|
return False |
|
|
|
|
|
def _validate_bbox(self, image, bbox_coords): |
|
|
"""Validiert und korrigiert BBox-Koordinaten""" |
|
|
width, height = image.size |
|
|
|
|
|
|
|
|
if isinstance(bbox_coords, (list, tuple)) and len(bbox_coords) == 4: |
|
|
x1, y1, x2, y2 = bbox_coords |
|
|
else: |
|
|
|
|
|
x1, y1, x2, y2 = bbox_coords |
|
|
|
|
|
|
|
|
x1, x2 = min(x1, x2), max(x1, x2) |
|
|
y1, y2 = min(y1, y2), max(y1, y2) |
|
|
|
|
|
|
|
|
x1 = max(0, min(x1, width - 1)) |
|
|
y1 = max(0, min(y1, height - 1)) |
|
|
x2 = max(0, min(x2, width - 1)) |
|
|
y2 = max(0, min(y2, height - 1)) |
|
|
|
|
|
|
|
|
if x2 - x1 < 10 or y2 - y1 < 10: |
|
|
|
|
|
size = min(width, height) * 0.3 |
|
|
x1 = max(0, width/2 - size/2) |
|
|
y1 = max(0, height/2 - size/2) |
|
|
x2 = min(width, width/2 + size/2) |
|
|
y2 = min(height, height/2 + size/2) |
|
|
|
|
|
return int(x1), int(y1), int(x2), int(y2) |
|
|
|
|
|
def _smooth_mask(self, mask_array, blur_radius=3): |
|
|
"""Glättet die Maske für bessere Übergänge""" |
|
|
try: |
|
|
if blur_radius > 0: |
|
|
|
|
|
mask_array = cv2.medianBlur(mask_array, blur_radius*2+1) |
|
|
return mask_array |
|
|
except Exception as e: |
|
|
print(f"⚠️ Fehler beim Glätten der Maske: {e}") |
|
|
return mask_array |
|
|
|
|
|
def create_sam_mask(self, image, bbox_coords, mode): |
|
|
""" |
|
|
Erstellt präzise Maske mit SAM 2 (via 🤗 Transformers API) |
|
|
Gibt PIL Image in L-Modus zurück (0=schwarz=erhalten, 255=weiß=verändern) |
|
|
""" |
|
|
try: |
|
|
|
|
|
if not self.sam_initialized: |
|
|
self._lazy_load_sam() |
|
|
|
|
|
if self.sam_model is None or self.sam_processor is None: |
|
|
print("⚠️ SAM 2 Model nicht verfügbar, verwende Fallback") |
|
|
return self._create_rectangular_mask(image, bbox_coords, mode) |
|
|
|
|
|
|
|
|
x1, y1, x2, y2 = self._validate_bbox(image, bbox_coords) |
|
|
width, height = image.size |
|
|
|
|
|
|
|
|
image_np = np.array(image.convert("RGB")) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
input_boxes = [[[x1, y1, x2, y2]]] |
|
|
|
|
|
|
|
|
inputs = self.sam_processor( |
|
|
image_np, |
|
|
input_boxes=input_boxes, |
|
|
return_tensors="pt" |
|
|
).to(self.device) |
|
|
|
|
|
|
|
|
print(f"🎯 SAM 2: Segmentiere Bereich {x1},{y1}-{x2},{y2}") |
|
|
with torch.no_grad(): |
|
|
outputs = self.sam_model(**inputs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mask = self.sam_processor.post_process_masks( |
|
|
outputs.pred_masks, |
|
|
inputs.original_sizes, |
|
|
inputs.reshaped_input_sizes |
|
|
)[0][0] |
|
|
|
|
|
|
|
|
mask = mask.sigmoid().cpu().numpy() |
|
|
mask_array = (mask > 0.5).astype(np.uint8) * 255 |
|
|
|
|
|
|
|
|
mask = Image.fromarray(mask_array.squeeze()).convert("L") |
|
|
mask = mask.resize((width, height), Image.Resampling.NEAREST) |
|
|
|
|
|
|
|
|
mask_array = np.array(mask) |
|
|
mask_array = self._smooth_mask(mask_array, blur_radius=2) |
|
|
mask = Image.fromarray(mask_array).convert("L") |
|
|
|
|
|
|
|
|
if mode == "environment_change": |
|
|
|
|
|
mask = Image.eval(mask, lambda x: 255 - x) |
|
|
print(" SAM-Modus: Umgebung ändern (Objekt erhalten)") |
|
|
else: |
|
|
|
|
|
print(" SAM-Modus: Focus/Gesicht ändern (Objekt verändern)") |
|
|
|
|
|
print(f"✅ SAM 2: Präzise Maske erstellt ({mask.size})") |
|
|
return mask |
|
|
|
|
|
except Exception as e: |
|
|
print(f"⚠️ SAM 2 Fehler (Transformers API): {str(e)[:200]}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
print("ℹ️ Fallback auf rechteckige Maske") |
|
|
return self._create_rectangular_mask(image, bbox_coords, mode) |
|
|
|
|
|
def _create_rectangular_mask(self, image, bbox_coords, mode): |
|
|
"""Fallback: Erstellt rechteckige Maske""" |
|
|
from PIL import ImageDraw |
|
|
|
|
|
mask = Image.new("L", image.size, 0) |
|
|
|
|
|
if bbox_coords and all(coord is not None for coord in bbox_coords): |
|
|
x1, y1, x2, y2 = self._validate_bbox(image, bbox_coords) |
|
|
draw = ImageDraw.Draw(mask) |
|
|
|
|
|
if mode == "environment_change": |
|
|
|
|
|
draw.rectangle([0, 0, image.size[0], image.size[1]], fill=255) |
|
|
draw.rectangle([x1, y1, x2, y2], fill=0) |
|
|
print("ℹ️ Rechteckige Maske: Umgebung ändern") |
|
|
else: |
|
|
|
|
|
draw.rectangle([x1, y1, x2, y2], fill=255) |
|
|
print("ℹ️ Rechteckige Maske: Focus/Gesicht ändern") |
|
|
|
|
|
return mask |
|
|
|
|
|
def load_pose_detector(self): |
|
|
"""Lädt nur den Pose-Detector""" |
|
|
if self.pose_detector is None: |
|
|
print("Loading Pose Detector...") |
|
|
try: |
|
|
self.pose_detector = OpenposeDetector.from_pretrained("lllyasviel/ControlNet") |
|
|
print("✅ Pose-Detector geladen") |
|
|
except Exception as e: |
|
|
print(f"⚠️ Pose-Detector konnte nicht geladen werden: {e}") |
|
|
return self.pose_detector |
|
|
|
|
|
def load_midas_model(self): |
|
|
"""Lädt MiDaS Model für Depth Maps""" |
|
|
if self.midas_model is None: |
|
|
print("🔄 Lade MiDaS Modell für Depth Maps...") |
|
|
try: |
|
|
import torchvision.transforms as T |
|
|
|
|
|
self.midas_model = torch.hub.load( |
|
|
"intel-isl/MiDaS", |
|
|
"DPT_Hybrid", |
|
|
trust_repo=True |
|
|
) |
|
|
|
|
|
self.midas_model.to(self.device) |
|
|
self.midas_model.eval() |
|
|
|
|
|
self.midas_transform = T.Compose([ |
|
|
T.Resize(384), |
|
|
T.ToTensor(), |
|
|
T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), |
|
|
]) |
|
|
|
|
|
print("✅ MiDaS Modell erfolgreich geladen") |
|
|
except Exception as e: |
|
|
print(f"❌ MiDaS konnte nicht geladen werden: {e}") |
|
|
print("ℹ️ Verwende Fallback-Methode") |
|
|
self.midas_model = None |
|
|
|
|
|
return self.midas_model |
|
|
|
|
|
def extract_pose_simple(self, image): |
|
|
"""Einfache Pose-Extraktion ohne komplexe Abhängigkeiten""" |
|
|
try: |
|
|
img_array = np.array(image.convert("RGB")) |
|
|
edges = cv2.Canny(img_array, 100, 200) |
|
|
pose_image = Image.fromarray(edges).convert("RGB") |
|
|
print("⚠️ Verwende Kanten-basierte Pose-Approximation") |
|
|
return pose_image |
|
|
except Exception as e: |
|
|
print(f"Fehler bei einfacher Pose-Extraktion: {e}") |
|
|
return image.convert("RGB").resize((512, 512)) |
|
|
|
|
|
def extract_pose(self, image): |
|
|
"""Extrahiert Pose-Map aus Bild mit Fallback""" |
|
|
try: |
|
|
detector = self.load_pose_detector() |
|
|
if detector is None: |
|
|
return self.extract_pose_simple(image) |
|
|
|
|
|
pose_image = detector(image, hand_and_face=True) |
|
|
return pose_image |
|
|
except Exception as e: |
|
|
print(f"Fehler bei Pose-Extraktion: {e}") |
|
|
return self.extract_pose_simple(image) |
|
|
|
|
|
def extract_canny_edges(self, image): |
|
|
"""Extrahiert Canny Edges für Umgebungserhaltung""" |
|
|
try: |
|
|
img_array = np.array(image.convert("RGB")) |
|
|
|
|
|
gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY) |
|
|
edges = cv2.Canny(gray, 100, 200) |
|
|
|
|
|
edges_rgb = cv2.cvtColor(edges, cv2.COLOR_GRAY2RGB) |
|
|
edges_image = Image.fromarray(edges_rgb) |
|
|
|
|
|
print("✅ Canny Edge Map erstellt") |
|
|
return edges_image |
|
|
except Exception as e: |
|
|
print(f"Fehler bei Canny Edge Extraction: {e}") |
|
|
return image.convert("RGB").resize((512, 512)) |
|
|
|
|
|
def extract_depth_map(self, image): |
|
|
""" |
|
|
Extrahiert Depth Map mit MiDaS (Fallback auf Filter) |
|
|
""" |
|
|
try: |
|
|
midas = self.load_midas_model() |
|
|
if midas is not None: |
|
|
print("🎯 Verwende MiDaS für Depth Map...") |
|
|
|
|
|
import torchvision.transforms as T |
|
|
|
|
|
img_transformed = self.midas_transform(image).unsqueeze(0).to(self.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
prediction = midas(img_transformed) |
|
|
prediction = torch.nn.functional.interpolate( |
|
|
prediction.unsqueeze(1), |
|
|
size=image.size[::-1], |
|
|
mode="bicubic", |
|
|
align_corners=False, |
|
|
).squeeze() |
|
|
|
|
|
depth_np = prediction.cpu().numpy() |
|
|
depth_min, depth_max = depth_np.min(), depth_np.max() |
|
|
|
|
|
if depth_max > depth_min: |
|
|
depth_np = (depth_np - depth_min) / (depth_max - depth_min) |
|
|
|
|
|
depth_np = (depth_np * 255).astype(np.uint8) |
|
|
depth_image = Image.fromarray(depth_np).convert("RGB") |
|
|
|
|
|
print("✅ MiDaS Depth Map erfolgreich erstellt") |
|
|
return depth_image |
|
|
|
|
|
else: |
|
|
raise Exception("MiDaS nicht geladen") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"⚠️ MiDaS Fehler: {e}. Verwende Fallback...") |
|
|
try: |
|
|
img_array = np.array(image.convert("RGB")) |
|
|
gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY) |
|
|
|
|
|
depth_map = cv2.GaussianBlur(gray, (5, 5), 0) |
|
|
depth_rgb = cv2.cvtColor(depth_map, cv2.COLOR_GRAY2RGB) |
|
|
depth_image = Image.fromarray(depth_rgb) |
|
|
|
|
|
print("✅ Fallback Depth Map erstellt") |
|
|
return depth_image |
|
|
except Exception as fallback_error: |
|
|
print(f"❌ Auch Fallback fehlgeschlagen: {fallback_error}") |
|
|
return image.convert("RGB").resize((512, 512)) |
|
|
|
|
|
def prepare_controlnet_maps(self, image, keep_environment=False): |
|
|
""" |
|
|
ERSTELLT NUR CONDITIONING-MAPS, generiert KEIN Bild. |
|
|
""" |
|
|
print("🎯 ControlNet: Erstelle Conditioning-Maps...") |
|
|
|
|
|
if keep_environment: |
|
|
print(" Modus: Depth + Canny") |
|
|
conditioning_images = [ |
|
|
self.extract_depth_map(image), |
|
|
self.extract_canny_edges(image) |
|
|
] |
|
|
else: |
|
|
print(" Modus: OpenPose + Canny") |
|
|
conditioning_images = [ |
|
|
self.extract_pose(image), |
|
|
self.extract_canny_edges(image) |
|
|
] |
|
|
|
|
|
print(f"✅ {len(conditioning_images)} Conditioning-Maps erstellt.") |
|
|
return conditioning_images |
|
|
|
|
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
torch_dtype = torch.float16 if device == "cuda" else torch.float32 |
|
|
controlnet_processor = ControlNetProcessor(device=device, torch_dtype=torch_dtype) |