ZIT / app.py
Alexander Bagus
22
214db8a
raw
history blame
4.42 kB
import gradio as gr
import numpy as np
import random, json, spaces, torch
from diffusers import DiffusionPipeline
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
MODEL_REPO = "Tongyi-MAI/Z-Image-Turbo"
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1280
# Load the pipeline once at startup
print("Loading Z-Image-Turbo pipeline...")
pipe = DiffusionPipeline.from_pretrained(
MODEL_REPO,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=False,
)
pipe.to("cuda")
# ======== AoTI compilation + FA3 ========
pipe.transformer.layers._repeated_blocks = ["ZImageTransformerBlock"]
spaces.aoti_blocks_load(pipe.transformer.layers, "zerogpu-aoti/Z-Image", variant="fa3")
@spaces.GPU
def inference(
prompt,
seed=42,
randomize_seed=True,
width=1024,
height=1024,
guidance_scale=5.0,
num_inference_steps=8,
progress=gr.Progress(track_tqdm=True),
):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator().manual_seed(seed)
scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=3)
pipe.scheduler = scheduler
image = pipe(
prompt=prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
width=width,
height=height,
generator=generator,
).images[0]
return image, seed
def read_file(path: str) -> str:
with open(path, 'r', encoding='utf-8') as f:
content = f.read()
return content
css = """
#col-container {
margin: 0 auto;
max-width: 640px;
}
"""
with open('static/data.json', 'r') as file: data = json.load(file)
examples = data['examples']
with gr.Blocks() as demo:
with gr.Column():
gr.HTML(read_file("static/header.html"))
with gr.Column(elem_id="col-container"):
with gr.Row():
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
container=False,
value="A high-resolution photographic image with sharp focus, balanced exposure, clean composition, accurate colour rendering, realistic materials and textures, soft and natural lighting, smooth tonal gradients, minimal noise, high dynamic range, detailed shadows and highlights, precise depth of field, lifelike detail, crisp edges, and visually clear separation between foreground, midground, and background.",
)
run_button = gr.Button("Img2Img", scale=0, variant="primary")
output_image = gr.Image(label="Result", show_label=False)
with gr.Accordion("Advanced Settings", open=False):
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=512,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
height = gr.Slider(
label="Height",
minimum=512,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=0.0,
maximum=10.0,
step=0.1,
value=5.0,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=30,
step=1,
value=8,
)
gr.Examples(examples=examples, inputs=[prompt])
gr.on(
triggers=[run_button.click, prompt.submit],
fn=inference,
inputs=[
prompt,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps,
],
outputs=[output_image, seed],
)
if __name__ == "__main__":
demo.launch(mcp_server=True, css=css)