srijaydeshpande commited on
Commit
a22d1ce
·
verified ·
1 Parent(s): 055d134

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +96 -133
app.py CHANGED
@@ -22,9 +22,57 @@ DEFAULT_NEGATIVE_PROMPT = (
22
 
23
  @spaces.GPU
24
  def run(*args):
25
- id_image = args[0]
26
- supp_images = args[1:4]
27
- prompt, neg_prompt, scale, n_samples, seed, steps, H, W, id_scale, mode, id_mix = args[4:]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  pipeline.debug_img_list = []
30
  if mode == 'fidelity':
@@ -38,16 +86,17 @@ def run(*args):
38
  else:
39
  raise ValueError
40
 
41
- if id_image is not None:
42
- id_image = resize_numpy_image_long(id_image, 1024)
 
 
43
  id_embeddings = pipeline.get_id_embedding(id_image)
44
- for supp_id_image in supp_images:
45
- if supp_id_image is not None:
46
- supp_id_image = resize_numpy_image_long(supp_id_image, 1024)
47
- supp_id_embeddings = pipeline.get_id_embedding(supp_id_image)
48
- id_embeddings = torch.cat(
49
- (id_embeddings, supp_id_embeddings if id_mix else supp_id_embeddings[:, :5]), dim=1
50
- )
51
  else:
52
  id_embeddings = None
53
 
@@ -57,6 +106,36 @@ def run(*args):
57
  img = pipeline.inference(prompt, (1, H, W), neg_prompt, id_embeddings, id_scale, scale, steps)[0]
58
  ims.append(np.array(img))
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  return ims, pipeline.debug_img_list
61
 
62
 
@@ -99,127 +178,11 @@ If you have any questions, feel free to open a discussion or contact us at <b>wu
99
  """ # noqa E501
100
 
101
 
102
- with gr.Blocks(title="PuLID", css=".gr-box {border-color: #8136e2}") as demo:
103
- gr.Markdown(_HEADER_)
104
- with gr.Row():
105
- with gr.Column():
106
- with gr.Row():
107
- face_image = gr.Image(label="ID image (main)", sources="upload", type="numpy", height=256)
108
- supp_image1 = gr.Image(
109
- label="Additional ID image (auxiliary)", sources="upload", type="numpy", height=256
110
- )
111
- supp_image2 = gr.Image(
112
- label="Additional ID image (auxiliary)", sources="upload", type="numpy", height=256
113
- )
114
- supp_image3 = gr.Image(
115
- label="Additional ID image (auxiliary)", sources="upload", type="numpy", height=256
116
- )
117
- prompt = gr.Textbox(label="Prompt", value='portrait,cinematic,wolf ears,white hair')
118
- submit = gr.Button("Generate")
119
- neg_prompt = gr.Textbox(label="Negative Prompt", value=DEFAULT_NEGATIVE_PROMPT)
120
- scale = gr.Slider(
121
- label="CFG, recommend value range [1, 1.5], 1 will be faster ",
122
- value=1.2,
123
- minimum=1,
124
- maximum=1.5,
125
- step=0.1,
126
- )
127
- n_samples = gr.Slider(label="Num samples", value=4, minimum=1, maximum=4, step=1)
128
- seed = gr.Slider(
129
- label="Seed", value=42, minimum=np.iinfo(np.uint32).min, maximum=np.iinfo(np.uint32).max, step=1
130
- )
131
- steps = gr.Slider(label="Steps", value=4, minimum=1, maximum=8, step=1)
132
- with gr.Row():
133
- H = gr.Slider(label="Height", value=1024, minimum=512, maximum=1280, step=64)
134
- W = gr.Slider(label="Width", value=768, minimum=512, maximum=1280, step=64)
135
- with gr.Row():
136
- id_scale = gr.Slider(label="ID scale", minimum=0, maximum=5, step=0.05, value=0.8, interactive=True)
137
- mode = gr.Dropdown(label="mode", choices=['fidelity', 'extremely style'], value='fidelity')
138
- id_mix = gr.Checkbox(
139
- label="ID Mix (if you want to mix two ID image, please turn this on, otherwise, turn this off)",
140
- value=False,
141
- )
142
-
143
- gr.Markdown("## Examples")
144
- example_inps = [
145
- [
146
- 'portrait,cinematic,wolf ears,white hair',
147
- 'example_inputs/liuyifei.png',
148
- 'fidelity',
149
- ]
150
- ]
151
- gr.Examples(examples=example_inps, inputs=[prompt, face_image, mode], label='realistic')
152
-
153
- example_inps = [
154
- [
155
- 'portrait, impressionist painting, loose brushwork, vibrant color, light and shadow play',
156
- 'example_inputs/zcy.webp',
157
- 'fidelity',
158
- ]
159
- ]
160
- gr.Examples(examples=example_inps, inputs=[prompt, face_image, mode], label='painting style')
161
-
162
- example_inps = [
163
- [
164
- 'portrait, flat papercut style, silhouette, clean cuts, paper, sharp edges, minimalist,color block,man',
165
- 'example_inputs/lecun.jpg',
166
- 'fidelity',
167
- ]
168
- ]
169
- gr.Examples(examples=example_inps, inputs=[prompt, face_image, mode], label='papercut style')
170
-
171
- example_inps = [
172
- [
173
- 'woman,cartoon,solo,Popmart Blind Box, Super Mario, 3d',
174
- 'example_inputs/rihanna.webp',
175
- 'fidelity',
176
- ]
177
- ]
178
- gr.Examples(examples=example_inps, inputs=[prompt, face_image, mode], label='3d style')
179
-
180
- example_inps = [
181
- [
182
- 'portrait, the legend of zelda, anime',
183
- 'example_inputs/liuyifei.png',
184
- 'extremely style',
185
- ]
186
- ]
187
- gr.Examples(examples=example_inps, inputs=[prompt, face_image, mode], label='anime style')
188
-
189
- example_inps = [
190
- [
191
- 'portrait, superman',
192
- 'example_inputs/lecun.jpg',
193
- 'example_inputs/lifeifei.jpg',
194
- 'fidelity',
195
- True,
196
- ]
197
- ]
198
- gr.Examples(examples=example_inps, inputs=[prompt, face_image, supp_image1, mode, id_mix], label='id mix')
199
-
200
- with gr.Column():
201
- output = gr.Gallery(label='Output', elem_id="gallery")
202
- intermediate_output = gr.Gallery(label='DebugImage', elem_id="gallery", visible=False)
203
- gr.Markdown(_CITE_)
204
-
205
- inps = [
206
- face_image,
207
- supp_image1,
208
- supp_image2,
209
- supp_image3,
210
- prompt,
211
- neg_prompt,
212
- scale,
213
- n_samples,
214
- seed,
215
- steps,
216
- H,
217
- W,
218
- id_scale,
219
- mode,
220
- id_mix,
221
- ]
222
- submit.click(fn=run, inputs=inps, outputs=[output, intermediate_output])
223
 
 
224
 
225
  demo.launch()
 
22
 
23
  @spaces.GPU
24
  def run(*args):
25
+
26
+ aws_access_key_id = 'AKIA2NMAMYX4K55CZ7HR'
27
+ BUCKET = 'syntheticai-headshots'
28
+ s3_client = boto3.client(
29
+ 's3',
30
+ aws_access_key_id=aws_access_key_id,
31
+ aws_secret_access_key=os.getenv('AMAZON_SECRET_KEY')
32
+ )
33
+
34
+ INPUT_BUCKET_FOLDER = bucket_folder #user_id/request_id/input/'
35
+ local_dir = req_id
36
+ os.makedirs(req_id, exist_ok=True)
37
+
38
+ # try:
39
+ response = s3_client.list_objects_v2(Bucket=BUCKET, Prefix=INPUT_BUCKET_FOLDER)
40
+
41
+ if 'Contents' in response:
42
+ for obj in response['Contents']:
43
+ s3_key = obj['Key']
44
+
45
+ if s3_key.endswith('/'):
46
+ continue
47
+
48
+ file_name = os.path.basename(s3_key)
49
+ local_path = os.path.join(local_dir, file_name)
50
+
51
+ s3_client.download_file(BUCKET, s3_key, local_path)
52
+
53
+ else:
54
+ print("No files found in that folder.")
55
+
56
+ # Get a list of image file extensions you want to include
57
+ image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.gif')
58
+ # Read image paths into a list
59
+ image_paths = [
60
+ os.path.join(local_dir, file)
61
+ for file in os.listdir(local_dir)
62
+ if file.lower().endswith(image_extensions)
63
+ ]
64
+
65
+ prompt = 'Professional LinkedIn-style headshot, symmetrical full face and upper body visible including shoulders and chest, centered composition with a small space above the head, wearing a formal suit and white shirt, neutral expression, captured from a short distance, realistic skin texture, exact face preserved, plain white or gray background, sharp focus, studio lighting, high-resolution, suitable for CV or resume'
66
+ neg_prompt = 'Wrong face, flaws in the eyes, flaws in the face, lowres, artifacts noise, text, deformed, partially rendered objects, deformed or partially rendered eyes, deformed eyeballs, cross-eyed, blurry'
67
+ scale = 1.2
68
+ n_samples = 5
69
+ seed = 0
70
+ steps = 1
71
+ H = 1024
72
+ W = 768
73
+ id_scale = 0.8
74
+ mode = 'fidelity'
75
+ id_mix = False
76
 
77
  pipeline.debug_img_list = []
78
  if mode == 'fidelity':
 
86
  else:
87
  raise ValueError
88
 
89
+ id_image = image_paths[0]
90
+
91
+ if image_paths is not None:
92
+ id_image = resize_numpy_image_long(image_paths[0], 1024)
93
  id_embeddings = pipeline.get_id_embedding(id_image)
94
+ for i in range(1,len(image_paths)):
95
+ supp_id_image = resize_numpy_image_long(image_paths[i], 1024)
96
+ supp_id_embeddings = pipeline.get_id_embedding(supp_id_image)
97
+ id_embeddings = torch.cat(
98
+ (id_embeddings, supp_id_embeddings if id_mix else supp_id_embeddings[:, :5]), dim=1
99
+ )
 
100
  else:
101
  id_embeddings = None
102
 
 
106
  img = pipeline.inference(prompt, (1, H, W), neg_prompt, id_embeddings, id_scale, scale, steps)[0]
107
  ims.append(np.array(img))
108
 
109
+ file_paths = []
110
+ for i, img in enumerate(ims):
111
+ if isinstance(img, torch.Tensor):
112
+ img = img.detach().cpu()
113
+ img = transforms.ToPILImage()(img)
114
+ elif not isinstance(img, Image.Image):
115
+ continue # skip unknown formats
116
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
117
+ img.save(temp_file.name)
118
+ file_paths.append(temp_file.name)
119
+
120
+ # Upload images to S3
121
+ OUTPUT_BUCKET_FOLDER = os.path.join(os.path.dirname(os.path.dirname(bucket_folder)), 'output') #'user_id/request_id/output/'
122
+ os.makedirs(OUTPUT_BUCKET_FOLDER, exist_ok=True)
123
+
124
+ try:
125
+ for local_file in file_paths:
126
+ file_name = os.path.basename(local_file)
127
+ s3_key = os.path.join(OUTPUT_BUCKET_FOLDER, file_name)
128
+ s3_client.upload_file(local_file, BUCKET, s3_key, ExtraArgs={
129
+ 'ContentType': 'image/jpeg'
130
+ })
131
+ except ClientError as e:
132
+ print('e', e)
133
+ print('ERROR OCCURRED while uploading data to S3')
134
+ return False
135
+
136
+ shutil.rmtree(req_id)
137
+ shutil.rmtree(OUTPUT_BUCKET_FOLDER)
138
+
139
  return ims, pipeline.debug_img_list
140
 
141
 
 
178
  """ # noqa E501
179
 
180
 
181
+ with gr.Blocks(title="AI headshot Generator", css=".gr-box {border-color: #8136e2}") as demo:
182
+
183
+ output = gr.Gallery(label='Output', elem_id="gallery")
184
+ intermediate_output = gr.Gallery(label='DebugImage', elem_id="gallery", visible=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
 
186
+ submit.click(fn=run, inputs=['textbox', 'textbox', 'textbox'], outputs=[output, intermediate_output])
187
 
188
  demo.launch()