Spaces:
Runtime error
Runtime error
| """ | |
| Models.py - Following examplewithface.py EXACTLY | |
| NO MultiControlNetModel wrapper! | |
| Using Kohya-style LoRA from lora.py (examplewithface.py lines 223-235) | |
| """ | |
| import torch | |
| import time | |
| import os | |
| from diffusers import ControlNetModel, AutoencoderKL, LCMScheduler | |
| from insightface.app import FaceAnalysis | |
| from controlnet_aux import ZoeDetector | |
| from huggingface_hub import hf_hub_download, snapshot_download | |
| from safetensors.torch import load_file | |
| from compel import Compel, ReturnedEmbeddingsType | |
| from pipeline_stable_diffusion_xl_instantid_img2img import ( | |
| StableDiffusionXLInstantIDImg2ImgPipeline, | |
| draw_kps | |
| ) | |
| from config import ( | |
| device, dtype, MODEL_REPO, MODEL_FILES, HUGGINGFACE_TOKEN, | |
| FACE_DETECTION_CONFIG, CLIP_SKIP, DOWNLOAD_CONFIG | |
| ) | |
| def download_model_with_retry(repo_id, filename, max_retries=None): | |
| if max_retries is None: | |
| max_retries = DOWNLOAD_CONFIG['max_retries'] | |
| for attempt in range(max_retries): | |
| try: | |
| kwargs = {"repo_type": "model"} | |
| if HUGGINGFACE_TOKEN: | |
| kwargs["token"] = HUGGINGFACE_TOKEN | |
| path = hf_hub_download(repo_id=repo_id, filename=filename, **kwargs) | |
| return path | |
| except Exception as e: | |
| if attempt < max_retries - 1: | |
| time.sleep(DOWNLOAD_CONFIG['retry_delay']) | |
| else: | |
| raise | |
| return None | |
| def load_face_analysis(): | |
| """examplewithface.py line 113""" | |
| print("Loading face analysis...") | |
| try: | |
| # Download antelopev2 model | |
| snapshot_download( | |
| repo_id="DIAMONIK7777/antelopev2", | |
| local_dir="/data/models/antelopev2" | |
| ) | |
| # examplewithface.py line 113 pattern | |
| app = FaceAnalysis(name='antelopev2', root='/data', providers=['CPUExecutionProvider']) | |
| app.prepare(ctx_id=0, det_size=(640, 640)) | |
| print(" [OK] Face analysis loaded") | |
| return app, True | |
| except Exception as e: | |
| print(f" [ERROR] Face analysis failed: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return None, False | |
| def load_depth_detector(): | |
| """examplewithface.py line 151-155""" | |
| print("Loading Zoe Depth...") | |
| try: | |
| zoe = ZoeDetector.from_pretrained("lllyasviel/Annotators") | |
| zoe.to(device) # examplewithface.py line 155 | |
| print(" [OK] Zoe Depth loaded") | |
| return zoe, True | |
| except Exception as e: | |
| print(f" [WARNING] Zoe unavailable: {e}") | |
| return None, False | |
| def load_controlnets(): | |
| """examplewithface.py lines 122-126""" | |
| print("Loading ControlNets...") | |
| # Load but don't move to device yet - pipe.to(device) will handle it | |
| identitynet = ControlNetModel.from_pretrained( | |
| "InstantX/InstantID", | |
| subfolder="ControlNetModel", | |
| torch_dtype=dtype | |
| ) | |
| print(" [OK] InstantID ControlNet") | |
| zoedepthnet = ControlNetModel.from_pretrained( | |
| "diffusers/controlnet-zoe-depth-sdxl-1.0", | |
| torch_dtype=dtype | |
| ) | |
| print(" [OK] Zoe Depth ControlNet") | |
| return identitynet, zoedepthnet | |
| def load_sdxl_pipeline(controlnets): | |
| """ | |
| examplewithface.py lines 128-145 | |
| CRITICAL: Pass controlnets as LIST - NO MultiControlNetModel! | |
| """ | |
| print("Loading pipeline...") | |
| # Load VAE (line 128) | |
| vae = AutoencoderKL.from_pretrained( | |
| "madebyollin/sdxl-vae-fp16-fix", | |
| torch_dtype=dtype | |
| ) | |
| print(" [OK] VAE loaded") | |
| # Create pipeline (line 134) - controlnets as LIST! | |
| pipe = StableDiffusionXLInstantIDImg2ImgPipeline.from_pretrained( | |
| "frankjoshua/albedobaseXL_v21", | |
| vae=vae, | |
| controlnet=controlnets, # ↠LIST [identitynet, zoedepthnet] - NO WRAPPER! | |
| torch_dtype=dtype | |
| ) | |
| print(" [OK] Pipeline created with direct controlnet list") | |
| # LCM scheduler | |
| pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config) | |
| print(" [OK] LCM scheduler") | |
| # IP-Adapter (line 139) | |
| ip_adapter_path = download_model_with_retry("InstantX/InstantID", "ip-adapter.bin") | |
| pipe.load_ip_adapter_instantid(ip_adapter_path) | |
| pipe.set_ip_adapter_scale(0.8) | |
| print(" [OK] IP-Adapter loaded") | |
| pipe = pipe.to(device) | |
| print(" [OK] Pipeline ready (following examplewithface.py EXACTLY)") | |
| return pipe, True | |
| # Global LoRA state | |
| lora_path_cached = None | |
| def load_lora(pipe): | |
| """Load LoRA - store path for later use""" | |
| print("Loading LoRA...") | |
| global lora_path_cached | |
| try: | |
| lora_path = download_model_with_retry(MODEL_REPO, MODEL_FILES['lora']) | |
| lora_path_cached = lora_path | |
| print(f" [OK] LoRA path stored") | |
| return True | |
| except Exception as e: | |
| print(f" [WARNING] LoRA failed: {e}") | |
| return False | |
| def fuse_lora_with_scale(pipe, lora_scale): | |
| """ | |
| Following examplewithface.py lines 223-235: | |
| Use the Kohya-style LoRA loader from lora.py (NOT diffusers built-in) | |
| """ | |
| global lora_path_cached | |
| if lora_path_cached is None: | |
| return False | |
| try: | |
| # Import the local lora module (Kohya-style) | |
| import lora | |
| print(f" [LORA] Creating network from weights...") | |
| # examplewithface.py lines 223-229 | |
| # Note: SDXL has two text encoders, pass both as a list | |
| text_encoders = [pipe.text_encoder, pipe.text_encoder_2] | |
| lora_model, weights_sd = lora.create_network_from_weights( | |
| lora_scale, # multiplier | |
| lora_path_cached, # file path | |
| pipe.vae, | |
| text_encoders, # Both SDXL text encoders | |
| pipe.unet, | |
| for_inference=True, | |
| ) | |
| # examplewithface.py lines 231-233 | |
| print(f" [LORA] Merging to model with scale {lora_scale}...") | |
| lora_model.merge_to( | |
| text_encoders, pipe.unet, weights_sd, torch.float16, "cuda" | |
| ) | |
| # Cleanup | |
| del weights_sd | |
| del lora_model | |
| print(f" [OK] LoRA merged into model using Kohya loader") | |
| return True | |
| except Exception as e: | |
| print(f" [ERROR] LoRA merge failed: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return False | |
| def setup_compel(pipe): | |
| """examplewithface.py line 145""" | |
| print("Setting up Compel...") | |
| try: | |
| compel = Compel( | |
| tokenizer=[pipe.tokenizer, pipe.tokenizer_2], | |
| text_encoder=[pipe.text_encoder, pipe.text_encoder_2], | |
| returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, | |
| requires_pooled=[False, True] | |
| ) | |
| print(" [OK] Compel ready") | |
| return compel, True | |
| except Exception as e: | |
| print(f" [WARNING] Compel unavailable: {e}") | |
| return None, False | |
| def setup_scheduler(pipe): | |
| pass | |
| def optimize_pipeline(pipe): | |
| if device == "cuda": | |
| try: | |
| pipe.enable_xformers_memory_efficient_attention() | |
| print(" [OK] xformers enabled") | |
| except: | |
| pass | |
| if hasattr(pipe, 'enable_vae_slicing'): | |
| pipe.enable_vae_slicing() | |
| if hasattr(pipe, 'enable_vae_tiling'): | |
| pipe.enable_vae_tiling() | |
| def load_caption_model(): | |
| print("Loading caption model...") | |
| try: | |
| from transformers import AutoProcessor, AutoModelForCausalLM | |
| processor = AutoProcessor.from_pretrained("microsoft/git-large-coco") | |
| model = AutoModelForCausalLM.from_pretrained("microsoft/git-large-coco", torch_dtype=dtype).to("cpu") | |
| print(" [OK] GIT-Large") | |
| return processor, model, True, 'git' | |
| except: | |
| try: | |
| from transformers import BlipProcessor, BlipForConditionalGeneration | |
| processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") | |
| model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", torch_dtype=dtype).to("cpu") | |
| print(" [OK] BLIP") | |
| return processor, model, True, 'blip' | |
| except: | |
| return None, None, False, 'none' | |
| def set_clip_skip(pipe): | |
| if hasattr(pipe, 'text_encoder'): | |
| print(f" [OK] CLIP skip {CLIP_SKIP}") | |
| def load_image_encoder(): | |
| """Load CLIP Image Encoder for IP-Adapter.""" | |
| print("Loading CLIP Image Encoder for IP-Adapter...") | |
| try: | |
| image_encoder = CLIPVisionModelWithProjection.from_pretrained( | |
| "h94/IP-Adapter", | |
| subfolder="models/image_encoder", | |
| torch_dtype=dtype | |
| ).to(device) | |
| print(" [OK] CLIP Image Encoder loaded successfully") | |
| return image_encoder | |
| except Exception as e: | |
| print(f" [ERROR] Could not load image encoder: {e}") | |
| return None | |
| def setup_ip_adapter(pipe, image_encoder): | |
| """ | |
| Setup IP-Adapter for InstantID face embeddings - PROPER IMPLEMENTATION. | |
| Based on the reference InstantID pipeline. | |
| """ | |
| if image_encoder is None: | |
| return None, False | |
| print("Setting up IP-Adapter for InstantID face embeddings (proper implementation)...") | |
| try: | |
| # Download InstantID weights | |
| ip_adapter_path = download_model_with_retry( | |
| "InstantX/InstantID", | |
| "ip-adapter.bin" | |
| ) | |
| # Load full state dict | |
| state_dict = torch.load(ip_adapter_path, map_location="cpu") | |
| # Extract image_proj and ip_adapter weights | |
| image_proj_state_dict = {} | |
| ip_adapter_state_dict = {} | |
| for key, value in state_dict.items(): | |
| if key.startswith("image_proj."): | |
| image_proj_state_dict[key.replace("image_proj.", "")] = value | |
| elif key.startswith("ip_adapter."): | |
| ip_adapter_state_dict[key.replace("ip_adapter.", "")] = value | |
| # Create Resampler (image projection model) with CORRECT parameters from reference | |
| print("Creating Resampler (Perceiver architecture)...") | |
| image_proj_model = Resampler( | |
| dim=1280, # Hidden dimension | |
| depth=4, # IMPORTANT: 4 layers (not 8!) | |
| dim_head=64, # Dimension per head | |
| heads=20, # Number of heads | |
| num_queries=16, # Number of output tokens | |
| embedding_dim=512, # InsightFace embedding dim | |
| output_dim=pipe.unet.config.cross_attention_dim, # SDXL cross-attention dim (2048) | |
| ff_mult=4 # Feedforward multiplier | |
| ) | |
| image_proj_model.eval() | |
| image_proj_model = image_proj_model.to(device, dtype=dtype) | |
| # Load image_proj weights | |
| if image_proj_state_dict: | |
| try: | |
| image_proj_model.load_state_dict(image_proj_state_dict, strict=True) | |
| print(" [OK] Resampler loaded with pretrained weights") | |
| except Exception as e: | |
| print(f" [WARNING] Could not load Resampler weights: {e}") | |
| print(" Using randomly initialized Resampler") | |
| else: | |
| print(" [WARNING] No image_proj weights found, using random initialization") | |
| # Setup IP-Adapter attention processors | |
| print("Setting up IP-Adapter attention processors...") | |
| attn_procs = {} | |
| num_tokens = 16 # Match Resampler num_queries | |
| for name in pipe.unet.attn_processors.keys(): | |
| cross_attention_dim = None if name.endswith("attn1.processor") else pipe.unet.config.cross_attention_dim | |
| if name.startswith("mid_block"): | |
| hidden_size = pipe.unet.config.block_out_channels[-1] | |
| elif name.startswith("up_blocks"): | |
| block_id = int(name[len("up_blocks.")]) | |
| hidden_size = list(reversed(pipe.unet.config.block_out_channels))[block_id] | |
| elif name.startswith("down_blocks"): | |
| block_id = int(name[len("down_blocks.")]) | |
| hidden_size = pipe.unet.config.block_out_channels[block_id] | |
| else: | |
| hidden_size = pipe.unet.config.block_out_channels[-1] | |
| if cross_attention_dim is None: | |
| attn_procs[name] = AttnProcessor2_0() | |
| else: | |
| attn_procs[name] = IPAttnProcessor2_0( | |
| hidden_size=hidden_size, | |
| cross_attention_dim=cross_attention_dim, | |
| scale=1.0, | |
| num_tokens=num_tokens | |
| ).to(device, dtype=dtype) | |
| # Set attention processors | |
| pipe.unet.set_attn_processor(attn_procs) | |
| # Load IP-Adapter weights into attention processors | |
| if ip_adapter_state_dict: | |
| try: | |
| ip_layers = torch.nn.ModuleList(pipe.unet.attn_processors.values()) | |
| ip_layers.load_state_dict(ip_adapter_state_dict, strict=False) | |
| print(" [OK] IP-Adapter attention weights loaded") | |
| except Exception as e: | |
| print(f" [WARNING] Could not load IP-Adapter weights: {e}") | |
| else: | |
| print(" [WARNING] No ip_adapter weights found") | |
| # Store image encoder and projection model | |
| pipe.image_encoder = image_encoder | |
| print(" [OK] IP-Adapter fully loaded with InstantID architecture") | |
| print(f" - Resampler: 4 layers, 20 heads, 16 output tokens") | |
| print(f" - Face embeddings: 512D → 16x2048D") | |
| return image_proj_model, True | |
| except Exception as e: | |
| print(f" [ERROR] Could not setup IP-Adapter: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return None, False | |
| __all__ = ['draw_kps', 'fuse_lora_with_scale', 'load_image_encoder', 'setup_ip_adapter'] | |
| print("[OK] models.py ready - NO MultiControlNetModel, following examplewithface.py") |