Spaces:
Sleeping
Sleeping
File size: 9,326 Bytes
4eca20b 9eb3654 a8982fd 1768e52 9eb3654 5ffb4e8 f4d8e9b 5ffb4e8 0814d79 84183f6 1d0cad2 9eb3654 4eca20b 8b9cf54 a22d1ce ce2da6f 401d1e2 a22d1ce 0b078c0 a22d1ce 5c61b5f a22d1ce 9eb3654 a22d1ce 0814d79 5e59917 0814d79 a22d1ce f4d8e9b 9eb3654 a22d1ce f4d8e9b a22d1ce 9eb3654 eeaeb90 a22d1ce e67c39b a22d1ce c549417 a22d1ce c549417 a22d1ce 9eb3654 757482b 9eb3654 a22d1ce 5ffb4e8 8b9cf54 5ffb4e8 9eb3654 c437b94 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 |
import spaces
import gradio as gr
import numpy as np
import torch
import os, shutil
import tempfile
from pulid import attention_processor as attention
from pulid.pipeline import PuLIDPipeline
from pulid.utils import resize_numpy_image_long, seed_everything
import boto3
from botocore.exceptions import ClientError
from PIL import Image
from google import genai
from google.genai import types
from io import BytesIO
torch.set_grad_enabled(False)
pipeline = PuLIDPipeline()
# other params
DEFAULT_NEGATIVE_PROMPT = (
'flaws in the eyes, flaws in the face, flaws, lowres, non-HDRi, low quality, worst quality,'
'artifacts noise, text, watermark, glitch, deformed, mutated, ugly, disfigured, hands, '
'low resolution, partially rendered objects, deformed or partially rendered eyes, '
'deformed, deformed eyeballs, cross-eyed,blurry'
)
@spaces.GPU
def run(bucket_folder, person_name, req_id, gender='male'):
aws_access_key_id = 'AKIA2NMAMYX4K55CZ7HR'
BUCKET = 'syntheticai-headshots'
s3_client = boto3.client(
's3',
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=os.getenv('AMAZON_SECRET_KEY')
)
INPUT_BUCKET_FOLDER = bucket_folder #user_id/request_id/input/'
local_dir = req_id
os.makedirs(req_id, exist_ok=True)
# try:
response = s3_client.list_objects_v2(Bucket=BUCKET, Prefix=INPUT_BUCKET_FOLDER)
if 'Contents' in response:
for obj in response['Contents']:
s3_key = obj['Key']
if s3_key.endswith('/'):
continue
file_name = os.path.basename(s3_key)
local_path = os.path.join(local_dir, file_name)
s3_client.download_file(BUCKET, s3_key, local_path)
else:
print("No files found in that folder.")
# Get a list of image file extensions you want to include
image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.gif')
# Read image paths into a list
image_paths = [
os.path.join(local_dir, file)
for file in os.listdir(local_dir)
if file.lower().endswith(image_extensions)
]
if(gender.lower()=='female'):
prompt = 'Professional LinkedIn-style realistic headshot for woman named ' + person_name + ', symmetrical full face and upper body visible including shoulders and chest, centered composition with a small space above the head, wearing a smart casual, neutral expression, captured from a short distance, realistic skin texture, exact face preserved, plain white or gray background, sharp focus, studio lighting, high-resolution, suitable for CV or resume'
else:
prompt = 'Professional LinkedIn-style realistic headshot for man named ' + person_name + ', symmetrical full face and upper body visible including shoulders and chest, centered composition with a small space above the head, wearing a formal suit and white shirt, neutral expression, captured from a short distance, realistic skin texture, exact face preserved, plain white or gray background, sharp focus, studio lighting, high-resolution, suitable for CV or resume'
neg_prompt = 'flaws in the eyes, wide chin, flaws in the face, flaws, lowres, non-HDRi, low quality, worst quality,artifacts noise, glitch, deformed, disfigured, partially rendered objects, deformed or partially rendered eyes'
scale = 1.2
n_samples = 20
seed = 0
steps = 4
H = 1024
W = 768
id_scale = 0.8
mode = 'fidelity'
id_mix = False
pipeline.debug_img_list = []
if mode == 'fidelity':
attention.NUM_ZERO = 8
attention.ORTHO = False
attention.ORTHO_v2 = True
elif mode == 'extremely style':
attention.NUM_ZERO = 16
attention.ORTHO = True
attention.ORTHO_v2 = False
else:
raise ValueError
id_image = image_paths[0]
ims = []
#--------------Google Gemini---------------#
client = genai.Client(api_key=os.getenv('GEMINI_SECRET_KEY'))
image = Image.open(id_image)
try:
response = client.models.generate_content(
model="gemini-2.5-flash-image-preview",
contents=[prompt, image],
)
for part in response.candidates[0].content.parts:
if part.text is not None:
print(part.text)
elif part.inline_data is not None:
image = Image.open(BytesIO(part.inline_data.data))
ims.append(image)
except:
print('Error while calling GEMINI')
#----------------------------------------#
if image_paths is not None:
id_image = resize_numpy_image_long(np.array(Image.open(image_paths[0])), 1024)
id_embeddings = pipeline.get_id_embedding(id_image)
for i in range(1,len(image_paths)):
supp_id_image = resize_numpy_image_long(np.array(Image.open(image_paths[i])), 1024)
supp_id_embeddings = pipeline.get_id_embedding(supp_id_image)
id_embeddings = torch.cat(
(id_embeddings, supp_id_embeddings if id_mix else supp_id_embeddings[:, :5]), dim=1
)
else:
id_embeddings = None
seed_everything(seed)
for _ in range(n_samples):
img = pipeline.inference(prompt, (1, H, W), neg_prompt, id_embeddings, id_scale, scale, steps)[0]
ims.append(np.array(img))
file_paths = []
for i, img in enumerate(ims):
if isinstance(img, torch.Tensor):
img = img.detach().cpu()
img = transforms.ToPILImage()(img)
elif isinstance(img, np.ndarray):
img = Image.fromarray(img)
elif not isinstance(img, Image.Image):
continue # skip unknown formats
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
img.save(temp_file.name)
file_paths.append(temp_file.name)
# Upload images to S3
OUTPUT_BUCKET_FOLDER = os.path.join(os.path.dirname(os.path.dirname(bucket_folder)), 'output') #'user_id/request_id/output/'
os.makedirs(OUTPUT_BUCKET_FOLDER, exist_ok=True)
try:
for local_file in file_paths:
file_name = os.path.basename(local_file)
s3_key = os.path.join(OUTPUT_BUCKET_FOLDER, file_name)
print('S3 key is ', s3_key)
s3_client.upload_file(local_file, BUCKET, s3_key, ExtraArgs={
'ContentType': 'image/jpeg'
})
except ClientError as e:
print('e', e)
print('ERROR OCCURRED while uploading data to S3')
return False
shutil.rmtree(req_id)
shutil.rmtree(OUTPUT_BUCKET_FOLDER)
return ims, pipeline.debug_img_list
_HEADER_ = '''
<h2><b>Official 🤗 Gradio Demo</b></h2><h2><a href='https://github.com/ToTheBeginning/PuLID' target='_blank'><b>PuLID: Pure and Lightning ID Customization via Contrastive Alignment</b></a></h2>
**PuLID** is a tuning-free ID customization approach. PuLID maintains high ID fidelity while effectively reducing interference with the original model’s behavior.
Code: <a href='https://github.com/ToTheBeginning/PuLID' target='_blank'>GitHub</a>. Techenical report: <a href='https://arxiv.org/abs/2404.16022' target='_blank'>ArXiv</a>.
❗️❗️❗️**Tips:**
- we provide some examples in the bottom, you can try these example prompts first
- a single ID image is usually sufficient, you can also supplement with additional auxiliary images
- We offer two modes: fidelity mode and extremely style mode. In most cases, the default fidelity mode should suffice. If you find that the generated results are not stylized enough, you can choose the extremely style mode.
''' # noqa E501
_CITE_ = r"""
If PuLID is helpful, please help to ⭐ the <a href='https://github.com/ToTheBeginning/PuLID' target='_blank'>Github Repo</a>. Thanks! [](https://github.com/ToTheBeginning/PuLID)
---
🚀 **Share**
If you have generated satisfying or interesting images with PuLID, please share them with us or your friends!
📝 **Citation**
If you find our work useful for your research or applications, please cite using this bibtex:
```bibtex
@article{guo2024pulid,
title={PuLID: Pure and Lightning ID Customization via Contrastive Alignment},
author={Guo, Zinan and Wu, Yanze and Chen, Zhuowei and Chen, Lang and He, Qian},
journal={arXiv preprint arXiv:2404.16022},
year={2024}
}
```
📋 **License**
Apache-2.0 LICENSE. Please refer to the [LICENSE file](placeholder) for details.
📧 **Contact**
If you have any questions, feel free to open a discussion or contact us at <b>wuyanze123@gmail.com</b> or <b>guozinan.1@bytedance.com</b>.
""" # noqa E501
with gr.Blocks(title="AI headshot Generator", css=".gr-box {border-color: #8136e2}") as demo:
output = gr.Gallery(label='Output', elem_id="gallery")
intermediate_output = gr.Gallery(label='DebugImage', elem_id="gallery", visible=False)
# submit = gr.Button("Generate")
# submit.click(fn=run, inputs=['textbox', 'textbox', 'textbox'], outputs=[output, intermediate_output])
demo = gr.Interface(
run,
inputs=['textbox', 'textbox', 'textbox', 'textbox'],
outputs=[output, intermediate_output],
title="AI Headshot Diffusor"
)
demo.launch()
|