pixagram-neo-backup / generator.py
primerz's picture
Update generator.py
3480bce verified
"""
Generation logic for Pixagram AI Pixel Art Generator
FIXED VERSION - Following exampleapp.py pattern more closely
"""
import torch
import numpy as np
import cv2
from PIL import Image
import torch.nn.functional as F
from torchvision import transforms
from config import (
device, dtype, TRIGGER_WORD, MULTI_SCALE_FACTORS,
ADAPTIVE_THRESHOLDS, ADAPTIVE_PARAMS, CAPTION_CONFIG, IDENTITY_BOOST_MULTIPLIER,
MODEL_REPO, MODEL_FILES
)
from utils import (
sanitize_text, enhanced_color_match, color_match, create_face_mask,
draw_kps, get_demographic_description, calculate_optimal_size, enhance_face_crop
)
from models import (
load_face_analysis, load_depth_detector, load_controlnets, load_image_encoder,
load_sdxl_pipeline, load_lora, setup_ip_adapter, setup_compel,
setup_scheduler, optimize_pipeline, load_caption_model, set_clip_skip
)
class RetroArtConverter:
"""Main class for retro art generation - FIXED VERSION"""
def __init__(self):
self.device = device
self.dtype = dtype
self.models_loaded = {
'custom_checkpoint': False,
'lora': False,
'lora_path': None,
'instantid': False,
'zoe_depth': False,
'ip_adapter': False
}
# Initialize face analysis
self.face_app, self.face_detection_enabled = load_face_analysis()
# Load Zoe Depth detector
self.zoe_depth, zoe_success = load_depth_detector()
self.models_loaded['zoe_depth'] = zoe_success
# Load ControlNets
controlnet_depth, self.controlnet_instantid, instantid_success = load_controlnets()
self.controlnet_depth = controlnet_depth
self.instantid_enabled = instantid_success
self.models_loaded['instantid'] = instantid_success
# Load image encoder
if self.instantid_enabled:
self.image_encoder = load_image_encoder()
else:
self.image_encoder = None
# Determine which controlnets to use
if self.instantid_enabled and self.controlnet_instantid is not None:
controlnets = [self.controlnet_instantid, controlnet_depth]
print(f"Initializing with multiple ControlNets: InstantID + Depth")
else:
controlnets = controlnet_depth
print(f"Initializing with single ControlNet: Depth only")
# Load SDXL pipeline
self.pipe, checkpoint_success = load_sdxl_pipeline(controlnets)
self.models_loaded['custom_checkpoint'] = checkpoint_success
# Load LORA and store path
lora_success = load_lora(self.pipe)
self.models_loaded['lora'] = lora_success
if lora_success:
# Store LORA path for later reloading
from huggingface_hub import hf_hub_download
try:
lora_path = hf_hub_download(MODEL_REPO, MODEL_FILES['lora'])
self.models_loaded['lora_path'] = lora_path
except:
self.models_loaded['lora_path'] = None
# Setup IP-Adapter using pipeline's built-in method
if self.instantid_enabled and self.image_encoder is not None:
ip_adapter_success = setup_ip_adapter(self.pipe)
self.models_loaded['ip_adapter'] = ip_adapter_success
else:
print("[INFO] Face preservation: InstantID ControlNet keypoints only")
self.models_loaded['ip_adapter'] = False
# Setup Compel
self.compel, self.use_compel = setup_compel(self.pipe)
# Setup LCM scheduler
setup_scheduler(self.pipe)
# Optimize pipeline
optimize_pipeline(self.pipe)
# Load caption model
self.caption_processor, self.caption_model, self.caption_enabled, self.caption_model_type = load_caption_model()
# Report caption model status
if self.caption_enabled and self.caption_model is not None:
if self.caption_model_type == "git":
print(" [OK] Using GIT for detailed captions")
elif self.caption_model_type == "blip":
print(" [OK] Using BLIP for standard captions")
else:
print(" [OK] Caption model loaded")
# Set CLIP skip
set_clip_skip(self.pipe)
# Track controlnet configuration
self.using_multiple_controlnets = isinstance(controlnets, list)
print(f"Pipeline initialized with {'multiple' if self.using_multiple_controlnets else 'single'} ControlNet(s)")
# Print model status
self._print_status()
print(" [OK] Model initialization complete!")
def _print_status(self):
"""Print model loading status"""
print("\n=== MODEL STATUS ===")
for model, loaded in self.models_loaded.items():
if model == 'lora_path':
continue
status = "[OK] LOADED" if loaded else "[FALLBACK/DISABLED]"
print(f"{model}: {status}")
print("===================\n")
print("=== IP-ADAPTER STATUS ===")
if self.models_loaded.get('ip_adapter', False):
if hasattr(self.pipe, 'image_proj_model'):
print("[OK] IP-Adapter fully loaded via pipeline method")
print(" - Resampler: Available at pipe.image_proj_model")
print(" - Scale control: Available via pipe.set_ip_adapter_scale()")
print(" - Expected improvement: High face similarity")
else:
print("[WARNING] IP-Adapter loaded but Resampler not accessible")
else:
print("[INFO] IP-Adapter not active (using keypoints only)")
print("=========================\n")
def get_depth_map(self, image):
"""Generate depth map using Zoe Depth"""
if self.zoe_depth is not None:
try:
if image.mode != 'RGB':
image = image.convert('RGB')
orig_width, orig_height = image.size
orig_width = int(orig_width)
orig_height = int(orig_height)
# Use multiples of 64
target_width = int((orig_width // 64) * 64)
target_height = int((orig_height // 64) * 64)
target_width = int(max(64, target_width))
target_height = int(max(64, target_height))
size_for_depth = (int(target_width), int(target_height))
image_for_depth = image.resize(size_for_depth, Image.LANCZOS)
depth_map = self.zoe_depth(image_for_depth, detect_resolution=512, image_resolution=512)
if depth_map.size != image.size:
depth_map = depth_map.resize(image.size, Image.LANCZOS)
return depth_map
except Exception as e:
print(f"Depth generation failed: {e}")
return None
return None
def generate(
self,
image,
prompt="a person",
negative_prompt="",
num_inference_steps=12,
guidance_scale=0.0,
strength=0.75,
lora_scale=1.0,
identity_control_scale=0.8,
depth_control_scale=0.8,
identity_preservation=1.0,
enable_color_matching=True,
consistency_mode=True,
seed=-1
):
"""
Generate retro art with InstantID face preservation.
FIXED: Following exampleapp.py pattern more closely.
"""
print(f"\n{'='*60}")
print(f"Starting generation with:")
print(f" Prompt: {prompt}")
print(f" Steps: {num_inference_steps}, CFG: {guidance_scale}, Strength: {strength}")
print(f" Identity scale: {identity_control_scale}, Depth scale: {depth_control_scale}")
print(f" Face preservation: {identity_preservation}")
print(f" Consistency mode: {'ON' if consistency_mode else 'OFF'}")
print(f"{'='*60}\n")
# Apply consistency mode adjustments
if consistency_mode:
print("[CONSISTENCY] Validating and adjusting parameters...")
# Validate guidance scale for LCM
if guidance_scale > 2.0:
print(f" [ADJUST] CFG too high ({guidance_scale:.2f}), capping at 2.0")
guidance_scale = 2.0
elif guidance_scale < 0.5:
print(f" [ADJUST] CFG too low ({guidance_scale:.2f}), raising to 0.5")
guidance_scale = 0.5
# Balance identity preservation and LORA scale
if identity_preservation > 1.5 and lora_scale > 1.5:
print(f" [ADJUST] High identity + high LORA conflict detected")
print(f" Reducing LORA scale: {lora_scale:.2f}{lora_scale * 0.8:.2f}")
lora_scale = lora_scale * 0.8
# Ensure ControlNet scales are reasonable
if depth_control_scale > 1.2:
print(f" [ADJUST] Depth scale too high ({depth_control_scale:.2f}), capping at 1.2")
depth_control_scale = 1.2
if identity_control_scale > 1.5:
print(f" [ADJUST] Identity control too high ({identity_control_scale:.2f}), capping at 1.5")
identity_control_scale = 1.5
# Validate strength range
if strength < 0.3:
print(f" [ADJUST] Strength too low ({strength:.2f}), raising to 0.3")
strength = 0.3
elif strength > 0.9:
print(f" [ADJUST] Strength too high ({strength:.2f}), capping at 0.9")
strength = 0.9
print("[CONSISTENCY] Parameter validation complete\n")
# Prepare input image
if image.mode != 'RGB':
image = image.convert('RGB')
optimal_width, optimal_height = calculate_optimal_size(image.size[0], image.size[1])
resized_image = image.resize((optimal_width, optimal_height), Image.LANCZOS)
print(f"Image resized: {image.size}{resized_image.size}")
# Generate depth map
print("Generating depth map...")
depth_image = self.get_depth_map(resized_image)
if depth_image is None:
raise RuntimeError("Could not generate depth map")
# Face detection and processing
has_detected_faces = False
face_kps_image = None
face_embeddings = None
face_crop = None
face_crop_enhanced = None
face_bbox_original = None
if self.face_app is not None:
print("Detecting faces...")
try:
image_np = cv2.cvtColor(np.array(resized_image), cv2.COLOR_RGB2BGR)
faces = self.face_app.get(image_np)
if len(faces) > 0:
has_detected_faces = True
face = faces[0]
print(f" [OK] Face detected (score: {face.det_score:.3f})")
# Get face keypoints image
face_kps_image = draw_kps(resized_image, face.kps)
# Get face embeddings (512D from InsightFace)
if hasattr(face, 'normed_embedding') and face.normed_embedding is not None:
face_embeddings = face.normed_embedding
print(f" Face embedding extracted (normed_embedding): shape {face_embeddings.shape}")
elif hasattr(face, 'embedding') and face.embedding is not None:
face_embeddings = face.embedding / np.linalg.norm(face.embedding)
print(f" Face embedding extracted (embedding, normalized): shape {face_embeddings.shape}")
elif isinstance(face, dict) and 'embedding' in face:
face_embeddings = face['embedding']
print(f" Face embedding extracted (dict['embedding']): shape {face_embeddings.shape}")
else:
face_embeddings = None
print(f" [WARNING] Face detected but embeddings not available")
# Store face bbox for color matching
if hasattr(face, 'bbox'):
face_bbox_original = face.bbox
# Get face crop for enhanced processing
bbox = face.bbox.astype(int)
face_crop = resized_image.crop((bbox[0], bbox[1], bbox[2], bbox[3]))
face_crop_enhanced = enhance_face_crop(face_crop)
# Debug info
if hasattr(face, 'age') and hasattr(face, 'gender'):
age = face.age
gender_code = face.gender
det_score = face.det_score
gender_str = 'M' if gender_code == 1 else ('F' if gender_code == 0 else 'N/A')
print(f" Face info: age={age if age else 'N/A'}, gender={gender_str}, quality={det_score:.3f}")
else:
print(" [INFO] No faces detected")
except Exception as e:
print(f" [WARNING] Face detection failed: {e}")
# Unfuse and reload LORA with new scale (like exampleapp.py)
#if hasattr(self.pipe, 'unfuse_lora'):
# try:
# self.pipe.unfuse_lora()
# self.pipe.unload_lora_weights()
# print(" [OK] Unfused previous LORA")
# except Exception as e:
# print(f" [INFO] No previous LORA to unfuse: {e}")
# Load and fuse LORA at the requested scale
#if self.models_loaded['lora'] and self.models_loaded.get('lora_path'):
# try:
# self.pipe.load_lora_weights(self.models_loaded['lora_path'])
# self.pipe.fuse_lora(lora_scale=lora_scale)
# print(f" [OK] LORA fused at scale: {lora_scale}")
# except Exception as e:
# print(f" [WARNING] Could not fuse LORA: {e}")
# --- CORRECTED BLOCK ---
# Set LORA scale using set_adapters (matches models.py)
if hasattr(self.pipe, 'set_adapters') and self.models_loaded['lora']:
try:
self.pipe.set_adapters(["retroart"], adapter_weights=[lora_scale])
print(f"LORA scale: {lora_scale}")
except Exception as e:
print(f"Could not set LORA scale: {e}")
# --- END OF BLOCK ---
# Setup generator with seed control
if seed == -1:
generator = torch.Generator(device=self.device)
actual_seed = generator.seed()
print(f"[SEED] Using random seed: {actual_seed}")
else:
generator = torch.Generator(device=self.device).manual_seed(seed)
actual_seed = seed
print(f"[SEED] Using fixed seed: {actual_seed}")
# Use Compel for prompt encoding (like exampleapp.py - simpler)
if self.use_compel and self.compel is not None:
print("Encoding prompts with Compel...")
# --- FIX: Add the LORA trigger word ---
# Ensure trigger word is present and avoid duplicates
if TRIGGER_WORD not in prompt:
# Prepend the trigger word for highest impact
prompt = f"{TRIGGER_WORD}, {prompt}"
print(f" Using final prompt: {prompt}")
# --- End Fix ---
conditioning, pooled = self.compel(prompt)
negative_conditioning, negative_pooled = self.compel(negative_prompt)
print(" [OK] Prompts encoded")
else:
# Fallback to standard prompts
conditioning = None
pooled = None
negative_conditioning = None
negative_pooled = None
# Set CLIP skip
clip_skip = 2 if hasattr(self.pipe, 'text_encoder') else None
# Configure ControlNet inputs
using_multiple_controlnets = self.using_multiple_controlnets
if using_multiple_controlnets and has_detected_faces and face_kps_image is not None:
print("Using InstantID (keypoints + embeddings) + Depth ControlNets")
control_image = [face_kps_image, depth_image]
conditioning_scales = [identity_control_scale, depth_control_scale]
# Set IP-Adapter scale if embeddings available
if face_embeddings is not None:
adjusted_scale = 0.8 * identity_preservation
self.pipe.set_ip_adapter_scale(adjusted_scale)
print(f" IP-Adapter scale: {adjusted_scale:.2f}")
print(f" Face embeddings shape: {face_embeddings.shape}")
print(" [OK] Face embeddings ready for InstantID pipeline")
else:
# No embeddings, pass None
face_embeddings = None
print(" [INFO] No face embeddings, passing None to pipeline")
elif using_multiple_controlnets and not has_detected_faces:
print("Multiple ControlNets available but no faces detected, using depth only")
# The InstantID controlnet (index 0) still needs an image input.
# We provide a blank image and set its scale to 0.0 to disable it.
blank_image = Image.new("RGB", depth_image.size, (0, 0, 0))
control_image = [blank_image, depth_image]
conditioning_scales = [0.0, depth_control_scale]
face_embeddings = None
else:
print("Using Depth ControlNet only")
control_image = depth_image
conditioning_scales = depth_control_scale
face_embeddings = None
# Generate (like exampleapp.py - direct call)
print(f"\nGenerating with LCM:")
print(f" Steps: {num_inference_steps}, CFG: {guidance_scale}, Strength: {strength}")
print(f" ControlNet scales - Identity: {identity_control_scale}, Depth: {depth_control_scale}")
try:
generated_image = self.pipe(
prompt_embeds=conditioning,
pooled_prompt_embeds=pooled,
negative_prompt_embeds=negative_conditioning,
negative_pooled_prompt_embeds=negative_pooled,
width=optimal_width,
height=optimal_height,
image_embeds=face_embeddings,
image=resized_image,
strength=strength,
control_image=control_image,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
clip_skip=clip_skip,
generator=generator,
controlnet_conditioning_scale=conditioning_scales
).images[0]
except Exception as e:
print(f"[ERROR] Generation failed: {e}")
import traceback
traceback.print_exc()
raise
# Post-processing
if enable_color_matching and has_detected_faces:
print("\nApplying enhanced face-aware color matching...")
try:
if face_bbox_original is not None:
generated_image = enhanced_color_match(
generated_image,
resized_image,
face_bbox=face_bbox_original
)
print(" [OK] Enhanced color matching applied (face-aware)")
else:
generated_image = color_match(generated_image, resized_image, mode='mkl')
print(" [OK] Standard color matching applied")
except Exception as e:
print(f" [WARNING] Color matching failed: {e}")
elif enable_color_matching:
print("\nApplying standard color matching...")
try:
generated_image = color_match(generated_image, resized_image, mode='mkl')
print(" [OK] Standard color matching applied")
except Exception as e:
print(f" [WARNING] Color matching failed: {e}")
print(f"\n{'='*60}")
print("Generation complete!")
print(f"{'='*60}\n")
return generated_image
def generate_caption(self, image):
"""
Generate a caption for an image.
Returns None if caption generation is disabled.
"""
if not self.caption_enabled or self.caption_model is None:
return None
try:
# Ensure image is PIL Image
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
# Convert to RGB if needed
if image.mode != 'RGB':
image = image.convert('RGB')
print("Generating caption...")
with torch.no_grad():
if self.caption_model_type == 'git':
# GIT model
inputs = self.caption_processor(images=image, return_tensors="pt").to(self.device)
generated_ids = self.caption_model.generate(
pixel_values=inputs.pixel_values,
max_length=50
)
caption = self.caption_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
elif self.caption_model_type == 'blip':
# BLIP model
inputs = self.caption_processor(image, return_tensors="pt").to(self.device)
generated_ids = self.caption_model.generate(**inputs, max_length=50)
caption = self.caption_processor.decode(generated_ids[0], skip_special_tokens=True)
else:
return None
print(f" [OK] Caption: {caption}")
return caption
except Exception as e:
print(f" [WARNING] Caption generation failed: {e}")
return None
print("[OK] Generator class ready (FIXED VERSION - exampleapp.py style)")