import spaces import os import random import re import sys import warnings import logging from PIL import Image from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer sys.path.append(os.path.dirname(os.path.abspath(__file__))) from diffusers import ZImagePipeline from diffusers.models.transformers.transformer_z_image import ZImageTransformer2DModel # ==================== Environment Variables ================================== MODEL_PATH = os.environ.get("MODEL_PATH", "Tongyi-MAI/Z-Image-Turbo") ENABLE_COMPILE = os.environ.get("ENABLE_COMPILE", "true").lower() == "true" ENABLE_WARMUP = os.environ.get("ENABLE_WARMUP", "true").lower() == "true" ATTENTION_BACKEND = os.environ.get("ATTENTION_BACKEND", "flash_3") HF_TOKEN = os.environ.get("HF_TOKEN") # ============================================================================= os.environ["TOKENIZERS_PARALLELISM"] = "false" warnings.filterwarnings("ignore") logging.getLogger("transformers").setLevel(logging.ERROR) RES_CHOICES = { "1024": [ "1024x1024 ( 1:1 )", "1152x896 ( 9:7 )", "896x1152 ( 7:9 )", "1152x864 ( 4:3 )", "864x1152 ( 3:4 )", "1248x832 ( 3:2 )", "832x1248 ( 2:3 )", "1280x720 ( 16:9 )", "720x1280 ( 9:16 )", "1344x576 ( 21:9 )", "576x1344 ( 9:21 )", ], "1280": [ "1280x1280 ( 1:1 )", "1440x1120 ( 9:7 )", "1120x1440 ( 7:9 )", "1472x1104 ( 4:3 )", "1104x1472 ( 3:4 )", "1536x1024 ( 3:2 )", "1024x1536 ( 2:3 )", "1536x864 ( 16:9 )", "864x1536 ( 9:16 )", "1680x720 ( 21:9 )", "720x1680 ( 9:21 )", ], "1536": [ "1536x1536 ( 1:1 )", "1728x1344 ( 9:7 )", "1344x1728 ( 7:9 )", "1728x1296 ( 4:3 )", "1296x1728 ( 3:4 )", "1872x1248 ( 3:2 )", "1248x1872 ( 2:3 )", "2048x1152 ( 16:9 )", "1152x2048 ( 9:16 )", "2016x864 ( 21:9 )", "864x2016 ( 9:21 )", ], } RESOLUTION_SET = [] for resolutions in RES_CHOICES.values(): RESOLUTION_SET.extend(resolutions) EXAMPLE_PROMPTS = [ ["一位男士和他的贵宾犬穿着配套的服装参加狗狗秀,室内灯光,背景中有观众。"], [ "极具氛围感的暗调人像,一位优雅的中国美女在黑暗的房间里。一束强光通过遮光板,在她的脸上投射出一个清晰的闪电形状的光影,正好照亮一只眼睛。高对比度,明暗交界清晰,神秘感,莱卡相机色调。" ], [ "一张中景手机自拍照片拍摄了一位留着长黑发的年轻东亚女子在灯光明亮的电梯内对着镜子自拍。她穿着一件带有白色花朵图案的黑色露肩短上衣和深色牛仔裤。她的头微微倾斜,嘴唇嘟起做亲吻状,非常可爱俏皮。她右手拿着一部深灰色智能手机,遮住了部分脸,后置摄像头镜头对着镜子" ], [ "Young Chinese woman in red Hanfu, intricate embroidery. Impeccable makeup, red floral forehead pattern. Elaborate high bun, golden phoenix headdress, red flowers, beads. Holds round folding fan with lady, trees, bird. Neon lightning-bolt lamp (⚡️), bright yellow glow, above extended left palm. Soft-lit outdoor night background, silhouetted tiered pagoda (西安大雁塔), blurred colorful distant lights." ], ] def get_resolution(resolution): match = re.search(r"(\d+)\s*[×x]\s*(\d+)", resolution) if match: return int(match.group(1)), int(match.group(2)) return 1024, 1024 def load_models(model_path, enable_compile=False, attention_backend="native"): print(f"Loading models from {model_path}...") use_auth_token = HF_TOKEN if HF_TOKEN else True if not os.path.exists(model_path): vae = AutoencoderKL.from_pretrained( f"{model_path}", subfolder="vae", torch_dtype=torch.bfloat16, device_map="cuda", use_auth_token=use_auth_token, ) text_encoder = AutoModelForCausalLM.from_pretrained( f"{model_path}", subfolder="text_encoder", torch_dtype=torch.bfloat16, device_map="cuda", use_auth_token=use_auth_token, ).eval() tokenizer = AutoTokenizer.from_pretrained(f"{model_path}", subfolder="tokenizer", use_auth_token=use_auth_token) else: vae = AutoencoderKL.from_pretrained( os.path.join(model_path, "vae"), torch_dtype=torch.bfloat16, device_map="cuda" ) text_encoder = AutoModelForCausalLM.from_pretrained( os.path.join(model_path, "text_encoder"), torch_dtype=torch.bfloat16, device_map="cuda", ).eval() tokenizer = AutoTokenizer.from_pretrained(os.path.join(model_path, "tokenizer")) tokenizer.padding_side = "left" if enable_compile: print("Enabling torch.compile optimizations...") torch._inductor.config.conv_1x1_as_mm = True torch._inductor.config.coordinate_descent_tuning = True torch._inductor.config.epilogue_fusion = False torch._inductor.config.coordinate_descent_check_all_directions = True torch._inductor.config.max_autotune_gemm = True torch._inductor.config.max_autotune_gemm_backends = "TRITON,ATEN" torch._inductor.config.triton.cudagraphs = False pipe = ZImagePipeline(scheduler=None, vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, transformer=None) if enable_compile: pipe.vae.disable_tiling() if not os.path.exists(model_path): transformer = ZImageTransformer2DModel.from_pretrained( f"{model_path}", subfolder="transformer", use_auth_token=use_auth_token ).to("cuda", torch.bfloat16) else: transformer = ZImageTransformer2DModel.from_pretrained(os.path.join(model_path, "transformer")).to( "cuda", torch.bfloat16 ) pipe.transformer = transformer pipe.transformer.set_attention_backend(attention_backend) if enable_compile: print("Compiling transformer...") pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune-no-cudagraphs", fullgraph=False) pipe.to("cuda", torch.bfloat16) return pipe def generate_image( pipe, prompt, resolution="1024x1024", seed=42, guidance_scale=5.0, num_inference_steps=50, shift=3.0, max_sequence_length=512, progress=gr.Progress(track_tqdm=True), ): width, height = get_resolution(resolution) generator = torch.Generator("cuda").manual_seed(seed) scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=shift) pipe.scheduler = scheduler image = pipe( prompt=prompt, height=height, width=width, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, generator=generator, max_sequence_length=max_sequence_length, ).images[0] return image def warmup_model(pipe, resolutions): print("Starting warmup phase...") dummy_prompt = "warmup" for res_str in resolutions: print(f"Warming up for resolution: {res_str}") try: for i in range(3): generate_image( pipe, prompt=dummy_prompt, resolution=res_str, num_inference_steps=9, guidance_scale=0.0, seed=42 + i, ) except Exception as e: print(f"Warmup failed for {res_str}: {e}") print("Warmup completed.") pipe = None def init_app(): global pipe try: pipe = load_models(MODEL_PATH, enable_compile=ENABLE_COMPILE, attention_backend=ATTENTION_BACKEND) print(f"Model loaded. Compile: {ENABLE_COMPILE}, Backend: {ATTENTION_BACKEND}") if ENABLE_WARMUP: all_resolutions = [] for cat in RES_CHOICES.values(): all_resolutions.extend(cat) warmup_model(pipe, all_resolutions) except Exception as e: print(f"Error loading model: {e}") pipe = None @spaces.GPU def generate( prompt, resolution="1024x1024 ( 1:1 )", seed=42, steps=9, shift=3.0, random_seed=True, gallery_images=None, progress=gr.Progress(track_tqdm=True), ): """ Generate an image using the Z-Image model based on the provided prompt and settings. """ if random_seed: new_seed = random.randint(1, 1000000) else: new_seed = seed if seed != -1 else random.randint(1, 1000000) if pipe is None: raise gr.Error("Model not loaded.") try: resolution_str = resolution.split(" ")[0] except: resolution_str = "1024x1024" image = generate_image( pipe=pipe, prompt=prompt, resolution=resolution_str, seed=new_seed, guidance_scale=0.0, num_inference_steps=int(steps + 1), shift=shift, ) if gallery_images is None: gallery_images = [] gallery_images = [image] + gallery_images return gallery_images, str(new_seed), int(new_seed) init_app() # ==================== AoTI (Ahead of Time Inductor compilation) ==================== pipe.transformer.layers._repeated_blocks = ["ZImageTransformerBlock"] spaces.aoti_blocks_load(pipe.transformer.layers, "zerogpu-aoti/Z-Image", variant="fa3") with gr.Blocks(title="Z-Image Demo") as demo: gr.Markdown( """