rahul7star commited on
Commit
a038035
·
verified ·
1 Parent(s): 41e7378

Update app_quant_latent.py

Browse files
Files changed (1) hide show
  1. app_quant_latent.py +17 -46
app_quant_latent.py CHANGED
@@ -474,50 +474,21 @@ import torch
474
  # Helper: Safe latent extractor
475
  # --------------------------
476
  def safe_get_latents(pipe, height, width, generator, device, LOGS):
477
- """
478
- Attempts multiple ways to get latents.
479
- Returns a valid tensor even if pipeline hides UNet.
480
- """
481
- # Try official prepare_latents
482
- try:
483
- if hasattr(pipe, "unet") and hasattr(pipe.unet, "in_channels"):
484
- num_channels = pipe.unet.in_channels
485
- latents = pipe.prepare_latents(
486
- batch_size=1,
487
- num_channels=num_channels,
488
- height=height,
489
- width=width,
490
- dtype=torch.float32,
491
- device=device,
492
- generator=generator
493
- )
494
- LOGS.append("✅ Latents extracted using official prepare_latents.")
495
- return latents
496
- except Exception as e:
497
- LOGS.append(f"⚠️ Official latent extraction failed: {e}")
498
-
499
- # Try hidden internal attribute
500
  try:
501
- if hasattr(pipe, "_default_latents"):
502
- LOGS.append("⚠️ Using hidden _default_latents.")
503
- return pipe._default_latents
504
- except:
505
- pass
506
-
507
- # Fallback: raw Gaussian tensor
508
- try:
509
- LOGS.append("⚠️ Using raw Gaussian latents fallback.")
510
- return torch.randn(
511
- (1, 4, height // 8, width // 8),
512
- generator=generator,
513
  device=device,
514
- dtype=torch.float32
515
  )
 
 
516
  except Exception as e:
517
- LOGS.append(f"⚠️ Gaussian fallback failed: {e}")
518
-
519
- LOGS.append("❗ Using CPU hard fallback latents.")
520
- return torch.randn((1, 4, height // 8, width // 8))
521
 
522
 
523
  # --------------------------
@@ -706,11 +677,11 @@ def generate_image_backup(prompt, height, width, steps, seed, guidance_scale=0.0
706
  # outputs=[final_image, latent_gallery, logs_box]
707
  # )
708
 
709
- with gr.Blocks(title="Z-Image- experiment - dont run")as demo:
710
 
711
- gr.Markdown("## 🔥 Z-Image Turbo — Output & Logs Viewer")
712
 
713
- with gr.Tabs():
714
  with gr.TabItem("Output"):
715
  with gr.Row():
716
  with gr.Column(scale=1):
@@ -728,16 +699,16 @@ def generate_image_backup(prompt, height, width, steps, seed, guidance_scale=0.0
728
  with gr.TabItem("Logs"):
729
  logs_box = gr.HTML(value=f"<pre'>{LOGS}</pre>", label="Full Logs")
730
 
731
- def run_and_update_logs(prompt, height, width, steps, seed):
732
  img, gallery, logs = generate_image(prompt, height, width, steps, seed)
733
  combined_logs = "\n".join(logs) + "\n\n" + LOGS # append global logs
734
  return img, gallery, f"<pre id='logbox'>{combined_logs}</pre>"
735
 
736
- generate_btn.click(
737
  fn=run_and_update_logs,
738
  inputs=[prompt, height, width, steps, seed],
739
  outputs=[final_image, latent_gallery, logs_box]
740
- )
741
 
742
 
743
 
 
474
  # Helper: Safe latent extractor
475
  # --------------------------
476
  def safe_get_latents(pipe, height, width, generator, device, LOGS):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
477
  try:
478
+ latents = pipe.prepare_latents(
479
+ batch_size=1,
480
+ num_channels=getattr(pipe.unet, "in_channels", 4),
481
+ height=height,
482
+ width=width,
483
+ dtype=torch.float32,
 
 
 
 
 
 
484
  device=device,
485
+ generator=generator
486
  )
487
+ LOGS.append(f"🔹 Latents shape: {latents.shape}, dtype: {latents.dtype}, device: {latents.device}")
488
+ return latents
489
  except Exception as e:
490
+ LOGS.append(f"⚠️ Latent extraction failed: {e}")
491
+ return torch.randn((1, 4, height // 8, width // 8), generator=generator, device=device)
 
 
492
 
493
 
494
  # --------------------------
 
677
  # outputs=[final_image, latent_gallery, logs_box]
678
  # )
679
 
680
+ with gr.Blocks(title="Z-Image- experiment - dont run")as demo:
681
 
682
+ gr.Markdown("## 🔥 Z-Image Turbo — Output & Logs Viewer")
683
 
684
+ with gr.Tabs():
685
  with gr.TabItem("Output"):
686
  with gr.Row():
687
  with gr.Column(scale=1):
 
699
  with gr.TabItem("Logs"):
700
  logs_box = gr.HTML(value=f"<pre'>{LOGS}</pre>", label="Full Logs")
701
 
702
+ def run_and_update_logs(prompt, height, width, steps, seed):
703
  img, gallery, logs = generate_image(prompt, height, width, steps, seed)
704
  combined_logs = "\n".join(logs) + "\n\n" + LOGS # append global logs
705
  return img, gallery, f"<pre id='logbox'>{combined_logs}</pre>"
706
 
707
+ generate_btn.click(
708
  fn=run_and_update_logs,
709
  inputs=[prompt, height, width, steps, seed],
710
  outputs=[final_image, latent_gallery, logs_box]
711
+ )
712
 
713
 
714