IdlecloudX commited on
Commit
694670c
·
verified ·
1 Parent(s): 48145c9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +232 -0
app.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Z-Image 图像生成演示
3
+ 简化版 UI,移除了提示词优化逻辑,合并了分辨率选择。
4
+ """
5
+
6
+ import spaces
7
+ import random
8
+ import re
9
+ import torch
10
+ import gradio as gr
11
+ from diffusers import ZImagePipeline
12
+
13
+ # ==================== 配置信息 ====================
14
+ MODEL_PATH = "Tongyi-MAI/Z-Image"
15
+
16
+ # ==================== 合并后的分辨率列表 ====================
17
+ ALL_RESOLUTIONS = [
18
+ # 720 级别
19
+ "720x720 ( 1:1 )",
20
+ "896x512 ( 16:9 )",
21
+ "512x896 ( 9:16 )",
22
+ "832x544 ( 3:2 )",
23
+ "544x832 ( 2:3 )",
24
+ "800x576 ( 4:3 )",
25
+ "576x800 ( 3:4 )",
26
+ # 1024 级别
27
+ "1024x1024 ( 1:1 )",
28
+ "1152x896 ( 9:7 )",
29
+ "896x1152 ( 7:9 )",
30
+ "1152x864 ( 4:3 )",
31
+ "864x1152 ( 3:4 )",
32
+ "1248x832 ( 3:2 )",
33
+ "832x1248 ( 2:3 )",
34
+ "1280x720 ( 16:9 )",
35
+ "720x1280 ( 9:16 )",
36
+ "1344x576 ( 21:9 )",
37
+ "576x1344 ( 9:21 )",
38
+ # 1280 级别
39
+ "1280x1280 ( 1:1 )",
40
+ "1440x1120 ( 9:7 )",
41
+ "1120x1440 ( 7:9 )",
42
+ "1472x1104 ( 4:3 )",
43
+ "1104x1472 ( 3:4 )",
44
+ "1536x1024 ( 3:2 )",
45
+ "1024x1536 ( 2:3 )",
46
+ "1536x864 ( 16:9 )",
47
+ "864x1536 ( 9:16 )",
48
+ "1680x720 ( 21:9 )",
49
+ "720x1680 ( 9:21 )",
50
+ ]
51
+
52
+ EXAMPLE_PROMPTS = [
53
+ ["一位男士和他的贵宾犬穿着配套的服装参加狗狗秀,室内灯光,背景中有观众。"],
54
+ ["极具氛围感的暗调人像,一位优雅的中国美女在黑暗的房间里。一束强光通过遮光板,在她的脸上投射出一个清晰的闪电形状的光影。"],
55
+ ["Young Chinese woman in red Hanfu, intricate embroidery, golden phoenix headdress, soft-lit outdoor night background."],
56
+ ["A serene mountain landscape at sunset with golden light reflecting off a calm lake."],
57
+ ["A futuristic cityscape with flying cars and neon holographic advertisements, cyberpunk style."],
58
+ ]
59
+
60
+ # ==================== 辅助函数 ====================
61
+ def get_resolution(resolution: str) -> tuple[int, int]:
62
+ """解析分辨率字符串为宽度和高度。"""
63
+ match = re.search(r"(\d+)\s*[×x]\s*(\d+)", resolution)
64
+ if match:
65
+ return int(match.group(1)), int(match.group(2))
66
+ return 1024, 1024
67
+
68
+ # ==================== 模型加载 ====================
69
+ print(f"正在从 {MODEL_PATH} 加载 Z-Image 流水线...")
70
+ pipe = ZImagePipeline.from_pretrained(
71
+ MODEL_PATH,
72
+ torch_dtype=torch.bfloat16,
73
+ low_cpu_mem_usage=False,
74
+ )
75
+ pipe.to("cuda")
76
+ print("流水线加载成功!")
77
+
78
+ # ==================== 生成核心逻辑 ====================
79
+ @spaces.GPU
80
+ def generate(
81
+ prompt: str,
82
+ negative_prompt: str = "",
83
+ resolution: str = "1024x1024 ( 1:1 )",
84
+ seed: int = 42,
85
+ num_inference_steps: int = 30,
86
+ guidance_scale: float = 4.0,
87
+ cfg_normalization: bool = False,
88
+ random_seed: bool = True,
89
+ gallery_images: list = None,
90
+ progress=gr.Progress(track_tqdm=True),
91
+ ):
92
+ """
93
+ 使用 Z-Image 扩散模型生成图像。
94
+ 仅保留核心生成逻辑。
95
+ """
96
+
97
+ if not prompt.strip():
98
+ raise gr.Error("请输入提示词。")
99
+
100
+ # 处理种子
101
+ if random_seed:
102
+ new_seed = random.randint(1, 1000000)
103
+ else:
104
+ new_seed = seed if seed != -1 else random.randint(1, 1000000)
105
+
106
+ # 解析分辨率
107
+ width, height = get_resolution(resolution)
108
+
109
+ # 设置生成器
110
+ generator = torch.Generator("cuda").manual_seed(new_seed)
111
+
112
+ # 执行生成
113
+ image = pipe(
114
+ prompt=prompt,
115
+ negative_prompt=negative_prompt if negative_prompt.strip() else None,
116
+ height=height,
117
+ width=width,
118
+ cfg_normalization=cfg_normalization,
119
+ num_inference_steps=num_inference_steps,
120
+ guidance_scale=guidance_scale,
121
+ generator=generator,
122
+ ).images[0]
123
+
124
+ # 更新画廊
125
+ if gallery_images is None:
126
+ gallery_images = []
127
+ gallery_images = [image] + gallery_images
128
+
129
+ return gallery_images, str(new_seed), int(new_seed)
130
+
131
+ # ==================== Gradio 界面设计 ====================
132
+ with gr.Blocks(title="Z-Image 核心生成器") as demo:
133
+ gr.Markdown(
134
+ """<div align="center">
135
+ <h1>Z-Image 核心生成器</h1>
136
+ <p>基于单流扩散 Transformer 的高效图像生成模型</p>
137
+ </div>"""
138
+ )
139
+
140
+ with gr.Row():
141
+ with gr.Column(scale=1):
142
+ prompt_input = gr.Textbox(
143
+ label="提示词 (Prompt)",
144
+ lines=4,
145
+ placeholder="在此输入你想生成的画面描述(支持中英文)..."
146
+ )
147
+ negative_prompt_input = gr.Textbox(
148
+ label="反向提示词 (可选)",
149
+ lines=2,
150
+ placeholder="输入你不想在图像中出现的内容..."
151
+ )
152
+
153
+ # 分辨率改为单一完整的下拉列表
154
+ resolution = gr.Dropdown(
155
+ label="分辨率选择 (宽 x 高)",
156
+ choices=ALL_RESOLUTIONS,
157
+ value="1024x1024 ( 1:1 )"
158
+ )
159
+
160
+ with gr.Row():
161
+ seed = gr.Number(label="种子", value=42, precision=0)
162
+ random_seed = gr.Checkbox(label="使用随机种子", value=True)
163
+
164
+ with gr.Row():
165
+ num_inference_steps = gr.Slider(
166
+ label="推理步数 (Steps)",
167
+ minimum=10,
168
+ maximum=100,
169
+ value=30,
170
+ step=1
171
+ )
172
+ guidance_scale = gr.Slider(
173
+ label="引导比例 (CFG Scale)",
174
+ minimum=1.0,
175
+ maximum=20.0,
176
+ value=4.0,
177
+ step=0.5
178
+ )
179
+
180
+ cfg_normalization = gr.Checkbox(
181
+ label="启用 CFG 归一化",
182
+ value=False
183
+ )
184
+
185
+ generate_btn = gr.Button("开始生成", variant="primary")
186
+
187
+ # 示例提示词
188
+ gr.Markdown("### 📝 示例")
189
+ gr.Examples(
190
+ examples=EXAMPLE_PROMPTS,
191
+ inputs=prompt_input,
192
+ label=None
193
+ )
194
+
195
+ with gr.Column(scale=1):
196
+ output_gallery = gr.Gallery(
197
+ label="生成结果",
198
+ columns=1,
199
+ rows=1,
200
+ height=512,
201
+ object_fit="contain",
202
+ format="png",
203
+ interactive=False,
204
+ )
205
+ used_seed = gr.Textbox(label="本次使用的种子", interactive=False)
206
+
207
+ # 绑定生成按钮事件
208
+ generate_btn.click(
209
+ fn=generate,
210
+ inputs=[
211
+ prompt_input,
212
+ negative_prompt_input,
213
+ resolution,
214
+ seed,
215
+ num_inference_steps,
216
+ guidance_scale,
217
+ cfg_normalization,
218
+ random_seed,
219
+ output_gallery,
220
+ ],
221
+ outputs=[output_gallery, used_seed, seed],
222
+ api_name="generate",
223
+ )
224
+
225
+ # ==================== 启动 ====================
226
+ css = ".fillable{max-width: 1230px !important}"
227
+ if __name__ == "__main__":
228
+ demo.launch(
229
+ server_name="0.0.0.0",
230
+ server_port=7860,
231
+ css=css
232
+ )