File size: 9,326 Bytes
4eca20b
9eb3654
 
 
a8982fd
1768e52
9eb3654
 
 
 
 
5ffb4e8
 
f4d8e9b
5ffb4e8
0814d79
 
 
 
84183f6
1d0cad2
9eb3654
 
 
 
 
 
 
 
 
 
 
4eca20b
8b9cf54
a22d1ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ce2da6f
 
 
 
 
401d1e2
a22d1ce
0b078c0
a22d1ce
5c61b5f
a22d1ce
 
 
 
 
9eb3654
 
 
 
 
 
 
 
 
 
 
 
 
a22d1ce
 
0814d79
 
 
 
5e59917
0814d79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a22d1ce
f4d8e9b
9eb3654
a22d1ce
f4d8e9b
a22d1ce
 
 
 
9eb3654
 
 
 
 
 
 
eeaeb90
a22d1ce
 
 
 
 
e67c39b
 
a22d1ce
 
 
 
 
 
 
 
 
 
 
 
 
 
c549417
a22d1ce
c549417
 
a22d1ce
 
 
 
 
 
 
 
9eb3654
 
 
 
757482b
9eb3654
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a22d1ce
 
 
 
5ffb4e8
 
 
 
 
8b9cf54
5ffb4e8
 
 
9eb3654
c437b94
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
import spaces
import gradio as gr
import numpy as np
import torch
import os, shutil 
import tempfile

from pulid import attention_processor as attention
from pulid.pipeline import PuLIDPipeline
from pulid.utils import resize_numpy_image_long, seed_everything

import boto3
from botocore.exceptions import ClientError
from PIL import Image

from google import genai
from google.genai import types
from io import BytesIO

torch.set_grad_enabled(False)

pipeline = PuLIDPipeline()

# other params
DEFAULT_NEGATIVE_PROMPT = (
    'flaws in the eyes, flaws in the face, flaws, lowres, non-HDRi, low quality, worst quality,'
    'artifacts noise, text, watermark, glitch, deformed, mutated, ugly, disfigured, hands, '
    'low resolution, partially rendered objects,  deformed or partially rendered eyes, '
    'deformed, deformed eyeballs, cross-eyed,blurry'
)


@spaces.GPU
def run(bucket_folder, person_name, req_id, gender='male'):

    aws_access_key_id = 'AKIA2NMAMYX4K55CZ7HR'
    BUCKET = 'syntheticai-headshots'
    s3_client = boto3.client(
            's3',
            aws_access_key_id=aws_access_key_id,
            aws_secret_access_key=os.getenv('AMAZON_SECRET_KEY')
        )
    
    INPUT_BUCKET_FOLDER = bucket_folder #user_id/request_id/input/'
    local_dir = req_id
    os.makedirs(req_id, exist_ok=True)

    # try:
    response = s3_client.list_objects_v2(Bucket=BUCKET, Prefix=INPUT_BUCKET_FOLDER)

    if 'Contents' in response:
        for obj in response['Contents']:
            s3_key = obj['Key']

            if s3_key.endswith('/'):
                continue

            file_name = os.path.basename(s3_key)
            local_path = os.path.join(local_dir, file_name)

            s3_client.download_file(BUCKET, s3_key, local_path)

    else:
        print("No files found in that folder.")
    
    # Get a list of image file extensions you want to include
    image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.gif')
    # Read image paths into a list
    image_paths = [
        os.path.join(local_dir, file)
        for file in os.listdir(local_dir)
        if file.lower().endswith(image_extensions)
    ]

    if(gender.lower()=='female'):
        prompt = 'Professional LinkedIn-style realistic headshot for woman named ' + person_name + ', symmetrical full face and upper body visible including shoulders and chest, centered composition with a small space above the head, wearing a smart casual, neutral expression, captured from a short distance, realistic skin texture, exact face preserved, plain white or gray background, sharp focus, studio lighting, high-resolution, suitable for CV or resume'
    else:
        prompt = 'Professional LinkedIn-style realistic headshot for man named ' + person_name + ', symmetrical full face and upper body visible including shoulders and chest, centered composition with a small space above the head, wearing a formal suit and white shirt, neutral expression, captured from a short distance, realistic skin texture, exact face preserved, plain white or gray background, sharp focus, studio lighting, high-resolution, suitable for CV or resume'

    neg_prompt = 'flaws in the eyes, wide chin, flaws in the face, flaws, lowres, non-HDRi, low quality, worst quality,artifacts noise, glitch, deformed, disfigured, partially rendered objects,  deformed or partially rendered eyes'
    scale = 1.2
    n_samples = 20
    seed = 0
    steps = 4
    H = 1024
    W = 768
    id_scale = 0.8
    mode = 'fidelity'
    id_mix = False 

    pipeline.debug_img_list = []
    if mode == 'fidelity':
        attention.NUM_ZERO = 8
        attention.ORTHO = False
        attention.ORTHO_v2 = True
    elif mode == 'extremely style':
        attention.NUM_ZERO = 16
        attention.ORTHO = True
        attention.ORTHO_v2 = False
    else:
        raise ValueError

    id_image = image_paths[0]

    ims = []

    #--------------Google Gemini---------------#

    client = genai.Client(api_key=os.getenv('GEMINI_SECRET_KEY'))

    image = Image.open(id_image)

    try:
        response = client.models.generate_content(
            model="gemini-2.5-flash-image-preview",
            contents=[prompt, image],
        )
        
        for part in response.candidates[0].content.parts:
            if part.text is not None:
                print(part.text)
            elif part.inline_data is not None:
                image = Image.open(BytesIO(part.inline_data.data))
        ims.append(image)
    except:
        print('Error while calling GEMINI')

        
    #----------------------------------------#


    if image_paths is not None:
        id_image = resize_numpy_image_long(np.array(Image.open(image_paths[0])), 1024)
        id_embeddings = pipeline.get_id_embedding(id_image)
        for i in range(1,len(image_paths)):
            supp_id_image = resize_numpy_image_long(np.array(Image.open(image_paths[i])), 1024)
            supp_id_embeddings = pipeline.get_id_embedding(supp_id_image)
            id_embeddings = torch.cat(
                (id_embeddings, supp_id_embeddings if id_mix else supp_id_embeddings[:, :5]), dim=1
            )
    else:
        id_embeddings = None

    seed_everything(seed)
    for _ in range(n_samples):
        img = pipeline.inference(prompt, (1, H, W), neg_prompt, id_embeddings, id_scale, scale, steps)[0]
        ims.append(np.array(img))
    
    file_paths = []
    for i, img in enumerate(ims):
        if isinstance(img, torch.Tensor):
            img = img.detach().cpu()
            img = transforms.ToPILImage()(img)
        elif isinstance(img, np.ndarray):
            img = Image.fromarray(img)
        elif not isinstance(img, Image.Image):
            continue  # skip unknown formats
        temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
        img.save(temp_file.name)
        file_paths.append(temp_file.name)

    # Upload images to S3
    OUTPUT_BUCKET_FOLDER = os.path.join(os.path.dirname(os.path.dirname(bucket_folder)), 'output') #'user_id/request_id/output/'
    os.makedirs(OUTPUT_BUCKET_FOLDER, exist_ok=True)
    
    try:
        for local_file in file_paths:
            file_name = os.path.basename(local_file)
            s3_key = os.path.join(OUTPUT_BUCKET_FOLDER, file_name)
            print('S3 key is ', s3_key)
            s3_client.upload_file(local_file, BUCKET, s3_key, ExtraArgs={
                        'ContentType': 'image/jpeg'
                        })
    except ClientError as e:
        print('e', e)
        print('ERROR OCCURRED while uploading data to S3')
        return False

    shutil.rmtree(req_id)
    shutil.rmtree(OUTPUT_BUCKET_FOLDER)

    return ims, pipeline.debug_img_list


_HEADER_ = '''
<h2><b>Official 🤗 Gradio Demo</b></h2><h2><a href='https://github.com/ToTheBeginning/PuLID' target='_blank'><b>PuLID: Pure and Lightning ID Customization via Contrastive Alignment</b></a></h2>

**PuLID** is a tuning-free ID customization approach. PuLID maintains high ID fidelity while effectively reducing interference with the original model’s behavior.

Code: <a href='https://github.com/ToTheBeginning/PuLID' target='_blank'>GitHub</a>. Techenical report: <a href='https://arxiv.org/abs/2404.16022' target='_blank'>ArXiv</a>.

❗️❗️❗️**Tips:**
- we provide some examples in the bottom, you can try these example prompts first
- a single ID image is usually sufficient, you can also supplement with additional auxiliary images
- We offer two modes: fidelity mode and extremely style mode. In most cases, the default fidelity mode should suffice. If you find that the generated results are not stylized enough, you can choose the extremely style mode.

'''  # noqa E501

_CITE_ = r"""
If PuLID is helpful, please help to ⭐ the <a href='https://github.com/ToTheBeginning/PuLID' target='_blank'>Github Repo</a>. Thanks! [![GitHub Stars](https://img.shields.io/github/stars/ToTheBeginning/PuLID?style=social)](https://github.com/ToTheBeginning/PuLID)
---
🚀 **Share**
If you have generated satisfying or interesting images with PuLID, please share them with us or your friends!

📝 **Citation**
If you find our work useful for your research or applications, please cite using this bibtex:
```bibtex
@article{guo2024pulid,
  title={PuLID: Pure and Lightning ID Customization via Contrastive Alignment},
  author={Guo, Zinan and Wu, Yanze and Chen, Zhuowei and Chen, Lang and He, Qian},
  journal={arXiv preprint arXiv:2404.16022},
  year={2024}
}
```

📋 **License**
Apache-2.0 LICENSE. Please refer to the [LICENSE file](placeholder) for details.

📧 **Contact**
If you have any questions, feel free to open a discussion or contact us at <b>wuyanze123@gmail.com</b> or <b>guozinan.1@bytedance.com</b>.
"""  # noqa E501


with gr.Blocks(title="AI headshot Generator", css=".gr-box {border-color: #8136e2}") as demo:

    output = gr.Gallery(label='Output', elem_id="gallery")
    intermediate_output = gr.Gallery(label='DebugImage', elem_id="gallery", visible=False)
    # submit = gr.Button("Generate")
    # submit.click(fn=run, inputs=['textbox', 'textbox', 'textbox'], outputs=[output, intermediate_output])

demo = gr.Interface(
    run,
    inputs=['textbox', 'textbox', 'textbox', 'textbox'], 
    outputs=[output, intermediate_output],
    title="AI Headshot Diffusor"
)

demo.launch()