Spaces:
Runtime error
Runtime error
| """ | |
| Generation logic for Pixagram AI Pixel Art Generator | |
| UPDATED VERSION with InstantID pipeline integration | |
| """ | |
| import torch | |
| import numpy as np | |
| import cv2 | |
| from PIL import Image | |
| import gc | |
| from config import ( | |
| device, dtype, TRIGGER_WORD, | |
| ADAPTIVE_THRESHOLDS, ADAPTIVE_PARAMS, CAPTION_CONFIG | |
| ) | |
| from utils import ( | |
| sanitize_text, enhanced_color_match, color_match, | |
| get_demographic_description, calculate_optimal_size, safe_image_size | |
| ) | |
| from models import ( | |
| load_face_analysis, load_depth_detector, load_controlnets, | |
| load_sdxl_pipeline, load_lora, setup_compel, | |
| setup_scheduler, optimize_pipeline, load_caption_model, set_clip_skip | |
| ) | |
| class RetroArtConverter: | |
| """Main class for retro art generation with InstantID""" | |
| def __init__(self): | |
| self.device = device | |
| self.dtype = dtype | |
| self.models_loaded = { | |
| 'custom_checkpoint': False, | |
| 'lora': False, | |
| 'instantid': False, | |
| 'zoe_depth': False | |
| } | |
| # Load face analysis | |
| self.face_app, self.face_detection_enabled = load_face_analysis() | |
| # Load depth detector | |
| self.zoe_depth, zoe_success = load_depth_detector() | |
| self.models_loaded['zoe_depth'] = zoe_success | |
| # Load ControlNets AS LIST | |
| controlnet_instantid, controlnet_depth = load_controlnets() | |
| controlnets = [controlnet_instantid, controlnet_depth] | |
| self.models_loaded['instantid'] = True | |
| print("Initializing InstantID pipeline with Face + Depth ControlNets") | |
| # Load SDXL pipeline with InstantID (handles IP-Adapter internally) | |
| self.pipe, checkpoint_success = load_sdxl_pipeline(controlnets) | |
| self.models_loaded['custom_checkpoint'] = checkpoint_success | |
| # Load LORA | |
| lora_success = load_lora(self.pipe) | |
| self.models_loaded['lora'] = lora_success | |
| # Setup Compel | |
| self.compel, self.use_compel = setup_compel(self.pipe) | |
| # Setup scheduler | |
| setup_scheduler(self.pipe) | |
| # Optimize | |
| optimize_pipeline(self.pipe) | |
| # Load caption model | |
| self.caption_processor, self.caption_model, self.caption_enabled, self.caption_model_type = load_caption_model() | |
| # Set CLIP skip | |
| set_clip_skip(self.pipe) | |
| # Print status | |
| self._print_status() | |
| print(" [OK] RetroArtConverter initialized with InstantID!") | |
| def _print_status(self): | |
| """Print model loading status""" | |
| print("\n=== MODEL STATUS ===") | |
| for model, loaded in self.models_loaded.items(): | |
| status = "[OK] LOADED" if loaded else "[FALLBACK/DISABLED]" | |
| print(f"{model}: {status}") | |
| print("InstantID Pipeline: [OK] ACTIVE") | |
| print("IP-Adapter: [OK] Built into pipeline") | |
| print("===================\n") | |
| def get_depth_map(self, image): | |
| """Generate depth map using Zoe Depth""" | |
| if self.zoe_depth is not None: | |
| try: | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| # Use safe size helper to avoid numpy.int64 issues | |
| orig_width, orig_height = safe_image_size(image) | |
| # Use multiples of 64 | |
| target_width = int((orig_width // 64) * 64) | |
| target_height = int((orig_height // 64) * 64) | |
| target_width = int(max(64, target_width)) | |
| target_height = int(max(64, target_height)) | |
| size_for_depth = (int(target_width), int(target_height)) | |
| image_for_depth = image.resize(size_for_depth, Image.LANCZOS) | |
| depth_array = self.zoe_depth(image_for_depth, detect_resolution=512, image_resolution=1024) | |
| depth_image = Image.fromarray(depth_array) | |
| if depth_image.size != image.size: | |
| depth_image = depth_image.resize(image.size, Image.LANCZOS) | |
| print(f"[DEPTH] Generated depth map: {depth_image.size}") | |
| return depth_image, depth_array | |
| except Exception as e: | |
| print(f"[DEPTH] Generation failed: {e}, using grayscale") | |
| return image.convert('L').convert('RGB'), None | |
| else: | |
| print("[DEPTH] Detector not available, using grayscale") | |
| return image.convert('L').convert('RGB'), None | |
| def add_trigger_word(self, prompt): | |
| """Add trigger word to prompt if not present""" | |
| if TRIGGER_WORD.lower() not in prompt.lower(): | |
| if not prompt or not prompt.strip(): | |
| return TRIGGER_WORD | |
| return f"{TRIGGER_WORD}, {prompt}" | |
| return prompt | |
| def detect_face_quality(self, face): | |
| """Detect face quality and adaptively adjust parameters""" | |
| try: | |
| bbox = face.bbox | |
| face_size = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) | |
| det_score = float(face.det_score) if hasattr(face, 'det_score') else 1.0 | |
| # Small face -> boost preservation | |
| if face_size < ADAPTIVE_THRESHOLDS['small_face_size']: | |
| return ADAPTIVE_PARAMS['small_face'].copy() | |
| # Low confidence -> boost preservation | |
| elif det_score < ADAPTIVE_THRESHOLDS['low_confidence']: | |
| return ADAPTIVE_PARAMS['low_confidence'].copy() | |
| # Check for profile view | |
| elif hasattr(face, 'pose') and len(face.pose) > 1: | |
| try: | |
| yaw = float(face.pose[1]) | |
| if abs(yaw) > ADAPTIVE_THRESHOLDS['profile_angle']: | |
| return ADAPTIVE_PARAMS['profile_view'].copy() | |
| except (ValueError, TypeError, IndexError): | |
| pass | |
| return None | |
| except Exception as e: | |
| print(f"[ADAPTIVE] Quality detection failed: {e}") | |
| return None | |
| def generate_caption(self, image): | |
| """Generate caption for image""" | |
| if not self.caption_enabled or self.caption_model is None: | |
| return None | |
| try: | |
| if self.caption_model_type == 'git': | |
| inputs = self.caption_processor(images=image, return_tensors="pt").to(self.device) | |
| generated_ids = self.caption_model.generate(**inputs, max_length=CAPTION_CONFIG['max_length']) | |
| caption = self.caption_processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
| elif self.caption_model_type == 'blip': | |
| inputs = self.caption_processor(image, return_tensors="pt").to(self.device) | |
| generated_ids = self.caption_model.generate(**inputs, max_length=CAPTION_CONFIG['max_length']) | |
| caption = self.caption_processor.decode(generated_ids[0], skip_special_tokens=True) | |
| else: | |
| return None | |
| return sanitize_text(caption) | |
| except Exception as e: | |
| print(f"[CAPTION] Generation failed: {e}") | |
| return None | |
| def generate_retro_art( | |
| self, | |
| input_image, | |
| prompt=" ", | |
| negative_prompt=" ", | |
| num_inference_steps=12, | |
| guidance_scale=1.3, | |
| depth_control_scale=0.75, | |
| identity_control_scale=0.85, | |
| lora_scale=1.0, | |
| identity_preservation=1.2, | |
| strength=0.50, | |
| enable_color_matching=False, | |
| consistency_mode=True, | |
| seed=-1 | |
| ): | |
| """Generate retro art with InstantID face preservation""" | |
| try: | |
| # Add trigger word | |
| prompt = self.add_trigger_word(prompt) | |
| prompt = sanitize_text(prompt) | |
| negative_prompt = sanitize_text(negative_prompt) | |
| print(f"[PROMPT] {prompt}") | |
| # Calculate optimal size | |
| orig_width, orig_height = safe_image_size(input_image) | |
| optimal_width, optimal_height = calculate_optimal_size(orig_width, orig_height) | |
| # Resize image | |
| resized_image = input_image.resize((optimal_width, optimal_height), Image.LANCZOS) | |
| print(f"[SIZE] Resized to {optimal_width}x{optimal_height}") | |
| # Generate depth map | |
| depth_image, depth_array = self.get_depth_map(resized_image) | |
| # Detect faces | |
| has_detected_faces = False | |
| face_kps_image = None | |
| face_embeddings = None | |
| face_bbox_original = None | |
| if self.face_detection_enabled and self.face_app is not None: | |
| try: | |
| image_array = cv2.cvtColor(np.array(resized_image), cv2.COLOR_RGB2BGR) | |
| faces = self.face_app.get(image_array) | |
| if len(faces) > 0: | |
| has_detected_faces = True | |
| face = faces[0] | |
| # Get face embeddings (512D array) | |
| face_embeddings = face.normed_embedding | |
| # Draw keypoints | |
| from pipeline_stable_diffusion_xl_instantid_img2img import draw_kps | |
| face_kps_image = draw_kps(resized_image, face.kps) | |
| # Get bbox for color matching | |
| face_bbox_original = face.bbox | |
| # Adaptive parameter adjustment | |
| adaptive_params = self.detect_face_quality(face) | |
| if adaptive_params: | |
| print(f"[ADAPTIVE] {adaptive_params['reason']}") | |
| identity_preservation = adaptive_params.get('identity_preservation', identity_preservation) | |
| identity_control_scale = adaptive_params.get('identity_control_scale', identity_control_scale) | |
| guidance_scale = adaptive_params.get('guidance_scale', guidance_scale) | |
| lora_scale = adaptive_params.get('lora_scale', lora_scale) | |
| print(f"[FACE] Detected face with {face.det_score:.2f} confidence") | |
| print(f"[FACE] Embeddings shape: {face_embeddings.shape}") | |
| else: | |
| print("[FACE] No faces detected") | |
| except Exception as e: | |
| print(f"[FACE] Detection failed: {e}") | |
| # Set LORA scale | |
| if hasattr(self.pipe, 'set_adapters') and self.models_loaded['lora']: | |
| try: | |
| self.pipe.set_adapters(["retroart"], adapter_weights=[lora_scale]) | |
| print(f"[LORA] Scale: {lora_scale}") | |
| except Exception as e: | |
| print(f"[LORA] Could not set scale: {e}") | |
| # Prepare generation kwargs | |
| pipe_kwargs = { | |
| "image": resized_image, | |
| "strength": strength, | |
| "num_inference_steps": num_inference_steps, | |
| "guidance_scale": guidance_scale, | |
| } | |
| # Setup generator with seed | |
| if seed == -1: | |
| generator = torch.Generator(device=self.device) | |
| actual_seed = generator.seed() | |
| print(f"[SEED] Random: {actual_seed}") | |
| else: | |
| generator = torch.Generator(device=self.device).manual_seed(seed) | |
| actual_seed = seed | |
| print(f"[SEED] Fixed: {actual_seed}") | |
| pipe_kwargs["generator"] = generator | |
| # Use Compel for prompt encoding | |
| if self.use_compel and self.compel is not None: | |
| try: | |
| conditioning = self.compel(prompt) | |
| negative_conditioning = self.compel(negative_prompt) | |
| pipe_kwargs["prompt_embeds"] = conditioning[0] | |
| pipe_kwargs["pooled_prompt_embeds"] = conditioning[1] | |
| pipe_kwargs["negative_prompt_embeds"] = negative_conditioning[0] | |
| pipe_kwargs["negative_pooled_prompt_embeds"] = negative_conditioning[1] | |
| print("[OK] Using Compel-encoded prompts") | |
| except Exception as e: | |
| print(f"[COMPEL] Failed, using standard prompts: {e}") | |
| pipe_kwargs["prompt"] = prompt | |
| pipe_kwargs["negative_prompt"] = negative_prompt | |
| else: | |
| pipe_kwargs["prompt"] = prompt | |
| pipe_kwargs["negative_prompt"] = negative_prompt | |
| # Configure ControlNets + IP-Adapter (SIMPLIFIED!) | |
| if has_detected_faces and face_kps_image is not None: | |
| print("Using InstantID (keypoints + embeddings) + Depth ControlNets") | |
| # Control images: [face keypoints, depth map] | |
| pipe_kwargs["control_image"] = [face_kps_image, depth_image] | |
| # Conditioning scales: [identity, depth] | |
| pipe_kwargs["controlnet_conditioning_scale"] = [ | |
| identity_control_scale, | |
| depth_control_scale | |
| ] | |
| # IP-Adapter face embeddings (SIMPLE - pipeline handles everything!) | |
| if face_embeddings is not None: | |
| print(f"Adding face embeddings for IP-Adapter...") | |
| # Just pass the embeddings - pipeline does the rest! | |
| pipe_kwargs["image_embeds"] = face_embeddings | |
| # Control IP-Adapter strength | |
| pipe_kwargs["ip_adapter_scale"] = identity_preservation | |
| print(f" - Face embeddings shape: {face_embeddings.shape}") | |
| print(f" - IP-Adapter scale: {identity_preservation}") | |
| print(f" [OK] Face embeddings configured") | |
| else: | |
| print(" [WARNING] No face embeddings - using keypoints only") | |
| else: | |
| print("No faces detected - using Depth ControlNet only") | |
| # Use depth for both ControlNet slots (identity scale = 0) | |
| pipe_kwargs["control_image"] = [depth_image, depth_image] | |
| pipe_kwargs["controlnet_conditioning_scale"] = [0.0, depth_control_scale] | |
| # Generate | |
| print(f"Generating: Steps={num_inference_steps}, CFG={guidance_scale}, Strength={strength}") | |
| result = self.pipe(**pipe_kwargs) | |
| generated_image = result.images[0] | |
| # Post-processing: Color matching | |
| if enable_color_matching and has_detected_faces: | |
| print("Applying enhanced face-aware color matching...") | |
| try: | |
| if face_bbox_original is not None: | |
| generated_image = enhanced_color_match( | |
| generated_image, | |
| resized_image, | |
| face_bbox=face_bbox_original | |
| ) | |
| print("[OK] Enhanced color matching applied") | |
| else: | |
| generated_image = color_match(generated_image, resized_image, mode='mkl') | |
| print("[OK] Standard color matching applied") | |
| except Exception as e: | |
| print(f"[COLOR] Matching failed: {e}") | |
| elif enable_color_matching: | |
| print("Applying standard color matching...") | |
| try: | |
| generated_image = color_match(generated_image, resized_image, mode='mkl') | |
| print("[OK] Standard color matching applied") | |
| except Exception as e: | |
| print(f"[COLOR] Matching failed: {e}") | |
| return generated_image | |
| finally: | |
| # Memory cleanup | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| print("[OK] Generator class ready with InstantID support") | |