""" Generation logic for Pixagram AI Pixel Art Generator FIXED VERSION - Following exampleapp.py pattern more closely """ import torch import numpy as np import cv2 from PIL import Image import torch.nn.functional as F from torchvision import transforms from config import ( device, dtype, TRIGGER_WORD, MULTI_SCALE_FACTORS, ADAPTIVE_THRESHOLDS, ADAPTIVE_PARAMS, CAPTION_CONFIG, IDENTITY_BOOST_MULTIPLIER, MODEL_REPO, MODEL_FILES ) from utils import ( sanitize_text, enhanced_color_match, color_match, create_face_mask, draw_kps, get_demographic_description, calculate_optimal_size, enhance_face_crop ) from models import ( load_face_analysis, load_depth_detector, load_controlnets, load_image_encoder, load_sdxl_pipeline, load_lora, setup_ip_adapter, setup_compel, setup_scheduler, optimize_pipeline, load_caption_model, set_clip_skip ) class RetroArtConverter: """Main class for retro art generation - FIXED VERSION""" def __init__(self): self.device = device self.dtype = dtype self.models_loaded = { 'custom_checkpoint': False, 'lora': False, 'lora_path': None, 'instantid': False, 'zoe_depth': False, 'ip_adapter': False } # Initialize face analysis self.face_app, self.face_detection_enabled = load_face_analysis() # Load Zoe Depth detector self.zoe_depth, zoe_success = load_depth_detector() self.models_loaded['zoe_depth'] = zoe_success # Load ControlNets controlnet_depth, self.controlnet_instantid, instantid_success = load_controlnets() self.controlnet_depth = controlnet_depth self.instantid_enabled = instantid_success self.models_loaded['instantid'] = instantid_success # Load image encoder if self.instantid_enabled: self.image_encoder = load_image_encoder() else: self.image_encoder = None # Determine which controlnets to use if self.instantid_enabled and self.controlnet_instantid is not None: controlnets = [self.controlnet_instantid, controlnet_depth] print(f"Initializing with multiple ControlNets: InstantID + Depth") else: controlnets = controlnet_depth print(f"Initializing with single ControlNet: Depth only") # Load SDXL pipeline self.pipe, checkpoint_success = load_sdxl_pipeline(controlnets) self.models_loaded['custom_checkpoint'] = checkpoint_success # Load LORA and store path lora_success = load_lora(self.pipe) self.models_loaded['lora'] = lora_success if lora_success: # Store LORA path for later reloading from huggingface_hub import hf_hub_download try: lora_path = hf_hub_download(MODEL_REPO, MODEL_FILES['lora']) self.models_loaded['lora_path'] = lora_path except: self.models_loaded['lora_path'] = None # Setup IP-Adapter using pipeline's built-in method if self.instantid_enabled and self.image_encoder is not None: ip_adapter_success = setup_ip_adapter(self.pipe) self.models_loaded['ip_adapter'] = ip_adapter_success else: print("[INFO] Face preservation: InstantID ControlNet keypoints only") self.models_loaded['ip_adapter'] = False # Setup Compel self.compel, self.use_compel = setup_compel(self.pipe) # Setup LCM scheduler setup_scheduler(self.pipe) # Optimize pipeline optimize_pipeline(self.pipe) # Load caption model self.caption_processor, self.caption_model, self.caption_enabled, self.caption_model_type = load_caption_model() # Report caption model status if self.caption_enabled and self.caption_model is not None: if self.caption_model_type == "git": print(" [OK] Using GIT for detailed captions") elif self.caption_model_type == "blip": print(" [OK] Using BLIP for standard captions") else: print(" [OK] Caption model loaded") # Set CLIP skip set_clip_skip(self.pipe) # Track controlnet configuration self.using_multiple_controlnets = isinstance(controlnets, list) print(f"Pipeline initialized with {'multiple' if self.using_multiple_controlnets else 'single'} ControlNet(s)") # Print model status self._print_status() print(" [OK] Model initialization complete!") def _print_status(self): """Print model loading status""" print("\n=== MODEL STATUS ===") for model, loaded in self.models_loaded.items(): if model == 'lora_path': continue status = "[OK] LOADED" if loaded else "[FALLBACK/DISABLED]" print(f"{model}: {status}") print("===================\n") print("=== IP-ADAPTER STATUS ===") if self.models_loaded.get('ip_adapter', False): if hasattr(self.pipe, 'image_proj_model'): print("[OK] IP-Adapter fully loaded via pipeline method") print(" - Resampler: Available at pipe.image_proj_model") print(" - Scale control: Available via pipe.set_ip_adapter_scale()") print(" - Expected improvement: High face similarity") else: print("[WARNING] IP-Adapter loaded but Resampler not accessible") else: print("[INFO] IP-Adapter not active (using keypoints only)") print("=========================\n") def get_depth_map(self, image): """Generate depth map using Zoe Depth""" if self.zoe_depth is not None: try: if image.mode != 'RGB': image = image.convert('RGB') orig_width, orig_height = image.size orig_width = int(orig_width) orig_height = int(orig_height) # Use multiples of 64 target_width = int((orig_width // 64) * 64) target_height = int((orig_height // 64) * 64) target_width = int(max(64, target_width)) target_height = int(max(64, target_height)) size_for_depth = (int(target_width), int(target_height)) image_for_depth = image.resize(size_for_depth, Image.LANCZOS) depth_map = self.zoe_depth(image_for_depth, detect_resolution=512, image_resolution=512) if depth_map.size != image.size: depth_map = depth_map.resize(image.size, Image.LANCZOS) return depth_map except Exception as e: print(f"Depth generation failed: {e}") return None return None def generate( self, image, prompt="a person", negative_prompt="", num_inference_steps=12, guidance_scale=0.0, strength=0.75, lora_scale=1.0, identity_control_scale=0.8, depth_control_scale=0.8, identity_preservation=1.0, enable_color_matching=True, consistency_mode=True, seed=-1 ): """ Generate retro art with InstantID face preservation. FIXED: Following exampleapp.py pattern more closely. """ print(f"\n{'='*60}") print(f"Starting generation with:") print(f" Prompt: {prompt}") print(f" Steps: {num_inference_steps}, CFG: {guidance_scale}, Strength: {strength}") print(f" Identity scale: {identity_control_scale}, Depth scale: {depth_control_scale}") print(f" Face preservation: {identity_preservation}") print(f" Consistency mode: {'ON' if consistency_mode else 'OFF'}") print(f"{'='*60}\n") # Apply consistency mode adjustments if consistency_mode: print("[CONSISTENCY] Validating and adjusting parameters...") # Validate guidance scale for LCM if guidance_scale > 2.0: print(f" [ADJUST] CFG too high ({guidance_scale:.2f}), capping at 2.0") guidance_scale = 2.0 elif guidance_scale < 0.5: print(f" [ADJUST] CFG too low ({guidance_scale:.2f}), raising to 0.5") guidance_scale = 0.5 # Balance identity preservation and LORA scale if identity_preservation > 1.5 and lora_scale > 1.5: print(f" [ADJUST] High identity + high LORA conflict detected") print(f" Reducing LORA scale: {lora_scale:.2f} → {lora_scale * 0.8:.2f}") lora_scale = lora_scale * 0.8 # Ensure ControlNet scales are reasonable if depth_control_scale > 1.2: print(f" [ADJUST] Depth scale too high ({depth_control_scale:.2f}), capping at 1.2") depth_control_scale = 1.2 if identity_control_scale > 1.5: print(f" [ADJUST] Identity control too high ({identity_control_scale:.2f}), capping at 1.5") identity_control_scale = 1.5 # Validate strength range if strength < 0.3: print(f" [ADJUST] Strength too low ({strength:.2f}), raising to 0.3") strength = 0.3 elif strength > 0.9: print(f" [ADJUST] Strength too high ({strength:.2f}), capping at 0.9") strength = 0.9 print("[CONSISTENCY] Parameter validation complete\n") # Prepare input image if image.mode != 'RGB': image = image.convert('RGB') optimal_width, optimal_height = calculate_optimal_size(image.size[0], image.size[1]) resized_image = image.resize((optimal_width, optimal_height), Image.LANCZOS) print(f"Image resized: {image.size} → {resized_image.size}") # Generate depth map print("Generating depth map...") depth_image = self.get_depth_map(resized_image) if depth_image is None: raise RuntimeError("Could not generate depth map") # Face detection and processing has_detected_faces = False face_kps_image = None face_embeddings = None face_crop = None face_crop_enhanced = None face_bbox_original = None if self.face_app is not None: print("Detecting faces...") try: image_np = cv2.cvtColor(np.array(resized_image), cv2.COLOR_RGB2BGR) faces = self.face_app.get(image_np) if len(faces) > 0: has_detected_faces = True face = faces[0] print(f" [OK] Face detected (score: {face.det_score:.3f})") # Get face keypoints image face_kps_image = draw_kps(resized_image, face.kps) # Get face embeddings (512D from InsightFace) if hasattr(face, 'normed_embedding') and face.normed_embedding is not None: face_embeddings = face.normed_embedding print(f" Face embedding extracted (normed_embedding): shape {face_embeddings.shape}") elif hasattr(face, 'embedding') and face.embedding is not None: face_embeddings = face.embedding / np.linalg.norm(face.embedding) print(f" Face embedding extracted (embedding, normalized): shape {face_embeddings.shape}") elif isinstance(face, dict) and 'embedding' in face: face_embeddings = face['embedding'] print(f" Face embedding extracted (dict['embedding']): shape {face_embeddings.shape}") else: face_embeddings = None print(f" [WARNING] Face detected but embeddings not available") # Store face bbox for color matching if hasattr(face, 'bbox'): face_bbox_original = face.bbox # Get face crop for enhanced processing bbox = face.bbox.astype(int) face_crop = resized_image.crop((bbox[0], bbox[1], bbox[2], bbox[3])) face_crop_enhanced = enhance_face_crop(face_crop) # Debug info if hasattr(face, 'age') and hasattr(face, 'gender'): age = face.age gender_code = face.gender det_score = face.det_score gender_str = 'M' if gender_code == 1 else ('F' if gender_code == 0 else 'N/A') print(f" Face info: age={age if age else 'N/A'}, gender={gender_str}, quality={det_score:.3f}") else: print(" [INFO] No faces detected") except Exception as e: print(f" [WARNING] Face detection failed: {e}") # Unfuse and reload LORA with new scale (like exampleapp.py) #if hasattr(self.pipe, 'unfuse_lora'): # try: # self.pipe.unfuse_lora() # self.pipe.unload_lora_weights() # print(" [OK] Unfused previous LORA") # except Exception as e: # print(f" [INFO] No previous LORA to unfuse: {e}") # Load and fuse LORA at the requested scale #if self.models_loaded['lora'] and self.models_loaded.get('lora_path'): # try: # self.pipe.load_lora_weights(self.models_loaded['lora_path']) # self.pipe.fuse_lora(lora_scale=lora_scale) # print(f" [OK] LORA fused at scale: {lora_scale}") # except Exception as e: # print(f" [WARNING] Could not fuse LORA: {e}") # --- CORRECTED BLOCK --- # Set LORA scale using set_adapters (matches models.py) if hasattr(self.pipe, 'set_adapters') and self.models_loaded['lora']: try: self.pipe.set_adapters(["retroart"], adapter_weights=[lora_scale]) print(f"LORA scale: {lora_scale}") except Exception as e: print(f"Could not set LORA scale: {e}") # --- END OF BLOCK --- # Setup generator with seed control if seed == -1: generator = torch.Generator(device=self.device) actual_seed = generator.seed() print(f"[SEED] Using random seed: {actual_seed}") else: generator = torch.Generator(device=self.device).manual_seed(seed) actual_seed = seed print(f"[SEED] Using fixed seed: {actual_seed}") # Use Compel for prompt encoding (like exampleapp.py - simpler) if self.use_compel and self.compel is not None: print("Encoding prompts with Compel...") # --- FIX: Add the LORA trigger word --- # Ensure trigger word is present and avoid duplicates if TRIGGER_WORD not in prompt: # Prepend the trigger word for highest impact prompt = f"{TRIGGER_WORD}, {prompt}" print(f" Using final prompt: {prompt}") # --- End Fix --- conditioning, pooled = self.compel(prompt) negative_conditioning, negative_pooled = self.compel(negative_prompt) print(" [OK] Prompts encoded") else: # Fallback to standard prompts conditioning = None pooled = None negative_conditioning = None negative_pooled = None # Set CLIP skip clip_skip = 2 if hasattr(self.pipe, 'text_encoder') else None # Configure ControlNet inputs using_multiple_controlnets = self.using_multiple_controlnets if using_multiple_controlnets and has_detected_faces and face_kps_image is not None: print("Using InstantID (keypoints + embeddings) + Depth ControlNets") control_image = [face_kps_image, depth_image] conditioning_scales = [identity_control_scale, depth_control_scale] # Set IP-Adapter scale if embeddings available if face_embeddings is not None: adjusted_scale = 0.8 * identity_preservation self.pipe.set_ip_adapter_scale(adjusted_scale) print(f" IP-Adapter scale: {adjusted_scale:.2f}") print(f" Face embeddings shape: {face_embeddings.shape}") print(" [OK] Face embeddings ready for InstantID pipeline") else: # No embeddings, pass None face_embeddings = None print(" [INFO] No face embeddings, passing None to pipeline") elif using_multiple_controlnets and not has_detected_faces: print("Multiple ControlNets available but no faces detected, using depth only") # The InstantID controlnet (index 0) still needs an image input. # We provide a blank image and set its scale to 0.0 to disable it. blank_image = Image.new("RGB", depth_image.size, (0, 0, 0)) control_image = [blank_image, depth_image] conditioning_scales = [0.0, depth_control_scale] face_embeddings = None else: print("Using Depth ControlNet only") control_image = depth_image conditioning_scales = depth_control_scale face_embeddings = None # Generate (like exampleapp.py - direct call) print(f"\nGenerating with LCM:") print(f" Steps: {num_inference_steps}, CFG: {guidance_scale}, Strength: {strength}") print(f" ControlNet scales - Identity: {identity_control_scale}, Depth: {depth_control_scale}") try: generated_image = self.pipe( prompt_embeds=conditioning, pooled_prompt_embeds=pooled, negative_prompt_embeds=negative_conditioning, negative_pooled_prompt_embeds=negative_pooled, width=optimal_width, height=optimal_height, image_embeds=face_embeddings, image=resized_image, strength=strength, control_image=control_image, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, clip_skip=clip_skip, generator=generator, controlnet_conditioning_scale=conditioning_scales ).images[0] except Exception as e: print(f"[ERROR] Generation failed: {e}") import traceback traceback.print_exc() raise # Post-processing if enable_color_matching and has_detected_faces: print("\nApplying enhanced face-aware color matching...") try: if face_bbox_original is not None: generated_image = enhanced_color_match( generated_image, resized_image, face_bbox=face_bbox_original ) print(" [OK] Enhanced color matching applied (face-aware)") else: generated_image = color_match(generated_image, resized_image, mode='mkl') print(" [OK] Standard color matching applied") except Exception as e: print(f" [WARNING] Color matching failed: {e}") elif enable_color_matching: print("\nApplying standard color matching...") try: generated_image = color_match(generated_image, resized_image, mode='mkl') print(" [OK] Standard color matching applied") except Exception as e: print(f" [WARNING] Color matching failed: {e}") print(f"\n{'='*60}") print("Generation complete!") print(f"{'='*60}\n") return generated_image def generate_caption(self, image): """ Generate a caption for an image. Returns None if caption generation is disabled. """ if not self.caption_enabled or self.caption_model is None: return None try: # Ensure image is PIL Image if not isinstance(image, Image.Image): image = Image.fromarray(image) # Convert to RGB if needed if image.mode != 'RGB': image = image.convert('RGB') print("Generating caption...") with torch.no_grad(): if self.caption_model_type == 'git': # GIT model inputs = self.caption_processor(images=image, return_tensors="pt").to(self.device) generated_ids = self.caption_model.generate( pixel_values=inputs.pixel_values, max_length=50 ) caption = self.caption_processor.batch_decode(generated_ids, skip_special_tokens=True)[0] elif self.caption_model_type == 'blip': # BLIP model inputs = self.caption_processor(image, return_tensors="pt").to(self.device) generated_ids = self.caption_model.generate(**inputs, max_length=50) caption = self.caption_processor.decode(generated_ids[0], skip_special_tokens=True) else: return None print(f" [OK] Caption: {caption}") return caption except Exception as e: print(f" [WARNING] Caption generation failed: {e}") return None print("[OK] Generator class ready (FIXED VERSION - exampleapp.py style)")