Spaces:
Runtime error
Runtime error
| """ | |
| Generation logic for Pixagram AI Pixel Art Generator | |
| FIXED VERSION - Following exampleapp.py pattern more closely | |
| """ | |
| import torch | |
| import numpy as np | |
| import cv2 | |
| from PIL import Image | |
| import torch.nn.functional as F | |
| from torchvision import transforms | |
| from config import ( | |
| device, dtype, TRIGGER_WORD, MULTI_SCALE_FACTORS, | |
| ADAPTIVE_THRESHOLDS, ADAPTIVE_PARAMS, CAPTION_CONFIG, IDENTITY_BOOST_MULTIPLIER, | |
| MODEL_REPO, MODEL_FILES | |
| ) | |
| from utils import ( | |
| sanitize_text, enhanced_color_match, color_match, create_face_mask, | |
| draw_kps, get_demographic_description, calculate_optimal_size, enhance_face_crop | |
| ) | |
| from models import ( | |
| load_face_analysis, load_depth_detector, load_controlnets, load_image_encoder, | |
| load_sdxl_pipeline, load_lora, setup_ip_adapter, setup_compel, | |
| setup_scheduler, optimize_pipeline, load_caption_model, set_clip_skip | |
| ) | |
| class RetroArtConverter: | |
| """Main class for retro art generation - FIXED VERSION""" | |
| def __init__(self): | |
| self.device = device | |
| self.dtype = dtype | |
| self.models_loaded = { | |
| 'custom_checkpoint': False, | |
| 'lora': False, | |
| 'lora_path': None, | |
| 'instantid': False, | |
| 'zoe_depth': False, | |
| 'ip_adapter': False | |
| } | |
| # Initialize face analysis | |
| self.face_app, self.face_detection_enabled = load_face_analysis() | |
| # Load Zoe Depth detector | |
| self.zoe_depth, zoe_success = load_depth_detector() | |
| self.models_loaded['zoe_depth'] = zoe_success | |
| # Load ControlNets | |
| controlnet_depth, self.controlnet_instantid, instantid_success = load_controlnets() | |
| self.controlnet_depth = controlnet_depth | |
| self.instantid_enabled = instantid_success | |
| self.models_loaded['instantid'] = instantid_success | |
| # Load image encoder | |
| if self.instantid_enabled: | |
| self.image_encoder = load_image_encoder() | |
| else: | |
| self.image_encoder = None | |
| # Determine which controlnets to use | |
| if self.instantid_enabled and self.controlnet_instantid is not None: | |
| controlnets = [self.controlnet_instantid, controlnet_depth] | |
| print(f"Initializing with multiple ControlNets: InstantID + Depth") | |
| else: | |
| controlnets = controlnet_depth | |
| print(f"Initializing with single ControlNet: Depth only") | |
| # Load SDXL pipeline | |
| self.pipe, checkpoint_success = load_sdxl_pipeline(controlnets) | |
| self.models_loaded['custom_checkpoint'] = checkpoint_success | |
| # Load LORA and store path | |
| lora_success = load_lora(self.pipe) | |
| self.models_loaded['lora'] = lora_success | |
| if lora_success: | |
| # Store LORA path for later reloading | |
| from huggingface_hub import hf_hub_download | |
| try: | |
| lora_path = hf_hub_download(MODEL_REPO, MODEL_FILES['lora']) | |
| self.models_loaded['lora_path'] = lora_path | |
| except: | |
| self.models_loaded['lora_path'] = None | |
| # Setup IP-Adapter using pipeline's built-in method | |
| if self.instantid_enabled and self.image_encoder is not None: | |
| ip_adapter_success = setup_ip_adapter(self.pipe) | |
| self.models_loaded['ip_adapter'] = ip_adapter_success | |
| else: | |
| print("[INFO] Face preservation: InstantID ControlNet keypoints only") | |
| self.models_loaded['ip_adapter'] = False | |
| # Setup Compel | |
| self.compel, self.use_compel = setup_compel(self.pipe) | |
| # Setup LCM scheduler | |
| setup_scheduler(self.pipe) | |
| # Optimize pipeline | |
| optimize_pipeline(self.pipe) | |
| # Load caption model | |
| self.caption_processor, self.caption_model, self.caption_enabled, self.caption_model_type = load_caption_model() | |
| # Report caption model status | |
| if self.caption_enabled and self.caption_model is not None: | |
| if self.caption_model_type == "git": | |
| print(" [OK] Using GIT for detailed captions") | |
| elif self.caption_model_type == "blip": | |
| print(" [OK] Using BLIP for standard captions") | |
| else: | |
| print(" [OK] Caption model loaded") | |
| # Set CLIP skip | |
| set_clip_skip(self.pipe) | |
| # Track controlnet configuration | |
| self.using_multiple_controlnets = isinstance(controlnets, list) | |
| print(f"Pipeline initialized with {'multiple' if self.using_multiple_controlnets else 'single'} ControlNet(s)") | |
| # Print model status | |
| self._print_status() | |
| print(" [OK] Model initialization complete!") | |
| def _print_status(self): | |
| """Print model loading status""" | |
| print("\n=== MODEL STATUS ===") | |
| for model, loaded in self.models_loaded.items(): | |
| if model == 'lora_path': | |
| continue | |
| status = "[OK] LOADED" if loaded else "[FALLBACK/DISABLED]" | |
| print(f"{model}: {status}") | |
| print("===================\n") | |
| print("=== IP-ADAPTER STATUS ===") | |
| if self.models_loaded.get('ip_adapter', False): | |
| if hasattr(self.pipe, 'image_proj_model'): | |
| print("[OK] IP-Adapter fully loaded via pipeline method") | |
| print(" - Resampler: Available at pipe.image_proj_model") | |
| print(" - Scale control: Available via pipe.set_ip_adapter_scale()") | |
| print(" - Expected improvement: High face similarity") | |
| else: | |
| print("[WARNING] IP-Adapter loaded but Resampler not accessible") | |
| else: | |
| print("[INFO] IP-Adapter not active (using keypoints only)") | |
| print("=========================\n") | |
| def get_depth_map(self, image): | |
| """Generate depth map using Zoe Depth""" | |
| if self.zoe_depth is not None: | |
| try: | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| orig_width, orig_height = image.size | |
| orig_width = int(orig_width) | |
| orig_height = int(orig_height) | |
| # Use multiples of 64 | |
| target_width = int((orig_width // 64) * 64) | |
| target_height = int((orig_height // 64) * 64) | |
| target_width = int(max(64, target_width)) | |
| target_height = int(max(64, target_height)) | |
| size_for_depth = (int(target_width), int(target_height)) | |
| image_for_depth = image.resize(size_for_depth, Image.LANCZOS) | |
| depth_map = self.zoe_depth(image_for_depth, detect_resolution=512, image_resolution=512) | |
| if depth_map.size != image.size: | |
| depth_map = depth_map.resize(image.size, Image.LANCZOS) | |
| return depth_map | |
| except Exception as e: | |
| print(f"Depth generation failed: {e}") | |
| return None | |
| return None | |
| def generate( | |
| self, | |
| image, | |
| prompt="a person", | |
| negative_prompt="", | |
| num_inference_steps=12, | |
| guidance_scale=0.0, | |
| strength=0.75, | |
| lora_scale=1.0, | |
| identity_control_scale=0.8, | |
| depth_control_scale=0.8, | |
| identity_preservation=1.0, | |
| enable_color_matching=True, | |
| consistency_mode=True, | |
| seed=-1 | |
| ): | |
| """ | |
| Generate retro art with InstantID face preservation. | |
| FIXED: Following exampleapp.py pattern more closely. | |
| """ | |
| print(f"\n{'='*60}") | |
| print(f"Starting generation with:") | |
| print(f" Prompt: {prompt}") | |
| print(f" Steps: {num_inference_steps}, CFG: {guidance_scale}, Strength: {strength}") | |
| print(f" Identity scale: {identity_control_scale}, Depth scale: {depth_control_scale}") | |
| print(f" Face preservation: {identity_preservation}") | |
| print(f" Consistency mode: {'ON' if consistency_mode else 'OFF'}") | |
| print(f"{'='*60}\n") | |
| # Apply consistency mode adjustments | |
| if consistency_mode: | |
| print("[CONSISTENCY] Validating and adjusting parameters...") | |
| # Validate guidance scale for LCM | |
| if guidance_scale > 2.0: | |
| print(f" [ADJUST] CFG too high ({guidance_scale:.2f}), capping at 2.0") | |
| guidance_scale = 2.0 | |
| elif guidance_scale < 0.5: | |
| print(f" [ADJUST] CFG too low ({guidance_scale:.2f}), raising to 0.5") | |
| guidance_scale = 0.5 | |
| # Balance identity preservation and LORA scale | |
| if identity_preservation > 1.5 and lora_scale > 1.5: | |
| print(f" [ADJUST] High identity + high LORA conflict detected") | |
| print(f" Reducing LORA scale: {lora_scale:.2f} → {lora_scale * 0.8:.2f}") | |
| lora_scale = lora_scale * 0.8 | |
| # Ensure ControlNet scales are reasonable | |
| if depth_control_scale > 1.2: | |
| print(f" [ADJUST] Depth scale too high ({depth_control_scale:.2f}), capping at 1.2") | |
| depth_control_scale = 1.2 | |
| if identity_control_scale > 1.5: | |
| print(f" [ADJUST] Identity control too high ({identity_control_scale:.2f}), capping at 1.5") | |
| identity_control_scale = 1.5 | |
| # Validate strength range | |
| if strength < 0.3: | |
| print(f" [ADJUST] Strength too low ({strength:.2f}), raising to 0.3") | |
| strength = 0.3 | |
| elif strength > 0.9: | |
| print(f" [ADJUST] Strength too high ({strength:.2f}), capping at 0.9") | |
| strength = 0.9 | |
| print("[CONSISTENCY] Parameter validation complete\n") | |
| # Prepare input image | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| optimal_width, optimal_height = calculate_optimal_size(image.size[0], image.size[1]) | |
| resized_image = image.resize((optimal_width, optimal_height), Image.LANCZOS) | |
| print(f"Image resized: {image.size} → {resized_image.size}") | |
| # Generate depth map | |
| print("Generating depth map...") | |
| depth_image = self.get_depth_map(resized_image) | |
| if depth_image is None: | |
| raise RuntimeError("Could not generate depth map") | |
| # Face detection and processing | |
| has_detected_faces = False | |
| face_kps_image = None | |
| face_embeddings = None | |
| face_crop = None | |
| face_crop_enhanced = None | |
| face_bbox_original = None | |
| if self.face_app is not None: | |
| print("Detecting faces...") | |
| try: | |
| image_np = cv2.cvtColor(np.array(resized_image), cv2.COLOR_RGB2BGR) | |
| faces = self.face_app.get(image_np) | |
| if len(faces) > 0: | |
| has_detected_faces = True | |
| face = faces[0] | |
| print(f" [OK] Face detected (score: {face.det_score:.3f})") | |
| # Get face keypoints image | |
| face_kps_image = draw_kps(resized_image, face.kps) | |
| # Get face embeddings (512D from InsightFace) | |
| if hasattr(face, 'normed_embedding') and face.normed_embedding is not None: | |
| face_embeddings = face.normed_embedding | |
| print(f" Face embedding extracted (normed_embedding): shape {face_embeddings.shape}") | |
| elif hasattr(face, 'embedding') and face.embedding is not None: | |
| face_embeddings = face.embedding / np.linalg.norm(face.embedding) | |
| print(f" Face embedding extracted (embedding, normalized): shape {face_embeddings.shape}") | |
| elif isinstance(face, dict) and 'embedding' in face: | |
| face_embeddings = face['embedding'] | |
| print(f" Face embedding extracted (dict['embedding']): shape {face_embeddings.shape}") | |
| else: | |
| face_embeddings = None | |
| print(f" [WARNING] Face detected but embeddings not available") | |
| # Store face bbox for color matching | |
| if hasattr(face, 'bbox'): | |
| face_bbox_original = face.bbox | |
| # Get face crop for enhanced processing | |
| bbox = face.bbox.astype(int) | |
| face_crop = resized_image.crop((bbox[0], bbox[1], bbox[2], bbox[3])) | |
| face_crop_enhanced = enhance_face_crop(face_crop) | |
| # Debug info | |
| if hasattr(face, 'age') and hasattr(face, 'gender'): | |
| age = face.age | |
| gender_code = face.gender | |
| det_score = face.det_score | |
| gender_str = 'M' if gender_code == 1 else ('F' if gender_code == 0 else 'N/A') | |
| print(f" Face info: age={age if age else 'N/A'}, gender={gender_str}, quality={det_score:.3f}") | |
| else: | |
| print(" [INFO] No faces detected") | |
| except Exception as e: | |
| print(f" [WARNING] Face detection failed: {e}") | |
| # Unfuse and reload LORA with new scale (like exampleapp.py) | |
| #if hasattr(self.pipe, 'unfuse_lora'): | |
| # try: | |
| # self.pipe.unfuse_lora() | |
| # self.pipe.unload_lora_weights() | |
| # print(" [OK] Unfused previous LORA") | |
| # except Exception as e: | |
| # print(f" [INFO] No previous LORA to unfuse: {e}") | |
| # Load and fuse LORA at the requested scale | |
| #if self.models_loaded['lora'] and self.models_loaded.get('lora_path'): | |
| # try: | |
| # self.pipe.load_lora_weights(self.models_loaded['lora_path']) | |
| # self.pipe.fuse_lora(lora_scale=lora_scale) | |
| # print(f" [OK] LORA fused at scale: {lora_scale}") | |
| # except Exception as e: | |
| # print(f" [WARNING] Could not fuse LORA: {e}") | |
| # --- CORRECTED BLOCK --- | |
| # Set LORA scale using set_adapters (matches models.py) | |
| if hasattr(self.pipe, 'set_adapters') and self.models_loaded['lora']: | |
| try: | |
| self.pipe.set_adapters(["retroart"], adapter_weights=[lora_scale]) | |
| print(f"LORA scale: {lora_scale}") | |
| except Exception as e: | |
| print(f"Could not set LORA scale: {e}") | |
| # --- END OF BLOCK --- | |
| # Setup generator with seed control | |
| if seed == -1: | |
| generator = torch.Generator(device=self.device) | |
| actual_seed = generator.seed() | |
| print(f"[SEED] Using random seed: {actual_seed}") | |
| else: | |
| generator = torch.Generator(device=self.device).manual_seed(seed) | |
| actual_seed = seed | |
| print(f"[SEED] Using fixed seed: {actual_seed}") | |
| # Use Compel for prompt encoding (like exampleapp.py - simpler) | |
| if self.use_compel and self.compel is not None: | |
| print("Encoding prompts with Compel...") | |
| # --- FIX: Add the LORA trigger word --- | |
| # Ensure trigger word is present and avoid duplicates | |
| if TRIGGER_WORD not in prompt: | |
| # Prepend the trigger word for highest impact | |
| prompt = f"{TRIGGER_WORD}, {prompt}" | |
| print(f" Using final prompt: {prompt}") | |
| # --- End Fix --- | |
| conditioning, pooled = self.compel(prompt) | |
| negative_conditioning, negative_pooled = self.compel(negative_prompt) | |
| print(" [OK] Prompts encoded") | |
| else: | |
| # Fallback to standard prompts | |
| conditioning = None | |
| pooled = None | |
| negative_conditioning = None | |
| negative_pooled = None | |
| # Set CLIP skip | |
| clip_skip = 2 if hasattr(self.pipe, 'text_encoder') else None | |
| # Configure ControlNet inputs | |
| using_multiple_controlnets = self.using_multiple_controlnets | |
| if using_multiple_controlnets and has_detected_faces and face_kps_image is not None: | |
| print("Using InstantID (keypoints + embeddings) + Depth ControlNets") | |
| control_image = [face_kps_image, depth_image] | |
| conditioning_scales = [identity_control_scale, depth_control_scale] | |
| # Set IP-Adapter scale if embeddings available | |
| if face_embeddings is not None: | |
| adjusted_scale = 0.8 * identity_preservation | |
| self.pipe.set_ip_adapter_scale(adjusted_scale) | |
| print(f" IP-Adapter scale: {adjusted_scale:.2f}") | |
| print(f" Face embeddings shape: {face_embeddings.shape}") | |
| print(" [OK] Face embeddings ready for InstantID pipeline") | |
| else: | |
| # No embeddings, pass None | |
| face_embeddings = None | |
| print(" [INFO] No face embeddings, passing None to pipeline") | |
| elif using_multiple_controlnets and not has_detected_faces: | |
| print("Multiple ControlNets available but no faces detected, using depth only") | |
| # The InstantID controlnet (index 0) still needs an image input. | |
| # We provide a blank image and set its scale to 0.0 to disable it. | |
| blank_image = Image.new("RGB", depth_image.size, (0, 0, 0)) | |
| control_image = [blank_image, depth_image] | |
| conditioning_scales = [0.0, depth_control_scale] | |
| face_embeddings = None | |
| else: | |
| print("Using Depth ControlNet only") | |
| control_image = depth_image | |
| conditioning_scales = depth_control_scale | |
| face_embeddings = None | |
| # Generate (like exampleapp.py - direct call) | |
| print(f"\nGenerating with LCM:") | |
| print(f" Steps: {num_inference_steps}, CFG: {guidance_scale}, Strength: {strength}") | |
| print(f" ControlNet scales - Identity: {identity_control_scale}, Depth: {depth_control_scale}") | |
| try: | |
| generated_image = self.pipe( | |
| prompt_embeds=conditioning, | |
| pooled_prompt_embeds=pooled, | |
| negative_prompt_embeds=negative_conditioning, | |
| negative_pooled_prompt_embeds=negative_pooled, | |
| width=optimal_width, | |
| height=optimal_height, | |
| image_embeds=face_embeddings, | |
| image=resized_image, | |
| strength=strength, | |
| control_image=control_image, | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=guidance_scale, | |
| clip_skip=clip_skip, | |
| generator=generator, | |
| controlnet_conditioning_scale=conditioning_scales | |
| ).images[0] | |
| except Exception as e: | |
| print(f"[ERROR] Generation failed: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| raise | |
| # Post-processing | |
| if enable_color_matching and has_detected_faces: | |
| print("\nApplying enhanced face-aware color matching...") | |
| try: | |
| if face_bbox_original is not None: | |
| generated_image = enhanced_color_match( | |
| generated_image, | |
| resized_image, | |
| face_bbox=face_bbox_original | |
| ) | |
| print(" [OK] Enhanced color matching applied (face-aware)") | |
| else: | |
| generated_image = color_match(generated_image, resized_image, mode='mkl') | |
| print(" [OK] Standard color matching applied") | |
| except Exception as e: | |
| print(f" [WARNING] Color matching failed: {e}") | |
| elif enable_color_matching: | |
| print("\nApplying standard color matching...") | |
| try: | |
| generated_image = color_match(generated_image, resized_image, mode='mkl') | |
| print(" [OK] Standard color matching applied") | |
| except Exception as e: | |
| print(f" [WARNING] Color matching failed: {e}") | |
| print(f"\n{'='*60}") | |
| print("Generation complete!") | |
| print(f"{'='*60}\n") | |
| return generated_image | |
| def generate_caption(self, image): | |
| """ | |
| Generate a caption for an image. | |
| Returns None if caption generation is disabled. | |
| """ | |
| if not self.caption_enabled or self.caption_model is None: | |
| return None | |
| try: | |
| # Ensure image is PIL Image | |
| if not isinstance(image, Image.Image): | |
| image = Image.fromarray(image) | |
| # Convert to RGB if needed | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| print("Generating caption...") | |
| with torch.no_grad(): | |
| if self.caption_model_type == 'git': | |
| # GIT model | |
| inputs = self.caption_processor(images=image, return_tensors="pt").to(self.device) | |
| generated_ids = self.caption_model.generate( | |
| pixel_values=inputs.pixel_values, | |
| max_length=50 | |
| ) | |
| caption = self.caption_processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
| elif self.caption_model_type == 'blip': | |
| # BLIP model | |
| inputs = self.caption_processor(image, return_tensors="pt").to(self.device) | |
| generated_ids = self.caption_model.generate(**inputs, max_length=50) | |
| caption = self.caption_processor.decode(generated_ids[0], skip_special_tokens=True) | |
| else: | |
| return None | |
| print(f" [OK] Caption: {caption}") | |
| return caption | |
| except Exception as e: | |
| print(f" [WARNING] Caption generation failed: {e}") | |
| return None | |
| print("[OK] Generator class ready (FIXED VERSION - exampleapp.py style)") |