File size: 12,496 Bytes
5ca66c3
5902e2a
 
 
dd08023
205b0fd
00d4335
205b0fd
ddb8055
 
205b0fd
ddb8055
5ca66c3
e487bbb
dd08023
 
0591cce
dd08023
 
 
 
 
 
 
 
 
 
 
 
0591cce
dd08023
 
 
 
 
 
 
 
 
 
 
 
0591cce
dd08023
 
 
 
 
 
 
 
 
 
 
 
 
 
00d4335
 
 
dd08023
 
 
 
205b0fd
dd08023
 
e487bbb
00d4335
5ca66c3
e487bbb
 
5902e2a
00d4335
5902e2a
 
 
e487bbb
5902e2a
5ca66c3
e487bbb
 
 
 
 
 
 
 
 
 
 
205b0fd
00d4335
5902e2a
5ca66c3
7135cb4
 
 
 
 
 
 
 
 
 
e487bbb
 
7135cb4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e487bbb
 
7135cb4
 
c2a4bc2
7135cb4
 
 
d8ef4d7
c2a4bc2
7135cb4
 
 
e487bbb
 
 
 
 
 
 
 
 
 
5ca66c3
e487bbb
ddb8055
205b0fd
ddb8055
 
5ca66c3
babc3bc
 
7135cb4
00d4335
5902e2a
 
 
dd08023
205b0fd
00d4335
 
205b0fd
 
 
00d4335
e487bbb
5902e2a
 
205b0fd
 
 
5902e2a
 
00d4335
5902e2a
 
00d4335
5902e2a
00d4335
 
 
dd08023
 
5902e2a
 
205b0fd
5902e2a
 
00d4335
5902e2a
 
e487bbb
00d4335
 
ddb8055
5902e2a
 
 
 
 
e487bbb
5902e2a
00d4335
 
5902e2a
205b0fd
00d4335
5902e2a
e487bbb
 
00d4335
5902e2a
00d4335
5902e2a
 
e487bbb
205b0fd
e487bbb
205b0fd
00d4335
 
 
 
 
 
e487bbb
00d4335
 
 
 
 
e487bbb
205b0fd
 
00d4335
 
e487bbb
00d4335
e487bbb
00d4335
 
e487bbb
9620421
00d4335
9620421
 
e487bbb
 
5902e2a
00d4335
 
 
e487bbb
5902e2a
00d4335
5902e2a
205b0fd
e487bbb
00d4335
 
 
e487bbb
 
 
00d4335
 
e487bbb
 
 
00d4335
e487bbb
205b0fd
00d4335
e487bbb
 
 
 
 
 
 
00d4335
 
 
e487bbb
 
00d4335
 
e487bbb
 
00d4335
e487bbb
 
205b0fd
e487bbb
00d4335
e487bbb
00d4335
e487bbb
 
205b0fd
e487bbb
205b0fd
 
e487bbb
00d4335
205b0fd
 
 
00d4335
5902e2a
205b0fd
 
ddb8055
00d4335
e487bbb
 
205b0fd
e487bbb
00d4335
e487bbb
00d4335
e487bbb
dd08023
00d4335
e487bbb
 
dd08023
 
e487bbb
dd08023
5902e2a
 
e487bbb
00d4335
5902e2a
 
e487bbb
5ca66c3
 
205b0fd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
import torch
import spaces
import gradio as gr
import time
import re
import random
import os
from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler
import warnings

# 忽略警告
warnings.filterwarnings("ignore")

# ==================== 1. 分辨率配置 ====================
RES_CHOICES = {
    "1024": [
        "720x1280 (9:16)",
        "1024x1024 (1:1)",
        "1152x896 (9:7)",
        "896x1152 (7:9)",
        "1152x864 (4:3)",
        "864x1152 (3:4)",
        "1248x832 (3:2)",
        "832x1248 (2:3)",
        "1280x720 (16:9)",
        "1344x576 (21:9)",
        "576x1344 (9:21)",
    ],
    "1280": [
        "864x1536 (9:16)",
        "1280x1280 (1:1)",
        "1440x1120 (9:7)",
        "1120x1440 (7:9)",
        "1472x1104 (4:3)",
        "1104x1472 (3:4)",
        "1536x1024 (3:2)",
        "1024x1536 (2:3)",
        "1536x864 (16:9)",
        "1680x720 (21:9)",
        "720x1680 (9:21)",
    ],
    "1536": [
        "1152x2048 (9:16)",
        "1536x1536 (1:1)",
        "1728x1344 (9:7)",
        "1344x1728 (7:9)",
        "1728x1296 (4:3)",
        "1296x1728 (3:4)",
        "1872x1248 (3:2)",
        "1248x1872 (2:3)",
        "2048x1152 (16:9)",
        "2016x864 (21:9)",
        "864x2016 (9:21)",
    ],
}

def get_resolution(resolution_str):
    """从分辨率字符串提取宽高,确保是8的倍数"""
    if not resolution_str:
        return 1024, 1024
    match = re.search(r"(\d+)\s*[×x]\s*(\d+)", resolution_str)
    if match:
        width = int(match.group(1))
        height = int(match.group(2))
        return width - width % 8, height - height % 8
    return 1024, 1024

# ==================== 2. 模型加载与核心优化 ====================
print("🚀 Loading Z-Image-Turbo pipeline...")

# 必须设置为 True,才能加载 Z-Image 自定义的 Pipeline 和 Transformer 类
# 否则无法调用 set_attention_backend
pipe = DiffusionPipeline.from_pretrained(
    "Tongyi-MAI/Z-Image-Turbo",
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    use_safetensors=True,
    trust_remote_code=True, 
)

# 使用 FlowMatchEulerDiscreteScheduler 并设置 shift=3.0
try:
    scheduler_config = dict(pipe.scheduler.config)
    scheduler_config.pop("algorithm_type", None) 
    pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_config(
        scheduler_config, 
        shift=3.0
    )
    print("✅ Scheduler optimized with shift=3.0")
except Exception as e:
    print(f"⚠️ Scheduler config warning: {e}")

# 移动到 GPU
pipe.to("cuda")


print("Enabling torch.compile optimizations...")
torch._inductor.config.conv_1x1_as_mm = True
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.epilogue_fusion = False
torch._inductor.config.coordinate_descent_check_all_directions = True
torch._inductor.config.max_autotune_gemm = True
torch._inductor.config.max_autotune_gemm_backends = "TRITON,ATEN"
torch._inductor.config.triton.cudagraphs = False

# 尝试按顺序启用最快的后端
def enable_best_attention_backend(pipeline):
    backends = [
        # ===== S Tier:当前最优 =====
        "flash_varlen",            # FA v2 varlen,稳定 + 高性能
        "_flash_3_varlen_hub",     # FA v3 varlen(hub),SM90 上非常强
        "_flash_varlen_3",         # FA v3 varlen(本地)
        "_flash_3",                # FA v3 非 varlen
        "flash",                   # FA v2 非 varlen

        # ===== A Tier:可接受 / 备用高性能 =====
        "flash_varlen_hub",
        "flash_hub",
        "xformers",                # 成熟但性能略逊于 FA
        "_native_flash",

        # ===== B Tier:框架原生 / 兼容优先 =====
        "native",
        "_native_efficient",
        "_native_cudnn",

        # ===== C Tier:特定后端 / 场景受限 =====
        "flex",
        "_native_xla",
        "_native_npu",
        "aiter",

        # ===== D Tier:Sage / 实验性量化实现 =====
        "sage",
        "sage_hub",
        "sage_varlen",
        "_sage_qk_int8_pv_fp16_cuda",
        "_sage_qk_int8_pv_fp16_triton",
        "_sage_qk_int8_pv_fp8_cuda",
        "_sage_qk_int8_pv_fp8_cuda_sm90",

        # ===== Fallback =====
        "_native_math",
    ]
    # 检查 pipeline.transformer 是否有 set_attention_backend 方法
    # 这是 Z-Image 自定义类特有的
    enabled = False
    for backend in backends:
        try:
            pipeline.transformer.set_attention_backend(backend)
            print(f"✅ Attention backend set to: {backend}")
            enabled = True
            break
        except Exception as e:
            pass
    
    if not enabled:
        print("⚠️ Warning: Transformer model does not support 'set_attention_backend'. Custom code might not be loaded.")
        # 如果加载失败,尝试标准的 xformers
        try:
            pipeline.enable_xformers_memory_efficient_attention()
            print("✅ Standard xFormers enabled as fallback")
        except:
            pass

# 执行后端设置
enable_best_attention_backend(pipe)

# VAE 内存优化
try:
    pipe.vae.enable_slicing()
except:
    pass

# print("Compiling transformer...")
# pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune-no-cudagraphs", fullgraph=False)

# ==================== 3. 生成逻辑 ====================
@spaces.GPU
def generate_image(
    prompt, 
    resolution_choice,
    use_custom_res,
    custom_width,
    custom_height,
    num_inference_steps, 
    seed, 
    randomize_seed,
    negative_prompt,
    gallery_history,
    progress=gr.Progress(track_tqdm=True)
):
    if gallery_history is None:
        gallery_history = []
        
    try:
        if not prompt or len(prompt.strip()) < 2:
            raise gr.Error("请输提示词 (Prompt)")
        
        prompt = prompt.strip()
        neg_prompt = negative_prompt.strip() if negative_prompt else None
        
        if use_custom_res:
            width = int(custom_width) - int(custom_width) % 8
            height = int(custom_height) - int(custom_height) % 8
        else:
            width, height = get_resolution(resolution_choice)
        
        if randomize_seed:
            seed = random.randint(0, 2**32 - 1)
        seed = int(seed)
        
        start_time = time.time()
        generator = torch.Generator("cuda").manual_seed(seed)
        
        # 清理显存确保最大空间
        torch.cuda.empty_cache()
        
        with torch.cuda.amp.autocast(dtype=torch.bfloat16):
            image = pipe(
                prompt=prompt,
                height=height,
                width=width,
                num_inference_steps=int(num_inference_steps),
                guidance_scale=0.0,
                generator=generator,
                negative_prompt=neg_prompt,
                max_sequence_length=512,
            ).images[0]
            
        gen_time = time.time() - start_time
        
        # 格式化历史记录
        info_label = f"{width}x{height} | Steps: {num_inference_steps} | Seed: {seed} | {gen_time:.2f}s"
        gallery_history.insert(0, (image, info_label))
        
        return gallery_history, seed
        
    except Exception as e:
        raise gr.Error(f"生成错误: {str(e)}")

# ==================== 4. UI 样式 ====================
css = """
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;600;800&display=swap');
body, .gradio-container { font-family: 'Inter', sans-serif !important; }

.header-container { text-align: center; margin-bottom: 20px; }
.header-title { 
    font-size: 2.5rem; font-weight: 800; margin: 0;
    background: linear-gradient(135deg, #f59e0b, #ea580c);
    -webkit-background-clip: text; -webkit-text-fill-color: transparent;
}
.header-subtitle { font-size: 1rem; color: #6b7280; font-weight: 500; }

.primary-btn { 
    background: linear-gradient(90deg, #f59e0b 0%, #d97706 100%) !important; 
    border: none !important; 
    color: white !important; 
    font-weight: 600 !important;
    font-size: 1.1rem !important;
    box-shadow: 0 4px 6px -1px rgba(245, 158, 11, 0.2) !important;
}
.primary-btn:hover { transform: translateY(-2px); box-shadow: 0 10px 15px -3px rgba(245, 158, 11, 0.3) !important; }

.panel-container { 
    background: #ffffff; border: 1px solid #e5e7eb; border-radius: 12px; padding: 15px; 
}
.dark .panel-container { background: #1f2937; border-color: #374151; }
"""

# ==================== 5. Gradio 界面 ====================
with gr.Blocks(theme=gr.themes.Soft(primary_hue="orange"), css=css, title="Z-Image-Turbo") as demo:
    
    gr.HTML("""
        <div class="header-container">
            <h1 class="header-title">⚡ Z-Image-Turbo</h1>
            <p class="header-subtitle">Optimized Backend • 8 Steps • Gallery History</p>
        </div>
    """)
    
    with gr.Row():
        # --- 控制面板 ---
        with gr.Column(scale=4, min_width=320):
            with gr.Group(elem_classes="panel-container"):
                prompt = gr.Textbox(
                    label="Prompt",
                    placeholder="Enter your prompt here...",
                    lines=3
                )
                negative_prompt = gr.Textbox(
                    label="Negative Prompt",
                    placeholder="Low quality, blurry...",
                    lines=1
                )
                generate_btn = gr.Button("🚀 Generate", elem_classes="primary-btn")
            
            with gr.Group(elem_classes="panel-container"):
                gr.Markdown("### 📐 Resolution")
                res_category = gr.Radio(
                    choices=["1024", "1280", "1536"],
                    value="1024",
                    label="Resolution Base",
                    container=False
                )
                resolution_dropdown = gr.Dropdown(
                    choices=RES_CHOICES["1024"],
                    value=RES_CHOICES["1024"][0],
                    label="Select Ratio",
                    show_label=False
                )
                
                with gr.Accordion("Custom Size", open=False):
                    use_custom_res = gr.Checkbox(label="Enable Custom", value=False)
                    with gr.Row(visible=False) as custom_res_row:
                        width_slider = gr.Slider(512, 1536, value=1024, step=64, label="W")
                        height_slider = gr.Slider(512, 1536, value=1024, step=64, label="H")

            with gr.Accordion("⚙️ Settings", open=False):
                with gr.Group(elem_classes="panel-container"):
                    steps_slider = gr.Slider(4, 20, value=8, step=1, label="Steps")
                    with gr.Row():
                        random_seed = gr.Checkbox(label="Random Seed", value=True)
                        seed_input = gr.Number(label="Seed", value=42, visible=False, precision=0)

        # --- 画廊 ---
        with gr.Column(scale=6, min_width=500):
            output_gallery = gr.Gallery(
                label="History",
                value=[],
                columns=[2],
                rows=[2],
                object_fit="contain",
                height="auto",
                show_share_button=True,
                show_download_button=True,
                interactive=False
            )
            with gr.Row():
                last_seed_display = gr.Textbox(label="Last Seed", interactive=False, scale=3)
                clear_btn = gr.Button("🗑️ Clear", scale=1, variant="secondary")

    # 交互逻辑
    def update_resolution_list(category):
        return gr.Dropdown(choices=RES_CHOICES[category], value=RES_CHOICES[category][0])

    res_category.change(update_resolution_list, inputs=res_category, outputs=resolution_dropdown)
    
    use_custom_res.change(
        lambda x: (gr.Row(visible=x), gr.Dropdown(interactive=not x)),
        inputs=use_custom_res, outputs=[custom_res_row, resolution_dropdown]
    )
    
    random_seed.change(lambda x: gr.Number(visible=not x), inputs=random_seed, outputs=seed_input)
    
    generate_btn.click(
        fn=generate_image,
        inputs=[prompt, resolution_dropdown, use_custom_res, width_slider, height_slider, steps_slider, seed_input, random_seed, negative_prompt, output_gallery],
        outputs=[output_gallery, last_seed_display]
    )
    
    clear_btn.click(lambda: ([], ""), outputs=[output_gallery, last_seed_display])

if __name__ == "__main__":
    demo.launch()