GheeButter's picture
{commit_message}
89c5a3f
import gradio as gr
import numpy as np
import random
import os
import re
import base64
from io import BytesIO
from PIL import Image
import spaces
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
from diffusers import ZImagePipeline
from diffusers.models.transformers.transformer_z_image import ZImageTransformer2DModel
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import InferenceClient
device = "cuda" if torch.cuda.is_available() else "cpu"
model_repo_id = "Tongyi-MAI/Z-Image-Turbo"
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024
# Load Z-Image model components
print(f"Loading models from {model_repo_id}...")
vae = AutoencoderKL.from_pretrained(
model_repo_id,
subfolder="vae",
torch_dtype=torch.bfloat16,
device_map="cuda",
)
text_encoder = AutoModelForCausalLM.from_pretrained(
model_repo_id,
subfolder="text_encoder",
torch_dtype=torch.bfloat16,
device_map="cuda",
).eval()
tokenizer = AutoTokenizer.from_pretrained(model_repo_id, subfolder="tokenizer")
tokenizer.padding_side = "left"
pipe = ZImagePipeline(
scheduler=None,
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=None
)
transformer = ZImageTransformer2DModel.from_pretrained(
model_repo_id,
subfolder="transformer"
).to("cuda", torch.bfloat16)
pipe.transformer = transformer
pipe.to("cuda", torch.bfloat16)
print("Model loaded successfully!")
# Vision-Language model for prompt enhancement
VL_MODEL = "Qwen/Qwen3-VL-30B-A3B-Instruct"
PROMPT_ENHANCEMENT_SYSTEM = """You are an expert prompt engineer for text-to-image generation models.
Your task is to enhance user prompts to create more detailed, vivid descriptions that will produce high-quality images.
RULES:
1. If an image is provided, analyze it and incorporate relevant visual details into the enhanced prompt
2. Maintain the user's original intent and core concept
3. Add details about: composition, lighting, style, mood, colors, textures, and quality descriptors
4. Keep the enhanced prompt concise but descriptive (under 150 words)
5. Output ONLY the enhanced prompt text - no explanations, no quotes, no prefixes like "Enhanced prompt:"
6. Do not include meta-commentary or thinking process
7. Write in a natural, flowing style suitable for image generation
EXAMPLE INPUT: "a cat sitting"
EXAMPLE OUTPUT: A fluffy orange tabby cat sitting gracefully on a sunlit windowsill, soft natural lighting streaming through sheer curtains, shallow depth of field, warm golden hour tones, detailed fur texture, peaceful domestic scene, professional photography, 8k resolution"""
def image_to_base64(image) -> str:
"""Convert PIL Image to base64 string."""
if image is None:
return None
# Resize large images to reduce payload size
max_size = 1024
if max(image.size) > max_size:
image.thumbnail((max_size, max_size), Image.Resampling.LANCZOS)
buffered = BytesIO()
image.save(buffered, format="JPEG", quality=85)
return base64.b64encode(buffered.getvalue()).decode("utf-8")
def enhance_prompt(prompt: str, reference_image=None, oauth_token: str = None) -> str:
"""Enhance the prompt using a VL model, optionally with a reference image."""
if not oauth_token:
print("[Prompt Enhancement] No auth token provided")
return prompt
try:
# Create client with user's token
client = InferenceClient(token=oauth_token)
# Build user content based on whether image is provided
if reference_image is not None:
# Convert image to base64 for the API
img_base64 = image_to_base64(reference_image)
user_content = [
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{img_base64}"}
},
{
"type": "text",
"text": f"Analyze this reference image and enhance this prompt for image generation: {prompt}"
}
]
else:
user_content = f"Enhance this prompt for image generation: {prompt}"
messages = [
{"role": "system", "content": PROMPT_ENHANCEMENT_SYSTEM},
{"role": "user", "content": user_content}
]
response = client.chat_completion(
messages=messages,
model=VL_MODEL,
max_tokens=250,
)
enhanced = response.choices[0].message.content.strip()
# Clean up any potential formatting artifacts
enhanced = enhanced.strip('"').strip("'").strip()
# Remove any thinking tags if present
if "<think>" in enhanced:
enhanced = re.sub(r'<think>.*?</think>', '', enhanced, flags=re.DOTALL).strip()
print(f"[Prompt Enhancement] Model: {VL_MODEL}")
print(f"[Prompt Enhancement] Original: {prompt}")
print(f"[Prompt Enhancement] Enhanced: {enhanced}")
return enhanced
except Exception as e:
print(f"Error enhancing prompt: {e}")
return prompt # Return original if enhancement fails
@spaces.GPU
def infer(
prompt,
negative_prompt,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps,
use_prompt_enhancement,
reference_image,
oauth_token: gr.OAuthToken | None,
progress=gr.Progress(track_tqdm=True),
):
# Enhance prompt if requested
if use_prompt_enhancement:
token = oauth_token.token if oauth_token else None
prompt = enhance_prompt(prompt, reference_image, token)
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator("cuda").manual_seed(seed)
# Create scheduler with shift parameter
scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=3.0)
pipe.scheduler = scheduler
image = pipe(
prompt=prompt,
height=height,
width=width,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
generator=generator,
max_sequence_length=512,
).images[0]
return image, seed
examples = [
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
"An astronaut riding a green horse",
"A delicious ceviche cheesecake slice",
]
# Aspect ratio presets with proper resolutions
ASPECT_RATIOS = {
"1:1 (Square)": (1024, 1024),
"16:9 (Landscape)": (1280, 720),
"9:16 (Portrait)": (720, 1280),
"4:3": (1152, 864),
"3:4": (864, 1152),
"3:2": (1248, 832),
"2:3": (832, 1248),
"21:9 (Ultrawide)": (1344, 576),
}
def update_dimensions(preset):
"""Update width/height based on aspect ratio preset."""
w, h = ASPECT_RATIOS.get(preset, (1024, 1024))
interactive = preset == "Custom"
return gr.update(value=w, interactive=interactive), gr.update(value=h, interactive=interactive)
css = """
#col-container {
margin: 0 auto;
max-width: 700px;
}
.title {
text-align: center;
font-weight: 500 !important;
letter-spacing: -0.02em;
margin-bottom: 0 !important;
}
.prompt-container textarea {
border-radius: 0 !important;
border: 1px solid var(--border-color-primary) !important;
}
.prompt-container textarea:focus {
border-color: var(--body-text-color) !important;
box-shadow: none !important;
}
.generate-btn {
min-height: 42px !important;
border-radius: 0 !important;
font-weight: 500 !important;
letter-spacing: 0.02em;
text-transform: uppercase;
font-size: 0.85em !important;
}
.radio-group label {
border-radius: 0 !important;
font-size: 0.8em !important;
padding: 6px 12px !important;
}
.radio-group label span {
font-weight: 400 !important;
}
.accordion {
border-radius: 0 !important;
border: none !important;
border-top: 1px solid var(--border-color-primary) !important;
border-bottom: 1px solid var(--border-color-primary) !important;
}
.accordion > .label-wrap {
padding: 14px 0 !important;
font-size: 0.8em !important;
text-transform: uppercase;
letter-spacing: 0.05em;
font-weight: 500 !important;
}
.result-image {
border-radius: 0 !important;
}
.result-image img {
border-radius: 0 !important;
}
.ref-image {
border-radius: 0 !important;
}
.info-text {
font-size: 0.75em;
opacity: 0.5;
margin-top: 8px !important;
}
.section-label {
font-size: 0.7em;
text-transform: uppercase;
letter-spacing: 0.1em;
opacity: 0.5;
margin-bottom: 8px !important;
}
input[type="range"] {
border-radius: 0 !important;
}
.examples-section button {
border-radius: 0 !important;
font-size: 0.8em !important;
}
footer { display: none !important; }
.gradio-container { background: var(--background-fill-primary) !important; }
"""
with gr.Blocks(css=css, theme=gr.themes.Base()) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown("# Z-Image", elem_classes="title")
# Login button for HF authentication
login_btn = gr.LoginButton(value="Sign in with Hugging Face")
# Prompt
prompt = gr.Textbox(
label="Prompt",
show_label=False,
max_lines=3,
placeholder="Describe what you want to generate",
elem_classes="prompt-container",
)
run_button = gr.Button("Generate", variant="primary", elem_classes="generate-btn")
# Aspect Ratio
gr.Markdown("Aspect Ratio", elem_classes="section-label")
aspect_ratio = gr.Radio(
label="Aspect Ratio",
show_label=False,
choices=list(ASPECT_RATIOS.keys()) + ["Custom"],
value="1:1 (Square)",
interactive=True,
elem_classes="radio-group",
)
# Result
result = gr.Image(label="Output", show_label=False, height=480, elem_classes="result-image")
# Prompt Enhancement
with gr.Accordion("Prompt Enhancement", open=False, elem_classes="accordion"):
use_prompt_enhancement = gr.Checkbox(
label="Enable AI enhancement",
value=False,
)
reference_image = gr.Image(
label="Reference image",
type="pil",
sources=["upload", "clipboard"],
height=160,
elem_classes="ref-image",
)
gr.Markdown("Optional reference to guide enhancement style", elem_classes="info-text")
# Settings
with gr.Accordion("Settings", open=False, elem_classes="accordion"):
negative_prompt = gr.Text(
label="Negative prompt",
max_lines=1,
placeholder="",
visible=False,
)
with gr.Row():
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
randomize_seed = gr.Checkbox(label="Random seed", value=True)
with gr.Row():
width = gr.Slider(label="Width", minimum=512, maximum=1536, step=32, value=1024, interactive=False)
height = gr.Slider(label="Height", minimum=512, maximum=1536, step=32, value=1024, interactive=False)
with gr.Row():
guidance_scale = gr.Slider(label="Guidance", minimum=0.0, maximum=5.0, step=0.5, value=0.0)
num_inference_steps = gr.Slider(label="Steps", minimum=1, maximum=20, step=1, value=8)
# Examples
with gr.Accordion("Examples", open=False, elem_classes="accordion"):
gr.Examples(examples=examples, inputs=[prompt], elem_id="examples-section")
aspect_ratio.change(fn=update_dimensions, inputs=[aspect_ratio], outputs=[width, height])
gr.on(
triggers=[run_button.click, prompt.submit],
fn=infer,
inputs=[
prompt,
negative_prompt,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps,
use_prompt_enhancement,
reference_image,
],
outputs=[result, seed],
)
if __name__ == "__main__":
demo.launch()