fffiloni commited on
Commit
a721627
·
verified ·
1 Parent(s): 4e2d24c

fix(app): align input tensor dtypes with model dtypes during inference

Browse files
Files changed (1) hide show
  1. app.py +152 -103
app.py CHANGED
@@ -28,11 +28,21 @@ from gdf.schedulers import CosineSchedule
28
  from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight
29
  from gdf.targets import EpsilonTarget
30
  import PIL
 
31
 
32
  # Device configuration
33
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
34
  print(device)
35
 
 
 
 
 
 
 
 
 
 
36
  # Flag for low VRAM usage
37
  # low_vram = False
38
 
@@ -53,10 +63,11 @@ def models_to(model, device="cpu", excepts=None):
53
  continue
54
  print(f"Change device of '{attr_name}' to {device}")
55
  attr_value.to(device)
56
-
57
  torch.cuda.empty_cache()
58
  gc.collect()
59
 
 
60
  # Stage C model configuration
61
  config_file = 'third_party/StableCascade/configs/inference/stage_c_3b.yaml'
62
  with open(config_file, "r", encoding="utf-8") as file:
@@ -68,7 +79,7 @@ core = WurstCoreCRBM(config_dict=loaded_config, device=device, training=False)
68
  config_file_b = 'third_party/StableCascade/configs/inference/stage_b_3b.yaml'
69
  with open(config_file_b, "r", encoding="utf-8") as file:
70
  config_file_b = yaml.safe_load(file)
71
-
72
  core_b = WurstCoreB(config_dict=config_file_b, device=device, training=False)
73
 
74
  # Setup extras and models for Stage C
@@ -129,20 +140,20 @@ models_rbm = core.Models(
129
  models_rbm.generator.eval().requires_grad_(False)
130
 
131
 
132
-
133
  def infer(ref_style_file, style_description, caption, use_low_vram, progress):
134
  global models_rbm, models_b, device
135
-
136
  models_to(models_rbm, device=device)
137
-
138
  try:
139
-
140
  caption = f"{caption} in {style_description}"
141
- height=1024
142
- width=1024
143
- batch_size=1
144
-
145
- stage_c_latent_shape, stage_b_latent_shape = calculate_latent_sizes(height, width, batch_size=batch_size)
 
 
146
 
147
  extras.sampling_configs['cfg'] = 4
148
  extras.sampling_configs['shift'] = 2
@@ -155,26 +166,46 @@ def infer(ref_style_file, style_description, caption, use_low_vram, progress):
155
  extras_b.sampling_configs['t_start'] = 1.0
156
 
157
  progress(0.1, "Loading style reference image")
158
- ref_style = resize_image(PIL.Image.open(ref_style_file).convert("RGB")).unsqueeze(0).expand(batch_size, -1, -1, -1).to(device)
 
 
 
 
 
159
 
160
  batch = {'captions': [caption] * batch_size}
161
- batch['style'] = ref_style
162
 
163
  progress(0.2, "Processing style reference image")
164
- x0_style_forward = models_rbm.effnet(extras.effnet_preprocess(ref_style.to(device)))
 
 
165
 
166
  progress(0.3, "Generating conditions")
167
- conditions = core.get_conditions(batch, models_rbm, extras, is_eval=True, is_unconditional=False, eval_image_embeds=True, eval_style=True, eval_csd=False)
168
- unconditions = core.get_conditions(batch, models_rbm, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False)
169
- conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False)
170
- unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
 
172
  if use_low_vram:
173
  # The sampling process uses more vram, so we offload everything except two modules to the cpu.
174
  models_to(models_rbm, device="cpu", excepts=["generator", "previewer"])
175
 
176
  progress(0.4, "Starting Stage C reverse process")
177
- # Stage C reverse process.
178
  sampling_c = extras.gdf.sample(
179
  models_rbm.generator, conditions, stage_c_latent_shape,
180
  unconditions, device=device,
@@ -186,74 +217,73 @@ def infer(ref_style_file, style_description, caption, use_low_vram, progress):
186
  lam_style=1, lam_txt_alignment=1.0,
187
  use_ddim_sampler=True,
188
  )
189
- for (sampled_c, _, _) in progress.tqdm(tqdm(sampling_c, total=extras.sampling_configs['timesteps']), desc="Stage C reverse process"):
190
- #for i, (sampled_c, _, _) in enumerate(sampling_c, 1):
191
- # if i % 5 == 0: # Update progress every 5 steps
192
- # progress(0.4 + 0.3 * (i / extras.sampling_configs['timesteps']), f"Stage C reverse process: step {i}/{extras.sampling_configs['timesteps']}")
 
193
  sampled_c = sampled_c
194
 
195
  progress(0.7, "Starting Stage B reverse process")
196
- # Stage B reverse process.
197
- with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
198
  conditions_b['effnet'] = sampled_c
199
  unconditions_b['effnet'] = torch.zeros_like(sampled_c)
200
-
201
  sampling_b = extras_b.gdf.sample(
202
  models_b.generator, conditions_b, stage_b_latent_shape,
203
  unconditions_b, device=device, **extras_b.sampling_configs,
204
  )
205
- for sampled_b, _, _ in progress.tqdm(tqdm(sampling_b, total=extras_b.sampling_configs['timesteps']), desc="Stage B reverse process"):
206
- #for i, (sampled_b, _, _) in enumerate(sampling_b, 1):
207
- # if i % 1 == 0: # Update progress every 1 step
208
- # progress(0.7 + 0.2 * (i / extras_b.sampling_configs['timesteps']), f"Stage B reverse process: step {i}/{extras_b.sampling_configs['timesteps']}")
209
  sampled_b = sampled_b
210
  sampled = models_b.stage_a.decode(sampled_b).float()
211
 
212
  torch.cuda.empty_cache()
213
  gc.collect()
214
-
215
  progress(0.9, "Finalizing the output image")
216
  sampled = torch.cat([
217
- torch.nn.functional.interpolate(ref_style.cpu(), size=(height, width)),
218
- sampled.cpu(),
219
  ], dim=0)
220
 
221
- # Remove the batch dimension and keep only the generated image
222
- sampled = sampled[1] # This selects the generated image, discarding the reference style image
223
-
224
- # Ensure the tensor values are in the correct range
225
  sampled = torch.clamp(sampled, 0, 1)
226
 
227
- # Ensure the tensor is in [C, H, W] format
228
  if sampled.dim() == 3 and sampled.shape[0] == 3:
229
- sampled_image = T.ToPILImage()(sampled) # Convert tensor to PIL image
230
  else:
231
  raise ValueError(f"Expected tensor of shape [3, H, W] but got {sampled.shape}")
232
 
233
  progress(1.0, "Inference complete")
234
- return sampled_image # Return the sampled_image PIL image
235
 
236
  finally:
237
  if use_low_vram:
238
  models_to(models_rbm, device=device)
239
- # Clear CUDA cache
240
  torch.cuda.empty_cache()
241
  gc.collect()
242
 
 
243
  def infer_compo(style_description, ref_style_file, caption, ref_sub_file, use_low_vram, progress):
244
  global models_rbm, models_b, device
245
  sam_model = LangSAM()
246
  models_to(models_rbm, device=device)
247
  models_to(sam_model, device=device)
248
  models_to(sam_model.sam, device=device)
 
249
  try:
250
  caption = f"{caption} in {style_description}"
251
  sam_prompt = f"{caption}"
252
  use_sam_mask = False
253
-
254
  batch_size = 1
255
  height, width = 1024, 1024
256
- stage_c_latent_shape, stage_b_latent_shape = calculate_latent_sizes(height, width, batch_size=batch_size)
 
 
257
 
258
  extras.sampling_configs['cfg'] = 4
259
  extras.sampling_configs['shift'] = 2
@@ -265,31 +295,58 @@ def infer_compo(style_description, ref_style_file, caption, ref_sub_file, use_lo
265
  extras_b.sampling_configs['t_start'] = 1.0
266
 
267
  progress(0.1, "Loading style and subject reference images")
268
- ref_style = resize_image(PIL.Image.open(ref_style_file).convert("RGB")).unsqueeze(0).expand(batch_size, -1, -1, -1).to(device)
269
- ref_images = resize_image(PIL.Image.open(ref_sub_file).convert("RGB")).unsqueeze(0).expand(batch_size, -1, -1, -1).to(device)
270
-
 
 
 
 
 
 
 
 
 
 
 
271
  batch = {'captions': [caption] * batch_size}
272
- batch['style'] = ref_style
273
- batch['images'] = ref_images
274
 
275
  progress(0.2, "Processing reference images")
276
- x0_forward = models_rbm.effnet(extras.effnet_preprocess(ref_images.to(device)))
277
- x0_style_forward = models_rbm.effnet(extras.effnet_preprocess(ref_style.to(device)))
278
-
279
- ## SAM Mask for sub
 
 
 
280
  use_sam_mask = False
281
  x0_preview = models_rbm.previewer(x0_forward)
282
-
283
- x0_preview_pil = T.ToPILImage()(x0_preview[0].cpu())
284
  sam_mask, boxes, phrases, logits = sam_model.predict(x0_preview_pil, sam_prompt)
285
- # sam_mask, boxes, phrases, logits = sam_model.predict(transform(x0_preview[0]), sam_prompt)
286
  sam_mask = sam_mask.detach().unsqueeze(dim=0).to(device)
287
 
288
  progress(0.3, "Generating conditions")
289
- conditions = core.get_conditions(batch, models_rbm, extras, is_eval=True, is_unconditional=False, eval_image_embeds=True, eval_subject_style=True, eval_csd=False)
290
- unconditions = core.get_conditions(batch, models_rbm, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False, eval_subject_style=True)
291
- conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False)
292
- unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
293
 
294
  if use_low_vram:
295
  models_to(models_rbm, device="cpu", excepts=["generator", "previewer"])
@@ -297,15 +354,14 @@ def infer_compo(style_description, ref_style_file, caption, ref_sub_file, use_lo
297
  models_to(sam_model.sam, device="cpu")
298
 
299
  progress(0.4, "Starting Stage C reverse process")
300
- # Stage C reverse process.
301
  sampling_c = extras.gdf.sample(
302
  models_rbm.generator, conditions, stage_c_latent_shape,
303
  unconditions, device=device,
304
  **extras.sampling_configs,
305
  x0_style_forward=x0_style_forward, x0_forward=x0_forward,
306
- apply_pushforward=False, tau_pushforward=5, tau_pushforward_csd=10,
307
  num_iter=3, eta=1e-1, tau=20, eval_sub_csd=True,
308
- extras=extras, models=models_rbm,
309
  use_attn_mask=use_sam_mask,
310
  save_attn_mask=False,
311
  lam_content=1, lam_style=1,
@@ -313,63 +369,58 @@ def infer_compo(style_description, ref_style_file, caption, ref_sub_file, use_lo
313
  sam_prompt=sam_prompt
314
  )
315
 
316
- for sampled_c, _, _ in progress.tqdm(tqdm(sampling_c, total=extras.sampling_configs['timesteps']), desc="Stage C reverse process"):
317
- #for i, (sampled_c, _, _) in enumerate(sampling_c, 1):
318
- # if i % 5 == 0: # Update progress every 5 steps
319
- # progress(0.4 + 0.3 * (i / extras.sampling_configs['timesteps']), f"Stage C reverse process: step {i}/{extras.sampling_configs['timesteps']}")
320
  sampled_c = sampled_c
321
 
322
  progress(0.7, "Starting Stage B reverse process")
323
- # Stage B reverse process.
324
- with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
325
  conditions_b['effnet'] = sampled_c
326
  unconditions_b['effnet'] = torch.zeros_like(sampled_c)
327
-
328
  sampling_b = extras_b.gdf.sample(
329
  models_b.generator, conditions_b, stage_b_latent_shape,
330
  unconditions_b, device=device, **extras_b.sampling_configs,
331
  )
332
- for sampled_b, _, _ in progress.tqdm(tqdm(sampling_b, total=extras_b.sampling_configs['timesteps']), desc="Stage B reverse process"):
333
- #for i, (sampled_b, _, _) in enumerate(sampling_b, 1):
334
- # if i % 5 == 0: # Update progress every 5 steps
335
- # progress(0.7 + 0.2 * (i / extras_b.sampling_configs['timesteps']), f"Stage B reverse process: step {i}/{extras_b.sampling_configs['timesteps']}")
336
  sampled_b = sampled_b
337
  sampled = models_b.stage_a.decode(sampled_b).float()
338
 
339
  torch.cuda.empty_cache()
340
  gc.collect()
341
-
342
  progress(0.9, "Finalizing the output image")
343
  sampled = torch.cat([
344
- torch.nn.functional.interpolate(ref_images.cpu(), size=(height, width)),
345
- torch.nn.functional.interpolate(ref_style.cpu(), size=(height, width)),
346
- sampled.cpu(),
347
  ], dim=0)
348
 
349
- # Remove the batch dimension and keep only the generated image
350
- sampled = sampled[2] # This selects the generated image, discarding the reference images
351
-
352
- # Ensure the tensor values are in the correct range
353
  sampled = torch.clamp(sampled, 0, 1)
354
 
355
- # Ensure the tensor is in [C, H, W] format
356
  if sampled.dim() == 3 and sampled.shape[0] == 3:
357
- sampled_image = T.ToPILImage()(sampled) # Convert tensor to PIL image
358
  else:
359
  raise ValueError(f"Expected tensor of shape [3, H, W] but got {sampled.shape}")
360
 
361
  progress(1.0, "Inference complete")
362
- return sampled_image # Return the sampled_image PIL image
363
 
364
  finally:
365
  if use_low_vram:
366
  models_to(models_rbm, device=device, excepts=["generator", "previewer"])
367
  models_to(sam_model, device=device)
368
  models_to(sam_model.sam, device=device)
369
- # Clear CUDA cache
370
  torch.cuda.empty_cache()
371
  gc.collect()
372
 
 
373
  def run(style_reference_image, style_description, subject_prompt, subject_reference, use_subject_ref, use_low_vram):
374
  result = None
375
  progress = gr.Progress(track_tqdm=True)
@@ -379,13 +430,13 @@ def run(style_reference_image, style_description, subject_prompt, subject_refere
379
  result = infer(style_reference_image, style_description, subject_prompt, use_low_vram, progress)
380
  return result
381
 
 
382
  def show_hide_subject_image_component(use_subject_ref):
383
  if use_subject_ref is True:
384
  return gr.update(open=True)
385
  else:
386
  return gr.update(open=False)
387
 
388
- import gradio as gr
389
 
390
  with gr.Blocks(analytics_enabled=False) as demo:
391
  with gr.Column():
@@ -404,29 +455,28 @@ with gr.Blocks(analytics_enabled=False) as demo:
404
  with gr.Row():
405
  with gr.Column():
406
  style_reference_image = gr.Image(
407
- label = "Style Reference Image",
408
- type = "filepath"
409
  )
410
  style_description = gr.Textbox(
411
- label ="Style Description"
412
  )
413
  subject_prompt = gr.Textbox(
414
- label = "Subject Prompt"
415
  )
416
  with gr.Row():
417
  use_subject_ref = gr.Checkbox(label="Use Subject Image as Reference", value=False)
418
  use_low_vram = gr.Checkbox(label="Use Low-VRAM", value=False)
419
-
420
  with gr.Accordion("Advanced Settings", open=False) as sub_img_panel:
421
  subject_reference = gr.Image(label="Subject Reference", type="filepath")
422
-
423
  submit_btn = gr.Button("Submit")
424
 
425
-
426
  with gr.Column():
427
  output_image = gr.Image(label="Output Image")
428
  gr.Examples(
429
- examples = [
430
  ["./data/cyberpunk.png", "cyberpunk art style", "a car", None, False, False],
431
  ["./data/mosaic.png", "mosaic art style", "a lighthouse", None, False, False],
432
  ["./data/glowing.png", "glowing style", "a dwarf", None, False, False],
@@ -436,21 +486,20 @@ with gr.Blocks(analytics_enabled=False) as demo:
436
  inputs=[style_reference_image, style_description, subject_prompt, subject_reference, use_subject_ref, use_low_vram],
437
  outputs=[output_image],
438
  cache_examples=False
439
-
440
  )
441
 
442
  use_subject_ref.input(
443
- fn = show_hide_subject_image_component,
444
- inputs = [use_subject_ref],
445
- outputs = [sub_img_panel],
446
- queue = False,
447
  api_visibility="private"
448
  )
449
-
450
  submit_btn.click(
451
- fn = run,
452
- inputs = [style_reference_image, style_description, subject_prompt, subject_reference, use_subject_ref, use_low_vram],
453
- outputs = [output_image],
454
  api_visibility="private"
455
  )
456
 
 
28
  from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight
29
  from gdf.targets import EpsilonTarget
30
  import PIL
31
+ import gradio as gr
32
 
33
  # Device configuration
34
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
35
  print(device)
36
 
37
+
38
+ def module_dtype(module):
39
+ return next(module.parameters()).dtype
40
+
41
+
42
+ def to_module_device_dtype(tensor, module, device=device):
43
+ return tensor.to(device=device, dtype=module_dtype(module))
44
+
45
+
46
  # Flag for low VRAM usage
47
  # low_vram = False
48
 
 
63
  continue
64
  print(f"Change device of '{attr_name}' to {device}")
65
  attr_value.to(device)
66
+
67
  torch.cuda.empty_cache()
68
  gc.collect()
69
 
70
+
71
  # Stage C model configuration
72
  config_file = 'third_party/StableCascade/configs/inference/stage_c_3b.yaml'
73
  with open(config_file, "r", encoding="utf-8") as file:
 
79
  config_file_b = 'third_party/StableCascade/configs/inference/stage_b_3b.yaml'
80
  with open(config_file_b, "r", encoding="utf-8") as file:
81
  config_file_b = yaml.safe_load(file)
82
+
83
  core_b = WurstCoreB(config_dict=config_file_b, device=device, training=False)
84
 
85
  # Setup extras and models for Stage C
 
140
  models_rbm.generator.eval().requires_grad_(False)
141
 
142
 
 
143
  def infer(ref_style_file, style_description, caption, use_low_vram, progress):
144
  global models_rbm, models_b, device
145
+
146
  models_to(models_rbm, device=device)
147
+
148
  try:
 
149
  caption = f"{caption} in {style_description}"
150
+ height = 1024
151
+ width = 1024
152
+ batch_size = 1
153
+
154
+ stage_c_latent_shape, stage_b_latent_shape = calculate_latent_sizes(
155
+ height, width, batch_size=batch_size
156
+ )
157
 
158
  extras.sampling_configs['cfg'] = 4
159
  extras.sampling_configs['shift'] = 2
 
166
  extras_b.sampling_configs['t_start'] = 1.0
167
 
168
  progress(0.1, "Loading style reference image")
169
+ ref_style = resize_image(
170
+ PIL.Image.open(ref_style_file).convert("RGB")
171
+ ).unsqueeze(0).expand(batch_size, -1, -1, -1)
172
+
173
+ ref_style_for_clip = to_module_device_dtype(ref_style, models_rbm.image_model)
174
+ ref_style_for_effnet = to_module_device_dtype(ref_style, models_rbm.effnet)
175
 
176
  batch = {'captions': [caption] * batch_size}
177
+ batch['style'] = ref_style_for_clip
178
 
179
  progress(0.2, "Processing style reference image")
180
+ x0_style_forward = models_rbm.effnet(
181
+ extras.effnet_preprocess(ref_style_for_effnet)
182
+ )
183
 
184
  progress(0.3, "Generating conditions")
185
+ conditions = core.get_conditions(
186
+ batch, models_rbm, extras,
187
+ is_eval=True, is_unconditional=False,
188
+ eval_image_embeds=True, eval_style=True, eval_csd=False
189
+ )
190
+ unconditions = core.get_conditions(
191
+ batch, models_rbm, extras,
192
+ is_eval=True, is_unconditional=True,
193
+ eval_image_embeds=False
194
+ )
195
+ conditions_b = core_b.get_conditions(
196
+ batch, models_b, extras_b,
197
+ is_eval=True, is_unconditional=False
198
+ )
199
+ unconditions_b = core_b.get_conditions(
200
+ batch, models_b, extras_b,
201
+ is_eval=True, is_unconditional=True
202
+ )
203
 
204
  if use_low_vram:
205
  # The sampling process uses more vram, so we offload everything except two modules to the cpu.
206
  models_to(models_rbm, device="cpu", excepts=["generator", "previewer"])
207
 
208
  progress(0.4, "Starting Stage C reverse process")
 
209
  sampling_c = extras.gdf.sample(
210
  models_rbm.generator, conditions, stage_c_latent_shape,
211
  unconditions, device=device,
 
217
  lam_style=1, lam_txt_alignment=1.0,
218
  use_ddim_sampler=True,
219
  )
220
+
221
+ for (sampled_c, _, _) in progress.tqdm(
222
+ tqdm(sampling_c, total=extras.sampling_configs['timesteps']),
223
+ desc="Stage C reverse process"
224
+ ):
225
  sampled_c = sampled_c
226
 
227
  progress(0.7, "Starting Stage B reverse process")
228
+ with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
 
229
  conditions_b['effnet'] = sampled_c
230
  unconditions_b['effnet'] = torch.zeros_like(sampled_c)
231
+
232
  sampling_b = extras_b.gdf.sample(
233
  models_b.generator, conditions_b, stage_b_latent_shape,
234
  unconditions_b, device=device, **extras_b.sampling_configs,
235
  )
236
+ for sampled_b, _, _ in progress.tqdm(
237
+ tqdm(sampling_b, total=extras_b.sampling_configs['timesteps']),
238
+ desc="Stage B reverse process"
239
+ ):
240
  sampled_b = sampled_b
241
  sampled = models_b.stage_a.decode(sampled_b).float()
242
 
243
  torch.cuda.empty_cache()
244
  gc.collect()
245
+
246
  progress(0.9, "Finalizing the output image")
247
  sampled = torch.cat([
248
+ torch.nn.functional.interpolate(ref_style.float().cpu(), size=(height, width)),
249
+ sampled.float().cpu(),
250
  ], dim=0)
251
 
252
+ sampled = sampled[1]
 
 
 
253
  sampled = torch.clamp(sampled, 0, 1)
254
 
 
255
  if sampled.dim() == 3 and sampled.shape[0] == 3:
256
+ sampled_image = T.ToPILImage()(sampled)
257
  else:
258
  raise ValueError(f"Expected tensor of shape [3, H, W] but got {sampled.shape}")
259
 
260
  progress(1.0, "Inference complete")
261
+ return sampled_image
262
 
263
  finally:
264
  if use_low_vram:
265
  models_to(models_rbm, device=device)
 
266
  torch.cuda.empty_cache()
267
  gc.collect()
268
 
269
+
270
  def infer_compo(style_description, ref_style_file, caption, ref_sub_file, use_low_vram, progress):
271
  global models_rbm, models_b, device
272
  sam_model = LangSAM()
273
  models_to(models_rbm, device=device)
274
  models_to(sam_model, device=device)
275
  models_to(sam_model.sam, device=device)
276
+
277
  try:
278
  caption = f"{caption} in {style_description}"
279
  sam_prompt = f"{caption}"
280
  use_sam_mask = False
281
+
282
  batch_size = 1
283
  height, width = 1024, 1024
284
+ stage_c_latent_shape, stage_b_latent_shape = calculate_latent_sizes(
285
+ height, width, batch_size=batch_size
286
+ )
287
 
288
  extras.sampling_configs['cfg'] = 4
289
  extras.sampling_configs['shift'] = 2
 
295
  extras_b.sampling_configs['t_start'] = 1.0
296
 
297
  progress(0.1, "Loading style and subject reference images")
298
+ ref_style = resize_image(
299
+ PIL.Image.open(ref_style_file).convert("RGB")
300
+ ).unsqueeze(0).expand(batch_size, -1, -1, -1)
301
+
302
+ ref_images = resize_image(
303
+ PIL.Image.open(ref_sub_file).convert("RGB")
304
+ ).unsqueeze(0).expand(batch_size, -1, -1, -1)
305
+
306
+ ref_style_for_clip = to_module_device_dtype(ref_style, models_rbm.image_model)
307
+ ref_images_for_clip = to_module_device_dtype(ref_images, models_rbm.image_model)
308
+
309
+ ref_style_for_effnet = to_module_device_dtype(ref_style, models_rbm.effnet)
310
+ ref_images_for_effnet = to_module_device_dtype(ref_images, models_rbm.effnet)
311
+
312
  batch = {'captions': [caption] * batch_size}
313
+ batch['style'] = ref_style_for_clip
314
+ batch['images'] = ref_images_for_clip
315
 
316
  progress(0.2, "Processing reference images")
317
+ x0_forward = models_rbm.effnet(
318
+ extras.effnet_preprocess(ref_images_for_effnet)
319
+ )
320
+ x0_style_forward = models_rbm.effnet(
321
+ extras.effnet_preprocess(ref_style_for_effnet)
322
+ )
323
+
324
  use_sam_mask = False
325
  x0_preview = models_rbm.previewer(x0_forward)
326
+
327
+ x0_preview_pil = T.ToPILImage()(x0_preview[0].float().cpu())
328
  sam_mask, boxes, phrases, logits = sam_model.predict(x0_preview_pil, sam_prompt)
 
329
  sam_mask = sam_mask.detach().unsqueeze(dim=0).to(device)
330
 
331
  progress(0.3, "Generating conditions")
332
+ conditions = core.get_conditions(
333
+ batch, models_rbm, extras,
334
+ is_eval=True, is_unconditional=False,
335
+ eval_image_embeds=True, eval_subject_style=True, eval_csd=False
336
+ )
337
+ unconditions = core.get_conditions(
338
+ batch, models_rbm, extras,
339
+ is_eval=True, is_unconditional=True,
340
+ eval_image_embeds=False, eval_subject_style=True
341
+ )
342
+ conditions_b = core_b.get_conditions(
343
+ batch, models_b, extras_b,
344
+ is_eval=True, is_unconditional=False
345
+ )
346
+ unconditions_b = core_b.get_conditions(
347
+ batch, models_b, extras_b,
348
+ is_eval=True, is_unconditional=True
349
+ )
350
 
351
  if use_low_vram:
352
  models_to(models_rbm, device="cpu", excepts=["generator", "previewer"])
 
354
  models_to(sam_model.sam, device="cpu")
355
 
356
  progress(0.4, "Starting Stage C reverse process")
 
357
  sampling_c = extras.gdf.sample(
358
  models_rbm.generator, conditions, stage_c_latent_shape,
359
  unconditions, device=device,
360
  **extras.sampling_configs,
361
  x0_style_forward=x0_style_forward, x0_forward=x0_forward,
362
+ apply_pushforward=False, tau_pushforward=5, tau_pushforward_csd=10,
363
  num_iter=3, eta=1e-1, tau=20, eval_sub_csd=True,
364
+ extras=extras, models=models_rbm,
365
  use_attn_mask=use_sam_mask,
366
  save_attn_mask=False,
367
  lam_content=1, lam_style=1,
 
369
  sam_prompt=sam_prompt
370
  )
371
 
372
+ for sampled_c, _, _ in progress.tqdm(
373
+ tqdm(sampling_c, total=extras.sampling_configs['timesteps']),
374
+ desc="Stage C reverse process"
375
+ ):
376
  sampled_c = sampled_c
377
 
378
  progress(0.7, "Starting Stage B reverse process")
379
+ with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
 
380
  conditions_b['effnet'] = sampled_c
381
  unconditions_b['effnet'] = torch.zeros_like(sampled_c)
382
+
383
  sampling_b = extras_b.gdf.sample(
384
  models_b.generator, conditions_b, stage_b_latent_shape,
385
  unconditions_b, device=device, **extras_b.sampling_configs,
386
  )
387
+ for sampled_b, _, _ in progress.tqdm(
388
+ tqdm(sampling_b, total=extras_b.sampling_configs['timesteps']),
389
+ desc="Stage B reverse process"
390
+ ):
391
  sampled_b = sampled_b
392
  sampled = models_b.stage_a.decode(sampled_b).float()
393
 
394
  torch.cuda.empty_cache()
395
  gc.collect()
396
+
397
  progress(0.9, "Finalizing the output image")
398
  sampled = torch.cat([
399
+ torch.nn.functional.interpolate(ref_images.float().cpu(), size=(height, width)),
400
+ torch.nn.functional.interpolate(ref_style.float().cpu(), size=(height, width)),
401
+ sampled.float().cpu(),
402
  ], dim=0)
403
 
404
+ sampled = sampled[2]
 
 
 
405
  sampled = torch.clamp(sampled, 0, 1)
406
 
 
407
  if sampled.dim() == 3 and sampled.shape[0] == 3:
408
+ sampled_image = T.ToPILImage()(sampled)
409
  else:
410
  raise ValueError(f"Expected tensor of shape [3, H, W] but got {sampled.shape}")
411
 
412
  progress(1.0, "Inference complete")
413
+ return sampled_image
414
 
415
  finally:
416
  if use_low_vram:
417
  models_to(models_rbm, device=device, excepts=["generator", "previewer"])
418
  models_to(sam_model, device=device)
419
  models_to(sam_model.sam, device=device)
 
420
  torch.cuda.empty_cache()
421
  gc.collect()
422
 
423
+
424
  def run(style_reference_image, style_description, subject_prompt, subject_reference, use_subject_ref, use_low_vram):
425
  result = None
426
  progress = gr.Progress(track_tqdm=True)
 
430
  result = infer(style_reference_image, style_description, subject_prompt, use_low_vram, progress)
431
  return result
432
 
433
+
434
  def show_hide_subject_image_component(use_subject_ref):
435
  if use_subject_ref is True:
436
  return gr.update(open=True)
437
  else:
438
  return gr.update(open=False)
439
 
 
440
 
441
  with gr.Blocks(analytics_enabled=False) as demo:
442
  with gr.Column():
 
455
  with gr.Row():
456
  with gr.Column():
457
  style_reference_image = gr.Image(
458
+ label="Style Reference Image",
459
+ type="filepath"
460
  )
461
  style_description = gr.Textbox(
462
+ label="Style Description"
463
  )
464
  subject_prompt = gr.Textbox(
465
+ label="Subject Prompt"
466
  )
467
  with gr.Row():
468
  use_subject_ref = gr.Checkbox(label="Use Subject Image as Reference", value=False)
469
  use_low_vram = gr.Checkbox(label="Use Low-VRAM", value=False)
470
+
471
  with gr.Accordion("Advanced Settings", open=False) as sub_img_panel:
472
  subject_reference = gr.Image(label="Subject Reference", type="filepath")
473
+
474
  submit_btn = gr.Button("Submit")
475
 
 
476
  with gr.Column():
477
  output_image = gr.Image(label="Output Image")
478
  gr.Examples(
479
+ examples=[
480
  ["./data/cyberpunk.png", "cyberpunk art style", "a car", None, False, False],
481
  ["./data/mosaic.png", "mosaic art style", "a lighthouse", None, False, False],
482
  ["./data/glowing.png", "glowing style", "a dwarf", None, False, False],
 
486
  inputs=[style_reference_image, style_description, subject_prompt, subject_reference, use_subject_ref, use_low_vram],
487
  outputs=[output_image],
488
  cache_examples=False
 
489
  )
490
 
491
  use_subject_ref.input(
492
+ fn=show_hide_subject_image_component,
493
+ inputs=[use_subject_ref],
494
+ outputs=[sub_img_panel],
495
+ queue=False,
496
  api_visibility="private"
497
  )
498
+
499
  submit_btn.click(
500
+ fn=run,
501
+ inputs=[style_reference_image, style_description, subject_prompt, subject_reference, use_subject_ref, use_low_vram],
502
+ outputs=[output_image],
503
  api_visibility="private"
504
  )
505