File size: 4,527 Bytes
28f30c9
d565c01
356ed08
d565c01
 
28f30c9
d565c01
 
 
28f30c9
d565c01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ef36217
 
d565c01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c0dc0c0
356ed08
d565c01
ad015ab
d565c01
08e0a1e
d565c01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ef36217
d565c01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad015ab
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import gradio as gr
import numpy as np
import random, json, spaces, torch
from diffusers import DiffusionPipeline
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler

MODEL_REPO = "Tongyi-MAI/Z-Image-Turbo"
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1280

# Load the pipeline once at startup
print("Loading Z-Image-Turbo pipeline...")
pipe = DiffusionPipeline.from_pretrained(
    MODEL_REPO,
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=False,
)
pipe.to("cuda")

# ======== AoTI compilation + FA3 ========
pipe.transformer.layers._repeated_blocks = ["ZImageTransformerBlock"]
spaces.aoti_blocks_load(pipe.transformer.layers, "zerogpu-aoti/Z-Image", variant="fa3")

@spaces.GPU
def inference(
    input_image,
    prompt,
    seed=42,
    randomize_seed=True,
    width=1024,
    height=1024,
    guidance_scale=5.0,
    num_inference_steps=8,
    progress=gr.Progress(track_tqdm=True),
):
    if input_image is None:
        print("Error: input_image is empty.")
        return None
    
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)

    generator = torch.Generator().manual_seed(seed)

    scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=3)
    pipe.scheduler = scheduler
    
    image = pipe(
        prompt=prompt,
        guidance_scale=guidance_scale,
        num_inference_steps=num_inference_steps,
        width=width,
        height=height,
        generator=generator,
    ).images[0]

    return image, seed


def read_file(path: str) -> str:
    with open(path, 'r', encoding='utf-8') as f:
        content = f.read()
    return content

css = """
#col-container {
    margin: 0 auto;
    max-width: 640px;
}
"""

with open('static/data.json', 'r') as file: data = json.load(file)
examples = data['examples']

with gr.Blocks() as demo:
    with gr.Column():
        gr.HTML(read_file("static/header.html"))
    with gr.Column(elem_id="col-container"):
        with gr.Row():
            prompt = gr.Text(
                label="Prompt",
                show_label=False,
                max_lines=1,
                placeholder="Enter your prompt",
                container=False,
                value="A high-resolution photographic image with sharp focus, balanced exposure, clean composition, accurate colour rendering, realistic materials and textures, soft and natural lighting, smooth tonal gradients, minimal noise, high dynamic range, detailed shadows and highlights, precise depth of field, lifelike detail, crisp edges, and visually clear separation between foreground, midground, and background.",
            )

            run_button = gr.Button("Img2Img", scale=0, variant="primary")

        output_image = gr.Image(label="Result", show_label=False)

        with gr.Accordion("Advanced Settings", open=False):

            seed = gr.Slider(
                label="Seed",
                minimum=0,
                maximum=MAX_SEED,
                step=1,
                value=0,
            )

            randomize_seed = gr.Checkbox(label="Randomize seed", value=True)

            with gr.Row():
                width = gr.Slider(
                    label="Width",
                    minimum=512,
                    maximum=MAX_IMAGE_SIZE,
                    step=32,
                    value=1024, 
                )

                height = gr.Slider(
                    label="Height",
                    minimum=512,
                    maximum=MAX_IMAGE_SIZE,
                    step=32,
                    value=1024,
                )

            with gr.Row():
                guidance_scale = gr.Slider(
                    label="Guidance scale",
                    minimum=0.0,
                    maximum=10.0,
                    step=0.1,
                    value=5.0,
                )

                num_inference_steps = gr.Slider(
                    label="Number of inference steps",
                    minimum=1,
                    maximum=30,
                    step=1,
                    value=8,
                )

        gr.Examples(examples=examples, inputs=[prompt])
    gr.on(
        triggers=[run_button.click, prompt.submit],
        fn=inference,
        inputs=[
            prompt,
            seed,
            randomize_seed,
            width,
            height,
            guidance_scale,
            num_inference_steps,
        ],
        outputs=[output_image, seed],
    )

if __name__ == "__main__":
    demo.launch(mcp_server=True, css=css)