File size: 5,257 Bytes
212e7d6
 
f076923
212e7d6
480e744
212e7d6
 
 
 
 
 
 
 
 
 
 
586af6b
212e7d6
432ea72
212e7d6
 
 
 
 
4145cb7
f076923
212e7d6
 
 
 
f076923
d7d90ba
212e7d6
 
 
 
 
 
 
 
 
 
bac7c86
 
 
 
 
 
 
 
 
 
d7d90ba
 
 
bac7c86
 
d7d90ba
 
 
 
 
 
 
 
bac7c86
 
 
 
 
 
 
 
d7d90ba
bac7c86
 
 
 
 
212e7d6
 
 
bac7c86
d7d90ba
212e7d6
 
 
 
 
 
 
 
 
 
bac7c86
212e7d6
 
 
 
bac7c86
212e7d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bac7c86
 
 
212e7d6
 
bac7c86
212e7d6
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import os
import torch
import spaces
import gradio as gr
from diffusers import DiffusionPipeline, EulerDiscreteScheduler

# ---------------------------------------------------------------------
# Model setup (maps roughly to UNETLoader + VAELoader + CLIPLoader)
# ---------------------------------------------------------------------

# Change this to your preferred SD3 model or a local path.
# For example, you can replace with a local snapshot inside the Space repo.
MODEL_ID = os.getenv("MODEL_ID", "Tongyi-MAI/Z-Image-Turbo")

device = "cuda" if torch.cuda.is_available() else "cpu"

pipe = DiffusionPipeline.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.bfloat16,
    use_safetensors=True,
)

# KSampler → choose a scheduler (Euler is close to your Comfy euler/simple)
pipe.to(device)
pipe.transformer.layers._repeated_blocks = ["ZImageTransformerBlock"]
spaces.aoti_blocks_load(pipe.transformer.layers, "zerogpu-aoti/Z-Image", variant="fa3")

# ---------------------------------------------------------------------
# Inference function (maps to CLIPTextEncode + EmptySD3LatentImage + KSampler + VAEDecode)
# ---------------------------------------------------------------------
@spaces.GPU
@spaces.GPU
def generate_images(
    positive: str,
    negative: str,
    width: int,
    height: int,
    steps: int,
    cfg: float,
    seed: int,
    num_images: int,
):
    run_device = "cuda" if torch.cuda.is_available() else "cpu"
    pipe.to(run_device)

    num_images = int(num_images)
    width = int(width)
    height = int(height)
    steps = int(steps)

    images = []

    # seed >= 0  -> deterministic series: seed, seed+1, ...
    # seed < 0   -> fully random seeds per image
    fixed_base_seed = int(seed) if seed >= 0 else None

    for i in range(num_images):
        if fixed_base_seed is None:
            # random seed for this image
            this_seed = torch.randint(0, 2**63 - 1, (1,), device=run_device).item()
        else:
            # deterministic offset
            this_seed = fixed_base_seed + i

        generator = torch.Generator(device=run_device).manual_seed(int(this_seed))

        out = pipe(
            prompt=positive,
            negative_prompt=negative or None,
            width=width,
            height=height,
            num_inference_steps=steps,
            guidance_scale=float(cfg),
            num_images_per_prompt=1,
            generator=generator,
        ).images[0]

        images.append(out)

    return images




# ---------------------------------------------------------------------
# Gradio UI (inputs correspond to Comfy node widgets_values)
# ---------------------------------------------------------------------
with gr.Blocks() as demo:
    gr.Markdown("# SD3 Text-to-Image – ComfyUI Workflow Port")

    with gr.Row():
        with gr.Column():
            positive = gr.Textbox(
                label="Positive Prompt",
                value="masterpiece, best quality, extremely detailed, high resolution.",  # from CLIP Text Encode (Positive Prompt)
                lines=5,
            )
            negative = gr.Textbox(
                label="Negative Prompt",
                value="watermark, blurry, ugly, bad anatomy",  # from CLIP Text Encode (Negative Prompt)
                lines=4,
            )

            width = gr.Slider(
                label="Width",
                minimum=256,
                maximum=1536,
                step=64,
                value=512,  # EmptySD3LatentImage width
            )
            height = gr.Slider(
                label="Height",
                minimum=256,
                maximum=1536,
                step=64,
                value=768,  # EmptySD3LatentImage height
            )

            steps = gr.Slider(
                label="Steps (KSampler)",
                minimum=1,
                maximum=50,
                step=1,
                value=12,  # KSampler steps
            )
            cfg = gr.Slider(
                label="CFG (Guidance Scale)",
                minimum=1.0,
                maximum=20.0,
                step=0.1,
                value=1.5,  # KSampler cfg in your graph
            )
            num_images = gr.Slider(
                label="Batch Size",
                minimum=1,
                maximum=8,
                step=1,
                value=6,  # EmptySD3LatentImage batch_size
            )
            seed = gr.Number(
                label="Seed (negative for random)",
                value=-1,  # "randomize" in Comfy
                precision=0,
            )

            run_btn = gr.Button("Generate")

        with gr.Column():
            gallery = gr.Gallery(
                label="Output Images",
                show_label=True,
                columns=3,
                height=768,
                object_fit="contain",  # keep full image visible in cell
                preview=False,         # do not start in zoomed preview mode
                allow_preview=True,    # still allow zoom when clicked
            )


    run_btn.click(
        fn=generate_images,
        inputs=[positive, negative, width, height, steps, cfg, seed, num_images],
        outputs=[gallery],
    )

if __name__ == "__main__":
    demo.launch()