Spaces:
Running
Running
| import cv2 | |
| import torch | |
| import numpy as np | |
| import gradio as gr | |
| from diffusers import StableDiffusionXLPipeline | |
| from insightface.app import FaceAnalysis | |
| from accelerate import init_empty_weights, load_checkpoint_and_dispatch | |
| import os | |
| from huggingface_hub import hf_hub_download | |
| import time | |
| # Allow network access for runtime download | |
| os.environ["HF_HUB_OFFLINE"] = "0" | |
| # Set device to CPU | |
| device = "cpu" | |
| dtype = torch.float32 | |
| # Load face encoder (InsightFace handles its own download) | |
| try: | |
| face_app = FaceAnalysis(providers=["CPUExecutionProvider"]) | |
| face_app.prepare(ctx_id=0, det_size=(480, 480)) | |
| print("InsightFace model loaded successfully.") | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to load InsightFace model: {e}. Ensure network access for initial download.") | |
| # Define paths for preloaded or downloaded weights | |
| model_path = "./" # Start with root | |
| ip_adapter_path = "./" | |
| # Debug: List files to confirm preloading or download | |
| print("Files in root directory:", os.listdir(".")) | |
| print("Files in ./unet/ directory:", os.listdir("./unet") if os.path.exists("./unet") else "No ./unet/ directory") | |
| # Check if base model weights exist or download them | |
| kolors_weights = model_path + "diffusers_weights.safetensors" | |
| if not os.path.exists(kolors_weights): | |
| kolors_weights = model_path + "diffusion_pytorch_model.fp16.safetensors" | |
| if not os.path.exists(kolors_weights): | |
| kolors_weights_unet = "./unet/diffusion_pytorch_model.fp16.safetensors" | |
| if not os.path.exists(kolors_weights_unet): | |
| print("Preloading failed. Attempting runtime download with retry...") | |
| os.makedirs("./unet", exist_ok=True) | |
| max_retries = 3 | |
| for attempt in range(max_retries): | |
| try: | |
| print(f"Download attempt {attempt + 1} of {max_retries}") | |
| hf_hub_download( | |
| repo_id="Kwai-Kolors/Kolors", | |
| filename="unet/diffusion_pytorch_model.fp16.safetensors", | |
| local_dir="./unet", | |
| local_files_only=False | |
| ) | |
| print("Kolors base weights downloaded to", kolors_weights_unet) | |
| model_path = "./unet/" | |
| kolors_weights = kolors_weights_unet | |
| break | |
| except Exception as e: | |
| print(f"Download attempt {attempt + 1} failed: {e}") | |
| if attempt < max_retries - 1: | |
| time.sleep(5) | |
| else: | |
| raise FileNotFoundError(f"Failed to download Kolors base weights after {max_retries} attempts: {e}. Verify the repo or contact support.") | |
| else: | |
| model_path = "./unet/" | |
| kolors_weights = kolors_weights_unet | |
| # Check if IP-Adapter weights exist or download them | |
| ip_adapter_weights = ip_adapter_path + "ipa-faceid-plus.bin" | |
| if not os.path.exists(ip_adapter_weights): | |
| print("IP-Adapter preloading failed. Attempting runtime download with retry...") | |
| max_retries = 3 | |
| for attempt in range(max_retries): | |
| try: | |
| print(f"IP-Adapter download attempt {attempt + 1} of {max_retries}") | |
| hf_hub_download( | |
| repo_id="Kwai-Kolors/Kolors-IP-Adapter-FaceID-Plus", | |
| filename="ipa-faceid-plus.bin", | |
| local_dir="./", | |
| local_files_only=False | |
| ) | |
| print("IP-Adapter weights downloaded to", ip_adapter_weights) | |
| break | |
| except Exception as e: | |
| print(f"IP-Adapter download attempt {attempt + 1} failed: {e}") | |
| if attempt < max_retries - 1: | |
| time.sleep(5) | |
| else: | |
| raise FileNotFoundError(f"Failed to download IP-Adapter weights after {max_retries} attempts: {e}. Verify the repo or contact support.") | |
| # Initialize model with empty weights | |
| with init_empty_weights(): | |
| pipe = StableDiffusionXLPipeline.from_pretrained( | |
| "Kwai-Kolors/Kolors-IP-Adapter-FaceID-Plus", | |
| torch_dtype=dtype, | |
| safety_checker=None, | |
| ) | |
| # Load and dispatch model with accelerate | |
| pipe = load_checkpoint_and_dispatch(pipe, model_path, device_map="cpu", offload_folder=None) | |
| pipe.load_ip_adapter("Kwai-Kolors/Kolors-IP-Adapter-FaceID-Plus", subfolder=None, weight_name="ipa-faceid-plus.bin") | |
| def generate_image(uploaded_image, prompt): | |
| img = cv2.cvtColor(np.array(uploaded_image), cv2.COLOR_RGB2BGR) | |
| faces = face_app.get(img) | |
| if not faces: | |
| return "No face detected!", None | |
| face_info = faces[-1] | |
| face_emb = face_info["embedding"] | |
| try: | |
| image = pipe( | |
| prompt=prompt, | |
| image_embeds=face_emb, | |
| num_inference_steps=20, | |
| guidance_scale=7.5, | |
| height=512, | |
| width=512, | |
| ).images[0] | |
| return "Image generated successfully!", image | |
| except Exception as e: | |
| return f"Generation failed: {e}", None | |
| interface = gr.Interface( | |
| fn=generate_image, | |
| inputs=[gr.Image(type="pil", label="Upload Reference Image"), gr.Textbox(label="Enter Prompt", placeholder="e.g., A photorealistic astronaut in space")], | |
| outputs=[gr.Textbox(label="Status"), gr.Image(label="Generated Image")], | |
| title="Face Reference Image Generator (Kolors-IP-Adapter-FaceID-Plus)", | |
| description="Upload an image with a face, enter a prompt, and generate a new image preserving the reference face." | |
| ) | |
| interface.launch() |