Alexander Bagus
commited on
Commit
·
d565c01
1
Parent(s):
28c106c
22
Browse files
app.py
CHANGED
|
@@ -1,7 +1,178 @@
|
|
| 1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
-
|
| 4 |
-
|
|
|
|
| 5 |
|
| 6 |
-
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
+
import numpy as np
|
| 3 |
+
import random
|
| 4 |
+
import spaces
|
| 5 |
+
from diffusers import DiffusionPipeline
|
| 6 |
+
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
|
| 7 |
+
import torch
|
| 8 |
|
| 9 |
+
MODEL_REPO = "Tongyi-MAI/Z-Image-Turbo"
|
| 10 |
+
MAX_SEED = np.iinfo(np.int32).max
|
| 11 |
+
MAX_IMAGE_SIZE = 1280
|
| 12 |
|
| 13 |
+
# Load the pipeline once at startup
|
| 14 |
+
print("Loading Z-Image-Turbo pipeline...")
|
| 15 |
+
pipe = DiffusionPipeline.from_pretrained(
|
| 16 |
+
MODEL_REPO,
|
| 17 |
+
torch_dtype=torch.bfloat16,
|
| 18 |
+
low_cpu_mem_usage=False,
|
| 19 |
+
)
|
| 20 |
+
pipe.to("cuda")
|
| 21 |
+
|
| 22 |
+
# ======== AoTI compilation + FA3 ========
|
| 23 |
+
pipe.transformer.layers._repeated_blocks = ["ZImageTransformerBlock"]
|
| 24 |
+
spaces.aoti_blocks_load(pipe.transformer.layers, "zerogpu-aoti/Z-Image", variant="fa3")
|
| 25 |
+
|
| 26 |
+
@spaces.GPU
|
| 27 |
+
def inference(
|
| 28 |
+
input_image,
|
| 29 |
+
prompt,
|
| 30 |
+
seed=42,
|
| 31 |
+
randomize_seed=True,
|
| 32 |
+
width=1024,
|
| 33 |
+
height=1024,
|
| 34 |
+
guidance_scale=5.0,
|
| 35 |
+
num_inference_steps=8,
|
| 36 |
+
progress=gr.Progress(track_tqdm=True),
|
| 37 |
+
):
|
| 38 |
+
if input_image is None:
|
| 39 |
+
print("Error: input_image is empty.")
|
| 40 |
+
return None
|
| 41 |
+
|
| 42 |
+
if randomize_seed:
|
| 43 |
+
seed = random.randint(0, MAX_SEED)
|
| 44 |
+
|
| 45 |
+
generator = torch.Generator().manual_seed(seed)
|
| 46 |
+
|
| 47 |
+
scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=3)
|
| 48 |
+
pipe.scheduler = scheduler
|
| 49 |
+
|
| 50 |
+
### load input_image
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
# input_image = input_image.convert("RGB").resize((1024, 1024))
|
| 54 |
+
|
| 55 |
+
# input_tensor = pil_image_to_torch_bgr(input_image).to(pipe.vae.device).to(pipe.vae.dtype)
|
| 56 |
+
input_tensor = pipe.image_processor.preprocess(input_image, height=height, width=width)
|
| 57 |
+
# input_tensor = image.to(dtype=torch.float32)
|
| 58 |
+
input_tensor = input_tensor.to(device=pipe.vae.device, dtype=pipe.vae.dtype)
|
| 59 |
+
|
| 60 |
+
# input_latents = _encode_vae_image(pipe=pipe, image=input_tensor, generator=generator)
|
| 61 |
+
# input_latents = pipe.vae.encode(input_tensor).latent_dist.sample()
|
| 62 |
+
# input_latents = input_latents * pipe.vae.config.scaling_factor
|
| 63 |
+
|
| 64 |
+
image = pipe(
|
| 65 |
+
# latents=input_latents,
|
| 66 |
+
prompt=prompt,
|
| 67 |
+
guidance_scale=guidance_scale,
|
| 68 |
+
num_inference_steps=num_inference_steps,
|
| 69 |
+
# width=width,
|
| 70 |
+
# height=height,
|
| 71 |
+
generator=generator,
|
| 72 |
+
).images[0]
|
| 73 |
+
|
| 74 |
+
return image, seed
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def read_file(path: str) -> str:
|
| 78 |
+
with open(path, 'r', encoding='utf-8') as f:
|
| 79 |
+
content = f.read()
|
| 80 |
+
return content
|
| 81 |
+
|
| 82 |
+
css = """
|
| 83 |
+
#col-container {
|
| 84 |
+
margin: 0 auto;
|
| 85 |
+
max-width: 640px;
|
| 86 |
+
}
|
| 87 |
+
"""
|
| 88 |
+
|
| 89 |
+
examples = [
|
| 90 |
+
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
|
| 91 |
+
"An astronaut riding a green horse",
|
| 92 |
+
"A delicious ceviche cheesecake slice",
|
| 93 |
+
]
|
| 94 |
+
|
| 95 |
+
with gr.Blocks(css=css) as demo:
|
| 96 |
+
with gr.Column():
|
| 97 |
+
gr.HTML(read_file("html/header.html"))
|
| 98 |
+
with gr.Column(elem_id="col-container"):
|
| 99 |
+
input_image = gr.Image(height=800,sources=['upload','clipboard'],image_mode='RGB', elem_id="image_upload", type="pil", label="Upload")
|
| 100 |
+
with gr.Row():
|
| 101 |
+
prompt = gr.Text(
|
| 102 |
+
label="Prompt",
|
| 103 |
+
show_label=False,
|
| 104 |
+
max_lines=1,
|
| 105 |
+
placeholder="Enter your prompt",
|
| 106 |
+
container=False,
|
| 107 |
+
value="A high-resolution photographic image with sharp focus, balanced exposure, clean composition, accurate colour rendering, realistic materials and textures, soft and natural lighting, smooth tonal gradients, minimal noise, high dynamic range, detailed shadows and highlights, precise depth of field, lifelike detail, crisp edges, and visually clear separation between foreground, midground, and background.",
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
run_button = gr.Button("Img2Img", scale=0, variant="primary")
|
| 111 |
+
|
| 112 |
+
output_image = gr.Image(label="Result", show_label=False)
|
| 113 |
+
|
| 114 |
+
with gr.Accordion("Advanced Settings", open=False):
|
| 115 |
+
|
| 116 |
+
seed = gr.Slider(
|
| 117 |
+
label="Seed",
|
| 118 |
+
minimum=0,
|
| 119 |
+
maximum=MAX_SEED,
|
| 120 |
+
step=1,
|
| 121 |
+
value=0,
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
| 125 |
+
|
| 126 |
+
with gr.Row():
|
| 127 |
+
width = gr.Slider(
|
| 128 |
+
label="Width",
|
| 129 |
+
minimum=512,
|
| 130 |
+
maximum=MAX_IMAGE_SIZE,
|
| 131 |
+
step=32,
|
| 132 |
+
value=1024,
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
height = gr.Slider(
|
| 136 |
+
label="Height",
|
| 137 |
+
minimum=512,
|
| 138 |
+
maximum=MAX_IMAGE_SIZE,
|
| 139 |
+
step=32,
|
| 140 |
+
value=1024,
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
with gr.Row():
|
| 144 |
+
guidance_scale = gr.Slider(
|
| 145 |
+
label="Guidance scale",
|
| 146 |
+
minimum=0.0,
|
| 147 |
+
maximum=10.0,
|
| 148 |
+
step=0.1,
|
| 149 |
+
value=0.0,
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
num_inference_steps = gr.Slider(
|
| 153 |
+
label="Number of inference steps",
|
| 154 |
+
minimum=1,
|
| 155 |
+
maximum=30,
|
| 156 |
+
step=1,
|
| 157 |
+
value=8,
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
gr.Examples(examples=examples, inputs=[prompt])
|
| 161 |
+
gr.on(
|
| 162 |
+
triggers=[run_button.click, prompt.submit],
|
| 163 |
+
fn=inference,
|
| 164 |
+
inputs=[
|
| 165 |
+
input_image,
|
| 166 |
+
prompt,
|
| 167 |
+
seed,
|
| 168 |
+
randomize_seed,
|
| 169 |
+
width,
|
| 170 |
+
height,
|
| 171 |
+
guidance_scale,
|
| 172 |
+
num_inference_steps,
|
| 173 |
+
],
|
| 174 |
+
outputs=[output_image, seed],
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
if __name__ == "__main__":
|
| 178 |
+
demo.launch(mcp_server=True)
|