Spaces:
Running
on
Zero
Running
on
Zero
| import spaces | |
| import os | |
| import random | |
| import re | |
| import sys | |
| import warnings | |
| import logging | |
| from PIL import Image | |
| from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler | |
| import gradio as gr | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| sys.path.append(os.path.dirname(os.path.abspath(__file__))) | |
| from diffusers import ZImagePipeline | |
| from diffusers.models.transformers.transformer_z_image import ZImageTransformer2DModel | |
| # ==================== Environment Variables ================================== | |
| MODEL_PATH = os.environ.get("MODEL_PATH", "Tongyi-MAI/Z-Image-Turbo") | |
| ENABLE_COMPILE = os.environ.get("ENABLE_COMPILE", "true").lower() == "true" | |
| ENABLE_WARMUP = os.environ.get("ENABLE_WARMUP", "true").lower() == "true" | |
| ATTENTION_BACKEND = os.environ.get("ATTENTION_BACKEND", "flash_3") | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| # ============================================================================= | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| warnings.filterwarnings("ignore") | |
| logging.getLogger("transformers").setLevel(logging.ERROR) | |
| RES_CHOICES = { | |
| "1024": [ | |
| "1024x1024 ( 1:1 )", | |
| "1152x896 ( 9:7 )", | |
| "896x1152 ( 7:9 )", | |
| "1152x864 ( 4:3 )", | |
| "864x1152 ( 3:4 )", | |
| "1248x832 ( 3:2 )", | |
| "832x1248 ( 2:3 )", | |
| "1280x720 ( 16:9 )", | |
| "720x1280 ( 9:16 )", | |
| "1344x576 ( 21:9 )", | |
| "576x1344 ( 9:21 )", | |
| ], | |
| "1280": [ | |
| "1280x1280 ( 1:1 )", | |
| "1440x1120 ( 9:7 )", | |
| "1120x1440 ( 7:9 )", | |
| "1472x1104 ( 4:3 )", | |
| "1104x1472 ( 3:4 )", | |
| "1536x1024 ( 3:2 )", | |
| "1024x1536 ( 2:3 )", | |
| "1536x864 ( 16:9 )", | |
| "864x1536 ( 9:16 )", | |
| "1680x720 ( 21:9 )", | |
| "720x1680 ( 9:21 )", | |
| ], | |
| "1536": [ | |
| "1536x1536 ( 1:1 )", | |
| "1728x1344 ( 9:7 )", | |
| "1344x1728 ( 7:9 )", | |
| "1728x1296 ( 4:3 )", | |
| "1296x1728 ( 3:4 )", | |
| "1872x1248 ( 3:2 )", | |
| "1248x1872 ( 2:3 )", | |
| "2048x1152 ( 16:9 )", | |
| "1152x2048 ( 9:16 )", | |
| "2016x864 ( 21:9 )", | |
| "864x2016 ( 9:21 )", | |
| ], | |
| } | |
| RESOLUTION_SET = [] | |
| for resolutions in RES_CHOICES.values(): | |
| RESOLUTION_SET.extend(resolutions) | |
| EXAMPLE_PROMPTS = [ | |
| ["一位男士和他的贵宾犬穿着配套的服装参加狗狗秀,室内灯光,背景中有观众。"], | |
| [ | |
| "极具氛围感的暗调人像,一位优雅的中国美女在黑暗的房间里。一束强光通过遮光板,在她的脸上投射出一个清晰的闪电形状的光影,正好照亮一只眼睛。高对比度,明暗交界清晰,神秘感,莱卡相机色调。" | |
| ], | |
| [ | |
| "一张中景手机自拍照片拍摄了一位留着长黑发的年轻东亚女子在灯光明亮的电梯内对着镜子自拍。她穿着一件带有白色花朵图案的黑色露肩短上衣和深色牛仔裤。她的头微微倾斜,嘴唇嘟起做亲吻状,非常可爱俏皮。她右手拿着一部深灰色智能手机,遮住了部分脸,后置摄像头镜头对着镜子" | |
| ], | |
| [ | |
| "Young Chinese woman in red Hanfu, intricate embroidery. Impeccable makeup, red floral forehead pattern. Elaborate high bun, golden phoenix headdress, red flowers, beads. Holds round folding fan with lady, trees, bird. Neon lightning-bolt lamp (⚡️), bright yellow glow, above extended left palm. Soft-lit outdoor night background, silhouetted tiered pagoda (西安大雁塔), blurred colorful distant lights." | |
| ], | |
| ] | |
| def get_resolution(resolution): | |
| match = re.search(r"(\d+)\s*[×x]\s*(\d+)", resolution) | |
| if match: | |
| return int(match.group(1)), int(match.group(2)) | |
| return 1024, 1024 | |
| def load_models(model_path, enable_compile=False, attention_backend="native"): | |
| print(f"Loading models from {model_path}...") | |
| use_auth_token = HF_TOKEN if HF_TOKEN else True | |
| if not os.path.exists(model_path): | |
| vae = AutoencoderKL.from_pretrained( | |
| f"{model_path}", | |
| subfolder="vae", | |
| torch_dtype=torch.bfloat16, | |
| device_map="cuda", | |
| use_auth_token=use_auth_token, | |
| ) | |
| text_encoder = AutoModelForCausalLM.from_pretrained( | |
| f"{model_path}", | |
| subfolder="text_encoder", | |
| torch_dtype=torch.bfloat16, | |
| device_map="cuda", | |
| use_auth_token=use_auth_token, | |
| ).eval() | |
| tokenizer = AutoTokenizer.from_pretrained(f"{model_path}", subfolder="tokenizer", use_auth_token=use_auth_token) | |
| else: | |
| vae = AutoencoderKL.from_pretrained( | |
| os.path.join(model_path, "vae"), torch_dtype=torch.bfloat16, device_map="cuda" | |
| ) | |
| text_encoder = AutoModelForCausalLM.from_pretrained( | |
| os.path.join(model_path, "text_encoder"), | |
| torch_dtype=torch.bfloat16, | |
| device_map="cuda", | |
| ).eval() | |
| tokenizer = AutoTokenizer.from_pretrained(os.path.join(model_path, "tokenizer")) | |
| tokenizer.padding_side = "left" | |
| if enable_compile: | |
| print("Enabling torch.compile optimizations...") | |
| torch._inductor.config.conv_1x1_as_mm = True | |
| torch._inductor.config.coordinate_descent_tuning = True | |
| torch._inductor.config.epilogue_fusion = False | |
| torch._inductor.config.coordinate_descent_check_all_directions = True | |
| torch._inductor.config.max_autotune_gemm = True | |
| torch._inductor.config.max_autotune_gemm_backends = "TRITON,ATEN" | |
| torch._inductor.config.triton.cudagraphs = False | |
| pipe = ZImagePipeline(scheduler=None, vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, transformer=None) | |
| if enable_compile: | |
| pipe.vae.disable_tiling() | |
| if not os.path.exists(model_path): | |
| transformer = ZImageTransformer2DModel.from_pretrained( | |
| f"{model_path}", subfolder="transformer", use_auth_token=use_auth_token | |
| ).to("cuda", torch.bfloat16) | |
| else: | |
| transformer = ZImageTransformer2DModel.from_pretrained(os.path.join(model_path, "transformer")).to( | |
| "cuda", torch.bfloat16 | |
| ) | |
| pipe.transformer = transformer | |
| pipe.transformer.set_attention_backend(attention_backend) | |
| if enable_compile: | |
| print("Compiling transformer...") | |
| pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune-no-cudagraphs", fullgraph=False) | |
| pipe.to("cuda", torch.bfloat16) | |
| return pipe | |
| def generate_image( | |
| pipe, | |
| prompt, | |
| resolution="1024x1024", | |
| seed=42, | |
| guidance_scale=5.0, | |
| num_inference_steps=50, | |
| shift=3.0, | |
| max_sequence_length=512, | |
| progress=gr.Progress(track_tqdm=True), | |
| ): | |
| width, height = get_resolution(resolution) | |
| generator = torch.Generator("cuda").manual_seed(seed) | |
| scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=shift) | |
| pipe.scheduler = scheduler | |
| image = pipe( | |
| prompt=prompt, | |
| height=height, | |
| width=width, | |
| guidance_scale=guidance_scale, | |
| num_inference_steps=num_inference_steps, | |
| generator=generator, | |
| max_sequence_length=max_sequence_length, | |
| ).images[0] | |
| return image | |
| def warmup_model(pipe, resolutions): | |
| print("Starting warmup phase...") | |
| dummy_prompt = "warmup" | |
| for res_str in resolutions: | |
| print(f"Warming up for resolution: {res_str}") | |
| try: | |
| for i in range(3): | |
| generate_image( | |
| pipe, | |
| prompt=dummy_prompt, | |
| resolution=res_str, | |
| num_inference_steps=9, | |
| guidance_scale=0.0, | |
| seed=42 + i, | |
| ) | |
| except Exception as e: | |
| print(f"Warmup failed for {res_str}: {e}") | |
| print("Warmup completed.") | |
| pipe = None | |
| def init_app(): | |
| global pipe | |
| try: | |
| pipe = load_models(MODEL_PATH, enable_compile=ENABLE_COMPILE, attention_backend=ATTENTION_BACKEND) | |
| print(f"Model loaded. Compile: {ENABLE_COMPILE}, Backend: {ATTENTION_BACKEND}") | |
| if ENABLE_WARMUP: | |
| all_resolutions = [] | |
| for cat in RES_CHOICES.values(): | |
| all_resolutions.extend(cat) | |
| warmup_model(pipe, all_resolutions) | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| pipe = None | |
| def generate( | |
| prompt, | |
| resolution="1024x1024 ( 1:1 )", | |
| seed=42, | |
| steps=9, | |
| shift=3.0, | |
| random_seed=True, | |
| gallery_images=None, | |
| progress=gr.Progress(track_tqdm=True), | |
| ): | |
| """ | |
| Generate an image using the Z-Image model based on the provided prompt and settings. | |
| """ | |
| if random_seed: | |
| new_seed = random.randint(1, 1000000) | |
| else: | |
| new_seed = seed if seed != -1 else random.randint(1, 1000000) | |
| if pipe is None: | |
| raise gr.Error("Model not loaded.") | |
| try: | |
| resolution_str = resolution.split(" ")[0] | |
| except: | |
| resolution_str = "1024x1024" | |
| image = generate_image( | |
| pipe=pipe, | |
| prompt=prompt, | |
| resolution=resolution_str, | |
| seed=new_seed, | |
| guidance_scale=0.0, | |
| num_inference_steps=int(steps + 1), | |
| shift=shift, | |
| ) | |
| if gallery_images is None: | |
| gallery_images = [] | |
| gallery_images = [image] + gallery_images | |
| return gallery_images, str(new_seed), int(new_seed) | |
| init_app() | |
| # ==================== AoTI (Ahead of Time Inductor compilation) ==================== | |
| pipe.transformer.layers._repeated_blocks = ["ZImageTransformerBlock"] | |
| spaces.aoti_blocks_load(pipe.transformer.layers, "zerogpu-aoti/Z-Image", variant="fa3") | |
| with gr.Blocks(title="Z-Image Demo") as demo: | |
| gr.Markdown( | |
| """<div align="center"> | |
| # Z-Image Generation Demo | |
| [](https://github.com/Tongyi-MAI/Z-Image) | |
| *An Efficient Image Generation Foundation Model with Single-Stream Diffusion Transformer* | |
| </div>""" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| prompt_input = gr.Textbox(label="Prompt", lines=3, placeholder="Enter your prompt here...") | |
| with gr.Row(): | |
| choices = [int(k) for k in RES_CHOICES.keys()] | |
| res_cat = gr.Dropdown(value=1024, choices=choices, label="Resolution Category") | |
| initial_res_choices = RES_CHOICES["1024"] | |
| resolution = gr.Dropdown( | |
| value=initial_res_choices[0], choices=RESOLUTION_SET, label="Width x Height (Ratio)" | |
| ) | |
| with gr.Row(): | |
| seed = gr.Number(label="Seed", value=42, precision=0) | |
| random_seed = gr.Checkbox(label="Random Seed", value=True) | |
| with gr.Row(): | |
| steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=8, step=1, interactive=False) | |
| shift = gr.Slider(label="Time Shift", minimum=1.0, maximum=10.0, value=3.0, step=0.1) | |
| generate_btn = gr.Button("Generate", variant="primary") | |
| gr.Markdown("### 📝 Example Prompts") | |
| gr.Examples(examples=EXAMPLE_PROMPTS, inputs=prompt_input, label=None) | |
| with gr.Column(scale=1): | |
| output_gallery = gr.Gallery( | |
| label="Generated Images", | |
| columns=2, | |
| rows=2, | |
| height=600, | |
| object_fit="contain", | |
| format="png", | |
| interactive=False, | |
| ) | |
| used_seed = gr.Textbox(label="Seed Used", interactive=False) | |
| def update_res_choices(_res_cat): | |
| if str(_res_cat) in RES_CHOICES: | |
| res_choices = RES_CHOICES[str(_res_cat)] | |
| else: | |
| res_choices = RES_CHOICES["1024"] | |
| return gr.update(value=res_choices[0], choices=res_choices) | |
| res_cat.change(update_res_choices, inputs=res_cat, outputs=resolution, api_visibility="private") | |
| generate_btn.click( | |
| generate, | |
| inputs=[prompt_input, resolution, seed, steps, shift, random_seed, output_gallery], | |
| outputs=[output_gallery, used_seed, seed], | |
| api_visibility="public", | |
| ) | |
| css = """ | |
| .fillable{max-width: 1230px !important} | |
| """ | |
| if __name__ == "__main__": | |
| demo.launch(css=css, mcp_server=True) | |