File size: 5,503 Bytes
9de5dbc
 
 
 
b42faf9
9de5dbc
 
 
278fdd9
7a5105f
9de5dbc
7a5105f
 
9de5dbc
b42faf9
9de5dbc
 
 
b4fbfab
68c6e14
 
 
 
 
 
9de5dbc
7a5105f
 
e74c661
9de5dbc
7a5105f
b4fbfab
 
 
596bd62
 
24da2db
596bd62
 
 
 
 
 
 
 
 
 
278fdd9
 
 
 
 
 
596bd62
 
 
 
 
 
 
 
 
278fdd9
596bd62
 
 
 
278fdd9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9de5dbc
b42faf9
9de5dbc
 
596bd62
9de5dbc
 
 
 
b42faf9
 
596bd62
9de5dbc
 
 
 
 
 
 
e74c661
9de5dbc
 
 
 
 
 
 
 
a383ea5
 
9e558cc
a383ea5
 
 
 
 
 
 
 
596bd62
a383ea5
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import cv2
import torch
import numpy as np
import gradio as gr
from diffusers import StableDiffusionXLPipeline
from insightface.app import FaceAnalysis
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
import os
from huggingface_hub import hf_hub_download
import time

# Allow network access for runtime download
os.environ["HF_HUB_OFFLINE"] = "0"

# Set device to CPU
device = "cpu"
dtype = torch.float32

# Load face encoder (InsightFace handles its own download)
try:
    face_app = FaceAnalysis(providers=["CPUExecutionProvider"])
    face_app.prepare(ctx_id=0, det_size=(480, 480))
    print("InsightFace model loaded successfully.")
except Exception as e:
    raise RuntimeError(f"Failed to load InsightFace model: {e}. Ensure network access for initial download.")

# Define paths for preloaded or downloaded weights
model_path = "./"  # Start with root
ip_adapter_path = "./"

# Debug: List files to confirm preloading or download
print("Files in root directory:", os.listdir("."))
print("Files in ./unet/ directory:", os.listdir("./unet") if os.path.exists("./unet") else "No ./unet/ directory")

# Check if base model weights exist or download them
kolors_weights = model_path + "diffusers_weights.safetensors"
if not os.path.exists(kolors_weights):
    kolors_weights = model_path + "diffusion_pytorch_model.fp16.safetensors"
    if not os.path.exists(kolors_weights):
        kolors_weights_unet = "./unet/diffusion_pytorch_model.fp16.safetensors"
        if not os.path.exists(kolors_weights_unet):
            print("Preloading failed. Attempting runtime download with retry...")
            os.makedirs("./unet", exist_ok=True)
            max_retries = 3
            for attempt in range(max_retries):
                try:
                    print(f"Download attempt {attempt + 1} of {max_retries}")
                    hf_hub_download(
                        repo_id="Kwai-Kolors/Kolors",
                        filename="unet/diffusion_pytorch_model.fp16.safetensors",
                        local_dir="./unet",
                        local_files_only=False
                    )
                    print("Kolors base weights downloaded to", kolors_weights_unet)
                    model_path = "./unet/"
                    kolors_weights = kolors_weights_unet
                    break
                except Exception as e:
                    print(f"Download attempt {attempt + 1} failed: {e}")
                    if attempt < max_retries - 1:
                        time.sleep(5)
                    else:
                        raise FileNotFoundError(f"Failed to download Kolors base weights after {max_retries} attempts: {e}. Verify the repo or contact support.")
        else:
            model_path = "./unet/"
            kolors_weights = kolors_weights_unet

# Check if IP-Adapter weights exist or download them
ip_adapter_weights = ip_adapter_path + "ipa-faceid-plus.bin"
if not os.path.exists(ip_adapter_weights):
    print("IP-Adapter preloading failed. Attempting runtime download with retry...")
    max_retries = 3
    for attempt in range(max_retries):
        try:
            print(f"IP-Adapter download attempt {attempt + 1} of {max_retries}")
            hf_hub_download(
                repo_id="Kwai-Kolors/Kolors-IP-Adapter-FaceID-Plus",
                filename="ipa-faceid-plus.bin",
                local_dir="./",
                local_files_only=False
            )
            print("IP-Adapter weights downloaded to", ip_adapter_weights)
            break
        except Exception as e:
            print(f"IP-Adapter download attempt {attempt + 1} failed: {e}")
            if attempt < max_retries - 1:
                time.sleep(5)
            else:
                raise FileNotFoundError(f"Failed to download IP-Adapter weights after {max_retries} attempts: {e}. Verify the repo or contact support.")

# Initialize model with empty weights
with init_empty_weights():
    pipe = StableDiffusionXLPipeline.from_pretrained(
        "Kwai-Kolors/Kolors-IP-Adapter-FaceID-Plus",
        torch_dtype=dtype,
        safety_checker=None,
    )

# Load and dispatch model with accelerate
pipe = load_checkpoint_and_dispatch(pipe, model_path, device_map="cpu", offload_folder=None)
pipe.load_ip_adapter("Kwai-Kolors/Kolors-IP-Adapter-FaceID-Plus", subfolder=None, weight_name="ipa-faceid-plus.bin")

def generate_image(uploaded_image, prompt):
    img = cv2.cvtColor(np.array(uploaded_image), cv2.COLOR_RGB2BGR)
    faces = face_app.get(img)
    if not faces:
        return "No face detected!", None

    face_info = faces[-1]
    face_emb = face_info["embedding"]

    try:
        image = pipe(
            prompt=prompt,
            image_embeds=face_emb,
            num_inference_steps=20,
            guidance_scale=7.5,
            height=512,
            width=512,
        ).images[0]
        return "Image generated successfully!", image
    except Exception as e:
        return f"Generation failed: {e}", None

interface = gr.Interface(
    fn=generate_image,
    inputs=[gr.Image(type="pil", label="Upload Reference Image"), gr.Textbox(label="Enter Prompt", placeholder="e.g., A photorealistic astronaut in space")],
    outputs=[gr.Textbox(label="Status"), gr.Image(label="Generated Image")],
    title="Face Reference Image Generator (Kolors-IP-Adapter-FaceID-Plus)",
    description="Upload an image with a face, enter a prompt, and generate a new image preserving the reference face."
)

interface.launch()