Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import torch | |
| import os | |
| import json | |
| import random | |
| import sys | |
| import logging | |
| import warnings | |
| import re | |
| import spaces | |
| from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler | |
| from transformers import AutoModel, AutoTokenizer | |
| from dataclasses import dataclass | |
| sys.path.append(os.path.dirname(os.path.abspath(__file__))) | |
| from diffusers import ZImagePipeline | |
| from diffusers.models.transformers.transformer_z_image import ZImageTransformer2DModel | |
| from pe import prompt_template | |
| # ==================== Environment Variables ================================== | |
| MODEL_PATH = os.environ.get("MODEL_PATH", "Tongyi-MAI/Z-Image-Turbo") | |
| ENABLE_COMPILE = os.environ.get("ENABLE_COMPILE", "true").lower() == "true" | |
| ENABLE_WARMUP = os.environ.get("ENABLE_WARMUP", "true").lower() == "true" | |
| ATTENTION_BACKEND = os.environ.get("ATTENTION_BACKEND", "flash_3") | |
| DASHSCOPE_API_KEY = os.environ.get("DASHSCOPE_API_KEY") | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| # ============================================================================= | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| warnings.filterwarnings("ignore") | |
| logging.getLogger("transformers").setLevel(logging.ERROR) | |
| RESOLUTION_SET = [ | |
| "1024x1024 ( 1:1 )", | |
| "1152x896 ( 9:7 )", | |
| "896x1152 ( 7:9 )", | |
| "1152x864 ( 4:3 )", | |
| "864x1152 ( 3:4 )", | |
| "1248x832 ( 3:2 )", | |
| "832x1248 ( 2:3 )", | |
| "1280x720 (16:9 )", | |
| "720x1280 (9:16 )", | |
| "1344x576 (21:9 )", | |
| "576x1344 (9:21 )", | |
| ] | |
| RES_CHOICES = { | |
| "1024": [ | |
| "1024x1024 ( 1:1 )", | |
| "1152x896 ( 9:7 )", | |
| "896x1152 ( 7:9 )", | |
| "1152x864 ( 4:3 )", | |
| "864x1152 ( 3:4 )", | |
| "1248x832 ( 3:2 )", | |
| "832x1248 ( 2:3 )", | |
| "1280x720 ( 16:9 )", | |
| "720x1280 ( 9:16 )", | |
| "1344x576 ( 21:9 )", | |
| "576x1344 ( 9:21 )", | |
| ], | |
| "1280": [ | |
| "1280x1280 ( 1:1 )", | |
| "1440x1120 ( 9:7 )", | |
| "1120x1440 ( 7:9 )", | |
| "1472x1104 ( 4:3 )", | |
| "1104x1472 ( 3:4 )", | |
| "1536x1024 ( 3:2 )", | |
| "1024x1536 ( 2:3 )", | |
| "1600x900 ( 16:9 )", | |
| "900x1600 ( 9:16 )", | |
| "1680x720 ( 21:9 )", | |
| "720x1680 ( 9:21 )", | |
| ], | |
| } | |
| EXAMPLE_PROMPTS = [ | |
| ['''A vertical digital illustration depicting a serene and majestic Chinese landscape, rendered in a style reminiscent of traditional Shanshui painting but with a modern, clean aesthetic. The scene is dominated by towering, steep cliffs in various shades of blue and teal, which frame a central valley. In the distance, layers of mountains fade into a light blue and white mist, creating a strong sense of atmospheric perspective and depth. A calm, turquoise river flows through the center of the composition, with a small, traditional Chinese boat, possibly a sampan, navigating its waters. The boat has a bright yellow canopy and a red hull, and it leaves a gentle wake behind it. It carries several indistinct figures of people. Sparse vegetation, including green trees and some bare-branched trees, clings to the rocky ledges and peaks. The overall lighting is soft and diffused, casting a tranquil glow over the entire scene. Centered in the image is overlaid text. At the top of the text block is a small, red, circular seal-like logo containing stylized characters. Below it, in a smaller, black, sans-serif font, are the words 'Zao-Xiang * East Beauty & West Fashion * Z-Image'. Directly beneath this, in a larger, elegant black serif font, is the word 'SHOW & SHARE CREATIVITY WITH THE WORLD'. Among them, there are "SHOW & SHARE", "CREATIVITY", and "WITH THE WORLD"'''], | |
| ["极具氛围感的暗调人像,一位优雅的中国美女在黑暗的房间里。一束强光通过遮光板,在她的脸上投射出一个清晰的闪电形状的光影,正好照亮一只眼睛。高对比度,明暗交界清晰,神秘感,莱卡相机色调。"], | |
| ["一张中景手机自拍照片拍摄了一位留着长黑发的年轻东亚女子在灯光明亮的电梯内对着镜子自拍。她穿着一件带有白色花朵图案的黑色露肩短上衣和深色牛仔裤。她的头微微倾斜,嘴唇嘟起做亲吻状,非常可爱俏皮。她右手拿着一部深灰色智能手机,遮住了部分脸,后置摄像头镜头对着镜子"], | |
| ['''一张虚构的英语电影《回忆之味》(The Taste of Memory)的电影海报。场景设置在一个质朴的19世纪风格厨房里。画面中央,一位红棕色头发、留着小胡子的中年男子(演员阿瑟·彭哈利根饰)站在一张木桌后,他身穿白色衬衫、黑色马甲和米色围裙,正看着一位女士,手中拿着一大块生红肉,下方是一个木制切菜板。在他的右边,一位梳着高髻的黑发女子(演员埃莉诺·万斯饰)倚靠在桌子上,温柔地对他微笑。她穿着浅色衬衫和一条上白下蓝的长裙。桌上除了放有切碎的葱和卷心菜丝的切菜板外,还有一个白色陶瓷盘、新鲜香草,左侧一个木箱上放着一串深色葡萄。背景是一面粗糙的灰白色抹灰墙,墙上挂着一幅风景画。最右边的一个台面上放着一盏复古油灯。海报上有大量的文字信息。左上角是白色的无衬线字体"ARTISAN FILMS PRESENTS",其下方是"ELEANOR VANCE"和"ACADEMY AWARD® WINNER"。右上角写着"ARTHUR PENHALIGON"和"GOLDEN GLOBE® AWARD WINNER"。顶部中央是圣丹斯电影节的桂冠标志,下方写着"SUNDANCE FILM FESTIVAL GRAND JURY PRIZE 2024"。主标题"THE TASTE OF MEMORY"以白色的大号衬线字体醒目地显示在下半部分。标题下方注明了"A FILM BY Tongyi Interaction Lab"。底部区域用白色小字列出了完整的演职员名单,包括"SCREENPLAY BY ANNA REID"、"CULINARY DIRECTION BY JAMES CARTER"以及Artisan Films、Riverstone Pictures和Heritage Media等众多出品公司标志。整体风格是写实主义,采用温暖柔和的灯光方案,营造出一种亲密的氛围。色调以棕色、米色和柔和的绿色等大地色系为主。两位演员的身体都在腰部被截断。'''], | |
| ['''一张方形构图的特写照片,主体是一片巨大的、鲜绿色的植物叶片,并叠加了文字,使其具有海报或杂志封面的外观。主要拍摄对象是一片厚实、有蜡质感的叶子,从左下角到右上角呈对角线弯曲穿过画面。其表面反光性很强,捕捉到一个明亮的直射光源,形成了一道突出的高光,亮面下显露出平行的精细叶脉。背景由其他深绿色的叶子组成,这些叶子轻微失焦,营造出浅景深效果,突出了前景的主叶片。整体风格是写实摄影,明亮的叶片与黑暗的阴影背景之间形成高对比度。图像上有多处渲染文字。左上角是白色的衬线字体文字"PIXEL-PEEPERS GUILD Presents"。右上角同样是白色衬线字体的文字"[Instant Noodle] 泡面调料包"。左侧垂直排列着标题"Render Distance: Max",为白色衬线字体。左下角是五个硕大的白色宋体汉字"显卡在...燃烧"。右下角是较小的白色衬线字体文字"Leica Glow™ Unobtanium X-1",其正上方是用白色宋体字书写的名字"蔡几"。识别出的核心实体包括品牌像素偷窥者协会、其产品线泡面调料包、相机型号买不到™ X-1以及摄影师名字造相。'''], | |
| ["一位男士和他的贵宾犬穿着配套的服装参加狗狗秀,室内灯光,背景中有观众。"], | |
| ] | |
| def get_resolution(resolution): | |
| match = re.search(r"(\d+)\s*[×x]\s*(\d+)", resolution) | |
| if match: | |
| return int(match.group(1)), int(match.group(2)) | |
| return 1024, 1024 | |
| def load_models(model_path, enable_compile=False, attention_backend="native"): | |
| print(f"Loading models from {model_path}...") | |
| use_auth_token = HF_TOKEN if HF_TOKEN else True | |
| if not os.path.exists(model_path): | |
| vae = AutoencoderKL.from_pretrained( | |
| f"{model_path}", | |
| subfolder="vae", | |
| torch_dtype=torch.bfloat16, | |
| device_map="cuda", | |
| use_auth_token=use_auth_token | |
| ) | |
| text_encoder = AutoModel.from_pretrained( | |
| f"{model_path}", | |
| subfolder="text_encoder", | |
| torch_dtype=torch.bfloat16, | |
| device_map="cuda", | |
| use_auth_token=use_auth_token | |
| ).eval() | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| f"{model_path}", | |
| subfolder="tokenizer", | |
| use_auth_token=use_auth_token | |
| ) | |
| else: | |
| vae = AutoencoderKL.from_pretrained( | |
| os.path.join(model_path, "vae"), | |
| torch_dtype=torch.bfloat16, | |
| device_map="cuda" | |
| ) | |
| text_encoder = AutoModel.from_pretrained( | |
| os.path.join(model_path, "text_encoder"), | |
| torch_dtype=torch.bfloat16, | |
| device_map="cuda", | |
| ).eval() | |
| tokenizer = AutoTokenizer.from_pretrained(os.path.join(model_path, "tokenizer")) | |
| tokenizer.padding_side = "left" | |
| if enable_compile: | |
| print("Enabling torch.compile optimizations...") | |
| torch._inductor.config.conv_1x1_as_mm = True | |
| torch._inductor.config.coordinate_descent_tuning = True | |
| torch._inductor.config.epilogue_fusion = False | |
| torch._inductor.config.coordinate_descent_check_all_directions = True | |
| torch._inductor.config.max_autotune_gemm = True | |
| torch._inductor.config.max_autotune_gemm_backends = "TRITON,ATEN" | |
| torch._inductor.config.triton.cudagraphs = False | |
| pipe = ZImagePipeline( | |
| scheduler=None, | |
| vae=vae, | |
| text_encoder=text_encoder, | |
| tokenizer=tokenizer, | |
| transformer=None | |
| ) | |
| if enable_compile: | |
| pipe.vae.disable_tiling() | |
| if not os.path.exists(model_path): | |
| transformer = ZImageTransformer2DModel.from_pretrained( | |
| f"{model_path}", | |
| subfolder="transformer", | |
| use_auth_token=use_auth_token | |
| ).to("cuda", torch.bfloat16) | |
| else: | |
| transformer = ZImageTransformer2DModel.from_pretrained( | |
| os.path.join(model_path, "transformer") | |
| ).to("cuda", torch.bfloat16) | |
| pipe.transformer = transformer | |
| pipe.transformer.set_attention_backend(attention_backend) | |
| if enable_compile: | |
| print("Compiling transformer...") | |
| pipe.transformer = torch.compile( | |
| pipe.transformer, mode="max-autotune-no-cudagraphs", fullgraph=False | |
| ) | |
| pipe.to("cuda", torch.bfloat16) | |
| return pipe | |
| def generate_image( | |
| pipe, | |
| prompt, | |
| resolution="1024x1024", | |
| seed=-1, | |
| guidance_scale=5.0, | |
| num_inference_steps=50, | |
| shift=3.0, | |
| max_sequence_length=512, | |
| ): | |
| height, width = get_resolution(resolution) | |
| if seed == -1: | |
| seed = torch.randint(0, 1000000, (1,)).item() | |
| print(f"Using seed: {seed}") | |
| generator = torch.Generator("cuda").manual_seed(seed) | |
| scheduler = FlowMatchEulerDiscreteScheduler( | |
| num_train_timesteps=1000, | |
| shift=shift | |
| ) | |
| pipe.scheduler = scheduler | |
| image = pipe( | |
| prompt=prompt, | |
| height=height, | |
| width=width, | |
| guidance_scale=guidance_scale, | |
| num_inference_steps=num_inference_steps, | |
| generator=generator, | |
| max_sequence_length=max_sequence_length, | |
| ).images[0] | |
| return image | |
| def warmup_model(pipe, resolutions): | |
| print("Starting warmup phase...") | |
| dummy_prompt = "warmup" | |
| for res_str in resolutions: | |
| print(f"Warming up for resolution: {res_str}") | |
| try: | |
| for i in range(3): | |
| generate_image( | |
| pipe, | |
| prompt=dummy_prompt, | |
| resolution=res_str, | |
| num_inference_steps=9, | |
| guidance_scale=0.0, | |
| seed=42 + i | |
| ) | |
| except Exception as e: | |
| print(f"Warmup failed for {res_str}: {e}") | |
| print("Warmup completed.") | |
| # ==================== Prompt Expander ==================== | |
| class PromptOutput: | |
| status: bool | |
| prompt: str | |
| seed: int | |
| system_prompt: str | |
| message: str | |
| class PromptExpander: | |
| def __init__(self, backend="api", **kwargs): | |
| self.backend = backend | |
| def decide_system_prompt(self, template_name=None): | |
| return prompt_template | |
| class APIPromptExpander(PromptExpander): | |
| def __init__(self, api_config=None, **kwargs): | |
| super().__init__(backend="api", **kwargs) | |
| self.api_config = api_config or {} | |
| self.client = self._init_api_client() | |
| def _init_api_client(self): | |
| try: | |
| from openai import OpenAI | |
| api_key = self.api_config.get("api_key") or DASHSCOPE_API_KEY | |
| base_url = self.api_config.get("base_url", "https://dashscope.aliyuncs.com/compatible-mode/v1") | |
| if not api_key: | |
| print("Warning: DASHSCOPE_API_KEY not found.") | |
| return None | |
| return OpenAI(api_key=api_key, base_url=base_url) | |
| except ImportError: | |
| print("Please install openai: pip install openai") | |
| return None | |
| except Exception as e: | |
| print(f"Failed to initialize API client: {e}") | |
| return None | |
| def __call__(self, prompt, system_prompt=None, seed=-1, **kwargs): | |
| return self.extend(prompt, system_prompt, seed, **kwargs) | |
| def extend(self, prompt, system_prompt=None, seed=-1, **kwargs): | |
| if self.client is None: | |
| return PromptOutput(False, "", seed, system_prompt, "API client not initialized") | |
| if system_prompt is None: | |
| system_prompt = self.decide_system_prompt() | |
| if "{prompt}" in system_prompt: | |
| system_prompt = system_prompt.format(prompt=prompt) | |
| prompt = " " | |
| try: | |
| model = self.api_config.get("model", "qwen3-max-preview") | |
| response = self.client.chat.completions.create( | |
| model=model, | |
| messages=[ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": prompt} | |
| ], | |
| temperature=0.7, | |
| top_p=0.8, | |
| ) | |
| content = response.choices[0].message.content | |
| json_start = content.find("```json") | |
| if json_start != -1: | |
| json_end = content.find("```", json_start + 7) | |
| try: | |
| json_str = content[json_start + 7 : json_end].strip() | |
| data = json.loads(json_str) | |
| expanded_prompt = data.get("revised_prompt", content) | |
| except: | |
| expanded_prompt = content | |
| else: | |
| expanded_prompt = content | |
| return PromptOutput( | |
| status=True, | |
| prompt=expanded_prompt, | |
| seed=seed, | |
| system_prompt=system_prompt, | |
| message=content | |
| ) | |
| except Exception as e: | |
| return PromptOutput(False, "", seed, system_prompt, str(e)) | |
| def create_prompt_expander(backend="api", **kwargs): | |
| if backend == "api": | |
| return APIPromptExpander(**kwargs) | |
| raise ValueError("Only 'api' backend is supported.") | |
| pipe = None | |
| prompt_expander = None | |
| def init_app(): | |
| global pipe, prompt_expander | |
| try: | |
| pipe = load_models(MODEL_PATH, enable_compile=ENABLE_COMPILE, attention_backend=ATTENTION_BACKEND) | |
| print(f"Model loaded. Compile: {ENABLE_COMPILE}, Backend: {ATTENTION_BACKEND}") | |
| if ENABLE_WARMUP: | |
| all_resolutions = [] | |
| for cat in RES_CHOICES.values(): | |
| all_resolutions.extend(cat) | |
| warmup_model(pipe, all_resolutions) | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| pipe = None | |
| try: | |
| prompt_expander = create_prompt_expander(backend="api", api_config={"model": "qwen3-max-preview"}) | |
| print("Prompt expander initialized.") | |
| except Exception as e: | |
| print(f"Error initializing prompt expander: {e}") | |
| prompt_expander = None | |
| def prompt_enhance(prompt, enable_enhance): | |
| if not enable_enhance or not prompt_expander: | |
| return prompt, "Enhancement disabled or not available." | |
| if not prompt.strip(): | |
| return "", "Please enter a prompt." | |
| try: | |
| result = prompt_expander(prompt) | |
| if result.status: | |
| return result.prompt, result.message | |
| else: | |
| return prompt, f"Enhancement failed: {result.message}" | |
| except Exception as e: | |
| return prompt, f"Error: {str(e)}" | |
| def generate(prompt, resolution, seed, steps, shift, enhance, gallery_images): | |
| if pipe is None: | |
| raise gr.Error("Model not loaded.") | |
| final_prompt = prompt | |
| if enhance: | |
| final_prompt, _ = prompt_enhance(prompt, True) | |
| print(f"Enhanced prompt: {final_prompt}") | |
| if seed == -1: | |
| seed = random.randint(0, 1000000) | |
| try: | |
| resolution_str = resolution.split(" ")[0] | |
| except: | |
| resolution_str = "1024x1024" | |
| image = generate_image( | |
| pipe=pipe, | |
| prompt=final_prompt, | |
| resolution=resolution_str, | |
| seed=seed, | |
| guidance_scale=0.0, | |
| num_inference_steps=steps, | |
| shift=shift | |
| ) | |
| if gallery_images is None: | |
| gallery_images = [] | |
| gallery_images.append(image) | |
| return gallery_images, str(seed) | |
| init_app() | |
| with gr.Blocks(title="Z-Image Demo") as demo: | |
| gr.Markdown("# Z-Image Generation Demo") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| prompt_input = gr.Textbox(label="Prompt", lines=3, placeholder="Enter your prompt here...") | |
| # PE components (Temporarily disabled) | |
| # with gr.Row(): | |
| # enable_enhance = gr.Checkbox(label="Enhance Prompt (DashScope)", value=False) | |
| # enhance_btn = gr.Button("Enhance Only") | |
| with gr.Row(): | |
| choices = [int(k) for k in RES_CHOICES.keys()] | |
| res_cat = gr.Dropdown(value=1024, choices=choices, label="Resolution Category") | |
| initial_res_choices = RES_CHOICES["1024"] | |
| resolution = gr.Dropdown( | |
| value=initial_res_choices[0], | |
| choices=initial_res_choices, | |
| label="Resolution" | |
| ) | |
| seed = gr.Number(label="Seed", value=-1, precision=0, info="-1 for random") | |
| with gr.Row(): | |
| steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=9, step=1) | |
| shift = gr.Slider(label="Shift", minimum=1.0, maximum=10.0, value=3.0, step=0.1) | |
| generate_btn = gr.Button("Generate", variant="primary") | |
| # Example prompts | |
| gr.Markdown("### 📝 Example Prompts") | |
| gr.Examples( | |
| examples=EXAMPLE_PROMPTS, | |
| inputs=prompt_input, | |
| label=None | |
| ) | |
| with gr.Column(scale=1): | |
| output_gallery = gr.Gallery(label="Generated Images", columns=2, rows=2, height=600, object_fit="contain", format="png") | |
| used_seed = gr.Textbox(label="Seed Used", interactive=False) | |
| def update_res_choices(_res_cat): | |
| if str(_res_cat) in RES_CHOICES: | |
| res_choices = RES_CHOICES[str(_res_cat)] | |
| else: | |
| res_choices = RES_CHOICES["1024"] | |
| return gr.update(value=res_choices[0], choices=res_choices) | |
| res_cat.change(update_res_choices, inputs=res_cat, outputs=resolution) | |
| # PE enhancement button (Temporarily disabled) | |
| # enhance_btn.click( | |
| # prompt_enhance, | |
| # inputs=[prompt_input, enable_enhance], | |
| # outputs=[prompt_input, final_prompt_output] | |
| # ) | |
| # Dummy enable_enhance variable set to False | |
| enable_enhance = gr.State(value=False) | |
| generate_btn.click( | |
| generate, | |
| inputs=[ | |
| prompt_input, | |
| resolution, | |
| seed, | |
| steps, | |
| shift, | |
| enable_enhance, | |
| output_gallery | |
| ], | |
| outputs=[output_gallery, used_seed] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |