Spaces:
Sleeping
Sleeping
| import spaces | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| import os, shutil | |
| import tempfile | |
| from pulid import attention_processor as attention | |
| from pulid.pipeline import PuLIDPipeline | |
| from pulid.utils import resize_numpy_image_long, seed_everything | |
| import boto3 | |
| from botocore.exceptions import ClientError | |
| from PIL import Image | |
| from google import genai | |
| from google.genai import types | |
| from io import BytesIO | |
| torch.set_grad_enabled(False) | |
| pipeline = PuLIDPipeline() | |
| # other params | |
| DEFAULT_NEGATIVE_PROMPT = ( | |
| 'flaws in the eyes, flaws in the face, flaws, lowres, non-HDRi, low quality, worst quality,' | |
| 'artifacts noise, text, watermark, glitch, deformed, mutated, ugly, disfigured, hands, ' | |
| 'low resolution, partially rendered objects, deformed or partially rendered eyes, ' | |
| 'deformed, deformed eyeballs, cross-eyed,blurry' | |
| ) | |
| def run(bucket_folder, person_name, req_id, gender='male'): | |
| aws_access_key_id = 'AKIA2NMAMYX4K55CZ7HR' | |
| BUCKET = 'syntheticai-headshots' | |
| s3_client = boto3.client( | |
| 's3', | |
| aws_access_key_id=aws_access_key_id, | |
| aws_secret_access_key=os.getenv('AMAZON_SECRET_KEY') | |
| ) | |
| INPUT_BUCKET_FOLDER = bucket_folder #user_id/request_id/input/' | |
| local_dir = req_id | |
| os.makedirs(req_id, exist_ok=True) | |
| # try: | |
| response = s3_client.list_objects_v2(Bucket=BUCKET, Prefix=INPUT_BUCKET_FOLDER) | |
| if 'Contents' in response: | |
| for obj in response['Contents']: | |
| s3_key = obj['Key'] | |
| if s3_key.endswith('/'): | |
| continue | |
| file_name = os.path.basename(s3_key) | |
| local_path = os.path.join(local_dir, file_name) | |
| s3_client.download_file(BUCKET, s3_key, local_path) | |
| else: | |
| print("No files found in that folder.") | |
| # Get a list of image file extensions you want to include | |
| image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.gif') | |
| # Read image paths into a list | |
| image_paths = [ | |
| os.path.join(local_dir, file) | |
| for file in os.listdir(local_dir) | |
| if file.lower().endswith(image_extensions) | |
| ] | |
| if(gender.lower()=='female'): | |
| prompt = 'Professional LinkedIn-style realistic headshot for woman named ' + person_name + ', symmetrical full face and upper body visible including shoulders and chest, centered composition with a small space above the head, wearing a smart casual, neutral expression, captured from a short distance, realistic skin texture, exact face preserved, plain white or gray background, sharp focus, studio lighting, high-resolution, suitable for CV or resume' | |
| else: | |
| prompt = 'Professional LinkedIn-style realistic headshot for man named ' + person_name + ', symmetrical full face and upper body visible including shoulders and chest, centered composition with a small space above the head, wearing a formal suit and white shirt, neutral expression, captured from a short distance, realistic skin texture, exact face preserved, plain white or gray background, sharp focus, studio lighting, high-resolution, suitable for CV or resume' | |
| neg_prompt = 'flaws in the eyes, wide chin, flaws in the face, flaws, lowres, non-HDRi, low quality, worst quality,artifacts noise, glitch, deformed, disfigured, partially rendered objects, deformed or partially rendered eyes' | |
| scale = 1.2 | |
| n_samples = 20 | |
| seed = 0 | |
| steps = 4 | |
| H = 1024 | |
| W = 768 | |
| id_scale = 0.8 | |
| mode = 'fidelity' | |
| id_mix = False | |
| pipeline.debug_img_list = [] | |
| if mode == 'fidelity': | |
| attention.NUM_ZERO = 8 | |
| attention.ORTHO = False | |
| attention.ORTHO_v2 = True | |
| elif mode == 'extremely style': | |
| attention.NUM_ZERO = 16 | |
| attention.ORTHO = True | |
| attention.ORTHO_v2 = False | |
| else: | |
| raise ValueError | |
| id_image = image_paths[0] | |
| ims = [] | |
| #--------------Google Gemini---------------# | |
| client = genai.Client(api_key=os.getenv('GEMINI_SECRET_KEY')) | |
| image = Image.open(id_image) | |
| try: | |
| response = client.models.generate_content( | |
| model="gemini-2.5-flash-image-preview", | |
| contents=[prompt, image], | |
| ) | |
| for part in response.candidates[0].content.parts: | |
| if part.text is not None: | |
| print(part.text) | |
| elif part.inline_data is not None: | |
| image = Image.open(BytesIO(part.inline_data.data)) | |
| ims.append(image) | |
| except: | |
| print('Error while calling GEMINI') | |
| #----------------------------------------# | |
| if image_paths is not None: | |
| id_image = resize_numpy_image_long(np.array(Image.open(image_paths[0])), 1024) | |
| id_embeddings = pipeline.get_id_embedding(id_image) | |
| for i in range(1,len(image_paths)): | |
| supp_id_image = resize_numpy_image_long(np.array(Image.open(image_paths[i])), 1024) | |
| supp_id_embeddings = pipeline.get_id_embedding(supp_id_image) | |
| id_embeddings = torch.cat( | |
| (id_embeddings, supp_id_embeddings if id_mix else supp_id_embeddings[:, :5]), dim=1 | |
| ) | |
| else: | |
| id_embeddings = None | |
| seed_everything(seed) | |
| for _ in range(n_samples): | |
| img = pipeline.inference(prompt, (1, H, W), neg_prompt, id_embeddings, id_scale, scale, steps)[0] | |
| ims.append(np.array(img)) | |
| file_paths = [] | |
| for i, img in enumerate(ims): | |
| if isinstance(img, torch.Tensor): | |
| img = img.detach().cpu() | |
| img = transforms.ToPILImage()(img) | |
| elif isinstance(img, np.ndarray): | |
| img = Image.fromarray(img) | |
| elif not isinstance(img, Image.Image): | |
| continue # skip unknown formats | |
| temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".png") | |
| img.save(temp_file.name) | |
| file_paths.append(temp_file.name) | |
| # Upload images to S3 | |
| OUTPUT_BUCKET_FOLDER = os.path.join(os.path.dirname(os.path.dirname(bucket_folder)), 'output') #'user_id/request_id/output/' | |
| os.makedirs(OUTPUT_BUCKET_FOLDER, exist_ok=True) | |
| try: | |
| for local_file in file_paths: | |
| file_name = os.path.basename(local_file) | |
| s3_key = os.path.join(OUTPUT_BUCKET_FOLDER, file_name) | |
| print('S3 key is ', s3_key) | |
| s3_client.upload_file(local_file, BUCKET, s3_key, ExtraArgs={ | |
| 'ContentType': 'image/jpeg' | |
| }) | |
| except ClientError as e: | |
| print('e', e) | |
| print('ERROR OCCURRED while uploading data to S3') | |
| return False | |
| shutil.rmtree(req_id) | |
| shutil.rmtree(OUTPUT_BUCKET_FOLDER) | |
| return ims, pipeline.debug_img_list | |
| _HEADER_ = ''' | |
| <h2><b>Official 🤗 Gradio Demo</b></h2><h2><a href='https://github.com/ToTheBeginning/PuLID' target='_blank'><b>PuLID: Pure and Lightning ID Customization via Contrastive Alignment</b></a></h2> | |
| **PuLID** is a tuning-free ID customization approach. PuLID maintains high ID fidelity while effectively reducing interference with the original model’s behavior. | |
| Code: <a href='https://github.com/ToTheBeginning/PuLID' target='_blank'>GitHub</a>. Techenical report: <a href='https://arxiv.org/abs/2404.16022' target='_blank'>ArXiv</a>. | |
| ❗️❗️❗️**Tips:** | |
| - we provide some examples in the bottom, you can try these example prompts first | |
| - a single ID image is usually sufficient, you can also supplement with additional auxiliary images | |
| - We offer two modes: fidelity mode and extremely style mode. In most cases, the default fidelity mode should suffice. If you find that the generated results are not stylized enough, you can choose the extremely style mode. | |
| ''' # noqa E501 | |
| _CITE_ = r""" | |
| If PuLID is helpful, please help to ⭐ the <a href='https://github.com/ToTheBeginning/PuLID' target='_blank'>Github Repo</a>. Thanks! [](https://github.com/ToTheBeginning/PuLID) | |
| --- | |
| 🚀 **Share** | |
| If you have generated satisfying or interesting images with PuLID, please share them with us or your friends! | |
| 📝 **Citation** | |
| If you find our work useful for your research or applications, please cite using this bibtex: | |
| ```bibtex | |
| @article{guo2024pulid, | |
| title={PuLID: Pure and Lightning ID Customization via Contrastive Alignment}, | |
| author={Guo, Zinan and Wu, Yanze and Chen, Zhuowei and Chen, Lang and He, Qian}, | |
| journal={arXiv preprint arXiv:2404.16022}, | |
| year={2024} | |
| } | |
| ``` | |
| 📋 **License** | |
| Apache-2.0 LICENSE. Please refer to the [LICENSE file](placeholder) for details. | |
| 📧 **Contact** | |
| If you have any questions, feel free to open a discussion or contact us at <b>wuyanze123@gmail.com</b> or <b>guozinan.1@bytedance.com</b>. | |
| """ # noqa E501 | |
| with gr.Blocks(title="AI headshot Generator", css=".gr-box {border-color: #8136e2}") as demo: | |
| output = gr.Gallery(label='Output', elem_id="gallery") | |
| intermediate_output = gr.Gallery(label='DebugImage', elem_id="gallery", visible=False) | |
| # submit = gr.Button("Generate") | |
| # submit.click(fn=run, inputs=['textbox', 'textbox', 'textbox'], outputs=[output, intermediate_output]) | |
| demo = gr.Interface( | |
| run, | |
| inputs=['textbox', 'textbox', 'textbox', 'textbox'], | |
| outputs=[output, intermediate_output], | |
| title="AI Headshot Diffusor" | |
| ) | |
| demo.launch() | |