Spaces:
Runtime error
Runtime error
| """ | |
| Model loading and initialization for Pixagram AI Pixel Art Generator | |
| Torch 2.1.1 optimized with Depth Anything V2 | |
| """ | |
| import torch | |
| import time | |
| from diffusers import ( | |
| StableDiffusionXLControlNetImg2ImgPipeline, | |
| ControlNetModel, | |
| AutoencoderKL, | |
| LCMScheduler | |
| ) | |
| from diffusers.models.attention_processor import AttnProcessor2_0 | |
| from transformers import CLIPVisionModelWithProjection | |
| from transformers import BlipProcessor, BlipForConditionalGeneration | |
| from insightface.app import FaceAnalysis | |
| from controlnet_aux import ZoeDetector | |
| from huggingface_hub import hf_hub_download | |
| from compel import Compel, ReturnedEmbeddingsType | |
| from ip_attention_processor_compatible import IPAttnProcessorCompatible as IPAttnProcessor2_0 | |
| from resampler_compatible import create_compatible_resampler | |
| from config import ( | |
| device, dtype, MODEL_REPO, MODEL_FILES, HUGGINGFACE_TOKEN, | |
| FACE_DETECTION_CONFIG, CLIP_SKIP, DOWNLOAD_CONFIG | |
| ) | |
| def download_model_with_retry(repo_id, filename, max_retries=None): | |
| """Download model with retry logic and proper token handling.""" | |
| if max_retries is None: | |
| max_retries = DOWNLOAD_CONFIG['max_retries'] | |
| for attempt in range(max_retries): | |
| try: | |
| print(f" Attempting to download {filename} (attempt {attempt + 1}/{max_retries})...") | |
| kwargs = {"repo_type": "model"} | |
| if HUGGINGFACE_TOKEN: | |
| kwargs["token"] = HUGGINGFACE_TOKEN | |
| path = hf_hub_download( | |
| repo_id=repo_id, | |
| filename=filename, | |
| **kwargs | |
| ) | |
| print(f" [OK] Downloaded: {filename}") | |
| return path | |
| except Exception as e: | |
| print(f" [WARNING] Download attempt {attempt + 1} failed: {e}") | |
| if attempt < max_retries - 1: | |
| print(f" Retrying in {DOWNLOAD_CONFIG['retry_delay']} seconds...") | |
| time.sleep(DOWNLOAD_CONFIG['retry_delay']) | |
| else: | |
| print(f" [ERROR] Failed to download {filename} after {max_retries} attempts") | |
| raise | |
| return None | |
| def load_face_analysis(): | |
| """ | |
| Load face analysis with GPU/CPU fallback. | |
| Critical fix: InsightFace often fails on GPU, CPU fallback essential. | |
| """ | |
| print("Loading face analysis model...") | |
| # Try GPU first | |
| try: | |
| face_app = FaceAnalysis( | |
| name=FACE_DETECTION_CONFIG['model_name'], | |
| root='./models/insightface', | |
| providers=['CUDAExecutionProvider', 'CPUExecutionProvider'] | |
| ) | |
| face_app.prepare( | |
| ctx_id=FACE_DETECTION_CONFIG['ctx_id'], | |
| det_size=FACE_DETECTION_CONFIG['det_size'] | |
| ) | |
| print(" [OK] Face analysis loaded (GPU)") | |
| return face_app, True | |
| except Exception as e: | |
| print(f" [WARNING] GPU face detection failed: {e}") | |
| # Fallback to CPU | |
| try: | |
| print(" [INFO] Trying CPU fallback...") | |
| face_app = FaceAnalysis( | |
| name=FACE_DETECTION_CONFIG['model_name'], | |
| root='./models/insightface', | |
| providers=['CPUExecutionProvider'] | |
| ) | |
| face_app.prepare( | |
| ctx_id=-1, # CPU context | |
| det_size=FACE_DETECTION_CONFIG['det_size'] | |
| ) | |
| print(" [OK] Face analysis loaded (CPU fallback)") | |
| return face_app, True | |
| except Exception as e: | |
| print(f" [ERROR] Face detection not available: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return None, False | |
| def load_depth_anything_v2(): | |
| """ | |
| Load Depth Anything V2 - faster and better quality than Zoe. | |
| 3-5x faster, sharper details, Apache 2.0 license (Small model). | |
| """ | |
| print("Loading Depth Anything V2 (3-5x faster than Zoe)...") | |
| try: | |
| from transformers import pipeline | |
| depth_pipe = pipeline( | |
| task="depth-estimation", | |
| model="depth-anything/Depth-Anything-V2-Small", | |
| device=0 if device == "cuda" else -1 | |
| ) | |
| print(" [OK] Depth Anything V2 loaded (state-of-the-art quality)") | |
| return depth_pipe, True | |
| except Exception as e: | |
| print(f" [WARNING] Depth Anything V2 not available: {e}") | |
| return None, False | |
| def load_depth_detector(): | |
| """ | |
| Load depth detector with fallback chain: | |
| 1. Depth Anything V2 (fastest, best quality) | |
| 2. Zoe Depth (fallback) | |
| 3. Grayscale (emergency fallback) | |
| """ | |
| # Try Depth Anything V2 first | |
| depth_anything, success = load_depth_anything_v2() | |
| if success: | |
| return depth_anything, True, "depth_anything_v2" | |
| # Fallback to Zoe | |
| print("Loading Zoe Depth detector (fallback)...") | |
| try: | |
| zoe_depth = ZoeDetector.from_pretrained("lllyasviel/Annotators") | |
| zoe_depth.to(device) | |
| print(" [OK] Zoe Depth loaded") | |
| return zoe_depth, True, "zoe" | |
| except Exception as e: | |
| print(f" [WARNING] Zoe Depth not available: {e}") | |
| return None, False, "grayscale" | |
| def load_controlnets(): | |
| """Load ControlNet models.""" | |
| print("Loading ControlNet Zoe Depth model...") | |
| controlnet_depth = ControlNetModel.from_pretrained( | |
| "diffusers/controlnet-zoe-depth-sdxl-1.0", | |
| torch_dtype=dtype | |
| ).to(device) | |
| print(" [OK] ControlNet Depth loaded") | |
| print("Loading InstantID ControlNet...") | |
| try: | |
| controlnet_instantid = ControlNetModel.from_pretrained( | |
| "InstantX/InstantID", | |
| subfolder="ControlNetModel", | |
| torch_dtype=dtype | |
| ).to(device) | |
| print(" [OK] InstantID ControlNet loaded") | |
| return controlnet_depth, controlnet_instantid, True | |
| except Exception as e: | |
| print(f" [WARNING] InstantID ControlNet not available: {e}") | |
| return controlnet_depth, None, False | |
| def load_image_encoder(): | |
| """Load CLIP Image Encoder for IP-Adapter.""" | |
| print("Loading CLIP Image Encoder...") | |
| try: | |
| image_encoder = CLIPVisionModelWithProjection.from_pretrained( | |
| "h94/IP-Adapter", | |
| subfolder="models/image_encoder", | |
| torch_dtype=dtype | |
| ).to(device) | |
| print(" [OK] CLIP Image Encoder loaded") | |
| return image_encoder | |
| except Exception as e: | |
| print(f" [ERROR] Could not load image encoder: {e}") | |
| return None | |
| def load_sdxl_pipeline(controlnets): | |
| """Load SDXL checkpoint.""" | |
| print("Loading SDXL checkpoint (horizon) from HuggingFace Hub...") | |
| try: | |
| model_path = download_model_with_retry(MODEL_REPO, MODEL_FILES['checkpoint']) | |
| pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_single_file( | |
| model_path, | |
| controlnet=controlnets, | |
| torch_dtype=dtype, | |
| use_safetensors=True | |
| ).to(device) | |
| print(" [OK] Custom checkpoint loaded") | |
| return pipe, True | |
| except Exception as e: | |
| print(f" [WARNING] Could not load custom checkpoint: {e}") | |
| print(" Using default SDXL base") | |
| pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained( | |
| "stabilityai/stable-diffusion-xl-base-1.0", | |
| controlnet=controlnets, | |
| torch_dtype=dtype, | |
| use_safetensors=True | |
| ).to(device) | |
| return pipe, False | |
| def load_lora(pipe): | |
| """Load LORA.""" | |
| print("Loading LORA (retroart) from HuggingFace Hub...") | |
| try: | |
| lora_path = download_model_with_retry(MODEL_REPO, MODEL_FILES['lora']) | |
| pipe.load_lora_weights(lora_path) | |
| print(f" [OK] LORA loaded") | |
| return True | |
| except Exception as e: | |
| print(f" [WARNING] Could not load LORA: {e}") | |
| return False | |
| def setup_ip_adapter(pipe, image_encoder): | |
| """Setup IP-Adapter with compatible architecture.""" | |
| if image_encoder is None: | |
| return None, False | |
| print("Setting up IP-Adapter...") | |
| try: | |
| ip_adapter_path = download_model_with_retry("InstantX/InstantID", "ip-adapter.bin") | |
| ip_adapter_state_dict = torch.load(ip_adapter_path, map_location="cpu") | |
| image_proj_state_dict = {} | |
| ip_state_dict = {} | |
| for key, value in ip_adapter_state_dict.items(): | |
| if key.startswith("image_proj."): | |
| image_proj_state_dict[key.replace("image_proj.", "")] = value | |
| elif key.startswith("ip_adapter."): | |
| ip_state_dict[key.replace("ip_adapter.", "")] = value | |
| print("Creating Compatible Perceiver Resampler...") | |
| # Create resampler with compatible architecture | |
| image_proj_model = create_compatible_resampler( | |
| num_queries=4, | |
| embedding_dim=512, | |
| output_dim=pipe.unet.config.cross_attention_dim, | |
| device=device, | |
| dtype=dtype | |
| ) | |
| # Load pretrained weights | |
| try: | |
| if 'latents' in image_proj_state_dict: | |
| image_proj_model.load_state_dict(image_proj_state_dict, strict=False) | |
| print(" [OK] Resampler loaded with pretrained weights") | |
| else: | |
| print(" [INFO] Using randomly initialized Resampler") | |
| except Exception as e: | |
| print(f" [INFO] Resampler weights: {e}") | |
| # Setup attention processors | |
| attn_procs = {} | |
| for name in pipe.unet.attn_processors.keys(): | |
| cross_attention_dim = None if name.endswith("attn1.processor") else pipe.unet.config.cross_attention_dim | |
| if name.startswith("mid_block"): | |
| hidden_size = pipe.unet.config.block_out_channels[-1] | |
| elif name.startswith("up_blocks"): | |
| block_id = int(name[len("up_blocks.")]) | |
| hidden_size = list(reversed(pipe.unet.config.block_out_channels))[block_id] | |
| elif name.startswith("down_blocks"): | |
| block_id = int(name[len("down_blocks.")]) | |
| hidden_size = pipe.unet.config.block_out_channels[block_id] | |
| if cross_attention_dim is None: | |
| attn_procs[name] = AttnProcessor2_0() | |
| else: | |
| attn_procs[name] = IPAttnProcessor2_0( | |
| hidden_size=hidden_size, | |
| cross_attention_dim=cross_attention_dim, | |
| scale=1.0, | |
| num_tokens=4 | |
| ).to(device, dtype=dtype) | |
| pipe.unet.set_attn_processor(attn_procs) | |
| ip_layers = torch.nn.ModuleList(pipe.unet.attn_processors.values()) | |
| ip_layers.load_state_dict(ip_state_dict, strict=False) | |
| print(" [OK] IP-Adapter loaded with InstantID weights") | |
| pipe.image_encoder = image_encoder | |
| return image_proj_model, True | |
| except Exception as e: | |
| print(f" [ERROR] Could not load IP-Adapter: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return None, False | |
| def setup_compel(pipe): | |
| """Setup Compel.""" | |
| print("Setting up Compel...") | |
| try: | |
| compel = Compel( | |
| tokenizer=[pipe.tokenizer, pipe.tokenizer_2], | |
| text_encoder=[pipe.text_encoder, pipe.text_encoder_2], | |
| returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, | |
| requires_pooled=[False, True] | |
| ) | |
| print(" [OK] Compel loaded") | |
| return compel, True | |
| except Exception as e: | |
| print(f" [WARNING] Compel not available: {e}") | |
| return None, False | |
| def setup_scheduler(pipe): | |
| """Setup LCM scheduler.""" | |
| print("Setting up LCM scheduler...") | |
| pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config) | |
| print(" [OK] LCM scheduler configured") | |
| def optimize_pipeline(pipe): | |
| """Apply torch 2.1.1 optimizations.""" | |
| # Enable attention optimizations | |
| pipe.unet.set_attn_processor(AttnProcessor2_0()) | |
| # xformers | |
| if device == "cuda": | |
| try: | |
| pipe.enable_xformers_memory_efficient_attention() | |
| print(" [OK] xformers enabled") | |
| except Exception as e: | |
| print(f" [INFO] xformers not available: {e}") | |
| # TORCH 2.1.1: Compile UNet for 50-100% speedup | |
| if hasattr(torch, 'compile') and device == "cuda": | |
| try: | |
| print(" [TORCH 2.1] Compiling UNet (first run +30s, then 50-100% faster)...") | |
| pipe.unet = torch.compile( | |
| pipe.unet, | |
| mode="reduce-overhead", # Faster for repeated inference | |
| fullgraph=False # More stable with ControlNet | |
| ) | |
| print(" [OK] UNet compiled") | |
| except Exception as e: | |
| print(f" [INFO] torch.compile not available: {e}") | |
| def load_caption_model(): | |
| """Load BLIP caption model.""" | |
| print("Loading BLIP model...") | |
| try: | |
| caption_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") | |
| caption_model = BlipForConditionalGeneration.from_pretrained( | |
| "Salesforce/blip-image-captioning-base", | |
| torch_dtype=dtype | |
| ).to(device) | |
| print(" [OK] BLIP model loaded") | |
| return caption_processor, caption_model, True | |
| except Exception as e: | |
| print(f" [WARNING] BLIP not available: {e}") | |
| return None, None, False | |
| def set_clip_skip(pipe): | |
| """Set CLIP skip.""" | |
| if hasattr(pipe, 'text_encoder'): | |
| print(f" [OK] CLIP skip set to {CLIP_SKIP}") | |
| print("[OK] Model loading functions ready (Torch 2.1.1 + Depth Anything V2)") | |