Z-Image-Turbo / app.py
rahul7star's picture
Update app.py
fa46cad verified
raw
history blame
4.8 kB
import torch
import spaces
import gradio as gr
from diffusers import DiffusionPipeline
import diffusers
import io
import sys
import logging
# ------------------------
# GLOBAL LOG BUFFER
# ------------------------
log_buffer = io.StringIO()
def log(msg):
print(msg)
log_buffer.write(msg + "\n")
# Enable diffusers debug logs
diffusers.utils.logging.set_verbosity_info()
log("Loading Z-Image-Turbo pipeline...")
pipe = DiffusionPipeline.from_pretrained(
"Tongyi-MAI/Z-Image-Turbo",
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=False,
attn_implementation="kernels-community/vllm-flash-attn3",
)
pipe.to("cuda")
#pipe.transformer.layers._repeated_blocks = ["ZImageTransformerBlock"] #spaces.aoti_blocks_load(pipe.transformer.layers, "zerogpu-aoti/Z-Image", variant="fa3")
# ------------------------
# ATTENTION + PIPE INFO
# ------------------------
def pipeline_debug_info(pipe):
info = []
# Transformer config
try:
info.append(f"Transformer attention backend: {pipe.transformer.config.attn_implementation}")
except:
info.append("No transformer.attn_implementation")
# Processor class
try:
proc = pipe.transformer.blocks[0].attn.processor
info.append(f"Processor type: {type(proc)}")
except Exception as e:
info.append(f"Processor error: {e}")
return "\n".join(info)
# ------------------------
# IMAGE GENERATOR
# ------------------------
@spaces.GPU
def generate_image(prompt, height, width, num_inference_steps, seed, randomize_seed, num_images):
log_buffer.truncate(0)
log_buffer.seek(0)
log("=== NEW GENERATION REQUEST ===")
log(f"Prompt: {prompt}")
log(f"Height: {height}, Width: {width}")
log(f"Inference Steps: {num_inference_steps}")
log(f"Num Images: {num_images}")
if randomize_seed:
seed = torch.randint(0, 2**32 - 1, (1,)).item()
log(f"Randomized Seed → {seed}")
else:
log(f"Seed: {seed}")
# Clamp images
num_images = min(max(1, int(num_images)), 3)
# Debug pipe info
log(pipeline_debug_info())
generator = torch.Generator("cuda").manual_seed(int(seed))
log("Running pipeline forward()...")
result = pipe(
prompt=prompt,
height=int(height),
width=int(width),
num_inference_steps=int(num_inference_steps),
guidance_scale=0.0,
generator=generator,
max_sequence_length=1024,
num_images_per_prompt=num_images,
output_type="pil",
)
# Tensor diagnostics (shapes only)
try:
latent_shape = pipe.unet.config.sample_size
log(f"UNet latent resolution (approx): {latent_shape}")
except:
pass
log("Pipeline finished.")
log("Returning images...")
return result.images, seed, log_buffer.getvalue()
# ------------------------
# GRADIO UI
# ------------------------
examples = [
["Young Chinese woman in red Hanfu, intricate embroidery..."],
["A majestic dragon soaring through clouds at sunset..."],
["Cozy coffee shop interior, warm lighting, rain on windows..."],
["Astronaut riding a horse on Mars, cinematic lighting..."],
["Portrait of a wise old wizard..."],
]
with gr.Blocks(title="Z-Image-Turbo Debug Demo") as demo:
gr.Markdown("# 🎨 Z-Image-Turbo — Multi Image + Full Debug Logs")
with gr.Row():
with gr.Column(scale=1):
prompt = gr.Textbox(label="Prompt", lines=4)
with gr.Row():
height = gr.Slider(512, 2048, 1024, step=64, label="Height")
width = gr.Slider(512, 2048, 1024, step=64, label="Width")
num_images = gr.Slider(1, 3, 2, step=1, label="Number of Images")
num_inference_steps = gr.Slider(
1, 20, 9, step=1, label="Inference Steps",
info="9 steps = 8 DiT forward passes",
)
with gr.Row():
seed = gr.Number(label="Seed", value=42, precision=0)
randomize_seed = gr.Checkbox(label="Randomize Seed", value=False)
generate_btn = gr.Button("🚀 Generate", variant="primary")
with gr.Column(scale=1):
output_images = gr.Gallery(label="Generated Images")
used_seed = gr.Number(label="Seed Used", interactive=False)
debug_log = gr.Textbox(
label="Debug Log Output",
lines=25,
interactive=False
)
gr.Examples(examples=examples, inputs=[prompt], cache_examples=False)
generate_btn.click(
fn=generate_image,
inputs=[prompt, height, width, num_inference_steps, seed, randomize_seed, num_images],
outputs=[output_images, used_seed, debug_log],
)
if __name__ == "__main__":
demo.launch()