fantaxy commited on
Commit
1bf58f3
·
verified ·
1 Parent(s): 862fe5a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +100 -58
app.py CHANGED
@@ -6,7 +6,6 @@ import gradio as gr
6
  import torch
7
  from einops import rearrange
8
  from PIL import Image
9
- from transformers import pipeline
10
 
11
  from flux.cli import SamplingOptions
12
  from flux.sampling import denoise, get_noise, get_schedule, prepare, unpack
@@ -16,6 +15,7 @@ from pulid.utils import resize_numpy_image_long
16
 
17
  NSFW_THRESHOLD = 0.85
18
 
 
19
  def get_models(name: str, device: torch.device, offload: bool):
20
  t5 = load_t5(device, max_length=128)
21
  clip = load_clip(device)
@@ -27,15 +27,17 @@ def get_models(name: str, device: torch.device, offload: bool):
27
 
28
  class FluxGenerator:
29
  def __init__(self):
30
- self.device = torch.device('cuda')
31
  self.offload = False
32
- self.model_name = 'flux-dev'
33
  self.model, self.ae, self.t5, self.clip = get_models(
34
  self.model_name,
35
  device=self.device,
36
  offload=self.offload,
37
  )
38
- self.pulid_model = PuLIDPipeline(self.model, 'cuda', weight_dtype=torch.bfloat16)
 
 
39
  self.pulid_model.load_pretrain()
40
 
41
 
@@ -45,19 +47,19 @@ flux_generator = FluxGenerator()
45
  @spaces.GPU
46
  @torch.inference_mode()
47
  def generate_image(
48
- width,
49
- height,
50
- num_steps,
51
- start_step,
52
- guidance,
53
- seed,
54
- prompt,
55
- id_image=None,
56
- id_weight=1.0,
57
- neg_prompt="",
58
- true_cfg=1.0,
59
- timestep_to_start_cfg=1,
60
- max_sequence_length=128,
61
  ):
62
  flux_generator.t5.max_length = max_sequence_length
63
 
@@ -83,7 +85,9 @@ def generate_image(
83
 
84
  if id_image is not None:
85
  id_image = resize_numpy_image_long(id_image, 1024)
86
- id_embeddings, uncond_id_embeddings = flux_generator.pulid_model.get_id_embedding(id_image, cal_uncond=use_true_cfg)
 
 
87
  else:
88
  id_embeddings = None
89
  uncond_id_embeddings = None
@@ -96,7 +100,7 @@ def generate_image(
96
  opts.height,
97
  opts.width,
98
  device=flux_generator.device,
99
- dtype=torch.bfloat16,
100
  seed=opts.seed,
101
  )
102
  print(x)
@@ -107,7 +111,10 @@ def generate_image(
107
  )
108
 
109
  if flux_generator.offload:
110
- flux_generator.t5, flux_generator.clip = flux_generator.t5.to(flux_generator.device), flux_generator.clip.to(flux_generator.device)
 
 
 
111
  inp = prepare(t5=flux_generator.t5, clip=flux_generator.clip, img=x, prompt=opts.prompt)
112
  inp_neg = prepare(t5=flux_generator.t5, clip=flux_generator.clip, img=x, prompt=neg_prompt) if use_true_cfg else None
113
 
@@ -119,8 +126,15 @@ def generate_image(
119
 
120
  # denoise initial noise
121
  x = denoise(
122
- flux_generator.model, **inp, timesteps=timesteps, guidance=opts.guidance, id=id_embeddings, id_weight=id_weight,
123
- start_step=start_step, uncond_id=uncond_id_embeddings, true_cfg=true_cfg,
 
 
 
 
 
 
 
124
  timestep_to_start_cfg=timestep_to_start_cfg,
125
  neg_txt=inp_neg["txt"] if use_true_cfg else None,
126
  neg_txt_ids=inp_neg["txt_ids"] if use_true_cfg else None,
@@ -135,7 +149,10 @@ def generate_image(
135
 
136
  # decode latents to pixel space
137
  x = unpack(x.float(), opts.height, opts.width)
138
- with torch.autocast(device_type=flux_generator.device.type, dtype=torch.bfloat16):
 
 
 
139
  x = flux_generator.ae.decode(x)
140
 
141
  if flux_generator.offload:
@@ -147,15 +164,13 @@ def generate_image(
147
  print(f"Done in {t1 - t0:.1f}s.")
148
  # bring into PIL format
149
  x = x.clamp(-1, 1)
150
- # x = embed_watermark(x.float())
151
  x = rearrange(x[0], "c h w -> h w c")
152
 
153
  img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
154
  return img, str(opts.seed), flux_generator.pulid_model.debug_img_list
155
 
156
- def create_demo(args, model_name: str, device: str = "cuda" if torch.cuda.is_available() else "cpu",
157
- offload: bool = False):
158
-
159
  with gr.Blocks(theme="soft") as demo:
160
  gr.HTML(
161
  """
@@ -163,14 +178,14 @@ def create_demo(args, model_name: str, device: str = "cuda" if torch.cuda.is_ava
163
  <a href="https://huggingface.co/spaces/openfree/Best-AI" target="_blank">
164
  <img src="https://img.shields.io/static/v1?label=OpenFree&message=BEST%20AI%20Services&color=%230000ff&labelColor=%23000080&logo=huggingface&logoColor=%23ffa500&style=for-the-badge" alt="OpenFree badge">
165
  </a>
166
-
167
  <a href="https://discord.gg/openfreeai" target="_blank">
168
  <img src="https://img.shields.io/static/v1?label=Discord&message=Openfree%20AI&color=%230000ff&labelColor=%23800080&logo=discord&logoColor=white&style=for-the-badge" alt="Discord badge">
169
  </a>
170
  </div>
171
  """
172
  )
173
-
174
  with gr.Row():
175
  with gr.Column():
176
  prompt = gr.Textbox(label="Prompt", value="portrait, color, cinematic")
@@ -183,75 +198,102 @@ def create_demo(args, model_name: str, device: str = "cuda" if torch.cuda.is_ava
183
  start_step = gr.Slider(0, 10, 0, step=1, label="timestep to start inserting ID")
184
  guidance = gr.Slider(1.0, 10.0, 4, step=0.1, label="Guidance")
185
  seed = gr.Textbox(-1, label="Seed (-1 for random)")
186
- max_sequence_length = gr.Slider(128, 512, 128, step=128,
187
- label="max_sequence_length for prompt (T5), small will be faster")
188
 
189
- with gr.Accordion("Advanced Options (True CFG, true_cfg_scale=1 means use fake CFG, >1 means use true CFG, if using true CFG, we recommend set the guidance scale to 1)", open=False): # noqa E501
 
 
 
190
  neg_prompt = gr.Textbox(
191
  label="Negative Prompt",
192
- value="bad quality, worst quality, text, signature, watermark, extra limbs")
 
193
  true_cfg = gr.Slider(1.0, 10.0, 1, step=0.1, label="true CFG scale")
194
  timestep_to_start_cfg = gr.Slider(0, 20, 1, step=1, label="timestep to start cfg", visible=args.dev)
195
 
196
  generate_btn = gr.Button("Generate")
197
-
198
  with gr.Column():
199
  output_image = gr.Image(label="Generated Image")
200
  seed_output = gr.Textbox(label="Used Seed")
201
- intermediate_output = gr.Gallery(label='Output', elem_id="gallery", visible=args.dev)
202
 
203
  with gr.Row(), gr.Column():
204
  gr.Markdown("## Examples")
205
  example_inps = [
206
  [
207
- 'a woman holding sign with glowing green text \"PuLID for FLUX\"',
208
- 'example_inputs/qw1.webp',
209
- 4, 4, 2680261499100305976, 1
 
 
 
210
  ],
211
  [
212
- 'portrait, pixar',
213
- 'example_inputs/qw2.webp',
214
- 1, 4, 9445036702517583939, 1
 
 
 
215
  ],
216
  ]
217
- gr.Examples(examples=example_inps, inputs=[prompt, id_image, start_step, guidance, seed, true_cfg],
218
- label='fake CFG')
219
 
220
  example_inps = [
221
  [
222
- 'portrait, made of ice sculpture',
223
- 'example_inputs/qw3.webp',
224
- 1, 1, 3811899118709451814, 5
 
 
 
225
  ],
226
  ]
227
- gr.Examples(examples=example_inps, inputs=[prompt, id_image, start_step, guidance, seed, true_cfg],
228
- label='true CFG')
229
 
230
  generate_btn.click(
231
  fn=generate_image,
232
- inputs=[width, height, num_steps, start_step, guidance, seed, prompt, id_image, id_weight, neg_prompt,
233
- true_cfg, timestep_to_start_cfg, max_sequence_length],
 
 
 
 
 
 
 
 
 
 
 
 
 
234
  outputs=[output_image, seed_output, intermediate_output],
235
  )
236
 
237
  return demo
238
 
 
239
  if __name__ == "__main__":
240
  import argparse
241
 
242
  parser = argparse.ArgumentParser(description="PuLID for FLUX.1-dev")
243
- parser.add_argument("--name", type=str, default="flux-dev", choices=list('flux-dev'),
244
- help="currently only support flux-dev")
245
- parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu",
246
- help="Device to use")
247
  parser.add_argument("--offload", action="store_true", help="Offload model to CPU when not in use")
248
  parser.add_argument("--port", type=int, default=8080, help="Port to use")
249
- parser.add_argument("--dev", action='store_true', help="Development mode")
250
- parser.add_argument("--pretrained_model", type=str, help='for development')
251
  args = parser.parse_args()
252
 
253
  import huggingface_hub
254
- huggingface_hub.login(os.getenv('HF_TOKEN'))
 
 
 
255
 
256
  demo = create_demo(args, args.name, args.device, args.offload)
257
- demo.launch()
 
6
  import torch
7
  from einops import rearrange
8
  from PIL import Image
 
9
 
10
  from flux.cli import SamplingOptions
11
  from flux.sampling import denoise, get_noise, get_schedule, prepare, unpack
 
15
 
16
  NSFW_THRESHOLD = 0.85
17
 
18
+
19
  def get_models(name: str, device: torch.device, offload: bool):
20
  t5 = load_t5(device, max_length=128)
21
  clip = load_clip(device)
 
27
 
28
  class FluxGenerator:
29
  def __init__(self):
30
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
  self.offload = False
32
+ self.model_name = "flux-dev"
33
  self.model, self.ae, self.t5, self.clip = get_models(
34
  self.model_name,
35
  device=self.device,
36
  offload=self.offload,
37
  )
38
+ device_str = "cuda" if torch.cuda.is_available() else "cpu"
39
+ weight_dtype = torch.bfloat16 if device_str == "cuda" else torch.float32
40
+ self.pulid_model = PuLIDPipeline(self.model, device_str, weight_dtype=weight_dtype)
41
  self.pulid_model.load_pretrain()
42
 
43
 
 
47
  @spaces.GPU
48
  @torch.inference_mode()
49
  def generate_image(
50
+ width,
51
+ height,
52
+ num_steps,
53
+ start_step,
54
+ guidance,
55
+ seed,
56
+ prompt,
57
+ id_image=None,
58
+ id_weight=1.0,
59
+ neg_prompt="",
60
+ true_cfg=1.0,
61
+ timestep_to_start_cfg=1,
62
+ max_sequence_length=128,
63
  ):
64
  flux_generator.t5.max_length = max_sequence_length
65
 
 
85
 
86
  if id_image is not None:
87
  id_image = resize_numpy_image_long(id_image, 1024)
88
+ id_embeddings, uncond_id_embeddings = flux_generator.pulid_model.get_id_embedding(
89
+ id_image, cal_uncond=use_true_cfg
90
+ )
91
  else:
92
  id_embeddings = None
93
  uncond_id_embeddings = None
 
100
  opts.height,
101
  opts.width,
102
  device=flux_generator.device,
103
+ dtype=torch.bfloat16 if flux_generator.device.type == "cuda" else torch.float32,
104
  seed=opts.seed,
105
  )
106
  print(x)
 
111
  )
112
 
113
  if flux_generator.offload:
114
+ flux_generator.t5, flux_generator.clip = (
115
+ flux_generator.t5.to(flux_generator.device),
116
+ flux_generator.clip.to(flux_generator.device),
117
+ )
118
  inp = prepare(t5=flux_generator.t5, clip=flux_generator.clip, img=x, prompt=opts.prompt)
119
  inp_neg = prepare(t5=flux_generator.t5, clip=flux_generator.clip, img=x, prompt=neg_prompt) if use_true_cfg else None
120
 
 
126
 
127
  # denoise initial noise
128
  x = denoise(
129
+ flux_generator.model,
130
+ **inp,
131
+ timesteps=timesteps,
132
+ guidance=opts.guidance,
133
+ id=id_embeddings,
134
+ id_weight=id_weight,
135
+ start_step=start_step,
136
+ uncond_id=uncond_id_embeddings,
137
+ true_cfg=true_cfg,
138
  timestep_to_start_cfg=timestep_to_start_cfg,
139
  neg_txt=inp_neg["txt"] if use_true_cfg else None,
140
  neg_txt_ids=inp_neg["txt_ids"] if use_true_cfg else None,
 
149
 
150
  # decode latents to pixel space
151
  x = unpack(x.float(), opts.height, opts.width)
152
+ with torch.autocast(
153
+ device_type=flux_generator.device.type,
154
+ dtype=torch.bfloat16 if flux_generator.device.type == "cuda" else torch.float32,
155
+ ):
156
  x = flux_generator.ae.decode(x)
157
 
158
  if flux_generator.offload:
 
164
  print(f"Done in {t1 - t0:.1f}s.")
165
  # bring into PIL format
166
  x = x.clamp(-1, 1)
 
167
  x = rearrange(x[0], "c h w -> h w c")
168
 
169
  img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
170
  return img, str(opts.seed), flux_generator.pulid_model.debug_img_list
171
 
172
+
173
+ def create_demo(args, model_name: str, device: str = "cuda" if torch.cuda.is_available() else "cpu", offload: bool = False):
 
174
  with gr.Blocks(theme="soft") as demo:
175
  gr.HTML(
176
  """
 
178
  <a href="https://huggingface.co/spaces/openfree/Best-AI" target="_blank">
179
  <img src="https://img.shields.io/static/v1?label=OpenFree&message=BEST%20AI%20Services&color=%230000ff&labelColor=%23000080&logo=huggingface&logoColor=%23ffa500&style=for-the-badge" alt="OpenFree badge">
180
  </a>
181
+
182
  <a href="https://discord.gg/openfreeai" target="_blank">
183
  <img src="https://img.shields.io/static/v1?label=Discord&message=Openfree%20AI&color=%230000ff&labelColor=%23800080&logo=discord&logoColor=white&style=for-the-badge" alt="Discord badge">
184
  </a>
185
  </div>
186
  """
187
  )
188
+
189
  with gr.Row():
190
  with gr.Column():
191
  prompt = gr.Textbox(label="Prompt", value="portrait, color, cinematic")
 
198
  start_step = gr.Slider(0, 10, 0, step=1, label="timestep to start inserting ID")
199
  guidance = gr.Slider(1.0, 10.0, 4, step=0.1, label="Guidance")
200
  seed = gr.Textbox(-1, label="Seed (-1 for random)")
201
+ max_sequence_length = gr.Slider(128, 512, 128, step=128, label="max_sequence_length for prompt (T5), small will be faster")
 
202
 
203
+ with gr.Accordion(
204
+ "Advanced Options (True CFG, true_cfg_scale=1 means use fake CFG, >1 means use true CFG, if using true CFG, we recommend set the guidance scale to 1)",
205
+ open=False,
206
+ ):
207
  neg_prompt = gr.Textbox(
208
  label="Negative Prompt",
209
+ value="bad quality, worst quality, text, signature, watermark, extra limbs",
210
+ )
211
  true_cfg = gr.Slider(1.0, 10.0, 1, step=0.1, label="true CFG scale")
212
  timestep_to_start_cfg = gr.Slider(0, 20, 1, step=1, label="timestep to start cfg", visible=args.dev)
213
 
214
  generate_btn = gr.Button("Generate")
215
+
216
  with gr.Column():
217
  output_image = gr.Image(label="Generated Image")
218
  seed_output = gr.Textbox(label="Used Seed")
219
+ intermediate_output = gr.Gallery(label="Output", elem_id="gallery", visible=args.dev)
220
 
221
  with gr.Row(), gr.Column():
222
  gr.Markdown("## Examples")
223
  example_inps = [
224
  [
225
+ 'a woman holding sign with glowing green text "PuLID for FLUX"',
226
+ "example_inputs/qw1.webp",
227
+ 4,
228
+ 4,
229
+ 2680261499100305976,
230
+ 1,
231
  ],
232
  [
233
+ "portrait, pixar",
234
+ "example_inputs/qw2.webp",
235
+ 1,
236
+ 4,
237
+ 9445036702517583939,
238
+ 1,
239
  ],
240
  ]
241
+ gr.Examples(examples=example_inps, inputs=[prompt, id_image, start_step, guidance, seed, true_cfg], label="fake CFG")
 
242
 
243
  example_inps = [
244
  [
245
+ "portrait, made of ice sculpture",
246
+ "example_inputs/qw3.webp",
247
+ 1,
248
+ 1,
249
+ 3811899118709451814,
250
+ 5,
251
  ],
252
  ]
253
+ gr.Examples(examples=example_inps, inputs=[prompt, id_image, start_step, guidance, seed, true_cfg], label="true CFG")
 
254
 
255
  generate_btn.click(
256
  fn=generate_image,
257
+ inputs=[
258
+ width,
259
+ height,
260
+ num_steps,
261
+ start_step,
262
+ guidance,
263
+ seed,
264
+ prompt,
265
+ id_image,
266
+ id_weight,
267
+ neg_prompt,
268
+ true_cfg,
269
+ timestep_to_start_cfg,
270
+ max_sequence_length,
271
+ ],
272
  outputs=[output_image, seed_output, intermediate_output],
273
  )
274
 
275
  return demo
276
 
277
+
278
  if __name__ == "__main__":
279
  import argparse
280
 
281
  parser = argparse.ArgumentParser(description="PuLID for FLUX.1-dev")
282
+ parser.add_argument("--name", type=str, default="flux-dev", choices=["flux-dev"], help="currently only support flux-dev")
283
+ parser.add_argument(
284
+ "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to use"
285
+ )
286
  parser.add_argument("--offload", action="store_true", help="Offload model to CPU when not in use")
287
  parser.add_argument("--port", type=int, default=8080, help="Port to use")
288
+ parser.add_argument("--dev", action="store_true", help="Development mode")
289
+ parser.add_argument("--pretrained_model", type=str, help="for development")
290
  args = parser.parse_args()
291
 
292
  import huggingface_hub
293
+
294
+ hf_token = os.getenv("HF_TOKEN")
295
+ if hf_token:
296
+ huggingface_hub.login(hf_token)
297
 
298
  demo = create_demo(args, args.name, args.device, args.offload)
299
+ demo.launch()