rahul7star commited on
Commit
62e2776
·
verified ·
1 Parent(s): 8e89616

Update app_quant_latent.py

Browse files
Files changed (1) hide show
  1. app_quant_latent.py +47 -69
app_quant_latent.py CHANGED
@@ -558,93 +558,71 @@ def safe_get_latents(pipe, height, width, generator, device, LOGS):
558
  def generate_image(prompt, height, width, steps, seed, guidance_scale=0.0):
559
  LOGS = []
560
  device = "cuda"
561
- generator = torch.Generator(device).manual_seed(int(seed))
562
 
563
- placeholder = Image.new("RGB", (width, height), color=(255, 255, 255))
564
  latent_gallery = []
565
- final_gallery = []
566
-
567
- try:
568
- # ==========================================================
569
- # ADVANCED LATENT MODE
570
- # ==========================================================
571
- try:
572
- latents = safe_get_latents(pipe, height, width, generator, device, LOGS)
573
 
574
- for i, t in enumerate(pipe.scheduler.timesteps):
575
- with torch.no_grad():
576
- noise_pred = pipe.unet(
577
- latents, t,
578
- encoder_hidden_states=pipe.get_text_embeddings(prompt)
579
- )["sample"]
580
 
581
- latents = pipe.scheduler.step(noise_pred, t, latents)["prev_sample"]
 
 
 
 
 
 
 
582
 
583
- # convert latent → image
584
- try:
585
- latent_img = latent_to_image(latents)
586
- except Exception:
587
- latent_img = placeholder
588
 
589
- latent_gallery.append(latent_img)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
590
 
591
- # 🔥 STREAM update
 
 
 
592
  yield (
593
- None, # final_image
594
- latent_gallery, # latent gallery list
595
- "\n".join(LOGS), # logs
596
  )
597
 
598
- # ---------------------
599
- # FINAL decode
600
- # ---------------------
601
- final_img = pipe.decode_latents(latents)[0]
602
- final_gallery.append(final_img)
603
- LOGS.append("✅ Advanced latent pipeline succeeded.")
604
-
605
- yield (
606
- final_img,
607
- latent_gallery,
608
- "\n".join(LOGS),
609
- )
610
-
611
- # ==========================================================
612
- # FALLBACK STANDARD PIPELINE
613
- # ==========================================================
614
- except Exception as e:
615
- LOGS.append(f"⚠️ Advanced latent mode failed: {e}")
616
- LOGS.append("🔁 Switching to standard pipeline...")
617
-
618
- output = pipe(
619
- prompt=prompt,
620
- height=height,
621
- width=width,
622
- num_inference_steps=steps,
623
- guidance_scale=guidance_scale,
624
- generator=generator,
625
- )
626
-
627
- final_img = output.images[0]
628
- latent_gallery.append(final_img)
629
- LOGS.append("✅ Standard pipeline succeeded.")
630
 
631
- yield (
632
- final_img,
633
- latent_gallery,
634
- "\n".join(LOGS),
635
- )
636
 
637
  except Exception as e:
638
- LOGS.append(f"❌ Total failure: {e}")
639
- placeholder_img = placeholder
640
  yield (
641
- placeholder_img,
642
- [placeholder_img],
643
  "\n".join(LOGS),
644
  )
645
 
646
 
647
-
648
  @spaces.GPU
649
  def generate_image_backup(prompt, height, width, steps, seed, guidance_scale=0.0, return_latents=False):
650
  """
 
558
  def generate_image(prompt, height, width, steps, seed, guidance_scale=0.0):
559
  LOGS = []
560
  device = "cuda"
 
561
 
562
+ generator = torch.Generator(device).manual_seed(int(seed))
563
  latent_gallery = []
 
 
 
 
 
 
 
 
564
 
565
+ # Gradio-safe streaming buffer
566
+ stream_buffer = {"updated": False}
 
 
 
 
567
 
568
+ # --------------------------
569
+ # CALLBACK CALLED EACH STEP
570
+ # --------------------------
571
+ def latent_callback(step, timestep, latents):
572
+ try:
573
+ img = latent_to_image(latents) # convert latent → PIL
574
+ except Exception:
575
+ img = Image.new("RGB", (width, height), "white")
576
 
577
+ latent_gallery.append(img)
578
+ stream_buffer["updated"] = True
 
 
 
579
 
580
+ # --------------------------
581
+ # MAIN GENERATOR LOOP
582
+ # --------------------------
583
+ try:
584
+ # start generation
585
+ gen = pipe.generate(
586
+ prompt=prompt,
587
+ height=height,
588
+ width=width,
589
+ num_inference_steps=steps,
590
+ guidance_scale=guidance_scale,
591
+ generator=generator,
592
+ callback=latent_callback,
593
+ callback_steps=1,
594
+ stream=True, # << IMPORTANT for manual iteration
595
+ )
596
 
597
+ # stream steps as they happen
598
+ for out in gen:
599
+ if stream_buffer["updated"]:
600
+ stream_buffer["updated"] = False
601
  yield (
602
+ None, # final image not ready
603
+ latent_gallery, # growing thumbnails
604
+ "\n".join(LOGS), # logs
605
  )
606
 
607
+ # final result
608
+ final_img = out.images[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
609
 
610
+ yield (
611
+ final_img,
612
+ latent_gallery,
613
+ "\n".join(LOGS),
614
+ )
615
 
616
  except Exception as e:
617
+ LOGS.append(str(e))
618
+ ph = Image.new("RGB", (width, height), "white")
619
  yield (
620
+ ph,
621
+ [ph],
622
  "\n".join(LOGS),
623
  )
624
 
625
 
 
626
  @spaces.GPU
627
  def generate_image_backup(prompt, height, width, steps, seed, guidance_scale=0.0, return_latents=False):
628
  """