import gradio as gr import numpy as np import random import os import re import base64 from io import BytesIO from PIL import Image import spaces from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler from diffusers import ZImagePipeline from diffusers.models.transformers.transformer_z_image import ZImageTransformer2DModel import torch from transformers import AutoModelForCausalLM, AutoTokenizer from huggingface_hub import InferenceClient device = "cuda" if torch.cuda.is_available() else "cpu" model_repo_id = "Tongyi-MAI/Z-Image-Turbo" MAX_SEED = np.iinfo(np.int32).max MAX_IMAGE_SIZE = 1024 # Load Z-Image model components print(f"Loading models from {model_repo_id}...") vae = AutoencoderKL.from_pretrained( model_repo_id, subfolder="vae", torch_dtype=torch.bfloat16, device_map="cuda", ) text_encoder = AutoModelForCausalLM.from_pretrained( model_repo_id, subfolder="text_encoder", torch_dtype=torch.bfloat16, device_map="cuda", ).eval() tokenizer = AutoTokenizer.from_pretrained(model_repo_id, subfolder="tokenizer") tokenizer.padding_side = "left" pipe = ZImagePipeline( scheduler=None, vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, transformer=None ) transformer = ZImageTransformer2DModel.from_pretrained( model_repo_id, subfolder="transformer" ).to("cuda", torch.bfloat16) pipe.transformer = transformer pipe.to("cuda", torch.bfloat16) print("Model loaded successfully!") # Vision-Language model for prompt enhancement VL_MODEL = "Qwen/Qwen3-VL-30B-A3B-Instruct" PROMPT_ENHANCEMENT_SYSTEM = """You are an expert prompt engineer for text-to-image generation models. Your task is to enhance user prompts to create more detailed, vivid descriptions that will produce high-quality images. RULES: 1. If an image is provided, analyze it and incorporate relevant visual details into the enhanced prompt 2. Maintain the user's original intent and core concept 3. Add details about: composition, lighting, style, mood, colors, textures, and quality descriptors 4. Keep the enhanced prompt concise but descriptive (under 150 words) 5. Output ONLY the enhanced prompt text - no explanations, no quotes, no prefixes like "Enhanced prompt:" 6. Do not include meta-commentary or thinking process 7. Write in a natural, flowing style suitable for image generation EXAMPLE INPUT: "a cat sitting" EXAMPLE OUTPUT: A fluffy orange tabby cat sitting gracefully on a sunlit windowsill, soft natural lighting streaming through sheer curtains, shallow depth of field, warm golden hour tones, detailed fur texture, peaceful domestic scene, professional photography, 8k resolution""" def image_to_base64(image) -> str: """Convert PIL Image to base64 string.""" if image is None: return None # Resize large images to reduce payload size max_size = 1024 if max(image.size) > max_size: image.thumbnail((max_size, max_size), Image.Resampling.LANCZOS) buffered = BytesIO() image.save(buffered, format="JPEG", quality=85) return base64.b64encode(buffered.getvalue()).decode("utf-8") def enhance_prompt(prompt: str, reference_image=None, oauth_token: str = None) -> str: """Enhance the prompt using a VL model, optionally with a reference image.""" if not oauth_token: print("[Prompt Enhancement] No auth token provided") return prompt try: # Create client with user's token client = InferenceClient(token=oauth_token) # Build user content based on whether image is provided if reference_image is not None: # Convert image to base64 for the API img_base64 = image_to_base64(reference_image) user_content = [ { "type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{img_base64}"} }, { "type": "text", "text": f"Analyze this reference image and enhance this prompt for image generation: {prompt}" } ] else: user_content = f"Enhance this prompt for image generation: {prompt}" messages = [ {"role": "system", "content": PROMPT_ENHANCEMENT_SYSTEM}, {"role": "user", "content": user_content} ] response = client.chat_completion( messages=messages, model=VL_MODEL, max_tokens=250, ) enhanced = response.choices[0].message.content.strip() # Clean up any potential formatting artifacts enhanced = enhanced.strip('"').strip("'").strip() # Remove any thinking tags if present if "" in enhanced: enhanced = re.sub(r'.*?', '', enhanced, flags=re.DOTALL).strip() print(f"[Prompt Enhancement] Model: {VL_MODEL}") print(f"[Prompt Enhancement] Original: {prompt}") print(f"[Prompt Enhancement] Enhanced: {enhanced}") return enhanced except Exception as e: print(f"Error enhancing prompt: {e}") return prompt # Return original if enhancement fails @spaces.GPU def infer( prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, use_prompt_enhancement, reference_image, oauth_token: gr.OAuthToken | None, progress=gr.Progress(track_tqdm=True), ): # Enhance prompt if requested if use_prompt_enhancement: token = oauth_token.token if oauth_token else None prompt = enhance_prompt(prompt, reference_image, token) if randomize_seed: seed = random.randint(0, MAX_SEED) generator = torch.Generator("cuda").manual_seed(seed) # Create scheduler with shift parameter scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=3.0) pipe.scheduler = scheduler image = pipe( prompt=prompt, height=height, width=width, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, generator=generator, max_sequence_length=512, ).images[0] return image, seed examples = [ "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", "An astronaut riding a green horse", "A delicious ceviche cheesecake slice", ] # Aspect ratio presets with proper resolutions ASPECT_RATIOS = { "1:1 (Square)": (1024, 1024), "16:9 (Landscape)": (1280, 720), "9:16 (Portrait)": (720, 1280), "4:3": (1152, 864), "3:4": (864, 1152), "3:2": (1248, 832), "2:3": (832, 1248), "21:9 (Ultrawide)": (1344, 576), } def update_dimensions(preset): """Update width/height based on aspect ratio preset.""" w, h = ASPECT_RATIOS.get(preset, (1024, 1024)) interactive = preset == "Custom" return gr.update(value=w, interactive=interactive), gr.update(value=h, interactive=interactive) css = """ #col-container { margin: 0 auto; max-width: 700px; } .title { text-align: center; font-weight: 500 !important; letter-spacing: -0.02em; margin-bottom: 0 !important; } .prompt-container textarea { border-radius: 0 !important; border: 1px solid var(--border-color-primary) !important; } .prompt-container textarea:focus { border-color: var(--body-text-color) !important; box-shadow: none !important; } .generate-btn { min-height: 42px !important; border-radius: 0 !important; font-weight: 500 !important; letter-spacing: 0.02em; text-transform: uppercase; font-size: 0.85em !important; } .radio-group label { border-radius: 0 !important; font-size: 0.8em !important; padding: 6px 12px !important; } .radio-group label span { font-weight: 400 !important; } .accordion { border-radius: 0 !important; border: none !important; border-top: 1px solid var(--border-color-primary) !important; border-bottom: 1px solid var(--border-color-primary) !important; } .accordion > .label-wrap { padding: 14px 0 !important; font-size: 0.8em !important; text-transform: uppercase; letter-spacing: 0.05em; font-weight: 500 !important; } .result-image { border-radius: 0 !important; } .result-image img { border-radius: 0 !important; } .ref-image { border-radius: 0 !important; } .info-text { font-size: 0.75em; opacity: 0.5; margin-top: 8px !important; } .section-label { font-size: 0.7em; text-transform: uppercase; letter-spacing: 0.1em; opacity: 0.5; margin-bottom: 8px !important; } input[type="range"] { border-radius: 0 !important; } .examples-section button { border-radius: 0 !important; font-size: 0.8em !important; } footer { display: none !important; } .gradio-container { background: var(--background-fill-primary) !important; } """ with gr.Blocks(css=css, theme=gr.themes.Base()) as demo: with gr.Column(elem_id="col-container"): gr.Markdown("# Z-Image", elem_classes="title") # Login button for HF authentication login_btn = gr.LoginButton(value="Sign in with Hugging Face") # Prompt prompt = gr.Textbox( label="Prompt", show_label=False, max_lines=3, placeholder="Describe what you want to generate", elem_classes="prompt-container", ) run_button = gr.Button("Generate", variant="primary", elem_classes="generate-btn") # Aspect Ratio gr.Markdown("Aspect Ratio", elem_classes="section-label") aspect_ratio = gr.Radio( label="Aspect Ratio", show_label=False, choices=list(ASPECT_RATIOS.keys()) + ["Custom"], value="1:1 (Square)", interactive=True, elem_classes="radio-group", ) # Result result = gr.Image(label="Output", show_label=False, height=480, elem_classes="result-image") # Prompt Enhancement with gr.Accordion("Prompt Enhancement", open=False, elem_classes="accordion"): use_prompt_enhancement = gr.Checkbox( label="Enable AI enhancement", value=False, ) reference_image = gr.Image( label="Reference image", type="pil", sources=["upload", "clipboard"], height=160, elem_classes="ref-image", ) gr.Markdown("Optional reference to guide enhancement style", elem_classes="info-text") # Settings with gr.Accordion("Settings", open=False, elem_classes="accordion"): negative_prompt = gr.Text( label="Negative prompt", max_lines=1, placeholder="", visible=False, ) with gr.Row(): seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0) randomize_seed = gr.Checkbox(label="Random seed", value=True) with gr.Row(): width = gr.Slider(label="Width", minimum=512, maximum=1536, step=32, value=1024, interactive=False) height = gr.Slider(label="Height", minimum=512, maximum=1536, step=32, value=1024, interactive=False) with gr.Row(): guidance_scale = gr.Slider(label="Guidance", minimum=0.0, maximum=5.0, step=0.5, value=0.0) num_inference_steps = gr.Slider(label="Steps", minimum=1, maximum=20, step=1, value=8) # Examples with gr.Accordion("Examples", open=False, elem_classes="accordion"): gr.Examples(examples=examples, inputs=[prompt], elem_id="examples-section") aspect_ratio.change(fn=update_dimensions, inputs=[aspect_ratio], outputs=[width, height]) gr.on( triggers=[run_button.click, prompt.submit], fn=infer, inputs=[ prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, use_prompt_enhancement, reference_image, ], outputs=[result, seed], ) if __name__ == "__main__": demo.launch()