File size: 3,611 Bytes
9de5dbc
 
 
 
b42faf9
9de5dbc
 
 
581454b
9de5dbc
581454b
9de5dbc
 
b42faf9
9de5dbc
 
 
581454b
e74c661
581454b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e74c661
581454b
e74c661
581454b
e74c661
9de5dbc
 
 
e74c661
 
9de5dbc
 
b42faf9
 
e74c661
 
9de5dbc
b42faf9
9de5dbc
 
b42faf9
9de5dbc
 
 
 
b42faf9
 
 
9de5dbc
 
 
 
 
 
 
e74c661
9de5dbc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e74c661
 
b42faf9
9de5dbc
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import cv2
import torch
import numpy as np
import gradio as gr
from diffusers import StableDiffusionXLPipeline
from insightface.app import FaceAnalysis
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
import os
import urllib.request

# Force offline mode for Hugging Face Hub (but allow InsightFace download)
os.environ["HF_HUB_OFFLINE"] = "1"

# Set device to CPU
device = "cpu"
dtype = torch.float32

# Set up InsightFace model directory
insightface_model_dir = "/home/user/.insightface/models/buffalo_l"
os.makedirs(insightface_model_dir, exist_ok=True)
buffalo_l_zip = "./buffalo_l.zip"

# Download buffalo_l.zip if not present
if not os.path.exists(buffalo_l_zip):
    try:
        print("Downloading buffalo_l.zip for InsightFace...")
        urllib.request.urlretrieve(
            "https://github.com/deepinsight/insightface/releases/download/v0.7/buffalo_l.zip",
            buffalo_l_zip
        )
        print("Download completed.")
    except Exception as e:
        print(f"Failed to download buffalo_l.zip: {e}")
        # Fallback to a lighter model or raise error
        raise RuntimeError("Cannot download buffalo_l.zip. Please ensure network access or preload the file.")

# Extract buffalo_l.zip
if os.path.exists(buffalo_l_zip):
    import zipfile
    with zipfile.ZipFile(buffalo_l_zip, "r") as zip_ref:
        zip_ref.extractall(insightface_model_dir)
    print("Extracted buffalo_l.zip.")

# Load face encoder
face_app = FaceAnalysis(providers=["CPUExecutionProvider"], root="/home/user/.insightface/models")
face_app.prepare(ctx_id=0, det_size=(480, 480))

# Define paths for preloaded weights
model_path = "./"  # Kolors base model weights
ip_adapter_path = "./"

# Check if files exist
if not os.path.exists(model_path + "diffusion_pytorch_model.safetensors"):
    raise FileNotFoundError(f"Kolors model weights not found at {model_path}")
if not os.path.exists(ip_adapter_path + "ip-adapter.bin"):
    raise FileNotFoundError(f"IP-Adapter weights not found at {ip_adapter_path}")

# Initialize model with empty weights
with init_empty_weights():
    pipe = StableDiffusionXLPipeline.from_pretrained(
        "Kwai-Kolors/Kolors-diffusers",
        torch_dtype=dtype,
        safety_checker=None,
    )

# Load and dispatch model with accelerate
pipe = load_checkpoint_and_dispatch(pipe, model_path, device_map="cpu", offload_folder=None)
pipe.load_ip_adapter("h94/IP-Adapter-FaceID-Plus-SDXL", subfolder=None, weight_name="ip-adapter.bin")

def generate_image(uploaded_image, prompt):
    img = cv2.cvtColor(np.array(uploaded_image), cv2.COLOR_RGB2BGR)
    faces = face_app.get(img)
    if not faces:
        return "No face detected!", None

    face_info = faces[-1]
    face_emb = face_info["embedding"]

    try:
        image = pipe(
            prompt=prompt,
            image_embeds=face_emb,
            num_inference_steps=20,
            guidance_scale=7.5,
            height=512,
            width=512,
        ).images[0]
        return "Image generated successfully!", image
    except Exception as e:
        return f"Generation failed: {e}", None

interface = gr.Interface(
    fn=generate_image,
    inputs=[gr.Image(type="pil", label="Upload Reference Image"), gr.Textbox(label="Enter Prompt", placeholder="e.g., A photorealistic astronaut in space")],
    outputs=[gr.Textbox(label="Status"), gr.Image(label="Generated Image")],
    title="Face Reference Image Generator (Kolors with IP-Adapter)",
    description="Upload an image with a face, enter a prompt, and generate a new image preserving the reference face."
)

interface.launch()