telcom commited on
Commit
72ae055
·
verified ·
1 Parent(s): 60e88d4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +306 -516
app.py CHANGED
@@ -1,568 +1,358 @@
1
- import spaces
2
- from dataclasses import dataclass
3
- import json
4
- import logging
5
  import os
 
6
  import random
7
- import re
8
- import sys
9
  import warnings
 
 
 
 
 
 
 
 
10
 
11
- from PIL import Image
12
- from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
13
  import gradio as gr
 
 
 
14
  import torch
 
 
 
 
 
 
 
 
15
  from transformers import AutoModelForCausalLM, AutoTokenizer
16
 
17
- from prompt_check import is_unsafe_prompt
18
-
19
- sys.path.append(os.path.dirname(os.path.abspath(__file__)))
20
-
21
- from diffusers import ZImagePipeline
22
- from diffusers.models.transformers.transformer_z_image import ZImageTransformer2DModel
23
 
24
- from pe import prompt_template
25
-
26
- # ==================== Environment Variables ==================================
27
- MODEL_PATH = os.environ.get("MODEL_PATH", "telcom/dee-z-image")
28
- ENABLE_COMPILE = os.environ.get("ENABLE_COMPILE", "true").lower() == "true"
29
- ENABLE_WARMUP = os.environ.get("ENABLE_WARMUP", "true").lower() == "true"
30
- ATTENTION_BACKEND = os.environ.get("ATTENTION_BACKEND", "flash_3")
31
- UNSAFE_MAX_NEW_TOKEN = int(os.environ.get("UNSAFE_MAX_NEW_TOKEN", "10"))
32
- DASHSCOPE_API_KEY = os.environ.get("DASHSCOPE_API_KEY")
33
- HF_TOKEN = os.environ.get("HF_TOKEN")
34
- UNSAFE_PROMPT_CHECK = os.environ.get("UNSAFE_PROMPT_CHECK")
35
- # =============================================================================
36
 
 
 
 
37
 
38
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
39
  warnings.filterwarnings("ignore")
40
  logging.getLogger("transformers").setLevel(logging.ERROR)
41
 
42
- RES_CHOICES = {
43
- "1024": [
44
- "1024x1024 ( 1:1 )",
45
- "1152x896 ( 9:7 )",
46
- "896x1152 ( 7:9 )",
47
- "1152x864 ( 4:3 )",
48
- "864x1152 ( 3:4 )",
49
- "1248x832 ( 3:2 )",
50
- "832x1248 ( 2:3 )",
51
- "1280x720 ( 16:9 )",
52
- "720x1280 ( 9:16 )",
53
- "1344x576 ( 21:9 )",
54
- "576x1344 ( 9:21 )",
55
- ],
56
- "1280": [
57
- "1280x1280 ( 1:1 )",
58
- "1440x1120 ( 9:7 )",
59
- "1120x1440 ( 7:9 )",
60
- "1472x1104 ( 4:3 )",
61
- "1104x1472 ( 3:4 )",
62
- "1536x1024 ( 3:2 )",
63
- "1024x1536 ( 2:3 )",
64
- "1536x864 ( 16:9 )",
65
- "864x1536 ( 9:16 )",
66
- "1680x720 ( 21:9 )",
67
- "720x1680 ( 9:21 )",
68
- ],
69
- "1536": [
70
- "1536x1536 ( 1:1 )",
71
- "1728x1344 ( 9:7 )",
72
- "1344x1728 ( 7:9 )",
73
- "1728x1296 ( 4:3 )",
74
- "1296x1728 ( 3:4 )",
75
- "1872x1248 ( 3:2 )",
76
- "1248x1872 ( 2:3 )",
77
- "2048x1152 ( 16:9 )",
78
- "1152x2048 ( 9:16 )",
79
- "2016x864 ( 21:9 )",
80
- "864x2016 ( 9:21 )",
81
- ],
82
- }
83
 
84
- RESOLUTION_SET = []
85
- for resolutions in RES_CHOICES.values():
86
- RESOLUTION_SET.extend(resolutions)
87
 
88
- EXAMPLE_PROMPTS = [
89
- [
90
- "Young Chinese woman in red Hanfu, intricate embroidery. Impeccable makeup, red floral forehead pattern. Elaborate high bun, golden phoenix headdress, red flowers, beads. Holds round folding fan with lady, trees, bird. Neon lightning-bolt lamp (⚡️), bright yellow glow, above extended left palm. Soft-lit outdoor night background, silhouetted tiered pagoda (西安大雁塔), blurred colorful distant lights."
91
- ],
92
- [
93
- '''A vertical digital illustration depicting a serene and majestic Chinese landscape, rendered in a style reminiscent of traditional Shanshui painting but with a modern, clean aesthetic. The scene is dominated by towering, steep cliffs in various shades of blue and teal, which frame a central valley. In the distance, layers of mountains fade into a light blue and white mist, creating a strong sense of atmospheric perspective and depth. A calm, turquoise river flows through the center of the composition, with a small, traditional Chinese boat, possibly a sampan, navigating its waters. The boat has a bright yellow canopy and a red hull, and it leaves a gentle wake behind it. It carries several indistinct figures of people. Sparse vegetation, including green trees and some bare-branched trees, clings to the rocky ledges and peaks. The overall lighting is soft and diffused, casting a tranquil glow over the entire scene. Centered in the image is overlaid text. At the top of the text block is a small, red, circular seal-like logo containing stylized characters. Below it, in a smaller, black, sans-serif font, are the words 'Zao-Xiang * East Beauty & West Fashion * Z-Image'. Directly beneath this, in a larger, elegant black serif font, is the word 'SHOW & SHARE CREATIVITY WITH THE WORLD'. Among them, there are "SHOW & SHARE", "CREATIVITY", and "WITH THE WORLD"'''
94
- ],
95
-
96
- ]
97
 
 
 
 
 
 
 
98
 
99
- def get_resolution(resolution):
100
- match = re.search(r"(\d+)\s*[×x]\s*(\d+)", resolution)
101
- if match:
102
- return int(match.group(1)), int(match.group(2))
103
- return 1024, 1024
104
 
 
 
 
105
 
106
- def load_models(model_path, enable_compile=False, attention_backend="native"):
107
- print(f"Loading models from {model_path}...")
 
108
 
109
- use_auth_token = HF_TOKEN if HF_TOKEN else True
 
 
 
110
 
111
- if not os.path.exists(model_path):
112
- vae = AutoencoderKL.from_pretrained(
113
- f"{model_path}",
114
- subfolder="vae",
115
- torch_dtype=torch.bfloat16,
116
- device_map="cuda",
117
- use_auth_token=use_auth_token,
118
- )
119
-
120
- text_encoder = AutoModelForCausalLM.from_pretrained(
121
- f"{model_path}",
122
- subfolder="text_encoder",
123
- torch_dtype=torch.bfloat16,
124
- device_map="cuda",
125
- use_auth_token=use_auth_token,
126
- ).eval()
127
-
128
- tokenizer = AutoTokenizer.from_pretrained(f"{model_path}", subfolder="tokenizer", use_auth_token=use_auth_token)
129
- else:
130
- vae = AutoencoderKL.from_pretrained(
131
- os.path.join(model_path, "vae"), torch_dtype=torch.bfloat16, device_map="cuda"
132
- )
133
-
134
- text_encoder = AutoModelForCausalLM.from_pretrained(
135
- os.path.join(model_path, "text_encoder"),
136
- torch_dtype=torch.bfloat16,
137
- device_map="cuda",
138
- ).eval()
139
-
140
- tokenizer = AutoTokenizer.from_pretrained(os.path.join(model_path, "tokenizer"))
141
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  tokenizer.padding_side = "left"
143
 
144
- if enable_compile:
145
- print("Enabling torch.compile optimizations...")
146
- torch._inductor.config.conv_1x1_as_mm = True
147
- torch._inductor.config.coordinate_descent_tuning = True
148
- torch._inductor.config.epilogue_fusion = False
149
- torch._inductor.config.coordinate_descent_check_all_directions = True
150
- torch._inductor.config.max_autotune_gemm = True
151
- torch._inductor.config.max_autotune_gemm_backends = "TRITON,ATEN"
152
- torch._inductor.config.triton.cudagraphs = False
153
-
154
- pipe = ZImagePipeline(scheduler=None, vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, transformer=None)
155
-
156
- if enable_compile:
157
- pipe.vae.disable_tiling()
158
-
159
- if not os.path.exists(model_path):
160
- transformer = ZImageTransformer2DModel.from_pretrained(
161
- f"{model_path}", subfolder="transformer", use_auth_token=use_auth_token
162
- ).to("cuda", torch.bfloat16)
163
- else:
164
- transformer = ZImageTransformer2DModel.from_pretrained(os.path.join(model_path, "transformer")).to(
165
- "cuda", torch.bfloat16
166
- )
167
-
168
- pipe.transformer = transformer
169
- pipe.transformer.set_attention_backend(attention_backend)
170
-
171
- if enable_compile:
172
- print("Compiling transformer...")
173
- pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune-no-cudagraphs", fullgraph=False)
174
-
175
- pipe.to("cuda", torch.bfloat16)
176
-
177
- from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
178
- from transformers import CLIPImageProcessor
179
-
180
- safety_model_id = "CompVis/stable-diffusion-safety-checker"
181
- safety_feature_extractor = CLIPImageProcessor.from_pretrained(safety_model_id)
182
- safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id, torch_dtype=torch.float16).to("cuda")
183
-
184
- pipe.safety_feature_extractor = safety_feature_extractor
185
- pipe.safety_checker = safety_checker
186
- return pipe
187
-
188
-
189
- def generate_image(
190
- pipe,
191
- prompt,
192
- resolution="1024x1024",
193
- seed=42,
194
- guidance_scale=5.0,
195
- num_inference_steps=50,
196
- shift=3.0,
197
- max_sequence_length=512,
198
- progress=gr.Progress(track_tqdm=True),
199
- ):
200
- width, height = get_resolution(resolution)
201
-
202
- generator = torch.Generator("cuda").manual_seed(seed)
203
-
204
- scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=shift)
205
- pipe.scheduler = scheduler
206
-
207
- image = pipe(
208
- prompt=prompt,
209
- height=height,
210
- width=width,
211
- guidance_scale=guidance_scale,
212
- num_inference_steps=num_inference_steps,
213
- generator=generator,
214
- max_sequence_length=max_sequence_length,
215
- ).images[0]
216
-
217
- return image
218
-
219
-
220
- def warmup_model(pipe, resolutions):
221
- print("Starting warmup phase...")
222
-
223
- dummy_prompt = "warmup"
224
-
225
- for res_str in resolutions:
226
- print(f"Warming up for resolution: {res_str}")
227
- try:
228
- for i in range(3):
229
- generate_image(
230
- pipe,
231
- prompt=dummy_prompt,
232
- resolution=res_str,
233
- num_inference_steps=9,
234
- guidance_scale=0.0,
235
- seed=42 + i,
236
- )
237
- except Exception as e:
238
- print(f"Warmup failed for {res_str}: {e}")
239
-
240
- print("Warmup completed.")
241
-
242
-
243
- # ==================== Prompt Expander ====================
244
- @dataclass
245
- class PromptOutput:
246
- status: bool
247
- prompt: str
248
- seed: int
249
- system_prompt: str
250
- message: str
251
-
252
-
253
- class PromptExpander:
254
- def __init__(self, backend="api", **kwargs):
255
- self.backend = backend
256
-
257
- def decide_system_prompt(self, template_name=None):
258
- return prompt_template
259
-
260
-
261
- class APIPromptExpander(PromptExpander):
262
- def __init__(self, api_config=None, **kwargs):
263
- super().__init__(backend="api", **kwargs)
264
- self.api_config = api_config or {}
265
- self.client = self._init_api_client()
266
-
267
- def _init_api_client(self):
268
- try:
269
- from openai import OpenAI
270
-
271
- api_key = self.api_config.get("api_key") or DASHSCOPE_API_KEY
272
- base_url = self.api_config.get("base_url", "https://dashscope.aliyuncs.com/compatible-mode/v1")
273
-
274
- if not api_key:
275
- print("Warning: DASHSCOPE_API_KEY not found.")
276
- return None
277
-
278
- return OpenAI(api_key=api_key, base_url=base_url)
279
- except ImportError:
280
- print("Please install openai: pip install openai")
281
- return None
282
- except Exception as e:
283
- print(f"Failed to initialize API client: {e}")
284
- return None
285
-
286
- def __call__(self, prompt, system_prompt=None, seed=-1, **kwargs):
287
- return self.extend(prompt, system_prompt, seed, **kwargs)
288
-
289
- def extend(self, prompt, system_prompt=None, seed=-1, **kwargs):
290
- if self.client is None:
291
- return PromptOutput(False, "", seed, system_prompt, "API client not initialized")
292
-
293
- if system_prompt is None:
294
- system_prompt = self.decide_system_prompt()
295
-
296
- if "{prompt}" in system_prompt:
297
- system_prompt = system_prompt.format(prompt=prompt)
298
- prompt = " "
299
-
300
- try:
301
- model = self.api_config.get("model", "qwen3-max-preview")
302
- response = self.client.chat.completions.create(
303
- model=model,
304
- messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}],
305
- temperature=0.7,
306
- top_p=0.8,
307
- )
308
-
309
- content = response.choices[0].message.content
310
- json_start = content.find("```json")
311
- if json_start != -1:
312
- json_end = content.find("```", json_start + 7)
313
- try:
314
- json_str = content[json_start + 7 : json_end].strip()
315
- data = json.loads(json_str)
316
- expanded_prompt = data.get("revised_prompt", content)
317
- except:
318
- expanded_prompt = content
319
- else:
320
- expanded_prompt = content
321
-
322
- return PromptOutput(
323
- status=True, prompt=expanded_prompt, seed=seed, system_prompt=system_prompt, message=content
324
- )
325
- except Exception as e:
326
- return PromptOutput(False, "", seed, system_prompt, str(e))
327
-
328
-
329
- def create_prompt_expander(backend="api", **kwargs):
330
- if backend == "api":
331
- return APIPromptExpander(**kwargs)
332
- raise ValueError("Only 'api' backend is supported.")
333
-
334
-
335
- pipe = None
336
- prompt_expander = None
337
 
 
 
 
 
 
 
338
 
339
- def init_app():
340
- global pipe, prompt_expander
 
341
 
 
 
 
 
342
  try:
343
- pipe = load_models(MODEL_PATH, enable_compile=ENABLE_COMPILE, attention_backend=ATTENTION_BACKEND)
344
- print(f"Model loaded. Compile: {ENABLE_COMPILE}, Backend: {ATTENTION_BACKEND}")
345
-
346
- if ENABLE_WARMUP:
347
- all_resolutions = []
348
- for cat in RES_CHOICES.values():
349
- all_resolutions.extend(cat)
350
- warmup_model(pipe, all_resolutions)
351
 
352
- except Exception as e:
353
- print(f"Error loading model: {e}")
354
- pipe = None
355
 
 
356
  try:
357
- prompt_expander = create_prompt_expander(backend="api", api_config={"model": "qwen3-max-preview"})
358
- print("Prompt expander initialized.")
359
- except Exception as e:
360
- print(f"Error initializing prompt expander: {e}")
361
- prompt_expander = None
362
-
363
-
364
- def prompt_enhance(prompt, enable_enhance):
365
- if not enable_enhance or not prompt_expander:
366
- return prompt, "Enhancement disabled or not available."
367
-
368
- if not prompt.strip():
369
- return "", "Please enter a prompt."
370
-
371
- try:
372
- result = prompt_expander(prompt)
373
- if result.status:
374
- return result.prompt, result.message
375
- else:
376
- return prompt, f"Enhancement failed: {result.message}"
377
- except Exception as e:
378
- return prompt, f"Error: {str(e)}"
379
-
380
-
381
- @spaces.GPU
382
- def generate(
383
- prompt,
384
- resolution="1024x1024 ( 1:1 )",
385
- seed=42,
386
- steps=9,
387
- shift=3.0,
388
- random_seed=True,
389
- gallery_images=None,
390
- enhance=False,
391
- progress=gr.Progress(track_tqdm=True),
392
- ):
393
- """
394
- Generate an image using the Z-Image model based on the provided prompt and settings.
395
-
396
- This function is triggered when the user clicks the "Generate" button. It processes
397
- the input prompt (optionally enhancing it), configures generation parameters, and
398
- produces an image using the Z-Image diffusion transformer pipeline.
399
-
400
- Args:
401
- prompt (str): Text prompt describing the desired image content
402
- resolution (str): Output resolution in format "WIDTHxHEIGHT ( RATIO )" (e.g., "1024x1024 ( 1:1 )")
403
- seed (int): Seed for reproducible generation
404
- steps (int): Number of inference steps for the diffusion process
405
- shift (float): Time shift parameter for the flow matching scheduler
406
- random_seed (bool): Whether to generate a new random seed, if True will ignore the seed input
407
- gallery_images (list): List of previously generated images to append to (only needed for the Gradio UI)
408
- enhance (bool): This was Whether to enhance the prompt (DISABLED! Do not use)
409
- progress (gr.Progress): Gradio progress tracker for displaying generation progress (only needed for the Gradio UI)
410
-
411
- Returns:
412
- tuple: (gallery_images, seed_str, seed_int)
413
- - gallery_images: Updated list of generated images including the new image
414
- - seed_str: String representation of the seed used for generation
415
- - seed_int: Integer representation of the seed used for generation
416
- """
417
-
418
- if random_seed:
419
- new_seed = random.randint(1, 1000000)
420
- else:
421
- new_seed = seed if seed != -1 else random.randint(1, 1000000)
422
-
423
- class UnsafeContentError(Exception):
424
  pass
425
 
426
- try:
427
- if pipe is None:
428
- raise gr.Error("Model not loaded.")
429
-
430
- has_unsafe_concept = is_unsafe_prompt(
431
- pipe.text_encoder,
432
- pipe.tokenizer,
433
- system_prompt=UNSAFE_PROMPT_CHECK,
434
- user_prompt=prompt,
435
- max_new_token=UNSAFE_MAX_NEW_TOKEN,
436
- )
437
- if has_unsafe_concept:
438
- raise UnsafeContentError("Input unsafe")
439
-
440
- final_prompt = prompt
441
-
442
- if enhance:
443
- final_prompt, _ = prompt_enhance(prompt, True)
444
- print(f"Enhanced prompt: {final_prompt}")
445
-
446
  try:
447
- resolution_str = resolution.split(" ")[0]
448
- except:
449
- resolution_str = "1024x1024"
450
-
451
- image = generate_image(
452
- pipe=pipe,
453
- prompt=final_prompt,
454
- resolution=resolution_str,
455
- seed=new_seed,
456
- guidance_scale=0.0,
457
- num_inference_steps=int(steps + 1),
458
- shift=shift,
459
- )
460
-
461
- safety_checker_input = pipe.safety_feature_extractor([image], return_tensors="pt").pixel_values.cuda()
462
- _, has_nsfw_concept = pipe.safety_checker(images=[torch.zeros(1)], clip_input=safety_checker_input)
463
- has_nsfw_concept = has_nsfw_concept[0]
464
- if has_nsfw_concept:
465
- print("input unsafe")
466
 
467
- except UnsafeContentError:
468
- image = Image.open("nsfw.png")
 
 
 
 
469
 
470
- if gallery_images is None:
471
- gallery_images = []
472
- # gallery_images.append(image)
473
- gallery_images = [image] + gallery_images # latest output to be at the top of the list
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
474
 
475
- return gallery_images, str(new_seed), int(new_seed)
 
476
 
 
 
 
477
 
478
- init_app()
 
479
 
480
- # ==================== AoTI (Ahead of Time Inductor compilation) ====================
481
 
482
- pipe.transformer.layers._repeated_blocks = ["ZImageTransformerBlock"]
483
- spaces.aoti_blocks_load(pipe.transformer.layers, "zerogpu-aoti/Z-Image", variant="fa3")
484
 
485
- with gr.Blocks(title="Z-Image Demo") as demo:
486
- gr.Markdown(
487
- """<div align="center">
488
 
489
- # Z-Image Generation Demo
 
 
 
490
 
491
- [![GitHub](https://img.shields.io/badge/GitHub-Z--Image-181717?logo=github&logoColor=white)](https://github.com/Tongyi-MAI/Z-Image)
 
 
 
 
 
 
 
 
 
 
492
 
493
- *An Efficient Image Generation Foundation Model with Single-Stream Diffusion Transformer*
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
494
 
495
- </div>"""
496
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
497
 
498
- with gr.Row():
499
- with gr.Column(scale=1):
500
- prompt_input = gr.Textbox(label="Prompt", lines=3, placeholder="Enter your prompt here...")
501
- # PE components (Temporarily disabled)
502
- # with gr.Row():
503
- # enable_enhance = gr.Checkbox(label="Enhance Prompt (DashScope)", value=False)
504
- # enhance_btn = gr.Button("Enhance Only")
505
-
506
- with gr.Row():
507
- choices = [int(k) for k in RES_CHOICES.keys()]
508
- res_cat = gr.Dropdown(value=1024, choices=choices, label="Resolution Category")
509
-
510
- initial_res_choices = RES_CHOICES["1024"]
511
- resolution = gr.Dropdown(
512
- value=initial_res_choices[0], choices=RESOLUTION_SET, label="Width x Height (Ratio)"
513
- )
514
-
515
- with gr.Row():
516
- seed = gr.Number(label="Seed", value=42, precision=0)
517
- random_seed = gr.Checkbox(label="Random Seed", value=True)
518
-
519
- with gr.Row():
520
- steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=8, step=1, interactive=False)
521
- shift = gr.Slider(label="Time Shift", minimum=1.0, maximum=10.0, value=3.0, step=0.1)
522
-
523
- generate_btn = gr.Button("Generate", variant="primary")
524
-
525
- # Example prompts
526
- gr.Markdown("### 📝 Example Prompts")
527
- gr.Examples(examples=EXAMPLE_PROMPTS, inputs=prompt_input, label=None)
528
-
529
- with gr.Column(scale=1):
530
- output_gallery = gr.Gallery(
531
- label="Generated Images",
532
- columns=2,
533
- rows=2,
534
- height=600,
535
- object_fit="contain",
536
- format="png",
537
- interactive=False,
538
- )
539
- used_seed = gr.Textbox(label="Seed Used", interactive=False)
540
-
541
- def update_res_choices(_res_cat):
542
- if str(_res_cat) in RES_CHOICES:
543
- res_choices = RES_CHOICES[str(_res_cat)]
544
- else:
545
- res_choices = RES_CHOICES["1024"]
546
- return gr.update(value=res_choices[0], choices=res_choices)
547
-
548
- res_cat.change(update_res_choices, inputs=res_cat, outputs=resolution, api_visibility="private")
549
-
550
- # PE enhancement button (Temporarily disabled)
551
- # enhance_btn.click(
552
- # prompt_enhance,
553
- # inputs=[prompt_input, enable_enhance],
554
- # outputs=[prompt_input, final_prompt_output]
555
- # )
556
-
557
- generate_btn.click(
558
- generate,
559
- inputs=[prompt_input, resolution, seed, steps, shift, random_seed, output_gallery],
560
- outputs=[output_gallery, used_seed, seed],
561
- api_visibility="public",
562
  )
563
 
564
- css = """
565
- .fillable{max-width: 1230px !important}
566
- """
567
  if __name__ == "__main__":
568
- demo.launch(css=css, mcp_server=True)
 
 
1
+ # ============================================================
2
+ # IMPORTANT: imports order matters for Hugging Face Spaces
3
+ # ============================================================
4
+
5
  import os
6
+ import gc
7
  import random
 
 
8
  import warnings
9
+ import logging
10
+
11
+ # ---- Spaces GPU decorator (must be imported early) ----------
12
+ try:
13
+ import spaces # noqa: F401
14
+ SPACES_AVAILABLE = True
15
+ except Exception:
16
+ SPACES_AVAILABLE = False
17
 
 
 
18
  import gradio as gr
19
+ import numpy as np
20
+ from PIL import Image
21
+
22
  import torch
23
+ from huggingface_hub import login
24
+
25
+ from diffusers import (
26
+ ZImagePipeline,
27
+ ZImageImg2ImgPipeline,
28
+ AutoencoderKL,
29
+ FlowMatchEulerDiscreteScheduler,
30
+ )
31
  from transformers import AutoModelForCausalLM, AutoTokenizer
32
 
33
+ # ============================================================
34
+ # Config
35
+ # ============================================================
 
 
 
36
 
37
+ MODEL_PATH = os.environ.get("MODEL_PATH", "telcom/dee-z-image").strip()
38
+ ATTENTION_BACKEND = os.environ.get("ATTENTION_BACKEND", "flash_3").strip() # try: flash_3, flash, sdpa
39
+ ENABLE_COMPILE = os.environ.get("ENABLE_COMPILE", "false").lower() == "true"
 
 
 
 
 
 
 
 
 
40
 
41
+ HF_TOKEN = os.getenv("HF_TOKEN", "").strip()
42
+ if HF_TOKEN:
43
+ login(token=HF_TOKEN)
44
 
45
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
46
  warnings.filterwarnings("ignore")
47
  logging.getLogger("transformers").setLevel(logging.ERROR)
48
 
49
+ MAX_SEED = np.iinfo(np.int32).max
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
+ # ============================================================
52
+ # Device & dtype
53
+ # ============================================================
54
 
55
+ cuda_available = torch.cuda.is_available()
56
+ device = torch.device("cuda" if cuda_available else "cpu")
 
 
 
 
 
 
 
57
 
58
+ if cuda_available and hasattr(torch.cuda, "is_bf16_supported") and torch.cuda.is_bf16_supported():
59
+ dtype = torch.bfloat16
60
+ elif cuda_available:
61
+ dtype = torch.float16
62
+ else:
63
+ dtype = torch.float32
64
 
65
+ # A conservative max for most Spaces GPUs. Increase if you know you have headroom.
66
+ MAX_IMAGE_SIZE = 1536 if cuda_available else 768
 
 
 
67
 
68
+ fallback_msg = ""
69
+ if not cuda_available:
70
+ fallback_msg = "GPU unavailable. Running in CPU fallback mode (slow)."
71
 
72
+ # ============================================================
73
+ # Load pipelines
74
+ # ============================================================
75
 
76
+ pipe_txt2img = None
77
+ pipe_img2img = None
78
+ model_loaded = False
79
+ load_error = None
80
 
81
+ def _try_load_with_from_pretrained():
82
+ """
83
+ Preferred path: load everything via Diffusers from_pretrained.
84
+ Works when the repo is structured as a standard Diffusers pipeline repo.
85
+ """
86
+ kwargs = {
87
+ "torch_dtype": dtype,
88
+ "use_safetensors": True,
89
+ }
90
+ if HF_TOKEN:
91
+ kwargs["token"] = HF_TOKEN
92
+
93
+ p_txt = ZImagePipeline.from_pretrained(MODEL_PATH, **kwargs)
94
+ p_img = ZImageImg2ImgPipeline(**p_txt.components)
95
+ return p_txt, p_img
96
+
97
+ def _fallback_manual_load():
98
+ """
99
+ Fallback path: load subfolders manually, similar to many Z-Image demos.
100
+ Works when MODEL_PATH points to a repo with subfolders:
101
+ vae/, transformer/, text_encoder/, tokenizer/
102
+ """
103
+ use_auth_token = HF_TOKEN if HF_TOKEN else True
 
 
 
 
 
 
 
104
 
105
+ vae = AutoencoderKL.from_pretrained(
106
+ MODEL_PATH,
107
+ subfolder="vae",
108
+ torch_dtype=dtype,
109
+ use_auth_token=use_auth_token,
110
+ )
111
+ text_encoder = AutoModelForCausalLM.from_pretrained(
112
+ MODEL_PATH,
113
+ subfolder="text_encoder",
114
+ torch_dtype=dtype,
115
+ use_auth_token=use_auth_token,
116
+ ).eval()
117
+ tokenizer = AutoTokenizer.from_pretrained(
118
+ MODEL_PATH,
119
+ subfolder="tokenizer",
120
+ use_auth_token=use_auth_token,
121
+ )
122
  tokenizer.padding_side = "left"
123
 
124
+ # ZImageTransformer2DModel lives inside diffusers; importing lazily avoids import issues on older versions.
125
+ from diffusers.models.transformers.transformer_z_image import ZImageTransformer2DModel
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
+ transformer = ZImageTransformer2DModel.from_pretrained(
128
+ MODEL_PATH,
129
+ subfolder="transformer",
130
+ torch_dtype=dtype,
131
+ use_auth_token=use_auth_token,
132
+ )
133
 
134
+ p_txt = ZImagePipeline(scheduler=None, vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, transformer=transformer)
135
+ p_img = ZImageImg2ImgPipeline(scheduler=None, vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, transformer=transformer)
136
+ return p_txt, p_img
137
 
138
+ try:
139
+ pipe_txt2img, pipe_img2img = _try_load_with_from_pretrained()
140
+ model_loaded = True
141
+ except Exception as e1:
142
  try:
143
+ pipe_txt2img, pipe_img2img = _fallback_manual_load()
144
+ model_loaded = True
145
+ except Exception as e2:
146
+ load_error = f"from_pretrained error: {repr(e1)}\nmanual_load error: {repr(e2)}"
147
+ model_loaded = False
 
 
 
148
 
149
+ if model_loaded:
150
+ pipe_txt2img = pipe_txt2img.to(device)
151
+ pipe_img2img = pipe_img2img.to(device)
152
 
153
+ # Try attention backend (best-effort)
154
  try:
155
+ if hasattr(pipe_txt2img, "transformer") and hasattr(pipe_txt2img.transformer, "set_attention_backend"):
156
+ pipe_txt2img.transformer.set_attention_backend(ATTENTION_BACKEND)
157
+ pipe_img2img.transformer.set_attention_backend(ATTENTION_BACKEND)
158
+ except Exception:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  pass
160
 
161
+ # Optional compile (best-effort, can break on some setups)
162
+ if ENABLE_COMPILE and device.type == "cuda":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  try:
164
+ pipe_txt2img.transformer = torch.compile(pipe_txt2img.transformer, mode="max-autotune-no-cudagraphs", fullgraph=False)
165
+ pipe_img2img.transformer = pipe_txt2img.transformer
166
+ except Exception:
167
+ pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
 
169
+ # Disable diffusers progress bars
170
+ try:
171
+ pipe_txt2img.set_progress_bar_config(disable=True)
172
+ pipe_img2img.set_progress_bar_config(disable=True)
173
+ except Exception:
174
+ pass
175
 
176
+ # ============================================================
177
+ # Utility: error image
178
+ # ============================================================
179
+
180
+ def make_error_image(w, h):
181
+ return Image.new("RGB", (w, h), (18, 18, 22))
182
+
183
+ def _prep_init_image(init_image, width, height):
184
+ if init_image is None:
185
+ return None
186
+ if not isinstance(init_image, Image.Image):
187
+ return None
188
+ init_image = init_image.convert("RGB")
189
+ if init_image.size != (width, height):
190
+ init_image = init_image.resize((width, height), Image.LANCZOS)
191
+ return init_image
192
+
193
+ # ============================================================
194
+ # Inference
195
+ # ============================================================
196
+
197
+ def _infer_impl(
198
+ prompt: str,
199
+ negative_prompt: str,
200
+ seed: int,
201
+ randomize_seed: bool,
202
+ width: int,
203
+ height: int,
204
+ guidance_scale: float,
205
+ num_inference_steps: int,
206
+ shift: float,
207
+ max_sequence_length: int,
208
+ init_image,
209
+ strength: float,
210
+ ):
211
+ width = int(width)
212
+ height = int(height)
213
+ seed = int(seed)
214
 
215
+ if not model_loaded:
216
+ return make_error_image(width, height), f"Model load failed:\n\n{load_error}"
217
 
218
+ prompt = (prompt or "").strip()
219
+ if not prompt:
220
+ return make_error_image(width, height), "Error: Prompt is empty."
221
 
222
+ if randomize_seed:
223
+ seed = random.randint(0, MAX_SEED)
224
 
225
+ init_image = _prep_init_image(init_image, width, height)
226
 
227
+ generator = torch.Generator(device=device)
228
+ generator = generator.manual_seed(seed)
229
 
230
+ status = f"Seed: {seed}"
231
+ if fallback_msg:
232
+ status += f" | {fallback_msg}"
233
 
234
+ # Set scheduler per-run because shift can change
235
+ scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=float(shift))
236
+ pipe_txt2img.scheduler = scheduler
237
+ pipe_img2img.scheduler = scheduler
238
 
239
+ try:
240
+ common_kwargs = dict(
241
+ prompt=prompt,
242
+ negative_prompt=(negative_prompt or "").strip() if (guidance_scale and float(guidance_scale) > 1.0) else None,
243
+ guidance_scale=float(guidance_scale),
244
+ num_inference_steps=int(num_inference_steps),
245
+ generator=generator,
246
+ height=height,
247
+ width=width,
248
+ max_sequence_length=int(max_sequence_length),
249
+ )
250
 
251
+ with torch.inference_mode():
252
+ if device.type == "cuda":
253
+ with torch.autocast("cuda", dtype=dtype):
254
+ if init_image is not None:
255
+ out = pipe_img2img(
256
+ image=init_image,
257
+ strength=float(strength),
258
+ **common_kwargs,
259
+ )
260
+ else:
261
+ out = pipe_txt2img(**common_kwargs)
262
+ else:
263
+ if init_image is not None:
264
+ out = pipe_img2img(
265
+ image=init_image,
266
+ strength=float(strength),
267
+ **common_kwargs,
268
+ )
269
+ else:
270
+ out = pipe_txt2img(**common_kwargs)
271
+
272
+ image = out.images[0]
273
+ return image, status
274
 
275
+ except Exception as e:
276
+ return make_error_image(width, height), f"Error: {type(e).__name__}: {e}"
277
+
278
+ finally:
279
+ gc.collect()
280
+ if device.type == "cuda":
281
+ torch.cuda.empty_cache()
282
+
283
+ # IMPORTANT: decorator must be explicit
284
+ if SPACES_AVAILABLE:
285
+ @spaces.GPU
286
+ def infer(*args, **kwargs):
287
+ return _infer_impl(*args, **kwargs)
288
+ else:
289
+ def infer(*args, **kwargs):
290
+ return _infer_impl(*args, **kwargs)
291
+
292
+ # ============================================================
293
+ # UI
294
+ # ============================================================
295
+
296
+ CSS = """
297
+ body {
298
+ background: #000;
299
+ color: #fff;
300
+ }
301
+ """
302
 
303
+ with gr.Blocks(title="Z-Image txt2img + img2img") as demo:
304
+ gr.HTML(f"<style>{CSS}</style>")
305
+
306
+ if fallback_msg:
307
+ gr.Markdown(f"**{fallback_msg}**")
308
+
309
+ if not model_loaded:
310
+ gr.Markdown(f"⚠️ Model failed to load:\n\n{load_error}")
311
+
312
+ gr.Markdown("## Z-Image Generator (txt2img + img2img)")
313
+
314
+ prompt = gr.Textbox(label="Prompt", lines=2, placeholder="Describe what you want...")
315
+ init_image = gr.Image(label="Initial image (optional)", type="pil")
316
+
317
+ run_button = gr.Button("Generate")
318
+ result = gr.Image(label="Result")
319
+ status = gr.Markdown("")
320
+
321
+ with gr.Accordion("Advanced Settings", open=False):
322
+ negative_prompt = gr.Textbox(label="Negative prompt (only used if Guidance > 1)")
323
+ seed = gr.Slider(0, MAX_SEED, step=1, value=0, label="Seed")
324
+ randomize_seed = gr.Checkbox(value=True, label="Randomize seed")
325
+
326
+ width = gr.Slider(256, MAX_IMAGE_SIZE, step=64, value=1024, label="Width")
327
+ height = gr.Slider(256, MAX_IMAGE_SIZE, step=64, value=1024, label="Height")
328
+
329
+ guidance_scale = gr.Slider(0.0, 10.0, step=0.1, value=0.0, label="Guidance scale")
330
+ num_inference_steps = gr.Slider(1, 50, step=1, value=8, label="Steps")
331
+ shift = gr.Slider(1.0, 10.0, step=0.1, value=3.0, label="Time shift")
332
+
333
+ max_sequence_length = gr.Slider(64, 512, step=64, value=512, label="Max sequence length")
334
+
335
+ strength = gr.Slider(0.0, 1.0, step=0.05, value=0.6, label="Image strength (img2img)")
336
+
337
+ run_button.click(
338
+ fn=infer,
339
+ inputs=[
340
+ prompt,
341
+ negative_prompt,
342
+ seed,
343
+ randomize_seed,
344
+ width,
345
+ height,
346
+ guidance_scale,
347
+ num_inference_steps,
348
+ shift,
349
+ max_sequence_length,
350
+ init_image,
351
+ strength,
352
+ ],
353
+ outputs=[result, status],
 
 
 
 
 
 
 
 
 
 
 
 
 
354
  )
355
 
 
 
 
356
  if __name__ == "__main__":
357
+ # Keep the same launch feel as your first script
358
+ demo.queue().launch(ssr_mode=False)