pixagram-neo-backup / generator.py
primerz's picture
Update generator.py
c336b59 verified
raw
history blame
32.7 kB
"""
Generation logic for Pixagram AI Pixel Art Generator
CORRECTED VERSION - Following examplewithface.py pattern
"""
import torch
import numpy as np
import cv2
from PIL import Image
import gc
from config import (
device, dtype, TRIGGER_WORD,
ADAPTIVE_THRESHOLDS, ADAPTIVE_PARAMS, CAPTION_CONFIG
)
from utils import (
sanitize_text, enhanced_color_match, color_match,
get_demographic_description, calculate_optimal_size, safe_image_size
)
from models import (
load_face_analysis, load_depth_detector, load_controlnets,
load_sdxl_pipeline, load_lora, setup_compel,
setup_scheduler, optimize_pipeline, load_caption_model, set_clip_skip
)
from pipeline_stable_diffusion_xl_instantid_img2img import draw_kps
from memory_utils import MemoryManager, ModelOffloader
class RetroArtConverter:
"""Main class for retro art generation with InstantID"""
def __init__(self):
self.device = device
self.dtype = dtype
self.models_loaded = {
'custom_checkpoint': False,
'lora': False,
'instantid': False,
'zoe_depth': False
}
# Initialize memory manager
self.memory_manager = MemoryManager(device=device, dtype=dtype, verbose=True)
# Load face analysis (like examplewithface.py line 113)
self.face_app, face_detection_success = load_face_analysis()
if not face_detection_success or self.face_app is None:
raise RuntimeError("[ERROR] Face detection is required! Check InsightFace installation.")
# Load depth detector (starts on CPU) - single assignment, no alias
self.zoe_depth, zoe_success = load_depth_detector()
self.models_loaded['zoe_depth'] = zoe_success
# Load ControlNets AS LIST
controlnet_instantid, controlnet_depth = load_controlnets()
controlnets = [controlnet_instantid, controlnet_depth]
self.models_loaded['instantid'] = True
print("Initializing InstantID pipeline with Face + Depth ControlNets")
# Load SDXL pipeline with InstantID (handles IP-Adapter internally)
self.pipe, checkpoint_success = load_sdxl_pipeline(controlnets)
self.models_loaded['custom_checkpoint'] = checkpoint_success
# Load LORA
lora_success = load_lora(self.pipe)
self.models_loaded['lora'] = lora_success
# Setup Compel
self.compel, self.use_compel = setup_compel(self.pipe)
# Setup scheduler
setup_scheduler(self.pipe)
# Optimize
optimize_pipeline(self.pipe)
# Load caption model (starts on CPU)
self.caption_processor, self.caption_model, self.caption_enabled, self.caption_model_type = load_caption_model()
# Set CLIP skip
set_clip_skip(self.pipe)
# Print status
self._print_status()
# Initial memory cleanup
self.memory_manager.cleanup_memory(aggressive=True)
print(" [OK] RetroArtConverter initialized with optimized memory management!")
def _print_status(self):
"""Print model loading status"""
print("\n=== MODEL STATUS ===")
for model, loaded in self.models_loaded.items():
status = "[OK] LOADED" if loaded else "[FALLBACK/DISABLED]"
print(f"{model}: {status}")
print("InstantID Pipeline: [OK] ACTIVE")
print("IP-Adapter: [OK] Built into pipeline")
print(f"Face Detection: [OK] {'READY' if self.face_app else 'UNAVAILABLE'}")
print("===================\n")
def get_depth_map(self, image):
"""Generate depth map using Zoe Depth with optimized GPU usage"""
if self.zoe_depth is not None:
try:
if image.mode != 'RGB':
image = image.convert('RGB')
# Use safe size helper to avoid numpy.int64 issues
orig_width, orig_height = safe_image_size(image)
# Use multiples of 64
target_width = int((orig_width // 64) * 64)
target_height = int((orig_height // 64) * 64)
target_width = int(max(64, target_width))
target_height = int(max(64, target_height))
size_for_depth = (int(target_width), int(target_height))
image_for_depth = image.resize(size_for_depth, Image.LANCZOS)
# Move depth model to GPU temporarily
try:
if torch.cuda.is_available():
self.zoe_depth = self.zoe_depth.to(self.device)
# Generate depth map
depth_output = self.zoe_depth(image_for_depth, detect_resolution=512, image_resolution=1024)
# Handle different output types
if isinstance(depth_output, Image.Image):
depth_image = depth_output
elif isinstance(depth_output, np.ndarray):
depth_image = Image.fromarray(depth_output.astype(np.uint8))
elif isinstance(depth_output, torch.Tensor):
depth_array = depth_output.cpu().numpy()
if depth_array.ndim == 3 and depth_array.shape[0] == 3:
depth_array = depth_array.transpose(1, 2, 0)
depth_image = Image.fromarray((depth_array * 255).astype(np.uint8))
else:
print(f"[DEPTH] Unexpected output type: {type(depth_output)}")
depth_image = image_for_depth.convert('L').convert('RGB')
# Move back to CPU to free GPU memory
if torch.cuda.is_available():
self.zoe_depth = self.zoe_depth.to("cpu")
torch.cuda.empty_cache()
except Exception as inner_e:
print(f"[DEPTH] GPU processing failed: {inner_e}, trying on CPU")
self.zoe_depth = self.zoe_depth.to("cpu")
depth_output = self.zoe_depth(image_for_depth, detect_resolution=512, image_resolution=1024)
if isinstance(depth_output, Image.Image):
depth_image = depth_output
elif isinstance(depth_output, np.ndarray):
depth_image = Image.fromarray(depth_output.astype(np.uint8))
else:
depth_image = image_for_depth.convert('L').convert('RGB')
# Ensure depth image is RGB
if depth_image.mode != 'RGB':
depth_image = depth_image.convert('RGB')
if depth_image.size != image.size:
depth_image = depth_image.resize(image.size, Image.LANCZOS)
print(f"[DEPTH] Generated depth map: {depth_image.size}")
return depth_image
except Exception as e:
print(f"[DEPTH] Generation failed: {e}, using grayscale fallback")
fallback = image.convert('L').convert('RGB')
return fallback
else:
print("[DEPTH] Detector not available, using grayscale")
fallback = image.convert('L').convert('RGB')
return fallback
def add_trigger_word(self, prompt):
"""Add trigger word to prompt if not present"""
if TRIGGER_WORD.lower() not in prompt.lower():
if not prompt or not prompt.strip():
return TRIGGER_WORD
return f"{TRIGGER_WORD}, {prompt}"
return prompt
def detect_face_quality(self, face):
"""Detect face quality and adaptively adjust parameters"""
try:
bbox = face.bbox
face_size = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
det_score = float(face.det_score) if hasattr(face, 'det_score') else 1.0
# Small face -> boost preservation
if face_size < ADAPTIVE_THRESHOLDS['small_face_size']:
return ADAPTIVE_PARAMS['small_face'].copy()
# Low confidence -> boost preservation
elif det_score < ADAPTIVE_THRESHOLDS['low_confidence']:
return ADAPTIVE_PARAMS['low_confidence'].copy()
# Check for profile view
elif hasattr(face, 'pose') and len(face.pose) > 1:
try:
yaw = float(face.pose[1])
if abs(yaw) > ADAPTIVE_THRESHOLDS['profile_angle']:
return ADAPTIVE_PARAMS['profile_view'].copy()
except (ValueError, TypeError, IndexError):
pass
return None
except Exception as e:
print(f"[ADAPTIVE] Quality detection failed: {e}")
return None
def generate_caption(self, image):
"""Generate caption for image with optimized GPU usage"""
if not self.caption_enabled or self.caption_model is None:
return None
try:
# Move caption model to GPU temporarily
original_device = "cpu"
if hasattr(self.caption_model, 'device'):
original_device = str(self.caption_model.device)
try:
# Move to GPU for processing
if torch.cuda.is_available():
self.caption_model = self.caption_model.to(self.device)
if self.caption_model_type == 'git':
inputs = self.caption_processor(images=image, return_tensors="pt").to(self.device)
generated_ids = self.caption_model.generate(**inputs, max_length=CAPTION_CONFIG['max_length'])
caption = self.caption_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
elif self.caption_model_type == 'blip':
inputs = self.caption_processor(image, return_tensors="pt").to(self.device)
generated_ids = self.caption_model.generate(**inputs, max_length=CAPTION_CONFIG['max_length'])
caption = self.caption_processor.decode(generated_ids[0], skip_special_tokens=True)
else:
return None
# Move back to CPU to free GPU memory
if torch.cuda.is_available() and "cpu" in original_device:
self.caption_model = self.caption_model.to("cpu")
torch.cuda.empty_cache()
except Exception as gpu_error:
print(f"[CAPTION] GPU processing failed: {gpu_error}, trying on CPU")
self.caption_model = self.caption_model.to("cpu")
if self.caption_model_type == 'git':
inputs = self.caption_processor(images=image, return_tensors="pt")
generated_ids = self.caption_model.generate(**inputs, max_length=CAPTION_CONFIG['max_length'])
caption = self.caption_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
elif self.caption_model_type == 'blip':
inputs = self.caption_processor(image, return_tensors="pt")
generated_ids = self.caption_model.generate(**inputs, max_length=CAPTION_CONFIG['max_length'])
caption = self.caption_processor.decode(generated_ids[0], skip_special_tokens=True)
else:
return None
return sanitize_text(caption)
except Exception as e:
print(f"[CAPTION] Generation failed: {e}")
return None
def generate_retro_art(
self,
input_image,
prompt=" ",
negative_prompt=" ",
num_inference_steps=12,
guidance_scale=1.3,
depth_control_scale=0.75,
identity_control_scale=0.85,
lora_scale=1.0,
identity_preservation=1.2,
strength=0.50,
enable_color_matching=False,
consistency_mode=True,
seed=-1
):
"""Generate retro art with InstantID face preservation"""
try:
# Add trigger word
prompt = self.add_trigger_word(prompt)
prompt = sanitize_text(prompt)
negative_prompt = sanitize_text(negative_prompt)
print(f"[PROMPT] {prompt}")
# Calculate optimal size
orig_width, orig_height = safe_image_size(input_image)
optimal_width, optimal_height = calculate_optimal_size(orig_width, orig_height)
# Resize image
resized_image = input_image.resize((optimal_width, optimal_height), Image.LANCZOS)
print(f"[SIZE] Resized to {optimal_width}x{optimal_height}")
# Generate depth map
depth_image = self.get_depth_map(resized_image)
# ═══════════════════════════════════════════════════════════
# FACE DETECTION
# ═══════════════════════════════════════════════════════════
has_detected_faces = False
face_kps_image = None
face_embeddings = None
face_bbox_original = None
# FACE DETECTION (examplewithface.py line 321-327)
try:
image_array = cv2.cvtColor(np.array(resized_image), cv2.COLOR_RGB2BGR)
faces = self.face_app.get(image_array)
if len(faces) == 0:
raise ValueError("No faces detected in image")
# Get largest face (examplewithface.py line 322)
face = sorted(faces, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1]
# Get embeddings and keypoints
face_embeddings = face['embedding']
face_kps_image = draw_kps(resized_image, face['kps'])
face_bbox_original = face.get('bbox', None)
# Adaptive parameter adjustment
adaptive_params = self.detect_face_quality(face)
if adaptive_params:
print(f"[ADAPTIVE] {adaptive_params['reason']}")
identity_preservation = adaptive_params.get('identity_preservation', identity_preservation)
identity_control_scale = adaptive_params.get('identity_control_scale', identity_control_scale)
guidance_scale = adaptive_params.get('guidance_scale', guidance_scale)
lora_scale = adaptive_params.get('lora_scale', lora_scale)
print(f"[FACE] Detected face with {face.get('det_score', 1.0):.2f} confidence")
print(f"[FACE] Embeddings shape: {face_embeddings.shape}")
has_detected_faces = True
except Exception as e:
print(f"[FACE] Face detection failed: {str(e)[:100]}")
raise ValueError(f"No face found in image. Only face images work. Error: {str(e)}")
# Fuse LORA with scale (following working example approach)
if self.models_loaded['lora']:
try:
from models import fuse_lora_with_scale
fuse_lora_with_scale(self.pipe, lora_scale)
print(f"[LORA] Fused with scale: {lora_scale}")
except Exception as e:
print(f"[LORA] Could not fuse: {e}")
# ═══════════════════════════════════════════════════════════
# PIPELINE CONFIGURATION
# ═══════════════════════════════════════════════════════════
pipe_kwargs = {
"image": resized_image,
"strength": strength,
"num_inference_steps": num_inference_steps,
"guidance_scale": guidance_scale,
}
# Setup generator with seed
if seed == -1:
generator = torch.Generator(device=self.device)
actual_seed = generator.seed()
print(f"[SEED] Random: {actual_seed}")
else:
generator = torch.Generator(device=self.device).manual_seed(seed)
actual_seed = seed
print(f"[SEED] Fixed: {actual_seed}")
pipe_kwargs["generator"] = generator
# Use Compel for prompt encoding
if self.use_compel and self.compel is not None:
try:
conditioning = self.compel(prompt)
negative_conditioning = self.compel(negative_prompt)
pipe_kwargs["prompt_embeds"] = conditioning[0]
pipe_kwargs["pooled_prompt_embeds"] = conditioning[1]
pipe_kwargs["negative_prompt_embeds"] = negative_conditioning[0]
pipe_kwargs["negative_pooled_prompt_embeds"] = negative_conditioning[1]
print("[OK] Using Compel-encoded prompts")
except Exception as e:
print(f"[COMPEL] Failed, using standard prompts: {e}")
pipe_kwargs["prompt"] = prompt
pipe_kwargs["negative_prompt"] = negative_prompt
else:
pipe_kwargs["prompt"] = prompt
pipe_kwargs["negative_prompt"] = negative_prompt
# ═══════════════════════════════════════════════════════════
# CONTROLNET + IP-ADAPTER CONFIGURATION
# ═══════════════════════════════════════════════════════════
if has_detected_faces and face_kps_image is not None and face_embeddings is not None:
print("═" * 60)
print("MODE: InstantID (Face Keypoints + Depth + IP-Adapter)")
print("═" * 60)
# Set IP-Adapter scale
self.pipe.set_ip_adapter_scale(identity_preservation)
print(f" [IP-ADAPTER] Scale set to: {identity_preservation}")
# Control images: [face keypoints, depth map]
pipe_kwargs["control_image"] = [face_kps_image, depth_image]
# ControlNet scales: [identity keypoints, depth]
pipe_kwargs["controlnet_conditioning_scale"] = [
identity_control_scale,
depth_control_scale
]
# Control guidance timing
pipe_kwargs["control_guidance_start"] = [0.0, 0.0]
pipe_kwargs["control_guidance_end"] = [1.0, 1.0]
# Pass raw face embeddings - pipeline handles everything
pipe_kwargs["image_embeds"] = face_embeddings
print(f" [CONTROLNET] Identity scale: {identity_control_scale}")
print(f" [CONTROLNET] Depth scale: {depth_control_scale}")
print(f" [EMBEDDINGS] Shape: {face_embeddings.shape} (raw)")
print(" [INFO] Pipeline will handle: Resampler → Concatenation → Attention")
print("═" * 60)
elif has_detected_faces and face_kps_image is not None:
print("═" * 60)
print("MODE: InstantID Keypoints Only (no embeddings)")
print("═" * 60)
# Disable IP-Adapter
self.pipe.set_ip_adapter_scale(0.0)
print(" [IP-ADAPTER] Disabled (no embeddings)")
# Use keypoints + depth
pipe_kwargs["control_image"] = [face_kps_image, depth_image]
pipe_kwargs["controlnet_conditioning_scale"] = [
identity_control_scale,
depth_control_scale
]
pipe_kwargs["control_guidance_start"] = [0.0, 0.0]
pipe_kwargs["control_guidance_end"] = [1.0, 1.0]
# Pass zero embeddings
zero_embeddings = np.zeros(512, dtype=np.float32)
pipe_kwargs["image_embeds"] = zero_embeddings
print(" [INFO] Using keypoints for structure only (zero embeddings)")
print("═" * 60)
else:
print("═" * 60)
print("MODE: Depth Only (no face detection)")
print("═" * 60)
# Disable IP-Adapter
self.pipe.set_ip_adapter_scale(0.0)
print(" [IP-ADAPTER] Disabled (no face)")
# Use depth only
pipe_kwargs["control_image"] = [depth_image, depth_image]
pipe_kwargs["controlnet_conditioning_scale"] = [0.0, depth_control_scale]
pipe_kwargs["control_guidance_start"] = [0.0, 0.0]
pipe_kwargs["control_guidance_end"] = [1.0, 1.0]
# Pass zero embeddings
zero_embeddings = np.zeros(512, dtype=np.float32)
pipe_kwargs["image_embeds"] = zero_embeddings
print(f" [CONTROLNET] Depth scale: {depth_control_scale}")
print(" [INFO] Generating without face preservation (zero embeddings)")
print("═" * 60)
# ═══════════════════════════════════════════════════════════
# GENERATION
# ═══════════════════════════════════════════════════════════
print(f"\nGenerating: Steps={num_inference_steps}, CFG={guidance_scale}, Strength={strength}")
result = self.pipe(**pipe_kwargs)
generated_image = result.images[0]
# ═══════════════════════════════════════════════════════════
# POST-PROCESSING
# ═══════════════════════════════════════════════════════════
if enable_color_matching and has_detected_faces:
print("Applying enhanced face-aware color matching...")
try:
if face_bbox_original is not None:
generated_image = enhanced_color_match(
generated_image,
resized_image,
face_bbox=face_bbox_original
)
print("[OK] Enhanced color matching applied")
else:
generated_image = color_match(generated_image, resized_image, mode='mkl')
print("[OK] Standard color matching applied")
except Exception as e:
print(f"[COLOR] Matching failed: {e}")
elif enable_color_matching:
print("Applying standard color matching...")
try:
generated_image = color_match(generated_image, resized_image, mode='mkl')
print("[OK] Standard color matching applied")
except Exception as e:
print(f"[COLOR] Matching failed: {e}")
return generated_image
finally:
# Memory cleanup
self.memory_manager.cleanup_memory(aggressive=True)
# Final memory status
if self.memory_manager.verbose:
print("[MEMORY] Final status after generation:")
self.memory_manager.print_memory_status()
print("[OK] Generator class ready with cleaned code")