import cv2 import torch import numpy as np import gradio as gr from diffusers import StableDiffusionXLPipeline from insightface.app import FaceAnalysis from accelerate import init_empty_weights, load_checkpoint_and_dispatch import os from huggingface_hub import hf_hub_download import time # Allow network access for runtime download os.environ["HF_HUB_OFFLINE"] = "0" # Set device to CPU device = "cpu" dtype = torch.float32 # Load face encoder (InsightFace handles its own download) try: face_app = FaceAnalysis(providers=["CPUExecutionProvider"]) face_app.prepare(ctx_id=0, det_size=(480, 480)) print("InsightFace model loaded successfully.") except Exception as e: raise RuntimeError(f"Failed to load InsightFace model: {e}. Ensure network access for initial download.") # Define paths for preloaded or downloaded weights model_path = "./" # Start with root ip_adapter_path = "./" # Debug: List files to confirm preloading or download print("Files in root directory:", os.listdir(".")) print("Files in ./unet/ directory:", os.listdir("./unet") if os.path.exists("./unet") else "No ./unet/ directory") # Check if base model weights exist or download them kolors_weights = model_path + "diffusers_weights.safetensors" if not os.path.exists(kolors_weights): kolors_weights = model_path + "diffusion_pytorch_model.fp16.safetensors" if not os.path.exists(kolors_weights): kolors_weights_unet = "./unet/diffusion_pytorch_model.fp16.safetensors" if not os.path.exists(kolors_weights_unet): print("Preloading failed. Attempting runtime download with retry...") os.makedirs("./unet", exist_ok=True) max_retries = 3 for attempt in range(max_retries): try: print(f"Download attempt {attempt + 1} of {max_retries}") hf_hub_download( repo_id="Kwai-Kolors/Kolors", filename="unet/diffusion_pytorch_model.fp16.safetensors", local_dir="./unet", local_files_only=False ) print("Kolors base weights downloaded to", kolors_weights_unet) model_path = "./unet/" kolors_weights = kolors_weights_unet break except Exception as e: print(f"Download attempt {attempt + 1} failed: {e}") if attempt < max_retries - 1: time.sleep(5) else: raise FileNotFoundError(f"Failed to download Kolors base weights after {max_retries} attempts: {e}. Verify the repo or contact support.") else: model_path = "./unet/" kolors_weights = kolors_weights_unet # Check if IP-Adapter weights exist or download them ip_adapter_weights = ip_adapter_path + "ipa-faceid-plus.bin" if not os.path.exists(ip_adapter_weights): print("IP-Adapter preloading failed. Attempting runtime download with retry...") max_retries = 3 for attempt in range(max_retries): try: print(f"IP-Adapter download attempt {attempt + 1} of {max_retries}") hf_hub_download( repo_id="Kwai-Kolors/Kolors-IP-Adapter-FaceID-Plus", filename="ipa-faceid-plus.bin", local_dir="./", local_files_only=False ) print("IP-Adapter weights downloaded to", ip_adapter_weights) break except Exception as e: print(f"IP-Adapter download attempt {attempt + 1} failed: {e}") if attempt < max_retries - 1: time.sleep(5) else: raise FileNotFoundError(f"Failed to download IP-Adapter weights after {max_retries} attempts: {e}. Verify the repo or contact support.") # Initialize model with empty weights with init_empty_weights(): pipe = StableDiffusionXLPipeline.from_pretrained( "Kwai-Kolors/Kolors-IP-Adapter-FaceID-Plus", torch_dtype=dtype, safety_checker=None, ) # Load and dispatch model with accelerate pipe = load_checkpoint_and_dispatch(pipe, model_path, device_map="cpu", offload_folder=None) pipe.load_ip_adapter("Kwai-Kolors/Kolors-IP-Adapter-FaceID-Plus", subfolder=None, weight_name="ipa-faceid-plus.bin") def generate_image(uploaded_image, prompt): img = cv2.cvtColor(np.array(uploaded_image), cv2.COLOR_RGB2BGR) faces = face_app.get(img) if not faces: return "No face detected!", None face_info = faces[-1] face_emb = face_info["embedding"] try: image = pipe( prompt=prompt, image_embeds=face_emb, num_inference_steps=20, guidance_scale=7.5, height=512, width=512, ).images[0] return "Image generated successfully!", image except Exception as e: return f"Generation failed: {e}", None interface = gr.Interface( fn=generate_image, inputs=[gr.Image(type="pil", label="Upload Reference Image"), gr.Textbox(label="Enter Prompt", placeholder="e.g., A photorealistic astronaut in space")], outputs=[gr.Textbox(label="Status"), gr.Image(label="Generated Image")], title="Face Reference Image Generator (Kolors-IP-Adapter-FaceID-Plus)", description="Upload an image with a face, enter a prompt, and generate a new image preserving the reference face." ) interface.launch()