pixagram-neo-backup / generator.py
primerz's picture
Update generator.py
050255c verified
raw
history blame
43.7 kB
"""
Generation logic for Pixagram AI Pixel Art Generator
"""
import gc
import torch
import numpy as np
import cv2
from PIL import Image
import torch.nn.functional as F
from torchvision import transforms
import traceback
from config import (
device, dtype, TRIGGER_WORD, MULTI_SCALE_FACTORS,
ADAPTIVE_THRESHOLDS, ADAPTIVE_PARAMS, CAPTION_CONFIG, IDENTITY_BOOST_MULTIPLIER
)
from utils import (
sanitize_text, enhanced_color_match, color_match, create_face_mask,
draw_kps, get_demographic_description, calculate_optimal_size, enhance_face_crop
)
from models import (
load_face_analysis, load_depth_detector, load_controlnets, load_image_encoder,
load_sdxl_pipeline, load_loras, setup_ip_adapter,
# --- START FIX: Import setup_compel ---
setup_compel,
# --- END FIX ---
setup_scheduler, optimize_pipeline, load_caption_model, set_clip_skip,
load_openpose_detector, load_mediapipe_face_detector
)
class RetroArtConverter:
"""Main class for retro art generation"""
def __init__(self):
self.device = device
self.dtype = dtype
self.models_loaded = {
'custom_checkpoint': False,
'lora': False,
'instantid': False,
'depth_detector': False,
'depth_type': None,
'ip_adapter': False,
'openpose': False,
'mediapipe_face': False
}
self.loaded_loras = {} # Store status of each LORA
# Initialize face analysis (InsightFace)
self.face_app, self.face_detection_enabled = load_face_analysis()
# Load MediapipeFaceDetector (alternative face detection)
self.mediapipe_face, mediapipe_success = load_mediapipe_face_detector()
self.models_loaded['mediapipe_face'] = mediapipe_success
# Load Depth detector with fallback hierarchy (Leres → Zoe → Midas)
self.depth_detector, self.depth_type, depth_success = load_depth_detector()
self.models_loaded['depth_detector'] = depth_success
self.models_loaded['depth_type'] = self.depth_type
# --- NEW: Load OpenPose detector ---
self.openpose_detector, openpose_success = load_openpose_detector()
self.models_loaded['openpose'] = openpose_success
# --- END NEW ---
# Load ControlNets
# Now unpacks 3 models + success boolean
controlnet_depth, self.controlnet_instantid, self.controlnet_openpose, instantid_success = load_controlnets()
self.controlnet_depth = controlnet_depth
self.instantid_enabled = instantid_success
self.models_loaded['instantid'] = instantid_success
# Load image encoder
if self.instantid_enabled:
self.image_encoder = load_image_encoder()
else:
self.image_encoder = None
# --- FIX START: Robust ControlNet Loading ---
# Determine which controlnets to use
# Store booleans for which models are active
self.instantid_active = self.instantid_enabled and self.controlnet_instantid is not None
self.depth_active = self.controlnet_depth is not None
self.openpose_active = self.controlnet_openpose is not None
# Build the list of *active* controlnet models
controlnets = []
if self.instantid_active:
controlnets.append(self.controlnet_instantid)
print(" [CN] InstantID (Identity) active")
else:
print(" [CN] InstantID (Identity) DISABLED")
if self.depth_active:
controlnets.append(self.controlnet_depth)
print(" [CN] Depth active")
else:
print(" [CN] Depth DISABLED")
if self.openpose_active:
controlnets.append(self.controlnet_openpose)
print(" [CN] OpenPose (Expression) active")
else:
print(" [CN] OpenPose (Expression) DISABLED")
if not controlnets:
print("[WARNING] No ControlNets loaded!")
print(f"Initializing with {len(controlnets)} active ControlNet(s)")
# Load SDXL pipeline
# Pass the filtered list (or None if empty)
self.pipe, checkpoint_success = load_sdxl_pipeline(controlnets if controlnets else None)
# --- FIX END ---
self.models_loaded['custom_checkpoint'] = checkpoint_success
# Load LORAs
self.loaded_loras, lora_success = load_loras(self.pipe)
self.models_loaded['lora'] = lora_success
# Setup IP-Adapter
if self.instantid_active and self.image_encoder is not None: # <-- Check instantid_active
self.image_proj_model, ip_adapter_success = setup_ip_adapter(self.pipe, self.image_encoder)
self.models_loaded['ip_adapter'] = ip_adapter_success
else:
print("[INFO] Face preservation: IP-Adapter disabled (InstantID model failed or encoder failed)")
self.models_loaded['ip_adapter'] = False
self.image_proj_model = None
# --- START FIX: Setup Compel ---
self.compel, self.use_compel = setup_compel(self.pipe)
# --- END FIX ---
# Setup LCM scheduler
setup_scheduler(self.pipe)
# Optimize pipeline
optimize_pipeline(self.pipe)
# Load caption model
self.caption_processor, self.caption_model, self.caption_enabled, self.caption_model_type = load_caption_model()
# Report caption model status
if self.caption_enabled and self.caption_model is not None:
if self.caption_model_type == "git":
print(" [OK] Using GIT for detailed captions")
elif self.caption_model_type == "blip":
print(" [OK] Using BLIP for standard captions")
else:
print(" [OK] Caption model loaded")
# Set CLIP skip
set_clip_skip(self.pipe)
# Track controlnet configuration
self.using_multiple_controlnets = isinstance(controlnets, list)
print(f"Pipeline initialized with {'multiple' if self.using_multiple_controlnets else 'single'} ControlNet(s)")
# Print model status
self._print_status()
print(" [OK] Model initialization complete!")
def _print_status(self):
"""Print model loading status"""
print("\n=== MODEL STATUS ===")
for model, loaded in self.models_loaded.items():
if model == 'lora':
lora_status = 'DISABLED'
if loaded:
loaded_count = sum(1 for status in self.loaded_loras.values() if status)
lora_status = f"[OK] LOADED ({loaded_count}/3)"
print(f"loras: {lora_status}")
else:
status = "[OK] LOADED" if loaded else "[FALLBACK/DISABLED]"
print(f"{model}: {status}")
print("===================\n")
print("=== UPGRADE VERIFICATION ===")
try:
# --- FIX: Corrected import paths and class names ---
from resampler import Resampler
from attention_processor import IPAttnProcessor2_0
resampler_check = isinstance(self.image_proj_model, Resampler) if hasattr(self, 'image_proj_model') and self.image_proj_model is not None else False
custom_attn_check = any(isinstance(p, IPAttnProcessor2_0) for p in self.pipe.unet.attn_processors.values()) if hasattr(self, 'pipe') else False
# --- END FIX ---
print(f"Enhanced Perceiver Resampler: {'[OK] ACTIVE' if resampler_check else '[INFO] Not active'}")
print(f"Enhanced IP-Adapter Attention: {'[OK] ACTIVE' if custom_attn_check else '[INFO] Not active'}")
if resampler_check and custom_attn_check:
print("[SUCCESS] Face preservation upgrade fully active")
print(" Expected improvement: +10-15% face similarity")
elif resampler_check or custom_attn_check:
print("[PARTIAL] Some upgrades active")
else:
print("[INFO] Using standard components")
except Exception as e:
print(f"[INFO] Verification skipped: {e}")
print("============================\n")
def get_depth_map(self, image):
"""
Generate depth map using available depth detector.
Supports: LeresDetector, ZoeDetector, or MidasDetector.
"""
if self.depth_detector is not None:
try:
if image.mode != 'RGB':
image = image.convert('RGB')
orig_width, orig_height = image.size
orig_width = int(orig_width)
orig_height = int(orig_height)
target_width = int((orig_width // 64) * 64)
target_height = int((orig_height // 64) * 64)
target_width = int(max(64, target_width))
target_height = int(max(64, target_height))
size_for_depth = (int(target_width), int(target_height))
image_for_depth = image.resize(size_for_depth, Image.LANCZOS)
if target_width != orig_width or target_height != orig_height:
print(f"[DEPTH] Resized for {self.depth_type.upper()}Detector: {orig_width}x{orig_height} -> {target_width}x{target_height}")
# Use torch.no_grad() and clear cache
with torch.no_grad():
# --- FIX: Move model to GPU for inference and back to CPU ---
self.depth_detector.to(self.device)
depth_image = self.depth_detector(image_for_depth)
self.depth_detector.to("cpu")
# ADDED: Clear GPU cache after depth detection
if torch.cuda.is_available():
torch.cuda.empty_cache()
depth_width, depth_height = depth_image.size
if depth_width != orig_width or depth_height != orig_height:
depth_image = depth_image.resize((int(orig_width), int(orig_height)), Image.LANCZOS)
print(f"[DEPTH] {self.depth_type.upper()} depth map generated: {orig_width}x{orig_height}")
return depth_image
except Exception as e:
print(f"[DEPTH] {self.depth_type.upper()}Detector failed ({e}), falling back to grayscale depth")
# ADDED: Clear cache on error
if torch.cuda.is_available():
torch.cuda.empty_cache()
gray = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
depth_colored = cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB)
return Image.fromarray(depth_colored)
else:
print("[DEPTH] No depth detector available, using grayscale fallback")
gray = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
depth_colored = cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB)
return Image.fromarray(depth_colored)
# --- START FIX: Updated function to use lora_choice ---
def add_trigger_word(self, prompt, lora_choice="RetroArt"):
"""Add trigger word to prompt if not present"""
# Get the correct trigger word from the config dictionary
trigger = TRIGGER_WORD.get(lora_choice, TRIGGER_WORD["RetroArt"])
if not trigger:
return prompt
if trigger.lower() not in prompt.lower():
if not prompt or not prompt.strip():
return trigger
# Prepend the trigger word as requested
return f"{trigger}, {prompt}"
return prompt
# --- END FIX ---
def extract_multi_scale_face(self, face_crop, face):
"""
Extract face features at multiple scales for better detail.
+1-2% improvement in face preservation.
"""
try:
multi_scale_embeds = []
for scale in MULTI_SCALE_FACTORS:
# Resize
w, h = face_crop.size
scaled_size = (int(w * scale), int(h * scale))
scaled_crop = face_crop.resize(scaled_size, Image.LANCZOS)
# Pad/crop back to original
scaled_crop = scaled_crop.resize((w, h), Image.LANCZOS)
# Extract features
scaled_array = cv2.cvtColor(np.array(scaled_crop), cv2.COLOR_RGB2BGR)
scaled_faces = self.face_app.get(scaled_array)
if len(scaled_faces) > 0:
multi_scale_embeds.append(scaled_faces[0].normed_embedding)
# Average embeddings
if len(multi_scale_embeds) > 0:
averaged = np.mean(multi_scale_embeds, axis=0)
# Renormalize
averaged = averaged / np.linalg.norm(averaged)
print(f"[MULTI-SCALE] Combined {len(multi_scale_embeds)} scales")
return averaged
return face.normed_embedding
except Exception as e:
print(f"[MULTI-SCALE] Failed: {e}, using single scale")
return face.normed_embedding
def detect_face_quality(self, face):
"""
Detect face quality and adaptively adjust parameters.
+2-3% consistency improvement.
"""
try:
bbox = face.bbox
face_size = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
det_score = float(face.det_score) if hasattr(face, 'det_score') else 1.0
# Small face -> boost identity preservation
if face_size < ADAPTIVE_THRESHOLDS['small_face_size']:
return ADAPTIVE_PARAMS['small_face'].copy()
# Low confidence -> boost preservation
elif det_score < ADAPTIVE_THRESHOLDS['low_confidence']:
return ADAPTIVE_PARAMS['low_confidence'].copy()
# Check for profile/side view (if pose available)
elif hasattr(face, 'pose') and len(face.pose) > 1:
try:
yaw = float(face.pose[1])
if abs(yaw) > ADAPTIVE_THRESHOLDS['profile_angle']:
return ADAPTIVE_PARAMS['profile_view'].copy()
except (ValueError, TypeError, IndexError):
pass
# Good quality face - use provided parameters
return None
except Exception as e:
print(f"[ADAPTIVE] Quality detection failed: {e}")
return None
def validate_and_adjust_parameters(self, strength, guidance_scale, lora_scale,
identity_preservation, identity_control_scale,
depth_control_scale, consistency_mode=True,
expression_control_scale=0.6):
"""
Enhanced parameter validation with stricter rules for consistency.
"""
if consistency_mode:
print("[CONSISTENCY] Applying strict parameter validation...")
adjustments = []
# Rule 1: Strong inverse relationship between identity and LORA
if identity_preservation > 1.2:
original_lora = lora_scale
lora_scale = min(lora_scale, 1.0)
if abs(lora_scale - original_lora) > 0.01:
adjustments.append(f"LORA: {original_lora:.2f}->{lora_scale:.2f} (high identity)")
# Rule 2: Strength-based profile activation
if strength < 0.5:
# Maximum preservation mode
if identity_preservation < 1.3:
original_identity = identity_preservation
identity_preservation = 1.3
adjustments.append(f"Identity: {original_identity:.2f}->{identity_preservation:.2f} (max preservation)")
if lora_scale > 0.9:
original_lora = lora_scale
lora_scale = 0.9
adjustments.append(f"LORA: {original_lora:.2f}->{lora_scale:.2f} (max preservation)")
if guidance_scale > 1.3:
original_cfg = guidance_scale
guidance_scale = 1.3
adjustments.append(f"CFG: {original_cfg:.2f}->{guidance_scale:.2f} (max preservation)")
elif strength > 0.7:
# Artistic transformation mode
if identity_preservation > 1.0:
original_identity = identity_preservation
identity_preservation = 1.0
adjustments.append(f"Identity: {original_identity:.2f}->{identity_preservation:.2f} (artistic mode)")
if lora_scale < 1.2:
original_lora = lora_scale
lora_scale = 1.2
adjustments.append(f"LORA: {original_lora:.2f}->{lora_scale:.2f} (artistic mode)")
# Rule 3: CFG-LORA relationship
if guidance_scale > 1.4 and lora_scale > 1.2:
original_lora = lora_scale
lora_scale = 1.1
adjustments.append(f"LORA: {original_lora:.2f}->{lora_scale:.2f} (high CFG detected)")
# Rule 4: LCM sweet spot enforcement
original_cfg = guidance_scale
guidance_scale = max(1.0, min(guidance_scale, 1.5))
if abs(guidance_scale - original_cfg) > 0.01:
adjustments.append(f"CFG: {original_cfg:.2f}->{guidance_scale:.2f} (LCM optimal)")
# Rule 5: ControlNet balance
# MODIFIED: Only sum *active* controlnets
total_control = 0
if self.instantid_active:
total_control += identity_control_scale
if self.depth_active:
total_control += depth_control_scale
if self.openpose_active:
total_control += expression_control_scale
if total_control > 2.0: # Increased max total from 1.7 to 2.0
scale_factor = 2.0 / total_control
original_id_ctrl = identity_control_scale
original_depth_ctrl = depth_control_scale
original_expr_ctrl = expression_control_scale
# Only scale active controlnets
if self.instantid_active:
identity_control_scale *= scale_factor
if self.depth_active:
depth_control_scale *= scale_factor
if self.openpose_active:
expression_control_scale *= scale_factor
adjustments.append(f"ControlNets balanced: ID {original_id_ctrl:.2f}->{identity_control_scale:.2f}, Depth {original_depth_ctrl:.2f}->{depth_control_scale:.2f}, Expr {original_expr_ctrl:.2f}->{expression_control_scale:.2f}")
# Report adjustments
if adjustments:
print(" [OK] Applied adjustments:")
for adj in adjustments:
print(f" - {adj}")
else:
print(" [OK] Parameters already optimal")
return strength, guidance_scale, lora_scale, identity_preservation, identity_control_scale, depth_control_scale, expression_control_scale
def generate_caption(self, image, max_length=None, num_beams=None):
"""Generate a descriptive caption for the image (supports BLIP-2, GIT, BLIP)."""
if not self.caption_enabled or self.caption_model is None:
return None
# Set defaults based on model type
if max_length is None:
if self.caption_model_type == "blip2":
max_length = 50 # BLIP-2 can handle longer captions
elif self.caption_model_type == "git":
max_length = 40 # GIT also produces good long captions
else:
max_length = CAPTION_CONFIG['max_length'] # BLIP base (20)
if num_beams is None:
num_beams = CAPTION_CONFIG['num_beams']
try:
# --- FIX: Move model to GPU for inference and back to CPU ---
self.caption_model.to(self.device)
if self.caption_model_type == "blip2":
# BLIP-2 specific processing
inputs = self.caption_processor(image, return_tensors="pt").to(self.device, self.dtype)
with torch.no_grad():
output = self.caption_model.generate(
**inputs,
max_length=max_length,
num_beams=num_beams,
min_length=10, # Encourage longer captions
length_penalty=1.0,
repetition_penalty=1.5,
early_stopping=True
)
caption = self.caption_processor.decode(output[0], skip_special_tokens=True)
elif self.caption_model_type == "git":
# GIT specific processing
inputs = self.caption_processor(images=image, return_tensors="pt").to(self.device, self.dtype)
with torch.no_grad():
output = self.caption_model.generate(
pixel_values=inputs.pixel_values,
max_length=max_length,
num_beams=num_beams,
min_length=10,
length_penalty=1.0,
repetition_penalty=1.5,
early_stopping=True
)
caption = self.caption_processor.batch_decode(output, skip_special_tokens=True)[0]
else:
# BLIP base processing
inputs = self.caption_processor(image, return_tensors="pt").to(self.device, self.dtype)
with torch.no_grad():
output = self.caption_model.generate(
**inputs,
max_length=max_length,
num_beams=num_beams,
early_stopping=True
)
caption = self.caption_processor.decode(output[0], skip_special_tokens=True)
self.caption_model.to("cpu")
return caption.strip()
except Exception as e:
print(f"Caption generation failed: {e}")
self.caption_model.to("cpu")
return None
def generate_retro_art(
self,
input_image,
prompt="retro game character, vibrant colors, detailed",
negative_prompt="blurry, low quality, ugly, distorted",
num_inference_steps=12,
guidance_scale=1.0,
depth_control_scale=0.8,
identity_control_scale=0.85,
expression_control_scale=0.6,
lora_choice="RetroArt",
lora_scale=1.0,
identity_preservation=0.8,
strength=0.75,
enable_color_matching=False,
consistency_mode=True,
seed=-1
):
"""Generate retro art with img2img pipeline and enhanced InstantID"""
# Sanitize text inputs
prompt = sanitize_text(prompt)
negative_prompt = sanitize_text(negative_prompt)
if not negative_prompt or not negative_prompt.strip():
negative_prompt = ""
# Apply parameter validation
if consistency_mode:
print("\n[CONSISTENCY] Validating and adjusting parameters...")
strength, guidance_scale, lora_scale, identity_preservation, identity_control_scale, depth_control_scale, expression_control_scale = \
self.validate_and_adjust_parameters(
strength, guidance_scale, lora_scale, identity_preservation,
identity_control_scale, depth_control_scale, consistency_mode,
expression_control_scale
)
# --- START FIX: Pass lora_choice to add_trigger_word ---
prompt = self.add_trigger_word(prompt, lora_choice)
# --- END FIX ---
# Calculate optimal size with flexible aspect ratio support
original_width, original_height = input_image.size
target_width, target_height = calculate_optimal_size(original_width, original_height)
print(f"Resizing from {original_width}x{original_height} to {target_width}x{target_height}")
print(f"Prompt: {prompt}")
print(f"Img2Img Strength: {strength}")
# Resize with high quality
resized_image = input_image.resize((int(target_width), int(target_height)), Image.LANCZOS)
# --- FIX START: Generate control images only if models are active ---
# Generate depth map
depth_image = None
if self.depth_active:
print("Generating Zoe depth map...")
depth_image = self.get_depth_map(resized_image)
if depth_image.size != (target_width, target_height):
depth_image = depth_image.resize((int(target_width), int(target_height)), Image.LANCZOS)
# Generate OpenPose map
openpose_image = None
if self.openpose_active:
print("Generating OpenPose map...")
try:
# --- FIX: Move model to GPU for inference and back to CPU ---
self.openpose_detector.to(self.device)
openpose_image = self.openpose_detector(resized_image, face_only=True)
self.openpose_detector.to("cpu")
except Exception as e:
print(f"OpenPose failed, using blank map: {e}")
self.openpose_detector.to("cpu")
openpose_image = Image.new("RGB", (target_width, target_height), (0,0,0))
# --- FIX END ---
# Handle face detection
face_kps_image = None
face_embeddings = None
face_crop_enhanced = None
has_detected_faces = False
face_bbox_original = None
if self.instantid_active:
# Try InsightFace first (if available)
insightface_tried = False
insightface_success = False
if self.face_app is not None:
print("Detecting faces with InsightFace...")
insightface_tried = True
try:
img_array = cv2.cvtColor(np.array(resized_image), cv2.COLOR_RGB2BGR)
faces = self.face_app.get(img_array)
if len(faces) > 0:
insightface_success = True
has_detected_faces = True
print(f"✓ InsightFace detected {len(faces)} face(s)")
# Get largest face
face = sorted(faces, key=lambda x: (x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1]))[-1]
# ADAPTIVE PARAMETERS
adaptive_params = self.detect_face_quality(face)
if adaptive_params is not None:
print(f"[ADAPTIVE] {adaptive_params['reason']}")
identity_preservation = adaptive_params['identity_preservation']
identity_control_scale = adaptive_params['identity_control_scale']
guidance_scale = adaptive_params['guidance_scale']
lora_scale = adaptive_params['lora_scale']
# Extract face embeddings
face_embeddings_base = face.normed_embedding
# Extract face crop
bbox = face.bbox.astype(int)
x1, y1, x2, y2 = bbox[0], bbox[1], bbox[2], bbox[3]
face_bbox_original = [x1, y1, x2, y2]
# Add padding
face_width = x2 - x1
face_height = y2 - y1
padding_x = int(face_width * 0.3)
padding_y = int(face_height * 0.3)
x1 = max(0, x1 - padding_x)
y1 = max(0, y1 - padding_y)
x2 = min(resized_image.width, x2 + padding_x)
y2 = min(resized_image.height, y2 + padding_y)
# Crop face region
face_crop = resized_image.crop((x1, y1, x2, y2))
# MULTI-SCALE PROCESSING
face_embeddings = self.extract_multi_scale_face(face_crop, face)
# Enhance face crop
face_crop_enhanced = enhance_face_crop(face_crop)
# Draw keypoints
face_kps = face.kps
face_kps_image = draw_kps(resized_image, face_kps)
# ENHANCED: Extract comprehensive facial attributes
from utils import get_facial_attributes, build_enhanced_prompt
facial_attrs = get_facial_attributes(face)
# Update prompt with detected attributes
prompt = build_enhanced_prompt(prompt, facial_attrs, TRIGGER_WORD[lora_choice])
# Legacy output for compatibility
age = facial_attrs['age']
gender_code = facial_attrs['gender']
det_score = facial_attrs['quality']
gender_str = 'M' if gender_code == 1 else ('F' if gender_code == 0 else 'N/A')
print(f"Face info: bbox={face.bbox}, age={age if age else 'N/A'}, gender={gender_str}")
print(f"Face crop size: {face_crop.size}, enhanced: {face_crop_enhanced.size if face_crop_enhanced else 'N/A'}")
else:
print("✗ InsightFace found no faces")
except Exception as e:
print(f"[ERROR] InsightFace detection failed: {e}")
traceback.print_exc()
else:
print("[INFO] InsightFace not available (face_app is None)")
# If InsightFace didn't succeed, try MediapipeFace
if not insightface_success:
if self.mediapipe_face is not None:
print("Trying MediapipeFaceDetector as fallback...")
try:
# MediapipeFace returns an annotated image with keypoints
mediapipe_result = self.mediapipe_face(resized_image)
# Check if face was detected (result is not blank/black)
mediapipe_array = np.array(mediapipe_result)
if mediapipe_array.sum() > 1000: # If image has significant content
has_detected_faces = True
face_kps_image = mediapipe_result
print(f"✓ MediapipeFace detected face(s)")
print(f"[INFO] Using MediapipeFace keypoints (no embeddings available)")
# Note: MediapipeFace doesn't provide embeddings or detailed info
# So face_embeddings, face_crop_enhanced remain None
# InstantID will work with keypoints only (reduced quality)
else:
print("✗ MediapipeFace found no faces")
except Exception as e:
print(f"[ERROR] MediapipeFace detection failed: {e}")
traceback.print_exc()
else:
print("[INFO] MediapipeFaceDetector not available")
# Final summary
if not has_detected_faces:
print("\n[SUMMARY] No faces detected by any detector")
if insightface_tried:
print(" - InsightFace: tried, found nothing")
else:
print(" - InsightFace: not available")
if self.mediapipe_face is not None:
print(" - MediapipeFace: tried, found nothing")
else:
print(" - MediapipeFace: not available")
print()
# Set LORA
if hasattr(self.pipe, 'set_adapters') and self.models_loaded['lora']:
adapter_name = lora_choice.lower() # "retroart", "vga", "lucasart", or "none"
if adapter_name != "none" and self.loaded_loras.get(adapter_name, False):
try:
self.pipe.set_adapters([adapter_name], adapter_weights=[lora_scale])
print(f"LORA: Set adapter '{adapter_name}' with scale: {lora_scale}")
except Exception as e:
print(f"Could not set LORA adapter '{adapter_name}': {e}")
self.pipe.set_adapters([]) # Disable LORAs if setting failed
else:
if adapter_name == "none":
print("LORAs disabled by user choice.")
else:
print(f"LORA '{adapter_name}' not loaded or available, disabling LORAs.")
self.pipe.set_adapters([]) # Disable all LORAs
# Prepare generation kwargs
pipe_kwargs = {
"image": resized_image,
"strength": strength,
"num_inference_steps": num_inference_steps,
"guidance_scale": guidance_scale,
}
# Setup generator with seed control
if seed == -1:
generator = torch.Generator(device=self.device)
actual_seed = generator.seed()
print(f"[SEED] Using random seed: {actual_seed}")
else:
generator = torch.Generator(device=self.device).manual_seed(seed)
actual_seed = seed
print(f"[SEED] Using fixed seed: {actual_seed}")
pipe_kwargs["generator"] = generator
# --- START FIX: Use Compel instead of Cappella ---
if self.use_compel and self.compel is not None:
try:
print("Encoding prompts with Compel...")
conditioning = self.compel(prompt)
negative_conditioning = self.compel(negative_prompt)
pipe_kwargs["prompt_embeds"] = conditioning[0]
pipe_kwargs["pooled_prompt_embeds"] = conditioning[1]
pipe_kwargs["negative_prompt_embeds"] = negative_conditioning[0]
pipe_kwargs["negative_pooled_prompt_embeds"] = negative_conditioning[1]
print(f"[OK] Compel encoded - Prompt: {pipe_kwargs['prompt_embeds'].shape}, Negative: {pipe_kwargs['negative_prompt_embeds'].shape}")
except Exception as e:
print(f"Compel encoding failed, using standard prompts: {e}")
traceback.print_exc()
pipe_kwargs["prompt"] = prompt
pipe_kwargs["negative_prompt"] = negative_prompt
else:
print("[WARNING] Compel not found, using standard prompt encoding.")
pipe_kwargs["prompt"] = prompt
pipe_kwargs["negative_prompt"] = negative_prompt
# --- END FIX ---
# Add CLIP skip
if hasattr(self.pipe, 'text_encoder'):
pipe_kwargs["clip_skip"] = 2
control_images = []
conditioning_scales = []
scale_debug_str = []
# Helper function to ensure control image has correct dimensions
def ensure_correct_size(img, target_w, target_h, name="control"):
"""Ensure image matches target dimensions exactly"""
if img is None:
return Image.new("RGB", (target_w, target_h), (0,0,0))
if img.size != (target_w, target_h):
print(f" [RESIZE] {name}: {img.size} -> ({target_w}, {target_h})")
img = img.resize((target_w, target_h), Image.LANCZOS)
return img
# --- START FIX: Re-written IP-Adapter/ControlNet logic ---
# 1. InstantID (Identity)
if self.instantid_active:
if has_detected_faces and face_kps_image is not None and face_embeddings is not None:
# Case 1: Face + Embeddings found
# A. Set the IP-Adapter (face) strength
boosted_scale = identity_preservation * IDENTITY_BOOST_MULTIPLIER
self.pipe.set_ip_adapter_scale(boosted_scale)
# B. Pass the raw 512-dim face embeddings to the pipeline
pipe_kwargs["image_embeds"] = face_embeddings
# C. Add the face keypoints (ControlNet) image
face_kps_image = ensure_correct_size(face_kps_image, target_width, target_height, "InstantID")
control_images.append(face_kps_image)
conditioning_scales.append(identity_control_scale)
scale_debug_str.append(f"Identity (IP): {boosted_scale:.2f}")
scale_debug_str.append(f"Identity (CN): {identity_control_scale:.2f}")
print(f"[OK] InstantID active: IP-Adapter scale set to {boosted_scale:.2f}, ControlNet scale set to {identity_control_scale:.2f}")
elif has_detected_faces:
# Case 2: Face detected (e.g., Mediapipe) but no embeddings available
print("[INSTANTID] Using keypoints only (no face embeddings for IP-Adapter).")
# A. Turn off IP-Adapter
self.pipe.set_ip_adapter_scale(0.0)
# B. Pass dummy embeddings to prevent crash
pipe_kwargs["image_embeds"] = np.zeros(512)
# C. Add face keypoints (ControlNet)
face_kps_image = ensure_correct_size(face_kps_image, target_width, target_height, "InstantID")
control_images.append(face_kps_image)
conditioning_scales.append(identity_control_scale) # Use the CN scale
scale_debug_str.append("Identity (IP): 0.00")
scale_debug_str.append(f"Identity (CN): {identity_control_scale:.2f}")
else:
# Case 3: No face detected at all
print("[INSTANTID] No face detected. Disabling face identity.")
# A. Turn off IP-Adapter
self.pipe.set_ip_adapter_scale(0.0)
# B. Pass dummy embeddings to prevent crash
pipe_kwargs["image_embeds"] = np.zeros(512)
# C. Add blank image for ControlNet (to keep list order)
control_images.append(Image.new("RGB", (target_width, target_height), (0,0,0)))
conditioning_scales.append(0.0) # Set CN scale to 0
scale_debug_str.append("Identity (IP): 0.00")
scale_debug_str.append("Identity (CN): 0.00")
# --- END FIX ---
# 2. Depth
if self.depth_active:
# Ensure depth image has correct size
depth_image = ensure_correct_size(depth_image, target_width, target_height, "Depth")
control_images.append(depth_image)
conditioning_scales.append(depth_control_scale)
scale_debug_str.append(f"Depth: {depth_control_scale:.2f}")
# 3. OpenPose (Expression)
if self.openpose_active:
# Ensure openpose image has correct size
openpose_image = ensure_correct_size(openpose_image, target_width, target_height, "OpenPose")
control_images.append(openpose_image)
conditioning_scales.append(expression_control_scale)
scale_debug_str.append(f"Expression: {expression_control_scale:.2f}")
# Final validation: ensure all control images have identical dimensions
if control_images:
expected_size = (target_width, target_height)
for idx, img in enumerate(control_images):
if img.size != expected_size:
print(f" [WARNING] Control image {idx} size mismatch: {img.size} vs expected {expected_size}")
control_images[idx] = img.resize(expected_size, Image.LANCZOS)
pipe_kwargs["control_image"] = control_images
pipe_kwargs["controlnet_conditioning_scale"] = conditioning_scales
print(f"Active ControlNets: {len(control_images)} (all {target_width}x{target_height})")
else:
print("No active ControlNets, running standard Img2Img")
# Generate
print(f"Generating with LCM: Steps={num_inference_steps}, CFG={guidance_scale}, Strength={strength}")
print(f"Controlnet scales - {' | '.join(scale_debug_str)}")
result = self.pipe(**pipe_kwargs)
generated_image = result.images[0]
# Post-processing
if enable_color_matching and has_detected_faces:
print("Applying enhanced face-aware color matching...")
try:
if face_bbox_original is not None:
generated_image = enhanced_color_match(
generated_image,
resized_image,
face_bbox=face_bbox_original
)
print("[OK] Enhanced color matching applied (face-aware)")
else:
generated_image = color_match(generated_image, resized_image, mode='mkl')
print("[OK] Standard color matching applied")
except Exception as e:
print(f"Color matching failed: {e}")
elif enable_color_matching:
print("Applying standard color matching...")
try:
generated_image = color_match(generated_image, resized_image, mode='mkl')
print("[OK] Standard color matching applied")
except Exception as e:
print(f"Color matching failed: {e}")
return generated_image
print("[OK] Generator class ready")