Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import numpy as np | |
| import random | |
| import os | |
| import re | |
| import base64 | |
| from io import BytesIO | |
| from PIL import Image | |
| import spaces | |
| from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler | |
| from diffusers import ZImagePipeline | |
| from diffusers.models.transformers.transformer_z_image import ZImageTransformer2DModel | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from huggingface_hub import InferenceClient | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model_repo_id = "Tongyi-MAI/Z-Image-Turbo" | |
| MAX_SEED = np.iinfo(np.int32).max | |
| MAX_IMAGE_SIZE = 1024 | |
| # Load Z-Image model components | |
| print(f"Loading models from {model_repo_id}...") | |
| vae = AutoencoderKL.from_pretrained( | |
| model_repo_id, | |
| subfolder="vae", | |
| torch_dtype=torch.bfloat16, | |
| device_map="cuda", | |
| ) | |
| text_encoder = AutoModelForCausalLM.from_pretrained( | |
| model_repo_id, | |
| subfolder="text_encoder", | |
| torch_dtype=torch.bfloat16, | |
| device_map="cuda", | |
| ).eval() | |
| tokenizer = AutoTokenizer.from_pretrained(model_repo_id, subfolder="tokenizer") | |
| tokenizer.padding_side = "left" | |
| pipe = ZImagePipeline( | |
| scheduler=None, | |
| vae=vae, | |
| text_encoder=text_encoder, | |
| tokenizer=tokenizer, | |
| transformer=None | |
| ) | |
| transformer = ZImageTransformer2DModel.from_pretrained( | |
| model_repo_id, | |
| subfolder="transformer" | |
| ).to("cuda", torch.bfloat16) | |
| pipe.transformer = transformer | |
| pipe.to("cuda", torch.bfloat16) | |
| print("Model loaded successfully!") | |
| # Vision-Language model for prompt enhancement | |
| VL_MODEL = "Qwen/Qwen3-VL-30B-A3B-Instruct" | |
| PROMPT_ENHANCEMENT_SYSTEM = """You are an expert prompt engineer for text-to-image generation models. | |
| Your task is to enhance user prompts to create more detailed, vivid descriptions that will produce high-quality images. | |
| RULES: | |
| 1. If an image is provided, analyze it and incorporate relevant visual details into the enhanced prompt | |
| 2. Maintain the user's original intent and core concept | |
| 3. Add details about: composition, lighting, style, mood, colors, textures, and quality descriptors | |
| 4. Keep the enhanced prompt concise but descriptive (under 150 words) | |
| 5. Output ONLY the enhanced prompt text - no explanations, no quotes, no prefixes like "Enhanced prompt:" | |
| 6. Do not include meta-commentary or thinking process | |
| 7. Write in a natural, flowing style suitable for image generation | |
| EXAMPLE INPUT: "a cat sitting" | |
| EXAMPLE OUTPUT: A fluffy orange tabby cat sitting gracefully on a sunlit windowsill, soft natural lighting streaming through sheer curtains, shallow depth of field, warm golden hour tones, detailed fur texture, peaceful domestic scene, professional photography, 8k resolution""" | |
| def image_to_base64(image) -> str: | |
| """Convert PIL Image to base64 string.""" | |
| if image is None: | |
| return None | |
| # Resize large images to reduce payload size | |
| max_size = 1024 | |
| if max(image.size) > max_size: | |
| image.thumbnail((max_size, max_size), Image.Resampling.LANCZOS) | |
| buffered = BytesIO() | |
| image.save(buffered, format="JPEG", quality=85) | |
| return base64.b64encode(buffered.getvalue()).decode("utf-8") | |
| def enhance_prompt(prompt: str, reference_image=None, oauth_token: str = None) -> str: | |
| """Enhance the prompt using a VL model, optionally with a reference image.""" | |
| if not oauth_token: | |
| print("[Prompt Enhancement] No auth token provided") | |
| return prompt | |
| try: | |
| # Create client with user's token | |
| client = InferenceClient(token=oauth_token) | |
| # Build user content based on whether image is provided | |
| if reference_image is not None: | |
| # Convert image to base64 for the API | |
| img_base64 = image_to_base64(reference_image) | |
| user_content = [ | |
| { | |
| "type": "image_url", | |
| "image_url": {"url": f"data:image/jpeg;base64,{img_base64}"} | |
| }, | |
| { | |
| "type": "text", | |
| "text": f"Analyze this reference image and enhance this prompt for image generation: {prompt}" | |
| } | |
| ] | |
| else: | |
| user_content = f"Enhance this prompt for image generation: {prompt}" | |
| messages = [ | |
| {"role": "system", "content": PROMPT_ENHANCEMENT_SYSTEM}, | |
| {"role": "user", "content": user_content} | |
| ] | |
| response = client.chat_completion( | |
| messages=messages, | |
| model=VL_MODEL, | |
| max_tokens=250, | |
| ) | |
| enhanced = response.choices[0].message.content.strip() | |
| # Clean up any potential formatting artifacts | |
| enhanced = enhanced.strip('"').strip("'").strip() | |
| # Remove any thinking tags if present | |
| if "<think>" in enhanced: | |
| enhanced = re.sub(r'<think>.*?</think>', '', enhanced, flags=re.DOTALL).strip() | |
| print(f"[Prompt Enhancement] Model: {VL_MODEL}") | |
| print(f"[Prompt Enhancement] Original: {prompt}") | |
| print(f"[Prompt Enhancement] Enhanced: {enhanced}") | |
| return enhanced | |
| except Exception as e: | |
| print(f"Error enhancing prompt: {e}") | |
| return prompt # Return original if enhancement fails | |
| def infer( | |
| prompt, | |
| negative_prompt, | |
| seed, | |
| randomize_seed, | |
| width, | |
| height, | |
| guidance_scale, | |
| num_inference_steps, | |
| use_prompt_enhancement, | |
| reference_image, | |
| oauth_token: gr.OAuthToken | None, | |
| progress=gr.Progress(track_tqdm=True), | |
| ): | |
| # Enhance prompt if requested | |
| if use_prompt_enhancement: | |
| token = oauth_token.token if oauth_token else None | |
| prompt = enhance_prompt(prompt, reference_image, token) | |
| if randomize_seed: | |
| seed = random.randint(0, MAX_SEED) | |
| generator = torch.Generator("cuda").manual_seed(seed) | |
| # Create scheduler with shift parameter | |
| scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=3.0) | |
| pipe.scheduler = scheduler | |
| image = pipe( | |
| prompt=prompt, | |
| height=height, | |
| width=width, | |
| guidance_scale=guidance_scale, | |
| num_inference_steps=num_inference_steps, | |
| generator=generator, | |
| max_sequence_length=512, | |
| ).images[0] | |
| return image, seed | |
| examples = [ | |
| "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", | |
| "An astronaut riding a green horse", | |
| "A delicious ceviche cheesecake slice", | |
| ] | |
| # Aspect ratio presets with proper resolutions | |
| ASPECT_RATIOS = { | |
| "1:1 (Square)": (1024, 1024), | |
| "16:9 (Landscape)": (1280, 720), | |
| "9:16 (Portrait)": (720, 1280), | |
| "4:3": (1152, 864), | |
| "3:4": (864, 1152), | |
| "3:2": (1248, 832), | |
| "2:3": (832, 1248), | |
| "21:9 (Ultrawide)": (1344, 576), | |
| } | |
| def update_dimensions(preset): | |
| """Update width/height based on aspect ratio preset.""" | |
| w, h = ASPECT_RATIOS.get(preset, (1024, 1024)) | |
| interactive = preset == "Custom" | |
| return gr.update(value=w, interactive=interactive), gr.update(value=h, interactive=interactive) | |
| css = """ | |
| #col-container { | |
| margin: 0 auto; | |
| max-width: 700px; | |
| } | |
| .title { | |
| text-align: center; | |
| font-weight: 500 !important; | |
| letter-spacing: -0.02em; | |
| margin-bottom: 0 !important; | |
| } | |
| .prompt-container textarea { | |
| border-radius: 0 !important; | |
| border: 1px solid var(--border-color-primary) !important; | |
| } | |
| .prompt-container textarea:focus { | |
| border-color: var(--body-text-color) !important; | |
| box-shadow: none !important; | |
| } | |
| .generate-btn { | |
| min-height: 42px !important; | |
| border-radius: 0 !important; | |
| font-weight: 500 !important; | |
| letter-spacing: 0.02em; | |
| text-transform: uppercase; | |
| font-size: 0.85em !important; | |
| } | |
| .radio-group label { | |
| border-radius: 0 !important; | |
| font-size: 0.8em !important; | |
| padding: 6px 12px !important; | |
| } | |
| .radio-group label span { | |
| font-weight: 400 !important; | |
| } | |
| .accordion { | |
| border-radius: 0 !important; | |
| border: none !important; | |
| border-top: 1px solid var(--border-color-primary) !important; | |
| border-bottom: 1px solid var(--border-color-primary) !important; | |
| } | |
| .accordion > .label-wrap { | |
| padding: 14px 0 !important; | |
| font-size: 0.8em !important; | |
| text-transform: uppercase; | |
| letter-spacing: 0.05em; | |
| font-weight: 500 !important; | |
| } | |
| .result-image { | |
| border-radius: 0 !important; | |
| } | |
| .result-image img { | |
| border-radius: 0 !important; | |
| } | |
| .ref-image { | |
| border-radius: 0 !important; | |
| } | |
| .info-text { | |
| font-size: 0.75em; | |
| opacity: 0.5; | |
| margin-top: 8px !important; | |
| } | |
| .section-label { | |
| font-size: 0.7em; | |
| text-transform: uppercase; | |
| letter-spacing: 0.1em; | |
| opacity: 0.5; | |
| margin-bottom: 8px !important; | |
| } | |
| input[type="range"] { | |
| border-radius: 0 !important; | |
| } | |
| .examples-section button { | |
| border-radius: 0 !important; | |
| font-size: 0.8em !important; | |
| } | |
| footer { display: none !important; } | |
| .gradio-container { background: var(--background-fill-primary) !important; } | |
| """ | |
| with gr.Blocks(css=css, theme=gr.themes.Base()) as demo: | |
| with gr.Column(elem_id="col-container"): | |
| gr.Markdown("# Z-Image", elem_classes="title") | |
| # Login button for HF authentication | |
| login_btn = gr.LoginButton(value="Sign in with Hugging Face") | |
| # Prompt | |
| prompt = gr.Textbox( | |
| label="Prompt", | |
| show_label=False, | |
| max_lines=3, | |
| placeholder="Describe what you want to generate", | |
| elem_classes="prompt-container", | |
| ) | |
| run_button = gr.Button("Generate", variant="primary", elem_classes="generate-btn") | |
| # Aspect Ratio | |
| gr.Markdown("Aspect Ratio", elem_classes="section-label") | |
| aspect_ratio = gr.Radio( | |
| label="Aspect Ratio", | |
| show_label=False, | |
| choices=list(ASPECT_RATIOS.keys()) + ["Custom"], | |
| value="1:1 (Square)", | |
| interactive=True, | |
| elem_classes="radio-group", | |
| ) | |
| # Result | |
| result = gr.Image(label="Output", show_label=False, height=480, elem_classes="result-image") | |
| # Prompt Enhancement | |
| with gr.Accordion("Prompt Enhancement", open=False, elem_classes="accordion"): | |
| use_prompt_enhancement = gr.Checkbox( | |
| label="Enable AI enhancement", | |
| value=False, | |
| ) | |
| reference_image = gr.Image( | |
| label="Reference image", | |
| type="pil", | |
| sources=["upload", "clipboard"], | |
| height=160, | |
| elem_classes="ref-image", | |
| ) | |
| gr.Markdown("Optional reference to guide enhancement style", elem_classes="info-text") | |
| # Settings | |
| with gr.Accordion("Settings", open=False, elem_classes="accordion"): | |
| negative_prompt = gr.Text( | |
| label="Negative prompt", | |
| max_lines=1, | |
| placeholder="", | |
| visible=False, | |
| ) | |
| with gr.Row(): | |
| seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0) | |
| randomize_seed = gr.Checkbox(label="Random seed", value=True) | |
| with gr.Row(): | |
| width = gr.Slider(label="Width", minimum=512, maximum=1536, step=32, value=1024, interactive=False) | |
| height = gr.Slider(label="Height", minimum=512, maximum=1536, step=32, value=1024, interactive=False) | |
| with gr.Row(): | |
| guidance_scale = gr.Slider(label="Guidance", minimum=0.0, maximum=5.0, step=0.5, value=0.0) | |
| num_inference_steps = gr.Slider(label="Steps", minimum=1, maximum=20, step=1, value=8) | |
| # Examples | |
| with gr.Accordion("Examples", open=False, elem_classes="accordion"): | |
| gr.Examples(examples=examples, inputs=[prompt], elem_id="examples-section") | |
| aspect_ratio.change(fn=update_dimensions, inputs=[aspect_ratio], outputs=[width, height]) | |
| gr.on( | |
| triggers=[run_button.click, prompt.submit], | |
| fn=infer, | |
| inputs=[ | |
| prompt, | |
| negative_prompt, | |
| seed, | |
| randomize_seed, | |
| width, | |
| height, | |
| guidance_scale, | |
| num_inference_steps, | |
| use_prompt_enhancement, | |
| reference_image, | |
| ], | |
| outputs=[result, seed], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |