ZIT / app.py
Alexander Bagus
22
2352f1a
raw
history blame
4.01 kB
import gradio as gr
import numpy as np
import random, json, spaces, torch
from diffusers import DiffusionPipeline
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
MODEL_REPO = "Tongyi-MAI/Z-Image-Turbo"
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1280
# Load the pipeline once at startup
print("Loading Z-Image-Turbo pipeline...")
pipe = DiffusionPipeline.from_pretrained(
MODEL_REPO,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=False,
)
pipe.to("cuda")
# ======== AoTI compilation + FA3 ========
pipe.transformer.layers._repeated_blocks = ["ZImageTransformerBlock"]
spaces.aoti_blocks_load(pipe.transformer.layers, "zerogpu-aoti/Z-Image", variant="fa3")
@spaces.GPU
def inference(
prompt,
seed=42,
randomize_seed=True,
width=1024,
height=1024,
guidance_scale=5.0,
num_inference_steps=8,
progress=gr.Progress(track_tqdm=True),
):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator().manual_seed(seed)
scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=3)
pipe.scheduler = scheduler
image = pipe(
prompt=prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
width=width,
height=height,
generator=generator,
).images[0]
return image, seed
def read_file(path: str) -> str:
with open(path, 'r', encoding='utf-8') as f:
content = f.read()
return content
css = """
#col-container {
margin: 0 auto;
max-width: 640px;
}
"""
with open('static/data.json', 'r') as file: data = json.load(file)
examples = data['examples']
with gr.Blocks() as demo:
with gr.Column():
gr.HTML(read_file("static/header.html"))
with gr.Column(elem_id="col-container"):
with gr.Row():
prompt = gr.Textbox(
label="Prompt",
show_label=False,
lines=2,
placeholder="Enter your prompt",
container=False,
scale=2
)
run_button = gr.Button("Run", variant="primary", scale=1)
with gr.Accordion("Advanced Settings", open=False):
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=512,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
height = gr.Slider(
label="Height",
minimum=512,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=0.0,
maximum=10.0,
step=0.1,
value=5.0,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=30,
step=1,
value=8,
)
output_image = gr.Image(label="Result", show_label=False)
gr.Examples(examples=examples, inputs=[prompt])
gr.on(
triggers=[run_button.click, prompt.submit],
fn=inference,
inputs=[
prompt,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps,
],
outputs=[output_image, seed],
)
if __name__ == "__main__":
demo.launch(mcp_server=True, css=css)