Spaces:
Running
Running
File size: 4,990 Bytes
9de5dbc b42faf9 9de5dbc 7a5105f 9de5dbc 7a5105f 9de5dbc b42faf9 9de5dbc b4fbfab 68c6e14 9de5dbc 7a5105f e74c661 9de5dbc 7a5105f b4fbfab 596bd62 24da2db 596bd62 e74c661 9de5dbc b42faf9 9de5dbc 596bd62 9de5dbc b42faf9 596bd62 9de5dbc e74c661 9de5dbc a383ea5 9e558cc a383ea5 596bd62 a383ea5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
import cv2
import torch
import numpy as np
import gradio as gr
from diffusers import StableDiffusionXLPipeline
from insightface.app import FaceAnalysis
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
import os
import urllib.request
import time
# Allow network access for runtime download
os.environ["HF_HUB_OFFLINE"] = "0"
# Set device to CPU
device = "cpu"
dtype = torch.float32
# Load face encoder (InsightFace handles its own download)
try:
face_app = FaceAnalysis(providers=["CPUExecutionProvider"])
face_app.prepare(ctx_id=0, det_size=(480, 480))
print("InsightFace model loaded successfully.")
except Exception as e:
raise RuntimeError(f"Failed to load InsightFace model: {e}. Ensure network access for initial download.")
# Define paths for preloaded or downloaded weights
model_path = "./" # Start with root
ip_adapter_path = "./"
# Debug: List files to confirm preloading or download
print("Files in root directory:", os.listdir("."))
print("Files in ./unet/ directory:", os.listdir("./unet") if os.path.exists("./unet") else "No ./unet/ directory")
# Check if base model weights exist or download them
kolors_weights = model_path + "diffusers_weights.safetensors"
if not os.path.exists(kolors_weights):
kolors_weights = model_path + "diffusion_pytorch_model.fp16.safetensors"
if not os.path.exists(kolors_weights):
kolors_weights_unet = "./unet/diffusion_pytorch_model.fp16.safetensors"
if not os.path.exists(kolors_weights_unet):
print("Preloading failed. Attempting runtime download with retry...")
os.makedirs("./unet", exist_ok=True)
max_retries = 3
correct_url = "https://huggingface.co/Kwai-Kolors/Kolors/raw/main/unet/diffusion_pytorch_model.fp16.safetensors"
for attempt in range(max_retries):
try:
print(f"Download attempt {attempt + 1} of {max_retries}")
urllib.request.urlretrieve(correct_url, kolors_weights_unet)
print("Kolors base weights downloaded to", kolors_weights_unet)
model_path = "./unet/"
kolors_weights = kolors_weights_unet
break
except urllib.error.HTTPError as e:
print(f"Download attempt {attempt + 1} failed: HTTP Error {e.code} - {e.reason}")
if attempt < max_retries - 1:
time.sleep(5)
else:
raise FileNotFoundError(f"Failed to download Kolors base weights after {max_retries} attempts: HTTP Error {e.code} - {e.reason}. Verify the URL or contact support.")
except Exception as e:
print(f"Download attempt {attempt + 1} failed: {e}")
if attempt < max_retries - 1:
time.sleep(5)
else:
raise FileNotFoundError(f"Failed to download Kolors base weights after {max_retries} attempts: {e}. Check network access or contact support.")
else:
model_path = "./unet/"
kolors_weights = kolors_weights_unet
# Check if IP-Adapter weights exist (preloaded)
if not os.path.exists(ip_adapter_path + "ipa-faceid-plus.bin"):
raise FileNotFoundError(f"IP-Adapter weights not found at {ip_adapter_path}")
# Initialize model with empty weights
with init_empty_weights():
pipe = StableDiffusionXLPipeline.from_pretrained(
"Kwai-Kolors/Kolors-IP-Adapter-FaceID-Plus",
torch_dtype=dtype,
safety_checker=None,
)
# Load and dispatch model with accelerate
pipe = load_checkpoint_and_dispatch(pipe, model_path, device_map="cpu", offload_folder=None)
pipe.load_ip_adapter("Kwai-Kolors/Kolors-IP-Adapter-FaceID-Plus", subfolder=None, weight_name="ipa-faceid-plus.bin")
def generate_image(uploaded_image, prompt):
img = cv2.cvtColor(np.array(uploaded_image), cv2.COLOR_RGB2BGR)
faces = face_app.get(img)
if not faces:
return "No face detected!", None
face_info = faces[-1]
face_emb = face_info["embedding"]
try:
image = pipe(
prompt=prompt,
image_embeds=face_emb,
num_inference_steps=20,
guidance_scale=7.5,
height=512,
width=512,
).images[0]
return "Image generated successfully!", image
except Exception as e:
return f"Generation failed: {e}", None
interface = gr.Interface(
fn=generate_image,
inputs=[gr.Image(type="pil", label="Upload Reference Image"), gr.Textbox(label="Enter Prompt", placeholder="e.g., A photorealistic astronaut in space")],
outputs=[gr.Textbox(label="Status"), gr.Image(label="Generated Image")],
title="Face Reference Image Generator (Kolors-IP-Adapter-FaceID-Plus)",
description="Upload an image with a face, enter a prompt, and generate a new image preserving the reference face."
)
interface.launch() |