import torch import spaces import gradio as gr import time import re import random import os from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler import warnings # 忽略警告 warnings.filterwarnings("ignore") # ==================== 1. 分辨率配置 ==================== RES_CHOICES = { "1024": [ "720x1280 (9:16)", "1024x1024 (1:1)", "1152x896 (9:7)", "896x1152 (7:9)", "1152x864 (4:3)", "864x1152 (3:4)", "1248x832 (3:2)", "832x1248 (2:3)", "1280x720 (16:9)", "1344x576 (21:9)", "576x1344 (9:21)", ], "1280": [ "864x1536 (9:16)", "1280x1280 (1:1)", "1440x1120 (9:7)", "1120x1440 (7:9)", "1472x1104 (4:3)", "1104x1472 (3:4)", "1536x1024 (3:2)", "1024x1536 (2:3)", "1536x864 (16:9)", "1680x720 (21:9)", "720x1680 (9:21)", ], "1536": [ "1152x2048 (9:16)", "1536x1536 (1:1)", "1728x1344 (9:7)", "1344x1728 (7:9)", "1728x1296 (4:3)", "1296x1728 (3:4)", "1872x1248 (3:2)", "1248x1872 (2:3)", "2048x1152 (16:9)", "2016x864 (21:9)", "864x2016 (9:21)", ], } def get_resolution(resolution_str): """从分辨率字符串提取宽高,确保是8的倍数""" if not resolution_str: return 1024, 1024 match = re.search(r"(\d+)\s*[×x]\s*(\d+)", resolution_str) if match: width = int(match.group(1)) height = int(match.group(2)) return width - width % 8, height - height % 8 return 1024, 1024 # ==================== 2. 模型加载与核心优化 ==================== print("🚀 Loading Z-Image-Turbo pipeline...") # 必须设置为 True,才能加载 Z-Image 自定义的 Pipeline 和 Transformer 类 # 否则无法调用 set_attention_backend pipe = DiffusionPipeline.from_pretrained( "Tongyi-MAI/Z-Image-Turbo", torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, use_safetensors=True, trust_remote_code=True, ) # 使用 FlowMatchEulerDiscreteScheduler 并设置 shift=3.0 try: scheduler_config = dict(pipe.scheduler.config) scheduler_config.pop("algorithm_type", None) pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_config( scheduler_config, shift=3.0 ) print("✅ Scheduler optimized with shift=3.0") except Exception as e: print(f"⚠️ Scheduler config warning: {e}") # 移动到 GPU pipe.to("cuda") print("Enabling torch.compile optimizations...") torch._inductor.config.conv_1x1_as_mm = True torch._inductor.config.coordinate_descent_tuning = True torch._inductor.config.epilogue_fusion = False torch._inductor.config.coordinate_descent_check_all_directions = True torch._inductor.config.max_autotune_gemm = True torch._inductor.config.max_autotune_gemm_backends = "TRITON,ATEN" torch._inductor.config.triton.cudagraphs = False # 尝试按顺序启用最快的后端 def enable_best_attention_backend(pipeline): backends = [ # ===== S Tier:当前最优 ===== "flash_varlen", # FA v2 varlen,稳定 + 高性能 "_flash_3_varlen_hub", # FA v3 varlen(hub),SM90 上非常强 "_flash_varlen_3", # FA v3 varlen(本地) "_flash_3", # FA v3 非 varlen "flash", # FA v2 非 varlen # ===== A Tier:可接受 / 备用高性能 ===== "flash_varlen_hub", "flash_hub", "xformers", # 成熟但性能略逊于 FA "_native_flash", # ===== B Tier:框架原生 / 兼容优先 ===== "native", "_native_efficient", "_native_cudnn", # ===== C Tier:特定后端 / 场景受限 ===== "flex", "_native_xla", "_native_npu", "aiter", # ===== D Tier:Sage / 实验性量化实现 ===== "sage", "sage_hub", "sage_varlen", "_sage_qk_int8_pv_fp16_cuda", "_sage_qk_int8_pv_fp16_triton", "_sage_qk_int8_pv_fp8_cuda", "_sage_qk_int8_pv_fp8_cuda_sm90", # ===== Fallback ===== "_native_math", ] # 检查 pipeline.transformer 是否有 set_attention_backend 方法 # 这是 Z-Image 自定义类特有的 enabled = False for backend in backends: try: pipeline.transformer.set_attention_backend(backend) print(f"✅ Attention backend set to: {backend}") enabled = True break except Exception as e: pass if not enabled: print("⚠️ Warning: Transformer model does not support 'set_attention_backend'. Custom code might not be loaded.") # 如果加载失败,尝试标准的 xformers try: pipeline.enable_xformers_memory_efficient_attention() print("✅ Standard xFormers enabled as fallback") except: pass # 执行后端设置 enable_best_attention_backend(pipe) # VAE 内存优化 try: pipe.vae.enable_slicing() except: pass # print("Compiling transformer...") # pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune-no-cudagraphs", fullgraph=False) # ==================== 3. 生成逻辑 ==================== @spaces.GPU def generate_image( prompt, resolution_choice, use_custom_res, custom_width, custom_height, num_inference_steps, seed, randomize_seed, negative_prompt, gallery_history, progress=gr.Progress(track_tqdm=True) ): if gallery_history is None: gallery_history = [] try: if not prompt or len(prompt.strip()) < 2: raise gr.Error("请输提示词 (Prompt)") prompt = prompt.strip() neg_prompt = negative_prompt.strip() if negative_prompt else None if use_custom_res: width = int(custom_width) - int(custom_width) % 8 height = int(custom_height) - int(custom_height) % 8 else: width, height = get_resolution(resolution_choice) if randomize_seed: seed = random.randint(0, 2**32 - 1) seed = int(seed) start_time = time.time() generator = torch.Generator("cuda").manual_seed(seed) # 清理显存确保最大空间 torch.cuda.empty_cache() with torch.cuda.amp.autocast(dtype=torch.bfloat16): image = pipe( prompt=prompt, height=height, width=width, num_inference_steps=int(num_inference_steps), guidance_scale=0.0, generator=generator, negative_prompt=neg_prompt, max_sequence_length=512, ).images[0] gen_time = time.time() - start_time # 格式化历史记录 info_label = f"{width}x{height} | Steps: {num_inference_steps} | Seed: {seed} | {gen_time:.2f}s" gallery_history.insert(0, (image, info_label)) return gallery_history, seed except Exception as e: raise gr.Error(f"生成错误: {str(e)}") # ==================== 4. UI 样式 ==================== css = """ @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;600;800&display=swap'); body, .gradio-container { font-family: 'Inter', sans-serif !important; } .header-container { text-align: center; margin-bottom: 20px; } .header-title { font-size: 2.5rem; font-weight: 800; margin: 0; background: linear-gradient(135deg, #f59e0b, #ea580c); -webkit-background-clip: text; -webkit-text-fill-color: transparent; } .header-subtitle { font-size: 1rem; color: #6b7280; font-weight: 500; } .primary-btn { background: linear-gradient(90deg, #f59e0b 0%, #d97706 100%) !important; border: none !important; color: white !important; font-weight: 600 !important; font-size: 1.1rem !important; box-shadow: 0 4px 6px -1px rgba(245, 158, 11, 0.2) !important; } .primary-btn:hover { transform: translateY(-2px); box-shadow: 0 10px 15px -3px rgba(245, 158, 11, 0.3) !important; } .panel-container { background: #ffffff; border: 1px solid #e5e7eb; border-radius: 12px; padding: 15px; } .dark .panel-container { background: #1f2937; border-color: #374151; } """ # ==================== 5. Gradio 界面 ==================== with gr.Blocks(theme=gr.themes.Soft(primary_hue="orange"), css=css, title="Z-Image-Turbo") as demo: gr.HTML("""

⚡ Z-Image-Turbo

Optimized Backend • 8 Steps • Gallery History

""") with gr.Row(): # --- 控制面板 --- with gr.Column(scale=4, min_width=320): with gr.Group(elem_classes="panel-container"): prompt = gr.Textbox( label="Prompt", placeholder="Enter your prompt here...", lines=3 ) negative_prompt = gr.Textbox( label="Negative Prompt", placeholder="Low quality, blurry...", lines=1 ) generate_btn = gr.Button("🚀 Generate", elem_classes="primary-btn") with gr.Group(elem_classes="panel-container"): gr.Markdown("### 📐 Resolution") res_category = gr.Radio( choices=["1024", "1280", "1536"], value="1024", label="Resolution Base", container=False ) resolution_dropdown = gr.Dropdown( choices=RES_CHOICES["1024"], value=RES_CHOICES["1024"][0], label="Select Ratio", show_label=False ) with gr.Accordion("Custom Size", open=False): use_custom_res = gr.Checkbox(label="Enable Custom", value=False) with gr.Row(visible=False) as custom_res_row: width_slider = gr.Slider(512, 1536, value=1024, step=64, label="W") height_slider = gr.Slider(512, 1536, value=1024, step=64, label="H") with gr.Accordion("⚙️ Settings", open=False): with gr.Group(elem_classes="panel-container"): steps_slider = gr.Slider(4, 20, value=8, step=1, label="Steps") with gr.Row(): random_seed = gr.Checkbox(label="Random Seed", value=True) seed_input = gr.Number(label="Seed", value=42, visible=False, precision=0) # --- 画廊 --- with gr.Column(scale=6, min_width=500): output_gallery = gr.Gallery( label="History", value=[], columns=[2], rows=[2], object_fit="contain", height="auto", show_share_button=True, show_download_button=True, interactive=False ) with gr.Row(): last_seed_display = gr.Textbox(label="Last Seed", interactive=False, scale=3) clear_btn = gr.Button("🗑️ Clear", scale=1, variant="secondary") # 交互逻辑 def update_resolution_list(category): return gr.Dropdown(choices=RES_CHOICES[category], value=RES_CHOICES[category][0]) res_category.change(update_resolution_list, inputs=res_category, outputs=resolution_dropdown) use_custom_res.change( lambda x: (gr.Row(visible=x), gr.Dropdown(interactive=not x)), inputs=use_custom_res, outputs=[custom_res_row, resolution_dropdown] ) random_seed.change(lambda x: gr.Number(visible=not x), inputs=random_seed, outputs=seed_input) generate_btn.click( fn=generate_image, inputs=[prompt, resolution_dropdown, use_custom_res, width_slider, height_slider, steps_slider, seed_input, random_seed, negative_prompt, output_gallery], outputs=[output_gallery, last_seed_display] ) clear_btn.click(lambda: ([], ""), outputs=[output_gallery, last_seed_display]) if __name__ == "__main__": demo.launch()