Spaces:
Running
Running
| import cv2 | |
| import torch | |
| import numpy as np | |
| import gradio as gr | |
| from diffusers import StableDiffusionXLPipeline | |
| from insightface.app import FaceAnalysis | |
| from accelerate import init_empty_weights, load_checkpoint_and_dispatch | |
| import os | |
| import urllib.request | |
| import time | |
| # Allow network access for runtime download | |
| os.environ["HF_HUB_OFFLINE"] = "0" | |
| # Set device to CPU | |
| device = "cpu" | |
| dtype = torch.float32 | |
| # Load face encoder (InsightFace handles its own download) | |
| try: | |
| face_app = FaceAnalysis(providers=["CPUExecutionProvider"]) | |
| face_app.prepare(ctx_id=0, det_size=(480, 480)) | |
| print("InsightFace model loaded successfully.") | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to load InsightFace model: {e}. Ensure network access for initial download.") | |
| # Define paths for preloaded or downloaded weights | |
| model_path = "./" # Start with root | |
| ip_adapter_path = "./" | |
| # Debug: List files to confirm preloading or download | |
| print("Files in root directory:", os.listdir(".")) | |
| print("Files in ./unet/ directory:", os.listdir("./unet") if os.path.exists("./unet") else "No ./unet/ directory") | |
| # Check if base model weights exist or download them | |
| kolors_weights = model_path + "diffusers_weights.safetensors" | |
| if not os.path.exists(kolors_weights): | |
| kolors_weights = model_path + "diffusion_pytorch_model.fp16.safetensors" | |
| if not os.path.exists(kolors_weights): | |
| kolors_weights_unet = "./unet/diffusion_pytorch_model.fp16.safetensors" | |
| if not os.path.exists(kolors_weights_unet): | |
| print("Preloading failed. Attempting runtime download with retry...") | |
| os.makedirs("./unet", exist_ok=True) | |
| max_retries = 3 | |
| correct_url = "https://huggingface.co/Kwai-Kolors/Kolors/raw/main/unet/diffusion_pytorch_model.fp16.safetensors" | |
| for attempt in range(max_retries): | |
| try: | |
| print(f"Download attempt {attempt + 1} of {max_retries}") | |
| urllib.request.urlretrieve(correct_url, kolors_weights_unet) | |
| print("Kolors base weights downloaded to", kolors_weights_unet) | |
| model_path = "./unet/" | |
| kolors_weights = kolors_weights_unet | |
| break | |
| except urllib.error.HTTPError as e: | |
| print(f"Download attempt {attempt + 1} failed: HTTP Error {e.code} - {e.reason}") | |
| if attempt < max_retries - 1: | |
| time.sleep(5) | |
| else: | |
| raise FileNotFoundError(f"Failed to download Kolors base weights after {max_retries} attempts: HTTP Error {e.code} - {e.reason}. Verify the URL or contact support.") | |
| except Exception as e: | |
| print(f"Download attempt {attempt + 1} failed: {e}") | |
| if attempt < max_retries - 1: | |
| time.sleep(5) | |
| else: | |
| raise FileNotFoundError(f"Failed to download Kolors base weights after {max_retries} attempts: {e}. Check network access or contact support.") | |
| else: | |
| model_path = "./unet/" | |
| kolors_weights = kolors_weights_unet | |
| # Check if IP-Adapter weights exist (preloaded) | |
| if not os.path.exists(ip_adapter_path + "ipa-faceid-plus.bin"): | |
| raise FileNotFoundError(f"IP-Adapter weights not found at {ip_adapter_path}") | |
| # Initialize model with empty weights | |
| with init_empty_weights(): | |
| pipe = StableDiffusionXLPipeline.from_pretrained( | |
| "Kwai-Kolors/Kolors-IP-Adapter-FaceID-Plus", | |
| torch_dtype=dtype, | |
| safety_checker=None, | |
| ) | |
| # Load and dispatch model with accelerate | |
| pipe = load_checkpoint_and_dispatch(pipe, model_path, device_map="cpu", offload_folder=None) | |
| pipe.load_ip_adapter("Kwai-Kolors/Kolors-IP-Adapter-FaceID-Plus", subfolder=None, weight_name="ipa-faceid-plus.bin") | |
| def generate_image(uploaded_image, prompt): | |
| img = cv2.cvtColor(np.array(uploaded_image), cv2.COLOR_RGB2BGR) | |
| faces = face_app.get(img) | |
| if not faces: | |
| return "No face detected!", None | |
| face_info = faces[-1] | |
| face_emb = face_info["embedding"] | |
| try: | |
| image = pipe( | |
| prompt=prompt, | |
| image_embeds=face_emb, | |
| num_inference_steps=20, | |
| guidance_scale=7.5, | |
| height=512, | |
| width=512, | |
| ).images[0] | |
| return "Image generated successfully!", image | |
| except Exception as e: | |
| return f"Generation failed: {e}", None | |
| interface = gr.Interface( | |
| fn=generate_image, | |
| inputs=[gr.Image(type="pil", label="Upload Reference Image"), gr.Textbox(label="Enter Prompt", placeholder="e.g., A photorealistic astronaut in space")], | |
| outputs=[gr.Textbox(label="Status"), gr.Image(label="Generated Image")], | |
| title="Face Reference Image Generator (Kolors-IP-Adapter-FaceID-Plus)", | |
| description="Upload an image with a face, enter a prompt, and generate a new image preserving the reference face." | |
| ) | |
| interface.launch() |