Spaces:
Runtime error
Runtime error
| """ | |
| Generation logic for Pixagram AI Pixel Art Generator | |
| FIXED VERSION - Proper embedding integration following exampleapp.py pattern | |
| """ | |
| import torch | |
| import numpy as np | |
| import cv2 | |
| from PIL import Image | |
| import torch.nn.functional as F | |
| from torchvision import transforms | |
| from config import ( | |
| device, dtype, TRIGGER_WORD, MULTI_SCALE_FACTORS, | |
| ADAPTIVE_THRESHOLDS, ADAPTIVE_PARAMS, CAPTION_CONFIG, IDENTITY_BOOST_MULTIPLIER | |
| ) | |
| from utils import ( | |
| sanitize_text, enhanced_color_match, color_match, create_face_mask, | |
| draw_kps, get_demographic_description, calculate_optimal_size, enhance_face_crop | |
| ) | |
| from models import ( # Use the hybrid version (supports both loading methods) | |
| load_face_analysis, load_depth_detector, load_controlnets, load_image_encoder, | |
| load_sdxl_pipeline, load_lora, setup_ip_adapter, setup_compel, | |
| setup_scheduler, optimize_pipeline, load_caption_model, set_clip_skip | |
| ) | |
| class RetroArtConverter: | |
| """Main class for retro art generation - FIXED VERSION""" | |
| def __init__(self): | |
| self.device = device | |
| self.dtype = dtype | |
| self.models_loaded = { | |
| 'custom_checkpoint': False, | |
| 'lora': False, | |
| 'instantid': False, | |
| 'zoe_depth': False, | |
| 'ip_adapter': False | |
| } | |
| # Initialize face analysis | |
| self.face_app, self.face_detection_enabled = load_face_analysis() | |
| # Load Zoe Depth detector | |
| self.zoe_depth, zoe_success = load_depth_detector() | |
| self.models_loaded['zoe_depth'] = zoe_success | |
| # Load ControlNets | |
| controlnet_depth, self.controlnet_instantid, instantid_success = load_controlnets() | |
| self.controlnet_depth = controlnet_depth | |
| self.instantid_enabled = instantid_success | |
| self.models_loaded['instantid'] = instantid_success | |
| # Load image encoder (still needed for some pipeline functions) | |
| if self.instantid_enabled: | |
| self.image_encoder = load_image_encoder() | |
| else: | |
| self.image_encoder = None | |
| # Determine which controlnets to use | |
| if self.instantid_enabled and self.controlnet_instantid is not None: | |
| controlnets = [self.controlnet_instantid, controlnet_depth] | |
| print(f"Initializing with multiple ControlNets: InstantID + Depth") | |
| else: | |
| controlnets = controlnet_depth | |
| print(f"Initializing with single ControlNet: Depth only") | |
| # CRITICAL FIX: Load SDXL pipeline with from_pretrained() | |
| self.pipe, checkpoint_success = load_sdxl_pipeline(controlnets) | |
| self.models_loaded['custom_checkpoint'] = checkpoint_success | |
| # Load LORA | |
| lora_success = load_lora(self.pipe) | |
| self.models_loaded['lora'] = lora_success | |
| # CRITICAL FIX: Setup IP-Adapter using pipeline's built-in method | |
| if self.instantid_enabled and self.image_encoder is not None: | |
| ip_adapter_success = setup_ip_adapter(self.pipe) | |
| self.models_loaded['ip_adapter'] = ip_adapter_success | |
| # The pipeline now has these attributes after load_ip_adapter_instantid: | |
| # - self.pipe.image_proj_model (the Resampler) | |
| # - self.pipe.ip_adapter_scale (current scale) | |
| # We don't need to manually manage these anymore! | |
| else: | |
| print("[INFO] Face preservation: InstantID ControlNet keypoints only") | |
| self.models_loaded['ip_adapter'] = False | |
| # Setup Compel | |
| self.compel, self.use_compel = setup_compel(self.pipe) | |
| # Setup LCM scheduler | |
| setup_scheduler(self.pipe) | |
| # Optimize pipeline | |
| optimize_pipeline(self.pipe) | |
| # Load caption model | |
| self.caption_processor, self.caption_model, self.caption_enabled, self.caption_model_type = load_caption_model() | |
| # Report caption model status | |
| if self.caption_enabled and self.caption_model is not None: | |
| if self.caption_model_type == "git": | |
| print(" [OK] Using GIT for detailed captions") | |
| elif self.caption_model_type == "blip": | |
| print(" [OK] Using BLIP for standard captions") | |
| else: | |
| print(" [OK] Caption model loaded") | |
| # Set CLIP skip | |
| set_clip_skip(self.pipe) | |
| # Track controlnet configuration | |
| self.using_multiple_controlnets = isinstance(controlnets, list) | |
| print(f"Pipeline initialized with {'multiple' if self.using_multiple_controlnets else 'single'} ControlNet(s)") | |
| # Print model status | |
| self._print_status() | |
| print(" [OK] Model initialization complete!") | |
| def _print_status(self): | |
| """Print model loading status""" | |
| print("\n=== MODEL STATUS ===") | |
| for model, loaded in self.models_loaded.items(): | |
| status = "[OK] LOADED" if loaded else "[FALLBACK/DISABLED]" | |
| print(f"{model}: {status}") | |
| print("===================\n") | |
| print("=== IP-ADAPTER STATUS ===") | |
| if self.models_loaded.get('ip_adapter', False): | |
| if hasattr(self.pipe, 'image_proj_model'): | |
| print("[OK] IP-Adapter fully loaded via pipeline method") | |
| print(" - Resampler: Available at pipe.image_proj_model") | |
| print(" - Scale control: Available via pipe.set_ip_adapter_scale()") | |
| print(" - Expected improvement: High face similarity") | |
| else: | |
| print("[WARNING] IP-Adapter loaded but Resampler not accessible") | |
| else: | |
| print("[INFO] IP-Adapter not active (using keypoints only)") | |
| print("=========================\n") | |
| def get_depth_map(self, image): | |
| """Generate depth map using Zoe Depth""" | |
| if self.zoe_depth is not None: | |
| try: | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| orig_width, orig_height = image.size | |
| orig_width = int(orig_width) | |
| orig_height = int(orig_height) | |
| # Use multiples of 64 | |
| target_width = int((orig_width // 64) * 64) | |
| target_height = int((orig_height // 64) * 64) | |
| target_width = int(max(64, target_width)) | |
| target_height = int(max(64, target_height)) | |
| size_for_depth = (int(target_width), int(target_height)) | |
| image_for_depth = image.resize(size_for_depth, Image.LANCZOS) | |
| depth_map = self.zoe_depth(image_for_depth, detect_resolution=512, image_resolution=512) | |
| if depth_map.size != image.size: | |
| depth_map = depth_map.resize(image.size, Image.LANCZOS) | |
| return depth_map | |
| except Exception as e: | |
| print(f"Depth generation failed: {e}") | |
| return None | |
| return None | |
| def generate( | |
| self, | |
| image, | |
| prompt="a person", | |
| negative_prompt="", | |
| num_inference_steps=4, | |
| guidance_scale=0.0, | |
| strength=0.75, | |
| lora_scale=1.0, | |
| identity_control_scale=0.8, | |
| depth_control_scale=0.8, | |
| identity_preservation=1.0, | |
| enable_color_matching=True, | |
| seed=-1, | |
| **kwargs | |
| ): | |
| """ | |
| Generate retro art with InstantID face preservation. | |
| FIXED: Proper IP-Adapter integration following exampleapp.py pattern. | |
| """ | |
| print(f"\n{'='*60}") | |
| print(f"Starting generation with:") | |
| print(f" Prompt: {prompt}") | |
| print(f" Steps: {num_inference_steps}, CFG: {guidance_scale}, Strength: {strength}") | |
| print(f" Identity scale: {identity_control_scale}, Depth scale: {depth_control_scale}") | |
| print(f" Face preservation: {identity_preservation}") | |
| print(f"{'='*60}\n") | |
| # Prepare input image | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| optimal_width, optimal_height = calculate_optimal_size(image.size) | |
| resized_image = image.resize((optimal_width, optimal_height), Image.LANCZOS) | |
| print(f"Image resized: {image.size} → {resized_image.size}") | |
| # Generate depth map | |
| print("Generating depth map...") | |
| depth_image = self.get_depth_map(resized_image) | |
| if depth_image is None: | |
| raise RuntimeError("Could not generate depth map") | |
| # Face detection and processing | |
| has_detected_faces = False | |
| face_kps_image = None | |
| face_embeddings = None | |
| face_crop = None | |
| face_crop_enhanced = None | |
| face_bbox_original = None | |
| if self.face_app is not None: | |
| print("Detecting faces...") | |
| try: | |
| image_np = cv2.cvtColor(np.array(resized_image), cv2.COLOR_RGB2BGR) | |
| faces = self.face_app.get(image_np) | |
| if len(faces) > 0: | |
| has_detected_faces = True | |
| face = faces[0] | |
| print(f" [OK] Face detected (score: {face.det_score:.3f})") | |
| # Get face keypoints image | |
| face_kps_image = draw_kps(resized_image, face.kps) | |
| # Get face embeddings (512D from InsightFace) | |
| if hasattr(face, 'normed_embedding'): | |
| face_embeddings = face.normed_embedding | |
| print(f" Face embedding shape: {face_embeddings.shape}") | |
| elif hasattr(face, 'embedding'): | |
| face_embeddings = face.embedding / np.linalg.norm(face.embedding) | |
| print(f" Face embedding shape: {face_embeddings.shape}") | |
| # Store face bbox for color matching | |
| if hasattr(face, 'bbox'): | |
| face_bbox_original = face.bbox | |
| # Get face crop for enhanced processing | |
| bbox = face.bbox.astype(int) | |
| face_crop = resized_image.crop((bbox[0], bbox[1], bbox[2], bbox[3])) | |
| face_crop_enhanced = enhance_face_crop(face_crop) | |
| # Debug info | |
| if hasattr(face, 'age') and hasattr(face, 'gender'): | |
| facial_attrs = { | |
| 'age': face.age, | |
| 'gender': face.gender, | |
| 'quality': face.det_score | |
| } | |
| age = facial_attrs['age'] | |
| gender_code = facial_attrs['gender'] | |
| det_score = facial_attrs['quality'] | |
| gender_str = 'M' if gender_code == 1 else ('F' if gender_code == 0 else 'N/A') | |
| print(f" Face info: age={age if age else 'N/A'}, gender={gender_str}, quality={det_score:.3f}") | |
| else: | |
| print(" [INFO] No faces detected") | |
| except Exception as e: | |
| print(f" [WARNING] Face detection failed: {e}") | |
| # CRITICAL FIX: Set IP-Adapter scale dynamically | |
| # The pipeline's built-in method allows runtime adjustment | |
| if self.models_loaded.get('ip_adapter', False) and has_detected_faces: | |
| try: | |
| # Scale based on identity_preservation parameter | |
| adjusted_scale = 0.8 * identity_preservation | |
| self.pipe.set_ip_adapter_scale(adjusted_scale) | |
| print(f" IP-Adapter scale adjusted to: {adjusted_scale:.2f}") | |
| except Exception as e: | |
| print(f" [WARNING] Could not adjust IP-Adapter scale: {e}") | |
| # Set LORA scale | |
| if hasattr(self.pipe, 'set_adapters') and self.models_loaded['lora']: | |
| try: | |
| self.pipe.set_adapters(["retroart"], adapter_weights=[lora_scale]) | |
| print(f" LORA scale: {lora_scale}") | |
| except Exception as e: | |
| print(f" [WARNING] Could not set LORA scale: {e}") | |
| # Prepare generation kwargs | |
| pipe_kwargs = { | |
| "image": resized_image, | |
| "strength": strength, | |
| "num_inference_steps": num_inference_steps, | |
| "guidance_scale": guidance_scale, | |
| } | |
| # Setup generator with seed control | |
| if seed == -1: | |
| generator = torch.Generator(device=self.device) | |
| actual_seed = generator.seed() | |
| print(f"[SEED] Using random seed: {actual_seed}") | |
| else: | |
| generator = torch.Generator(device=self.device).manual_seed(seed) | |
| actual_seed = seed | |
| print(f"[SEED] Using fixed seed: {actual_seed}") | |
| pipe_kwargs["generator"] = generator | |
| # Use Compel for prompt encoding if available | |
| if self.use_compel and self.compel is not None: | |
| try: | |
| print("Encoding prompts with Compel...") | |
| conditioning = self.compel(prompt) | |
| negative_conditioning = self.compel(negative_prompt) | |
| pipe_kwargs["prompt_embeds"] = conditioning[0] | |
| pipe_kwargs["pooled_prompt_embeds"] = conditioning[1] | |
| pipe_kwargs["negative_prompt_embeds"] = negative_conditioning[0] | |
| pipe_kwargs["negative_pooled_prompt_embeds"] = negative_conditioning[1] | |
| print(" [OK] Using Compel-encoded prompts") | |
| except Exception as e: | |
| print(f" Compel encoding failed, using standard prompts: {e}") | |
| pipe_kwargs["prompt"] = prompt | |
| pipe_kwargs["negative_prompt"] = negative_prompt | |
| else: | |
| pipe_kwargs["prompt"] = prompt | |
| pipe_kwargs["negative_prompt"] = negative_prompt | |
| # Add CLIP skip | |
| if hasattr(self.pipe, 'text_encoder'): | |
| pipe_kwargs["clip_skip"] = 2 | |
| # Configure ControlNet inputs | |
| using_multiple_controlnets = self.using_multiple_controlnets | |
| if using_multiple_controlnets and has_detected_faces and face_kps_image is not None: | |
| print("Using InstantID (keypoints + embeddings) + Depth ControlNets") | |
| control_images = [face_kps_image, depth_image] | |
| conditioning_scales = [identity_control_scale, depth_control_scale] | |
| pipe_kwargs["control_image"] = control_images | |
| pipe_kwargs["controlnet_conditioning_scale"] = conditioning_scales | |
| # CRITICAL FIX: The pipeline handles face embeddings automatically! | |
| # When load_ip_adapter_instantid() was called, the pipeline was configured | |
| # to automatically process face embeddings through the Resampler and | |
| # integrate them with text embeddings during generation. | |
| # | |
| # We just need to provide the face image via control_image and the | |
| # pipeline does the rest. No manual concatenation needed! | |
| if face_embeddings is not None and self.models_loaded.get('ip_adapter', False): | |
| print(" [OK] Face embeddings will be processed by pipeline") | |
| print(" - Pipeline automatically handles Resampler projection") | |
| print(" - Face features integrated via IP-Adapter attention") | |
| elif using_multiple_controlnets and not has_detected_faces: | |
| print("Multiple ControlNets available but no faces detected, using depth only") | |
| control_images = [depth_image, depth_image] | |
| conditioning_scales = [0.0, depth_control_scale] | |
| pipe_kwargs["control_image"] = control_images | |
| pipe_kwargs["controlnet_conditioning_scale"] = conditioning_scales | |
| else: | |
| print("Using Depth ControlNet only") | |
| pipe_kwargs["control_image"] = depth_image | |
| pipe_kwargs["controlnet_conditioning_scale"] = depth_control_scale | |
| # Generate | |
| print(f"\nGenerating with LCM:") | |
| print(f" Steps: {num_inference_steps}, CFG: {guidance_scale}, Strength: {strength}") | |
| print(f" ControlNet scales - Identity: {identity_control_scale}, Depth: {depth_control_scale}") | |
| result = self.pipe(**pipe_kwargs) | |
| generated_image = result.images[0] | |
| # Post-processing | |
| if enable_color_matching and has_detected_faces: | |
| print("\nApplying enhanced face-aware color matching...") | |
| try: | |
| if face_bbox_original is not None: | |
| generated_image = enhanced_color_match( | |
| generated_image, | |
| resized_image, | |
| face_bbox=face_bbox_original | |
| ) | |
| print(" [OK] Enhanced color matching applied (face-aware)") | |
| else: | |
| generated_image = color_match(generated_image, resized_image, mode='mkl') | |
| print(" [OK] Standard color matching applied") | |
| except Exception as e: | |
| print(f" [WARNING] Color matching failed: {e}") | |
| elif enable_color_matching: | |
| print("\nApplying standard color matching...") | |
| try: | |
| generated_image = color_match(generated_image, resized_image, mode='mkl') | |
| print(" [OK] Standard color matching applied") | |
| except Exception as e: | |
| print(f" [WARNING] Color matching failed: {e}") | |
| print(f"\n{'='*60}") | |
| print("Generation complete!") | |
| print(f"{'='*60}\n") | |
| return generated_image | |
| print("[OK] Generator class ready (FIXED VERSION)") |