rahul7star commited on
Commit
8ce7829
·
verified ·
1 Parent(s): 62e2776

Update app_quant_latent.py

Browse files
Files changed (1) hide show
  1. app_quant_latent.py +64 -47
app_quant_latent.py CHANGED
@@ -558,67 +558,84 @@ def safe_get_latents(pipe, height, width, generator, device, LOGS):
558
  def generate_image(prompt, height, width, steps, seed, guidance_scale=0.0):
559
  LOGS = []
560
  device = "cuda"
561
-
562
  generator = torch.Generator(device).manual_seed(int(seed))
563
- latent_gallery = []
564
 
565
- # Gradio-safe streaming buffer
566
- stream_buffer = {"updated": False}
 
567
 
568
- # --------------------------
569
- # CALLBACK CALLED EACH STEP
570
- # --------------------------
571
- def latent_callback(step, timestep, latents):
572
  try:
573
- img = latent_to_image(latents) # convert latent → PIL
574
- except Exception:
575
- img = Image.new("RGB", (width, height), "white")
576
 
577
- latent_gallery.append(img)
578
- stream_buffer["updated"] = True
 
 
 
579
 
580
- # --------------------------
581
- # MAIN GENERATOR LOOP
582
- # --------------------------
583
- try:
584
- # start generation
585
- gen = pipe.generate(
586
- prompt=prompt,
587
- height=height,
588
- width=width,
589
- num_inference_steps=steps,
590
- guidance_scale=guidance_scale,
591
- generator=generator,
592
- callback=latent_callback,
593
- callback_steps=1,
594
- stream=True, # << IMPORTANT for manual iteration
595
- )
596
 
597
- # stream steps as they happen
598
- for out in gen:
599
- if stream_buffer["updated"]:
600
- stream_buffer["updated"] = False
601
  yield (
602
- None, # final image not ready
603
- latent_gallery, # growing thumbnails
604
- "\n".join(LOGS), # logs
605
  )
606
 
607
- # final result
608
- final_img = out.images[0]
 
 
609
 
610
- yield (
611
- final_img,
612
- latent_gallery,
613
- "\n".join(LOGS),
614
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
615
 
616
  except Exception as e:
617
- LOGS.append(str(e))
618
- ph = Image.new("RGB", (width, height), "white")
619
  yield (
620
- ph,
621
- [ph],
622
  "\n".join(LOGS),
623
  )
624
 
 
558
  def generate_image(prompt, height, width, steps, seed, guidance_scale=0.0):
559
  LOGS = []
560
  device = "cuda"
 
561
  generator = torch.Generator(device).manual_seed(int(seed))
 
562
 
563
+ placeholder = Image.new("RGB", (width, height), color=(255, 255, 255))
564
+ latent_gallery = []
565
+ final_gallery = []
566
 
567
+ try:
568
+ # ==========================================================
569
+ # ADVANCED LATENT MODE (hack using latent tensor)
570
+ # ==========================================================
571
  try:
572
+ # Get initial latent tensor
573
+ latents = safe_get_latents(pipe, height, width, generator, device, LOGS)
 
574
 
575
+ # Fake step-wise preview by slicing latent channels / noise
576
+ num_preview_steps = min(6, latents.shape[1]) # e.g. 6 slices
577
+ for i in range(num_preview_steps):
578
+ # Take a slice of latent channels to simulate intermediate step
579
+ step_latent = latents[:, : (i + 1), :, :]
580
 
581
+ # Convert step latent to PIL
582
+ try:
583
+ latent_img = latent_to_image(step_latent)
584
+ except Exception:
585
+ latent_img = placeholder
 
 
 
 
 
 
 
 
 
 
 
586
 
587
+ latent_gallery.append(latent_img)
588
+
589
+ # Stream intermediate latent preview to Gradio
 
590
  yield (
591
+ None, # final image not ready
592
+ latent_gallery, # gallery updates live
593
+ "\n".join(LOGS),
594
  )
595
 
596
+ # Decode final latent tensor into final image
597
+ final_img = pipe.decode_latents(latents)[0]
598
+ final_gallery.append(final_img)
599
+ LOGS.append("✅ Advanced latent pipeline succeeded.")
600
 
601
+ yield (
602
+ final_img,
603
+ latent_gallery,
604
+ "\n".join(LOGS),
605
+ )
606
+
607
+ # ==========================================================
608
+ # FALLBACK STANDARD PIPELINE
609
+ # ==========================================================
610
+ except Exception as e:
611
+ LOGS.append(f"⚠️ Advanced latent mode failed: {e}")
612
+ LOGS.append("🔁 Switching to standard pipeline...")
613
+
614
+ output = pipe(
615
+ prompt=prompt,
616
+ height=height,
617
+ width=width,
618
+ num_inference_steps=steps,
619
+ guidance_scale=guidance_scale,
620
+ generator=generator,
621
+ )
622
+
623
+ final_img = output.images[0]
624
+ latent_gallery.append(final_img) # last step in gallery
625
+ LOGS.append("✅ Standard pipeline succeeded.")
626
+
627
+ yield (
628
+ final_img,
629
+ latent_gallery,
630
+ "\n".join(LOGS),
631
+ )
632
 
633
  except Exception as e:
634
+ LOGS.append(f"❌ Total failure: {e}")
635
+ placeholder_img = placeholder
636
  yield (
637
+ placeholder_img,
638
+ [placeholder_img],
639
  "\n".join(LOGS),
640
  )
641