""" Model loading and initialization for Pixagram AI Pixel Art Generator HYBRID VERSION - Supports both local files and HuggingFace repos """ import torch import time import os from diffusers import ( ControlNetModel, AutoencoderKL, LCMScheduler ) from diffusers.models.attention_processor import AttnProcessor2_0 from transformers import CLIPVisionModelWithProjection from insightface.app import FaceAnalysis from controlnet_aux import LeresDetector from controlnet_aux.processor import Processor from huggingface_hub import hf_hub_download from compel import Compel, ReturnedEmbeddingsType # Import the custom pipeline that has load_ip_adapter_instantid method from pipeline_stable_diffusion_xl_instantid_img2img import StableDiffusionXLInstantIDImg2ImgPipeline from config import ( device, dtype, MODEL_REPO, MODEL_FILES, HUGGINGFACE_TOKEN, FACE_DETECTION_CONFIG, CLIP_SKIP, DOWNLOAD_CONFIG ) def download_model_with_retry(repo_id, filename, max_retries=None): """Download model with retry logic and proper token handling.""" if max_retries is None: max_retries = DOWNLOAD_CONFIG['max_retries'] for attempt in range(max_retries): try: print(f" Attempting to download {filename} (attempt {attempt + 1}/{max_retries})...") kwargs = {"repo_type": "model"} if HUGGINGFACE_TOKEN: kwargs["token"] = HUGGINGFACE_TOKEN path = hf_hub_download( repo_id=repo_id, filename=filename, **kwargs ) print(f" [OK] Downloaded: {filename}") return path except Exception as e: print(f" [WARNING] Download attempt {attempt + 1} failed: {e}") if attempt < max_retries - 1: print(f" Retrying in {DOWNLOAD_CONFIG['retry_delay']} seconds...") time.sleep(DOWNLOAD_CONFIG['retry_delay']) else: print(f" [ERROR] Failed to download {filename} after {max_retries} attempts") raise return None def load_face_analysis(): """Load face analysis model with proper error handling.""" print("Loading face analysis model...") try: face_app = FaceAnalysis( name='antelopev2', root='/data', providers=['CPUExecutionProvider'] ) face_app.prepare( ctx_id=0, det_size=(640, 640) ) print(" [OK] Face analysis model loaded successfully") return face_app, True except Exception as e: print(f" [WARNING] Face detection not available: {e}") return None, False def load_depth_detector(): """Load Zoe Depth detector.""" print("Loading Zoe Depth detector...") try: zoe_depth = LeresDetector.from_pretrained( "lllyasviel/Annotators" ) zoe_depth.to(device) print(" [OK] Zoe Depth loaded successfully") return zoe_depth, True except Exception as e: print(f" [WARNING] Zoe Depth not available: {e}") return None, False def load_controlnets(): """Load ControlNet models.""" print("Loading ControlNet Zoe Depth model...") controlnet_depth = ControlNetModel.from_pretrained( "diffusers/controlnet-zoe-depth-sdxl-1.0", torch_dtype=dtype ).to(device) print(" [OK] ControlNet Depth loaded") print("Loading InstantID ControlNet...") try: controlnet_instantid = ControlNetModel.from_pretrained( "InstantX/InstantID", subfolder="ControlNetModel", torch_dtype=dtype ).to(device) print(" [OK] InstantID ControlNet loaded successfully") return controlnet_depth, controlnet_instantid, True except Exception as e: print(f" [WARNING] InstantID ControlNet not available: {e}") return controlnet_depth, None, False def load_image_encoder(): """Load CLIP Image Encoder for IP-Adapter.""" print("Loading CLIP Image Encoder for IP-Adapter...") try: image_encoder = CLIPVisionModelWithProjection.from_pretrained( "h94/IP-Adapter", subfolder="models/image_encoder", torch_dtype=dtype ).to(device) print(" [OK] CLIP Image Encoder loaded successfully") return image_encoder except Exception as e: print(f" [ERROR] Could not load image encoder: {e}") return None def load_sdxl_pipeline(controlnets): """ Load SDXL checkpoint - HYBRID APPROACH. Tries in order: 1. Local file via from_single_file (like examplemodels.py) 2. HuggingFace repo via from_pretrained (like exampleapp.py) 3. Fallback to known working checkpoint """ print("Loading SDXL checkpoint (hybrid approach)...") # ATTEMPT 1: Try loading from local file using from_single_file # This is the examplemodels.py approach if MODEL_FILES.get('checkpoint'): try: print(f" [Attempt 1] Loading from local file via from_single_file...") model_path = download_model_with_retry(MODEL_REPO, MODEL_FILES['checkpoint']) # Check if file exists and is a safetensors file if model_path and os.path.exists(model_path) and model_path.endswith('.safetensors'): pipe = StableDiffusionXLInstantIDImg2ImgPipeline.from_single_file( model_path, controlnet=controlnets, torch_dtype=dtype, use_safetensors=True ).to(device) print(f" [OK] Checkpoint loaded from local file: {model_path}") return pipe, True else: print(f" [INFO] Local file not found or invalid, trying next method...") except Exception as e: print(f" [WARNING] from_single_file failed: {e}") print(f" [INFO] Trying from_pretrained approach...") # ATTEMPT 2: Try loading from HuggingFace repo using from_pretrained # This is the exampleapp.py approach try: print(f" [Attempt 2] Loading from HuggingFace repo via from_pretrained...") pipe = StableDiffusionXLInstantIDImg2ImgPipeline.from_pretrained( MODEL_REPO, controlnet=controlnets, torch_dtype=dtype, use_safetensors=True ).to(device) print(f" [OK] Checkpoint loaded from HuggingFace repo: {MODEL_REPO}") return pipe, True except Exception as e: print(f" [WARNING] from_pretrained failed: {e}") print(f" [INFO] Trying fallback checkpoint...") # ATTEMPT 3: Fallback to known working checkpoint try: print(f" [Attempt 3] Loading fallback: frankjoshua/albedobaseXL_v21...") pipe = StableDiffusionXLInstantIDImg2ImgPipeline.from_pretrained( "frankjoshua/albedobaseXL_v21", controlnet=controlnets, torch_dtype=dtype, use_safetensors=True ).to(device) print(" [OK] Fallback checkpoint loaded successfully") return pipe, False except Exception as e: print(f" [WARNING] Fallback also failed: {e}") print(" [INFO] Trying SDXL base model...") # ATTEMPT 4: Last resort - SDXL base print(f" [Attempt 4] Loading base SDXL model...") pipe = StableDiffusionXLInstantIDImg2ImgPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnets, torch_dtype=dtype, use_safetensors=True ).to(device) print(" [OK] Base SDXL model loaded") return pipe, False def load_lora(pipe): """Load LORA from HuggingFace Hub.""" print("Loading LORA (retroart) from HuggingFace Hub...") try: lora_path = download_model_with_retry(MODEL_REPO, MODEL_FILES['lora']) pipe.load_lora_weights(lora_path, adapter_name="retroart") print(f" [OK] LORA loaded successfully") return True except Exception as e: print(f" [WARNING] Could not load LORA: {e}") return False def setup_ip_adapter(pipe): """ Setup IP-Adapter for InstantID - SIMPLIFIED VERSION. Uses pipeline's built-in method (like exampleapp.py lines 139-140). This is much simpler and more reliable than manual Resampler setup. """ print("Setting up IP-Adapter for InstantID face embeddings...") try: # Download InstantID IP-Adapter weights face_adapter_path = download_model_with_retry( "InstantX/InstantID", "ip-adapter.bin" ) # Use the pipeline's built-in method # This handles all the complex Resampler setup automatically pipe.load_ip_adapter_instantid(face_adapter_path) # Set initial scale (can be adjusted later during generation) pipe.set_ip_adapter_scale(0.8) print(" [OK] IP-Adapter loaded successfully with built-in method") print(" - Pipeline handles Resampler and attention processors automatically") print(" - Face embeddings will be properly integrated during generation") return True except Exception as e: print(f" [ERROR] Could not setup IP-Adapter: {e}") import traceback traceback.print_exc() return False def setup_compel(pipe): """Setup Compel for better SDXL prompt handling.""" print("Setting up Compel for enhanced prompt processing...") try: compel = Compel( tokenizer=[pipe.tokenizer, pipe.tokenizer_2], text_encoder=[pipe.text_encoder, pipe.text_encoder_2], returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, requires_pooled=[False, True] ) print(" [OK] Compel loaded successfully") return compel, True except Exception as e: print(f" [WARNING] Compel not available: {e}") return None, False def setup_scheduler(pipe): """Setup LCM scheduler.""" print("Setting up LCM scheduler...") pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config) print(" [OK] LCM scheduler configured") def optimize_pipeline(pipe): """Apply optimizations to pipeline.""" # Try to enable xformers if device == "cuda": try: pipe.enable_xformers_memory_efficient_attention() print(" [OK] xformers enabled") except Exception as e: print(f" [INFO] xformers not available: {e}") def load_caption_model(): """ Load caption model with proper error handling. Tries multiple models in order of quality. """ print("Loading caption model...") # Try GIT-Large first (good balance of quality and compatibility) try: from transformers import AutoProcessor, AutoModelForCausalLM print(" Attempting GIT-Large (recommended)...") caption_processor = AutoProcessor.from_pretrained("microsoft/git-large-coco") caption_model = AutoModelForCausalLM.from_pretrained( "microsoft/git-large-coco", torch_dtype=dtype ).to(device) print(" [OK] GIT-Large model loaded (produces detailed captions)") return caption_processor, caption_model, True, 'git' except Exception as e1: print(f" [INFO] GIT-Large not available: {e1}") # Try BLIP base as fallback try: from transformers import BlipProcessor, BlipForConditionalGeneration print(" Attempting BLIP base (fallback)...") caption_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") caption_model = BlipForConditionalGeneration.from_pretrained( "Salesforce/blip-image-captioning-base", torch_dtype=dtype ).to(device) print(" [OK] BLIP base model loaded (standard captions)") return caption_processor, caption_model, True, 'blip' except Exception as e2: print(f" [WARNING] Caption models not available: {e2}") print(" Caption generation will be disabled") return None, None, False, 'none' def set_clip_skip(pipe): """Set CLIP skip value.""" if hasattr(pipe, 'text_encoder'): print(f" [OK] CLIP skip set to {CLIP_SKIP}") print("[OK] Model loading functions ready (HYBRID VERSION)")