""" Model loading and initialization for Pixagram AI Pixel Art Generator FIXED VERSION - Uses correct InstantID pipeline and Compel encoder """ import torch import time import os from diffusers import ( ControlNetModel, AutoencoderKL, LCMScheduler ) from transformers import ( CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection ) from insightface.app import FaceAnalysis from controlnet_aux import ZoeDetector, OpenposeDetector, LeresDetector, MidasDetector, MediapipeFaceDetector from huggingface_hub import hf_hub_download, snapshot_download # --- START FIX: Import correct pipeline and Compel --- from pipeline_stable_diffusion_xl_instantid_img2img import StableDiffusionXLInstantIDImg2ImgPipeline from compel import Compel, ReturnedEmbeddingsType # --- END FIX --- from config import ( device, dtype, MODEL_REPO, MODEL_FILES, HUGGINGFACE_TOKEN, FACE_DETECTION_CONFIG, CLIP_SKIP, DOWNLOAD_CONFIG ) # (We keep download_model_with_retry, load_face_analysis, load_depth_detector, # load_openpose_detector, and load_mediapipe_face_detector as they were) # ... (Keep all original functions from line 25 down to line 180) ... def download_model_with_retry(repo_id, filename, max_retries=None, **kwargs): """Download model with retry logic and proper token handling.""" if max_retries is None: max_retries = DOWNLOAD_CONFIG['max_retries'] # Ensure token is passed if available if HUGGINGFACE_TOKEN and "token" not in kwargs: kwargs["token"] = HUGGINGFACE_TOKEN for attempt in range(max_retries): try: print(f" Attempting to download {filename} (attempt {attempt + 1}/{max_retries})...") return hf_hub_download( repo_id=repo_id, filename=filename, **kwargs ) except Exception as e: print(f" [WARNING] Download attempt {attempt + 1} failed: {e}") if attempt < max_retries - 1: print(f" Retrying in {DOWNLOAD_CONFIG['retry_delay']} seconds...") time.sleep(DOWNLOAD_CONFIG['retry_delay']) else: print(f" [ERROR] Failed to download {filename} after {max_retries} attempts") raise return None def load_face_analysis(): """ Load face analysis model with proper model downloading from HuggingFace. Downloads from DIAMONIK7777/antelopev2 which has the correct model structure. """ print("Loading face analysis model...") try: antelope_download = snapshot_download(repo_id="DIAMONIK7777/antelopev2", local_dir="/data/models/antelopev2") # --- FIX: Load InsightFace on CPU to save VRAM --- face_app = FaceAnalysis(name='antelopev2', root='/data', providers=['CPUExecutionProvider']) face_app.prepare(ctx_id=0, det_size=(640, 640)) print(" [OK] Face analysis loaded (on CPU)") return face_app, True except Exception as e: print(f" [ERROR] Face detection not available: {e}") import traceback traceback.print_exc() return None, False def load_depth_detector(): """ Load depth detector with fallback hierarchy: Leres → Zoe → Midas. Returns (detector, detector_type, success). """ print("Loading depth detector with fallback hierarchy...") # Try LeresDetector first (best quality) try: print(" Attempting LeresDetector (highest quality)...") # --- FIX: Load on CPU --- leres_depth = LeresDetector.from_pretrained("lllyasviel/Annotators") # leres_depth.to(device) # Removed print(" [OK] LeresDetector loaded successfully (on CPU)") return leres_depth, 'leres', True except Exception as e: print(f" [INFO] LeresDetector not available: {e}") # Fallback to ZoeDetector try: print(" Attempting ZoeDetector (fallback #1)...") # --- FIX: Load on CPU --- zoe_depth = ZoeDetector.from_pretrained("lllyasviel/Annotators") # zoe_depth.to(device) # Removed print(" [OK] ZoeDetector loaded successfully (on CPU)") return zoe_depth, 'zoe', True except Exception as e: print(f" [INFO] ZoeDetector not available: {e}") # Final fallback to MidasDetector try: print(" Attempting MidasDetector (fallback #2)...") # --- FIX: Load on CPU --- midas_depth = MidasDetector.from_pretrained("lllyasviel/Annotators") # midas_depth.to(device) # Removed print(" [OK] MidasDetector loaded successfully (on CPU)") return midas_depth, 'midas', True except Exception as e: print(f" [WARNING] MidasDetector not available: {e}") print(" [ERROR] No depth detector available") return None, None, False # --- NEW FUNCTION --- def load_openpose_detector(): """Load OpenPose detector.""" print("Loading OpenPose detector...") try: # --- FIX: Load on CPU --- openpose = OpenposeDetector.from_pretrained("lllyasviel/Annotators") # openpose.to(device) # Removed print(" [OK] OpenPose loaded successfully (on CPU)") return openpose, True except Exception as e: print(f" [WARNING] OpenPose not available: {e}") return None, False # --- END NEW FUNCTION --- # --- NEW FUNCTION --- def load_mediapipe_face_detector(): """Load MediapipeFaceDetector for advanced face detection.""" print("Loading MediapipeFaceDetector...") try: face_detector = MediapipeFaceDetector() print(" [OK] MediapipeFaceDetector loaded successfully") return face_detector, True except Exception as e: print(f" [WARNING] MediapipeFaceDetector not available: {e}") return None, False # --- END NEW FUNCTION --- def load_controlnets(): """Load ControlNet models.""" print("Loading ControlNet Zoe Depth model...") # --- FIX: Load core models on GPU --- controlnet_depth = ControlNetModel.from_pretrained( "xinsir/controlnet-depth-sdxl-1.0", torch_dtype=dtype ).to(device) print(" [OK] ControlNet Depth loaded (on GPU)") # --- NEW: Load OpenPose ControlNet --- print("Loading ControlNet OpenPose model...") try: # --- FIX: Load core models on GPU --- controlnet_openpose = ControlNetModel.from_pretrained( "xinsir/controlnet-openpose-sdxl-1.0", torch_dtype=dtype ).to(device) print(" [OK] ControlNet OpenPose loaded (on GPU)") except Exception as e: print(f" [WARNING] ControlNet OpenPose not available: {e}") controlnet_openpose = None # --- END NEW --- print("Loading InstantID ControlNet...") try: # --- FIX: Load core models on GPU --- controlnet_instantid = ControlNetModel.from_pretrained( "InstantX/InstantID", subfolder="ControlNetModel", torch_dtype=dtype ).to(device) print(" [OK] InstantID ControlNet loaded successfully (on GPU)") # Return all three models return controlnet_depth, controlnet_instantid, controlnet_openpose, True except Exception as e: print(f" [WARNING] InstantID ControlNet not available: {e}") # Return models, indicating InstantID failure return controlnet_depth, None, controlnet_openpose, False # --- START: REMOVED load_image_encoder --- # (The new pipeline handles this internally) # --- END: REMOVED load_image_encoder --- def load_sdxl_pipeline(controlnets): """Load SDXL checkpoint from HuggingFace Hub.""" print("Loading SDXL checkpoint (horizon) with bundled VAE from HuggingFace Hub...") # --- START FIX: Load base text models for Compel (from previous fix) --- print(" Loading base tokenizers and text encoders...") BASE_MODEL = "stabilityai/stable-diffusion-xl-base-1.0" tokenizer = CLIPTokenizer.from_pretrained(BASE_MODEL, subfolder="tokenizer") tokenizer_2 = CLIPTokenizer.from_pretrained(BASE_MODEL, subfolder="tokenizer_2") text_encoder = CLIPTextModel.from_pretrained( BASE_MODEL, subfolder="text_encoder", torch_dtype=dtype ).to(device) text_encoder_2 = CLIPTextModelWithProjection.from_pretrained( BASE_MODEL, subfolder="text_encoder_2", torch_dtype=dtype ).to(device) print(" [OK] Base text/token models loaded") # --- END FIX --- try: model_path = download_model_with_retry(MODEL_REPO, MODEL_FILES['checkpoint'], repo_type="model") # --- START FIX: Load the CORRECT pipeline --- pipe = StableDiffusionXLInstantIDImg2ImgPipeline.from_single_file( model_path, controlnet=controlnets, torch_dtype=dtype, use_safetensors=True, # Pass components tokenizer=tokenizer, tokenizer_2=tokenizer_2, text_encoder=text_encoder, text_encoder_2=text_encoder_2, ).to(device) # --- END FIX --- print(" [OK] Custom checkpoint loaded successfully (VAE bundled)") return pipe, True except Exception as e: print(f" [WARNING] Could not load custom checkpoint: {e}") print(" Using default SDXL base model") # --- START FIX: Fallback to the CORRECT pipeline --- pipe = StableDiffusionXLInstantIDImg2ImgPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnets, torch_dtype=dtype, use_safetensors=True, # Pass components tokenizer=tokenizer, tokenizer_2=tokenizer_2, text_encoder=text_encoder, text_encoder_2=text_encoder_2, ).to(device) # --- END FIX --- return pipe, False def load_loras(pipe): """Load all LORAs from HuggingFace Hub.""" print("Loading all LORAs from HuggingFace Hub...") loaded_loras = {} lora_files = { "retroart": MODEL_FILES.get("lora_retroart"), "vga": MODEL_FILES.get("lora_vga"), "lucasart": MODEL_FILES.get("lora_lucasart") } for adapter_name, filename in lora_files.items(): if not filename: print(f" [INFO] No file specified for LORA '{adapter_name}', skipping.") loaded_loras[adapter_name] = False continue try: lora_path = download_model_with_retry(MODEL_REPO, filename, repo_type="model") pipe.load_lora_weights(lora_path, adapter_name=adapter_name) print(f" [OK] LORA loaded successfully: {filename} as '{adapter_name}'") loaded_loras[adapter_name] = True except Exception as e: print(f" [WARNING] Could not load LORA {filename}: {e}") loaded_loras[adapter_name] = False success = any(loaded_loras.values()) if not success: print(" [WARNING] No LORAs were loaded successfully.") return loaded_loras, success # --- START FIX: Replace setup_ip_adapter --- def setup_ip_adapter(pipe): """ Setup IP-Adapter for InstantID face embeddings using the pipeline's method. """ print("Setting up IP-Adapter for InstantID face embeddings...") try: # Download InstantID weights ip_adapter_path = download_model_with_retry( "InstantX/InstantID", "ip-adapter.bin", repo_type="model" ) # Use the pipeline's built-in loader pipe.load_ip_adapter_instantid(ip_adapter_path) print(" [OK] IP-Adapter fully loaded via pipeline") return None, True # We don't need to return a model except Exception as e: print(f" [ERROR] Could not setup IP-Adapter: {e}") import traceback traceback.print_exc() return None, False # --- END FIX --- # --- START FIX: Replace setup_cappella with setup_compel --- def setup_compel(pipe): """Setup Compel for robust prompt encoding.""" print("Setting up Compel (prompt encoder)...") try: compel = Compel( tokenizer=[pipe.tokenizer, pipe.tokenizer_2], text_encoder=[pipe.text_encoder, pipe.text_encoder_2], returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, requires_pooled=[False, True] ) print(" [OK] Compel loaded successfully.") return compel, True except Exception as e: print(f" [WARNING] Compel not available: {e}") return None, False # --- END FIX --- def setup_scheduler(pipe): """Setup LCM scheduler.""" print("Setting up LCM scheduler...") pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config) print(" [OK] LCM scheduler configured") def optimize_pipeline(pipe): """Apply optimizations to pipeline.""" if device == "cuda": try: pipe.enable_xformers_memory_efficient_attention() print(" [OK] xformers enabled") except Exception as e: print(f" [INFO] xformers not available: {e}") def load_caption_model(): """ Load caption model with proper error handling. Tries multiple models in order of quality. """ print("Loading caption model...") # Try GIT-Large first (good balance of quality and compatibility) try: from transformers import AutoProcessor, AutoModelForCausalLM print(" Attempting GIT-Large (recommended)...") caption_processor = AutoProcessor.from_pretrained("microsoft/git-large-coco") caption_model = AutoModelForCausalLM.from_pretrained( "microsoft/git-large-coco", torch_dtype=dtype ) print(" [OK] GIT-Large model loaded (produces detailed captions, on CPU)") return caption_processor, caption_model, True, 'git' except Exception as e1: print(f" [INFO] GIT-Large not available: {e1}") # Try BLIP base as fallback try: from transformers import BlipProcessor, BlipForConditionalGeneration print(" Attempting BLIP base (fallback)...") caption_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") caption_model = BlipForConditionalGeneration.from_pretrained( "Salesforce/blip-image-captioning-base", torch_dtype=dtype ) print(" [OK] BLIP base model loaded (standard captions, on CPU)") return caption_processor, caption_model, True, 'blip' except Exception as e2: print(f" [WARNING] Caption models not available: {e2}") print(" Caption generation will be disabled") return None, None, False, 'none' def set_clip_skip(pipe): """Set CLIP skip value.""" if hasattr(pipe, 'text_encoder'): print(f" [OK] CLIP skip set to {CLIP_SKIP}") print("[OK] Model loading functions ready")