srijaydeshpande's picture
Update app.py
5e59917 verified
import spaces
import gradio as gr
import numpy as np
import torch
import os, shutil
import tempfile
from pulid import attention_processor as attention
from pulid.pipeline import PuLIDPipeline
from pulid.utils import resize_numpy_image_long, seed_everything
import boto3
from botocore.exceptions import ClientError
from PIL import Image
from google import genai
from google.genai import types
from io import BytesIO
torch.set_grad_enabled(False)
pipeline = PuLIDPipeline()
# other params
DEFAULT_NEGATIVE_PROMPT = (
'flaws in the eyes, flaws in the face, flaws, lowres, non-HDRi, low quality, worst quality,'
'artifacts noise, text, watermark, glitch, deformed, mutated, ugly, disfigured, hands, '
'low resolution, partially rendered objects, deformed or partially rendered eyes, '
'deformed, deformed eyeballs, cross-eyed,blurry'
)
@spaces.GPU
def run(bucket_folder, person_name, req_id, gender='male'):
aws_access_key_id = 'AKIA2NMAMYX4K55CZ7HR'
BUCKET = 'syntheticai-headshots'
s3_client = boto3.client(
's3',
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=os.getenv('AMAZON_SECRET_KEY')
)
INPUT_BUCKET_FOLDER = bucket_folder #user_id/request_id/input/'
local_dir = req_id
os.makedirs(req_id, exist_ok=True)
# try:
response = s3_client.list_objects_v2(Bucket=BUCKET, Prefix=INPUT_BUCKET_FOLDER)
if 'Contents' in response:
for obj in response['Contents']:
s3_key = obj['Key']
if s3_key.endswith('/'):
continue
file_name = os.path.basename(s3_key)
local_path = os.path.join(local_dir, file_name)
s3_client.download_file(BUCKET, s3_key, local_path)
else:
print("No files found in that folder.")
# Get a list of image file extensions you want to include
image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.gif')
# Read image paths into a list
image_paths = [
os.path.join(local_dir, file)
for file in os.listdir(local_dir)
if file.lower().endswith(image_extensions)
]
if(gender.lower()=='female'):
prompt = 'Professional LinkedIn-style realistic headshot for woman named ' + person_name + ', symmetrical full face and upper body visible including shoulders and chest, centered composition with a small space above the head, wearing a smart casual, neutral expression, captured from a short distance, realistic skin texture, exact face preserved, plain white or gray background, sharp focus, studio lighting, high-resolution, suitable for CV or resume'
else:
prompt = 'Professional LinkedIn-style realistic headshot for man named ' + person_name + ', symmetrical full face and upper body visible including shoulders and chest, centered composition with a small space above the head, wearing a formal suit and white shirt, neutral expression, captured from a short distance, realistic skin texture, exact face preserved, plain white or gray background, sharp focus, studio lighting, high-resolution, suitable for CV or resume'
neg_prompt = 'flaws in the eyes, wide chin, flaws in the face, flaws, lowres, non-HDRi, low quality, worst quality,artifacts noise, glitch, deformed, disfigured, partially rendered objects, deformed or partially rendered eyes'
scale = 1.2
n_samples = 20
seed = 0
steps = 4
H = 1024
W = 768
id_scale = 0.8
mode = 'fidelity'
id_mix = False
pipeline.debug_img_list = []
if mode == 'fidelity':
attention.NUM_ZERO = 8
attention.ORTHO = False
attention.ORTHO_v2 = True
elif mode == 'extremely style':
attention.NUM_ZERO = 16
attention.ORTHO = True
attention.ORTHO_v2 = False
else:
raise ValueError
id_image = image_paths[0]
ims = []
#--------------Google Gemini---------------#
client = genai.Client(api_key=os.getenv('GEMINI_SECRET_KEY'))
image = Image.open(id_image)
try:
response = client.models.generate_content(
model="gemini-2.5-flash-image-preview",
contents=[prompt, image],
)
for part in response.candidates[0].content.parts:
if part.text is not None:
print(part.text)
elif part.inline_data is not None:
image = Image.open(BytesIO(part.inline_data.data))
ims.append(image)
except:
print('Error while calling GEMINI')
#----------------------------------------#
if image_paths is not None:
id_image = resize_numpy_image_long(np.array(Image.open(image_paths[0])), 1024)
id_embeddings = pipeline.get_id_embedding(id_image)
for i in range(1,len(image_paths)):
supp_id_image = resize_numpy_image_long(np.array(Image.open(image_paths[i])), 1024)
supp_id_embeddings = pipeline.get_id_embedding(supp_id_image)
id_embeddings = torch.cat(
(id_embeddings, supp_id_embeddings if id_mix else supp_id_embeddings[:, :5]), dim=1
)
else:
id_embeddings = None
seed_everything(seed)
for _ in range(n_samples):
img = pipeline.inference(prompt, (1, H, W), neg_prompt, id_embeddings, id_scale, scale, steps)[0]
ims.append(np.array(img))
file_paths = []
for i, img in enumerate(ims):
if isinstance(img, torch.Tensor):
img = img.detach().cpu()
img = transforms.ToPILImage()(img)
elif isinstance(img, np.ndarray):
img = Image.fromarray(img)
elif not isinstance(img, Image.Image):
continue # skip unknown formats
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
img.save(temp_file.name)
file_paths.append(temp_file.name)
# Upload images to S3
OUTPUT_BUCKET_FOLDER = os.path.join(os.path.dirname(os.path.dirname(bucket_folder)), 'output') #'user_id/request_id/output/'
os.makedirs(OUTPUT_BUCKET_FOLDER, exist_ok=True)
try:
for local_file in file_paths:
file_name = os.path.basename(local_file)
s3_key = os.path.join(OUTPUT_BUCKET_FOLDER, file_name)
print('S3 key is ', s3_key)
s3_client.upload_file(local_file, BUCKET, s3_key, ExtraArgs={
'ContentType': 'image/jpeg'
})
except ClientError as e:
print('e', e)
print('ERROR OCCURRED while uploading data to S3')
return False
shutil.rmtree(req_id)
shutil.rmtree(OUTPUT_BUCKET_FOLDER)
return ims, pipeline.debug_img_list
_HEADER_ = '''
<h2><b>Official 🤗 Gradio Demo</b></h2><h2><a href='https://github.com/ToTheBeginning/PuLID' target='_blank'><b>PuLID: Pure and Lightning ID Customization via Contrastive Alignment</b></a></h2>
**PuLID** is a tuning-free ID customization approach. PuLID maintains high ID fidelity while effectively reducing interference with the original model’s behavior.
Code: <a href='https://github.com/ToTheBeginning/PuLID' target='_blank'>GitHub</a>. Techenical report: <a href='https://arxiv.org/abs/2404.16022' target='_blank'>ArXiv</a>.
❗️❗️❗️**Tips:**
- we provide some examples in the bottom, you can try these example prompts first
- a single ID image is usually sufficient, you can also supplement with additional auxiliary images
- We offer two modes: fidelity mode and extremely style mode. In most cases, the default fidelity mode should suffice. If you find that the generated results are not stylized enough, you can choose the extremely style mode.
''' # noqa E501
_CITE_ = r"""
If PuLID is helpful, please help to ⭐ the <a href='https://github.com/ToTheBeginning/PuLID' target='_blank'>Github Repo</a>. Thanks! [![GitHub Stars](https://img.shields.io/github/stars/ToTheBeginning/PuLID?style=social)](https://github.com/ToTheBeginning/PuLID)
---
🚀 **Share**
If you have generated satisfying or interesting images with PuLID, please share them with us or your friends!
📝 **Citation**
If you find our work useful for your research or applications, please cite using this bibtex:
```bibtex
@article{guo2024pulid,
title={PuLID: Pure and Lightning ID Customization via Contrastive Alignment},
author={Guo, Zinan and Wu, Yanze and Chen, Zhuowei and Chen, Lang and He, Qian},
journal={arXiv preprint arXiv:2404.16022},
year={2024}
}
```
📋 **License**
Apache-2.0 LICENSE. Please refer to the [LICENSE file](placeholder) for details.
📧 **Contact**
If you have any questions, feel free to open a discussion or contact us at <b>wuyanze123@gmail.com</b> or <b>guozinan.1@bytedance.com</b>.
""" # noqa E501
with gr.Blocks(title="AI headshot Generator", css=".gr-box {border-color: #8136e2}") as demo:
output = gr.Gallery(label='Output', elem_id="gallery")
intermediate_output = gr.Gallery(label='DebugImage', elem_id="gallery", visible=False)
# submit = gr.Button("Generate")
# submit.click(fn=run, inputs=['textbox', 'textbox', 'textbox'], outputs=[output, intermediate_output])
demo = gr.Interface(
run,
inputs=['textbox', 'textbox', 'textbox', 'textbox'],
outputs=[output, intermediate_output],
title="AI Headshot Diffusor"
)
demo.launch()