File size: 5,449 Bytes
9de5dbc
 
 
 
f2f8f12
9de5dbc
278fdd9
f52c5b7
9b003e1
a3c77b4
9de5dbc
a1bd508
9b003e1
 
 
 
7a5105f
9de5dbc
9b003e1
9de5dbc
9b003e1
 
 
 
 
9de5dbc
9b003e1
a1bd508
68c6e14
 
 
9b003e1
68c6e14
a1bd508
 
9de5dbc
a3c77b4
 
9b003e1
 
a3c77b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9b003e1
 
 
 
 
e74c661
9b003e1
9de5dbc
f4836de
a1bd508
9b003e1
 
 
 
f52c5b7
 
f2f8f12
 
9b003e1
a3c77b4
 
 
f2f8f12
 
 
a3c77b4
 
 
 
 
 
 
f2f8f12
a3c77b4
 
 
 
 
 
 
f2f8f12
9b003e1
f2f8f12
9b003e1
 
f52c5b7
9b003e1
 
 
a1bd508
9b003e1
 
 
9de5dbc
f52c5b7
a1bd508
f52c5b7
a1bd508
9de5dbc
 
a1bd508
9b003e1
 
 
 
a1bd508
9b003e1
9de5dbc
9b003e1
 
9de5dbc
9b003e1
9de5dbc
 
 
c701a24
9de5dbc
c701a24
a1bd508
9e558cc
a1bd508
a383ea5
 
9b003e1
a383ea5
 
f52c5b7
a383ea5
 
f52c5b7
 
 
 
 
 
 
 
9b003e1
 
a383ea5
 
a1bd508
f52c5b7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import cv2
import torch
import numpy as np
import gradio as gr
from diffusers import StableDiffusionPipeline  # Use SD 2.1 instead of SDXL
from insightface.app import FaceAnalysis
from huggingface_hub import hf_hub_download
import os
import logging
import time

# Set up detailed logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Allow network access
os.environ["HF_HUB_OFFLINE"] = "0"

# Set device to CPU
device = "cpu"
dtype = torch.float32

# Define cache directory
cache_dir = "./cache"
os.makedirs(cache_dir, exist_ok=True)

# Load face encoder
logger.info("Starting InsightFace initialization...")
try:
    face_app = FaceAnalysis(providers=["CPUExecutionProvider"])
    face_app.prepare(ctx_id=0, det_size=(480, 480))
    logger.info("InsightFace model loaded successfully.")
except Exception as e:
    logger.error(f"Failed to load InsightFace model: {e}")
    raise

# Download function with retry logic
def download_file(repo_id, filename, local_dir, max_retries=3):
    file_path = os.path.join(local_dir, filename)
    if not os.path.exists(file_path):
        for attempt in range(max_retries):
            logger.info(f"Attempt {attempt + 1}/{max_retries}: Downloading {filename} from {repo_id} to {local_dir}...")
            try:
                downloaded_path = hf_hub_download(
                    repo_id=repo_id,
                    filename=filename,
                    local_dir=local_dir,
                    cache_dir=cache_dir,
                    local_files_only=False
                )
                logger.info(f"Downloaded to {downloaded_path}")
                return downloaded_path
            except Exception as e:
                logger.error(f"Download attempt {attempt + 1} failed: {e}")
                if attempt < max_retries - 1:
                    logger.info("Retrying in 5 seconds...")
                    time.sleep(5)
                else:
                    raise RuntimeError(f"Failed to download {filename} after {max_retries} attempts: {e}")
    else:
        logger.info(f"Using cached file at {file_path}")
        return file_path

# Define paths
ip_adapter_path = "./"
os.makedirs(ip_adapter_path, exist_ok=True)

# Download IP-Adapter weights with retries
logger.info("Starting weights download...")
ip_adapter_weights = download_file(
    "Kwai-Kolors/Kolors-IP-Adapter-FaceID-Plus",
    "ipa-faceid-plus.bin",
    ip_adapter_path
)

# Load the pipeline with SD 2.1
logger.info("Loading Stable Diffusion 2.1 base model...")
try:
    max_retries = 3
    for attempt in range(max_retries):
        try:
            logger.info(f"Attempt {attempt + 1}/{max_retries}: Loading SD 2.1 model...")
            pipe = StableDiffusionPipeline.from_pretrained(
                "stabilityai/stable-diffusion-2-1",
                torch_dtype=dtype,
                safety_checker=None,
                local_files_only=False,
                cache_dir=cache_dir,
                variant="fp16",
                use_safetensors=True
            )
            logger.info("SD 2.1 base model loaded successfully.")
            break
        except Exception as e:
            logger.error(f"Load attempt {attempt + 1} failed: {e}")
            if attempt < max_retries - 1:
                logger.info("Retrying in 5 seconds...")
                time.sleep(5)
            else:
                raise RuntimeError(f"Failed to load SD 2.1 model after {max_retries} attempts: {e}")
except Exception as e:
    logger.error(f"Failed to load SD 2.1 base model: {e}")
    raise

# Load IP-Adapter
logger.info(f"Loading IP-Adapter from {ip_adapter_weights}...")
try:
    pipe.load_ip_adapter(ip_adapter_path, subfolder=None, weight_name="ipa-faceid-plus.bin")
    logger.info("IP-Adapter loaded successfully.")
except Exception as e:
    logger.error(f"Failed to load IP-Adapter: {e}")
    raise

# Move pipeline to CPU
logger.info("Moving pipeline to CPU...")
pipe.to(device)
logger.info("Pipeline moved to CPU.")

def generate_image(uploaded_image, prompt):
    logger.info("Starting image generation...")
    try:
        img = cv2.cvtColor(np.array(uploaded_image), cv2.COLOR_RGB2BGR)
        faces = face_app.get(img)
        if not faces:
            logger.warning("No face detected in uploaded image.")
            return "No face detected!", None

        face_info = faces[-1]
        face_emb = face_info["embedding"]

        logger.info(f"Generating image with prompt: {prompt}")
        image = pipe(
            prompt=prompt,
            image_embeds=face_emb,
            num_inference_steps=10,
            guidance_scale=7.5,
            height=256,
            width=256
        ).images[0]
        logger.info("Image generated successfully.")
        return "Image generated successfully!", image
    except Exception as e:
        logger.error(f"Generation failed: {e}")
        return f"Generation failed: {e}", None

# Gradio interface
interface = gr.Interface(
    fn=generate_image,
    inputs=[
        gr.Image(type="pil", label="Upload Reference Image"),
        gr.Textbox(label="Enter Prompt", placeholder="e.g., A photorealistic astronaut in space")
    ],
    outputs=[
        gr.Textbox(label="Status"),
        gr.Image(label="Generated Image")
    ],
    title="Face Reference Image Generator",
    description="Upload an image with a face and generate a new image."
)

logger.info("Launching Gradio interface...")
interface.launch()