rahul7star commited on
Commit
11a45c8
·
verified ·
1 Parent(s): 4c1d5ac

Update app_quant_latent.py

Browse files
Files changed (1) hide show
  1. app_quant_latent.py +85 -0
app_quant_latent.py CHANGED
@@ -579,6 +579,7 @@ def upload_latents_to_hf(latent_dict, filename="latents.pt"):
579
  os.remove(local_path)
580
  raise e
581
 
 
582
  @spaces.GPU
583
  def generate_image(prompt, height, width, steps, seed, guidance_scale=0.0):
584
  LOGS = []
@@ -589,6 +590,90 @@ def generate_image(prompt, height, width, steps, seed, guidance_scale=0.0):
589
  latent_gallery = []
590
  final_gallery = []
591
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
592
  # --- Try generating latent previews ---
593
  try:
594
  latents = safe_get_latents(pipe, height, width, generator, device, LOGS)
 
579
  os.remove(local_path)
580
  raise e
581
 
582
+
583
  @spaces.GPU
584
  def generate_image(prompt, height, width, steps, seed, guidance_scale=0.0):
585
  LOGS = []
 
590
  latent_gallery = []
591
  final_gallery = []
592
 
593
+ # --- Try generating latent previews ---
594
+ try:
595
+ latents = safe_get_latents(pipe, height, width, generator, device, LOGS)
596
+
597
+ # Decode latent tensor to PIL for preview with robust fallbacks
598
+ latent_img = placeholder
599
+ try:
600
+ with torch.no_grad():
601
+ # 1️⃣ Try normal VAE decode if available
602
+ if hasattr(pipe, "vae") and hasattr(pipe.vae, "decode"):
603
+ try:
604
+ latent_img_tensor = pipe.vae.decode(latents).sample # [1,3,H,W]
605
+ latent_img_tensor = (latent_img_tensor / 2 + 0.5).clamp(0, 1)
606
+ latent_img_tensor = latent_img_tensor.cpu().permute(0, 2, 3, 1)[0]
607
+ latent_img = Image.fromarray((latent_img_tensor.numpy() * 255).astype('uint8'))
608
+ except Exception as e1:
609
+ LOGS.append(f"⚠️ VAE decode failed: {e1}")
610
+
611
+ # 2️⃣ Collapse first 3 channels if decode failed
612
+ if latent_img is placeholder and latents.shape[1] >= 3:
613
+ ch = latents[0, :3, :, :]
614
+ ch = (ch - ch.min()) / (ch.max() - ch.min() + 1e-8)
615
+ latent_img = Image.fromarray((ch.permute(1, 2, 0).cpu().numpy() * 255).astype('uint8'))
616
+
617
+ # 3️⃣ Collapse all channels to mean -> replicate to RGB
618
+ if latent_img is placeholder:
619
+ mean_ch = latents[0].mean(dim=0, keepdim=True) # [1,H,W]
620
+ mean_ch = (mean_ch - mean_ch.min()) / (mean_ch.max() - mean_ch.min() + 1e-8)
621
+ latent_img = Image.fromarray(
622
+ torch.cat([mean_ch]*3, dim=0).permute(1,2,0).cpu().numpy().astype('uint8')
623
+ )
624
+
625
+ except Exception as e:
626
+ LOGS.append(f"⚠️ Latent to image conversion failed: {e}")
627
+ latent_img = placeholder
628
+
629
+ latent_gallery.append(latent_img)
630
+ yield None, latent_gallery, LOGS # show preview immediately
631
+
632
+ # Save latents to HF for later testing
633
+ latent_dict = {"latents": latents.cpu(), "prompt": prompt, "seed": seed}
634
+ try:
635
+ hf_url = upload_latents_to_hf(latent_dict, filename=f"latents_{seed}.pt")
636
+ LOGS.append(f"🔹 Latents uploaded: {hf_url}")
637
+ except Exception as e:
638
+ LOGS.append(f"⚠️ Failed to upload latents: {e}")
639
+
640
+ except Exception as e:
641
+ LOGS.append(f"⚠️ Latent generation failed: {e}")
642
+ latent_gallery.append(placeholder)
643
+ yield None, latent_gallery, LOGS
644
+
645
+ # --- Final image: completely untouched, uses standard pipeline ---
646
+ try:
647
+ output = pipe(
648
+ prompt=prompt,
649
+ height=height,
650
+ width=width,
651
+ num_inference_steps=steps,
652
+ guidance_scale=guidance_scale,
653
+ generator=generator,
654
+ )
655
+ final_img = output.images[0]
656
+ final_gallery.append(final_img)
657
+ latent_gallery.append(final_img) # fallback preview if needed
658
+ LOGS.append("✅ Standard pipeline succeeded.")
659
+ yield final_img, latent_gallery, LOGS
660
+
661
+ except Exception as e2:
662
+ LOGS.append(f"❌ Standard pipeline failed: {e2}")
663
+ final_gallery.append(placeholder)
664
+ latent_gallery.append(placeholder)
665
+ yield placeholder, latent_gallery, LOGS
666
+ # this version generate well for final and gives a tensor back for latent
667
+ @spaces.GPU
668
+ def generate_image_workswell(prompt, height, width, steps, seed, guidance_scale=0.0):
669
+ LOGS = []
670
+ device = "cuda"
671
+ generator = torch.Generator(device).manual_seed(int(seed))
672
+
673
+ placeholder = Image.new("RGB", (width, height), color=(255, 255, 255))
674
+ latent_gallery = []
675
+ final_gallery = []
676
+
677
  # --- Try generating latent previews ---
678
  try:
679
  latents = safe_get_latents(pipe, height, width, generator, device, LOGS)