pixagram-neo-backup / generator.py
primerz's picture
Update generator.py
5307fa2 verified
raw
history blame
18.1 kB
"""
Generation logic for Pixagram AI Pixel Art Generator
FIXED VERSION - Proper embedding integration following exampleapp.py pattern
"""
import torch
import numpy as np
import cv2
from PIL import Image
import torch.nn.functional as F
from torchvision import transforms
from config import (
device, dtype, TRIGGER_WORD, MULTI_SCALE_FACTORS,
ADAPTIVE_THRESHOLDS, ADAPTIVE_PARAMS, CAPTION_CONFIG, IDENTITY_BOOST_MULTIPLIER
)
from utils import (
sanitize_text, enhanced_color_match, color_match, create_face_mask,
draw_kps, get_demographic_description, calculate_optimal_size, enhance_face_crop
)
from models import ( # Use the hybrid version (supports both loading methods)
load_face_analysis, load_depth_detector, load_controlnets, load_image_encoder,
load_sdxl_pipeline, load_lora, setup_ip_adapter, setup_compel,
setup_scheduler, optimize_pipeline, load_caption_model, set_clip_skip
)
class RetroArtConverter:
"""Main class for retro art generation - FIXED VERSION"""
def __init__(self):
self.device = device
self.dtype = dtype
self.models_loaded = {
'custom_checkpoint': False,
'lora': False,
'instantid': False,
'zoe_depth': False,
'ip_adapter': False
}
# Initialize face analysis
self.face_app, self.face_detection_enabled = load_face_analysis()
# Load Zoe Depth detector
self.zoe_depth, zoe_success = load_depth_detector()
self.models_loaded['zoe_depth'] = zoe_success
# Load ControlNets
controlnet_depth, self.controlnet_instantid, instantid_success = load_controlnets()
self.controlnet_depth = controlnet_depth
self.instantid_enabled = instantid_success
self.models_loaded['instantid'] = instantid_success
# Load image encoder (still needed for some pipeline functions)
if self.instantid_enabled:
self.image_encoder = load_image_encoder()
else:
self.image_encoder = None
# Determine which controlnets to use
if self.instantid_enabled and self.controlnet_instantid is not None:
controlnets = [self.controlnet_instantid, controlnet_depth]
print(f"Initializing with multiple ControlNets: InstantID + Depth")
else:
controlnets = controlnet_depth
print(f"Initializing with single ControlNet: Depth only")
# CRITICAL FIX: Load SDXL pipeline with from_pretrained()
self.pipe, checkpoint_success = load_sdxl_pipeline(controlnets)
self.models_loaded['custom_checkpoint'] = checkpoint_success
# Load LORA
lora_success = load_lora(self.pipe)
self.models_loaded['lora'] = lora_success
# CRITICAL FIX: Setup IP-Adapter using pipeline's built-in method
if self.instantid_enabled and self.image_encoder is not None:
ip_adapter_success = setup_ip_adapter(self.pipe)
self.models_loaded['ip_adapter'] = ip_adapter_success
# The pipeline now has these attributes after load_ip_adapter_instantid:
# - self.pipe.image_proj_model (the Resampler)
# - self.pipe.ip_adapter_scale (current scale)
# We don't need to manually manage these anymore!
else:
print("[INFO] Face preservation: InstantID ControlNet keypoints only")
self.models_loaded['ip_adapter'] = False
# Setup Compel
self.compel, self.use_compel = setup_compel(self.pipe)
# Setup LCM scheduler
setup_scheduler(self.pipe)
# Optimize pipeline
optimize_pipeline(self.pipe)
# Load caption model
self.caption_processor, self.caption_model, self.caption_enabled, self.caption_model_type = load_caption_model()
# Report caption model status
if self.caption_enabled and self.caption_model is not None:
if self.caption_model_type == "git":
print(" [OK] Using GIT for detailed captions")
elif self.caption_model_type == "blip":
print(" [OK] Using BLIP for standard captions")
else:
print(" [OK] Caption model loaded")
# Set CLIP skip
set_clip_skip(self.pipe)
# Track controlnet configuration
self.using_multiple_controlnets = isinstance(controlnets, list)
print(f"Pipeline initialized with {'multiple' if self.using_multiple_controlnets else 'single'} ControlNet(s)")
# Print model status
self._print_status()
print(" [OK] Model initialization complete!")
def _print_status(self):
"""Print model loading status"""
print("\n=== MODEL STATUS ===")
for model, loaded in self.models_loaded.items():
status = "[OK] LOADED" if loaded else "[FALLBACK/DISABLED]"
print(f"{model}: {status}")
print("===================\n")
print("=== IP-ADAPTER STATUS ===")
if self.models_loaded.get('ip_adapter', False):
if hasattr(self.pipe, 'image_proj_model'):
print("[OK] IP-Adapter fully loaded via pipeline method")
print(" - Resampler: Available at pipe.image_proj_model")
print(" - Scale control: Available via pipe.set_ip_adapter_scale()")
print(" - Expected improvement: High face similarity")
else:
print("[WARNING] IP-Adapter loaded but Resampler not accessible")
else:
print("[INFO] IP-Adapter not active (using keypoints only)")
print("=========================\n")
def get_depth_map(self, image):
"""Generate depth map using Zoe Depth"""
if self.zoe_depth is not None:
try:
if image.mode != 'RGB':
image = image.convert('RGB')
orig_width, orig_height = image.size
orig_width = int(orig_width)
orig_height = int(orig_height)
# Use multiples of 64
target_width = int((orig_width // 64) * 64)
target_height = int((orig_height // 64) * 64)
target_width = int(max(64, target_width))
target_height = int(max(64, target_height))
size_for_depth = (int(target_width), int(target_height))
image_for_depth = image.resize(size_for_depth, Image.LANCZOS)
depth_map = self.zoe_depth(image_for_depth, detect_resolution=512, image_resolution=512)
if depth_map.size != image.size:
depth_map = depth_map.resize(image.size, Image.LANCZOS)
return depth_map
except Exception as e:
print(f"Depth generation failed: {e}")
return None
return None
def generate(
self,
image,
prompt="a person",
negative_prompt="",
num_inference_steps=4,
guidance_scale=0.0,
strength=0.75,
lora_scale=1.0,
identity_control_scale=0.8,
depth_control_scale=0.8,
identity_preservation=1.0,
enable_color_matching=True,
seed=-1,
**kwargs
):
"""
Generate retro art with InstantID face preservation.
FIXED: Proper IP-Adapter integration following exampleapp.py pattern.
"""
print(f"\n{'='*60}")
print(f"Starting generation with:")
print(f" Prompt: {prompt}")
print(f" Steps: {num_inference_steps}, CFG: {guidance_scale}, Strength: {strength}")
print(f" Identity scale: {identity_control_scale}, Depth scale: {depth_control_scale}")
print(f" Face preservation: {identity_preservation}")
print(f"{'='*60}\n")
# Prepare input image
if image.mode != 'RGB':
image = image.convert('RGB')
optimal_width, optimal_height = calculate_optimal_size(image.size)
resized_image = image.resize((optimal_width, optimal_height), Image.LANCZOS)
print(f"Image resized: {image.size}{resized_image.size}")
# Generate depth map
print("Generating depth map...")
depth_image = self.get_depth_map(resized_image)
if depth_image is None:
raise RuntimeError("Could not generate depth map")
# Face detection and processing
has_detected_faces = False
face_kps_image = None
face_embeddings = None
face_crop = None
face_crop_enhanced = None
face_bbox_original = None
if self.face_app is not None:
print("Detecting faces...")
try:
image_np = cv2.cvtColor(np.array(resized_image), cv2.COLOR_RGB2BGR)
faces = self.face_app.get(image_np)
if len(faces) > 0:
has_detected_faces = True
face = faces[0]
print(f" [OK] Face detected (score: {face.det_score:.3f})")
# Get face keypoints image
face_kps_image = draw_kps(resized_image, face.kps)
# Get face embeddings (512D from InsightFace)
if hasattr(face, 'normed_embedding'):
face_embeddings = face.normed_embedding
print(f" Face embedding shape: {face_embeddings.shape}")
elif hasattr(face, 'embedding'):
face_embeddings = face.embedding / np.linalg.norm(face.embedding)
print(f" Face embedding shape: {face_embeddings.shape}")
# Store face bbox for color matching
if hasattr(face, 'bbox'):
face_bbox_original = face.bbox
# Get face crop for enhanced processing
bbox = face.bbox.astype(int)
face_crop = resized_image.crop((bbox[0], bbox[1], bbox[2], bbox[3]))
face_crop_enhanced = enhance_face_crop(face_crop)
# Debug info
if hasattr(face, 'age') and hasattr(face, 'gender'):
facial_attrs = {
'age': face.age,
'gender': face.gender,
'quality': face.det_score
}
age = facial_attrs['age']
gender_code = facial_attrs['gender']
det_score = facial_attrs['quality']
gender_str = 'M' if gender_code == 1 else ('F' if gender_code == 0 else 'N/A')
print(f" Face info: age={age if age else 'N/A'}, gender={gender_str}, quality={det_score:.3f}")
else:
print(" [INFO] No faces detected")
except Exception as e:
print(f" [WARNING] Face detection failed: {e}")
# CRITICAL FIX: Set IP-Adapter scale dynamically
# The pipeline's built-in method allows runtime adjustment
if self.models_loaded.get('ip_adapter', False) and has_detected_faces:
try:
# Scale based on identity_preservation parameter
adjusted_scale = 0.8 * identity_preservation
self.pipe.set_ip_adapter_scale(adjusted_scale)
print(f" IP-Adapter scale adjusted to: {adjusted_scale:.2f}")
except Exception as e:
print(f" [WARNING] Could not adjust IP-Adapter scale: {e}")
# Set LORA scale
if hasattr(self.pipe, 'set_adapters') and self.models_loaded['lora']:
try:
self.pipe.set_adapters(["retroart"], adapter_weights=[lora_scale])
print(f" LORA scale: {lora_scale}")
except Exception as e:
print(f" [WARNING] Could not set LORA scale: {e}")
# Prepare generation kwargs
pipe_kwargs = {
"image": resized_image,
"strength": strength,
"num_inference_steps": num_inference_steps,
"guidance_scale": guidance_scale,
}
# Setup generator with seed control
if seed == -1:
generator = torch.Generator(device=self.device)
actual_seed = generator.seed()
print(f"[SEED] Using random seed: {actual_seed}")
else:
generator = torch.Generator(device=self.device).manual_seed(seed)
actual_seed = seed
print(f"[SEED] Using fixed seed: {actual_seed}")
pipe_kwargs["generator"] = generator
# Use Compel for prompt encoding if available
if self.use_compel and self.compel is not None:
try:
print("Encoding prompts with Compel...")
conditioning = self.compel(prompt)
negative_conditioning = self.compel(negative_prompt)
pipe_kwargs["prompt_embeds"] = conditioning[0]
pipe_kwargs["pooled_prompt_embeds"] = conditioning[1]
pipe_kwargs["negative_prompt_embeds"] = negative_conditioning[0]
pipe_kwargs["negative_pooled_prompt_embeds"] = negative_conditioning[1]
print(" [OK] Using Compel-encoded prompts")
except Exception as e:
print(f" Compel encoding failed, using standard prompts: {e}")
pipe_kwargs["prompt"] = prompt
pipe_kwargs["negative_prompt"] = negative_prompt
else:
pipe_kwargs["prompt"] = prompt
pipe_kwargs["negative_prompt"] = negative_prompt
# Add CLIP skip
if hasattr(self.pipe, 'text_encoder'):
pipe_kwargs["clip_skip"] = 2
# Configure ControlNet inputs
using_multiple_controlnets = self.using_multiple_controlnets
if using_multiple_controlnets and has_detected_faces and face_kps_image is not None:
print("Using InstantID (keypoints + embeddings) + Depth ControlNets")
control_images = [face_kps_image, depth_image]
conditioning_scales = [identity_control_scale, depth_control_scale]
pipe_kwargs["control_image"] = control_images
pipe_kwargs["controlnet_conditioning_scale"] = conditioning_scales
# CRITICAL FIX: The pipeline handles face embeddings automatically!
# When load_ip_adapter_instantid() was called, the pipeline was configured
# to automatically process face embeddings through the Resampler and
# integrate them with text embeddings during generation.
#
# We just need to provide the face image via control_image and the
# pipeline does the rest. No manual concatenation needed!
if face_embeddings is not None and self.models_loaded.get('ip_adapter', False):
print(" [OK] Face embeddings will be processed by pipeline")
print(" - Pipeline automatically handles Resampler projection")
print(" - Face features integrated via IP-Adapter attention")
elif using_multiple_controlnets and not has_detected_faces:
print("Multiple ControlNets available but no faces detected, using depth only")
control_images = [depth_image, depth_image]
conditioning_scales = [0.0, depth_control_scale]
pipe_kwargs["control_image"] = control_images
pipe_kwargs["controlnet_conditioning_scale"] = conditioning_scales
else:
print("Using Depth ControlNet only")
pipe_kwargs["control_image"] = depth_image
pipe_kwargs["controlnet_conditioning_scale"] = depth_control_scale
# Generate
print(f"\nGenerating with LCM:")
print(f" Steps: {num_inference_steps}, CFG: {guidance_scale}, Strength: {strength}")
print(f" ControlNet scales - Identity: {identity_control_scale}, Depth: {depth_control_scale}")
result = self.pipe(**pipe_kwargs)
generated_image = result.images[0]
# Post-processing
if enable_color_matching and has_detected_faces:
print("\nApplying enhanced face-aware color matching...")
try:
if face_bbox_original is not None:
generated_image = enhanced_color_match(
generated_image,
resized_image,
face_bbox=face_bbox_original
)
print(" [OK] Enhanced color matching applied (face-aware)")
else:
generated_image = color_match(generated_image, resized_image, mode='mkl')
print(" [OK] Standard color matching applied")
except Exception as e:
print(f" [WARNING] Color matching failed: {e}")
elif enable_color_matching:
print("\nApplying standard color matching...")
try:
generated_image = color_match(generated_image, resized_image, mode='mkl')
print(" [OK] Standard color matching applied")
except Exception as e:
print(f" [WARNING] Color matching failed: {e}")
print(f"\n{'='*60}")
print("Generation complete!")
print(f"{'='*60}\n")
return generated_image
print("[OK] Generator class ready (FIXED VERSION)")