import spaces, json import random import re import torch import gradio as gr from diffusers import ZImagePipeline # ==================== Configuration ==================== MODEL_PATH = "Tongyi-MAI/Z-Image" # ==================== Model Loading (Global Context) ==================== print(f"Loading Z-Image pipeline from {MODEL_PATH}...") pipe = ZImagePipeline.from_pretrained( MODEL_PATH, torch_dtype=torch.bfloat16, low_cpu_mem_usage=False, ) pipe.to("cuda") print("Pipeline loaded successfully!") # pipe.transformer.layers._repeated_blocks = ["ZImageTransformerBlock"] # spaces.aoti_blocks_load(pipe.transformer.layers, "zerogpu-aoti/Z-Image", variant="fa3") # ==================== Generation Function ==================== @spaces.GPU def generate( prompt: str, negative_prompt: str = "", width=1024, height=1024, seed: int = 42, num_inference_steps: int = 50, guidance_scale: float = 4.0, cfg_normalization: bool = False, random_seed: bool = True, gallery_images: list = [], progress=gr.Progress(track_tqdm=True), ): if not prompt.strip(): raise gr.Error("Please enter a prompt.") print("prompt: ", prompt) # Handle seed if random_seed: new_seed = random.randint(1, 1000000) else: new_seed = seed if seed != -1 else random.randint(1, 1000000) # Generate generator = torch.Generator("cuda").manual_seed(new_seed) image = pipe( prompt=prompt, negative_prompt=negative_prompt if negative_prompt.strip() else None, height=height, width=width, cfg_normalization=cfg_normalization, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator, ).images[0] if not gallery_images: gallery_images = [] gallery_images = [image] + gallery_images return gallery_images, int(new_seed) def read_file(path: str) -> str: with open(path, 'r', encoding='utf-8') as f: content = f.read() return content # ==================== Gradio Interface ==================== css = """ #col-container { margin: 0 auto; max-width: 960px; } h3{ text-align: center; display:block; } """ with open('examples/0_examples.json', 'r') as file: examples = json.load(file) output_gallery = gr.Gallery( label="Generated Images", columns=2, rows=2, height=600, object_fit="contain", format="png", interactive=False, ) with gr.Blocks(title="Z-Image Demo") as demo: with gr.Column(elem_id="col-container"): with gr.Column(): gr.HTML(read_file("static/header.html")) with gr.Row(): with gr.Column(scale=1): prompt_input = gr.Textbox( label="Prompt", lines=3, placeholder="Enter your prompt here..." ) negative_prompt_input = gr.Textbox( label="Negative Prompt (optional)", lines=2, placeholder="Enter what you want to avoid..." ) with gr.Row(): width = gr.Slider( label="Width", minimum=512, maximum=2048, step=32, value=1024, ) height = gr.Slider( label="Height", minimum=512, maximum=2048, step=32, value=1024, ) with gr.Row(): seed = gr.Number(label="Seed", value=42, precision=0) random_seed = gr.Checkbox(label="Random Seed", value=True) with gr.Row(): num_inference_steps = gr.Slider( label="Inference Steps", minimum=12, maximum=50, value=28, step=1 ) guidance_scale = gr.Slider( label="Guidance Scale (CFG)", minimum=1.0, maximum=10.0, value=4.0, step=0.1 ) cfg_normalization = gr.Checkbox( label="CFG Normalization", value=False ) generate_btn = gr.Button("Generate", variant="primary") with gr.Column(scale=1): output_gallery.render() gr.Examples(examples=examples, inputs=prompt_input,) gr.Markdown(read_file("static/footer.md")) generate_btn.click( generate, inputs=[ prompt_input, negative_prompt_input, width, height, seed, num_inference_steps, guidance_scale, cfg_normalization, random_seed, output_gallery, ], outputs=[output_gallery, seed], api_name="generate", ) # ==================== Launch ==================== if __name__ == "__main__": demo.launch( server_name="0.0.0.0", server_port=7860, mcp_server=True, css=css )