z-slut / app.py
victor's picture
victor HF Staff
Update app.py
2e5f8c5 verified
import spaces
from dataclasses import dataclass
import json
import logging
import os
import random
import re
import sys
import warnings
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
import gradio as gr
import torch
from transformers import AutoModel, AutoTokenizer
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from diffusers import ZImagePipeline
from diffusers.models.transformers.transformer_z_image import ZImageTransformer2DModel
from pe import prompt_template
# ==================== Environment Variables ==================================
MODEL_PATH = os.environ.get("MODEL_PATH", "Tongyi-MAI/Z-Image-Turbo")
ENABLE_COMPILE = os.environ.get("ENABLE_COMPILE", "true").lower() == "true"
ENABLE_WARMUP = os.environ.get("ENABLE_WARMUP", "true").lower() == "true"
ATTENTION_BACKEND = os.environ.get("ATTENTION_BACKEND", "flash_3")
DASHSCOPE_API_KEY = os.environ.get("DASHSCOPE_API_KEY")
HF_TOKEN = os.environ.get("HF_TOKEN")
# =============================================================================
os.environ["TOKENIZERS_PARALLELISM"] = "false"
warnings.filterwarnings("ignore")
logging.getLogger("transformers").setLevel(logging.ERROR)
RESOLUTIONS = [
# Square
"1024x1024 ( 1:1 )",
# Landscape (wide to narrow)
"1344x576 ( 21:9 )",
"1280x720 ( 16:9 )",
"1248x832 ( 3:2 )",
"1152x864 ( 4:3 )",
# Portrait (tall to short)
"576x1344 ( 9:21 )",
"720x1280 ( 9:16 )",
"832x1248 ( 2:3 )",
"864x1152 ( 3:4 )",
]
EXAMPLE_PROMPTS = [
["ไธ€ไฝ็”ทๅฃซๅ’Œไป–็š„่ดตๅฎพ็Šฌ็ฉฟ็€้…ๅฅ—็š„ๆœ่ฃ…ๅ‚ๅŠ ็‹—็‹—็ง€๏ผŒๅฎคๅ†…็ฏๅ…‰๏ผŒ่ƒŒๆ™ฏไธญๆœ‰่ง‚ไผ—ใ€‚"],
[
"ๆžๅ…ทๆฐ›ๅ›ดๆ„Ÿ็š„ๆš—่ฐƒไบบๅƒ๏ผŒไธ€ไฝไผ˜้›…็š„ไธญๅ›ฝ็พŽๅฅณๅœจ้ป‘ๆš—็š„ๆˆฟ้—ด้‡Œใ€‚ไธ€ๆŸๅผบๅ…‰้€š่ฟ‡้ฎๅ…‰ๆฟ๏ผŒๅœจๅฅน็š„่„ธไธŠๆŠ•ๅฐ„ๅ‡บไธ€ไธชๆธ…ๆ™ฐ็š„้—ช็”ตๅฝข็Šถ็š„ๅ…‰ๅฝฑ๏ผŒๆญฃๅฅฝ็…งไบฎไธ€ๅช็œผ็›ใ€‚้ซ˜ๅฏนๆฏ”ๅบฆ๏ผŒๆ˜Žๆš—ไบค็•Œๆธ…ๆ™ฐ๏ผŒ็ฅž็ง˜ๆ„Ÿ๏ผŒ่Žฑๅก็›ธๆœบ่‰ฒ่ฐƒใ€‚"
],
[
"ไธ€ๅผ ไธญๆ™ฏๆ‰‹ๆœบ่‡ชๆ‹็…ง็‰‡ๆ‹ๆ‘„ไบ†ไธ€ไฝ็•™็€้•ฟ้ป‘ๅ‘็š„ๅนด่ฝปไธœไบšๅฅณๅญๅœจ็ฏๅ…‰ๆ˜Žไบฎ็š„็”ตๆขฏๅ†…ๅฏน็€้•œๅญ่‡ชๆ‹ใ€‚ๅฅน็ฉฟ็€ไธ€ไปถๅธฆๆœ‰็™ฝ่‰ฒ่Šฑๆœตๅ›พๆกˆ็š„้ป‘่‰ฒ้œฒ่‚ฉ็ŸญไธŠ่กฃๅ’Œๆทฑ่‰ฒ็‰›ไป”่ฃคใ€‚ๅฅน็š„ๅคดๅพฎๅพฎๅ€พๆ–œ๏ผŒๅ˜ดๅ”‡ๅ˜Ÿ่ตทๅšไบฒๅป็Šถ๏ผŒ้žๅธธๅฏ็ˆฑไฟ็šฎใ€‚ๅฅนๅณๆ‰‹ๆ‹ฟ็€ไธ€้ƒจๆทฑ็ฐ่‰ฒๆ™บ่ƒฝๆ‰‹ๆœบ๏ผŒ้ฎไฝไบ†้ƒจๅˆ†่„ธ๏ผŒๅŽ็ฝฎๆ‘„ๅƒๅคด้•œๅคดๅฏน็€้•œๅญ"
],
[
"Young Chinese woman in red Hanfu, intricate embroidery. Impeccable makeup, red floral forehead pattern. Elaborate high bun, golden phoenix headdress, red flowers, beads. Holds round folding fan with lady, trees, bird. Neon lightning-bolt lamp (โšก๏ธ), bright yellow glow, above extended left palm. Soft-lit outdoor night background, silhouetted tiered pagoda (่ฅฟๅฎ‰ๅคง้›ๅก”), blurred colorful distant lights."
],
[
'''A vertical digital illustration depicting a serene and majestic Chinese landscape, rendered in a style reminiscent of traditional Shanshui painting but with a modern, clean aesthetic. The scene is dominated by towering, steep cliffs in various shades of blue and teal, which frame a central valley. In the distance, layers of mountains fade into a light blue and white mist, creating a strong sense of atmospheric perspective and depth. A calm, turquoise river flows through the center of the composition, with a small, traditional Chinese boat, possibly a sampan, navigating its waters. The boat has a bright yellow canopy and a red hull, and it leaves a gentle wake behind it. It carries several indistinct figures of people. Sparse vegetation, including green trees and some bare-branched trees, clings to the rocky ledges and peaks. The overall lighting is soft and diffused, casting a tranquil glow over the entire scene. Centered in the image is overlaid text. At the top of the text block is a small, red, circular seal-like logo containing stylized characters. Below it, in a smaller, black, sans-serif font, are the words 'Zao-Xiang * East Beauty & West Fashion * Z-Image'. Directly beneath this, in a larger, elegant black serif font, is the word 'SHOW & SHARE CREATIVITY WITH THE WORLD'. Among them, there are "SHOW & SHARE", "CREATIVITY", and "WITH THE WORLD"'''
],
[
"""ไธ€ๅผ ่™šๆž„็š„่‹ฑ่ฏญ็”ตๅฝฑใ€Šๅ›žๅฟ†ไน‹ๅ‘ณใ€‹๏ผˆThe Taste of Memory๏ผ‰็š„็”ตๅฝฑๆตทๆŠฅใ€‚ๅœบๆ™ฏ่ฎพ็ฝฎๅœจไธ€ไธช่ดจๆœด็š„19ไธ–็บช้ฃŽๆ ผๅŽจๆˆฟ้‡Œใ€‚็”ป้ขไธญๅคฎ๏ผŒไธ€ไฝ็บขๆฃ•่‰ฒๅคดๅ‘ใ€็•™็€ๅฐ่ƒกๅญ็š„ไธญๅนด็”ทๅญ๏ผˆๆผ”ๅ‘˜้˜ฟ็‘Ÿยทๅฝญๅ“ˆๅˆฉๆ น้ฅฐ๏ผ‰็ซ™ๅœจไธ€ๅผ ๆœจๆกŒๅŽ๏ผŒไป–่บซ็ฉฟ็™ฝ่‰ฒ่กฌ่กซใ€้ป‘่‰ฒ้ฉฌ็”ฒๅ’Œ็ฑณ่‰ฒๅ›ด่ฃ™๏ผŒๆญฃ็œ‹็€ไธ€ไฝๅฅณๅฃซ๏ผŒๆ‰‹ไธญๆ‹ฟ็€ไธ€ๅคงๅ—็”Ÿ็บข่‚‰๏ผŒไธ‹ๆ–นๆ˜ฏไธ€ไธชๆœจๅˆถๅˆ‡่œๆฟใ€‚ๅœจไป–็š„ๅณ่พน๏ผŒไธ€ไฝๆขณ็€้ซ˜้ซป็š„้ป‘ๅ‘ๅฅณๅญ๏ผˆๆผ”ๅ‘˜ๅŸƒ่މ่ฏบยทไธ‡ๆ–ฏ้ฅฐ๏ผ‰ๅ€š้ ๅœจๆกŒๅญไธŠ๏ผŒๆธฉๆŸ”ๅœฐๅฏนไป–ๅพฎ็ฌ‘ใ€‚ๅฅน็ฉฟ็€ๆต…่‰ฒ่กฌ่กซๅ’Œไธ€ๆกไธŠ็™ฝไธ‹่“็š„้•ฟ่ฃ™ใ€‚ๆกŒไธŠ้™คไบ†ๆ”พๆœ‰ๅˆ‡็ขŽ็š„่‘ฑๅ’Œๅทๅฟƒ่œไธ็š„ๅˆ‡่œๆฟๅค–๏ผŒ่ฟ˜ๆœ‰ไธ€ไธช็™ฝ่‰ฒ้™ถ็“ท็›˜ใ€ๆ–ฐ้ฒœ้ฆ™่‰๏ผŒๅทฆไพงไธ€ไธชๆœจ็ฎฑไธŠๆ”พ็€ไธ€ไธฒๆทฑ่‰ฒ่‘ก่„ใ€‚่ƒŒๆ™ฏๆ˜ฏไธ€้ข็ฒ—็ณ™็š„็ฐ็™ฝ่‰ฒๆŠน็ฐๅข™๏ผŒๅข™ไธŠๆŒ‚็€ไธ€ๅน…้ฃŽๆ™ฏ็”ปใ€‚ๆœ€ๅณ่พน็š„ไธ€ไธชๅฐ้ขไธŠๆ”พ็€ไธ€็›ๅคๅคๆฒน็ฏใ€‚ๆตทๆŠฅไธŠๆœ‰ๅคง้‡็š„ๆ–‡ๅญ—ไฟกๆฏใ€‚ๅทฆไธŠ่ง’ๆ˜ฏ็™ฝ่‰ฒ็š„ๆ— ่กฌ็บฟๅญ—ไฝ“"ARTISAN FILMS PRESENTS"๏ผŒๅ…ถไธ‹ๆ–นๆ˜ฏ"ELEANOR VANCE"ๅ’Œ"ACADEMY AWARDยฎ WINNER"ใ€‚ๅณไธŠ่ง’ๅ†™็€"ARTHUR PENHALIGON"ๅ’Œ"GOLDEN GLOBEยฎ AWARD WINNER"ใ€‚้กถ้ƒจไธญๅคฎๆ˜ฏๅœฃไธนๆ–ฏ็”ตๅฝฑ่Š‚็š„ๆก‚ๅ† ๆ ‡ๅฟ—๏ผŒไธ‹ๆ–นๅ†™็€"SUNDANCE FILM FESTIVAL GRAND JURY PRIZE 2024"ใ€‚ไธปๆ ‡้ข˜"THE TASTE OF MEMORY"ไปฅ็™ฝ่‰ฒ็š„ๅคงๅท่กฌ็บฟๅญ—ไฝ“้†’็›ฎๅœฐๆ˜พ็คบๅœจไธ‹ๅŠ้ƒจๅˆ†ใ€‚ๆ ‡้ข˜ไธ‹ๆ–นๆณจๆ˜Žไบ†"A FILM BY Tongyi Interaction Lab"ใ€‚ๅบ•้ƒจๅŒบๅŸŸ็”จ็™ฝ่‰ฒๅฐๅญ—ๅˆ—ๅ‡บไบ†ๅฎŒๆ•ด็š„ๆผ”่Œๅ‘˜ๅๅ•๏ผŒๅŒ…ๆ‹ฌ"SCREENPLAY BY ANNA REID"ใ€"CULINARY DIRECTION BY JAMES CARTER"ไปฅๅŠArtisan Filmsใ€Riverstone Picturesๅ’ŒHeritage Media็ญ‰ไผ—ๅคšๅ‡บๅ“ๅ…ฌๅธๆ ‡ๅฟ—ใ€‚ๆ•ดไฝ“้ฃŽๆ ผๆ˜ฏๅ†™ๅฎžไธปไน‰๏ผŒ้‡‡็”จๆธฉๆš–ๆŸ”ๅ’Œ็š„็ฏๅ…‰ๆ–นๆกˆ๏ผŒ่ฅ้€ ๅ‡บไธ€็งไบฒๅฏ†็š„ๆฐ›ๅ›ดใ€‚่‰ฒ่ฐƒไปฅๆฃ•่‰ฒใ€็ฑณ่‰ฒๅ’ŒๆŸ”ๅ’Œ็š„็ปฟ่‰ฒ็ญ‰ๅคงๅœฐ่‰ฒ็ณปไธบไธปใ€‚ไธคไฝๆผ”ๅ‘˜็š„่บซไฝ“้ƒฝๅœจ่…ฐ้ƒจ่ขซๆˆชๆ–ญใ€‚"""
],
[
"""ไธ€ๅผ ๆ–นๅฝขๆž„ๅ›พ็š„็‰นๅ†™็…ง็‰‡๏ผŒไธปไฝ“ๆ˜ฏไธ€็‰‡ๅทจๅคง็š„ใ€้ฒœ็ปฟ่‰ฒ็š„ๆค็‰ฉๅถ็‰‡๏ผŒๅนถๅ ๅŠ ไบ†ๆ–‡ๅญ—๏ผŒไฝฟๅ…ถๅ…ทๆœ‰ๆตทๆŠฅๆˆ–ๆ‚ๅฟ—ๅฐ้ข็š„ๅค–่ง‚ใ€‚ไธป่ฆๆ‹ๆ‘„ๅฏน่ฑกๆ˜ฏไธ€็‰‡ๅŽšๅฎžใ€ๆœ‰่œก่ดจๆ„Ÿ็š„ๅถๅญ๏ผŒไปŽๅทฆไธ‹่ง’ๅˆฐๅณไธŠ่ง’ๅ‘ˆๅฏน่ง’็บฟๅผฏๆ›ฒ็ฉฟ่ฟ‡็”ป้ขใ€‚ๅ…ถ่กจ้ขๅๅ…‰ๆ€งๅพˆๅผบ๏ผŒๆ•ๆ‰ๅˆฐไธ€ไธชๆ˜Žไบฎ็š„็›ดๅฐ„ๅ…‰ๆบ๏ผŒๅฝขๆˆไบ†ไธ€้“็ชๅ‡บ็š„้ซ˜ๅ…‰๏ผŒไบฎ้ขไธ‹ๆ˜พ้œฒๅ‡บๅนณ่กŒ็š„็ฒพ็ป†ๅถ่„‰ใ€‚่ƒŒๆ™ฏ็”ฑๅ…ถไป–ๆทฑ็ปฟ่‰ฒ็š„ๅถๅญ็ป„ๆˆ๏ผŒ่ฟ™ไบ›ๅถๅญ่ฝปๅพฎๅคฑ็„ฆ๏ผŒ่ฅ้€ ๅ‡บๆต…ๆ™ฏๆทฑๆ•ˆๆžœ๏ผŒ็ชๅ‡บไบ†ๅ‰ๆ™ฏ็š„ไธปๅถ็‰‡ใ€‚ๆ•ดไฝ“้ฃŽๆ ผๆ˜ฏๅ†™ๅฎžๆ‘„ๅฝฑ๏ผŒๆ˜Žไบฎ็š„ๅถ็‰‡ไธŽ้ป‘ๆš—็š„้˜ดๅฝฑ่ƒŒๆ™ฏไน‹้—ดๅฝขๆˆ้ซ˜ๅฏนๆฏ”ๅบฆใ€‚ๅ›พๅƒไธŠๆœ‰ๅคšๅค„ๆธฒๆŸ“ๆ–‡ๅญ—ใ€‚ๅทฆไธŠ่ง’ๆ˜ฏ็™ฝ่‰ฒ็š„่กฌ็บฟๅญ—ไฝ“ๆ–‡ๅญ—"PIXEL-PEEPERS GUILD Presents"ใ€‚ๅณไธŠ่ง’ๅŒๆ ทๆ˜ฏ็™ฝ่‰ฒ่กฌ็บฟๅญ—ไฝ“็š„ๆ–‡ๅญ—"[Instant Noodle] ๆณก้ข่ฐƒๆ–™ๅŒ…"ใ€‚ๅทฆไพงๅž‚็›ดๆŽ’ๅˆ—็€ๆ ‡้ข˜"Render Distance: Max"๏ผŒไธบ็™ฝ่‰ฒ่กฌ็บฟๅญ—ไฝ“ใ€‚ๅทฆไธ‹่ง’ๆ˜ฏไบ”ไธช็ก•ๅคง็š„็™ฝ่‰ฒๅฎ‹ไฝ“ๆฑ‰ๅญ—"ๆ˜พๅกๅœจ...็‡ƒ็ƒง"ใ€‚ๅณไธ‹่ง’ๆ˜ฏ่พƒๅฐ็š„็™ฝ่‰ฒ่กฌ็บฟๅญ—ไฝ“ๆ–‡ๅญ—"Leica Glowโ„ข Unobtanium X-1"๏ผŒๅ…ถๆญฃไธŠๆ–นๆ˜ฏ็”จ็™ฝ่‰ฒๅฎ‹ไฝ“ๅญ—ไนฆๅ†™็š„ๅๅญ—"่”กๅ‡ "ใ€‚่ฏ†ๅˆซๅ‡บ็š„ๆ ธๅฟƒๅฎžไฝ“ๅŒ…ๆ‹ฌๅ“็‰Œๅƒ็ด ๅท็ชฅ่€…ๅไผšใ€ๅ…ถไบงๅ“็บฟๆณก้ข่ฐƒๆ–™ๅŒ…ใ€็›ธๆœบๅž‹ๅทไนฐไธๅˆฐโ„ข X-1ไปฅๅŠๆ‘„ๅฝฑๅธˆๅๅญ—้€ ็›ธใ€‚"""
],
]
def get_resolution(resolution):
match = re.search(r"(\d+)\s*[ร—x]\s*(\d+)", resolution)
if match:
return int(match.group(1)), int(match.group(2))
return 1024, 1024
def load_models(model_path, enable_compile=False, attention_backend="native"):
print(f"Loading models from {model_path}...")
use_auth_token = HF_TOKEN if HF_TOKEN else True
if not os.path.exists(model_path):
vae = AutoencoderKL.from_pretrained(
f"{model_path}",
subfolder="vae",
torch_dtype=torch.bfloat16,
device_map="cuda",
use_auth_token=use_auth_token,
)
text_encoder = AutoModel.from_pretrained(
f"{model_path}",
subfolder="text_encoder",
torch_dtype=torch.bfloat16,
device_map="cuda",
use_auth_token=use_auth_token,
).eval()
tokenizer = AutoTokenizer.from_pretrained(
f"{model_path}", subfolder="tokenizer", use_auth_token=use_auth_token
)
else:
vae = AutoencoderKL.from_pretrained(
os.path.join(model_path, "vae"),
torch_dtype=torch.bfloat16,
device_map="cuda",
)
text_encoder = AutoModel.from_pretrained(
os.path.join(model_path, "text_encoder"),
torch_dtype=torch.bfloat16,
device_map="cuda",
).eval()
tokenizer = AutoTokenizer.from_pretrained(os.path.join(model_path, "tokenizer"))
tokenizer.padding_side = "left"
if enable_compile:
print("Enabling torch.compile optimizations...")
torch._inductor.config.conv_1x1_as_mm = True
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.epilogue_fusion = False
torch._inductor.config.coordinate_descent_check_all_directions = True
torch._inductor.config.max_autotune_gemm = True
torch._inductor.config.max_autotune_gemm_backends = "TRITON,ATEN"
torch._inductor.config.triton.cudagraphs = False
pipe = ZImagePipeline(
scheduler=None,
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=None,
)
if enable_compile:
pipe.vae.disable_tiling()
if not os.path.exists(model_path):
transformer = ZImageTransformer2DModel.from_pretrained(
f"{model_path}", subfolder="transformer", use_auth_token=use_auth_token
).to("cuda", torch.bfloat16)
else:
transformer = ZImageTransformer2DModel.from_pretrained(
os.path.join(model_path, "transformer")
).to("cuda", torch.bfloat16)
pipe.transformer = transformer
pipe.transformer.set_attention_backend(attention_backend)
if enable_compile:
print("Compiling transformer...")
pipe.transformer = torch.compile(
pipe.transformer, mode="max-autotune-no-cudagraphs", fullgraph=False
)
pipe.to("cuda", torch.bfloat16)
return pipe
def generate_image(
pipe,
prompt,
resolution="1024x1024",
seed=42,
guidance_scale=5.0,
num_inference_steps=50,
shift=3.0,
max_sequence_length=512,
progress=gr.Progress(track_tqdm=True),
):
width, height = get_resolution(resolution)
generator = torch.Generator("cuda").manual_seed(seed)
scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=shift)
pipe.scheduler = scheduler
image = pipe(
prompt=prompt,
height=height,
width=width,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
generator=generator,
max_sequence_length=max_sequence_length,
).images[0]
return image
def warmup_model(pipe, resolutions):
print("Starting warmup phase...")
dummy_prompt = "warmup"
for res_str in resolutions:
print(f"Warming up for resolution: {res_str}")
try:
for i in range(3):
generate_image(
pipe,
prompt=dummy_prompt,
resolution=res_str,
num_inference_steps=9,
guidance_scale=0.0,
seed=42 + i,
)
except Exception as e:
print(f"Warmup failed for {res_str}: {e}")
print("Warmup completed.")
# ==================== Prompt Expander ====================
@dataclass
class PromptOutput:
status: bool
prompt: str
seed: int
system_prompt: str
message: str
class PromptExpander:
def __init__(self, backend="api", **kwargs):
self.backend = backend
def decide_system_prompt(self, template_name=None):
return prompt_template
class APIPromptExpander(PromptExpander):
def __init__(self, api_config=None, **kwargs):
super().__init__(backend="api", **kwargs)
self.api_config = api_config or {}
self.client = self._init_api_client()
def _init_api_client(self):
try:
from openai import OpenAI
api_key = self.api_config.get("api_key") or DASHSCOPE_API_KEY
base_url = self.api_config.get(
"base_url", "https://dashscope.aliyuncs.com/compatible-mode/v1"
)
if not api_key:
print("Warning: DASHSCOPE_API_KEY not found.")
return None
return OpenAI(api_key=api_key, base_url=base_url)
except ImportError:
print("Please install openai: pip install openai")
return None
except Exception as e:
print(f"Failed to initialize API client: {e}")
return None
def __call__(self, prompt, system_prompt=None, seed=-1, **kwargs):
return self.extend(prompt, system_prompt, seed, **kwargs)
def extend(self, prompt, system_prompt=None, seed=-1, **kwargs):
if self.client is None:
return PromptOutput(
False, "", seed, system_prompt, "API client not initialized"
)
if system_prompt is None:
system_prompt = self.decide_system_prompt()
if "{prompt}" in system_prompt:
system_prompt = system_prompt.format(prompt=prompt)
prompt = " "
try:
model = self.api_config.get("model", "qwen3-max-preview")
response = self.client.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt},
],
temperature=0.7,
top_p=0.8,
)
content = response.choices[0].message.content
json_start = content.find("```json")
if json_start != -1:
json_end = content.find("```", json_start + 7)
try:
json_str = content[json_start + 7 : json_end].strip()
data = json.loads(json_str)
expanded_prompt = data.get("revised_prompt", content)
except:
expanded_prompt = content
else:
expanded_prompt = content
return PromptOutput(
status=True,
prompt=expanded_prompt,
seed=seed,
system_prompt=system_prompt,
message=content,
)
except Exception as e:
return PromptOutput(False, "", seed, system_prompt, str(e))
def create_prompt_expander(backend="api", **kwargs):
if backend == "api":
return APIPromptExpander(**kwargs)
raise ValueError("Only 'api' backend is supported.")
pipe = None
prompt_expander = None
def init_app():
global pipe, prompt_expander
try:
pipe = load_models(
MODEL_PATH,
enable_compile=ENABLE_COMPILE,
attention_backend=ATTENTION_BACKEND,
)
print(f"Model loaded. Compile: {ENABLE_COMPILE}, Backend: {ATTENTION_BACKEND}")
if ENABLE_WARMUP:
warmup_model(pipe, RESOLUTIONS)
except Exception as e:
print(f"Error loading model: {e}")
pipe = None
try:
prompt_expander = create_prompt_expander(
backend="api", api_config={"model": "qwen3-max-preview"}
)
print("Prompt expander initialized.")
except Exception as e:
print(f"Error initializing prompt expander: {e}")
prompt_expander = None
def prompt_enhance(prompt, enable_enhance):
if not enable_enhance or not prompt_expander:
return prompt, "Enhancement disabled or not available."
if not prompt.strip():
return "", "Please enter a prompt."
try:
result = prompt_expander(prompt)
if result.status:
return result.prompt, result.message
else:
return prompt, f"Enhancement failed: {result.message}"
except Exception as e:
return prompt, f"Error: {str(e)}"
@spaces.GPU
def generate(
prompt: str,
resolution: str = "1024x1024 ( 1:1 )",
seed: int = 42,
steps: int = 9,
shift: float = 3.0,
enhance: bool = False,
random_seed: bool = True,
progress: gr.Progress = gr.Progress(track_tqdm=True),
):
"""
Generate a single image using the Z-Image model based on the provided prompt and settings.
This function is exposed as a Gradio/MCP tool via the main 'Generate' button.
It optionally enhances the prompt, configures generation parameters, and
returns exactly one image plus the seed used.
Args:
prompt: Text prompt describing the desired image content.
resolution: Output resolution in format "WIDTHxHEIGHT ( RATIO )".
seed: Seed for reproducible generation. Ignored if random_seed is True.
steps: Number of inference steps for the diffusion process.
shift: Time shift parameter for the flow matching scheduler.
enhance: (Currently disabled in the UI) Whether to enhance the prompt.
random_seed: If True, a new random seed will be sampled.
progress: Gradio progress tracker (automatically provided by Gradio).
Returns:
tuple[object, str, int]: (image, seed_str, seed_int)
"""
if pipe is None:
raise gr.Error("Model not loaded.")
final_prompt = prompt
if enhance:
final_prompt, _ = prompt_enhance(prompt, True)
print(f"Enhanced prompt: {final_prompt}")
if random_seed:
new_seed = random.randint(1, 1000000)
else:
new_seed = seed if seed != -1 else random.randint(1, 1000000)
try:
resolution_str = resolution.split(" ")[0]
except Exception:
resolution_str = "1024x1024"
image = generate_image(
pipe=pipe,
prompt=final_prompt,
resolution=resolution_str,
seed=new_seed,
guidance_scale=0.0,
num_inference_steps=int(steps + 1),
shift=shift,
)
return image, str(new_seed), int(new_seed)
init_app()
# ==================== AoTI (Ahead of Time Inductor compilation) ====================
pipe.transformer.layers._repeated_blocks = ["ZImageTransformerBlock"]
spaces.aoti_blocks_load(pipe.transformer.layers, "zerogpu-aoti/Z-Image", variant="fa3")
with gr.Blocks(title="Z-Image Generation MCP") as demo:
gr.Markdown(
"""<div align="center">
# Z-Image Generation MCP
<a href="https://huggingface.co/settings/mcp?add=victor/Z-Image-Turbo-MCP" target="_blank"
style="display: inline-block; padding: 8px 20px; background: #22c55e;
color: white; text-decoration: none; border-radius: 9999px; font-weight: 600;">
Use via MCP
</a>
*An Efficient Image Generation Foundation Model with Single-Stream Diffusion Transformer*
</div>"""
)
with gr.Row():
with gr.Column(scale=1):
prompt_input = gr.Textbox(
label="Prompt", lines=3, placeholder="Enter your prompt here..."
)
# PE components (Temporarily disabled)
# with gr.Row():
# enable_enhance = gr.Checkbox(label="Enhance Prompt (DashScope)", value=False)
# enhance_btn = gr.Button("Enhance Only")
resolution = gr.Dropdown(
value=RESOLUTIONS[0],
choices=RESOLUTIONS,
label="Resolution",
)
with gr.Row():
seed = gr.Number(label="Seed", value=42, precision=0)
random_seed = gr.Checkbox(label="Random Seed", value=True)
with gr.Row():
steps = gr.Slider(
label="Steps",
minimum=1,
maximum=100,
value=8,
step=1,
interactive=False,
)
shift = gr.Slider(
label="Time Shift", minimum=1.0, maximum=10.0, value=3.0, step=0.1
)
generate_btn = gr.Button("Generate", variant="primary")
# Example prompts
gr.Markdown("### ๐Ÿ“ Example Prompts")
gr.Examples(examples=EXAMPLE_PROMPTS, inputs=prompt_input, label=None)
with gr.Column(scale=1):
# Switched from Gallery -> single Image for MCP-friendly output
output_image = gr.Image(
label="Generated Image",
format="png",
height=600,
interactive=False,
)
used_seed = gr.Textbox(label="Seed Used", interactive=False)
# Dummy enable_enhance variable set to False
enable_enhance = gr.State(value=False)
generate_btn.click(
generate,
inputs=[
prompt_input,
resolution,
seed,
steps,
shift,
enable_enhance,
random_seed,
],
outputs=[output_image, used_seed, seed],
api_visibility="public", # exposed as MCP tool
api_name="generate_image", # nice, stable name for tool clients
)
css = """
.fillable{max-width: 1230px !important}
"""
if __name__ == "__main__":
demo.launch(css=css, mcp_server=True)