Spaces:
Runtime error
Runtime error
| """ | |
| Model loading and initialization for Pixagram AI Pixel Art Generator | |
| HYBRID VERSION - Supports both local files and HuggingFace repos | |
| MODIFIED for IP-Adapter-FaceIDXL (non-plus) and LCM Scheduler | |
| """ | |
| import torch | |
| import time | |
| import os | |
| from diffusers import ( | |
| ControlNetModel, | |
| AutoencoderKL, | |
| LCMScheduler, # Changed back to LCM | |
| StableDiffusionXLControlNetImg2ImgPipeline | |
| ) | |
| from diffusers.models.attention_processor import AttnProcessor2_0 | |
| from transformers import CLIPVisionModelWithProjection, pipeline | |
| from insightface.app import FaceAnalysis | |
| from controlnet_aux import LeresDetector, CannyDetector | |
| from huggingface_hub import hf_hub_download | |
| from compel import Compel, ReturnedEmbeddingsType | |
| # Import the IP-Adapter wrapper classes | |
| try: | |
| # Import base class and the specific SDXL class | |
| from ip_adapter.ip_adapter_faceid import IPAdapterFaceID, IPAdapterFaceIDXL | |
| except ImportError: | |
| print("="*80) | |
| print("[FATAL ERROR] `ip_adapter` library not found.") | |
| print("Please install it: pip install ip-adapter") | |
| print("="*80) | |
| raise | |
| from config import ( | |
| device, dtype, MODEL_REPO, MODEL_FILES, HUGGINGFACE_TOKEN, | |
| FACE_DETECTION_CONFIG, CLIP_SKIP, DOWNLOAD_CONFIG | |
| ) | |
| def download_model_with_retry(repo_id, filename, max_retries=None): | |
| """Download model with retry logic and proper token handling.""" | |
| if max_retries is None: | |
| max_retries = DOWNLOAD_CONFIG['max_retries'] | |
| for attempt in range(max_retries): | |
| try: | |
| print(f" Attempting to download {filename} (attempt {attempt + 1}/{max_retries})...") | |
| kwargs = {"repo_type": "model"} | |
| if HUGGINGFACE_TOKEN: | |
| kwargs["token"] = HUGGINGFACE_TOKEN | |
| path = hf_hub_download( | |
| repo_id=repo_id, | |
| filename=filename, | |
| **kwargs | |
| ) | |
| print(f" [OK] Downloaded: {filename}") | |
| return path | |
| except Exception as e: | |
| print(f" [WARNING] Download attempt {attempt + 1} failed: {e}") | |
| if attempt < max_retries - 1: | |
| print(f" Retrying in {DOWNLOAD_CONFIG['retry_delay']} seconds...") | |
| time.sleep(DOWNLOAD_CONFIG['retry_delay']) | |
| else: | |
| print(f" [ERROR] Failed to download {filename} after {max_retries} attempts") | |
| raise | |
| return None | |
| def load_face_analysis(): | |
| """Load face analysis model (buffalo_l) with proper error handling.""" | |
| print("Loading face analysis model (buffalo_l)...") | |
| try: | |
| face_app = FaceAnalysis( | |
| name='buffalo_l', # Changed from antelopev2 | |
| root='/data', | |
| providers=['CUDAExecutionProvider', 'CPUExecutionProvider'] | |
| ) | |
| face_app.prepare( | |
| ctx_id=0, | |
| det_size=(640, 640) | |
| ) | |
| print(" [OK] Face analysis model (buffalo_l) loaded successfully") | |
| return face_app, True | |
| except Exception as e: | |
| print(f" [WARNING] Face detection not available: {e}") | |
| return None, False | |
| def load_depth_detector(): | |
| """Load LeReS++ Depth detector.""" | |
| print("Loading LeReS++ detector...") | |
| try: | |
| leres = LeresDetector.from_pretrained("lllyasviel/Annotators") | |
| leres.to(device) | |
| print(" [OK] LeReS++ loaded successfully") | |
| return leres, True | |
| except Exception as e: | |
| print(f" [WARNING] LeReS++ not available: {e}") | |
| return None, False | |
| def load_canny_detector(): | |
| """Load Canny detector.""" | |
| print("Loading Canny detector...") | |
| try: | |
| canny = CannyDetector() | |
| print(" [OK] Canny loaded successfully") | |
| return canny, True | |
| except Exception as e: | |
| print(f" [WARNING] Canny detector not available: {e}") | |
| return None, False | |
| def load_controlnets(): | |
| """Load ControlNet models for Depth and Canny.""" | |
| print("Loading ControlNet Depth model...") | |
| controlnet_depth = ControlNetModel.from_pretrained( | |
| "diffusers/controlnet-depth-sdxl-1.0", # Standard depth model | |
| torch_dtype=dtype | |
| ).to(device) | |
| print(" [OK] ControlNet Depth loaded") | |
| print("Loading ControlNet Canny model...") | |
| try: | |
| controlnet_canny = ControlNetModel.from_pretrained( | |
| "diffusers/controlnet-canny-sdxl-1.0", | |
| torch_dtype=dtype | |
| ).to(device) | |
| print(" [OK] ControlNet Canny loaded successfully") | |
| return controlnet_depth, controlnet_canny, True | |
| except Exception as e: | |
| print(f" [WARNING] ControlNet Canny not available: {e}") | |
| return controlnet_depth, None, False | |
| def load_image_encoder(): | |
| """ | |
| [DEPRECATED] This function is no longer needed by IPAdapterFaceIDXL, | |
| but we keep it here in case other components need it. | |
| It will not be called by the generator. | |
| """ | |
| print("Loading CLIP Image Encoder [SKIPPED - Not required by IPAdapterFaceIDXL]") | |
| return None | |
| def load_sdxl_pipeline(controlnets): | |
| """ | |
| Load SDXL checkpoint - MODIFIED for LCM and built-in VAE. | |
| """ | |
| # --- VAE LOADING REMOVED --- | |
| # We are using the VAE built into the "horizon" checkpoint. | |
| print("Loading SDXL checkpoint (using built-in VAE)...") | |
| pipeline_kwargs = { | |
| "controlnet": controlnets, | |
| "torch_dtype": dtype, | |
| "use_safetensors": True, | |
| # "vae": None, # <--- This line was correctly removed | |
| } | |
| # ATTEMPT 1: Try loading from local file (This should be your "horizon" checkpoint) | |
| if MODEL_FILES.get('checkpoint'): | |
| try: | |
| print(f" [Attempt 1] Loading from local file: {MODEL_FILES['checkpoint']}...") | |
| model_path = download_model_with_retry(MODEL_REPO, MODEL_FILES['checkpoint']) | |
| if model_path and os.path.exists(model_path) and model_path.endswith('.safetensors'): | |
| pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_single_file( | |
| model_path, | |
| **pipeline_kwargs | |
| ).to(device) | |
| print(f" [OK] Checkpoint loaded from local file: {model_path}") | |
| return pipe, True | |
| else: | |
| print(f" [INFO] Local file not found or invalid...") | |
| except Exception as e: | |
| print(f" [WARNING] from_single_file failed: {e}") | |
| # ATTEMPT 2: Try loading from HuggingFace repo | |
| try: | |
| print(f" [Attempt 2] Loading from HuggingFace repo: {MODEL_REPO}...") | |
| pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained( | |
| MODEL_REPO, | |
| **pipeline_kwargs | |
| ).to(device) | |
| print(f" [OK] Checkpoint loaded from HuggingFace repo: {MODEL_REPO}") | |
| return pipe, True | |
| except Exception as e: | |
| print(f" [WARNING] from_pretrained failed: {e}") | |
| # ATTEMPT 3: Fallback (Base SDXL) | |
| print(f" [Attempt 3] Loading base SDXL model...") | |
| pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained( | |
| "stabilityai/stable-diffusion-xl-base-1.0", | |
| **pipeline_kwargs | |
| ).to(device) | |
| print(" [OK] Base SDXL model loaded") | |
| return pipe, False | |
| def load_lora(pipe): | |
| """Load LORA (retroart) from HuggingFace Hub.""" | |
| print("Loading LORA (retroart) from HuggingFace Hub...") | |
| try: | |
| lora_path = download_model_with_retry(MODEL_REPO, MODEL_FILES['lora']) | |
| pipe.load_lora_weights(lora_path, adapter_name="retroart") | |
| print(f" [OK] LORA loaded successfully") | |
| return True | |
| except Exception as e: | |
| print(f" [WARNING] Could not load LORA: {e}") | |
| return False | |
| def setup_ip_adapter(pipe): | |
| """ | |
| Setup IP-Adapter-FaceIDXL wrapper. | |
| [FIXED] Does not take image_encoder_path. | |
| """ | |
| print("Setting up IP-Adapter-FaceIDXL...") | |
| try: | |
| # Download the SDXL non-plus FaceID model | |
| ip_ckpt_path = hf_hub_download( | |
| repo_id="h94/IP-Adapter-FaceID", | |
| filename="ip-adapter-faceid_sdxl.bin", | |
| token=HUGGINGFACE_TOKEN | |
| ) | |
| # --- [FIX] Instantiate without image_encoder_path --- | |
| ip_model = IPAdapterFaceIDXL(pipe, ip_ckpt_path, device) | |
| print(" [OK] IPAdapterFaceIDXL wrapper initialized successfully.") | |
| return ip_model, True | |
| except Exception as e: | |
| print(f" [ERROR] Could not setup IP-Adapter: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return None, False | |
| def setup_compel(pipe): | |
| """Setup Compel for better SDXL prompt handling.""" | |
| print("Setting up Compel for enhanced prompt processing...") | |
| try: | |
| compel = Compel( | |
| tokenizer=[pipe.tokenizer, pipe.tokenizer_2], | |
| text_encoder=[pipe.text_encoder, pipe.text_encoder_2], | |
| returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, | |
| requires_pooled=[False, True] | |
| ) | |
| print(" [OK] Compel loaded successfully") | |
| return compel, True | |
| except Exception as e: | |
| print(f" [WARNING] Compel not available: {e}") | |
| return None, False | |
| def setup_scheduler(pipe): | |
| """Setup LCM scheduler.""" | |
| print("Setting up LCM scheduler...") | |
| pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config) | |
| print(" [OK] LCM scheduler configured") | |
| def optimize_pipeline(pipe): | |
| """Apply optimizations to pipeline.""" | |
| # Try to enable xformers | |
| if device == "cuda": | |
| try: | |
| pipe.enable_xformers_memory_efficient_attention() | |
| print(" [OK] xformers enabled") | |
| except Exception as e: | |
| print(f" [INFO] xformers not available: {e}") | |
| def load_caption_model(): | |
| """ | |
| Load caption model with proper error handling. | |
| Tries multiple models in order of quality. | |
| """ | |
| print("Loading caption model...") | |
| # Try GIT-Large first (good balance of quality and compatibility) | |
| try: | |
| from transformers import AutoProcessor, AutoModelForCausalLM | |
| print(" Attempting GIT-Large (recommended)...") | |
| caption_processor = AutoProcessor.from_pretrained("microsoft/git-large-coco") | |
| caption_model = AutoModelForCausalLM.from_pretrained( | |
| "microsoft/git-large-coco", | |
| torch_dtype=dtype | |
| ).to(device) | |
| print(" [OK] GIT-Large model loaded (produces detailed captions)") | |
| return caption_processor, caption_model, True, 'git' | |
| except Exception as e1: | |
| print(f" [INFO] GIT-Large not available: {e1}") | |
| # Try BLIP base as fallback | |
| try: | |
| from transformers import BlipProcessor, BlipForConditionalGeneration | |
| print(" Attempting BLIP base (fallback)...") | |
| caption_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") | |
| caption_model = BlipForConditionalGeneration.from_pretrained( | |
| "Salesforce/blip-image-captioning-base", | |
| torch_dtype=dtype | |
| ).to(device) | |
| print(" [OK] BLIP base model loaded (standard captions)") | |
| return caption_processor, caption_model, True, 'blip' | |
| except Exception as e2: | |
| print(f" [WARNING] Caption models not available: {e2}") | |
| print(" Caption generation will be disabled") | |
| return None, None, False, 'none' | |
| def set_clip_skip(pipe): | |
| """Set CLIP skip value.""" | |
| if hasattr(pipe, 'text_encoder'): | |
| print(f" [OK] CLIP skip set to {CLIP_SKIP}") | |
| print("[OK] Model loading functions ready (IP-Adapter-FaceIDXL / LCM VERSION)") |